Skip to content

Commit

Permalink
Also support MobileNetV3-Small
Browse files Browse the repository at this point in the history
  • Loading branch information
aclex committed Apr 5, 2020
1 parent ff49219 commit 8dffd20
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 7 deletions.
11 changes: 7 additions & 4 deletions train_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vision.ssd.mobilenetv1_ssd import create_mobilenetv1_ssd
from vision.ssd.mobilenetv1_ssd_lite import create_mobilenetv1_ssd_lite
from vision.ssd.mobilenet_v2_ssd_lite import create_mobilenetv2_ssd_lite
from vision.ssd.mobilenetv3_ssd_lite import create_mobilenetv3_ssd_lite
from vision.ssd.mobilenetv3_ssd_lite import create_mobilenetv3_large_ssd_lite, create_mobilenetv3_small_ssd_lite
from vision.ssd.squeezenet_ssd_lite import create_squeezenet_ssd_lite
from vision.datasets.voc_dataset import VOCDataset
from vision.datasets.open_images import OpenImagesDataset
Expand All @@ -37,7 +37,7 @@


parser.add_argument('--net', default="vgg16-ssd",
help="The network architecture, it can be mb1-ssd, mb1-lite-ssd, mb2-ssd-lite, mb3-ssd-lite or vgg16-ssd.")
help="The network architecture, it can be mb1-ssd, mb1-lite-ssd, mb2-ssd-lite, mb3-large-ssd-lite, mb3-small-ssd-lite or vgg16-ssd.")
parser.add_argument('--freeze_base_net', action='store_true',
help="Freeze base net layers.")
parser.add_argument('--freeze_net', action='store_true',
Expand Down Expand Up @@ -187,8 +187,11 @@ def test(loader, net, criterion, device):
elif args.net == 'mb2-ssd-lite':
create_net = lambda num: create_mobilenetv2_ssd_lite(num, width_mult=args.mb2_width_mult)
config = mobilenetv1_ssd_config
elif args.net == 'mb3-ssd-lite':
create_net = lambda num: create_mobilenetv3_ssd_lite(num)
elif args.net == 'mb3-large-ssd-lite':
create_net = lambda num: create_mobilenetv3_large_ssd_lite(num)
config = mobilenetv1_ssd_config
elif args.net == 'mb3-small-ssd-lite':
create_net = lambda num: create_mobilenetv3_small_ssd_lite(num)
config = mobilenetv1_ssd_config
else:
logging.fatal("The net type is wrong.")
Expand Down
12 changes: 12 additions & 0 deletions vision/nn/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,15 @@ def forward(self, x):
class MobileNetV3_Small(nn.Module):
def __init__(self, num_classes=1000):
super(MobileNetV3_Small, self).__init__()

self.features = []

self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1, bias=False)
self.features.append(self.conv1)
self.bn1 = nn.BatchNorm2d(16)
self.features.append(self.bn1)
self.hs1 = hswish()
self.features.append(self.hs1)

self.bneck = nn.Sequential(
Block(3, 16, 16, 16, nn.ReLU(inplace=True), SeModule(16), 2),
Expand All @@ -167,16 +173,22 @@ def __init__(self, num_classes=1000):
Block(5, 96, 576, 96, hswish(), SeModule(96), 1),
)

self.features.extend([block for block in self.bneck])

self.conv2 = nn.Conv2d(96, 576, kernel_size=1, stride=1, padding=0, bias=False)
self.features.append(self.conv2)
self.bn2 = nn.BatchNorm2d(576)
self.features.append(self.bn2)
self.hs2 = hswish()
self.features.append(self.hs2)
self.linear3 = nn.Linear(576, 1280)
self.bn3 = nn.BatchNorm1d(1280)
self.hs3 = hswish()
self.linear4 = nn.Linear(1280, num_classes)
self.init_params()

self.features = nn.Sequential(*self.features)

def init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
Expand Down
40 changes: 37 additions & 3 deletions vision/ssd/mobilenetv3_ssd_lite.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch.nn import Conv2d, Sequential, ModuleList, BatchNorm2d
from torch import nn
from ..nn.mobilenetv3 import MobileNetV3_Large, Block, hswish
from ..nn.mobilenetv3 import MobileNetV3_Large, MobileNetV3_Small, Block, hswish

from .ssd import SSD
from .predictor import Predictor
Expand All @@ -21,12 +21,46 @@ def SeperableConv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=
)


def create_mobilenetv3_ssd_lite(num_classes, width_mult=1.0, use_batch_norm=True, onnx_compatible=False, is_test=False):
def create_mobilenetv3_large_ssd_lite(num_classes, width_mult=1.0, use_batch_norm=True, onnx_compatible=False, is_test=False):
base_net = MobileNetV3_Large().features

source_layer_indexes = [ 16, 20 ]
extras = ModuleList([
Block(3, 1280, 512, 256, hswish(), stride=2),
Block(3, 960, 512, 256, hswish(), stride=2),
Block(3, 512, 256, 128, hswish(), stride=2),
Block(3, 256, 256, 128, hswish(), stride=2),
Block(3, 256, 64, 64, hswish(), stride=2)
])

regression_headers = ModuleList([
SeperableConv2d(in_channels=round(576 * width_mult), out_channels=6 * 4,
kernel_size=3, padding=1, onnx_compatible=False),
SeperableConv2d(in_channels=1280, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
SeperableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
Conv2d(in_channels=64, out_channels=6 * 4, kernel_size=1),
])

classification_headers = ModuleList([
SeperableConv2d(in_channels=round(576 * width_mult), out_channels=6 * num_classes, kernel_size=3, padding=1),
SeperableConv2d(in_channels=1280, out_channels=6 * num_classes, kernel_size=3, padding=1),
SeperableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1),
SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
Conv2d(in_channels=64, out_channels=6 * num_classes, kernel_size=1),
])

return SSD(num_classes, base_net, source_layer_indexes,
extras, classification_headers, regression_headers, is_test=is_test, config=config)


def create_mobilenetv3_small_ssd_lite(num_classes, width_mult=1.0, use_batch_norm=True, onnx_compatible=False, is_test=False):
base_net = MobileNetV3_Small().features

source_layer_indexes = [ 11, 16 ]
extras = ModuleList([
Block(3, 576, 512, 256, hswish(), stride=2),
Block(3, 512, 256, 128, hswish(), stride=2),
Block(3, 256, 256, 128, hswish(), stride=2),
Block(3, 256, 64, 64, hswish(), stride=2)
Expand Down

0 comments on commit 8dffd20

Please sign in to comment.