Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] sklearn 1.6.dev0 adjustments. #335

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions treeple/ensemble/_honest_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,10 +728,16 @@
return oob_samples

def __sklearn_tags__(self):
# XXX: nans should be supportable in HRF
tags = super().__sklearn_tags__()
tags.classifier_tags.multi_output = False
# XXX: nans should be supportable in HRF
tags.input_tags.allow_nan = False

try:
# sklearn >= 1.6 tags were revamped
tags.target_tags.multi_output = False
except AttributeError:
tags.classifier_tags.multi_output = False

Check warning on line 739 in treeple/ensemble/_honest_forest.py

View check run for this annotation

Codecov / codecov/patch

treeple/ensemble/_honest_forest.py#L738-L739

Added lines #L738 - L739 were not covered by tests

return tags

def decision_path(self, X):
Expand Down
1 change: 1 addition & 0 deletions treeple/experimental/tests/test_sdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def test_sklearn_compatible_estimator(estimator, check):
# XXX: can include this "generalization" in the future if it's useful
if check.func.__name__ in [
"check_class_weight_classifiers",
"check_sample_weight_equivalence",
]:
pytest.skip()
check(estimator)
2 changes: 1 addition & 1 deletion treeple/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from treeple.tree._neighbors import _compute_distance_matrix, compute_forest_similarity_matrix


class NearestNeighborsMetaEstimator(BaseEstimator, MetaEstimatorMixin):
class NearestNeighborsMetaEstimator(MetaEstimatorMixin, BaseEstimator):
"""Meta-estimator for nearest neighbors.

Uses a decision-tree, or forest model to compute distances between samples
Expand Down
2 changes: 1 addition & 1 deletion treeple/stats/permuteforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class PermutationHonestForestClassifier(HonestForestClassifier):
**tree_estimator_params : dict
Parameters to pass to the underlying base tree estimators.
These must be parameters for ``tree_estimator``.

Attributes
----------
estimator : treeple.tree.HonestTreeClassifier
Expand Down
52 changes: 27 additions & 25 deletions treeple/stats/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import scipy.sparse as sp
from joblib import Parallel, delayed
from joblib import Parallel, delayed, parallel_config
from numpy.typing import ArrayLike
from scipy.stats import entropy
from sklearn.ensemble._forest import _generate_unsampled_indices, _get_n_samples_bootstrap
Expand Down Expand Up @@ -234,18 +234,19 @@ def _compute_null_distribution_coleman(

# generate the random seeds for the parallel jobs
ss = np.random.SeedSequence(seed)
out = Parallel(n_jobs=n_jobs)(
delayed(_parallel_build_null_forests)(
y_pred_ind_arr,
n_estimators,
all_y_pred,
y_test,
seed,
metric,
**metric_kwargs,
with parallel_config("multiprocessing"):
out = Parallel(n_jobs=n_jobs)(
delayed(_parallel_build_null_forests)(
y_pred_ind_arr,
n_estimators,
all_y_pred,
y_test,
seed,
metric,
**metric_kwargs,
)
for i, seed in zip(range(n_repeats), ss.spawn(n_repeats))
Comment on lines +237 to +248
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this change made?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly the default loky would segfault during unit testing the *Oblique trees.

)
for i, seed in zip(range(n_repeats), ss.spawn(n_repeats))
)

for idx, (first_half_metric, second_half_metric) in enumerate(out):
metric_star[idx] = first_half_metric
Expand Down Expand Up @@ -512,20 +513,21 @@ def _compute_null_distribution_coleman_sparse(

# generate the random seeds for the parallel jobs
ss = np.random.SeedSequence(seed)
out = Parallel(n_jobs=n_jobs)(
delayed(_parallel_build_null_forests_sparse)(
np.arange(n_trees),
oob_predictions,
oob_indicators,
y_test,
n_outputs,
seed,
True,
metric,
**metric_kwargs,
with parallel_config("multiprocessing"):
out = Parallel(n_jobs=n_jobs)(
Comment on lines +516 to +517
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

delayed(_parallel_build_null_forests_sparse)(
np.arange(n_trees),
oob_predictions,
oob_indicators,
y_test,
n_outputs,
seed,
True,
metric,
**metric_kwargs,
)
for _, seed in zip(range(n_repeats), ss.spawn(n_repeats))
)
for _, seed in zip(range(n_repeats), ss.spawn(n_repeats))
)

metric_star = np.zeros((n_repeats,))
metric_star_pi = np.zeros((n_repeats,))
Expand Down
1 change: 1 addition & 0 deletions treeple/tests/test_honest_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def test_sklearn_compatible_estimator(estimator, check):
# for fitting the tree's splits
if check.func.__name__ in [
"check_class_weight_classifiers",
"check_sample_weight_equivalence",
# TODO: this is an error. Somehow a segfault is raised when fit is called first and
# then partial_fit
"check_fit_score_takes_y",
Expand Down
4 changes: 4 additions & 0 deletions treeple/tests/test_multiview_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
]
)
def test_sklearn_compatible_estimator(estimator, check):
if check.func.__name__ in [
"check_sample_weight_equivalence",
]:
pytest.skip()
check(estimator)


Expand Down
5 changes: 4 additions & 1 deletion treeple/tests/test_supervised_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,14 @@ def test_sklearn_compatible_estimator(estimator, check):
if isinstance(
estimator,
(
ExtraObliqueRandomForestRegressor,
ObliqueRandomForestRegressor,
PatchObliqueRandomForestRegressor,
ExtraObliqueRandomForestClassifier,
ObliqueRandomForestClassifier,
PatchObliqueRandomForestClassifier,
),
) and check.func.__name__ in ["check_fit_score_takes_y"]:
) and check.func.__name__ in ["check_sample_weight_equivalence", "check_fit_score_takes_y"]:
pytest.skip()
check(estimator)

Expand Down
1 change: 1 addition & 0 deletions treeple/tests/test_unsupervised_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_sklearn_compatible_estimator(estimator, check):
"check_methods_subset_invariance",
# # sample weights do not necessarily imply a sample is not used in clustering
"check_sample_weights_invariance",
"check_sample_weight_equivalence",
# # sample order is not preserved in predict
"check_methods_sample_order_invariance",
]:
Expand Down
1 change: 1 addition & 0 deletions treeple/tree/tests/test_honest_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def test_sklearn_compatible_estimator(estimator, check):
"check_class_weight_classifiers",
"check_classifier_multioutput",
"check_do_not_raise_errors_in_init_or_set_params",
"check_sample_weight_equivalence",
]:
pytest.skip()
check(estimator)
Expand Down
1 change: 1 addition & 0 deletions treeple/tree/tests/test_unsupervised_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def test_sklearn_compatible_transformer(estimator, check):
"check_sample_weights_invariance",
# sample order is not preserved in predict
"check_methods_sample_order_invariance",
"check_sample_weight_equivalence",
]:
pytest.skip()
check(estimator)
Expand Down
Loading