-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptimizer.py
79 lines (61 loc) · 2.37 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
code has been taken from Annotated Transformers :
https://nlp.seas.harvard.edu/2018/04/03/attention.html
TODO: Checkpoint functionality
"""
import torch
class NoamOpt:
"Optim wrapper that implements rate."
def __init__(self, original_lr, warmup, optimizer):
self.optimizer = optimizer
self._step = 0
self.warmup = warmup
self.original_lr = original_lr
self._rate = 0
def step(self):
"Update parameters and rate"
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"Implement `lrate` above"
if step is None:
step = self._step
return self.original_lr * min(step ** (-0.5), step * self.warmup ** (-1.5))
class OptimizerArgs:
def __init__(self, lr, warmup_steps, beta1=0.9, beta2=0.999):
self.lr = lr
self.warmup_steps = warmup_steps
self.beta1 = beta1
self.beta2 = beta2
def _get_params(model_params):
params = []
for k, p in model_params:
if p.requires_grad:
params.append(p)
return params
def _get_optim(params, args):
return NoamOpt(original_lr=args.lr, warmup=args.warmup_steps,
optimizer=torch.optim.Adam(params, lr=args.lr, betas=(args.beta1, args.beta2), eps=1e-9))
# Using defaults from the BertSum code
def optim_bert(args, model):
# print("In optim bert")
# print(model.named_parameters())
model_params = [(n, p) for n, p in list(model.named_parameters()) if n.startswith('bert_model')]
# print(model_params)
params = _get_params(model_params)
return _get_optim(params, args)
def optim_decoder(args, model):
model_params = [(n, p) for n, p in list(model.named_parameters()) if not n.startswith('bert_model')]
params = _get_params(model_params)
return _get_optim(params, args)
# def get_std_opt(model):
# return NoamOpt(model.src_embed[0].d_model, 2, 4000,
# torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.999), eps=1e-9))
# def optim(args, model):
# params = _get_params(list(model.named_parameters()))
# return NoamOpt(original_lr=args.lr, warmup=args.warmup_steps,
# optimizer=torch.optim.Adam(params, lr=args.lr, betas=(args.beta1, args.beta2), eps=1e-9))