From a413077ed726ed32ab277d4be296c940b25ec31f Mon Sep 17 00:00:00 2001 From: Prabhat Date: Fri, 14 Apr 2023 21:37:42 -0500 Subject: [PATCH 01/26] Adds gymnasium to the setup --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1bf1048a0..84ca77dba 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ install_requires = [ 'torch>=1.3.0', - 'gym>=0.9.7', + 'gymnasium[all]', 'numpy>=1.11.0', 'pillow', 'filelock', From 35a420c8baa373a6d0464c4597c93c275c007210 Mon Sep 17 00:00:00 2001 From: Prabhat Date: Fri, 14 Apr 2023 21:55:07 -0500 Subject: [PATCH 02/26] gym -> gymnasium --- .gitignore | 2 +- .pfnci/run.sh | 2 +- README.md | 12 ++--- examples/README.md | 2 +- examples/atari/train_acer_ale.py | 6 +-- examples/atari/train_drqn_ale.py | 6 +-- examples/atari/train_ppo_ale.py | 4 +- .../atlas/train_soft_actor_critic_atlas.py | 18 ++++---- examples/grasping/train_dqn_batch_grasping.py | 24 +++++----- examples/gym/README.md | 14 +++--- examples/gym/train_categorical_dqn_gym.py | 10 ++-- examples/gym/train_dqn_gym.py | 14 +++--- examples/gym/train_reinforce_gym.py | 16 +++---- examples/mujoco/reproduction/ddpg/README.md | 2 +- .../mujoco/reproduction/ddpg/train_ddpg.py | 14 +++--- examples/mujoco/reproduction/ppo/README.md | 4 +- examples/mujoco/reproduction/ppo/train_ppo.py | 16 +++---- .../reproduction/soft_actor_critic/README.md | 2 +- .../train_soft_actor_critic.py | 16 +++---- examples/mujoco/reproduction/td3/README.md | 6 +-- examples/mujoco/reproduction/td3/train_td3.py | 14 +++--- examples/mujoco/reproduction/trpo/README.md | 4 +- .../mujoco/reproduction/trpo/train_trpo.py | 18 ++++---- examples/optuna/optuna_dqn_obs1d.py | 10 ++-- examples/quickstart/quickstart.ipynb | 8 ++-- examples/slimevolley/README.md | 8 ++-- examples/slimevolley/train_rainbow.py | 24 +++++----- examples_tests/gym/test_categorical_dqn.sh | 8 ++-- examples_tests/gym/test_dqn.sh | 8 ++-- examples_tests/gym/test_reinforce.sh | 8 ++-- examples_tests/slimevolley/test_rainbow.sh | 2 +- pfrl/envs/abc.py | 2 +- pfrl/envs/multiprocess_vector_env.py | 2 +- pfrl/envs/serial_vector_env.py | 2 +- pfrl/utils/pretrained_models.py | 2 +- pfrl/wrappers/atari_wrappers.py | 46 +++++++++---------- pfrl/wrappers/cast_observation.py | 4 +- pfrl/wrappers/continuing_time_limit.py | 8 ++-- pfrl/wrappers/monitor.py | 20 ++++---- pfrl/wrappers/normalize_action_space.py | 10 ++-- pfrl/wrappers/randomize_action.py | 12 ++--- pfrl/wrappers/render.py | 6 +-- pfrl/wrappers/scale_reward.py | 4 +- pfrl/wrappers/vector_frame_stack.py | 4 +- requirements.txt | 2 +- setup.cfg | 4 +- setup.py | 2 +- tests/envs_tests/test_vector_envs.py | 8 ++-- tests/wrappers_tests/test_atari_wrappers.py | 12 ++--- tests/wrappers_tests/test_cast_observation.py | 6 +-- tests/wrappers_tests/test_monitor.py | 6 +-- tests/wrappers_tests/test_randomize_action.py | 10 ++-- tests/wrappers_tests/test_scale_reward.py | 4 +- .../wrappers_tests/test_vector_frame_stack.py | 8 ++-- 54 files changed, 243 insertions(+), 243 deletions(-) diff --git a/.gitignore b/.gitignore index 6473ced97..a98cef38e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,4 @@ build/ dist/ .idea/ results/ -examples/gym/results/ +examples/gymnasium/results/ diff --git a/.pfnci/run.sh b/.pfnci/run.sh index 5e5dbee31..37480af36 100644 --- a/.pfnci/run.sh +++ b/.pfnci/run.sh @@ -75,7 +75,7 @@ main() { # pytest does not run with attrs==19.2.0 (https://github.com/pytest-dev/pytest/issues/3280) # NOQA "${PYTHON}" -m pip install \ 'pytest==4.1.1' 'attrs==19.1.0' 'pytest-xdist==1.26.1' \ - 'gym[atari,classic_control]==0.19.0' 'optuna' 'zipp==1.0.0' 'pybullet==2.8.1' 'jupyterlab==2.1.5' 'traitlets==5.1.1' + 'gymnasium[atari,classic_control]==0.19.0' 'optuna' 'zipp==1.0.0' 'pybullet==2.8.1' 'jupyterlab==2.1.5' 'traitlets==5.1.1' git config --global user.email "you@example.com" git config --global user.name "Your Name" diff --git a/README.md b/README.md index f2d3f0a9a..a20d8bc34 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ Refer to [Installation](http://pfrl.readthedocs.io/en/latest/install.html) for m ## Getting started -You can try [PFRL Quickstart Guide](examples/quickstart/quickstart.ipynb) first, or check the [examples](examples) ready for Atari 2600 and Open AI Gym. +You can try [PFRL Quickstart Guide](examples/quickstart/quickstart.ipynb) first, or check the [examples](examples) ready for Atari 2600 and Open AI gymnasium. For more information, you can refer to [PFRL's documentation](http://pfrl.readthedocs.io/en/latest/index.html). @@ -64,9 +64,9 @@ Following algorithms have been implemented in PFRL: - [ACER (Actor-Critic with Experience Replay)](https://arxiv.org/abs/1611.01224) - examples: [[atari]](examples/atari/train_acer_ale.py) - [Categorical DQN](https://arxiv.org/abs/1707.06887) - - examples: [[atari]](examples/atari/train_categorical_dqn_ale.py) [[general gym]](examples/gym/train_categorical_dqn_gym.py) + - examples: [[atari]](examples/atari/train_categorical_dqn_ale.py) [[general gymnasium]](examples/gymnasium/train_categorical_dqn_gymnasium.py) - [DQN (Deep Q-Network)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) (including [Double DQN](https://arxiv.org/abs/1509.06461), [Persistent Advantage Learning (PAL)](https://arxiv.org/abs/1512.04860), Double PAL, [Dynamic Policy Programming (DPP)](http://www.jmlr.org/papers/volume13/azar12a/azar12a.pdf)) - - examples: [[atari reproduction]](examples/atari/reproduction/dqn) [[atari]](examples/atari/train_dqn_ale.py) [[atari (batched)]](examples/atari/train_dqn_batch_ale.py) [[flickering atari]](examples/atari/train_drqn_ale.py) [[general gym]](examples/gym/train_dqn_gym.py) + - examples: [[atari reproduction]](examples/atari/reproduction/dqn) [[atari]](examples/atari/train_dqn_ale.py) [[atari (batched)]](examples/atari/train_dqn_batch_ale.py) [[flickering atari]](examples/atari/train_drqn_ale.py) [[general gymnasium]](examples/gymnasium/train_dqn_gymnasium.py) - [DDPG (Deep Deterministic Policy Gradients)](https://arxiv.org/abs/1509.02971) (including [SVG(0)](https://arxiv.org/abs/1510.09142)) - examples: [[mujoco reproduction]](examples/mujoco/reproduction/ddpg) - [IQN (Implicit Quantile Networks)](https://arxiv.org/abs/1806.06923) @@ -76,7 +76,7 @@ Following algorithms have been implemented in PFRL: - [Rainbow](https://arxiv.org/abs/1710.02298) - examples: [[atari reproduction]](examples/atari/reproduction/rainbow) [[Slime volleyball]](examples/slimevolley/) - [REINFORCE](http://www-anw.cs.umass.edu/~barto/courses/cs687/williams92simple.pdf) - - examples: [[general gym]](examples/gym/train_reinforce_gym.py) + - examples: [[general gymnasium]](examples/gymnasium/train_reinforce_gymnasium.py) - [SAC (Soft Actor-Critic)](https://arxiv.org/abs/1812.05905) - examples: [[mujoco reproduction]](examples/mujoco/reproduction/soft_actor_critic) [[Atlas walk]](examples/atlas/) - [TRPO (Trust Region Policy Optimization)](https://arxiv.org/abs/1502.05477) with [GAE (Generalized Advantage Estimation)](https://arxiv.org/abs/1506.02438) @@ -92,14 +92,14 @@ Following useful techniques have been also implemented in PFRL: - [Dueling Network](https://arxiv.org/abs/1511.06581) - examples: [[Rainbow]](examples/atari/reproduction/rainbow) [[DQN/DoubleDQN/PAL]](examples/atari/train_dqn_ale.py) - [Normalized Advantage Function](https://arxiv.org/abs/1603.00748) - - examples: [[DQN]](examples/gym/train_dqn_gym.py) (for continuous-action envs only) + - examples: [[DQN]](examples/gymnasium/train_dqn_gymnasium.py) (for continuous-action envs only) - [Deep Recurrent Q-Network](https://arxiv.org/abs/1507.06527) - examples: [[DQN]](examples/atari/train_drqn_ale.py) ## Environments -Environments that support the subset of OpenAI Gym's interface (`reset` and `step` methods) can be used. +Environments that support the subset of OpenAI gymnasium's interface (`reset` and `step` methods) can be used. ## Contributing diff --git a/examples/README.md b/examples/README.md index f8fc3c4b6..4b97fc16c 100644 --- a/examples/README.md +++ b/examples/README.md @@ -3,7 +3,7 @@ - `atari`: examples for general Atari games - `atari/reproduction`: examples with benchmark scores for reproducing published results on Atari - `atlas`: training an Atlas robot to walk -- `gym`: examples for OpenAI Gym environments +- `gymnasium`: examples for OpenAI gymnasium environments - `grasping`: examples for a Bullet-based robotic grasping environment - `mujoco/reproduction`: examples with benchmark scores for reproducing published results on MuJoCo tasks - `quickstart`: a quickstart guide of PFRL diff --git a/examples/atari/train_acer_ale.py b/examples/atari/train_acer_ale.py index 091377718..686853ad9 100644 --- a/examples/atari/train_acer_ale.py +++ b/examples/atari/train_acer_ale.py @@ -4,8 +4,8 @@ # Prevent numpy from using multiple threads os.environ["OMP_NUM_THREADS"] = "1" -import gym # NOQA:E402 -import gym.wrappers # NOQA:E402 +import gymnasium # NOQA:E402 +import gymnasium.wrappers # NOQA:E402 import numpy as np # NOQA:E402 from torch import nn # NOQA:E402 @@ -92,7 +92,7 @@ def main(): args.outdir = experiments.prepare_output_dir(args, args.outdir) print("Output files are saved in {}".format(args.outdir)) - n_actions = gym.make(args.env).action_space.n + n_actions = gymnasium.make(args.env).action_space.n input_to_hidden = nn.Sequential( nn.Conv2d(4, 16, 8, stride=4), diff --git a/examples/atari/train_drqn_ale.py b/examples/atari/train_drqn_ale.py index ccbefa699..a0425784d 100644 --- a/examples/atari/train_drqn_ale.py +++ b/examples/atari/train_drqn_ale.py @@ -11,8 +11,8 @@ """ import argparse -import gym -import gym.wrappers +import gymnasium +import gymnasium.wrappers import numpy as np import torch from torch import nn @@ -193,7 +193,7 @@ def make_env(test): # Randomize actions like epsilon-greedy in evaluation as well env = pfrl.wrappers.RandomizeAction(env, args.eval_epsilon) if args.monitor: - env = gym.wrappers.Monitor( + env = gymnasium.wrappers.Monitor( env, args.outdir, mode="evaluation" if test else "training" ) if args.render: diff --git a/examples/atari/train_ppo_ale.py b/examples/atari/train_ppo_ale.py index 80bac591f..dd48244fd 100644 --- a/examples/atari/train_ppo_ale.py +++ b/examples/atari/train_ppo_ale.py @@ -1,4 +1,4 @@ -"""An example of training PPO against OpenAI Gym Atari Envs. +"""An example of training PPO against OpenAI gymnasium Atari Envs. This script is an example of training a PPO agent on Atari envs. @@ -25,7 +25,7 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument( - "--env", type=str, default="BreakoutNoFrameskip-v4", help="Gym Env ID." + "--env", type=str, default="BreakoutNoFrameskip-v4", help="gymnasium Env ID." ) parser.add_argument( "--gpu", type=int, default=0, help="GPU device ID. Set to -1 to use CPUs only." diff --git a/examples/atlas/train_soft_actor_critic_atlas.py b/examples/atlas/train_soft_actor_critic_atlas.py index 8dc411192..8b1a32505 100644 --- a/examples/atlas/train_soft_actor_critic_atlas.py +++ b/examples/atlas/train_soft_actor_critic_atlas.py @@ -4,8 +4,8 @@ import logging import sys -import gym -import gym.wrappers +import gymnasium +import gymnasium.wrappers import numpy as np import torch from torch import distributions, nn @@ -17,16 +17,16 @@ def make_env(args, seed, test): if args.env.startswith("Roboschool"): - # Check gym version because roboschool does not work with gym>=0.15.6 + # Check gymnasium version because roboschool does not work with gymnasium>=0.15.6 from distutils.version import StrictVersion - gym_version = StrictVersion(gym.__version__) - if gym_version >= StrictVersion("0.15.6"): - raise RuntimeError("roboschool does not work with gym>=0.15.6") + gymnasium_version = StrictVersion(gymnasium.__version__) + if gymnasium_version >= StrictVersion("0.15.6"): + raise RuntimeError("roboschool does not work with gymnasium>=0.15.6") import roboschool # NOQA - env = gym.make(args.env) + env = gymnasium.make(args.env) # Unwrap TimiLimit wrapper - assert isinstance(env, gym.wrappers.TimeLimit) + assert isinstance(env, gymnasium.wrappers.TimeLimit) env = env.env # Use different random seeds for train and test envs env_seed = 2**32 - 1 - seed if test else seed @@ -60,7 +60,7 @@ def main(): "--env", type=str, default="RoboschoolAtlasForwardWalk-v1", - help="OpenAI Gym env to perform algorithm on.", + help="OpenAI gymnasium env to perform algorithm on.", ) parser.add_argument( "--num-envs", type=int, default=4, help="Number of envs run in parallel." diff --git a/examples/grasping/train_dqn_batch_grasping.py b/examples/grasping/train_dqn_batch_grasping.py index 0274a0530..a81d04649 100644 --- a/examples/grasping/train_dqn_batch_grasping.py +++ b/examples/grasping/train_dqn_batch_grasping.py @@ -2,8 +2,8 @@ import functools import os -import gym -import gym.spaces +import gymnasium +import gymnasium.spaces import numpy as np import torch from torch import nn @@ -13,7 +13,7 @@ from pfrl.q_functions import DiscreteActionValueHead -class CastAction(gym.ActionWrapper): +class CastAction(gymnasium.ActionWrapper): """Cast actions to a given type.""" def __init__(self, env, type_): @@ -24,14 +24,14 @@ def action(self, action): return self.type_(action) -class TransposeObservation(gym.ObservationWrapper): +class TransposeObservation(gymnasium.ObservationWrapper): """Transpose observations.""" def __init__(self, env, axes): super().__init__(env) self._axes = axes - assert isinstance(env.observation_space, gym.spaces.Box) - self.observation_space = gym.spaces.Box( + assert isinstance(env.observation_space, gymnasium.spaces.Box) + self.observation_space = gymnasium.spaces.Box( low=env.observation_space.low.transpose(*self._axes), high=env.observation_space.high.transpose(*self._axes), dtype=env.observation_space.dtype, @@ -41,7 +41,7 @@ def observation(self, observation): return observation.transpose(*self._axes) -class ObserveElapsedSteps(gym.Wrapper): +class ObserveElapsedSteps(gymnasium.Wrapper): """Observe the number of elapsed steps in an episode. A new observation will be a tuple of an original observation and an integer @@ -52,10 +52,10 @@ def __init__(self, env, max_steps): super().__init__(env) self._max_steps = max_steps self._elapsed_steps = 0 - self.observation_space = gym.spaces.Tuple( + self.observation_space = gymnasium.spaces.Tuple( ( env.observation_space, - gym.spaces.Discrete(self._max_steps + 1), + gymnasium.spaces.Discrete(self._max_steps + 1), ) ) @@ -70,7 +70,7 @@ def step(self, action): return (observation, self._elapsed_steps), reward, done, info -class RecordMovie(gym.Wrapper): +class RecordMovie(gymnasium.Wrapper): """Record MP4 videos using pybullet's logging API.""" def __init__(self, env, dirname): @@ -243,7 +243,7 @@ def main(): max_episode_steps = 8 def make_env(idx, test): - from pybullet_envs.bullet.kuka_diverse_object_gym_env import ( # NOQA + from pybullet_envs.bullet.kuka_diverse_object_gymnasium_env import ( # NOQA KukaDiverseObjectEnv, ) @@ -263,7 +263,7 @@ def make_env(idx, test): # Disable file caching to keep memory usage small env._p.setPhysicsEngineParameter(enableFileCaching=False) assert env.observation_space is None - env.observation_space = gym.spaces.Box( + env.observation_space = gymnasium.spaces.Box( low=0, high=255, shape=(84, 84, 3), dtype=np.uint8 ) # (84, 84, 3) -> (3, 84, 84) diff --git a/examples/gym/README.md b/examples/gym/README.md index 0e46abf0d..b17585519 100644 --- a/examples/gym/README.md +++ b/examples/gym/README.md @@ -1,15 +1,15 @@ -# Examples for OpenAI Gym environments +# Examples for OpenAI gymnasium environments -- `train_categorical_dqn_gym.py`: CategoricalDQN for discrete action action spaces -- `train_dqn_gym.py`: DQN for both discrete action and continuous action spaces -- `train_reinforce_gym.py`: REINFORCE for both discrete action and continuous action spaces (only for episodic envs) +- `train_categorical_dqn_gymnasium.py`: CategoricalDQN for discrete action action spaces +- `train_dqn_gymnasium.py`: DQN for both discrete action and continuous action spaces +- `train_reinforce_gymnasium.py`: REINFORCE for both discrete action and continuous action spaces (only for episodic envs) ## How to run ``` -python train_categorical_dqn_gym.py [options] -python train_dqn_gym.py [options] -python train_reinforce_gym.py [options] +python train_categorical_dqn_gymnasium.py [options] +python train_dqn_gymnasium.py [options] +python train_reinforce_gymnasium.py [options] ``` Specify `--help` or read code for options. diff --git a/examples/gym/train_categorical_dqn_gym.py b/examples/gym/train_categorical_dqn_gym.py index 7c7105189..ac07557c7 100644 --- a/examples/gym/train_categorical_dqn_gym.py +++ b/examples/gym/train_categorical_dqn_gym.py @@ -1,16 +1,16 @@ -"""An example of training Categorical DQN against OpenAI Gym Envs. +"""An example of training Categorical DQN against OpenAI gymnasium Envs. This script is an example of training a CategoricalDQN agent against OpenAI -Gym envs. Only discrete spaces are supported. +gymnasium envs. Only discrete spaces are supported. To solve CartPole-v0, run: - python train_categorical_dqn_gym.py --env CartPole-v0 + python train_categorical_dqn_gymnasium.py --env CartPole-v0 """ import argparse import sys -import gym +import gymnasium import torch import pfrl @@ -66,7 +66,7 @@ def main(): print("Output files are saved in {}".format(args.outdir)) def make_env(test): - env = gym.make(args.env) + env = gymnasium.make(args.env) env_seed = 2**32 - 1 - args.seed if test else args.seed env.seed(env_seed) # Cast observations to float32 because our model uses float32 diff --git a/examples/gym/train_dqn_gym.py b/examples/gym/train_dqn_gym.py index 9319a9125..d39385cf2 100644 --- a/examples/gym/train_dqn_gym.py +++ b/examples/gym/train_dqn_gym.py @@ -1,24 +1,24 @@ -"""An example of training DQN against OpenAI Gym Envs. +"""An example of training DQN against OpenAI gymnasium Envs. -This script is an example of training a DQN agent against OpenAI Gym envs. +This script is an example of training a DQN agent against OpenAI gymnasium envs. Both discrete and continuous action spaces are supported. For continuous action spaces, A NAF (Normalized Advantage Function) is used to approximate Q-values. To solve CartPole-v0, run: - python train_dqn_gym.py --env CartPole-v0 + python train_dqn_gymnasium.py --env CartPole-v0 To solve Pendulum-v0, run: - python train_dqn_gym.py --env Pendulum-v0 + python train_dqn_gymnasium.py --env Pendulum-v0 """ import argparse import os import sys -import gym +import gymnasium import numpy as np import torch.optim as optim -from gym import spaces +from gymnasium import spaces import pfrl from pfrl import experiments, explorers @@ -100,7 +100,7 @@ def clip_action_filter(a): return np.clip(a, action_space.low, action_space.high) def make_env(idx=0, test=False): - env = gym.make(args.env) + env = gymnasium.make(args.env) # Use different random seeds for train and test envs process_seed = int(process_seeds[idx]) env_seed = 2**32 - 1 - process_seed if test else process_seed diff --git a/examples/gym/train_reinforce_gym.py b/examples/gym/train_reinforce_gym.py index f2c9eaa61..c82ed51e0 100644 --- a/examples/gym/train_reinforce_gym.py +++ b/examples/gym/train_reinforce_gym.py @@ -1,18 +1,18 @@ -"""An example of training a REINFORCE agent against OpenAI Gym envs. +"""An example of training a REINFORCE agent against OpenAI gymnasium envs. -This script is an example of training a REINFORCE agent against OpenAI Gym +This script is an example of training a REINFORCE agent against OpenAI gymnasium envs. Both discrete and continuous action spaces are supported. To solve CartPole-v0, run: - python train_reinforce_gym.py + python train_reinforce_gymnasium.py To solve InvertedPendulum-v1, run: - python train_reinforce_gym.py --env InvertedPendulum-v1 + python train_reinforce_gymnasium.py --env InvertedPendulum-v1 """ import argparse -import gym -import gym.spaces +import gymnasium +import gymnasium.spaces import torch from torch import nn @@ -59,7 +59,7 @@ def main(): args.outdir = experiments.prepare_output_dir(args, args.outdir) def make_env(test): - env = gym.make(args.env) + env = gymnasium.make(args.env) # Use different random seeds for train and test envs env_seed = 2**32 - 1 - args.seed if test else args.seed env.seed(env_seed) @@ -83,7 +83,7 @@ def make_env(test): obs_size = obs_space.low.size hidden_size = 200 # Switch policy types accordingly to action space types - if isinstance(action_space, gym.spaces.Box): + if isinstance(action_space, gymnasium.spaces.Box): model = nn.Sequential( nn.Linear(obs_size, hidden_size), nn.LeakyReLU(0.2), diff --git a/examples/mujoco/reproduction/ddpg/README.md b/examples/mujoco/reproduction/ddpg/README.md index bdc824806..4821f7abc 100644 --- a/examples/mujoco/reproduction/ddpg/README.md +++ b/examples/mujoco/reproduction/ddpg/README.md @@ -1,6 +1,6 @@ # DDPG on MuJoCo benchmarks -This example trains a DDPG agent ([Continuous Control with Deep Reinforcement Learning](https://arxiv.org/abs/1509.02971)) on MuJoCo benchmarks from OpenAI Gym. +This example trains a DDPG agent ([Continuous Control with Deep Reinforcement Learning](https://arxiv.org/abs/1509.02971)) on MuJoCo benchmarks from OpenAI gymnasium. We follow the training and evaluation settings of [Addressing Function Approximation Error in Actor-Critic Methods](http://arxiv.org/abs/1802.09477), which provides thorough, highly tuned benchmark results. diff --git a/examples/mujoco/reproduction/ddpg/train_ddpg.py b/examples/mujoco/reproduction/ddpg/train_ddpg.py index 705cccc5d..2f764fdf2 100644 --- a/examples/mujoco/reproduction/ddpg/train_ddpg.py +++ b/examples/mujoco/reproduction/ddpg/train_ddpg.py @@ -1,4 +1,4 @@ -"""A training script of DDPG on OpenAI Gym Mujoco environments. +"""A training script of DDPG on OpenAI gymnasium Mujoco environments. This script follows the settings of http://arxiv.org/abs/1802.09477 as much as possible. @@ -8,8 +8,8 @@ import logging import sys -import gym -import gym.wrappers +import gymnasium +import gymnasium.wrappers import numpy as np import torch from torch import nn @@ -37,7 +37,7 @@ def main(): "--env", type=str, default="Hopper-v2", - help="OpenAI Gym MuJoCo env to perform algorithm on.", + help="OpenAI gymnasium MuJoCo env to perform algorithm on.", ) parser.add_argument("--seed", type=int, default=0, help="Random seed [0, 2 ** 32)") parser.add_argument( @@ -82,7 +82,7 @@ def main(): "--pretrained-type", type=str, default="best", choices=["best", "final"] ) parser.add_argument( - "--monitor", action="store_true", help="Wrap env with gym.wrappers.Monitor." + "--monitor", action="store_true", help="Wrap env with gymnasium.wrappers.Monitor." ) parser.add_argument( "--log-level", type=int, default=logging.INFO, help="Level of the root logger." @@ -98,9 +98,9 @@ def main(): utils.set_random_seed(args.seed) def make_env(test): - env = gym.make(args.env) + env = gymnasium.make(args.env) # Unwrap TimeLimit wrapper - assert isinstance(env, gym.wrappers.TimeLimit) + assert isinstance(env, gymnasium.wrappers.TimeLimit) env = env.env # Use different random seeds for train and test envs env_seed = 2**32 - 1 - args.seed if test else args.seed diff --git a/examples/mujoco/reproduction/ppo/README.md b/examples/mujoco/reproduction/ppo/README.md index 7170455c4..ad1129aaf 100644 --- a/examples/mujoco/reproduction/ppo/README.md +++ b/examples/mujoco/reproduction/ppo/README.md @@ -1,6 +1,6 @@ # PPO on MuJoCo benchmarks -This example trains a PPO agent ([Proximal Policy Optimization Algorithms](http://arxiv.org/abs/1707.06347)) on MuJoCo benchmarks from OpenAI Gym. +This example trains a PPO agent ([Proximal Policy Optimization Algorithms](http://arxiv.org/abs/1707.06347)) on MuJoCo benchmarks from OpenAI gymnasium. We follow the training and evaluation settings of [Deep Reinforcement Learning that Matters](https://arxiv.org/abs/1709.06560), which provides thorough, highly tuned benchmark results. @@ -37,7 +37,7 @@ To view the full list of options, either view the code or run the example with t ## Known differences - While the original paper initialized weights by normal distribution (https://github.com/Breakend/baselines/blob/50ffe01d254221db75cdb5c2ba0ab51a6da06b0a/baselines/ppo1/mlp_policy.py#L28), we use orthogonal initialization as the latest openai/baselines does (https://github.com/openai/baselines/blob/9b68103b737ac46bc201dfb3121cfa5df2127e53/baselines/a2c/utils.py#L61). -- We used version v2 of the environments whereas the original results were reported for version v1, however this doesn't seem to introduce significant differences: https://github.com/openai/gym/pull/834 +- We used version v2 of the environments whereas the original results were reported for version v1, however this doesn't seem to introduce significant differences: https://github.com/openai/gymnasium/pull/834 ## Results diff --git a/examples/mujoco/reproduction/ppo/train_ppo.py b/examples/mujoco/reproduction/ppo/train_ppo.py index a42d8f0af..991de8aec 100644 --- a/examples/mujoco/reproduction/ppo/train_ppo.py +++ b/examples/mujoco/reproduction/ppo/train_ppo.py @@ -1,4 +1,4 @@ -"""A training script of PPO on OpenAI Gym Mujoco environments. +"""A training script of PPO on OpenAI gymnasium Mujoco environments. This script follows the settings of https://arxiv.org/abs/1709.06560 as much as possible. @@ -6,8 +6,8 @@ import argparse import functools -import gym -import gym.spaces +import gymnasium +import gymnasium.spaces import numpy as np import torch from torch import nn @@ -28,7 +28,7 @@ def main(): "--env", type=str, default="Hopper-v2", - help="OpenAI Gym MuJoCo env to perform algorithm on.", + help="OpenAI gymnasium MuJoCo env to perform algorithm on.", ) parser.add_argument( "--num-envs", type=int, default=1, help="Number of envs run in parallel." @@ -75,7 +75,7 @@ def main(): "--log-level", type=int, default=logging.INFO, help="Level of the root logger." ) parser.add_argument( - "--monitor", action="store_true", help="Wrap env with gym.wrappers.Monitor." + "--monitor", action="store_true", help="Wrap env with gymnasium.wrappers.Monitor." ) parser.add_argument( "--log-interval", @@ -112,7 +112,7 @@ def main(): args.outdir = experiments.prepare_output_dir(args, args.outdir) def make_env(process_idx, test): - env = gym.make(args.env) + env = gymnasium.make(args.env) # Use different random seeds for train and test envs process_seed = int(process_seeds[process_idx]) env_seed = 2**32 - 1 - process_seed if test else process_seed @@ -134,14 +134,14 @@ def make_batch_env(test): ) # Only for getting timesteps, and obs-action spaces - sample_env = gym.make(args.env) + sample_env = gymnasium.make(args.env) timestep_limit = sample_env.spec.max_episode_steps obs_space = sample_env.observation_space action_space = sample_env.action_space print("Observation space:", obs_space) print("Action space:", action_space) - assert isinstance(action_space, gym.spaces.Box) + assert isinstance(action_space, gymnasium.spaces.Box) # Normalize observations based on their empirical mean and variance obs_normalizer = pfrl.nn.EmpiricalNormalization( diff --git a/examples/mujoco/reproduction/soft_actor_critic/README.md b/examples/mujoco/reproduction/soft_actor_critic/README.md index 319fdd0c0..da7dd4fde 100644 --- a/examples/mujoco/reproduction/soft_actor_critic/README.md +++ b/examples/mujoco/reproduction/soft_actor_critic/README.md @@ -1,6 +1,6 @@ # Soft Actor-Critic (SAC) on MuJoCo benchmarks -This example trains a SAC agent ([Soft Actor-Critic Algorithms and Applications](https://arxiv.org/abs/1812.05905)) on MuJoCo benchmarks from OpenAI Gym. +This example trains a SAC agent ([Soft Actor-Critic Algorithms and Applications](https://arxiv.org/abs/1812.05905)) on MuJoCo benchmarks from OpenAI gymnasium. ## Requirements diff --git a/examples/mujoco/reproduction/soft_actor_critic/train_soft_actor_critic.py b/examples/mujoco/reproduction/soft_actor_critic/train_soft_actor_critic.py index 548a2ae38..bdfdd0e21 100644 --- a/examples/mujoco/reproduction/soft_actor_critic/train_soft_actor_critic.py +++ b/examples/mujoco/reproduction/soft_actor_critic/train_soft_actor_critic.py @@ -1,4 +1,4 @@ -"""A training script of Soft Actor-Critic on OpenAI Gym Mujoco environments. +"""A training script of Soft Actor-Critic on OpenAI gymnasium Mujoco environments. This script follows the settings of https://arxiv.org/abs/1812.05905 as much as possible. @@ -9,8 +9,8 @@ import sys from distutils.version import LooseVersion -import gym -import gym.wrappers +import gymnasium +import gymnasium.wrappers import numpy as np import torch from torch import distributions, nn @@ -36,7 +36,7 @@ def main(): "--env", type=str, default="Hopper-v2", - help="OpenAI Gym MuJoCo env to perform algorithm on.", + help="OpenAI gymnasium MuJoCo env to perform algorithm on.", ) parser.add_argument( "--num-envs", type=int, default=1, help="Number of envs run in parallel." @@ -84,7 +84,7 @@ def main(): "--pretrained-type", type=str, default="best", choices=["best", "final"] ) parser.add_argument( - "--monitor", action="store_true", help="Wrap env with gym.wrappers.Monitor." + "--monitor", action="store_true", help="Wrap env with gymnasium.wrappers.Monitor." ) parser.add_argument( "--log-interval", @@ -118,9 +118,9 @@ def main(): assert process_seeds.max() < 2**32 def make_env(process_idx, test): - env = gym.make(args.env) + env = gymnasium.make(args.env) # Unwrap TimiLimit wrapper - assert isinstance(env, gym.wrappers.TimeLimit) + assert isinstance(env, gymnasium.wrappers.TimeLimit) env = env.env # Use different random seeds for train and test envs process_seed = int(process_seeds[process_idx]) @@ -131,7 +131,7 @@ def make_env(process_idx, test): # Normalize action space to [-1, 1]^n env = pfrl.wrappers.NormalizeActionSpace(env) if args.monitor: - env = gym.wrappers.Monitor(env, args.outdir) + env = gymnasium.wrappers.Monitor(env, args.outdir) if args.render: env = pfrl.wrappers.Render(env) return env diff --git a/examples/mujoco/reproduction/td3/README.md b/examples/mujoco/reproduction/td3/README.md index a9503b03c..81c2748d9 100644 --- a/examples/mujoco/reproduction/td3/README.md +++ b/examples/mujoco/reproduction/td3/README.md @@ -1,6 +1,6 @@ # TD3 on MuJoCo benchmarks -This example trains a TD3 agent ([Addressing Function Approximation Error in Actor-Critic Methods](http://arxiv.org/abs/1802.09477)) on MuJoCo benchmarks from OpenAI Gym. +This example trains a TD3 agent ([Addressing Function Approximation Error in Actor-Critic Methods](http://arxiv.org/abs/1802.09477)) on MuJoCo benchmarks from OpenAI gymnasium. ## Requirements @@ -55,7 +55,7 @@ Each evaluation reports average return over 10 episodes without exploration nois Maximum evaluation scores, averaged over 10 trials (+/- standard deviation), are reported for each environment. Reported scores are taken from the "TD3" column of Table 1 of [Addressing Function Approximation Error in Actor-Critic Methods](http://arxiv.org/abs/1802.09477). -Although the original paper used v1 versions of MuJoCo envs, we used v2 as v1 are not supported by recent versions of OpenAI Gym. +Although the original paper used v1 versions of MuJoCo envs, we used v2 as v1 are not supported by recent versions of OpenAI gymnasium. | Environment | PFRL Score | Reported Score | | ------------------------- |:---------------------:|:---------------------:| @@ -73,7 +73,7 @@ Although the original paper used v1 versions of MuJoCo envs, we used v2 as v1 ar Average return of last 10 evaluation scores, averaged over 10 trials, are reported for each environment. Reported scores are taken from the "TD3" row of Table 2 of [Addressing Function Approximation Error in Actor-Critic Methods](http://arxiv.org/abs/1802.09477). -Although the original paper used v1 versions of MuJoCo envs, we used v2 as v1 are not supported by recent versions of OpenAI Gym. +Although the original paper used v1 versions of MuJoCo envs, we used v2 as v1 are not supported by recent versions of OpenAI gymnasium. | Environment | PFRL Score | Reported Score | | ------------------------- |:------------:|:--------------:| diff --git a/examples/mujoco/reproduction/td3/train_td3.py b/examples/mujoco/reproduction/td3/train_td3.py index 7913a3765..f57bec25e 100644 --- a/examples/mujoco/reproduction/td3/train_td3.py +++ b/examples/mujoco/reproduction/td3/train_td3.py @@ -1,4 +1,4 @@ -"""A training script of TD3 on OpenAI Gym Mujoco environments. +"""A training script of TD3 on OpenAI gymnasium Mujoco environments. This script follows the settings of http://arxiv.org/abs/1802.09477 as much as possible. @@ -8,8 +8,8 @@ import logging import sys -import gym -import gym.wrappers +import gymnasium +import gymnasium.wrappers import numpy as np import torch from torch import nn @@ -34,7 +34,7 @@ def main(): "--env", type=str, default="Hopper-v2", - help="OpenAI Gym MuJoCo env to perform algorithm on.", + help="OpenAI gymnasium MuJoCo env to perform algorithm on.", ) parser.add_argument("--seed", type=int, default=0, help="Random seed [0, 2 ** 32)") parser.add_argument( @@ -79,7 +79,7 @@ def main(): "--pretrained-type", type=str, default="best", choices=["best", "final"] ) parser.add_argument( - "--monitor", action="store_true", help="Wrap env with gym.wrappers.Monitor." + "--monitor", action="store_true", help="Wrap env with gymnasium.wrappers.Monitor." ) parser.add_argument( "--log-level", type=int, default=logging.INFO, help="Level of the root logger." @@ -95,9 +95,9 @@ def main(): utils.set_random_seed(args.seed) def make_env(test): - env = gym.make(args.env) + env = gymnasium.make(args.env) # Unwrap TimeLimit wrapper - assert isinstance(env, gym.wrappers.TimeLimit) + assert isinstance(env, gymnasium.wrappers.TimeLimit) env = env.env # Use different random seeds for train and test envs env_seed = 2**32 - 1 - args.seed if test else args.seed diff --git a/examples/mujoco/reproduction/trpo/README.md b/examples/mujoco/reproduction/trpo/README.md index 1841ee7e4..b2b176ece 100644 --- a/examples/mujoco/reproduction/trpo/README.md +++ b/examples/mujoco/reproduction/trpo/README.md @@ -1,6 +1,6 @@ # TRPO on MuJoCo benchmarks -This example trains a TRPO agent ([Trust Region Policy Optimization](https://arxiv.org/abs/1502.05477)) on MuJoCo benchmarks from OpenAI Gym. +This example trains a TRPO agent ([Trust Region Policy Optimization](https://arxiv.org/abs/1502.05477)) on MuJoCo benchmarks from OpenAI gymnasium. We follow the training and evaluation settings of [Deep Reinforcement Learning that Matters](https://arxiv.org/abs/1709.06560), which provides thorough, highly tuned benchmark results. @@ -37,7 +37,7 @@ To view the full list of options, either view the code or run the example with t ## Known differences -- We used version v2 of the environments whereas the original results were reported for version v1, however this doesn't seem to introduce significant differences: https://github.com/openai/gym/pull/834 +- We used version v2 of the environments whereas the original results were reported for version v1, however this doesn't seem to introduce significant differences: https://github.com/openai/gymnasium/pull/834 ## Results diff --git a/examples/mujoco/reproduction/trpo/train_trpo.py b/examples/mujoco/reproduction/trpo/train_trpo.py index 0a9de705b..c8a7715fe 100644 --- a/examples/mujoco/reproduction/trpo/train_trpo.py +++ b/examples/mujoco/reproduction/trpo/train_trpo.py @@ -1,4 +1,4 @@ -"""A training script of TRPO on OpenAI Gym Mujoco environments. +"""A training script of TRPO on OpenAI gymnasium Mujoco environments. This script follows the settings of https://arxiv.org/abs/1709.06560 as much as possible. @@ -6,9 +6,9 @@ import argparse import logging -import gym -import gym.spaces -import gym.wrappers +import gymnasium +import gymnasium.spaces +import gymnasium.wrappers import torch from torch import nn @@ -21,7 +21,7 @@ def main(): parser.add_argument( "--gpu", type=int, default=0, help="GPU device ID. Set to -1 to use CPUs only." ) - parser.add_argument("--env", type=str, default="Hopper-v2", help="Gym Env ID") + parser.add_argument("--env", type=str, default="Hopper-v2", help="gymnasium Env ID") parser.add_argument("--seed", type=int, default=0, help="Random seed [0, 2 ** 32)") parser.add_argument( "--outdir", @@ -82,7 +82,7 @@ def main(): "--monitor", action="store_true", help=( - "Monitor the env by gym.wrappers.Monitor." + "Monitor the env by gymnasium.wrappers.Monitor." " Videos and additional log will be saved." ), ) @@ -96,14 +96,14 @@ def main(): args.outdir = pfrl.experiments.prepare_output_dir(args, args.outdir) def make_env(test): - env = gym.make(args.env) + env = gymnasium.make(args.env) # Use different random seeds for train and test envs env_seed = 2**32 - 1 - args.seed if test else args.seed env.seed(env_seed) # Cast observations to float32 because our model uses float32 env = pfrl.wrappers.CastObservationToFloat32(env) if args.monitor: - env = gym.wrappers.Monitor(env, args.outdir) + env = gymnasium.wrappers.Monitor(env, args.outdir) if args.render: env = pfrl.wrappers.Render(env) return env @@ -115,7 +115,7 @@ def make_env(test): print("Observation space:", obs_space) print("Action space:", action_space) - assert isinstance(obs_space, gym.spaces.Box) + assert isinstance(obs_space, gymnasium.spaces.Box) # Normalize observations based on their empirical mean and variance obs_normalizer = pfrl.nn.EmpiricalNormalization( diff --git a/examples/optuna/optuna_dqn_obs1d.py b/examples/optuna/optuna_dqn_obs1d.py index c21e70e8d..dbe9ec741 100644 --- a/examples/optuna/optuna_dqn_obs1d.py +++ b/examples/optuna/optuna_dqn_obs1d.py @@ -14,7 +14,7 @@ import os import random -import gym +import gymnasium import torch.optim as optim try: @@ -54,9 +54,9 @@ def _objective_core( test_seed = 2**31 - 1 - seed def make_env(test=False): - env = gym.make(env_id) + env = gymnasium.make(env_id) - if not isinstance(env.observation_space, gym.spaces.Box): + if not isinstance(env.observation_space, gymnasium.spaces.Box): raise ValueError( "Supported only Box observation environments, but given: {}".format( env.observation_space @@ -68,7 +68,7 @@ def make_env(test=False): env.observation_space.shape ) ) - if not isinstance(env.action_space, gym.spaces.Discrete): + if not isinstance(env.action_space, gymnasium.spaces.Discrete): raise ValueError( "Supported only discrete action environments, but given: {}".format( env.action_space @@ -244,7 +244,7 @@ def main(): "--env", type=str, default="LunarLander-v2", - help="OpenAI Gym Environment ID.", + help="OpenAI gymnasium Environment ID.", ) parser.add_argument( "--outdir", diff --git a/examples/quickstart/quickstart.ipynb b/examples/quickstart/quickstart.ipynb index b31d0fe2e..5e2aa43d4 100644 --- a/examples/quickstart/quickstart.ipynb +++ b/examples/quickstart/quickstart.ipynb @@ -15,7 +15,7 @@ "\n", "If you have already installed PFRL, let's begin!\n", "\n", - "First, you need to import necessary modules. The module name of PFRL is `pfrl`. Let's import `torch`, `gym`, and `numpy` as well since they are used later." + "First, you need to import necessary modules. The module name of PFRL is `pfrl`. Let's import `torch`, `gymnasium`, and `numpy` as well since they are used later." ] }, { @@ -27,7 +27,7 @@ "import pfrl\n", "import torch\n", "import torch.nn\n", - "import gym\n", + "import gymnasium\n", "import numpy" ] }, @@ -35,7 +35,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "PFRL can be used for any problems if they are modeled as \"environments\". [OpenAI Gym](https://github.com/openai/gym) provides various kinds of benchmark environments and defines the common interface among them. PFRL uses a subset of the interface. Specifically, an environment must define its observation space and action space and have at least two methods: `reset` and `step`.\n", + "PFRL can be used for any problems if they are modeled as \"environments\". [OpenAI gymnasium](https://github.com/openai/gymnasium) provides various kinds of benchmark environments and defines the common interface among them. PFRL uses a subset of the interface. Specifically, an environment must define its observation space and action space and have at least two methods: `reset` and `step`.\n", "\n", "- `env.reset` will reset the environment to the initial state and return the initial observation.\n", "- `env.step` will execute a given action, move to the next state and return four values:\n", @@ -73,7 +73,7 @@ } ], "source": [ - "env = gym.make('CartPole-v0')\n", + "env = gymnasium.make('CartPole-v0')\n", "print('observation space:', env.observation_space)\n", "print('action space:', env.action_space)\n", "\n", diff --git a/examples/slimevolley/README.md b/examples/slimevolley/README.md index a3a4eac8f..7eb0afc43 100644 --- a/examples/slimevolley/README.md +++ b/examples/slimevolley/README.md @@ -1,6 +1,6 @@ # Slime Volleyball -This directory contains an example script that learns to play Slime Volleyball using the environment `SlimeVolley-v0` of [slimevolleygym](https://github.com/hardmaru/slimevolleygym). +This directory contains an example script that learns to play Slime Volleyball using the environment `SlimeVolley-v0` of [slimevolleygymnasium](https://github.com/hardmaru/slimevolleygymnasium). ![SlimeVolley](assets/slimevolley.gif) @@ -10,8 +10,8 @@ This directory contains an example script that learns to play Slime Volleyball u ## Requirements -- `slimevolleygym` (https://github.com/hardmaru/slimevolleygym) - - You can install from PyPI: `pip install slimevolleygym==0.1.0` +- `slimevolleygymnasium` (https://github.com/hardmaru/slimevolleygymnasium) + - You can install from PyPI: `pip install slimevolleygymnasium==0.1.0` ## Algorithm @@ -37,7 +37,7 @@ python examples/slimevolley/train_rainbow.py --demo --render --load > i) % 2 for i in range(self.orig_action_space.n)] @@ -129,10 +129,10 @@ def main(): def make_env(test): if "SlimeVolley" in args.env: - # You need to install slimevolleygym - import slimevolleygym # NOQA + # You need to install slimevolleygymnasium + import slimevolleygymnasium # NOQA - env = gym.make(args.env) + env = gymnasium.make(args.env) # Use different random seeds for train and test envs env_seed = test_seed if test else train_seed env.seed(int(env_seed)) @@ -142,7 +142,7 @@ def make_env(test): ) if args.render: env = pfrl.wrappers.Render(env) - if isinstance(env.action_space, gym.spaces.MultiBinary): + if isinstance(env.action_space, gymnasium.spaces.MultiBinary): env = MultiBinaryAsDiscreteAction(env) return env diff --git a/examples_tests/gym/test_categorical_dqn.sh b/examples_tests/gym/test_categorical_dqn.sh index db8c8505f..28181fb16 100644 --- a/examples_tests/gym/test_categorical_dqn.sh +++ b/examples_tests/gym/test_categorical_dqn.sh @@ -6,7 +6,7 @@ outdir=$(mktemp -d) gpu="$1" -# gym/categorical_dqn -python examples/gym/train_categorical_dqn_gym.py --steps 100 --replay-start-size 50 --outdir $outdir/gym/categorical_dqn --gpu $gpu -model=$(find $outdir/gym/categorical_dqn -name "*_finish") -python examples/gym/train_categorical_dqn_gym.py --demo --load $model --eval-n-runs 1 --outdir $outdir/temp --gpu $gpu +# gymnasium/categorical_dqn +python examples/gymnasium/train_categorical_dqn_gymnasium.py --steps 100 --replay-start-size 50 --outdir $outdir/gymnasium/categorical_dqn --gpu $gpu +model=$(find $outdir/gymnasium/categorical_dqn -name "*_finish") +python examples/gymnasium/train_categorical_dqn_gymnasium.py --demo --load $model --eval-n-runs 1 --outdir $outdir/temp --gpu $gpu diff --git a/examples_tests/gym/test_dqn.sh b/examples_tests/gym/test_dqn.sh index c4452538c..fca628ddf 100644 --- a/examples_tests/gym/test_dqn.sh +++ b/examples_tests/gym/test_dqn.sh @@ -6,7 +6,7 @@ outdir=$(mktemp -d) gpu="$1" -# gym/dqn -python examples/gym/train_dqn_gym.py --steps 100 --replay-start-size 50 --outdir $outdir/gym/dqn --gpu $gpu -model=$(find $outdir/gym/dqn -name "*_finish") -python examples/gym/train_dqn_gym.py --demo --load $model --eval-n-runs 1 --outdir $outdir/temp --gpu $gpu +# gymnasium/dqn +python examples/gymnasium/train_dqn_gymnasium.py --steps 100 --replay-start-size 50 --outdir $outdir/gymnasium/dqn --gpu $gpu +model=$(find $outdir/gymnasium/dqn -name "*_finish") +python examples/gymnasium/train_dqn_gymnasium.py --demo --load $model --eval-n-runs 1 --outdir $outdir/temp --gpu $gpu diff --git a/examples_tests/gym/test_reinforce.sh b/examples_tests/gym/test_reinforce.sh index 77a36bc89..f5a8d1e86 100644 --- a/examples_tests/gym/test_reinforce.sh +++ b/examples_tests/gym/test_reinforce.sh @@ -6,7 +6,7 @@ outdir=$(mktemp -d) gpu="$1" -# gym/reinforce -python examples/gym/train_reinforce_gym.py --steps 100 --batchsize 1 --outdir $outdir/gym/reinforce --gpu $gpu -model=$(find $outdir/gym/reinforce -name "*_finish") -python examples/gym/train_reinforce_gym.py --demo --load $model --eval-n-runs 1 --outdir $outdir/temp --gpu $gpu +# gymnasium/reinforce +python examples/gymnasium/train_reinforce_gymnasium.py --steps 100 --batchsize 1 --outdir $outdir/gymnasium/reinforce --gpu $gpu +model=$(find $outdir/gymnasium/reinforce -name "*_finish") +python examples/gymnasium/train_reinforce_gymnasium.py --demo --load $model --eval-n-runs 1 --outdir $outdir/temp --gpu $gpu diff --git a/examples_tests/slimevolley/test_rainbow.sh b/examples_tests/slimevolley/test_rainbow.sh index 605f19b08..e2c48c133 100644 --- a/examples_tests/slimevolley/test_rainbow.sh +++ b/examples_tests/slimevolley/test_rainbow.sh @@ -7,7 +7,7 @@ outdir=$(mktemp -d) gpu="$1" # slimevolley/rainbow -# Use CartPole-v0 to test without installing slimevolleygym +# Use CartPole-v0 to test without installing slimevolleygymnasium python examples/slimevolley/train_rainbow.py --gpu $gpu --steps 100 --outdir $outdir/slimevolley/rainbow --env CartPole-v0 model=$(find $outdir/slimevolley/rainbow -name "*_finish") python examples/slimevolley/train_rainbow.py --demo --load $model --eval-n-episodes 1 --outdir $outdir/temp --gpu $gpu --env CartPole-v0 diff --git a/pfrl/envs/abc.py b/pfrl/envs/abc.py index 29b7b8e29..7322fdc8d 100644 --- a/pfrl/envs/abc.py +++ b/pfrl/envs/abc.py @@ -1,5 +1,5 @@ import numpy as np -from gym import spaces +from gymnasium import spaces from pfrl import env diff --git a/pfrl/envs/multiprocess_vector_env.py b/pfrl/envs/multiprocess_vector_env.py index a993e1940..2b3decd12 100644 --- a/pfrl/envs/multiprocess_vector_env.py +++ b/pfrl/envs/multiprocess_vector_env.py @@ -41,7 +41,7 @@ class MultiprocessVectorEnv(pfrl.env.VectorEnv): Args: env_fns (list of callable): List of callables, each of which - returns gym.Env that is run in its own subprocess. + returns gymnasium.Env that is run in its own subprocess. """ def __init__(self, env_fns): diff --git a/pfrl/envs/serial_vector_env.py b/pfrl/envs/serial_vector_env.py index 73104adfe..025448f36 100644 --- a/pfrl/envs/serial_vector_env.py +++ b/pfrl/envs/serial_vector_env.py @@ -10,7 +10,7 @@ class SerialVectorEnv(pfrl.env.VectorEnv): use MultiprocessVectorEnv if possible. Args: - env_fns (list of gym.Env): List of gym.Env. + env_fns (list of gymnasium.Env): List of gymnasium.Env. """ def __init__(self, envs): diff --git a/pfrl/utils/pretrained_models.py b/pfrl/utils/pretrained_models.py index 3c7e64d02..37e7bc5a0 100644 --- a/pfrl/utils/pretrained_models.py +++ b/pfrl/utils/pretrained_models.py @@ -162,7 +162,7 @@ def download_model(alg, env, model_type="best"): Args: alg (string): URL to download from. - env (string): Gym Environment name. + env (string): gymnasium Environment name. model_type (string): Either `best` or `final`. Returns: str: Path to the downloaded file. diff --git a/pfrl/wrappers/atari_wrappers.py b/pfrl/wrappers/atari_wrappers.py index 2a4977952..ce87fc1e1 100644 --- a/pfrl/wrappers/atari_wrappers.py +++ b/pfrl/wrappers/atari_wrappers.py @@ -4,9 +4,9 @@ from collections import deque -import gym +import gymnasium import numpy as np -from gym import spaces +from gymnasium import spaces from packaging import version import pfrl @@ -20,13 +20,13 @@ _is_cv2_available = False -class NoopResetEnv(gym.Wrapper): +class NoopResetEnv(gymnasium.Wrapper): def __init__(self, env, noop_max=30): """Sample initial states by taking random number of no-ops on reset. No-op is assumed to be action 0. """ - gym.Wrapper.__init__(self, env) + gymnasium.Wrapper.__init__(self, env) self.noop_max = noop_max self.override_num_noops = None self.noop_action = 0 @@ -38,7 +38,7 @@ def reset(self, **kwargs): if self.override_num_noops is not None: noops = self.override_num_noops else: - if version.parse(gym.__version__) >= version.parse("0.24.0"): + if version.parse(gymnasium.__version__) >= version.parse("0.24.0"): noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) else: noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) @@ -54,10 +54,10 @@ def step(self, ac): return self.env.step(ac) -class FireResetEnv(gym.Wrapper): +class FireResetEnv(gymnasium.Wrapper): def __init__(self, env): """Take action on reset for envs that are fixed until firing.""" - gym.Wrapper.__init__(self, env) + gymnasium.Wrapper.__init__(self, env) assert env.unwrapped.get_action_meanings()[1] == "FIRE" assert len(env.unwrapped.get_action_meanings()) >= 3 @@ -75,13 +75,13 @@ def step(self, ac): return self.env.step(ac) -class EpisodicLifeEnv(gym.Wrapper): +class EpisodicLifeEnv(gymnasium.Wrapper): def __init__(self, env): """Make end-of-life == end-of-episode, but only reset on true game end. Done by DeepMind for the DQN and co. since it helps value estimation. """ - gym.Wrapper.__init__(self, env) + gymnasium.Wrapper.__init__(self, env) self.lives = 0 self.needs_real_reset = True @@ -115,10 +115,10 @@ def reset(self, **kwargs): return obs -class MaxAndSkipEnv(gym.Wrapper): +class MaxAndSkipEnv(gymnasium.Wrapper): def __init__(self, env, skip=4): """Return only every `skip`-th frame""" - gym.Wrapper.__init__(self, env) + gymnasium.Wrapper.__init__(self, env) # most recent raw observations (for max pooling across time steps) self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8) self._skip = skip @@ -146,16 +146,16 @@ def reset(self, **kwargs): return self.env.reset(**kwargs) -class ClipRewardEnv(gym.RewardWrapper): +class ClipRewardEnv(gymnasium.RewardWrapper): def __init__(self, env): - gym.RewardWrapper.__init__(self, env) + gymnasium.RewardWrapper.__init__(self, env) def reward(self, reward): """Bin reward to {+1, 0, -1} by its sign.""" return np.sign(reward) -class WarpFrame(gym.ObservationWrapper): +class WarpFrame(gymnasium.ObservationWrapper): def __init__(self, env, channel_order="hwc"): """Warp frames to 84x84 as done in the Nature paper and later work. @@ -166,7 +166,7 @@ def __init__(self, env, channel_order="hwc"): "Cannot import cv2 module. Please install OpenCV-Python to use" " WarpFrame." ) - gym.ObservationWrapper.__init__(self, env) + gymnasium.ObservationWrapper.__init__(self, env) self.width = 84 self.height = 84 shape = { @@ -185,7 +185,7 @@ def observation(self, frame): return frame.reshape(self.observation_space.low.shape) -class FrameStack(gym.Wrapper): +class FrameStack(gymnasium.Wrapper): def __init__(self, env, k, channel_order="hwc"): """Stack k last frames. @@ -195,7 +195,7 @@ def __init__(self, env, k, channel_order="hwc"): -------- baselines.common.atari_wrappers.LazyFrames """ - gym.Wrapper.__init__(self, env) + gymnasium.Wrapper.__init__(self, env) self.k = k self.frames = deque([], maxlen=k) self.stack_axis = {"hwc": 2, "chw": 0}[channel_order] @@ -222,7 +222,7 @@ def _get_ob(self): return LazyFrames(list(self.frames), stack_axis=self.stack_axis) -class ScaledFloatFrame(gym.ObservationWrapper): +class ScaledFloatFrame(gymnasium.ObservationWrapper): """Divide frame values by 255.0 and return them as np.float32. Especially, when the original env.observation_space is np.uint8, @@ -231,7 +231,7 @@ class ScaledFloatFrame(gym.ObservationWrapper): def __init__(self, env): assert isinstance(env.observation_space, spaces.Box) - gym.ObservationWrapper.__init__(self, env) + gymnasium.ObservationWrapper.__init__(self, env) self.scale = 255.0 @@ -272,11 +272,11 @@ def __array__(self, dtype=None): return out -class FlickerFrame(gym.ObservationWrapper): +class FlickerFrame(gymnasium.ObservationWrapper): """Stochastically flicker frames.""" def __init__(self, env): - gym.ObservationWrapper.__init__(self, env) + gymnasium.ObservationWrapper.__init__(self, env) def observation(self, observation): if self.unwrapped.np_random.rand() < 0.5: @@ -286,9 +286,9 @@ def observation(self, observation): def make_atari(env_id, max_frames=30 * 60 * 60): - env = gym.make(env_id) + env = gymnasium.make(env_id) assert "NoFrameskip" in env.spec.id - assert isinstance(env, gym.wrappers.TimeLimit) + assert isinstance(env, gymnasium.wrappers.TimeLimit) # Unwrap TimeLimit wrapper because we use our own time limits env = env.env if max_frames: diff --git a/pfrl/wrappers/cast_observation.py b/pfrl/wrappers/cast_observation.py index 4519e6fd4..c3444d3cd 100644 --- a/pfrl/wrappers/cast_observation.py +++ b/pfrl/wrappers/cast_observation.py @@ -1,8 +1,8 @@ -import gym +import gymnasium import numpy as np -class CastObservation(gym.ObservationWrapper): +class CastObservation(gymnasium.ObservationWrapper): """Cast observations to a given type. Args: diff --git a/pfrl/wrappers/continuing_time_limit.py b/pfrl/wrappers/continuing_time_limit.py index 04d7bec4f..3e47dd0da 100644 --- a/pfrl/wrappers/continuing_time_limit.py +++ b/pfrl/wrappers/continuing_time_limit.py @@ -1,10 +1,10 @@ -import gym +import gymnasium -class ContinuingTimeLimit(gym.Wrapper): +class ContinuingTimeLimit(gymnasium.Wrapper): """TimeLimit wrapper for continuing environments. - This is similar gym.wrappers.TimeLimit, which sets a time limit for + This is similar gymnasium.wrappers.TimeLimit, which sets a time limit for each episode, except that done=False is returned and that info['needs_reset'] is set to True when past the limit. @@ -13,7 +13,7 @@ class ContinuingTimeLimit(gym.Wrapper): key and its value is True. Args: - env (gym.Env): Env to wrap. + env (gymnasium.Env): Env to wrap. max_episode_steps (int): Maximum number of timesteps during an episode, after which the env needs a reset. """ diff --git a/pfrl/wrappers/monitor.py b/pfrl/wrappers/monitor.py index 4e8e842da..ad5e1f3ba 100644 --- a/pfrl/wrappers/monitor.py +++ b/pfrl/wrappers/monitor.py @@ -2,32 +2,32 @@ from logging import getLogger try: - from gym.wrappers import Monitor as _GymMonitor + from gymnasium.wrappers import Monitor as _gymnasiumMonitor except ImportError: class _Stub: def __init__(self, *args, **kwargs): - raise RuntimeError("Monitor is not available in this version of gym") + raise RuntimeError("Monitor is not available in this version of gymnasium") - class _GymMonitor(_Stub): # type: ignore + class _gymnasiumMonitor(_Stub): # type: ignore pass - class _GymStatsRecorder(_Stub): + class _gymnasiumStatsRecorder(_Stub): pass else: - from gym.wrappers.monitoring.stats_recorder import StatsRecorder as _GymStatsRecorder # type: ignore # isort: skip # noqa: E501 + from gymnasium.wrappers.monitoring.stats_recorder import StatsRecorder as _gymnasiumStatsRecorder # type: ignore # isort: skip # noqa: E501 -class Monitor(_GymMonitor): +class Monitor(_gymnasiumMonitor): """`Monitor` with PFRL's `ContinuingTimeLimit` support. `Agent` in PFRL might reset the env even when `done=False` if `ContinuingTimeLimit` returns `info['needs_reset']=True`, - which is not expected for `gym.Monitor`. + which is not expected for `gymnasium.Monitor`. For details, see - https://github.com/openai/gym/blob/master/gym/wrappers/monitor.py + https://github.com/openai/gymnasium/blob/master/gymnasium/wrappers/monitor.py """ def _start( @@ -66,11 +66,11 @@ def _start( return ret -class _StatsRecorder(_GymStatsRecorder): +class _StatsRecorder(_gymnasiumStatsRecorder): """`StatsRecorder` with PFRL's `ContinuingTimeLimit` support. For details, see - https://github.com/openai/gym/blob/master/gym/wrappers/monitoring/stats_recorder.py + https://github.com/openai/gymnasium/blob/master/gymnasium/wrappers/monitoring/stats_recorder.py """ def __init__( diff --git a/pfrl/wrappers/normalize_action_space.py b/pfrl/wrappers/normalize_action_space.py index dbf0ed24f..3e485c91f 100644 --- a/pfrl/wrappers/normalize_action_space.py +++ b/pfrl/wrappers/normalize_action_space.py @@ -1,15 +1,15 @@ -import gym -import gym.spaces +import gymnasium +import gymnasium.spaces import numpy as np -class NormalizeActionSpace(gym.ActionWrapper): +class NormalizeActionSpace(gymnasium.ActionWrapper): """Normalize a Box action space to [-1, 1]^n.""" def __init__(self, env): super().__init__(env) - assert isinstance(env.action_space, gym.spaces.Box) - self.action_space = gym.spaces.Box( + assert isinstance(env.action_space, gymnasium.spaces.Box) + self.action_space = gymnasium.spaces.Box( low=-np.ones_like(env.action_space.low), high=np.ones_like(env.action_space.low), ) diff --git a/pfrl/wrappers/randomize_action.py b/pfrl/wrappers/randomize_action.py index 9390f33bf..407cb1899 100644 --- a/pfrl/wrappers/randomize_action.py +++ b/pfrl/wrappers/randomize_action.py @@ -1,21 +1,21 @@ -import gym +import gymnasium import numpy as np -class RandomizeAction(gym.ActionWrapper): +class RandomizeAction(gymnasium.ActionWrapper): """Apply a random action instead of the one sent by the agent. This wrapper can be used to make a stochastic env. The common use is for evaluation in Atari environments, where actions are replaced with random ones with a low probability. - Only gym.spaces.Discrete is supported as an action space. + Only gymnasium.spaces.Discrete is supported as an action space. For exploration during training, use explorers like pfrl.explorers.ConstantEpsilonGreedy instead of this wrapper. Args: - env (gym.Env): Env to wrap. + env (gymnasium.Env): Env to wrap. random_fraction (float): Fraction of actions that will be replaced with a random action. It must be in [0, 1]. """ @@ -24,8 +24,8 @@ def __init__(self, env, random_fraction): super().__init__(env) assert 0 <= random_fraction <= 1 assert isinstance( - env.action_space, gym.spaces.Discrete - ), "RandomizeAction supports only gym.spaces.Discrete as an action space" + env.action_space, gymnasium.spaces.Discrete + ), "RandomizeAction supports only gymnasium.spaces.Discrete as an action space" self._random_fraction = random_fraction self._np_random = np.random.RandomState() diff --git a/pfrl/wrappers/render.py b/pfrl/wrappers/render.py index 6dc0c0384..dbd54de26 100644 --- a/pfrl/wrappers/render.py +++ b/pfrl/wrappers/render.py @@ -1,11 +1,11 @@ -import gym +import gymnasium -class Render(gym.Wrapper): +class Render(gymnasium.Wrapper): """Render env by calling its render method. Args: - env (gym.Env): Env to wrap. + env (gymnasium.Env): Env to wrap. **kwargs: Keyword arguments passed to the render method. """ diff --git a/pfrl/wrappers/scale_reward.py b/pfrl/wrappers/scale_reward.py index 784616da5..7d309c20e 100644 --- a/pfrl/wrappers/scale_reward.py +++ b/pfrl/wrappers/scale_reward.py @@ -1,7 +1,7 @@ -import gym +import gymnasium -class ScaleReward(gym.RewardWrapper): +class ScaleReward(gymnasium.RewardWrapper): """Scale reward by a scale factor. Args: diff --git a/pfrl/wrappers/vector_frame_stack.py b/pfrl/wrappers/vector_frame_stack.py index 5596f5b87..c01c18007 100644 --- a/pfrl/wrappers/vector_frame_stack.py +++ b/pfrl/wrappers/vector_frame_stack.py @@ -1,14 +1,14 @@ from collections import deque import numpy as np -from gym import spaces +from gymnasium import spaces from pfrl.env import VectorEnv from pfrl.wrappers.atari_wrappers import LazyFrames class VectorEnvWrapper(VectorEnv): - """VectorEnv analog to gym.Wrapper.""" + """VectorEnv analog to gymnasium.Wrapper.""" def __init__(self, env): self.env = env diff --git a/requirements.txt b/requirements.txt index 45b6e8b0b..2ac56ecd7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ torch>=1.3.0 -gym>=0.9.7 +gymnasium>=0.9.7 numpy>=1.10.4 filelock pillow diff --git a/setup.cfg b/setup.cfg index 808dfd412..8504ffc24 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,7 @@ [mypy-torch.*] ignore_missing_imports = True -[mypy-gym.*] +[mypy-gymnasium.*] ignore_missing_imports = True [mypy-numpy.*] @@ -33,7 +33,7 @@ ignore_missing_imports = True [mypy-roboschool.*] ignore_missing_imports = True -[mypy-slimevolleygym.*] +[mypy-slimevolleygymnasium.*] ignore_missing_imports = True [mypy-optuna.*] diff --git a/setup.py b/setup.py index 84ca77dba..d1910e0b9 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ install_requires = [ 'torch>=1.3.0', - 'gymnasium[all]', + 'gymnasiumnasium[all]', 'numpy>=1.11.0', 'pillow', 'filelock', diff --git a/tests/envs_tests/test_vector_envs.py b/tests/envs_tests/test_vector_envs.py index 768c09cc1..89515c096 100644 --- a/tests/envs_tests/test_vector_envs.py +++ b/tests/envs_tests/test_vector_envs.py @@ -1,4 +1,4 @@ -import gym +import gymnasium import numpy as np import pytest @@ -21,16 +21,16 @@ def setUp(self, num_envs, env_id, random_seed_offset, vector_env_to_test): # Init VectorEnv to test if self.vector_env_to_test == "SerialVectorEnv": self.vec_env = pfrl.envs.SerialVectorEnv( - [gym.make(self.env_id) for _ in range(self.num_envs)] + [gymnasium.make(self.env_id) for _ in range(self.num_envs)] ) elif self.vector_env_to_test == "MultiprocessVectorEnv": self.vec_env = pfrl.envs.MultiprocessVectorEnv( - [(lambda: gym.make(self.env_id)) for _ in range(self.num_envs)] + [(lambda: gymnasium.make(self.env_id)) for _ in range(self.num_envs)] ) else: assert False # Init envs to compare against - self.envs = [gym.make(self.env_id) for _ in range(self.num_envs)] + self.envs = [gymnasium.make(self.env_id) for _ in range(self.num_envs)] def teardown_method(self): # Delete so that all the subprocesses are joined diff --git a/tests/wrappers_tests/test_atari_wrappers.py b/tests/wrappers_tests/test_atari_wrappers.py index d3f47986c..77d4a3ea6 100644 --- a/tests/wrappers_tests/test_atari_wrappers.py +++ b/tests/wrappers_tests/test_atari_wrappers.py @@ -4,8 +4,8 @@ from unittest import mock -import gym -import gym.spaces +import gymnasium +import gymnasium.spaces import numpy as np import pytest @@ -46,8 +46,8 @@ def dtyped_rand(): ) for _ in range(steps) ] - env.action_space = gym.spaces.Discrete(2) - env.observation_space = gym.spaces.Box( + env.action_space = gymnasium.spaces.Discrete(2) + env.observation_space = gymnasium.spaces.Box( low=low, high=high, shape=(1, 84, 84), dtype=dtype ) return env @@ -118,8 +118,8 @@ def dtyped_rand(): ) for _ in range(steps) ] - env.action_space = gym.spaces.Discrete(2) - env.observation_space = gym.spaces.Box( + env.action_space = gymnasium.spaces.Discrete(2) + env.observation_space = gymnasium.spaces.Box( low=low, high=high, shape=(1, 84, 84), dtype=dtype ) return env diff --git a/tests/wrappers_tests/test_cast_observation.py b/tests/wrappers_tests/test_cast_observation.py index f6fac6269..06ad13f3b 100644 --- a/tests/wrappers_tests/test_cast_observation.py +++ b/tests/wrappers_tests/test_cast_observation.py @@ -1,4 +1,4 @@ -import gym +import gymnasium import numpy as np import pytest @@ -8,7 +8,7 @@ @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"]) @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64]) def test_cast_observation(env_id, dtype): - env = pfrl.wrappers.CastObservation(gym.make(env_id), dtype=dtype) + env = pfrl.wrappers.CastObservation(gymnasium.make(env_id), dtype=dtype) rtol = 1e-3 if dtype == np.float16 else 1e-7 obs = env.reset() @@ -25,7 +25,7 @@ def test_cast_observation(env_id, dtype): @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"]) def test_cast_observation_to_float32(env_id): - env = pfrl.wrappers.CastObservationToFloat32(gym.make(env_id)) + env = pfrl.wrappers.CastObservationToFloat32(gymnasium.make(env_id)) obs = env.reset() assert env.original_observation.dtype == np.float64 diff --git a/tests/wrappers_tests/test_monitor.py b/tests/wrappers_tests/test_monitor.py index ba65e9cc9..2151590e9 100644 --- a/tests/wrappers_tests/test_monitor.py +++ b/tests/wrappers_tests/test_monitor.py @@ -2,9 +2,9 @@ import shutil import tempfile -import gym +import gymnasium import pytest -from gym.wrappers import TimeLimit +from gymnasium.wrappers import TimeLimit import pfrl @@ -13,7 +13,7 @@ def test_monitor(n_episodes): steps = 15 - env = gym.make("CartPole-v1") + env = gymnasium.make("CartPole-v1") # unwrap default TimeLimit and wrap with new one to simulate done=True # at step 5 assert isinstance(env, TimeLimit) diff --git a/tests/wrappers_tests/test_randomize_action.py b/tests/wrappers_tests/test_randomize_action.py index 3d7826ef4..8b48f72d1 100644 --- a/tests/wrappers_tests/test_randomize_action.py +++ b/tests/wrappers_tests/test_randomize_action.py @@ -1,15 +1,15 @@ -import gym -import gym.spaces +import gymnasium +import gymnasium.spaces import numpy as np import pytest import pfrl -class ActionRecordingEnv(gym.Env): +class ActionRecordingEnv(gymnasium.Env): - observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,)) - action_space = gym.spaces.Discrete(3) + observation_space = gymnasium.spaces.Box(low=-1, high=1, shape=(1,)) + action_space = gymnasium.spaces.Discrete(3) def __init__(self): self.past_actions = [] diff --git a/tests/wrappers_tests/test_scale_reward.py b/tests/wrappers_tests/test_scale_reward.py index 027287461..0fade3b57 100644 --- a/tests/wrappers_tests/test_scale_reward.py +++ b/tests/wrappers_tests/test_scale_reward.py @@ -1,4 +1,4 @@ -import gym +import gymnasium import numpy as np import pytest @@ -8,7 +8,7 @@ @pytest.mark.parametrize("env_id", ["CartPole-v1", "MountainCar-v0"]) @pytest.mark.parametrize("scale", [1.0, 0.1]) def test_scale_reward(env_id, scale): - env = pfrl.wrappers.ScaleReward(gym.make(env_id), scale=scale) + env = pfrl.wrappers.ScaleReward(gymnasium.make(env_id), scale=scale) assert env.original_reward is None np.testing.assert_allclose(env.scale, scale) diff --git a/tests/wrappers_tests/test_vector_frame_stack.py b/tests/wrappers_tests/test_vector_frame_stack.py index ef45e63c8..1f649c400 100644 --- a/tests/wrappers_tests/test_vector_frame_stack.py +++ b/tests/wrappers_tests/test_vector_frame_stack.py @@ -2,8 +2,8 @@ import unittest from unittest import mock -import gym -import gym.spaces +import gymnasium +import gymnasium.spaces import numpy as np import pytest @@ -45,8 +45,8 @@ def make_env(idx): ) for _ in range(steps) ] - env.action_space = gym.spaces.Discrete(2) - env.observation_space = gym.spaces.Box( + env.action_space = gymnasium.spaces.Discrete(2) + env.observation_space = gymnasium.spaces.Box( low=0, high=255, shape=(1, 84, 84), dtype=np.uint8 ) return env From e1d7ead345c4b76a6ad8c3eaf037aa5108c988f7 Mon Sep 17 00:00:00 2001 From: Prabhat Date: Fri, 14 Apr 2023 22:34:36 -0500 Subject: [PATCH 03/26] modifies calls to env step to use truncations --- examples/grasping/train_dqn_batch_grasping.py | 4 +- examples/quickstart/quickstart.ipynb | 16 +++--- pfrl/envs/multiprocess_vector_env.py | 4 +- pfrl/envs/serial_vector_env.py | 4 +- pfrl/experiments/evaluator.py | 12 ++--- pfrl/experiments/train_agent.py | 8 +-- pfrl/experiments/train_agent_async.py | 8 +-- pfrl/experiments/train_agent_batch.py | 10 ++-- pfrl/utils/env_modifiers.py | 12 ++--- pfrl/wrappers/atari_wrappers.py | 26 ++++----- pfrl/wrappers/continuing_time_limit.py | 2 +- pfrl/wrappers/vector_frame_stack.py | 4 +- tests/envs_tests/test_vector_envs.py | 7 +-- tests/experiments_tests/test_evaluator.py | 54 +++++++++---------- tests/experiments_tests/test_train_agent.py | 32 +++++------ .../test_train_agent_async.py | 24 ++++----- .../test_train_agent_batch.py | 34 ++++++------ tests/wrappers_tests/test_atari_wrappers.py | 10 ++-- tests/wrappers_tests/test_cast_observation.py | 4 +- .../test_continuing_time_limit.py | 12 ++--- tests/wrappers_tests/test_monitor.py | 4 +- .../wrappers_tests/test_vector_frame_stack.py | 10 ++-- 22 files changed, 152 insertions(+), 149 deletions(-) diff --git a/examples/grasping/train_dqn_batch_grasping.py b/examples/grasping/train_dqn_batch_grasping.py index a81d04649..0cff325ad 100644 --- a/examples/grasping/train_dqn_batch_grasping.py +++ b/examples/grasping/train_dqn_batch_grasping.py @@ -64,10 +64,10 @@ def reset(self): return self.env.reset(), self._elapsed_steps def step(self, action): - observation, reward, done, info = self.env.step(action) + observation, reward, terminated, truncated, info = self.env.step(action) self._elapsed_steps += 1 assert self._elapsed_steps <= self._max_steps - return (observation, self._elapsed_steps), reward, done, info + return (observation, self._elapsed_steps), reward, terminated, truncated, info class RecordMovie(gymnasium.Wrapper): diff --git a/examples/quickstart/quickstart.ipynb b/examples/quickstart/quickstart.ipynb index 5e2aa43d4..d139c0ef7 100644 --- a/examples/quickstart/quickstart.ipynb +++ b/examples/quickstart/quickstart.ipynb @@ -38,10 +38,11 @@ "PFRL can be used for any problems if they are modeled as \"environments\". [OpenAI gymnasium](https://github.com/openai/gymnasium) provides various kinds of benchmark environments and defines the common interface among them. PFRL uses a subset of the interface. Specifically, an environment must define its observation space and action space and have at least two methods: `reset` and `step`.\n", "\n", "- `env.reset` will reset the environment to the initial state and return the initial observation.\n", - "- `env.step` will execute a given action, move to the next state and return four values:\n", + "- `env.step` will execute a given action, move to the next state and return five values:\n", " - a next observation\n", " - a scalar reward\n", " - a boolean value indicating whether the current state is terminal or not\n", + " - a boolean value indicating whether the episode has been truncated or not\n", " - additional information\n", "- `env.render` will render the current state. (optional)\n", "\n", @@ -81,10 +82,11 @@ "print('initial observation:', obs)\n", "\n", "action = env.action_space.sample()\n", - "obs, r, done, info = env.step(action)\n", + "obs, r, terminated, truncated, info = env.step(action)\n", "print('next observation:', obs)\n", "print('reward:', r)\n", - "print('done:', done)\n", + "print('terminated:', terminated)\n", + "print('terminated:', truncated)\n", "print('info:', info)\n", "\n", "# Uncomment to open a GUI window rendering the current state of the environment\n", @@ -315,11 +317,11 @@ " # Uncomment to watch the behavior in a GUI window\n", " # env.render()\n", " action = agent.act(obs)\n", - " obs, reward, done, _ = env.step(action)\n", + " obs, reward, terminated, _, _ = env.step(action)\n", " R += reward\n", " t += 1\n", " reset = t == max_episode_len\n", - " agent.observe(obs, reward, done, reset)\n", + " agent.observe(obs, reward, terminated, reset)\n", " if done or reset:\n", " break\n", " if i % 10 == 0:\n", @@ -373,11 +375,11 @@ " # Uncomment to watch the behavior in a GUI window\n", " # env.render()\n", " action = agent.act(obs)\n", - " obs, r, done, _ = env.step(action)\n", + " obs, r, terminated, _, _ = env.step(action)\n", " R += r\n", " t += 1\n", " reset = t == 200\n", - " agent.observe(obs, r, done, reset)\n", + " agent.observe(obs, r, terminated, reset)\n", " if done or reset:\n", " break\n", " print('evaluation episode:', i, 'R:', R)" diff --git a/pfrl/envs/multiprocess_vector_env.py b/pfrl/envs/multiprocess_vector_env.py index 2b3decd12..9e1a6aef2 100644 --- a/pfrl/envs/multiprocess_vector_env.py +++ b/pfrl/envs/multiprocess_vector_env.py @@ -16,8 +16,8 @@ def worker(remote, env_fn): while True: cmd, data = remote.recv() if cmd == "step": - ob, reward, done, info = env.step(data) - remote.send((ob, reward, done, info)) + ob, reward, terminated, truncated, info = env.step(data) + remote.send((ob, reward, terminated, truncated, info)) elif cmd == "reset": ob = env.reset() remote.send(ob) diff --git a/pfrl/envs/serial_vector_env.py b/pfrl/envs/serial_vector_env.py index 025448f36..b5a61e1a2 100644 --- a/pfrl/envs/serial_vector_env.py +++ b/pfrl/envs/serial_vector_env.py @@ -22,8 +22,8 @@ def __init__(self, envs): def step(self, actions): results = [env.step(a) for env, a in zip(self.envs, actions)] - self.last_obs, rews, dones, infos = zip(*results) - return self.last_obs, rews, dones, infos + self.last_obs, rews, terminations, truncations, infos = zip(*results) + return self.last_obs, rews, terminations, truncations, infos def reset(self, mask=None): if mask is None: diff --git a/pfrl/experiments/evaluator.py b/pfrl/experiments/evaluator.py index 75691784c..bce3f7f55 100644 --- a/pfrl/experiments/evaluator.py +++ b/pfrl/experiments/evaluator.py @@ -35,12 +35,12 @@ def _run_episodes( episode_len = 0 info = {} a = agent.act(obs) - obs, r, done, info = env.step(a) + obs, r, terminated, truncated, info = env.step(a) test_r += r episode_len += 1 timestep += 1 - reset = done or episode_len == max_episode_len or info.get("needs_reset", False) - agent.observe(obs, r, done, reset) + reset = terminated or episode_len == max_episode_len or info.get("needs_reset", False) or truncated + agent.observe(obs, r, terminated, reset) if reset: logger.info( "evaluation episode %s length:%s R:%s", len(scores), episode_len, test_r @@ -130,7 +130,7 @@ def _batch_run_episodes( actions = agent.batch_act(obss) timestep += 1 # o_{t+1}, r_{t+1} - obss, rs, dones, infos = env.step(actions) + obss, rs, terminations, truncations, infos = env.step(actions) episode_r += rs episode_len += 1 # Compute mask for done and reset @@ -139,11 +139,11 @@ def _batch_run_episodes( else: resets = episode_len == max_episode_len resets = np.logical_or( - resets, [info.get("needs_reset", False) for info in infos] + resets, [info.get("needs_reset", False) or truncated for truncated, info in zip(truncations, infos)] ) # Make mask. 0 if done/reset, 1 if pass - end = np.logical_or(resets, dones) + end = np.logical_or(resets, terminations) not_end = np.logical_not(end) for index in range(len(end)): diff --git a/pfrl/experiments/train_agent.py b/pfrl/experiments/train_agent.py index 210c7ed24..341ebefff 100644 --- a/pfrl/experiments/train_agent.py +++ b/pfrl/experiments/train_agent.py @@ -56,17 +56,17 @@ def train_agent( # a_t action = agent.act(obs) # o_{t+1}, r_{t+1} - obs, r, done, info = env.step(action) + obs, r, terminated, truncated, info = env.step(action) t += 1 episode_r += r episode_len += 1 - reset = episode_len == max_episode_len or info.get("needs_reset", False) - agent.observe(obs, r, done, reset) + reset = episode_len == max_episode_len or info.get("needs_reset", False) or truncated + agent.observe(obs, r, terminated, reset) for hook in step_hooks: hook(env, agent, t) - episode_end = done or reset or t == steps + episode_end = terminated or reset or t == steps if episode_end: logger.info( diff --git a/pfrl/experiments/train_agent_async.py b/pfrl/experiments/train_agent_async.py index 9856bb0a2..02b34c908 100644 --- a/pfrl/experiments/train_agent_async.py +++ b/pfrl/experiments/train_agent_async.py @@ -69,12 +69,12 @@ def save_model(): # a_t a = agent.act(obs) # o_{t+1}, r_{t+1} - obs, r, done, info = env.step(a) + obs, r, terminated, truncated, info = env.step(a) local_t += 1 episode_r += r episode_len += 1 - reset = episode_len == max_episode_len or info.get("needs_reset", False) - agent.observe(obs, r, done, reset) + reset = episode_len == max_episode_len or info.get("needs_reset", False) or truncated + agent.observe(obs, r, terminated, reset) # Get and increment the global counter with counter.get_lock(): @@ -84,7 +84,7 @@ def save_model(): for hook in global_step_hooks: hook(env, agent, global_t) - if done or reset or global_t >= steps or stop_event.is_set(): + if terminated or reset or global_t >= steps or stop_event.is_set(): if process_idx == 0: logger.info( "outdir:%s global_step:%s local_step:%s R:%s", diff --git a/pfrl/experiments/train_agent_batch.py b/pfrl/experiments/train_agent_batch.py index add7cda81..5452d9aa9 100644 --- a/pfrl/experiments/train_agent_batch.py +++ b/pfrl/experiments/train_agent_batch.py @@ -66,7 +66,7 @@ def train_agent_batch( # a_t actions = agent.batch_act(obss) # o_{t+1}, r_{t+1} - obss, rs, dones, infos = env.step(actions) + obss, rs, terminations, truncations, infos = env.step(actions) episode_r += rs episode_len += 1 @@ -76,13 +76,13 @@ def train_agent_batch( else: resets = episode_len == max_episode_len resets = np.logical_or( - resets, [info.get("needs_reset", False) for info in infos] + resets, [info.get("needs_reset", False) or truncation for truncation, info in zip(truncations, infos)] ) # Agent observes the consequences - agent.batch_observe(obss, rs, dones, resets) + agent.batch_observe(obss, rs, terminations, resets) - # Make mask. 0 if done/reset, 1 if pass - end = np.logical_or(resets, dones) + # Make mask. 0 if termination/reset, 1 if pass + end = np.logical_or(resets, terminations) not_end = np.logical_not(end) # For episodes that ends, do the following: diff --git a/pfrl/utils/env_modifiers.py b/pfrl/utils/env_modifiers.py index a605b7b71..2c8b94259 100644 --- a/pfrl/utils/env_modifiers.py +++ b/pfrl/utils/env_modifiers.py @@ -24,11 +24,11 @@ def make_timestep_limited(env, timestep_limit): old_reset = env.reset def step(action): - observation, reward, done, info = old_step(action) + observation, reward, done, truncated, info = old_step(action) if t[0] >= timestep_limit: done = True t[0] += 1 - return observation, reward, done, info + return observation, reward, done, truncated, info def reset(): t[0] = 1 @@ -51,9 +51,9 @@ def make_reward_filtered(env, reward_filter): old_step = env.step def step(action): - observation, reward, done, info = old_step(action) + observation, reward, done, truncated, info = old_step(action) reward = reward_filter(reward) - return observation, reward, done, info + return observation, reward, done, truncated, info env.step = step @@ -73,10 +73,10 @@ def make_action_repeated(env, n_times): def step(action): r_total = 0 for _ in range(n_times): - obs, r, done, info = old_step(action) + obs, r, done, truncated, info = old_step(action) r_total += r if done: break - return obs, r_total, done, info + return obs, r_total, done, truncated, info env.step = step diff --git a/pfrl/wrappers/atari_wrappers.py b/pfrl/wrappers/atari_wrappers.py index ce87fc1e1..53884329d 100644 --- a/pfrl/wrappers/atari_wrappers.py +++ b/pfrl/wrappers/atari_wrappers.py @@ -45,8 +45,8 @@ def reset(self, **kwargs): assert noops > 0 obs = None for _ in range(noops): - obs, _, done, info = self.env.step(self.noop_action) - if done or info.get("needs_reset", False): + obs, _, done, truncated, info = self.env.step(self.noop_action) + if done or info.get("needs_reset", False) or truncated: obs = self.env.reset(**kwargs) return obs @@ -63,11 +63,11 @@ def __init__(self, env): def reset(self, **kwargs): self.env.reset(**kwargs) - obs, _, done, info = self.env.step(1) - if done or info.get("needs_reset", False): + obs, _, done, truncated, info = self.env.step(1) + if done or info.get("needs_reset", False) or truncated: self.env.reset(**kwargs) obs, _, done, info = self.env.step(2) - if done or info.get("needs_reset", False): + if done or info.get("needs_reset", False) or truncated: self.env.reset(**kwargs) return obs @@ -86,7 +86,7 @@ def __init__(self, env): self.needs_real_reset = True def step(self, action): - obs, reward, done, info = self.env.step(action) + obs, reward, done, truncated, info = self.env.step(action) self.needs_real_reset = done or info.get("needs_reset", False) # check current lives, make loss of life terminal, # then update lives to handle bonus lives @@ -98,7 +98,7 @@ def step(self, action): # the environment advertises done. done = True self.lives = lives - return obs, reward, done, info + return obs, reward, done, truncated, info def reset(self, **kwargs): """Reset only when lives are exhausted. @@ -110,7 +110,7 @@ def reset(self, **kwargs): obs = self.env.reset(**kwargs) else: # no-op step to advance from terminal/lost life state - obs, _, _, _ = self.env.step(0) + obs, _, _, _, _ = self.env.step(0) self.lives = self.env.unwrapped.ale.lives() return obs @@ -128,19 +128,19 @@ def step(self, action): total_reward = 0.0 done = None for i in range(self._skip): - obs, reward, done, info = self.env.step(action) + obs, reward, done, truncated, info = self.env.step(action) if i == self._skip - 2: self._obs_buffer[0] = obs if i == self._skip - 1: self._obs_buffer[1] = obs total_reward += reward - if done or info.get("needs_reset", False): + if done or info.get("needs_reset", False) or truncated: break # Note that the observation on the done=True frame # doesn't matter max_frame = self._obs_buffer.max(axis=0) - return max_frame, total_reward, done, info + return max_frame, total_reward, done, truncated, info def reset(self, **kwargs): return self.env.reset(**kwargs) @@ -213,9 +213,9 @@ def reset(self): return self._get_ob() def step(self, action): - ob, reward, done, info = self.env.step(action) + ob, reward, done, truncated, info = self.env.step(action) self.frames.append(ob) - return self._get_ob(), reward, done, info + return self._get_ob(), reward, done, truncated, info def _get_ob(self): assert len(self.frames) == self.k diff --git a/pfrl/wrappers/continuing_time_limit.py b/pfrl/wrappers/continuing_time_limit.py index 3e47dd0da..30f5900c2 100644 --- a/pfrl/wrappers/continuing_time_limit.py +++ b/pfrl/wrappers/continuing_time_limit.py @@ -28,7 +28,7 @@ def step(self, action): assert ( self._elapsed_steps is not None ), "Cannot call env.step() before calling reset()" - observation, reward, done, info = self.env.step(action) + observation, reward, done, _, info = self.env.step(action) self._elapsed_steps += 1 if self._max_episode_steps <= self._elapsed_steps: diff --git a/pfrl/wrappers/vector_frame_stack.py b/pfrl/wrappers/vector_frame_stack.py index c01c18007..6b3626caf 100644 --- a/pfrl/wrappers/vector_frame_stack.py +++ b/pfrl/wrappers/vector_frame_stack.py @@ -91,10 +91,10 @@ def reset(self, mask=None): return self._get_ob() def step(self, action): - batch_ob, reward, done, info = self.env.step(action) + batch_ob, reward, terminated, _, info = self.env.step(action) for frames, ob in zip(self.frames, batch_ob): frames.append(ob) - return self._get_ob(), reward, done, info + return self._get_ob(), reward, terminated, info def _get_ob(self): assert len(self.frames) == self.env.num_envs diff --git a/tests/envs_tests/test_vector_envs.py b/tests/envs_tests/test_vector_envs.py index 89515c096..7a89e9984 100644 --- a/tests/envs_tests/test_vector_envs.py +++ b/tests/envs_tests/test_vector_envs.py @@ -59,14 +59,15 @@ def test_seed_reset_and_step(self): # step actions = [env.action_space.sample() for env in self.envs] - real_obss, real_rewards, real_dones, real_infos = zip( + real_obss, real_rewards, real_terminations, real_truncations, real_infos = zip( *[env.step(action) for env, action in zip(self.envs, actions)] ) - obss, rewards, dones, infos = self.vec_env.step(actions) + obss, rewards, terminations, truncations, infos = self.vec_env.step(actions) np.testing.assert_allclose(obss, real_obss) assert rewards == real_rewards - assert dones == real_dones + assert terminations == real_terminations assert infos == real_infos + assert truncations == real_truncations # reset with full mask should have no effect mask = np.ones(self.num_envs) diff --git a/tests/experiments_tests/test_evaluator.py b/tests/experiments_tests/test_evaluator.py index 77ea7a6a3..4f6811ea1 100644 --- a/tests/experiments_tests/test_evaluator.py +++ b/tests/experiments_tests/test_evaluator.py @@ -23,7 +23,7 @@ def test_evaluator_evaluate_if_necessary(save_best_so_far_agent, n_steps, n_epis env = mock.Mock() env.reset.return_value = "obs" - env.step.return_value = ("obs", 0, True, {}) + env.step.return_value = ("obs", 0, True, False, {}) env.get_statistics.return_value = [] evaluation_hook = mock.create_autospec( @@ -89,7 +89,7 @@ def test_evaluator_evaluate_if_necessary(save_best_so_far_agent, n_steps, n_epis assert agent.save.call_count == 0 # Third evaluation with a better score - env.step.return_value = ("obs", 1, True, {}) + env.step.return_value = ("obs", 1, True, False, {}) agent_evaluator.evaluate_if_necessary(t=9, episodes=9) assert agent.act.call_count == 3 * value assert agent.observe.call_count == 3 * value @@ -112,7 +112,7 @@ def test_async_evaluator_evaluate_if_necessary(save_best_so_far_agent, n_episode env = mock.Mock() env.reset.return_value = "obs" - env.step.return_value = ("obs", 0, True, {}) + env.step.return_value = ("obs", 0, True, False, {}) env.get_statistics.return_value = [] evaluation_hook = mock.create_autospec( @@ -159,7 +159,7 @@ def test_async_evaluator_evaluate_if_necessary(save_best_so_far_agent, n_episode assert agent.save.call_count == 0 # Third evaluation with a better score - env.step.return_value = ("obs", 1, True, {}) + env.step.return_value = ("obs", 1, True, False, {}) agent_evaluator.evaluate_if_necessary(t=9, episodes=9, env=env, agent=agent) assert agent.act.call_count == 3 * n_episodes assert agent.observe.call_count == 3 * n_episodes @@ -184,8 +184,8 @@ def test_run_evaluation_episodes_with_n_steps(n_episodes, n_steps): (("state", 2), 0.2, False, {}), (("state", 3), 0.3, False, {"needs_reset": True}), (("state", 5), -0.5, False, {}), - (("state", 6), 0, False, {}), - (("state", 7), 1, True, {}), + (("state", 6), 0, False, False, {}), + (("state", 7), 1, True, False, {}), ] if n_episodes: @@ -262,11 +262,11 @@ def make_env(idx): # Second episode: 4 -> 5 -> 6 -> 7 (done) env.reset.side_effect = [("state", 0), ("state", 4)] env.step.side_effect = [ - (("state", 1), 0, False, {}), - (("state", 2), 0.1, False, {}), - (("state", 3), 0.2, False, {"needs_reset": True}), - (("state", 5), -0.5, False, {}), - (("state", 6), 0, False, {}), + (("state", 1), 0, False, False, {}), + (("state", 2), 0.1, False, False, {}), + (("state", 3), 0.2, False, False, {"needs_reset": True}), + (("state", 5), -0.5, False, False, {}), + (("state", 6), 0, False, False, {}), (("state", 7), 1, True, {}), ] else: @@ -275,11 +275,11 @@ def make_env(idx): # Third episode: 4 -> 5 -> 6 -> 7 (done) env.reset.side_effect = [("state", 0), ("state", 2), ("state", 4)] env.step.side_effect = [ - (("state", 1), 2, False, {"needs_reset": True}), - (("state", 3), 3, False, {"needs_reset": True}), - (("state", 5), -0.6, False, {}), - (("state", 6), 0, False, {}), - (("state", 7), 1, True, {}), + (("state", 1), 2, False, False, {"needs_reset": True}), + (("state", 3), 3, False, False, {"needs_reset": True}), + (("state", 5), -0.6, False, False, {}), + (("state", 6), 0, False, False, {}), + (("state", 7), 1, True, False, {}), ] return env @@ -327,12 +327,12 @@ def make_env(idx): # Second episode: 4 -> 5 -> 6 -> 7 (done) env.reset.side_effect = [("state", 0), ("state", 4)] env.step.side_effect = [ - (("state", 1), 0, False, {}), - (("state", 2), 0, False, {}), - (("state", 3), 0, False, {"needs_reset": True}), - (("state", 5), -0.5, False, {}), - (("state", 6), 0, False, {}), - (("state", 7), 1, True, {}), + (("state", 1), 0, False, False, {}), + (("state", 2), 0, False, False, {}), + (("state", 3), 0, False, False, {"needs_reset": True}), + (("state", 5), -0.5, False, False, {}), + (("state", 6), 0, False, False, {}), + (("state", 7), 1, True, False, {}), ] else: # First episode: 0 -> 1 (reset) @@ -340,11 +340,11 @@ def make_env(idx): # Third episode: 4 -> 5 -> 6 -> 7 (done) env.reset.side_effect = [("state", 0), ("state", 2), ("state", 4)] env.step.side_effect = [ - (("state", 1), 2, False, {"needs_reset": True}), - (("state", 3), 3, False, {"needs_reset": True}), - (("state", 5), -0.6, False, {}), - (("state", 6), 0, False, {}), - (("state", 7), 1, True, {}), + (("state", 1), 2, False, False, {"needs_reset": True}), + (("state", 3), 3, False, False, {"needs_reset": True}), + (("state", 5), -0.6, False, False, {}), + (("state", 6), 0, False, False, {}), + (("state", 7), 1, True, False, {}), ] return env diff --git a/tests/experiments_tests/test_train_agent.py b/tests/experiments_tests/test_train_agent.py index ac367e149..5ba306e88 100644 --- a/tests/experiments_tests/test_train_agent.py +++ b/tests/experiments_tests/test_train_agent.py @@ -17,11 +17,11 @@ def test(self): # Reaches the terminal state after five actions env.reset.side_effect = [("state", 0)] env.step.side_effect = [ - (("state", 1), 0, False, {}), - (("state", 2), 0, False, {}), - (("state", 3), -0.5, False, {}), - (("state", 4), 0, False, {}), - (("state", 5), 1, True, {}), + (("state", 1), 0, False, False, {}), + (("state", 2), 0, False, False, {}), + (("state", 3), -0.5, False, False, {}), + (("state", 4), 0, False, False, {}), + (("state", 5), 1, True, False, {}), ] hook = mock.Mock() @@ -59,12 +59,12 @@ def test_needs_reset(self): # Second episode: 4 -> 5 -> 6 -> 7 (done) env.reset.side_effect = [("state", 0), ("state", 4)] env.step.side_effect = [ - (("state", 1), 0, False, {}), - (("state", 2), 0, False, {}), - (("state", 3), 0, False, {"needs_reset": True}), - (("state", 5), -0.5, False, {}), - (("state", 6), 0, False, {}), - (("state", 7), 1, True, {}), + (("state", 1), 0, False, False, {}), + (("state", 2), 0, False, False, {}), + (("state", 3), 0, False, False, {"needs_reset": True}), + (("state", 5), -0.5, False, False, {}), + (("state", 6), 0, False, False,{}), + (("state", 7), 1, True, False, {}), ] hook = mock.Mock() @@ -144,11 +144,11 @@ def test_eval_during_episode(eval_during_episode): # Two episodes env.reset.side_effect = [("state", 0)] * 2 env.step.side_effect = [ - (("state", 1), 0, False, {}), - (("state", 2), 0, False, {}), - (("state", 3), -0.5, True, {}), - (("state", 4), 0, False, {}), - (("state", 5), 1, True, {}), + (("state", 1), 0, False, False, {}), + (("state", 2), 0, False, False, {}), + (("state", 3), -0.5, True, False, {}), + (("state", 4), 0, False, False, {}), + (("state", 5), 1, True, False, {}), ] evaluator = mock.Mock() diff --git a/tests/experiments_tests/test_train_agent_async.py b/tests/experiments_tests/test_train_agent_async.py index bd0434018..cc3869acc 100644 --- a/tests/experiments_tests/test_train_agent_async.py +++ b/tests/experiments_tests/test_train_agent_async.py @@ -27,16 +27,16 @@ def _make_env(process_idx, test): if max_episode_len is None: # Episodic env that terminates after 5 actions env.step.side_effect = [ - (("state", 1), 0, False, {}), - (("state", 2), 0, False, {}), - (("state", 3), -0.5, False, {}), - (("state", 4), 0, False, {}), - (("state", 5), 1, True, {}), + (("state", 1), 0, False, False, {}), + (("state", 2), 0, False, False, {}), + (("state", 3), -0.5, False, False, {}), + (("state", 4), 0, False, False, {}), + (("state", 5), 1, True, False, {}), ] * 1000 else: # Continuing env env.step.side_effect = [ - (("state", 1), 0, False, {}), + (("state", 1), 0, False,False, {}), ] * 1000 return env @@ -156,12 +156,12 @@ def test_needs_reset(self): # Second episode: 4 -> 5 -> 6 -> 7 (done) env.reset.side_effect = [("state", 0), ("state", 4)] env.step.side_effect = [ - (("state", 1), 0, False, {}), - (("state", 2), 0, False, {}), - (("state", 3), 0, False, {"needs_reset": True}), - (("state", 5), -0.5, False, {}), - (("state", 6), 0, False, {}), - (("state", 7), 1, True, {}), + (("state", 1), 0, False, False, {}), + (("state", 2), 0, False, False, {}), + (("state", 3), 0, False, False, {"needs_reset": True}), + (("state", 5), -0.5, False, False, {}), + (("state", 6), 0, False, False, {}), + (("state", 7), 1, True, False, {}), ] counter = mp.Value("i", 0) diff --git a/tests/experiments_tests/test_train_agent_batch.py b/tests/experiments_tests/test_train_agent_batch.py index ef2bbfc37..b719cfac3 100644 --- a/tests/experiments_tests/test_train_agent_batch.py +++ b/tests/experiments_tests/test_train_agent_batch.py @@ -25,16 +25,16 @@ def make_env(): if max_episode_len is None: # Episodic env that terminates after 5 actions env.step.side_effect = [ - (("state", 1), 0, False, {}), - (("state", 2), 0, False, {}), - (("state", 3), -0.5, False, {}), - (("state", 4), 0, False, {}), - (("state", 5), 1, True, {}), + (("state", 1), 0, False, False, {}), + (("state", 2), 0, False, False, {}), + (("state", 3), -0.5, False, False, {}), + (("state", 4), 0, False, False, {}), + (("state", 5), 1, True, False, {}), ] * 1000 else: # Continuing env env.step.side_effect = [ - (("state", 1), 0, False, {}), + (("state", 1), 0, False, False, {}), ] * 1000 return env @@ -194,12 +194,12 @@ def make_env(idx): # Second episode: 4 -> 5 -> 6 -> 7 (done) env.reset.side_effect = [("state", 0), ("state", 4)] env.step.side_effect = [ - (("state", 1), 0, False, {}), - (("state", 2), 0, False, {}), - (("state", 3), 0, False, {"needs_reset": True}), - (("state", 5), -0.5, False, {}), - (("state", 6), 0, False, {}), - (("state", 7), 1, True, {}), + (("state", 1), 0, False, False, {}), + (("state", 2), 0, False, False, {}), + (("state", 3), 0, False, False, {"needs_reset": True}), + (("state", 5), -0.5, False, False, {}), + (("state", 6), 0, False, False, {}), + (("state", 7), 1, True, False, {}), ] else: # First episode: 0 -> 1 (reset) @@ -207,11 +207,11 @@ def make_env(idx): # Third episode: 4 -> 5 -> 6 -> 7 (done) env.reset.side_effect = [("state", 0), ("state", 2), ("state", 4)] env.step.side_effect = [ - (("state", 1), 0, False, {"needs_reset": True}), - (("state", 3), 0, False, {"needs_reset": True}), - (("state", 5), -0.5, False, {}), - (("state", 6), 0, False, {}), - (("state", 7), 1, True, {}), + (("state", 1), 0, False, False, {"needs_reset": True}), + (("state", 3), 0, False, False, {"needs_reset": True}), + (("state", 5), -0.5, False, False, {}), + (("state", 6), 0, False, False, {}), + (("state", 7), 1, True, False, {}), ] return env diff --git a/tests/wrappers_tests/test_atari_wrappers.py b/tests/wrappers_tests/test_atari_wrappers.py index 77d4a3ea6..ec0cbeffc 100644 --- a/tests/wrappers_tests/test_atari_wrappers.py +++ b/tests/wrappers_tests/test_atari_wrappers.py @@ -74,8 +74,8 @@ def dtyped_rand(): for _ in range(steps - 1): action = env.action_space.sample() fs_action = fs_env.action_space.sample() - obs, r, done, info = env.step(action) - fs_obs, fs_r, fs_done, fs_info = fs_env.step(fs_action) + obs, r, done, _, info = env.step(action) + fs_obs, fs_r, fs_done, _, fs_info = fs_env.step(fs_action) assert isinstance(fs_obs, LazyFrames) np.testing.assert_allclose( obs.take(indices=0, axis=fs_env.stack_axis), @@ -142,8 +142,8 @@ def dtyped_rand(): for _ in range(steps - 1): action = env.action_space.sample() s_action = s_env.action_space.sample() - obs, r, done, info = env.step(action) - s_obs, s_r, s_done, s_info = s_env.step(s_action) + obs, r, terminated, _, info = env.step(action) + s_obs, s_r, s_terminated, _, s_info = s_env.step(s_action) np.testing.assert_allclose(np.array(obs) / s_env.scale, s_obs) assert r == s_r - assert done == s_done + assert terminated == s_terminated diff --git a/tests/wrappers_tests/test_cast_observation.py b/tests/wrappers_tests/test_cast_observation.py index 06ad13f3b..5f925fb39 100644 --- a/tests/wrappers_tests/test_cast_observation.py +++ b/tests/wrappers_tests/test_cast_observation.py @@ -16,7 +16,7 @@ def test_cast_observation(env_id, dtype): assert obs.dtype == dtype np.testing.assert_allclose(env.original_observation, obs, rtol=rtol) - obs, r, done, info = env.step(env.action_space.sample()) + obs, r, done, _, info = env.step(env.action_space.sample()) assert env.original_observation.dtype == np.float64 assert obs.dtype == dtype @@ -32,7 +32,7 @@ def test_cast_observation_to_float32(env_id): assert obs.dtype == np.float32 np.testing.assert_allclose(env.original_observation, obs) - obs, r, done, info = env.step(env.action_space.sample()) + obs, r, done, _, info = env.step(env.action_space.sample()) assert env.original_observation.dtype == np.float64 assert obs.dtype == np.float32 np.testing.assert_allclose(env.original_observation, obs) diff --git a/tests/wrappers_tests/test_continuing_time_limit.py b/tests/wrappers_tests/test_continuing_time_limit.py index 9a20d93c5..f2bcda759 100644 --- a/tests/wrappers_tests/test_continuing_time_limit.py +++ b/tests/wrappers_tests/test_continuing_time_limit.py @@ -16,16 +16,16 @@ def test_continuing_time_limit(max_episode_steps): env.reset() for t in range(2): - _, _, done, info = env.step(0) + _, _, done, truncated, info = env.step(0) if t + 1 >= max_episode_steps: - assert info["needs_reset"] + assert info["needs_reset"] and truncated else: - assert not info.get("needs_reset", False) + assert not info.get("needs_reset", False) and not truncated env.reset() for t in range(4): - _, _, done, info = env.step(0) + _, _, done, truncated , info = env.step(0) if t + 1 >= max_episode_steps: - assert info["needs_reset"] + assert info["needs_reset"] and truncated else: - assert not info.get("needs_reset", False) + assert not info.get("needs_reset", False) and not truncated diff --git a/tests/wrappers_tests/test_monitor.py b/tests/wrappers_tests/test_monitor.py index 2151590e9..749440c00 100644 --- a/tests/wrappers_tests/test_monitor.py +++ b/tests/wrappers_tests/test_monitor.py @@ -30,12 +30,12 @@ def test_monitor(n_episodes): t = 0 _ = env.reset() while True: - _, _, done, info = env.step(env.action_space.sample()) + _, _, terminated, truncated, info = env.step(env.action_space.sample()) episode_len += 1 t += 1 if episode_idx == 1 and episode_len >= 3: info["needs_reset"] = True # simulate ContinuingTimeLimit - if done or info.get("needs_reset", False) or t == steps: + if terminated or truncated or info.get("needs_reset", False) or t == steps: if episode_idx + 1 == n_episodes or t == steps: break env.reset() diff --git a/tests/wrappers_tests/test_vector_frame_stack.py b/tests/wrappers_tests/test_vector_frame_stack.py index 1f649c400..dac808edd 100644 --- a/tests/wrappers_tests/test_vector_frame_stack.py +++ b/tests/wrappers_tests/test_vector_frame_stack.py @@ -86,8 +86,8 @@ def make_env(idx): ) batch_action = [0] * num_envs - fs_new_obs, fs_r, fs_done, _ = fs_env.step(batch_action) - vfs_new_obs, vfs_r, vfs_done, _ = vfs_env.step(batch_action) + fs_new_obs, fs_r, fs_done, _, _ = fs_env.step(batch_action) + vfs_new_obs, vfs_r, vfs_done, _, _ = vfs_env.step(batch_action) # Same LazyFrames observations, but those from fs_env are copies # while those from vfs_env are references. @@ -107,8 +107,8 @@ def make_env(idx): for _ in range(steps - 1): fs_env.reset(mask=np.logical_not(fs_done)) vfs_env.reset(mask=np.logical_not(vfs_done)) - fs_obs, fs_r, fs_done, _ = fs_env.step(batch_action) - vfs_obs, vfs_r, vfs_done, _ = vfs_env.step(batch_action) + fs_obs, fs_r, fs_terminated, _, _ = fs_env.step(batch_action) + vfs_obs, vfs_r, vfs_terminated, _, _ = vfs_env.step(batch_action) for env_idx in range(num_envs): assert isinstance(fs_new_obs[env_idx], LazyFrames) assert isinstance(vfs_new_obs[env_idx], LazyFrames) @@ -116,4 +116,4 @@ def make_env(idx): np.asarray(fs_new_obs[env_idx]), np.asarray(vfs_new_obs[env_idx]) ) np.testing.assert_allclose(fs_r, vfs_r) - np.testing.assert_allclose(fs_done, vfs_done) + np.testing.assert_allclose(fs_terminated, vfs_terminated) From c804fe3bd85ae47111770ab53d1287f401b44a36 Mon Sep 17 00:00:00 2001 From: Prabhat Date: Fri, 14 Apr 2023 22:56:22 -0500 Subject: [PATCH 04/26] some Atari changes --- pfrl/wrappers/atari_wrappers.py | 3 --- pfrl/wrappers/continuing_time_limit.py | 2 +- pfrl/wrappers/randomize_action.py | 8 ++++---- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pfrl/wrappers/atari_wrappers.py b/pfrl/wrappers/atari_wrappers.py index 53884329d..6e00db7c7 100644 --- a/pfrl/wrappers/atari_wrappers.py +++ b/pfrl/wrappers/atari_wrappers.py @@ -288,9 +288,6 @@ def observation(self, observation): def make_atari(env_id, max_frames=30 * 60 * 60): env = gymnasium.make(env_id) assert "NoFrameskip" in env.spec.id - assert isinstance(env, gymnasium.wrappers.TimeLimit) - # Unwrap TimeLimit wrapper because we use our own time limits - env = env.env if max_frames: env = pfrl.wrappers.ContinuingTimeLimit(env, max_episode_steps=max_frames) env = NoopResetEnv(env, noop_max=30) diff --git a/pfrl/wrappers/continuing_time_limit.py b/pfrl/wrappers/continuing_time_limit.py index 30f5900c2..360aa0ac5 100644 --- a/pfrl/wrappers/continuing_time_limit.py +++ b/pfrl/wrappers/continuing_time_limit.py @@ -34,7 +34,7 @@ def step(self, action): if self._max_episode_steps <= self._elapsed_steps: info["needs_reset"] = True - return observation, reward, done, info + return observation, reward, done, False, info def reset(self): self._elapsed_steps = 0 diff --git a/pfrl/wrappers/randomize_action.py b/pfrl/wrappers/randomize_action.py index 407cb1899..226ff932d 100644 --- a/pfrl/wrappers/randomize_action.py +++ b/pfrl/wrappers/randomize_action.py @@ -27,14 +27,14 @@ def __init__(self, env, random_fraction): env.action_space, gymnasium.spaces.Discrete ), "RandomizeAction supports only gymnasium.spaces.Discrete as an action space" self._random_fraction = random_fraction - self._np_random = np.random.RandomState() + self.unwrapped._np_random = np.random.RandomState() def action(self, action): - if self._np_random.rand() < self._random_fraction: - return self._np_random.randint(self.env.action_space.n) + if self.unwrapped._np_random.rand() < self._random_fraction: + return self.unwrapped._np_random.randint(self.env.action_space.n) else: return action def seed(self, seed): super().seed(seed) - self._np_random.seed(seed) + self.unwrapped._np_random.seed(seed) From 5daca4c3dbabc3511867acc540d1030a75e67ef9 Mon Sep 17 00:00:00 2001 From: Prabhat Date: Sat, 15 Apr 2023 03:22:08 -0500 Subject: [PATCH 05/26] Makes more env modifications --- examples/grasping/train_dqn_batch_grasping.py | 2 +- pfrl/envs/abc.py | 2 +- pfrl/envs/multiprocess_vector_env.py | 2 +- pfrl/envs/serial_vector_env.py | 2 +- pfrl/wrappers/atari_wrappers.py | 12 +++++++----- pfrl/wrappers/render.py | 2 +- pfrl/wrappers/vector_frame_stack.py | 2 +- tests/wrappers_tests/test_randomize_action.py | 2 +- tests/wrappers_tests/test_render.py | 18 ++++++++++-------- .../wrappers_tests/test_vector_frame_stack.py | 4 ++-- 10 files changed, 26 insertions(+), 22 deletions(-) diff --git a/examples/grasping/train_dqn_batch_grasping.py b/examples/grasping/train_dqn_batch_grasping.py index 0cff325ad..e4fa96024 100644 --- a/examples/grasping/train_dqn_batch_grasping.py +++ b/examples/grasping/train_dqn_batch_grasping.py @@ -87,7 +87,7 @@ def reset(self): pybullet.STATE_LOGGING_VIDEO_MP4, os.path.join(self._dirname, "{}.mp4".format(self._episode_idx)), ) - return obs + return obs, {} class GraspingQFunction(nn.Module): diff --git a/pfrl/envs/abc.py b/pfrl/envs/abc.py index 7322fdc8d..664453465 100644 --- a/pfrl/envs/abc.py +++ b/pfrl/envs/abc.py @@ -123,7 +123,7 @@ def reset(self): self._offset = np.random.randint(self.n_max_offset + 1) else: self._offset = 0 - return self.observe() + return self.observe(), {} def step(self, action): if isinstance(self.action_space, spaces.Box): diff --git a/pfrl/envs/multiprocess_vector_env.py b/pfrl/envs/multiprocess_vector_env.py index 9e1a6aef2..3aba6239b 100644 --- a/pfrl/envs/multiprocess_vector_env.py +++ b/pfrl/envs/multiprocess_vector_env.py @@ -99,7 +99,7 @@ def reset(self, mask=None): for m, remote, o in zip(mask, self.remotes, self.last_obs) ] self.last_obs = obs - return obs + return obs, {} def close(self): self._assert_not_closed() diff --git a/pfrl/envs/serial_vector_env.py b/pfrl/envs/serial_vector_env.py index b5a61e1a2..7c2416fe6 100644 --- a/pfrl/envs/serial_vector_env.py +++ b/pfrl/envs/serial_vector_env.py @@ -33,7 +33,7 @@ def reset(self, mask=None): for m, env, o in zip(mask, self.envs, self.last_obs) ] self.last_obs = obs - return obs + return obs, {} def seed(self, seeds): for env, seed in zip(self.envs, seeds): diff --git a/pfrl/wrappers/atari_wrappers.py b/pfrl/wrappers/atari_wrappers.py index 6e00db7c7..d231cc2ab 100644 --- a/pfrl/wrappers/atari_wrappers.py +++ b/pfrl/wrappers/atari_wrappers.py @@ -11,6 +11,7 @@ import pfrl + try: import cv2 @@ -48,7 +49,7 @@ def reset(self, **kwargs): obs, _, done, truncated, info = self.env.step(self.noop_action) if done or info.get("needs_reset", False) or truncated: obs = self.env.reset(**kwargs) - return obs + return obs, {} def step(self, ac): return self.env.step(ac) @@ -69,7 +70,7 @@ def reset(self, **kwargs): obs, _, done, info = self.env.step(2) if done or info.get("needs_reset", False) or truncated: self.env.reset(**kwargs) - return obs + return obs, {} def step(self, ac): return self.env.step(ac) @@ -112,7 +113,7 @@ def reset(self, **kwargs): # no-op step to advance from terminal/lost life state obs, _, _, _, _ = self.env.step(0) self.lives = self.env.unwrapped.ale.lives() - return obs + return obs, {} class MaxAndSkipEnv(gymnasium.Wrapper): @@ -178,6 +179,7 @@ def __init__(self, env, channel_order="hwc"): ) def observation(self, frame): + set_trace() frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) frame = cv2.resize( frame, (self.width, self.height), interpolation=cv2.INTER_AREA @@ -207,10 +209,10 @@ def __init__(self, env, k, channel_order="hwc"): ) def reset(self): - ob = self.env.reset() + ob, _ = self.env.reset() for _ in range(self.k): self.frames.append(ob) - return self._get_ob() + return self._get_ob(), {} def step(self, action): ob, reward, done, truncated, info = self.env.step(action) diff --git a/pfrl/wrappers/render.py b/pfrl/wrappers/render.py index dbd54de26..83dede7aa 100644 --- a/pfrl/wrappers/render.py +++ b/pfrl/wrappers/render.py @@ -16,7 +16,7 @@ def __init__(self, env, **kwargs): def reset(self, **kwargs): ret = self.env.reset(**kwargs) self.env.render(**self._kwargs) - return ret + return ret, {} def step(self, action): ret = self.env.step(action) diff --git a/pfrl/wrappers/vector_frame_stack.py b/pfrl/wrappers/vector_frame_stack.py index 6b3626caf..1165b7e18 100644 --- a/pfrl/wrappers/vector_frame_stack.py +++ b/pfrl/wrappers/vector_frame_stack.py @@ -88,7 +88,7 @@ def reset(self, mask=None): if not m: for _ in range(self.k): frames.append(ob) - return self._get_ob() + return self._get_ob(), {} def step(self, action): batch_ob, reward, terminated, _, info = self.env.step(action) diff --git a/tests/wrappers_tests/test_randomize_action.py b/tests/wrappers_tests/test_randomize_action.py index 8b48f72d1..d0c204a70 100644 --- a/tests/wrappers_tests/test_randomize_action.py +++ b/tests/wrappers_tests/test_randomize_action.py @@ -15,7 +15,7 @@ def __init__(self): self.past_actions = [] def reset(self): - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): self.past_actions.append(action) diff --git a/tests/wrappers_tests/test_render.py b/tests/wrappers_tests/test_render.py index 64c347370..4e555b37a 100644 --- a/tests/wrappers_tests/test_render.py +++ b/tests/wrappers_tests/test_render.py @@ -21,39 +21,41 @@ def test_render(render_kwargs): ("state", 3), ] orig_env.step.side_effect = [ - (("state", 1), 0, False, {}), - (("state", 2), 1, True, {}), + (("state", 1), 0, False, False, {}), + (("state", 2), 1, True, False, {}), ] env = pfrl.wrappers.Render(orig_env, **render_kwargs) # Not called env.render yet assert orig_env.render.call_count == 0 - obs = env.reset() + obs, _ = env.reset() assert obs == ("state", 0) # Called once assert orig_env.render.call_count == 1 - obs, reward, done, info = env.step(0) + obs, reward, terminated, truncated, info = env.step(0) assert obs == ("state", 1) assert reward == 0 - assert not done + assert not terminated + assert not truncated assert info == {} # Called twice assert orig_env.render.call_count == 2 - obs, reward, done, info = env.step(0) + obs, reward, terminated, truncated, info = env.step(0) assert obs == ("state", 2) assert reward == 1 - assert done + assert terminated + assert not truncated assert info == {} # Called thrice assert orig_env.render.call_count == 3 - obs = env.reset() + obs, _ = env.reset() assert obs == ("state", 3) # Called four times diff --git a/tests/wrappers_tests/test_vector_frame_stack.py b/tests/wrappers_tests/test_vector_frame_stack.py index dac808edd..6e5919ca1 100644 --- a/tests/wrappers_tests/test_vector_frame_stack.py +++ b/tests/wrappers_tests/test_vector_frame_stack.py @@ -74,8 +74,8 @@ def make_env(idx): assert fs_env.action_space == vfs_env.action_space assert fs_env.observation_space == vfs_env.observation_space - fs_obs = fs_env.reset() - vfs_obs = vfs_env.reset() + fs_obs, _ = fs_env.reset() + vfs_obs, _ = vfs_env.reset() # Same LazyFrames observations for env_idx in range(num_envs): From 85fe46fc6c21439b14cefb2de7a2b276c748146f Mon Sep 17 00:00:00 2001 From: Prabhat Date: Sun, 16 Apr 2023 01:14:18 -0500 Subject: [PATCH 06/26] Fixes some observations, and uses new Gym AtariEnv properly --- pfrl/experiments/train_agent.py | 4 ++-- pfrl/wrappers/atari_wrappers.py | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pfrl/experiments/train_agent.py b/pfrl/experiments/train_agent.py index 341ebefff..a2e903f6e 100644 --- a/pfrl/experiments/train_agent.py +++ b/pfrl/experiments/train_agent.py @@ -42,7 +42,7 @@ def train_agent( episode_idx = 0 # o_0, r_0 - obs = env.reset() + obs , info = env.reset() t = step_offset if hasattr(agent, "t"): @@ -98,7 +98,7 @@ def train_agent( # Start a new episode episode_r = 0 episode_len = 0 - obs = env.reset() + obs, info = env.reset() if checkpoint_freq and t % checkpoint_freq == 0: save_agent(agent, t, outdir, logger, suffix="_checkpoint") diff --git a/pfrl/wrappers/atari_wrappers.py b/pfrl/wrappers/atari_wrappers.py index d231cc2ab..bced89722 100644 --- a/pfrl/wrappers/atari_wrappers.py +++ b/pfrl/wrappers/atari_wrappers.py @@ -48,8 +48,8 @@ def reset(self, **kwargs): for _ in range(noops): obs, _, done, truncated, info = self.env.step(self.noop_action) if done or info.get("needs_reset", False) or truncated: - obs = self.env.reset(**kwargs) - return obs, {} + obs, info = self.env.reset(**kwargs) + return obs, info def step(self, ac): return self.env.step(ac) @@ -108,12 +108,12 @@ def reset(self, **kwargs): and the learner need not know about any of this behind-the-scenes. """ if self.needs_real_reset: - obs = self.env.reset(**kwargs) + obs, info = self.env.reset(**kwargs) else: # no-op step to advance from terminal/lost life state - obs, _, _, _, _ = self.env.step(0) + obs, _, _, _, info = self.env.step(0) self.lives = self.env.unwrapped.ale.lives() - return obs, {} + return obs, info class MaxAndSkipEnv(gymnasium.Wrapper): @@ -179,7 +179,6 @@ def __init__(self, env, channel_order="hwc"): ) def observation(self, frame): - set_trace() frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) frame = cv2.resize( frame, (self.width, self.height), interpolation=cv2.INTER_AREA @@ -209,10 +208,10 @@ def __init__(self, env, k, channel_order="hwc"): ) def reset(self): - ob, _ = self.env.reset() + ob, info = self.env.reset() for _ in range(self.k): self.frames.append(ob) - return self._get_ob(), {} + return self._get_ob(), info def step(self, action): ob, reward, done, truncated, info = self.env.step(action) @@ -288,10 +287,11 @@ def observation(self, observation): def make_atari(env_id, max_frames=30 * 60 * 60): - env = gymnasium.make(env_id) + env = gymnasium.make(env_id, + repeat_action_probability=0.0, + full_action_space=False, frameskip=1, + max_num_frames_per_episode=max_frames) assert "NoFrameskip" in env.spec.id - if max_frames: - env = pfrl.wrappers.ContinuingTimeLimit(env, max_episode_steps=max_frames) env = NoopResetEnv(env, noop_max=30) env = MaxAndSkipEnv(env, skip=4) return env From c7d62f7514691cf2cf0f767bc37360b5dde0f777 Mon Sep 17 00:00:00 2001 From: Prabhat Date: Sun, 16 Apr 2023 01:18:06 -0500 Subject: [PATCH 07/26] makes some evaluator updates --- pfrl/experiments/evaluator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pfrl/experiments/evaluator.py b/pfrl/experiments/evaluator.py index bce3f7f55..2afbb9caa 100644 --- a/pfrl/experiments/evaluator.py +++ b/pfrl/experiments/evaluator.py @@ -29,8 +29,8 @@ def _run_episodes( reset = True while not terminate: if reset: - obs = env.reset() - done = False + obs, info = env.reset() + terminated = False test_r = 0 episode_len = 0 info = {} @@ -120,7 +120,7 @@ def _batch_run_episodes( episode_r = np.zeros(num_envs, dtype=np.float64) episode_len = np.zeros(num_envs, dtype="i") - obss = env.reset() + obss, infos = env.reset() rs = np.zeros(num_envs, dtype="f") termination_conditions = False @@ -199,7 +199,7 @@ def _batch_run_episodes( resets.fill(True) # Agent observes the consequences. - agent.batch_observe(obss, rs, dones, resets) + agent.batch_observe(obss, rs, terminations, resets) if termination_conditions: break From b51ae32670f376dfeb3a29a17e5393d1408c05d0 Mon Sep 17 00:00:00 2001 From: Prabhat Date: Sun, 16 Apr 2023 03:05:59 -0500 Subject: [PATCH 08/26] Gets evaluations working by modifying RandomizeAction class --- pfrl/experiments/evaluator.py | 9 ++++----- pfrl/wrappers/atari_wrappers.py | 8 ++++---- pfrl/wrappers/randomize_action.py | 11 +++++++---- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/pfrl/experiments/evaluator.py b/pfrl/experiments/evaluator.py index 2afbb9caa..b02f8a52e 100644 --- a/pfrl/experiments/evaluator.py +++ b/pfrl/experiments/evaluator.py @@ -8,7 +8,6 @@ import pfrl - def _run_episodes( env, agent, @@ -23,11 +22,11 @@ def _run_episodes( logger = logger or logging.getLogger(__name__) scores = [] lengths = [] - terminate = False + terminated = False timestep = 0 reset = True - while not terminate: + while not terminated: if reset: obs, info = env.reset() terminated = False @@ -50,9 +49,9 @@ def _run_episodes( scores.append(float(test_r)) lengths.append(float(episode_len)) if n_steps is None: - terminate = len(scores) >= n_episodes + terminated = len(scores) >= n_episodes else: - terminate = timestep >= n_steps + terminated = timestep >= n_steps # If all steps were used for a single unfinished episode if len(scores) == 0: scores.append(float(test_r)) diff --git a/pfrl/wrappers/atari_wrappers.py b/pfrl/wrappers/atari_wrappers.py index bced89722..11a533a89 100644 --- a/pfrl/wrappers/atari_wrappers.py +++ b/pfrl/wrappers/atari_wrappers.py @@ -87,8 +87,8 @@ def __init__(self, env): self.needs_real_reset = True def step(self, action): - obs, reward, done, truncated, info = self.env.step(action) - self.needs_real_reset = done or info.get("needs_reset", False) + obs, reward, terminated, truncated, info = self.env.step(action) + self.needs_real_reset = terminated or info.get("needs_reset", False) or truncated # check current lives, make loss of life terminal, # then update lives to handle bonus lives lives = self.env.unwrapped.ale.lives() @@ -97,9 +97,9 @@ def step(self, action): # frames # so its important to keep lives > 0, so that we only reset once # the environment advertises done. - done = True + terminated = True self.lives = lives - return obs, reward, done, truncated, info + return obs, reward, terminated, truncated, info def reset(self, **kwargs): """Reset only when lives are exhausted. diff --git a/pfrl/wrappers/randomize_action.py b/pfrl/wrappers/randomize_action.py index 226ff932d..8d848d15b 100644 --- a/pfrl/wrappers/randomize_action.py +++ b/pfrl/wrappers/randomize_action.py @@ -27,14 +27,17 @@ def __init__(self, env, random_fraction): env.action_space, gymnasium.spaces.Discrete ), "RandomizeAction supports only gymnasium.spaces.Discrete as an action space" self._random_fraction = random_fraction - self.unwrapped._np_random = np.random.RandomState() + self._rng = np.random.RandomState() def action(self, action): - if self.unwrapped._np_random.rand() < self._random_fraction: - return self.unwrapped._np_random.randint(self.env.action_space.n) + if self._rng.rand() < self._random_fraction: + return self._rng.randint(self.env.action_space.n) else: return action + def reset(self, **kwargs): + return self.env.reset(**kwargs) + def seed(self, seed): super().seed(seed) - self.unwrapped._np_random.seed(seed) + self._rng.seed(seed) From ffdc311289b59a3432b517c9f29d1d50ed988b5a Mon Sep 17 00:00:00 2001 From: Prabhat Date: Mon, 17 Apr 2023 11:15:31 -0700 Subject: [PATCH 09/26] fixes setup --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d1910e0b9..84ca77dba 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ install_requires = [ 'torch>=1.3.0', - 'gymnasiumnasium[all]', + 'gymnasium[all]', 'numpy>=1.11.0', 'pillow', 'filelock', From 07c464bf98fee9a100c80ddf3bb0ae6dbb1ecea8 Mon Sep 17 00:00:00 2001 From: Prabhat Date: Tue, 18 Apr 2023 00:27:59 -0700 Subject: [PATCH 10/26] Adds a generic GymWrapper --- pfrl/wrappers/gym_wrapper.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 pfrl/wrappers/gym_wrapper.py diff --git a/pfrl/wrappers/gym_wrapper.py b/pfrl/wrappers/gym_wrapper.py new file mode 100644 index 000000000..728cb40c6 --- /dev/null +++ b/pfrl/wrappers/gym_wrapper.py @@ -0,0 +1,16 @@ +import gymnasium + + +class GymWrapper(gymnasium.Env): + def __init__(self, gym_env): + """A Gymnasium environment that wraps OpenAI gym environments.""" + super(GymWrapper, self).__init__() + self.env = gym_env + + def reset(self, **kwargs): + obs = self.env.reset() + return obs, {} + + def step(self, action): + obs, reward, done, info = self.env.step(action) + return obs, reward, done, False, info From 675c9785fd28f4e44b26b735ac452316225b47e2 Mon Sep 17 00:00:00 2001 From: Prabhat Date: Sat, 22 Apr 2023 18:51:29 -0700 Subject: [PATCH 11/26] Shifts Pendulum version in example to v1 since v0 is deprecated --- examples/gym/train_dqn_gym.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/gym/train_dqn_gym.py b/examples/gym/train_dqn_gym.py index d39385cf2..60b5ec932 100644 --- a/examples/gym/train_dqn_gym.py +++ b/examples/gym/train_dqn_gym.py @@ -4,11 +4,11 @@ Both discrete and continuous action spaces are supported. For continuous action spaces, A NAF (Normalized Advantage Function) is used to approximate Q-values. -To solve CartPole-v0, run: - python train_dqn_gymnasium.py --env CartPole-v0 +To solve CartPole-v1, run: + python train_dqn_gymnasium.py --env CartPole-v1 -To solve Pendulum-v0, run: - python train_dqn_gymnasium.py --env Pendulum-v0 +To solve Pendulum-v1, run: + python train_dqn_gymnasium.py --env Pendulum-v1 """ import argparse @@ -42,7 +42,7 @@ def main(): " If it does not exist, it will be created." ), ) - parser.add_argument("--env", type=str, default="Pendulum-v0") + parser.add_argument("--env", type=str, default="Pendulum-v1") parser.add_argument("--seed", type=int, default=0, help="Random seed [0, 2 ** 32)") parser.add_argument("--gpu", type=int, default=0) parser.add_argument("--final-exploration-steps", type=int, default=10**4) From 85c38e15867ad6cc2e49fb276343dbe659c2d714 Mon Sep 17 00:00:00 2001 From: Prabhat Date: Thu, 27 Apr 2023 02:30:46 -0700 Subject: [PATCH 12/26] Adds Q value computation to DDQN (and by extension DDQN) --- pfrl/agents/dqn.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pfrl/agents/dqn.py b/pfrl/agents/dqn.py index 740fcc2fa..fddc0de43 100644 --- a/pfrl/agents/dqn.py +++ b/pfrl/agents/dqn.py @@ -33,6 +33,7 @@ recurrent_state_as_numpy, ) +from pdb import set_trace def _mean_or_nan(xs: Sequence[float]) -> float: """Return its mean a non-empty sequence, numpy.nan for a empty one.""" @@ -485,6 +486,13 @@ def _evaluate_model_and_update_recurrent_states( batch_av = self.model(batch_xs) return batch_av + def compute_q(self, batch_obs: Sequence[Any], batch_action: Sequence[Any]) -> Sequence[Any]: + with torch.no_grad(), evaluating(self.model): + batch_av = self._evaluate_model_and_update_recurrent_states(batch_obs) + q_values = batch_av.q_values + batch_q_values = q_values[torch.arange(q_values.shape[0]), batch_action] + return batch_q_values + def batch_act(self, batch_obs: Sequence[Any]) -> Sequence[Any]: with torch.no_grad(), evaluating(self.model): batch_av = self._evaluate_model_and_update_recurrent_states(batch_obs) From 02048ae25a36fd11a34950fdd8527f1e6de1eb10 Mon Sep 17 00:00:00 2001 From: Prabhat Date: Tue, 9 May 2023 21:41:35 -0700 Subject: [PATCH 13/26] removes filelock from setup --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 84ca77dba..9a55b9dca 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ 'gymnasium[all]', 'numpy>=1.11.0', 'pillow', - 'filelock', + # 'filelock', ] test_requires = [ From 98a4efcb48f3f4985cf46e11d003f42253a26392 Mon Sep 17 00:00:00 2001 From: Prabhat Date: Tue, 9 May 2023 21:55:12 -0700 Subject: [PATCH 14/26] removes all required items --- setup.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 9a55b9dca..05362956d 100644 --- a/setup.py +++ b/setup.py @@ -3,10 +3,10 @@ from setuptools import setup install_requires = [ - 'torch>=1.3.0', - 'gymnasium[all]', - 'numpy>=1.11.0', - 'pillow', + # 'torch>=1.3.0', + # 'gymnasium[all]', + # 'numpy>=1.11.0', + # 'pillow', # 'filelock', ] From 481854542ba44f02e44c630b809ccd08237cbf5c Mon Sep 17 00:00:00 2001 From: Prabhat Date: Sat, 24 Jun 2023 01:51:52 -0600 Subject: [PATCH 15/26] fixes setup --- setup.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 05362956d..84ca77dba 100644 --- a/setup.py +++ b/setup.py @@ -3,11 +3,11 @@ from setuptools import setup install_requires = [ - # 'torch>=1.3.0', - # 'gymnasium[all]', - # 'numpy>=1.11.0', - # 'pillow', - # 'filelock', + 'torch>=1.3.0', + 'gymnasium[all]', + 'numpy>=1.11.0', + 'pillow', + 'filelock', ] test_requires = [ From 3f98ef79981cfdaa190ee4689edeb3c55580ac57 Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Thu, 28 Dec 2023 00:17:22 -0600 Subject: [PATCH 16/26] does gymnasium all to gymnasium atari --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 8bb833cb9..0dcb7c8d5 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ install_requires = [ 'torch>=1.3.0', - 'gymnasium[all]', + 'gymnasium[atari]', 'numpy>=1.11.0', 'pillow', 'filelock', From b14560915f5c8a520ee352218d7fffbeebe45454 Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Mon, 1 Apr 2024 21:28:27 -0600 Subject: [PATCH 17/26] Fixes multiprocessvector_env step --- pfrl/envs/multiprocess_vector_env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pfrl/envs/multiprocess_vector_env.py b/pfrl/envs/multiprocess_vector_env.py index 3aba6239b..94ebc2cc3 100644 --- a/pfrl/envs/multiprocess_vector_env.py +++ b/pfrl/envs/multiprocess_vector_env.py @@ -83,8 +83,8 @@ def step(self, actions): for remote, action in zip(self.remotes, actions): remote.send(("step", action)) results = [remote.recv() for remote in self.remotes] - self.last_obs, rews, dones, infos = zip(*results) - return self.last_obs, rews, dones, infos + self.last_obs, rews, terminateds, truncateds, infos = zip(*results) + return self.last_obs, rews, terminateds, truncateds, infos def reset(self, mask=None): self._assert_not_closed() From 0c770b59c303ff5f8b48123c5235af8b26271be7 Mon Sep 17 00:00:00 2001 From: Brett Daley Date: Tue, 2 Apr 2024 00:22:58 -0400 Subject: [PATCH 18/26] Multiprocess fixes --- pfrl/envs/multiprocess_vector_env.py | 11 ++++++----- pfrl/experiments/evaluator.py | 2 +- pfrl/experiments/train_agent_async.py | 4 ++-- pfrl/experiments/train_agent_batch.py | 4 ++-- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/pfrl/envs/multiprocess_vector_env.py b/pfrl/envs/multiprocess_vector_env.py index 94ebc2cc3..540552f17 100644 --- a/pfrl/envs/multiprocess_vector_env.py +++ b/pfrl/envs/multiprocess_vector_env.py @@ -19,8 +19,8 @@ def worker(remote, env_fn): ob, reward, terminated, truncated, info = env.step(data) remote.send((ob, reward, terminated, truncated, info)) elif cmd == "reset": - ob = env.reset() - remote.send(ob) + ob, info = env.reset() + remote.send((ob, info)) elif cmd == "close": remote.close() break @@ -94,12 +94,13 @@ def reset(self, mask=None): if not m: remote.send(("reset", None)) - obs = [ - remote.recv() if not m else o + results = [ + remote.recv() if not m else (o, {}) for m, remote, o in zip(mask, self.remotes, self.last_obs) ] + obs, info = zip(*results) self.last_obs = obs - return obs, {} + return obs, info def close(self): self._assert_not_closed() diff --git a/pfrl/experiments/evaluator.py b/pfrl/experiments/evaluator.py index b02f8a52e..4b0afbede 100644 --- a/pfrl/experiments/evaluator.py +++ b/pfrl/experiments/evaluator.py @@ -203,7 +203,7 @@ def _batch_run_episodes( if termination_conditions: break else: - obss = env.reset(not_end) + obss, infos = env.reset(not_end) for i, (epi_len, epi_ret) in enumerate( zip(eval_episode_lens, eval_episode_returns) diff --git a/pfrl/experiments/train_agent_async.py b/pfrl/experiments/train_agent_async.py index ff19fa0de..9e5971523 100644 --- a/pfrl/experiments/train_agent_async.py +++ b/pfrl/experiments/train_agent_async.py @@ -58,7 +58,7 @@ def save_model(): global_t = 0 local_t = 0 global_episodes = 0 - obs = env.reset() + obs, info = env.reset() episode_len = 0 successful = False @@ -119,7 +119,7 @@ def save_model(): # Start a new episode episode_r = 0 episode_len = 0 - obs = env.reset() + obs, info = env.reset() if process_idx == 0 and exception_event.is_set(): logger.exception("An exception detected, exiting") diff --git a/pfrl/experiments/train_agent_batch.py b/pfrl/experiments/train_agent_batch.py index 5452d9aa9..8826830ef 100644 --- a/pfrl/experiments/train_agent_batch.py +++ b/pfrl/experiments/train_agent_batch.py @@ -54,7 +54,7 @@ def train_agent_batch( episode_len = np.zeros(num_envs, dtype="i") # o_0, r_0 - obss = env.reset() + obss, infos = env.reset() t = step_offset if hasattr(agent, "t"): @@ -138,7 +138,7 @@ def train_agent_batch( # Start new episodes if needed episode_r[end] = 0 episode_len[end] = 0 - obss = env.reset(not_end) + obss, infos = env.reset(not_end) except (Exception, KeyboardInterrupt): # Save the current model before being killed From 04c1dd53b03d2856c243dcedbbff535ade56a958 Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Tue, 2 Apr 2024 21:06:50 -0600 Subject: [PATCH 19/26] OpenAI -> Farama Foundation --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a20d8bc34..a1580ca8a 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ Refer to [Installation](http://pfrl.readthedocs.io/en/latest/install.html) for m ## Getting started -You can try [PFRL Quickstart Guide](examples/quickstart/quickstart.ipynb) first, or check the [examples](examples) ready for Atari 2600 and Open AI gymnasium. +You can try [PFRL Quickstart Guide](examples/quickstart/quickstart.ipynb) first, or check the [examples](examples) ready for Atari 2600 and Farama Foundation's gymnasium. For more information, you can refer to [PFRL's documentation](http://pfrl.readthedocs.io/en/latest/index.html). @@ -99,7 +99,7 @@ Following useful techniques have been also implemented in PFRL: ## Environments -Environments that support the subset of OpenAI gymnasium's interface (`reset` and `step` methods) can be used. +Environments that support the subset of Farama Foundation's gymnasium's interface (`reset` and `step` methods) can be used. ## Contributing From c408b08c2c6d0fc39249abe3e636e133b9b1c335 Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Tue, 2 Apr 2024 23:14:08 -0600 Subject: [PATCH 20/26] Makes modifications for gymnasium imports, etc. --- examples/optuna/optuna_dqn_obs1d.py | 10 +++--- examples/slimevolley/README.md | 8 ++--- pfrl/agents/dqn.py | 8 ----- pfrl/envs/abc.py | 3 +- pfrl/wrappers/atari_wrappers.py | 44 +++++++++++------------ pfrl/wrappers/cast_observation.py | 4 +-- pfrl/wrappers/continuing_time_limit.py | 13 ++++--- pfrl/wrappers/scale_reward.py | 4 +-- setup.cfg | 2 +- tests/wrappers_tests/test_scale_reward.py | 4 +-- 10 files changed, 46 insertions(+), 54 deletions(-) diff --git a/examples/optuna/optuna_dqn_obs1d.py b/examples/optuna/optuna_dqn_obs1d.py index dbe9ec741..c1cd44011 100644 --- a/examples/optuna/optuna_dqn_obs1d.py +++ b/examples/optuna/optuna_dqn_obs1d.py @@ -14,7 +14,7 @@ import os import random -import gymnasium +import gymnasium as gym import torch.optim as optim try: @@ -54,9 +54,9 @@ def _objective_core( test_seed = 2**31 - 1 - seed def make_env(test=False): - env = gymnasium.make(env_id) + env = gym.make(env_id) - if not isinstance(env.observation_space, gymnasium.spaces.Box): + if not isinstance(env.observation_space, gym.spaces.Box): raise ValueError( "Supported only Box observation environments, but given: {}".format( env.observation_space @@ -68,7 +68,7 @@ def make_env(test=False): env.observation_space.shape ) ) - if not isinstance(env.action_space, gymnasium.spaces.Discrete): + if not isinstance(env.action_space, gym.spaces.Discrete): raise ValueError( "Supported only discrete action environments, but given: {}".format( env.action_space @@ -244,7 +244,7 @@ def main(): "--env", type=str, default="LunarLander-v2", - help="OpenAI gymnasium Environment ID.", + help="OpenAI gym Environment ID.", ) parser.add_argument( "--outdir", diff --git a/examples/slimevolley/README.md b/examples/slimevolley/README.md index 7eb0afc43..b70b73032 100644 --- a/examples/slimevolley/README.md +++ b/examples/slimevolley/README.md @@ -1,6 +1,6 @@ # Slime Volleyball -This directory contains an example script that learns to play Slime Volleyball using the environment `SlimeVolley-v0` of [slimevolleygymnasium](https://github.com/hardmaru/slimevolleygymnasium). +This directory contains an example script that learns to play Slime Volleyball using the environment `SlimeVolley-v0` of [slimevolleygym](https://github.com/hardmaru/slimevolleygym). ![SlimeVolley](assets/slimevolley.gif) @@ -10,8 +10,8 @@ This directory contains an example script that learns to play Slime Volleyball u ## Requirements -- `slimevolleygymnasium` (https://github.com/hardmaru/slimevolleygymnasium) - - You can install from PyPI: `pip install slimevolleygymnasium==0.1.0` +- `slimevolleygym` (https://github.com/hardmaru/slimevolleygym) + - You can install from PyPI: `pip install slimevolleygym==0.1.0` ## Algorithm @@ -37,7 +37,7 @@ python examples/slimevolley/train_rainbow.py --demo --render --load float: """Return its mean a non-empty sequence, numpy.nan for a empty one.""" @@ -488,13 +487,6 @@ def _evaluate_model_and_update_recurrent_states( batch_av = self.model(batch_xs) return batch_av - def compute_q(self, batch_obs: Sequence[Any], batch_action: Sequence[Any]) -> Sequence[Any]: - with torch.no_grad(), evaluating(self.model): - batch_av = self._evaluate_model_and_update_recurrent_states(batch_obs) - q_values = batch_av.q_values - batch_q_values = q_values[torch.arange(q_values.shape[0]), batch_action] - return batch_q_values - def batch_act(self, batch_obs: Sequence[Any]) -> Sequence[Any]: with torch.no_grad(), evaluating(self.model): batch_av = self._evaluate_model_and_update_recurrent_states(batch_obs) diff --git a/pfrl/envs/abc.py b/pfrl/envs/abc.py index 664453465..74ad9c514 100644 --- a/pfrl/envs/abc.py +++ b/pfrl/envs/abc.py @@ -1,5 +1,6 @@ import numpy as np -from gymnasium import spaces +import gymnasium as gym +from gym import spaces from pfrl import env diff --git a/pfrl/wrappers/atari_wrappers.py b/pfrl/wrappers/atari_wrappers.py index 11a533a89..d54e06cc1 100644 --- a/pfrl/wrappers/atari_wrappers.py +++ b/pfrl/wrappers/atari_wrappers.py @@ -4,9 +4,9 @@ from collections import deque -import gymnasium +import gymnasium as gym import numpy as np -from gymnasium import spaces +from gym import spaces from packaging import version import pfrl @@ -21,13 +21,13 @@ _is_cv2_available = False -class NoopResetEnv(gymnasium.Wrapper): +class NoopResetEnv(gym.Wrapper): def __init__(self, env, noop_max=30): """Sample initial states by taking random number of no-ops on reset. No-op is assumed to be action 0. """ - gymnasium.Wrapper.__init__(self, env) + gym.Wrapper.__init__(self, env) self.noop_max = noop_max self.override_num_noops = None self.noop_action = 0 @@ -39,7 +39,7 @@ def reset(self, **kwargs): if self.override_num_noops is not None: noops = self.override_num_noops else: - if version.parse(gymnasium.__version__) >= version.parse("0.24.0"): + if version.parse(gym.__version__) >= version.parse("0.24.0"): noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) else: noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) @@ -55,10 +55,10 @@ def step(self, ac): return self.env.step(ac) -class FireResetEnv(gymnasium.Wrapper): +class FireResetEnv(gym.Wrapper): def __init__(self, env): """Take action on reset for envs that are fixed until firing.""" - gymnasium.Wrapper.__init__(self, env) + gym.Wrapper.__init__(self, env) assert env.unwrapped.get_action_meanings()[1] == "FIRE" assert len(env.unwrapped.get_action_meanings()) >= 3 @@ -76,13 +76,13 @@ def step(self, ac): return self.env.step(ac) -class EpisodicLifeEnv(gymnasium.Wrapper): +class EpisodicLifeEnv(gym.Wrapper): def __init__(self, env): """Make end-of-life == end-of-episode, but only reset on true game end. Done by DeepMind for the DQN and co. since it helps value estimation. """ - gymnasium.Wrapper.__init__(self, env) + gym.Wrapper.__init__(self, env) self.lives = 0 self.needs_real_reset = True @@ -116,10 +116,10 @@ def reset(self, **kwargs): return obs, info -class MaxAndSkipEnv(gymnasium.Wrapper): +class MaxAndSkipEnv(gym.Wrapper): def __init__(self, env, skip=4): """Return only every `skip`-th frame""" - gymnasium.Wrapper.__init__(self, env) + gym.Wrapper.__init__(self, env) # most recent raw observations (for max pooling across time steps) self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8) self._skip = skip @@ -147,16 +147,16 @@ def reset(self, **kwargs): return self.env.reset(**kwargs) -class ClipRewardEnv(gymnasium.RewardWrapper): +class ClipRewardEnv(gym.RewardWrapper): def __init__(self, env): - gymnasium.RewardWrapper.__init__(self, env) + gym.RewardWrapper.__init__(self, env) def reward(self, reward): """Bin reward to {+1, 0, -1} by its sign.""" return np.sign(reward) -class WarpFrame(gymnasium.ObservationWrapper): +class WarpFrame(gym.ObservationWrapper): def __init__(self, env, channel_order="hwc"): """Warp frames to 84x84 as done in the Nature paper and later work. @@ -167,7 +167,7 @@ def __init__(self, env, channel_order="hwc"): "Cannot import cv2 module. Please install OpenCV-Python to use" " WarpFrame." ) - gymnasium.ObservationWrapper.__init__(self, env) + gym.ObservationWrapper.__init__(self, env) self.width = 84 self.height = 84 shape = { @@ -186,7 +186,7 @@ def observation(self, frame): return frame.reshape(self.observation_space.low.shape) -class FrameStack(gymnasium.Wrapper): +class FrameStack(gym.Wrapper): def __init__(self, env, k, channel_order="hwc"): """Stack k last frames. @@ -196,7 +196,7 @@ def __init__(self, env, k, channel_order="hwc"): -------- baselines.common.atari_wrappers.LazyFrames """ - gymnasium.Wrapper.__init__(self, env) + gym.Wrapper.__init__(self, env) self.k = k self.frames = deque([], maxlen=k) self.stack_axis = {"hwc": 2, "chw": 0}[channel_order] @@ -223,7 +223,7 @@ def _get_ob(self): return LazyFrames(list(self.frames), stack_axis=self.stack_axis) -class ScaledFloatFrame(gymnasium.ObservationWrapper): +class ScaledFloatFrame(gym.ObservationWrapper): """Divide frame values by 255.0 and return them as np.float32. Especially, when the original env.observation_space is np.uint8, @@ -232,7 +232,7 @@ class ScaledFloatFrame(gymnasium.ObservationWrapper): def __init__(self, env): assert isinstance(env.observation_space, spaces.Box) - gymnasium.ObservationWrapper.__init__(self, env) + gym.ObservationWrapper.__init__(self, env) self.scale = 255.0 @@ -273,11 +273,11 @@ def __array__(self, dtype=None): return out -class FlickerFrame(gymnasium.ObservationWrapper): +class FlickerFrame(gym.ObservationWrapper): """Stochastically flicker frames.""" def __init__(self, env): - gymnasium.ObservationWrapper.__init__(self, env) + gym.ObservationWrapper.__init__(self, env) def observation(self, observation): if self.unwrapped.np_random.rand() < 0.5: @@ -287,7 +287,7 @@ def observation(self, observation): def make_atari(env_id, max_frames=30 * 60 * 60): - env = gymnasium.make(env_id, + env = gym.make(env_id, repeat_action_probability=0.0, full_action_space=False, frameskip=1, max_num_frames_per_episode=max_frames) diff --git a/pfrl/wrappers/cast_observation.py b/pfrl/wrappers/cast_observation.py index c3444d3cd..2fc853243 100644 --- a/pfrl/wrappers/cast_observation.py +++ b/pfrl/wrappers/cast_observation.py @@ -1,8 +1,8 @@ -import gymnasium +import gymnasium as gym import numpy as np -class CastObservation(gymnasium.ObservationWrapper): +class CastObservation(gym.ObservationWrapper): """Cast observations to a given type. Args: diff --git a/pfrl/wrappers/continuing_time_limit.py b/pfrl/wrappers/continuing_time_limit.py index 360aa0ac5..805f884b8 100644 --- a/pfrl/wrappers/continuing_time_limit.py +++ b/pfrl/wrappers/continuing_time_limit.py @@ -1,11 +1,11 @@ -import gymnasium +import gymnasium as gym -class ContinuingTimeLimit(gymnasium.Wrapper): +class ContinuingTimeLimit(gym.Wrapper): """TimeLimit wrapper for continuing environments. - This is similar gymnasium.wrappers.TimeLimit, which sets a time limit for - each episode, except that done=False is returned and that + This is similar to gymnasium.wrappers.TimeLimit, which sets a time limit for + each episode, except that truncated=False is returned and that info['needs_reset'] is set to True when past the limit. Code that calls env.step is responsible for checking the info dict, the @@ -21,20 +21,19 @@ class ContinuingTimeLimit(gymnasium.Wrapper): def __init__(self, env, max_episode_steps): super(ContinuingTimeLimit, self).__init__(env) self._max_episode_steps = max_episode_steps - self._elapsed_steps = None def step(self, action): assert ( self._elapsed_steps is not None ), "Cannot call env.step() before calling reset()" - observation, reward, done, _, info = self.env.step(action) + observation, reward, terminated, _, info = self.env.step(action) self._elapsed_steps += 1 if self._max_episode_steps <= self._elapsed_steps: info["needs_reset"] = True - return observation, reward, done, False, info + return observation, reward, terminated, False, info def reset(self): self._elapsed_steps = 0 diff --git a/pfrl/wrappers/scale_reward.py b/pfrl/wrappers/scale_reward.py index 7d309c20e..d34a238d3 100644 --- a/pfrl/wrappers/scale_reward.py +++ b/pfrl/wrappers/scale_reward.py @@ -1,7 +1,7 @@ -import gymnasium +import gymnasium as gym -class ScaleReward(gymnasium.RewardWrapper): +class ScaleReward(gym.RewardWrapper): """Scale reward by a scale factor. Args: diff --git a/setup.cfg b/setup.cfg index 8504ffc24..84f1c2234 100644 --- a/setup.cfg +++ b/setup.cfg @@ -33,7 +33,7 @@ ignore_missing_imports = True [mypy-roboschool.*] ignore_missing_imports = True -[mypy-slimevolleygymnasium.*] +[mypy-slimevolleygym.*] ignore_missing_imports = True [mypy-optuna.*] diff --git a/tests/wrappers_tests/test_scale_reward.py b/tests/wrappers_tests/test_scale_reward.py index 0fade3b57..4bb95f720 100644 --- a/tests/wrappers_tests/test_scale_reward.py +++ b/tests/wrappers_tests/test_scale_reward.py @@ -1,4 +1,4 @@ -import gymnasium +import gymnasium as gym import numpy as np import pytest @@ -8,7 +8,7 @@ @pytest.mark.parametrize("env_id", ["CartPole-v1", "MountainCar-v0"]) @pytest.mark.parametrize("scale", [1.0, 0.1]) def test_scale_reward(env_id, scale): - env = pfrl.wrappers.ScaleReward(gymnasium.make(env_id), scale=scale) + env = pfrl.wrappers.ScaleReward(gym.make(env_id), scale=scale) assert env.original_reward is None np.testing.assert_allclose(env.scale, scale) From 03f203f707b2fe885eb0d65bf764eb7b0a41f386 Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Wed, 3 Apr 2024 01:32:52 -0600 Subject: [PATCH 21/26] Removes continuing_time_limit now that gymnasium has truncation --- pfrl/wrappers/continuing_time_limit.py | 40 ------------------- .../test_continuing_time_limit.py | 31 -------------- 2 files changed, 71 deletions(-) delete mode 100644 pfrl/wrappers/continuing_time_limit.py delete mode 100644 tests/wrappers_tests/test_continuing_time_limit.py diff --git a/pfrl/wrappers/continuing_time_limit.py b/pfrl/wrappers/continuing_time_limit.py deleted file mode 100644 index 805f884b8..000000000 --- a/pfrl/wrappers/continuing_time_limit.py +++ /dev/null @@ -1,40 +0,0 @@ -import gymnasium as gym - - -class ContinuingTimeLimit(gym.Wrapper): - """TimeLimit wrapper for continuing environments. - - This is similar to gymnasium.wrappers.TimeLimit, which sets a time limit for - each episode, except that truncated=False is returned and that - info['needs_reset'] is set to True when past the limit. - - Code that calls env.step is responsible for checking the info dict, the - fourth returned value, and resetting the env if it has the 'needs_reset' - key and its value is True. - - Args: - env (gymnasium.Env): Env to wrap. - max_episode_steps (int): Maximum number of timesteps during an episode, - after which the env needs a reset. - """ - - def __init__(self, env, max_episode_steps): - super(ContinuingTimeLimit, self).__init__(env) - self._max_episode_steps = max_episode_steps - self._elapsed_steps = None - - def step(self, action): - assert ( - self._elapsed_steps is not None - ), "Cannot call env.step() before calling reset()" - observation, reward, terminated, _, info = self.env.step(action) - self._elapsed_steps += 1 - - if self._max_episode_steps <= self._elapsed_steps: - info["needs_reset"] = True - - return observation, reward, terminated, False, info - - def reset(self): - self._elapsed_steps = 0 - return self.env.reset() diff --git a/tests/wrappers_tests/test_continuing_time_limit.py b/tests/wrappers_tests/test_continuing_time_limit.py deleted file mode 100644 index f2bcda759..000000000 --- a/tests/wrappers_tests/test_continuing_time_limit.py +++ /dev/null @@ -1,31 +0,0 @@ -from unittest import mock - -import pytest - -import pfrl - - -@pytest.mark.parametrize("max_episode_steps", [1, 2, 3]) -def test_continuing_time_limit(max_episode_steps): - env = mock.Mock() - env.reset.side_effect = ["state"] * 2 - # Since info dicts are modified by the wapper, each step call needs to - # return a new info dict. - env.step.side_effect = [("state", 0, False, {}) for _ in range(6)] - env = pfrl.wrappers.ContinuingTimeLimit(env, max_episode_steps=max_episode_steps) - - env.reset() - for t in range(2): - _, _, done, truncated, info = env.step(0) - if t + 1 >= max_episode_steps: - assert info["needs_reset"] and truncated - else: - assert not info.get("needs_reset", False) and not truncated - - env.reset() - for t in range(4): - _, _, done, truncated , info = env.step(0) - if t + 1 >= max_episode_steps: - assert info["needs_reset"] and truncated - else: - assert not info.get("needs_reset", False) and not truncated From 561432307b6e96f8d84bd9f7c2b7b80c4be4dafa Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Wed, 3 Apr 2024 01:39:01 -0600 Subject: [PATCH 22/26] Remove Monitor --- pfrl/wrappers/monitor.py | 113 --------------------------- tests/wrappers_tests/test_monitor.py | 59 -------------- 2 files changed, 172 deletions(-) delete mode 100644 pfrl/wrappers/monitor.py delete mode 100644 tests/wrappers_tests/test_monitor.py diff --git a/pfrl/wrappers/monitor.py b/pfrl/wrappers/monitor.py deleted file mode 100644 index ad5e1f3ba..000000000 --- a/pfrl/wrappers/monitor.py +++ /dev/null @@ -1,113 +0,0 @@ -import time -from logging import getLogger - -try: - from gymnasium.wrappers import Monitor as _gymnasiumMonitor -except ImportError: - - class _Stub: - def __init__(self, *args, **kwargs): - raise RuntimeError("Monitor is not available in this version of gymnasium") - - class _gymnasiumMonitor(_Stub): # type: ignore - pass - - class _gymnasiumStatsRecorder(_Stub): - pass - -else: - from gymnasium.wrappers.monitoring.stats_recorder import StatsRecorder as _gymnasiumStatsRecorder # type: ignore # isort: skip # noqa: E501 - - -class Monitor(_gymnasiumMonitor): - """`Monitor` with PFRL's `ContinuingTimeLimit` support. - - `Agent` in PFRL might reset the env even when `done=False` - if `ContinuingTimeLimit` returns `info['needs_reset']=True`, - which is not expected for `gymnasium.Monitor`. - - For details, see - https://github.com/openai/gymnasium/blob/master/gymnasium/wrappers/monitor.py - """ - - def _start( - self, - directory, - video_callable=None, - force=False, - resume=False, - write_upon_reset=False, - uid=None, - mode=None, - ): - if self.env_semantics_autoreset: - raise NotImplementedError( - "Detect 'semantics.autoreset=True' in `env.metadata`, " - "which means the env is from deprecated OpenAI Universe." - ) - ret = super()._start( - directory=directory, - video_callable=video_callable, - force=force, - resume=resume, - write_upon_reset=write_upon_reset, - uid=uid, - mode=mode, - ) - env_id = self.stats_recorder.env_id - self.stats_recorder = _StatsRecorder( - directory, - "{}.episode_batch.{}".format(self.file_prefix, self.file_infix), - autoreset=False, - env_id=env_id, - ) - if mode is not None: - self._set_mode(mode) - return ret - - -class _StatsRecorder(_gymnasiumStatsRecorder): - """`StatsRecorder` with PFRL's `ContinuingTimeLimit` support. - - For details, see - https://github.com/openai/gymnasium/blob/master/gymnasium/wrappers/monitoring/stats_recorder.py - """ - - def __init__( - self, - directory, - file_prefix, - autoreset=False, - env_id=None, - logger=getLogger(__name__), - ): - super().__init__(directory, file_prefix, autoreset=autoreset, env_id=env_id) - self._save_completed = True - self.logger = logger - - def before_reset(self): - assert not self.closed - - if self.done is not None and not self.done and self.steps > 0: - self.logger.debug( - "Tried to reset the env which is not done=True. " - "StatsRecorder completes the last episode." - ) - self.save_complete() - - self.done = False - if self.initial_reset_timestamp is None: - self.initial_reset_timestamp = time.time() - - def after_step(self, observation, reward, done, info): - self._save_completed = False - return super().after_step(observation, reward, done, info) - - def save_complete(self): - if not self._save_completed: - super().save_complete() - self._save_completed = True - - def close(self): - self.save_complete() - super().close() diff --git a/tests/wrappers_tests/test_monitor.py b/tests/wrappers_tests/test_monitor.py deleted file mode 100644 index 749440c00..000000000 --- a/tests/wrappers_tests/test_monitor.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -import shutil -import tempfile - -import gymnasium -import pytest -from gymnasium.wrappers import TimeLimit - -import pfrl - - -@pytest.mark.parametrize("n_episodes", [1, 2, 3, 4]) -def test_monitor(n_episodes): - steps = 15 - - env = gymnasium.make("CartPole-v1") - # unwrap default TimeLimit and wrap with new one to simulate done=True - # at step 5 - assert isinstance(env, TimeLimit) - env = env.env # unwrap - env = TimeLimit(env, max_episode_steps=5) # wrap - - tmpdir = tempfile.mkdtemp() - try: - env = pfrl.wrappers.Monitor( - env, directory=tmpdir, video_callable=lambda episode_id: True - ) - episode_idx = 0 - episode_len = 0 - t = 0 - _ = env.reset() - while True: - _, _, terminated, truncated, info = env.step(env.action_space.sample()) - episode_len += 1 - t += 1 - if episode_idx == 1 and episode_len >= 3: - info["needs_reset"] = True # simulate ContinuingTimeLimit - if terminated or truncated or info.get("needs_reset", False) or t == steps: - if episode_idx + 1 == n_episodes or t == steps: - break - env.reset() - episode_idx += 1 - episode_len = 0 - # `env.close()` is called when `env` is gabage-collected - # (or explicitly deleted/closed). - del env - # check if videos & meta files were generated - files = os.listdir(tmpdir) - mp4s = [f for f in files if f.endswith(".mp4")] - metas = [f for f in files if f.endswith(".meta.json")] - stats = [f for f in files if f.endswith(".stats.json")] - manifests = [f for f in files if f.endswith(".manifest.json")] - assert len(mp4s) == n_episodes - assert len(metas) == n_episodes - assert len(stats) == 1 - assert len(manifests) == 1 - - finally: - shutil.rmtree(tmpdir) From 4b0494ef60303dc6fc3f7f709a47a232f0b52692 Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Fri, 5 Apr 2024 20:41:23 -0600 Subject: [PATCH 23/26] Removes things from __init__ --- pfrl/wrappers/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pfrl/wrappers/__init__.py b/pfrl/wrappers/__init__.py index 0f3e99258..ae26a4db5 100644 --- a/pfrl/wrappers/__init__.py +++ b/pfrl/wrappers/__init__.py @@ -1,7 +1,5 @@ from pfrl.wrappers.cast_observation import CastObservation # NOQA from pfrl.wrappers.cast_observation import CastObservationToFloat32 # NOQA -from pfrl.wrappers.continuing_time_limit import ContinuingTimeLimit # NOQA -from pfrl.wrappers.monitor import Monitor # NOQA from pfrl.wrappers.normalize_action_space import NormalizeActionSpace # NOQA from pfrl.wrappers.randomize_action import RandomizeAction # NOQA from pfrl.wrappers.render import Render # NOQA From e92cc638f6a13a8b3b2fade9a9a1862f90fd931e Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Fri, 5 Apr 2024 22:02:43 -0600 Subject: [PATCH 24/26] Moves gym folder in examples to gymnasium --- examples/{gym => gymnasium}/README.md | 0 examples/{gym => gymnasium}/train_categorical_dqn_gym.py | 0 examples/{gym => gymnasium}/train_dqn_gym.py | 0 examples/{gym => gymnasium}/train_reinforce_gym.py | 0 pfrl/wrappers/atari_wrappers.py | 2 +- 5 files changed, 1 insertion(+), 1 deletion(-) rename examples/{gym => gymnasium}/README.md (100%) rename examples/{gym => gymnasium}/train_categorical_dqn_gym.py (100%) rename examples/{gym => gymnasium}/train_dqn_gym.py (100%) rename examples/{gym => gymnasium}/train_reinforce_gym.py (100%) diff --git a/examples/gym/README.md b/examples/gymnasium/README.md similarity index 100% rename from examples/gym/README.md rename to examples/gymnasium/README.md diff --git a/examples/gym/train_categorical_dqn_gym.py b/examples/gymnasium/train_categorical_dqn_gym.py similarity index 100% rename from examples/gym/train_categorical_dqn_gym.py rename to examples/gymnasium/train_categorical_dqn_gym.py diff --git a/examples/gym/train_dqn_gym.py b/examples/gymnasium/train_dqn_gym.py similarity index 100% rename from examples/gym/train_dqn_gym.py rename to examples/gymnasium/train_dqn_gym.py diff --git a/examples/gym/train_reinforce_gym.py b/examples/gymnasium/train_reinforce_gym.py similarity index 100% rename from examples/gym/train_reinforce_gym.py rename to examples/gymnasium/train_reinforce_gym.py diff --git a/pfrl/wrappers/atari_wrappers.py b/pfrl/wrappers/atari_wrappers.py index d54e06cc1..02a821f7e 100644 --- a/pfrl/wrappers/atari_wrappers.py +++ b/pfrl/wrappers/atari_wrappers.py @@ -6,7 +6,7 @@ import gymnasium as gym import numpy as np -from gym import spaces +from gymnasium import spaces from packaging import version import pfrl From 420dddb22e7036439c3ea21da7cf450b7adf9d8b Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Wed, 10 Apr 2024 01:54:52 -0600 Subject: [PATCH 25/26] Fixes some imports and some tests --- pfrl/envs/abc.py | 2 +- tests/experiments_tests/test_evaluator.py | 24 +++++++++++------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pfrl/envs/abc.py b/pfrl/envs/abc.py index 74ad9c514..53e5591fe 100644 --- a/pfrl/envs/abc.py +++ b/pfrl/envs/abc.py @@ -1,6 +1,6 @@ import numpy as np import gymnasium as gym -from gym import spaces +from gymnasium import spaces from pfrl import env diff --git a/tests/experiments_tests/test_evaluator.py b/tests/experiments_tests/test_evaluator.py index 3148edc1e..2f6f82791 100644 --- a/tests/experiments_tests/test_evaluator.py +++ b/tests/experiments_tests/test_evaluator.py @@ -21,7 +21,7 @@ def test_evaluator_evaluate_if_necessary(save_best_so_far_agent, n_steps, n_epis agent.get_statistics.return_value = [] env = mock.Mock() - env.reset.return_value = "obs" + env.reset.return_value = "obs", {} env.step.return_value = ("obs", 0, True, False, {}) env.get_statistics.return_value = [] @@ -110,7 +110,7 @@ def test_async_evaluator_evaluate_if_necessary(save_best_so_far_agent, n_episode agent.get_statistics.return_value = [] env = mock.Mock() - env.reset.return_value = "obs" + env.reset.return_value = "obs", {} env.step.return_value = ("obs", 0, True, False, {}) env.get_statistics.return_value = [] @@ -179,10 +179,10 @@ def test_run_evaluation_episodes_with_n_steps(n_episodes, n_steps): # Second episode: 4 -> 5 -> 6 -> 7 (done) env.reset.side_effect = [("state", 0), ("state", 4)] env.step.side_effect = [ - (("state", 1), 0.1, False, {}), - (("state", 2), 0.2, False, {}), - (("state", 3), 0.3, False, {"needs_reset": True}), - (("state", 5), -0.5, False, {}), + (("state", 1), 0.1, False, False, {}), + (("state", 2), 0.2, False, False, {}), + (("state", 3), 0.3, False, True, {"needs_reset": True}), + (("state", 5), -0.5, False, False, {}), (("state", 6), 0, False, False, {}), (("state", 7), 1, True, False, {}), ] @@ -226,12 +226,12 @@ def test_needs_reset(self): # Second episode: 4 -> 5 -> 6 -> 7 (done) env.reset.side_effect = [("state", 0), ("state", 4)] env.step.side_effect = [ - (("state", 1), 0, False, {}), - (("state", 2), 0, False, {}), - (("state", 3), 0, False, {"needs_reset": True}), - (("state", 5), -0.5, False, {}), - (("state", 6), 0, False, {}), - (("state", 7), 1, True, {}), + (("state", 1), 0, False, False, {}), + (("state", 2), 0, False, False, {}), + (("state", 3), 0, False, True, {"needs_reset": True}), + (("state", 5), -0.5, False, False, {}), + (("state", 6), 0, False, False, {}), + (("state", 7), 1, True, False, {}), ] scores, lengths = evaluator.run_evaluation_episodes( env, agent, n_steps=None, n_episodes=2 From 74198b9ac4ad4ee4e6da1e7bddd8f2191e92af10 Mon Sep 17 00:00:00 2001 From: Prabhat Nagarajan Date: Mon, 6 May 2024 22:53:36 -0600 Subject: [PATCH 26/26] Fixes Randomize Action Wrapper --- pfrl/wrappers/randomize_action.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pfrl/wrappers/randomize_action.py b/pfrl/wrappers/randomize_action.py index 8d848d15b..d9485aa8c 100644 --- a/pfrl/wrappers/randomize_action.py +++ b/pfrl/wrappers/randomize_action.py @@ -36,8 +36,7 @@ def action(self, action): return action def reset(self, **kwargs): + if 'seed' in kwargs: + self._rng = np.random.RandomState(kwargs['seed']) return self.env.reset(**kwargs) - def seed(self, seed): - super().seed(seed) - self._rng.seed(seed)