Skip to content

Commit

Permalink
[Feature] Pass replay buffers to MultiaSyncDataCollector
Browse files Browse the repository at this point in the history
ghstack-source-id: 21f1d09457b43d6ccc20c983a26bc239ca696bd6
Pull Request resolved: #2387
  • Loading branch information
vmoens committed Aug 10, 2024
1 parent 6a6115e commit 669d4da
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
24 changes: 24 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2822,6 +2822,30 @@ def test_collector_rb_multisync(self):
collector.shutdown()
assert len(rb) == 256

def test_collector_rb_multiasync(self):
env = GymEnv(CARTPOLE_VERSIONED())
env.set_seed(0)

rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5)
rb.add(env.rand_step(env.reset()))
rb.empty()

collector = MultiaSyncDataCollector(
[lambda: env, lambda: env],
RandomPolicy(env.action_spec),
replay_buffer=rb,
total_frames=256,
frames_per_batch=16,
)
torch.manual_seed(0)
pred_len = 0
for c in collector:
pred_len += 16
assert c is None
assert len(rb) >= pred_len
collector.shutdown()
assert len(rb) == 256


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
21 changes: 13 additions & 8 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sys
import time
import warnings
from collections import OrderedDict
from collections import defaultdict, OrderedDict
from copy import deepcopy
from multiprocessing import connection, queues
from multiprocessing.managers import SyncManager
Expand Down Expand Up @@ -2428,7 +2428,7 @@ class MultiaSyncDataCollector(_MultiDataCollector):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.out_tensordicts = {}
self.out_tensordicts = defaultdict(lambda: None)
self.running = False

if self.postprocs is not None:
Expand Down Expand Up @@ -2473,7 +2473,9 @@ def frames_per_batch_worker(self):
def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]:
new_data, j = self.queue_out.get(timeout=timeout)
use_buffers = self._use_buffers
if j == 0 or not use_buffers:
if self.replay_buffer is not None:
idx = new_data
elif j == 0 or not use_buffers:
try:
data, idx = new_data
self.out_tensordicts[idx] = data
Expand All @@ -2488,7 +2490,7 @@ def _get_from_queue(self, timeout=None) -> Tuple[int, int, TensorDictBase]:
else:
idx = new_data
out = self.out_tensordicts[idx]
if j == 0 or use_buffers:
if not self.replay_buffer and (j == 0 or use_buffers):
# we clone the data to make sure that we'll be working with a fixed copy
out = out.clone()
return idx, j, out
Expand All @@ -2513,9 +2515,12 @@ def iterator(self) -> Iterator[TensorDictBase]:
_check_for_faulty_process(self.procs)
self._iter += 1
idx, j, out = self._get_from_queue()
worker_frames = out.numel()
if self.split_trajs:
out = split_trajectories(out, prefix="collector")
if self.replay_buffer is None:
worker_frames = out.numel()
if self.split_trajs:
out = split_trajectories(out, prefix="collector")
else:
worker_frames = self.frames_per_batch_worker
self._frames += worker_frames
workers_frames[idx] = workers_frames[idx] + worker_frames
if self.postprocs:
Expand All @@ -2531,7 +2536,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
else:
msg = "continue"
self.pipes[idx].send((idx, msg))
if self._exclude_private_keys:
if out is not None and self._exclude_private_keys:
excluded_keys = [key for key in out.keys() if key.startswith("_")]
out = out.exclude(*excluded_keys)
yield out
Expand Down

0 comments on commit 669d4da

Please sign in to comment.