From 32c72f83c64ee5afcb240edef4d2f821836046ff Mon Sep 17 00:00:00 2001 From: David Redo Date: Sat, 7 Dec 2024 17:10:48 +0100 Subject: [PATCH 1/7] Added MetricResult base class --- supervision/metrics/core.py | 14 ++++++++++++++ supervision/metrics/f1_score.py | 4 ++-- supervision/metrics/mean_average_precision.py | 4 ++-- supervision/metrics/mean_average_recall.py | 4 ++-- supervision/metrics/precision.py | 4 ++-- supervision/metrics/recall.py | 4 ++-- 6 files changed, 24 insertions(+), 10 deletions(-) diff --git a/supervision/metrics/core.py b/supervision/metrics/core.py index def5999a0..bae56785f 100644 --- a/supervision/metrics/core.py +++ b/supervision/metrics/core.py @@ -33,6 +33,20 @@ def compute(self, *args, **kwargs) -> Any: raise NotImplementedError +class MetricResult(ABC): + """ + Base class for all metric results. + """ + + @abstractmethod + def to_pandas(): + raise NotImplementedError() + + @abstractmethod + def plot(): + raise NotImplementedError() + + class MetricTarget(Enum): """ Specifies what type of detection is used to compute the metric. diff --git a/supervision/metrics/f1_score.py b/supervision/metrics/f1_score.py index 98cb5f265..91d587fbf 100644 --- a/supervision/metrics/f1_score.py +++ b/supervision/metrics/f1_score.py @@ -15,7 +15,7 @@ oriented_box_iou_batch, ) from supervision.draw.color import LEGACY_COLOR_PALETTE -from supervision.metrics.core import AveragingMethod, Metric, MetricTarget +from supervision.metrics.core import AveragingMethod, Metric, MetricTarget, MetricResult from supervision.metrics.utils.object_size import ( ObjectSizeCategory, get_detection_size_category, @@ -455,7 +455,7 @@ def _filter_predictions_and_targets_by_size( @dataclass -class F1ScoreResult: +class F1ScoreResult(MetricResult): """ The results of the F1 score metric calculation. diff --git a/supervision/metrics/mean_average_precision.py b/supervision/metrics/mean_average_precision.py index 9e7a30d0e..edfef9b55 100644 --- a/supervision/metrics/mean_average_precision.py +++ b/supervision/metrics/mean_average_precision.py @@ -15,7 +15,7 @@ oriented_box_iou_batch, ) from supervision.draw.color import LEGACY_COLOR_PALETTE -from supervision.metrics.core import Metric, MetricTarget +from supervision.metrics.core import Metric, MetricTarget, MetricResult from supervision.metrics.utils.object_size import ( ObjectSizeCategory, get_detection_size_category, @@ -418,7 +418,7 @@ def _filter_detections_by_size( @dataclass -class MeanAveragePrecisionResult: +class MeanAveragePrecisionResult(MetricResult): """ The result of the Mean Average Precision calculation. diff --git a/supervision/metrics/mean_average_recall.py b/supervision/metrics/mean_average_recall.py index 9c3a40718..6f2aa28ae 100644 --- a/supervision/metrics/mean_average_recall.py +++ b/supervision/metrics/mean_average_recall.py @@ -15,7 +15,7 @@ oriented_box_iou_batch, ) from supervision.draw.color import LEGACY_COLOR_PALETTE -from supervision.metrics.core import Metric, MetricTarget +from supervision.metrics.core import Metric, MetricTarget, MetricResult from supervision.metrics.utils.object_size import ( ObjectSizeCategory, get_detection_size_category, @@ -460,7 +460,7 @@ def _filter_predictions_and_targets_by_size( @dataclass -class MeanAverageRecallResult: +class MeanAverageRecallResult(MetricResult): # """ # The results of the recall metric calculation. diff --git a/supervision/metrics/precision.py b/supervision/metrics/precision.py index a5d4011e8..c10c7aca4 100644 --- a/supervision/metrics/precision.py +++ b/supervision/metrics/precision.py @@ -15,7 +15,7 @@ oriented_box_iou_batch, ) from supervision.draw.color import LEGACY_COLOR_PALETTE -from supervision.metrics.core import AveragingMethod, Metric, MetricTarget +from supervision.metrics.core import AveragingMethod, Metric, MetricTarget, MetricResult from supervision.metrics.utils.object_size import ( ObjectSizeCategory, get_detection_size_category, @@ -458,7 +458,7 @@ def _filter_predictions_and_targets_by_size( @dataclass -class PrecisionResult: +class PrecisionResult(MetricResult): """ The results of the precision metric calculation. diff --git a/supervision/metrics/recall.py b/supervision/metrics/recall.py index b3586ff7d..3f8b9e808 100644 --- a/supervision/metrics/recall.py +++ b/supervision/metrics/recall.py @@ -15,7 +15,7 @@ oriented_box_iou_batch, ) from supervision.draw.color import LEGACY_COLOR_PALETTE -from supervision.metrics.core import AveragingMethod, Metric, MetricTarget +from supervision.metrics.core import AveragingMethod, Metric, MetricTarget, MetricResult from supervision.metrics.utils.object_size import ( ObjectSizeCategory, get_detection_size_category, @@ -457,7 +457,7 @@ def _filter_predictions_and_targets_by_size( @dataclass -class RecallResult: +class RecallResult(MetricResult): """ The results of the recall metric calculation. From be69963661473571361d47387ea93f237d093995 Mon Sep 17 00:00:00 2001 From: David Redo Date: Sat, 7 Dec 2024 17:39:13 +0100 Subject: [PATCH 2/7] feat: add a funtion to aggregate metric results into a DataFrame --- .../metrics/aggregate_metric_results.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 supervision/metrics/aggregate_metric_results.py diff --git a/supervision/metrics/aggregate_metric_results.py b/supervision/metrics/aggregate_metric_results.py new file mode 100644 index 000000000..ff30a5de5 --- /dev/null +++ b/supervision/metrics/aggregate_metric_results.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING, List + +from supervision.metrics.core import MetricResult +from supervision.metrics.utils.utils import ensure_pandas_installed + +if TYPE_CHECKING: + import pandas as pd + + +def aggregate_metric_results( + metrics_results: List[MetricResult], + model_names: List[str], + include_object_sizes=False, +) -> "pd.DataFrame": + """ + Raises when different types of metrics results are passed in + """ + ensure_pandas_installed() + import pandas as pd + + assert len(metrics_results) == len( + model_names + ), "Number of metrics results and model names must be equal" + + if len(metrics_results) == 0: + raise ValueError("metrics_results must not be empty") + + first_elem_type = type(metrics_results[0]) + all_same_type = all(isinstance(x, first_elem_type) for x in metrics_results) + if not all_same_type: + raise ValueError("All metrics_results must be of the same type") + + if not isinstance(metrics_results[0], MetricResult): + raise ValueError("Base class of metrics_results must be of type MetricResult") + + pd_results = [] + for metric_result, model_name in zip(metrics_results, model_names): + pd_result = metric_result.to_pandas() + pd_result.insert(loc=0, column="Model Name", value=model_name) + pd_results.append(pd_result) + + df_merged = pd.concat(pd_results) + + if not include_object_sizes: + regex_pattern = "small|medium|large" + df_merged = df_merged.drop(columns=list(df_merged.filter(regex=regex_pattern))) + + return df_merged + From 0556a80b90f26d540e5f076a1745515afb348961 Mon Sep 17 00:00:00 2001 From: David Redo Date: Sat, 7 Dec 2024 19:05:23 +0100 Subject: [PATCH 3/7] docs: added docstring to aggregate_metric_results function --- .../metrics/aggregate_metric_results.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/supervision/metrics/aggregate_metric_results.py b/supervision/metrics/aggregate_metric_results.py index ff30a5de5..03ea737c7 100644 --- a/supervision/metrics/aggregate_metric_results.py +++ b/supervision/metrics/aggregate_metric_results.py @@ -13,7 +13,22 @@ def aggregate_metric_results( include_object_sizes=False, ) -> "pd.DataFrame": """ - Raises when different types of metrics results are passed in + Convert a list of results to a pandas DataFrame. + + Args: + metrics_results (List[MetricResult]): List of results to be aggregated. + model_names (List[str]): List of model names corresponding to the results. + include_object_sizes (bool, optional): Whether to include object sizes in the + DataFrame. Defaults to False. + + Raises: + ValueError: `metrics_results` can not be empty + ValueError: All elements of `metrics_results` must be of the same type + ValueError: Base class of elements in `metrics_results` must be of type + `MetricResult` + + Returns: + pd.DataFrame: The results as a DataFrame. """ ensure_pandas_installed() import pandas as pd @@ -46,4 +61,3 @@ def aggregate_metric_results( df_merged = df_merged.drop(columns=list(df_merged.filter(regex=regex_pattern))) return df_merged - From d5e1b93547f1e4ac6fd2da8c78c1dade4121945a Mon Sep 17 00:00:00 2001 From: David Redo Date: Sun, 8 Dec 2024 12:21:12 +0100 Subject: [PATCH 4/7] feat: Added function to plot the aggregated metric results --- .../metrics/aggregate_metric_results.py | 99 +++++++++++++++++++ supervision/metrics/mean_average_precision.py | 10 +- 2 files changed, 106 insertions(+), 3 deletions(-) diff --git a/supervision/metrics/aggregate_metric_results.py b/supervision/metrics/aggregate_metric_results.py index 03ea737c7..6707cca0c 100644 --- a/supervision/metrics/aggregate_metric_results.py +++ b/supervision/metrics/aggregate_metric_results.py @@ -1,5 +1,9 @@ from typing import TYPE_CHECKING, List +import numpy as np +from matplotlib import pyplot as plt + +from supervision.draw.color import LEGACY_COLOR_PALETTE from supervision.metrics.core import MetricResult from supervision.metrics.utils.utils import ensure_pandas_installed @@ -61,3 +65,98 @@ def aggregate_metric_results( df_merged = df_merged.drop(columns=list(df_merged.filter(regex=regex_pattern))) return df_merged + + +def plot_aggregate_metric_results( + metrics_results: List[MetricResult], + model_names: List[str], + include_object_sizes=False, +): + """ + Plot a bar chart with the results of multiple metrics. + + Args: + metrics_results (List[MetricResult]): List of results to be plotted. + model_names (List[str]): List of model names corresponding to the results. + include_object_sizes (bool, optional): Whether to include object sizes in the + plot. Defaults to False. + + Raises: + ValueError: `metrics_results` can not be empty + ValueError: All elements of `metrics_results` must be of the same type + ValueError: Base class of elements in `metrics_results` must be of type + `MetricResult` + """ + assert len(metrics_results) == len( + model_names + ), "Number of metrics results and model names must be equal" + + if len(metrics_results) == 0: + raise ValueError("metrics_results must not be empty") + + first_elem_type = type(metrics_results[0]) + all_same_type = all(isinstance(x, first_elem_type) for x in metrics_results) + if not all_same_type: + raise ValueError("All metrics_results must be of the same type") + + if not isinstance(metrics_results[0], MetricResult): + raise ValueError("Base class of metrics_results must be of type MetricResult") + + model_values = [] + labels, value, plot_title = metrics_results[0].plot(return_params=True) + model_values.append(value) + + for metric in metrics_results[1:]: + _, value, _ = metric.plot(return_params=True) + model_values.append(value) + + if not include_object_sizes: + labels = labels[0:3] + aux_values = [] + for values in model_values: + aux_values.append(values[0:3]) + model_values = aux_values + + n = len(model_names) + x_positions = np.arange(len(labels)) + width = 0.8 / n + value_text_rotation = 90 if include_object_sizes else 0 + + plt.rcParams["font.family"] = "monospace" + + _, ax = plt.subplots(figsize=(10, 6)) + ax.set_ylim(0, 1) + ax.set_ylabel("Value", fontweight="bold") + ax.set_title(plot_title, fontweight="bold") + + ax.set_xticks(x_positions) + ax.set_xticklabels(labels, rotation=45, ha="right") + + colors = LEGACY_COLOR_PALETTE[:n] + + for i, model_value in enumerate(model_values): + offset = (i - (n - 1) / 2) * width + bars = ax.bar( + x_positions + offset, + model_value, + width=width, + label=model_names[i], + color=colors[i % len(colors)], + ) + + for bar in bars: + y_value = bar.get_height() + ax.text( + bar.get_x() + bar.get_width() / 2, + y_value + 0.02, + f"{y_value:.2f}", + ha="center", + va="bottom", + rotation=value_text_rotation, + ) + + plt.rcParams["font.family"] = "sans-serif" + + plt.legend(loc="best") + plt.tight_layout() + plt.show() diff --git a/supervision/metrics/mean_average_precision.py b/supervision/metrics/mean_average_precision.py index edfef9b55..4854697fb 100644 --- a/supervision/metrics/mean_average_precision.py +++ b/supervision/metrics/mean_average_precision.py @@ -15,7 +15,7 @@ oriented_box_iou_batch, ) from supervision.draw.color import LEGACY_COLOR_PALETTE -from supervision.metrics.core import Metric, MetricTarget, MetricResult +from supervision.metrics.core import Metric, MetricResult, MetricTarget from supervision.metrics.utils.object_size import ( ObjectSizeCategory, get_detection_size_category, @@ -559,7 +559,7 @@ def to_pandas(self) -> "pd.DataFrame": index=[0], ) - def plot(self): + def plot(self, return_params=False): """ Plot the mAP results. @@ -571,6 +571,7 @@ def plot(self): labels = ["mAP@50:95", "mAP@50", "mAP@75"] values = [self.map50_95, self.map50, self.map75] colors = [LEGACY_COLOR_PALETTE[0]] * 3 + plot_title = "Mean Average Precision" if self.small_objects is not None: labels += ["Small: mAP@50:95", "Small: mAP@50", "Small: mAP@75"] @@ -599,12 +600,15 @@ def plot(self): ] colors += [LEGACY_COLOR_PALETTE[4]] * 3 + if return_params: + return labels, values, plot_title + plt.rcParams["font.family"] = "monospace" _, ax = plt.subplots(figsize=(10, 6)) ax.set_ylim(0, 1) ax.set_ylabel("Value", fontweight="bold") - ax.set_title("Mean Average Precision", fontweight="bold") + ax.set_title(plot_title, fontweight="bold") x_positions = range(len(labels)) bars = ax.bar(x_positions, values, color=colors, align="center") From 38886ef0880711fa2b70b8ba2b9aa7c48d2f5a93 Mon Sep 17 00:00:00 2001 From: David Redo Date: Wed, 11 Dec 2024 20:37:29 +0100 Subject: [PATCH 5/7] Moved aggregate_metric_results to utils --- .../{ => utils}/aggregate_metric_results.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) rename supervision/metrics/{ => utils}/aggregate_metric_results.py (92%) diff --git a/supervision/metrics/aggregate_metric_results.py b/supervision/metrics/utils/aggregate_metric_results.py similarity index 92% rename from supervision/metrics/aggregate_metric_results.py rename to supervision/metrics/utils/aggregate_metric_results.py index 6707cca0c..cfc3d76f3 100644 --- a/supervision/metrics/aggregate_metric_results.py +++ b/supervision/metrics/utils/aggregate_metric_results.py @@ -103,18 +103,19 @@ def plot_aggregate_metric_results( raise ValueError("Base class of metrics_results must be of type MetricResult") model_values = [] - labels, value, plot_title = metrics_results[0].plot(return_params=True) - model_values.append(value) + labels, values, title, _ = metrics_results[0]._get_plot_details() + model_values.append(values) for metric in metrics_results[1:]: - _, value, _ = metric.plot(return_params=True) - model_values.append(value) + _, values, _, _ = metric._get_plot_details() + model_values.append(values) if not include_object_sizes: - labels = labels[0:3] + labels_length = 3 if len(labels) % 3 == 0 else 2 + labels = labels[:labels_length] aux_values = [] for values in model_values: - aux_values.append(values[0:3]) + aux_values.append(values[:labels_length]) model_values = aux_values n = len(model_names) @@ -127,7 +128,7 @@ def plot_aggregate_metric_results( _, ax = plt.subplots(figsize=(10, 6)) ax.set_ylim(0, 1) ax.set_ylabel("Value", fontweight="bold") - ax.set_title(plot_title, fontweight="bold") + ax.set_title(title, fontweight="bold") ax.set_xticks(x_positions) ax.set_xticklabels(labels, rotation=45, ha="right") From 5fa6fc07b4967cbd3d84e80bcdeebd5e4a83401a Mon Sep 17 00:00:00 2001 From: David Redo Date: Wed, 11 Dec 2024 20:38:49 +0100 Subject: [PATCH 6/7] Added _get_plot_details function to MetricResult classes --- supervision/metrics/core.py | 4 ++ supervision/metrics/f1_score.py | 36 ++++++++++++------ supervision/metrics/mean_average_precision.py | 30 ++++++++++----- supervision/metrics/mean_average_recall.py | 34 ++++++++++++----- supervision/metrics/precision.py | 37 ++++++++++++------ supervision/metrics/recall.py | 38 +++++++++++++------ 6 files changed, 123 insertions(+), 56 deletions(-) diff --git a/supervision/metrics/core.py b/supervision/metrics/core.py index bae56785f..b1335ecd6 100644 --- a/supervision/metrics/core.py +++ b/supervision/metrics/core.py @@ -46,6 +46,10 @@ def to_pandas(): def plot(): raise NotImplementedError() + @abstractmethod + def _get_plot_details(): + raise NotImplementedError() + class MetricTarget(Enum): """ diff --git a/supervision/metrics/f1_score.py b/supervision/metrics/f1_score.py index 91d587fbf..5184d5121 100644 --- a/supervision/metrics/f1_score.py +++ b/supervision/metrics/f1_score.py @@ -15,7 +15,7 @@ oriented_box_iou_batch, ) from supervision.draw.color import LEGACY_COLOR_PALETTE -from supervision.metrics.core import AveragingMethod, Metric, MetricTarget, MetricResult +from supervision.metrics.core import AveragingMethod, Metric, MetricResult, MetricTarget from supervision.metrics.utils.object_size import ( ObjectSizeCategory, get_detection_size_category, @@ -583,15 +583,15 @@ def to_pandas(self) -> "pd.DataFrame": return pd.DataFrame(pandas_data, index=[0]) - def plot(self): + def _get_plot_details(self) -> Tuple[List[str], List[float], str, List[str]]: """ - Plot the F1 results. + Obtain the metric details for plotting them. - ![example_plot](\ - https://media.roboflow.com/supervision-docs/metrics/f1_plot_example.png\ - ){ align=center width="800" } + Returns: + Tuple[List[str], List[float], str, List[str]]: The details for plotting the + metric. It is a tuple of four elements: a list of labels, a list of + values, the title of the plot and the bar colors. """ - labels = ["F1@50", "F1@75"] values = [self.f1_50, self.f1_75] colors = [LEGACY_COLOR_PALETTE[0]] * 2 @@ -614,16 +614,28 @@ def plot(self): values += [large_objects.f1_50, large_objects.f1_75] colors += [LEGACY_COLOR_PALETTE[4]] * 2 - plt.rcParams["font.family"] = "monospace" - - _, ax = plt.subplots(figsize=(10, 6)) - ax.set_ylim(0, 1) - ax.set_ylabel("Value", fontweight="bold") title = ( f"F1 Score, by Object Size" f"\n(target: {self.metric_target.value}," f" averaging: {self.averaging_method.value})" ) + return labels, values, title, colors + + def plot(self): + """ + Plot the F1 results. + + ![example_plot](\ + https://media.roboflow.com/supervision-docs/metrics/f1_plot_example.png\ + ){ align=center width="800" } + """ + labels, values, title, colors = self._get_plot_details() + + plt.rcParams["font.family"] = "monospace" + + _, ax = plt.subplots(figsize=(10, 6)) + ax.set_ylim(0, 1) + ax.set_ylabel("Value", fontweight="bold") ax.set_title(title, fontweight="bold") x_positions = range(len(labels)) diff --git a/supervision/metrics/mean_average_precision.py b/supervision/metrics/mean_average_precision.py index 4854697fb..1e0637a33 100644 --- a/supervision/metrics/mean_average_precision.py +++ b/supervision/metrics/mean_average_precision.py @@ -559,19 +559,19 @@ def to_pandas(self) -> "pd.DataFrame": index=[0], ) - def plot(self, return_params=False): + def _get_plot_details(self) -> Tuple[List[str], List[float], str, List[str]]: """ - Plot the mAP results. + Obtain the metric details for plotting them. - ![example_plot](\ - https://media.roboflow.com/supervision-docs/metrics/mAP_plot_example.png\ - ){ align=center width="800" } + Returns: + Tuple[List[str], List[float], str, List[str]]: The details for plotting the + metric. It is a tuple of four elements: a list of labels, a list of + values, the title of the plot and the bar colors. """ - labels = ["mAP@50:95", "mAP@50", "mAP@75"] values = [self.map50_95, self.map50, self.map75] colors = [LEGACY_COLOR_PALETTE[0]] * 3 - plot_title = "Mean Average Precision" + title = "Mean Average Precision" if self.small_objects is not None: labels += ["Small: mAP@50:95", "Small: mAP@50", "Small: mAP@75"] @@ -600,15 +600,25 @@ def plot(self, return_params=False): ] colors += [LEGACY_COLOR_PALETTE[4]] * 3 - if return_params: - return labels, values, plot_title + return labels, values, title, colors + + def plot(self): + """ + Plot the mAP results. + + ![example_plot](\ + https://media.roboflow.com/supervision-docs/metrics/mAP_plot_example.png\ + ){ align=center width="800" } + """ + + labels, values, title, colors = self._get_plot_details() plt.rcParams["font.family"] = "monospace" _, ax = plt.subplots(figsize=(10, 6)) ax.set_ylim(0, 1) ax.set_ylabel("Value", fontweight="bold") - ax.set_title(plot_title, fontweight="bold") + ax.set_title(title, fontweight="bold") x_positions = range(len(labels)) bars = ax.bar(x_positions, values, color=colors, align="center") diff --git a/supervision/metrics/mean_average_recall.py b/supervision/metrics/mean_average_recall.py index 6f2aa28ae..6e4d9c0d6 100644 --- a/supervision/metrics/mean_average_recall.py +++ b/supervision/metrics/mean_average_recall.py @@ -15,7 +15,7 @@ oriented_box_iou_batch, ) from supervision.draw.color import LEGACY_COLOR_PALETTE -from supervision.metrics.core import Metric, MetricTarget, MetricResult +from supervision.metrics.core import Metric, MetricResult, MetricTarget from supervision.metrics.utils.object_size import ( ObjectSizeCategory, get_detection_size_category, @@ -622,13 +622,14 @@ def to_pandas(self) -> "pd.DataFrame": return pd.DataFrame(pandas_data, index=[0]) - def plot(self): + def _get_plot_details(self) -> Tuple[List[str], List[float], str, List[str]]: """ - Plot the Mean Average Recall results. + Obtain the metric details for plotting them. - ![example_plot](\ - https://media.roboflow.com/supervision-docs/metrics/mAR_plot_example.png\ - ){ align=center width="800" } + Returns: + Tuple[List[str], List[float], str, List[str]]: The details for plotting the + metric. It is a tuple of four elements: a list of labels, a list of + values, the title of the plot and the bar colors. """ labels = ["mAR @ 1", "mAR @ 10", "mAR @ 100"] values = [self.mAR_at_1, self.mAR_at_10, self.mAR_at_100] @@ -664,15 +665,28 @@ def plot(self): ] colors += [LEGACY_COLOR_PALETTE[4]] * 3 + title = ( + f"Mean Average Recall, by Object Size" + f"\n(target: {self.metric_target.value})" + ) + return labels, values, title, colors + + def plot(self): + """ + Plot the Mean Average Recall results. + + ![example_plot](\ + https://media.roboflow.com/supervision-docs/metrics/mAR_plot_example.png\ + ){ align=center width="800" } + """ + + labels, values, title, colors = self._get_plot_details() + plt.rcParams["font.family"] = "monospace" _, ax = plt.subplots(figsize=(10, 6)) ax.set_ylim(0, 1) ax.set_ylabel("Value", fontweight="bold") - title = ( - f"Mean Average Recall, by Object Size" - f"\n(target: {self.metric_target.value})" - ) ax.set_title(title, fontweight="bold") x_positions = range(len(labels)) diff --git a/supervision/metrics/precision.py b/supervision/metrics/precision.py index c10c7aca4..49eae5adb 100644 --- a/supervision/metrics/precision.py +++ b/supervision/metrics/precision.py @@ -15,7 +15,7 @@ oriented_box_iou_batch, ) from supervision.draw.color import LEGACY_COLOR_PALETTE -from supervision.metrics.core import AveragingMethod, Metric, MetricTarget, MetricResult +from supervision.metrics.core import AveragingMethod, Metric, MetricResult, MetricTarget from supervision.metrics.utils.object_size import ( ObjectSizeCategory, get_detection_size_category, @@ -588,15 +588,15 @@ def to_pandas(self) -> "pd.DataFrame": return pd.DataFrame(pandas_data, index=[0]) - def plot(self): + def _get_plot_details(self) -> Tuple[List[str], List[float], str, List[str]]: """ - Plot the precision results. + Obtain the metric details for plotting them. - ![example_plot](\ - https://media.roboflow.com/supervision-docs/metrics/precision_plot_example.png\ - ){ align=center width="800" } + Returns: + Tuple[List[str], List[float], str, List[str]]: The details for plotting the + metric. It is a tuple of four elements: a list of labels, a list of + values, the title of the plot and the bar colors. """ - labels = ["Precision@50", "Precision@75"] values = [self.precision_at_50, self.precision_at_75] colors = [LEGACY_COLOR_PALETTE[0]] * 2 @@ -619,16 +619,29 @@ def plot(self): values += [large_objects.precision_at_50, large_objects.precision_at_75] colors += [LEGACY_COLOR_PALETTE[4]] * 2 - plt.rcParams["font.family"] = "monospace" - - _, ax = plt.subplots(figsize=(10, 6)) - ax.set_ylim(0, 1) - ax.set_ylabel("Value", fontweight="bold") title = ( f"Precision, by Object Size" f"\n(target: {self.metric_target.value}," f" averaging: {self.averaging_method.value})" ) + return labels, values, title, colors + + def plot(self): + """ + Plot the precision results. + + ![example_plot](\ + https://media.roboflow.com/supervision-docs/metrics/precision_plot_example.png\ + ){ align=center width="800" } + """ + + labels, values, title, colors = self._get_plot_details() + + plt.rcParams["font.family"] = "monospace" + + _, ax = plt.subplots(figsize=(10, 6)) + ax.set_ylim(0, 1) + ax.set_ylabel("Value", fontweight="bold") ax.set_title(title, fontweight="bold") x_positions = range(len(labels)) diff --git a/supervision/metrics/recall.py b/supervision/metrics/recall.py index 3f8b9e808..b84658aec 100644 --- a/supervision/metrics/recall.py +++ b/supervision/metrics/recall.py @@ -15,7 +15,7 @@ oriented_box_iou_batch, ) from supervision.draw.color import LEGACY_COLOR_PALETTE -from supervision.metrics.core import AveragingMethod, Metric, MetricTarget, MetricResult +from supervision.metrics.core import AveragingMethod, Metric, MetricResult, MetricTarget from supervision.metrics.utils.object_size import ( ObjectSizeCategory, get_detection_size_category, @@ -587,15 +587,15 @@ def to_pandas(self) -> "pd.DataFrame": return pd.DataFrame(pandas_data, index=[0]) - def plot(self): + def _get_plot_details(self): """ - Plot the recall results. + Obtain the metric details for plotting them. - ![example_plot](\ - https://media.roboflow.com/supervision-docs/metrics/recall_plot_example.png\ - ){ align=center width="800" } + Returns: + Tuple[List[str], List[float], str, List[str]]: The details for plotting the + metric. It is a tuple of four elements: a list of labels, a list of + values, the title of the plot and the bar colors. """ - labels = ["Recall@50", "Recall@75"] values = [self.recall_at_50, self.recall_at_75] colors = [LEGACY_COLOR_PALETTE[0]] * 2 @@ -618,16 +618,30 @@ def plot(self): values += [large_objects.recall_at_50, large_objects.recall_at_75] colors += [LEGACY_COLOR_PALETTE[4]] * 2 - plt.rcParams["font.family"] = "monospace" - - _, ax = plt.subplots(figsize=(10, 6)) - ax.set_ylim(0, 1) - ax.set_ylabel("Value", fontweight="bold") title = ( f"Recall, by Object Size" f"\n(target: {self.metric_target.value}," f" averaging: {self.averaging_method.value})" ) + + return labels, values, title, colors + + def plot(self): + """ + Plot the recall results. + + ![example_plot](\ + https://media.roboflow.com/supervision-docs/metrics/recall_plot_example.png\ + ){ align=center width="800" } + """ + + labels, values, title, colors = self._get_plot_details() + + plt.rcParams["font.family"] = "monospace" + + _, ax = plt.subplots(figsize=(10, 6)) + ax.set_ylim(0, 1) + ax.set_ylabel("Value", fontweight="bold") ax.set_title(title, fontweight="bold") x_positions = range(len(labels)) From 4bc57d4a3397040e5d6c51ee21e9317b472bba8f Mon Sep 17 00:00:00 2001 From: David Redo Date: Wed, 11 Dec 2024 20:48:40 +0100 Subject: [PATCH 7/7] Updated comments for aggregate functions --- supervision/metrics/core.py | 12 ++++++++++++ .../metrics/utils/aggregate_metric_results.py | 16 ++++++++-------- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/supervision/metrics/core.py b/supervision/metrics/core.py index b1335ecd6..203bb59ed 100644 --- a/supervision/metrics/core.py +++ b/supervision/metrics/core.py @@ -40,14 +40,26 @@ class MetricResult(ABC): @abstractmethod def to_pandas(): + """ + Convert the result to a pandas DataFrame. + + Returns: + (pd.DataFrame): The result as a DataFrame. + """ raise NotImplementedError() @abstractmethod def plot(): + """ + Plot the results. + """ raise NotImplementedError() @abstractmethod def _get_plot_details(): + """ + Get the metric details to be plotted. + """ raise NotImplementedError() diff --git a/supervision/metrics/utils/aggregate_metric_results.py b/supervision/metrics/utils/aggregate_metric_results.py index cfc3d76f3..57feac230 100644 --- a/supervision/metrics/utils/aggregate_metric_results.py +++ b/supervision/metrics/utils/aggregate_metric_results.py @@ -26,7 +26,7 @@ def aggregate_metric_results( DataFrame. Defaults to False. Raises: - ValueError: `metrics_results` can not be empty + ValueError: List `metrics_results` can not be empty ValueError: All elements of `metrics_results` must be of the same type ValueError: Base class of elements in `metrics_results` must be of type `MetricResult` @@ -39,15 +39,15 @@ def aggregate_metric_results( assert len(metrics_results) == len( model_names - ), "Number of metrics results and model names must be equal" + ), "Length of metrics_results and model_names must be equal" if len(metrics_results) == 0: - raise ValueError("metrics_results must not be empty") + raise ValueError("List metrics_results must not be empty") first_elem_type = type(metrics_results[0]) all_same_type = all(isinstance(x, first_elem_type) for x in metrics_results) if not all_same_type: - raise ValueError("All metrics_results must be of the same type") + raise ValueError("All metrics_results elements must be of the same type") if not isinstance(metrics_results[0], MetricResult): raise ValueError("Base class of metrics_results must be of type MetricResult") @@ -82,22 +82,22 @@ def plot_aggregate_metric_results( plot. Defaults to False. Raises: - ValueError: `metrics_results` can not be empty + ValueError: List `metrics_results` can not be empty ValueError: All elements of `metrics_results` must be of the same type ValueError: Base class of elements in `metrics_results` must be of type `MetricResult` """ assert len(metrics_results) == len( model_names - ), "Number of metrics results and model names must be equal" + ), "Length of metrics_results and model_names must be equal" if len(metrics_results) == 0: - raise ValueError("metrics_results must not be empty") + raise ValueError("List metrics_results must not be empty") first_elem_type = type(metrics_results[0]) all_same_type = all(isinstance(x, first_elem_type) for x in metrics_results) if not all_same_type: - raise ValueError("All metrics_results must be of the same type") + raise ValueError("All metrics_results elements must be of the same type") if not isinstance(metrics_results[0], MetricResult): raise ValueError("Base class of metrics_results must be of type MetricResult")