From 9cad1d0753b6a405f16c92392bb0ad5f88b8e516 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 14 Jan 2025 14:20:44 +0100 Subject: [PATCH] Add SimBa Policy: Simplicity Bias for Scaling Up Parameters in DRL (#59) * Start testing simba * Quick try with CrossQ * Add actor for CrossQ * Add simba net for TQC * Remove unused param * Add parameter resets for TQC * Fix reset * Add missing param * Update documentation * Add parameter resets * Reformat pyproject.toml * Refactor: share actor between SAC and TQC * Add run tests for simba * Upgrade to python 3.9 (#64) * Fix mypy error, update version --- .github/workflows/ci.yml | 4 +- README.md | 44 +++++++ pyproject.toml | 16 ++- sbx/__init__.py | 2 +- sbx/common/jax_layers.py | 25 +++- sbx/common/off_policy_algorithm.py | 33 +++-- sbx/common/on_policy_algorithm.py | 10 +- sbx/common/policies.py | 139 ++++++++++++++++++- sbx/crossq/crossq.py | 28 ++-- sbx/crossq/policies.py | 205 +++++++++++++++++++++++++++-- sbx/ddpg/ddpg.py | 12 +- sbx/dqn/dqn.py | 14 +- sbx/dqn/policies.py | 8 +- sbx/ppo/policies.py | 9 +- sbx/ppo/ppo.py | 8 +- sbx/sac/policies.py | 103 +++++++++------ sbx/sac/sac.py | 27 ++-- sbx/td3/policies.py | 9 +- sbx/td3/td3.py | 23 ++-- sbx/tqc/policies.py | 130 +++++++++--------- sbx/tqc/tqc.py | 34 ++--- sbx/version.txt | 2 +- setup.py | 15 +-- tests/test_flatten.py | 4 +- tests/test_run.py | 21 ++- tests/test_spaces.py | 4 +- 26 files changed, 702 insertions(+), 227 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 70548c8..1c8d7fb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -20,7 +20,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v3 @@ -52,8 +52,6 @@ jobs: - name: Type check run: | make type - # skip mypy, jax doesn't have its latest version for python 3.8 - if: "!(matrix.python-version == '3.8')" - name: Test with pytest run: | make pytest diff --git a/README.md b/README.md index 7cf8d55..de49233 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,11 @@ Implemented algorithms: - [Twin Delayed DDPG (TD3)](https://arxiv.org/abs/1802.09477) - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971) - [Batch Normalization in Deep Reinforcement Learning (CrossQ)](https://openreview.net/forum?id=PczQtTsTIX) +- [Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning (SimBa)](https://openreview.net/forum?id=jXLiDKsuDo) +Note: parameter resets for off-policy algorithms can be activated by passing a list of timesteps to the model constructor (ex: `param_resets=[int(1e5), int(5e5)]` to reset parameters and optimizers after 100_000 and 500_000 timesteps. + ### Install using pip For the latest master version: @@ -132,6 +135,47 @@ Having a higher learning rate for the q-value function is also helpful: `qf_lear Note: when using the DroQ configuration with CrossQ, you should set `layer_norm=False` as there is already batch normalization. +## Note about SimBa + +[SimBa](https://openreview.net/forum?id=jXLiDKsuDo) is a special network architecture for off-policy algorithms (SAC, TQC, ...). + +Some recommended hyperparameters (tested on MuJoCo and PyBullet environments): +```python +import optax + + +default_hyperparams = dict( + n_envs=1, + n_timesteps=int(1e6), + policy="SimbaPolicy", + learning_rate=3e-4, + # qf_learning_rate=1e-3, + policy_kwargs={ + "optimizer_class": optax.adamw, + # "optimizer_kwargs": {"weight_decay": 0.01}, + # Note: here [128] represent a residual block, not just a single layer + "net_arch": {"pi": [128], "qf": [256, 256]}, + "n_critics": 2, + }, + learning_starts=10_000, + # Important: input normalization using VecNormalize + normalize={"norm_obs": True, "norm_reward": False}, +) + +hyperparams = {} + +# You can also loop gym.registry +for env_id in [ + "HalfCheetah-v4", + "HalfCheetahBulletEnv-v0", + "Ant-v4", +]: + hyperparams[env_id] = default_hyperparams +``` + +and then using the RL Zoo script defined above: `python train.py --algo tqc --env HalfCheetah-v4 -c simba.py -P`. + + ## Benchmark A partial benchmark can be found on [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sbx) where you can also find several [reports](https://wandb.ai/openrlbenchmark/sbx/reportlist). diff --git a/pyproject.toml b/pyproject.toml index bb89786..57514e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,8 @@ [tool.ruff] # Same as Black. line-length = 127 -# Assume Python 3.8 -target-version = "py38" +# Assume Python 3.9 +target-version = "py39" [tool.ruff.lint] # See https://beta.ruff.rs/docs/rules/ @@ -28,9 +28,7 @@ show_error_codes = true [tool.pytest.ini_options] # Deterministic ordering for tests; useful for pytest-xdist. -env = [ - "PYTHONHASHSEED=0" -] +env = ["PYTHONHASHSEED=0"] filterwarnings = [ # Tensorboard warnings @@ -41,7 +39,7 @@ filterwarnings = [ "ignore:rich is experimental", ] markers = [ - "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')" + "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')", ] [tool.coverage.run] @@ -50,4 +48,8 @@ branch = false omit = ["tests/*", "setup.py"] [tool.coverage.report] -exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:"] +exclude_lines = [ + "pragma: no cover", + "raise NotImplementedError()", + "if typing.TYPE_CHECKING:", +] diff --git a/sbx/__init__.py b/sbx/__init__.py index a7c13bc..c2762bc 100644 --- a/sbx/__init__.py +++ b/sbx/__init__.py @@ -23,11 +23,11 @@ def DroQ(*args, **kwargs): __all__ = [ - "CrossQ", "DDPG", "DQN", "PPO", "SAC", "TD3", "TQC", + "CrossQ", ] diff --git a/sbx/common/jax_layers.py b/sbx/common/jax_layers.py index 67e79e8..01880aa 100644 --- a/sbx/common/jax_layers.py +++ b/sbx/common/jax_layers.py @@ -1,5 +1,7 @@ -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union +import flax.linen as nn import jax import jax.numpy as jnp from flax.linen.module import Module, compact, merge_param @@ -8,7 +10,7 @@ PRNGKey = Any Array = Any -Shape = Tuple[int, ...] +Shape = tuple[int, ...] Dtype = Any # this could be a real type? Axes = Union[int, Sequence[int]] @@ -204,3 +206,22 @@ def __call__(self, x, use_running_average: Optional[bool] = None): self.bias_init, self.scale_init, ) + + +# Adapted from simba: https://github.com/SonyResearch/simba +class SimbaResidualBlock(nn.Module): + hidden_dim: int + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + # "the MLP is structured with an inverted bottleneck, where the hidden + # dimension is expanded to 4 * hidden_dim" + scale_factor: int = 4 + norm_layer: type[nn.Module] = nn.LayerNorm + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + residual = x + x = self.norm_layer()(x) + x = nn.Dense(self.hidden_dim * self.scale_factor, kernel_init=nn.initializers.he_normal())(x) + x = self.activation_fn(x) + x = nn.Dense(self.hidden_dim, kernel_init=nn.initializers.he_normal())(x) + return residual + x diff --git a/sbx/common/off_policy_algorithm.py b/sbx/common/off_policy_algorithm.py index ba0b9ed..b7b385c 100644 --- a/sbx/common/off_policy_algorithm.py +++ b/sbx/common/off_policy_algorithm.py @@ -1,6 +1,6 @@ import io import pathlib -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Optional, Union import jax import numpy as np @@ -17,7 +17,7 @@ class OffPolicyAlgorithmJax(OffPolicyAlgorithm): def __init__( self, - policy: Type[BasePolicy], + policy: type[BasePolicy], env: Union[GymEnv, str], learning_rate: Union[float, Schedule], qf_learning_rate: Optional[float] = None, @@ -26,13 +26,13 @@ def __init__( batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = (1, "step"), + train_freq: Union[int, tuple[int, str]] = (1, "step"), gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, optimize_memory_usage: bool = False, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, tensorboard_log: Optional[str] = None, verbose: int = 0, device: str = "auto", @@ -43,7 +43,9 @@ def __init__( sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, sde_support: bool = True, - supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None, + stats_window_size: int = 100, + param_resets: Optional[list[int]] = None, + supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None, ): super().__init__( policy=policy, @@ -62,6 +64,7 @@ def __init__( use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup, + stats_window_size=stats_window_size, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, verbose=verbose, @@ -74,11 +77,25 @@ def __init__( self.key = jax.random.PRNGKey(0) # Note: we do not allow schedule for it self.qf_learning_rate = qf_learning_rate + self.param_resets = param_resets + self.reset_idx = 0 + + def _maybe_reset_params(self) -> None: + # Maybe reset the parameters + if ( + self.param_resets + and self.reset_idx < len(self.param_resets) + and self.num_timesteps >= self.param_resets[self.reset_idx] + ): + # Note: we are not resetting the entropy coeff + assert isinstance(self.qf_learning_rate, float) + self.key = self.policy.build(self.key, self.lr_schedule, self.qf_learning_rate) + self.reset_idx += 1 def _get_torch_save_params(self): return [], [] - def _excluded_save_params(self) -> List[str]: + def _excluded_save_params(self) -> list[str]: excluded = super()._excluded_save_params() excluded.remove("policy") return excluded diff --git a/sbx/common/on_policy_algorithm.py b/sbx/common/on_policy_algorithm.py index 015fdc2..36b8ea6 100644 --- a/sbx/common/on_policy_algorithm.py +++ b/sbx/common/on_policy_algorithm.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Optional, TypeVar, Union import gymnasium as gym import jax @@ -24,7 +24,7 @@ class OnPolicyAlgorithmJax(OnPolicyAlgorithm): def __init__( self, - policy: Union[str, Type[BasePolicy]], + policy: Union[str, type[BasePolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule], n_steps: int, @@ -37,12 +37,12 @@ def __init__( sde_sample_freq: int, tensorboard_log: Optional[str] = None, monitor_wrapper: bool = True, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: str = "auto", _init_setup_model: bool = True, - supported_action_spaces: Optional[Tuple[Type[spaces.Space], ...]] = None, + supported_action_spaces: Optional[tuple[type[spaces.Space], ...]] = None, ): super().__init__( policy=policy, # type: ignore[arg-type] @@ -70,7 +70,7 @@ def __init__( def _get_torch_save_params(self): return [], [] - def _excluded_save_params(self) -> List[str]: + def _excluded_save_params(self) -> list[str]: excluded = super()._excluded_save_params() excluded.remove("policy") return excluded diff --git a/sbx/common/policies.py b/sbx/common/policies.py index 4b6df5d..59c53eb 100644 --- a/sbx/common/policies.py +++ b/sbx/common/policies.py @@ -1,15 +1,22 @@ # import copy -from typing import Callable, Dict, Optional, Sequence, Tuple, Union, no_type_check +from collections.abc import Sequence +from typing import Callable, Optional, Union, no_type_check import flax.linen as nn import jax import jax.numpy as jnp import numpy as np +import tensorflow_probability.substrates.jax as tfp from gymnasium import spaces from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import is_image_space, maybe_transpose from stable_baselines3.common.utils import is_vectorized_observation +from sbx.common.distributions import TanhTransformedDistribution +from sbx.common.jax_layers import SimbaResidualBlock + +tfd = tfp.distributions + class Flatten(nn.Module): """ @@ -43,11 +50,11 @@ def select_action(actor_state, obervations): @no_type_check def predict( self, - observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[Tuple[np.ndarray, ...]] = None, + observation: Union[np.ndarray, dict[str, np.ndarray]], + state: Optional[tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]: # self.set_training_mode(False) observation, vectorized_env = self.prepare_obs(observation) @@ -74,7 +81,7 @@ def predict( return actions, state - def prepare_obs(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> Tuple[np.ndarray, bool]: + def prepare_obs(self, observation: Union[np.ndarray, dict[str, np.ndarray]]) -> tuple[np.ndarray, bool]: vectorized_env = False if isinstance(observation, dict): assert isinstance(self.observation_space, spaces.Dict) @@ -84,7 +91,7 @@ def prepare_obs(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) -> # Add batch dim and concatenate observation = np.concatenate( - [observation[key].reshape(-1, *self.observation_space[key].shape) for key in keys], + [observation[key].reshape(-1, *self.observation_space[key].shape) for key in keys], # type: ignore[misc] axis=1, ) # need to copy the dict as the dict in VecFrameStack will become a torch tensor @@ -127,6 +134,7 @@ class ContinuousCritic(nn.Module): use_layer_norm: bool = False dropout_rate: Optional[float] = None activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + output_dim: int = 1 @nn.compact def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: @@ -139,7 +147,32 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: if self.use_layer_norm: x = nn.LayerNorm()(x) x = self.activation_fn(x) - x = nn.Dense(1)(x) + x = nn.Dense(self.output_dim)(x) + return x + + +class SimbaContinuousCritic(nn.Module): + net_arch: Sequence[int] + use_layer_norm: bool = False # for consistency, not used + dropout_rate: Optional[float] = None + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + output_dim: int = 1 + scale_factor: int = 4 + + @nn.compact + def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray: + x = Flatten()(x) + x = jnp.concatenate([x, action], -1) + # Note: simba was using kernel_init=orthogonal_init(1) + x = nn.Dense(self.net_arch[0])(x) + for n_units in self.net_arch: + x = SimbaResidualBlock(n_units, self.activation_fn, self.scale_factor)(x) + # TODO: double check where to put the dropout + if self.dropout_rate is not None and self.dropout_rate > 0: + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) + x = nn.LayerNorm()(x) + + x = nn.Dense(self.output_dim)(x) return x @@ -149,6 +182,7 @@ class VectorCritic(nn.Module): dropout_rate: Optional[float] = None n_critics: int = 2 activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + output_dim: int = 1 @nn.compact def __call__(self, obs: jnp.ndarray, action: jnp.ndarray): @@ -167,5 +201,96 @@ def __call__(self, obs: jnp.ndarray, action: jnp.ndarray): dropout_rate=self.dropout_rate, net_arch=self.net_arch, activation_fn=self.activation_fn, + output_dim=self.output_dim, )(obs, action) return q_values + + +class SimbaVectorCritic(nn.Module): + net_arch: Sequence[int] + # Note: we have use_layer_norm for consistency but it is not used (always on) + use_layer_norm: bool = True + dropout_rate: Optional[float] = None + n_critics: int = 2 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + output_dim: int = 1 + + @nn.compact + def __call__(self, obs: jnp.ndarray, action: jnp.ndarray): + # Idea taken from https://github.com/perrin-isir/xpag + # Similar to https://github.com/tinkoff-ai/CORL for PyTorch + vmap_critic = nn.vmap( + SimbaContinuousCritic, + variable_axes={"params": 0}, # parameters not shared between the critics + split_rngs={"params": True, "dropout": True}, # different initializations + in_axes=None, + out_axes=0, + axis_size=self.n_critics, + ) + q_values = vmap_critic( + dropout_rate=self.dropout_rate, + net_arch=self.net_arch, + activation_fn=self.activation_fn, + output_dim=self.output_dim, + )(obs, action) + return q_values + + +class SquashedGaussianActor(nn.Module): + net_arch: Sequence[int] + action_dim: int + log_std_min: float = -20 + log_std_max: float = 2 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + + def get_std(self): + # Make it work with gSDE + return jnp.array(0.0) + + @nn.compact + def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] + x = Flatten()(x) + for n_units in self.net_arch: + x = nn.Dense(n_units)(x) + x = self.activation_fn(x) + mean = nn.Dense(self.action_dim)(x) + log_std = nn.Dense(self.action_dim)(x) + log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) + dist = TanhTransformedDistribution( + tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), + ) + return dist + + +class SimbaSquashedGaussianActor(nn.Module): + # Note: each element in net_arch correpond to a residual block + # not just a single layer + net_arch: Sequence[int] + action_dim: int + # num_blocks: int = 2 + log_std_min: float = -20 + log_std_max: float = 2 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + scale_factor: int = 4 + + def get_std(self): + # Make it work with gSDE + return jnp.array(0.0) + + @nn.compact + def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] + x = Flatten()(x) + + # Note: simba was using kernel_init=orthogonal_init(1) + x = nn.Dense(self.net_arch[0])(x) + for n_units in self.net_arch: + x = SimbaResidualBlock(n_units, self.activation_fn, self.scale_factor)(x) + x = nn.LayerNorm()(x) + + mean = nn.Dense(self.action_dim)(x) + log_std = nn.Dense(self.action_dim)(x) + log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) + dist = TanhTransformedDistribution( + tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), + ) + return dist diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index f888672..03404bb 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, ClassVar, Dict, Literal, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Literal, Optional, Union import flax import flax.linen as nn @@ -16,7 +16,7 @@ from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax from sbx.common.type_aliases import BatchNormTrainState, ReplayBufferSamplesNp -from sbx.crossq.policies import CrossQPolicy +from sbx.crossq.policies import CrossQPolicy, SimbaCrossQPolicy class EntropyCoef(nn.Module): @@ -40,8 +40,9 @@ def __call__(self) -> float: class CrossQ(OffPolicyAlgorithmJax): - policy_aliases: ClassVar[Dict[str, Type[CrossQPolicy]]] = { # type: ignore[assignment] + policy_aliases: ClassVar[dict[str, type[CrossQPolicy]]] = { # type: ignore[assignment] "MlpPolicy": CrossQPolicy, + "SimbaPolicy": SimbaCrossQPolicy, # Minimal dict support using flatten() "MultiInputPolicy": CrossQPolicy, } @@ -59,19 +60,21 @@ def __init__( learning_starts: int = 100, batch_size: int = 256, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = 1, + train_freq: Union[int, tuple[int, str]] = 1, gradient_steps: int = 1, policy_delay: int = 3, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, ent_coef: Union[str, float] = "auto", target_entropy: Union[Literal["auto"], float] = "auto", use_sde: bool = False, sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, + stats_window_size: int = 100, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, + param_resets: Optional[list[int]] = None, # List of timesteps after which to reset the params verbose: int = 0, seed: Optional[int] = None, device: str = "auto", @@ -94,7 +97,9 @@ def __init__( use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup, + stats_window_size=stats_window_size, policy_kwargs=policy_kwargs, + param_resets=param_resets, tensorboard_log=tensorboard_log, verbose=verbose, seed=seed, @@ -189,6 +194,9 @@ def train(self, gradient_steps: int, batch_size: int) -> None: # Sample all at once for efficiency (so we can jit the for loop) data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) + # Maybe reset the parameters/optimizers fully + self._maybe_reset_params() + if isinstance(data.observations, dict): keys = list(self.observation_space.keys()) # type: ignore[attr-defined] obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1) @@ -258,7 +266,7 @@ def update_critic( def mse_loss( params: flax.core.FrozenDict, batch_stats: flax.core.FrozenDict, dropout_key: flax.core.FrozenDict - ) -> Tuple[jax.Array, jax.Array]: + ) -> tuple[jax.Array, jax.Array]: # Joint forward pass of obs/next_obs and actions/next_state_actions to have only # one forward pass with shape (n_critics, 2 * batch_size, 1). # @@ -320,7 +328,7 @@ def update_actor( def actor_loss( params: flax.core.FrozenDict, batch_stats: flax.core.FrozenDict - ) -> Tuple[jax.Array, Tuple[jax.Array, jax.Array]]: + ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: dist, state_updates = actor_state.apply_fn( {"params": params, "batch_stats": batch_stats}, observations, @@ -416,7 +424,7 @@ def _train( }, } - def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: + def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]: # Note: this method must be defined inline because # `fori_loop` expect a signature fn(index, carry) -> carry actor_state = carry["actor_state"] diff --git a/sbx/crossq/policies.py b/sbx/crossq/policies.py index 0093c47..63b61a8 100644 --- a/sbx/crossq/policies.py +++ b/sbx/crossq/policies.py @@ -1,4 +1,6 @@ -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from functools import partial +from typing import Any, Callable, Optional, Union import flax.linen as nn import jax @@ -10,7 +12,7 @@ from stable_baselines3.common.type_aliases import Schedule from sbx.common.distributions import TanhTransformedDistribution -from sbx.common.jax_layers import BatchRenorm +from sbx.common.jax_layers import BatchRenorm, SimbaResidualBlock from sbx.common.policies import BaseJaxPolicy, Flatten from sbx.common.type_aliases import BatchNormTrainState @@ -48,12 +50,52 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) -> x = nn.LayerNorm()(x) x = self.activation_fn(x) if self.use_batch_norm: - x = BatchRenorm(use_running_average=not train, momentum=self.batch_norm_momentum)(x) + x = BatchRenorm( + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + )(x) x = nn.Dense(1)(x) return x +class SimbaCritic(nn.Module): + net_arch: Sequence[int] + dropout_rate: Optional[float] = None + batch_norm_momentum: float = 0.99 + renorm_warmup_steps: int = 100_000 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + scale_factor: int = 4 + + @nn.compact + def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) -> jnp.ndarray: + x = Flatten()(x) + x = jnp.concatenate([x, action], -1) + norm_layer = partial( + BatchRenorm, + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + ) + x = norm_layer()(x) + x = nn.Dense(self.net_arch[0])(x) + + for n_units in self.net_arch: + x = SimbaResidualBlock( + n_units, + self.activation_fn, + self.scale_factor, + norm_layer, # type: ignore[arg-type] + )(x) + # TODO: double check where to put the dropout + if self.dropout_rate is not None and self.dropout_rate > 0: + x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) + x = norm_layer()(x) + x = nn.Dense(1)(x) + return x + + class VectorCritic(nn.Module): net_arch: Sequence[int] use_layer_norm: bool = False @@ -88,6 +130,87 @@ def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, train: bool = False): return q_values +class SimbaVectorCritic(nn.Module): + net_arch: Sequence[int] + use_layer_norm: bool = False # ignored + use_batch_norm: bool = True + batch_norm_momentum: float = 0.99 + renorm_warmup_steps: int = 100_000 + dropout_rate: Optional[float] = None + n_critics: int = 2 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + scale_factor: int = 4 + + @nn.compact + def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, train: bool = False): + # Idea taken from https://github.com/perrin-isir/xpag + # Similar to https://github.com/tinkoff-ai/CORL for PyTorch + vmap_critic = nn.vmap( + SimbaCritic, + variable_axes={"params": 0, "batch_stats": 0}, # parameters not shared between the critics + split_rngs={"params": True, "dropout": True, "batch_stats": True}, # different initializations + in_axes=None, + out_axes=0, + axis_size=self.n_critics, + ) + q_values = vmap_critic( + # use_layer_norm=self.use_layer_norm, + # use_batch_norm=self.use_batch_norm, + batch_norm_momentum=self.batch_norm_momentum, + renorm_warmup_steps=self.renorm_warmup_steps, + dropout_rate=self.dropout_rate, + net_arch=self.net_arch, + activation_fn=self.activation_fn, + scale_factor=self.scale_factor, + )(obs, action, train) + return q_values + + +class SimbaActor(nn.Module): + net_arch: Sequence[int] + action_dim: int + log_std_min: float = -20 + log_std_max: float = 2 + use_batch_norm: bool = True + batch_norm_momentum: float = 0.99 + renorm_warmup_steps: int = 100_000 + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu + scale_factor: int = 4 + + def get_std(self): + # Make it work with gSDE + return jnp.array(0.0) + + @nn.compact + def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: # type: ignore[name-defined] + x = Flatten()(x) + norm_layer = partial( + BatchRenorm, + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + ) + x = norm_layer()(x) + x = nn.Dense(self.net_arch[0])(x) + + for n_units in self.net_arch: + x = SimbaResidualBlock( + n_units, + self.activation_fn, + self.scale_factor, + norm_layer, # type: ignore[arg-type] + )(x) + x = norm_layer()(x) + + mean = nn.Dense(self.action_dim)(x) + log_std = nn.Dense(self.action_dim)(x) + log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) + dist = TanhTransformedDistribution( + tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), + ) + return dist + + class Actor(nn.Module): net_arch: Sequence[int] action_dim: int @@ -119,7 +242,11 @@ def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: # x = nn.Dense(n_units)(x) x = self.activation_fn(x) if self.use_batch_norm: - x = BatchRenorm(use_running_average=not train, momentum=self.batch_norm_momentum)(x) + x = BatchRenorm( + use_running_average=not train, + momentum=self.batch_norm_momentum, + warmup_steps=self.renorm_warmup_steps, + )(x) mean = nn.Dense(self.action_dim)(x) log_std = nn.Dense(self.action_dim)(x) @@ -138,7 +265,7 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, dropout_rate: float = 0.0, layer_norm: bool = False, batch_norm: bool = True, # for critic @@ -153,12 +280,14 @@ def __init__( use_expln: bool = False, clip_mean: float = 2.0, features_extractor_class=None, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, + actor_class: type[nn.Module] = Actor, + vector_critic_class: type[nn.Module] = VectorCritic, ): if optimizer_kwargs is None: # Note: the default value for b1 is 0.9 in Adam. @@ -183,6 +312,8 @@ def __init__( self.batch_norm_momentum = batch_norm_momentum self.batch_norm_actor = batch_norm_actor self.renorm_warmup_steps = renorm_warmup_steps + self.actor_class = actor_class + self.vector_critic_class = vector_critic_class if net_arch is not None: if isinstance(net_arch, list): @@ -216,7 +347,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) obs = jnp.array([self.observation_space.sample()]) action = jnp.array([self.action_space.sample()]) - self.actor = Actor( + self.actor = self.actor_class( action_dim=int(np.prod(self.action_space.shape)), net_arch=self.net_arch_pi, use_batch_norm=self.batch_norm_actor, @@ -244,7 +375,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) ), ) - self.qf = VectorCritic( + self.qf = self.vector_critic_class( dropout_rate=self.dropout_rate, use_layer_norm=self.layer_norm, use_batch_norm=self.batch_norm, @@ -319,3 +450,59 @@ def _predict(self, observation: np.ndarray, deterministic: bool = False) -> np.n if not self.use_sde: self.reset_noise() return self.sample_action(self.actor_state, observation, self.noise_key) + + +class SimbaCrossQPolicy(CrossQPolicy): + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Box, + lr_schedule: Schedule, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + dropout_rate: float = 0, + layer_norm: bool = False, + batch_norm: bool = True, + batch_norm_actor: bool = True, + batch_norm_momentum: float = 0.99, + renorm_warmup_steps: int = 100000, + use_sde: bool = False, + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, + log_std_init: float = -3, + use_expln: bool = False, + clip_mean: float = 2, + features_extractor_class=None, + features_extractor_kwargs: Optional[dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Callable[..., optax.GradientTransformation] = optax.adamw, + optimizer_kwargs: Optional[dict[str, Any]] = None, + n_critics: int = 2, + share_features_extractor: bool = False, + actor_class: type[nn.Module] = SimbaActor, + vector_critic_class: type[nn.Module] = SimbaVectorCritic, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + dropout_rate, + layer_norm, + batch_norm, + batch_norm_actor, + batch_norm_momentum, + renorm_warmup_steps, + use_sde, + activation_fn, + log_std_init, + use_expln, + clip_mean, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + n_critics, + share_features_extractor, + actor_class, + vector_critic_class, + ) diff --git a/sbx/ddpg/ddpg.py b/sbx/ddpg/ddpg.py index 12678cf..7ee5728 100644 --- a/sbx/ddpg/ddpg.py +++ b/sbx/ddpg/ddpg.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Optional, Union from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise @@ -9,7 +9,7 @@ class DDPG(TD3): - policy_aliases: ClassVar[Dict[str, Type[TD3Policy]]] = { + policy_aliases: ClassVar[dict[str, type[TD3Policy]]] = { "MlpPolicy": TD3Policy, } @@ -24,13 +24,13 @@ def __init__( batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = 1, + train_freq: Union[int, tuple[int, str]] = 1, gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: str = "auto", diff --git a/sbx/dqn/dqn.py b/sbx/dqn/dqn.py index 852b823..98b0f42 100644 --- a/sbx/dqn/dqn.py +++ b/sbx/dqn/dqn.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Optional, Union import gymnasium as gym import jax @@ -15,7 +15,7 @@ class DQN(OffPolicyAlgorithmJax): - policy_aliases: ClassVar[Dict[str, Type[DQNPolicy]]] = { # type: ignore[assignment] + policy_aliases: ClassVar[dict[str, type[DQNPolicy]]] = { # type: ignore[assignment] "MlpPolicy": DQNPolicy, "CnnPolicy": CNNPolicy, } @@ -39,10 +39,10 @@ def __init__( exploration_final_eps: float = 0.05, optimize_memory_usage: bool = False, # Note: unused but to match SB3 API # max_grad_norm: float = 10, - train_freq: Union[int, Tuple[int, str]] = 4, + train_freq: Union[int, tuple[int, str]] = 4, gradient_steps: int = 1, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: str = "auto", @@ -230,11 +230,11 @@ def _on_step(self) -> None: def predict( self, - observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[Tuple[np.ndarray, ...]] = None, + observation: Union[np.ndarray, dict[str, np.ndarray]], + state: Optional[tuple[np.ndarray, ...]] = None, episode_start: Optional[np.ndarray] = None, deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: + ) -> tuple[np.ndarray, Optional[tuple[np.ndarray, ...]]]: """ Overrides the base_class predict function to include epsilon-greedy exploration. diff --git a/sbx/dqn/policies.py b/sbx/dqn/policies.py index 4cff77b..c03be4d 100644 --- a/sbx/dqn/policies.py +++ b/sbx/dqn/policies.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Optional, Union import flax.linen as nn import jax @@ -62,13 +62,13 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Discrete, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, features_extractor_class=None, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_kwargs: Optional[dict[str, Any]] = None, ): super().__init__( observation_space, diff --git a/sbx/ppo/policies.py b/sbx/ppo/policies.py index 54915c8..b1446e9 100644 --- a/sbx/ppo/policies.py +++ b/sbx/ppo/policies.py @@ -1,5 +1,6 @@ +from collections.abc import Sequence from dataclasses import field -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Optional, Union import flax.linen as nn import gymnasium as gym @@ -101,7 +102,7 @@ def __init__( observation_space: gym.spaces.Space, action_space: gym.spaces.Space, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, ortho_init: bool = False, log_std_init: float = 0.0, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh, @@ -111,10 +112,10 @@ def __init__( use_expln: bool = False, clip_mean: float = 2.0, features_extractor_class=None, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_kwargs: Optional[dict[str, Any]] = None, share_features_extractor: bool = False, ): if optimizer_kwargs is None: diff --git a/sbx/ppo/ppo.py b/sbx/ppo/ppo.py index 3fccee9..5ef219e 100644 --- a/sbx/ppo/ppo.py +++ b/sbx/ppo/ppo.py @@ -1,6 +1,6 @@ import warnings from functools import partial -from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union +from typing import Any, ClassVar, Optional, TypeVar, Union import jax import jax.numpy as jnp @@ -68,7 +68,7 @@ class PPO(OnPolicyAlgorithmJax): :param _init_setup_model: Whether or not to build the network at the creation of the instance """ - policy_aliases: ClassVar[Dict[str, Type[PPOPolicy]]] = { # type: ignore[assignment] + policy_aliases: ClassVar[dict[str, type[PPOPolicy]]] = { # type: ignore[assignment] "MlpPolicy": PPOPolicy, # "CnnPolicy": ActorCriticCnnPolicy, # "MultiInputPolicy": MultiInputActorCriticPolicy, @@ -77,7 +77,7 @@ class PPO(OnPolicyAlgorithmJax): def __init__( self, - policy: Union[str, Type[PPOPolicy]], + policy: Union[str, type[PPOPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 2048, @@ -95,7 +95,7 @@ def __init__( sde_sample_freq: int = -1, target_kl: Optional[float] = None, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, verbose: int = 0, seed: Optional[int] = None, device: str = "auto", diff --git a/sbx/sac/policies.py b/sbx/sac/policies.py index f936f91..95c319f 100644 --- a/sbx/sac/policies.py +++ b/sbx/sac/policies.py @@ -1,47 +1,23 @@ -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Optional, Union import flax.linen as nn import jax import jax.numpy as jnp import numpy as np import optax -import tensorflow_probability.substrates.jax as tfp from flax.training.train_state import TrainState from gymnasium import spaces from stable_baselines3.common.type_aliases import Schedule -from sbx.common.distributions import TanhTransformedDistribution -from sbx.common.policies import BaseJaxPolicy, Flatten, VectorCritic +from sbx.common.policies import ( + BaseJaxPolicy, + SimbaSquashedGaussianActor, + SimbaVectorCritic, + SquashedGaussianActor, + VectorCritic, +) from sbx.common.type_aliases import RLTrainState -tfd = tfp.distributions - - -class Actor(nn.Module): - net_arch: Sequence[int] - action_dim: int - log_std_min: float = -20 - log_std_max: float = 2 - activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu - - def get_std(self): - # Make it work with gSDE - return jnp.array(0.0) - - @nn.compact - def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] - x = Flatten()(x) - for n_units in self.net_arch: - x = nn.Dense(n_units)(x) - x = self.activation_fn(x) - mean = nn.Dense(self.action_dim)(x) - log_std = nn.Dense(self.action_dim)(x) - log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) - dist = TanhTransformedDistribution( - tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), - ) - return dist - class SACPolicy(BaseJaxPolicy): action_space: spaces.Box # type: ignore[assignment] @@ -51,7 +27,7 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, dropout_rate: float = 0.0, layer_norm: bool = False, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, @@ -62,12 +38,14 @@ def __init__( use_expln: bool = False, clip_mean: float = 2.0, features_extractor_class=None, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, + actor_class: type[nn.Module] = SquashedGaussianActor, + vector_critic_class: type[nn.Module] = VectorCritic, ): super().__init__( observation_space, @@ -91,6 +69,8 @@ def __init__( self.n_critics = n_critics self.use_sde = use_sde self.activation_fn = activation_fn + self.actor_class = actor_class + self.vector_critic_class = vector_critic_class self.key = self.noise_key = jax.random.PRNGKey(0) @@ -107,7 +87,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) obs = jnp.array([self.observation_space.sample()]) action = jnp.array([self.action_space.sample()]) - self.actor = Actor( + self.actor = self.actor_class( action_dim=int(np.prod(self.action_space.shape)), net_arch=self.net_arch_pi, activation_fn=self.activation_fn, @@ -124,7 +104,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) ), ) - self.qf = VectorCritic( + self.qf = self.vector_critic_class( dropout_rate=self.dropout_rate, use_layer_norm=self.layer_norm, net_arch=self.net_arch_qf, @@ -174,3 +154,52 @@ def _predict(self, observation: np.ndarray, deterministic: bool = False) -> np.n if not self.use_sde: self.reset_noise() return BaseJaxPolicy.sample_action(self.actor_state, observation, self.noise_key) + + +class SimbaSACPolicy(SACPolicy): + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Box, + lr_schedule: Schedule, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + dropout_rate: float = 0, + layer_norm: bool = False, + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, + use_sde: bool = False, + log_std_init: float = -3, + use_expln: bool = False, + clip_mean: float = 2, + features_extractor_class=None, + features_extractor_kwargs: Optional[dict[str, Any]] = None, + normalize_images: bool = True, + # AdamW for simba + optimizer_class: Callable[..., optax.GradientTransformation] = optax.adamw, + optimizer_kwargs: Optional[dict[str, Any]] = None, + n_critics: int = 2, + share_features_extractor: bool = False, + actor_class: type[nn.Module] = SimbaSquashedGaussianActor, + vector_critic_class: type[nn.Module] = SimbaVectorCritic, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + dropout_rate, + layer_norm, + activation_fn, + use_sde, + log_std_init, + use_expln, + clip_mean, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + n_critics, + share_features_extractor, + actor_class, + vector_critic_class, + ) diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index e3795cd..1a18fc3 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, ClassVar, Dict, Literal, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Literal, Optional, Union import flax import flax.linen as nn @@ -16,7 +16,7 @@ from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState -from sbx.sac.policies import SACPolicy +from sbx.sac.policies import SACPolicy, SimbaSACPolicy class EntropyCoef(nn.Module): @@ -40,8 +40,10 @@ def __call__(self) -> float: class SAC(OffPolicyAlgorithmJax): - policy_aliases: ClassVar[Dict[str, Type[SACPolicy]]] = { # type: ignore[assignment] + policy_aliases: ClassVar[dict[str, type[SACPolicy]]] = { # type: ignore[assignment] "MlpPolicy": SACPolicy, + # Residual net, from https://github.com/SonyResearch/simba + "SimbaPolicy": SimbaSACPolicy, # Minimal dict support using flatten() "MultiInputPolicy": SACPolicy, } @@ -60,19 +62,21 @@ def __init__( batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = 1, + train_freq: Union[int, tuple[int, str]] = 1, gradient_steps: int = 1, policy_delay: int = 1, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, ent_coef: Union[str, float] = "auto", target_entropy: Union[Literal["auto"], float] = "auto", use_sde: bool = False, sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, + stats_window_size: int = 100, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, + param_resets: Optional[list[int]] = None, # List of timesteps after which to reset the params verbose: int = 0, seed: Optional[int] = None, device: str = "auto", @@ -96,7 +100,9 @@ def __init__( use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup, + stats_window_size=stats_window_size, policy_kwargs=policy_kwargs, + param_resets=param_resets, tensorboard_log=tensorboard_log, verbose=verbose, seed=seed, @@ -189,6 +195,9 @@ def train(self, gradient_steps: int, batch_size: int) -> None: # Sample all at once for efficiency (so we can jit the for loop) data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) + # Maybe reset the parameters/optimizers fully + self._maybe_reset_params() + if isinstance(data.observations, dict): keys = list(self.observation_space.keys()) # type: ignore[attr-defined] obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1) @@ -291,7 +300,7 @@ def update_actor( ): key, dropout_key, noise_key = jax.random.split(key, 3) - def actor_loss(params: flax.core.FrozenDict) -> Tuple[jax.Array, jax.Array]: + def actor_loss(params: flax.core.FrozenDict) -> tuple[jax.Array, jax.Array]: dist = actor_state.apply_fn(params, observations) actor_actions = dist.sample(seed=noise_key) log_prob = dist.log_prob(actor_actions).reshape(-1, 1) @@ -385,7 +394,7 @@ def _train( }, } - def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: + def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]: # Note: this method must be defined inline because # `fori_loop` expect a signature fn(index, carry) -> carry actor_state = carry["actor_state"] diff --git a/sbx/td3/policies.py b/sbx/td3/policies.py index 3592535..459113e 100644 --- a/sbx/td3/policies.py +++ b/sbx/td3/policies.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union import flax.linen as nn import jax @@ -34,16 +35,16 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, dropout_rate: float = 0.0, layer_norm: bool = False, activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, use_sde: bool = False, features_extractor_class=None, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, ): diff --git a/sbx/td3/td3.py b/sbx/td3/td3.py index 7ece762..304952a 100644 --- a/sbx/td3/td3.py +++ b/sbx/td3/td3.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Optional, Union import flax import jax @@ -17,7 +17,7 @@ class TD3(OffPolicyAlgorithmJax): - policy_aliases: ClassVar[Dict[str, Type[TD3Policy]]] = { # type: ignore[assignment] + policy_aliases: ClassVar[dict[str, type[TD3Policy]]] = { # type: ignore[assignment] "MlpPolicy": TD3Policy, # Minimal dict support using flatten() "MultiInputPolicy": TD3Policy, @@ -37,16 +37,18 @@ def __init__( batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = 1, + train_freq: Union[int, tuple[int, str]] = 1, gradient_steps: int = 1, policy_delay: int = 2, target_policy_noise: float = 0.2, target_noise_clip: float = 0.5, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + stats_window_size: int = 100, + policy_kwargs: Optional[dict[str, Any]] = None, + param_resets: Optional[list[int]] = None, # List of timesteps after which to reset the params verbose: int = 0, seed: Optional[int] = None, device: str = "auto", @@ -68,7 +70,9 @@ def __init__( replay_buffer_class=replay_buffer_class, replay_buffer_kwargs=replay_buffer_kwargs, use_sde=False, + stats_window_size=stats_window_size, policy_kwargs=policy_kwargs, + param_resets=param_resets, tensorboard_log=tensorboard_log, verbose=verbose, seed=seed, @@ -124,6 +128,9 @@ def train(self, gradient_steps: int, batch_size: int) -> None: # Sample all at once for efficiency (so we can jit the for loop) data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) + # Maybe reset the parameters/optimizers fully + self._maybe_reset_params() + if isinstance(data.observations, dict): keys = list(self.observation_space.keys()) # type: ignore[attr-defined] obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1) @@ -243,7 +250,7 @@ def actor_loss(params: flax.core.FrozenDict) -> jax.Array: @staticmethod @jax.jit - def soft_update(tau: float, qf_state: RLTrainState, actor_state: RLTrainState) -> Tuple[RLTrainState, RLTrainState]: + def soft_update(tau: float, qf_state: RLTrainState, actor_state: RLTrainState) -> tuple[RLTrainState, RLTrainState]: qf_state = qf_state.replace(target_params=optax.incremental_update(qf_state.params, qf_state.target_params, tau)) actor_state = actor_state.replace( target_params=optax.incremental_update(actor_state.params, actor_state.target_params, tau) @@ -279,7 +286,7 @@ def _train( }, } - def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: + def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]: # Note: this method must be defined inline because # `fori_loop` expect a signature fn(index, carry) -> carry actor_state = carry["actor_state"] diff --git a/sbx/tqc/policies.py b/sbx/tqc/policies.py index d075de5..0c4c03f 100644 --- a/sbx/tqc/policies.py +++ b/sbx/tqc/policies.py @@ -1,69 +1,23 @@ -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Optional, Union import flax.linen as nn import jax import jax.numpy as jnp import numpy as np import optax -import tensorflow_probability.substrates.jax as tfp from flax.training.train_state import TrainState from gymnasium import spaces from stable_baselines3.common.type_aliases import Schedule -from sbx.common.distributions import TanhTransformedDistribution -from sbx.common.policies import BaseJaxPolicy, Flatten +from sbx.common.policies import ( + BaseJaxPolicy, + ContinuousCritic, + SimbaContinuousCritic, + SimbaSquashedGaussianActor, + SquashedGaussianActor, +) from sbx.common.type_aliases import RLTrainState -tfd = tfp.distributions - - -class Critic(nn.Module): - net_arch: Sequence[int] - use_layer_norm: bool = False - dropout_rate: Optional[float] = None - n_quantiles: int = 25 - activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu - - @nn.compact - def __call__(self, x: jnp.ndarray, a: jnp.ndarray, training: bool = False) -> jnp.ndarray: - x = Flatten()(x) - x = jnp.concatenate([x, a], -1) - for n_units in self.net_arch: - x = nn.Dense(n_units)(x) - if self.dropout_rate is not None and self.dropout_rate > 0: - x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False) - if self.use_layer_norm: - x = nn.LayerNorm()(x) - x = self.activation_fn(x) - x = nn.Dense(self.n_quantiles)(x) - return x - - -class Actor(nn.Module): - net_arch: Sequence[int] - action_dim: int - log_std_min: float = -20 - log_std_max: float = 2 - activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu - - def get_std(self): - # Make it work with gSDE - return jnp.array(0.0) - - @nn.compact - def __call__(self, x: jnp.ndarray) -> tfd.Distribution: # type: ignore[name-defined] - x = Flatten()(x) - for n_units in self.net_arch: - x = nn.Dense(n_units)(x) - x = self.activation_fn(x) - mean = nn.Dense(self.action_dim)(x) - log_std = nn.Dense(self.action_dim)(x) - log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max) - dist = TanhTransformedDistribution( - tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)), - ) - return dist - class TQCPolicy(BaseJaxPolicy): action_space: spaces.Box # type: ignore[assignment] @@ -73,7 +27,7 @@ def __init__( observation_space: spaces.Space, action_space: spaces.Box, lr_schedule: Schedule, - net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, dropout_rate: float = 0.0, layer_norm: bool = False, top_quantiles_to_drop_per_net: int = 2, @@ -86,12 +40,14 @@ def __init__( use_expln: bool = False, clip_mean: float = 2.0, features_extractor_class=None, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor_kwargs: Optional[dict[str, Any]] = None, normalize_images: bool = True, optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, + optimizer_kwargs: Optional[dict[str, Any]] = None, n_critics: int = 2, share_features_extractor: bool = False, + actor_class: type[nn.Module] = SquashedGaussianActor, + critic_class: type[nn.Module] = ContinuousCritic, ): super().__init__( observation_space, @@ -121,6 +77,8 @@ def __init__( self.n_target_quantiles = quantiles_total - top_quantiles_to_drop_per_net * self.n_critics self.use_sde = use_sde self.activation_fn = activation_fn + self.actor_class = actor_class + self.critic_class = critic_class self.key = self.noise_key = jax.random.PRNGKey(0) @@ -137,7 +95,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) action = jnp.array([self.action_space.sample()]) - self.actor = Actor( + self.actor = self.actor_class( action_dim=int(np.prod(self.action_space.shape)), net_arch=self.net_arch_pi, activation_fn=self.activation_fn, @@ -154,11 +112,11 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float) ), ) - self.qf = Critic( + self.qf = self.critic_class( dropout_rate=self.dropout_rate, use_layer_norm=self.layer_norm, net_arch=self.net_arch_qf, - n_quantiles=self.n_quantiles, + output_dim=self.n_quantiles, activation_fn=self.activation_fn, ) @@ -217,3 +175,55 @@ def _predict(self, observation: np.ndarray, deterministic: bool = False) -> np.n if not self.use_sde: self.reset_noise() return BaseJaxPolicy.sample_action(self.actor_state, observation, self.noise_key) + + +class SimbaTQCPolicy(TQCPolicy): + def __init__( + self, + observation_space: spaces.Space, + action_space: spaces.Box, + lr_schedule: Schedule, + net_arch: Optional[Union[list[int], dict[str, list[int]]]] = None, + dropout_rate: float = 0, + layer_norm: bool = False, + top_quantiles_to_drop_per_net: int = 2, + n_quantiles: int = 25, + activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu, + use_sde: bool = False, + log_std_init: float = -3, + use_expln: bool = False, + clip_mean: float = 2, + features_extractor_class=None, + features_extractor_kwargs: Optional[dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Callable[..., optax.GradientTransformation] = optax.adamw, + optimizer_kwargs: Optional[dict[str, Any]] = None, + n_critics: int = 2, + share_features_extractor: bool = False, + actor_class: type[nn.Module] = SimbaSquashedGaussianActor, + critic_class: type[nn.Module] = SimbaContinuousCritic, + ): + super().__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + dropout_rate, + layer_norm, + top_quantiles_to_drop_per_net, + n_quantiles, + activation_fn, + use_sde, + log_std_init, + use_expln, + clip_mean, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + n_critics, + share_features_extractor, + actor_class, + critic_class, + ) diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index f723c31..15fa8e3 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, ClassVar, Dict, Literal, Optional, Tuple, Type, Union +from typing import Any, ClassVar, Literal, Optional, Union import flax import flax.linen as nn @@ -16,7 +16,7 @@ from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax from sbx.common.type_aliases import ReplayBufferSamplesNp, RLTrainState -from sbx.tqc.policies import TQCPolicy +from sbx.tqc.policies import SimbaTQCPolicy, TQCPolicy class EntropyCoef(nn.Module): @@ -40,8 +40,9 @@ def __call__(self) -> float: class TQC(OffPolicyAlgorithmJax): - policy_aliases: ClassVar[Dict[str, Type[TQCPolicy]]] = { # type: ignore[assignment] + policy_aliases: ClassVar[dict[str, type[TQCPolicy]]] = { # type: ignore[assignment] "MlpPolicy": TQCPolicy, + "SimbaPolicy": SimbaTQCPolicy, # Minimal dict support using flatten() "MultiInputPolicy": TQCPolicy, } @@ -60,20 +61,22 @@ def __init__( batch_size: int = 256, tau: float = 0.005, gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = 1, + train_freq: Union[int, tuple[int, str]] = 1, gradient_steps: int = 1, policy_delay: int = 1, top_quantiles_to_drop_per_net: int = 2, action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, + replay_buffer_class: Optional[type[ReplayBuffer]] = None, + replay_buffer_kwargs: Optional[dict[str, Any]] = None, ent_coef: Union[str, float] = "auto", target_entropy: Union[Literal["auto"], float] = "auto", use_sde: bool = False, sde_sample_freq: int = -1, use_sde_at_warmup: bool = False, + stats_window_size: int = 100, tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, + policy_kwargs: Optional[dict[str, Any]] = None, + param_resets: Optional[list[int]] = None, # List of timesteps after which to reset the params verbose: int = 0, seed: Optional[int] = None, device: str = "auto", @@ -97,7 +100,9 @@ def __init__( use_sde=use_sde, sde_sample_freq=sde_sample_freq, use_sde_at_warmup=use_sde_at_warmup, + stats_window_size=stats_window_size, policy_kwargs=policy_kwargs, + param_resets=param_resets, tensorboard_log=tensorboard_log, verbose=verbose, seed=seed, @@ -194,6 +199,9 @@ def train(self, gradient_steps: int, batch_size: int) -> None: # Sample all at once for efficiency (so we can jit the for loop) data = self.replay_buffer.sample(batch_size * gradient_steps, env=self._vec_normalize_env) + # Maybe reset the parameters/optimizers fully + self._maybe_reset_params() + if isinstance(data.observations, dict): keys = list(self.observation_space.keys()) # type: ignore[attr-defined] obs = np.concatenate([data.observations[key].numpy() for key in keys], axis=1) @@ -267,14 +275,12 @@ def update_critic( qf1_state.target_params, next_observations, next_state_actions, - True, rngs={"dropout": dropout_key_1}, ) qf2_next_quantiles = qf1_state.apply_fn( qf2_state.target_params, next_observations, next_state_actions, - True, rngs={"dropout": dropout_key_2}, ) @@ -296,7 +302,7 @@ def update_critic( def huber_quantile_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array: # Compute huber quantile loss - current_quantiles = qf1_state.apply_fn(params, observations, actions, True, rngs={"dropout": dropout_key}) + current_quantiles = qf1_state.apply_fn(params, observations, actions, rngs={"dropout": dropout_key}) # convert to shape: (batch_size, n_quantiles, 1) for broadcast current_quantiles = jnp.expand_dims(current_quantiles, axis=-1) @@ -337,7 +343,7 @@ def update_actor( ): key, dropout_key_1, dropout_key_2, noise_key = jax.random.split(key, 4) - def actor_loss(params: flax.core.FrozenDict) -> Tuple[jax.Array, jax.Array]: + def actor_loss(params: flax.core.FrozenDict) -> tuple[jax.Array, jax.Array]: dist = actor_state.apply_fn(params, observations) actor_actions = dist.sample(seed=noise_key) log_prob = dist.log_prob(actor_actions).reshape(-1, 1) @@ -346,14 +352,12 @@ def actor_loss(params: flax.core.FrozenDict) -> Tuple[jax.Array, jax.Array]: qf1_state.params, observations, actor_actions, - True, rngs={"dropout": dropout_key_1}, ) qf2_pi = qf1_state.apply_fn( qf2_state.params, observations, actor_actions, - True, rngs={"dropout": dropout_key_2}, ) qf1_pi = jnp.expand_dims(qf1_pi, axis=-1) @@ -374,7 +378,7 @@ def actor_loss(params: flax.core.FrozenDict) -> Tuple[jax.Array, jax.Array]: @staticmethod @jax.jit - def soft_update(tau: float, qf1_state: RLTrainState, qf2_state: RLTrainState) -> Tuple[RLTrainState, RLTrainState]: + def soft_update(tau: float, qf1_state: RLTrainState, qf2_state: RLTrainState) -> tuple[RLTrainState, RLTrainState]: qf1_state = qf1_state.replace(target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, tau)) qf2_state = qf2_state.replace(target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, tau)) return qf1_state, qf2_state @@ -454,7 +458,7 @@ def _train( }, } - def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: + def one_update(i: int, carry: dict[str, Any]) -> dict[str, Any]: # Note: this method must be defined inline because # `fori_loop` expect a signature fn(index, carry) -> carry actor_state = carry["actor_state"] diff --git a/sbx/version.txt b/sbx/version.txt index 6633391..1cf0537 100644 --- a/sbx/version.txt +++ b/sbx/version.txt @@ -1 +1 @@ -0.18.0 +0.19.0 diff --git a/setup.py b/setup.py index 96c390f..d7752db 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ - [Twin Delayed DDPG (TD3)](https://arxiv.org/abs/1802.09477) - [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971) - [Batch Normalization in Deep Reinforcement Learning (CrossQ)](https://openreview.net/forum?id=PczQtTsTIX) +- [Simplicity Bias for Scaling Up Parameters in Deep Reinforcement Learning (SimBa)](https://openreview.net/forum?id=jXLiDKsuDo) ## Example @@ -40,13 +41,11 @@ packages=[package for package in find_packages() if package.startswith("sbx")], package_data={"sbx": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=2.4.0a4,<3.0", - "jax", + "stable_baselines3>=2.4.0,<3.0", + "jax>=0.4.12", "jaxlib", "flax", - 'optax; python_version >= "3.9.0"', - # See https://github.com/google-deepmind/optax/issues/711 - 'optax<0.1.8; python_version < "3.9.0"', + "optax", "tqdm", "rich", "tensorflow_probability", @@ -71,18 +70,18 @@ url="https://github.com/araffin/sbx", author_email="antonin.raffin@dlr.de", keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning " - "gym openai stable baselines toolbox python data-science", + "gym gymnasium jax openai stable baselines toolbox python data-science", license="MIT", long_description=long_description, long_description_content_type="text/markdown", version=__version__, - python_requires=">=3.8", + python_requires=">=3.9", # PyPI package information. classifiers=[ "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ], ) diff --git a/tests/test_flatten.py b/tests/test_flatten.py index 1c8b46f..c1307f2 100644 --- a/tests/test_flatten.py +++ b/tests/test_flatten.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, Optional +from typing import Optional import gymnasium as gym import numpy as np @@ -17,7 +17,7 @@ class DummyEnv(gym.Env): def step(self, action): return self.observation_space.sample(), 0.0, False, False, {} - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {} diff --git a/tests/test_run.py b/tests/test_run.py index 18d6dec..6d3b5b9 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,4 +1,4 @@ -from typing import Optional, Type +from typing import Optional import flax.linen as nn import numpy as np @@ -71,19 +71,32 @@ def test_tqc(tmp_path) -> None: use_sde=True, qf_learning_rate=1e-3, target_entropy=-10, + param_resets=[125, 150], ) model.learn(200) check_save_load(model, TQC, tmp_path) -@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, CrossQ]) +@pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, CrossQ, "SimbaSAC", "SimbaCrossQ"]) def test_sac_td3(tmp_path, model_class) -> None: + policy = "MlpPolicy" + net_kwargs = {} + if model_class == "SimbaSAC": + model_class = SAC + policy = "SimbaPolicy" + net_kwargs = dict(net_arch=[64]) + elif model_class == "SimbaCrossQ": + model_class = CrossQ + policy = "SimbaPolicy" + net_kwargs = dict(net_arch=[64]) + model = model_class( - "MlpPolicy", + policy, "Pendulum-v1", verbose=1, gradient_steps=1, learning_rate=1e-3, + policy_kwargs=net_kwargs, ) key_before_learn = model.key model.learn(110) @@ -160,7 +173,7 @@ def test_dqn(tmp_path) -> None: @pytest.mark.parametrize("replay_buffer_class", [None, HerReplayBuffer]) -def test_dict(replay_buffer_class: Optional[Type[HerReplayBuffer]]) -> None: +def test_dict(replay_buffer_class: Optional[type[HerReplayBuffer]]) -> None: env = BitFlippingEnv(n_bits=2, continuous=True) model = SAC("MultiInputPolicy", env, replay_buffer_class=replay_buffer_class) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index fe96390..c9c9595 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, Optional +from typing import Optional import gymnasium as gym import numpy as np @@ -19,7 +19,7 @@ def step(self, action): assert action in self.action_space return self.observation_space.sample(), 0.0, False, False, {} - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): if seed is not None: super().reset(seed=seed) return self.observation_space.sample(), {}