Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] fix offload gpu tests & a few device_map related refactor #10366

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/diffusers/models/transformers/sana_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class SanaModulatedNorm(nn.Module):
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)

def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor) -> torch.Tensor:
hidden_states = self.norm(hidden_states)
shift, scale = (
scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
return hidden_states


class SanaTransformerBlock(nn.Module):
r"""
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
Expand Down Expand Up @@ -288,8 +302,7 @@ def __init__(

# 4. Output blocks
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)

self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)

self.gradient_checkpointing = False
Expand Down Expand Up @@ -462,13 +475,8 @@ def custom_forward(*inputs):
)

# 3. Normalization
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)

# 4. Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)

# 5. Unpatchify
Expand Down
58 changes: 26 additions & 32 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,49 +391,34 @@ def to(self, *args, **kwargs):
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())

# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
def module_is_sequentially_offloaded(module):
pipeline_is_sequentially_offloaded = hasattr(self, "_all_sequential_hooks") and self._all_sequential_hooks is not None and len(self._all_sequential_hooks) > 0
pipeline_is_offloaded = hasattr(self, "_all_hooks") and self._all_hooks is not None and len(self._all_hooks) > 0
pipeline_is_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
def module_has_hooks(module):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False

return hasattr(module, "_hf_hook") and (
isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
or hasattr(module._hf_hook, "hooks")
and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
)

def module_is_offloaded(module):
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
return False

return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)

# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
return hasattr(module, "_hf_hook") and module._hf_hook is not None
pipeline_has_hooks = any(module_has_hooks(module) for _, module in self.components.items())
if device and torch.device(device).type == "cuda":
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
if pipeline_is_offloaded:
raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)
if pipeline_is_sequentially_offloaded:
raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now manually moving the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
)
# PR: https://github.com/huggingface/accelerate/pull/3223/
elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
if pipeline_has_hooks and pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
raise ValueError(
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
)

is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
if pipeline_is_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
)

# Display a warning in this case (the operation succeeds but the benefits are lost)
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
logger.warning(
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)

module_names, _ = self._get_signature_keys(self)
modules = [getattr(self, n, None) for n in module_names]
Expand All @@ -452,13 +437,18 @@ def module_is_offloaded(module):
logger.warning(
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
)

is_device_mapped = module_has_hooks(module) and hasattr(module, "hf_device_map") and module.hf_device_map is not None

# This can happen for `transformer` models. CPU placement was added in
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
module.to(device=device)
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
module.to(device, dtype)
if is_device_mapped:
logger.warning(f"{module.__class__.__name__} is has a device map {module.hf_device_map} and will not be moved to {device}.")
else:
module.to(device, dtype)

if (
module.dtype == torch.float16
Expand Down Expand Up @@ -1014,7 +1004,10 @@ def remove_all_hooks(self):
for _, model in self.components.items():
if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
accelerate.hooks.remove_hook_from_module(model, recurse=True)
self._all_hooks = []
if hasattr(self, "_all_hooks"):
self._all_hooks = []
if hasattr(self, "_all_sequential_hooks"):
self._all_sequential_hooks = []

def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
Expand Down Expand Up @@ -1166,17 +1159,18 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

self._all_sequential_hooks = []
for name, model in self.components.items():
if not isinstance(model, torch.nn.Module):
continue

if name in self._exclude_from_cpu_offload:
model.to(device)
else:
# make sure to offload buffers if not all high level weights
# are of type nn.Module
offload_buffers = len(model._parameters) > 0
cpu_offload(model, device, offload_buffers=offload_buffers)
self.all_sequential_hooks.append(model._hf_hook)

def reset_device_map(self):
r"""
Expand Down
11 changes: 6 additions & 5 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import requests_mock
import torch
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size, compute_module_sizes
from huggingface_hub import ModelCard, delete_repo, snapshot_download
from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
Expand Down Expand Up @@ -1080,7 +1080,7 @@ def test_cpu_offload(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_persistent_sizes(model)[""]
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down Expand Up @@ -1110,7 +1110,7 @@ def test_disk_offload_without_safetensors(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_persistent_sizes(model)[""]
model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)

Expand Down Expand Up @@ -1144,7 +1144,7 @@ def test_disk_offload_with_safetensors(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_persistent_sizes(model)[""]
model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)

Expand Down Expand Up @@ -1172,7 +1172,7 @@ def test_model_parallelism(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_persistent_sizes(model)[""]
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
Expand All @@ -1183,6 +1183,7 @@ def test_model_parallelism(self):
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
print(f" new_model.hf_device_map:{new_model.hf_device_map}")

self.check_device_map_is_respected(new_model, new_model.hf_device_map)

Expand Down
25 changes: 1 addition & 24 deletions tests/models/transformers/test_models_transformer_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SanaTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.7, 0.7, 0.9]

@property
def dummy_input(self):
Expand Down Expand Up @@ -81,27 +82,3 @@ def prepare_init_args_and_inputs_for_common(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SanaTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda",
reason="Test currently fails.",
strict=True,
)
def test_cpu_offload(self):
return super().test_cpu_offload()

@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda",
reason="Test currently fails.",
strict=True,
)
def test_disk_offload_with_safetensors(self):
return super().test_disk_offload_with_safetensors()

@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda",
reason="Test currently fails.",
strict=True,
)
def test_disk_offload_without_safetensors(self):
return super().test_disk_offload_without_safetensors()
Loading