forked from nimashoghi/smart-quantization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcifar_base.py
82 lines (70 loc) · 2.42 KB
/
cifar_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from abc import abstractmethod
from argparse import ArgumentParser
from typing import Optional
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms
test_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
]
)
transform = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
test_transform,
]
)
class CIFARBaseDataModule(LightningDataModule):
@staticmethod
def add_argparse_args(parent_parser: ArgumentParser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--val_batch_size", type=int, help="validation batch size")
return parser
def __init__(self, hparams):
super().__init__()
self.hparams.update(hparams)
if self.hparams.val_batch_size is None:
self.hparams.val_batch_size = max(self.hparams.batch_size // 4, 1)
@abstractmethod
def make_dataset(self, *args, **kwargs):
raise Exception("Not implemented")
def setup(self, stage: Optional[str]):
if stage == "fit" or stage is None:
self.cifar_train = self.make_dataset(
"train", transform=transform, train=True, download=True
)
self.cifar_val = self.make_dataset(
"test", transform=test_transform, train=False, download=True
)
if stage == "test" or stage is None:
self.cifar_test = self.make_dataset(
"test", transform=test_transform, train=False, download=True
)
def train_dataloader(self):
return DataLoader(
self.cifar_train,
batch_size=self.hparams.batch_size,
shuffle=True,
num_workers=8,
pin_memory=True,
)
def val_dataloader(self):
return DataLoader(
self.cifar_val,
batch_size=self.hparams.val_batch_size,
shuffle=False,
num_workers=8,
pin_memory=True,
)
def test_dataloader(self):
return DataLoader(
self.cifar_test,
batch_size=self.hparams.val_batch_size,
shuffle=False,
num_workers=8,
pin_memory=True,
)