From af47334da011af86837c6278ee0968057f5f1ca3 Mon Sep 17 00:00:00 2001 From: Onuralp SEZER Date: Fri, 27 Sep 2024 18:39:08 +0300 Subject: [PATCH 1/3] =?UTF-8?q?feat:=20=F0=9F=9A=80=20initial=20keypoint?= =?UTF-8?q?=20support=20for=20transformers=20added?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Onuralp SEZER --- supervision/keypoint/core.py | 54 ++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/supervision/keypoint/core.py b/supervision/keypoint/core.py index 36d6a5968..fd812f62e 100644 --- a/supervision/keypoint/core.py +++ b/supervision/keypoint/core.py @@ -509,6 +509,60 @@ def from_detectron2(cls, detectron2_results: Any) -> KeyPoints: else: return cls.empty() + @classmethod + def from_transformers(cls, transformers_results: List) -> KeyPoints: + """ + Create a `sv.KeyPoints` object from the + [Transformers](https://huggingface.co/transformers/) inference result. + + Args: + transformers_results (Any): The output of a + Hugging Face Transformers model containing instances with prediction data. + + Returns: + A `sv.KeyPoints` object containing the keypoint coordinates, class IDs, + and class names, and confidences of each keypoint. + + Example: + ```python + import cv2 + import torch + from PIL import Image + import supervision as sv + from transformers import AutoImageProcessor, SuperPointForKeypointDetection + + processor = AutoImageProcessor.from_pretrained("magic-leap-community/superpoint") + model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint") + + image = cv2.imread() + image_pil = Image.fromarray(image) + inputs = processor(images,return_tensors="pt").to(model.device, model.dtype) + outputs = model(**inputs) + keypoints = sv.KeyPoints.from_transformers(outputs) + ``` + """ # noqa: E501 // docs + + keypoints_list = [] + scores_list = [] + + for result in transformers_results: + if "keypoints" in result: + keypoints = result["keypoints"].detach().numpy() + scores = result["scores"].detach().numpy() + + if keypoints.size > 0: + keypoints_list.append(keypoints) + scores_list.append(scores) + + if not keypoints_list: + return cls.empty() + + return cls( + xy=np.array(keypoints_list), + confidence=np.array(scores_list), + class_id=None, + ) + def __getitem__( self, index: Union[int, slice, List[int], np.ndarray, str] ) -> Union[KeyPoints, List, np.ndarray, None]: From 933f71d7352db016b1aed62edffe4ad23aa45f72 Mon Sep 17 00:00:00 2001 From: Onuralp SEZER Date: Fri, 1 Nov 2024 22:02:00 +0300 Subject: [PATCH 2/3] =?UTF-8?q?feat(refactor):=20=E2=9C=A8=20transformers?= =?UTF-8?q?=20keypoint=20xy,conf=20returns=20dtypes=20added=20and=20class?= =?UTF-8?q?=5Fid=20removed?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Onuralp SEZER --- supervision/keypoint/core.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/supervision/keypoint/core.py b/supervision/keypoint/core.py index fd812f62e..803b818eb 100644 --- a/supervision/keypoint/core.py +++ b/supervision/keypoint/core.py @@ -558,9 +558,8 @@ def from_transformers(cls, transformers_results: List) -> KeyPoints: return cls.empty() return cls( - xy=np.array(keypoints_list), - confidence=np.array(scores_list), - class_id=None, + xy=np.array(keypoints_list,dtype=np.float32), + confidence=np.array(scores_list,dtype=np.float32), ) def __getitem__( From ed928d2056f604828f007cabadb677366636a888 Mon Sep 17 00:00:00 2001 From: Onuralp SEZER Date: Fri, 1 Nov 2024 22:30:12 +0300 Subject: [PATCH 3/3] =?UTF-8?q?feat:=20=E2=9C=A8=20descriptors=20custom=20?= =?UTF-8?q?field=20added=20for=20keypoint=20and=20as=20constant=20name?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Onuralp SEZER --- supervision/config.py | 1 + supervision/keypoint/core.py | 22 +++++++++++++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/supervision/config.py b/supervision/config.py index b18d2e20b..18236c8a7 100644 --- a/supervision/config.py +++ b/supervision/config.py @@ -1,2 +1,3 @@ CLASS_NAME_DATA_FIELD = "class_name" ORIENTED_BOX_COORDINATES = "xyxyxyxy" +DESCRIPTORS_FIELD = "descriptors" diff --git a/supervision/keypoint/core.py b/supervision/keypoint/core.py index 803b818eb..a55cab725 100644 --- a/supervision/keypoint/core.py +++ b/supervision/keypoint/core.py @@ -7,7 +7,7 @@ import numpy as np import numpy.typing as npt -from supervision.config import CLASS_NAME_DATA_FIELD +from supervision.config import CLASS_NAME_DATA_FIELD, DESCRIPTORS_FIELD from supervision.detection.utils import get_data_item, is_data_equal from supervision.validators import validate_keypoints_fields @@ -542,8 +542,10 @@ def from_transformers(cls, transformers_results: List) -> KeyPoints: ``` """ # noqa: E501 // docs - keypoints_list = [] - scores_list = [] + keypoints_list: List[np.ndarray] = [] + scores_list: List[np.ndarray] = [] + descriptors_list: List[np.ndarray] = [] + data: Dict[str, Any] = {} for result in transformers_results: if "keypoints" in result: @@ -554,12 +556,22 @@ def from_transformers(cls, transformers_results: List) -> KeyPoints: keypoints_list.append(keypoints) scores_list.append(scores) + if "descriptors" in result: + descriptors = result["descriptors"].detach().numpy() + + if descriptors.size > 0: + descriptors_list.append(descriptors) + if not keypoints_list: return cls.empty() + if descriptors_list: + data[DESCRIPTORS_FIELD] = np.array(descriptors_list) + return cls( - xy=np.array(keypoints_list,dtype=np.float32), - confidence=np.array(scores_list,dtype=np.float32), + xy=np.array(keypoints_list, dtype=np.float32), + confidence=np.array(scores_list, dtype=np.float32), + data=data if data else {}, ) def __getitem__(