From 90903545fdf63d6b2fc312f86f7791ae2bbac4d4 Mon Sep 17 00:00:00 2001 From: mikeevmm Date: Sun, 4 Apr 2021 19:21:53 +0100 Subject: [PATCH 1/4] feat: implementation of bounded LBFGS algorithm --- .../python/optimizer/__init__.py | 2 + .../python/optimizer/lbfgsb.py | 1625 +++++++++++++++++ 2 files changed, 1627 insertions(+) create mode 100644 tensorflow_probability/python/optimizer/lbfgsb.py diff --git a/tensorflow_probability/python/optimizer/__init__.py b/tensorflow_probability/python/optimizer/__init__.py index 2187b56593..7ad87a8b93 100644 --- a/tensorflow_probability/python/optimizer/__init__.py +++ b/tensorflow_probability/python/optimizer/__init__.py @@ -27,6 +27,7 @@ from tensorflow_probability.python.optimizer.differential_evolution import minimize as differential_evolution_minimize from tensorflow_probability.python.optimizer.differential_evolution import one_step as differential_evolution_one_step from tensorflow_probability.python.optimizer.lbfgs import minimize as lbfgs_minimize +from tensorflow_probability.python.optimizer.lbfgsb import minimize as lbfgsb_minimize from tensorflow_probability.python.optimizer.nelder_mead import minimize as nelder_mead_minimize from tensorflow_probability.python.optimizer.nelder_mead import nelder_mead_one_step from tensorflow_probability.python.optimizer.proximal_hessian_sparse import minimize as proximal_hessian_sparse_minimize @@ -42,6 +43,7 @@ 'differential_evolution_minimize', 'differential_evolution_one_step', 'lbfgs_minimize', + 'lbfgsb_minimize', 'nelder_mead_minimize', 'nelder_mead_one_step', 'proximal_hessian_sparse_minimize', diff --git a/tensorflow_probability/python/optimizer/lbfgsb.py b/tensorflow_probability/python/optimizer/lbfgsb.py new file mode 100644 index 0000000000..4e7c9e6d8e --- /dev/null +++ b/tensorflow_probability/python/optimizer/lbfgsb.py @@ -0,0 +1,1625 @@ +# Copyright 2018 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""A constrained version of the Limited-Memory BFGS minimization algorithm. + +Limited-memory quasi-Newton methods are useful for solving large problems +whose Hessian matrices cannot be computed at a reasonable cost or are not +sparse. Instead of storing fully dense n x n approximations of Hessian +matrices, they only save a few vectors of length n that represent the +approximations implicitly. + +This module implements the algorithm known as L-BFGS-B, which, as its name +suggests, is a limited-memory version of the BFGS algorithm, with bounds. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +from numpy.core.fromnumeric import argmin, clip + +from tensorflow.python.ops.gen_array_ops import gather, lower_bound, where +from tensorflow_probability.python.internal.backend.numpy import dtype, numpy_math + +# Dependency imports +import tensorflow.compat.v2 as tf + +from tensorflow_probability.python.internal import dtype_util +from tensorflow_probability.python.internal import prefer_static as ps +from tensorflow_probability.python.optimizer import bfgs_utils +from tensorflow_probability.python.optimizer import lbfgs_minimize + + +LBfgsBOptimizerResults = collections.namedtuple( + 'LBfgsBOptimizerResults', [ + 'converged', # Scalar boolean tensor indicating whether the minimum + # was found within tolerance. + 'failed', # Scalar boolean tensor indicating whether a line search + # step failed to find a suitable step size satisfying Wolfe + # conditions. In the absence of any constraints on the + # number of objective evaluations permitted, this value will + # be the complement of `converged`. However, if there is + # a constraint and the search stopped due to available + # evaluations being exhausted, both `failed` and `converged` + # will be simultaneously False. + 'num_iterations', # The number of iterations of the BFGS update. + 'num_objective_evaluations', # The total number of objective + # evaluations performed. + 'position', # A tensor containing the last argument value found + # during the search. If the search converged, then + # this value is the argmin of the objective function. + 'lower_bounds', # A tensor containing the lower bounds to the constrained + # optimization, cast to the shape of `position`. + 'upper_bounds', # A tensor containing the upper bounds to the constrained + # optimization, cast to the shape of `position`. + 'objective_value', # A tensor containing the value of the objective + # function at the `position`. If the search + # converged, then this is the (local) minimum of + # the objective function. + 'objective_gradient', # A tensor containing the gradient of the + # objective function at the + # `final_position`. If the search converged + # the max-norm of this tensor should be + # below the tolerance. + 'position_deltas', # A tensor encoding information about the latest + # changes in `position` during the algorithm + # execution. Its shape is of the form + # `(num_correction_pairs,) + position.shape` where + # `num_correction_pairs` is given as an argument to + # the minimize function. + 'gradient_deltas', # A tensor encoding information about the latest + # changes in `objective_gradient` during the + # algorithm execution. Has the same shape as + # position_deltas. + 'history', # How many gradient/position deltas should be considered. + ]) + +_ConstrainedCauchyState = collections.namedtuple( + '_ConstrainedCauchyResult', [ + 'theta', # `\theta` in [2]; n the Cauchy search, relates to the implicit Hessian + # `B = \theta*I - WMW'` (`I` the identity, see [1,2] for details) + 'm', # `M_k` matrix in [2]; part of the implicit representation of the Hessian, + # see the comment above + 'breakpoints', # `t_i` in [Byrd et al.][2]; + # the breakpoints in the branch definition of the + # projection of the gradients, batched + 'steepest', # `d` in [2]; steepest descent clamped to bounds + 'free_vars_idx', # `\mathcal{F}` of [2]; the indices of (currently) free variables. + # Indices that are no longer free are marked with a negative value. + # This is used instead of a ragged tensor because the size of the + # state object must remain constant between iterations of the + # while loop. + 'free_mask', # Boolean mask of free variables + 'p', # as in [2] + 'c', # as in [2] + 'df', # `f'` in [2]; (corrected) gradient 2-norm + 'ddf', # `f''` in [2]; (corrected) laplacian 2-norm (?) + 'dt_min', # `\Delta t_min` in [2]; the minimizing parameter + # along the search direction + 'breakpoint_min', # `t` in [2] + 'breakpoint_min_idx', # `b` in [2] + 'dt', # `\Delta t` in [2] + 'breakpoint_min_old', # t_old in [2] + 'cauchy_point', # `x^cp` in [2]; the actual cauchy point (we're looking for) + 'active', # What batches are in active optimization + ]) + +def minimize(value_and_gradients_function, + initial_position, + bounds=None, + previous_optimizer_results=None, + num_correction_pairs=10, + tolerance=1e-8, + x_tolerance=0, + f_relative_tolerance=0, + initial_inverse_hessian_estimate=None, + max_iterations=50, + parallel_iterations=1, + stopping_condition=None, + max_line_search_iterations=50, + name=None): + """Applies the L-BFGS-B algorithm to minimize a differentiable function. + + Performs optionally constrained minimization of a differentiable function using the + L-BFGS-B scheme. See [Nocedal and Wright(2006)][1] for details on the unconstrained + version, and [Byrd et al.][2] for details on the constrained algorithm. + + ### Usage: + + The following example demonstrates the L-BFGS-B optimizer attempting to find the + constrained minimum for a simple high-dimensional quadratic objective function. + + ```python + TODO + ``` + + ### References: + + [1] Jorge Nocedal, Stephen Wright. Numerical Optimization. Springer Series + in Operations Research. pp 176-180. 2006 + + http://pages.mtu.edu/~struther/Courses/OLD/Sp2013/5630/Jorge_Nocedal_Numerical_optimization_267490.pdf + + [2] Richard H. Byrd, Peihuang Lu, Jorge Nocedal, & Ciyou Zhu (1995). + A Limited Memory Algorithm for Bound Constrained Optimization + SIAM Journal on Scientific Computing, 16(5), 1190–1208. + + https://doi.org/10.1137/0916069 + + Args: + value_and_gradients_function: A Python callable that accepts a point as a + real `Tensor` and returns a tuple of `Tensor`s of real dtype containing + the value of the function and its gradient at that point. The function + to be minimized. The input is of shape `[..., n]`, where `n` is the size + of the domain of input points, and all others are batching dimensions. + The first component of the return value is a real `Tensor` of matching + shape `[...]`. The second component (the gradient) is also of shape + `[..., n]` like the input value to the function. + initial_position: Real `Tensor` of shape `[..., n]`. The starting point, or + points when using batching dimensions, of the search procedure. At these + points the function value and the gradient norm should be finite. + Exactly one of `initial_position` and `previous_optimizer_results` can be + non-None. + bounds: Tuple of two real `Tensor`s of shape `[..., n]`. The first element + indicates the lower bounds in the constrained optimization, and the second + element of the tuple indicates the upper bounds of the optimization. If + `bounds` is `None`, the optimization is deferred to the unconstrained + version (see also `lbfgs_minimize`). If one of the elements of the tuple + is `None`, the optimization is assumed to be unconstrained (from above/below, + respectively). + previous_optimizer_results: An `LBfgsBOptimizerResults` namedtuple to + intialize the optimizer state from, instead of an `initial_position`. + This can be passed in from a previous return value to resume optimization + with a different `stopping_condition`. Exactly one of `initial_position` + and `previous_optimizer_results` can be non-None. + num_correction_pairs: Positive integer. Specifies the maximum number of + (position_delta, gradient_delta) correction pairs to keep as implicit + approximation of the Hessian matri + A real `Tensor` of the same shape as the `state.position`, of dtype `bool`, + denoting a mask over the free variables.x. + tolerance: Scalar `Tensor` of real dtype. Specifies the gradient tolerance + for the procedure. If the supremum norm of the gradient vector is below + this number, the algorithm is stopped. + x_tolerance: Scalar `Tensor` of real dtype. If the absolute change in the + position between one iteration and the next is smaller than this number, + the algorithm is stopped. + f_relative_tolerance: Scalar `Tensor` of real dtype. If the relative change + in the objective value between one iteration and the next is smaller + than this value, the algorithm is stopped. + initial_inverse_hessian_estimate: None. Option currently not supported. + max_iterations: Scalar positive int32 `Tensor`. The maximum number of + iterations for L-BFGS updates. + parallel_iterations: Positive integer. The number of iterations allowed to + run in parallel. + stopping_condition: (Optional) A Python function that takes as input two + Boolean tensors of shape `[...]`, and returns a Boolean scalar tensor. + The input tensors are `converged` and `failed`, indicating the current + status of each respective batch member; the return value states whether + the algorithm should stop. The default is tfp.optimizer.converged_all + which only stops when all batch members have either converged or failed. + An alternative is tfp.optimizer.converged_any which stops as soon as one + batch member has converged, or when all have failed. + max_line_search_iterations: Python int. The maximum number of iterations + for the `hager_zhang` line search algorithm. + name: (Optional) Python str. The name prefixed to the ops created by this + function. If not supplied, the default name 'minimize' is used. + + Returns: + optimizer_results: A namedtuple containing the following items: + converged: Scalar boolean tensor indicating whether the minimum was + found within tolerance. + failed: Scalar boolean tensor indicating whether a line search + step failed to find a suitable step size satisfying Wolfe + conditions. In the absence of any constraints on the + number of objective evaluations permitted, this value will + be the complement of `converged`. However, if there is + a constraint and the search stopped due to available + evaluations being exhausted, both `failed` and `converged` + will be simultaneously False. + num_objective_evaluations: The total number of objective + evaluations performed. + position: A tensor containing the last argument value found + during the search. If the search converged, then + this value is the argmin of the objective function. + objective_value: A tensor containing the value of the objective + function at the `position`. If the search converged, then this is + the (local) minimum of the objective function. + objective_gradient: A tensor containing the gradient of the objective + function at the `position`. If the search converged the + max-norm of this tensor should be below the tolerance. + position_deltas: A tensor encoding information about the latest + changes in `position` during the algorithm execution. + gradient_deltas: A tensor encoding information about the latest + changes in `objective_gradient` during the algorithm execution. + """ + + def _lbfgs_defer(): + return lbfgs_minimize(value_and_gradients_function, + initial_position, + previous_optimizer_results, + num_correction_pairs, + tolerance, + x_tolerance, + f_relative_tolerance, + initial_inverse_hessian_estimate, + max_iterations, + parallel_iterations, + stopping_condition, + max_line_search_iterations, + name) + + if bounds is None: + return _lbfgs_defer() + + if len(bounds) != 2: + raise ValueError( + '`bounds` parameter has unexpected number of elements ' + '(expected 2).') + + lower_bounds, upper_bounds = bounds + + if lower_bounds is None and upper_bounds is None: + return _lbfgs_defer() + # Defer further conversion of the bounds to appropriate tensors + # until the shape of the input is known + + if initial_inverse_hessian_estimate is not None: + raise NotImplementedError( + 'Support of initial_inverse_hessian_estimate arg not yet implemented') + + if stopping_condition is None: + stopping_condition = bfgs_utils.converged_all + + with tf.name_scope(name or 'minimize'): + if (initial_position is None) == (previous_optimizer_results is None): + raise ValueError( + 'Exactly one of `initial_position` or ' + '`previous_optimizer_results` may be specified.') + + if initial_position is not None: + initial_position = tf.convert_to_tensor( + initial_position, name='initial_position') + # Force at least one batching dimension + if len(ps.shape(initial_position)) == 1: + initial_position = initial_position[tf.newaxis, :] + position_shape = ps.shape(initial_position) + dtype = dtype_util.base_dtype(initial_position.dtype) + + if previous_optimizer_results is not None: + position_shape = ps.shape(previous_optimizer_results.position) + dtype = dtype_util.base_dtype(previous_optimizer_results.position.dtype) + + # TODO: This isn't agnostic to the number of batch dimensions, it only + # supports one batch dimension, but I've found RaggedTensors to be far + # too finicky/undocumented to handle multiple batch dimensions in any + # sane way. (Even the way it's working so far is less than ideal.) + if len(position_shape) > 2: + raise NotImplementedError("More than a batch dimension is not implemented. " + "Consider flattening and then reshaping the results.") + # NOTE: Broadcasting the batched dimensions breaks when there are no + # batched dimensions. Although this isn't handled like this in + # `lbfgs.py`, I'd rather force a batch dimension with a single + # element than do conditional checks later. + if len(position_shape) == 1: + position_shape = tf.concat([[1], position_shape], axis=0) + initial_position = tf.broadcast_to(initial_position, position_shape) + + # NOTE: Could maybe use bfgs_utils._broadcast here, but would have to check + # that the non-batching dimensions also match; using `tf.broadcast_to` has + # the advantage that passing a (1,)-shaped tensor as bounds will correctly + # bound every variable at the single value. + if lower_bounds is None: + lower_bounds = tf.constant( + [-float('inf')], shape=position_shape, dtype=dtype, name='lower_bounds') + else: + lower_bounds = tf.cast(tf.convert_to_tensor(lower_bounds), dtype=dtype) + try: + lower_bounds = tf.broadcast_to( + lower_bounds, position_shape, name='lower_bounds') + except tf.errors.InvalidArgumentError: + raise ValueError( + 'Failed to broadcast lower bounds tensor to the shape of starting position. ' + 'Are the lower bounds well formed?') + if upper_bounds is None: + upper_bounds = tf.constant( + [float('inf')], shape=position_shape, dtype=dtype, name='upper_bounds') + else: + upper_bounds = tf.cast(tf.convert_to_tensor(upper_bounds), dtype=dtype) + try: + upper_bounds = tf.broadcast_to( + upper_bounds, position_shape, name='upper_bounds') + except tf.errors.InvalidArgumentError: + raise ValueError( + 'Failed to broadcast upper bounds tensor to the shape of starting position. ' + 'Are the lower bounds well formed?') + + # Clamp the starting position to the bounds, because the algorithm expects the + # variables to be in range for the Hessian inverse estimation, but also because + # that fast-tracks the first iteration of the Cauchy optimization. + initial_position = tf.clip_by_value(initial_position, lower_bounds, upper_bounds) + + tolerance = tf.convert_to_tensor( + tolerance, dtype=dtype, name='grad_tolerance') + f_relative_tolerance = tf.convert_to_tensor( + f_relative_tolerance, dtype=dtype, name='f_relative_tolerance') + x_tolerance = tf.convert_to_tensor( + x_tolerance, dtype=dtype, name='x_tolerance') + max_iterations = tf.convert_to_tensor(max_iterations, name='max_iterations') + + # The `state` here is a `LBfgsBOptimizerResults` tuple with values for the + # current state of the algorithm computation. + def _cond(state): + """Continue if iterations remain and stopping condition is not met.""" + return ((state.num_iterations < max_iterations) & + tf.logical_not(stopping_condition(state.converged, state.failed))) + + def _body(current_state): + """Main optimization loop.""" + current_state = bfgs_utils.terminate_if_not_finite(current_state) + + cauchy_point, free_mask = \ + _cauchy_minimization(current_state, num_correction_pairs, parallel_iterations) + + search_direction = _get_search_direction(current_state) + + # TODO(b/120134934): Check if the derivative at the start point is not + # negative, if so then reset position/gradient deltas and recompute + # search direction. + # NOTE: Erasing is currently handled in `_bounded_line_search_step` + search_direction = tf.where( + free_mask, + search_direction, + 0.) + bad_direction = \ + (tf.reduce_sum(search_direction * current_state.objective_gradient, axis=-1) > 0) + + cauchy_search = _cauchy_line_search_step(current_state, + value_and_gradients_function, search_direction, + tolerance, f_relative_tolerance, x_tolerance, stopping_condition, + max_line_search_iterations, free_mask, cauchy_point) + + search_direction = cauchy_search.position - current_state.position + next_state = _bounded_line_search_step(current_state, + value_and_gradients_function, search_direction, + tolerance, f_relative_tolerance, x_tolerance, stopping_condition, + max_line_search_iterations, bad_direction) + + # If not failed or converged, update the Hessian estimate. + # Only do this if the new pairs obey the s.y > 0 + position_delta = next_state.position - current_state.position + gradient_delta = next_state.objective_gradient - current_state.objective_gradient + positive_prod = (tf.math.reduce_sum(position_delta * gradient_delta, axis=-1) > \ + 1E-8*tf.reduce_sum(gradient_delta**2, axis=-1)) + should_push = ~(next_state.converged | next_state.failed) & positive_prod & ~bad_direction + new_position_deltas = _queue_push( + next_state.position_deltas, should_push, position_delta) + new_gradient_deltas = _queue_push( + next_state.gradient_deltas, should_push, gradient_delta) + new_history = tf.where( + should_push, + tf.math.minimum(next_state.history + 1, num_correction_pairs), + next_state.history) + + if not tf.executing_eagerly(): + # Hint the compiler that the shape of the properties has not changed + new_position_deltas = tf.ensure_shape( + new_position_deltas, next_state.position_deltas.shape) + new_gradient_deltas = tf.ensure_shape( + new_gradient_deltas, next_state.gradient_deltas.shape) + new_history = tf.ensure_shape( + new_history, next_state.history.shape) + + state_after_inv_hessian_update = bfgs_utils.update_fields( + next_state, + position_deltas=new_position_deltas, + gradient_deltas=new_gradient_deltas, + history=new_history) + + return [state_after_inv_hessian_update] + + if previous_optimizer_results is None: + assert initial_position is not None + initial_state = _get_initial_state(value_and_gradients_function, + initial_position, + lower_bounds, + upper_bounds, + num_correction_pairs, + tolerance) + else: + initial_state = previous_optimizer_results + + return tf.while_loop( + cond=_cond, + body=_body, + loop_vars=[initial_state], + parallel_iterations=parallel_iterations)[0] + + +def _cauchy_minimization(bfgs_state, num_correction_pairs, parallel_iterations): + """Calculates the Cauchy point (minimizes the quadratic approximation to the + objective function at the current position, in the direction of steepest + descent), but bounding the gradient by the corresponding bounds. + + See algorithm CP and associated discussion of [Byrd,Lu,Nocedal,Zhu][2] + for details. + + Args: + bfgs_state: A `_ConstrainedCauchyState` initialized to the starting point of the + constrained minimization. + Returns: + A potentially modified `state`, the obtained `cauchy_point` and boolean + `free_mask` indicating which variables are free (`True`) and which variables + are under active constrain (`False`) + """ + cauchy_state = _get_initial_cauchy_state(bfgs_state, num_correction_pairs) + # NOTE: See lbfgsb.f (l. 1649) + ddf_org = -cauchy_state.theta * cauchy_state.df + + def _cond(state): + """Test convergence to Cauchy point at current branch""" + return tf.math.reduce_any(state.active) + + def _body(state): + """Cauchy point iterative loop + + (While loop of CP algorithm [2])""" + # Remove b from the free indices + free_vars_idx, free_mask = _cauchy_remove_breakpoint_min( + state.free_vars_idx, + state.breakpoint_min_idx, + state.free_mask, + state.active) + + # Shape: [b] + d_b = tf.where( + state.active, + tf.gather( + state.steepest, + state.breakpoint_min_idx, + batch_dims=1), + 0.) + # Shape: [b] + x_b = tf.where( + state.active, + tf.gather( + bfgs_state.position, + state.breakpoint_min_idx, + batch_dims=1), + 0.) + + # Shape: [b] + x_cp_b = tf.where( + state.active, + tf.where( + d_b > 0., + tf.gather( + bfgs_state.upper_bounds, + state.breakpoint_min_idx, + batch_dims=1), + tf.where( + d_b < 0., + tf.gather( + bfgs_state.lower_bounds, + state.breakpoint_min_idx, + batch_dims=1), + x_b)), + tf.gather( + state.cauchy_point, + state.breakpoint_min_idx, + batch_dims=1)) + + keep_idx = (tf.range(ps.shape(state.cauchy_point)[-1]) != \ + state.breakpoint_min_idx[..., tf.newaxis]) + cauchy_point = tf.where( + state.active[..., tf.newaxis], + tf.where( + keep_idx, + state.cauchy_point, + x_cp_b[..., tf.newaxis]), + state.cauchy_point) + + z_b = tf.where( + state.active, + x_cp_b - x_b, + 0.) + + c = tf.where( + state.active[..., tf.newaxis], + state.c + state.dt[...,tf.newaxis] * state.p, + state.c) + + # The matrix M has shape + # + # [[ 0 0 ] + # [ 0 M_h ]] + # + # where M_h is the M matrix considering the current history `h`. + # Therefore, for W, we should consider that the last `h` columns + # are + # Y[k-h,...,k-1] theta*S[k-h,...k-1] + # (so that the first `2*(m-h)` columns are 0. + + # 1. Create the "full" W matrix row + # TODO: Transpose seems inevitable, because of batch dims? + w_b = tf.concat( + [ + tf.gather( + tf.transpose( + bfgs_state.gradient_deltas, + perm=[1,0,2]), + state.breakpoint_min_idx, + axis=-1, + batch_dims=1), + state.theta[..., tf.newaxis] * \ + tf.gather( + tf.transpose( + bfgs_state.position_deltas, + perm=[1,0,2]), + state.breakpoint_min_idx, + axis=-1, + batch_dims=1) + ], + axis=-1) + # 2. "Permute" the relevant items to the right + idx = tf.concat( + [ + tf.ragged.range( + num_correction_pairs - bfgs_state.history), + tf.ragged.range( + num_correction_pairs, + 2*num_correction_pairs - bfgs_state.history), + tf.ragged.range( + num_correction_pairs - bfgs_state.history, + num_correction_pairs), + tf.ragged.range( + 2*num_correction_pairs - bfgs_state.history, + 2*num_correction_pairs) + ], + axis=-1).to_tensor() + w_b = tf.gather( + w_b, + idx, + batch_dims=1) + + # NOTE Use of d_b = -g_b + df = tf.where( + state.active, + state.df + state.dt * state.ddf + \ + d_b**2 - \ + state.theta * d_b * z_b + \ + d_b * tf.einsum( + '...j,...jk,...k->...', + w_b, + state.m, + c), + state.df) + + # NOTE use of d_b = -g_b + ddf = tf.where( + state.active, + state.ddf - state.theta * d_b**2 + \ + 2. * d_b * tf.einsum( + "...i,...ij,...j->...", + w_b, + state.m, + state.p) - \ + d_b**2 * tf.einsum( + "...i,...ij,...j->...", + w_b, + state.m, + w_b), + state.ddf) + # NOTE: See lbfgsb.f (l. 1649) + # TODO: How to get machine epsilon? + ddf = tf.math.maximum(ddf, 1E-8*ddf_org) + + # NOTE use of d_b = -g_b + p = tf.where( + state.active[..., tf.newaxis], + state.p - d_b[..., tf.newaxis] * w_b, + state.p) + + steepest_idx = tf.range( + ps.shape(state.steepest)[-1], + dtype=state.breakpoint_min_idx.dtype)[tf.newaxis, ...] + steepest = tf.where( + state.active[..., tf.newaxis], + tf.where( + steepest_idx == state.breakpoint_min_idx[..., tf.newaxis], + 0., + state.steepest), + state.steepest) + + dt_min = tf.where( + state.active, + -tf.math.divide_no_nan(df, ddf), + state.dt_min) + + breakpoint_min_old = tf.where( + state.active, + state.breakpoint_min, + state.breakpoint_min_old) + + # Find b + breakpoint_min_idx, breakpoint_min = \ + _cauchy_get_breakpoint_min( + state.breakpoints, + free_vars_idx) + breakpoint_min_idx = tf.where( + state.active, + breakpoint_min_idx, + state.breakpoint_min_idx) + breakpoint_min = tf.where( + state.active, + breakpoint_min, + state.breakpoint_min) + + dt = tf.where( + state.active, + breakpoint_min - state.breakpoint_min, + state.dt) + + active = tf.where( + state.active, + _cauchy_update_active(free_vars_idx, dt_min, dt), + state.active) + + # We have to hint the "compiler" that the shapes of the new + # values are the same as the old values. + if not tf.executing_eagerly(): + steepest = tf.ensure_shape(steepest, state.steepest.shape) + free_vars_idx = tf.ensure_shape(free_vars_idx, state.free_vars_idx.shape) + free_mask = tf.ensure_shape(free_mask, state.free_mask.shape) + p = tf.ensure_shape(p, state.p.shape) + c = tf.ensure_shape(c, state.c.shape) + df = tf.ensure_shape(df, state.df.shape) + ddf = tf.ensure_shape(ddf, state.ddf.shape) + dt_min = tf.ensure_shape(dt_min, state.dt_min.shape) + breakpoint_min = tf.ensure_shape(breakpoint_min, state.breakpoint_min.shape) + breakpoint_min_idx = tf.ensure_shape(breakpoint_min_idx, state.breakpoint_min_idx.shape) + dt = tf.ensure_shape(dt, state.dt.shape) + breakpoint_min_old = tf.ensure_shape(breakpoint_min_old, state.breakpoint_min_old.shape) + cauchy_point = tf.ensure_shape(cauchy_point, state.cauchy_point.shape) + active = tf.ensure_shape(active, state.active.shape) + + new_state = bfgs_utils.update_fields( + state, steepest=steepest, free_vars_idx=free_vars_idx, + free_mask=free_mask, p=p, c=c, df=df, ddf=ddf, dt_min=dt_min, + breakpoint_min=breakpoint_min, breakpoint_min_idx=breakpoint_min_idx, + dt=dt, breakpoint_min_old=breakpoint_min_old, + cauchy_point=cauchy_point, active=active) + + return [new_state] + + cauchy_loop = tf.while_loop( + cond=_cond, + body=_body, + loop_vars=[cauchy_state], + parallel_iterations=parallel_iterations)[0] + + # The loop broke, so the last identified `b` index never got + # removed + _free_vars_idx, free_mask = _cauchy_remove_breakpoint_min( + cauchy_loop.free_vars_idx, + cauchy_loop.breakpoint_min_idx, + cauchy_loop.free_mask, + cauchy_loop.active) + + dt_min = tf.math.maximum(cauchy_loop.dt_min, 0) + t_old = cauchy_loop.breakpoint_min_old + dt_min + + # A breakpoint of -1 means that we ran out of free variables + flagged_breakpoint_min = tf.where( + cauchy_loop.breakpoint_min < 0, + float('inf'), + cauchy_loop.breakpoint_min) + cauchy_point = tf.where( + ~(bfgs_state.converged | bfgs_state.failed)[..., tf.newaxis], + tf.where( + cauchy_loop.breakpoints >= flagged_breakpoint_min[..., tf.newaxis], + bfgs_state.position + t_old[..., tf.newaxis] * cauchy_loop.steepest, + cauchy_loop.cauchy_point), + bfgs_state.position) + + # NOTE: We only return the cauchy point and the free mask, so there is no + # need to update the actual state, even though we could at this point update + # `free_vars_idx`, `free_mask`, and `cauchy_point` + free_mask = free_mask & ~(cauchy_loop.breakpoints != cauchy_loop.breakpoint_min) + + return cauchy_point, free_mask + + +def _cauchy_update_active(free_vars_idx, dt_min, dt): + return tf.where( + tf.reduce_any(free_vars_idx >= 0, axis=-1) & (dt_min >= dt), + True, + False) + + +def _hz_line_search(state, value_and_gradients_function, + search_direction, max_iterations, inactive): + line_search_value_grad_func = bfgs_utils._restrict_along_direction( + value_and_gradients_function, state.position, search_direction) + derivative_at_start_pt = tf.reduce_sum( + state.objective_gradient * search_direction, axis=-1) + val_0 = bfgs_utils.ValueAndGradient(x=bfgs_utils._broadcast(0, state.position), + f=state.objective_value, + df=derivative_at_start_pt, + full_gradient=state.objective_gradient) + return bfgs_utils.linesearch.hager_zhang( + line_search_value_grad_func, + initial_step_size=bfgs_utils._broadcast(1, state.position), + value_at_zero=val_0, + converged=inactive, + max_iterations=max_iterations) # No search needed for these. + + +def _cauchy_line_search_step(state, value_and_gradients_function, search_direction, + grad_tolerance, f_relative_tolerance, x_tolerance, + stopping_condition, max_iterations, free_mask, cauchy_point): + """Performs the line search in given direction, backtracking in direction to the cauchy point, + and clamping actively contrained variables to the cauchy point.""" + inactive = state.failed | state.converged + ls_result = _hz_line_search(state, value_and_gradients_function, + search_direction, max_iterations, inactive) + + state_after_ls = bfgs_utils.update_fields( + state, + failed=state.failed | (~state.converged & ~ls_result.converged & tf.reduce_any(free_mask, axis=-1)), + num_iterations=state.num_iterations + 1, + num_objective_evaluations=( + state.num_objective_evaluations + ls_result.func_evals + 1)) + + def _do_update_position(): + # For inactive batch members `left.x` is zero. However, their + # `search_direction` might also be undefined, so we can't rely on + # multiplication by zero to produce a `position_delta` of zero. + alpha = ls_result.left.x[..., tf.newaxis] + ideal_position = tf.where( + inactive[..., tf.newaxis], + state.position, + tf.where( + free_mask, + state.position + search_direction * alpha, + cauchy_point)) + + # Backtrack from the ideal position in direction to the Cauchy point + cauchy_to_ideal = ideal_position - cauchy_point + clip_lower = tf.math.divide_no_nan( + state.lower_bounds - cauchy_point, + cauchy_to_ideal) + clip_upper = tf.math.divide_no_nan( + state.upper_bounds - cauchy_point, + cauchy_to_ideal) + clip = tf.math.reduce_min( + tf.where( + cauchy_to_ideal > 0, + clip_upper, + tf.where( + cauchy_to_ideal < 0, + clip_lower, + float('inf'))), + axis=-1) + alpha = tf.minimum(1.0, clip)[..., tf.newaxis] + + next_position = tf.where( + inactive[..., tf.newaxis], + state.position, + tf.where( + free_mask, + cauchy_point + alpha * cauchy_to_ideal, + cauchy_point)) + + # NOTE: one extra call to the function + next_objective, next_gradient = \ + value_and_gradients_function(next_position) + + return _update_position( + state_after_ls, + next_position, + next_objective, + next_gradient, + grad_tolerance, + f_relative_tolerance, + x_tolerance, + tf.constant(False)) + + return ps.cond( + stopping_condition(state.converged, state.failed), + true_fn=lambda: state_after_ls, + false_fn=_do_update_position) + + +def _bounded_line_search_step(state, value_and_gradients_function, search_direction, + grad_tolerance, f_relative_tolerance, x_tolerance, + stopping_condition, max_iterations, bad_direction): + """Performs a line search in given direction, clamping to the bounds, and fixing the actively + constrained values to the given values.""" + inactive = state.failed | state.converged | bad_direction + ls_result = _hz_line_search(state, value_and_gradients_function, + search_direction, max_iterations, inactive) + + new_failed = state.failed | (~state.converged & ~ls_result.converged \ + & tf.reduce_any(search_direction != 0, axis=-1)) \ + & ~bad_direction + new_num_iterations = state.num_iterations + 1 + new_num_objective_evaluations = ( + state.num_objective_evaluations + ls_result.func_evals + 1) + + if not tf.executing_eagerly(): + # Hint the compiler that the properties' shape will not change + new_failed = tf.ensure_shape( + new_failed, state.failed.shape) + new_num_iterations = tf.ensure_shape( + new_num_iterations, state.num_iterations.shape) + new_num_objective_evaluations = tf.ensure_shape( + new_num_objective_evaluations, state.num_objective_evaluations.shape) + + state_after_ls = bfgs_utils.update_fields( + state, + failed=new_failed, + num_iterations=new_num_iterations, + num_objective_evaluations=new_num_objective_evaluations) + + def _do_update_position(): + lower_term = tf.math.divide_no_nan( + state.lower_bounds - state.position, + search_direction) + upper_term = tf.math.divide_no_nan( + state.upper_bounds - state.position, + search_direction) + + under_clip = tf.math.reduce_max( + tf.where( + (search_direction > 0), + lower_term, + tf.where( + (search_direction < 0), + upper_term, + -float('inf'))), + axis=-1) + over_clip = tf.math.reduce_min( + tf.where( + (search_direction > 0), + upper_term, + tf.where( + (search_direction < 0), + lower_term, + float('inf'))), + axis=-1) + + alpha_clip = tf.clip_by_value( + ls_result.left.x, + under_clip, + over_clip)[..., tf.newaxis] + + # For inactive batch members `left.x` is zero. However, their + # `search_direction` might also be undefined, so we can't rely on + # multiplication by zero to produce a `position_delta` of zero. + next_position = tf.where( + inactive[..., tf.newaxis], + state.position, + state.position + search_direction * alpha_clip) + + # one extra call to the function, counted above + next_objective, next_gradient = \ + value_and_gradients_function(next_position) + + return _update_position( + state_after_ls, + next_position, + next_objective, + next_gradient, + grad_tolerance, + f_relative_tolerance, + x_tolerance, + bad_direction) + + return ps.cond( + stopping_condition(state.converged, state.failed), + true_fn=lambda: state_after_ls, + false_fn=_do_update_position) + + +def _update_position(state, + next_position, + next_objective, + next_gradient, + grad_tolerance, + f_relative_tolerance, + x_tolerance, + erase_memory): + """Updates the state advancing its position by a given position_delta. + Also erases the LBFGS memory if indicated.""" + state = bfgs_utils.terminate_if_not_finite(state, next_objective, next_gradient) + + converged = ~state.failed & \ + _check_convergence_bounded(state.position, + next_position, + state.objective_value, + next_objective, + next_gradient, + grad_tolerance, + f_relative_tolerance, + x_tolerance, + state.lower_bounds, + state.upper_bounds) + new_position_deltas = tf.where( + erase_memory[..., tf.newaxis], + tf.zeros_like(state.position_deltas), + state.position_deltas) + new_gradient_deltas = tf.where( + erase_memory[..., tf.newaxis], + tf.zeros_like(state.gradient_deltas), + state.gradient_deltas) + new_history = tf.where( + erase_memory, + tf.zeros_like(state.history), + state.history) + new_converged = (state.converged | converged) + + if not tf.executing_eagerly(): + # Hint the compiler that the properties have not changed shape + new_converged = tf.ensure_shape(new_converged, state.converged.shape) + next_position = tf.ensure_shape(next_position, state.position.shape) + next_objective = tf.ensure_shape(next_objective, state.objective_value.shape) + next_gradient = tf.ensure_shape(next_gradient, state.objective_gradient.shape) + new_position_deltas = tf.ensure_shape(new_position_deltas, state.position_deltas.shape) + new_gradient_deltas = tf.ensure_shape(new_gradient_deltas, state.gradient_deltas.shape) + new_history = tf.ensure_shape(new_history, state.history.shape) + + return bfgs_utils.update_fields( + state, + converged=new_converged, + position=next_position, + objective_value=next_objective, + objective_gradient=next_gradient, + position_deltas=new_position_deltas, + gradient_deltas=new_gradient_deltas, + history=new_history) + + +def _check_convergence_bounded(current_position, + next_position, + current_objective, + next_objective, + next_gradient, + grad_tolerance, + f_relative_tolerance, + x_tolerance, + lower_bounds, + upper_bounds): + """Checks if the algorithm satisfies the convergence criteria.""" + proj_grad_converged = bfgs_utils.norm( + tf.clip_by_value( + next_position - next_gradient, + lower_bounds, + upper_bounds) - next_position, dims=1) <= grad_tolerance + x_converged = bfgs_utils.norm(next_position - current_position, dims=1) <= x_tolerance + f_converged = bfgs_utils.norm(next_objective - current_objective, dims=0) <= \ + f_relative_tolerance * current_objective + return proj_grad_converged | x_converged | f_converged + + +def _get_initial_state(value_and_gradients_function, + initial_position, + lower_bounds, + upper_bounds, + num_correction_pairs, + tolerance): + """Create LBfgsBOptimizerResults with initial state of search procedure.""" + init_args = bfgs_utils.get_initial_state_args( + value_and_gradients_function, + initial_position, + tolerance) + init_args.update(lower_bounds=lower_bounds, upper_bounds=upper_bounds) + empty_queue = _make_empty_queue_for(num_correction_pairs, initial_position) + init_args.update( + position_deltas=empty_queue, + gradient_deltas=empty_queue, + history=tf.zeros(ps.shape(initial_position)[:-1], dtype=tf.int32)) + return LBfgsBOptimizerResults(**init_args) + + +def _get_initial_cauchy_state(state, num_correction_pairs): + """Create _ConstrainedCauchyState with initial parameters""" + + theta = tf.math.divide_no_nan( + tf.reduce_sum(state.gradient_deltas[-1, ...]**2, axis=-1), + tf.reduce_sum(state.gradient_deltas[-1,...] * state.position_deltas[-1, ...], axis=-1)) + theta = tf.where( + theta != 0, + theta, + 1.0) + + m, refresh = _cauchy_init_m( + state, + ps.shape(state.position_deltas), + theta, + num_correction_pairs) + # Erase the history where M isn't invertible + state = \ + bfgs_utils.update_fields( + state, + gradient_deltas=tf.where( + refresh[..., tf.newaxis], + tf.zeros_like(state.gradient_deltas), + state.gradient_deltas), + position_deltas=tf.where( + refresh[..., tf.newaxis], + tf.zeros_like(state.position_deltas), + state.position_deltas), + history=tf.where(refresh, 0, state.history)) + theta = tf.where(refresh, 1.0, theta) + + breakpoints = _cauchy_init_breakpoints(state) + + steepest = tf.where( + breakpoints != 0., + -state.objective_gradient, + 0.) + + free_mask = (breakpoints > 0) + free_vars_idx = tf.where( + free_mask, + tf.broadcast_to( + tf.range(ps.shape(state.position)[-1], dtype=tf.int32), + ps.shape(state.position)), + -1) + + # We need to account for the varying histories: + # we assume that the first `2*(m-h)` rows of W'^T + # are 0 (where `m` is the number of correction pairs + # and `h` is the history), in concordance with the first + # `2*(m-h)` rows of M being 0. + # 1. Calculate all elements + p = tf.concat( + [ + tf.einsum( + "m...i,...i->...m", + state.gradient_deltas, + steepest), + theta[..., tf.newaxis] * \ + tf.einsum( + "m...i,...i->...m", + state.position_deltas, + steepest) + ], + axis=-1) + # 2. Assemble the rows in the correct order + idx = tf.concat( + [ + tf.ragged.range( + num_correction_pairs - state.history), + tf.ragged.range( + num_correction_pairs, + 2*num_correction_pairs - state.history), + tf.ragged.range( + num_correction_pairs - state.history, + num_correction_pairs), + tf.ragged.range( + 2*num_correction_pairs - state.history, + 2*num_correction_pairs) + ], + axis=-1).to_tensor() + p = tf.gather( + p, + idx, + batch_dims=1) + + c = tf.zeros_like(p) + + df = -tf.reduce_sum(steepest**2, axis=-1) + ddf = -theta*df - tf.einsum("...i,...ij,...j->...", p, m, p) + dt_min = -tf.math.divide_no_nan(df, ddf) + + breakpoint_min_idx, breakpoint_min = \ + _cauchy_get_breakpoint_min(breakpoints, free_vars_idx) + + dt = breakpoint_min + + breakpoint_min_old = tf.zeros_like(breakpoint_min) + + cauchy_point = state.position + + active = ~(state.converged | state.failed) & \ + _cauchy_update_active(free_vars_idx, dt_min, dt) + + return _ConstrainedCauchyState( + theta, m, breakpoints, steepest, free_vars_idx, free_mask, + p, c, df, ddf, dt_min, breakpoint_min, breakpoint_min_idx, + dt, breakpoint_min_old, cauchy_point, active) + + +def _cauchy_init_m(state, deltas_shape, theta, num_correction_pairs): + def build_m(): + # All of the below block matrices have dimensions [..., m, m] + # where `...` denotes the batch dimensions, and `m` the number + # of correction pairs (compare to `deltas_shape`, which is [m,...,n]). + # New elements are pushed in "from the back", so we want to index + # position_deltas and gradient_deltas with negative indices. + # Index 0 of `position_deltas` and `gradient_deltas` is oldest, and index -1 + # is most recent, so the below respects the indexing of the article. + + # 1. calculate inner product (s_i.y_j) in shape [..., m, m] + l = tf.einsum( + "m...i,u...i->...mu", + state.position_deltas, + state.gradient_deltas) + # 2. Zero out diagonal and upper triangular + l_shape = ps.shape(l) + l = tf.linalg.set_diag( + tf.linalg.band_part(l, -1, 0), + tf.zeros([l_shape[0], l_shape[-1]])) + l_transpose = tf.linalg.matrix_transpose(l) + s_t_s = tf.einsum( + 'm...i,n...i->...mn', + state.position_deltas, + state.position_deltas) + d = tf.linalg.diag( + tf.einsum( + 'm...i,m...i->...m', + state.position_deltas, + state.gradient_deltas)) + + # Assemble into full matrix + # TODO: Is there no better way to create a block matrix? + block_d = tf.concat([-d, tf.zeros_like(d)], axis=-1) + block_d = tf.concat([block_d, tf.zeros_like(block_d)], axis=-2) + block_l_transpose = tf.concat([tf.zeros_like(l_transpose), l_transpose], axis=-1) + block_l_transpose = tf.concat([block_l_transpose, tf.zeros_like(block_l_transpose)], axis=-2) + block_l = tf.concat([l, tf.zeros_like(l)], axis=-1) + block_l = tf.concat([tf.zeros_like(block_l), block_l], axis=-2) + block_s_t_s = tf.concat([tf.zeros_like(s_t_s), s_t_s], axis=-1) + block_s_t_s = tf.concat([tf.zeros_like(block_s_t_s), block_s_t_s], axis=-2) + + # shape [b, 2m, 2m] + m_inv = block_d + block_l_transpose + block_l + \ + theta[..., tf.newaxis, tf.newaxis] * block_s_t_s + + # Adjust for varying history: + # Push columns indexed h,...,2m-h to the left (but to the right of 0...m-h) + # and same index rows to the bottom + idx = tf.concat( + [tf.ragged.range(num_correction_pairs-state.history), + tf.ragged.range(num_correction_pairs, 2*num_correction_pairs-state.history), + tf.ragged.range(num_correction_pairs-state.history, num_correction_pairs), + tf.ragged.range(2*num_correction_pairs-state.history, 2*num_correction_pairs)], + axis=-1).to_tensor() + m_inv = tf.gather( + m_inv, + idx, + axis=-1, + batch_dims=1) + m_inv = tf.gather( + m_inv, + idx, + axis=-2, + batch_dims=1) + + # Insert an identity in the empty block + identity_mask = \ + (tf.range(ps.shape(m_inv)[-1])[tf.newaxis, ...] < \ + 2*(num_correction_pairs - state.history[..., tf.newaxis]))[..., tf.newaxis] + + m_inv = tf.where( + identity_mask, + tf.eye(deltas_shape[0]*2, batch_shape=[deltas_shape[1]]), + m_inv) + + # If M is not invertible, refresh the memory + refresh = (tf.linalg.det(m_inv) == 0) + + # Invert where invertible; 0s otherwise + m = tf.where( + refresh[..., tf.newaxis, tf.newaxis], + tf.zeros_like(m_inv), + tf.linalg.inv( + tf.where( + refresh[..., tf.newaxis, tf.newaxis], + tf.eye(deltas_shape[0]*2, batch_shape=[deltas_shape[1]]), + m_inv))) + + # Re-zero the introduced identity blocks + m = tf.where( + identity_mask, + tf.zeros_like(m), + m) + + return m, refresh + + # M is 0 for the first iterations + return tf.cond( + state.num_iterations < 1, + lambda: (tf.zeros([deltas_shape[1], 2*deltas_shape[0], 2*deltas_shape[0]]), + tf.broadcast_to(False, ps.shape(state.history))), + build_m) + + +def _cauchy_init_breakpoints(state): + breakpoints = \ + tf.where( + state.objective_gradient < 0, + tf.math.divide_no_nan( + state.position - state.upper_bounds, + state.objective_gradient), + tf.where( + state.objective_gradient > 0, + tf.math.divide_no_nan( + state.position - state.lower_bounds, + state.objective_gradient), + float('inf'))) + + return breakpoints + + +def _cauchy_remove_breakpoint_min(free_vars_idx, + breakpoint_min_idx, + free_mask, + active): + """Update the free variable indices to remove the minimum breakpoint index. + + Returns: + Updated `free_vars_idx`, `free_mask` + """ + + # NOTE: In situations where none of the indices are free, breakpoint_min_idx + # will falsely report 0. However, this is fine, because in this situation, + # every element of free_vars_idx is -1, and so there is no match. + matching = (free_vars_idx == breakpoint_min_idx[..., tf.newaxis]) + free_vars_idx = tf.where( + matching, + -1, + free_vars_idx) + free_mask = tf.where( + active[..., tf.newaxis], + free_vars_idx >= 0, + free_mask) + + return free_vars_idx, free_mask + + +def _cauchy_get_breakpoint_min(breakpoints, free_vars_idx): + """Find the smallest breakpoint of free indices, returning the minimum breakpoint + and the corresponding index. + + Returns: + Tuple of `breakpoint_min_idx`, `breakpoint_min` + where + `breakpoint_min_idx` is the index that has min. breakpoint + `breakpoint_min` is the corresponding breakpoint + """ + # A tensor of shape [batch, dims] that has +infinity where free_vars_idx < 0, + # and has breakpoints[free_vars_idx] otherwise. + flagged_breakpoints = tf.where( + free_vars_idx < 0, + float('inf'), + tf.gather( + breakpoints, + tf.where( + free_vars_idx < 0, + 0, + free_vars_idx), + batch_dims=1)) + + argmin_idx = tf.math.argmin( + flagged_breakpoints, + axis=-1, + output_type=tf.int32) + + # NOTE: For situations where there are no more free indices + # (and therefore argmin_idx indexes into -1), we set + # breakpoint_min_idx to 0 and flag that there are no free + # indices by setting the breakpoint to -1 (this is an impossible + # value, as breakpoints are g.e. to 0). + # This is because in branching situations, indexing with + # breakpoint_min_idx can occur, and later be discarded, but all + # elements in breakpoint_min_idx must be a priori valid indices. + no_free = tf.gather( + free_vars_idx, + argmin_idx, + batch_dims=1) < 0 + breakpoint_min_idx = tf.where( + no_free, + 0, + tf.gather( + free_vars_idx, + argmin_idx, + batch_dims=1)) + breakpoint_min = tf.where( + no_free, + -1., + tf.gather( + breakpoints, + argmin_idx, + batch_dims=1)) + + return breakpoint_min_idx, breakpoint_min + + +def _get_search_direction(state): + """Computes the search direction to follow at the current state. + + On the `k`-th iteration of the main L-BFGS algorithm, the state has collected + the most recent `m` correction pairs in position_deltas and gradient_deltas, + where `k = state.num_iterations` and `m = min(k, num_correction_pairs)`. + + Assuming these, the code below is an implementation of the L-BFGS two-loop + recursion algorithm given by [Nocedal and Wright(2006)][1]: + + ```None + q_direction = objective_gradient + for i in reversed(range(m)): # First loop. + inv_rho[i] = gradient_deltas[i]^T * position_deltas[i] + alpha[i] = position_deltas[i]^T * q_direction / inv_rho[i] + q_direction = q_direction - alpha[i] * gradient_deltas[i] + + kth_inv_hessian_factor = (gradient_deltas[-1]^T * position_deltas[-1] / + gradient_deltas[-1]^T * gradient_deltas[-1]) + r_direction = kth_inv_hessian_factor * I * q_direction + + for i in range(m): # Second loop. + beta = gradient_deltas[i]^T * r_direction / inv_rho[i] + r_direction = r_direction + position_deltas[i] * (alpha[i] - beta) + + return -r_direction # Approximates - H_k * objective_gradient. + ``` + + Args: + state: A `LBfgsBOptimizerResults` tuple with the current state of the + search procedure. + + Returns: + A real `Tensor` of the same shape as the `state.position`. The direction + along which to perform line search. + """ + # The number of correction pairs that have been collected so far. + #num_elements = ps.minimum( + # state.num_iterations, # TODO(b/162733947): Change loop state -> closure. + # ps.shape(state.position_deltas)[0]) + + def _two_loop_algorithm(): + """L-BFGS two-loop algorithm.""" + # Correction pairs are always appended to the end, so only the latest + # `num_elements` vectors have valid position/gradient deltas. Vectors + # that haven't been computed yet are zero. + position_deltas = state.position_deltas + gradient_deltas = state.gradient_deltas + num_correction_pairs, num_batches, _point_dims = \ + ps.shape(gradient_deltas, out_type=tf.int32) + + # Pre-compute all `inv_rho[i]`s. + inv_rhos = tf.reduce_sum( + gradient_deltas * position_deltas, axis=-1) + + def first_loop(acc, args): + _, q_direction, num_iter = acc + position_delta, gradient_delta, inv_rho = args + active = (num_iter < state.history) + alpha = tf.math.divide_no_nan( + tf.reduce_sum( + position_delta * q_direction, + axis=-1), + inv_rho) + direction_delta = alpha[..., tf.newaxis] * gradient_delta + new_q_direction = tf.where( + active[..., tf.newaxis], + q_direction - direction_delta, + q_direction) + + return (alpha, new_q_direction, num_iter + 1) + + # Run first loop body computing and collecting `alpha[i]`s, while also + # computing the updated `q_direction` at each step. + zero = tf.zeros_like(inv_rhos[0]) + alphas, q_directions, _num_iters = tf.scan( + first_loop, [position_deltas, gradient_deltas, inv_rhos], + initializer=(zero, state.objective_gradient, 0), reverse=True) + + # We use `H^0_k = gamma_k * I` as an estimate for the initial inverse + # hessian for the k-th iteration; then `r_direction = H^0_k * q_direction`. + idx = tf.transpose( + tf.stack( + [tf.where( + state.history > 0, + num_correction_pairs - state.history, + 0), + tf.range(num_batches)])) + gamma_k = tf.math.divide_no_nan( + tf.gather_nd(inv_rhos, idx), + tf.reduce_sum( + tf.gather_nd(gradient_deltas, idx)**2, + axis=-1)) + gamma_k = tf.where( + (state.history > 0), + gamma_k, + 1.0) + r_direction = gamma_k[..., tf.newaxis] * tf.gather_nd(q_directions, idx) + + def second_loop(acc, args): + r_direction, iter_idx = acc + alpha, position_delta, gradient_delta, inv_rho = args + active = (iter_idx >= num_correction_pairs - state.history) + beta = tf.math.divide_no_nan( + tf.reduce_sum( + gradient_delta * r_direction, + axis=-1), + inv_rho) + direction_delta = (alpha - beta)[..., tf.newaxis] * position_delta + new_r_direction = tf.where( + active[..., tf.newaxis], + r_direction + direction_delta, + r_direction) + return (new_r_direction, iter_idx + 1) + + # Finally, run second loop body computing the updated `r_direction` at each + # step. + r_directions, _num_iters = tf.scan( + second_loop, [alphas, position_deltas, gradient_deltas, inv_rhos], + initializer=(r_direction, 0)) + + return -r_directions[-1] + + return ps.cond(tf.reduce_any(state.history != 0), + _two_loop_algorithm, + lambda: -state.objective_gradient) + + +def _get_ragged_sizes(tensor, dtype=tf.int32): + """Creates a tensor indicating the size of each component of + a ragged dimension. + + For example: + + ```python + element = tf.ragged.constant([[1,2], [3,4,5], [], [0]]) + _get_ragged_sizes(element) + # => + ``` + """ + return tf.reduce_sum( + tf.ones_like( + tensor, + dtype=dtype), + axis=-1)[..., tf.newaxis] + + +def _get_range_like_ragged(tensor, dtype=tf.int32): + """Creates a batched range for the elements of the batched tensor. + + For example: + + ```python + element = tf.ragged.constant([[1,2], [3,4,5], [], [0]]) + _get_range_like_ragged(element) + # => + + Args: + tensor: a RaggedTensor of shape `[n, None]`. + + Returns: + A ragged tensor of shape `[n, None]` where the ragged dimensions + match the ragged dimensions of `tensor`, and are a range from `0` to + the size of the ragged dimension. + ``` + """ + sizes = _get_ragged_sizes(tensor) + flat_ranges = tf.ragged.range( + tf.reshape( + sizes, + [tf.reduce_prod(sizes.shape)]), + dtype=dtype) + return tf.RaggedTensor.from_row_lengths(flat_ranges, sizes.shape[:-1])[0] + + +def _make_empty_queue_for(k, element): + """Creates a `tf.Tensor` suitable to hold `k` element-shaped tensors. + + For example: + + ```python + element = tf.constant([[0., 1., 2., 3., 4.], + [5., 6., 7., 8., 9.]]) + + # A queue capable of holding 3 elements. + _make_empty_queue_for(3, element) + # => [[[ 0., 0., 0., 0., 0.], + # [ 0., 0., 0., 0., 0.]], + # + # [[ 0., 0., 0., 0., 0.], + # [ 0., 0., 0., 0., 0.]], + # + # [[ 0., 0., 0., 0., 0.], + # [ 0., 0., 0., 0., 0.]]] + ``` + + Args: + k: A positive scalar integer, number of elements that each queue will hold. + element: A `tf.Tensor`, only its shape and dtype information are relevant. + + Returns: + A zero-filed `tf.Tensor` of shape `(k,) + tf.shape(element)` and same dtype + as `element`. + """ + queue_shape = ps.concat([[k], ps.shape(element)], axis=0) + return tf.zeros(queue_shape, dtype=dtype_util.base_dtype(element.dtype)) + + +def _queue_push(queue, should_update, new_vecs): + """Conditionally push new vectors into a batch of first-in-first-out queues. + + The `queue` of shape `[k, ..., n]` can be thought of as a batch of queues, + each holding `k` n-D vectors; while `new_vecs` of shape `[..., n]` is a + fresh new batch of n-D vectors. The `should_update` batch of Boolean scalars, + i.e. shape `[...]`, indicates batch members whose corresponding n-D vector in + `new_vecs` should be added at the back of its queue, pushing out the + corresponding n-D vector from the front. Batch members in `new_vecs` for + which `should_update` is False are ignored. + + Note: the choice of placing `k` at the dimension 0 of the queue is + constrained by the L-BFGS two-loop algorithm above. The algorithm uses + tf.scan to iterate over the `k` correction pairs simulatneously across all + batches, and tf.scan itself can only iterate over dimension 0. + + For example: + + ```python + k, b, n = (3, 2, 5) + queue = tf.reshape(tf.range(30), (k, b, n)) + # => [[[ 0, 1, 2, 3, 4], + # [ 5, 6, 7, 8, 9]], + # + # [[10, 11, 12, 13, 14], + # [15, 16, 17, 18, 19]], + # + # [[20, 21, 22, 23, 24], + # [25, 26, 27, 28, 29]]] + + element = tf.reshape(tf.range(30, 40), (b, n)) + # => [[30, 31, 32, 33, 34], + [35, 36, 37, 38, 39]] + + should_update = tf.constant([True, False]) # Shape: (b,) + + _queue_add(should_update, queue, element) + # => [[[10, 11, 12, 13, 14], + # [ 5, 6, 7, 8, 9]], + # + # [[20, 21, 22, 23, 24], + # [15, 16, 17, 18, 19]], + # + # [[30, 31, 32, 33, 34], + # [25, 26, 27, 28, 29]]] + ``` + + Args: + queue: A `tf.Tensor` of shape `[k, ..., n]`; a batch of queues each with + `k` n-D vectors. + should_update: A Boolean `tf.Tensor` of shape `[...]` indicating batch + members where new vectors should be added to their queues. + new_vecs: A `tf.Tensor` of shape `[..., n]`; a batch of n-D vectors to add + at the end of their respective queues, pushing out the first element from + each. + + Returns: + A new `tf.Tensor` of shape `[k, ..., n]`. + """ + new_queue = tf.concat([queue[1:], [new_vecs]], axis=0) + return tf.where( + should_update[tf.newaxis, ..., tf.newaxis], new_queue, queue) From ea9557c4fcc21a23eb6980084a6f41bd1dfb9d1e Mon Sep 17 00:00:00 2001 From: mikeevmm Date: Thu, 29 Apr 2021 12:10:43 +0100 Subject: [PATCH 2/4] feat: correct subspace minimization (less fn. evals.) --- .../python/optimizer/lbfgsb.py | 1197 ++++++++++------- 1 file changed, 679 insertions(+), 518 deletions(-) diff --git a/tensorflow_probability/python/optimizer/lbfgsb.py b/tensorflow_probability/python/optimizer/lbfgsb.py index 4e7c9e6d8e..1ec36a61f1 100644 --- a/tensorflow_probability/python/optimizer/lbfgsb.py +++ b/tensorflow_probability/python/optimizer/lbfgsb.py @@ -28,10 +28,6 @@ from __future__ import print_function import collections -from numpy.core.fromnumeric import argmin, clip - -from tensorflow.python.ops.gen_array_ops import gather, lower_bound, where -from tensorflow_probability.python.internal.backend.numpy import dtype, numpy_math # Dependency imports import tensorflow.compat.v2 as tf @@ -87,7 +83,7 @@ ]) _ConstrainedCauchyState = collections.namedtuple( - '_ConstrainedCauchyResult', [ + '_CauchyMinimizationResult', [ 'theta', # `\theta` in [2]; n the Cauchy search, relates to the implicit Hessian # `B = \theta*I - WMW'` (`I` the identity, see [1,2] for details) 'm', # `M_k` matrix in [2]; part of the implicit representation of the Hessian, @@ -103,7 +99,7 @@ # while loop. 'free_mask', # Boolean mask of free variables 'p', # as in [2] - 'c', # as in [2] + 'c', # as in [2]; eventually made to equal `W'(cauchy_point - position)` 'df', # `f'` in [2]; (corrected) gradient 2-norm 'ddf', # `f''` in [2]; (corrected) laplacian 2-norm (?) 'dt_min', # `\Delta t_min` in [2]; the minimizing parameter @@ -142,7 +138,29 @@ def minimize(value_and_gradients_function, constrained minimum for a simple high-dimensional quadratic objective function. ```python - TODO + ndims = 60 + minimum = tf.convert_to_tensor( + np.ones([ndims]), dtype=tf.float32) + lower_bounds = tf.convert_to_tensor( + np.arange(ndims), dtype=tf.float32) + upper_bounds = tf.convert_to_tensor( + np.arange(100, 100-ndims, -1), dtype=tf.float32) + scales = tf.convert_to_tensor( + (np.random.rand(ndims) + 1.)*5. + 1., dtype=tf.float32) + start = tf.constant(np.random.rand(2, ndims)*100, dtype=tf.float32) + + # The objective function and the gradient. + def quadratic_loss_and_gradient(x): + return tfp.math.value_and_gradient( + lambda x: tf.reduce_sum( + scales * tf.math.squared_difference(x, minimum), axis=-1), + x) + opt_results = tfp.optimizer.lbfgsb_minimize( + quadratic_loss_and_gradient, + initial_position=start, + num_correction_pairs=10, + tolerance=1e-10, + bounds=[lower_bounds, upper_bounds]) ``` ### References: @@ -158,6 +176,13 @@ def minimize(value_and_gradients_function, https://doi.org/10.1137/0916069 + [3] Jose Luis Morales, Jorge Nocedal (2011). + "Remark On Algorithm 788: L-BFGS-B: Fortran Subroutines for Large-Scale + Bound Constrained Optimization" + ACM Trans. Math. Softw. 38, 1, Article 7. + + https://dl.acm.org/doi/abs/10.1145/2049662.2049669 + Args: value_and_gradients_function: A Python callable that accepts a point as a real `Tensor` and returns a tuple of `Tensor`s of real dtype containing @@ -369,40 +394,28 @@ def _body(current_state): """Main optimization loop.""" current_state = bfgs_utils.terminate_if_not_finite(current_state) - cauchy_point, free_mask = \ - _cauchy_minimization(current_state, num_correction_pairs, parallel_iterations) - - search_direction = _get_search_direction(current_state) - - # TODO(b/120134934): Check if the derivative at the start point is not - # negative, if so then reset position/gradient deltas and recompute - # search direction. - # NOTE: Erasing is currently handled in `_bounded_line_search_step` - search_direction = tf.where( - free_mask, - search_direction, - 0.) - bad_direction = \ - (tf.reduce_sum(search_direction * current_state.objective_gradient, axis=-1) > 0) - - cauchy_search = _cauchy_line_search_step(current_state, - value_and_gradients_function, search_direction, - tolerance, f_relative_tolerance, x_tolerance, stopping_condition, - max_line_search_iterations, free_mask, cauchy_point) + cauchy_state, current_state = ( + _cauchy_minimization(current_state, num_correction_pairs, parallel_iterations)) + + search_direction, current_state, refresh = ( + _find_search_direction(current_state, cauchy_state, num_correction_pairs)) - search_direction = cauchy_search.position - current_state.position - next_state = _bounded_line_search_step(current_state, - value_and_gradients_function, search_direction, + next_state = _constrained_line_search_step( + current_state, value_and_gradients_function, search_direction, tolerance, f_relative_tolerance, x_tolerance, stopping_condition, - max_line_search_iterations, bad_direction) + max_line_search_iterations, refresh) # If not failed or converged, update the Hessian estimate. - # Only do this if the new pairs obey the s.y > 0 - position_delta = next_state.position - current_state.position - gradient_delta = next_state.objective_gradient - current_state.objective_gradient - positive_prod = (tf.math.reduce_sum(position_delta * gradient_delta, axis=-1) > \ - 1E-8*tf.reduce_sum(gradient_delta**2, axis=-1)) - should_push = ~(next_state.converged | next_state.failed) & positive_prod & ~bad_direction + # Only do this if the new pairs obey the s.y > eps.||g|| + position_delta = (next_state.position - current_state.position) + gradient_delta = (next_state.objective_gradient - current_state.objective_gradient) + # Article is ambiguous; see lbfgs.f:863 + positive_prod = ( + tf.reduce_sum(position_delta * gradient_delta, axis=-1) > + dtype_util.eps(current_state.position.dtype) * + tf.reduce_sum(current_state.objective_gradient**2, axis=-1) + ) + should_push = ~(next_state.converged | next_state.failed) & positive_prod & ~refresh new_position_deltas = _queue_push( next_state.position_deltas, should_push, position_delta) new_gradient_deltas = _queue_push( @@ -421,13 +434,13 @@ def _body(current_state): new_history = tf.ensure_shape( new_history, next_state.history.shape) - state_after_inv_hessian_update = bfgs_utils.update_fields( + next_state = bfgs_utils.update_fields( next_state, position_deltas=new_position_deltas, gradient_deltas=new_gradient_deltas, history=new_history) - return [state_after_inv_hessian_update] + return [next_state] if previous_optimizer_results is None: assert initial_position is not None @@ -455,26 +468,28 @@ def _cauchy_minimization(bfgs_state, num_correction_pairs, parallel_iterations): See algorithm CP and associated discussion of [Byrd,Lu,Nocedal,Zhu][2] for details. + This function may modify the given `bfgs_state`, in that it refreshes the memory + for batches that are found to be in an invalid state. + Args: - bfgs_state: A `_ConstrainedCauchyState` initialized to the starting point of the - constrained minimization. + bfgs_state: current `LBfgsBOptimizerResults` state + num_correction_pairs: typically `m`; the (maximum) number of past steps to keep as + history for the LBFGS algorithm + parallel_iterations: argument of `tf.while` loops Returns: - A potentially modified `state`, the obtained `cauchy_point` and boolean - `free_mask` indicating which variables are free (`True`) and which variables - are under active constrain (`False`) + A `_CauchyMinimizationResult` containing the results of the Cauchy point computation. + Updated `bfgs_state` """ - cauchy_state = _get_initial_cauchy_state(bfgs_state, num_correction_pairs) - # NOTE: See lbfgsb.f (l. 1649) + cauchy_state, bfgs_state = _get_initial_cauchy_state(bfgs_state, num_correction_pairs) + # NOTE: See lbfgsb.f (l. 1524) ddf_org = -cauchy_state.theta * cauchy_state.df def _cond(state): """Test convergence to Cauchy point at current branch""" - return tf.math.reduce_any(state.active) + return tf.reduce_any(state.active) def _body(state): - """Cauchy point iterative loop - - (While loop of CP algorithm [2])""" + """Cauchy point iterative loop (While loop of CP algorithm [2])""" # Remove b from the free indices free_vars_idx, free_mask = _cauchy_remove_breakpoint_min( state.free_vars_idx, @@ -520,7 +535,8 @@ def _body(state): state.breakpoint_min_idx, batch_dims=1)) - keep_idx = (tf.range(ps.shape(state.cauchy_point)[-1]) != \ + # Set the `b`th component of the `cauchy_point` to `x_cp_b` + keep_idx = (tf.range(ps.shape(state.cauchy_point)[-1])[tf.newaxis, ...] != state.breakpoint_min_idx[..., tf.newaxis]) cauchy_point = tf.where( state.active[..., tf.newaxis], @@ -562,14 +578,14 @@ def _body(state): state.breakpoint_min_idx, axis=-1, batch_dims=1), - state.theta[..., tf.newaxis] * \ + (state.theta[..., tf.newaxis] * tf.gather( tf.transpose( bfgs_state.position_deltas, perm=[1,0,2]), state.breakpoint_min_idx, axis=-1, - batch_dims=1) + batch_dims=1)) ], axis=-1) # 2. "Permute" the relevant items to the right @@ -596,34 +612,33 @@ def _body(state): # NOTE Use of d_b = -g_b df = tf.where( state.active, - state.df + state.dt * state.ddf + \ - d_b**2 - \ - state.theta * d_b * z_b + \ + (state.df + state.dt * state.ddf + + d_b**2 - + state.theta * d_b * z_b + d_b * tf.einsum( '...j,...jk,...k->...', w_b, state.m, - c), + c)), state.df) # NOTE use of d_b = -g_b ddf = tf.where( state.active, - state.ddf - state.theta * d_b**2 + \ + (state.ddf - state.theta * d_b**2 + 2. * d_b * tf.einsum( "...i,...ij,...j->...", w_b, state.m, - state.p) - \ + state.p) - d_b**2 * tf.einsum( "...i,...ij,...j->...", w_b, state.m, - w_b), + w_b)), state.ddf) # NOTE: See lbfgsb.f (l. 1649) - # TODO: How to get machine epsilon? - ddf = tf.math.maximum(ddf, 1E-8*ddf_org) + ddf = tf.math.maximum(ddf, dtype_util.eps(ddf.dtype)*ddf_org) # NOTE use of d_b = -g_b p = tf.where( @@ -653,10 +668,10 @@ def _body(state): state.breakpoint_min_old) # Find b - breakpoint_min_idx, breakpoint_min = \ + breakpoint_min_idx, breakpoint_min = ( _cauchy_get_breakpoint_min( state.breakpoints, - free_vars_idx) + free_vars_idx)) breakpoint_min_idx = tf.where( state.active, breakpoint_min_idx, @@ -670,11 +685,9 @@ def _body(state): state.active, breakpoint_min - state.breakpoint_min, state.dt) - - active = tf.where( - state.active, - _cauchy_update_active(free_vars_idx, dt_min, dt), - state.active) + + active = (state.active & + _cauchy_update_active(free_vars_idx, state.breakpoints, dt_min, dt)) # We have to hint the "compiler" that the shapes of the new # values are the same as the old values. @@ -710,212 +723,444 @@ def _body(state): parallel_iterations=parallel_iterations)[0] # The loop broke, so the last identified `b` index never got - # removed - _free_vars_idx, free_mask = _cauchy_remove_breakpoint_min( - cauchy_loop.free_vars_idx, - cauchy_loop.breakpoint_min_idx, - cauchy_loop.free_mask, - cauchy_loop.active) + # removed; we do not require knowledge of the free mask to + # terminate the algorithm and recalculate the free mask below + # with a different method, so we do not correct for this + #free_vars_idx, free_mask = _cauchy_remove_breakpoint_min( + # cauchy_loop.free_vars_idx, + # cauchy_loop.breakpoint_min_idx, + # cauchy_loop.free_mask, + # cauchy_loop.active) dt_min = tf.math.maximum(cauchy_loop.dt_min, 0) t_old = cauchy_loop.breakpoint_min_old + dt_min # A breakpoint of -1 means that we ran out of free variables - flagged_breakpoint_min = tf.where( - cauchy_loop.breakpoint_min < 0, - float('inf'), - cauchy_loop.breakpoint_min) + change_cauchy = ( + (cauchy_loop.breakpoint_min >= 0)[..., tf.newaxis] & + (cauchy_loop.breakpoints >= cauchy_loop.breakpoint_min[..., tf.newaxis]) + ) cauchy_point = tf.where( ~(bfgs_state.converged | bfgs_state.failed)[..., tf.newaxis], tf.where( - cauchy_loop.breakpoints >= flagged_breakpoint_min[..., tf.newaxis], + change_cauchy, bfgs_state.position + t_old[..., tf.newaxis] * cauchy_loop.steepest, cauchy_loop.cauchy_point), bfgs_state.position) - # NOTE: We only return the cauchy point and the free mask, so there is no - # need to update the actual state, even though we could at this point update - # `free_vars_idx`, `free_mask`, and `cauchy_point` - free_mask = free_mask & ~(cauchy_loop.breakpoints != cauchy_loop.breakpoint_min) + c = cauchy_loop.c + dt_min[..., tf.newaxis]*cauchy_loop.p + # NOTE: `c` is already permuted to match the subspace of `M`, because `w_b` + # was already permuted. + # You can explicitly check this by comparing its value with W'.(x^c - x) + # at this point. + + # Update the free mask; + # Instead of updating the mask as suggested in [1, CP Algorithm], we recalculate + # whether each variable is free by looking at whether the Cauchy point is near + # the bound. This matches other implementations, and avoids weirdness where + # the first considered variable is always marked as constrained. + # NOTE: the 10 epsilon margin is fairly arbitrary + free_mask =( + tf.math.minimum( + tf.math.abs(cauchy_point - bfgs_state.upper_bounds), + tf.math.abs(cauchy_point - bfgs_state.lower_bounds), + ) > 10. * dtype_util.eps(cauchy_point.dtype) + ) + free_vars_idx = ( + tf.where( + free_mask, + tf.range(ps.shape(free_mask)[-1])[tf.newaxis, ...], + -1)) - return cauchy_point, free_mask + # Update the final cauchy_state + # Hint the compiler that shape of things will not change + if not tf.executing_eagerly(): + free_vars_idx = ( + tf.ensure_shape( + free_vars_idx, + cauchy_loop.free_vars_idx.shape)) + free_mask = ( + tf.ensure_shape( + free_mask, + cauchy_loop.free_mask.shape)) + cauchy_point = ( + tf.ensure_shape( + cauchy_point, + cauchy_loop.cauchy_point.shape)) + c = ( + tf.ensure_shape( + c, + cauchy_loop.c.shape)) + # Do the actual updating + final_cauchy_state = bfgs_utils.update_fields( + cauchy_loop, + free_vars_idx=free_vars_idx, + free_mask=free_mask, + cauchy_point=cauchy_point, + c=c) + + return final_cauchy_state, bfgs_state + + +def _cauchy_update_active(free_vars_idx, breakpoints, dt_min, dt): + """Determines whether each batch of a `_CauchyMinimizationResult` is active. + + The conditions for a batch being active (i.e. for the loop of "Algorithm CP" + of [2] to proceed for that batch are): + + 1. That `dt_min >= dt` (as made explicit in the paper), + 2. That there are free variables, and + 3. That of those free variables, at least one of the corresponding breakpoints is + finite. + Args: + free_vars_idx: tensor of shape [batch, dims] where each element corresponds to + the index of the variable if the variable is free, and `-1` if the variable is + actively constrained + breakpoints: the breakpoints (`t` in [2]) of the `_CauchyMinimizationResult` + dt_min: the current `dt_min` property of the `_CauchyMinimizationResult` + dt: the current `dt` property of the `_CauchyMinimizationResult` + """ + free_vars = (free_vars_idx >= 0) + return ( + (dt_min >= dt) & + tf.reduce_any(free_vars, axis=-1) & + tf.reduce_any(free_vars & (breakpoints != float('inf')), axis=-1)) -def _cauchy_update_active(free_vars_idx, dt_min, dt): - return tf.where( - tf.reduce_any(free_vars_idx >= 0, axis=-1) & (dt_min >= dt), - True, - False) + +def _find_search_direction(bfgs_state, cauchy_state, num_correction_pairs): + """Finds the search direction based on the direct primal method. + + This function corresponds to points 1-6 of the Direct Primal Method presented + in [2, p. 1199], with the first modification suggested in [3]. + + If an invalid condition is reached for a given batch, its history is reset. Therefore, + this function also returns an updated `bfgs_state`. + + Args: + bfgs_state: the `LBfgsBOptimizerResults` object representing the current iteration. + cauchy_state: the `_CauchyMinimizationResult` results of a cauchy search computation. + Typically the output of `_cauchy_minimization`. + num_correction_pairs: The (maximum) number of correction pairs stored in memory (`m`) + Returns: + Tensor of batched search directions, + Updated `bfgs_state` + Boolean mask of batches that have been refreshed + """ + # Let the reduced gradient be [2, eq. 5.4] + # + # ρ = Z'r + # r = g + Θ(x^c - x) - W.M.c + # + # and the search direction [2, eq. 5.7] + # + # d = -B⁻¹ρ + # + # and [2, eq. 5.10] + # + # B⁻¹ = 1/Θ [ I + 1/Θ Z'.W.N⁻¹.M.W'.Z ] + # N = I - 1/Θ M.W'.Z.Z'.W + # + # Therefore, (see NOTE below regarding minus sign) + # + # d = Z' . (-1/Θ) . [ r + 1/Θ W.N⁻¹.M.W'.Z.Z'.r ] + # + # Letting + # + # K = M.W'.Z.Z' + # + # this is rewritten as + # + # d = -Z' . (1/Θ) . [ r + 1/Θ W.N⁻¹.K.r ] + # N = I - (1/Θ) K.W + + idx = ( + tf.concat([ + tf.ragged.range( + num_correction_pairs - bfgs_state.history), + tf.ragged.range( + num_correction_pairs, + 2*num_correction_pairs - bfgs_state.history), + tf.ragged.range( + num_correction_pairs - bfgs_state.history, + num_correction_pairs), + tf.ragged.range( + 2*num_correction_pairs - bfgs_state.history, + 2*num_correction_pairs) + ], + axis=-1).to_tensor()) + + w_transpose = ( + tf.gather( + tf.transpose( + tf.concat( + [bfgs_state.gradient_deltas, + cauchy_state.theta[tf.newaxis, ..., tf.newaxis] * bfgs_state.position_deltas], + axis=0), + perm=[1, 0, 2] + ), + idx, + batch_dims=1) + ) + + k = ( + tf.einsum( + '...ij,...jk->...ik', + cauchy_state.m, + tf.where( + cauchy_state.free_mask[..., tf.newaxis, :], + w_transpose, + 0. + ) + ) + ) + + n = ( + tf.eye(2*num_correction_pairs, batch_shape=ps.shape(bfgs_state.position)[:-1]) - + tf.einsum( + '...ij,...kj->...ik', + k, + w_transpose + ) / cauchy_state.theta[..., tf.newaxis, tf.newaxis] + ) + + n_mask = ( + tf.range(2*num_correction_pairs)[tf.newaxis, ...] < + (2*num_correction_pairs - 2*bfgs_state.history)[..., tf.newaxis] + )[..., tf.newaxis] + + n = ( + tf.where( + n_mask, + tf.eye(2*num_correction_pairs, batch_shape=ps.shape(bfgs_state.position)[:-1]), + n + ) + ) + + # NOTE: For no history, N is at the moment the identity (so never triggers a refresh), + # and is correctly completely zeroed once the inversion is complete + refresh = (tf.linalg.det(n) == 0) + + n = ( + tf.where( + (refresh[..., tf.newaxis, tf.newaxis] | n_mask), + 0., + tf.linalg.inv( + tf.where( + refresh[..., tf.newaxis, tf.newaxis], + tf.eye(2*num_correction_pairs, batch_shape=ps.shape(bfgs_state.position)[:-1]), + n + ) + ) + ) + ) + + r = ( + bfgs_state.objective_gradient + + (cauchy_state.cauchy_point - bfgs_state.position) * cauchy_state.theta[..., tf.newaxis] - + tf.einsum( + '...ji,...jk,...k->...i', + w_transpose, + cauchy_state.m, + cauchy_state.c + ) + ) + + # TODO: According to the comment at the top of this function's definition + # there's a leading minus here, but both the article and the Fortran + # implementation do not use it. I cannot understand why, but the negative + # sign seems to produce the correct results. + # (See lbfgsb.f:3021) + d = ( + -(r + + tf.einsum( + '...ji,...jk,...kl,...l->...i', + w_transpose, + n, + k, + r + ) / cauchy_state.theta[..., tf.newaxis] + ) / cauchy_state.theta[..., tf.newaxis] + ) + + d = ( + tf.where( + cauchy_state.free_mask, + d, + 0. + ) + ) + + # Per [3]: + # Clip the `(cauchy point) + d` into the bounds + lower_term = tf.math.divide_no_nan( + bfgs_state.lower_bounds - cauchy_state.cauchy_point, + d) + upper_term = tf.math.divide_no_nan( + bfgs_state.upper_bounds - cauchy_state.cauchy_point, + d) + clip_per_var = ( + tf.where( + (d > 0), + upper_term, + tf.where( + (d < 0), + lower_term, + float('inf'))) + ) + + movement_clip = ( + tf.math.minimum( + tf.math.reduce_min(clip_per_var, axis=-1), + 1.) + ) + # NOTE: `d` is zeroed for constrained variables, and `movement_clip` is at most 1. + minimizer = ( + cauchy_state.cauchy_point + movement_clip[..., tf.newaxis]*d + ) + + # Per [3]: If the search direction obtained with this minimizer is not a direction + # of strong descent, do not clip `d` to obtain the minimizer (i.e. fall back to the + # original algorithm) + fallback = ( + tf.reduce_sum( + (minimizer - bfgs_state.position) * bfgs_state.objective_gradient, axis=-1) > 0 + ) + + minimizer = ( + tf.where( + fallback[..., tf.newaxis], + cauchy_state.cauchy_point + d, + minimizer + ) + ) + + search_direction = (minimizer - bfgs_state.position) + + # Reset if the search direction still isn't a direction of strong descent + refresh |= ( + tf.reduce_sum(search_direction * bfgs_state.objective_gradient, axis=-1) > 0) + + # Apply refresh + bfgs_state = _erase_history(bfgs_state, refresh) + + return search_direction, bfgs_state, refresh -def _hz_line_search(state, value_and_gradients_function, - search_direction, max_iterations, inactive): +def _hz_line_search(starting_position, starting_value, starting_gradient, + value_and_gradients_function, search_direction, max_iterations, inactive): + """Performs Hager Zhang line search via `bfgs_utils.linesearch.hager_zhang`.""" line_search_value_grad_func = bfgs_utils._restrict_along_direction( - value_and_gradients_function, state.position, search_direction) + value_and_gradients_function, starting_position, search_direction) derivative_at_start_pt = tf.reduce_sum( - state.objective_gradient * search_direction, axis=-1) - val_0 = bfgs_utils.ValueAndGradient(x=bfgs_utils._broadcast(0, state.position), - f=state.objective_value, + starting_gradient * search_direction, axis=-1) + val_0 = bfgs_utils.ValueAndGradient(x=bfgs_utils._broadcast(0, starting_position), + f=starting_value, df=derivative_at_start_pt, - full_gradient=state.objective_gradient) + full_gradient=starting_gradient) return bfgs_utils.linesearch.hager_zhang( line_search_value_grad_func, - initial_step_size=bfgs_utils._broadcast(1, state.position), + initial_step_size=bfgs_utils._broadcast(1, starting_position), value_at_zero=val_0, converged=inactive, max_iterations=max_iterations) # No search needed for these. -def _cauchy_line_search_step(state, value_and_gradients_function, search_direction, +def _constrained_line_search_step(bfgs_state, value_and_gradients_function, search_direction, grad_tolerance, f_relative_tolerance, x_tolerance, - stopping_condition, max_iterations, free_mask, cauchy_point): - """Performs the line search in given direction, backtracking in direction to the cauchy point, - and clamping actively contrained variables to the cauchy point.""" - inactive = state.failed | state.converged - ls_result = _hz_line_search(state, value_and_gradients_function, - search_direction, max_iterations, inactive) - - state_after_ls = bfgs_utils.update_fields( - state, - failed=state.failed | (~state.converged & ~ls_result.converged & tf.reduce_any(free_mask, axis=-1)), - num_iterations=state.num_iterations + 1, - num_objective_evaluations=( - state.num_objective_evaluations + ls_result.func_evals + 1)) - - def _do_update_position(): - # For inactive batch members `left.x` is zero. However, their - # `search_direction` might also be undefined, so we can't rely on - # multiplication by zero to produce a `position_delta` of zero. - alpha = ls_result.left.x[..., tf.newaxis] - ideal_position = tf.where( - inactive[..., tf.newaxis], - state.position, - tf.where( - free_mask, - state.position + search_direction * alpha, - cauchy_point)) - - # Backtrack from the ideal position in direction to the Cauchy point - cauchy_to_ideal = ideal_position - cauchy_point - clip_lower = tf.math.divide_no_nan( - state.lower_bounds - cauchy_point, - cauchy_to_ideal) - clip_upper = tf.math.divide_no_nan( - state.upper_bounds - cauchy_point, - cauchy_to_ideal) - clip = tf.math.reduce_min( - tf.where( - cauchy_to_ideal > 0, - clip_upper, - tf.where( - cauchy_to_ideal < 0, - clip_lower, - float('inf'))), - axis=-1) - alpha = tf.minimum(1.0, clip)[..., tf.newaxis] - - next_position = tf.where( - inactive[..., tf.newaxis], - state.position, - tf.where( - free_mask, - cauchy_point + alpha * cauchy_to_ideal, - cauchy_point)) - - # NOTE: one extra call to the function - next_objective, next_gradient = \ - value_and_gradients_function(next_position) + stopping_condition, max_iterations, refresh): + """Performs a constrained line search clamped to bounds in given direction.""" + inactive = (bfgs_state.failed | bfgs_state.converged) | refresh - return _update_position( - state_after_ls, - next_position, - next_objective, - next_gradient, - grad_tolerance, - f_relative_tolerance, - x_tolerance, - tf.constant(False)) - - return ps.cond( - stopping_condition(state.converged, state.failed), - true_fn=lambda: state_after_ls, - false_fn=_do_update_position) - - -def _bounded_line_search_step(state, value_and_gradients_function, search_direction, - grad_tolerance, f_relative_tolerance, x_tolerance, - stopping_condition, max_iterations, bad_direction): - """Performs a line search in given direction, clamping to the bounds, and fixing the actively - constrained values to the given values.""" - inactive = state.failed | state.converged | bad_direction - ls_result = _hz_line_search(state, value_and_gradients_function, - search_direction, max_iterations, inactive) - - new_failed = state.failed | (~state.converged & ~ls_result.converged \ - & tf.reduce_any(search_direction != 0, axis=-1)) \ - & ~bad_direction - new_num_iterations = state.num_iterations + 1 - new_num_objective_evaluations = ( - state.num_objective_evaluations + ls_result.func_evals + 1) - - if not tf.executing_eagerly(): - # Hint the compiler that the properties' shape will not change - new_failed = tf.ensure_shape( - new_failed, state.failed.shape) - new_num_iterations = tf.ensure_shape( - new_num_iterations, state.num_iterations.shape) - new_num_objective_evaluations = tf.ensure_shape( - new_num_objective_evaluations, state.num_objective_evaluations.shape) - - state_after_ls = bfgs_utils.update_fields( - state, - failed=new_failed, - num_iterations=new_num_iterations, - num_objective_evaluations=new_num_objective_evaluations) - - def _do_update_position(): + def _do_line_search_step(): + """Do unconstrained line search.""" + # Truncation bounds lower_term = tf.math.divide_no_nan( - state.lower_bounds - state.position, - search_direction) + bfgs_state.lower_bounds - bfgs_state.position, + search_direction) upper_term = tf.math.divide_no_nan( - state.upper_bounds - state.position, + bfgs_state.upper_bounds - bfgs_state.position, search_direction) + + # Truncate the search direction to bounds before search + bounds_clip = ( + tf.reduce_min( + tf.where( + (search_direction > 0), + upper_term, + tf.where( + (search_direction < 0), + lower_term, + float('inf'))), + axis=-1) + ) + pre_clip = ( + tf.math.minimum( + bounds_clip, + 1.) + ) + + clipped_search_direction = search_direction * pre_clip[..., tf.newaxis] - under_clip = tf.math.reduce_max( - tf.where( - (search_direction > 0), - lower_term, - tf.where( - (search_direction < 0), - upper_term, - -float('inf'))), - axis=-1) - over_clip = tf.math.reduce_min( - tf.where( - (search_direction > 0), - upper_term, - tf.where( - (search_direction < 0), - lower_term, - float('inf'))), - axis=-1) + ls_result = _hz_line_search(bfgs_state.position, bfgs_state.objective_value, bfgs_state.objective_gradient, + value_and_gradients_function, clipped_search_direction, + max_iterations, inactive) + + new_failed = ((bfgs_state.failed | (~bfgs_state.converged & ~ls_result.converged)) & ~inactive) + new_num_iterations = bfgs_state.num_iterations + 1 + new_num_objective_evaluations = ( + bfgs_state.num_objective_evaluations + ls_result.func_evals + 1) + + # Also truncate to bounds after search + step = ( + tf.math.minimum( + bounds_clip, + ls_result.left.x + ) + ) - alpha_clip = tf.clip_by_value( - ls_result.left.x, - under_clip, - over_clip)[..., tf.newaxis] + # Hint the compiler that the properties' shape will not change + if not tf.executing_eagerly(): + new_failed = tf.ensure_shape( + new_failed, bfgs_state.failed.shape) + new_num_iterations = tf.ensure_shape( + new_num_iterations, bfgs_state.num_iterations.shape) + new_num_objective_evaluations = tf.ensure_shape( + new_num_objective_evaluations, bfgs_state.num_objective_evaluations.shape) + + state_after_ls = bfgs_utils.update_fields( + state=bfgs_state, + failed=new_failed, + num_iterations=new_num_iterations, + num_objective_evaluations=new_num_objective_evaluations) + + return step, state_after_ls + + # NOTE: It's important that the default (false `pred`) step matches + # the shape of true `pred` shape for graph purposes + step, state_after_ls = ( + tf.cond( + pred=tf.math.logical_not(tf.reduce_all(inactive)), + true_fn=_do_line_search_step, + false_fn=lambda: (tf.zeros_like(inactive, dtype=search_direction.dtype), bfgs_state) + )) + def _do_update_position(): + """Update the position""" # For inactive batch members `left.x` is zero. However, their # `search_direction` might also be undefined, so we can't rely on # multiplication by zero to produce a `position_delta` of zero. + # Also, the search direction has already been clipped to make sure + # it does not go out of bounds. next_position = tf.where( inactive[..., tf.newaxis], - state.position, - state.position + search_direction * alpha_clip) + bfgs_state.position, + bfgs_state.position + step[..., tf.newaxis] * search_direction) # one extra call to the function, counted above - next_objective, next_gradient = \ + next_objective, next_gradient = ( value_and_gradients_function(next_position) + ) return _update_position( state_after_ls, @@ -925,10 +1170,11 @@ def _do_update_position(): grad_tolerance, f_relative_tolerance, x_tolerance, - bad_direction) + inactive) return ps.cond( - stopping_condition(state.converged, state.failed), + (stopping_condition(bfgs_state.converged, bfgs_state.failed) & + tf.math.logical_not(tf.reduce_all(inactive))), true_fn=lambda: state_after_ls, false_fn=_do_update_position) @@ -940,34 +1186,22 @@ def _update_position(state, grad_tolerance, f_relative_tolerance, x_tolerance, - erase_memory): - """Updates the state advancing its position by a given position_delta. - Also erases the LBFGS memory if indicated.""" + inactive): + """Updates the state advancing its position by a given position_delta.""" state = bfgs_utils.terminate_if_not_finite(state, next_objective, next_gradient) - converged = ~state.failed & \ - _check_convergence_bounded(state.position, - next_position, - state.objective_value, - next_objective, - next_gradient, - grad_tolerance, - f_relative_tolerance, - x_tolerance, - state.lower_bounds, - state.upper_bounds) - new_position_deltas = tf.where( - erase_memory[..., tf.newaxis], - tf.zeros_like(state.position_deltas), - state.position_deltas) - new_gradient_deltas = tf.where( - erase_memory[..., tf.newaxis], - tf.zeros_like(state.gradient_deltas), - state.gradient_deltas) - new_history = tf.where( - erase_memory, - tf.zeros_like(state.history), - state.history) + converged = (~inactive & + ~state.failed & + _check_convergence_bounded(state.position, + next_position, + state.objective_value, + next_objective, + next_gradient, + grad_tolerance, + f_relative_tolerance, + x_tolerance, + state.lower_bounds, + state.upper_bounds)) new_converged = (state.converged | converged) if not tf.executing_eagerly(): @@ -976,19 +1210,57 @@ def _update_position(state, next_position = tf.ensure_shape(next_position, state.position.shape) next_objective = tf.ensure_shape(next_objective, state.objective_value.shape) next_gradient = tf.ensure_shape(next_gradient, state.objective_gradient.shape) - new_position_deltas = tf.ensure_shape(new_position_deltas, state.position_deltas.shape) - new_gradient_deltas = tf.ensure_shape(new_gradient_deltas, state.gradient_deltas.shape) - new_history = tf.ensure_shape(new_history, state.history.shape) return bfgs_utils.update_fields( state, converged=new_converged, position=next_position, objective_value=next_objective, - objective_gradient=next_gradient, - position_deltas=new_position_deltas, - gradient_deltas=new_gradient_deltas, - history=new_history) + objective_gradient=next_gradient) + + +def _erase_history(bfgs_state, where_erase): + """Erases the BFGS correction pairs for the specified batches. + + This function will zero `gradient_deltas`, `position_deltas`, and `history`. + + Args: + `bfgs_state`: a `LBfgsBOptimizerResults` to modify + `where_erase`: a Boolean tensor with shape matching the batch dimensions + with `True` for the batches to erase the history of. + Returns: + Modified `bfgs_state`. + """ + # Calculate new values + new_gradient_deltas = (tf.where( + where_erase[..., tf.newaxis], + 0., + bfgs_state.gradient_deltas)) + new_position_deltas = (tf.where( + where_erase[..., tf.newaxis], + 0., + bfgs_state.position_deltas)) + new_history = tf.where(where_erase, 0, bfgs_state.history) + # Assure the compiler that the shape of things has not changed + if not tf.executing_eagerly(): + new_gradient_deltas = ( + tf.ensure_shape( + new_gradient_deltas, + bfgs_state.gradient_deltas.shape)) + new_position_deltas = ( + tf.ensure_shape( + new_position_deltas, + bfgs_state.position_deltas.shape)) + new_history = ( + tf.ensure_shape( + new_history, + bfgs_state.history.shape)) + # Update and return + return bfgs_utils.update_fields( + bfgs_state, + gradient_deltas=new_gradient_deltas, + position_deltas=new_position_deltas, + history=new_history) def _check_convergence_bounded(current_position, @@ -1002,17 +1274,22 @@ def _check_convergence_bounded(current_position, lower_bounds, upper_bounds): """Checks if the algorithm satisfies the convergence criteria.""" + # NOTE: The original algorithm (as described in [2]) only considers halting on + # the projected gradient condition. However, `x_converged` and `f_converged` do + # not seem to pose a problem when refreshing is correctly accounted for (so that + # the optimization does not halt upon a refresh), and the default values of `0` + # for `f_relative_tolerance` and `x_tolerance` further strengthen these conditions. proj_grad_converged = bfgs_utils.norm( tf.clip_by_value( next_position - next_gradient, lower_bounds, upper_bounds) - next_position, dims=1) <= grad_tolerance x_converged = bfgs_utils.norm(next_position - current_position, dims=1) <= x_tolerance - f_converged = bfgs_utils.norm(next_objective - current_objective, dims=0) <= \ - f_relative_tolerance * current_objective + f_converged = ( + bfgs_utils.norm(next_objective - current_objective, dims=0) <= + f_relative_tolerance * current_objective) return proj_grad_converged | x_converged | f_converged - def _get_initial_state(value_and_gradients_function, initial_position, lower_bounds, @@ -1033,50 +1310,56 @@ def _get_initial_state(value_and_gradients_function, return LBfgsBOptimizerResults(**init_args) -def _get_initial_cauchy_state(state, num_correction_pairs): - """Create _ConstrainedCauchyState with initial parameters""" +def _get_initial_cauchy_state(bfgs_state, num_correction_pairs): + """Create `_ConstrainedCauchyState` with initial parameters. + + This will calculate the elements of `_ConstrainedCauchyState` based on the given + `LBfgsBOptimizerResults` state object. Some of these properties may be incalculable, + for which batches the state will be reset. + + Args: + bfgs_state: `LBfgsBOptimizerResults` object representing the current state of the + LBFGSB optimization + num_correction_pairs: typically `m`; the (maximum) number of past steps to keep as + history for the LBFGS algorithm + + Returns: + Initialized `_ConstrainedCauchyState` + Updated `bfgs_state` + """ theta = tf.math.divide_no_nan( - tf.reduce_sum(state.gradient_deltas[-1, ...]**2, axis=-1), - tf.reduce_sum(state.gradient_deltas[-1,...] * state.position_deltas[-1, ...], axis=-1)) + tf.reduce_sum(bfgs_state.gradient_deltas[-1, ...]**2, axis=-1), + (tf.reduce_sum(bfgs_state.gradient_deltas[-1,...] * + bfgs_state.position_deltas[-1, ...], axis=-1))) theta = tf.where( theta != 0, theta, - 1.0) + 1.) m, refresh = _cauchy_init_m( - state, - ps.shape(state.position_deltas), + bfgs_state, + ps.shape(bfgs_state.position_deltas), theta, num_correction_pairs) + # Erase the history where M isn't invertible - state = \ - bfgs_utils.update_fields( - state, - gradient_deltas=tf.where( - refresh[..., tf.newaxis], - tf.zeros_like(state.gradient_deltas), - state.gradient_deltas), - position_deltas=tf.where( - refresh[..., tf.newaxis], - tf.zeros_like(state.position_deltas), - state.position_deltas), - history=tf.where(refresh, 0, state.history)) - theta = tf.where(refresh, 1.0, theta) - - breakpoints = _cauchy_init_breakpoints(state) + bfgs_state = _erase_history(bfgs_state, refresh) + theta = tf.where(refresh, 1., theta) + + breakpoints = _cauchy_init_breakpoints(bfgs_state) steepest = tf.where( breakpoints != 0., - -state.objective_gradient, + -bfgs_state.objective_gradient, 0.) free_mask = (breakpoints > 0) free_vars_idx = tf.where( free_mask, tf.broadcast_to( - tf.range(ps.shape(state.position)[-1], dtype=tf.int32), - ps.shape(state.position)), + tf.range(ps.shape(bfgs_state.position)[-1], dtype=tf.int32), + ps.shape(bfgs_state.position)), -1) # We need to account for the varying histories: @@ -1089,28 +1372,28 @@ def _get_initial_cauchy_state(state, num_correction_pairs): [ tf.einsum( "m...i,...i->...m", - state.gradient_deltas, + bfgs_state.gradient_deltas, steepest), - theta[..., tf.newaxis] * \ + (theta[..., tf.newaxis] * tf.einsum( "m...i,...i->...m", - state.position_deltas, - steepest) + bfgs_state.position_deltas, + steepest)) ], axis=-1) # 2. Assemble the rows in the correct order idx = tf.concat( [ tf.ragged.range( - num_correction_pairs - state.history), + num_correction_pairs - bfgs_state.history), tf.ragged.range( num_correction_pairs, - 2*num_correction_pairs - state.history), + 2*num_correction_pairs - bfgs_state.history), tf.ragged.range( - num_correction_pairs - state.history, + num_correction_pairs - bfgs_state.history, num_correction_pairs), tf.ragged.range( - 2*num_correction_pairs - state.history, + 2*num_correction_pairs - bfgs_state.history, 2*num_correction_pairs) ], axis=-1).to_tensor() @@ -1125,26 +1408,31 @@ def _get_initial_cauchy_state(state, num_correction_pairs): ddf = -theta*df - tf.einsum("...i,...ij,...j->...", p, m, p) dt_min = -tf.math.divide_no_nan(df, ddf) - breakpoint_min_idx, breakpoint_min = \ - _cauchy_get_breakpoint_min(breakpoints, free_vars_idx) + breakpoint_min_idx, breakpoint_min = ( + _cauchy_get_breakpoint_min(breakpoints, free_vars_idx)) dt = breakpoint_min breakpoint_min_old = tf.zeros_like(breakpoint_min) - cauchy_point = state.position + cauchy_point = bfgs_state.position - active = ~(state.converged | state.failed) & \ - _cauchy_update_active(free_vars_idx, dt_min, dt) + active = (~(bfgs_state.converged | bfgs_state.failed) & + _cauchy_update_active(free_vars_idx, breakpoints, dt_min, dt)) - return _ConstrainedCauchyState( - theta, m, breakpoints, steepest, free_vars_idx, free_mask, - p, c, df, ddf, dt_min, breakpoint_min, breakpoint_min_idx, - dt, breakpoint_min_old, cauchy_point, active) + cauchy_state = ( + _ConstrainedCauchyState( + theta, m, breakpoints, steepest, free_vars_idx, free_mask, + p, c, df, ddf, dt_min, breakpoint_min, breakpoint_min_idx, + dt, breakpoint_min_old, cauchy_point, active)) + + return cauchy_state, bfgs_state def _cauchy_init_m(state, deltas_shape, theta, num_correction_pairs): + """Initialize the M matrix for a `_CauchyMinimizationResult` state.""" def build_m(): + """Construct and invert the M block matrix.""" # All of the below block matrices have dimensions [..., m, m] # where `...` denotes the batch dimensions, and `m` the number # of correction pairs (compare to `deltas_shape`, which is [m,...,n]). @@ -1176,18 +1464,23 @@ def build_m(): # Assemble into full matrix # TODO: Is there no better way to create a block matrix? - block_d = tf.concat([-d, tf.zeros_like(d)], axis=-1) - block_d = tf.concat([block_d, tf.zeros_like(block_d)], axis=-2) - block_l_transpose = tf.concat([tf.zeros_like(l_transpose), l_transpose], axis=-1) - block_l_transpose = tf.concat([block_l_transpose, tf.zeros_like(block_l_transpose)], axis=-2) - block_l = tf.concat([l, tf.zeros_like(l)], axis=-1) - block_l = tf.concat([tf.zeros_like(block_l), block_l], axis=-2) - block_s_t_s = tf.concat([tf.zeros_like(s_t_s), s_t_s], axis=-1) - block_s_t_s = tf.concat([tf.zeros_like(block_s_t_s), block_s_t_s], axis=-2) + m_inv = tf.concat( + [ + tf.concat([-d, l_transpose], axis=-1), + tf.concat([l, theta[..., tf.newaxis, tf.newaxis] * s_t_s], axis=-1) + ], axis=-2) + #block_d = tf.concat([-d, tf.zeros_like(d)], axis=-1) + #block_d = tf.concat([block_d, tf.zeros_like(block_d)], axis=-2) + #block_l_transpose = tf.concat([tf.zeros_like(l_transpose), l_transpose], axis=-1) + #block_l_transpose = tf.concat([block_l_transpose, tf.zeros_like(block_l_transpose)], axis=-2) + #block_l = tf.concat([l, tf.zeros_like(l)], axis=-1) + #block_l = tf.concat([tf.zeros_like(block_l), block_l], axis=-2) + #block_s_t_s = tf.concat([tf.zeros_like(s_t_s), s_t_s], axis=-1) + #block_s_t_s = tf.concat([tf.zeros_like(block_s_t_s), block_s_t_s], axis=-2) # shape [b, 2m, 2m] - m_inv = block_d + block_l_transpose + block_l + \ - theta[..., tf.newaxis, tf.newaxis] * block_s_t_s + #m_inv = (block_d + block_l_transpose + block_l + + # theta[..., tf.newaxis, tf.newaxis] * block_s_t_s) # Adjust for varying history: # Push columns indexed h,...,2m-h to the left (but to the right of 0...m-h) @@ -1210,9 +1503,9 @@ def build_m(): batch_dims=1) # Insert an identity in the empty block - identity_mask = \ - (tf.range(ps.shape(m_inv)[-1])[tf.newaxis, ...] < \ - 2*(num_correction_pairs - state.history[..., tf.newaxis]))[..., tf.newaxis] + identity_mask = ( + (tf.range(ps.shape(m_inv)[-1])[tf.newaxis, ...] < + 2*(num_correction_pairs - state.history[..., tf.newaxis]))[..., tf.newaxis]) m_inv = tf.where( identity_mask, @@ -1249,7 +1542,8 @@ def build_m(): def _cauchy_init_breakpoints(state): - breakpoints = \ + """Calculate the breakpoints for a `_CauchyMinimizationResult` state.""" + breakpoints = ( tf.where( state.objective_gradient < 0, tf.math.divide_no_nan( @@ -1261,6 +1555,7 @@ def _cauchy_init_breakpoints(state): state.position - state.lower_bounds, state.objective_gradient), float('inf'))) + ) return breakpoints @@ -1271,6 +1566,18 @@ def _cauchy_remove_breakpoint_min(free_vars_idx, active): """Update the free variable indices to remove the minimum breakpoint index. + This will set the `breakpoint_min_idx`th entry of `free_mask` to `False`, + and of `free_vars_idx` to `-1`. + + Args: + free_vars_idx: tensor of shape [batch, dims] where each entry is the index of the + entry for the batch if the corresponding variable is free, and -1 otherwise + breakpoint_min_idx: tensor of shape [batch] denoting the indices to mark as + constrained for each batch + free_mask: tensor of shape [batch, dims] where `True` denotes a free variable, and + `False` an actively constrained variable + active: tensor of shape [batch] denoting whether each batch should be updated + Returns: Updated `free_vars_idx`, `free_mask` """ @@ -1280,28 +1587,35 @@ def _cauchy_remove_breakpoint_min(free_vars_idx, # every element of free_vars_idx is -1, and so there is no match. matching = (free_vars_idx == breakpoint_min_idx[..., tf.newaxis]) free_vars_idx = tf.where( - matching, + active[..., tf.newaxis] & matching, -1, free_vars_idx) free_mask = tf.where( active[..., tf.newaxis], free_vars_idx >= 0, free_mask) - return free_vars_idx, free_mask def _cauchy_get_breakpoint_min(breakpoints, free_vars_idx): - """Find the smallest breakpoint of free indices, returning the minimum breakpoint - and the corresponding index. + """Find the smallest breakpoint of free indices. + + If every breakpoint is equal, this function will return the first found variable + that is not actively constrained. + + Args: + breakpoints: tensor of breakpoints as initialized in a `_CauchyMinimizationResult` + state + free_vars_idx: tensor denoting free and constrained variables, as initialized in + a `_CauchyMinimizationResult` state Returns: - Tuple of `breakpoint_min_idx`, `breakpoint_min` - where - `breakpoint_min_idx` is the index that has min. breakpoint - `breakpoint_min` is the corresponding breakpoint + Index that has min. breakpoint + Corresponding breakpoint """ - # A tensor of shape [batch, dims] that has +infinity where free_vars_idx < 0, + no_free = (~tf.reduce_any(free_vars_idx >= 0, axis=-1)) + + # A tensor of shape [batch, dims] that has inf where free_vars_idx < 0, # and has breakpoints[free_vars_idx] otherwise. flagged_breakpoints = tf.where( free_vars_idx < 0, @@ -1319,6 +1633,36 @@ def _cauchy_get_breakpoint_min(breakpoints, free_vars_idx): axis=-1, output_type=tf.int32) + # Sometimes free variables have 'inf' breakpoints, and then there + # is no guarantee that argmin will not have picked a constrained variable + # In this case, grab the first free variable by iterating along the variables + # until one is free + + def _check_gathered(active, _): + """Whether we are still looking for a free variable.""" + return tf.reduce_any(active) + + def _get_first(active, new_idx): + """Check if next variable is free.""" + new_idx = tf.where(active, new_idx+1, new_idx) + active = (~no_free & + (tf.gather( + free_vars_idx, + new_idx, + batch_dims=1) < 0)) + return [active, new_idx] + + active = (~no_free & + (tf.gather( + free_vars_idx, + argmin_idx, + batch_dims=1) < 0)) + _, argmin_idx = ( + tf.while_loop( + cond=_check_gathered, + body=_get_first, + loop_vars=[active, argmin_idx])) + # NOTE: For situations where there are no more free indices # (and therefore argmin_idx indexes into -1), we set # breakpoint_min_idx to 0 and flag that there are no free @@ -1327,10 +1671,6 @@ def _cauchy_get_breakpoint_min(breakpoints, free_vars_idx): # This is because in branching situations, indexing with # breakpoint_min_idx can occur, and later be discarded, but all # elements in breakpoint_min_idx must be a priori valid indices. - no_free = tf.gather( - free_vars_idx, - argmin_idx, - batch_dims=1) < 0 breakpoint_min_idx = tf.where( no_free, 0, @@ -1349,185 +1689,6 @@ def _cauchy_get_breakpoint_min(breakpoints, free_vars_idx): return breakpoint_min_idx, breakpoint_min -def _get_search_direction(state): - """Computes the search direction to follow at the current state. - - On the `k`-th iteration of the main L-BFGS algorithm, the state has collected - the most recent `m` correction pairs in position_deltas and gradient_deltas, - where `k = state.num_iterations` and `m = min(k, num_correction_pairs)`. - - Assuming these, the code below is an implementation of the L-BFGS two-loop - recursion algorithm given by [Nocedal and Wright(2006)][1]: - - ```None - q_direction = objective_gradient - for i in reversed(range(m)): # First loop. - inv_rho[i] = gradient_deltas[i]^T * position_deltas[i] - alpha[i] = position_deltas[i]^T * q_direction / inv_rho[i] - q_direction = q_direction - alpha[i] * gradient_deltas[i] - - kth_inv_hessian_factor = (gradient_deltas[-1]^T * position_deltas[-1] / - gradient_deltas[-1]^T * gradient_deltas[-1]) - r_direction = kth_inv_hessian_factor * I * q_direction - - for i in range(m): # Second loop. - beta = gradient_deltas[i]^T * r_direction / inv_rho[i] - r_direction = r_direction + position_deltas[i] * (alpha[i] - beta) - - return -r_direction # Approximates - H_k * objective_gradient. - ``` - - Args: - state: A `LBfgsBOptimizerResults` tuple with the current state of the - search procedure. - - Returns: - A real `Tensor` of the same shape as the `state.position`. The direction - along which to perform line search. - """ - # The number of correction pairs that have been collected so far. - #num_elements = ps.minimum( - # state.num_iterations, # TODO(b/162733947): Change loop state -> closure. - # ps.shape(state.position_deltas)[0]) - - def _two_loop_algorithm(): - """L-BFGS two-loop algorithm.""" - # Correction pairs are always appended to the end, so only the latest - # `num_elements` vectors have valid position/gradient deltas. Vectors - # that haven't been computed yet are zero. - position_deltas = state.position_deltas - gradient_deltas = state.gradient_deltas - num_correction_pairs, num_batches, _point_dims = \ - ps.shape(gradient_deltas, out_type=tf.int32) - - # Pre-compute all `inv_rho[i]`s. - inv_rhos = tf.reduce_sum( - gradient_deltas * position_deltas, axis=-1) - - def first_loop(acc, args): - _, q_direction, num_iter = acc - position_delta, gradient_delta, inv_rho = args - active = (num_iter < state.history) - alpha = tf.math.divide_no_nan( - tf.reduce_sum( - position_delta * q_direction, - axis=-1), - inv_rho) - direction_delta = alpha[..., tf.newaxis] * gradient_delta - new_q_direction = tf.where( - active[..., tf.newaxis], - q_direction - direction_delta, - q_direction) - - return (alpha, new_q_direction, num_iter + 1) - - # Run first loop body computing and collecting `alpha[i]`s, while also - # computing the updated `q_direction` at each step. - zero = tf.zeros_like(inv_rhos[0]) - alphas, q_directions, _num_iters = tf.scan( - first_loop, [position_deltas, gradient_deltas, inv_rhos], - initializer=(zero, state.objective_gradient, 0), reverse=True) - - # We use `H^0_k = gamma_k * I` as an estimate for the initial inverse - # hessian for the k-th iteration; then `r_direction = H^0_k * q_direction`. - idx = tf.transpose( - tf.stack( - [tf.where( - state.history > 0, - num_correction_pairs - state.history, - 0), - tf.range(num_batches)])) - gamma_k = tf.math.divide_no_nan( - tf.gather_nd(inv_rhos, idx), - tf.reduce_sum( - tf.gather_nd(gradient_deltas, idx)**2, - axis=-1)) - gamma_k = tf.where( - (state.history > 0), - gamma_k, - 1.0) - r_direction = gamma_k[..., tf.newaxis] * tf.gather_nd(q_directions, idx) - - def second_loop(acc, args): - r_direction, iter_idx = acc - alpha, position_delta, gradient_delta, inv_rho = args - active = (iter_idx >= num_correction_pairs - state.history) - beta = tf.math.divide_no_nan( - tf.reduce_sum( - gradient_delta * r_direction, - axis=-1), - inv_rho) - direction_delta = (alpha - beta)[..., tf.newaxis] * position_delta - new_r_direction = tf.where( - active[..., tf.newaxis], - r_direction + direction_delta, - r_direction) - return (new_r_direction, iter_idx + 1) - - # Finally, run second loop body computing the updated `r_direction` at each - # step. - r_directions, _num_iters = tf.scan( - second_loop, [alphas, position_deltas, gradient_deltas, inv_rhos], - initializer=(r_direction, 0)) - - return -r_directions[-1] - - return ps.cond(tf.reduce_any(state.history != 0), - _two_loop_algorithm, - lambda: -state.objective_gradient) - - -def _get_ragged_sizes(tensor, dtype=tf.int32): - """Creates a tensor indicating the size of each component of - a ragged dimension. - - For example: - - ```python - element = tf.ragged.constant([[1,2], [3,4,5], [], [0]]) - _get_ragged_sizes(element) - # => - ``` - """ - return tf.reduce_sum( - tf.ones_like( - tensor, - dtype=dtype), - axis=-1)[..., tf.newaxis] - - -def _get_range_like_ragged(tensor, dtype=tf.int32): - """Creates a batched range for the elements of the batched tensor. - - For example: - - ```python - element = tf.ragged.constant([[1,2], [3,4,5], [], [0]]) - _get_range_like_ragged(element) - # => - - Args: - tensor: a RaggedTensor of shape `[n, None]`. - - Returns: - A ragged tensor of shape `[n, None]` where the ragged dimensions - match the ragged dimensions of `tensor`, and are a range from `0` to - the size of the ragged dimension. - ``` - """ - sizes = _get_ragged_sizes(tensor) - flat_ranges = tf.ragged.range( - tf.reshape( - sizes, - [tf.reduce_prod(sizes.shape)]), - dtype=dtype) - return tf.RaggedTensor.from_row_lengths(flat_ranges, sizes.shape[:-1])[0] - - def _make_empty_queue_for(k, element): """Creates a `tf.Tensor` suitable to hold `k` element-shaped tensors. From 41b22679c6df83d8e9ea77146b1d010b7054ac40 Mon Sep 17 00:00:00 2001 From: mikeevmm Date: Fri, 25 Jun 2021 14:49:44 +0100 Subject: [PATCH 3/4] refact: minor optimizations and documentation refactoring. --- .../python/optimizer/lbfgs.py | 1752 ++++++++++++++--- 1 file changed, 1434 insertions(+), 318 deletions(-) diff --git a/tensorflow_probability/python/optimizer/lbfgs.py b/tensorflow_probability/python/optimizer/lbfgs.py index 886a3084ef..e4382b5bf6 100644 --- a/tensorflow_probability/python/optimizer/lbfgs.py +++ b/tensorflow_probability/python/optimizer/lbfgs.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""The Limited-Memory BFGS minimization algorithm. +"""A constrained version of the Limited-Memory BFGS minimization algorithm. Limited-memory quasi-Newton methods are useful for solving large problems whose Hessian matrices cannot be computed at a reasonable cost or are not @@ -20,8 +20,8 @@ matrices, they only save a few vectors of length n that represent the approximations implicitly. -This module implements the algorithm known as L-BFGS, which, as its name -suggests, is a limited-memory version of the BFGS algorithm. +This module implements the algorithm known as L-BFGS-B, which, as its name +suggests, is a limited-memory version of the BFGS algorithm, with bounds. """ from __future__ import absolute_import from __future__ import division @@ -35,183 +35,261 @@ from tensorflow_probability.python.internal import dtype_util from tensorflow_probability.python.internal import prefer_static as ps from tensorflow_probability.python.optimizer import bfgs_utils - - -LBfgsOptimizerResults = collections.namedtuple( - 'LBfgsOptimizerResults', [ - 'converged', # Scalar boolean tensor indicating whether the minimum - # was found within tolerance. - 'failed', # Scalar boolean tensor indicating whether a line search - # step failed to find a suitable step size satisfying Wolfe - # conditions. In the absence of any constraints on the - # number of objective evaluations permitted, this value will - # be the complement of `converged`. However, if there is - # a constraint and the search stopped due to available - # evaluations being exhausted, both `failed` and `converged` - # will be simultaneously False. - 'num_iterations', # The number of iterations of the BFGS update. - 'num_objective_evaluations', # The total number of objective - # evaluations performed. - 'position', # A tensor containing the last argument value found - # during the search. If the search converged, then - # this value is the argmin of the objective function. - 'objective_value', # A tensor containing the value of the objective - # function at the `position`. If the search - # converged, then this is the (local) minimum of - # the objective function. - 'objective_gradient', # A tensor containing the gradient of the - # objective function at the - # `final_position`. If the search converged - # the max-norm of this tensor should be - # below the tolerance. - 'position_deltas', # A tensor encoding information about the latest - # changes in `position` during the algorithm - # execution. Its shape is of the form - # `(num_correction_pairs,) + position.shape` where - # `num_correction_pairs` is given as an argument to - # the minimize function. - 'gradient_deltas', # A tensor encoding information about the latest - # changes in `objective_gradient` during the - # algorithm execution. Has the same shape as - # position_deltas. - ]) +from tensorflow_probability.python.optimizer import lbfgs_minimize + + +LBfgsBOptimizerResults = collections.namedtuple( + 'LBfgsBOptimizerResults', [ + 'converged', # Scalar boolean tensor indicating whether the minimum + # was found within tolerance. + 'failed', # Scalar boolean tensor indicating whether a line search + # step failed to find a suitable step size satisfying Wolfe + # conditions. In the absence of any constraints on the + # number of objective evaluations permitted, this value will + # be the complement of `converged`. However, if there is + # a constraint and the search stopped due to available + # evaluations being exhausted, both `failed` and `converged` + # will be simultaneously False. + 'num_iterations', # The number of iterations of the BFGS update. + 'num_objective_evaluations', # The total number of objective + # evaluations performed. + 'position', # A tensor containing the last argument value found + # during the search. If the search converged, then + # this value is the argmin of the objective function. + 'lower_bounds', # A tensor containing the lower bounds to the constrained + # optimization, cast to the shape of `position`. + 'upper_bounds', # A tensor containing the upper bounds to the constrained + # optimization, cast to the shape of `position`. + 'objective_value', # A tensor containing the value of the objective + # function at the `position`. If the search + # converged, then this is the (local) minimum of + # the objective function. + 'objective_gradient', # A tensor containing the gradient of the + # objective function at the + # `final_position`. If the search converged + # the max-norm of this tensor should be + # below the tolerance. + 'position_deltas', # A tensor encoding information about the latest + # changes in `position` during the algorithm + # execution. Its shape is of the form + # `(num_correction_pairs,) + position.shape` where + # `num_correction_pairs` is given as an argument to + # the minimize function. + 'gradient_deltas', # A tensor encoding information about the latest + # changes in `objective_gradient` during the + # algorithm execution. Has the same shape as + # position_deltas. + 'history', # How many gradient/position deltas should be considered. + ]) + +_ConstrainedCauchyState = collections.namedtuple( + '_CauchyMinimizationResult', [ + # `\theta` in [2]; n the Cauchy search, relates to the implicit Hessian + 'theta', + # `B = \theta*I - WMW'` (`I` the identity, see [1,2] for details) + # `M_k` matrix in [2]; part of the implicit representation of the Hessian, + 'm', + # see the comment above + 'breakpoints', # `t_i` in [Byrd et al.][2]; + # the breakpoints in the branch definition of the + # projection of the gradients, batched + 'breakpoints_argsort', # Range from 0...n-1 sorted by increasing breakpoints + # Tensor of shape [batch]; the index into `breakpoints_argsort` + 'next_free_idx', + # for the breakpoint in effect + 'steepest', # `d` in [2]; steepest descent clamped to bounds + 'p', # as in [2] + # as in [2]; eventually made to equal `W'(cauchy_point - position)` + 'c', + 'df', # `f'` in [2] + 'ddf', # `f''` in [2] + 'dt', # `\Delta t` in [2] + 'dt_min', # `\Delta t_min` in [2] + 'tsum', # Sum of all the considered breakpoints so far + 'breakpoint_min_old', # t_old in [2] + # `x^cp` in [2]; the actual cauchy point (we're looking for) + 'cauchy_point', + 'active', # What batches are in active optimization + 'free_mask', # Boolean tensor of what variables are actively constrained + ]) def minimize(value_and_gradients_function, - initial_position, - previous_optimizer_results=None, - num_correction_pairs=10, - tolerance=1e-8, - x_tolerance=0, - f_relative_tolerance=0, - initial_inverse_hessian_estimate=None, - max_iterations=50, - parallel_iterations=1, - stopping_condition=None, - max_line_search_iterations=50, - name=None): - """Applies the L-BFGS algorithm to minimize a differentiable function. - - Performs unconstrained minimization of a differentiable function using the - L-BFGS scheme. See [Nocedal and Wright(2006)][1] for details of the algorithm. + initial_position, + bounds=None, + previous_optimizer_results=None, + num_correction_pairs=10, + tolerance=1e-5, + x_tolerance=0, + f_relative_tolerance=0, + initial_inverse_hessian_estimate=None, + max_iterations=50, + parallel_iterations=1, + stopping_condition=None, + max_line_search_iterations=50, + name=None): + """Applies the L-BFGS-B algorithm to minimize a differentiable function. + + Performs optionally constrained minimization of a differentiable function using the + L-BFGS-B scheme. See [Nocedal and Wright(2006)][1] for details on the unconstrained + version, and [Byrd et al.][2] for details on the constrained algorithm. ### Usage: - The following example demonstrates the L-BFGS optimizer attempting to find the - minimum for a simple high-dimensional quadratic objective function. + The following example demonstrates the L-BFGS-B optimizer attempting to find the + constrained minimum for a simple high-dimensional quadratic objective function. ```python - # A high-dimensional quadratic bowl. - ndims = 60 - minimum = np.ones([ndims], dtype='float64') - scales = np.arange(ndims, dtype='float64') + 1.0 - - # The objective function and the gradient. - def quadratic_loss_and_gradient(x): - return tfp.math.value_and_gradient( - lambda x: tf.reduce_sum( - scales * tf.math.squared_difference(x, minimum), axis=-1), - x) - start = np.arange(ndims, 0, -1, dtype='float64') - optim_results = tfp.optimizer.lbfgs_minimize( - quadratic_loss_and_gradient, - initial_position=start, - num_correction_pairs=10, - tolerance=1e-8) - - # Check that the search converged - assert(optim_results.converged) - # Check that the argmin is close to the actual value. - np.testing.assert_allclose(optim_results.position, minimum) + ndims = 60 + minimum = tf.convert_to_tensor( + np.ones([ndims]), dtype=tf.float32) + lower_bounds = tf.convert_to_tensor( + np.arange(ndims), dtype=tf.float32) + upper_bounds = tf.convert_to_tensor( + np.arange(100, 100-ndims, -1), dtype=tf.float32) + scales = tf.convert_to_tensor( + (np.random.rand(ndims) + 1.)*5. + 1., dtype=tf.float32) + start = tf.constant(np.random.rand(2, ndims)*100, dtype=tf.float32) + + # The objective function and the gradient. + def quadratic_loss_and_gradient(x): + return tfp.math.value_and_gradient( + lambda x: tf.reduce_sum( + scales * tf.math.squared_difference(x, minimum), axis=-1), + x) + opt_results = tfp.optimizer.lbfgsb_minimize( + quadratic_loss_and_gradient, + initial_position=start, + num_correction_pairs=10, + tolerance=1e-10, + bounds=[lower_bounds, upper_bounds]) ``` ### References: [1] Jorge Nocedal, Stephen Wright. Numerical Optimization. Springer Series - in Operations Research. pp 176-180. 2006 + in Operations Research. pp 176-180. 2006 http://pages.mtu.edu/~struther/Courses/OLD/Sp2013/5630/Jorge_Nocedal_Numerical_optimization_267490.pdf + [2] Richard H. Byrd, Peihuang Lu, Jorge Nocedal, & Ciyou Zhu (1995). + A Limited Memory Algorithm for Bound Constrained Optimization + SIAM Journal on Scientific Computing, 16(5), 1190–1208. + + https://doi.org/10.1137/0916069 + + [3] Jose Luis Morales, Jorge Nocedal (2011). + "Remark On Algorithm 788: L-BFGS-B: Fortran Subroutines for Large-Scale + Bound Constrained Optimization" + ACM Trans. Math. Softw. 38, 1, Article 7. + + https://dl.acm.org/doi/abs/10.1145/2049662.2049669 + Args: value_and_gradients_function: A Python callable that accepts a point as a - real `Tensor` and returns a tuple of `Tensor`s of real dtype containing - the value of the function and its gradient at that point. The function - to be minimized. The input is of shape `[..., n]`, where `n` is the size - of the domain of input points, and all others are batching dimensions. - The first component of the return value is a real `Tensor` of matching - shape `[...]`. The second component (the gradient) is also of shape - `[..., n]` like the input value to the function. + real `Tensor` and reporting arguments, and returns a tuple of `Tensor`s of + real dtype containing the value of the function and its gradient at that + point. The function to be minimized. The input is of shape `[..., n]`, + where `n` is the size of the domain of input points, and all others are + batching dimensions. The first component of the return value is a real + `Tensor` of matching shape `[...]`. The second component (the gradient) is + also of shape `[..., n]` like the input value to the function. + The reporting arguments consist of a Boolean `Tensor` of shape `[...]` + denoting which batches have terminated, and two real `Tensor` of shape + `[..., n]`, denoting the last evaluated objective values and gradients + (respectively). initial_position: Real `Tensor` of shape `[..., n]`. The starting point, or - points when using batching dimensions, of the search procedure. At these - points the function value and the gradient norm should be finite. - Exactly one of `initial_position` and `previous_optimizer_results` can be - non-None. - previous_optimizer_results: An `LBfgsOptimizerResults` namedtuple to - intialize the optimizer state from, instead of an `initial_position`. - This can be passed in from a previous return value to resume optimization - with a different `stopping_condition`. Exactly one of `initial_position` - and `previous_optimizer_results` can be non-None. + points when using batching dimensions, of the search procedure. At these + points the function value and the gradient norm should be finite. + Exactly one of `initial_position` and `previous_optimizer_results` can be + non-None. + bounds: Tuple of two real `Tensor`s of shape `[..., n]`. The first element + indicates the lower bounds in the constrained optimization, and the second + element of the tuple indicates the upper bounds of the optimization. If + `bounds` is `None`, the optimization is deferred to the unconstrained + version (see also `lbfgs_minimize`). If one of the elements of the tuple + is `None`, the optimization is assumed to be unconstrained (from above/below, + respectively). + previous_optimizer_results: An `LBfgsBOptimizerResults` namedtuple to + intialize the optimizer state from, instead of an `initial_position`. + This can be passed in from a previous return value to resume optimization + with a different `stopping_condition`. Exactly one of `initial_position` + and `previous_optimizer_results` can be non-None. num_correction_pairs: Positive integer. Specifies the maximum number of - (position_delta, gradient_delta) correction pairs to keep as implicit - approximation of the Hessian matrix. + (position_delta, gradient_delta) correction pairs to keep as implicit + approximation of the Hessian matrix + A real `Tensor` of the same shape as the `state.position`, of dtype `bool`, + denoting a mask over the free variables.x. tolerance: Scalar `Tensor` of real dtype. Specifies the gradient tolerance - for the procedure. If the supremum norm of the gradient vector is below - this number, the algorithm is stopped. + for the procedure. If the supremum norm of the gradient vector is below + this number, the algorithm is stopped. x_tolerance: Scalar `Tensor` of real dtype. If the absolute change in the - position between one iteration and the next is smaller than this number, - the algorithm is stopped. + position between one iteration and the next is smaller than this number, + the algorithm is stopped. f_relative_tolerance: Scalar `Tensor` of real dtype. If the relative change - in the objective value between one iteration and the next is smaller - than this value, the algorithm is stopped. + in the objective value between one iteration and the next is smaller + than this value, referenced to the current objective value, the previous + objective value, or `1`, whichever is greatest, the algorithm is stopped. initial_inverse_hessian_estimate: None. Option currently not supported. max_iterations: Scalar positive int32 `Tensor`. The maximum number of - iterations for L-BFGS updates. + iterations for L-BFGS updates. parallel_iterations: Positive integer. The number of iterations allowed to - run in parallel. + run in parallel. stopping_condition: (Optional) A Python function that takes as input two - Boolean tensors of shape `[...]`, and returns a Boolean scalar tensor. - The input tensors are `converged` and `failed`, indicating the current - status of each respective batch member; the return value states whether - the algorithm should stop. The default is tfp.optimizer.converged_all - which only stops when all batch members have either converged or failed. - An alternative is tfp.optimizer.converged_any which stops as soon as one - batch member has converged, or when all have failed. + Boolean tensors of shape `[...]`, and returns a Boolean scalar tensor. + The input tensors are `converged` and `failed`, indicating the current + status of each respective batch member; the return value states whether + the algorithm should stop. The default is tfp.optimizer.converged_all + which only stops when all batch members have either converged or failed. + An alternative is tfp.optimizer.converged_any which stops as soon as one + batch member has converged, or when all have failed. max_line_search_iterations: Python int. The maximum number of iterations - for the `hager_zhang` line search algorithm. + for the `hager_zhang` line search algorithm. name: (Optional) Python str. The name prefixed to the ops created by this - function. If not supplied, the default name 'minimize' is used. + function. If not supplied, the default name 'minimize' is used. Returns: optimizer_results: A namedtuple containing the following items: - converged: Scalar boolean tensor indicating whether the minimum was - found within tolerance. - failed: Scalar boolean tensor indicating whether a line search - step failed to find a suitable step size satisfying Wolfe - conditions. In the absence of any constraints on the - number of objective evaluations permitted, this value will - be the complement of `converged`. However, if there is - a constraint and the search stopped due to available - evaluations being exhausted, both `failed` and `converged` - will be simultaneously False. - num_objective_evaluations: The total number of objective - evaluations performed. - position: A tensor containing the last argument value found - during the search. If the search converged, then - this value is the argmin of the objective function. - objective_value: A tensor containing the value of the objective - function at the `position`. If the search converged, then this is - the (local) minimum of the objective function. - objective_gradient: A tensor containing the gradient of the objective - function at the `position`. If the search converged the - max-norm of this tensor should be below the tolerance. - position_deltas: A tensor encoding information about the latest - changes in `position` during the algorithm execution. - gradient_deltas: A tensor encoding information about the latest - changes in `objective_gradient` during the algorithm execution. + converged: Scalar boolean tensor indicating whether the minimum was + found within tolerance. + failed: Scalar boolean tensor indicating whether a line search + step failed to find a suitable step size satisfying Wolfe + conditions. In the absence of any constraints on the + number of objective evaluations permitted, this value will + be the complement of `converged`. However, if there is + a constraint and the search stopped due to available + evaluations being exhausted, both `failed` and `converged` + will be simultaneously False. + num_objective_evaluations: The total number of objective + evaluations performed. + position: A tensor containing the last argument value found + during the search. If the search converged, then + this value is the argmin of the objective function. + objective_value: A tensor containing the value of the objective + function at the `position`. If the search converged, then this is + the (local) minimum of the objective function. + objective_gradient: A tensor containing the gradient of the objective + function at the `position`. If the search converged the + max-norm of this tensor should be below the tolerance. + position_deltas: A tensor encoding information about the latest + changes in `position` during the algorithm execution. + gradient_deltas: A tensor encoding information about the latest + changes in `objective_gradient` during the algorithm execution. """ + + if len(bounds) != 2: + raise ValueError( + '`bounds` parameter has unexpected number of elements ' + '(expected 2).') + + lower_bounds, upper_bounds = bounds + + # Defer further conversion of the bounds to appropriate tensors + # until the shape of the input is known + if initial_inverse_hessian_estimate is not None: raise NotImplementedError( - 'Support of initial_inverse_hessian_estimate arg not yet implemented') + 'Support of initial_inverse_hessian_estimate arg not yet implemented') if stopping_condition is None: stopping_condition = bfgs_utils.converged_all @@ -219,180 +297,1223 @@ def quadratic_loss_and_gradient(x): with tf.name_scope(name or 'minimize'): if (initial_position is None) == (previous_optimizer_results is None): raise ValueError( - 'Exactly one of `initial_position` or ' - '`previous_optimizer_results` may be specified.') + 'Exactly one of `initial_position` or ' + '`previous_optimizer_results` may be specified.') if initial_position is not None: initial_position = tf.convert_to_tensor( - initial_position, name='initial_position') + initial_position, name='initial_position') + # Force at least one batching dimension + if len(ps.shape(initial_position)) == 1: + initial_position = initial_position[tf.newaxis, :] + position_shape = ps.shape(initial_position) dtype = dtype_util.base_dtype(initial_position.dtype) if previous_optimizer_results is not None: - dtype = dtype_util.base_dtype(previous_optimizer_results.position.dtype) + position_shape = ps.shape(previous_optimizer_results.position) + dtype = dtype_util.base_dtype( + previous_optimizer_results.position.dtype) + + # TODO: This isn't agnostic to the number of batch dimensions, it only + # supports one batch dimension, but I've found RaggedTensors to be far + # too finicky/undocumented to handle multiple batch dimensions in any + # sane way. (Even the way it's working so far is less than ideal.) + if len(position_shape) > 2: + raise NotImplementedError( + "More than a batch dimension is not implemented. " + "Consider flattening and then reshaping the results.") + # NOTE: Broadcasting the batched dimensions breaks when there are no + # batched dimensions. Although this isn't handled like this in + # `lbfgs.py`, I'd rather force a batch dimension with a single + # element than do conditional checks later. + if len(position_shape) == 1: + position_shape = tf.concat([[1], position_shape], axis=0) + initial_position = tf.broadcast_to( + initial_position, position_shape) + + # NOTE: Could maybe use bfgs_utils._broadcast here, but would have to check + # that the non-batching dimensions also match; using `tf.broadcast_to` has + # the advantage that passing a (1,)-shaped tensor as bounds will correctly + # bound every variable at the single value. + if lower_bounds is None: + lower_bounds = tf.constant( + [-float('inf')], shape=position_shape, dtype=dtype, name='lower_bounds') + else: + lower_bounds = tf.cast( + tf.convert_to_tensor(lower_bounds), dtype=dtype) + try: + lower_bounds = tf.broadcast_to( + lower_bounds, position_shape, name='lower_bounds') + except tf.errors.InvalidArgumentError: + raise ValueError( + 'Failed to broadcast lower bounds tensor to the shape of starting ' + 'position. Are the lower bounds well formed?') + if upper_bounds is None: + upper_bounds = tf.constant( + [float('inf')], shape=position_shape, dtype=dtype, name='upper_bounds') + else: + upper_bounds = tf.cast( + tf.convert_to_tensor(upper_bounds), dtype=dtype) + try: + upper_bounds = tf.broadcast_to( + upper_bounds, position_shape, name='upper_bounds') + except tf.errors.InvalidArgumentError: + raise ValueError( + 'Failed to broadcast upper bounds tensor to the shape of starting ' + 'position. Are the lower bounds well formed?') + + # Clamp the starting position to the bounds, because the algorithm expects + # the variables to be in range for the Hessian inverse estimation, but also + # because that fast-tracks the first iteration of the Cauchy optimization. + initial_position = tf.clip_by_value( + initial_position, lower_bounds, upper_bounds) tolerance = tf.convert_to_tensor( - tolerance, dtype=dtype, name='grad_tolerance') + tolerance, dtype=dtype, name='grad_tolerance') f_relative_tolerance = tf.convert_to_tensor( - f_relative_tolerance, dtype=dtype, name='f_relative_tolerance') + f_relative_tolerance, dtype=dtype, name='f_relative_tolerance') x_tolerance = tf.convert_to_tensor( - x_tolerance, dtype=dtype, name='x_tolerance') - max_iterations = tf.convert_to_tensor(max_iterations, name='max_iterations') + x_tolerance, dtype=dtype, name='x_tolerance') + max_iterations = tf.convert_to_tensor( + max_iterations, name='max_iterations') - # The `state` here is a `LBfgsOptimizerResults` tuple with values for the + # The `state` here is a `LBfgsBOptimizerResults` tuple with values for the # current state of the algorithm computation. def _cond(state): """Continue if iterations remain and stopping condition is not met.""" return ((state.num_iterations < max_iterations) & - tf.logical_not(stopping_condition(state.converged, state.failed))) + tf.logical_not(stopping_condition(state.converged, state.failed))) def _body(current_state): """Main optimization loop.""" current_state = bfgs_utils.terminate_if_not_finite(current_state) - search_direction = _get_search_direction(current_state) + cauchy_state, current_state = _cauchy_minimization( + current_state, num_correction_pairs, parallel_iterations) - # TODO(b/120134934): Check if the derivative at the start point is not - # negative, if so then reset position/gradient deltas and recompute - # search direction. + search_direction, current_state, clip_before, refreshed = ( + _find_search_direction( + current_state, cauchy_state, num_correction_pairs)) - next_state = bfgs_utils.line_search_step( - current_state, - value_and_gradients_function, search_direction, - tolerance, f_relative_tolerance, x_tolerance, stopping_condition, - max_line_search_iterations) + # If any batch needs a refresh, restart the whole thing, to reduce number + # of function evaluations - # If not failed or converged, update the Hessian estimate. - should_update = ~(next_state.converged | next_state.failed) - state_after_inv_hessian_update = bfgs_utils.update_fields( + def _continue_minimization(): + """Proceeds with minimization iteration.""" + next_state = _constrained_line_search_step( + current_state, value_and_gradients_function, search_direction, + tolerance, f_relative_tolerance, x_tolerance, stopping_condition, + max_line_search_iterations, clip_before) + + # If not failed or converged, update the Hessian estimate. + # Only do this if the new pairs obey the s.y > eps.||y|| + position_delta = (next_state.position - current_state.position) + gradient_delta = (next_state.objective_gradient - + current_state.objective_gradient) + # Article is ambiguous; see lbfgs.f:863 + curvature_cond = ( + tf.reduce_sum(position_delta * gradient_delta, axis=-1) >= + bfgs_utils.norm(current_state.objective_gradient, dims=1) * + dtype_util.eps(position_delta.dtype)) + should_push = (~(next_state.converged | next_state.failed) & + curvature_cond & ~refreshed) + # TODO: Track number of skipped pairs + new_position_deltas = _queue_push( + next_state.position_deltas, should_push, position_delta) + new_gradient_deltas = _queue_push( + next_state.gradient_deltas, should_push, gradient_delta) + new_history = tf.where( + should_push, + tf.math.minimum(next_state.history + 1, + num_correction_pairs), + next_state.history) + + if not tf.executing_eagerly(): + # Hint the compiler that the shape of the properties has not changed + new_position_deltas = tf.ensure_shape( + new_position_deltas, next_state.position_deltas.shape) + new_gradient_deltas = tf.ensure_shape( + new_gradient_deltas, next_state.gradient_deltas.shape) + new_history = tf.ensure_shape( + new_history, next_state.history.shape) + + next_state = bfgs_utils.update_fields( next_state, - position_deltas=_queue_push( - current_state.position_deltas, should_update, - next_state.position - current_state.position), - gradient_deltas=_queue_push( - current_state.gradient_deltas, should_update, - next_state.objective_gradient - current_state.objective_gradient)) - return [state_after_inv_hessian_update] + position_deltas=new_position_deltas, + gradient_deltas=new_gradient_deltas, + history=new_history) + + return [next_state] + + return tf.cond( + pred=tf.reduce_any(refreshed), + true_fn=lambda: [current_state], + false_fn=_continue_minimization) if previous_optimizer_results is None: assert initial_position is not None initial_state = _get_initial_state(value_and_gradients_function, - initial_position, - num_correction_pairs, - tolerance) + initial_position, + lower_bounds, + upper_bounds, + num_correction_pairs, + tolerance) else: initial_state = previous_optimizer_results return tf.while_loop( - cond=_cond, - body=_body, - loop_vars=[initial_state], - parallel_iterations=parallel_iterations)[0] - - -def _get_initial_state(value_and_gradients_function, - initial_position, - num_correction_pairs, - tolerance): - """Create LBfgsOptimizerResults with initial state of search procedure.""" - init_args = bfgs_utils.get_initial_state_args( - value_and_gradients_function, - initial_position, - tolerance) - empty_queue = _make_empty_queue_for(num_correction_pairs, initial_position) - init_args.update(position_deltas=empty_queue, gradient_deltas=empty_queue) - return LBfgsOptimizerResults(**init_args) + cond=_cond, + body=_body, + loop_vars=[initial_state], + parallel_iterations=parallel_iterations)[0] -def _get_search_direction(state): - """Computes the search direction to follow at the current state. +def _cauchy_minimization(bfgs_state, num_correction_pairs, parallel_iterations): + """Calculates the Cauchy point, bounding the gradient by the bounds. - On the `k`-th iteration of the main L-BFGS algorithm, the state has collected - the most recent `m` correction pairs in position_deltas and gradient_deltas, - where `k = state.num_iterations` and `m = min(k, num_correction_pairs)`. + This function minimizes the quadratic approximation to the objective + function at the current position, in the direction of steepest descent, + but bounding the gradient by the corresponding bounds. - Assuming these, the code below is an implementation of the L-BFGS two-loop - recursion algorithm given by [Nocedal and Wright(2006)][1]: + See algorithm CP and associated discussion of [Byrd,Lu,Nocedal,Zhu][2] + for details. - ```None - q_direction = objective_gradient - for i in reversed(range(m)): # First loop. - inv_rho[i] = gradient_deltas[i]^T * position_deltas[i] - alpha[i] = position_deltas[i]^T * q_direction / inv_rho[i] - q_direction = q_direction - alpha[i] * gradient_deltas[i] + This function may modify the given `bfgs_state`, in that it refreshes the + memory for batches that are found to be in an invalid state. - kth_inv_hessian_factor = (gradient_deltas[-1]^T * position_deltas[-1] / - gradient_deltas[-1]^T * gradient_deltas[-1]) - r_direction = kth_inv_hessian_factor * I * q_direction + Args: + bfgs_state: current `LBfgsBOptimizerResults` state + num_correction_pairs: the (maximum) number of past steps to keep as + history for the LBFGS algorithm + parallel_iterations: argument of `tf.while` loops + Returns: + A `_CauchyMinimizationResult` containing the results of the Cauchy point + computation. + Updated `bfgs_state` + """ + cauchy_state, bfgs_state = _get_initial_cauchy_state( + bfgs_state, num_correction_pairs) + n = ps.shape(bfgs_state.position)[-1] + idx_range = tf.range(ps.shape(bfgs_state.position)[-1])[tf.newaxis, ...] + # NOTE: See lbfgsb.f (l. 1524) + ddf_org = -cauchy_state.theta * cauchy_state.df + + def _cond(state): + """Test convergence to Cauchy point at current branch""" + return tf.reduce_any(state.active) + + def _body(state): + """Cauchy point iterative loop (While loop of CP algorithm [2])""" + # Because of `where` statements, the indices for gathering must always + # be valid, even if the result is not used afterwards. For batches that + # are no longer active, the `next_free_idx` (which points to the index + # of the current minimum breakpoint via `breakpoints_argsort`) may + # exceed the size of `breakpoints_argsort` (if the batch isn't active + # because there are no free variables left). So, instead, we take 0 as a + # dummy value, which will later be discarded by the `where` statements. + next_free_idx = tf.where( + state.active, + state.next_free_idx, + 0) + breakpoint_min_idx = tf.where( + state.active, + tf.gather( + state.breakpoints_argsort, + next_free_idx, + batch_dims=1), + 0) + breakpoint_min = tf.where( + state.active, + tf.gather( + state.breakpoints, + breakpoint_min_idx, + batch_dims=1), + state.breakpoint_min_old) + + dt = (breakpoint_min - state.breakpoint_min_old) + + # NOTE: We immediately update active to simulate an early return + # This value should be used below (instead of `state.active`) + active = (state.active & (state.dt_min >= dt)) + + # Set the considered variable as fixed + tsum = tf.where(active, state.tsum + dt, state.tsum) + breakpoint_min_idx_mask = ( + idx_range == breakpoint_min_idx[..., tf.newaxis]) + steepest = tf.where( + active[..., tf.newaxis], + tf.where( + breakpoint_min_idx_mask, + 0., + state.steepest), + state.steepest) + free_mask = tf.where( + active[..., tf.newaxis], + (state.free_mask & ~breakpoint_min_idx_mask), + state.free_mask) + d_b = tf.gather( + state.steepest, + breakpoint_min_idx, + batch_dims=1) + x_cp_b = tf.gather( + tf.where( + (d_b > 0.)[..., tf.newaxis], + bfgs_state.upper_bounds, + tf.where( + (d_b < 0.)[..., tf.newaxis], + bfgs_state.lower_bounds, + state.cauchy_point + )), + breakpoint_min_idx, + batch_dims=1) + cauchy_point = tf.where( + active[..., tf.newaxis], + tf.where( + breakpoint_min_idx_mask, + x_cp_b[..., tf.newaxis], + state.cauchy_point), + state.cauchy_point) + + # If we're out of free variables, set dt_min to dt and "return" + next_free_idx = tf.where(active, next_free_idx + 1, next_free_idx) + no_more_free = (next_free_idx >= n) + dt_min = tf.where(no_more_free, dt, state.dt_min) + active &= ~no_more_free + + # Update remaining properties + # - Update `c` + c = tf.where( + active[..., tf.newaxis], + state.c + dt[..., tf.newaxis] * state.p, + state.c) + # - Get the `b`th row of W (needed for f', f'') + # The matrix M has shape + # + # [[ 0 0 ] + # [ 0 M_h ]] + # + # where M_h is the M matrix considering the current history `h`. + # Therefore, for W, we should consider that the last `h` columns + # are + # Y[k-h,...,k-1] theta*S[k-h,...k-1] + # (so that the first `2*(m-h)` columns are 0. + # 1. Create the "full" W matrix row + w_b = tf.concat( + [tf.gather( + bfgs_state.gradient_deltas, + breakpoint_min_idx, + axis=-1, + batch_dims=1), + (state.theta[..., tf.newaxis] * + tf.gather( + bfgs_state.position_deltas, + breakpoint_min_idx, + axis=-1, + batch_dims=1)) + ], + axis=-1) + # 2. "Permute" the relevant items to the right + idx = tf.concat( + [ + tf.ragged.range( + num_correction_pairs - bfgs_state.history), + tf.ragged.range( + num_correction_pairs, + 2*num_correction_pairs - bfgs_state.history), + tf.ragged.range( + num_correction_pairs - bfgs_state.history, + num_correction_pairs), + tf.ragged.range( + 2*num_correction_pairs - bfgs_state.history, + 2*num_correction_pairs) + ], + axis=-1).to_tensor() + w_b = tf.gather( + w_b, + idx, + batch_dims=1) + + # - Update f' + x_b = tf.gather( + bfgs_state.position, + breakpoint_min_idx, + batch_dims=1) + # NOTE Use of d_b = -g_b + df = tf.where( + active, + (state.df + dt * state.ddf + + d_b**2 - + state.theta * d_b * (x_cp_b - x_b) + + d_b * tf.einsum( + '...j,...jk,...k->...', + w_b, + state.m, + c)), + state.df) + + # - Update f'' + # NOTE use of d_b = -g_b + ddf = tf.where( + active, + (state.ddf - state.theta * d_b**2 + + 2. * d_b * tf.einsum( + "...i,...ij,...j->...", + w_b, + state.m, + state.p) - + d_b**2 * tf.einsum( + "...i,...ij,...j->...", + w_b, + state.m, + w_b)), + state.ddf) + # NOTE: See lbfgsb.f (l. 1649) + ddf = tf.where( + active, + tf.math.maximum(ddf, dtype_util.eps(ddf.dtype)*ddf_org), + state.ddf) + + # - Update p + # NOTE use of d_b = -g_b + p = tf.where( + active[..., tf.newaxis], + state.p - d_b[..., tf.newaxis] * w_b, + state.p) + + # - Update dt_min + dt_min = tf.where( + active, -tf.math.divide_no_nan(df, ddf), state.dt_min) + + # Create the updated state + + # We need to hint the compiler that nothing changed shapes + if not tf.executing_eagerly(): + steepest = tf.ensure_shape(steepest, state.steepest.shape) + p = tf.ensure_shape(p, state.p.shape) + c = tf.ensure_shape(c, state.c.shape) + df = tf.ensure_shape(df, state.df.shape) + ddf = tf.ensure_shape(ddf, state.ddf.shape) + dt = tf.ensure_shape(dt, state.dt.shape) + dt_min = tf.ensure_shape(dt_min, state.dt_min.shape) + tsum = tf.ensure_shape(tsum, state.tsum.shape) + breakpoint_min = tf.ensure_shape( + breakpoint_min, state.breakpoint_min_old.shape) + next_free_idx = tf.ensure_shape( + next_free_idx, state.next_free_idx.shape) + cauchy_point = tf.ensure_shape( + cauchy_point, state.cauchy_point.shape) + free_mask = tf.ensure_shape(free_mask, state.free_mask.shape) + active = tf.ensure_shape(active, state.active.shape) + + new_state = bfgs_utils.update_fields( + state, steepest=steepest, p=p, c=c, df=df, ddf=ddf, dt=dt, + dt_min=dt_min, tsum=tsum, breakpoint_min_old=breakpoint_min, + next_free_idx=next_free_idx, cauchy_point=cauchy_point, + free_mask=free_mask, active=active) + + return [new_state] + + cauchy_loop = tf.while_loop( + cond=_cond, + body=_body, + loop_vars=[cauchy_state], + parallel_iterations=parallel_iterations)[0] + + # NOTE: See lbfgs.f lines 1584, 1606, 1667, 1682 + free_remaining = (cauchy_loop.next_free_idx < n) + dt_min = tf.where( + free_remaining, + tf.math.maximum(cauchy_loop.dt_min, 0), + cauchy_loop.dt_min) + tsum = tf.where( + free_remaining, + cauchy_loop.tsum + dt_min, + cauchy_loop.tsum) + + cauchy_point = tf.where( + (bfgs_state.converged | bfgs_state.failed)[..., tf.newaxis], + bfgs_state.position, + tf.where( + free_remaining[..., tf.newaxis], + cauchy_loop.cauchy_point + + tsum[..., tf.newaxis] * cauchy_loop.steepest, + cauchy_loop.cauchy_point)) + + c = cauchy_loop.c + dt_min[..., tf.newaxis]*cauchy_loop.p + # NOTE: `c` is already permuted to match the subspace of `M`, because `w_b` + # was already permuted. + # You can explicitly check this by comparing its value with W'.(x^c - x) + # at this point. + + # Set points where gradient is 0 as fixed + # TODO: Does this cause problems with sadle points? + free_mask = (cauchy_loop.free_mask & (bfgs_state.objective_gradient != 0)) + + # Hint the compiler that shape of things will not change + if not tf.executing_eagerly(): + dt_min = tf.ensure_shape(dt_min, cauchy_loop.dt_min.shape) + tsum = tf.ensure_shape(tsum, cauchy_loop.tsum.shape) + cauchy_point = tf.ensure_shape( + cauchy_point, cauchy_loop.cauchy_point.shape) + c = tf.ensure_shape(c, cauchy_loop.c.shape) + free_mask = tf.ensure_shape(free_mask, cauchy_loop.free_mask.shape) + # Do the actual updating + final_cauchy_state = bfgs_utils.update_fields( + cauchy_loop, dt_min=dt_min, tsum=tsum, cauchy_point=cauchy_point, c=c, + free_mask=free_mask) + + return final_cauchy_state, bfgs_state + + +def _get_initial_cauchy_state(bfgs_state, num_correction_pairs): + """Create `_ConstrainedCauchyState` with initial parameters. + + This will calculate the elements of `_ConstrainedCauchyState` based on the + given `LBfgsBOptimizerResults` state object. Some of these properties may be + incalculable, for which batches the state will be reset. - for i in range(m): # Second loop. - beta = gradient_deltas[i]^T * r_direction / inv_rho[i] - r_direction = r_direction + position_deltas[i] * (alpha[i] - beta) + Args: + bfgs_state: `LBfgsBOptimizerResults` object representing the current state + of the LBFGSB optimization + num_correction_pairs: typically `m`; the (maximum) number of past steps to + keep as history for the LBFGS algorithm - return -r_direction # Approximates - H_k * objective_gradient. - ``` + Returns: + Initialized `_ConstrainedCauchyState` + Updated `bfgs_state` + """ + cauchy_point = bfgs_state.position + + theta = tf.math.divide_no_nan( + tf.reduce_sum(bfgs_state.gradient_deltas[..., -1, :]**2, axis=-1), + (tf.reduce_sum(bfgs_state.gradient_deltas[..., -1, :] * + bfgs_state.position_deltas[..., -1, :], axis=-1))) + theta = tf.where(bfgs_state.history == 0, 1., theta) + + m, refresh = _cauchy_init_m(bfgs_state, theta, num_correction_pairs) + + # Erase the history where M isn't invertible + bfgs_state = _erase_history(bfgs_state, refresh) + theta = tf.where(refresh, 1., theta) + + breakpoints = _cauchy_init_breakpoints(bfgs_state) + breakpoints_argsort = tf.argsort(breakpoints) + + steepest = tf.where((breakpoints > 0.), -bfgs_state.objective_gradient, 0.) + + # We need to account for the varying histories: + # we assume that the first `2*(m-h)` rows of W'^T + # are 0 (where `m` is the number of correction pairs + # and `h` is the history), in concordance with the first + # `2*(m-h)` rows of M being 0. + # 1. Calculate all elements + p = tf.concat( + [ + tf.einsum( + "...mi,...i->...m", + bfgs_state.gradient_deltas, + steepest), + (theta[..., tf.newaxis] * + tf.einsum( + "...mi,...i->...m", + bfgs_state.position_deltas, + steepest)) + ], + axis=-1) + # 2. Assemble the rows in the correct order + idx = tf.concat( + [ + tf.ragged.range( + num_correction_pairs - bfgs_state.history), + tf.ragged.range( + num_correction_pairs, + 2*num_correction_pairs - bfgs_state.history), + tf.ragged.range( + num_correction_pairs - bfgs_state.history, + num_correction_pairs), + tf.ragged.range( + 2*num_correction_pairs - bfgs_state.history, + 2*num_correction_pairs) + ], + axis=-1).to_tensor() + p = tf.gather( + p, + idx, + batch_dims=1) + + c = tf.zeros_like(p) + df = -tf.reduce_sum(steepest**2, axis=-1) + ddf = -theta*df - tf.einsum("...i,...ij,...j->...", p, m, p) + dt_min = -tf.math.divide_no_nan(df, ddf) + tsum = tf.zeros_like(dt_min) + + # NOTE: These are placeholder values. + # All of these have shape [batch], which matches dt_min + dt = tf.zeros_like(dt_min) + breakpoint_min_old = tf.zeros_like(dt_min) + + next_free_idx = tf.reduce_sum(tf.where(breakpoints <= 0., 1, 0), axis=-1) + free_mask = (breakpoints > 0.) + + # NOTE: _cauchy_update_active should NOT be accounted for here; the first + # iteration should always run (if the batch is overall active) + active = ~(bfgs_state.converged | bfgs_state.failed) + + cauchy_state = _ConstrainedCauchyState( + theta=theta, m=m, breakpoints=breakpoints, + breakpoints_argsort=breakpoints_argsort, next_free_idx=next_free_idx, + steepest=steepest, p=p, c=c, df=df, ddf=ddf, dt=dt, dt_min=dt_min, + tsum=tsum, breakpoint_min_old=breakpoint_min_old, cauchy_point=cauchy_point, + active=active, free_mask=free_mask) + + return cauchy_state, bfgs_state + + +def _cauchy_init_breakpoints(state): + """Calculate the breakpoints for a `_CauchyMinimizationResult` state.""" + breakpoints = ( + tf.where( + state.objective_gradient < 0, + tf.math.divide_no_nan( + state.position - state.upper_bounds, + state.objective_gradient), + tf.where( + state.objective_gradient > 0, + tf.math.divide_no_nan( + state.position - state.lower_bounds, + state.objective_gradient), + float('inf'))) + ) + + return breakpoints + + +def _find_search_direction(bfgs_state, cauchy_state, num_correction_pairs): + """Finds the search direction based on the direct primal method. + + This function corresponds to points 1-6 of the Direct Primal Method presented + in [2, p. 1199] for subspace minimization, with the first modification + suggested in [3]. + + If an invalid condition is reached for a given batch, its history is reset. + Therefore, this function also returns an updated `bfgs_state`. Args: - state: A `LBfgsOptimizerResults` tuple with the current state of the - search procedure. + bfgs_state: the `LBfgsBOptimizerResults` object representing the current + iteration. + cauchy_state: the `_CauchyMinimizationResult` results of a cauchy search + computation. Typically the output of `_cauchy_minimization`. + num_correction_pairs: The (maximum) number of correction pairs stored in + memory (`m`) + Returns: + Tensor of batched search directions, + Updated `bfgs_state`, + Tensor of Boolean dtype indicating whether the search direction should be + clamped to bounds before the search is performed, + Tensor of Boolean dtype indicating what batches have been refreshed. + """ + def _find_constrained_minimizer(): + """Performs free subspace minimization based on the Direct Method.""" + # Let the reduced gradient be [2, eq. 5.4] + # + # ρ = Z'r + # r = g + Θ(x^c - x) + (1/Θ).W.M.c + # + # and the search direction [2, eq. 5.7] + # + # d = -B⁻¹ρ + # + # and [2, eq. 5.10] + # + # B⁻¹ = 1/Θ [ I + 1/Θ Z'.W.N⁻¹.M.W'.Z ] + # N = I - 1/Θ M.W'.Z.Z'.W + # + # Therefore, + # + # d = Z' . (-1/Θ) . [ r + 1/Θ W.N⁻¹.M.W'.Z.Z'.r ] + # + # NOTE that the leading sign does not match that of [2, eq. 5.11]. This is + # because the article conflates the definition of r in [2, eq. 5.4] and the + # definition employed in the Fortran implementation, where + # + # r = -Z'B(x^c - x) - Z'g + # + # From which follows + # + # d = Z' (1/Θ) . [ r + 1/Θ W.N⁻¹.M.W'.Z.Z'.r ] + idx = ( + tf.concat([ + tf.ragged.range( + num_correction_pairs - bfgs_state.history), + tf.ragged.range( + num_correction_pairs, + 2*num_correction_pairs - bfgs_state.history), + tf.ragged.range( + num_correction_pairs - bfgs_state.history, + num_correction_pairs), + tf.ragged.range( + 2*num_correction_pairs - bfgs_state.history, + 2*num_correction_pairs) + ], + axis=-1).to_tensor()) + + w_transpose = ( + tf.gather( + tf.concat( + [bfgs_state.gradient_deltas, + cauchy_state.theta[..., tf.newaxis, tf.newaxis] * + bfgs_state.position_deltas], + axis=-2), + idx, + batch_dims=1) + ) + + r = ( + cauchy_state.theta[..., tf.newaxis] * + (bfgs_state.position - cauchy_state.cauchy_point) + + tf.einsum( + '...ji,...jk,...k->...i', + w_transpose, + cauchy_state.m, + cauchy_state.c) - + bfgs_state.objective_gradient) + + n = ( + tf.eye( + num_rows=num_correction_pairs*2, + batch_shape=ps.shape(bfgs_state.position)[:-1]) - + (tf.einsum( + '...ij,...jk,...lk->...il', + cauchy_state.m, + w_transpose, + tf.where( + cauchy_state.free_mask[..., tf.newaxis, :], + w_transpose, + 0.) + ) / cauchy_state.theta[..., tf.newaxis, tf.newaxis])) + + # NOTE: No need to "mask" the no-history subspace of N: because of I - (...) + # we correctly get a block form. The extraneous identity block is then + # zeroed when the product with M is taken + refresh = (tf.linalg.det(n) == 0.) + + n = tf.linalg.inv( + tf.where( + refresh[..., tf.newaxis, tf.newaxis], + tf.eye( + num_rows=num_correction_pairs*2, + batch_shape=ps.shape(bfgs_state.position)[:-1]), + n)) + + n = tf.where( + refresh[..., tf.newaxis, tf.newaxis], + tf.zeros_like(n), + n) + + # d is composed in three parts + d = tf.einsum('...ji,...jk,...kl,...lm,...m->...i', + w_transpose, + n, + cauchy_state.m, + tf.where( + cauchy_state.free_mask[..., tf.newaxis, :], + w_transpose, + 0.), + r) + + d = r + d/cauchy_state.theta[..., tf.newaxis] + d = d/cauchy_state.theta[..., tf.newaxis] + + d = tf.where(cauchy_state.free_mask, d, 0.) + + # Per [3]: + # Project `(cauchy point) + d` into the bounds + # NOTE: `d` is zeroed for constrained variables, and `movement_clip` is + # at most 1. + minimizer = tf.clip_by_value( + cauchy_state.cauchy_point + d, + bfgs_state.lower_bounds, + bfgs_state.upper_bounds) + + # Per [3]: If the search direction obtained with this minimizer is not a + # direction of strong descent, allow the minimizer to be oob, and clip the + # direction (i.e. fall back to the original algorithm). The clipping is + # handled outside this fn. + fallback = (tf.reduce_sum((minimizer - bfgs_state.position) * + bfgs_state.objective_gradient, axis=-1) > 0) + + minimizer = tf.where( + fallback[..., tf.newaxis], + cauchy_state.cauchy_point + d, + minimizer) + + active = (tf.reduce_any(cauchy_state.free_mask, axis=-1) & + (bfgs_state.history > 0)) + minimizer = tf.where( + active[..., tf.newaxis], minimizer, cauchy_state.cauchy_point) + + return minimizer, refresh, fallback + + # NOTE: we're abusing `bfgs_state.history.shape` again to get the batch + # dimensions Also: the Cauchy point is a minimization along the (projected) + # minus gradient direction; this is why we can skip subspace minimization if + # there's no history (because the search direction would indeed have been the + # minus gradient), but should run it otherwise (to make use of the BFGS + # information). + skip_subspace = ( + (~tf.reduce_any(cauchy_state.free_mask)) | + tf.reduce_all(bfgs_state.history == 0)) + minimizer, refresh, clip_before = ( + tf.cond( + pred=skip_subspace, + true_fn=lambda: (cauchy_state.cauchy_point, + tf.broadcast_to( + False, ps.shape(bfgs_state.history)), + tf.broadcast_to(True, ps.shape(bfgs_state.history))), + false_fn=_find_constrained_minimizer)) + + search_direction = (minimizer - bfgs_state.position) + + # Reset if the search direction still isn't a direction of strong descent + refresh |= ( + tf.reduce_sum( + search_direction * bfgs_state.objective_gradient, axis=-1) > 0) + + # Refresh conditions only make sense if a batch had not already converged + refresh &= ~ (bfgs_state.converged | bfgs_state.failed) + + # Apply refresh + bfgs_state = _erase_history(bfgs_state, refresh) + + return search_direction, bfgs_state, clip_before, refresh + + +def _constrained_line_search_step(bfgs_state, value_and_gradients_function, + search_direction, grad_tolerance, f_relative_tolerance, + x_tolerance, stopping_condition, max_iterations, clip_before): + """Performs a constrained line search clamped to bounds in given direction.""" + inactive = (bfgs_state.failed | bfgs_state.converged) + + def _do_line_search_step(): + """Do unconstrained line search.""" + nonlocal search_direction + # Truncation bounds + lower_term = tf.math.divide_no_nan( + bfgs_state.lower_bounds - bfgs_state.position, + search_direction) + upper_term = tf.math.divide_no_nan( + bfgs_state.upper_bounds - bfgs_state.position, + search_direction) + bounds_clip = ( + tf.reduce_min( + tf.where( + (search_direction > 0), + upper_term, + tf.where( + (search_direction < 0), + lower_term, + float('inf'))), + axis=-1) + ) + + search_direction *= tf.where( + clip_before, + tf.math.minimum(1., bounds_clip), + 1.)[..., tf.newaxis] + + def _fn_with_report(x): + return value_and_gradients_function( + x, inactive, bfgs_state.objective_value, bfgs_state.objective_gradient) + + ls_result = _hz_line_search( + bfgs_state.position, bfgs_state.objective_value, + bfgs_state.objective_gradient, + _fn_with_report, search_direction, + max_iterations, inactive) + + # Truncate to bounds after search + step = ( + tf.math.minimum( + bounds_clip, + ls_result.left.x + ) + ) + + # For inactive batch members `left.x` is zero. However, their + # `search_direction` might also be undefined, so we can't rely on + # multiplication by zero to produce a `position_delta` of zero. + next_position = tf.where( + inactive[..., tf.newaxis], + bfgs_state.position, + bfgs_state.position + step[..., tf.newaxis] * search_direction) + + # If the movement isn't clipped, we can use the final results of the + # line search. + reevaluated = (tf.reduce_any(ls_result.left.x > bounds_clip)) + next_objective, next_gradient = ( + tf.cond( + pred=reevaluated, + true_fn=lambda: value_and_gradients_function( + next_position, inactive, bfgs_state.objective_value, + bfgs_state.objective_gradient), + false_fn=lambda: (ls_result.left.f, + ls_result.left.full_gradient) + ) + ) + + new_failed = (bfgs_state.failed | ( + ~inactive & ~bfgs_state.converged & ~ls_result.converged)) + new_num_iterations = bfgs_state.num_iterations + 1 + new_num_objective_evaluations = tf.cond( + pred=reevaluated, + true_fn=lambda: ( + bfgs_state.num_objective_evaluations + ls_result.func_evals + 1), + false_fn=lambda: ( + bfgs_state.num_objective_evaluations + ls_result.func_evals)) + + # Hint the compiler that the properties' shape will not change + if not tf.executing_eagerly(): + new_failed = tf.ensure_shape(new_failed, bfgs_state.failed.shape) + new_num_iterations = tf.ensure_shape( + new_num_iterations, bfgs_state.num_iterations.shape) + new_num_objective_evaluations = tf.ensure_shape( + new_num_objective_evaluations, bfgs_state.num_objective_evaluations.shape) + + state_after_ls = bfgs_utils.update_fields( + state=bfgs_state, + failed=new_failed, + num_iterations=new_num_iterations, + num_objective_evaluations=new_num_objective_evaluations) + + return state_after_ls, next_position, next_objective, next_gradient + + # NOTE: It's important that the default (false `pred`) step matches + # the shape of true `pred` shape for graph purposes + state_after_ls, next_position, next_objective, next_gradient = ( + tf.cond( + pred=tf.math.logical_not(tf.reduce_all(inactive)), + true_fn=_do_line_search_step, + false_fn=lambda: (bfgs_state, + bfgs_state.position, + bfgs_state.objective_value, + bfgs_state.objective_gradient) + )) + + def _do_update_position(): + """Update the position""" + return _update_position( + state_after_ls, + next_position, + next_objective, + next_gradient, + grad_tolerance, + f_relative_tolerance, + x_tolerance, + inactive) + + return ps.cond( + (stopping_condition(bfgs_state.converged, bfgs_state.failed) | + tf.reduce_all(inactive)), + true_fn=lambda: state_after_ls, + false_fn=_do_update_position) + + +def _hz_line_search(starting_position, starting_value, starting_gradient, + value_and_gradients_function, search_direction, max_iterations, + inactive): + """Performs Hager Zhang line search via `bfgs_utils.linesearch.hager_zhang`.""" + line_search_value_grad_func = bfgs_utils._restrict_along_direction( + value_and_gradients_function, starting_position, search_direction) + derivative_at_start_pt = tf.reduce_sum( + starting_gradient * search_direction, axis=-1) + val_0 = bfgs_utils.ValueAndGradient( + x=bfgs_utils._broadcast(0, starting_position), + f=starting_value, + df=derivative_at_start_pt, + full_gradient=starting_gradient) + return bfgs_utils.linesearch.hager_zhang( + line_search_value_grad_func, + initial_step_size=bfgs_utils._broadcast(1, starting_position), + value_at_zero=val_0, + converged=inactive, + max_iterations=max_iterations) + + +def _update_position(state, + next_position, + next_objective, + next_gradient, + grad_tolerance, + f_relative_tolerance, + x_tolerance, + inactive): + """Updates the state advancing its position by a given position_delta.""" + state = bfgs_utils.terminate_if_not_finite( + state, next_objective, next_gradient) + + converged = (~inactive & ~state.failed & + _check_convergence_bounded(state.position, + next_position, + state.objective_value, + next_objective, + next_gradient, + grad_tolerance, + f_relative_tolerance, + x_tolerance, + state.lower_bounds, + state.upper_bounds)) + new_converged = (state.converged | converged) + + if not tf.executing_eagerly(): + # Hint the compiler that the properties have not changed shape + new_converged = tf.ensure_shape(new_converged, state.converged.shape) + next_position = tf.ensure_shape(next_position, state.position.shape) + next_objective = tf.ensure_shape( + next_objective, state.objective_value.shape) + next_gradient = tf.ensure_shape( + next_gradient, state.objective_gradient.shape) + + return bfgs_utils.update_fields( + state, + converged=new_converged, + position=next_position, + objective_value=next_objective, + objective_gradient=next_gradient) + + +def _erase_history(bfgs_state, where_erase): + """Erases the BFGS correction pairs for the specified batches. + + This function will zero `gradient_deltas`, `position_deltas`, and `history`. + Args: + `bfgs_state`: a `LBfgsBOptimizerResults` to modify + `where_erase`: a Boolean tensor with shape matching the batch dimensions + with `True` for the batches to erase the history of. Returns: - A real `Tensor` of the same shape as the `state.position`. The direction - along which to perform line search. + Modified `bfgs_state`. """ - # The number of correction pairs that have been collected so far. - num_elements = ps.minimum( - state.num_iterations, # TODO(b/162733947): Change loop state -> closure. - ps.shape(state.position_deltas)[0]) - - def _two_loop_algorithm(): - """L-BFGS two-loop algorithm.""" - # Correction pairs are always appended to the end, so only the latest - # `num_elements` vectors have valid position/gradient deltas. Vectors - # that haven't been computed yet are zero. - position_deltas = state.position_deltas - gradient_deltas = state.gradient_deltas - - # Pre-compute all `inv_rho[i]`s. - inv_rhos = tf.reduce_sum( - gradient_deltas * position_deltas, axis=-1) - - def first_loop(acc, args): - _, q_direction = acc - position_delta, gradient_delta, inv_rho = args - alpha = tf.math.divide_no_nan( - tf.reduce_sum(position_delta * q_direction, axis=-1), inv_rho) - direction_delta = alpha[..., tf.newaxis] * gradient_delta - return (alpha, q_direction - direction_delta) - - # Run first loop body computing and collecting `alpha[i]`s, while also - # computing the updated `q_direction` at each step. - zero = tf.zeros_like(inv_rhos[-num_elements]) - alphas, q_directions = tf.scan( - first_loop, [position_deltas, gradient_deltas, inv_rhos], - initializer=(zero, state.objective_gradient), reverse=True) - - # We use `H^0_k = gamma_k * I` as an estimate for the initial inverse - # hessian for the k-th iteration; then `r_direction = H^0_k * q_direction`. - gamma_k = inv_rhos[-1] / tf.reduce_sum( - gradient_deltas[-1] * gradient_deltas[-1], axis=-1) - r_direction = gamma_k[..., tf.newaxis] * q_directions[-num_elements] - - def second_loop(r_direction, args): - alpha, position_delta, gradient_delta, inv_rho = args - beta = tf.math.divide_no_nan( - tf.reduce_sum(gradient_delta * r_direction, axis=-1), inv_rho) - direction_delta = (alpha - beta)[..., tf.newaxis] * position_delta - return r_direction + direction_delta - - # Finally, run second loop body computing the updated `r_direction` at each - # step. - r_directions = tf.scan( - second_loop, [alphas, position_deltas, gradient_deltas, inv_rhos], - initializer=r_direction) - return -r_directions[-1] - - return ps.cond(ps.equal(num_elements, 0), - lambda: -state.objective_gradient, - _two_loop_algorithm) + # Calculate new values + new_gradient_deltas = (tf.where( + where_erase[..., tf.newaxis, tf.newaxis], + 0., + bfgs_state.gradient_deltas)) + new_position_deltas = (tf.where( + where_erase[..., tf.newaxis, tf.newaxis], + 0., + bfgs_state.position_deltas)) + new_history = tf.where(where_erase, 0, bfgs_state.history) + # Assure the compiler that the shape of things has not changed + if not tf.executing_eagerly(): + new_gradient_deltas = ( + tf.ensure_shape( + new_gradient_deltas, + bfgs_state.gradient_deltas.shape)) + new_position_deltas = ( + tf.ensure_shape( + new_position_deltas, + bfgs_state.position_deltas.shape)) + new_history = ( + tf.ensure_shape( + new_history, + bfgs_state.history.shape)) + # Update and return + return bfgs_utils.update_fields( + bfgs_state, + gradient_deltas=new_gradient_deltas, + position_deltas=new_position_deltas, + history=new_history) + + +def _check_convergence_bounded(current_position, + next_position, + current_objective, + next_objective, + next_gradient, + grad_tolerance, + f_relative_tolerance, + x_tolerance, + lower_bounds, + upper_bounds): + """Checks if the algorithm satisfies the convergence criteria.""" + # NOTE: The original algorithm (as described in [2]) only considers halting on + # the projected gradient condition. However, `x_converged` and `f_converged` + # do not seem to pose a problem when refreshing is correctly accounted for + # (so that the optimization does not halt upon a refresh), and the default + # values of `0` for `f_relative_tolerance` and `x_tolerance` further + # strengthen these conditions. + proj_grad_converged = bfgs_utils.norm( + tf.clip_by_value( + next_position - next_gradient, + lower_bounds, + upper_bounds) - next_position, dims=1) <= grad_tolerance + x_converged = bfgs_utils.norm( + next_position - current_position, dims=1) <= x_tolerance + f_ref = tf.math.maximum(1., tf.math.maximum( + tf.math.abs(next_objective), + tf.math.abs(current_objective))) + f_converged = (tf.math.abs(next_objective - current_objective) + <= f_ref*f_relative_tolerance) + return proj_grad_converged | x_converged | f_converged + + +def _get_initial_state(value_and_gradients_function, + initial_position, + lower_bounds, + upper_bounds, + num_correction_pairs, + tolerance): + """Create LBfgsBOptimizerResults with initial state of search procedure.""" + init_args = get_initial_state_args(value_and_gradients_function, + initial_position, + tolerance) + empty_queue = _make_empty_queue_for(num_correction_pairs, initial_position) + zero_history = tf.zeros(ps.shape(initial_position)[:-1], dtype=tf.int32) + init_args.update( + lower_bounds=lower_bounds, + upper_bounds=upper_bounds, + position_deltas=empty_queue, + gradient_deltas=empty_queue, + history=zero_history) + return LBfgsBOptimizerResults(**init_args) + + +def get_initial_state_args(value_and_gradients_function, + initial_position, + grad_tolerance, + control_inputs=None): + none_finished = tf.broadcast_to(False, ps.shape(initial_position)[:-1]) + zero_values = bfgs_utils._broadcast(0., initial_position) + zero_gradients = tf.zeros_like(initial_position) + if control_inputs: + with tf.control_dependencies(control_inputs): + f0, df0 = value_and_gradients_function( + initial_position, none_finished, zero_values, zero_gradients) + else: + f0, df0 = value_and_gradients_function( + initial_position, none_finished, zero_values, zero_gradients) + # This is a gradient-based convergence check. We only do it for finite + # objective values because we assume the gradient reported at a position with + # a non-finite objective value is untrustworthy. The main loop handles + # non-finite objective values itself (see `terminate_if_not_finite`). + init_converged = (tf.math.is_finite(f0) & + (bfgs_utils.norm(df0, dims=1) < grad_tolerance)) + return dict( + converged=init_converged, + failed=tf.zeros_like(init_converged), # i.e. False. + num_iterations=tf.convert_to_tensor(0), + num_objective_evaluations=tf.convert_to_tensor(1), + position=initial_position, + objective_value=f0, + objective_gradient=df0) + + +def _cauchy_init_m(state, theta, num_correction_pairs): + """Initialize the M matrix for a `_CauchyMinimizationResult` state.""" + def build_m(): + """Construct and invert the M block matrix.""" + # All of the below block matrices have dimensions [..., 2m, 2m] + # where `...` denotes the batch dimensions, and `m` the number + # of correction pairs. + # New elements are pushed in "from the back", so we want to index + # position_deltas and gradient_deltas with negative indices. + # Index 0 of `position_deltas` and `gradient_deltas` is oldest, and index -1 + # is most recent, so the below respects the indexing of the article. + + # 1. calculate inner product (s_i.y_j) in shape [..., m, m] + l = tf.einsum( + "...mi,...ui->...mu", + state.position_deltas, + state.gradient_deltas) + # 2. Zero out diagonal and upper triangular + l_shape = ps.shape(l) + l = tf.linalg.set_diag( + tf.linalg.band_part(l, -1, 0), + tf.zeros([l_shape[0], l_shape[-1]])) + l_transpose = tf.linalg.matrix_transpose(l) + s_t_s = tf.einsum( + '...mi,...ni->...mn', + state.position_deltas, + state.position_deltas) + d = tf.linalg.diag( + tf.einsum( + '...mi,...mi->...m', + state.position_deltas, + state.gradient_deltas)) + + # Assemble into full matrix + # shape [b, 2m, 2m] + m_inv = tf.concat( + [ + tf.concat([-d, l_transpose], axis=-1), + tf.concat( + [l, theta[..., tf.newaxis, tf.newaxis] * s_t_s], axis=-1) + ], axis=-2) + + # Adjust for varying history: + # Push columns indexed h,...,2m-h to the left (but to the right of 0...m-h) + # and same index rows to the bottom + idx = tf.concat( + [tf.ragged.range(num_correction_pairs-state.history), + tf.ragged.range(num_correction_pairs, 2 * + num_correction_pairs-state.history), + tf.ragged.range(num_correction_pairs - + state.history, num_correction_pairs), + tf.ragged.range( + 2*num_correction_pairs-state.history, 2*num_correction_pairs)], + axis=-1).to_tensor() + m_inv = tf.gather( + m_inv, + idx, + axis=-1, + batch_dims=1) + m_inv = tf.gather( + m_inv, + idx, + axis=-2, + batch_dims=1) + + # Insert an identity in the empty block + identity_mask = ( + (tf.range(ps.shape(m_inv)[-1])[tf.newaxis, ...] < + 2*(num_correction_pairs - state.history[..., tf.newaxis]))[..., tf.newaxis]) + + m_inv = tf.where( + identity_mask, + tf.eye(ps.shape(m_inv)[-1], batch_shape=ps.shape(m_inv)[:-2]), + m_inv) + + # If M is not invertible, refresh the memory + # TODO: Checking the determinant is likely overkill? + refresh = (tf.linalg.det(m_inv) == 0) + + # Invert where invertible; 0s otherwise + m = tf.where( + refresh[..., tf.newaxis, tf.newaxis], + tf.zeros_like(m_inv), + tf.linalg.inv( + tf.where( + refresh[..., tf.newaxis, tf.newaxis], + tf.eye(ps.shape(m_inv)[-1], + batch_shape=ps.shape(m_inv)[:-2]), + m_inv))) + + # Re-zero the introduced identity blocks + m = tf.where( + identity_mask, + tf.zeros_like(m), + m) + + return m, refresh + + # M is 0 for the first iterations + # We abuse `state.history` to extract the batch shape + m_shape = ps.concat([ps.shape(state.history), + [num_correction_pairs*2, num_correction_pairs*2]], axis=0) + return tf.cond( + state.num_iterations < 1, + lambda: (tf.zeros(m_shape), + tf.broadcast_to(False, ps.shape(state.history))), + build_m) def _make_empty_queue_for(k, element): @@ -402,18 +1523,16 @@ def _make_empty_queue_for(k, element): ```python element = tf.constant([[0., 1., 2., 3., 4.], - [5., 6., 7., 8., 9.]]) + [5., 6., 7., 8., 9.]]) # A queue capable of holding 3 elements. _make_empty_queue_for(3, element) - # => [[[ 0., 0., 0., 0., 0.], - # [ 0., 0., 0., 0., 0.]], - # - # [[ 0., 0., 0., 0., 0.], - # [ 0., 0., 0., 0., 0.]], - # - # [[ 0., 0., 0., 0., 0.], - # [ 0., 0., 0., 0., 0.]]] + # => [[[0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.]], + # [[0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.], + # [0., 0., 0., 0., 0.]]] ``` Args: @@ -421,17 +1540,18 @@ def _make_empty_queue_for(k, element): element: A `tf.Tensor`, only its shape and dtype information are relevant. Returns: - A zero-filed `tf.Tensor` of shape `(k,) + tf.shape(element)` and same dtype - as `element`. + A zero-filed `tf.Tensor` of shape `(s[:-1], k, s[-1])`, where + `s = tf.shape(element)`, and same dtype as `element`. """ - queue_shape = ps.concat([[k], ps.shape(element)], axis=0) + queue_shape = ps.concat( + [ps.shape(element)[:-1], [k], ps.shape(element)[-1:]], axis=0) return tf.zeros(queue_shape, dtype=dtype_util.base_dtype(element.dtype)) def _queue_push(queue, should_update, new_vecs): """Conditionally push new vectors into a batch of first-in-first-out queues. - The `queue` of shape `[k, ..., n]` can be thought of as a batch of queues, + The `queue` of shape `[..., k, n]` can be thought of as a batch of queues, each holding `k` n-D vectors; while `new_vecs` of shape `[..., n]` is a fresh new batch of n-D vectors. The `should_update` batch of Boolean scalars, i.e. shape `[...]`, indicates batch members whose corresponding n-D vector in @@ -439,54 +1559,50 @@ def _queue_push(queue, should_update, new_vecs): corresponding n-D vector from the front. Batch members in `new_vecs` for which `should_update` is False are ignored. - Note: the choice of placing `k` at the dimension 0 of the queue is - constrained by the L-BFGS two-loop algorithm above. The algorithm uses - tf.scan to iterate over the `k` correction pairs simulatneously across all - batches, and tf.scan itself can only iterate over dimension 0. + Note: whereas `lbfgs.py` places the `k` at dimension 0 due to constraints + of `tf.scan`, those do not apply here, and in fact it is more advantageous + to have the batch dimensions before `k`. For example: ```python - k, b, n = (3, 2, 5) - queue = tf.reshape(tf.range(30), (k, b, n)) + b, k, n = (2, 3, 5) + queue = tf.reshape(tf.range(30), (b, k, n)) # => [[[ 0, 1, 2, 3, 4], - # [ 5, 6, 7, 8, 9]], - # - # [[10, 11, 12, 13, 14], - # [15, 16, 17, 18, 19]], - # - # [[20, 21, 22, 23, 24], - # [25, 26, 27, 28, 29]]] + # [ 5, 6, 7, 8, 9], + # [10, 11, 12, 13, 14]], + # [[15, 16, 17, 18, 19], + # [20, 21, 22, 23, 24], + # [25, 26, 27, 28, 29]]] element = tf.reshape(tf.range(30, 40), (b, n)) # => [[30, 31, 32, 33, 34], - [35, 36, 37, 38, 39]] + [35, 36, 37, 38, 39]] should_update = tf.constant([True, False]) # Shape: (b,) - _queue_add(should_update, queue, element) - # => [[[10, 11, 12, 13, 14], - # [ 5, 6, 7, 8, 9]], - # - # [[20, 21, 22, 23, 24], - # [15, 16, 17, 18, 19]], - # - # [[30, 31, 32, 33, 34], + _queue_push(queue, should_update, element) + # => [[[ 5, 6, 7, 8, 9], + # [10, 11, 12, 13, 14], + # [30, 31, 32, 33, 34]], + # [[15, 16, 17, 18, 19], + # [20, 21, 22, 23, 24], # [25, 26, 27, 28, 29]]] ``` Args: - queue: A `tf.Tensor` of shape `[k, ..., n]`; a batch of queues each with - `k` n-D vectors. + queue: A `tf.Tensor` of shape `[..., k, n]`; a batch of queues each with + `k` n-D vectors. should_update: A Boolean `tf.Tensor` of shape `[...]` indicating batch - members where new vectors should be added to their queues. + members where new vectors should be added to their queues. new_vecs: A `tf.Tensor` of shape `[..., n]`; a batch of n-D vectors to add - at the end of their respective queues, pushing out the first element from - each. + at the end of their respective queues, pushing out the first element from + each. Returns: - A new `tf.Tensor` of shape `[k, ..., n]`. + A new `tf.Tensor` of shape `[..., k, n]`. """ - new_queue = tf.concat([queue[1:], [new_vecs]], axis=0) + new_queue = tf.concat( + [queue[..., 1:, :], new_vecs[..., tf.newaxis, :]], axis=-2) return tf.where( - should_update[tf.newaxis, ..., tf.newaxis], new_queue, queue) + should_update[..., tf.newaxis, tf.newaxis], new_queue, queue) From 64d30a428ea3d81b08c5a23faad2795372411084 Mon Sep 17 00:00:00 2001 From: mikeevmm Date: Fri, 25 Jun 2021 14:55:57 +0100 Subject: [PATCH 4/4] wip: test file base --- .../python/optimizer/lbfgsb_test.py | 573 ++++++++++++++++++ 1 file changed, 573 insertions(+) create mode 100644 tensorflow_probability/python/optimizer/lbfgsb_test.py diff --git a/tensorflow_probability/python/optimizer/lbfgsb_test.py b/tensorflow_probability/python/optimizer/lbfgsb_test.py new file mode 100644 index 0000000000..03dd8fdf9c --- /dev/null +++ b/tensorflow_probability/python/optimizer/lbfgsb_test.py @@ -0,0 +1,573 @@ +# Copyright 2018 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for the constrained L-BFGS-B optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from absl.testing import parameterized +import numpy as np +from scipy.stats import special_ortho_group + +import tensorflow.compat.v1 as tf1 +import tensorflow.compat.v2 as tf +import tensorflow_probability as tfp + +from tensorflow_probability.python.internal import test_util + + +def _make_val_and_grad_fn(value_fn): + @functools.wraps(value_fn) + def val_and_grad(x): + return tfp.math.value_and_gradient(value_fn, x) + return val_and_grad + + +def _norm(x): + return np.linalg.norm(x, np.inf) + + +@test_util.test_all_tf_execution_regimes +class LBfgsTest(test_util.TestCase): + """Tests for LBFGSB optimization algorithm.""" + + def test_quadratic_bowl_2d(self): + """Can minimize a two dimensional quadratic function when constrained.""" + minimum = np.array([1.0, 1.0]) + scales = np.array([2.0, 3.0]) + lower_bounds = np.array([0., 2.]) + upper_bounds = np.array([2., 5.]) + expected = np.array([1.0, 2.0]) + + @_make_val_and_grad_fn + def quadratic(x): + return tf.reduce_sum(scales * tf.math.squared_difference(x, minimum)) + + start = tf.constant([0.6, 0.8]) + results = self.evaluate(tfp.optimizer.lbfgsb_minimize( + quadratic, initial_position=start, tolerance=1e-8, + lower_bounds=lower_bounds, upper_bounds=upper_bounds)) + self.assertTrue(results.converged) + self.assertLessEqual(_norm(results.objective_gradient), 1e-8) + self.assertArrayNear(results.position, expected, 1e-5) + + # TODO: + def test_high_dims_quadratic_bowl_trivial(self): + """Can minimize a high-dimensional trivial bowl (sphere).""" + ndims = 100 + minimum = np.ones([ndims], dtype='float64') + scales = np.ones([ndims], dtype='float64') + + @_make_val_and_grad_fn + def quadratic(x): + return tf.reduce_sum(scales * tf.math.squared_difference(x, minimum)) + + start = np.zeros([ndims], dtype='float64') + results = self.evaluate(tfp.optimizer.lbfgs_minimize( + quadratic, initial_position=start, tolerance=1e-8)) + self.assertTrue(results.converged) + self.assertEqual(results.num_iterations, 1) # Solved by first line search. + self.assertLessEqual(_norm(results.objective_gradient), 1e-8) + self.assertArrayNear(results.position, minimum, 1e-5) + + # TODO: + def test_quadratic_bowl_40d(self): + """Can minimize a high-dimensional quadratic function.""" + dim = 40 + np.random.seed(14159) + minimum = np.random.randn(dim) + scales = np.exp(np.random.randn(dim)) + + @_make_val_and_grad_fn + def quadratic(x): + return tf.reduce_sum(scales * tf.math.squared_difference(x, minimum)) + + start = tf.ones_like(minimum) + results = self.evaluate(tfp.optimizer.lbfgs_minimize( + quadratic, initial_position=start, tolerance=1e-8)) + self.assertTrue(results.converged) + self.assertLessEqual(_norm(results.objective_gradient), 1e-8) + self.assertArrayNear(results.position, minimum, 1e-5) + + # TODO: + def test_quadratic_with_skew(self): + """Can minimize a general quadratic function.""" + dim = 50 + np.random.seed(26535) + minimum = np.random.randn(dim) + principal_values = np.diag(np.exp(np.random.randn(dim))) + rotation = special_ortho_group.rvs(dim) + hessian = np.dot(np.transpose(rotation), np.dot(principal_values, rotation)) + + @_make_val_and_grad_fn + def quadratic(x): + y = x - minimum + yp = tf.tensordot(hessian, y, axes=[1, 0]) + return tf.reduce_sum(y * yp) / 2 + + start = tf.ones_like(minimum) + results = self.evaluate(tfp.optimizer.lbfgs_minimize( + quadratic, initial_position=start, tolerance=1e-8)) + self.assertTrue(results.converged) + self.assertLessEqual(_norm(results.objective_gradient), 1e-8) + self.assertArrayNear(results.position, minimum, 1e-5) + + # TODO: + def test_quadratic_with_strong_skew(self): + """Can minimize a strongly skewed quadratic function.""" + np.random.seed(89793) + minimum = np.random.randn(3) + principal_values = np.diag(np.array([0.1, 2.0, 50.0])) + rotation = special_ortho_group.rvs(3) + hessian = np.dot(np.transpose(rotation), np.dot(principal_values, rotation)) + + @_make_val_and_grad_fn + def quadratic(x): + y = x - minimum + yp = tf.tensordot(hessian, y, axes=[1, 0]) + return tf.reduce_sum(y * yp) / 2 + + start = tf.ones_like(minimum) + results = self.evaluate(tfp.optimizer.lbfgs_minimize( + quadratic, initial_position=start, tolerance=1e-8)) + self.assertTrue(results.converged) + self.assertLessEqual(_norm(results.objective_gradient), 1e-8) + self.assertArrayNear(results.position, minimum, 1e-5) + + # TODO: + def test_rosenbrock_2d(self): + """Tests L-BFGS on the Rosenbrock function. + + The Rosenbrock function is a standard optimization test case. In two + dimensions, the function is (a, b > 0): + f(x, y) = (a - x)^2 + b (y - x^2)^2 + The function has a global minimum at (a, a^2). This minimum lies inside + a parabolic valley (y = x^2). + """ + def rosenbrock(coord): + """The Rosenbrock function in two dimensions with a=1, b=100. + + Args: + coord: A Tensor of shape [2]. The coordinate of the point to evaluate + the function at. + + Returns: + fv: A scalar tensor containing the value of the Rosenbrock function at + the supplied point. + dfx: Scalar tensor. The derivative of the function with respect to x. + dfy: Scalar tensor. The derivative of the function with respect to y. + """ + x, y = coord[0], coord[1] + fv = (1 - x)**2 + 100 * (y - x**2)**2 + dfx = 2 * (x - 1) + 400 * x * (x**2 - y) + dfy = 200 * (y - x**2) + return fv, tf.stack([dfx, dfy]) + + start = tf.constant([-1.2, 1.0]) + results = self.evaluate(tfp.optimizer.lbfgs_minimize( + rosenbrock, initial_position=start, tolerance=1e-5)) + self.assertTrue(results.converged) + self.assertLessEqual(_norm(results.objective_gradient), 1e-5) + self.assertArrayNear(results.position, np.array([1.0, 1.0]), 1e-5) + + # TODO: + def test_himmelblau(self): + """Tests minimization on the Himmelblau's function. + + Himmelblau's function is a standard optimization test case. The function is + given by: + + f(x, y) = (x^2 + y - 11)^2 + (x + y^2 - 7)^2 + + The function has four minima located at (3, 2), (-2.805118, 3.131312), + (-3.779310, -3.283186), (3.584428, -1.848126). + + All these minima may be reached from appropriate starting points. To keep + the runtime of this test small, here we only find the first two minima. + However, all four can be easily found in `test_himmelblau_batch_all` below + with the help of batching. + """ + @_make_val_and_grad_fn + def himmelblau(coord): + x, y = coord[0], coord[1] + return (x * x + y - 11) ** 2 + (x + y * y - 7) ** 2 + + starts_and_targets = [ + # Start Point, Target Minimum, Num evaluations expected. + [(1, 1), (3, 2), 31], + [(-2, 2), (-2.805118, 3.131312), 17], + ] + dtype = 'float64' + for start, expected_minima, expected_evals in starts_and_targets: + start = tf.constant(start, dtype=dtype) + results = self.evaluate(tfp.optimizer.lbfgs_minimize( + himmelblau, initial_position=start, tolerance=1e-8)) + self.assertTrue(results.converged) + self.assertArrayNear(results.position, + np.array(expected_minima, dtype=dtype), + 1e-5) + self.assertEqual(results.num_objective_evaluations, expected_evals) + + # TODO: + def test_himmelblau_batch_all(self): + @_make_val_and_grad_fn + def himmelblau(coord): + x, y = coord[..., 0], coord[..., 1] + return (x * x + y - 11) ** 2 + (x + y * y - 7) ** 2 + + dtype = 'float64' + starts = tf.constant([[1, 1], + [-2, 2], + [-1, -1], + [1, -2]], dtype=dtype) + expected_minima = np.array([[3, 2], + [-2.805118, 3.131312], + [-3.779310, -3.283186], + [3.584428, -1.848126]], dtype=dtype) + batch_results = self.evaluate(tfp.optimizer.lbfgs_minimize( + himmelblau, initial_position=starts, + stopping_condition=tfp.optimizer.converged_all, tolerance=1e-8)) + + self.assertFalse(np.any(batch_results.failed)) # None have failed. + self.assertTrue(np.all(batch_results.converged)) # All converged. + + # All converged points are near expected minima. + for actual, expected in zip(batch_results.position, expected_minima): + self.assertArrayNear(actual, expected, 1e-5) + self.assertEqual(batch_results.num_objective_evaluations, 36) + + # TODO: + def test_himmelblau_batch_any(self): + @_make_val_and_grad_fn + def himmelblau(coord): + x, y = coord[..., 0], coord[..., 1] + return (x * x + y - 11) ** 2 + (x + y * y - 7) ** 2 + + dtype = 'float64' + starts = tf.constant([[1, 1], + [-2, 2], + [-1, -1], + [1, -2]], dtype=dtype) + expected_minima = np.array([[3, 2], + [-2.805118, 3.131312], + [-3.779310, -3.283186], + [3.584428, -1.848126]], dtype=dtype) + + # Run with `converged_any` stopping condition, to stop as soon as any of + # the batch members have converged. + batch_results = self.evaluate(tfp.optimizer.lbfgs_minimize( + himmelblau, initial_position=starts, + stopping_condition=tfp.optimizer.converged_any, tolerance=1e-8)) + + self.assertFalse(np.any(batch_results.failed)) # None have failed. + self.assertTrue(np.any(batch_results.converged)) # At least one converged. + self.assertFalse(np.all(batch_results.converged)) # But not all did. + + # Converged points are near expected minima. + for actual, expected in zip(batch_results.position[batch_results.converged], + expected_minima[batch_results.converged]): + self.assertArrayNear(actual, expected, 1e-5) + self.assertEqual(batch_results.num_objective_evaluations, 28) + + # TODO: + def test_himmelblau_batch_any_resume_then_all(self): + @_make_val_and_grad_fn + def himmelblau(coord): + x, y = coord[..., 0], coord[..., 1] + return (x * x + y - 11) ** 2 + (x + y * y - 7) ** 2 + + dtype = 'float64' + starts = tf.constant([[1, 1], + [-2, 2], + [-1, -1], + [1, -2]], dtype=dtype) + expected_minima = np.array([[3, 2], + [-2.805118, 3.131312], + [-3.779310, -3.283186], + [3.584428, -1.848126]], dtype=dtype) + + # Run with `converged_any` stopping condition, to stop as soon as any of + # the batch members have converged. + raw_batch_results = tfp.optimizer.lbfgs_minimize( + himmelblau, initial_position=starts, + stopping_condition=tfp.optimizer.converged_any, tolerance=1e-8) + batch_results = self.evaluate(raw_batch_results) + + self.assertFalse(np.any(batch_results.failed)) # None have failed. + self.assertTrue(np.any(batch_results.converged)) # At least one converged. + self.assertFalse(np.all(batch_results.converged)) # But not all did. + + # Converged points are near expected minima. + for actual, expected in zip(batch_results.position[batch_results.converged], + expected_minima[batch_results.converged]): + self.assertArrayNear(actual, expected, 1e-5) + self.assertEqual(batch_results.num_objective_evaluations, 28) + + # Run with `converged_all`, starting from previous state. + batch_results = self.evaluate(tfp.optimizer.lbfgs_minimize( + himmelblau, initial_position=None, + previous_optimizer_results=raw_batch_results, + stopping_condition=tfp.optimizer.converged_all, tolerance=1e-8)) + + # All converged points are near expected minima and the nunmber of + # evaluaitons is as if we never stopped. + for actual, expected in zip(batch_results.position, expected_minima): + self.assertArrayNear(actual, expected, 1e-5) + self.assertEqual(batch_results.num_objective_evaluations, 36) + + # TODO: + def test_initial_position_and_previous_optimizer_results_are_exclusive(self): + minimum = np.array([1.0, 1.0]) + scales = np.array([2.0, 3.0]) + + @_make_val_and_grad_fn + def quadratic(x): + return tf.reduce_sum(scales * tf.math.squared_difference(x, minimum)) + + start = tf.constant([0.6, 0.8]) + + def run(position, state): + raw_results = tfp.optimizer.lbfgs_minimize( + quadratic, initial_position=position, + previous_optimizer_results=state, tolerance=1e-8) + self.evaluate(raw_results) + return raw_results + + self.assertRaises(ValueError, run, None, None) + results = run(start, None) + self.assertRaises(ValueError, run, start, results) + + # TODO: + def test_data_fitting(self): + """Tests MLE estimation for a simple geometric GLM.""" + n, dim = 100, 30 + dtype = tf.float64 + np.random.seed(234095) + x = np.random.choice([0, 1], size=[dim, n]) + s = 0.01 * np.sum(x, 0) + p = 1. / (1 + np.exp(-s)) + y = np.random.geometric(p) + x_data = tf.convert_to_tensor(x, dtype=dtype) + y_data = tf.convert_to_tensor(y, dtype=dtype)[..., tf.newaxis] + + @_make_val_and_grad_fn + def neg_log_likelihood(state): + state_ext = tf.expand_dims(state, 0) + linear_part = tf.matmul(state_ext, x_data) + linear_part_ex = tf.stack([tf.zeros_like(linear_part), + linear_part], axis=0) + term1 = tf.squeeze( + tf.matmul( + tf.reduce_logsumexp(linear_part_ex, axis=0), y_data), + -1) + term2 = ( + 0.5 * tf.reduce_sum(state_ext * state_ext, axis=-1) - + tf.reduce_sum(linear_part, axis=-1)) + return tf.squeeze(term1 + term2) + + start = tf.ones(shape=[dim], dtype=dtype) + + results = self.evaluate(tfp.optimizer.lbfgs_minimize( + neg_log_likelihood, initial_position=start, tolerance=1e-6)) + self.assertTrue(results.converged) + + # TODO: + def test_determinism(self): + """Tests that the results are determinsitic.""" + dim = 25 + + @_make_val_and_grad_fn + def rastrigin(x): + """The value and gradient of the Rastrigin function. + + The Rastrigin function is a standard optimization test case. It is a + multimodal non-convex function. While it has a large number of local + minima, the global minimum is located at the origin and where the function + value is zero. The standard search domain for optimization problems is the + hypercube [-5.12, 5.12]**d in d-dimensions. + + Args: + x: Real `Tensor` of shape [2]. The position at which to evaluate the + function. + + Returns: + value_and_gradient: A tuple of two `Tensor`s containing + value: A scalar `Tensor` of the function value at the supplied point. + gradient: A `Tensor` of shape [2] containing the gradient of the + function along the two axes. + """ + return tf.reduce_sum(x**2 - 10.0 * tf.cos(2 * np.pi * x)) + 10.0 * dim + + start_position = np.random.rand(dim) * 2.0 * 5.12 - 5.12 + + def get_results(): + start = tf.constant(start_position) + return self.evaluate(tfp.optimizer.lbfgs_minimize( + rastrigin, initial_position=start, tolerance=1e-5)) + + res1, res2 = get_results(), get_results() + + self.assertTrue(res1.converged) + self.assertEqual(res1.converged, res2.converged) + self.assertEqual(res1.failed, res2.failed) + self.assertEqual(res1.num_objective_evaluations, + res2.num_objective_evaluations) + self.assertArrayNear(res1.position, res2.position, 1e-5) + self.assertAlmostEqual(res1.objective_value, res2.objective_value) + self.assertArrayNear(res1.objective_gradient, res2.objective_gradient, 1e-5) + self.assertArrayNear(res1.position_deltas.reshape([-1]), + res2.position_deltas.reshape([-1]), 1e-5) + self.assertArrayNear(res1.gradient_deltas.reshape([-1]), + res2.gradient_deltas.reshape([-1]), 1e-5) + + # TODO: + def test_compile(self): + """Tests that the computation can be XLA-compiled.""" + + self.skip_if_no_xla() + + dim = 25 + + @_make_val_and_grad_fn + def rastrigin(x): + """The value and gradient of the Rastrigin function. + + The Rastrigin function is a standard optimization test case. It is a + multimodal non-convex function. While it has a large number of local + minima, the global minimum is located at the origin and where the function + value is zero. The standard search domain for optimization problems is the + hypercube [-5.12, 5.12]**d in d-dimensions. + + Args: + x: Real `Tensor` of shape [2]. The position at which to evaluate the + function. + + Returns: + value_and_gradient: A tuple of two `Tensor`s containing + value: A scalar `Tensor` of the function value at the supplied point. + gradient: A `Tensor` of shape [2] containing the gradient of the + function along the two axes. + """ + return tf.reduce_sum(x**2 - 10.0 * tf.cos(2 * np.pi * x)) + 10.0 * dim + + start_position = np.random.rand(dim) * 2.0 * 5.12 - 5.12 + + res = tf.function(tfp.optimizer.lbfgs_minimize, jit_compile=True)( + rastrigin, + initial_position=tf.constant(start_position), + tolerance=1e-5) + + # We simply verify execution & convergence. + self.assertTrue(self.evaluate(res.converged)) + + # TODO: + def test_dynamic_shapes(self): + """Can build an lbfgs_op with dynamic shapes in graph mode.""" + if tf.executing_eagerly(): return + ndims = 60 + minimum = np.ones([ndims], dtype='float64') + scales = np.arange(ndims, dtype='float64') + minimum + + @_make_val_and_grad_fn + def quadratic(x): + return tf.reduce_sum(scales * tf.math.squared_difference(x, minimum)) + + # Test with a vector of unknown dimension, and a fully unknown shape. + for shape in ([None], None): + start = tf1.placeholder(tf.float32, shape=shape) + lbfgs_op = tfp.optimizer.lbfgs_minimize( + quadratic, initial_position=start, tolerance=1e-8) + self.assertFalse(lbfgs_op.position.shape.is_fully_defined()) + + start_value = np.arange(ndims, 0, -1, dtype='float64') + with self.cached_session() as session: + results = session.run(lbfgs_op, feed_dict={start: start_value}) + self.assertTrue(results.converged) + self.assertLessEqual(_norm(results.objective_gradient), 1e-8) + self.assertArrayNear(results.position, minimum, 1e-5) + + # TODO: + @parameterized.named_parameters( + [{'testcase_name': '_from_start', 'start': np.array([0.8, 0.8])}, + {'testcase_name': '_during_opt', 'start': np.array([0.0, 0.0])}, + {'testcase_name': '_mixed', 'start': np.array([[0.8, 0.8], [0.0, 0.0]])}, + ]) + def test_stop_at_negative_infinity(self, start): + """Stops gently when encountering a -inf objective.""" + minimum = np.array([1.0, 1.0]) + scales = np.array([2.0, 3.0]) + + @_make_val_and_grad_fn + def quadratic_with_hole(x): + quadratic = tf.reduce_sum( + scales * tf.math.squared_difference(x, minimum), axis=-1) + square_hole = tf.reduce_all(tf.logical_and((x > 0.7), (x < 1.3)), axis=-1) + minus_infty = tf.constant(float('-inf'), dtype=quadratic.dtype) + answer = tf.where(square_hole, minus_infty, quadratic) + return answer + + start = tf.constant(start) + results = self.evaluate(tfp.optimizer.lbfgs_minimize( + quadratic_with_hole, initial_position=start, tolerance=1e-8)) + self.assertAllTrue(results.converged) + self.assertAllFalse(results.failed) + self.assertAllNegativeInf(results.objective_value) + self.assertAllFinite(results.position) + self.assertAllNegativeInf(quadratic_with_hole(results.position)[0]) + + # TODO: + @parameterized.named_parameters( + [{'testcase_name': '_from_start', 'start': np.array([0.8, 0.8])}, + {'testcase_name': '_during_opt', 'start': np.array([0.0, 0.0])}, + {'testcase_name': '_mixed', 'start': np.array([[0.8, 0.8], [0.0, 0.0]])}, + ]) + def test_fail_at_non_finite(self, start): + """Fails promptly when encountering a non-finite but not -inf objective.""" + # Meaning, +inf (tested here) and nan (not tested separately due to nearly + # identical code paths) objective values cause a "stop with failure". + # Actually, there is a further nitpick: +inf is currently treated a little + # inconsistently. To wit, if the outer loop hits a +inf, it gives up and + # reports failure, because it assumes the gradient from a +inf value is + # garbage and no further progress is possible. However, if the line search + # encounters an intermediate +inf, it in some cases knows to just treat it + # as a large finite value and avoid it. So in principle, minimizing this + # test function starting outside the +inf region could stop at the actual + # minimum at the edge of said +inf region. However, currently it happens to + # fail. + minimum = np.array([1.0, 1.0]) + scales = np.array([2.0, 3.0]) + + @_make_val_and_grad_fn + def quadratic_with_spike(x): + quadratic = tf.reduce_sum( + scales * tf.math.squared_difference(x, minimum), axis=-1) + square_hole = tf.reduce_all(tf.logical_and((x > 0.7), (x < 1.3)), axis=-1) + infty = tf.constant(float('+inf'), dtype=quadratic.dtype) + answer = tf.where(square_hole, infty, quadratic) + return answer + + start = tf.constant(start) + results = self.evaluate(tfp.optimizer.lbfgs_minimize( + quadratic_with_spike, initial_position=start, tolerance=1e-8)) + self.assertAllFalse(results.converged) + self.assertAllTrue(results.failed) + self.assertAllFinite(results.position) + + +if __name__ == '__main__': + tf.test.main()