From 39a4111c6023ac2e4b67e2967b70db04c6e3ab38 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 22 Nov 2022 14:02:48 -0500 Subject: [PATCH] Fixes #3156 (#3157) --- pyro/infer/mcmc/hmc.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index 6de9ddc51a..ff73dedde5 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -66,6 +66,8 @@ class HMC(MCMCKernel): step size, hence the sampling will be slower and more robust. Default to 0.8. :param callable init_strategy: A per-site initialization function. See :ref:`autoguide-initialization` section for available functions. + :param min_stepsize (float): Lower bound on stepsize in adaptation strategy. + :param max_stepsize (float): Upper bound on stepsize in adaptation strategy. .. note:: Internally, the mass matrix will be ordered according to the order of the names of latent variables, not the order of their appearance in @@ -108,6 +110,9 @@ def __init__( ignore_jit_warnings=False, target_accept_prob=0.8, init_strategy=init_to_uniform, + *, + min_stepsize: float = 1e-10, + max_stepsize: float = 1e10, ): if not ((model is None) ^ (potential_fn is None)): raise ValueError("Only one of `model` or `potential_fn` must be specified.") @@ -119,6 +124,8 @@ def __init__( self._jit_options = jit_options self._ignore_jit_warnings = ignore_jit_warnings self._init_strategy = init_strategy + self._min_stepsize = min_stepsize + self._max_stepsize = max_stepsize self.potential_fn = potential_fn if trajectory_length is not None: @@ -188,9 +195,11 @@ def _find_reasonable_step_size(self, z): step_size_scale = 2**direction direction_new = direction # keep scale step_size until accept_prob crosses its target - # TODO: make thresholds for too small step_size or too large step_size t = 0 - while direction_new == direction: + while ( + direction_new == direction + and self._min_stepsize < step_size < self._max_stepsize + ): t += 1 step_size = step_size_scale * step_size r, r_unscaled = self._sample_r(name="r_presample_{}".format(t)) @@ -206,6 +215,8 @@ def _find_reasonable_step_size(self, z): energy_new = self._kinetic_energy(r_new_unscaled) + potential_energy_new delta_energy = energy_new - energy_current direction_new = 1 if self._direction_threshold < -delta_energy else -1 + step_size = max(step_size, self._min_stepsize) + step_size = min(step_size, self._max_stepsize) return step_size def _sample_r(self, name):