diff --git a/tensorflow_probability/python/internal/dtype_util_test.py b/tensorflow_probability/python/internal/dtype_util_test.py index abfcc7ca11..861b99bebc 100644 --- a/tensorflow_probability/python/internal/dtype_util_test.py +++ b/tensorflow_probability/python/internal/dtype_util_test.py @@ -74,37 +74,44 @@ def testCommonStructuredDtype(self): w = structured_dtype_obj(None) # Check that structured dtypes unify correctly. - self.assertAllEqualNested( + self.assertAllAssertsNested( + self.assertEqual, dtype_util.common_dtype([w, x, y, z]), {'a': tf.float32, 'b': (None, tf.float64)}) # Check that dict `args` works and that `dtype_hint` works. dtype_hint = {'a': tf.int32, 'b': (tf.int32, None)} - self.assertAllEqualNested( + self.assertAllAssertsNested( + self.assertEqual, dtype_util.common_dtype( {'x': x, 'y': y, 'z': z}, dtype_hint=dtype_hint), {'a': tf.float32, 'b': (tf.int32, tf.float64)}) - self.assertAllEqualNested( + self.assertAllAssertsNested( + self.assertEqual, dtype_util.common_dtype([w], dtype_hint=dtype_hint), dtype_hint) # Check that non-nested dtype_hint broadcasts. - self.assertAllEqualNested( + self.assertAllAssertsNested( + self.assertEqual, dtype_util.common_dtype([y, z], dtype_hint=tf.int32), {'a': tf.int32, 'b': (tf.int32, tf.float64)}) # Check that structured `dtype_hint` behaves as expected. s = {'a': [tf.ones([3], tf.float32), 4.], 'b': (np.float64(2.), None)} - self.assertAllEqualNested( + self.assertAllAssertsNested( + self.assertEqual, dtype_util.common_dtype([x, s], dtype_hint=z.dtype), {'a': tf.float32, 'b': (tf.float64, None)}) - self.assertAllEqualNested( + self.assertAllAssertsNested( + self.assertEqual, dtype_util.common_dtype([y, s], dtype_hint=z.dtype), {'a': tf.float32, 'b': (tf.float64, tf.float64)}) t = {'a': [[1., 2., 3.]], 'b': {'c': np.float64(1.), 'd': np.float64(2.)}} - self.assertAllEqualNested( + self.assertAllAssertsNested( + self.assertEqual, dtype_util.common_dtype( [w, t], dtype_hint={'a': tf.float32, 'b': tf.float32}),