From d5b5941e23a1e79750430a5c9369c9d8afad8bc6 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 28 Jan 2025 15:51:40 +0100 Subject: [PATCH 1/3] Add DCT, Householder, Neg and Sandwich bijectons --- flowjax/bijections/__init__.py | 7 +++ flowjax/bijections/affine.py | 4 +- flowjax/bijections/chain.py | 14 +++-- flowjax/bijections/orthogonal.py | 72 ++++++++++++++++++++++++ flowjax/bijections/utils.py | 61 ++++++++++++++++++-- tests/test_bijections/test_bijections.py | 11 ++++ 6 files changed, 157 insertions(+), 12 deletions(-) create mode 100644 flowjax/bijections/orthogonal.py diff --git a/flowjax/bijections/__init__.py b/flowjax/bijections/__init__.py index 916e1219..987aea1e 100644 --- a/flowjax/bijections/__init__.py +++ b/flowjax/bijections/__init__.py @@ -24,7 +24,10 @@ NumericalInverse, Permute, Reshape, + Sandwich, ) +from .utils import EmbedCondition, Flip, Identity, Invert, Permute, Reshape, Sandwich +from .orthogonal import Householder, DCT, Neg __all__ = [ "AdditiveCondition", @@ -34,20 +37,24 @@ "Chain", "Concatenate", "Coupling", + "DCT", "EmbedCondition", "Exp", "Flip", + "Householder", "Identity", "Invert", "LeakyTanh", "Loc", "MaskedAutoregressive", "Indexed", + "Neg", "Permute", "Power", "Planar", "RationalQuadraticSpline", "Reshape", + "Sandwich", "Scale", "Scan", "Sigmoid", diff --git a/flowjax/bijections/affine.py b/flowjax/bijections/affine.py index d066a610..1bcf995e 100644 --- a/flowjax/bijections/affine.py +++ b/flowjax/bijections/affine.py @@ -200,7 +200,7 @@ def __init__( self.cond_shape = cond_shape def transform_and_log_det(self, x, condition=None): - return x + self.module(condition), jnp.array(0) + return x + self.module(condition), jnp.zeros(()) def inverse_and_log_det(self, y, condition=None): - return y - self.module(condition), jnp.array(0) + return y - self.module(condition), jnp.zeros(()) diff --git a/flowjax/bijections/chain.py b/flowjax/bijections/chain.py index b028c09b..58decd46 100644 --- a/flowjax/bijections/chain.py +++ b/flowjax/bijections/chain.py @@ -3,6 +3,8 @@ from collections.abc import Sequence from paramax import AbstractUnwrappable, unwrap +import jax.numpy as jnp +from jax import Array from flowjax.bijections.bijection import AbstractBijection from flowjax.utils import check_shapes_match, merge_cond_shapes @@ -36,15 +38,19 @@ def __init__( self.cond_shape = merge_cond_shapes([unwrap(b).cond_shape for b in unwrapped]) self.bijections = tuple(bijections) - def transform_and_log_det(self, x, condition=None): - log_abs_det_jac = 0 + def transform_and_log_det( + self, x: Array, condition: Array | None = None + ) -> tuple[Array, Array]: + log_abs_det_jac = jnp.zeros(()) for bijection in self.bijections: x, log_abs_det_jac_i = bijection.transform_and_log_det(x, condition) log_abs_det_jac += log_abs_det_jac_i.sum() return x, log_abs_det_jac - def inverse_and_log_det(self, y, condition=None): - log_abs_det_jac = 0 + def inverse_and_log_det( + self, y: Array, condition: Array | None = None + ) -> tuple[Array, Array]: + log_abs_det_jac = jnp.zeros(()) for bijection in reversed(self.bijections): y, log_abs_det_jac_i = bijection.inverse_and_log_det(y, condition) log_abs_det_jac += log_abs_det_jac_i.sum() diff --git a/flowjax/bijections/orthogonal.py b/flowjax/bijections/orthogonal.py new file mode 100644 index 00000000..abc62a94 --- /dev/null +++ b/flowjax/bijections/orthogonal.py @@ -0,0 +1,72 @@ +from flowjax.bijections.bijection import AbstractBijection +from jax import Array +import jax.numpy as jnp +import jax.nn as jnn +from jax.scipy import fft + + +class Neg(AbstractBijection): + shape: tuple[int, ...] + cond_shape = None + + def __init__(self, shape): + """Initialize the MvScale bijection with `params`.""" + self.shape = shape + + def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None): + return -x, jnp.zeros(()) + + def inverse_and_log_det(self, y: Array, condition: Array | None = None): + return -y, jnp.zeros(()) + + +class Householder(AbstractBijection): + shape: tuple[int, ...] + params: Array + cond_shape = None + + def __init__(self, params: Array): + """Initialize the MvScale bijection with `params`.""" + self.shape = (params.shape[-1],) + self.params = params + + def _householder(self, x: Array, params: Array) -> Array: + norm_sq = params @ params + norm = jnp.sqrt(norm_sq) + + vec = params / norm + return x - 2 * vec * (x @ vec) + + def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None): + return self._householder(x, self.params), jnp.zeros(()) + + def inverse_and_log_det(self, y: Array, condition: Array | None = None): + return self._householder(y, self.params), jnp.zeros(()) + + +class DCT(AbstractBijection): + shape: tuple[int, ...] + cond_shape = None + axis: int + norm: str + + def __init__(self, shape, *, axis: int = -1): + self.shape = shape + self.axis = axis + self.norm = "ortho" + + def _dct(self, x: Array, inverse: bool = False) -> Array: + if inverse: + z = fft.idct(x, norm=self.norm, axis=self.axis) + else: + z = fft.dct(x, norm=self.norm, axis=self.axis) + + return z + + def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None): + y = self._dct(x) + return y, jnp.zeros(()) + + def inverse_and_log_det(self, y: Array, condition: Array | None = None): + x = self._dct(y, inverse=True) + return x, jnp.zeros(()) diff --git a/flowjax/bijections/utils.py b/flowjax/bijections/utils.py index 5f46e5db..f1066acf 100644 --- a/flowjax/bijections/utils.py +++ b/flowjax/bijections/utils.py @@ -79,10 +79,10 @@ def __init__(self, permutation: Int[Array | np.ndarray, "..."]): ) def transform_and_log_det(self, x, condition=None): - return x[self.permutation], jnp.array(0) + return x[self.permutation], jnp.array(0.0) def inverse_and_log_det(self, y, condition=None): - return y[self.inverse_permutation], jnp.array(0) + return y[self.inverse_permutation], jnp.array(0.0) class Flip(AbstractBijection): @@ -95,11 +95,15 @@ class Flip(AbstractBijection): shape: tuple[int, ...] = () cond_shape: ClassVar[None] = None - def transform_and_log_det(self, x, condition=None): - return jnp.flip(x), jnp.array(0) + def transform_and_log_det( + self, x: Array, condition: Array | None = None + ) -> tuple[Array, Array]: + return jnp.flip(x), jnp.zeros(()) - def inverse_and_log_det(self, y, condition=None): - return jnp.flip(y), jnp.array(0) + def inverse_and_log_det( + self, y: Array, condition: Array | None = None + ) -> tuple[Array, Array]: + return jnp.flip(y), jnp.zeros(()) class Indexed(AbstractBijection): @@ -300,3 +304,48 @@ def inverse_and_log_det(self, y, condition=None): x = self.inverter(self.bijection, y, condition) _, log_det = self.bijection.transform_and_log_det(x, condition) return x, -log_det + + +class Sandwich(AbstractBijection): + """ + A bijection that sandwiches one transformation inside another. + + The `Sandwich` bijection applies an "outer" transformation, followed by an + "inner" transformation, and then the inverse of the "outer" transformation. + This allows for the composition of transformations in a nested structure. + + Args: + outer (AbstractBijection): The outer transformation applied first and + inverted last. + inner (AbstractBijection): The inner transformation applied between + the forward and inverse outer transformations. + """ + shape: tuple[int, ...] + cond_shape: tuple[int, ...] | None + outer: AbstractBijection + inner: AbstractBijection + + def __init__(self, outer: AbstractBijection, inner: AbstractBijection): + shape = inner.shape + if outer.shape != shape: + raise ValueError("Inner and outer transformations are incompatible") + self.cond_shape = inner.cond_shape + if outer.cond_shape != self.cond_shape: + raise ValueError("Inner and outer transformations are incompatible") + self.shape = shape + self.outer = outer + self.inner = inner + + def transform_and_log_det(self, x: Array, condition=None) -> tuple[Array, Array]: + z1, logdet1 = self.outer.transform_and_log_det(x, condition) + z2, logdet2 = self.inner.transform_and_log_det(z1, condition) + y, logdet3 = self.outer.inverse_and_log_det(z2, condition) + + return y, logdet1 + logdet2 + logdet3 + + def inverse_and_log_det(self, y: Array, condition=None) -> tuple[Array, Array]: + z1, logdet1 = self.outer.transform_and_log_det(y, condition) + z2, logdet2 = self.inner.inverse_and_log_det(z1, condition) + x, logdet3 = self.outer.inverse_and_log_det(z2, condition) + + return x, logdet1 + logdet2 + logdet3 diff --git a/tests/test_bijections/test_bijections.py b/tests/test_bijections/test_bijections.py index 46cc3d87..5c156927 100644 --- a/tests/test_bijections/test_bijections.py +++ b/tests/test_bijections/test_bijections.py @@ -16,20 +16,24 @@ Chain, Concatenate, Coupling, + DCT, EmbedCondition, Exp, Flip, + Householder, Identity, Indexed, LeakyTanh, Loc, MaskedAutoregressive, NumericalInverse, + Neg, Permute, Planar, Power, RationalQuadraticSpline, Reshape, + Sandwich, Scale, Scan, Sigmoid, @@ -132,6 +136,7 @@ nn_depth=2, ) ), + "Neg": lambda: Neg(shape=(DIM,)), "BlockAutoregressiveNetwork (unconditional)": lambda: BlockAutoregressiveNetwork( KEY, dim=DIM, @@ -217,6 +222,12 @@ partial(bisection_search, lower=-1, upper=1, atol=1e-7), ), ), + "Sandwich": lambda: Sandwich( + Exp(), + Affine(0.1, 0.5), + ), + "DCT": lambda: DCT(shape=(3, 4)), + "Householder": lambda: Householder(jnp.ones(3)), } From 328de18cbcf347e365aa792ff8546f477e1846e6 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 30 Jan 2025 18:21:36 +0100 Subject: [PATCH 2/3] Add AsymmetricAffine and add some docs --- flowjax/bijections/__init__.py | 3 +- flowjax/bijections/bijection.py | 1 + flowjax/bijections/coupling.py | 1 + flowjax/bijections/jax_transforms.py | 1 + flowjax/bijections/orthogonal.py | 49 ++++++++- flowjax/bijections/softplus.py | 125 ++++++++++++++++++++++- flowjax/bijections/utils.py | 30 +++--- tests/test_bijections/test_bijections.py | 8 ++ 8 files changed, 201 insertions(+), 17 deletions(-) diff --git a/flowjax/bijections/__init__.py b/flowjax/bijections/__init__.py index 987aea1e..4f69851f 100644 --- a/flowjax/bijections/__init__.py +++ b/flowjax/bijections/__init__.py @@ -13,7 +13,7 @@ from .power import Power from .rational_quadratic_spline import RationalQuadraticSpline from .sigmoid import Sigmoid -from .softplus import SoftPlus +from .softplus import SoftPlus, AsymmetricAffine from .tanh import LeakyTanh, Tanh from .utils import ( EmbedCondition, @@ -33,6 +33,7 @@ "AdditiveCondition", "Affine", "AbstractBijection", + "AsymmetricAffine", "BlockAutoregressiveNetwork", "Chain", "Concatenate", diff --git a/flowjax/bijections/bijection.py b/flowjax/bijections/bijection.py index c76e0c44..c7a1c490 100644 --- a/flowjax/bijections/bijection.py +++ b/flowjax/bijections/bijection.py @@ -15,6 +15,7 @@ from equinox import AbstractVar from jaxtyping import Array, ArrayLike from paramax import unwrap +import jax from flowjax.utils import _get_ufunc_signature, arraylike_to_array diff --git a/flowjax/bijections/coupling.py b/flowjax/bijections/coupling.py index 5fb66995..675b0f28 100644 --- a/flowjax/bijections/coupling.py +++ b/flowjax/bijections/coupling.py @@ -8,6 +8,7 @@ import equinox as eqx import jax.nn as jnn import jax.numpy as jnp +import jax import paramax from jaxtyping import PRNGKeyArray diff --git a/flowjax/bijections/jax_transforms.py b/flowjax/bijections/jax_transforms.py index d3672c3c..939394df 100644 --- a/flowjax/bijections/jax_transforms.py +++ b/flowjax/bijections/jax_transforms.py @@ -3,6 +3,7 @@ from collections.abc import Callable import equinox as eqx +from jax import Array import jax.numpy as jnp from jax.lax import scan from jax.tree_util import tree_leaves, tree_map diff --git a/flowjax/bijections/orthogonal.py b/flowjax/bijections/orthogonal.py index abc62a94..d79a6bec 100644 --- a/flowjax/bijections/orthogonal.py +++ b/flowjax/bijections/orthogonal.py @@ -6,11 +6,18 @@ class Neg(AbstractBijection): + """A bijection that negates its input (multiplies by -1). + + This is a simple bijection that flips the sign of all elements in the input array. + + Attributes: + shape: Shape of the input/output arrays + cond_shape: Shape of conditional inputs (None as this bijection is unconditional) + """ shape: tuple[int, ...] cond_shape = None def __init__(self, shape): - """Initialize the MvScale bijection with `params`.""" self.shape = shape def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None): @@ -21,12 +28,28 @@ def inverse_and_log_det(self, y: Array, condition: Array | None = None): class Householder(AbstractBijection): + """A Householder reflection bijection. + + This bijection implements a Householder reflection, which is a linear + transformation that reflects vectors across a hyperplane defined by a normal + vector (params). The transformation is its own inverse and volume-preserving + (determinant = ±1). + + Given a unit vector v, the transformation is: + x → x - 2(x·v)v + + Attributes: + shape: Shape of the input/output vectors + cond_shape: Shape of conditional inputs (None as this bijection is unconditional) + params: Normal vector defining the reflection hyperplane. The vector is + normalized in the transformation, so scaling params will have no effect + on the bijection. + """ shape: tuple[int, ...] params: Array cond_shape = None def __init__(self, params: Array): - """Initialize the MvScale bijection with `params`.""" self.shape = (params.shape[-1],) self.params = params @@ -43,8 +66,30 @@ def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None): def inverse_and_log_det(self, y: Array, condition: Array | None = None): return self._householder(y, self.params), jnp.zeros(()) + def inverse_gradient_and_val( + self, + y: Array, + y_grad: Array, + y_logp: Array, + condition: Array | None = None, + ) -> tuple[Array, Array, Array]: + x, logdet = self.inverse_and_log_det(y) + x_grad = self._householder(y_grad, params=self.params) + return (x, x_grad, y_logp - logdet) + class DCT(AbstractBijection): + """Discrete Cosine Transform (DCT) bijection. + + This bijection applies the DCT or its inverse along a specified axis. + + Attributes: + shape: Shape of the input/output arrays + cond_shape: Shape of conditional inputs (None as this bijection is unconditional) + axis: Axis along which to apply the DCT + norm: Normalization method, fixed to 'ortho' to ensure bijectivity + """ + shape: tuple[int, ...] cond_shape = None axis: int diff --git a/flowjax/bijections/softplus.py b/flowjax/bijections/softplus.py index 1987cd4d..c45b903b 100644 --- a/flowjax/bijections/softplus.py +++ b/flowjax/bijections/softplus.py @@ -2,11 +2,15 @@ from typing import ClassVar +import jax import jax.numpy as jnp -from jax.nn import softplus +from jax.nn import softplus, soft_sign +from jaxtyping import Array, ArrayLike +from paramax import AbstractUnwrappable, Parameterize, unwrap +from paramax.utils import inv_softplus from flowjax.bijections.bijection import AbstractBijection - +from flowjax.utils import arraylike_to_array class SoftPlus(AbstractBijection): r"""Transforms to positive domain using softplus :math:`y = \log(1 + \exp(x))`.""" @@ -20,3 +24,120 @@ def transform_and_log_det(self, x, condition=None): def inverse_and_log_det(self, y, condition=None): x = jnp.log(-jnp.expm1(-y)) + y return x, softplus(-x).sum() + + +class AsymmetricAffine(AbstractBijection): + """An asymmetric bijection that applies different scaling factors for + positive and negative inputs. + + This bijection implements a continuous, differentiable transformation that + scales positive and negative inputs differently while maintaining smoothness + at zero. It's particularly useful for modeling data with different variances + in positive and negative regions. + + The forward transformation is defined as: + y = σ θ x for x ≥ 0 + y = σ x/θ for x < 0 + where: + - σ (scale) controls the overall scaling + - θ (theta) controls the asymmetry between positive and negative regions + - μ (loc) controls the location shift + + The transformation uses a smooth transition between the two regions to + maintain differentiability. + + For θ = 0, this is exactly an affine function with the specified location + and scale. + + Attributes: + shape: The shape of the transformation parameters + cond_shape: Shape of conditional inputs (None as this bijection is + unconditional) + loc: Location parameter μ for shifting the distribution + scale: Scale parameter σ (positive) + theta: Asymmetry parameter θ (positive) + """ + shape: tuple[int, ...] = () + cond_shape: ClassVar[None] = None + loc: Array + scale: Array | AbstractUnwrappable[Array] + theta: Array | AbstractUnwrappable[Array] + + def __init__( + self, + loc: ArrayLike = 0, + scale: ArrayLike = 1, + theta: ArrayLike = 1, + ): + self.loc, scale, theta = jnp.broadcast_arrays( + *(arraylike_to_array(a, dtype=float) for a in (loc, scale, theta)), + ) + self.shape = scale.shape + self.scale = Parameterize(softplus, inv_softplus(scale)) + self.theta = Parameterize(softplus, inv_softplus(theta)) + + def _log_derivative_f(self, x, mu, sigma, theta): + abs_x = jnp.abs(x) + theta = jnp.log(theta) + + sinh_theta = jnp.sinh(theta) + #sinh_theta = (theta - 1 / theta) / 2 + cosh_theta = jnp.cosh(theta) + #cosh_theta = (theta + 1 / theta) / 2 + numerator = sinh_theta * x * (abs_x + 2.0) + denominator = (abs_x + 1.0)**2 + term = numerator / denominator + dy_dx = sigma * (cosh_theta + term) + return jnp.log(dy_dx) + + def transform_and_log_det(self, x: ArrayLike, condition: ArrayLike | None = None) -> tuple[Array, Array]: + + def transform(x, mu, sigma, theta): + weight = (soft_sign(x) + 1) / 2 + z = x * sigma + y_pos = z * theta + y_neg = z / theta + y = weight * y_pos + (1.0 - weight) * y_neg + mu + return y + + mu, sigma, theta = self.loc, self.scale, self.theta + + y = transform(x, mu, sigma, theta) + logjac = self._log_derivative_f(x, mu, sigma, theta) + return y, logjac.sum() + + def inverse_and_log_det(self, y: ArrayLike, condition: ArrayLike | None = None) -> tuple[Array, Array]: + + def inverse(y, mu, sigma, theta): + delta = y - mu + inv_theta = 1 / theta + + # Case 1: y >= mu (delta >= 0) + a = sigma * (theta + inv_theta) + discriminant_pos = jnp.square(a - 2.0 * delta) + 16.0 * sigma * theta * delta + discriminant_pos = jnp.where(discriminant_pos < 0, 1., discriminant_pos) + sqrt_pos = jnp.sqrt(discriminant_pos) + numerator_pos = 2.0 * delta - a + sqrt_pos + denominator_pos = 4.0 * sigma * theta + x_pos = numerator_pos / denominator_pos + + # Case 2: y < mu (delta < 0) + sigma_part = sigma * (1.0 + theta * theta) + term2 = 2.0 * delta * theta + inside_sqrt_neg = jnp.square(sigma_part + term2) - 16.0 * sigma * delta * theta + inside_sqrt_neg = jnp.where(inside_sqrt_neg < 0, 1., inside_sqrt_neg) + sqrt_neg = jnp.sqrt(inside_sqrt_neg) + numerator_neg = sigma_part + term2 - sqrt_neg + denominator_neg = 4.0 * sigma + x_neg = numerator_neg / denominator_neg + + # Combine cases based on delta + x = jnp.where(delta >= 0.0, x_pos, x_neg) + return x + + mu, sigma, theta = self.loc, self.scale, self.theta + + x = inverse(y, mu, sigma, theta) + logjac = self._log_derivative_f(x, mu, sigma, theta) + return x, -logjac.sum() + diff --git a/flowjax/bijections/utils.py b/flowjax/bijections/utils.py index f1066acf..1ff766eb 100644 --- a/flowjax/bijections/utils.py +++ b/flowjax/bijections/utils.py @@ -307,18 +307,24 @@ def inverse_and_log_det(self, y, condition=None): class Sandwich(AbstractBijection): - """ - A bijection that sandwiches one transformation inside another. - - The `Sandwich` bijection applies an "outer" transformation, followed by an - "inner" transformation, and then the inverse of the "outer" transformation. - This allows for the composition of transformations in a nested structure. - - Args: - outer (AbstractBijection): The outer transformation applied first and - inverted last. - inner (AbstractBijection): The inner transformation applied between - the forward and inverse outer transformations. + """A bijection that composes bijections in a nested structure: g⁻¹ ∘ f ∘ g. + + The Sandwich bijection creates a new transformation by "sandwiching" one + bijection between the forward and inverse applications of another. Given + bijections f and g, it computes: + Forward: x → g⁻¹(f(g(x))) + Inverse: y → g⁻¹(f⁻¹(g(y))) + + This composition pattern is useful for: + - Creating symmetries in the transformation + - Applying a transformation in a different coordinate system + - Building more complex bijections from simpler ones + + Attributes: + shape: Shape of the input/output arrays + cond_shape: Shape of conditional inputs + outer: Transformation g applied first and inverted last + inner: Transformation f applied in the middle """ shape: tuple[int, ...] cond_shape: tuple[int, ...] | None diff --git a/tests/test_bijections/test_bijections.py b/tests/test_bijections/test_bijections.py index 5c156927..380c7667 100644 --- a/tests/test_bijections/test_bijections.py +++ b/tests/test_bijections/test_bijections.py @@ -7,11 +7,14 @@ import jax.numpy as jnp import jax.random as jr import pytest +import numpy as np +from scipy import stats from flowjax.bijections import ( AbstractBijection, AdditiveCondition, Affine, + AsymmetricAffine, BlockAutoregressiveNetwork, Chain, Concatenate, @@ -94,6 +97,11 @@ ), jnp.diag(jnp.array([-1, 2, -3])), ), + "AsymmetricAffine": lambda: AsymmetricAffine( + jnp.ones(DIM), + jnp.full(DIM, 2.6), + jnp.full(DIM, 0.1), + ), "RationalQuadraticSpline": lambda: RationalQuadraticSpline(knots=4, interval=1), "Coupling (unconditional)": lambda: Coupling( KEY, From 1002c9a5e7e92da102ca4cb713be7a801cef9810 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 12 Feb 2025 19:04:06 +0100 Subject: [PATCH 3/3] Remove Neg and AssymetricAffine and minor edits --- flowjax/bijections/__init__.py | 8 +- flowjax/bijections/orthogonal.py | 66 ++----------- flowjax/bijections/softplus.py | 117 ----------------------- flowjax/bijections/utils.py | 27 ++---- tests/test_bijections/test_bijections.py | 12 +-- 5 files changed, 24 insertions(+), 206 deletions(-) diff --git a/flowjax/bijections/__init__.py b/flowjax/bijections/__init__.py index 4f69851f..8fe2bcf7 100644 --- a/flowjax/bijections/__init__.py +++ b/flowjax/bijections/__init__.py @@ -13,7 +13,7 @@ from .power import Power from .rational_quadratic_spline import RationalQuadraticSpline from .sigmoid import Sigmoid -from .softplus import SoftPlus, AsymmetricAffine +from .softplus import SoftPlus from .tanh import LeakyTanh, Tanh from .utils import ( EmbedCondition, @@ -27,18 +27,17 @@ Sandwich, ) from .utils import EmbedCondition, Flip, Identity, Invert, Permute, Reshape, Sandwich -from .orthogonal import Householder, DCT, Neg +from .orthogonal import Householder, DiscreteCosine __all__ = [ "AdditiveCondition", "Affine", "AbstractBijection", - "AsymmetricAffine", "BlockAutoregressiveNetwork", "Chain", "Concatenate", "Coupling", - "DCT", + "DiscreteCosine", "EmbedCondition", "Exp", "Flip", @@ -49,7 +48,6 @@ "Loc", "MaskedAutoregressive", "Indexed", - "Neg", "Permute", "Power", "Planar", diff --git a/flowjax/bijections/orthogonal.py b/flowjax/bijections/orthogonal.py index d79a6bec..4d06534f 100644 --- a/flowjax/bijections/orthogonal.py +++ b/flowjax/bijections/orthogonal.py @@ -1,3 +1,4 @@ +from paramax import AbstractUnwrappable, Parameterize from flowjax.bijections.bijection import AbstractBijection from jax import Array import jax.numpy as jnp @@ -5,28 +6,6 @@ from jax.scipy import fft -class Neg(AbstractBijection): - """A bijection that negates its input (multiplies by -1). - - This is a simple bijection that flips the sign of all elements in the input array. - - Attributes: - shape: Shape of the input/output arrays - cond_shape: Shape of conditional inputs (None as this bijection is unconditional) - """ - shape: tuple[int, ...] - cond_shape = None - - def __init__(self, shape): - self.shape = shape - - def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None): - return -x, jnp.zeros(()) - - def inverse_and_log_det(self, y: Array, condition: Array | None = None): - return -y, jnp.zeros(()) - - class Householder(AbstractBijection): """A Householder reflection bijection. @@ -46,39 +25,24 @@ class Householder(AbstractBijection): on the bijection. """ shape: tuple[int, ...] - params: Array + unit_vec: Array | AbstractUnwrappable cond_shape = None def __init__(self, params: Array): self.shape = (params.shape[-1],) - self.params = params - - def _householder(self, x: Array, params: Array) -> Array: - norm_sq = params @ params - norm = jnp.sqrt(norm_sq) + self.unit_vec = Parameterize(lambda x: x / jnp.linalg.norm(x), params) - vec = params / norm - return x - 2 * vec * (x @ vec) + def _householder(self, x: Array) -> Array: + return x - 2 * self.unit_vec * (x @ self.unit_vec) def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None): - return self._householder(x, self.params), jnp.zeros(()) + return self._householder(x), jnp.zeros(()) def inverse_and_log_det(self, y: Array, condition: Array | None = None): - return self._householder(y, self.params), jnp.zeros(()) - - def inverse_gradient_and_val( - self, - y: Array, - y_grad: Array, - y_logp: Array, - condition: Array | None = None, - ) -> tuple[Array, Array, Array]: - x, logdet = self.inverse_and_log_det(y) - x_grad = self._householder(y_grad, params=self.params) - return (x, x_grad, y_logp - logdet) + return self._householder(y), jnp.zeros(()) -class DCT(AbstractBijection): +class DiscreteCosine(AbstractBijection): """Discrete Cosine Transform (DCT) bijection. This bijection applies the DCT or its inverse along a specified axis. @@ -93,25 +57,15 @@ class DCT(AbstractBijection): shape: tuple[int, ...] cond_shape = None axis: int - norm: str def __init__(self, shape, *, axis: int = -1): self.shape = shape self.axis = axis - self.norm = "ortho" - - def _dct(self, x: Array, inverse: bool = False) -> Array: - if inverse: - z = fft.idct(x, norm=self.norm, axis=self.axis) - else: - z = fft.dct(x, norm=self.norm, axis=self.axis) - - return z def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None): - y = self._dct(x) + y = fft.dct(x, norm="ortho", axis=self.axis) return y, jnp.zeros(()) def inverse_and_log_det(self, y: Array, condition: Array | None = None): - x = self._dct(y, inverse=True) + x = fft.idct(y, norm="ortho", axis=self.axis) return x, jnp.zeros(()) diff --git a/flowjax/bijections/softplus.py b/flowjax/bijections/softplus.py index c45b903b..90c33151 100644 --- a/flowjax/bijections/softplus.py +++ b/flowjax/bijections/softplus.py @@ -24,120 +24,3 @@ def transform_and_log_det(self, x, condition=None): def inverse_and_log_det(self, y, condition=None): x = jnp.log(-jnp.expm1(-y)) + y return x, softplus(-x).sum() - - -class AsymmetricAffine(AbstractBijection): - """An asymmetric bijection that applies different scaling factors for - positive and negative inputs. - - This bijection implements a continuous, differentiable transformation that - scales positive and negative inputs differently while maintaining smoothness - at zero. It's particularly useful for modeling data with different variances - in positive and negative regions. - - The forward transformation is defined as: - y = σ θ x for x ≥ 0 - y = σ x/θ for x < 0 - where: - - σ (scale) controls the overall scaling - - θ (theta) controls the asymmetry between positive and negative regions - - μ (loc) controls the location shift - - The transformation uses a smooth transition between the two regions to - maintain differentiability. - - For θ = 0, this is exactly an affine function with the specified location - and scale. - - Attributes: - shape: The shape of the transformation parameters - cond_shape: Shape of conditional inputs (None as this bijection is - unconditional) - loc: Location parameter μ for shifting the distribution - scale: Scale parameter σ (positive) - theta: Asymmetry parameter θ (positive) - """ - shape: tuple[int, ...] = () - cond_shape: ClassVar[None] = None - loc: Array - scale: Array | AbstractUnwrappable[Array] - theta: Array | AbstractUnwrappable[Array] - - def __init__( - self, - loc: ArrayLike = 0, - scale: ArrayLike = 1, - theta: ArrayLike = 1, - ): - self.loc, scale, theta = jnp.broadcast_arrays( - *(arraylike_to_array(a, dtype=float) for a in (loc, scale, theta)), - ) - self.shape = scale.shape - self.scale = Parameterize(softplus, inv_softplus(scale)) - self.theta = Parameterize(softplus, inv_softplus(theta)) - - def _log_derivative_f(self, x, mu, sigma, theta): - abs_x = jnp.abs(x) - theta = jnp.log(theta) - - sinh_theta = jnp.sinh(theta) - #sinh_theta = (theta - 1 / theta) / 2 - cosh_theta = jnp.cosh(theta) - #cosh_theta = (theta + 1 / theta) / 2 - numerator = sinh_theta * x * (abs_x + 2.0) - denominator = (abs_x + 1.0)**2 - term = numerator / denominator - dy_dx = sigma * (cosh_theta + term) - return jnp.log(dy_dx) - - def transform_and_log_det(self, x: ArrayLike, condition: ArrayLike | None = None) -> tuple[Array, Array]: - - def transform(x, mu, sigma, theta): - weight = (soft_sign(x) + 1) / 2 - z = x * sigma - y_pos = z * theta - y_neg = z / theta - y = weight * y_pos + (1.0 - weight) * y_neg + mu - return y - - mu, sigma, theta = self.loc, self.scale, self.theta - - y = transform(x, mu, sigma, theta) - logjac = self._log_derivative_f(x, mu, sigma, theta) - return y, logjac.sum() - - def inverse_and_log_det(self, y: ArrayLike, condition: ArrayLike | None = None) -> tuple[Array, Array]: - - def inverse(y, mu, sigma, theta): - delta = y - mu - inv_theta = 1 / theta - - # Case 1: y >= mu (delta >= 0) - a = sigma * (theta + inv_theta) - discriminant_pos = jnp.square(a - 2.0 * delta) + 16.0 * sigma * theta * delta - discriminant_pos = jnp.where(discriminant_pos < 0, 1., discriminant_pos) - sqrt_pos = jnp.sqrt(discriminant_pos) - numerator_pos = 2.0 * delta - a + sqrt_pos - denominator_pos = 4.0 * sigma * theta - x_pos = numerator_pos / denominator_pos - - # Case 2: y < mu (delta < 0) - sigma_part = sigma * (1.0 + theta * theta) - term2 = 2.0 * delta * theta - inside_sqrt_neg = jnp.square(sigma_part + term2) - 16.0 * sigma * delta * theta - inside_sqrt_neg = jnp.where(inside_sqrt_neg < 0, 1., inside_sqrt_neg) - sqrt_neg = jnp.sqrt(inside_sqrt_neg) - numerator_neg = sigma_part + term2 - sqrt_neg - denominator_neg = 4.0 * sigma - x_neg = numerator_neg / denominator_neg - - # Combine cases based on delta - x = jnp.where(delta >= 0.0, x_pos, x_neg) - return x - - mu, sigma, theta = self.loc, self.scale, self.theta - - x = inverse(y, mu, sigma, theta) - logjac = self._log_derivative_f(x, mu, sigma, theta) - return x, -logjac.sum() - diff --git a/flowjax/bijections/utils.py b/flowjax/bijections/utils.py index 1ff766eb..36549372 100644 --- a/flowjax/bijections/utils.py +++ b/flowjax/bijections/utils.py @@ -10,7 +10,8 @@ from jaxtyping import Array, Int from flowjax.bijections.bijection import AbstractBijection -from flowjax.utils import arraylike_to_array +from flowjax.bijections.chain import Chain +from flowjax.utils import arraylike_to_array, check_shapes_match, merge_cond_shapes class Invert(AbstractBijection): @@ -332,26 +333,16 @@ class Sandwich(AbstractBijection): inner: AbstractBijection def __init__(self, outer: AbstractBijection, inner: AbstractBijection): - shape = inner.shape - if outer.shape != shape: - raise ValueError("Inner and outer transformations are incompatible") - self.cond_shape = inner.cond_shape - if outer.cond_shape != self.cond_shape: - raise ValueError("Inner and outer transformations are incompatible") - self.shape = shape + check_shapes_match([outer.shape, inner.shape]) + self.cond_shape = merge_cond_shapes([outer.cond_shape, inner.cond_shape]) + self.shape = inner.shape self.outer = outer self.inner = inner def transform_and_log_det(self, x: Array, condition=None) -> tuple[Array, Array]: - z1, logdet1 = self.outer.transform_and_log_det(x, condition) - z2, logdet2 = self.inner.transform_and_log_det(z1, condition) - y, logdet3 = self.outer.inverse_and_log_det(z2, condition) - - return y, logdet1 + logdet2 + logdet3 + chain = Chain([self.outer, self.inner, Invert(self.outer)]) + return chain.transform_and_log_det(x, condition) def inverse_and_log_det(self, y: Array, condition=None) -> tuple[Array, Array]: - z1, logdet1 = self.outer.transform_and_log_det(y, condition) - z2, logdet2 = self.inner.inverse_and_log_det(z1, condition) - x, logdet3 = self.outer.inverse_and_log_det(z2, condition) - - return x, logdet1 + logdet2 + logdet3 + chain = Chain([self.outer, self.inner, Invert(self.outer)]) + return chain.inverse_and_log_det(y, condition) diff --git a/tests/test_bijections/test_bijections.py b/tests/test_bijections/test_bijections.py index 380c7667..4ff1561a 100644 --- a/tests/test_bijections/test_bijections.py +++ b/tests/test_bijections/test_bijections.py @@ -14,12 +14,11 @@ AbstractBijection, AdditiveCondition, Affine, - AsymmetricAffine, BlockAutoregressiveNetwork, Chain, Concatenate, Coupling, - DCT, + DiscreteCosine, EmbedCondition, Exp, Flip, @@ -30,7 +29,6 @@ Loc, MaskedAutoregressive, NumericalInverse, - Neg, Permute, Planar, Power, @@ -97,11 +95,6 @@ ), jnp.diag(jnp.array([-1, 2, -3])), ), - "AsymmetricAffine": lambda: AsymmetricAffine( - jnp.ones(DIM), - jnp.full(DIM, 2.6), - jnp.full(DIM, 0.1), - ), "RationalQuadraticSpline": lambda: RationalQuadraticSpline(knots=4, interval=1), "Coupling (unconditional)": lambda: Coupling( KEY, @@ -144,7 +137,6 @@ nn_depth=2, ) ), - "Neg": lambda: Neg(shape=(DIM,)), "BlockAutoregressiveNetwork (unconditional)": lambda: BlockAutoregressiveNetwork( KEY, dim=DIM, @@ -234,7 +226,7 @@ Exp(), Affine(0.1, 0.5), ), - "DCT": lambda: DCT(shape=(3, 4)), + "DiscreteCosine": lambda: DiscreteCosine(shape=(3, 4)), "Householder": lambda: Householder(jnp.ones(3)), }