-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
74 lines (63 loc) · 2.29 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gc
import os
import requests
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import torch
from latent_class import LatentClass
from model import SDLatentTiling
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
torch.cuda.empty_cache()
gc.collect()
scheduler = 'ddpm'
# scheduler = 'ddim'
model = SDLatentTiling(scheduler=scheduler)
# Parameters
prompt_1 = "Red brick texture"
prompt_2 = "Green brick texture"
negative_prompt = "blured, ugly, deformed, disfigured, poor details, bad anatomy, pixelized, bad order"
inference_steps = 40
seed = 151
cfg_scale = 7.5
max_replica_width = 5
max_width = 32
height = 512
width = 512
input_image = None
######################### IMAGE TO IMAGE #########################
strength = 0.92
# url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
# response = requests.get(url)
# input_image = Image.open(BytesIO(response.content)).convert("RGB")
# input_image = input_image.resize((768, 512))
#################################################################
if input_image:
width, height = input_image.size
# Right, Left, Up, Down
lat1 = LatentClass(prompt=prompt_1, negative_prompt=negative_prompt, side_id=[1, 1, None, None],
side_dir=['cw', 'ccw', None, None])
lat2 = LatentClass(prompt=prompt_2, negative_prompt=negative_prompt, side_id=[1, 1, None, None],
side_dir=['cw', 'ccw', None, None])
latents_arr = [lat1, lat2]
device = 'cuda' if torch.cuda.is_available() else 'cpu'
new_latents_arr = model(latents_arr=latents_arr,
negative_prompt=negative_prompt,
inference_steps=inference_steps,
seed=seed,
cfg_scale=cfg_scale,
height=height,
width=width,
max_width=max_width,
max_replica_width=max_replica_width,
strength=strength,
device=device)
torch.cuda.empty_cache()
gc.collect()
lat1_new = new_latents_arr[0]
lat2_new = new_latents_arr[1]
t_1 = np.concatenate((lat1_new.image, lat2_new.image, lat2_new.image, lat1_new.image),
axis=1)
plt.imshow(t_1)
plt.show()