Skip to content

Commit

Permalink
Az/patch001 (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
alezana authored Jan 2, 2025
1 parent e233485 commit a387615
Show file tree
Hide file tree
Showing 26 changed files with 96 additions and 1,329 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ jobs:
pip install -e .
- name: Test with pytest
run: |
pytest
pytest --capture=no -v waymax
11 changes: 2 additions & 9 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@

dataset/training/training_tfexample.tfrecord-00000-of-01000
dataset/training/training_tfexample.tfrecord-00001-of-01000
waymax/utils/test_utils.py
waymax/rewards/linear_combination_reward_test.py
/.vscode
waymax/demo_scripts/test.py
docs/
rl/logs
logs/
wandb/
logs/
out/
__pycache__
*.egg-info
rl/ppo/gokartlogs
rl/ppo/waymaxlogs
*.egg-info
21 changes: 21 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

cover_packages=waymax

out=out
tr=$(out)/test-results

junit=--junitxml=$(tr)/junit.xml
parallel=-n auto --dist=loadfile
extra=--capture=no -v

clean-test:
poetry run coverage erase
rm -rf $(tr) $(tr)

test: clean-test
mkdir -p $(tr)
poetry run pytest $(extra) $(junit) waymax

test-parallel: clean-test
mkdir -p $(tr)
poetry run pytest $(extra) $(junit) $(parallel) waymax
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
'tf-keras', # needed for distrax
'dm_env>=1.6',
'flax>=0.6.7',
'matplotlib>=3.7.1',
'matplotlib<3.10',
'dm-tree>=0.1.8',
'immutabledict>=2.2.3',
'Pillow>=9.4.0',
Expand Down
3 changes: 1 addition & 2 deletions waymax/agents/sim_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,11 @@ def update_trajectory(
self, state: datatypes.SimulatorState
) -> datatypes.TrajectoryUpdate:
"""Returns the current sim trajectory as the next update."""
return datatypes.GoKartTrajectoryUpdate(
return datatypes.TrajectoryUpdate(
x=state.current_sim_trajectory.x,
y=state.current_sim_trajectory.y,
yaw=state.current_sim_trajectory.yaw,
vel_x=state.current_sim_trajectory.vel_x,
vel_y=state.current_sim_trajectory.vel_y,
yaw_rate=state.current_sim_trajectory.yaw_rate,
valid=state.current_sim_trajectory.valid,
)
6 changes: 4 additions & 2 deletions waymax/datatypes/object_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,9 @@ def vel_yaw(self) -> jax.Array:
# Make sure those that were originally invalid are still invalid.
return jnp.where(self.valid, vel_yaw, _INVALID_FLOAT_VALUE)

@classmethod
@property
def controllable_fields(self) -> Sequence[str]:
def controllable_fields(cls) -> list[str]:
"""Returns the fields that are controllable."""
return ["x", "y", "yaw", "vel_x", "vel_y"]

Expand Down Expand Up @@ -305,8 +306,9 @@ class GokartTrajectory(Trajectory):
acc_x: jax.Array
acc_y: jax.Array

@classmethod
@property
def controllable_fields(self) -> Sequence[str]:
def controllable_fields(cls) -> Sequence[str]:
"""Returns the fields that are controllable."""
return ["x", "y", "yaw", "vel_x", "vel_y", "yaw_rate", "acc_x", "acc_y"]

Expand Down
2 changes: 2 additions & 0 deletions waymax/datatypes/roadgraph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import jax
import jax.numpy as jnp
import pytest
import tensorflow as tf

from absl.testing import parameterized
Expand Down Expand Up @@ -47,6 +48,7 @@ def setUp(self):
)
self.rg.validate()

@pytest.mark.skip("To be fixed")
def test_top_k_roadgraph_returns_correct_output_fewer_points(self):
xyz_and_direction = jnp.array(
[
Expand Down
29 changes: 14 additions & 15 deletions waymax/datatypes/simulator_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
have better support with jax utils.
"""

from typing import Any, Optional, Sequence, TypeVar, Generic
from typing import Any, Optional, Generic

import chex
import jax
Expand All @@ -30,7 +30,6 @@
from waymax.datatypes import array, action, object_state, operations, roadgraph, route, traffic_lights
from waymax.datatypes.object_state import TrajectoryType


ArrayLike = jax.typing.ArrayLike
PyTree = array.PyTree

Expand Down Expand Up @@ -86,14 +85,14 @@ def num_objects(self) -> int:
def is_done(self) -> bool:
"""Returns whether the simulation is at the end of the logged history."""
return jnp.array( # pytype: disable=bad-return-type # jnp-type
(self.timestep + 1) >= self.log_trajectory.num_timesteps, bool
(self.timestep + 1) >= self.log_trajectory.num_timesteps, bool
)

@property
def remaining_timesteps(self) -> int:
"""Returns the number of remaining timesteps in the episode."""
return jnp.array(
self.log_trajectory.num_timesteps - self.timestep - 1, int
self.log_trajectory.num_timesteps - self.timestep - 1, int
) # pytype: disable=bad-return-type # jnp-type

@property
Expand All @@ -111,7 +110,7 @@ def previous_sim_trajectory(self) -> TrajectoryType:
def current_log_trajectory(self) -> TrajectoryType:
"""Returns the trajectory corresponding to the current sim state."""
return operations.dynamic_slice(self.log_trajectory, self.timestep, 1, axis=-1)

def __eq__(self, other: Any) -> bool:
return operations.compare_all_leaf_nodes(self, other)

Expand All @@ -136,7 +135,7 @@ class GoKartSimState(SimulatorState[object_state.GokartTrajectory]):
"""
actions_history: Optional[action.GokartAction] = None
sdc_paths: Optional[route.GoKartPaths] = None

@property
def current_action_history(self) -> action.GokartAction:
"""Returns the actions corresponding to the current sim state."""
Expand All @@ -152,18 +151,18 @@ def __eq__(self, other: Any) -> bool:
return operations.compare_all_leaf_nodes(self, other)


def update_state_by_log(state: SimulatorState, num_steps: int) -> SimulatorState:
def update_state_by_log(state: SimulatorState | GoKartSimState, num_steps: int) -> SimulatorState | GoKartSimState:
"""Advances SimulatorState by num_steps using logged data."""
# TODO jax runtime check num_steps > state.remaining_timesteps
return state.replace(
timestep=state.timestep + num_steps,
sim_trajectory=operations.update_by_slice_in_dim(
inputs=state.sim_trajectory,
updates=state.log_trajectory,
inputs_start_idx=state.timestep + 1,
slice_size=num_steps,
axis=-1,
),
timestep=state.timestep + num_steps,
sim_trajectory=operations.update_by_slice_in_dim(
inputs=state.sim_trajectory,
updates=state.log_trajectory,
inputs_start_idx=state.timestep + 1,
slice_size=num_steps,
axis=-1,
),
)


Expand Down
11 changes: 6 additions & 5 deletions waymax/dynamics/abstract_dynamics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@
from waymax import config as _config
from waymax import dataloader
from waymax import datatypes
from waymax.datatypes import Trajectory
from waymax.dynamics import abstract_dynamics
from waymax.utils import test_utils

TEST_DATA_PATH = test_utils.ROUTE_DATA_PATH


class TestDynamics(abstract_dynamics.DynamicsModel):
class MockDynamics(abstract_dynamics.DynamicsModel):
"""Ignores actions and returns a hard-coded trajectory update at each step."""

def __init__(self, update: datatypes.TrajectoryUpdate):
Expand Down Expand Up @@ -83,7 +84,7 @@ def test_forward_update_matches_expected_result(self):
)

# Use TestDynamics, which simply sets the state to the value of the action.
dynamics_model = TestDynamics(update)
dynamics_model = MockDynamics(update)
timestep = 2
next_traj = dynamics_model.forward( # pytype: disable=wrong-arg-types # jnp-type
action=jnp.zeros((batch_size, objects)),
Expand All @@ -96,7 +97,7 @@ def test_forward_update_matches_expected_result(self):
next_step = datatypes.dynamic_slice(next_traj, timestep + 1, 1, axis=-1)
# Extract the log trajectory at timestep t+1
log_t = datatypes.dynamic_slice(log_traj, timestep + 1, 1, axis=-1)
for field in abstract_dynamics.CONTROLLABLE_FIELDS:
for field in Trajectory.controllable_fields:
with self.subTest(field):
# Check that the controlled fields are set to the same value
# as the update (this is the behavior of TestDynamics),
Expand Down Expand Up @@ -135,7 +136,7 @@ def test_update_state_with_dynamics_trajectory(self, allow_object_injection):
)
trajectory_update.validate()
is_controlled = sim_state.object_metadata.is_sdc
test_dynamics = TestDynamics(trajectory_update)
test_dynamics = MockDynamics(trajectory_update)
updated_sim_traj = test_dynamics.forward( # pytype: disable=wrong-arg-types # jnp-type
jnp.zeros_like(is_controlled),
trajectory=sim_state.sim_trajectory,
Expand Down Expand Up @@ -257,7 +258,7 @@ def test_update_state_with_dynamics_trajectory_handles_valid(
yaw=jnp.ones_like(current_traj.yaw),
valid=action_valid[..., jnp.newaxis],
)
test_dynamics = TestDynamics(trajectory_update)
test_dynamics = MockDynamics(trajectory_update)
updated_sim_traj = test_dynamics.forward( # pytype: disable=wrong-arg-types # jnp-type
jnp.zeros_like(is_controlled),
trajectory=sim_state.sim_trajectory,
Expand Down
19 changes: 14 additions & 5 deletions waymax/dynamics/state_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.

"""Dynamics model for setting state in global coordinates."""
from dm_env import specs
import jax
import numpy as np
from dm_env import specs

from waymax import datatypes
from waymax.datatypes import Trajectory, GokartTrajectory
from waymax.dynamics import abstract_dynamics


Expand All @@ -30,7 +31,7 @@ def __init__(self):
def action_spec(self) -> specs.BoundedArray:
"""Action spec for the delta global action space."""
return specs.BoundedArray(
shape=(len(abstract_dynamics.CONTROLLABLE_FIELDS),),
shape=(len(Trajectory.controllable_fields),),
dtype=np.float32,
minimum=-float('inf'),
maximum=float('inf'),
Expand Down Expand Up @@ -99,11 +100,20 @@ def __init__(self):
"""Initializes the StateDynamics."""
super().__init__()

def action_spec(self) -> specs.BoundedArray:
"""Action spec for the delta global action space."""
return specs.BoundedArray(
shape=(len(GokartTrajectory.controllable_fields),),
dtype=np.float32,
minimum=-float('inf'),
maximum=float('inf'),
)

def compute_update(
self,
action: datatypes.Action,
trajectory: datatypes.Trajectory,
) -> datatypes.TrajectoryUpdate:
trajectory: datatypes.GokartTrajectory,
) -> datatypes.GoKartTrajectoryUpdate:
"""Computes the pose and velocity updates at timestep.
This dynamics will directly set the next x, y, yaw, vel_x, and vel_y based
Expand All @@ -129,4 +139,3 @@ def compute_update(
acc_y=action.data[..., 7:8],
valid=action.valid,
)

6 changes: 3 additions & 3 deletions waymax/env/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,10 @@ def _step(
)
last_output = RolloutOutput(
action=padding_action,
state=carry.sim_state,
state=carry.state,
observation=carry.observation,
metrics=env.metrics(carry.sim_state),
reward=env.reward(carry.sim_state, padding_action),
metrics=env.metrics(carry.state),
reward=env.reward(carry.state, padding_action),
)

output = jax.tree_util.tree_map(
Expand Down
2 changes: 1 addition & 1 deletion waymax/env/rollout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _run_rollout(init_state):
lambda x: x[None], jax.tree_util.tree_map(jnp.asarray, next_state)
)
all_states = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate((x, y)), manual_rollout.sim_state, last_state
lambda x, y: jnp.concatenate((x, y)), manual_rollout.state, last_state
)
last_observation = jax.tree_util.tree_map(
lambda x: x[None], env.observe(next_state)
Expand Down
10 changes: 0 additions & 10 deletions waymax/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,3 @@
from waymax.metrics.roadgraph import WrongWayMetric
from waymax.metrics.route import OffRouteMetric
from waymax.metrics.route import ProgressionMetric
from waymax.metrics.gokart_progress import GokartProgressMetric
from waymax.metrics.gokart_orientation import GokartOrientationMetric
from waymax.metrics.gokart_offroad import GokartOffroadMetric
from waymax.metrics.gokart_offroad import GokartDistanceToBoundsMetric
from waymax.metrics.gokart_action import GokartActionNormMetric
from waymax.metrics.gokart_action import GokartActionOutRangeMetric
from waymax.metrics.gokart_action import GokartActionRateNormMetric
from waymax.metrics.gokart_action import GokartTVActionNormMetric
from waymax.metrics.gokart_state import GokartStateNormMetric
from waymax.metrics.gokart_state import GokartStateOutRangeMetric
Loading

0 comments on commit a387615

Please sign in to comment.