Skip to content

Commit

Permalink
NUTS sampler clean up:
Browse files Browse the repository at this point in the history
- Add reference of the current implementation.
- Expose energy diagnostic for each sample. It is important for computing marginal energy plot and Bayesian fraction of missing information (BFMI).
- Implement generalized U turn criteria (currently under a static flag). This is to lay down the ground work for extending NUTS to non-Euclidean metric

PiperOrigin-RevId: 264675589
  • Loading branch information
junpenglao authored and tensorflower-gardener committed Aug 21, 2019
1 parent 1509c0c commit 7757171
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 31 deletions.
4 changes: 3 additions & 1 deletion discussion/technical_note_on_unrolled_nuts.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ The NUTS recursion is a pre-order tree traversal. In the original algorithm 3,
this traversal terminates when the trajectory makes a U turn or there is
divergent sample (during leapfrog integration). We observed that:
1. Typical implementations of NUTS cap the the recursion
(by default we cap `max_tree_depth` to 6).
(by default we cap `max_tree_depth` to 10).
2. The NUTS computation is dominated by gradient evals in the leapfrog
calculation.
3. The remain computation is in U-turn checking and slice sampling which notably
Expand Down Expand Up @@ -302,6 +302,8 @@ step 1(0): x0 ==> U([x_,x0], [0,1]) ==> x1 --> MH([x',x1], 1/1) --> x'
```

which means that for the purpose of slice sampling, it could be memory-less.
This is also valid for multinominal sampling as we accumulating the weight the
same way.

## FAQ

Expand Down
106 changes: 76 additions & 30 deletions tensorflow_probability/python/mcmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"""No U-Turn Sampler.
The implementation closely follows [1; Algorithm 3], with Multinomial sampling
on the tree instead of slice sampling.
on the tree (instead of slice sampling) and a generalized No-U-Turn termination
criterion [2; Appendix A].
Achieves batch execution across chains by precomputing the recursive tree
doubling data access patterns and then executes this "unrolled" data pattern via
Expand All @@ -27,6 +28,9 @@
Setting Path Lengths in Hamiltonian Monte Carlo.
In _Journal of Machine Learning Research_, 15(1):1593-1623, 2014.
http://jmlr.org/papers/volume15/hoffman14a/hoffman14a.pdf
[2]: Michael Betancourt. A Conceptual Introduction to Hamiltonian Monte Carlo.
_arXiv preprint arXiv:1701.02434_, 2018. https://arxiv.org/abs/1701.02434
"""

from __future__ import absolute_import
Expand All @@ -51,9 +55,13 @@

TREE_COUNT_DTYPE = tf.int32 # Default: tf.int32

# Whether to use slice sampling (original NUTS implementation) or multinomial
# sampling from the tree trajectory.
# Whether to use slice sampling (original NUTS implementation in [1]) or
# multinomial sampling (implementation in [2]) from the tree trajectory.
MULTINOMIAL_SAMPLE = True # Default: True

# Whether to use U turn criteria in [1] or generalized U turn criteria in [2]
# to check the tree trajectory.
GENERALIZED_UTURN = True # Default: True
##############################################################
### END STATIC CONFIGURATION #################################
##############################################################
Expand All @@ -72,6 +80,7 @@
'is_accepted',
'reach_max_depth',
'has_divergence',
'energy',
])

MomentumStateSwap = collections.namedtuple('MomentumStateSwap', [
Expand All @@ -98,6 +107,7 @@
'state',
'target',
'target_grad_parts',
'energy',
'weight',
])

Expand Down Expand Up @@ -298,6 +308,7 @@ def _copy(v):
state=current_state,
target=current_target_log_prob,
target_grad_parts=previous_kernel_results.grads_target_log_prob,
energy=init_energy,
weight=init_weight)

initial_step_metastate = TreeDoublingMetaState(
Expand Down Expand Up @@ -363,6 +374,7 @@ def _copy(v):
is_accepted=new_step_metastate.is_accepted,
reach_max_depth=new_step_metastate.continue_tree,
has_divergence=~new_step_metastate.not_divergence,
energy=new_step_metastate.candidate_state.energy
)

result_state = new_step_metastate.candidate_state.state
Expand Down Expand Up @@ -426,6 +438,7 @@ def _init(shape_and_dtype):
has_divergence=tf.zeros_like(current_target_log_prob,
dtype=tf.bool,
name='has_divergence'),
energy=compute_hamiltonian(current_target_log_prob, dummy_momentum)
)

def _start_trajectory_batched(self, state, target_log_prob):
Expand Down Expand Up @@ -495,7 +508,8 @@ def loop_tree_doubling(self, step_size, momentum_state_memory,
final_not_divergence,
continue_tree_final,
energy_diff_tree_sum,
leapfrogs_taken,
momentum_tree_cumsum,
leapfrogs_taken
] = self._build_sub_tree(
directions_expanded,
integrator,
Expand Down Expand Up @@ -550,6 +564,11 @@ def loop_tree_doubling(self, step_size, momentum_state_memory,
for grad0, grad1 in zip(candidate_tree_state.target_grad_parts,
last_candidate_state.target_grad_parts)
],
energy=tf.where(
_rightmost_expand_to_rank(
choose_new_state,
prefer_static.rank(candidate_tree_state.target)),
candidate_tree_state.energy, last_candidate_state.energy),
weight=weight_sum)

# Update left right information of the trajectory, and check trajectory
Expand All @@ -571,10 +590,15 @@ def loop_tree_doubling(self, step_size, momentum_state_memory,
for l, r in zip(tf.nest.flatten(tree_final_states),
tf.nest.flatten(tree_otherend_states))
])

if GENERALIZED_UTURN:
state_diff = momentum_tree_cumsum
else:
state_diff = [s[1] - s[0] for s in new_step_state.state]

no_u_turns_trajectory = has_not_u_turn(
[s[0] for s in new_step_state.state],
state_diff,
[m[0] for m in new_step_state.momentum],
[s[1] for s in new_step_state.state],
[m[1] for m in new_step_state.momentum],
log_prob_rank=len(batch_shape))

Expand Down Expand Up @@ -610,36 +634,43 @@ def _build_sub_tree(self,
dtype=current_step_meta_info.init_energy.dtype))
else:
init_weight = tf.zeros(batch_shape, dtype=TREE_COUNT_DTYPE)

init_momentum_cumsum = [tf.zeros_like(x) for x in initial_state.momentum]
initial_state_candidate = TreeDoublingStateCandidate(
state=initial_state.state,
target=initial_state.target,
target_grad_parts=initial_state.target_grad_parts,
energy=initial_state.target,
weight=init_weight)
energy_diff_sum = tf.zeros_like(current_step_meta_info.init_energy,
name='energy_diff_sum')
[
_,
energy_diff_tree_sum,
momentum_tree_cumsum,
leapfrogs_taken,
final_state,
candidate_tree_state,
final_continue_tree,
final_not_divergence,
momentum_state_memory,
] = tf.while_loop(
cond=lambda iter_, energy_diff_sum, leapfrogs_taken, state, state_c, # pylint: disable=g-long-lambda
continue_tree, not_divergence, momentum_state_memory: (
cond=lambda iter_, energy_diff_sum, init_momentum_cumsum, # pylint: disable=g-long-lambda
leapfrogs_taken, state, state_c, continue_tree,
not_divergence, momentum_state_memory: (
(iter_ < nsteps) & tf.reduce_any(continue_tree)),
body=lambda iter_, energy_diff_sum, leapfrogs_taken, state, state_c, # pylint: disable=g-long-lambda
continue_tree, not_divergence, momentum_state_memory: (
body=lambda iter_, energy_diff_sum, init_momentum_cumsum, # pylint: disable=g-long-lambda
leapfrogs_taken, state, state_c, continue_tree,
not_divergence, momentum_state_memory: (
self._loop_build_sub_tree(
directions, integrator, current_step_meta_info,
iter_, energy_diff_sum, leapfrogs_taken, state,
state_c, continue_tree, not_divergence,
momentum_state_memory)),
iter_, energy_diff_sum, init_momentum_cumsum,
leapfrogs_taken, state, state_c, continue_tree,
not_divergence, momentum_state_memory)),
loop_vars=(
tf.zeros([], dtype=tf.int32, name='iter'),
energy_diff_sum,
init_momentum_cumsum,
tf.zeros(batch_shape, dtype=TREE_COUNT_DTYPE),
initial_state,
initial_state_candidate,
Expand All @@ -656,6 +687,7 @@ def _build_sub_tree(self,
final_not_divergence,
final_continue_tree,
energy_diff_tree_sum,
momentum_tree_cumsum,
leapfrogs_taken,
)

Expand All @@ -665,6 +697,7 @@ def _loop_build_sub_tree(self,
current_step_meta_info,
iter_,
energy_diff_sum_previous,
momentum_cumsum_previous,
leapfrogs_taken,
prev_tree_state,
candidate_tree_state,
Expand All @@ -689,6 +722,8 @@ def _loop_build_sub_tree(self,
state=next_state_parts,
target=next_target,
target_grad_parts=next_target_grad_parts)
momentum_cumsum = [p0 + p1 for p0, p1 in zip(momentum_cumsum_previous,
next_momentum_parts)]
# If the tree have not yet terminated previously, we count this leapfrog.
leapfrogs_taken = tf.where(
continue_tree_previous, leapfrogs_taken + 1, leapfrogs_taken)
Expand All @@ -704,6 +739,11 @@ def _loop_build_sub_tree(self,
write_instruction.gather([iter_ // 2]),
self.max_tree_depth)

if GENERALIZED_UTURN:
state_to_write = momentum_cumsum
else:
state_to_write = next_state_parts

momentum_state_memory = MomentumStateSwap(
momentum_swap=[
tf.tensor_scatter_nd_update(old, [write_index], [new])
Expand All @@ -713,7 +753,7 @@ def _loop_build_sub_tree(self,
state_swap=[
tf.tensor_scatter_nd_update(old, [write_index], [new])
for old, new in zip(momentum_state_memory.state_swap,
next_state_parts)
state_to_write)
])
batch_shape = prefer_static.shape(next_target)
has_not_u_turn_at_even_step = tf.ones(batch_shape, dtype=tf.bool)
Expand All @@ -724,15 +764,15 @@ def _loop_build_sub_tree(self,
lambda: has_not_u_turn_at_even_step,
lambda: has_not_u_turn_at_odd_step( # pylint: disable=g-long-lambda
read_index, directions, momentum_state_memory,
next_momentum_parts, next_state_parts,
next_momentum_parts, state_to_write,
has_not_u_turn_at_even_step,
log_prob_rank=prefer_static.rank(next_target)))

energy = compute_hamiltonian(next_target, next_momentum_parts)
energy = tf.where(tf.math.is_nan(energy),
tf.constant(-np.inf, dtype=energy.dtype),
energy)
energy_diff = energy - init_energy
current_energy = tf.where(tf.math.is_nan(energy),
tf.constant(-np.inf, dtype=energy.dtype),
energy)
energy_diff = current_energy - init_energy

if MULTINOMIAL_SAMPLE:
not_divergent = -energy_diff < self.max_energy_diff
Expand Down Expand Up @@ -777,6 +817,10 @@ def _loop_build_sub_tree(self,
for grad0, grad1 in zip(next_target_grad_parts,
candidate_tree_state.target_grad_parts)
],
energy=tf.where(
_rightmost_expand_to_rank(is_sample_accepted,
prefer_static.rank(next_target)),
current_energy, init_energy),
weight=weight_sum)

continue_tree = not_divergent & continue_tree_previous
Expand All @@ -795,6 +839,7 @@ def _loop_build_sub_tree(self,
return (
iter_ + 1,
energy_diff_sum,
momentum_cumsum,
leapfrogs_taken,
next_tree_state,
next_candidate_tree_state,
Expand All @@ -819,12 +864,14 @@ def _get_left_state_and_check_u_turn(left_current_index, no_u_turns_last):
tf.gather(x, left_current_index, axis=0)
for x in momentum_state_memory.state_swap
]
state_diff = [s1 - s2 for s1, s2 in zip(state_right, state_left)]
if not GENERALIZED_UTURN:
state_diff = [tf.where(d, m, -m) for d, m in zip(direction, state_diff)]

no_u_turns_current = has_not_u_turn(
state_left,
[tf.where(d, m, -m) for d, m in zip(direction, momentum_left)],
state_right,
[tf.where(d, m, -m) for d, m in zip(direction, momentum_right)],
state_diff,
momentum_left,
momentum_right,
log_prob_rank)
return left_current_index + 1, no_u_turns_current & no_u_turns_last

Expand All @@ -836,24 +883,23 @@ def _get_left_state_and_check_u_turn(left_current_index, no_u_turns_last):
return no_u_turns_within_tree


def has_not_u_turn(state_left,
def has_not_u_turn(state_diff,
momentum_left,
state_right,
momentum_right,
log_prob_rank):
"""If two given states and momentum do not exhibit a U-turn pattern."""
"""If the trajectory does not exhibit a U-turn pattern."""
with tf.name_scope('has_not_u_turn'):
batch_dot_product_left = sum([
tf.reduce_sum( # pylint: disable=g-complex-comprehension
(s1 - s2) * m,
s_diff * m,
axis=tf.range(log_prob_rank, prefer_static.rank(m)))
for s1, s2, m in zip(state_right, state_left, momentum_left)
for s_diff, m in zip(state_diff, momentum_left)
])
batch_dot_product_right = sum([
tf.reduce_sum( # pylint: disable=g-complex-comprehension
(s1 - s2) * m,
s_diff * m,
axis=tf.range(log_prob_rank, prefer_static.rank(m)))
for s1, s2, m in zip(state_right, state_left, momentum_right)
for s_diff, m in zip(state_diff, momentum_right)
])
return (batch_dot_product_left >= 0) & (batch_dot_product_right >= 0)

Expand Down

0 comments on commit 7757171

Please sign in to comment.