Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sorting on scores #164

Merged
merged 7 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/nrtk_explorer/app/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(self, server=None, **kwargs):
self.ctrl.on_server_reload = self._build_ui
self.ctrl.add("on_server_ready")(self.on_server_ready)

self.state.num_images = NUM_IMAGES_DEFAULT
self.state.num_images_max = 0
self.state.num_images_disabled = True
self.state.random_sampling = False
Expand Down Expand Up @@ -164,7 +165,7 @@ def on_dataset_change(self, **kwargs):
self.state.dataset_ids = [] # sampled images
self.context.dataset = get_dataset(self.state.current_dataset)
self.state.num_images_max = len(self.context.dataset.imgs)
self.state.num_images = min(self.state.num_images_max, NUM_IMAGES_DEFAULT)
self.state.num_images = min(self.state.num_images_max, self.state.num_images)
self.state.dirty("num_images") # Trigger resample_images()
self.state.random_sampling_disabled = False
self.state.num_images_disabled = False
Expand All @@ -188,19 +189,19 @@ def on_filter_apply(self, filter: FilterProtocol[Iterable[int]], **kwargs):
self._embeddings_app.on_select(selected_ids)

def resample_images(self, **kwargs):
images = list(self.context.dataset.imgs.values())
ids = [image["id"] for image in self.context.dataset.imgs.values()]

selected_images = []
if self.state.num_images:
if self.state.random_sampling:
selected_images = random.sample(images, min(len(images), self.state.num_images))
selected_images = random.sample(ids, min(len(ids), self.state.num_images))
else:
selected_images = images[: self.state.num_images]
selected_images = ids[: self.state.num_images]
else:
selected_images = images
selected_images = ids

self.context.dataset_ids = [img["id"] for img in selected_images]
self.state.dataset_ids = [str(image_id) for image_id in self.context.dataset_ids]
self.context.dataset_ids = selected_images
self.state.dataset_ids = [str(id) for id in self.context.dataset_ids]
self.state.user_selected_ids = self.state.dataset_ids

def _build_ui(self):
Expand Down
46 changes: 20 additions & 26 deletions src/nrtk_explorer/app/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from nrtk_explorer.library import embeddings_extractor
from nrtk_explorer.library import dimension_reducers
from nrtk_explorer.library.dataset import get_dataset
from nrtk_explorer.library.scoring import partition
from nrtk_explorer.app.applet import Applet

from nrtk_explorer.app.images.image_ids import (
Expand Down Expand Up @@ -39,7 +40,7 @@ def __init__(
if self.is_standalone_app and datasets:
self.state.dataset_ids = []
self.state.current_dataset = datasets[0]
self.on_current_dataset_change()
self.context.dataset = get_dataset(self.state.current_dataset)

self.features = None

Expand All @@ -54,35 +55,21 @@ def __init__(
}

def on_server_ready(self, *args, **kwargs):
# Bind instance methods to state change
self.on_current_dataset_change()
self.state.change("current_dataset")(self.on_current_dataset_change)

self.on_feature_extraction_model_change()
self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change)

self.update_points()
self.state.change("dataset_ids")(self.update_points)

self.server.controller.apply_transform.add(self.clear_points_transformations)
self.state.change("transform_enabled_switch")(
self.update_points_transformations_visibility
)
self.state.change("transform_enabled_switch")(self.update_points_transformations_state)

def on_feature_extraction_model_change(self, **kwargs):
feature_extraction_model = self.state.feature_extraction_model
self.extractor = embeddings_extractor.EmbeddingsExtractor(
model_name=feature_extraction_model
)

def on_current_dataset_change(self, **kwargs):
self.state.num_elements_disabled = True
if self.context.dataset is None:
self.context.dataset = get_dataset(self.state.current_dataset)

self.state.num_elements_max = len(list(self.context.dataset.imgs))
self.state.num_elements_disabled = False

def compute_points(self, fit_features, features):
if len(features) == 0:
# reduce will fail if no features
Expand Down Expand Up @@ -115,19 +102,19 @@ def compute_points(self, fit_features, features):
)

def clear_points_transformations(self, **kwargs):
self.state.points_transformations = {} # ID to point
self.state.points_transformations = {} # datset ID to point
self._stashed_points_transformations = {}

def update_points_transformations_visibility(self, **kwargs):
def update_points_transformations_state(self, **kwargs):
if self.state.transform_enabled_switch:
self.state.points_transformations = self._stashed_points_transformations
else:
self._stashed_points_transformations = self.state.points_transformations
self.state.points_transformations = {}

async def compute_source_points(self):
with self.state:
self.state.is_loading = True
self.clear_points_transformations()

# Don't lock server before enabling the spinner on client
await self.server.network_completion
Expand All @@ -146,8 +133,6 @@ async def compute_source_points(self):
id: point for id, point in zip(self.state.dataset_ids, points)
}

self.clear_points_transformations()

self.state.camera_position = []

with self.state:
Expand All @@ -162,16 +147,25 @@ def on_run_clicked(self):
self.update_points()

def on_run_transformations(self, id_to_image):
hits, misses = partition(
lambda id: image_id_to_dataset_id(id) in self._stashed_points_transformations,
id_to_image.keys(),
)

to_plot = {id: id_to_image[id] for id in misses}
transformation_features = self.extractor.extract(
id_to_image.values(),
list(to_plot.values()),
batch_size=int(self.state.model_batch_size),
)

points = self.compute_points(self.features, transformation_features)
ids_to_points = zip(to_plot.keys(), points)

ids = id_to_image.keys()
updated_points = {image_id_to_dataset_id(id): point for id, point in zip(ids, points)}
self.state.points_transformations = {**self.state.points_transformations, **updated_points}
updated_points = {image_id_to_dataset_id(id): point for id, point in ids_to_points}
self._stashed_points_transformations = {
**self._stashed_points_transformations,
**updated_points,
}
self.update_points_transformations_state()

# called by category filter
def on_select(self, image_ids):
Expand Down
18 changes: 9 additions & 9 deletions src/nrtk_explorer/app/images/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import lru_cache, partial
from PIL import Image
from nrtk_explorer.app.images.cache import LruCache
from nrtk_explorer.library.object_detector import ObjectDetector
from nrtk_explorer.library.multiprocess_predictor import MultiprocessPredictor
from nrtk_explorer.library.scoring import partition


Expand Down Expand Up @@ -67,21 +67,21 @@ def __init__(
self.add_to_cache_callback = add_to_cache_callback
self.delete_from_cache_callback = delete_from_cache_callback

def get_annotations(self, detector: ObjectDetector, id_to_image: Dict[str, Image.Image]):
hits, misses = partition(self.cache.get_item, id_to_image.keys())
cached_predictions = {id: self.cache.get_item(id) for id in hits}
async def get_annotations(
self, predictor: MultiprocessPredictor, id_to_image: Dict[str, Image.Image]
):
hits, misses = partition(
lambda id: self.cache.get_item(id) is not None, id_to_image.keys()
)

to_detect = {id: id_to_image[id] for id in misses}
predictions = detector.eval(
to_detect,
)
predictions = await predictor.infer(to_detect)
for id, annotations in predictions.items():
self.cache.add_item(
id, annotations, self.add_to_cache_callback, self.delete_from_cache_callback
)

predictions.update(**cached_predictions)
return predictions
return {id: self.cache.get_item(id) for id in id_to_image.keys()}

def cache_clear(self):
self.cache.clear()
15 changes: 13 additions & 2 deletions src/nrtk_explorer/app/images/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def convert_to_base64(img: Image.Image) -> str:
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode()


IMAGE_CACHE_SIZE = 200
IMAGE_CACHE_SIZE = 500


@TrameApp()
Expand Down Expand Up @@ -74,11 +74,15 @@ def _load_transformed_image(self, transform: ImageTransform, dataset_id: str):
return transformed.resize(original.size)
return transformed

def get_transformed_image(self, transform: ImageTransform, dataset_id: str, **kwargs):
def _get_transformed_image(self, transform: ImageTransform, dataset_id: str, **kwargs):
image_id = dataset_id_to_transformed_image_id(dataset_id)
image = self.transformed_images.get_item(image_id) or self._load_transformed_image(
transform, dataset_id
)
return image_id, image

def get_transformed_image(self, transform: ImageTransform, dataset_id: str, **kwargs):
image_id, image = self._get_transformed_image(transform, dataset_id, **kwargs)
self.transformed_images.add_item(image_id, image, **kwargs)
return image

Expand All @@ -90,6 +94,13 @@ def get_stateful_transformed_image(self, transform: ImageTransform, dataset_id:
on_clear_item=self._delete_from_state,
)

def get_transformed_image_without_cache_eviction(
self, transform: ImageTransform, dataset_id: str
):
image_id, image = self._get_transformed_image(transform, dataset_id)
self.transformed_images.add_if_room(image_id, image)
return image

@change("current_dataset")
def clear_all(self, **kwargs):
self.original_images.clear()
Expand Down
Loading
Loading