Skip to content

Commit

Permalink
add support for directional derivative and jvp (#9)
Browse files Browse the repository at this point in the history
* add support for directional derivative and jvp

* add documentation and support for lambdas

* punctuation is important
  • Loading branch information
lucaferranti authored Oct 30, 2022
1 parent beff859 commit 8d84eaf
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 18 deletions.
52 changes: 41 additions & 11 deletions docs/source/api/differentiation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,46 @@ Functions for differentiation

.. function:: proc initdual(x: real)


Initializes the input to the appropriate dual number to evalute the derivative.

:arg x: point where to evaluate the derivative
:type x: real or [dom] real

:returns: If ``x`` is a real number, then it is initialized to :math:`x+\epsilon`. If ``x`` is a vector of reals, it is initialized to the vector of multiduals :math:`\begin{bmatrix}x_1+\epsilon_1\\\vdots\\x_n+\epsilon_n\end{bmatrix}`.
:rtype: ``dual`` if ``x`` is ``real`` or ``[dom] multidual`` if ``x`` is ``[dom] real``.


.. function:: proc initdual(x: [?D] ?t, v: [D] ?s)


Given a vector ``x`` and a vector ``v``, creates a vector of duals :math:`[x_i + \epsilon v_i]`.
Used to compute directional derivative and Jacobian-vector product (JVP).

:arg x: point where to evaluate the directional derivative / JVP
:type x: [D] real

:arg v: direction
:type v: [D] real

:returns: Vector of dual numbers
:rtype: [D] dual

.. function:: proc derivative(f, x: real)


Evaluates the derivative of ``f`` at ``x``.

:arg f: Function, note that this must be a concrete function.
:type f: Function

:arg x: point at which the derivative is evaluated
:type x: real

:returns: value of f'(x)
:rtype: real

Note that `f` must be a concrete function, if it's written as a generic function, you can pass ``derivative`` a lambda as follows

.. code-block:: chapel
proc f(x) {
Expand Down Expand Up @@ -95,7 +109,7 @@ Functions for differentiation
:returns: value of :math:`J_f`
:rtype: [Dout, Din] real

Note that `f` must be a concrete function, if it's written as a generic function, you can pass ``jacobian`` a lambda as follows
Note that ``f`` must be a concrete function, if it's written as a generic function, you can pass ``jacobian`` a lambda as follows

.. code-block:: chapel
Expand All @@ -122,4 +136,20 @@ Functions for differentiation
Extracts the function value.

:arg x: result of computations using dual numbers.
:type x: dual, multidual or [] multidual.
:type x: dual, multidual or [] multidual.

.. function:: proc directionalDerivative(x: dual)

Extracts the directional derivative from a dual number.

.. function:: proc directionalDerivative(f, x: [?D], v: [D])

Computes the directional derivative of ``f`` at ``x`` in the direction of ``v``.

.. function:: proc jvp(x: [] dual)

Extracts the Jacobian-vector product from a vector of dual numbers.

.. function:: proc jvp(f, x: [?D], v: [D])

Computes the Jacobian-vector product of the Jacobian of ``f`` at ``x`` and vector ``v``.
67 changes: 66 additions & 1 deletion docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ Next, we can compute the gradient similarly to before
Computing the Jacobian
**********************

For many-variables manyvalued functions :math:`f:\mathbb{R}^m\rightarrow\mathbb{R}^n` we can compute the Jacobian :math:`J_f`. Both methods described so far still apply.
For many-variables manyvalued functions :math:`F:\mathbb{R}^m\rightarrow\mathbb{R}^n` we can compute the Jacobian :math:`J_F`. Both methods described so far still apply.

Using ``initdual`` the strategy is very similar to before, except that now the Jacobian should be extracted using the ``jacobian`` function.

Expand Down Expand Up @@ -181,3 +181,68 @@ Using the example function above
2.0 1.0
3.0 5.0
Computing directional derivative and Jacobian-vector product
************************************************************

In some applications, instead of the gradient, one may need to compute the directional derivative, that is, the dot product :math:`\nabla f(\mathbf{x})\cdot \mathbf{v}`, where :math:`\mathbf{v}` is the direction vector.
Instead of computing the gradient and dot product separately, one can directly compute the directional derivative by evaluating :math:`f` over the vector of dual numbers :math:`[x_1+v_1\epsilon,\ldots,x_n+v_n\epsilon]^T` and taking the dual number of the result.

In practice, this is achieved by passing both the point ``x`` and the direction ``v`` to ``initdual``.
The directional derivative can be extracted using ``directionalDerivative``.

.. code-block:: chapel
proc f(x) {
return x[0] ** 2 + 3 * x[0] * x[1];
}
var dirder = f(initdual([1, 2], [0.5, 2.0]));
writeln(value(dirder));
writeln(directionalDerivative(dirder));
.. code-block::
7.0
10.0
Similarly, for many-valued functions, one may compute the Jacobian-vector product (VJP) :math:`J_F\mathbf{v}` directly using the same strategy. The JVP can be extracted using the ``jvp`` function.

.. code-block:: chapel
proc F(x) {
return [x[0] ** 2 + x[1] + 1, x[0] + x[1] ** 2 + x[0] * x[1]];
}
var valjvp = F(initdual([1, 2], [0.5, 2.0]));
writeln(value(valjvp), "\n");
writeln(jvp(valjvp));
.. code-block::
4.0 7.0
3.0 11.5
As for the previous cases, ``directionalDerivative`` and ``jvp`` can also take a function as input.
Similarly to ``gradient`` and ``jacobian``, the function passed cannot be generic and the domain must be written down as a type variable.

.. code-block:: chapel
type D = [0..#2] dual;
var dirder = directionalDerivative(lambda(x: D) {return f(x);}, [1, 2], [0.5, 2.0]);
Jv = jvp(lambda(x: D) {return F(x);}, [1, 2], [0.5, 2.0]);
writeln(dirder, "\n");
writeln(Jv);
.. code-block::
10.0
3.0 11.5
41 changes: 38 additions & 3 deletions src/differentiation.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ module differentiation {
}

pragma "no doc"
proc initdual(x : [?D] real) {
proc initdual(x : [?D] ?t) {
var x0 : [D] multidual;
forall i in D {
var eps : [D] real = 0.0;
Expand All @@ -26,6 +26,22 @@ module differentiation {
return x0;
}

/*
Given a vector ``x`` and a vector ``v``, creates a vector of duals :math:`[x_i + \epsilon v_i]`.
Used to compute directional derivative and Jacobian-vector product (JVP).
:arg x: point where to evaluate the directional derivative / JVP
:type x: [D] real
:arg v: direction
:type v: [D] real
:returns: Vector of dual numbers
:rtype: [D] dual
*/
proc initdual(x: [?D] ?t, v: [D] ?s) {
return [(xi, vi) in zip(x, v)] todual(xi, vi);
}

/*
Evaluates the derivative of ``f`` at ``x``.
Expand Down Expand Up @@ -82,7 +98,6 @@ module differentiation {
}
type D = [0..#2] multidual; // domain for the lambda function
var dh = gradient(lambda(x : D){return h(x);}, [1.0, 2.0]);
//outputs
//8.0 3.0
Expand All @@ -108,7 +123,7 @@ module differentiation {
:returns: value of :math:`J_f`
:rtype: [Dout, Din] real
Note that `f` must be a concrete function, if it's written as a generic function, you can pass ``jacobian`` a lambda as follows
Note that ``f`` must be a concrete function, if it's written as a generic function, you can pass ``jacobian`` a lambda as follows
.. code-block:: chapel
Expand Down Expand Up @@ -141,4 +156,24 @@ module differentiation {
:type x: dual, multidual or [] multidual.
*/
proc value(x) {return primalPart(x);}

/* Extracts the directional derivative from a dual number. */
proc directionalDerivative(x: dual) {
return dualPart(x);
}

/* Computes the directional derivative of ``f`` at ``x`` in the direction of ``v``. */
proc directionalDerivative(f, x: [?D], v: [D]) {
return dualPart(f(initdual(x, v)));
}

/* Extracts the Jacobian-vector product from a vector of dual numbers. */
proc jvp(x: [] dual) {
return dualPart(x);
}

/* Computes the Jacobian-vector product of the Jacobian of ``f`` at ``x`` and vector ``v``. */
proc jvp(f, x: [?D], v: [D]) {
return dualPart(f(initdual(x, v)));
}
}
6 changes: 5 additions & 1 deletion src/dualtype.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,18 @@ module dualtype {
*/
proc dualPart(a) where isDualType(a.type) {return a.dualPart;}

proc primalPart(a : [] multidual) {return [i in a] primalPart(i);}
proc primalPart(a : [] ?t) where isDualType(t) {return [i in a] primalPart(i);}

proc dualPart(a : [?Dout] multidual, Din : domain(1) = a(0).dom) {
var res : [Dout.dim(0), Din.dim(0)] real;
[i in Dout] res(i, Din) = dualPart(a(i));
return res;
}

proc dualPart(a: [] dual) {
return [ai in a] dualPart(ai);
}

pragma "no doc"
proc primalPart(a) {return a;}

Expand Down
29 changes: 27 additions & 2 deletions test/test_differentiation.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use UnitTest;
use ForwardModeAD;

type D = [0..#2] multidual;
type D2 = [0..#2] dual;

proc testUnivariateFunctions(test: borrowed Test) throws {
proc f(x) {
Expand All @@ -28,8 +29,7 @@ proc testGradient(test: borrowed Test) throws {
return x[0] ** 2 + 3 * x[0] * x[1];
}

// TODO: debug why doesn't work with integer input
var valgradh = h(initdual([1.0, 2.0]));
var valgradh = h(initdual([1, 2]));
test.assertEqual(value(valgradh), 7);
test.assertEqual(gradient(valgradh), [8.0, 3.0]);
}
Expand All @@ -53,4 +53,29 @@ proc testJacobian(test: borrowed Test) throws {
test.assertEqual(Jg, _Jg);
}

proc testDirectionalAndJvp(test: borrowed Test) throws {
proc f(x) {
return x[0] ** 2 + 3 * x[0] * x[1];
}

var valdirder = f(initdual([1, 2], [0.5, 2.0]));

test.assertEqual(value(valdirder), 7);
test.assertEqual(directionalDerivative(valdirder), 10);

var dirder = directionalDerivative(lambda(x: D2) {return f(x);}, [1, 2], [0.5, 2.0]);
test.assertEqual(dirder, 10);

proc F(x) {
return [x[0] ** 2 + x[1] + 1, x[0] + x[1] ** 2 + x[0] * x[1]];
}

var valjvp = F(initdual([1, 2], [0.5, 2.0]));
test.assertEqual(value(valjvp), [4.0, 7.0]);
test.assertEqual(jvp(valjvp), [3.0, 11.5]);

var Jv = jvp(lambda(x: D2) {return F(x);}, [1, 2], [0.5, 2.0]);
test.assertEqual(Jv, [3.0, 11.5]);
}

UnitTest.main();

0 comments on commit 8d84eaf

Please sign in to comment.