-
Notifications
You must be signed in to change notification settings - Fork 277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add Fuji v3 405b and solve HBM OOMs for larger models #766
base: main
Are you sure you want to change the base?
Conversation
@kelvin-zou @hanzhi713 would appreciate your review to make sure this PR roughly matches 405B. Thank you! |
Getting this error:
am I doing something wrong or is there a missing tokenizer?
|
Fixed the issue after vocab model was uploaded. Now I'm hitting OOM issues. Here is the model config:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will defer to @kelvin-zou for review.
axlearn/experiments/text/gpt/fuji.py
Outdated
@@ -91,16 +91,19 @@ class Version(enum.Enum): | |||
"test": 1 * (1024**4), # 1T tokens | |||
"7B": 1 * (1024**4), # 1T tokens | |||
"70B": int(1.4 * (1024**4)), # 1.4T tokens | |||
"405B": int(1.4 * (1024**4)), # 1.4T tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no v1/v2 version for 405B right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct but existing fuji code requires it to be defined otherwise threw an error. Let me clean it up though.
axlearn/experiments/text/gpt/fuji.py
Outdated
hidden_dim=53248, | ||
num_heads=128, | ||
# No GQA support in V1 models, so num_kv_heads is the same as num_heads. | ||
num_kv_heads=None if version == Version.V1 else 8, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 only since no v1/v2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
), | ||
learner_kwargs=dict(peak_lr=8e-5, weight_decay=0.1), | ||
max_sequence_length=max_sequence_length, | ||
train_batch_size=train_batch_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Training batch is not 4M tokens for 400B model, it is 16M according to this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
This PR is still in draft mode. I will be able to update it once I get 405B working on trillium. |
The implementation is wrong, hidden_dim should be 16384 and ffn_dim should be 53248 right? I will update this PR once I get my |
a0142c1
to
497f90e
Compare
Unable to run golden_config_test:
|
ChainConfigModifier.default_config().set( | ||
config_modifiers=[ | ||
MeshShapeModifier.default_config().set( | ||
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note this requires optimizer state weight only offloading PR to be merged: #789
@kelvin-zou could you give it another review? I added the TransformerLayer input checkpointing offload to host. This is required in order to run 405B. I did something similar in my full branch: main...samos123:axlearn:trillium-405b-offload |
@@ -1598,7 +1598,7 @@ class Config(BaseLayer.Config): | |||
@classmethod | |||
def default_config(cls): | |||
cfg = super().default_config() | |||
cfg.param_partition_spec = (None, "model") | |||
cfg.param_partition_spec = ("fsdp", "model") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kelvin-zou unclear if this is safe to change by default for everyone. Please review this specifically. It was needed to fit 405b on HBM memory.
this needs to be rebased since shared_lm_head is a new thing in latest main. |
408033d
to
7532d9e
Compare
7532d9e
to
fe10772
Compare
Main things changed as part of this:
reference: https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/configs/models/llama3.1-405b.yml
Requires optimizer state weight only to be merged for the fsdp=256 data=-1 config: #789