-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update vmas wrapper base class, move wrappers and add wrapper tests
- Loading branch information
1 parent
76278ba
commit 67ed9dd
Showing
12 changed files
with
479 additions
and
200 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from pathlib import Path | ||
|
||
import gym | ||
import numpy as np | ||
import pytest | ||
from torch import Tensor | ||
|
||
from vmas import make_env | ||
from vmas.simulator.environment import Environment | ||
|
||
|
||
def scenario_names(): | ||
scenarios = [] | ||
scenarios_folder = Path(__file__).parent.parent.parent / "vmas" / "scenarios" | ||
for path in scenarios_folder.iterdir(): | ||
if path.is_file() and path.suffix == ".py" and not path.name.startswith("__"): | ||
scenarios.append(path.stem) | ||
return scenarios | ||
|
||
|
||
def _check_obs_type(obss, obs_shapes, dict_space, return_numpy): | ||
if dict_space: | ||
assert isinstance( | ||
obss, dict | ||
), f"Expected dictionary of observations, got {type(obss)}" | ||
obss = list(obss.values()) | ||
else: | ||
assert isinstance( | ||
obss, list | ||
), f"Expected list of observations, got {type(obss)}" | ||
for o, shape in zip(obss, obs_shapes): | ||
if return_numpy: | ||
assert isinstance(o, np.ndarray), f"Expected numpy array, got {type(o)}" | ||
assert o.shape == shape, f"Expected shape {shape}, got {o.shape}" | ||
else: | ||
assert isinstance(o, Tensor), f"Expected torch tensor, got {type(o)}" | ||
assert o.shape == shape, f"Expected shape {shape}, got {o.shape}" | ||
|
||
|
||
@pytest.mark.parametrize("scenario", scenario_names()) | ||
@pytest.mark.parametrize("return_numpy", [True, False]) | ||
@pytest.mark.parametrize("continuous_actions", [True, False]) | ||
@pytest.mark.parametrize("dict_space", [True, False]) | ||
def test_gym_wrapper( | ||
scenario, return_numpy, continuous_actions, dict_space, max_steps=10 | ||
): | ||
env = make_env( | ||
scenario=scenario, | ||
num_envs=1, | ||
device="cpu", | ||
continuous_actions=continuous_actions, | ||
dict_spaces=dict_space, | ||
wrapper="gym", | ||
wrapper_kwargs={"return_numpy": return_numpy}, | ||
max_steps=max_steps, | ||
) | ||
|
||
assert ( | ||
len(env.observation_space) == env.unwrapped.n_agents | ||
), "Expected one observation per agent" | ||
assert ( | ||
len(env.action_space) == env.unwrapped.n_agents | ||
), "Expected one action per agent" | ||
if dict_space: | ||
assert isinstance( | ||
env.observation_space, gym.spaces.Dict | ||
), "Expected Dict observation space" | ||
assert isinstance( | ||
env.action_space, gym.spaces.Dict | ||
), "Expected Dict action space" | ||
obs_shapes = [ | ||
obs_space.shape for obs_space in env.observation_space.spaces.values() | ||
] | ||
else: | ||
assert isinstance( | ||
env.observation_space, gym.spaces.Tuple | ||
), "Expected Tuple observation space" | ||
assert isinstance( | ||
env.action_space, gym.spaces.Tuple | ||
), "Expected Tuple action space" | ||
obs_shapes = [obs_space.shape for obs_space in env.observation_space.spaces] | ||
|
||
assert isinstance( | ||
env.unwrapped, Environment | ||
), "The unwrapped attribute of the Gym wrapper should be a VMAS Environment" | ||
|
||
obss = env.reset() | ||
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) | ||
|
||
for _ in range(max_steps): | ||
actions = env.unwrapped.get_random_actions() | ||
obss, rews, done, info = env.step(actions) | ||
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) | ||
|
||
assert len(rews) == env.unwrapped.n_agents, "Expected one reward per agent" | ||
if not dict_space: | ||
assert isinstance( | ||
rews, list | ||
), f"Expected list of rewards but got {type(rews)}" | ||
|
||
rew_values = rews | ||
else: | ||
assert isinstance( | ||
rews, dict | ||
), f"Expected dictionary of rewards but got {type(rews)}" | ||
rew_values = list(rews.values()) | ||
assert all( | ||
isinstance(rew, float) for rew in rew_values | ||
), f"Expected float rewards but got {type(rew_values[0])}" | ||
|
||
assert isinstance(done, bool), f"Expected bool for done but got {type(done)}" | ||
|
||
assert isinstance( | ||
info, dict | ||
), f"Expected info to be a dictionary but got {type(info)}" | ||
|
||
assert done, "Expected done to be True after 100 steps" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import pytest | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
from vmas import make_env | ||
from vmas.simulator.environment import Environment | ||
import torch | ||
|
||
from .test_gym_wrapper import scenario_names, _check_obs_type | ||
|
||
|
||
@pytest.mark.parametrize("scenario", scenario_names()) | ||
@pytest.mark.parametrize("return_numpy", [True, False]) | ||
@pytest.mark.parametrize("continuous_actions", [True, False]) | ||
@pytest.mark.parametrize("dict_space", [True, False]) | ||
@pytest.mark.parametrize("num_envs", [1, 10]) | ||
def test_gymnasium_wrapper( | ||
scenario, return_numpy, continuous_actions, dict_space, num_envs, max_steps=10 | ||
): | ||
env = make_env( | ||
scenario=scenario, | ||
num_envs=num_envs, | ||
device="cpu", | ||
continuous_actions=continuous_actions, | ||
dict_spaces=dict_space, | ||
wrapper="gymnasium_vec", | ||
terminated_truncated=True, | ||
wrapper_kwargs={"return_numpy": return_numpy}, | ||
max_steps=max_steps, | ||
) | ||
|
||
assert isinstance( | ||
env.unwrapped, Environment | ||
), "The unwrapped attribute of the Gym wrapper should be a VMAS Environment" | ||
|
||
assert ( | ||
len(env.observation_space) == env.unwrapped.n_agents | ||
), "Expected one observation per agent" | ||
assert ( | ||
len(env.action_space) == env.unwrapped.n_agents | ||
), "Expected one action per agent" | ||
if dict_space: | ||
assert isinstance( | ||
env.observation_space, gym.spaces.Dict | ||
), "Expected Dict observation space" | ||
assert isinstance( | ||
env.action_space, gym.spaces.Dict | ||
), "Expected Dict action space" | ||
obs_shapes = [ | ||
obs_space.shape for obs_space in env.observation_space.spaces.values() | ||
] | ||
else: | ||
assert isinstance( | ||
env.observation_space, gym.spaces.Tuple | ||
), "Expected Tuple observation space" | ||
assert isinstance( | ||
env.action_space, gym.spaces.Tuple | ||
), "Expected Tuple action space" | ||
obs_shapes = [obs_space.shape for obs_space in env.observation_space.spaces] | ||
|
||
obss, info = env.reset() | ||
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) | ||
assert isinstance( | ||
info, dict | ||
), f"Expected info to be a dictionary but got {type(info)}" | ||
|
||
for _ in range(max_steps): | ||
actions = env.unwrapped.get_random_actions() | ||
obss, rews, terminated, truncated, info = env.step(actions) | ||
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) | ||
|
||
assert len(rews) == env.unwrapped.n_agents, "Expected one reward per agent" | ||
if not dict_space: | ||
assert isinstance( | ||
rews, list | ||
), f"Expected list of rewards but got {type(rews)}" | ||
|
||
rew_values = rews | ||
else: | ||
assert isinstance( | ||
rews, dict | ||
), f"Expected dictionary of rewards but got {type(rews)}" | ||
rew_values = list(rews.values()) | ||
if return_numpy: | ||
assert all( | ||
isinstance(rew, np.ndarray) for rew in rew_values | ||
), f"Expected np.array rewards but got {type(rew_values[0])}" | ||
else: | ||
assert all( | ||
isinstance(rew, torch.Tensor) for rew in rew_values | ||
), f"Expected torch tensor rewards but got {type(rew_values[0])}" | ||
|
||
if return_numpy: | ||
assert isinstance( | ||
terminated, np.ndarray | ||
), f"Expected np.array for terminated but got {type(terminated)}" | ||
assert isinstance( | ||
truncated, np.ndarray | ||
), f"Expected np.array for truncated but got {type(truncated)}" | ||
else: | ||
assert isinstance( | ||
terminated, torch.Tensor | ||
), f"Expected torch tensor for terminated but got {type(terminated)}" | ||
assert isinstance( | ||
truncated, torch.Tensor | ||
), f"Expected torch tensor for truncated but got {type(truncated)}" | ||
|
||
assert isinstance( | ||
info, dict | ||
), f"Expected info to be a dictionary but got {type(info)}" | ||
|
||
assert all(truncated), "Expected done to be True after 100 steps" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import pytest | ||
|
||
import gymnasium as gym | ||
from vmas import make_env | ||
from vmas.simulator.environment import Environment | ||
from .test_gym_wrapper import scenario_names, _check_obs_type | ||
|
||
|
||
@pytest.mark.parametrize("scenario", scenario_names()) | ||
@pytest.mark.parametrize("return_numpy", [True, False]) | ||
@pytest.mark.parametrize("continuous_actions", [True, False]) | ||
@pytest.mark.parametrize("dict_space", [True, False]) | ||
def test_gymnasium_wrapper( | ||
scenario, return_numpy, continuous_actions, dict_space, max_steps=10 | ||
): | ||
env = make_env( | ||
scenario=scenario, | ||
num_envs=1, | ||
device="cpu", | ||
continuous_actions=continuous_actions, | ||
dict_spaces=dict_space, | ||
wrapper="gymnasium", | ||
terminated_truncated=True, | ||
wrapper_kwargs={"return_numpy": return_numpy}, | ||
max_steps=max_steps, | ||
) | ||
|
||
assert ( | ||
len(env.observation_space) == env.unwrapped.n_agents | ||
), "Expected one observation per agent" | ||
assert ( | ||
len(env.action_space) == env.unwrapped.n_agents | ||
), "Expected one action per agent" | ||
if dict_space: | ||
assert isinstance( | ||
env.observation_space, gym.spaces.Dict | ||
), "Expected Dict observation space" | ||
assert isinstance( | ||
env.action_space, gym.spaces.Dict | ||
), "Expected Dict action space" | ||
obs_shapes = [ | ||
obs_space.shape for obs_space in env.observation_space.spaces.values() | ||
] | ||
else: | ||
assert isinstance( | ||
env.observation_space, gym.spaces.Tuple | ||
), "Expected Tuple observation space" | ||
assert isinstance( | ||
env.action_space, gym.spaces.Tuple | ||
), "Expected Tuple action space" | ||
obs_shapes = [obs_space.shape for obs_space in env.observation_space.spaces] | ||
|
||
assert isinstance( | ||
env.unwrapped, Environment | ||
), "The unwrapped attribute of the Gym wrapper should be a VMAS Environment" | ||
|
||
obss, info = env.reset() | ||
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) | ||
assert isinstance( | ||
info, dict | ||
), f"Expected info to be a dictionary but got {type(info)}" | ||
|
||
for _ in range(max_steps): | ||
actions = env.unwrapped.get_random_actions() | ||
obss, rews, terminated, truncated, info = env.step(actions) | ||
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy) | ||
|
||
assert len(rews) == env.unwrapped.n_agents, "Expected one reward per agent" | ||
if not dict_space: | ||
assert isinstance( | ||
rews, list | ||
), f"Expected list of rewards but got {type(rews)}" | ||
|
||
rew_values = rews | ||
else: | ||
assert isinstance( | ||
rews, dict | ||
), f"Expected dictionary of rewards but got {type(rews)}" | ||
rew_values = list(rews.values()) | ||
assert all( | ||
isinstance(rew, float) for rew in rew_values | ||
), f"Expected float rewards but got {type(rew_values[0])}" | ||
|
||
assert isinstance( | ||
terminated, bool | ||
), f"Expected bool for terminated but got {type(terminated)}" | ||
assert isinstance( | ||
truncated, bool | ||
), f"Expected bool for truncated but got {type(truncated)}" | ||
|
||
assert isinstance( | ||
info, dict | ||
), f"Expected info to be a dictionary but got {type(info)}" | ||
|
||
assert truncated, "Expected done to be True after 100 steps" |
Oops, something went wrong.