-
Notifications
You must be signed in to change notification settings - Fork 277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Special remat for Neuron #898
base: main
Are you sure you want to change the base?
Conversation
a08fb3c
to
9715aef
Compare
axlearn/common/attention.py
Outdated
@@ -3934,6 +3935,33 @@ def forward( | |||
_SavePattern = Union[str, re.Pattern, None] | |||
|
|||
|
|||
def save_only_these_regex_patterns(*regex_patterns_to_save): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this different from
axlearn/axlearn/common/attention.py
Lines 3933 to 3956 in 92205bc
OffloadPolicy = Callable[[Primitive, list[Any], dict[str, Any]], Union[bool, Any]] | |
_SavePattern = Union[str, re.Pattern, None] | |
# Adapted from jax source code to support regex. Reference: | |
# https://github.com/jax-ml/jax/blob/0d36b0b433a93c707f86dac89b0c05d40302775a/jax/_src/ad_checkpoint.py#L120 | |
def _save_and_offload_only_these_names_regex( | |
*, | |
names_which_can_be_saved: _SavePattern, | |
names_which_can_be_offloaded: _SavePattern, | |
offload_src: str, | |
offload_dst: str, | |
) -> OffloadPolicy: | |
def policy(prim, *_, **params): | |
if prim is name_p: | |
if names_which_can_be_saved and re.fullmatch(names_which_can_be_saved, params["name"]): | |
return pe.Saveable | |
if names_which_can_be_offloaded and re.fullmatch( | |
names_which_can_be_offloaded, params["name"] | |
): | |
return pe.Offloadable(src=offload_src, dst=offload_dst) | |
return pe.Recompute # not saveable unless it's in the allow-list | |
return policy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, you should probably move
axlearn/axlearn/common/attention.py
into
axlearn/axlearn/common/utils.py
Line 129 in 92205bc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree on reusing this policy, will also modify this policy to ingest multiple regex patterns as a sequence. I feel It is hard to put all remat names in just a single regex pattern. Will get back with this change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate on why it's hard to put all names into a single regex? You can use "|" to separate patterns and use "()" to make things more readable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, by hard I meant readable. It is this
regex_patterns_to_save=[
r"TransformerAttentionLayer\.residual_add",
r"\.?(k|q|v)_proj$",
r"\.?linear1_0",
r"\.?linear1_1",
r"TransformerFeedForwardLayer\.mlp_residual",
]
vs
r"TransformerAttentionLayer\.residual_add|\.?(k|q|v)_proj$|\.?linear1_[01]|TransformerFeedForwardLayer\.mlp_residual"
Second one keeps the changes minimal, but the first may be a little more readable. let me know what the decision on this is, I am okay with both.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should reuse the API of _save_and_offload_only_these_names_regex
and avoid changing it since it appears in golden configs. Also cc @ruomingp for some input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
r"TransformerAttentionLayer\.residual_add|\.?(k|q|v)_proj$|\.?linear1_[01]|TransformerFeedForwardLayer\.mlp_residual"
, regarding this part, is it possible to make it a few substrings and then concat them and make one, so that it is more readable while not creating a new API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made updates to the PR to use _save_and_offload_only_these_names_regex
as we discussed. The changes look clean. I put each different remat regex patterns neatly into different lines.
names_which_can_be_saved=r"(TransformerAttentionLayer\.residual_add" # pylint: disable=C0301
"|.*\.?(k|q|v)_proj"
"|.*\.?linear1_[01]"
"|TransformerFeedForwardLayer\.mlp_residual)",
@@ -277,8 +269,7 @@ def model_config( | |||
layer_cfg.self_attention.attention.input_linear = attention_qkv_linear | |||
layer_cfg.self_attention.structure = atten_structure | |||
layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap | |||
if stack_cfg.klass is RepeatedTransformerLayer: | |||
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) | |||
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain why we originally only did this for RepeatedTransformer and why it is okay to do it for everything now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it was used to avoid applying to pipelinedTransformer, where we may have explosion of activation memories due to saving intermediaries, and both stacked and repeated transformer are safe to apply the same remat thing here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still have PipelinedTransformer in the codebase. Will this change increase the memory usage of it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically, it is gated at a later stage, so it may be fine here. https://github.com/apple/axlearn/blob/main/axlearn/common/attention.py#L3994
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, you're saying the check this PR removes may never have been needed at all? Is this PR urgent? Ideally, it would be good to check with Mark.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or can we defensively leave the check for now with a TODO to revisit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about adding
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) | |
if issubclass(stack_cfg.klass, (RepeatedTransformerLayer, StackedTransformerLayer)): | |
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) |
9715aef
to
b8c8fa5
Compare
Thanks for the reviews, most comments are resolved and PR looks clean, let me know if more changes are needed. |
b8c8fa5
to
9b9bba5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A suggestion to address the concern from @apghml. WDYT?
@@ -277,8 +269,7 @@ def model_config( | |||
layer_cfg.self_attention.attention.input_linear = attention_qkv_linear | |||
layer_cfg.self_attention.structure = atten_structure | |||
layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap | |||
if stack_cfg.klass is RepeatedTransformerLayer: | |||
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) | |||
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about adding
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) | |
if issubclass(stack_cfg.klass, (RepeatedTransformerLayer, StackedTransformerLayer)): | |
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg) |
9b9bba5
to
bf4a016
Compare
Thanks for the guidance, I addressed all comments. Let me know if more changes are needed |
axlearn/experiments/text/gpt/fuji.py
Outdated
# Regex patterns for matching remat names | ||
class RematRegex(enum.Enum): | ||
QKV_PROJ = r".*\.?(k|q|v)_proj" | ||
LINEAR1_X = r".*\.?linear1_[01]" | ||
RESIDUAL_ADD = r"TransformerAttentionLayer\.residual_add" | ||
MLP_RESIDUAL = r"TransformerFeedForwardLayer\.mlp_residual" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do we check whether these regex match the right activations?
Consider adding a test as in
axlearn/axlearn/common/attention_test.py
Lines 3831 to 3876 in a15a3bc
def test_build_remat_spec(self): | |
model_dim, num_heads = 6, 2 | |
cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim) | |
cfg.self_attention.attention.set(num_heads=num_heads, causal=True) | |
cfg.feed_forward.hidden_dim = model_dim * 4 | |
cfg.vlog = 5 | |
layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None) | |
layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) | |
batch_size, tgt_len = 2, 5 | |
rng = np.random.default_rng(seed=123) | |
target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32) | |
def f(x, layer_params): | |
forward_outputs, _ = F( | |
layer, | |
inputs=dict( | |
data=x, | |
), | |
state=layer_params, | |
is_training=True, | |
prng_key=jax.random.PRNGKey(0), | |
) | |
return forward_outputs | |
# Ignore type errors. | |
spec: Any = build_remat_spec(mock.MagicMock()) | |
_, default_policy_backward = jax.linearize( | |
jax.remat(f, policy=spec.policy.instantiate(), prevent_cse=spec.prevent_cse), | |
jnp.asarray(target), | |
layer_params, | |
) | |
_, full_remat_backward = jax.linearize( | |
jax.remat(f), | |
jnp.asarray(target), | |
layer_params, | |
) | |
# Eliminated the remat of qkv_proj, context and o_proj = 5 dots. This assumes | |
# FlashAttention is not enabled. | |
self.assertEqual( | |
str(full_remat_backward).count(" dot_general") | |
- str(default_policy_backward).count(" dot_general"), | |
5, | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a test
bf4a016
to
45342e7
Compare
45342e7
to
4d4e32e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High level comment, probably not to add more activation checkpoints in attention transformer layer class unless we must do so, otherwise the change looks good to me. Will approve once revert that part.
@@ -2801,6 +2805,7 @@ def _linear2(x): | |||
self._add_tensor_stats("inputs", inputs) | |||
|
|||
remat_pt2 = "linear2" | |||
remat_pt3 = "feed_forward_output" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, is this necessary? Seems quite wasteful to checkpoint an activation after a norm? I would recommend restraining from making change here.
@@ -2508,23 +2508,27 @@ def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]: | |||
atten_state, atten_output = attention_thunk(TensorSpec(target.shape, target.dtype)) | |||
return dict(attention=atten_state), atten_output | |||
|
|||
remat_pt1 = "attention_output" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really needed? How much perf gain we get from saving attention_output
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please address @kelvin-zou 's comments. Otherwise LGTM.
This PR adds special remat configuration for TRN2 and Fuji-70B. This is done by adding a new remat policy that uses regex to match against multiple regex patterns for remat names. This allows more flexibility in remat configurations for different backends and device types.
Misc: Enable remat for StackedTransformer