Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Goodput & Badput recording and monitoring support. #783

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 60 additions & 3 deletions axlearn/cloud/gcp/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,39 @@
import jax
from absl import flags, logging
from ml_goodput_measurement import goodput
from ml_goodput_measurement import monitoring as goodput_monitoring

from axlearn.cloud.common.utils import parse_kv_flags
from axlearn.common import measurement
from axlearn.common.config import maybe_set_config
from axlearn.common.config import REQUIRED, Required, config_class, maybe_set_config


@measurement.register_recorder("goodput")
class GoodputRecorder(measurement.Recorder):
"""Records overall training goodput."""

Config = measurement.Recorder.Config
@config_class
class Config(measurement.Recorder.Config):
"""Configures GoodputRecorder.

Attributes:
upload_dir: Directory to store metrics for the monitor.
upload_interval: Time interval (seconds) for monitoring uploads.
"""

upload_dir: Required[str] = REQUIRED
upload_interval: Required[int] = REQUIRED

@classmethod
def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder":
"""Converts flags to a recorder.

`fv.recorder_spec` will be interpreted as a list of `key=value` pairs; config names
corresponding to keys will be set to the corresponding values.
corresponding to keys will be set to the corresponding values. A GoodputRecorder can
additionally take in following Tensorboard configs in the recorder_spec:
- upload_dir: The directory to write Tensorboard data to.
- upload_interval: The time interval in seconds at which to query and upload data
dipannita08 marked this conversation as resolved.
Show resolved Hide resolved
to Tensorboard.
"""
cfg: measurement.Recorder.Config = cls.default_config()
cfg = maybe_set_config(cfg, **parse_kv_flags(fv.recorder_spec, delimiter="="))
Expand All @@ -32,6 +47,7 @@ def __init__(self, cfg):
super().__init__(cfg)
cfg: GoodputRecorder.Config = self.config
self._recorder = None
self._monitor = None

def record(self, event: measurement.Event, *args, **kwargs):
# Lazily instantiate the recorder. This avoids invoking jax before setup is complete.
Expand All @@ -49,10 +65,51 @@ def record(self, event: measurement.Event, *args, **kwargs):
self._recorder.record_job_end_time(*args, **kwargs)
elif event == measurement.Event.START_STEP:
self._recorder.record_step_start_time(*args, **kwargs)
elif event == measurement.Event.START_ACCELERATOR_INIT:
self._recorder.record_tpu_init_start_time(*args, **kwargs)
elif event == measurement.Event.END_ACCELERATOR_INIT:
self._recorder.record_tpu_init_end_time(*args, **kwargs)
elif event == measurement.Event.START_TRAINING_PREPARATION:
self._recorder.record_training_preparation_start_time(*args, **kwargs)
elif event == measurement.Event.END_TRAINING_PREPARATION:
self._recorder.record_training_preparation_end_time(*args, **kwargs)
elif event == measurement.Event.START_DATA_LOADING:
self._recorder.record_data_loading_start_time(*args, **kwargs)
elif event == measurement.Event.END_DATA_LOADING:
self._recorder.record_data_loading_end_time(*args, **kwargs)
else:
logging.log_first_n(
logging.WARNING,
"Ignoring unknown event %s",
1,
event,
)

def start_monitoring(self, *args, **kwargs):
"""Starts Monitoring of Goodput.

Instantiate ml-goodput-measurement's GoodputMonitor to asynchronously calculate
Goodput and Badput at the upload_interval and upload to the specified TensorBoard
directory.
Note: This function requires initialization of distributed JAX before it is called.
"""
if self._monitor is None:
cfg: GoodputRecorder.Config = self.config
self._monitor = goodput_monitoring.GoodputMonitor(
job_name=cfg.name,
logger_name=f"goodput_logger_{cfg.name}",
tensorboard_dir=cfg.upload_dir,
upload_interval=int(cfg.upload_interval),
monitoring_enabled=(jax.process_index() == 0),
include_badput_breakdown=True,
)

if self._monitor:
self._monitor.start_goodput_uploader(*args, **kwargs)
logging.info("Started Goodput upload to Tensorboard in the background!")
else:
logging.log_first_n(
logging.WARNING,
"Goodput upload could not be started. Please check GoodputMonitor logs.",
1,
)
markblee marked this conversation as resolved.
Show resolved Hide resolved
41 changes: 38 additions & 3 deletions axlearn/cloud/gcp/measurement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
class GoodputRecorderTest(parameterized.TestCase):
"""Tests GoodputRecorder."""

@parameterized.parameters(None, ["name=test-name"])
@parameterized.parameters(
(None,), (["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],)
)
def test_from_flags(self, spec):
fv = flags.FlagValues()
measurement.define_flags(flag_values=fv)
Expand All @@ -34,13 +36,46 @@ def test_from_flags(self, spec):
# Recorder is not instantiated until first event.
self.assertIsNone(recorder._recorder)

def test_record(self):
def test_record_and_monitor(self):
fv = flags.FlagValues()
measurement.define_flags(flag_values=fv)
fv.set_default("recorder_spec", ["name=test-name"])
fv.set_default(
"recorder_spec",
["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],
)
fv.mark_as_parsed()

recorder = GoodputRecorder.from_flags(fv)
recorder._recorder = mock.MagicMock()
recorder.record(measurement.Event.START_JOB)
self.assertTrue(recorder._recorder.record_job_start_time.called)

def test_start_monitoring(self):
fv = flags.FlagValues()
measurement.define_flags(flag_values=fv)
fv.set_default(
"recorder_spec",
["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],
)
fv.mark_as_parsed()

recorder = GoodputRecorder.from_flags(fv)
self.assertIsNone(recorder._monitor) # Ensure _monitor is initially None

with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_goodput_monitor:
mock_monitor_instance = mock_goodput_monitor.return_value
recorder.start_monitoring()

# Check that GoodputMonitor was instantiated
mock_goodput_monitor.assert_called_once_with(
job_name="test-name",
logger_name="goodput_logger_test-name",
tensorboard_dir="/test/path/to/upload",
upload_interval=15,
monitoring_enabled=True,
include_badput_breakdown=True,
)

# Ensure that start_goodput_uploader is called on the monitor instance
mock_monitor_instance.start_goodput_uploader.assert_called_once()
self.assertIsNotNone(recorder._monitor)
1 change: 1 addition & 0 deletions axlearn/common/launch_trainer_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def main(_):
launch.setup()
trainer_config = launch_trainer.get_trainer_config()
trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder))
measurement.start_monitoring()
dipannita08 marked this conversation as resolved.
Show resolved Hide resolved
launch_trainer.run_trainer(trainer_config)


Expand Down
29 changes: 29 additions & 0 deletions axlearn/common/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,23 @@ class Event(enum.Enum):
START_JOB: Start of job.
END_JOB: End of job.
START_STEP: Start of a training step. Should be recorded with `step` as a positional arg.
START_ACCELERATOR_INIT: Start of accelerator mesh initialization.
END_ACCELERATOR_INIT: End of accelerator mesh initialization.
START_TRAINING_PREPARATION: Start of training preparation.
END_TRAINING_PREPARATION: End of training preparation.
START_DATA_LOADING: Start of data loading.
END_DATA_LOADING: End of data loading.
"""

START_JOB = "START_JOB"
END_JOB = "END_JOB"
START_STEP = "START_STEP"
START_ACCELERATOR_INIT = "START_ACCELERATOR_INIT"
END_ACCELERATOR_INIT = "END_ACCELERATOR_INIT"
START_TRAINING_PREPARATION = "START_TRAINING_PREPARATION"
END_TRAINING_PREPARATION = "END_TRAINING_PREPARATION"
START_DATA_LOADING = "START_DATA_LOADING"
END_DATA_LOADING = "END_DATA_LOADING"


class Recorder(Configurable):
Expand All @@ -47,6 +59,10 @@ def record(self, event: Event, *args, **kwargs):
"""Records an event with the given name."""
raise NotImplementedError(type(self))

def start_monitoring(self, **kwargs):
"""Starts computing and uploading metrics at some configured interval in the background."""
raise NotImplementedError(type(self))


_recorders: dict[str, type] = {}
_T = TypeVar("_T")
Expand Down Expand Up @@ -120,3 +136,16 @@ def record_event(event: Event):
logging.log_first_n(logging.INFO, "No recorder configured, ignoring events.", 1)
else:
global_recorder.record(event)


def start_monitoring():
"""Begins monitoring events as per global monitor functionality."""
if global_recorder is None:
logging.log_first_n(
logging.INFO, "Since recorder is not set up, monitoring cannot be started.", 1
)
else:
global_recorder.start_monitoring()
logging.info(
"Starting monitoring of events using global recorder's monitor: %s", global_recorder
)
7 changes: 7 additions & 0 deletions axlearn/common/measurement_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,10 @@ def test_initialize(self, recorder_type, expected):
with mock.patch.object(measurement.global_recorder, "record") as mock_record:
measurement.record_event(measurement.Event.START_JOB)
self.assertIn(measurement.Event.START_JOB, mock_record.call_args[0])

# Ensure that start_monitoring does not fail.
with mock.patch.object(
measurement.global_recorder, "start_monitoring"
) as mock_start_monitoring:
measurement.start_monitoring()
mock_start_monitoring.assert_called_once()
8 changes: 8 additions & 0 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def __init__(
utils.validate_float_dtype(cfg.train_dtype)

# Create the device mesh.
self._maybe_record_event(measurement.Event.START_ACCELERATOR_INIT)
if devices is None:
self._step_log(
"Devices: global=%s local=%s %s",
Expand Down Expand Up @@ -324,6 +325,7 @@ def __init__(
model=self.model,
model_param_partition_specs=model_param_partition_specs,
)
self._maybe_record_event(measurement.Event.END_ACCELERATOR_INIT)
dipannita08 marked this conversation as resolved.
Show resolved Hide resolved

@property
def step(self):
Expand Down Expand Up @@ -828,6 +830,7 @@ def _prepare_training(self, prng_key: Tensor) -> bool:
# Attempt to restore the latest checkpoint, which may contain a saved `_input_iter`.
self.restore_checkpoint(restore_step=None)

self._maybe_record_event(measurement.Event.START_TRAINING_PREPARATION)
if self.step is None:
# If we didn't restore from checkpoint, attempt to build initial state according
# to `cfg.init_state_builder` and initialize the remaining parameters.
Expand All @@ -847,6 +850,7 @@ def _prepare_training(self, prng_key: Tensor) -> bool:
with fs.open(os.path.join(cfg.dir, "model_analysis.txt"), "w") as f:
f.write(model_analysis)

self._maybe_record_event(measurement.Event.END_TRAINING_PREPARATION)
dipannita08 marked this conversation as resolved.
Show resolved Hide resolved
# Log config.
self.summary_writer.log_config(cfg, step=self.step)

Expand Down Expand Up @@ -883,6 +887,7 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int
restore_input_iter = cfg.save_input_iterator
try:
# Try to restore with `input_iter`.
self._maybe_record_event(measurement.Event.START_DATA_LOADING)
dipannita08 marked this conversation as resolved.
Show resolved Hide resolved
step, ckpt_state = self.checkpointer.restore(
step=restore_step,
state=(
Expand All @@ -896,13 +901,15 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int
step,
restore_input_iter,
)
self._maybe_record_event(measurement.Event.END_DATA_LOADING)
dipannita08 marked this conversation as resolved.
Show resolved Hide resolved
except ValueError as e:
logging.warning(
"Attempt to restore checkpoint with restore_input_iter=%s failed: %s",
restore_input_iter,
e,
)
# Restore with a different restore_input_iter setting.
self._maybe_record_event(measurement.Event.START_DATA_LOADING)
restore_input_iter = not restore_input_iter
step, ckpt_state = self.checkpointer.restore(
step=restore_step,
Expand All @@ -917,6 +924,7 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int
step,
restore_input_iter,
)
self._maybe_record_event(measurement.Event.END_DATA_LOADING)
if step is not None:
self._step = step
self._trainer_state = TrainerState(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,8 @@ class DummyRecorder(measurement.Recorder):
@classmethod
def from_flags(cls, fv) -> measurement.Recorder:
del fv
return cls.default_config().set(name="dummy_recorder").instantiate()
return (
cls.default_config()
.set(name="dummy_recorder", upload_dir="/dummy/upload_dir", upload_interval=15)
.instantiate()
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ gcp = [
"google-cloud-compute==1.19.2", # Needed for region discovery for CloudBuild API access.
"google-cloud-core==2.3.3",
"google-cloud-build==3.24.1",
"ml_goodput_measurement==0.0.2",
"ml-goodput-measurement==0.0.4",
"pika==1.3.2", # used by event queue
"pyOpenSSL>=22.1.0", # compat with cryptography version.
"tpu-info==0.2.0", # For TPU monitoring from libtpu. https://github.com/AI-Hypercomputer/cloud-accelerator-diagnostics/tree/main/tpu_info
Expand Down