Skip to content

Commit

Permalink
Merge pull request #175 from danielward27/simplify_wrappers
Browse files Browse the repository at this point in the history
Simplify wrappers
  • Loading branch information
danielward27 authored Sep 13, 2024
2 parents 298cc1d + 2e70dcb commit 969de58
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 98 deletions.
5 changes: 5 additions & 0 deletions docs/api/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@ Alternatively, we can use ``fit_to_variational_target`` to fit the flow to a fun
using variational inference.

.. autofunction:: flowjax.train.fit_to_variational_target

Finally, for more control over the training script, you may still find the ``step``
function useful.

.. autofunction:: flowjax.train.step
1 change: 1 addition & 0 deletions docs/api/wrappers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ Wrappers
.. automodule:: flowjax.wrappers
:members:
:undoc-members:
:member-order: bysource
4 changes: 1 addition & 3 deletions flowjax/bijections/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,7 @@ def __init__(
lower: bool = True,
):
loc, arr = (arraylike_to_array(a, dtype=float) for a in (loc, arr))
if (arr.ndim != 2) or (
arr.shape[0] != arr.shape[1]
): # TODO unnecersary if beartype enabled
if (arr.ndim != 2) or (arr.shape[0] != arr.shape[1]):
raise ValueError("arr must be a square, 2-dimensional matrix.")
dim = arr.shape[0]

Expand Down
3 changes: 3 additions & 0 deletions flowjax/train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Utilities for training flows, fitting to samples or ysing variational inference."""

from .data_fit import fit_to_data
from .train_utils import step
from .variational_fit import fit_to_variational_target

__all__ = [
"fit_to_data",
"fit_to_variational_target",
"step",
]
133 changes: 69 additions & 64 deletions flowjax/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@
``sample_and_log_prob``.
* Prior to computing the loss functions.
If implementing a custom unwrappable, bear in mind:
* The wrapper should avoid implementing information or logic beyond what is required
for initialization and unwrapping, as this information will be lost when unwrapping.
* The unwrapping should support broadcasting/vmapped initializations. Otherwise, if
the unwrappable is created within a batched context, it will fail to unwrap
correctly.
.. note::
If creating a custom unwrappable, remember that unwrapping will generally occur
after initialization of the model. Because of this, we recommend ensuring that
the ``unwrap`` method supports unwrapping if the model is constructed in a
vectorized context, such as ``eqx.filter_vmap``, e.g. through broadcasting or
vectorization.
"""

from abc import abstractmethod
from collections.abc import Callable, Iterable
from typing import Any, ClassVar, Generic, TypeVar
from typing import Any, Generic, TypeVar

import equinox as eqx
import jax
Expand All @@ -40,11 +41,29 @@


def unwrap(tree: PyTree):
"""Unwrap all :class:`AbstractUnwrappable` nodes within a pytree."""
"""Recursively unwraps all :class:`AbstractUnwrappable` nodes within a pytree.
This leaves all other nodes unchanged. If nested, the innermost
``AbstractUnwrappable`` nodes are unwrapped first.
Example:
>>> from flowjax.wrappers import Parameterize, unwrap
>>> import jax.numpy as jnp
>>> params = Parameterize(jnp.exp, jnp.zeros(3))
>>> unwrap(("abc", 1, params))
("abc", 1, Array([1., 1., 1.], dtype=float32))
"""

def _map_fn(leaf):
if isinstance(leaf, AbstractUnwrappable):
# Flatten to ignore until all contained AbstractUnwrappables are unwrapped
flat, tree_def = eqx.tree_flatten_one_level(leaf)
tree = jax.tree_util.tree_unflatten(tree_def, unwrap(flat))
return tree.unwrap()
return leaf

return jax.tree_util.tree_map(
f=lambda leaf: (
leaf.recursive_unwrap() if isinstance(leaf, AbstractUnwrappable) else leaf
),
f=_map_fn,
tree=tree,
is_leaf=lambda x: isinstance(x, AbstractUnwrappable),
)
Expand All @@ -57,37 +76,54 @@ class AbstractUnwrappable(eqx.Module, Generic[T]):
behaviour to apply upon unwrapping before use. This can be used e.g. to apply
parameter constraints, such as making scale parameters postive, or applying
stop_gradient before accessing the parameters.
If ``_dummy`` is set to an array (must have shape ()), this is used for inferring
vmapped dimensions (and sizes) when calling :func:`unwrap` to automatically
vecotorize the method. In some cases this is important for supporting the case where
an :class:`AbstractUnwrappable` is created within e.g. ``eqx.filter_vmap``.
"""

_dummy: eqx.AbstractVar[Int[Scalar, ""] | None]
@abstractmethod
def unwrap(self) -> T:
"""Returns the unwrapped pytree, assuming no wrapped subnodes exist."""
pass

def recursive_unwrap(self) -> T:
"""Returns the unwrapped pytree, unwrapping subnodes as required."""

def vectorized_unwrap(unwrappable):
if unwrappable._dummy is None:
return unwrappable.unwrap()
class Parameterize(AbstractUnwrappable[T]):
"""Unwrap an object by calling fn with args and kwargs.
def v_unwrap(unwrappable):
return unwrappable.unwrap()
All of fn, args and kwargs may contain trainable parameters. If the Parameterize is
created within ``eqx.filter_vmap``, unwrapping is automatically vectorized
correctly, as long as the vmapped constructor adds leading batch
dimensions to all arrays (the default for ``eqx.filter_vmap``).
for dim in reversed(unwrappable._dummy.shape):
v_unwrap = eqx.filter_vmap(v_unwrap, axis_size=dim)
return v_unwrap(unwrappable)
Example:
>>> from flowjax.wrappers import Parameterize, unwrap
>>> import jax.numpy as jnp
>>> positive = Parameterize(jnp.exp, jnp.zeros(3))
>>> unwrap(positive) # Aplies exp on unwrapping
Array([1., 1., 1.], dtype=float32)
flat, tree_def = eqx.tree_flatten_one_level(self)
tree = jax.tree_util.tree_unflatten(tree_def, unwrap(flat))
return vectorized_unwrap(tree)
Args:
fn: Callable to call with args, and kwargs.
*args: Positional arguments to pass to fn.
**kwargs: Keyword arguments to pass to fn.
"""

fn: Callable[..., T]
args: Iterable
kwargs: dict[str, Any]
_dummy: Int[Scalar, ""] # Used to track vectorized construction.

def __init__(self, fn: Callable, *args, **kwargs):
self.fn = fn
self.args = args
self.kwargs = kwargs
self._dummy = jnp.empty((), int)

@abstractmethod
def unwrap(self) -> T:
"""Returns the unwrapped pytree, assuming no wrapped subnodes exist."""
pass

def _unwrap_fn(self):
return self.fn(*self.args, **self.kwargs)

for dim in reversed(self._dummy.shape): # vectorize if constructed under vmap
_unwrap_fn = eqx.filter_vmap(_unwrap_fn, axis_size=dim)
return _unwrap_fn(self)


class NonTrainable(AbstractUnwrappable[T]):
Expand All @@ -103,7 +139,6 @@ class NonTrainable(AbstractUnwrappable[T]):
"""

tree: T
_dummy: ClassVar[None] = None

def unwrap(self) -> T:
differentiable, static = eqx.partition(self.tree, eqx.is_array_like)
Expand All @@ -130,35 +165,6 @@ def _map_fn(leaf):
)


class Parameterize(AbstractUnwrappable[T]):
"""Unwrap an object by calling fn with args and kwargs.
All of fn, args and kwargs may contain trainable parameters. If the Parameterize is
created within ``eqx.filter_vmap``, unwrapping is automatically vectorized
correctly, as long as the vmapped constructor adds leading batch
dimensions to all arrays (the default for ``eqx.filter_vmap``).
Args:
fn: Callable to call with args, and kwargs.
*args: Positional arguments to pass to fn.
**kwargs: Keyword arguments to pass to fn.
"""

fn: Callable[..., T]
args: Iterable
kwargs: dict[str, Any]
_dummy: Int[Scalar, ""]

def __init__(self, fn: Callable, *args, **kwargs):
self.fn = fn
self.args = args
self.kwargs = kwargs
self._dummy = jnp.empty((), int)

def unwrap(self) -> T:
return self.fn(*self.args, **self.kwargs)


class WeightNormalization(AbstractUnwrappable[Array]):
"""Applies weight normalization (https://arxiv.org/abs/1602.07868).
Expand All @@ -168,7 +174,6 @@ class WeightNormalization(AbstractUnwrappable[Array]):

weight: Array | AbstractUnwrappable[Array]
scale: Array | AbstractUnwrappable[Array] = eqx.field(init=False)
_dummy: ClassVar[None] = None

def __init__(self, weight: Array | AbstractUnwrappable[Array]):
self.weight = weight
Expand Down
67 changes: 36 additions & 31 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from math import prod

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import pytest
from jax.tree_util import tree_map

from flowjax.distributions import Normal
from flowjax.wrappers import (
Expand All @@ -18,24 +21,6 @@ def test_Parameterize():
diag = Parameterize(jnp.diag, jnp.ones(3))
assert pytest.approx(jnp.eye(3)) == unwrap(diag)

# Test works when vmapped (note diag does not follow standard vectorization rules)
v_diag = eqx.filter_vmap(Parameterize)(jnp.diag, jnp.ones((4, 3)))
expected = eqx.filter_vmap(jnp.eye, axis_size=4)(3)
assert pytest.approx(expected) == unwrap(v_diag)

# Test works when double vmapped
v_diag = eqx.filter_vmap(eqx.filter_vmap(Parameterize))(
jnp.diag, jnp.ones((5, 4, 3))
)
expected = eqx.filter_vmap(eqx.filter_vmap(jnp.eye, axis_size=4), axis_size=5)(3)
assert pytest.approx(expected) == unwrap(v_diag)

# Test works when no arrays present (in which case axis_size is relied on)
unwrappable = eqx.filter_vmap(
eqx.filter_vmap(Parameterize, axis_size=2), axis_size=3
)(lambda: jnp.zeros(()))
assert pytest.approx(unwrap(unwrappable)) == jnp.zeros((3, 2))


def test_nested_Parameterized():
param = Parameterize(
Expand All @@ -45,16 +30,14 @@ def test_nested_Parameterized():
assert unwrap(param) == jnp.square(jnp.square(jnp.square(2)))


def test_NonTrainable_and_non_trainable():
dist1 = eqx.tree_at(lambda dist: dist.bijection, Normal(), replace_fn=NonTrainable)
dist2 = non_trainable(Normal())
def test_non_trainable():
dist = non_trainable(Normal())

def loss(dist, x):
return dist.log_prob(x)

for dist in [dist1, dist2]:
grad = eqx.filter_grad(loss)(dist, 1)
assert pytest.approx(0) == jax.flatten_util.ravel_pytree(grad)[0]
grad = eqx.filter_grad(loss)(dist, 1)
assert pytest.approx(0) == jax.flatten_util.ravel_pytree(grad)[0]


def test_WeightNormalization():
Expand All @@ -67,10 +50,32 @@ def test_WeightNormalization():
unwrap(weight_norm), axis=-1, keepdims=True
)

# Test under vmap
arr = jr.normal(jr.PRNGKey(0), (5, 10, 3))
weight_norm = eqx.filter_vmap(WeightNormalization)(arr)
expected = unwrap(weight_norm.scale)
assert pytest.approx(expected) == eqx.filter_vmap(
lambda arr: jnp.linalg.norm(arr, axis=1, keepdims=True)
)(unwrap(weight_norm))

test_cases = {
"NonTrainable": lambda key: NonTrainable(jr.normal(key, 10)),
"Parameterize-exp": lambda key: Parameterize(jnp.exp, jr.normal(key, 10)),
"Parameterize-diag": lambda key: Parameterize(jnp.diag, jr.normal(key, 10)),
"WeightNormalization": lambda key: WeightNormalization(jr.normal(key, (10, 2))),
}


@pytest.mark.parametrize("shape", [(), (2,), (5, 2, 4)])
@pytest.mark.parametrize("wrapper_fn", test_cases.values(), ids=test_cases.keys())
def test_vectorization_invariance(wrapper_fn, shape):
keys = jr.split(jr.PRNGKey(0), prod(shape))
wrapper = wrapper_fn(keys[0]) # Standard init

# Multiple vmap init - should have same result in zero-th index
vmap_wrapper_fn = wrapper_fn
for _ in shape:
vmap_wrapper_fn = eqx.filter_vmap(vmap_wrapper_fn)

vmap_wrapper = vmap_wrapper_fn(keys.reshape((*shape, 2)))

unwrapped = unwrap(wrapper)
unwrapped_vmap = unwrap(vmap_wrapper)
unwrapped_vmap_zero = tree_map(
lambda leaf: leaf[*([0] * len(shape)), ...],
unwrapped_vmap,
)
assert eqx.tree_equal(unwrapped, unwrapped_vmap_zero, atol=1e-7)

0 comments on commit 969de58

Please sign in to comment.