Skip to content

Commit

Permalink
[Feature] replay_buffer_chunk
Browse files Browse the repository at this point in the history
ghstack-source-id: 66572fb9559824296240989a9a8763739fcde5f6
Pull Request resolved: #2388
  • Loading branch information
vmoens committed Aug 11, 2024
1 parent b8f6b7c commit ecc5e00
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 23 deletions.
91 changes: 72 additions & 19 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from torchrl.collectors.utils import split_trajectories
from torchrl.data import (
Composite,
LazyMemmapStorage,
LazyTensorStorage,
NonTensor,
ReplayBuffer,
Expand Down Expand Up @@ -2799,44 +2800,86 @@ def test_collector_rb_sync(self):
del collector, env
assert assert_allclose_td(rbdata0, rbdata1)

def test_collector_rb_multisync(self):
env = GymEnv(CARTPOLE_VERSIONED())
env.set_seed(0)
@pytest.mark.parametrize("replay_buffer_chunk", [False, True])
@pytest.mark.parametrize("env_creator", [False, True])
@pytest.mark.parametrize("storagetype", [LazyTensorStorage, LazyMemmapStorage])
def test_collector_rb_multisync(
self, replay_buffer_chunk, env_creator, storagetype, tmpdir
):
if not env_creator:
env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter())
env.set_seed(0)
action_spec = env.action_spec
env = lambda env=env: env
else:
env = EnvCreator(
lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp).append_transform(
StepCounter()
)
)
action_spec = env.meta_data.specs["input_spec", "full_action_spec"]

rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5)
rb.add(env.rand_step(env.reset()))
rb.empty()
if storagetype == LazyMemmapStorage:
storagetype = functools.partial(LazyMemmapStorage, scratch_dir=tmpdir)
rb = ReplayBuffer(storage=storagetype(256), batch_size=5)

collector = MultiSyncDataCollector(
[lambda: env, lambda: env],
RandomPolicy(env.action_spec),
[env, env],
RandomPolicy(action_spec),
replay_buffer=rb,
total_frames=256,
frames_per_batch=16,
frames_per_batch=32,
replay_buffer_chunk=replay_buffer_chunk,
)
torch.manual_seed(0)
pred_len = 0
for c in collector:
pred_len += 16
pred_len += 32
assert c is None
assert len(rb) == pred_len
collector.shutdown()
assert len(rb) == 256
if not replay_buffer_chunk:
steps_counts = rb["step_count"].squeeze().split(16)
collector_ids = rb["collector", "traj_ids"].squeeze().split(16)
for step_count, ids in zip(steps_counts, collector_ids):
step_countdiff = step_count.diff()
idsdiff = ids.diff()
assert (
(step_countdiff == 1) | (step_countdiff < 0)
).all(), steps_counts
assert (idsdiff >= 0).all()

@pytest.mark.parametrize("replay_buffer_chunk", [False, True])
@pytest.mark.parametrize("env_creator", [False, True])
@pytest.mark.parametrize("storagetype", [LazyTensorStorage, LazyMemmapStorage])
def test_collector_rb_multiasync(
self, replay_buffer_chunk, env_creator, storagetype, tmpdir
):
if not env_creator:
env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter())
env.set_seed(0)
action_spec = env.action_spec
env = lambda env=env: env
else:
env = EnvCreator(
lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp).append_transform(
StepCounter()
)
)
action_spec = env.meta_data.specs["input_spec", "full_action_spec"]

def test_collector_rb_multiasync(self):
env = GymEnv(CARTPOLE_VERSIONED())
env.set_seed(0)

rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5)
rb.add(env.rand_step(env.reset()))
rb.empty()
if storagetype == LazyMemmapStorage:
storagetype = functools.partial(LazyMemmapStorage, scratch_dir=tmpdir)
rb = ReplayBuffer(storage=storagetype(256), batch_size=5)

collector = MultiaSyncDataCollector(
[lambda: env, lambda: env],
RandomPolicy(env.action_spec),
[env, env],
RandomPolicy(action_spec),
replay_buffer=rb,
total_frames=256,
frames_per_batch=16,
replay_buffer_chunk=replay_buffer_chunk,
)
torch.manual_seed(0)
pred_len = 0
Expand All @@ -2846,6 +2889,16 @@ def test_collector_rb_multiasync(self):
assert len(rb) >= pred_len
collector.shutdown()
assert len(rb) == 256
if not replay_buffer_chunk:
steps_counts = rb["step_count"].squeeze().split(16)
collector_ids = rb["collector", "traj_ids"].squeeze().split(16)
for step_count, ids in zip(steps_counts, collector_ids):
step_countdiff = step_count.diff()
idsdiff = ids.diff()
assert (
(step_countdiff == 1) | (step_countdiff < 0)
).all(), steps_counts
assert (idsdiff >= 0).all()


if __name__ == "__main__":
Expand Down
29 changes: 27 additions & 2 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from torchrl.data.tensor_specs import TensorSpec
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, EnvBase
from torchrl.envs.env_creator import EnvCreator
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.utils import (
_aggregate_end_of_traj,
Expand Down Expand Up @@ -1469,6 +1470,7 @@ def __init__(
set_truncated: bool = False,
use_buffers: bool | None = None,
replay_buffer: ReplayBuffer | None = None,
replay_buffer_chunk: bool = True,
):
exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
Expand Down Expand Up @@ -1513,6 +1515,8 @@ def __init__(

self._use_buffers = use_buffers
self.replay_buffer = replay_buffer
self._check_replay_buffer_init()
self.replay_buffer_chunk = replay_buffer_chunk
if (
replay_buffer is not None
and hasattr(replay_buffer, "shared")
Expand Down Expand Up @@ -1659,6 +1663,21 @@ def _get_weight_fn(weights=policy_weights):
)
self.cat_results = cat_results

def _check_replay_buffer_init(self):
try:
if not self.replay_buffer._storage.initialized:
if isinstance(self.create_env_fn, EnvCreator):
fake_td = self.create_env_fn.tensordict
else:
fake_td = self.create_env_fn[0](
**self.create_env_kwargs[0]
).fake_tensordict()
fake_td["collector", "traj_ids"] = torch.zeros((), dtype=torch.long)

self.replay_buffer._storage._init(fake_td)
except AttributeError:
pass

@classmethod
def _total_workers_from_env(cls, env_creators):
if isinstance(env_creators, (tuple, list)):
Expand Down Expand Up @@ -1795,6 +1814,7 @@ def _run_processes(self) -> None:
"set_truncated": self.set_truncated,
"use_buffers": self._use_buffers,
"replay_buffer": self.replay_buffer,
"replay_buffer_chunk": self.replay_buffer_chunk,
"traj_pool": traj_pool,
}
proc = _ProcessNoWarn(
Expand Down Expand Up @@ -2804,6 +2824,7 @@ def _main_async_collector(
set_truncated: bool = False,
use_buffers: bool | None = None,
replay_buffer: ReplayBuffer | None = None,
replay_buffer_chunk: bool = True,
traj_pool: _TrajectoryPool = None,
) -> None:
pipe_parent.close()
Expand All @@ -2825,11 +2846,11 @@ def _main_async_collector(
env_device=env_device,
exploration_type=exploration_type,
reset_when_done=reset_when_done,
return_same_td=True,
return_same_td=replay_buffer is None,
interruptor=interruptor,
set_truncated=set_truncated,
use_buffers=use_buffers,
replay_buffer=replay_buffer,
replay_buffer=replay_buffer if replay_buffer_chunk else None,
traj_pool=traj_pool,
)
use_buffers = inner_collector._use_buffers
Expand Down Expand Up @@ -2895,6 +2916,10 @@ def _main_async_collector(
continue

if replay_buffer is not None:
if not replay_buffer_chunk:
next_data.names = None
replay_buffer.extend(next_data)

try:
queue_out.put((idx, j), timeout=_TIMEOUT)
if verbose:
Expand Down
5 changes: 5 additions & 0 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,11 @@ def __len__(self) -> int:
with self._replay_lock:
return len(self._storage)

@property
def write_count(self):
"""The total number of items written so far in the buffer through add and extend."""
return self._writer._write_count

def __repr__(self) -> str:
from torchrl.envs.transforms import Compose

Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def assert_is_sharable(tensor):
raise RuntimeError(STORAGE_ERR)

if is_tensor_collection(storage):
storage.apply(assert_is_sharable)
storage.apply(assert_is_sharable, filter_empty=True)
else:
tree_map(storage, assert_is_sharable)

Expand Down
20 changes: 20 additions & 0 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def add(self, data: Any) -> int | torch.Tensor:
self._cursor = (self._cursor + 1) % self._storage._max_size_along_dim0(
single_data=data
)
self._write_count += 1
# Replicate index requires the shape of the storage to be known
# Other than that, a "flat" (1d) index is ok to write the data
self._storage.set(_cursor, data)
Expand Down Expand Up @@ -191,6 +192,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
)
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (batch_size + cur_size) % max_size_along0
self._write_count += batch_size
# Replicate index requires the shape of the storage to be known
# Other than that, a "flat" (1d) index is ok to write the data
self._storage.set(index, data)
Expand Down Expand Up @@ -222,6 +224,20 @@ def _cursor(self, value):
_cursor_value = self._cursor_value = mp.Value("i", 0)
_cursor_value.value = value

@property
def _write_count(self):
_write_count = self.__dict__.get("_write_count_value", None)
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
return _write_count.value

@_write_count.setter
def _write_count(self, value):
_write_count = self.__dict__.get("_write_count_value", None)
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
_write_count.value = value

def __getstate__(self):
state = super().__getstate__()
if get_spawning_popen() is None:
Expand Down Expand Up @@ -249,6 +265,7 @@ def add(self, data: Any) -> int | torch.Tensor:
# we need to update the cursor first to avoid race conditions between workers
max_size_along_dim0 = self._storage._max_size_along_dim0(single_data=data)
self._cursor = (index + 1) % max_size_along_dim0
self._write_count += 1
if not is_tensorclass(data):
data.set(
"index",
Expand All @@ -275,6 +292,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
)
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (batch_size + cur_size) % max_size_along_dim0
self._write_count += batch_size
# storage must convert the data to the appropriate format if needed
if not is_tensorclass(data):
data.set(
Expand Down Expand Up @@ -469,6 +487,7 @@ def add(self, data: Any) -> int | torch.Tensor:
index = self.get_insert_index(data)
if index is not None:
data.set("index", index)
self._write_count += 1
# Replicate index requires the shape of the storage to be known
# Other than that, a "flat" (1d) index is ok to write the data
self._storage.set(index, data)
Expand All @@ -488,6 +507,7 @@ def extend(self, data: TensorDictBase) -> None:
for data_idx, sample in enumerate(data):
storage_idx = self.get_insert_index(sample)
if storage_idx is not None:
self._write_count += 1
data_to_replace[storage_idx] = data_idx

# -1 will be interpreted as invalid by prioritized buffers
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/env_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def share_memory(self, state_dict: OrderedDict) -> None:
del state_dict[key]

@property
def meta_data(self):
def meta_data(self) -> EnvMetaData:
if self._meta_data is None:
raise RuntimeError(
"meta_data is None in EnvCreator. " "Make sure init_() has been called."
Expand Down

0 comments on commit ecc5e00

Please sign in to comment.