diff --git a/.github/unittest/linux_libs/scripts_habitat/environment.yml b/.github/unittest/linux_libs/scripts_habitat/environment.yml index feab90cd052..256d8edbf64 100644 --- a/.github/unittest/linux_libs/scripts_habitat/environment.yml +++ b/.github/unittest/linux_libs/scripts_habitat/environment.yml @@ -18,3 +18,4 @@ dependencies: - scipy==1.9.1 - hydra-core - ninja + - numpy<2.0 diff --git a/.github/unittest/linux_libs/scripts_habitat/install.sh b/.github/unittest/linux_libs/scripts_habitat/install.sh index 6b948279a61..fac059bc4d5 100755 --- a/.github/unittest/linux_libs/scripts_habitat/install.sh +++ b/.github/unittest/linux_libs/scripts_habitat/install.sh @@ -21,9 +21,9 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [[ "$TORCH_VERSION" == "nightly" ]]; then - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121 -U elif [[ "$TORCH_VERSION" == "stable" ]]; then - pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu121 fi # install tensordict diff --git a/.github/unittest/linux_libs/scripts_habitat/run_all.sh b/.github/unittest/linux_libs/scripts_habitat/run_all.sh index 69204455a69..3e109ee78bb 100755 --- a/.github/unittest/linux_libs/scripts_habitat/run_all.sh +++ b/.github/unittest/linux_libs/scripts_habitat/run_all.sh @@ -5,14 +5,8 @@ set -v apt-get update && apt-get upgrade -y -apt-get install -y vim git wget +apt-get install -y g++ gcc vim git wget libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libglvnd0 libgl1 libglx0 libegl1 libgles2 -apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev -apt-get install -y libglvnd0 libgl1 libglx0 libegl1 libgles2 - -apt-get install -y g++ gcc -#apt-get upgrade -y libstdc++6 -#apt-get install -y libgcc apt-get dist-upgrade -y this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" diff --git a/.github/unittest/linux_libs/scripts_habitat/run_test.sh b/.github/unittest/linux_libs/scripts_habitat/run_test.sh index a60fffd8f45..f790e757d3d 100755 --- a/.github/unittest/linux_libs/scripts_habitat/run_test.sh +++ b/.github/unittest/linux_libs/scripts_habitat/run_test.sh @@ -42,6 +42,12 @@ conda deactivate && conda activate ./env # this workflow only tests the libs +mkdir data +git lfs update + +python -m habitat_sim.utils.datasets_download --uids rearrange_pick_dataset_v0 rearrange_task_assets --data-path ./data --no-prune +#python -m habitat_sim.utils.datasets_download --uids rearrange_task_assets --data-path ./data --no-prune + python -c "import habitat;import habitat.gym" python -c """from torchrl.envs.libs.habitat import HabitatEnv env = HabitatEnv('HabitatRenderPick-v0') diff --git a/.github/unittest/linux_libs/scripts_habitat/setup_env.sh b/.github/unittest/linux_libs/scripts_habitat/setup_env.sh index 6ad970c3f47..075931be71c 100755 --- a/.github/unittest/linux_libs/scripts_habitat/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_habitat/setup_env.sh @@ -70,6 +70,7 @@ conda env update --file "${this_dir}/environment.yml" --prune conda install habitat-sim withbullet headless -c conda-forge -c aihabitat -y git clone https://github.com/facebookresearch/habitat-lab.git cd habitat-lab -pip3 install -e habitat-lab -pip3 install -e habitat-baselines # install habitat_baselines +echo "numpy<2.0" > constraints.txt +pip3 install -e habitat-lab --constraint constraints.txt +pip3 install -e habitat-baselines --constraint constraints.txt # install habitat_baselines conda run python -m pip install "gym[atari,accept-rom-license]" pygame diff --git a/test/test_libs.py b/test/test_libs.py index b3ba8d54c3d..4046dabfb8e 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -1566,6 +1566,14 @@ def test_habitat(self, envname): _ = env.rollout(3) check_env_specs(env) + def test_from_config(self): + import habitat + + cfg = habitat.get_config("benchmark/nav/objectnav/objectnav_hssd-hab.yaml") + env = HabitatEnv.from_config(cfg) + check_env_specs(env) + assert isinstance(env, HabitatEnv) + @pytest.mark.parametrize("from_pixels", [True, False]) def test_habitat_render(self, envname, from_pixels): env = HabitatEnv(envname, from_pixels=from_pixels) diff --git a/torchrl/envs/libs/habitat.py b/torchrl/envs/libs/habitat.py index 4180c42b2dc..3fc9672dabb 100644 --- a/torchrl/envs/libs/habitat.py +++ b/torchrl/envs/libs/habitat.py @@ -9,7 +9,7 @@ from torchrl._utils import _make_ordinal_device from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.common import EnvBase -from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.libs.gym import GymEnv, GymWrapper, set_gym_backend from torchrl.envs.utils import _classproperty _has_habitat = importlib.util.find_spec("habitat") is not None @@ -106,6 +106,25 @@ def __init__(self, env_name, **kwargs): ] super().__init__(env_name=env_name, **kwargs) + @classmethod + def from_config(cls, cfg): + """Creates a HabitatEnv from the config. + + Examples: + >>> config = habitat.get_config( + ... "benchmark/nav/objectnav/objectnav_hssd-hab.yaml" + ... ) + >>> env = HabitatEnv(config) + + """ + import habitat.gym + + wrapper = cls.__new__(cls) + env = habitat.gym.make_gym_from_config(cfg) + env.reset() # Prevents crash from TorchRL GymWrapper + wrapper.__dict__.update(GymWrapper(env).__dict__) + return wrapper + @_classproperty def available_envs(cls): if not _has_habitat: