-
Notifications
You must be signed in to change notification settings - Fork 1
/
vgg.py
96 lines (80 loc) · 3.28 KB
/
vgg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
'''
A PyTorch implementation of VGGNet.
The original paper can be found at https://arxiv.org/abs/1409.1556.
'''
import torch
import torch.nn as nn
import numpy as np
from .activations import activetion_func
class VGGBlock(nn.Module):
def __init__(self,
num_convs,
in_channels,
out_channels,
activation='relu'):
super(VGGBlock, self).__init__()
self.activation = activetion_func(activation)
net = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)]
net.append(self.activation)
if num_convs > 1:
for _ in range(1, num_convs):
net.append(
nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
padding=1))
net.append(self.activation)
net.append(nn.MaxPool2d(kernel_size=2, stride=2))
self.net = nn.Sequential(*net)
def forward(self, x):
return self.net(x)
class VGG(nn.Module):
def __init__(self,
conv_arch,
image_size=224,
num_hiddens=4096,
activation='relu',
num_classes=10):
super(VGG, self).__init__()
self.image_size = image_size
self.activation = activetion_func(activation)
self.conv = nn.Sequential()
for i, (num_convs, in_channels, out_channels) in enumerate(conv_arch):
self.conv.add_module(
'vgg_block_' + str(i),
VGGBlock(num_convs, in_channels, out_channels, activation))
feature_size = self._get_feature_size(self.conv)
self.linear = nn.Sequential(
nn.Linear(np.prod(feature_size), num_hiddens), self.activation,
nn.Dropout(0.5), nn.Linear(num_hiddens, num_hiddens),
self.activation, nn.Dropout(0.5), nn.Linear(num_hiddens, 10))
def _get_feature_size(self, net):
x = torch.zeros(1, 3, self.image_size, self.image_size)
out = net(x)
return out.size()[1:]
def forward(self, x):
out = self.conv(x)
out = torch.flatten(out, 1)
return self.linear(out)
def vgg11(image_size=32, ratio=8, activation='relu', num_classes=10):
conv_arch = ((1, 3, 64 // ratio), (1, 64 // ratio, 128 // ratio),
(2, 128 // ratio, 256 // ratio),
(2, 256 // ratio, 512 // ratio), (2, 512 // ratio,
512 // ratio))
num_hiddens = 4096 // ratio
net = VGG(conv_arch,
image_size,
num_hiddens,
activation=activation,
num_classes=num_classes)
return net
if __name__ == "__main__":
from ptflops import get_model_complexity_info
image_size = 32
net = vgg11(activation='relu', ratio=1, image_size=image_size)
macs, params = get_model_complexity_info(net, (3, image_size, image_size),
as_strings=True,
print_per_layer_stat=True,
verbose=True)
print('{:<30} {:<8}'.format('Number of parameters: ', params))
print('{:<30} {:<8}'.format('Computational complexity: ', macs))