forked from facebookresearch/ParlAI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparams.py
1038 lines (946 loc) · 36.2 KB
/
params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Provide an argument parser and default command line options for using ParlAI."""
import argparse
import importlib
import os
import pickle
import json
import sys as _sys
import datetime
from parlai.core.agents import get_agent_module, get_task_module
from parlai.core.build_data import modelzoo_path
from parlai.tasks.tasks import ids_to_tasks
from parlai.core.utils import Opt, load_opt_file
def print_announcements(opt):
"""
Output any announcements the ParlAI team wishes to make to users.
Also gives the user the option to suppress the output.
"""
# no annoucements to make right now
return
noannounce_file = os.path.join(opt.get('datapath'), 'noannouncements')
if os.path.exists(noannounce_file):
# user has suppressed announcements, don't do anything
return
# useful constants
# all of these colors are bolded
RESET = '\033[0m'
BOLD = '\033[1m'
RED = '\033[1;91m'
YELLOW = '\033[1;93m'
GREEN = '\033[1;92m'
BLUE = '\033[1;96m'
CYAN = '\033[1;94m'
MAGENTA = '\033[1;95m'
# only use colors if we're outputting to a terminal
USE_COLORS = _sys.stdout.isatty()
if not USE_COLORS:
RESET = BOLD = RED = YELLOW = GREEN = BLUE = CYAN = MAGENTA = ''
# generate the rainbow stars
rainbow = [RED, YELLOW, GREEN, CYAN, BLUE, MAGENTA]
size = 78 // len(rainbow)
stars = ''.join([color + '*' * size for color in rainbow])
stars += RESET
# do the actual output
print(
'\n'.join(
[
'',
stars,
BOLD,
'Announcements go here.',
RESET,
# don't bold the suppression command
'To suppress this message (and future announcements), run\n`touch {}`'.format(
noannounce_file
),
stars,
]
)
)
def get_model_name(opt):
"""Get the model name from either `--model` or `--model-file`."""
model = opt.get('model', None)
if model is None:
# try to get model name from model opt file
model_file = opt.get('model_file', None)
if model_file is not None:
model_file = modelzoo_path(opt.get('datapath'), model_file)
optfile = model_file + '.opt'
if os.path.isfile(optfile):
new_opt = load_opt_file(optfile)
model = new_opt.get('model', None)
return model
def str2bool(value):
"""Convert 'yes', 'false', '1', etc. into a boolean."""
v = value.lower()
if v in ('yes', 'true', 't', '1', 'y'):
return True
elif v in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def str2floats(s):
"""Look for single float or comma-separated floats."""
return tuple(float(f) for f in s.split(','))
def str2class(value):
"""
From import path string, returns the class specified.
For example, the string 'parlai.agents.drqa.drqa:SimpleDictionaryAgent'
returns <class 'parlai.agents.drqa.drqa.SimpleDictionaryAgent'>.
"""
if ':' not in value:
raise RuntimeError('Use a colon before the name of the class.')
name = value.split(':')
module = importlib.import_module(name[0])
return getattr(module, name[1])
def class2str(value):
"""Inverse of params.str2class()."""
s = str(value)
s = s[s.find('\'') + 1 : s.rfind('\'')] # pull out import path
s = ':'.join(s.rsplit('.', 1)) # replace last period with ':'
return s
def fix_underscores(args):
"""
Convert underscores to hyphens in args.
For example, converts '--gradient_clip' to '--gradient-clip'.
:param args: iterable, possibly containing args strings with underscores.
"""
if args:
new_args = []
for a in args:
if type(a) is str and a.startswith('-'):
a = a.replace('_', '-')
new_args.append(a)
args = new_args
return args
class CustomHelpFormatter(argparse.ArgumentDefaultsHelpFormatter):
"""
Produce a custom-formatted `--help` option.
See https://goo.gl/DKtHb5 for details.
"""
def __init__(self, *args, **kwargs):
kwargs['max_help_position'] = 8
kwargs['width'] = 130
super().__init__(*args, **kwargs)
def _format_action_invocation(self, action):
if not action.option_strings or action.nargs == 0:
return super()._format_action_invocation(action)
default = self._get_default_metavar_for_optional(action)
args_string = self._format_args(action, default)
return ', '.join(action.option_strings) + ' ' + args_string
class ParlaiParser(argparse.ArgumentParser):
"""
Provide an opt-producer and CLI argument parser.
Pseudo-extension of ``argparse`` which sets a number of parameters
for the ParlAI framework. More options can be added specific to other
modules by passing this object and calling ``add_arg()`` or
``add_argument()`` on it.
For example, see ``parlai.core.dict.DictionaryAgent.add_cmdline_args``.
:param add_parlai_args:
(default True) initializes the default arguments for ParlAI
package, including the data download paths and task arguments.
:param add_model_args:
(default False) initializes the default arguments for loading
models, including initializing arguments from that model.
"""
def __init__(
self, add_parlai_args=True, add_model_args=False, description='ParlAI parser'
):
"""Initialize the ParlAI argparser."""
super().__init__(
description=description,
allow_abbrev=False,
conflict_handler='resolve',
formatter_class=CustomHelpFormatter,
)
self.register('type', 'bool', str2bool)
self.register('type', 'floats', str2floats)
self.register('type', 'class', str2class)
self.parlai_home = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
)
os.environ['PARLAI_HOME'] = self.parlai_home
self.add_arg = self.add_argument
# remember which args were specified on the command line
self.cli_args = _sys.argv[1:]
self.overridable = {}
if add_parlai_args:
self.add_parlai_args()
if add_model_args:
self.add_model_args()
def add_parlai_data_path(self, argument_group=None):
"""Add --datapath CLI arg."""
if argument_group is None:
argument_group = self
argument_group.add_argument(
'-dp',
'--datapath',
default=None,
help='path to datasets, defaults to {parlai_dir}/data',
)
def add_mturk_args(self):
"""Add standard mechanical turk arguments."""
mturk = self.add_argument_group('Mechanical Turk')
default_log_path = os.path.join(self.parlai_home, 'logs', 'mturk')
mturk.add_argument(
'--mturk-log-path',
default=default_log_path,
help='path to MTurk logs, defaults to {parlai_dir}/logs/mturk',
)
mturk.add_argument(
'-t',
'--task',
help='MTurk task, e.g. "qa_data_collection" or "model_evaluator"',
)
mturk.add_argument(
'-nc',
'--num-conversations',
default=1,
type=int,
help='number of conversations you want to create for this task',
)
mturk.add_argument(
'--unique',
dest='unique_worker',
default=False,
action='store_true',
help='enforce that no worker can work on your task twice',
)
mturk.add_argument(
'--max-hits-per-worker',
dest='max_hits_per_worker',
default=0,
type=int,
help='Max number of hits each worker can perform during current group run',
)
mturk.add_argument(
'--unique-qual-name',
dest='unique_qual_name',
default=None,
type=str,
help='qualification name to use for uniqueness between HITs',
)
mturk.add_argument(
'-r',
'--reward',
default=0.05,
type=float,
help='reward for each worker for finishing the conversation, '
'in US dollars',
)
mturk.add_argument(
'--sandbox',
dest='is_sandbox',
action='store_true',
help='submit the HITs to MTurk sandbox site',
)
mturk.add_argument(
'--live',
dest='is_sandbox',
action='store_false',
help='submit the HITs to MTurk live site',
)
mturk.add_argument(
'--debug',
dest='is_debug',
action='store_true',
help='print and log all server interactions and messages',
)
mturk.add_argument(
'--verbose',
dest='verbose',
action='store_true',
help='print all messages sent to and from Turkers',
)
mturk.add_argument(
'--hard-block',
dest='hard_block',
action='store_true',
default=False,
help='Hard block disconnecting Turkers from all of your HITs',
)
mturk.add_argument(
'--log-level',
dest='log_level',
type=int,
default=20,
help='importance level for what to put into the logs. the lower '
'the level the more that gets logged. values are 0-50',
)
mturk.add_argument(
'--disconnect-qualification',
dest='disconnect_qualification',
default=None,
help='Qualification to use for soft blocking users for '
'disconnects. By default '
'turkers are never blocked, though setting this will allow '
'you to filter out turkers that have disconnected too many '
'times on previous HITs where this qualification was set.',
)
mturk.add_argument(
'--block-qualification',
dest='block_qualification',
default=None,
help='Qualification to use for soft blocking users. This '
'qualification is granted whenever soft_block_worker is '
'called, and can thus be used to filter workers out from a '
'single task or group of tasks by noted performance.',
)
mturk.add_argument(
'--count-complete',
dest='count_complete',
default=False,
action='store_true',
help='continue until the requested number of conversations are '
'completed rather than attempted',
)
mturk.add_argument(
'--allowed-conversations',
dest='allowed_conversations',
default=0,
type=int,
help='number of concurrent conversations that one mturk worker '
'is able to be involved in, 0 is unlimited',
)
mturk.add_argument(
'--max-connections',
dest='max_connections',
default=30,
type=int,
help='number of HITs that can be launched at the same time, 0 is '
'unlimited.',
)
mturk.add_argument(
'--min-messages',
dest='min_messages',
default=0,
type=int,
help='number of messages required to be sent by MTurk agent when '
'considering whether to approve a HIT in the event of a '
'partner disconnect. I.e. if the number of messages '
'exceeds this number, the turker can submit the HIT.',
)
mturk.add_argument(
'--local',
dest='local',
default=False,
action='store_true',
help='Run the server locally on this server rather than setting up'
' a heroku server.',
)
mturk.add_argument(
'--hobby',
dest='hobby',
default=False,
action='store_true',
help='Run the heroku server on the hobby tier.',
)
mturk.add_argument(
'--max-time',
dest='max_time',
default=0,
type=int,
help='Maximum number of seconds per day that a worker is allowed '
'to work on this assignment',
)
mturk.add_argument(
'--max-time-qual',
dest='max_time_qual',
default=None,
help='Qualification to use to share the maximum time requirement '
'with other runs from other machines.',
)
mturk.add_argument(
'--heroku-team',
dest='heroku_team',
default=None,
help='Specify Heroku team name to use for launching Dynos.',
)
mturk.add_argument(
'--tmp-dir',
dest='tmp_dir',
default=None,
help='Specify location to use for scratch builds and such.',
)
mturk.set_defaults(is_sandbox=True)
mturk.set_defaults(is_debug=False)
mturk.set_defaults(verbose=False)
def add_messenger_args(self):
"""Add Facebook Messenger arguments."""
messenger = self.add_argument_group('Facebook Messenger')
messenger.add_argument(
'--debug',
dest='is_debug',
action='store_true',
help='print and log all server interactions and messages',
)
messenger.add_argument(
'--verbose',
dest='verbose',
action='store_true',
help='print all messages sent to and from Turkers',
)
messenger.add_argument(
'--log-level',
dest='log_level',
type=int,
default=20,
help='importance level for what to put into the logs. the lower '
'the level the more that gets logged. values are 0-50',
)
messenger.add_argument(
'--force-page-token',
dest='force_page_token',
action='store_true',
help='override the page token stored in the cache for a new one',
)
messenger.add_argument(
'--password',
dest='password',
type=str,
default=None,
help='Require a password for entry to the bot',
)
messenger.add_argument(
'--bypass-server-setup',
dest='bypass_server_setup',
action='store_true',
default=False,
help='should bypass traditional server and socket setup',
)
messenger.add_argument(
'--local',
dest='local',
action='store_true',
default=False,
help='Run the server locally on this server rather than setting up'
' a heroku server.',
)
messenger.set_defaults(is_debug=False)
messenger.set_defaults(verbose=False)
def add_parlai_args(self, args=None):
"""Add common ParlAI args across all scripts."""
parlai = self.add_argument_group('Main ParlAI Arguments')
parlai.add_argument(
'-o',
'--init-opt',
default=None,
help='Path to json file of options. '
'Note: Further Command-line arguments override file-based options.',
)
parlai.add_argument(
'-v',
'--show-advanced-args',
action='store_true',
help='Show hidden command line options (advanced users only)',
)
parlai.add_argument(
'-t', '--task', help='ParlAI task(s), e.g. "babi:Task1" or "babi,cbt"'
)
parlai.add_argument(
'--download-path',
default=None,
hidden=True,
help='path for non-data dependencies to store any needed files.'
'defaults to {parlai_dir}/downloads',
)
parlai.add_argument(
'-dt',
'--datatype',
default='train',
choices=[
'train',
'train:stream',
'train:ordered',
'train:ordered:stream',
'train:stream:ordered',
'train:evalmode',
'train:evalmode:stream',
'train:evalmode:ordered',
'train:evalmode:ordered:stream',
'train:evalmode:stream:ordered',
'valid',
'valid:stream',
'test',
'test:stream',
],
help='choose from: train, train:ordered, valid, test. to stream '
'data add ":stream" to any option (e.g., train:stream). '
'by default: train is random with replacement, '
'valid is ordered, test is ordered.',
)
parlai.add_argument(
'-im',
'--image-mode',
default='raw',
type=str,
help='image preprocessor to use. default is "raw". set to "none" '
'to skip image loading.',
hidden=True,
)
parlai.add_argument(
'-nt',
'--numthreads',
default=1,
type=int,
help='number of threads. Used for hogwild if batchsize is 1, else '
'for number of threads in threadpool loading,',
)
parlai.add_argument(
'--hide-labels',
default=False,
type='bool',
hidden=True,
help='default (False) moves labels in valid and test sets to the '
'eval_labels field. If True, they are hidden completely.',
)
parlai.add_argument(
'-mtw',
'--multitask-weights',
type='floats',
default=[1],
help='list of floats, one for each task, specifying '
'the probability of drawing the task in multitask case',
hidden=True,
)
parlai.add_argument(
'-bs',
'--batchsize',
default=1,
type=int,
help='batch size for minibatch training schemes',
)
self.add_parlai_data_path(parlai)
def add_distributed_training_args(self):
"""Add CLI args for distributed training."""
grp = self.add_argument_group('Distributed Training')
grp.add_argument(
'--distributed-world-size', type=int, help='Number of workers.'
)
grp.add_argument(
'--verbose',
type='bool',
default=False,
help='All workers print output.',
hidden=True,
)
return grp
def add_pytorch_datateacher_args(self):
"""Add CLI args for PytorchDataTeacher."""
pytorch = self.add_argument_group('PytorchData Arguments')
pytorch.add_argument(
'-pyt',
'--pytorch-teacher-task',
help='Use the PytorchDataTeacher for multiprocessed '
'data loading with a standard ParlAI task, e.g. "babi:Task1k"',
)
pytorch.add_argument(
'-pytd',
'--pytorch-teacher-dataset',
help='Use the PytorchDataTeacher for multiprocessed '
'data loading with a pytorch Dataset, e.g. "vqa_1" or "flickr30k"',
)
pytorch.add_argument(
'--pytorch-datapath',
type=str,
default=None,
help='datapath for pytorch data loader'
'(note: only specify if the data does not reside'
'in the normal ParlAI datapath)',
hidden=True,
)
pytorch.add_argument(
'-nw',
'--numworkers',
type=int,
default=4,
help='how many workers the Pytorch dataloader should use',
hidden=True,
)
pytorch.add_argument(
'--pytorch-preprocess',
type='bool',
default=False,
help='Whether the agent should preprocess the data while building'
'the pytorch data',
hidden=True,
)
pytorch.add_argument(
'-pybsrt',
'--pytorch-teacher-batch-sort',
type='bool',
default=False,
help='Whether to construct batches of similarly sized episodes'
'when using the PytorchDataTeacher (either via specifying `-pyt`',
hidden=True,
)
pytorch.add_argument(
'--batch-sort-cache-type',
type=str,
choices=['pop', 'index', 'none'],
default='pop',
help='how to build up the batch cache',
hidden=True,
)
pytorch.add_argument(
'--batch-length-range',
type=int,
default=5,
help='degree of variation of size allowed in batch',
hidden=True,
)
pytorch.add_argument(
'--shuffle',
type='bool',
default=False,
help='Whether to shuffle the data',
hidden=True,
)
pytorch.add_argument(
'--batch-sort-field',
type=str,
default='text',
help='What field to use when determining the length of an episode',
hidden=True,
)
pytorch.add_argument(
'-pyclen',
'--pytorch-context-length',
default=-1,
type=int,
help='Number of past utterances to remember when building flattened '
'batches of data in multi-example episodes.'
'(For use with PytorchDataTeacher)',
hidden=True,
)
pytorch.add_argument(
'-pyincl',
'--pytorch-include-labels',
default=True,
type='bool',
help='Specifies whether or not to include labels as past utterances when '
'building flattened batches of data in multi-example episodes.'
'(For use with PytorchDataTeacher)',
hidden=True,
)
def add_model_args(self):
"""Add arguments related to models such as model files."""
model_args = self.add_argument_group('ParlAI Model Arguments')
model_args.add_argument(
'-m',
'--model',
default=None,
help='the model class name. can match parlai/agents/<model> for '
'agents in that directory, or can provide a fully specified '
'module for `from X import Y` via `-m X:Y` '
'(e.g. `-m parlai.agents.seq2seq.seq2seq:Seq2SeqAgent`)',
)
model_args.add_argument(
'-mf',
'--model-file',
default=None,
help='model file name for loading and saving models',
)
model_args.add_argument(
'-im',
'--init-model',
default=None,
type=str,
help='load model weights and dict from this file',
)
model_args.add_argument(
'--dict-class', hidden=True, help='the class of the dictionary agent uses'
)
def add_model_subargs(self, model):
"""Add arguments specific to a particular model."""
agent = get_agent_module(model)
try:
if hasattr(agent, 'add_cmdline_args'):
agent.add_cmdline_args(self)
except argparse.ArgumentError:
# already added
pass
try:
if hasattr(agent, 'dictionary_class'):
s = class2str(agent.dictionary_class())
self.set_defaults(dict_class=s)
except argparse.ArgumentError:
# already added
pass
def add_task_args(self, task):
"""Add arguments specific to the specified task."""
for t in ids_to_tasks(task).split(','):
agent = get_task_module(t)
try:
if hasattr(agent, 'add_cmdline_args'):
agent.add_cmdline_args(self)
except argparse.ArgumentError:
# already added
pass
def add_pyt_dataset_args(self, opt):
"""Add arguments specific to specified pytorch dataset."""
from parlai.core.pytorch_data_teacher import get_dataset_classes
dataset_classes = get_dataset_classes(opt)
for dataset, _, _ in dataset_classes:
try:
if hasattr(dataset, 'add_cmdline_args'):
dataset.add_cmdline_args(self)
except argparse.ArgumentError:
# already added
pass
def add_image_args(self, image_mode):
"""Add additional arguments for handling images."""
try:
parlai = self.add_argument_group('ParlAI Image Preprocessing Arguments')
parlai.add_argument(
'--image-size',
type=int,
default=256,
help='resizing dimension for images',
hidden=True,
)
parlai.add_argument(
'--image-cropsize',
type=int,
default=224,
help='crop dimension for images',
hidden=True,
)
except argparse.ArgumentError:
# already added
pass
def add_extra_args(self, args=None):
"""Add more args depending on how known args are set."""
parsed = vars(self.parse_known_args(args, nohelp=True)[0])
# Also load extra args options if a file is given.
if parsed.get('init_opt', None) is not None:
self._load_known_opts(parsed.get('init_opt'), parsed)
parsed = self._infer_datapath(parsed)
# find which image mode specified if any, and add additional arguments
image_mode = parsed.get('image_mode', None)
if image_mode is not None and image_mode != 'none':
self.add_image_args(image_mode)
# find which task specified if any, and add its specific arguments
task = parsed.get('task', None)
if task is not None:
self.add_task_args(task)
evaltask = parsed.get('evaltask', None)
if evaltask is not None:
self.add_task_args(evaltask)
# find pytorch teacher task if specified, add its specific arguments
pytorch_teacher_task = parsed.get('pytorch_teacher_task', None)
if pytorch_teacher_task is not None:
self.add_task_args(pytorch_teacher_task)
# find pytorch dataset if specified, add its specific arguments
pytorch_teacher_dataset = parsed.get('pytorch_teacher_dataset', None)
if pytorch_teacher_dataset is not None:
self.add_pyt_dataset_args(parsed)
# find which model specified if any, and add its specific arguments
model = get_model_name(parsed)
if model is not None:
self.add_model_subargs(model)
# reset parser-level defaults over any model-level defaults
try:
self.set_defaults(**self._defaults)
except AttributeError:
raise RuntimeError(
'Please file an issue on github that argparse '
'got an attribute error when parsing.'
)
def parse_known_args(self, args=None, namespace=None, nohelp=False):
"""Parse known args to ignore help flag."""
if args is None:
# args default to the system args
args = _sys.argv[1:]
args = fix_underscores(args)
if nohelp:
# ignore help
args = [a for a in args if a != '-h' and a != '--help']
return super().parse_known_args(args, namespace)
def _load_known_opts(self, optfile, parsed):
"""
Pull in CLI args for proper models/tasks/etc.
Called before args are parsed; ``_load_opts`` is used for actually
overriding opts after they are parsed.
"""
new_opt = load_opt_file(optfile)
for key, value in new_opt.items():
# existing command line parameters take priority.
if key not in parsed or parsed[key] is None:
parsed[key] = value
def _load_opts(self, opt):
optfile = opt.get('init_opt')
new_opt = load_opt_file(optfile)
for key, value in new_opt.items():
# existing command line parameters take priority.
if key not in opt:
raise RuntimeError(
'Trying to set opt from file that does not exist: ' + str(key)
)
if key not in opt['override']:
opt[key] = value
opt['override'][key] = value
def _infer_datapath(self, opt):
"""
Set the value for opt['datapath'] and opt['download_path'].
Sets the value for opt['datapath'] and opt['download_path'], correctly
respecting environmental variables and the default.
"""
# set environment variables
# Priority for setting the datapath (same applies for download_path):
# --datapath -> os.environ['PARLAI_DATAPATH'] -> <self.parlai_home>/data
if opt.get('download_path'):
os.environ['PARLAI_DOWNPATH'] = opt['download_path']
elif os.environ.get('PARLAI_DOWNPATH') is None:
os.environ['PARLAI_DOWNPATH'] = os.path.join(self.parlai_home, 'downloads')
if opt.get('datapath'):
os.environ['PARLAI_DATAPATH'] = opt['datapath']
elif os.environ.get('PARLAI_DATAPATH') is None:
os.environ['PARLAI_DATAPATH'] = os.path.join(self.parlai_home, 'data')
opt['download_path'] = os.environ['PARLAI_DOWNPATH']
opt['datapath'] = os.environ['PARLAI_DATAPATH']
return opt
def _process_args_to_opts(self):
self.opt = Opt(vars(self.args))
# custom post-parsing
self.opt['parlai_home'] = self.parlai_home
self.opt = self._infer_datapath(self.opt)
# set all arguments specified in command line as overridable
option_strings_dict = {}
store_true = []
store_false = []
for group in self._action_groups:
for a in group._group_actions:
if hasattr(a, 'option_strings'):
for option in a.option_strings:
option_strings_dict[option] = a.dest
if '_StoreTrueAction' in str(type(a)):
store_true.append(option)
elif '_StoreFalseAction' in str(type(a)):
store_false.append(option)
for i in range(len(self.cli_args)):
if self.cli_args[i] in option_strings_dict:
if self.cli_args[i] in store_true:
self.overridable[option_strings_dict[self.cli_args[i]]] = True
elif self.cli_args[i] in store_false:
self.overridable[option_strings_dict[self.cli_args[i]]] = False
elif i < len(self.cli_args) - 1 and self.cli_args[i + 1][:1] != '-':
key = option_strings_dict[self.cli_args[i]]
self.overridable[key] = self.opt[key]
self.opt['override'] = self.overridable
# load opts if a file is provided.
if self.opt.get('init_opt', None) is not None:
self._load_opts(self.opt)
# map filenames that start with 'zoo:' to point to the model zoo dir
if self.opt.get('model_file') is not None:
self.opt['model_file'] = modelzoo_path(
self.opt.get('datapath'), self.opt['model_file']
)
if self.opt['override'].get('model_file') is not None:
# also check override
self.opt['override']['model_file'] = modelzoo_path(
self.opt.get('datapath'), self.opt['override']['model_file']
)
if self.opt.get('dict_file') is not None:
self.opt['dict_file'] = modelzoo_path(
self.opt.get('datapath'), self.opt['dict_file']
)
if self.opt['override'].get('dict_file') is not None:
# also check override
self.opt['override']['dict_file'] = modelzoo_path(
self.opt.get('datapath'), self.opt['override']['dict_file']
)
# add start time of an experiment
self.opt['starttime'] = datetime.datetime.today().strftime('%b%d_%H-%M')
def parse_and_process_known_args(self, args=None):
"""
Parse provided arguments and return parlai opts and unknown arg list.
Runs the same arg->opt parsing that parse_args does, but doesn't
throw an error if the args being parsed include additional command
line arguments that parlai doesn't know what to do with.
"""
self.args, unknowns = super().parse_known_args(args=args)
self._process_args_to_opts()
return self.opt, unknowns
def parse_args(self, args=None, namespace=None, print_args=True):
"""
Parse the provided arguments and returns a dictionary of the ``args``.
We specifically remove items with ``None`` as values in order
to support the style ``opt.get(key, default)``, which would otherwise
return ``None``.
"""
self.add_extra_args(args)
self.args = super().parse_args(args=args)
self._process_args_to_opts()
if print_args:
self.print_args()
print_announcements(self.opt)
return self.opt
def print_args(self):
"""Print out all the arguments in this parser."""
if not self.opt:
self.parse_args(print_args=False)
values = {}
for key, value in self.opt.items():
values[str(key)] = str(value)
for group in self._action_groups:
group_dict = {
a.dest: getattr(self.args, a.dest, None) for a in group._group_actions
}
namespace = argparse.Namespace(**group_dict)
count = 0
for key in sorted(namespace.__dict__):
if key in values:
if count == 0:
print('[ ' + group.title + ': ] ')
count += 1
print('[ ' + key + ': ' + values[key] + ' ]')
def set_params(self, **kwargs):
"""Set overridable kwargs."""
self.set_defaults(**kwargs)
for k, v in kwargs.items():
self.overridable[k] = v
@property
def show_advanced_args(self):
"""Check if we should show arguments marked as hidden."""
if hasattr(self, '_show_advanced_args'):
return self._show_advanced_args
known_args, _ = self.parse_known_args(nohelp=True)
if hasattr(known_args, 'show_advanced_args'):
self._show_advanced_args = known_args.show_advanced_args