diff --git a/src/nrtk_explorer/app/core.py b/src/nrtk_explorer/app/core.py index 00c8c6f..0104a42 100644 --- a/src/nrtk_explorer/app/core.py +++ b/src/nrtk_explorer/app/core.py @@ -98,6 +98,7 @@ def __init__(self, server=None, **kwargs): self.ctrl.on_server_reload = self._build_ui self.ctrl.add("on_server_ready")(self.on_server_ready) + self.state.num_images = NUM_IMAGES_DEFAULT self.state.num_images_max = 0 self.state.num_images_disabled = True self.state.random_sampling = False @@ -126,7 +127,7 @@ def on_dataset_change(self, **kwargs): self.state.dataset_ids = [] # sampled images self.context.dataset = get_dataset(self.state.current_dataset) self.state.num_images_max = len(self.context.dataset.imgs) - self.state.num_images = min(self.state.num_images_max, NUM_IMAGES_DEFAULT) + self.state.num_images = min(self.state.num_images_max, self.state.num_images) self.state.dirty("num_images") # Trigger resample_images() self.state.random_sampling_disabled = False self.state.num_images_disabled = False @@ -150,18 +151,18 @@ def on_filter_apply(self, filter: FilterProtocol[Iterable[int]], **kwargs): self._embeddings_app.on_select(selected_ids) def resample_images(self, **kwargs): - images = list(self.context.dataset.imgs.values()) + ids = [image["id"] for image in self.context.dataset.imgs.values()] selected_images = [] if self.state.num_images: if self.state.random_sampling: - selected_images = random.sample(images, min(len(images), self.state.num_images)) + selected_images = random.sample(ids, min(len(ids), self.state.num_images)) else: - selected_images = images[: self.state.num_images] + selected_images = ids[: self.state.num_images] else: - selected_images = images + selected_images = ids - self.state.dataset_ids = [str(img["id"]) for img in selected_images] + self.state.dataset_ids = [str(id) for id in selected_images] self.state.user_selected_ids = self.state.dataset_ids def _build_ui(self):