From 3856366de752bd06738ed39e0659c49d1ff0ec26 Mon Sep 17 00:00:00 2001 From: paulruelle Date: Mon, 30 Dec 2024 11:24:43 +0100 Subject: [PATCH] feat(LAB-3307): handle export of comparison in llm_static --- src/kili/domain/llm.py | 4 ++- src/kili/llm/services/export/dynamic.py | 38 ++++++++++++++++++------- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/src/kili/domain/llm.py b/src/kili/domain/llm.py index 37e21c39a..43b6cc3bd 100644 --- a/src/kili/domain/llm.py +++ b/src/kili/domain/llm.py @@ -138,6 +138,7 @@ class ProjectModelDict(TypedDict): class ChatItem(TypedDict): """Dict that represents a ChatItem.""" + id: str content: str external_id: str model_name: Optional[str] @@ -153,8 +154,9 @@ class ConversationLabel(TypedDict): class Conversation(TypedDict): + """Dict that represents a Conversation.""" + external_id: Optional[str] chat_items: List[ChatItem] label: Optional[ConversationLabel] - labeler: str metadata: Optional[dict] diff --git a/src/kili/llm/services/export/dynamic.py b/src/kili/llm/services/export/dynamic.py index 2c2e6ae3c..3e122d001 100644 --- a/src/kili/llm/services/export/dynamic.py +++ b/src/kili/llm/services/export/dynamic.py @@ -45,8 +45,19 @@ def export( for asset in assets: # obfuscate models here obfuscated_models = {} - for index, asset_project_model in enumerate(asset["assetProjectModels"]): - obfuscated_models[asset_project_model["id"]] = f"{chr(65 + index)}" + if asset.get("assetProjectModels"): + for index, asset_project_model in enumerate(asset["assetProjectModels"]): + obfuscated_models[asset_project_model["id"]] = f"{chr(65 + index)}" + else: + model_names = { + chat_item.get("modelName") + for chat_item in asset["labels"][0]["chatItems"] + if chat_item.get("modelName") is not None + } + obfuscated_models = { + model_name: f"{chr(65 + index)}" for index, model_name in enumerate(model_names) + } + for label in asset["labels"]: result = {} chat_items = label["chatItems"] @@ -81,17 +92,23 @@ def export( if step == total_rounds - 1 and formatted_response["conversation"]: label_data["label"]["conversation"] = formatted_response["conversation"] + if asset.get("assetProjectModels"): + models = _format_models_object( + asset["assetProjectModels"], obfuscated_models + ) + else: + models = {v: k for k, v in obfuscated_models.items()} + result[f"{step}"] = { "external_id": asset["externalId"], "metadata": asset["jsonMetadata"], - "models": _format_models_object( - asset["assetProjectModels"], obfuscated_models - ), + "models": models, "labels": [label_data], "raw_data": raw_data, "status": asset["status"], } export_res.append(result) + return export_res def _get_round_winner(self, completions, annotations, json_interface): @@ -221,15 +238,14 @@ def _format_transcription_annotation(annotation): def _format_comparison_annotation(annotation, completions, job, obfuscated_models): """Return A_X or B_X depending of the evaluation completion and its score.""" - model_id = None + model_key = None for completion in completions: if annotation["annotationValue"]["choice"]["firstId"] == completion["id"]: - model_id = completion["modelId"] + model_key = completion.get("modelId") or completion.get("modelName") break - if model_id is None: - # FIXME : model_id can be null on LLM_STATIC - return None + if model_key is None: + raise ValueError(f"Failed to found model of annotation {annotation['id']}") iteration = 0 for _comparison_code, comparison_note in job["content"]["options"].items(): @@ -237,7 +253,7 @@ def _format_comparison_annotation(annotation, completions, job, obfuscated_model if comparison_note["name"] == annotation["annotationValue"]["choice"]["code"]: break - return f"{obfuscated_models[model_id]}_{iteration}" + return f"{obfuscated_models[model_key]}_{iteration}" def _format_json_response(