diff --git a/tensorflow_probability/python/__init__.py b/tensorflow_probability/python/__init__.py index d5fd9afa7e..fcccb0644a 100644 --- a/tensorflow_probability/python/__init__.py +++ b/tensorflow_probability/python/__init__.py @@ -55,7 +55,7 @@ def _validate_tf_environment(package): # # Update this whenever we need to depend on a newer TensorFlow release. # - required_tensorflow_version = '2.4' + required_tensorflow_version = '2.5' # required_tensorflow_version = '1.15' # Needed internally -- DisableOnExport if (distutils.version.LooseVersion(tf.__version__) < diff --git a/tensorflow_probability/python/bijectors/BUILD b/tensorflow_probability/python/bijectors/BUILD index 783121242a..81dff46d00 100644 --- a/tensorflow_probability/python/bijectors/BUILD +++ b/tensorflow_probability/python/bijectors/BUILD @@ -167,7 +167,6 @@ multi_substrate_py_library( ":bijector", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", ], ) @@ -214,7 +213,6 @@ multi_substrate_py_library( deps = [ ":bijector", # tensorflow dep, - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", @@ -227,7 +225,6 @@ multi_substrate_py_library( deps = [ ":bijector", # tensorflow dep, - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", @@ -242,7 +239,6 @@ multi_substrate_py_library( ":softplus", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", @@ -359,7 +355,6 @@ multi_substrate_py_library( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:prefer_static", @@ -375,7 +370,6 @@ multi_substrate_py_library( ":cholesky_outer_product", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:prefer_static", ], @@ -388,7 +382,6 @@ multi_substrate_py_library( ":bijector", ":fill_triangular", # tensorflow dep, - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:tensorshape_util", ], @@ -400,7 +393,6 @@ multi_substrate_py_library( deps = [ ":bijector", # tensorflow dep, - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:prefer_static", ], ) @@ -412,7 +404,6 @@ multi_substrate_py_library( ":bijector", # tensorflow dep, # "//tensorflow_probability/google:platform_google", # DisableOnExport - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", ], ) @@ -463,7 +454,6 @@ multi_substrate_py_library( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", "//tensorflow_probability/python/internal:tensorshape_util", "//tensorflow_probability/python/math:linalg", @@ -478,7 +468,6 @@ multi_substrate_py_library( ":softplus", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", @@ -495,7 +484,6 @@ multi_substrate_py_library( ":sigmoid", ":softplus", # tensorflow dep, - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", @@ -531,7 +519,6 @@ multi_substrate_py_library( ":softplus", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", @@ -546,7 +533,6 @@ multi_substrate_py_library( ":softplus", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", @@ -561,7 +547,6 @@ multi_substrate_py_library( ":softplus", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", @@ -574,7 +559,6 @@ multi_substrate_py_library( deps = [ ":bijector", # tensorflow dep, - "//tensorflow_probability/python/internal:auto_composite_tensor", ], ) @@ -584,7 +568,6 @@ multi_substrate_py_library( deps = [ ":bijector", # tensorflow dep, - "//tensorflow_probability/python/internal:auto_composite_tensor", ], ) @@ -593,6 +576,7 @@ multi_substrate_py_library( srcs = ["invert.py"], deps = [ ":bijector", + "//tensorflow_probability/python/internal:auto_composite_tensor", ], ) @@ -605,7 +589,6 @@ multi_substrate_py_library( ":softplus", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", @@ -659,7 +642,6 @@ multi_substrate_py_library( ":bijector", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", ], ) @@ -673,7 +655,6 @@ multi_substrate_py_library( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", @@ -702,7 +683,6 @@ multi_substrate_py_library( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", "//tensorflow_probability/python/internal:special_math", ], @@ -715,7 +695,6 @@ multi_substrate_py_library( ":bijector", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", ], ) @@ -727,7 +706,6 @@ multi_substrate_py_library( ":bijector", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", ], ) @@ -740,7 +718,6 @@ multi_substrate_py_library( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:tensor_util", @@ -756,7 +733,6 @@ multi_substrate_py_library( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", @@ -771,7 +747,6 @@ multi_substrate_py_library( ":bijector", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", ], @@ -784,7 +759,6 @@ multi_substrate_py_library( ":bijector", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", ], ) @@ -797,7 +771,6 @@ multi_substrate_py_library( ":softplus", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", @@ -814,7 +787,6 @@ multi_substrate_py_library( ":softplus", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", ], @@ -839,7 +811,6 @@ multi_substrate_py_library( ":bijector", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", ], ) @@ -852,7 +823,6 @@ multi_substrate_py_library( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:nest_util", "//tensorflow_probability/python/internal:parameter_properties", @@ -869,7 +839,6 @@ multi_substrate_py_library( ":bijector", ":invert", # tensorflow dep, - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:nest_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:prefer_static", @@ -898,7 +867,6 @@ multi_substrate_py_library( ":bijector", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", ], @@ -910,7 +878,6 @@ multi_substrate_py_library( deps = [ ":bijector", # tensorflow dep, - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/math", ], ) @@ -925,7 +892,6 @@ multi_substrate_py_library( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", @@ -940,7 +906,6 @@ multi_substrate_py_library( ":bijector", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", @@ -956,7 +921,6 @@ multi_substrate_py_library( ":softplus", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/math", @@ -972,7 +936,6 @@ multi_substrate_py_library( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:prefer_static", ], @@ -985,7 +948,6 @@ multi_substrate_py_library( ":bijector", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", @@ -1000,7 +962,6 @@ multi_substrate_py_library( ":bijector", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", "//tensorflow_probability/python/internal:dtype_util", ], @@ -1014,7 +975,6 @@ multi_substrate_py_library( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:tensor_util", @@ -1030,7 +990,6 @@ multi_substrate_py_library( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", ], ) @@ -1042,7 +1001,6 @@ multi_substrate_py_library( ":bijector", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:prefer_static", ], @@ -1055,7 +1013,6 @@ multi_substrate_py_library( ":bijector", # numpy dep, # tensorflow dep, - "//tensorflow_probability/python/internal:auto_composite_tensor", ], ) @@ -1065,7 +1022,6 @@ multi_substrate_py_library( deps = [ ":bijector", # tensorflow dep, - "//tensorflow_probability/python/internal:auto_composite_tensor", ], ) @@ -1077,7 +1033,6 @@ multi_substrate_py_library( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:distribution_util", "//tensorflow_probability/python/internal:dtype_util", "//tensorflow_probability/python/internal:parameter_properties", @@ -1093,7 +1048,6 @@ multi_substrate_py_library( ":softplus", # tensorflow dep, "//tensorflow_probability/python/internal:assert_util", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", ], @@ -1141,7 +1095,6 @@ multi_substrate_py_test( # numpy dep, # tensorflow dep, "//tensorflow_probability/python/distributions", - "//tensorflow_probability/python/internal:auto_composite_tensor", "//tensorflow_probability/python/internal:cache_util", "//tensorflow_probability/python/internal:parameter_properties", "//tensorflow_probability/python/internal:tensor_util", diff --git a/tensorflow_probability/python/bijectors/absolute_value.py b/tensorflow_probability/python/bijectors/absolute_value.py index 735398c4ae..3c84a85156 100644 --- a/tensorflow_probability/python/bijectors/absolute_value.py +++ b/tensorflow_probability/python/bijectors/absolute_value.py @@ -22,7 +22,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util __all__ = [ @@ -30,8 +29,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class AbsoluteValue(bijector.AutoCompositeTensorBijector): """Computes `Y = g(X) = Abs(X)`, element-wise. diff --git a/tensorflow_probability/python/bijectors/ascending.py b/tensorflow_probability/python/bijectors/ascending.py index 0f2223d466..be3594445c 100644 --- a/tensorflow_probability/python/bijectors/ascending.py +++ b/tensorflow_probability/python/bijectors/ascending.py @@ -22,7 +22,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor __all__ = [ @@ -30,8 +29,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Ascending(bijector.AutoCompositeTensorBijector): """Maps unconstrained R^n to R^n in ascending order. diff --git a/tensorflow_probability/python/bijectors/bijector.py b/tensorflow_probability/python/bijectors/bijector.py index b5df1d41d1..fc45cc6644 100644 --- a/tensorflow_probability/python/bijectors/bijector.py +++ b/tensorflow_probability/python/bijectors/bijector.py @@ -20,6 +20,7 @@ import abc import contextlib +# import functools # Dependency imports import numpy as np @@ -1598,26 +1599,39 @@ def _composite_tensor_shape_params(self): return () -class AutoCompositeTensorBijector( - Bijector, auto_composite_tensor.AutoCompositeTensor): - r"""Base for `CompositeTensor` bijectors with auto-generated `TypeSpec`s. +# Temporarily disable AutoCT for TFP 0.13 release +# class AutoCompositeTensorBijector( +# Bijector, auto_composite_tensor.AutoCompositeTensor): +# r"""Base for `CompositeTensor` bijectors with auto-generated `TypeSpec`s. - `CompositeTensor` objects are able to pass in and out of `tf.function` and - `tf.while_loop`, or serve as part of the signature of a TF saved model. - `Bijector` subclasses that follow the contract of - `tfp.experimental.auto_composite_tensor` may be defined as `CompositeTensor`s - by inheriting from `AutoCompositeTensorBijector` and applying a class - decorator as shown here: +# `CompositeTensor` objects are able to pass in and out of `tf.function` and +# `tf.while_loop`, or serve as part of the signature of a TF saved model. +# `Bijector` subclasses that follow the contract of +# `tfp.experimental.auto_composite_tensor` may be defined as `CompositeTensor`s # pylint: disable=line-too-long +# by inheriting from `AutoCompositeTensorBijector` and applying a class +# decorator as shown here: - ```python - @tfp.experimental.auto_composite_tensor( - omit_kwargs=('name',), module_name='my_module') - class MyBijector(tfb.AutoCompositeTensorBijector): +# ```python +# @tfp.experimental.auto_composite_tensor( +# omit_kwargs=('name',), module_name='my_module') +# class MyBijector(tfb.AutoCompositeTensorBijector): - # The remainder of the subclass implementation is unchanged. - ``` - """ - pass +# # The remainder of the subclass implementation is unchanged. +# ``` +# """ +# pass + + +# auto_composite_tensor_bijector = functools.partial( +# auto_composite_tensor.auto_composite_tensor, +# omit_kwargs=('parameters',), +# non_identifying_kwargs=('name',), +# module_name='tfp.bijectors') + +AutoCompositeTensorBijector = Bijector + + +auto_composite_tensor_bijector = lambda cls, **kwargs: cls def check_valid_ndims(ndims, validate=True): diff --git a/tensorflow_probability/python/bijectors/bijector_properties_test.py b/tensorflow_probability/python/bijectors/bijector_properties_test.py index c6063d4e25..8eab905268 100644 --- a/tensorflow_probability/python/bijectors/bijector_properties_test.py +++ b/tensorflow_probability/python/bijectors/bijector_properties_test.py @@ -34,7 +34,7 @@ from tensorflow_probability.python.internal import tensor_util from tensorflow_probability.python.internal import tensorshape_util from tensorflow_probability.python.internal import test_util -from tensorflow_probability.python.util.deferred_tensor import DeferredTensor +# from tensorflow_probability.python.util.deferred_tensor import DeferredTensor TF2_FRIENDLY_BIJECTORS = ( @@ -185,6 +185,7 @@ COMPOSITE_TENSOR_IS_BROKEN = [ 'BatchNormalization', # tf.layers arg + 'Inline', # callable 'RationalQuadraticSpline', # TODO(b/185628453): Debug loss of static info. ] @@ -197,7 +198,6 @@ # TODO(b/182603117): Enable AutoCT for meta-bijectors and LinearOperator. AUTO_COMPOSITE_TENSOR_IS_BROKEN = [ 'FillScaleTriL', - 'Invert', 'ScaleMatvecDiag', 'ScaleMatvecTriL', ] @@ -244,8 +244,8 @@ def _constraint(param): # TODO(b/141098791): Eliminate this. -@experimental.auto_composite_tensor -class CallableModule(tf.Module, experimental.AutoCompositeTensor): +# @experimental.auto_composite_tensor +class CallableModule(tf.Module): # , experimental.AutoCompositeTensor): """Convenience object for capturing variables closed over by Inline.""" def __init__(self, fn, varobj): @@ -630,7 +630,7 @@ def exception(bijector): return False if isinstance(bijector, tfb.Softfloor): return True - if isinstance(bijector, tfb.Invert): + if is_invert(bijector): return exception(bijector.bijector) return False if (bijector.forward_min_event_ndims == 0 and @@ -887,32 +887,16 @@ def testEquality(self, bijector_name, data): @hp.given(hps.data()) @tfp_hps.tfp_hp_settings() def testCompositeTensor(self, bijector_name, data): - + # Test that making a composite tensor of this bijector doesn't throw any + # errors. bijector, event_dim = self._draw_bijector( - bijector_name, data, - batch_shape=[], - validate_args=True, + bijector_name, data, batch_shape=[], allowed_bijectors=(set(TF2_FRIENDLY_BIJECTORS) - set(COMPOSITE_TENSOR_IS_BROKEN))) - - # TODO(b/182603117): Remove "if" condition and s/composite_bij/bijector - # when AutoCT is enabled for meta-bijectors and LinearOperator. - if type(bijector).__name__ in AUTO_COMPOSITE_TENSOR_IS_BROKEN: - composite_bij = experimental.as_composite(bijector) - else: - composite_bij = bijector - - if not tf.executing_eagerly(): - composite_bij = tf.nest.map_structure( - lambda x: (tf.convert_to_tensor(x) # pylint: disable=g-long-lambda - if isinstance(x, DeferredTensor) else x), - composite_bij, - expand_composites=True) - - self.assertIsInstance(composite_bij, tf.__internal__.CompositeTensor) + composite_bij = experimental.as_composite(bijector) flat = tf.nest.flatten(composite_bij, expand_composites=True) - unflat = tf.nest.pack_sequence_as( - composite_bij, flat, expand_composites=True) + unflat = tf.nest.pack_sequence_as(composite_bij, flat, + expand_composites=True) # Compare forward maps before and after compositing. n = 3 @@ -927,26 +911,6 @@ def testCompositeTensor(self, bijector_name, data): after_xs = unflat.inverse(ys) self.assertAllClose(*self.evaluate((before_xs, after_xs))) - # Input to tf.function - self.assertAllClose( - before_ys, - tf.function(lambda b: b.forward(xs))(composite_bij), - rtol=COMPOSITE_TENSOR_RTOL[bijector_name], - atol=COMPOSITE_TENSOR_ATOL[bijector_name]) - - # Forward mapping: Check differentiation through forward mapping with - # respect to the input and parameter variables. Also check that any - # variables are not referenced overmuch. - xs = self._draw_domain_tensor(bijector, data, event_dim) - wrt_vars = [xs] + [v for v in composite_bij.trainable_variables - if v.dtype.is_floating] - with tf.GradientTape() as tape: - tape.watch(wrt_vars) - # TODO(b/73073515): Fix graph mode gradients with bijector caching. - ys = bijector.forward(xs + 0) - grads = tape.gradient(ys, wrt_vars) - assert_no_none_grad(bijector, 'forward', wrt_vars, grads) - def ensure_nonzero(x): return tf.where(x < 1e-6, tf.constant(1e-3, x.dtype), x) diff --git a/tensorflow_probability/python/bijectors/bijector_test.py b/tensorflow_probability/python/bijectors/bijector_test.py index cb45c1fae1..0d26899e63 100644 --- a/tensorflow_probability/python/bijectors/bijector_test.py +++ b/tensorflow_probability/python/bijectors/bijector_test.py @@ -25,7 +25,7 @@ import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf from tensorflow_probability.python import bijectors as tfb -from tensorflow_probability.python.internal import auto_composite_tensor +from tensorflow_probability.python.bijectors import bijector as bijector_lib from tensorflow_probability.python.internal import cache_util from tensorflow_probability.python.internal import tensor_util from tensorflow_probability.python.internal import test_util @@ -768,7 +768,7 @@ def testNestedCondition(self): mock_method.assert_called_once_with(mock.ANY, arg1=arg1, arg2=arg2) -@auto_composite_tensor.auto_composite_tensor(omit_kwargs=('name',)) +@bijector_lib.auto_composite_tensor_bijector class CompositeForwardBijector(tfb.AutoCompositeTensorBijector): def __init__(self, scale=2., validate_args=False, name=None): @@ -790,25 +790,26 @@ def _forward_log_det_jacobian(self, _): return tf.math.log(self._scale) -@test_util.test_all_tf_execution_regimes -class AutoCompositeTensorBijectorTest(test_util.TestCase): + # Test disabled temporarily for TFP 0.13 release. +# @test_util.test_all_tf_execution_regimes +# class AutoCompositeTensorBijectorTest(test_util.TestCase): - def test_disable_ct_bijector(self): +# def test_disable_ct_bijector(self): - ct_bijector = CompositeForwardBijector() - self.assertIsInstance(ct_bijector, tf.__internal__.CompositeTensor) +# ct_bijector = CompositeForwardBijector() +# self.assertIsInstance(ct_bijector, tf.__internal__.CompositeTensor) - non_ct_bijector = ForwardOnlyBijector() - self.assertNotIsInstance(non_ct_bijector, tf.__internal__.CompositeTensor) +# non_ct_bijector = ForwardOnlyBijector() +# self.assertNotIsInstance(non_ct_bijector, tf.__internal__.CompositeTensor) - flat = tf.nest.flatten(ct_bijector, expand_composites=True) - unflat = tf.nest.pack_sequence_as( - ct_bijector, flat, expand_composites=True) +# flat = tf.nest.flatten(ct_bijector, expand_composites=True) +# unflat = tf.nest.pack_sequence_as( +# ct_bijector, flat, expand_composites=True) - x = tf.constant([2., 3.]) - self.assertAllClose( - non_ct_bijector.forward(x), - tf.function(lambda b: b.forward(x))(unflat)) +# x = tf.constant([2., 3.]) +# self.assertAllClose( +# non_ct_bijector.forward(x), +# tf.function(lambda b: b.forward(x))(unflat)) if __name__ == '__main__': diff --git a/tensorflow_probability/python/bijectors/categorical_to_discrete.py b/tensorflow_probability/python/bijectors/categorical_to_discrete.py index 91a4426de4..e762bf52bb 100644 --- a/tensorflow_probability/python/bijectors/categorical_to_discrete.py +++ b/tensorflow_probability/python/bijectors/categorical_to_discrete.py @@ -36,7 +36,8 @@ ] -class CategoricalToDiscrete(bijector.Bijector): +@bijector.auto_composite_tensor_bijector +class CategoricalToDiscrete(bijector.AutoCompositeTensorBijector): """Bijector which computes `Y = g(X) = values[X]`. Example Usage: diff --git a/tensorflow_probability/python/bijectors/cholesky_outer_product.py b/tensorflow_probability/python/bijectors/cholesky_outer_product.py index e826aacf28..9b1ebc7780 100644 --- a/tensorflow_probability/python/bijectors/cholesky_outer_product.py +++ b/tensorflow_probability/python/bijectors/cholesky_outer_product.py @@ -24,7 +24,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import distribution_util from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import prefer_static as ps @@ -36,8 +35,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class CholeskyOuterProduct(bijector.AutoCompositeTensorBijector): """Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix. diff --git a/tensorflow_probability/python/bijectors/cholesky_to_inv_cholesky.py b/tensorflow_probability/python/bijectors/cholesky_to_inv_cholesky.py index f1ec24ddac..c3cdebad11 100644 --- a/tensorflow_probability/python/bijectors/cholesky_to_inv_cholesky.py +++ b/tensorflow_probability/python/bijectors/cholesky_to_inv_cholesky.py @@ -23,7 +23,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors.cholesky_outer_product import CholeskyOuterProduct from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import prefer_static as ps @@ -33,8 +32,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class CholeskyToInvCholesky(bijector.AutoCompositeTensorBijector): """Maps the Cholesky factor of `M` to the Cholesky factor of `M^{-1}`. diff --git a/tensorflow_probability/python/bijectors/correlation_cholesky.py b/tensorflow_probability/python/bijectors/correlation_cholesky.py index c83985a23d..1f5e425f8f 100644 --- a/tensorflow_probability/python/bijectors/correlation_cholesky.py +++ b/tensorflow_probability/python/bijectors/correlation_cholesky.py @@ -24,7 +24,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors import fill_triangular -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensorshape_util @@ -33,8 +32,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class CorrelationCholesky(bijector.AutoCompositeTensorBijector): """Maps unconstrained reals to Cholesky-space correlation matrices. diff --git a/tensorflow_probability/python/bijectors/cumsum.py b/tensorflow_probability/python/bijectors/cumsum.py index 0023840500..1f89d661fc 100644 --- a/tensorflow_probability/python/bijectors/cumsum.py +++ b/tensorflow_probability/python/bijectors/cumsum.py @@ -20,7 +20,6 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import prefer_static __all__ = [ @@ -28,8 +27,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Cumsum(bijector.AutoCompositeTensorBijector): """Computes the cumulative sum of a tensor along a specified axis. diff --git a/tensorflow_probability/python/bijectors/discrete_cosine_transform.py b/tensorflow_probability/python/bijectors/discrete_cosine_transform.py index d66e62349d..c9a96eb897 100644 --- a/tensorflow_probability/python/bijectors/discrete_cosine_transform.py +++ b/tensorflow_probability/python/bijectors/discrete_cosine_transform.py @@ -21,7 +21,6 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util @@ -30,8 +29,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class DiscreteCosineTransform(bijector.AutoCompositeTensorBijector): """Compute `Y = g(X) = DCT(X)`, where DCT type is indicated by the `type` arg. diff --git a/tensorflow_probability/python/bijectors/exp.py b/tensorflow_probability/python/bijectors/exp.py index 0fe2565118..6e59ef49de 100644 --- a/tensorflow_probability/python/bijectors/exp.py +++ b/tensorflow_probability/python/bijectors/exp.py @@ -20,9 +20,9 @@ import tensorflow.compat.v2 as tf +from tensorflow_probability.python.bijectors import bijector as bijector_lib from tensorflow_probability.python.bijectors import invert from tensorflow_probability.python.bijectors import power_transform -from tensorflow_probability.python.internal import auto_composite_tensor __all__ = [ @@ -31,8 +31,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector_lib.auto_composite_tensor_bijector class Exp(power_transform.PowerTransform): """Compute `Y = g(X) = exp(X)`. @@ -76,10 +75,8 @@ def __init__(self, # TODO(b/182603117): Remove `AutoCompositeTensor` when `Invert` subclasses # `AutoCompositeTensor` and ensure `tf.saved_model` still works. -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') -class Log(invert.Invert, - auto_composite_tensor.AutoCompositeTensor): +@bijector_lib.auto_composite_tensor_bijector +class Log(invert.Invert): """Compute `Y = log(X)`. This is `Invert(Exp())`.""" def __init__(self, validate_args=False, name='log'): diff --git a/tensorflow_probability/python/bijectors/expm1.py b/tensorflow_probability/python/bijectors/expm1.py index d51b644678..69daa73d97 100644 --- a/tensorflow_probability/python/bijectors/expm1.py +++ b/tensorflow_probability/python/bijectors/expm1.py @@ -21,7 +21,6 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors import invert -from tensorflow_probability.python.internal import auto_composite_tensor __all__ = [ @@ -30,8 +29,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Expm1(bijector.AutoCompositeTensorBijector): """Compute `Y = g(X) = exp(X) - 1`. @@ -95,9 +93,8 @@ def _forward_log_det_jacobian(self, x): # TODO(b/182603117): Remove `AutoCompositeTensor` when `Invert` subclasses # `AutoCompositeTensor`. -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') -class Log1p(invert.Invert, auto_composite_tensor.AutoCompositeTensor): +@bijector.auto_composite_tensor_bijector +class Log1p(invert.Invert): """Compute `Y = log1p(X)`. This is `Invert(Expm1())`.""" def __init__(self, validate_args=False, name='log1p'): diff --git a/tensorflow_probability/python/bijectors/fill_triangular.py b/tensorflow_probability/python/bijectors/fill_triangular.py index ac3ba881dc..84f1643515 100644 --- a/tensorflow_probability/python/bijectors/fill_triangular.py +++ b/tensorflow_probability/python/bijectors/fill_triangular.py @@ -25,7 +25,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import tensorshape_util from tensorflow_probability.python.math.linalg import fill_triangular from tensorflow_probability.python.math.linalg import fill_triangular_inverse @@ -36,8 +35,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class FillTriangular(bijector.AutoCompositeTensorBijector): """Transforms vectors to triangular. diff --git a/tensorflow_probability/python/bijectors/frechet_cdf.py b/tensorflow_probability/python/bijectors/frechet_cdf.py index 9f8297b527..24f4261f4b 100644 --- a/tensorflow_probability/python/bijectors/frechet_cdf.py +++ b/tensorflow_probability/python/bijectors/frechet_cdf.py @@ -23,7 +23,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors import softplus as softplus_bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import tensor_util @@ -34,8 +33,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class FrechetCDF(bijector.AutoCompositeTensorBijector): """The Frechet cumulative density function. diff --git a/tensorflow_probability/python/bijectors/generalized_pareto.py b/tensorflow_probability/python/bijectors/generalized_pareto.py index 6f4910589e..b5a75d04d3 100644 --- a/tensorflow_probability/python/bijectors/generalized_pareto.py +++ b/tensorflow_probability/python/bijectors/generalized_pareto.py @@ -24,7 +24,6 @@ from tensorflow_probability.python.bijectors import shift as shift_bijector from tensorflow_probability.python.bijectors import sigmoid as sigmoid_bijector from tensorflow_probability.python.bijectors import softplus as softplus_bijector -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import tensor_util @@ -35,8 +34,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector_lib.auto_composite_tensor_bijector class GeneralizedPareto(bijector_lib.AutoCompositeTensorBijector): """Bijector mapping R**n to non-negative reals. diff --git a/tensorflow_probability/python/bijectors/gev_cdf.py b/tensorflow_probability/python/bijectors/gev_cdf.py index 11cfbdd767..00ce4c4cc3 100644 --- a/tensorflow_probability/python/bijectors/gev_cdf.py +++ b/tensorflow_probability/python/bijectors/gev_cdf.py @@ -23,7 +23,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors import softplus as softplus_bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import tensor_util @@ -33,8 +32,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class GeneralizedExtremeValueCDF(bijector.AutoCompositeTensorBijector): """Compute the GeneralizedExtremeValue CDF. diff --git a/tensorflow_probability/python/bijectors/gompertz_cdf.py b/tensorflow_probability/python/bijectors/gompertz_cdf.py index 334de5d027..ebe493ee41 100644 --- a/tensorflow_probability/python/bijectors/gompertz_cdf.py +++ b/tensorflow_probability/python/bijectors/gompertz_cdf.py @@ -23,7 +23,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors import softplus as softplus_bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import tensor_util @@ -34,8 +33,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class GompertzCDF(bijector.AutoCompositeTensorBijector): """Compute `Y = g(X) = 1 - exp(-c * (exp(rate * X) - 1)`, the Gompertz CDF. diff --git a/tensorflow_probability/python/bijectors/gumbel_cdf.py b/tensorflow_probability/python/bijectors/gumbel_cdf.py index aab68232e8..f1d215eef5 100644 --- a/tensorflow_probability/python/bijectors/gumbel_cdf.py +++ b/tensorflow_probability/python/bijectors/gumbel_cdf.py @@ -23,7 +23,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors import softplus as softplus_bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import tensor_util @@ -34,8 +33,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class GumbelCDF(bijector.AutoCompositeTensorBijector): """Compute `Y = g(X) = exp(-exp(-(X - loc) / scale))`, the Gumbel CDF. diff --git a/tensorflow_probability/python/bijectors/hypothesis_testlib.py b/tensorflow_probability/python/bijectors/hypothesis_testlib.py index 652338b1ab..8cd9aa2451 100644 --- a/tensorflow_probability/python/bijectors/hypothesis_testlib.py +++ b/tensorflow_probability/python/bijectors/hypothesis_testlib.py @@ -122,6 +122,8 @@ def bijector_supports(): return BIJECTOR_SUPPORTS Support = tfp_hps.Support # pylint: disable=invalid-name supports = { + '_Invert': + BijectorSupport(Support.OTHER, Support.OTHER), 'Ascending': BijectorSupport(Support.VECTOR_UNCONSTRAINED, Support.VECTOR_STRICTLY_INCREASING), diff --git a/tensorflow_probability/python/bijectors/identity.py b/tensorflow_probability/python/bijectors/identity.py index a7d350ce1a..add76e7121 100644 --- a/tensorflow_probability/python/bijectors/identity.py +++ b/tensorflow_probability/python/bijectors/identity.py @@ -21,7 +21,6 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector -from tensorflow_probability.python.internal import auto_composite_tensor __all__ = [ 'Identity', @@ -34,8 +33,7 @@ def __getitem__(self, _): return {} -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Identity(bijector.AutoCompositeTensorBijector): """Compute Y = g(X) = X. diff --git a/tensorflow_probability/python/bijectors/inline.py b/tensorflow_probability/python/bijectors/inline.py index b60a307ea1..9b5e8a5ece 100644 --- a/tensorflow_probability/python/bijectors/inline.py +++ b/tensorflow_probability/python/bijectors/inline.py @@ -21,7 +21,6 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector -from tensorflow_probability.python.internal import auto_composite_tensor __all__ = [ @@ -29,8 +28,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Inline(bijector.AutoCompositeTensorBijector): """Bijector constructed from custom callables. diff --git a/tensorflow_probability/python/bijectors/invert.py b/tensorflow_probability/python/bijectors/invert.py index 7f764d8dd0..d747920774 100644 --- a/tensorflow_probability/python/bijectors/invert.py +++ b/tensorflow_probability/python/bijectors/invert.py @@ -21,9 +21,10 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector as bijector_lib +# from tensorflow_probability.python.internal import auto_composite_tensor __all__ = [ - "Invert", + 'Invert', ] @@ -67,9 +68,9 @@ def __init__(self, bijector, validate_args=False, parameters=None, name=None): parameters = dict(locals()) if parameters is None else parameters if not bijector._is_injective: # pylint: disable=protected-access raise NotImplementedError( - "Invert is not implemented for non-injective bijectors.") + 'Invert is not implemented for non-injective bijectors.') - name = name or "_".join(["invert", bijector.name]) + name = name or '_'.join(['invert', bijector.name]) with tf.name_scope(name) as name: self._bijector = bijector super(Invert, self).__init__( @@ -135,3 +136,30 @@ def inverse_event_ndims(self, event_ndims, **kwargs): def forward_event_ndims(self, event_ndims, **kwargs): return self.bijector.inverse_event_ndims(event_ndims, **kwargs) + + +# Temporarily removing AutoCompositeTensor for TFP 0.13 release. +# pylint: disable=line-too-long +# @bijector_lib.auto_composite_tensor_bijector +# class Invert(_Invert, auto_composite_tensor.AutoCompositeTensor): + +# def __new__(cls, *args, **kwargs): +# """Returns an `_Invert` instance if `bijector` is not a `CompositeTensor.""" +# if cls is Invert: +# if args: +# bijector = args[0] +# elif 'bijector' in kwargs: +# bijector = kwargs['bijector'] +# else: +# raise TypeError('`Invert.__new__()` is missing argument `bijector`.') + +# if not isinstance(bijector, tf.__internal__.CompositeTensor): +# return _Invert(*args, **kwargs) +# return super(Invert, cls).__new__(cls) + + +# Invert.__doc__ = _Invert.__doc__ + '/n' + ( +# 'When an `Invert` bijector is constructed, if its `bijector` arg is not a ' +# '`CompositeTensor` instance, an `_Invert` instance is returned instead. ' +# 'Bijectors subclasses that inherit from `Invert` will also inherit from ' +# ' `CompositeTensor`.') diff --git a/tensorflow_probability/python/bijectors/invert_test.py b/tensorflow_probability/python/bijectors/invert_test.py index 05c1831983..e101d51d43 100644 --- a/tensorflow_probability/python/bijectors/invert_test.py +++ b/tensorflow_probability/python/bijectors/invert_test.py @@ -28,6 +28,22 @@ from tensorflow_probability.python.internal import test_util +class NonCompositeScale(tfb.Bijector): + """Bijector that is not a `CompositeTensor`.""" + + def __init__(self, scale): + parameters = dict(locals()) + self.scale = scale + super(NonCompositeScale, self).__init__( + validate_args=True, + forward_min_event_ndims=0., + parameters=parameters, + name='non_composite_scale') + + def _inverse(self, y): + return y / self.scale + + @test_util.test_all_tf_execution_regimes class InvertBijectorTest(test_util.TestCase): """Tests the correctness of the Y = Invert(bij) transformation.""" @@ -41,7 +57,7 @@ def testBijector(self): tfb.SoftmaxCentered(), ]: rev = tfb.Invert(fwd) - self.assertStartsWith(rev.name, "_".join(["invert", fwd.name])) + self.assertStartsWith(rev.name, '_'.join(['invert', fwd.name])) x = [[[1., 2.], [2., 3.]]] self.assertAllClose( @@ -114,6 +130,40 @@ def testNoReductionWhenEventNdimsIsOmitted(self): x, self.evaluate(bij.inverse_log_det_jacobian(x))) + def testNonCompositeTensorBijectorTfFunction(self): + scale = tf.Variable(5.) + b = NonCompositeScale(scale) + inv_b = tfb.Invert(b) + x = tf.constant([3.]) + + @tf.function + def f(bij, x): + return bij.forward(x) + + self.evaluate(scale.initializer) + self.assertAllClose(self.evaluate(f(inv_b, x)), [0.6]) + + @test_util.numpy_disable_variable_test + @test_util.jax_disable_variable_test + def testNonCompositeTensorBijectorRetainsVariable(self): + + class BijectorContainer(tf.Module): + + def __init__(self, bijector): + self.bijector = bijector + + b = NonCompositeScale(tf.Variable(3.)) + inv_b = tfb.Invert(b) + bc = BijectorContainer(inv_b) + + # If `Invert` subclasses `CompositeTensor` but its inner bijector does not, + # this test fails because `tf.Module.trainable_variables` calls + # `nest.flatten(..., expand_composites=True` on the `tf.Module`s attributes. + # `Invert._type_spec` will treat the inner bijector as a callable (see + # `AutoCompositeTensor` docs) and not decompose the inner bijector correctly + # into its `Tensor` components. + self.assertLen(bc.trainable_variables, 1) + -if __name__ == "__main__": +if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_probability/python/bijectors/iterated_sigmoid_centered.py b/tensorflow_probability/python/bijectors/iterated_sigmoid_centered.py index 30ef831ced..a8ac21a49c 100644 --- a/tensorflow_probability/python/bijectors/iterated_sigmoid_centered.py +++ b/tensorflow_probability/python/bijectors/iterated_sigmoid_centered.py @@ -22,7 +22,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import prefer_static as ps @@ -32,8 +31,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class IteratedSigmoidCentered(bijector.AutoCompositeTensorBijector): """Bijector which applies a Stick Breaking procedure. diff --git a/tensorflow_probability/python/bijectors/kumaraswamy_cdf.py b/tensorflow_probability/python/bijectors/kumaraswamy_cdf.py index 79a77ebd7e..f961fadd06 100644 --- a/tensorflow_probability/python/bijectors/kumaraswamy_cdf.py +++ b/tensorflow_probability/python/bijectors/kumaraswamy_cdf.py @@ -23,7 +23,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors import softplus as softplus_bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import distribution_util from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties @@ -35,8 +34,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class KumaraswamyCDF(bijector.AutoCompositeTensorBijector): """Compute `Y = g(X) = (1 - X**a)**b, X in [0, 1]`. diff --git a/tensorflow_probability/python/bijectors/matrix_inverse_tril.py b/tensorflow_probability/python/bijectors/matrix_inverse_tril.py index 88f0b8785c..15a4f856ff 100644 --- a/tensorflow_probability/python/bijectors/matrix_inverse_tril.py +++ b/tensorflow_probability/python/bijectors/matrix_inverse_tril.py @@ -22,7 +22,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util @@ -31,8 +30,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class MatrixInverseTriL(bijector.AutoCompositeTensorBijector): """Computes `g(L) = inv(L)`, where `L` is a lower-triangular matrix. diff --git a/tensorflow_probability/python/bijectors/moyal_cdf.py b/tensorflow_probability/python/bijectors/moyal_cdf.py index 265f71ae1b..89aeb97719 100644 --- a/tensorflow_probability/python/bijectors/moyal_cdf.py +++ b/tensorflow_probability/python/bijectors/moyal_cdf.py @@ -27,7 +27,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors import softplus as softplus_bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import tensor_util @@ -38,8 +37,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class MoyalCDF(bijector.AutoCompositeTensorBijector): """Compute `Y = g(X) = erfc(exp(- 1/2 * (X - loc) / scale) / sqrt(2))`. diff --git a/tensorflow_probability/python/bijectors/normal_cdf.py b/tensorflow_probability/python/bijectors/normal_cdf.py index abcfea99d0..3ca3369a5e 100644 --- a/tensorflow_probability/python/bijectors/normal_cdf.py +++ b/tensorflow_probability/python/bijectors/normal_cdf.py @@ -25,7 +25,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import special_math @@ -34,8 +33,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class NormalCDF(bijector.AutoCompositeTensorBijector): """Compute `Y = g(X) = NormalCDF(x)`. diff --git a/tensorflow_probability/python/bijectors/ordered.py b/tensorflow_probability/python/bijectors/ordered.py index 4bad67c8f3..046a0c2c53 100644 --- a/tensorflow_probability/python/bijectors/ordered.py +++ b/tensorflow_probability/python/bijectors/ordered.py @@ -22,7 +22,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import @@ -31,8 +30,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Ordered(bijector.AutoCompositeTensorBijector): """Maps a vector of increasing elements to an unconstrained vector. diff --git a/tensorflow_probability/python/bijectors/pad.py b/tensorflow_probability/python/bijectors/pad.py index 813089fa45..d8c423c1d5 100644 --- a/tensorflow_probability/python/bijectors/pad.py +++ b/tensorflow_probability/python/bijectors/pad.py @@ -23,7 +23,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensor_util @@ -35,8 +34,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Pad(bijector.AutoCompositeTensorBijector): """Pads a value to the `event_shape` of a `Tensor`. diff --git a/tensorflow_probability/python/bijectors/permute.py b/tensorflow_probability/python/bijectors/permute.py index accf30b591..c2412f92eb 100644 --- a/tensorflow_probability/python/bijectors/permute.py +++ b/tensorflow_probability/python/bijectors/permute.py @@ -25,7 +25,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import tensor_util @@ -37,8 +36,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Permute(bijector.AutoCompositeTensorBijector): """Permutes the rightmost dimension of a `Tensor`. diff --git a/tensorflow_probability/python/bijectors/power.py b/tensorflow_probability/python/bijectors/power.py index e595e9b409..b9b628cb83 100644 --- a/tensorflow_probability/python/bijectors/power.py +++ b/tensorflow_probability/python/bijectors/power.py @@ -24,7 +24,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensor_util @@ -38,8 +37,7 @@ def _is_odd_integer(x): return ps.equal(x, ps.round(x)) & ps.not_equal(2. * ps.floor(x / 2.), x) -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Power(bijector.AutoCompositeTensorBijector): """Compute `g(X) = X ** power`; where X is a non-negative real number. diff --git a/tensorflow_probability/python/bijectors/power_transform.py b/tensorflow_probability/python/bijectors/power_transform.py index cd653d9fe3..d96d945772 100644 --- a/tensorflow_probability/python/bijectors/power_transform.py +++ b/tensorflow_probability/python/bijectors/power_transform.py @@ -22,7 +22,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import prefer_static as ps @@ -31,8 +30,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class PowerTransform(bijector.AutoCompositeTensorBijector): """Compute `Y = g(X) = (1 + X * c)**(1 / c), X >= -1 / c`. diff --git a/tensorflow_probability/python/bijectors/rational_quadratic_spline.py b/tensorflow_probability/python/bijectors/rational_quadratic_spline.py index eb3454a533..eab94194c8 100644 --- a/tensorflow_probability/python/bijectors/rational_quadratic_spline.py +++ b/tensorflow_probability/python/bijectors/rational_quadratic_spline.py @@ -26,7 +26,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors import softplus as softplus_bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import tensor_util @@ -59,8 +58,7 @@ def _knot_positions(bin_sizes, range_min): 'SplineShared', 'out_of_bounds,x_k,y_k,d_k,d_kp1,h_k,w_k,s_k') -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class RationalQuadraticSpline(bijector.AutoCompositeTensorBijector): """A piecewise rational quadratic spline, as developed in [1]. diff --git a/tensorflow_probability/python/bijectors/rayleigh_cdf.py b/tensorflow_probability/python/bijectors/rayleigh_cdf.py index f36fcfade9..22ca6de0e8 100644 --- a/tensorflow_probability/python/bijectors/rayleigh_cdf.py +++ b/tensorflow_probability/python/bijectors/rayleigh_cdf.py @@ -21,7 +21,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors import softplus as softplus_bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import tensor_util @@ -31,8 +30,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class RayleighCDF(bijector.AutoCompositeTensorBijector): """Compute `Y = g(X) = 1 - exp( -(X/scale)**2 / 2 ), X >= 0`. diff --git a/tensorflow_probability/python/bijectors/reciprocal.py b/tensorflow_probability/python/bijectors/reciprocal.py index e5fca69ae2..40eda26d9d 100644 --- a/tensorflow_probability/python/bijectors/reciprocal.py +++ b/tensorflow_probability/python/bijectors/reciprocal.py @@ -22,14 +22,12 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util __all__ = ['Reciprocal'] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Reciprocal(bijector.AutoCompositeTensorBijector): """A `Bijector` that computes the reciprocal `b(x) = 1. / x` entrywise. diff --git a/tensorflow_probability/python/bijectors/reshape.py b/tensorflow_probability/python/bijectors/reshape.py index 0a56d9193f..362b8f4d9b 100644 --- a/tensorflow_probability/python/bijectors/reshape.py +++ b/tensorflow_probability/python/bijectors/reshape.py @@ -25,7 +25,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import prefer_static as ps @@ -38,8 +37,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Reshape(bijector.AutoCompositeTensorBijector): """Reshapes the `event_shape` of a `Tensor`. diff --git a/tensorflow_probability/python/bijectors/restructure.py b/tensorflow_probability/python/bijectors/restructure.py index 98a8b3fd1f..cb6b2b13c9 100644 --- a/tensorflow_probability/python/bijectors/restructure.py +++ b/tensorflow_probability/python/bijectors/restructure.py @@ -24,7 +24,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors import invert -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import nest_util from tensorflow_probability.python.internal import parameter_properties from tensorflow.python.util import nest # pylint: disable=g-direct-tensorflow-import @@ -45,8 +44,7 @@ def unique_token_set(source_structure): return flat_token_set -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Restructure(bijector.AutoCompositeTensorBijector): """Converts between nested structures of Tensors. diff --git a/tensorflow_probability/python/bijectors/scale.py b/tensorflow_probability/python/bijectors/scale.py index 0402f15e64..ac6cb662de 100644 --- a/tensorflow_probability/python/bijectors/scale.py +++ b/tensorflow_probability/python/bijectors/scale.py @@ -22,7 +22,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import tensor_util @@ -33,8 +32,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Scale(bijector.AutoCompositeTensorBijector): """Compute `Y = g(X; scale) = scale * X`. diff --git a/tensorflow_probability/python/bijectors/scale_matvec_lu.py b/tensorflow_probability/python/bijectors/scale_matvec_lu.py index 0f93a0ddfc..f9b6b19158 100644 --- a/tensorflow_probability/python/bijectors/scale_matvec_lu.py +++ b/tensorflow_probability/python/bijectors/scale_matvec_lu.py @@ -22,7 +22,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import tensor_util @@ -38,8 +37,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class ScaleMatvecLU(bijector.AutoCompositeTensorBijector): """Matrix-vector multiply using LU decomposition. diff --git a/tensorflow_probability/python/bijectors/shift.py b/tensorflow_probability/python/bijectors/shift.py index 38962dd5df..c11fc1228d 100644 --- a/tensorflow_probability/python/bijectors/shift.py +++ b/tensorflow_probability/python/bijectors/shift.py @@ -21,7 +21,6 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import tensor_util @@ -32,8 +31,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Shift(bijector.AutoCompositeTensorBijector): """Compute `Y = g(X; shift) = X + shift`. diff --git a/tensorflow_probability/python/bijectors/shifted_gompertz_cdf.py b/tensorflow_probability/python/bijectors/shifted_gompertz_cdf.py index c18a25e40b..90f30f83e9 100644 --- a/tensorflow_probability/python/bijectors/shifted_gompertz_cdf.py +++ b/tensorflow_probability/python/bijectors/shifted_gompertz_cdf.py @@ -25,7 +25,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors import softplus as softplus_bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import tensor_util @@ -36,8 +35,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class ShiftedGompertzCDF(bijector.AutoCompositeTensorBijector): """Compute `Y = g(X) = (1 - exp(-rate * X)) * exp(-c * exp(-rate * X))`. diff --git a/tensorflow_probability/python/bijectors/sigmoid.py b/tensorflow_probability/python/bijectors/sigmoid.py index 48b3a9af87..839ad3b0ab 100644 --- a/tensorflow_probability/python/bijectors/sigmoid.py +++ b/tensorflow_probability/python/bijectors/sigmoid.py @@ -21,7 +21,6 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import tensor_util @@ -65,8 +64,7 @@ def grad_fn(dy): return y, grad_fn -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Sigmoid(bijector.AutoCompositeTensorBijector): """Bijector that computes the logistic sigmoid function. diff --git a/tensorflow_probability/python/bijectors/sinh.py b/tensorflow_probability/python/bijectors/sinh.py index c9f1c14a65..9e5b1606ab 100644 --- a/tensorflow_probability/python/bijectors/sinh.py +++ b/tensorflow_probability/python/bijectors/sinh.py @@ -21,7 +21,6 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python import math as tfp_math from tensorflow_probability.python.bijectors import bijector -from tensorflow_probability.python.internal import auto_composite_tensor __all__ = [ @@ -29,8 +28,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Sinh(bijector.AutoCompositeTensorBijector): """Bijector that computes `Y = sinh(X)`. diff --git a/tensorflow_probability/python/bijectors/sinh_arcsinh.py b/tensorflow_probability/python/bijectors/sinh_arcsinh.py index 8262963477..2aeadf0718 100644 --- a/tensorflow_probability/python/bijectors/sinh_arcsinh.py +++ b/tensorflow_probability/python/bijectors/sinh_arcsinh.py @@ -24,7 +24,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors import softplus as softplus_bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import tensor_util @@ -34,8 +33,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class SinhArcsinh(bijector.AutoCompositeTensorBijector): """`Y = g(X) = Sinh( (Arcsinh(X) + skewness) * tailweight ) * multiplier`. diff --git a/tensorflow_probability/python/bijectors/soft_clip.py b/tensorflow_probability/python/bijectors/soft_clip.py index 0a4d45b6dd..563d9074ec 100644 --- a/tensorflow_probability/python/bijectors/soft_clip.py +++ b/tensorflow_probability/python/bijectors/soft_clip.py @@ -31,7 +31,6 @@ from tensorflow_probability.python.bijectors import softplus from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import tensor_util @@ -42,8 +41,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class SoftClip(bijector.AutoCompositeTensorBijector): """Bijector that approximates clipping as a continuous, differentiable map. diff --git a/tensorflow_probability/python/bijectors/softfloor.py b/tensorflow_probability/python/bijectors/softfloor.py index a8e213bd0e..6df7a679f1 100644 --- a/tensorflow_probability/python/bijectors/softfloor.py +++ b/tensorflow_probability/python/bijectors/softfloor.py @@ -26,7 +26,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors import softplus as softplus_bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import tensor_util @@ -37,8 +36,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Softfloor(bijector.AutoCompositeTensorBijector): """Compute a differentiable approximation to `tf.math.floor`. diff --git a/tensorflow_probability/python/bijectors/softmax_centered.py b/tensorflow_probability/python/bijectors/softmax_centered.py index 5d4d0b63fc..b5e69c9812 100644 --- a/tensorflow_probability/python/bijectors/softmax_centered.py +++ b/tensorflow_probability/python/bijectors/softmax_centered.py @@ -25,7 +25,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors import pad as pad_lib from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import prefer_static as ps @@ -35,8 +34,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class SoftmaxCentered(bijector.AutoCompositeTensorBijector): """Bijector which computes `Y = g(X) = exp([X 0]) / sum(exp([X 0]))`. diff --git a/tensorflow_probability/python/bijectors/softplus.py b/tensorflow_probability/python/bijectors/softplus.py index 9b7de51502..92abae943c 100644 --- a/tensorflow_probability/python/bijectors/softplus.py +++ b/tensorflow_probability/python/bijectors/softplus.py @@ -21,7 +21,6 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import distribution_util from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties @@ -59,8 +58,7 @@ def grad_fn(dy): return y, grad_fn -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Softplus(bijector.AutoCompositeTensorBijector): """Bijector which computes `Y = g(X) = Log[1 + exp(X)]`. diff --git a/tensorflow_probability/python/bijectors/softsign.py b/tensorflow_probability/python/bijectors/softsign.py index d7cdb076e5..f388dca1a8 100644 --- a/tensorflow_probability/python/bijectors/softsign.py +++ b/tensorflow_probability/python/bijectors/softsign.py @@ -22,7 +22,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util @@ -31,8 +30,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Softsign(bijector.AutoCompositeTensorBijector): """Bijector which computes `Y = g(X) = X / (1 + |X|)`. diff --git a/tensorflow_probability/python/bijectors/split.py b/tensorflow_probability/python/bijectors/split.py index 87f98f3719..1bed48136b 100644 --- a/tensorflow_probability/python/bijectors/split.py +++ b/tensorflow_probability/python/bijectors/split.py @@ -25,7 +25,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import prefer_static from tensorflow_probability.python.internal import tensor_util @@ -36,8 +35,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Split(bijector.AutoCompositeTensorBijector): """Split a `Tensor` event along an axis into a list of `Tensor`s. diff --git a/tensorflow_probability/python/bijectors/square.py b/tensorflow_probability/python/bijectors/square.py index ed7f91f6b0..5ec5168da6 100644 --- a/tensorflow_probability/python/bijectors/square.py +++ b/tensorflow_probability/python/bijectors/square.py @@ -25,7 +25,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor __all__ = [ @@ -33,8 +32,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Square(bijector.AutoCompositeTensorBijector): """Compute `g(X) = X^2`; X is a positive real number. diff --git a/tensorflow_probability/python/bijectors/tanh.py b/tensorflow_probability/python/bijectors/tanh.py index 57b8309916..2483348760 100644 --- a/tensorflow_probability/python/bijectors/tanh.py +++ b/tensorflow_probability/python/bijectors/tanh.py @@ -21,7 +21,6 @@ import numpy as np import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector -from tensorflow_probability.python.internal import auto_composite_tensor __all__ = [ @@ -29,8 +28,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Tanh(bijector.AutoCompositeTensorBijector): """Bijector that computes `Y = tanh(X)`, therefore `Y in (-1, 1)`. diff --git a/tensorflow_probability/python/bijectors/transform_diagonal.py b/tensorflow_probability/python/bijectors/transform_diagonal.py index 7df7ad6d2c..fc527ea475 100644 --- a/tensorflow_probability/python/bijectors/transform_diagonal.py +++ b/tensorflow_probability/python/bijectors/transform_diagonal.py @@ -20,15 +20,13 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector -from tensorflow_probability.python.internal import auto_composite_tensor __all__ = [ 'TransformDiagonal', ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name', 'parameters'), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class TransformDiagonal(bijector.AutoCompositeTensorBijector): """Applies a Bijector to the diagonal of a matrix. diff --git a/tensorflow_probability/python/bijectors/transpose.py b/tensorflow_probability/python/bijectors/transpose.py index 67d11931f1..b6dfdd91a0 100644 --- a/tensorflow_probability/python/bijectors/transpose.py +++ b/tensorflow_probability/python/bijectors/transpose.py @@ -25,7 +25,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import distribution_util from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties @@ -38,8 +37,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class Transpose(bijector.AutoCompositeTensorBijector): """Compute `Y = g(X) = transpose_rightmost_dims(X, rightmost_perm)`. diff --git a/tensorflow_probability/python/bijectors/weibull_cdf.py b/tensorflow_probability/python/bijectors/weibull_cdf.py index c20ec003aa..8734aa6aee 100644 --- a/tensorflow_probability/python/bijectors/weibull_cdf.py +++ b/tensorflow_probability/python/bijectors/weibull_cdf.py @@ -23,7 +23,6 @@ from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors import softplus as softplus_bijector from tensorflow_probability.python.internal import assert_util -from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import parameter_properties from tensorflow_probability.python.internal import tensor_util @@ -34,8 +33,7 @@ ] -@auto_composite_tensor.auto_composite_tensor( - omit_kwargs=('name',), module_name='tfp.bijectors') +@bijector.auto_composite_tensor_bijector class WeibullCDF(bijector.AutoCompositeTensorBijector): """Compute `Y = g(X) = 1 - exp( -( X / scale) ** concentration), X >= 0`. diff --git a/tensorflow_probability/python/distributions/BUILD b/tensorflow_probability/python/distributions/BUILD index e2d8e938c1..d8e0b3f871 100644 --- a/tensorflow_probability/python/distributions/BUILD +++ b/tensorflow_probability/python/distributions/BUILD @@ -3029,6 +3029,7 @@ multi_substrate_py_test( # numpy dep, # tensorflow dep, "//tensorflow_probability", + "//tensorflow_probability/python/internal:prefer_static", "//tensorflow_probability/python/internal:test_util", # tensorflow/compiler/jit dep, ], diff --git a/tensorflow_probability/python/distributions/bernoulli.py b/tensorflow_probability/python/distributions/bernoulli.py index 6f00c9a0f9..327384ee32 100644 --- a/tensorflow_probability/python/distributions/bernoulli.py +++ b/tensorflow_probability/python/distributions/bernoulli.py @@ -168,7 +168,7 @@ def _logits_parameter_no_checks(self): if self._logits is None: probs = tf.convert_to_tensor(self._probs) return tf.math.log(probs) - tf.math.log1p(-probs) - return tf.identity(self._logits) + return tensor_util.identity_as_tensor(self._logits) def probs_parameter(self, name=None): """Probs computed from non-`None` input arg (`probs` or `logits`).""" @@ -177,7 +177,7 @@ def probs_parameter(self, name=None): def _probs_parameter_no_checks(self): if self._logits is None: - return tf.identity(self._probs) + return tensor_util.identity_as_tensor(self._probs) return tf.math.sigmoid(self._logits) def _default_event_space_bijector(self): diff --git a/tensorflow_probability/python/distributions/binomial.py b/tensorflow_probability/python/distributions/binomial.py index 6721802606..26d0c55b1c 100644 --- a/tensorflow_probability/python/distributions/binomial.py +++ b/tensorflow_probability/python/distributions/binomial.py @@ -486,7 +486,7 @@ def _logits_parameter_no_checks(self): if self._logits is None: probs = tf.convert_to_tensor(self._probs) return tf.math.log(probs) - tf.math.log1p(-probs) - return tf.identity(self._logits) + return tensor_util.identity_as_tensor(self._logits) def probs_parameter(self, name=None): """Probs computed from non-`None` input arg (`probs` or `logits`).""" @@ -495,7 +495,7 @@ def probs_parameter(self, name=None): def _probs_parameter_no_checks(self, total_count=None): if self._logits is None: - probs = tf.identity(self._probs) + probs = tensor_util.identity_as_tensor(self._probs) else: probs = tf.math.sigmoid(self._logits) # Suppress potentially nasty probs like `nan` b/c they don't matter where diff --git a/tensorflow_probability/python/distributions/categorical.py b/tensorflow_probability/python/distributions/categorical.py index 6413fe5412..0dadab9588 100644 --- a/tensorflow_probability/python/distributions/categorical.py +++ b/tensorflow_probability/python/distributions/categorical.py @@ -343,7 +343,7 @@ def logits_parameter(self, name=None): def _logits_parameter_no_checks(self): if self._logits is None: return tf.math.log(self._probs) - return tf.identity(self._logits) + return tensor_util.identity_as_tensor(self._logits) def probs_parameter(self, name=None): """Probs vec computed from non-`None` input arg (`probs` or `logits`).""" @@ -352,7 +352,7 @@ def probs_parameter(self, name=None): def _probs_parameter_no_checks(self): if self._logits is None: - return tf.identity(self._probs) + return tensor_util.identity_as_tensor(self._probs) return tf.math.softmax(self._logits) def _num_categories(self, x=None): diff --git a/tensorflow_probability/python/distributions/continuous_bernoulli.py b/tensorflow_probability/python/distributions/continuous_bernoulli.py index dea6b3db05..6359c83257 100644 --- a/tensorflow_probability/python/distributions/continuous_bernoulli.py +++ b/tensorflow_probability/python/distributions/continuous_bernoulli.py @@ -446,7 +446,7 @@ def _logits_parameter_no_checks(self): if self._logits is None: probs = tf.convert_to_tensor(self._probs) return tf.math.log(probs) - tf.math.log1p(-probs) - return tf.identity(self._logits) + return tensor_util.identity_as_tensor(self._logits) def probs_parameter(self, name=None): """probs computed from non-`None` input arg (`probs` or `logits`).""" @@ -455,7 +455,7 @@ def probs_parameter(self, name=None): def _probs_parameter_no_checks(self): if self._logits is None: - return tf.identity(self._probs) + return tensor_util.identity_as_tensor(self._probs) return tf.math.sigmoid(self._logits) def _default_event_space_bijector(self): diff --git a/tensorflow_probability/python/distributions/deterministic.py b/tensorflow_probability/python/distributions/deterministic.py index 2de7eddc67..0ded753ede 100644 --- a/tensorflow_probability/python/distributions/deterministic.py +++ b/tensorflow_probability/python/distributions/deterministic.py @@ -317,7 +317,7 @@ def _prob(self, x): return tf.cast(tf.abs(x - loc) <= self._slack(loc), dtype=prob_dtype) def _cdf(self, x): - loc = tf.identity(self.loc) + loc = tensor_util.identity_as_tensor(self.loc) return tf.cast(x >= loc - self._slack(loc), dtype=self.dtype) diff --git a/tensorflow_probability/python/distributions/distribution.py b/tensorflow_probability/python/distributions/distribution.py index a1724f679d..e88d8214b3 100644 --- a/tensorflow_probability/python/distributions/distribution.py +++ b/tensorflow_probability/python/distributions/distribution.py @@ -23,6 +23,7 @@ import contextlib import functools import inspect +import logging import types from absl import logging @@ -276,40 +277,61 @@ def __new__(mcs, classname, baseclasses, attrs): return super(_DistributionMeta, mcs).__new__( mcs, classname, baseclasses, attrs) - # Subclasses shouldn't inherit their parents' `_parameter_properties`, - # since (in general) they'll have different parameters. Exceptions (for - # convenience) are: + # Warn when a subclass inherits `_parameter_properties` from its parent + # (this is unsafe, since the subclass will in general have different + # parameters). Exceptions are: # - Subclasses that don't define their own `__init__` (handled above by # the short-circuit when `default_init is None`). # - Subclasses that define a passthrough `__init__(self, *args, **kwargs)`. - # - Direct children of `Distribution`, since the inherited method just - # raises a NotImplementedError. + # pylint: disable=protected-access init_argspec = tf_inspect.getfullargspec(default_init) if ('_parameter_properties' not in attrs - and base != Distribution # Passthrough exception: may only take `self` and at least one of # `*args` and `**kwargs`. and (len(init_argspec.args) > 1 or not (init_argspec.varargs or init_argspec.varkw))): - # TODO(b/183457779) remove warning and raise `NotImplementedError`. - attrs['_parameter_properties'] = deprecation.deprecated( - date='2021-07-01', - instructions=""" -Calling `_parameter_properties` on subclass {classname} that redefines the -parent ({basename}) `__init__` is unsafe and will raise an error in the future. -Please implement an explicit `_parameter_properties` for the subclass. If the -subclass `__init__` takes the same parameters as the parent, you may use the -placeholder implementation: - @classmethod - def _parameter_properties(cls, dtype, num_classes=None): - return {basename}._parameter_properties( - dtype=dtype, num_classes=num_classes) + @functools.wraps(base._parameter_properties) + def wrapped_properties(*args, **kwargs): # pylint: disable=missing-docstring + """Wrapper to warn if `parameter_properties` is inherited.""" + properties = base._parameter_properties(*args, **kwargs) + # Warn *after* calling the base method, so that we don't bother warning + # if it just raised NotImplementedError anyway. + logging.warning(""" +Distribution subclass %s inherits `_parameter_properties from its parent (%s) +while also redefining `__init__`. The inherited annotations cover the following +parameters: %s. It is likely that these do not match the subclass parameters. +This may lead to errors when computing batch shapes, slicing into batch +dimensions, calling `.copy()`, flattening the distribution as a CompositeTensor +(e.g., when it is passed or returned from a `tf.function`), and possibly other +cases. The recommended pattern for distribution subclasses is to define a new +`_parameter_properties` method with the subclass parameters, and to store the +corresponding parameter values as `self._parameters` in `__init__`, after +calling the superclass constructor: + +``` +class MySubclass(tfd.SomeDistribution): + + def __init__(self, param_a, param_b): + parameters = dict(locals()) + # ... do subclass initialization ... + super(MySubclass, self).__init__(**base_class_params) + # Ensure that the subclass (not base class) parameters are stored. + self._parameters = parameters + + def _parameter_properties(self, dtype, num_classes=None): + return dict( + # Annotations may optionally specify properties, such as `event_ndims`, + # `default_constraining_bijector_fn`, `specifies_shape`, etc.; see + # the `ParameterProperties` documentation for details. + param_a=tfp.util.ParameterProperties(), + param_b=tfp.util.ParameterProperties()) +``` +""", classname, base.__name__, str(properties.keys())) + return properties -""".format(classname=classname, - basename=base.__name__))(base._parameter_properties) + attrs['_parameter_properties'] = wrapped_properties - # pylint: disable=protected-access # For a comparison of different methods for wrapping functions, see: # https://hynek.me/articles/decorators/ @decorator.decorator @@ -657,10 +679,7 @@ def _composite_tensor_shape_params(self): @classmethod def _parameter_properties(cls, dtype, num_classes=None): raise NotImplementedError( - '_parameter_properties` is not implemented: {}. ' - 'Note that subclasses that redefine `__init__` are not assumed to ' - 'share parameters with their parent class and must provide a separate ' - 'implementation.'.format(cls.__name__)) + '_parameter_properties` is not implemented: {}.'.format(cls.__name__)) @classmethod def parameter_properties(cls, dtype=tf.float32, num_classes=None): diff --git a/tensorflow_probability/python/distributions/distribution_test.py b/tensorflow_probability/python/distributions/distribution_test.py index bc1b561d29..52a7d6b71a 100644 --- a/tensorflow_probability/python/distributions/distribution_test.py +++ b/tensorflow_probability/python/distributions/distribution_test.py @@ -19,6 +19,7 @@ import collections # Dependency imports +from absl import logging from absl.testing import parameterized import numpy as np @@ -29,8 +30,6 @@ from tensorflow_probability.python.internal import test_util from tensorflow.python.framework import test_util as tf_test_util # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top -from tensorflow.python.platform import test as tf_test # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top -from tensorflow.python.platform import tf_logging # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top class TupleDistribution(tfd.Distribution): @@ -617,10 +616,7 @@ def normal_differential_entropy(scale): self.evaluate(normal_differential_entropy(scale)), err=1e-5) - @test_util.jax_disable_test_missing_functionality('tf_logging') - @tf_test.mock.patch.object(tf_logging, 'warning', autospec=True) - def testParameterPropertiesNotInherited(self, mock_warning): - # TODO(b/183457779) Test for NotImplementedError (rather than just warning). + def testParameterPropertiesNotInherited(self): # Subclasses that don't redefine __init__ can inherit properties. class NormalTrivialSubclass(tfd.Normal): @@ -640,11 +636,17 @@ class MyDistribution(tfd.Distribution): def __init__(self, param1, param2): pass - NormalTrivialSubclass.parameter_properties() - NormalWithPassThroughInit.parameter_properties() - with self.assertRaises(NotImplementedError): - MyDistribution.parameter_properties() - self.assertEqual(0, mock_warning.call_count) + with self.assertLogs(level=logging.WARNING) as log: + NormalTrivialSubclass.parameter_properties() + NormalWithPassThroughInit.parameter_properties() + with self.assertRaises(NotImplementedError): + MyDistribution.parameter_properties() + with self.assertRaises(NotImplementedError): + # Ensure that the unimplemented JD propertoes don't raise a warning. + tfd.JointDistributionCoroutine.parameter_properties() + logging.warning('assertLogs context requires at least one warning.') + # Assert that no warnings occurred other than the dummy warning. + self.assertLen(log.records, 1) class NormalWithExtraParam(tfd.Normal): @@ -652,8 +654,9 @@ def __init__(self, extra_param, *args, **kwargs): self._extra_param = extra_param super(NormalWithExtraParam, self).__init__(*args, **kwargs) - NormalWithExtraParam.parameter_properties() - self.assertEqual(1, mock_warning.call_count) + with self.assertLogs(level=logging.WARNING) as log: + NormalWithExtraParam.parameter_properties() + self.assertLen(log.records, 1) @test_util.test_all_tf_execution_regimes diff --git a/tensorflow_probability/python/distributions/geometric.py b/tensorflow_probability/python/distributions/geometric.py index 1eab6fdc85..4c2eced3b4 100644 --- a/tensorflow_probability/python/distributions/geometric.py +++ b/tensorflow_probability/python/distributions/geometric.py @@ -261,11 +261,11 @@ def _logits_parameter_no_checks(self): if self._logits is None: probs = tf.convert_to_tensor(self._probs) return tf.math.log(probs) - tf.math.log1p(-probs) - return tf.identity(self._logits) + return tensor_util.identity_as_tensor(self._logits) def _probs_parameter_no_checks(self): if self._logits is None: - return tf.identity(self._probs) + return tensor_util.identity_as_tensor(self._probs) return tf.math.sigmoid(self._logits) def _logits_and_probs_no_checks(self): diff --git a/tensorflow_probability/python/distributions/independent.py b/tensorflow_probability/python/distributions/independent.py index 214720803e..5cdc1f0633 100644 --- a/tensorflow_probability/python/distributions/independent.py +++ b/tensorflow_probability/python/distributions/independent.py @@ -99,12 +99,12 @@ class Independent(distribution_lib.Distribution): """ - @deprecation.deprecated_arg_values( - '2022-03-01', - 'Please pass an integer value for `reinterpreted_batch_ndims`. The ' - 'current behavior corresponds to `reinterpreted_batch_ndims=tf.size(' - 'distribution.batch_shape_tensor()) - 1`.', - reinterpreted_batch_ndims=None) + # @deprecation.deprecated_arg_values( + # '2022-03-01', + # 'Please pass an integer value for `reinterpreted_batch_ndims`. The ' + # 'current behavior corresponds to `reinterpreted_batch_ndims=tf.size(' + # 'distribution.batch_shape_tensor()) - 1`.', + # reinterpreted_batch_ndims=None) def __init__(self, distribution, reinterpreted_batch_ndims=None, diff --git a/tensorflow_probability/python/distributions/joint_distribution_auto_batched_test.py b/tensorflow_probability/python/distributions/joint_distribution_auto_batched_test.py index 2146ac8135..9e51fa16c9 100644 --- a/tensorflow_probability/python/distributions/joint_distribution_auto_batched_test.py +++ b/tensorflow_probability/python/distributions/joint_distribution_auto_batched_test.py @@ -30,6 +30,7 @@ import tensorflow.compat.v2 as tf import tensorflow_probability as tfp +from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.internal import test_util tfb = tfp.bijectors @@ -493,7 +494,9 @@ def test_unit_sample_shape(self): @tfd.JointDistributionCoroutineAutoBatched def dist(): x = yield tfd.Normal(loc=tf.zeros([3]), scale=1., name='x') - yield tfd.Bernoulli(logits=tf.einsum('n->', x), name='y') + if ps.rank(x) != 1: + raise ValueError('Unexpected shape.') + yield tfd.Bernoulli(logits=tf.reduce_sum(x), name='y') for sample_shape in [(), 1, [1], [1, 1], [2]]: self.assertAllEqual( diff --git a/tensorflow_probability/python/distributions/linear_gaussian_ssm.py b/tensorflow_probability/python/distributions/linear_gaussian_ssm.py index 2010b409f1..9189cff83d 100644 --- a/tensorflow_probability/python/distributions/linear_gaussian_ssm.py +++ b/tensorflow_probability/python/distributions/linear_gaussian_ssm.py @@ -661,7 +661,7 @@ def _get_time_varying_kwargs(self, idx): transition_noise = self.get_transition_noise_for_timestep(t) observation_noise = self.get_observation_noise_for_timestep(t) return tf.nest.map_structure( - tf.identity, + tensor_util.identity_as_tensor, {'transition_matrix': ( self.get_transition_matrix_for_timestep(t).to_dense()), 'observation_matrix': ( diff --git a/tensorflow_probability/python/distributions/multinomial.py b/tensorflow_probability/python/distributions/multinomial.py index 9ef75f7946..00e8e7bfd3 100644 --- a/tensorflow_probability/python/distributions/multinomial.py +++ b/tensorflow_probability/python/distributions/multinomial.py @@ -290,7 +290,7 @@ def logits_parameter(self, name=None): def _logits_parameter_no_checks(self): if self._logits is None: return tf.math.log(self._probs) - return tf.identity(self._logits) + return tensor_util.identity_as_tensor(self._logits) def probs_parameter(self, name=None): """Probs vec computed from non-`None` input arg (`probs` or `logits`).""" @@ -299,7 +299,7 @@ def probs_parameter(self, name=None): def _probs_parameter_no_checks(self): if self._logits is None: - return tf.identity(self._probs) + return tensor_util.identity_as_tensor(self._probs) return tf.math.softmax(self._logits) def _default_event_space_bijector(self): diff --git a/tensorflow_probability/python/distributions/negative_binomial.py b/tensorflow_probability/python/distributions/negative_binomial.py index 567475b61d..a8bbd19039 100644 --- a/tensorflow_probability/python/distributions/negative_binomial.py +++ b/tensorflow_probability/python/distributions/negative_binomial.py @@ -236,7 +236,7 @@ def _logits_parameter_no_checks(self, name=None): if self._logits is None: probs = tf.convert_to_tensor(self._probs) return tf.math.log(probs) - tf.math.log1p(-probs) - return tf.identity(self._logits) + return tensor_util.identity_as_tensor(self._logits) def logits_parameter(self, name=None): """Logits computed from non-`None` input arg (`probs` or `logits`).""" @@ -246,7 +246,7 @@ def logits_parameter(self, name=None): def _probs_parameter_no_checks(self, name=None): """Probs computed from non-`None` input arg (`probs` or `logits`).""" if self._logits is None: - return tf.identity(self._probs) + return tensor_util.identity_as_tensor(self._probs) return tf.math.sigmoid(self._logits) def probs_parameter(self, name=None): diff --git a/tensorflow_probability/python/distributions/onehot_categorical.py b/tensorflow_probability/python/distributions/onehot_categorical.py index 2e0c168b0f..a554d4126c 100644 --- a/tensorflow_probability/python/distributions/onehot_categorical.py +++ b/tensorflow_probability/python/distributions/onehot_categorical.py @@ -265,7 +265,7 @@ def logits_parameter(self, name=None): def _logits_parameter_no_checks(self): if self._logits is None: return tf.math.log(self._probs) - return tf.identity(self._logits) + return tensor_util.identity_as_tensor(self._logits) def probs_parameter(self, name=None): """Probs vec computed from non-`None` input arg (`probs` or `logits`).""" @@ -274,7 +274,7 @@ def probs_parameter(self, name=None): def _probs_parameter_no_checks(self): if self._logits is None: - return tf.identity(self._probs) + return tensor_util.identity_as_tensor(self._probs) return tf.math.softmax(self._logits) def _default_event_space_bijector(self): diff --git a/tensorflow_probability/python/distributions/poisson.py b/tensorflow_probability/python/distributions/poisson.py index 3a7424ef60..0f9231aca2 100644 --- a/tensorflow_probability/python/distributions/poisson.py +++ b/tensorflow_probability/python/distributions/poisson.py @@ -368,7 +368,7 @@ def rate_parameter(self, name=None): def _rate_parameter_no_checks(self): if self._rate is None: return tf.exp(self._log_rate) - return tf.identity(self._rate) + return tensor_util.identity_as_tensor(self._rate) def log_rate_parameter(self, name=None): """Log-rate vec computed from non-`None` input arg (`rate`, `log_rate`).""" @@ -378,7 +378,7 @@ def log_rate_parameter(self, name=None): def _log_rate_parameter_no_checks(self): if self._log_rate is None: return tf.math.log(self._rate) - return tf.identity(self._log_rate) + return tensor_util.identity_as_tensor(self._log_rate) def _default_event_space_bijector(self): return diff --git a/tensorflow_probability/python/distributions/probit_bernoulli.py b/tensorflow_probability/python/distributions/probit_bernoulli.py index 9ef56d98aa..51a363c2d0 100644 --- a/tensorflow_probability/python/distributions/probit_bernoulli.py +++ b/tensorflow_probability/python/distributions/probit_bernoulli.py @@ -170,7 +170,7 @@ def _probits_parameter_no_checks(self): if self._probits is None: probs = tf.convert_to_tensor(self._probs) return tf.math.ndtri(probs) - return tf.identity(self._probits) + return tensor_util.identity_as_tensor(self._probits) def probs_parameter(self, name=None): """Probs computed from non-`None` input arg (`probs` or `probits`).""" @@ -179,7 +179,7 @@ def probs_parameter(self, name=None): def _probs_parameter_no_checks(self): if self._probits is None: - return tf.identity(self._probs) + return tensor_util.identity_as_tensor(self._probs) return special_math.ndtr(self._probits) def _default_event_space_bijector(self): diff --git a/tensorflow_probability/python/distributions/relaxed_bernoulli.py b/tensorflow_probability/python/distributions/relaxed_bernoulli.py index f18629e4b4..43be808577 100644 --- a/tensorflow_probability/python/distributions/relaxed_bernoulli.py +++ b/tensorflow_probability/python/distributions/relaxed_bernoulli.py @@ -234,7 +234,7 @@ def _logits_parameter_no_checks(self): if self._logits is None: probs = tf.convert_to_tensor(self._probs) return tf.math.log(probs) - tf.math.log1p(-probs) - return tf.identity(self._logits) + return tensor_util.identity_as_tensor(self._logits) def probs_parameter(self, name=None): """Probs computed from non-`None` input arg (`probs` or `logits`).""" @@ -243,7 +243,7 @@ def probs_parameter(self, name=None): def _probs_parameter_no_checks(self): if self._logits is None: - return tf.identity(self._probs) + return tensor_util.identity_as_tensor(self._probs) return tf.math.sigmoid(self._logits) def _event_shape_tensor(self): diff --git a/tensorflow_probability/python/distributions/relaxed_onehot_categorical.py b/tensorflow_probability/python/distributions/relaxed_onehot_categorical.py index 546ecb4596..b05dd0a20e 100644 --- a/tensorflow_probability/python/distributions/relaxed_onehot_categorical.py +++ b/tensorflow_probability/python/distributions/relaxed_onehot_categorical.py @@ -297,7 +297,7 @@ def logits_parameter(self, name=None): def _logits_parameter_no_checks(self): if self._logits is None: return tf.math.log(self._probs) - return tf.identity(self._logits) + return tensor_util.identity_as_tensor(self._logits) def probs_parameter(self, name=None): """Probs vec computed from non-`None` input arg (`probs` or `logits`).""" @@ -306,7 +306,7 @@ def probs_parameter(self, name=None): def _probs_parameter_no_checks(self): if self._logits is None: - return tf.identity(self._probs) + return tensor_util.identity_as_tensor(self._probs) return tf.math.softmax(self._logits) def _sample_control_dependencies(self, x): diff --git a/tensorflow_probability/python/distributions/skellam.py b/tensorflow_probability/python/distributions/skellam.py index 60859a032d..25d9ab00a5 100644 --- a/tensorflow_probability/python/distributions/skellam.py +++ b/tensorflow_probability/python/distributions/skellam.py @@ -236,7 +236,7 @@ def rate1_parameter(self, name=None): def _rate1_parameter_no_checks(self): if self._rate1 is None: return tf.exp(self._log_rate1) - return tf.identity(self._rate1) + return tensor_util.identity_as_tensor(self._rate1) def log_rate1_parameter(self, name=None): """Log-rate computed from non-`None` input arg (`rate1`, `log_rate1`).""" @@ -246,7 +246,7 @@ def log_rate1_parameter(self, name=None): def _log_rate1_parameter_no_checks(self): if self._log_rate1 is None: return tf.math.log(self._rate1) - return tf.identity(self._log_rate1) + return tensor_util.identity_as_tensor(self._log_rate1) def rate2_parameter(self, name=None): """Rate computed from non-`None` input arg (`rate2` or `log_rate2`).""" @@ -256,7 +256,7 @@ def rate2_parameter(self, name=None): def _rate2_parameter_no_checks(self): if self._rate2 is None: return tf.exp(self._log_rate2) - return tf.identity(self._rate2) + return tensor_util.identity_as_tensor(self._rate2) def log_rate2_parameter(self, name=None): """Log-rate computed from non-`None` input arg (`rate2`, `log_rate2`).""" @@ -266,7 +266,7 @@ def log_rate2_parameter(self, name=None): def _log_rate2_parameter_no_checks(self): if self._log_rate2 is None: return tf.math.log(self._rate2) - return tf.identity(self._log_rate2) + return tensor_util.identity_as_tensor(self._log_rate2) def _all_rate_parameters(self): rate1 = None diff --git a/tensorflow_probability/python/experimental/composite_tensor_test.py b/tensorflow_probability/python/experimental/composite_tensor_test.py index 6543d67f69..1e683d47b6 100644 --- a/tensorflow_probability/python/experimental/composite_tensor_test.py +++ b/tensorflow_probability/python/experimental/composite_tensor_test.py @@ -454,7 +454,11 @@ def test_basics_mixture_same_family(self): self.evaluate(unflat.log_prob(.5)) def test_already_composite_tensor(self): - b = tfb.Scale(2.) + AutoScale = tfp.experimental.auto_composite_tensor( # pylint: disable=invalid-name + tfb.Scale, omit_kwargs=('parameters',), + non_identifying_kwargs=('name',), + module_name=('tfp.bijectors')) + b = AutoScale(2.) b2 = tfp.experimental.as_composite(b) self.assertIsInstance(b, tf.__internal__.CompositeTensor) self.assertIs(b, b2) diff --git a/tensorflow_probability/python/internal/auto_composite_tensor.py b/tensorflow_probability/python/internal/auto_composite_tensor.py index 83ff4cdcf5..7dfdfceaad 100644 --- a/tensorflow_probability/python/internal/auto_composite_tensor.py +++ b/tensorflow_probability/python/internal/auto_composite_tensor.py @@ -60,7 +60,7 @@ def _deferred_assertion_context(is_deferred=True): _SENTINEL = object() -_AUTO_COMPOSITE_TENSOR_VERSION = 2 +_AUTO_COMPOSITE_TENSOR_VERSION = 3 # Cache maps __init__ method to signature _sig_cache = {} @@ -171,7 +171,8 @@ class _AutoCompositeTensorTypeSpec(tf.TypeSpec): '_comparable') def __init__(self, param_specs, non_tensor_params, omit_kwargs, - prefer_static_value, callable_params=None): + prefer_static_value, non_identifying_kwargs, + callable_params=None): """Initializes a new `_AutoCompositeTensorTypeSpec`. Args: @@ -189,6 +190,11 @@ def __init__(self, param_specs, non_tensor_params, omit_kwargs, of `Tensor`-like kwargs to the `AutoCompositeTensor`s constructor that may be stored as static values, if known. These are typically shapes or axis values. + non_identifying_kwargs: Python `tuple` of strings corresponding to the + names of kwargs to the `AutoCompositeTensor`s constructor whose values + are not relevant to the unique identification of the + `_AutoCompositeTensorTypeSpec` instance. Equality/comparison checks and + `__hash__` do not depend on these kwargs. callable_params: Python `dict` of callable kwargs to the `AutoCompositeTensor`'s constructor that do not subclass `CompositeTensor`, or `None`. If `callable_params` is a non-empty @@ -199,6 +205,7 @@ def __init__(self, param_specs, non_tensor_params, omit_kwargs, self._non_tensor_params = non_tensor_params self._omit_kwargs = omit_kwargs self._prefer_static_value = prefer_static_value + self._non_identifying_kwargs = non_identifying_kwargs self._callable_params = {} if callable_params is None else callable_params self._serializable = ( @@ -206,15 +213,24 @@ def __init__(self, param_specs, non_tensor_params, omit_kwargs, self._param_specs, self._non_tensor_params, self._omit_kwargs, - self._prefer_static_value) + self._prefer_static_value, + self._non_identifying_kwargs) - # TODO(b/182603117): Distinguish between `omit_kwargs_from_constructor` - # and `omit_kwargs_for_comparison`. - self._comparable = self._serializable + ( - tf.nest.map_structure(id, self._callable_params),) + def remove_kwargs(d): + return {k: v for k, v in d.items() + if k not in self._non_identifying_kwargs} + + self._comparable = ( + _AUTO_COMPOSITE_TENSOR_VERSION, + remove_kwargs(self._param_specs), + remove_kwargs(self._non_tensor_params), + self._omit_kwargs, + self._prefer_static_value, + self._non_identifying_kwargs, + tf.nest.map_structure(id, remove_kwargs(self._callable_params))) @classmethod - def from_instance(cls, instance, omit_kwargs=()): + def from_instance(cls, instance, omit_kwargs=(), non_identifying_kwargs=()): cls_value_type = cls.value_type.fget(None) if type(instance) is not cls_value_type: # pylint: disable=unidiomatic-typecheck raise ValueError(f'`{type(instance).__name__}` has inherited the ' @@ -245,6 +261,7 @@ def from_instance(cls, instance, omit_kwargs=()): non_tensor_params=non_tensor_params, omit_kwargs=omit_kwargs, prefer_static_value=prefer_static_value, + non_identifying_kwargs=non_identifying_kwargs, callable_params=callable_params) def _to_components(self, obj): @@ -273,6 +290,9 @@ def _deserialize(cls, encoded): if version == 1: encoded = encoded + ((),) version = 2 + if version == 2: + encoded = encoded + ((),) + version = 3 if version != _AUTO_COMPOSITE_TENSOR_VERSION: raise ValueError(f'Expected version {_AUTO_COMPOSITE_TENSOR_VERSION},' f' but got {version}.') @@ -375,7 +395,8 @@ def _type_spec(self): pass -def auto_composite_tensor(cls=None, omit_kwargs=(), module_name=None): +def auto_composite_tensor( + cls=None, omit_kwargs=(), non_identifying_kwargs=(), module_name=None): """Automagically generate `CompositeTensor` behavior for `cls`. `CompositeTensor` objects are able to pass in and out of `tf.function` and @@ -499,6 +520,8 @@ def body(obj): Args: cls: The class for which to create a CompositeTensor subclass. omit_kwargs: Optional sequence of kwarg names to be omitted from the spec. + non_identifying_kwargs: Optional sequence of kwarg names to be omitted from + equality/comparison checks and the `__hash__` method of the spec. module_name: The module name with which to register the `TypeSpec`. If `None`, defaults to `cls.__module__`. @@ -508,6 +531,7 @@ def body(obj): if cls is None: return functools.partial(auto_composite_tensor, omit_kwargs=omit_kwargs, + non_identifying_kwargs=non_identifying_kwargs, module_name=module_name) if module_name is None: @@ -537,11 +561,15 @@ def value_type(self): _AlreadyCTTypeSpec.__name__ = type_spec_class_name - cls._type_spec = property( # pylint: disable=protected-access - lambda self: _AlreadyCTTypeSpec.from_instance(self, omit_kwargs)) + def _type_spec(obj): + return _AlreadyCTTypeSpec.from_instance( + obj, omit_kwargs, non_identifying_kwargs) + + cls._type_spec = property(_type_spec) # pylint: disable=protected-access return cls - clsid = (cls.__module__, cls.__name__, omit_kwargs) + clsid = (cls.__module__, cls.__name__, omit_kwargs, + non_identifying_kwargs) # Check for subclass if retrieving from the _registry, in case the user # has redefined the class (e.g. in a REPL/notebook). @@ -562,7 +590,8 @@ class _AutoCompositeTensor(cls, composite_tensor.CompositeTensor): @property def _type_spec(self): - return _GeneratedCTTypeSpec.from_instance(self, omit_kwargs) + return _GeneratedCTTypeSpec.from_instance( + self, omit_kwargs, non_identifying_kwargs) _AutoCompositeTensor.__name__ = cls.__name__ _registry[clsid] = _AutoCompositeTensor diff --git a/tensorflow_probability/python/internal/auto_composite_tensor_test.py b/tensorflow_probability/python/internal/auto_composite_tensor_test.py index 2fb3a13100..5d3015745d 100644 --- a/tensorflow_probability/python/internal/auto_composite_tensor_test.py +++ b/tensorflow_probability/python/internal/auto_composite_tensor_test.py @@ -22,7 +22,7 @@ import os from absl import flags -from absl.testing import absltest +# from absl.testing import absltest from absl.testing import parameterized import tensorflow.compat.v2 as tf @@ -49,20 +49,24 @@ AutoIdentity = tfp.experimental.auto_composite_tensor( - tf.linalg.LinearOperatorIdentity, omit_kwargs=('name',)) + tf.linalg.LinearOperatorIdentity, non_identifying_kwargs=('name',)) AutoDiag = tfp.experimental.auto_composite_tensor( - tf.linalg.LinearOperatorDiag, omit_kwargs=('name',)) + tf.linalg.LinearOperatorDiag, non_identifying_kwargs=('name',)) AutoBlockDiag = tfp.experimental.auto_composite_tensor( - tf.linalg.LinearOperatorBlockDiag, omit_kwargs=('name',)) + tf.linalg.LinearOperatorBlockDiag, non_identifying_kwargs=('name',)) AutoTriL = tfp.experimental.auto_composite_tensor( - tf.linalg.LinearOperatorLowerTriangular, omit_kwargs=('name',)) + tf.linalg.LinearOperatorLowerTriangular, non_identifying_kwargs=('name',)) AutoNormal = tfp.experimental.auto_composite_tensor( - tfd.Normal, omit_kwargs=('name',)) + tfd.Normal, non_identifying_kwargs=('name',)) AutoIndependent = tfp.experimental.auto_composite_tensor( - tfd.Independent, omit_kwargs=('name',)) + tfd.Independent, non_identifying_kwargs=('name',)) AutoReshape = tfp.experimental.auto_composite_tensor( - tfb.Reshape, omit_kwargs=('name',)) + tfb.Reshape, non_identifying_kwargs=('name',)) +AutoScale = tfp.experimental.auto_composite_tensor( + tfb.Scale, non_identifying_kwargs=('name',)) +AutoTransformDiagonal = tfp.experimental.auto_composite_tensor( + tfb.TransformDiagonal, non_identifying_kwargs=('name',)) class Model(tf.Module): @@ -71,9 +75,9 @@ def __init__(self): self.scale = tf.Variable([0., 1.], shape=[None]) @tf.function(input_signature=( - tfb.Scale([1., 2.], validate_args=True)._type_spec,)) + AutoScale([1., 2.], validate_args=True)._type_spec,)) def make_bij(self, b): - return tfb.Scale( + return AutoScale( tf.convert_to_tensor(self.scale) + b.scale, validate_args=True) @@ -105,7 +109,7 @@ def tearDownModule(): class AutoCompositeTensorTest(test_util.TestCase): def test_example(self): - @tfp.experimental.auto_composite_tensor(omit_kwargs=('name',)) + @tfp.experimental.auto_composite_tensor(non_identifying_kwargs=('name',)) class Adder(object): def __init__(self, x, y, name=None): @@ -185,7 +189,7 @@ def test_preconditioner(self): tfed = tfp.experimental.distributions auto_ct_mvn_prec_linop = tfp.experimental.auto_composite_tensor( tfed.MultivariateNormalPrecisionFactorLinearOperator, - omit_kwargs=('name',)) + non_identifying_kwargs=('name',)) tril = AutoTriL(**cov_linop.cholesky().parameters) momentum_distribution = auto_ct_mvn_prec_linop(precision_factor=tril) def body(d): @@ -314,7 +318,7 @@ def test_export_import(self): tf.saved_model.save(m1, os.path.join(path, 'saved_model1')) m2 = tf.saved_model.load(os.path.join(path, 'saved_model1')) self.evaluate(m2.scale.initializer) - b = tfb.Scale([5., 9.], validate_args=True) + b = AutoScale([5., 9.], validate_args=True) self.evaluate(m2.make_bij(b).forward(2.)) self.evaluate(m2.scale.assign(m2.scale + [1., 2.])) self.evaluate(m2.make_bij(b).forward(2.)) @@ -326,19 +330,20 @@ def test_export_import(self): with self.assertRaisesOpError('compatible shape'): self.evaluate(m3.make_bij(b).forward([3.])) - def test_saved_model_from_disk(self): + # Test disabled for 0.13 release. +# def test_saved_model_from_disk(self): - test_srcdir = absltest.get_default_test_srcdir() - relative_testdata_path = os.path.join( - TFP_PYTHON_DIR, 'internal/testdata/auto_composite_tensor') - absolute_testdata_path = os.path.join(test_srcdir, relative_testdata_path) +# test_srcdir = absltest.get_default_test_srcdir() +# relative_testdata_path = os.path.join( +# TFP_PYTHON_DIR, 'internal/testdata/auto_composite_tensor') +# absolute_testdata_path = os.path.join(test_srcdir, relative_testdata_path) - m = tf.saved_model.load(absolute_testdata_path) - self.evaluate(m.scale.initializer) - b = tfb.Scale([5., 9.], validate_args=True) - self.assertAllClose(self.evaluate(m.make_bij(b).forward(2.)), [10., 20.]) - self.evaluate(m.scale.assign(m.scale + [1., 2.])) - self.assertAllClose(self.evaluate(m.make_bij(b).forward(2.)), [12., 24.]) +# m = tf.saved_model.load(absolute_testdata_path) +# self.evaluate(m.scale.initializer) +# b = tfb.Scale([5., 9.], validate_args=True) +# self.assertAllClose(self.evaluate(m.make_bij(b).forward(2.)), [10., 20.]) +# self.evaluate(m.scale.assign(m.scale + [1., 2.])) +# self.assertAllClose(self.evaluate(m.make_bij(b).forward(2.)), [12., 24.]) def test_callable_arg(self): @@ -384,8 +389,8 @@ def __call__(self, *args, **kwargs): def test_composite_tensor_callable_arg(self): # Parameters that are both `CompositeTensor` and callable should be # handled by the `_type_spec` as `CompositeTensor`. - inner_bij = tfb.Scale([[1., 3.]], validate_args=True) - bij = tfb.TransformDiagonal(inner_bij, validate_args=True) + inner_bij = AutoScale([[1., 3.]], validate_args=True) + bij = AutoTransformDiagonal(inner_bij, validate_args=True) self.assertLen(tf.nest.flatten(bij), 1) self.assertLen(bij._type_spec._callable_params, 0) # pylint: disable=protected-access self.assertIn('diag_bijector', bij._type_spec._param_specs) # pylint: disable=protected-access @@ -408,9 +413,28 @@ def __init__(self): d_ct = AutoStandardNormal() self.assertLen(tf.nest.flatten(d_ct, expand_composites=True), 0) + def test_names_preserved_through_flatten(self): + + dist = AutoNormal(0., scale=3., name='ScaleThreeNormal') + flat = tf.nest.flatten(dist, expand_composites=True) + unflat = tf.nest.pack_sequence_as(dist, flat, expand_composites=True) + unflat_name = ('ScaleThreeNormal' if tf.executing_eagerly() + else 'ScaleThreeNormal_1') + self.assertEqual(unflat.name, unflat_name) + class _TestTypeSpec(auto_composite_tensor._AutoCompositeTensorTypeSpec): + def __init__(self, param_specs, non_tensor_params=None, omit_kwargs=(), + prefer_static_value=(), non_identifying_kwargs=(), + callable_params=None): + non_tensor_params = {} if non_tensor_params is None else non_tensor_params + super(_TestTypeSpec, self).__init__( + param_specs, non_tensor_params=non_tensor_params, + omit_kwargs=omit_kwargs, prefer_static_value=prefer_static_value, + non_identifying_kwargs=non_identifying_kwargs, + callable_params=callable_params) + @property def value_type(self): """Unused `value_type` to allow the `TypeSpec` to be instantiated.""" @@ -435,18 +459,25 @@ class AutoCompositeTensorTypeSpecTest(test_util.TestCase): ('WithCallable', _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32), - 'b': tfb.Scale(3.)._type_spec}, - non_tensor_params={}, + 'b': AutoScale(3.)._type_spec}, omit_kwargs=('name', 'foo'), prefer_static_value=('a',), callable_params={'f': tf.math.exp}), _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32), - 'b': tfb.Scale(3.)._type_spec}, - non_tensor_params={}, + 'b': AutoScale(3.)._type_spec}, omit_kwargs=('name', 'foo'), prefer_static_value=('a',), callable_params={'f': tf.math.exp})), + ('DifferentNonIdentifyingKwargsValues', + _TestTypeSpec( + param_specs={'x': tf.TensorSpec([], tf.float64)}, + non_tensor_params={'name': 'MyAutoCT'}, + non_identifying_kwargs=('name')), + _TestTypeSpec( + param_specs={'x': tf.TensorSpec([], tf.float64)}, + non_tensor_params={'name': 'OtherAutoCT'}, + non_identifying_kwargs=('name'))), ) def testEquality(self, v1, v2): # pylint: disable=g-generic-assert @@ -461,26 +492,28 @@ def testEquality(self, v1, v2): _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, 2], tf.float32)}, non_tensor_params={'validate_args': True}, - omit_kwargs=('name',), - prefer_static_value=()), + omit_kwargs=('name',)), _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32)}, non_tensor_params={'validate_args': True}, - omit_kwargs=('name',), - prefer_static_value=())), + omit_kwargs=('name',))), ('DifferentCallables', _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32)}, - non_tensor_params={}, omit_kwargs=('name', 'foo'), - prefer_static_value=(), callable_params={'f': tf.math.exp}), _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32)}, - non_tensor_params={}, omit_kwargs=('name', 'foo'), - prefer_static_value=(), - callable_params={'f': tf.math.sigmoid})) + callable_params={'f': tf.math.sigmoid})), + ('DifferentMetadata', + _TestTypeSpec( + param_specs={'a': tf.TensorSpec([3, 2], tf.float32)}, + non_tensor_params={'validate_args': True}, + non_identifying_kwargs=('name',)), + _TestTypeSpec( + param_specs={'a': tf.TensorSpec([3, None], tf.float32)}, + non_tensor_params={'validate_args': True})), ) def testInequality(self, v1, v2): # pylint: disable=g-generic-assert @@ -504,19 +537,24 @@ def testInequality(self, v1, v2): ('WithCallable', _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32), - 'b': tfb.Scale( + 'b': AutoScale( tf.Variable(2., shape=None))._type_spec}, - non_tensor_params={}, omit_kwargs=('name', 'foo'), - prefer_static_value=(), callable_params={'f': tf.math.exp}), _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32), - 'b': tfb.Scale(3.)._type_spec}, - non_tensor_params={}, + 'b': AutoScale(3.)._type_spec}, omit_kwargs=('name', 'foo'), - prefer_static_value=(), - callable_params={'f': tf.math.exp})) + callable_params={'f': tf.math.exp})), + ('DifferentNonIdentifyingKwargsValues', + _TestTypeSpec( + param_specs={'x': tf.TensorSpec(None, tf.float64)}, + non_tensor_params={'name': 'MyAutoCT'}, + non_identifying_kwargs=('name')), + _TestTypeSpec( + param_specs={'x': tf.TensorSpec([], tf.float64)}, + non_tensor_params={'name': 'OtherAutoCT'}, + non_identifying_kwargs=('name'))), ) def testIsCompatibleWith(self, v1, v2): self.assertTrue(v1.is_compatible_with(v2)) @@ -539,29 +577,23 @@ def testIsCompatibleWith(self, v1, v2): param_specs={'a': tf.TensorSpec([4, 2], tf.float32)}, non_tensor_params={'validate_args': True}, omit_kwargs=('name',), - prefer_static_value=(), callable_params={'g': tf.math.softplus}), _TestTypeSpec( param_specs={'a': tf.TensorSpec([4, None], tf.float32)}, non_tensor_params={'validate_args': False}, omit_kwargs=('name',), - prefer_static_value=(), callable_params={'g': tf.math.softplus})), ('DifferentCallables', _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32), - 'b': tfb.Scale( + 'b': AutoScale( tf.Variable(2., shape=None))._type_spec}, - non_tensor_params={}, omit_kwargs=('name', 'foo'), - prefer_static_value=(), callable_params={'f': tf.math.exp}), _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32), - 'b': tfb.Scale(3.)._type_spec}, - non_tensor_params={}, + 'b': AutoScale(3.)._type_spec}, omit_kwargs=('name', 'foo'), - prefer_static_value=(), callable_params={'f': tf.math.sigmoid})) ) def testIsNotCompatibleWith(self, v1, v2): @@ -572,42 +604,27 @@ def testIsNotCompatibleWith(self, v1, v2): ('WithoutCallable', _TestTypeSpec( param_specs={'a': tf.TensorSpec([4, 2], tf.float32)}, - non_tensor_params={}, - omit_kwargs=('name',), - prefer_static_value=()), + omit_kwargs=('name',)), _TestTypeSpec( param_specs={'a': tf.TensorSpec([4, None], tf.float32)}, - non_tensor_params={}, - omit_kwargs=('name',), - prefer_static_value=()), + omit_kwargs=('name',)), _TestTypeSpec( param_specs={'a': tf.TensorSpec([4, None], tf.float32)}, - non_tensor_params={}, - omit_kwargs=('name',), - prefer_static_value=())), + omit_kwargs=('name',))), ('WithCallable', _TestTypeSpec( param_specs={'a': tf.TensorSpec(None, tf.float32), - 'b': tfb.Scale( + 'b': AutoScale( tf.Variable(2., shape=None))._type_spec}, - non_tensor_params={}, - omit_kwargs=(), - prefer_static_value=(), callable_params={'f': tf.math.exp}), _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32), - 'b': tfb.Scale(tf.Variable(3.))._type_spec}, - non_tensor_params={}, - omit_kwargs=(), - prefer_static_value=(), + 'b': AutoScale(tf.Variable(3.))._type_spec}, callable_params={'f': tf.math.exp}), _TestTypeSpec( param_specs={'a': tf.TensorSpec(None, tf.float32), - 'b': tfb.Scale( + 'b': AutoScale( tf.Variable(2., shape=None))._type_spec}, - non_tensor_params={}, - omit_kwargs=(), - prefer_static_value=(), callable_params={'f': tf.math.exp})), ) def testMostSpecificCompatibleType(self, v1, v2, expected): @@ -618,40 +635,26 @@ def testMostSpecificCompatibleType(self, v1, v2, expected): ('DifferentParamSpecs', _TestTypeSpec( param_specs={'a': tf.TensorSpec([4, 2], tf.float32)}, - non_tensor_params={}, - omit_kwargs=('foo',), - prefer_static_value=()), + omit_kwargs=('foo',)), _TestTypeSpec( param_specs={'b': tf.TensorSpec([5, None], tf.float32)}, - non_tensor_params={}, - omit_kwargs=('foo',), - prefer_static_value=())), + omit_kwargs=('foo',))), ('DifferentMetadata', _TestTypeSpec( param_specs={'a': tf.TensorSpec([4, 2], tf.float32)}, - non_tensor_params={}, - omit_kwargs=('foo',), - prefer_static_value=()), + omit_kwargs=('foo',)), _TestTypeSpec( param_specs={'a': tf.TensorSpec([4, None], tf.float32)}, - non_tensor_params={}, - omit_kwargs=('bar',), - prefer_static_value=())), + omit_kwargs=('bar',))), ('DifferentCallables', _TestTypeSpec( param_specs={'a': tf.TensorSpec(None, tf.float32), - 'b': tfb.Scale( + 'b': AutoScale( tf.Variable(2., shape=None))._type_spec}, - non_tensor_params={}, - omit_kwargs=(), - prefer_static_value=(), callable_params={'f': tf.math.exp}), _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32), - 'b': tfb.Scale(tf.Variable(3.))._type_spec}, - non_tensor_params={}, - omit_kwargs=(), - prefer_static_value=(), + 'b': AutoScale(tf.Variable(3.))._type_spec}, callable_params={'f': tf.math.softplus})), ) def testMostSpecificCompatibleTypeException(self, v1, v2): @@ -664,23 +667,19 @@ def testMostSpecificCompatibleTypeException(self, v1, v2): ('WithoutCallable', _TestTypeSpec( param_specs={'a': tf.TensorSpec([4, 2], tf.float32)}, - non_tensor_params={}, - omit_kwargs=('name',), - prefer_static_value=())), + omit_kwargs=('parameters',), non_identifying_kwargs=('name',))), ('WithCallable', _TestTypeSpec( param_specs={'a': tf.TensorSpec(None, tf.float32), - 'b': tfb.Scale( + 'b': AutoScale( tf.Variable(2., shape=None))._type_spec}, - non_tensor_params={}, - omit_kwargs=(), - prefer_static_value=(), callable_params={'f': tf.math.exp})), ) def testRepr(self, spec): spec_data = (auto_composite_tensor._AUTO_COMPOSITE_TENSOR_VERSION, spec._param_specs, spec._non_tensor_params, spec._omit_kwargs, - spec._prefer_static_value, spec._callable_params) + spec._prefer_static_value, spec._non_identifying_kwargs, + spec._callable_params) self.assertEqual(repr(spec), f'_TestTypeSpec{spec_data}') if __name__ == '__main__': diff --git a/tensorflow_probability/python/internal/backend/jax/BUILD b/tensorflow_probability/python/internal/backend/jax/BUILD index bd2cbac09b..9d52f16af7 100644 --- a/tensorflow_probability/python/internal/backend/jax/BUILD +++ b/tensorflow_probability/python/internal/backend/jax/BUILD @@ -53,6 +53,7 @@ FILENAMES = [ "private", "random_generators", "raw_ops", + "resource_variable_ops", "sets_lib", "sparse_lib", "tensor_array_ops", diff --git a/tensorflow_probability/python/internal/backend/numpy/BUILD b/tensorflow_probability/python/internal/backend/numpy/BUILD index c39fc14ff8..a57a85c911 100644 --- a/tensorflow_probability/python/internal/backend/numpy/BUILD +++ b/tensorflow_probability/python/internal/backend/numpy/BUILD @@ -52,6 +52,7 @@ py_library( ":private", ":random_generators", ":raw_ops", + ":resource_variable_ops", ":sets_lib", ":sparse_lib", ":static_rewrites", @@ -332,6 +333,11 @@ py_library( ], ) +py_library( + name = "resource_variable_ops", + srcs = ["resource_variable_ops.py"], +) + py_library( name = "sets_lib", srcs = ["sets_lib.py"], diff --git a/tensorflow_probability/python/internal/backend/numpy/__init__.py b/tensorflow_probability/python/internal/backend/numpy/__init__.py index d2f0565231..34fe3edbcb 100644 --- a/tensorflow_probability/python/internal/backend/numpy/__init__.py +++ b/tensorflow_probability/python/internal/backend/numpy/__init__.py @@ -46,6 +46,8 @@ from tensorflow_probability.python.internal.backend.numpy.numpy_array import * # pylint: disable=wildcard-import from tensorflow_probability.python.internal.backend.numpy.numpy_math import * # pylint: disable=wildcard-import from tensorflow_probability.python.internal.backend.numpy.ops import * # pylint: disable=wildcard-import +from tensorflow_probability.python.internal.backend.numpy.type_spec import BatchableTypeSpec +from tensorflow_probability.python.internal.backend.numpy.type_spec import TypeSpec Assert = debugging.Assert diff --git a/tensorflow_probability/python/internal/backend/numpy/ops.py b/tensorflow_probability/python/internal/backend/numpy/ops.py index 9d7953be53..3ce10dca5a 100644 --- a/tensorflow_probability/python/internal/backend/numpy/ops.py +++ b/tensorflow_probability/python/internal/backend/numpy/ops.py @@ -61,9 +61,7 @@ 'Module', 'Tensor', 'TensorSpec', - 'TypeSpec', 'Variable', - 'VariableSpec', # 'gradients', ] @@ -696,15 +694,10 @@ class Tensor(six.with_metaclass(_TensorMeta)): class TensorSpec(object): - pass - -class TypeSpec(object): - pass - - -class VariableSpec(object): - pass + def __init__(self, *args, **kwargs): + del args, kwargs + self.dtype = None class Module(object): diff --git a/tensorflow_probability/python/internal/backend/numpy/resource_variable_ops.py b/tensorflow_probability/python/internal/backend/numpy/resource_variable_ops.py new file mode 100644 index 0000000000..5cf9a1f943 --- /dev/null +++ b/tensorflow_probability/python/internal/backend/numpy/resource_variable_ops.py @@ -0,0 +1,26 @@ +# Copyright 2021 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Numpy stub for `resource_variable_ops`.""" + +__all__ = [ + 'VariableSpec', +] + + +class VariableSpec(object): + + def __init__(self, *args, **kwargs): + del args, kwargs + self.dtype = None diff --git a/tensorflow_probability/python/internal/backend/numpy/tf_inspect.py b/tensorflow_probability/python/internal/backend/numpy/tf_inspect.py index d781a99da0..7c98144653 100644 --- a/tensorflow_probability/python/internal/backend/numpy/tf_inspect.py +++ b/tensorflow_probability/python/internal/backend/numpy/tf_inspect.py @@ -27,6 +27,7 @@ # Although `inspect` is different between Python 2 and 3, we should only ever # be using Python 3's inspect because JAX is Python 3 only and if TF is present # we will use `tf_inspect` which is compatible with both Python 2 and 3. +Parameter = inspect.Parameter getfullargspec = inspect.getfullargspec getcallargs = inspect.getcallargs getframeinfo = inspect.getframeinfo @@ -46,4 +47,5 @@ ismethod = inspect.ismethod ismodule = inspect.ismodule isroutine = inspect.isroutine +signature = inspect.signature stack = inspect.stack diff --git a/tensorflow_probability/python/internal/backend/numpy/type_spec.py b/tensorflow_probability/python/internal/backend/numpy/type_spec.py index 74080e4d26..534b196d6f 100644 --- a/tensorflow_probability/python/internal/backend/numpy/type_spec.py +++ b/tensorflow_probability/python/internal/backend/numpy/type_spec.py @@ -16,7 +16,9 @@ __all__ = [ 'lookup', - 'register' + 'register', + 'BatchableTypeSpec', + 'TypeSpec', ] @@ -30,3 +32,11 @@ def decorator_fn(cls): def lookup(_): # Raise ValueError instead of NotImplementedError to conform to TF. raise ValueError('`TypeSpec`s are not registered in Numpy/JAX.') + + +class TypeSpec(object): + pass + + +class BatchableTypeSpec(TypeSpec): + pass diff --git a/tensorflow_probability/python/internal/backend/numpy/v2.py b/tensorflow_probability/python/internal/backend/numpy/v2.py index 2edc956f5e..f5c11b775b 100644 --- a/tensorflow_probability/python/internal/backend/numpy/v2.py +++ b/tensorflow_probability/python/internal/backend/numpy/v2.py @@ -51,6 +51,8 @@ from tensorflow_probability.python.internal.backend.numpy.numpy_math import * # pylint: disable=wildcard-import from tensorflow_probability.python.internal.backend.numpy.ops import * # pylint: disable=wildcard-import from tensorflow_probability.python.internal.backend.numpy.tensor_array_ops import TensorArray +from tensorflow_probability.python.internal.backend.numpy.type_spec import BatchableTypeSpec +from tensorflow_probability.python.internal.backend.numpy.type_spec import TypeSpec # pylint: enable=unused-import diff --git a/tensorflow_probability/python/internal/tensor_util.py b/tensorflow_probability/python/internal/tensor_util.py index f77912fed3..9c1e5849c9 100644 --- a/tensorflow_probability/python/internal/tensor_util.py +++ b/tensorflow_probability/python/internal/tensor_util.py @@ -27,6 +27,7 @@ 'convert_nonref_to_tensor', 'discover_trainable_variables', 'discover_variables', + 'identity_as_tensor', 'is_module', 'is_ref', 'is_trainable_variable', @@ -123,6 +124,14 @@ def convert_nonref_to_tensor(value, dtype=None, dtype_hint=None, value, dtype=dtype, dtype_hint=dtype_hint, name=name) +def identity_as_tensor(value): + """Converts `value` to `Tensor` while ensuring an op is added to the graph.""" + t = tf.convert_to_tensor(value) + if t is value: + t = tf.identity(value) + return t + + def is_ref(x): """Evaluates if the object has reference semantics. diff --git a/tensorflow_probability/python/internal/tensor_util_test.py b/tensorflow_probability/python/internal/tensor_util_test.py index 4aaccf90e1..72c851f878 100644 --- a/tensorflow_probability/python/internal/tensor_util_test.py +++ b/tensorflow_probability/python/internal/tensor_util_test.py @@ -29,6 +29,7 @@ from tensorflow_probability.python.internal import test_util +tfb = tfp.bijectors tfd = tfp.distributions @@ -173,6 +174,15 @@ def test_is_module(self): self.assertTrue(tensor_util.is_module(m)) self.assertFalse(tensor_util.is_module(tf.Variable(0.))) + def test_identity_as_tensor(self): + for v in (tf.constant([4., 3.]), + tf.Variable(0.), + tfp.util.DeferredTensor(tf.Variable(1.), tf.math.exp), + tfp.util.TransformedVariable(2., tfb.Scale(tf.Variable(4.)))): + v_ = tensor_util.identity_as_tensor(v) + self.assertIsNot(v, v_) + self.assertIsInstance(v_, tf.Tensor) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_probability/python/internal/testdata/auto_composite_tensor/saved_model.pb b/tensorflow_probability/python/internal/testdata/auto_composite_tensor/saved_model.pb index 3e61c7c1d4..c9d6e0640e 100644 Binary files a/tensorflow_probability/python/internal/testdata/auto_composite_tensor/saved_model.pb and b/tensorflow_probability/python/internal/testdata/auto_composite_tensor/saved_model.pb differ diff --git a/tensorflow_probability/python/layers/distribution_layer.py b/tensorflow_probability/python/layers/distribution_layer.py index f536417d1b..bedcbc37ca 100644 --- a/tensorflow_probability/python/layers/distribution_layer.py +++ b/tensorflow_probability/python/layers/distribution_layer.py @@ -68,8 +68,14 @@ 'VariationalGaussianProcess', ] - -tf.keras.__internal__.utils.register_symbolic_tensor_type(dtc._TensorCoercible) # pylint: disable=protected-access +try: + k_u = tf.keras.__internal__.utils +except: + try: + from tensorflow.python.keras.utils import tf_utils as k_u + except: + from keras.utils import tf_utils as k_u +k_u.register_symbolic_tensor_type(dtc._TensorCoercible) # pylint: disable=protected-access def _event_size(event_shape, name=None): diff --git a/tensorflow_probability/python/mcmc/hmc_test.py b/tensorflow_probability/python/mcmc/hmc_test.py index 864596b5d6..600c17c653 100644 --- a/tensorflow_probability/python/mcmc/hmc_test.py +++ b/tensorflow_probability/python/mcmc/hmc_test.py @@ -1053,7 +1053,7 @@ def trace_fn(_, pkr): sigma.pretransformed_input ]]) - weights_prior_estimated_scale = tf.identity(sigma) + weights_prior_estimated_scale = tf.convert_to_tensor(sigma) return (weights_prior_estimated_scale, weights[-1], loss, step_size[-1], avg_acceptance_ratio) diff --git a/tensorflow_probability/python/util/deferred_tensor.py b/tensorflow_probability/python/util/deferred_tensor.py index f479af7808..3303c56904 100644 --- a/tensorflow_probability/python/util/deferred_tensor.py +++ b/tensorflow_probability/python/util/deferred_tensor.py @@ -27,6 +27,9 @@ from tensorflow_probability.python.internal import tensor_util from tensorflow_probability.python.internal import tensorshape_util +from tensorflow.python.framework import type_spec # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.ops import resource_variable_ops # pylint: disable=g-direct-tensorflow-import + __all__ = [ 'DeferredTensor', @@ -525,3 +528,234 @@ def assign_sub(self, value, use_locking=False, name=None, read_value=True): use_locking=use_locking, name=name, read_value=read_value) + + +@type_spec.register('tfp.util.TransformedVariableSpec') +class _TransformedVariableSpec(type_spec.BatchableTypeSpec): + """`tf.TypeSpec` for `tfp.util.TransformedVariable`.""" + + __slots__ = ('_input_spec', '_transform_or_spec', '_dtype', '_name', '_specs', + '_unique_id_params', '_transform_is_composite') + + def __init__(self, input_spec, transform_or_spec, dtype, name): + """Initializes a new `_TransformedVariableSpec`. + + Args: + input_spec: `tf.TypeSpec` instance describing the `TransformedVariable`s + `pretransformed_input` attribute. + transform_or_spec: The `bijector` passed to the `TransformedVariable`'s + constructor, or `bijector._type_spec` if `bijector` is a + `CompositeTensor`. + dtype: `tf.DType`, `dtype` property of the `TransformedVariable`. + name: `str`, name of the `TransformedVariable`. + """ + self._input_spec = input_spec + self._transform_or_spec = transform_or_spec + self._dtype = dtype + self._name = name + + self._unique_id_params = {'dtype': dtype} + self._transform_is_composite = isinstance(transform_or_spec, tf.TypeSpec) + + self._specs = {'input_spec': input_spec} + if self._transform_is_composite: + self._specs['transform_or_spec'] = transform_or_spec + + @property + def value_type(self): + return TransformedVariable + + @property + def name(self): + return self._name + + @property + def dtype(self): + return self._dtype + + @property + def transform_or_spec(self): + return self._transform_or_spec + + def most_specific_compatible_type(self, other): + """Returns the most specific TypeSpec compatible with `self` and `other`. + + Args: + other: A `TypeSpec`. + + Returns: + compatible_spec: The `TypeSpec` most compatible with `self` and `other`. + + Raises: + ValueError: If there is no TypeSpec that is compatible with both `self` + and `other`. + ValueError: If `self._transform_fn` is not a `CompositeTensor` and not + equal to `other._transform_fn`. + """ + if type(self) is not type(other): + raise ValueError( + f'No TypeSpec is compatible with both {self} and {other}.') + specs, params = self._TypeSpec__most_specific_compatible_type_serialization( + (self._specs, self._unique_id_params), + (other._specs, other._unique_id_params)) # pylint: disable=protected-access + kwargs = dict(specs, **params) + if not self._transform_is_composite: + if self.transform_or_spec != other.transform_or_spec: + raise ValueError( + f'{self.transform_or_spec} and {other.transform_or_spec} must be ' + f'identical.') + kwargs['transform_or_spec'] = self.transform_or_spec + return type(self)(**kwargs, name=None) + + def is_compatible_with(self, spec_or_value): + """Returns True if `spec_or_value` is compatible with this TypeSpec.""" + if not isinstance(spec_or_value, tf.TypeSpec): + spec_or_value = type_spec.type_spec_from_value(spec_or_value) + if type(self) is not type(spec_or_value): + return False + if not self._transform_is_composite: + if self.transform_or_spec != spec_or_value.transform_or_spec: + return False + return self._TypeSpec__is_compatible( + (self._specs, self._unique_id_params), + (spec_or_value._specs, spec_or_value._unique_id_params)) # pylint: disable=protected-access + + def _with_tensor_ranks_only(self): + """Returns a TypeSpec compatible with `self`, with Tensor shapes relaxed. + + Returns: + A `TypeSpec` that is compatible with `self`, where any `TensorShape` + information has been relaxed to include only Tensor rank (and not + the dimension sizes for individual axes). + """ + def relax(value): + if isinstance(value, tf.TypeSpec): + return value._with_tensor_ranks_only() # pylint: disable=protected-access + elif (isinstance(value, tf.TensorShape) and + value.rank is not None): + return tf.TensorShape([None] * value.rank) + else: + return value + + transform_or_spec = self._specs.pop( + 'transform_or_spec', self.transform_or_spec) + return type(self)( + **tf.nest.map_structure( + relax, + dict(self._specs, + transform_or_spec=transform_or_spec, + **self._unique_id_params, + name=self.name))) + + def _to_components(self, value): + """Encodes `value` as a nested structure of Tensor/CompositeTensor.""" + components = dict(pretransformed_input=value.pretransformed_input) + if isinstance(value.bijector, tf.__internal__.CompositeTensor): + components['bijector'] = value.bijector + return components + + def _from_components(self, components): + """Reconstructs a value from a structure of Tensor/CompositeTensor.""" + bijector = components.pop('bijector', self.transform_or_spec) + return TransformedVariable( + **components, initial_value=None, bijector=bijector, + dtype=self.dtype, name=self.name) + + @property + def _component_specs(self): + """A nested structure of TypeSpecs for the DeferredTensor's components.""" + specs = dict(pretransformed_input=self._input_spec) + if self._transform_is_composite: + specs['bijector'] = self.transform_or_spec + return specs + + def _batch(self, batch_size): + """Returns a TypeSpec representing a batch of DeferredTensors.""" + transform_or_spec = self._specs.pop( + 'transform_or_spec', self.transform_or_spec) + return type(self)( + self._get_batched_input_spec(batch_size), + transform_or_spec=transform_or_spec, + dtype=self.dtype, + name=self.name) + + def _unbatch(self): + """Returns a TypeSpec representing a single DeferredTensor.""" + transform_or_spec = self._specs.pop( + 'transform_or_spec', self.transform_or_spec) + return type(self)( + self._get_unbatched_input_spec(), + transform_or_spec=transform_or_spec, + dtype=self.dtype, + name=self.name) + + def _get_batched_input_spec(self, batch_size): + """Returns the batched `input_spec` for the given `batch_size`.""" + if isinstance(self._input_spec, type_spec.BatchableTypeSpec): + return self._input_spec._batch(batch_size) # pylint: disable=protected-access + if isinstance(self._input_spec, resource_variable_ops.VariableSpec): + return resource_variable_ops.VariableSpec( + shape=tf.TensorShape([batch_size]).concatenate( + self._input_spec.shape), + dtype=self._input_spec.dtype, + trainable=self._input_spec.trainable) + raise NotImplementedError( + f'`{self.value_type.__name__}`s `TypeSpec` is not supported for ' + f'inputs of type {type(self._input_spec)}.') + + def _get_unbatched_input_spec(self): + """Returns the `input_spec` with leading batch dimension removed.""" + if isinstance(self._input_spec, type_spec.BatchableTypeSpec): + return self._input_spec._unbatch() # pylint: disable=protected-access + if isinstance(self._input_spec, resource_variable_ops.VariableSpec): + return resource_variable_ops.VariableSpec( + shape=(None if self._input_spec.shape is None + else self._input_spec.shape[1:]), + dtype=self._input_spec.dtype, + trainable=self._input_spec.trainable) + else: + raise NotImplementedError( + f'`{self.value_type.__name__}`s `TypeSpec` is not supported for ' + f'inputs of type {type(self._input_spec)}.') + + def _serialize(self): + if not self._transform_is_composite: + raise ValueError( + f'Cannot serialize non-`CompositeTensor: {self.transform_or_spec}.') + return tuple( + dict(self._specs, **self._unique_id_params, name=self.name).items()) + + @classmethod + def _deserialize(cls, serialization): + return cls(**dict(serialization)) + + def __get_cmp_key(self): + fn_key = (None if self._transform_is_composite + else id(self.transform_or_spec)) + return (type(self), self._TypeSpec__make_cmp_key( + (self._specs, self._unique_id_params, fn_key))) + + def __repr__(self): + kwargs = dict(self._specs, **self._unique_id_params, name=self.name) + if not self._transform_is_composite: + kwargs['transform_or_spec'] = self.transform_or_spec + kwargs_str = ', '.join(f'{k}={v}' for k, v in kwargs.items()) + return f'{type(self).__name__}({kwargs_str})' + + def __reduce__(self): + if not self._transform_is_composite: + raise ValueError( + f'Cannot serialize object with callable parameters that are not ' + f'`CompositeTensor`s: {self.transform_or_spec}.') + super().__reduce__() + + def __eq__(self, other): + return (type(other) is type(self) and + self.__get_cmp_key() == other.__get_cmp_key()) # pylint: disable=protected-access + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash(self.__get_cmp_key()) + diff --git a/tensorflow_probability/python/util/deferred_tensor_test.py b/tensorflow_probability/python/util/deferred_tensor_test.py index 944bdf33c6..91af7433f1 100644 --- a/tensorflow_probability/python/util/deferred_tensor_test.py +++ b/tensorflow_probability/python/util/deferred_tensor_test.py @@ -24,12 +24,17 @@ import numpy as np import tensorflow.compat.v2 as tf import tensorflow_probability as tfp +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.util import deferred_tensor +from tensorflow.python.ops import resource_variable_ops # pylint: disable=g-direct-tensorflow-import tfb = tfp.bijectors tfd = tfp.distributions +JAX_MODE = False + @test_util.test_all_tf_execution_regimes class DeferredTensorTest(test_util.TestCase): @@ -71,7 +76,7 @@ def test_properties(self): @test_util.numpy_disable_variable_test def test_retains_trainable_variables_from_bijector(self): m = tf.Variable(0., name='m') - x = tfp.util.DeferredTensor(1., tfb.Scale(m)) + x = tfp.util.DeferredTensor(1., AutoScale(m)) self.assertIn(m, x.trainable_variables) @test_util.jax_disable_variable_test @@ -408,5 +413,232 @@ def fn(y): atol=0., rtol=1e-5) +def _make_transformed_variable_spec( + input_spec, transform_or_spec, dtype=None, name=None): + """Returns a `_TransformedVariableSpec` instance.""" + dtype = dtype or input_spec.dtype + return deferred_tensor._TransformedVariableSpec( + input_spec=input_spec, transform_or_spec=transform_or_spec, dtype=dtype, + name=name) + + +def _make_bijector_spec( + bijector_class, param, use_variable=False, variable_shape=None): + """Returns the `TypeSpec` of a Bijector with one Tensor-valued parameter. + + This utility avoids errors in the JAX backend due to instantiation of a + bijector before `app.run` is called. + + Args: + bijector_class: Subclass of `tfp.bijectors.Bijector`. + param: `Tensor`-like parameter of the bijector. + use_variable: Python `bool`. If True, `param` is converted to a + `tf.Variable`. + variable_shape: `tf.TensorShape` or list of `int`s. Static shape of the + `tf.Variable`, if `use_variable` is True. + + Returns: + bijector_spec: `TypeSpec` for a `Bijector` instance, or None if the test is + running in JAX mode. + """ + if JAX_MODE: + return None + if use_variable: + param = tf.Variable(param, shape=variable_shape) + return bijector_class(param)._type_spec + + +AutoScale = auto_composite_tensor.auto_composite_tensor( + tfb.Scale, omit_kwargs=('parameters',), non_identifying_kwargs=('name',), + module_name=('tfp.bijectors')) +AutoSigmoid = auto_composite_tensor.auto_composite_tensor( + tfb.Sigmoid, omit_kwargs=('parameters',), non_identifying_kwargs=('name',), + module_name=('tfp.bijectors')) +AutoShift = auto_composite_tensor.auto_composite_tensor( + tfb.Shift, omit_kwargs=('parameters',), non_identifying_kwargs=('name',), + module_name=('tfp.bijectors')) + + +@test_util.test_all_tf_execution_regimes +@test_util.disable_test_for_backend( + disable_numpy=True, disable_jax=True, + reason='JAX and Numpy have no notion of `TypeSpec`.') +class DeferredTensorSpecTest(test_util.TestCase): + + @parameterized.named_parameters( + ('TransformedVariableBijector', + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([4, None], tf.float32), + transform_or_spec=_make_bijector_spec(AutoScale, [3.])), + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([4, None], tf.float32), + transform_or_spec=_make_bijector_spec(AutoScale, [3.]))), + ('TranformedVariableCallable', + _make_transformed_variable_spec( + input_spec=resource_variable_ops.VariableSpec(None, tf.float64), + transform_or_spec=tf.math.sigmoid, + dtype=tf.float64, + name='one'), + _make_transformed_variable_spec( + input_spec=resource_variable_ops.VariableSpec(None, tf.float64), + transform_or_spec=tf.math.sigmoid, + dtype=tf.float64, + name='two')), + ) + def testEquality(self, v1, v2): + # pylint: disable=g-generic-assert + self.assertEqual(v1, v2) + self.assertEqual(v2, v1) + self.assertFalse(v1 != v2) + self.assertFalse(v2 != v1) + self.assertEqual(hash(v1), hash(v2)) + + @parameterized.named_parameters( + ('DifferentDtypes', + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([4, 2], tf.float64), + transform_or_spec=AutoSigmoid(validate_args=True)._type_spec, + dtype=tf.float64), + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([4, 2], tf.float32), + transform_or_spec=AutoSigmoid(validate_args=True)._type_spec)), + ('DifferentCallables', + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([4, 2], tf.float64), + transform_or_spec=tf.math.sigmoid, + dtype=tf.float64, + name='one'), + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([4, 2], tf.float64), + transform_or_spec=tf.math.softplus, + dtype=tf.float64, + name='two')), + ) + def testInequality(self, v1, v2): + # pylint: disable=g-generic-assert + self.assertNotEqual(v1, v2) + self.assertNotEqual(v2, v1) + self.assertFalse(v1 == v2) + self.assertFalse(v2 == v1) + + @parameterized.named_parameters( + ('TransformedVariableBijector', + _make_transformed_variable_spec( + input_spec=resource_variable_ops.VariableSpec([4, 2], tf.float32), + transform_or_spec=AutoSigmoid(validate_args=True)._type_spec), + _make_transformed_variable_spec( + input_spec=resource_variable_ops.VariableSpec([4, 2], tf.float32), + transform_or_spec=AutoSigmoid(validate_args=True)._type_spec)), + ('TransformedVariableCallable', + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([4, 2], tf.float32), + transform_or_spec=tf.math.sigmoid, + name='one'), + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([4, 2], tf.float32), + transform_or_spec=tf.math.sigmoid, + name='two')), + ) + def testIsCompatibleWith(self, v1, v2): + self.assertTrue(v1.is_compatible_with(v2)) + self.assertTrue(v2.is_compatible_with(v1)) + + @parameterized.named_parameters( + ('DifferentDtypes', + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([4, 2], tf.float32), + transform_or_spec=AutoSigmoid(validate_args=True)._type_spec, + dtype=tf.float64), + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([4, 2], tf.float32), + transform_or_spec=AutoSigmoid(validate_args=True)._type_spec)), + ('DifferentCallables', + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([4, 2], tf.float64), + transform_or_spec=tf.math.sigmoid, + dtype=tf.float64, + name='one'), + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([4, 2], tf.float64), + transform_or_spec=tf.math.softplus, + dtype=tf.float64, + name='two')), + ) + def testIsNotCompatibleWith(self, v1, v2): + self.assertFalse(v1.is_compatible_with(v2)) + self.assertFalse(v2.is_compatible_with(v1)) + + @parameterized.named_parameters( + ('TransformedVariableBijector', + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([4, 2], tf.float32), + transform_or_spec=_make_bijector_spec( + AutoShift, [[2.]], use_variable=True, variable_shape=[1, 1])), + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([4, 2], tf.float32), + transform_or_spec=_make_bijector_spec( + AutoShift, [[3.]], use_variable=True, variable_shape=[1, None])), + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([4, 2], tf.float32), + transform_or_spec=_make_bijector_spec( + AutoShift, [[3.]], use_variable=True, variable_shape=[1, None])) + ), + ('TransformedVariableCallable', + _make_transformed_variable_spec( + input_spec=resource_variable_ops.VariableSpec([4, 2], tf.float32), + transform_or_spec=tf.math.sigmoid), + _make_transformed_variable_spec( + input_spec=resource_variable_ops.VariableSpec(None, tf.float32), + transform_or_spec=tf.math.sigmoid), + _make_transformed_variable_spec( + input_spec=resource_variable_ops.VariableSpec(None, tf.float32), + transform_or_spec=tf.math.sigmoid)) + ) + def testMostSpecificCompatibleType(self, v1, v2, expected): + self.assertEqual(v1.most_specific_compatible_type(v2), expected) + self.assertEqual(v2.most_specific_compatible_type(v1), expected) + + @parameterized.named_parameters( + ('DifferentDtypes', + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([], tf.float32), + transform_or_spec=AutoSigmoid()._type_spec, + dtype=tf.float64), + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([], tf.float32), + transform_or_spec=AutoSigmoid()._type_spec)), + ('DifferentCallables', + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([4, 2], tf.float64), + transform_or_spec=tf.math.sigmoid, + dtype=tf.float64, + name='one'), + _make_transformed_variable_spec( + input_spec=tf.TensorSpec([4, 2], tf.float64), + transform_or_spec=tf.math.softplus, + dtype=tf.float64, + name='two')), + ) + def testMostSpecificCompatibleTypeException(self, v1, v2): + with self.assertRaises(ValueError): + v1.most_specific_compatible_type(v2) + with self.assertRaises(ValueError): + v2.most_specific_compatible_type(v1) + + # Disable test for TFP 0.13 release, in which bijectors are not AutoCT. + # def testRepr(self): + # spec = _make_transformed_variable_spec( + # input_spec=tf.TensorSpec([4, 2], tf.float32), + # transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec, + # dtype=tf.float64) + # expected = ( + # "_TransformedVariableSpec(input_spec=TensorSpec(shape=(4, 2), " + # "dtype=tf.float32, name=None), " + # "transform_or_spec=Sigmoid_ACTTypeSpec(3, {}, {'low': None, 'high': " + # "None, 'validate_args': True, 'name': 'sigmoid'}, ('parameters',), (), " # pylint: disable=line-too-long + # "('name',), {}), dtype=, name=None)") + # self.assertEqual(repr(spec), expected) + + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_probability/python/version.py b/tensorflow_probability/python/version.py index 503fc451e3..3993dc205d 100644 --- a/tensorflow_probability/python/version.py +++ b/tensorflow_probability/python/version.py @@ -24,7 +24,7 @@ # stable release (indicated by `_VERSION_SUFFIX = ''`). Outside the context of a # release branch, the current version is by default assumed to be a # 'development' version, labeled 'dev'. -_VERSION_SUFFIX = 'dev' +_VERSION_SUFFIX = '' # Example, '0.4.0-dev' __version__ = '.'.join([ diff --git a/tensorflow_probability/python/vi/optimization_test.py b/tensorflow_probability/python/vi/optimization_test.py index 27951c0b91..6c14ef77b9 100644 --- a/tensorflow_probability/python/vi/optimization_test.py +++ b/tensorflow_probability/python/vi/optimization_test.py @@ -73,7 +73,7 @@ def trainable_log_prob(z): with tf.control_dependencies([loss_curve]): final_q_loc = tf.identity(q.mean()) final_q_scale = tf.identity(q.stddev()) - final_likelihood_scale = tf.identity(likelihood_scale) + final_likelihood_scale = tf.convert_to_tensor(likelihood_scale) # We expect to recover the true posterior because the variational family # includes the true posterior, and the true parameters because we observed diff --git a/tensorflow_probability/substrates/meta/rewrite.py b/tensorflow_probability/substrates/meta/rewrite.py index 901ef4ccb7..12f76a691c 100644 --- a/tensorflow_probability/substrates/meta/rewrite.py +++ b/tensorflow_probability/substrates/meta/rewrite.py @@ -61,7 +61,7 @@ ('from tensorflow.python.ops import ' 'resource_variable_ops'): ('from tensorflow_probability.python.internal.backend.numpy ' - 'import ops'), + 'import resource_variable_ops'), 'from tensorflow.python.util import': 'from tensorflow_probability.python.internal.backend.numpy import', 'from tensorflow.python.util.all_util':