-
Notifications
You must be signed in to change notification settings - Fork 535
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
143 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import torch | ||
import sys | ||
|
||
from vision.nn.mobilenet import MobileNetV1 | ||
from extract_tf_weights import read_weights | ||
|
||
|
||
def fill_weights_torch_model(weights, state_dict): | ||
for name in state_dict: | ||
if name == 'classifier.weight': | ||
weight = weights['MobilenetV1/Logits/Conv2d_1c_1x1/weights'] | ||
weight = torch.tensor(weight, dtype=torch.float32).permute(3, 2, 0, 1) | ||
assert state_dict[name].size() == weight.size() | ||
state_dict[name] = weight | ||
elif name == 'classifier.bias': | ||
bias = weights['MobilenetV1/Logits/Conv2d_1c_1x1/biases'] | ||
bias = torch.tensor(bias, dtype=torch.float32) | ||
assert state_dict[name].size() == bias.size() | ||
state_dict[name] = bias | ||
elif name.endswith('BatchNorm.weight'): | ||
key = name.replace("features", "MobilenetV1").replace(".", "/").replace('BatchNorm/weight', 'BatchNorm/gamma') | ||
weight = torch.tensor(weights[key], dtype=torch.float32) | ||
assert weight.size() == state_dict[name].size() | ||
state_dict[name] = weight | ||
elif name.endswith('BatchNorm.bias'): | ||
key = name.replace("features", "MobilenetV1").replace(".", "/").replace('BatchNorm/bias', 'BatchNorm/beta') | ||
bias = torch.tensor(weights[key], dtype=torch.float32) | ||
assert bias.size() == state_dict[name].size() | ||
state_dict[name] = bias | ||
elif name.endswith('running_mean'): | ||
key = name.replace("features", "MobilenetV1").replace(".", "/").replace('running_mean', 'moving_mean') | ||
running_mean = torch.tensor(weights[key], dtype=torch.float32) | ||
assert running_mean.size() == state_dict[name].size() | ||
state_dict[name] = running_mean | ||
elif name.endswith('running_var'): | ||
key = name.replace("features", "MobilenetV1").replace(".", "/").replace('running_var', 'moving_variance') | ||
running_var = torch.tensor(weights[key], dtype=torch.float32) | ||
assert running_var.size() == state_dict[name].size() | ||
state_dict[name] = running_var | ||
elif name.endswith('depthwise.weight'): | ||
key = name.replace("features", "MobilenetV1").replace(".", "/") | ||
key = key.replace('depthwise/weight', 'depthwise/depthwise_weights') | ||
weight = torch.tensor(weights[key], dtype=torch.float32).permute(2, 3, 0, 1) | ||
assert weight.size() == state_dict[name].size() | ||
state_dict[name] = weight | ||
else: | ||
key = name.replace("features", "MobilenetV1").replace(".", "/").replace('weight', 'weights') | ||
weight = torch.tensor(weights[key], dtype=torch.float32).permute(3, 2, 0, 1) | ||
assert weight.size() == state_dict[name].size() | ||
state_dict[name] = weight | ||
|
||
|
||
if __name__ == '__main__': | ||
if len(sys.argv) < 3: | ||
print("Usage: python translate_tf_modelnetv1.py <tf_model.pb> <pytorch_weights.pth>") | ||
tf_model = sys.argv[1] | ||
torch_weights_path = sys.argv[2] | ||
print("Extract weights from tf model.") | ||
weights = read_weights(tf_model) | ||
|
||
net = MobileNetV1(1001) | ||
states = net.state_dict() | ||
print("Translate tf weights.") | ||
fill_weights_torch_model(weights, states) | ||
torch.save(states, torch_weights_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from collections import OrderedDict, namedtuple | ||
|
||
|
||
ConvParam = namedtuple('ConvParam', ['stride', 'depth']) | ||
|
||
|
||
class MobileNetV1(nn.Module): | ||
def __init__(self, num_classes, dropout_keep_prob=0.5, depth_multiplier=1.0, | ||
min_depth=8, num_feature_layers=0, in_channels=3): | ||
super(MobileNetV1, self).__init__() | ||
|
||
self.num_classes = num_classes | ||
self.depth_multiplier = depth_multiplier | ||
self.dropout_keep_prob = dropout_keep_prob | ||
|
||
# inital normal conv2d, Conv2d_0 layer | ||
feature_layers = OrderedDict([ | ||
('Conv2d_0', nn.Conv2d(in_channels, 32, kernel_size=(3, 3), stride=2, padding=1, bias=False)), | ||
('Conv2d_0/BatchNorm', nn.BatchNorm2d(32, eps=1e-03)), | ||
('Conv2d_0/Relu6', nn.ReLU6(inplace=True)) | ||
]) | ||
|
||
# depthwise separable layer params | ||
conv_params = [ | ||
ConvParam(stride=1, depth=64), | ||
ConvParam(stride=2, depth=128), | ||
ConvParam(stride=1, depth=128), | ||
ConvParam(stride=2, depth=256), | ||
ConvParam(stride=1, depth=256), | ||
ConvParam(stride=2, depth=512), | ||
ConvParam(stride=1, depth=512), | ||
ConvParam(stride=1, depth=512), | ||
ConvParam(stride=1, depth=512), | ||
ConvParam(stride=1, depth=512), | ||
ConvParam(stride=1, depth=512), | ||
ConvParam(stride=2, depth=1024), | ||
ConvParam(stride=1, depth=1024) | ||
] | ||
|
||
# depthwise separable Conv2d | ||
in_channels = 32 | ||
for i, param in enumerate(conv_params): | ||
if 1 <= num_feature_layers <= i + 1: | ||
break | ||
i = i + 1 # make the layer index start from 1 to follow tensorflow MobileNetformat. | ||
# with groups=output_channels, Conv2d is a depthwise Conv2d. | ||
feature_layers[f'Conv2d_{i}_depthwise'] = nn.Conv2d(in_channels, in_channels, | ||
kernel_size=(3, 3), | ||
stride=param.stride, | ||
padding=1, | ||
groups=in_channels, | ||
bias=False) | ||
feature_layers[f'Conv2d_{i}_depthwise/BatchNorm'] = nn.BatchNorm2d(in_channels, eps=1e-03) | ||
feature_layers[f'Conv2d_{i}_depthwise/Relu6'] = nn.ReLU6(inplace=True) | ||
|
||
# pointwise Conv2d | ||
out_channels = max(int(param.depth * depth_multiplier), min_depth) | ||
feature_layers[f'Conv2d_{i}_pointwise'.format(i)] = nn.Conv2d(in_channels, out_channels, | ||
kernel_size=(1, 1), stride=1, bias=False) | ||
feature_layers[f'Conv2d_{i}_pointwise/BatchNorm'.format(i)] = nn.BatchNorm2d(out_channels, eps=1e-03) | ||
feature_layers[f'Conv2d_{i}_pointwise/Relu6'.format(i)] = nn.ReLU6(inplace=True) | ||
in_channels = out_channels | ||
|
||
self.features = nn.Sequential(feature_layers) | ||
self.classifier = nn.Conv2d(in_channels, num_classes, kernel_size=(1, 1), stride=1) | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
_, _, height, width = x.size() | ||
# the kernel size 7x7 is for 224x224 inputs of ImageNet images | ||
kernel_size = (min(height, 7), min(width, 7)) | ||
x = F.avg_pool2d(x, kernel_size=kernel_size) | ||
x = F.dropout2d(x, self.dropout_keep_prob) | ||
x = self.classifier(x) | ||
x = x.view(-1, self.num_classes) | ||
return x |