diff --git a/docs/source/documents/api/learners/drl/a2c.rst b/docs/source/documents/api/learners/drl/a2c.rst
index 746dd3e2..4d811e8d 100644
--- a/docs/source/documents/api/learners/drl/a2c.rst
+++ b/docs/source/documents/api/learners/drl/a2c.rst
@@ -41,564 +41,6 @@ A2C_Learner
:return: xxxxxx.
:rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.BasicQnetwork(action_space, n_agents, representation, hidden_size, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param hidden_sizes: xxxxxx.
- :type hidden_sizes: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.BasicQnetwork.forward(observation, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.BasicQnetwork.target_Q(observation, agent_ids, *rnn_hidden)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :param *rnn_hidden: xxxxxx.
- :type *rnn_hidden: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.BasicQnetwork.copy_target()
-
- :return: None.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.MFQnetwork(action_space, n_agents, representation, hidden_sizes, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param hidden_sizes: xxxxxx.
- :type hidden_sizes: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MFQnetwork.forward(observation, actions_mean, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions_mean: xxxxxx.
- :type actions_mean: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MFQnetwork.sample_actions(logits)
-
- :param logits: xxxxxx.
- :type logits: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MFQnetwork.target_Q(observation, actions_mean, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions_mean: xxxxxx.
- :type actions_mean: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MFQnetwork.copy_target()
-
- :return: None.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.MixingQnetwork(action_space, n_agents, representation, mixer, hidden_size, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param mixer: xxxxxx.
- :type mixer: xxxxxx
- :param hidden_size: xxxxxx.
- :type hidden_size: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MixingQnetwork.forward(observation, agent_ids, *rnn_hidden, avail_actions)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :param *rnn_hidden: xxxxxx.
- :type *rnn_hidden: xxxxxx
- :param avail_actions: xxxxxx.
- :type avail_actions: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MixingQnetwork.target_Q(observation, agent_ids, *rnn_hidden)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :param *rnn_hidden: xxxxxx.
- :type *rnn_hidden: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MixingQnetwork.Q_tot(q, states)
-
- :param q: xxxxxx.
- :type q: xxxxxx
- :param states: xxxxxx.
- :type gstates: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MixingQnetwork.target_Q_tot(q, states)
-
- :param q: xxxxxx.
- :type q: xxxxxx
- :param states: xxxxxx.
- :type gstates: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MixingQnetwork.copy_target()
-
- :return: None.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.Weighted_MixingQnetwork(action_space, n_agents, representation, mixer, ff_mixer, hidden_size, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param mixer: xxxxxx.
- :type mixer: xxxxxx
- :param ff_mixer: xxxxxx.
- :type ff_mixer: xxxxxx
- :param hidden_size: xxxxxx.
- :type hidden_size: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Weighted_MixingQnetwork.q_centralized(observation, agent_ids, *rnn_hidden)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :param *rnn_hidden: xxxxxx.
- :type *rnn_hidden: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Weighted_MixingQnetwork.target_q_centralized(observation, agent_ids, *rnn_hidden)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :param *rnn_hidden: xxxxxx.
- :type *rnn_hidden: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Weighted_MixingQnetwork.copy_target()
-
- :return: None.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.Qtran_MixingQnetwork(action_space, n_agents, representation, mixer, qtran_mixer, hidden_size, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param mixer: xxxxxx.
- :type mixer: xxxxxx
- :param qtran_mixer: xxxxxx.
- :type qtran_mixer: xxxxxx
- :param critic_hidden_size: xxxxxx.
- :type critic_hidden_size: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Qtran_MixingQnetwork.forward(observation, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Qtran_MixingQnetwork.target_Q(observation, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Qtran_MixingQnetwork.copy_target()
-
- :return: None.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.DCG_policy(action_space, global_state_dim, representation, utility, payoffs, dcgraph, hidden_size_bias, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param global_state_dim: xxxxxx.
- :type global_state_dim: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param utility: xxxxxx.
- :type utility: xxxxxx
- :param payoffs: xxxxxx.
- :type payoffs: xxxxxx
- :param hidden_size_bias: xxxxxx.
- :type hidden_size_bias: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.DCG_policy.forward(observation, agent_ids, *rnn_hidden, avail_actions)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :param *rnn_hidden: xxxxxx.
- :type *rnn_hidden: xxxxxx
- :param avail_actions: xxxxxx.
- :type avail_actions: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.DCG_policy.copy_target()
-
- :return: None.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.ActorNet(state_dim, n_agents, action_space, hidden_sizes, normalize, initialize, activation, device)
-
- :param state_dim: xxxxxx.
- :type state_dim: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param hidden_sizes: xxxxxx.
- :type hidden_sizes: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.ActorNet.forward()
-
- :return: None.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.CriticNet(independent, state_dim, n_agents, action_dim, hidden_sizes, normalize, initialize, activation, device)
-
- :param independent: xxxxxx.
- :type independent: xxxxxx
- :param state_dim: xxxxxx.
- :type state_dim: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param action_dim: xxxxxx.
- :type action_dim: xxxxxx
- :param hidden_sizes: xxxxxx.
- :type hidden_sizes: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.ACriticNet.forward()
-
- :return: None.
- :rtype: xxxxxx
-
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.Basic_DDPG_policy(action_space, n_agents, representation, actor_hidden_size, critic_hidden_size, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param actor_hidden_size: xxxxxx.
- :type actor_hidden_size: xxxxxx
- :param critic_hidden_size: xxxxxx.
- :type critic_hidden_size: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Basic_DDPG_policy.forward(observation, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: None.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Basic_DDPG_policy.critic(observation, actions, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions: xxxxxx.
- :type actions: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Basic_DDPG_policy.target_critic(observation, actions, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions: xxxxxx.
- :type actions: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Basic_DDPG_policy.soft_update(tau)
-
- :param tau: xxxxxx.
- :type tau: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.MADDPG_policy(action_space, n_agents, representation, actor_hidden_size, critic_hidden_size, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param actor_hidden_size: xxxxxx.
- :type actor_hidden_size: xxxxxx
- :param critic_hidden_size: xxxxxx.
- :type critic_hidden_size: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MADDPG_policy.critic(observation, actions, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions: xxxxxx.
- :type actions: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MADDPG_policy.target_critic(observation, actions, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions: xxxxxx.
- :type actions: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.MATD3_policy(action_space, n_agents, representation, actor_hidden_size, critic_hidden_size, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param actor_hidden_size: xxxxxx.
- :type actor_hidden_size: xxxxxx
- :param critic_hidden_size: xxxxxx.
- :type critic_hidden_size: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MATD3_policy.Qpolicy(observation, actions, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions: xxxxxx.
- :type actions: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MATD3_policy.Qtarget(observation, actions, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions: xxxxxx.
- :type actions: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MATD3_policy.Qaction(observation, actions, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions: xxxxxx.
- :type actions: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MATD3_policy.soft_update()
-
- :return: None.
- :rtype: xxxxxx
-
.. raw:: html
diff --git a/docs/source/documents/api/learners/drl/ddpg.rst b/docs/source/documents/api/learners/drl/ddpg.rst
index 8a315e89..9dc6e734 100644
--- a/docs/source/documents/api/learners/drl/ddpg.rst
+++ b/docs/source/documents/api/learners/drl/ddpg.rst
@@ -1,7 +1,5 @@
DDPG_Learner
=====================================
-A2C_Learner
-=====================================
.. raw:: html
@@ -10,7 +8,7 @@ A2C_Learner
**PyTorch:**
.. py:class::
- xuance.torch.learners.policy_gradient.a2c_learner.A2C_Learner(policy, optimizer, scheduler, device, model_dir, vf_coef, ent_coef, clip_grad)
+ xuance.torch.learners.policy_gradient.ddpg_learner.DDPG_Learner(policy, optimizer, scheduler, device, model_dir, gamma, tau)
:param policy: xxxxxx.
:type policy: xxxxxx
@@ -22,585 +20,27 @@ A2C_Learner
:type device: xxxxxx
:param model_dir: xxxxxx.
:type model_dir: xxxxxx
- :param vf_coef: xxxxxx.
- :type vf_coef: xxxxxx
- :param ent_coef: xxxxxx.
- :type ent_coef: xxxxxx
- :param clip_grad: xxxxxx.
- :type clip_grad: xxxxxx
+ :param gamma: xxxxxx.
+ :type gamma: xxxxxx
+ :param tau: xxxxxx.
+ :type tau: xxxxxx
.. py:function::
- xuance.torch.learners.policy_gradient.a2c_learner.A2C_Learner.update(obs_batch, act_batch, ret_batch, adv_batch)
+ xuance.torch.learners.policy_gradient.ddpg_learner.DDPG_Learner.update(obs_batch, act_batch, rew_batch, next_batch, terminal_batch)
:param obs_batch: xxxxxx.
:type obs_batch: xxxxxx
:param act_batch: xxxxxx.
:type act_batch: xxxxxx
- :param ret_batch: xxxxxx.
- :type ret_batch: xxxxxx
- :param adv_batch: xxxxxx.
- :type adv_batch: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.BasicQnetwork(action_space, n_agents, representation, hidden_size, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param hidden_sizes: xxxxxx.
- :type hidden_sizes: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.BasicQnetwork.forward(observation, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.BasicQnetwork.target_Q(observation, agent_ids, *rnn_hidden)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :param *rnn_hidden: xxxxxx.
- :type *rnn_hidden: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.BasicQnetwork.copy_target()
-
- :return: None.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.MFQnetwork(action_space, n_agents, representation, hidden_sizes, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param hidden_sizes: xxxxxx.
- :type hidden_sizes: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MFQnetwork.forward(observation, actions_mean, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions_mean: xxxxxx.
- :type actions_mean: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MFQnetwork.sample_actions(logits)
-
- :param logits: xxxxxx.
- :type logits: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MFQnetwork.target_Q(observation, actions_mean, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions_mean: xxxxxx.
- :type actions_mean: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MFQnetwork.copy_target()
-
- :return: None.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.MixingQnetwork(action_space, n_agents, representation, mixer, hidden_size, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param mixer: xxxxxx.
- :type mixer: xxxxxx
- :param hidden_size: xxxxxx.
- :type hidden_size: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MixingQnetwork.forward(observation, agent_ids, *rnn_hidden, avail_actions)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :param *rnn_hidden: xxxxxx.
- :type *rnn_hidden: xxxxxx
- :param avail_actions: xxxxxx.
- :type avail_actions: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MixingQnetwork.target_Q(observation, agent_ids, *rnn_hidden)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :param *rnn_hidden: xxxxxx.
- :type *rnn_hidden: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MixingQnetwork.Q_tot(q, states)
-
- :param q: xxxxxx.
- :type q: xxxxxx
- :param states: xxxxxx.
- :type gstates: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MixingQnetwork.target_Q_tot(q, states)
-
- :param q: xxxxxx.
- :type q: xxxxxx
- :param states: xxxxxx.
- :type gstates: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MixingQnetwork.copy_target()
-
- :return: None.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.Weighted_MixingQnetwork(action_space, n_agents, representation, mixer, ff_mixer, hidden_size, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param mixer: xxxxxx.
- :type mixer: xxxxxx
- :param ff_mixer: xxxxxx.
- :type ff_mixer: xxxxxx
- :param hidden_size: xxxxxx.
- :type hidden_size: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Weighted_MixingQnetwork.q_centralized(observation, agent_ids, *rnn_hidden)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :param *rnn_hidden: xxxxxx.
- :type *rnn_hidden: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Weighted_MixingQnetwork.target_q_centralized(observation, agent_ids, *rnn_hidden)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :param *rnn_hidden: xxxxxx.
- :type *rnn_hidden: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Weighted_MixingQnetwork.copy_target()
-
- :return: None.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.Qtran_MixingQnetwork(action_space, n_agents, representation, mixer, qtran_mixer, hidden_size, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param mixer: xxxxxx.
- :type mixer: xxxxxx
- :param qtran_mixer: xxxxxx.
- :type qtran_mixer: xxxxxx
- :param critic_hidden_size: xxxxxx.
- :type critic_hidden_size: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Qtran_MixingQnetwork.forward(observation, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Qtran_MixingQnetwork.target_Q(observation, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Qtran_MixingQnetwork.copy_target()
-
- :return: None.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.DCG_policy(action_space, global_state_dim, representation, utility, payoffs, dcgraph, hidden_size_bias, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param global_state_dim: xxxxxx.
- :type global_state_dim: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param utility: xxxxxx.
- :type utility: xxxxxx
- :param payoffs: xxxxxx.
- :type payoffs: xxxxxx
- :param hidden_size_bias: xxxxxx.
- :type hidden_size_bias: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.DCG_policy.forward(observation, agent_ids, *rnn_hidden, avail_actions)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :param *rnn_hidden: xxxxxx.
- :type *rnn_hidden: xxxxxx
- :param avail_actions: xxxxxx.
- :type avail_actions: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.DCG_policy.copy_target()
-
- :return: None.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.ActorNet(state_dim, n_agents, action_space, hidden_sizes, normalize, initialize, activation, device)
-
- :param state_dim: xxxxxx.
- :type state_dim: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param hidden_sizes: xxxxxx.
- :type hidden_sizes: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.ActorNet.forward()
-
- :return: None.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.CriticNet(independent, state_dim, n_agents, action_dim, hidden_sizes, normalize, initialize, activation, device)
-
- :param independent: xxxxxx.
- :type independent: xxxxxx
- :param state_dim: xxxxxx.
- :type state_dim: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param action_dim: xxxxxx.
- :type action_dim: xxxxxx
- :param hidden_sizes: xxxxxx.
- :type hidden_sizes: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.ACriticNet.forward()
-
- :return: None.
- :rtype: xxxxxx
-
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.Basic_DDPG_policy(action_space, n_agents, representation, actor_hidden_size, critic_hidden_size, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param actor_hidden_size: xxxxxx.
- :type actor_hidden_size: xxxxxx
- :param critic_hidden_size: xxxxxx.
- :type critic_hidden_size: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Basic_DDPG_policy.forward(observation, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: None.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Basic_DDPG_policy.critic(observation, actions, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions: xxxxxx.
- :type actions: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
+ :param rew_batch: xxxxxx.
+ :type rew_batch: xxxxxx
+ :param next_batch: xxxxxx.
+ :type next_batch: xxxxxx
+ :param terminal_batch: xxxxxx.
+ :type terminal_batch: xxxxxx
:return: xxxxxx.
:rtype: xxxxxx
-.. py:function::
- xuance.torch.policies.deterministic_marl.Basic_DDPG_policy.target_critic(observation, actions, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions: xxxxxx.
- :type actions: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.Basic_DDPG_policy.soft_update(tau)
-
- :param tau: xxxxxx.
- :type tau: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.MADDPG_policy(action_space, n_agents, representation, actor_hidden_size, critic_hidden_size, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param actor_hidden_size: xxxxxx.
- :type actor_hidden_size: xxxxxx
- :param critic_hidden_size: xxxxxx.
- :type critic_hidden_size: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MADDPG_policy.critic(observation, actions, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions: xxxxxx.
- :type actions: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MADDPG_policy.target_critic(observation, actions, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions: xxxxxx.
- :type actions: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:class::
- xuance.torch.policies.deterministic_marl.MATD3_policy(action_space, n_agents, representation, actor_hidden_size, critic_hidden_size, normalize, initialize, activation, device)
-
- :param action_space: xxxxxx.
- :type action_space: xxxxxx
- :param n_agents: xxxxxx.
- :type n_agents: xxxxxx
- :param representation: xxxxxx.
- :type representation: xxxxxx
- :param actor_hidden_size: xxxxxx.
- :type actor_hidden_size: xxxxxx
- :param critic_hidden_size: xxxxxx.
- :type critic_hidden_size: xxxxxx
- :param normalize: xxxxxx.
- :type normalize: xxxxxx
- :param initialize: xxxxxx.
- :type initialize: xxxxxx
- :param activation: xxxxxx.
- :type activation: xxxxxx
- :param device: xxxxxx.
- :type device: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MATD3_policy.Qpolicy(observation, actions, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions: xxxxxx.
- :type actions: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MATD3_policy.Qtarget(observation, actions, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions: xxxxxx.
- :type actions: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MATD3_policy.Qaction(observation, actions, agent_ids)
-
- :param observation: xxxxxx.
- :type observation: xxxxxx
- :param actions: xxxxxx.
- :type actions: xxxxxx
- :param agent_ids: xxxxxx.
- :type agent_ids: xxxxxx
- :return: xxxxxx.
- :rtype: xxxxxx
-
-.. py:function::
- xuance.torch.policies.deterministic_marl.MATD3_policy.soft_update()
-
- :return: None.
- :rtype: xxxxxx
-
.. raw:: html
@@ -629,50 +69,56 @@ Source Code
from xuance.torch.learners import *
- class A2C_Learner(Learner):
+ class DDPG_Learner(Learner):
def __init__(self,
policy: nn.Module,
- optimizer: torch.optim.Optimizer,
- scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+ optimizers: Sequence[torch.optim.Optimizer],
+ schedulers: Sequence[torch.optim.lr_scheduler._LRScheduler],
device: Optional[Union[int, str, torch.device]] = None,
model_dir: str = "./",
- vf_coef: float = 0.25,
- ent_coef: float = 0.005,
- clip_grad: Optional[float] = None):
- super(A2C_Learner, self).__init__(policy, optimizer, scheduler, device, model_dir)
- self.vf_coef = vf_coef
- self.ent_coef = ent_coef
- self.clip_grad = clip_grad
-
- def update(self, obs_batch, act_batch, ret_batch, adv_batch):
+ gamma: float = 0.99,
+ tau: float = 0.01):
+ self.tau = tau
+ self.gamma = gamma
+ super(DDPG_Learner, self).__init__(policy, optimizers, schedulers, device, model_dir)
+
+ def update(self, obs_batch, act_batch, rew_batch, next_batch, terminal_batch):
self.iterations += 1
act_batch = torch.as_tensor(act_batch, device=self.device)
- ret_batch = torch.as_tensor(ret_batch, device=self.device)
- adv_batch = torch.as_tensor(adv_batch, device=self.device)
- outputs, a_dist, v_pred = self.policy(obs_batch)
- log_prob = a_dist.log_prob(act_batch)
-
- a_loss = -(adv_batch * log_prob).mean()
- c_loss = F.mse_loss(v_pred, ret_batch)
- e_loss = a_dist.entropy().mean()
-
- loss = a_loss - self.ent_coef * e_loss + self.vf_coef * c_loss
- self.optimizer.zero_grad()
- loss.backward()
- torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.clip_grad)
- self.optimizer.step()
+ rew_batch = torch.as_tensor(rew_batch, device=self.device)
+ ter_batch = torch.as_tensor(terminal_batch, device=self.device)
+ # critic update
+ action_q = self.policy.Qaction(obs_batch, act_batch)
+ # with torch.no_grad():
+ target_q = self.policy.Qtarget(next_batch)
+ backup = rew_batch + (1 - ter_batch) * self.gamma * target_q
+ q_loss = F.mse_loss(action_q, backup.detach())
+ self.optimizer[1].zero_grad()
+ q_loss.backward()
+ self.optimizer[1].step()
+
+ # actor update
+ policy_q = self.policy.Qpolicy(obs_batch)
+ p_loss = -policy_q.mean()
+ self.optimizer[0].zero_grad()
+ p_loss.backward()
+ self.optimizer[0].step()
+
if self.scheduler is not None:
- self.scheduler.step()
+ self.scheduler[0].step()
+ self.scheduler[1].step()
- # Logger
- lr = self.optimizer.state_dict()['param_groups'][0]['lr']
+ self.policy.soft_update(self.tau)
+
+ actor_lr = self.optimizer[0].state_dict()['param_groups'][0]['lr']
+ critic_lr = self.optimizer[1].state_dict()['param_groups'][0]['lr']
info = {
- "actor-loss": a_loss.item(),
- "critic-loss": c_loss.item(),
- "entropy": e_loss.item(),
- "learning_rate": lr,
- "predict_value": v_pred.mean().item()
+ "Qloss": q_loss.item(),
+ "Ploss": p_loss.item(),
+ "Qvalue": action_q.mean().item(),
+ "actor_lr": actor_lr,
+ "critic_lr": critic_lr
}
return info
@@ -680,6 +126,7 @@ Source Code
+
.. group-tab:: TensorFlow
.. code-block:: python