From 8dffd20ff0fc8af5780364163e12749134e66d7e Mon Sep 17 00:00:00 2001 From: Alexey Chernov <4ernov@gmail.com> Date: Sat, 4 Apr 2020 21:05:14 +0300 Subject: [PATCH] Also support MobileNetV3-Small --- train_ssd.py | 11 +++++--- vision/nn/mobilenetv3.py | 12 +++++++++ vision/ssd/mobilenetv3_ssd_lite.py | 40 +++++++++++++++++++++++++++--- 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/train_ssd.py b/train_ssd.py index 0e699955..c4ce8219 100644 --- a/train_ssd.py +++ b/train_ssd.py @@ -14,7 +14,7 @@ from vision.ssd.mobilenetv1_ssd import create_mobilenetv1_ssd from vision.ssd.mobilenetv1_ssd_lite import create_mobilenetv1_ssd_lite from vision.ssd.mobilenet_v2_ssd_lite import create_mobilenetv2_ssd_lite -from vision.ssd.mobilenetv3_ssd_lite import create_mobilenetv3_ssd_lite +from vision.ssd.mobilenetv3_ssd_lite import create_mobilenetv3_large_ssd_lite, create_mobilenetv3_small_ssd_lite from vision.ssd.squeezenet_ssd_lite import create_squeezenet_ssd_lite from vision.datasets.voc_dataset import VOCDataset from vision.datasets.open_images import OpenImagesDataset @@ -37,7 +37,7 @@ parser.add_argument('--net', default="vgg16-ssd", - help="The network architecture, it can be mb1-ssd, mb1-lite-ssd, mb2-ssd-lite, mb3-ssd-lite or vgg16-ssd.") + help="The network architecture, it can be mb1-ssd, mb1-lite-ssd, mb2-ssd-lite, mb3-large-ssd-lite, mb3-small-ssd-lite or vgg16-ssd.") parser.add_argument('--freeze_base_net', action='store_true', help="Freeze base net layers.") parser.add_argument('--freeze_net', action='store_true', @@ -187,8 +187,11 @@ def test(loader, net, criterion, device): elif args.net == 'mb2-ssd-lite': create_net = lambda num: create_mobilenetv2_ssd_lite(num, width_mult=args.mb2_width_mult) config = mobilenetv1_ssd_config - elif args.net == 'mb3-ssd-lite': - create_net = lambda num: create_mobilenetv3_ssd_lite(num) + elif args.net == 'mb3-large-ssd-lite': + create_net = lambda num: create_mobilenetv3_large_ssd_lite(num) + config = mobilenetv1_ssd_config + elif args.net == 'mb3-small-ssd-lite': + create_net = lambda num: create_mobilenetv3_small_ssd_lite(num) config = mobilenetv1_ssd_config else: logging.fatal("The net type is wrong.") diff --git a/vision/nn/mobilenetv3.py b/vision/nn/mobilenetv3.py index 91ed96e2..ffb925a9 100644 --- a/vision/nn/mobilenetv3.py +++ b/vision/nn/mobilenetv3.py @@ -149,9 +149,15 @@ def forward(self, x): class MobileNetV3_Small(nn.Module): def __init__(self, num_classes=1000): super(MobileNetV3_Small, self).__init__() + + self.features = [] + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False) + self.features.append(self.conv1) self.bn1 = nn.BatchNorm2d(16) + self.features.append(self.bn1) self.hs1 = hswish() + self.features.append(self.hs1) self.bneck = nn.Sequential( Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2), @@ -167,16 +173,22 @@ def __init__(self, num_classes=1000): Block(5, 96, 576, 96, hswish(), SeModule(96), 1), ) + self.features.extend([block for block in self.bneck]) self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False) + self.features.append(self.conv2) self.bn2 = nn.BatchNorm2d(576) + self.features.append(self.bn2) self.hs2 = hswish() + self.features.append(self.hs2) self.linear3 = nn.Linear(576, 1280) self.bn3 = nn.BatchNorm1d(1280) self.hs3 = hswish() self.linear4 = nn.Linear(1280, num_classes) self.init_params() + self.features = nn.Sequential(*self.features) + def init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): diff --git a/vision/ssd/mobilenetv3_ssd_lite.py b/vision/ssd/mobilenetv3_ssd_lite.py index 14c57673..a2d851c5 100644 --- a/vision/ssd/mobilenetv3_ssd_lite.py +++ b/vision/ssd/mobilenetv3_ssd_lite.py @@ -1,7 +1,7 @@ import torch from torch.nn import Conv2d, Sequential, ModuleList, BatchNorm2d from torch import nn -from ..nn.mobilenetv3 import MobileNetV3_Large, Block, hswish +from ..nn.mobilenetv3 import MobileNetV3_Large, MobileNetV3_Small, Block, hswish from .ssd import SSD from .predictor import Predictor @@ -21,12 +21,46 @@ def SeperableConv2d(in_channels, out_channels, kernel_size=1, stride=1, padding= ) -def create_mobilenetv3_ssd_lite(num_classes, width_mult=1.0, use_batch_norm=True, onnx_compatible=False, is_test=False): +def create_mobilenetv3_large_ssd_lite(num_classes, width_mult=1.0, use_batch_norm=True, onnx_compatible=False, is_test=False): base_net = MobileNetV3_Large().features source_layer_indexes = [ 16, 20 ] extras = ModuleList([ - Block(3, 1280, 512, 256, hswish(), stride=2), + Block(3, 960, 512, 256, hswish(), stride=2), + Block(3, 512, 256, 128, hswish(), stride=2), + Block(3, 256, 256, 128, hswish(), stride=2), + Block(3, 256, 64, 64, hswish(), stride=2) + ]) + + regression_headers = ModuleList([ + SeperableConv2d(in_channels=round(576 * width_mult), out_channels=6 * 4, + kernel_size=3, padding=1, onnx_compatible=False), + SeperableConv2d(in_channels=1280, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False), + SeperableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False), + SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False), + SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False), + Conv2d(in_channels=64, out_channels=6 * 4, kernel_size=1), + ]) + + classification_headers = ModuleList([ + SeperableConv2d(in_channels=round(576 * width_mult), out_channels=6 * num_classes, kernel_size=3, padding=1), + SeperableConv2d(in_channels=1280, out_channels=6 * num_classes, kernel_size=3, padding=1), + SeperableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1), + SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), + SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), + Conv2d(in_channels=64, out_channels=6 * num_classes, kernel_size=1), + ]) + + return SSD(num_classes, base_net, source_layer_indexes, + extras, classification_headers, regression_headers, is_test=is_test, config=config) + + +def create_mobilenetv3_small_ssd_lite(num_classes, width_mult=1.0, use_batch_norm=True, onnx_compatible=False, is_test=False): + base_net = MobileNetV3_Small().features + + source_layer_indexes = [ 11, 16 ] + extras = ModuleList([ + Block(3, 576, 512, 256, hswish(), stride=2), Block(3, 512, 256, 128, hswish(), stride=2), Block(3, 256, 256, 128, hswish(), stride=2), Block(3, 256, 64, 64, hswish(), stride=2)