Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Check for statically known rank
Browse files Browse the repository at this point in the history
Parametrize tests
nicolaspi committed Sep 26, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 169f7f5 commit c90e961
Showing 2 changed files with 63 additions and 47 deletions.
6 changes: 6 additions & 0 deletions tensorflow_probability/python/stats/sample_stats.py
Original file line number Diff line number Diff line change
@@ -915,6 +915,12 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
low_indices = high_indices // 2
else:
low_indices = tf.convert_to_tensor(low_indices)

indices_rank = tf.get_static_value(ps.rank(low_indices))
x_rank = tf.get_static_value(ps.rank(x))
if indices_rank is None or x_rank is None:
raise ValueError("`indices` and `x` ranks must be statically known.")

# Broadcast indices together.
high_indices = high_indices + tf.zeros_like(low_indices)
low_indices = low_indices + tf.zeros_like(high_indices)
104 changes: 57 additions & 47 deletions tensorflow_probability/python/stats/sample_stats_test.py
Original file line number Diff line number Diff line change
@@ -15,10 +15,14 @@
"""Tests for Sample Stats Ops."""

# Dependency imports
import functools
import itertools

import numpy as np
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf
from absl.testing import parameterized
from tensorflow.python.framework.errors_impl import InvalidArgumentError

from tensorflow_probability.python.internal import test_util
from tensorflow_probability.python.stats import sample_stats

@@ -721,7 +725,8 @@ def apply_func(vector, l, h):
out = np.transpose(t_out, axes=dims)
return out

def check_gaussian_windowed(self, shape, indice_shape, axis,

def check_gaussian_windowed_func(self, shape, indice_shape, axis,
window_func, np_func):
stat_shape = np.array(shape).astype(np.int32)
stat_shape[axis] = 1
@@ -753,51 +758,56 @@ def check_gaussian_windowed(self, shape, indice_shape, axis,
def _make_dynamic_shape(self, x):
return tf1.placeholder_with_default(x, shape=(None,)*len(x.shape))

def check_windowed(self, func, numpy_func):
check_fn = functools.partial(self.check_gaussian_windowed,
window_func=func, np_func=numpy_func)
check_fn((64, 4, 8), (128, 1, 1), axis=0)
check_fn((64, 4, 8), (32, 1, 1), axis=0)
check_fn((64, 4, 8), (32, 4, 1), axis=0)
check_fn((64, 4, 8), (32, 4, 8), axis=0)
check_fn((64, 4, 8), (64, 4, 8), axis=0)
check_fn((64, 4, 8), (128, 1), axis=0)
check_fn((64, 4, 8), (32,), axis=0)
check_fn((64, 4, 8), (32, 4), axis=0)

check_fn((64, 4, 8), (64, 64, 1), axis=1)
check_fn((64, 4, 8), (1, 64, 1), axis=1)
check_fn((64, 4, 8), (64, 2, 8), axis=1)
check_fn((64, 4, 8), (64, 4, 8), axis=1)
check_fn((64, 4, 8), (16,), axis=1)
check_fn((64, 4, 8), (1, 64), axis=1)

check_fn((64, 4, 8), (64, 4, 64), axis=2)
check_fn((64, 4, 8), (1, 1, 64), axis=2)
check_fn((64, 4, 8), (64, 4, 4), axis=2)
check_fn((64, 4, 8), (1, 1, 4), axis=2)
check_fn((64, 4, 8), (64, 4, 8), axis=2)
check_fn((64, 4, 8), (16,), axis=2)
check_fn((64, 4, 8), (1, 4), axis=2)
check_fn((64, 4, 8), (64, 4), axis=2)

with self.assertRaises(Exception):
# Non broadcastable shapes
check_fn((64, 4, 8), (4, 1, 4), axis=2)

with self.assertRaises(Exception):
# Non broadcastable shapes
check_fn((64, 4, 8), (2, 4), axis=2)

def test_windowed_mean(self):
self.check_windowed(func=sample_stats.windowed_mean, numpy_func=np.mean)

def test_windowed_mean_graph(self):
func = tf.function(sample_stats.windowed_mean)
self.check_windowed(func=func, numpy_func=np.mean)

def test_windowed_variance(self):
self.check_windowed(func=sample_stats.windowed_variance, numpy_func=np.var)
@parameterized.named_parameters(*[(
f"{np_func.__name__} shape={a} indices_shape={b} axis={axis}", a, b, axis,
tf_func, np_func) for a, (b, axis), (tf_func, np_func) in
itertools.product([(64, 4, 8), ],
[((128, 1, 1), 0),
((32, 1, 1), 0),
((32, 4, 1), 0),
((32, 4, 8), 0),
((64, 4, 8), 0),
((128, 1), 0),
((32,), 0),
((32, 4), 0),
((64, 64, 1), 1),
((1, 64, 1), 1),
((64, 2, 8), 1),
((64, 4, 8), 1),
((16,), 1),
((1, 64), 1),
((64, 4, 64), 2),
((1, 1, 64), 2),
((64, 4, 4), 2),
((1, 1, 4), 2),
((64, 4, 8), 2),
((16,), 2),
((1, 4), 2),
((64, 4), 2)],
[
(sample_stats.windowed_mean, np.mean),
(sample_stats.windowed_variance, np.var)
])])
def test_windowed(self, shape, indice_shape, axis, window_func, np_func):
self.check_gaussian_windowed_func(shape, indice_shape, axis, window_func,
np_func)


@parameterized.named_parameters(*[(
f"{np_func.__name__} shape={a} indices_shape={b} axis={axis}", a, b, axis,
tf_func, np_func) for a, (b, axis), (tf_func, np_func) in
itertools.product([(64, 4, 8), ],
[((4, 1, 4), 2), ((2, 4), 2)],
[(sample_stats.windowed_mean, np.mean),
(sample_stats.windowed_variance, np.var)])])
def test_non_broadcastable_shapes(self, shape, indice_shape, axis,
window_func, np_func):
with self.assertRaisesRegexp((IndexError, ValueError, InvalidArgumentError),
'^shape mismatch|Incompatible shapes'):
self.check_gaussian_windowed_func(shape, indice_shape, axis, window_func,
np_func)


@test_util.test_all_tf_execution_regimes

0 comments on commit c90e961

Please sign in to comment.