Skip to content

Commit

Permalink
Include running_mean and running_val when updating target networks (
Browse files Browse the repository at this point in the history
DLR-RM#1004)

* include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3.

* Update stable_baselines3/common/utils.py

Co-authored-by: Antonin RAFFIN <[email protected]>

* Precompute batch norm parameters in `_setup_model` and directly copy them in the target update.

* include `running_mean` and `running_val` when updating target networks in DQN, SAC, TD3.

* Update stable_baselines3/common/utils.py

Co-authored-by: Antonin RAFFIN <[email protected]>

* Precompute batch norm parameters in `_setup_model` and directly copy them in the target update.

* Fix `DictReplayBuffer.next_observations` type (DLR-RM#1013)

* Fix DictReplayBuffer.next_observations type

* Update changelog

Co-authored-by: Antonin RAFFIN <[email protected]>

* Fixed missing verbose parameter passing (DLR-RM#1011)

Co-authored-by: Quentin Gallouédec <[email protected]>

* Support for `device=auto` buffers and set it as default value (DLR-RM#1009)

* Default device is "auto" for buffer + auto device support in BufferBaseClass

* Update docstring

* Update tests

* Unify tests

* Update changelog

* Fix tests on CUDA device

Co-authored-by: Antonin RAFFIN <[email protected]>
Co-authored-by: Antonin Raffin <[email protected]>

* Precompute batch norm parameters in `_setup_model` and directly copy them in the target update.

* Update test

* Add comments and update tests

* Bump version

* Remove one extra space to conform code style.

* Update docstrings

Co-authored-by: Antonin RAFFIN <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Burak Demirbilek <[email protected]>
Co-authored-by: Antonin Raffin <[email protected]>
  • Loading branch information
5 people authored Aug 23, 2022
1 parent 01cc127 commit 29a481a
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 25 deletions.
5 changes: 3 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 1.6.1a1 (WIP)
Release 1.6.1a2 (WIP)
---------------------------

Breaking Changes:
Expand All @@ -23,6 +23,7 @@ Bug Fixes:
- Fixed division by zero error when computing FPS when a small number of time has elapsed in operating systems with low-precision timers.
- Added multidimensional action space support (@qgallouedec)
- Fixed missing verbose parameter passing in the ``EvalCallback`` constructor (@burakdmb)
- Fixed the issue that when updating the target network in DQN, SAC, TD3, the ``running_mean`` and ``running_var`` properties of batch norm layers are not updated (@honglu2875)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -1026,4 +1027,4 @@ And all the contributors:
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
37 changes: 25 additions & 12 deletions stable_baselines3/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random
from collections import deque
from itertools import zip_longest
from typing import Dict, Iterable, Optional, Tuple, Union
from typing import Dict, Iterable, List, Optional, Tuple, Union

import gym
import numpy as np
Expand Down Expand Up @@ -67,8 +67,8 @@ def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) ->
Update the learning rate for a given optimizer.
Useful when doing linear schedule.
:param optimizer:
:param learning_rate:
:param optimizer: Pytorch optimizer
:param learning_rate: New learning rate value
"""
for param_group in optimizer.param_groups:
param_group["lr"] = learning_rate
Expand All @@ -79,8 +79,8 @@ def get_schedule_fn(value_schedule: Union[Schedule, float, int]) -> Schedule:
Transform (if needed) learning rate and clip range (for PPO)
to callable.
:param value_schedule:
:return:
:param value_schedule: Constant value of schedule function
:return: Schedule function (can return constant value)
"""
# If the passed schedule is a float
# create a constant function
Expand All @@ -104,7 +104,7 @@ def get_linear_fn(start: float, end: float, end_fraction: float) -> Schedule:
:params end_fraction: fraction of ``progress_remaining``
where end is reached e.g 0.1 then end is reached after 10%
of the complete training process.
:return:
:return: Linear schedule function.
"""

def func(progress_remaining: float) -> float:
Expand All @@ -121,8 +121,8 @@ def constant_fn(val: float) -> Schedule:
Create a function that returns a constant
It is useful for learning rate schedule (to avoid code duplication)
:param val:
:return:
:param val: constant value
:return: Constant schedule function.
"""

def func(_):
Expand All @@ -139,7 +139,7 @@ def get_device(device: Union[th.device, str] = "auto") -> th.device:
By default, it tries to use the gpu.
:param device: One for 'auto', 'cuda', 'cpu'
:return:
:return: Supported Pytorch device
"""
# Cuda by default
if device == "auto":
Expand Down Expand Up @@ -386,12 +386,25 @@ def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray:
Compute the mean of an array if there is at least one element.
For empty array, return NaN. It is used for logging only.
:param arr:
:param arr: Numpy array or list of values
:return:
"""
return np.nan if len(arr) == 0 else np.mean(arr)


def get_parameters_by_name(model: th.nn.Module, included_names: Iterable[str]) -> List[th.Tensor]:
"""
Extract parameters from the state dict of ``model``
if the name contains one of the strings in ``included_names``.
:param model: the model where the parameters come from.
:param included_names: substrings of names to include.
:return: List of parameters values (Pytorch tensors)
that matches the queried names.
"""
return [param for name, param in model.state_dict().items() if any([key in name for key in included_names])]


def zip_strict(*iterables: Iterable) -> Iterable:
r"""
``zip()`` function but enforces that iterables are of equal length.
Expand All @@ -411,8 +424,8 @@ def zip_strict(*iterables: Iterable) -> Iterable:


def polyak_update(
params: Iterable[th.nn.Parameter],
target_params: Iterable[th.nn.Parameter],
params: Iterable[th.Tensor],
target_params: Iterable[th.Tensor],
tau: float,
) -> None:
"""
Expand Down
7 changes: 6 additions & 1 deletion stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update
from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy


Expand Down Expand Up @@ -140,6 +140,9 @@ def __init__(
def _setup_model(self) -> None:
super()._setup_model()
self._create_aliases()
# Copy running stats, see GH issue #996
self.batch_norm_stats = get_parameters_by_name(self.q_net, ["running_"])
self.batch_norm_stats_target = get_parameters_by_name(self.q_net_target, ["running_"])
self.exploration_schedule = get_linear_fn(
self.exploration_initial_eps,
self.exploration_final_eps,
Expand Down Expand Up @@ -170,6 +173,8 @@ def _on_step(self) -> None:
self._n_calls += 1
if self._n_calls % self.target_update_interval == 0:
polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau)
# Copy running stats, see GH issue #996
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)

self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
self.logger.record("rollout/exploration_rate", self.exploration_rate)
Expand Down
7 changes: 6 additions & 1 deletion stable_baselines3/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy


Expand Down Expand Up @@ -152,6 +152,9 @@ def __init__(
def _setup_model(self) -> None:
super()._setup_model()
self._create_aliases()
# Running mean and running var
self.batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
self.batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"])
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == "auto":
# automatically set target entropy if needed
Expand Down Expand Up @@ -272,6 +275,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
# Update target networks
if gradient_step % self.target_update_interval == 0:
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
# Copy running stats, see GH issue #996
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)

self._n_updates += gradient_steps

Expand Down
10 changes: 9 additions & 1 deletion stable_baselines3/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import polyak_update
from stable_baselines3.common.utils import get_parameters_by_name, polyak_update
from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy


Expand Down Expand Up @@ -131,6 +131,11 @@ def __init__(
def _setup_model(self) -> None:
super()._setup_model()
self._create_aliases()
# Running mean and running var
self.actor_batch_norm_stats = get_parameters_by_name(self.actor, ["running_"])
self.critic_batch_norm_stats = get_parameters_by_name(self.critic, ["running_"])
self.actor_batch_norm_stats_target = get_parameters_by_name(self.actor_target, ["running_"])
self.critic_batch_norm_stats_target = get_parameters_by_name(self.critic_target, ["running_"])

def _create_aliases(self) -> None:
self.actor = self.policy.actor
Expand Down Expand Up @@ -189,6 +194,9 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:

polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
polyak_update(self.actor.parameters(), self.actor_target.parameters(), self.tau)
# Copy running stats, see GH issue #996
polyak_update(self.critic_batch_norm_stats, self.critic_batch_norm_stats_target, 1.0)
polyak_update(self.actor_batch_norm_stats, self.actor_batch_norm_stats_target, 1.0)

self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
if len(actor_losses) > 0:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.6.1a1
1.6.1a2
24 changes: 18 additions & 6 deletions tests/test_train_eval_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def test_dqn_train_with_batch_norm():
policy_kwargs=dict(net_arch=[16, 16], features_extractor_class=FlattenBatchNormDropoutExtractor),
learning_starts=0,
seed=1,
tau=0, # do not clone the target
tau=0.0, # do not clone the target
target_update_interval=100, # Copy the stats to the target
)

(
Expand All @@ -154,6 +155,9 @@ def test_dqn_train_with_batch_norm():
) = clone_dqn_batch_norm_stats(model)

model.learn(total_timesteps=200)
# Force stats copy
model.target_update_interval = 1
model._on_step()

(
q_net_bias_after,
Expand All @@ -165,8 +169,12 @@ def test_dqn_train_with_batch_norm():
assert ~th.isclose(q_net_bias_before, q_net_bias_after).all()
assert ~th.isclose(q_net_running_mean_before, q_net_running_mean_after).all()

# No weight update
assert th.isclose(q_net_bias_before, q_net_target_bias_after).all()
assert th.isclose(q_net_target_bias_before, q_net_target_bias_after).all()
assert th.isclose(q_net_target_running_mean_before, q_net_target_running_mean_after).all()
# Running stat should be copied even when tau=0
assert th.isclose(q_net_running_mean_before, q_net_target_running_mean_before).all()
assert th.isclose(q_net_running_mean_after, q_net_target_running_mean_after).all()


def test_td3_train_with_batch_norm():
Expand Down Expand Up @@ -210,10 +218,12 @@ def test_td3_train_with_batch_norm():
assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()

assert th.isclose(actor_target_bias_before, actor_target_bias_after).all()
assert th.isclose(actor_target_running_mean_before, actor_target_running_mean_after).all()
# Running stat should be copied even when tau=0
assert th.isclose(actor_running_mean_after, actor_target_running_mean_after).all()

assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()
# Running stat should be copied even when tau=0
assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all()


def test_sac_train_with_batch_norm():
Expand Down Expand Up @@ -250,10 +260,12 @@ def test_sac_train_with_batch_norm():
assert ~th.isclose(actor_running_mean_before, actor_running_mean_after).all()

assert ~th.isclose(critic_bias_before, critic_bias_after).all()
assert ~th.isclose(critic_running_mean_before, critic_running_mean_after).all()
# Running stat should be copied even when tau=0
assert th.isclose(critic_running_mean_before, critic_target_running_mean_before).all()

assert th.isclose(critic_target_bias_before, critic_target_bias_after).all()
assert th.isclose(critic_target_running_mean_before, critic_target_running_mean_after).all()
# Running stat should be copied even when tau=0
assert th.isclose(critic_running_mean_after, critic_target_running_mean_after).all()


@pytest.mark.parametrize("model_class", [A2C, PPO])
Expand Down
24 changes: 23 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.noise import ActionNoise, OrnsteinUhlenbeckActionNoise, VectorizedActionNoise
from stable_baselines3.common.utils import get_system_info, is_vectorized_observation, polyak_update, zip_strict
from stable_baselines3.common.utils import (
get_parameters_by_name,
get_system_info,
is_vectorized_observation,
polyak_update,
zip_strict,
)
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv


Expand Down Expand Up @@ -322,6 +328,22 @@ def test_vec_noise():
assert len(vec.noises) == num_envs


def test_get_parameters_by_name():
model = th.nn.Sequential(th.nn.Linear(5, 5), th.nn.BatchNorm1d(5))
# Initialize stats
model(th.ones(3, 5))
included_names = ["weight", "bias", "running_"]
# 2 x weight, 2 x bias, 1 x running_mean, 1 x running_var; Ignore num_batches_tracked.
parameters = get_parameters_by_name(model, included_names)
assert len(parameters) == 6
assert th.allclose(parameters[4], model[1].running_mean)
assert th.allclose(parameters[5], model[1].running_var)
parameters = get_parameters_by_name(model, ["running_"])
assert len(parameters) == 2
assert th.allclose(parameters[0], model[1].running_mean)
assert th.allclose(parameters[1], model[1].running_var)


def test_polyak():
param1, param2 = th.nn.Parameter(th.ones((5, 5))), th.nn.Parameter(th.zeros((5, 5)))
target1, target2 = th.nn.Parameter(th.ones((5, 5))), th.nn.Parameter(th.zeros((5, 5)))
Expand Down

0 comments on commit 29a481a

Please sign in to comment.