Skip to content

Commit

Permalink
Use opencv when PIL is not available
Browse files Browse the repository at this point in the history
Signed-off-by: guillemdb <[email protected]>
  • Loading branch information
Guillemdb committed Jan 5, 2025
1 parent 43abd4a commit 77f5304
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 15 deletions.
18 changes: 6 additions & 12 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,12 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Setup Rye
id: setup-rye
uses: eifinger/setup-rye@v4
- name: Setup ruff
id: setup-ruff
uses: astral-sh/ruff-action@v3
with:
enable-cache: true
cache-prefix: ubuntu-20.04-rye-check-${{ hashFiles('pyproject.toml') }}

- name: Run style check and linter
run: |
set -x
rye fmt --check
rye lint
version-file: "pyproject.toml"
args: "check --diff"

pytest:
name: Run Pytest
Expand Down Expand Up @@ -120,7 +114,7 @@ jobs:
set -x
# TODO: Figure out how to emulate a display in headless machines, and figure out why the commented files fail
# SKIP_RENDER=True rye run pytest tests/test_registry.py tests/videogames/test_retro.py
SKIP_RENDER=True rye run pytest tests/control tests/videogames/test_atari.py tests/videogames/test_nes.py tests/test_core.py
SKIP_RENDER=True rye run pytest tests/control tests/videogames/test_atari.py tests/videogames/test_nes.py tests/test_core.py tests/test_utils.py
- name: Run code coverage on Ubuntu
if: ${{ matrix.os == 'ubuntu-latest' }}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ classifiers = [
]
dependencies = [
"numpy",
"pillow",
"pillow; sys_platform != 'darwin'",
"fragile-gym",
"opencv-python>=4.10.0.84",
"pyglet==1.5.11",
Expand Down
68 changes: 66 additions & 2 deletions src/plangym/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,17 @@
from gymnasium.spaces import Box
from gymnasium.wrappers.time_limit import TimeLimit
import numpy
from PIL import Image
from pyvirtualdisplay import Display

try:
from PIL import Image

USE_PIL = True
except ImportError: # pragma: no cover
import cv2

USE_PIL = False


def get_display(visible=False, size=(400, 400), **kwargs):
"""Start a virtual display."""
Expand Down Expand Up @@ -59,7 +67,7 @@ def remove_time_limit(gym_env: gym.Env) -> gym.Env:
return gym_env


def process_frame(
def process_frame_pil(
frame: numpy.ndarray,
width: int | None = None,
height: int | None = None,
Expand Down Expand Up @@ -87,6 +95,62 @@ def process_frame(
return numpy.array(frame)


def process_frame_opencv(
frame: numpy.ndarray,
width: int | None = None,
height: int | None = None,
mode: str = "RGB",
) -> numpy.ndarray:
"""Resize an RGB frame to a specified shape and mode.
Use OpenCV to resize an RGB frame to a specified height and width \
or changing it to a different mode.
Args:
frame: Target numpy array representing the image that will be resized.
width: Width of the resized image.
height: Height of the resized image.
mode: Passed to cv2.cvtColor.
Returns:
The resized frame that matches the provided width and height.
"""
height = height or frame.shape[0]
width = width or frame.shape[1]
frame = cv2.resize(frame, (width, height))

Check warning on line 121 in src/plangym/utils.py

View check run for this annotation

Codecov / codecov/patch

src/plangym/utils.py#L119-L121

Added lines #L119 - L121 were not covered by tests
if mode == "GRAY":
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)

Check warning on line 123 in src/plangym/utils.py

View check run for this annotation

Codecov / codecov/patch

src/plangym/utils.py#L123

Added line #L123 was not covered by tests
elif mode == "BGR":
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
return frame

Check warning on line 126 in src/plangym/utils.py

View check run for this annotation

Codecov / codecov/patch

src/plangym/utils.py#L125-L126

Added lines #L125 - L126 were not covered by tests


def process_frame(
frame: numpy.ndarray,
width: int | None = None,
height: int | None = None,
mode: str = "RGB",
) -> numpy.ndarray:
"""Resize an RGB frame to a specified shape and mode.
Use either PIL or OpenCV to resize an RGB frame to a specified height and width \
or changing it to a different mode.
Args:
frame: Target numpy array representing the image that will be resized.
width: Width of the resized image.
height: Height of the resized image.
mode: Passed to either Image.convert or cv2.cvtColor.
Returns:
The resized frame that matches the provided width and height.
"""
func = process_frame_pil if USE_PIL else process_frame_opencv
return func(frame, width, height, mode)


class GrayScaleObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Convert the image observation from RGB to gray scale.
Expand Down

0 comments on commit 77f5304

Please sign in to comment.