diff --git a/docs/source/documents/api/learners/drl/ppo_clip.rst b/docs/source/documents/api/learners/drl/ppo_clip.rst index 8a913b5b..58227023 100644 --- a/docs/source/documents/api/learners/drl/ppo_clip.rst +++ b/docs/source/documents/api/learners/drl/ppo_clip.rst @@ -7,6 +7,47 @@ PPOCLIP_Learner **PyTorch:** +.. py:class:: + xuance.torch.learners.policy_gradient.ppoclip_learner.PPOCLIP_Learner(policy, optimizer, scheduler, device, model_dir, vf_coef, ent_coef, clip_range, clip_grad_norm, use_grad_clip) + + :param policy: xxxxxx. + :type policy: xxxxxx + :param optimizer: xxxxxx. + :type optimizer: xxxxxx + :param scheduler: xxxxxx. + :type scheduler: xxxxxx + :param device: xxxxxx. + :type device: xxxxxx + :param model_dir: xxxxxx. + :type model_dir: xxxxxx + :param vf_coef: xxxxxx. + :type vf_coef: xxxxxx + :param ent_coef: xxxxxx. + :type ent_coef: xxxxxx + :param clip_range: xxxxxx. + :type clip_range: xxxxxx + :param clip_grad_norm: xxxxxx. + :type clip_grad_norm: xxxxxx + :param use_grad_clip: xxxxxx. + :type use_grad_clip: xxxxxx + +.. py:function:: + xuance.torch.learners.policy_gradient.ppoclip_learner.PPOCLIP_Learner.update(obs_batch, act_batch, ret_batch, value_batch, adv_batch,old_logp) + + :param obs_batch: xxxxxx. + :type obs_batch: xxxxxx + :param act_batch: xxxxxx. + :type act_batch: xxxxxx + :param ret_batch: xxxxxx. + :type ret_batch: xxxxxx + :param value_batch: xxxxxx. + :type value_batch: xxxxxx + :param adv_batch: xxxxxx. + :type adv_batch: xxxxxx + :param old_logp: xxxxxx. + :type old_logp: xxxxxx + :return: xxxxxx. + :rtype: xxxxxx .. raw:: html @@ -28,19 +69,85 @@ Source Code ----------------- .. tabs:: - + .. group-tab:: PyTorch - - .. code-block:: python3 + .. code-block:: python + + from xuance.torch.learners import * + + + class PPOCLIP_Learner(Learner): + def __init__(self, + policy: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + device: Optional[Union[int, str, torch.device]] = None, + model_dir: str = "./", + vf_coef: float = 0.25, + ent_coef: float = 0.005, + clip_range: float = 0.25, + clip_grad_norm: float = 0.25, + use_grad_clip: bool = True, + ): + super(PPOCLIP_Learner, self).__init__(policy, optimizer, scheduler, device, model_dir) + self.vf_coef = vf_coef + self.ent_coef = ent_coef + self.clip_range = clip_range + self.clip_grad_norm = clip_grad_norm + self.use_grad_clip = use_grad_clip + + def update(self, obs_batch, act_batch, ret_batch, value_batch, adv_batch, old_logp): + self.iterations += 1 + act_batch = torch.as_tensor(act_batch, device=self.device) + ret_batch = torch.as_tensor(ret_batch, device=self.device) + value_batch = torch.as_tensor(value_batch, device=self.device) + adv_batch = torch.as_tensor(adv_batch, device=self.device) + old_logp_batch = torch.as_tensor(old_logp, device=self.device) + + outputs, a_dist, v_pred = self.policy(obs_batch) + log_prob = a_dist.log_prob(act_batch) + + # ppo-clip core implementations + ratio = (log_prob - old_logp_batch).exp().float() + surrogate1 = ratio.clamp(1.0 - self.clip_range, 1.0 + self.clip_range) * adv_batch + surrogate2 = adv_batch * ratio + a_loss = -torch.minimum(surrogate1, surrogate2).mean() + + c_loss = F.mse_loss(v_pred, ret_batch) + + e_loss = a_dist.entropy().mean() + loss = a_loss - self.ent_coef * e_loss + self.vf_coef * c_loss + self.optimizer.zero_grad() + loss.backward() + if self.use_grad_clip: + torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.clip_grad_norm) + self.optimizer.step() + if self.scheduler is not None: + self.scheduler.step() + # Logger + lr = self.optimizer.state_dict()['param_groups'][0]['lr'] + cr = ((ratio < 1 - self.clip_range).sum() + (ratio > 1 + self.clip_range).sum()) / ratio.shape[0] + + info = { + "actor-loss": a_loss.item(), + "critic-loss": c_loss.item(), + "entropy": e_loss.item(), + "learning_rate": lr, + "predict_value": v_pred.mean().item(), + "clip_ratio": cr + } + + return info .. group-tab:: TensorFlow - .. code-block:: python3 + .. code-block:: python + .. group-tab:: MindSpore - .. code-block:: python3 \ No newline at end of file + .. code-block:: python \ No newline at end of file