Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference error in Flux with quantization & LoRA applied, and a bug of Quanto with Zero GPU spaces #10381

Open
John6666cat opened this issue Dec 25, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@John6666cat
Copy link

Describe the bug

Merry Christmas.🎅

The inference of the Flux model with LoRA applied and the inference of the quantized Flux model work fine on their own, but when combined, they often result in an error.

RuntimeError('Only Tensors of floating point and complex dtype can require gradients')

I don't know which library or environment is the main cause of this error, but I'm posting it here because it's easy to confirm the bug with the combination of Diffusers and Flux.
In the following demo, I'm using the pip version, but the same error occurs with the github version of Diffusers and PEFT.
The code is written in a roundabout way to avoid bugs in the Zero GPU space.

Demo Space for error reproduction (for Huggingface ZeroGPU)

https://huggingface.co/spaces/John6666/diffusers_lora_error_test1

P.S.

The demo that was working a few minutes ago has stopped working due to a new error...

RuntimeError('NVML_SUCCESS == r INTERNAL ASSERT FAILED at "../c10/cuda/CUDACachingAllocator.cpp":838, please report a bug to PyTorch. ')

Extra: Zero GPU space and Quanto are too incompatible

In this demo, just by writing the following line, the inference in Zero GPU space will always crash. I think CUDA is being called within Quanto.🥶

from optimum.quanto import freeze, qfloat8, quantize
  File "/usr/local/lib/python3.10/site-packages/spaces/zero/wrappers.py", line 214, in gradio_handler
    raise res.value
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

Reproduction

# for Huggingface ZeroGPU Spaces
import spaces
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig, BitsAndBytesConfig
import os
import subprocess
#subprocess.run("pip list", shell=True)
#subprocess.run("diffusers-cli env", shell=True)
#from optimum.quanto import freeze, qfloat8, quantize

HF_TOKEN = os.getenv("HF_TOKEN", "")
device = "cuda" if torch.cuda.is_available() else "cpu"
flux_repo = "multimodalart/FLUX.1-dev2pro-full"
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
transformer_gguf = FluxTransformer2DModel.from_single_file(ckpt_path, subfolder="transformer", quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
                                                           torch_dtype=torch.bfloat16, config=flux_repo, token=HF_TOKEN)
transformer = FluxTransformer2DModel.from_pretrained(flux_repo, subfolder="transformer", torch_dtype=torch.bfloat16, token=HF_TOKEN)
nf4_quantization_config = BitsAndBytesConfig(load_in_4bit=True)
transformer_nf4 = FluxTransformer2DModel.from_pretrained(flux_repo, subfolder="transformer", quantization_config=nf4_quantization_config,
                                                         torch_dtype=torch.bfloat16, token=HF_TOKEN)
pipe = FluxPipeline.from_pretrained(flux_repo, transformer=transformer, torch_dtype=torch.bfloat16, token=HF_TOKEN)
hyper_sd_lora = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")

@spaces.GPU(duration=70)
def infer(prompt: str, mode: str, is_lora: bool, progress=gr.Progress(track_tqdm=True)):
    global pipe
    try:
        pipe.unload_lora_weights()
        if mode == "Default": pipe.transformer = transformer
        elif mode == "GGUF": pipe.transformer = transformer_gguf
        elif mode == "NF4": pipe.transformer = transformer_nf4
        if is_lora:
            pipe.load_lora_weights(hyper_sd_lora, adapter_name="hyper-sd")
            pipe.set_adapters(["hyper-sd"], adapter_weights=[0.125])
            steps = 8
        else: steps = 28
        pipe.to(device)
        image = pipe(prompt, generator=torch.manual_seed(0), num_inference_steps=steps).images[0]
        pipe.to("cpu")
        return image
    except Exception as e:
        raise gr.Error(e)

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", value="A cat holding a sign that says hello world", lines=1)
            mode = gr.Radio(label="Mode", choices=["Default", "GGUF", "NF4"], value="Default")
            is_lora = gr.Checkbox(label="Enable LoRA", value=True)
            gen_btn = gr.Button("Generate Image")
        with gr.Column():
            result = gr.Image(label="Result Image")

    gen_btn.click(infer, [prompt, mode, is_lora], [result])

demo.launch()
huggingface_hub
torch
diffusers
peft
transformers
accelerate
numpy<2
gguf
bitsandbytes
optimum-quanto

Logs

No response

System Info

  • 🤗 Diffusers version: 0.32.0
  • Platform: Linux-5.10.228-219.884.amzn2.x86_64-x86_64-with-glibc2.36
  • Running on Google Colab?: No
  • Python version: 3.10.13
  • PyTorch version (GPU?): 2.4.0+cu121 (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.25.1
  • Transformers version: 4.47.1
  • Accelerate version: 1.2.1
  • PEFT version: 0.14.0
  • Bitsandbytes version: 0.45.0
  • Safetensors version: 0.4.5
  • xFormers version: not installed
  • Accelerator: Failed to initialize NVML: Unknown Error
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

No response

@John6666cat John6666cat added the bug Something isn't working label Dec 25, 2024
@Dylanooo
Copy link

I got the same Error
RuntimeError('Only Tensors of floating point and complex dtype can require gradients')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants