From 0342c2305c020b38af98164d6b2610f0a331aeb3 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 13 Jan 2025 16:39:39 +0100 Subject: [PATCH] Run all tests and fix nmslib (#73) --- pyproject.toml | 2 +- src/sklearn_ann/cluster/tests/__init__.py | 0 src/sklearn_ann/kneighbors/nmslib.py | 3 ++- .../cluster/tests/test_common.py => tests/test_cluster.py | 0 tests/{ => test_kneighbors}/conftest.py | 0 tests/{ => test_kneighbors}/test_annoy.py | 0 tests/{ => test_kneighbors}/test_common.py | 0 tests/{ => test_kneighbors}/test_faiss.py | 0 tests/{ => test_kneighbors}/test_nmslib.py | 0 9 files changed, 3 insertions(+), 2 deletions(-) delete mode 100644 src/sklearn_ann/cluster/tests/__init__.py rename src/sklearn_ann/cluster/tests/test_common.py => tests/test_cluster.py (100%) rename tests/{ => test_kneighbors}/conftest.py (100%) rename tests/{ => test_kneighbors}/test_annoy.py (100%) rename tests/{ => test_kneighbors}/test_common.py (100%) rename tests/{ => test_kneighbors}/test_faiss.py (100%) rename tests/{ => test_kneighbors}/test_nmslib.py (100%) diff --git a/pyproject.toml b/pyproject.toml index 37086a7..9cdfaa9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ pynndescent = [ "numba>=0.52", ] nmslib = [ - "nmslib>=2.1.1,<3.0.0 ; python_version < '3.11'", + "nmslib-metabrainz>=2.1.1,<3.0.0", ] annlibs = [ "sklearn-ann[annoy,faiss,pynndescent,nmslib]", diff --git a/src/sklearn_ann/cluster/tests/__init__.py b/src/sklearn_ann/cluster/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/sklearn_ann/kneighbors/nmslib.py b/src/sklearn_ann/kneighbors/nmslib.py index c10fb43..4736947 100644 --- a/src/sklearn_ann/kneighbors/nmslib.py +++ b/src/sklearn_ann/kneighbors/nmslib.py @@ -2,7 +2,7 @@ import numpy as np from scipy.sparse import csr_matrix from sklearn.base import BaseEstimator, TransformerMixin -from sklearn.utils import Tags, TransformerTags +from sklearn.utils import Tags, TargetTags, TransformerTags from sklearn.utils.validation import validate_data from ..utils import TransformerChecksMixin, check_metric @@ -67,5 +67,6 @@ def transform(self, X): def __sklearn_tags__(self) -> Tags: return Tags( estimator_type="transformer", + target_tags=TargetTags(required=False), transformer_tags=TransformerTags(preserves_dtype=[np.float32]), ) diff --git a/src/sklearn_ann/cluster/tests/test_common.py b/tests/test_cluster.py similarity index 100% rename from src/sklearn_ann/cluster/tests/test_common.py rename to tests/test_cluster.py diff --git a/tests/conftest.py b/tests/test_kneighbors/conftest.py similarity index 100% rename from tests/conftest.py rename to tests/test_kneighbors/conftest.py diff --git a/tests/test_annoy.py b/tests/test_kneighbors/test_annoy.py similarity index 100% rename from tests/test_annoy.py rename to tests/test_kneighbors/test_annoy.py diff --git a/tests/test_common.py b/tests/test_kneighbors/test_common.py similarity index 100% rename from tests/test_common.py rename to tests/test_kneighbors/test_common.py diff --git a/tests/test_faiss.py b/tests/test_kneighbors/test_faiss.py similarity index 100% rename from tests/test_faiss.py rename to tests/test_kneighbors/test_faiss.py diff --git a/tests/test_nmslib.py b/tests/test_kneighbors/test_nmslib.py similarity index 100% rename from tests/test_nmslib.py rename to tests/test_kneighbors/test_nmslib.py