Skip to content

Commit

Permalink
tensorflow_probability: avoid JAX key reuse in tests
Browse files Browse the repository at this point in the history
These issues were found by running the test suite with jax_enable_key_reuse_checks=True, available since jax-ml/jax#19795.

I did find some issues in non-test code, but I plan to address those issues individually as they may have user-visible effects. This change only touches innocuous key reuse in test files.

PiperOrigin-RevId: 613958607
  • Loading branch information
vanderplas authored and tensorflower-gardener committed Mar 8, 2024
1 parent 93bdfbf commit 0af6f41
Show file tree
Hide file tree
Showing 28 changed files with 108 additions and 37 deletions.
3 changes: 1 addition & 2 deletions tensorflow_probability/python/bijectors/ffjord_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def testJacobianDiagonalScaling(self, dtype):
)

def testHutchinsonsNormalEstimator(self, dtype):
seed = test_util.test_seed()
tf_dtype = tf.as_dtype(dtype)
num_dims = 10
np.random.seed(seed=test_util.test_seed(sampler_type='integer'))
Expand All @@ -170,7 +169,7 @@ def testHutchinsonsNormalEstimator(self, dtype):

def trace_augmentation_fn(ode_fn, z_shape, dtype):
return ffjord.trace_jacobian_hutchinson(
ode_fn, z_shape, dtype, num_samples=128, seed=seed)
ode_fn, z_shape, dtype, num_samples=128, seed=test_util.test_seed())

bijector = ffjord.FFJORD(
trace_augmentation_fn=trace_augmentation_fn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,16 @@ def build_operator(self):
tf.linalg.LinearOperatorIdentity(2)]], is_non_singular=True)

def build_batched_operator(self):
seed = test_util.test_seed()
seed1 = test_util.test_seed()
seed2 = test_util.test_seed()
return tf.linalg.LinearOperatorBlockLowerTriangular([
[tf.linalg.LinearOperatorFullMatrix(
tf.random.normal((3, 4, 4), dtype=tf.float32, seed=seed),
tf.random.normal((3, 4, 4), dtype=tf.float32, seed=seed1),
is_non_singular=True)],
[tf.linalg.LinearOperatorZeros(
3, 4, is_square=False, is_self_adjoint=False),
tf.linalg.LinearOperatorFullMatrix(
tf.random.normal((3, 3), dtype=tf.float32, seed=seed),
tf.random.normal((3, 3), dtype=tf.float32, seed=seed2),
is_non_singular=True)]
], is_non_singular=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _seed(seed=None):
seed = test_util.test_seed() if seed is None else seed
if tf.executing_eagerly():
tf.random.set_seed(seed)
return seed
return test_util.clone_seed(seed)
seed = _seed()
self.assertAllEqual(
self.evaluate(dist.sample(n, _seed(seed=seed))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def _seed(seed=None):
seed = test_util.test_seed() if seed is None else seed
if tf.executing_eagerly():
tf1.set_random_seed(seed)
return seed
return test_util.clone_seed(seed)

seed = _seed()
self.assertAllClose(
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_probability/python/distributions/empirical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def testSampleN(self):
if tf.executing_eagerly():
tf.random.set_seed(seed)
samples1 = dist.sample(n, seed)
seed = test_util.clone_seed(seed)
if tf.executing_eagerly():
tf.random.set_seed(seed)
samples2 = dist.sample(n, seed)
Expand Down Expand Up @@ -448,6 +449,7 @@ def testSampleN(self):
if tf.executing_eagerly():
tf.random.set_seed(seed)
samples1 = dist.sample(n, seed)
seed = test_util.clone_seed(seed)
if tf.executing_eagerly():
tf.random.set_seed(seed)
samples2 = dist.sample(n, seed)
Expand Down Expand Up @@ -650,6 +652,7 @@ def testSampleN(self):
if tf.executing_eagerly():
tf.random.set_seed(seed)
samples1 = dist.sample(n, seed)
seed = test_util.clone_seed(seed)
if tf.executing_eagerly():
tf.random.set_seed(seed)
samples2 = dist.sample(n, seed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def _sample(seed):
seed = test_util.test_seed()
result = jax.jit(_sample)(seed)
if not FLAGS.execute_only:
seed = test_util.clone_seed(seed)
self.assertAllClose(_sample(seed), result, rtol=1e-6,
atol=1e-6)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def testKumaraswamySampleMultipleTimes(self):
validate_args=True)
samples1 = self.evaluate(dist1.sample(n_val, seed=seed))

seed = test_util.clone_seed(seed)
tf.random.set_seed(seed)
dist2 = kumaraswamy.Kumaraswamy(
concentration1=a_val,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -971,11 +971,12 @@ def testInitSetsDefaultMask(self):
model_with_mask.posterior_marginals(observed_time_series),
model.posterior_marginals(observed_time_series, mask=observation_mask))
seed = test_util.test_seed(sampler_type='stateless')
seed2 = test_util.clone_seed(seed)
self.assertAllEqual(
model_with_mask.posterior_sample(
observed_time_series, seed=seed),
model.posterior_sample(
observed_time_series, mask=observation_mask, seed=seed))
observed_time_series, mask=observation_mask, seed=seed2))


class MissingObservationsTestsSequential(_MissingObservationsTests):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _seed(seed=None):
seed = test_util.test_seed() if seed is None else seed
if tf.executing_eagerly():
tf.random.set_seed(seed)
return seed
return test_util.clone_seed(seed)
seed = _seed()
self.assertAllEqual(
self.evaluate(dist.sample(n, seed)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def _forward_log_det_jacobian(self, x):
sample, s=sigma, scale=np.exp(mu))
self.assertAllClose(expected_log_pdf, log_pdf, rtol=1e-4, atol=0.)

seed = test_util.clone_seed(seed)
sample2 = self.evaluate(log_normal.sample(seed=seed))
self.assertAllClose(sample, sample2, rtol=1e-4)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def testBijectorIsDeterministicGivenSeed(self):
seed = test_util.test_seed(sampler_type='stateless')
bijector1 = highway_flow.build_trainable_highway_flow(
width, activation_fn=tf.nn.softplus, seed=seed)
seed = test_util.clone_seed(seed)
bijector2 = highway_flow.build_trainable_highway_flow(
width, activation_fn=tf.nn.softplus, seed=seed)
self.evaluate([v.initializer for v in bijector1.trainable_variables])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_samples_approach_target_distribution(self):
xs = self.evaluate(target.sample(num_samples, seed=seeds[0]))
xs2 = self.evaluate(resampled2.sample(num_samples, seed=seeds[1]))
xs20 = self.evaluate(resampled20.sample(num_samples, seed=seeds[2]))
for statistic_fn in (lambda x: x, lambda x: x**2):
for statistic_fn, seed in zip([lambda x: x, lambda x: x**2], seeds[3:]):
true_statistic = tf.reduce_mean(statistic_fn(xs))
# Statistics should approach those of the target distribution as
# `importance_sample_size` increases.
Expand All @@ -120,7 +120,7 @@ def test_samples_approach_target_distribution(self):
tf.abs(tf.reduce_mean(statistic_fn(xs20)) - true_statistic))

expectation_no_resampling = resampled2.self_normalized_expectation(
statistic_fn, importance_sample_size=10000, seed=seeds[3])
statistic_fn, importance_sample_size=10000, seed=seed)
self.assertAllClose(true_statistic, expectation_no_resampling, atol=0.15)

def test_log_prob_approaches_target_distribution(self):
Expand Down Expand Up @@ -237,32 +237,33 @@ def target_log_prob_fn(x):
return prior + likelihood

# Use importance sampling to infer an approximate posterior.
seed = test_util.test_seed(sampler_type='stateless')
seed = lambda: test_util.test_seed(sampler_type='stateless')
approximate_posterior = importance_resample.ImportanceResample(
proposal_distribution=normal.Normal(loc=0., scale=2.),
target_log_prob_fn=target_log_prob_fn,
importance_sample_size=3,
stochastic_approximation_seed=seed)
stochastic_approximation_seed=seed())

# Directly compute expectations under the posterior via importance weights.
posterior_mean = approximate_posterior.self_normalized_expectation(
lambda x: x, seed=seed)
lambda x: x, seed=seed())
approximate_posterior.self_normalized_expectation(
lambda x: (x - posterior_mean)**2, seed=seed)
lambda x: (x - posterior_mean)**2, seed=seed())

posterior_samples = approximate_posterior.sample(5, seed=seed)
posterior_samples = approximate_posterior.sample(5, seed=seed())
tf.reduce_mean(posterior_samples)
tf.math.reduce_variance(posterior_samples)

posterior_mean_efficient = (
approximate_posterior.self_normalized_expectation(
lambda x: x, sample_size=10, seed=seed))
lambda x: x, sample_size=10, seed=seed()))
approximate_posterior.self_normalized_expectation(
lambda x: (x - posterior_mean_efficient)**2, sample_size=10, seed=seed)
lambda x: (
x - posterior_mean_efficient)**2, sample_size=10, seed=seed())

# Approximate the posterior density.
xs = tf.linspace(-3., 3., 101)
approximate_posterior.prob(xs, sample_size=10, seed=seed)
approximate_posterior.prob(xs, sample_size=10, seed=seed())

def test_log_prob_independence_per_x(self):
dist = importance_resample.ImportanceResample(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.distributions import normal
from tensorflow_probability.python.experimental.mcmc import elliptical_slice_sampler
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal import test_util
from tensorflow_probability.python.mcmc import sample

Expand Down Expand Up @@ -166,9 +167,11 @@ def normal_log_likelihood(state):
def testTupleShapes(self):

def normal_sampler(seed):
shapes = [(8, 31, 3), (8,)]
seeds = samplers.split_seed(seed, len(shapes))
return tuple(
normal.Normal(0, 1).sample(shp, seed=seed)
for shp in [(8, 31, 3), (8,)])
for shp, seed in zip(shapes, seeds))
params = normal_sampler(test_util.test_seed())

def normal_log_likelihood(p0, p1):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,10 +477,14 @@ def target_log_prob_fn(x):
state = tf.zeros([64], self.dtype)
seed = test_util.test_seed(sampler_type='stateless')
step_0_kernel_results = kernel.bootstrap_results(state)

seed, step_seed = samplers.split_seed(seed)
state, step_1_kernel_results = kernel.one_step(
state, step_0_kernel_results, seed=seed)
state, step_0_kernel_results, seed=step_seed)

seed, step_seed = samplers.split_seed(seed)
_, step_2_kernel_results = kernel.one_step(
state, step_1_kernel_results, seed=seed)
state, step_1_kernel_results, seed=step_seed)

(step_0_kernel_results, step_1_kernel_results,
step_2_kernel_results) = self.evaluate([
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def test_seed_reproducibility(self):
kernel=fake_kernel,
reducer=fake_reducer,
seed=seed)
seed = test_util.clone_seed(seed)
second_reduction_rslt, _, _ = sample_fold(
num_steps=3,
current_state=0.,
Expand Down Expand Up @@ -436,6 +437,7 @@ def test_seed_reproducibility(self):
first_trace = sample_chain_with_burnin(
num_results=5, current_state=0., kernel=first_fake_kernel,
seed=seed).trace
seed = test_util.clone_seed(seed)
second_trace = sample_chain_with_burnin(
num_results=5,
current_state=1., # difference should be irrelevant
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def propose_and_update_log_weights_fn(_, weighted_particles, seed=None):
tf.nest.map_structure(tf.convert_to_tensor, results)))

# Re-initialize and run the same steps with the same seed.
seeds = test_util.clone_seed(seeds)
kernel2 = SequentialMonteCarlo(
propose_and_update_log_weights_fn=propose_and_update_log_weights_fn,
resample_fn=weighted_resampling.resample_systematic,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def do_sample(seed):
return sampler.sample_noise_variance_and_weights(
targets, initial_nonzeros, seed=seed)
variance1, weights1 = self.evaluate(do_sample(seed))
seed = test_util.clone_seed(seed)
variance2, weights2 = self.evaluate(do_sample(seed))
self.assertAllFinite(variance1)
self.assertAllClose(variance1, variance2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def do_sample(seed):
return sampler.sample_noise_variance_and_weights(
targets, initial_nonzeros, seed=seed)
variance1, weights1 = self.evaluate(do_sample(seed))
seed = test_util.clone_seed(seed)
variance2, weights2 = self.evaluate(do_sample(seed))
self.assertAllFinite(variance1)
self.assertAllClose(variance1, variance2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,10 @@ def test_initialization_is_deterministic_with_seed(self):
seed = test_util.test_seed(sampler_type='stateless')
init_fn, _ = trainable.make_trainable_stateless(
normal.Normal, validate_args=True)
self.assertAllCloseNested(init_fn(seed=seed), init_fn(seed=seed))
result1 = init_fn(seed=seed)
seed = test_util.clone_seed(seed)
result2 = init_fn(seed=seed)
self.assertAllCloseNested(result1, result2)

def test_can_specify_parameter_dtype(self):
init_fn, apply_fn = trainable.make_trainable_stateless(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def uniform_less_than_point_five(seed):
self.assertAllClose(-values, negative_values)

# Check for reproducibility.
seed = test_util.clone_seed(seed)
((negative_values_2, values_2), _, _) = self.evaluate(
brs.batched_las_vegas_algorithm(
uniform_less_than_point_five,
Expand Down Expand Up @@ -115,6 +116,7 @@ def proposal_fn(seed):
self.assertLess(ks, 0.02)

# Check for reproducibility.
seed = test_util.clone_seed(seed)
all_samples_2, _ = self.evaluate(brs.batched_rejection_sampler(
proposal_fn, target_fn, seed=seed, dtype=dtype))
self.assertAllEqual(all_samples, all_samples_2)
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_probability/python/internal/samplers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def test_sanitize_tensor_or_tensorlike(self):
def test_split(self):
seed = test_util.test_seed(sampler_type='stateless')
seed1, seed2 = samplers.split_seed(seed)
seed = test_util.clone_seed(seed)
seed3, seed4 = samplers.split_seed(seed)
seed, seed1, seed2, seed3, seed4 = self.evaluate(
[seed, seed1, seed2, seed3, seed4])
Expand Down Expand Up @@ -142,6 +143,7 @@ def test_sampler(self, sampler, kwargs):
self.skipTest('gamma sampler not implemented for rbg PRNG.')
seed = test_util.test_seed(sampler_type='stateless')
s1 = sampler(seed=seed, **kwargs)
seed = test_util.clone_seed(seed)
s2 = sampler(seed=seed, **kwargs)
self.assertAllEqual(s1, s2)

Expand Down
11 changes: 11 additions & 0 deletions tensorflow_probability/python/internal/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,17 @@ def test_seed(hardcoded_seed=None,
return answer


def clone_seed(seed):
"""Clone a seed: this is useful for JAX's experimental key reuse checking."""
# TODO(b/328085305): switch to standard clone API when possible.
if JAX_MODE:
import jax # pylint: disable=g-import-not-at-top
return jax.random.wrap_key_data(
jax.random.key_data(seed), impl=jax.random.key_impl(seed)
)
return seed


def test_seed_stream(salt='Salt of the Earth', hardcoded_seed=None):
"""Returns a command-line-controllable SeedStream PRNG for unit tests.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,11 @@ def test_init_supports_arg_or_kwarg_seed(self):
init_fn, _ = trainable_state_util.as_stateless_builder(
seed_generator)()
self.assertLen(init_fn(seed=seed), 5)
seed = test_util.clone_seed(seed)
seed2 = test_util.clone_seed(seed)
# Check that we can invoke init_fn with an arg or kwarg seed,
# regardless of how the inner functions are parameterized.
self.assertAllCloseNested(init_fn(seed), init_fn(seed=seed))
self.assertAllCloseNested(init_fn(seed), init_fn(seed=seed2))

if not JAX_MODE:
# Check that we can initialize with no seed.
Expand All @@ -165,6 +167,7 @@ def test_distribution_init_apply(self, generator, expected_num_params, shape):

# Check that the distribution's samples have the expected shape.
dist = apply_fn(params)
seed = test_util.clone_seed(seed)
x = dist.sample(seed=seed)
self.assertAllEqualNested(shape, tf.nest.map_structure(ps.shape, x))

Expand Down Expand Up @@ -311,6 +314,7 @@ def test_initialization_is_deterministic_with_seed(self):
variables1 = trainable_jd1.trainable_variables
self.assertLen(variables1, 5)

seed = test_util.clone_seed(seed)
trainable_jd2 = make_trainable_jd(seed=seed)
variables2 = trainable_jd2.trainable_variables
self.evaluate([v.initializer for v in variables1 + variables2])
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_probability/python/mcmc/hmc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,7 +1160,8 @@ def testReproducibleSingleStepStatelessSeed(self):
tr_nm1.accepted_results.target_log_prob)

# Rerun the kernel with the seed that it reported it used
state, kr = k.one_step(states[n - 1], tr_nm1, seed=tr_n.seed)
state, kr = k.one_step(
states[n - 1], tr_nm1, seed=test_util.clone_seed(tr_n.seed))
# Check that the results are the same
self.assertAllClose(state, states[n])
self.assertAllClose(kr, tr_n)
Expand Down
9 changes: 5 additions & 4 deletions tensorflow_probability/python/mcmc/transformed_kernel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,13 +449,14 @@ def test_nested_transform(self):
outer_pkr_two.inner_results.inner_results.accepted_results)

seed = test_util.test_seed(sampler_type='stateless')
dup_seed = lambda: test_util.clone_seed(seed)
outer_results_one, outer_results_two = self.evaluate([
outer_kernel.one_step(2., outer_pkr_one, seed=seed),
outer_kernel.one_step(9., outer_pkr_two, seed=seed)
outer_kernel.one_step(2., outer_pkr_one, seed=dup_seed()),
outer_kernel.one_step(9., outer_pkr_two, seed=dup_seed())
])
chain_results_one, chain_results_two = self.evaluate([
chain_kernel.one_step(2., chain_pkr_one, seed=seed),
chain_kernel.one_step(9., chain_pkr_two, seed=seed)
chain_kernel.one_step(2., chain_pkr_one, seed=dup_seed()),
chain_kernel.one_step(9., chain_pkr_two, seed=dup_seed())
])
self.assertNear(chain_results_one[0],
outer_results_one[0],
Expand Down
Loading

0 comments on commit 0af6f41

Please sign in to comment.