Skip to content

Commit

Permalink
Merge pull request #206 from aseyboldt/add-bijections-only
Browse files Browse the repository at this point in the history
Closes #70
  • Loading branch information
danielward27 authored Feb 18, 2025
2 parents 9c45347 + 1002c9a commit d597485
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 15 deletions.
6 changes: 6 additions & 0 deletions flowjax/bijections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
NumericalInverse,
Permute,
Reshape,
Sandwich,
)
from .utils import EmbedCondition, Flip, Identity, Invert, Permute, Reshape, Sandwich
from .orthogonal import Householder, DiscreteCosine

__all__ = [
"AdditiveCondition",
Expand All @@ -34,9 +37,11 @@
"Chain",
"Concatenate",
"Coupling",
"DiscreteCosine",
"EmbedCondition",
"Exp",
"Flip",
"Householder",
"Identity",
"Invert",
"LeakyTanh",
Expand All @@ -48,6 +53,7 @@
"Planar",
"RationalQuadraticSpline",
"Reshape",
"Sandwich",
"Scale",
"Scan",
"Sigmoid",
Expand Down
4 changes: 2 additions & 2 deletions flowjax/bijections/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def __init__(
self.cond_shape = cond_shape

def transform_and_log_det(self, x, condition=None):
return x + self.module(condition), jnp.array(0)
return x + self.module(condition), jnp.zeros(())

def inverse_and_log_det(self, y, condition=None):
return y - self.module(condition), jnp.array(0)
return y - self.module(condition), jnp.zeros(())
1 change: 1 addition & 0 deletions flowjax/bijections/bijection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from equinox import AbstractVar
from jaxtyping import Array, ArrayLike
from paramax import unwrap
import jax

from flowjax.utils import _get_ufunc_signature, arraylike_to_array

Expand Down
14 changes: 10 additions & 4 deletions flowjax/bijections/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from collections.abc import Sequence

from paramax import AbstractUnwrappable, unwrap
import jax.numpy as jnp
from jax import Array

from flowjax.bijections.bijection import AbstractBijection
from flowjax.utils import check_shapes_match, merge_cond_shapes
Expand Down Expand Up @@ -36,15 +38,19 @@ def __init__(
self.cond_shape = merge_cond_shapes([unwrap(b).cond_shape for b in unwrapped])
self.bijections = tuple(bijections)

def transform_and_log_det(self, x, condition=None):
log_abs_det_jac = 0
def transform_and_log_det(
self, x: Array, condition: Array | None = None
) -> tuple[Array, Array]:
log_abs_det_jac = jnp.zeros(())
for bijection in self.bijections:
x, log_abs_det_jac_i = bijection.transform_and_log_det(x, condition)
log_abs_det_jac += log_abs_det_jac_i.sum()
return x, log_abs_det_jac

def inverse_and_log_det(self, y, condition=None):
log_abs_det_jac = 0
def inverse_and_log_det(
self, y: Array, condition: Array | None = None
) -> tuple[Array, Array]:
log_abs_det_jac = jnp.zeros(())
for bijection in reversed(self.bijections):
y, log_abs_det_jac_i = bijection.inverse_and_log_det(y, condition)
log_abs_det_jac += log_abs_det_jac_i.sum()
Expand Down
1 change: 1 addition & 0 deletions flowjax/bijections/coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import equinox as eqx
import jax.nn as jnn
import jax.numpy as jnp
import jax
import paramax
from jaxtyping import PRNGKeyArray

Expand Down
1 change: 1 addition & 0 deletions flowjax/bijections/jax_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Callable

import equinox as eqx
from jax import Array
import jax.numpy as jnp
from jax.lax import scan
from jax.tree_util import tree_leaves, tree_map
Expand Down
71 changes: 71 additions & 0 deletions flowjax/bijections/orthogonal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from paramax import AbstractUnwrappable, Parameterize
from flowjax.bijections.bijection import AbstractBijection
from jax import Array
import jax.numpy as jnp
import jax.nn as jnn
from jax.scipy import fft


class Householder(AbstractBijection):
"""A Householder reflection bijection.
This bijection implements a Householder reflection, which is a linear
transformation that reflects vectors across a hyperplane defined by a normal
vector (params). The transformation is its own inverse and volume-preserving
(determinant = ±1).
Given a unit vector v, the transformation is:
x → x - 2(x·v)v
Attributes:
shape: Shape of the input/output vectors
cond_shape: Shape of conditional inputs (None as this bijection is unconditional)
params: Normal vector defining the reflection hyperplane. The vector is
normalized in the transformation, so scaling params will have no effect
on the bijection.
"""
shape: tuple[int, ...]
unit_vec: Array | AbstractUnwrappable
cond_shape = None

def __init__(self, params: Array):
self.shape = (params.shape[-1],)
self.unit_vec = Parameterize(lambda x: x / jnp.linalg.norm(x), params)

def _householder(self, x: Array) -> Array:
return x - 2 * self.unit_vec * (x @ self.unit_vec)

def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None):
return self._householder(x), jnp.zeros(())

def inverse_and_log_det(self, y: Array, condition: Array | None = None):
return self._householder(y), jnp.zeros(())


class DiscreteCosine(AbstractBijection):
"""Discrete Cosine Transform (DCT) bijection.
This bijection applies the DCT or its inverse along a specified axis.
Attributes:
shape: Shape of the input/output arrays
cond_shape: Shape of conditional inputs (None as this bijection is unconditional)
axis: Axis along which to apply the DCT
norm: Normalization method, fixed to 'ortho' to ensure bijectivity
"""

shape: tuple[int, ...]
cond_shape = None
axis: int

def __init__(self, shape, *, axis: int = -1):
self.shape = shape
self.axis = axis

def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None):
y = fft.dct(x, norm="ortho", axis=self.axis)
return y, jnp.zeros(())

def inverse_and_log_det(self, y: Array, condition: Array | None = None):
x = fft.idct(y, norm="ortho", axis=self.axis)
return x, jnp.zeros(())
8 changes: 6 additions & 2 deletions flowjax/bijections/softplus.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

from typing import ClassVar

import jax
import jax.numpy as jnp
from jax.nn import softplus
from jax.nn import softplus, soft_sign
from jaxtyping import Array, ArrayLike
from paramax import AbstractUnwrappable, Parameterize, unwrap
from paramax.utils import inv_softplus

from flowjax.bijections.bijection import AbstractBijection

from flowjax.utils import arraylike_to_array

class SoftPlus(AbstractBijection):
r"""Transforms to positive domain using softplus :math:`y = \log(1 + \exp(x))`."""
Expand Down
60 changes: 53 additions & 7 deletions flowjax/bijections/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from jaxtyping import Array, Int

from flowjax.bijections.bijection import AbstractBijection
from flowjax.utils import arraylike_to_array
from flowjax.bijections.chain import Chain
from flowjax.utils import arraylike_to_array, check_shapes_match, merge_cond_shapes


class Invert(AbstractBijection):
Expand Down Expand Up @@ -79,10 +80,10 @@ def __init__(self, permutation: Int[Array | np.ndarray, "..."]):
)

def transform_and_log_det(self, x, condition=None):
return x[self.permutation], jnp.array(0)
return x[self.permutation], jnp.array(0.0)

def inverse_and_log_det(self, y, condition=None):
return y[self.inverse_permutation], jnp.array(0)
return y[self.inverse_permutation], jnp.array(0.0)


class Flip(AbstractBijection):
Expand All @@ -95,11 +96,15 @@ class Flip(AbstractBijection):
shape: tuple[int, ...] = ()
cond_shape: ClassVar[None] = None

def transform_and_log_det(self, x, condition=None):
return jnp.flip(x), jnp.array(0)
def transform_and_log_det(
self, x: Array, condition: Array | None = None
) -> tuple[Array, Array]:
return jnp.flip(x), jnp.zeros(())

def inverse_and_log_det(self, y, condition=None):
return jnp.flip(y), jnp.array(0)
def inverse_and_log_det(
self, y: Array, condition: Array | None = None
) -> tuple[Array, Array]:
return jnp.flip(y), jnp.zeros(())


class Indexed(AbstractBijection):
Expand Down Expand Up @@ -300,3 +305,44 @@ def inverse_and_log_det(self, y, condition=None):
x = self.inverter(self.bijection, y, condition)
_, log_det = self.bijection.transform_and_log_det(x, condition)
return x, -log_det


class Sandwich(AbstractBijection):
"""A bijection that composes bijections in a nested structure: g⁻¹ ∘ f ∘ g.
The Sandwich bijection creates a new transformation by "sandwiching" one
bijection between the forward and inverse applications of another. Given
bijections f and g, it computes:
Forward: x → g⁻¹(f(g(x)))
Inverse: y → g⁻¹(f⁻¹(g(y)))
This composition pattern is useful for:
- Creating symmetries in the transformation
- Applying a transformation in a different coordinate system
- Building more complex bijections from simpler ones
Attributes:
shape: Shape of the input/output arrays
cond_shape: Shape of conditional inputs
outer: Transformation g applied first and inverted last
inner: Transformation f applied in the middle
"""
shape: tuple[int, ...]
cond_shape: tuple[int, ...] | None
outer: AbstractBijection
inner: AbstractBijection

def __init__(self, outer: AbstractBijection, inner: AbstractBijection):
check_shapes_match([outer.shape, inner.shape])
self.cond_shape = merge_cond_shapes([outer.cond_shape, inner.cond_shape])
self.shape = inner.shape
self.outer = outer
self.inner = inner

def transform_and_log_det(self, x: Array, condition=None) -> tuple[Array, Array]:
chain = Chain([self.outer, self.inner, Invert(self.outer)])
return chain.transform_and_log_det(x, condition)

def inverse_and_log_det(self, y: Array, condition=None) -> tuple[Array, Array]:
chain = Chain([self.outer, self.inner, Invert(self.outer)])
return chain.inverse_and_log_det(y, condition)
11 changes: 11 additions & 0 deletions tests/test_bijections/test_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import jax.numpy as jnp
import jax.random as jr
import pytest
import numpy as np
from scipy import stats

from flowjax.bijections import (
AbstractBijection,
Expand All @@ -16,9 +18,11 @@
Chain,
Concatenate,
Coupling,
DiscreteCosine,
EmbedCondition,
Exp,
Flip,
Householder,
Identity,
Indexed,
LeakyTanh,
Expand All @@ -30,6 +34,7 @@
Power,
RationalQuadraticSpline,
Reshape,
Sandwich,
Scale,
Scan,
Sigmoid,
Expand Down Expand Up @@ -217,6 +222,12 @@
partial(bisection_search, lower=-1, upper=1, atol=1e-7),
),
),
"Sandwich": lambda: Sandwich(
Exp(),
Affine(0.1, 0.5),
),
"DiscreteCosine": lambda: DiscreteCosine(shape=(3, 4)),
"Householder": lambda: Householder(jnp.ones(3)),
}


Expand Down

0 comments on commit d597485

Please sign in to comment.