forked from facebookresearch/ParlAI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytorch_data_teacher.py
797 lines (702 loc) · 28.6 KB
/
pytorch_data_teacher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Provide PytorchDataTeacher functionality.
To use this class, please follow the tutorial here:
http://parl.ai/docs/tutorial_worlds.html#multiprocessed-pytorch-dataloader
"""
from .teachers import FixedDialogTeacher
from parlai.core.utils import warn_once
from parlai.scripts.build_pytorch_data import build_data
from .agents import get_agent_module
import json
import math
import collections
import random
import os
from functools import wraps
import importlib
from functools import lru_cache
try:
import torch # noqa: F401
except ImportError:
raise ImportError('Need to install Pytorch: go to pytorch.org')
from torch.utils.data import ConcatDataset, Dataset, DataLoader, sampler
from torch.multiprocessing import Lock, Value
import ctypes
from threading import Thread, Condition, RLock
if torch.version.__version__.startswith('0.'):
raise ImportError(
"Please upgrade to PyTorch >=1.0; "
"visit https://pytorch.org for instructions."
)
class BatchSortCache(object):
"""
Object that encapsulates the functionality of the batch sort cache.
Maps episode length to dictionary with following keys:
- ``current_idx``: which episode in the list are we at (if simply indexing
into list)
- ``ep_list``: list of episodes of the length of the key
- ``bucket_complete``: if there are no more episodes left to consider in the
bucket
"""
@classmethod
def create(cls):
"""Singleton factory."""
if not hasattr(cls, 'length_to_eps'):
# Maps episode length to list of episodes
cls.length_to_eps = {}
# Set of episode indices already in the cache
cls.ep_indices = set()
# List of batches if popping batches
cls.batches = []
# If all episodes have been loaded into memory
cls.load_complete = Value(ctypes.c_bool, False)
# Lock to access batches
cls.batches_lock = Lock()
# Lock to access length_to_eps
cls.cache_lock = Lock()
# Lock for condition variables
cls.fill_cache_lock = RLock()
# Condition notifying Loader to add to cache
cls.add_to_cache_cv = Condition(lock=cls.fill_cache_lock)
# Condition notifying teacher that cache has episodes
cls.cache_filled_cv = Condition(lock=cls.fill_cache_lock)
@classmethod
def destroy(cls):
"""Singleton destroyer."""
if hasattr(cls, 'length_to_eps'):
del cls.length_to_eps
del cls.ep_indices
del cls.batches
del cls.load_complete
del cls.batches_lock
del cls.cache_lock
del cls.fill_cache_lock
del cls.add_to_cache_cv
del cls.cache_filled_cv
@classmethod
def batch_cache(cls, function):
"""Create the cache of batches."""
max_cache_size = 10000 # Max unseen eps
min_cache_size = 1000 # Min unseen eps
def get_cache_size():
"""Return number of available episodes."""
return sum(
len(v['ep_list']) - v['current_idx']
for k, v in cls.length_to_eps.items()
)
def get_available_buckets(bsz):
"""Return buckets where there are enough episodes for a batch."""
if cls.load_complete.value:
return {
k: v
for k, v in cls.length_to_eps.items()
if not v['bucket_complete']
or len(v['ep_list']) - v['current_idx'] > 0
}
else:
return {
k: v
for k, v in cls.length_to_eps.items()
if len(v['ep_list']) - v['current_idx'] >= bsz
}
def reset():
"""Reset the indices into the buckets."""
with cls.cache_lock:
for idx in cls.length_to_eps:
cls.length_to_eps[idx]['current_idx'] = 0
cls.length_to_eps[idx]['bucket_complete'] = False
def consolidate(caller):
"""Consolidate remaining episodes into batches."""
cls.load_complete.value = True
bsz = caller.bsz
batch = []
sorted_lengths = sorted(cls.length_to_eps.keys())
with cls.cache_lock:
if caller.batch_cache_type == 'index':
for length in sorted_lengths:
current_idx = cls.length_to_eps[length]['current_idx']
ep_list = cls.length_to_eps[length]['ep_list']
unseen_eps = ep_list[current_idx:]
cls.length_to_eps[length]['ep_list'] = ep_list[:current_idx]
batch = unseen_eps + batch
while len(batch) >= bsz:
cls.length_to_eps[length]['ep_list'] += batch[:bsz]
batch = batch[bsz:]
if len(batch) > 0:
cls.length_to_eps[-1] = {
'current_idx': 0,
'ep_list': batch,
'bucket_complete': False,
}
elif caller.batch_cache_type == 'pop':
for length in sorted_lengths:
batch += cls.length_to_eps[length]['ep_list']
with cls.batches_lock:
while len(batch) >= bsz:
cls.batches.append(batch[:bsz])
batch = batch[bsz:]
if len(batch) > 0:
with cls.batches_lock:
cls.batches.append(batch)
def flatten(l):
"""Flatten a list."""
return [item for sublist in l for item in sublist]
def put_in_cache(ep_idx, episode, caller):
"""Put episode `ep_idx` into cache."""
length = ep_length(episode[caller.batch_sort_field])
lengths = [length] + flatten(
[
[length + i, length + (i * -1)]
for i in range(1, caller.batch_length_range)
]
)
lengths = [max(i, 1) for i in lengths]
in_cache = ep_idx in cls.ep_indices
# first check if episode can go in existing bucket
if not in_cache:
for l in lengths:
if l in cls.length_to_eps:
with cls.cache_lock:
cls.length_to_eps[l]['ep_list'] += [(ep_idx, episode)]
cls.ep_indices.add(ep_idx)
in_cache = True
break
# otherwise, make a new bucket
if not in_cache:
with cls.cache_lock:
cls.length_to_eps[length] = {
'current_idx': 0,
'ep_list': [(ep_idx, episode)],
'bucket_complete': False,
}
cls.ep_indices.add(ep_idx)
if ep_idx == caller.dataset.num_episodes() - 1:
consolidate(caller)
with cls.add_to_cache_cv:
cls.cache_filled_cv.notify_all()
@wraps(function)
def wrapper(*args):
"""Wrap a function."""
# TODO: refactor
caller = args[0]
batch_sort = caller.batch_sort
batch_cache_type = caller.batch_cache_type
bsz = caller.bsz
if not batch_sort or not caller.datatype.startswith('train'):
return function(*args)
# If Loader, put episodes in cache
if isinstance(caller, LoaderProcess):
with cls.add_to_cache_cv:
while (
get_cache_size() >= max_cache_size
and len(get_available_buckets(bsz)) > 0
):
cls.cache_filled_cv.notify_all()
cls.add_to_cache_cv.wait()
idx_and_batch = function(*args)
if idx_and_batch is None:
return None
for ep_index, ep in idx_and_batch[1]:
put_in_cache(ep_index, ep, caller)
return idx_and_batch
# If teacher, return batch of episodes
else:
teacher = caller
num_batches = teacher.num_batches
while True:
with cls.cache_filled_cv:
while not cls.load_complete.value and (
get_cache_size() <= min_cache_size
or len(get_available_buckets(bsz)) == 0
):
cls.add_to_cache_cv.notify()
cls.cache_filled_cv.wait()
available_buckets = get_available_buckets(bsz)
if cls.load_complete.value and batch_cache_type == 'pop':
return teacher.batch_idx + 1, random.choice(cls.batches)
batch = None
available_buckets = get_available_buckets(bsz)
if len(available_buckets) != 0:
# Pick length index at random
length = random.choice(list(available_buckets.keys()))
with cls.cache_lock:
current_idx = cls.length_to_eps[length]['current_idx']
ep_list = cls.length_to_eps[length]['ep_list']
num_eps = len(ep_list)
if num_eps - current_idx >= bsz:
if batch_cache_type == 'pop':
batch = ep_list[:bsz]
cls.length_to_eps[length]['ep_list'] = ep_list[bsz:]
else:
batch = ep_list[current_idx : current_idx + bsz]
cls.length_to_eps[length]['current_idx'] = (
current_idx + bsz
)
elif cls.load_complete.value and num_eps > 0:
if batch_cache_type == 'pop':
batch = ep_list
elif num_eps - current_idx > 0:
batch = ep_list[current_idx:]
cls.length_to_eps[length]['current_idx'] = (
num_eps - 1
)
cls.length_to_eps[length]['bucket_complete'] = True
if batch is not None:
if batch_cache_type == 'pop':
with cls.batches_lock:
cls.batches.append(batch)
elif teacher.batch_idx + 1 >= num_batches:
reset()
return teacher.batch_idx + 1, batch
return wrapper
def ep_length(val):
"""Determine the length of an episode, given the specified value."""
if isinstance(val, (int, bytes, bool)):
return 1
if isinstance(val, str):
return len(val.replace('\n', ' ').split(' '))
if isinstance(val, (collections.Mapping, collections.Sequence, torch.Tensor)):
if isinstance(val, collections.Mapping) and val.get(
'deserialized_tensor', False
):
return len(val['value'])
return len(val)
def get_dataset_classes(opt):
"""
Get datasets from the options.
To use a custom dataset (as opposed to the StreamDataset or ParlAIDataset),
you can subclass the pytorch Dataset class and specify its location on the
command line.
For example, the VQA v1 task provides a custom dataset, which can
be specified on the command line as follows: ``-pytd vqa_v1:VQADataset``
Note that if the dataset is named ``DefaultDataset``, then you do
not need to specify its name following the colon; e.g., it
would just be: ``-pytd vqa_v1``
"""
if 'stream' in opt.get('datatype'):
default_dataset = StreamDataset
else:
default_dataset = ParlAIDataset
dataset_name = opt.get('pytorch_teacher_dataset')
task_name = opt.get('pytorch_teacher_task')
datasets = []
if task_name is not None:
datasets += [
(default_dataset, default_collate, task) for task in task_name.split(',')
]
if not dataset_name:
return datasets
sps = [d.strip() for d in dataset_name.split(',')]
for sp in sps:
full_task_name = sp
repo = 'parlai'
if sp.startswith('internal:'):
# To switch to local repo, useful for non-public projects
# (make a directory called 'parlai_internal' with your private agents)
repo = 'parlai_internal'
sp = sp[9:]
sp = sp.split(':')
if '.' in sp[0]:
module_name = sp[0]
else:
dataset = sp[0].lower()
module_name = '{}.tasks.{}.agents'.format(repo, dataset)
if len(sp) > 1:
sp[1] = sp[1][0].upper() + sp[1][1:]
dataset = sp[1]
if '.' not in sp[0] and 'Dataset' not in dataset:
# Reformat from underscore to CamelCase and append "Dataset" to
# class name by default if a complete path is not given.
words = dataset.split('_')
teacher_name = ''
for w in words:
teacher_name += w[0].upper() + w[1:]
dataset = teacher_name + 'Dataset'
else:
dataset = 'DefaultDataset'
my_module = importlib.import_module(module_name)
dataset_class = getattr(my_module, dataset)
collate = default_collate
if hasattr(dataset_class, 'collate'):
collate = dataset_class.collate
elif opt.get('model', False):
agent_class = get_agent_module(opt.get('model'))
if hasattr(agent_class, 'collate'):
collate = agent_class.collate
datasets.append((dataset_class, collate, full_task_name))
return datasets
class LoaderProcess(Thread):
"""Background thread that submits jobs to the DataLoader."""
def __init__(self, opt):
super().__init__(daemon=True)
dataset_classes = get_dataset_classes(opt)
if len(dataset_classes) > 1:
datasets = []
for class_name, collate_fn, task_name in dataset_classes:
opt['pytorch_teacher_task'] = task_name
opt['task'] = task_name
datasets.append(class_name(opt))
self.collate = collate_fn
self.dataset = ParlAIConcatDataset(datasets)
else:
class_name, self.collate, task_name = dataset_classes[0]
self.dataset = class_name(opt)
self.bsz = opt.get('batchsize', 1)
self.num_workers = opt.get('num_workers', 4)
self.dataloader = DataLoader(
self.dataset,
batch_size=self.bsz,
shuffle=False,
sampler=sampler.SequentialSampler(self.dataset),
num_workers=self.num_workers,
collate_fn=self.collate,
pin_memory=False,
drop_last=False,
)
self.datatype = opt.get('datatype')
self.data = enumerate(self.dataloader)
self.batch_sort = opt.get('pytorch_teacher_batch_sort')
self.batch_cache_type = opt.get('batch_sort_cache_type')
self.batch_length_range = opt.get('batch_length_range')
self.batch_sort_field = opt.get('batch_sort_field')
def run(self):
"""Run the process loop."""
while True:
idx_and_batch = self.load_next()
if idx_and_batch is None:
return
@BatchSortCache.batch_cache
def load_next(self):
"""Get the next item or return ``None``."""
try:
return next(self.data)
except StopIteration:
return None
"""Collating, deserializing, processing batches"""
TORCH_DTYPES = [
torch.float32,
torch.float64,
torch.float16,
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
]
STR_TO_TORCH_DTYPE = {str(d): d for d in TORCH_DTYPES}
def default_collate(batch):
"""
Collate a batch.
Default collate function, used for ParlAIDataset and StreamDataset.
"""
new_batch = []
for b in batch:
idx = b[0]
if type(b[1]) is list:
ep = b[1][0]
else:
ep = b[1]
new_batch.append((idx, ep))
return new_batch
def deserialize(obj):
"""Deserializes lists into Tensors."""
keys = list(obj.keys())
for key in keys:
if type(obj[key]) is dict and obj[key].get('deserialized_tensor', False):
dtype = STR_TO_TORCH_DTYPE[obj[key]['type']]
val = obj[key]['value']
del obj[key]
obj[key] = torch.as_tensor(val, dtype=dtype)
return obj
def process(ex_or_batch):
"""Process examples/batches, i.e. deserialize if necessary."""
if type(ex_or_batch) is list:
if all([ep.get('preprocessed') for ep in ex_or_batch]):
ex_or_batch = [deserialize(ep) for ep in ex_or_batch]
else:
if ex_or_batch.get('preprocessed'):
ex_or_batch = deserialize(ex_or_batch)
return ex_or_batch
"""ParlAI Implementations of Pytorch Datasets"""
class StreamDataset(Dataset):
"""A Pytorch Dataset utilizing streaming."""
def __init__(self, opt):
self.opt = opt
self.datatype = opt.get('datatype')
self.datapath = build_data(self.opt)
self.length_datafile = os.path.join(self.datapath, 'data_length')
self.char_index_file = os.path.join(self.datapath, 'char_index')
self.datafile = os.path.join(self.datapath, 'data')
self.training = self.datatype.startswith('train')
self.ordered = 'ordered' in self.datatype or (
'stream' in self.datatype and not opt.get('shuffle')
)
self._load_lens()
def __getitem__(self, index):
if self.ordered or not self.training:
if not hasattr(self, 'data_gen'):
self.data_gen = self._read_episode()
while True:
idx, ep = next(self.data_gen)
if idx == index:
return (index, ep)
else:
episode = []
episode_done = False
with open(self.datafile) as f:
ex_offset = self.char_index[index]
f.seek(ex_offset)
while not episode_done:
example = json.loads(f.readline())
episode.append(example)
episode_done = example['episode_done']
return (index, episode)
def __len__(self):
return self.num_episodes()
def _load_lens(self):
with open(self.length_datafile) as length:
lengths = json.load(length)
self.num_eps = lengths['num_eps']
self.num_exs = lengths['num_exs']
with open(self.char_index_file) as char:
self.char_index = json.load(char)
def _data_generator(self):
while True:
for idx, episode in self._read_episode():
yield idx, episode
def _read_episode(self):
read = open(self.datafile)
episode = []
for idx, line in enumerate(read):
example = json.loads(line)
episode.append(example)
if example['episode_done']:
yield idx, episode
episode = []
read.close()
def num_episodes(self):
"""Return the number of episodes."""
return self.num_eps
def num_examples(self):
"""Return the number of examples."""
return self.num_exs
class ParlAIDataset(Dataset):
"""A Pytorch Dataset, for random sampling."""
def __init__(self, opt):
self.opt = opt
self.datatype = opt.get('datatype')
self.datapath = build_data(self.opt)
self.length_datafile = os.path.join(self.datapath, 'data_length')
self.datafile = os.path.join(self.datapath, 'data')
self.training = self.datatype.startswith('train')
self._load_lens()
self._setup_data()
def __getitem__(self, index):
return index, self.data[index]
def __len__(self):
return self.num_episodes()
def _load_lens(self):
with open(self.length_datafile) as length:
lengths = json.load(length)
self.num_eps = lengths['num_eps']
self.num_exs = lengths['num_exs']
def _setup_data(self):
self.data = []
with open(self.datafile) as f:
for line in f:
self.data.append(json.loads(line))
def num_episodes(self):
"""Return the number of episodes."""
return self.num_eps
def num_examples(self):
"""Return the number of examples."""
return self.num_exs
class ParlAIConcatDataset(ConcatDataset):
"""Override to set num_eps and num_exs."""
@lru_cache(maxsize=1)
def num_episodes(self):
"""Return the number of episodes."""
return sum(d.num_episodes() for d in self.datasets)
@lru_cache(maxsize=1)
def num_examples(self):
"""Return the number of examples."""
return sum(d.num_examples() for d in self.datasets)
class PytorchDataTeacher(FixedDialogTeacher):
"""
A teacher that loads data using Pytorch Datasets.
For details on how to use, please follow the tutorial here:
http://parl.ai/static/docs/tutorial_worlds.html#multiprocessed-pytorch-dataloader
"""
def __init__(self, opt, shared=None):
opt['batch_sort'] = False
super().__init__(opt, shared)
self.use_batch_act = self.bsz > 1
self.num_workers = opt['numworkers']
self.batch_sort = (
opt.get('pytorch_teacher_batch_sort') and 'train' in self.datatype
)
self.batch_cache_type = opt.get('batch_sort_cache_type')
self.batch_sort_field = opt.get('batch_sort_field')
# One can specify a collate function to use for preparing a batch
self.opt = opt.copy()
self.is_shared = shared is not None
dataset_classes = self._get_dataset_class(opt)
self.ordered = 'ordered' in self.datatype or (
'stream' in self.datatype and not opt.get('shuffle')
)
if self.ordered:
# force index for ordered, so that we see every example
warn_once(
'\nNote: You are using PytorchDataTeacher with ordered '
'examples. Please specify `--shuffle` if you would like '
'to have examples loaded in randomized order.\n'
)
self.batch_cache_type = 'index'
if not shared:
BatchSortCache.create()
if len(dataset_classes) > 1:
datasets = []
for class_name, collate_fn, task_name in dataset_classes:
dataset_opt = opt.copy()
dataset_opt['pytorch_teacher_task'] = task_name
dataset_opt['task'] = task_name
datasets.append(class_name(dataset_opt))
self.collate_fn = collate_fn
self.id = ','.join([d[2] for d in dataset_classes])
self.dataset = ParlAIConcatDataset(datasets)
else:
class_name, self.collate_fn, task_name = dataset_classes[0]
self.id = task_name
self.dataset = class_name(opt)
if self.ordered or not self.training:
data_sampler = sampler.SequentialSampler(self.dataset)
else:
data_sampler = sampler.RandomSampler(self.dataset)
self.pytorch_dataloader = DataLoader(
self.dataset,
batch_size=self.bsz,
sampler=data_sampler,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
pin_memory=False,
drop_last=False,
)
self.lastYs = [None] * self.bsz
if self.batch_sort:
self.loader_process = LoaderProcess(opt)
self.loader_process.start()
self.data = enumerate(self.pytorch_dataloader)
else:
self.dataset = shared['dataset']
self.pytorch_dataloader = shared['pytorch_dataloader']
self.lastYs = shared['lastYs']
self.data = shared['data']
self.id = shared['id']
self.num_batches = math.ceil(self.dataset.num_episodes() / self.bsz)
self.reset()
def _get_dataset_class(self, opt):
return get_dataset_classes(opt)
def reset(self):
"""
Reset the dialog so that it is at the start of the epoch.
Also resets all metrics.
"""
super().reset()
self.reset_data()
def reset_data(self):
"""Reset the data."""
if not self.is_shared:
self.data = enumerate(self.pytorch_dataloader)
self.lastY = None
self.epochDone = False
self.episode = None
self.episode_done = True
self.episode_idx = 0
self.batch_idx = 0
def share(self):
"""Share this teacher."""
shared = super().share()
shared['pytorch_dataloader'] = self.pytorch_dataloader
shared['dataset'] = self.dataset
shared['data'] = self.data
shared['id'] = self.id
return shared
def next_example(self):
"""Get the next example."""
if self.epochDone:
if not self.training:
return {'episode_done': True, 'id': self.getID()}, True
else:
# Reset the data because it is streaming data
self.reset_data()
if self.episode_done:
try:
self.episode_idx, episode = next(self.data)
if self.collate_fn == default_collate:
episode = [ex[1] for ex in episode]
self.episode = process(episode)
self.entry_idx = 0
epoch_done = False
except StopIteration:
ex = {'episode_done': True, 'id': self.getID()}
epoch_done = True
else:
self.entry_idx += 1
if not epoch_done:
ex = self.episode[self.entry_idx]
self.episode_done = ex['episode_done']
if self.episode_done and self.episode_idx + self.bsz >= self.num_episodes():
epoch_done = True
return ex, epoch_done
@BatchSortCache.batch_cache
def get_next_batch(self):
"""Get the next batch."""
# employs a cache to see if there is a batch of equal size ready
batch = next(self.data)
return batch
def next_batch(self):
"""Get the next batch."""
if self.epochDone:
if not self.training:
return [{'episode_done': True, 'id': self.getID()}] * self.bsz
else:
# Reset the data because it is streaming data
self.reset_data()
try:
self.batch_idx, batch = self.get_next_batch()
if self.collate_fn == default_collate:
batch = [b[1] for b in batch]
batch = process(batch)
epoch_done = False
except StopIteration:
batch = [{'episode_done': True, 'id': self.getID()}] * self.bsz
epoch_done = True
if not epoch_done and self.batch_idx == self.num_batches:
epoch_done = True
self.epochDone = epoch_done
return batch
def num_episodes(self):
"""Get the number of episodes in this dataset."""
return self.dataset.num_episodes()
def num_examples(self):
"""Get the total number of examples in this dataset."""
return self.dataset.num_examples()
def act(self):
"""Send new dialog message."""
action = super().act()
self.lastY = action.get('labels', action.get('eval_labels', None))
return action
def shutdown(self):
"""Shut down."""
super().shutdown()
BatchSortCache.destroy()
class DefaultTeacher(PytorchDataTeacher):
"""
Alias for PytorchDataTeacher.
This exists to simplify loading code in parlai.core.agents.get_task_module.
"""
pass