forked from lcicek/Critic-VAE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvae_utility.py
476 lines (366 loc) · 15.2 KB
/
vae_utility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
from io import BytesIO
import os
#import minerl
import statistics
import torch
from torch import Tensor
import torch.utils
import torch.distributions
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from collections import defaultdict
import denseCRF
from vae_parameters import *
from vae_nets import *
THRESHOLD = 50
#font = ImageFont.truetype("/usr/share/fonts/truetype/ubuntu/Ubuntu-R.ttf", 10)
font = ImageFont.load_default()
titles = ["orig img\n+crit val", "crit val\ninjected", "crit=0\ninjected", "difference\nmask", f"thr-mask\nthr={THRESHOLD}", "thr-mask +\ncrf", "ground\ntruth"]
# copied from critic-code, source and github link are listed in bachelor thesis
def crf(imgs, mask, Y, skip=1):
mask = mask.copy()
w1 = [22] # weight of bilateral term
alpha = [12] # spatial std
beta = [3.1] # rgb std
w2 = [8] # weight of spatial term
gamma = [1.8] # spatial std
it = [10] # iteration
res = []
params = []
for param in [(a,b,c,d,e,i) for a in w1 for b in alpha for c in beta for d in w2 for e in gamma for i in it]:
M = mask[::skip]
#param = (w1, alpha, beta, w2, gamma, it)
for i, img in enumerate(imgs[::skip]):
maskframe = M[i,0]
prob = np.stack((1-maskframe, maskframe), axis=-1)
seg = denseCRF.densecrf(img, prob, param)
M[i,0] = seg
M = M.transpose(0, 2, 3, 1).astype(np.bool)
r = np.sum(Y[::skip] & M)/np.sum(Y[::skip] | M)
res.append(r)
params.append(param)
res = np.array(res)
order = np.argsort(res)
res = res[order]
params = np.array(params)[order]
mask[::skip] = M.transpose(0,3,1,2)
return (mask >= 1)
def get_iou(G, T):
tp = np.sum(G & T) # intersection i.e. true positive
fn = np.sum(G & np.logical_not(T)) # false negative
fp = np.sum(np.logical_not(G) & T) # false positive
if tp+fn+fp == 0: # 0 out of 0 correctly classified pixels is equivalent to IoU=1
iou = 1 # so avoid dividing by zero
else:
iou = tp / (tp + fn + fp) # intersection div by union
iou = round(iou, 3)
return iou
def load_textured_minerl():
text_dset = np.load(MINERL_EPISODE_PATH + "X.npy") # / 255.0
gt_dset = np.expand_dims(np.all(np.load(MINERL_EPISODE_PATH + "Y.npy"), axis=-1), axis=-1)
text_dset = text_dset[100:5000:2]
gt_dset = gt_dset[100:5000:2].transpose(0, 3, 1, 2) # gt = ground turth
gt_dset = gt_dset.squeeze()
#gt_dset = gt_dset[np.newaxis, ...]
# Y.transpose(0,3,1,2)
return text_dset, gt_dset
# source: https://github.com/python-pillow/Pillow/issues/4263
def create_video(frames):
#count = 200
#for i, frame in enumerate(frames): # save first "count" images
# if i > count:
# break
# frame.save(f'{SAVE_PATH}/image-{i:03d}.png', format="png")
print('creating video...')
if not os.path.exists(VIDEO_PATH):
os.mkdir(VIDEO_PATH)
byteframes = []
for f in frames:
byte = BytesIO()
byteframes.append(byte)
f.save(byte, format="GIF")
imgs = [Image.open(byteframe) for byteframe in byteframes]
imgs[0].save(f"{VIDEO_PATH}video-threshold={THRESHOLD}.gif", format='GIF', duration=100, save_all=True, loop=0, append_images=imgs[1:])
def get_diff_factor(max_values):
mean_max = statistics.mean(max_values)
diff_factor = 1.0 / mean_max if mean_max != 0 else 0
return diff_factor, mean_max
def save_bin_info_file(bin_ious, bin_frames, bin_gts):
total_gt = np.sum(list(bin_gts.values()))
with open(f'bin_info_vae1.txt', 'w') as f:
f.write('ground truth pixels sorted by bin:\n')
for value_bin in bin_gts:
count = bin_gts[value_bin]
f.write(f'bin: {value_bin}, pixels = {count} = {round(count/total_gt, 2) * 100}%\n')
f.write('\nframes separated by bin:\n')
for value_bin in bin_frames:
count = bin_frames[value_bin]
f.write(f'bin: {value_bin}, frames = {count} = {round(count/1200, 2) * 100}%\n')
f.write('\niou-mean and std:\n')
for value_bin in bin_ious:
mean = round(statistics.mean(bin_ious[value_bin]), 2)
std = round(statistics.stdev(bin_ious[value_bin]), 2)
f.write(f'bin: {value_bin}, iou_mean={mean}, iou_std={std}\n')
def save_bin_info(preds, gt, thr_masks):
bin_ious = defaultdict(lambda: [], {})
bin_frames = defaultdict(lambda: 0, {})
bin_gts = defaultdict(lambda: 0, {})
for i, pred in enumerate(preds):
value_bin = round(pred.item(), 1)
thr_iou = get_iou(thr_masks[i], gt[i])
bin_ious[value_bin].append(thr_iou)
bin_frames[value_bin] += 1
bin_gts[value_bin] += gt[i].sum()
save_bin_info_file(bin_ious, bin_frames, bin_gts)
def get_diff_and_thr_masks(diff_masks, max_values, thr=THRESHOLD):
thr_masks = []
diff_factor, mean_max = get_diff_factor(max_values)
for i, diff in enumerate(diff_masks):
diff = prepare_diff(diff, diff_factor, mean_max)
diff = (diff * 255).astype(np.uint8)
thr_mask = diff > thr
thr_masks.append(thr_mask)
diff_masks[i] = diff
return np.array(diff_masks), np.array(thr_masks)
def eval_textured_frames(trajectory, vae, critic, gt, t=THRESHOLD):
print('processing frames...')
one_recons = []
zero_recons = []
diff_masks = [] # unnormalized yet
preds = []
frames = []
max_values = []
for image in trajectory:
frame = preprocess_observation(image)
pred = critic.evaluate(frame)
ro, rz, diff, max_value = get_diff_image(vae, frame, pred[0])
one_recons.append(ro)
zero_recons.append(rz)
diff_masks.append(diff)
max_values.append(max_value)
preds.append(pred[0])
frames.append(frame)
diff_masks, thr_masks = get_diff_and_thr_masks(diff_masks, max_values, thr=t)
thr_iou = get_iou(gt, thr_masks)
crf_imgs = trajectory[:, np.newaxis, ...]
crf_diff_mask = np.array(thr_masks)[:, np.newaxis, ...].astype(np.float32)
crf_gt = gt[..., np.newaxis]
crf_masks = crf(crf_imgs, crf_diff_mask, crf_gt).squeeze()
crf_iou = get_iou(gt, crf_masks)
ret = []
for i, frame in enumerate(frames):
final_frame = get_final_frame(
frame,
one_recons[i],
zero_recons[i],
Image.fromarray(diff_masks[i]),
preds[i],
gt_img=Image.fromarray(gt[i]),
thr_img=Image.fromarray(thr_masks[i]),
crf_img=Image.fromarray(crf_masks[i]),
thr_iou=thr_iou,
crf_iou=crf_iou
)
ret.append(final_frame)
save_bin_info(preds, gt, thr_masks)
return ret, thr_iou, crf_iou
def collect_frames(trajectory_names): # returns list of (64, 64, 3) images for each trajectory
print('collecting frames...')
import os
import minerl
steps = 1000
os.environ['MINERL_DATA_ROOT'] = MINERL_DATA_ROOT_PATH
data = minerl.data.make('MineRLTreechop-v0', num_workers=1)
all_frames = []
for name in trajectory_names:
frames = []
trajectory = data.load_data(name, skip_interval=0, include_metadata=False)
for dataset_observation, _, _, _, _ in trajectory:
obs = dataset_observation["pov"]
obs = preprocess_observation(obs) # tensor
frames.append(obs)
if len(frames) >= steps:
all_frames.append(frames)
break
del data
return all_frames
def get_injected_img(autoencoder, img_tensor, pred):
orig_recon = autoencoder.evaluate(img_tensor, pred)
recons = autoencoder.inject(img_tensor)
conc_h = np.concatenate((
to_np(img_tensor.view(-1, ch, h, w)[0]),
#to_np(orig_recon.view(-1, ch, w, w)[0]),
), axis=2)
conc_recons = np.concatenate([to_np(recons[i].view(-1, ch, h, w)[0]) for i in range(inject_n)], axis=2)
conc_h = np.concatenate((conc_h, conc_recons), axis=2)
_, img = prepare_rgb_image(conc_h)
return img
def get_diff_image(autoencoder, img_tensor, pred, one=False):
if one:
high_tensor = torch.ones(1).to(device)
else:
high_tensor = torch.zeros(1).to(device) + pred
low_tensor = torch.zeros(1).to(device)
recon_one = autoencoder.evaluate(img_tensor, high_tensor)
recon_zero = autoencoder.evaluate(img_tensor, low_tensor)
print(recon_one.shape)
recon_one = to_np(recon_one.view(-1, ch, h, w)[0])
recon_zero = to_np(recon_zero.view(-1, ch, h, w)[0])
diff = np.subtract(recon_zero, recon_one)
diff = abs(diff)
diff = np.transpose(diff, (1, 2, 0))
diff = np.dot(diff[...,:3], [0.2989, 0.5870, 0.1140]) # to greyscale
#diff = (diff * 255).astype(np.uint8)
max_value = np.amax(diff)
return recon_one, recon_zero, diff, max_value
def prepare_diff(diff_img, diff_factor, mean_max):
diff_img[diff_img > mean_max] = mean_max
diff_img = diff_img * diff_factor
# diff_img = diff_img.astype(np.uint8)
return diff_img
def get_final_frame(img_tensor, recon_one, recon_zero, diff_img, pred, gt_img=None, thr_img=None, crf_img=None, thr_iou=None, crf_iou=None):
conc_h = np.array(np.concatenate((
to_np(img_tensor.view(-1, ch, h, w)[0]),
recon_one,
recon_zero,
), axis=2))
_, conc_img = prepare_rgb_image(conc_h)
with_masks = gt_img is not None
image_count = 7 if with_masks else 4
height = w*2 if with_masks else w
ih = w if with_masks else 0 # image height
width = w*image_count
img = Image.new('RGB', (width, height))
draw = ImageDraw.Draw(img)
img.paste(conc_img, (0, ih))
img.paste(diff_img, (w*3, ih))
if with_masks:
img.paste(thr_img, (w*4, ih))
img.paste(crf_img, (w*5, ih))
img.paste(gt_img, (w*6, ih))
for i, title in enumerate(titles):
if (i == 4):
title += f"\niou={thr_iou}"
elif (i == 5):
title += f"\niou={crf_iou}"
draw.text((w*i+2, 0), title, (255,255,255), font=font)
draw = ImageDraw.Draw(img)
draw.text((2, ih+2), f'{pred.item():.1f}', (255,255,255), font=font)
return img
def adjust_values(obs):
img_array = np.array(obs).astype(np.float32)
img_array /= 255 # to range 0-1
return img_array
def reverse_preprocess(recon):
recon = to_np(recon.view(-1, ch, w, w)[0])
recon = recon.transpose(1, 2, 0) # from CHW to HWC
recon = (recon * 255).astype(np.uint8)
return recon
def preprocess_observation(obs):
img_array = adjust_values(obs)
img_array = img_array.transpose(2, 0, 1) # HWC to CHW for critic
img_array = img_array[np.newaxis, ...] # add batch_size = 1 to make it BCHW
img_tensor = Tensor(img_array).to(device)
return img_tensor
def load_vae_network(vae, second_vae=False):
if second_vae:
enc_path = SECOND_ENCODER_PATH
dec_path = SECOND_DECODER_PATH
else:
enc_path = ENCODER_PATH
dec_path = DECODER_PATH
try:
vae.encoder.load_state_dict(torch.load(enc_path))
vae.decoder.load_state_dict(torch.load(dec_path))
except Exception as e:
print(e)
vae.eval()
vae.encoder.eval()
vae.decoder.eval()
def load_critic(path,crafter=False):
if crafter:
from crafter_extension_critic_model import Critic
critic = Critic()
else:
from critic_net import Critic
critic = Critic()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
critic.load_state_dict(torch.load(path,map_location=device))
critic.eval()
critic.to(device)
return critic
def log_info(losses, logger, batch_i, ep, num_samples):
info = {
'recon_loss': losses['recon_loss'].item(),
'kld': losses['KLD'].item(),
'total_loss': losses['total_loss'].item()
}
for tag, value in info.items():
logger.scalar_summary(tag, value, batch_i + (num_samples * ep))
def to_np(x):
return x.data.cpu().numpy()
def prepare_rgb_image(img_array): # numpy_array
img_array = np.transpose(img_array, (1, 2, 0)) # CHW to HWC
img_array = (img_array * 255).astype(np.uint8)
image = Image.fromarray(img_array, mode='RGB')
return img_array, image
# source: https://github.com/KarolisRam/MineRL2021-Research-baselines/blob/main/standalone/Behavioural_cloning.py#L105
def load_minerl_data(critic, recon_dset=False, vae=None):
print("loading minerl-data...")
### Initialize mineRL dataset ###
#os.environ['MINERL_DATA_ROOT'] = MINERL_DATA_ROOT_PATH
# data = minerl.data.make('MineRLTreechop-v0', num_workers=1)
trajectory_names = data.get_trajectory_names()
rng = np.random.default_rng(seed=0)
rng.shuffle(trajectory_names)
collect = 150
dset = []
# Add trajectories to the data until we reach the required DATA_SAMPLES.
for trajectory_name in trajectory_names:
if len(dset) >= total_images: # total_images defined in vae_parameters.py
break
print(f'total images = {len(dset)}')
c_high = 0
c_mid = 0
c_low = 0
trajectory = data.load_data(trajectory_name, skip_interval=0, include_metadata=False)
for dataset_observation, _, _, _, _ in trajectory:
obs = dataset_observation["pov"]
obs = preprocess_observation(obs)
pred = critic.evaluate(obs)
pred = pred[0]
if recon_dset:
obs_pred = vae.evaluate(obs, torch.zeros(1).to(device) + pred)
obs_low = vae.evaluate(obs, torch.zeros(1).to(device))
obs_pred = obs_pred.detach().cpu().numpy()
obs_low = obs_low.detach().cpu().numpy()
#print(f'memory:: high:{torch.cuda.memory_allocated(obs_high)}, low:{torch.cuda.memory_allocated(obs_low)}, obs: {torch.cuda.memory_allocated(obs)}')
if c_high >= collect and c_low >= collect and c_mid >= collect:
break
elif 0.4 <= pred <= 0.6 and c_mid < collect:
dset.append(obs_pred)
dset.append(obs_low)
c_mid += 1
elif pred >= 0.7 and c_high < collect:
dset.append(obs_pred)
c_high += 1
elif pred <= 0.25 and c_low < collect:
dset.append(obs_low)
#dset.append(obs_low)
c_low += 1
else:
obs = obs.detach().cpu().numpy()
if c_high >= collect and c_low >= collect and c_mid >= collect:
break
elif 0.4 <= pred <= 0.6 and c_mid < collect:
dset.append(obs)
c_mid += 1
elif pred >= 0.7 and c_high < collect:
dset.append(obs)
c_high += 1
elif pred <= 0.25 and c_low < collect:
dset.append(obs)
c_low += 1
#low_val = np.array(dset)
del data # without this line, error gets thrown at the end of the program
return dset