From 73822c34da221f85b3b736518cc5b49c315f9480 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Tue, 16 Aug 2022 17:54:55 +0200 Subject: [PATCH] Support for `device=auto` buffers and set it as default value (#1009) * Default device is "auto" for buffer + auto device support in BufferBaseClass * Update docstring * Update tests * Unify tests * Update changelog * Fix tests on CUDA device Co-authored-by: Antonin RAFFIN Co-authored-by: Antonin Raffin --- docs/misc/changelog.rst | 2 + stable_baselines3/common/buffers.py | 21 +++++----- stable_baselines3/her/her_replay_buffer.py | 2 +- tests/test_buffers.py | 49 +++++++++++++++++++++- 4 files changed, 61 insertions(+), 13 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index b01d60d4d..88c3c9091 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -29,6 +29,8 @@ Others: ^^^^^^^ - Fixed ``DictReplayBuffer.next_observations`` typing (@qgallouedec) +- Added support for ``device="auto"`` in buffers and made it default (@qgallouedec) + Documentation: ^^^^^^^^^^^^^^ - Fixed typo in docstring "nature" -> "Nature" (@Melanol) diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 0eb26515b..59725312b 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -13,6 +13,7 @@ ReplayBufferSamples, RolloutBufferSamples, ) +from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import VecNormalize try: @@ -39,7 +40,7 @@ def __init__( buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", n_envs: int = 1, ): super().__init__() @@ -51,7 +52,7 @@ def __init__( self.action_dim = get_action_dim(action_space) self.pos = 0 self.full = False - self.device = device + self.device = get_device(device) self.n_envs = n_envs @staticmethod @@ -157,7 +158,7 @@ class ReplayBuffer(BaseBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param device: + :param device: PyTorch device :param n_envs: Number of parallel environments :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer which reduces by almost a factor two the memory used, @@ -175,7 +176,7 @@ def __init__( buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, @@ -328,7 +329,7 @@ class RolloutBuffer(BaseBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param device: + :param device: PyTorch device :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to classic advantage when set to 1. :param gamma: Discount factor @@ -340,7 +341,7 @@ def __init__( buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, @@ -493,7 +494,7 @@ class DictReplayBuffer(ReplayBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param device: + :param device: PyTorch device :param n_envs: Number of parallel environments :param optimize_memory_usage: Enable a memory efficient variant Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702) @@ -507,7 +508,7 @@ def __init__( buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", n_envs: int = 1, optimize_memory_usage: bool = False, handle_timeout_termination: bool = True, @@ -658,7 +659,7 @@ class DictRolloutBuffer(RolloutBuffer): :param buffer_size: Max number of element in the buffer :param observation_space: Observation space :param action_space: Action space - :param device: + :param device: PyTorch device :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator Equivalent to Monte-Carlo advantage estimate when set to 1. :param gamma: Discount factor @@ -670,7 +671,7 @@ def __init__( buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1, diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index 3c19aac42..e3fc63e14 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -73,7 +73,7 @@ def __init__( self, env: VecEnv, buffer_size: int, - device: Union[th.device, str] = "cpu", + device: Union[th.device, str] = "auto", replay_buffer: Optional[DictReplayBuffer] = None, max_episode_length: Optional[int] = None, n_sampled_goal: int = 4, diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 45c5e6aa3..0e028e670 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -4,9 +4,10 @@ import torch as th from gym import spaces -from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer +from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples +from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import VecNormalize @@ -71,7 +72,7 @@ def test_replay_buffer_normalization(replay_buffer_cls): env = make_vec_env(env) env = VecNormalize(env) - buffer = replay_buffer_cls(100, env.observation_space, env.action_space) + buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device="cpu") # Interract and store transitions env.reset() @@ -94,3 +95,47 @@ def test_replay_buffer_normalization(replay_buffer_cls): assert th.allclose(observations.mean(0), th.zeros(1), atol=1) # Test reward normalization assert np.allclose(sample.rewards.mean(0), np.zeros(1), atol=1) + + +@pytest.mark.parametrize("replay_buffer_cls", [DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer]) +@pytest.mark.parametrize("device", ["cpu", "cuda", "auto"]) +def test_device_buffer(replay_buffer_cls, device): + if device == "cuda" and not th.cuda.is_available(): + pytest.skip("CUDA not available") + + env = { + RolloutBuffer: DummyEnv, + DictRolloutBuffer: DummyDictEnv, + ReplayBuffer: DummyEnv, + DictReplayBuffer: DummyDictEnv, + }[replay_buffer_cls] + env = make_vec_env(env) + + buffer = replay_buffer_cls(100, env.observation_space, env.action_space, device=device) + + # Interract and store transitions + obs = env.reset() + for _ in range(100): + action = env.action_space.sample() + next_obs, reward, done, info = env.step(action) + if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: + episode_start, values, log_prob = np.zeros(1), th.zeros(1), th.ones(1) + buffer.add(obs, action, reward, episode_start, values, log_prob) + else: + buffer.add(obs, next_obs, action, reward, done, info) + obs = next_obs + + # Get data from the buffer + if replay_buffer_cls in [RolloutBuffer, DictRolloutBuffer]: + data = buffer.get(50) + elif replay_buffer_cls in [ReplayBuffer, DictReplayBuffer]: + data = buffer.sample(50) + + # Check that all data are on the desired device + desired_device = get_device(device).type + for value in list(data): + if isinstance(value, dict): + for key in value.keys(): + assert value[key].device.type == desired_device + elif isinstance(value, th.Tensor): + assert value.device.type == desired_device