Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Replace deprecated jax.random.shuffle with jax.random.permutation
jax.random.shuffle has long been deprecated, because it cannot operate in-place like np.random.shuffle, and because its functionality can be performed with jax.random.permutation (with independent=True in the case of multi-dimensional arrays). PiperOrigin-RevId: 622905545