-
Notifications
You must be signed in to change notification settings - Fork 0
/
runexplain.py
264 lines (214 loc) · 10.5 KB
/
runexplain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
# -*- coding: utf-8 -*-
# Choose the same model as in clipga.py (or try how it compares if you don't!):
# ['ViT-B/32', 'ViT-B/16', 'ViT-L/14']
clipmodel='ViT-B/32'
# No need to bother with the rest of this
import os
import sys
import glob
from PIL import Image
print('\n\n 2. Creating CLIP heatmaps...This shouldn\'t take more than 30 seconds in total.\n')
def show_image_relevance(image_relevance, image, orig_image, img_path):
# create heatmap from mask on image
def show_cam_on_image(img, mask):
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return cam
dim = int(image_relevance.numel() ** 0.5)
image_relevance = image_relevance.reshape(1, 1, dim, dim)
image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
image_relevance = image_relevance.reshape(224, 224).cuda().data.cpu().numpy()
image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
image = image[0].permute(1, 2, 0).data.cpu().numpy()
image = (image - image.min()) / (image.max() - image.min())
vis = show_cam_on_image(image, image_relevance)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
import torch
import CLIP.clip as clip
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt
from captum.attr import visualization
#@title Control context expansion (number of attention layers to consider)
#@title Number of layers for image Transformer
start_layer = -1#@param {type:"number"}
#@title Number of layers for text Transformer
start_layer_text = -1#@param {type:"number"}
def interpret(image, texts, model, device, start_layer=start_layer, start_layer_text=start_layer_text):
batch_size = texts.shape[0]
images = image.repeat(batch_size, 1, 1, 1)
logits_per_image, logits_per_text = model(images, texts)
probs = logits_per_image.softmax(dim=-1).detach().cpu().numpy()
index = [i for i in range(batch_size)]
one_hot = np.zeros((logits_per_image.shape[0], logits_per_image.shape[1]), dtype=np.float32)
one_hot[torch.arange(logits_per_image.shape[0]), index] = 1
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
one_hot = torch.sum(one_hot.cuda() * logits_per_image)
model.zero_grad()
image_attn_blocks = list(dict(model.visual.transformer.resblocks.named_children()).values())
if start_layer == -1:
# calculate index of last layer
start_layer = len(image_attn_blocks) - 1
num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
R = torch.eye(num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype).to(device)
R = R.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
for i, blk in enumerate(image_attn_blocks):
if i < start_layer:
continue
grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
cam = blk.attn_probs.detach()
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
cam = grad * cam
cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
cam = cam.clamp(min=0).mean(dim=1)
R = R + torch.bmm(cam, R)
image_relevance = R[:, 0, 1:]
text_attn_blocks = list(dict(model.transformer.resblocks.named_children()).values())
if start_layer_text == -1:
# calculate index of last layer
start_layer_text = len(text_attn_blocks) - 1
num_tokens = text_attn_blocks[0].attn_probs.shape[-1]
R_text = torch.eye(num_tokens, num_tokens, dtype=text_attn_blocks[0].attn_probs.dtype).to(device)
R_text = R_text.unsqueeze(0).expand(batch_size, num_tokens, num_tokens)
for i, blk in enumerate(text_attn_blocks):
if i < start_layer_text:
continue
grad = torch.autograd.grad(one_hot, [blk.attn_probs], retain_graph=True)[0].detach()
cam = blk.attn_probs.detach()
cam = cam.reshape(-1, cam.shape[-1], cam.shape[-1])
grad = grad.reshape(-1, grad.shape[-1], grad.shape[-1])
cam = grad * cam
cam = cam.reshape(batch_size, -1, cam.shape[-1], cam.shape[-1])
cam = cam.clamp(min=0).mean(dim=1)
R_text = R_text + torch.bmm(cam, R_text)
text_relevance = R_text
return text_relevance, image_relevance
def show_image_relevance(image_relevance, image, orig_image):
# create heatmap from mask on image
def show_cam_on_image(img, mask):
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return cam
fig, axs = plt.subplots(1, 2)
axs[0].imshow(orig_image);
axs[0].axis('off');
dim = int(image_relevance.numel() ** 0.5)
image_relevance = image_relevance.reshape(1, 1, dim, dim)
image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
image_relevance = image_relevance.reshape(224, 224).cuda().data.cpu().numpy()
image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
image = image[0].permute(1, 2, 0).data.cpu().numpy()
image = (image - image.min()) / (image.max() - image.min())
vis = show_cam_on_image(image, image_relevance)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
axs[1].imshow(vis);
axs[1].axis('off');
from CLIP.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
_tokenizer = _Tokenizer()
def show_heatmap_on_text(text, text_encoding, R_text):
CLS_idx = text_encoding.argmax(dim=-1)
R_text = R_text[CLS_idx, 1:CLS_idx]
text_scores = R_text / R_text.sum()
text_scores = text_scores.flatten()
#print(text_scores)
text_tokens=_tokenizer.encode(text)
text_tokens_decoded=[_tokenizer.decode([a]) for a in text_tokens]
vis_data_records = [visualization.VisualizationDataRecord(text_scores,0,0,0,0,0,text_tokens_decoded,1)]
#visualization.visualize_text(vis_data_records)
clip.clip._MODELS = {
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
}
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load(clipmodel, device=device, jit=False)
def show_image_relevance(image_relevance, image, orig_image, img_path):
# the function body here...
# create heatmap from mask on image
def show_cam_on_image(img, mask):
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
return cam
dim = int(image_relevance.numel() ** 0.5)
image_relevance = image_relevance.reshape(1, 1, dim, dim)
image_relevance = torch.nn.functional.interpolate(image_relevance, size=224, mode='bilinear')
image_relevance = image_relevance.reshape(224, 224).cuda().data.cpu().numpy()
image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
image = image[0].permute(1, 2, 0).data.cpu().numpy()
image = (image - image.min()) / (image.max() - image.min())
vis = show_cam_on_image(image, image_relevance)
vis = np.uint8(255 * vis)
vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
return vis
class color:
PURPLE = '\033[95m'
CYAN = '\033[96m'
DARKCYAN = '\033[36m'
BLUE = '\033[94m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
RED = '\033[91m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
END = '\033[0m'
# the other functions and model setup here...
# path to your image and token folders
image_folder = "IMG_IN"
token_folder = "TOK"
heatmap_folder = "VIS"
# Get all files in the folder
all_files = os.listdir(image_folder)
# Loop through each file
for file in all_files:
file_path = os.path.join(image_folder, file)
# Attempt to open the file using PIL
try:
img = Image.open(file_path)
img = img.convert('RGBA') # Convert to RGBA format for PNG
# Construct the new filename (with .png extension)
new_file_path = os.path.join(image_folder, os.path.splitext(file)[0] + '.png')
# Save the image as PNG
img.save(new_file_path, "PNG")
# If the original file was not a PNG, you can optionally delete it
if not file_path.endswith('.png'):
os.remove(file_path)
except IOError:
# This will be triggered if the file isn't a valid image
print(f"{file} is not a valid image.")
# list of all image files in the image_folder
image_files = glob.glob(f"{image_folder}/*.png")
# loop through each image
for img_file in image_files:
# strip the file extension and folder to get the image name
img_name = os.path.basename(os.path.splitext(img_file)[0])
# open the corresponding token file
token_file = f"{token_folder}/tokens_{img_name}.txt"
with open(token_file, 'r') as f:
tokens = f.read().split()
# preprocess the image
img = preprocess(Image.open(img_file)).unsqueeze(0).to(device)
print(f"Processing {img_file} tokens...")
# loop through each token
for token in tokens:
texts = [token]
text = clip.tokenize(texts).to(device)
# run the model
R_text, R_image = interpret(model=model, image=img, texts=text, device=device)
batch_size = text.shape[0]
for i in range(batch_size):
show_heatmap_on_text(texts[i], text[i], R_text[i])
show_image_relevance(R_image[i], img, orig_image=Image.open(img_file), img_path=img_file)
# save the heatmap image with the token in the filename
heatmap_filename = f"{heatmap_folder}/{img_name}_{token}.png"
vis = show_image_relevance(R_image[i], img, orig_image=Image.open(img_file), img_path=img_file)
cv2.imwrite(heatmap_filename, vis)