Skip to content

Commit

Permalink
Fix use of deprecated validate_data
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Jan 13, 2025
1 parent 6a8f60d commit bac08ef
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/sklearn_ann/cluster/rnn_dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sklearn.base import BaseEstimator, ClusterMixin
from sklearn.neighbors import KNeighborsTransformer
from sklearn.utils import Tags
from sklearn.utils.validation import validate_data

from ..utils import get_sparse_row

Expand Down Expand Up @@ -145,7 +146,7 @@ def __init__(
self.keep_knns = keep_knns

def fit(self, X, y=None):
X = self._validate_data(X, accept_sparse="csr")
X = validate_data(self, X, accept_sparse="csr")
if self.input_guarantee == "none":
algorithm = KNeighborsTransformer(n_neighbors=self.n_neighbors)
X = algorithm.fit_transform(X)
Expand Down
3 changes: 2 additions & 1 deletion src/sklearn_ann/kneighbors/annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import Tags, TargetTags, TransformerTags
from sklearn.utils.validation import validate_data

from ..utils import TransformerChecksMixin

Expand All @@ -17,7 +18,7 @@ def __init__(self, n_neighbors=5, *, metric="euclidean", n_trees=10, search_k=-1
self.metric = metric

def fit(self, X, y=None):
X = self._validate_data(X)
X = validate_data(self, X)
self.n_samples_fit_ = X.shape[0]
metric = self.metric if self.metric != "sqeuclidean" else "euclidean"
self.annoy_ = annoy.AnnoyIndex(X.shape[1], metric=metric)
Expand Down
3 changes: 2 additions & 1 deletion src/sklearn_ann/kneighbors/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import Tags, TargetTags, TransformerTags
from sklearn.utils.validation import validate_data

from ..utils import TransformerChecksMixin, postprocess_knn_csr

Expand Down Expand Up @@ -86,7 +87,7 @@ def _metric_info(self):

def fit(self, X, y=None):
normalize = self._metric_info.get("normalize", False)
X = self._validate_data(X, dtype=np.float32, copy=normalize)
X = validate_data(self, X, dtype=np.float32, copy=normalize)
self.n_samples_fit_ = X.shape[0]
if self.n_jobs == -1:
n_jobs = cpu_count()
Expand Down
3 changes: 2 additions & 1 deletion src/sklearn_ann/kneighbors/nmslib.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import Tags, TransformerTags
from sklearn.utils.validation import validate_data

from ..utils import TransformerChecksMixin, check_metric

Expand All @@ -29,7 +30,7 @@ def __init__(
self.n_jobs = n_jobs

def fit(self, X, y=None):
X = self._validate_data(X)
X = validate_data(self, X)
self.n_samples_fit_ = X.shape[0]

check_metric(self.metric, METRIC_MAP)
Expand Down
3 changes: 2 additions & 1 deletion src/sklearn_ann/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.utils.validation import validate_data


def check_metric(metric, metrics):
Expand Down Expand Up @@ -90,6 +91,6 @@ class TransformerChecksMixin:
def _transform_checks(self, X, *fitted_props, **check_params):
from sklearn.utils.validation import check_is_fitted

X = self._validate_data(X, reset=False, **check_params)
X = validate_data(self, X, reset=False, **check_params)
check_is_fitted(self, *fitted_props)
return X

0 comments on commit bac08ef

Please sign in to comment.