diff --git a/vmas/scenarios/balance.py b/vmas/scenarios/balance.py index 036c3a94..cc4fbe43 100644 --- a/vmas/scenarios/balance.py +++ b/vmas/scenarios/balance.py @@ -26,6 +26,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): self.shaping_factor = 100 self.fall_reward = -10 + self.visualize_semidims = False + # Make world world = World(batch_dim, device, gravity=(0.0, -0.05), y_semidim=1) # Add agents diff --git a/vmas/scenarios/ball_passage.py b/vmas/scenarios/ball_passage.py index 6cde72c9..61cf3c53 100644 --- a/vmas/scenarios/ball_passage.py +++ b/vmas/scenarios/ball_passage.py @@ -33,6 +33,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): self.passage_width = 0.2 self.passage_length = 0.103 + self.visualize_semidims = False + # Make world world = World( batch_dim, diff --git a/vmas/scenarios/football.py b/vmas/scenarios/football.py index 3436a011..cbf39a99 100644 --- a/vmas/scenarios/football.py +++ b/vmas/scenarios/football.py @@ -17,6 +17,7 @@ class Scenario(BaseScenario): def make_world(self, batch_dim: int, device: torch.device, **kwargs): self.init_params(**kwargs) + self.visualize_semidims = False world = self.init_world(batch_dim, device) self.init_agents(world) self.init_ball(world) diff --git a/vmas/scenarios/joint_passage.py b/vmas/scenarios/joint_passage.py index 382efa02..cea4268a 100644 --- a/vmas/scenarios/joint_passage.py +++ b/vmas/scenarios/joint_passage.py @@ -65,6 +65,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): ScenarioUtils.check_kwargs_consumed(kwargs) self.plot_grid = True + self.visualize_semidims = False + # Make world world = World( batch_dim, diff --git a/vmas/scenarios/joint_passage_size.py b/vmas/scenarios/joint_passage_size.py index 0f3b22b1..befaf8b3 100644 --- a/vmas/scenarios/joint_passage_size.py +++ b/vmas/scenarios/joint_passage_size.py @@ -73,6 +73,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): assert self.n_passages == 3 or self.n_passages == 4 self.plot_grid = False + self.visualize_semidims = False # Make world world = World( diff --git a/vmas/scenarios/mpe/simple_tag.py b/vmas/scenarios/mpe/simple_tag.py index e4a88440..57eb14a2 100644 --- a/vmas/scenarios/mpe/simple_tag.py +++ b/vmas/scenarios/mpe/simple_tag.py @@ -26,6 +26,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): self.respawn_at_catch = kwargs.pop("respawn_at_catch", False) ScenarioUtils.check_kwargs_consumed(kwargs) + self.visualize_semidims = False + world = World( batch_dim=batch_dim, device=device, diff --git a/vmas/scenarios/passage.py b/vmas/scenarios/passage.py index 1be86b65..9633eaf6 100644 --- a/vmas/scenarios/passage.py +++ b/vmas/scenarios/passage.py @@ -26,6 +26,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): self.passage_width = 0.2 self.passage_length = 0.103 + self.visualize_semidims = False + # Make world world = World(batch_dim, device, x_semidim=1, y_semidim=1) # Add agents diff --git a/vmas/scenarios/road_traffic.py b/vmas/scenarios/road_traffic.py index cb608f84..556dfc85 100644 --- a/vmas/scenarios/road_traffic.py +++ b/vmas/scenarios/road_traffic.py @@ -57,6 +57,7 @@ class Scenario(BaseScenario): def make_world(self, batch_dim: int, device: torch.device, **kwargs): self.init_params(batch_dim, device, **kwargs) + self.visualize_semidims = False world = self.init_world(batch_dim, device) self.init_agents(world) return world diff --git a/vmas/scenarios/sampling.py b/vmas/scenarios/sampling.py index 79887317..fea1bcb3 100644 --- a/vmas/scenarios/sampling.py +++ b/vmas/scenarios/sampling.py @@ -43,6 +43,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): assert len(self.covs) == self.n_gaussians self.plot_grid = False + self.visualize_semidims = False self.n_x_cells = int((2 * self.xdim) / self.grid_spacing) self.n_y_cells = int((2 * self.ydim) / self.grid_spacing) self.max_pdf = torch.zeros((batch_dim,), device=device, dtype=torch.float32) diff --git a/vmas/simulator/environment/environment.py b/vmas/simulator/environment/environment.py index e96b1be4..9779a150 100644 --- a/vmas/simulator/environment/environment.py +++ b/vmas/simulator/environment/environment.py @@ -741,6 +741,9 @@ def render( ) # Render + if self.scenario.visualize_semidims: + self.plot_boundary() + self._set_agent_comm_messages(env_index) if plot_position_function is not None: @@ -770,6 +773,64 @@ def render( # render to display or array return self.viewer.render(return_rgb_array=mode == "rgb_array") + def plot_boundary(self): + # include boundaries in the rendering if the environment is dimension-limited + if self.world.x_semidim is not None or self.world.y_semidim is not None: + from vmas.simulator.rendering import Line + from vmas.simulator.utils import Color + + # set a big value for the cases where the environment is dimension-limited only in one coordinate + infinite_value = 100 + + x_semi = ( + self.world.x_semidim + if self.world.x_semidim is not None + else infinite_value + ) + y_semi = ( + self.world.y_semidim + if self.world.y_semidim is not None + else infinite_value + ) + + # set the color for the boundary line + color = Color.GRAY.value + + # Define boundary points based on whether world semidims are provided + if ( + self.world.x_semidim is not None and self.world.y_semidim is not None + ) or self.world.y_semidim is not None: + boundary_points = [ + (-x_semi, y_semi), + (x_semi, y_semi), + (x_semi, -y_semi), + (-x_semi, -y_semi), + ] + else: + boundary_points = [ + (-x_semi, y_semi), + (-x_semi, -y_semi), + (x_semi, y_semi), + (x_semi, -y_semi), + ] + + # Create lines by connecting points + for i in range( + 0, + len(boundary_points), + 1 + if ( + self.world.x_semidim is not None + and self.world.y_semidim is not None + ) + else 2, + ): + start = boundary_points[i] + end = boundary_points[(i + 1) % len(boundary_points)] + line = Line(start, end, width=0.7) + line.set_color(*color) + self.viewer.add_onetime(line) + def plot_function( self, f, precision, plot_range, cmap_range, cmap_alpha, cmap_name ): diff --git a/vmas/simulator/scenario.py b/vmas/simulator/scenario.py index 667790fe..bae6a6f4 100644 --- a/vmas/simulator/scenario.py +++ b/vmas/simulator/scenario.py @@ -56,6 +56,8 @@ def __init__(self): """Whether to plot a grid in the scenario rendering background. This can be changed in the :class:`~make_world` function. """ self.grid_spacing = 0.1 """If :class:`~plot_grid`, the distance between lines in the background grid. This can be changed in the :class:`~make_world` function. """ + self.visualize_semidims = True + """Whether to display boundaries in dimension-limited environment. This can be changed in the :class:`~make_world` function. """ @property def world(self):