Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] replay_buffer_chunk #2388

Merged
merged 15 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 71 additions & 18 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from torchrl.collectors.utils import split_trajectories
from torchrl.data import (
Composite,
LazyMemmapStorage,
LazyTensorStorage,
NonTensor,
ReplayBuffer,
Expand Down Expand Up @@ -2806,45 +2807,87 @@ def test_collector_rb_sync(self):
assert assert_allclose_td(rbdata0, rbdata1)

@pytest.mark.skipif(not _has_gym, reason="requires gym.")
def test_collector_rb_multisync(self):
env = GymEnv(CARTPOLE_VERSIONED())
env.set_seed(0)
@pytest.mark.parametrize("replay_buffer_chunk", [False, True])
@pytest.mark.parametrize("env_creator", [False, True])
@pytest.mark.parametrize("storagetype", [LazyTensorStorage, LazyMemmapStorage])
def test_collector_rb_multisync(
self, replay_buffer_chunk, env_creator, storagetype, tmpdir
):
if not env_creator:
env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter())
env.set_seed(0)
action_spec = env.action_spec
env = lambda env=env: env
else:
env = EnvCreator(
lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp).append_transform(
StepCounter()
)
)
action_spec = env.meta_data.specs["input_spec", "full_action_spec"]

rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5)
rb.add(env.rand_step(env.reset()))
rb.empty()
if storagetype == LazyMemmapStorage:
storagetype = functools.partial(LazyMemmapStorage, scratch_dir=tmpdir)
rb = ReplayBuffer(storage=storagetype(256), batch_size=5)

collector = MultiSyncDataCollector(
[lambda: env, lambda: env],
RandomPolicy(env.action_spec),
[env, env],
RandomPolicy(action_spec),
replay_buffer=rb,
total_frames=256,
frames_per_batch=16,
frames_per_batch=32,
replay_buffer_chunk=replay_buffer_chunk,
)
torch.manual_seed(0)
pred_len = 0
for c in collector:
pred_len += 16
pred_len += 32
assert c is None
assert len(rb) == pred_len
collector.shutdown()
assert len(rb) == 256
if not replay_buffer_chunk:
steps_counts = rb["step_count"].squeeze().split(16)
collector_ids = rb["collector", "traj_ids"].squeeze().split(16)
for step_count, ids in zip(steps_counts, collector_ids):
step_countdiff = step_count.diff()
idsdiff = ids.diff()
assert (
(step_countdiff == 1) | (step_countdiff < 0)
).all(), steps_counts
assert (idsdiff >= 0).all()

@pytest.mark.skipif(not _has_gym, reason="requires gym.")
def test_collector_rb_multiasync(self):
env = GymEnv(CARTPOLE_VERSIONED())
env.set_seed(0)
@pytest.mark.parametrize("replay_buffer_chunk", [False, True])
@pytest.mark.parametrize("env_creator", [False, True])
@pytest.mark.parametrize("storagetype", [LazyTensorStorage, LazyMemmapStorage])
def test_collector_rb_multiasync(
self, replay_buffer_chunk, env_creator, storagetype, tmpdir
):
if not env_creator:
env = GymEnv(CARTPOLE_VERSIONED()).append_transform(StepCounter())
env.set_seed(0)
action_spec = env.action_spec
env = lambda env=env: env
else:
env = EnvCreator(
lambda cp=CARTPOLE_VERSIONED(): GymEnv(cp).append_transform(
StepCounter()
)
)
action_spec = env.meta_data.specs["input_spec", "full_action_spec"]

rb = ReplayBuffer(storage=LazyTensorStorage(256), batch_size=5)
rb.add(env.rand_step(env.reset()))
rb.empty()
if storagetype == LazyMemmapStorage:
storagetype = functools.partial(LazyMemmapStorage, scratch_dir=tmpdir)
rb = ReplayBuffer(storage=storagetype(256), batch_size=5)

collector = MultiaSyncDataCollector(
[lambda: env, lambda: env],
RandomPolicy(env.action_spec),
[env, env],
RandomPolicy(action_spec),
replay_buffer=rb,
total_frames=256,
frames_per_batch=16,
replay_buffer_chunk=replay_buffer_chunk,
)
torch.manual_seed(0)
pred_len = 0
Expand All @@ -2854,6 +2897,16 @@ def test_collector_rb_multiasync(self):
assert len(rb) >= pred_len
collector.shutdown()
assert len(rb) == 256
if not replay_buffer_chunk:
steps_counts = rb["step_count"].squeeze().split(16)
collector_ids = rb["collector", "traj_ids"].squeeze().split(16)
for step_count, ids in zip(steps_counts, collector_ids):
step_countdiff = step_count.diff()
idsdiff = ids.diff()
assert (
(step_countdiff == 1) | (step_countdiff < 0)
).all(), steps_counts
assert (idsdiff >= 0).all()


if __name__ == "__main__":
Expand Down
29 changes: 27 additions & 2 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from torchrl.data.tensor_specs import TensorSpec
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, EnvBase
from torchrl.envs.env_creator import EnvCreator
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.utils import (
_aggregate_end_of_traj,
Expand Down Expand Up @@ -1470,6 +1471,7 @@ def __init__(
set_truncated: bool = False,
use_buffers: bool | None = None,
replay_buffer: ReplayBuffer | None = None,
replay_buffer_chunk: bool = True,
):
exploration_type = _convert_exploration_type(
exploration_mode=exploration_mode, exploration_type=exploration_type
Expand Down Expand Up @@ -1514,6 +1516,8 @@ def __init__(

self._use_buffers = use_buffers
self.replay_buffer = replay_buffer
self._check_replay_buffer_init()
self.replay_buffer_chunk = replay_buffer_chunk
if (
replay_buffer is not None
and hasattr(replay_buffer, "shared")
Expand Down Expand Up @@ -1660,6 +1664,21 @@ def _get_weight_fn(weights=policy_weights):
)
self.cat_results = cat_results

def _check_replay_buffer_init(self):
try:
if not self.replay_buffer._storage.initialized:
if isinstance(self.create_env_fn, EnvCreator):
fake_td = self.create_env_fn.tensordict
else:
fake_td = self.create_env_fn[0](
**self.create_env_kwargs[0]
).fake_tensordict()
fake_td["collector", "traj_ids"] = torch.zeros((), dtype=torch.long)

self.replay_buffer._storage._init(fake_td)
except AttributeError:
pass

@classmethod
def _total_workers_from_env(cls, env_creators):
if isinstance(env_creators, (tuple, list)):
Expand Down Expand Up @@ -1795,6 +1814,7 @@ def _run_processes(self) -> None:
"set_truncated": self.set_truncated,
"use_buffers": self._use_buffers,
"replay_buffer": self.replay_buffer,
"replay_buffer_chunk": self.replay_buffer_chunk,
"traj_pool": self._traj_pool,
}
proc = _ProcessNoWarn(
Expand Down Expand Up @@ -2804,6 +2824,7 @@ def _main_async_collector(
set_truncated: bool = False,
use_buffers: bool | None = None,
replay_buffer: ReplayBuffer | None = None,
replay_buffer_chunk: bool = True,
traj_pool: _TrajectoryPool = None,
) -> None:
pipe_parent.close()
Expand All @@ -2825,11 +2846,11 @@ def _main_async_collector(
env_device=env_device,
exploration_type=exploration_type,
reset_when_done=reset_when_done,
return_same_td=True,
return_same_td=replay_buffer is None,
interruptor=interruptor,
set_truncated=set_truncated,
use_buffers=use_buffers,
replay_buffer=replay_buffer,
replay_buffer=replay_buffer if replay_buffer_chunk else None,
traj_pool=traj_pool,
)
use_buffers = inner_collector._use_buffers
Expand Down Expand Up @@ -2895,6 +2916,10 @@ def _main_async_collector(
continue

if replay_buffer is not None:
if not replay_buffer_chunk:
next_data.names = None
replay_buffer.extend(next_data)

try:
queue_out.put((idx, j), timeout=_TIMEOUT)
if verbose:
Expand Down
5 changes: 5 additions & 0 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,11 @@ def __len__(self) -> int:
with self._replay_lock:
return len(self._storage)

@property
def write_count(self):
"""The total number of items written so far in the buffer through add and extend."""
return self._writer._write_count

def __repr__(self) -> str:
from torchrl.envs.transforms import Compose

Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def assert_is_sharable(tensor):
raise RuntimeError(STORAGE_ERR)

if is_tensor_collection(storage):
storage.apply(assert_is_sharable)
storage.apply(assert_is_sharable, filter_empty=True)
else:
tree_map(storage, assert_is_sharable)

Expand Down
37 changes: 36 additions & 1 deletion torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def add(self, data: Any) -> int | torch.Tensor:
self._cursor = (self._cursor + 1) % self._storage._max_size_along_dim0(
single_data=data
)
self._write_count += 1
# Replicate index requires the shape of the storage to be known
# Other than that, a "flat" (1d) index is ok to write the data
self._storage.set(_cursor, data)
Expand Down Expand Up @@ -191,6 +192,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
)
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (batch_size + cur_size) % max_size_along0
self._write_count += batch_size
# Replicate index requires the shape of the storage to be known
# Other than that, a "flat" (1d) index is ok to write the data
self._storage.set(index, data)
Expand Down Expand Up @@ -222,6 +224,20 @@ def _cursor(self, value):
_cursor_value = self._cursor_value = mp.Value("i", 0)
_cursor_value.value = value

@property
def _write_count(self):
_write_count = self.__dict__.get("_write_count_value", None)
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
return _write_count.value

@_write_count.setter
def _write_count(self, value):
_write_count = self.__dict__.get("_write_count_value", None)
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
_write_count.value = value

def __getstate__(self):
state = super().__getstate__()
if get_spawning_popen() is None:
Expand Down Expand Up @@ -249,6 +265,7 @@ def add(self, data: Any) -> int | torch.Tensor:
# we need to update the cursor first to avoid race conditions between workers
max_size_along_dim0 = self._storage._max_size_along_dim0(single_data=data)
self._cursor = (index + 1) % max_size_along_dim0
self._write_count += 1
if not is_tensorclass(data):
data.set(
"index",
Expand All @@ -275,6 +292,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
)
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (batch_size + cur_size) % max_size_along_dim0
self._write_count += batch_size
# storage must convert the data to the appropriate format if needed
if not is_tensorclass(data):
data.set(
Expand Down Expand Up @@ -460,6 +478,20 @@ def get_insert_index(self, data: Any) -> int:

return ret

@property
def _write_count(self):
_write_count = self.__dict__.get("_write_count_value", None)
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
return _write_count.value

@_write_count.setter
def _write_count(self, value):
_write_count = self.__dict__.get("_write_count_value", None)
if _write_count is None:
_write_count = self._write_count_value = mp.Value("i", 0)
_write_count.value = value

def add(self, data: Any) -> int | torch.Tensor:
"""Inserts a single element of data at an appropriate index, and returns that index.

Expand All @@ -469,6 +501,7 @@ def add(self, data: Any) -> int | torch.Tensor:
index = self.get_insert_index(data)
if index is not None:
data.set("index", index)
self._write_count += 1
# Replicate index requires the shape of the storage to be known
# Other than that, a "flat" (1d) index is ok to write the data
self._storage.set(index, data)
Expand All @@ -488,6 +521,7 @@ def extend(self, data: TensorDictBase) -> None:
for data_idx, sample in enumerate(data):
storage_idx = self.get_insert_index(sample)
if storage_idx is not None:
self._write_count += 1
data_to_replace[storage_idx] = data_idx

# -1 will be interpreted as invalid by prioritized buffers
Expand Down Expand Up @@ -517,7 +551,8 @@ def _empty(self) -> None:
def __getstate__(self):
if get_spawning_popen() is not None:
raise RuntimeError(
f"Writers of type {type(self)} cannot be shared between processes."
f"Writers of type {type(self)} cannot be shared between processes. "
f"Please submit an issue at https://github.com/pytorch/rl if this feature is needed."
)
state = super().__getstate__()
return state
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/env_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def share_memory(self, state_dict: OrderedDict) -> None:
del state_dict[key]

@property
def meta_data(self):
def meta_data(self) -> EnvMetaData:
if self._meta_data is None:
raise RuntimeError(
"meta_data is None in EnvCreator. " "Make sure init_() has been called."
Expand Down
Loading