From 30c737caec412e59ee79b979713af5f138249dc7 Mon Sep 17 00:00:00 2001 From: siege Date: Thu, 22 Aug 2024 15:34:55 -0700 Subject: [PATCH] Make PowerSpherical preserve 32 bit precision. PiperOrigin-RevId: 666514877 --- .../python/distributions/power_spherical.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow_probability/python/distributions/power_spherical.py b/tensorflow_probability/python/distributions/power_spherical.py index 820566e930..acfeafa13a 100644 --- a/tensorflow_probability/python/distributions/power_spherical.py +++ b/tensorflow_probability/python/distributions/power_spherical.py @@ -217,8 +217,9 @@ def _log_normalization(self, concentration=None, mean_direction=None): concentration1 = concentration + (event_size - 1.) / 2. concentration0 = (event_size - 1.) / 2. - return ((concentration1 + concentration0) * np.log(2.) + - concentration0 * np.log(np.pi) + + np_dtype = dtype_util.as_numpy_dtype(concentration.dtype) + return ((concentration1 + concentration0) * np.log(2.).astype(np_dtype) + + concentration0 * np.log(np.pi).astype(np_dtype) + special.log_gamma_difference(concentration0, concentration1)) def _sample_control_dependencies(self, samples):