forked from kiwi12138/RealisticTTA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
83 lines (66 loc) · 3.37 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import logging
import numpy as np
from datasets.imagenet_subsets import IMAGENET_D_MAPPING
logger = logging.getLogger(__name__)
def split_results_by_domain(domain_dict, data, predictions):
"""
Create a dictionary which separates the labels and predictions by domain
:param domain_dict: dictionary, where the keys are the domains and the content is [labels, predictions]
:param data: list containing [images, labels, domains, ...]
:param predictions: predictions of the model
:return: updated result dict
"""
imgs = data[0][0] if isinstance(data[0], list) else data[0]
for i in range(imgs.shape[0]):
label, domain = data[1][i], data[2][i]
if domain in domain_dict.keys():
domain_dict[domain].append([label.item(), predictions[i].item()])
else:
domain_dict[domain] = [[label.item(), predictions[i].item()]]
return domain_dict
def eval_domain_dict(domain_dict, domain_seq=None):
"""
Print detailed results for each domain. This is useful for settings where the domains are mixed
:param domain_dict: dictionary containing the labels and predictions for each domain
:param domain_seq: if specified and the domains are contained in the domain dict, the results will be printed in this order
"""
correct = []
num_samples = []
avg_error_domains = []
dom_names = domain_seq if all([dname in domain_seq for dname in domain_dict.keys()]) else domain_dict.keys()
logger.info(f"Splitting up the results by domain...")
for key in dom_names:
content = np.array(domain_dict[key])
correct.append((content[:, 0] == content[:, 1]).sum())
num_samples.append(content.shape[0])
accuracy = correct[-1] / num_samples[-1]
error = 1 - accuracy
avg_error_domains.append(error)
logger.info(f"{key:<20} error: {error:.2%}")
total_err = 1 - sum(correct) / sum(num_samples)
logger.info(f"Average error across all domains: {sum(avg_error_domains) / len(avg_error_domains):.2%}")
logger.info(f"Error over all samples: {total_err:.2%}")
def get_accuracy(model: torch.nn.Module,
data_loader: torch.utils.data.DataLoader,
dataset_name: str,
domain_name: str,
setting: str,
domain_dict: dict,
device: torch.device = None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
correct = 0.
with torch.no_grad():
for i, data in enumerate(data_loader):
imgs, labels = data[0], data[1]
output = model([img.to(device) for img in imgs]) if isinstance(imgs, list) else model(imgs.to(device))
predictions = output.argmax(1)
if dataset_name == "imagenet_d" and domain_name != "none":
mapping_vector = list(IMAGENET_D_MAPPING.values())
predictions = torch.tensor([mapping_vector[pred] for pred in predictions], device=device)
correct += (predictions == labels.to(device)).float().sum()
if "mixed_domains" in setting and len(data) >= 3:
domain_dict = split_results_by_domain(domain_dict, data, predictions)
accuracy = correct.item() / len(data_loader.dataset)
return accuracy, domain_dict