Skip to content

Commit

Permalink
FunMC: Allow specifying max_steps to enable tracing even when `num_…
Browse files Browse the repository at this point in the history
…steps` is dynamic.

PiperOrigin-RevId: 645136707
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Jun 20, 2024
1 parent e60a96d commit b21689a
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 25 deletions.
47 changes: 33 additions & 14 deletions spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,33 +166,47 @@ def random_categorical(logits, num_samples, seed):
return jax.vmap(_searchsorted)(flat_cum_sum, flat_eta).reshape(eta.shape).T


def trace(state, fn, num_steps, unroll, **_):
def trace(state, fn, num_steps, unroll, max_steps, **_):
"""Implementation of `trace` operator, without the calling convention."""
# We need the shapes and dtypes of the outputs of `fn`.
_, untraced_spec, traced_spec = jax.eval_shape(
fn, map_tree(lambda s: jax.ShapeDtypeStruct(s.shape, s.dtype), state))
untraced_init = map_tree(lambda spec: jnp.zeros(spec.shape, spec.dtype),
untraced_spec)
untraced_init, traced_init = map_tree(
lambda spec: jnp.zeros(spec.shape, spec.dtype),
(untraced_spec, traced_spec),
)

try:
num_steps = int(num_steps)
use_scan = True
except TypeError:
use_scan = False
if flatten_tree(traced_spec):
raise ValueError(
'Cannot trace values when `num_steps` is not statically known. Pass '
'False to `trace_mask` or return an empty structure (e.g. `()`) as '
'the extra output.')
if unroll:
raise ValueError(
'Cannot unroll when `num_steps` is not statically known.')
if max_steps is None:
if flatten_tree(traced_spec):
raise ValueError( # pylint: disable=raise-missing-from
'Cannot trace values when `num_steps` is not statically known and '
'`max_steps` is not specified. Pass `False` to `trace_mask` or '
'return an empty structure (e.g. `()`) as '
'the extra output.'
)
if unroll:
raise ValueError( # pylint: disable=raise-missing-from
'Cannot unroll when `num_steps` is not statically known and '
'`max_steps` is not specified.'
)
if max_steps is not None:
use_scan = False

if unroll:
num_outputs = num_steps if max_steps is None else max_steps

traced_lists = map_tree(lambda _: [], traced_spec)
untraced = untraced_init
for _ in range(num_steps):
state, untraced, traced_element = fn(state)
for step in range(num_outputs):
if step < num_steps:
state, untraced, traced_element = fn(state)
else:
traced_element = traced_init
map_tree_up_to(traced_spec, lambda l, e: l.append(e), traced_lists,
traced_element)
# Using asarray instead of stack to handle empty arrays correctly.
Expand All @@ -213,8 +227,13 @@ def wrapper(state_untraced, _):
length=num_steps,
)
else:
num_outputs = num_steps if max_steps is None else max_steps
num_steps = (
num_steps if max_steps is None else jnp.minimum(num_steps, max_steps)
)

trace_arrays = map_tree(
lambda spec: jnp.zeros((num_steps,) + spec.shape, spec.dtype),
lambda spec: jnp.zeros((num_outputs,) + spec.shape, spec.dtype),
traced_spec)

def wrapper(i, state_untraced_traced):
Expand Down
32 changes: 24 additions & 8 deletions spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,16 @@ def compiled_fn(x):
return compiled_fn, output_spec


def trace(state, fn, num_steps, unroll, parallel_iterations=10):
def trace(state, fn, num_steps, unroll, max_steps, parallel_iterations=10):
"""TF implementation of `trace` operator, without the calling convention."""
num_outputs = num_steps if max_steps is None else max_steps

if tf.config.experimental_functions_run_eagerly() or tf.executing_eagerly():
state, first_untraced, first_traced = fn(state)
arrays = tf.nest.map_structure(
lambda v: tf.TensorArray( # pylint: disable=g-long-lambda
v.dtype,
size=num_steps,
size=num_outputs,
element_shape=v.shape).write(0, v),
first_traced)
start_idx = 1
Expand All @@ -198,7 +200,7 @@ def trace(state, fn, num_steps, unroll, parallel_iterations=10):
arrays = tf.nest.map_structure(
lambda spec: tf.TensorArray( # pylint: disable=g-long-lambda
spec.dtype,
size=num_steps,
size=num_outputs,
element_shape=spec.shape),
traced_spec)
first_untraced = tf.nest.map_structure(
Expand All @@ -214,22 +216,36 @@ def cond(i, *_):
return i < num_steps

static_num_steps = tf.get_static_value(num_steps)
static_num_outputs = tf.get_static_value(num_outputs)
loop_vars = (start_idx, state, first_untraced, arrays)

if unroll:
if static_num_steps is None:
raise ValueError(
'Cannot unroll when `num_steps` is not statically known.')
'Cannot unroll when `num_steps` is not statically known or '
'`max_steps` is None.'
)
static_num_iters = (
static_num_steps
if max_steps is None
else min(static_num_steps, max_steps)
)
# TODO(siege): Investigate if using lists instead of TensorArray's is faster
# (like is done in the JAX backend).
for _ in range(start_idx, static_num_steps):
for _ in range(start_idx, static_num_iters):
loop_vars = body(*loop_vars)
_, state, untraced, arrays = loop_vars
else:
if static_num_steps is None:
maximum_iterations = None
if max_steps is None:
maximum_iterations = None
else:
maximum_iterations = max_steps - start_idx
else:
maximum_iterations = static_num_steps - start_idx
if max_steps is None:
maximum_iterations = static_num_steps - start_idx
else:
maximum_iterations = min(static_num_steps, max_steps) - start_idx
_, state, untraced, arrays = tf.while_loop(
cond=cond,
body=body,
Expand All @@ -241,7 +257,7 @@ def cond(i, *_):
traced = tf.nest.map_structure(lambda a: a.stack(), arrays)

def _merge_static_length(x):
x.set_shape(tf.TensorShape(static_num_steps).concatenate(x.shape[1:]))
x.set_shape(tf.TensorShape(static_num_outputs).concatenate(x.shape[1:]))
return x

traced = tf.nest.map_structure(_merge_static_length, traced)
Expand Down
7 changes: 6 additions & 1 deletion spinoffs/fun_mc/fun_mc/fun_mc_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def trace(
trace_fn: Callable[[State, ArrayNest], ArrayNest] = _trace_extra,
trace_mask: bool | BooleanNest = True,
unroll: bool = False,
max_steps: int | None = None,
parallel_iterations: int = 10,
) -> tuple[State, ArrayNest]:
"""`TransitionOperator` that runs `fn` repeatedly and traces its outputs.
Expand All @@ -243,8 +244,11 @@ def trace(
unroll: Whether to unroll the loop. This can occasionally lead to improved
performance at the cost of increasing the XLA optimization time. Only
works if `num_steps` is statically known.
max_steps: If `num_steps` is not statically known and you still want to
trace values, you can use `max_steps` to allocate output trace to be of
this length. Only elements up to `num_steps` will be valid, however.
parallel_iterations: Number of iterations of the while loop to run in
parallel.
parallel (TensorFlow-only).
Returns:
state: The final state returned by `fn`.
Expand Down Expand Up @@ -295,6 +299,7 @@ def wrapper(state):
fn=wrapper,
num_steps=num_steps,
unroll=unroll,
max_steps=max_steps,
parallel_iterations=parallel_iterations,
)

Expand Down
31 changes: 29 additions & 2 deletions spinoffs/fun_mc/fun_mc/fun_mc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,37 @@ def testTraceDynamic(self):

@jax.jit
def trace_n(num_steps):
return fun_mc.trace(0, lambda x: (x + 1, ()), num_steps)[0]
return fun_mc.trace(
0,
lambda x: (x + 1, (10 * x, 100 * x)),
num_steps,
max_steps=6,
trace_mask=(True, False),
)

x = trace_n(5)
x, (traced, untraced) = trace_n(5)
self.assertAllEqual(5, x)
self.assertAllEqual(40, traced[4])
self.assertEqual(6, traced.shape[0])
self.assertAllEqual(400, untraced)

@parameterized.named_parameters(
('Unrolled', True),
('NotUnrolled', False),
)
def testTraceMaxSteps(self, unroll):
x, (traced, untraced) = fun_mc.trace(
0,
lambda x: (x + 1, (10 * x, 100 * x)),
5,
max_steps=6,
unroll=unroll,
trace_mask=(True, False),
)
self.assertAllEqual(5, x)
self.assertAllEqual(40, traced[4])
self.assertEqual(6, traced.shape[0])
self.assertAllEqual(400, untraced)

@parameterized.named_parameters(
('Unrolled', True),
Expand Down

0 comments on commit b21689a

Please sign in to comment.