diff --git a/coffee/visitors/utilities.py b/coffee/visitors/utilities.py index 4018ad3a..1b9a1749 100644 --- a/coffee/visitors/utilities.py +++ b/coffee/visitors/utilities.py @@ -337,6 +337,13 @@ def visit_object(self, o, *args, **kwargs): def visit_list(self, o, *args, **kwargs): return sum(self.visit(e) for e in o) + def visit_ArrayInit(self, o, *args, **kwargs): + vals = o.values + if isinstance(vals, np.ndarray): + return sum(self.visit(vals[i]) for i in np.ndindex(vals.shape)) + else: + return self.visit(vals) + def visit_Node(self, o, *args, **kwargs): ops, _ = o.operands() return sum(self.visit(op) for op in ops)