-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #142 from firedrakeproject/underintegration
* underintegration: add a short comment fix previous bug expose delta_elimination bug in test case change default mode: coffee -> spectral add some more comments rename argument indices to linear indices move COFFEE algorithm from TSFC to GEM test underintegration tricks fix Python 2 flake8 fix failing test case rewrite spectral mode tolerate and skip monomials with no atomics delay index substitution in delta_elimination disconnect coffee_mode.Integrals from spectral.Integrals optimise away Delta in UFL arguments
- Loading branch information
Showing
7 changed files
with
512 additions
and
233 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
"""This module contains an implementation of the COFFEE optimisation | ||
algorithm operating on a GEM representation. | ||
This file is NOT for code generation as a COFFEE AST. | ||
""" | ||
|
||
from __future__ import absolute_import, print_function, division | ||
from six.moves import map, range | ||
|
||
from collections import OrderedDict | ||
import itertools | ||
import logging | ||
|
||
import numpy | ||
|
||
from gem.gem import IndexSum, one | ||
from gem.optimise import make_sum, make_product | ||
from gem.refactorise import Monomial | ||
from gem.utils import groupby | ||
|
||
|
||
try: | ||
from firedrake import Citations | ||
Citations().register("Luporini2016") | ||
except ImportError: | ||
pass | ||
|
||
|
||
__all__ = ['optimise_monomial_sum'] | ||
|
||
|
||
def monomial_sum_to_expression(monomial_sum): | ||
"""Convert a monomial sum to a GEM expression. | ||
:arg monomial_sum: an iterable of :class:`Monomial`s | ||
:returns: GEM expression | ||
""" | ||
indexsums = [] # The result is summation of indexsums | ||
# Group monomials according to their sum indices | ||
groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices)) | ||
# Create IndexSum's from each monomial group | ||
for _, monomials in groups: | ||
sum_indices = monomials[0].sum_indices | ||
products = [make_product(monomial.atomics + (monomial.rest,)) for monomial in monomials] | ||
indexsums.append(IndexSum(make_sum(products), sum_indices)) | ||
return make_sum(indexsums) | ||
|
||
|
||
def index_extent(factor, linear_indices): | ||
"""Compute the product of the extents of linear indices of a GEM expression | ||
:arg factor: GEM expression | ||
:arg linear_indices: set of linear indices | ||
:returns: product of extents of linear indices | ||
""" | ||
return numpy.prod([i.extent for i in factor.free_indices if i in linear_indices]) | ||
|
||
|
||
def find_optimal_atomics(monomials, linear_indices): | ||
"""Find optimal atomic common subexpressions, which produce least number of | ||
terms in the resultant IndexSum when factorised. | ||
:arg monomials: A list of :class:`Monomial`s, all of which should have | ||
the same sum indices | ||
:arg linear_indices: tuple of linear indices | ||
:returns: list of atomic GEM expressions | ||
""" | ||
atomics = tuple(OrderedDict.fromkeys(itertools.chain(*(monomial.atomics for monomial in monomials)))) | ||
|
||
def cost(solution): | ||
extent = sum(map(lambda atomic: index_extent(atomic, linear_indices), solution)) | ||
# Prefer shorter solutions, but larger extents | ||
return (len(solution), -extent) | ||
|
||
optimal_solution = set(atomics) # pessimal but feasible solution | ||
solution = set() | ||
|
||
max_it = 1 << 12 | ||
it = iter(range(max_it)) | ||
|
||
def solve(idx): | ||
while idx < len(monomials) and solution.intersection(monomials[idx].atomics): | ||
idx += 1 | ||
|
||
if idx < len(monomials): | ||
if len(solution) < len(optimal_solution): | ||
for atomic in monomials[idx].atomics: | ||
solution.add(atomic) | ||
solve(idx + 1) | ||
solution.remove(atomic) | ||
else: | ||
if cost(solution) < cost(optimal_solution): | ||
optimal_solution.clear() | ||
optimal_solution.update(solution) | ||
next(it) | ||
|
||
try: | ||
solve(0) | ||
except StopIteration: | ||
logger = logging.getLogger('tsfc') | ||
logger.warning("Solution to ILP problem may not be optimal: search " | ||
"interrupted after examining %d solutions.", max_it) | ||
|
||
return tuple(atomic for atomic in atomics if atomic in optimal_solution) | ||
|
||
|
||
def factorise_atomics(monomials, optimal_atomics, linear_indices): | ||
"""Group and factorise monomials using a list of atomics as common | ||
subexpressions. Create new monomials for each group and optimise them recursively. | ||
:arg monomials: an iterable of :class:`Monomial`s, all of which should have | ||
the same sum indices | ||
:arg optimal_atomics: list of tuples of atomics to be used as common subexpression | ||
:arg linear_indices: tuple of linear indices | ||
:returns: an iterable of :class:`Monomials`s after factorisation | ||
""" | ||
if not optimal_atomics or len(monomials) <= 1: | ||
return monomials | ||
|
||
# Group monomials with respect to each optimal atomic | ||
def group_key(monomial): | ||
for oa in optimal_atomics: | ||
if oa in monomial.atomics: | ||
return oa | ||
assert False, "Expect at least one optimal atomic per monomial." | ||
factor_group = groupby(monomials, key=group_key) | ||
|
||
# We should not drop monomials | ||
assert sum(len(ms) for _, ms in factor_group) == len(monomials) | ||
|
||
sum_indices = next(iter(monomials)).sum_indices | ||
new_monomials = [] | ||
for oa, monomials in factor_group: | ||
# Create new MonomialSum for the factorised out terms | ||
sub_monomials = [] | ||
for monomial in monomials: | ||
atomics = list(monomial.atomics) | ||
atomics.remove(oa) # remove common factor | ||
sub_monomials.append(Monomial((), tuple(atomics), monomial.rest)) | ||
# Continue to factorise the remaining expression | ||
sub_monomials = optimise_monomials(sub_monomials, linear_indices) | ||
if len(sub_monomials) == 1: | ||
# Factorised part is a product, we add back the common atomics then | ||
# add to new MonomialSum directly rather than forming a product node | ||
# Retaining the monomial structure enables applying associativity | ||
# when forming GEM nodes later. | ||
sub_monomial, = sub_monomials | ||
new_monomials.append( | ||
Monomial(sum_indices, (oa,) + sub_monomial.atomics, sub_monomial.rest)) | ||
else: | ||
# Factorised part is a summation, we need to create a new GEM node | ||
# and multiply with the common factor | ||
node = monomial_sum_to_expression(sub_monomials) | ||
# If the free indices of the new node intersect with linear indices, | ||
# add to the new monomial as `atomic`, otherwise add as `rest`. | ||
# Note: we might want to continue to factorise with the new atomics | ||
# by running optimise_monoials twice. | ||
if set(linear_indices) & set(node.free_indices): | ||
new_monomials.append(Monomial(sum_indices, (oa, node), one)) | ||
else: | ||
new_monomials.append(Monomial(sum_indices, (oa, ), node)) | ||
return new_monomials | ||
|
||
|
||
def optimise_monomial_sum(monomial_sum, linear_indices): | ||
"""Choose optimal common atomic subexpressions and factorise a | ||
:class:`MonomialSum` object to create a GEM expression. | ||
:arg monomial_sum: a :class:`MonomialSum` object | ||
:arg linear_indices: tuple of linear indices | ||
:returns: factorised GEM expression | ||
""" | ||
groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices)) | ||
new_monomials = [] | ||
for _, monomials in groups: | ||
new_monomials.extend(optimise_monomials(monomials, linear_indices)) | ||
return monomial_sum_to_expression(new_monomials) | ||
|
||
|
||
def optimise_monomials(monomials, linear_indices): | ||
"""Choose optimal common atomic subexpressions and factorise an iterable | ||
of monomials. | ||
:arg monomials: a list of :class:`Monomial`s, all of which should have | ||
the same sum indices | ||
:arg linear_indices: tuple of linear indices | ||
:returns: an iterable of factorised :class:`Monomials`s | ||
""" | ||
assert len(set(frozenset(m.sum_indices) for m in monomials)) <= 1,\ | ||
"All monomials required to have same sum indices for factorisation" | ||
|
||
result = [m for m in monomials if not m.atomics] # skipped monomials | ||
active_monomials = [m for m in monomials if m.atomics] | ||
optimal_atomics = find_optimal_atomics(active_monomials, linear_indices) | ||
result += factorise_atomics(active_monomials, optimal_atomics, linear_indices) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from __future__ import absolute_import, print_function, division | ||
|
||
import pytest | ||
|
||
from gem.gem import Delta, Identity, Index, Indexed, one | ||
from gem.optimise import delta_elimination, remove_componenttensors | ||
|
||
|
||
def test_delta_elimination(): | ||
i = Index() | ||
j = Index() | ||
k = Index() | ||
I = Identity(3) | ||
|
||
sum_indices = (i, j) | ||
factors = [Delta(i, j), Delta(i, k), Indexed(I, (j, k))] | ||
|
||
sum_indices, factors = delta_elimination(sum_indices, factors) | ||
factors = remove_componenttensors(factors) | ||
|
||
assert sum_indices == [] | ||
assert factors == [one, one, Indexed(I, (k, k))] | ||
|
||
|
||
if __name__ == "__main__": | ||
import os | ||
import sys | ||
pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from __future__ import absolute_import, print_function, division | ||
from six.moves import range | ||
|
||
from functools import reduce | ||
|
||
import numpy | ||
import pytest | ||
|
||
from coffee.visitors import EstimateFlops | ||
|
||
from ufl import (Mesh, FunctionSpace, FiniteElement, VectorElement, | ||
TestFunction, TrialFunction, TensorProductCell, dx, | ||
action, interval, quadrilateral, dot, grad) | ||
|
||
from FIAT import ufc_cell | ||
from FIAT.quadrature import GaussLobattoLegendreQuadratureLineRule | ||
|
||
from finat.point_set import GaussLobattoLegendrePointSet | ||
from finat.quadrature import QuadratureRule, TensorProductQuadratureRule | ||
|
||
from tsfc import compile_form | ||
|
||
|
||
def gll_quadrature_rule(cell, elem_deg): | ||
fiat_cell = ufc_cell("interval") | ||
fiat_rule = GaussLobattoLegendreQuadratureLineRule(fiat_cell, elem_deg + 1) | ||
line_rules = [QuadratureRule(GaussLobattoLegendrePointSet(fiat_rule.get_points()), | ||
fiat_rule.get_weights()) | ||
for _ in range(cell.topological_dimension())] | ||
finat_rule = reduce(lambda a, b: TensorProductQuadratureRule([a, b]), line_rules) | ||
return finat_rule | ||
|
||
|
||
def mass_cg(cell, degree): | ||
m = Mesh(VectorElement('Q', cell, 1)) | ||
V = FunctionSpace(m, FiniteElement('Q', cell, degree, variant='spectral')) | ||
u = TrialFunction(V) | ||
v = TestFunction(V) | ||
return u*v*dx(rule=gll_quadrature_rule(cell, degree)) | ||
|
||
|
||
def mass_dg(cell, degree): | ||
m = Mesh(VectorElement('Q', cell, 1)) | ||
V = FunctionSpace(m, FiniteElement('DQ', cell, degree, variant='spectral')) | ||
u = TrialFunction(V) | ||
v = TestFunction(V) | ||
# In this case, the estimated quadrature degree will give the | ||
# correct number of quadrature points by luck. | ||
return u*v*dx | ||
|
||
|
||
def laplace(cell, degree): | ||
m = Mesh(VectorElement('Q', cell, 1)) | ||
V = FunctionSpace(m, FiniteElement('Q', cell, degree, variant='spectral')) | ||
u = TrialFunction(V) | ||
v = TestFunction(V) | ||
return dot(grad(u), grad(v))*dx(rule=gll_quadrature_rule(cell, degree)) | ||
|
||
|
||
def count_flops(form): | ||
kernel, = compile_form(form, parameters=dict(mode='spectral')) | ||
return EstimateFlops().visit(kernel.ast) | ||
|
||
|
||
@pytest.mark.parametrize('form', [mass_cg, mass_dg]) | ||
@pytest.mark.parametrize(('cell', 'order'), | ||
[(quadrilateral, 2), | ||
(TensorProductCell(interval, interval), 2), | ||
(TensorProductCell(quadrilateral, interval), 3)]) | ||
def test_mass(form, cell, order): | ||
degrees = numpy.arange(4, 10) | ||
flops = [count_flops(form(cell, int(degree))) for degree in degrees] | ||
rates = numpy.diff(numpy.log(flops)) / numpy.diff(numpy.log(degrees + 1)) | ||
assert (rates < order).all() | ||
|
||
|
||
@pytest.mark.parametrize('form', [mass_cg, mass_dg]) | ||
@pytest.mark.parametrize(('cell', 'order'), | ||
[(quadrilateral, 2), | ||
(TensorProductCell(interval, interval), 2), | ||
(TensorProductCell(quadrilateral, interval), 3)]) | ||
def test_mass_action(form, cell, order): | ||
degrees = numpy.arange(4, 10) | ||
flops = [count_flops(action(form(cell, int(degree)))) for degree in degrees] | ||
rates = numpy.diff(numpy.log(flops)) / numpy.diff(numpy.log(degrees + 1)) | ||
assert (rates < order).all() | ||
|
||
|
||
@pytest.mark.parametrize(('cell', 'order'), | ||
[(quadrilateral, 4), | ||
(TensorProductCell(interval, interval), 4), | ||
(TensorProductCell(quadrilateral, interval), 5)]) | ||
def test_laplace(cell, order): | ||
degrees = numpy.arange(4, 10) | ||
flops = [count_flops(laplace(cell, int(degree))) for degree in degrees] | ||
rates = numpy.diff(numpy.log(flops)) / numpy.diff(numpy.log(degrees + 1)) | ||
assert (rates < order).all() | ||
|
||
|
||
if __name__ == "__main__": | ||
import os | ||
import sys | ||
pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:]) |
Oops, something went wrong.