From b5ea9068ca74abba4fa64788032a114f2b1486a2 Mon Sep 17 00:00:00 2001 From: guillemdb Date: Sun, 5 Jan 2025 13:45:55 +0100 Subject: [PATCH] Use opencv when PIL is not available Signed-off-by: guillemdb --- .github/workflows/push.yml | 18 ++++------ pyproject.toml | 2 +- src/plangym/utils.py | 68 ++++++++++++++++++++++++++++++++++++-- tests/test_utils.py | 14 +++++++- 4 files changed, 86 insertions(+), 16 deletions(-) diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index f341702..0129b79 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -33,18 +33,12 @@ jobs: uses: actions/setup-python@v2 with: python-version: "3.10" - - name: Setup Rye - id: setup-rye - uses: eifinger/setup-rye@v4 + - name: Setup ruff + id: setup-ruff + uses: astral-sh/ruff-action@v3 with: - enable-cache: true - cache-prefix: ubuntu-20.04-rye-check-${{ hashFiles('pyproject.toml') }} - - - name: Run style check and linter - run: | - set -x - rye fmt --check - rye lint + version-file: "pyproject.toml" + args: "check --diff" pytest: name: Run Pytest @@ -120,7 +114,7 @@ jobs: set -x # TODO: Figure out how to emulate a display in headless machines, and figure out why the commented files fail # SKIP_RENDER=True rye run pytest tests/test_registry.py tests/videogames/test_retro.py - SKIP_RENDER=True rye run pytest tests/control tests/videogames/test_atari.py tests/videogames/test_nes.py tests/test_core.py + SKIP_RENDER=True rye run pytest tests/control tests/videogames/test_atari.py tests/videogames/test_nes.py tests/test_core.py tests/test_utils.py - name: Run code coverage on Ubuntu if: ${{ matrix.os == 'ubuntu-latest' }} diff --git a/pyproject.toml b/pyproject.toml index 142a2cc..7265c5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ ] dependencies = [ "numpy", - "pillow", + "pillow; sys_platform != 'darwin'", "fragile-gym", "opencv-python>=4.10.0.84", "pyglet==1.5.11", diff --git a/src/plangym/utils.py b/src/plangym/utils.py index 2b11c7e..53494ae 100644 --- a/src/plangym/utils.py +++ b/src/plangym/utils.py @@ -6,8 +6,15 @@ from gymnasium.spaces import Box from gymnasium.wrappers.time_limit import TimeLimit import numpy -from PIL import Image from pyvirtualdisplay import Display +import cv2 + +try: + from PIL import Image + + USE_PIL = True +except ImportError: # pragma: no cover + USE_PIL = False def get_display(visible=False, size=(400, 400), **kwargs): @@ -59,7 +66,7 @@ def remove_time_limit(gym_env: gym.Env) -> gym.Env: return gym_env -def process_frame( +def process_frame_pil( frame: numpy.ndarray, width: int | None = None, height: int | None = None, @@ -80,6 +87,7 @@ def process_frame( The resized frame that matches the provided width and height. """ + mode = "L" if mode == "GRAY" else mode height = height or frame.shape[0] width = width or frame.shape[1] frame = Image.fromarray(frame) @@ -87,6 +95,62 @@ def process_frame( return numpy.array(frame) +def process_frame_opencv( + frame: numpy.ndarray, + width: int | None = None, + height: int | None = None, + mode: str = "RGB", +) -> numpy.ndarray: + """Resize an RGB frame to a specified shape and mode. + + Use OpenCV to resize an RGB frame to a specified height and width \ + or changing it to a different mode. + + Args: + frame: Target numpy array representing the image that will be resized. + width: Width of the resized image. + height: Height of the resized image. + mode: Passed to cv2.cvtColor. + + Returns: + The resized frame that matches the provided width and height. + + """ + height = height or frame.shape[0] + width = width or frame.shape[1] + frame = cv2.resize(frame, (width, height)) + if mode in {"GRAY", "L"}: + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) + elif mode == "BGR": + frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) + return frame + + +def process_frame( + frame: numpy.ndarray, + width: int | None = None, + height: int | None = None, + mode: str = "RGB", +) -> numpy.ndarray: + """Resize an RGB frame to a specified shape and mode. + + Use either PIL or OpenCV to resize an RGB frame to a specified height and width \ + or changing it to a different mode. + + Args: + frame: Target numpy array representing the image that will be resized. + width: Width of the resized image. + height: Height of the resized image. + mode: Passed to either Image.convert or cv2.cvtColor. + + Returns: + The resized frame that matches the provided width and height. + + """ + func = process_frame_pil if USE_PIL else process_frame_opencv # pragma: no cover + return func(frame, width, height, mode) + + class GrayScaleObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): """Convert the image observation from RGB to gray scale. diff --git a/tests/test_utils.py b/tests/test_utils.py index 376d66b..a0da404 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,7 +5,7 @@ import numpy from numpy.random import default_rng -from plangym.utils import process_frame, remove_time_limit +from plangym.utils import process_frame, remove_time_limit, process_frame_opencv rng = default_rng() @@ -27,3 +27,15 @@ def test_process_frame(): assert frame.shape == (50, 30, 3) frame = process_frame(example, width=80, height=70, mode="L") assert frame.shape == (70, 80) + + +def test_process_frame_opencv(): + example = (rng.random((100, 100, 3)) * 255).astype(numpy.uint8) + frame = process_frame_opencv(example, mode="L") + assert frame.shape == (100, 100) + frame = process_frame_opencv(example, width=30, height=50) + assert frame.shape == (50, 30, 3) + frame = process_frame_opencv(example, width=30, height=50, mode="BGR") + assert frame.shape == (50, 30, 3) + frame = process_frame_opencv(example, width=80, height=70, mode="GRAY") + assert frame.shape == (70, 80)