Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Visualizer improvements
Browse files Browse the repository at this point in the history
daphne-cornelisse committed Jan 3, 2025
1 parent 813988c commit cde4286
Showing 2 changed files with 51 additions and 26 deletions.
11 changes: 11 additions & 0 deletions pygpudrive/env/env_torch.py
Original file line number Diff line number Diff line change
@@ -508,6 +508,17 @@ def swap_data_batch(self, data_batch=None):
self.cont_agent_mask.sum().item()
)

# Reinitialize the visualizer with the new data
# TODO: Improve
self.vis = MatplotlibVisualizer(
sim_object=self.sim,
goal_radius=self.config.dist_to_goal_threshold,
backend=self.backend,
num_worlds=self.num_worlds,
render_config=self.render_config,
env_config=self.config,
)

def get_expert_actions(self):
"""Get expert actions for the full trajectories across worlds.
66 changes: 40 additions & 26 deletions pygpudrive/visualize/core.py
Original file line number Diff line number Diff line change
@@ -47,6 +47,16 @@ def __init__(
self.num_worlds = num_worlds
self.render_config = render_config
self.env_config = env_config
self.response_type = ResponseType.from_tensor(
tensor=self.sim_object.response_type_tensor(),
backend=self.backend,
device=self.device,
)
self.global_roadgraph = GlobalRoadGraphPoints.from_tensor(
roadgraph_tensor=self.sim_object.map_observation_tensor(),
backend=self.backend,
device=self.device,
)

def get_controlled_agents_mask(self):
"""Get the control mask."""
@@ -88,22 +98,11 @@ def plot_simulator_state(
env_indices
) # Default to None for all

# Extract data for all environments
global_roadgraph = GlobalRoadGraphPoints.from_tensor(
roadgraph_tensor=self.sim_object.map_observation_tensor(),
backend=self.backend,
device=self.device,
)
global_agent_states = GlobalEgoState.from_tensor(
self.sim_object.absolute_self_observation_tensor(),
backend=self.backend,
device=self.device,
)
response_type = ResponseType.from_tensor(
tensor=self.sim_object.response_type_tensor(),
backend=self.backend,
device=self.device,
)

agent_infos = self.sim_object.info_tensor().to_torch().to(self.device)

@@ -131,6 +130,9 @@ def plot_simulator_state(
squeeze=False,
)
axes = axes.flatten()

for idx in range(len(axes)):
axes[idx].clear() # Clear each subplot
else:
axes = [None] * len(env_indices)

@@ -149,9 +151,10 @@ def plot_simulator_state(
ax.set_aspect("equal", adjustable="box")
else:
fig, ax = plt.subplots(figsize=figsize)
ax.clear() # Clear any existing content
ax.set_aspect("equal", adjustable="box")
ax.clear()
figs.append(fig)
figs.append(fig) # Add the new figure
plt.close(fig) # Close the figure to prevent carryover

# Get control mask and omit out-of-bound agents (dead agents)
controlled = self.controlled_agents[env_idx, :]
@@ -167,7 +170,7 @@ def plot_simulator_state(

# Draw the road graph
self._plot_roadgraph(
road_graph=global_roadgraph,
road_graph=self.global_roadgraph,
env_idx=env_idx,
ax=ax,
line_width_scale=line_width_scale,
@@ -191,7 +194,7 @@ def plot_simulator_state(
is_ok_mask=is_ok,
is_offroad_mask=is_offroad,
is_collided_mask=is_collided,
response_type=response_type,
response_type=self.response_type,
alpha=1.0,
line_width_scale=line_width_scale,
marker_size_scale=marker_scale,
@@ -237,8 +240,10 @@ def plot_simulator_state(

if return_single_figure:
for ax in axes[len(env_indices) :]:
ax.clear()
ax.axis("off") # Hide unused subplots
plt.tight_layout()
plt.close(fig) # Close the figure to prevent carryover
return fig
else:
return figs
@@ -581,19 +586,30 @@ def plot_agent_observation(
if observation_roadgraph is not None:
for road_type, type_name in ROAD_GRAPH_TYPE_NAMES.items():
mask = (
observation_roadgraph.type[env_idx, agent_idx, :] == road_type
observation_roadgraph.type[env_idx, agent_idx, :]
== road_type
)

# Extract relevant roadgraph data for plotting
x_points = observation_roadgraph.x[env_idx, agent_idx, mask]
y_points = observation_roadgraph.y[env_idx, agent_idx, mask]
orientations = observation_roadgraph.orientation[env_idx, agent_idx, mask]
segment_lengths = observation_roadgraph.segment_length[env_idx, agent_idx, mask]
widths = observation_roadgraph.segment_width[env_idx, agent_idx, mask]
orientations = observation_roadgraph.orientation[
env_idx, agent_idx, mask
]
segment_lengths = observation_roadgraph.segment_length[
env_idx, agent_idx, mask
]
widths = observation_roadgraph.segment_width[
env_idx, agent_idx, mask
]

# Scatter plot for the points
ax.scatter(
x_points, y_points, c=[ROAD_GRAPH_COLORS[road_type]], s=8, label=type_name
x_points,
y_points,
c=[ROAD_GRAPH_COLORS[road_type]],
s=8,
label=type_name,
)

# Plot lines for road edges
@@ -619,32 +635,30 @@ def plot_agent_observation(
[y_start - width_dy, y_end - width_dy],
color=ROAD_GRAPH_COLORS[road_type],
alpha=0.5,
linewidth=1.0
linewidth=1.0,
)
ax.plot(
[x_start + width_dx, x_end + width_dx],
[y_start + width_dy, y_end + width_dy],
color=ROAD_GRAPH_COLORS[road_type],
alpha=0.5,
linewidth=1.0
linewidth=1.0,
)
ax.plot(
[x_start - width_dx, x_start + width_dx],
[y_start - width_dy, y_start + width_dy],
color=ROAD_GRAPH_COLORS[road_type],
alpha=0.5,
linewidth=1.0
linewidth=1.0,
)
ax.plot(
[x_end - width_dx, x_end + width_dx],
[y_end - width_dy, y_end + width_dy],
color=ROAD_GRAPH_COLORS[road_type],
alpha=0.5,
linewidth=1.0
linewidth=1.0,
)



# Plot partner agents if provided
if observation_partner is not None:
partner_positions = torch.stack(

0 comments on commit cde4286

Please sign in to comment.