From 0af6f4106a033ed0f35cbe5b439a7082d4b30fdc Mon Sep 17 00:00:00 2001 From: vanderplas Date: Fri, 8 Mar 2024 09:29:44 -0800 Subject: [PATCH] tensorflow_probability: avoid JAX key reuse in tests These issues were found by running the test suite with jax_enable_key_reuse_checks=True, available since https://github.com/google/jax/pull/19795. I did find some issues in non-test code, but I plan to address those issues individually as they may have user-visible effects. This change only touches innocuous key reuse in test files. PiperOrigin-RevId: 613958607 --- .../python/bijectors/ffjord_test.py | 3 +-- .../scale_matvec_linear_operator_test.py | 7 +++--- .../python/distributions/bernoulli_test.py | 2 +- .../continuous_bernoulli_test.py | 2 +- .../python/distributions/empirical_test.py | 3 +++ .../distributions/jax_transformation_test.py | 1 + .../python/distributions/kumaraswamy_test.py | 1 + .../distributions/linear_gaussian_ssm_test.py | 3 ++- .../distributions/probit_bernoulli_test.py | 2 +- .../transformed_distribution_test.py | 1 + .../bijectors/highway_flow_test.py | 1 + .../distributions/importance_resample_test.py | 21 +++++++++-------- .../mcmc/elliptical_slice_sampler_test.py | 5 +++- ...based_trajectory_length_adaptation_test.py | 8 +++++-- .../experimental/mcmc/sample_fold_test.py | 2 ++ .../sequential_monte_carlo_kernel_test.py | 1 + .../sts_gibbs/dynamic_spike_and_slab_test.py | 1 + .../sts_gibbs/spike_and_slab_test.py | 1 + .../experimental/util/trainable_test.py | 5 +++- .../batched_rejection_sampler_test.py | 2 ++ .../python/internal/samplers_test.py | 2 ++ .../python/internal/test_util.py | 11 +++++++++ .../internal/trainable_state_util_test.py | 6 ++++- .../python/mcmc/hmc_test.py | 3 ++- .../python/mcmc/transformed_kernel_test.py | 9 ++++---- .../python/sts/components/regression_test.py | 8 +++++-- .../python/sts/structural_time_series_test.py | 11 ++++++--- .../python/vi/csiszar_divergence_test.py | 23 ++++++++++++++++--- 28 files changed, 108 insertions(+), 37 deletions(-) diff --git a/tensorflow_probability/python/bijectors/ffjord_test.py b/tensorflow_probability/python/bijectors/ffjord_test.py index c23fedab14..2e4159b140 100644 --- a/tensorflow_probability/python/bijectors/ffjord_test.py +++ b/tensorflow_probability/python/bijectors/ffjord_test.py @@ -159,7 +159,6 @@ def testJacobianDiagonalScaling(self, dtype): ) def testHutchinsonsNormalEstimator(self, dtype): - seed = test_util.test_seed() tf_dtype = tf.as_dtype(dtype) num_dims = 10 np.random.seed(seed=test_util.test_seed(sampler_type='integer')) @@ -170,7 +169,7 @@ def testHutchinsonsNormalEstimator(self, dtype): def trace_augmentation_fn(ode_fn, z_shape, dtype): return ffjord.trace_jacobian_hutchinson( - ode_fn, z_shape, dtype, num_samples=128, seed=seed) + ode_fn, z_shape, dtype, num_samples=128, seed=test_util.test_seed()) bijector = ffjord.FFJORD( trace_augmentation_fn=trace_augmentation_fn, diff --git a/tensorflow_probability/python/bijectors/scale_matvec_linear_operator_test.py b/tensorflow_probability/python/bijectors/scale_matvec_linear_operator_test.py index 824fede01f..b03b0a6594 100644 --- a/tensorflow_probability/python/bijectors/scale_matvec_linear_operator_test.py +++ b/tensorflow_probability/python/bijectors/scale_matvec_linear_operator_test.py @@ -234,15 +234,16 @@ def build_operator(self): tf.linalg.LinearOperatorIdentity(2)]], is_non_singular=True) def build_batched_operator(self): - seed = test_util.test_seed() + seed1 = test_util.test_seed() + seed2 = test_util.test_seed() return tf.linalg.LinearOperatorBlockLowerTriangular([ [tf.linalg.LinearOperatorFullMatrix( - tf.random.normal((3, 4, 4), dtype=tf.float32, seed=seed), + tf.random.normal((3, 4, 4), dtype=tf.float32, seed=seed1), is_non_singular=True)], [tf.linalg.LinearOperatorZeros( 3, 4, is_square=False, is_self_adjoint=False), tf.linalg.LinearOperatorFullMatrix( - tf.random.normal((3, 3), dtype=tf.float32, seed=seed), + tf.random.normal((3, 3), dtype=tf.float32, seed=seed2), is_non_singular=True)] ], is_non_singular=True) diff --git a/tensorflow_probability/python/distributions/bernoulli_test.py b/tensorflow_probability/python/distributions/bernoulli_test.py index ef8b3a883e..ba357b01b0 100644 --- a/tensorflow_probability/python/distributions/bernoulli_test.py +++ b/tensorflow_probability/python/distributions/bernoulli_test.py @@ -245,7 +245,7 @@ def _seed(seed=None): seed = test_util.test_seed() if seed is None else seed if tf.executing_eagerly(): tf.random.set_seed(seed) - return seed + return test_util.clone_seed(seed) seed = _seed() self.assertAllEqual( self.evaluate(dist.sample(n, _seed(seed=seed))), diff --git a/tensorflow_probability/python/distributions/continuous_bernoulli_test.py b/tensorflow_probability/python/distributions/continuous_bernoulli_test.py index fd899093c9..6ffe512b04 100644 --- a/tensorflow_probability/python/distributions/continuous_bernoulli_test.py +++ b/tensorflow_probability/python/distributions/continuous_bernoulli_test.py @@ -342,7 +342,7 @@ def _seed(seed=None): seed = test_util.test_seed() if seed is None else seed if tf.executing_eagerly(): tf1.set_random_seed(seed) - return seed + return test_util.clone_seed(seed) seed = _seed() self.assertAllClose( diff --git a/tensorflow_probability/python/distributions/empirical_test.py b/tensorflow_probability/python/distributions/empirical_test.py index dc6fe3dd3d..3c12807c8f 100644 --- a/tensorflow_probability/python/distributions/empirical_test.py +++ b/tensorflow_probability/python/distributions/empirical_test.py @@ -192,6 +192,7 @@ def testSampleN(self): if tf.executing_eagerly(): tf.random.set_seed(seed) samples1 = dist.sample(n, seed) + seed = test_util.clone_seed(seed) if tf.executing_eagerly(): tf.random.set_seed(seed) samples2 = dist.sample(n, seed) @@ -448,6 +449,7 @@ def testSampleN(self): if tf.executing_eagerly(): tf.random.set_seed(seed) samples1 = dist.sample(n, seed) + seed = test_util.clone_seed(seed) if tf.executing_eagerly(): tf.random.set_seed(seed) samples2 = dist.sample(n, seed) @@ -650,6 +652,7 @@ def testSampleN(self): if tf.executing_eagerly(): tf.random.set_seed(seed) samples1 = dist.sample(n, seed) + seed = test_util.clone_seed(seed) if tf.executing_eagerly(): tf.random.set_seed(seed) samples2 = dist.sample(n, seed) diff --git a/tensorflow_probability/python/distributions/jax_transformation_test.py b/tensorflow_probability/python/distributions/jax_transformation_test.py index 3604b4e464..d96a8876a9 100644 --- a/tensorflow_probability/python/distributions/jax_transformation_test.py +++ b/tensorflow_probability/python/distributions/jax_transformation_test.py @@ -160,6 +160,7 @@ def _sample(seed): seed = test_util.test_seed() result = jax.jit(_sample)(seed) if not FLAGS.execute_only: + seed = test_util.clone_seed(seed) self.assertAllClose(_sample(seed), result, rtol=1e-6, atol=1e-6) diff --git a/tensorflow_probability/python/distributions/kumaraswamy_test.py b/tensorflow_probability/python/distributions/kumaraswamy_test.py index 08a13321c3..3ad38d7371 100644 --- a/tensorflow_probability/python/distributions/kumaraswamy_test.py +++ b/tensorflow_probability/python/distributions/kumaraswamy_test.py @@ -291,6 +291,7 @@ def testKumaraswamySampleMultipleTimes(self): validate_args=True) samples1 = self.evaluate(dist1.sample(n_val, seed=seed)) + seed = test_util.clone_seed(seed) tf.random.set_seed(seed) dist2 = kumaraswamy.Kumaraswamy( concentration1=a_val, diff --git a/tensorflow_probability/python/distributions/linear_gaussian_ssm_test.py b/tensorflow_probability/python/distributions/linear_gaussian_ssm_test.py index c8ddf3953c..62a1d8a34d 100644 --- a/tensorflow_probability/python/distributions/linear_gaussian_ssm_test.py +++ b/tensorflow_probability/python/distributions/linear_gaussian_ssm_test.py @@ -971,11 +971,12 @@ def testInitSetsDefaultMask(self): model_with_mask.posterior_marginals(observed_time_series), model.posterior_marginals(observed_time_series, mask=observation_mask)) seed = test_util.test_seed(sampler_type='stateless') + seed2 = test_util.clone_seed(seed) self.assertAllEqual( model_with_mask.posterior_sample( observed_time_series, seed=seed), model.posterior_sample( - observed_time_series, mask=observation_mask, seed=seed)) + observed_time_series, mask=observation_mask, seed=seed2)) class MissingObservationsTestsSequential(_MissingObservationsTests): diff --git a/tensorflow_probability/python/distributions/probit_bernoulli_test.py b/tensorflow_probability/python/distributions/probit_bernoulli_test.py index 34852897ad..74c7800a2d 100644 --- a/tensorflow_probability/python/distributions/probit_bernoulli_test.py +++ b/tensorflow_probability/python/distributions/probit_bernoulli_test.py @@ -259,7 +259,7 @@ def _seed(seed=None): seed = test_util.test_seed() if seed is None else seed if tf.executing_eagerly(): tf.random.set_seed(seed) - return seed + return test_util.clone_seed(seed) seed = _seed() self.assertAllEqual( self.evaluate(dist.sample(n, seed)), diff --git a/tensorflow_probability/python/distributions/transformed_distribution_test.py b/tensorflow_probability/python/distributions/transformed_distribution_test.py index 8ea33bea6f..3076040611 100644 --- a/tensorflow_probability/python/distributions/transformed_distribution_test.py +++ b/tensorflow_probability/python/distributions/transformed_distribution_test.py @@ -321,6 +321,7 @@ def _forward_log_det_jacobian(self, x): sample, s=sigma, scale=np.exp(mu)) self.assertAllClose(expected_log_pdf, log_pdf, rtol=1e-4, atol=0.) + seed = test_util.clone_seed(seed) sample2 = self.evaluate(log_normal.sample(seed=seed)) self.assertAllClose(sample, sample2, rtol=1e-4) diff --git a/tensorflow_probability/python/experimental/bijectors/highway_flow_test.py b/tensorflow_probability/python/experimental/bijectors/highway_flow_test.py index 567f154aaa..caf4165f62 100644 --- a/tensorflow_probability/python/experimental/bijectors/highway_flow_test.py +++ b/tensorflow_probability/python/experimental/bijectors/highway_flow_test.py @@ -171,6 +171,7 @@ def testBijectorIsDeterministicGivenSeed(self): seed = test_util.test_seed(sampler_type='stateless') bijector1 = highway_flow.build_trainable_highway_flow( width, activation_fn=tf.nn.softplus, seed=seed) + seed = test_util.clone_seed(seed) bijector2 = highway_flow.build_trainable_highway_flow( width, activation_fn=tf.nn.softplus, seed=seed) self.evaluate([v.initializer for v in bijector1.trainable_variables]) diff --git a/tensorflow_probability/python/experimental/distributions/importance_resample_test.py b/tensorflow_probability/python/experimental/distributions/importance_resample_test.py index d09f05f110..a8c81d76be 100644 --- a/tensorflow_probability/python/experimental/distributions/importance_resample_test.py +++ b/tensorflow_probability/python/experimental/distributions/importance_resample_test.py @@ -111,7 +111,7 @@ def test_samples_approach_target_distribution(self): xs = self.evaluate(target.sample(num_samples, seed=seeds[0])) xs2 = self.evaluate(resampled2.sample(num_samples, seed=seeds[1])) xs20 = self.evaluate(resampled20.sample(num_samples, seed=seeds[2])) - for statistic_fn in (lambda x: x, lambda x: x**2): + for statistic_fn, seed in zip([lambda x: x, lambda x: x**2], seeds[3:]): true_statistic = tf.reduce_mean(statistic_fn(xs)) # Statistics should approach those of the target distribution as # `importance_sample_size` increases. @@ -120,7 +120,7 @@ def test_samples_approach_target_distribution(self): tf.abs(tf.reduce_mean(statistic_fn(xs20)) - true_statistic)) expectation_no_resampling = resampled2.self_normalized_expectation( - statistic_fn, importance_sample_size=10000, seed=seeds[3]) + statistic_fn, importance_sample_size=10000, seed=seed) self.assertAllClose(true_statistic, expectation_no_resampling, atol=0.15) def test_log_prob_approaches_target_distribution(self): @@ -237,32 +237,33 @@ def target_log_prob_fn(x): return prior + likelihood # Use importance sampling to infer an approximate posterior. - seed = test_util.test_seed(sampler_type='stateless') + seed = lambda: test_util.test_seed(sampler_type='stateless') approximate_posterior = importance_resample.ImportanceResample( proposal_distribution=normal.Normal(loc=0., scale=2.), target_log_prob_fn=target_log_prob_fn, importance_sample_size=3, - stochastic_approximation_seed=seed) + stochastic_approximation_seed=seed()) # Directly compute expectations under the posterior via importance weights. posterior_mean = approximate_posterior.self_normalized_expectation( - lambda x: x, seed=seed) + lambda x: x, seed=seed()) approximate_posterior.self_normalized_expectation( - lambda x: (x - posterior_mean)**2, seed=seed) + lambda x: (x - posterior_mean)**2, seed=seed()) - posterior_samples = approximate_posterior.sample(5, seed=seed) + posterior_samples = approximate_posterior.sample(5, seed=seed()) tf.reduce_mean(posterior_samples) tf.math.reduce_variance(posterior_samples) posterior_mean_efficient = ( approximate_posterior.self_normalized_expectation( - lambda x: x, sample_size=10, seed=seed)) + lambda x: x, sample_size=10, seed=seed())) approximate_posterior.self_normalized_expectation( - lambda x: (x - posterior_mean_efficient)**2, sample_size=10, seed=seed) + lambda x: ( + x - posterior_mean_efficient)**2, sample_size=10, seed=seed()) # Approximate the posterior density. xs = tf.linspace(-3., 3., 101) - approximate_posterior.prob(xs, sample_size=10, seed=seed) + approximate_posterior.prob(xs, sample_size=10, seed=seed()) def test_log_prob_independence_per_x(self): dist = importance_resample.ImportanceResample( diff --git a/tensorflow_probability/python/experimental/mcmc/elliptical_slice_sampler_test.py b/tensorflow_probability/python/experimental/mcmc/elliptical_slice_sampler_test.py index 5d7f310366..1c2db8f282 100644 --- a/tensorflow_probability/python/experimental/mcmc/elliptical_slice_sampler_test.py +++ b/tensorflow_probability/python/experimental/mcmc/elliptical_slice_sampler_test.py @@ -18,6 +18,7 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.experimental.mcmc import elliptical_slice_sampler +from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.internal import test_util from tensorflow_probability.python.mcmc import sample @@ -166,9 +167,11 @@ def normal_log_likelihood(state): def testTupleShapes(self): def normal_sampler(seed): + shapes = [(8, 31, 3), (8,)] + seeds = samplers.split_seed(seed, len(shapes)) return tuple( normal.Normal(0, 1).sample(shp, seed=seed) - for shp in [(8, 31, 3), (8,)]) + for shp, seed in zip(shapes, seeds)) params = normal_sampler(test_util.test_seed()) def normal_log_likelihood(p0, p1): diff --git a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py index 768c516084..1c874be708 100644 --- a/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py +++ b/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation_test.py @@ -477,10 +477,14 @@ def target_log_prob_fn(x): state = tf.zeros([64], self.dtype) seed = test_util.test_seed(sampler_type='stateless') step_0_kernel_results = kernel.bootstrap_results(state) + + seed, step_seed = samplers.split_seed(seed) state, step_1_kernel_results = kernel.one_step( - state, step_0_kernel_results, seed=seed) + state, step_0_kernel_results, seed=step_seed) + + seed, step_seed = samplers.split_seed(seed) _, step_2_kernel_results = kernel.one_step( - state, step_1_kernel_results, seed=seed) + state, step_1_kernel_results, seed=step_seed) (step_0_kernel_results, step_1_kernel_results, step_2_kernel_results) = self.evaluate([ diff --git a/tensorflow_probability/python/experimental/mcmc/sample_fold_test.py b/tensorflow_probability/python/experimental/mcmc/sample_fold_test.py index 336a961bf1..4b5f27bc42 100644 --- a/tensorflow_probability/python/experimental/mcmc/sample_fold_test.py +++ b/tensorflow_probability/python/experimental/mcmc/sample_fold_test.py @@ -212,6 +212,7 @@ def test_seed_reproducibility(self): kernel=fake_kernel, reducer=fake_reducer, seed=seed) + seed = test_util.clone_seed(seed) second_reduction_rslt, _, _ = sample_fold( num_steps=3, current_state=0., @@ -436,6 +437,7 @@ def test_seed_reproducibility(self): first_trace = sample_chain_with_burnin( num_results=5, current_state=0., kernel=first_fake_kernel, seed=seed).trace + seed = test_util.clone_seed(seed) second_trace = sample_chain_with_burnin( num_results=5, current_state=1., # difference should be irrelevant diff --git a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py index 2a9302a420..3084972fc6 100644 --- a/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py +++ b/tensorflow_probability/python/experimental/mcmc/sequential_monte_carlo_kernel_test.py @@ -70,6 +70,7 @@ def propose_and_update_log_weights_fn(_, weighted_particles, seed=None): tf.nest.map_structure(tf.convert_to_tensor, results))) # Re-initialize and run the same steps with the same seed. + seeds = test_util.clone_seed(seeds) kernel2 = SequentialMonteCarlo( propose_and_update_log_weights_fn=propose_and_update_log_weights_fn, resample_fn=weighted_resampling.resample_systematic, diff --git a/tensorflow_probability/python/experimental/sts_gibbs/dynamic_spike_and_slab_test.py b/tensorflow_probability/python/experimental/sts_gibbs/dynamic_spike_and_slab_test.py index 16e7da442d..0c5ec9726f 100644 --- a/tensorflow_probability/python/experimental/sts_gibbs/dynamic_spike_and_slab_test.py +++ b/tensorflow_probability/python/experimental/sts_gibbs/dynamic_spike_and_slab_test.py @@ -325,6 +325,7 @@ def do_sample(seed): return sampler.sample_noise_variance_and_weights( targets, initial_nonzeros, seed=seed) variance1, weights1 = self.evaluate(do_sample(seed)) + seed = test_util.clone_seed(seed) variance2, weights2 = self.evaluate(do_sample(seed)) self.assertAllFinite(variance1) self.assertAllClose(variance1, variance2) diff --git a/tensorflow_probability/python/experimental/sts_gibbs/spike_and_slab_test.py b/tensorflow_probability/python/experimental/sts_gibbs/spike_and_slab_test.py index e2c4ad8237..85b81bea29 100644 --- a/tensorflow_probability/python/experimental/sts_gibbs/spike_and_slab_test.py +++ b/tensorflow_probability/python/experimental/sts_gibbs/spike_and_slab_test.py @@ -366,6 +366,7 @@ def do_sample(seed): return sampler.sample_noise_variance_and_weights( targets, initial_nonzeros, seed=seed) variance1, weights1 = self.evaluate(do_sample(seed)) + seed = test_util.clone_seed(seed) variance2, weights2 = self.evaluate(do_sample(seed)) self.assertAllFinite(variance1) self.assertAllClose(variance1, variance2) diff --git a/tensorflow_probability/python/experimental/util/trainable_test.py b/tensorflow_probability/python/experimental/util/trainable_test.py index c9e23aae6b..315f61c1de 100644 --- a/tensorflow_probability/python/experimental/util/trainable_test.py +++ b/tensorflow_probability/python/experimental/util/trainable_test.py @@ -297,7 +297,10 @@ def test_initialization_is_deterministic_with_seed(self): seed = test_util.test_seed(sampler_type='stateless') init_fn, _ = trainable.make_trainable_stateless( normal.Normal, validate_args=True) - self.assertAllCloseNested(init_fn(seed=seed), init_fn(seed=seed)) + result1 = init_fn(seed=seed) + seed = test_util.clone_seed(seed) + result2 = init_fn(seed=seed) + self.assertAllCloseNested(result1, result2) def test_can_specify_parameter_dtype(self): init_fn, apply_fn = trainable.make_trainable_stateless( diff --git a/tensorflow_probability/python/internal/batched_rejection_sampler_test.py b/tensorflow_probability/python/internal/batched_rejection_sampler_test.py index 61b7c061ff..f002202d53 100644 --- a/tensorflow_probability/python/internal/batched_rejection_sampler_test.py +++ b/tensorflow_probability/python/internal/batched_rejection_sampler_test.py @@ -47,6 +47,7 @@ def uniform_less_than_point_five(seed): self.assertAllClose(-values, negative_values) # Check for reproducibility. + seed = test_util.clone_seed(seed) ((negative_values_2, values_2), _, _) = self.evaluate( brs.batched_las_vegas_algorithm( uniform_less_than_point_five, @@ -115,6 +116,7 @@ def proposal_fn(seed): self.assertLess(ks, 0.02) # Check for reproducibility. + seed = test_util.clone_seed(seed) all_samples_2, _ = self.evaluate(brs.batched_rejection_sampler( proposal_fn, target_fn, seed=seed, dtype=dtype)) self.assertAllEqual(all_samples, all_samples_2) diff --git a/tensorflow_probability/python/internal/samplers_test.py b/tensorflow_probability/python/internal/samplers_test.py index 2aa7009f6d..8c64db4a54 100644 --- a/tensorflow_probability/python/internal/samplers_test.py +++ b/tensorflow_probability/python/internal/samplers_test.py @@ -93,6 +93,7 @@ def test_sanitize_tensor_or_tensorlike(self): def test_split(self): seed = test_util.test_seed(sampler_type='stateless') seed1, seed2 = samplers.split_seed(seed) + seed = test_util.clone_seed(seed) seed3, seed4 = samplers.split_seed(seed) seed, seed1, seed2, seed3, seed4 = self.evaluate( [seed, seed1, seed2, seed3, seed4]) @@ -142,6 +143,7 @@ def test_sampler(self, sampler, kwargs): self.skipTest('gamma sampler not implemented for rbg PRNG.') seed = test_util.test_seed(sampler_type='stateless') s1 = sampler(seed=seed, **kwargs) + seed = test_util.clone_seed(seed) s2 = sampler(seed=seed, **kwargs) self.assertAllEqual(s1, s2) diff --git a/tensorflow_probability/python/internal/test_util.py b/tensorflow_probability/python/internal/test_util.py index f148f00ce6..ae17d5a9cc 100644 --- a/tensorflow_probability/python/internal/test_util.py +++ b/tensorflow_probability/python/internal/test_util.py @@ -1544,6 +1544,17 @@ def test_seed(hardcoded_seed=None, return answer +def clone_seed(seed): + """Clone a seed: this is useful for JAX's experimental key reuse checking.""" + # TODO(b/328085305): switch to standard clone API when possible. + if JAX_MODE: + import jax # pylint: disable=g-import-not-at-top + return jax.random.wrap_key_data( + jax.random.key_data(seed), impl=jax.random.key_impl(seed) + ) + return seed + + def test_seed_stream(salt='Salt of the Earth', hardcoded_seed=None): """Returns a command-line-controllable SeedStream PRNG for unit tests. diff --git a/tensorflow_probability/python/internal/trainable_state_util_test.py b/tensorflow_probability/python/internal/trainable_state_util_test.py index 47bcfea474..f14de01fec 100644 --- a/tensorflow_probability/python/internal/trainable_state_util_test.py +++ b/tensorflow_probability/python/internal/trainable_state_util_test.py @@ -140,9 +140,11 @@ def test_init_supports_arg_or_kwarg_seed(self): init_fn, _ = trainable_state_util.as_stateless_builder( seed_generator)() self.assertLen(init_fn(seed=seed), 5) + seed = test_util.clone_seed(seed) + seed2 = test_util.clone_seed(seed) # Check that we can invoke init_fn with an arg or kwarg seed, # regardless of how the inner functions are parameterized. - self.assertAllCloseNested(init_fn(seed), init_fn(seed=seed)) + self.assertAllCloseNested(init_fn(seed), init_fn(seed=seed2)) if not JAX_MODE: # Check that we can initialize with no seed. @@ -165,6 +167,7 @@ def test_distribution_init_apply(self, generator, expected_num_params, shape): # Check that the distribution's samples have the expected shape. dist = apply_fn(params) + seed = test_util.clone_seed(seed) x = dist.sample(seed=seed) self.assertAllEqualNested(shape, tf.nest.map_structure(ps.shape, x)) @@ -311,6 +314,7 @@ def test_initialization_is_deterministic_with_seed(self): variables1 = trainable_jd1.trainable_variables self.assertLen(variables1, 5) + seed = test_util.clone_seed(seed) trainable_jd2 = make_trainable_jd(seed=seed) variables2 = trainable_jd2.trainable_variables self.evaluate([v.initializer for v in variables1 + variables2]) diff --git a/tensorflow_probability/python/mcmc/hmc_test.py b/tensorflow_probability/python/mcmc/hmc_test.py index ef0fd4d4a4..246e1426bb 100644 --- a/tensorflow_probability/python/mcmc/hmc_test.py +++ b/tensorflow_probability/python/mcmc/hmc_test.py @@ -1160,7 +1160,8 @@ def testReproducibleSingleStepStatelessSeed(self): tr_nm1.accepted_results.target_log_prob) # Rerun the kernel with the seed that it reported it used - state, kr = k.one_step(states[n - 1], tr_nm1, seed=tr_n.seed) + state, kr = k.one_step( + states[n - 1], tr_nm1, seed=test_util.clone_seed(tr_n.seed)) # Check that the results are the same self.assertAllClose(state, states[n]) self.assertAllClose(kr, tr_n) diff --git a/tensorflow_probability/python/mcmc/transformed_kernel_test.py b/tensorflow_probability/python/mcmc/transformed_kernel_test.py index ce72df9c94..c7f36bc333 100644 --- a/tensorflow_probability/python/mcmc/transformed_kernel_test.py +++ b/tensorflow_probability/python/mcmc/transformed_kernel_test.py @@ -449,13 +449,14 @@ def test_nested_transform(self): outer_pkr_two.inner_results.inner_results.accepted_results) seed = test_util.test_seed(sampler_type='stateless') + dup_seed = lambda: test_util.clone_seed(seed) outer_results_one, outer_results_two = self.evaluate([ - outer_kernel.one_step(2., outer_pkr_one, seed=seed), - outer_kernel.one_step(9., outer_pkr_two, seed=seed) + outer_kernel.one_step(2., outer_pkr_one, seed=dup_seed()), + outer_kernel.one_step(9., outer_pkr_two, seed=dup_seed()) ]) chain_results_one, chain_results_two = self.evaluate([ - chain_kernel.one_step(2., chain_pkr_one, seed=seed), - chain_kernel.one_step(9., chain_pkr_two, seed=seed) + chain_kernel.one_step(2., chain_pkr_one, seed=dup_seed()), + chain_kernel.one_step(9., chain_pkr_two, seed=dup_seed()) ]) self.assertNear(chain_results_one[0], outer_results_one[0], diff --git a/tensorflow_probability/python/sts/components/regression_test.py b/tensorflow_probability/python/sts/components/regression_test.py index a8266dbf5b..dcc2ea8a6a 100644 --- a/tensorflow_probability/python/sts/components/regression_test.py +++ b/tensorflow_probability/python/sts/components/regression_test.py @@ -191,8 +191,12 @@ def test_builds_without_errors(self): sparse_regression = SparseLinearRegression( design_matrix=design_matrix, weights_batch_shape=weights_batch_shape) - prior_params = [param.prior.sample(seed=prior_seed) - for param in sparse_regression.parameters] + prior_seeds = samplers.split_seed( + prior_seed, len(sparse_regression.parameters)) + prior_params = [ + param.prior.sample(seed=seed) + for param, seed in zip(sparse_regression.parameters, prior_seeds) + ] ssm = sparse_regression.make_state_space_model( num_timesteps=num_timesteps, diff --git a/tensorflow_probability/python/sts/structural_time_series_test.py b/tensorflow_probability/python/sts/structural_time_series_test.py index 33d2acc670..1f049889a8 100644 --- a/tensorflow_probability/python/sts/structural_time_series_test.py +++ b/tensorflow_probability/python/sts/structural_time_series_test.py @@ -71,7 +71,9 @@ def test_broadcast_batch_shapes(self): initial_effect_prior=loc_prior) model = Sum([linear_trend, seasonal], observation_noise_scale_prior=partial_scale_prior) - param_samples = [p.prior.sample(seed=seed) for p in model.parameters] + seeds = samplers.split_seed(seed, n=len(model.parameters)) + param_samples = [ + p.prior.sample(seed=s) for p, s in zip(model.parameters, seeds)] ssm = model.make_state_space_model(num_timesteps=2, param_vals=param_samples) @@ -108,7 +110,7 @@ def test_adding_two_sums(self): seed = test_util.test_seed(sampler_type='stateless') def observation_noise_scale_prior_sample(s): - return s.parameters[0].prior.sample(seed=seed) + return s.parameters[0].prior.sample(seed=test_util.clone_seed(seed)) self.assertAllEqual(observation_noise_scale_prior_sample(s3), observation_noise_scale_prior_sample(s1)) self.assertAllEqual(observation_noise_scale_prior_sample(s3), @@ -179,7 +181,10 @@ def test_state_space_model(self): seed = test_util.test_seed(sampler_type='stateless') model = self._build_sts() - dummy_param_vals = [p.prior.sample(seed=seed) for p in model.parameters] + seeds = samplers.split_seed(seed, n=len(model.parameters)) + dummy_param_vals = [ + p.prior.sample(seed=s) for p, s in zip(model.parameters, seeds) + ] initial_state_prior = mvn_diag.MultivariateNormalDiag( loc=-2. + tf.zeros([model.latent_size]), scale_diag=3. * tf.ones([model.latent_size])) diff --git a/tensorflow_probability/python/vi/csiszar_divergence_test.py b/tensorflow_probability/python/vi/csiszar_divergence_test.py index eff6bf2215..676a6dc1f4 100644 --- a/tensorflow_probability/python/vi/csiszar_divergence_test.py +++ b/tensorflow_probability/python/vi/csiszar_divergence_test.py @@ -481,6 +481,7 @@ def test_kl_forward(self): sample_size=int(4e5), seed=seed) + seed = test_util.clone_seed(seed) approx_kl_self_normalized = cd.monte_carlo_variational_loss( discrepancy_fn=(lambda logu: cd.kl_forward(logu, self_normalized=True)), target_log_prob_fn=p.log_prob, @@ -514,6 +515,7 @@ def test_kl_reverse(self): sample_size=int(4.5e5), seed=seed) + seed = test_util.clone_seed(seed) approx_kl_self_normalized = cd.monte_carlo_variational_loss( target_log_prob_fn=p.log_prob, surrogate_posterior=q, @@ -553,6 +555,8 @@ def test_kl_forward_multidim(self): sample_size=int(6e5), seed=seed) + seed = test_util.test_seed() + approx_kl_self_normalized = cd.monte_carlo_variational_loss( target_log_prob_fn=p.log_prob, surrogate_posterior=q, @@ -593,6 +597,8 @@ def test_kl_reverse_multidim(self): sample_size=int(6e5), seed=seed) + seed = test_util.test_seed() + approx_kl_self_normalized = cd.monte_carlo_variational_loss( target_log_prob_fn=p.log_prob, surrogate_posterior=q, @@ -637,6 +643,8 @@ def target_log_prob_fn(z, x): sample_size=int(3e5), seed=seed) + seed = test_util.test_seed() + reverse_kl_named = cd.monte_carlo_variational_loss( target_log_prob_fn=target_log_prob_fn, surrogate_posterior=q_named, @@ -669,6 +677,7 @@ def test_importance_weighted_objective(self): self.assertAllGreater(elbo_loss, 0.) # Check that importance sampling reduces the loss towards zero. + seed = test_util.clone_seed(seed) iwae_10_loss = cd.monte_carlo_variational_loss( target_log_prob_fn=target.log_prob, surrogate_posterior=proposal, @@ -679,6 +688,7 @@ def test_importance_weighted_objective(self): self.assertAllGreater(elbo_loss, iwae_10_loss) self.assertAllGreater(iwae_10_loss, 0) + seed = test_util.clone_seed(seed) iwae_100_loss = cd.monte_carlo_variational_loss( target_log_prob_fn=target.log_prob, surrogate_posterior=proposal, @@ -690,6 +700,7 @@ def test_importance_weighted_objective(self): self.assertAllClose(iwae_100_loss, 0, atol=0.1) # Check reproducibility + seed = test_util.clone_seed(seed) elbo_loss_again = cd.monte_carlo_variational_loss( target_log_prob_fn=target.log_prob, surrogate_posterior=proposal, @@ -699,6 +710,7 @@ def test_importance_weighted_objective(self): seed=seed) self.assertAllClose(elbo_loss_again, elbo_loss) + seed = test_util.clone_seed(seed) iwae_10_loss_again = cd.monte_carlo_variational_loss( target_log_prob_fn=target.log_prob, surrogate_posterior=proposal, @@ -712,7 +724,6 @@ def test_importance_weighted_objective(self): def test_score_trick(self): d = 5 # Dimension sample_size = int(4.5e5) - seed = test_util.test_seed() # Variance is very high when approximating Forward KL, so we make # scale_diag large. This ensures q "covers" p and thus Var_q[p/q] is @@ -731,7 +742,7 @@ def _fn(s): discrepancy_fn=func, sample_size=sample_size, gradient_estimator=gradient_estimator, - seed=seed) + seed=test_util.test_seed()) return _fn approx_kl = construct_monte_carlo_csiszar_f_divergence( @@ -866,6 +877,7 @@ def loss(params, gradient_estimator, seed): loss, gradient_estimator=cd.GradientEstimators.REPARAMETERIZATION, seed=seed), [initial_params]) + seed = test_util.clone_seed(seed) dreg_loss, dreg_grad = gradient.value_and_gradient( functools.partial( loss, @@ -918,6 +930,7 @@ def target_log_prob_fn(x): tf.cast(importance_sample_size, dtype=log_weights.dtype)), axis=0) + seed = test_util.clone_seed(seed) loss = cd.monte_carlo_variational_loss( target_log_prob_fn, surrogate_posterior=surrogate_posterior, @@ -1023,7 +1036,6 @@ def test_vimco_and_gradient(self): dims = 5 # Dimension num_draws = int(1e3) num_batch_draws = int(3) - seed = test_util.test_seed(sampler_type='stateless') f = lambda logu: cd.kl_reverse(logu, self_normalized=False) np_f = lambda logu: -logu @@ -1038,6 +1050,7 @@ def test_vimco_and_gradient(self): scale_diag=tf.tile([s], [dims]))) def vimco_loss(s): + seed = test_util.test_seed(sampler_type='stateless') return cd.monte_carlo_variational_loss( p.log_prob, surrogate_posterior=build_q(s), @@ -1048,6 +1061,7 @@ def vimco_loss(s): seed=seed) def logu(s): + seed = test_util.test_seed(sampler_type='stateless') q = build_q(s) x = q.sample(sample_shape=[num_draws, num_batch_draws], # Brittle hack to ensure that the q samples match those @@ -1060,6 +1074,7 @@ def f_log_sum_u(s): return f(leave_one_out.log_soomean_exp(logu(s), axis=0)[::-1][0]) def q_log_prob_x(s): + seed = test_util.test_seed(sampler_type='stateless') q = build_q(s) x = q.sample(sample_shape=[num_draws, num_batch_draws], # Brittle hack to ensure that the q samples match those @@ -1144,6 +1159,7 @@ def p_log_prob(z, x): gradient_estimator=cd.GradientEstimators.VIMCO, seed=seed) + seed = test_util.clone_seed(seed) reverse_kl_named = cd.monte_carlo_variational_loss( p_log_prob, surrogate_posterior=q_named, @@ -1178,6 +1194,7 @@ def p_log_prob(z, x): gradient_estimator=cd.GradientEstimators.VIMCO, seed=seed) + seed = test_util.clone_seed(seed) reverse_kl_again = cd.monte_carlo_variational_loss( p_log_prob, surrogate_posterior=q,