Skip to content

Commit

Permalink
Allow multiple positional data arrays in fit_to_data
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Feb 18, 2025
1 parent 7364b4d commit 954fdd7
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions flowjax/train/loops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Training loops."""

from collections.abc import Callable
from warnings import DeprecationWarning

import equinox as eqx
import jax.numpy as jnp
Expand Down Expand Up @@ -28,6 +29,8 @@ def fit_to_key_based_loss(
learning_rate: float = 5e-4,
optimizer: optax.GradientTransformation | None = None,
show_progress: bool = True,
opt_state: optax.OptState | None = None,
return_opt_state: bool = False,
):
"""Train a pytree, using a loss with params, static and key as arguments.
Expand All @@ -43,6 +46,8 @@ def fit_to_key_based_loss(
learning_rate: The adam learning rate. Ignored if optimizer is provided.
optimizer: Optax optimizer. Defaults to None.
show_progress: Whether to show progress bar. Defaults to True.
opt_state: Optinal initial state of the optimizer.
return_opt_state: Whether to return the optimizer state.
Returns:
A tuple containing the trained pytree and the losses.
Expand All @@ -55,7 +60,8 @@ def fit_to_key_based_loss(
eqx.is_inexact_array,
is_leaf=lambda leaf: isinstance(leaf, paramax.NonTrainable),
)
opt_state = optimizer.init(params)
if opt_state is None:
opt_state = optimizer.init(params)

losses = []

Expand All @@ -72,14 +78,15 @@ def fit_to_key_based_loss(
)
losses.append(loss.item())
keys.set_postfix({"loss": loss.item()})
if return_opt_state:
return eqx.combine(params, static), losses, opt_state
return eqx.combine(params, static), losses


def fit_to_data(
key: PRNGKeyArray,
dist: PyTree, # Custom losses may support broader types than AbstractDistribution
x: ArrayLike | tuple[ArrayLike, ...],
*,
*data: ArrayLike,
condition: ArrayLike | None = None,
loss_fn: Callable | None = None,
learning_rate: float = 5e-4,
Expand All @@ -103,11 +110,14 @@ def fit_to_data(
Args:
key: Jax random seed.
dist: The pytree to train (usually a distribution).
x: Samples from target distribution.
data: Samples from target distribution. If several arrays are passed, each one
is split into batches along the first axes, and one batch of each is
passed into the loss function.
learning_rate: The learning rate for adam optimizer. Ignored if optimizer is
provided.
optimizer: Optax optimizer. Defaults to None.
condition: Conditioning variables. Defaults to None.
condition: Conditioning variables. Defaults to None. This argument is
deprecated, you can pass this information as the last `x` argument.
loss_fn: Loss function. Defaults to MaximumLikelihoodLoss.
max_epochs: Maximum number of epochs. Defaults to 100.
max_patience: Number of consecutive epochs with no validation loss improvement
Expand All @@ -127,11 +137,11 @@ def fit_to_data(
If an opt_state is provided, it will also return the new opt_state.
"""
if isinstance(x, tuple):
data = x
else:
data = (x,)
if condition is not None:
raise DeprecationWarning(
"The `condition` argument is deprecated. "
"You can pass condition data as additonal data arrays."
)
data = (*data, condition)
data = tuple(jnp.asarray(a) for a in data)

Expand Down

0 comments on commit 954fdd7

Please sign in to comment.