diff --git a/axlearn/common/evaler.py b/axlearn/common/evaler.py index 391946f75..592a44eb4 100644 --- a/axlearn/common/evaler.py +++ b/axlearn/common/evaler.py @@ -31,6 +31,7 @@ from axlearn.common.module import Module, OutputCollection from axlearn.common.module import functional as F from axlearn.common.utils import ( + DataPartitionType, NestedPartitionSpec, NestedTensor, Tensor, @@ -81,6 +82,11 @@ class Config(Module.Config): # evalers, not setting prefix will show the accuracies on the same plot for comparison # across evalers. prefix: Optional[str] = None + # Subset of mesh axis names over which the leaves of the input batch are sharded. + batch_axis_names: Union[str, Sequence[str]] = "data" + # The input partition: + # Options: FULL (default), BATCH, REPLICATED + input_partition_type: Optional[DataPartitionType] = DataPartitionType.FULL def __init__( self, @@ -188,11 +194,11 @@ def _pjit(self, fn: Callable) -> Callable: in_shardings=( self._model_param_partition_specs, # model_params. None, # replicated_inputs (e.g., prng_key). - utils.input_partition_spec(), # per_example_inputs. + utils.data_partition_type_to_spec(partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names), # per_example_inputs. ), out_shardings=dict( replicated=None, - per_example=utils.input_partition_spec(), + per_example=utils.data_partition_type_to_spec( partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names), ), ) @@ -574,6 +580,11 @@ class Config(Module.Config): metric_calculator: BaseMetricCalculator.Config = ModelSummaryAccumulator.default_config() # If not None, writes input batches and `metric_calculator` forward outputs. output_writer: Optional[BaseOutputWriter.Config] = None + # Subset of mesh axis names over which the leaves of the input batch are sharded. + batch_axis_names: Union[str, Sequence[str]] = "data" + # The input partition: + # Options: FULL (default), BATCH, REPLICATED + input_partition_type: Optional[DataPartitionType] = DataPartitionType.FULL def __init__( self, @@ -595,7 +606,7 @@ def __init__( self._add_child("input", maybe_set_config(cfg.input, is_training=False)) self._add_child( "metric_calculator", - cfg.metric_calculator.set(eval_dtype=cfg.eval_dtype), + cfg.metric_calculator.set(eval_dtype=cfg.eval_dtype, batch_axis_names=cfg.batch_axis_names, input_partition_type=cfg.input_partition_type), model=model, model_param_partition_specs=model_param_partition_specs, ) @@ -691,7 +702,7 @@ def eval_step( with jax.profiler.StepTraceAnnotation(cfg.name, step_num=step): with jax.profiler.TraceAnnotation(f"{cfg.name}.forward"): - global_input_batch = utils.host_to_global_device_array(input_batch) + global_input_batch = utils.host_to_global_device_array(input_batch, partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names) forward_outputs = self.metric_calculator.forward( global_input_batch, model_params=model_params, diff --git a/axlearn/common/gda_test.py b/axlearn/common/gda_test.py index edf415517..0d98049d1 100644 --- a/axlearn/common/gda_test.py +++ b/axlearn/common/gda_test.py @@ -26,9 +26,9 @@ class GDATest(TestCase): @parameterized.parameters( itertools.product( - ((1, 1), (8, 1), (4, 2)), # mesh_shape + ((1, 1), (8, 1), (4, 2), (16, 4)), # mesh_shape (1, 16), # per_host_batch_size - (DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition + (DataPartitionType.FULL, DataPartitionType.REPLICATED, DataPartitionType.BATCH), # data_partition ) ) def test_host_array_to_gda(self, mesh_shape, per_host_batch_size, data_partition): @@ -41,13 +41,16 @@ def test_host_array_to_gda(self, mesh_shape, per_host_batch_size, data_partition if not is_supported_mesh_shape(mesh_shape): return devices = mesh_utils.create_device_mesh(mesh_shape) - if data_partition == DataPartitionType.FULL: + if data_partition == DataPartitionType.FULL or data_partition == DataPartitionType.BATCH: global_batch_size = per_host_batch_size * jax.process_count() else: assert data_partition == DataPartitionType.REPLICATED global_batch_size = per_host_batch_size if data_partition == DataPartitionType.FULL and global_batch_size < jax.device_count(): return + # first axis is assumed to be batch axis + if data_partition == DataPartitionType.BATCH and global_batch_size % mesh_shape[0] == 0: + return per_host_input_batch = dict(x=jnp.zeros((per_host_batch_size, 8), dtype=jnp.float32)) with jax.sharding.Mesh(devices, ("data", "model")): global_input_batch = host_to_global_device_array( diff --git a/axlearn/common/host_array_test.py b/axlearn/common/host_array_test.py index 1a3adc417..f631816ee 100644 --- a/axlearn/common/host_array_test.py +++ b/axlearn/common/host_array_test.py @@ -21,7 +21,6 @@ host_to_global_device_array, ) - def is_supported( platform: str, mesh_shape: tuple[int, int], @@ -37,16 +36,15 @@ def is_supported( ) ) - class HostArrayTest(TestCase): @parameterized.parameters( filter( lambda params: is_supported(*params), itertools.product( ("cpu", "tpu"), # platform, - ((1, 1), (4, 1), (2, 2), (8, 1), (4, 2)), # mesh_shape + ((1, 1), (4, 1), (2, 2), (8, 1), (4, 2), (16, 4)), # mesh_shape (1, 16), # global_batch_size - (DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition + (DataPartitionType.FULL, DataPartitionType.REPLICATED, DataPartitionType,BATCH), # data_partition ), ) ) diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index a60560769..74fb4309b 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -47,6 +47,7 @@ from axlearn.common.summary_writer import BaseWriter, SummaryWriter from axlearn.common.update_transformation import ForwardOutputs from axlearn.common.utils import ( + DataPartitionType, HybridMeshShape, MeshShape, Nested, @@ -200,6 +201,10 @@ class Config(Module.Config): # The provided config should instantiate to a thunk that returns the context manager. context_manager: Optional[ConfigOr[Callable[[], ContextManager]]] = None + # The input partition: + # Options: FULL (default), BATCH, REPLICATED + input_partition_type: Optional[DataPartitionType] = DataPartitionType.FULL + def __init__( self, cfg: Config, @@ -343,7 +348,7 @@ def trainer_state_partition_specs(self): def _train_step_input_partition_specs(self): # By default, each input tensor is fully partitioned along the batch axis. - return utils.input_partition_spec() + return utils.data_partition_type_to_spec(self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names) def model_params_for_eval(self): state = self.trainer_state @@ -568,7 +573,7 @@ def run( self._step = self._step + 1 self.vlog(3, "Start step %s", self.step) output = self._run_step( - utils.host_to_global_device_array(input_batch), + utils.host_to_global_device_array(input_batch, partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names), force_run_evals=( force_run_eval_sets_at_max_step if self.step >= cfg.max_step else None ), diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 1401337f8..d1b30ef51 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -591,14 +591,17 @@ class DataPartitionType(Enum): FULL = "full" # Data are fully replicated across all devices. REPLICATED = "replicated" + # Data are partitioned across batch axis only. + BATCH = "batch" - -def data_partition_type_to_spec(partition: DataPartitionType) -> PartitionSpec: +def data_partition_type_to_spec(partition: DataPartitionType, * , batch_axis_names: Union[str, Sequence[str]] = ("data", "fsdp")) -> PartitionSpec: """Returns a PartitionSpec for the given partition type.""" if partition == DataPartitionType.FULL: return input_partition_spec() elif partition == DataPartitionType.REPLICATED: return None + elif partition == DataPartitionType.BATCH: + return PartitionSpec(batch_axis_names) else: raise NotImplementedError(f"Unsupported partition: {partition}") @@ -607,6 +610,7 @@ def host_to_global_device_array( host_arrays: Nested[Union[np.ndarray, Tensor]], *, partition: DataPartitionType = DataPartitionType.FULL, + batch_axis_names: Union[str, Sequence[str]] = ("data", "fsdp"), ) -> NestedTensor: """Converts the given host device arrays to global device arrays. @@ -625,7 +629,7 @@ def host_to_global_device_array( NotImplementedError: if the given `partition` type is not supported. """ mesh = thread_resources.env.physical_mesh - partition_spec = data_partition_type_to_spec(partition) + partition_spec = data_partition_type_to_spec(partition, batch_axis_names=batch_axis_names) partition_specs = complete_partition_spec_tree( jax.tree_util.tree_structure(host_arrays), partition_spec ) @@ -636,6 +640,8 @@ def make_gda(x, partition_spec): global_shape = (x.shape[0] * process_count, *x.shape[1:]) elif partition == DataPartitionType.REPLICATED: global_shape = (x.shape[0], *x.shape[1:]) + elif partition == DataPartitionType.BATCH: + global_shape = (x.shape[0] * process_count, *x.shape[1:]) else: raise NotImplementedError(f"Unsupported partition: {partition}") return jax.make_array_from_process_local_data( diff --git a/axlearn/common/utils_test.py b/axlearn/common/utils_test.py index f4c06b47a..374d85e66 100644 --- a/axlearn/common/utils_test.py +++ b/axlearn/common/utils_test.py @@ -42,6 +42,7 @@ ) from axlearn.common.trainer import SpmdTrainer from axlearn.common.utils import ( + DataPartitionType, PHYSICAL_TO_LOGICAL_DISPATCH_KEY, HybridMeshShape, MeshShape, @@ -1701,6 +1702,31 @@ def test_length(self): class HostToGlobalArrayTest(TestCase): """Tests host_to_global_device_array.""" + @pytest.mark.neuron + def test_partition_batch(self): + """Test a case where each process produces a slice.""" + device_count = jax.device_count() + process_count = jax.process_count() + print(f"{device_count=}, {process_count=}") + assert device_count > 1 + + global_shape = (device_count // 2, 1) + assert global_shape[0] % process_count == 0 + per_feed_size = global_shape[0] // process_count + feed_index = jax.process_index() + + with jax.sharding.Mesh(np.array(jax.devices()).reshape(device_count // 2, 2), ("x", "y")): + start = feed_index * per_feed_size + local_x = jnp.arange(start, start + per_feed_size)[:, None] + + # Construct global array. + global_x = host_to_global_device_array(local_x, partition=DataPartitionType.BATCH, batch_axis_names="x") + + # Compare against expected. + expected = jnp.arange(global_shape[0])[:, None] + self.assertEqual(jnp.mean(expected), jnp.mean(global_x)) + self.assertNestedEqual(expected, replicate_to_local_data(global_x)) + @pytest.mark.tpu def test_partition_full(self): """Test a case where each process produces a slice.""" diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index d32120d25..015dffbcf 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -57,7 +57,7 @@ from axlearn.common.param_init import PARAM_REGEXP_WEIGHT, DefaultInitializer, WeightInitializer from axlearn.common.summary_writer import BaseWriter from axlearn.common.trainer import MeshShape, SpmdTrainer -from axlearn.common.utils import HybridMeshShape, Nested, get_data_dir +from axlearn.common.utils import DataPartitionType, HybridMeshShape, Nested, get_data_dir from axlearn.experiments.text.common import DataMixtureComponent, tfds_text_source from axlearn.experiments.trainer_config_utils import TrainerConfigFn @@ -640,6 +640,7 @@ def get_trainer_config_fn( mesh_shape: Union[MeshShape, HybridMeshShape], mesh_axis_names: Sequence[str] = MESH_AXIS_NAMES, mesh_rules: Optional[Sequence[tuple[str, Optional[Union[MeshShape, HybridMeshShape]]]]] = None, + input_partition_type: Optional[DataPartitionType] = None, eval_every_n_steps: int = 5000, eval_batch_size: Optional[int] = None, keep_every_n_steps: int = 50_000, @@ -689,9 +690,27 @@ def config_fn() -> InstantiableConfig: pad_example_fn=input_tf_data.default_pad_example_fn, ), ) + if input_partition_type: + cfg.input_partition_type = input_partition_type + if len(mesh_axis_names) != len(mesh_shape): + raise ValueError( + f"Number of mesh axis names ({mesh_axis_names}) " + f"must match number of mesh dims ({mesh_shape})." + ) + cfg.mesh_axis_names = mesh_axis_names + cfg.mesh_shape = mesh_shape + # Set batch sharding spec to exclude the "model" axis (assumed for tensor-parallelism) and + # "pipeline" axis (for pipeline parallelism). + cfg.batch_axis_names = tuple( + el for el in mesh_axis_names if el not in ("model", "pipeline") + ) + cfg.mesh_rules = mesh_rules cfg.evalers = {} for name, evaler_cfg in evalers.items(): evaler_cfg.input.batcher.set(global_batch_size=eval_batch_size or train_batch_size) + if input_partition_type: + evaler_cfg.set(input_partition_type=input_partition_type) + evaler_cfg.set(batch_axis_names=cfg.batch_axis_names) evaler_cfg.set( eval_policy=config_for_function(eval_every_n_steps_policy).set( n=eval_every_n_steps, @@ -708,19 +727,6 @@ def config_fn() -> InstantiableConfig: cfg.checkpointer.keep_last_n = 3 cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100) cfg.summary_writer.max_queue = 1000 - if len(mesh_axis_names) != len(mesh_shape): - raise ValueError( - f"Number of mesh axis names ({mesh_axis_names}) " - f"must match number of mesh dims ({mesh_shape})." - ) - cfg.mesh_axis_names = mesh_axis_names - cfg.mesh_shape = mesh_shape - # Set batch sharding spec to exclude the "model" axis (assumed for tensor-parallelism) and - # "pipeline" axis (for pipeline parallelism). - cfg.batch_axis_names = tuple( - el for el in mesh_axis_names if el not in ("model", "pipeline") - ) - cfg.mesh_rules = mesh_rules # Maybe load state. if init_state_builder: cfg.init_state_builder = init_state_builder diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 69f6b1102..302ab8956 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -39,7 +39,7 @@ MeshShapeModifier, RematSpecModifier, ) -from axlearn.common.utils import extended_checkpoint_policies +from axlearn.common.utils import DataPartitionType, extended_checkpoint_policies from axlearn.experiments.text.gpt.common import ( STEP_DTYPE, SourceBuilder, @@ -423,6 +423,7 @@ def get_trainer_kwargs( raise NotImplementedError(f"Unknown model size {model_size}.") model_kwargs = trainer_kwargs.pop("model_kwargs") model_kwargs.setdefault("vocab_size", vocab_size) + trainer_kwargs["input_partition_type"] = None if backend != "neuron" else DataPartitionType.BATCH trainer_kwargs["model_cfg"] = model_config(**model_kwargs) trainer_kwargs["learner_cfg"] = adamw_decoupled_learner_config( max_step=trainer_kwargs["max_step"],