Skip to content

Commit

Permalink
[FIX] Fix regressor y value check (#53)
Browse files Browse the repository at this point in the history
* [FIX] Fix regressor y value check

* [FIX] Revert Code Quality

* [FIX] Fix Logic

* [FIX] Reformat to pass code quality check

* refactor some code to pass ci

* Update CHANGELOG.rst
  • Loading branch information
chendingyan authored Mar 8, 2021
1 parent 165e5d5 commit 5d97efd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Version 0.1.*
.. |Fix| replace:: :raw-html:`<span class="badge badge-danger">Fix</span>` :raw-latex:`{\small\sc [Fix]}`
.. |API| replace:: :raw-html:`<span class="badge badge-warning">API Change</span>` :raw-latex:`{\small\sc [API Change]}`

- |Enhancement| improve target checks for :obj:`CascadeForestRegressor` (`#53 <https://github.com/LAMDA-NJU/Deep-Forest/pull/53>`__) @chendingyan
- |Fix| fix the prediction workflow with only one cascade layer (`#56 <https://github.com/LAMDA-NJU/Deep-Forest/pull/56>`__) @xuyxu
- |Fix| fix inconsistency on predictor name (`#52 <https://github.com/LAMDA-NJU/Deep-Forest/pull/52>`__) @xuyxu
- |Feature| add official support for ManyLinux-aarch64 (`#47 <https://github.com/LAMDA-NJU/Deep-Forest/pull/47>`__) @xuyxu
Expand Down
27 changes: 22 additions & 5 deletions deepforest/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,20 +1415,37 @@ def __init__(
self.type_of_target_ = None

def _check_target_values(self, y):
"""
Check the input target values for regressor.
"""
"""Check the input target values for regressor."""
self.type_of_target_ = type_of_target(y)

if not self._check_array_numeric(y):
msg = (
"CascadeForestRegressor only accepts numeric values as"
" valid target values."
)
raise ValueError(msg)

if self.type_of_target_ not in (
"continuous",
"continuous-multioutput",
"multiclass",
"multiclass-multioutput",
):
msg = (
"CascadeForestRegressor is used for univariate or multi-variate regression,"
" but the target values seem not to be one of them."
"CascadeForestRegressor is used for univariate or"
" multi-variate regression, but the target values seem not"
" to be one of them."
)
raise ValueError(msg)

def _check_array_numeric(self, y):
"""Check the input numpy array y is all numeric."""
numeric_types = np.typecodes['AllInteger'] + np.typecodes["AllFloat"]
if y.dtype.kind in numeric_types:
return True
else:
return False

def _repr_performance(self, pivot):
msg = "Val MSE = {:.5f}"
return msg.format(pivot)
Expand Down

0 comments on commit 5d97efd

Please sign in to comment.