This repository has been archived by the owner on Nov 29, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreshape.py
58 lines (43 loc) · 2.17 KB
/
reshape.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn
#
# reshape the model for N classes
#
def reshape_model(model, arch, num_classes):
"""Reshape a model's output layers for the given number of classes"""
# reshape output layers for the dataset
if arch.startswith("resnet"):
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
print("=> reshaped ResNet fully-connected layer with: " + str(model.fc))
elif arch.startswith("alexnet"):
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, num_classes)
print("=> reshaped AlexNet classifier layer with: " + str(model.classifier[6]))
elif arch.startswith("vgg"):
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, num_classes)
print("=> reshaped VGG classifier layer with: " + str(model.classifier[6]))
elif arch.startswith("squeezenet"):
model.classifier[1] = torch.nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
model.num_classes = num_classes
print("=> reshaped SqueezeNet classifier layer with: " + str(model.classifier[1]))
elif arch.startswith("densenet"):
model.classifier = torch.nn.Linear(model.classifier.in_features, num_classes)
print("=> reshaped DenseNet classifier layer with: " + str(model.classifier))
elif arch.startswith("inception"):
model.AuxLogits.fc = torch.nn.Linear(model.AuxLogits.fc.in_features, num_classes)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
print("=> reshaped Inception aux-logits layer with: " + str(model.AuxLogits.fc))
print("=> reshaped Inception fully-connected layer with: " + str(model.fc))
elif arch.startswith("googlenet"):
if model.aux_logits:
from torchvision.models.googlenet import InceptionAux
model.aux1 = InceptionAux(512, num_classes)
model.aux2 = InceptionAux(528, num_classes)
print("=> reshaped GoogleNet aux-logits layers with: ")
print(" " + str(model.aux1))
print(" " + str(model.aux2))
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
print("=> reshaped GoogleNet fully-connected layer with: " + str(model.fc))
else:
print("classifier reshaping not supported for " + args.arch)
print("model will retain default of 1000 output classes")
return model