Skip to content

Commit

Permalink
Avoid packing dtypes into ndarrays for comparison (which is what asse…
Browse files Browse the repository at this point in the history
…rtAllEqual will do).

PiperOrigin-RevId: 567640584
brianwa84 authored and tensorflower-gardener committed Sep 22, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent dff8111 commit 6cc612f
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions tensorflow_probability/python/internal/dtype_util_test.py
Original file line number Diff line number Diff line change
@@ -74,37 +74,44 @@ def testCommonStructuredDtype(self):
w = structured_dtype_obj(None)

# Check that structured dtypes unify correctly.
self.assertAllEqualNested(
self.assertAllAssertsNested(
self.assertEqual,
dtype_util.common_dtype([w, x, y, z]),
{'a': tf.float32, 'b': (None, tf.float64)})

# Check that dict `args` works and that `dtype_hint` works.
dtype_hint = {'a': tf.int32, 'b': (tf.int32, None)}
self.assertAllEqualNested(
self.assertAllAssertsNested(
self.assertEqual,
dtype_util.common_dtype(
{'x': x, 'y': y, 'z': z}, dtype_hint=dtype_hint),
{'a': tf.float32, 'b': (tf.int32, tf.float64)})
self.assertAllEqualNested(
self.assertAllAssertsNested(
self.assertEqual,
dtype_util.common_dtype([w], dtype_hint=dtype_hint),
dtype_hint)

# Check that non-nested dtype_hint broadcasts.
self.assertAllEqualNested(
self.assertAllAssertsNested(
self.assertEqual,
dtype_util.common_dtype([y, z], dtype_hint=tf.int32),
{'a': tf.int32, 'b': (tf.int32, tf.float64)})

# Check that structured `dtype_hint` behaves as expected.
s = {'a': [tf.ones([3], tf.float32), 4.],
'b': (np.float64(2.), None)}
self.assertAllEqualNested(
self.assertAllAssertsNested(
self.assertEqual,
dtype_util.common_dtype([x, s], dtype_hint=z.dtype),
{'a': tf.float32, 'b': (tf.float64, None)})
self.assertAllEqualNested(
self.assertAllAssertsNested(
self.assertEqual,
dtype_util.common_dtype([y, s], dtype_hint=z.dtype),
{'a': tf.float32, 'b': (tf.float64, tf.float64)})

t = {'a': [[1., 2., 3.]], 'b': {'c': np.float64(1.), 'd': np.float64(2.)}}
self.assertAllEqualNested(
self.assertAllAssertsNested(
self.assertEqual,
dtype_util.common_dtype(
[w, t],
dtype_hint={'a': tf.float32, 'b': tf.float32}),

0 comments on commit 6cc612f

Please sign in to comment.