From c441048a4f885ba8f3bc0ce959b5df034d63368a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 19 Dec 2024 05:31:39 -0500 Subject: [PATCH] Make VAE Encode tiled node work with video VAE. --- comfy/sd.py | 56 ++++++++++++++++++++++++++++++++++++++++++-------- comfy/utils.py | 22 +++++++++++++++----- nodes.py | 9 ++++---- 3 files changed, 70 insertions(+), 17 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 3b29eecb482..85393ef0d47 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -336,6 +336,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None): self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8) + self.downscale_ratio = (lambda a: max(0, (a + 3) / 6), 8, 8) self.working_dtypes = [torch.float16, torch.float32] elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE() @@ -344,12 +345,14 @@ def __init__(self, sd=None, device=None, config=None, dtype=None): self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32) + self.downscale_ratio = (lambda a: max(0, (a + 4) / 8), 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True ddconfig["time_compress"] = 4 self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.downscale_ratio = (lambda a: max(0, (a + 2) / 4), 8, 8) self.latent_dim = 3 self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) @@ -385,10 +388,12 @@ def __init__(self, sd=None, device=None, config=None, dtype=None): logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) def vae_encode_crop_pixels(self, pixels): + downscale_ratio = self.spacial_compression_encode() + dims = pixels.shape[1:-1] for d in range(len(dims)): - x = (dims[d] // self.downscale_ratio) * self.downscale_ratio - x_offset = (dims[d] % self.downscale_ratio) // 2 + x = (dims[d] // downscale_ratio) * downscale_ratio + x_offset = (dims[d] % downscale_ratio) // 2 if x != dims[d]: pixels = pixels.narrow(d + 1, x_offset, x) return pixels @@ -409,7 +414,7 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_1d(self, samples, tile_x=128, overlap=32): decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() - return comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device) + return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() @@ -432,6 +437,10 @@ def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048): encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device) + def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)): + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() + return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, output_device=self.output_device) + def decode(self, samples_in): pixel_samples = None try: @@ -504,18 +513,43 @@ def encode(self, pixel_samples): except model_management.OOM_EXCEPTION: logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") - if len(pixel_samples.shape) == 3: + if self.latent_dim == 3: + tile = 256 + overlap = tile // 4 + samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + elif self.latent_dim == 1: samples = self.encode_tiled_1d(pixel_samples) else: samples = self.encode_tiled_(pixel_samples) return samples - def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): + def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None): pixel_samples = self.vae_encode_crop_pixels(pixel_samples) - model_management.load_model_gpu(self.patcher) - pixel_samples = pixel_samples.movedim(-1,1) - samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) + dims = self.latent_dim + pixel_samples = pixel_samples.movedim(-1, 1) + if dims == 3: + pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) + + memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile + model_management.load_models_gpu([self.patcher], memory_required=memory_used) + + args = {} + if tile_x is not None: + args["tile_x"] = tile_x + if tile_y is not None: + args["tile_y"] = tile_y + if overlap is not None: + args["overlap"] = overlap + + if dims == 1: + args.pop("tile_y") + samples = self.encode_tiled_1d(pixel_samples, **args) + elif dims == 2: + samples = self.encode_tiled_(pixel_samples, **args) + elif dims == 3: + samples = self.encode_tiled_3d(pixel_samples, **args) + return samples def get_sd(self): @@ -527,6 +561,12 @@ def spacial_compression_decode(self): except: return self.upscale_ratio + def spacial_compression_encode(self): + try: + return self.downscale_ratio[-1] + except: + return self.downscale_ratio + class StyleModel: def __init__(self, model, device="cpu"): self.model = model diff --git a/comfy/utils.py b/comfy/utils.py index d03693bba8d..ab1e3cd5a47 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -751,7 +751,7 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): return rows * cols @torch.inference_mode() -def tiled_scale_multidim(samples, function, tile=(64, 64), overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): +def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, pbar=None): dims = len(tile) if not (isinstance(upscale_amount, (tuple, list))): @@ -767,10 +767,22 @@ def get_upscale(dim, val): else: return up * val + def get_downscale(dim, val): + up = upscale_amount[dim] + if callable(up): + return up(val) + else: + return val / up + + if downscale: + get_scale = get_downscale + else: + get_scale = get_upscale + def mult_list_upscale(a): out = [] for i in range(len(a)): - out.append(round(get_upscale(i, a[i]))) + out.append(round(get_scale(i, a[i]))) return out output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device) @@ -798,13 +810,13 @@ def mult_list_upscale(a): pos = max(0, min(s.shape[d + 2] - (overlap[d] + 1), it[d])) l = min(tile[d], s.shape[d + 2] - pos) s_in = s_in.narrow(d + 2, pos, l) - upscaled.append(round(get_upscale(d, pos))) + upscaled.append(round(get_scale(d, pos))) ps = function(s_in).to(output_device) mask = torch.ones_like(ps) for d in range(2, dims + 2): - feather = round(get_upscale(d - 2, overlap[d - 2])) + feather = round(get_scale(d - 2, overlap[d - 2])) if feather >= mask.shape[d]: continue for t in range(feather): @@ -828,7 +840,7 @@ def mult_list_upscale(a): return output def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): - return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar) + return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) PROGRESS_BAR_ENABLED = True def set_progress_bar_enabled(enabled): diff --git a/nodes.py b/nodes.py index 6187e228d2b..1a90073e9da 100644 --- a/nodes.py +++ b/nodes.py @@ -291,7 +291,7 @@ class VAEDecodeTiled: @classmethod def INPUT_TYPES(s): return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), - "tile_size": ("INT", {"default": 512, "min": 128, "max": 4096, "step": 32}), + "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}), }} RETURN_TYPES = ("IMAGE",) @@ -325,15 +325,16 @@ class VAEEncodeTiled: @classmethod def INPUT_TYPES(s): return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ), - "tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64}) + "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}), + "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}), }} RETURN_TYPES = ("LATENT",) FUNCTION = "encode" CATEGORY = "_for_testing" - def encode(self, vae, pixels, tile_size): - t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, ) + def encode(self, vae, pixels, tile_size, overlap): + t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap) return ({"samples":t}, ) class VAEEncodeForInpaint: