forked from nimashoghi/smart-quantization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresnet.py
44 lines (36 loc) · 1.36 KB
/
resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from argparse import ArgumentParser
from torch.utils.checkpoint import checkpoint
import torchmetrics.functional as FM
from argparse_utils.mapping import mapping_action
from smart_compress.models.base import BaseModule
from smart_compress.models.pytorch.resnet import resnet18, resnet34, resnet50
class ResNetModule(BaseModule):
@staticmethod
def add_argparse_args(parent_parser):
parser = ArgumentParser(
parents=[BaseModule.add_argparse_args(parent_parser)], add_help=False
)
parser.add_argument(
"--resnet_model",
action=mapping_action(
dict(resnet18=resnet18, resnet34=resnet34, resnet50=resnet50)
),
default="resnet34",
dest="resnet_model_fn",
)
parser.add_argument("--num_classes", default=10, type=int)
return parser
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
self.model = self.hparams.resnet_model_fn(num_classes=self.hparams.num_classes)
def forward(self, x):
return self.model(x)
def accuracy_function(self, outputs, ground_truth):
return dict(
accuracy=FM.accuracy(
outputs.argmax(dim=1),
ground_truth,
num_classes=self.hparams.num_classes,
)
)