Skip to content

Commit

Permalink
[Notbooks] Refresh notebooks (#90)
Browse files Browse the repository at this point in the history
* amend

* amend

* amend

* amend
  • Loading branch information
matteobettini authored Mar 7, 2024
1 parent ce242ab commit 99f562c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 35 deletions.
3 changes: 2 additions & 1 deletion notebooks/VMAS_RLlib.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
"! git clone https://github.com/proroklab/VectorizedMultiAgentSimulator.git\n",
"%cd /content/VectorizedMultiAgentSimulator\n",
"!pip install -e .\n",
"!pip install \"ray[rllib]\"==2.2 wandb"
"!pip install \"ray[rllib]\"==2.2 wandb\n",
"!pip install \"pydantic<2\" numpy==1.23.5"
]
},
{
Expand Down
67 changes: 34 additions & 33 deletions notebooks/VMAS_Use_vmas_environment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -77,38 +77,43 @@
"import time\n",
"import torch\n",
"from vmas import make_env\n",
"from vmas.simulator.core import Agent\n",
"\n",
"def _get_deterministic_action(agent: Agent, continuous: bool, env):\n",
" if continuous:\n",
" action = -agent.action.u_range_tensor.expand(env.batch_dim, agent.action_size)\n",
" else:\n",
" action = (\n",
" torch.tensor([1], device=env.device, dtype=torch.long)\n",
" .unsqueeze(-1)\n",
" .expand(env.batch_dim, 1)\n",
" )\n",
" return action.clone()\n",
"\n",
"def use_vmas_env(\n",
" render: bool = False,\n",
" save_render: bool = False,\n",
" num_envs: int = 32,\n",
" n_steps: int = 100,\n",
" device: str = \"cpu\",\n",
" scenario: Union[str, BaseScenario]= \"waterfall\",\n",
" n_agents: int = 4,\n",
" continuous_actions: bool = True,\n",
" render: bool,\n",
" num_envs: int,\n",
" n_steps: int,\n",
" device: str,\n",
" scenario: Union[str, BaseScenario],\n",
" continuous_actions: bool,\n",
" random_action: bool,\n",
" **kwargs\n",
"):\n",
" \"\"\"Example function to use a vmas environment\n",
" \"\"\"Example function to use a vmas environment.\n",
" \n",
" This is a simplification of the function in `vmas.examples.use_vmas_env.py`.\n",
"\n",
" Args:\n",
" continuous_actions (bool): Whether the agents have continuous or discrete actions\n",
" n_agents (int): Number of agents\n",
" scenario (str): Name of scenario\n",
" device (str): Torch device to use\n",
" render (bool): Whether to render the scenario\n",
" save_render (bool): Whether to save render of the scenario\n",
" num_envs (int): Number of vectorized environments\n",
" n_steps (int): Number of steps before returning done\n",
"\n",
" Returns:\n",
" random_action (bool): Use random actions or have all agents perform the down action\n",
"\n",
" \"\"\"\n",
" assert not (save_render and not render), \"To save the video you have to render it\"\n",
"\n",
" simple_2d_action = (\n",
" [0, -1.0] if continuous_actions else [3]\n",
" ) # Simple action for an agent with 2d actions\n",
"\n",
" scenario_name = scenario if isinstance(scenario,str) else scenario.__class__.__name__\n",
"\n",
Expand All @@ -117,10 +122,8 @@
" num_envs=num_envs,\n",
" device=device,\n",
" continuous_actions=continuous_actions,\n",
" wrapper=None,\n",
" seed=None,\n",
" seed=0,\n",
" # Environment specific variables\n",
" n_agents=n_agents,\n",
" **kwargs\n",
" )\n",
"\n",
Expand All @@ -134,31 +137,29 @@
"\n",
" actions = []\n",
" for i, agent in enumerate(env.agents):\n",
" action = torch.tensor(\n",
" simple_2d_action,\n",
" device=device,\n",
" ).repeat(num_envs, 1)\n",
" if not random_action:\n",
" action = _get_deterministic_action(agent, continuous_actions, env)\n",
" else:\n",
" action = env.get_random_action(agent)\n",
"\n",
" actions.append(action)\n",
"\n",
" obs, rews, dones, info = env.step(actions)\n",
"\n",
" if render:\n",
" frame = env.render(\n",
" mode=\"rgb_array\" if save_render else \"human\",\n",
" mode=\"rgb_array\",\n",
" agent_index_focus=None, # Can give the camera an agent index to focus on\n",
" visualize_when_rgb=True,\n",
" )\n",
" if save_render:\n",
" frame_list.append(frame)\n",
" frame_list.append(frame)\n",
"\n",
" total_time = time.time() - init_time\n",
" print(\n",
" f\"It took: {total_time}s for {n_steps} steps of {num_envs} parallel environments on device {device} \"\n",
" f\"for {scenario_name} scenario.\"\n",
" )\n",
"\n",
" if render and save_render:\n",
" if render:\n",
" from moviepy.editor import ImageSequenceClip\n",
" fps=30\n",
" clip = ImageSequenceClip(frame_list, fps=fps)\n",
Expand All @@ -177,11 +178,11 @@
"use_vmas_env(\n",
" scenario=scenario_name,\n",
" render=True,\n",
" save_render=True,\n",
" num_envs=32,\n",
" n_steps=150,\n",
" n_steps=100,\n",
" device=\"cuda\",\n",
" continuous_actions=True,\n",
" continuous_actions=False,\n",
" random_action=False,\n",
" # Environment specific variables\n",
" n_agents=4,\n",
")"
Expand Down
2 changes: 1 addition & 1 deletion vmas/examples/use_vmas_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _get_deterministic_action(agent: Agent, continuous: bool, env):
.unsqueeze(-1)
.expand(env.batch_dim, 1)
)
return action
return action.clone()


def use_vmas_env(
Expand Down

0 comments on commit 99f562c

Please sign in to comment.