Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix KeyError: 'up_cross' #68

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
dee8dd9
liutao
byliutao Aug 25, 2023
59969db
liutao
byliutao Aug 26, 2023
62bb401
liutao
byliutao Aug 26, 2023
237ea0e
liutao
byliutao Aug 26, 2023
97a60fe
liutao
byliutao Aug 26, 2023
fb77df7
Merge branch 'main' of https://github.com/byliutao/prompt-to-prompt i…
byliutao Aug 26, 2023
160965d
liutao
byliutao Aug 27, 2023
5a6bb14
liutao
byliutao Aug 28, 2023
256b15d
liutao
byliutao Aug 29, 2023
def6e50
liutao
byliutao Aug 30, 2023
16e045b
liutao
byliutao Aug 31, 2023
601e049
liutao
byliutao Sep 1, 2023
1c93f36
liutao
byliutao Sep 1, 2023
864a5ae
liutao
byliutao Sep 2, 2023
ee94102
liutao
byliutao Sep 2, 2023
780bb61
liutao
byliutao Sep 2, 2023
904bab6
liutao
byliutao Sep 3, 2023
dd534be
liutao
byliutao Sep 5, 2023
8905a04
liutao
byliutao Sep 8, 2023
6a524c4
liutao
byliutao Sep 8, 2023
c728c9b
liutao
byliutao Sep 12, 2023
5017587
liutao
byliutao Sep 12, 2023
e11e650
liutao
byliutao Sep 12, 2023
ea4423c
liutao
byliutao Sep 13, 2023
fa53975
liutao
byliutao Sep 14, 2023
f601fc5
liutao
byliutao Sep 14, 2023
8d0e565
liutao
byliutao Sep 15, 2023
1cef42c
liutao
byliutao Sep 16, 2023
a5b2d04
liutao
byliutao Sep 18, 2023
10e12bd
liutao
byliutao Sep 18, 2023
f3dd130
liutao
byliutao Sep 18, 2023
d6937e2
liutao
byliutao Sep 19, 2023
aacdf6e
liutao
byliutao Sep 20, 2023
0c9658c
liutao
byliutao Sep 21, 2023
c53d5c2
liutao
byliutao Sep 24, 2023
53015d0
liuato
byliutao Sep 25, 2023
c941473
liutao
byliutao Sep 26, 2023
27890c3
liutao
byliutao Sep 29, 2023
1de1312
liutao
byliutao Oct 7, 2023
6591513
liutao
byliutao Oct 8, 2023
beab11a
liutao
byliutao Oct 8, 2023
ffd833c
liutao
byliutao Oct 9, 2023
9624b65
liutao
byliutao Oct 10, 2023
35da05e
liutao
byliutao Oct 23, 2023
f1a9fa1
legacy commit
byliutao Jan 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,290 changes: 3,290 additions & 0 deletions .ipynb_checkpoints/null_text_w_ptp-checkpoint.ipynb

Large diffs are not rendered by default.

855 changes: 855 additions & 0 deletions .ipynb_checkpoints/prompt-to-prompt_stable-checkpoint.ipynb

Large diffs are not rendered by default.

320 changes: 320 additions & 0 deletions .ipynb_checkpoints/ptp_utils-checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
import cv2
from typing import Optional, Union, Tuple, List, Callable, Dict
from IPython.display import display
from tqdm.notebook import tqdm


def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
h, w, c = image.shape
offset = int(h * .2)
img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
font = cv2.FONT_HERSHEY_SIMPLEX
# font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
img[:h] = image
textsize = cv2.getTextSize(text, font, 1, 2)[0]
text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
return img


def view_images(images, num_rows=1, offset_ratio=0.02):
if type(images) is list:
num_empty = len(images) % num_rows
elif images.ndim == 4:
num_empty = images.shape[0] % num_rows
else:
images = [images]
num_empty = 0

empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
num_items = len(images)

h, w, c = images[0].shape
offset = int(h * offset_ratio)
num_cols = num_items // num_rows
image_ = np.ones((h * num_rows + offset * (num_rows - 1),
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
for i in range(num_rows):
for j in range(num_cols):
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
i * num_cols + j]

pil_img = Image.fromarray(image_)
display(pil_img)


def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False):
if low_resource:
noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
else:
latents_input = torch.cat([latents] * 2)
noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
latents = controller.step_callback(latents)
return latents


def latent2image(vae, latents):
latents = 1 / 0.18215 * latents
image = vae.decode(latents)['sample']
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image = (image * 255).astype(np.uint8)
return image


def init_latent(latent, model, height, width, generator, batch_size):
if latent is None:
latent = torch.randn(
(1, model.unet.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
return latent, latents


@torch.no_grad()
def text2image_ldm(
model,
prompt: List[str],
controller,
num_inference_steps: int = 50,
guidance_scale: Optional[float] = 7.,
generator: Optional[torch.Generator] = None,
latent: Optional[torch.FloatTensor] = None,
):
register_attention_control(model, controller)
height = width = 256
batch_size = len(prompt)

uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0]

text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
text_embeddings = model.bert(text_input.input_ids.to(model.device))[0]
latent, latents = init_latent(latent, model, height, width, generator, batch_size)
context = torch.cat([uncond_embeddings, text_embeddings])

model.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(model.scheduler.timesteps):
latents = diffusion_step(model, controller, latents, context, t, guidance_scale)

image = latent2image(model.vqvae, latents)

return image, latent


@torch.no_grad()
def text2image_ldm_stable(
model,
prompt: List[str],
controller,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
generator: Optional[torch.Generator] = None,
latent: Optional[torch.FloatTensor] = None,
low_resource: bool = False,
):
register_attention_control(model, controller)
height = width = 512
batch_size = len(prompt)

text_input = model.tokenizer(
prompt,
padding="max_length",
max_length=model.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = model.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0]

context = [uncond_embeddings, text_embeddings]
if not low_resource:
context = torch.cat(context)
latent, latents = init_latent(latent, model, height, width, generator, batch_size)

# set timesteps
model.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(model.scheduler.timesteps):
latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource)

image = latent2image(model.vae, latents)

return image, latent


def register_attention_control(model, controller):
def ca_forward(self, place_in_unet):
to_out = self.to_out
if type(to_out) is torch.nn.modules.container.ModuleList:
to_out = self.to_out[0]
else:
to_out = self.to_out

def forward(hidden_states, encoder_hidden_states=None, attention_mask=None,temb=None,):
is_cross = encoder_hidden_states is not None

residual = hidden_states

if self.spatial_norm is not None:
hidden_states = self.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)

if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = self.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif self.norm_cross:
encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)

key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)

query = self.head_to_batch_dim(query)
key = self.head_to_batch_dim(key)
value = self.head_to_batch_dim(value)

attention_probs = self.get_attention_scores(query, key, attention_mask)
attention_probs = controller(attention_probs, is_cross, place_in_unet)

hidden_states = torch.bmm(attention_probs, value)
hidden_states = self.batch_to_head_dim(hidden_states)

# linear proj
hidden_states = to_out(hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if self.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / self.rescale_output_factor

return hidden_states
return forward
class DummyController:

def __call__(self, *args):
return args[0]

def __init__(self):
self.num_att_layers = 0

if controller is None:
controller = DummyController()

def register_recr(net_, count, place_in_unet):
if net_.__class__.__name__ == 'Attention':
net_.forward = ca_forward(net_, place_in_unet)
return count + 1
elif hasattr(net_, 'children'):
for net__ in net_.children():
count = register_recr(net__, count, place_in_unet)
return count

cross_att_count = 0
sub_nets = model.unet.named_children()
for net in sub_nets:
if "down" in net[0]:
cross_att_count += register_recr(net[1], 0, "down")
elif "up" in net[0]:
cross_att_count += register_recr(net[1], 0, "up")
elif "mid" in net[0]:
cross_att_count += register_recr(net[1], 0, "mid")

controller.num_att_layers = cross_att_count


def get_word_inds(text: str, word_place: int, tokenizer):
split_text = text.split(" ")
if type(word_place) is str:
word_place = [i for i, word in enumerate(split_text) if word_place == word]
elif type(word_place) is int:
word_place = [word_place]
out = []
if len(word_place) > 0:
words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
cur_len, ptr = 0, 0

for i in range(len(words_encode)):
cur_len += len(words_encode[i])
if ptr in word_place:
out.append(i + 1)
if cur_len >= len(split_text[ptr]):
ptr += 1
cur_len = 0
return np.array(out)


def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
word_inds: Optional[torch.Tensor]=None):
if type(bounds) is float:
bounds = 0, bounds
start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
if word_inds is None:
word_inds = torch.arange(alpha.shape[2])
alpha[: start, prompt_ind, word_inds] = 0
alpha[start: end, prompt_ind, word_inds] = 1
alpha[end:, prompt_ind, word_inds] = 0
return alpha


def get_time_words_attention_alpha(prompts, num_steps,
cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
tokenizer, max_num_words=77):
if type(cross_replace_steps) is not dict:
cross_replace_steps = {"default_": cross_replace_steps}
if "default_" not in cross_replace_steps:
cross_replace_steps["default_"] = (0., 1.)
alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
for i in range(len(prompts) - 1):
alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
i)
for key, item in cross_replace_steps.items():
if key != "default_":
inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
for i, ind in enumerate(inds):
if len(ind) > 0:
alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
return alpha_time_words
9 changes: 9 additions & 0 deletions .ipynb_checkpoints/requirements-checkpoint.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
diffusers==0.17.1
transformers
ftfy
opencv-python
ipywidgets
accelerate
scikit-image
torchvision
matplotlib
Binary file added __pycache__/ptp_utils.cpython-311.pyc
Binary file not shown.
Binary file added __pycache__/seq_aligner.cpython-311.pyc
Binary file not shown.
Binary file added example_images/a bird in rain.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example_images/a green plant in rain.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example_images/a green tree in raining forest.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example_images/a man with glass.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example_images/a red flower in rain.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example_images/a women with a glass.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example_images/loop_test/a street in rain.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading