From 63b2100b9be43924d52bbd352913fc6b6f3ef3a7 Mon Sep 17 00:00:00 2001 From: vanderplas Date: Mon, 8 Apr 2024 11:28:38 -0700 Subject: [PATCH] Replace deprecated jax.random.shuffle with jax.random.permutation jax.random.shuffle has long been deprecated, because it cannot operate in-place like np.random.shuffle, and because its functionality can be performed with jax.random.permutation (with independent=True in the case of multi-dimensional arrays). PiperOrigin-RevId: 622905545 --- .../python/internal/backend/numpy/random_generators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_probability/python/internal/backend/numpy/random_generators.py b/tensorflow_probability/python/internal/backend/numpy/random_generators.py index 2e6807b34c..b14238caa1 100644 --- a/tensorflow_probability/python/internal/backend/numpy/random_generators.py +++ b/tensorflow_probability/python/internal/backend/numpy/random_generators.py @@ -219,7 +219,7 @@ def _shuffle_jax(value, seed=None, name=None): # pylint: disable=unused-argumen import jax.random as jaxrand # pylint: disable=g-import-not-at-top if seed is None: raise ValueError('Must provide PRNGKey to sample in JAX.') - return jaxrand.shuffle(seed, value, axis=0) + return jaxrand.permutation(seed, value, axis=0, independent=True) def _truncated_normal(