Skip to content

Commit

Permalink
- added dropq to TQC (from pull request Stable-Baselines-Team#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
RaikoPipe committed Jan 8, 2025
1 parent e0335d7 commit ffec97e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
16 changes: 15 additions & 1 deletion sb3_contrib/tqc/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def __init__(
n_quantiles: int = 25,
n_critics: int = 2,
share_features_extractor: bool = False,
dropout_rate: float = 0.0,
layer_norm: bool = False
):
super().__init__(
observation_space,
Expand All @@ -226,7 +228,14 @@ def __init__(
self.quantiles_total = n_quantiles * n_critics

for i in range(n_critics):
qf_net_list = create_mlp(features_dim + action_dim, n_quantiles, net_arch, activation_fn)
qf_net = create_mlp(
features_dim + action_dim,
n_quantiles,
net_arch,
activation_fn,
dropout_rate=dropout_rate,
layer_norm=layer_norm,
)
qf_net = nn.Sequential(*qf_net_list)
self.add_module(f"qf{i}", qf_net)
self.q_networks.append(qf_net)
Expand Down Expand Up @@ -294,6 +303,9 @@ def __init__(
n_quantiles: int = 25,
n_critics: int = 2,
share_features_extractor: bool = False,
# For the critic only
dropout_rate: float = 0.0,
layer_norm: bool = False,
):
super().__init__(
observation_space,
Expand Down Expand Up @@ -335,6 +347,8 @@ def __init__(
"n_critics": n_critics,
"net_arch": critic_arch,
"share_features_extractor": share_features_extractor,
"dropout_rate": dropout_rate,
"layer_norm": layer_norm,
}
self.critic_kwargs.update(tqc_kwargs)
self.share_features_extractor = share_features_extractor
Expand Down
30 changes: 17 additions & 13 deletions sb3_contrib/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
replay_buffer_class: Optional[type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[dict[str, Any]] = None,
optimize_memory_usage: bool = False,
policy_delay: int = 1,
ent_coef: Union[str, float] = "auto",
target_update_interval: int = 1,
target_entropy: Union[str, float] = "auto",
Expand Down Expand Up @@ -145,6 +146,7 @@ def __init__(
self.target_update_interval = target_update_interval
self.ent_coef_optimizer: Optional[th.optim.Adam] = None
self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net
self.policy_delay = policy_delay

if _init_setup_model:
self._setup_model()
Expand Down Expand Up @@ -190,7 +192,7 @@ def _create_aliases(self) -> None:
self.critic = self.policy.critic
self.critic_target = self.policy.critic_target

def train(self, gradient_steps: int, batch_size: int = 64) -> None:
def train(self, gradient_steps: int, batch_size: int = 64, train_freq: int = 1) -> None:
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update optimizers learning rate
Expand All @@ -205,6 +207,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
actor_losses, critic_losses = [], []

for gradient_step in range(gradient_steps):
self._n_updates += 1
update_actor = self._n_updates % self.policy_delay == 0
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]

Expand All @@ -222,8 +226,9 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
# so we don't change it with other losses
# see https://github.com/rail-berkeley/softlearning/issues/60
ent_coef = th.exp(self.log_ent_coef.detach())
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
ent_coef_losses.append(ent_coef_loss.item())
if update_actor:
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
ent_coef_losses.append(ent_coef_loss.item())
else:
ent_coef = self.ent_coef_tensor

Expand Down Expand Up @@ -265,24 +270,23 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
critic_loss.backward()
self.critic.optimizer.step()

# Compute actor loss
qf_pi = self.critic(replay_data.observations, actions_pi).mean(dim=2).mean(dim=1, keepdim=True)
actor_loss = (ent_coef * log_prob - qf_pi).mean()
actor_losses.append(actor_loss.item())
if update_actor:
qf_pi = self.critic(replay_data.observations, actions_pi).mean(dim=2).mean(dim=1, keepdim=True)
actor_loss = (ent_coef * log_prob - qf_pi).mean()
actor_losses.append(actor_loss.item())

# Optimize the actor
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()

# Optimize the actor
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()

# Update target networks
if gradient_step % self.target_update_interval == 0:
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)

self._n_updates += gradient_steps

self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/ent_coef", np.mean(ent_coefs))
self.logger.record("train/actor_loss", np.mean(actor_losses))
Expand Down

0 comments on commit ffec97e

Please sign in to comment.