Skip to content

Commit

Permalink
Removes/relaxes dtype requirements from kernel in `multitask_gaussi…
Browse files Browse the repository at this point in the history
…an_process_regression_model.py`

PiperOrigin-RevId: 715188555
  • Loading branch information
Googler authored and tensorflower-gardener committed Jan 14, 2025
1 parent f1dd1c7 commit 8d655ca
Showing 1 changed file with 9 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def _flattened_conditional_mean_fn_helper(
observations = tf.convert_to_tensor(observations)
if observation_index_points is not None:
observation_index_points = nest_util.convert_to_nested_tensor(
observation_index_points, dtype=kernel.dtype, allow_packing=True)
observation_index_points, dtype_hint=kernel.dtype, allow_packing=True
)

k_x_obs_linop = kernel.matrix_over_all_tasks(x, observation_index_points)
if solve_on_observations is None:
Expand Down Expand Up @@ -296,12 +297,13 @@ def __init__(self,

input_dtype = dtype_util.common_dtype(
dict(
kernel=kernel,
index_points=index_points,
observation_index_points=observation_index_points,
),
dtype_hint=nest_util.broadcast_structure(
kernel.feature_ndims, tf.float32))
kernel.feature_ndims, tf.float32
),
)

# If the input dtype is non-nested float, we infer a single dtype for the
# input and the float parameters, which is also the dtype of the MTGP's
Expand Down Expand Up @@ -573,9 +575,11 @@ def precompute_regression_model(
with tf.name_scope(name) as name:
if tf.nest.is_nested(kernel.feature_ndims):
input_dtype = dtype_util.common_dtype(
[kernel, index_points, observation_index_points],
[index_points, observation_index_points],
dtype_hint=nest_util.broadcast_structure(
kernel.feature_ndims, tf.float32))
kernel.feature_ndims, tf.float32
),
)
dtype = dtype_util.common_dtype(
[observations, observation_noise_variance,
predictive_noise_variance], tf.float32)
Expand Down

0 comments on commit 8d655ca

Please sign in to comment.