Skip to content

Commit

Permalink
change mesh shape to fit the model in memory
Browse files Browse the repository at this point in the history
  • Loading branch information
sychen52 committed Dec 25, 2024
1 parent 0cc22e0 commit 8926ffb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def get_trainer_kwargs(
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256)
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=512)
),
RematSpecModifier.default_config().set(
remat_policies={
Expand All @@ -359,7 +359,7 @@ def get_trainer_kwargs(
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256)
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=512)
),
RematSpecModifier.default_config().set(
remat_policies={
Expand Down

0 comments on commit 8926ffb

Please sign in to comment.