Skip to content

Commit

Permalink
[Feature] Pass replay buffers to MultiaSyncDataCollector
Browse files Browse the repository at this point in the history
ghstack-source-id: 7275208e2f02560229ca83c999cd9b0ae68aaf4f
Pull Request resolved: #2387
  • Loading branch information
vmoens committed Aug 13, 2024
1 parent 9627e8a commit a0c12cd
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 8 deletions.
25 changes: 25 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2830,6 +2830,31 @@ def test_collector_rb_multisync(self):
collector.shutdown()
assert len(rb) == 256

@pytest.mark.skipif(not _has_gym, reason="requires gym.")
def test_collector_rb_multiasync(self):
env = GymEnv(CARTPOLE_VERSIONED())
env.set_seed(0)

rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5)
rb.add(env.rand_step(env.reset()))
rb.empty()

collector = MultiaSyncDataCollector(
[lambda: env, lambda: env],
RandomPolicy(env.action_spec),
replay_buffer=rb,
total_frames=256,
frames_per_batch=16,
)
torch.manual_seed(0)
pred_len = 0
for c in collector:
pred_len += 16
assert c is None
assert len(rb) >= pred_len
collector.shutdown()
assert len(rb) == 256


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
21 changes: 13 additions & 8 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sys
import time
import warnings
from collections import OrderedDict
from collections import defaultdict, OrderedDict
from copy import deepcopy
from multiprocessing import connection, queues
from multiprocessing.managers import SyncManager
Expand Down Expand Up @@ -2433,7 +2433,7 @@ class MultiaSyncDataCollector(_MultiDataCollector):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.out_tensordicts = {}
self.out_tensordicts = defaultdict(lambda: None)
self.running = False

if self.postprocs is not None:
Expand Down Expand Up @@ -2478,7 +2478,9 @@ def frames_per_batch_worker(self):
def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]:
new_data, j = self.queue_out.get(timeout=timeout)
use_buffers = self._use_buffers
if j == 0 or not use_buffers:
if self.replay_buffer is not None:
idx = new_data
elif j == 0 or not use_buffers:
try:
data, idx = new_data
self.out_tensordicts[idx] = data
Expand All @@ -2493,7 +2495,7 @@ def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]:
else:
idx = new_data
out = self.out_tensordicts[idx]
if j == 0 or use_buffers:
if not self.replay_buffer and (j == 0 or use_buffers):
# we clone the data to make sure that we'll be working with a fixed copy
out = out.clone()
return idx, j, out
Expand All @@ -2518,9 +2520,12 @@ def iterator(self) -> Iterator[TensorDictBase]:
_check_for_faulty_process(self.procs)
self._iter += 1
idx, j, out = self._get_from_queue()
worker_frames = out.numel()
if self.split_trajs:
out = split_trajectories(out, prefix="collector")
if self.replay_buffer is None:
worker_frames = out.numel()
if self.split_trajs:
out = split_trajectories(out, prefix="collector")
else:
worker_frames = self.frames_per_batch_worker
self._frames += worker_frames
workers_frames[idx] = workers_frames[idx] + worker_frames
if self.postprocs:
Expand All @@ -2536,7 +2541,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
else:
msg = "continue"
self.pipes[idx].send((idx, msg))
if self._exclude_private_keys:
if out is not None and self._exclude_private_keys:
excluded_keys = [key for key in out.keys() if key.startswith("_")]
out = out.exclude(*excluded_keys)
yield out
Expand Down

0 comments on commit a0c12cd

Please sign in to comment.