diff --git a/tensorflow_probability/python/bijectors/BUILD b/tensorflow_probability/python/bijectors/BUILD index d07635b0fb..9eff496250 100644 --- a/tensorflow_probability/python/bijectors/BUILD +++ b/tensorflow_probability/python/bijectors/BUILD @@ -1521,12 +1521,16 @@ multi_substrate_py_test( name = "softplus_test", size = "small", srcs = ["softplus_test.py"], + jax_size = "medium", deps = [ ":bijector_test_util", ":bijectors", + # absl/testing:parameterized dep, # numpy dep, # tensorflow dep, "//tensorflow_probability/python/internal:test_util", + "//tensorflow_probability/python/math", + # tensorflow/compiler/jit dep, ], ) diff --git a/tensorflow_probability/python/bijectors/softplus.py b/tensorflow_probability/python/bijectors/softplus.py index 75b8e3086e..cad12b2f0b 100644 --- a/tensorflow_probability/python/bijectors/softplus.py +++ b/tensorflow_probability/python/bijectors/softplus.py @@ -33,6 +33,31 @@ ] +JAX_MODE = False # Overwritten by rewrite script. + + +# TODO(b/155501444): Remove this when tf.nn.softplus is fixed. +if JAX_MODE: + _stable_grad_softplus = tf.nn.softplus +else: + + @tf.custom_gradient + def _stable_grad_softplus(x): + """A (more) numerically stable softplus than `tf.nn.softplus`.""" + x = tf.convert_to_tensor(x) + if x.dtype == tf.float64: + cutoff = -20 + else: + cutoff = -9 + + y = tf.where(x < cutoff, tf.math.log1p(tf.exp(x)), tf.nn.softplus(x)) + + def grad_fn(dy): + return dy * tf.where(x < cutoff, tf.exp(x), tf.nn.sigmoid(x)) + + return y, grad_fn + + class Softplus(bijector.Bijector): """Bijector which computes `Y = g(X) = Log[1 + exp(X)]`. @@ -101,9 +126,9 @@ def _is_increasing(cls): def _forward(self, x): if self.hinge_softness is None: - return tf.math.softplus(x) + return _stable_grad_softplus(x) hinge_softness = tf.cast(self.hinge_softness, x.dtype) - return hinge_softness * tf.math.softplus(x / hinge_softness) + return hinge_softness * _stable_grad_softplus(x / hinge_softness) def _inverse(self, y): if self.hinge_softness is None: diff --git a/tensorflow_probability/python/bijectors/softplus_test.py b/tensorflow_probability/python/bijectors/softplus_test.py index 6d9553be46..af3cfa1919 100644 --- a/tensorflow_probability/python/bijectors/softplus_test.py +++ b/tensorflow_probability/python/bijectors/softplus_test.py @@ -20,9 +20,11 @@ # Dependency imports +from absl.testing import parameterized import numpy as np import tensorflow.compat.v2 as tf from tensorflow_probability.python import bijectors as tfb +from tensorflow_probability.python import math as tfp_math from tensorflow_probability.python.bijectors import bijector_test_util from tensorflow_probability.python.internal import test_util @@ -149,6 +151,25 @@ def testVariableHingeSoftness(self): with tf.control_dependencies([hinge_softness.assign(0.)]): self.evaluate(b.forward(0.5)) + @parameterized.named_parameters( + ('32bitGraph', np.float32, False), + ('64bitGraph', np.float64, False), + ('32bitXLA', np.float32, True), + ('64bitXLA', np.float64, True), + ) + @test_util.numpy_disable_gradient_test + def testLeftTailGrad(self, dtype, do_compile): + x = np.linspace(-50., -8., 1000).astype(dtype) + + @tf.function(autograph=False, experimental_compile=do_compile) + def fn(x): + return tf.math.log(tfb.Softplus().forward(x)) + + _, grad = tfp_math.value_and_gradient(fn, x) + + true_grad = 1 / (1 + np.exp(-x)) / np.log1p(np.exp(x)) + self.assertAllClose(true_grad, self.evaluate(grad), atol=1e-3) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_probability/python/distributions/joint_distribution_named.py b/tensorflow_probability/python/distributions/joint_distribution_named.py index 67c194f518..691827b494 100644 --- a/tensorflow_probability/python/distributions/joint_distribution_named.py +++ b/tensorflow_probability/python/distributions/joint_distribution_named.py @@ -287,5 +287,7 @@ def _convert_to_dict(x): if isinstance(x, collections.OrderedDict): return x if hasattr(x, '_asdict'): - return x._asdict() + # Wrap with `OrderedDict` to indicate that namedtuples have a well-defined + # order (by default, they convert to just `dict` in Python 3.8+). + return collections.OrderedDict(x._asdict()) return dict(x) diff --git a/tensorflow_probability/python/layers/distribution_layer.py b/tensorflow_probability/python/layers/distribution_layer.py index 075d887a4d..2050957c43 100644 --- a/tensorflow_probability/python/layers/distribution_layer.py +++ b/tensorflow_probability/python/layers/distribution_layer.py @@ -25,7 +25,7 @@ import pickle # Dependency imports -from cloudpickle import CloudPickler +from cloudpickle.cloudpickle import CloudPickler import numpy as np import six import tensorflow.compat.v2 as tf @@ -47,7 +47,7 @@ from tensorflow_probability.python.distributions import variational_gaussian_process as variational_gaussian_process_lib from tensorflow_probability.python.internal import distribution_util as dist_util from tensorflow_probability.python.layers.internal import distribution_tensor_coercible as dtc -from tensorflow_probability.python.layers.internal import tensor_tuple as tensor_tuple +from tensorflow_probability.python.layers.internal import tensor_tuple from tensorflow.python.keras.utils import tf_utils as keras_tf_utils # pylint: disable=g-direct-tensorflow-import