diff --git a/supervision/config.py b/supervision/config.py index b18d2e20b..18236c8a7 100644 --- a/supervision/config.py +++ b/supervision/config.py @@ -1,2 +1,3 @@ CLASS_NAME_DATA_FIELD = "class_name" ORIENTED_BOX_COORDINATES = "xyxyxyxy" +DESCRIPTORS_FIELD = "descriptors" diff --git a/supervision/keypoint/core.py b/supervision/keypoint/core.py index 36d6a5968..a55cab725 100644 --- a/supervision/keypoint/core.py +++ b/supervision/keypoint/core.py @@ -7,7 +7,7 @@ import numpy as np import numpy.typing as npt -from supervision.config import CLASS_NAME_DATA_FIELD +from supervision.config import CLASS_NAME_DATA_FIELD, DESCRIPTORS_FIELD from supervision.detection.utils import get_data_item, is_data_equal from supervision.validators import validate_keypoints_fields @@ -509,6 +509,71 @@ def from_detectron2(cls, detectron2_results: Any) -> KeyPoints: else: return cls.empty() + @classmethod + def from_transformers(cls, transformers_results: List) -> KeyPoints: + """ + Create a `sv.KeyPoints` object from the + [Transformers](https://huggingface.co/transformers/) inference result. + + Args: + transformers_results (Any): The output of a + Hugging Face Transformers model containing instances with prediction data. + + Returns: + A `sv.KeyPoints` object containing the keypoint coordinates, class IDs, + and class names, and confidences of each keypoint. + + Example: + ```python + import cv2 + import torch + from PIL import Image + import supervision as sv + from transformers import AutoImageProcessor, SuperPointForKeypointDetection + + processor = AutoImageProcessor.from_pretrained("magic-leap-community/superpoint") + model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint") + + image = cv2.imread() + image_pil = Image.fromarray(image) + inputs = processor(images,return_tensors="pt").to(model.device, model.dtype) + outputs = model(**inputs) + keypoints = sv.KeyPoints.from_transformers(outputs) + ``` + """ # noqa: E501 // docs + + keypoints_list: List[np.ndarray] = [] + scores_list: List[np.ndarray] = [] + descriptors_list: List[np.ndarray] = [] + data: Dict[str, Any] = {} + + for result in transformers_results: + if "keypoints" in result: + keypoints = result["keypoints"].detach().numpy() + scores = result["scores"].detach().numpy() + + if keypoints.size > 0: + keypoints_list.append(keypoints) + scores_list.append(scores) + + if "descriptors" in result: + descriptors = result["descriptors"].detach().numpy() + + if descriptors.size > 0: + descriptors_list.append(descriptors) + + if not keypoints_list: + return cls.empty() + + if descriptors_list: + data[DESCRIPTORS_FIELD] = np.array(descriptors_list) + + return cls( + xy=np.array(keypoints_list, dtype=np.float32), + confidence=np.array(scores_list, dtype=np.float32), + data=data if data else {}, + ) + def __getitem__( self, index: Union[int, slice, List[int], np.ndarray, str] ) -> Union[KeyPoints, List, np.ndarray, None]: