diff --git a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py index 1b73d34edf..8f9617b78d 100644 --- a/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py +++ b/tensorflow_probability/python/experimental/distributions/multitask_gaussian_process_regression_model.py @@ -67,7 +67,8 @@ def _flattened_conditional_mean_fn_helper( observations = tf.convert_to_tensor(observations) if observation_index_points is not None: observation_index_points = nest_util.convert_to_nested_tensor( - observation_index_points, dtype=kernel.dtype, allow_packing=True) + observation_index_points, dtype_hint=kernel.dtype, allow_packing=True + ) k_x_obs_linop = kernel.matrix_over_all_tasks(x, observation_index_points) if solve_on_observations is None: @@ -296,12 +297,13 @@ def __init__(self, input_dtype = dtype_util.common_dtype( dict( - kernel=kernel, index_points=index_points, observation_index_points=observation_index_points, ), dtype_hint=nest_util.broadcast_structure( - kernel.feature_ndims, tf.float32)) + kernel.feature_ndims, tf.float32 + ), + ) # If the input dtype is non-nested float, we infer a single dtype for the # input and the float parameters, which is also the dtype of the MTGP's @@ -573,9 +575,11 @@ def precompute_regression_model( with tf.name_scope(name) as name: if tf.nest.is_nested(kernel.feature_ndims): input_dtype = dtype_util.common_dtype( - [kernel, index_points, observation_index_points], + [index_points, observation_index_points], dtype_hint=nest_util.broadcast_structure( - kernel.feature_ndims, tf.float32)) + kernel.feature_ndims, tf.float32 + ), + ) dtype = dtype_util.common_dtype( [observations, observation_noise_variance, predictive_noise_variance], tf.float32)