forked from facebookresearch/ParlAI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtorch_agent.py
1809 lines (1553 loc) · 65.2 KB
/
torch_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
General utility code for building PyTorch-based agents in ParlAI.
Contains the following main utilities:
* TorchAgent class which serves as a useful parent class for other model agents
* Batch namedtuple which is the input type of the main abstract methods of
the TorchAgent class
* Output namedtuple which is the expected output type of the main abstract
methods of the TorchAgent class
See below for documentation on each specific tool.
"""
from abc import ABC, abstractmethod
from copy import deepcopy
from collections import deque
import json
import random
import numpy as np
import os
from torch import optim
from parlai.core.agents import Agent
from parlai.core.thread_utils import SharedTable
from parlai.core.build_data import modelzoo_path
from parlai.core.dict import DictionaryAgent
from parlai.core.utils import (
AttrDict,
argsort,
padded_tensor,
warn_once,
round_sigfigs,
fp16_optimizer_wrapper,
)
from parlai.core.distributed_utils import is_primary_worker
try:
import torch
except ImportError:
raise ImportError('Need to install Pytorch: go to pytorch.org')
class Batch(AttrDict):
"""
Batch is a namedtuple containing data being sent to an agent.
This is the input type of the train_step and eval_step functions.
Agents can override the batchify function to return an extended namedtuple
with additional fields if they would like, though we recommend calling the
parent function to set up these fields as a base.
:param text_vec:
bsz x seqlen tensor containing the parsed text data.
:param text_lengths:
list of length bsz containing the lengths of the text in same order as
text_vec; necessary for pack_padded_sequence.
:param label_vec:
bsz x seqlen tensor containing the parsed label (one per batch row).
:param label_lengths:
list of length bsz containing the lengths of the labels in same order as
label_vec.
:param labels:
list of length bsz containing the selected label for each batch row (some
datasets have multiple labels per input example).
:param valid_indices:
list of length bsz containing the original indices of each example in the
batch. we use these to map predictions back to their proper row, since e.g.
we may sort examples by their length or some examples may be invalid.
:param candidates:
list of lists of text. outer list has size bsz, inner lists vary in size
based on the number of candidates for each row in the batch.
:param candidate_vecs:
list of lists of tensors. outer list has size bsz, inner lists vary in size
based on the number of candidates for each row in the batch.
:param image:
list of image features in the format specified by the --image-mode arg.
:param observations:
the original observations in the batched order
"""
def __init__(
self,
text_vec=None,
text_lengths=None,
label_vec=None,
label_lengths=None,
labels=None,
valid_indices=None,
candidates=None,
candidate_vecs=None,
image=None,
observations=None,
**kwargs,
):
super().__init__(
text_vec=text_vec,
text_lengths=text_lengths,
label_vec=label_vec,
label_lengths=label_lengths,
labels=labels,
valid_indices=valid_indices,
candidates=candidates,
candidate_vecs=candidate_vecs,
image=image,
observations=observations,
**kwargs,
)
class Output(AttrDict):
"""
Output is an object containing agent predictions.
This is the expected return type of the train_step and eval_step functions,
though agents can choose to return None if they do not want to answer.
:param List[str] text:
list of strings of length bsz containing the predictions of the model
:param List[List[str]] text_candidates:
list of lists of length bsz containing ranked predictions of the model.
each sub-list is an ordered ranking of strings, of variable length.
"""
def __init__(self, text=None, text_candidates=None, **kwargs):
super().__init__(text=text, text_candidates=text_candidates, **kwargs)
class History(object):
"""
History handles tracking the dialogue state over the course of an episode.
History may also be used to track the history of any field.
:param field:
field in the observation to track over the course of the episode
(defaults to 'text')
:param vec_type:
specify a 'list' or 'deque' to save the history in this object
:param maxlen:
if `vec_type` is 'deque', this sets the maximum length of that object
:param p1_token:
token indicating 'person 1'; opt must have 'person_tokens' set to True
for this to be added
:param p1_token:
token indicating 'person 2'; opt must have 'person_tokens' set to True
for this to be added
:param dict_agent:
DictionaryAgent object for tokenizing the history
"""
def __init__(
self,
opt,
field='text',
vec_type='deque',
maxlen=None,
size=-1,
p1_token='__p1__',
p2_token='__p2__',
dict_agent=None,
):
self.field = field
self.dict = dict_agent
self.delimiter = opt.get('delimiter', '\n')
self.delimiter_tok = self.parse(self.delimiter)
self.size = size
self.split_on_newln = opt.get('split_lines', False)
# set up history objects
if vec_type != 'deque' and vec_type != 'list':
raise RuntimeError('Type {} is not supported for history'.format(vec_type))
self.vec_type = vec_type
self.max_len = maxlen
self.history_strings = []
self.history_raw_strings = []
self.history_vecs = []
# person token args
self.add_person_tokens = opt.get('person_tokens', False)
self.add_p1_after_newln = opt.get('add_p1_after_newln', False)
self.p1_token = p1_token
self.p2_token = p2_token
# tracking when to clear history
self.reset_on_next_update = False
def parse(self, text):
"""Tokenize text with the given dictionary."""
return self.dict.txt2vec(text)
def reset(self):
"""Clear the history."""
self.history_raw_strings = []
self.history_strings = []
self.history_vecs = []
def _update_strings(self, text):
if self.size > 0:
while len(self.history_strings) >= self.size:
self.history_strings.pop(0)
self.history_strings.append(text)
def _update_raw_strings(self, text):
if self.size > 0:
while len(self.history_raw_strings) >= self.size:
self.history_raw_strings.pop(0)
self.history_raw_strings.append(text)
def _update_vecs(self, text):
if self.size > 0:
while len(self.history_vecs) >= self.size:
self.history_vecs.pop(0)
self.history_vecs.append(self.parse(text))
def update_history(self, obs, add_next=None):
"""
Update the history with the given observation.
:param add_next:
string to append to history prior to updating it with the
observation
"""
if self.reset_on_next_update:
# this is the first example in a new episode, clear the previous
# history
self.reset()
self.reset_on_next_update = False
if add_next is not None:
self._update_raw_strings(add_next)
if self.add_person_tokens:
add_next = self._add_person_tokens(add_next, self.p2_token)
# update history string
self._update_strings(add_next)
# update history vecs
self._update_vecs(add_next)
if self.field in obs and obs[self.field] is not None:
if self.split_on_newln:
next_texts = obs[self.field].split('\n')
else:
next_texts = [obs[self.field]]
for text in next_texts:
self._update_raw_strings(text)
if self.add_person_tokens:
text = self._add_person_tokens(
obs[self.field], self.p1_token, self.add_p1_after_newln
)
# update history string
self._update_strings(text)
# update history vecs
self._update_vecs(text)
if obs.get('episode_done'):
# end of this episode, clear the history when we see a new example
self.reset_on_next_update = True
def get_history_str(self):
"""Return the string version of the history."""
if len(self.history_strings) > 0:
return self.delimiter.join(self.history_strings)
return None
def get_history_vec(self):
"""Return a vectorized version of the history."""
if len(self.history_vecs) == 0:
return None
if self.vec_type == 'deque':
history = deque(maxlen=self.max_len)
for vec in self.history_vecs[:-1]:
history.extend(vec)
history.extend(self.delimiter_tok)
history.extend(self.history_vecs[-1])
else:
# vec type is a list
history = []
for vec in self.history_vecs[:-1]:
history += vec
history += self.delimiter_tok
history += self.history_vecs[-1]
return history
def get_history_vec_list(self):
"""Return a list of history vecs."""
return self.history_vecs
def _add_person_tokens(self, text, token, add_after_newln=False):
if add_after_newln:
split = text.split('\n')
split[-1] = token + ' ' + split[-1]
return '\n'.join(split)
else:
return token + ' ' + text
class TorchAgent(ABC, Agent):
"""
A provided abstract base agent for any model that wants to use Torch.
Exists to make it easier to implement a new agent.
Not necessary, but reduces duplicated code.
Many methods are intended to be either used as is when the default is
acceptable, or to be overriden and called with super(), with the extra
functionality added to the initial result. See the method comment for
recommended behavior.
This agent serves as a common framework for all ParlAI models which want
to use PyTorch.
"""
P1_TOKEN = '__p1__'
P2_TOKEN = '__p2__'
@classmethod
def optim_opts(self):
"""
Fetch optimizer selection.
By default, collects everything in torch.optim, as well as importing:
- qhm / qhmadam if installed from github.com/facebookresearch/qhoptim
Override this (and probably call super()) to add your own optimizers.
"""
# first pull torch.optim in
optims = {
k.lower(): v
for k, v in optim.__dict__.items()
if not k.startswith('__') and k[0].isupper()
}
try:
import apex.optimizers.fused_adam as fused_adam
optims['fused_adam'] = fused_adam.FusedAdam
except ImportError:
pass
try:
# https://openreview.net/pdf?id=S1fUpoR5FQ
from qhoptim.pyt import QHM, QHAdam
optims['qhm'] = QHM
optims['qhadam'] = QHAdam
except ImportError:
# no QHM installed
pass
return optims
@staticmethod
def dictionary_class():
"""
Return the dictionary class that this agent expects to use.
Can be overriden if a more complex dictionary is required.
"""
return DictionaryAgent
@classmethod
def history_class(cls):
"""
Return the history class that this agent expects to use.
Can be overriden if a more complex history is required.
"""
return History
@classmethod
def add_cmdline_args(cls, argparser):
"""Add the default commandline args we expect most agents to want."""
agent = argparser.add_argument_group('TorchAgent Arguments')
agent.add_argument(
'-i',
'--interactive-mode',
type='bool',
default=False,
help='Whether in full interactive mode or not, which means generating text or'
' retrieving from a full set of candidates, which is necessary to actually'
' do full dialogue. However, during training or quick validation (e.g. PPL for'
' generation or ranking a few candidates for ranking models) you might want these'
' set to off.'
' Typically, scripts can set their preferred default behavior at the start,'
' e.g. eval scripts.',
)
# pretrained embedding arguments
agent.add_argument(
'-emb',
'--embedding-type',
default='random',
choices=[
'random',
'glove',
'glove-fixed',
'glove-twitter-fixed',
'fasttext',
'fasttext-fixed',
'fasttext_cc',
'fasttext_cc-fixed',
],
help='Choose between different strategies for initializing word '
'embeddings. Default is random, but can also preinitialize '
'from Glove or Fasttext. Preinitialized embeddings can also '
'be fixed so they are not updated during training.',
)
agent.add_argument(
'-embp',
'--embedding-projection',
default='random',
help='If pretrained embeddings have a different dimensionality '
'than your embedding size, strategy for projecting to the '
'correct size. If the dimensions are the same, this is '
'ignored unless you append "-force" to your choice.',
)
agent.add_argument(
'--fp16', type='bool', default=False, help='Use fp16 computations.'
)
# optimizer arguments
optim_group = agent.add_argument_group('Optimizer Arguments')
optim_group.add_argument(
'-opt',
'--optimizer',
default='sgd',
choices=cls.optim_opts(),
help='Choose between pytorch optimizers. Any member of torch.optim'
' should be valid.',
)
optim_group.add_argument(
'-lr', '--learningrate', type=float, default=1, help='Learning rate'
)
optim_group.add_argument(
'-clip',
'--gradient-clip',
type=float,
default=0.1,
help='gradient clipping using l2 norm',
)
optim_group.add_argument(
'--adam-eps',
type=float,
default=1e-8,
hidden=True,
help='Epsilon value for Adam optimizers. Set to 1e-6 if your '
'large model has stability issues, but prefer the default.',
)
optim_group.add_argument(
'-mom',
'--momentum',
default=0,
type=float,
help='if applicable, momentum value for optimizer.',
)
optim_group.add_argument(
'--nesterov',
default=True,
type='bool',
help='if applicable, whether to use nesterov momentum.',
)
optim_group.add_argument(
'-nu',
'--nus',
default='0.7',
type='floats',
help='if applicable, nu value(s) for optimizer. can use a single '
'value like 0.7 or a comma-separated tuple like 0.7,1.0',
)
optim_group.add_argument(
'-beta',
'--betas',
default='0.9,0.999',
type='floats',
help='if applicable, beta value(s) for optimizer. can use a single '
'value like 0.9 or a comma-separated tuple like 0.9,0.999',
)
optim_group.add_argument(
'-wd',
'--weight-decay',
type=float,
default=None,
help='Weight decay on the weights.',
)
# lr scheduler
lr_group = agent.add_argument_group('Learning Rate Scheduler')
lr_group.add_argument(
'--lr-scheduler',
type=str,
default='reduceonplateau',
choices=['reduceonplateau', 'none', 'fixed', 'invsqrt'],
help='Learning rate scheduler.',
)
lr_group.add_argument(
'--lr-scheduler-patience',
type=int,
default=3,
help='LR scheduler patience. In number of validation runs. If using '
'fixed scheduler, LR is decayed every <patience> validations.',
)
lr_group.add_argument(
'--lr-scheduler-decay',
type=float,
default=0.5,
help='Decay factor for LR scheduler, or how much LR is multiplied by '
'when it is lowered.',
)
lr_group.add_argument(
'--warmup-updates',
type=int,
default=-1,
hidden=True,
help='Learning rate warmup period, in number of SGD updates. '
'Linearly scales up LR over period. Only enabled if > 0.',
)
lr_group.add_argument(
'--warmup-rate',
type=float,
default=1e-4,
hidden=True,
help='Warmup learning rate *multiplier*. Initial LR is multiplied by '
'this value. Linearly adjusted up to 1.0 across --warmup-updates '
'steps.',
)
lr_group.add_argument(
'--update-freq',
type=int,
default=1,
hidden=True,
help='Accumulate gradients N times before performing an optimizer.step().',
)
# preprocessing arguments
agent.add_argument(
'-rc',
'--rank-candidates',
type='bool',
default=False,
help='Whether the model should parse candidates for ranking.',
)
agent.add_argument(
'-tr',
'--truncate',
default=-1,
type=int,
help='Truncate input lengths to increase speed / use less memory.',
)
agent.add_argument(
'--text-truncate',
type=int,
help='Text input truncation length: if not specified, this will '
'default to `truncate`',
)
agent.add_argument(
'--label-truncate',
type=int,
help='Label truncation length: if not specified, this will default '
'to `truncate`',
)
agent.add_argument(
'-histsz',
'--history-size',
default=-1,
type=int,
help='Number of past dialog utterances to remember.',
)
agent.add_argument(
'-pt',
'--person-tokens',
type='bool',
default=False,
help='add person tokens to history. adds __p1__ in front of input '
'text and __p2__ in front of past labels when available or '
'past utterances generated by the model. these are added to '
'the dictionary during initialization.',
)
agent.add_argument(
'--split-lines',
type='bool',
default=False,
help='split the dialogue history on newlines and save in separate '
'vectors',
)
agent.add_argument(
'--use-reply',
default='label',
hidden=True,
choices=['label', 'model', 'none'],
help='Which previous replies to use as history. If label, use '
'gold dataset replies. If model, use model\'s own replies. '
'If none, do not track replies in history.',
)
agent.add_argument(
'--add-p1-after-newln',
type='bool',
default=False,
hidden=True,
help='Add the other speaker token before the last newline in the '
'input instead of at the beginning of the input. this is '
'useful for tasks that include some kind of context before '
'the actual utterance (e.g. squad, babi, personachat).',
)
agent.add_argument(
'--delimiter',
type=str,
default='\n',
help='Join history lines with this token, defaults to newline',
)
# GPU arguments
# these gpu options are all mutually exclusive, and should error if the
# user tries to present multiple of them
gpugroup = agent.add_mutually_exclusive_group()
gpugroup.add_argument(
'-gpu', '--gpu', type=int, default=-1, help='which GPU to use'
)
gpugroup.add_argument(
'--no-cuda',
default=False,
action='store_true',
dest='no_cuda',
help='disable GPUs even if available. otherwise, will use GPUs if '
'available on the device.',
)
cls.dictionary_class().add_cmdline_args(argparser)
def __init__(self, opt, shared=None):
"""Initialize agent."""
super().__init__(opt, shared)
opt = self.opt
if not shared:
# intitialize any important structures from scratch
self.replies = {} # past replies
self.dict = self.build_dictionary()
if opt.get('fp16'):
# Volta cores revert to FP32 hardware if tensors are not multiples
# of 8 in all dimensions. This INCLUDES the embeddings layer! As
# such, we need some extra magic to ensure the dictionary is padded
# with extra tokens to make it a multiple of 8.
if len(self.dict) % 8 != 0:
for i in range(8 - len(self.dict) % 8):
self.dict['__FP16_PAD_{}__'.format(i)] = 1
self.metrics = {}
# gradient norms
self.metrics['gnorm'] = 0.0
# gradient clipping rate
self.metrics['clip'] = 0.0
# number of calls to optimizer.step()
self.metrics['updates'] = 0
else:
# copy initialized data from shared table
self.opt = shared['opt']
self.dict = shared['dict']
self.metrics = shared['metrics']
if self.opt['batchsize'] == 1:
# if we're not using batching (e.g. mturk), then replies really need
# to stay separated
self.replies = {}
else:
self.replies = shared['replies']
if opt.get('numthreads', 1) > 1:
torch.set_num_threads(1)
# check for cuda
self.use_cuda = not opt['no_cuda'] and torch.cuda.is_available()
if self.use_cuda:
if not shared:
print('[ Using CUDA ]')
if not shared and opt['gpu'] != -1:
torch.cuda.set_device(opt['gpu'])
# indicate whether using fp16
self.fp16 = self.use_cuda and self.opt.get('fp16', False)
# Default to the class name, sans "Agent". child can override
self.id = type(self).__name__.replace("Agent", "")
# now set up any fields that all instances may need
self.EMPTY = torch.LongTensor([])
self.NULL_IDX = self.dict[self.dict.null_token]
self.START_IDX = self.dict[self.dict.start_token]
self.END_IDX = self.dict[self.dict.end_token]
# for gradient acumulation
self._number_grad_accum = 0
# for the LR scheduler
self._number_training_updates = 0
# fixed random seed
self.random = random.Random(42)
# which row in the batch this instance is
self.batch_idx = shared and shared.get('batchindex') or 0
# can remember as few as zero utterances if desired
self.histsz = opt['history_size']
# truncate == 0 might give funny behavior
self.truncate = opt['truncate'] if opt['truncate'] >= 0 else None
text_truncate = opt.get('text_truncate') or opt['truncate']
self.text_truncate = text_truncate if text_truncate >= 0 else None
label_truncate = opt.get('label_truncate') or opt['truncate']
self.label_truncate = label_truncate if label_truncate >= 0 else None
# stores up to hist_utt past observations within current dialog
self.history = self.history_class()(
opt,
maxlen=self.text_truncate,
size=self.histsz,
p1_token=self.P1_TOKEN,
p2_token=self.P2_TOKEN,
dict_agent=self.dict,
)
self.is_training = False # track whether model is training
self.rank_candidates = opt['rank_candidates']
self.add_person_tokens = opt.get('person_tokens', False)
# set interactive mode or not according to options.
self.set_interactive_mode(opt['interactive_mode'], shared)
def build_dictionary(self):
"""
Return the constructed dictionary, which will be set to self.dict.
If you need to add additional tokens to the dictionary, this is likely
the right place to do it.
"""
d = self.dictionary_class()(self.opt)
if self.opt.get('person_tokens'):
d[self.P1_TOKEN] = 999_999_999
d[self.P2_TOKEN] = 999_999_998
return d
def _get_init_model(self, opt, shared):
"""
Get model file to initialize with.
If `init_model` exits, we will return the path to that file and maybe
load dict file from that path. Otherwise, use `model_file.`
:return: path to load model from, whether we loaded from `init_model`
or not
"""
init_model = None
is_finetune = False
if not shared: # only do this on first setup
# first check load path in case we need to override paths
if opt.get('init_model') and os.path.isfile(opt['init_model']):
# check first for 'init_model' for loading model from file
init_model = opt['init_model']
is_finetune = True
if opt.get('model_file') and os.path.isfile(opt['model_file']):
# next check for 'model_file', this would override init_model
init_model = opt['model_file']
is_finetune = False
if init_model is not None:
# if we are loading a model, should load its dict too
if os.path.isfile(init_model + '.dict') or opt['dict_file'] is None:
opt['dict_file'] = init_model + '.dict'
return init_model, is_finetune
def init_optim(self, params, optim_states=None, saved_optim_type=None):
"""
Initialize optimizer with model parameters.
:param params:
parameters from the model
:param optim_states:
optional argument providing states of optimizer to load
:param saved_optim_type:
type of optimizer being loaded, if changed will skip loading
optimizer states
"""
opt = self.opt
# set up optimizer args
lr = opt['learningrate']
kwargs = {'lr': lr}
if opt.get('weight_decay'):
kwargs['weight_decay'] = opt['weight_decay']
if opt.get('momentum') > 0 and opt['optimizer'] in ['sgd', 'rmsprop', 'qhm']:
# turn on momentum for optimizers that use it
kwargs['momentum'] = opt['momentum']
if opt['optimizer'] == 'sgd' and opt.get('nesterov', True):
# for sgd, maybe nesterov
kwargs['nesterov'] = opt.get('nesterov', True)
elif opt['optimizer'] == 'qhm':
# qhm needs a nu
kwargs['nu'] = opt.get('nus', (0.7,))[0]
elif opt['optimizer'] == 'adam':
# turn on amsgrad for adam
# amsgrad paper: https://openreview.net/forum?id=ryQu7f-RZ
kwargs['amsgrad'] = True
elif opt['optimizer'] == 'qhadam':
# set nus for qhadam
kwargs['nus'] = opt.get('nus', (0.7, 1.0))
if opt['optimizer'] in ['adam', 'sparseadam', 'fused_adam', 'adamax', 'qhadam']:
# set betas for optims that use it
kwargs['betas'] = opt.get('betas', (0.9, 0.999))
# set adam optimizer, but only if user specified it
if opt.get('adam_eps'):
kwargs['eps'] = opt['adam_eps']
optim_class = self.optim_opts()[opt['optimizer']]
self.optimizer = optim_class(params, **kwargs)
if self.fp16:
self.optimizer = fp16_optimizer_wrapper(self.optimizer)
# TODO: we might want to hard reset optimizers here in the
# case of fine tuning. Some rudimentary experiments seemed to
# indicate that keeping adam weights around was desirable, so this
# will remain the behavior for the time being.
if optim_states and saved_optim_type != opt['optimizer']:
# we changed from adam to adamax, or sgd to adam, or similar
print('WARNING: not loading optim state since optim class changed.')
elif optim_states:
# check for any fp16/fp32 conversions we need to do
optimstate_fp16 = 'loss_scaler' in optim_states
if self.fp16 and optimstate_fp16:
# previously trained in fp16, now we're training in fp16.
# ideally no action needed, but APEX broke backwards
# compatibility and this is the hack around it.
optim_states['loss_scaler'] = self.optimizer.state_dict()['loss_scaler']
elif optimstate_fp16 and not self.fp16:
# old optimizer was fp16 but now we're doing fp32,
# drop the fp16 wrapper from the state_dict and just load
# the fp16 weights into the fp32 tensors
optim_states = optim_states['optimizer_state_dict']
elif not optimstate_fp16 and self.fp16:
# old optimizer was fp32, but now we're doing fp16.
# this is a bit clunky, but alternatives are worse
self.optimizer.optimizer.load_state_dict(optim_states)
return
else:
# previously trained in fp32, loading in fp32.
# no special treatment needed.
pass
# finally, try to actually load the optimizer state
try:
self.optimizer.load_state_dict(optim_states)
except ValueError:
print('WARNING: not loading optim state since model params changed.')
def build_lr_scheduler(self, states=None, hard_reset=False):
"""
Create the learning rate scheduler, and assign it to self.scheduler.
This scheduler will be updated upon a call to receive_metrics.
May also create self.warmup_scheduler, if appropriate.
:param state_dict states: Possible state_dict provided by model
checkpoint, for restoring LR state
:param bool hard_reset: If true, the LR scheduler should ignore the
state dictionary.
"""
# first make sure there are no null pointers
if states is None:
states = {}
optimizer = self.optimizer
if self.fp16:
# lr schedulers don't work with apex, they expect the "real" optimizer
optimizer = optimizer.optimizer
warmup_updates = self.opt.get('warmup_updates', -1)
updates_so_far = states.get('number_training_updates', 0)
if warmup_updates > 0 and (updates_so_far < warmup_updates or hard_reset):
def _warmup_lr(step):
start = self.opt['warmup_rate']
end = 1.0
progress = min(1.0, step / self.opt['warmup_updates'])
lr_mult = start + (end - start) * progress
return lr_mult
self.warmup_scheduler = optim.lr_scheduler.LambdaLR(optimizer, _warmup_lr)
else:
self.warmup_scheduler = None
patience = self.opt.get('lr_scheduler_patience', 3)
decay = self.opt.get('lr_scheduler_decay', 0.5)
if self.opt.get('lr_scheduler') == 'none':
self.scheduler = None
elif decay == 1.0:
warn_once(
"Your LR decay is set to 1.0. Assuming you meant you wanted "
"to disable learning rate scheduling. Adjust --lr-scheduler-decay "
"if this is not correct."
)
self.scheduler = None
elif self.opt.get('lr_scheduler') == 'reduceonplateau':
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, 'min', factor=decay, patience=patience, verbose=True
)
elif self.opt.get('lr_scheduler') == 'fixed':
self.scheduler = optim.lr_scheduler.StepLR(optimizer, patience, gamma=decay)
elif self.opt.get('lr_scheduler') == 'invsqrt':
if self.opt.get('warmup_updates', -1) <= 0:
raise ValueError(
'--lr-scheduler invsqrt requires setting --warmup-updates'
)
warmup_updates = self.opt['warmup_updates']
decay_factor = np.sqrt(max(1, warmup_updates))
def _invsqrt_lr(step):
return decay_factor / np.sqrt(max(1, step))
self.scheduler = optim.lr_scheduler.LambdaLR(optimizer, _invsqrt_lr)
else:
raise ValueError(
"Don't know what to do with lr_scheduler '{}'".format(
self.opt.get('lr_scheduler')
)
)
# time to load LR state from the checkpoint, if possible.
if (
# there is already an old LR scheduler saved on disk
states
and
# and the old LR scheduler is different
states.get('lr_scheduler_type') != self.opt['lr_scheduler']
and
# and we're not already using a fresh scheduler
not hard_reset
):
# the LR scheduler changed, start things fresh
warn_once("LR scheduler is different from saved. Starting fresh!")
hard_reset = True
if hard_reset:
# We're not going to use the LR schedule, let's just exit
return
# do the actual loading (if possible)
if 'number_training_updates' in states:
self._number_training_updates = states['number_training_updates']
if self.scheduler and 'lr_scheduler' in states:
self.scheduler.load_state_dict(states['lr_scheduler'])
if states.get('warmup_scheduler') and getattr(self, 'warmup_scheduler', None):
self.warmup_scheduler.load_state_dict(states['warmup_scheduler'])
def report(self):
"""
Report metrics.
Report includes learning rate and number of training updates.
"""
metrics = {}
# only report LR if we have a scheduler
if hasattr(self, 'scheduler') and self.scheduler is not None:
current_lr = round_sigfigs(self.optimizer.param_groups[0]['lr'], 4)
metrics['lr'] = round_sigfigs(current_lr, 4)
metrics['num_updates'] = self._number_training_updates
steps = self.metrics['updates']
if steps > 0 and self.opt.get('gradient_clip', -1) > 0:
metrics['gnorm'] = round_sigfigs(self.metrics['gnorm'] / steps, 4)
metrics['clip'] = round_sigfigs(self.metrics['clip'] / steps, 2)
return metrics
def _is_lr_warming_up(self):
"""Check if we're warming up the learning rate."""
return (
self.warmup_scheduler is not None
and self._number_training_updates <= self.opt['warmup_updates']
)
def receive_metrics(self, metrics_dict):
"""
Use the metrics to decide when to adjust LR schedule.
This uses the loss as the validation metric if present, if not this
function does nothing. Note that the model must be reporting loss for
this to work.
Override this to override the behavior.