From cde42869a8cbca47dce539fef2863685a05d64c0 Mon Sep 17 00:00:00 2001 From: Daphne Cornelisse Date: Fri, 3 Jan 2025 18:24:47 -0500 Subject: [PATCH] Visualizer improvements --- pygpudrive/env/env_torch.py | 11 ++++++ pygpudrive/visualize/core.py | 66 ++++++++++++++++++++++-------------- 2 files changed, 51 insertions(+), 26 deletions(-) diff --git a/pygpudrive/env/env_torch.py b/pygpudrive/env/env_torch.py index 02fb895c7..60971ea62 100755 --- a/pygpudrive/env/env_torch.py +++ b/pygpudrive/env/env_torch.py @@ -508,6 +508,17 @@ def swap_data_batch(self, data_batch=None): self.cont_agent_mask.sum().item() ) + # Reinitialize the visualizer with the new data + # TODO: Improve + self.vis = MatplotlibVisualizer( + sim_object=self.sim, + goal_radius=self.config.dist_to_goal_threshold, + backend=self.backend, + num_worlds=self.num_worlds, + render_config=self.render_config, + env_config=self.config, + ) + def get_expert_actions(self): """Get expert actions for the full trajectories across worlds. diff --git a/pygpudrive/visualize/core.py b/pygpudrive/visualize/core.py index 32ea162f5..cdba64a6e 100644 --- a/pygpudrive/visualize/core.py +++ b/pygpudrive/visualize/core.py @@ -47,6 +47,16 @@ def __init__( self.num_worlds = num_worlds self.render_config = render_config self.env_config = env_config + self.response_type = ResponseType.from_tensor( + tensor=self.sim_object.response_type_tensor(), + backend=self.backend, + device=self.device, + ) + self.global_roadgraph = GlobalRoadGraphPoints.from_tensor( + roadgraph_tensor=self.sim_object.map_observation_tensor(), + backend=self.backend, + device=self.device, + ) def get_controlled_agents_mask(self): """Get the control mask.""" @@ -88,22 +98,11 @@ def plot_simulator_state( env_indices ) # Default to None for all - # Extract data for all environments - global_roadgraph = GlobalRoadGraphPoints.from_tensor( - roadgraph_tensor=self.sim_object.map_observation_tensor(), - backend=self.backend, - device=self.device, - ) global_agent_states = GlobalEgoState.from_tensor( self.sim_object.absolute_self_observation_tensor(), backend=self.backend, device=self.device, ) - response_type = ResponseType.from_tensor( - tensor=self.sim_object.response_type_tensor(), - backend=self.backend, - device=self.device, - ) agent_infos = self.sim_object.info_tensor().to_torch().to(self.device) @@ -131,6 +130,9 @@ def plot_simulator_state( squeeze=False, ) axes = axes.flatten() + + for idx in range(len(axes)): + axes[idx].clear() # Clear each subplot else: axes = [None] * len(env_indices) @@ -149,9 +151,10 @@ def plot_simulator_state( ax.set_aspect("equal", adjustable="box") else: fig, ax = plt.subplots(figsize=figsize) + ax.clear() # Clear any existing content ax.set_aspect("equal", adjustable="box") - ax.clear() - figs.append(fig) + figs.append(fig) # Add the new figure + plt.close(fig) # Close the figure to prevent carryover # Get control mask and omit out-of-bound agents (dead agents) controlled = self.controlled_agents[env_idx, :] @@ -167,7 +170,7 @@ def plot_simulator_state( # Draw the road graph self._plot_roadgraph( - road_graph=global_roadgraph, + road_graph=self.global_roadgraph, env_idx=env_idx, ax=ax, line_width_scale=line_width_scale, @@ -191,7 +194,7 @@ def plot_simulator_state( is_ok_mask=is_ok, is_offroad_mask=is_offroad, is_collided_mask=is_collided, - response_type=response_type, + response_type=self.response_type, alpha=1.0, line_width_scale=line_width_scale, marker_size_scale=marker_scale, @@ -237,8 +240,10 @@ def plot_simulator_state( if return_single_figure: for ax in axes[len(env_indices) :]: + ax.clear() ax.axis("off") # Hide unused subplots plt.tight_layout() + plt.close(fig) # Close the figure to prevent carryover return fig else: return figs @@ -581,19 +586,30 @@ def plot_agent_observation( if observation_roadgraph is not None: for road_type, type_name in ROAD_GRAPH_TYPE_NAMES.items(): mask = ( - observation_roadgraph.type[env_idx, agent_idx, :] == road_type + observation_roadgraph.type[env_idx, agent_idx, :] + == road_type ) # Extract relevant roadgraph data for plotting x_points = observation_roadgraph.x[env_idx, agent_idx, mask] y_points = observation_roadgraph.y[env_idx, agent_idx, mask] - orientations = observation_roadgraph.orientation[env_idx, agent_idx, mask] - segment_lengths = observation_roadgraph.segment_length[env_idx, agent_idx, mask] - widths = observation_roadgraph.segment_width[env_idx, agent_idx, mask] + orientations = observation_roadgraph.orientation[ + env_idx, agent_idx, mask + ] + segment_lengths = observation_roadgraph.segment_length[ + env_idx, agent_idx, mask + ] + widths = observation_roadgraph.segment_width[ + env_idx, agent_idx, mask + ] # Scatter plot for the points ax.scatter( - x_points, y_points, c=[ROAD_GRAPH_COLORS[road_type]], s=8, label=type_name + x_points, + y_points, + c=[ROAD_GRAPH_COLORS[road_type]], + s=8, + label=type_name, ) # Plot lines for road edges @@ -619,32 +635,30 @@ def plot_agent_observation( [y_start - width_dy, y_end - width_dy], color=ROAD_GRAPH_COLORS[road_type], alpha=0.5, - linewidth=1.0 + linewidth=1.0, ) ax.plot( [x_start + width_dx, x_end + width_dx], [y_start + width_dy, y_end + width_dy], color=ROAD_GRAPH_COLORS[road_type], alpha=0.5, - linewidth=1.0 + linewidth=1.0, ) ax.plot( [x_start - width_dx, x_start + width_dx], [y_start - width_dy, y_start + width_dy], color=ROAD_GRAPH_COLORS[road_type], alpha=0.5, - linewidth=1.0 + linewidth=1.0, ) ax.plot( [x_end - width_dx, x_end + width_dx], [y_end - width_dy, y_end + width_dy], color=ROAD_GRAPH_COLORS[road_type], alpha=0.5, - linewidth=1.0 + linewidth=1.0, ) - - # Plot partner agents if provided if observation_partner is not None: partner_positions = torch.stack(