Skip to content

Commit

Permalink
Fix batch slicing of precomputed GPRM.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 548144627
  • Loading branch information
brianwa84 authored and jburnim committed Jul 28, 2023
1 parent 6f62a06 commit 3f42739
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.python.internal import nest_util
from tensorflow_probability.python.internal import parameter_properties
from tensorflow_probability.python.internal import slicing
from tensorflow_probability.python.internal import tensor_util
from tensorflow_probability.python.math.psd_kernels import schur_complement
from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import
Expand Down Expand Up @@ -819,6 +820,7 @@ def _event_ndims_fn(self):
shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED,
),
kernel=parameter_properties.BatchedComponentProperties(),
_conditional_kernel=parameter_properties.BatchedComponentProperties(),
observation_noise_variance=parameter_properties.ParameterProperties(
event_ndims=0,
shape_fn=lambda sample_shape: sample_shape[:-1],
Expand All @@ -829,3 +831,8 @@ def _event_ndims_fn(self):
shape_fn=lambda sample_shape: sample_shape[:-1],
default_constraining_bijector_fn=(
lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype)))))

def __getitem__(self, slices) -> 'GaussianProcessRegressionModel':
# _conditional_mean_fn is a closure over possibly-sliced values, but will
# be rebuilt by the constructor.
return slicing.batch_slice(self, dict(_conditional_mean_fn=None), slices)
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,25 @@ def testPrivateArgPreventsCholeskyRecomputation(self):
self.assertAllClose(d.log_prob(y_obs), d2.log_prob(y_obs))
self.assertEqual(mock_cholesky_fn.call_count, 2)

def test_batch_slice_precomputed_gprm(self):
base_kernel = exponentiated_quadratic.ExponentiatedQuadratic(
length_scale=tf.linspace(tf.ones([]), 2., 64), feature_ndims=0)
x = tf.linspace(tf.zeros([]), 1., 126)
y = tf.linspace(tf.zeros([]), 1.5, 162)
d = gprm.GaussianProcessRegressionModel.precompute_regression_model(
base_kernel,
index_points=y,
observation_index_points=x,
observations=tf.math.sin(x),
observation_noise_variance=1e-3)
self.assertEqual((64,), d.batch_shape)
self.assertEqual((162,), d.event_shape)
self.assertEqual((64, 162,), d.sample(seed=test_util.test_seed()).shape)

self.assertEqual((), d[2].batch_shape)
self.assertEqual((162,), d[2].event_shape)
self.assertEqual((162,), d[2].sample(seed=test_util.test_seed()).shape)


class GaussianProcessRegressionModelStaticTest(
_GaussianProcessRegressionModelTest):
Expand Down

0 comments on commit 3f42739

Please sign in to comment.