Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Special remat for Neuron #898

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions axlearn/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2508,23 +2508,27 @@ def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]:
atten_state, atten_output = attention_thunk(TensorSpec(target.shape, target.dtype))
return dict(attention=atten_state), atten_output

remat_pt1 = "attention_output"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really needed? How much perf gain we get from saving attention_output?

if cfg.structure == "prenorm":
skip_input = target # pre-norm: where normalization happens within the residual part.
norm_target = self.norm(target)
atten_state, atten_output = attention_thunk(norm_target)
data = skip_input + self.stochastic_depth(self.dropout(atten_output.data))
data = self._remat_name(data, remat_pt1)
elif cfg.structure == "postnorm":
# This is the structure used by the original Transformer, BERT, and RoBERTa.
atten_state, atten_output = attention_thunk(target)
# Post-norm: norm applied on the sum of input and attention output.
data = self.norm(target + self.stochastic_depth(self.dropout(atten_output.data)))
data = self._remat_name(data, remat_pt1)
elif cfg.structure == "hybridnorm":
skip_input = target # pre-norm: where normalization happens within the residual part.
norm_target = self.prenorm(target)
atten_state, atten_output = attention_thunk(norm_target)
data = skip_input + self.stochastic_depth(
self.dropout(self.postnorm(atten_output.data))
)
data = self._remat_name(data, remat_pt1)
else:
raise NotImplementedError(cfg.structure)
return dict(attention=atten_state), self.Output(
Expand Down Expand Up @@ -2801,6 +2805,7 @@ def _linear2(x):
self._add_tensor_stats("inputs", inputs)

remat_pt2 = "linear2"
remat_pt3 = "feed_forward_output"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, is this necessary? Seems quite wasteful to checkpoint an activation after a norm? I would recommend restraining from making change here.

if cfg.structure == "prenorm":
x = self.norm(inputs)
x = self._linear1_activation(x)
Expand All @@ -2812,6 +2817,7 @@ def _linear2(x):
if cfg.residual_weight != 1:
x *= cfg.residual_weight
x += inputs
x = self._remat_name(x, remat_pt3)
elif cfg.structure == "postnorm":
x = self._linear1_activation(inputs)
x = _linear2(x)
Expand All @@ -2821,6 +2827,7 @@ def _linear2(x):
if cfg.residual_weight != 1:
x *= cfg.residual_weight
x = self.norm(x + inputs)
x = self._remat_name(x, remat_pt3)
elif cfg.structure == "hybridnorm":
x = self.prenorm(inputs)
x = self._linear1_activation(x)
Expand All @@ -2833,6 +2840,7 @@ def _linear2(x):
if cfg.residual_weight != 1:
x *= cfg.residual_weight
x += inputs
x = self._remat_name(x, remat_pt3)
elif cfg.structure == "nonorm":
x = inputs
x = self._linear1_activation(x)
Expand All @@ -2845,6 +2853,7 @@ def _linear2(x):
# this layer, e.g., in ParallelTransformerLayer.
if cfg.residual_weight != 1:
x *= cfg.residual_weight
x = self._remat_name(x, remat_pt3)
else:
raise NotImplementedError(cfg.structure)
return x
Expand Down Expand Up @@ -3956,15 +3965,21 @@ def policy(prim, *_, **params):
return policy


SELF_ATTENTION_SAVE_PATTERN = ".*([qkvo]_proj|context)"
FEED_FORWARD_SAVE_PATTERN = ".*linear[12]_.*"
# Regex patterns for matching remat names
class RematRegexSavePatterns(enum.Enum):
QKV_PROJ = r".*\.?(k|q|v)_proj"
LINEAR1_X = r".*\.?linear1_[01]"
ATTENTION_OUTPUT = r"TransformerAttentionLayer\.attention_output"
FEED_FORWARD_OUTPUT = r"TransformerFeedForwardLayer\.feed_forward_output"
SELF_ATTENTION = ".*([qkvo]_proj|context)"
FEED_FORWARD = ".*linear[12]_.*"


def build_remat_spec(
stack_cfg: Union[
BaseStackedTransformerLayer.Config, "RepeatedConformerLayer.Config" # type: ignore
],
save_pattern: _SavePattern = SELF_ATTENTION_SAVE_PATTERN,
save_pattern: _SavePattern = RematRegexSavePatterns.SELF_ATTENTION.value,
offload_pattern: _SavePattern = None,
offload_dst: str = "pinned_host",
) -> Optional[RematSpec]:
Expand Down
70 changes: 68 additions & 2 deletions axlearn/common/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@

from axlearn.common import attention, attention_bias, test_utils, utils
from axlearn.common.attention import (
FEED_FORWARD_SAVE_PATTERN,
BaseStackedTransformerLayer,
BaseTransformerLayer,
BottleNeckAdapterTransformerLayer,
Expand All @@ -58,6 +57,7 @@
PipelinedTransformerLayer,
QKVLinear,
QLinear,
RematRegexSavePatterns,
RepeatedTransformerLayer,
RoFormerQKVLinear,
StackedTransformerLayer,
Expand Down Expand Up @@ -3420,7 +3420,7 @@ def f(x, layer_params):
jax.remat(
f,
policy=_save_and_offload_only_these_names_regex(
names_which_can_be_saved=FEED_FORWARD_SAVE_PATTERN,
names_which_can_be_saved=RematRegexSavePatterns.FEED_FORWARD.value,
names_which_can_be_offloaded=None,
offload_src="device",
offload_dst="pinned_host",
Expand Down Expand Up @@ -3875,6 +3875,72 @@ def f(x, layer_params):
5,
)

def test_build_remat_spec_neuron(self):
model_dim, num_heads = 6, 2
cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim)
cfg.self_attention.attention.set(num_heads=num_heads, causal=True)
cfg.feed_forward.hidden_dim = model_dim * 4
cfg.vlog = 5

layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None)
layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0))

batch_size, tgt_len = 2, 5
rng = np.random.default_rng(seed=123)
target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32)

def f(x, layer_params):
forward_outputs, _ = F(
layer,
inputs=dict(
data=x,
),
state=layer_params,
is_training=True,
prng_key=jax.random.PRNGKey(0),
)
return forward_outputs

# Ignore type errors.
spec: Any = build_remat_spec(mock.MagicMock())

policy = (
config_for_function(_save_and_offload_only_these_names_regex)
.set(
names_which_can_be_saved="|".join(
[
RematRegexSavePatterns.QKV_PROJ.value,
RematRegexSavePatterns.LINEAR1_X.value,
RematRegexSavePatterns.ATTENTION_OUTPUT.value,
RematRegexSavePatterns.FEED_FORWARD_OUTPUT.value,
]
),
names_which_can_be_offloaded=None,
offload_src=None,
offload_dst=None,
)
.instantiate()
)

_, default_policy_backward = jax.linearize(
jax.remat(f, policy=policy, prevent_cse=spec.prevent_cse),
jnp.asarray(target),
layer_params,
)
_, full_remat_backward = jax.linearize(
jax.remat(f),
jnp.asarray(target),
layer_params,
)

# Eliminated the remat of qkv_proj, o_proj and linear1_0 = 5 dots. This assumes
# FlashAttention is not enabled.
self.assertEqual(
str(full_remat_backward).count(" dot_general")
- str(default_policy_backward).count(" dot_general"),
5,
)


class TestStackModel(BaseLayer):
"""A dummy transformer stack."""
Expand Down
11 changes: 2 additions & 9 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
BaseQKVLinear,
MultiheadAttention,
RepeatedTransformerLayer,
StackedTransformerLayer,
TransformerLayer,
build_remat_spec,
set_double_shard_weights_config,
Expand Down Expand Up @@ -190,20 +191,12 @@ def update_model_remat_config(
):
"""Recomputes and sets the remat_spec based on provided layer_cfg.

Only applied if the stack_cfg is a RepeatedTransformerLayer.

Args:
stack_cfg: The transformer stack config.
layer_cfg: The transformer layer config.
offload_dst: Destination of remat checkptoing offloading.

Raises:
NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer.
"""
if stack_cfg.klass is not RepeatedTransformerLayer:
raise NotImplementedError(
f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}"
)

remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg))
layer_cfg.set(remat_spec=remat_spec)
Expand Down Expand Up @@ -277,7 +270,7 @@ def model_config(
layer_cfg.self_attention.attention.input_linear = attention_qkv_linear
layer_cfg.self_attention.structure = atten_structure
layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap
if stack_cfg.klass is RepeatedTransformerLayer:
if issubclass(stack_cfg.klass, (RepeatedTransformerLayer, StackedTransformerLayer)):
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
# Stack.
transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg)
Expand Down
35 changes: 34 additions & 1 deletion axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
FusedQKVLinear,
GroupedQueryAttention,
MultiheadAttention,
RematRegexSavePatterns,
RepeatedTransformerLayer,
RoFormerQKVLinear,
_save_and_offload_only_these_names_regex,
)
from axlearn.common.base_layer import RematSpec
from axlearn.common.config import config_for_function
Expand Down Expand Up @@ -85,7 +87,6 @@ class Version(enum.Enum):
Version.V3: 5e5,
}


# Mapping from Fuji versions to total number of tokens used in training.
TOTAL_TOKENS = {
Version.V1: {
Expand Down Expand Up @@ -417,6 +418,38 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)",
mesh_shape_from_axes(data=-1, fsdp=128),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=True,
policy=config_for_function(
_save_and_offload_only_these_names_regex
).set(
names_which_can_be_saved="|".join(
[
RematRegexSavePatterns.QKV_PROJ.value,
RematRegexSavePatterns.LINEAR1_X.value,
RematRegexSavePatterns.RESIDUAL_ADD.value,
RematRegexSavePatterns.MLP_RESIDUAL.value,
]
),
names_which_can_be_offloaded=None,
ruomingp marked this conversation as resolved.
Show resolved Hide resolved
offload_src=None,
offload_dst=None,
),
),
}
),
],
),
),
),
)
else:
Expand Down