Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advice on how to use a MJX env for dial-mpc #16

Open
ajaytalati opened this issue Feb 12, 2025 · 1 comment
Open

Advice on how to use a MJX env for dial-mpc #16

ajaytalati opened this issue Feb 12, 2025 · 1 comment

Comments

@ajaytalati
Copy link

Hi,

I really like your work !!! It's so neat !!!

I'm a novice just learning mjx and brax. I was wondering if you could give any advice on how to transfer an env from mjx to work with your code/brax?

My environment is just the standard (dm_control) humanoid, balancing on one leg. It works in MuJoCo playground, but I think it will perform much better with dial-mpc, due to the high dimension of the humanoid, and fine control needed for balancing?

I've tried to convert the code but brax seems to want things different? i was just wondering if there are any guides, or if you can give any advice how to do the conversion?

Here's the env,


# %% Defnine Humanoid Env

import mujoco
import mujoco.mjx as mjx
import jax
import jax.numpy as jp
from etils import epath
from brax.envs.base import PipelineEnv, State
from brax import envs

# Define environment path
HUMANOID_ROOT_PATH = epath.Path('/home/ajay/Python_Projects/mujoco-mjx/mjx/mujoco/mjx/test_data/humanoid')

class HumanoidBalance(PipelineEnv):
    """Humanoid learns to stand on one leg while maintaining the existing health check logic."""

    def __init__(
        self,
        pose_reward_weight=5.0,
        stability_reward_weight=3.0,
        ctrl_cost_weight=0.1,
        non_standing_foot_penalty_weight=2.0,  # NEW penalty weight
        healthy_reward=5.0,
        terminate_when_unhealthy=True,
        healthy_z_range=(1.0, 2.0),
        reset_noise_scale=1e-2,
        exclude_current_positions_from_observation=True,
        **kwargs,
    ):
        # Load Mujoco Model
        mj_model = mujoco.MjModel.from_xml_path(
            (HUMANOID_ROOT_PATH / 'humanoid.xml').as_posix())
        mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG
        mj_model.opt.iterations = 6
        mj_model.opt.ls_iterations = 6

        sys = mjcf.load_model(mj_model)

        physics_steps_per_control_step = 5
        kwargs['n_frames'] = kwargs.get('n_frames', physics_steps_per_control_step)
        kwargs['backend'] = 'mjx'

        super().__init__(sys, **kwargs)

        # Store environment parameters
        self._pose_reward_weight = pose_reward_weight
        self._stability_reward_weight = stability_reward_weight
        self._ctrl_cost_weight = ctrl_cost_weight
        self._non_standing_foot_penalty_weight = non_standing_foot_penalty_weight  # NEW
        self._healthy_reward = healthy_reward
        self._terminate_when_unhealthy = terminate_when_unhealthy
        self._healthy_z_range = healthy_z_range
        self._reset_noise_scale = reset_noise_scale
        self._exclude_current_positions_from_observation = exclude_current_positions_from_observation

        # Define target pose for standing on one leg
        self._target_qpos = jp.array([
            0, 0, 1.21948, 0.971588, -0.179973, 0.135318, -0.0729076,
            -0.0516, -0.202, 0.23, -0.24, -0.007, -0.34, -1.76, -0.466, -0.0415,
            -0.08, -0.01, -0.37, -0.685, -0.35, -0.09, 0.109, -0.067, -0.7, -0.05, 0.12, 0.16
        ])

        # ✅ Correctly Get Foot Indices Using `mj_name2id`
        self._left_foot_index = mujoco.mj_name2id(self.sys.mj_model, mujoco.mjtObj.mjOBJ_BODY, "foot_left")
        self._right_foot_index = mujoco.mj_name2id(self.sys.mj_model, mujoco.mjtObj.mjOBJ_BODY, "foot_right")

        # Ensure indices are valid
        assert self._left_foot_index != -1, "Error: 'foot_left' not found in model!"
        assert self._right_foot_index != -1, "Error: 'foot_right' not found in model!"

        # ✅ Assuming left foot is standing, right foot should be non-weight bearing
        self._standing_foot_index = self._left_foot_index
        self._non_standing_foot_index = self._right_foot_index

    def step(self, state: State, action: jp.ndarray) -> State:
        """Runs one timestep of the environment's dynamics."""
        data0 = state.pipeline_state
        data = self.pipeline_step(data0, action)

        # ✅ Pose Matching Reward
        pose_error = jp.linalg.norm(data.qpos - self._target_qpos)
        pose_reward = jp.exp(-pose_error)

        # ✅ Stability Reward (CoM should stay above the standing foot)
        com = data.subtree_com[1]
        foot_pos = data.xpos[self._standing_foot_index]
        stability_reward = jp.exp(-jp.linalg.norm(com[:2] - foot_pos[:2]))

        # ✅ Non-Standing Foot Penalty (Quadratic penalty)
        foot_height = data.xpos[self._non_standing_foot_index, 2]
        penalty_threshold = 0.2  # Below this height, penalty applies
        foot_penalty = jp.where(
            foot_height < penalty_threshold,  
            self._non_standing_foot_penalty_weight * (penalty_threshold - foot_height) ** 2,
            0.0
        )

        # ✅ Healthy Reward
        min_z, max_z = self._healthy_z_range
        is_healthy = jp.where(data.q[2] < min_z, 0.0, 1.0)
        is_healthy = jp.where(data.q[2] > max_z, 0.0, is_healthy)
        healthy_reward = self._healthy_reward * is_healthy if not self._terminate_when_unhealthy else self._healthy_reward

        # ✅ Control Cost Penalty
        ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

        # ✅ Final Reward
        reward = (
            self._pose_reward_weight * pose_reward
            + self._stability_reward_weight * stability_reward
            - ctrl_cost
            - foot_penalty  # NEW
            + healthy_reward
        )

        done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
        state.metrics.update(
            pose_reward=pose_reward,
            stability_reward=stability_reward,
            control_penalty=-ctrl_cost,
            foot_penalty=-foot_penalty,  # NEW
            reward_alive=healthy_reward,
        )

        return state.replace(pipeline_state=data, obs=self._get_obs(data, action), reward=reward, done=done)

    def reset(self, rng: jp.ndarray) -> State:
        """Resets the environment to an initial state."""
        rng, rng1, rng2 = jax.random.split(rng, 3)
    
        low, hi = -self._reset_noise_scale, self._reset_noise_scale
        qpos = self.sys.qpos0 + jax.random.uniform(
            rng1, (self.sys.nq,), minval=low, maxval=hi
        )
        qvel = jax.random.uniform(
            rng2, (self.sys.nv,), minval=low, maxval=hi
        )
    
        data = self.pipeline_init(qpos, qvel)
    
        # ✅ Compute pose reward at reset
        pose_error = jp.linalg.norm(qpos - self._target_qpos)
        pose_reward = jp.exp(-pose_error)
    
        # ✅ Stability Reward (CoM should start above the foot)
        com = data.subtree_com[1]
        foot_pos = data.xpos[self._standing_foot_index]  # Ensure correct foot index
        stability_reward = jp.exp(-jp.linalg.norm(com[:2] - foot_pos[:2]))
    
        # ✅ Foot Penalty (encouraging non-standing foot to be above 0.2m)
        non_standing_foot_height = data.xpos[self._non_standing_foot_index, 2]
        foot_penalty = jp.maximum(0.0, 0.2 - non_standing_foot_height) ** 2
    
        # ✅ Healthy Reward (same as step function)
        min_z, max_z = self._healthy_z_range
        is_healthy = jp.where(qpos[2] < min_z, 0.0, 1.0)
        is_healthy = jp.where(qpos[2] > max_z, 0.0, is_healthy)
        healthy_reward = self._healthy_reward * is_healthy if not self._terminate_when_unhealthy else self._healthy_reward
    
        # ✅ Control cost is 0 at reset since no action taken
        ctrl_cost = 0.0
    
        # ✅ Compute total initial reward
        reward = (
            self._pose_reward_weight * pose_reward
            + self._stability_reward_weight * stability_reward
            - self._non_standing_foot_penalty_weight * foot_penalty
            + healthy_reward
        )
    
        done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
    
        obs = self._get_obs(data, jp.zeros(self.sys.nu))
    
        # ✅ Initialize ALL metrics to zero to ensure JAX structure consistency
        metrics = {
            'pose_reward': pose_reward,
            'stability_reward': stability_reward,
            'control_penalty': -ctrl_cost,
            'foot_penalty': -foot_penalty,  # Include foot penalty
            'reward_alive': healthy_reward,
        }
    
        return State(data, obs, reward, done, metrics)

    def _get_obs(self, data: mjx.Data, action: jp.ndarray) -> jp.ndarray:
        """Observes humanoid body position, velocities, and actuator forces."""
        position = data.qpos
        if self._exclude_current_positions_from_observation:
            position = position[2:]  # Remove global x, y coordinates
    
        # ✅ Concatenate all observation components
        return jp.concatenate([
            position,                # Joint positions
            data.qvel,               # Joint velocities
            data.cinert[1:].ravel(), # Center of mass inertia
            data.cvel[1:].ravel(),   # Center of mass velocity
            data.qfrc_actuator,      # Actuator forces
            action                   # Last action taken
        ])

envs.register_environment('humanoid_balance', HumanoidBalance)

thanks for reading my issue

best regards,

Ajay

@HaoruXue
Copy link
Collaborator

Hi @ajaytalati

Thanks for asking. Converting MJX into Brax should be fairly straight forward since Brax uses MJX on the backend. It then does one layer of abstraction to take out qpos, qvel, etc. from mjData and convert into a Brax pipeline state.

Can you share the specific errors you are encountering with this environment? Better off maybe a test repo for side-by-side comparison of your DIAL-MPC environment and dm_control environment?

Also FYI here is a test branch of DIAL-MPC where I removed all the Brax stuffs and use pure MJX. Only the unitree_go2_trot environment works right now, but feel free to see how different it is from the main branch: https://github.com/LeCAR-Lab/dial-mpc/blob/haoru/remove-brax/dial_mpc/envs/unitree_go2_env.py#L36 (very minimal difference)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants