Skip to content

Commit

Permalink
[BugFix] Call all agents' reward functions before calling `observat…
Browse files Browse the repository at this point in the history
…ion`
  • Loading branch information
matteobettini committed Mar 19, 2024
1 parent cd68651 commit fb29f90
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions vmas/simulator/environment/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,24 @@ def get_from_scenario(
if get_infos:
infos = {} if dict_agent_names else []

for agent in self.agents:
if get_rewards:
if get_rewards:
for agent in self.agents:
reward = self.scenario.reward(agent).clone()
if dict_agent_names:
rewards.update({agent.name: reward})
else:
rewards.append(reward)
if get_observations:
if get_observations:
for agent in self.agents:
observation = TorchUtils.recursive_clone(
self.scenario.observation(agent)
)
if dict_agent_names:
obs.update({agent.name: observation})
else:
obs.append(observation)
if get_infos:
if get_infos:
for agent in self.agents:
info = TorchUtils.recursive_clone(self.scenario.info(agent))
if dict_agent_names:
infos.update({agent.name: info})
Expand Down

0 comments on commit fb29f90

Please sign in to comment.