Skip to content

Commit

Permalink
Add SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL (#59)
Browse files Browse the repository at this point in the history
* Start testing simba

* Quick try with CrossQ

* Add actor for CrossQ

* Add simba net for TQC

* Remove unused param

* Add parameter resets for TQC

* Fix reset

* Add missing param

* Update documentation

* Add parameter resets

* Reformat pyproject.toml

* Refactor: share actor between SAC and TQC

* Add run tests for simba

* Upgrade to python 3.9 (#64)

* Fix mypy error, update version
  • Loading branch information
araffin authored Jan 14, 2025
1 parent 1c79684 commit 9cad1d0
Show file tree
Hide file tree
Showing 26 changed files with 702 additions and 227 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v3
Expand Down Expand Up @@ -52,8 +52,6 @@ jobs:
- name: Type check
run: |
make type
# skip mypy, jax doesn't have its latest version for python 3.8
if: "!(matrix.python-version == '3.8')"
- name: Test with pytest
run: |
make pytest
44 changes: 44 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ Implemented algorithms:
- [Twin Delayed DDPG (TD3)](https://arxiv.org/abs/1802.09477)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971)
- [Batch Normalization in Deep Reinforcement Learning (CrossQ)](https://openreview.net/forum?id=PczQtTsTIX)
- [Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning (SimBa)](https://openreview.net/forum?id=jXLiDKsuDo)


Note: parameter resets for off-policy algorithms can be activated by passing a list of timesteps to the model constructor (ex: `param_resets=[int(1e5), int(5e5)]` to reset parameters and optimizers after 100_000 and 500_000 timesteps.

### Install using pip

For the latest master version:
Expand Down Expand Up @@ -132,6 +135,47 @@ Having a higher learning rate for the q-value function is also helpful: `qf_lear

Note: when using the DroQ configuration with CrossQ, you should set `layer_norm=False` as there is already batch normalization.

## Note about SimBa

[SimBa](https://openreview.net/forum?id=jXLiDKsuDo) is a special network architecture for off-policy algorithms (SAC, TQC, ...).

Some recommended hyperparameters (tested on MuJoCo and PyBullet environments):
```python
import optax
default_hyperparams = dict(
n_envs=1,
n_timesteps=int(1e6),
policy="SimbaPolicy",
learning_rate=3e-4,
# qf_learning_rate=1e-3,
policy_kwargs={
"optimizer_class": optax.adamw,
# "optimizer_kwargs": {"weight_decay": 0.01},
# Note: here [128] represent a residual block, not just a single layer
"net_arch": {"pi": [128], "qf": [256, 256]},
"n_critics": 2,
},
learning_starts=10_000,
# Important: input normalization using VecNormalize
normalize={"norm_obs": True, "norm_reward": False},
)
hyperparams = {}
# You can also loop gym.registry
for env_id in [
"HalfCheetah-v4",
"HalfCheetahBulletEnv-v0",
"Ant-v4",
]:
hyperparams[env_id] = default_hyperparams
```

and then using the RL Zoo script defined above: `python train.py --algo tqc --env HalfCheetah-v4 -c simba.py -P`.


## Benchmark

A partial benchmark can be found on [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sbx) where you can also find several [reports](https://wandb.ai/openrlbenchmark/sbx/reportlist).
Expand Down
16 changes: 9 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[tool.ruff]
# Same as Black.
line-length = 127
# Assume Python 3.8
target-version = "py38"
# Assume Python 3.9
target-version = "py39"

[tool.ruff.lint]
# See https://beta.ruff.rs/docs/rules/
Expand All @@ -28,9 +28,7 @@ show_error_codes = true

[tool.pytest.ini_options]
# Deterministic ordering for tests; useful for pytest-xdist.
env = [
"PYTHONHASHSEED=0"
]
env = ["PYTHONHASHSEED=0"]

filterwarnings = [
# Tensorboard warnings
Expand All @@ -41,7 +39,7 @@ filterwarnings = [
"ignore:rich is experimental",
]
markers = [
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')"
"expensive: marks tests as expensive (deselect with '-m \"not expensive\"')",
]

[tool.coverage.run]
Expand All @@ -50,4 +48,8 @@ branch = false
omit = ["tests/*", "setup.py"]

[tool.coverage.report]
exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:"]
exclude_lines = [
"pragma: no cover",
"raise NotImplementedError()",
"if typing.TYPE_CHECKING:",
]
2 changes: 1 addition & 1 deletion sbx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ def DroQ(*args, **kwargs):


__all__ = [
"CrossQ",
"DDPG",
"DQN",
"PPO",
"SAC",
"TD3",
"TQC",
"CrossQ",
]
25 changes: 23 additions & 2 deletions sbx/common/jax_layers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Callable, Optional, Sequence, Tuple, Union
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union

import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.linen.module import Module, compact, merge_param
Expand All @@ -8,7 +10,7 @@

PRNGKey = Any
Array = Any
Shape = Tuple[int, ...]
Shape = tuple[int, ...]
Dtype = Any # this could be a real type?
Axes = Union[int, Sequence[int]]

Expand Down Expand Up @@ -204,3 +206,22 @@ def __call__(self, x, use_running_average: Optional[bool] = None):
self.bias_init,
self.scale_init,
)


# Adapted from simba: https://github.com/SonyResearch/simba
class SimbaResidualBlock(nn.Module):
hidden_dim: int
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
# "the MLP is structured with an inverted bottleneck, where the hidden
# dimension is expanded to 4 * hidden_dim"
scale_factor: int = 4
norm_layer: type[nn.Module] = nn.LayerNorm

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
residual = x
x = self.norm_layer()(x)
x = nn.Dense(self.hidden_dim * self.scale_factor, kernel_init=nn.initializers.he_normal())(x)
x = self.activation_fn(x)
x = nn.Dense(self.hidden_dim, kernel_init=nn.initializers.he_normal())(x)
return residual + x
33 changes: 25 additions & 8 deletions sbx/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import io
import pathlib
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Optional, Union

import jax
import numpy as np
Expand All @@ -17,7 +17,7 @@
class OffPolicyAlgorithmJax(OffPolicyAlgorithm):
def __init__(
self,
policy: Type[BasePolicy],
policy: type[BasePolicy],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule],
qf_learning_rate: Optional[float] = None,
Expand All @@ -26,13 +26,13 @@ def __init__(
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = (1, "step"),
train_freq: Union[int, tuple[int, str]] = (1, "step"),
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
optimize_memory_usage: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
policy_kwargs: Optional[dict[str, Any]] = None,
tensorboard_log: Optional[str] = None,
verbose: int = 0,
device: str = "auto",
Expand All @@ -43,7 +43,9 @@ def __init__(
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
sde_support: bool = True,
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
stats_window_size: int = 100,
param_resets: Optional[list[int]] = None,
supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None,
):
super().__init__(
policy=policy,
Expand All @@ -62,6 +64,7 @@ def __init__(
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
use_sde_at_warmup=use_sde_at_warmup,
stats_window_size=stats_window_size,
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
verbose=verbose,
Expand All @@ -74,11 +77,25 @@ def __init__(
self.key = jax.random.PRNGKey(0)
# Note: we do not allow schedule for it
self.qf_learning_rate = qf_learning_rate
self.param_resets = param_resets
self.reset_idx = 0

def _maybe_reset_params(self) -> None:
# Maybe reset the parameters
if (
self.param_resets
and self.reset_idx < len(self.param_resets)
and self.num_timesteps >= self.param_resets[self.reset_idx]
):
# Note: we are not resetting the entropy coeff
assert isinstance(self.qf_learning_rate, float)
self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate)
self.reset_idx += 1

def _get_torch_save_params(self):
return [], []

def _excluded_save_params(self) -> List[str]:
def _excluded_save_params(self) -> list[str]:
excluded = super()._excluded_save_params()
excluded.remove("policy")
return excluded
Expand Down
10 changes: 5 additions & 5 deletions sbx/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Optional, TypeVar, Union

import gymnasium as gym
import jax
Expand All @@ -24,7 +24,7 @@ class OnPolicyAlgorithmJax(OnPolicyAlgorithm):

def __init__(
self,
policy: Union[str, Type[BasePolicy]],
policy: Union[str, type[BasePolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule],
n_steps: int,
Expand All @@ -37,12 +37,12 @@ def __init__(
sde_sample_freq: int,
tensorboard_log: Optional[str] = None,
monitor_wrapper: bool = True,
policy_kwargs: Optional[Dict[str, Any]] = None,
policy_kwargs: Optional[dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: str = "auto",
_init_setup_model: bool = True,
supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None,
supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None,
):
super().__init__(
policy=policy, # type: ignore[arg-type]
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(
def _get_torch_save_params(self):
return [], []

def _excluded_save_params(self) -> List[str]:
def _excluded_save_params(self) -> list[str]:
excluded = super()._excluded_save_params()
excluded.remove("policy")
return excluded
Expand Down
Loading

0 comments on commit 9cad1d0

Please sign in to comment.