Skip to content

Commit

Permalink
Merge main to release
Browse files Browse the repository at this point in the history
  • Loading branch information
actions-user committed Aug 7, 2024
2 parents 0e6708e + 0ac0975 commit d67665e
Show file tree
Hide file tree
Showing 29 changed files with 4,283 additions and 2,967 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ ignore =
# Black and flake8 conflict here
E203
E704
W503
# Just assume black did a good job with the line lengths
E501

6 changes: 4 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
name: CI

on:
- push
- pull_request
push:
pull_request:
schedule:
- cron: "0 10 * * *"

jobs:
linters_python:
Expand Down
17 changes: 9 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,32 @@ classifiers = [
]

dependencies = [
"nrtk",
"accelerate",
"kwcoco",
"nrtk>=0.4.2",
"numpy",
"Pillow",
"scikit-learn==1.4.2",
"accelerate",
"scikit-learn==1.5.1",
"smqtk_image_io",
"tabulate",
"transformers",
"timm",
"timm>=1.0.3",
"torch",
"torchvision",
"trame",
"trame-client>=2.15.0",
"trame-quasar",
"trame-server>=2.15.0",
"transformers",
"umap-learn",
"tabulate",
]

[project.optional-dependencies]
dev = [
"black",
"flake8",
"mypy",
"pytest",
"tabulate",
"mypy",
]

package = [
Expand All @@ -75,7 +76,7 @@ build-backend = "hatchling.build"
[project.scripts]
nrtk-explorer = "nrtk_explorer.app:main"
nrtk-explorer-embeddings = "nrtk_explorer.app.embeddings:embeddings"
nrtk-explorer-tranforms = "nrtk_explorer.app.transforms:transforms"
nrtk-explorer-transforms = "nrtk_explorer.app.transforms:transforms"
nrtk-explorer-filtering = "nrtk_explorer.app.filtering:filtering"

[tool.black]
Expand Down
44 changes: 12 additions & 32 deletions src/nrtk_explorer/app/core.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
import logging
from typing import Iterable
from pathlib import Path

from trame.widgets import html
from trame_server.utils.namespace import Translator
from nrtk_explorer.library import images_manager
from nrtk_explorer.library.filtering import FilterProtocol
from nrtk_explorer.library.dataset import get_dataset

from nrtk_explorer.app.embeddings import EmbeddingsApp
from nrtk_explorer.app.transforms import TransformsApp
from nrtk_explorer.app.filtering import FilteringApp
from nrtk_explorer.app.applet import Applet
from nrtk_explorer.app import ui
import nrtk_explorer.test_data
from pathlib import Path

import os

import json
import random


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

Expand All @@ -35,14 +34,6 @@
]


def image_id_to_meta(image_id):
return f"{image_id}_meta"


def image_id_to_result(image_id):
return f"{image_id}_result"


# ---------------------------------------------------------
# Engine class
# ---------------------------------------------------------
Expand All @@ -65,7 +56,6 @@ def __init__(self, server=None):

self.context["image_objects"] = {}
self.context["images_manager"] = images_manager.ImagesManager()
self.context["annotations"] = {}

self.state.collapse_dataset = False
self.state.collapse_embeddings = False
Expand Down Expand Up @@ -103,7 +93,7 @@ def __init__(self, server=None):
server=self.server.create_child_server(translator=filtering_translator),
)

self._embeddings_app.set_on_select(self._transforms_app.on_selected_images_change)
self._embeddings_app.set_on_select(self._transforms_app.set_selected_dataset_ids)
self._transforms_app.set_on_transform(self._embeddings_app.on_run_transformations)
self._embeddings_app.set_on_hover(self._transforms_app.on_image_hovered)
self._transforms_app.set_on_hover(self._embeddings_app.on_image_hovered)
Expand All @@ -124,8 +114,6 @@ def __init__(self, server=None):

self._build_ui()

self.context.images_manager = images_manager.ImagesManager()

def on_server_ready(self, *args, **kwargs):
# Bind instance methods to state change
self.state.change("current_dataset")(self.on_dataset_change)
Expand All @@ -137,27 +125,22 @@ def on_server_ready(self, *args, **kwargs):
def on_dataset_change(self, **kwargs):
# Reset cache
self.context.images_manager = images_manager.ImagesManager()

with open(self.state.current_dataset) as f:
dataset = json.load(f)

self.state.num_images_max = len(dataset["images"])
self.context.dataset = get_dataset(self.state.current_dataset, force_reload=True)
self.state.num_images_max = len(self.context.dataset.imgs)
self.state.random_sampling_disabled = False
self.state.num_images_disabled = False

self.reload_images()

def on_filter_apply(self, filter: FilterProtocol[Iterable[int]], **kwargs):
selected_indices = []

for index, image_id in enumerate(self.state.images_ids):
image_annotations_categories = map(
lambda annotation: annotation["category_id"],
self.context["annotations"].get(f"img_{image_id}", []),
)

image_annotations_categories = [
annotation["category_id"]
for annotation in self.context.dataset.anns.values()
if annotation["image_id"] == image_id
]
include = filter.evaluate(image_annotations_categories)

if include:
selected_indices.append(index)

Expand All @@ -170,14 +153,11 @@ def on_random_sampling_change(self, **kwargs):
self.reload_images()

def reload_images(self):
with open(self.state.current_dataset) as f:
dataset = json.load(f)

categories = {}
for category in dataset["categories"]:
for category in self.context.dataset.cats.values():
categories[category["id"]] = category

images = dataset["images"]
images = list(self.context.dataset.imgs.values())

selected_images = []
if self.state.num_images:
Expand Down
60 changes: 44 additions & 16 deletions src/nrtk_explorer/app/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from nrtk_explorer.library import embeddings_extractor
from nrtk_explorer.library import dimension_reducers
from nrtk_explorer.library import images_manager
from nrtk_explorer.library.dataset import get_dataset
from nrtk_explorer.app.applet import Applet
import nrtk_explorer.test_data

import asyncio
import json
import os

from trame.widgets import quasar, html
Expand Down Expand Up @@ -60,10 +60,11 @@ def on_feature_extraction_model_change(self, **kwargs):

def on_current_dataset_change(self, **kwargs):
self.state.num_elements_disabled = True
with open(self.state.current_dataset) as f:
dataset = json.load(f)
self.images = dataset["images"]
self.state.num_elements_max = len(self.images)
if self.context.dataset is None:
self.context.dataset = get_dataset(self.state.current_dataset, force_reload=True)

self.images = list(self.context.dataset.imgs.values())
self.state.num_elements_max = len(self.images)
self.state.num_elements_disabled = False

if self.is_standalone_app:
Expand Down Expand Up @@ -165,9 +166,20 @@ def set_on_select(self, fn):
self._on_select_fn = fn

def on_select(self, indices):
self.state.user_selected_points_indices = indices
# remap transformed indices to original indices
original_indices = set()
for point_index in indices:
original_image_point_index = point_index
if point_index >= len(self.state.points_sources):
original_image_point_index = self.state.user_selected_points_indices[
point_index - len(self.state.points_sources)
]
original_indices.add(original_image_point_index)
original_indices = list(original_indices)

self.state.user_selected_points_indices = original_indices
self.state.points_transformations = []
ids = [self.state.images_ids[i] for i in indices]
ids = [self.state.images_ids[i] for i in original_indices]
if self._on_select_fn:
self._on_select_fn(ids)

Expand All @@ -177,18 +189,34 @@ def on_move(self, camera_position):
def set_on_hover(self, fn):
self._on_hover_fn = fn

def on_hover(self, point):
self.state.highlighted_point = point
image_id = -1
if point is not None and point in self.state.user_selected_points_indices:
image_id = self.state.images_ids[int(point)]
def on_point_hover(self, point_index):
self.state.highlighted_point = point_index
image_id = ""
if point_index is not None:
original_image_point_index = point_index
if point_index >= len(self.state.points_sources):
image_kind = "transformed_img_"
original_image_point_index = self.state.user_selected_points_indices[
point_index - len(self.state.points_sources)
]
else:
image_kind = "img_"
dataset_id = self.state.images_ids[original_image_point_index]
image_id = f"{image_kind}{dataset_id}"

if self._on_hover_fn:
self._on_hover_fn(image_id)

def on_image_hovered(self, id_, is_transformation):
def on_image_hovered(self, id_):
# If the point is in the list of selected points, we set it as the highlighted point
if id_ in self.state.images_ids:
index = self.state.images_ids.index(id_)
is_transformation = id_.startswith("transformed_img_")
try:
dataset_id = int(id_.split("_")[-1]) # img_123 or transformed_img_123 -> 123
except ValueError:
# id_ probably is an empty string
dataset_id = id_
if dataset_id in self.state.images_ids:
index = self.state.images_ids.index(dataset_id)
if is_transformation:
index_selected = self.state.user_selected_points_indices.index(index)
self.state.highlighted_point = len(self.state.points_sources) + index_selected
Expand All @@ -203,7 +231,7 @@ def visualization_widget(self):
cameraMove="camera_position=$event",
cameraPosition=("camera_position",),
highlightedPoint=("highlighted_point", -1),
hover=(self.on_hover, "[$event]"),
hover=(self.on_point_hover, "[$event]"),
points=("points_sources", []),
transformedPoints=("points_transformations", []),
select=(self.on_select, "[$event]"),
Expand Down
14 changes: 14 additions & 0 deletions src/nrtk_explorer/app/image_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
def image_id_to_dataset_id(image_id: str):
return image_id.split("_")[-1]


def dataset_id_to_image_id(dataset_id: str):
return f"img_{dataset_id}"


def dataset_id_to_transformed_image_id(dataset_id: str):
return f"transformed_img_{dataset_id}"


def image_id_to_result_id(image_id: str):
return f"result_{image_id}"
38 changes: 38 additions & 0 deletions src/nrtk_explorer/app/image_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import TypedDict
from nrtk_explorer.app.trame_utils import delete_state

ImageMetaId = str


def image_id_to_meta(image_id: str) -> ImageMetaId:
return f"meta_{image_id}"


class DatasetImageMeta(TypedDict):
original_ground_to_original_detection_score: float
original_detection_to_transformed_detection_score: float
ground_truth_to_transformed_detection_score: float


PartialDatasetImageMeta = TypedDict(
"PartialDatasetImageMeta", {**DatasetImageMeta.__annotations__}, total=False
)

IMAGE_META_DEFAULTS: DatasetImageMeta = {
"original_ground_to_original_detection_score": 0,
"original_detection_to_transformed_detection_score": 0,
"ground_truth_to_transformed_detection_score": 0,
}


def update_image_meta(state, dataset_id: str, meta_patch: PartialDatasetImageMeta):
meta_key = image_id_to_meta(dataset_id)
current_meta = {}
if state.has(meta_key) and state[meta_key] is not None:
current_meta = state[meta_key]
state[meta_key] = {**IMAGE_META_DEFAULTS, **current_meta, **meta_patch}


def delete_image_meta(state, dataset_id: str):
meta_key = image_id_to_meta(dataset_id)
delete_state(state, meta_key)
Loading

0 comments on commit d67665e

Please sign in to comment.