Skip to content

Commit

Permalink
Temporarily remove AutoCompositeTensor from bijectors for TFP 0.13 re…
Browse files Browse the repository at this point in the history
…lease.

PiperOrigin-RevId: 374899448
  • Loading branch information
emilyfertig authored and brianwa84 committed May 20, 2021
1 parent 1b1beb7 commit 49e7f5e
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 173 deletions.
52 changes: 29 additions & 23 deletions tensorflow_probability/python/bijectors/bijector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import abc
import contextlib
import functools
# import functools

# Dependency imports
import numpy as np
Expand Down Expand Up @@ -1599,33 +1599,39 @@ def _composite_tensor_shape_params(self):
return ()


class AutoCompositeTensorBijector(
Bijector, auto_composite_tensor.AutoCompositeTensor):
r"""Base for `CompositeTensor` bijectors with auto-generated `TypeSpec`s.
# Temporarily disable AutoCT for TFP 0.13 release
# class AutoCompositeTensorBijector(
# Bijector, auto_composite_tensor.AutoCompositeTensor):
# r"""Base for `CompositeTensor` bijectors with auto-generated `TypeSpec`s.

`CompositeTensor` objects are able to pass in and out of `tf.function` and
`tf.while_loop`, or serve as part of the signature of a TF saved model.
`Bijector` subclasses that follow the contract of
`tfp.experimental.auto_composite_tensor` may be defined as `CompositeTensor`s
by inheriting from `AutoCompositeTensorBijector` and applying a class
decorator as shown here:
# `CompositeTensor` objects are able to pass in and out of `tf.function` and
# `tf.while_loop`, or serve as part of the signature of a TF saved model.
# `Bijector` subclasses that follow the contract of
# `tfp.experimental.auto_composite_tensor` may be defined as `CompositeTensor`s # pylint: disable=line-too-long
# by inheriting from `AutoCompositeTensorBijector` and applying a class
# decorator as shown here:

```python
@tfp.experimental.auto_composite_tensor(
omit_kwargs=('name',), module_name='my_module')
class MyBijector(tfb.AutoCompositeTensorBijector):
# ```python
# @tfp.experimental.auto_composite_tensor(
# omit_kwargs=('name',), module_name='my_module')
# class MyBijector(tfb.AutoCompositeTensorBijector):

# The remainder of the subclass implementation is unchanged.
```
"""
pass
# # The remainder of the subclass implementation is unchanged.
# ```
# """
# pass


# auto_composite_tensor_bijector = functools.partial(
# auto_composite_tensor.auto_composite_tensor,
# omit_kwargs=('parameters',),
# non_identifying_kwargs=('name',),
# module_name='tfp.bijectors')

AutoCompositeTensorBijector = Bijector


auto_composite_tensor_bijector = functools.partial(
auto_composite_tensor.auto_composite_tensor,
omit_kwargs=('parameters',),
non_identifying_kwargs=('name',),
module_name='tfp.bijectors')
auto_composite_tensor_bijector = lambda cls, **kwargs: cls


def check_valid_ndims(ndims, validate=True):
Expand Down
64 changes: 11 additions & 53 deletions tensorflow_probability/python/bijectors/bijector_properties_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,13 @@
from tensorflow_probability.python import bijectors as tfb
from tensorflow_probability.python import experimental
from tensorflow_probability.python.bijectors import hypothesis_testlib as bijector_hps
from tensorflow_probability.python.bijectors import invert as invert_lib
from tensorflow_probability.python.internal import hypothesis_testlib as tfp_hps
from tensorflow_probability.python.internal import prefer_static
from tensorflow_probability.python.internal import samplers
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow_probability.python.internal import test_util
from tensorflow_probability.python.util.deferred_tensor import DeferredTensor
# from tensorflow_probability.python.util.deferred_tensor import DeferredTensor


TF2_FRIENDLY_BIJECTORS = (
Expand Down Expand Up @@ -186,6 +185,7 @@

COMPOSITE_TENSOR_IS_BROKEN = [
'BatchNormalization', # tf.layers arg
'Inline', # callable
'RationalQuadraticSpline', # TODO(b/185628453): Debug loss of static info.
]

Expand All @@ -204,7 +204,7 @@


def is_invert(bijector):
return isinstance(bijector, (tfb.Invert, invert_lib._Invert))
return isinstance(bijector, tfb.Invert)


def is_transform_diagonal(bijector):
Expand Down Expand Up @@ -244,8 +244,8 @@ def _constraint(param):


# TODO(b/141098791): Eliminate this.
@experimental.auto_composite_tensor
class CallableModule(tf.Module, experimental.AutoCompositeTensor):
# @experimental.auto_composite_tensor
class CallableModule(tf.Module): # , experimental.AutoCompositeTensor):
"""Convenience object for capturing variables closed over by Inline."""

def __init__(self, fn, varobj):
Expand Down Expand Up @@ -887,38 +887,16 @@ def testEquality(self, bijector_name, data):
@hp.given(hps.data())
@tfp_hps.tfp_hp_settings()
def testCompositeTensor(self, bijector_name, data):

# Test that making a composite tensor of this bijector doesn't throw any
# errors.
bijector, event_dim = self._draw_bijector(
bijector_name, data,
batch_shape=[],
validate_args=True,
bijector_name, data, batch_shape=[],
allowed_bijectors=(set(TF2_FRIENDLY_BIJECTORS) -
set(COMPOSITE_TENSOR_IS_BROKEN)))

if type(bijector) is invert_lib._Invert: # pylint: disable=unidiomatic-typecheck
if isinstance(bijector.bijector, tf.__internal__.CompositeTensor):
raise TypeError('`_Invert` should wrap only non-`CompositeTensor` '
'bijectors.')
self.skipTest('`_Invert` bijectors are not `CompositeTensor`s.')

# TODO(b/182603117): Remove "if" condition and s/composite_bij/bijector
# when AutoCT is enabled for meta-bijectors and LinearOperator.
if type(bijector).__name__ in AUTO_COMPOSITE_TENSOR_IS_BROKEN:
composite_bij = experimental.as_composite(bijector)
else:
composite_bij = bijector

if not tf.executing_eagerly():
composite_bij = tf.nest.map_structure(
lambda x: (tf.convert_to_tensor(x) # pylint: disable=g-long-lambda
if isinstance(x, DeferredTensor) else x),
composite_bij,
expand_composites=True)

self.assertIsInstance(composite_bij, tf.__internal__.CompositeTensor)
composite_bij = experimental.as_composite(bijector)
flat = tf.nest.flatten(composite_bij, expand_composites=True)
unflat = tf.nest.pack_sequence_as(
composite_bij, flat, expand_composites=True)
unflat = tf.nest.pack_sequence_as(composite_bij, flat,
expand_composites=True)

# Compare forward maps before and after compositing.
n = 3
Expand All @@ -933,26 +911,6 @@ def testCompositeTensor(self, bijector_name, data):
after_xs = unflat.inverse(ys)
self.assertAllClose(*self.evaluate((before_xs, after_xs)))

# Input to tf.function
self.assertAllClose(
before_ys,
tf.function(lambda b: b.forward(xs))(composite_bij),
rtol=COMPOSITE_TENSOR_RTOL[bijector_name],
atol=COMPOSITE_TENSOR_ATOL[bijector_name])

# Forward mapping: Check differentiation through forward mapping with
# respect to the input and parameter variables. Also check that any
# variables are not referenced overmuch.
xs = self._draw_domain_tensor(bijector, data, event_dim)
wrt_vars = [xs] + [v for v in composite_bij.trainable_variables
if v.dtype.is_floating]
with tf.GradientTape() as tape:
tape.watch(wrt_vars)
# TODO(b/73073515): Fix graph mode gradients with bijector caching.
ys = bijector.forward(xs + 0)
grads = tape.gradient(ys, wrt_vars)
assert_no_none_grad(bijector, 'forward', wrt_vars, grads)


def ensure_nonzero(x):
return tf.where(x < 1e-6, tf.constant(1e-3, x.dtype), x)
Expand Down
29 changes: 15 additions & 14 deletions tensorflow_probability/python/bijectors/bijector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,25 +790,26 @@ def _forward_log_det_jacobian(self, _):
return tf.math.log(self._scale)


@test_util.test_all_tf_execution_regimes
class AutoCompositeTensorBijectorTest(test_util.TestCase):
# Test disabled temporarily for TFP 0.13 release.
# @test_util.test_all_tf_execution_regimes
# class AutoCompositeTensorBijectorTest(test_util.TestCase):

def test_disable_ct_bijector(self):
# def test_disable_ct_bijector(self):

ct_bijector = CompositeForwardBijector()
self.assertIsInstance(ct_bijector, tf.__internal__.CompositeTensor)
# ct_bijector = CompositeForwardBijector()
# self.assertIsInstance(ct_bijector, tf.__internal__.CompositeTensor)

non_ct_bijector = ForwardOnlyBijector()
self.assertNotIsInstance(non_ct_bijector, tf.__internal__.CompositeTensor)
# non_ct_bijector = ForwardOnlyBijector()
# self.assertNotIsInstance(non_ct_bijector, tf.__internal__.CompositeTensor)

flat = tf.nest.flatten(ct_bijector, expand_composites=True)
unflat = tf.nest.pack_sequence_as(
ct_bijector, flat, expand_composites=True)
# flat = tf.nest.flatten(ct_bijector, expand_composites=True)
# unflat = tf.nest.pack_sequence_as(
# ct_bijector, flat, expand_composites=True)

x = tf.constant([2., 3.])
self.assertAllClose(
non_ct_bijector.forward(x),
tf.function(lambda b: b.forward(x))(unflat))
# x = tf.constant([2., 3.])
# self.assertAllClose(
# non_ct_bijector.forward(x),
# tf.function(lambda b: b.forward(x))(unflat))


if __name__ == '__main__':
Expand Down
4 changes: 1 addition & 3 deletions tensorflow_probability/python/bijectors/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from tensorflow_probability.python.bijectors import bijector as bijector_lib
from tensorflow_probability.python.bijectors import invert
from tensorflow_probability.python.bijectors import power_transform
from tensorflow_probability.python.internal import auto_composite_tensor


__all__ = [
Expand Down Expand Up @@ -77,8 +76,7 @@ def __init__(self,
# TODO(b/182603117): Remove `AutoCompositeTensor` when `Invert` subclasses
# `AutoCompositeTensor` and ensure `tf.saved_model` still works.
@bijector_lib.auto_composite_tensor_bijector
class Log(invert.Invert,
auto_composite_tensor.AutoCompositeTensor):
class Log(invert.Invert):
"""Compute `Y = log(X)`. This is `Invert(Exp())`."""

def __init__(self, validate_args=False, name='log'):
Expand Down
3 changes: 1 addition & 2 deletions tensorflow_probability/python/bijectors/expm1.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import tensorflow.compat.v2 as tf
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.bijectors import invert
from tensorflow_probability.python.internal import auto_composite_tensor


__all__ = [
Expand Down Expand Up @@ -95,7 +94,7 @@ def _forward_log_det_jacobian(self, x):
# TODO(b/182603117): Remove `AutoCompositeTensor` when `Invert` subclasses
# `AutoCompositeTensor`.
@bijector.auto_composite_tensor_bijector
class Log1p(invert.Invert, auto_composite_tensor.AutoCompositeTensor):
class Log1p(invert.Invert):
"""Compute `Y = log1p(X)`. This is `Invert(Expm1())`."""

def __init__(self, validate_args=False, name='log1p'):
Expand Down
46 changes: 24 additions & 22 deletions tensorflow_probability/python/bijectors/invert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
import tensorflow.compat.v2 as tf

from tensorflow_probability.python.bijectors import bijector as bijector_lib
from tensorflow_probability.python.internal import auto_composite_tensor
# from tensorflow_probability.python.internal import auto_composite_tensor

__all__ = [
'Invert',
]


class _Invert(bijector_lib.Bijector):
class Invert(bijector_lib.Bijector):
"""Bijector which inverts another Bijector.
Example Use: [ExpGammaDistribution (see Background & Context)](
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(self, bijector, validate_args=False, parameters=None, name=None):
name = name or '_'.join(['invert', bijector.name])
with tf.name_scope(name) as name:
self._bijector = bijector
super(_Invert, self).__init__(
super(Invert, self).__init__(
forward_min_event_ndims=bijector.inverse_min_event_ndims,
inverse_min_event_ndims=bijector.forward_min_event_ndims,
dtype=bijector.dtype,
Expand Down Expand Up @@ -138,26 +138,28 @@ def forward_event_ndims(self, event_ndims, **kwargs):
return self.bijector.inverse_event_ndims(event_ndims, **kwargs)


@bijector_lib.auto_composite_tensor_bijector
class Invert(_Invert, auto_composite_tensor.AutoCompositeTensor):
# Temporarily removing AutoCompositeTensor for TFP 0.13 release.
# pylint: disable=line-too-long
# @bijector_lib.auto_composite_tensor_bijector
# class Invert(_Invert, auto_composite_tensor.AutoCompositeTensor):

def __new__(cls, *args, **kwargs):
"""Returns an `_Invert` instance if `bijector` is not a `CompositeTensor."""
if cls is Invert:
if args:
bijector = args[0]
elif 'bijector' in kwargs:
bijector = kwargs['bijector']
else:
raise TypeError('`Invert.__new__()` is missing argument `bijector`.')
# def __new__(cls, *args, **kwargs):
# """Returns an `_Invert` instance if `bijector` is not a `CompositeTensor."""
# if cls is Invert:
# if args:
# bijector = args[0]
# elif 'bijector' in kwargs:
# bijector = kwargs['bijector']
# else:
# raise TypeError('`Invert.__new__()` is missing argument `bijector`.')

if not isinstance(bijector, tf.__internal__.CompositeTensor):
return _Invert(*args, **kwargs)
return super(Invert, cls).__new__(cls)
# if not isinstance(bijector, tf.__internal__.CompositeTensor):
# return _Invert(*args, **kwargs)
# return super(Invert, cls).__new__(cls)


Invert.__doc__ = _Invert.__doc__ + '/n' + (
'When an `Invert` bijector is constructed, if its `bijector` arg is not a '
'`CompositeTensor` instance, an `_Invert` instance is returned instead. '
'Bijectors subclasses that inherit from `Invert` will also inherit from '
' `CompositeTensor`.')
# Invert.__doc__ = _Invert.__doc__ + '/n' + (
# 'When an `Invert` bijector is constructed, if its `bijector` arg is not a '
# '`CompositeTensor` instance, an `_Invert` instance is returned instead. '
# 'Bijectors subclasses that inherit from `Invert` will also inherit from '
# ' `CompositeTensor`.')
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,11 @@ def test_basics_mixture_same_family(self):
self.evaluate(unflat.log_prob(.5))

def test_already_composite_tensor(self):
b = tfb.Scale(2.)
AutoScale = tfp.experimental.auto_composite_tensor( # pylint: disable=invalid-name
tfb.Scale, omit_kwargs=('parameters',),
non_identifying_kwargs=('name',),
module_name=('tfp.bijectors'))
b = AutoScale(2.)
b2 = tfp.experimental.as_composite(b)
self.assertIsInstance(b, tf.__internal__.CompositeTensor)
self.assertIs(b, b2)
Expand Down
Loading

0 comments on commit 49e7f5e

Please sign in to comment.