forked from nimashoghi/smart-quantization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptimizer.py
149 lines (122 loc) · 5.68 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from torch.optim import SGD, Adam, Optimizer
from torch.optim.adamw import AdamW
__all__ = ["OptimLP"]
class OptimLP(Optimizer):
"""
A low-precision optimizer wrapper that handles weight, gradient, accumulator quantization.
Args:
- :attr: `optim`: underlying optimizer to use
- :attr: `weight_quant`: a weight quantization function which takes a pytorch tensor and returns a tensor. If None, does not quantize weight.
- :attr: `grad_quant`: a gradient quantization function which takes a pytorch tensor and returns a tensor. If None, does not quantize weight.
- :attr: `grad_scaling`: float, scaling factor before apply gradient quantization.
- :attr: `momentum_quant`: a momentum quantization function which takes a pytorch tensor and returns a tensor.
If None, does not quantize weight.
- :attr: `acc_quant`: a accumulator quantization function which takes
a pytorch tensor and returns a tensor. If not None, a
OptimLP object would create memory copies of model parameters that serve as
gradient accumulators. If None, does not use gradient accumulators.
Example:
>>> weight_q = quantizer(...) # define weight quantization
>>> optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer = OptimLP(optiimizer, weight_quant=weight_q)
"""
def __init__(
self,
optim,
weight_quant=None,
grad_scaling=1.0,
grad_quant=None,
momentum_quant=None,
acc_quant=None,
):
super().__init__(optim.param_groups, optim.defaults) # place holder
# python dictionary does not copy by default
self.param_groups = optim.param_groups
self.optim = optim
assert grad_scaling > 0, "gradient scaling must be positive"
self.grad_scaling = grad_scaling
self.weight_quant = weight_quant
self.grad_quant = grad_quant
self.momentum_quant = momentum_quant
self.acc_quant = acc_quant
if isinstance(self.optim, SGD):
self.momentum_keys = [("momentum_buffer", dict())]
elif isinstance(self.optim, Adam) or isinstance(self.optim, AdamW):
# TODO: support amsgrad
self.momentum_keys = [
("exp_avg", dict()),
("exp_avg_sq", dict(all_positive=True)),
]
else:
raise NotImplementedError("Only supporting Adam and SGD for now. ")
if self.acc_quant != None:
self.weight_acc = {}
for group in self.param_groups:
for p in group["params"]:
self.weight_acc[p] = p.detach().clone().type_as(p)
def _pre_closure(self):
# quantize gradient
if not self.grad_quant is None:
for group in self.param_groups:
if "no_grad_compression" in group and group["no_grad_compression"]:
continue
for p in group["params"]:
if not p.requires_grad or p.grad is None:
continue
p.grad.data = self.grad_quant(p.grad.data * self.grad_scaling).data
# switch acc into weight before stepping
if not self.acc_quant is None:
for group in self.param_groups:
for p in group["params"]:
p.data = self.weight_acc[p].data
def _post_closure(self):
# quantize gradient
if not self.grad_quant is None:
for group in self.param_groups:
if "no_grad_compression" in group and group["no_grad_compression"]:
continue
for p in group["params"]:
if not p.requires_grad or p.grad is None:
continue
p.grad.data = self.grad_quant(p.grad.data * self.grad_scaling).data
# quantize weight from acc
if not self.weight_quant is None:
for group in self.param_groups:
if "no_weight_compression" in group and group["no_weight_compression"]:
continue
for p in group["params"]:
p.data = self.weight_quant(p.data).data
# quantize momentum
if not self.momentum_quant is None:
for group in self.param_groups:
if (
"no_momentum_compression" in group
and group["no_momentum_compression"]
):
continue
if isinstance(self.optim, SGD) and group["momentum"] == 0:
continue
for p in group["params"]:
if not p.requires_grad or p.grad is None:
continue
param_state = self.optim.state[p]
for key, kwargs in self.momentum_keys:
param_state[key].data = self.momentum_quant(
param_state[key], **kwargs
).data
def step(self, closure=None):
"""
Performs one step of optimization with the underlying optimizer.
Quantizes gradient and momentum before stepping. Quantizes gradient accumulator and weight after stepping.
"""
def closure_(*args, **kwargs):
value = closure(*args, **kwargs)
self._pre_closure()
return value
loss = self.optim.step(closure=closure_)
self._post_closure()
return loss
def __repr__(self):
return "LP Optimizer: {}".format(self.optim.__repr__())
def __str__(self):
return "LP Optimizer: {}".format(self.optim.__str__())