diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e30e28ffe..dd327435b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 1.6.1a1 (WIP) +Release 1.6.1a2 (WIP) --------------------------- Breaking Changes: @@ -23,6 +23,7 @@ Bug Fixes: - Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers. - Added multidimensional action space support (@qgallouedec) - Fixed missing verbose parameter passing in the ``EvalCallback`` constructor (@burakdmb) +- Fixed the issue that when updating the target network in DQN, SAC, TD3, the ``running_mean`` and ``running_var`` properties of batch norm layers are not updated (@honglu2875) Deprecations: ^^^^^^^^^^^^^ @@ -1026,4 +1027,4 @@ And all the contributors: @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede -@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont +@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 94cd65827..a2126a252 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -4,7 +4,7 @@ import random from collections import deque from itertools import zip_longest -from typing import Dict, Iterable, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union import gym import numpy as np @@ -67,8 +67,8 @@ def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> Update the learning rate for a given optimizer. Useful when doing linear schedule. - :param optimizer: - :param learning_rate: + :param optimizer: Pytorch optimizer + :param learning_rate: New learning rate value """ for param_group in optimizer.param_groups: param_group["lr"] = learning_rate @@ -79,8 +79,8 @@ def get_schedule_fn(value_schedule: Union[Schedule, float, int]) -> Schedule: Transform (if needed) learning rate and clip range (for PPO) to callable. - :param value_schedule: - :return: + :param value_schedule: Constant value of schedule function + :return: Schedule function (can return constant value) """ # If the passed schedule is a float # create a constant function @@ -104,7 +104,7 @@ def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule: :params end_fraction: fraction of ``progress_remaining`` where end is reached e.g 0.1 then end is reached after 10% of the complete training process. - :return: + :return: Linear schedule function. """ def func(progress_remaining: float) -> float: @@ -121,8 +121,8 @@ def constant_fn(val: float) -> Schedule: Create a function that returns a constant It is useful for learning rate schedule (to avoid code duplication) - :param val: - :return: + :param val: constant value + :return: Constant schedule function. """ def func(_): @@ -139,7 +139,7 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device: By default, it tries to use the gpu. :param device: One for 'auto', 'cuda', 'cpu' - :return: + :return: Supported Pytorch device """ # Cuda by default if device == "auto": @@ -386,12 +386,25 @@ def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray: Compute the mean of an array if there is at least one element. For empty array, return NaN. It is used for logging only. - :param arr: + :param arr: Numpy array or list of values :return: """ return np.nan if len(arr) == 0 else np.mean(arr) +def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> List[th.Tensor]: + """ + Extract parameters from the state dict of ``model`` + if the name contains one of the strings in ``included_names``. + + :param model: the model where the parameters come from. + :param included_names: substrings of names to include. + :return: List of parameters values (Pytorch tensors) + that matches the queried names. + """ + return [param for name, param in model.state_dict().items() if any([key in name for key in included_names])] + + def zip_strict(*iterables: Iterable) -> Iterable: r""" ``zip()`` function but enforces that iterables are of equal length. @@ -411,8 +424,8 @@ def zip_strict(*iterables: Iterable) -> Iterable: def polyak_update( - params: Iterable[th.nn.Parameter], - target_params: Iterable[th.nn.Parameter], + params: Iterable[th.Tensor], + target_params: Iterable[th.Tensor], tau: float, ) -> None: """ diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 0cd6dfbf7..839fe334d 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -11,7 +11,7 @@ from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.preprocessing import maybe_transpose from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule -from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update +from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy @@ -140,6 +140,9 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() self._create_aliases() + # Copy running stats, see GH issue #996 + self.batch_norm_stats = get_parameters_by_name(self.q_net, ["running_"]) + self.batch_norm_stats_target = get_parameters_by_name(self.q_net_target, ["running_"]) self.exploration_schedule = get_linear_fn( self.exploration_initial_eps, self.exploration_final_eps, @@ -170,6 +173,8 @@ def _on_step(self) -> None: self._n_calls += 1 if self._n_calls % self.target_update_interval == 0: polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau) + # Copy running stats, see GH issue #996 + polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) self.exploration_rate = self.exploration_schedule(self._current_progress_remaining) self.logger.record("rollout/exploration_rate", self.exploration_rate) diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 07f88d9ab..6969ef1d7 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -10,7 +10,7 @@ from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule -from stable_baselines3.common.utils import polyak_update +from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy @@ -152,6 +152,9 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() self._create_aliases() + # Running mean and running var + self.batch_norm_stats = get_parameters_by_name(self.critic, ["running_"]) + self.batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"]) # Target entropy is used when learning the entropy coefficient if self.target_entropy == "auto": # automatically set target entropy if needed @@ -272,6 +275,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Update target networks if gradient_step % self.target_update_interval == 0: polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau) + # Copy running stats, see GH issue #996 + polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) self._n_updates += gradient_steps diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index 34a783d29..f440b73ee 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -10,7 +10,7 @@ from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.policies import BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule -from stable_baselines3.common.utils import polyak_update +from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy @@ -131,6 +131,11 @@ def __init__( def _setup_model(self) -> None: super()._setup_model() self._create_aliases() + # Running mean and running var + self.actor_batch_norm_stats = get_parameters_by_name(self.actor, ["running_"]) + self.critic_batch_norm_stats = get_parameters_by_name(self.critic, ["running_"]) + self.actor_batch_norm_stats_target = get_parameters_by_name(self.actor_target, ["running_"]) + self.critic_batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"]) def _create_aliases(self) -> None: self.actor = self.policy.actor @@ -189,6 +194,9 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau) polyak_update(self.actor.parameters(), self.actor_target.parameters(), self.tau) + # Copy running stats, see GH issue #996 + polyak_update(self.critic_batch_norm_stats, self.critic_batch_norm_stats_target, 1.0) + polyak_update(self.actor_batch_norm_stats, self.actor_batch_norm_stats_target, 1.0) self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") if len(actor_losses) > 0: diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index e36b72724..51cf83a1c 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.6.1a1 +1.6.1a2 diff --git a/tests/test_train_eval_mode.py b/tests/test_train_eval_mode.py index 4f023e96a..a1a63c0c3 100644 --- a/tests/test_train_eval_mode.py +++ b/tests/test_train_eval_mode.py @@ -143,7 +143,8 @@ def test_dqn_train_with_batch_norm(): policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor), learning_starts=0, seed=1, - tau=0, # do not clone the target + tau=0.0, # do not clone the target + target_update_interval=100, # Copy the stats to the target ) ( @@ -154,6 +155,9 @@ def test_dqn_train_with_batch_norm(): ) = clone_dqn_batch_norm_stats(model) model.learn(total_timesteps=200) + # Force stats copy + model.target_update_interval = 1 + model._on_step() ( q_net_bias_after, @@ -165,8 +169,12 @@ def test_dqn_train_with_batch_norm(): assert ~th.isclose(q_net_bias_before, q_net_bias_after).all() assert ~th.isclose(q_net_running_mean_before, q_net_running_mean_after).all() + # No weight update + assert th.isclose(q_net_bias_before, q_net_target_bias_after).all() assert th.isclose(q_net_target_bias_before, q_net_target_bias_after).all() - assert th.isclose(q_net_target_running_mean_before, q_net_target_running_mean_after).all() + # Running stat should be copied even when tau=0 + assert th.isclose(q_net_running_mean_before, q_net_target_running_mean_before).all() + assert th.isclose(q_net_running_mean_after, q_net_target_running_mean_after).all() def test_td3_train_with_batch_norm(): @@ -210,10 +218,12 @@ def test_td3_train_with_batch_norm(): assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all() assert th.isclose(actor_target_bias_before, actor_target_bias_after).all() - assert th.isclose(actor_target_running_mean_before, actor_target_running_mean_after).all() + # Running stat should be copied even when tau=0 + assert th.isclose(actor_running_mean_after, actor_target_running_mean_after).all() assert th.isclose(critic_target_bias_before, critic_target_bias_after).all() - assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all() + # Running stat should be copied even when tau=0 + assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all() def test_sac_train_with_batch_norm(): @@ -250,10 +260,12 @@ def test_sac_train_with_batch_norm(): assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all() assert ~th.isclose(critic_bias_before, critic_bias_after).all() - assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all() + # Running stat should be copied even when tau=0 + assert th.isclose(critic_running_mean_before, critic_target_running_mean_before).all() assert th.isclose(critic_target_bias_before, critic_target_bias_after).all() - assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all() + # Running stat should be copied even when tau=0 + assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all() @pytest.mark.parametrize("model_class", [A2C, PPO]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 67f2ad1a3..57b4b391a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -14,7 +14,13 @@ from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise -from stable_baselines3.common.utils import get_system_info, is_vectorized_observation, polyak_update, zip_strict +from stable_baselines3.common.utils import ( + get_parameters_by_name, + get_system_info, + is_vectorized_observation, + polyak_update, + zip_strict, +) from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv @@ -322,6 +328,22 @@ def test_vec_noise(): assert len(vec.noises) == num_envs +def test_get_parameters_by_name(): + model = th.nn.Sequential(th.nn.Linear(5, 5), th.nn.BatchNorm1d(5)) + # Initialize stats + model(th.ones(3, 5)) + included_names = ["weight", "bias", "running_"] + # 2 x weight, 2 x bias, 1 x running_mean, 1 x running_var; Ignore num_batches_tracked. + parameters = get_parameters_by_name(model, included_names) + assert len(parameters) == 6 + assert th.allclose(parameters[4], model[1].running_mean) + assert th.allclose(parameters[5], model[1].running_var) + parameters = get_parameters_by_name(model, ["running_"]) + assert len(parameters) == 2 + assert th.allclose(parameters[0], model[1].running_mean) + assert th.allclose(parameters[1], model[1].running_var) + + def test_polyak(): param1, param2 = th.nn.Parameter(th.ones((5, 5))), th.nn.Parameter(th.zeros((5, 5))) target1, target2 = th.nn.Parameter(th.ones((5, 5))), th.nn.Parameter(th.zeros((5, 5)))