Skip to content

Commit

Permalink
Merge pull request #127 from danielward27/sample_wrapper
Browse files Browse the repository at this point in the history
numpyro sample wrapper
  • Loading branch information
danielward27 authored Jan 12, 2024
2 parents 5248aa1 + 49df452 commit ffcb557
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 20 deletions.
28 changes: 19 additions & 9 deletions docs/api/experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,37 @@ Interfacing with numpyro
--------------------------

Supporting complex inference approaches such as MCMC or variational inference
with arbitrary probabilistic models is out of the scope of this package. However, we do
provide an (experimental) wrapper class,
:class:`~flowjax.experimental.numpyro.TransformedToNumpyro`, which will wrap
a flowjax :class:`~flowjax.distributions.AbstractTransformed` distribution, into a
`numpyro <https://github.com/pyro-ppl/numpyro>`_ distribution.
with arbitrary probabilistic models is out of the scope of this package. However,
we do provide some basic suppport for interfacing with numpyro. We note this support is
in its infancy and there may be breaking changes without warning.

.. warning::
Batch dimensions are handled differently for flowjax distributions and numpyro
distributions. In flowjax we do not make a clear distinction between
event shapes and batch shapes. Hence, when a flowjax distribution is converted to a
numpyro distribution, we assume its shape corresponds to the event shape.

In general, we can use a combination of flowjax and numpyro distributions in a
numpyro model by using :func:`~flowjax.experimental.numpyro.sample`, in place of
numpyro's ``sample``. This will wrap flowjax
:class:`~flowjax.distributions.AbstractTransformed` distributions to numpyro
distributions, using :class:`~flowjax.experimental.numpyro.TransformedToNumpyro`.
This can be used for example to embed normalising flows into arbitrary
probabilistic models. Here is a simple example

.. doctest::


>>> from numpyro.infer import MCMC, NUTS
>>> from flowjax.experimental.numpyro import TransformedToNumpyro
>>> from numpyro import sample
>>> from flowjax.experimental.numpyro import sample
>>> from flowjax.distributions import Normal
>>> import jax.random as jr
>>> import numpy as np

>>> def numpyro_model(X, y):
... "Example regression model defined in terms of flowjax distributions"
... beta = sample("beta", TransformedToNumpyro(Normal(np.zeros(2))))
... sample("y", TransformedToNumpyro(Normal(X @ beta)), obs=y)
... beta = sample("beta", Normal(np.zeros(2)))
... sample("y", Normal(X @ beta), obs=y)

>>> X = np.random.randn(100, 2)
>>> beta_true = np.array([-1, 1])
Expand Down
20 changes: 19 additions & 1 deletion flowjax/experimental/numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
raise

from numpyro.distributions import constraints

from flowjax.bijections import AbstractBijection
from flowjax.distributions import AbstractTransformed
from flowjax.utils import _get_ufunc_signature
Expand Down Expand Up @@ -138,6 +137,25 @@ def _base_condition(self):
return self.condition if self.dist.base_dist.cond_shape else None


def sample(name: str, fn: Any, *args, condition=None, **kwargs):
"""Numpyro sample wrapper that wraps flowjax AbstractTransformed distributions.
Args:
name: Name of the sample site.
fn: A flowjax distribution, numpyro distribution or a stochastic function that
returns a sample.
condition: Conditioning variable if fn is a conditional flowjax distribution.
Defaults to None.
*args: Passed to numpyro sample.
**kwargs: Passed to numpyro sample.
"""

if isinstance(fn, AbstractTransformed):
fn = TransformedToNumpyro(fn, condition)

return numpyro.sample(name, fn, *args, **kwargs)


def register_params(
name: str,
model: PyTree,
Expand Down
19 changes: 9 additions & 10 deletions tests/test_experimental/test_numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from equinox.nn import Linear
from jax.flatten_util import ravel_pytree
from numpyro import sample
from flowjax.experimental.numpyro import sample
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.optim import Adam

Expand All @@ -21,7 +21,7 @@


def numpyro_model():
sample("x", TransformedToNumpyro(Normal(true_mean, true_std)))
sample("x", Normal(true_mean, true_std))


def test_mcmc():
Expand All @@ -39,7 +39,7 @@ def plate_model():
# We could add support in the numpyro context later if we wish. Note below,
# the plate dim is -1 (as the flowjax normal has event dim (2,)).
with numpyro.plate("obs", 10, dim=-1):
sample("x", TransformedToNumpyro(Normal(true_mean, true_std)))
sample("x", Normal(true_mean, true_std))

mcmc = MCMC(NUTS(plate_model), num_warmup=50, num_samples=500) # 2d N(1, 2)
key, subkey = jr.split(key)
Expand All @@ -56,7 +56,6 @@ def test_vi():

def guide(dist):
dist = register_params("guide", dist)
dist = TransformedToNumpyro(dist)
sample("x", dist)

optimizer = Adam(step_size=0.01)
Expand Down Expand Up @@ -94,12 +93,12 @@ def test_conditional_vi():

def model():
cond = sample("cond", ndist.Normal(jnp.zeros((3,))))
sample("x", TransformedToNumpyro(true_dist, cond))
sample("x", true_dist, condition=cond)

def guide(guide_dist):
guide_dist = register_params("guide", guide_dist)
cond = sample("cond", ndist.Normal(jnp.zeros((3,))))
sample("x", TransformedToNumpyro(guide_dist, cond))
sample("x", guide_dist, condition=cond)

optimizer = Adam(step_size=0.01)
svi = SVI(model, partial(guide, guide_dist), optimizer, loss=Trace_ELBO())
Expand All @@ -122,14 +121,14 @@ def test_vi_plate():

def model():
with numpyro.plate("obs", plate_dim):
sample("x", TransformedToNumpyro(Normal(true_mean, true_std)))
sample("x", Normal(true_mean, true_std))

guide_dist = Normal(jnp.ones_like(true_mean), jnp.ones_like(true_std))

def guide(guide_dist):
guide = register_params("guide", guide_dist)
with numpyro.plate("obs", plate_dim):
sample("x", TransformedToNumpyro(guide))
sample("x", guide)

guide = partial(guide, guide_dist)
optimizer = Adam(step_size=0.01)
Expand Down Expand Up @@ -186,13 +185,13 @@ def test_batched_condition():
def model():
with numpyro.plate("N", 10, dim=-2):
cond = sample("cond", ndist.Normal(jnp.zeros(3)).to_event())
sample("x", TransformedToNumpyro(true_dist, cond))
sample("x", true_dist, condition=cond)

def guide(guide_dist):
guide_dist = register_params("guide", guide_dist)
with numpyro.plate("N", 10, dim=-2):
cond = sample("cond", ndist.Normal(jnp.zeros(3)).to_event())
sample("x", TransformedToNumpyro(guide_dist, cond))
sample("x", guide_dist, condition=cond)

optimizer = Adam(step_size=0.01)
svi = SVI(model, partial(guide, guide_dist), optimizer, loss=Trace_ELBO())
Expand Down

0 comments on commit ffcb557

Please sign in to comment.