diff --git a/tensorflow_probability/python/internal/backend/numpy/random_generators.py b/tensorflow_probability/python/internal/backend/numpy/random_generators.py index 2e6807b34c..b14238caa1 100644 --- a/tensorflow_probability/python/internal/backend/numpy/random_generators.py +++ b/tensorflow_probability/python/internal/backend/numpy/random_generators.py @@ -219,7 +219,7 @@ def _shuffle_jax(value, seed=None, name=None): # pylint: disable=unused-argumen import jax.random as jaxrand # pylint: disable=g-import-not-at-top if seed is None: raise ValueError('Must provide PRNGKey to sample in JAX.') - return jaxrand.shuffle(seed, value, axis=0) + return jaxrand.permutation(seed, value, axis=0, independent=True) def _truncated_normal(