Skip to content

Commit

Permalink
documentation: api: learner: mappo-masac1
Browse files Browse the repository at this point in the history
  • Loading branch information
F-coder1857 committed Nov 30, 2023
1 parent 2a4079a commit abb9ef8
Show file tree
Hide file tree
Showing 2 changed files with 455 additions and 10 deletions.
293 changes: 283 additions & 10 deletions docs/source/documents/api/learners/marl/mappo.rst
Original file line number Diff line number Diff line change
@@ -1,20 +1,68 @@
MAPPO_Learner
=====================================

xxxxxx.

.. raw:: html

<br><hr>

**PyTorch:**

.. py:class::
xuance.torch.learners.multi_agent_rl.mappo_learner.MAPPO_Clip_Learner(config, policy, optimizer, scheduler, device, model_dir, gamma)

:param config: xxxxxx.
:type config: xxxxxx
:param policy: xxxxxx.
:type policy: xxxxxx
:param optimizer: xxxxxx.
:type optimizer: xxxxxx
:param scheduler: xxxxxx.
:type scheduler: xxxxxx
:param device: xxxxxx.
:type device: xxxxxx
:param model_dir: xxxxxx.
:type model_dir: xxxxxx
:param gamma: xxxxxx.
:type gamma: xxxxxx

.. py:function::
xuance.torch.learners.multi_agent_rl.mappo_learner.MAPPO_Clip_Learner.lr_decay(i_step)

xxxxxx.

:param i_step: xxxxxx.
:type i_step: xxxxxx
:return: xxxxxx.
:rtype: xxxxxx

.. py:function::
xuance.torch.learners.multi_agent_rl.mappo_learner.MAPPO_Clip_Learner.update(sample)

xxxxxx.

:param sample: xxxxxx.
:type sample: xxxxxx
:return: xxxxxx.
:rtype: xxxxxx

.. py:function::
xuance.torch.learners.multi_agent_rl.mappo_learner.MAPPO_Clip_Learner.update_recurrent(sample)

xxxxxx.

:param sample: xxxxxx.
:type sample: xxxxxx
:return: xxxxxx.
:rtype: xxxxxx

.. raw:: html

<br><hr>

**TensorFlow:**


.. raw:: html

<br><hr>
Expand All @@ -29,21 +77,246 @@ Source Code
-----------------

.. tabs::

.. group-tab:: PyTorch

.. code-block:: python3

.. group-tab:: PyTorch

.. code-block:: python
"""
Multi-Agent Proximal Policy Optimization (MAPPO)
Paper link:
https://arxiv.org/pdf/2103.01955.pdf
Implementation: Pytorch
"""
from xuance.torch.learners import *
from xuance.torch.utils.value_norm import ValueNorm
from xuance.torch.utils.operations import update_linear_decay


class MAPPO_Clip_Learner(LearnerMAS):
def __init__(self,
config: Namespace,
policy: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
device: Optional[Union[int, str, torch.device]] = None,
model_dir: str = "./",
gamma: float = 0.99,
):
self.gamma = gamma
self.clip_range = config.clip_range
self.use_linear_lr_decay = config.use_linear_lr_decay
self.use_grad_norm, self.max_grad_norm = config.use_grad_norm, config.max_grad_norm
self.use_value_clip, self.value_clip_range = config.use_value_clip, config.value_clip_range
self.use_huber_loss, self.huber_delta = config.use_huber_loss, config.huber_delta
self.use_value_norm = config.use_value_norm
self.use_global_state = config.use_global_state
self.vf_coef, self.ent_coef = config.vf_coef, config.ent_coef
self.mse_loss = nn.MSELoss()
self.huber_loss = nn.HuberLoss(reduction="none", delta=self.huber_delta)
super(MAPPO_Clip_Learner, self).__init__(config, policy, optimizer, scheduler, device, model_dir)
if self.use_value_norm:
self.value_normalizer = ValueNorm(1).to(device)
else:
self.value_normalizer = None
self.lr = config.learning_rate
self.end_factor_lr_decay = config.end_factor_lr_decay

def lr_decay(self, i_step):
if self.use_linear_lr_decay:
update_linear_decay(self.optimizer, i_step, self.running_steps, self.lr, self.end_factor_lr_decay)

def update(self, sample):
info = {}
self.iterations += 1
state = torch.Tensor(sample['state']).to(self.device)
obs = torch.Tensor(sample['obs']).to(self.device)
actions = torch.Tensor(sample['actions']).to(self.device)
values = torch.Tensor(sample['values']).to(self.device)
returns = torch.Tensor(sample['returns']).to(self.device)
advantages = torch.Tensor(sample['advantages']).to(self.device)
log_pi_old = torch.Tensor(sample['log_pi_old']).to(self.device)
agent_mask = torch.Tensor(sample['agent_mask']).float().reshape(-1, self.n_agents, 1).to(self.device)
batch_size = obs.shape[0]
IDs = torch.eye(self.n_agents).unsqueeze(0).expand(batch_size, -1, -1).to(self.device)

# actor loss
_, pi_dist = self.policy(obs, IDs)
log_pi = pi_dist.log_prob(actions)
ratio = torch.exp(log_pi - log_pi_old).reshape(batch_size, self.n_agents, 1)
advantages_mask = advantages.detach() * agent_mask
surrogate1 = ratio * advantages_mask
surrogate2 = torch.clip(ratio, 1 - self.clip_range, 1 + self.clip_range) * advantages_mask
loss_a = -torch.sum(torch.min(surrogate1, surrogate2), dim=-2, keepdim=True).mean()

# entropy loss
entropy = pi_dist.entropy().reshape(agent_mask.shape) * agent_mask
loss_e = entropy.mean()

# critic loss
critic_in = torch.Tensor(obs).reshape([batch_size, 1, -1]).to(self.device)
critic_in = critic_in.expand(-1, self.n_agents, -1)
_, value_pred = self.policy.get_values(critic_in, IDs)
value_pred = value_pred
value_target = returns
if self.use_value_clip:
value_clipped = values + (value_pred - values).clamp(-self.value_clip_range, self.value_clip_range)
if self.use_huber_loss:
loss_v = self.huber_loss(value_pred, value_target)
loss_v_clipped = self.huber_loss(value_clipped, value_target)
else:
loss_v = (value_pred - value_target) ** 2
loss_v_clipped = (value_clipped - value_target) ** 2
loss_c = torch.max(loss_v, loss_v_clipped) * agent_mask
loss_c = loss_c.sum() / agent_mask.sum()
else:
if self.use_huber_loss:
loss_v = self.huber_loss(value_pred, value_target) * agent_mask
else:
loss_v = ((value_pred - value_target) ** 2) * agent_mask
loss_c = loss_v.sum() / agent_mask.sum()

loss = loss_a + self.vf_coef * loss_c - self.ent_coef * loss_e
self.optimizer.zero_grad()
loss.backward()
if self.use_grad_norm:
grad_norm = torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
info["gradient_norm"] = grad_norm.item()
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()

# Logger
lr = self.optimizer.state_dict()['param_groups'][0]['lr']

info.update({
"learning_rate": lr,
"actor_loss": loss_a.item(),
"critic_loss": loss_c.item(),
"entropy": loss_e.item(),
"loss": loss.item(),
"predict_value": value_pred.mean().item()
})

return info

def update_recurrent(self, sample):
info = {}
self.iterations += 1
state = torch.Tensor(sample['state']).to(self.device)
if self.use_global_state:
state = state.unsqueeze(1).expand(-1, self.n_agents, -1, -1)
obs = torch.Tensor(sample['obs']).to(self.device)
actions = torch.Tensor(sample['actions']).to(self.device)
values = torch.Tensor(sample['values']).to(self.device)
returns = torch.Tensor(sample['returns']).to(self.device)
advantages = torch.Tensor(sample['advantages']).to(self.device)
log_pi_old = torch.Tensor(sample['log_pi_old']).to(self.device)
avail_actions = torch.Tensor(sample['avail_actions']).float().to(self.device)
filled = torch.Tensor(sample['filled']).float().to(self.device)
batch_size = obs.shape[0]
episode_length = actions.shape[2]
IDs = torch.eye(self.n_agents).unsqueeze(1).unsqueeze(0).expand(batch_size, -1, episode_length + 1, -1).to(
self.device)

# actor loss
rnn_hidden_actor = self.policy.representation.init_hidden(batch_size * self.n_agents)
_, pi_dist = self.policy(obs[:, :, :-1].reshape(-1, episode_length, self.dim_obs),
IDs[:, :, :-1].reshape(-1, episode_length, self.n_agents),
*rnn_hidden_actor,
avail_actions=avail_actions[:, :, :-1].reshape(-1, episode_length, self.dim_act))
log_pi = pi_dist.log_prob(actions.reshape(-1, episode_length)).reshape(batch_size, self.n_agents, episode_length)
ratio = torch.exp(log_pi - log_pi_old).unsqueeze(-1)
filled_n = filled.unsqueeze(1).expand(batch_size, self.n_agents, episode_length, 1)
surrogate1 = ratio * advantages
surrogate2 = torch.clip(ratio, 1 - self.clip_range, 1 + self.clip_range) * advantages
loss_a = -(torch.min(surrogate1, surrogate2) * filled_n).sum() / filled_n.sum()
# entropy loss
entropy = pi_dist.entropy().reshape(batch_size, self.n_agents, episode_length, 1)
entropy = entropy * filled_n
loss_e = entropy.sum() / filled_n.sum()

# critic loss
rnn_hidden_critic = self.policy.representation_critic.init_hidden(batch_size * self.n_agents)
if self.use_global_state:
critic_in_obs = obs[:, :, :-1].transpose(1, 2).reshape(batch_size, episode_length, -1)
critic_in_obs = critic_in_obs.unsqueeze(1).expand(-1, self.n_agents, -1, -1)
critic_in_state = state[:, :, :-1]
critic_in = torch.concat([critic_in_obs, critic_in_state], dim=-1)
_, value_pred = self.policy.get_values(critic_in, IDs[:, :, :-1], *rnn_hidden_critic)
else:
critic_in = obs[:, :, :-1].transpose(1, 2).reshape(batch_size, episode_length, -1)
critic_in = critic_in.unsqueeze(1).expand(-1, self.n_agents, -1, -1)
_, value_pred = self.policy.get_values(critic_in, IDs[:, :, :-1], *rnn_hidden_critic)
value_target = returns.reshape(-1, 1)
values = values.reshape(-1, 1)
value_pred = value_pred.reshape(-1, 1)
filled_all = filled_n.reshape(-1, 1)
if self.use_value_clip:
value_clipped = values + (value_pred - values).clamp(-self.value_clip_range, self.value_clip_range)
if self.use_value_norm:
self.value_normalizer.update(value_target)
value_target = self.value_normalizer.normalize(value_target)
if self.use_huber_loss:
loss_v = self.huber_loss(value_pred, value_target)
loss_v_clipped = self.huber_loss(value_clipped, value_target)
else:
loss_v = (value_pred - value_target) ** 2
loss_v_clipped = (value_clipped - value_target) ** 2
loss_c = torch.max(loss_v, loss_v_clipped) * filled_all
loss_c = loss_c.sum() / filled_all.sum()
else:
if self.use_value_norm:
self.value_normalizer.update(value_target)
value_pred = self.value_normalizer.normalize(value_pred)
if self.use_huber_loss:
loss_v = self.huber_loss(value_pred, value_target)
else:
loss_v = (value_pred - value_target) ** 2
loss_c = (loss_v * filled_all).sum() / filled_all.sum()

loss = loss_a + self.vf_coef * loss_c - self.ent_coef * loss_e
self.optimizer.zero_grad()
loss.backward()
if self.use_grad_norm:
grad_norm = torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
info["gradient_norm"] = grad_norm.item()
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()

# Logger
lr = self.optimizer.state_dict()['param_groups'][0]['lr']

info.update({
"learning_rate": lr,
"actor_loss": loss_a.item(),
"critic_loss": loss_c.item(),
"entropy": loss_e.item(),
"loss": loss.item(),
"predict_value": value_pred.mean().item()
})

return info











.. group-tab:: TensorFlow

.. code-block:: python3
.. group-tab:: TensorFlow

.. code-block:: python
.. group-tab:: MindSpore
.. group-tab:: MindSpore

.. code-block:: python3
.. code-block:: python
Loading

0 comments on commit abb9ef8

Please sign in to comment.