From 5eabb7122ed8a7cb2cc3d692e77bef5a177550b1 Mon Sep 17 00:00:00 2001
From: charliemolony <73535968+charliemolony@users.noreply.github.com>
Date: Sat, 25 Jan 2025 09:14:01 -0500
Subject: [PATCH] Cach branch (#328)

* caching

* caching

* adding back env_torch

* adding back env_torch

* adding caching to visualiser

* adding caching and cleaning up code

* cleaning
---
 pygpudrive/env/env_torch.py   |  17 ++++--
 pygpudrive/visualize/core.py  | 102 +++++++++++++++++-----------------
 pygpudrive/visualize/utils.py |   4 +-
 3 files changed, 67 insertions(+), 56 deletions(-)
 mode change 100755 => 100644 pygpudrive/env/env_torch.py

diff --git a/pygpudrive/env/env_torch.py b/pygpudrive/env/env_torch.py
old mode 100755
new mode 100644
index 59ccad9e..ddf547d3
--- a/pygpudrive/env/env_torch.py
+++ b/pygpudrive/env/env_torch.py
@@ -81,8 +81,11 @@ def __init__(
         self.num_agents = self.cont_agent_mask.sum().item()
         self.single_action_space = self.action_space
         # self.action_space = pufferlib.spaces.joint_space(self.single_action_space, self.num_agents)
-        self.observation_space = pufferlib.spaces.joint_space(
-            self.single_observation_space, self.num_agents
+        # self.observation_space = pufferlib.spaces.joint_space(
+        #     self.single_observation_space, self.num_agents
+        # )
+        self.observation_space = Box(
+            low=-np.inf, high=np.inf, shape=(self.get_obs().shape[-1],)
         )
 
         self.info_dim = 5  # Number of info features
@@ -180,6 +183,7 @@ def get_rewards(
                 self.num_worlds,
                 self.max_agent_count,
                 backend=self.backend,
+                device=self.device,
             )
 
             # Index log positions at current time steps
@@ -193,6 +197,7 @@ def get_rewards(
             agent_state = GlobalEgoState.from_tensor(
                 self.sim.absolute_self_observation_tensor(),
                 self.backend,
+                device=self.device,
             )
 
             agent_pos = torch.stack(
@@ -364,6 +369,7 @@ def _get_ego_state(self) -> torch.Tensor:
             ego_state = LocalEgoState.from_tensor(
                 self_obs_tensor=self.sim.self_observation_tensor(),
                 backend=self.backend,
+                device=self.device
             )
             if self.config.norm_obs:
                 ego_state.normalize()
@@ -391,6 +397,7 @@ def _get_partner_obs(self):
             partner_obs = PartnerObs.from_tensor(
                 partner_obs_tensor=self.sim.partner_observations_tensor(),
                 backend=self.backend,
+                device=self.device,
             )
 
             if self.config.norm_obs:
@@ -423,6 +430,7 @@ def _get_road_map_obs(self):
             roadgraph = LocalRoadGraphPoints.from_tensor(
                 local_roadgraph_tensor=self.sim.agent_roadmap_tensor(),
                 backend=self.backend,
+                device=self.device,
             )
 
             if self.config.norm_obs:
@@ -455,6 +463,7 @@ def _get_lidar_obs(self):
             lidar = LidarObs.from_tensor(
                 lidar_tensor=self.sim.lidar_tensor(),
                 backend=self.backend,
+                device=self.device,
             )
 
             return (
@@ -660,7 +669,7 @@ def get_expert_actions(self):
 
     # Create data loader
     train_loader = SceneDataLoader(
-        root="data/processed/training",
+        root="data/processed/examples",
         batch_size=data_config.batch_size,
         dataset_size=data_config.dataset_size,
         sample_with_replacement=True,
@@ -749,4 +758,4 @@ def get_expert_actions(self):
     )
     media.write_video(
         "obs_video.gif", np.array(agent_obs_frames), fps=10, codec="gif"
-    )
+    )
\ No newline at end of file
diff --git a/pygpudrive/visualize/core.py b/pygpudrive/visualize/core.py
index d2826aef..c6473a90 100644
--- a/pygpudrive/visualize/core.py
+++ b/pygpudrive/visualize/core.py
@@ -24,7 +24,6 @@
     REL_OBS_OBJ_COLORS,
     AGENT_COLOR_BY_STATE,
 )
-
 OUT_OF_BOUNDS = 1000
 
 
@@ -45,7 +44,9 @@ def __init__(
         self.goal_radius = goal_radius
         self.num_worlds = num_worlds
         self.render_config = render_config
-        self.figsize = (15, 15)
+        self.figsize = (10, 10)
+        self.marker_scale = max(self.figsize) / 15
+        self.line_width_scale = max(self.figsize) / 25
         self.env_config = env_config
         self.initialize_static_scenario_data(controlled_agent_mask)
 
@@ -74,18 +75,21 @@ def initialize_static_scenario_data(self, controlled_agent_mask):
         )
 
         # Cache pre-rendered road graphs for all environments
-        # self.cached_roadgraphs = []
-        # for env_idx in range(self.controlled_agent_mask.shape[0]):
-        #     fig, ax = plt.subplots(figsize=self.figsize)
-        #     self._plot_roadgraph(
-        #         road_graph=self.global_roadgraph,
-        #         env_idx=env_idx,
-        #         ax=ax,
-        #         line_width_scale=1.0,
-        #         marker_size_scale=1.0,
-        #     )
-        #     self.cached_roadgraphs.append(fig)
-        #     plt.close(fig)
+        self.cached_roadgraphs = []
+        for env_idx in range(self.controlled_agent_mask.shape[0]):
+            fig, ax = plt.subplots(figsize=self.figsize,dpi=300)
+            self._plot_roadgraph(
+                road_graph=self.global_roadgraph,
+                env_idx=env_idx,
+                ax=ax,
+                line_width_scale=self.line_width_scale,
+                marker_size_scale=self.marker_scale,
+            )
+            self.cached_roadgraphs.append(fig)
+            plt.close(fig)
+        
+        self.plot_limits = ax.viewLim.get_points() #[x_min,y_min,xmax,ymax] (bottom left corner -> top right)
+
 
     def plot_simulator_state(
         self,
@@ -132,34 +136,34 @@ def plot_simulator_state(
 
         figs = []
 
-        # Calculate scale factors based on figure size
-        marker_scale = max(self.figsize) / 15
-        line_width_scale = max(self.figsize) / 15
-
         # Iterate over each environment index
         for idx, (env_idx, time_step, center_agent_idx) in enumerate(
             zip(env_indices, time_steps, center_agent_indices)
         ):
 
-            # Initialize figure and axes from cached road graph
             fig, ax = plt.subplots(figsize=self.figsize)
-            ax.clear()  # Clear any existing content
+            ax.clear()  
             ax.set_aspect("equal", adjustable="box")
-            figs.append(fig)  # Add the new figure
-            plt.close(fig)  # Close the figure to prevent carryover
-
-            # Render the pre-cached road graph for the current environment
-            # cached_roadgraph_array = utils.bg_img_from_fig(self.cached_roadgraphs[env_idx])
-            # ax.imshow(
-            #     cached_roadgraph_array,
-            #     origin="upper",
-            #     extent=(-100, 100, -100, 100),  # Stretch to full plot
-            #     zorder=0,  # Draw as background
-            # )
-
-            # Explicitly set the axis limits to match your coordinates
-            # cached_ax.set_xlim(-100, 100)
-            # cached_ax.set_ylim(-100, 100)
+            figs.append(fig)  
+            plt.close(fig)  
+
+   
+            cached_roadgraph_array = utils.bg_img_from_fig(self.cached_roadgraphs[env_idx])
+            ax.imshow(
+                cached_roadgraph_array,
+                origin="upper",
+                extent=(self.plot_limits[0][0],self.plot_limits[1][0], # X: [x_min,x_max]
+                        self.plot_limits[0][1],self.plot_limits[1][1]),# Y: [y_min, y_max]
+                zorder=0,  
+            )
+
+            # # # Explicitly set the axis limits to match your coordinates
+            ax.set_xlim(self.plot_limits[0][0], self.plot_limits[1][0] )
+            ax.set_ylim(self.plot_limits[0][1], self.plot_limits[1][1])
+  
+
+            # Remove axes
+            # cached_ax.axis('off')
 
             # Get control mask and omit out-of-bound agents (dead agents)
             controlled = self.controlled_agent_mask[env_idx, :]
@@ -173,14 +177,6 @@ def plot_simulator_state(
             ) & controlled_live
             is_ok = ~is_offroad & ~is_collided & controlled_live
 
-            # Draw the road graph
-            self._plot_roadgraph(
-                road_graph=self.global_roadgraph,
-                env_idx=env_idx,
-                ax=ax,
-                line_width_scale=line_width_scale,
-                marker_size_scale=marker_scale,
-            )
 
             if plot_log_replay_trajectory:
                 self._plot_log_replay_trajectory(
@@ -188,7 +184,7 @@ def plot_simulator_state(
                     control_mask=controlled_live,
                     env_idx=env_idx,
                     log_trajectory=self.log_trajectory,
-                    line_width_scale=line_width_scale,
+                    line_width_scale=self.line_width_scale,
                 )
 
             # Draw the agents
@@ -201,8 +197,8 @@ def plot_simulator_state(
                 is_collided_mask=is_collided,
                 response_type=self.response_type,
                 alpha=1.0,
-                line_width_scale=line_width_scale,
-                marker_size_scale=marker_scale,
+                line_width_scale=self.line_width_scale,
+                marker_size_scale=self.marker_scale,
             )
 
 
@@ -255,7 +251,7 @@ def plot_simulator_state(
                     horizontalalignment="center",
                     verticalalignment="center",
                     transform=ax.transAxes,
-                    fontsize=20 * marker_scale,
+                    fontsize=20 * self.marker_scale,
                     color="black",
                     bbox=dict(facecolor="white", edgecolor="none", alpha=0.9),
                 )
@@ -387,9 +383,10 @@ def _plot_roadgraph(
                                 gpudrive.EntityType.RoadEdge
                             ):
                                 line_width = 1.1 * line_width_scale
-                            else:
-                                line_width = 0.75 * line_width_scale
 
+                            else:
+                                line_width = 0.80* line_width_scale
+                                
                             ax.plot(
                                 [start[0], end[0]],
                                 [start[1], end[1]],
@@ -405,6 +402,7 @@ def _plot_roadgraph(
                             segment_widths,
                             segment_orientations,
                             ax,
+                            line_width,
                         )
 
                     elif road_point_type == int(gpudrive.EntityType.StopSign):
@@ -416,8 +414,8 @@ def _plot_roadgraph(
                                 radius=1.5,
                                 facecolor="#c04000",
                                 edgecolor="none",
-                                linewidth=3.0,
-                                alpha=0.9,
+                                linewidth=3.5,
+                                alpha=0.8,
                             )
                     elif road_point_type == int(gpudrive.EntityType.CrossWalk):
                         for x, y, length, width, orientation in zip(
@@ -436,6 +434,7 @@ def _plot_roadgraph(
                                 facecolor="none",
                                 edgecolor="xkcd:bluish grey",
                                 alpha=0.4,
+                                linewidth=1.5*self.line_width_scale
                             )
 
                 else:
@@ -447,6 +446,7 @@ def _plot_roadgraph(
                         label=road_point_type,
                         color=ROAD_GRAPH_COLORS[int(road_point_type)],
                     )
+   
 
     def _plot_filtered_agent_bounding_boxes(
         self,
diff --git a/pygpudrive/visualize/utils.py b/pygpudrive/visualize/utils.py
index 0e0f2038..c899576c 100644
--- a/pygpudrive/visualize/utils.py
+++ b/pygpudrive/visualize/utils.py
@@ -347,6 +347,7 @@ def plot_speed_bumps(
     facecolor: str = None,
     edgecolor: str = None,
     alpha: float = None,
+    linewidth: float = 0
 ) -> None:
     facecolor = "xkcd:goldenrod"
     edgecolor = "xkcd:black"
@@ -413,6 +414,7 @@ def plot_crosswalk(
     facecolor: str = None,
     edgecolor: str = None,
     alpha: float = None,
+    linewidth: float=None,
 ):
     if ax is None:
         ax = plt.gca()
@@ -429,7 +431,7 @@ def plot_crosswalk(
         points,
         facecolor=facecolor,
         edgecolor=edgecolor,
-        linewidth=2,
+        linewidth=linewidth,
         alpha=alpha,
         hatch=r"//",
         zorder=2,