From eaa6df4cd278d4b4a5b8d9c5d1986acbc08588d1 Mon Sep 17 00:00:00 2001 From: Alessandro Genova Date: Fri, 10 Jan 2025 14:59:30 -0500 Subject: [PATCH] chore(deps): make kwcoco a required dependency --- pyproject.toml | 5 +-- src/nrtk_explorer/library/annotations.py | 4 +-- src/nrtk_explorer/library/dataset.py | 40 +++--------------------- tests/test_dataset.py | 4 +-- 4 files changed, 10 insertions(+), 43 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d102fc9..3d37ff2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,13 +44,10 @@ dependencies = [ "umap-learn", "nrtk[headless]>=0.12.0,<=0.16.0", "trame-annotations>=0.4.0", + "kwcoco", ] [project.optional-dependencies] -kwcoco= [ - "kwcoco", -] - dev = [ "black", "flake8", diff --git a/src/nrtk_explorer/library/annotations.py b/src/nrtk_explorer/library/annotations.py index baae568..7457f4c 100644 --- a/src/nrtk_explorer/library/annotations.py +++ b/src/nrtk_explorer/library/annotations.py @@ -1,5 +1,5 @@ from typing import TypedDict, List -from .dataset import JsonDataset +from .dataset import CocoDataset def get_cat_id(dataset, annotation): @@ -33,7 +33,7 @@ class Annotation(TypedDict, total=False): bbox: List[float] -def to_annotation(dataset: JsonDataset, prediction: Prediction) -> Annotation: +def to_annotation(dataset: CocoDataset, prediction: Prediction) -> Annotation: annotation: Annotation = {} if "label" in prediction: diff --git a/src/nrtk_explorer/library/dataset.py b/src/nrtk_explorer/library/dataset.py index 304f0fc..df98c30 100644 --- a/src/nrtk_explorer/library/dataset.py +++ b/src/nrtk_explorer/library/dataset.py @@ -10,8 +10,8 @@ import os from functools import lru_cache from pathlib import Path -import json from PIL import Image +import kwcoco from datasets import ( load_dataset, get_dataset_infos, @@ -36,42 +36,12 @@ def build_cat_index(self): self.name_to_cat = {cat["name"]: cat for cat in self.cats.values()} -class JsonDataset(BaseDataset, CategoryIndex): - """JSON-based COCO datasets.""" - - def __init__(self, path: str): - with open(path) as f: - self.data = json.load(f) - self.fpath = path - self.cats = {cat["id"]: cat for cat in self.data["categories"]} - self.anns = {ann["id"]: ann for ann in self.data["annotations"]} - self.imgs = {img["id"]: img for img in self.data["images"]} - self.build_cat_index() - - def _get_image_fpath(self, selected_id: int): - dataset_dir = Path(self.fpath).parent - file_name = self.imgs[selected_id]["file_name"] - return str(dataset_dir / file_name) - +class CocoDataset(kwcoco.CocoDataset, BaseDataset): def get_image(self, id: int): - image_fpath = self._get_image_fpath(id) + image_fpath = self.get_image_fpath(id) return Image.open(image_fpath) -def make_coco_dataset(path: str): - try: - import kwcoco - - class CocoDataset(kwcoco.CocoDataset, BaseDataset): - def get_image(self, id: int): - image_fpath = self.get_image_fpath(id) - return Image.open(image_fpath) - - return CocoDataset(path) - except ImportError: - return JsonDataset(path) - - def is_coco_dataset(path: str): if not os.path.exists(path) or os.path.isdir(path): return False @@ -102,7 +72,7 @@ def find_column_name(features, column_names): class HuggingFaceDataset(BaseDataset, CategoryIndex): - """Interface for Hugging Face datasets with a similar API to JsonDataset.""" + """Interface for Hugging Face datasets with a similar API to CocoDataset.""" def __init__(self, identifier: str): self.imgs: dict[str, dict] = {} @@ -244,7 +214,7 @@ def get_dataset(identifier: str): absolute_path = str(Path(identifier).resolve()) if is_coco_dataset(absolute_path): - return make_coco_dataset(absolute_path) + return CocoDataset(absolute_path) # Assume identifier is a Hugging Face Dataset return HuggingFaceDataset(identifier) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 687902c..e7c4563 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,4 +1,4 @@ -from nrtk_explorer.library.dataset import get_dataset, JsonDataset +from nrtk_explorer.library.dataset import get_dataset, CocoDataset import nrtk_explorer.test_data from pathlib import Path @@ -27,7 +27,7 @@ def test_get_dataset_empty(): def test_DefaultDataset(dataset_path): - ds = JsonDataset(dataset_path) + ds = CocoDataset(dataset_path) assert len(ds.imgs) > 0 assert len(ds.cats) > 0 assert len(ds.anns) > 0