From 5d97efd075c57e9b61bff9222e37b57b4828c67e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E9=BC=8E=E5=BD=A6?= Date: Mon, 8 Mar 2021 16:05:42 +0800 Subject: [PATCH] [FIX] Fix regressor y value check (#53) * [FIX] Fix regressor y value check * [FIX] Revert Code Quality * [FIX] Fix Logic * [FIX] Reformat to pass code quality check * refactor some code to pass ci * Update CHANGELOG.rst --- CHANGELOG.rst | 1 + deepforest/cascade.py | 27 ++++++++++++++++++++++----- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 8aae007..7c82c20 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -30,6 +30,7 @@ Version 0.1.* .. |Fix| replace:: :raw-html:`Fix` :raw-latex:`{\small\sc [Fix]}` .. |API| replace:: :raw-html:`API Change` :raw-latex:`{\small\sc [API Change]}` +- |Enhancement| improve target checks for :obj:`CascadeForestRegressor` (`#53 `__) @chendingyan - |Fix| fix the prediction workflow with only one cascade layer (`#56 `__) @xuyxu - |Fix| fix inconsistency on predictor name (`#52 `__) @xuyxu - |Feature| add official support for ManyLinux-aarch64 (`#47 `__) @xuyxu diff --git a/deepforest/cascade.py b/deepforest/cascade.py index e427786..4da19b4 100644 --- a/deepforest/cascade.py +++ b/deepforest/cascade.py @@ -1415,20 +1415,37 @@ def __init__( self.type_of_target_ = None def _check_target_values(self, y): - """ - Check the input target values for regressor. - """ + """Check the input target values for regressor.""" self.type_of_target_ = type_of_target(y) + + if not self._check_array_numeric(y): + msg = ( + "CascadeForestRegressor only accepts numeric values as" + " valid target values." + ) + raise ValueError(msg) + if self.type_of_target_ not in ( "continuous", "continuous-multioutput", + "multiclass", + "multiclass-multioutput", ): msg = ( - "CascadeForestRegressor is used for univariate or multi-variate regression," - " but the target values seem not to be one of them." + "CascadeForestRegressor is used for univariate or" + " multi-variate regression, but the target values seem not" + " to be one of them." ) raise ValueError(msg) + def _check_array_numeric(self, y): + """Check the input numpy array y is all numeric.""" + numeric_types = np.typecodes['AllInteger'] + np.typecodes["AllFloat"] + if y.dtype.kind in numeric_types: + return True + else: + return False + def _repr_performance(self, pivot): msg = "Val MSE = {:.5f}" return msg.format(pivot)