Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Fuji v3 405b and solve HBM OOMs for larger models #766

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

samos123
Copy link
Contributor

@samos123 samos123 commented Oct 21, 2024

Main things changed as part of this:

  • Change Embedding partition spec from (None, "model") to ("fsdp", "model")
  • Add a new remat checkpoint on TransformerLayer input and offload to host

reference: https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/configs/models/llama3.1-405b.yml

Requires optimizer state weight only to be merged for the fsdp=256 data=-1 config: #789

@samos123 samos123 marked this pull request as draft October 21, 2024 21:03
@samos123
Copy link
Contributor Author

@kelvin-zou @hanzhi713 would appreciate your review to make sure this PR roughly matches 405B. Thank you!

@samos123
Copy link
Contributor Author

samos123 commented Oct 21, 2024

Getting this error:

NotFoundError: The specified path gs://axlearn-public/tensorflow_datasets/tokenizers/sentencepiece/bpe
_128k_c4.model was not found.

am I doing something wrong or is there a missing tokenizer?

gsutil ls -r -l gs://axlearn-public/tensorflow_datasets/tokenizers/sentencepiece/
gs://axlearn-public/tensorflow_datasets/tokenizers/sentencepiece/:
    778888  2023-11-22T23:02:56Z  gs://axlearn-public/tensorflow_datasets/tokenizers/sentencepiece/bpe_32k_c4.model

@samos123
Copy link
Contributor Author

Fixed the issue after vocab model was uploaded. Now I'm hitting OOM issues. Here is the model config:

max_step: 3932160
mesh_axis_names[0]: 'pipeline'
mesh_axis_names[1]: 'data'
mesh_axis_names[2]: 'expert'
mesh_axis_names[3]: 'fsdp'
mesh_axis_names[4]: 'seq'
mesh_axis_names[5]: 'model'
mesh_rules[0][0]: 'tpu-v6e-.*'
mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier'
mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1
mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1
mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1
mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256
mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1
mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1
mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier'
mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True
mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.nothing_saveable'
mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier'
mesh_shape[0]: 1
mesh_shape[1]: 1
mesh_shape[2]: 1
mesh_shape[3]: 256
mesh_shape[4]: 1
mesh_shape[5]: 1
model.batch_axis_names[0]: 'data'
model.batch_axis_names[1]: 'expert'
model.batch_axis_names[2]: 'fsdp'
model.decoder.attention_mask: None
model.decoder.dim: 53248
model.decoder.dropout_rate: 0.0
model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings'
model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding'
model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal'
model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out'
model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer'
model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0
model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.decoder.emb.token_emb.param_partition_spec[0]: None
model.decoder.emb.token_emb.param_partition_spec[1]: 'model'
model.decoder.eos_token_id: 1
model.decoder.klass: 'axlearn.common.decoder.Decoder'
model.decoder.logits_partition_spec[0][0]: 'data'
model.decoder.logits_partition_spec[0][1]: 'expert'
model.decoder.logits_partition_spec[0][2]: 'fsdp'
model.decoder.logits_partition_spec[1]: 'seq'
model.decoder.logits_partition_spec[2]: 'model'
model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.output_norm.eps: 1e-05
model.decoder.output_norm.forward_dtype: None
model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm'
model.decoder.pad_token_id: 0
model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer'
model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu'
model.decoder.transformer.layer.feed_forward.activation[1]: 'linear'
model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn'
model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256
model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 2.6666666666666665
model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer'
model.decoder.transformer.layer.feed_forward.linear1.bias: False
model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq'
model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq'
model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model'
model.decoder.transformer.layer.feed_forward.linear2.bias: False
model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq'
model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp'
model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq'
model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05
model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None
model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm'
model.decoder.transformer.layer.feed_forward.residual_weight: 1.0
model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth'
model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row'
model.decoder.transformer.layer.feed_forward.structure: 'prenorm'
model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer'
model.decoder.transformer.layer.remat_spec['prevent_cse']: True
model.decoder.transformer.layer.remat_spec['policy']: 'jax._src.ad_checkpoint.nothing_saveable'
model.decoder.transformer.layer.self_attention.attention.causal: True
model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model'
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None
model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8
model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear'
model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding'
model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0
model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False
model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey'
model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None
model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None
model.decoder.transformer.layer.self_attention.attention.num_heads: 128
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq'
model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None
model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False
model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model'
model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None
model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery'
model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512
model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer'
model.decoder.transformer.layer.self_attention.norm.eps: 1e-05
model.decoder.transformer.layer.self_attention.norm.forward_dtype: None
model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm'
model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth'
model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row'
model.decoder.transformer.layer.self_attention.structure: 'prenorm'
model.decoder.transformer.num_layers: 126
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
model.decoder.vocab_size: 131072
model.dtype: 'jax.numpy.float32'
model.klass: 'axlearn.common.causal_lm.Model'
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'
model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in'
model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer'
model.param_init.init_by_param_name['.*weight$'].scale: 1.0
model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.seq_axis_names[0]: 'seq'
model.z_loss_scale: 0.0
name: 'gpt_trainer'
prune_empty_state_updates: True
recorder.fn: '__main__.<lambda>'
save_input_iterator: False
start_trace_process_indices[0]: 0
start_trace_steps[0]: 16
summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter'
summary_writer.max_queue: 1000
summary_writer.write_every_n_steps: 100
train_dtype: 'jax.numpy.bfloat16'

@ruomingp ruomingp requested a review from kelvin-zou October 23, 2024 17:39
Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will defer to @kelvin-zou for review.

@@ -91,16 +91,19 @@ class Version(enum.Enum):
"test": 1 * (1024**4), # 1T tokens
"7B": 1 * (1024**4), # 1T tokens
"70B": int(1.4 * (1024**4)), # 1.4T tokens
"405B": int(1.4 * (1024**4)), # 1.4T tokens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no v1/v2 version for 405B right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct but existing fuji code requires it to be defined otherwise threw an error. Let me clean it up though.

hidden_dim=53248,
num_heads=128,
# No GQA support in V1 models, so num_kv_heads is the same as num_heads.
num_kv_heads=None if version == Version.V1 else 8,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 only since no v1/v2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

),
learner_kwargs=dict(peak_lr=8e-5, weight_decay=0.1),
max_sequence_length=max_sequence_length,
train_batch_size=train_batch_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Training batch is not 4M tokens for 400B model, it is 16M according to this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@samos123
Copy link
Contributor Author

This PR is still in draft mode. I will be able to update it once I get 405B working on trillium.

@samos123
Copy link
Contributor Author

The implementation is wrong, hidden_dim should be 16384 and ffn_dim should be 53248 right?

I will update this PR once I get my trillium-405b branch working.

@samos123
Copy link
Contributor Author

Unable to run golden_config_test:

_______________________ ERROR collecting axlearn/experiments/golden_config_test.py ________________________
ImportError while importing test module '/Users/stoelinga/workspace/axlearn/axlearn/experiments/golden_config_test.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
../../miniforge3/envs/axlearn-10/lib/python3.10/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
axlearn/experiments/golden_config_test.py:19: in <module>
    from axlearn.experiments.audio import conformer
axlearn/experiments/audio/conformer/__init__.py:5: in <module>
    from . import librispeech_trainer
axlearn/experiments/audio/conformer/librispeech_trainer.py:36: in <module>
    from axlearn.audio.encoder_asr import SpeechFeatureLayer
axlearn/audio/encoder_asr.py:18: in <module>
    from axlearn.audio.subsamplers import ConvSubSampler
axlearn/audio/subsamplers.py:10: in <module>
    from axlearn.common.layers import BaseNormalizationLayer, Conv2DWith1DPadding, get_activation_fn
axlearn/common/layers.py:52: in <module>
    from axlearn.common.quantized_dot_general.layers import DenseGeneralBaseLayer
axlearn/common/quantized_dot_general/layers.py:29: in <module>
    from aqt.jax.v2.config import DotGeneral, set_context
E   ImportError: cannot import name 'set_context' from 'aqt.jax.v2.config' (/Users/stoelinga/miniforge3/envs/axlearn-10/lib/python3.10/site-packages/aqt/jax/v2/config.py)
------------- generated xml file: /Users/stoelinga/workspace/axlearn/test-results/testing.xml -------------
============================================ 1 error in 22.07s ============================================

ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note this requires optimizer state weight only offloading PR to be merged: #789

@samos123 samos123 requested a review from kelvin-zou November 19, 2024 00:54
@samos123
Copy link
Contributor Author

@kelvin-zou could you give it another review?

I added the TransformerLayer input checkpointing offload to host. This is required in order to run 405B. I did something similar in my full branch: main...samos123:axlearn:trillium-405b-offload

@samos123 samos123 changed the title add Fuji v3 405b model config add Fuji v3 405b and solve HBM OOMs for larger models Nov 19, 2024
@@ -1598,7 +1598,7 @@ class Config(BaseLayer.Config):
@classmethod
def default_config(cls):
cfg = super().default_config()
cfg.param_partition_spec = (None, "model")
cfg.param_partition_spec = ("fsdp", "model")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kelvin-zou unclear if this is safe to change by default for everyone. Please review this specifically. It was needed to fit 405b on HBM memory.

@samos123
Copy link
Contributor Author

this needs to be rebased since shared_lm_head is a new thing in latest main.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants