diff --git a/test/test_collector.py b/test/test_collector.py index 9a356d89637..e8e32e4d8c9 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -2823,6 +2823,30 @@ def test_collector_rb_multisync(self): collector.shutdown() assert len(rb) == 256 + def test_collector_rb_multiasync(self): + env = GymEnv(CARTPOLE_VERSIONED()) + env.set_seed(0) + + rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5) + rb.add(env.rand_step(env.reset())) + rb.empty() + + collector = MultiaSyncDataCollector( + [lambda: env, lambda: env], + RandomPolicy(env.action_spec), + replay_buffer=rb, + total_frames=256, + frames_per_batch=16, + ) + torch.manual_seed(0) + pred_len = 0 + for c in collector: + pred_len += 16 + assert c is None + assert len(rb) >= pred_len + collector.shutdown() + assert len(rb) == 256 + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 0ce425e50e3..7005010fc1a 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -16,7 +16,7 @@ import sys import time import warnings -from collections import OrderedDict +from collections import defaultdict, OrderedDict from copy import deepcopy from multiprocessing import connection, queues from multiprocessing.managers import SyncManager @@ -1023,7 +1023,7 @@ def _update_traj_ids(self, env_output) -> None: pool = self._traj_pool new_traj = pool.get_traj_and_increment( - traj_sop.sum().item(), device=traj_sop.device + traj_sop.sum(), device=traj_sop.device ) traj_ids = traj_ids.masked_scatter(traj_sop, new_traj) self._shuttle.set(("collector", "traj_ids"), traj_ids) @@ -1756,6 +1756,8 @@ def _run_processes(self) -> None: self.procs = [] self.pipes = [] traj_pool = _TrajectoryPool(lock=True) + self._traj_pool = traj_pool + for i, (env_fun, env_fun_kwargs) in enumerate( zip(self.create_env_fn, self.create_env_kwargs) ): @@ -2431,7 +2433,7 @@ class MultiaSyncDataCollector(_MultiDataCollector): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.out_tensordicts = {} + self.out_tensordicts = defaultdict(lambda: None) self.running = False if self.postprocs is not None: @@ -2476,7 +2478,9 @@ def frames_per_batch_worker(self): def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]: new_data, j = self.queue_out.get(timeout=timeout) use_buffers = self._use_buffers - if j == 0 or not use_buffers: + if self.replay_buffer is not None: + idx = new_data + elif j == 0 or not use_buffers: try: data, idx = new_data self.out_tensordicts[idx] = data @@ -2491,7 +2495,7 @@ def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]: else: idx = new_data out = self.out_tensordicts[idx] - if j == 0 or use_buffers: + if not self.replay_buffer and (j == 0 or use_buffers): # we clone the data to make sure that we'll be working with a fixed copy out = out.clone() return idx, j, out @@ -2516,9 +2520,12 @@ def iterator(self) -> Iterator[TensorDictBase]: _check_for_faulty_process(self.procs) self._iter += 1 idx, j, out = self._get_from_queue() - worker_frames = out.numel() - if self.split_trajs: - out = split_trajectories(out, prefix="collector") + if self.replay_buffer is None: + worker_frames = out.numel() + if self.split_trajs: + out = split_trajectories(out, prefix="collector") + else: + worker_frames = self.frames_per_batch_worker self._frames += worker_frames workers_frames[idx] = workers_frames[idx] + worker_frames if self.postprocs: @@ -2534,7 +2541,7 @@ def iterator(self) -> Iterator[TensorDictBase]: else: msg = "continue" self.pipes[idx].send((idx, msg)) - if self._exclude_private_keys: + if out is not None and self._exclude_private_keys: excluded_keys = [key for key in out.keys() if key.startswith("_")] out = out.exclude(*excluded_keys) yield out @@ -3016,15 +3023,13 @@ def __init__(self, ctx=None, lock: bool = False): self.ctx = ctx if ctx is None: self._traj_id = mp.Value("i", 0) - self.lock = contextlib.nullcontext() if not lock else mp.Lock() + self.lock = contextlib.nullcontext() if not lock else mp.RLock() else: self._traj_id = ctx.Value("i", 0) - self.lock = contextlib.nullcontext() if not lock else ctx.Lock() + self.lock = contextlib.nullcontext() if not lock else ctx.RLock() def get_traj_and_increment(self, n=1, device=None): with self.lock: - traj_id = torch.arange( - self._traj_id.value, self._traj_id.value + n, device=device - ) - self._traj_id.value += n - return traj_id + v = self._traj_id.value + self._traj_id.value = v + n + return torch.arange(v, v + n, device=device)