-
Notifications
You must be signed in to change notification settings - Fork 88
/
Copy pathmodel.py
170 lines (130 loc) · 5.67 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import torch
import torch.nn as nn
from torchvision.models import resnet18
import math
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
class FaceModel(nn.Module):
def __init__(self,embedding_size,num_classes,pretrained=False):
super(FaceModel, self).__init__()
self.model = resnet18(pretrained)
self.embedding_size = embedding_size
self.model.fc = nn.Linear(512*3*3, self.embedding_size)
self.model.classifier = nn.Linear(self.embedding_size, num_classes)
def l2_norm(self,input):
input_size = input.size()
buffer = torch.pow(input, 2)
normp = torch.sum(buffer, 1).add_(1e-10)
norm = torch.sqrt(normp)
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
output = _output.view(input_size)
return output
def forward(self, x):
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
x = self.model.layer4(x)
x = x.view(x.size(0), -1)
x = self.model.fc(x)
self.features = self.l2_norm(x)
# Multiply by alpha = 10 as suggested in https://arxiv.org/pdf/1703.09507.pdf
alpha=10
self.features = self.features*alpha
#x = self.model.classifier(self.features)
return self.features
def forward_classifier(self, x):
features = self.forward(x)
res = self.model.classifier(features)
return res
from torch.nn.parameter import Parameter
class FaceModelCenter(nn.Module):
def __init__(self,embedding_size,num_classes, checkpoint=None):
super(FaceModelCenter, self).__init__()
self.model = resnet18()
self.model.avgpool = None
self.model.fc1 = nn.Linear(512*3*3, 512)
self.model.fc2 = nn.Linear(512, embedding_size)
self.model.classifier = nn.Linear(embedding_size, num_classes)
self.centers = torch.zeros(num_classes, embedding_size).type(torch.FloatTensor)
self.num_classes = num_classes
self.apply(self.weights_init)
if checkpoint is not None:
# Check if there are the same number of classes
if list(checkpoint['state_dict'].values())[-1].size(0) == num_classes:
self.load_state_dict(checkpoint['state_dict'])
self.centers = checkpoint['centers']
else:
own_state = self.state_dict()
for name, param in checkpoint['state_dict'].items():
if "classifier" not in name:
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
own_state[name].copy_(param)
def weights_init(self,m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('BatchNorm') != -1:
m.weight.data.fill_(1)
m.bias.data.zero_()
elif classname.find('Linear') != -1:
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
def get_center_loss(self,target, alpha):
batch_size = target.size(0)
features_dim = self.features.size(1)
target_expand = target.view(batch_size,1).expand(batch_size,features_dim)
centers_var = Variable(self.centers)
centers_batch = centers_var.gather(0,target_expand).cuda()
criterion = nn.MSELoss()
center_loss = criterion(self.features, centers_batch)
diff = centers_batch - self.features
unique_label, unique_reverse, unique_count = np.unique(target.cpu().data.numpy(), return_inverse=True, return_counts=True)
appear_times = torch.from_numpy(unique_count).gather(0,torch.from_numpy(unique_reverse))
appear_times_expand = appear_times.view(-1,1).expand(batch_size,features_dim).type(torch.FloatTensor)
diff_cpu = diff.cpu().data / appear_times_expand.add(1e-6)
#∆c_j =(sum_i=1^m δ(yi = j)(c_j − x_i)) / (1 + sum_i=1^m δ(yi = j))
diff_cpu = alpha * diff_cpu
for i in range(batch_size):
#Update the parameters c_j for each j by c^(t+1)_j = c^t_j − α · ∆c^t_j
self.centers[target.data[i]] -= diff_cpu[i].type(self.centers.type())
return center_loss, self.centers
def l2_norm(self,input):
input_size = input.size()
buffer = torch.pow(input, 2)
normp = torch.sum(buffer, 1).add_(1e-10)
norm = torch.sqrt(normp)
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
output = _output.view(input_size)
return output
def forward(self, x):
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
x = self.model.layer4(x)
x = x.view(x.size(0), -1)
x = self.model.fc1(x)
#feature for center loss
x = self.model.fc2(x)
self.features = x
self.features_norm = self.l2_norm(x)
return self.features_norm
def forward_classifier(self,x):
features_norm = self.forward(x)
x = self.model.classifier(features_norm)
return F.log_softmax(x)