Skip to content

Commit

Permalink
Merge pull request #139 from firedrakeproject/mixed-element
Browse files Browse the repository at this point in the history
Use FInAT mixed element and update UFC kernel interface
  • Loading branch information
miklos1 authored Aug 2, 2017
2 parents 252c3f3 + 518a3df commit 8706fe4
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 106 deletions.
10 changes: 8 additions & 2 deletions gem/gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,12 +697,18 @@ class Concatenate(Node):
"""
__slots__ = ('children',)

def __init__(self, *children):
def __new__(cls, *children):
if all(isinstance(child, Zero) for child in children):
size = int(sum(numpy.prod(child.shape, dtype=int) for child in children))
return Zero((size,))

self = super(Concatenate, cls).__new__(cls)
self.children = children
return self

@property
def shape(self):
return (sum(numpy.prod(child.shape, dtype=int) for child in self.children),)
return (int(sum(numpy.prod(child.shape, dtype=int) for child in self.children)),)


class Delta(Scalar, Terminal):
Expand Down
5 changes: 2 additions & 3 deletions tsfc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from tsfc import fem, ufl_utils
from tsfc.coffee import SCALAR_TYPE, generate as generate_coffee
from tsfc.fiatinterface import as_fiat_cell
from tsfc.finatinterface import create_element
from tsfc.logging import logger
from tsfc.parameters import default_parameters

Expand Down Expand Up @@ -100,14 +99,14 @@ def compile_integral(integral_data, form_data, prefix, parameters,
fiat_cell = as_fiat_cell(cell)
integration_dim, entity_ids = lower_integral_type(fiat_cell, integral_type)

argument_multiindices = tuple(create_element(arg.ufl_element()).get_indices()
for arg in arguments)
quadrature_indices = []

# Dict mapping domains to index in original_form.ufl_domains()
domain_numbering = form_data.original_form.domain_numbering()
builder = interface.KernelBuilder(integral_type, integral_data.subdomain_id,
domain_numbering[integral_data.domain])
argument_multiindices = tuple(builder.create_element(arg.ufl_element()).get_indices()
for arg in arguments)
return_variables = builder.set_arguments(arguments, argument_multiindices)

coordinates = ufl_utils.coordinate_coefficient(mesh)
Expand Down
6 changes: 3 additions & 3 deletions tsfc/fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from finat.quadrature import make_quadrature

from tsfc import ufl2gem
from tsfc.finatinterface import create_element, as_fiat_cell
from tsfc.finatinterface import as_fiat_cell
from tsfc.kernel_interface import ProxyKernelInterface
from tsfc.modified_terminals import analyse_modified_terminal
from tsfc.parameters import NUMPY_TYPE, PARAMETERS
Expand Down Expand Up @@ -319,7 +319,7 @@ def fiat_to_ufl(fiat_dict, order):
def translate_argument(terminal, mt, ctx):
argument_multiindex = ctx.argument_multiindices[terminal.number()]
sigma = tuple(gem.Index(extent=d) for d in mt.expr.ufl_shape)
element = create_element(terminal.ufl_element())
element = ctx.create_element(terminal.ufl_element())

def callback(entity_id):
finat_dict = ctx.basis_evaluation(element, mt.local_derivatives, entity_id)
Expand All @@ -346,7 +346,7 @@ def translate_coefficient(terminal, mt, ctx):
assert mt.local_derivatives == 0
return vec

element = create_element(terminal.ufl_element())
element = ctx.create_element(terminal.ufl_element())

# Collect FInAT tabulation for all entities
per_derivative = collections.defaultdict(list)
Expand Down
64 changes: 43 additions & 21 deletions tsfc/finatinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def fiat_compat(element):


@singledispatch
def convert(element):
def convert(element, shape_innermost=True):
"""Handler for converting UFL elements to FInAT elements.
:arg element: The UFL element to convert.
Expand All @@ -95,7 +95,7 @@ def convert(element):

# Base finite elements first
@convert.register(ufl.FiniteElement)
def convert_finiteelement(element):
def convert_finiteelement(element, shape_innermost=True):
cell = as_fiat_cell(element.cell())
if element.family() == "Quadrature":
degree = element.degree()
Expand All @@ -113,7 +113,7 @@ def convert_finiteelement(element):
element.family())
# Handle quadrilateral short names like RTCF and RTCE.
element = element.reconstruct(cell=quad_tpc)
return finat.QuadrilateralElement(create_element(element))
return finat.QuadrilateralElement(create_element(element, shape_innermost))

kind = element.variant()
if kind is None:
Expand All @@ -137,62 +137,84 @@ def convert_finiteelement(element):
return lmbda(cell, element.degree())


# EnrichedElement case
@convert.register(ufl.EnrichedElement)
def convert_enrichedelement(element):
return finat.EnrichedElement([create_element(elem) for elem in element._elements])
def convert_enrichedelement(element, shape_innermost=True):
return finat.EnrichedElement([create_element(elem, shape_innermost)
for elem in element._elements])


# Generic MixedElement case
@convert.register(ufl.MixedElement)
def convert_mixedelement(element, shape_innermost=True):
return finat.MixedElement([create_element(elem, shape_innermost)
for elem in element.sub_elements()])


# VectorElement case
@convert.register(ufl.VectorElement)
def convert_vectorelement(element):
scalar_element = create_element(element.sub_elements()[0])
return finat.TensorFiniteElement(scalar_element, (element.num_sub_elements(),))
def convert_vectorelement(element, shape_innermost=True):
scalar_element = create_element(element.sub_elements()[0], shape_innermost)
return finat.TensorFiniteElement(scalar_element,
(element.num_sub_elements(),),
transpose=not shape_innermost)


# TensorElement case
@convert.register(ufl.TensorElement)
def convert_tensorelement(element):
scalar_element = create_element(element.sub_elements()[0])
return finat.TensorFiniteElement(scalar_element, element.reference_value_shape())
def convert_tensorelement(element, shape_innermost=True):
scalar_element = create_element(element.sub_elements()[0], shape_innermost)
return finat.TensorFiniteElement(scalar_element,
element.reference_value_shape(),
transpose=not shape_innermost)


# TensorProductElement case
@convert.register(ufl.TensorProductElement)
def convert_tensorproductelement(element):
def convert_tensorproductelement(element, shape_innermost=True):
cell = element.cell()
if type(cell) is not ufl.TensorProductCell:
raise ValueError("TensorProductElement not on TensorProductCell?")
return finat.TensorProductElement([create_element(elem)
return finat.TensorProductElement([create_element(elem, shape_innermost)
for elem in element.sub_elements()])


# HDivElement case
@convert.register(ufl.HDivElement)
def convert_hdivelement(element):
return finat.HDivElement(create_element(element._element))
def convert_hdivelement(element, shape_innermost=True):
return finat.HDivElement(create_element(element._element, shape_innermost))


# HDivElement case
@convert.register(ufl.HCurlElement)
def convert_hcurlelement(element):
return finat.HCurlElement(create_element(element._element))
def convert_hcurlelement(element, shape_innermost=True):
return finat.HCurlElement(create_element(element._element, shape_innermost))


quad_tpc = ufl.TensorProductCell(ufl.interval, ufl.interval)
_cache = weakref.WeakKeyDictionary()


def create_element(element):
def create_element(element, shape_innermost=True):
"""Create a FInAT element (suitable for tabulating with) given a UFL element.
:arg element: The UFL element to create a FInAT element from.
:arg shape_innermost: Vector/tensor indices come after basis function indices
"""
try:
return _cache[element]
cache = _cache[element]
except KeyError:
_cache[element] = {}
cache = _cache[element]

try:
return cache[shape_innermost]
except KeyError:
pass

if element.cell() is None:
raise ValueError("Don't know how to build element when cell is not given")

finat_element = convert(element)
_cache[element] = finat_element
finat_element = convert(element, shape_innermost=shape_innermost)
cache[shape_innermost] = finat_element
return finat_element
5 changes: 5 additions & 0 deletions tsfc/kernel_interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,10 @@ def cell_orientation(self, restriction):
def entity_number(self, restriction):
"""Facet or vertex number as a GEM index."""

@abstractmethod
def create_element(self, element):
"""Create a FInAT element (suitable for tabulating with) given
a UFL element."""


ProxyKernelInterface = make_proxy_class('ProxyKernelInterface', KernelInterface)
5 changes: 5 additions & 0 deletions tsfc/kernel_interface/firedrake.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def needs_cell_orientations(ir):
return True
return False

def create_element(self, element):
"""Create a FInAT element (suitable for tabulating with) given
a UFL element."""
return create_element(element)


class ExpressionKernelBuilder(KernelBuilderBase):
"""Builds expression kernels for UFL interpolation in Firedrake."""
Expand Down
Loading

0 comments on commit 8706fe4

Please sign in to comment.