diff --git a/test/test_collector.py b/test/test_collector.py index f0255ff6ae8..00ae2aed519 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -2822,6 +2822,30 @@ def test_collector_rb_multisync(self): collector.shutdown() assert len(rb) == 256 + def test_collector_rb_multiasync(self): + env = GymEnv(CARTPOLE_VERSIONED()) + env.set_seed(0) + + rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5) + rb.add(env.rand_step(env.reset())) + rb.empty() + + collector = MultiaSyncDataCollector( + [lambda: env, lambda: env], + RandomPolicy(env.action_spec), + replay_buffer=rb, + total_frames=256, + frames_per_batch=16, + ) + torch.manual_seed(0) + pred_len = 0 + for c in collector: + pred_len += 16 + assert c is None + assert len(rb) >= pred_len + collector.shutdown() + assert len(rb) == 256 + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index d5b6ff0509d..39ca6f23c5c 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -16,7 +16,7 @@ import sys import time import warnings -from collections import OrderedDict +from collections import defaultdict, OrderedDict from copy import deepcopy from multiprocessing import connection, queues from multiprocessing.managers import SyncManager @@ -2428,7 +2428,7 @@ class MultiaSyncDataCollector(_MultiDataCollector): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.out_tensordicts = {} + self.out_tensordicts = defaultdict(lambda: None) self.running = False if self.postprocs is not None: @@ -2473,7 +2473,9 @@ def frames_per_batch_worker(self): def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]: new_data, j = self.queue_out.get(timeout=timeout) use_buffers = self._use_buffers - if j == 0 or not use_buffers: + if self.replay_buffer is not None: + idx = new_data + elif j == 0 or not use_buffers: try: data, idx = new_data self.out_tensordicts[idx] = data @@ -2488,7 +2490,7 @@ def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]: else: idx = new_data out = self.out_tensordicts[idx] - if j == 0 or use_buffers: + if not self.replay_buffer and (j == 0 or use_buffers): # we clone the data to make sure that we'll be working with a fixed copy out = out.clone() return idx, j, out @@ -2513,9 +2515,12 @@ def iterator(self) -> Iterator[TensorDictBase]: _check_for_faulty_process(self.procs) self._iter += 1 idx, j, out = self._get_from_queue() - worker_frames = out.numel() - if self.split_trajs: - out = split_trajectories(out, prefix="collector") + if self.replay_buffer is None: + worker_frames = out.numel() + if self.split_trajs: + out = split_trajectories(out, prefix="collector") + else: + worker_frames = self.frames_per_batch_worker self._frames += worker_frames workers_frames[idx] = workers_frames[idx] + worker_frames if self.postprocs: @@ -2531,7 +2536,7 @@ def iterator(self) -> Iterator[TensorDictBase]: else: msg = "continue" self.pipes[idx].send((idx, msg)) - if self._exclude_private_keys: + if out is not None and self._exclude_private_keys: excluded_keys = [key for key in out.keys() if key.startswith("_")] out = out.exclude(*excluded_keys) yield out