diff --git a/lora_diffusion/cli_lora_pti.py b/lora_diffusion/cli_lora_pti.py index 83703d0..7de4bae 100644 --- a/lora_diffusion/cli_lora_pti.py +++ b/lora_diffusion/cli_lora_pti.py @@ -128,17 +128,31 @@ def get_models( ) -def text2img_dataloader(train_dataset, train_batch_size, tokenizer, vae, text_encoder): +@torch.no_grad() +def text2img_dataloader( + train_dataset, + train_batch_size, + tokenizer, + vae, + text_encoder, + cached_latents: bool = False, +): + + if cached_latents: + cached_latents_dataset = [] + for idx in tqdm(range(len(train_dataset))): + batch = train_dataset[idx] + # rint(batch) + latents = vae.encode( + batch["instance_images"].unsqueeze(0).to(dtype=vae.dtype).to(vae.device) + ).latent_dist.sample() + latents = latents * 0.18215 + batch["instance_images"] = latents.squeeze(0) + cached_latents_dataset.append(batch) + def collate_fn(examples): input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] - - # Concat class and instance examples for prior preservation. - # We do this to avoid doing two forward passes. - if examples[0].get("class_prompt_ids", None) is not None: - input_ids += [example["class_prompt_ids"] for example in examples] - pixel_values += [example["class_images"] for example in examples] - pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() @@ -159,21 +173,38 @@ def collate_fn(examples): return batch - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - batch_size=train_batch_size, - shuffle=True, - collate_fn=collate_fn, - ) + if cached_latents: + + train_dataloader = torch.utils.data.DataLoader( + cached_latents_dataset, + batch_size=train_batch_size, + shuffle=True, + collate_fn=collate_fn, + ) + + print("PTI : Using cached latent.") + + else: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=train_batch_size, + shuffle=True, + collate_fn=collate_fn, + ) return train_dataloader -def inpainting_dataloader(train_dataset, train_batch_size, tokenizer, vae, text_encoder): + +def inpainting_dataloader( + train_dataset, train_batch_size, tokenizer, vae, text_encoder +): def collate_fn(examples): input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] mask_values = [example["instance_masks"] for example in examples] - masked_image_values = [example["instance_masked_images"] for example in examples] + masked_image_values = [ + example["instance_masked_images"] for example in examples + ] # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. @@ -181,11 +212,21 @@ def collate_fn(examples): input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] mask_values += [example["class_masks"] for example in examples] - masked_image_values += [example["class_masked_images"] for example in examples] + masked_image_values += [ + example["class_masked_images"] for example in examples + ] - pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float() - mask_values = torch.stack(mask_values).to(memory_format=torch.contiguous_format).float() - masked_image_values = torch.stack(masked_image_values).to(memory_format=torch.contiguous_format).float() + pixel_values = ( + torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float() + ) + mask_values = ( + torch.stack(mask_values).to(memory_format=torch.contiguous_format).float() + ) + masked_image_values = ( + torch.stack(masked_image_values) + .to(memory_format=torch.contiguous_format) + .float() + ) input_ids = tokenizer.pad( {"input_ids": input_ids}, @@ -198,7 +239,7 @@ def collate_fn(examples): "input_ids": input_ids, "pixel_values": pixel_values, "mask_values": mask_values, - "masked_image_values": masked_image_values + "masked_image_values": masked_image_values, } if examples[0].get("mask", None) is not None: @@ -215,6 +256,7 @@ def collate_fn(examples): return train_dataloader + def loss_step( batch, unet, @@ -225,23 +267,30 @@ def loss_step( t_mutliplier=1.0, mixed_precision=False, mask_temperature=1.0, + cached_latents: bool = False, ): weight_dtype = torch.float32 - - latents = vae.encode( - batch["pixel_values"].to(dtype=weight_dtype).to(unet.device) - ).latent_dist.sample() - latents = latents * 0.18215 - - if train_inpainting: - masked_image_latents = vae.encode( - batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device) + if not cached_latents: + latents = vae.encode( + batch["pixel_values"].to(dtype=weight_dtype).to(unet.device) ).latent_dist.sample() - masked_image_latents = masked_image_latents * 0.18215 - mask = F.interpolate( - batch["mask_values"].to(dtype=weight_dtype).to(unet.device), - scale_factor=1/8 - ) + latents = latents * 0.18215 + + if train_inpainting: + masked_image_latents = vae.encode( + batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device) + ).latent_dist.sample() + masked_image_latents = masked_image_latents * 0.18215 + mask = F.interpolate( + batch["mask_values"].to(dtype=weight_dtype).to(unet.device), + scale_factor=1 / 8, + ) + else: + latents = batch["pixel_values"] + + if train_inpainting: + masked_image_latents = batch["masked_image_latents"] + mask = batch["mask_values"] noise = torch.randn_like(latents) bsz = latents.shape[0] @@ -257,7 +306,9 @@ def loss_step( noisy_latents = scheduler.add_noise(latents, noise, timesteps) if train_inpainting: - latent_model_input = torch.cat([noisy_latents, mask, masked_image_latents], dim=1) + latent_model_input = torch.cat( + [noisy_latents, mask, masked_image_latents], dim=1 + ) else: latent_model_input = noisy_latents @@ -268,7 +319,9 @@ def loss_step( batch["input_ids"].to(text_encoder.device) )[0] - model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample + model_pred = unet( + latent_model_input, timesteps, encoder_hidden_states + ).sample else: encoder_hidden_states = text_encoder( @@ -308,7 +361,12 @@ def loss_step( target = target * mask - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + loss = ( + F.mse_loss(model_pred.float(), target.float(), reduction="none") + .mean([1, 2, 3]) + .mean() + ) + return loss @@ -328,6 +386,7 @@ def train_inversion( tokenizer, lr_scheduler, test_image_path: str, + cached_latents: bool, accum_iter: int = 1, log_wandb: bool = False, wandb_log_prompt_cnt: int = 10, @@ -367,6 +426,7 @@ def train_inversion( scheduler, train_inpainting=train_inpainting, mixed_precision=mixed_precision, + cached_latents=cached_latents, ) / accum_iter ) @@ -375,6 +435,13 @@ def train_inversion( loss_sum += loss.detach().item() if global_step % accum_iter == 0: + # print gradient of text encoder embedding + print( + text_encoder.get_input_embeddings() + .weight.grad[index_updates, :] + .norm(dim=-1) + .mean() + ) optimizer.step() optimizer.zero_grad() @@ -448,7 +515,11 @@ def train_inversion( # open all images in test_image_path images = [] for file in os.listdir(test_image_path): - if file.lower().endswith(".png") or file.lower().endswith(".jpg") or file.lower().endswith(".jpeg"): + if ( + file.lower().endswith(".png") + or file.lower().endswith(".jpg") + or file.lower().endswith(".jpeg") + ): images.append( Image.open(os.path.join(test_image_path, file)) ) @@ -490,6 +561,7 @@ def perform_tuning( out_name: str, tokenizer, test_image_path: str, + cached_latents: bool, log_wandb: bool = False, wandb_log_prompt_cnt: int = 10, class_token: str = "person", @@ -526,6 +598,7 @@ def perform_tuning( t_mutliplier=0.8, mixed_precision=True, mask_temperature=mask_temperature, + cached_latents=cached_latents, ) loss_sum += loss.detach().item() @@ -627,18 +700,12 @@ def train( train_text_encoder: bool = True, pretrained_vae_name_or_path: str = None, revision: Optional[str] = None, - class_data_dir: Optional[str] = None, - stochastic_attribute: Optional[str] = None, perform_inversion: bool = True, use_template: Literal[None, "object", "style"] = None, train_inpainting: bool = False, placeholder_tokens: str = "", placeholder_token_at_data: Optional[str] = None, initializer_tokens: Optional[str] = None, - class_prompt: Optional[str] = None, - with_prior_preservation: bool = False, - prior_loss_weight: float = 1.0, - num_class_images: int = 100, seed: int = 42, resolution: int = 512, color_jitter: bool = True, @@ -649,7 +716,6 @@ def train( save_steps: int = 100, gradient_accumulation_steps: int = 4, gradient_checkpointing: bool = False, - mixed_precision="fp16", lora_rank: int = 4, lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"}, lora_clip_target_modules={"CLIPAttention"}, @@ -663,6 +729,7 @@ def train( continue_inversion: bool = False, continue_inversion_lr: Optional[float] = None, use_face_segmentation_condition: bool = False, + cached_latents: bool = True, use_mask_captioned_data: bool = False, mask_temperature: float = 1.0, scale_lr: bool = False, @@ -773,11 +840,8 @@ def train( train_dataset = PivotalTuningDatasetCapation( instance_data_root=instance_data_dir, - stochastic_attribute=stochastic_attribute, token_map=token_map, use_template=use_template, - class_data_root=class_data_dir if with_prior_preservation else None, - class_prompt=class_prompt, tokenizer=tokenizer, size=resolution, color_jitter=color_jitter, @@ -789,12 +853,19 @@ def train( train_dataset.blur_amount = 200 if train_inpainting: + assert not cached_latents, "Cached latents not supported for inpainting" + train_dataloader = inpainting_dataloader( train_dataset, train_batch_size, tokenizer, vae, text_encoder ) else: train_dataloader = text2img_dataloader( - train_dataset, train_batch_size, tokenizer, vae, text_encoder + train_dataset, + train_batch_size, + tokenizer, + vae, + text_encoder, + cached_latents=cached_latents, ) index_no_updates = torch.arange(len(tokenizer)) != -1 @@ -813,6 +884,8 @@ def train( for param in params_to_freeze: param.requires_grad = False + if cached_latents: + vae = None # STEP 1 : Perform Inversion if perform_inversion: ti_optimizer = optim.AdamW( @@ -836,6 +909,7 @@ def train( text_encoder, train_dataloader, max_train_steps_ti, + cached_latents=cached_latents, accum_iter=gradient_accumulation_steps, scheduler=noise_scheduler, index_no_updates=index_no_updates, @@ -941,6 +1015,7 @@ def train( text_encoder, train_dataloader, max_train_steps_tuning, + cached_latents=cached_latents, scheduler=noise_scheduler, optimizer=lora_optimizers, save_steps=save_steps, diff --git a/lora_diffusion/dataset.py b/lora_diffusion/dataset.py index 2a46313..f1c28fd 100644 --- a/lora_diffusion/dataset.py +++ b/lora_diffusion/dataset.py @@ -2,14 +2,11 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union -import cv2 -import numpy as np -from PIL import Image, ImageFilter +from PIL import Image from torch import zeros_like from torch.utils.data import Dataset from torchvision import transforms import glob - from .preprocess_files import face_mask_google_mediapipe OBJECT_TEMPLATE = [ @@ -128,12 +125,9 @@ class PivotalTuningDatasetCapation(Dataset): def __init__( self, instance_data_root, - stochastic_attribute, tokenizer, token_map: Optional[dict] = None, use_template: Optional[str] = None, - class_data_root=None, - class_prompt=None, size=512, h_flip=True, color_jitter=False, @@ -240,18 +234,6 @@ def __init__( self._length = self.num_instance_images - if class_data_root is not None: - assert NotImplementedError, "Prior preservation is not implemented yet." - - self.class_data_root = Path(class_data_root) - self.class_data_root.mkdir(parents=True, exist_ok=True) - self.class_images_path = list(self.class_data_root.iterdir()) - self.num_class_images = len(self.class_images_path) - self._length = max(self.num_class_images, self.num_instance_images) - self.class_prompt = class_prompt - else: - self.class_data_root = None - self.h_flip = h_flip self.image_transforms = transforms.Compose( [ @@ -326,23 +308,4 @@ def __getitem__(self, index): max_length=self.tokenizer.model_max_length, ).input_ids - if self.class_data_root: - class_image = Image.open( - self.class_images_path[index % self.num_class_images] - ) - if not class_image.mode == "RGB": - class_image = class_image.convert("RGB") - example["class_images"] = self.image_transforms(class_image) - if self.train_inpainting: - ( - example["class_masks"], - example["class_masked_images"], - ) = _generate_random_mask(example["class_images"]) - example["class_prompt_ids"] = self.tokenizer( - self.class_prompt, - padding="do_not_pad", - truncation=True, - max_length=self.tokenizer.model_max_length, - ).input_ids - return example diff --git a/lora_diffusion/lora_manager.py b/lora_diffusion/lora_manager.py index 02d679b..9d8306e 100644 --- a/lora_diffusion/lora_manager.py +++ b/lora_diffusion/lora_manager.py @@ -23,7 +23,10 @@ def lora_join(lora_safetenors: list): if k.endswith("rank"): rankset.append(int(v)) - assert len(set(rankset)) == 1, "Rank should be the same per model" + assert len(set(rankset)) <= 1, "Rank should be the same per model" + if len(rankset) == 0: + rankset = [0] + total_rank += rankset[0] _total_metadata.update(_metadata) ranklist.append(rankset[0]) @@ -119,6 +122,10 @@ def _setup(self): def tune(self, scales): + assert len(scales) == len( + self.ranklist + ), "Scale list should be the same length as ranklist" + diags = [] for scale, rank in zip(scales, self.ranklist): diags = diags + [scale] * rank diff --git a/setup.py b/setup.py index 3f21767..6d286b3 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="lora_diffusion", py_modules=["lora_diffusion"], - version="0.1.6", + version="0.1.7", description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.", author="Simo Ryu", packages=find_packages(),