-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcombined_test_model.py
54 lines (43 loc) · 1.59 KB
/
combined_test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from cover_test_model import print_acc
from text_model import load_data_loaders
from combined_model import CombinedModel
from combined_dataloaders import *
import sys
import torch
def test_combined_model(topK):
"""
Test the combined model on the test set
"""
BATCH_SIZE = 32
print("creating model")
model = CombinedModel()
model.load_state_dict(torch.load("combined_models/combined_model.pt"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
print("loading dataloaders")
dataloaders = load_data_loaders("dataloaders/combined_data_loaders_{}.pickle".format(BATCH_SIZE))
test_dataloader = dataloaders["test"]
dataset_size = len(test_dataloader.dataset)
print("computing acc")
print_acc(model, test_dataloader, dataset_size, topK, BATCH_SIZE, device)
def test_combined_model_10(topK):
"""
Test the combined model on the test set with 10 classes
"""
BATCH_SIZE = 32
print("creating model")
model = CombinedModel(10)
model.load_state_dict(torch.load("combined_models/combined_model_10.pt"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
print("loading dataloaders")
dataloaders = load_data_loaders("dataloaders/combined_data_loaders_{}_10.pickle".format(BATCH_SIZE))
test_dataloader = dataloaders["test"]
dataset_size = len(test_dataloader.dataset)
print("computing acc")
print_acc(model, test_dataloader, dataset_size, topK, BATCH_SIZE, device)
if __name__ == "__main__":
topK = int(sys.argv[1])
test_combined_model_10(topK)