Skip to content

Commit

Permalink
ppo_cilp.rst APIs(#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ykizi committed Nov 30, 2023
1 parent 1fa9f2f commit d470ca8
Showing 1 changed file with 112 additions and 5 deletions.
117 changes: 112 additions & 5 deletions docs/source/documents/api/learners/drl/ppo_clip.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,47 @@ PPOCLIP_Learner

**PyTorch:**

.. py:class::
xuance.torch.learners.policy_gradient.ppoclip_learner.PPOCLIP_Learner(policy, optimizer, scheduler, device, model_dir, vf_coef, ent_coef, clip_range, clip_grad_norm, use_grad_clip)

:param policy: xxxxxx.
:type policy: xxxxxx
:param optimizer: xxxxxx.
:type optimizer: xxxxxx
:param scheduler: xxxxxx.
:type scheduler: xxxxxx
:param device: xxxxxx.
:type device: xxxxxx
:param model_dir: xxxxxx.
:type model_dir: xxxxxx
:param vf_coef: xxxxxx.
:type vf_coef: xxxxxx
:param ent_coef: xxxxxx.
:type ent_coef: xxxxxx
:param clip_range: xxxxxx.
:type clip_range: xxxxxx
:param clip_grad_norm: xxxxxx.
:type clip_grad_norm: xxxxxx
:param use_grad_clip: xxxxxx.
:type use_grad_clip: xxxxxx

.. py:function::
xuance.torch.learners.policy_gradient.ppoclip_learner.PPOCLIP_Learner.update(obs_batch, act_batch, ret_batch, value_batch, adv_batch,old_logp)

:param obs_batch: xxxxxx.
:type obs_batch: xxxxxx
:param act_batch: xxxxxx.
:type act_batch: xxxxxx
:param ret_batch: xxxxxx.
:type ret_batch: xxxxxx
:param value_batch: xxxxxx.
:type value_batch: xxxxxx
:param adv_batch: xxxxxx.
:type adv_batch: xxxxxx
:param old_logp: xxxxxx.
:type old_logp: xxxxxx
:return: xxxxxx.
:rtype: xxxxxx

.. raw:: html

Expand All @@ -28,19 +69,85 @@ Source Code
-----------------

.. tabs::

.. group-tab:: PyTorch

.. code-block:: python3

.. code-block:: python
from xuance.torch.learners import *
class PPOCLIP_Learner(Learner):
def __init__(self,
policy: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
device: Optional[Union[int, str, torch.device]] = None,
model_dir: str = "./",
vf_coef: float = 0.25,
ent_coef: float = 0.005,
clip_range: float = 0.25,
clip_grad_norm: float = 0.25,
use_grad_clip: bool = True,
):
super(PPOCLIP_Learner, self).__init__(policy, optimizer, scheduler, device, model_dir)
self.vf_coef = vf_coef
self.ent_coef = ent_coef
self.clip_range = clip_range
self.clip_grad_norm = clip_grad_norm
self.use_grad_clip = use_grad_clip
def update(self, obs_batch, act_batch, ret_batch, value_batch, adv_batch, old_logp):
self.iterations += 1
act_batch = torch.as_tensor(act_batch, device=self.device)
ret_batch = torch.as_tensor(ret_batch, device=self.device)
value_batch = torch.as_tensor(value_batch, device=self.device)
adv_batch = torch.as_tensor(adv_batch, device=self.device)
old_logp_batch = torch.as_tensor(old_logp, device=self.device)
outputs, a_dist, v_pred = self.policy(obs_batch)
log_prob = a_dist.log_prob(act_batch)
# ppo-clip core implementations
ratio = (log_prob - old_logp_batch).exp().float()
surrogate1 = ratio.clamp(1.0 - self.clip_range, 1.0 + self.clip_range) * adv_batch
surrogate2 = adv_batch * ratio
a_loss = -torch.minimum(surrogate1, surrogate2).mean()
c_loss = F.mse_loss(v_pred, ret_batch)
e_loss = a_dist.entropy().mean()
loss = a_loss - self.ent_coef * e_loss + self.vf_coef * c_loss
self.optimizer.zero_grad()
loss.backward()
if self.use_grad_clip:
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.clip_grad_norm)
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
# Logger
lr = self.optimizer.state_dict()['param_groups'][0]['lr']
cr = ((ratio < 1 - self.clip_range).sum() + (ratio > 1 + self.clip_range).sum()) / ratio.shape[0]
info = {
"actor-loss": a_loss.item(),
"critic-loss": c_loss.item(),
"entropy": e_loss.item(),
"learning_rate": lr,
"predict_value": v_pred.mean().item(),
"clip_ratio": cr
}
return info
.. group-tab:: TensorFlow

.. code-block:: python3
.. code-block:: python
.. group-tab:: MindSpore

.. code-block:: python3
.. code-block:: python

0 comments on commit d470ca8

Please sign in to comment.