Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA Add support for accepting a Numpy RandomState #6150

Open
wants to merge 10 commits into
base: branch-25.02
Choose a base branch
from
8 changes: 7 additions & 1 deletion python/cuml/cuml/cluster/kmeans.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ np = cpu_only_import('numpy')
from cuml.internals.safe_imports import gpu_only_import
rmm = gpu_only_import('rmm')
from cuml.internals.safe_imports import safe_import_from, return_false
from cuml.internals.utils import check_random_seed
import typing

IF GPUBUILD == 1:
Expand Down Expand Up @@ -206,7 +207,10 @@ class KMeans(UniversalBase,
params.max_iter = <int>self.max_iter
params.tol = <double>self.tol
params.verbosity = <int>self.verbose
params.rng_state.seed = self.random_state
# After transferring from one device to another `_seed` might not be set
# so we need to pass a dummy value here. Its value does not matter as the
# seed is only used during fitting
params.rng_state.seed = getattr(self, "_seed", 0)
params.metric = DistanceType.L2Expanded # distance metric as squared L2: @todo - support other metrics # noqa: E501
params.batch_samples = <int>self.max_samples_per_batch
params.oversampling_factor = <double>self.oversampling_factor
Expand Down Expand Up @@ -302,6 +306,8 @@ class KMeans(UniversalBase,
else None),
check_dtype=check_dtype)

self._seed = check_random_seed(self.random_state)

IF GPUBUILD == 1:

cdef uintptr_t input_ptr = _X_m.ptr
Expand Down
3 changes: 3 additions & 0 deletions python/cuml/cuml/cluster/kmeans_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ from cuml.common import input_to_cuml_array

from cuml.cluster import KMeans
from cuml.cluster.kmeans_utils cimport params as KMeansParams
from cuml.internals.utils import check_random_seed


cdef extern from "cuml/cluster/kmeans_mg.hpp" \
Expand Down Expand Up @@ -129,6 +130,8 @@ class KMeansMG(KMeans):

cdef uintptr_t sample_weight_ptr = sample_weight_m.ptr

self._seed = check_random_seed(self.random_state)

if (self.init in ['scalable-k-means++', 'k-means||', 'random']):
self.cluster_centers_ = CumlArray.zeros(shape=(self.n_clusters,
self.n_cols),
Expand Down
8 changes: 2 additions & 6 deletions python/cuml/cuml/decomposition/pca.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,6 @@ class PCA(UniversalBase,

``n_components = min(n_samples, n_features)``

random_state : int / None (default = None)
If you want results to be the same when you restart Python, select a
state.
svd_solver : 'full' or 'jacobi' or 'auto' (default = 'full')
Full uses a eigendecomposition of the covariance matrix then discards
components.
Expand Down Expand Up @@ -292,7 +289,7 @@ class PCA(UniversalBase,

@device_interop_preparation
def __init__(self, *, copy=True, handle=None, iterated_power=15,
n_components=None, random_state=None, svd_solver='auto',
n_components=None, svd_solver='auto',
tol=1e-7, verbose=False, whiten=False,
output_type=None):
# parameters
Expand All @@ -302,7 +299,6 @@ class PCA(UniversalBase,
self.copy = copy
self.iterated_power = iterated_power
self.n_components = n_components
self.random_state = random_state
self.svd_solver = svd_solver
self.tol = tol
self.whiten = whiten
Expand Down Expand Up @@ -739,7 +735,7 @@ class PCA(UniversalBase,
def _get_param_names(cls):
return super()._get_param_names() + \
["copy", "iterated_power", "n_components", "svd_solver", "tol",
"whiten", "random_state"]
"whiten"]

def _check_is_fitted(self, attr):
if not hasattr(self, attr) or (getattr(self, attr) is None):
Expand Down
3 changes: 2 additions & 1 deletion python/cuml/cuml/ensemble/randomforestclassifier.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import cuml.internals
from cuml.common.doc_utils import generate_docstring
from cuml.common.doc_utils import insert_into_docstring
from cuml.common import input_to_cuml_array
from cuml.internals.utils import check_random_seed

from cuml.ensemble.randomforest_common import BaseRandomForestModel
from cuml.ensemble.randomforest_common import _obtain_fil_model
Expand Down Expand Up @@ -450,7 +451,7 @@ class RandomForestClassifier(BaseRandomForestModel,
if self.random_state is None:
seed_val = <uintptr_t>NULL
else:
seed_val = <uintptr_t>self.random_state
seed_val = <uintptr_t>check_random_seed(self.random_state)

rf_params = set_rf_params(<int> self.max_depth,
<int> self.max_leaves,
Expand Down
3 changes: 2 additions & 1 deletion python/cuml/cuml/ensemble/randomforestregressor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ from cuml.internals.mixins import RegressorMixin
from cuml.common.doc_utils import generate_docstring
from cuml.common.doc_utils import insert_into_docstring
from cuml.common import input_to_cuml_array
from cuml.internals.utils import check_random_seed

from cuml.ensemble.randomforest_common import BaseRandomForestModel
from cuml.ensemble.randomforest_common import _obtain_fil_model
Expand Down Expand Up @@ -437,7 +438,7 @@ class RandomForestRegressor(BaseRandomForestModel,
if self.random_state is None:
seed_val = <uintptr_t>NULL
else:
seed_val = <uintptr_t>self.random_state
seed_val = <uintptr_t>check_random_seed(self.random_state)

rf_params = set_rf_params(<int> self.max_depth,
<int> self.max_leaves,
Expand Down
39 changes: 39 additions & 0 deletions python/cuml/cuml/internals/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numbers
import numpy as np


def check_random_seed(seed):
"""Turn a np.random.RandomState instance into a seed.
Parameters
----------
seed : None | int | instance of RandomState
If seed is None, return a random int as seed.
If seed is an int, return it.
If seed is a RandomState instance, derive a seed from it.
Otherwise raise ValueError.
"""
if seed is None:
seed = np.random.RandomState(None)

if isinstance(seed, numbers.Integral):
return seed
if isinstance(seed, np.random.RandomState):
return seed.randint(
low=0, high=np.iinfo(np.uint32).max, dtype=np.uint32
)
raise ValueError("%r cannot be used to create a seed." % seed)
3 changes: 2 additions & 1 deletion python/cuml/cuml/manifold/t_sne.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ from pylibraft.common.handle cimport handle_t
from cuml.internals.api_decorators import device_interop_preparation
from cuml.internals.api_decorators import enable_device_interop
import cuml.internals.logger as logger
from cuml.internals.utils import check_random_seed


from cuml.internals.array import CumlArray
Expand Down Expand Up @@ -595,7 +596,7 @@ class TSNE(UniversalBase,
def _build_tsne_params(self, algo):
cdef long long seed = -1
if self.random_state is not None:
seed = self.random_state
seed = check_random_seed(self.random_state)

cdef TSNEParams* params = new TSNEParams()
params.dim = <int> self.n_components
Expand Down
99 changes: 42 additions & 57 deletions python/cuml/cuml/manifold/umap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ from cuml.internals.array_sparse import SparseCumlArray
from cuml.internals.mem_type import MemoryType
from cuml.internals.mixins import CMajorInputTagMixin, SparseInputTagMixin
from cuml.common.sparse_utils import is_sparse
from cuml.internals.utils import check_random_seed

from cuml.common.array_descriptor import CumlArrayDescriptor
from cuml.internals.api_decorators import device_interop_preparation
Expand Down Expand Up @@ -400,22 +401,7 @@ class UMAP(UniversalBase,

self.deterministic = random_state is not None

# Check to see if we are already a random_state (type==np.uint64).
# Reuse this if already passed (can happen from get_params() of another
# instance)
if isinstance(random_state, np.uint64):
self.random_state = random_state
else:
# Otherwise create a RandomState instance to generate a new
# np.uint64
if isinstance(random_state, np.random.RandomState):
rs = random_state
else:
rs = np.random.RandomState(random_state)

self.random_state = rs.randint(low=0,
high=np.iinfo(np.uint32).max,
dtype=np.uint32)
self.random_state = random_state

if target_metric == "euclidean" or target_metric == "categorical":
self.target_metric = target_metric
Expand Down Expand Up @@ -455,77 +441,76 @@ class UMAP(UniversalBase,
if self.min_dist > self.spread:
raise ValueError("min_dist should be <= spread")

@staticmethod
def _build_umap_params(cls, sparse):
def _build_umap_params(self, sparse):
IF GPUBUILD == 1:
cdef UMAPParams* umap_params = new UMAPParams()
umap_params.n_neighbors = <int> cls.n_neighbors
umap_params.n_components = <int> cls.n_components
umap_params.n_epochs = <int> cls.n_epochs if cls.n_epochs else 0
umap_params.learning_rate = <float> cls.learning_rate
umap_params.min_dist = <float> cls.min_dist
umap_params.spread = <float> cls.spread
umap_params.set_op_mix_ratio = <float> cls.set_op_mix_ratio
umap_params.local_connectivity = <float> cls.local_connectivity
umap_params.repulsion_strength = <float> cls.repulsion_strength
umap_params.negative_sample_rate = <int> cls.negative_sample_rate
umap_params.transform_queue_size = <int> cls.transform_queue_size
umap_params.verbosity = <int> cls.verbose
umap_params.a = <float> cls.a
umap_params.b = <float> cls.b
if cls.init == "spectral":
umap_params.n_neighbors = <int> self.n_neighbors
umap_params.n_components = <int> self.n_components
umap_params.n_epochs = <int> self.n_epochs if self.n_epochs else 0
umap_params.learning_rate = <float> self.learning_rate
umap_params.min_dist = <float> self.min_dist
umap_params.spread = <float> self.spread
umap_params.set_op_mix_ratio = <float> self.set_op_mix_ratio
umap_params.local_connectivity = <float> self.local_connectivity
umap_params.repulsion_strength = <float> self.repulsion_strength
umap_params.negative_sample_rate = <int> self.negative_sample_rate
umap_params.transform_queue_size = <int> self.transform_queue_size
umap_params.verbosity = <int> self.verbose
umap_params.a = <float> self.a
umap_params.b = <float> self.b
if self.init == "spectral":
umap_params.init = <int> 1
else: # self.init == "random"
umap_params.init = <int> 0
umap_params.target_n_neighbors = <int> cls.target_n_neighbors
if cls.target_metric == "euclidean":
umap_params.target_n_neighbors = <int> self.target_n_neighbors
if self.target_metric == "euclidean":
umap_params.target_metric = MetricType.EUCLIDEAN
else: # self.target_metric == "categorical"
umap_params.target_metric = MetricType.CATEGORICAL
if cls.build_algo == "brute_force_knn":
if self.build_algo == "brute_force_knn":
umap_params.build_algo = graph_build_algo.BRUTE_FORCE_KNN
else: # self.init == "nn_descent"
umap_params.build_algo = graph_build_algo.NN_DESCENT
if cls.build_kwds is None:
if self.build_kwds is None:
umap_params.nn_descent_params.graph_degree = <uint64_t> 64
umap_params.nn_descent_params.intermediate_graph_degree = <uint64_t> 128
umap_params.nn_descent_params.max_iterations = <uint64_t> 20
umap_params.nn_descent_params.termination_threshold = <float> 0.0001
umap_params.nn_descent_params.return_distances = <bool> True
umap_params.nn_descent_params.n_clusters = <uint64_t> 1
else:
umap_params.nn_descent_params.graph_degree = <uint64_t> cls.build_kwds.get("nnd_graph_degree", 64)
umap_params.nn_descent_params.intermediate_graph_degree = <uint64_t> cls.build_kwds.get("nnd_intermediate_graph_degree", 128)
umap_params.nn_descent_params.max_iterations = <uint64_t> cls.build_kwds.get("nnd_max_iterations", 20)
umap_params.nn_descent_params.termination_threshold = <float> cls.build_kwds.get("nnd_termination_threshold", 0.0001)
umap_params.nn_descent_params.return_distances = <bool> cls.build_kwds.get("nnd_return_distances", True)
if cls.build_kwds.get("nnd_n_clusters", 1) < 1:
umap_params.nn_descent_params.graph_degree = <uint64_t> self.build_kwds.get("nnd_graph_degree", 64)
umap_params.nn_descent_params.intermediate_graph_degree = <uint64_t> self.build_kwds.get("nnd_intermediate_graph_degree", 128)
umap_params.nn_descent_params.max_iterations = <uint64_t> self.build_kwds.get("nnd_max_iterations", 20)
umap_params.nn_descent_params.termination_threshold = <float> self.build_kwds.get("nnd_termination_threshold", 0.0001)
umap_params.nn_descent_params.return_distances = <bool> self.build_kwds.get("nnd_return_distances", True)
if self.build_kwds.get("nnd_n_clusters", 1) < 1:
logger.info("Negative number of nnd_n_clusters not allowed. Changing nnd_n_clusters to 1")
umap_params.nn_descent_params.n_clusters = <uint64_t> cls.build_kwds.get("nnd_n_clusters", 1)
umap_params.nn_descent_params.n_clusters = <uint64_t> self.build_kwds.get("nnd_n_clusters", 1)

umap_params.target_weight = <float> cls.target_weight
umap_params.random_state = <uint64_t> cls.random_state
umap_params.deterministic = <bool> cls.deterministic
umap_params.target_weight = <float> self.target_weight
umap_params.random_state = <uint64_t> check_random_seed(self.random_state)
umap_params.deterministic = <bool> self.deterministic

try:
umap_params.metric = metric_parsing[cls.metric.lower()]
umap_params.metric = metric_parsing[self.metric.lower()]
if sparse:
if umap_params.metric not in SPARSE_SUPPORTED_METRICS:
raise NotImplementedError(f"Metric '{cls.metric}' not supported for sparse inputs.")
raise NotImplementedError(f"Metric '{self.metric}' not supported for sparse inputs.")
elif umap_params.metric not in DENSE_SUPPORTED_METRICS:
raise NotImplementedError(f"Metric '{cls.metric}' not supported for dense inputs.")
raise NotImplementedError(f"Metric '{self.metric}' not supported for dense inputs.")

except KeyError:
raise ValueError(f"Invalid value for metric: {cls.metric}")
raise ValueError(f"Invalid value for metric: {self.metric}")

if cls.metric_kwds is None:
if self.metric_kwds is None:
umap_params.p = <float> 2.0
else:
umap_params.p = <float>cls.metric_kwds.get('p')
umap_params.p = <float>self.metric_kwds.get('p')

cdef uintptr_t callback_ptr = 0
if cls.callback:
callback_ptr = cls.callback.get_native_callback()
if self.callback:
callback_ptr = self.callback.get_native_callback()
umap_params.callback = <GraphBasedDimRedCallback*>callback_ptr

return <size_t>umap_params
Expand Down Expand Up @@ -658,7 +643,7 @@ class UMAP(UniversalBase,
<handle_t*> <size_t> self.handle.getHandle()
fss_graph = GraphHolder.new_graph(handle_.get_stream())
cdef UMAPParams* umap_params = \
<UMAPParams*> <size_t> UMAP._build_umap_params(self,
<UMAPParams*> <size_t> self._build_umap_params(
self.sparse_fit)
if self.sparse_fit:
fit_sparse(handle_[0],
Expand Down Expand Up @@ -817,7 +802,7 @@ class UMAP(UniversalBase,

IF GPUBUILD == 1:
cdef UMAPParams* umap_params = \
<UMAPParams*> <size_t> UMAP._build_umap_params(self,
<UMAPParams*> <size_t> self._build_umap_params(
self.sparse_fit)
cdef handle_t * handle_ = \
<handle_t*> <size_t> self.handle.getHandle()
Expand Down
42 changes: 42 additions & 0 deletions python/cuml/cuml/tests/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest

import numpy as np
import cuml
from cuml.datasets import make_blobs


@pytest.mark.parametrize(
"Estimator",
[
cuml.KMeans,
cuml.RandomForestRegressor,
cuml.RandomForestClassifier,
cuml.TSNE,
cuml.UMAP,
],
)
def test_random_state_argument(Estimator):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a quick test here that the results are the same with the seed, or is that tested in the individual algo tests?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the results will be the same because RandomState(42) will not lead to 42 being passed as the seed to the internal functions that cuml calls.

We can't pass any form of "RNG state" to the internal functions, we can just pass an integer. So I think the best we can do when a RandomState is passed in is to use it to generate a uint64 and use that as seed for the internal functions. I think this is better than trying to extract the (original) seed from the RandomState because that way you get a different value if the random state has been used previously.

For example in this (contrived) example I think the two RFs should not both use 42 as the seed internally as they are two separate instances.

rs = RandomState(42)

rf1 = cuml.RandomForestClassifier(random_state=rs)
rf2 = cuml.RandomForestClassifier(random_state=rs)

X, y = make_blobs(random_state=0)
# Check that both integer and np.random.RandomState are accepted
for seed in (42, np.random.RandomState(42)):
est = Estimator(random_state=seed)

if est.__class__.__name__ != "TSNE":
est.fit(X, y)
else:
est.fit(X)
4 changes: 2 additions & 2 deletions python/cuml/cuml/tests/test_device_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,11 +977,11 @@ def test_hdbscan_methods(train_device, infer_device):
@pytest.mark.parametrize("infer_device", ["cpu", "gpu"])
def test_kmeans_methods(train_device, infer_device):
n_clusters = 20
ref_model = skKMeans(n_clusters=n_clusters)
ref_model = skKMeans(n_clusters=n_clusters, random_state=42)
ref_model.fit(X_train_blob)
ref_output = ref_model.predict(X_test_blob)

model = KMeans(n_clusters=n_clusters)
model = KMeans(n_clusters=n_clusters, random_state=42)
with using_device_type(train_device):
model.fit(X_train_blob)
with using_device_type(infer_device):
Expand Down
Loading