Skip to content

Commit

Permalink
Merge pull request #125 from FragileTech/fix-pil-mac
Browse files Browse the repository at this point in the history
Use opencv when PIL is not available
  • Loading branch information
Guillemdb authored Jan 5, 2025
2 parents 43abd4a + b5ea906 commit 36641e2
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 16 deletions.
18 changes: 6 additions & 12 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,12 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Setup Rye
id: setup-rye
uses: eifinger/setup-rye@v4
- name: Setup ruff
id: setup-ruff
uses: astral-sh/ruff-action@v3
with:
enable-cache: true
cache-prefix: ubuntu-20.04-rye-check-${{ hashFiles('pyproject.toml') }}

- name: Run style check and linter
run: |
set -x
rye fmt --check
rye lint
version-file: "pyproject.toml"
args: "check --diff"

pytest:
name: Run Pytest
Expand Down Expand Up @@ -120,7 +114,7 @@ jobs:
set -x
# TODO: Figure out how to emulate a display in headless machines, and figure out why the commented files fail
# SKIP_RENDER=True rye run pytest tests/test_registry.py tests/videogames/test_retro.py
SKIP_RENDER=True rye run pytest tests/control tests/videogames/test_atari.py tests/videogames/test_nes.py tests/test_core.py
SKIP_RENDER=True rye run pytest tests/control tests/videogames/test_atari.py tests/videogames/test_nes.py tests/test_core.py tests/test_utils.py
- name: Run code coverage on Ubuntu
if: ${{ matrix.os == 'ubuntu-latest' }}
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ classifiers = [
]
dependencies = [
"numpy",
"pillow",
"pillow; sys_platform != 'darwin'",
"fragile-gym",
"opencv-python>=4.10.0.84",
"pyglet==1.5.11",
Expand Down
68 changes: 66 additions & 2 deletions src/plangym/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@
from gymnasium.spaces import Box
from gymnasium.wrappers.time_limit import TimeLimit
import numpy
from PIL import Image
from pyvirtualdisplay import Display
import cv2

try:
from PIL import Image

USE_PIL = True
except ImportError: # pragma: no cover
USE_PIL = False


def get_display(visible=False, size=(400, 400), **kwargs):
Expand Down Expand Up @@ -59,7 +66,7 @@ def remove_time_limit(gym_env: gym.Env) -> gym.Env:
return gym_env


def process_frame(
def process_frame_pil(
frame: numpy.ndarray,
width: int | None = None,
height: int | None = None,
Expand All @@ -80,13 +87,70 @@ def process_frame(
The resized frame that matches the provided width and height.
"""
mode = "L" if mode == "GRAY" else mode
height = height or frame.shape[0]
width = width or frame.shape[1]
frame = Image.fromarray(frame)
frame = frame.convert(mode).resize(size=(width, height))
return numpy.array(frame)


def process_frame_opencv(
frame: numpy.ndarray,
width: int | None = None,
height: int | None = None,
mode: str = "RGB",
) -> numpy.ndarray:
"""Resize an RGB frame to a specified shape and mode.
Use OpenCV to resize an RGB frame to a specified height and width \
or changing it to a different mode.
Args:
frame: Target numpy array representing the image that will be resized.
width: Width of the resized image.
height: Height of the resized image.
mode: Passed to cv2.cvtColor.
Returns:
The resized frame that matches the provided width and height.
"""
height = height or frame.shape[0]
width = width or frame.shape[1]
frame = cv2.resize(frame, (width, height))
if mode in {"GRAY", "L"}:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
elif mode == "BGR":
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
return frame


def process_frame(
frame: numpy.ndarray,
width: int | None = None,
height: int | None = None,
mode: str = "RGB",
) -> numpy.ndarray:
"""Resize an RGB frame to a specified shape and mode.
Use either PIL or OpenCV to resize an RGB frame to a specified height and width \
or changing it to a different mode.
Args:
frame: Target numpy array representing the image that will be resized.
width: Width of the resized image.
height: Height of the resized image.
mode: Passed to either Image.convert or cv2.cvtColor.
Returns:
The resized frame that matches the provided width and height.
"""
func = process_frame_pil if USE_PIL else process_frame_opencv # pragma: no cover
return func(frame, width, height, mode)


class GrayScaleObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs):
"""Convert the image observation from RGB to gray scale.
Expand Down
14 changes: 13 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy
from numpy.random import default_rng

from plangym.utils import process_frame, remove_time_limit
from plangym.utils import process_frame, remove_time_limit, process_frame_opencv

rng = default_rng()

Expand All @@ -27,3 +27,15 @@ def test_process_frame():
assert frame.shape == (50, 30, 3)
frame = process_frame(example, width=80, height=70, mode="L")
assert frame.shape == (70, 80)


def test_process_frame_opencv():
example = (rng.random((100, 100, 3)) * 255).astype(numpy.uint8)
frame = process_frame_opencv(example, mode="L")
assert frame.shape == (100, 100)
frame = process_frame_opencv(example, width=30, height=50)
assert frame.shape == (50, 30, 3)
frame = process_frame_opencv(example, width=30, height=50, mode="BGR")
assert frame.shape == (50, 30, 3)
frame = process_frame_opencv(example, width=80, height=70, mode="GRAY")
assert frame.shape == (70, 80)

0 comments on commit 36641e2

Please sign in to comment.