diff --git a/translate_tf_mobilenetv1.py b/translate_tf_mobilenetv1.py new file mode 100644 index 00000000..4c90e59f --- /dev/null +++ b/translate_tf_mobilenetv1.py @@ -0,0 +1,65 @@ +import torch +import sys + +from vision.nn.mobilenet import MobileNetV1 +from extract_tf_weights import read_weights + + +def fill_weights_torch_model(weights, state_dict): + for name in state_dict: + if name == 'classifier.weight': + weight = weights['MobilenetV1/Logits/Conv2d_1c_1x1/weights'] + weight = torch.tensor(weight, dtype=torch.float32).permute(3, 2, 0, 1) + assert state_dict[name].size() == weight.size() + state_dict[name] = weight + elif name == 'classifier.bias': + bias = weights['MobilenetV1/Logits/Conv2d_1c_1x1/biases'] + bias = torch.tensor(bias, dtype=torch.float32) + assert state_dict[name].size() == bias.size() + state_dict[name] = bias + elif name.endswith('BatchNorm.weight'): + key = name.replace("features", "MobilenetV1").replace(".", "/").replace('BatchNorm/weight', 'BatchNorm/gamma') + weight = torch.tensor(weights[key], dtype=torch.float32) + assert weight.size() == state_dict[name].size() + state_dict[name] = weight + elif name.endswith('BatchNorm.bias'): + key = name.replace("features", "MobilenetV1").replace(".", "/").replace('BatchNorm/bias', 'BatchNorm/beta') + bias = torch.tensor(weights[key], dtype=torch.float32) + assert bias.size() == state_dict[name].size() + state_dict[name] = bias + elif name.endswith('running_mean'): + key = name.replace("features", "MobilenetV1").replace(".", "/").replace('running_mean', 'moving_mean') + running_mean = torch.tensor(weights[key], dtype=torch.float32) + assert running_mean.size() == state_dict[name].size() + state_dict[name] = running_mean + elif name.endswith('running_var'): + key = name.replace("features", "MobilenetV1").replace(".", "/").replace('running_var', 'moving_variance') + running_var = torch.tensor(weights[key], dtype=torch.float32) + assert running_var.size() == state_dict[name].size() + state_dict[name] = running_var + elif name.endswith('depthwise.weight'): + key = name.replace("features", "MobilenetV1").replace(".", "/") + key = key.replace('depthwise/weight', 'depthwise/depthwise_weights') + weight = torch.tensor(weights[key], dtype=torch.float32).permute(2, 3, 0, 1) + assert weight.size() == state_dict[name].size() + state_dict[name] = weight + else: + key = name.replace("features", "MobilenetV1").replace(".", "/").replace('weight', 'weights') + weight = torch.tensor(weights[key], dtype=torch.float32).permute(3, 2, 0, 1) + assert weight.size() == state_dict[name].size() + state_dict[name] = weight + + +if __name__ == '__main__': + if len(sys.argv) < 3: + print("Usage: python translate_tf_modelnetv1.py ") + tf_model = sys.argv[1] + torch_weights_path = sys.argv[2] + print("Extract weights from tf model.") + weights = read_weights(tf_model) + + net = MobileNetV1(1001) + states = net.state_dict() + print("Translate tf weights.") + fill_weights_torch_model(weights, states) + torch.save(states, torch_weights_path) \ No newline at end of file diff --git a/vision/nn/mobilenet.py b/vision/nn/mobilenet.py new file mode 100644 index 00000000..d10c7009 --- /dev/null +++ b/vision/nn/mobilenet.py @@ -0,0 +1,78 @@ +import torch.nn as nn +import torch.nn.functional as F +from collections import OrderedDict, namedtuple + + +ConvParam = namedtuple('ConvParam', ['stride', 'depth']) + + +class MobileNetV1(nn.Module): + def __init__(self, num_classes, dropout_keep_prob=0.5, depth_multiplier=1.0, + min_depth=8, num_feature_layers=0, in_channels=3): + super(MobileNetV1, self).__init__() + + self.num_classes = num_classes + self.depth_multiplier = depth_multiplier + self.dropout_keep_prob = dropout_keep_prob + + # inital normal conv2d, Conv2d_0 layer + feature_layers = OrderedDict([ + ('Conv2d_0', nn.Conv2d(in_channels, 32, kernel_size=(3, 3), stride=2, padding=1, bias=False)), + ('Conv2d_0/BatchNorm', nn.BatchNorm2d(32, eps=1e-03)), + ('Conv2d_0/Relu6', nn.ReLU6(inplace=True)) + ]) + + # depthwise separable layer params + conv_params = [ + ConvParam(stride=1, depth=64), + ConvParam(stride=2, depth=128), + ConvParam(stride=1, depth=128), + ConvParam(stride=2, depth=256), + ConvParam(stride=1, depth=256), + ConvParam(stride=2, depth=512), + ConvParam(stride=1, depth=512), + ConvParam(stride=1, depth=512), + ConvParam(stride=1, depth=512), + ConvParam(stride=1, depth=512), + ConvParam(stride=1, depth=512), + ConvParam(stride=2, depth=1024), + ConvParam(stride=1, depth=1024) + ] + + # depthwise separable Conv2d + in_channels = 32 + for i, param in enumerate(conv_params): + if 1 <= num_feature_layers <= i + 1: + break + i = i + 1 # make the layer index start from 1 to follow tensorflow MobileNetformat. + # with groups=output_channels, Conv2d is a depthwise Conv2d. + feature_layers[f'Conv2d_{i}_depthwise'] = nn.Conv2d(in_channels, in_channels, + kernel_size=(3, 3), + stride=param.stride, + padding=1, + groups=in_channels, + bias=False) + feature_layers[f'Conv2d_{i}_depthwise/BatchNorm'] = nn.BatchNorm2d(in_channels, eps=1e-03) + feature_layers[f'Conv2d_{i}_depthwise/Relu6'] = nn.ReLU6(inplace=True) + + # pointwise Conv2d + out_channels = max(int(param.depth * depth_multiplier), min_depth) + feature_layers[f'Conv2d_{i}_pointwise'.format(i)] = nn.Conv2d(in_channels, out_channels, + kernel_size=(1, 1), stride=1, bias=False) + feature_layers[f'Conv2d_{i}_pointwise/BatchNorm'.format(i)] = nn.BatchNorm2d(out_channels, eps=1e-03) + feature_layers[f'Conv2d_{i}_pointwise/Relu6'.format(i)] = nn.ReLU6(inplace=True) + in_channels = out_channels + + self.features = nn.Sequential(feature_layers) + self.classifier = nn.Conv2d(in_channels, num_classes, kernel_size=(1, 1), stride=1) + + def forward(self, x): + x = self.features(x) + _, _, height, width = x.size() + # the kernel size 7x7 is for 224x224 inputs of ImageNet images + kernel_size = (min(height, 7), min(width, 7)) + x = F.avg_pool2d(x, kernel_size=kernel_size) + x = F.dropout2d(x, self.dropout_keep_prob) + x = self.classifier(x) + x = x.view(-1, self.num_classes) + return x