diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md index 7105bc51ee..1534103f3f 100644 --- a/STYLE_GUIDE.md +++ b/STYLE_GUIDE.md @@ -109,7 +109,7 @@ they supersede all previous conventions. * Definitely use named args for 2nd args onward in docstrings. 1. Use names which describe semantics, not computation or mathematics, e.g., - avoid `xp1 = x+1` or `tfd.Normal(loc=mu, scale=sigma)`. + avoid `xp1 = x + 1` or `tfd.Normal(loc=mu, scale=sigma)`. 1. Prefer inlining intermediates which are used once. @@ -157,16 +157,16 @@ they supersede all previous conventions. 1. Prefer using the most specific TF operator. E.g, - * Use `tf.squared_difference(x,y)` over `(x-y)**2`. - * Use `tf.rsqrt` over `1./tf.sqrt(x)`. + * Use `tf.squared_difference(x, y)` over `(x - y)**2`. + * Use `tf.rsqrt` over `1. / tf.sqrt(x)`. 1. Worry about gradients! (It's often not automatic for API builders!) 1. When forced to choose between FLOPS and numerical accuracy, prefer numerical accuracy. -1. Avoid tf.cast if possible. Eg, prefer `tf.where(cond, a, b)` over - `tf.cast(cond,dtype=a.dtype)*a + (1-tf.cast(cond,dtype=b.dtype)*b` +1. Avoid tf.cast if possible. Eg, prefer `tf.where(pred, a, b)` over + `tf.cast(cond, dtype=a.dtype) * a + (1 - tf.cast(cond, dtype=b.dtype) * b` 1. Preserve static shape hints. @@ -217,3 +217,15 @@ they supersede all previous conventions. `Tensor`s, and Numpy objects. When converting a user-provided literal to a `Tensor` (see e.g. `Distribution._call_log_prob`), specify the dtype to `tf.convert_to_tensor` if it is available. + +1. Prefer overloaded operators on `Tensor`s (`+`, `-`, etc.) to explicit + method calls (`tf.add`, `tf.sub`, etc.). Exceptions: + + * Prefer `tf.equal` to `==` when checking element-wise equality, because the + semantics of the latter are inconsistent between eager and graph + (`tf.function`) modes. + * Use `&` and `|` only if you want bitwise logic. Note that these are + equivalent to logical ops only if all inputs are `bool`s or are in + `{0, 1}`. + + diff --git a/spinoffs/oryx/oryx/bijectors/__init__.py b/spinoffs/oryx/oryx/bijectors/__init__.py index 346d61ace4..31a5d8b4c6 100644 --- a/spinoffs/oryx/oryx/bijectors/__init__.py +++ b/spinoffs/oryx/oryx/bijectors/__init__.py @@ -18,24 +18,15 @@ from oryx.bijectors import bijector_extensions from tensorflow_probability.substrates import jax as tfp -__all__ = [ - 'bijector_extensions' -] - tfb = tfp.bijectors -_bijectors = {} +__all__ = tfb.__all__ -for name in dir(tfb): +for name in __all__: bij = getattr(tfb, name) if inspect.isclass(bij) and issubclass(bij, tfb.Bijector): if bij is not tfb.Bijector: bij = bijector_extensions.make_type(bij) - _bijectors[name] = bij - - -for key, val in _bijectors.items(): - locals()[key] = val - + locals()[name] = bij -del _bijectors +del tfb diff --git a/spinoffs/oryx/oryx/core/interpreters/harvest.py b/spinoffs/oryx/oryx/core/interpreters/harvest.py index df32ca0eba..c376f25662 100644 --- a/spinoffs/oryx/oryx/core/interpreters/harvest.py +++ b/spinoffs/oryx/oryx/core/interpreters/harvest.py @@ -335,8 +335,8 @@ def process_higher_order_primitive(self, primitive, f, tracers, params, params = params.copy() new_params = dict( params, - mapped_invars=(True,) * len(tree_util.tree_leaves(plants)) + - params['mapped_invars']) + in_axes=(0,) * len(tree_util.tree_leaves(plants)) + + params['in_axes']) else: new_params = dict(params) all_args, all_tree = tree_util.tree_flatten((plants, vals)) diff --git a/spinoffs/oryx/oryx/core/interpreters/inverse/core.py b/spinoffs/oryx/oryx/core/interpreters/inverse/core.py index 04220ab28e..0044c247eb 100644 --- a/spinoffs/oryx/oryx/core/interpreters/inverse/core.py +++ b/spinoffs/oryx/oryx/core/interpreters/inverse/core.py @@ -373,8 +373,8 @@ def remove_slice(cell): flat_vals, in_tree = tree_util.tree_flatten((mapped_incells, mapped_outcells)) f, aux = flat_propagate(f, in_tree) # Assume all invars as mapped - new_mapped_invars = (True,) * len(flat_vals) - new_params = dict(params, mapped_invars=new_mapped_invars) + new_in_axes = (0,) * len(flat_vals) + new_params = dict(params, in_axes=new_in_axes) if 'donated_invars' in params: new_params['donated_invars'] = (False,) * len(flat_vals) subenv_vals = prim.bind(f, *flat_vals, **new_params) diff --git a/spinoffs/oryx/oryx/core/interpreters/unzip.py b/spinoffs/oryx/oryx/core/interpreters/unzip.py index a195930a41..c7cbd613dd 100644 --- a/spinoffs/oryx/oryx/core/interpreters/unzip.py +++ b/spinoffs/oryx/oryx/core/interpreters/unzip.py @@ -34,9 +34,9 @@ from jax import core as jax_core from jax import custom_derivatives as cd from jax import linear_util as lu -from jax import source_info_util from jax import tree_util from jax import util as jax_util +from jax._src import source_info_util from jax.interpreters import partial_eval as pe import numpy as onp @@ -282,14 +282,13 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map): return current_custom_rules()[call_primitive](self, f, *tracers, **params) if call_primitive in pe.call_partial_eval_rules: raise NotImplementedError - in_pvs, in_consts = jax_util.unzip2(t.pval for t in tracers) + in_pvals = [t.pval for t in tracers] if is_map: - pvs = [ - None if pv is None else mapped_aval(params['axis_size'], pv) - for pv in in_pvs - ] - else: - pvs = in_pvs + unknown = pe.PartialVal.unknown + in_pvals = [pval if pval.is_known() or in_axis is None else + unknown(mapped_aval(params['axis_size'], in_axis, pval[0])) + for pval, in_axis in zip(in_pvals, params['in_axes'])] + pvs, in_consts = jax_util.unzip2(t.pval for t in tracers) keys = tuple(t.is_key() for t in tracers) new_settings = UnzipSettings(settings.tag, call_primitive in block_registry) fun, aux = unzip_eval(f, self, keys, tuple(pvs), new_settings) @@ -360,12 +359,6 @@ def _bound_output_tracers(self, primitive, params, jaxpr, consts, env, for pv, const, key in safe_zip(out_pvs, out_consts, out_keys) ] new_params = dict(params, name=name, call_jaxpr=lifted_jaxpr) - if is_map: - new_params = dict( - new_params, - mapped_invars=tuple([True] * len(const_tracers) + - [False] * len(env_tracers) + - [True] * len(in_tracers))) if 'donated_invars' in params: new_donated_invars = ( (False,) * len(const_tracers) + (False,) * len(env_tracers) + diff --git a/spinoffs/oryx/oryx/distributions/__init__.py b/spinoffs/oryx/oryx/distributions/__init__.py index baa140ae66..34a15658d2 100644 --- a/spinoffs/oryx/oryx/distributions/__init__.py +++ b/spinoffs/oryx/oryx/distributions/__init__.py @@ -16,23 +16,12 @@ from oryx.distributions import distribution_extensions from tensorflow_probability.substrates import jax as tfp -__all__ = [ - 'distribution_extensions' -] - - tfd = tfp.distributions -_distributions = {} +__all__ = tfd.__all__ -for name in dir(tfd): +for name in __all__: dist = getattr(tfd, name) - _distributions[name] = dist - - -for key, val in _distributions.items(): - locals()[key] = val - + locals()[name] = dist -del _distributions -del distribution_extensions # Only needed for registration. +del tfd diff --git a/spinoffs/oryx/oryx/experimental/nn/normalization_test.py b/spinoffs/oryx/oryx/experimental/nn/normalization_test.py index d6aeeeebd0..2e26913dd7 100644 --- a/spinoffs/oryx/oryx/experimental/nn/normalization_test.py +++ b/spinoffs/oryx/oryx/experimental/nn/normalization_test.py @@ -171,7 +171,7 @@ def test_check_grads(self): net = net_init.init(net_rng, state.Shape(in_shape)) x = random.normal(data_rng, in_shape) - jtu.check_grads(net, (x,), 2) + jtu.check_grads(net.call, (x,), 2) def mse(x, y): diff --git a/tensorflow_probability/python/__init__.py b/tensorflow_probability/python/__init__.py index bfccbe631e..601c74228f 100644 --- a/tensorflow_probability/python/__init__.py +++ b/tensorflow_probability/python/__init__.py @@ -18,20 +18,23 @@ from __future__ import division from __future__ import print_function +import functools + from tensorflow_probability.python.internal import all_util from tensorflow_probability.python.internal import lazy_loader -# Ensure TensorFlow is importable and its version is sufficiently recent. This -# needs to happen before anything else, since the imports below will try to -# import tensorflow, too. # pylint: disable=g-import-not-at-top -def _ensure_tf_install(): - """Attempt to import tensorflow, and ensure its version is sufficient. +def _validate_tf_environment(package): + """Check TF version and (depending on package) warn about TensorFloat32. + + Args: + package: Python `str` indicating which package is being imported. Used for + package-dependent warning about TensorFloat32. Raises: ImportError: if either tensorflow is not importable or its version is - inadequate. + inadequate. """ try: import tensorflow.compat.v1 as tf @@ -62,9 +65,10 @@ def _ensure_tf_install(): required=required_tensorflow_version, present=tf.__version__)) - if tf.config.experimental.tensor_float_32_execution_enabled(): + if (package == 'mcmc' and + tf.config.experimental.tensor_float_32_execution_enabled()): # Must import here, because symbols get pruned to __all__. - import warnings # pylint: disable=g-import-not-at-top + import warnings warnings.warn( 'TensorFloat-32 matmul/conv are enabled for NVIDIA Ampere+ GPUs. The ' 'resulting loss of precision may hinder MCMC convergence. To turn off, ' @@ -94,6 +98,8 @@ def _ensure_tf_install(): for pkg in _allowed_symbols: globals()[pkg] = lazy_loader.LazyLoader( pkg, globals(), 'tensorflow_probability.python.{}'.format(pkg), - on_first_access=_ensure_tf_install) + # These checks need to happen before lazy-loading, since the modules + # themselves will try to import tensorflow, too. + on_first_access=functools.partial(_validate_tf_environment, pkg)) all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow_probability/python/bijectors/__init__.py b/tensorflow_probability/python/bijectors/__init__.py index 5d29392c32..86eaf037cc 100644 --- a/tensorflow_probability/python/bijectors/__init__.py +++ b/tensorflow_probability/python/bijectors/__init__.py @@ -40,7 +40,6 @@ from tensorflow_probability.python.bijectors.expm1 import Log1p from tensorflow_probability.python.bijectors.ffjord import FFJORD from tensorflow_probability.python.bijectors.fill_scale_tril import FillScaleTriL -from tensorflow_probability.python.bijectors.fill_scale_tril import ScaleTriL from tensorflow_probability.python.bijectors.fill_triangular import FillTriangular from tensorflow_probability.python.bijectors.frechet_cdf import FrechetCDF from tensorflow_probability.python.bijectors.generalized_pareto import GeneralizedPareto @@ -159,7 +158,6 @@ "ScaleMatvecLinearOperatorBlock", "ScaleMatvecLU", "ScaleMatvecTriL", - "ScaleTriL", "Shift", "ShiftedGompertzCDF", "Sigmoid", diff --git a/tensorflow_probability/python/bijectors/bijector_properties_test.py b/tensorflow_probability/python/bijectors/bijector_properties_test.py index f3e3935895..a1277657b2 100644 --- a/tensorflow_probability/python/bijectors/bijector_properties_test.py +++ b/tensorflow_probability/python/bijectors/bijector_properties_test.py @@ -77,7 +77,6 @@ 'ScaleMatvecTriL', 'Shift', 'ShiftedGompertzCDF', - 'ScaleTriL', 'Sigmoid', 'Sinh', 'SinhArcsinh', diff --git a/tensorflow_probability/python/bijectors/fill_scale_tril.py b/tensorflow_probability/python/bijectors/fill_scale_tril.py index 3d815ee687..0718e11c80 100644 --- a/tensorflow_probability/python/bijectors/fill_scale_tril.py +++ b/tensorflow_probability/python/bijectors/fill_scale_tril.py @@ -26,12 +26,10 @@ from tensorflow_probability.python.bijectors import transform_diagonal from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import tensor_util -from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import __all__ = [ 'FillScaleTriL', - 'ScaleTriL', ] @@ -127,56 +125,3 @@ def __init__(self, validate_args=validate_args, parameters=parameters, name=name) - - -class ScaleTriL(chain.Chain): - """DEPRECATED. Please use `tfp.bijectors.FillScaleTriL`.""" - - @deprecation.deprecated( - '2020-01-01', - '`ScaleTriL` has been deprecated and renamed `FillScaleTriL`; please use ' - 'that symbol instead.') - def __init__(self, - diag_bijector=None, - diag_shift=1e-5, - validate_args=False, - name='scale_tril'): - """Instantiates the `ScaleTriL` bijector. - - Args: - diag_bijector: `Bijector` instance, used to transform the output diagonal - to be positive. - Default value: `None` (i.e., `tfb.Softplus()`). - diag_shift: Float value broadcastable and added to all diagonal entries - after applying the `diag_bijector`. Setting a positive - value forces the output diagonal entries to be positive, but - prevents inverting the transformation for matrices with - diagonal entries less than this value. - Default value: `1e-5`. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - Default value: `False` (i.e., arguments are not validated). - name: Python `str` name given to ops managed by this object. - Default value: `scale_tril`. - """ - parameters = dict(locals()) - with tf.name_scope(name) as name: - if diag_bijector is None: - diag_bijector = softplus.Softplus(validate_args=validate_args) - - if diag_shift is not None: - dtype = dtype_util.common_dtype([diag_bijector, diag_shift], tf.float32) - diag_shift = tensor_util.convert_nonref_to_tensor(diag_shift, - name='diag_shift', - dtype=dtype) - diag_bijector = chain.Chain([ - shift.Shift(diag_shift), - diag_bijector - ]) - - super(ScaleTriL, self).__init__( - [transform_diagonal.TransformDiagonal(diag_bijector=diag_bijector), - fill_triangular.FillTriangular()], - validate_args=validate_args, - parameters=parameters, - name=name) diff --git a/tensorflow_probability/python/bijectors/glow.py b/tensorflow_probability/python/bijectors/glow.py index 9deaf17126..5367160208 100644 --- a/tensorflow_probability/python/bijectors/glow.py +++ b/tensorflow_probability/python/bijectors/glow.py @@ -581,7 +581,7 @@ def bijector_fn(inputs, ignored_input): output = this_shift(this_scale) elif target_shape[-1] == output_shape[-1]: - output = shift.Shift(possible_output[..., c:]) + output = shift.Shift(possible_output[..., :c]) else: raise ValueError('Shape inconsistent with input. Expected shape' '{0} or {1} but tensor was shape {2}'.format( @@ -676,7 +676,7 @@ def bijector_fn(inputs, ignored_input): output = this_shift(this_scale) elif input_shape[-1] == output_shape[-1]: - output = shift.Shift(possible_output[..., c:]) + output = shift.Shift(possible_output[..., :c]) else: raise ValueError('Shape inconsistent with input. Expected shape' '{0} or {1} but tensor was shape {2}'.format( @@ -860,4 +860,3 @@ def __init__(self, input_shape, output_chan, kernel_shape=3): super(GlowDefaultExitNetwork, self).__init__([ tfkl.Input(input_shape), conv(this_nchan, kernel_shape)]) - diff --git a/tensorflow_probability/python/bijectors/glow_test.py b/tensorflow_probability/python/bijectors/glow_test.py index f3e6a5d238..f2bcd26941 100644 --- a/tensorflow_probability/python/bijectors/glow_test.py +++ b/tensorflow_probability/python/bijectors/glow_test.py @@ -351,5 +351,34 @@ def float64_exit(input_shape, output_chan): self.assertAllFinite(self.evaluate(z)) self.assertAllFinite(self.evaluate(zf64)) + def testBijectorFn(self): + """Test if the bijector function works for additive coupling.""" + ims = self._make_images() + def shiftfn(input_shape): + input_nchan = input_shape[-1] + return tf.keras.Sequential([ + tf.keras.layers.Input(input_shape), + tf.keras.layers.Conv2D( + input_nchan, 3, padding='same')]) + + def shiftexitfn(input_shape, output_chan): + return tf.keras.Sequential([ + tf.keras.layers.Input(input_shape), + tf.keras.layers.Conv2D( + output_chan, 3, padding='same')]) + + shiftonlyglow = tfb.Glow( + output_shape=self.output_shape, + num_glow_blocks=2, + num_steps_per_block=1, + coupling_bijector_fn=shiftfn, + exit_bijector_fn=shiftexitfn, + grab_after_block=[0.5, 0.5] + ) + z = shiftonlyglow.inverse(ims) + self.evaluate([v.initializer for v in shiftonlyglow.variables]) + self.assertAllFinite(self.evaluate(z)) + + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_probability/python/bijectors/hypothesis_testlib.py b/tensorflow_probability/python/bijectors/hypothesis_testlib.py index c96f258248..1a8f8160dc 100644 --- a/tensorflow_probability/python/bijectors/hypothesis_testlib.py +++ b/tensorflow_probability/python/bijectors/hypothesis_testlib.py @@ -227,9 +227,6 @@ def bijector_supports(): 'ScaleMatvecTriL': BijectorSupport(Support.VECTOR_UNCONSTRAINED, Support.VECTOR_UNCONSTRAINED), - 'ScaleTriL': - BijectorSupport(Support.VECTOR_SIZE_TRIANGULAR, - Support.MATRIX_LOWER_TRIL_POSITIVE_DEFINITE), 'Shift': BijectorSupport(Support.SCALAR_UNCONSTRAINED, Support.SCALAR_UNCONSTRAINED), diff --git a/tensorflow_probability/python/bijectors/rational_quadratic_spline.py b/tensorflow_probability/python/bijectors/rational_quadratic_spline.py index 863762c8bc..e7e1688dde 100644 --- a/tensorflow_probability/python/bijectors/rational_quadratic_spline.py +++ b/tensorflow_probability/python/bijectors/rational_quadratic_spline.py @@ -82,29 +82,33 @@ def __init__(self, nbins=32): self._bin_heights = None self._knot_slopes = None - def _bin_positions(self, x): - x = tf.reshape(x, [-1, self._nbins]) - return tf.math.softmax(x, axis=-1) * (2 - self._nbins * 1e-2) + 1e-2 - - def _slopes(self, x): - x = tf.reshape(x, [-1, self._nbins - 1]) - return tf.math.softplus(x) + 1e-2 - def __call__(self, x, nunits): if not self._built: + def _bin_positions(x): + out_shape = tf.concat((tf.shape(x)[:-1], (nunits, self._nbins)), 0) + x = tf.reshape(x, out_shape) + return tf.math.softmax(x, axis=-1) * (2 - self._nbins * 1e-2) + 1e-2 + + def _slopes(x): + out_shape = tf.concat(( + tf.shape(x)[:-1], (nunits, self._nbins - 1)), 0) + x = tf.reshape(x, out_shape) + return tf.math.softplus(x) + 1e-2 + self._bin_widths = tf.keras.layers.Dense( - nunits * self._nbins, activation=self._bin_positions, name='w') + nunits * self._nbins, activation=_bin_positions, name='w') self._bin_heights = tf.keras.layers.Dense( - nunits * self._nbins, activation=self._bin_positions, name='h') + nunits * self._nbins, activation=_bin_positions, name='h') self._knot_slopes = tf.keras.layers.Dense( - nunits * (self._nbins - 1), activation=self._slopes, name='s') + nunits * (self._nbins - 1), activation=_slopes, name='s') self._built = True + return tfb.RationalQuadraticSpline( bin_widths=self._bin_widths(x), bin_heights=self._bin_heights(x), knot_slopes=self._knot_slopes(x)) - xs = np.random.randn(1, 15).astype(np.float32) # Keras won't Dense(.)(vec). + xs = np.random.randn(3, 15).astype(np.float32) # Keras won't Dense(.)(vec). splines = [SplineParams() for _ in range(nsplits)] def spline_flow(): diff --git a/tensorflow_probability/python/bijectors/rational_quadratic_spline_test.py b/tensorflow_probability/python/bijectors/rational_quadratic_spline_test.py index 1632253a8f..1e678e4ed2 100644 --- a/tensorflow_probability/python/bijectors/rational_quadratic_spline_test.py +++ b/tensorflow_probability/python/bijectors/rational_quadratic_spline_test.py @@ -84,29 +84,33 @@ def __init__(self, nbins=32): self._bin_heights = None self._knot_slopes = None - def _bin_positions(self, x): - x = tf.reshape(x, [-1, self._nbins]) - return tf.math.softmax(x, axis=-1) * (2 - self._nbins * 1e-2) + 1e-2 - - def _slopes(self, x): - x = tf.reshape(x, [-1, self._nbins - 1]) - return tf.math.softplus(x) + 1e-2 - def __call__(self, x, nunits): if not self._built: + def _bin_positions(x): + out_shape = tf.concat((tf.shape(x)[:-1], (nunits, self._nbins)), 0) + x = tf.reshape(x, out_shape) + return tf.math.softmax(x, axis=-1) * (2 - self._nbins * 1e-2) + 1e-2 + + def _slopes(x): + out_shape = tf.concat(( + tf.shape(x)[:-1], (nunits, self._nbins - 1)), 0) + x = tf.reshape(x, out_shape) + return tf.math.softplus(x) + 1e-2 + self._bin_widths = tf.keras.layers.Dense( - nunits * self._nbins, activation=self._bin_positions, name='w') + nunits * self._nbins, activation=_bin_positions, name='w') self._bin_heights = tf.keras.layers.Dense( - nunits * self._nbins, activation=self._bin_positions, name='h') + nunits * self._nbins, activation=_bin_positions, name='h') self._knot_slopes = tf.keras.layers.Dense( - nunits * (self._nbins - 1), activation=self._slopes, name='s') + nunits * (self._nbins - 1), activation=_slopes, name='s') self._built = True + return tfb.RationalQuadraticSpline( bin_widths=self._bin_widths(x), bin_heights=self._bin_heights(x), knot_slopes=self._knot_slopes(x)) - xs = np.random.randn(1, 15).astype(np.float32) # Keras won't Dense(.)(vec). + xs = np.random.randn(3, 15).astype(np.float32) # Keras won't Dense(.)(vec). splines = [SplineParams() for _ in range(nsplits)] def spline_flow(): diff --git a/tensorflow_probability/python/bijectors/real_nvp.py b/tensorflow_probability/python/bijectors/real_nvp.py index c76074d87e..7645195606 100644 --- a/tensorflow_probability/python/bijectors/real_nvp.py +++ b/tensorflow_probability/python/bijectors/real_nvp.py @@ -107,7 +107,7 @@ class RealNVP(bijector_lib.Bijector): x = nvp.sample() nvp.log_prob(x) - nvp.log_prob(0.) + nvp.log_prob([0.0, 0.0, 0.0]) ``` For more examples, see [Jang (2018)][3]. diff --git a/tensorflow_probability/python/distributions/autoregressive.py b/tensorflow_probability/python/distributions/autoregressive.py index fea93af816..76ad728fe5 100644 --- a/tensorflow_probability/python/distributions/autoregressive.py +++ b/tensorflow_probability/python/distributions/autoregressive.py @@ -90,7 +90,7 @@ class Autoregressive(distribution.Distribution): def _normal_fn(event_size): n = event_size * (event_size + 1) // 2 p = tf.Variable(tfd.Normal(loc=0., scale=1.).sample(n)) - affine = tfb.ScaleTriL(tfp.math.fill_triangular(0.25 * p)) + affine = tfb.FillScaleTriL(tfp.math.fill_triangular(0.25 * p)) def _fn(samples): scale = tf.exp(affine(samples)) return tfd.Independent( diff --git a/tensorflow_probability/python/distributions/continuous_bernoulli.py b/tensorflow_probability/python/distributions/continuous_bernoulli.py index 1c8107559d..e7da25101a 100644 --- a/tensorflow_probability/python/distributions/continuous_bernoulli.py +++ b/tensorflow_probability/python/distributions/continuous_bernoulli.py @@ -21,6 +21,7 @@ # Dependency imports import numpy as np import tensorflow.compat.v2 as tf +from tensorflow_probability.python import math as tfp_math from tensorflow_probability.python.bijectors import sigmoid as sigmoid_bijector from tensorflow_probability.python.distributions import distribution from tensorflow_probability.python.distributions import kullback_leibler @@ -330,11 +331,31 @@ def _quantile(self, p, probs=None): if probs is None: probs = self._probs_parameter_no_checks() cut_probs = self._cut_probs(probs) + cut_logits = tf.math.log(cut_probs) - tf.math.log1p(-cut_probs) + logp = tf.math.log(p) + # The expression for the quantile function is: + # log(1 + (e^s - 1) * p) / s, where s is `cut_logits`. When s is large, + # the e^s sub-term becomes increasingly ill-conditioned. However, + # since the numerator tends to s, we can reformulate the s > 0 case + # as a offset from 1, which is more accurate. Coincidentally, + # this eliminates a ratio of infinities problem when `s == +inf`. + result = tf.where( + cut_logits > 0., + 1. + tfp_math.log_add_exp( + logp + tfp_math.log1mexp(cut_logits), -cut_logits) / cut_logits, + tf.math.log1p(tf.math.expm1(cut_logits) * p) / cut_logits) + + # Finally, handle the case where `cut_logits` and `p` are on the boundary, + # as the above expressions can result in ratio of `infs` in that case as + # well. + result = tf.where( + (tf.math.equal(cut_probs, 0.) & tf.math.equal(logp, 0.)) | + (tf.math.equal(cut_probs, 1.) & tf.math.is_inf(logp)), + tf.ones_like(cut_probs), + result) + return tf.where( - (probs < self._lims[0]) | (probs > self._lims[1]), - (tf.math.log1p(-cut_probs + p * (2.0 * cut_probs - 1.0)) - - tf.math.log1p(-cut_probs)) - / (tf.math.log(cut_probs) - tf.math.log1p(-cut_probs)), p) + (probs < self._lims[0]) | (probs > self._lims[1]), result, p) def _mode(self): """Returns `1` if `prob > 0.5` and `0` otherwise.""" diff --git a/tensorflow_probability/python/distributions/continuous_bernoulli_test.py b/tensorflow_probability/python/distributions/continuous_bernoulli_test.py index e377300ec7..6acebb9517 100644 --- a/tensorflow_probability/python/distributions/continuous_bernoulli_test.py +++ b/tensorflow_probability/python/distributions/continuous_bernoulli_test.py @@ -431,6 +431,21 @@ def testQuantile(self): [quantile(0.1, 0.2), quantile(0.3, 0.2), quantile(0.9, 0.2)], dtype=np.float32)) + def testQuantileAtExtremesIsNotNaN(self): + prob = [[0.], [1.]] + dist = tfd.ContinuousBernoulli(probs=prob, validate_args=True) + self.assertAllNotNan( + self.evaluate( + dist.quantile(np.array( + [[0., 0.1, 0.3, 0.9, 1.]], dtype=np.float32)))) + + def testSampleAtExtremesIsNotNaN(self): + prob = [[0.], [1.]] + dist = tfd.ContinuousBernoulli(probs=prob, validate_args=True) + self.assertAllNotNan( + self.evaluate( + dist.sample(int(1e2), seed=test_util.test_seed()))) + def testContinuousBernoulliContinuousBernoulliKL(self): batch_size = 6 a_p = np.array([0.6] * batch_size, dtype=np.float32) diff --git a/tensorflow_probability/python/distributions/distribution_properties_test.py b/tensorflow_probability/python/distributions/distribution_properties_test.py index 84289baa35..282cd9a86f 100644 --- a/tensorflow_probability/python/distributions/distribution_properties_test.py +++ b/tensorflow_probability/python/distributions/distribution_properties_test.py @@ -289,6 +289,7 @@ def testCanConstructAndSampleDistribution(self, data): non_trainable_tensor_params = ( 'atol', 'rtol', + 'eigenvectors', # TODO(b/171872834): DeterminantalPointProcess 'total_count', 'num_samples', 'df', # Can't represent constraint that Wishart df > dimension. diff --git a/tensorflow_probability/python/distributions/half_normal.py b/tensorflow_probability/python/distributions/half_normal.py index ee7541cbe5..fa63739e0e 100644 --- a/tensorflow_probability/python/distributions/half_normal.py +++ b/tensorflow_probability/python/distributions/half_normal.py @@ -18,6 +18,8 @@ from __future__ import division from __future__ import print_function +import math + # Dependency imports import numpy as np import tensorflow.compat.v2 as tf @@ -154,7 +156,7 @@ def _sample_n(self, n, seed=None): def _prob(self, x): scale = tf.convert_to_tensor(self.scale) - coeff = np.sqrt(2) / scale / np.sqrt(np.pi) + coeff = math.sqrt(2) / scale / math.sqrt(np.pi) pdf = coeff * tf.exp(-0.5 * (x / scale)**2) return pdf * tf.cast(x >= 0, self.dtype) @@ -173,10 +175,10 @@ def _entropy(self): return 0.5 * tf.math.log(np.pi * self.scale**2.0 / 2.0) + 0.5 def _mean(self): - return self.scale * np.sqrt(2.0) / np.sqrt(np.pi) + return self.scale * math.sqrt(2.0) / math.sqrt(np.pi) def _quantile(self, p): - return np.sqrt(2.0) * self.scale * tf.math.erfinv(p) + return math.sqrt(2.0) * self.scale * tf.math.erfinv(p) def _mode(self): return tf.zeros(self.batch_shape_tensor()) diff --git a/tensorflow_probability/python/distributions/half_student_t_test.py b/tensorflow_probability/python/distributions/half_student_t_test.py index 7eb6f827f2..e8ac309c6f 100644 --- a/tensorflow_probability/python/distributions/half_student_t_test.py +++ b/tensorflow_probability/python/distributions/half_student_t_test.py @@ -134,8 +134,10 @@ def testLogPDFMultidimensional(self): np.log(2.) + sp_stats.t.logpdf(t, df_v, loc=loc_v, scale=sigma_v)) expected_pdf = ( 2. * sp_stats.t.pdf(t, df_v, loc=loc_v, scale=sigma_v)) - self.assertAllClose(expected_log_pdf, log_pdf_values) - self.assertAllClose(np.log(expected_pdf), log_pdf_values) + self.assertAllClose(expected_log_pdf, log_pdf_values, + atol=0, rtol=1e-5) # relaxed tol for fp32 in JAX + self.assertAllClose(np.log(expected_pdf), log_pdf_values, + atol=0, rtol=1e-5) # relaxed tol for fp32 in JAX self.assertAllClose(expected_pdf, pdf_values) self.assertAllClose(np.exp(expected_log_pdf), pdf_values) diff --git a/tensorflow_probability/python/distributions/ordered_logistic.py b/tensorflow_probability/python/distributions/ordered_logistic.py index 5e56fec6d5..d3ebd5dbd7 100644 --- a/tensorflow_probability/python/distributions/ordered_logistic.py +++ b/tensorflow_probability/python/distributions/ordered_logistic.py @@ -56,7 +56,7 @@ def _broadcast_cat_event_and_params(event, params, base_dtype): if not shape_known_statically or params.shape[:-1] != event.shape: params = params * tf.ones_like(event[..., tf.newaxis], dtype=params.dtype) - params_shape = tf.shape(params)[:-1] + params_shape = ps.shape(params)[:-1] event = event * tf.ones(params_shape, dtype=event.dtype) if tensorshape_util.rank(params.shape) is not None: tensorshape_util.set_shape(event, params.shape[:-1]) diff --git a/tensorflow_probability/python/distributions/platform_compatibility_test.py b/tensorflow_probability/python/distributions/platform_compatibility_test.py index b38794cc5a..5ecc7dbfee 100644 --- a/tensorflow_probability/python/distributions/platform_compatibility_test.py +++ b/tensorflow_probability/python/distributions/platform_compatibility_test.py @@ -119,6 +119,7 @@ # TODO(b/142827327): Bring tolerance down to 0 for all distributions. VECTORIZED_LOGPROB_ATOL = collections.defaultdict(lambda: 1e-6) VECTORIZED_LOGPROB_ATOL.update({ + 'Beta': 1e-5, 'BetaBinomial': 1e-5, 'CholeskyLKJ': 1e-4, 'LKJ': 1e-3, @@ -395,7 +396,8 @@ def testCompositeTensor(self, dist_name, data): dist = data.draw( dhps.distributions( dist_name=dist_name, enable_vars=False, validate_args=False)) - self._test_sample_and_log_prob(dist_name, dist) + with tfp_hps.no_tf_rank_errors(): + self._test_sample_and_log_prob(dist_name, dist) @test_util.test_graph_mode_only @@ -472,7 +474,8 @@ def testVmap(self, dist_name, data): dist = data.draw(dhps.distributions( dist_name=dist_name, enable_vars=False, validate_args=False)) # TODO(b/142826246): Enable validate_args. - self._test_vectorization(dist_name, dist) + with tfp_hps.no_tf_rank_errors(): + self._test_vectorization(dist_name, dist) if __name__ == '__main__': diff --git a/tensorflow_probability/python/distributions/skellam_test.py b/tensorflow_probability/python/distributions/skellam_test.py index 0142b67a5a..e70a5fa766 100644 --- a/tensorflow_probability/python/distributions/skellam_test.py +++ b/tensorflow_probability/python/distributions/skellam_test.py @@ -103,12 +103,12 @@ def testSkellamLogPmfGradient(self): err = self.compute_max_gradient_error( lambda lam: self._make_skellam( # pylint:disable=g-long-lambda rate1=lam, rate2=rate2).log_prob(x), [rate1]) - self.assertLess(err, 3e-4) + self.assertLess(err, 5e-4) err = self.compute_max_gradient_error( lambda lam: self._make_skellam( # pylint:disable=g-long-lambda rate1=rate1, rate2=lam).log_prob(x), [rate2]) - self.assertLess(err, 3e-4) + self.assertLess(err, 5e-4) @test_util.numpy_disable_gradient_test @test_util.jax_disable_test_missing_functionality( diff --git a/tensorflow_probability/python/experimental/BUILD b/tensorflow_probability/python/experimental/BUILD index 6fca9f8b85..42f2e5e98a 100644 --- a/tensorflow_probability/python/experimental/BUILD +++ b/tensorflow_probability/python/experimental/BUILD @@ -52,6 +52,7 @@ multi_substrate_py_library( deps = [ ":composite_tensor", "//tensorflow_probability/python/experimental/auto_batching", + "//tensorflow_probability/python/experimental/bijectors", "//tensorflow_probability/python/experimental/distribute", "//tensorflow_probability/python/experimental/distributions", "//tensorflow_probability/python/experimental/lazybones", diff --git a/tensorflow_probability/python/experimental/__init__.py b/tensorflow_probability/python/experimental/__init__.py index 1cdce4ad26..85db29fa9a 100644 --- a/tensorflow_probability/python/experimental/__init__.py +++ b/tensorflow_probability/python/experimental/__init__.py @@ -32,6 +32,7 @@ from __future__ import print_function from tensorflow_probability.python.experimental import auto_batching +from tensorflow_probability.python.experimental import bijectors from tensorflow_probability.python.experimental import distribute from tensorflow_probability.python.experimental import distributions from tensorflow_probability.python.experimental import lazybones @@ -55,6 +56,7 @@ 'auto_batching', 'as_composite', 'auto_composite_tensor', + 'bijectors', 'distribute', 'distributions', 'lazybones', diff --git a/tensorflow_probability/python/experimental/bijectors/BUILD b/tensorflow_probability/python/experimental/bijectors/BUILD new file mode 100644 index 0000000000..96ed131f8a --- /dev/null +++ b/tensorflow_probability/python/experimental/bijectors/BUILD @@ -0,0 +1,69 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# tfp.experimental distributions library. + +load( + "//tensorflow_probability/python:build_defs.bzl", + "multi_substrate_py_library", + "multi_substrate_py_test", +) + +package( + default_visibility = [ + "//tensorflow_probability:__subpackages__", + ], +) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +multi_substrate_py_library( + name = "bijectors", + srcs = ["__init__.py"], + srcs_version = "PY3", + deps = [ + ":scalar_function_with_inferred_inverse", + ], +) + +multi_substrate_py_library( + name = "scalar_function_with_inferred_inverse", + srcs = ["scalar_function_with_inferred_inverse.py"], + srcs_version = "PY3", + deps = [ + # numpy dep, + # tensorflow dep, + "//tensorflow_probability/python/bijectors:bijector", + "//tensorflow_probability/python/internal:custom_gradient", + "//tensorflow_probability/python/internal:prefer_static", + "//tensorflow_probability/python/internal:tensor_util", + "//tensorflow_probability/python/math", + ], +) + +multi_substrate_py_test( + name = "scalar_function_with_inferred_inverse_test", + size = "medium", + srcs = ["scalar_function_with_inferred_inverse_test.py"], + jax_size = "medium", + srcs_version = "PY3", + deps = [ + # numpy dep, + # tensorflow dep, + "//tensorflow_probability", + "//tensorflow_probability/python/internal:test_util", + ], +) diff --git a/tensorflow_probability/python/experimental/bijectors/__init__.py b/tensorflow_probability/python/experimental/bijectors/__init__.py new file mode 100644 index 0000000000..81650620f0 --- /dev/null +++ b/tensorflow_probability/python/experimental/bijectors/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""TensorFlow Probability experimental bijectors package.""" + +from tensorflow_probability.python.experimental.bijectors.scalar_function_with_inferred_inverse import ScalarFunctionWithInferredInverse + +__all__ = [ + 'ScalarFunctionWithInferredInverse' +] diff --git a/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse.py b/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse.py new file mode 100644 index 0000000000..de1b11bc39 --- /dev/null +++ b/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse.py @@ -0,0 +1,144 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bijector to associate a numeric inverse with any invertible function.""" + +import numpy as np +import tensorflow.compat.v2 as tf + +from tensorflow_probability.python import math as tfp_math +from tensorflow_probability.python.bijectors import bijector +from tensorflow_probability.python.internal import custom_gradient as tfp_custom_gradient + +__all__ = ['ScalarFunctionWithInferredInverse'] + + +class ScalarFunctionWithInferredInverse(bijector.Bijector): + """Bijector to associate a numeric inverse with any invertible function.""" + + def __init__(self, + fn, + domain_constraint_fn=None, + root_search_fn=tfp_math.secant_root, + max_iterations=50, + require_convergence=True, + validate_args=False, + name='scalar_function_with_inferred_inverse'): + """Initialize the ScalarFunctionWithInferredInverse bijector. + + Args: + fn: Python `callable` taking a single Tensor argument `x`, and returning a + Tensor `y` of the same shape. This is assumed to be an invertible + (continuous and monotonic) function applied elementwise to `x`. + domain_constraint_fn: optional Python `callable` that returns values + within the domain of `fn`, used to constrain the root search. For any + real-valued input `r`, the value `x = domain_constraint_fn(r)` should be + a valid input to `fn`. + Default value: `None`. + root_search_fn: Optional Python `callable` used to search for roots of an + objective function. This should have signature + `root_search_fn(objective_fn, initial_x, max_iterations=None)` + and return a tuple containing three `Tensor`s + `(estimated_root, objective_at_estimated_root, num_iterations)`. + Default value: `tfp.math.secant_root`. + max_iterations: Optional Python integer maximum number of iterations to + run the root search algorithm. + Default value: `50`. + require_convergence: Optional Python `bool` indicating whether to return + inverse values when the root-finding algorithm may not have + converged. If `True`, such values are replaced by `NaN`. + Default value: `True`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + Default value: `scalar_function_with_inferred_inverse`. + """ + parameters = locals() + with tf.name_scope(name): + if domain_constraint_fn is None: + domain_constraint_fn = lambda x: x + self._fn = fn + self._root_search_fn = root_search_fn + self._domain_constraint_fn = domain_constraint_fn + self._require_convergence = require_convergence + self._max_iterations = max_iterations + + self._inverse = self._wrap_inverse_with_implicit_gradient() + + super(ScalarFunctionWithInferredInverse, self).__init__( + parameters=parameters, + forward_min_event_ndims=0, + inverse_min_event_ndims=0, + validate_args=validate_args, + name=name) + + @property + def domain_constraint_fn(self): + return self._domain_constraint_fn + + @property + def fn(self): + return self._fn + + @property + def max_iterations(self): + return self._max_iterations + + @property + def require_convergence(self): + return self._require_convergence + + @property + def root_search_fn(self): + return self._root_search_fn + + def _forward(self, x): + return self.fn(x) # pylint: disable=not-callable + + def _inverse_no_gradient(self, y): + # Search for a root in unconstrained space. + unconstrained_root, _, num_iterations = self.root_search_fn( + lambda ux: (self.fn(self.domain_constraint_fn(ux)) - y), # pylint: disable=not-callable + tf.ones_like(y), + max_iterations=self.max_iterations) + x = self.domain_constraint_fn(unconstrained_root) # pylint: disable=not-callable + if self.require_convergence: + x = tf.where( + num_iterations < self.max_iterations, + x, + tf.cast(np.nan, x.dtype)) + return x + + def _wrap_inverse_with_implicit_gradient(self): + """Wraps the inverse to provide implicit reparameterization gradients.""" + + def _vjp_fwd(y): + x = self._inverse_no_gradient(y) + return x, x # Keep `x` as an auxiliary value for the backwards pass. + + # By the inverse function theorem, the derivative of an + # inverse function is the reciprocal of the forward derivative. This has + # been popularized in machine learning by [1]. + # [1] Michael Figurnov, Shakir Mohamed, Andriy Mnih (2018). Implicit + # Reparameterization Gradients. https://arxiv.org/abs/1805.08498. + def _vjp_bwd(x, grad_x): + _, grads = tfp_math.value_and_gradient(self.fn, x) + return (grad_x / grads,) + + @tfp_custom_gradient.custom_gradient( + vjp_fwd=_vjp_fwd, + vjp_bwd=_vjp_bwd) + def _inverse_with_gradient(y): + return self._inverse_no_gradient(y) + return _inverse_with_gradient diff --git a/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse_test.py b/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse_test.py new file mode 100644 index 0000000000..c826ac3608 --- /dev/null +++ b/tensorflow_probability/python/experimental/bijectors/scalar_function_with_inferred_inverse_test.py @@ -0,0 +1,82 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for ScalarFunctionWithInferredInverse bijector.""" + + +import tensorflow.compat.v2 as tf +import tensorflow_probability as tfp + +from tensorflow_probability.python.internal import samplers +from tensorflow_probability.python.internal import test_util + +tfb = tfp.bijectors +tfd = tfp.distributions +tfbe = tfp.experimental.bijectors + + +@test_util.test_all_tf_execution_regimes +class ScalarFunctionWithInferredInverseTests(test_util.TestCase): + + @test_util.numpy_disable_gradient_test + def test_student_t_cdf(self): + dist = tfd.StudentT(df=7, loc=3., scale=2.) + xs = self.evaluate(dist.sample([100], seed=test_util.test_seed())) + + bij = tfbe.ScalarFunctionWithInferredInverse(dist.cdf) + ys = bij.forward(xs) + xxs = bij.inverse(ys) + self.assertAllClose(xs, xxs) + + @test_util.numpy_disable_gradient_test + def test_normal_cdf_gradients(self): + dist = tfd.Normal(loc=3., scale=2.) + bij = tfbe.ScalarFunctionWithInferredInverse(dist.cdf) + + ys = self.evaluate(samplers.uniform([100], seed=test_util.test_seed())) + xs_true, grad_true = tfp.math.value_and_gradient(dist.quantile, ys) + xs_numeric, grad_numeric = tfp.math.value_and_gradient(bij.inverse, ys) + self.assertAllClose(xs_true, xs_numeric, atol=1e-4) + self.assertAllClose(grad_true, grad_numeric, rtol=1e-4) + + @test_util.numpy_disable_gradient_test + def test_domain_constraint_fn(self): + dist = tfd.Beta(concentration0=5., concentration1=3.) + xs = self.evaluate(dist.sample([100], seed=test_util.test_seed())) + + bij = tfbe.ScalarFunctionWithInferredInverse( + dist.cdf, + domain_constraint_fn=dist.experimental_default_event_space_bijector()) + self.assertAllClose(xs, bij.inverse(bij.forward(xs))) + + @test_util.numpy_disable_gradient_test + def test_transformed_distribution_log_prob(self): + uniform = tfd.Uniform(low=0, high=1.) + normal = tfd.Normal(loc=0., scale=1.) + xs = self.evaluate(normal.sample(100, seed=test_util.test_seed())) + + # Define a normal distribution using inverse-CDF sampling. Computing + # log probs under this definition requires inverting the quantile function, + # i.e., numerically approximating `normal.cdf`. + inverse_transform_normal = tfbe.ScalarFunctionWithInferredInverse( + fn=normal.quantile, + domain_constraint_fn=uniform.experimental_default_event_space_bijector() + )(uniform) + self.assertAllClose(normal.log_prob(xs), + inverse_transform_normal.log_prob(xs), + atol=1e-4) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_probability/python/experimental/mcmc/BUILD b/tensorflow_probability/python/experimental/mcmc/BUILD index 4b132cceac..7e223ff3b4 100644 --- a/tensorflow_probability/python/experimental/mcmc/BUILD +++ b/tensorflow_probability/python/experimental/mcmc/BUILD @@ -561,6 +561,7 @@ py_test( ":sample", # tensorflow dep, "//tensorflow_probability", + "//tensorflow_probability/python/experimental/mcmc/internal:test_fixtures", "//tensorflow_probability/python/internal:test_util", ], ) @@ -713,7 +714,6 @@ py_library( py_test( name = "potential_scale_reduction_reducer_test", size = "small", - timeout = "moderate", srcs = ["potential_scale_reduction_reducer_test.py"], python_version = "PY3", srcs_version = "PY3", @@ -721,6 +721,7 @@ py_test( # numpy dep, # tensorflow dep, "//tensorflow_probability", + "//tensorflow_probability/python/experimental/mcmc/internal:test_fixtures", "//tensorflow_probability/python/internal:test_util", ], ) diff --git a/tensorflow_probability/python/experimental/mcmc/covariance_reducer.py b/tensorflow_probability/python/experimental/mcmc/covariance_reducer.py index f266f3b7f2..7d4970176a 100644 --- a/tensorflow_probability/python/experimental/mcmc/covariance_reducer.py +++ b/tensorflow_probability/python/experimental/mcmc/covariance_reducer.py @@ -148,7 +148,7 @@ def __init__( name=name or 'covariance_reducer' ) - def initialize(self, initial_chain_state, initial_kernel_results): + def initialize(self, initial_chain_state, initial_kernel_results=None): """Initializes a `CovarianceReducerState` using previously defined metadata. For calculation purposes, the `initial_chain_state` does not count as a @@ -172,14 +172,13 @@ def initialize(self, initial_chain_state, initial_kernel_results): initial_chain_state = tf.nest.map_structure( tf.convert_to_tensor, initial_chain_state) - initial_kernel_results = tf.nest.map_structure( - tf.convert_to_tensor, - initial_kernel_results, - ) + if initial_kernel_results is not None: + initial_kernel_results = tf.nest.map_structure( + tf.convert_to_tensor, + initial_kernel_results) initial_fn_result = tf.nest.map_structure( lambda fn: fn(initial_chain_state, initial_kernel_results), - self.transform_fn, - ) + self.transform_fn) event_ndims = _canonicalize_event_ndims( initial_fn_result, self.event_ndims) def init(tensor, event_ndims): @@ -193,7 +192,7 @@ def one_step( self, new_chain_state, current_reducer_state, - previous_kernel_results, + previous_kernel_results=None, axis=None): """Update the `current_reducer_state` with a new chain state. @@ -229,13 +228,13 @@ def one_step( new_chain_state = tf.nest.map_structure( tf.convert_to_tensor, new_chain_state) - previous_kernel_results = tf.nest.map_structure( - tf.convert_to_tensor, - previous_kernel_results) + if previous_kernel_results is not None: + previous_kernel_results = tf.nest.map_structure( + tf.convert_to_tensor, + previous_kernel_results) fn_results = tf.nest.map_structure( lambda fn: fn(new_chain_state, previous_kernel_results), - self.transform_fn, - ) + self.transform_fn) if not nest.is_nested(axis): axis = nest_util.broadcast_structure(fn_results, axis) running_covariances = nest.map_structure( @@ -243,8 +242,7 @@ def one_step( current_reducer_state.cov_state, fn_results, axis, - check_types=False, - ) + check_types=False) return CovarianceReducerState(running_covariances) def finalize(self, final_reducer_state): diff --git a/tensorflow_probability/python/experimental/mcmc/covariance_reducer_test.py b/tensorflow_probability/python/experimental/mcmc/covariance_reducer_test.py index a2f1eda767..3ac7e89d2d 100644 --- a/tensorflow_probability/python/experimental/mcmc/covariance_reducer_test.py +++ b/tensorflow_probability/python/experimental/mcmc/covariance_reducer_test.py @@ -26,6 +26,7 @@ import tensorflow.compat.v2 as tf import tensorflow_probability as tfp +from tensorflow_probability.python.experimental.mcmc.internal import test_fixtures from tensorflow_probability.python.internal import test_util @@ -40,116 +41,51 @@ @test_util.test_all_tf_execution_regimes class CovarianceReducersTest(test_util.TestCase): - def test_zero_covariance(self): - cov_reducer = tfp.experimental.mcmc.CovarianceReducer() - fake_kr = FakeKernelResults(0, FakeInnerResults(0)) - state = cov_reducer.initialize(0., fake_kr) - for _ in range(2): - state = cov_reducer.one_step(0., state, fake_kr) - final_num_samples, final_mean, final_cov = self.evaluate([ - state.cov_state.num_samples, - state.cov_state.mean, - cov_reducer.finalize(state)]) - self.assertEqual(2, final_num_samples) - self.assertEqual(0, final_mean) - self.assertEqual(0, final_cov) - def test_random_sanity_check(self): rng = test_util.test_np_rng() x = rng.rand(100) cov_reducer = tfp.experimental.mcmc.CovarianceReducer() - fake_kr = FakeKernelResults(0, FakeInnerResults(0)) - state = cov_reducer.initialize(0., fake_kr) - for sample in x: - state = cov_reducer.one_step(sample, state, fake_kr) - final_mean, final_cov = self.evaluate([ - state.cov_state.mean, - cov_reducer.finalize(state)]) - self.assertNear(np.mean(x), final_mean, err=1e-6) + final_cov = self.evaluate(test_fixtures.reduce(cov_reducer, x)) self.assertNear(np.var(x, ddof=0), final_cov, err=1e-6) - def test_covariance_shape(self): + def test_covariance_shape_and_zero_covariance(self): cov_reducer = tfp.experimental.mcmc.CovarianceReducer(event_ndims=1) - fake_kr = FakeKernelResults(0, FakeInnerResults(0)) - state = cov_reducer.initialize(tf.ones((9, 3)), fake_kr) + state = cov_reducer.initialize(tf.ones((9, 3))) for _ in range(2): - state = cov_reducer.one_step( - tf.zeros((5, 9, 3)), state, fake_kr, axis=0) - final_mean, final_cov = self.evaluate([ + state = cov_reducer.one_step(tf.zeros((5, 9, 3)), state, axis=0) + final_num_samples, final_mean, final_cov = self.evaluate([ + state.cov_state.num_samples, state.cov_state.mean, cov_reducer.finalize(state)]) self.assertEqual((9, 3), final_mean.shape) self.assertEqual((9, 3, 3), final_cov.shape) - - def test_variance_shape(self): - var_reducer = tfp.experimental.mcmc.VarianceReducer() - fake_kr = FakeKernelResults(0, FakeInnerResults(0)) - state = var_reducer.initialize(tf.ones((9, 3)), fake_kr) - for _ in range(2): - state = var_reducer.one_step( - tf.zeros((5, 9, 3)), state, fake_kr, axis=0) - final_mean, final_var = self.evaluate([ - state.cov_state.mean, - var_reducer.finalize(state)]) - self.assertEqual((9, 3), final_mean.shape) - self.assertEqual((9, 3), final_var.shape) + self.assertEqual(10, final_num_samples) + self.assertAllEqual(tf.zeros((9, 3)), final_mean) + self.assertAllEqual(tf.zeros((9, 3, 3)), final_cov) def test_attributes(self): cov_reducer = tfp.experimental.mcmc.CovarianceReducer( event_ndims=1, ddof=1) - fake_kr = FakeKernelResults(0, FakeInnerResults(0)) - state = cov_reducer.initialize( - tf.ones((2, 3), dtype=tf.float64), fake_kr) + state = cov_reducer.initialize(tf.ones((2, 3), dtype=tf.float64)) # check attributes are correct right after initialization self.assertEqual(1, cov_reducer.event_ndims) self.assertEqual(1, cov_reducer.ddof) for _ in range(2): - state = cov_reducer.one_step( - tf.zeros((2, 3), dtype=tf.float64), state, fake_kr) + state = cov_reducer.one_step(tf.zeros((2, 3), dtype=tf.float64), state) # check attributes don't change after stepping through self.assertEqual(1, cov_reducer.event_ndims) self.assertEqual(1, cov_reducer.ddof) - def test_tf_while(self): - cov_reducer = tfp.experimental.mcmc.CovarianceReducer() - fake_kr = FakeKernelResults(0, FakeInnerResults(0)) - state = cov_reducer.initialize(tf.ones((2, 3)), fake_kr) - print(state) - _, state = tf.while_loop( - lambda i, _: i < 100, - lambda i, s: (i + 1, cov_reducer.one_step(tf.ones((2, 3)), s, fake_kr)), - (0., state) - ) - final_cov = self.evaluate(cov_reducer.finalize(state)) - self.assertAllClose(tf.zeros((2, 3, 2, 3)), final_cov, rtol=1e-6) - - def test_nested_chain_state(self): - cov_reducer = tfp.experimental.mcmc.CovarianceReducer(event_ndims=0) - chain_state = ({'one': tf.ones((2, 3)), 'zero': tf.zeros((2, 3))}, - {'two': tf.ones((2, 3)) * 2}) - fake_kr = FakeKernelResults(0, FakeInnerResults(0)) - state = cov_reducer.initialize(chain_state, fake_kr) - _, state = tf.while_loop( - lambda i, _: i < 10, - lambda i, s: (i + 1, cov_reducer.one_step(chain_state, s, fake_kr)), - (0., state) - ) - final_cov = self.evaluate(cov_reducer.finalize(state)) - self.assertAllEqualNested( - final_cov, ({'one': tf.zeros((2, 3)), 'zero': tf.zeros((2, 3))}, - {'two': tf.zeros((2, 3))})) - def test_nested_with_batching_and_chunking(self): cov_reducer = tfp.experimental.mcmc.CovarianceReducer(event_ndims=1) chain_state = ({'one': tf.ones((3, 4)), 'zero': tf.zeros((3, 4))}, {'two': tf.ones((3, 4)) * 2}) - fake_kr = FakeKernelResults(0, FakeInnerResults(0)) - state = cov_reducer.initialize(chain_state, fake_kr) + state = cov_reducer.initialize(chain_state) _, state = tf.while_loop( lambda i, _: i < 10, - lambda i, s: (i + 1, cov_reducer.one_step(chain_state, s, fake_kr, 0)), + lambda i, s: (i + 1, cov_reducer.one_step(chain_state, s, 0)), (0., state) ) final_cov = self.evaluate(cov_reducer.finalize(state)) @@ -157,50 +93,6 @@ def test_nested_with_batching_and_chunking(self): final_cov, ({'one': tf.zeros((3, 4, 4)), 'zero': tf.zeros((3, 4, 4))}, {'two': tf.zeros((3, 4, 4))})) - def test_manual_variance_transform_fn(self): - var_reducer = tfp.experimental.mcmc.VarianceReducer( - transform_fn=lambda _, kr: kr.inner_results.value) - fake_kr = FakeKernelResults(0., FakeInnerResults( - tf.zeros((2, 3)))) - # chain state should be irrelevant - state = var_reducer.initialize(0., fake_kr) - for sample in range(5): - fake_kr = FakeKernelResults( - sample, FakeInnerResults(tf.ones((2, 3)) * sample)) - state = var_reducer.one_step(sample, state, fake_kr) - final_mean, final_var = self.evaluate([ - state.cov_state.mean, - var_reducer.finalize(state)]) - self.assertEqual((2, 3), final_mean.shape) - self.assertAllEqual(np.ones((2, 3)) * 2, final_mean) - self.assertEqual((2, 3), final_var.shape) - self.assertAllEqual(np.ones((2, 3)) * 2, final_var) - - def test_manual_covariance_transform_fn_with_random_states(self): - rng = test_util.test_np_rng() - x = rng.rand(100, 5, 2) - cov_reducer = tfp.experimental.mcmc.CovarianceReducer( - transform_fn=lambda _, kr: kr.inner_results.value) - fake_kr = FakeKernelResults(0., FakeInnerResults( - tf.zeros((5, 2)))) - state = cov_reducer.initialize(0., fake_kr) - for sample in x: - fake_kr = FakeKernelResults(0., FakeInnerResults(sample)) - state = cov_reducer.one_step(0., state, fake_kr) - final_mean, final_cov = self.evaluate([ - state.cov_state.mean, - cov_reducer.finalize(state)]) - - # reshaping to be compatible with a check against numpy - x_reshaped = x.reshape(100, 10) - final_cov_reshaped = tf.reshape(final_cov, (10, 10)) - self.assertEqual((5, 2), final_mean.shape) - self.assertAllClose(np.mean(x, axis=0), final_mean, rtol=1e-5) - self.assertEqual((5, 2, 5, 2), final_cov.shape) - self.assertAllClose(np.cov(x_reshaped.T, ddof=0), - final_cov_reshaped, - rtol=1e-5) - def test_latent_state_with_multiple_transform_fn(self): cov_reducer = tfp.experimental.mcmc.CovarianceReducer( event_ndims=1, @@ -220,10 +112,8 @@ def test_latent_state_with_multiple_transform_fn(self): final_cov = self.evaluate(cov_reducer.finalize(state)) cov_latent = ({'one': tf.zeros((3, 4, 4)), 'zero': tf.zeros((3, 4, 4))}, {'two': tf.zeros((3, 4, 4))}) - self.assertAllEqualNested( - final_cov[0], cov_latent) - self.assertAllEqualNested( - final_cov[1], cov_latent) + self.assertAllEqualNested(final_cov[0], cov_latent) + self.assertAllEqualNested(final_cov[1], cov_latent) def test_transform_fn_with_nested_return(self): cov_red = tfp.experimental.mcmc.CovarianceReducer( diff --git a/tensorflow_probability/python/experimental/mcmc/expectations_reducer.py b/tensorflow_probability/python/experimental/mcmc/expectations_reducer.py index a5c3245683..aab6ad6e45 100644 --- a/tensorflow_probability/python/experimental/mcmc/expectations_reducer.py +++ b/tensorflow_probability/python/experimental/mcmc/expectations_reducer.py @@ -76,7 +76,7 @@ def __init__(self, transform_fn=_get_sample, name=None): name=name or 'expectations_reducer' ) - def initialize(self, initial_chain_state, initial_kernel_results): + def initialize(self, initial_chain_state, initial_kernel_results=None): """Initializes an empty `ExpectationsReducerState`. Args: @@ -94,25 +94,23 @@ def initialize(self, initial_chain_state, initial_kernel_results): initial_chain_state = tf.nest.map_structure( tf.convert_to_tensor, initial_chain_state) - initial_kernel_results = tf.nest.map_structure( - tf.convert_to_tensor, - initial_kernel_results - ) + if initial_kernel_results is not None: + initial_kernel_results = tf.nest.map_structure( + tf.convert_to_tensor, + initial_kernel_results) initial_fn_results = tf.nest.map_structure( lambda fn: fn(initial_chain_state, initial_kernel_results), - self.transform_fn - ) - stream = _prepare_args(initial_fn_results) - return tf.nest.map_structure( - lambda run_mean: ExpectationsReducerState(run_mean.initialize()), - stream - ) + self.transform_fn) + def from_example(res): + return sample_stats.RunningMean.from_shape(res.shape, res.dtype) + return ExpectationsReducerState(tf.nest.map_structure( + from_example, initial_fn_results)) def one_step( self, new_chain_state, current_reducer_state, - previous_kernel_results, + previous_kernel_results=None, axis=None): """Update the `current_reducer_state` with a new chain state. @@ -146,28 +144,21 @@ def one_step( new_chain_state = tf.nest.map_structure( tf.convert_to_tensor, new_chain_state) - previous_kernel_results = tf.nest.map_structure( - tf.convert_to_tensor, - previous_kernel_results - ) - if not nest.is_nested(axis): - axis = nest_util.broadcast_structure(self.transform_fn, axis) + if previous_kernel_results is not None: + previous_kernel_results = tf.nest.map_structure( + tf.convert_to_tensor, + previous_kernel_results) fn_results = tf.nest.map_structure( lambda fn: fn(new_chain_state, previous_kernel_results), - self.transform_fn - ) - stream = _prepare_args(fn_results) - def update(run_mean, fn_results, state, axis): - return run_mean.update(state.expectation_state, fn_results, axis=axis) - updated_expectation = nest.map_structure_up_to( - self.transform_fn, + self.transform_fn) + if not nest.is_nested(axis): + axis = nest_util.broadcast_structure(fn_results, axis) + def update(fn_results, state, axis): + return state.update(fn_results, axis=axis) + return ExpectationsReducerState(nest.map_structure( update, - stream, fn_results, current_reducer_state, axis, - check_types=False) - return nest.map_structure_up_to( - self.transform_fn, - ExpectationsReducerState, - updated_expectation) + fn_results, current_reducer_state.expectation_state, axis, + check_types=False)) def finalize(self, final_reducer_state): """Finalizes expectation calculation from the `final_reducer_state`. @@ -185,16 +176,9 @@ def finalize(self, final_reducer_state): """ with tf.name_scope( mcmc_util.make_name(self.name, 'expectations_reducer', 'finalize')): - fn_results = nest.map_structure_up_to( - self.transform_fn, - lambda state: state.expectation_state.mean, - final_reducer_state - ) - stream = _prepare_args(fn_results) - return nest.map_structure_up_to( - self.transform_fn, - lambda run_mean, state: run_mean.finalize(state.expectation_state), - stream, final_reducer_state, + return nest.map_structure( + lambda state: state.mean, + final_reducer_state.expectation_state, check_types=False) @property @@ -208,11 +192,3 @@ def name(self): @property def parameters(self): return self._parameters - - -def _prepare_args(fn_results): - """Creates a structure of compatible `RunningMean` streams.""" - stream = tf.nest.map_structure( - lambda res: sample_stats.RunningMean(shape=res.shape, dtype=res.dtype), - fn_results) - return stream diff --git a/tensorflow_probability/python/experimental/mcmc/expectations_reducer_test.py b/tensorflow_probability/python/experimental/mcmc/expectations_reducer_test.py index 03330fe0fa..7148854ea3 100644 --- a/tensorflow_probability/python/experimental/mcmc/expectations_reducer_test.py +++ b/tensorflow_probability/python/experimental/mcmc/expectations_reducer_test.py @@ -38,45 +38,6 @@ @test_util.test_all_tf_execution_regimes class ExpectationsReducerTest(test_util.TestCase): - def test_simple_operation(self): - mean_reducer = tfp.experimental.mcmc.ExpectationsReducer() - fake_kr = FakeKernelResults(0, FakeInnerResults(0)) - state = mean_reducer.initialize(0, fake_kr) - for sample in range(6): - state = mean_reducer.one_step(sample, state, fake_kr) - mean = self.evaluate(mean_reducer.finalize(state)) - self.assertEqual(2.5, mean) - - def test_with_transform_fn(self): - transform_fn = [lambda x, y: x + 1, lambda x, y: x + 2] - mean_reducer = tfp.experimental.mcmc.ExpectationsReducer( - transform_fn=transform_fn - ) - fake_kr = FakeKernelResults(0, FakeInnerResults(0)) - state = mean_reducer.initialize(0, fake_kr) - for sample in range(6): - state = mean_reducer.one_step(sample, state, fake_kr) - mean = self.evaluate(mean_reducer.finalize(state)) - self.assertEqual([3.5, 4.5], mean) - - def test_with_nested_transform_fn(self): - transform_fn = [ - {'add_one': lambda x, y: x + 1}, - {'add_two': lambda x, y: x + 2, 'zero': lambda x, y: tf.zeros(())} - ] - expectations_reducer = tfp.experimental.mcmc.ExpectationsReducer( - transform_fn=transform_fn - ) - fake_kr = FakeKernelResults(0, FakeInnerResults(0)) - state = expectations_reducer.initialize(0, fake_kr) - for sample in range(6): - state = expectations_reducer.one_step(sample, state, fake_kr) - mean = self.evaluate(expectations_reducer.finalize(state)) - self.assertEqual([ - {'add_one': 3.5}, - {'add_two': 4.5, 'zero': 0} - ], mean) - def test_with_kernel_results(self): def kernel_average(sample, kr): del sample @@ -86,62 +47,35 @@ def inner_average(sample, kr): return kr.inner_results.value mean_reducer = tfp.experimental.mcmc.ExpectationsReducer( - transform_fn=[kernel_average, inner_average] - ) - kernel_results = FakeKernelResults( - 0, FakeInnerResults(0)) + transform_fn=[kernel_average, inner_average]) + kernel_results = FakeKernelResults(0, FakeInnerResults(0)) state = mean_reducer.initialize(0, kernel_results) for sample in range(6): - kernel_results = FakeKernelResults( - sample, FakeInnerResults(sample + 1)) + kernel_results = FakeKernelResults(sample, FakeInnerResults(sample + 1)) state = mean_reducer.one_step(sample, state, kernel_results) mean = self.evaluate(mean_reducer.finalize(state)) self.assertEqual([2.5, 3.5], mean) def test_chunking(self): mean_reducer = tfp.experimental.mcmc.ExpectationsReducer() - kernel_results = FakeKernelResults( - tf.zeros((3, 9)), FakeInnerResults(tf.ones((3, 9)))) - state = mean_reducer.initialize(tf.ones((3,)), kernel_results) + state = mean_reducer.initialize(tf.ones((3,))) for sample in range(6): state = mean_reducer.one_step( - tf.ones((3, 9)) * sample, state, kernel_results, axis=1) + tf.ones((3, 9)) * sample, state, axis=1) mean = self.evaluate(mean_reducer.finalize(state)) self.assertEqual((3,), mean.shape) self.assertAllEqual([2.5, 2.5, 2.5], mean) - def test_no_steps(self): - mean_reducer = tfp.experimental.mcmc.ExpectationsReducer() - kernel_results = FakeKernelResults( - 0, FakeInnerResults(0)) - state = mean_reducer.initialize(0, kernel_results) - mean = self.evaluate(mean_reducer.finalize(state)) - self.assertEqual(0, mean) - - def test_in_with_reductions(self): - fake_kernel = test_fixtures.TestTransitionKernel() - mean_reducer = tfp.experimental.mcmc.ExpectationsReducer() - reduced_kernel = tfp.experimental.mcmc.WithReductions( - fake_kernel, mean_reducer, - ) - pkr = reduced_kernel.bootstrap_results(8) - _, kernel_results = reduced_kernel.one_step(8, pkr) - reduction_results = self.evaluate( - mean_reducer.finalize(kernel_results.reduction_results)) - self.assertEqual(9, reduction_results) - def test_in_step_kernel(self): fake_kernel = test_fixtures.TestTransitionKernel() mean_reducer = tfp.experimental.mcmc.ExpectationsReducer() reduced_kernel = tfp.experimental.mcmc.WithReductions( - fake_kernel, mean_reducer, - ) + fake_kernel, mean_reducer) _, kernel_results = tfp.experimental.mcmc.step_kernel( num_steps=5, current_state=8, kernel=reduced_kernel, - return_final_kernel_results=True, - ) + return_final_kernel_results=True) reduction_results = self.evaluate( mean_reducer.finalize(kernel_results.reduction_results)) self.assertEqual(11, reduction_results) diff --git a/tensorflow_probability/python/experimental/mcmc/internal/test_fixtures.py b/tensorflow_probability/python/experimental/mcmc/internal/test_fixtures.py index b528207c8a..a377aa092f 100644 --- a/tensorflow_probability/python/experimental/mcmc/internal/test_fixtures.py +++ b/tensorflow_probability/python/experimental/mcmc/internal/test_fixtures.py @@ -107,3 +107,13 @@ def initialize(self, initial_chain_state, initial_kernel_results=None): def one_step( self, new_chain_state, current_reducer_state, previous_kernel_results): return new_chain_state + + +def reduce(reducer, elems): + """Reduces `elems` along the first dimension with `reducer`.""" + elems = tf.convert_to_tensor(elems) + state = reducer.initialize(elems[0]) + def body(i, state): + return i + 1, reducer.one_step(elems[i], state) + _, state = tf.while_loop(lambda i, _: i < elems.shape[0], body, (0, state)) + return reducer.finalize(state) diff --git a/tensorflow_probability/python/experimental/mcmc/potential_scale_reduction_reducer_test.py b/tensorflow_probability/python/experimental/mcmc/potential_scale_reduction_reducer_test.py index 6bf1080df5..b1c0715dcb 100644 --- a/tensorflow_probability/python/experimental/mcmc/potential_scale_reduction_reducer_test.py +++ b/tensorflow_probability/python/experimental/mcmc/potential_scale_reduction_reducer_test.py @@ -18,102 +18,22 @@ from __future__ import division from __future__ import print_function -import collections - # Dependency imports import numpy as np import tensorflow.compat.v2 as tf import tensorflow_probability as tfp +from tensorflow_probability.python.experimental.mcmc.internal import test_fixtures from tensorflow_probability.python.internal import test_util -TestTransitionKernelResults = collections.namedtuple( - 'TestTransitionKernelResults', 'counter_1, counter_2') - - -class TestTransitionKernel(tfp.mcmc.TransitionKernel): - """Fake deterministic Transition Kernel.""" - - def __init__(self, shape=(), target_log_prob_fn=None, is_calibrated=True): - self._is_calibrated = is_calibrated - self._shape = shape - # for composition purposes - self.parameters = dict( - target_log_prob_fn=target_log_prob_fn) - - def one_step(self, current_state, previous_kernel_results, seed=None): - return (current_state + tf.ones(self._shape), - TestTransitionKernelResults( - counter_1=previous_kernel_results.counter_1 + 1, - counter_2=previous_kernel_results.counter_2 + 2)) - - def bootstrap_results(self, current_state): - return TestTransitionKernelResults( - counter_1=tf.zeros(()), - counter_2=tf.zeros(())) - - @property - def is_calibrated(self): - return self._is_calibrated - - @test_util.test_all_tf_execution_regimes class PotentialScaleReductionReducerTest(test_util.TestCase): - def test_simple_operation(self): - rhat_reducer = tfp.experimental.mcmc.PotentialScaleReductionReducer( - independent_chain_ndims=1, - ) - state = rhat_reducer.initialize(tf.zeros(5,)) - chain_state = np.arange(20, dtype=np.float32).reshape((4, 5)) - for sample in chain_state: - state = rhat_reducer.one_step(sample, state) - rhat = rhat_reducer.finalize(state) - true_rhat = tfp.mcmc.potential_scale_reduction( - chains_states=chain_state, - independent_chain_ndims=1, - ) - rhat, true_rhat = self.evaluate([rhat, true_rhat]) - self.assertAllClose(true_rhat, rhat, rtol=1e-6) - - def test_non_scalar_sample(self): - rhat_reducer = tfp.experimental.mcmc.PotentialScaleReductionReducer( - independent_chain_ndims=1, - ) - state = rhat_reducer.initialize(tf.zeros((5, 3))) - chain_state = np.arange(60, dtype=np.float32).reshape((4, 5, 3)) - for sample in chain_state: - state = rhat_reducer.one_step(sample, state) - rhat = rhat_reducer.finalize(state) - true_rhat = tfp.mcmc.potential_scale_reduction( - chains_states=chain_state, - independent_chain_ndims=1, - ) - rhat, true_rhat = self.evaluate([rhat, true_rhat]) - self.assertAllClose(true_rhat, rhat, rtol=1e-6) - - def test_independent_chain_ndims(self): - rhat_reducer = tfp.experimental.mcmc.PotentialScaleReductionReducer( - independent_chain_ndims=2, - ) - state = rhat_reducer.initialize(tf.zeros((2, 5, 3))) - chain_state = np.arange(120, dtype=np.float32).reshape((4, 2, 5, 3)) - for sample in chain_state: - state = rhat_reducer.one_step(sample, state) - rhat = rhat_reducer.finalize(state) - true_rhat = tfp.mcmc.potential_scale_reduction( - chains_states=chain_state, - independent_chain_ndims=2, - ) - rhat, true_rhat = self.evaluate([rhat, true_rhat]) - self.assertAllClose(true_rhat, rhat, rtol=1e-6) - def test_int_samples(self): rhat_reducer = tfp.experimental.mcmc.PotentialScaleReductionReducer( - independent_chain_ndims=1, - ) + independent_chain_ndims=1) state = rhat_reducer.initialize(tf.zeros((5, 3), dtype=tf.int64)) chain_state = np.arange(60).reshape((4, 5, 3)) for sample in chain_state: @@ -121,42 +41,19 @@ def test_int_samples(self): rhat = rhat_reducer.finalize(state) true_rhat = tfp.mcmc.potential_scale_reduction( chains_states=chain_state, - independent_chain_ndims=1, - ) + independent_chain_ndims=1) self.assertEqual(tf.float64, rhat.dtype) rhat, true_rhat = self.evaluate([rhat, true_rhat]) self.assertAllClose(true_rhat, rhat, rtol=1e-6) - def test_in_with_reductions(self): - rhat_reducer = tfp.experimental.mcmc.PotentialScaleReductionReducer( - independent_chain_ndims=1, - ) - fake_kernel = TestTransitionKernel(shape=(5,)) - reduced_kernel = tfp.experimental.mcmc.WithReductions( - inner_kernel=fake_kernel, - reducer=rhat_reducer, - ) - chain_state = tf.zeros(5,) - pkr = reduced_kernel.bootstrap_results(chain_state) - for _ in range(2): - chain_state, pkr = reduced_kernel.one_step( - chain_state, pkr) - rhat = self.evaluate( - rhat_reducer.finalize(pkr.reduction_results)) - self.assertEqual(0.5, rhat) - def test_iid_normal_passes(self): n_samples = 500 - # two scalar chains taken from iid Normal(0, 1) + # five scalar chains taken from iid Normal(0, 1) rng = test_util.test_np_rng() - iid_normal_samples = rng.randn(n_samples, 2) + iid_normal_samples = rng.randn(n_samples, 5) rhat_reducer = tfp.experimental.mcmc.PotentialScaleReductionReducer( - independent_chain_ndims=1, - ) - state = rhat_reducer.initialize(iid_normal_samples[0]) - for sample in iid_normal_samples: - state = rhat_reducer.one_step(sample, state) - rhat = self.evaluate(rhat_reducer.finalize(state)) + independent_chain_ndims=1) + rhat = self.evaluate(test_fixtures.reduce(rhat_reducer, iid_normal_samples)) self.assertAllEqual((), rhat.shape) self.assertAllClose(1., rhat, rtol=0.02) @@ -169,21 +66,17 @@ def test_offset_normal_fails(self): rng = test_util.test_np_rng() offset_samples = rng.randn(n_samples, 3, 4) + offset rhat_reducer = tfp.experimental.mcmc.PotentialScaleReductionReducer( - independent_chain_ndims=1, - ) - state = rhat_reducer.initialize(offset_samples[0]) - for sample in offset_samples: - state = rhat_reducer.one_step(sample, state) - rhat = self.evaluate(rhat_reducer.finalize(state)) + independent_chain_ndims=1) + rhat = self.evaluate(test_fixtures.reduce(rhat_reducer, offset_samples)) self.assertAllEqual((4,), rhat.shape) - self.assertAllEqual(np.ones_like(rhat).astype(bool), rhat > 1.2) + self.assertAllGreater(rhat, 1.2) def test_with_hmc(self): target_dist = tfp.distributions.Normal(loc=0., scale=1.) hmc_kernel = tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=target_dist.log_prob, num_leapfrog_steps=27, - step_size=1/3) + step_size=0.33) reduced_stats, _, _ = tfp.experimental.mcmc.sample_fold( num_steps=50, current_state=tf.zeros((2,)), @@ -191,30 +84,27 @@ def test_with_hmc(self): reducer=[ tfp.experimental.mcmc.TracingReducer(), tfp.experimental.mcmc.PotentialScaleReductionReducer() - ] - ) + ]) rhat = reduced_stats[1] true_rhat = tfp.mcmc.potential_scale_reduction( chains_states=reduced_stats[0][0], - independent_chain_ndims=1, - ) + independent_chain_ndims=1) true_rhat, rhat = self.evaluate([true_rhat, rhat]) self.assertAllClose(true_rhat, rhat, rtol=1e-6) - def test_multiple_latent_state(self): + def test_multiple_latent_states_and_independent_chain_ndims(self): + rng = test_util.test_np_rng() rhat_reducer = tfp.experimental.mcmc.PotentialScaleReductionReducer( - independent_chain_ndims=1, - ) - state = rhat_reducer.initialize([tf.zeros(5,), tf.zeros((2, 5))]) - chain_state = np.arange(20, dtype=np.float32).reshape((4, 5)) - second_chain_state = np.arange(40, dtype=np.float32).reshape((4, 2, 5)) + independent_chain_ndims=2) + state = rhat_reducer.initialize([tf.zeros((2, 5, 3)), tf.zeros((7, 2, 8))]) + chain_state = rng.randn(4, 2, 5, 3) + second_chain_state = rng.randn(4, 7, 2, 8) for latent in zip(chain_state, second_chain_state): state = rhat_reducer.one_step(latent, state) rhat = rhat_reducer.finalize(state) true_rhat = tfp.mcmc.potential_scale_reduction( chains_states=[chain_state, second_chain_state], - independent_chain_ndims=1, - ) + independent_chain_ndims=2) rhat, true_rhat = self.evaluate([rhat, true_rhat]) self.assertAllClose(true_rhat, rhat, rtol=1e-6) diff --git a/tensorflow_probability/python/experimental/mcmc/tracing_reducer_test.py b/tensorflow_probability/python/experimental/mcmc/tracing_reducer_test.py index 2cc1f0e467..057428b1bc 100644 --- a/tensorflow_probability/python/experimental/mcmc/tracing_reducer_test.py +++ b/tensorflow_probability/python/experimental/mcmc/tracing_reducer_test.py @@ -19,7 +19,6 @@ from __future__ import print_function # Dependency imports -import numpy as np import tensorflow.compat.v2 as tf import tensorflow_probability as tfp @@ -31,53 +30,6 @@ @test_util.test_all_tf_execution_regimes class TracingReducerTest(test_util.TestCase): - def test_simple_operation(self): - tracer = tfp.experimental.mcmc.TracingReducer() - state = tracer.initialize(tf.zeros(()), tf.zeros(())) - for sample in range(1, 6): - # kernel results is simply the last sample - state = tracer.one_step(sample, state, sample) - all_states, final_trace = self.evaluate(tracer.finalize(state)) - self.assertAllEqual([1, 2, 3, 4, 5], all_states) - self.assertAllEqual([1, 2, 3, 4, 5], final_trace) - - def test_custom_tracing(self): - tracer = tfp.experimental.mcmc.TracingReducer( - trace_fn=lambda sample, pkr: (sample + pkr,)) - state = tracer.initialize(tf.zeros(()), tf.zeros(())) - for sample in range(1, 6): - state = tracer.one_step(sample, state, sample * 2) - final_trace = self.evaluate(tracer.finalize(state)) - self.assertAllEqual(([3, 6, 9, 12, 15],), final_trace) - - def test_latent_chain_state(self): - tracer = tfp.experimental.mcmc.TracingReducer( - trace_fn=lambda current_state, _: current_state - ) - chain_state = ({'one': np.ones((2, 3)), 'zero': np.zeros((2, 3))}, - {'two': np.ones((2, 3)) * 2}) - state = tracer.initialize(chain_state) - for _ in range(3): - state = tracer.one_step(chain_state, state, None) - final_trace = self.evaluate(tracer.finalize(state)) - self.assertEqual(2, len(final_trace[0])) - self.assertAllEqualNested(chain_state, tf.nest.map_structure( - lambda trace_state: trace_state[0], final_trace)) - - def test_differently_structured_trace_results(self): - def trace_fn(sample, pkr): - return sample, (sample, pkr), {'one': sample, 'two': pkr} - tracer = tfp.experimental.mcmc.TracingReducer(trace_fn=trace_fn) - state = tracer.initialize(tf.zeros(()), tf.zeros(())) - for sample in range(1, 3): - state = tracer.one_step(sample, state, sample * 2) - final_trace = self.evaluate(tracer.finalize(state)) - self.assertEqual(3, len(final_trace)) - self.assertAllEqual([1, 2], final_trace[0]) - self.assertAllEqual(([1, 2], [2, 4]), final_trace[1]) - self.assertAllEqualNested(final_trace[2], ({'one': [1, 2], - 'two': [2, 4]})) - def test_tf_while(self): def trace_fn(sample, pkr): return sample, (sample, pkr), {'one': sample, 'two': pkr} @@ -89,14 +41,12 @@ def _body(sample, pkr, state): _, _, state = tf.while_loop( cond=lambda i, _, __: i < 3, body=_body, - loop_vars=(1., 2., state) - ) + loop_vars=(1., 2., state)) final_trace = self.evaluate(tracer.finalize(state)) self.assertEqual(3, len(final_trace)) self.assertAllEqual([1, 2], final_trace[0]) self.assertAllEqual(([1, 2], [2, 4]), final_trace[1]) - self.assertAllEqualNested(final_trace[2], ({'one': [1, 2], - 'two': [2, 4]})) + self.assertAllEqualNested(final_trace[2], ({'one': [1, 2], 'two': [2, 4]})) def test_in_sample_fold(self): tracer = tfp.experimental.mcmc.TracingReducer() @@ -105,13 +55,11 @@ def test_in_sample_fold(self): num_steps=3, current_state=0., kernel=fake_kernel, - reducer=tracer, - ) + reducer=tracer) trace, final_state, kernel_results = self.evaluate([ trace, final_state, - kernel_results - ]) + kernel_results]) self.assertAllEqual([1, 2, 3], trace[0]) self.assertAllEqual([1, 2, 3], trace[1].counter_1) self.assertAllEqual([2, 4, 6], trace[1].counter_2) @@ -126,13 +74,9 @@ def test_known_size(self): for sample in range(3): state = tracer.one_step(sample, state, sample) all_states, final_trace = tracer.finalize(state) - self.assertAllClose( - [3], tensorshape_util.as_list(all_states.shape)) - self.assertAllClose( - [3], tensorshape_util.as_list(final_trace.shape)) - all_states, final_trace = self.evaluate([ - all_states, final_trace - ]) + self.assertAllEqual([3], tensorshape_util.as_list(all_states.shape)) + self.assertAllEqual([3], tensorshape_util.as_list(final_trace.shape)) + all_states, final_trace = self.evaluate([all_states, final_trace]) self.assertAllEqual([0, 1, 2], all_states) self.assertAllEqual([0, 1, 2], final_trace) diff --git a/tensorflow_probability/python/experimental/nn/BUILD b/tensorflow_probability/python/experimental/nn/BUILD index c2c09e3e1f..959b560aa8 100644 --- a/tensorflow_probability/python/experimental/nn/BUILD +++ b/tensorflow_probability/python/experimental/nn/BUILD @@ -148,7 +148,7 @@ py_library( py_test( name = "convolutional_transpose_layers_test", - size = "small", + size = "medium", srcs = ["convolutional_transpose_layers_test.py"], python_version = "PY3", srcs_version = "PY3", diff --git a/tensorflow_probability/python/experimental/stats/__init__.py b/tensorflow_probability/python/experimental/stats/__init__.py index b2ea559b69..668f14fe83 100644 --- a/tensorflow_probability/python/experimental/stats/__init__.py +++ b/tensorflow_probability/python/experimental/stats/__init__.py @@ -19,20 +19,16 @@ from __future__ import print_function from tensorflow_probability.python.experimental.stats.sample_stats import RunningCentralMoments -from tensorflow_probability.python.experimental.stats.sample_stats import RunningCentralMomentsState from tensorflow_probability.python.experimental.stats.sample_stats import RunningCovariance from tensorflow_probability.python.experimental.stats.sample_stats import RunningMean -from tensorflow_probability.python.experimental.stats.sample_stats import RunningMeanState from tensorflow_probability.python.experimental.stats.sample_stats import RunningPotentialScaleReduction from tensorflow_probability.python.experimental.stats.sample_stats import RunningVariance __all__ = [ 'RunningCentralMoments', - 'RunningCentralMomentsState', 'RunningCovariance', 'RunningMean', - 'RunningMeanState', 'RunningPotentialScaleReduction', 'RunningVariance', ] diff --git a/tensorflow_probability/python/experimental/stats/sample_stats.py b/tensorflow_probability/python/experimental/stats/sample_stats.py index 7ab6ea77c8..035351d0f4 100644 --- a/tensorflow_probability/python/experimental/stats/sample_stats.py +++ b/tensorflow_probability/python/experimental/stats/sample_stats.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function -import collections import functools import inspect import math @@ -36,10 +35,8 @@ __all__ = [ 'RunningCentralMoments', - 'RunningCentralMomentsState', 'RunningCovariance', 'RunningMean', - 'RunningMeanState', 'RunningPotentialScaleReduction', 'RunningVariance', ] @@ -326,32 +323,61 @@ def variance(self, ddof=0): """ return self.covariance(ddof) + @classmethod + def init_from_stats(cls, num_samples, mean, variance): + """Initialize a `RunningVariance` object with given stats. + + This allows the user to initialize knowing the mean, variance, and number + of samples seen so far. + + Args: + num_samples: Scalar `float` `Tensor`, for number of examples already seen. + mean: `float` `Tensor`, for starting mean of estimate. + variance: `float` `Tensor`, for starting estimate of the variance. -RunningMeanState = collections.namedtuple( - 'RunningMeanState', 'num_samples, mean') + Returns: + `RunningVariance` object, with given mean and variance estimate. + """ + # TODO(b/173736911): Add this to RunningCovariance + num_samples = tf.convert_to_tensor(num_samples, name='num_samples') + mean = tf.convert_to_tensor(mean, name='mean') + variance = tf.convert_to_tensor(variance, name='variance') + return cls(num_samples=num_samples, + mean=mean, + sum_squared_residuals=num_samples * variance, + event_ndims=0) +@auto_composite_tensor.auto_composite_tensor(omit_kwargs='name') class RunningMean(object): - """Holds metadata for and computes a running mean. + """Computes a running mean. In computation, samples can be provided individually or in chunks. A "chunk" of size M implies incorporating M samples into a single expectation - computation at once, which is more efficient than one by one. If more than one - sample is accepted and chunking is enabled, the chunked `axis` will define - chunking semantics for all samples. - - `RunningMean` objects do not hold state information. That information, - which includes intermediate calculations, are held in a - `RunningMeanState` as returned via `initialize` and `update` method - calls. + computation at once, which is more efficient than one by one. `RunningMean` is meant to serve general streaming expectations. For a specialized version that fits streaming over MCMC samples, see `ExpectationsReducer` in `tfp.experimental.mcmc`. """ - def __init__(self, shape, dtype=tf.float32): - """Instantiates this object. + def __init__(self, num_samples, mean): + """Instantiates a `RunningMean`. + + Support batch accumulation of multiple independent running means. + + Args: + num_samples: A `Tensor` counting the number of samples + accumulated so far. + mean: A `Tensor` broadcast-compatible with `num_samples` giving the + current mean. + """ + self.num_samples = num_samples + self.mean = mean + + @classmethod + def from_shape(cls, shape, dtype=tf.float32): + """Initialize an empty `RunningMean`. Args: shape: Python `Tuple` or `TensorShape` representing the shape of @@ -361,26 +387,17 @@ def __init__(self, shape, dtype=tf.float32): cast to corresponding floats (i.e. `tf.int32` will be cast to `tf.float32`), as intermediate calculations should be performing floating-point division. - """ - self.shape = shape - if dtype is tf.int64: - dtype = tf.float64 - elif dtype.is_integer: - dtype = tf.float32 - self.dtype = dtype - - def initialize(self): - """Initializes an empty `RunningMeanState`. Returns: state: `RunningMeanState` representing a stream of no inputs. """ - return RunningMeanState( - num_samples=tf.zeros((), dtype=self.dtype), - mean=tf.zeros(self.shape, self.dtype)) + dtype = _float_dtype_like(dtype) + return cls( + num_samples=tf.zeros((), dtype=dtype), + mean=tf.zeros(shape, dtype)) - def update(self, state, new_sample, axis=None): - """Update the `RunningMeanState` with a new sample. + def update(self, new_sample, axis=None): + """Update the `RunningMean` with a new sample. The update formula is from Philippe Pebay (2008) [1] and is identical to that used to calculate the intermediate mean in @@ -388,16 +405,14 @@ def update(self, state, new_sample, axis=None): `tfp.experimental.stats.RunningVariance`. Args: - state: `RunningMeanState` that represents the current state of - running statistics. new_sample: Incoming `Tensor` sample with shape and dtype compatible with - those used to form the `RunningMeanState`. + those used to form the `RunningMean`. axis: If chunking is desired, this is an integer that specifies the axis with chunked samples. For individual samples, set this to `None`. By default, samples are not chunked (`axis` is None). Returns: - state: `RunningMeanState` with updated calculations. + mean: `RunningMean` updated to the new sample. #### References [1]: Philippe Pebay. Formulas for Robust, One-Pass Parallel Computation of @@ -405,43 +420,25 @@ def update(self, state, new_sample, axis=None): SAND2008-6212_, 2008. https://prod-ng.sandia.gov/techlib-noauth/access-control.cgi/2008/086212.pdf """ + dtype = self.mean.dtype new_sample = tf.nest.map_structure( - lambda new_sample: tf.cast(new_sample, dtype=self.dtype), + lambda new_sample: tf.cast(new_sample, dtype=dtype), new_sample) if axis is None: - chunk_n = tf.cast(1, dtype=self.dtype) + chunk_n = tf.constant(1, dtype=dtype) chunk_mean = new_sample else: - chunk_n = tf.cast(ps.shape(new_sample)[axis], dtype=self.dtype) + chunk_n = tf.cast(ps.shape(new_sample)[axis], dtype=dtype) chunk_mean = tf.math.reduce_mean(new_sample, axis=axis) - new_n = state.num_samples + chunk_n - delta_mean = chunk_mean - state.mean - new_mean = state.mean + chunk_n * delta_mean / new_n - return RunningMeanState(new_n, new_mean) - - def finalize(self, state): - """Finalizes expectation computation for the `state`. - - If the `finalized` method is invoked on a running state of no inputs, - `RunningMean` will return a corresponding structure of `tf.zeros`. - - Args: - state: `RunningMeanState` that represents the current state of - running statistics. - - Returns: - mean: An estimate of the mean. - """ - return state.mean - - -RunningCentralMomentsState = collections.namedtuple( - 'RunningCentralMomentsState', - 'mean_state, sum_exponentiated_residuals') + new_n = self.num_samples + chunk_n + delta_mean = chunk_mean - self.mean + new_mean = self.mean + chunk_n * delta_mean / new_n + return RunningMean(new_n, new_mean) +@auto_composite_tensor.auto_composite_tensor class RunningCentralMoments(object): - """Holds metadata for and computes running central moments. + """Computes running central moments. `RunningCentralMoments` will compute arbitrary central moments in streaming fashion following the formula proposed by Philippe Pebay @@ -454,11 +451,6 @@ class RunningCentralMoments(object): `RunningCentralMoments` cannot guarantee numerical stability for all moments. - `RunningCentralMoments` objects do not hold state information. That - information, which includes intermediate calculations, are held in a - `RunningCentralMomentsState` as returned via `initialize` and `update` - method calls. - #### References [1]: Philippe Pebay. Formulas for Robust, One-Pass Parallel Computation of Covariances and Arbitrary-Order Statistical Moments. _Technical Report @@ -466,8 +458,28 @@ class RunningCentralMoments(object): https://prod-ng.sandia.gov/techlib-noauth/access-control.cgi/2008/086212.pdf """ - def __init__(self, shape, moment, dtype=tf.float32): - """Instantiates this object. + def __init__(self, mean_state, exponentiated_residuals, desired_moments): + """Constructs a `RunningCentralMoments`. + + All moments up to the maximum of the desired moments will be computed. + + Args: + mean_state: A `RunningMean` carrying the running mean estimate. + exponentiated_residuals: A `Tensor` representing the sum of exponentiated + residuals. This is a `Tensor` of shape `[max_moment - 1] + + mean_state.mean.shape`, which contains the sum of the residuals raised + to the kth power, for all `2 <= k <= max_moment`. + desired_moments: A Python list of integers giving the moments to return. + The maximum element of this list gives the number of moments that + will be computed. + """ + self.mean_state = mean_state + self.exponentiated_residuals = exponentiated_residuals + self.desired_moments = desired_moments + + @classmethod + def from_shape(cls, shape, moment, dtype=tf.float32): + """Returns an empty `RunningCentralMoments`. Args: shape: Python `Tuple` or `TensorShape` representing the shape of @@ -479,117 +491,101 @@ def __init__(self, shape, moment, dtype=tf.float32): cast to corresponding floats (i.e. `tf.int32` will be cast to `tf.float32`), as intermediate calculations should be performing floating-point division. + + Returns: + state: `RunningCentralMoments` representing a stream of no + inputs. """ - self.shape = shape if isinstance(moment, (tuple, list, np.ndarray)): # we want to support numpy arrays too, but must convert to a list to not # confuse `tf.nest.map_structure` in `finalize` - self.moment = list(moment) - self.max_moment = max(self.moment) + desired_moments = list(moment) + max_moment = max(desired_moments) else: - self.moment = moment - self.max_moment = moment - if dtype is tf.int64: - dtype = tf.float64 - elif dtype.is_integer: - dtype = tf.float32 - self.dtype = dtype - self.mean_stream = RunningMean( - self.shape, self.dtype - ) - - def initialize(self): - """Initializes an empty `RunningCentralMomentsState`. - - The `RunningCentralMomentsState` contains a `RunningMeanState` and - a `Tensor` representing the sum of exponentiated residuals. The sum - of exponentiated residuals is a `Tensor` of shape - (`self.max_moment - 1`, `self.shape`), which contains the sum of - residuals raised to the nth power, for all `2 <= n <= self.max_moment`. - - Returns: - state: `RunningCentralMomentsState` representing a stream of no - inputs. - """ - return RunningCentralMomentsState( - mean_state=self.mean_stream.initialize(), - sum_exponentiated_residuals=tf.zeros( - (self.max_moment - 1,) + self.shape, self.dtype), - ) + desired_moments = [moment] + max_moment = moment + dtype = _float_dtype_like(dtype) + return cls( + mean_state=RunningMean.from_shape(shape, dtype), + exponentiated_residuals=tf.zeros( + ps.concat([(max_moment - 1,), shape], axis=0), dtype), + desired_moments=desired_moments) - def update(self, state, new_sample): - """Update the `RunningCentralMomentsState` with a new sample. + def update(self, new_sample): + """Update with a new sample. Args: - state: `RunningCentralMomentsState` that represents the current - state of running statistics. new_sample: Incoming `Tensor` sample with shape and dtype compatible with - those used to form the `RunningCentralMomentsState`. + those used to form the `RunningCentralMoments`. Returns: - state: `RunningCentralMomentsState` with updated calculations. + state: `RunningCentralMoments` updated to include the new sample. """ + shape = self.mean_state.mean.shape + dtype = self.mean_state.mean.dtype n_2 = 1 - n_1 = state.mean_state.num_samples - n = tf.cast(n_1 + n_2, dtype=self.dtype) - delta_mean = new_sample - state.mean_state.mean - new_mean_state = self.mean_stream.update(state.mean_state, new_sample) - old_res = tf.concat([ - tf.zeros((1,) + self.shape, self.dtype), - state.sum_exponentiated_residuals], axis=0) - # the sum of exponentiated residuals can be thought of as an estimation - # of the central moment before diving through by the number of samples. + n_1 = self.mean_state.num_samples + n = tf.cast(n_1 + n_2, dtype=dtype) + delta_mean = new_sample - self.mean_state.mean + new_mean_state = self.mean_state.update(new_sample) + # The sum of exponentiated residuals can be thought of as an estimation + # of the central moment before dividing through by the number of samples. # Since the first central moment is always 0, it simplifies update # logic to prepend an appropriate structure of zeros. - new_sum_exponentiated_residuals = [tf.zeros(self.shape, self.dtype)] + old_res = tf.concat([ + tf.zeros(ps.concat([(1,), shape], axis=0), dtype), + self.exponentiated_residuals], axis=0) + # Not storing said zeros in the carried state, though + new_exponentiated_residuals = [] # the following two nested for loops calculate equation 2.9 in Pebay's # 2008 paper from smallest moment to highest. - for p in range(2, self.max_moment + 1): - summation = tf.zeros(self.shape, self.dtype) + max_moment = max(self.desired_moments) + for p in range(2, max_moment + 1): + summation = tf.zeros(shape, dtype) for k in range(1, p - 1): adjusted_old_res = ((-delta_mean / n) ** k) * old_res[p - k - 1] - summation += self._n_choose_k(p, k) * adjusted_old_res + summation += _n_choose_k(p, k) * adjusted_old_res # the `adj_term` refers to the final term in equation 2.9 and is not # transcribed exactly; rather, it's simplified to avoid having a # `(n - 1)` denominator. adj_term = (((delta_mean / n) ** p) * (n - 1) * ((n - 1) ** (p - 1) + (-1) ** p)) - new_sum_pth_residual = old_res[p - 1] + summation + adj_term - new_sum_exponentiated_residuals.append(new_sum_pth_residual) + new_pth_residual = old_res[p - 1] + summation + adj_term + new_exponentiated_residuals.append(new_pth_residual) - return RunningCentralMomentsState( + return RunningCentralMoments( new_mean_state, - sum_exponentiated_residuals=tf.convert_to_tensor( - new_sum_exponentiated_residuals[1:], dtype=self.dtype - ) - ) + # The cast is needed in case new_exponentiated_residuals is the empty + # list, which will happen if the user requested only the first moment. + exponentiated_residuals=tf.cast( + tf.stack(new_exponentiated_residuals, axis=0), dtype=dtype), + desired_moments=self.desired_moments) - def finalize(self, state): - """Finalizes streaming computation for all central moments. - - Args: - state: `RunningCentralMomentsState` that represents the current state - of running statistics. + def moments(self): + """Returns the central moments represented by this `RunningCentralMoments`. Returns: all_moments: A `Tensor` representing estimates of the requested central moments. Its leading dimension indexes the moment, in order of those - requested (i.e. in order of `self.moment`). + requested (i.e. in order of `self.desired_moments`). """ # prepend a structure of zeros for the first moment + shape = self.mean_state.mean.shape + dtype = self.mean_state.mean.dtype all_unfinalized_moments = tf.concat([ - tf.zeros((1,) + self.shape, self.dtype), - state.sum_exponentiated_residuals], axis=0) + tf.zeros(ps.concat([(1,), shape], axis=0), dtype), + self.exponentiated_residuals], axis=0) all_moments = all_unfinalized_moments / tf.cast( - state.mean_state.num_samples, self.dtype) - return tf.convert_to_tensor(tf.nest.map_structure( - lambda i: all_moments[i - 1], - self.moment), self.dtype) - - def _n_choose_k(self, n, k): - """Computes nCk.""" - return math.factorial(n) // math.factorial(k) // math.factorial(n - k) + self.mean_state.num_samples, dtype) + desired_moment_indices = tf.convert_to_tensor( + self.desired_moments, dtype=tf.int32) - 1 + return tf.gather(all_moments, desired_moment_indices) + + +def _n_choose_k(n, k): + """Computes nCk.""" + return math.factorial(n) // math.factorial(k) // math.factorial(n - k) @auto_composite_tensor.auto_composite_tensor(omit_kwargs='name') diff --git a/tensorflow_probability/python/experimental/stats/sample_stats_test.py b/tensorflow_probability/python/experimental/stats/sample_stats_test.py index 2d67aa8150..f8cc6433e6 100644 --- a/tensorflow_probability/python/experimental/stats/sample_stats_test.py +++ b/tensorflow_probability/python/experimental/stats/sample_stats_test.py @@ -34,6 +34,15 @@ @test_util.test_all_tf_execution_regimes class RunningCovarianceTest(test_util.TestCase): + def test_from_stats(self): + num_counts = 10. + mean = 1. + variance = 3. + var = tfp.experimental.stats.RunningVariance.init_from_stats( + num_counts, mean, variance) + self.assertEqual(self.evaluate(var.mean), mean) + self.assertEqual(self.evaluate(var.variance()), variance) + def test_zero_running_variance(self): deterministic_samples = [0., 0., 0., 0.] var = tfp.experimental.stats.RunningVariance.from_shape() @@ -457,45 +466,38 @@ def _loop_body(i, running_rhat): class RunningMeanTest(test_util.TestCase): def test_zero_mean(self): - running_mean = tfp.experimental.stats.RunningMean( - shape=(), - ) - state = running_mean.initialize() + running_mean = tfp.experimental.stats.RunningMean.from_shape( + shape=()) for _ in range(6): - state = running_mean.update(state, 0) - mean = self.evaluate(running_mean.finalize(state)) + running_mean = running_mean.update(0) + mean = self.evaluate(running_mean.mean) self.assertEqual(0, mean) def test_higher_rank_shape(self): - running_mean = tfp.experimental.stats.RunningMean( - shape=(5, 3), - ) - state = running_mean.initialize() + running_mean = tfp.experimental.stats.RunningMean.from_shape( + shape=(5, 3)) for sample in range(6): - state = running_mean.update(state, tf.ones((5, 3)) * sample) - mean = self.evaluate(running_mean.finalize(state)) + running_mean = running_mean.update(tf.ones((5, 3)) * sample) + mean = self.evaluate(running_mean.mean) self.assertAllEqual(np.ones((5, 3)) * 2.5, mean) def test_manual_dtype(self): - running_mean = tfp.experimental.stats.RunningMean( + running_mean = tfp.experimental.stats.RunningMean.from_shape( shape=(), - dtype=tf.float64, - ) - state = running_mean.initialize() + dtype=tf.float64) for _ in range(6): - state = running_mean.update(state, 0) - mean = running_mean.finalize(state) + running_mean = running_mean.update(0) + mean = running_mean.mean self.assertEqual(tf.float64, mean.dtype) def test_integer_dtype(self): - running_mean = tfp.experimental.stats.RunningMean( + running_mean = tfp.experimental.stats.RunningMean.from_shape( shape=(), dtype=tf.int32, ) - state = running_mean.initialize() for sample in range(6): - state = running_mean.update(state, sample) - mean = running_mean.finalize(state) + running_mean = running_mean.update(sample) + mean = running_mean.mean self.assertEqual(tf.float32, mean.dtype) mean = self.evaluate(mean) self.assertEqual(2.5, mean) @@ -503,72 +505,41 @@ def test_integer_dtype(self): def test_random_mean(self): rng = test_util.test_np_rng() x = rng.rand(100) - running_mean = tfp.experimental.stats.RunningMean( - shape=(), - ) - state = running_mean.initialize() + running_mean = tfp.experimental.stats.RunningMean.from_shape( + shape=()) for sample in x: - state = running_mean.update(state, sample) - mean = self.evaluate(running_mean.finalize(state)) + running_mean = running_mean.update(sample) + mean = self.evaluate(running_mean.mean) self.assertAllClose(np.mean(x), mean, rtol=1e-6) def test_chunking(self): rng = test_util.test_np_rng() x = rng.rand(100, 10, 5) - running_mean = tfp.experimental.stats.RunningMean( + running_mean = tfp.experimental.stats.RunningMean.from_shape( shape=(5,), ) - state = running_mean.initialize() for sample in x: - state = running_mean.update(state, sample, axis=0) - mean = self.evaluate(running_mean.finalize(state)) + running_mean = running_mean.update(sample, axis=0) + mean = self.evaluate(running_mean.mean) self.assertAllClose(np.mean(x.reshape(1000, 5), axis=0), mean, rtol=1e-6) def test_tf_while(self): rng = test_util.test_np_rng() x = rng.rand(100, 10) tensor_x = tf.convert_to_tensor(x, dtype=tf.float32) - running_mean = tfp.experimental.stats.RunningMean( + running_mean = tfp.experimental.stats.RunningMean.from_shape( shape=(10,)) - _, state = tf.while_loop( + _, running_mean = tf.while_loop( lambda i, _: i < 100, - lambda i, state: (i + 1, running_mean.update(state, tensor_x[i])), - (0, running_mean.initialize())) - mean = self.evaluate(running_mean.finalize(state)) - self.assertAllClose(np.mean(x, axis=0), mean, rtol=1e-6) - - def test_tf_while_with_dynamic_shape(self): - rng = test_util.test_np_rng() - x = rng.rand(100, 10) - tensor_x = tf.convert_to_tensor(x, dtype=tf.float32) - running_mean = tfp.experimental.stats.RunningMean( - shape=(10,)) - - def _loop_body(i, state): - if not tf.executing_eagerly(): - sample = tf1.placeholder_with_default(tensor_x[i], shape=None) - else: - sample = tensor_x[i] - return (i + 1, running_mean.update(state, sample)) - - _, state = tf.while_loop( - lambda i, _: i < 100, - _loop_body, - (tf.constant(0, dtype=tf.int32), running_mean.initialize()), - shape_invariants=( - None, tfp.experimental.stats.RunningMeanState( - None, - tf.TensorShape(None), - ))) - mean = self.evaluate(running_mean.finalize(state)) + lambda i, running_mean: (i + 1, running_mean.update(tensor_x[i])), + (0, running_mean)) + mean = self.evaluate(running_mean.mean) self.assertAllClose(np.mean(x, axis=0), mean, rtol=1e-6) def test_no_inputs(self): - running_mean = tfp.experimental.stats.RunningMean( - shape=(), - ) - state = running_mean.initialize() - mean = self.evaluate(running_mean.finalize(state)) + running_mean = tfp.experimental.stats.RunningMean.from_shape( + shape=()) + mean = self.evaluate(running_mean.mean) self.assertEqual(0, mean) @@ -576,15 +547,13 @@ def test_no_inputs(self): class RunningCentralMomentsTest(test_util.TestCase): def test_first_five_moments(self): - running_moments = tfp.experimental.stats.RunningCentralMoments( + running_moments = tfp.experimental.stats.RunningCentralMoments.from_shape( shape=(), - moment=np.arange(5) + 1 - ) - state = running_moments.initialize() + moment=np.arange(5) + 1) for sample in range(5): - state = running_moments.update(state, sample) + running_moments = running_moments.update(sample) zeroth_moment, var, skew, kur, fifth_moment = self.evaluate( - running_moments.finalize(state)) + running_moments.moments()) self.assertNear(0, zeroth_moment, err=1e-6) self.assertNear(2, var, err=1e-6) self.assertNear(0, skew, err=1e-6) @@ -592,43 +561,35 @@ def test_first_five_moments(self): self.assertNear(0, fifth_moment, err=1e-6) def test_specific_moments(self): - running_moments = tfp.experimental.stats.RunningCentralMoments( + running_moments = tfp.experimental.stats.RunningCentralMoments.from_shape( shape=(), - moment=[5, 3] - ) - state = running_moments.initialize() + moment=[5, 3]) for sample in range(5): - state = running_moments.update(state, sample) - fifth_moment, skew = self.evaluate( - running_moments.finalize(state)) + running_moments = running_moments.update(sample) + fifth_moment, skew = self.evaluate(running_moments.moments()) self.assertNear(0, skew, err=1e-6) self.assertNear(0, fifth_moment, err=1e-6) def test_very_high_moments(self): - running_moments = tfp.experimental.stats.RunningCentralMoments( + running_moments = tfp.experimental.stats.RunningCentralMoments.from_shape( shape=(), - moment=np.arange(15) + 1 - ) - state = running_moments.initialize() + moment=np.arange(15) + 1) for sample in range(5): - state = running_moments.update(state, sample) - moments = self.evaluate( - running_moments.finalize(state)) + running_moments = running_moments.update(sample) + moments = self.evaluate(running_moments.moments()) self.assertAllClose( stats.moment(np.arange(5), moment=np.arange(15) + 1), moments, rtol=1e-6) def test_higher_rank_samples(self): - running_moments = tfp.experimental.stats.RunningCentralMoments( + running_moments = tfp.experimental.stats.RunningCentralMoments.from_shape( shape=(2, 2), - moment=np.arange(5) + 1 - ) - state = running_moments.initialize() + moment=np.arange(5) + 1) for sample in range(5): - state = running_moments.update(state, tf.ones((2, 2)) * sample) + running_moments = running_moments.update(tf.ones((2, 2)) * sample) zeroth_moment, var, skew, kur, fifth_moment = self.evaluate( - running_moments.finalize(state)) + running_moments.moments()) self.assertAllClose(tf.zeros((2, 2)), zeroth_moment, rtol=1e-6) self.assertAllClose(tf.ones((2, 2)) * 2, var, rtol=1e-6) self.assertAllClose(tf.zeros((2, 2)), skew, rtol=1e-6) @@ -638,66 +599,54 @@ def test_higher_rank_samples(self): def test_random_scalar_samples(self): rng = test_util.test_np_rng() x = rng.rand(100) - running_moments = tfp.experimental.stats.RunningCentralMoments( + running_moments = tfp.experimental.stats.RunningCentralMoments.from_shape( shape=(), - moment=np.arange(5) + 1 - ) - state = running_moments.initialize() + moment=np.arange(5) + 1) for sample in x: - state = running_moments.update(state, sample) - moments = self.evaluate(running_moments.finalize(state)) + running_moments = running_moments.update(sample) + moments = self.evaluate(running_moments.moments()) self.assertAllClose( stats.moment(x, moment=[1, 2, 3, 4, 5]), moments, rtol=1e-6) def test_random_higher_rank_samples(self): rng = test_util.test_np_rng() x = rng.rand(100, 10) - running_moments = tfp.experimental.stats.RunningCentralMoments( + running_moments = tfp.experimental.stats.RunningCentralMoments.from_shape( shape=(10,), - moment=np.arange(5) + 1 - ) - state = running_moments.initialize() + moment=np.arange(5) + 1) for sample in x: - state = running_moments.update(state, sample) - moments = self.evaluate(running_moments.finalize(state)) + running_moments = running_moments.update(sample) + moments = self.evaluate(running_moments.moments()) self.assertAllClose( stats.moment(x, moment=[1, 2, 3, 4, 5]), moments, rtol=1e-6) def test_manual_dtype(self): - running_moments = tfp.experimental.stats.RunningCentralMoments( + running_moments = tfp.experimental.stats.RunningCentralMoments.from_shape( shape=(), moment=1, - dtype=tf.float64 - ) - state = running_moments.initialize() - state = running_moments.update(state, 0) - moment = running_moments.finalize(state) + dtype=tf.float64) + running_moments = running_moments.update(0) + moment = running_moments.moments() self.assertEqual(tf.float64, moment.dtype) def test_int_dtype_casts(self): - running_moments = tfp.experimental.stats.RunningCentralMoments( + running_moments = tfp.experimental.stats.RunningCentralMoments.from_shape( shape=(), moment=1, - dtype=tf.int32 - ) - state = running_moments.initialize() - state = running_moments.update(state, 0) - moment = running_moments.finalize(state) + dtype=tf.int32) + running_moments = running_moments.update(0) + moment = running_moments.moments() self.assertEqual(tf.float32, moment.dtype) def test_in_tf_while(self): - running_moments = tfp.experimental.stats.RunningCentralMoments( - shape=(), - moment=np.arange(4) + 1 - ) - state = running_moments.initialize() - _, state = tf.while_loop( + running_moments = tfp.experimental.stats.RunningCentralMoments.from_shape( + shape=(), moment=[1, 2, 3, 4]) + _, running_moments = tf.while_loop( lambda i, _: i < 5, - lambda i, st: (i + 1, running_moments.update(st, tf.ones(()) * i)), - (0., state) + lambda i, mom: (i + 1, mom.update(tf.ones(()) * i)), + (0., running_moments) ) - moments = self.evaluate( - running_moments.finalize(state)) + moments = self.evaluate(running_moments.moments()) self.assertAllClose( stats.moment(np.arange(5), moment=np.arange(4) + 1), moments, diff --git a/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py b/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py index f328fd432b..d918546a0e 100755 --- a/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py +++ b/tensorflow_probability/python/internal/backend/numpy/gen/tensor_shape.py @@ -950,8 +950,28 @@ def num_elements(self): def merge_with(self, other): """Returns a `TensorShape` combining the information in `self` and `other`. - The dimensions in `self` and `other` are merged elementwise, - according to the rules defined for `Dimension.merge_with()`. + The dimensions in `self` and `other` are merged element-wise, + according to the rules below: + + ```python + Dimension(n).merge_with(Dimension(None)) == Dimension(n) + Dimension(None).merge_with(Dimension(n)) == Dimension(n) + Dimension(None).merge_with(Dimension(None)) == Dimension(None) + # raises ValueError for n != m + Dimension(n).merge_with(Dimension(m)) + ``` + >> ts = tf.TensorShape([1,2]) + >> ot1 = tf.TensorShape([1,2]) + >> ts.merge_with(ot).as_list() + [1,2] + + >> ot2 = tf.TensorShape([1,None]) + >> ts.merge_with(ot2).as_list() + [1,2] + + >> ot3 = tf.TensorShape([None, None]) + >> ot3.merge_with(ot2).as_list() + [1, None] Args: other: Another `TensorShape`. @@ -1226,7 +1246,50 @@ def as_proto(self): ]) def __eq__(self, other): - """Returns True if `self` is equivalent to `other`.""" + """Returns True if `self` is equivalent to `other`. + + It first tries to convert `other` to `TensorShape`. `TypeError` is thrown + when the conversion fails. Otherwise, it compares each element in the + TensorShape dimensions. + + * Two *Fully known* shapes, return True iff each element is equal. + >>> t_a = tf.TensorShape([1,2]) + >>> a = [1, 2] + >>> t_b = tf.TensorShape([1,2]) + >>> t_c = tf.TensorShape([1,2,3]) + >>> t_a.__eq__(a) + True + >>> t_a.__eq__(t_b) + True + >>> t_a.__eq__(t_c) + False + + * Two *Partially-known* shapes, return False. + >>> p_a = tf.TensorShape([1,None]) + >>> p_b = tf.TensorShape([2,None]) + >>> p_a.__eq__(p_b) + False + >>> t_a.__eq__(p_a) + False + + * Two *Unknown shape*, return True. + >>> unk_a = tf.TensorShape(None) + >>> unk_b = tf.TensorShape(None) + >>> unk_a.__eq__(unk_b) + True + >>> unk_a.__eq__(t_a) + False + + Args: + other: A `TensorShape` or type that can be converted to `TensorShape`. + + Returns: + True if the dimensions are all equal. + + Raises: + TypeError if `other` can not be converted to `TensorShape`. + """ + try: other = as_shape(other) except TypeError: diff --git a/tensorflow_probability/python/internal/backend/numpy/linalg_impl.py b/tensorflow_probability/python/internal/backend/numpy/linalg_impl.py index 118d434b15..1be19332e9 100644 --- a/tensorflow_probability/python/internal/backend/numpy/linalg_impl.py +++ b/tensorflow_probability/python/internal/backend/numpy/linalg_impl.py @@ -161,9 +161,9 @@ def _lu(input, output_idx_type=np.int32, name=None): # pylint: disable=redefine input = ops.convert_to_tensor(input) if JAX_MODE: # JAX uses XLA, which can do a batched factorization. lu_out, pivots = scipy_linalg.lu_factor(input) - from jax import lax_linalg # pylint: disable=g-import-not-at-top + from jax import lax # pylint: disable=g-import-not-at-top return Lu(lu_out, - lax_linalg.lu_pivots_to_permutation(pivots, lu_out.shape[-1])) + lax.linalg.lu_pivots_to_permutation(pivots, lu_out.shape[-1])) # Scipy can't batch, so we must do so manually. nbatch = int(np.prod(input.shape[:-2])) dim = input.shape[-1] diff --git a/tensorflow_probability/python/internal/hypothesis_testlib.py b/tensorflow_probability/python/internal/hypothesis_testlib.py index 952c0526cb..3188944703 100644 --- a/tensorflow_probability/python/internal/hypothesis_testlib.py +++ b/tensorflow_probability/python/internal/hypothesis_testlib.py @@ -686,6 +686,7 @@ def no_tf_rank_errors(): r'rank > (8|9|[1-9][0-9]+).') does_not_work_pat = (r'does not work on tensors with ' r'more than (8|9|[1-9][0-9]+) dimensions') + only_support_pat = r'only support up to 7 input dimensions' pat_1 = _rank_broadcasting_error_pattern(1, 6) pat_2 = _rank_broadcasting_error_pattern(6, 1) try: @@ -710,6 +711,14 @@ def no_tf_rank_errors(): hp.assume(False) else: raise + except tf.errors.InvalidArgumentError as e: + msg = str(e) + if re.search(only_support_pat, msg): + # We asked some TF op (Argmin/Argmax/...) to operate on a Tensor of + # rank >= 8. + hp.assume(False) + else: + raise except ValueError as e: msg = str(e) if re.search(does_not_work_pat, msg): diff --git a/tensorflow_probability/python/internal/test_util.py b/tensorflow_probability/python/internal/test_util.py index e2b3e8fb77..1efaf2dfc2 100644 --- a/tensorflow_probability/python/internal/test_util.py +++ b/tensorflow_probability/python/internal/test_util.py @@ -251,6 +251,16 @@ def assertNotAllZero(self, a): """ self.assertNotAllEqual(a, tf.nest.map_structure(tf.zeros_like, a)) + def assertAllNotNan(self, a): + """Assert that every entry in a `Tensor` is not NaN. + + Args: + a: A `Tensor` whose entries must be verified as not NaN. + """ + is_not_nan = ~np.isnan(self._GetNdArray(a)) + all_true = np.ones_like(is_not_nan, dtype=np.bool) + self.assertAllEqual(all_true, is_not_nan) + def assertAllNan(self, a): """Assert that every entry in a `Tensor` is NaN. diff --git a/tensorflow_probability/python/math/BUILD b/tensorflow_probability/python/math/BUILD index 53378839bd..8e3646f0bb 100644 --- a/tensorflow_probability/python/math/BUILD +++ b/tensorflow_probability/python/math/BUILD @@ -45,6 +45,7 @@ multi_substrate_py_library( ":generic", ":gradient", ":gram_schmidt", + ":hypergeometric", ":interpolation", ":linalg", ":minimize", @@ -175,6 +176,8 @@ multi_substrate_py_library( srcs_version = "PY3", deps = [ # tensorflow dep, + "//tensorflow_probability/python/internal:tensor_util", + "//tensorflow_probability/python/internal/backend/numpy:tf_inspect", ], ) @@ -190,6 +193,35 @@ multi_substrate_py_test( ], ) +multi_substrate_py_library( + name = "hypergeometric", + srcs = [ + "hypergeometric.py", + ], + srcs_version = "PY2AND3", + deps = [ + # numpy dep, + # tensorflow dep, + "//tensorflow_probability/python/internal:dtype_util", + "//tensorflow_probability/python/internal:prefer_static", + ], +) + +multi_substrate_py_test( + name = "hypergeometric_test", + size = "medium", + srcs = ["hypergeometric_test.py"], + shard_count = 3, + deps = [ + # absl/testing:parameterized dep, + # numpy dep, + # tensorflow dep, + "//tensorflow_probability", + "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/math:hypergeometric", + ], +) + multi_substrate_py_library( name = "interpolation", srcs = [ diff --git a/tensorflow_probability/python/math/__init__.py b/tensorflow_probability/python/math/__init__.py index bd2375b1a8..d430404dde 100644 --- a/tensorflow_probability/python/math/__init__.py +++ b/tensorflow_probability/python/math/__init__.py @@ -59,6 +59,8 @@ from tensorflow_probability.python.math.minimize import MinimizeTraceableQuantities from tensorflow_probability.python.math.numeric import clip_by_value_preserve_gradient from tensorflow_probability.python.math.numeric import log1psquare +from tensorflow_probability.python.math.root_search import find_root_chandrupatla +from tensorflow_probability.python.math.root_search import find_root_secant from tensorflow_probability.python.math.root_search import secant_root from tensorflow_probability.python.math.scan_associative import scan_associative from tensorflow_probability.python.math.sparse import dense_to_sparse @@ -93,6 +95,8 @@ 'dense_to_sparse', 'diag_jacobian', 'erfcinv', + 'find_root_chandrupatla', + 'find_root_secant', 'fill_triangular', 'fill_triangular_inverse', 'gram_schmidt', diff --git a/tensorflow_probability/python/math/bessel.py b/tensorflow_probability/python/math/bessel.py index 1e268f9268..f554dd3f05 100644 --- a/tensorflow_probability/python/math/bessel.py +++ b/tensorflow_probability/python/math/bessel.py @@ -369,7 +369,7 @@ def grad(dy): ] -def _olver_asymptotic_uniform(v, z, name=None): +def _olver_asymptotic_uniform(v, z, output_log_space=False, name=None): """Use Olver's uniform asymptotic expansion for the Bessel function. Olver's uniform asymptotic expansion [1] is specified by @@ -391,6 +391,8 @@ def _olver_asymptotic_uniform(v, z, name=None): Args: v: value for which `I_{v}(z)` and `K_{v}(z) should be computed. z: value for which `I_{v}(z)` and `K_{v}(z) should be computed. + output_log_space: `bool`. If `True`, output is in log-space. + Default value: `False`. name: A name for the operation (optional). Default value: `None` (i.e., 'olver_asymptotic_uniform'). Returns: @@ -423,22 +425,35 @@ def _olver_asymptotic_uniform(v, z, name=None): # since we are subtracting off x. shared_prefactor = 1. / (tf.math.sqrt( 1 + tf.math.square(w)) + w) + tf.math.log(w / (1 + 1. / t)) - i_prefactor = tf.math.sqrt( - t / (2 * np.pi * v_abs)) * tf.math.exp(v_abs * shared_prefactor) + log_i_prefactor = 0.5 * tf.math.log( + t / (2 * np.pi * v_abs)) + v_abs * shared_prefactor # Not the same here since they will have the same sign. - k_prefactor = tf.math.sqrt(np.pi * t / (2 * v_abs)) * tf.math.exp( - -v_abs * shared_prefactor) - kve = k_prefactor * kve_sum - - ive = tf.where( - v > 0., - i_prefactor * ive_sum, - # This uses the reflection formulation for negative v, to - # write this in terms of kve. - i_prefactor * ive_sum + 2 / np.pi * tf.math.sin( - np.pi * v_abs) * k_prefactor * kve_sum * tf.math.exp(-2. * z)) - return ive, kve + log_k_prefactor = 0.5 * tf.math.log( + np.pi * t / (2 * v_abs)) - v_abs * shared_prefactor + + log_kve = log_k_prefactor + tf.math.log(kve_sum) + log_ive = log_i_prefactor + tf.math.log(ive_sum) + + # We need to add a correction term for negative v. + negative_v_correction = log_kve - 2. * z + n = tf.math.round(v) + u = v - n + coeff = 2 / np.pi * tf.math.sin(np.pi * u) + coeff = (1. - 2. * tf.math.mod(n, 2.)) * coeff + ive_negative_v = tf.where( + log_ive > negative_v_correction, + tf.math.exp(log_ive + tf.math.log1p( + coeff * tf.math.exp(negative_v_correction - log_ive))), + tf.math.exp(negative_v_correction) * (tf.math.exp( + log_ive - negative_v_correction) + coeff)) + + ive = tf.where(v > 0., tf.math.exp(log_ive), ive_negative_v) + if output_log_space: + log_ive = tf.where( + v > 0., log_ive, tf.math.log(tf.math.abs(ive_negative_v))) + return log_ive, log_kve + return ive, tf.math.exp(log_kve) def _evaluate_temme_coeffs(v): @@ -491,7 +506,7 @@ def _evaluate_temme_coeffs(v): def _temme_series(v, z): - """Computes K(v, z) and K(v + 1., z) via Power series expansion.""" + """Computes Kve(v, z) and Kve(v + 1., z) via Power series expansion.""" # This is based on: # [1] N. Temme, On the Numerical Evaluation of the Modified Bessel Function # of the Third Kind. Journal of Computational Physics 19, 1975. @@ -556,10 +571,12 @@ def body_fn(should_stop, index, f, p, q, coeff, kv_sum, kvp1_sum): initial_f, initial_p)) - return kv_sum, 2 * kvp1_sum / z + log_kve = tf.math.log(kv_sum) + z + log_kvep1 = tf.math.log(2. * kvp1_sum) + z - tf.math.log(z) + return tf.math.exp(log_kve), tf.math.exp(log_kvep1) -def _continued_fraction_kv(v, z): +def _continued_fraction_kv(v, z, output_log_space=False): """Compute Modified Bessels of Second Kind using Hypergeometric functions. First define `k_n(z) = (-1)**n U(v + n + 0.5, 2 * v + 1., 2 * z)` where @@ -567,7 +584,7 @@ def _continued_fraction_kv(v, z): We can compute via [1] `K_v(z)` and `K_{v + 1}(z)` via the identities: - `K_v(z) = sqrt(pi) * (2 * z) ** v * ezp(-z) * k_0(z)`, + `K_v(z) = sqrt(pi) * (2 * z) ** v * exp(-z) * k_0(z)`, `K_{v + 1}(z) = K_v(z) * (v + z + 0.5 - k_1(z) / k_0(z)`, This function aims to compute the ratio `k_1(z) / k_0(z)` via @@ -577,6 +594,8 @@ def _continued_fraction_kv(v, z): Args: v: Floating-point `Tensor` broadcastable with `z`. z: Floating-point `Tensor` broadcastable with `v`. + output_log_space: `bool`. If `True`, output is in log-space. + Default value: `False`. Returns: kv_tuple: `K_v(z)` and `K_{v + 1}(z)`. @@ -669,9 +688,12 @@ def steeds_algorithm( initial_seq, 1 - initial_numerator * initial_ratio)) - kv = tf.math.sqrt(np.pi / (2 * z)) * tf.math.exp(-z) / hypergeometric_sum - kvp1 = kv * (0.5 + v + z + initial_numerator * hypergeometric_ratio) / z - return kv, kvp1 + log_kve = 0.5 * tf.math.log(np.pi / (2 * z)) - tf.math.log(hypergeometric_sum) + log_kvp1e = log_kve + tf.math.log( + 0.5 + v + z + initial_numerator * hypergeometric_ratio) - tf.math.log(z) + if output_log_space: + return log_kve, log_kvp1e + return tf.math.exp(log_kve), tf.math.exp(log_kvp1e) def _temme_expansion(v, x): @@ -690,53 +712,66 @@ def _temme_expansion(v, x): small_x = tf.where(x_abs <= 2., x_abs, numpy_dtype(0.1)) large_x = tf.where(x_abs > 2., x_abs, numpy_dtype(1000.)) - temme_ku, temme_kup1 = _temme_series(u, small_x) - cf_ku, cf_kup1 = _continued_fraction_kv(u, large_x) + temme_kue, temme_kuep1 = _temme_series(u, small_x) + cf_kue, cf_kuep1 = _continued_fraction_kv(u, large_x) - ku = tf.where(x_abs <= 2., temme_ku, cf_ku) - kup1 = tf.where(x_abs <= 2., temme_kup1, cf_kup1) + kue = tf.where(x_abs <= 2., temme_kue, cf_kue) + kuep1 = tf.where(x_abs <= 2., temme_kuep1, cf_kuep1) # Now use the forward recurrence for modified bessel functions # to compute Kv(v, x). That is, # K_{v + 1}(z) - (2v / z) K_v(z) - K_{v - 1}(z) = 0. # This is known to be forward numerically stable. + # Note: This recurrence is also satisfied by K_v(z) * exp(z) - def bessel_recurrence(index, kv, kvp1): - next_kvp1 = 2 * (u + index) * kvp1 / x_abs + kv - kv = tf.where(index > n, kv, kvp1) - kvp1 = tf.where(index > n, kvp1, next_kvp1) - return index + 1., kv, kvp1 + def bessel_recurrence(index, kve, kvep1): + next_kvep1 = 2 * (u + index) * kvep1 / x_abs + kve + kve = tf.where(index > n, kve, kvep1) + kvep1 = tf.where(index > n, kvep1, next_kvep1) + return index + 1., kve, kvep1 - _, kv, kvp1 = tf.while_loop( + _, kve, kvep1 = tf.while_loop( cond=lambda i, *_: tf.reduce_any(i <= n), body=bessel_recurrence, - loop_vars=(tf.cast(1., dtype=dtype), ku, kup1)) + loop_vars=(tf.cast(1., dtype=dtype), kue, kuep1)) # Finally, it is known that the Wronskian # det(I_v * K'_v - K_v * I'_v) = - 1. / x. We can # use this to evaluate I_v by taking advantage of identities of Bessel # derivatives. - iv = tf.math.reciprocal( - x_abs * (kv * bessel_iv_ratio(v + 1., x) + kvp1)) + ive = tf.math.reciprocal( + x_abs * (kve * bessel_iv_ratio(v + 1., x) + kvep1)) + + # We need to add a correction term for negative v. + negative_v_correction = tf.math.log(kve) - 2. * x_abs + coeff = 2 / np.pi * tf.math.sin(np.pi * u) + coeff = (1. - 2. * tf.math.mod(n, 2.)) * coeff + log_ive = tf.math.log(ive) + + ive_negative_v = tf.where( + log_ive > negative_v_correction, + tf.math.exp(log_ive + tf.math.log1p( + coeff * tf.math.exp(negative_v_correction - log_ive))), + tf.math.exp(negative_v_correction) * (tf.math.exp( + log_ive - negative_v_correction) + coeff)) + + ive = tf.where(v_less_than_zero, ive_negative_v, ive) + z = u + tf.math.mod(n, 2.) - iv = tf.where( - v_less_than_zero, - iv + 2. / np.pi * tf.math.sin(np.pi * z) * kv, iv) - iv = tf.where( + ive = tf.where( tf.math.equal(x, 0.), - tf.where(tf.math.equal(v, 0.), numpy_dtype(1.), numpy_dtype(0.)), iv) - iv = tf.where(tf.math.equal(x, 0.) & v_less_than_zero, - tf.where( - tf.math.equal(z, tf.math.floor(z)), - iv, - numpy_dtype(np.inf)), - iv) - kv = tf.where(tf.math.equal(x, 0.), numpy_dtype(np.inf), kv) - iv = tf.where(x < 0., numpy_dtype(np.nan), iv) - kv = tf.where(x < 0., numpy_dtype(np.nan), kv) - return iv, kv + tf.where(tf.math.equal(v, 0.), numpy_dtype(1.), numpy_dtype(0.)), ive) + ive = tf.where(tf.math.equal(x, 0.) & v_less_than_zero, + tf.where( + tf.math.equal(z, tf.math.floor(z)), + ive, + numpy_dtype(np.inf)), ive) + kve = tf.where(tf.math.equal(x, 0.), numpy_dtype(np.inf), kve) + ive = tf.where(x < 0., numpy_dtype(np.nan), ive) + kve = tf.where(x < 0., numpy_dtype(np.nan), kve) + return ive, kve @tf.custom_gradient @@ -782,7 +817,7 @@ def bessel_ive(v, z, name=None): large_v = tf.where(tf.math.abs(v_abs) >= 50., v_abs, numpy_dtype(1000.)) olver_ive, _ = _olver_asymptotic_uniform(large_v, z_abs) - temme_ive = _temme_expansion(small_v, z_abs)[0] * tf.math.exp(-z_abs) + temme_ive = _temme_expansion(small_v, z_abs)[0] ive = tf.where(tf.math.abs(v) >= 50., olver_ive, temme_ive) # Handle when z is zero. @@ -873,7 +908,7 @@ def bessel_kve(v, z, name=None): large_v = tf.where(v >= 50., v, numpy_dtype(1000.)) _, olver_kve = _olver_asymptotic_uniform(large_v, z_abs) - temme_kve = _temme_expansion(small_v, z_abs)[1] * tf.math.exp(z_abs) + temme_kve = _temme_expansion(small_v, z_abs)[1] kve = tf.where(v >= 50., olver_kve, temme_kve) # Handle when z is zero. diff --git a/tensorflow_probability/python/math/bessel_test.py b/tensorflow_probability/python/math/bessel_test.py index 2d332a3c1e..49d3ade751 100644 --- a/tensorflow_probability/python/math/bessel_test.py +++ b/tensorflow_probability/python/math/bessel_test.py @@ -81,7 +81,7 @@ def testBesselIvRatioVAndZSmall(self): # the computation become numerically unstable. # Anecdotally (when comparing to mpmath) the computation is more often # 'right' compared to the naive ratio. - self.VerifyBesselIvRatio(v, z, rtol=2e-4) + self.VerifyBesselIvRatio(v, z, rtol=3e-4) def testBesselIvRatioVAndZMedium(self): seed_stream = test_util.test_seed_stream() @@ -119,7 +119,7 @@ def testBesselIvRatioVLessThanZ(self): z = tf.random.uniform([int(1e5)], 1., 10., seed=seed_stream()) # Make v randomly less than z v = z * tf.random.uniform([int(1e5)], 0.1, 0.5, seed=seed_stream()) - self.VerifyBesselIvRatio(v, z, rtol=1e-6) + self.VerifyBesselIvRatio(v, z, rtol=6e-6) def testBesselIvRatioVGreaterThanZ(self): seed_stream = test_util.test_seed_stream() @@ -269,13 +269,14 @@ def testBesselIveVGreaterThanZ(self, dtype, rtol): self.VerifyBesselIve(v, z, rtol=rtol) @parameterized.named_parameters( - ("float32", np.float32, 2e-3), - ("float64", np.float64, 8e-4), + ("float32", np.float32, 1e-4), + ("float64", np.float64, 1e-6), ) def testBesselIveVNegative(self, dtype, rtol): seed_stream = test_util.test_seed_stream() - v = tf.random.uniform([int(1e5)], -10., -1., seed=seed_stream()) - z = tf.random.uniform([int(1e5)], 1., 15., seed=seed_stream()) + v = tf.random.uniform( + [int(1e5)], -10., -1., seed=seed_stream(), dtype=dtype) + z = tf.random.uniform([int(1e5)], 1., 15., seed=seed_stream(), dtype=dtype) self.VerifyBesselIve(v, z, rtol=rtol) @parameterized.named_parameters( @@ -288,6 +289,18 @@ def testBesselIveVZero(self, dtype, rtol): z = tf.random.uniform([int(1e5)], 1., 15., seed=seed_stream(), dtype=dtype) self.VerifyBesselIve(v, z, rtol=rtol) + @parameterized.named_parameters( + ("float32", np.float32, 1e-6), + ("float64", np.float64, 1e-6), + ) + def testBesselIveLargeZ(self, dtype, rtol): + seed_stream = test_util.test_seed_stream() + v = tf.random.uniform( + [int(1e5)], minval=0., maxval=0.5, seed=seed_stream(), dtype=dtype) + z = tf.random.uniform( + [int(1e5)], minval=100., maxval=10000., seed=seed_stream(), dtype=dtype) + self.VerifyBesselIve(v, z, rtol=rtol) + @test_util.numpy_disable_gradient_test @test_util.jax_disable_test_missing_functionality( "Relies on Tensorflow gradient_checker") @@ -296,7 +309,7 @@ def testBesselIveVZero(self, dtype, rtol): ("float64", np.float64)) def testBesselIveGradient(self, dtype): v = tf.constant([-1., 0.5, 1., 10., 20.], dtype=dtype)[..., tf.newaxis] - z = tf.constant([0.1, 0.5, 0.9, 1., 12., 14., 22.], dtype=dtype) + z = tf.constant([0.2, 0.5, 0.9, 1., 12., 14., 22.], dtype=dtype) err = self.compute_max_gradient_error( functools.partial(tfp_math.bessel_ive, v), [z]) @@ -310,7 +323,7 @@ def testBesselIveGradient(self, dtype): ("float64", np.float64)) def testBesselIveNegativeGradient(self, dtype): v = tf.constant([1., 10., 20.], dtype=dtype)[..., tf.newaxis] - z = tf.constant([-.1, -2.5, -3.5, -5.], dtype=dtype) + z = tf.constant([-.2, -2.5, -3.5, -5.], dtype=dtype) err = self.compute_max_gradient_error( functools.partial(tfp_math.bessel_ive, v), [z]) @@ -320,15 +333,15 @@ def testBesselIveNegativeGradient(self, dtype): @test_util.jax_disable_test_missing_functionality( "Relies on Tensorflow gradient_checker") @parameterized.named_parameters( - ("float32", np.float32), - ("float64", np.float64)) - def testLogBesselIveGradient(self, dtype): + ("float32", np.float32, 1e-3), + ("float64", np.float64, 1e-4)) + def testLogBesselIveGradient(self, dtype, tol): v = tf.constant([-0.2, -1., 1., 0.5, 2.], dtype=dtype)[..., tf.newaxis] - z = tf.constant([0.1, 0.5, 0.9, 1., 12., 22.], dtype=dtype) + z = tf.constant([0.3, 0.5, 0.9, 1., 12., 22.], dtype=dtype) err = self.compute_max_gradient_error( functools.partial(tfp_math.log_bessel_ive, v), [z]) - self.assertLess(err, 8e-4) + self.assertLess(err, tol) class BesselKveTest(test_util.TestCase): @@ -404,8 +417,8 @@ def testBesselKveVAndZSmall(self, dtype, rtol): self.VerifyBesselKve(v, z, rtol=rtol) @parameterized.named_parameters( - ("float32", np.float32, 3e-5), - ("float64", np.float64, 2e-5), + ("float32", np.float32, 1.5e-5), + ("float64", np.float64, 1.2e-5), ) def testBesselKveVAndZMedium(self, dtype, rtol): seed_stream = test_util.test_seed_stream() @@ -414,7 +427,7 @@ def testBesselKveVAndZMedium(self, dtype, rtol): self.VerifyBesselKve(v, z, rtol=rtol) @parameterized.named_parameters( - ("float32", np.float32, 2e-5), + ("float32", np.float32, 2e-6), ("float64", np.float64, 1e-6), ) def testBesselKveVAndZLarge(self, dtype, rtol): @@ -424,8 +437,8 @@ def testBesselKveVAndZLarge(self, dtype, rtol): self.VerifyBesselKve(v, z, rtol=rtol) @parameterized.named_parameters( - ("float32", np.float32, 2e-5), - ("float64", np.float64, 2e-5), + ("float32", np.float32, 1.5e-5), + ("float64", np.float64, 1.5e-5), ) def testBesselKveVLessThanZ(self, dtype, rtol): seed_stream = test_util.test_seed_stream() @@ -436,7 +449,7 @@ def testBesselKveVLessThanZ(self, dtype, rtol): self.VerifyBesselKve(v, z, rtol=rtol) @parameterized.named_parameters( - ("float32", np.float32, 2e-5), + ("float32", np.float32, 3e-6), ("float64", np.float64, 3e-6), ) def testBesselKveVGreaterThanZ(self, dtype, rtol): @@ -448,13 +461,26 @@ def testBesselKveVGreaterThanZ(self, dtype, rtol): self.VerifyBesselKve(v, z, rtol=rtol) @parameterized.named_parameters( - ("float32", np.float32, 2e-5), - ("float64", np.float64, 2e-5), + ("float32", np.float32, 1.5e-5), + ("float64", np.float64, 1.5e-5), ) def testBesselKveVNegative(self, dtype, rtol): seed_stream = test_util.test_seed_stream() - v = tf.random.uniform([int(1e5)], -10., -1., seed=seed_stream()) - z = tf.random.uniform([int(1e5)], 1., 15., seed=seed_stream()) + v = tf.random.uniform( + [int(1e5)], -10., -1., seed=seed_stream(), dtype=dtype) + z = tf.random.uniform([int(1e5)], 1., 15., seed=seed_stream(), dtype=dtype) + self.VerifyBesselKve(v, z, rtol=rtol) + + @parameterized.named_parameters( + ("float32", np.float32, 1e-6), + ("float64", np.float64, 1e-6), + ) + def testBesselKveLargeZ(self, dtype, rtol): + seed_stream = test_util.test_seed_stream() + v = tf.random.uniform( + [int(1e5)], minval=0., maxval=0.5, seed=seed_stream(), dtype=dtype) + z = tf.random.uniform( + [int(1e5)], minval=100., maxval=10000., seed=seed_stream(), dtype=dtype) self.VerifyBesselKve(v, z, rtol=rtol) @test_util.numpy_disable_gradient_test diff --git a/tensorflow_probability/python/math/gradient.py b/tensorflow_probability/python/math/gradient.py index f7a202e1e1..9b609d3ff9 100644 --- a/tensorflow_probability/python/math/gradient.py +++ b/tensorflow_probability/python/math/gradient.py @@ -18,114 +18,357 @@ from __future__ import division from __future__ import print_function - import tensorflow.compat.v2 as tf +from tensorflow_probability.python.internal import tensor_util +from tensorflow_probability.python.internal.backend.numpy import tf_inspect + __all__ = [ 'value_and_gradient', ] -def _prepare_args(xs): - """Returns a `list` and a `bool` indicating whether args started list-like.""" - is_list_like = isinstance(xs, (tuple, list)) - if not is_list_like: - xs = [xs] - xs = [ - tf.convert_to_tensor(x, dtype_hint=tf.float32, name='x{}'.format(i)) - for i, x in enumerate(xs) - ] - return xs, is_list_like - - def value_and_gradient(f, - xs, + *args, output_gradients=None, use_gradient_tape=False, - name=None): - """Computes `f(*xs)` and its gradients wrt to `*xs`. + auto_unpack_single_arg=True, + name=None, + **kwargs): + """Computes `f(*args, **kwargs)` and its gradients wrt to `args`, `kwargs`. + + The function `f` is invoked according to one of the following rules: + + 1. If `f` is a function of no arguments then it is called as `f()`. + + 2. If `len(args) == 1`, `len(kwargs) == 0`, `auto_unpack_single_arg == True` + and `isinstance(args[0], (list, tuple))` then `args` is presumed to be a + packed sequence of args, i.e., the function is called as `f(*args[0])`. + + 3. Otherwise, the function is called as `f(*args, **kwargs)`. + + Regardless of how `f` is called, gradients are computed with respect to `args` + and `kwargs`. + + #### Examples + + ```python + tfd = tfp.distributions + tfm = tfp.math + + # Case 1: argless `f`. + x = tf.constant(2.) + tfm.value_and_gradient(lambda: tf.math.log(x), x) + # ==> [log(2.), 0.5] + + # Case 2: packed arguments. + tfm.value_and_gradient(lambda x, y: x * tf.math.log(y), [2., 3.]) + # ==> [2. * np.log(3.), (np.log(3.), 2. / 3)] + + # Case 3: default. + tfm.value_and_gradient(tf.math.log, [1., 2., 3.], + auto_unpack_single_arg=False) + # ==> [(log(1.), log(2.), log(3.)), (1., 0.5, 0.333)] + ``` Args: f: Python `callable` to be differentiated. If `f` returns a scalar, this scalar will be differentiated. If `f` returns a tensor or list of tensors, - by default a scalar will be computed by adding all their values to produce - a single scalar. If desired, the tensors can be elementwise multiplied by - the tensors passed as the `dy` keyword argument to the returned gradient - function. - xs: Python list of parameters of `f` for which to differentiate. (Can also - be single `Tensor`.) - output_gradients: A `Tensor` or list of `Tensor`s the same size as the - result `ys = f(*xs)` and holding the gradients computed for each `y` in - `ys`. This argument is forwarded to the underlying gradient implementation - (i.e., either the `grad_ys` argument of `tf.gradients` or the - `output_gradients` argument of `tf.GradientTape.gradient`). + the gradient will be the sum of the gradients of each part. If desired the + sum can be weighted by `output_gradients` (see below). + *args: Arguments as in `f(*args, **kwargs)` and basis for gradient. + output_gradients: A `Tensor` or structure of `Tensor`s the same size as the + result `ys = f(*args, **kwargs)` and holding the gradients computed for + each `y` in `ys`. This argument is forwarded to the underlying gradient + implementation (i.e., either the `grad_ys` argument of `tf.gradients` or + the `output_gradients` argument of `tf.GradientTape.gradient`). + Default value: `None`. use_gradient_tape: Python `bool` indicating that `tf.GradientTape` should be - used regardless of `tf.executing_eagerly()` status. + used rather than `tf.gradient` and regardless of `tf.executing_eagerly()`. + (It is only possible to use `tf.gradient` when `not use_gradient_tape and + not tf.executing_eagerly()`.) Default value: `False`. + auto_unpack_single_arg: Python `bool` which when `False` means the single + arg case will not be interpreted as a list of arguments. (See case 2.) + Default value: `True`. name: Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., `'value_and_gradient'`). + **kwargs: Named arguments as in `f(*args, **kwargs)` and basis for gradient. Returns: - y: `y = f(*xs)`. - dydx: Gradient of `y` wrt each of `xs`. + y: `y = f(*args, **kwargs)`. + dydx: Gradients of `y` with respect to each of `args` and `kwargs`. """ with tf.name_scope(name or 'value_and_gradient'): - xs, is_xs_list_like = _prepare_args(xs) - if tf.executing_eagerly() or use_gradient_tape: - with tf.GradientTape(watch_accessed_variables=False) as tape: - for x in xs: - tape.watch(x) - y = f(*xs) - dydx = tape.gradient(y, xs, output_gradients=output_gradients) - else: - y = f(*xs) - dydx = tf.gradients(ys=y, xs=xs, grad_ys=output_gradients) - if not is_xs_list_like: - dydx = dydx[0] - return y, dydx + return _value_and_grad_impl( + f, + _gradient_new if tf.executing_eagerly() or use_gradient_tape else + _gradient_old, + *args, + output_gradients=output_gradients, + auto_unpack_single_arg=auto_unpack_single_arg, + expand_tf_modules_as_trainable_vars=False, + **kwargs) + +def value_and_gradient_with_auto_expansion(f, + *args, + output_gradients=None, + use_gradient_tape=False, + auto_unpack_single_arg=True, + name=None, + **kwargs): + """Computes `f(*args, **kwargs)` and its gradients wrt to `args`, `kwargs`. -def value_and_batch_jacobian(f, xs): - """Computes the value and batch jacobian of `f(arg)` w.r.t. `arg`. + The function `f` is invoked according to one of the following rules: + + 1. If `f` is a function of no arguments then it is called as `f()`. + + 2. If `len(args) == 1`, `len(kwargs) == 0`, `auto_unpack_single_arg == True` + and `isinstance(args[0], (list, tuple))` then `args` is presumed to be a + packed sequence of args, i.e., the function is called as `f(*args[0])`. + + 3. Otherwise, the function is called as `f(*args, **kwargs)`. + + Regardless of how `f` is called, gradients are computed with respect to `args` + and `kwargs`. + + #### Examples + + ```python + tfd = tfp.distributions + tfm = tfp.math + + # Case 1: argless `f`. + x = tf.constant(2.) + tfm.value_and_gradient(lambda: tf.math.log(x), x) + # ==> [log(2.), 0.5] + + # Case 2: packed arguments. + tfm.value_and_gradient(lambda x, y: x * tf.math.log(y), [2., 3.]) + # ==> [2. * np.log(3.), (np.log(3.), 2. / 3)] + + # Case 3: default. + tfm.value_and_gradient(tf.math.log, [1., 2., 3.], + auto_unpack_single_arg=False) + # ==> [(log(1.), log(2.), log(3.)), (1., 0.5, 0.333)] + + # The following examples demonstrate computing gradients wrt trainable + # variables. + q = tfd.Normal(tf.Variable(1.), tf.Variable(1., trainable=False)) + r = tfd.Normal(0., tf.Variable(1.)) + tfm.value_and_gradient(tfd.kl_divergence, q, r) + # ==> 0.5, [[1.], [-1.]] + + # The following all produce the same numerical result as above. + tfm.value_and_gradient(lambda: tfd.kl_divergence(q, r), q, r) + tfm.value_and_gradient(lambda *_: tfd.kl_divergence(q, r), q, r) + tfm.value_and_gradient(lambda **kw: tfd.kl_divergence( + tfd.Normal(kw['loc_q'], 1), tfd.Normal(0, kw['scale_r'])), + loc_q=1., scale_r=1.) + + ``` Args: - f: Python callable, returning a 2D `(batch, n)` shaped `Tensor`. - xs: 2D `(batch, n)`-shaped argument `Tensor`(s). If multiple are provided, - a tuple of jacobians are returned. + f: Python `callable` to be differentiated. If `f` returns a scalar, this + scalar will be differentiated. If `f` returns a tensor or list of tensors, + the gradient will be the sum of the gradients of each part. If desired the + sum can be weighted by `output_gradients` (see below). + *args: Arguments as in `f(*args, **kwargs)` and basis for gradient. + output_gradients: A `Tensor` or structure of `Tensor`s the same size as the + result `ys = f(*args, **kwargs)` and holding the gradients computed for + each `y` in `ys`. This argument is forwarded to the underlying gradient + implementation (i.e., either the `grad_ys` argument of `tf.gradients` or + the `output_gradients` argument of `tf.GradientTape.gradient`). + Default value: `None`. + use_gradient_tape: Python `bool` indicating that `tf.GradientTape` should be + used rather than `tf.gradient` and regardless of `tf.executing_eagerly()`. + (It is only possible to use `tf.gradient` when `not use_gradient_tape and + not tf.executing_eagerly()`.) + Default value: `False`. + auto_unpack_single_arg: Python `bool` which when `False` means the single + arg case will not be interpreted as a list of arguments. (See case 2.) + Default value: `True`. + name: Python `str` name prefixed to ops created by this function. + Default value: `None` (i.e., `'value_and_gradient'`). + **kwargs: Named arguments as in `f(*args, **kwargs)` and basis for gradient. Returns: - value: The result of `f(xs)`. - jacobian: A `(batch, n, n)` shaped `Tensor`, `d f(xs) / d xs`, or a tuple - thereof. + y: `y = f(*args, **kwargs)`. + dydx: Gradients of `y` with respect to each of `args` and `kwargs`. """ - xs, is_xs_list_like = _prepare_args(xs) - with tf.GradientTape(persistent=True) as tape: - tape.watch(xs) - result = f(*xs) - try: - jacobian = tuple(tape.batch_jacobian(result, x) for x in xs) - except ValueError: # Fallback to for-loop jacobian. - jacobian = tuple( - tape.batch_jacobian(result, x, experimental_use_pfor=False) for x in xs) - if not is_xs_list_like: - jacobian = jacobian[0] - return result, jacobian + with tf.name_scope(name or 'value_and_gradient'): + return _value_and_grad_impl( + f, + _gradient_new if tf.executing_eagerly() or use_gradient_tape else + _gradient_old, + *args, + output_gradients=output_gradients, + auto_unpack_single_arg=auto_unpack_single_arg, + expand_tf_modules_as_trainable_vars=True, + **kwargs) + + +def value_and_batch_jacobian(f, + *args, + auto_unpack_single_arg=True, + name=None, + **kwargs): + """Computes `f(*args, **kwargs)` and batch Jacobian wrt to `args`, `kwargs`. + + Args: + f: Python `callable`, returning a 2D `(batch, n)` shaped `Tensor`. + *args: Arguments as in `f(*args, **kwargs)` and basis for Jacobian. Each + element must be 2D `(batch, n)`-shaped argument `Tensor`(s). If multiple + are provided, a tuple of jacobians are returned. + auto_unpack_single_arg: Python `bool` which when `False` means the single + arg case will not be interpreted as a list of arguments. + Default value: `True`. + name: Python `str` name prefixed to ops created by this function. + Default value: `None` (i.e., `'value_and_gradient'`). + **kwargs: Named arguments as in `f(*args, **kwargs)` and basis for Jacobian. + Each element must be 2D `(batch, n)`-shaped argument `Tensor`(s). If + multiple are provided, a tuple of jacobians are returned. + + Returns: + y: `y = f(*args, **kwargs)`. + jacobian: A `(batch, n, n)` shaped `Tensor`, `dy/dx`, or a tuple thereof. + """ + with tf.name_scope(name or 'value_and_batch_jacobian'): + return _value_and_grad_impl( + f, + _jacobian, + *args, + output_gradients=None, + auto_unpack_single_arg=auto_unpack_single_arg, + expand_tf_modules_as_trainable_vars=False, + **kwargs) -def batch_jacobian(f, xs): - """Computes the batch jacobian of `f(xs)` w.r.t. `xs`. +def batch_jacobian(f, + *args, + auto_unpack_single_arg=True, + name=None, + **kwargs): + """Computes batch Jacobian of `f(*args, **kwargs)` wrt to `args`, `kwargs`. Args: - f: Python callable, returning a 2D `(batch, n)` shaped `Tensor`. - xs: 2D `(batch, n)`-shaped argument `Tensor`(s). If multiple are provided, - a tuple of jacobians are returned. + f: Python `callable`, returning a 2D `(batch, n)` shaped `Tensor`. + *args: Arguments as in `f(*args, **kwargs)` and basis for Jacobian. Each + element must be 2D `(batch, n)`-shaped argument `Tensor`(s). If multiple + are provided, a tuple of jacobians are returned. + auto_unpack_single_arg: Python `bool` which when `False` means the single + arg case will not be interpreted as a list of arguments. + Default value: `True`. + name: Python `str` name prefixed to ops created by this function. + Default value: `None` (i.e., `'value_and_gradient'`). + **kwargs: Named arguments as in `f(*args, **kwargs)` and basis for Jacobian. + Each element must be 2D `(batch, n)`-shaped argument `Tensor`(s). If + multiple are provided, a tuple of jacobians are returned. Returns: - jacobian: A `(batch, n, n)` shaped `Tensor`, `d f(xs) / d xs`, or a tuple - thereof. + jacobian: A `(batch, n, n)` shaped `Tensor`, `dy/dx`, or a tuple thereof. """ - return value_and_batch_jacobian(f, xs)[1] + return value_and_batch_jacobian( + f, + *args, + auto_unpack_single_arg=auto_unpack_single_arg, + name=name, + **kwargs)[1] + + +def _gradient_new(f, xs, grad_ys): + with tf.GradientTape(watch_accessed_variables=False) as tape: + for x in xs: + tape.watch(x) + y = f() + return y, tape.gradient(y, xs, output_gradients=grad_ys) + + +def _gradient_old(f, xs, grad_ys): + assert not tf.executing_eagerly() + y = f() + return y, tf.gradients(y, xs, grad_ys=grad_ys) + + +def _jacobian(f, xs, grad_ys): + assert grad_ys is None + with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape: + for x in xs: + tape.watch(x) + y = f() + try: + return y, tuple(tape.batch_jacobian(y, x) for x in xs) + except ValueError: # Fallback to for-loop jacobian. + return y, tuple(tape.batch_jacobian(y, x, experimental_use_pfor=False) + for x in xs) + + +def _value_and_grad_impl(f, grad_fn, *args, output_gradients, + auto_unpack_single_arg, + expand_tf_modules_as_trainable_vars=False, + **kwargs): + """Helper which generalizes gradient / Jacobian.""" + if not args and not kwargs: + raise ValueError('Gradient is not defined unless at least one of `arg` or ' + '`kwarg` is specified.') + # The following is for backwards compatibility. In the one arg case with no + # kwargs we can't tell which protocol to use if not for + # `auto_unpack_single_arg`. When `True` and when the sole arg is a tuple + # or list then we unpack it as if it was the args, i.e., preserve the old + # behavior. + do_unpack = (auto_unpack_single_arg and len(args) == 1 and not(kwargs) and + isinstance(args[0], (tuple, list))) + if do_unpack: + args = args[0] + args, kwargs = _prepare_args(args, kwargs) + if expand_tf_modules_as_trainable_vars: + expand_args, expand_kwargs = tf.nest.map_structure( + lambda x: x.trainable_variables if tensor_util.is_module(x) else x, + [args, kwargs]) + else: + expand_args, expand_kwargs = args, kwargs + y, dydx = grad_fn(lambda: f(*args, **kwargs) if _has_args(f) else f(), + tf.nest.flatten([expand_args, expand_kwargs]), + output_gradients) + dydx_args, dydx_kwargs = tf.nest.pack_sequence_as( + [expand_args, expand_kwargs], dydx) + if len(args) == 1 and not do_unpack: + dydx_args = dydx_args[0] + if not kwargs: + return y, dydx_args + if not args: + return y, dydx_kwargs + return y, dydx_args, dydx_kwargs + + +def _prepare_args(args, kwargs): + """Returns structures like inputs with values as Tensors.""" + i = [-1] + def c2t(x): + # Don't use convert_nonref_to_tensor here. We want to have semantics like + # tf.GradientTape which watches only trainable_variables. (Note: we also + # don't want to cal c2t on non-trainable variables since these are already + # watchable by GradientTape.) + if tensor_util.is_module(x) or tensor_util.is_variable(x): + return x + i[0] += 1 + return tf.convert_to_tensor( + x, dtype_hint=tf.float32, name='x{}'.format(i[0])) + return ( + type(args)(c2t(v) for v in args), + type(kwargs)((k, c2t(v)) for k, v in kwargs.items()), + ) + + +def _has_args(fn): + """Returns `True` if the function takes an argument.""" + argspec = tf_inspect.getfullargspec(fn) + return (bool(argspec.args) or + bool(argspec.kwonlyargs) or + argspec.varargs is not None or + argspec.varkw is not None) JAX_MODE = False # Rewritten by script. @@ -136,24 +379,41 @@ def batch_jacobian(f, xs): import numpy as onp # pylint: disable=g-import-not-at-top def value_and_gradient(f, # pylint: disable=function-redefined - xs, + *args, output_gradients=None, use_gradient_tape=False, # pylint: disable=unused-argument - name=None): # pylint: disable=unused-argument - """Computes `f(*xs)` and its gradients wrt to `*xs`.""" - xs, is_xs_list_like = _prepare_args(xs) - y, f_vjp = jax.vjp(f, *xs) + name=None, # pylint: disable=unused-argument + auto_unpack_single_arg=True, + **kwargs): + """Computes `f(*args)` and its gradients wrt to `*args`.""" + if kwargs: + raise NotImplementedError('Jax version of `value_and_gradient` does ' + 'not support `kwargs`.') + do_unpack = (auto_unpack_single_arg and len(args) == 1 and + isinstance(args[0], (tuple, list))) + if do_unpack: + args = args[0] + args, _ = _prepare_args(args, {}) + y, f_vjp = jax.vjp(f, *args) if output_gradients is None: output_gradients = tf.nest.map_structure(np.ones_like, y) dydx = list(f_vjp(output_gradients)) - if not is_xs_list_like: + if len(args) == 1 and not do_unpack: dydx = dydx[0] return y, dydx - def value_and_batch_jacobian(f, xs): # pylint: disable=function-redefined + def value_and_batch_jacobian( # pylint: disable=function-redefined + f, *args, auto_unpack_single_arg=True, name=None, **kwargs): # pylint: disable=unused-argument """JAX implementation of value_and_batch_jacobian.""" - xs, is_xs_list_like = _prepare_args(xs) - y, f_vjp = jax.vjp(f, *xs) + if kwargs: + raise NotImplementedError('Jax version of `value_and_batch_jacobian` ' + 'does not support `kwargs`.') + do_unpack = (auto_unpack_single_arg and len(args) == 1 and + isinstance(args[0], (tuple, list))) + if do_unpack: + args = args[0] + args, _ = _prepare_args(args, {}) + y, f_vjp = jax.vjp(f, *args) # Let `[B, E_1, ..., E_k]` be the shape of `y`, where the first dimension # is a batch dimension. We construct a basis for the cotangent space @@ -164,12 +424,17 @@ def value_and_batch_jacobian(f, xs): # pylint: disable=function-redefined basis = np.broadcast_to( basis, y.shape[:1] + basis.shape[1:]) # `[B, size, E_1, ..., E_k]` - jacobian = jax.vmap(f_vjp, in_axes=1, out_axes=1)(basis) - jacobian = [x.reshape(y.shape + x.shape[2:]) for x in jacobian] - if not is_xs_list_like: - jacobian = jacobian[0] - return y, jacobian + dydx = jax.vmap(f_vjp, in_axes=1, out_axes=1)(basis) + dydx = [x.reshape(y.shape + x.shape[2:]) for x in dydx] + if len(args) == 1 and not do_unpack: + dydx = dydx[0] + return y, dydx - def batch_jacobian(f, xs): # pylint: disable=function-redefined + def batch_jacobian( # pylint: disable=function-redefined + f, *args, auto_unpack_single_arg=True, name=None, **kwargs): # pylint: disable=unused-argument """Computes the batch jacobian of `f(xs)` w.r.t. `xs`.""" - return value_and_batch_jacobian(f, xs)[1] + return value_and_batch_jacobian( + f, + *args, + auto_unpack_single_arg=auto_unpack_single_arg, + **kwargs)[1] diff --git a/tensorflow_probability/python/math/gradient_test.py b/tensorflow_probability/python/math/gradient_test.py index a518afbaec..b1193c7f28 100644 --- a/tensorflow_probability/python/math/gradient_test.py +++ b/tensorflow_probability/python/math/gradient_test.py @@ -27,6 +27,11 @@ from tensorflow_probability.python.internal import test_util from tensorflow_probability.python.math.gradient import batch_jacobian +from tensorflow_probability.python.math.gradient import value_and_gradient_with_auto_expansion + + +tfd = tfp.distributions +tfm = tfp.math @test_util.test_all_tf_execution_regimes @@ -37,7 +42,7 @@ def test_non_list(self): f = lambda x: x**2 / 2 g = lambda x: x x = np.concatenate([np.linspace(-100, 100, int(1e1)), [0]], axis=0) - y, dydx = self.evaluate(tfp.math.value_and_gradient(f, x)) + y, dydx = self.evaluate(tfm.value_and_gradient(f, x)) self.assertAllClose(f(x), y, atol=1e-6, rtol=1e-6) self.assertAllClose(g(x), dydx, atol=1e-6, rtol=1e-6) @@ -47,7 +52,7 @@ def test_list(self): g = lambda x, y: [y, x] args = [np.linspace(0, 100, int(1e1)), np.linspace(-100, 0, int(1e1))] - y, dydx = self.evaluate(tfp.math.value_and_gradient(f, args)) + y, dydx = self.evaluate(tfm.value_and_gradient(f, args)) self.assertAllClose(f(*args), y, atol=1e-6, rtol=1e-6) self.assertAllClose(g(*args), dydx, atol=1e-6, rtol=1e-6) @@ -57,10 +62,130 @@ def test_output_list(self): g = lambda x, y: [y + 1., x] args = [np.linspace(0, 100, int(1e1)), np.linspace(-100, 0, int(1e1))] - y, dydx = self.evaluate(tfp.math.value_and_gradient(f, args)) + y, dydx = self.evaluate(tfm.value_and_gradient(f, args)) self.assertAllClose(f(*args), y, atol=1e-6, rtol=1e-6) self.assertAllClose(g(*args), dydx, atol=1e-6, rtol=1e-6) + @test_util.numpy_disable_gradient_test + @test_util.jax_disable_test_missing_functionality('value_and_gradient') + def test_multi_input_old_style(self): + arg0 = [2., 3., 4.] + arg1 = [5., 6., 7.] + f_actual = lambda x, y: x * np.log(y) + g_actual = lambda x, y: (np.log(y), x / np.array(y)) + y, dydx = self.evaluate( + tfm.value_and_gradient(lambda x, y: x * tf.math.log(y), [arg0, arg1])) + self.assertAllClose(f_actual(arg0, arg1), y, atol=1e-6, rtol=1e-6) + self.assertAllClose(g_actual(arg0, arg1), dydx, atol=1e-6, rtol=1e-6) + + @test_util.numpy_disable_gradient_test + @test_util.jax_disable_test_missing_functionality('value_and_gradient') + def test_multi_input_no_auto_unpack(self): + arg0 = [2., 3., 4.] + arg1 = [5., 6., 7.] + f_actual = lambda x, y: x * np.log(y) + g_actual = lambda x, y: (np.log(y), x / np.array(y)) + + # This is how users would typically write things. + y, dydx = self.evaluate( + tfm.value_and_gradient(lambda x, y: x * tf.math.log(y), arg0, arg1)) + self.assertAllClose(f_actual(arg0, arg1), y, atol=1e-6, rtol=1e-6) + self.assertAllClose(g_actual(arg0, arg1), dydx, atol=1e-6, rtol=1e-6) + + # This is uncommon but possible and unambigous under new style. + y, dydx = self.evaluate(tfm.value_and_gradient( + lambda x: x[0] * tf.math.log(x[1]), [arg0, arg1], + auto_unpack_single_arg=False)) + self.assertAllClose(f_actual(arg0, arg1), y, atol=1e-6, rtol=1e-6) + self.assertAllClose(g_actual(arg0, arg1), dydx, atol=1e-6, rtol=1e-6) + + @test_util.numpy_disable_gradient_test + @test_util.jax_disable_test_missing_functionality('value_and_gradient') + def test_simple_input_no_auto_unpack(self): + x = [1., 2., 3.] + y, dydx = self.evaluate(tfm.value_and_gradient( + tf.math.log, x, auto_unpack_single_arg=False)) + self.assertAllClose(np.log(x), y, atol=1e-6, rtol=1e-6) + self.assertAllClose(1. / np.array(x), dydx, atol=1e-6, rtol=1e-6) + + @test_util.numpy_disable_gradient_test + @test_util.jax_disable_test_missing_functionality('value_and_gradient') + def test_variable_and_constant_identical(self): + expected = (2. * np.log(2.), [1. + np.log(2.), 1. + np.log(2.)]) + x = tf.constant(2.) + self.assertAllClose( + expected, + self.evaluate(tfp.math.value_and_gradient( + lambda a, b: a * tf.math.log(x), x, x)), + atol=1e-6, rtol=1e-6) + x = tf.Variable(2.) + self.evaluate(x.initializer) + self.assertAllClose( + expected, + self.evaluate(tfp.math.value_and_gradient( + lambda a, b: a*tf.math.log(x), x, x)), + atol=1e-6, rtol=1e-6) + + @test_util.numpy_disable_gradient_test + @test_util.jax_disable_test_missing_functionality('value_and_gradient') + def test_docstring_examples(self): + # Case 1: argless `f`. + x = tf.constant(2.) + self.assertAllClose( + [np.log(2.), 0.5], + self.evaluate(tfm.value_and_gradient(lambda: tf.math.log(x), x)), + atol=1e-6, rtol=1e-6) + + # Case 2: packed arguments. + self.assertAllClose( + [2. * np.log(3.), (np.log(3.), 2. / 3)], + self.evaluate(tfm.value_and_gradient( + lambda x, y: x * tf.math.log(y), [2., 3.])), + atol=1e-6, rtol=1e-6) + + # Case 3: default. + x = np.array([1., 2, 3]) + self.assertAllClose( + (np.log(x), 1. / x), + self.evaluate(tfm.value_and_gradient( + tf.math.log, [1., 2., 3.], auto_unpack_single_arg=False)), + atol=1e-6, rtol=1e-6) + + @test_util.numpy_disable_gradient_test + @test_util.jax_disable_test_missing_functionality('value_and_gradient') + def test_variable_tracking(self): + value_and_gradient = value_and_gradient_with_auto_expansion + q = tfd.Normal(tf.Variable(1.), tf.Variable(1., trainable=False)) + r = tfd.Normal(0., tf.Variable(1.)) + self.evaluate([v.initializer for v in q.variables + r.variables]) + + y, dydx = self.evaluate(value_and_gradient(tfd.kl_divergence, q, r)) + self.assertAllClose(0.5, y, atol=1e-6, rtol=1e-6) + self.assertAllClose([[1.], [-1.]], dydx, atol=1e-6, rtol=1e-6) + + y, dydx = self.evaluate(value_and_gradient( + lambda: tfd.kl_divergence(q, r), q, r)) + self.assertAllClose(0.5, y, atol=1e-6, rtol=1e-6) + self.assertAllClose([[1.], [-1.]], dydx, atol=1e-6, rtol=1e-6) + + y, dydx = self.evaluate(value_and_gradient( + lambda: tfd.kl_divergence(q, r), *q.trainable_variables, r)) + self.assertAllClose(0.5, y, atol=1e-6, rtol=1e-6) + self.assertAllClose([1., [-1.]], dydx, atol=1e-6, rtol=1e-6) + + y, dydx = self.evaluate(value_and_gradient( + lambda *_: tfd.kl_divergence(q, r), q, r)) + self.assertAllClose(0.5, y, atol=1e-6, rtol=1e-6) + self.assertAllClose([[1.], [-1.]], dydx, atol=1e-6, rtol=1e-6) + + y, dydx = self.evaluate(value_and_gradient( + lambda **kw: tfd.kl_divergence(tfd.Normal(kw['loc_q'], 1), # pylint: disable=g-long-lambda + tfd.Normal(0, kw['scale_r'])), + loc_q=1., scale_r=1.)) + self.assertAllClose(0.5, y, atol=1e-6, rtol=1e-6) + self.assertAllClose({'loc_q': 1., 'scale_r': -1.}, dydx, + atol=1e-6, rtol=1e-6) + @test_util.numpy_disable_gradient_test def test_output_gradients(self): jacobian = np.float32([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]) @@ -68,7 +193,7 @@ def test_output_gradients(self): x = np.ones([3], dtype=np.float32) output_gradients = np.float32([1., 2., 3.]) y, dydx = self.evaluate( - tfp.math.value_and_gradient(f, x, output_gradients=output_gradients)) + tfm.value_and_gradient(f, x, output_gradients=output_gradients)) self.assertAllClose(f(x), y, atol=1e-6, rtol=1e-6) self.assertAllClose( np.dot(output_gradients, jacobian), dydx, atol=1e-6, rtol=1e-6) @@ -113,7 +238,7 @@ def f(x, y): # [4, 2, 3], [4, 2, 1, 3] -> [4, 3, 2] for i in range(np.prod(out_shape)): idx = (slice(None),) + np.unravel_index(i, out_shape) # pylint: disable=cell-var-from-loop - _, grad = tfp.math.value_and_gradient(lambda x, y: f(x, y)[idx], [x, y]) + _, grad = tfm.value_and_gradient(lambda x, y: f(x, y)[idx], [x, y]) print(grad[0].shape, jac[0].shape, jac[0][idx].shape) self.assertAllClose(grad[0], jac[0][idx]) self.assertAllClose(grad[1], jac[1][idx]) diff --git a/tensorflow_probability/python/math/hypergeometric.py b/tensorflow_probability/python/math/hypergeometric.py new file mode 100644 index 0000000000..a88a6bdf0f --- /dev/null +++ b/tensorflow_probability/python/math/hypergeometric.py @@ -0,0 +1,387 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Implements hypergeometric functions in TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +# [internal] enable type annotations +from __future__ import print_function + +import functools + +# Dependency imports +import numpy as np +import tensorflow.compat.v2 as tf +from tensorflow_probability.python.internal import dtype_util +from tensorflow_probability.python.internal import prefer_static as ps +from tensorflow_probability.python.internal import tensorshape_util + + +__all__ = [ + 'hyp2f1_small_argument', +] + + +def _hyp2f1_taylor_series(a, b, c, z): + """Compute Hyp2F1(a, b, c, z) via the Taylor Series expansion.""" + with tf.name_scope('hyp2f1_taylor_series'): + dtype = dtype_util.common_dtype([a, b, c, z], tf.float32) + a = tf.convert_to_tensor(a, dtype=dtype) + b = tf.convert_to_tensor(b, dtype=dtype) + c = tf.convert_to_tensor(c, dtype=dtype) + z = tf.convert_to_tensor(z, dtype=dtype) + np_finfo = np.finfo(dtype_util.as_numpy_dtype(dtype)) + tolerance = tf.cast(np_finfo.resolution, dtype=dtype) + + broadcast_shape = functools.reduce( + ps.broadcast_shape, + [ps.shape(x) for x in [a, b, c, z]]) + + def taylor_series( + should_stop, + index, + term, + taylor_sum, + previous_term, + previous_taylor_sum, + two_before_taylor_sum): + new_term = term * (a + index) * (b + index) * z / ( + (c + index) * (index + 1.)) + new_term = tf.where(should_stop, term, new_term) + new_taylor_sum = tf.where(should_stop, taylor_sum, taylor_sum + new_term) + + # When a or be is near a negative integer n, it's possibly the term is + # small because we are computing (a + n) * (b + n) in the numerator. + # Checking that three consecutive terms are small compared their + # corresponding sum will let us avoid this error. + should_stop = ( + (tf.math.abs(new_term) < tolerance * tf.math.abs(taylor_sum)) & + (tf.math.abs(term) < tolerance * tf.math.abs(previous_taylor_sum)) & + (tf.math.abs(previous_term) < tolerance * tf.math.abs( + two_before_taylor_sum))) + return ( + should_stop, + index + 1., + new_term, + new_taylor_sum, + term, + taylor_sum, + previous_taylor_sum) + + (_, _, _, taylor_sum, _, _, _) = tf.while_loop( + cond=lambda stop, *_: tf.reduce_any(~stop), + body=taylor_series, + loop_vars=( + tf.zeros(broadcast_shape, dtype=tf.bool), + tf.cast(0., dtype=dtype), + # Only the previous term and taylor sum are used for computation. + # The rest are used for checking convergence. We can safely set + # these to zero. + tf.ones(broadcast_shape, dtype=dtype), + tf.ones(broadcast_shape, dtype=dtype), + tf.zeros(broadcast_shape, dtype=dtype), + tf.zeros(broadcast_shape, dtype=dtype), + tf.zeros(broadcast_shape, dtype=dtype))) + return taylor_sum + + +def _hyp2f1_fraction(a, b, c, z): + """Compute 2F1(a, b, c, z) by using a running fraction.""" + with tf.name_scope('hyp2f1_fraction'): + dtype = dtype_util.common_dtype([a, b, c, z], tf.float32) + a = tf.convert_to_tensor(a, dtype=dtype) + b = tf.convert_to_tensor(b, dtype=dtype) + c = tf.convert_to_tensor(c, dtype=dtype) + z = tf.convert_to_tensor(z, dtype=dtype) + np_finfo = np.finfo(dtype_util.as_numpy_dtype(dtype)) + tolerance = tf.cast(np_finfo.resolution, dtype=dtype) + + broadcast_shape = functools.reduce( + ps.broadcast_shape, + [ps.shape(x) for x in [a, b, c, z]]) + + def hypergeometric_fraction( + should_stop, + index, + numerator_term0, + numerator_term1, + denominator, + fraction, + previous_fraction, + two_before_fraction): + new_numerator_term0 = (numerator_term0 + numerator_term1) * index + new_numerator_term1 = ( + numerator_term1 * (a + index - 1.) * (b + index - 1.) * z) / ( + c + index - 1.) + new_denominator = denominator * index + + # Rescale to prevent overflow. + should_rescale = ((tf.math.abs(new_numerator_term0) > 10.) | + (tf.math.abs(new_numerator_term1) > 10.) | + (tf.math.abs(new_denominator) > 10.)) + new_numerator_term0 = tf.where( + should_rescale, new_numerator_term0 / 10., new_numerator_term0) + new_numerator_term1 = tf.where( + should_rescale, new_numerator_term1 / 10., new_numerator_term1) + new_denominator = tf.where( + should_rescale, new_denominator / 10., new_denominator) + + new_fraction = ( + new_numerator_term0 + new_numerator_term1) / new_denominator + new_fraction = tf.where(should_stop, fraction, new_fraction) + + # When a or be is near a negative integer n, it's possibly the term is + # small because we are computing (a + n) * (b + n) in the numerator. + # Checking that three consecutive terms are small compared their + # corresponding sum will let us avoid this error. + should_stop = ( + (tf.math.abs(new_fraction - fraction) < + tolerance * tf.math.abs(fraction)) & + (tf.math.abs(fraction - previous_fraction) < + tolerance * tf.math.abs(previous_fraction)) & + (tf.math.abs(previous_fraction - two_before_fraction) < + tolerance * tf.math.abs(two_before_fraction))) + return ( + should_stop | (index > 50.), + index + 1., + new_numerator_term0, + new_numerator_term1, + new_denominator, + new_fraction, + fraction, + previous_fraction) + + (_, _, _, _, _, fraction, _, _) = tf.while_loop( + cond=lambda stop, *_: tf.reduce_any(~stop), + body=hypergeometric_fraction, + loop_vars=( + tf.zeros(broadcast_shape, dtype=tf.bool), + tf.cast(1., dtype=dtype), + tf.zeros(broadcast_shape, dtype=dtype), + tf.ones(broadcast_shape, dtype=dtype), + tf.ones(broadcast_shape, dtype=dtype), + # Only the previous term and taylor sum are used for computation. + # The rest are used for checking convergence. We can safely set + # these to zero. + tf.ones(broadcast_shape, dtype=dtype), + tf.zeros(broadcast_shape, dtype=dtype), + tf.zeros(broadcast_shape, dtype=dtype))) + return fraction + + +def _hyp2f1_small_parameters(a, b, c, z): + """"Compute 2F1(a, b, c, z) when a, b, and c are small.""" + safe_c = tf.where(tf.math.abs(c) < 1., c, 0.) + safe_a = tf.where(tf.math.abs(c) < 1., a, 0.) + safe_b = tf.where(tf.math.abs(c) < 1., b, 0.) + result = _hyp2f1_fraction(safe_a, safe_b, safe_c, z) + safe_c = tf.where( + tf.math.abs(c) < 1., tf.math.abs(a) + tf.math.abs(b), c) + result = tf.where( + tf.math.abs(c) < 1., + result, + _hyp2f1_taylor_series(a, b, safe_c, z)) + return result + + +def _gamma_negative(z): + """Returns whether Sign(Gamma(z)) == -1.""" + return (z < 0.) & tf.math.not_equal( + tf.math.floormod(tf.math.floor(z), 2.), 0.) + + +def _hyp2f1_z_near_one(a, b, c, z): + """"Compute 2F1(a, b, c, z) when z is near 1.""" + with tf.name_scope('hyp2f1_z_near_one'): + dtype = dtype_util.common_dtype([a, b, c, z], tf.float32) + a = tf.convert_to_tensor(a, dtype=dtype) + b = tf.convert_to_tensor(b, dtype=dtype) + c = tf.convert_to_tensor(c, dtype=dtype) + z = tf.convert_to_tensor(z, dtype=dtype) + + # When z > 0.9, We can transform z to 1 - z and make use of a hypergeometric + # identity. + + # TODO(b/171982819): When tfp.math.log_gamma_difference and tfp.math.lbeta + # support negative parameters, use them here for greater accuracy. + log_first_coefficient = (tf.math.lgamma(c) + tf.math.lgamma(c - a - b) - + tf.math.lgamma(c - a) - tf.math.lgamma(c - b)) + + sign_first_coefficient = ( + _gamma_negative(c) ^ _gamma_negative(c - a - b) ^ + _gamma_negative(c - a) ^ _gamma_negative(c - b)) + sign_first_coefficient = -2. * tf.cast(sign_first_coefficient, dtype) + 1. + + log_second_coefficient = ( + tf.math.xlog1py(c - a - b, -z) + + tf.math.lgamma(c) + tf.math.lgamma(a + b - c) - + tf.math.lgamma(a) - tf.math.lgamma(b)) + + sign_second_coefficient = ( + _gamma_negative(c) ^ _gamma_negative(a) ^ _gamma_negative(b) ^ + _gamma_negative(a + b - c)) + sign_second_coefficient = -2. * tf.cast(sign_second_coefficient, dtype) + 1. + + safe_a = tf.where(c > 1., b - c + 1, a) + safe_b = tf.where(c > 1., a - c + 1, b) + first_term = _hyp2f1_small_parameters(safe_a, safe_b, a + b - c + 1., 1 - z) + first_term = tf.where( + c > 1., + tf.math.exp(tf.math.xlogy(1. - c, z)) * first_term, + first_term) + + safe_a = tf.where(c > 1., 1. - b, c - a) + safe_b = tf.where(c > 1., 1. - a, c - b) + second_term = _hyp2f1_small_parameters( + safe_a, safe_b, c - a - b + 1., 1 - z) + second_term = tf.where( + c > 1., + tf.math.exp(tf.math.xlogy(1. - c, z)) * second_term, + second_term) + + result = (sign_first_coefficient * tf.math.exp(log_first_coefficient) * + first_term + + sign_second_coefficient * tf.math.exp(log_second_coefficient) * + second_term) + result = tf.where( + c > 1., + tf.math.exp(tf.math.xlogy(1. - c, z)) * result, + result) + return result + + +def _hyp2f1_z_near_negative_one(a, b, c, z): + # 2F1(a, b, c, z) = (1 - z)**(-b) * 2F1(b, c - a, c, z / (z - 1)) + # When z < -0.9, we can transform z to z / (z - 1) and make use of a + # hypergeometric identity. + return tf.math.exp( + tf.math.xlog1py(-b, -z)) * _hyp2f1_small_parameters( + b, c - a, c, z / (z - 1.)) + + +@tf.custom_gradient +def hyp2f1_small_argument(a, b, c, z, name=None): + """Compute the Hypergeometric function 2f1(a, b, c, z) when |z| <= 1. + + Given `a, b, c` and `z`, compute Gauss' Hypergeometric Function, specified + by the series: + + `1 + (a * b/c) * z + (a * (a + 1) * b * (b + 1) / ((c * (c + 1)) * z**2 / 2 + + ... (a)_n * (b)_n / (c)_n * z ** n / n! + ....` + + + NOTE: Gradients with only respect to `z` are available. + NOTE: It is recommended that the arguments are `float64` due to the heavy + loss of precision in float32. + + Args: + a: Floating-point `Tensor`, broadcastable with `b, c, z`. Parameter for the + numerator of the series fraction. + b: Floating-point `Tensor`, broadcastable with `a, c, z`. Parameter for the + numerator of the series fraction. + c: Floating-point `Tensor`, broadcastable with `a, b, z`. Parameter for the + denominator of the series fraction. + z: Floating-point `Tensor`, broadcastable `a, b, c`. Value to compute + `2F1(a, b, c, z)` at. Only values of `|z| < 1` are allowed. + name: A name for the operation (optional). + Default value: `None` (i.e., 'continued_fraction'). + + Returns: + hypergeo: `2F1(a, b, c, z)` + + + #### References + + [1] F. Johansson. Computing hypergeometric functions rigorously. + ACM Transactions on Mathematical Software, August 2019. + https://arxiv.org/abs/1606.06977 + [2] J. Pearson, S. Olver, M. Porter. Numerical methods for the computation of + the confluent and Gauss hypergeometric functions. + Numerical Algorithms, August 2016. + """ + with tf.name_scope(name or 'hyp2f1_small_argument'): + dtype = dtype_util.common_dtype([a, b, c, z], tf.float32) + numpy_dtype = dtype_util.as_numpy_dtype(dtype) + a = tf.convert_to_tensor(a, dtype=dtype) + b = tf.convert_to_tensor(b, dtype=dtype) + c = tf.convert_to_tensor(c, dtype=dtype) + z = tf.convert_to_tensor(z, dtype=dtype) + + # TODO(b/128632717): Extend this by including transformations for: + # * Large parameter ranges. Specifically use Hypergeometric recurrences + # to decrease the parameter values. + # * Include |z| > 1. This can be done via Hypergeometric identities that + # transform to |z| < 1. + # * Handling exceptional cases where parameters are negative integers. + + # Assume that |b| > |a|. Swapping the two makes no effect on the + # calculation. + a_small = tf.where(tf.math.abs(a) > tf.math.abs(b), b, a) + b = tf.where(tf.math.abs(a) > tf.math.abs(b), a, b) + a = a_small + + safe_a = tf.where(c < a + b, c - a, a) + safe_b = tf.where(c < a + b, c - b, b) + + # When |z| < 0.9, use approximations to Taylor Series. + safe_z_small = tf.where(tf.math.abs(z) > 0.9, numpy_dtype(0.), z) + taylor_series = _hyp2f1_small_parameters(safe_a, safe_b, c, safe_z_small) + taylor_series = tf.where( + c < a + b, + tf.math.exp((c - a - b) * tf.math.log1p(-z)) * taylor_series, + taylor_series) + + # When |z| >= 0.9, we use hypergeometric identities to ensure that |z| is + # small. + safe_positive_z_large = tf.where(z >= 0.9, z, numpy_dtype(1.)) + hyp2f1_z_near_one = _hyp2f1_z_near_one(a, b, c, safe_positive_z_large) + + safe_negative_z_large = tf.where(z <= -0.9, z, numpy_dtype(-1.)) + hyp2f1_z_near_negative_one = _hyp2f1_z_near_negative_one( + a, b, c, safe_negative_z_large) + + result = tf.where( + z >= 0.9, hyp2f1_z_near_one, + tf.where(z <= -0.9, hyp2f1_z_near_negative_one, taylor_series)) + + def grad(dy): + grad_z = a * b * dy * hyp2f1_small_argument( + a + 1., b + 1., c + 1., z) / c + # We don't have an easily computable gradient with respect to parameters, + # so ignore that for now. + broadcast_shape = functools.reduce( + ps.broadcast_shape, + [ps.shape(x) for x in [a, b, c]]) + + _, grad_z = _fix_gradient_for_broadcasting( + tf.ones(broadcast_shape, dtype=z.dtype), + z, tf.ones_like(grad_z), grad_z) + return None, None, None, grad_z + + return result, grad + + +def _fix_gradient_for_broadcasting(a, b, grad_a, grad_b): + """Reduces broadcast dimensions for a custom gradient.""" + if (tensorshape_util.is_fully_defined(a.shape) and + tensorshape_util.is_fully_defined(b.shape) and + a.shape == b.shape): + return [grad_a, grad_b] + a_shape = tf.shape(a) + b_shape = tf.shape(b) + ra, rb = tf.raw_ops.BroadcastGradientArgs(s0=a_shape, s1=b_shape) + grad_a = tf.reshape(tf.reduce_sum(grad_a, axis=ra), a_shape) + grad_b = tf.reshape(tf.reduce_sum(grad_b, axis=rb), b_shape) + return [grad_a, grad_b] diff --git a/tensorflow_probability/python/math/hypergeometric_test.py b/tensorflow_probability/python/math/hypergeometric_test.py new file mode 100644 index 0000000000..9d63975ca3 --- /dev/null +++ b/tensorflow_probability/python/math/hypergeometric_test.py @@ -0,0 +1,179 @@ +# Copyright 2020 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for special.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from absl.testing import parameterized +import numpy as np +from scipy import special as scipy_special +import tensorflow.compat.v2 as tf + +from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.math import hypergeometric as tfp_math + + +class Hyp2F1Test(test_util.TestCase): + + def GenParam(self, low, high, dtype, seed): + return tf.random.uniform( + [int(1e4)], seed=seed, + minval=low, maxval=high, dtype=dtype) + + def VerifyHyp2F1( + self, + dtype, + rtol, + a, + b, + c, + z_lower=-0.9, + z_upper=0.9): + seed_stream = test_util.test_seed_stream() + z = tf.random.uniform( + [int(1e4)], seed=seed_stream(), + minval=z_lower, maxval=z_upper, dtype=dtype) + + hyp2f1, a, b, c, z = self.evaluate([ + tfp_math.hyp2f1_small_argument(a, b, c, z), a, b, c, z]) + scipy_hyp2f1 = scipy_special.hyp2f1(a, b, c, z) + self.assertAllClose(hyp2f1, scipy_hyp2f1, rtol=rtol) + + @parameterized.parameters( + ([1], [1], [1], [1]), + ([2], [3, 1], [5, 1, 1], [7, 1, 1, 1]), + ([2, 1], [3], [5, 1, 1, 1], [7, 1, 1]), + ([2, 1, 1, 1], [3, 1, 1], [5], [7, 1]), + ([2, 1, 1], [3, 1, 1, 1], [5, 1], [7]) + ) + def testHyp2F1ShapeBroadcast(self, a_shape, b_shape, c_shape, z_shape): + a = tf.zeros(a_shape, dtype=tf.float32) + b = tf.zeros(b_shape, dtype=tf.float32) + c = 10.5 * tf.ones(c_shape, dtype=tf.float32) + z = tf.zeros(z_shape, dtype=tf.float32) + broadcast_shape = functools.reduce( + tf.broadcast_dynamic_shape, [a_shape, b_shape, c_shape, z_shape]) + hyp2f1 = tfp_math.hyp2f1_small_argument(a, b, c, z) + broadcast_shape = self.evaluate(broadcast_shape) + self.assertAllEqual(hyp2f1.shape, broadcast_shape) + + @parameterized.named_parameters( + ("float32", np.float32, 6e-3), + ("float64", np.float64, 2e-4)) + def testHyp2F1ParamsSmallZSmallCLargerPositive(self, dtype, rtol): + # Ensure that |c| > |b|. + seed_stream = test_util.test_seed_stream() + a = self.GenParam(-0.5, 0.5, dtype, seed_stream()) + b = self.GenParam(-0.5, 0.5, dtype, seed_stream()) + c = self.GenParam(0.5, 1., dtype, seed_stream()) + self.VerifyHyp2F1(dtype, rtol, a, b, c) + self.VerifyHyp2F1(dtype, rtol, a, b, c) + + @parameterized.named_parameters( + ("float64", np.float64, 5e-3)) + def testHyp2F1ParamsSmallZSmallCLargerNegative(self, dtype, rtol): + # Ensure that |c| > |b|. + seed_stream = test_util.test_seed_stream() + a = self.GenParam(-0.5, 0.5, dtype, seed_stream()) + b = self.GenParam(-0.5, 0.5, dtype, seed_stream()) + c = self.GenParam(-1., -0.5, dtype, seed_stream()) + self.VerifyHyp2F1(dtype, rtol, a, b, c) + self.VerifyHyp2F1(dtype, rtol, a, b, c) + + @parameterized.named_parameters( + ("float64", np.float64, 2e-4)) + def testHyp2F1ParamsSmallZSmallCSmaller(self, dtype, rtol): + # Ensure that |c| < |b|. + seed_stream = test_util.test_seed_stream() + a = self.GenParam(0.5, 1., dtype, seed_stream()) + b = self.GenParam(0.5, 1., dtype, seed_stream()) + c = self.GenParam(0., 0.5, dtype, seed_stream()) + self.VerifyHyp2F1(dtype, rtol, a, b, c) + self.VerifyHyp2F1(dtype, rtol, a, b, c) + + @parameterized.named_parameters( + ("float64", np.float64, 2e-4)) + def testHyp2F1ParamsSmallZPositiveLargeCLarger(self, dtype, rtol): + seed_stream = test_util.test_seed_stream() + a = self.GenParam(-0.5, 0.5, dtype, seed_stream()) + b = self.GenParam(-0.5, 0.5, dtype, seed_stream()) + c = self.GenParam(0.5, 1., dtype, seed_stream()) + self.VerifyHyp2F1(dtype, rtol, a, b, c, z_lower=0.9, z_upper=1.) + + @parameterized.named_parameters( + ("float64", np.float64, 2e-4)) + def testHyp2F1ParamsSmallZPositiveLargeCSmaller(self, dtype, rtol): + seed_stream = test_util.test_seed_stream() + a = self.GenParam(0.5, 1., dtype, seed_stream()) + b = self.GenParam(0.5, 1., dtype, seed_stream()) + c = self.GenParam(-0.5, 0.5, dtype, seed_stream()) + self.VerifyHyp2F1(dtype, rtol, a, b, c, z_lower=0.9, z_upper=1.) + + @parameterized.named_parameters( + ("float32", np.float32, 6e-3), + ("float64", np.float64, 2e-5)) + def testHyp2F1ParamsSmallZNegativeLargeCLarger(self, dtype, rtol): + seed_stream = test_util.test_seed_stream() + a = self.GenParam(-0.5, 0.5, dtype, seed_stream()) + b = self.GenParam(-0.5, 0.5, dtype, seed_stream()) + c = self.GenParam(0.5, 1., dtype, seed_stream()) + self.VerifyHyp2F1(dtype, rtol, a, b, c, z_lower=-1., z_upper=-0.9) + + @parameterized.named_parameters( + ("float64", np.float64, 2e-5)) + def testHyp2F1ParamsSmallZNegativeLargeCSmaller(self, dtype, rtol): + seed_stream = test_util.test_seed_stream() + a = self.GenParam(0.5, 1., dtype, seed_stream()) + b = self.GenParam(0.5, 1., dtype, seed_stream()) + c = self.GenParam(-0.5, 0.5, dtype, seed_stream()) + self.VerifyHyp2F1(dtype, rtol, a, b, c, z_lower=-1., z_upper=-0.9) + + @parameterized.named_parameters( + ("float64", np.float64, 2e-5)) + def testHyp2F1ParamsMediumCLarger(self, dtype, rtol): + seed_stream = test_util.test_seed_stream() + a = self.GenParam(-10., 10., dtype, seed_stream()) + b = self.GenParam(-10., 10., dtype, seed_stream()) + c = self.GenParam(10., 20., dtype, seed_stream()) + self.VerifyHyp2F1(dtype, rtol, a, b, c, z_lower=-1., z_upper=-1.) + + @parameterized.named_parameters( + ("float64", np.float64, 2e-5)) + def testHyp2F1ParamsLargerCLarger(self, dtype, rtol): + seed_stream = test_util.test_seed_stream() + a = self.GenParam(10., 50., dtype, seed_stream()) + b = self.GenParam(10., 50., dtype, seed_stream()) + c = self.GenParam(50., 100., dtype, seed_stream()) + self.VerifyHyp2F1(dtype, rtol, a, b, c, z_lower=-1., z_upper=-1.) + + @test_util.numpy_disable_gradient_test + @test_util.jax_disable_test_missing_functionality( + "Gradients not supported in JAX.") + def test2F1HypergeometricGradient(self): + a = tf.constant([-0.1,], dtype=np.float64)[..., tf.newaxis] + b = tf.constant([0.8,], dtype=np.float64)[..., tf.newaxis] + c = tf.constant([9.9,], dtype=np.float64)[..., tf.newaxis] + z = tf.constant([0.1], dtype=np.float64) + err = self.compute_max_gradient_error( + functools.partial(tfp_math.hyp2f1_small_argument, a, b, c), [z]) + self.assertLess(err, 2e-4) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_probability/python/math/psd_kernels/parabolic.py b/tensorflow_probability/python/math/psd_kernels/parabolic.py index 84d6a701c9..6f1b62021a 100644 --- a/tensorflow_probability/python/math/psd_kernels/parabolic.py +++ b/tensorflow_probability/python/math/psd_kernels/parabolic.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""ExpSinSquared kernel.""" +"""Parabolic kernel.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np import tensorflow.compat.v2 as tf from tensorflow_probability.python.internal import assert_util @@ -34,8 +33,7 @@ class Parabolic(PositiveSemidefiniteKernel): """The Parabolic kernel. ```none - k(x, y) = 3 / (4 * sqrt(5)) * amplitude * - max(0, 1 - (||x_k - y_k|| / (length_scale * sqrt(5)))**2) + k(x, y) = 3 / 4 * amplitude * max(0, 1 - (||x_k - y_k|| / length_scale**2) ``` where the double-bars represent vector length (ie, Euclidean, or L2 norm). @@ -46,7 +44,7 @@ class Parabolic(PositiveSemidefiniteKernel): `efficiency = sqrt(integral(u**2 k(u) du)) integral(k(u)**2 du)`. This optimality was first derived in a different context [1], and suggested for use in KDE by Epanechnikov in [2]. This is nicely summarized in [3], adjacent to - Fig 3.1. + Fig 3.1. The Epanechnikov kernel integrates to `1` over its support `[-1, 1]`. #### References @@ -121,15 +119,13 @@ def _batch_shape_tensor(self): def _apply_with_distance( self, x1, x2, pairwise_square_distance, example_ndims=0): - default_bandwidth_sq = 5. - pairwise_square_distance = pairwise_square_distance / default_bandwidth_sq if self.length_scale is not None: length_scale = tf.convert_to_tensor(self.length_scale) length_scale = util.pad_shape_with_ones( length_scale, example_ndims) pairwise_square_distance = pairwise_square_distance / length_scale**2 - default_scale = tf.cast(.75 / np.sqrt(5.), pairwise_square_distance.dtype) + default_scale = tf.cast(.75, pairwise_square_distance.dtype) result = tf.nn.relu(1 - pairwise_square_distance) * default_scale if self.amplitude is not None: diff --git a/tensorflow_probability/python/math/psd_kernels/parabolic_test.py b/tensorflow_probability/python/math/psd_kernels/parabolic_test.py index e8912f4400..8a0318ed2d 100644 --- a/tensorflow_probability/python/math/psd_kernels/parabolic_test.py +++ b/tensorflow_probability/python/math/psd_kernels/parabolic_test.py @@ -56,10 +56,16 @@ def testValuesAreCorrect(self, feature_ndims, dims): x = np.random.uniform(-1, 1, size=shape).astype(np.float32) y = np.random.uniform(-1, 1, size=shape).astype(np.float32) self.assertAllClose( - amplitude * .75 / np.sqrt(5) * - np.maximum(0., 1 - np.sum((x - y)**2) / length_scale**2 / 5), + amplitude * .75 * + np.maximum(0., 1 - np.sum((x - y)**2) / length_scale**2), self.evaluate(k.apply(x, y))) + def testEpanechnikov(self): + k = tfp.math.psd_kernels.Parabolic() + self.assertAllClose(.75, k.matrix([[0.]], [[0.]])[0, 0]) + self.assertAllEqual([0., 0.], k.matrix([[0.]], [[1.], [-1.]])[0]) + self.assertAllEqual([0., 0.], k.matrix([[0.]], [[1.1], [-1.1]])[0]) + def testNoneShapes(self): k = tfp.math.psd_kernels.Parabolic( amplitude=np.reshape(np.arange(12.), [2, 3, 2])) diff --git a/tensorflow_probability/python/math/root_search.py b/tensorflow_probability/python/math/root_search.py index 1c9232568a..99da7d3966 100644 --- a/tensorflow_probability/python/math/root_search.py +++ b/tensorflow_probability/python/math/root_search.py @@ -22,10 +22,16 @@ import tensorflow.compat.v2 as tf +from tensorflow_probability.python.internal import assert_util from tensorflow_probability.python.internal import dtype_util +from tensorflow_probability.python.internal import prefer_static as ps +from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import + __all__ = [ 'secant_root', + 'find_root_chandrupatla', + 'find_root_secant', ] RootSearchResults = collections.namedtuple( @@ -43,16 +49,16 @@ ]) -def secant_root(objective_fn, - initial_position, - next_position=None, - value_at_position=None, - position_tolerance=1e-8, - value_tolerance=1e-8, - max_iterations=50, - stopping_policy_fn=tf.reduce_all, - validate_args=False, - name=None): +def find_root_secant(objective_fn, + initial_position, + next_position=None, + value_at_position=None, + position_tolerance=1e-8, + value_tolerance=1e-8, + max_iterations=50, + stopping_policy_fn=tf.reduce_all, + validate_args=False, + name=None): r"""Finds root(s) of a function of single variable using the secant method. The [secant method](https://en.wikipedia.org/wiki/Secant_method) is a @@ -291,7 +297,7 @@ def _body(position, value_at_position, num_iterations, step, finished): return (next_position, value_at_next_position, num_iterations, next_step, is_finished) - with tf.name_scope(name or 'secant_root'): + with tf.name_scope(name or 'find_root_secant'): assertions = [] if validate_args: @@ -319,3 +325,207 @@ def _body(position, value_at_position, num_iterations, step, finished): estimated_root=root, objective_at_estimated_root=value_at_root, num_iterations=num_iterations) + + +secant_root = deprecation.deprecated_alias( + 'tfp.math.secant_root', 'tfp.math.find_root_secant', find_root_secant) + + +def _structure_broadcasting_where(c, x, y): + """Selects elements from two structures using a shared condition `c`.""" + return tf.nest.map_structure( + lambda xp, yp: tf.where(c, xp, yp), x, y) + + +def find_root_chandrupatla(objective_fn, + low, + high, + position_tolerance=1e-8, + value_tolerance=0., + max_iterations=50, + stopping_policy_fn=tf.reduce_all, + validate_args=False, + name='find_root_chandrupatla'): + r"""Finds root(s) of a scalar function using Chandrupatla's method. + + Chandrupatla's method [1, 2] is a root-finding algorithm that is guaranteed + to converge if a root lies within the given bounds. It generalizes the + [bisection method](https://en.wikipedia.org/wiki/Bisection_method); at each + step it chooses to perform either bisection or inverse quadratic + interpolation. This makes it similar in spirit to [Brent's method]( + https://en.wikipedia.org/wiki/Brent%27s_method), which also considers steps + that use the secant method, but Chandrupatla's method is simpler and often + converges at least as quickly [3]. + + Args: + objective_fn: Python callable for which roots are searched. It must be a + callable of a single variable. `objective_fn` must return a `Tensor` with + shape `batch_shape` and dtype matching `lower_bound` and `upper_bound`. + low: Float `Tensor` of shape `batch_shape` representing a lower + bound(s) on the value of a root(s). + high: Float `Tensor` of shape `batch_shape` representing an upper + bound(s) on the value of a root(s). + position_tolerance: Optional `Tensor` representing the maximum absolute + error in the positions of the estimated roots. Shape must broadcast with + `batch_shape`. + Default value: `1e-8`. + value_tolerance: Optional `Tensor` representing the absolute error allowed + in the value of the objective function. If the absolute value of + `objective_fn` is smaller than + `value_tolerance` at a given position, then that position is considered a + root for the function. Shape must broadcast with `batch_shape`. + Default value: `1e-8`. + max_iterations: Optional `Tensor` or Python integer specifying the maximum + number of steps to perform. Shape must broadcast with `batch_shape`. + Default value: `50`. + stopping_policy_fn: Python `callable` controlling the algorithm termination. + It must be a callable accepting a `Tensor` of booleans with the same shape + as `lower_bound` and `upper_bound` (denoting whether each search is + finished), and returning a scalar boolean `Tensor` indicating + whether the overall search should stop. Typical values are + `tf.reduce_all` (which returns only when the search is finished for all + points), and `tf.reduce_any` (which returns as soon as the search is + finished for any point). + Default value: `tf.reduce_all` (returns only when the search is finished + for all points). + validate_args: Python `bool` indicating whether to validate arguments. + Default value: `False`. + name: Python `str` name prefixed to ops created by this function. + Default value: 'find_root_chandrupatla'. + + Returns: + root_search_results: A Python `namedtuple` containing the following items: + estimated_root: `Tensor` containing the last position explored. If the + search was successful within the specified tolerance, this position is + a root of the objective function. + objective_at_estimated_root: `Tensor` containing the value of the + objective function at `position`. If the search was successful within + the specified tolerance, then this is close to 0. + num_iterations: The number of iterations performed. + + #### References + + [1] Tirupathi R. Chandrupatla. A new hybrid quadratic/bisection algorithm for + finding the zero of a nonlinear function without using derivatives. + _Advances in Engineering Software_, 28.3:145-149, 1997. + [2] Philipp OJ Scherer. Computational Physics. _Springer Berlin_, + Heidelberg, 2010. + Section 6.1.7.3 https://books.google.com/books?id=cC-8BAAAQBAJ&pg=PA95 + [3] Jason Sachs. Ten Little Algorithms, Part 5: Quadratic Extremum + Interpolation and Chandrupatla's Method (2015). + https://www.embeddedrelated.com/showarticle/855.php + """ + + ################################################ + # Loop variables used by Chandrupatla's method: + # + # a: endpoint of an interval `[min(a, b), max(a, b)]` containing the + # root. There is no guarantee as to which of `a` and `b` is larger. + # b: endpoint of an interval `[min(a, b), max(a, b)]` containing the + # root. There is no guarantee as to which of `a` and `b` is larger. + # f_a: value of the objective at `a`. + # f_b: value of the objective at `b`. + # t: the next position to be evaluated as the coefficient of a convex + # combination of `a` and `b` (i.e., a value in the unit interval). + # num_iterations: integer number of steps taken so far. + # converged: boolean indicating whether each batch element has converged. + # + # All variables have the same shape `batch_shape`. + + def _should_continue(a, b, f_a, f_b, t, num_iterations, converged): + del a, b, f_a, f_b, t # Unused. + all_converged = stopping_policy_fn( + tf.logical_or(converged, + num_iterations >= max_iterations)) + return ~all_converged + + def _body(a, b, f_a, f_b, t, num_iterations, converged): + """One step of Chandrupatla's method for root finding.""" + previous_loop_vars = (a, b, f_a, f_b, t, num_iterations, converged) + finalized_elements = tf.logical_or(converged, + num_iterations >= max_iterations) + + # Evaluate the new point. + x_new = (1 - t) * a + t * b + f_new = objective_fn(x_new) + # If we've bisected (t==0.5) and the new float value for `a` is identical to + # that from the previous iteration, then we'll keep bisecting (the + # logic below will set t==0.5 for the next step), and nothing further will + # change. + at_fixed_point = tf.equal(x_new, a) & tf.equal(t, 0.5) + # Otherwise, tighten the bounds. + a, b, c, f_a, f_b, f_c = _structure_broadcasting_where( + tf.equal(tf.math.sign(f_new), tf.math.sign(f_a)), + (x_new, b, a, f_new, f_b, f_a), + (x_new, a, b, f_new, f_a, f_b)) + + # Check for convergence. + f_best = tf.where(tf.abs(f_a) < tf.abs(f_b), f_a, f_b) + interval_tolerance = position_tolerance / (tf.abs(b - c)) + converged = tf.logical_or(interval_tolerance > 0.5, + tf.logical_or( + tf.math.abs(f_best) <= value_tolerance, + at_fixed_point)) + + # Propose next point to evaluate. + xi = (a - b) / (c - b) + phi = (f_a - f_b) / (f_c - f_b) + t = tf.where( + # Condition for inverse quadratic interpolation. + tf.logical_and(1 - tf.math.sqrt(1 - xi) < phi, + tf.math.sqrt(xi) > phi), + # Propose a point by inverse quadratic interpolation. + (f_a / (f_b - f_a) * f_c / (f_b - f_c) + + (c - a) / (b - a) * f_a / (f_c - f_a) * f_b / (f_c - f_b)), + # Otherwise, just cut the interval in half (bisection). + 0.5) + # Constrain the proposal to the current interval (0 < t < 1). + t = tf.minimum(tf.maximum(t, interval_tolerance), + 1 - interval_tolerance) + + # Update elements that haven't converged. + return _structure_broadcasting_where( + finalized_elements, + previous_loop_vars, + (a, b, f_a, f_b, t, num_iterations + 1, converged)) + + with tf.name_scope(name): + max_iterations = tf.convert_to_tensor( + max_iterations, name='max_iterations', dtype_hint=tf.int32) + a = tf.convert_to_tensor(low, name='lower_bound') + b = tf.convert_to_tensor(high, name='upper_bound') + f_a, f_b = objective_fn(a), objective_fn(b) + batch_shape = ps.broadcast_shape(ps.shape(f_a), ps.shape(f_b)) + + assertions = [] + if validate_args: + assertions += [ + assert_util.assert_none_equal( + tf.math.sign(f_a), tf.math.sign(f_b), + message='Bounds must be on different sides of a root.')] + + with tf.control_dependencies(assertions): + initial_loop_vars = [ + a, + b, + f_a, + f_b, + tf.cast(0.5, dtype=f_a.dtype), + tf.cast(0, dtype=max_iterations.dtype), + False + ] + a, b, f_a, f_b, _, num_iterations, _ = tf.while_loop( + _should_continue, + _body, + loop_vars=tf.nest.map_structure( + lambda x: tf.broadcast_to(x, batch_shape), + initial_loop_vars)) + + x_best, f_best = _structure_broadcasting_where( + tf.abs(f_a) < tf.abs(f_b), + (a, f_a), + (b, f_b)) + return RootSearchResults( + estimated_root=x_best, + objective_at_estimated_root=f_best, + num_iterations=num_iterations) diff --git a/tensorflow_probability/python/math/root_search_test.py b/tensorflow_probability/python/math/root_search_test.py index 962ae34005..7a0754f9cb 100644 --- a/tensorflow_probability/python/math/root_search_test.py +++ b/tensorflow_probability/python/math/root_search_test.py @@ -24,12 +24,13 @@ import tensorflow.compat.v2 as tf import tensorflow_probability as tfp - +from tensorflow_probability.python.internal import samplers +from tensorflow_probability.python.internal import special_math from tensorflow_probability.python.internal import test_util @test_util.test_all_tf_execution_regimes -class RootSearchTest(test_util.TestCase): +class SecantRootSearchTest(test_util.TestCase): def test_secant_finds_all_roots_from_one_initial_position(self): f = lambda x: (63 * x**5 - 70 * x**3 + 15 * x) / 8. @@ -39,7 +40,7 @@ def test_secant_finds_all_roots_from_one_initial_position(self): tolerance = 1e-8 roots, value_at_roots, _ = self.evaluate( - tfp.math.secant_root(f, guess, position_tolerance=tolerance)) + tfp.math.find_root_secant(f, guess, position_tolerance=tolerance)) expected_roots = [optimize.newton(f, x0), optimize.newton(f, x1)] zeros = [0., 0.] @@ -56,7 +57,7 @@ def test_secant_finds_any_root_from_one_initial_position(self): tolerance = 1e-8 # Only the root close to the first starting point will be found. roots, value_at_roots, _ = self.evaluate( - tfp.math.secant_root( + tfp.math.find_root_secant( f, guess, position_tolerance=tolerance, @@ -80,7 +81,8 @@ def test_secant_finds_all_roots_from_two_initial_positions(self): tolerance = 1e-8 roots, value_at_roots, _ = self.evaluate( - tfp.math.secant_root(f, guess, guess_1, position_tolerance=tolerance)) + tfp.math.find_root_secant( + f, guess, guess_1, position_tolerance=tolerance)) expected_roots = [optimize.newton(f, x0), optimize.newton(f, x1)] zeros = [0., 0.] @@ -97,7 +99,7 @@ def test_secant_finds_any_roots_from_two_initial_positions(self): tolerance = 1e-8 roots, value_at_roots, _ = self.evaluate( - tfp.math.secant_root( + tfp.math.find_root_secant( f, guess, next_guess, @@ -121,7 +123,7 @@ def test_secant_finds_all_roots_using_float32(self): tolerance = 1e-8 roots, value_at_roots, _ = self.evaluate( - tfp.math.secant_root(f, guess, position_tolerance=tolerance)) + tfp.math.find_root_secant(f, guess, position_tolerance=tolerance)) expected_roots = [optimize.newton(f, x0), optimize.newton(f, x1)] zeros = [0., 0.] @@ -137,7 +139,7 @@ def test_secant_skips_iteration(self): # Skip iteration entirely. This should be a no-op. guess, result = self.evaluate( - [guess, tfp.math.secant_root(f, guess, max_iterations=0)]) + [guess, tfp.math.find_root_secant(f, guess, max_iterations=0)]) self.assertAllEqual(result.estimated_root, guess) @@ -148,7 +150,7 @@ def test_secant_invalid_position_tolerance(self): with self.assertRaisesOpError( '`position_tolerance` must be greater than 0.'): self.evaluate( - tfp.math.secant_root( + tfp.math.find_root_secant( f, guess, position_tolerance=-1e-8, validate_args=True)) def test_secant_invalid_value_tolerance(self): @@ -157,7 +159,7 @@ def test_secant_invalid_value_tolerance(self): guess = tf.constant(-2, dtype=tf.float64) with self.assertRaisesOpError('`value_tolerance` must be greater than 0.'): self.evaluate( - tfp.math.secant_root( + tfp.math.find_root_secant( f, guess, value_tolerance=-1e-8, validate_args=True)) def test_secant_invalid_max_iterations(self): @@ -166,8 +168,94 @@ def test_secant_invalid_max_iterations(self): guess = tf.constant(-2, dtype=tf.float64) with self.assertRaisesOpError('`max_iterations` must be nonnegative.'): self.evaluate( - tfp.math.secant_root(f, guess, max_iterations=-1, validate_args=True)) + tfp.math.find_root_secant( + f, guess, max_iterations=-1, validate_args=True)) + +@test_util.test_all_tf_execution_regimes +class ChandrupatlaRootSearchTest(test_util.TestCase): + + def test_chandrupatla_scalar_inverse_gaussian_cdf(self): + true_x = 3.14159 + u = special_math.ndtr(true_x) + + roots, value_at_roots, _ = tfp.math.find_root_chandrupatla( + objective_fn=lambda x: special_math.ndtr(x) - u, + low=-100., + high=100., + position_tolerance=1e-8) + self.assertAllClose(value_at_roots, tf.zeros_like(value_at_roots)) + # The normal CDF function is not precise enough to be inverted to a + # position tolerance of 1e-8 (the objective goes to zero relatively + # far from the expected point), so check it at a lower tolerance. + self.assertAllClose(roots, true_x, atol=1e-4) + + def test_chandrupatla_batch_high_degree_polynomial(self): + seed = test_util.test_seed(sampler_type='stateless') + expected_roots = self.evaluate(samplers.normal( + [4, 3], seed=seed)) + roots, value_at_roots, _ = tfp.math.find_root_chandrupatla( + objective_fn=lambda x: (x - expected_roots)**15, + low=-20., + high=20., + position_tolerance=1e-8) + self.assertAllClose(value_at_roots, tf.zeros_like(value_at_roots)) + # The function is not precise enough to be inverted to a + # position tolerance of 1e-8, (the objective goes to zero relatively + # far from the expected point), so check it at a lower tolerance. + self.assertAllClose(roots, expected_roots, atol=1e-2) + + def test_chandrupatla_max_iterations(self): + expected_roots = samplers.normal( + [4, 3], seed=test_util.test_seed(sampler_type='stateless')) + max_iterations = samplers.uniform( + [4, 3], minval=1, maxval=6, dtype=tf.int32, + seed=test_util.test_seed(sampler_type='stateless')) + _, _, num_iterations = tfp.math.find_root_chandrupatla( + objective_fn=lambda x: (x - expected_roots)**3, + low=-1000000., + high=1000000., + position_tolerance=1e-8, + max_iterations=max_iterations) + self.assertAllClose(num_iterations, + max_iterations) + + def test_chandrupatla_halts_at_fixed_point(self): + # This search would naively get stuck at the interval + # {a=1.4717137813568115, b=1.471713662147522}, which does not quite + # satisfy the tolerance, but will never be tightened further because it has + # the property that `0.5 * a + 0.5 * b == a` in float32. The search should + # detect the fixed point and halt early. + max_iterations = 50 + _, _, num_iterations = tfp.math.find_root_chandrupatla( + lambda ux: tf.math.igamma(2., tf.nn.softplus(ux)) - 0.5, + low=-100., + high=100., + position_tolerance=1e-8, + value_tolerance=1e-8, + max_iterations=max_iterations) + self.assertLess(self.evaluate(num_iterations), max_iterations) + + def test_chandrupatla_float64_high_precision(self): + expected_roots = samplers.normal( + [4, 3], seed=test_util.test_seed(sampler_type='stateless'), + dtype=tf.float64) + tolerance = 1e-12 + roots, value_at_roots, _ = tfp.math.find_root_chandrupatla( + objective_fn=lambda x: (x - expected_roots)**3, + low=tf.convert_to_tensor(-100., dtype=expected_roots.dtype), + high=tf.convert_to_tensor(100., dtype=expected_roots.dtype), + position_tolerance=tolerance) + self.assertAllClose(roots, expected_roots, atol=tolerance) + self.assertAllClose(value_at_roots, tf.zeros_like(value_at_roots)) + + def test_chandrupatla_invalid_bounds(self): + with self.assertRaisesOpError('must be on different sides of a root'): + self.evaluate(tfp.math.find_root_chandrupatla( + lambda x: x**2 - 2., + 3., + 4., + validate_args=True)) if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_probability/python/random/BUILD b/tensorflow_probability/python/random/BUILD index 7a6e0cd8b9..9f35fc0a7a 100644 --- a/tensorflow_probability/python/random/BUILD +++ b/tensorflow_probability/python/random/BUILD @@ -58,5 +58,6 @@ multi_substrate_py_test( "//tensorflow_probability", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:test_util", + # tensorflow/compiler/jit dep, ], ) diff --git a/tensorflow_probability/python/random/__init__.py b/tensorflow_probability/python/random/__init__.py index 194ae24059..f8661560cb 100644 --- a/tensorflow_probability/python/random/__init__.py +++ b/tensorflow_probability/python/random/__init__.py @@ -22,10 +22,12 @@ from tensorflow_probability.python.internal.samplers import split_seed from tensorflow_probability.python.random.random_ops import rademacher from tensorflow_probability.python.random.random_ops import rayleigh +from tensorflow_probability.python.random.random_ops import spherical_uniform _allowed_symbols = [ 'rademacher', 'rayleigh', + 'spherical_uniform', 'split_seed', ] diff --git a/tensorflow_probability/python/random/random_ops.py b/tensorflow_probability/python/random/random_ops.py index 60ce32569e..f4ed20da2a 100644 --- a/tensorflow_probability/python/random/random_ops.py +++ b/tensorflow_probability/python/random/random_ops.py @@ -21,13 +21,17 @@ from __future__ import division from __future__ import print_function +import numpy as np + import tensorflow.compat.v2 as tf +from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import samplers __all__ = [ 'rademacher', 'rayleigh', + 'spherical_uniform', ] @@ -53,6 +57,7 @@ def rademacher(shape, dtype=tf.float32, seed=None, name=None): # memory (host or device) as the downstream cast will want to put it. The # convention on GPU is that int32 are in host memory and int64 are in device # memory. + shape = ps.convert_to_shape_tensor(shape) generation_dtype = tf.int64 if tf.as_dtype(dtype) != tf.int32 else tf.int32 random_bernoulli = samplers.uniform( shape, minval=0, maxval=2, dtype=generation_dtype, seed=seed) @@ -98,3 +103,82 @@ def rayleigh(shape, scale=None, dtype=tf.float32, seed=None, name=None): if scale is None: return x return x * scale + + +def spherical_uniform( + shape, + dimension, + dtype=tf.float32, + seed=None, + name=None): + """Generates `Tensor` drawn from a uniform distribution on the sphere. + + Args: + shape: Vector-shaped, `int` `Tensor` representing shape of output. + dimension: Scalar `int` `Tensor`, representing the dimensionality of the + space where the sphere is embedded. + dtype: (Optional) TF `dtype` representing `dtype` of output. + Default value: `tf.float32`. + seed: (Optional) Python integer to seed the random number generator. + Default value: `None` (i.e., no seed). + name: Python `str` name prefixed to Ops created by this function. + Default value: `None` (i.e., 'random_rayleigh'). + + Returns: + spherical_uniform: `Tensor` with specified `shape` and `dtype` consisting + of positive real values drawn from a Rayleigh distribution with specified + `scale`. + """ + with tf.name_scope(name or 'spherical_uniform'): + seed = samplers.sanitize_seed(seed) + dimension = ps.convert_to_shape_tensor(tf.cast(dimension, dtype=tf.int32)) + shape = ps.convert_to_shape_tensor(shape, dtype=tf.int32) + dimension_static = tf.get_static_value(dimension) + sample_shape = ps.concat([shape, [dimension]], axis=0) + sample_shape = ps.convert_to_shape_tensor(sample_shape) + # Special case one and two dimensions. This is to guard against the case + # where the normal samples are zero. This can happen in dimensions 1 and 2. + if dimension_static is not None: + # This is equivalent to sampling Rademacher random variables. + if dimension_static == 1: + return rademacher(sample_shape, dtype=dtype, seed=seed) + elif dimension_static == 2: + u = samplers.uniform( + shape, minval=0, maxval=2 * np.pi, dtype=dtype, seed=seed) + return tf.stack([tf.math.cos(u), tf.math.sin(u)], axis=-1) + else: + normal_samples = samplers.normal( + shape=ps.concat([shape, [dimension_static]], axis=0), + seed=seed, + dtype=dtype) + unit_norm = normal_samples / tf.norm( + normal_samples, ord=2, axis=-1)[..., tf.newaxis] + return unit_norm + + # If we can't determine the dimension statically, tf.where between the + # different options. + r_seed, u_seed, n_seed = samplers.split_seed( + seed, n=3, salt='spherical_uniform_dynamic_shape') + rademacher_samples = rademacher(sample_shape, dtype=dtype, seed=r_seed) + u = samplers.uniform( + shape, minval=0, maxval=2 * np.pi, dtype=dtype, seed=u_seed) + twod_samples = tf.concat( + [tf.math.cos(u)[..., tf.newaxis], + tf.math.sin(u)[..., tf.newaxis] * tf.ones( + [dimension - 1], dtype=dtype)], axis=-1) + + normal_samples = samplers.normal( + shape=ps.concat([shape, [dimension]], axis=0), + seed=n_seed, + dtype=dtype) + nd_samples = normal_samples / tf.norm( + normal_samples, ord=2, axis=-1)[..., tf.newaxis] + + return tf.where( + tf.math.equal(dimension, 1), + rademacher_samples, + tf.where( + tf.math.equal(dimension, 2), + twod_samples, + nd_samples)) + diff --git a/tensorflow_probability/python/random/random_ops_test.py b/tensorflow_probability/python/random/random_ops_test.py index 3af617ac25..d3ed813b49 100644 --- a/tensorflow_probability/python/random/random_ops_test.py +++ b/tensorflow_probability/python/random/random_ops_test.py @@ -104,5 +104,82 @@ class RandomRayleighDynamic64(test_util.TestCase, _RandomRayleigh): use_static_shape = True +class _RandomSphericalUniform(object): + + def verify_expectations(self, d): + shape_ = np.array([int(1e6)], np.int32) + shape = ( + tf.constant(shape_) if self.use_static_shape else + tf1.placeholder_with_default(shape_, shape=None)) + # This shape will require broadcasting before sampling. + dimension = ( + tf.constant(d) if self.use_static_shape else + tf1.placeholder_with_default(d, shape=None)) + x = tfp.random.spherical_uniform( + dimension=dimension, + shape=shape, + dtype=self.dtype, + seed=test_util.test_seed()) + self.assertEqual(self.dtype, dtype_util.as_numpy_dtype(x.dtype)) + final_shape_ = [int(1e6), d] + if self.use_static_shape: + self.assertAllEqual(final_shape_, x.shape) + sample_mean = tf.reduce_mean(x, axis=0, keepdims=True) + sample_covar = tfp.stats.covariance(x) + [x_, sample_mean_, sample_covar_] = self.evaluate([ + x, sample_mean, sample_covar]) + self.assertAllEqual(final_shape_, x_.shape) + self.assertAllClose( + np.zeros_like(sample_mean_), sample_mean_, atol=2e-3, rtol=1e-3) + self.assertAllClose( + np.eye(d, dtype=self.dtype) / d, sample_covar_, atol=2e-3, rtol=1e-2) + + def test_expectations_1d(self): + self.verify_expectations(1) + + def test_expectations_2d(self): + self.verify_expectations(2) + + def test_expectations_3d(self): + self.verify_expectations(3) + + def test_expectations_5d(self): + self.verify_expectations(4) + + def test_expectations_9d(self): + self.verify_expectations(9) + + def test_jitted_sampling(self): + self.skip_if_no_xla() + shape = np.int32([2, 3]) + seed = test_util.test_seed() + dimension = np.int32(10) + + @tf.function(experimental_compile=True) + def sample(): + return tfp.random.spherical_uniform( + dimension=dimension, + shape=shape, + seed=seed, + dtype=self.dtype) + + samples = self.evaluate(sample()) + self.assertAllEqual([2, 3, 10], samples.shape) + + +@test_util.test_all_tf_execution_regimes +class RandomSphericalUniformDynamic32( + test_util.TestCase, _RandomSphericalUniform): + dtype = np.float32 + use_static_shape = False + + +@test_util.test_all_tf_execution_regimes +class RandomSphericalUniformStatic64( + test_util.TestCase, _RandomSphericalUniform): + dtype = np.float64 + use_static_shape = True + + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_probability/python/sts/BUILD b/tensorflow_probability/python/sts/BUILD index ce137ebcf3..2e1cacd3c9 100644 --- a/tensorflow_probability/python/sts/BUILD +++ b/tensorflow_probability/python/sts/BUILD @@ -356,6 +356,7 @@ py_test( # numpy dep, # tensorflow dep, "//tensorflow_probability", + "//tensorflow_probability/python/internal:samplers", "//tensorflow_probability/python/internal:test_util", ], ) diff --git a/tensorflow_probability/python/sts/autoregressive.py b/tensorflow_probability/python/sts/autoregressive.py index cd05ef3197..fdd32f60b2 100644 --- a/tensorflow_probability/python/sts/autoregressive.py +++ b/tensorflow_probability/python/sts/autoregressive.py @@ -165,6 +165,7 @@ def __init__(self, name: Python `str` name prefixed to ops created by this class. Default value: "AutoregressiveStateSpaceModel". """ + parameters = dict(locals()) with tf.name_scope(name or 'AutoregressiveStateSpaceModel') as name: # The initial state prior determines the dtype of sampled values. @@ -204,6 +205,7 @@ def __init__(self, initial_step=initial_step, validate_args=validate_args, name=name) + self._parameters = parameters @property def order(self): diff --git a/tensorflow_probability/python/sts/dynamic_regression.py b/tensorflow_probability/python/sts/dynamic_regression.py index 6d779de60f..6ad093ece2 100644 --- a/tensorflow_probability/python/sts/dynamic_regression.py +++ b/tensorflow_probability/python/sts/dynamic_regression.py @@ -166,7 +166,7 @@ def __init__(self, Default value: 'DynamicLinearRegressionStateSpaceModel'. """ - + parameters = dict(locals()) with tf.name_scope( name or 'DynamicLinearRegressionStateSpaceModel') as name: dtype = dtype_util.common_dtype( @@ -215,6 +215,7 @@ def observation_matrix_fn(t): allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=name) + self._parameters = parameters @property def drift_scale(self): diff --git a/tensorflow_probability/python/sts/local_level.py b/tensorflow_probability/python/sts/local_level.py index 0af59b67a0..bba44f437e 100644 --- a/tensorflow_probability/python/sts/local_level.py +++ b/tensorflow_probability/python/sts/local_level.py @@ -145,7 +145,7 @@ def __init__(self, name: Python `str` name prefixed to ops created by this class. Default value: "LocalLevelStateSpaceModel". """ - + parameters = dict(locals()) with tf.name_scope(name or 'LocalLevelStateSpaceModel') as name: # The initial state prior determines the dtype of sampled values. # Other model parameters must have the same dtype. @@ -180,6 +180,7 @@ def __init__(self, allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=name) + self._parameters = parameters @property def level_scale(self): diff --git a/tensorflow_probability/python/sts/local_linear_trend.py b/tensorflow_probability/python/sts/local_linear_trend.py index cb803b7a7d..2ccb6cb0a7 100644 --- a/tensorflow_probability/python/sts/local_linear_trend.py +++ b/tensorflow_probability/python/sts/local_linear_trend.py @@ -158,7 +158,7 @@ def __init__(self, name: Python `str` name prefixed to ops created by this class. Default value: "LocalLinearTrendStateSpaceModel". """ - + parameters = dict(locals()) with tf.name_scope(name or 'LocalLinearTrendStateSpaceModel') as name: # The initial state prior determines the dtype of sampled values. # Other model parameters must have the same dtype. @@ -205,6 +205,7 @@ def __init__(self, allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=name) + self._parameters = parameters @property def level_scale(self): diff --git a/tensorflow_probability/python/sts/seasonal.py b/tensorflow_probability/python/sts/seasonal.py index 8d6578c19d..0e665450ff 100644 --- a/tensorflow_probability/python/sts/seasonal.py +++ b/tensorflow_probability/python/sts/seasonal.py @@ -207,7 +207,7 @@ def __init__(self, {seasonal_init_args} """ - + parameters = dict(locals()) with tf.name_scope(name or 'SeasonalStateSpaceModel') as name: # The initial state prior determines the dtype of sampled values. # Other model parameters must have the same dtype. @@ -262,6 +262,7 @@ def __init__(self, allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=name) + self._parameters = parameters @property def drift_scale(self): @@ -419,7 +420,7 @@ def __init__(self, {seasonal_init_args} """ - + parameters = dict(locals()) with tf.name_scope(name or 'ConstrainedSeasonalStateSpaceModel') as name: # The initial state prior determines the dtype of sampled values. @@ -483,6 +484,7 @@ def __init__(self, allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=name) + self._parameters = parameters @property def drift_scale(self): diff --git a/tensorflow_probability/python/sts/semilocal_linear_trend.py b/tensorflow_probability/python/sts/semilocal_linear_trend.py index c3fb77fdee..f5f00b689d 100644 --- a/tensorflow_probability/python/sts/semilocal_linear_trend.py +++ b/tensorflow_probability/python/sts/semilocal_linear_trend.py @@ -173,7 +173,7 @@ def __init__(self, name: Python `str` name prefixed to ops created by this class. Default value: "SemiLocalLinearTrendStateSpaceModel". """ - + parameters = dict(locals()) with tf.name_scope(name or 'SemiLocalLinearTrendStateSpaceModel') as name: dtype = initial_state_prior.dtype @@ -211,6 +211,7 @@ def __init__(self, allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=name) + self._parameters = parameters @property def level_scale(self): diff --git a/tensorflow_probability/python/sts/smooth_seasonal.py b/tensorflow_probability/python/sts/smooth_seasonal.py index 3de0c20efd..3c79e6f74a 100644 --- a/tensorflow_probability/python/sts/smooth_seasonal.py +++ b/tensorflow_probability/python/sts/smooth_seasonal.py @@ -200,7 +200,7 @@ def __init__(self, Default value: 'SmoothSeasonalStateSpaceModel'. """ - + parameters = dict(locals()) with tf.name_scope(name or 'SmoothSeasonalStateSpaceModel') as name: dtype = dtype_util.common_dtype( @@ -254,6 +254,7 @@ def __init__(self, allow_nan_stats=allow_nan_stats, validate_args=validate_args, name=name) + self._parameters = parameters @property def drift_scale(self): diff --git a/tensorflow_probability/python/sts/structural_time_series_test.py b/tensorflow_probability/python/sts/structural_time_series_test.py index 11c5e86896..e3fa7db931 100644 --- a/tensorflow_probability/python/sts/structural_time_series_test.py +++ b/tensorflow_probability/python/sts/structural_time_series_test.py @@ -23,6 +23,7 @@ import tensorflow.compat.v2 as tf import tensorflow_probability as tfp from tensorflow_probability.python import distributions as tfd +from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.internal import test_util from tensorflow_probability.python.sts import Autoregressive from tensorflow_probability.python.sts import DynamicLinearRegression @@ -154,6 +155,14 @@ def test_state_space_model(self): ssm.latent_size_tensor())), model.latent_size) + # Verify that the SSM tracks its parameters. + observed_time_series = self.evaluate( + samplers.normal([10, 1], seed=test_util.test_seed())) + ssm_copy = ssm.copy(name='copied_ssm') + self.assertAllClose(*self.evaluate(( + ssm.log_prob(observed_time_series), + ssm_copy.log_prob(observed_time_series)))) + def test_log_joint(self): seed = test_util.test_seed_stream() model = self._build_sts() diff --git a/tensorflow_probability/python/sts/sum.py b/tensorflow_probability/python/sts/sum.py index c8f3e1373b..7fd015b094 100644 --- a/tensorflow_probability/python/sts/sum.py +++ b/tensorflow_probability/python/sts/sum.py @@ -215,7 +215,7 @@ def __init__(self, Raises: ValueError: if components have different `num_timesteps`. """ - + parameters = dict(locals()) with tf.name_scope(name or 'AdditiveStateSpaceModel') as name: # Check that all components have the same dtype dtype = tf.debugging.assert_same_float_dtype(component_ssms) @@ -333,6 +333,7 @@ def observation_noise_fn(t): validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=name) + self._parameters = parameters class Sum(StructuralTimeSeries): diff --git a/tensorflow_probability/python/version.py b/tensorflow_probability/python/version.py index 2d6a3e8646..bb2e0a94fa 100644 --- a/tensorflow_probability/python/version.py +++ b/tensorflow_probability/python/version.py @@ -24,7 +24,7 @@ # stable release (indicated by `_VERSION_SUFFIX = ''`). Outside the context of a # release branch, the current version is by default assumed to be a # 'development' version, labeled 'dev'. -_VERSION_SUFFIX = 'rc1' +_VERSION_SUFFIX = 'rc2' # Example, '0.4.0-dev' __version__ = '.'.join([