From bcd8618a1e8a9068adeb1da93401275afbfa32a8 Mon Sep 17 00:00:00 2001 From: Dipannita Shaw Date: Mon, 21 Oct 2024 23:44:49 +0000 Subject: [PATCH 1/4] Code clean up --- axlearn/cloud/gcp/measurement.py | 12 +++ axlearn/cloud/gcp/monitoring.py | 64 +++++++++++++++ axlearn/common/launch_trainer_main.py | 5 +- axlearn/common/measurement.py | 12 +++ axlearn/common/monitoring.py | 113 ++++++++++++++++++++++++++ axlearn/common/trainer.py | 8 ++ pyproject.toml | 2 +- 7 files changed, 214 insertions(+), 2 deletions(-) create mode 100644 axlearn/cloud/gcp/monitoring.py create mode 100644 axlearn/common/monitoring.py diff --git a/axlearn/cloud/gcp/measurement.py b/axlearn/cloud/gcp/measurement.py index a2515f6d4..cea224335 100644 --- a/axlearn/cloud/gcp/measurement.py +++ b/axlearn/cloud/gcp/measurement.py @@ -49,6 +49,18 @@ def record(self, event: measurement.Event, *args, **kwargs): self._recorder.record_job_end_time(*args, **kwargs) elif event == measurement.Event.START_STEP: self._recorder.record_step_start_time(*args, **kwargs) + elif event == measurement.Event.START_ACCELERATOR_INIT: + self._recorder.record_tpu_init_start_time(*args, **kwargs) + elif event == measurement.Event.END_ACCELERATOR_INIT: + self._recorder.record_tpu_init_end_time(*args, **kwargs) + elif event == measurement.Event.START_TRAINING_PREPARATION: + self._recorder.record_training_preparation_start_time(*args, **kwargs) + elif event == measurement.Event.END_TRAINING_PREPARATION: + self._recorder.record_training_preparation_end_time(*args, **kwargs) + elif event == measurement.Event.START_DATA_LOADING: + self._recorder.record_data_loading_start_time(*args, **kwargs) + elif event == measurement.Event.END_DATA_LOADING: + self._recorder.record_data_loading_end_time(*args, **kwargs) else: logging.log_first_n( logging.WARNING, diff --git a/axlearn/cloud/gcp/monitoring.py b/axlearn/cloud/gcp/monitoring.py new file mode 100644 index 000000000..09f108bba --- /dev/null +++ b/axlearn/cloud/gcp/monitoring.py @@ -0,0 +1,64 @@ +# Copyright © 2024 Apple Inc. + +"""Goodput & Badput computation and monitoring utils for GCP.""" + +import jax +from absl import flags, logging +from ml_goodput_measurement import monitoring as goodput_monitoring + +from axlearn.cloud.common.utils import parse_kv_flags +from axlearn.common import monitoring +from axlearn.common.config import maybe_set_config + + +@monitoring.register_monitor("GoodputMonitor") +class GoodputMonitor(monitoring.Monitor): + """Computes and uploads overall training goodput and optionally badput.""" + + Config = monitoring.Monitor.Config + + @classmethod + def from_flags(cls, fv: flags.FlagValues) -> "GoodputMonitor": + """Converts flags to a GoodputMonitor. + + `fv.monitor_spec` will be interpreted as a list of `key=value` pairs; config names + corresponding to keys will be set to the corresponding values. A GoodputMonitor can + additionally take in following Tensorboard configs in the monitor_spec: + - upload_dir: The directory to write Tensorboard data to. + - upload_interval: The time interval in seconds at which to query and upload data + to Tensorboard. + """ + cfg: monitoring.Monitor.Config = cls.default_config() + cfg = maybe_set_config(cfg, **parse_kv_flags(fv.monitor_spec, delimiter="=")) + return cfg.instantiate() + + def __init__(self, cfg): + super().__init__(cfg) + cfg: GoodputMonitor.Config = self.config + self._monitor = None + + def start_monitoring(self, *args, **kwargs): + # Instantiate ml-goodput-measurement's GoodputMonitor + # to asynchronously calculate goodput and badput at + # the upload_interval and upload to the specified + # tensorboard directory. + if self._monitor is None: + cfg: GoodputMonitor.Config = self.config + self._monitor = goodput_monitoring.GoodputMonitor( + job_name=cfg.name, + logger_name=f"goodput_logger_{cfg.name}", + tensorboard_dir=cfg.upload_dir, + upload_interval=int(cfg.upload_interval), + monitoring_enabled=(jax.process_index() == 0), + include_badput_breakdown=True, + ) + + if self._monitor: + self._monitor.start_goodput_uploader(*args, **kwargs) + logging.info("Started Goodput upload to Tensorboard in the background!") + else: + logging.log_first_n( + logging.WARNING, + "Goodput upload could not be started. Please check GoodputMonitor logs.", + 1, + ) diff --git a/axlearn/common/launch_trainer_main.py b/axlearn/common/launch_trainer_main.py index 8d170a950..c79b6e89a 100644 --- a/axlearn/common/launch_trainer_main.py +++ b/axlearn/common/launch_trainer_main.py @@ -4,18 +4,21 @@ from absl import app, flags -from axlearn.common import launch, launch_trainer, measurement +from axlearn.common import launch, launch_trainer, measurement, monitoring from axlearn.common.config import config_for_function def main(_): measurement.initialize(flags.FLAGS) + monitoring.initialize(flags.FLAGS) launch.setup() trainer_config = launch_trainer.get_trainer_config() trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder)) + monitoring.start_monitoring() launch_trainer.run_trainer(trainer_config) if __name__ == "__main__": measurement.define_flags() + monitoring.define_flags() app.run(main) diff --git a/axlearn/common/measurement.py b/axlearn/common/measurement.py index eeff5cb96..ac38870b0 100644 --- a/axlearn/common/measurement.py +++ b/axlearn/common/measurement.py @@ -18,11 +18,23 @@ class Event(enum.Enum): START_JOB: Start of job. END_JOB: End of job. START_STEP: Start of a training step. Should be recorded with `step` as a positional arg. + START_ACCELERATOR_INIT: Start of accelerator mesh initialization. + END_ACCELERATOR_INIT: End of accelerator mesh initialization. + START_TRAINING_PREPARATION: Start of training preparation. + END_TRAINING_PREPARATION: End of training preparation. + START_DATA_LOADING: Start of data loading. + END_DATA_LOADING: End of data loading. """ START_JOB = "START_JOB" END_JOB = "END_JOB" START_STEP = "START_STEP" + START_ACCELERATOR_INIT = "START_ACCELERATOR_INIT" + END_ACCELERATOR_INIT = "END_ACCELERATOR_INIT" + START_TRAINING_PREPARATION = "START_TRAINING_PREPARATION" + END_TRAINING_PREPARATION = "END_TRAINING_PREPARATION" + START_DATA_LOADING = "START_DATA_LOADING" + END_DATA_LOADING = "END_DATA_LOADING" class Recorder(Configurable): diff --git a/axlearn/common/monitoring.py b/axlearn/common/monitoring.py new file mode 100644 index 000000000..0d123ac1f --- /dev/null +++ b/axlearn/common/monitoring.py @@ -0,0 +1,113 @@ +# Copyright © 2024 Apple Inc. + +"""Asynchronously compute and monitor metrics like goodput and badput.""" + +import importlib +from typing import Optional, TypeVar + +from absl import flags, logging + +from axlearn.common.config import REQUIRED, Configurable, Required, config_class + + +class Monitor(Configurable): + """The base interface for computing and monitoring metrics.""" + + @config_class + class Config(Configurable.Config): + """Configures any type of Monitor. + + Attributes: + name: Name of the monitor (example: GoodputMonitor). + upload_dir: Storage directory where metrics are uploaded. + upload_interval: Time interval (seconds) at which to query and upload metrics. + """ + + name: Required[str] = REQUIRED + upload_dir: Required[str] = REQUIRED + upload_interval: Required[int] = REQUIRED + + @classmethod + def from_flags(cls, fv: Optional[flags.FlagValues]) -> "Monitor": + """Converts flags to a monitor.""" + raise NotImplementedError(cls) + + def start_monitoring(self, **kwargs): + """Starts computing and uploading metrics at some configured interval in the background.""" + raise NotImplementedError(type(self)) + + +_monitors: dict[str, type] = {} +_T = TypeVar("_T") + + +def register_monitor(name: str): + def fn(cls: _T) -> _T: + """Registers a monitor into a dict of global monitors with reference to its class type.""" + if name in _monitors: + raise ValueError(f"Monitor {name} is already registered.") + _monitors[name] = cls + return cls + + return fn + + +def define_flags(**kwargs): + """Common monitoring flags.""" + + flags.DEFINE_string( + "monitor_type", + None, + "The monitor type. It can be a monitor name, e.g. `GoodputMonitor`, or " + "a module paired with a monitor name, e.g. `my.module:my_monitor`.", + **kwargs, + ) + flags.DEFINE_multi_string( + "monitor_spec", + [], + "Monitor spec provided as key=value. " + "Refer to each monitor's `from_flags` method docstring for details.", + **kwargs, + ) + + +global_monitor: Optional[Monitor] = None + + +def initialize(fv: flags.FlagValues): + """Initializes the monitor from flags.""" + global global_monitor + if not fv.monitor_type: + logging.info("No monitor type specified, skipping monitoring initialize().") + return + if global_monitor is None: + # Infer module from monitor_type. + parts = fv.monitor_type.split(":", 1) + if len(parts) > 1: + logging.info("Registering monitors in %s", parts[0]) + importlib.import_module(parts[0]) + if monitor_class := _monitors.get(parts[-1], None): + # This will instantiate a specific monitor of monitor_type if supported. + global_monitor = monitor_class.from_flags(fv=fv) + else: + raise NotImplementedError( + f"Monitor type: {fv.monitor_type} is not supported. " + f"Supported types are: {sorted(list(_monitors.keys()))}\n" + "You can also specify a specific module to identify the monitor " + "(e.g., `my.module:my_monitor`)." + ) + logging.info("Initialized global monitor: %s", global_monitor) + else: + logging.warning( + "Monitor %s is already initialized, ignoring monitoring initialize().", + global_monitor, + ) + + +def start_monitoring(): + """Begins monitoring events as per global monitor functionality.""" + if global_monitor is None: + logging.log_first_n(logging.INFO, "No Monitor configured, no events will be monitored.", 1) + else: + global_monitor.start_monitoring() + logging.info("Starting monitoring of events using global monitor: %s", global_monitor) diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index cda75090f..734a85330 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -212,6 +212,7 @@ def __init__( utils.validate_float_dtype(cfg.train_dtype) # Create the device mesh. + self._maybe_record_event(measurement.Event.START_ACCELERATOR_INIT) if devices is None: self._step_log( "Devices: global=%s local=%s %s", @@ -286,6 +287,7 @@ def __init__( model=self.model, model_param_partition_specs=model_param_partition_specs, ) + self._maybe_record_event(measurement.Event.END_ACCELERATOR_INIT) @property def step(self): @@ -718,6 +720,7 @@ def _prepare_training(self, prng_key: Tensor) -> bool: # Attempt to restore the latest checkpoint, which may contain a saved `_input_iter`. self.restore_checkpoint(restore_step=None) + self._maybe_record_event(measurement.Event.START_TRAINING_PREPARATION) if self.step is None: # If we didn't restore from checkpoint, attempt to build initial state according # to `cfg.init_state_builder` and initialize the remaining parameters. @@ -733,6 +736,7 @@ def _prepare_training(self, prng_key: Tensor) -> bool: f.write(str(jax.tree_util.tree_structure(self._trainer_state))) self._log_trainer_state_stats() + self._maybe_record_event(measurement.Event.END_TRAINING_PREPARATION) # Log config. self.summary_writer.log_config(cfg, step=self.step) @@ -769,6 +773,7 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int restore_input_iter = cfg.save_input_iterator try: # Try to restore with `input_iter`. + self._maybe_record_event(measurement.Event.START_DATA_LOADING) step, ckpt_state = self.checkpointer.restore( step=restore_step, state=( @@ -782,6 +787,7 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int step, restore_input_iter, ) + self._maybe_record_event(measurement.Event.END_DATA_LOADING) except ValueError as e: logging.warning( "Attempt to restore checkpoint with restore_input_iter=%s failed: %s", @@ -789,6 +795,7 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int e, ) # Restore with a different restore_input_iter setting. + self._maybe_record_event(measurement.Event.START_DATA_LOADING) restore_input_iter = not restore_input_iter step, ckpt_state = self.checkpointer.restore( step=restore_step, @@ -803,6 +810,7 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int step, restore_input_iter, ) + self._maybe_record_event(measurement.Event.END_DATA_LOADING) if step is not None: self._step = step self._trainer_state = TrainerState( diff --git a/pyproject.toml b/pyproject.toml index e2131abd7..9494be6cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,7 @@ gcp = [ "google-cloud-storage==2.16.0", "google-cloud-core==2.3.3", "google-cloud-build==3.24.1", - "ml_goodput_measurement==0.0.2", + "ml-goodput-measurement==0.0.4", "pika==1.3.2", # used by event queue "pyOpenSSL>=22.1.0", # compat with cryptography version. ] From 1417133c5e387582bcc2eaaecce742e6304d7ff7 Mon Sep 17 00:00:00 2001 From: Dipannita Shaw Date: Mon, 4 Nov 2024 23:10:45 +0000 Subject: [PATCH 2/4] Add more testing --- axlearn/cloud/gcp/measurement.py | 34 +++++- axlearn/cloud/gcp/measurement_test.py | 41 ++++++- axlearn/cloud/gcp/monitoring.py | 64 ---------- axlearn/common/launch_trainer_main.py | 6 +- axlearn/common/measurement.py | 21 ++++ axlearn/common/measurement_test.py | 7 ++ axlearn/common/monitoring.py | 113 ------------------ .../dummy_recorder.py | 6 +- 8 files changed, 106 insertions(+), 186 deletions(-) delete mode 100644 axlearn/cloud/gcp/monitoring.py delete mode 100644 axlearn/common/monitoring.py diff --git a/axlearn/cloud/gcp/measurement.py b/axlearn/cloud/gcp/measurement.py index cea224335..89094d856 100644 --- a/axlearn/cloud/gcp/measurement.py +++ b/axlearn/cloud/gcp/measurement.py @@ -5,6 +5,7 @@ import jax from absl import flags, logging from ml_goodput_measurement import goodput +from ml_goodput_measurement import monitoring as goodput_monitoring from axlearn.cloud.common.utils import parse_kv_flags from axlearn.common import measurement @@ -22,7 +23,11 @@ def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder": """Converts flags to a recorder. `fv.recorder_spec` will be interpreted as a list of `key=value` pairs; config names - corresponding to keys will be set to the corresponding values. + corresponding to keys will be set to the corresponding values. A GoodputRecorder can + additionally take in following Tensorboard configs in the recorder_spec: + - upload_dir: The directory to write Tensorboard data to. + - upload_interval: The time interval in seconds at which to query and upload data + to Tensorboard. """ cfg: measurement.Recorder.Config = cls.default_config() cfg = maybe_set_config(cfg, **parse_kv_flags(fv.recorder_spec, delimiter="=")) @@ -32,6 +37,7 @@ def __init__(self, cfg): super().__init__(cfg) cfg: GoodputRecorder.Config = self.config self._recorder = None + self._monitor = None def record(self, event: measurement.Event, *args, **kwargs): # Lazily instantiate the recorder. This avoids invoking jax before setup is complete. @@ -68,3 +74,29 @@ def record(self, event: measurement.Event, *args, **kwargs): 1, event, ) + + def start_monitoring(self, *args, **kwargs): + # Instantiate ml-goodput-measurement's GoodputMonitor + # to asynchronously calculate goodput and badput at + # the upload_interval and upload to the specified + # tensorboard directory. + if self._monitor is None: + cfg: GoodputRecorder.Config = self.config + self._monitor = goodput_monitoring.GoodputMonitor( + job_name=cfg.name, + logger_name=f"goodput_logger_{cfg.name}", + tensorboard_dir=cfg.upload_dir, + upload_interval=int(cfg.upload_interval), + monitoring_enabled=(jax.process_index() == 0), + include_badput_breakdown=True, + ) + + if self._monitor: + self._monitor.start_goodput_uploader(*args, **kwargs) + logging.info("Started Goodput upload to Tensorboard in the background!") + else: + logging.log_first_n( + logging.WARNING, + "Goodput upload could not be started. Please check GoodputMonitor logs.", + 1, + ) diff --git a/axlearn/cloud/gcp/measurement_test.py b/axlearn/cloud/gcp/measurement_test.py index 161edf800..302142622 100644 --- a/axlearn/cloud/gcp/measurement_test.py +++ b/axlearn/cloud/gcp/measurement_test.py @@ -16,7 +16,9 @@ class GoodputRecorderTest(parameterized.TestCase): """Tests GoodputRecorder.""" - @parameterized.parameters(None, ["name=test-name"]) + @parameterized.parameters( + (None,), (["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"],) + ) def test_from_flags(self, spec): fv = flags.FlagValues() measurement.define_flags(flag_values=fv) @@ -34,13 +36,46 @@ def test_from_flags(self, spec): # Recorder is not instantiated until first event. self.assertIsNone(recorder._recorder) - def test_record(self): + def test_record_and_monitor(self): fv = flags.FlagValues() measurement.define_flags(flag_values=fv) - fv.set_default("recorder_spec", ["name=test-name"]) + fv.set_default( + "recorder_spec", + ["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"], + ) fv.mark_as_parsed() recorder = GoodputRecorder.from_flags(fv) recorder._recorder = mock.MagicMock() recorder.record(measurement.Event.START_JOB) self.assertTrue(recorder._recorder.record_job_start_time.called) + + def test_start_monitoring(self): + fv = flags.FlagValues() + measurement.define_flags(flag_values=fv) + fv.set_default( + "recorder_spec", + ["name=test-name", "upload_dir=/test/path/to/upload", "upload_interval=15"], + ) + fv.mark_as_parsed() + + recorder = GoodputRecorder.from_flags(fv) + recorder._monitor = None # Ensure _monitor is initially None + + with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_goodput_monitor: + mock_monitor_instance = mock_goodput_monitor.return_value + recorder.start_monitoring() + + # Check that GoodputMonitor was instantiated + mock_goodput_monitor.assert_called_once_with( + job_name="test-name", + logger_name="goodput_logger_test-name", + tensorboard_dir="/test/path/to/upload", + upload_interval=15, + monitoring_enabled=True, + include_badput_breakdown=True, + ) + + # Ensure that start_goodput_uploader is called on the monitor instance + mock_monitor_instance.start_goodput_uploader.assert_called_once() + self.assertIsNotNone(recorder._monitor) diff --git a/axlearn/cloud/gcp/monitoring.py b/axlearn/cloud/gcp/monitoring.py deleted file mode 100644 index 09f108bba..000000000 --- a/axlearn/cloud/gcp/monitoring.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright © 2024 Apple Inc. - -"""Goodput & Badput computation and monitoring utils for GCP.""" - -import jax -from absl import flags, logging -from ml_goodput_measurement import monitoring as goodput_monitoring - -from axlearn.cloud.common.utils import parse_kv_flags -from axlearn.common import monitoring -from axlearn.common.config import maybe_set_config - - -@monitoring.register_monitor("GoodputMonitor") -class GoodputMonitor(monitoring.Monitor): - """Computes and uploads overall training goodput and optionally badput.""" - - Config = monitoring.Monitor.Config - - @classmethod - def from_flags(cls, fv: flags.FlagValues) -> "GoodputMonitor": - """Converts flags to a GoodputMonitor. - - `fv.monitor_spec` will be interpreted as a list of `key=value` pairs; config names - corresponding to keys will be set to the corresponding values. A GoodputMonitor can - additionally take in following Tensorboard configs in the monitor_spec: - - upload_dir: The directory to write Tensorboard data to. - - upload_interval: The time interval in seconds at which to query and upload data - to Tensorboard. - """ - cfg: monitoring.Monitor.Config = cls.default_config() - cfg = maybe_set_config(cfg, **parse_kv_flags(fv.monitor_spec, delimiter="=")) - return cfg.instantiate() - - def __init__(self, cfg): - super().__init__(cfg) - cfg: GoodputMonitor.Config = self.config - self._monitor = None - - def start_monitoring(self, *args, **kwargs): - # Instantiate ml-goodput-measurement's GoodputMonitor - # to asynchronously calculate goodput and badput at - # the upload_interval and upload to the specified - # tensorboard directory. - if self._monitor is None: - cfg: GoodputMonitor.Config = self.config - self._monitor = goodput_monitoring.GoodputMonitor( - job_name=cfg.name, - logger_name=f"goodput_logger_{cfg.name}", - tensorboard_dir=cfg.upload_dir, - upload_interval=int(cfg.upload_interval), - monitoring_enabled=(jax.process_index() == 0), - include_badput_breakdown=True, - ) - - if self._monitor: - self._monitor.start_goodput_uploader(*args, **kwargs) - logging.info("Started Goodput upload to Tensorboard in the background!") - else: - logging.log_first_n( - logging.WARNING, - "Goodput upload could not be started. Please check GoodputMonitor logs.", - 1, - ) diff --git a/axlearn/common/launch_trainer_main.py b/axlearn/common/launch_trainer_main.py index c79b6e89a..2f617b4cd 100644 --- a/axlearn/common/launch_trainer_main.py +++ b/axlearn/common/launch_trainer_main.py @@ -4,21 +4,19 @@ from absl import app, flags -from axlearn.common import launch, launch_trainer, measurement, monitoring +from axlearn.common import launch, launch_trainer, measurement from axlearn.common.config import config_for_function def main(_): measurement.initialize(flags.FLAGS) - monitoring.initialize(flags.FLAGS) launch.setup() trainer_config = launch_trainer.get_trainer_config() trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder)) - monitoring.start_monitoring() + measurement.start_monitoring() launch_trainer.run_trainer(trainer_config) if __name__ == "__main__": measurement.define_flags() - monitoring.define_flags() app.run(main) diff --git a/axlearn/common/measurement.py b/axlearn/common/measurement.py index ac38870b0..ee0e83da3 100644 --- a/axlearn/common/measurement.py +++ b/axlearn/common/measurement.py @@ -46,9 +46,13 @@ class Config(Configurable.Config): Attributes: name: Name of the recorder. + upload_dir: Directory to store metrics for the monitor. + upload_interval: Time interval (seconds) for monitoring uploads. """ name: Required[str] = REQUIRED + upload_dir: Required[str] = REQUIRED + upload_interval: Required[int] = REQUIRED @classmethod def from_flags(cls, fv: Optional[flags.FlagValues]) -> "Recorder": @@ -59,6 +63,10 @@ def record(self, event: Event, *args, **kwargs): """Records an event with the given name.""" raise NotImplementedError(type(self)) + def start_monitoring(self, **kwargs): + """Starts computing and uploading metrics at some configured interval in the background.""" + raise NotImplementedError(type(self)) + _recorders: dict[str, type] = {} _T = TypeVar("_T") @@ -132,3 +140,16 @@ def record_event(event: Event): logging.log_first_n(logging.INFO, "No recorder configured, ignoring events.", 1) else: global_recorder.record(event) + + +def start_monitoring(): + """Begins monitoring events as per global monitor functionality.""" + if global_recorder is None: + logging.log_first_n( + logging.INFO, "Since recorder is not set up, monitoring cannot be started.", 1 + ) + else: + global_recorder.start_monitoring() + logging.info( + "Starting monitoring of events using global recorder's monitor: %s", global_recorder + ) diff --git a/axlearn/common/measurement_test.py b/axlearn/common/measurement_test.py index 0db79b0fb..c9043f20b 100644 --- a/axlearn/common/measurement_test.py +++ b/axlearn/common/measurement_test.py @@ -85,3 +85,10 @@ def test_initialize(self, recorder_type, expected): with mock.patch.object(measurement.global_recorder, "record") as mock_record: measurement.record_event(measurement.Event.START_JOB) self.assertIn(measurement.Event.START_JOB, mock_record.call_args[0]) + + # Ensure that start_monitoring does not fail. + with mock.patch.object( + measurement.global_recorder, "start_monitoring" + ) as mock_start_monitoring: + measurement.start_monitoring() + mock_start_monitoring.assert_called_once() diff --git a/axlearn/common/monitoring.py b/axlearn/common/monitoring.py deleted file mode 100644 index 0d123ac1f..000000000 --- a/axlearn/common/monitoring.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright © 2024 Apple Inc. - -"""Asynchronously compute and monitor metrics like goodput and badput.""" - -import importlib -from typing import Optional, TypeVar - -from absl import flags, logging - -from axlearn.common.config import REQUIRED, Configurable, Required, config_class - - -class Monitor(Configurable): - """The base interface for computing and monitoring metrics.""" - - @config_class - class Config(Configurable.Config): - """Configures any type of Monitor. - - Attributes: - name: Name of the monitor (example: GoodputMonitor). - upload_dir: Storage directory where metrics are uploaded. - upload_interval: Time interval (seconds) at which to query and upload metrics. - """ - - name: Required[str] = REQUIRED - upload_dir: Required[str] = REQUIRED - upload_interval: Required[int] = REQUIRED - - @classmethod - def from_flags(cls, fv: Optional[flags.FlagValues]) -> "Monitor": - """Converts flags to a monitor.""" - raise NotImplementedError(cls) - - def start_monitoring(self, **kwargs): - """Starts computing and uploading metrics at some configured interval in the background.""" - raise NotImplementedError(type(self)) - - -_monitors: dict[str, type] = {} -_T = TypeVar("_T") - - -def register_monitor(name: str): - def fn(cls: _T) -> _T: - """Registers a monitor into a dict of global monitors with reference to its class type.""" - if name in _monitors: - raise ValueError(f"Monitor {name} is already registered.") - _monitors[name] = cls - return cls - - return fn - - -def define_flags(**kwargs): - """Common monitoring flags.""" - - flags.DEFINE_string( - "monitor_type", - None, - "The monitor type. It can be a monitor name, e.g. `GoodputMonitor`, or " - "a module paired with a monitor name, e.g. `my.module:my_monitor`.", - **kwargs, - ) - flags.DEFINE_multi_string( - "monitor_spec", - [], - "Monitor spec provided as key=value. " - "Refer to each monitor's `from_flags` method docstring for details.", - **kwargs, - ) - - -global_monitor: Optional[Monitor] = None - - -def initialize(fv: flags.FlagValues): - """Initializes the monitor from flags.""" - global global_monitor - if not fv.monitor_type: - logging.info("No monitor type specified, skipping monitoring initialize().") - return - if global_monitor is None: - # Infer module from monitor_type. - parts = fv.monitor_type.split(":", 1) - if len(parts) > 1: - logging.info("Registering monitors in %s", parts[0]) - importlib.import_module(parts[0]) - if monitor_class := _monitors.get(parts[-1], None): - # This will instantiate a specific monitor of monitor_type if supported. - global_monitor = monitor_class.from_flags(fv=fv) - else: - raise NotImplementedError( - f"Monitor type: {fv.monitor_type} is not supported. " - f"Supported types are: {sorted(list(_monitors.keys()))}\n" - "You can also specify a specific module to identify the monitor " - "(e.g., `my.module:my_monitor`)." - ) - logging.info("Initialized global monitor: %s", global_monitor) - else: - logging.warning( - "Monitor %s is already initialized, ignoring monitoring initialize().", - global_monitor, - ) - - -def start_monitoring(): - """Begins monitoring events as per global monitor functionality.""" - if global_monitor is None: - logging.log_first_n(logging.INFO, "No Monitor configured, no events will be monitored.", 1) - else: - global_monitor.start_monitoring() - logging.info("Starting monitoring of events using global monitor: %s", global_monitor) diff --git a/axlearn/experiments/testdata/axlearn_common_measurement_test/dummy_recorder.py b/axlearn/experiments/testdata/axlearn_common_measurement_test/dummy_recorder.py index 76925e92b..5313a95f4 100644 --- a/axlearn/experiments/testdata/axlearn_common_measurement_test/dummy_recorder.py +++ b/axlearn/experiments/testdata/axlearn_common_measurement_test/dummy_recorder.py @@ -10,4 +10,8 @@ class DummyRecorder(measurement.Recorder): @classmethod def from_flags(cls, fv) -> measurement.Recorder: del fv - return cls.default_config().set(name="dummy_recorder").instantiate() + return ( + cls.default_config() + .set(name="dummy_recorder", upload_dir="/dummy/upload_dir", upload_interval=15) + .instantiate() + ) From 02214263a5d07ca7550a92831241387382b777cd Mon Sep 17 00:00:00 2001 From: Dipannita Shaw Date: Mon, 25 Nov 2024 21:30:38 +0000 Subject: [PATCH 3/4] Fix docstrings --- axlearn/cloud/gcp/measurement.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/axlearn/cloud/gcp/measurement.py b/axlearn/cloud/gcp/measurement.py index 7a40b7554..c4a1a9e6d 100644 --- a/axlearn/cloud/gcp/measurement.py +++ b/axlearn/cloud/gcp/measurement.py @@ -86,7 +86,8 @@ def record(self, event: measurement.Event, *args, **kwargs): ) def start_monitoring(self, *args, **kwargs): - """ + """Starts Monitoring of Goodput. + Instantiate ml-goodput-measurement's GoodputMonitor to asynchronously calculate Goodput and Badput at the upload_interval and upload to the specified TensorBoard directory. From 8d0c58d024c117239a5eebd522a39ef4649775a2 Mon Sep 17 00:00:00 2001 From: Dipannita Shaw Date: Mon, 2 Dec 2024 17:19:58 +0000 Subject: [PATCH 4/4] Remove recorder calls from trainer for now --- axlearn/common/trainer.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index 4ffcb810b..a60560769 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -238,7 +238,6 @@ def __init__( utils.validate_float_dtype(cfg.train_dtype) # Create the device mesh. - self._maybe_record_event(measurement.Event.START_ACCELERATOR_INIT) if devices is None: self._step_log( "Devices: global=%s local=%s %s", @@ -325,7 +324,6 @@ def __init__( model=self.model, model_param_partition_specs=model_param_partition_specs, ) - self._maybe_record_event(measurement.Event.END_ACCELERATOR_INIT) @property def step(self): @@ -830,7 +828,6 @@ def _prepare_training(self, prng_key: Tensor) -> bool: # Attempt to restore the latest checkpoint, which may contain a saved `_input_iter`. self.restore_checkpoint(restore_step=None) - self._maybe_record_event(measurement.Event.START_TRAINING_PREPARATION) if self.step is None: # If we didn't restore from checkpoint, attempt to build initial state according # to `cfg.init_state_builder` and initialize the remaining parameters. @@ -850,7 +847,6 @@ def _prepare_training(self, prng_key: Tensor) -> bool: with fs.open(os.path.join(cfg.dir, "model_analysis.txt"), "w") as f: f.write(model_analysis) - self._maybe_record_event(measurement.Event.END_TRAINING_PREPARATION) # Log config. self.summary_writer.log_config(cfg, step=self.step) @@ -887,7 +883,6 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int restore_input_iter = cfg.save_input_iterator try: # Try to restore with `input_iter`. - self._maybe_record_event(measurement.Event.START_DATA_LOADING) step, ckpt_state = self.checkpointer.restore( step=restore_step, state=( @@ -901,7 +896,6 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int step, restore_input_iter, ) - self._maybe_record_event(measurement.Event.END_DATA_LOADING) except ValueError as e: logging.warning( "Attempt to restore checkpoint with restore_input_iter=%s failed: %s", @@ -909,7 +903,6 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int e, ) # Restore with a different restore_input_iter setting. - self._maybe_record_event(measurement.Event.START_DATA_LOADING) restore_input_iter = not restore_input_iter step, ckpt_state = self.checkpointer.restore( step=restore_step, @@ -924,7 +917,6 @@ def restore_checkpoint(self, restore_step: Optional[int] = None) -> Optional[int step, restore_input_iter, ) - self._maybe_record_event(measurement.Event.END_DATA_LOADING) if step is not None: self._step = step self._trainer_state = TrainerState(