You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm a novice just learning mjx and brax. I was wondering if you could give any advice on how to transfer an env from mjx to work with your code/brax?
My environment is just the standard (dm_control) humanoid, balancing on one leg. It works in MuJoCo playground, but I think it will perform much better with dial-mpc, due to the high dimension of the humanoid, and fine control needed for balancing?
I've tried to convert the code but brax seems to want things different? i was just wondering if there are any guides, or if you can give any advice how to do the conversion?
Here's the env,
# %% Defnine Humanoid Env
import mujoco
import mujoco.mjx as mjx
import jax
import jax.numpy as jp
from etils import epath
from brax.envs.base import PipelineEnv, State
from brax import envs
# Define environment path
HUMANOID_ROOT_PATH = epath.Path('/home/ajay/Python_Projects/mujoco-mjx/mjx/mujoco/mjx/test_data/humanoid')
class HumanoidBalance(PipelineEnv):
"""Humanoid learns to stand on one leg while maintaining the existing health check logic."""
def __init__(
self,
pose_reward_weight=5.0,
stability_reward_weight=3.0,
ctrl_cost_weight=0.1,
non_standing_foot_penalty_weight=2.0, # NEW penalty weight
healthy_reward=5.0,
terminate_when_unhealthy=True,
healthy_z_range=(1.0, 2.0),
reset_noise_scale=1e-2,
exclude_current_positions_from_observation=True,
**kwargs,
):
# Load Mujoco Model
mj_model = mujoco.MjModel.from_xml_path(
(HUMANOID_ROOT_PATH / 'humanoid.xml').as_posix())
mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
mj_model.opt.iterations = 6
mj_model.opt.ls_iterations = 6
sys = mjcf.load_model(mj_model)
physics_steps_per_control_step = 5
kwargs['n_frames'] = kwargs.get('n_frames', physics_steps_per_control_step)
kwargs['backend'] = 'mjx'
super().__init__(sys, **kwargs)
# Store environment parameters
self._pose_reward_weight = pose_reward_weight
self._stability_reward_weight = stability_reward_weight
self._ctrl_cost_weight = ctrl_cost_weight
self._non_standing_foot_penalty_weight = non_standing_foot_penalty_weight # NEW
self._healthy_reward = healthy_reward
self._terminate_when_unhealthy = terminate_when_unhealthy
self._healthy_z_range = healthy_z_range
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = exclude_current_positions_from_observation
# Define target pose for standing on one leg
self._target_qpos = jp.array([
0, 0, 1.21948, 0.971588, -0.179973, 0.135318, -0.0729076,
-0.0516, -0.202, 0.23, -0.24, -0.007, -0.34, -1.76, -0.466, -0.0415,
-0.08, -0.01, -0.37, -0.685, -0.35, -0.09, 0.109, -0.067, -0.7, -0.05, 0.12, 0.16
])
# ✅ Correctly Get Foot Indices Using `mj_name2id`
self._left_foot_index = mujoco.mj_name2id(self.sys.mj_model, mujoco.mjtObj.mjOBJ_BODY, "foot_left")
self._right_foot_index = mujoco.mj_name2id(self.sys.mj_model, mujoco.mjtObj.mjOBJ_BODY, "foot_right")
# Ensure indices are valid
assert self._left_foot_index != -1, "Error: 'foot_left' not found in model!"
assert self._right_foot_index != -1, "Error: 'foot_right' not found in model!"
# ✅ Assuming left foot is standing, right foot should be non-weight bearing
self._standing_foot_index = self._left_foot_index
self._non_standing_foot_index = self._right_foot_index
def step(self, state: State, action: jp.ndarray) -> State:
"""Runs one timestep of the environment's dynamics."""
data0 = state.pipeline_state
data = self.pipeline_step(data0, action)
# ✅ Pose Matching Reward
pose_error = jp.linalg.norm(data.qpos - self._target_qpos)
pose_reward = jp.exp(-pose_error)
# ✅ Stability Reward (CoM should stay above the standing foot)
com = data.subtree_com[1]
foot_pos = data.xpos[self._standing_foot_index]
stability_reward = jp.exp(-jp.linalg.norm(com[:2] - foot_pos[:2]))
# ✅ Non-Standing Foot Penalty (Quadratic penalty)
foot_height = data.xpos[self._non_standing_foot_index, 2]
penalty_threshold = 0.2 # Below this height, penalty applies
foot_penalty = jp.where(
foot_height < penalty_threshold,
self._non_standing_foot_penalty_weight * (penalty_threshold - foot_height) ** 2,
0.0
)
# ✅ Healthy Reward
min_z, max_z = self._healthy_z_range
is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
healthy_reward = self._healthy_reward * is_healthy if not self._terminate_when_unhealthy else self._healthy_reward
# ✅ Control Cost Penalty
ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))
# ✅ Final Reward
reward = (
self._pose_reward_weight * pose_reward
+ self._stability_reward_weight * stability_reward
- ctrl_cost
- foot_penalty # NEW
+ healthy_reward
)
done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
state.metrics.update(
pose_reward=pose_reward,
stability_reward=stability_reward,
control_penalty=-ctrl_cost,
foot_penalty=-foot_penalty, # NEW
reward_alive=healthy_reward,
)
return state.replace(pipeline_state=data, obs=self._get_obs(data, action), reward=reward, done=done)
def reset(self, rng: jp.ndarray) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng, 3)
low, hi = -self._reset_noise_scale, self._reset_noise_scale
qpos = self.sys.qpos0 + jax.random.uniform(
rng1, (self.sys.nq,), minval=low, maxval=hi
)
qvel = jax.random.uniform(
rng2, (self.sys.nv,), minval=low, maxval=hi
)
data = self.pipeline_init(qpos, qvel)
# ✅ Compute pose reward at reset
pose_error = jp.linalg.norm(qpos - self._target_qpos)
pose_reward = jp.exp(-pose_error)
# ✅ Stability Reward (CoM should start above the foot)
com = data.subtree_com[1]
foot_pos = data.xpos[self._standing_foot_index] # Ensure correct foot index
stability_reward = jp.exp(-jp.linalg.norm(com[:2] - foot_pos[:2]))
# ✅ Foot Penalty (encouraging non-standing foot to be above 0.2m)
non_standing_foot_height = data.xpos[self._non_standing_foot_index, 2]
foot_penalty = jp.maximum(0.0, 0.2 - non_standing_foot_height) ** 2
# ✅ Healthy Reward (same as step function)
min_z, max_z = self._healthy_z_range
is_healthy = jp.where(qpos[2] < min_z, 0.0, 1.0)
is_healthy = jp.where(qpos[2] > max_z, 0.0, is_healthy)
healthy_reward = self._healthy_reward * is_healthy if not self._terminate_when_unhealthy else self._healthy_reward
# ✅ Control cost is 0 at reset since no action taken
ctrl_cost = 0.0
# ✅ Compute total initial reward
reward = (
self._pose_reward_weight * pose_reward
+ self._stability_reward_weight * stability_reward
- self._non_standing_foot_penalty_weight * foot_penalty
+ healthy_reward
)
done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
obs = self._get_obs(data, jp.zeros(self.sys.nu))
# ✅ Initialize ALL metrics to zero to ensure JAX structure consistency
metrics = {
'pose_reward': pose_reward,
'stability_reward': stability_reward,
'control_penalty': -ctrl_cost,
'foot_penalty': -foot_penalty, # Include foot penalty
'reward_alive': healthy_reward,
}
return State(data, obs, reward, done, metrics)
def _get_obs(self, data: mjx.Data, action: jp.ndarray) -> jp.ndarray:
"""Observes humanoid body position, velocities, and actuator forces."""
position = data.qpos
if self._exclude_current_positions_from_observation:
position = position[2:] # Remove global x, y coordinates
# ✅ Concatenate all observation components
return jp.concatenate([
position, # Joint positions
data.qvel, # Joint velocities
data.cinert[1:].ravel(), # Center of mass inertia
data.cvel[1:].ravel(), # Center of mass velocity
data.qfrc_actuator, # Actuator forces
action # Last action taken
])
envs.register_environment('humanoid_balance', HumanoidBalance)
thanks for reading my issue
best regards,
Ajay
The text was updated successfully, but these errors were encountered:
Thanks for asking. Converting MJX into Brax should be fairly straight forward since Brax uses MJX on the backend. It then does one layer of abstraction to take out qpos, qvel, etc. from mjData and convert into a Brax pipeline state.
Can you share the specific errors you are encountering with this environment? Better off maybe a test repo for side-by-side comparison of your DIAL-MPC environment and dm_control environment?
Hi,
I really like your work !!! It's so neat !!!
I'm a novice just learning mjx and brax. I was wondering if you could give any advice on how to transfer an env from mjx to work with your code/brax?
My environment is just the standard (dm_control) humanoid, balancing on one leg. It works in MuJoCo playground, but I think it will perform much better with
dial-mpc
, due to the high dimension of the humanoid, and fine control needed for balancing?I've tried to convert the code but brax seems to want things different? i was just wondering if there are any guides, or if you can give any advice how to do the conversion?
Here's the env,
thanks for reading my issue
best regards,
Ajay
The text was updated successfully, but these errors were encountered: