From 49e7f5e9f7d443bac6bb7584258f976e28b9e505 Mon Sep 17 00:00:00 2001 From: emilyaf Date: Thu, 20 May 2021 10:27:31 -0700 Subject: [PATCH] Temporarily remove AutoCompositeTensor from bijectors for TFP 0.13 release. PiperOrigin-RevId: 374899448 --- .../python/bijectors/bijector.py | 52 ++++++++------- .../bijectors/bijector_properties_test.py | 64 ++++-------------- .../python/bijectors/bijector_test.py | 29 +++++---- .../python/bijectors/exp.py | 4 +- .../python/bijectors/expm1.py | 3 +- .../python/bijectors/invert.py | 46 ++++++------- .../experimental/composite_tensor_test.py | 6 +- .../internal/auto_composite_tensor_test.py | 63 +++++++++--------- .../python/util/deferred_tensor_test.py | 65 +++++++++++-------- 9 files changed, 159 insertions(+), 173 deletions(-) diff --git a/tensorflow_probability/python/bijectors/bijector.py b/tensorflow_probability/python/bijectors/bijector.py index 8cb9515488..fc45cc6644 100644 --- a/tensorflow_probability/python/bijectors/bijector.py +++ b/tensorflow_probability/python/bijectors/bijector.py @@ -20,7 +20,7 @@ import abc import contextlib -import functools +# import functools # Dependency imports import numpy as np @@ -1599,33 +1599,39 @@ def _composite_tensor_shape_params(self): return () -class AutoCompositeTensorBijector( - Bijector, auto_composite_tensor.AutoCompositeTensor): - r"""Base for `CompositeTensor` bijectors with auto-generated `TypeSpec`s. +# Temporarily disable AutoCT for TFP 0.13 release +# class AutoCompositeTensorBijector( +# Bijector, auto_composite_tensor.AutoCompositeTensor): +# r"""Base for `CompositeTensor` bijectors with auto-generated `TypeSpec`s. - `CompositeTensor` objects are able to pass in and out of `tf.function` and - `tf.while_loop`, or serve as part of the signature of a TF saved model. - `Bijector` subclasses that follow the contract of - `tfp.experimental.auto_composite_tensor` may be defined as `CompositeTensor`s - by inheriting from `AutoCompositeTensorBijector` and applying a class - decorator as shown here: +# `CompositeTensor` objects are able to pass in and out of `tf.function` and +# `tf.while_loop`, or serve as part of the signature of a TF saved model. +# `Bijector` subclasses that follow the contract of +# `tfp.experimental.auto_composite_tensor` may be defined as `CompositeTensor`s # pylint: disable=line-too-long +# by inheriting from `AutoCompositeTensorBijector` and applying a class +# decorator as shown here: - ```python - @tfp.experimental.auto_composite_tensor( - omit_kwargs=('name',), module_name='my_module') - class MyBijector(tfb.AutoCompositeTensorBijector): +# ```python +# @tfp.experimental.auto_composite_tensor( +# omit_kwargs=('name',), module_name='my_module') +# class MyBijector(tfb.AutoCompositeTensorBijector): - # The remainder of the subclass implementation is unchanged. - ``` - """ - pass +# # The remainder of the subclass implementation is unchanged. +# ``` +# """ +# pass + + +# auto_composite_tensor_bijector = functools.partial( +# auto_composite_tensor.auto_composite_tensor, +# omit_kwargs=('parameters',), +# non_identifying_kwargs=('name',), +# module_name='tfp.bijectors') + +AutoCompositeTensorBijector = Bijector -auto_composite_tensor_bijector = functools.partial( - auto_composite_tensor.auto_composite_tensor, - omit_kwargs=('parameters',), - non_identifying_kwargs=('name',), - module_name='tfp.bijectors') +auto_composite_tensor_bijector = lambda cls, **kwargs: cls def check_valid_ndims(ndims, validate=True): diff --git a/tensorflow_probability/python/bijectors/bijector_properties_test.py b/tensorflow_probability/python/bijectors/bijector_properties_test.py index 34135cddfe..8eab905268 100644 --- a/tensorflow_probability/python/bijectors/bijector_properties_test.py +++ b/tensorflow_probability/python/bijectors/bijector_properties_test.py @@ -28,14 +28,13 @@ from tensorflow_probability.python import bijectors as tfb from tensorflow_probability.python import experimental from tensorflow_probability.python.bijectors import hypothesis_testlib as bijector_hps -from tensorflow_probability.python.bijectors import invert as invert_lib from tensorflow_probability.python.internal import hypothesis_testlib as tfp_hps from tensorflow_probability.python.internal import prefer_static from tensorflow_probability.python.internal import samplers from tensorflow_probability.python.internal import tensor_util from tensorflow_probability.python.internal import tensorshape_util from tensorflow_probability.python.internal import test_util -from tensorflow_probability.python.util.deferred_tensor import DeferredTensor +# from tensorflow_probability.python.util.deferred_tensor import DeferredTensor TF2_FRIENDLY_BIJECTORS = ( @@ -186,6 +185,7 @@ COMPOSITE_TENSOR_IS_BROKEN = [ 'BatchNormalization', # tf.layers arg + 'Inline', # callable 'RationalQuadraticSpline', # TODO(b/185628453): Debug loss of static info. ] @@ -204,7 +204,7 @@ def is_invert(bijector): - return isinstance(bijector, (tfb.Invert, invert_lib._Invert)) + return isinstance(bijector, tfb.Invert) def is_transform_diagonal(bijector): @@ -244,8 +244,8 @@ def _constraint(param): # TODO(b/141098791): Eliminate this. -@experimental.auto_composite_tensor -class CallableModule(tf.Module, experimental.AutoCompositeTensor): +# @experimental.auto_composite_tensor +class CallableModule(tf.Module): # , experimental.AutoCompositeTensor): """Convenience object for capturing variables closed over by Inline.""" def __init__(self, fn, varobj): @@ -887,38 +887,16 @@ def testEquality(self, bijector_name, data): @hp.given(hps.data()) @tfp_hps.tfp_hp_settings() def testCompositeTensor(self, bijector_name, data): - + # Test that making a composite tensor of this bijector doesn't throw any + # errors. bijector, event_dim = self._draw_bijector( - bijector_name, data, - batch_shape=[], - validate_args=True, + bijector_name, data, batch_shape=[], allowed_bijectors=(set(TF2_FRIENDLY_BIJECTORS) - set(COMPOSITE_TENSOR_IS_BROKEN))) - - if type(bijector) is invert_lib._Invert: # pylint: disable=unidiomatic-typecheck - if isinstance(bijector.bijector, tf.__internal__.CompositeTensor): - raise TypeError('`_Invert` should wrap only non-`CompositeTensor` ' - 'bijectors.') - self.skipTest('`_Invert` bijectors are not `CompositeTensor`s.') - - # TODO(b/182603117): Remove "if" condition and s/composite_bij/bijector - # when AutoCT is enabled for meta-bijectors and LinearOperator. - if type(bijector).__name__ in AUTO_COMPOSITE_TENSOR_IS_BROKEN: - composite_bij = experimental.as_composite(bijector) - else: - composite_bij = bijector - - if not tf.executing_eagerly(): - composite_bij = tf.nest.map_structure( - lambda x: (tf.convert_to_tensor(x) # pylint: disable=g-long-lambda - if isinstance(x, DeferredTensor) else x), - composite_bij, - expand_composites=True) - - self.assertIsInstance(composite_bij, tf.__internal__.CompositeTensor) + composite_bij = experimental.as_composite(bijector) flat = tf.nest.flatten(composite_bij, expand_composites=True) - unflat = tf.nest.pack_sequence_as( - composite_bij, flat, expand_composites=True) + unflat = tf.nest.pack_sequence_as(composite_bij, flat, + expand_composites=True) # Compare forward maps before and after compositing. n = 3 @@ -933,26 +911,6 @@ def testCompositeTensor(self, bijector_name, data): after_xs = unflat.inverse(ys) self.assertAllClose(*self.evaluate((before_xs, after_xs))) - # Input to tf.function - self.assertAllClose( - before_ys, - tf.function(lambda b: b.forward(xs))(composite_bij), - rtol=COMPOSITE_TENSOR_RTOL[bijector_name], - atol=COMPOSITE_TENSOR_ATOL[bijector_name]) - - # Forward mapping: Check differentiation through forward mapping with - # respect to the input and parameter variables. Also check that any - # variables are not referenced overmuch. - xs = self._draw_domain_tensor(bijector, data, event_dim) - wrt_vars = [xs] + [v for v in composite_bij.trainable_variables - if v.dtype.is_floating] - with tf.GradientTape() as tape: - tape.watch(wrt_vars) - # TODO(b/73073515): Fix graph mode gradients with bijector caching. - ys = bijector.forward(xs + 0) - grads = tape.gradient(ys, wrt_vars) - assert_no_none_grad(bijector, 'forward', wrt_vars, grads) - def ensure_nonzero(x): return tf.where(x < 1e-6, tf.constant(1e-3, x.dtype), x) diff --git a/tensorflow_probability/python/bijectors/bijector_test.py b/tensorflow_probability/python/bijectors/bijector_test.py index 82d4d34b39..0d26899e63 100644 --- a/tensorflow_probability/python/bijectors/bijector_test.py +++ b/tensorflow_probability/python/bijectors/bijector_test.py @@ -790,25 +790,26 @@ def _forward_log_det_jacobian(self, _): return tf.math.log(self._scale) -@test_util.test_all_tf_execution_regimes -class AutoCompositeTensorBijectorTest(test_util.TestCase): + # Test disabled temporarily for TFP 0.13 release. +# @test_util.test_all_tf_execution_regimes +# class AutoCompositeTensorBijectorTest(test_util.TestCase): - def test_disable_ct_bijector(self): +# def test_disable_ct_bijector(self): - ct_bijector = CompositeForwardBijector() - self.assertIsInstance(ct_bijector, tf.__internal__.CompositeTensor) +# ct_bijector = CompositeForwardBijector() +# self.assertIsInstance(ct_bijector, tf.__internal__.CompositeTensor) - non_ct_bijector = ForwardOnlyBijector() - self.assertNotIsInstance(non_ct_bijector, tf.__internal__.CompositeTensor) +# non_ct_bijector = ForwardOnlyBijector() +# self.assertNotIsInstance(non_ct_bijector, tf.__internal__.CompositeTensor) - flat = tf.nest.flatten(ct_bijector, expand_composites=True) - unflat = tf.nest.pack_sequence_as( - ct_bijector, flat, expand_composites=True) +# flat = tf.nest.flatten(ct_bijector, expand_composites=True) +# unflat = tf.nest.pack_sequence_as( +# ct_bijector, flat, expand_composites=True) - x = tf.constant([2., 3.]) - self.assertAllClose( - non_ct_bijector.forward(x), - tf.function(lambda b: b.forward(x))(unflat)) +# x = tf.constant([2., 3.]) +# self.assertAllClose( +# non_ct_bijector.forward(x), +# tf.function(lambda b: b.forward(x))(unflat)) if __name__ == '__main__': diff --git a/tensorflow_probability/python/bijectors/exp.py b/tensorflow_probability/python/bijectors/exp.py index 6bc0524fab..6e59ef49de 100644 --- a/tensorflow_probability/python/bijectors/exp.py +++ b/tensorflow_probability/python/bijectors/exp.py @@ -23,7 +23,6 @@ from tensorflow_probability.python.bijectors import bijector as bijector_lib from tensorflow_probability.python.bijectors import invert from tensorflow_probability.python.bijectors import power_transform -from tensorflow_probability.python.internal import auto_composite_tensor __all__ = [ @@ -77,8 +76,7 @@ def __init__(self, # TODO(b/182603117): Remove `AutoCompositeTensor` when `Invert` subclasses # `AutoCompositeTensor` and ensure `tf.saved_model` still works. @bijector_lib.auto_composite_tensor_bijector -class Log(invert.Invert, - auto_composite_tensor.AutoCompositeTensor): +class Log(invert.Invert): """Compute `Y = log(X)`. This is `Invert(Exp())`.""" def __init__(self, validate_args=False, name='log'): diff --git a/tensorflow_probability/python/bijectors/expm1.py b/tensorflow_probability/python/bijectors/expm1.py index ea9d022c70..69daa73d97 100644 --- a/tensorflow_probability/python/bijectors/expm1.py +++ b/tensorflow_probability/python/bijectors/expm1.py @@ -21,7 +21,6 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector from tensorflow_probability.python.bijectors import invert -from tensorflow_probability.python.internal import auto_composite_tensor __all__ = [ @@ -95,7 +94,7 @@ def _forward_log_det_jacobian(self, x): # TODO(b/182603117): Remove `AutoCompositeTensor` when `Invert` subclasses # `AutoCompositeTensor`. @bijector.auto_composite_tensor_bijector -class Log1p(invert.Invert, auto_composite_tensor.AutoCompositeTensor): +class Log1p(invert.Invert): """Compute `Y = log1p(X)`. This is `Invert(Expm1())`.""" def __init__(self, validate_args=False, name='log1p'): diff --git a/tensorflow_probability/python/bijectors/invert.py b/tensorflow_probability/python/bijectors/invert.py index 02a1541365..d747920774 100644 --- a/tensorflow_probability/python/bijectors/invert.py +++ b/tensorflow_probability/python/bijectors/invert.py @@ -21,14 +21,14 @@ import tensorflow.compat.v2 as tf from tensorflow_probability.python.bijectors import bijector as bijector_lib -from tensorflow_probability.python.internal import auto_composite_tensor +# from tensorflow_probability.python.internal import auto_composite_tensor __all__ = [ 'Invert', ] -class _Invert(bijector_lib.Bijector): +class Invert(bijector_lib.Bijector): """Bijector which inverts another Bijector. Example Use: [ExpGammaDistribution (see Background & Context)]( @@ -73,7 +73,7 @@ def __init__(self, bijector, validate_args=False, parameters=None, name=None): name = name or '_'.join(['invert', bijector.name]) with tf.name_scope(name) as name: self._bijector = bijector - super(_Invert, self).__init__( + super(Invert, self).__init__( forward_min_event_ndims=bijector.inverse_min_event_ndims, inverse_min_event_ndims=bijector.forward_min_event_ndims, dtype=bijector.dtype, @@ -138,26 +138,28 @@ def forward_event_ndims(self, event_ndims, **kwargs): return self.bijector.inverse_event_ndims(event_ndims, **kwargs) -@bijector_lib.auto_composite_tensor_bijector -class Invert(_Invert, auto_composite_tensor.AutoCompositeTensor): +# Temporarily removing AutoCompositeTensor for TFP 0.13 release. +# pylint: disable=line-too-long +# @bijector_lib.auto_composite_tensor_bijector +# class Invert(_Invert, auto_composite_tensor.AutoCompositeTensor): - def __new__(cls, *args, **kwargs): - """Returns an `_Invert` instance if `bijector` is not a `CompositeTensor.""" - if cls is Invert: - if args: - bijector = args[0] - elif 'bijector' in kwargs: - bijector = kwargs['bijector'] - else: - raise TypeError('`Invert.__new__()` is missing argument `bijector`.') +# def __new__(cls, *args, **kwargs): +# """Returns an `_Invert` instance if `bijector` is not a `CompositeTensor.""" +# if cls is Invert: +# if args: +# bijector = args[0] +# elif 'bijector' in kwargs: +# bijector = kwargs['bijector'] +# else: +# raise TypeError('`Invert.__new__()` is missing argument `bijector`.') - if not isinstance(bijector, tf.__internal__.CompositeTensor): - return _Invert(*args, **kwargs) - return super(Invert, cls).__new__(cls) +# if not isinstance(bijector, tf.__internal__.CompositeTensor): +# return _Invert(*args, **kwargs) +# return super(Invert, cls).__new__(cls) -Invert.__doc__ = _Invert.__doc__ + '/n' + ( - 'When an `Invert` bijector is constructed, if its `bijector` arg is not a ' - '`CompositeTensor` instance, an `_Invert` instance is returned instead. ' - 'Bijectors subclasses that inherit from `Invert` will also inherit from ' - ' `CompositeTensor`.') +# Invert.__doc__ = _Invert.__doc__ + '/n' + ( +# 'When an `Invert` bijector is constructed, if its `bijector` arg is not a ' +# '`CompositeTensor` instance, an `_Invert` instance is returned instead. ' +# 'Bijectors subclasses that inherit from `Invert` will also inherit from ' +# ' `CompositeTensor`.') diff --git a/tensorflow_probability/python/experimental/composite_tensor_test.py b/tensorflow_probability/python/experimental/composite_tensor_test.py index 6543d67f69..1e683d47b6 100644 --- a/tensorflow_probability/python/experimental/composite_tensor_test.py +++ b/tensorflow_probability/python/experimental/composite_tensor_test.py @@ -454,7 +454,11 @@ def test_basics_mixture_same_family(self): self.evaluate(unflat.log_prob(.5)) def test_already_composite_tensor(self): - b = tfb.Scale(2.) + AutoScale = tfp.experimental.auto_composite_tensor( # pylint: disable=invalid-name + tfb.Scale, omit_kwargs=('parameters',), + non_identifying_kwargs=('name',), + module_name=('tfp.bijectors')) + b = AutoScale(2.) b2 = tfp.experimental.as_composite(b) self.assertIsInstance(b, tf.__internal__.CompositeTensor) self.assertIs(b, b2) diff --git a/tensorflow_probability/python/internal/auto_composite_tensor_test.py b/tensorflow_probability/python/internal/auto_composite_tensor_test.py index 03908bcb95..5d3015745d 100644 --- a/tensorflow_probability/python/internal/auto_composite_tensor_test.py +++ b/tensorflow_probability/python/internal/auto_composite_tensor_test.py @@ -22,7 +22,7 @@ import os from absl import flags -from absl.testing import absltest +# from absl.testing import absltest from absl.testing import parameterized import tensorflow.compat.v2 as tf @@ -63,6 +63,10 @@ tfd.Independent, non_identifying_kwargs=('name',)) AutoReshape = tfp.experimental.auto_composite_tensor( tfb.Reshape, non_identifying_kwargs=('name',)) +AutoScale = tfp.experimental.auto_composite_tensor( + tfb.Scale, non_identifying_kwargs=('name',)) +AutoTransformDiagonal = tfp.experimental.auto_composite_tensor( + tfb.TransformDiagonal, non_identifying_kwargs=('name',)) class Model(tf.Module): @@ -71,9 +75,9 @@ def __init__(self): self.scale = tf.Variable([0., 1.], shape=[None]) @tf.function(input_signature=( - tfb.Scale([1., 2.], validate_args=True)._type_spec,)) + AutoScale([1., 2.], validate_args=True)._type_spec,)) def make_bij(self, b): - return tfb.Scale( + return AutoScale( tf.convert_to_tensor(self.scale) + b.scale, validate_args=True) @@ -314,7 +318,7 @@ def test_export_import(self): tf.saved_model.save(m1, os.path.join(path, 'saved_model1')) m2 = tf.saved_model.load(os.path.join(path, 'saved_model1')) self.evaluate(m2.scale.initializer) - b = tfb.Scale([5., 9.], validate_args=True) + b = AutoScale([5., 9.], validate_args=True) self.evaluate(m2.make_bij(b).forward(2.)) self.evaluate(m2.scale.assign(m2.scale + [1., 2.])) self.evaluate(m2.make_bij(b).forward(2.)) @@ -326,19 +330,20 @@ def test_export_import(self): with self.assertRaisesOpError('compatible shape'): self.evaluate(m3.make_bij(b).forward([3.])) - def test_saved_model_from_disk(self): + # Test disabled for 0.13 release. +# def test_saved_model_from_disk(self): - test_srcdir = absltest.get_default_test_srcdir() - relative_testdata_path = os.path.join( - TFP_PYTHON_DIR, 'internal/testdata/auto_composite_tensor') - absolute_testdata_path = os.path.join(test_srcdir, relative_testdata_path) +# test_srcdir = absltest.get_default_test_srcdir() +# relative_testdata_path = os.path.join( +# TFP_PYTHON_DIR, 'internal/testdata/auto_composite_tensor') +# absolute_testdata_path = os.path.join(test_srcdir, relative_testdata_path) - m = tf.saved_model.load(absolute_testdata_path) - self.evaluate(m.scale.initializer) - b = tfb.Scale([5., 9.], validate_args=True) - self.assertAllClose(self.evaluate(m.make_bij(b).forward(2.)), [10., 20.]) - self.evaluate(m.scale.assign(m.scale + [1., 2.])) - self.assertAllClose(self.evaluate(m.make_bij(b).forward(2.)), [12., 24.]) +# m = tf.saved_model.load(absolute_testdata_path) +# self.evaluate(m.scale.initializer) +# b = tfb.Scale([5., 9.], validate_args=True) +# self.assertAllClose(self.evaluate(m.make_bij(b).forward(2.)), [10., 20.]) +# self.evaluate(m.scale.assign(m.scale + [1., 2.])) +# self.assertAllClose(self.evaluate(m.make_bij(b).forward(2.)), [12., 24.]) def test_callable_arg(self): @@ -384,8 +389,8 @@ def __call__(self, *args, **kwargs): def test_composite_tensor_callable_arg(self): # Parameters that are both `CompositeTensor` and callable should be # handled by the `_type_spec` as `CompositeTensor`. - inner_bij = tfb.Scale([[1., 3.]], validate_args=True) - bij = tfb.TransformDiagonal(inner_bij, validate_args=True) + inner_bij = AutoScale([[1., 3.]], validate_args=True) + bij = AutoTransformDiagonal(inner_bij, validate_args=True) self.assertLen(tf.nest.flatten(bij), 1) self.assertLen(bij._type_spec._callable_params, 0) # pylint: disable=protected-access self.assertIn('diag_bijector', bij._type_spec._param_specs) # pylint: disable=protected-access @@ -454,13 +459,13 @@ class AutoCompositeTensorTypeSpecTest(test_util.TestCase): ('WithCallable', _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32), - 'b': tfb.Scale(3.)._type_spec}, + 'b': AutoScale(3.)._type_spec}, omit_kwargs=('name', 'foo'), prefer_static_value=('a',), callable_params={'f': tf.math.exp}), _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32), - 'b': tfb.Scale(3.)._type_spec}, + 'b': AutoScale(3.)._type_spec}, omit_kwargs=('name', 'foo'), prefer_static_value=('a',), callable_params={'f': tf.math.exp})), @@ -532,13 +537,13 @@ def testInequality(self, v1, v2): ('WithCallable', _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32), - 'b': tfb.Scale( + 'b': AutoScale( tf.Variable(2., shape=None))._type_spec}, omit_kwargs=('name', 'foo'), callable_params={'f': tf.math.exp}), _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32), - 'b': tfb.Scale(3.)._type_spec}, + 'b': AutoScale(3.)._type_spec}, omit_kwargs=('name', 'foo'), callable_params={'f': tf.math.exp})), ('DifferentNonIdentifyingKwargsValues', @@ -581,13 +586,13 @@ def testIsCompatibleWith(self, v1, v2): ('DifferentCallables', _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32), - 'b': tfb.Scale( + 'b': AutoScale( tf.Variable(2., shape=None))._type_spec}, omit_kwargs=('name', 'foo'), callable_params={'f': tf.math.exp}), _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32), - 'b': tfb.Scale(3.)._type_spec}, + 'b': AutoScale(3.)._type_spec}, omit_kwargs=('name', 'foo'), callable_params={'f': tf.math.sigmoid})) ) @@ -609,16 +614,16 @@ def testIsNotCompatibleWith(self, v1, v2): ('WithCallable', _TestTypeSpec( param_specs={'a': tf.TensorSpec(None, tf.float32), - 'b': tfb.Scale( + 'b': AutoScale( tf.Variable(2., shape=None))._type_spec}, callable_params={'f': tf.math.exp}), _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32), - 'b': tfb.Scale(tf.Variable(3.))._type_spec}, + 'b': AutoScale(tf.Variable(3.))._type_spec}, callable_params={'f': tf.math.exp}), _TestTypeSpec( param_specs={'a': tf.TensorSpec(None, tf.float32), - 'b': tfb.Scale( + 'b': AutoScale( tf.Variable(2., shape=None))._type_spec}, callable_params={'f': tf.math.exp})), ) @@ -644,12 +649,12 @@ def testMostSpecificCompatibleType(self, v1, v2, expected): ('DifferentCallables', _TestTypeSpec( param_specs={'a': tf.TensorSpec(None, tf.float32), - 'b': tfb.Scale( + 'b': AutoScale( tf.Variable(2., shape=None))._type_spec}, callable_params={'f': tf.math.exp}), _TestTypeSpec( param_specs={'a': tf.TensorSpec([3, None], tf.float32), - 'b': tfb.Scale(tf.Variable(3.))._type_spec}, + 'b': AutoScale(tf.Variable(3.))._type_spec}, callable_params={'f': tf.math.softplus})), ) def testMostSpecificCompatibleTypeException(self, v1, v2): @@ -666,7 +671,7 @@ def testMostSpecificCompatibleTypeException(self, v1, v2): ('WithCallable', _TestTypeSpec( param_specs={'a': tf.TensorSpec(None, tf.float32), - 'b': tfb.Scale( + 'b': AutoScale( tf.Variable(2., shape=None))._type_spec}, callable_params={'f': tf.math.exp})), ) diff --git a/tensorflow_probability/python/util/deferred_tensor_test.py b/tensorflow_probability/python/util/deferred_tensor_test.py index e1e824de1d..91af7433f1 100644 --- a/tensorflow_probability/python/util/deferred_tensor_test.py +++ b/tensorflow_probability/python/util/deferred_tensor_test.py @@ -24,6 +24,7 @@ import numpy as np import tensorflow.compat.v2 as tf import tensorflow_probability as tfp +from tensorflow_probability.python.internal import auto_composite_tensor from tensorflow_probability.python.internal import test_util from tensorflow_probability.python.util import deferred_tensor from tensorflow.python.ops import resource_variable_ops # pylint: disable=g-direct-tensorflow-import @@ -75,7 +76,7 @@ def test_properties(self): @test_util.numpy_disable_variable_test def test_retains_trainable_variables_from_bijector(self): m = tf.Variable(0., name='m') - x = tfp.util.DeferredTensor(1., tfb.Scale(m)) + x = tfp.util.DeferredTensor(1., AutoScale(m)) self.assertIn(m, x.trainable_variables) @test_util.jax_disable_variable_test @@ -447,6 +448,17 @@ def _make_bijector_spec( return bijector_class(param)._type_spec +AutoScale = auto_composite_tensor.auto_composite_tensor( + tfb.Scale, omit_kwargs=('parameters',), non_identifying_kwargs=('name',), + module_name=('tfp.bijectors')) +AutoSigmoid = auto_composite_tensor.auto_composite_tensor( + tfb.Sigmoid, omit_kwargs=('parameters',), non_identifying_kwargs=('name',), + module_name=('tfp.bijectors')) +AutoShift = auto_composite_tensor.auto_composite_tensor( + tfb.Shift, omit_kwargs=('parameters',), non_identifying_kwargs=('name',), + module_name=('tfp.bijectors')) + + @test_util.test_all_tf_execution_regimes @test_util.disable_test_for_backend( disable_numpy=True, disable_jax=True, @@ -457,10 +469,10 @@ class DeferredTensorSpecTest(test_util.TestCase): ('TransformedVariableBijector', _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, None], tf.float32), - transform_or_spec=_make_bijector_spec(tfb.Scale, [3.])), + transform_or_spec=_make_bijector_spec(AutoScale, [3.])), _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, None], tf.float32), - transform_or_spec=_make_bijector_spec(tfb.Scale, [3.]))), + transform_or_spec=_make_bijector_spec(AutoScale, [3.]))), ('TranformedVariableCallable', _make_transformed_variable_spec( input_spec=resource_variable_ops.VariableSpec(None, tf.float64), @@ -485,11 +497,11 @@ def testEquality(self, v1, v2): ('DifferentDtypes', _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float64), - transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec, + transform_or_spec=AutoSigmoid(validate_args=True)._type_spec, dtype=tf.float64), _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), - transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)), + transform_or_spec=AutoSigmoid(validate_args=True)._type_spec)), ('DifferentCallables', _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float64), @@ -513,10 +525,10 @@ def testInequality(self, v1, v2): ('TransformedVariableBijector', _make_transformed_variable_spec( input_spec=resource_variable_ops.VariableSpec([4, 2], tf.float32), - transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec), + transform_or_spec=AutoSigmoid(validate_args=True)._type_spec), _make_transformed_variable_spec( input_spec=resource_variable_ops.VariableSpec([4, 2], tf.float32), - transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)), + transform_or_spec=AutoSigmoid(validate_args=True)._type_spec)), ('TransformedVariableCallable', _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), @@ -535,11 +547,11 @@ def testIsCompatibleWith(self, v1, v2): ('DifferentDtypes', _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), - transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec, + transform_or_spec=AutoSigmoid(validate_args=True)._type_spec, dtype=tf.float64), _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), - transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec)), + transform_or_spec=AutoSigmoid(validate_args=True)._type_spec)), ('DifferentCallables', _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float64), @@ -561,15 +573,15 @@ def testIsNotCompatibleWith(self, v1, v2): _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=_make_bijector_spec( - tfb.Shift, [[2.]], use_variable=True, variable_shape=[1, 1])), + AutoShift, [[2.]], use_variable=True, variable_shape=[1, 1])), _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=_make_bijector_spec( - tfb.Shift, [[3.]], use_variable=True, variable_shape=[1, None])), + AutoShift, [[3.]], use_variable=True, variable_shape=[1, None])), _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float32), transform_or_spec=_make_bijector_spec( - tfb.Shift, [[3.]], use_variable=True, variable_shape=[1, None])) + AutoShift, [[3.]], use_variable=True, variable_shape=[1, None])) ), ('TransformedVariableCallable', _make_transformed_variable_spec( @@ -590,11 +602,11 @@ def testMostSpecificCompatibleType(self, v1, v2, expected): ('DifferentDtypes', _make_transformed_variable_spec( input_spec=tf.TensorSpec([], tf.float32), - transform_or_spec=tfb.Sigmoid()._type_spec, + transform_or_spec=AutoSigmoid()._type_spec, dtype=tf.float64), _make_transformed_variable_spec( input_spec=tf.TensorSpec([], tf.float32), - transform_or_spec=tfb.Sigmoid()._type_spec)), + transform_or_spec=AutoSigmoid()._type_spec)), ('DifferentCallables', _make_transformed_variable_spec( input_spec=tf.TensorSpec([4, 2], tf.float64), @@ -613,18 +625,19 @@ def testMostSpecificCompatibleTypeException(self, v1, v2): with self.assertRaises(ValueError): v2.most_specific_compatible_type(v1) - def testRepr(self): - spec = _make_transformed_variable_spec( - input_spec=tf.TensorSpec([4, 2], tf.float32), - transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec, - dtype=tf.float64) - expected = ( - "_TransformedVariableSpec(input_spec=TensorSpec(shape=(4, 2), " - "dtype=tf.float32, name=None), " - "transform_or_spec=Sigmoid_ACTTypeSpec(3, {}, {'low': None, 'high': " - "None, 'validate_args': True, 'name': 'sigmoid'}, ('parameters',), (), " - "('name',), {}), dtype=, name=None)") - self.assertEqual(repr(spec), expected) + # Disable test for TFP 0.13 release, in which bijectors are not AutoCT. + # def testRepr(self): + # spec = _make_transformed_variable_spec( + # input_spec=tf.TensorSpec([4, 2], tf.float32), + # transform_or_spec=tfb.Sigmoid(validate_args=True)._type_spec, + # dtype=tf.float64) + # expected = ( + # "_TransformedVariableSpec(input_spec=TensorSpec(shape=(4, 2), " + # "dtype=tf.float32, name=None), " + # "transform_or_spec=Sigmoid_ACTTypeSpec(3, {}, {'low': None, 'high': " + # "None, 'validate_args': True, 'name': 'sigmoid'}, ('parameters',), (), " # pylint: disable=line-too-long + # "('name',), {}), dtype=, name=None)") + # self.assertEqual(repr(spec), expected) if __name__ == '__main__':