diff --git a/coffee/base.py b/coffee/base.py index c8f4d030..3bf2f55b 100644 --- a/coffee/base.py +++ b/coffee/base.py @@ -36,7 +36,7 @@ from __future__ import absolute_import, print_function, division from copy import deepcopy as dcopy -from math import isnan +from cmath import isnan import numbers import numpy as np @@ -171,6 +171,7 @@ def __deepcopy__(self, memo): def gencode(self, not_scope=True, parent=None): children = [n.gencode(not_scope, self) for n in self.children] + children = ["("+child+")" for child in children] subtree = (" "+type(self).op+" ").join(children) if parent: return wrap(subtree) @@ -259,9 +260,12 @@ def values(self, val): self._values = val def _formatter(self, v): - """Format a float into a string, showing up to ``precision`` decimal digits. - This function is partly extracted from the open_source "FFC: the FEniCS Form - Compiler", freely accessible at https://bitbucket.org/fenics-project/ffc.""" + """Format a real or complex value into a string, showing up to + ``precision`` decimal digits. This function is partly + extracted from the open_source "FFC: the FEniCS Form + Compiler", freely accessible at + https://bitbucket.org/fenics-project/ffc. + """ f = "%%.%dg" % self.precision f_int = "%%.%df" % 1 eps = 10.0**(-self.precision) @@ -269,10 +273,12 @@ def _formatter(self, v): return v.gencode(not_scope=True) elif isnan(v): return "NAN" - elif abs(v - round(v, 1)) < eps: - return f_int % v + elif abs(v.real - round(v.real, 1)) < eps and abs(v.imag - round(v.imag, 1)) < eps: + formatter = f_int else: - return f % v + formatter = f + re, im, zero = map(lambda arg: formatter % arg, (v.real, v.imag, 0)) + return re if im == zero else re + ' + ' + im + ' * I' def _tabulate_values(self, arr): if len(arr.shape) == 1: @@ -468,7 +474,7 @@ def is_const(self): @property def is_number(self): try: - float(self.symbol) + complex(self.symbol) return True except ValueError: return False @@ -1153,6 +1159,34 @@ def gencode(self, not_scope=True): """ % (str(dim), str(lda), str(sym), str(sym)) +class ComplexInvert(Statement, LinAlg): + """In-place inversion of a square array.""" + # this should probably be changed later to not require a real and complex version + def __init__(self, sym, dim, pragma=None): + super(ComplexInvert, self).__init__([sym, dim, dim], pragma) + + def reconstruct(self, sym, dim, **kwargs): + return type(self)(sym, dim, **kwargs) + + def operands(self): + return [self.children[0], self.children[1]], {'pragma': self.pragma} + + def gencode(self, not_scope=True): + sym, dim, lda = self.children + return """{ + int n = %s; + int lda = %s; + int ipiv[n]; + int lwork = n*n; + double complex work[lwork]; + int info; + + zgetrf_(&n,&n,%s,&lda,ipiv,&info); + zgetri_(&n,%s,&lda,ipiv,work,&lwork,&info); +} +""" % (str(dim), str(lda), str(sym), str(sym)) + + class Determinant(Expr, LinAlg): """Generic determinant""" def __init__(self, sym, pragma=None): diff --git a/coffee/visitor.py b/coffee/visitor.py index bf0ea69d..12ea3bb6 100644 --- a/coffee/visitor.py +++ b/coffee/visitor.py @@ -44,8 +44,8 @@ def __init__(self): # Check the argument specification # Valid options are: # visit_Foo(self, o, [*args, **kwargs]) - signature = inspect.signature(meth) - if len(signature.parameters) < 2: + argspec = inspect.getfullargspec(meth) + if len(argspec.args) < 2: raise RuntimeError("Visit method signature must be visit_Foo(self, o, [*args, **kwargs])") handlers[name[len(prefix):]] = meth self._handlers = handlers diff --git a/tests/test_precedence.py b/tests/test_precedence.py index e30d82a7..7566b37a 100644 --- a/tests/test_precedence.py +++ b/tests/test_precedence.py @@ -6,10 +6,10 @@ def test_prod_div(): tree = ast.Prod("a", ast.Div("1", "b")) - assert tree.gencode() == "a * (1 / b)" + assert tree.gencode() == "(a) * (((1) / (b)))" def test_unary_op(): tree = ast.Not(ast.Or("a", ast.And("b", "c"))) - assert tree.gencode() == "!(a || (b && c))" + assert tree.gencode() == "!((a) || (((b) && (c))))"