Skip to content

Commit

Permalink
update vmas wrapper base class, move wrappers and add wrapper tests
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasSchaefer committed Sep 19, 2024
1 parent 76278ba commit 67ed9dd
Show file tree
Hide file tree
Showing 12 changed files with 479 additions and 200 deletions.
Empty file added tests/test_wrappers/__init__.py
Empty file.
117 changes: 117 additions & 0 deletions tests/test_wrappers/test_gym_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from pathlib import Path

import gym
import numpy as np
import pytest
from torch import Tensor

from vmas import make_env
from vmas.simulator.environment import Environment


def scenario_names():
scenarios = []
scenarios_folder = Path(__file__).parent.parent.parent / "vmas" / "scenarios"
for path in scenarios_folder.iterdir():
if path.is_file() and path.suffix == ".py" and not path.name.startswith("__"):
scenarios.append(path.stem)
return scenarios


def _check_obs_type(obss, obs_shapes, dict_space, return_numpy):
if dict_space:
assert isinstance(
obss, dict
), f"Expected dictionary of observations, got {type(obss)}"
obss = list(obss.values())
else:
assert isinstance(
obss, list
), f"Expected list of observations, got {type(obss)}"
for o, shape in zip(obss, obs_shapes):
if return_numpy:
assert isinstance(o, np.ndarray), f"Expected numpy array, got {type(o)}"
assert o.shape == shape, f"Expected shape {shape}, got {o.shape}"
else:
assert isinstance(o, Tensor), f"Expected torch tensor, got {type(o)}"
assert o.shape == shape, f"Expected shape {shape}, got {o.shape}"


@pytest.mark.parametrize("scenario", scenario_names())
@pytest.mark.parametrize("return_numpy", [True, False])
@pytest.mark.parametrize("continuous_actions", [True, False])
@pytest.mark.parametrize("dict_space", [True, False])
def test_gym_wrapper(
scenario, return_numpy, continuous_actions, dict_space, max_steps=10
):
env = make_env(
scenario=scenario,
num_envs=1,
device="cpu",
continuous_actions=continuous_actions,
dict_spaces=dict_space,
wrapper="gym",
wrapper_kwargs={"return_numpy": return_numpy},
max_steps=max_steps,
)

assert (
len(env.observation_space) == env.unwrapped.n_agents
), "Expected one observation per agent"
assert (
len(env.action_space) == env.unwrapped.n_agents
), "Expected one action per agent"
if dict_space:
assert isinstance(
env.observation_space, gym.spaces.Dict
), "Expected Dict observation space"
assert isinstance(
env.action_space, gym.spaces.Dict
), "Expected Dict action space"
obs_shapes = [
obs_space.shape for obs_space in env.observation_space.spaces.values()
]
else:
assert isinstance(
env.observation_space, gym.spaces.Tuple
), "Expected Tuple observation space"
assert isinstance(
env.action_space, gym.spaces.Tuple
), "Expected Tuple action space"
obs_shapes = [obs_space.shape for obs_space in env.observation_space.spaces]

assert isinstance(
env.unwrapped, Environment
), "The unwrapped attribute of the Gym wrapper should be a VMAS Environment"

obss = env.reset()
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy)

for _ in range(max_steps):
actions = env.unwrapped.get_random_actions()
obss, rews, done, info = env.step(actions)
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy)

assert len(rews) == env.unwrapped.n_agents, "Expected one reward per agent"
if not dict_space:
assert isinstance(
rews, list
), f"Expected list of rewards but got {type(rews)}"

rew_values = rews
else:
assert isinstance(
rews, dict
), f"Expected dictionary of rewards but got {type(rews)}"
rew_values = list(rews.values())
assert all(
isinstance(rew, float) for rew in rew_values
), f"Expected float rewards but got {type(rew_values[0])}"

assert isinstance(done, bool), f"Expected bool for done but got {type(done)}"

assert isinstance(
info, dict
), f"Expected info to be a dictionary but got {type(info)}"

assert done, "Expected done to be True after 100 steps"
112 changes: 112 additions & 0 deletions tests/test_wrappers/test_gymnasium_vec_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import pytest

import gymnasium as gym
import numpy as np
from vmas import make_env
from vmas.simulator.environment import Environment
import torch

from .test_gym_wrapper import scenario_names, _check_obs_type


@pytest.mark.parametrize("scenario", scenario_names())
@pytest.mark.parametrize("return_numpy", [True, False])
@pytest.mark.parametrize("continuous_actions", [True, False])
@pytest.mark.parametrize("dict_space", [True, False])
@pytest.mark.parametrize("num_envs", [1, 10])
def test_gymnasium_wrapper(
scenario, return_numpy, continuous_actions, dict_space, num_envs, max_steps=10
):
env = make_env(
scenario=scenario,
num_envs=num_envs,
device="cpu",
continuous_actions=continuous_actions,
dict_spaces=dict_space,
wrapper="gymnasium_vec",
terminated_truncated=True,
wrapper_kwargs={"return_numpy": return_numpy},
max_steps=max_steps,
)

assert isinstance(
env.unwrapped, Environment
), "The unwrapped attribute of the Gym wrapper should be a VMAS Environment"

assert (
len(env.observation_space) == env.unwrapped.n_agents
), "Expected one observation per agent"
assert (
len(env.action_space) == env.unwrapped.n_agents
), "Expected one action per agent"
if dict_space:
assert isinstance(
env.observation_space, gym.spaces.Dict
), "Expected Dict observation space"
assert isinstance(
env.action_space, gym.spaces.Dict
), "Expected Dict action space"
obs_shapes = [
obs_space.shape for obs_space in env.observation_space.spaces.values()
]
else:
assert isinstance(
env.observation_space, gym.spaces.Tuple
), "Expected Tuple observation space"
assert isinstance(
env.action_space, gym.spaces.Tuple
), "Expected Tuple action space"
obs_shapes = [obs_space.shape for obs_space in env.observation_space.spaces]

obss, info = env.reset()
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy)
assert isinstance(
info, dict
), f"Expected info to be a dictionary but got {type(info)}"

for _ in range(max_steps):
actions = env.unwrapped.get_random_actions()
obss, rews, terminated, truncated, info = env.step(actions)
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy)

assert len(rews) == env.unwrapped.n_agents, "Expected one reward per agent"
if not dict_space:
assert isinstance(
rews, list
), f"Expected list of rewards but got {type(rews)}"

rew_values = rews
else:
assert isinstance(
rews, dict
), f"Expected dictionary of rewards but got {type(rews)}"
rew_values = list(rews.values())
if return_numpy:
assert all(
isinstance(rew, np.ndarray) for rew in rew_values
), f"Expected np.array rewards but got {type(rew_values[0])}"
else:
assert all(
isinstance(rew, torch.Tensor) for rew in rew_values
), f"Expected torch tensor rewards but got {type(rew_values[0])}"

if return_numpy:
assert isinstance(
terminated, np.ndarray
), f"Expected np.array for terminated but got {type(terminated)}"
assert isinstance(
truncated, np.ndarray
), f"Expected np.array for truncated but got {type(truncated)}"
else:
assert isinstance(
terminated, torch.Tensor
), f"Expected torch tensor for terminated but got {type(terminated)}"
assert isinstance(
truncated, torch.Tensor
), f"Expected torch tensor for truncated but got {type(truncated)}"

assert isinstance(
info, dict
), f"Expected info to be a dictionary but got {type(info)}"

assert all(truncated), "Expected done to be True after 100 steps"
95 changes: 95 additions & 0 deletions tests/test_wrappers/test_gymnasium_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import pytest

import gymnasium as gym
from vmas import make_env
from vmas.simulator.environment import Environment
from .test_gym_wrapper import scenario_names, _check_obs_type


@pytest.mark.parametrize("scenario", scenario_names())
@pytest.mark.parametrize("return_numpy", [True, False])
@pytest.mark.parametrize("continuous_actions", [True, False])
@pytest.mark.parametrize("dict_space", [True, False])
def test_gymnasium_wrapper(
scenario, return_numpy, continuous_actions, dict_space, max_steps=10
):
env = make_env(
scenario=scenario,
num_envs=1,
device="cpu",
continuous_actions=continuous_actions,
dict_spaces=dict_space,
wrapper="gymnasium",
terminated_truncated=True,
wrapper_kwargs={"return_numpy": return_numpy},
max_steps=max_steps,
)

assert (
len(env.observation_space) == env.unwrapped.n_agents
), "Expected one observation per agent"
assert (
len(env.action_space) == env.unwrapped.n_agents
), "Expected one action per agent"
if dict_space:
assert isinstance(
env.observation_space, gym.spaces.Dict
), "Expected Dict observation space"
assert isinstance(
env.action_space, gym.spaces.Dict
), "Expected Dict action space"
obs_shapes = [
obs_space.shape for obs_space in env.observation_space.spaces.values()
]
else:
assert isinstance(
env.observation_space, gym.spaces.Tuple
), "Expected Tuple observation space"
assert isinstance(
env.action_space, gym.spaces.Tuple
), "Expected Tuple action space"
obs_shapes = [obs_space.shape for obs_space in env.observation_space.spaces]

assert isinstance(
env.unwrapped, Environment
), "The unwrapped attribute of the Gym wrapper should be a VMAS Environment"

obss, info = env.reset()
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy)
assert isinstance(
info, dict
), f"Expected info to be a dictionary but got {type(info)}"

for _ in range(max_steps):
actions = env.unwrapped.get_random_actions()
obss, rews, terminated, truncated, info = env.step(actions)
_check_obs_type(obss, obs_shapes, dict_space, return_numpy=return_numpy)

assert len(rews) == env.unwrapped.n_agents, "Expected one reward per agent"
if not dict_space:
assert isinstance(
rews, list
), f"Expected list of rewards but got {type(rews)}"

rew_values = rews
else:
assert isinstance(
rews, dict
), f"Expected dictionary of rewards but got {type(rews)}"
rew_values = list(rews.values())
assert all(
isinstance(rew, float) for rew in rew_values
), f"Expected float rewards but got {type(rew_values[0])}"

assert isinstance(
terminated, bool
), f"Expected bool for terminated but got {type(terminated)}"
assert isinstance(
truncated, bool
), f"Expected bool for truncated but got {type(truncated)}"

assert isinstance(
info, dict
), f"Expected info to be a dictionary but got {type(info)}"

assert truncated, "Expected done to be True after 100 steps"
Loading

0 comments on commit 67ed9dd

Please sign in to comment.