Skip to content

Commit

Permalink
Sklearn 1.6 compat (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep authored Jan 13, 2025
1 parent 93c1f19 commit eaae17c
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 37 deletions.
21 changes: 10 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ urls.Documentation = "https://sklearn-ann.readthedocs.io/"
dynamic = ["version", "readme"]
requires-python = "<3.13,>=3.9" # enforced by scipy
dependencies = [
"scikit-learn>=0.24.0",
"scikit-learn>=1.6.0",
"scipy>=1.11.1,<2.0.0",
]

Expand All @@ -23,7 +23,7 @@ tests = [
docs = [
"sphinx>=7",
"sphinx-gallery>=0.8.2",
"sphinx-book-theme>=1.1.0rc1",
"sphinx-book-theme>=1.1.0",
"sphinx-issues>=1.2.0",
"numpydoc>=1.1.0",
"matplotlib>=3.3.3",
Expand All @@ -37,6 +37,7 @@ faiss = [
]
pynndescent = [
"pynndescent>=0.5.1,<1.0.0",
"numba>=0.52",
]
nmslib = [
"nmslib>=2.1.1,<3.0.0 ; python_version < '3.11'",
Expand Down Expand Up @@ -84,16 +85,14 @@ ignore = [
[tool.ruff.lint.isort]
known-first-party = ["sklearn_ann"]

[tool.hatch.envs.default]
features = [
"tests",
"docs",
"annlibs",
]
[tool.hatch.envs.docs]
installer = "uv"
features = ["docs", "annlibs"]
scripts.build = "sphinx-build -M html docs docs/_build"

[tool.hatch.envs.default.scripts]
test = "pytest {args:tests}"
build-docs = "sphinx-build -M html docs docs/_build"
[tool.hatch.envs.hatch-test]
default-args = []
features = ["tests", "annlibs"]

[tool.hatch.build.targets.wheel]
packages = ["src/sklearn_ann"]
Expand Down
10 changes: 9 additions & 1 deletion src/sklearn_ann/cluster/rnn_dbscan.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from collections import deque
from typing import cast

import numpy as np
from sklearn.base import BaseEstimator, ClusterMixin
from sklearn.neighbors import KNeighborsTransformer
from sklearn.utils import Tags
from sklearn.utils.validation import validate_data

from ..utils import get_sparse_row

Expand Down Expand Up @@ -143,7 +146,7 @@ def __init__(
self.keep_knns = keep_knns

def fit(self, X, y=None):
X = self._validate_data(X, accept_sparse="csr")
X = validate_data(self, X, accept_sparse="csr")
if self.input_guarantee == "none":
algorithm = KNeighborsTransformer(n_neighbors=self.n_neighbors)
X = algorithm.fit_transform(X)
Expand Down Expand Up @@ -181,6 +184,11 @@ def drop_knns(self):
del self.knns_
del self.rev_knns_

def __sklearn_tags__(self) -> Tags:
tags = cast(Tags, super().__sklearn_tags__())
tags.input_tags.sparse = True
return tags


def simple_rnn_dbscan_pipeline(
neighbor_transformer, n_neighbors, n_jobs=None, keep_knns=None, **kwargs
Expand Down
15 changes: 9 additions & 6 deletions src/sklearn_ann/kneighbors/annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import Tags, TargetTags, TransformerTags
from sklearn.utils.validation import validate_data

from ..utils import TransformerChecksMixin

Expand All @@ -16,7 +18,7 @@ def __init__(self, n_neighbors=5, *, metric="euclidean", n_trees=10, search_k=-1
self.metric = metric

def fit(self, X, y=None):
X = self._validate_data(X)
X = validate_data(self, X)
self.n_samples_fit_ = X.shape[0]
metric = self.metric if self.metric != "sqeuclidean" else "euclidean"
self.annoy_ = annoy.AnnoyIndex(X.shape[1], metric=metric)
Expand Down Expand Up @@ -68,8 +70,9 @@ def _transform(self, X):

return kneighbors_graph

def _more_tags(self):
return {
"_xfail_checks": {"check_estimators_pickle": "Cannot pickle AnnoyIndex"},
"requires_y": False,
}
def __sklearn_tags__(self) -> Tags:
return Tags(
estimator_type="transformer",
target_tags=TargetTags(required=False),
transformer_tags=TransformerTags(),
)
21 changes: 10 additions & 11 deletions src/sklearn_ann/kneighbors/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from joblib import cpu_count
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import Tags, TargetTags, TransformerTags
from sklearn.utils.validation import validate_data

from ..utils import TransformerChecksMixin, postprocess_knn_csr

Expand Down Expand Up @@ -85,7 +87,7 @@ def _metric_info(self):

def fit(self, X, y=None):
normalize = self._metric_info.get("normalize", False)
X = self._validate_data(X, dtype=np.float32, copy=normalize)
X = validate_data(self, X, dtype=np.float32, copy=normalize)
self.n_samples_fit_ = X.shape[0]
if self.n_jobs == -1:
n_jobs = cpu_count()
Expand Down Expand Up @@ -157,14 +159,11 @@ def _transform(self, X):
def fit_transform(self, X, y=None):
return self.fit(X, y=y)._transform(X=None)

def _more_tags(self):
return {
"_xfail_checks": {
"check_estimators_pickle": "Cannot pickle FAISS index",
"check_methods_subset_invariance": "Unable to reset FAISS internal RNG",
},
"requires_y": False,
"preserves_dtype": [np.float32],
def __sklearn_tags__(self) -> Tags:
return Tags(
estimator_type="transformer",
target_tags=TargetTags(required=False),
transformer_tags=TransformerTags(preserves_dtype=[np.float32]),
# Could be made deterministic *if* we could reset FAISS's internal RNG
"non_deterministic": True,
}
non_deterministic=True,
)
14 changes: 8 additions & 6 deletions src/sklearn_ann/kneighbors/nmslib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import Tags, TransformerTags
from sklearn.utils.validation import validate_data

from ..utils import TransformerChecksMixin, check_metric

Expand All @@ -28,7 +30,7 @@ def __init__(
self.n_jobs = n_jobs

def fit(self, X, y=None):
X = self._validate_data(X)
X = validate_data(self, X)
self.n_samples_fit_ = X.shape[0]

check_metric(self.metric, METRIC_MAP)
Expand Down Expand Up @@ -62,8 +64,8 @@ def transform(self, X):

return kneighbors_graph

def _more_tags(self):
return {
"_xfail_checks": {"check_estimators_pickle": "Cannot pickle NMSLib index"},
"preserves_dtype": [np.float32],
}
def __sklearn_tags__(self) -> Tags:
return Tags(
estimator_type="transformer",
transformer_tags=TransformerTags(preserves_dtype=[np.float32]),
)
3 changes: 2 additions & 1 deletion src/sklearn_ann/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.utils.validation import validate_data


def check_metric(metric, metrics):
Expand Down Expand Up @@ -90,6 +91,6 @@ class TransformerChecksMixin:
def _transform_checks(self, X, *fitted_props, **check_params):
from sklearn.utils.validation import check_is_fitted

X = self._validate_data(X, reset=False, **check_params)
X = validate_data(self, X, reset=False, **check_params)
check_is_fitted(self, *fitted_props)
return X
14 changes: 13 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@
pytest.param(KDTreeTransformer),
]

PER_ESTIMATOR_XFAIL_CHECKS = {
AnnoyTransformer: dict(check_estimators_pickle="Cannot pickle AnnoyIndex"),
FAISSTransformer: dict(
check_estimators_pickle="Cannot pickle FAISS index",
check_methods_subset_invariance="Unable to reset FAISS internal RNG",
),
NMSlibTransformer: dict(check_estimators_pickle="Cannot pickle NMSLib index"),
}


def add_mark(param, mark):
return pytest.param(*param.values, marks=[*param.marks, mark], id=param.id)
Expand All @@ -51,7 +60,10 @@ def add_mark(param, mark):
],
)
def test_all_estimators(Estimator):
check_estimator(Estimator())
check_estimator(
Estimator(),
expected_failed_checks=PER_ESTIMATOR_XFAIL_CHECKS.get(Estimator, {}),
)


# The following critera are from:
Expand Down

0 comments on commit eaae17c

Please sign in to comment.