Skip to content

Commit

Permalink
add mobilenetv1
Browse files Browse the repository at this point in the history
  • Loading branch information
qfgaohao committed Jun 18, 2018
1 parent 7749c62 commit 58d3aa8
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 0 deletions.
65 changes: 65 additions & 0 deletions translate_tf_mobilenetv1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
import sys

from vision.nn.mobilenet import MobileNetV1
from extract_tf_weights import read_weights


def fill_weights_torch_model(weights, state_dict):
for name in state_dict:
if name == 'classifier.weight':
weight = weights['MobilenetV1/Logits/Conv2d_1c_1x1/weights']
weight = torch.tensor(weight, dtype=torch.float32).permute(3, 2, 0, 1)
assert state_dict[name].size() == weight.size()
state_dict[name] = weight
elif name == 'classifier.bias':
bias = weights['MobilenetV1/Logits/Conv2d_1c_1x1/biases']
bias = torch.tensor(bias, dtype=torch.float32)
assert state_dict[name].size() == bias.size()
state_dict[name] = bias
elif name.endswith('BatchNorm.weight'):
key = name.replace("features", "MobilenetV1").replace(".", "/").replace('BatchNorm/weight', 'BatchNorm/gamma')
weight = torch.tensor(weights[key], dtype=torch.float32)
assert weight.size() == state_dict[name].size()
state_dict[name] = weight
elif name.endswith('BatchNorm.bias'):
key = name.replace("features", "MobilenetV1").replace(".", "/").replace('BatchNorm/bias', 'BatchNorm/beta')
bias = torch.tensor(weights[key], dtype=torch.float32)
assert bias.size() == state_dict[name].size()
state_dict[name] = bias
elif name.endswith('running_mean'):
key = name.replace("features", "MobilenetV1").replace(".", "/").replace('running_mean', 'moving_mean')
running_mean = torch.tensor(weights[key], dtype=torch.float32)
assert running_mean.size() == state_dict[name].size()
state_dict[name] = running_mean
elif name.endswith('running_var'):
key = name.replace("features", "MobilenetV1").replace(".", "/").replace('running_var', 'moving_variance')
running_var = torch.tensor(weights[key], dtype=torch.float32)
assert running_var.size() == state_dict[name].size()
state_dict[name] = running_var
elif name.endswith('depthwise.weight'):
key = name.replace("features", "MobilenetV1").replace(".", "/")
key = key.replace('depthwise/weight', 'depthwise/depthwise_weights')
weight = torch.tensor(weights[key], dtype=torch.float32).permute(2, 3, 0, 1)
assert weight.size() == state_dict[name].size()
state_dict[name] = weight
else:
key = name.replace("features", "MobilenetV1").replace(".", "/").replace('weight', 'weights')
weight = torch.tensor(weights[key], dtype=torch.float32).permute(3, 2, 0, 1)
assert weight.size() == state_dict[name].size()
state_dict[name] = weight


if __name__ == '__main__':
if len(sys.argv) < 3:
print("Usage: python translate_tf_modelnetv1.py <tf_model.pb> <pytorch_weights.pth>")
tf_model = sys.argv[1]
torch_weights_path = sys.argv[2]
print("Extract weights from tf model.")
weights = read_weights(tf_model)

net = MobileNetV1(1001)
states = net.state_dict()
print("Translate tf weights.")
fill_weights_torch_model(weights, states)
torch.save(states, torch_weights_path)
78 changes: 78 additions & 0 deletions vision/nn/mobilenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict, namedtuple


ConvParam = namedtuple('ConvParam', ['stride', 'depth'])


class MobileNetV1(nn.Module):
def __init__(self, num_classes, dropout_keep_prob=0.5, depth_multiplier=1.0,
min_depth=8, num_feature_layers=0, in_channels=3):
super(MobileNetV1, self).__init__()

self.num_classes = num_classes
self.depth_multiplier = depth_multiplier
self.dropout_keep_prob = dropout_keep_prob

# inital normal conv2d, Conv2d_0 layer
feature_layers = OrderedDict([
('Conv2d_0', nn.Conv2d(in_channels, 32, kernel_size=(3, 3), stride=2, padding=1, bias=False)),
('Conv2d_0/BatchNorm', nn.BatchNorm2d(32, eps=1e-03)),
('Conv2d_0/Relu6', nn.ReLU6(inplace=True))
])

# depthwise separable layer params
conv_params = [
ConvParam(stride=1, depth=64),
ConvParam(stride=2, depth=128),
ConvParam(stride=1, depth=128),
ConvParam(stride=2, depth=256),
ConvParam(stride=1, depth=256),
ConvParam(stride=2, depth=512),
ConvParam(stride=1, depth=512),
ConvParam(stride=1, depth=512),
ConvParam(stride=1, depth=512),
ConvParam(stride=1, depth=512),
ConvParam(stride=1, depth=512),
ConvParam(stride=2, depth=1024),
ConvParam(stride=1, depth=1024)
]

# depthwise separable Conv2d
in_channels = 32
for i, param in enumerate(conv_params):
if 1 <= num_feature_layers <= i + 1:
break
i = i + 1 # make the layer index start from 1 to follow tensorflow MobileNetformat.
# with groups=output_channels, Conv2d is a depthwise Conv2d.
feature_layers[f'Conv2d_{i}_depthwise'] = nn.Conv2d(in_channels, in_channels,
kernel_size=(3, 3),
stride=param.stride,
padding=1,
groups=in_channels,
bias=False)
feature_layers[f'Conv2d_{i}_depthwise/BatchNorm'] = nn.BatchNorm2d(in_channels, eps=1e-03)
feature_layers[f'Conv2d_{i}_depthwise/Relu6'] = nn.ReLU6(inplace=True)

# pointwise Conv2d
out_channels = max(int(param.depth * depth_multiplier), min_depth)
feature_layers[f'Conv2d_{i}_pointwise'.format(i)] = nn.Conv2d(in_channels, out_channels,
kernel_size=(1, 1), stride=1, bias=False)
feature_layers[f'Conv2d_{i}_pointwise/BatchNorm'.format(i)] = nn.BatchNorm2d(out_channels, eps=1e-03)
feature_layers[f'Conv2d_{i}_pointwise/Relu6'.format(i)] = nn.ReLU6(inplace=True)
in_channels = out_channels

self.features = nn.Sequential(feature_layers)
self.classifier = nn.Conv2d(in_channels, num_classes, kernel_size=(1, 1), stride=1)

def forward(self, x):
x = self.features(x)
_, _, height, width = x.size()
# the kernel size 7x7 is for 224x224 inputs of ImageNet images
kernel_size = (min(height, 7), min(width, 7))
x = F.avg_pool2d(x, kernel_size=kernel_size)
x = F.dropout2d(x, self.dropout_keep_prob)
x = self.classifier(x)
x = x.view(-1, self.num_classes)
return x

0 comments on commit 58d3aa8

Please sign in to comment.