Skip to content

Commit

Permalink
Support for device=auto buffers and set it as default value (DLR-RM…
Browse files Browse the repository at this point in the history
…#1009)

* Default device is "auto" for buffer + auto device support in BufferBaseClass

* Update docstring

* Update tests

* Unify tests

* Update changelog

* Fix tests on CUDA device

Co-authored-by: Antonin RAFFIN <[email protected]>
Co-authored-by: Antonin Raffin <[email protected]>
  • Loading branch information
3 people authored Aug 16, 2022
1 parent 792e3bc commit 73822c3
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 13 deletions.
2 changes: 2 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ Others:
^^^^^^^
- Fixed ``DictReplayBuffer.next_observations`` typing (@qgallouedec)

- Added support for ``device="auto"`` in buffers and made it default (@qgallouedec)

Documentation:
^^^^^^^^^^^^^^
- Fixed typo in docstring "nature" -> "Nature" (@Melanol)
Expand Down
21 changes: 11 additions & 10 deletions stable_baselines3/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ReplayBufferSamples,
RolloutBufferSamples,
)
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize

try:
Expand All @@ -39,7 +40,7 @@ def __init__(
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
device: Union[th.device, str] = "auto",
n_envs: int = 1,
):
super().__init__()
Expand All @@ -51,7 +52,7 @@ def __init__(
self.action_dim = get_action_dim(action_space)
self.pos = 0
self.full = False
self.device = device
self.device = get_device(device)
self.n_envs = n_envs

@staticmethod
Expand Down Expand Up @@ -157,7 +158,7 @@ class ReplayBuffer(BaseBuffer):
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device:
:param device: PyTorch device
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
of the replay buffer which reduces by almost a factor two the memory used,
Expand All @@ -175,7 +176,7 @@ def __init__(
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
device: Union[th.device, str] = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
Expand Down Expand Up @@ -328,7 +329,7 @@ class RolloutBuffer(BaseBuffer):
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device:
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
:param gamma: Discount factor
Expand All @@ -340,7 +341,7 @@ def __init__(
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
Expand Down Expand Up @@ -493,7 +494,7 @@ class DictReplayBuffer(ReplayBuffer):
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device:
:param device: PyTorch device
:param n_envs: Number of parallel environments
:param optimize_memory_usage: Enable a memory efficient variant
Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702)
Expand All @@ -507,7 +508,7 @@ def __init__(
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
device: Union[th.device, str] = "auto",
n_envs: int = 1,
optimize_memory_usage: bool = False,
handle_timeout_termination: bool = True,
Expand Down Expand Up @@ -658,7 +659,7 @@ class DictRolloutBuffer(RolloutBuffer):
:param buffer_size: Max number of element in the buffer
:param observation_space: Observation space
:param action_space: Action space
:param device:
:param device: PyTorch device
:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to Monte-Carlo advantage estimate when set to 1.
:param gamma: Discount factor
Expand All @@ -670,7 +671,7 @@ def __init__(
buffer_size: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = "cpu",
device: Union[th.device, str] = "auto",
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1,
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/her/her_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
self,
env: VecEnv,
buffer_size: int,
device: Union[th.device, str] = "cpu",
device: Union[th.device, str] = "auto",
replay_buffer: Optional[DictReplayBuffer] = None,
max_episode_length: Optional[int] = None,
n_sampled_goal: int = 4,
Expand Down
49 changes: 47 additions & 2 deletions tests/test_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import torch as th
from gym import spaces

from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize


Expand Down Expand Up @@ -71,7 +72,7 @@ def test_replay_buffer_normalization(replay_buffer_cls):
env = make_vec_env(env)
env = VecNormalize(env)

buffer = replay_buffer_cls(100, env.observation_space, env.action_space)
buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device="cpu")

# Interract and store transitions
env.reset()
Expand All @@ -94,3 +95,47 @@ def test_replay_buffer_normalization(replay_buffer_cls):
assert th.allclose(observations.mean(0), th.zeros(1), atol=1)
# Test reward normalization
assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1)


@pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer])
@pytest.mark.parametrize("device", ["cpu", "cuda", "auto"])
def test_device_buffer(replay_buffer_cls, device):
if device == "cuda" and not th.cuda.is_available():
pytest.skip("CUDA not available")

env = {
RolloutBuffer: DummyEnv,
DictRolloutBuffer: DummyDictEnv,
ReplayBuffer: DummyEnv,
DictReplayBuffer: DummyDictEnv,
}[replay_buffer_cls]
env = make_vec_env(env)

buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device)

# Interract and store transitions
obs = env.reset()
for _ in range(100):
action = env.action_space.sample()
next_obs, reward, done, info = env.step(action)
if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
episode_start, values, log_prob = np.zeros(1), th.zeros(1), th.ones(1)
buffer.add(obs, action, reward, episode_start, values, log_prob)
else:
buffer.add(obs, next_obs, action, reward, done, info)
obs = next_obs

# Get data from the buffer
if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]:
data = buffer.get(50)
elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]:
data = buffer.sample(50)

# Check that all data are on the desired device
desired_device = get_device(device).type
for value in list(data):
if isinstance(value, dict):
for key in value.keys():
assert value[key].device.type == desired_device
elif isinstance(value, th.Tensor):
assert value.device.type == desired_device

0 comments on commit 73822c3

Please sign in to comment.