Skip to content

Commit

Permalink
Add an auto_composite_tensor_bijector decorator for bijectors that …
Browse files Browse the repository at this point in the history
…preserves the `name` attribute through flattening/unflattening and in serialization.

PiperOrigin-RevId: 374731367
  • Loading branch information
emilyfertig authored and brianwa84 committed May 20, 2021
1 parent 35932a7 commit 1b1beb7
Show file tree
Hide file tree
Showing 57 changed files with 68 additions and 207 deletions.
48 changes: 0 additions & 48 deletions tensorflow_probability/python/bijectors/BUILD

Large diffs are not rendered by default.

4 changes: 1 addition & 3 deletions tensorflow_probability/python/bijectors/absolute_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,14 @@

from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import auto_composite_tensor
from tensorflow_probability.python.internal import dtype_util

__all__ = [
'AbsoluteValue',
]


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector.auto_composite_tensor_bijector
class AbsoluteValue(bijector.AutoCompositeTensorBijector):
"""Computes `Y = g(X) = Abs(X)`, element-wise.
Expand Down
4 changes: 1 addition & 3 deletions tensorflow_probability/python/bijectors/ascending.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,14 @@

from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import auto_composite_tensor


__all__ = [
'Ascending',
]


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector.auto_composite_tensor_bijector
class Ascending(bijector.AutoCompositeTensorBijector):
"""Maps unconstrained R^n to R^n in ascending order.
Expand Down
8 changes: 8 additions & 0 deletions tensorflow_probability/python/bijectors/bijector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import abc
import contextlib
import functools

# Dependency imports
import numpy as np
Expand Down Expand Up @@ -1620,6 +1621,13 @@ class MyBijector(tfb.AutoCompositeTensorBijector):
pass


auto_composite_tensor_bijector = functools.partial(
auto_composite_tensor.auto_composite_tensor,
omit_kwargs=('parameters',),
non_identifying_kwargs=('name',),
module_name='tfp.bijectors')


def check_valid_ndims(ndims, validate=True):
"""Ensures that `ndims` is a non-negative integer.
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_probability/python/bijectors/bijector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf
from tensorflow_probability.python import bijectors as tfb
from tensorflow_probability.python.internal import auto_composite_tensor
from tensorflow_probability.python.bijectors import bijector as bijector_lib
from tensorflow_probability.python.internal import cache_util
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.internal import test_util
Expand Down Expand Up @@ -768,7 +768,7 @@ def testNestedCondition(self):
mock_method.assert_called_once_with(mock.ANY, arg1=arg1, arg2=arg2)


@auto_composite_tensor.auto_composite_tensor(omit_kwargs=('name',))
@bijector_lib.auto_composite_tensor_bijector
class CompositeForwardBijector(tfb.AutoCompositeTensorBijector):

def __init__(self, scale=2., validate_args=False, name=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
]


class CategoricalToDiscrete(bijector.Bijector):
@bijector.auto_composite_tensor_bijector
class CategoricalToDiscrete(bijector.AutoCompositeTensorBijector):
"""Bijector which computes `Y = g(X) = values[X]`.
Example Usage:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import auto_composite_tensor
from tensorflow_probability.python.internal import distribution_util
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static as ps
Expand All @@ -36,8 +35,7 @@
]


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector.auto_composite_tensor_bijector
class CholeskyOuterProduct(bijector.AutoCompositeTensorBijector):
"""Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.bijectors.cholesky_outer_product import CholeskyOuterProduct
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import auto_composite_tensor
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static as ps

Expand All @@ -33,8 +32,7 @@
]


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector.auto_composite_tensor_bijector
class CholeskyToInvCholesky(bijector.AutoCompositeTensorBijector):
"""Maps the Cholesky factor of `M` to the Cholesky factor of `M^{-1}`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.bijectors import fill_triangular
from tensorflow_probability.python.internal import auto_composite_tensor
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import tensorshape_util

Expand All @@ -33,8 +32,7 @@
]


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector.auto_composite_tensor_bijector
class CorrelationCholesky(bijector.AutoCompositeTensorBijector):
"""Maps unconstrained reals to Cholesky-space correlation matrices.
Expand Down
4 changes: 1 addition & 3 deletions tensorflow_probability/python/bijectors/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,14 @@

import tensorflow.compat.v2 as tf
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import auto_composite_tensor
from tensorflow_probability.python.internal import prefer_static

__all__ = [
'Cumsum',
]


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector.auto_composite_tensor_bijector
class Cumsum(bijector.AutoCompositeTensorBijector):
"""Computes the cumulative sum of a tensor along a specified axis.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import tensorflow.compat.v2 as tf

from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import auto_composite_tensor
from tensorflow_probability.python.internal import dtype_util


Expand All @@ -30,8 +29,7 @@
]


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector.auto_composite_tensor_bijector
class DiscreteCosineTransform(bijector.AutoCompositeTensorBijector):
"""Compute `Y = g(X) = DCT(X)`, where DCT type is indicated by the `type` arg.
Expand Down
7 changes: 3 additions & 4 deletions tensorflow_probability/python/bijectors/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import tensorflow.compat.v2 as tf

from tensorflow_probability.python.bijectors import bijector as bijector_lib
from tensorflow_probability.python.bijectors import invert
from tensorflow_probability.python.bijectors import power_transform
from tensorflow_probability.python.internal import auto_composite_tensor
Expand All @@ -31,8 +32,7 @@
]


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector_lib.auto_composite_tensor_bijector
class Exp(power_transform.PowerTransform):
"""Compute `Y = g(X) = exp(X)`.
Expand Down Expand Up @@ -76,8 +76,7 @@ def __init__(self,

# TODO(b/182603117): Remove `AutoCompositeTensor` when `Invert` subclasses
# `AutoCompositeTensor` and ensure `tf.saved_model` still works.
@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector_lib.auto_composite_tensor_bijector
class Log(invert.Invert,
auto_composite_tensor.AutoCompositeTensor):
"""Compute `Y = log(X)`. This is `Invert(Exp())`."""
Expand Down
6 changes: 2 additions & 4 deletions tensorflow_probability/python/bijectors/expm1.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
]


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector.auto_composite_tensor_bijector
class Expm1(bijector.AutoCompositeTensorBijector):
"""Compute `Y = g(X) = exp(X) - 1`.
Expand Down Expand Up @@ -95,8 +94,7 @@ def _forward_log_det_jacobian(self, x):

# TODO(b/182603117): Remove `AutoCompositeTensor` when `Invert` subclasses
# `AutoCompositeTensor`.
@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector.auto_composite_tensor_bijector
class Log1p(invert.Invert, auto_composite_tensor.AutoCompositeTensor):
"""Compute `Y = log1p(X)`. This is `Invert(Expm1())`."""

Expand Down
4 changes: 1 addition & 3 deletions tensorflow_probability/python/bijectors/fill_triangular.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import auto_composite_tensor
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow_probability.python.math.linalg import fill_triangular
from tensorflow_probability.python.math.linalg import fill_triangular_inverse
Expand All @@ -36,8 +35,7 @@
]


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector.auto_composite_tensor_bijector
class FillTriangular(bijector.AutoCompositeTensorBijector):
"""Transforms vectors to triangular.
Expand Down
4 changes: 1 addition & 3 deletions tensorflow_probability/python/bijectors/frechet_cdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.bijectors import softplus as softplus_bijector
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import auto_composite_tensor
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import parameter_properties
from tensorflow_probability.python.internal import tensor_util
Expand All @@ -34,8 +33,7 @@
]


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector.auto_composite_tensor_bijector
class FrechetCDF(bijector.AutoCompositeTensorBijector):
"""The Frechet cumulative density function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from tensorflow_probability.python.bijectors import shift as shift_bijector
from tensorflow_probability.python.bijectors import sigmoid as sigmoid_bijector
from tensorflow_probability.python.bijectors import softplus as softplus_bijector
from tensorflow_probability.python.internal import auto_composite_tensor
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import parameter_properties
from tensorflow_probability.python.internal import tensor_util
Expand All @@ -35,8 +34,7 @@
]


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector_lib.auto_composite_tensor_bijector
class GeneralizedPareto(bijector_lib.AutoCompositeTensorBijector):
"""Bijector mapping R**n to non-negative reals.
Expand Down
4 changes: 1 addition & 3 deletions tensorflow_probability/python/bijectors/gev_cdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.bijectors import softplus as softplus_bijector
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import auto_composite_tensor
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import parameter_properties
from tensorflow_probability.python.internal import tensor_util
Expand All @@ -33,8 +32,7 @@
]


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector.auto_composite_tensor_bijector
class GeneralizedExtremeValueCDF(bijector.AutoCompositeTensorBijector):
"""Compute the GeneralizedExtremeValue CDF.
Expand Down
4 changes: 1 addition & 3 deletions tensorflow_probability/python/bijectors/gompertz_cdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.bijectors import softplus as softplus_bijector
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import auto_composite_tensor
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import parameter_properties
from tensorflow_probability.python.internal import tensor_util
Expand All @@ -34,8 +33,7 @@
]


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector.auto_composite_tensor_bijector
class GompertzCDF(bijector.AutoCompositeTensorBijector):
"""Compute `Y = g(X) = 1 - exp(-c * (exp(rate * X) - 1)`, the Gompertz CDF.
Expand Down
4 changes: 1 addition & 3 deletions tensorflow_probability/python/bijectors/gumbel_cdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.bijectors import softplus as softplus_bijector
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import auto_composite_tensor
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import parameter_properties
from tensorflow_probability.python.internal import tensor_util
Expand All @@ -34,8 +33,7 @@
]


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector.auto_composite_tensor_bijector
class GumbelCDF(bijector.AutoCompositeTensorBijector):
"""Compute `Y = g(X) = exp(-exp(-(X - loc) / scale))`, the Gumbel CDF.
Expand Down
4 changes: 1 addition & 3 deletions tensorflow_probability/python/bijectors/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import tensorflow.compat.v2 as tf

from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import auto_composite_tensor

__all__ = [
'Identity',
Expand All @@ -34,8 +33,7 @@ def __getitem__(self, _):
return {}


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector.auto_composite_tensor_bijector
class Identity(bijector.AutoCompositeTensorBijector):
"""Compute Y = g(X) = X.
Expand Down
4 changes: 1 addition & 3 deletions tensorflow_probability/python/bijectors/inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,14 @@
import tensorflow.compat.v2 as tf

from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import auto_composite_tensor


__all__ = [
'Inline',
]


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector.auto_composite_tensor_bijector
class Inline(bijector.AutoCompositeTensorBijector):
"""Bijector constructed from custom callables.
Expand Down
3 changes: 1 addition & 2 deletions tensorflow_probability/python/bijectors/invert.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,7 @@ def forward_event_ndims(self, event_ndims, **kwargs):
return self.bijector.inverse_event_ndims(event_ndims, **kwargs)


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name', 'parameters'), module_name='tfp.bijectors')
@bijector_lib.auto_composite_tensor_bijector
class Invert(_Invert, auto_composite_tensor.AutoCompositeTensor):

def __new__(cls, *args, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import auto_composite_tensor
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import prefer_static as ps

Expand All @@ -32,8 +31,7 @@
]


@auto_composite_tensor.auto_composite_tensor(
omit_kwargs=('name',), module_name='tfp.bijectors')
@bijector.auto_composite_tensor_bijector
class IteratedSigmoidCentered(bijector.AutoCompositeTensorBijector):
"""Bijector which applies a Stick Breaking procedure.
Expand Down
Loading

0 comments on commit 1b1beb7

Please sign in to comment.