Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sklearn 1.6 compat #70

Merged
merged 4 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ urls.Documentation = "https://sklearn-ann.readthedocs.io/"
dynamic = ["version", "readme"]
requires-python = "<3.13,>=3.9" # enforced by scipy
dependencies = [
"scikit-learn>=0.24.0",
"scikit-learn>=1.6.0",
"scipy>=1.11.1,<2.0.0",
]

Expand All @@ -23,7 +23,7 @@ tests = [
docs = [
"sphinx>=7",
"sphinx-gallery>=0.8.2",
"sphinx-book-theme>=1.1.0rc1",
"sphinx-book-theme>=1.1.0",
"sphinx-issues>=1.2.0",
"numpydoc>=1.1.0",
"matplotlib>=3.3.3",
Expand All @@ -37,6 +37,7 @@ faiss = [
]
pynndescent = [
"pynndescent>=0.5.1,<1.0.0",
"numba>=0.52",
]
nmslib = [
"nmslib>=2.1.1,<3.0.0 ; python_version < '3.11'",
Expand Down Expand Up @@ -84,16 +85,14 @@ ignore = [
[tool.ruff.lint.isort]
known-first-party = ["sklearn_ann"]

[tool.hatch.envs.default]
features = [
"tests",
"docs",
"annlibs",
]
[tool.hatch.envs.docs]
installer = "uv"
features = ["docs", "annlibs"]
scripts.build = "sphinx-build -M html docs docs/_build"

[tool.hatch.envs.default.scripts]
test = "pytest {args:tests}"
build-docs = "sphinx-build -M html docs docs/_build"
[tool.hatch.envs.hatch-test]
default-args = []
features = ["tests", "annlibs"]

[tool.hatch.build.targets.wheel]
packages = ["src/sklearn_ann"]
Expand Down
10 changes: 9 additions & 1 deletion src/sklearn_ann/cluster/rnn_dbscan.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from collections import deque
from typing import cast

import numpy as np
from sklearn.base import BaseEstimator, ClusterMixin
from sklearn.neighbors import KNeighborsTransformer
from sklearn.utils import Tags
from sklearn.utils.validation import validate_data

from ..utils import get_sparse_row

Expand Down Expand Up @@ -143,7 +146,7 @@ def __init__(
self.keep_knns = keep_knns

def fit(self, X, y=None):
X = self._validate_data(X, accept_sparse="csr")
X = validate_data(self, X, accept_sparse="csr")
if self.input_guarantee == "none":
algorithm = KNeighborsTransformer(n_neighbors=self.n_neighbors)
X = algorithm.fit_transform(X)
Expand Down Expand Up @@ -181,6 +184,11 @@ def drop_knns(self):
del self.knns_
del self.rev_knns_

def __sklearn_tags__(self) -> Tags:
tags = cast(Tags, super().__sklearn_tags__())
tags.input_tags.sparse = True
return tags


def simple_rnn_dbscan_pipeline(
neighbor_transformer, n_neighbors, n_jobs=None, keep_knns=None, **kwargs
Expand Down
15 changes: 9 additions & 6 deletions src/sklearn_ann/kneighbors/annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import Tags, TargetTags, TransformerTags
from sklearn.utils.validation import validate_data

from ..utils import TransformerChecksMixin

Expand All @@ -16,7 +18,7 @@ def __init__(self, n_neighbors=5, *, metric="euclidean", n_trees=10, search_k=-1
self.metric = metric

def fit(self, X, y=None):
X = self._validate_data(X)
X = validate_data(self, X)
self.n_samples_fit_ = X.shape[0]
metric = self.metric if self.metric != "sqeuclidean" else "euclidean"
self.annoy_ = annoy.AnnoyIndex(X.shape[1], metric=metric)
Expand Down Expand Up @@ -68,8 +70,9 @@ def _transform(self, X):

return kneighbors_graph

def _more_tags(self):
return {
"_xfail_checks": {"check_estimators_pickle": "Cannot pickle AnnoyIndex"},
"requires_y": False,
}
def __sklearn_tags__(self) -> Tags:
return Tags(
estimator_type="transformer",
target_tags=TargetTags(required=False),
transformer_tags=TransformerTags(),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default for preserves_dtype is ["float64"], is that accurate here?

)
21 changes: 10 additions & 11 deletions src/sklearn_ann/kneighbors/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from joblib import cpu_count
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import Tags, TargetTags, TransformerTags
from sklearn.utils.validation import validate_data

from ..utils import TransformerChecksMixin, postprocess_knn_csr

Expand Down Expand Up @@ -85,7 +87,7 @@ def _metric_info(self):

def fit(self, X, y=None):
normalize = self._metric_info.get("normalize", False)
X = self._validate_data(X, dtype=np.float32, copy=normalize)
X = validate_data(self, X, dtype=np.float32, copy=normalize)
self.n_samples_fit_ = X.shape[0]
if self.n_jobs == -1:
n_jobs = cpu_count()
Expand Down Expand Up @@ -157,14 +159,11 @@ def _transform(self, X):
def fit_transform(self, X, y=None):
return self.fit(X, y=y)._transform(X=None)

def _more_tags(self):
return {
"_xfail_checks": {
"check_estimators_pickle": "Cannot pickle FAISS index",
"check_methods_subset_invariance": "Unable to reset FAISS internal RNG",
},
"requires_y": False,
"preserves_dtype": [np.float32],
def __sklearn_tags__(self) -> Tags:
return Tags(
estimator_type="transformer",
target_tags=TargetTags(required=False),
transformer_tags=TransformerTags(preserves_dtype=[np.float32]),
# Could be made deterministic *if* we could reset FAISS's internal RNG
"non_deterministic": True,
}
non_deterministic=True,
)
14 changes: 8 additions & 6 deletions src/sklearn_ann/kneighbors/nmslib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import Tags, TransformerTags
from sklearn.utils.validation import validate_data

from ..utils import TransformerChecksMixin, check_metric

Expand All @@ -28,7 +30,7 @@ def __init__(
self.n_jobs = n_jobs

def fit(self, X, y=None):
X = self._validate_data(X)
X = validate_data(self, X)
self.n_samples_fit_ = X.shape[0]

check_metric(self.metric, METRIC_MAP)
Expand Down Expand Up @@ -62,8 +64,8 @@ def transform(self, X):

return kneighbors_graph

def _more_tags(self):
return {
"_xfail_checks": {"check_estimators_pickle": "Cannot pickle NMSLib index"},
"preserves_dtype": [np.float32],
}
def __sklearn_tags__(self) -> Tags:
return Tags(
estimator_type="transformer",
transformer_tags=TransformerTags(preserves_dtype=[np.float32]),
)
3 changes: 2 additions & 1 deletion src/sklearn_ann/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.utils.validation import validate_data


def check_metric(metric, metrics):
Expand Down Expand Up @@ -90,6 +91,6 @@ class TransformerChecksMixin:
def _transform_checks(self, X, *fitted_props, **check_params):
from sklearn.utils.validation import check_is_fitted

X = self._validate_data(X, reset=False, **check_params)
X = validate_data(self, X, reset=False, **check_params)
check_is_fitted(self, *fitted_props)
return X
14 changes: 13 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@
pytest.param(KDTreeTransformer),
]

PER_ESTIMATOR_XFAIL_CHECKS = {
AnnoyTransformer: dict(check_estimators_pickle="Cannot pickle AnnoyIndex"),
FAISSTransformer: dict(
check_estimators_pickle="Cannot pickle FAISS index",
check_methods_subset_invariance="Unable to reset FAISS internal RNG",
),
NMSlibTransformer: dict(check_estimators_pickle="Cannot pickle NMSLib index"),
}


def add_mark(param, mark):
return pytest.param(*param.values, marks=[*param.marks, mark], id=param.id)
Expand All @@ -51,7 +60,10 @@ def add_mark(param, mark):
],
)
def test_all_estimators(Estimator):
check_estimator(Estimator())
check_estimator(
Estimator(),
expected_failed_checks=PER_ESTIMATOR_XFAIL_CHECKS.get(Estimator, {}),
)


# The following critera are from:
Expand Down
Loading