diff --git a/tensorflow_probability/python/experimental/mcmc/windowed_sampling.py b/tensorflow_probability/python/experimental/mcmc/windowed_sampling.py index 643b409864..c8297d24ef 100644 --- a/tensorflow_probability/python/experimental/mcmc/windowed_sampling.py +++ b/tensorflow_probability/python/experimental/mcmc/windowed_sampling.py @@ -641,12 +641,14 @@ def windowed_adaptive_nuts(n_draws, Where to initialize the step size for the leapfrog integrator. The structure should broadcast with `current_state`. For example, if the initial state is + ``` { - 'a': tf.zeros(n_chains), - 'b': tf.zeros([n_chains, n_features]), + 'a': tf.zeros(n_chains), + 'b': tf.zeros([n_chains, n_features]), } ``` + then any of `1.`, `{'a': 1., 'b': 1.}`, or `{'a': tf.ones(n_chains), 'b': tf.ones([n_chains, n_features])}` will work. Defaults to the dimension of the log density to the 0.25 power. @@ -690,6 +692,7 @@ class defaults are used otherwise. **pins: These are used to condition the provided joint distribution, and are passed directly to `joint_dist.experimental_pin(**pins)`. + Returns: A single structure of draws is returned in case the trace_fn is `None`, and `return_final_kernel_results` is `False`. If there is a trace function, @@ -766,10 +769,14 @@ def windowed_adaptive_hmc(n_draws, Where to initialize the step size for the leapfrog integrator. The structure should broadcast with `current_state`. For example, if the initial state is + ``` - {'a': tf.zeros(n_chains), - 'b': tf.zeros([n_chains, n_features])} - ``` + { + 'a': tf.zeros(n_chains), + 'b': tf.zeros([n_chains, n_features]), + } + ``` + then any of `1.`, `{'a': 1., 'b': 1.}`, or `{'a': tf.ones(n_chains), 'b': tf.ones([n_chains, n_features])}` will work. Defaults to the dimension of the log density to the 0.25 power. @@ -801,6 +808,7 @@ class defaults are used otherwise. **pins: These are used to condition the provided joint distribution, and are passed directly to `joint_dist.experimental_pin(**pins)`. + Returns: A single structure of draws is returned in case the trace_fn is `None`, and `return_final_kernel_results` is `False`. If there is a trace function,