Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add _no_split_modules to some models #10308

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ def _get_signature_keys(cls, obj):
# Adapted from `transformers` modeling_utils.py
def _get_no_split_modules(self, device_map: str):
"""
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
Get the modules of the model that should not be split when using device_map. We iterate through the modules to
get the underlying `_no_split_modules`.

Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""

_supports_gradient_checkpointing = True
_no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]

@register_to_config
def __init__(
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/transformers/transformer_allegro.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
Scaling factor to apply in 3D positional embeddings across time dimension.
"""

_supports_gradient_checkpointing = True

@register_to_config
def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
"""

_supports_gradient_checkpointing = True
_no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]

@register_to_config
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
"""

_supports_gradient_checkpointing = True
_no_split_modules = [
"HunyuanVideoTransformerBlock",
"HunyuanVideoSingleTransformerBlock",
"HunyuanVideoPatchEmbed",
"HunyuanVideoTokenRefiner",
]

@register_to_config
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def prepare_init_args_and_inputs_for_common(self):
"out_channels": 4,
"time_embed_dim": 2,
"text_embed_dim": 8,
"num_layers": 1,
"num_layers": 2,
"sample_width": 8,
"sample_height": 8,
"sample_frames": 8,
Expand Down Expand Up @@ -130,7 +130,7 @@ def prepare_init_args_and_inputs_for_common(self):
"out_channels": 4,
"time_embed_dim": 2,
"text_embed_dim": 8,
"num_layers": 1,
"num_layers": 2,
"sample_width": 8,
"sample_height": 8,
"sample_frames": 8,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 4,
"num_layers": 1,
"num_layers": 2,
"attention_head_dim": 4,
"num_attention_heads": 2,
"out_channels": 4,
Expand Down
Loading