Skip to content

Commit

Permalink
FunMC: Allow early stopping inside trace.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 645136956
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Jun 20, 2024
1 parent b21689a commit f74f179
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 28 deletions.
46 changes: 31 additions & 15 deletions spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,12 @@ def random_categorical(logits, num_samples, seed):
def trace(state, fn, num_steps, unroll, max_steps, **_):
"""Implementation of `trace` operator, without the calling convention."""
# We need the shapes and dtypes of the outputs of `fn`.
_, untraced_spec, traced_spec = jax.eval_shape(
_, untraced_spec, traced_spec, stop_spec = jax.eval_shape(
fn, map_tree(lambda s: jax.ShapeDtypeStruct(s.shape, s.dtype), state))
if isinstance(stop_spec, tuple):
stop = ()
else:
stop = False
untraced_init, traced_init = map_tree(
lambda spec: jnp.zeros(spec.shape, spec.dtype),
(untraced_spec, traced_spec),
Expand All @@ -194,7 +198,7 @@ def trace(state, fn, num_steps, unroll, max_steps, **_):
'Cannot unroll when `num_steps` is not statically known and '
'`max_steps` is not specified.'
)
if max_steps is not None:
if max_steps is not None or not isinstance(stop_spec, tuple):
use_scan = False

if unroll:
Expand All @@ -203,8 +207,8 @@ def trace(state, fn, num_steps, unroll, max_steps, **_):
traced_lists = map_tree(lambda _: [], traced_spec)
untraced = untraced_init
for step in range(num_outputs):
if step < num_steps:
state, untraced, traced_element = fn(state)
if step < num_steps and not stop:
state, untraced, traced_element, stop = fn(state)
else:
traced_element = traced_init
map_tree_up_to(traced_spec, lambda l, e: l.append(e), traced_lists,
Expand All @@ -217,7 +221,7 @@ def trace(state, fn, num_steps, unroll, max_steps, **_):

def wrapper(state_untraced, _):
state, _ = state_untraced
state, untraced, traced = fn(state)
state, untraced, traced, _ = fn(state)
return (state, untraced), traced

(state, untraced), traced = lax.scan(
Expand All @@ -234,19 +238,31 @@ def wrapper(state_untraced, _):

trace_arrays = map_tree(
lambda spec: jnp.zeros((num_outputs,) + spec.shape, spec.dtype),
traced_spec)
traced_spec,
)
loop_vars = (
jnp.zeros_like(num_steps),
stop,
state,
untraced_init,
trace_arrays,
)

def wrapper(i, state_untraced_traced):
state, _, trace_arrays = state_untraced_traced
state, untraced, traced = fn(state)
def cond(loop_vars):
i, stop, *_ = loop_vars
return (i < num_steps) & (isinstance(stop, tuple) or ~stop)

def body(loop_vars):
i, _, state, _, trace_arrays = loop_vars
state, untraced, traced, stop = fn(state)
trace_arrays = map_tree(lambda a, e: a.at[i].set(e), trace_arrays, traced)
return (state, untraced, trace_arrays)

state, untraced, traced = lax.fori_loop(
jnp.asarray(0, num_steps.dtype),
num_steps,
wrapper,
(state, untraced_init, trace_arrays),
return i + 1, stop, state, untraced, trace_arrays

_, _, state, untraced, traced = lax.while_loop(
cond,
body,
loop_vars,
)
return state, untraced, traced

Expand Down
27 changes: 17 additions & 10 deletions spinoffs/fun_mc/fun_mc/dynamic/backend_tensorflow/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def trace(state, fn, num_steps, unroll, max_steps, parallel_iterations=10):
num_outputs = num_steps if max_steps is None else max_steps

if tf.config.experimental_functions_run_eagerly() or tf.executing_eagerly():
state, first_untraced, first_traced = fn(state)
state, first_untraced, first_traced, stop = fn(state)
arrays = tf.nest.map_structure(
lambda v: tf.TensorArray( # pylint: disable=g-long-lambda
v.dtype,
Expand All @@ -195,7 +195,7 @@ def trace(state, fn, num_steps, unroll, max_steps, parallel_iterations=10):
# the `TensorArray`s etc., we can get it by pre-compiling the wrapper
# function.
input_spec = tf.nest.map_structure(tf.TensorSpec.from_tensor, state)
fn, (_, untraced_spec, traced_spec) = _eval_shape(fn, input_spec)
fn, (_, untraced_spec, traced_spec, stop_spec) = _eval_shape(fn, input_spec)

arrays = tf.nest.map_structure(
lambda spec: tf.TensorArray( # pylint: disable=g-long-lambda
Expand All @@ -206,18 +206,23 @@ def trace(state, fn, num_steps, unroll, max_steps, parallel_iterations=10):
first_untraced = tf.nest.map_structure(
lambda spec: tf.zeros(spec.shape, spec.dtype), untraced_spec)
start_idx = 0
if isinstance(stop_spec, tuple):
stop = ()
else:
stop = False

def body(i, state, _, arrays):
state, untraced, traced = fn(state)
def body(i, stop, state, untraced, arrays):
del stop, untraced
state, untraced, traced, stop = fn(state)
arrays = tf.nest.map_structure(lambda a, e: a.write(i, e), arrays, traced)
return i + 1, state, untraced, arrays
return i + 1, stop, state, untraced, arrays

def cond(i, *_):
return i < num_steps
def cond(i, stop, *_):
return (i < num_steps) & (isinstance(stop, tuple) or ~stop)

static_num_steps = tf.get_static_value(num_steps)
static_num_outputs = tf.get_static_value(num_outputs)
loop_vars = (start_idx, state, first_untraced, arrays)
loop_vars = (start_idx, stop, state, first_untraced, arrays)

if unroll:
if static_num_steps is None:
Expand All @@ -233,8 +238,10 @@ def cond(i, *_):
# TODO(siege): Investigate if using lists instead of TensorArray's is faster
# (like is done in the JAX backend).
for _ in range(start_idx, static_num_iters):
if loop_vars[1]:
break
loop_vars = body(*loop_vars)
_, state, untraced, arrays = loop_vars
_, _, state, untraced, arrays = loop_vars
else:
if static_num_steps is None:
if max_steps is None:
Expand All @@ -246,7 +253,7 @@ def cond(i, *_):
maximum_iterations = static_num_steps - start_idx
else:
maximum_iterations = min(static_num_steps, max_steps) - start_idx
_, state, untraced, arrays = tf.while_loop(
_, _, state, untraced, arrays = tf.while_loop(
cond=cond,
body=body,
loop_vars=loop_vars,
Expand Down
16 changes: 13 additions & 3 deletions spinoffs/fun_mc/fun_mc/fun_mc_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def trace(
trace_mask: bool | BooleanNest = True,
unroll: bool = False,
max_steps: int | None = None,
stop_fn: Callable[[State, ArrayNest], BooleanArray] | None = None,
parallel_iterations: int = 10,
) -> tuple[State, ArrayNest]:
"""`TransitionOperator` that runs `fn` repeatedly and traces its outputs.
Expand All @@ -245,8 +246,11 @@ def trace(
performance at the cost of increasing the XLA optimization time. Only
works if `num_steps` is statically known.
max_steps: If `num_steps` is not statically known and you still want to
trace values, you can use `max_steps` to allocate output trace to be of
this length. Only elements up to `num_steps` will be valid, however.
values, you can use `max_steps` to allocate output trace to be of this
length. Only elements up to `num_steps` will be valid, however.
stop_fn: Optional callable that takes in the outputs of `fn` and returns a
boolean. If `True`, then the iteration is stopped. Only the elements
stored into traces before `stop_fn` returned `True` are valid.
parallel_iterations: Number of iterations of the while loop to run in
parallel (TensorFlow-only).
Expand Down Expand Up @@ -286,11 +290,17 @@ def wrapper(state):
state, extra = util.map_tree(
util.convert_to_tensor, call_transition_operator(fn, state)
)
if stop_fn is None:
# For TF compatibility, we can't use None. () is conveniently "falsy",
# which we rely on in the backend implementations.
stop = ()
else:
stop = stop_fn(state, extra)
trace_element = util.map_tree(
util.convert_to_tensor, trace_fn(state, extra)
)
untraced, traced = _split_trace(trace_element, trace_mask)
return state, untraced, traced
return state, untraced, traced, stop

state = util.map_tree(util.convert_to_tensor, state)

Expand Down
36 changes: 36 additions & 0 deletions spinoffs/fun_mc/fun_mc/fun_mc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,42 @@ def testTraceMaxSteps(self, unroll):
self.assertEqual(6, traced.shape[0])
self.assertAllEqual(400, untraced)

@parameterized.named_parameters(
('Unrolled', True),
('NotUnrolled', False),
)
def testTraceStopFnSingle(self, unroll):
x, (traced, untraced) = fun_mc.trace(
0,
lambda x: (x + 1, (10 * x, 100 * x)),
5,
unroll=unroll,
trace_mask=(True, False),
stop_fn=lambda x, _: x == 1,
)
self.assertAllEqual(1, x)
self.assertAllEqual(0, traced[0])
self.assertEqual(5, traced.shape[0])
self.assertAllEqual(0, untraced)

@parameterized.named_parameters(
('Unrolled', True),
('NotUnrolled', False),
)
def testTraceStopFnMulti(self, unroll):
x, (traced, untraced) = fun_mc.trace(
0,
lambda x: (x + 1, (10 * x, 100 * x)),
5,
unroll=unroll,
trace_mask=(True, False),
stop_fn=lambda x, _: x == 3,
)
self.assertAllEqual(3, x)
self.assertAllEqual(20, traced[2])
self.assertEqual(5, traced.shape[0])
self.assertAllEqual(200, untraced)

@parameterized.named_parameters(
('Unrolled', True),
('NotUnrolled', False),
Expand Down

0 comments on commit f74f179

Please sign in to comment.