diff --git a/vmas/scenarios/navigation.py b/vmas/scenarios/navigation.py index a40f0a0a..e8b71531 100644 --- a/vmas/scenarios/navigation.py +++ b/vmas/scenarios/navigation.py @@ -43,7 +43,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): ScenarioUtils.check_kwargs_consumed(kwargs) self.min_distance_between_entities = self.agent_radius * 2 + 0.05 - self.world_semidim = 1 + self.world_semidim_x = 1 if self.x_semidim is None else self.x_semidim + self.world_semidim_y = 1 if self.y_semidim is None else self.y_semidim self.min_collision_distance = 0.005 assert 1 <= self.agents_with_same_goal <= self.n_agents @@ -135,8 +136,8 @@ def reset_world_at(self, env_index: int = None): self.world, env_index, self.min_distance_between_entities, - (-self.world_semidim, self.world_semidim), - (-self.world_semidim, self.world_semidim), + (-self.world_semidim_x, self.world_semidim_x), + (-self.world_semidim_y, self.world_semidim_y), ) occupied_positions = torch.stack( @@ -152,8 +153,8 @@ def reset_world_at(self, env_index: int = None): env_index=env_index, world=self.world, min_dist_between_entities=self.min_distance_between_entities, - x_bounds=(-self.world_semidim, self.world_semidim), - y_bounds=(-self.world_semidim, self.world_semidim), + x_bounds=(-self.world_semidim_x, self.world_semidim_x), + y_bounds=(-self.world_semidim_y, self.world_semidim_y), ) goal_poses.append(position.squeeze(1)) occupied_positions = torch.cat([occupied_positions, position], dim=1)