Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reshape #138

Merged
merged 3 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flowjax/bijections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .rational_quadratic_spline import RationalQuadraticSpline
from .softplus import SoftPlus
from .tanh import LeakyTanh, Tanh
from .utils import EmbedCondition, Flip, Identity, Invert, Partial, Permute
from .utils import EmbedCondition, Flip, Identity, Invert, Partial, Permute, Reshape

__all__ = [
"AdditiveCondition",
Expand All @@ -34,6 +34,7 @@
"Permute",
"Planar",
"RationalQuadraticSpline",
"Reshape",
"Scale",
"Scan",
"SoftPlus",
Expand Down
88 changes: 88 additions & 0 deletions flowjax/bijections/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Utility bijections (embedding network, permutations, inversion etc.)."""

from __future__ import annotations

from collections.abc import Callable
from math import prod
from typing import ClassVar

import jax.numpy as jnp
Expand Down Expand Up @@ -233,3 +235,89 @@ def inverse_and_log_det(self, y, condition=None):
@property
def shape(self):
return self.bijection.shape


class Reshape(AbstractBijection):
"""Wraps bijection methods with reshaping operations.

One use case for this is for bijections that do not directly support a scalar
shape, but this allows construction with shape (1, ) and reshaping to ().

Args:
bijection: The bijection to wrap.
shape: The new input and output shape of the bijection. Defaults to
unchanged.
cond_shape: The new cond_shape of the bijection. Defaults to unchanged.

Example:
.. doctest::

>>> import jax.numpy as jnp
>>> from flowjax.bijections import Affine, Reshape
>>> affine = Affine(loc=jnp.arange(4))
>>> affine.shape
(4,)
>>> affine = Reshape(affine, (2,2))
>>> affine.shape
(2, 2)
>>> affine.transform(jnp.zeros((2,2)))
Array([[0., 1.],
[2., 3.]], dtype=float32)
"""

bijection: AbstractBijection
shape: tuple[int, ...]
cond_shape: tuple[int, ...] | None = None

def __init__(
self,
bijection: AbstractBijection,
shape: tuple[int, ...] | None = None,
cond_shape: tuple[int, ...] | None = None,
):
self.bijection = bijection
self.shape = shape if shape is not None else bijection.shape
self.cond_shape = cond_shape if cond_shape is not None else bijection.cond_shape

def __check_init__(self):
if self.bijection.cond_shape is None and self.cond_shape is not None:
raise ValueError(
"Cannot reshape cond_shape for unconditional bijection.",
)
shapes = {
"shape": (self.shape, self.bijection.shape),
"cond_shape": (self.cond_shape, self.bijection.cond_shape),
}

for k, v in shapes.items():
if v != (None, None) and prod(v[0]) != prod(v[1]):
raise ValueError(
f"Cannot reshape to a different number of elements. Got {k} "
f"{v[0]}, but bijection has shape {v[1]}.",
)

def transform(self, x, condition=None):
x = x.reshape(self.bijection.shape)
if self.cond_shape is not None:
condition = condition.reshape(self.bijection.cond_shape)
return self.bijection.transform(x, condition).reshape(self.shape)

def inverse(self, y, condition=None):
y = y.reshape(self.bijection.shape)
if self.cond_shape is not None:
condition = condition.reshape(self.bijection.cond_shape)
return self.bijection.inverse(y, condition).reshape(self.shape)

def transform_and_log_det(self, x, condition=None):
x = x.reshape(self.bijection.shape)
if self.cond_shape is not None:
condition = condition.reshape(self.bijection.cond_shape)
y, log_det = self.bijection.transform_and_log_det(x, condition)
return y.reshape(self.shape), log_det

def inverse_and_log_det(self, y, condition=None):
y = y.reshape(self.bijection.shape)
if self.cond_shape is not None:
condition = condition.reshape(self.bijection.cond_shape)
x, log_det = self.bijection.inverse_and_log_det(y, condition)
return x.reshape(self.shape), log_det
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,17 @@ pythonpath = ["."]


[tool.ruff]
select = ["E", "F", "B", "D", "COM", "I", "UP", "TRY004", "RET", "PT", "FBT"]
include = ["*.py", "*.pyi", "**/pyproject.toml", "*.ipynb"]

[tool.ruff.lint]
select = ["E", "F", "B", "D", "COM", "I", "UP", "TRY004", "RET", "PT", "FBT"]
ignore = ["D102", "D105", "D107"]

[tool.ruff.pydocstyle]

[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"tests/*" = ["D"]
"*.ipynb" = ["D"]
"__init__.py" = ["D"]
14 changes: 14 additions & 0 deletions tests/test_bijections/test_bijections.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Permute,
Planar,
RationalQuadraticSpline,
Reshape,
Scale,
Scan,
SoftPlus,
Expand Down Expand Up @@ -166,6 +167,19 @@
eqx.filter_vmap(Affine)(jnp.ones(3)),
in_axis=eqx.if_array(0),
),
"Reshape (unconditional)": Reshape(Affine(scale=jnp.arange(1, 5)), (2, 2)),
"Reshape (conditional)": Reshape(
MaskedAutoregressive(
KEY,
transformer=Affine(),
dim=4,
cond_dim=1,
nn_width=3,
nn_depth=1,
),
shape=(1, 4, 1),
cond_shape=(),
),
}


Expand Down
Loading