forked from nimashoghi/smart-quantization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcifar10.py
26 lines (21 loc) · 791 Bytes
/
cifar10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from argparse import ArgumentParser
from smart_compress.data.cifar_base import CIFARBaseDataModule
from torchvision.datasets import CIFAR10
class CIFAR10DataModule(CIFARBaseDataModule):
@staticmethod
def add_argparse_args(parent_parser: ArgumentParser):
parser = ArgumentParser(
parents=[CIFARBaseDataModule.add_argparse_args(parent_parser)],
add_help=False,
)
parser.add_argument(
"--batch_size",
default=8,
type=int,
help="batch size",
)
return parser
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def make_dataset(self, name, *args, **kwargs):
return CIFAR10(f"./datasets/cifar10/cifar10-{name}", *args, **kwargs)