From 4b557132ce955d58fd84572c03e79f43bdc91450 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 23 Dec 2024 19:51:33 +0530 Subject: [PATCH] [core] LTX Video 0.9.1 (#10330) * update * make style * update * update * update * make style * single file related changes * update * fix * update single file urls and docs * update * fix --- docs/source/en/api/pipelines/ltx_video.md | 42 ++- scripts/convert_ltx_to_diffusers.py | 110 +++++++- src/diffusers/loaders/single_file_utils.py | 28 +- .../models/autoencoders/autoencoder_kl_ltx.py | 264 +++++++++++++++--- src/diffusers/pipelines/ltx/pipeline_ltx.py | 26 +- .../pipelines/ltx/pipeline_ltx_image2video.py | 26 +- tests/lora/test_lora_layers_ltx_video.py | 11 +- .../test_models_autoencoder_ltx_video.py | 169 +++++++++++ tests/pipelines/ltx/test_ltx.py | 11 +- tests/pipelines/ltx/test_ltx_image2video.py | 11 +- 10 files changed, 642 insertions(+), 56 deletions(-) create mode 100644 tests/models/autoencoders/test_models_autoencoder_ltx_video.py diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md index a925b848706e..017a8ac49e53 100644 --- a/docs/source/en/api/pipelines/ltx_video.md +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. --> -# LTX +# LTX Video [LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases. @@ -22,14 +22,24 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m +Available models: + +| Model name | Recommended dtype | +|:-------------:|:-----------------:| +| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` | +| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` | + +Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository. + ## Loading Single Files -Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. +Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format. ```python import torch from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel +# `single_file_url` could also be https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.1.safetensors single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors" transformer = LTXVideoTransformer3DModel.from_single_file( single_file_url, torch_dtype=torch.bfloat16 @@ -99,6 +109,34 @@ export_to_video(video, "output_gguf_ltx.mp4", fps=24) Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support. + + +Loading and running inference with [LTX Video 0.9.1](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) weights. + +```python +import torch +from diffusers import LTXPipeline +from diffusers.utils import export_to_video + +pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.1-diffusers", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" +negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + +video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=768, + height=512, + num_frames=161, + decode_timestep=0.03, + decode_noise_scale=0.025, + num_inference_steps=50, +).frames[0] +export_to_video(video, "output.mp4", fps=24) +``` + Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption. ## LTXPipeline diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py index f4398a2e687c..7df0745fd98c 100644 --- a/scripts/convert_ltx_to_diffusers.py +++ b/scripts/convert_ltx_to_diffusers.py @@ -1,7 +1,9 @@ import argparse +from pathlib import Path from typing import Any, Dict import torch +from accelerate import init_empty_weights from safetensors.torch import load_file from transformers import T5EncoderModel, T5Tokenizer @@ -21,7 +23,9 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]): "k_norm": "norm_k", } -TRANSFORMER_SPECIAL_KEYS_REMAP = {} +TRANSFORMER_SPECIAL_KEYS_REMAP = { + "vae": remove_keys_, +} VAE_KEYS_RENAME_DICT = { # decoder @@ -54,10 +58,31 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]): "per_channel_statistics.std-of-means": "latents_std", } +VAE_091_RENAME_DICT = { + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.upsamplers.0", + "up_blocks.8": "up_blocks.3", + # common + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", +} + VAE_SPECIAL_KEYS_REMAP = { "per_channel_statistics.channel": remove_keys_, "per_channel_statistics.mean-of-means": remove_keys_, "per_channel_statistics.mean-of-stds": remove_keys_, + "model.diffusion_model": remove_keys_, +} + +VAE_091_SPECIAL_KEYS_REMAP = { + "timestep_scale_multiplier": remove_keys_, } @@ -80,13 +105,16 @@ def convert_transformer( ckpt_path: str, dtype: torch.dtype, ): - PREFIX_KEY = "" + PREFIX_KEY = "model.diffusion_model." original_state_dict = get_state_dict(load_file(ckpt_path)) - transformer = LTXVideoTransformer3DModel().to(dtype=dtype) + with init_empty_weights(): + transformer = LTXVideoTransformer3DModel() for key in list(original_state_dict.keys()): - new_key = key[len(PREFIX_KEY) :] + new_key = key[:] + if new_key.startswith(PREFIX_KEY): + new_key = key[len(PREFIX_KEY) :] for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) update_state_dict_inplace(original_state_dict, key, new_key) @@ -97,16 +125,21 @@ def convert_transformer( continue handler_fn_inplace(key, original_state_dict) - transformer.load_state_dict(original_state_dict, strict=True) + transformer.load_state_dict(original_state_dict, strict=True, assign=True) return transformer -def convert_vae(ckpt_path: str, dtype: torch.dtype): +def convert_vae(ckpt_path: str, config, dtype: torch.dtype): + PREFIX_KEY = "vae." + original_state_dict = get_state_dict(load_file(ckpt_path)) - vae = AutoencoderKLLTXVideo().to(dtype=dtype) + with init_empty_weights(): + vae = AutoencoderKLLTXVideo(**config) for key in list(original_state_dict.keys()): new_key = key[:] + if new_key.startswith(PREFIX_KEY): + new_key = key[len(PREFIX_KEY) :] for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) update_state_dict_inplace(original_state_dict, key, new_key) @@ -117,10 +150,60 @@ def convert_vae(ckpt_path: str, dtype: torch.dtype): continue handler_fn_inplace(key, original_state_dict) - vae.load_state_dict(original_state_dict, strict=True) + vae.load_state_dict(original_state_dict, strict=True, assign=True) return vae +def get_vae_config(version: str) -> Dict[str, Any]: + if version == "0.9.0": + config = { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (128, 256, 512, 512), + "decoder_block_out_channels": (128, 256, 512, 512), + "layers_per_block": (4, 3, 3, 3, 4), + "decoder_layers_per_block": (4, 3, 3, 3, 4), + "spatio_temporal_scaling": (True, True, True, False), + "decoder_spatio_temporal_scaling": (True, True, True, False), + "decoder_inject_noise": (False, False, False, False, False), + "upsample_residual": (False, False, False, False), + "upsample_factor": (1, 1, 1, 1), + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "scaling_factor": 1.0, + "encoder_causal": True, + "decoder_causal": False, + "timestep_conditioning": False, + } + elif version == "0.9.1": + config = { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "block_out_channels": (128, 256, 512, 512), + "decoder_block_out_channels": (256, 512, 1024), + "layers_per_block": (4, 3, 3, 3, 4), + "decoder_layers_per_block": (5, 6, 7, 8), + "spatio_temporal_scaling": (True, True, True, False), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (True, True, True, False), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": True, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-6, + "scaling_factor": 1.0, + "encoder_causal": True, + "decoder_causal": False, + } + VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT) + VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP) + return config + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -139,6 +222,9 @@ def get_args(): parser.add_argument("--save_pipeline", action="store_true") parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") + parser.add_argument( + "--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model" + ) return parser.parse_args() @@ -161,6 +247,7 @@ def get_args(): transformer = None dtype = DTYPE_MAPPING[args.dtype] variant = VARIANT_MAPPING[args.dtype] + output_path = Path(args.output_path) if args.save_pipeline: assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None @@ -169,13 +256,14 @@ def get_args(): transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype) if not args.save_pipeline: transformer.save_pretrained( - args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant + output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant ) if args.vae_ckpt_path is not None: - vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype) + config = get_vae_config(args.version) + vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, config, dtype) if not args.save_pipeline: - vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant) + vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant) if args.save_pipeline: text_encoder_id = "google/t5-v1_1-xxl" diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 6de9f0e9e638..b623576e3990 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -157,7 +157,8 @@ "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"}, "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"}, "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, - "ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"}, + "ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"}, + "ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"}, "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"}, "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"}, "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, @@ -605,7 +606,10 @@ def infer_diffusers_model_type(checkpoint): model_type = "flux-schnell" elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]): - model_type = "ltx-video" + if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint: + model_type = "ltx-video-0.9.1" + else: + model_type = "ltx-video" elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint: encoder_key = "encoder.project_in.conv.conv.bias" @@ -2338,12 +2342,32 @@ def remove_keys_(key: str, state_dict): "per_channel_statistics.std-of-means": "latents_std", } + VAE_091_RENAME_DICT = { + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0.upsamplers.0", + "up_blocks.2": "up_blocks.0", + "up_blocks.3": "up_blocks.1.upsamplers.0", + "up_blocks.4": "up_blocks.1", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.upsamplers.0", + "up_blocks.8": "up_blocks.3", + # common + "last_time_embedder": "time_embedder", + "last_scale_shift_table": "scale_shift_table", + } + VAE_SPECIAL_KEYS_REMAP = { "per_channel_statistics.channel": remove_keys_, "per_channel_statistics.mean-of-means": remove_keys_, "per_channel_statistics.mean-of-stds": remove_keys_, + "timestep_scale_multiplier": remove_keys_, } + if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict: + VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT) + for key in list(converted_state_dict.keys()): new_key = key for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index a6cb943e09cc..9aa53f7af243 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -22,6 +22,7 @@ from ...loaders import FromOriginalModelMixin from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation +from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from ..normalization import RMSNorm @@ -109,7 +110,9 @@ def __init__( elementwise_affine: bool = False, non_linearity: str = "swish", is_causal: bool = True, - ): + inject_noise: bool = False, + timestep_conditioning: bool = False, + ) -> None: super().__init__() out_channels = out_channels or in_channels @@ -135,18 +138,54 @@ def __init__( in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal ) - def forward(self, inputs: torch.Tensor) -> torch.Tensor: + self.per_channel_scale1 = None + self.per_channel_scale2 = None + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + self.per_channel_scale2 = nn.Parameter(torch.zeros(in_channels, 1, 1)) + + self.scale_shift_table = None + if timestep_conditioning: + self.scale_shift_table = nn.Parameter(torch.randn(4, in_channels) / in_channels**0.5) + + def forward( + self, inputs: torch.Tensor, temb: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None + ) -> torch.Tensor: hidden_states = inputs hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.scale_shift_table is not None: + temb = temb.unflatten(1, (4, -1)) + self.scale_shift_table[None, ..., None, None, None] + shift_1, scale_1, shift_2, scale_2 = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale_1) + shift_1 + hidden_states = self.nonlinearity(hidden_states) hidden_states = self.conv1(hidden_states) + if self.per_channel_scale1 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale1)[None, :, None, ...] + hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.scale_shift_table is not None: + hidden_states = hidden_states * (1 + scale_2) + shift_2 + hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) + if self.per_channel_scale2 is not None: + spatial_shape = hidden_states.shape[-2:] + spatial_noise = torch.randn( + spatial_shape, generator=generator, device=hidden_states.device, dtype=hidden_states.dtype + )[None] + hidden_states = hidden_states + (spatial_noise * self.per_channel_scale2)[None, :, None, ...] + if self.norm3 is not None: inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1) @@ -163,12 +202,16 @@ def __init__( in_channels: int, stride: Union[int, Tuple[int, int, int]] = 1, is_causal: bool = True, + residual: bool = False, + upscale_factor: int = 1, ) -> None: super().__init__() self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.residual = residual + self.upscale_factor = upscale_factor - out_channels = in_channels * stride[0] * stride[1] * stride[2] + out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor self.conv = LTXVideoCausalConv3d( in_channels=in_channels, @@ -181,6 +224,15 @@ def __init__( def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = hidden_states.shape + if self.residual: + residual = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + residual = residual.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + repeats = (self.stride[0] * self.stride[1] * self.stride[2]) // self.upscale_factor + residual = residual.repeat(1, repeats, 1, 1, 1) + residual = residual[:, :, self.stride[0] - 1 :] + hidden_states = self.conv(hidden_states) hidden_states = hidden_states.reshape( batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width @@ -188,6 +240,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) hidden_states = hidden_states[:, :, self.stride[0] - 1 :] + if self.residual: + hidden_states = hidden_states + residual + return hidden_states @@ -273,7 +328,12 @@ def __init__( self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: r"""Forward method of the `LTXDownBlock3D` class.""" for i, resnet in enumerate(self.resnets): @@ -285,16 +345,18 @@ def create_forward(*inputs): return create_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, generator + ) else: - hidden_states = resnet(hidden_states) + hidden_states = resnet(hidden_states, temb, generator) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) if self.conv_out is not None: - hidden_states = self.conv_out(hidden_states) + hidden_states = self.conv_out(hidden_states, temb, generator) return hidden_states @@ -329,9 +391,15 @@ def __init__( resnet_eps: float = 1e-6, resnet_act_fn: str = "swish", is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, ) -> None: super().__init__() + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + resnets = [] for _ in range(num_layers): resnets.append( @@ -342,15 +410,32 @@ def __init__( eps=resnet_eps, non_linearity=resnet_act_fn, is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, ) ) self.resnets = nn.ModuleList(resnets) self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: r"""Forward method of the `LTXMidBlock3D` class.""" + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) + for i, resnet in enumerate(self.resnets): if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -360,9 +445,11 @@ def create_forward(*inputs): return create_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, generator + ) else: - hidden_states = resnet(hidden_states) + hidden_states = resnet(hidden_states, temb, generator) return hidden_states @@ -403,11 +490,19 @@ def __init__( resnet_act_fn: str = "swish", spatio_temporal_scale: bool = True, is_causal: bool = True, + inject_noise: bool = False, + timestep_conditioning: bool = False, + upsample_residual: bool = False, + upscale_factor: int = 1, ): super().__init__() out_channels = out_channels or in_channels + self.time_embedder = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(in_channels * 4, 0) + self.conv_in = None if in_channels != out_channels: self.conv_in = LTXVideoResnetBlock3d( @@ -417,11 +512,23 @@ def __init__( eps=resnet_eps, non_linearity=resnet_act_fn, is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, ) self.upsamplers = None if spatio_temporal_scale: - self.upsamplers = nn.ModuleList([LTXVideoUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)]) + self.upsamplers = nn.ModuleList( + [ + LTXVideoUpsampler3d( + out_channels * upscale_factor, + stride=(2, 2, 2), + is_causal=is_causal, + residual=upsample_residual, + upscale_factor=upscale_factor, + ) + ] + ) resnets = [] for _ in range(num_layers): @@ -433,15 +540,32 @@ def __init__( eps=resnet_eps, non_linearity=resnet_act_fn, is_causal=is_causal, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, ) ) self.resnets = nn.ModuleList(resnets) self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward( + self, + hidden_states: torch.Tensor, + temb: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> torch.Tensor: if self.conv_in is not None: - hidden_states = self.conv_in(hidden_states) + hidden_states = self.conv_in(hidden_states, temb, generator) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -456,9 +580,11 @@ def create_forward(*inputs): return create_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, generator + ) else: - hidden_states = resnet(hidden_states) + hidden_states = resnet(hidden_states, temb, generator) return hidden_states @@ -623,6 +749,8 @@ class LTXVideoDecoder3d(nn.Module): Epsilon value for ResNet normalization layers. is_causal (`bool`, defaults to `False`): Whether this layer behaves causally (future frames depend only on past frames) or not. + timestep_conditioning (`bool`, defaults to `False`): + Whether to condition the model on timesteps. """ def __init__( @@ -636,6 +764,10 @@ def __init__( patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, is_causal: bool = False, + inject_noise: Tuple[bool, ...] = (False, False, False, False), + timestep_conditioning: bool = False, + upsample_residual: Tuple[bool, ...] = (False, False, False, False), + upsample_factor: Tuple[bool, ...] = (1, 1, 1, 1), ) -> None: super().__init__() @@ -646,6 +778,9 @@ def __init__( block_out_channels = tuple(reversed(block_out_channels)) spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) layers_per_block = tuple(reversed(layers_per_block)) + inject_noise = tuple(reversed(inject_noise)) + upsample_residual = tuple(reversed(upsample_residual)) + upsample_factor = tuple(reversed(upsample_factor)) output_channel = block_out_channels[0] self.conv_in = LTXVideoCausalConv3d( @@ -653,15 +788,20 @@ def __init__( ) self.mid_block = LTXVideoMidBlock3d( - in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal + in_channels=output_channel, + num_layers=layers_per_block[0], + resnet_eps=resnet_norm_eps, + is_causal=is_causal, + inject_noise=inject_noise[0], + timestep_conditioning=timestep_conditioning, ) # up blocks num_block_out_channels = len(block_out_channels) self.up_blocks = nn.ModuleList([]) for i in range(num_block_out_channels): - input_channel = output_channel - output_channel = block_out_channels[i] + input_channel = output_channel // upsample_factor[i] + output_channel = block_out_channels[i] // upsample_factor[i] up_block = LTXVideoUpBlock3d( in_channels=input_channel, @@ -670,6 +810,10 @@ def __init__( resnet_eps=resnet_norm_eps, spatio_temporal_scale=spatio_temporal_scaling[i], is_causal=is_causal, + inject_noise=inject_noise[i + 1], + timestep_conditioning=timestep_conditioning, + upsample_residual=upsample_residual[i], + upscale_factor=upsample_factor[i], ) self.up_blocks.append(up_block) @@ -681,9 +825,16 @@ def __init__( in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal ) + # timestep embedding + self.time_embedder = None + self.scale_shift_table = None + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0) + self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5) + self.gradient_checkpointing = False - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: hidden_states = self.conv_in(hidden_states) if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -694,17 +845,33 @@ def create_forward(*inputs): return create_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), hidden_states, temb + ) for up_block in self.up_blocks: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states) + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states, temb) else: - hidden_states = self.mid_block(hidden_states) + hidden_states = self.mid_block(hidden_states, temb) for up_block in self.up_blocks: - hidden_states = up_block(hidden_states) + hidden_states = up_block(hidden_states, temb) hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + + if self.time_embedder is not None: + temb = self.time_embedder( + timestep=temb.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=hidden_states.size(0), + hidden_dtype=hidden_states.dtype, + ) + temb = temb.view(hidden_states.size(0), -1, 1, 1, 1).unflatten(1, (2, -1)) + temb = temb + self.scale_shift_table[None, ..., None, None, None] + shift, scale = temb.unbind(dim=1) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.conv_act(hidden_states) hidden_states = self.conv_out(hidden_states) @@ -767,8 +934,15 @@ def __init__( out_channels: int = 3, latent_channels: int = 128, block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), - spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), + decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), + decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), + decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), + decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False), + upsample_residual: Tuple[bool, ...] = (False, False, False, False), + upsample_factor: Tuple[int, ...] = (1, 1, 1, 1), + timestep_conditioning: bool = False, patch_size: int = 4, patch_size_t: int = 1, resnet_norm_eps: float = 1e-6, @@ -792,13 +966,17 @@ def __init__( self.decoder = LTXVideoDecoder3d( in_channels=latent_channels, out_channels=out_channels, - block_out_channels=block_out_channels, - spatio_temporal_scaling=spatio_temporal_scaling, - layers_per_block=layers_per_block, + block_out_channels=decoder_block_out_channels, + spatio_temporal_scaling=decoder_spatio_temporal_scaling, + layers_per_block=decoder_layers_per_block, patch_size=patch_size, patch_size_t=patch_size_t, resnet_norm_eps=resnet_norm_eps, is_causal=decoder_causal, + timestep_conditioning=timestep_conditioning, + inject_noise=decoder_inject_noise, + upsample_residual=upsample_residual, + upsample_factor=upsample_factor, ) latents_mean = torch.zeros((latent_channels,), requires_grad=False) @@ -937,13 +1115,15 @@ def encode( return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) - def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + def _decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: batch_size, num_channels, num_frames, height, width = z.shape tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): - return self.tiled_decode(z, return_dict=return_dict) + return self.tiled_decode(z, temb, return_dict=return_dict) if self.use_framewise_decoding: # TODO(aryan): requires investigation @@ -953,7 +1133,7 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." ) else: - dec = self.decoder(z) + dec = self.decoder(z, temb) if not return_dict: return (dec,) @@ -961,7 +1141,9 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut return DecoderOutput(sample=dec) @apply_forward_hook - def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + def decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor] = None, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: """ Decode a batch of images. @@ -976,10 +1158,15 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp returned. """ if self.use_slicing and z.shape[0] > 1: - decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + if temb is not None: + decoded_slices = [ + self._decode(z_slice, t_slice).sample for z_slice, t_slice in (z.split(1), temb.split(1)) + ] + else: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] decoded = torch.cat(decoded_slices) else: - decoded = self._decode(z).sample + decoded = self._decode(z, temb).sample if not return_dict: return (decoded,) @@ -1061,7 +1248,9 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] return enc - def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + def tiled_decode( + self, z: torch.Tensor, temb: Optional[torch.Tensor], return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: r""" Decode a batch of images using a tiled decoder. @@ -1102,7 +1291,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." ) else: - time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]) + time = self.decoder( + z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width], temb + ) row.append(time) rows.append(row) @@ -1130,6 +1321,7 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod def forward( self, sample: torch.Tensor, + temb: Optional[torch.Tensor] = None, sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, @@ -1140,7 +1332,7 @@ def forward( z = posterior.sample(generator=generator) else: z = posterior.mode() - dec = self.decode(z) + dec = self.decode(z, temb) if not return_dict: return (dec,) return dec diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 7180601dad41..96d41bb3224b 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -511,6 +511,8 @@ def __call__( prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -563,6 +565,10 @@ def __call__( provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -753,7 +759,25 @@ def __call__( latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor ) latents = latents.to(prompt_embeds.dtype) - video = self.vae.decode(latents, return_dict=False)[0] + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + video = self.vae.decode(latents, timestep, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) # Offload all models diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index fbb30e304d65..71fd725c915b 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -571,6 +571,8 @@ def __call__( prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, + decode_timestep: Union[float, List[float]] = 0.0, + decode_noise_scale: Optional[Union[float, List[float]]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -625,6 +627,10 @@ def __call__( provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for negative text embeddings. + decode_timestep (`float`, defaults to `0.0`): + The timestep at which generated video is decoded. + decode_noise_scale (`float`, defaults to `None`): + The interpolation factor between random noise and denoised latents at the decode timestep. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -849,7 +855,25 @@ def __call__( latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor ) latents = latents.to(prompt_embeds.dtype) - video = self.vae.decode(latents, return_dict=False)[0] + + if not self.vae.config.timestep_conditioning: + timestep = None + else: + noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size + + timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype) + decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[ + :, None, None, None, None + ] + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + video = self.vae.decode(latents, timestep, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) # Offload all models diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py index 1ed426f6e8dd..0eccaa73ad42 100644 --- a/tests/lora/test_lora_layers_ltx_video.py +++ b/tests/lora/test_lora_layers_ltx_video.py @@ -52,10 +52,19 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): } transformer_cls = LTXVideoTransformer3DModel vae_kwargs = { + "in_channels": 3, + "out_channels": 3, "latent_channels": 8, "block_out_channels": (8, 8, 8, 8), - "spatio_temporal_scaling": (True, True, False, False), + "decoder_block_out_channels": (8, 8, 8, 8), "layers_per_block": (1, 1, 1, 1, 1), + "decoder_layers_per_block": (1, 1, 1, 1, 1), + "spatio_temporal_scaling": (True, True, False, False), + "decoder_spatio_temporal_scaling": (True, True, False, False), + "decoder_inject_noise": (False, False, False, False, False), + "upsample_residual": (False, False, False, False), + "upsample_factor": (1, 1, 1, 1), + "timestep_conditioning": False, "patch_size": 1, "patch_size_t": 1, "encoder_causal": True, diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py new file mode 100644 index 000000000000..37f9837c8245 --- /dev/null +++ b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py @@ -0,0 +1,169 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AutoencoderKLLTXVideo +from diffusers.utils.testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin + + +enable_full_determinism() + + +class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): + model_class = AutoencoderKLLTXVideo + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_ltx_video_config(self): + return { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 8, + "block_out_channels": (8, 8, 8, 8), + "decoder_block_out_channels": (8, 8, 8, 8), + "layers_per_block": (1, 1, 1, 1, 1), + "decoder_layers_per_block": (1, 1, 1, 1, 1), + "spatio_temporal_scaling": (True, True, False, False), + "decoder_spatio_temporal_scaling": (True, True, False, False), + "decoder_inject_noise": (False, False, False, False, False), + "upsample_residual": (False, False, False, False), + "upsample_factor": (1, 1, 1, 1), + "timestep_conditioning": False, + "patch_size": 1, + "patch_size_t": 1, + "encoder_causal": True, + "decoder_causal": False, + } + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (16, 16) + + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + + return {"sample": image} + + @property + def input_shape(self): + return (3, 9, 16, 16) + + @property + def output_shape(self): + return (3, 9, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_ltx_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "LTXVideoEncoder3d", + "LTXVideoDecoder3d", + "LTXVideoDownBlock3D", + "LTXVideoMidBlock3d", + "LTXVideoUpBlock3d", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass + + @unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.") + def test_forward_with_norm_groups(self): + pass + + +class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): + model_class = AutoencoderKLLTXVideo + main_input_name = "sample" + base_precision = 1e-2 + + def get_autoencoder_kl_ltx_video_config(self): + return { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 8, + "block_out_channels": (8, 8, 8, 8), + "decoder_block_out_channels": (16, 32, 64), + "layers_per_block": (1, 1, 1, 1), + "decoder_layers_per_block": (1, 1, 1, 1), + "spatio_temporal_scaling": (True, True, True, False), + "decoder_spatio_temporal_scaling": (True, True, True), + "decoder_inject_noise": (True, True, True, False), + "upsample_residual": (True, True, True), + "upsample_factor": (2, 2, 2), + "timestep_conditioning": True, + "patch_size": 1, + "patch_size_t": 1, + "encoder_causal": True, + "decoder_causal": False, + } + + @property + def dummy_input(self): + batch_size = 2 + num_frames = 9 + num_channels = 3 + sizes = (16, 16) + + image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) + timestep = torch.tensor([0.05] * batch_size, device=torch_device) + + return {"sample": image, "temb": timestep} + + @property + def input_shape(self): + return (3, 9, 16, 16) + + @property + def output_shape(self): + return (3, 9, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = self.get_autoencoder_kl_ltx_video_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "LTXVideoEncoder3d", + "LTXVideoDecoder3d", + "LTXVideoDownBlock3D", + "LTXVideoMidBlock3d", + "LTXVideoUpBlock3d", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + @unittest.skip("Unsupported test.") + def test_outputs_equivalence(self): + pass + + @unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.") + def test_forward_with_norm_groups(self): + pass diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py index 0f9819bfd6d8..dd166c6242fc 100644 --- a/tests/pipelines/ltx/test_ltx.py +++ b/tests/pipelines/ltx/test_ltx.py @@ -63,10 +63,19 @@ def get_dummy_components(self): torch.manual_seed(0) vae = AutoencoderKLLTXVideo( + in_channels=3, + out_channels=3, latent_channels=8, block_out_channels=(8, 8, 8, 8), - spatio_temporal_scaling=(True, True, False, False), + decoder_block_out_channels=(8, 8, 8, 8), layers_per_block=(1, 1, 1, 1, 1), + decoder_layers_per_block=(1, 1, 1, 1, 1), + spatio_temporal_scaling=(True, True, False, False), + decoder_spatio_temporal_scaling=(True, True, False, False), + decoder_inject_noise=(False, False, False, False, False), + upsample_residual=(False, False, False, False), + upsample_factor=(1, 1, 1, 1), + timestep_conditioning=False, patch_size=1, patch_size_t=1, encoder_causal=True, diff --git a/tests/pipelines/ltx/test_ltx_image2video.py b/tests/pipelines/ltx/test_ltx_image2video.py index 40397e4c3619..1c3e018a8a4b 100644 --- a/tests/pipelines/ltx/test_ltx_image2video.py +++ b/tests/pipelines/ltx/test_ltx_image2video.py @@ -68,10 +68,19 @@ def get_dummy_components(self): torch.manual_seed(0) vae = AutoencoderKLLTXVideo( + in_channels=3, + out_channels=3, latent_channels=8, block_out_channels=(8, 8, 8, 8), - spatio_temporal_scaling=(True, True, False, False), + decoder_block_out_channels=(8, 8, 8, 8), layers_per_block=(1, 1, 1, 1, 1), + decoder_layers_per_block=(1, 1, 1, 1, 1), + spatio_temporal_scaling=(True, True, False, False), + decoder_spatio_temporal_scaling=(True, True, False, False), + decoder_inject_noise=(False, False, False, False, False), + upsample_residual=(False, False, False, False), + upsample_factor=(1, 1, 1, 1), + timestep_conditioning=False, patch_size=1, patch_size_t=1, encoder_causal=True,