From dff8111eb646fbbd94818d63eb36e26a891bd02e Mon Sep 17 00:00:00 2001 From: Srinivas Vasudevan Date: Tue, 19 Sep 2023 19:14:36 -0700 Subject: [PATCH] Ensure low_rank_cholesky works well with `LinearOperator`s on non-TF backends. PiperOrigin-RevId: 566817003 --- tensorflow_probability/python/math/linalg.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow_probability/python/math/linalg.py b/tensorflow_probability/python/math/linalg.py index c7af29af78..fd421f354c 100644 --- a/tensorflow_probability/python/math/linalg.py +++ b/tensorflow_probability/python/math/linalg.py @@ -454,8 +454,9 @@ def low_rank_cholesky(matrix, max_rank, trace_atol=0, trace_rtol=0, name=None): dtype_hint=tf.float32) if not isinstance(matrix, tf.linalg.LinearOperator): matrix = tf.convert_to_tensor(matrix, name='matrix', dtype=dtype) + matrix = tf.linalg.LinearOperatorFullMatrix(matrix) - mtrace = tf.linalg.trace(matrix) + mtrace = matrix.trace() mrank = tensorshape_util.rank(matrix.shape) batch_dims = mrank - 2 @@ -485,7 +486,7 @@ def lr_cholesky_body(i, lr, residual_diag): matrix_row = tf.squeeze(matrix.row(max_j), axis=-2) else: matrix_row = tf.gather( - matrix, max_j, axis=-1, batch_dims=batch_dims)[..., 0] + matrix.to_dense(), max_j, axis=-1, batch_dims=batch_dims)[..., 0] # residual_matrix[max_j, :] = matrix_row[max_j, :] - (lr * lr^t)[max_j, :] # And (lr * lr^t)[max_j, :] = lr[max_j, :] * lr^t lr_row_maxj = tf.gather(lr, max_j, axis=-2, batch_dims=batch_dims) @@ -530,7 +531,7 @@ def lr_cholesky_body(i, lr, residual_diag): lr = tf.zeros(matrix.shape, dtype=matrix.dtype)[..., :max_rank] - mdiag = tf.linalg.diag_part(matrix) + mdiag = matrix.diag_part() i, lr, residual_diag = tf.while_loop( cond=lr_cholesky_cond, body=lr_cholesky_body,