Skip to content

Commit

Permalink
Provide interface for controlling the number of HMC iterations during…
Browse files Browse the repository at this point in the history
… which to adapt the step size.

Also fixes a small bug that was preventing list-valued step sizes.

PiperOrigin-RevId: 208095852
  • Loading branch information
davmre authored and Copybara-Service committed Aug 9, 2018
1 parent 5eb4973 commit d740d84
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2235,7 +2235,7 @@
" target_log_prob_fn=unnormalized_posterior_log_prob,\n",
" num_leapfrog_steps=2,\n",
" step_size=step_size,\n",
" step_size_update_fn=tfp.mcmc.step_size_simple_update,\n",
" step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy(),\n",
" state_gradients_are_stopped=True)\n",
"\n",
"init_random_weights = tf.placeholder(dtype, shape=[len(log_county_uranium_ppm)])\n",
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_probability/python/mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tensorflow_probability.python.mcmc.diagnostic import effective_sample_size
from tensorflow_probability.python.mcmc.diagnostic import potential_scale_reduction
from tensorflow_probability.python.mcmc.hmc import HamiltonianMonteCarlo
from tensorflow_probability.python.mcmc.hmc import step_size_simple_update
from tensorflow_probability.python.mcmc.hmc import make_simple_step_size_update_policy
from tensorflow_probability.python.mcmc.hmc import UncalibratedHamiltonianMonteCarlo
from tensorflow_probability.python.mcmc.kernel import TransitionKernel
from tensorflow_probability.python.mcmc.langevin import MetropolisAdjustedLangevinAlgorithm
Expand Down Expand Up @@ -61,7 +61,7 @@
'sample_annealed_importance_chain',
'sample_chain',
'sample_halton_sequence',
'step_size_simple_update',
'make_simple_step_size_update_policy',
]

remove_undocumented(__name__, _allowed_symbols)
122 changes: 85 additions & 37 deletions tensorflow_probability/python/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import collections
# Dependency imports

import numpy as np
import tensorflow as tf

from tensorflow_probability.python.mcmc import kernel as kernel_base
Expand All @@ -33,7 +34,7 @@
__all__ = [
'HamiltonianMonteCarlo',
'UncalibratedHamiltonianMonteCarlo',
'step_size_simple_update',
'make_simple_step_size_update_policy',
]


Expand All @@ -52,24 +53,33 @@
])


def step_size_simple_update(
step_size_var,
kernel_results,
target_rate=0.75,
decrement_multiplier=0.01,
increment_multiplier=0.01):
"""Updates (list of) `step_size` using a standard adaptive MCMC procedure.
def make_simple_step_size_update_policy(num_adaptation_steps=None,
target_rate=0.75,
decrement_multiplier=0.01,
increment_multiplier=0.01,
step_counter=None):
"""Create a function implementing a step-size update policy.
This function increases or decreases the `step_size_var` based on the average
of `exp(minimum(0., log_accept_ratio))`. It is based on
The simple policy increases or decreases the `step_size_var` based on the
average of `exp(minimum(0., log_accept_ratio))`. It is based on
[Section 4.2 of Andrieu and Thoms (2008)](
http://www4.ncsu.edu/~rsmith/MA797V_S12/Andrieu08_AdaptiveMCMC_Tutorial.pdf).
The `num_adaptation_steps` argument is set independently of any burnin
for the overall chain. In general, adaptation prevents the chain from
reaching a stationary distribution, so obtaining consistent samples requires
`num_adaptation_steps` be set to a value [somewhat smaller](
http://andrewgelman.com/2017/12/15/burn-vs-warm-iterative-simulation-algorithms/#comment-627745)
than the number of burnin steps. However, it may sometimes be helpful to set
`num_adaptation_steps` to a larger value during development in order to
inspect the behavior of the chain during adaptation.
Args:
step_size_var: (List of) `tf.Variable`s representing the per `state_part`
HMC `step_size`.
kernel_results: `collections.namedtuple` containing `Tensor`s
representing values from most recent call to `one_step`.
num_adaptation_steps: Scalar `int` `Tensor` number of initial steps to
during which to adjust the step size. This may be greater, less than, or
equal to the number of burnin steps. If `None`, the step size is adapted
on every step.
Default value: `None`.
target_rate: Scalar `Tensor` representing desired `accept_ratio`.
Default value: `0.75` (i.e., [center of asymptotically optimal
rate](https://arxiv.org/abs/1411.6669)).
Expand All @@ -79,29 +89,63 @@ def step_size_simple_update(
increment_multiplier: `Tensor` representing amount to upscale current
`step_size`.
Default value: `0.01`.
step_counter: Scalar `int` `Variable` specifying the current step. The step
size is adapted iff `step_counter < num_adaptation_steps`.
Default value: if `None`, an internal variable
`step_size_adaptation_step_counter` is created and initialized to `-1`.
Returns:
step_size_assign: (List of) `Tensor`(s) representing updated
`step_size_var`(s).
step_size_simple_update_fn: Callable that takes args
`step_size_var, kernel_results` and returns updated step size(s).
"""
if kernel_results is None:
if mcmc_util.is_list_like(step_size_var):
return [tf.identity(ss) for ss in step_size_var]
return tf.identity(step_size_var)
log_n = tf.log(tf.cast(tf.size(kernel_results.log_accept_ratio),
kernel_results.log_accept_ratio.dtype))
log_mean_accept_ratio = tf.reduce_logsumexp(
tf.minimum(kernel_results.log_accept_ratio, 0.)) - log_n
adjustment = tf.where(
log_mean_accept_ratio < tf.log(target_rate),
-decrement_multiplier / (1. + decrement_multiplier),
increment_multiplier)
if not mcmc_util.is_list_like(step_size_var):
return step_size_var.assign_add(step_size_var * adjustment)
step_size_assign = []
for ss in step_size_var:
step_size_assign.append(ss.assign_add(ss * adjustment))
return step_size_assign
if step_counter is None and num_adaptation_steps is not None:
step_counter = tf.get_variable(
name='step_size_adaptation_step_counter',
initializer=np.array(-1, dtype=np.int32),
trainable=False,
use_resource=True)

def step_size_simple_update_fn(step_size_var, kernel_results):
"""Updates (list of) `step_size` using a standard adaptive MCMC procedure.
Args:
step_size_var: (List of) `tf.Variable`s representing the per `state_part`
HMC `step_size`.
kernel_results: `collections.namedtuple` containing `Tensor`s
representing values from most recent call to `one_step`.
Returns:
step_size_assign: (List of) `Tensor`(s) representing updated
`step_size_var`(s).
"""

if kernel_results is None:
if mcmc_util.is_list_like(step_size_var):
return [tf.identity(ss) for ss in step_size_var]
return tf.identity(step_size_var)
log_n = tf.log(tf.cast(tf.size(kernel_results.log_accept_ratio),
kernel_results.log_accept_ratio.dtype))
log_mean_accept_ratio = tf.reduce_logsumexp(
tf.minimum(kernel_results.log_accept_ratio, 0.)) - log_n
adjustment = tf.where(
log_mean_accept_ratio < tf.log(target_rate),
-decrement_multiplier / (1. + decrement_multiplier),
increment_multiplier)

def build_assign_op():
if mcmc_util.is_list_like(step_size_var):
return [ss.assign_add(ss * adjustment) for ss in step_size_var]
return step_size_var.assign_add(step_size_var * adjustment)

if num_adaptation_steps is None:
return build_assign_op()
else:
with tf.control_dependencies([step_counter.assign_add(1)]):
return tf.cond(step_counter < num_adaptation_steps,
build_assign_op,
lambda: step_size_var)

return step_size_simple_update_fn


class HamiltonianMonteCarlo(kernel_base.TransitionKernel):
Expand Down Expand Up @@ -150,7 +194,7 @@ def unnormalized_log_prob(x):
target_log_prob_fn=unnormalized_log_prob,
num_leapfrog_steps=3,
step_size=step_size,
step_size_update_fn=tfp.mcmc.step_size_simple_update)
step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy())
# Run the chain (with burn-in).
samples, kernel_results = tfp.mcmc.sample_chain(
Expand Down Expand Up @@ -268,7 +312,7 @@ def unnormalized_posterior_log_prob(w):
target_log_prob_fn=unnormalized_posterior_log_prob,
num_leapfrog_steps=2,
step_size=step_size,
step_size_update_fn=tfp.mcmc.step_size_simple_update,
step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy(),
state_gradients_are_stopped=True))
avg_acceptance_ratio = tf.reduce_mean(
Expand Down Expand Up @@ -453,7 +497,11 @@ def one_step(self, current_state, previous_kernel_results):
"""
previous_step_size_assign = (
[] if self.step_size_update_fn is None
else [previous_kernel_results.extra.step_size_assign])
else (previous_kernel_results.extra.step_size_assign
if mcmc_util.is_list_like(
previous_kernel_results.extra.step_size_assign)
else [previous_kernel_results.extra.step_size_assign]))

with tf.control_dependencies(previous_step_size_assign):
next_state, kernel_results = self._impl.one_step(
current_state, previous_kernel_results)
Expand Down
116 changes: 114 additions & 2 deletions tensorflow_probability/python/mcmc/hmc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,118 @@ class HMCHandlesLists64(_HMCHandlesLists, tf.test.TestCase):
dtype = np.float64


class HMCAdaptiveStepSize(tf.test.TestCase):

def setUp(self):
random_seed.set_random_seed(10014)
np.random.seed(10014)

# TODO(b/112427830): Reenable eager tests for TF 1.11 release.
@run_in_graph_mode_only()
def test_multiple_step_sizes(self):
num_results = 5
initial_step_sizes = [1e-5, 1e-4]
initial_state = [0., 0.]
dtype = np.float32

# TODO(b/111765211): Switch to the following once
# `get_variable(use_resource=True)` has the same semantics as
# `tf.contrib.eager.Variable`.
# step_size = tf.get_variable(
# name='step_size',
# initializer=np.array(1e-3, dtype),
# use_resource=True,
# trainable=False)
step_size = [tf.contrib.eager.Variable(
initial_value=np.array(initial_step_size, dtype),
name='step_size',
trainable=False) for initial_step_size in initial_step_sizes]

def target_log_prob_fn(x1, x2):
return tf.reduce_sum(tfd.Normal(0., 1.).log_prob([x1, x2]))

_, kernel_results = tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=0,
current_state=initial_state,
kernel=tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=target_log_prob_fn,
num_leapfrog_steps=2,
step_size=step_size,
step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy(),
state_gradients_are_stopped=True,
seed=_set_seed(252)),
parallel_iterations=1)

init_op = tf.global_variables_initializer()
self.evaluate(init_op)

step_size_ = self.evaluate(kernel_results.extra.step_size_assign)

# We apply the same adjustment to each step size in the list, so
# the starting ratio of step sizes should match the final ratio.
self.assertNear(step_size_[0][0]/step_size_[1][0],
step_size_[0][-1]/step_size_[1][-1], err=1e-4)

# TODO(b/112427830): Reenable eager tests for TF 1.11 release.
@run_in_graph_mode_only()
def test_finite_adaptation(self):

# Test that the adaptation runs for the specified number of steps.
# We set up a chain with a tiny initial step size, so every step accepts,
# and test that the final step size is incremented exactly
# `num_adaptation_steps` times.
num_results = 10
num_adaptation_steps = 3
initial_step_size = 1e-5
dtype = np.float32

# TODO(b/111765211): Switch to the following once
# `get_variable(use_resource=True)` has the same semantics as
# `tf.contrib.eager.Variable`.
# step_size = tf.get_variable(
# name='step_size',
# initializer=np.array(1e-3, dtype),
# use_resource=True,
# trainable=False)
step_size = tf.contrib.eager.Variable(
initial_value=np.array(initial_step_size, dtype),
name='step_size',
trainable=False)

_, kernel_results = tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=0,
current_state=0.,
kernel=tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=tfd.Normal(0., 1.).log_prob,
num_leapfrog_steps=2,
step_size=step_size,
step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy(
num_adaptation_steps=num_adaptation_steps,
increment_multiplier=1., # double step_size on every accept
decrement_multiplier=1.), # halve step_size on every reject
state_gradients_are_stopped=True,
seed=_set_seed(252)),
parallel_iterations=1)

init_op = tf.global_variables_initializer()
self.evaluate(init_op)

[_, step_size_] = self.evaluate([
kernel_results, kernel_results.extra.step_size_assign])

# Test that we've incremented the step size every time. This verifies
# that adaptation ran on each of the first `num_adaptation_steps` steps.
self.assertNear(initial_step_size * 2**num_adaptation_steps,
step_size_[num_adaptation_steps], err=1e-6)

# Test that the step size does not change after the first
# `num_adaptation_steps` steps.
self.assertEqual(step_size_[num_adaptation_steps:].min(),
step_size_[num_adaptation_steps:].max())


class HMCEMAdaptiveStepSize(tf.test.TestCase):
"""This test verifies that the docstring example works as advertised."""

Expand Down Expand Up @@ -889,7 +1001,7 @@ def unnormalized_posterior_log_prob(w):
target_log_prob_fn=unnormalized_posterior_log_prob,
num_leapfrog_steps=2,
step_size=step_size,
step_size_update_fn=tfp.mcmc.step_size_simple_update,
step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy(),
state_gradients_are_stopped=True,
seed=_set_seed(252)),
parallel_iterations=1)
Expand Down Expand Up @@ -980,7 +1092,7 @@ def unnormalized_log_prob(x):
target_log_prob_fn=unnormalized_log_prob,
num_leapfrog_steps=2,
step_size=step_size,
step_size_update_fn=tfp.mcmc.step_size_simple_update,
step_size_update_fn=tfp.mcmc.make_simple_step_size_update_policy(),
seed=_set_seed(252)),
parallel_iterations=1)

Expand Down

0 comments on commit d740d84

Please sign in to comment.