-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy patheval_infid_sen.py
64 lines (49 loc) · 1.88 KB
/
eval_infid_sen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import numpy as np
import torch
import loader
from infid_sen_utils import evaluate_infid_sen
import config
class Args:
def __init__(self, args):
self.source = args
for key, val in args.items():
setattr(self, key, val)
def __repr__(self):
return repr(self.source)
class OutputLog(object):
'''Create output log'''
def __init__(self):
self.infid_dict = dict()
self.max_sen_dict = dict()
def write(self, method, infid, max_sen):
self.infid_dict[method] = infid
self.max_sen_dict[method] = max_sen
def __str__(self):
log = 'Infidelity:'
log += '{}\n'.format(self.infid_dict)
log += 'Max-Sensitivity:'
log += '{}\n'.format(self.max_sen_dict)
return log
if __name__ == "__main__":
args = Args(config.args)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
np.random.seed(args.seed)
train_loader, test_loader = loader.mnist_loaders(batch_size=1)
model = loader.mnist_load_model(args.model, state_dict=True, tf=True)
model = model.cuda()
output_log = OutputLog()
for pert in args.perts:
for exp in args.exps:
print('Perturbation =', pert, '/ Explanation =', exp)
infid, max_sen = evaluate_infid_sen(test_loader, model, exp, pert,
args.sen_r, args.sen_N)
output_log.write(exp, infid, max_sen)
print(output_log)
for sg in args.sgs:
print('Perturbation =', pert, '/ Smooth-Grad =', sg)
infid, max_sen = evaluate_infid_sen(test_loader, model, 'Smooth_Grad', pert,
args.sen_r, args.sen_N, args.sg_r, args.sg_N,
given_expl=sg)
output_log.write(sg+'-SG', infid, max_sen)
print(output_log)