You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I used the above code to do flux fill's fp8 inference and the GPU memory usage went down to 15G, but it took 4 minutes to infer a picture !!!!
Although the original model use a lot of GPU memory, the inference only takes more than ten seconds.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
transformer = FluxTransformer2DModel.from_pretrained(model_name, torch_dtype=torch.bfloat16)
quantize(transformer, weights=qfloat8)
freeze(transformer)
text_encoder_2 = T5EncoderModel.from_pretrained(model_name, torch_dtype=torch.bfloat16)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
pipe = FluxFillPipeline.from_pretrained(model_name, transformer=None, text_encoder_2=None, torch_dtype=torch.bfloat16)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
pipe.enable_model_cpu_offload()
image = pipe(
prompt="A yellow umbrella",
image=image,
mask_image=mask,
height=1024,
width=1024,
guidance_scale=30,
num_inference_steps=30,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
I used the above code to do flux fill's fp8 inference and the GPU memory usage went down to 15G, but it took 4 minutes to infer a picture !!!!
Although the original model use a lot of GPU memory, the inference only takes more than ten seconds.
Beta Was this translation helpful? Give feedback.
All reactions