Sample latents from VAE to generate close images #9327
Replies: 2 comments 11 replies
-
I tested with your code with updating noise and scale factor with some modification. import torch
from diffusers import AutoencoderKL
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import pil_to_tensor
# 2: assume I'm interested in getting close variations of an image, using SDXL's VAE
# instead of inferring, eg doing an img2img or inpaint,
# I can use the latent proba distribution inferred by the VAE, sample from it and decode back to pixel space
# Instantiate SDXL's VAE
with torch.no_grad():
# vae:AutoencoderKL = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix")
vae: AutoencoderKL = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae"
)
vae.to(dtype=torch.float32) # otherwise it produces NaNs, even madebyollin's VAE
vae.to(device="cuda")
assert vae.device == torch.device("cuda:0")
assert vae.dtype == torch.float32
# make image as tensor
img = Image.open("avenger.jpg") # Replace with your actual image path
img_tensor = pil_to_tensor(img).unsqueeze(0) / 255.0
img_tensor = img_tensor.to(vae.device)
img_tensor = img_tensor.to(vae.dtype)
# get the inferred latent distribution
latent_dist: DiagonalGaussianDistribution = vae.encode(img_tensor, return_dict=False)[0]
print(f"{latent_dist.mean.shape=} {latent_dist.std.shape=} {latent_dist.mean.mean()=} {latent_dist.std.mean()=}")
assert not latent_dist.mean.isnan().any()
assert not latent_dist.std.isnan().any()
assert latent_dist.deterministic is False
# -- Tried with scale factor and add noise--
scale_factor = 5.0 # increased the variations
noise_strength = 0.2 # add noise will help further perturb the latent space
noise = noise_strength * torch.randn_like(latent_dist.mean).to(vae.device)
# generate new latents with added noise and scaling
sample_1 = latent_dist.mean + scale_factor * latent_dist.std * torch.randn_like(latent_dist.mean).to(vae.device) + noise
sample_2 = latent_dist.mean + scale_factor * latent_dist.std * torch.randn_like(latent_dist.mean).to(vae.device) + noise
assert not sample_1.isnan().any()
assert not sample_2.isnan().any()
assert (sample_1 != sample_2).any(), "samples should be different"
print(f"{sample_1.shape=}")
assert vae.dtype == sample_1.dtype
assert vae.device == sample_1.device
# decode the sampled latents back to images
img_1: torch.Tensor = vae.decode(sample_1).sample # Decoding the first variation
img_1 = img_1.squeeze(0).cpu().detach()
assert (img_1 != img_tensor.cpu().detach()).any(), "generated image should be different from the input image"
# save first variation
img_1_pil = transforms.ToPILImage()(img_1)
img_1_pil.save("sample2_1.png")
# save second variation
img_2: torch.Tensor = vae.decode(sample_2).sample
img_2 = img_2.squeeze(0).cpu().detach()
img_2_pil = transforms.ToPILImage()(img_2)
img_2_pil.save("sample2_2.png")
# -- Try interpolation between Latent vectors --
t = torch.rand(1).item() # generates a random interpolation factor between 0 - 1
interpolated_sample = (1 - t) * sample_1 + t * sample_2
# Decode the interpolated sample --> image
img_interpolated = vae.decode(interpolated_sample).sample
img_interpolated = img_interpolated.squeeze(0).cpu().detach()
img_interpolated = (img_interpolated * 0.5 + 0.5).clamp(0, 1)
img_interpolated_pil = transforms.ToPILImage()(img_interpolated)
img_interpolated_pil.save("interpolated_variation2.png")
print("Done")
It sure does working well. It does change eyes, smiles, etc. You may want to see significant changes. However, you may need to consider the dimension of the latent space. Also VAE are good at generating smooth variation. it may not be as effective at huge change unless it trained for that. |
Beta Was this translation helpful? Give feedback.
-
Here is an exmaple of what I get when interpolating two images (smile vs non smile portraits) This is very cool as it doesn't require reverse diffusion so it's pretty fast, but again my initial idea is to start from a single input image and generate variations of it |
Beta Was this translation helpful? Give feedback.
-
hi,
I would like to use a VAE (SDXL's VAE in the example below) in order to get close variations of a given image, in the spirit of what was demonstrated in that paper.
I could achieve this with inpainting or maybe img2img, but my point is to get these close images without having to go through reverse diffusion, for performance reasons.
I tried to implement the logic, and found out that diffusers already has nearly all the tools.
PROBLEM: It does run without errors and I do get images that are different pixel-wise to the initial image, however the difference across images is invisible to the human eye. I noticed that the std of the latent distribution is extremely low relative to the mean, that might explain why all the produced images look so identical to the input? I might have done a preproc step wrong also?
I'd be grateful if someone could review my logic and advise.
Here is my code:
Thanks a lot!
cc @asomoza
Beta Was this translation helpful? Give feedback.
All reactions