Skip to content

Commit

Permalink
Fix a bug in log_loosum_exp.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 393441595
  • Loading branch information
gjtucker authored and tensorflower-gardener committed Aug 27, 2021
1 parent 4d6a311 commit df96f3d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tensorflow_probability/python/stats/leave_one_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,19 +219,19 @@ def _log_loosum_exp_impl(logx, axis, keepdims, compute_mean):

neg_inf = tf.constant(-np.inf, dtype=dtype)

# When not(d_ok) and is_positive_and_largest then we manually compute the
# When not(d_ok) and is_largest then we manually compute the
# log_loosum_x. (We can efficiently do this for any one point but not all,
# hence we still need the above calculation.) This is good because when
# this condition is met, we cannot use the above calculation; its -inf.
# We now compute the log-leave-out-max-sum, replicate it to every
# point and make sure to select it only when we need to.
max_logx = tf.reduce_max(logx, axis=axis, keepdims=True)
is_positive_and_largest = (logx > 0.) & tf.equal(logx, max_logx)
is_largest = tf.equal(logx, max_logx)
log_lomsum_x = tf.reduce_logsumexp(
tf.where(is_positive_and_largest, neg_inf, logx),
tf.where(is_largest, neg_inf, logx),
axis=axis,
keepdims=True)
d_not_ok_result = tf.where(is_positive_and_largest, log_lomsum_x, neg_inf)
d_not_ok_result = tf.where(is_largest, log_lomsum_x, neg_inf)

log_loosum_x = tf.where(d_ok, d_ok_result, d_not_ok_result)

Expand Down

0 comments on commit df96f3d

Please sign in to comment.