Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower Hunyuan Video LoRA memory requirements #135

Open
a-r-r-o-w opened this issue Dec 23, 2024 · 15 comments
Open

Lower Hunyuan Video LoRA memory requirements #135

a-r-r-o-w opened this issue Dec 23, 2024 · 15 comments
Labels
enhancement New feature or request

Comments

@a-r-r-o-w
Copy link
Owner

It should be possible to leverage fp8 casted models, or torchao quantization, to support training in under 24 GB upto a reasonable resolution. Or atleast that's the hope when combined with precomputation from #129. Will take a look soon 🤗

TorchAO docs: https://huggingface.co/docs/diffusers/main/en/quantization/torchao
FP8 casting: huggingface/diffusers#10347

@a-r-r-o-w a-r-r-o-w added the enhancement New feature or request label Dec 23, 2024
@Lumoria
Copy link

Lumoria commented Dec 23, 2024

What are the memory requirements for Hunyuan currently? I'm OOMing with 48gb

@a-r-r-o-w
Copy link
Owner Author

Could you give #129 a try? I believe with FP8 it should fit in 24 gb based on rough calculations, but will continue to try and improve

image

@Lumoria
Copy link

Lumoria commented Dec 23, 2024

Sadly I still OOM even after precompiling the conditions and latents

@a-r-r-o-w
Copy link
Owner Author

Just to confirm, are you using the bash script from README or a custom launch script? And are you sure --gradient_checkpointing is being used. If it still isn't working after that, I'll take a look tomorrow and try to have FP8 support asap

@Lumoria
Copy link

Lumoria commented Dec 23, 2024

Yeah, using the bash script from the readme. --gradient_checkpointing and --precompute_conditions are both being passed.

@Aristo23333
Copy link

Aristo23333 commented Dec 24, 2024

Hi, I also try the bash in your README.md and load the CKPT you provide in https://huggingface.co/hunyuanvideo-community/HunyuanVideo. But I get OOM even in an 80 GiB H800 when loading the HunyuanVideo transformer, before training. And my training device is 1/2 H800

@generalsvr
Copy link

Have same OOM problem with --precompute_conditions and --gradient_checkpointing form README script on A100

@a-r-r-o-w
Copy link
Owner Author

a-r-r-o-w commented Dec 24, 2024

I'm unable to replicate unfortunately. I just verified once again that I can run training in about 42 GB of memory when precomputation and gradient checkpointing is enabled with 49x512x768 videos. I would like to know what version of pytorch everyone is using. Can you share the output of diffusers-cli env? Here's mine:

- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.14
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): 0.8.5 (cpu)
- Jax version: 0.4.31
- JaxLib version: 0.4.31
- Huggingface_hub version: 0.26.2
- Transformers version: 4.48.0.dev0
- Accelerate version: 1.1.0.dev0
- PEFT version: 0.13.3.dev0
- Bitsandbytes version: 0.43.3
- Safetensors version: 0.4.5
- xFormers version: not installed
- Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA DGX Display, 4096 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

Can you also try running training with resolution buckets set as 1x512x512 and see if that OOMs as well? If it OOMs, it's like a problem with the pytorch version and needs upgrade. If it doesn't, I will try running on the code on a few different environments to verify. So far, I know that it works on A6000, A100, RTX 4090 (but with custom fp8 code), and H100 for sure from atleast two other folks but the above info will be very helpful to localize the error. DeepSpeed support by @sayakpaul should be in soon too, so hopefully that helps further reduce some memory requirements

@Lumoria
Copy link

Lumoria commented Dec 24, 2024

Setting the bucket size to 1x512x12 still OOMs.

  • 🤗 Diffusers version: 0.32.0.dev0
  • Platform: Linux-6.11.2-amd64-x86_64-with-glibc2.40
  • Running on Google Colab?: No
  • Python version: 3.11.11
  • PyTorch version (GPU?): 2.4.1+cu121 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.27.0
  • Transformers version: 4.47.1
  • Accelerate version: 1.2.1
  • PEFT version: 0.14.0
  • Bitsandbytes version: 0.45.0
  • Safetensors version: 0.4.5
  • xFormers version: not installed
  • Accelerator: NVIDIA RTX 6000 Ada Generation, 49140 MiB
    NVIDIA GeForce RTX 3080 Ti, 12288 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

@a-r-r-o-w
Copy link
Owner Author

I see. I'll give pytorch 2.4 a try and profile it tomorrow. Could you try upgrading to pytorch 2.5.1 and see if it does away, or the nightly 2.6.0?

@a-r-r-o-w
Copy link
Owner Author

a-r-r-o-w commented Dec 24, 2024

Also, on 2.4, could you first check if the inference results in a normal video or a black video with the example code here: https://huggingface.co/docs/diffusers/en/api/pipelines/hunyuan_video

There have been reports of it not working and I suspect it's something to do with the torch version. If the inference is not working, there's a slim chance training would work well

The example doesn't mention it, but if you're facing OOM for inference, pipe.enable_model_cpu_offload() should do the trick

@generalsvr
Copy link

generalsvr commented Dec 24, 2024

I'm unable to replicate unfortunately. I just verified once again that I can run training in about 42 GB of memory when precomputation and gradient checkpointing is enabled with 49x512x768 videos. I would like to know what version of pytorch everyone is using. Can you share the output of diffusers-cli env? Here's mine:

- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.31
- Running on Google Colab?: No
- Python version: 3.10.14
- PyTorch version (GPU?): 2.5.1+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): 0.8.5 (cpu)
- Jax version: 0.4.31
- JaxLib version: 0.4.31
- Huggingface_hub version: 0.26.2
- Transformers version: 4.48.0.dev0
- Accelerate version: 1.1.0.dev0
- PEFT version: 0.13.3.dev0
- Bitsandbytes version: 0.43.3
- Safetensors version: 0.4.5
- xFormers version: not installed
- Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA DGX Display, 4096 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

Can you also try running training with resolution buckets set as 1x512x512 and see if that OOMs as well? If it OOMs, it's like a problem with the pytorch version and needs upgrade. If it doesn't, I will try running on the code on a few different environments to verify. So far, I know that it works on A6000, A100, RTX 4090 (but with custom fp8 code), and H100 for sure from atleast two other folks but the above info will be very helpful to localize the error. DeepSpeed support by @sayakpaul should be in soon too, so hopefully that helps further reduce some memory requirements

I was able to run training with accelerate_configs/uncompiled_1.yaml config but during training the loss was nan. The output of this lora model after training is black screen. Can you explain what is the difference between these configs please? Pytorch 2.4.0

@a-r-r-o-w
Copy link
Owner Author

@generalsvr It seems like 2.4 might be a problematic torch version for some operations. I'm going through the relevant commits in pytorch to try and see what exactly causes this, but I believe upgrading to 2.5.1 will fix the nan loss. Could you give that a try?

@a-r-r-o-w
Copy link
Owner Author

The configs are simply some rules that tell accelerate, the HuggingFace solution we used for distributed data parallelism, what kind of environment and settings we want to train with. So, uncompiled_1.yaml is a file that specifies we want to use 1 GPU without torch.compile. uncompiled_8.yaml, similarly, is for when we want to use 8 GPUs in parallel. The full list of configurations is explorable via the accelerate config command and you can read about it in the docs

@generalsvr
Copy link

@generalsvr It seems like 2.4 might be a problematic torch version for some operations. I'm going through the relevant commits in pytorch to try and see what exactly causes this, but I believe upgrading to 2.5.1 will fix the nan loss. Could you give that a try?

After updating pytorch to 2.5.1 on the same machine original hunyuan model started to generate videos. But training is still a problem. I can see loss appeared once on step 1, but then again nan. Video generation with lora resulting in a black screen.

Training run 1 log:

Training steps: 0%| | 0/20 [00:00<?, ?it/s]12/24/2024 16:47:37 - DEBUG - finetrainers - Starting epoch (1/1)
12/24/2024 16:47:37 - DEBUG - finetrainers - Starting step 1
Training steps: 5%|███▏ | 1/20 [02:06<40:10, 126.89s/it, loss=0.462, lr=2e-7]12/24/2024 16:49:44 - DEBUG - accelerate.tracking - Successfully logged to WandB
12/24/2024 16:49:44 - DEBUG - finetrainers - Starting step 2
Training steps: 10%|██████▌ | 2/20 [04:11<36:54, 123.04s/it, loss=nan, lr=4e-7]12/24/2024 16:51:49 - DEBUG - accelerate.tracking - Successfully logged to WandB
12/24/2024 16:51:49 - DEBUG - finetrainers - Starting step 3
Training steps: 15%|█████████▉ | 3/20 [06:17<35:07, 123.97s/it, loss=nan, lr=6e-7]12/24/2024 16:53:54 - DEBUG - accelerate.tracking - Successfully logged to WandB
12/24/2024 16:53:54 - DEBUG - finetrainers - Starting step 4
Training steps: 20%|█████████████▏ | 4/20 [08:22<33:10, 124.40s/it, loss=nan, lr=8e-7]12/24/2024 16:55:59 - DEBUG - accelerate.tracking - Successfully logged to WandB
12/24/2024 16:55:59 - DEBUG - finetrainers - Starting step 5
Training steps: 25%|████████████████▌ | 5/20 [10:22<31:09, 124.64s/it, loss=nan, lr=8e-7]12/24/2024 16:57:59 - INFO - finetrainers - Checkpointing at step 5

Training run 2 log:

Training steps: 0%| | 0/20 [00:00<?, ?it/s]12/24/2024 17:10:49 - DEBUG - finetrainers - Starting epoch (1/1)
12/24/2024 17:10:50 - DEBUG - finetrainers - Starting step 1
Training steps: 5%|███▏ | 1/20 [02:06<40:11, 126.92s/it, loss=0.462, lr=2e-7]12/24/2024 17:12:56 - DEBUG - accelerate.tracking - Successfully logged to WandB
12/24/2024 17:12:56 - DEBUG - finetrainers - Starting step 2
Training steps: 10%|██████▌ | 2/20 [04:11<36:54, 123.05s/it, loss=nan, lr=4e-7]12/24/2024 17:15:01 - DEBUG - accelerate.tracking - Successfully logged to WandB
12/24/2024 17:15:01 - DEBUG - finetrainers - Starting step 3
Training steps: 15%|█████████▉ | 3/20 [06:17<35:07, 123.97s/it, loss=nan, lr=6e-7]12/24/2024 17:17:06 - DEBUG - accelerate.tracking - Successfully logged to WandB
12/24/2024 17:17:06 - DEBUG - finetrainers - Starting step 4
Training steps: 20%|█████████████▏ | 4/20 [08:22<33:10, 124.41s/it, loss=nan, lr=8e-7]12/24/2024 17:19:12 - DEBUG - accelerate.tracking - Successfully logged to WandB
12/24/2024 17:19:12 - DEBUG - finetrainers - Starting step 5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants