Skip to content

Commit

Permalink
refactor(transforms): pass original detections via func args
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulHax committed Jan 15, 2025
1 parent e326e4e commit 81bc9d9
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions src/nrtk_explorer/app/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,9 @@ def on_apply_transform(self, **kwargs):
self.state.transform_enabled_switch = True
self._start_update_images()

async def update_transformed_images(self, dataset_ids, visible=False):
async def update_transformed_images(
self, dataset_ids, predictions_original_images, visible=False
):
if not self.state.transform_enabled:
return

Expand Down Expand Up @@ -310,10 +312,10 @@ async def update_transformed_images(self, dataset_ids, visible=False):
)

# depends on original images predictions
if self.state.predictions_original_images_enabled:
if predictions_original_images:
scores = compute_score(
self.context.dataset,
self.predictions_original_images,
predictions_original_images,
annotations,
self.state.confidence_score_threshold,
)
Expand Down Expand Up @@ -344,25 +346,25 @@ async def compute_predictions_original_images(self, dataset_ids):
}
)

self.predictions_original_images = (
await self.original_detection_annotations.get_annotations(
self.predictor, image_id_to_image
)
predictions_original_images = await self.original_detection_annotations.get_annotations(
self.predictor, image_id_to_image
)

ground_truth_annotations = self.ground_truth_annotations.get_annotations(dataset_ids)

scores = compute_score(
self.context.dataset,
ground_truth_annotations,
self.predictions_original_images,
predictions_original_images,
self.state.confidence_score_threshold,
)
for dataset_id, score in scores:
update_image_meta(
self.state, dataset_id, {"original_ground_to_original_detection_score": score}
)

return predictions_original_images

async def _update_images(self, dataset_ids, visible=False):
if visible:
# load images on state for ImageList
Expand All @@ -374,16 +376,19 @@ async def _update_images(self, dataset_ids, visible=False):

# always push to state because compute_predictions_original_images updates score metadata
with self.state:
await self.compute_predictions_original_images(dataset_ids)
predictions_original_images = await self.compute_predictions_original_images(
dataset_ids
)
await self.server.network_completion
# sortable score value may have changed which may have changed images that are in view
self.server.controller.check_images_in_view()

await self.update_transformed_images(dataset_ids, visible=visible)
await self.update_transformed_images(
dataset_ids, predictions_original_images, visible=visible
)

async def _chunk_update_images(self, dataset_ids, visible=False):
ids = list(dataset_ids)

for i in range(0, len(ids), UPDATE_IMAGES_CHUNK_SIZE):
chunk = ids[i : i + UPDATE_IMAGES_CHUNK_SIZE]
await self._update_images(chunk, visible=visible)
Expand Down

0 comments on commit 81bc9d9

Please sign in to comment.