Skip to content
This repository has been archived by the owner on Oct 6, 2024. It is now read-only.

Commit

Permalink
improvement(zcls): 1. 更新resnet测试;
Browse files Browse the repository at this point in the history
2. fix预训练模型num_classes不一致出错问题
  • Loading branch information
zjykzj committed Nov 27, 2020
1 parent d117bef commit 0de4046
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 14 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
.idea/
*.pyc
outputs/
60 changes: 60 additions & 0 deletions configs/r50_custom_pretrained_cifar100_224.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
NUM_GPUS: 1
NUM_NODES: 1
RANK_ID: 0
DIST_BACKEND: "nccl"
RNG_SEED: 1
OUTPUT_DIR: 'outputs/r50_custom_pretrained_cifar100_224'
TRAIN:
LOG_STEP: 10
MAX_EPOCH: 200
SAVE_EPOCH: 5
EVAL_EPOCH: 5
RESUME: False
USE_TENSORBOARD: True
DATASET:
NAME: 'CIFAR100'
DATA_DIR: './data/cifar'
TRANSFORM:
MEAN: (0.5071, 0.4865, 0.4409)
STD: (0.1942, 0.1918, 0.1958)
TRAIN:
SHORTER_SIDE: 224
CENTER_CROP: True
TRAIN_CROP_SIZE: 224
TEST:
SHORTER_SIDE: 224
CENTER_CROP: True
TEST_CROP_SIZE: 224
DATALOADER:
TRAIN_BATCH_SIZE: 96
TEST_BATCH_SIZE: 96
NUM_WORKERS: 8
MODEL:
NAME: 'ResNet'
PRETRAINED: ''
TORCHVISION_PRETRAINED: True
SYNC_BN: False
BACKBONE:
ARCH: 'resnet50'
HEAD:
FEATURE_DIMS: 2048
NUM_CLASSES: 100
RECOGNIZER:
NAME: 'ResNet_Custom'
CRITERION:
NAME: 'CrossEntropyLoss'
OPTIMIZER:
NAME: 'SGD'
LR: 1e-3
WEIGHT_DECAY: 1e-5
SGD:
MOMENTUM: 0.9
LR_SCHEDULER:
NAME: 'MultiStepLR'
IS_WARMUP: True
GAMMA: 0.1
MULTISTEP_LR:
MILESTONES: [ 100, 150, 175 ]
WARMUP:
ITERATION: 5
MULTIPLIER: 1.0
2 changes: 1 addition & 1 deletion configs/r50_pytorch_pretrained_cifar100_224.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ MODEL:
HEAD:
NUM_CLASSES: 100
RECOGNIZER:
NAME: 'R50_Pytorch'
NAME: 'ResNet_Pytorch'
CRITERION:
NAME: 'CrossEntropyLoss'
OPTIMIZER:
Expand Down
9 changes: 7 additions & 2 deletions tests/test_model/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,28 @@
"""

import torch
from torchvision.models import resnet50

from zcls.model.recognizers.resnet_recognizer import ResNetRecognizer


def test_resnet():
model = ResNetRecognizer(
arch=50,
arch="resnet50",
feature_dims=2048,
num_classes=1000
)
print(model)

data = torch.randn(1, 3, 224, 224)
outputs = model(data)
outputs = model(data)['probs']
print(outputs.shape)

assert outputs.shape == (1, 1000)

model = resnet50()
print(model)


if __name__ == '__main__':
test_resnet()
1 change: 1 addition & 0 deletions zcls/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def add_config(_C):
_C.MODEL.PRETRAINED = ""
_C.MODEL.TORCHVISION_PRETRAINED = False
_C.MODEL.SYNC_BN = False
_C.MODEL.GROUPS = 3

_C.MODEL.BACKBONE = CN()
# for ResNet
Expand Down
20 changes: 20 additions & 0 deletions zcls/model/batchnorm_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import torch.nn as nn
from functools import partial


def convert_sync_bn(model, process_group, device):
Expand All @@ -19,3 +20,22 @@ def convert_sync_bn(model, process_group, device):
setattr(model, child_name, m)
else:
convert_sync_bn(child, process_group, device)


def get_norm(cfg):
"""
Args:
cfg (CfgNode): model building configs, details are in the comments of
the config file.
Returns:
nn.Module: the normalization layer.
"""
if cfg.MODEL.NORM_TYPE == "batchnorm2d":
return nn.BatchNorm2d
elif cfg.MODEL.NORM_TYPE == "groupnorm":
num_groups = cfg.MODEL.GROUPS
return partial(nn.GroupNorm, num_groups=num_groups)
else:
raise NotImplementedError(
"Norm type {} is not supported".format(cfg.BN.NORM_TYPE)
)
19 changes: 19 additions & 0 deletions zcls/model/init_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# -*- coding: utf-8 -*-

"""
@date: 2020/11/26 下午10:47
@file: init_helper.py
@author: zj
@description:
"""

import math
from torch.nn import init


def reset_parameters(layer) -> None:
init.kaiming_uniform_(layer.weight, a=math.sqrt(5))
if layer.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(layer.weight)
bound = 1 / math.sqrt(fan_in)
init.uniform_(layer.bias, -bound, bound)
38 changes: 27 additions & 11 deletions zcls/model/recognizers/resnet_recognizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def __init__(self,
fix_bn=False,
partial_bn=False):
super(ResNetRecognizer, self).__init__()

self.num_classes = num_classes
self.fix_bn = fix_bn
self.partial_bn = partial_bn

block_layer, layer_blocks = arch_settings[arch]

self.backbone = ResNetBackbone(
Expand All @@ -57,20 +62,20 @@ def __init__(self,
)
self.head = ResNetHead(
feature_dims=feature_dims,
num_classes=num_classes
num_classes=1000
)

self._init_weights(arch=arch, pretrained=torchvision_pretrained)
self.fix_bn = fix_bn
self.partial_bn = partial_bn

def _init_weights(self, arch='resnet18', pretrained=False):
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=True)
res = self.backbone.load_state_dict(state_dict, strict=False)
print(res)
res = self.head.load_state_dict(state_dict, strict=False)
print(res)
self.backbone.load_state_dict(state_dict, strict=False)
self.head.load_state_dict(state_dict, strict=False)
if self.num_classes != 1000:
fc = self.head.fc
fc_features = fc.in_features
self.head.fc = nn.Linear(fc_features, self.num_classes)

def freezing_bn(self: T) -> None:
count = 0
Expand Down Expand Up @@ -109,14 +114,25 @@ def __init__(self,
fix_bn=False,
partial_bn=False):
super(ResNet_Pytorch, self).__init__()

self.num_classes = num_classes
self.fix_bn = fix_bn
self.partial_bn = partial_bn

if arch == 'resnet18':
self.model = resnet.resnet18(pretrained=torchvision_pretrained, num_classes=num_classes)
self.model = resnet.resnet18(pretrained=torchvision_pretrained, num_classes=1000)
elif arch == 'resnet50':
self.model = resnet.resnet50(pretrained=torchvision_pretrained, num_classes=num_classes)
self.model = resnet.resnet50(pretrained=torchvision_pretrained, num_classes=1000)
else:
raise ValueError('no such value')
self.fix_bn = fix_bn
self.partial_bn = partial_bn

self._init_weights()

def _init_weights(self):
if self.num_classes != 1000:
fc = self.model.fc
fc_features = fc.in_features
self.model.fc = nn.Linear(fc_features, self.num_classes)

def freezing_bn(self):
count = 0
Expand Down

0 comments on commit 0de4046

Please sign in to comment.