Skip to content

Commit

Permalink
Rays observation visualization (#22)
Browse files Browse the repository at this point in the history
* Rays observation visualization

* fix CI
  • Loading branch information
nicolaloi authored Jan 31, 2025
1 parent 889401a commit d4b5b52
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 26 deletions.
133 changes: 107 additions & 26 deletions waymax/visualization/gokart_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from functools import partial
from typing import Any, Iterable, List, Optional
import typing

import jax
import matplotlib
Expand All @@ -25,45 +26,51 @@

from waymax import config as waymax_config
from waymax import datatypes
from waymax.datatypes import operations
from waymax.utils import geometry
from waymax.visualization import color
from waymax.visualization import utils
from waymax.visualization import viz

import matplotlib.pyplot as plt

GokartObs = typing.TypeVar("GokartObs")


def create_video_simulator_state(
state: datatypes.SimulatorState,
obs: GokartObs | None = None,
video_path: str | None = None,
use_log_traj: bool = True,
n_steps: int = 10,
interval: int = 100,
batch_idx: int = -1,
highlight_obj: waymax_config.ObjectType = waymax_config.ObjectType.SDC,
ref: bool = False,
rays_length: np.ndarray | None = None,
viz_config: Optional[dict[str, Any]] = None,
) -> List[np.ndarray]:
"""
Make an animation of the simulator state. Return the list of numpy matrices representing the video frames if
video_path is None. Otherwise, the video is saved to disk at video_path path and an empty list is returned.
Retuning the list of numpy matrices is faster (2-3x) than saving the video to disk.
"""

video_plotter = VideoPlotter(
video_path, viz_config, skip_traffic_light=True, plot_last_history_only=True, faster_axis_origin=True
)

video_plotter.create_fig()

imgs = video_plotter.plot_sequence_simulator_state(
state, use_log_traj, n_steps, interval, batch_idx, highlight_obj, ref, rays_length
state,
obs,
use_log_traj,
n_steps,
interval,
batch_idx,
highlight_obj,
ref,
)

video_plotter.close_fig()

return imgs


Expand All @@ -85,13 +92,17 @@ def __init__(
self.plot_last_history_only = plot_last_history_only
self.faster_axis_origin = faster_axis_origin

self.fig, self.ax = None, None
# Initialized like this to remove obnoxious pylance errors when using self.ax.plot
fig_ax = self.create_fig()
self.fig: matplotlib.figure.Figure = fig_ax[0]
self.ax: matplotlib.axes.Axes = fig_ax[1]

self.trajectory_lines = None
self.history_lines = None
self.context_lines = None
self.overlap_lines = None
self.reference_lines = None
self.rays_lines = None
self.text = None
self.roadgraph_lines_dict = None

Expand All @@ -103,10 +114,11 @@ def __init__(
def __del__(self):
self.close_fig()

def create_fig(self):
self.fig, self.ax = utils.init_fig_ax(self.viz_config)
def create_fig(self) -> tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]:
fig, ax = utils.init_fig_ax(self.viz_config)
# Just enough margin in the figure to display xticks and yticks.
self.fig.subplots_adjust(left=0.08, bottom=0.08, right=0.98, top=0.98, wspace=0.0, hspace=0.0)
fig.subplots_adjust(left=0.08, bottom=0.08, right=0.98, top=0.98, wspace=0.0, hspace=0.0)
return fig, ax

def close_fig(self):
if self.fig is not None:
Expand All @@ -116,57 +128,64 @@ def close_fig(self):
def plot_sequence_simulator_state(
self,
state: datatypes.SimulatorState,
obs: GokartObs | None,
use_log_traj: bool = True,
n_steps: int = 10,
interval: int = 100,
batch_idx: int = -1,
highlight_obj: waymax_config.ObjectType = waymax_config.ObjectType.SDC,
ref: bool = False,
rays_length: np.ndarray | None = None,
) -> List[np.ndarray]:
"""
Make an animation of the simulator state. Return the list of numpy matrices representing the video frames if
video_path is None. Otherwise, the video is saved to disk at video_path path and an empty list is returned.
Args:
state: A SimulatorState instance.
obs: Optional observation instance.
use_log_traj: Set True to use logged trajectory, o/w uses simulated
trajectory.
viz_config: dict for optional config.
batch_idx: optional batch index.
highlight_obj: Represents the type of objects that will be highlighted with
`color.COLOR_DICT['controlled']` color.
ref: Set True to plot reference trajectory.
rays_length: The length of the rays to plot.
Returns:
list of np images if video_path is None, otherwise an empty list since the video will be saved to disk.
"""
if self.fig is None or self.ax is None:
self.create_fig(self.viz_config)
imgs = []

if batch_idx > -1:
if len(state.shape) != 1:
raise ValueError(f"Expecting one batch dimension, got {len(state.shape)}")
state = viz._index_pytree(state, batch_idx)
obs = viz._index_pytree(obs, batch_idx)

if self.video_path is not None:

def animate_step(
i: int,
state: datatypes.GoKartSimState,
obs: GokartObs | None,
use_log_traj: bool,
highlight_obj: waymax_config.ObjectType,
ref: bool,
rays_length: np.ndarray | None,
) -> Iterable[matplotlib.lines.Line2D]:
state = state.replace(timestep=i)
obs_t = operations.dynamic_slice(obs, i, 1, axis=0) if obs is not None else None

self.plot_simulator_state(state, use_log_traj, highlight_obj, ref, rays_length)
self.plot_simulator_state(state, use_log_traj, highlight_obj, ref, obs)

artists = []
for line in [self.trajectory_lines, self.context_lines, self.overlap_lines, self.reference_lines]:
for line in [
self.trajectory_lines,
self.context_lines,
self.overlap_lines,
self.reference_lines,
self.rays_lines,
]:
if line is not None:
artists.extend(line)
if not self.plot_last_history_only and self.history_lines is not None:
Expand All @@ -179,10 +198,10 @@ def animate_step(
partial_animate_step = partial(
animate_step,
state=state,
obs=obs_t,
use_log_traj=use_log_traj,
highlight_obj=highlight_obj,
ref=ref,
rays_length=rays_length,
)
ani = FuncAnimation(
self.fig, partial_animate_step, frames=n_steps, repeat=False, interval=interval, blit=True
Expand All @@ -198,7 +217,8 @@ def animate_step(
self.ax_background = self.fig.canvas.copy_from_bbox(self.ax.bbox)
for i in range(n_steps):
state = state.replace(timestep=i)
self.plot_simulator_state(state, use_log_traj, highlight_obj, ref, rays_length)
obs_t = operations.dynamic_slice(obs, i, 1, axis=0) if obs is not None else None
self.plot_simulator_state(state, obs_t, use_log_traj, highlight_obj, ref)
img = self.img_from_fig(close_fig=False, clear_fig=False, blit=blit)
imgs.append(img)

Expand All @@ -209,15 +229,16 @@ def animate_step(
def plot_simulator_state(
self,
state: datatypes.SimulatorState,
obs: GokartObs | None = None,
use_log_traj: bool = True,
highlight_obj: waymax_config.ObjectType = waymax_config.ObjectType.SDC,
ref: bool = False,
rays_length: np.ndarray | None = None,
) -> None:
"""Plots np array image for SimulatorState.
Args:
state: A SimulatorState instance.
obs: Optional observation instance.
use_log_traj: Set True to use logged trajectory, o/w uses simulated
trajectory.
viz_config: dict for optional config.
Expand Down Expand Up @@ -260,12 +281,8 @@ def plot_simulator_state(
traj_5dof[valid_controlled][::5, 1],
)

if rays_length is not None:
position = traj.xy[0, state.timestep, :]
yaw = traj.yaw[0, state.timestep]
rays_length = rays_length[batch_idx, state.timestep, :]
utils.plot_numpy_rays(self.ax, position, yaw, color=np.array([1.0, 0.65, 0.0]), rays_length=rays_length)
pass
if obs is not None:
self.plot_obs(traj=traj, timestep=state.timestep, obs=obs)

# 2. Plots road graph elements.
# assume roadgraph points do not change over time. Plot only once.
Expand Down Expand Up @@ -298,6 +315,26 @@ def plot_simulator_state(
)
)

def plot_obs(
self,
traj: datatypes.Trajectory,
timestep: int,
obs: GokartObs,
) -> None:

# plot rays
position = traj.xy[0, timestep, :]
yaw = traj.yaw[0, timestep]
# rays_length[0, timestep, :]
self.plot_numpy_rays(
position,
yaw,
color=np.array([1.0, 0.65, 0.0]),
rays_length=obs.distances.squeeze(),
rays_angles_base=obs._angles_rays,
alpha=0.9,
)

def plot_trajectory(
self,
traj: datatypes.Trajectory,
Expand Down Expand Up @@ -534,6 +571,49 @@ def plot_numpy_bounding_boxes(

setattr(self, line_name, lines)

def plot_numpy_rays(
self,
position: np.ndarray,
yaw: np.ndarray,
color: np.ndarray,
rays_length: np.ndarray,
rays_angles_base: np.ndarray,
alpha: Optional[float] = 1.0,
) -> None:
"""
Plots rays originating from a given position and orientation.
Args:
position: Array of shape (2,), representing the start position (x, y) of the rays.
yaw: Array of shape (1,), representing the orientation angle of the source in radians.
color: Array of shape (3,), representing the RGB color for the rays.
rays_length: Array of shape (num_rays,), representing the length of each ray.
rays_angles_base: Array of shape (num_rays,), representing the angles of the rays
in radians, relative to the sensor base.
num_rays: Number of rays to cast in the range of [-pi/2, pi/2] relative to the orientation.
alpha: Alpha value for drawing, where 0 is fully transparent.
"""

# Calculate angles for each ray relative to the orientation of the source.
angles = rays_angles_base + yaw
rays = position[:, None] + rays_length * np.array([np.cos(angles), np.sin(angles)])
plot_xy = np.empty((2, rays.shape[1] * 2))
plot_xy[:, ::2] = position[:, None]
plot_xy[:, 1::2] = rays

if self.rays_lines is not None:
self.rays_lines[0].set_data(plot_xy[0], plot_xy[1])
else:
self.rays_lines = self.ax.plot(
plot_xy[0],
plot_xy[1],
":",
color=color,
alpha=alpha,
zorder=4,
linewidth=1.0,
)

def plot_roadgraph_points(
self,
rg_pts: datatypes.RoadgraphPoints,
Expand Down Expand Up @@ -580,6 +660,7 @@ def img_from_fig(self, close_fig: bool = True, clear_fig: bool = False, blit: bo
self.context_lines,
self.overlap_lines,
self.reference_lines,
self.rays_lines,
]:
if lines is not None:
for line in lines:
Expand Down
1 change: 1 addition & 0 deletions waymax/visualization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class VizConfig:
front_y: float = 75.0
back_y: float = 75.0
px_per_meter: float = 4.0
viz_obs: bool = False
show_agent_id: bool = True
center_agent_idx: int = -1 # -1 for SDC
verbose: bool = True
Expand Down

0 comments on commit d4b5b52

Please sign in to comment.