Skip to content

Commit

Permalink
Merge pull request #142 from firedrakeproject/underintegration
Browse files Browse the repository at this point in the history
* underintegration:
  add a short comment
  fix previous bug
  expose delta_elimination bug in test case
  change default mode: coffee -> spectral
  add some more comments
  rename argument indices to linear indices
  move COFFEE algorithm from TSFC to GEM
  test underintegration tricks
  fix Python 2 flake8
  fix failing test case
  rewrite spectral mode
  tolerate and skip monomials with no atomics
  delay index substitution in delta_elimination
  disconnect coffee_mode.Integrals from spectral.Integrals
  optimise away Delta in UFL arguments
  • Loading branch information
miklos1 committed Aug 3, 2017
2 parents 8706fe4 + a783f99 commit 9519235
Show file tree
Hide file tree
Showing 7 changed files with 512 additions and 233 deletions.
202 changes: 202 additions & 0 deletions gem/coffee.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
"""This module contains an implementation of the COFFEE optimisation
algorithm operating on a GEM representation.
This file is NOT for code generation as a COFFEE AST.
"""

from __future__ import absolute_import, print_function, division
from six.moves import map, range

from collections import OrderedDict
import itertools
import logging

import numpy

from gem.gem import IndexSum, one
from gem.optimise import make_sum, make_product
from gem.refactorise import Monomial
from gem.utils import groupby


try:
from firedrake import Citations
Citations().register("Luporini2016")
except ImportError:
pass


__all__ = ['optimise_monomial_sum']


def monomial_sum_to_expression(monomial_sum):
"""Convert a monomial sum to a GEM expression.
:arg monomial_sum: an iterable of :class:`Monomial`s
:returns: GEM expression
"""
indexsums = [] # The result is summation of indexsums
# Group monomials according to their sum indices
groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices))
# Create IndexSum's from each monomial group
for _, monomials in groups:
sum_indices = monomials[0].sum_indices
products = [make_product(monomial.atomics + (monomial.rest,)) for monomial in monomials]
indexsums.append(IndexSum(make_sum(products), sum_indices))
return make_sum(indexsums)


def index_extent(factor, linear_indices):
"""Compute the product of the extents of linear indices of a GEM expression
:arg factor: GEM expression
:arg linear_indices: set of linear indices
:returns: product of extents of linear indices
"""
return numpy.prod([i.extent for i in factor.free_indices if i in linear_indices])


def find_optimal_atomics(monomials, linear_indices):
"""Find optimal atomic common subexpressions, which produce least number of
terms in the resultant IndexSum when factorised.
:arg monomials: A list of :class:`Monomial`s, all of which should have
the same sum indices
:arg linear_indices: tuple of linear indices
:returns: list of atomic GEM expressions
"""
atomics = tuple(OrderedDict.fromkeys(itertools.chain(*(monomial.atomics for monomial in monomials))))

def cost(solution):
extent = sum(map(lambda atomic: index_extent(atomic, linear_indices), solution))
# Prefer shorter solutions, but larger extents
return (len(solution), -extent)

optimal_solution = set(atomics) # pessimal but feasible solution
solution = set()

max_it = 1 << 12
it = iter(range(max_it))

def solve(idx):
while idx < len(monomials) and solution.intersection(monomials[idx].atomics):
idx += 1

if idx < len(monomials):
if len(solution) < len(optimal_solution):
for atomic in monomials[idx].atomics:
solution.add(atomic)
solve(idx + 1)
solution.remove(atomic)
else:
if cost(solution) < cost(optimal_solution):
optimal_solution.clear()
optimal_solution.update(solution)
next(it)

try:
solve(0)
except StopIteration:
logger = logging.getLogger('tsfc')
logger.warning("Solution to ILP problem may not be optimal: search "
"interrupted after examining %d solutions.", max_it)

return tuple(atomic for atomic in atomics if atomic in optimal_solution)


def factorise_atomics(monomials, optimal_atomics, linear_indices):
"""Group and factorise monomials using a list of atomics as common
subexpressions. Create new monomials for each group and optimise them recursively.
:arg monomials: an iterable of :class:`Monomial`s, all of which should have
the same sum indices
:arg optimal_atomics: list of tuples of atomics to be used as common subexpression
:arg linear_indices: tuple of linear indices
:returns: an iterable of :class:`Monomials`s after factorisation
"""
if not optimal_atomics or len(monomials) <= 1:
return monomials

# Group monomials with respect to each optimal atomic
def group_key(monomial):
for oa in optimal_atomics:
if oa in monomial.atomics:
return oa
assert False, "Expect at least one optimal atomic per monomial."
factor_group = groupby(monomials, key=group_key)

# We should not drop monomials
assert sum(len(ms) for _, ms in factor_group) == len(monomials)

sum_indices = next(iter(monomials)).sum_indices
new_monomials = []
for oa, monomials in factor_group:
# Create new MonomialSum for the factorised out terms
sub_monomials = []
for monomial in monomials:
atomics = list(monomial.atomics)
atomics.remove(oa) # remove common factor
sub_monomials.append(Monomial((), tuple(atomics), monomial.rest))
# Continue to factorise the remaining expression
sub_monomials = optimise_monomials(sub_monomials, linear_indices)
if len(sub_monomials) == 1:
# Factorised part is a product, we add back the common atomics then
# add to new MonomialSum directly rather than forming a product node
# Retaining the monomial structure enables applying associativity
# when forming GEM nodes later.
sub_monomial, = sub_monomials
new_monomials.append(
Monomial(sum_indices, (oa,) + sub_monomial.atomics, sub_monomial.rest))
else:
# Factorised part is a summation, we need to create a new GEM node
# and multiply with the common factor
node = monomial_sum_to_expression(sub_monomials)
# If the free indices of the new node intersect with linear indices,
# add to the new monomial as `atomic`, otherwise add as `rest`.
# Note: we might want to continue to factorise with the new atomics
# by running optimise_monoials twice.
if set(linear_indices) & set(node.free_indices):
new_monomials.append(Monomial(sum_indices, (oa, node), one))
else:
new_monomials.append(Monomial(sum_indices, (oa, ), node))
return new_monomials


def optimise_monomial_sum(monomial_sum, linear_indices):
"""Choose optimal common atomic subexpressions and factorise a
:class:`MonomialSum` object to create a GEM expression.
:arg monomial_sum: a :class:`MonomialSum` object
:arg linear_indices: tuple of linear indices
:returns: factorised GEM expression
"""
groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices))
new_monomials = []
for _, monomials in groups:
new_monomials.extend(optimise_monomials(monomials, linear_indices))
return monomial_sum_to_expression(new_monomials)


def optimise_monomials(monomials, linear_indices):
"""Choose optimal common atomic subexpressions and factorise an iterable
of monomials.
:arg monomials: a list of :class:`Monomial`s, all of which should have
the same sum indices
:arg linear_indices: tuple of linear indices
:returns: an iterable of factorised :class:`Monomials`s
"""
assert len(set(frozenset(m.sum_indices) for m in monomials)) <= 1,\
"All monomials required to have same sum indices for factorisation"

result = [m for m in monomials if not m.atomics] # skipped monomials
active_monomials = [m for m in monomials if m.atomics]
optimal_atomics = find_optimal_atomics(active_monomials, linear_indices)
result += factorise_atomics(active_monomials, optimal_atomics, linear_indices)
return result
16 changes: 13 additions & 3 deletions gem/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,15 @@ def delta_elimination(sum_indices, factors):
"""
sum_indices = list(sum_indices) # copy for modification

def substitute(expression, from_, to_):
if from_ not in expression.free_indices:
return expression
elif isinstance(expression, Delta):
mapper = MemoizerArg(filtered_replace_indices)
return mapper(expression, ((from_, to_),))
else:
return Indexed(ComponentTensor(expression, (from_,)), (to_,))

delta_queue = [(f, index)
for f in factors if isinstance(f, Delta)
for index in (f.i, f.j) if index in sum_indices]
Expand All @@ -251,8 +260,7 @@ def delta_elimination(sum_indices, factors):

sum_indices.remove(from_)

mapper = MemoizerArg(filtered_replace_indices)
factors = [mapper(e, ((from_, to_),)) for e in factors]
factors = [substitute(f, from_, to_) for f in factors]

delta_queue = [(f, index)
for f in factors if isinstance(f, Delta)
Expand Down Expand Up @@ -492,7 +500,9 @@ def contraction(expression):

# Flatten product tree, eliminate deltas, sum factorise
def rebuild(expression):
return sum_factorise(*delta_elimination(*traverse_product(expression)))
sum_indices, factors = delta_elimination(*traverse_product(expression))
factors = remove_componenttensors(factors)
return sum_factorise(sum_indices, factors)

# Sometimes the value shape is composed as a ListTensor, which
# could get in the way of decomposing factors. In particular,
Expand Down
28 changes: 28 additions & 0 deletions tests/test_delta_elimination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import absolute_import, print_function, division

import pytest

from gem.gem import Delta, Identity, Index, Indexed, one
from gem.optimise import delta_elimination, remove_componenttensors


def test_delta_elimination():
i = Index()
j = Index()
k = Index()
I = Identity(3)

sum_indices = (i, j)
factors = [Delta(i, j), Delta(i, k), Indexed(I, (j, k))]

sum_indices, factors = delta_elimination(sum_indices, factors)
factors = remove_componenttensors(factors)

assert sum_indices == []
assert factors == [one, one, Indexed(I, (k, k))]


if __name__ == "__main__":
import os
import sys
pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:])
103 changes: 103 additions & 0 deletions tests/test_underintegration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import absolute_import, print_function, division
from six.moves import range

from functools import reduce

import numpy
import pytest

from coffee.visitors import EstimateFlops

from ufl import (Mesh, FunctionSpace, FiniteElement, VectorElement,
TestFunction, TrialFunction, TensorProductCell, dx,
action, interval, quadrilateral, dot, grad)

from FIAT import ufc_cell
from FIAT.quadrature import GaussLobattoLegendreQuadratureLineRule

from finat.point_set import GaussLobattoLegendrePointSet
from finat.quadrature import QuadratureRule, TensorProductQuadratureRule

from tsfc import compile_form


def gll_quadrature_rule(cell, elem_deg):
fiat_cell = ufc_cell("interval")
fiat_rule = GaussLobattoLegendreQuadratureLineRule(fiat_cell, elem_deg + 1)
line_rules = [QuadratureRule(GaussLobattoLegendrePointSet(fiat_rule.get_points()),
fiat_rule.get_weights())
for _ in range(cell.topological_dimension())]
finat_rule = reduce(lambda a, b: TensorProductQuadratureRule([a, b]), line_rules)
return finat_rule


def mass_cg(cell, degree):
m = Mesh(VectorElement('Q', cell, 1))
V = FunctionSpace(m, FiniteElement('Q', cell, degree, variant='spectral'))
u = TrialFunction(V)
v = TestFunction(V)
return u*v*dx(rule=gll_quadrature_rule(cell, degree))


def mass_dg(cell, degree):
m = Mesh(VectorElement('Q', cell, 1))
V = FunctionSpace(m, FiniteElement('DQ', cell, degree, variant='spectral'))
u = TrialFunction(V)
v = TestFunction(V)
# In this case, the estimated quadrature degree will give the
# correct number of quadrature points by luck.
return u*v*dx


def laplace(cell, degree):
m = Mesh(VectorElement('Q', cell, 1))
V = FunctionSpace(m, FiniteElement('Q', cell, degree, variant='spectral'))
u = TrialFunction(V)
v = TestFunction(V)
return dot(grad(u), grad(v))*dx(rule=gll_quadrature_rule(cell, degree))


def count_flops(form):
kernel, = compile_form(form, parameters=dict(mode='spectral'))
return EstimateFlops().visit(kernel.ast)


@pytest.mark.parametrize('form', [mass_cg, mass_dg])
@pytest.mark.parametrize(('cell', 'order'),
[(quadrilateral, 2),
(TensorProductCell(interval, interval), 2),
(TensorProductCell(quadrilateral, interval), 3)])
def test_mass(form, cell, order):
degrees = numpy.arange(4, 10)
flops = [count_flops(form(cell, int(degree))) for degree in degrees]
rates = numpy.diff(numpy.log(flops)) / numpy.diff(numpy.log(degrees + 1))
assert (rates < order).all()


@pytest.mark.parametrize('form', [mass_cg, mass_dg])
@pytest.mark.parametrize(('cell', 'order'),
[(quadrilateral, 2),
(TensorProductCell(interval, interval), 2),
(TensorProductCell(quadrilateral, interval), 3)])
def test_mass_action(form, cell, order):
degrees = numpy.arange(4, 10)
flops = [count_flops(action(form(cell, int(degree)))) for degree in degrees]
rates = numpy.diff(numpy.log(flops)) / numpy.diff(numpy.log(degrees + 1))
assert (rates < order).all()


@pytest.mark.parametrize(('cell', 'order'),
[(quadrilateral, 4),
(TensorProductCell(interval, interval), 4),
(TensorProductCell(quadrilateral, interval), 5)])
def test_laplace(cell, order):
degrees = numpy.arange(4, 10)
flops = [count_flops(laplace(cell, int(degree))) for degree in degrees]
rates = numpy.diff(numpy.log(flops)) / numpy.diff(numpy.log(degrees + 1))
assert (rates < order).all()


if __name__ == "__main__":
import os
import sys
pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:])
Loading

0 comments on commit 9519235

Please sign in to comment.