diff --git a/axlearn/common/launch.py b/axlearn/common/launch.py index 843454aa..b9a6cf46 100644 --- a/axlearn/common/launch.py +++ b/axlearn/common/launch.py @@ -41,6 +41,11 @@ # Note: this will disable other TF_CPP info and warnnings. os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2") +if instance_type.startswith("gpu"): + # Prevent GPU OOM issues due to TF taking up all the GPU memory. + # Reference: https://stackoverflow.com/a/54927279 + os.environ.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true") + # Import jax before tensorflow else to avoid problems such as: # tpu_library_init_fns.inc:98] TpuEmbeddingEngine_ExecutePartitioner not available in this library. import jax # jax must be imported before tensorflow!