Skip to content

Commit

Permalink
Rename to stable-baselines3
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed May 5, 2020
1 parent 4a2c247 commit d542732
Show file tree
Hide file tree
Showing 72 changed files with 164 additions and 164 deletions.
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ omit =
tests/*
setup.py
# Require graphical interface
torchy_baselines/common/results_plotter.py
stable_baselines3/common/results_plotter.py

[report]
exclude_lines =
Expand Down
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE/issue-template.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ If you are submitting a bug report, please fill in the following details.
If your issue is related to a custom gym environment, please check it first using:

```python
from torchy_baselines.common.env_checker import check_env
from stable_baselines3.common.env_checker import check_env

env = CustomEnv(arg1, ...)
# It will check your custom environment and output additional warnings if needed
Expand All @@ -30,7 +30,7 @@ Please use the [markdown code blocks](https://help.github.com/en/articles/creati
for both code and stack traces.

```python
from torchy_baselines import ...
from stable_baselines3 import ...

```

Expand Down
2 changes: 1 addition & 1 deletion NOTICE
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Large portion of the code of Torchy-Baselines (in `common/`) were ported from Stable-Baselines, a fork of OpenAI Baselines,
Large portion of the code of Stable-Baselines3 (in `common/`) were ported from Stable-Baselines, a fork of OpenAI Baselines,
both licensed under the MIT License:

before the fork (June 2018):
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

[![Build Status](https://travis-ci.com/hill-a/stable-baselines.svg?branch=master)](https://travis-ci.com/hill-a/stable-baselines) [![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines.readthedocs.io/en/master/?badge=master)

# Torchy Baselines
# Stable Baselines3

PyTorch version of [Stable Baselines](https://github.com/hill-a/stable-baselines), a set of improved implementations of reinforcement learning algorithms.

Expand Down Expand Up @@ -58,7 +58,7 @@ To cite this repository in publications:
```
@misc{torchy-baselines,
author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah},
title = {Torchy Baselines},
title = {Stable Baselines3},
year = {2019},
publisher = {GitHub},
journal = {GitHub repository},
Expand Down
20 changes: 10 additions & 10 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,19 @@ def __getattr__(cls, name):
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)


import torchy_baselines
import stable_baselines3


# -- Project information -----------------------------------------------------

project = 'Torchy Baselines'
copyright = '2020, Torchy Baselines'
author = 'Torchy Baselines Contributors'
project = 'Stable Baselines3'
copyright = '2020, Stable Baselines3'
author = 'Stable Baselines3 Contributors'

# The short X.Y version
version = 'master (' + torchy_baselines.__version__ + ' )'
version = 'master (' + stable_baselines3.__version__ + ' )'
# The full version, including alpha/beta/rc tags
release = torchy_baselines.__version__
release = stable_baselines3.__version__


# -- General configuration ---------------------------------------------------
Expand Down Expand Up @@ -179,8 +179,8 @@ def setup(app):
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'TorchyBaselines.tex', 'Torchy Baselines Documentation',
'Torchy Baselines Contributors', 'manual'),
(master_doc, 'TorchyBaselines.tex', 'Stable Baselines3 Documentation',
'Stable Baselines3 Contributors', 'manual'),
]


Expand All @@ -189,7 +189,7 @@ def setup(app):
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'torchybaselines', 'Torchy Baselines Documentation',
(master_doc, 'torchybaselines', 'Stable Baselines3 Documentation',
[author], 1)
]

Expand All @@ -200,7 +200,7 @@ def setup(app):
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'TorchyBaselines', 'Torchy Baselines Documentation',
(master_doc, 'TorchyBaselines', 'Stable Baselines3 Documentation',
author, 'TorchyBaselines', 'One line description of project.',
'Miscellaneous'),
]
Expand Down
8 changes: 4 additions & 4 deletions docs/guide/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ Here is a quick example of how to train and run SAC on a Pendulum environment:
import gym
from torchy_baselines.sac.policies import MlpPolicy
from torchy_baselines.common.vec_env import DummyVecEnv
from torchy_baselines import SAC
from stable_baselines3.sac.policies import MlpPolicy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import SAC
env = gym.make('Pendulum-v0')
Expand All @@ -34,6 +34,6 @@ the policy is registered:

.. code-block:: python
from torchy_baselines import SAC
from stable_baselines3 import SAC
model = SAC('MlpPolicy', 'Pendulum-v0').learn(10000)
2 changes: 1 addition & 1 deletion docs/guide/vec_envs.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.. _vec_env:

.. automodule:: torchy_baselines.common.vec_env
.. automodule:: stable_baselines3.common.vec_env

Vectorized Environments
=======================
Expand Down
8 changes: 4 additions & 4 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to Torchy Baselines docs! - Pytorch RL Baselines
Welcome to Stable Baselines3 docs! - Pytorch RL Baselines
========================================================

`Torchy Baselines <https://github.com/hill-a/stable-baselines>`_ is the PyTorch version of `Stable Baselines <https://github.com/hill-a/stable-baselines>`_,
`Stable Baselines3 <https://github.com/hill-a/stable-baselines>`_ is the PyTorch version of `Stable Baselines <https://github.com/hill-a/stable-baselines>`_,
a set of improved implementations of reinforcement learning algorithms.

RL Baselines Zoo (collection of pre-trained agents): https://github.com/araffin/rl-baselines-zoo
Expand Down Expand Up @@ -41,15 +41,15 @@ RL Baselines zoo also offers a simple interface to train, evaluate agents and do
misc/changelog


Citing Torchy Baselines
Citing Stable Baselines3
-----------------------
To cite this project in publications:

.. code-block:: bibtex
@misc{torchy-baselines,
author = {Raffin, Antonin and Hill, Ashley and Ernestus, Maximilian and Gleave, Adam and Kanervisto, Anssi and Dormann, Noah},
title = {Torchy Baselines},
title = {Stable Baselines3},
year = {2019},
publisher = {GitHub},
journal = {GitHub repository},
Expand Down
4 changes: 2 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Pre-Release 0.2.0 (2020-02-14)

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Python 2 support was dropped, Torchy Baselines now requires Python 3.6 or above
- Python 2 support was dropped, Stable Baselines3 now requires Python 3.6 or above
- Return type of ``evaluation.evaluate_policy()`` has been changed
- Refactored the replay buffer to avoid transformation between PyTorch and NumPy
- Created `OffPolicyRLModel` base class
Expand Down Expand Up @@ -160,7 +160,7 @@ New Features:
Maintainers
-----------

Torchy-Baselines is currently maintained by `Antonin Raffin`_ (aka `@araffin`_).
Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_).

.. _Antonin Raffin: https://araffin.github.io/
.. _@araffin: https://github.com/araffin
Expand Down
8 changes: 4 additions & 4 deletions docs/modules/a2c.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.. _a2c:

.. automodule:: torchy_baselines.a2c
.. automodule:: stable_baselines3.a2c


A2C
Expand Down Expand Up @@ -44,9 +44,9 @@ Train a A2C agent on `CartPole-v1` using 4 processes.
import gym
from torchy_baselines.common.policies import MlpPolicy
from torchy_baselines.common import make_vec_env
from torchy_baselines import A2C
from stable_baselines3.common.policies import MlpPolicy
from stable_baselines3.common import make_vec_env
from stable_baselines3 import A2C
# Parallel environments
env = make_vec_env('CartPole-v1', n_envs=4)
Expand Down
2 changes: 1 addition & 1 deletion docs/modules/base.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.. _base_algo:

.. automodule:: torchy_baselines.common.base_class
.. automodule:: stable_baselines3.common.base_class


Base RL Class
Expand Down
8 changes: 4 additions & 4 deletions docs/modules/ppo.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.. _ppo2:

.. automodule:: torchy_baselines.ppo
.. automodule:: stable_baselines3.ppo

PPO
===
Expand Down Expand Up @@ -53,9 +53,9 @@ Train a PPO agent on `Pendulum-v0` using 4 processes.
import gym
from torchy_baselines.ppo.policies import MlpPolicy
from torchy_baselines.common.vec_env import SubprocVecEnv
from torchy_baselines import PPO
from stable_baselines3.ppo.policies import MlpPolicy
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3 import PPO
# multiprocess environment
n_cpu = 4
Expand Down
10 changes: 5 additions & 5 deletions docs/modules/sac.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.. _sac:

.. automodule:: torchy_baselines.sac
.. automodule:: stable_baselines3.sac


SAC
Expand All @@ -14,7 +14,7 @@ A key feature of SAC, and a major difference with common RL algorithms, is that

.. warning::

The SAC model does not support ``torchy_baselines.common.policies`` because it uses double q-values
The SAC model does not support ``stable_baselines3.common.policies`` because it uses double q-values
and value estimation, as a result it must use its own policy models (see :ref:`sac_policies`).


Expand Down Expand Up @@ -72,9 +72,9 @@ Example
import gym
import numpy as np
from torchy_baselines.sac.policies import MlpPolicy
from torchy_baselines.common.vec_env import DummyVecEnv
from torchy_baselines import SAC
from stable_baselines3.sac.policies import MlpPolicy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import SAC
env = gym.make('Pendulum-v0')
env = DummyVecEnv([lambda: env])
Expand Down
10 changes: 5 additions & 5 deletions docs/modules/td3.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.. _td3:

.. automodule:: torchy_baselines.td3
.. automodule:: stable_baselines3.td3


TD3
Expand All @@ -14,7 +14,7 @@ We recommend reading `OpenAI Spinning guide on TD3 <https://spinningup.openai.co

.. warning::

The TD3 model does not support ``torchy_baselines.common.policies`` because it uses double q-values
The TD3 model does not support ``stable_baselines3.common.policies`` because it uses double q-values
estimation, as a result it must use its own policy models (see :ref:`td3_policies`).


Expand Down Expand Up @@ -64,9 +64,9 @@ Example
import numpy as np
from torchy_baselines import TD3
from torchy_baselines.td3.policies import MlpPolicy
from torchy_baselines.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines3 import TD3
from stable_baselines3.td3.policies import MlpPolicy
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
# The noise objects for TD3
n_actions = env.action_space.shape[-1]
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ filterwarnings =
ignore::UserWarning:gym

[pytype]
inputs = torchy_baselines
inputs = stable_baselines3
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import subprocess
from setuptools import setup, find_packages

with open(os.path.join('torchy_baselines', 'version.txt'), 'r') as file_handler:
with open(os.path.join('stable_baselines3', 'version.txt'), 'r') as file_handler:
__version__ = file_handler.read()


setup(name='torchy_baselines',
setup(name='stable_baselines3',
packages=[package for package in find_packages()
if package.startswith('torchy_baselines')],
if package.startswith('stable_baselines3')],
install_requires=[
'gym[classic_control]>=0.11',
'numpy',
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os

from torchy_baselines.a2c import A2C
from torchy_baselines.ppo import PPO
from torchy_baselines.sac import SAC
from torchy_baselines.td3 import TD3
from stable_baselines3.a2c import A2C
from stable_baselines3.ppo import PPO
from stable_baselines3.sac import SAC
from stable_baselines3.td3 import TD3

# Read version from file
version_file = os.path.join(os.path.dirname(__file__), 'version.txt')
Expand Down
2 changes: 2 additions & 0 deletions stable_baselines3/a2c/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from stable_baselines3.a2c.a2c import A2C
from stable_baselines3.ppo.policies import MlpPolicy
10 changes: 5 additions & 5 deletions torchy_baselines/a2c/a2c.py → stable_baselines3/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from gym import spaces
from typing import Type, Union, Callable, Optional, Dict, Any

from torchy_baselines.common import logger
from torchy_baselines.common.type_aliases import GymEnv, MaybeCallback
from torchy_baselines.common.utils import explained_variance
from torchy_baselines.ppo.policies import PPOPolicy
from torchy_baselines.ppo.ppo import PPO
from stable_baselines3.common import logger
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import explained_variance
from stable_baselines3.ppo.policies import PPOPolicy
from stable_baselines3.ppo.ppo import PPO


class A2C(PPO):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
import torch as th
import numpy as np

from torchy_baselines.common import logger
from torchy_baselines.common.policies import BasePolicy, get_policy_from_name
from torchy_baselines.common.utils import set_random_seed, get_schedule_fn, update_learning_rate, get_device
from torchy_baselines.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize, VecTransposeImage
from torchy_baselines.common.preprocessing import is_image_space
from torchy_baselines.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr
from torchy_baselines.common.type_aliases import GymEnv, TensorDict, RolloutReturn, MaybeCallback
from torchy_baselines.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
from torchy_baselines.common.monitor import Monitor
from torchy_baselines.common.noise import ActionNoise
from torchy_baselines.common.buffers import ReplayBuffer
from stable_baselines3.common import logger
from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
from stable_baselines3.common.utils import set_random_seed, get_schedule_fn, update_learning_rate, get_device
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize, VecTransposeImage
from stable_baselines3.common.preprocessing import is_image_space
from stable_baselines3.common.save_util import data_to_json, json_to_data, recursive_getattr, recursive_setattr
from stable_baselines3.common.type_aliases import GymEnv, TensorDict, RolloutReturn, MaybeCallback
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, ConvertCallback, EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise
from stable_baselines3.common.buffers import ReplayBuffer


class BaseRLModel(ABC):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import torch as th
from gym import spaces

from torchy_baselines.common.vec_env import VecNormalize
from torchy_baselines.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples
from torchy_baselines.common.preprocessing import get_action_dim, get_obs_shape
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples
from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape


class BaseBuffer(object):
Expand Down
Loading

0 comments on commit d542732

Please sign in to comment.