Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input batch sharding strategy BATCH #884

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions axlearn/common/evaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from axlearn.common.module import Module, OutputCollection
from axlearn.common.module import functional as F
from axlearn.common.utils import (
DataPartitionType,
NestedPartitionSpec,
NestedTensor,
Tensor,
Expand Down Expand Up @@ -81,6 +82,11 @@ class Config(Module.Config):
# evalers, not setting prefix will show the accuracies on the same plot for comparison
# across evalers.
prefix: Optional[str] = None
# Subset of mesh axis names over which the leaves of the input batch are sharded.
batch_axis_names: Union[str, Sequence[str]] = "data"
# The input partition:
# Options: FULL (default), BATCH, REPLICATED
input_partition_type: Optional[DataPartitionType] = DataPartitionType.FULL

def __init__(
self,
Expand Down Expand Up @@ -188,11 +194,11 @@ def _pjit(self, fn: Callable) -> Callable:
in_shardings=(
self._model_param_partition_specs, # model_params.
None, # replicated_inputs (e.g., prng_key).
utils.input_partition_spec(), # per_example_inputs.
utils.data_partition_type_to_spec(partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names), # per_example_inputs.
),
out_shardings=dict(
replicated=None,
per_example=utils.input_partition_spec(),
per_example=utils.data_partition_type_to_spec( partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names),
),
)

Expand Down Expand Up @@ -574,6 +580,11 @@ class Config(Module.Config):
metric_calculator: BaseMetricCalculator.Config = ModelSummaryAccumulator.default_config()
# If not None, writes input batches and `metric_calculator` forward outputs.
output_writer: Optional[BaseOutputWriter.Config] = None
# Subset of mesh axis names over which the leaves of the input batch are sharded.
batch_axis_names: Union[str, Sequence[str]] = "data"
# The input partition:
# Options: FULL (default), BATCH, REPLICATED
input_partition_type: Optional[DataPartitionType] = DataPartitionType.FULL

def __init__(
self,
Expand All @@ -595,7 +606,7 @@ def __init__(
self._add_child("input", maybe_set_config(cfg.input, is_training=False))
self._add_child(
"metric_calculator",
cfg.metric_calculator.set(eval_dtype=cfg.eval_dtype),
cfg.metric_calculator.set(eval_dtype=cfg.eval_dtype, batch_axis_names=cfg.batch_axis_names, input_partition_type=cfg.input_partition_type),
model=model,
model_param_partition_specs=model_param_partition_specs,
)
Expand Down Expand Up @@ -691,7 +702,7 @@ def eval_step(

with jax.profiler.StepTraceAnnotation(cfg.name, step_num=step):
with jax.profiler.TraceAnnotation(f"{cfg.name}.forward"):
global_input_batch = utils.host_to_global_device_array(input_batch)
global_input_batch = utils.host_to_global_device_array(input_batch, partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names)
forward_outputs = self.metric_calculator.forward(
global_input_batch,
model_params=model_params,
Expand Down
9 changes: 6 additions & 3 deletions axlearn/common/gda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
class GDATest(TestCase):
@parameterized.parameters(
itertools.product(
((1, 1), (8, 1), (4, 2)), # mesh_shape
((1, 1), (8, 1), (4, 2), (16, 4)), # mesh_shape
(1, 16), # per_host_batch_size
(DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition
(DataPartitionType.FULL, DataPartitionType.REPLICATED, DataPartitionType.BATCH), # data_partition
)
)
def test_host_array_to_gda(self, mesh_shape, per_host_batch_size, data_partition):
Expand All @@ -41,13 +41,16 @@ def test_host_array_to_gda(self, mesh_shape, per_host_batch_size, data_partition
if not is_supported_mesh_shape(mesh_shape):
return
devices = mesh_utils.create_device_mesh(mesh_shape)
if data_partition == DataPartitionType.FULL:
if data_partition == DataPartitionType.FULL or data_partition == DataPartitionType.BATCH:
global_batch_size = per_host_batch_size * jax.process_count()
else:
assert data_partition == DataPartitionType.REPLICATED
global_batch_size = per_host_batch_size
if data_partition == DataPartitionType.FULL and global_batch_size < jax.device_count():
return
# first axis is assumed to be batch axis
if data_partition == DataPartitionType.BATCH and global_batch_size % mesh_shape[0] == 0:
return
per_host_input_batch = dict(x=jnp.zeros((per_host_batch_size, 8), dtype=jnp.float32))
with jax.sharding.Mesh(devices, ("data", "model")):
global_input_batch = host_to_global_device_array(
Expand Down
6 changes: 2 additions & 4 deletions axlearn/common/host_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
host_to_global_device_array,
)


def is_supported(
platform: str,
mesh_shape: tuple[int, int],
Expand All @@ -37,16 +36,15 @@ def is_supported(
)
)


class HostArrayTest(TestCase):
@parameterized.parameters(
filter(
lambda params: is_supported(*params),
itertools.product(
("cpu", "tpu"), # platform,
((1, 1), (4, 1), (2, 2), (8, 1), (4, 2)), # mesh_shape
((1, 1), (4, 1), (2, 2), (8, 1), (4, 2), (16, 4)), # mesh_shape
(1, 16), # global_batch_size
(DataPartitionType.FULL, DataPartitionType.REPLICATED), # data_partition
(DataPartitionType.FULL, DataPartitionType.REPLICATED, DataPartitionType,BATCH), # data_partition
),
)
)
Expand Down
9 changes: 7 additions & 2 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from axlearn.common.summary_writer import BaseWriter, SummaryWriter
from axlearn.common.update_transformation import ForwardOutputs
from axlearn.common.utils import (
DataPartitionType,
HybridMeshShape,
MeshShape,
Nested,
Expand Down Expand Up @@ -200,6 +201,10 @@ class Config(Module.Config):
# The provided config should instantiate to a thunk that returns the context manager.
context_manager: Optional[ConfigOr[Callable[[], ContextManager]]] = None

# The input partition:
# Options: FULL (default), BATCH, REPLICATED
input_partition_type: Optional[DataPartitionType] = DataPartitionType.FULL

def __init__(
self,
cfg: Config,
Expand Down Expand Up @@ -343,7 +348,7 @@ def trainer_state_partition_specs(self):

def _train_step_input_partition_specs(self):
# By default, each input tensor is fully partitioned along the batch axis.
return utils.input_partition_spec()
return utils.data_partition_type_to_spec(self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names)

def model_params_for_eval(self):
state = self.trainer_state
Expand Down Expand Up @@ -568,7 +573,7 @@ def run(
self._step = self._step + 1
self.vlog(3, "Start step %s", self.step)
output = self._run_step(
utils.host_to_global_device_array(input_batch),
utils.host_to_global_device_array(input_batch, partition=self.config.input_partition_type, batch_axis_names=self.config.batch_axis_names),
force_run_evals=(
force_run_eval_sets_at_max_step if self.step >= cfg.max_step else None
),
Expand Down
12 changes: 9 additions & 3 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,14 +591,17 @@ class DataPartitionType(Enum):
FULL = "full"
# Data are fully replicated across all devices.
REPLICATED = "replicated"
# Data are partitioned across batch axis only.
BATCH = "batch"


def data_partition_type_to_spec(partition: DataPartitionType) -> PartitionSpec:
def data_partition_type_to_spec(partition: DataPartitionType, * , batch_axis_names: Union[str, Sequence[str]] = ("data", "fsdp")) -> PartitionSpec:
"""Returns a PartitionSpec for the given partition type."""
if partition == DataPartitionType.FULL:
return input_partition_spec()
elif partition == DataPartitionType.REPLICATED:
return None
elif partition == DataPartitionType.BATCH:
return PartitionSpec(batch_axis_names)
else:
raise NotImplementedError(f"Unsupported partition: {partition}")

Expand All @@ -607,6 +610,7 @@ def host_to_global_device_array(
host_arrays: Nested[Union[np.ndarray, Tensor]],
*,
partition: DataPartitionType = DataPartitionType.FULL,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @markblee plans to remove the DataPartitionType enum and rely on https://jax.readthedocs.io/en/latest/_autosummary/jax.make_array_from_process_local_data.html to support flexible partition specs.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks that sounds promising.

Hello @markblee, let me know if this PR is needed till you make your changes, or if you have your design in mind I can reshape the PR to make it compatible with your design.

batch_axis_names: Union[str, Sequence[str]] = ("data", "fsdp"),
) -> NestedTensor:
"""Converts the given host device arrays to global device arrays.

Expand All @@ -625,7 +629,7 @@ def host_to_global_device_array(
NotImplementedError: if the given `partition` type is not supported.
"""
mesh = thread_resources.env.physical_mesh
partition_spec = data_partition_type_to_spec(partition)
partition_spec = data_partition_type_to_spec(partition, batch_axis_names=batch_axis_names)
partition_specs = complete_partition_spec_tree(
jax.tree_util.tree_structure(host_arrays), partition_spec
)
Expand All @@ -636,6 +640,8 @@ def make_gda(x, partition_spec):
global_shape = (x.shape[0] * process_count, *x.shape[1:])
elif partition == DataPartitionType.REPLICATED:
global_shape = (x.shape[0], *x.shape[1:])
elif partition == DataPartitionType.BATCH:
global_shape = (x.shape[0] * process_count, *x.shape[1:])
else:
raise NotImplementedError(f"Unsupported partition: {partition}")
return jax.make_array_from_process_local_data(
Expand Down
26 changes: 26 additions & 0 deletions axlearn/common/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from axlearn.common.trainer import SpmdTrainer
from axlearn.common.utils import (
DataPartitionType,
PHYSICAL_TO_LOGICAL_DISPATCH_KEY,
HybridMeshShape,
MeshShape,
Expand Down Expand Up @@ -1701,6 +1702,31 @@ def test_length(self):
class HostToGlobalArrayTest(TestCase):
"""Tests host_to_global_device_array."""

@pytest.mark.neuron
def test_partition_batch(self):
"""Test a case where each process produces a slice."""
device_count = jax.device_count()
process_count = jax.process_count()
print(f"{device_count=}, {process_count=}")
assert device_count > 1

global_shape = (device_count // 2, 1)
assert global_shape[0] % process_count == 0
per_feed_size = global_shape[0] // process_count
feed_index = jax.process_index()

with jax.sharding.Mesh(np.array(jax.devices()).reshape(device_count // 2, 2), ("x", "y")):
start = feed_index * per_feed_size
local_x = jnp.arange(start, start + per_feed_size)[:, None]

# Construct global array.
global_x = host_to_global_device_array(local_x, partition=DataPartitionType.BATCH, batch_axis_names="x")

# Compare against expected.
expected = jnp.arange(global_shape[0])[:, None]
self.assertEqual(jnp.mean(expected), jnp.mean(global_x))
self.assertNestedEqual(expected, replicate_to_local_data(global_x))

@pytest.mark.tpu
def test_partition_full(self):
"""Test a case where each process produces a slice."""
Expand Down
34 changes: 20 additions & 14 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from axlearn.common.param_init import PARAM_REGEXP_WEIGHT, DefaultInitializer, WeightInitializer
from axlearn.common.summary_writer import BaseWriter
from axlearn.common.trainer import MeshShape, SpmdTrainer
from axlearn.common.utils import HybridMeshShape, Nested, get_data_dir
from axlearn.common.utils import DataPartitionType, HybridMeshShape, Nested, get_data_dir
from axlearn.experiments.text.common import DataMixtureComponent, tfds_text_source
from axlearn.experiments.trainer_config_utils import TrainerConfigFn

Expand Down Expand Up @@ -640,6 +640,7 @@ def get_trainer_config_fn(
mesh_shape: Union[MeshShape, HybridMeshShape],
mesh_axis_names: Sequence[str] = MESH_AXIS_NAMES,
mesh_rules: Optional[Sequence[tuple[str, Optional[Union[MeshShape, HybridMeshShape]]]]] = None,
input_partition_type: Optional[DataPartitionType] = None,
eval_every_n_steps: int = 5000,
eval_batch_size: Optional[int] = None,
keep_every_n_steps: int = 50_000,
Expand Down Expand Up @@ -689,9 +690,27 @@ def config_fn() -> InstantiableConfig:
pad_example_fn=input_tf_data.default_pad_example_fn,
),
)
if input_partition_type:
cfg.input_partition_type = input_partition_type
if len(mesh_axis_names) != len(mesh_shape):
raise ValueError(
f"Number of mesh axis names ({mesh_axis_names}) "
f"must match number of mesh dims ({mesh_shape})."
)
cfg.mesh_axis_names = mesh_axis_names
cfg.mesh_shape = mesh_shape
# Set batch sharding spec to exclude the "model" axis (assumed for tensor-parallelism) and
# "pipeline" axis (for pipeline parallelism).
cfg.batch_axis_names = tuple(
el for el in mesh_axis_names if el not in ("model", "pipeline")
)
cfg.mesh_rules = mesh_rules
cfg.evalers = {}
for name, evaler_cfg in evalers.items():
evaler_cfg.input.batcher.set(global_batch_size=eval_batch_size or train_batch_size)
if input_partition_type:
evaler_cfg.set(input_partition_type=input_partition_type)
evaler_cfg.set(batch_axis_names=cfg.batch_axis_names)
evaler_cfg.set(
eval_policy=config_for_function(eval_every_n_steps_policy).set(
n=eval_every_n_steps,
Expand All @@ -708,19 +727,6 @@ def config_fn() -> InstantiableConfig:
cfg.checkpointer.keep_last_n = 3
cfg.summary_writer.write_every_n_steps = min(eval_every_n_steps, 100)
cfg.summary_writer.max_queue = 1000
if len(mesh_axis_names) != len(mesh_shape):
raise ValueError(
f"Number of mesh axis names ({mesh_axis_names}) "
f"must match number of mesh dims ({mesh_shape})."
)
cfg.mesh_axis_names = mesh_axis_names
cfg.mesh_shape = mesh_shape
# Set batch sharding spec to exclude the "model" axis (assumed for tensor-parallelism) and
# "pipeline" axis (for pipeline parallelism).
cfg.batch_axis_names = tuple(
el for el in mesh_axis_names if el not in ("model", "pipeline")
)
cfg.mesh_rules = mesh_rules
# Maybe load state.
if init_state_builder:
cfg.init_state_builder = init_state_builder
Expand Down
3 changes: 2 additions & 1 deletion axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
MeshShapeModifier,
RematSpecModifier,
)
from axlearn.common.utils import extended_checkpoint_policies
from axlearn.common.utils import DataPartitionType, extended_checkpoint_policies
from axlearn.experiments.text.gpt.common import (
STEP_DTYPE,
SourceBuilder,
Expand Down Expand Up @@ -423,6 +423,7 @@ def get_trainer_kwargs(
raise NotImplementedError(f"Unknown model size {model_size}.")
model_kwargs = trainer_kwargs.pop("model_kwargs")
model_kwargs.setdefault("vocab_size", vocab_size)
trainer_kwargs["input_partition_type"] = None if backend != "neuron" else DataPartitionType.BATCH
trainer_kwargs["model_cfg"] = model_config(**model_kwargs)
trainer_kwargs["learner_cfg"] = adamw_decoupled_learner_config(
max_step=trainer_kwargs["max_step"],
Expand Down