-
Notifications
You must be signed in to change notification settings - Fork 100
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add 'BoundingBox', 'Pose2D' and 'Pose3D' models (#557)
- Loading branch information
1 parent
10c4e2a
commit e455180
Showing
7 changed files
with
178 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from . import yolo | ||
from .bbox import BBox | ||
from .pose import Pose, Pose3D | ||
|
||
__all__ = ["BBox", "Pose", "Pose3D", "yolo"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from typing import Optional | ||
|
||
from pydantic import Field | ||
|
||
from datachain.lib.data_model import DataModel | ||
|
||
|
||
class BBox(DataModel): | ||
""" | ||
A data model for representing bounding boxes. | ||
Attributes: | ||
title (str): The title of the bounding box. | ||
x1 (float): The x-coordinate of the top-left corner of the bounding box. | ||
y1 (float): The y-coordinate of the top-left corner of the bounding box. | ||
x2 (float): The x-coordinate of the bottom-right corner of the bounding box. | ||
y2 (float): The y-coordinate of the bottom-right corner of the bounding box. | ||
The bounding box is defined by two points: | ||
- (x1, y1): The top-left corner of the box. | ||
- (x2, y2): The bottom-right corner of the box. | ||
""" | ||
|
||
title: str = Field(default="") | ||
x1: float = Field(default=0) | ||
y1: float = Field(default=0) | ||
x2: float = Field(default=0) | ||
y2: float = Field(default=0) | ||
|
||
@staticmethod | ||
def from_xywh(bbox: list[float], title: Optional[str] = None) -> "BBox": | ||
""" | ||
Converts a bounding box in (x, y, width, height) format | ||
to a BBox data model instance. | ||
Args: | ||
bbox (list[float]): A bounding box, represented as a list | ||
of four floats [x, y, width, height]. | ||
Returns: | ||
BBox2D: An instance of the BBox data model. | ||
""" | ||
assert len(bbox) == 4, f"Bounding box must have 4 elements, got f{len(bbox)}" | ||
x, y, w, h = bbox | ||
return BBox(title=title or "", x1=x, y1=y, x2=x + w, y2=y + h) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from pydantic import Field | ||
|
||
from datachain.lib.data_model import DataModel | ||
|
||
|
||
class Pose(DataModel): | ||
""" | ||
A data model for representing pose keypoints. | ||
Attributes: | ||
x (list[float]): The x-coordinates of the keypoints. | ||
y (list[float]): The y-coordinates of the keypoints. | ||
The keypoints are represented as lists of x and y coordinates, where each index | ||
corresponds to a specific body part. | ||
""" | ||
|
||
x: list[float] = Field(default=None) | ||
y: list[float] = Field(default=None) | ||
|
||
|
||
class Pose3D(DataModel): | ||
""" | ||
A data model for representing 3D pose keypoints. | ||
Attributes: | ||
x (list[float]): The x-coordinates of the keypoints. | ||
y (list[float]): The y-coordinates of the keypoints. | ||
visible (list[float]): The visibility of the keypoints. | ||
The keypoints are represented as lists of x, y, and visibility values, | ||
where each index corresponds to a specific body part. | ||
""" | ||
|
||
x: list[float] = Field(default=None) | ||
y: list[float] = Field(default=None) | ||
visible: list[float] = Field(default=None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
""" | ||
This module contains the YOLO models. | ||
YOLO stands for "You Only Look Once", a family of object detection models that | ||
are designed to be fast and accurate. The models are trained to detect objects | ||
in images by dividing the image into a grid and predicting the bounding boxes | ||
and class probabilities for each grid cell. | ||
More information about YOLO can be found here: | ||
- https://pjreddie.com/darknet/yolo/ | ||
- https://docs.ultralytics.com/ | ||
""" | ||
|
||
|
||
class PoseBodyPart: | ||
""" | ||
An enumeration of body parts for YOLO pose keypoints. | ||
More information about the body parts can be found here: | ||
https://docs.ultralytics.com/tasks/pose/ | ||
""" | ||
|
||
nose = 0 | ||
left_eye = 1 | ||
right_eye = 2 | ||
left_ear = 3 | ||
right_ear = 4 | ||
left_shoulder = 5 | ||
right_shoulder = 6 | ||
left_elbow = 7 | ||
right_elbow = 8 | ||
left_wrist = 9 | ||
right_wrist = 10 | ||
left_hip = 11 | ||
right_hip = 12 | ||
left_knee = 13 | ||
right_knee = 14 | ||
left_ankle = 15 | ||
right_ankle = 16 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from datachain.lib import models | ||
|
||
|
||
def test_bbox(): | ||
bbox = models.BBox(title="BBox", x1=0.5, y1=1.5, x2=2.5, y2=3.5) | ||
assert bbox.model_dump() == { | ||
"title": "BBox", | ||
"x1": 0.5, | ||
"y1": 1.5, | ||
"x2": 2.5, | ||
"y2": 3.5, | ||
} | ||
|
||
|
||
def test_bbox_from_xywh(): | ||
bbox = models.BBox.from_xywh([0.5, 1.5, 2.5, 3.5]) | ||
assert bbox.model_dump() == {"title": "", "x1": 0.5, "y1": 1.5, "x2": 3, "y2": 5} | ||
|
||
bbox = models.BBox.from_xywh([0.5, 1.5, 2.5, 3.5], title="BBox") | ||
assert bbox.model_dump() == { | ||
"title": "BBox", | ||
"x1": 0.5, | ||
"y1": 1.5, | ||
"x2": 3, | ||
"y2": 5, | ||
} | ||
|
||
|
||
def test_pose(): | ||
x = [x * 0.5 for x in range(17)] | ||
y = [y * 1.5 for y in range(17)] | ||
pose = models.Pose(x=x, y=y) | ||
assert pose.model_dump() == {"x": x, "y": y} | ||
assert pose.x[models.yolo.PoseBodyPart.nose] == 0 | ||
assert pose.x[models.yolo.PoseBodyPart.left_eye] == 0.5 | ||
assert pose.x[models.yolo.PoseBodyPart.right_eye] == 1 | ||
assert pose.x[models.yolo.PoseBodyPart.left_ear] == 1.5 | ||
assert pose.x[models.yolo.PoseBodyPart.right_ear] == 2 | ||
assert pose.x[models.yolo.PoseBodyPart.left_shoulder] == 2.5 | ||
assert pose.x[models.yolo.PoseBodyPart.right_shoulder] == 3 | ||
assert pose.x[models.yolo.PoseBodyPart.left_elbow] == 3.5 | ||
assert pose.x[models.yolo.PoseBodyPart.right_elbow] == 4 | ||
assert pose.x[models.yolo.PoseBodyPart.left_wrist] == 4.5 | ||
assert pose.x[models.yolo.PoseBodyPart.right_wrist] == 5 | ||
assert pose.x[models.yolo.PoseBodyPart.left_hip] == 5.5 | ||
assert pose.x[models.yolo.PoseBodyPart.right_hip] == 6 | ||
assert pose.x[models.yolo.PoseBodyPart.left_knee] == 6.5 | ||
assert pose.x[models.yolo.PoseBodyPart.right_knee] == 7 | ||
assert pose.x[models.yolo.PoseBodyPart.left_ankle] == 7.5 | ||
assert pose.x[models.yolo.PoseBodyPart.right_ankle] == 8 |