diff --git a/tensorflow_probability/python/distributions/gaussian_process_regression_model.py b/tensorflow_probability/python/distributions/gaussian_process_regression_model.py index 9fff7be726..7b034a0f71 100644 --- a/tensorflow_probability/python/distributions/gaussian_process_regression_model.py +++ b/tensorflow_probability/python/distributions/gaussian_process_regression_model.py @@ -25,6 +25,7 @@ from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import nest_util from tensorflow_probability.python.internal import parameter_properties +from tensorflow_probability.python.internal import slicing from tensorflow_probability.python.internal import tensor_util from tensorflow_probability.python.math.psd_kernels import schur_complement from tensorflow.python.util import deprecation # pylint: disable=g-direct-tensorflow-import @@ -819,6 +820,7 @@ def _event_ndims_fn(self): shape_fn=parameter_properties.SHAPE_FN_NOT_IMPLEMENTED, ), kernel=parameter_properties.BatchedComponentProperties(), + _conditional_kernel=parameter_properties.BatchedComponentProperties(), observation_noise_variance=parameter_properties.ParameterProperties( event_ndims=0, shape_fn=lambda sample_shape: sample_shape[:-1], @@ -829,3 +831,8 @@ def _event_ndims_fn(self): shape_fn=lambda sample_shape: sample_shape[:-1], default_constraining_bijector_fn=( lambda: softplus_bijector.Softplus(low=dtype_util.eps(dtype))))) + + def __getitem__(self, slices) -> 'GaussianProcessRegressionModel': + # _conditional_mean_fn is a closure over possibly-sliced values, but will + # be rebuilt by the constructor. + return slicing.batch_slice(self, dict(_conditional_mean_fn=None), slices) diff --git a/tensorflow_probability/python/distributions/gaussian_process_regression_model_test.py b/tensorflow_probability/python/distributions/gaussian_process_regression_model_test.py index 8d9f1ca623..71daa27df6 100644 --- a/tensorflow_probability/python/distributions/gaussian_process_regression_model_test.py +++ b/tensorflow_probability/python/distributions/gaussian_process_regression_model_test.py @@ -689,6 +689,25 @@ def testPrivateArgPreventsCholeskyRecomputation(self): self.assertAllClose(d.log_prob(y_obs), d2.log_prob(y_obs)) self.assertEqual(mock_cholesky_fn.call_count, 2) + def test_batch_slice_precomputed_gprm(self): + base_kernel = exponentiated_quadratic.ExponentiatedQuadratic( + length_scale=tf.linspace(tf.ones([]), 2., 64), feature_ndims=0) + x = tf.linspace(tf.zeros([]), 1., 126) + y = tf.linspace(tf.zeros([]), 1.5, 162) + d = gprm.GaussianProcessRegressionModel.precompute_regression_model( + base_kernel, + index_points=y, + observation_index_points=x, + observations=tf.math.sin(x), + observation_noise_variance=1e-3) + self.assertEqual((64,), d.batch_shape) + self.assertEqual((162,), d.event_shape) + self.assertEqual((64, 162,), d.sample(seed=test_util.test_seed()).shape) + + self.assertEqual((), d[2].batch_shape) + self.assertEqual((162,), d[2].event_shape) + self.assertEqual((162,), d[2].sample(seed=test_util.test_seed()).shape) + class GaussianProcessRegressionModelStaticTest( _GaussianProcessRegressionModelTest):