Skip to content

Commit

Permalink
Revert "Cach branch (#328)"
Browse files Browse the repository at this point in the history
This reverts commit 5eabb71.
  • Loading branch information
daphne-cornelisse committed Jan 25, 2025
1 parent 5eabb71 commit 61f2f19
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 67 deletions.
17 changes: 4 additions & 13 deletions pygpudrive/env/env_torch.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,8 @@ def __init__(
self.num_agents = self.cont_agent_mask.sum().item()
self.single_action_space = self.action_space
# self.action_space = pufferlib.spaces.joint_space(self.single_action_space, self.num_agents)
# self.observation_space = pufferlib.spaces.joint_space(
# self.single_observation_space, self.num_agents
# )
self.observation_space = Box(
low=-np.inf, high=np.inf, shape=(self.get_obs().shape[-1],)
self.observation_space = pufferlib.spaces.joint_space(
self.single_observation_space, self.num_agents
)

self.info_dim = 5 # Number of info features
Expand Down Expand Up @@ -183,7 +180,6 @@ def get_rewards(
self.num_worlds,
self.max_agent_count,
backend=self.backend,
device=self.device,
)

# Index log positions at current time steps
Expand All @@ -197,7 +193,6 @@ def get_rewards(
agent_state = GlobalEgoState.from_tensor(
self.sim.absolute_self_observation_tensor(),
self.backend,
device=self.device,
)

agent_pos = torch.stack(
Expand Down Expand Up @@ -369,7 +364,6 @@ def _get_ego_state(self) -> torch.Tensor:
ego_state = LocalEgoState.from_tensor(
self_obs_tensor=self.sim.self_observation_tensor(),
backend=self.backend,
device=self.device
)
if self.config.norm_obs:
ego_state.normalize()
Expand Down Expand Up @@ -397,7 +391,6 @@ def _get_partner_obs(self):
partner_obs = PartnerObs.from_tensor(
partner_obs_tensor=self.sim.partner_observations_tensor(),
backend=self.backend,
device=self.device,
)

if self.config.norm_obs:
Expand Down Expand Up @@ -430,7 +423,6 @@ def _get_road_map_obs(self):
roadgraph = LocalRoadGraphPoints.from_tensor(
local_roadgraph_tensor=self.sim.agent_roadmap_tensor(),
backend=self.backend,
device=self.device,
)

if self.config.norm_obs:
Expand Down Expand Up @@ -463,7 +455,6 @@ def _get_lidar_obs(self):
lidar = LidarObs.from_tensor(
lidar_tensor=self.sim.lidar_tensor(),
backend=self.backend,
device=self.device,
)

return (
Expand Down Expand Up @@ -669,7 +660,7 @@ def get_expert_actions(self):

# Create data loader
train_loader = SceneDataLoader(
root="data/processed/examples",
root="data/processed/training",
batch_size=data_config.batch_size,
dataset_size=data_config.dataset_size,
sample_with_replacement=True,
Expand Down Expand Up @@ -758,4 +749,4 @@ def get_expert_actions(self):
)
media.write_video(
"obs_video.gif", np.array(agent_obs_frames), fps=10, codec="gif"
)
)
102 changes: 51 additions & 51 deletions pygpudrive/visualize/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
REL_OBS_OBJ_COLORS,
AGENT_COLOR_BY_STATE,
)

OUT_OF_BOUNDS = 1000


Expand All @@ -44,9 +45,7 @@ def __init__(
self.goal_radius = goal_radius
self.num_worlds = num_worlds
self.render_config = render_config
self.figsize = (10, 10)
self.marker_scale = max(self.figsize) / 15
self.line_width_scale = max(self.figsize) / 25
self.figsize = (15, 15)
self.env_config = env_config
self.initialize_static_scenario_data(controlled_agent_mask)

Expand Down Expand Up @@ -75,21 +74,18 @@ def initialize_static_scenario_data(self, controlled_agent_mask):
)

# Cache pre-rendered road graphs for all environments
self.cached_roadgraphs = []
for env_idx in range(self.controlled_agent_mask.shape[0]):
fig, ax = plt.subplots(figsize=self.figsize,dpi=300)
self._plot_roadgraph(
road_graph=self.global_roadgraph,
env_idx=env_idx,
ax=ax,
line_width_scale=self.line_width_scale,
marker_size_scale=self.marker_scale,
)
self.cached_roadgraphs.append(fig)
plt.close(fig)

self.plot_limits = ax.viewLim.get_points() #[x_min,y_min,xmax,ymax] (bottom left corner -> top right)

# self.cached_roadgraphs = []
# for env_idx in range(self.controlled_agent_mask.shape[0]):
# fig, ax = plt.subplots(figsize=self.figsize)
# self._plot_roadgraph(
# road_graph=self.global_roadgraph,
# env_idx=env_idx,
# ax=ax,
# line_width_scale=1.0,
# marker_size_scale=1.0,
# )
# self.cached_roadgraphs.append(fig)
# plt.close(fig)

def plot_simulator_state(
self,
Expand Down Expand Up @@ -136,34 +132,34 @@ def plot_simulator_state(

figs = []

# Calculate scale factors based on figure size
marker_scale = max(self.figsize) / 15
line_width_scale = max(self.figsize) / 15

# Iterate over each environment index
for idx, (env_idx, time_step, center_agent_idx) in enumerate(
zip(env_indices, time_steps, center_agent_indices)
):

# Initialize figure and axes from cached road graph
fig, ax = plt.subplots(figsize=self.figsize)
ax.clear()
ax.clear() # Clear any existing content
ax.set_aspect("equal", adjustable="box")
figs.append(fig)
plt.close(fig)


cached_roadgraph_array = utils.bg_img_from_fig(self.cached_roadgraphs[env_idx])
ax.imshow(
cached_roadgraph_array,
origin="upper",
extent=(self.plot_limits[0][0],self.plot_limits[1][0], # X: [x_min,x_max]
self.plot_limits[0][1],self.plot_limits[1][1]),# Y: [y_min, y_max]
zorder=0,
)

# # # Explicitly set the axis limits to match your coordinates
ax.set_xlim(self.plot_limits[0][0], self.plot_limits[1][0] )
ax.set_ylim(self.plot_limits[0][1], self.plot_limits[1][1])


# Remove axes
# cached_ax.axis('off')
figs.append(fig) # Add the new figure
plt.close(fig) # Close the figure to prevent carryover

# Render the pre-cached road graph for the current environment
# cached_roadgraph_array = utils.bg_img_from_fig(self.cached_roadgraphs[env_idx])
# ax.imshow(
# cached_roadgraph_array,
# origin="upper",
# extent=(-100, 100, -100, 100), # Stretch to full plot
# zorder=0, # Draw as background
# )

# Explicitly set the axis limits to match your coordinates
# cached_ax.set_xlim(-100, 100)
# cached_ax.set_ylim(-100, 100)

# Get control mask and omit out-of-bound agents (dead agents)
controlled = self.controlled_agent_mask[env_idx, :]
Expand All @@ -177,14 +173,22 @@ def plot_simulator_state(
) & controlled_live
is_ok = ~is_offroad & ~is_collided & controlled_live

# Draw the road graph
self._plot_roadgraph(
road_graph=self.global_roadgraph,
env_idx=env_idx,
ax=ax,
line_width_scale=line_width_scale,
marker_size_scale=marker_scale,
)

if plot_log_replay_trajectory:
self._plot_log_replay_trajectory(
ax=ax,
control_mask=controlled_live,
env_idx=env_idx,
log_trajectory=self.log_trajectory,
line_width_scale=self.line_width_scale,
line_width_scale=line_width_scale,
)

# Draw the agents
Expand All @@ -197,8 +201,8 @@ def plot_simulator_state(
is_collided_mask=is_collided,
response_type=self.response_type,
alpha=1.0,
line_width_scale=self.line_width_scale,
marker_size_scale=self.marker_scale,
line_width_scale=line_width_scale,
marker_size_scale=marker_scale,
)


Expand Down Expand Up @@ -251,7 +255,7 @@ def plot_simulator_state(
horizontalalignment="center",
verticalalignment="center",
transform=ax.transAxes,
fontsize=20 * self.marker_scale,
fontsize=20 * marker_scale,
color="black",
bbox=dict(facecolor="white", edgecolor="none", alpha=0.9),
)
Expand Down Expand Up @@ -383,10 +387,9 @@ def _plot_roadgraph(
gpudrive.EntityType.RoadEdge
):
line_width = 1.1 * line_width_scale

else:
line_width = 0.80* line_width_scale
line_width = 0.75 * line_width_scale

ax.plot(
[start[0], end[0]],
[start[1], end[1]],
Expand All @@ -402,7 +405,6 @@ def _plot_roadgraph(
segment_widths,
segment_orientations,
ax,
line_width,
)

elif road_point_type == int(gpudrive.EntityType.StopSign):
Expand All @@ -414,8 +416,8 @@ def _plot_roadgraph(
radius=1.5,
facecolor="#c04000",
edgecolor="none",
linewidth=3.5,
alpha=0.8,
linewidth=3.0,
alpha=0.9,
)
elif road_point_type == int(gpudrive.EntityType.CrossWalk):
for x, y, length, width, orientation in zip(
Expand All @@ -434,7 +436,6 @@ def _plot_roadgraph(
facecolor="none",
edgecolor="xkcd:bluish grey",
alpha=0.4,
linewidth=1.5*self.line_width_scale
)

else:
Expand All @@ -446,7 +447,6 @@ def _plot_roadgraph(
label=road_point_type,
color=ROAD_GRAPH_COLORS[int(road_point_type)],
)


def _plot_filtered_agent_bounding_boxes(
self,
Expand Down
4 changes: 1 addition & 3 deletions pygpudrive/visualize/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,6 @@ def plot_speed_bumps(
facecolor: str = None,
edgecolor: str = None,
alpha: float = None,
linewidth: float = 0
) -> None:
facecolor = "xkcd:goldenrod"
edgecolor = "xkcd:black"
Expand Down Expand Up @@ -414,7 +413,6 @@ def plot_crosswalk(
facecolor: str = None,
edgecolor: str = None,
alpha: float = None,
linewidth: float=None,
):
if ax is None:
ax = plt.gca()
Expand All @@ -431,7 +429,7 @@ def plot_crosswalk(
points,
facecolor=facecolor,
edgecolor=edgecolor,
linewidth=linewidth,
linewidth=2,
alpha=alpha,
hatch=r"//",
zorder=2,
Expand Down

0 comments on commit 61f2f19

Please sign in to comment.