Skip to content

Commit

Permalink
Fix NMSLib transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Jan 13, 2025
1 parent 9fef4ac commit cf156f6
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/sklearn_ann/kneighbors/nmslib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import Tags, TransformerTags
from sklearn.utils import Tags, TargetTags, TransformerTags
from sklearn.utils.validation import validate_data

from ..utils import TransformerChecksMixin, check_metric
Expand Down Expand Up @@ -67,5 +67,6 @@ def transform(self, X):
def __sklearn_tags__(self) -> Tags:
return Tags(
estimator_type="transformer",
target_tags=TargetTags(required=False),
transformer_tags=TransformerTags(preserves_dtype=[np.float32]),
)

0 comments on commit cf156f6

Please sign in to comment.