Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compatibility with pipeline when loading model with device_map on single gpu #10390

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Dec 26, 2024

What does this PR do?

This PR fixes an device issue with the pipeline when we load a diffusers model separately with device_map in a single gpu case. We can't move the whole pipeline to device as the diffusers model have hooks on it (as we set force_hooks=True) and the following check raise an error :

        def module_is_sequentially_offloaded(module):
            if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
                return False

            return hasattr(module, "_hf_hook") and (
                isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
                or hasattr(module._hf_hook, "hooks")
                and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
            )

        # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
        pipeline_is_sequentially_offloaded = any(
            module_is_sequentially_offloaded(module) for _, module in self.components.items()
        )

The model shouldn't need to have hook in a single gpu case.

Two issues that needs to be solved in follow-up PR ?

  • module_is_sequentially_offloaded check was there initially for sequential offloaded model but we shouldn't allow to move the model that have AlignDevicesHook in general. So maybe we can rename the function and change the error message ?
  • Right now it only works in single-gpu, maybe to fix for multi-gpu case, we can just raise a warning that we won't move this specific module instead of an error ? I'm fine also to suggest using reset_device_map() if the goal is to put all models on the same device.

To reproduce :

import torch
from diffusers import FluxPipeline, FluxTransformer2DModel

model_id = "black-forest-labs/Flux.1-Dev"
dtype = torch.bfloat16

transformer = FluxTransformer2DModel.from_pretrained(
    model_id,
    subfolder="transformer",
    torch_dtype=dtype,
    device_map="auto",
)
if hasattr(transformer, "hf_device_map"):
    print(transformer.hf_device_map)

pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16).to("cuda")
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0, generator=torch.Generator().manual_seed(42)).images[0]
image.save("test_3_out.png")

@@ -937,7 +935,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
force_hooks=force_hook,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are not needed because we actually removes all hooks and dispatch the model again in the pipeline logic.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the model were to be used independently of the pipeline, would removing this be sensible?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only if the user used device_map with only one gpu. Should be fine to be honest. The user actually expect to be able to move the model without any issues if the model in dispatch on only one gpu.

Comment on lines 414 to 420

is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved up because otherwise it will trigger the pipeline_is_sequentially_offloaded error first when passing device_map to the pipeline

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Marc!

@@ -937,7 +935,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
force_hooks=force_hook,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the model were to be used independently of the pipeline, would removing this be sensible?

src/diffusers/pipelines/pipeline_utils.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Member

sayakpaul commented Dec 26, 2024

@SunMarc regarding

Right now it only works in single-gpu, maybe to fix for multi-gpu case, we can just raise a warning that we won't move this specific module instead of an error ? I'm fine also to suggest using reset_device_map() if the goal is to put all models on the same device.

This is the fix that I am suggesting:

diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index a504184ea..393dc83c8 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -388,6 +388,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
 
         device = device or device_arg
         pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
+        is_any_pipeline_component_device_mapped = any(getattr(module, "hf_device_map", None) is not None for _, module in self.components.items()) 
 
         # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
         def module_is_sequentially_offloaded(module):
@@ -411,7 +412,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
             module_is_sequentially_offloaded(module) for _, module in self.components.items()
         )
         if device and torch.device(device).type == "cuda":
-            if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
+            if not is_any_pipeline_component_device_mapped and pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
                 raise ValueError(
                     "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
                 )
@@ -429,7 +430,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
 
         # Display a warning in this case (the operation succeeds but the benefits are lost)
         pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
-        if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
+        if not is_any_pipeline_component_device_mapped and pipeline_is_offloaded and device and torch.device(device).type == "cuda":
             logger.warning(
                 f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
             )
@@ -454,10 +455,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
 
             # This can happen for `transformer` models. CPU placement was added in
             # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
-            if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
-                module.to(device=device)
-            elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
-                module.to(device, dtype)
+            if getattr(module, "hf_device_map", None) is None or (len(module.hf_device_map) == 1 and module.hf_device_map != {'': 'cpu'}):
+                if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
+                    module.to(device=device)
+                elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
+                    module.to(device, dtype)
 
             if (
                 module.dtype == torch.float16

LMK WDYT. I think the the if/else is messy and is not ideal. But this was to show a solution.

@SunMarc
Copy link
Member Author

SunMarc commented Dec 27, 2024

This is roughly the solution I was thinking.
What could be better is to add more condition to pipeline_is_sequentially_offloaded so that we don't have to add is_any_pipeline_component_device_mapped condition everywhere and add a warning when we don't move the model because the device_map is length is superior to 2 here:
if getattr(module, "hf_device_map", None) is None or (len(module.hf_device_map) == 1 and module.hf_device_map != {'': 'cpu'}):

@sayakpaul
Copy link
Member

sayakpaul commented Dec 27, 2024

What could be better is to add more condition to pipeline_is_sequentially_offloaded so that we don't have to add is_any_pipeline_component_device_mapped condition everywhere and add a warning when we don't move the model because the device_map is length is superior to 2 here:
if getattr(module, "hf_device_map", None) is None or (len(module.hf_device_map) == 1 and module.hf_device_map != {'': 'cpu'}):

That sounds good to me! That sounds like it would better cover different kinds of situations more easily. Do you want to tackle it in this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants