-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path_counterfactual.py
121 lines (104 loc) · 4.92 KB
/
_counterfactual.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# imports
from utils import *
from skimage.metrics import mean_squared_error
from skimage.segmentation import mark_boundaries
# Helper functions
def get_segments(img, height, width, imgheight, imgwidth):
# Returns segments of images given image array, height, width
segments = []
for i in range(0,imgheight, height):
for j in range(0,imgwidth, width):
box = (j, i, j+width, i+height)
segments.append(box)
return segments
def calculate_error(_img1, _img2):
# Calculates the mean square error
return mean_squared_error(_img1, _img2)
def get_reconstruction(_img, model):
# predicts and returns the reconstruction for given input image
_reshaped = _img.reshape((1, 128, 128, 3))
return model.predict(_reshaped)[0]
def minimum_edits(_img1, _img2, _threshold, _filter_size, model):
# Creates the minimum edits for counterfactual
img_height = np.array(_img1).shape[0]
img_width = np.array(_img1).shape[1]
segs1 = get_segments(_img1, _filter_size, _filter_size, img_height, img_width)
segs2 = get_segments(_img2, _filter_size, _filter_size, img_height, img_width)
edits = []
current_img = np.array(_img1, copy=True)
cur_iter = 0
while True:
if (cur_iter > len(segs1)):
break
max_error = 0
seg_idx = 0
for idx, seg in enumerate(segs2):
edited_img = np.array(current_img, copy=True)
edited_img[seg[0]: seg[2], seg[1]: seg[3], :] = _img2[seg[0]: seg[2], seg[1]: seg[3], :]
edited_reconstruction = get_reconstruction(edited_img, model)
edited_error = calculate_error(edited_img, edited_reconstruction)
if edited_error > max_error:
max_error = edited_error
seg_idx = idx
best_seg = segs2.pop(seg_idx)
current_img[best_seg[0]: best_seg[2], best_seg[1]: best_seg[3], :] = _img2[best_seg[0]: best_seg[2], best_seg[1]: best_seg[3], :]
edits.append(best_seg)
if (max_error > _threshold):
return edits
cur_iter += 1
return edits
def draw_boundaries(_edits, _img):
# draws the bounding boxes on the images
masks = np.zeros((128, 128)).astype(int)
for _edit in _edits:
masks[_edit[0]: _edit[2], _edit[1]: _edit[3]] = 1
_final_img = mark_boundaries(_img / 2 + 0.5, masks)
return _final_img # returns value between 0 and 1
# _final_img = (_final_img * 255).astype(np.uint8)
# _im = Image.fromarray(_final_img)
# _im.save(_path)
class Counterfactual:
'''
Class for counterfactual explanations
'''
def __init__(self, model, debug = True):
# initiates the class
self.model = model
self.debug = debug
def get_mse(self, a, b):
# Returns MSE for two images
return ((a-b)**2).mean(axis = (1,2,3))
def explain(
self,
normal_samples,
anomalous_samples,
threshold_pct = 0.98,
block_size = 32,
anomaly_type = "",
save_path = None,
save_results = False
):
# Derives explanations
results = dict()
# Check for datatype
if not isinstance(normal_samples, np.ndarray): normal_samples = np.array(normal_samples)
if not isinstance(anomalous_samples, np.ndarray): anomalous_samples = np.array(anomalous_samples)
# Reconstructions and error calculations can be reduced to being done once
normal_reconstructions = self.model.predict(normal_samples)
normal_mse = self.get_mse(normal_samples, normal_reconstructions)
anomalous_reconstructions = self.model.predict(anomalous_samples)
anomalous_mse = self.get_mse(anomalous_samples, anomalous_reconstructions)
if self.debug: print("NORMAL:", normal_samples.shape, normal_reconstructions.shape, normal_mse.shape)
if self.debug: print("ANOMALY:", anomalous_samples.shape, anomalous_reconstructions.shape, anomalous_mse.shape)
if self.debug: print("Models loaded. Starting analysis")
# Run the counterfactual part
for norm_idx, (norm_image, norm_reconstruction, norm_err) in tqdm(enumerate(zip(normal_samples, normal_reconstructions, normal_mse))):
results[norm_idx] = dict()
for anomaly_idx, (anomaly_image, anomaly_reconstruction, anomaly_err) in enumerate(zip(anomalous_samples, anomalous_reconstructions, anomalous_mse)):
if norm_err < anomaly_err:
threshold = anomaly_err * threshold_pct
edits_made = minimum_edits(norm_image, anomaly_image, threshold, block_size, self.model)
masked_result = draw_boundaries(edits_made, anomaly_image)
results[norm_idx][anomaly_idx] = masked_result
if save_results: save_image(masked_result, join_paths([save_path, "%s_N-%d_A-%d.png"%(anomaly_type, norm_idx, anomaly_idx)]))
return results