forked from laksjdjf/attention-couple-ComfyUI
-
Notifications
You must be signed in to change notification settings - Fork 8
/
attention_couple.py
164 lines (138 loc) · 7.81 KB
/
attention_couple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import torch
import torch.nn.functional as F
import copy
import comfy
from comfy.ldm.modules.attention import optimized_attention
def get_masks_from_q(masks, q, original_shape):
if original_shape[2] * original_shape[3] == q.shape[1]:
down_sample_rate = 1
elif (original_shape[2] // 2) * (original_shape[3] // 2) == q.shape[1]:
down_sample_rate = 2
elif (original_shape[2] // 4) * (original_shape[3] // 4) == q.shape[1]:
down_sample_rate = 4
else:
down_sample_rate = 8
ret_masks = []
for mask in masks:
if isinstance(mask,torch.Tensor):
size = (original_shape[2] // down_sample_rate, original_shape[3] // down_sample_rate)
mask_downsample = F.interpolate(mask.unsqueeze(0), size=size, mode="nearest")
mask_downsample = mask_downsample.view(1,-1, 1).repeat(q.shape[0], 1, q.shape[2])
ret_masks.append(mask_downsample)
else: # coupling処理なしの場合
ret_masks.append(torch.ones_like(q))
ret_masks = torch.cat(ret_masks, dim=0)
return ret_masks
def set_model_patch_replace(model, patch, key):
to = model.model_options["transformer_options"]
if "patches_replace" not in to:
to["patches_replace"] = {}
if "attn2" not in to["patches_replace"]:
to["patches_replace"]["attn2"] = {}
to["patches_replace"]["attn2"][key] = patch
class AttentionCouple:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"mode": (["Attention", "Latent"], ),
}
}
RETURN_TYPES = ("MODEL", "CONDITIONING", "CONDITIONING")
FUNCTION = "attention_couple"
CATEGORY = "loaders"
def attention_couple(self, model, positive, negative, mode):
if mode == "Latent":
return (model, positive, negative) # latent coupleの場合は何もしない
self.negative_positive_masks = []
self.negative_positive_conds = []
new_positive = copy.deepcopy(positive)
new_negative = copy.deepcopy(negative)
dtype = model.model.diffusion_model.dtype
device = comfy.model_management.get_torch_device()
# maskとcondをリストに格納する
for conditions in [new_negative, new_positive]:
conditions_masks = []
conditions_conds = []
if len(conditions) != 1:
mask_norm = torch.stack([cond[1]["mask"].to(device, dtype=dtype) * cond[1]["mask_strength"] for cond in conditions])
mask_norm = mask_norm / mask_norm.sum(dim=0) # 合計が1になるように正規化(他が0の場合mask_strengthの効果がなくなる)
conditions_masks.extend([mask_norm[i] for i in range(mask_norm.shape[0])])
conditions_conds.extend([cond[0].to(device, dtype=dtype) for cond in conditions])
del conditions[0][1]["mask"] # latent coupleの無効化のため
del conditions[0][1]["mask_strength"]
else:
conditions_masks = [False]
conditions_conds = [conditions[0][0].to(device, dtype=dtype)]
self.negative_positive_masks.append(conditions_masks)
self.negative_positive_conds.append(conditions_conds)
self.conditioning_length = (len(new_negative), len(new_positive))
new_model = model.clone()
self.sdxl = hasattr(new_model.model.diffusion_model, "label_emb")
if not self.sdxl:
for id in [1,2,4,5,7,8]: # id of input_blocks that have cross attention
set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.input_blocks[id][1].transformer_blocks[0].attn2), ("input", id))
set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.middle_block[1].transformer_blocks[0].attn2), ("middle", 0))
for id in [3,4,5,6,7,8,9,10,11]: # id of output_blocks that have cross attention
set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.output_blocks[id][1].transformer_blocks[0].attn2), ("output", id))
else:
for id in [4,5,7,8]: # id of input_blocks that have cross attention
block_indices = range(2) if id in [4, 5] else range(10) # transformer_depth
for index in block_indices:
set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.input_blocks[id][1].transformer_blocks[index].attn2), ("input", id, index))
for index in range(10):
set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.middle_block[1].transformer_blocks[index].attn2), ("middle", id, index))
for id in range(6): # id of output_blocks that have cross attention
block_indices = range(2) if id in [3, 4, 5] else range(10) # transformer_depth
for index in block_indices:
set_model_patch_replace(new_model, self.make_patch(new_model.model.diffusion_model.output_blocks[id][1].transformer_blocks[index].attn2), ("output", id, index))
return (new_model, [new_positive[0]], [new_negative[0]]) # pool outputは・・・後回し
def make_patch(self, module):
def patch(q, k, v, extra_options):
len_neg, len_pos = self.conditioning_length # negative, positiveの長さ
cond_or_uncond = extra_options["cond_or_uncond"] # 0: cond, 1: uncond
q_list = q.chunk(len(cond_or_uncond), dim=0)
b = q_list[0].shape[0] # batch_size
masks_uncond = get_masks_from_q(self.negative_positive_masks[0], q_list[0], extra_options["original_shape"])
masks_cond = get_masks_from_q(self.negative_positive_masks[1], q_list[0], extra_options["original_shape"])
context_uncond = torch.cat([cond for cond in self.negative_positive_conds[0]], dim=0)
context_cond = torch.cat([cond for cond in self.negative_positive_conds[1]], dim=0)
k_uncond = module.to_k(context_uncond)
k_cond = module.to_k(context_cond)
v_uncond = module.to_v(context_uncond)
v_cond = module.to_v(context_cond)
out = []
for i, c in enumerate(cond_or_uncond):
if c == 0:
masks = masks_cond
k = k_cond
v = v_cond
length = len_pos
else:
masks = masks_uncond
k = k_uncond
v = v_uncond
length = len_neg
q_target = q_list[i].repeat(length, 1, 1)
k = torch.cat([k[i].unsqueeze(0).repeat(b,1,1) for i in range(length)], dim=0)
v = torch.cat([v[i].unsqueeze(0).repeat(b,1,1) for i in range(length)], dim=0)
if k.dtype != q_target.dtype or v.dtype != q_target.dtype:
# Ensure all dtypes match
k = k.to(q_target.dtype)
v = v.to(q_target.dtype)
qkv = optimized_attention(q_target, k, v, extra_options["n_heads"])
qkv = qkv * masks
qkv = qkv.view(length, b, -1, module.heads * module.dim_head).sum(dim=0)
out.append(qkv)
out = torch.cat(out, dim=0)
return out
return patch
NODE_CLASS_MAPPINGS = {
"Attention couple": AttentionCouple
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Attention couple": "Load Attention couple",
}