Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple datasets in fit_to_data and add option to return opt_state #210

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 36 additions & 8 deletions flowjax/train/loops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Training loops."""

from collections.abc import Callable
from warnings import warn

import equinox as eqx
import jax.numpy as jnp
Expand Down Expand Up @@ -28,6 +29,8 @@ def fit_to_key_based_loss(
learning_rate: float = 5e-4,
optimizer: optax.GradientTransformation | None = None,
show_progress: bool = True,
opt_state: optax.OptState | None = None,
return_opt_state: bool = False,
):
"""Train a pytree, using a loss with params, static and key as arguments.

Expand All @@ -43,6 +46,8 @@ def fit_to_key_based_loss(
learning_rate: The adam learning rate. Ignored if optimizer is provided.
optimizer: Optax optimizer. Defaults to None.
show_progress: Whether to show progress bar. Defaults to True.
opt_state: Optinal initial state of the optimizer.
return_opt_state: Whether to return the optimizer state.

Returns:
A tuple containing the trained pytree and the losses.
Expand All @@ -55,7 +60,8 @@ def fit_to_key_based_loss(
eqx.is_inexact_array,
is_leaf=lambda leaf: isinstance(leaf, paramax.NonTrainable),
)
opt_state = optimizer.init(params)
if opt_state is None:
opt_state = optimizer.init(params)

losses = []

Expand All @@ -72,14 +78,15 @@ def fit_to_key_based_loss(
)
losses.append(loss.item())
keys.set_postfix({"loss": loss.item()})
if return_opt_state:
return eqx.combine(params, static), losses, opt_state
return eqx.combine(params, static), losses


def fit_to_data(
key: PRNGKeyArray,
dist: PyTree, # Custom losses may support broader types than AbstractDistribution
x: ArrayLike,
*,
*data: ArrayLike,
condition: ArrayLike | None = None,
loss_fn: Callable | None = None,
learning_rate: float = 5e-4,
Expand All @@ -90,6 +97,8 @@ def fit_to_data(
val_prop: float = 0.1,
return_best: bool = True,
show_progress: bool = True,
opt_state: optax.OptState | None = None,
return_opt_state: bool = False,
):
r"""Train a PyTree (e.g. a distribution) to samples from the target.

Expand All @@ -101,11 +110,14 @@ def fit_to_data(
Args:
key: Jax random seed.
dist: The pytree to train (usually a distribution).
x: Samples from target distribution.
data: Samples from target distribution. If several arrays are passed, each one
is split into batches along the first axes, and one batch of each is
passed into the loss function.
learning_rate: The learning rate for adam optimizer. Ignored if optimizer is
provided.
optimizer: Optax optimizer. Defaults to None.
condition: Conditioning variables. Defaults to None.
condition: Conditioning variables. Defaults to None. This argument is
deprecated, you can pass this information as the last `x` argument.
loss_fn: Loss function. Defaults to MaximumLikelihoodLoss.
max_epochs: Maximum number of epochs. Defaults to 100.
max_patience: Number of consecutive epochs with no validation loss improvement
Expand All @@ -116,11 +128,22 @@ def fit_to_data(
was reached (when True), or the parameters after the last update (when
False). Defaults to True.
show_progress: Whether to show progress bar. Defaults to True.
opt_state: Optinal initial state of the optimizer.
return_opt_state: Whether to return the optimizer state.

Returns:
A tuple containing the trained distribution and the losses.
A tuple containing the trained distribution and a dict with optimization
information like the losses and the optimizer state.

If an opt_state is provided, it will also return the new opt_state.
"""
data = (x,) if condition is None else (x, condition)
if condition is not None:
raise warn(
"The `condition` argument is deprecated. "
"You can pass condition data as additonal data arrays.",
DeprecationWarning,
)
data = (*data, condition)
data = tuple(jnp.asarray(a) for a in data)

if loss_fn is None:
Expand All @@ -135,7 +158,8 @@ def fit_to_data(
is_leaf=lambda leaf: isinstance(leaf, paramax.NonTrainable),
)
best_params = params
opt_state = optimizer.init(params)
if opt_state is None:
opt_state = optimizer.init(params)

# train val split
key, subkey = jr.split(key)
Expand Down Expand Up @@ -184,4 +208,8 @@ def fit_to_data(

params = best_params if return_best else params
dist = eqx.combine(params, static)

if return_opt_state:
return dist, losses, opt_state

return dist, losses
68 changes: 66 additions & 2 deletions tests/test_train/test_data_fit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import equinox as eqx
import jax.numpy as jnp
from jax import random
from paramax.wrappers import unwrap
import optax

from flowjax.bijections import Affine
from flowjax.distributions import Normal, Transformed
Expand All @@ -18,13 +20,75 @@ def test_data_fit():
x = random.normal(random.key(0), (100, dim))
flow, losses = fit_to_data(
random.key(0),
dist=flow,
x=x,
flow,
x,
max_epochs=1,
batch_size=50,
)
after = eqx.filter(flow, eqx.is_inexact_array)

flow2, losses2, opt_state = fit_to_data(
random.key(0),
flow,
x,
max_epochs=1,
batch_size=50,
return_opt_state=True,
)

assert jnp.all(before.base_dist.bijection.loc != after.base_dist.bijection.loc)
assert jnp.all(before.bijection.loc != after.bijection.loc)
assert isinstance(losses["train"][0], float)
assert isinstance(losses["val"][0], float)


def test_data_fit_opt_state():
dim = 3
mean, std = jnp.ones(dim), jnp.ones(dim)
base_dist = Normal(mean, std)
flow = Transformed(base_dist, Affine(jnp.ones(dim), jnp.ones(dim)))

# All params should change by default
before = eqx.filter(flow, eqx.is_inexact_array)
values = random.normal(random.key(0), (100, dim))
log_probs = random.normal(random.key(1), (100,))

def loss_fn(params, static, values, log_probs, key=None):
flow = unwrap(eqx.combine(params, static, is_leaf=eqx.is_inexact_array))
return (log_probs - flow.log_prob(params, values)).mean()

flow, losses, opt_state = fit_to_data(
random.key(0),
flow,
values,
log_probs,
max_epochs=1,
batch_size=50,
return_opt_state=True,
)
after = eqx.filter(flow, eqx.is_inexact_array)

assert jnp.all(before.base_dist.bijection.loc != after.base_dist.bijection.loc)
assert jnp.all(before.bijection.loc != after.bijection.loc)
assert isinstance(losses["train"][0], float)
assert isinstance(losses["val"][0], float)

# Continue training on new data
values = random.normal(random.key(2), (100, dim))
log_probs = random.normal(random.key(3), (100,))

flow, losses, opt_state = fit_to_data(
random.key(4),
flow,
values,
log_probs,
max_epochs=1,
batch_size=50,
return_opt_state=True,
opt_state=opt_state,
)
after = eqx.filter(flow, eqx.is_inexact_array)

assert jnp.all(before.base_dist.bijection.loc != after.base_dist.bijection.loc)
assert jnp.all(before.bijection.loc != after.bijection.loc)
assert isinstance(losses["train"][0], float)
Expand Down