Skip to content

Commit

Permalink
Merge pull request #1174 from jburnim/r0.12
Browse files Browse the repository at this point in the history
Prepare branch for the TFP 0.12.0rc2 release
  • Loading branch information
jburnim authored Nov 20, 2020
2 parents 7784466 + 1fd985d commit ed47dda
Show file tree
Hide file tree
Showing 75 changed files with 2,646 additions and 1,114 deletions.
22 changes: 17 additions & 5 deletions STYLE_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ they supersede all previous conventions.
* Definitely use named args for 2nd args onward in docstrings.
1. Use names which describe semantics, not computation or mathematics, e.g.,
avoid `xp1 = x+1` or `tfd.Normal(loc=mu, scale=sigma)`.
avoid `xp1 = x + 1` or `tfd.Normal(loc=mu, scale=sigma)`.
1. Prefer inlining intermediates which are used once.
Expand Down Expand Up @@ -157,16 +157,16 @@ they supersede all previous conventions.
1. Prefer using the most specific TF operator. E.g,
* Use `tf.squared_difference(x,y)` over `(x-y)**2`.
* Use `tf.rsqrt` over `1./tf.sqrt(x)`.
* Use `tf.squared_difference(x, y)` over `(x - y)**2`.
* Use `tf.rsqrt` over `1. / tf.sqrt(x)`.
1. Worry about gradients! (It's often not automatic for API builders!)
1. When forced to choose between FLOPS and numerical accuracy, prefer numerical
accuracy.
1. Avoid tf.cast if possible. Eg, prefer `tf.where(cond, a, b)` over
`tf.cast(cond,dtype=a.dtype)*a + (1-tf.cast(cond,dtype=b.dtype)*b`
1. Avoid tf.cast if possible. Eg, prefer `tf.where(pred, a, b)` over
`tf.cast(cond, dtype=a.dtype) * a + (1 - tf.cast(cond, dtype=b.dtype) * b`
1. Preserve static shape hints.
Expand Down Expand Up @@ -217,3 +217,15 @@ they supersede all previous conventions.
`Tensor`s, and Numpy objects. When converting a user-provided literal to a
`Tensor` (see e.g. `Distribution._call_log_prob`), specify the dtype to
`tf.convert_to_tensor` if it is available.
1. Prefer overloaded operators on `Tensor`s (`+`, `-`, etc.) to explicit
method calls (`tf.add`, `tf.sub`, etc.). Exceptions:
* Prefer `tf.equal` to `==` when checking element-wise equality, because the
semantics of the latter are inconsistent between eager and graph
(`tf.function`) modes.
* Use `&` and `|` only if you want bitwise logic. Note that these are
equivalent to logical ops only if all inputs are `bool`s or are in
`{0, 1}`.
17 changes: 4 additions & 13 deletions spinoffs/oryx/oryx/bijectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,15 @@
from oryx.bijectors import bijector_extensions
from tensorflow_probability.substrates import jax as tfp

__all__ = [
'bijector_extensions'
]

tfb = tfp.bijectors

_bijectors = {}
__all__ = tfb.__all__

for name in dir(tfb):
for name in __all__:
bij = getattr(tfb, name)
if inspect.isclass(bij) and issubclass(bij, tfb.Bijector):
if bij is not tfb.Bijector:
bij = bijector_extensions.make_type(bij)
_bijectors[name] = bij


for key, val in _bijectors.items():
locals()[key] = val

locals()[name] = bij

del _bijectors
del tfb
4 changes: 2 additions & 2 deletions spinoffs/oryx/oryx/core/interpreters/harvest.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def process_higher_order_primitive(self, primitive, f, tracers, params,
params = params.copy()
new_params = dict(
params,
mapped_invars=(True,) * len(tree_util.tree_leaves(plants)) +
params['mapped_invars'])
in_axes=(0,) * len(tree_util.tree_leaves(plants)) +
params['in_axes'])
else:
new_params = dict(params)
all_args, all_tree = tree_util.tree_flatten((plants, vals))
Expand Down
4 changes: 2 additions & 2 deletions spinoffs/oryx/oryx/core/interpreters/inverse/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,8 @@ def remove_slice(cell):
flat_vals, in_tree = tree_util.tree_flatten((mapped_incells, mapped_outcells))
f, aux = flat_propagate(f, in_tree)
# Assume all invars as mapped
new_mapped_invars = (True,) * len(flat_vals)
new_params = dict(params, mapped_invars=new_mapped_invars)
new_in_axes = (0,) * len(flat_vals)
new_params = dict(params, in_axes=new_in_axes)
if 'donated_invars' in params:
new_params['donated_invars'] = (False,) * len(flat_vals)
subenv_vals = prim.bind(f, *flat_vals, **new_params)
Expand Down
21 changes: 7 additions & 14 deletions spinoffs/oryx/oryx/core/interpreters/unzip.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
from jax import core as jax_core
from jax import custom_derivatives as cd
from jax import linear_util as lu
from jax import source_info_util
from jax import tree_util
from jax import util as jax_util
from jax._src import source_info_util
from jax.interpreters import partial_eval as pe
import numpy as onp

Expand Down Expand Up @@ -282,14 +282,13 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
return current_custom_rules()[call_primitive](self, f, *tracers, **params)
if call_primitive in pe.call_partial_eval_rules:
raise NotImplementedError
in_pvs, in_consts = jax_util.unzip2(t.pval for t in tracers)
in_pvals = [t.pval for t in tracers]
if is_map:
pvs = [
None if pv is None else mapped_aval(params['axis_size'], pv)
for pv in in_pvs
]
else:
pvs = in_pvs
unknown = pe.PartialVal.unknown
in_pvals = [pval if pval.is_known() or in_axis is None else
unknown(mapped_aval(params['axis_size'], in_axis, pval[0]))
for pval, in_axis in zip(in_pvals, params['in_axes'])]
pvs, in_consts = jax_util.unzip2(t.pval for t in tracers)
keys = tuple(t.is_key() for t in tracers)
new_settings = UnzipSettings(settings.tag, call_primitive in block_registry)
fun, aux = unzip_eval(f, self, keys, tuple(pvs), new_settings)
Expand Down Expand Up @@ -360,12 +359,6 @@ def _bound_output_tracers(self, primitive, params, jaxpr, consts, env,
for pv, const, key in safe_zip(out_pvs, out_consts, out_keys)
]
new_params = dict(params, name=name, call_jaxpr=lifted_jaxpr)
if is_map:
new_params = dict(
new_params,
mapped_invars=tuple([True] * len(const_tracers) +
[False] * len(env_tracers) +
[True] * len(in_tracers)))
if 'donated_invars' in params:
new_donated_invars = (
(False,) * len(const_tracers) + (False,) * len(env_tracers) +
Expand Down
19 changes: 4 additions & 15 deletions spinoffs/oryx/oryx/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,12 @@
from oryx.distributions import distribution_extensions
from tensorflow_probability.substrates import jax as tfp

__all__ = [
'distribution_extensions'
]


tfd = tfp.distributions

_distributions = {}
__all__ = tfd.__all__

for name in dir(tfd):
for name in __all__:
dist = getattr(tfd, name)
_distributions[name] = dist


for key, val in _distributions.items():
locals()[key] = val

locals()[name] = dist

del _distributions
del distribution_extensions # Only needed for registration.
del tfd
2 changes: 1 addition & 1 deletion spinoffs/oryx/oryx/experimental/nn/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_check_grads(self):
net = net_init.init(net_rng, state.Shape(in_shape))

x = random.normal(data_rng, in_shape)
jtu.check_grads(net, (x,), 2)
jtu.check_grads(net.call, (x,), 2)


def mse(x, y):
Expand Down
24 changes: 15 additions & 9 deletions tensorflow_probability/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,23 @@
from __future__ import division
from __future__ import print_function

import functools

from tensorflow_probability.python.internal import all_util
from tensorflow_probability.python.internal import lazy_loader


# Ensure TensorFlow is importable and its version is sufficiently recent. This
# needs to happen before anything else, since the imports below will try to
# import tensorflow, too.
# pylint: disable=g-import-not-at-top
def _ensure_tf_install():
"""Attempt to import tensorflow, and ensure its version is sufficient.
def _validate_tf_environment(package):
"""Check TF version and (depending on package) warn about TensorFloat32.
Args:
package: Python `str` indicating which package is being imported. Used for
package-dependent warning about TensorFloat32.
Raises:
ImportError: if either tensorflow is not importable or its version is
inadequate.
inadequate.
"""
try:
import tensorflow.compat.v1 as tf
Expand Down Expand Up @@ -62,9 +65,10 @@ def _ensure_tf_install():
required=required_tensorflow_version,
present=tf.__version__))

if tf.config.experimental.tensor_float_32_execution_enabled():
if (package == 'mcmc' and
tf.config.experimental.tensor_float_32_execution_enabled()):
# Must import here, because symbols get pruned to __all__.
import warnings # pylint: disable=g-import-not-at-top
import warnings
warnings.warn(
'TensorFloat-32 matmul/conv are enabled for NVIDIA Ampere+ GPUs. The '
'resulting loss of precision may hinder MCMC convergence. To turn off, '
Expand Down Expand Up @@ -94,6 +98,8 @@ def _ensure_tf_install():
for pkg in _allowed_symbols:
globals()[pkg] = lazy_loader.LazyLoader(
pkg, globals(), 'tensorflow_probability.python.{}'.format(pkg),
on_first_access=_ensure_tf_install)
# These checks need to happen before lazy-loading, since the modules
# themselves will try to import tensorflow, too.
on_first_access=functools.partial(_validate_tf_environment, pkg))

all_util.remove_undocumented(__name__, _allowed_symbols)
2 changes: 0 additions & 2 deletions tensorflow_probability/python/bijectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from tensorflow_probability.python.bijectors.expm1 import Log1p
from tensorflow_probability.python.bijectors.ffjord import FFJORD
from tensorflow_probability.python.bijectors.fill_scale_tril import FillScaleTriL
from tensorflow_probability.python.bijectors.fill_scale_tril import ScaleTriL
from tensorflow_probability.python.bijectors.fill_triangular import FillTriangular
from tensorflow_probability.python.bijectors.frechet_cdf import FrechetCDF
from tensorflow_probability.python.bijectors.generalized_pareto import GeneralizedPareto
Expand Down Expand Up @@ -159,7 +158,6 @@
"ScaleMatvecLinearOperatorBlock",
"ScaleMatvecLU",
"ScaleMatvecTriL",
"ScaleTriL",
"Shift",
"ShiftedGompertzCDF",
"Sigmoid",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
'ScaleMatvecTriL',
'Shift',
'ShiftedGompertzCDF',
'ScaleTriL',
'Sigmoid',
'Sinh',
'SinhArcsinh',
Expand Down
55 changes: 0 additions & 55 deletions tensorflow_probability/python/bijectors/fill_scale_tril.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@
from tensorflow_probability.python.bijectors import transform_diagonal
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import tensor_util
from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import


__all__ = [
'FillScaleTriL',
'ScaleTriL',
]


Expand Down Expand Up @@ -127,56 +125,3 @@ def __init__(self,
validate_args=validate_args,
parameters=parameters,
name=name)


class ScaleTriL(chain.Chain):
"""DEPRECATED. Please use `tfp.bijectors.FillScaleTriL`."""

@deprecation.deprecated(
'2020-01-01',
'`ScaleTriL` has been deprecated and renamed `FillScaleTriL`; please use '
'that symbol instead.')
def __init__(self,
diag_bijector=None,
diag_shift=1e-5,
validate_args=False,
name='scale_tril'):
"""Instantiates the `ScaleTriL` bijector.
Args:
diag_bijector: `Bijector` instance, used to transform the output diagonal
to be positive.
Default value: `None` (i.e., `tfb.Softplus()`).
diag_shift: Float value broadcastable and added to all diagonal entries
after applying the `diag_bijector`. Setting a positive
value forces the output diagonal entries to be positive, but
prevents inverting the transformation for matrices with
diagonal entries less than this value.
Default value: `1e-5`.
validate_args: Python `bool` indicating whether arguments should be
checked for correctness.
Default value: `False` (i.e., arguments are not validated).
name: Python `str` name given to ops managed by this object.
Default value: `scale_tril`.
"""
parameters = dict(locals())
with tf.name_scope(name) as name:
if diag_bijector is None:
diag_bijector = softplus.Softplus(validate_args=validate_args)

if diag_shift is not None:
dtype = dtype_util.common_dtype([diag_bijector, diag_shift], tf.float32)
diag_shift = tensor_util.convert_nonref_to_tensor(diag_shift,
name='diag_shift',
dtype=dtype)
diag_bijector = chain.Chain([
shift.Shift(diag_shift),
diag_bijector
])

super(ScaleTriL, self).__init__(
[transform_diagonal.TransformDiagonal(diag_bijector=diag_bijector),
fill_triangular.FillTriangular()],
validate_args=validate_args,
parameters=parameters,
name=name)
5 changes: 2 additions & 3 deletions tensorflow_probability/python/bijectors/glow.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def bijector_fn(inputs, ignored_input):
output = this_shift(this_scale)
elif target_shape[-1] == output_shape[-1]:

output = shift.Shift(possible_output[..., c:])
output = shift.Shift(possible_output[..., :c])
else:
raise ValueError('Shape inconsistent with input. Expected shape'
'{0} or {1} but tensor was shape {2}'.format(
Expand Down Expand Up @@ -676,7 +676,7 @@ def bijector_fn(inputs, ignored_input):
output = this_shift(this_scale)
elif input_shape[-1] == output_shape[-1]:

output = shift.Shift(possible_output[..., c:])
output = shift.Shift(possible_output[..., :c])
else:
raise ValueError('Shape inconsistent with input. Expected shape'
'{0} or {1} but tensor was shape {2}'.format(
Expand Down Expand Up @@ -860,4 +860,3 @@ def __init__(self, input_shape, output_chan, kernel_shape=3):
super(GlowDefaultExitNetwork, self).__init__([
tfkl.Input(input_shape),
conv(this_nchan, kernel_shape)])

29 changes: 29 additions & 0 deletions tensorflow_probability/python/bijectors/glow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,5 +351,34 @@ def float64_exit(input_shape, output_chan):
self.assertAllFinite(self.evaluate(z))
self.assertAllFinite(self.evaluate(zf64))

def testBijectorFn(self):
"""Test if the bijector function works for additive coupling."""
ims = self._make_images()
def shiftfn(input_shape):
input_nchan = input_shape[-1]
return tf.keras.Sequential([
tf.keras.layers.Input(input_shape),
tf.keras.layers.Conv2D(
input_nchan, 3, padding='same')])

def shiftexitfn(input_shape, output_chan):
return tf.keras.Sequential([
tf.keras.layers.Input(input_shape),
tf.keras.layers.Conv2D(
output_chan, 3, padding='same')])

shiftonlyglow = tfb.Glow(
output_shape=self.output_shape,
num_glow_blocks=2,
num_steps_per_block=1,
coupling_bijector_fn=shiftfn,
exit_bijector_fn=shiftexitfn,
grab_after_block=[0.5, 0.5]
)
z = shiftonlyglow.inverse(ims)
self.evaluate([v.initializer for v in shiftonlyglow.variables])
self.assertAllFinite(self.evaluate(z))


if __name__ == '__main__':
tf.test.main()
3 changes: 0 additions & 3 deletions tensorflow_probability/python/bijectors/hypothesis_testlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,6 @@ def bijector_supports():
'ScaleMatvecTriL':
BijectorSupport(Support.VECTOR_UNCONSTRAINED,
Support.VECTOR_UNCONSTRAINED),
'ScaleTriL':
BijectorSupport(Support.VECTOR_SIZE_TRIANGULAR,
Support.MATRIX_LOWER_TRIL_POSITIVE_DEFINITE),
'Shift':
BijectorSupport(Support.SCALAR_UNCONSTRAINED,
Support.SCALAR_UNCONSTRAINED),
Expand Down
Loading

0 comments on commit ed47dda

Please sign in to comment.