Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Special remat for Neuron #898

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

apoorvtintin
Copy link

@apoorvtintin apoorvtintin commented Dec 17, 2024

This PR adds special remat configuration for TRN2 and Fuji-70B. This is done by adding a new remat policy that uses regex to match against multiple regex patterns for remat names. This allows more flexibility in remat configurations for different backends and device types.

Misc: Enable remat for StackedTransformer

@apoorvtintin apoorvtintin requested review from ruomingp, markblee and a team as code owners December 17, 2024 21:08
@apoorvtintin apoorvtintin changed the title Special Remat for Neuron Special remat for Neuron Dec 17, 2024
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-remat branch from a08fb3c to 9715aef Compare December 17, 2024 21:52
axlearn/common/utils.py Outdated Show resolved Hide resolved
@@ -3934,6 +3935,33 @@ def forward(
_SavePattern = Union[str, re.Pattern, None]


def save_only_these_regex_patterns(*regex_patterns_to_save):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this different from

OffloadPolicy = Callable[[Primitive, list[Any], dict[str, Any]], Union[bool, Any]]
_SavePattern = Union[str, re.Pattern, None]
# Adapted from jax source code to support regex. Reference:
# https://github.com/jax-ml/jax/blob/0d36b0b433a93c707f86dac89b0c05d40302775a/jax/_src/ad_checkpoint.py#L120
def _save_and_offload_only_these_names_regex(
*,
names_which_can_be_saved: _SavePattern,
names_which_can_be_offloaded: _SavePattern,
offload_src: str,
offload_dst: str,
) -> OffloadPolicy:
def policy(prim, *_, **params):
if prim is name_p:
if names_which_can_be_saved and re.fullmatch(names_which_can_be_saved, params["name"]):
return pe.Saveable
if names_which_can_be_offloaded and re.fullmatch(
names_which_can_be_offloaded, params["name"]
):
return pe.Offloadable(src=offload_src, dst=offload_dst)
return pe.Recompute # not saveable unless it's in the allow-list
return policy
?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, you should probably move
axlearn/axlearn/common/attention.py

into

and make it part of extended_remat_policies, if you need it outside of attention.py.

Copy link
Author

@apoorvtintin apoorvtintin Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree on reusing this policy, will also modify this policy to ingest multiple regex patterns as a sequence. I feel It is hard to put all remat names in just a single regex pattern. Will get back with this change

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on why it's hard to put all names into a single regex? You can use "|" to separate patterns and use "()" to make things more readable.

Copy link
Author

@apoorvtintin apoorvtintin Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, by hard I meant readable. It is this

regex_patterns_to_save=[
    r"TransformerAttentionLayer\.residual_add",
    r"\.?(k|q|v)_proj$",
    r"\.?linear1_0",
    r"\.?linear1_1",
    r"TransformerFeedForwardLayer\.mlp_residual",
]

vs
r"TransformerAttentionLayer\.residual_add|\.?(k|q|v)_proj$|\.?linear1_[01]|TransformerFeedForwardLayer\.mlp_residual"
Second one keeps the changes minimal, but the first may be a little more readable. let me know what the decision on this is, I am okay with both.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should reuse the API of _save_and_offload_only_these_names_regex and avoid changing it since it appears in golden configs. Also cc @ruomingp for some input.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

r"TransformerAttentionLayer\.residual_add|\.?(k|q|v)_proj$|\.?linear1_[01]|TransformerFeedForwardLayer\.mlp_residual" , regarding this part, is it possible to make it a few substrings and then concat them and make one, so that it is more readable while not creating a new API?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made updates to the PR to use _save_and_offload_only_these_names_regex as we discussed. The changes look clean. I put each different remat regex patterns neatly into different lines.

  names_which_can_be_saved=r"(TransformerAttentionLayer\.residual_add"  # pylint: disable=C0301
  "|.*\.?(k|q|v)_proj"
  "|.*\.?linear1_[01]"
  "|TransformerFeedForwardLayer\.mlp_residual)",

axlearn/common/attention.py Outdated Show resolved Hide resolved
axlearn/common/utils.py Outdated Show resolved Hide resolved
axlearn/common/utils.py Outdated Show resolved Hide resolved
@@ -277,8 +269,7 @@ def model_config(
layer_cfg.self_attention.attention.input_linear = attention_qkv_linear
layer_cfg.self_attention.structure = atten_structure
layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap
if stack_cfg.klass is RepeatedTransformerLayer:
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why we originally only did this for RepeatedTransformer and why it is okay to do it for everything now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was used to avoid applying to pipelinedTransformer, where we may have explosion of activation memories due to saving intermediaries, and both stacked and repeated transformer are safe to apply the same remat thing here.

Copy link
Contributor

@apghml apghml Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still have PipelinedTransformer in the codebase. Will this change increase the memory usage of it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically, it is gated at a later stage, so it may be fine here. https://github.com/apple/axlearn/blob/main/axlearn/common/attention.py#L3994

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, you're saying the check this PR removes may never have been needed at all? Is this PR urgent? Ideally, it would be good to check with Mark.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or can we defensively leave the check for now with a TODO to revisit?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding

Suggested change
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
if issubclass(stack_cfg.klass, (RepeatedTransformerLayer, StackedTransformerLayer)):
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-remat branch from 9715aef to b8c8fa5 Compare December 18, 2024 02:26
@apoorvtintin
Copy link
Author

Thanks for the reviews, most comments are resolved and PR looks clean, let me know if more changes are needed.

axlearn/experiments/text/gpt/fuji.py Outdated Show resolved Hide resolved
axlearn/experiments/text/gpt/fuji.py Outdated Show resolved Hide resolved
axlearn/experiments/text/gpt/fuji.py Show resolved Hide resolved
axlearn/experiments/text/gpt/fuji.py Outdated Show resolved Hide resolved
@apoorvtintin apoorvtintin force-pushed the mainline-upstream-remat branch from b8c8fa5 to 9b9bba5 Compare December 18, 2024 08:13
Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A suggestion to address the concern from @apghml. WDYT?

@@ -277,8 +269,7 @@ def model_config(
layer_cfg.self_attention.attention.input_linear = attention_qkv_linear
layer_cfg.self_attention.structure = atten_structure
layer_cfg.self_attention.attention.atten_logit_cap = atten_logit_cap
if stack_cfg.klass is RepeatedTransformerLayer:
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding

Suggested change
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)
if issubclass(stack_cfg.klass, (RepeatedTransformerLayer, StackedTransformerLayer)):
update_model_remat_config(stack_cfg=stack_cfg, layer_cfg=layer_cfg)

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-remat branch from 9b9bba5 to bf4a016 Compare December 18, 2024 19:50
@apoorvtintin
Copy link
Author

Thanks for the guidance, I addressed all comments. Let me know if more changes are needed

axlearn/audio/evaler_asr_test.py Outdated Show resolved Hide resolved
Comment on lines 90 to 95
# Regex patterns for matching remat names
class RematRegex(enum.Enum):
QKV_PROJ = r".*\.?(k|q|v)_proj"
LINEAR1_X = r".*\.?linear1_[01]"
RESIDUAL_ADD = r"TransformerAttentionLayer\.residual_add"
MLP_RESIDUAL = r"TransformerFeedForwardLayer\.mlp_residual"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we check whether these regex match the right activations?

Consider adding a test as in

def test_build_remat_spec(self):
model_dim, num_heads = 6, 2
cfg: TransformerLayer.Config = TransformerLayer.default_config().set(input_dim=model_dim)
cfg.self_attention.attention.set(num_heads=num_heads, causal=True)
cfg.feed_forward.hidden_dim = model_dim * 4
cfg.vlog = 5
layer: BaseTransformerLayer = cfg.clone(name="layer").instantiate(parent=None)
layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0))
batch_size, tgt_len = 2, 5
rng = np.random.default_rng(seed=123)
target = rng.random([batch_size, tgt_len, cfg.input_dim], dtype=np.float32)
def f(x, layer_params):
forward_outputs, _ = F(
layer,
inputs=dict(
data=x,
),
state=layer_params,
is_training=True,
prng_key=jax.random.PRNGKey(0),
)
return forward_outputs
# Ignore type errors.
spec: Any = build_remat_spec(mock.MagicMock())
_, default_policy_backward = jax.linearize(
jax.remat(f, policy=spec.policy.instantiate(), prevent_cse=spec.prevent_cse),
jnp.asarray(target),
layer_params,
)
_, full_remat_backward = jax.linearize(
jax.remat(f),
jnp.asarray(target),
layer_params,
)
# Eliminated the remat of qkv_proj, context and o_proj = 5 dots. This assumes
# FlashAttention is not enabled.
self.assertEqual(
str(full_remat_backward).count(" dot_general")
- str(default_policy_backward).count(" dot_general"),
5,
)
.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test

@apoorvtintin apoorvtintin force-pushed the mainline-upstream-remat branch from bf4a016 to 45342e7 Compare December 18, 2024 21:47
Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High level comment, probably not to add more activation checkpoints in attention transformer layer class unless we must do so, otherwise the change looks good to me. Will approve once revert that part.

@@ -2801,6 +2805,7 @@ def _linear2(x):
self._add_tensor_stats("inputs", inputs)

remat_pt2 = "linear2"
remat_pt3 = "feed_forward_output"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, is this necessary? Seems quite wasteful to checkpoint an activation after a norm? I would recommend restraining from making change here.

@@ -2508,23 +2508,27 @@ def attention_thunk(target: Tensor) -> tuple[Optional[NestedTensor], Tensor]:
atten_state, atten_output = attention_thunk(TensorSpec(target.shape, target.dtype))
return dict(attention=atten_state), atten_output

remat_pt1 = "attention_output"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this really needed? How much perf gain we get from saving attention_output?

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address @kelvin-zou 's comments. Otherwise LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants