Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add keep_torch_compile param to unwrap_model and extract_model_from_parallel for distributed compiled model. #3282

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ggoggam
Copy link

@ggoggam ggoggam commented Dec 9, 2024

What does this PR do?

This PR fixes the unexpected behavior of Accelerator.unwrap_model (Issue #3281). Right now, if the model is wrapped in both distributed wrapper (e.g. DistributedDataParallel or DeepSpeedEngine) and compiled module (OptimizedModule) it only unwraps the distributed module. This behavior arises from the following code in L80 of utils/others.py:

is_compiled = is_compiled_module(model)

Instead of checking for compiled model both before and after unwrapping distributed wrapper, the current code only checks for compilation before unwrapping the distributed wrapper. If the model is wrapped in both, is_compiled will be set to False and won't unwrap the model fully, resulting in unexpected behavior (users expect fully unwrapped model before saving, but gets OptimizedModule instead, which may result in an error when loading the state dict due to key mismatch).

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Related Issue

Who can review?

@muellerzr @SunMarc

@ggoggam ggoggam force-pushed the bugfix/check-compiled-model branch from fb3809e to f601b8c Compare December 9, 2024 07:03
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ggoggam, this is actually intended behavior as you can see from this PR #1437.

What I can suggest modify either unwrap_model or add a new argument to extract_model_from_parallel. cc @muellerzr

@ggoggam ggoggam changed the title Fix extract_model_from_parallel to fully unwrap compiled model. Fix unwrap_model for distributed compiled model. Dec 9, 2024
@ggoggam
Copy link
Author

ggoggam commented Dec 9, 2024

I see. I think it would make more sense to modify unwrap_model to the exhibit expected behavior. I will wait for other reviews before I implement the fix.

@BenjaminBossan
Copy link
Member

Thanks for the PR. Regarding the implementation, extract_model_from_parallel is already supposed to handle both compilation and models wrapped by DS/DDP. If it doesn't work, I'd rather see extract_model_from_parallel fixed than calling it repeatedly. Also, could you please add a unit test similar to this one, your original example could be adjusted for that.

@ggoggam
Copy link
Author

ggoggam commented Dec 11, 2024

Thanks for the PR. Regarding the implementation, extract_model_from_parallel is already supposed to handle both compilation and models wrapped by DS/DDP. If it doesn't work, I'd rather see extract_model_from_parallel fixed than calling it repeatedly. Also, could you please add a unit test similar to this one, your original example could be adjusted for that.

My first commit actually fixes extract_model_from_parallel and the unit test. While I was fixing this, I noticed that the unit test you mentioned checks for the inner module of the unwrapped model, which is not expected.

assert compiled_model._orig_mod == compiled_model_unwrapped._orig_mod

If I understand correctly, it should be

assert compiled_model._orig_mod == compiled_model_unwrapped

if extract_model_from_parallel is supposed to handle both compilation and models by DS/DDP?

@BenjaminBossan
Copy link
Member

If I understand correctly, it should be

assert compiled_model._orig_mod == compiled_model_unwrapped

if extract_model_from_parallel is supposed to handle both compilation and models by DS/DDP?

Hmm, good question, I'll leave that to the others to answer, as I'm not sure.

@SunMarc
Copy link
Member

SunMarc commented Dec 11, 2024

If I understand correctly, it should be

assert compiled_model._orig_mod == compiled_model_unwrapped

if extract_model_from_parallel is supposed to handle both compilation and models by DS/DDP?

That's right. But the current behavior is that we don't unwrap the compiled model with extract_model_from_parallel. Right now, we just make sure that we are able to unwrap the distributed part even if the model is compiled.

cc @muellerzr what do you prefer, modify how extract_model_from_parallel works or modify unwrap function ?

@@ -77,8 +77,7 @@ def extract_model_from_parallel(model, keep_fp32_wrapper: bool = True, recursive
"""
options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel)

is_compiled = is_compiled_module(model)
Copy link
Author

@ggoggam ggoggam Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify: Is is_compiled being assigned here and then checked after DDP/DS is unwrapped? If so, this seems potentially unnecessary or wrong as it checks for compilation only before unwrapping DDP/DS

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree that we should check for compilation after

Copy link
Author

@ggoggam ggoggam Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I think I misunderstood the code. is_compiled is assigned here before unwrapping to keep the torch compilation, though I still think this could be a problem if the model is in the form of Distributed(Compiled(model)). Refer to my newest commit.

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Thanks for trying to tackle this. I think what I'd rather see instead is in extract_model_from_parallel we should add a new param arg called keep_torch_compile, which similar to keep_fp32_wrapper should default to False for a number of versions and then we likely should flip that to True after awhile, & modify the logic in extract_model_from_parallel to reflect this change.

Can you tweak this PR to so? :)

@ggoggam
Copy link
Author

ggoggam commented Dec 11, 2024

Hey! Thanks for trying to tackle this. I think what I'd rather see instead is in extract_model_from_parallel we should add a new param arg called keep_torch_compile, which similar to keep_fp32_wrapper should default to False for a number of versions and then we likely should flip that to True after awhile, & modify the logic in extract_model_from_parallel to reflect this change.

Can you tweak this PR to so? :)

Sure thing. I added keep_torch_compile argument to both Accelerator.unwrap_model (defaults to False) and extract_model_from_parallel (defaults to True). Let me know if this makes sense.

I am also curious about what you think in the case of Distributed(Compiled(model)), since the current code only accounts for Compiled(Distributed(model)) case. I think Accelerator.prepare does the latter, but I am not sure if we should still consider the former case.

@ggoggam ggoggam requested a review from muellerzr December 12, 2024 23:24
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this and the tests !
Also, note that we have to also upstream the modification done to extract_model_from_parallel in transformers if needed.

src/accelerate/utils/other.py Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -2601,7 +2601,7 @@ def pad_across_processes(self, tensor, dim=0, pad_index=0, pad_first=False):
"""
return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first)

def unwrap_model(self, model, keep_fp32_wrapper: bool = True):
def unwrap_model(self, model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you wanted to set keep_torch_compile to True for a couple of version @muellerzr ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, let's default to True until 2.5.0, to give users ~3 months.

Essentially we should default to None, and then if it gets None warn that the default to this will be changing

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've defaulted keep_torch_compile=True in the latest commit.

@ggoggam ggoggam force-pushed the bugfix/check-compiled-model branch from 71d8bf1 to c2e9e5b Compare December 18, 2024 09:10
@ggoggam ggoggam requested a review from SunMarc December 19, 2024 07:55
@ggoggam ggoggam changed the title Fix unwrap_model for distributed compiled model. Add keep_torch_compile param to unwrap_model and extract_model_from_parallel for distributed compiled model. Dec 19, 2024
@ggoggam
Copy link
Author

ggoggam commented Dec 24, 2024

@SunMarc Please let me know if there is any changes needed for this PR!

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !

@SunMarc
Copy link
Member

SunMarc commented Dec 24, 2024

cc @muellerzr for final

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants