From ecc5e0076c9ed21079aa11785505325f42b19746 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 10 Aug 2024 20:57:40 -0400 Subject: [PATCH] [Feature] replay_buffer_chunk ghstack-source-id: 66572fb9559824296240989a9a8763739fcde5f6 Pull Request resolved: https://github.com/pytorch/rl/pull/2388 --- test/test_collector.py | 91 +++++++++++++++---- torchrl/collectors/collectors.py | 29 +++++- torchrl/data/replay_buffers/replay_buffers.py | 5 + torchrl/data/replay_buffers/storages.py | 2 +- torchrl/data/replay_buffers/writers.py | 20 ++++ torchrl/envs/env_creator.py | 2 +- 6 files changed, 126 insertions(+), 23 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index e8e32e4d8c9..1b59e9a797e 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -69,6 +69,7 @@ from torchrl.collectors.utils import split_trajectories from torchrl.data import ( Composite, + LazyMemmapStorage, LazyTensorStorage, NonTensor, ReplayBuffer, @@ -2799,44 +2800,86 @@ def test_collector_rb_sync(self): del collector, env assert assert_allclose_td(rbdata0, rbdata1) - def test_collector_rb_multisync(self): - env = GymEnv(CARTPOLE_VERSIONED()) - env.set_seed(0) + @pytest.mark.parametrize("replay_buffer_chunk", [False, True]) + @pytest.mark.parametrize("env_creator", [False, True]) + @pytest.mark.parametrize("storagetype", [LazyTensorStorage, LazyMemmapStorage]) + def test_collector_rb_multisync( + self, replay_buffer_chunk, env_creator, storagetype, tmpdir + ): + if not env_creator: + env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter()) + env.set_seed(0) + action_spec = env.action_spec + env = lambda env=env: env + else: + env = EnvCreator( + lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp).append_transform( + StepCounter() + ) + ) + action_spec = env.meta_data.specs["input_spec", "full_action_spec"] - rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5) - rb.add(env.rand_step(env.reset())) - rb.empty() + if storagetype == LazyMemmapStorage: + storagetype = functools.partial(LazyMemmapStorage, scratch_dir=tmpdir) + rb = ReplayBuffer(storage=storagetype(256), batch_size=5) collector = MultiSyncDataCollector( - [lambda: env, lambda: env], - RandomPolicy(env.action_spec), + [env, env], + RandomPolicy(action_spec), replay_buffer=rb, total_frames=256, - frames_per_batch=16, + frames_per_batch=32, + replay_buffer_chunk=replay_buffer_chunk, ) torch.manual_seed(0) pred_len = 0 for c in collector: - pred_len += 16 + pred_len += 32 assert c is None assert len(rb) == pred_len collector.shutdown() assert len(rb) == 256 + if not replay_buffer_chunk: + steps_counts = rb["step_count"].squeeze().split(16) + collector_ids = rb["collector", "traj_ids"].squeeze().split(16) + for step_count, ids in zip(steps_counts, collector_ids): + step_countdiff = step_count.diff() + idsdiff = ids.diff() + assert ( + (step_countdiff == 1) | (step_countdiff < 0) + ).all(), steps_counts + assert (idsdiff >= 0).all() + + @pytest.mark.parametrize("replay_buffer_chunk", [False, True]) + @pytest.mark.parametrize("env_creator", [False, True]) + @pytest.mark.parametrize("storagetype", [LazyTensorStorage, LazyMemmapStorage]) + def test_collector_rb_multiasync( + self, replay_buffer_chunk, env_creator, storagetype, tmpdir + ): + if not env_creator: + env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter()) + env.set_seed(0) + action_spec = env.action_spec + env = lambda env=env: env + else: + env = EnvCreator( + lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp).append_transform( + StepCounter() + ) + ) + action_spec = env.meta_data.specs["input_spec", "full_action_spec"] - def test_collector_rb_multiasync(self): - env = GymEnv(CARTPOLE_VERSIONED()) - env.set_seed(0) - - rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5) - rb.add(env.rand_step(env.reset())) - rb.empty() + if storagetype == LazyMemmapStorage: + storagetype = functools.partial(LazyMemmapStorage, scratch_dir=tmpdir) + rb = ReplayBuffer(storage=storagetype(256), batch_size=5) collector = MultiaSyncDataCollector( - [lambda: env, lambda: env], - RandomPolicy(env.action_spec), + [env, env], + RandomPolicy(action_spec), replay_buffer=rb, total_frames=256, frames_per_batch=16, + replay_buffer_chunk=replay_buffer_chunk, ) torch.manual_seed(0) pred_len = 0 @@ -2846,6 +2889,16 @@ def test_collector_rb_multiasync(self): assert len(rb) >= pred_len collector.shutdown() assert len(rb) == 256 + if not replay_buffer_chunk: + steps_counts = rb["step_count"].squeeze().split(16) + collector_ids = rb["collector", "traj_ids"].squeeze().split(16) + for step_count, ids in zip(steps_counts, collector_ids): + step_countdiff = step_count.diff() + idsdiff = ids.diff() + assert ( + (step_countdiff == 1) | (step_countdiff < 0) + ).all(), steps_counts + assert (idsdiff >= 0).all() if __name__ == "__main__": diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 7005010fc1a..5abd684834a 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -54,6 +54,7 @@ from torchrl.data.tensor_specs import TensorSpec from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING from torchrl.envs.common import _do_nothing, EnvBase +from torchrl.envs.env_creator import EnvCreator from torchrl.envs.transforms import StepCounter, TransformedEnv from torchrl.envs.utils import ( _aggregate_end_of_traj, @@ -1469,6 +1470,7 @@ def __init__( set_truncated: bool = False, use_buffers: bool | None = None, replay_buffer: ReplayBuffer | None = None, + replay_buffer_chunk: bool = True, ): exploration_type = _convert_exploration_type( exploration_mode=exploration_mode, exploration_type=exploration_type @@ -1513,6 +1515,8 @@ def __init__( self._use_buffers = use_buffers self.replay_buffer = replay_buffer + self._check_replay_buffer_init() + self.replay_buffer_chunk = replay_buffer_chunk if ( replay_buffer is not None and hasattr(replay_buffer, "shared") @@ -1659,6 +1663,21 @@ def _get_weight_fn(weights=policy_weights): ) self.cat_results = cat_results + def _check_replay_buffer_init(self): + try: + if not self.replay_buffer._storage.initialized: + if isinstance(self.create_env_fn, EnvCreator): + fake_td = self.create_env_fn.tensordict + else: + fake_td = self.create_env_fn[0]( + **self.create_env_kwargs[0] + ).fake_tensordict() + fake_td["collector", "traj_ids"] = torch.zeros((), dtype=torch.long) + + self.replay_buffer._storage._init(fake_td) + except AttributeError: + pass + @classmethod def _total_workers_from_env(cls, env_creators): if isinstance(env_creators, (tuple, list)): @@ -1795,6 +1814,7 @@ def _run_processes(self) -> None: "set_truncated": self.set_truncated, "use_buffers": self._use_buffers, "replay_buffer": self.replay_buffer, + "replay_buffer_chunk": self.replay_buffer_chunk, "traj_pool": traj_pool, } proc = _ProcessNoWarn( @@ -2804,6 +2824,7 @@ def _main_async_collector( set_truncated: bool = False, use_buffers: bool | None = None, replay_buffer: ReplayBuffer | None = None, + replay_buffer_chunk: bool = True, traj_pool: _TrajectoryPool = None, ) -> None: pipe_parent.close() @@ -2825,11 +2846,11 @@ def _main_async_collector( env_device=env_device, exploration_type=exploration_type, reset_when_done=reset_when_done, - return_same_td=True, + return_same_td=replay_buffer is None, interruptor=interruptor, set_truncated=set_truncated, use_buffers=use_buffers, - replay_buffer=replay_buffer, + replay_buffer=replay_buffer if replay_buffer_chunk else None, traj_pool=traj_pool, ) use_buffers = inner_collector._use_buffers @@ -2895,6 +2916,10 @@ def _main_async_collector( continue if replay_buffer is not None: + if not replay_buffer_chunk: + next_data.names = None + replay_buffer.extend(next_data) + try: queue_out.put((idx, j), timeout=_TIMEOUT) if verbose: diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 5ad1bb170cb..afa6f861079 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -364,6 +364,11 @@ def __len__(self) -> int: with self._replay_lock: return len(self._storage) + @property + def write_count(self): + """The total number of items written so far in the buffer through add and extend.""" + return self._writer._write_count + def __repr__(self) -> str: from torchrl.envs.transforms import Compose diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index d1bd6fbf599..acdf2dcf8dd 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -562,7 +562,7 @@ def assert_is_sharable(tensor): raise RuntimeError(STORAGE_ERR) if is_tensor_collection(storage): - storage.apply(assert_is_sharable) + storage.apply(assert_is_sharable, filter_empty=True) else: tree_map(storage, assert_is_sharable) diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 066658993b1..6f906dea4ef 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -163,6 +163,7 @@ def add(self, data: Any) -> int | torch.Tensor: self._cursor = (self._cursor + 1) % self._storage._max_size_along_dim0( single_data=data ) + self._write_count += 1 # Replicate index requires the shape of the storage to be known # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(_cursor, data) @@ -191,6 +192,7 @@ def extend(self, data: Sequence) -> torch.Tensor: ) # we need to update the cursor first to avoid race conditions between workers self._cursor = (batch_size + cur_size) % max_size_along0 + self._write_count += batch_size # Replicate index requires the shape of the storage to be known # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(index, data) @@ -222,6 +224,20 @@ def _cursor(self, value): _cursor_value = self._cursor_value = mp.Value("i", 0) _cursor_value.value = value + @property + def _write_count(self): + _write_count = self.__dict__.get("_write_count_value", None) + if _write_count is None: + _write_count = self._write_count_value = mp.Value("i", 0) + return _write_count.value + + @_write_count.setter + def _write_count(self, value): + _write_count = self.__dict__.get("_write_count_value", None) + if _write_count is None: + _write_count = self._write_count_value = mp.Value("i", 0) + _write_count.value = value + def __getstate__(self): state = super().__getstate__() if get_spawning_popen() is None: @@ -249,6 +265,7 @@ def add(self, data: Any) -> int | torch.Tensor: # we need to update the cursor first to avoid race conditions between workers max_size_along_dim0 = self._storage._max_size_along_dim0(single_data=data) self._cursor = (index + 1) % max_size_along_dim0 + self._write_count += 1 if not is_tensorclass(data): data.set( "index", @@ -275,6 +292,7 @@ def extend(self, data: Sequence) -> torch.Tensor: ) # we need to update the cursor first to avoid race conditions between workers self._cursor = (batch_size + cur_size) % max_size_along_dim0 + self._write_count += batch_size # storage must convert the data to the appropriate format if needed if not is_tensorclass(data): data.set( @@ -469,6 +487,7 @@ def add(self, data: Any) -> int | torch.Tensor: index = self.get_insert_index(data) if index is not None: data.set("index", index) + self._write_count += 1 # Replicate index requires the shape of the storage to be known # Other than that, a "flat" (1d) index is ok to write the data self._storage.set(index, data) @@ -488,6 +507,7 @@ def extend(self, data: TensorDictBase) -> None: for data_idx, sample in enumerate(data): storage_idx = self.get_insert_index(sample) if storage_idx is not None: + self._write_count += 1 data_to_replace[storage_idx] = data_idx # -1 will be interpreted as invalid by prioritized buffers diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index 89ee8cc5614..f090289214d 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -109,7 +109,7 @@ def share_memory(self, state_dict: OrderedDict) -> None: del state_dict[key] @property - def meta_data(self): + def meta_data(self) -> EnvMetaData: if self._meta_data is None: raise RuntimeError( "meta_data is None in EnvCreator. " "Make sure init_() has been called."