Skip to content

Commit

Permalink
Clear agent processor properly on episode reset (#3437)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ervin T authored Feb 14, 2020
1 parent ff99fb0 commit 803e62f
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
6 changes: 3 additions & 3 deletions ml-agents/mlagents/trainers/agent_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ def end_episode(self) -> None:
Ends the episode, terminating the current trajectory and stopping stats collection for that
episode. Used for forceful reset (e.g. in curriculum or generalization training.)
"""
self.experience_buffers.clear()
self.episode_rewards.clear()
self.episode_steps.clear()
all_gids = list(self.experience_buffers.keys()) # Need to make copy
for _gid in all_gids:
self._clean_agent_data(_gid)


class AgentManagerQueue(Generic[T]):
Expand Down
53 changes: 53 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_agent_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,59 @@ def test_agent_deletion():
assert len(processor.episode_rewards.keys()) == 0


def test_end_episode():
policy = create_mock_policy()
tqueue = mock.Mock()
name_behavior_id = "test_brain_name"
processor = AgentProcessor(
policy,
name_behavior_id,
max_trajectory_length=5,
stats_reporter=StatsReporter("testcat"),
)

fake_action_outputs = {
"action": [0.1],
"entropy": np.array([1.0], dtype=np.float32),
"learning_rate": 1.0,
"pre_action": [0.1],
"log_probs": [0.1],
}
mock_step = mb.create_mock_batchedstep(
num_agents=1,
num_vector_observations=8,
action_shape=[2],
num_vis_observations=0,
)
fake_action_info = ActionInfo(
action=[0.1],
value=[0.1],
outputs=fake_action_outputs,
agent_ids=mock_step.agent_id,
)

processor.publish_trajectory_queue(tqueue)
# This is like the initial state after the env reset
processor.add_experiences(mock_step, 0, ActionInfo.empty())
# Run 3 trajectories, with different workers (to simulate different agents)
remove_calls = []
for _ep in range(3):
remove_calls.append(mock.call([get_global_agent_id(_ep, 0)]))
for _ in range(5):
processor.add_experiences(mock_step, _ep, fake_action_info)
# Make sure we don't add experiences from the prior agents after the done

# Call end episode
processor.end_episode()
# Check that we removed every agent
policy.remove_previous_action.assert_has_calls(remove_calls)
# Check that there are no experiences left
assert len(processor.experience_buffers.keys()) == 0
assert len(processor.last_take_action_outputs.keys()) == 0
assert len(processor.episode_steps.keys()) == 0
assert len(processor.episode_rewards.keys()) == 0


def test_agent_manager():
policy = create_mock_policy()
name_behavior_id = "test_brain_name"
Expand Down

0 comments on commit 803e62f

Please sign in to comment.