diff --git a/pygpudrive/env/env_torch.py b/pygpudrive/env/env_torch.py old mode 100755 new mode 100644 index 2d1e3f2c..12a35c96 --- a/pygpudrive/env/env_torch.py +++ b/pygpudrive/env/env_torch.py @@ -81,8 +81,11 @@ def __init__( self.num_agents = self.cont_agent_mask.sum().item() self.single_action_space = self.action_space # self.action_space = pufferlib.spaces.joint_space(self.single_action_space, self.num_agents) - self.observation_space = pufferlib.spaces.joint_space( - self.single_observation_space, self.num_agents + # self.observation_space = pufferlib.spaces.joint_space( + # self.single_observation_space, self.num_agents + # ) + self.observation_space = Box( + low=-np.inf, high=np.inf, shape=(self.get_obs().shape[-1],) ) self.info_dim = 5 # Number of info features @@ -180,6 +183,7 @@ def get_rewards( self.num_worlds, self.max_agent_count, backend=self.backend, + device=self.device, ) # Index log positions at current time steps @@ -193,6 +197,7 @@ def get_rewards( agent_state = GlobalEgoState.from_tensor( self.sim.absolute_self_observation_tensor(), self.backend, + device=self.device, ) agent_pos = torch.stack( @@ -364,6 +369,7 @@ def _get_ego_state(self) -> torch.Tensor: ego_state = LocalEgoState.from_tensor( self_obs_tensor=self.sim.self_observation_tensor(), backend=self.backend, + device=self.device ) if self.config.norm_obs: ego_state.normalize() @@ -391,6 +397,7 @@ def _get_partner_obs(self): partner_obs = PartnerObs.from_tensor( partner_obs_tensor=self.sim.partner_observations_tensor(), backend=self.backend, + device=self.device, ) if self.config.norm_obs: @@ -423,6 +430,7 @@ def _get_road_map_obs(self): roadgraph = LocalRoadGraphPoints.from_tensor( local_roadgraph_tensor=self.sim.agent_roadmap_tensor(), backend=self.backend, + device=self.device, ) if self.config.norm_obs: @@ -455,6 +463,7 @@ def _get_lidar_obs(self): lidar = LidarObs.from_tensor( lidar_tensor=self.sim.lidar_tensor(), backend=self.backend, + device=self.device, ) return ( @@ -613,7 +622,7 @@ def get_expert_actions(self): # Create data loader train_loader = SceneDataLoader( - root="data/processed/training", + root="data/processed/examples", batch_size=data_config.batch_size, dataset_size=data_config.dataset_size, sample_with_replacement=True, @@ -702,4 +711,4 @@ def get_expert_actions(self): ) media.write_video( "obs_video.gif", np.array(agent_obs_frames), fps=10, codec="gif" - ) + ) \ No newline at end of file diff --git a/pygpudrive/visualize/core.py b/pygpudrive/visualize/core.py index d2826aef..c6473a90 100644 --- a/pygpudrive/visualize/core.py +++ b/pygpudrive/visualize/core.py @@ -24,7 +24,6 @@ REL_OBS_OBJ_COLORS, AGENT_COLOR_BY_STATE, ) - OUT_OF_BOUNDS = 1000 @@ -45,7 +44,9 @@ def __init__( self.goal_radius = goal_radius self.num_worlds = num_worlds self.render_config = render_config - self.figsize = (15, 15) + self.figsize = (10, 10) + self.marker_scale = max(self.figsize) / 15 + self.line_width_scale = max(self.figsize) / 25 self.env_config = env_config self.initialize_static_scenario_data(controlled_agent_mask) @@ -74,18 +75,21 @@ def initialize_static_scenario_data(self, controlled_agent_mask): ) # Cache pre-rendered road graphs for all environments - # self.cached_roadgraphs = [] - # for env_idx in range(self.controlled_agent_mask.shape[0]): - # fig, ax = plt.subplots(figsize=self.figsize) - # self._plot_roadgraph( - # road_graph=self.global_roadgraph, - # env_idx=env_idx, - # ax=ax, - # line_width_scale=1.0, - # marker_size_scale=1.0, - # ) - # self.cached_roadgraphs.append(fig) - # plt.close(fig) + self.cached_roadgraphs = [] + for env_idx in range(self.controlled_agent_mask.shape[0]): + fig, ax = plt.subplots(figsize=self.figsize,dpi=300) + self._plot_roadgraph( + road_graph=self.global_roadgraph, + env_idx=env_idx, + ax=ax, + line_width_scale=self.line_width_scale, + marker_size_scale=self.marker_scale, + ) + self.cached_roadgraphs.append(fig) + plt.close(fig) + + self.plot_limits = ax.viewLim.get_points() #[x_min,y_min,xmax,ymax] (bottom left corner -> top right) + def plot_simulator_state( self, @@ -132,34 +136,34 @@ def plot_simulator_state( figs = [] - # Calculate scale factors based on figure size - marker_scale = max(self.figsize) / 15 - line_width_scale = max(self.figsize) / 15 - # Iterate over each environment index for idx, (env_idx, time_step, center_agent_idx) in enumerate( zip(env_indices, time_steps, center_agent_indices) ): - # Initialize figure and axes from cached road graph fig, ax = plt.subplots(figsize=self.figsize) - ax.clear() # Clear any existing content + ax.clear() ax.set_aspect("equal", adjustable="box") - figs.append(fig) # Add the new figure - plt.close(fig) # Close the figure to prevent carryover - - # Render the pre-cached road graph for the current environment - # cached_roadgraph_array = utils.bg_img_from_fig(self.cached_roadgraphs[env_idx]) - # ax.imshow( - # cached_roadgraph_array, - # origin="upper", - # extent=(-100, 100, -100, 100), # Stretch to full plot - # zorder=0, # Draw as background - # ) - - # Explicitly set the axis limits to match your coordinates - # cached_ax.set_xlim(-100, 100) - # cached_ax.set_ylim(-100, 100) + figs.append(fig) + plt.close(fig) + + + cached_roadgraph_array = utils.bg_img_from_fig(self.cached_roadgraphs[env_idx]) + ax.imshow( + cached_roadgraph_array, + origin="upper", + extent=(self.plot_limits[0][0],self.plot_limits[1][0], # X: [x_min,x_max] + self.plot_limits[0][1],self.plot_limits[1][1]),# Y: [y_min, y_max] + zorder=0, + ) + + # # # Explicitly set the axis limits to match your coordinates + ax.set_xlim(self.plot_limits[0][0], self.plot_limits[1][0] ) + ax.set_ylim(self.plot_limits[0][1], self.plot_limits[1][1]) + + + # Remove axes + # cached_ax.axis('off') # Get control mask and omit out-of-bound agents (dead agents) controlled = self.controlled_agent_mask[env_idx, :] @@ -173,14 +177,6 @@ def plot_simulator_state( ) & controlled_live is_ok = ~is_offroad & ~is_collided & controlled_live - # Draw the road graph - self._plot_roadgraph( - road_graph=self.global_roadgraph, - env_idx=env_idx, - ax=ax, - line_width_scale=line_width_scale, - marker_size_scale=marker_scale, - ) if plot_log_replay_trajectory: self._plot_log_replay_trajectory( @@ -188,7 +184,7 @@ def plot_simulator_state( control_mask=controlled_live, env_idx=env_idx, log_trajectory=self.log_trajectory, - line_width_scale=line_width_scale, + line_width_scale=self.line_width_scale, ) # Draw the agents @@ -201,8 +197,8 @@ def plot_simulator_state( is_collided_mask=is_collided, response_type=self.response_type, alpha=1.0, - line_width_scale=line_width_scale, - marker_size_scale=marker_scale, + line_width_scale=self.line_width_scale, + marker_size_scale=self.marker_scale, ) @@ -255,7 +251,7 @@ def plot_simulator_state( horizontalalignment="center", verticalalignment="center", transform=ax.transAxes, - fontsize=20 * marker_scale, + fontsize=20 * self.marker_scale, color="black", bbox=dict(facecolor="white", edgecolor="none", alpha=0.9), ) @@ -387,9 +383,10 @@ def _plot_roadgraph( gpudrive.EntityType.RoadEdge ): line_width = 1.1 * line_width_scale - else: - line_width = 0.75 * line_width_scale + else: + line_width = 0.80* line_width_scale + ax.plot( [start[0], end[0]], [start[1], end[1]], @@ -405,6 +402,7 @@ def _plot_roadgraph( segment_widths, segment_orientations, ax, + line_width, ) elif road_point_type == int(gpudrive.EntityType.StopSign): @@ -416,8 +414,8 @@ def _plot_roadgraph( radius=1.5, facecolor="#c04000", edgecolor="none", - linewidth=3.0, - alpha=0.9, + linewidth=3.5, + alpha=0.8, ) elif road_point_type == int(gpudrive.EntityType.CrossWalk): for x, y, length, width, orientation in zip( @@ -436,6 +434,7 @@ def _plot_roadgraph( facecolor="none", edgecolor="xkcd:bluish grey", alpha=0.4, + linewidth=1.5*self.line_width_scale ) else: @@ -447,6 +446,7 @@ def _plot_roadgraph( label=road_point_type, color=ROAD_GRAPH_COLORS[int(road_point_type)], ) + def _plot_filtered_agent_bounding_boxes( self, diff --git a/pygpudrive/visualize/utils.py b/pygpudrive/visualize/utils.py index 0e0f2038..c899576c 100644 --- a/pygpudrive/visualize/utils.py +++ b/pygpudrive/visualize/utils.py @@ -347,6 +347,7 @@ def plot_speed_bumps( facecolor: str = None, edgecolor: str = None, alpha: float = None, + linewidth: float = 0 ) -> None: facecolor = "xkcd:goldenrod" edgecolor = "xkcd:black" @@ -413,6 +414,7 @@ def plot_crosswalk( facecolor: str = None, edgecolor: str = None, alpha: float = None, + linewidth: float=None, ): if ax is None: ax = plt.gca() @@ -429,7 +431,7 @@ def plot_crosswalk( points, facecolor=facecolor, edgecolor=edgecolor, - linewidth=2, + linewidth=linewidth, alpha=alpha, hatch=r"//", zorder=2,