diff --git a/.coveragerc b/.coveragerc index 511f20d8b..9020116c6 100644 --- a/.coveragerc +++ b/.coveragerc @@ -4,7 +4,7 @@ omit = tests/* setup.py # Require graphical interface - torchy_baselines/common/results_plotter.py + stable_baselines3/common/results_plotter.py [report] exclude_lines = diff --git a/.github/ISSUE_TEMPLATE/issue-template.md b/.github/ISSUE_TEMPLATE/issue-template.md index 2e2e61b6b..f9ff9622f 100644 --- a/.github/ISSUE_TEMPLATE/issue-template.md +++ b/.github/ISSUE_TEMPLATE/issue-template.md @@ -13,7 +13,7 @@ If you are submitting a bug report, please fill in the following details. If your issue is related to a custom gym environment, please check it first using: ```python -from torchy_baselines.common.env_checker import check_env +from stable_baselines3.common.env_checker import check_env env = CustomEnv(arg1, ...) # It will check your custom environment and output additional warnings if needed @@ -30,7 +30,7 @@ Please use the [markdown code blocks](https://help.github.com/en/articles/creati for both code and stack traces. ```python -from torchy_baselines import ... +from stable_baselines3 import ... ``` diff --git a/NOTICE b/NOTICE index 9fc6700ee..6dbbda648 100644 --- a/NOTICE +++ b/NOTICE @@ -1,4 +1,4 @@ -Large portion of the code of Torchy-Baselines (in `common/`) were ported from Stable-Baselines, a fork of OpenAI Baselines, +Large portion of the code of Stable-Baselines3 (in `common/`) were ported from Stable-Baselines, a fork of OpenAI Baselines, both licensed under the MIT License: before the fork (June 2018): diff --git a/README.md b/README.md index 843d55e16..5f35c82e3 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![Build Status](https://travis-ci.com/hill-a/stable-baselines.svg?branch=master)](https://travis-ci.com/hill-a/stable-baselines) [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines.readthedocs.io/en/master/?badge=master) -# Torchy Baselines +# Stable Baselines3 PyTorch version of [Stable Baselines](https://github.com/hill-a/stable-baselines), a set of improved implementations of reinforcement learning algorithms. @@ -58,7 +58,7 @@ To cite this repository in publications: ``` @misc{torchy-baselines, author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah}, - title = {Torchy Baselines}, + title = {Stable Baselines3}, year = {2019}, publisher = {GitHub}, journal = {GitHub repository}, diff --git a/docs/conf.py b/docs/conf.py index e71d1660f..a3d83bcec 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -44,19 +44,19 @@ def __getattr__(cls, name): sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) -import torchy_baselines +import stable_baselines3 # -- Project information ----------------------------------------------------- -project = 'Torchy Baselines' -copyright = '2020, Torchy Baselines' -author = 'Torchy Baselines Contributors' +project = 'Stable Baselines3' +copyright = '2020, Stable Baselines3' +author = 'Stable Baselines3 Contributors' # The short X.Y version -version = 'master (' + torchy_baselines.__version__ + ' )' +version = 'master (' + stable_baselines3.__version__ + ' )' # The full version, including alpha/beta/rc tags -release = torchy_baselines.__version__ +release = stable_baselines3.__version__ # -- General configuration --------------------------------------------------- @@ -179,8 +179,8 @@ def setup(app): # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'TorchyBaselines.tex', 'Torchy Baselines Documentation', - 'Torchy Baselines Contributors', 'manual'), + (master_doc, 'TorchyBaselines.tex', 'Stable Baselines3 Documentation', + 'Stable Baselines3 Contributors', 'manual'), ] @@ -189,7 +189,7 @@ def setup(app): # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, 'torchybaselines', 'Torchy Baselines Documentation', + (master_doc, 'torchybaselines', 'Stable Baselines3 Documentation', [author], 1) ] @@ -200,7 +200,7 @@ def setup(app): # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'TorchyBaselines', 'Torchy Baselines Documentation', + (master_doc, 'TorchyBaselines', 'Stable Baselines3 Documentation', author, 'TorchyBaselines', 'One line description of project.', 'Miscellaneous'), ] diff --git a/docs/guide/quickstart.rst b/docs/guide/quickstart.rst index 58fca6dea..e20f36a26 100644 --- a/docs/guide/quickstart.rst +++ b/docs/guide/quickstart.rst @@ -12,9 +12,9 @@ Here is a quick example of how to train and run SAC on a Pendulum environment: import gym - from torchy_baselines.sac.policies import MlpPolicy - from torchy_baselines.common.vec_env import DummyVecEnv - from torchy_baselines import SAC + from stable_baselines3.sac.policies import MlpPolicy + from stable_baselines3.common.vec_env import DummyVecEnv + from stable_baselines3 import SAC env = gym.make('Pendulum-v0') @@ -34,6 +34,6 @@ the policy is registered: .. code-block:: python - from torchy_baselines import SAC + from stable_baselines3 import SAC model = SAC('MlpPolicy', 'Pendulum-v0').learn(10000) diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index e0e930cb2..2bce8d120 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -1,6 +1,6 @@ .. _vec_env: -.. automodule:: torchy_baselines.common.vec_env +.. automodule:: stable_baselines3.common.vec_env Vectorized Environments ======================= diff --git a/docs/index.rst b/docs/index.rst index c46e3177d..c97722bc9 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,10 +3,10 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -Welcome to Torchy Baselines docs! - Pytorch RL Baselines +Welcome to Stable Baselines3 docs! - Pytorch RL Baselines ======================================================== -`Torchy Baselines `_ is the PyTorch version of `Stable Baselines `_, +`Stable Baselines3 `_ is the PyTorch version of `Stable Baselines `_, a set of improved implementations of reinforcement learning algorithms. RL Baselines Zoo (collection of pre-trained agents): https://github.com/araffin/rl-baselines-zoo @@ -41,7 +41,7 @@ RL Baselines zoo also offers a simple interface to train, evaluate agents and do misc/changelog -Citing Torchy Baselines +Citing Stable Baselines3 ----------------------- To cite this project in publications: @@ -49,7 +49,7 @@ To cite this project in publications: @misc{torchy-baselines, author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah}, - title = {Torchy Baselines}, + title = {Stable Baselines3}, year = {2019}, publisher = {GitHub}, journal = {GitHub repository}, diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index a4fd0974c..3767b7cab 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -111,7 +111,7 @@ Pre-Release 0.2.0 (2020-02-14) Breaking Changes: ^^^^^^^^^^^^^^^^^ -- Python 2 support was dropped, Torchy Baselines now requires Python 3.6 or above +- Python 2 support was dropped, Stable Baselines3 now requires Python 3.6 or above - Return type of ``evaluation.evaluate_policy()`` has been changed - Refactored the replay buffer to avoid transformation between PyTorch and NumPy - Created `OffPolicyRLModel` base class @@ -160,7 +160,7 @@ New Features: Maintainers ----------- -Torchy-Baselines is currently maintained by `Antonin Raffin`_ (aka `@araffin`_). +Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_). .. _Antonin Raffin: https://araffin.github.io/ .. _@araffin: https://github.com/araffin diff --git a/docs/modules/a2c.rst b/docs/modules/a2c.rst index d8f5cf2c8..0ee3aa7e8 100644 --- a/docs/modules/a2c.rst +++ b/docs/modules/a2c.rst @@ -1,6 +1,6 @@ .. _a2c: -.. automodule:: torchy_baselines.a2c +.. automodule:: stable_baselines3.a2c A2C @@ -44,9 +44,9 @@ Train a A2C agent on `CartPole-v1` using 4 processes. import gym - from torchy_baselines.common.policies import MlpPolicy - from torchy_baselines.common import make_vec_env - from torchy_baselines import A2C + from stable_baselines3.common.policies import MlpPolicy + from stable_baselines3.common import make_vec_env + from stable_baselines3 import A2C # Parallel environments env = make_vec_env('CartPole-v1', n_envs=4) diff --git a/docs/modules/base.rst b/docs/modules/base.rst index d32268d5f..7fa2a59fb 100644 --- a/docs/modules/base.rst +++ b/docs/modules/base.rst @@ -1,6 +1,6 @@ .. _base_algo: -.. automodule:: torchy_baselines.common.base_class +.. automodule:: stable_baselines3.common.base_class Base RL Class diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index 9c9887ddd..477c13f93 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -1,6 +1,6 @@ .. _ppo2: -.. automodule:: torchy_baselines.ppo +.. automodule:: stable_baselines3.ppo PPO === @@ -53,9 +53,9 @@ Train a PPO agent on `Pendulum-v0` using 4 processes. import gym - from torchy_baselines.ppo.policies import MlpPolicy - from torchy_baselines.common.vec_env import SubprocVecEnv - from torchy_baselines import PPO + from stable_baselines3.ppo.policies import MlpPolicy + from stable_baselines3.common.vec_env import SubprocVecEnv + from stable_baselines3 import PPO # multiprocess environment n_cpu = 4 diff --git a/docs/modules/sac.rst b/docs/modules/sac.rst index 7bd949d07..c067bcf43 100644 --- a/docs/modules/sac.rst +++ b/docs/modules/sac.rst @@ -1,6 +1,6 @@ .. _sac: -.. automodule:: torchy_baselines.sac +.. automodule:: stable_baselines3.sac SAC @@ -14,7 +14,7 @@ A key feature of SAC, and a major difference with common RL algorithms, is that .. warning:: - The SAC model does not support ``torchy_baselines.common.policies`` because it uses double q-values + The SAC model does not support ``stable_baselines3.common.policies`` because it uses double q-values and value estimation, as a result it must use its own policy models (see :ref:`sac_policies`). @@ -72,9 +72,9 @@ Example import gym import numpy as np - from torchy_baselines.sac.policies import MlpPolicy - from torchy_baselines.common.vec_env import DummyVecEnv - from torchy_baselines import SAC + from stable_baselines3.sac.policies import MlpPolicy + from stable_baselines3.common.vec_env import DummyVecEnv + from stable_baselines3 import SAC env = gym.make('Pendulum-v0') env = DummyVecEnv([lambda: env]) diff --git a/docs/modules/td3.rst b/docs/modules/td3.rst index 9fd6806ba..338b9dac4 100644 --- a/docs/modules/td3.rst +++ b/docs/modules/td3.rst @@ -1,6 +1,6 @@ .. _td3: -.. automodule:: torchy_baselines.td3 +.. automodule:: stable_baselines3.td3 TD3 @@ -14,7 +14,7 @@ We recommend reading `OpenAI Spinning guide on TD3 =0.11', 'numpy', diff --git a/torchy_baselines/__init__.py b/stable_baselines3/__init__.py similarity index 55% rename from torchy_baselines/__init__.py rename to stable_baselines3/__init__.py index 28742f571..562ca36c5 100644 --- a/torchy_baselines/__init__.py +++ b/stable_baselines3/__init__.py @@ -1,9 +1,9 @@ import os -from torchy_baselines.a2c import A2C -from torchy_baselines.ppo import PPO -from torchy_baselines.sac import SAC -from torchy_baselines.td3 import TD3 +from stable_baselines3.a2c import A2C +from stable_baselines3.ppo import PPO +from stable_baselines3.sac import SAC +from stable_baselines3.td3 import TD3 # Read version from file version_file = os.path.join(os.path.dirname(__file__), 'version.txt') diff --git a/stable_baselines3/a2c/__init__.py b/stable_baselines3/a2c/__init__.py new file mode 100644 index 000000000..7dba39a7e --- /dev/null +++ b/stable_baselines3/a2c/__init__.py @@ -0,0 +1,2 @@ +from stable_baselines3.a2c.a2c import A2C +from stable_baselines3.ppo.policies import MlpPolicy diff --git a/torchy_baselines/a2c/a2c.py b/stable_baselines3/a2c/a2c.py similarity index 96% rename from torchy_baselines/a2c/a2c.py rename to stable_baselines3/a2c/a2c.py index ff10bf42d..3c5ead955 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -3,11 +3,11 @@ from gym import spaces from typing import Type, Union, Callable, Optional, Dict, Any -from torchy_baselines.common import logger -from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback -from torchy_baselines.common.utils import explained_variance -from torchy_baselines.ppo.policies import PPOPolicy -from torchy_baselines.ppo.ppo import PPO +from stable_baselines3.common import logger +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback +from stable_baselines3.common.utils import explained_variance +from stable_baselines3.ppo.policies import PPOPolicy +from stable_baselines3.ppo.ppo import PPO class A2C(PPO): diff --git a/torchy_baselines/common/__init__.py b/stable_baselines3/common/__init__.py similarity index 100% rename from torchy_baselines/common/__init__.py rename to stable_baselines3/common/__init__.py diff --git a/torchy_baselines/common/base_class.py b/stable_baselines3/common/base_class.py similarity index 97% rename from torchy_baselines/common/base_class.py rename to stable_baselines3/common/base_class.py index bb1404f1d..1aee061ca 100644 --- a/torchy_baselines/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -11,17 +11,17 @@ import torch as th import numpy as np -from torchy_baselines.common import logger -from torchy_baselines.common.policies import BasePolicy, get_policy_from_name -from torchy_baselines.common.utils import set_random_seed, get_schedule_fn, update_learning_rate, get_device -from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize, VecTransposeImage -from torchy_baselines.common.preprocessing import is_image_space -from torchy_baselines.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr -from torchy_baselines.common.type_aliases import GymEnv, TensorDict, RolloutReturn, MaybeCallback -from torchy_baselines.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback -from torchy_baselines.common.monitor import Monitor -from torchy_baselines.common.noise import ActionNoise -from torchy_baselines.common.buffers import ReplayBuffer +from stable_baselines3.common import logger +from stable_baselines3.common.policies import BasePolicy, get_policy_from_name +from stable_baselines3.common.utils import set_random_seed, get_schedule_fn, update_learning_rate, get_device +from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize, VecTransposeImage +from stable_baselines3.common.preprocessing import is_image_space +from stable_baselines3.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr +from stable_baselines3.common.type_aliases import GymEnv, TensorDict, RolloutReturn, MaybeCallback +from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.noise import ActionNoise +from stable_baselines3.common.buffers import ReplayBuffer class BaseRLModel(ABC): diff --git a/torchy_baselines/common/buffers.py b/stable_baselines3/common/buffers.py similarity index 98% rename from torchy_baselines/common/buffers.py rename to stable_baselines3/common/buffers.py index fd78d3e29..4fb4422b2 100644 --- a/torchy_baselines/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -4,9 +4,9 @@ import torch as th from gym import spaces -from torchy_baselines.common.vec_env import VecNormalize -from torchy_baselines.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples -from torchy_baselines.common.preprocessing import get_action_dim, get_obs_shape +from stable_baselines3.common.vec_env import VecNormalize +from stable_baselines3.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples +from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape class BaseBuffer(object): diff --git a/torchy_baselines/common/callbacks.py b/stable_baselines3/common/callbacks.py similarity index 97% rename from torchy_baselines/common/callbacks.py rename to stable_baselines3/common/callbacks.py index b5a015d1f..16143a7ee 100644 --- a/torchy_baselines/common/callbacks.py +++ b/stable_baselines3/common/callbacks.py @@ -7,12 +7,12 @@ import gym import numpy as np -from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization -from torchy_baselines.common.evaluation import evaluate_policy -from torchy_baselines.common.logger import Logger +from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, sync_envs_normalization +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.logger import Logger if typing.TYPE_CHECKING: - from torchy_baselines.common.base_class import BaseRLModel # pytype: disable=pyi-error + from stable_baselines3.common.base_class import BaseRLModel # pytype: disable=pyi-error class BaseCallback(ABC): diff --git a/torchy_baselines/common/distributions.py b/stable_baselines3/common/distributions.py similarity index 99% rename from torchy_baselines/common/distributions.py rename to stable_baselines3/common/distributions.py index 9a67b93a7..e3ff2a7f2 100644 --- a/torchy_baselines/common/distributions.py +++ b/stable_baselines3/common/distributions.py @@ -6,7 +6,7 @@ from torch.distributions import Normal, Categorical from gym import spaces -from torchy_baselines.common.preprocessing import get_action_dim +from stable_baselines3.common.preprocessing import get_action_dim class Distribution(object): diff --git a/torchy_baselines/common/evaluation.py b/stable_baselines3/common/evaluation.py similarity index 97% rename from torchy_baselines/common/evaluation.py rename to stable_baselines3/common/evaluation.py index b5017c9a9..c8d0b9769 100644 --- a/torchy_baselines/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -1,7 +1,7 @@ # Copied from stable_baselines import numpy as np -from torchy_baselines.common.vec_env import VecEnv +from stable_baselines3.common.vec_env import VecEnv def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True, diff --git a/torchy_baselines/common/identity_env.py b/stable_baselines3/common/identity_env.py similarity index 98% rename from torchy_baselines/common/identity_env.py rename to stable_baselines3/common/identity_env.py index 8a76d3b18..0a3daad24 100644 --- a/torchy_baselines/common/identity_env.py +++ b/stable_baselines3/common/identity_env.py @@ -5,7 +5,7 @@ from gym.spaces import Discrete, MultiDiscrete, MultiBinary, Box -from torchy_baselines.common.type_aliases import GymStepReturn, GymObs +from stable_baselines3.common.type_aliases import GymStepReturn, GymObs class IdentityEnv(Env): diff --git a/torchy_baselines/common/logger.py b/stable_baselines3/common/logger.py similarity index 100% rename from torchy_baselines/common/logger.py rename to stable_baselines3/common/logger.py diff --git a/torchy_baselines/common/monitor.py b/stable_baselines3/common/monitor.py similarity index 100% rename from torchy_baselines/common/monitor.py rename to stable_baselines3/common/monitor.py diff --git a/torchy_baselines/common/noise.py b/stable_baselines3/common/noise.py similarity index 100% rename from torchy_baselines/common/noise.py rename to stable_baselines3/common/noise.py diff --git a/torchy_baselines/common/policies.py b/stable_baselines3/common/policies.py similarity index 99% rename from torchy_baselines/common/policies.py rename to stable_baselines3/common/policies.py index f6c664449..22b51dfa8 100644 --- a/torchy_baselines/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -7,9 +7,9 @@ import torch.nn as nn import numpy as np -from torchy_baselines.common.preprocessing import preprocess_obs, get_flattened_obs_dim, is_image_space -from torchy_baselines.common.utils import get_device -from torchy_baselines.common.vec_env import VecTransposeImage +from stable_baselines3.common.preprocessing import preprocess_obs, get_flattened_obs_dim, is_image_space +from stable_baselines3.common.utils import get_device +from stable_baselines3.common.vec_env import VecTransposeImage class BaseFeaturesExtractor(nn.Module): diff --git a/torchy_baselines/common/preprocessing.py b/stable_baselines3/common/preprocessing.py similarity index 100% rename from torchy_baselines/common/preprocessing.py rename to stable_baselines3/common/preprocessing.py diff --git a/torchy_baselines/common/results_plotter.py b/stable_baselines3/common/results_plotter.py similarity index 98% rename from torchy_baselines/common/results_plotter.py rename to stable_baselines3/common/results_plotter.py index d447344cd..4f879be2a 100644 --- a/torchy_baselines/common/results_plotter.py +++ b/stable_baselines3/common/results_plotter.py @@ -6,7 +6,7 @@ # matplotlib.use('TkAgg') # Can change to 'Agg' for non-interactive mode import matplotlib.pyplot as plt -from torchy_baselines.common.monitor import load_results +from stable_baselines3.common.monitor import load_results X_TIMESTEPS = 'timesteps' diff --git a/torchy_baselines/common/running_mean_std.py b/stable_baselines3/common/running_mean_std.py similarity index 100% rename from torchy_baselines/common/running_mean_std.py rename to stable_baselines3/common/running_mean_std.py diff --git a/torchy_baselines/common/save_util.py b/stable_baselines3/common/save_util.py similarity index 100% rename from torchy_baselines/common/save_util.py rename to stable_baselines3/common/save_util.py diff --git a/torchy_baselines/common/type_aliases.py b/stable_baselines3/common/type_aliases.py similarity index 89% rename from torchy_baselines/common/type_aliases.py rename to stable_baselines3/common/type_aliases.py index 70b63a8a3..139cd6096 100644 --- a/torchy_baselines/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -7,8 +7,8 @@ import torch as th import gym -from torchy_baselines.common.vec_env import VecEnv -from torchy_baselines.common.callbacks import BaseCallback +from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.common.callbacks import BaseCallback GymEnv = Union[gym.Env, VecEnv] diff --git a/torchy_baselines/common/utils.py b/stable_baselines3/common/utils.py similarity index 100% rename from torchy_baselines/common/utils.py rename to stable_baselines3/common/utils.py diff --git a/torchy_baselines/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py similarity index 68% rename from torchy_baselines/common/vec_env/__init__.py rename to stable_baselines3/common/vec_env/__init__.py index b119a9f25..2c8918119 100644 --- a/torchy_baselines/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -3,17 +3,17 @@ from typing import Optional, Union from copy import deepcopy -from torchy_baselines.common.vec_env.base_vec_env import (AlreadySteppingError, NotSteppingError, +from stable_baselines3.common.vec_env.base_vec_env import (AlreadySteppingError, NotSteppingError, VecEnv, VecEnvWrapper, CloudpickleWrapper) -from torchy_baselines.common.vec_env.dummy_vec_env import DummyVecEnv -from torchy_baselines.common.vec_env.subproc_vec_env import SubprocVecEnv -from torchy_baselines.common.vec_env.vec_frame_stack import VecFrameStack -from torchy_baselines.common.vec_env.vec_normalize import VecNormalize -from torchy_baselines.common.vec_env.vec_transpose import VecTransposeImage +from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv +from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv +from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack +from stable_baselines3.common.vec_env.vec_normalize import VecNormalize +from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage # Avoid circular import if typing.TYPE_CHECKING: - from torchy_baselines.common.type_aliases import GymEnv + from stable_baselines3.common.type_aliases import GymEnv def unwrap_vec_normalize(env: Union['GymEnv', VecEnv]) -> Optional[VecNormalize]: diff --git a/torchy_baselines/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py similarity index 100% rename from torchy_baselines/common/vec_env/base_vec_env.py rename to stable_baselines3/common/vec_env/base_vec_env.py diff --git a/torchy_baselines/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py similarity index 96% rename from torchy_baselines/common/vec_env/dummy_vec_env.py rename to stable_baselines3/common/vec_env/dummy_vec_env.py index 2d0211e73..669fede89 100644 --- a/torchy_baselines/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -3,8 +3,8 @@ import numpy as np -from torchy_baselines.common.vec_env.base_vec_env import VecEnv -from torchy_baselines.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info +from stable_baselines3.common.vec_env.base_vec_env import VecEnv +from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info class DummyVecEnv(VecEnv): diff --git a/torchy_baselines/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py similarity index 99% rename from torchy_baselines/common/vec_env/subproc_vec_env.py rename to stable_baselines3/common/vec_env/subproc_vec_env.py index 5e6ee858c..128cdca07 100644 --- a/torchy_baselines/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -4,7 +4,7 @@ import gym import numpy as np -from torchy_baselines.common.vec_env.base_vec_env import VecEnv, CloudpickleWrapper +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, CloudpickleWrapper def _worker(remote, parent_remote, env_fn_wrapper): diff --git a/torchy_baselines/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py similarity index 100% rename from torchy_baselines/common/vec_env/util.py rename to stable_baselines3/common/vec_env/util.py diff --git a/torchy_baselines/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py similarity index 96% rename from torchy_baselines/common/vec_env/vec_frame_stack.py rename to stable_baselines3/common/vec_env/vec_frame_stack.py index 667616247..b32ddebd0 100644 --- a/torchy_baselines/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -3,7 +3,7 @@ import numpy as np from gym import spaces -from torchy_baselines.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper class VecFrameStack(VecEnvWrapper): diff --git a/torchy_baselines/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py similarity index 97% rename from torchy_baselines/common/vec_env/vec_normalize.py rename to stable_baselines3/common/vec_env/vec_normalize.py index 87fb70ae6..94cf74b14 100644 --- a/torchy_baselines/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -2,8 +2,8 @@ import numpy as np -from torchy_baselines.common.vec_env.base_vec_env import VecEnvWrapper -from torchy_baselines.common.running_mean_std import RunningMeanStd +from stable_baselines3.common.vec_env.base_vec_env import VecEnvWrapper +from stable_baselines3.common.running_mean_std import RunningMeanStd class VecNormalize(VecEnvWrapper): diff --git a/torchy_baselines/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py similarity index 89% rename from torchy_baselines/common/vec_env/vec_transpose.py rename to stable_baselines3/common/vec_env/vec_transpose.py index 3a514e97f..e4901b4c3 100644 --- a/torchy_baselines/common/vec_env/vec_transpose.py +++ b/stable_baselines3/common/vec_env/vec_transpose.py @@ -2,11 +2,11 @@ import numpy as np from gym import spaces -from torchy_baselines.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper -from torchy_baselines.common.preprocessing import is_image_space +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper +from stable_baselines3.common.preprocessing import is_image_space if typing.TYPE_CHECKING: - from torchy_baselines.common.type_aliases import GymStepReturn + from stable_baselines3.common.type_aliases import GymStepReturn class VecTransposeImage(VecEnvWrapper): diff --git a/stable_baselines3/ppo/__init__.py b/stable_baselines3/ppo/__init__.py new file mode 100644 index 000000000..8c9ed8ed9 --- /dev/null +++ b/stable_baselines3/ppo/__init__.py @@ -0,0 +1,2 @@ +from stable_baselines3.ppo.ppo import PPO +from stable_baselines3.ppo.policies import MlpPolicy diff --git a/torchy_baselines/ppo/policies.py b/stable_baselines3/ppo/policies.py similarity index 99% rename from torchy_baselines/ppo/policies.py rename to stable_baselines3/ppo/policies.py index 2b9164a1e..e9977698e 100644 --- a/torchy_baselines/ppo/policies.py +++ b/stable_baselines3/ppo/policies.py @@ -6,10 +6,10 @@ import torch.nn as nn import numpy as np -from torchy_baselines.common.policies import (BasePolicy, register_policy, MlpExtractor, +from stable_baselines3.common.policies import (BasePolicy, register_policy, MlpExtractor, create_sde_features_extractor, NatureCNN, BaseFeaturesExtractor, FlattenExtractor) -from torchy_baselines.common.distributions import (make_proba_distribution, Distribution, +from stable_baselines3.common.distributions import (make_proba_distribution, Distribution, DiagGaussianDistribution, CategoricalDistribution, StateDependentNoiseDistribution) diff --git a/torchy_baselines/ppo/ppo.py b/stable_baselines3/ppo/ppo.py similarity index 97% rename from torchy_baselines/ppo/ppo.py rename to stable_baselines3/ppo/ppo.py index d9496f05f..3e81fdb77 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -14,14 +14,14 @@ # SummaryWriter = None import numpy as np -from torchy_baselines.common import logger -from torchy_baselines.common.base_class import BaseRLModel -from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback -from torchy_baselines.common.buffers import RolloutBuffer -from torchy_baselines.common.utils import explained_variance, get_schedule_fn -from torchy_baselines.common.vec_env import VecEnv -from torchy_baselines.common.callbacks import BaseCallback -from torchy_baselines.ppo.policies import PPOPolicy +from stable_baselines3.common import logger +from stable_baselines3.common.base_class import BaseRLModel +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback +from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.utils import explained_variance, get_schedule_fn +from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.ppo.policies import PPOPolicy class PPO(BaseRLModel): diff --git a/torchy_baselines/py.typed b/stable_baselines3/py.typed similarity index 100% rename from torchy_baselines/py.typed rename to stable_baselines3/py.typed diff --git a/stable_baselines3/sac/__init__.py b/stable_baselines3/sac/__init__.py new file mode 100644 index 000000000..8c893788e --- /dev/null +++ b/stable_baselines3/sac/__init__.py @@ -0,0 +1,2 @@ +from stable_baselines3.sac.sac import SAC +from stable_baselines3.sac.policies import MlpPolicy diff --git a/torchy_baselines/sac/policies.py b/stable_baselines3/sac/policies.py similarity index 98% rename from torchy_baselines/sac/policies.py rename to stable_baselines3/sac/policies.py index ecd6faf44..868fde60d 100644 --- a/torchy_baselines/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -4,11 +4,11 @@ import torch as th import torch.nn as nn -from torchy_baselines.common.preprocessing import get_action_dim -from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, +from stable_baselines3.common.preprocessing import get_action_dim +from stable_baselines3.common.policies import (BasePolicy, register_policy, create_mlp, create_sde_features_extractor, NatureCNN, BaseFeaturesExtractor, FlattenExtractor) -from torchy_baselines.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution +from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution # CAP the standard deviation of the actor LOG_STD_MAX = 2 diff --git a/torchy_baselines/sac/sac.py b/stable_baselines3/sac/sac.py similarity index 98% rename from torchy_baselines/sac/sac.py rename to stable_baselines3/sac/sac.py index c1cbc6474..85aba9ec5 100644 --- a/torchy_baselines/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -4,11 +4,11 @@ import torch.nn.functional as F import numpy as np -from torchy_baselines.common import logger -from torchy_baselines.common.base_class import OffPolicyRLModel -from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback -from torchy_baselines.common.noise import ActionNoise -from torchy_baselines.sac.policies import SACPolicy +from stable_baselines3.common import logger +from stable_baselines3.common.base_class import OffPolicyRLModel +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback +from stable_baselines3.common.noise import ActionNoise +from stable_baselines3.sac.policies import SACPolicy class SAC(OffPolicyRLModel): diff --git a/stable_baselines3/td3/__init__.py b/stable_baselines3/td3/__init__.py new file mode 100644 index 000000000..96cecdfd1 --- /dev/null +++ b/stable_baselines3/td3/__init__.py @@ -0,0 +1,2 @@ +from stable_baselines3.td3.td3 import TD3 +from stable_baselines3.td3.policies import MlpPolicy diff --git a/torchy_baselines/td3/policies.py b/stable_baselines3/td3/policies.py similarity index 99% rename from torchy_baselines/td3/policies.py rename to stable_baselines3/td3/policies.py index b43adaf3f..5965334f9 100644 --- a/torchy_baselines/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -4,11 +4,11 @@ import torch as th import torch.nn as nn -from torchy_baselines.common.preprocessing import get_action_dim -from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, +from stable_baselines3.common.preprocessing import get_action_dim +from stable_baselines3.common.policies import (BasePolicy, register_policy, create_mlp, create_sde_features_extractor, NatureCNN, BaseFeaturesExtractor, FlattenExtractor) -from torchy_baselines.common.distributions import StateDependentNoiseDistribution +from stable_baselines3.common.distributions import StateDependentNoiseDistribution class Actor(BasePolicy): diff --git a/torchy_baselines/td3/td3.py b/stable_baselines3/td3/td3.py similarity index 97% rename from torchy_baselines/td3/td3.py rename to stable_baselines3/td3/td3.py index 09742df6f..18361f29a 100644 --- a/torchy_baselines/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -2,11 +2,11 @@ import torch.nn.functional as F from typing import List, Tuple, Type, Union, Callable, Optional, Dict, Any -from torchy_baselines.common import logger -from torchy_baselines.common.base_class import OffPolicyRLModel -from torchy_baselines.common.noise import ActionNoise -from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback -from torchy_baselines.td3.policies import TD3Policy +from stable_baselines3.common import logger +from stable_baselines3.common.base_class import OffPolicyRLModel +from stable_baselines3.common.noise import ActionNoise +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback +from stable_baselines3.td3.policies import TD3Policy class TD3(OffPolicyRLModel): diff --git a/torchy_baselines/version.txt b/stable_baselines3/version.txt similarity index 100% rename from torchy_baselines/version.txt rename to stable_baselines3/version.txt diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 5f0fc07ee..ff8958fd4 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -4,8 +4,8 @@ import pytest import gym -from torchy_baselines import A2C, PPO, SAC, TD3 -from torchy_baselines.common.callbacks import (CallbackList, CheckpointCallback, EvalCallback, +from stable_baselines3 import A2C, PPO, SAC, TD3 +from stable_baselines3.common.callbacks import (CallbackList, CheckpointCallback, EvalCallback, EveryNTimesteps, StopTrainingOnRewardThreshold) diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 1aec9d4a1..9dc342d5c 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -3,8 +3,8 @@ import numpy as np import pytest -from torchy_baselines import A2C, PPO, SAC, TD3 -from torchy_baselines.common.identity_env import FakeImageEnv +from stable_baselines3 import A2C, PPO, SAC, TD3 +from stable_baselines3.common.identity_env import FakeImageEnv SAVE_PATH = './cnn_model.zip' diff --git a/tests/test_custom_policy.py b/tests/test_custom_policy.py index dac69dcf9..9637f4e52 100644 --- a/tests/test_custom_policy.py +++ b/tests/test_custom_policy.py @@ -1,7 +1,7 @@ import pytest import torch as th -from torchy_baselines import A2C, PPO, SAC, TD3 +from stable_baselines3 import A2C, PPO, SAC, TD3 @pytest.mark.parametrize('net_arch', [ diff --git a/tests/test_distributions.py b/tests/test_distributions.py index de97b1390..1f340bb7d 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -1,11 +1,11 @@ import pytest import torch as th -from torchy_baselines import A2C, PPO -from torchy_baselines.common.distributions import (DiagGaussianDistribution, TanhBijector, +from stable_baselines3 import A2C, PPO +from stable_baselines3.common.distributions import (DiagGaussianDistribution, TanhBijector, StateDependentNoiseDistribution, CategoricalDistribution, SquashedDiagGaussianDistribution) -from torchy_baselines.common.utils import set_random_seed +from stable_baselines3.common.utils import set_random_seed N_ACTIONS = 2 diff --git a/tests/test_identity.py b/tests/test_identity.py index b717c0a56..d937c7e96 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -1,10 +1,10 @@ import numpy as np import pytest -from torchy_baselines import A2C, PPO, SAC, TD3 -from torchy_baselines.common.identity_env import IdentityEnvBox, IdentityEnv -from torchy_baselines.common.evaluation import evaluate_policy -from torchy_baselines.common.noise import NormalActionNoise +from stable_baselines3 import A2C, PPO, SAC, TD3 +from stable_baselines3.common.identity_env import IdentityEnvBox, IdentityEnv +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.noise import NormalActionNoise @pytest.mark.parametrize("model_class", [A2C, PPO]) diff --git a/tests/test_logger.py b/tests/test_logger.py index df6d05993..f648091d2 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -4,7 +4,7 @@ import pytest import numpy as np -from torchy_baselines.common.logger import (make_output_format, read_csv, read_json, DEBUG, ScopedConfigure, +from stable_baselines3.common.logger import (make_output_format, read_csv, read_json, DEBUG, ScopedConfigure, info, debug, set_level, configure, logkv, logkvs, dumpkvs, logkv_mean, warn, error, reset) @@ -18,7 +18,7 @@ "g": np.array([[[1]]]), } -LOG_DIR = '/tmp/torchy_baselines/' +LOG_DIR = '/tmp/stable_baselines3/' def test_main(): diff --git a/tests/test_monitor.py b/tests/test_monitor.py index b21c33b34..00a78026f 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -5,7 +5,7 @@ import pandas import gym -from torchy_baselines.common.monitor import Monitor, get_monitor_files, load_results +from stable_baselines3.common.monitor import Monitor, get_monitor_files, load_results def test_monitor(tmp_path): diff --git a/tests/test_predict.py b/tests/test_predict.py index 5fd5064d8..fad35c71c 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -1,8 +1,8 @@ import gym import pytest -from torchy_baselines import A2C, PPO, SAC, TD3 -from torchy_baselines.common.vec_env import DummyVecEnv +from stable_baselines3 import A2C, PPO, SAC, TD3 +from stable_baselines3.common.vec_env import DummyVecEnv MODEL_LIST = [ PPO, diff --git a/tests/test_run.py b/tests/test_run.py index 0c6ea813e..0c7174cf5 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,8 +1,8 @@ import numpy as np import pytest -from torchy_baselines import A2C, PPO, SAC, TD3 -from torchy_baselines.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise +from stable_baselines3 import A2C, PPO, SAC, TD3 +from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 5e637bc9c..b7c0924f7 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -5,10 +5,10 @@ import numpy as np import torch as th -from torchy_baselines import A2C, PPO, SAC, TD3 -from torchy_baselines.common.identity_env import IdentityEnvBox -from torchy_baselines.common.vec_env import DummyVecEnv -from torchy_baselines.common.identity_env import FakeImageEnv +from stable_baselines3 import A2C, PPO, SAC, TD3 +from stable_baselines3.common.identity_env import IdentityEnvBox +from stable_baselines3.common.vec_env import DummyVecEnv +from stable_baselines3.common.identity_env import FakeImageEnv MODEL_LIST = [ diff --git a/tests/test_sde.py b/tests/test_sde.py index cf3bbdb4f..eed012f13 100644 --- a/tests/test_sde.py +++ b/tests/test_sde.py @@ -2,7 +2,7 @@ import torch as th from torch.distributions import Normal -from torchy_baselines import A2C, TD3, SAC, PPO +from stable_baselines3 import A2C, TD3, SAC, PPO def test_state_dependent_exploration_grad(): diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index f2dd1c29a..c6ab43b3a 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -7,7 +7,7 @@ import gym import numpy as np -from torchy_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize, VecFrameStack +from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecNormalize, VecFrameStack N_ENVS = 3 VEC_ENV_CLASSES = [DummyVecEnv, SubprocVecEnv] diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 73c0824c1..5ec80c70e 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -2,9 +2,9 @@ import pytest import numpy as np -from torchy_baselines.common.running_mean_std import RunningMeanStd -from torchy_baselines.common.vec_env import DummyVecEnv, VecNormalize, VecFrameStack, sync_envs_normalization, unwrap_vec_normalize -from torchy_baselines import SAC, TD3 +from stable_baselines3.common.running_mean_std import RunningMeanStd +from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize, VecFrameStack, sync_envs_normalization, unwrap_vec_normalize +from stable_baselines3 import SAC, TD3 ENV_ID = 'Pendulum-v0' diff --git a/torchy_baselines/a2c/__init__.py b/torchy_baselines/a2c/__init__.py deleted file mode 100644 index 0cc4be01e..000000000 --- a/torchy_baselines/a2c/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from torchy_baselines.a2c.a2c import A2C -from torchy_baselines.ppo.policies import MlpPolicy diff --git a/torchy_baselines/ppo/__init__.py b/torchy_baselines/ppo/__init__.py deleted file mode 100644 index 72a556096..000000000 --- a/torchy_baselines/ppo/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from torchy_baselines.ppo.ppo import PPO -from torchy_baselines.ppo.policies import MlpPolicy diff --git a/torchy_baselines/sac/__init__.py b/torchy_baselines/sac/__init__.py deleted file mode 100644 index 1132a37ab..000000000 --- a/torchy_baselines/sac/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from torchy_baselines.sac.sac import SAC -from torchy_baselines.sac.policies import MlpPolicy diff --git a/torchy_baselines/td3/__init__.py b/torchy_baselines/td3/__init__.py deleted file mode 100644 index 148be493c..000000000 --- a/torchy_baselines/td3/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from torchy_baselines.td3.td3 import TD3 -from torchy_baselines.td3.policies import MlpPolicy