-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmaddpg.py
106 lines (91 loc) · 3.61 KB
/
maddpg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import numpy as np
import torch
from agent import Agent
from typing import List, Optional
import torch.nn.functional as F
from utils import RunningMeanStd
from gradient_estimators import GradientEstimator
class MADDPG:
def __init__(
self,
env,
critic_lr : float,
actor_lr : float,
gradient_clip : float,
hidden_dim_width : int,
gamma : float,
soft_update_size : float,
policy_regulariser : float,
gradient_estimator : GradientEstimator,
standardise_rewards : bool,
pretrained_agents : Optional[ List[Agent] ] = None,
):
self.n_agents = env.n_agents
self.gamma = gamma
obs_dims = [obs.shape[0] for obs in env.observation_space]
act_dims = [act.n for act in env.action_space]
self.agents = [
Agent(
agent_idx=ii,
obs_dims=obs_dims,
act_dims=act_dims,
# TODO: Consider changing this to **config
hidden_dim_width=hidden_dim_width,
critic_lr=critic_lr,
actor_lr=actor_lr,
gradient_clip=gradient_clip,
soft_update_size=soft_update_size,
policy_regulariser=policy_regulariser,
gradient_estimator=gradient_estimator,
)
for ii in range(self.n_agents)
] if pretrained_agents is None else pretrained_agents
self.return_std = RunningMeanStd(shape=(self.n_agents,)) if standardise_rewards else None
self.gradient_estimator = gradient_estimator # Keep reference to GE object
def acts(self, obs: List):
actions = [self.agents[ii].act_behaviour(obs[ii]) for ii in range(self.n_agents)]
return actions
def update(self, sample):
# sample['obs'] : agent batch obs
batched_obs = torch.concat(sample['obs'], axis=1)
batched_nobs = torch.concat(sample['nobs'], axis=1)
# ********
# TODO: This is all a bit cumbersome--could be cleaner?
target_actions = [
self.agents[ii].act_target(sample['nobs'][ii])
for ii in range(self.n_agents)
]
target_actions_one_hot = [
F.one_hot(target_actions[ii], num_classes=self.agents[ii].n_acts)
for ii in range(self.n_agents)
] # agent batch actions
sampled_actions_one_hot = [
F.one_hot(sample['acts'][ii].to(torch.int64), num_classes=self.agents[ii].n_acts)
for ii in range(self.n_agents)
] # agent batch actions
# ********
# Standardise rewards if requested
rewards = sample['rwds']
if self.return_std is not None:
self.return_std.update(rewards)
rewards = ((rewards.T - self.return_std.mean) / torch.sqrt(self.return_std.var)).T
# ********
for ii, agent in enumerate(self.agents):
agent.update_critic(
all_obs=batched_obs,
all_nobs=batched_nobs,
target_actions_per_agent=target_actions_one_hot,
sampled_actions_per_agent=sampled_actions_one_hot,
rewards=rewards[ii].unsqueeze(dim=1),
dones=sample['dones'][ii].unsqueeze(dim=1),
gamma=self.gamma,
)
agent.update_actor(
all_obs=batched_obs,
agent_obs=sample['obs'][ii],
sampled_actions=sampled_actions_one_hot,
)
for agent in self.agents:
agent.soft_update()
self.gradient_estimator.update_state() # Update GE state, if necessary
return None