Skip to content

Commit

Permalink
feat(LAB-2609): add LLM_RLHF label export (#1620)
Browse files Browse the repository at this point in the history
  • Loading branch information
BlueGrizzliBear authored Feb 5, 2024
1 parent baa6869 commit f66b2c9
Show file tree
Hide file tree
Showing 9 changed files with 484 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from kili.adapters.kili_api_gateway.project.common import get_project
from kili.core.graphql.graphql_client import GraphQLClient
from kili.domain.annotation import (
ClassicAnnotation,
ClassificationAnnotation,
RankingAnnotation,
TranscriptionAnnotation,
Vertice,
VideoAnnotation,
VideoClassificationAnnotation,
Expand Down Expand Up @@ -54,11 +58,17 @@ def patch_label_json_response(self, label: Dict, label_id: LabelId) -> None:
Modifies the input label.
"""
if self._project_input_type == "VIDEO":
if self._project_input_type in {"VIDEO", "LLM_RLHF"}:
annotations = list_annotations(
graphql_client=self._graphql_client,
label_id=label_id,
annotation_fields=("__typename", "id", "job", "path", "labelId"),
classification_annotation_fields=("annotationValue.categories",),
ranking_annotation_fields=(
"annotationValue.orders.elements",
"annotationValue.orders.rank",
),
transcription_annotation_fields=("annotationValue.text",),
video_annotation_fields=(
"frames.start",
"frames.end",
Expand All @@ -79,10 +89,15 @@ def patch_label_json_response(self, label: Dict, label_id: LabelId) -> None:
if not annotations and self._label_has_json_response_data(label):
return

annotations = cast(List[VideoAnnotation], annotations)
converted_json_resp = _video_label_annotations_to_json_response(
annotations=annotations, json_interface=self._project_json_interface
)
if self._project_input_type == "VIDEO":
annotations = cast(List[VideoAnnotation], annotations)
converted_json_resp = _video_annotations_to_json_response(
annotations=annotations, json_interface=self._project_json_interface
)
else:
annotations = cast(List[ClassicAnnotation], annotations)
converted_json_resp = _classic_annotations_to_json_response(annotations=annotations)

label["jsonResponse"] = converted_json_resp


Expand All @@ -105,7 +120,7 @@ def _fill_empty_frames(json_response: Dict) -> None:
json_response.setdefault(str(frame_id), {})


def _video_label_annotations_to_json_response(
def _video_annotations_to_json_response(
annotations: List[VideoAnnotation], json_interface: Dict
) -> Dict[str, Dict[JobName, Dict]]:
"""Convert video label annotations to a video json response."""
Expand Down Expand Up @@ -147,14 +162,49 @@ def _video_label_annotations_to_json_response(
json_resp[frame_id] = {**json_resp[frame_id], **frame_json_resp}

else:
raise NotImplementedError(f"Cannot convert annotation to json response: {ann}")
raise NotImplementedError(f"Cannot convert video annotation to json response: {ann}")

_add_annotation_metadata(annotations, json_resp)
_fill_empty_frames(json_resp)

return dict(sorted(json_resp.items(), key=lambda item: int(item[0]))) # sort by frame id


def _classic_annotations_to_json_response(
annotations: List[ClassicAnnotation],
) -> Dict[str, Dict[JobName, Dict]]:
"""Convert label annotations to a json response."""
json_resp = defaultdict(dict)

for ann in annotations:
if ann["__typename"] == "ClassificationAnnotation":
ann = cast(ClassificationAnnotation, ann)
ann_json_resp = _classification_annotation_to_json_response(ann)
for job_name, job_resp in ann_json_resp.items():
json_resp.setdefault(job_name, {}).setdefault("categories", []).extend(
job_resp["categories"]
)

elif ann["__typename"] == "RankingAnnotation":
ann = cast(RankingAnnotation, ann)
ann_json_resp = _ranking_annotation_to_json_response(ann)
for job_name, job_resp in ann_json_resp.items():
json_resp.setdefault(job_name, {}).setdefault("orders", []).extend(
job_resp["orders"]
)

elif ann["__typename"] == "TranscriptionAnnotation":
ann = cast(TranscriptionAnnotation, ann)
ann_json_resp = _transcription_annotation_to_json_response(ann)
for job_name, job_resp in ann_json_resp.items():
json_resp.setdefault(job_name, {}).setdefault("text", job_resp["text"])

else:
raise NotImplementedError(f"Cannot convert classic annotation to json response: {ann}")

return dict(json_resp)


@overload
def _key_annotations_iterator(
annotation: VideoTranscriptionAnnotation,
Expand Down Expand Up @@ -226,6 +276,40 @@ def _key_annotations_iterator(annotation: VideoAnnotation) -> Generator:
yield key_ann, key_ann_start, key_ann_end, next_key_ann


def _ranking_annotation_to_json_response(
annotation: RankingAnnotation,
) -> Dict[JobName, Dict]:
"""Convert ranking annotation to a json response.
Ranking jobs cannot have child jobs.
"""
json_resp = {
annotation["job"]: {
"orders": sorted(
annotation["annotationValue"]["orders"], key=lambda item: int(item["rank"])
),
}
}

return json_resp


def _transcription_annotation_to_json_response(
annotation: TranscriptionAnnotation,
) -> Dict[JobName, Dict]:
"""Convert transcription annotation to a json response.
Transcription jobs cannot have child jobs.
"""
json_resp = {
annotation["job"]: {
"text": annotation["annotationValue"]["text"],
}
}

return json_resp


def _video_transcription_annotation_to_json_response(
annotation: VideoTranscriptionAnnotation,
) -> Dict[str, Dict[JobName, Dict]]:
Expand Down Expand Up @@ -286,6 +370,25 @@ def _compute_children_json_resp(
return children_json_resp


def _classification_annotation_to_json_response(
annotation: ClassificationAnnotation,
) -> Dict[JobName, Dict]:
# initialize the json response
json_resp = {
annotation["job"]: {
"categories": [],
}
}

# a frame can have one or multiple categories
categories = annotation["annotationValue"]["categories"]
for category in categories:
category_annotation: Dict = {"name": category}
json_resp[annotation["job"]]["categories"].append(category_annotation)

return json_resp


def _video_classification_annotation_to_json_response(
annotation: VideoClassificationAnnotation,
other_annotations: List[VideoAnnotation],
Expand Down
6 changes: 6 additions & 0 deletions src/kili/adapters/kili_api_gateway/label/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def list_annotations(
label_id: LabelId,
*,
annotation_fields: ListOrTuple[str],
classification_annotation_fields: ListOrTuple[str] = (),
ranking_annotation_fields: ListOrTuple[str] = (),
transcription_annotation_fields: ListOrTuple[str] = (),
video_annotation_fields: ListOrTuple[str] = (),
video_classification_fields: ListOrTuple[str] = (),
video_object_detection_fields: ListOrTuple[str] = (),
Expand All @@ -23,6 +26,9 @@ def list_annotations(
"""List annotations for a label."""
query = get_annotations_query(
annotation_fragment=fragment_builder(annotation_fields),
classification_annotation_fragment=fragment_builder(classification_annotation_fields),
ranking_annotation_fragment=fragment_builder(ranking_annotation_fields),
transcription_annotation_fragment=fragment_builder(transcription_annotation_fields),
video_annotation_fragment=fragment_builder(video_annotation_fields),
video_classification_annotation_fragment=fragment_builder(video_classification_fields),
video_object_detection_annotation_fragment=fragment_builder(video_object_detection_fields),
Expand Down
24 changes: 24 additions & 0 deletions src/kili/adapters/kili_api_gateway/label/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def get_append_to_labels_mutation(fragment: str) -> str:
def get_annotations_query(
*,
annotation_fragment: str,
classification_annotation_fragment: str,
ranking_annotation_fragment: str,
transcription_annotation_fragment: str,
video_annotation_fragment: str,
video_object_detection_annotation_fragment: str,
video_classification_annotation_fragment: str,
Expand All @@ -99,6 +102,27 @@ def get_annotations_query(
"""Get the gql annotations query."""
inline_fragments = ""

if classification_annotation_fragment.strip():
inline_fragments += f"""
... on ClassificationAnnotation {{
{classification_annotation_fragment}
}}
"""

if ranking_annotation_fragment.strip():
inline_fragments += f"""
... on RankingAnnotation {{
{ranking_annotation_fragment}
}}
"""

if transcription_annotation_fragment.strip():
inline_fragments += f"""
... on TranscriptionAnnotation {{
{transcription_annotation_fragment}
}}
"""

if video_annotation_fragment.strip():
inline_fragments += f"""
... on VideoAnnotation {{
Expand Down
6 changes: 6 additions & 0 deletions src/kili/adapters/kili_api_gateway/label/operations_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def list_annotations(
label_id: LabelId,
*,
annotation_fields: ListOrTuple[str],
classification_annotation_fields: ListOrTuple[str] = (),
ranking_annotation_fields: ListOrTuple[str] = (),
transcription_annotation_fields: ListOrTuple[str] = (),
video_annotation_fields: ListOrTuple[str] = (),
video_classification_fields: ListOrTuple[str] = (),
video_object_detection_fields: ListOrTuple[str] = (),
Expand All @@ -188,6 +191,9 @@ def list_annotations(
graphql_client=self.graphql_client,
label_id=label_id,
annotation_fields=annotation_fields,
classification_annotation_fields=classification_annotation_fields,
ranking_annotation_fields=ranking_annotation_fields,
transcription_annotation_fields=transcription_annotation_fields,
video_annotation_fields=video_annotation_fields,
video_classification_fields=video_classification_fields,
video_object_detection_fields=video_object_detection_fields,
Expand Down
63 changes: 63 additions & 0 deletions src/kili/domain/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .ontology import JobName

AnnotationId = NewType("AnnotationId", str)
AnnotationValueId = NewType("AnnotationValueId", str)
KeyAnnotationId = NewType("KeyAnnotationId", str)


Expand All @@ -28,12 +29,61 @@ class ClassificationAnnotationValue(TypedDict):
categories: List[str]


class ClassificationAnnotation(TypedDict):
"""Classification annotation."""

# pylint: disable=unused-private-member
__typename: Literal["ClassificationAnnotation"]
id: AnnotationId
labelId: LabelId
job: JobName
path: List[List[str]]
annotationValue: ClassificationAnnotationValue


class RankingOrderValue(TypedDict):
"""Ranking order value."""

rank: int
elements: List[str]


class RankingAnnotationValue(TypedDict):
"""Ranking annotation value."""

orders: List[RankingOrderValue]


class RankingAnnotation(TypedDict):
"""Ranking annotation."""

# pylint: disable=unused-private-member
__typename: Literal["RankingAnnotation"]
id: AnnotationId
labelId: LabelId
job: JobName
path: List[List[str]]
annotationValue: RankingAnnotationValue


class TranscriptionAnnotationValue(TypedDict):
"""Transcription annotation value."""

text: str


class TranscriptionAnnotation(TypedDict):
"""Transcription annotation."""

# pylint: disable=unused-private-member
__typename: Literal["TranscriptionAnnotation"]
id: AnnotationId
labelId: LabelId
job: JobName
path: List[List[str]]
annotationValue: TranscriptionAnnotationValue


class Annotation(TypedDict):
"""Annotation."""

Expand Down Expand Up @@ -121,3 +171,16 @@ class VideoTranscriptionAnnotation(TypedDict):
VideoClassificationAnnotation,
VideoTranscriptionAnnotation,
]

ClassicAnnotation = Union[
ClassificationAnnotation,
RankingAnnotation,
TranscriptionAnnotation,
]

Annotation = Union[
ClassificationAnnotation,
RankingAnnotation,
TranscriptionAnnotation,
VideoAnnotation,
]
Loading

0 comments on commit f66b2c9

Please sign in to comment.