forked from nimashoghi/smart-quantization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
219 lines (196 loc) · 6.98 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import inspect
import time
from argparse import ArgumentParser
from typing import Dict, List, Union
from argparse_utils import mapping_action
from pytorch_lightning import Trainer
from pytorch_lightning.loggers.test_tube import TestTubeLogger
from pytorch_lightning.plugins.training_type import DDPPlugin
from smart_compress.util.globals import Globals
from smart_compress.util.pytorch.autograd import register_autograd_module
from smart_compress.util.pytorch.hooks import register_global_hooks
def _default_name(
args,
data_structures=[
"forward",
"backward",
"weights",
"gradients",
"momentum_vectors",
"loss",
],
):
tags = ",".join(
(
*(
(
data_structure
for data_structure in data_structures
if getattr(args, f"compress_{data_structure}", False)
)
if args.compress
else ()
),
)
)
return "-".join(
(
args.compression_cls.__name__,
args.model_cls.__name__.lower().replace("module", ""),
args.dataset_cls.__name__.lower().replace("datamodule", ""),
tags,
args.tags or "",
time.strftime("%Y%m%d_%H%M%S"),
)
).lower()
def _add_arg_names(args):
for name, value in dict(**vars(args)).items():
if value is None:
continue
if name.endswith("_cls"):
assert inspect.isclass(value), f"{name} is not a class"
setattr(
args,
f"{name}_name",
value.__name__
if value.__module__ is None
or value.__module__ == str.__class__.__module__
else f"{value.__module__}.{value.__name__}",
)
elif name.endswith("_fn"):
assert inspect.isfunction(value), f"{name} is not a function"
setattr(args, f"{name}_name", value.__name__)
return args
def init_model_from_args(argv: Union[None, str, List[str]] = None):
from smart_compress.compress.base import CompressionAlgorithmBase
from smart_compress.compress.bf16 import BF16
from smart_compress.compress.fp8 import FP8
from smart_compress.compress.fp16 import FP16
from smart_compress.compress.fp32 import FP32
from smart_compress.compress.s2fp8 import S2FP8
from smart_compress.compress.smart import SmartFP
from smart_compress.data.cifar10 import CIFAR10DataModule
from smart_compress.data.cifar100 import CIFAR100DataModule
from smart_compress.data.glue import GLUEDataModule
from smart_compress.data.imdb import IMDBDataModule
from smart_compress.models.base import BaseModule
from smart_compress.models.bert import BertModule
from smart_compress.models.inception import InceptionModule
from smart_compress.models.resnet import ResNetModule
if type(argv) == str:
argv = argv.split(" ")
parser = ArgumentParser()
parser.add_argument(
"--model",
action=mapping_action(
dict(bert=BertModule, inception=InceptionModule, resnet=ResNetModule)
),
default="resnet",
help="model name",
dest="model_cls",
)
parser.add_argument(
"--dataset",
action=mapping_action(
dict(
cifar10=CIFAR10DataModule,
cifar100=CIFAR100DataModule,
glue=GLUEDataModule,
imdb=IMDBDataModule,
)
),
default="cifar10",
help="dataset name",
dest="dataset_cls",
)
parser.add_argument("--no_compress", action="store_false", dest="compress")
parser.add_argument(
"--compress",
action=mapping_action(
dict(bf16=BF16, fp8=FP8, fp16=FP16, fp32=FP32, s2fp8=S2FP8, smart=SmartFP)
),
default="fp32",
dest="compression_cls",
)
parser.add_argument(
"--compression_hook_fn",
action=mapping_action(
dict(autograd=register_autograd_module, global_hook=register_global_hooks)
),
default="autograd",
)
parser.add_argument(
"--no_compress_forward",
action="store_false",
dest="compress_forward",
)
parser.add_argument(
"--no_compress_backward",
action="store_false",
dest="compress_backward",
)
parser.add_argument(
"--no_compress_weights",
action="store_false",
dest="compress_weights",
)
parser.add_argument(
"--no_compress_gradients",
action="store_false",
dest="compress_gradients",
)
parser.add_argument(
"--no_compress_momentum_vectors",
action="store_false",
dest="compress_momentum_vectors",
)
parser.add_argument(
"--compress_loss",
action="store_true",
dest="compress_loss",
)
parser.add_argument("--no_add_tags", action="store_false", dest="add_tags")
parser.add_argument("--name", required=False, type=str)
parser.add_argument("--logdir", default="lightning_logs", type=str)
parser.add_argument("--git", action="store_true")
parser.add_argument("--tags", required=False, type=str)
parser = Trainer.add_argparse_args(parser)
parser.set_defaults(terminate_on_nan=True)
args, _ = parser.parse_known_args(argv)
if args.model_cls in (BertModule,):
assert args.dataset_cls in (GLUEDataModule, IMDBDataModule)
elif args.model_cls in (ResNetModule, InceptionModule):
assert args.dataset_cls in (CIFAR10DataModule, CIFAR100DataModule)
else:
raise Exception("invalid model_cls")
parser = args.compression_cls.add_argparse_args(parser)
parser = args.model_cls.add_argparse_args(parser)
parser = args.dataset_cls.add_argparse_args(parser)
args = parser.parse_args(argv)
args = _add_arg_names(args)
if not args.name:
args.name = _default_name(args)
elif args.tags:
args.name += f"-{args.tags}"
trainer = Trainer.from_argparse_args(
args,
logger=TestTubeLogger(args.logdir, name=args.name, create_git_tag=args.git),
# plugins=[DDPPlugin(find_unused_parameters=False)],
)
compression: CompressionAlgorithmBase = (
args.compression_cls(args) if args.compress else None
)
model: BaseModule = args.model_cls(compression=compression, **vars(args))
data = args.dataset_cls(model.hparams)
def log_custom(metrics: Dict[str, float]):
if not trainer.logger_connector.should_update_logs and not trainer.fast_dev_run:
return
trainer.logger.agg_and_log_metrics(metrics, model.global_step)
compression.log = lambda *args, **kwargs: model.log(*args, **kwargs)
compression.log_custom = log_custom
if model.hparams.compress:
model = model.hparams.compression_hook_fn(model, compression, model.hparams)
# set up globals
Globals.compression = compression
Globals.profiler = trainer.profiler
return model, trainer, data