Skip to content

Commit

Permalink
Sklearn 1.6 compat
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Dec 17, 2024
1 parent 8f2c8ec commit 32275a7
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 32 deletions.
21 changes: 10 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ urls.Documentation = "https://sklearn-ann.readthedocs.io/"
dynamic = ["version", "readme"]
requires-python = "<3.13,>=3.9" # enforced by scipy
dependencies = [
"scikit-learn>=0.24.0",
"scikit-learn>=1.6.0",
"scipy>=1.11.1,<2.0.0",
]

Expand All @@ -23,7 +23,7 @@ tests = [
docs = [
"sphinx>=7",
"sphinx-gallery>=0.8.2",
"sphinx-book-theme>=1.1.0rc1",
"sphinx-book-theme>=1.1.0",
"sphinx-issues>=1.2.0",
"numpydoc>=1.1.0",
"matplotlib>=3.3.3",
Expand All @@ -37,6 +37,7 @@ faiss = [
]
pynndescent = [
"pynndescent>=0.5.1,<1.0.0",
"numba>=0.52",
]
nmslib = [
"nmslib>=2.1.1,<3.0.0 ; python_version < '3.11'",
Expand Down Expand Up @@ -84,16 +85,14 @@ ignore = [
[tool.ruff.lint.isort]
known-first-party = ["sklearn_ann"]

[tool.hatch.envs.default]
features = [
"tests",
"docs",
"annlibs",
]
[tool.hatch.envs.docs]
installer = "uv"
features = ["docs", "annlibs"]
scripts.build = "sphinx-build -M html docs docs/_build"

[tool.hatch.envs.default.scripts]
test = "pytest {args:tests}"
build-docs = "sphinx-build -M html docs docs/_build"
[tool.hatch.envs.hatch-test]
default-args = []
features = ["tests", "annlibs"]

[tool.hatch.build.targets.wheel]
packages = ["src/sklearn_ann"]
Expand Down
12 changes: 7 additions & 5 deletions src/sklearn_ann/kneighbors/annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import Tags, TargetTags, TransformerTags

from ..utils import TransformerChecksMixin

Expand Down Expand Up @@ -68,8 +69,9 @@ def _transform(self, X):

return kneighbors_graph

def _more_tags(self):
return {
"_xfail_checks": {"check_estimators_pickle": "Cannot pickle AnnoyIndex"},
"requires_y": False,
}
def __sklearn_tags__(self) -> Tags:
return Tags(
estimator_type="transformer",
target_tags=TargetTags(required=False),
transformer_tags=TransformerTags(),
)
18 changes: 8 additions & 10 deletions src/sklearn_ann/kneighbors/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from joblib import cpu_count
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import Tags, TargetTags, TransformerTags

from ..utils import TransformerChecksMixin, postprocess_knn_csr

Expand Down Expand Up @@ -157,14 +158,11 @@ def _transform(self, X):
def fit_transform(self, X, y=None):
return self.fit(X, y=y)._transform(X=None)

def _more_tags(self):
return {
"_xfail_checks": {
"check_estimators_pickle": "Cannot pickle FAISS index",
"check_methods_subset_invariance": "Unable to reset FAISS internal RNG",
},
"requires_y": False,
"preserves_dtype": [np.float32],
def __sklearn_tags__(self) -> Tags:
return Tags(
estimator_type="transformer",
target_tags=TargetTags(required=False),
transformer_tags=TransformerTags(preserves_dtype=[np.float32]),
# Could be made deterministic *if* we could reset FAISS's internal RNG
"non_deterministic": True,
}
non_deterministic=True,
)
11 changes: 6 additions & 5 deletions src/sklearn_ann/kneighbors/nmslib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import Tags, TransformerTags

from ..utils import TransformerChecksMixin, check_metric

Expand Down Expand Up @@ -62,8 +63,8 @@ def transform(self, X):

return kneighbors_graph

def _more_tags(self):
return {
"_xfail_checks": {"check_estimators_pickle": "Cannot pickle NMSLib index"},
"preserves_dtype": [np.float32],
}
def __sklearn_tags__(self) -> Tags:
return Tags(
estimator_type="transformer",
transformer_tags=TransformerTags(preserves_dtype=[np.float32]),
)
14 changes: 13 additions & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@
pytest.param(KDTreeTransformer),
]

PER_ESTIMATOR_XFAIL_CHECKS = {
AnnoyTransformer: dict(check_estimators_pickle="Cannot pickle AnnoyIndex"),
FAISSTransformer: dict(
check_estimators_pickle="Cannot pickle FAISS index",
check_methods_subset_invariance="Unable to reset FAISS internal RNG",
),
NMSlibTransformer: dict(check_estimators_pickle="Cannot pickle NMSLib index"),
}


def add_mark(param, mark):
return pytest.param(*param.values, marks=[*param.marks, mark], id=param.id)
Expand All @@ -51,7 +60,10 @@ def add_mark(param, mark):
],
)
def test_all_estimators(Estimator):
check_estimator(Estimator())
check_estimator(
Estimator(),
expected_failed_checks=PER_ESTIMATOR_XFAIL_CHECKS.get(Estimator, {}),
)


# The following critera are from:
Expand Down

0 comments on commit 32275a7

Please sign in to comment.