Skip to content

Commit

Permalink
enable special remat for neuron
Browse files Browse the repository at this point in the history
  • Loading branch information
apoorvtintin committed Dec 18, 2024
1 parent 3ae8f9f commit 45342e7
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 9 deletions.
11 changes: 2 additions & 9 deletions axlearn/experiments/text/gpt/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
BaseQKVLinear,
MultiheadAttention,
RepeatedTransformerLayer,
StackedTransformerLayer,
TransformerLayer,
build_remat_spec,
set_double_shard_weights_config,
Expand Down Expand Up @@ -190,20 +191,12 @@ def update_model_remat_config(
):
"""Recomputes and sets the remat_spec based on provided layer_cfg.
Only applied if the stack_cfg is a RepeatedTransformerLayer.
Args:
stack_cfg: The transformer stack config.
layer_cfg: The transformer layer config.
offload_dst: Destination of remat checkptoing offloading.
Raises:
NotImplementedError: If `stack_cfg.klass` is not a RepeatedTransformerLayer.
"""
if stack_cfg.klass is not RepeatedTransformerLayer:
raise NotImplementedError(
f"Remat spec is not implemented for stack_cfg with klass={type(stack_cfg.klass)}"
)

remat_spec = build_remat_spec(stack_cfg.clone(layer=layer_cfg))
layer_cfg.set(remat_spec=remat_spec)
Expand Down Expand Up @@ -277,7 +270,7 @@ def model_config(
layer_cfg.self_attention.attention.input_linear = attention_qkv_linear
layer_cfg.self_attention.structure = atten_structure
layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap
if stack_cfg.klass is RepeatedTransformerLayer:
if issubclass(stack_cfg.klass, (RepeatedTransformerLayer, StackedTransformerLayer)):
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
# Stack.
transformer_cfg = stack_cfg.set(num_layers=num_layers, layer=layer_cfg)
Expand Down
41 changes: 41 additions & 0 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
MultiheadAttention,
RepeatedTransformerLayer,
RoFormerQKVLinear,
_save_and_offload_only_these_names_regex,
)
from axlearn.common.base_layer import RematSpec
from axlearn.common.config import config_for_function
Expand Down Expand Up @@ -86,6 +87,14 @@ class Version(enum.Enum):
}


# Regex patterns for matching remat names
class RematRegex(enum.Enum):
QKV_PROJ = r".*\.?(k|q|v)_proj"
LINEAR1_X = r".*\.?linear1_[01]"
RESIDUAL_ADD = r"TransformerAttentionLayer\.residual_add"
MLP_RESIDUAL = r"TransformerFeedForwardLayer\.mlp_residual"


# Mapping from Fuji versions to total number of tokens used in training.
TOTAL_TOKENS = {
Version.V1: {
Expand Down Expand Up @@ -417,6 +426,38 @@ def get_trainer_kwargs(
"gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)",
mesh_shape_from_axes(data=-1, fsdp=128),
),
(
"neuron-(trn2|trn2n).48xlarge-64",
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(fsdp=-1, model=4)
),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=True,
policy=config_for_function(
_save_and_offload_only_these_names_regex
).set(
names_which_can_be_saved="|".join(
[
RematRegex.QKV_PROJ.value,
RematRegex.LINEAR1_X.value,
RematRegex.RESIDUAL_ADD.value,
RematRegex.MLP_RESIDUAL.value,
]
),
names_which_can_be_offloaded=None,
offload_src=None,
offload_dst=None,
),
),
}
),
],
),
),
),
)
else:
Expand Down

0 comments on commit 45342e7

Please sign in to comment.