diff --git a/gem/coffee.py b/gem/coffee.py new file mode 100644 index 00000000..6415e28d --- /dev/null +++ b/gem/coffee.py @@ -0,0 +1,202 @@ +"""This module contains an implementation of the COFFEE optimisation +algorithm operating on a GEM representation. + +This file is NOT for code generation as a COFFEE AST. +""" + +from __future__ import absolute_import, print_function, division +from six.moves import map, range + +from collections import OrderedDict +import itertools +import logging + +import numpy + +from gem.gem import IndexSum, one +from gem.optimise import make_sum, make_product +from gem.refactorise import Monomial +from gem.utils import groupby + + +try: + from firedrake import Citations + Citations().register("Luporini2016") +except ImportError: + pass + + +__all__ = ['optimise_monomial_sum'] + + +def monomial_sum_to_expression(monomial_sum): + """Convert a monomial sum to a GEM expression. + + :arg monomial_sum: an iterable of :class:`Monomial`s + + :returns: GEM expression + """ + indexsums = [] # The result is summation of indexsums + # Group monomials according to their sum indices + groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices)) + # Create IndexSum's from each monomial group + for _, monomials in groups: + sum_indices = monomials[0].sum_indices + products = [make_product(monomial.atomics + (monomial.rest,)) for monomial in monomials] + indexsums.append(IndexSum(make_sum(products), sum_indices)) + return make_sum(indexsums) + + +def index_extent(factor, linear_indices): + """Compute the product of the extents of linear indices of a GEM expression + + :arg factor: GEM expression + :arg linear_indices: set of linear indices + + :returns: product of extents of linear indices + """ + return numpy.prod([i.extent for i in factor.free_indices if i in linear_indices]) + + +def find_optimal_atomics(monomials, linear_indices): + """Find optimal atomic common subexpressions, which produce least number of + terms in the resultant IndexSum when factorised. + + :arg monomials: A list of :class:`Monomial`s, all of which should have + the same sum indices + :arg linear_indices: tuple of linear indices + + :returns: list of atomic GEM expressions + """ + atomics = tuple(OrderedDict.fromkeys(itertools.chain(*(monomial.atomics for monomial in monomials)))) + + def cost(solution): + extent = sum(map(lambda atomic: index_extent(atomic, linear_indices), solution)) + # Prefer shorter solutions, but larger extents + return (len(solution), -extent) + + optimal_solution = set(atomics) # pessimal but feasible solution + solution = set() + + max_it = 1 << 12 + it = iter(range(max_it)) + + def solve(idx): + while idx < len(monomials) and solution.intersection(monomials[idx].atomics): + idx += 1 + + if idx < len(monomials): + if len(solution) < len(optimal_solution): + for atomic in monomials[idx].atomics: + solution.add(atomic) + solve(idx + 1) + solution.remove(atomic) + else: + if cost(solution) < cost(optimal_solution): + optimal_solution.clear() + optimal_solution.update(solution) + next(it) + + try: + solve(0) + except StopIteration: + logger = logging.getLogger('tsfc') + logger.warning("Solution to ILP problem may not be optimal: search " + "interrupted after examining %d solutions.", max_it) + + return tuple(atomic for atomic in atomics if atomic in optimal_solution) + + +def factorise_atomics(monomials, optimal_atomics, linear_indices): + """Group and factorise monomials using a list of atomics as common + subexpressions. Create new monomials for each group and optimise them recursively. + + :arg monomials: an iterable of :class:`Monomial`s, all of which should have + the same sum indices + :arg optimal_atomics: list of tuples of atomics to be used as common subexpression + :arg linear_indices: tuple of linear indices + + :returns: an iterable of :class:`Monomials`s after factorisation + """ + if not optimal_atomics or len(monomials) <= 1: + return monomials + + # Group monomials with respect to each optimal atomic + def group_key(monomial): + for oa in optimal_atomics: + if oa in monomial.atomics: + return oa + assert False, "Expect at least one optimal atomic per monomial." + factor_group = groupby(monomials, key=group_key) + + # We should not drop monomials + assert sum(len(ms) for _, ms in factor_group) == len(monomials) + + sum_indices = next(iter(monomials)).sum_indices + new_monomials = [] + for oa, monomials in factor_group: + # Create new MonomialSum for the factorised out terms + sub_monomials = [] + for monomial in monomials: + atomics = list(monomial.atomics) + atomics.remove(oa) # remove common factor + sub_monomials.append(Monomial((), tuple(atomics), monomial.rest)) + # Continue to factorise the remaining expression + sub_monomials = optimise_monomials(sub_monomials, linear_indices) + if len(sub_monomials) == 1: + # Factorised part is a product, we add back the common atomics then + # add to new MonomialSum directly rather than forming a product node + # Retaining the monomial structure enables applying associativity + # when forming GEM nodes later. + sub_monomial, = sub_monomials + new_monomials.append( + Monomial(sum_indices, (oa,) + sub_monomial.atomics, sub_monomial.rest)) + else: + # Factorised part is a summation, we need to create a new GEM node + # and multiply with the common factor + node = monomial_sum_to_expression(sub_monomials) + # If the free indices of the new node intersect with linear indices, + # add to the new monomial as `atomic`, otherwise add as `rest`. + # Note: we might want to continue to factorise with the new atomics + # by running optimise_monoials twice. + if set(linear_indices) & set(node.free_indices): + new_monomials.append(Monomial(sum_indices, (oa, node), one)) + else: + new_monomials.append(Monomial(sum_indices, (oa, ), node)) + return new_monomials + + +def optimise_monomial_sum(monomial_sum, linear_indices): + """Choose optimal common atomic subexpressions and factorise a + :class:`MonomialSum` object to create a GEM expression. + + :arg monomial_sum: a :class:`MonomialSum` object + :arg linear_indices: tuple of linear indices + + :returns: factorised GEM expression + """ + groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices)) + new_monomials = [] + for _, monomials in groups: + new_monomials.extend(optimise_monomials(monomials, linear_indices)) + return monomial_sum_to_expression(new_monomials) + + +def optimise_monomials(monomials, linear_indices): + """Choose optimal common atomic subexpressions and factorise an iterable + of monomials. + + :arg monomials: a list of :class:`Monomial`s, all of which should have + the same sum indices + :arg linear_indices: tuple of linear indices + + :returns: an iterable of factorised :class:`Monomials`s + """ + assert len(set(frozenset(m.sum_indices) for m in monomials)) <= 1,\ + "All monomials required to have same sum indices for factorisation" + + result = [m for m in monomials if not m.atomics] # skipped monomials + active_monomials = [m for m in monomials if m.atomics] + optimal_atomics = find_optimal_atomics(active_monomials, linear_indices) + result += factorise_atomics(active_monomials, optimal_atomics, linear_indices) + return result diff --git a/gem/optimise.py b/gem/optimise.py index 11979fe1..d19006de 100644 --- a/gem/optimise.py +++ b/gem/optimise.py @@ -242,6 +242,15 @@ def delta_elimination(sum_indices, factors): """ sum_indices = list(sum_indices) # copy for modification + def substitute(expression, from_, to_): + if from_ not in expression.free_indices: + return expression + elif isinstance(expression, Delta): + mapper = MemoizerArg(filtered_replace_indices) + return mapper(expression, ((from_, to_),)) + else: + return Indexed(ComponentTensor(expression, (from_,)), (to_,)) + delta_queue = [(f, index) for f in factors if isinstance(f, Delta) for index in (f.i, f.j) if index in sum_indices] @@ -251,8 +260,7 @@ def delta_elimination(sum_indices, factors): sum_indices.remove(from_) - mapper = MemoizerArg(filtered_replace_indices) - factors = [mapper(e, ((from_, to_),)) for e in factors] + factors = [substitute(f, from_, to_) for f in factors] delta_queue = [(f, index) for f in factors if isinstance(f, Delta) @@ -492,7 +500,9 @@ def contraction(expression): # Flatten product tree, eliminate deltas, sum factorise def rebuild(expression): - return sum_factorise(*delta_elimination(*traverse_product(expression))) + sum_indices, factors = delta_elimination(*traverse_product(expression)) + factors = remove_componenttensors(factors) + return sum_factorise(sum_indices, factors) # Sometimes the value shape is composed as a ListTensor, which # could get in the way of decomposing factors. In particular, diff --git a/tests/test_delta_elimination.py b/tests/test_delta_elimination.py new file mode 100644 index 00000000..df76629f --- /dev/null +++ b/tests/test_delta_elimination.py @@ -0,0 +1,28 @@ +from __future__ import absolute_import, print_function, division + +import pytest + +from gem.gem import Delta, Identity, Index, Indexed, one +from gem.optimise import delta_elimination, remove_componenttensors + + +def test_delta_elimination(): + i = Index() + j = Index() + k = Index() + I = Identity(3) + + sum_indices = (i, j) + factors = [Delta(i, j), Delta(i, k), Indexed(I, (j, k))] + + sum_indices, factors = delta_elimination(sum_indices, factors) + factors = remove_componenttensors(factors) + + assert sum_indices == [] + assert factors == [one, one, Indexed(I, (k, k))] + + +if __name__ == "__main__": + import os + import sys + pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:]) diff --git a/tests/test_underintegration.py b/tests/test_underintegration.py new file mode 100644 index 00000000..9bd2cc1c --- /dev/null +++ b/tests/test_underintegration.py @@ -0,0 +1,103 @@ +from __future__ import absolute_import, print_function, division +from six.moves import range + +from functools import reduce + +import numpy +import pytest + +from coffee.visitors import EstimateFlops + +from ufl import (Mesh, FunctionSpace, FiniteElement, VectorElement, + TestFunction, TrialFunction, TensorProductCell, dx, + action, interval, quadrilateral, dot, grad) + +from FIAT import ufc_cell +from FIAT.quadrature import GaussLobattoLegendreQuadratureLineRule + +from finat.point_set import GaussLobattoLegendrePointSet +from finat.quadrature import QuadratureRule, TensorProductQuadratureRule + +from tsfc import compile_form + + +def gll_quadrature_rule(cell, elem_deg): + fiat_cell = ufc_cell("interval") + fiat_rule = GaussLobattoLegendreQuadratureLineRule(fiat_cell, elem_deg + 1) + line_rules = [QuadratureRule(GaussLobattoLegendrePointSet(fiat_rule.get_points()), + fiat_rule.get_weights()) + for _ in range(cell.topological_dimension())] + finat_rule = reduce(lambda a, b: TensorProductQuadratureRule([a, b]), line_rules) + return finat_rule + + +def mass_cg(cell, degree): + m = Mesh(VectorElement('Q', cell, 1)) + V = FunctionSpace(m, FiniteElement('Q', cell, degree, variant='spectral')) + u = TrialFunction(V) + v = TestFunction(V) + return u*v*dx(rule=gll_quadrature_rule(cell, degree)) + + +def mass_dg(cell, degree): + m = Mesh(VectorElement('Q', cell, 1)) + V = FunctionSpace(m, FiniteElement('DQ', cell, degree, variant='spectral')) + u = TrialFunction(V) + v = TestFunction(V) + # In this case, the estimated quadrature degree will give the + # correct number of quadrature points by luck. + return u*v*dx + + +def laplace(cell, degree): + m = Mesh(VectorElement('Q', cell, 1)) + V = FunctionSpace(m, FiniteElement('Q', cell, degree, variant='spectral')) + u = TrialFunction(V) + v = TestFunction(V) + return dot(grad(u), grad(v))*dx(rule=gll_quadrature_rule(cell, degree)) + + +def count_flops(form): + kernel, = compile_form(form, parameters=dict(mode='spectral')) + return EstimateFlops().visit(kernel.ast) + + +@pytest.mark.parametrize('form', [mass_cg, mass_dg]) +@pytest.mark.parametrize(('cell', 'order'), + [(quadrilateral, 2), + (TensorProductCell(interval, interval), 2), + (TensorProductCell(quadrilateral, interval), 3)]) +def test_mass(form, cell, order): + degrees = numpy.arange(4, 10) + flops = [count_flops(form(cell, int(degree))) for degree in degrees] + rates = numpy.diff(numpy.log(flops)) / numpy.diff(numpy.log(degrees + 1)) + assert (rates < order).all() + + +@pytest.mark.parametrize('form', [mass_cg, mass_dg]) +@pytest.mark.parametrize(('cell', 'order'), + [(quadrilateral, 2), + (TensorProductCell(interval, interval), 2), + (TensorProductCell(quadrilateral, interval), 3)]) +def test_mass_action(form, cell, order): + degrees = numpy.arange(4, 10) + flops = [count_flops(action(form(cell, int(degree)))) for degree in degrees] + rates = numpy.diff(numpy.log(flops)) / numpy.diff(numpy.log(degrees + 1)) + assert (rates < order).all() + + +@pytest.mark.parametrize(('cell', 'order'), + [(quadrilateral, 4), + (TensorProductCell(interval, interval), 4), + (TensorProductCell(quadrilateral, interval), 5)]) +def test_laplace(cell, order): + degrees = numpy.arange(4, 10) + flops = [count_flops(laplace(cell, int(degree))) for degree in degrees] + rates = numpy.diff(numpy.log(flops)) / numpy.diff(numpy.log(degrees + 1)) + assert (rates < order).all() + + +if __name__ == "__main__": + import os + import sys + pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:]) diff --git a/tsfc/coffee_mode.py b/tsfc/coffee_mode.py index c80298f3..2dc62174 100644 --- a/tsfc/coffee_mode.py +++ b/tsfc/coffee_mode.py @@ -1,23 +1,40 @@ from __future__ import absolute_import, print_function, division -from six.moves import map, range, zip +from six.moves import zip -import numpy -import itertools from functools import partial, reduce -from collections import OrderedDict -from gem.optimise import make_sum, make_product -from gem.refactorise import Monomial, collect_monomials -from gem.unconcatenate import unconcatenate + from gem.node import traversal -from gem.gem import IndexSum, Failure, Sum, one +from gem.gem import Failure, Sum, index_sum +from gem.optimise import replace_division, unroll_indexsum +from gem.refactorise import collect_monomials +from gem.unconcatenate import unconcatenate +from gem.coffee import optimise_monomial_sum from gem.utils import groupby -from tsfc.logging import logger - import tsfc.spectral as spectral -Integrals = spectral.Integrals +def Integrals(expressions, quadrature_multiindex, argument_multiindices, parameters): + """Constructs an integral representation for each GEM integrand + expression. + + :arg expressions: integrand multiplied with quadrature weight; + multi-root GEM expression DAG + :arg quadrature_multiindex: quadrature multiindex (tuple) + :arg argument_multiindices: tuple of argument multiindices, + one multiindex for each argument + :arg parameters: parameters dictionary + + :returns: list of integral representations + """ + # Unroll + max_extent = parameters["unroll_indexsum"] + if max_extent: + def predicate(index): + return index.extent <= max_extent + expressions = unroll_indexsum(expressions, predicate=predicate) + # Integral representation: just a GEM expression + return replace_division([index_sum(e, quadrature_multiindex) for e in expressions]) def flatten(var_reps, index_cache): @@ -29,11 +46,6 @@ def flatten(var_reps, index_cache): :returns: series of (return variable, GEM expression root) pairs """ - try: - from firedrake import Citations - Citations().register("Luporini2016") - except ImportError: - pass assignments = unconcatenate([(variable, reduce(Sum, reps)) for variable, reps in var_reps], cache=index_cache) @@ -69,172 +81,3 @@ def optimise_expressions(expressions, argument_indices): classifier = partial(spectral.classify, set(argument_indices)) monomial_sums = collect_monomials(expressions, classifier) return [optimise_monomial_sum(ms, argument_indices) for ms in monomial_sums] - - -def monomial_sum_to_expression(monomial_sum): - """Convert a monomial sum to a GEM expression. - - :arg monomial_sum: an iterable of :class:`Monomial`s - - :returns: GEM expression - """ - indexsums = [] # The result is summation of indexsums - # Group monomials according to their sum indices - groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices)) - # Create IndexSum's from each monomial group - for _, monomials in groups: - sum_indices = monomials[0].sum_indices - products = [make_product(monomial.atomics + (monomial.rest,)) for monomial in monomials] - indexsums.append(IndexSum(make_sum(products), sum_indices)) - return make_sum(indexsums) - - -def index_extent(factor, argument_indices): - """Compute the product of the extents of argument indices of a GEM expression - - :arg factor: GEM expression - :arg argument_indices: set of argument indices - - :returns: product of extents of argument indices - """ - return numpy.prod([i.extent for i in factor.free_indices if i in argument_indices]) - - -def find_optimal_atomics(monomials, argument_indices): - """Find optimal atomic common subexpressions, which produce least number of - terms in the resultant IndexSum when factorised. - - :arg monomials: A list of :class:`Monomial`s, all of which should have - the same sum indices - :arg argument_indices: tuple of argument indices - - :returns: list of atomic GEM expressions - """ - atomics = tuple(OrderedDict.fromkeys(itertools.chain(*(monomial.atomics for monomial in monomials)))) - - def cost(solution): - extent = sum(map(lambda atomic: index_extent(atomic, argument_indices), solution)) - # Prefer shorter solutions, but larger extents - return (len(solution), -extent) - - optimal_solution = set(atomics) # pessimal but feasible solution - solution = set() - - max_it = 1 << 12 - it = iter(range(max_it)) - - def solve(idx): - while idx < len(monomials) and solution.intersection(monomials[idx].atomics): - idx += 1 - - if idx < len(monomials): - if len(solution) < len(optimal_solution): - for atomic in monomials[idx].atomics: - solution.add(atomic) - solve(idx + 1) - solution.remove(atomic) - else: - if cost(solution) < cost(optimal_solution): - optimal_solution.clear() - optimal_solution.update(solution) - next(it) - - try: - solve(0) - except StopIteration: - logger.warning("Solution to ILP problem may not be optimal: search " - "interrupted after examining %d solutions.", max_it) - - return tuple(atomic for atomic in atomics if atomic in optimal_solution) - - -def factorise_atomics(monomials, optimal_atomics, argument_indices): - """Group and factorise monomials using a list of atomics as common - subexpressions. Create new monomials for each group and optimise them recursively. - - :arg monomials: an iterable of :class:`Monomial`s, all of which should have - the same sum indices - :arg optimal_atomics: list of tuples of atomics to be used as common subexpression - :arg argument_indices: tuple of argument indices - - :returns: an iterable of :class:`Monomials`s after factorisation - """ - if not optimal_atomics or len(monomials) <= 1: - return monomials - - # Group monomials with respect to each optimal atomic - def group_key(monomial): - for oa in optimal_atomics: - if oa in monomial.atomics: - return oa - assert False, "Expect at least one optimal atomic per monomial." - factor_group = groupby(monomials, key=group_key) - - # We should not drop monomials - assert sum(len(ms) for _, ms in factor_group) == len(monomials) - - sum_indices = next(iter(monomials)).sum_indices - new_monomials = [] - for oa, monomials in factor_group: - # Create new MonomialSum for the factorised out terms - sub_monomials = [] - for monomial in monomials: - atomics = list(monomial.atomics) - atomics.remove(oa) # remove common factor - sub_monomials.append(Monomial((), tuple(atomics), monomial.rest)) - # Continue to factorise the remaining expression - sub_monomials = optimise_monomials(sub_monomials, argument_indices) - if len(sub_monomials) == 1: - # Factorised part is a product, we add back the common atomics then - # add to new MonomialSum directly rather than forming a product node - # Retaining the monomial structure enables applying associativity - # when forming GEM nodes later. - sub_monomial, = sub_monomials - new_monomials.append( - Monomial(sum_indices, (oa,) + sub_monomial.atomics, sub_monomial.rest)) - else: - # Factorised part is a summation, we need to create a new GEM node - # and multiply with the common factor - node = monomial_sum_to_expression(sub_monomials) - # If the free indices of the new node intersect with argument indices, - # add to the new monomial as `atomic`, otherwise add as `rest`. - # Note: we might want to continue to factorise with the new atomics - # by running optimise_monoials twice. - if set(argument_indices) & set(node.free_indices): - new_monomials.append(Monomial(sum_indices, (oa, node), one)) - else: - new_monomials.append(Monomial(sum_indices, (oa, ), node)) - return new_monomials - - -def optimise_monomial_sum(monomial_sum, argument_indices): - """Choose optimal common atomic subexpressions and factorise a - :class:`MonomialSum` object to create a GEM expression. - - :arg monomial_sum: a :class:`MonomialSum` object - :arg argument_indices: tuple of argument indices - - :returns: factorised GEM expression - """ - groups = groupby(monomial_sum, key=lambda m: frozenset(m.sum_indices)) - new_monomials = [] - for _, monomials in groups: - new_monomials.extend(optimise_monomials(monomials, argument_indices)) - return monomial_sum_to_expression(new_monomials) - - -def optimise_monomials(monomials, argument_indices): - """Choose optimal common atomic subexpressions and factorise an iterable - of monomials. - - :arg monomials: a list of :class:`Monomial`s, all of which should have - the same sum indices - :arg argument_indices: tuple of argument indices - - :returns: an iterable of factorised :class:`Monomials`s - """ - assert len(set(frozenset(m.sum_indices) for m in monomials)) <= 1,\ - "All monomials required to have same sum indices for factorisation" - - optimal_atomics = find_optimal_atomics(monomials, argument_indices) - return factorise_atomics(monomials, optimal_atomics, argument_indices) diff --git a/tsfc/parameters.py b/tsfc/parameters.py index 3d112a41..51893807 100644 --- a/tsfc/parameters.py +++ b/tsfc/parameters.py @@ -14,7 +14,7 @@ "quadrature_degree": "auto", # Default mode - "mode": "coffee", + "mode": "spectral", # Maximum extent to unroll index sums. Default is 3, so that loops # over geometric dimensions are unrolled; this improves assembly diff --git a/tsfc/spectral.py b/tsfc/spectral.py index 722889ca..96f59948 100644 --- a/tsfc/spectral.py +++ b/tsfc/spectral.py @@ -1,34 +1,22 @@ from __future__ import absolute_import, print_function, division -from six.moves import zip +from six.moves import zip, zip_longest +from collections import OrderedDict, defaultdict, namedtuple from functools import partial, reduce +from itertools import chain -from gem import Delta, Indexed, Sum, index_sum +from gem.gem import Delta, Indexed, Sum, index_sum, one from gem.optimise import delta_elimination as _delta_elimination -from gem.optimise import sum_factorise as _sum_factorise -from gem.optimise import replace_division, unroll_indexsum +from gem.optimise import remove_componenttensors, replace_division, unroll_indexsum from gem.refactorise import ATOMIC, COMPOUND, OTHER, MonomialSum, collect_monomials from gem.unconcatenate import unconcatenate +from gem.coffee import optimise_monomial_sum from gem.utils import groupby -def delta_elimination(sum_indices, args, rest): - """IndexSum-Delta cancellation for monomials.""" - factors = [rest] + list(args) # construct factors - sum_indices, factors = _delta_elimination(sum_indices, factors) - # Destructure factors after cancellation - rest = factors.pop(0) - args = factors - return sum_indices, args, rest - - -def sum_factorise(sum_indices, args, rest): - """Optimised monomial product construction through sum factorisation - with reversed sum indices.""" - sum_indices = list(sum_indices) - sum_indices.reverse() - factors = args + (rest,) - return _sum_factorise(sum_indices, factors) +Integral = namedtuple('Integral', ['expression', + 'quadrature_multiindex', + 'argument_indices']) def Integrals(expressions, quadrature_multiindex, argument_multiindices, parameters): @@ -44,14 +32,86 @@ def Integrals(expressions, quadrature_multiindex, argument_multiindices, paramet :returns: list of integral representations """ + # Rewrite: a / b => a * (1 / b) + expressions = replace_division(expressions) + # Unroll max_extent = parameters["unroll_indexsum"] if max_extent: def predicate(index): return index.extent <= max_extent expressions = unroll_indexsum(expressions, predicate=predicate) - # Integral representation: just a GEM expression - return replace_division([index_sum(e, quadrature_multiindex) for e in expressions]) + + expressions = [index_sum(e, quadrature_multiindex) for e in expressions] + argument_indices = tuple(chain(*argument_multiindices)) + return [Integral(e, quadrature_multiindex, argument_indices) for e in expressions] + + +def flatten(var_reps, index_cache): + quadrature_indices = OrderedDict() + + pairs = [] # assignment pairs + for variable, reps in var_reps: + # Extract argument indices + argument_indices, = set(r.argument_indices for r in reps) + assert set(variable.free_indices) == set(argument_indices) + + # Extract and verify expressions + expressions = [r.expression for r in reps] + assert all(set(e.free_indices) <= set(argument_indices) + for e in expressions) + + # Save assignment pair + pairs.append((variable, reduce(Sum, expressions))) + + # Collect quadrature_indices + for r in reps: + quadrature_indices.update(zip_longest(r.quadrature_multiindex, ())) + + # Split Concatenate nodes + pairs = unconcatenate(pairs, cache=index_cache) + + def group_key(pair): + variable, expression = pair + return frozenset(variable.free_indices) + + # Variable ordering after delta cancellation + narrow_variables = OrderedDict() + # Assignments are variable -> MonomialSum map + delta_simplified = defaultdict(MonomialSum) + # Group assignment pairs by argument indices + for free_indices, pair_group in groupby(pairs, group_key): + variables, expressions = zip(*pair_group) + # Argument factorise expressions + classifier = partial(classify, set(free_indices)) + monomial_sums = collect_monomials(expressions, classifier) + # For each monomial, apply delta cancellation and insert + # result into delta_simplified. + for variable, monomial_sum in zip(variables, monomial_sums): + for monomial in monomial_sum: + var, s, a, r = delta_elimination(variable, *monomial) + narrow_variables.setdefault(var) + delta_simplified[var].add(s, a, r) + + # Final factorisation + for variable in narrow_variables: + monomial_sum = delta_simplified[variable] + # Collect sum indices applicable to the current MonomialSum + sum_indices = set().union(*[m.sum_indices for m in monomial_sum]) + # Put them in a deterministic order + sum_indices = [i for i in quadrature_indices if i in sum_indices] + # Sort for increasing index extent, this obtains the good + # factorisation for triangle x interval cells. Python sort is + # stable, so in the common case when index extents are equal, + # the previous deterministic ordering applies which is good + # for getting smaller temporaries. + sum_indices = sorted(sum_indices, key=lambda index: index.extent) + # Apply sum factorisation combined with COFFEE technology + expression = sum_factorise(variable, sum_indices, monomial_sum) + yield (variable, expression) + + +finalise_options = dict(replace_delta=False) def classify(argument_indices, expression): @@ -68,28 +128,61 @@ def classify(argument_indices, expression): return COMPOUND -def flatten(var_reps, index_cache): - assignments = unconcatenate([(variable, reduce(Sum, reps)) - for variable, reps in var_reps], - cache=index_cache) +def delta_elimination(variable, sum_indices, args, rest): + """IndexSum-Delta cancellation for monomials.""" + factors = list(args) + [variable, rest] # construct factors - def group_key(assignment): - variable, expression = assignment - return variable.free_indices + def prune(factors): + # Skip last factor (``rest``, see above) which can be + # arbitrarily complicated, so its pruning may be expensive, + # and its early pruning brings no advantages. + result = remove_componenttensors(factors[:-1]) + result.append(factors[-1]) + return result - for free_indices, assignment_group in groupby(assignments, group_key): - variables, expressions = zip(*assignment_group) - classifier = partial(classify, set(free_indices)) - monomial_sums = collect_monomials(expressions, classifier) - for variable, monomial_sum in zip(variables, monomial_sums): - # Compact MonomialSum after IndexSum-Delta cancellation - delta_simplified = MonomialSum() - for monomial in monomial_sum: - delta_simplified.add(*delta_elimination(*monomial)) - - # Yield assignments - for monomial in delta_simplified: - yield (variable, sum_factorise(*monomial)) + # Cancel sum indices + sum_indices, factors = _delta_elimination(sum_indices, factors) + factors = prune(factors) + # Cancel variable indices + var_indices, factors = _delta_elimination(variable.free_indices, factors) + factors = prune(factors) -finalise_options = dict(remove_componenttensors=False) + # Destructure factors after cancellation + rest = factors.pop() + variable = factors.pop() + args = [f for f in factors if f != one] + + assert set(var_indices) == set(variable.free_indices) + return variable, sum_indices, args, rest + + +def sum_factorise(variable, tail_ordering, monomial_sum): + if tail_ordering: + key_ordering = OrderedDict() + sub_monosums = defaultdict(MonomialSum) + for sum_indices, atomics, rest in monomial_sum: + # Pull out those sum indices that are not contained in the + # tail ordering, together with those atomics which do not + # share free indices with the tail ordering. + # + # Based on this, split the monomial sum, then recursively + # optimise each sub monomial sum with the first tail index + # removed. + tail_indices = tuple(i for i in sum_indices if i in tail_ordering) + tail_atomics = tuple(a for a in atomics + if set(tail_indices) & set(a.free_indices)) + head_indices = tuple(i for i in sum_indices if i not in tail_ordering) + head_atomics = tuple(a for a in atomics if a not in tail_atomics) + key = (head_indices, head_atomics) + key_ordering.setdefault(key) + sub_monosums[key].add(tail_indices, tail_atomics, rest) + sub_monosums = [(k, sub_monosums[k]) for k in key_ordering] + + monomial_sum = MonomialSum() + for (sum_indices, atomics), monosum in sub_monosums: + new_rest = sum_factorise(variable, tail_ordering[1:], monosum) + monomial_sum.add(sum_indices, atomics, new_rest) + + # Use COFFEE algorithm to optimise the monomial sum + return optimise_monomial_sum(monomial_sum, variable.index_ordering())