From fadd4034360081b9be757bd51a53d3e2fe167a0f Mon Sep 17 00:00:00 2001 From: siege Date: Wed, 20 Nov 2024 12:13:21 -0800 Subject: [PATCH] FunMC: Add a better implementation of SMC to FunMC. The old implementation (i.e. AIS) will be rewritten to use this version at some future date. PiperOrigin-RevId: 698476095 --- spinoffs/fun_mc/fun_mc/BUILD | 1 - .../fun_mc/fun_mc/dynamic/backend_jax/BUILD | 1 + .../fun_mc/fun_mc/dynamic/backend_jax/util.py | 34 +- .../dynamic/backend_tensorflow/backend.py | 8 +- .../fun_mc/dynamic/backend_tensorflow/util.py | 66 +- spinoffs/fun_mc/fun_mc/smc.py | 369 +++++- spinoffs/fun_mc/fun_mc/smc_test.py | 1037 +++++++++++++++++ 7 files changed, 1508 insertions(+), 8 deletions(-) diff --git a/spinoffs/fun_mc/fun_mc/BUILD b/spinoffs/fun_mc/fun_mc/BUILD index f20d2bda7c..e499eec6c6 100644 --- a/spinoffs/fun_mc/fun_mc/BUILD +++ b/spinoffs/fun_mc/fun_mc/BUILD @@ -226,7 +226,6 @@ pytype_strict_contrib_test( ":types", # absl/testing:parameterized dep, # jax dep, - # jaxtyping dep, # mock dep, # tensorflow dep, # tensorflow_probability/python/internal:test_util dep, diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/BUILD b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/BUILD index 019998f434..83fac10892 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/BUILD +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/BUILD @@ -33,6 +33,7 @@ py_library( srcs = ["util.py"], deps = [ # jaxtyping dep, + # numpy dep, ], ) diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py index 8ed20c1399..895dfc5f00 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py @@ -14,7 +14,9 @@ # ============================================================================ """FunMC utilities implemented via JAX.""" +import dataclasses import functools +from typing import TypeVar, dataclass_transform import jax from jax import lax @@ -22,16 +24,19 @@ from jax import tree_util import jax.numpy as jnp import jaxtyping +import numpy as np __all__ = [ 'Array', 'assert_same_shallow_tree', 'block_until_ready', 'convert_to_tensor', + 'dataclass', 'diff', 'DType', 'flatten_tree', 'get_shallow_tree', + 'get_static_value', 'inverse_fn', 'make_tensor_seed', 'map_tree', @@ -406,7 +411,7 @@ def diff(x, prepend=None): return jnp.diff(x, prepend=prepend) -def repeat(x, repeats, total_repeat_length=None): +def repeat(x, repeats, total_repeat_length): """Like jnp.repeat.""" return jnp.repeat(x, repeats, total_repeat_length=total_repeat_length) @@ -436,3 +441,30 @@ def convert_to_tensor(x): if x is None: return x return jnp.asarray(x) + + +T = TypeVar('T') + + +@dataclass_transform() +def dataclass(cls: T) -> T: + """Create a tree-compatible dataclass.""" + cls = dataclasses.dataclass(frozen=True)(cls) + fields = [f.name for f in dataclasses.fields(cls)] + jax.tree_util.register_dataclass(cls, fields, []) + + def replace(self, **updates): + """Returns a new object replacing the specified fields with new values.""" + return dataclasses.replace(self, **updates) + + cls.replace = replace + + return cls + + +def get_static_value(x): + """Returns the static value of x, or None if x is dynamic.""" + try: + return np.array(x) + except TypeError: + return None diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/backend.py b/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/backend.py index a1461e16c4..23dc14d2b1 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/backend.py +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/backend.py @@ -24,7 +24,13 @@ tnp = tf.experimental.numpy _lax = types.ModuleType('lax') -_lax.cond = tf.cond + + +def cond(pred, true_fn, false_fn, *args): + return tf.cond(pred, lambda: true_fn(*args), lambda: false_fn(*args)) + + +_lax.cond = cond _lax.stop_gradient = tf.stop_gradient _nn = types.ModuleType('nn') diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py b/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py index 41d9e0c43f..e1e81a1dd6 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py @@ -14,7 +14,9 @@ # ============================================================================ """FunMC utilities implemented via TensorFlow.""" +import dataclasses import functools +from typing import TypeVar, dataclass_transform import numpy as np import six @@ -29,10 +31,12 @@ 'assert_same_shallow_tree', 'block_until_ready', 'convert_to_tensor', + 'dataclass', 'diff', 'DType', 'flatten_tree', 'get_shallow_tree', + 'get_static_value', 'inverse_fn', 'make_tensor_seed', 'map_tree', @@ -419,10 +423,23 @@ def diff(x, prepend=None): def repeat(x, repeats, total_repeat_length): """Like jnp.repeat.""" - res = tf.repeat(x, repeats) - if total_repeat_length is not None: - res.set_shape([total_repeat_length] + [None] * (len(res.shape) - 1)) - return res + # Implementation based on JAX, with some adjustments due to TF's stricted + # indexing validation. + exclusive_repeats = tf.concat([[0], repeats[:-1]], axis=0) + scatter_indices = tf.cumsum(exclusive_repeats) + scatter_indices = tf.where( + scatter_indices < total_repeat_length, + scatter_indices, + total_repeat_length, + ) + block_split_indicators = tf.zeros([total_repeat_length + 1], tf.int32) + block_split_indicators = tf.tensor_scatter_nd_add( + block_split_indicators, + scatter_indices[..., tf.newaxis], + tf.ones_like(scatter_indices), + ) + gather_indices = tf.cumsum(block_split_indicators[:-1]) - 1 + return tf.gather(x, gather_indices) def new_dynamic_array(shape, dtype, size): @@ -454,3 +471,44 @@ def convert_to_tensor(x): if isinstance(x, tf.TensorArray): return x return tf.convert_to_tensor(x) + + +T = TypeVar('T') + + +@dataclass_transform() +def dataclass(cls: T) -> T: + """Create a tree-compatible dataclass.""" + cls = dataclasses.dataclass(frozen=True)(cls) + + def __tf_flatten__(self): # pylint: disable=invalid-name + metadata = () + fields = dataclasses.fields(self) + components = tuple(getattr(self, f.name) for f in fields) + return metadata, components + + @classmethod + def __tf_unflatten__(cls, metadata, leaves): # pylint: disable=invalid-name + del metadata + return cls(*leaves) + + def __len__(self): # pylint: disable=invalid-name + # This is to work around a bug in TF's tree-prefix matching. + return len(dataclasses.fields(self)) + + cls.__tf_flatten__ = __tf_flatten__ + cls.__tf_unflatten__ = __tf_unflatten__ + cls.__len__ = __len__ + + def replace(self, **updates): + """Returns a new object replacing the specified fields with new values.""" + return dataclasses.replace(self, **updates) + + cls.replace = replace + + return cls + + +def get_static_value(x): + """Returns the static value of x, or None if x is dynamic.""" + return tf.get_static_value(x) diff --git a/spinoffs/fun_mc/fun_mc/smc.py b/spinoffs/fun_mc/fun_mc/smc.py index ed5c32aae7..1479164c66 100644 --- a/spinoffs/fun_mc/fun_mc/smc.py +++ b/spinoffs/fun_mc/fun_mc/smc.py @@ -14,7 +14,7 @@ # ============================================================================ """Implementation of Sequential Monte Carlo.""" -from typing import Protocol, TypeVar, runtime_checkable +from typing import Any, Callable, Generic, Protocol, TypeVar, runtime_checkable from fun_mc import backend from fun_mc import types @@ -40,7 +40,14 @@ __all__ = [ 'conditional_systematic_resampling', + 'effective_sample_size_predicate', + 'ParticleGatherFn', + 'ResamplingPredicate', 'SampleAncestorsFn', + 'sequential_monte_carlo_init', + 'sequential_monte_carlo_step', + 'SequentialMonteCarloKernel', + 'SequentialMonteCarloState', 'systematic_resampling', ] @@ -165,3 +172,363 @@ def conditional_systematic_resampling( permuted_parents = util.random_permutation(parent_idxs[1:], permute_seed) parents = jnp.concatenate([parent_idxs[:1], permuted_parents]) return parents + + +@util.dataclass +class SequentialMonteCarloState(Generic[State]): + """State of sequential Monte Carlo. + + Attributes: + state: The particles. + log_weights: Unnormalized log weight of the particles. + step: Current timestep. + """ + + state: State + log_weights: Float[Array, 'num_particles'] + step: IntScalar + + def log_normalizing_constant(self) -> FloatScalar: + """Log of the unbiased normalizing constant estimator.""" + return tfp.math.reduce_logmeanexp(self.log_weights, axis=-1) + + def effective_sample_size(self) -> FloatScalar: + """Estimates the effective sample size.""" + norm_weights = jax.nn.softmax(self.log_weights) + return 1.0 / jnp.sum(norm_weights**2) + + +@util.dataclass +class SequentialMonteCarloExtra(Generic[State, Extra]): + """Extra outputs from sequential Monte Carlo. + + Attributes: + incremental_log_weights: Incremental log weights for this timestep. + kernel_extra: Extra outputs from the kernel operator. + resampled: Whether resampling happened or not. + ancestor_idxs: Ancestor indices. + state_after_resampling: State after resampling but before running the SMC + kernel. + log_weights_after_resampling: Log weights of particles after resampling but + before running the SMC kernel. + """ + + incremental_log_weights: Float[Array, 'num_particles'] + kernel_extra: Extra + resampled: BoolScalar + ancestor_idxs: Int[Array, 'num_particles'] + state_after_resampling: State + log_weights_after_resampling: Float[Array, 'num_particles'] + + +@runtime_checkable +class SequentialMonteCarloKernel(Protocol[State, Extra]): + """SMC kernel, a function that proposes new states and produces the incremental log weights.""" + + def __call__( + self, + state: State, + step: IntScalar, + seed: Seed, + ) -> tuple[ + State, + tuple[Float[Array, 'num_particles'], Extra], + ]: + """Perform an SMC kernel step. + + Given a previous state `x_{t - 1}^{1:K}` (`K` = number of particles) at + timestep `(t - 1)`, an SequentialMonteCarloKernel returns: + + 1. A new set of particles `x_t^{1:K}` at timestep t. + 2. The incremental log weights `iw_t^{1:K}` at step `t`. + + SMC is commonly performed on states whose dimension increases with each + timestep (see Section 3.5 of [1]), e.g. `len(x_t) = t` and + `len(x_{t - 1}) = t - 1`. Then, for every particle `k` in `{1, ..., K}`, the + new states are obtained as + + ```none + z_t^k ~ q_t(. | x_{t - 1}^k) + x_t^k = append(x_{t - 1}^k, z_t^k) + ``` + + the incremental log weight is computed as + + ```none + iw_t^k = log p_t(x_t^k) - log p_{t - 1}(x_{t - 1}^k) + - log q_t(x_t[-1]^k | x_{t - 1}^k) + ``` + + where `q_t` is the proposal and `{p_t(x_t); t = 1, ..., T}` is the sequence + of unnormalized target distributions. + + Alternatively, SMC can also be performed on states that live in the same + space, as in the SMC samplers [2]. Then, for every particle k in + `{1, ..., K}`, the new states are obtained as + + ```none + x_t^k ~ q_t(x_t | x_{t - 1}^k) + ``` + + and the incremental log weight is computed as + + ```none + iw_t^k = log p_t(x_t^k) + log r_{t - 1}(x_{t - 1}^k | x_t^k) + - log p_{t - 1}(x_{t - 1}^k) - log q_t(x_t^k | x_{t - 1}^k) + + where `q_t`, `r_{t - 1}` are the forward and reverse kernels respectively + and `{p_t(x_t); t = 1, ..., T}` is the sequence of unnormalized target + distributions. + + In the most general case, the kernel should maintain a 'proper weighting' + invariant. A set of weighted particles `x^{1:K}`, w^{1:K} is properly + weighted w.r.t. an unnormalized target `p(x)` if + + ```none + E[1 / K sum_k w^k f(x^k)] = c E_{pi(x)}[f(x)] for any f, + ``` + + where c is a constant and pi(x) is the normalized p(x). Commonly, c the + normalization constant of p(x), i.e. c = int p(x) dx and pi(x) = p(x) / c. + + The SequentialMonteCarloKernel maintains the 'proper weighting invariant' in + the sense that + if `x_{t - 1}^{1:K}`, `w_{t - 1}^{1:K}` is properly weighted w.r.t. an + unnormalized target `p_{t - 1}(x_{t - 1})`, then `x_t^{1:K}`, `w_t^{1:K}` is + properly weighted w.r.t. `p_t(x_t)`, where `w_t^k = w_{t - 1}^k * iw_t^k`. + + In this setting, the unnormalized targets are only defined implicitly. + + Args: + state: The previous particle state, `x_{t - 1}^{1:K}`. + step: The previous timestep, `t - 1`. + seed: A PRNG key. + + Returns: + state: The new particles, `x_t^{1:K}`. + extra: A 2-tuple of: + incremental_log_weights: The incremental log weight at timestep t, + `iw_t^{1:K}`. + kernel_extra: Extra information returned by the kernel. + + #### References: + + [1]: Doucet, Arnaud, and Adam M. Johansen. 'A tutorial on particle filtering + and smoothing: Fifteen years later.' Handbook of nonlinear filtering + 12.656-704 (2009): 3. + https://www.stats.ox.ac.uk/~doucet/doucet_johansen_tutorialPF2011.pdf + [2]: Del Moral, Pierre, Arnaud Doucet, and Ajay Jasra. 'Sequential monte + carlo samplers.' Journal of the Royal Statistical Society Series B: + Statistical Methodology 68.3 (2006): 411-436. + https://academic.oup.com/jrsssb/article/68/3/411/7110641 + """ + + +@runtime_checkable +class ResamplingPredicate(Protocol): + """Function that decides whether to resample.""" + + def __call__(self, state: SequentialMonteCarloState) -> BoolScalar: + """Return boolean indicating whether to resample. + + Note that resampling happens before stepping the kernel. + + Args: + state: State step `t - 1`. + + Returns: + Whether resampling happens during the SMC step at step `t`. + """ + + +@types.runtime_typed +def effective_sample_size_predicate( + state: SequentialMonteCarloState, +) -> BoolScalar: + """A resampling predicate that uses effective sample size. + + Args: + state: SMC state. + + Returns: + True if the effective sample size is less than half the number of particles, + False otherwise. + """ + num_particles = state.log_weights.shape[0] + return state.effective_sample_size() < num_particles / 2 + + +@runtime_checkable +class ParticleGatherFn(Protocol[State]): + """Function that indexes into a batch of states.""" + + def __call__( + self, + state: State, + indices: Int[Array, 'num_particles'], + ) -> State: + """Gather states at the given indices.""" + + +@types.runtime_typed +def _defalt_pytree_gather( + state: State, + indices: Int[Array, 'num_particles'], +) -> State: + """Indexes into states using the default gather. + + Assumes `state` is a pytree of arrays with a single leading batch dimension. + + Args: + state: The particles. + indices: The gather indices. + + Returns: + new_state: Gathered state (with the same leading dimension). + """ + return util.map_tree(lambda x: x[indices], state) + + +@types.runtime_typed +def sequential_monte_carlo_init( + state: State, + num_particles: int | None = None, + initial_step: IntScalar = 0, + weight_dtype: DType = jnp.float32, +) -> SequentialMonteCarloState[State]: + """Initializes the sequential Monte Carlo state. + + Args: + state: Initial state representing the SMC particles. + num_particles: Number of particles, if `None`, it is inferred from the first + element of `state`. + initial_step: Initial step number. + weight_dtype: DType of the `log_weights`. + + Returns: + `SequentialMonteCarloState`. + """ + if num_particles is None: + num_particles = util.flatten_tree(state)[0].shape[0] + return SequentialMonteCarloState( + state=state, + log_weights=jnp.zeros([num_particles], dtype=weight_dtype), + step=jnp.asarray(initial_step, dtype=jnp.int32), + ) + + +@types.runtime_typed +def sequential_monte_carlo_step( + smc_state: SequentialMonteCarloState[State], + kernel: SequentialMonteCarloKernel[State, Extra], + seed: Seed, + resampling_pred: ResamplingPredicate = effective_sample_size_predicate, + sample_ancestors_fn: SampleAncestorsFn = systematic_resampling, + state_gather_fn: ParticleGatherFn[State] = _defalt_pytree_gather, +) -> tuple[ + SequentialMonteCarloState[State], SequentialMonteCarloExtra[State, Extra] +]: + """Take a step of sequential Monte Carlo. + + Given a previous SMC state `x_{t - 1}^{1:K}`, `w_{t - 1}^{1:K}` (where `K` is + the number of particles) at timestep `t - 1` that is properly weighted w.r.t. + an unnormalized target `p_{t - 1}(x_{t - 1})`, returns a new SMC state + `x_t^{1:K}`, `w_t^{1:K}` at timestep `t` that is properly weighted w.r.t. + `p_t(x_t)`. + + Note that the unnormalized target is implicitly defined by the kernel. + + This implementation first resamples, then steps the kernel. + + Args: + smc_state: SMC state at timestep `t - 1`, `x_{t - 1}^{1:K}`, `w_{t - + 1}^{1:K}`. + kernel: SMC kernel. + seed: Random seed. + resampling_pred: Resampling predicate. + sample_ancestors_fn: Ancestor index sampling function. + state_gather_fn: State gather function. + + Returns: + smc_state: SMC state at timestep t, x_t^{1:K}, w_t^{1:K}. + smc_extra: Extra information for the SMC step. + """ + resample_seed, kernel_seed = util.split_seed(seed, 2) + + def do_resample( + state, + log_weights, + seed, + ): + ancestor_idxs = sample_ancestors_fn(log_weights, seed) + new_state = state_gather_fn(state, ancestor_idxs) + num_particles = log_weights.shape[0] + new_log_weights = jnp.full( + (num_particles,), tfp.math.reduce_logmeanexp(log_weights) + ) + return (new_state, ancestor_idxs, new_log_weights) + + def dont_resample( + state, + log_weights, + seed, + ): + del seed + num_particles = log_weights.shape[0] + return state, jnp.arange(num_particles), log_weights + + # NOTE: We don't explicitly disable resampling at the first step. However, if + # we initialize the log weights to zeros, either of + # 1. resampling according to the effective sample size criterion and + # 2. using systematic resampling effectively disables resampling at the first + # step. + # First-step resampling can always be forced via the `resampling_pred`. + should_resample = resampling_pred(smc_state) + state_after_resampling, ancestor_idxs, log_weights_after_resampling = ( + _smart_cond( + should_resample, + do_resample, + dont_resample, + smc_state.state, + smc_state.log_weights, + resample_seed, + ) + ) + + # Step kernel + state, (incremental_log_weights, kernel_extra) = kernel( + state_after_resampling, smc_state.step, kernel_seed + ) + + new_log_weights = log_weights_after_resampling + incremental_log_weights + + smc_state = smc_state.replace( # pytype: disable=attribute-error + state=state, + log_weights=new_log_weights, + step=smc_state.step + 1, + ) + smc_extra = SequentialMonteCarloExtra( + incremental_log_weights=incremental_log_weights, + kernel_extra=kernel_extra, + resampled=should_resample, + ancestor_idxs=ancestor_idxs, + state_after_resampling=state_after_resampling, + log_weights_after_resampling=log_weights_after_resampling, + ) + return smc_state, smc_extra + + +def _smart_cond( + pred: BoolScalar, + true_fn: Callable[..., T], + false_fn: Callable[..., T], + *args: Any +) -> T: + """Like lax.cond, but shortcircuiting for static predicates.""" + static_pred = util.get_static_value(pred) + if static_pred is None: + return jax.lax.cond(pred, true_fn, false_fn, *args) + elif static_pred: + return true_fn(*args) + else: + return false_fn(*args) diff --git a/spinoffs/fun_mc/fun_mc/smc_test.py b/spinoffs/fun_mc/fun_mc/smc_test.py index 8acf2535b3..1a7962e3fa 100644 --- a/spinoffs/fun_mc/fun_mc/smc_test.py +++ b/spinoffs/fun_mc/fun_mc/smc_test.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import functools + # Dependency imports from absl.testing import parameterized import jax as real_jax +import mock import tensorflow.compat.v2 as real_tf from tensorflow_probability.python.internal import test_util as tfp_test_util from fun_mc import backend @@ -56,6 +59,128 @@ def _test_seed() -> Seed: return util.make_tensor_seed([seed, 0]) +@types.runtime_typed +def basic_kernel( + state: Float[Array, 'num_particles'], + step: IntScalar, + seed: Seed, +) -> tuple[ + Float[Array, 'num_particles'], + tuple[Float[Array, 'num_particles'], tuple[()]], +]: + del step + random_weights = util.random_uniform(state.shape, state.dtype, seed) + log_weights = jnp.log(random_weights) + return state, (log_weights, ()) + + +@types.runtime_typed +def ess_kernel( + state: Float[Array, 'num_particles'], + step: IntScalar, + seed: Seed, +) -> tuple[ + Float[Array, 'num_particles'], + tuple[Float[Array, 'num_particles'], tuple[()]], +]: + """Make the ESS low only for first 3 timesteps.""" + del seed + num_particles = state.shape[0] + high_ess_log_weights = jnp.zeros([num_particles], state.dtype) + low_ess_log_weights = jnp.stack( + [jnp.zeros([], dtype=state.dtype)] + + [jnp.array(-jnp.inf, dtype=state.dtype)] * (num_particles - 1), + 0, + ) + log_weights = jnp.where(step < 3, low_ess_log_weights, high_ess_log_weights) + return state, (log_weights, ()) + + +@types.runtime_typed +def kernel_log_weights_eq_two( + state: Float[Array, 'num_particles'], + step: IntScalar, + seed: Seed, +) -> tuple[ + Float[Array, 'num_particles'], + tuple[Float[Array, 'num_particles'], tuple[()]], +]: + """Always returns log weight equals 2.""" + del seed, step + num_particles = state.shape[0] + log_weights = jnp.full([num_particles], 2, dtype=jnp.float32) + return state, (log_weights, ()) + + +@types.runtime_typed +def kernel_log_weights_eq_step( + state: Float[Array, 'num_particles'], + step: IntScalar, + seed: Seed, +) -> tuple[ + Float[Array, 'num_particles'], + tuple[Float[Array, 'num_particles'], tuple[()]], +]: + """Log weight equals timestep.""" + del seed + num_particles = state.shape[0] + log_weights = jnp.full([num_particles], step, dtype=jnp.float32) + return state, (log_weights, ()) + + +@types.runtime_typed +def kernel_extra_is_finished( + state: Float[Array, 'num_particles'], + step: IntScalar, + seed: Seed, + num_timesteps: int, +) -> tuple[ + Float[Array, 'num_particles'], + tuple[Float[Array, 'num_particles'], BoolScalar], +]: + """Returns an extra `is_finished` value based on `num_timesteps`.""" + del seed + num_particles = state.shape[0] + log_weights = jnp.zeros([num_particles], state.dtype) + is_finished = step >= num_timesteps - 1 + return state, (log_weights, is_finished) + + +@types.runtime_typed +def kernel_log_weights_eq_neg_inf_if_state_lt_zero( + state: Float[Array, 'num_particles'], + step: IntScalar, + seed: Seed, +) -> tuple[ + Float[Array, 'num_particles'], + tuple[Float[Array, 'num_particles'], tuple[()]], +]: + """Decrements the state and returns -inf weight when it dips below zero.""" + del seed, step + new_state = state - 1 + num_particles = state.shape[0] + neg_infs = jnp.full([num_particles], -jnp.inf, dtype=state.dtype) + zeros = jnp.zeros([num_particles], dtype=state.dtype) + log_weights = jnp.where(new_state < 0, neg_infs, zeros) + return new_state, (log_weights, ()) + + +@types.runtime_typed +def always_predicate( + state: smc.SequentialMonteCarloState, +) -> BoolScalar: + del state + return True + + +@types.runtime_typed +def never_predicate( + state: smc.SequentialMonteCarloState, +) -> BoolScalar: + del state + return False + + class SMCTest(tfp_test_util.TestCase): @property @@ -147,6 +272,918 @@ def kernel(seed): ) self.assertAllClose(rejection_freqs, conditional_freqs, atol=0.05) + def test_smc_runs_and_shapes_correct(self): + num_particles = 3 + num_timesteps = 20 + + def kernel(smc_state, seed): + step_seed, seed = util.split_seed(seed, 2) + smc_state, _ = smc.sequential_monte_carlo_step( + smc_state, + kernel=basic_kernel, + seed=step_seed, + ) + return (smc_state, seed), () + + @jax.jit + def run_smc(seed): + (smc_state, _), _ = fun_mc.trace( + ( + smc.sequential_monte_carlo_init( + jnp.zeros((num_particles,), dtype=self._dtype), + weight_dtype=self._dtype, + ), + seed, + ), + kernel, + num_timesteps, + ) + return smc_state + + smc_state = run_smc(_test_seed()) + self.assertEqual(smc_state.state.shape, (num_particles,)) + self.assertEqual(smc_state.log_weights.shape, (num_particles,)) + self.assertEqual(smc_state.step.shape, ()) + self.assertEqual(smc_state.log_normalizing_constant().shape, ()) + self.assertEqual(smc_state.effective_sample_size().shape, ()) + + @parameterized.parameters(True, False) + def test_static_resampling(self, resample): + num_particles = 3 + num_timesteps = 20 + + patch_cond = self.enter_context( + mock.patch.object( + jax.lax, 'cond', autospec=True, side_effect=jax.lax.cond + ) + ) + + def kernel(smc_state, seed): + step_seed, seed = util.split_seed(seed, 2) + smc_state, extra = smc.sequential_monte_carlo_step( + smc_state, + kernel=basic_kernel, + seed=step_seed, + resampling_pred=lambda _: resample, + ) + return (smc_state, seed), extra.resampled + + @jax.jit + def run_smc(seed): + _, resampled_trace = fun_mc.trace( + ( + smc.sequential_monte_carlo_init( + state=jnp.zeros((num_particles,), dtype=self._dtype), + weight_dtype=self._dtype, + ), + seed, + ), + kernel, + num_timesteps, + ) + return resampled_trace + + resampled_trace = run_smc(_test_seed()) + self.assertFalse(patch_cond.called) + self.assertAllEqual( + jnp.full(resampled_trace.shape, resample), resampled_trace + ) + + def test_ess_resampling(self): + num_particles = 3 + num_timesteps = 20 + + patch_cond = self.enter_context( + mock.patch.object( + jax.lax, 'cond', autospec=True, side_effect=jax.lax.cond + ) + ) + + def kernel(smc_state, seed): + step_seed, seed = util.split_seed(seed, 2) + smc_state, extra = smc.sequential_monte_carlo_step( + smc_state, + kernel=ess_kernel, + seed=step_seed, + resampling_pred=smc.effective_sample_size_predicate, + ) + return (smc_state, seed), extra.resampled + + @jax.jit + def run_smc(seed): + _, resampled_trace = fun_mc.trace( + ( + smc.sequential_monte_carlo_init( + state=jnp.zeros((num_particles,), dtype=self._dtype), + weight_dtype=self._dtype, + ), + seed, + ), + kernel, + num_timesteps, + ) + return resampled_trace + + resampled_trace = run_smc(_test_seed()) + self.assertTrue(patch_cond.called) + + # The initial weights are zero so ESS is high so there is no resampling. + self.assertTrue(~resampled_trace[0]) + # The next weights for the next 3 steps have low ESS by design so there is + # resampling. + self.assertTrue(jnp.all(resampled_trace[1:4])) + # The weights for the rest of the steps have high ESS by design so there is + # no resampling. + self.assertTrue(jnp.all(~resampled_trace[4:])) + + @parameterized.product( + [ + dict( + max_num_timesteps=12, + num_timesteps=12, + ), + dict( + max_num_timesteps=12, + num_timesteps=10, + ), + dict( + max_num_timesteps=12, + num_timesteps=1, + ), + dict( + max_num_timesteps=1, + num_timesteps=1, + ), + ], + stop_early=[True, False], + resampling_pred=[ + always_predicate, + never_predicate, + ], + ) + def test_log_normalizing_constant( + self, + max_num_timesteps, + num_timesteps, + stop_early, + resampling_pred, + ): + num_particles = 3 + + def kernel(smc_state, seed): + step_seed, seed = util.split_seed(seed, 2) + smc_state, _ = smc.sequential_monte_carlo_step( + smc_state, + kernel=kernel_log_weights_eq_two, + resampling_pred=resampling_pred, + seed=step_seed, + ) + if stop_early: # pylint: disable=cell-var-from-loop + return (smc_state, seed), () + else: + return (smc_state, seed), smc_state.log_normalizing_constant() + + @jax.jit + def run_smc(seed): + def stop_fn(smc_state_and_seed, _): + smc_state, _ = smc_state_and_seed + return smc_state.step >= num_timesteps + + (smc_state, _), log_normalizing_constants = fun_mc.trace( + ( + smc.sequential_monte_carlo_init( + state=jnp.zeros((num_particles,), dtype=self._dtype), + weight_dtype=self._dtype, + ), + seed, + ), + kernel, # pylint: disable=cell-var-from-loop + max_num_timesteps, + stop_fn=stop_fn if stop_early else None, # pylint: disable=cell-var-from-loop + ) + if stop_early: # pylint: disable=cell-var-from-loop + return smc_state.log_normalizing_constant() + else: + return log_normalizing_constants[num_timesteps - 1] + + log_normalizing_constant = run_smc(_test_seed()) + # log_weights are 2 at each step and n step are taken, so total is n * 2. + self.assertAllClose(log_normalizing_constant, 2.0 * num_timesteps) + + @parameterized.product( + [ + dict( + max_num_timesteps=12, + num_timesteps=12, + ), + dict( + max_num_timesteps=12, + num_timesteps=10, + ), + dict( + max_num_timesteps=12, + num_timesteps=1, + ), + dict( + max_num_timesteps=1, + num_timesteps=1, + ), + ], + stop_early=[True, False], + resampling_pred=[ + always_predicate, + never_predicate, + ], + ) + def test_log_normalizing_constant_time_dependent( + self, + max_num_timesteps, + num_timesteps, + stop_early, + resampling_pred, + ): + num_particles = 3 + + def kernel(smc_state, seed): + step_seed, seed = util.split_seed(seed, 2) + smc_state, _ = smc.sequential_monte_carlo_step( + smc_state, + kernel=kernel_log_weights_eq_step, + resampling_pred=resampling_pred, + seed=step_seed, + ) + if stop_early: # pylint: disable=cell-var-from-loop + return (smc_state, seed), () + else: + return (smc_state, seed), smc_state.log_normalizing_constant() + + @jax.jit + def run_smc(seed): + def stop_fn(smc_state_and_seed, _): + smc_state, _ = smc_state_and_seed + return smc_state.step >= num_timesteps + + (smc_state, _), log_normalizing_constants = fun_mc.trace( + ( + smc.sequential_monte_carlo_init( + state=jnp.zeros((num_particles,), dtype=self._dtype), + weight_dtype=self._dtype, + ), + seed, + ), + kernel, # pylint: disable=cell-var-from-loop + max_num_timesteps, + stop_fn=stop_fn if stop_early else None, # pylint: disable=cell-var-from-loop + ) + if stop_early: # pylint: disable=cell-var-from-loop + return smc_state.log_normalizing_constant() + else: + return log_normalizing_constants[num_timesteps - 1] + + log_normalizing_constant = run_smc(_test_seed()) + # log_weights are 2 at each step and n step are taken, so total is n * 2. + self.assertAllClose( + log_normalizing_constant, jnp.sum(jnp.arange(num_timesteps)) + ) + + @parameterized.named_parameters( + ('12_steps_12_max_stop_early', 12, 12, True), + ('10_steps_12_max_stop_early', 10, 12, True), + ('1_step_10_max_stop_early', 1, 10, True), + ('1_step_1_max_stop_early', 1, 1, True), + ('12_steps_12_max', 12, 12, False), + ('10_steps_12_max', 10, 12, False), + ('1_step_10_max', 1, 10, False), + ('1_step_1_max', 1, 1, False), + ) + def test_log_normalizing_constant_no_resampling_last_step( + self, + num_timesteps, + max_num_timesteps, + stop_early, + ): + """Check log normalizing constant when not resampling only on the last step.""" + num_particles = 3 + + def resampling_pred(state): + return state.step < num_timesteps - 1 + + def kernel(smc_state, seed): + step_seed, seed = util.split_seed(seed, 2) + smc_state, extra = smc.sequential_monte_carlo_step( + smc_state, + kernel=kernel_log_weights_eq_two, + resampling_pred=resampling_pred, + seed=step_seed, + ) + return (smc_state, seed), ( + extra.resampled, + smc_state.log_normalizing_constant(), + ) + + @jax.jit + def run_smc(seed): + def stop_fn(smc_state_and_seed, _): + smc_state, _ = smc_state_and_seed + return smc_state.step >= num_timesteps + + (smc_state, _), (resampled_trace, log_normalizing_constants) = ( + fun_mc.trace( + ( + smc.sequential_monte_carlo_init( + state=jnp.zeros((num_particles,), dtype=self._dtype), + weight_dtype=self._dtype, + ), + seed, + ), + kernel, # pylint: disable=cell-var-from-loop + max_num_timesteps, + stop_fn=stop_fn if stop_early else None, # pylint: disable=cell-var-from-loop + ) + ) + if stop_early: # pylint: disable=cell-var-from-loop + log_normalizing_constant = smc_state.log_normalizing_constant() + else: + log_normalizing_constant = log_normalizing_constants[num_timesteps - 1] + return log_normalizing_constant, resampled_trace + + log_normaling_constant, resampled_trace = run_smc(_test_seed()) + + self.assertAllClose(log_normaling_constant, 2.0 * num_timesteps) + # Resampling never happens on the last step. + self.assertAllTrue(~resampled_trace[num_timesteps - 1]) + # Resampling happens on all but the last step. + self.assertAllTrue(jnp.all(resampled_trace[: num_timesteps - 1])) + + @parameterized.product( + num_timesteps=[0, 1, 9, 10, 11, 100], + max_num_timesteps=[1, 10], + ) + def test_final_timestep(self, num_timesteps, max_num_timesteps): + """Test returning the index of the first is_finished occurrence.""" + num_particles = 3 + + def kernel(smc_state, seed): + step_seed, seed = util.split_seed(seed, 2) + smc_state, extra = smc.sequential_monte_carlo_step( + smc_state, + kernel=functools.partial( + kernel_extra_is_finished, num_timesteps=num_timesteps + ), + seed=step_seed, + ) + is_finished = extra.kernel_extra + return (smc_state, seed), is_finished + + @jax.jit + def run_smc(seed): + _, is_finished_trace = fun_mc.trace( + ( + smc.sequential_monte_carlo_init( + state=jnp.zeros((num_particles,), dtype=self._dtype), + weight_dtype=self._dtype, + ), + seed, + ), + kernel, + max_num_timesteps, + ) + + # Find the index of the first occurence of `is_finished==True`. If all + # are `False`, return `max_num_timesteps - 1`. + final_timestep = jnp.where( + jnp.all(~is_finished_trace), + max_num_timesteps - 1, + jnp.argmax(is_finished_trace), + ) + return final_timestep + + final_timestep = run_smc(_test_seed()) + + if num_timesteps >= max_num_timesteps: + self.assertEqual(final_timestep, max_num_timesteps - 1) + else: + self.assertEqual(final_timestep, max(num_timesteps - 1, 0)) + + @parameterized.parameters(True, False) + def test_some_neg_inf_weights(self, stop_early): + """Test that SMC correctly handles some weights of negative infinity.""" + num_particles = 4 + num_timesteps = 3 + max_num_timesteps = 3 + + def kernel(smc_state, seed): + step_seed, seed = util.split_seed(seed, 2) + smc_state, extra = smc.sequential_monte_carlo_step( + smc_state, + kernel=kernel_log_weights_eq_neg_inf_if_state_lt_zero, + resampling_pred=always_predicate, + seed=step_seed, + ) + return (smc_state, seed), { + 'state': smc_state.state, + 'log_weights': smc_state.log_weights, + 'ancestor_idxs': extra.ancestor_idxs, + 'resampled': extra.resampled, + 'log_normalizing_constant': smc_state.log_normalizing_constant(), + } + + @jax.jit + def run_smc(seed): + def stop_fn(smc_state_and_seed, _): + smc_state, _ = smc_state_and_seed + return smc_state.step >= num_timesteps + + _, trace = fun_mc.trace( + ( + smc.sequential_monte_carlo_init( + state=jnp.arange(num_particles, dtype=self._dtype), + weight_dtype=self._dtype, + ), + seed, + ), + kernel, + max_num_timesteps, + stop_fn=stop_fn if stop_early else None, + ) + return trace + + trace = run_smc(_test_seed()) + + # assert that the weights of the non-negative particles are finite + self.assertTrue( + jnp.all( + jnp.equal( + jnp.isfinite(trace['log_weights'][:num_timesteps]), + jnp.greater_equal(trace['state'][:num_timesteps], 0.0), + ) + ) + ) + + # Assert that the weights of the negative particles are -inf + self.assertTrue( + jnp.all( + jnp.equal( + jnp.isneginf(trace['log_weights'][:num_timesteps]), + jnp.less(trace['state'][:num_timesteps], 0.0), + ) + ) + ) + + # Assert that the weights don't all become negative inf at once. + for i in range(num_timesteps): + self.assertTrue( + jnp.any(jnp.isfinite(trace['log_weights'][i])), msg=f'step={i}' + ) + + # Assert that -inf weights are never selected for resampling + for i in range(1, num_timesteps): + ancestor_inds = trace['ancestor_idxs'][i] + ancestor_weights = trace['log_weights'][i - 1][ancestor_inds] + self.assertTrue( + jnp.all(ancestor_weights > -float('inf')), msg=f'step={i}' + ) + + # Assert that the log normalizing constant is finite + self.assertTrue( + jnp.isfinite(trace['log_normalizing_constant'][num_timesteps - 1]) + ) + + @parameterized.parameters(True, False) + def test_all_neg_inf_weights(self, stop_early): + """Test handling of weights *all* becoming negative infinity.""" + num_particles = 4 + num_timesteps = 5 + max_num_timesteps = 7 + + def kernel(smc_state, seed): + step_seed, seed = util.split_seed(seed, 2) + smc_state, extra = smc.sequential_monte_carlo_step( + smc_state, + kernel=kernel_log_weights_eq_neg_inf_if_state_lt_zero, + resampling_pred=always_predicate, + seed=step_seed, + ) + return (smc_state, seed), { + 'state': smc_state.state, + 'log_weights': smc_state.log_weights, + 'ancestor_idxs': extra.ancestor_idxs, + 'resampled': extra.resampled, + 'log_normalizing_constant': smc_state.log_normalizing_constant(), + } + + @jax.jit + def run_smc(seed): + def stop_fn(smc_state_and_seed, _): + smc_state, _ = smc_state_and_seed + return smc_state.step >= num_timesteps + + _, trace = fun_mc.trace( + ( + smc.sequential_monte_carlo_init( + state=jnp.arange(num_particles, dtype=self._dtype), + weight_dtype=self._dtype, + ), + seed, + ), + kernel, + max_num_timesteps, + stop_fn=stop_fn if stop_early else None, + ) + return trace + + trace = run_smc(_test_seed()) + + # Assert that the weights of the non-negative states are finite + self.assertTrue( + jnp.all( + jnp.equal( + jnp.isfinite(trace['log_weights'][:num_timesteps]), + jnp.greater_equal(trace['state'][:num_timesteps], 0.0), + ) + ) + ) + + # Assert that the weights of the negative states are -inf + self.assertTrue( + jnp.all( + jnp.equal( + jnp.isneginf(trace['log_weights'][:num_timesteps]), + jnp.less(trace['state'][:num_timesteps], 0.0), + ) + ) + ) + + # Assert that the weights eventually all become negative inf at once. + self.assertTrue( + jnp.any( + jnp.logical_not( + jnp.any( + jnp.isfinite(trace['log_weights'][:num_timesteps]), axis=1 + ) + ) + ) + ) + + # Assert that -inf weights are never selected for resampling, if there's a + # choice not to pick them. + for i in range(1, num_timesteps): + ancestor_inds = trace['ancestor_idxs'][i] + if not jnp.any(trace['log_weights'][i - 1] > -float('inf')): + continue + ancestor_weights = trace['log_weights'][i - 1][ancestor_inds] + self.assertTrue( + jnp.all(ancestor_weights > -float('inf')), msg=f'step={i}' + ) + + # Assert that the bound is negative inf + self.assertTrue( + jnp.isneginf(trace['log_normalizing_constant'][num_timesteps - 1]) + ) + + @parameterized.parameters(True, False) + def test_some_neg_inf_weights_no_resampling(self, stop_early): + """Test handling some negative inf weights without resampling.""" + num_particles = 4 + num_timesteps = 3 + max_num_timesteps = 3 + + def kernel(smc_state, seed): + step_seed, seed = util.split_seed(seed, 2) + smc_state, _ = smc.sequential_monte_carlo_step( + smc_state, + kernel=kernel_log_weights_eq_neg_inf_if_state_lt_zero, + resampling_pred=never_predicate, + seed=step_seed, + ) + return (smc_state, seed), { + 'state': smc_state.state, + 'log_weights': smc_state.log_weights, + 'log_normalizing_constant': smc_state.log_normalizing_constant(), + } + + @jax.jit + def run_smc(seed): + def stop_fn(smc_state_and_seed, _): + smc_state, _ = smc_state_and_seed + return smc_state.step >= num_timesteps + + _, trace = fun_mc.trace( + ( + smc.sequential_monte_carlo_init( + state=jnp.arange(num_particles, dtype=self._dtype), + weight_dtype=self._dtype, + ), + seed, + ), + kernel, + max_num_timesteps, + stop_fn=stop_fn if stop_early else None, + ) + return trace + + trace = run_smc(_test_seed()) + + # Assert that the weights of the non-negative states are finite + self.assertTrue( + jnp.all( + jnp.equal( + jnp.isfinite(trace['log_weights'][:num_timesteps]), + jnp.greater_equal(trace['state'][:num_timesteps], 0.0), + ) + ) + ) + + # Assert that the weights of the negative states are -inf + self.assertTrue( + jnp.all( + jnp.equal( + jnp.isneginf(trace['log_weights'][:num_timesteps]), + jnp.less(trace['state'][:num_timesteps], 0.0), + ) + ) + ) + + # Assert that the weights don't all become negative inf at once. + for i in range(num_timesteps): + self.assertTrue( + jnp.any(jnp.isfinite(trace['log_weights'][i])), msg=f'step={i}' + ) + + # Assert that the log normalizing constant is finite + self.assertTrue( + jnp.isfinite(trace['log_normalizing_constant'][num_timesteps - 1]) + ) + + @parameterized.parameters(True, False) + def test_all_neg_inf_weights_no_resampling(self, stop_early): + """Test handling *all* negative inf weights without resampling.""" + num_particles = 4 + num_timesteps = 5 + max_num_timesteps = 5 + + def kernel(smc_state, seed): + step_seed, seed = util.split_seed(seed, 2) + smc_state, _ = smc.sequential_monte_carlo_step( + smc_state, + kernel=kernel_log_weights_eq_neg_inf_if_state_lt_zero, + resampling_pred=never_predicate, + seed=step_seed, + ) + return (smc_state, seed), { + 'state': smc_state.state, + 'log_weights': smc_state.log_weights, + 'log_normalizing_constant': smc_state.log_normalizing_constant(), + } + + @jax.jit + def run_smc(seed): + def stop_fn(smc_state_and_seed, _): + smc_state, _ = smc_state_and_seed + return smc_state.step >= num_timesteps + + _, trace = fun_mc.trace( + ( + smc.sequential_monte_carlo_init( + state=jnp.arange(num_particles, dtype=self._dtype), + weight_dtype=self._dtype, + ), + seed, + ), + kernel, + max_num_timesteps, + stop_fn=stop_fn if stop_early else None, + ) + return trace + + trace = run_smc(_test_seed()) + + # Assert that the weights of the non-negative states are finite + self.assertTrue( + jnp.all( + jnp.equal( + jnp.isfinite(trace['log_weights'][:num_timesteps]), + jnp.greater_equal(trace['state'][:num_timesteps], 0.0), + ) + ) + ) + + # Assert that the weights of the negative states are -inf + self.assertTrue( + jnp.all( + jnp.equal( + jnp.isneginf(trace['log_weights'][:num_timesteps]), + jnp.less(trace['state'][:num_timesteps], 0.0), + ) + ) + ) + + # Assert that the weights eventually all become negative inf at once. + self.assertTrue( + jnp.any( + jnp.logical_not( + jnp.any( + jnp.isfinite(trace['log_weights'][:num_timesteps]), axis=1 + ) + ) + ) + ) + + # Assert that the bound is negative inf + self.assertTrue( + jnp.isneginf(trace['log_normalizing_constant'][num_timesteps - 1]) + ) + + @parameterized.product( + resampling_pred=[ + always_predicate, + never_predicate, + smc.effective_sample_size_predicate, + ], + num_timesteps=[1, 5], + ) + def test_normalizing_const_gaussian(self, resampling_pred, num_timesteps): + """Check that the normalizing constant estimate of SMC is correct. + + This test chooses the unnormalized targets, s_t, as + + s_t(x) = prod_{r=1}^t exp(-(x_r^2)/2) + + i.e. the unnormalized targets are a sequence of independent Gaussian + potentials with mean 0 and variance 1. Specifically, each potential does not + include the 1/sqrt(2*pi) term that would be required to normalize a standard + Gaussian. Those terms end up being collected into the computed normalizing + constant, i.e. Z should be + + (2 * pi)^(t/2) + + for each timestep t. + + The chosen sequence of unnormalized targets implies that the incremental + weight should be + + s_t(x_{1:t}) / (s_{t-1}(x_{1:t-1}) * q_t(x_t)) + = exp(-(x_t^2)/2) / q_t(x_t). + + In log space this is + + - x_t^2 / 2 - log q_t(x_t). + + Args: + resampling_pred: Resampling predicate. + num_timesteps: The number of steps to run SMC for. + """ + num_particles = 2_000 * num_timesteps + q_std = 1.25 + + def smc_kernel(state, step, seed): + del state, step + q_dist = tfd.Normal( + jnp.zeros([num_particles], self._dtype), + jnp.full([num_particles], q_std, dtype=self._dtype), + ) + new_x = q_dist.sample(seed=seed) + log_q = q_dist.log_prob(new_x) + log_p = -jnp.square(new_x) / 2 + return new_x, (log_p - log_q, ()) + + def kernel(smc_state, seed): + step_seed, seed = util.split_seed(seed, 2) + smc_state, _ = smc.sequential_monte_carlo_step( + smc_state, + kernel=smc_kernel, + resampling_pred=resampling_pred, + seed=step_seed, + ) + return (smc_state, seed), () + + @jax.jit + def run_smc(seed): + (smc_state, _), _ = fun_mc.trace( + ( + smc.sequential_monte_carlo_init( + state=jnp.zeros((num_particles,), self._dtype), + weight_dtype=self._dtype, + ), + seed, + ), + kernel, + num_timesteps, + ) + return smc_state.log_normalizing_constant() + + log_normalizing_constant = run_smc(_test_seed()) + log_normalizing_constant_expected = (num_timesteps / 2.0) * ( + jnp.log(2) + jnp.log(jnp.pi) + ) + self.assertAllClose( + log_normalizing_constant, log_normalizing_constant_expected, atol=1e-2 + ) + + def test_lgssm(self): + """Check filtering distributions and the normalizing constant of a linear Gaussian state-space model.""" + num_timesteps = 10 + init_loc = 0.0 + init_scale = 1.0 + transition_mult = 1.0 + transition_add = 0.0 + transition_scale = 1.0 + obs_mult = 1.0 + obs_add = 0.0 + obs_scale = 1.0 + + obs_seed, smc_seed = util.split_seed(_test_seed(), 2) + obs = util.random_normal((num_timesteps,), self._dtype, obs_seed) + + def smc_kernel(state, step, seed): + # Sample from the initial distribution (when prev_timestep = 0) + init = tfd.Normal( + jnp.full_like(state, init_loc, dtype=self._dtype), + jnp.full_like(state, init_scale, dtype=self._dtype), + ).sample(seed=seed) + + # Sample from the transition distribution (when step > 0) + not_init = tfd.Normal( + state * transition_mult + transition_add, + transition_scale, + ).sample(seed=seed) + new_particles = jnp.where(step == 0, init, not_init) + incremental_log_weights = tfd.Normal( + new_particles * obs_mult + obs_add, obs_scale + ).log_prob(obs[step]) + return new_particles, (incremental_log_weights, ()) + + @jax.jit + def particle_expectation(state, log_weights): + # Assumes state's shape is [num_particles, ...] + return jnp.einsum('i...,i->...', state, jax.nn.softmax(log_weights)) + + def kernel(smc_state, seed): + step_seed, seed = util.split_seed(seed, 2) + smc_state, _ = smc.sequential_monte_carlo_step( + smc_state, + kernel=smc_kernel, + resampling_pred=smc.effective_sample_size_predicate, + seed=step_seed, + ) + state = smc_state.state + log_weights = smc_state.log_weights + filtering_mean = particle_expectation(state, log_weights) + filtering_std = jnp.sqrt( + particle_expectation(state**2, log_weights) - filtering_mean**2 + ) + + return (smc_state, seed), (filtering_mean, filtering_std) + + # SMC bootstrap filtering posterior + num_particles = 500 + + (smc_state, _), (filtering_means, filtering_stds) = fun_mc.trace( + ( + smc.sequential_monte_carlo_init( + state=jnp.full( + [num_particles], float('NaN'), dtype=self._dtype + ), + weight_dtype=self._dtype, + ), + smc_seed, + ), + kernel, + num_timesteps, + ) + + # Ground truth filtering posterior + tfp_lgssm = tfd.LinearGaussianStateSpaceModel( + num_timesteps=num_timesteps, + transition_matrix=jnp.array([[transition_mult]]), + transition_noise=tfd.MultivariateNormalDiag( + jnp.array([transition_add]), + scale_diag=jnp.array([transition_scale]), + ), + observation_matrix=jnp.array([[obs_mult]]), + observation_noise=tfd.MultivariateNormalDiag( + jnp.array([obs_add]), jnp.array([obs_scale]) + ), + initial_state_prior=tfd.MultivariateNormalDiag( + jnp.array([init_loc]), jnp.array([init_scale]) + ), + ) + + gt_filter_results = tfp_lgssm.forward_filter(obs[..., jnp.newaxis]) + gt_filtering_means = gt_filter_results.filtered_means[:, 0] + gt_filtering_stds = jnp.sqrt(gt_filter_results.filtered_covs[:, 0, 0]) + + # SMC log marginal likelihood + log_evidence = smc_state.log_normalizing_constant() + + # Ground truth log evidence + gt_log_evidence = jnp.sum(gt_filter_results.log_likelihoods) + + self.assertAllClose(gt_filtering_means, filtering_means, atol=0.2) + self.assertAllClose(gt_filtering_stds, filtering_stds, atol=0.2) + self.assertAllClose(gt_log_evidence, log_evidence, rtol=0.01) + self.assertAllClose(gt_log_evidence, log_evidence, atol=0.2) + @test_util.multi_backend_test(globals(), 'smc_test') class SMCTest32(SMCTest):