Skip to content

Commit

Permalink
[Model offload] Add nice warning (#2543)
Browse files Browse the repository at this point in the history
* [Model offload] Add nice warning

* Treat sequential and model offload differently.

Sequential raises an error because the operation would fail with a
cryptic warning later.

* Forcibly move to cpu when offloading.

* make style

* one more fix

* make fix-copies

* up

---------

Co-authored-by: Pedro Cuenca <[email protected]>
  • Loading branch information
patrickvonplaten and pcuenca authored Mar 3, 2023
1 parent 4f0141a commit 5b6582c
Show file tree
Hide file tree
Showing 15 changed files with 166 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
cpu_offload(cpu_offloaded_model, device)

Expand All @@ -234,6 +238,10 @@ def enable_model_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

hook = None
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
cpu_offload(cpu_offloaded_model, device)

Expand All @@ -240,6 +244,10 @@ def enable_model_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

hook = None
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
Expand Down
44 changes: 42 additions & 2 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
get_class_from_dynamic_module,
http_user_agent,
is_accelerate_available,
is_accelerate_version,
is_safetensors_available,
is_torch_version,
is_transformers_available,
Expand All @@ -66,6 +67,10 @@
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME


if is_accelerate_available():
import accelerate


INDEX_FILE = "diffusion_pytorch_model.bin"
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
DUMMY_MODULES_FOLDER = "diffusers.utils"
Expand Down Expand Up @@ -337,15 +342,50 @@ def is_saveable_module(name, value):

save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)

def to(self, torch_device: Optional[Union[str, torch.device]] = None):
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):
if torch_device is None:
return self

# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
def module_is_sequentially_offloaded(module):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False

return hasattr(module, "_hf_hook") and not isinstance(module._hf_hook, accelerate.hooks.CpuOffload)

def module_is_offloaded(module):
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
return False

return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)

# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
if pipeline_is_sequentially_offloaded and torch.device(torch_device).type == "cuda":
raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
)

# Display a warning in this case (the operation succeeds but the benefits are lost)
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
if pipeline_is_offloaded and torch.device(torch_device).type == "cuda":
logger.warning(
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)

module_names, _, _ = self.extract_init_dict(dict(self.config))
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
if module.dtype == torch.float16 and str(torch_device) in ["cpu"]:
if (
module.dtype == torch.float16
and str(torch_device) in ["cpu"]
and not silence_dtype_warnings
and not is_offloaded
):
logger.warning(
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
" is not recommended to move them to `cpu` as running them will fail. Please make"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
cpu_offload(cpu_offloaded_model, device)

Expand All @@ -258,6 +262,10 @@ def enable_model_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

hook = None
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
cpu_offload(cpu_offloaded_model, device)

Expand All @@ -237,6 +241,10 @@ def enable_model_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

hook = None
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
cpu_offload(cpu_offloaded_model, device)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
cpu_offload(cpu_offloaded_model, device)

Expand All @@ -246,6 +250,10 @@ def enable_model_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

hook = None
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
cpu_offload(cpu_offloaded_model, device)

Expand All @@ -293,6 +297,10 @@ def enable_model_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

hook = None
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
cpu_offload(cpu_offloaded_model, device)

Expand All @@ -237,6 +241,10 @@ def enable_model_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

hook = None
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
cpu_offload(cpu_offloaded_model, device)

Expand All @@ -426,6 +430,10 @@ def enable_model_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

hook = None
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
cpu_offload(cpu_offloaded_model, device)

Expand All @@ -158,6 +162,10 @@ def enable_model_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

hook = None
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
cpu_offload(cpu_offloaded_model, device)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
cpu_offload(cpu_offloaded_model, device)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):

device = torch.device(f"cuda:{gpu_id}")

if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
cpu_offload(cpu_offloaded_model, device)

Expand Down
36 changes: 36 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,42 @@ def test_stable_diffusion_components(self):
assert image_img2img.shape == (1, 32, 32, 3)
assert image_text2img.shape == (1, 64, 64, 3)

@require_torch_gpu
def test_pipe_false_offload_warn(self):
unet = self.dummy_cond_unet()
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

sd = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)

sd.enable_model_cpu_offload()

logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
with CaptureLogger(logger) as cap_logger:
sd.to("cuda")

assert "It is strongly recommended against doing so" in str(cap_logger)

sd = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
)

def test_set_scheduler(self):
unet = self.dummy_cond_unet()
scheduler = PNDMScheduler(skip_prk_steps=True)
Expand Down

0 comments on commit 5b6582c

Please sign in to comment.