Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

high memory consumption of VAE decoder in SD2.1 #10350

Open
yfpeng1234 opened this issue Dec 23, 2024 · 1 comment
Open

high memory consumption of VAE decoder in SD2.1 #10350

yfpeng1234 opened this issue Dec 23, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@yfpeng1234
Copy link

Describe the bug

When I try to add the VAE decoder in SD2.1 to my training pipeline, I encountered a OOM error. After careful inspection, I found that the decoder really take a vast amount of memory. If input is in the shape of [1,4,96,96], the memory consumption is already 15G. If I increase the batch size, this value is even bigger.

Reproduction

`import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from diffusers import AutoencoderKL
import torch

vae = AutoencoderKL.from_pretrained(
'stabilityai/stable-diffusion-2-1', subfolder="vae"
)
#vae.encoder.requires_grad_(False)
vae.to('cuda', dtype=torch.float16)

batch=4, channel=4, h,w=96,96, this is the shape of latent

x0=torch.randn((4,4,96,96), device='cuda', dtype=torch.float16)

while True:
#latents = vae.encode(x0).latent_dist.sample()
#print(latents.shape)
x1= vae.decode(x0/vae.config.scaling_factor, return_dict=False)[0]`

Logs

No response

System Info

  • 🤗 Diffusers version: 0.32.0.dev0
  • Platform: Linux-5.15.0-124-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.15
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.26.2
  • Transformers version: 4.46.1
  • Accelerate version: 1.1.0
  • PEFT version: 0.7.0
  • Bitsandbytes version: 0.44.1
  • Safetensors version: 0.4.5
  • xFormers version: 0.0.28.post3
  • Accelerator: NVIDIA RTX 5880 Ada Generation, 49140 MiB
    NVIDIA RTX 5880 Ada Generation, 49140 MiB
    NVIDIA RTX 5880 Ada Generation, 49140 MiB
    NVIDIA RTX 5880 Ada Generation, 49140 MiB
    NVIDIA RTX 5880 Ada Generation, 49140 MiB
    NVIDIA RTX 5880 Ada Generation, 49140 MiB
    NVIDIA RTX 5880 Ada Generation, 49140 MiB
    NVIDIA RTX 5880 Ada Generation, 49140 MiB

Who can help?

@yiyixuxu @DN6

@yfpeng1234 yfpeng1234 added the bug Something isn't working label Dec 23, 2024
@hlky
Copy link
Collaborator

hlky commented Dec 23, 2024

Enabling tiling or slicing will reduce the memory consumption. Using no_grad context will also reduce memory usage.

def enable_tiling(self, use_tiling: bool = True):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants