Skip to content

Commit

Permalink
fix(embeddings): save embeddings params on compute
Browse files Browse the repository at this point in the history
and use saved params when ploting new transformed images.

Also, cache transformed images to re-plot then when
embedding params are changed.

Closes #170
Closes #171
  • Loading branch information
PaulHax committed Jan 17, 2025
1 parent 3e5fa3c commit dee044e
Showing 1 changed file with 96 additions and 30 deletions.
126 changes: 96 additions & 30 deletions src/nrtk_explorer/app/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Dict
from trame.decorators import TrameApp, change
from PIL import Image
from nrtk_explorer.widgets.nrtk_explorer import ScatterPlot
from nrtk_explorer.library import embeddings_extractor
from nrtk_explorer.library import dimension_reducers
from nrtk_explorer.library.dataset import get_dataset
from nrtk_explorer.library.scoring import partition
from nrtk_explorer.app.applet import Applet

from nrtk_explorer.app.images.image_ids import (
Expand All @@ -20,6 +22,41 @@
from trame.app import get_server, asynchronous


IdToImage = Dict[str, Image.Image]


@TrameApp()
class TransformedImages:
def __init__(self, server):
self.server = server
self.transformed_images: IdToImage = {}

def emit_update(self):
self.server.controller.update_transformed_images(self.transformed_images)

def add_images(self, dataset_id_to_image: IdToImage):
self.transformed_images.update(dataset_id_to_image)
self.emit_update()

@change("dataset_ids")
def on_dataset_ids(self, **kwargs):
self.transformed_images = {
k: v
for k, v in self.transformed_images.items()
if image_id_to_dataset_id(k) in self.server.state.dataset_ids
}
self.emit_update()

@change("current_dataset")
def on_dataset(self, **kwargs):
self.transformed_images = {}
self.emit_update()

def clear(self, **kwargs):
self.transformed_images = {}
self.emit_update()


class EmbeddingsApp(Applet):
def __init__(
self,
Expand Down Expand Up @@ -54,14 +91,19 @@ def __init__(
"is_transformed": True,
}

def on_server_ready(self, *args, **kwargs):
self.clear_points_transformations() # init vars
self.on_feature_extraction_model_change()
self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change)

self.transformed_images = TransformedImages(server)
self.server.controller.update_transformed_images.add(self.update_transformed_images)

def on_server_ready(self, *args, **kwargs):
self.state.change("feature_extraction_model")(self.on_feature_extraction_model_change)
self.save_embedding_params()
self.update_points()
self.state.change("dataset_ids")(self.update_points)

self.server.controller.apply_transform.add(self.clear_points_transformations)
self.server.controller.apply_transform.add(self.transformed_images.clear)
self.state.change("transform_enabled_switch")(self.update_points_transformations_state)

def on_feature_extraction_model_change(self, **kwargs):
Expand All @@ -75,29 +117,31 @@ def compute_points(self, fit_features, features):
# reduce will fail if no features
return []

if self.state.tab == "PCA":
params = self.embedding_params

if params["tab"] == "PCA":
return self.reducer.reduce(
name="PCA",
fit_features=fit_features,
features=features,
dims=self.state.dimensionality,
whiten=self.state.pca_whiten,
solver=self.state.pca_solver,
dims=params["dimensionality"],
whiten=params["pca_whiten"],
solver=params["pca_solver"],
)

# must be UMAP
args = {}
if self.state.umap_random_seed:
args["random_state"] = int(self.state.umap_random_seed_value)
if params["umap_random_seed"]:
args["random_state"] = int(params["umap_random_seed_value"])

if self.state.umap_n_neighbors:
args["n_neighbors"] = int(self.state.umap_n_neighbors_number)
if params["umap_n_neighbors"]:
args["n_neighbors"] = int(params["umap_n_neighbors_number"])

return self.reducer.reduce(
name="UMAP",
fit_features=fit_features,
features=features,
dims=self.state.dimensionality,
dims=params["dimensionality"],
**args,
)

Expand All @@ -111,14 +155,7 @@ def update_points_transformations_state(self, **kwargs):
else:
self.state.points_transformations = {}

async def compute_source_points(self):
with self.state:
self.state.is_loading = True
self.clear_points_transformations()

# Don't lock server before enabling the spinner on client
await self.server.network_completion

def compute_source_points(self):
images = [
self.images.get_image_without_cache_eviction(id) for id in self.state.dataset_ids
]
Expand All @@ -135,36 +172,65 @@ async def compute_source_points(self):

self.state.camera_position = []

async def _update_points(self):
with self.state:
self.state.is_loading = True
self.points_sources = {}
self.clear_points_transformations()
# Don't lock server before enabling the spinner on client
await self.server.network_completion

self.save_embedding_params()

with self.state:
self.compute_source_points()
self.update_transformed_images(self.transformed_images.transformed_images)
self.state.is_loading = False

def update_points(self, **kwargs):
if hasattr(self, "_update_task"):
self._update_task.cancel()
self._update_task = asynchronous.create_task(self.compute_source_points())
self._update_task = asynchronous.create_task(self._update_points())

def save_embedding_params(self):
self.embedding_params = {
"tab": self.state.tab,
"dimensionality": self.state.dimensionality,
"pca_whiten": self.state.pca_whiten,
"pca_solver": self.state.pca_solver,
"umap_random_seed": self.state.umap_random_seed,
"umap_random_seed_value": self.state.umap_random_seed_value,
"umap_n_neighbors": self.state.umap_n_neighbors,
"umap_n_neighbors_number": self.state.umap_n_neighbors_number,
}

def on_run_clicked(self):
self.save_embedding_params()
self.update_points()

def on_run_transformations(self, id_to_image):
hits, misses = partition(
lambda id: image_id_to_dataset_id(id) in self._stashed_points_transformations,
id_to_image.keys(),
)
self.transformed_images.add_images(id_to_image)

def update_transformed_images(self, id_to_image):
new_to_plot = {
id: img
for id, img in id_to_image.items()
if image_id_to_dataset_id(id) not in self._stashed_points_transformations
}

to_plot = {id: id_to_image[id] for id in misses}
transformation_features = self.extractor.extract(
list(to_plot.values()),
list(new_to_plot.values()),
batch_size=int(self.state.model_batch_size),
)
points = self.compute_points(self.features, transformation_features)
ids_to_points = zip(to_plot.keys(), points)
image_id_to_point = zip(new_to_plot.keys(), points)

updated_points = {image_id_to_dataset_id(id): point for id, point in ids_to_points}
updated_points = {image_id_to_dataset_id(id): point for id, point in image_id_to_point}
self._stashed_points_transformations = {
**self._stashed_points_transformations,
**updated_points,
}

self.update_points_transformations_state()

# called by category filter
Expand Down

0 comments on commit dee044e

Please sign in to comment.