Skip to content

Commit

Permalink
chore(deps): make kwcoco a required dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
alesgenova authored and PaulHax committed Jan 20, 2025
1 parent 3500aeb commit eaa6df4
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 43 deletions.
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,10 @@ dependencies = [
"umap-learn",
"nrtk[headless]>=0.12.0,<=0.16.0",
"trame-annotations>=0.4.0",
"kwcoco",
]

[project.optional-dependencies]
kwcoco= [
"kwcoco",
]

dev = [
"black",
"flake8",
Expand Down
4 changes: 2 additions & 2 deletions src/nrtk_explorer/library/annotations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import TypedDict, List
from .dataset import JsonDataset
from .dataset import CocoDataset


def get_cat_id(dataset, annotation):
Expand Down Expand Up @@ -33,7 +33,7 @@ class Annotation(TypedDict, total=False):
bbox: List[float]


def to_annotation(dataset: JsonDataset, prediction: Prediction) -> Annotation:
def to_annotation(dataset: CocoDataset, prediction: Prediction) -> Annotation:
annotation: Annotation = {}

if "label" in prediction:
Expand Down
40 changes: 5 additions & 35 deletions src/nrtk_explorer/library/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import os
from functools import lru_cache
from pathlib import Path
import json
from PIL import Image
import kwcoco
from datasets import (
load_dataset,
get_dataset_infos,
Expand All @@ -36,42 +36,12 @@ def build_cat_index(self):
self.name_to_cat = {cat["name"]: cat for cat in self.cats.values()}


class JsonDataset(BaseDataset, CategoryIndex):
"""JSON-based COCO datasets."""

def __init__(self, path: str):
with open(path) as f:
self.data = json.load(f)
self.fpath = path
self.cats = {cat["id"]: cat for cat in self.data["categories"]}
self.anns = {ann["id"]: ann for ann in self.data["annotations"]}
self.imgs = {img["id"]: img for img in self.data["images"]}
self.build_cat_index()

def _get_image_fpath(self, selected_id: int):
dataset_dir = Path(self.fpath).parent
file_name = self.imgs[selected_id]["file_name"]
return str(dataset_dir / file_name)

class CocoDataset(kwcoco.CocoDataset, BaseDataset):
def get_image(self, id: int):
image_fpath = self._get_image_fpath(id)
image_fpath = self.get_image_fpath(id)
return Image.open(image_fpath)


def make_coco_dataset(path: str):
try:
import kwcoco

class CocoDataset(kwcoco.CocoDataset, BaseDataset):
def get_image(self, id: int):
image_fpath = self.get_image_fpath(id)
return Image.open(image_fpath)

return CocoDataset(path)
except ImportError:
return JsonDataset(path)


def is_coco_dataset(path: str):
if not os.path.exists(path) or os.path.isdir(path):
return False
Expand Down Expand Up @@ -102,7 +72,7 @@ def find_column_name(features, column_names):


class HuggingFaceDataset(BaseDataset, CategoryIndex):
"""Interface for Hugging Face datasets with a similar API to JsonDataset."""
"""Interface for Hugging Face datasets with a similar API to CocoDataset."""

def __init__(self, identifier: str):
self.imgs: dict[str, dict] = {}
Expand Down Expand Up @@ -244,7 +214,7 @@ def get_dataset(identifier: str):
absolute_path = str(Path(identifier).resolve())

if is_coco_dataset(absolute_path):
return make_coco_dataset(absolute_path)
return CocoDataset(absolute_path)

# Assume identifier is a Hugging Face Dataset
return HuggingFaceDataset(identifier)
4 changes: 2 additions & 2 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from nrtk_explorer.library.dataset import get_dataset, JsonDataset
from nrtk_explorer.library.dataset import get_dataset, CocoDataset
import nrtk_explorer.test_data

from pathlib import Path
Expand Down Expand Up @@ -27,7 +27,7 @@ def test_get_dataset_empty():


def test_DefaultDataset(dataset_path):
ds = JsonDataset(dataset_path)
ds = CocoDataset(dataset_path)
assert len(ds.imgs) > 0
assert len(ds.cats) > 0
assert len(ds.anns) > 0
Expand Down

0 comments on commit eaa6df4

Please sign in to comment.