Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/complex'
Browse files Browse the repository at this point in the history
  • Loading branch information
dham committed Oct 3, 2018
2 parents aa483a3 + f62fcbf commit 19e2c0d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 12 deletions.
50 changes: 42 additions & 8 deletions coffee/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from __future__ import absolute_import, print_function, division

from copy import deepcopy as dcopy
from math import isnan
from cmath import isnan
import numbers
import numpy as np

Expand Down Expand Up @@ -171,6 +171,7 @@ def __deepcopy__(self, memo):

def gencode(self, not_scope=True, parent=None):
children = [n.gencode(not_scope, self) for n in self.children]
children = ["("+child+")" for child in children]
subtree = (" "+type(self).op+" ").join(children)
if parent:
return wrap(subtree)
Expand Down Expand Up @@ -259,20 +260,25 @@ def values(self, val):
self._values = val

def _formatter(self, v):
"""Format a float into a string, showing up to ``precision`` decimal digits.
This function is partly extracted from the open_source "FFC: the FEniCS Form
Compiler", freely accessible at https://bitbucket.org/fenics-project/ffc."""
"""Format a real or complex value into a string, showing up to
``precision`` decimal digits. This function is partly
extracted from the open_source "FFC: the FEniCS Form
Compiler", freely accessible at
https://bitbucket.org/fenics-project/ffc.
"""
f = "%%.%dg" % self.precision
f_int = "%%.%df" % 1
eps = 10.0**(-self.precision)
if not isinstance(v, numbers.Number):
return v.gencode(not_scope=True)
elif isnan(v):
return "NAN"
elif abs(v - round(v, 1)) < eps:
return f_int % v
elif abs(v.real - round(v.real, 1)) < eps and abs(v.imag - round(v.imag, 1)) < eps:
formatter = f_int
else:
return f % v
formatter = f
re, im, zero = map(lambda arg: formatter % arg, (v.real, v.imag, 0))
return re if im == zero else re + ' + ' + im + ' * I'

def _tabulate_values(self, arr):
if len(arr.shape) == 1:
Expand Down Expand Up @@ -468,7 +474,7 @@ def is_const(self):
@property
def is_number(self):
try:
float(self.symbol)
complex(self.symbol)
return True
except ValueError:
return False
Expand Down Expand Up @@ -1153,6 +1159,34 @@ def gencode(self, not_scope=True):
""" % (str(dim), str(lda), str(sym), str(sym))


class ComplexInvert(Statement, LinAlg):
"""In-place inversion of a square array."""
# this should probably be changed later to not require a real and complex version
def __init__(self, sym, dim, pragma=None):
super(ComplexInvert, self).__init__([sym, dim, dim], pragma)

def reconstruct(self, sym, dim, **kwargs):
return type(self)(sym, dim, **kwargs)

def operands(self):
return [self.children[0], self.children[1]], {'pragma': self.pragma}

def gencode(self, not_scope=True):
sym, dim, lda = self.children
return """{
int n = %s;
int lda = %s;
int ipiv[n];
int lwork = n*n;
double complex work[lwork];
int info;
zgetrf_(&n,&n,%s,&lda,ipiv,&info);
zgetri_(&n,%s,&lda,ipiv,work,&lwork,&info);
}
""" % (str(dim), str(lda), str(sym), str(sym))


class Determinant(Expr, LinAlg):
"""Generic determinant"""
def __init__(self, sym, pragma=None):
Expand Down
4 changes: 2 additions & 2 deletions coffee/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def __init__(self):
# Check the argument specification
# Valid options are:
# visit_Foo(self, o, [*args, **kwargs])
signature = inspect.signature(meth)
if len(signature.parameters) < 2:
argspec = inspect.getfullargspec(meth)
if len(argspec.args) < 2:
raise RuntimeError("Visit method signature must be visit_Foo(self, o, [*args, **kwargs])")
handlers[name[len(prefix):]] = meth
self._handlers = handlers
Expand Down
4 changes: 2 additions & 2 deletions tests/test_precedence.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
def test_prod_div():
tree = ast.Prod("a", ast.Div("1", "b"))

assert tree.gencode() == "a * (1 / b)"
assert tree.gencode() == "(a) * (((1) / (b)))"


def test_unary_op():
tree = ast.Not(ast.Or("a", ast.And("b", "c")))

assert tree.gencode() == "!(a || (b && c))"
assert tree.gencode() == "!((a) || (((b) && (c))))"

0 comments on commit 19e2c0d

Please sign in to comment.