diff --git a/docs/api/training.rst b/docs/api/training.rst index 88383c91..f04f2a84 100644 --- a/docs/api/training.rst +++ b/docs/api/training.rst @@ -10,3 +10,8 @@ Alternatively, we can use ``fit_to_variational_target`` to fit the flow to a fun using variational inference. .. autofunction:: flowjax.train.fit_to_variational_target + +Finally, for more control over the training script, you may still find the ``step`` +function useful. + +.. autofunction:: flowjax.train.step diff --git a/docs/api/wrappers.rst b/docs/api/wrappers.rst index d8122555..f523b4cc 100644 --- a/docs/api/wrappers.rst +++ b/docs/api/wrappers.rst @@ -3,3 +3,4 @@ Wrappers .. automodule:: flowjax.wrappers :members: :undoc-members: + :member-order: bysource \ No newline at end of file diff --git a/flowjax/bijections/affine.py b/flowjax/bijections/affine.py index cb46bdbd..497a361e 100644 --- a/flowjax/bijections/affine.py +++ b/flowjax/bijections/affine.py @@ -146,9 +146,7 @@ def __init__( lower: bool = True, ): loc, arr = (arraylike_to_array(a, dtype=float) for a in (loc, arr)) - if (arr.ndim != 2) or ( - arr.shape[0] != arr.shape[1] - ): # TODO unnecersary if beartype enabled + if (arr.ndim != 2) or (arr.shape[0] != arr.shape[1]): raise ValueError("arr must be a square, 2-dimensional matrix.") dim = arr.shape[0] diff --git a/flowjax/train/__init__.py b/flowjax/train/__init__.py index 575ba878..ea2fd1be 100644 --- a/flowjax/train/__init__.py +++ b/flowjax/train/__init__.py @@ -1,8 +1,11 @@ """Utilities for training flows, fitting to samples or ysing variational inference.""" + from .data_fit import fit_to_data +from .train_utils import step from .variational_fit import fit_to_variational_target __all__ = [ "fit_to_data", "fit_to_variational_target", + "step", ] diff --git a/flowjax/wrappers.py b/flowjax/wrappers.py index 818f26cc..d2e38a0a 100644 --- a/flowjax/wrappers.py +++ b/flowjax/wrappers.py @@ -14,18 +14,19 @@ ``sample_and_log_prob``. * Prior to computing the loss functions. -If implementing a custom unwrappable, bear in mind: -* The wrapper should avoid implementing information or logic beyond what is required - for initialization and unwrapping, as this information will be lost when unwrapping. -* The unwrapping should support broadcasting/vmapped initializations. Otherwise, if - the unwrappable is created within a batched context, it will fail to unwrap - correctly. +.. note:: + + If creating a custom unwrappable, remember that unwrapping will generally occur + after initialization of the model. Because of this, we recommend ensuring that + the ``unwrap`` method supports unwrapping if the model is constructed in a + vectorized context, such as ``eqx.filter_vmap``, e.g. through broadcasting or + vectorization. """ from abc import abstractmethod from collections.abc import Callable, Iterable -from typing import Any, ClassVar, Generic, TypeVar +from typing import Any, Generic, TypeVar import equinox as eqx import jax @@ -40,11 +41,29 @@ def unwrap(tree: PyTree): - """Unwrap all :class:`AbstractUnwrappable` nodes within a pytree.""" + """Recursively unwraps all :class:`AbstractUnwrappable` nodes within a pytree. + + This leaves all other nodes unchanged. If nested, the innermost + ``AbstractUnwrappable`` nodes are unwrapped first. + + Example: + >>> from flowjax.wrappers import Parameterize, unwrap + >>> import jax.numpy as jnp + >>> params = Parameterize(jnp.exp, jnp.zeros(3)) + >>> unwrap(("abc", 1, params)) + ("abc", 1, Array([1., 1., 1.], dtype=float32)) + """ + + def _map_fn(leaf): + if isinstance(leaf, AbstractUnwrappable): + # Flatten to ignore until all contained AbstractUnwrappables are unwrapped + flat, tree_def = eqx.tree_flatten_one_level(leaf) + tree = jax.tree_util.tree_unflatten(tree_def, unwrap(flat)) + return tree.unwrap() + return leaf + return jax.tree_util.tree_map( - f=lambda leaf: ( - leaf.recursive_unwrap() if isinstance(leaf, AbstractUnwrappable) else leaf - ), + f=_map_fn, tree=tree, is_leaf=lambda x: isinstance(x, AbstractUnwrappable), ) @@ -57,37 +76,54 @@ class AbstractUnwrappable(eqx.Module, Generic[T]): behaviour to apply upon unwrapping before use. This can be used e.g. to apply parameter constraints, such as making scale parameters postive, or applying stop_gradient before accessing the parameters. - - If ``_dummy`` is set to an array (must have shape ()), this is used for inferring - vmapped dimensions (and sizes) when calling :func:`unwrap` to automatically - vecotorize the method. In some cases this is important for supporting the case where - an :class:`AbstractUnwrappable` is created within e.g. ``eqx.filter_vmap``. """ - _dummy: eqx.AbstractVar[Int[Scalar, ""] | None] + @abstractmethod + def unwrap(self) -> T: + """Returns the unwrapped pytree, assuming no wrapped subnodes exist.""" + pass - def recursive_unwrap(self) -> T: - """Returns the unwrapped pytree, unwrapping subnodes as required.""" - def vectorized_unwrap(unwrappable): - if unwrappable._dummy is None: - return unwrappable.unwrap() +class Parameterize(AbstractUnwrappable[T]): + """Unwrap an object by calling fn with args and kwargs. - def v_unwrap(unwrappable): - return unwrappable.unwrap() + All of fn, args and kwargs may contain trainable parameters. If the Parameterize is + created within ``eqx.filter_vmap``, unwrapping is automatically vectorized + correctly, as long as the vmapped constructor adds leading batch + dimensions to all arrays (the default for ``eqx.filter_vmap``). - for dim in reversed(unwrappable._dummy.shape): - v_unwrap = eqx.filter_vmap(v_unwrap, axis_size=dim) - return v_unwrap(unwrappable) + Example: + >>> from flowjax.wrappers import Parameterize, unwrap + >>> import jax.numpy as jnp + >>> positive = Parameterize(jnp.exp, jnp.zeros(3)) + >>> unwrap(positive) # Aplies exp on unwrapping + Array([1., 1., 1.], dtype=float32) - flat, tree_def = eqx.tree_flatten_one_level(self) - tree = jax.tree_util.tree_unflatten(tree_def, unwrap(flat)) - return vectorized_unwrap(tree) + Args: + fn: Callable to call with args, and kwargs. + *args: Positional arguments to pass to fn. + **kwargs: Keyword arguments to pass to fn. + """ + + fn: Callable[..., T] + args: Iterable + kwargs: dict[str, Any] + _dummy: Int[Scalar, ""] # Used to track vectorized construction. + + def __init__(self, fn: Callable, *args, **kwargs): + self.fn = fn + self.args = args + self.kwargs = kwargs + self._dummy = jnp.empty((), int) - @abstractmethod def unwrap(self) -> T: - """Returns the unwrapped pytree, assuming no wrapped subnodes exist.""" - pass + + def _unwrap_fn(self): + return self.fn(*self.args, **self.kwargs) + + for dim in reversed(self._dummy.shape): # vectorize if constructed under vmap + _unwrap_fn = eqx.filter_vmap(_unwrap_fn, axis_size=dim) + return _unwrap_fn(self) class NonTrainable(AbstractUnwrappable[T]): @@ -103,7 +139,6 @@ class NonTrainable(AbstractUnwrappable[T]): """ tree: T - _dummy: ClassVar[None] = None def unwrap(self) -> T: differentiable, static = eqx.partition(self.tree, eqx.is_array_like) @@ -130,35 +165,6 @@ def _map_fn(leaf): ) -class Parameterize(AbstractUnwrappable[T]): - """Unwrap an object by calling fn with args and kwargs. - - All of fn, args and kwargs may contain trainable parameters. If the Parameterize is - created within ``eqx.filter_vmap``, unwrapping is automatically vectorized - correctly, as long as the vmapped constructor adds leading batch - dimensions to all arrays (the default for ``eqx.filter_vmap``). - - Args: - fn: Callable to call with args, and kwargs. - *args: Positional arguments to pass to fn. - **kwargs: Keyword arguments to pass to fn. - """ - - fn: Callable[..., T] - args: Iterable - kwargs: dict[str, Any] - _dummy: Int[Scalar, ""] - - def __init__(self, fn: Callable, *args, **kwargs): - self.fn = fn - self.args = args - self.kwargs = kwargs - self._dummy = jnp.empty((), int) - - def unwrap(self) -> T: - return self.fn(*self.args, **self.kwargs) - - class WeightNormalization(AbstractUnwrappable[Array]): """Applies weight normalization (https://arxiv.org/abs/1602.07868). @@ -168,7 +174,6 @@ class WeightNormalization(AbstractUnwrappable[Array]): weight: Array | AbstractUnwrappable[Array] scale: Array | AbstractUnwrappable[Array] = eqx.field(init=False) - _dummy: ClassVar[None] = None def __init__(self, weight: Array | AbstractUnwrappable[Array]): self.weight = weight diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 8420bdd2..807403ab 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -1,8 +1,11 @@ +from math import prod + import equinox as eqx import jax import jax.numpy as jnp import jax.random as jr import pytest +from jax.tree_util import tree_map from flowjax.distributions import Normal from flowjax.wrappers import ( @@ -18,24 +21,6 @@ def test_Parameterize(): diag = Parameterize(jnp.diag, jnp.ones(3)) assert pytest.approx(jnp.eye(3)) == unwrap(diag) - # Test works when vmapped (note diag does not follow standard vectorization rules) - v_diag = eqx.filter_vmap(Parameterize)(jnp.diag, jnp.ones((4, 3))) - expected = eqx.filter_vmap(jnp.eye, axis_size=4)(3) - assert pytest.approx(expected) == unwrap(v_diag) - - # Test works when double vmapped - v_diag = eqx.filter_vmap(eqx.filter_vmap(Parameterize))( - jnp.diag, jnp.ones((5, 4, 3)) - ) - expected = eqx.filter_vmap(eqx.filter_vmap(jnp.eye, axis_size=4), axis_size=5)(3) - assert pytest.approx(expected) == unwrap(v_diag) - - # Test works when no arrays present (in which case axis_size is relied on) - unwrappable = eqx.filter_vmap( - eqx.filter_vmap(Parameterize, axis_size=2), axis_size=3 - )(lambda: jnp.zeros(())) - assert pytest.approx(unwrap(unwrappable)) == jnp.zeros((3, 2)) - def test_nested_Parameterized(): param = Parameterize( @@ -45,16 +30,14 @@ def test_nested_Parameterized(): assert unwrap(param) == jnp.square(jnp.square(jnp.square(2))) -def test_NonTrainable_and_non_trainable(): - dist1 = eqx.tree_at(lambda dist: dist.bijection, Normal(), replace_fn=NonTrainable) - dist2 = non_trainable(Normal()) +def test_non_trainable(): + dist = non_trainable(Normal()) def loss(dist, x): return dist.log_prob(x) - for dist in [dist1, dist2]: - grad = eqx.filter_grad(loss)(dist, 1) - assert pytest.approx(0) == jax.flatten_util.ravel_pytree(grad)[0] + grad = eqx.filter_grad(loss)(dist, 1) + assert pytest.approx(0) == jax.flatten_util.ravel_pytree(grad)[0] def test_WeightNormalization(): @@ -67,10 +50,32 @@ def test_WeightNormalization(): unwrap(weight_norm), axis=-1, keepdims=True ) - # Test under vmap - arr = jr.normal(jr.PRNGKey(0), (5, 10, 3)) - weight_norm = eqx.filter_vmap(WeightNormalization)(arr) - expected = unwrap(weight_norm.scale) - assert pytest.approx(expected) == eqx.filter_vmap( - lambda arr: jnp.linalg.norm(arr, axis=1, keepdims=True) - )(unwrap(weight_norm)) + +test_cases = { + "NonTrainable": lambda key: NonTrainable(jr.normal(key, 10)), + "Parameterize-exp": lambda key: Parameterize(jnp.exp, jr.normal(key, 10)), + "Parameterize-diag": lambda key: Parameterize(jnp.diag, jr.normal(key, 10)), + "WeightNormalization": lambda key: WeightNormalization(jr.normal(key, (10, 2))), +} + + +@pytest.mark.parametrize("shape", [(), (2,), (5, 2, 4)]) +@pytest.mark.parametrize("wrapper_fn", test_cases.values(), ids=test_cases.keys()) +def test_vectorization_invariance(wrapper_fn, shape): + keys = jr.split(jr.PRNGKey(0), prod(shape)) + wrapper = wrapper_fn(keys[0]) # Standard init + + # Multiple vmap init - should have same result in zero-th index + vmap_wrapper_fn = wrapper_fn + for _ in shape: + vmap_wrapper_fn = eqx.filter_vmap(vmap_wrapper_fn) + + vmap_wrapper = vmap_wrapper_fn(keys.reshape((*shape, 2))) + + unwrapped = unwrap(wrapper) + unwrapped_vmap = unwrap(vmap_wrapper) + unwrapped_vmap_zero = tree_map( + lambda leaf: leaf[*([0] * len(shape)), ...], + unwrapped_vmap, + ) + assert eqx.tree_equal(unwrapped, unwrapped_vmap_zero, atol=1e-7)