From 8791164b29796d32a9717bfc9d5d48242deb85f0 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 6 Dec 2016 09:26:09 +0100 Subject: [PATCH 1/5] Drop self.decls all over the place --- coffee/cse.py | 38 ++++++++++++++++++++------------------ coffee/expander.py | 7 +------ coffee/hoister.py | 9 ++++----- coffee/optimizer.py | 21 +++++++++------------ coffee/plan.py | 20 +++++++++----------- coffee/rewriter.py | 13 +++++-------- coffee/scheduler.py | 20 +++++++++----------- coffee/vectorizer.py | 39 ++++++++++++++++++--------------------- 8 files changed, 75 insertions(+), 92 deletions(-) diff --git a/coffee/cse.py b/coffee/cse.py index fb9867de..480071dc 100644 --- a/coffee/cse.py +++ b/coffee/cse.py @@ -198,11 +198,10 @@ class CSEUnpicker(object): symbols (further information concerning loop linearity is available in the module ``expression.py``).""" - def __init__(self, exprs, header, hoisted, decls): + def __init__(self, exprs, header, hoisted): self.exprs = exprs self.header = header self.hoisted = hoisted - self.decls = decls @property def type(self): @@ -212,7 +211,7 @@ def type(self): def linear_dims(self): return self.exprs.values()[0].linear_dims - def _push_temporaries(self, temporaries, trace, global_trace, ra): + def _push_temporaries(self, temporaries, trace, global_trace, ra, decls): def is_pushable(temporary, temporaries): # To be pushable ... @@ -237,7 +236,9 @@ def is_pushable(temporary, temporaries): # they will be pushed if s.urepr in global_trace and global_trace[s.urepr].pushed: continue - if any(l not in ra[self.decls[s.symbol]] for l in pushed_in): + if s.symbol not in decls: + continue + if any(l not in ra[decls[s.symbol]] for l in pushed_in): return False return True @@ -256,7 +257,6 @@ def is_pushable(temporary, temporaries): all(rb.urepr in global_trace for rb in t.readby): global_trace[t.urepr].pushed = True t.main_loop.body.remove(t.node) - self.decls.pop(t.name, None) # Transform the AST (note: node replacement must happen in the order # in which the temporaries have been encountered) @@ -275,7 +275,7 @@ def is_pushable(temporary, temporaries): for p, p_c in r_linear_reads_costs.items() or [(r, 0)]: t.linear_reads_costs[p] = c + p_c - def _transform_temporaries(self, temporaries): + def _transform_temporaries(self, temporaries, decls): from .rewriter import ExpressionRewriter # Never attempt to transform the main expression @@ -287,8 +287,7 @@ def _transform_temporaries(self, temporaries): rewriters = OrderedDict() for t in temporaries: expr_info = MetaExpr(self.type, t.main_loop.block, t.main_nest) - ew = ExpressionRewriter(t.node, expr_info, self.decls, self.header, - self.hoisted) + ew = ExpressionRewriter(t.node, expr_info, self.header, self.hoisted) ew.replacediv() ew.expand(mode='all', lda=lda) ew.reassociate(lambda i: all(r != t.main_loop.dim for r in lda[i.symbol])) @@ -303,10 +302,13 @@ def _transform_temporaries(self, temporaries): if t.linearity_degree > 1: ew.licm(mode='only_linear', lda=lda) - def _analyze_expr(self, expr, loop, lda): + # Keep track of new declarations (recomputation might otherwise be too expensive) + decls.update(OrderedDict([(k, v.decl) for k, v in self.hoisted.items()])) + + def _analyze_expr(self, expr, loop, lda, decls): finder = FindInstances(Symbol) reads = finder.visit(expr, ret=FindInstances.default_retval())[Symbol] - reads = [s for s in reads if s.symbol in self.decls] + reads = [s for s in reads if s.symbol in decls] syms = [s for s in reads if any(d in loop.dim for d in lda[s])] linear_reads_costs = OrderedDict() @@ -328,7 +330,7 @@ def wrapper(node, found=0): return reads, linear_reads_costs - def _analyze_loop(self, loop, nest, lda, global_trace): + def _analyze_loop(self, loop, nest, lda, global_trace, decls): linear_dims = [l.dim for l, _ in nest if l.is_linear] trace = OrderedDict() @@ -338,7 +340,7 @@ def _analyze_loop(self, loop, nest, lda, global_trace): for t in not_ssa: t.readby.append(t.symbol) continue - reads, linear_reads_costs = self._analyze_expr(node.rvalue, loop, lda) + reads, linear_reads_costs = self._analyze_expr(node.rvalue, loop, lda, decls) affected = [s for s in reads if any(i in linear_dims for i in lda[s])] for s in affected: if s.urepr in global_trace: @@ -464,8 +466,8 @@ def _cost_fact(self, trace, levels, lda, bounds): def unpick(self): # Collect all necessary info - external_decls = [d for d in self.decls.values() if d.scope == EXTERNAL] - fors = visit(self.header, info_items=['fors'])['fors'] + info = visit(self.header, info_items=['decls', 'fors']) + decls, fors = info['decls'], info['fors'] lda = loops_analysis(self.header, value='dim') # Collect all loops to be analyzed @@ -479,7 +481,7 @@ def unpick(self): global_trace = OrderedDict() mapper = OrderedDict() for loop, nest in nests.items(): - trace = self._analyze_loop(loop, nest, lda, global_trace) + trace = self._analyze_loop(loop, nest, lda, global_trace, decls) if trace: mapper[loop] = trace global_trace.update(trace) @@ -499,9 +501,9 @@ def unpick(self): # Transform the loop for i in range(global_best[0] + 1, global_best[1] + 1): - ra = reachability_analysis(self.header, external_decls) - self._push_temporaries(levels[i-1], trace, global_trace, ra) - self._transform_temporaries(levels[i]) + ra = reachability_analysis(self.header) + self._push_temporaries(levels[i-1], trace, global_trace, ra, decls) + self._transform_temporaries(levels[i], decls) # Clean up for transformed_loop, nest in reversed(nests.items()): diff --git a/coffee/expander.py b/coffee/expander.py index 9ece54cc..6e1962cc 100644 --- a/coffee/expander.py +++ b/coffee/expander.py @@ -50,14 +50,11 @@ class Expander(object): GROUP = 0 # Expression /will/ not trigger expansion EXPAND = 1 # Expression /could/ be expanded - def __init__(self, stmt, expr_info=None, decls=None, hoisted=None): + def __init__(self, stmt, expr_info=None, hoisted=None): self.stmt = stmt self.expr_info = expr_info - self.decls = decls self.hoisted = hoisted - self.local_decls = {} - def _build(self, exp, grp): """Create a node for the expansion and keep track of it.""" expansion = Prod(exp, dcopy(grp)) @@ -122,5 +119,3 @@ def expand(self, should_expand, **kwargs): for node, parent in expressions: self.expansions = [] self._expand(node, parent) - - self.decls.update(self.local_decls) diff --git a/coffee/hoister.py b/coffee/hoister.py index 81cfbaac..90303727 100644 --- a/coffee/hoister.py +++ b/coffee/hoister.py @@ -122,12 +122,11 @@ class Hoister(object): # Temporary variables template _template = "ct%d" - def __init__(self, stmt, expr_info, header, decls, hoisted): + def __init__(self, stmt, expr_info, header, hoisted): """Initialize the Hoister.""" self.stmt = stmt self.expr_info = expr_info self.header = header - self.decls = decls self.hoisted = hoisted def _filter(self, dep, subexprs, make_unique=True, sharing=None): @@ -266,7 +265,6 @@ def licm(self, should_extract, **kwargs): for i, j in zip(stmts, decls): name = j.lvalue.symbol self.hoisted[name] = (i, j, clone, place) - self.decls[name] = j lda.update({s: set(dep) for s in replacements}) if not iterative: @@ -320,12 +318,13 @@ def trim(self, candidate, **kwargs): return # Inject the reductions into the AST + decls = visit(self.header, info_items=['decls'])['decls'] for w, p in make_reduce: name = self._template % len(self.hoisted) reduction = Incr(Symbol(name, w.lvalue.rank, w.lvalue.offset), ast_reconstruct(w.rvalue)) insert_at_elem(p.body, w, reduction) - handle = self.decls[w.lvalue.symbol] + handle = decls[w.lvalue.symbol] declaration = Decl(handle.typ, Symbol(name, handle.lvalue.rank), ArrayInit(np.array([0.0])), handle.qual, handle.attr) insert_at_elem(parents[index].children, candidate, declaration) @@ -350,5 +349,5 @@ def trim(self, candidate, **kwargs): if w.lvalue.symbol not in reads: p.body.remove(w) if not isinstance(w, Decl): - key = self.decls.pop(w.lvalue.symbol) + key = decls[w.lvalue.symbol] declarations[key].children.remove(key) diff --git a/coffee/optimizer.py b/coffee/optimizer.py index a9a9ceb0..16c312d2 100644 --- a/coffee/optimizer.py +++ b/coffee/optimizer.py @@ -53,17 +53,15 @@ class LoopOptimizer(object): - def __init__(self, loop, header, decls, exprs): + def __init__(self, loop, header, exprs): """Initialize the LoopOptimizer. :param loop: root AST node of a loop nest :param header: the kernel's top node - :param decls: list of Decl objects accessible in ``loop`` :param exprs: list of expressions to be optimized """ self.loop = loop self.header = header - self.decls = decls self.exprs = exprs # Track nonzero regions accessed in each symbol @@ -105,8 +103,7 @@ def rewrite(self, mode): # Expression rewriting, expressed as a sequence of AST transformation passes for stmt, expr_info in self.exprs.items(): - ew = ExpressionRewriter(stmt, expr_info, self.decls, self.header, - self.hoisted) + ew = ExpressionRewriter(stmt, expr_info, self.header, self.hoisted) if expr_info.mode == 1: if expr_info.dimension in [0, 1]: @@ -163,14 +160,14 @@ def eliminate_zeros(self): avoid evaluation of arithmetic operations involving zero-valued blocks in statically initialized arrays.""" - zls = ZeroRemover(self.exprs, self.decls, self.hoisted) + zls = ZeroRemover(self.exprs, self.hoisted) self.nz_syms = zls.reschedule(self.header) def _unpick_cse(self): """Search for factorization opportunities across temporaries created by common sub-expression elimination. If a gain in operation count is detected, unpick CSE and apply factorization + code motion.""" - cse_unpicker = CSEUnpicker(self.exprs, self.header, self.hoisted, self.decls) + cse_unpicker = CSEUnpicker(self.exprs, self.header, self.hoisted) cse_unpicker.unpick() def _min_temporaries(self): @@ -207,7 +204,6 @@ def _min_temporaries(self): place.children.remove(decl) # Update trackers self.hoisted.pop(temporary) - self.decls.pop(temporary) # Replace temporary symbols and clean up l_innermost_body = inner_loops(l)[-1].body @@ -375,7 +371,7 @@ def find_save(target_expr, expr_info): fake_stmt = stmt.__class__(stmt.children[0], dcopy(target_expr)) fake_parent = expr_info.parent.children fake_parent[fake_parent.index(stmt)] = fake_stmt - ew = ExpressionRewriter(fake_stmt, expr_info, self.decls) + ew = ExpressionRewriter(fake_stmt, expr_info) ew.expand(mode='all').factorize(mode='all').factorize(mode='linear') nterms = ew.licm(mode='aggressive', look_ahead=True) nterms = len(uniquify(nterms[expr_info.dims])) or 1 @@ -450,6 +446,7 @@ def find_save(target_expr, expr_info): # 3) Purge the AST from now useless symbols/expressions if should_unroll: + decls = visit(self.header, info_items=['decls'])['decls'] for stmt, expr_info in self.exprs.items(): nests = [n for n in visit(expr_info.loops_parents[0])['fors']] injectable_nests = [n for n in nests if list(zip(*n))[0] != expr_info.loops] @@ -458,10 +455,9 @@ def find_save(target_expr, expr_info): for l, p in unrolled: p.children.remove(l) for i_sym in injectable.keys(): - decl = self.decls.get(i_sym) + decl = decls.get(i_sym) if decl and decl in p.children: p.children.remove(decl) - self.decls.pop(i_sym) # 4) Split the expressions if injection has been performed for stmt, expr_info in self.exprs.items(): @@ -478,13 +474,14 @@ def find_save(target_expr, expr_info): def _recoil(self): """Increase the stack size if the kernel arrays exceed the stack limit threshold (at the C level).""" + decls = visit(self.header, info_items=['decls'])['decls'] # Assume the size of a C type double is 8 bytes c_double_size = 8 # Assume the stack size is 1.7 MB (2 MB is usually the limit) stack_size = 1.7*1024*1024 - decls = [d for d in self.decls.values() if d.sym.rank] + decls = [d for d in decls.values() if d.size] size = sum([reduce(operator.mul, d.sym.rank) for d in decls]) if size * c_double_size > stack_size: diff --git a/coffee/plan.py b/coffee/plan.py index fbff92c8..8c194ff3 100644 --- a/coffee/plan.py +++ b/coffee/plan.py @@ -87,17 +87,16 @@ def plan_cpu(self, opts): split = opts.get('split') dead_ops_elimination = opts.get('dead_ops_elimination') - info = visit(kernel) - decls = info['decls'] + info = visit(kernel, info_items=['decls', 'exprs']) # Collect expressions and related metadata nests = defaultdict(OrderedDict) for stmt, expr_info in info['exprs'].items(): parent, nest = expr_info if not nest: continue - metaexpr = MetaExpr(check_type(stmt, decls), parent, nest) + metaexpr = MetaExpr(check_type(stmt, info['decls']), parent, nest) nests[nest[0]].update({stmt: metaexpr}) - loop_opts = [CPULoopOptimizer(loop, header, decls, exprs) + loop_opts = [CPULoopOptimizer(loop, header, exprs) for (loop, header), exprs in nests.items()] # Combining certain optimizations is forbidden. @@ -178,22 +177,21 @@ def plan_gpu(self): # The optimization passes are performed individually (i.e., "locally") for # each function (or "kernel") found in the provided AST - retval = FindInstances.default_retval() - kernels = FindInstances(FunDecl, stop_when_found=True).visit(self.ast, - ret=retval)[FunDecl] + kernels = FindInstances(FunDecl, stop_when_found=True).visit(self.ast)[FunDecl] + for kernel in kernels: info = visit(kernel, info_items=['decls', 'exprs']) - decls = info['decls'] - # Structure up expressions and related metadata + # Collect expressions and related metadata nests = defaultdict(OrderedDict) for stmt, expr_info in info['exprs'].items(): parent, nest = expr_info if not nest: continue - metaexpr = MetaExpr(check_type(stmt, decls), parent, nest) + metaexpr = MetaExpr(check_type(stmt, info['decls']), parent, nest) nests[nest[0]].update({stmt: metaexpr}) + loop_opts = [CPULoopOptimizer(loop, header, exprs) + for (loop, header), exprs in nests.items()] - loop_opts = [GPULoopOptimizer(l, header, decls) for l, header in nests] for loop_opt in loop_opts: itspace_vrs, accessed_vrs = loop_opt.extract() diff --git a/coffee/rewriter.py b/coffee/rewriter.py index b20e8fa1..4d65d147 100644 --- a/coffee/rewriter.py +++ b/coffee/rewriter.py @@ -56,25 +56,21 @@ class ExpressionRewriter(object): * Expansion: transform an expression ``(a + b)*c`` into ``(a*c + b*c)`` * Factorization: transform an expression ``a*b + a*c`` into ``a*(b+c)``""" - def __init__(self, stmt, expr_info, decls, header=None, hoisted=None): + def __init__(self, stmt, expr_info, header=None, hoisted=None): """Initialize the ExpressionRewriter. :param stmt: the node whose rvalue is the expression for rewriting :param expr_info: ``MetaExpr`` object describing the expression - :param decls: all declarations for the symbols in ``stmt``. :param header: the kernel's top node :param hoisted: dictionary that tracks all hoisted expressions """ self.stmt = stmt self.expr_info = expr_info - self.decls = decls self.header = header or Root() self.hoisted = hoisted if hoisted is not None else StmtTracker() - self.expr_hoister = Hoister(self.stmt, self.expr_info, self.header, - self.decls, self.hoisted) - self.expr_expander = Expander(self.stmt, self.expr_info, self.decls, - self.hoisted) + self.expr_hoister = Hoister(self.stmt, self.expr_info, self.header, self.hoisted) + self.expr_expander = Expander(self.stmt, self.expr_info, self.hoisted) self.expr_factorizer = Factorizer(self.stmt) def licm(self, mode='normal', **kwargs): @@ -513,7 +509,8 @@ def preevaluate(self): self.expr_info._loops_info.remove((l, p)) # Precompute constant expressions - evaluator = Evaluate(self.decls, any(d.nonzero for s, d in self.decls.items())) + decls = visit(self.header, info_items=['decls'])['decls'] + evaluator = Evaluate(decls, any(d.nonzero for s, d in decls.items())) for hoisted_loop in self.hoisted.all_loops: evals = evaluator.visit(hoisted_loop, **Evaluate.default_args) # First, find out identical tables diff --git a/coffee/scheduler.py b/coffee/scheduler.py index 383f5068..fe45d2eb 100644 --- a/coffee/scheduler.py +++ b/coffee/scheduler.py @@ -478,15 +478,13 @@ class ZeroRemover(LoopScheduler): THRESHOLD = 1 # Only skip if there more than THRESHOLD consecutive zeros - def __init__(self, exprs, decls, hoisted): + def __init__(self, exprs, hoisted): """Initialize the ZeroRemover. :param exprs: the expressions for which zero removal is performed. - :param decls: lists of declarations visible to ``exprs``. :param hoisted: dictionary that tracks hoisted sub-expressions """ self.exprs = exprs - self.decls = decls self.hoisted = hoisted def _track_nz_expr(self, node, nz_syms, nest): @@ -658,7 +656,7 @@ def _track_nz_blocks(self, node, nz_syms, nz_info, nest=None, parent=None, candi else: raise ControlFlowError - def _reschedule_itspace(self, root, candidates): + def _reschedule_itspace(self, root, candidates, decls): """Consider two statements A and B, and their iteration space. If the two iteration spaces have @@ -688,9 +686,9 @@ def _reschedule_itspace(self, root, candidates): """ nz_info = OrderedDict() - # Elaborate the initial sparsity pattern of the symbols in /root/ + # Compute the initial sparsity pattern of the symbols in /root/ nz_syms = defaultdict(list) - for s, d in self.decls.items(): + for s, d in decls.items(): if not d.nonzero: continue for nz_b in product(*d.nonzero): @@ -786,8 +784,7 @@ def _recombine(self, nz_info): new_exprs[stmt] = self.exprs[i] for stmt, expr_info in new_exprs.items(): - ew = ExpressionRewriter(stmt, expr_info, self.decls, - expr_info.outermost_parent, self.hoisted) + ew = ExpressionRewriter(stmt, expr_info) ew.factorize('heuristic') if new_exprs: @@ -816,9 +813,10 @@ def reschedule(self, root): zero-valued data spaces. This is achieved through symbolic execution starting from ``root``. Control flow, in the form of If, Switch, etc., is forbidden.""" + decls = visit(root, info_items=['decls'])['decls'] # Avoid rescheduling if zero-valued blocks are too small - zero_decls = [d for d in self.decls.values() if d.nonzero] + zero_decls = [d for d in decls.values() if d.nonzero] if self._should_skip(zero_decls): return {} @@ -842,12 +840,12 @@ def reschedule(self, root): self.exprs.update(elf.fission(stmt, expr_info)) # Apply the rescheduling - nz_syms, nz_info = self._reschedule_itspace(root, candidates) + nz_syms, nz_info = self._reschedule_itspace(root, candidates, decls) # Finally, "inline" the expressions that were originally split, if possible self._recombine(nz_info) else: # Apply the rescheduling - nz_syms, nz_info = self._reschedule_itspace(root, candidates) + nz_syms, nz_info = self._reschedule_itspace(root, candidates, decls) return nz_syms diff --git a/coffee/vectorizer.py b/coffee/vectorizer.py index 2659f1a5..95650c29 100644 --- a/coffee/vectorizer.py +++ b/coffee/vectorizer.py @@ -75,7 +75,6 @@ def __init__(self, loop_opt, kernel=None): self.kernel = kernel or loop_opt.header self.header = loop_opt.header self.loop = loop_opt.loop - self.decls = loop_opt.decls self.exprs = loop_opt.exprs self.nz_syms = loop_opt.nz_syms @@ -140,29 +139,27 @@ def autovectorize(self, p_dim=-1): :arg p_dim: the array dimension that should be padded (defaults to the innermost, or -1) """ - buffer = self._pad(p_dim) - if buffer: - self._align_data(buffer, p_dim) + info = visit(self.header, info_items=['decls', 'fors', 'symbols_dep', + 'symbols_mode', 'symbol_refs']) + + padded = self._pad(p_dim, info['decls'], info['fors'], info['symbols_dep'], + info['symbols_mode'], info['symbol_refs']) + if padded: + self._align_data(p_dim, info['decls']) - def _pad(self, p_dim): + def _pad(self, p_dim, decls, fors, symbols_dep, symbols_mode, symbol_refs): """Apply padding.""" - info = visit(self.header, info_items=['fors', 'symbols_dep', - 'symbols_mode', 'symbol_refs']) - symbols_dep = info['symbols_dep'] - symbols_mode = info['symbols_mode'] - symbol_refs = info['symbol_refs'] - retval = FindInstances.default_retval() - to_invert = FindInstances(Invert).visit(self.header, ret=retval)[Invert] + to_invert = FindInstances(Invert).visit(self.header)[Invert] # Loop increments different than 1 are unsupported - if any([l.increment != 1 for l, _ in flatten(info['fors'])]): + if any([l.increment != 1 for l, _ in flatten(fors)]): return None DSpace = namedtuple('DSpace', ['region', 'nest', 'symbols']) ISpace = namedtuple('ISpace', ['region', 'nest', 'bag']) buf_decl = None - for decl_name, decl in self.decls.items(): + for decl_name, decl in decls.items(): if not decl.size or decl.is_pointer_type: continue @@ -272,11 +269,11 @@ def _pad(self, p_dim): self.header.children.append(copy_back[0]) # D) Update the global data structures - self.decls[buf_name] = buf_decl + decls[buf_name] = buf_decl return buf_decl - def _align_data(self, buffer, p_dim): + def _align_data(self, p_dim, decls): """Apply data alignment. This boils down to: * Decorate declarations with qualifiers for data alignment @@ -288,7 +285,7 @@ def _align_data(self, buffer, p_dim): align = system.compiler['align'](system.isa['alignment']) # Array alignment - for decl in self.decls.values(): + for decl in decls.values(): if decl.sym.rank and decl.scope == LOCAL: decl.attr.append(align) @@ -298,7 +295,7 @@ def _align_data(self, buffer, p_dim): for stmt in l.body: sym, expr = stmt.lvalue, stmt.rvalue - decl = self.decls[sym.symbol] + decl = decls[sym.symbol] # Condition A: the lvalue can be a scalar only if /stmt/ is not an # augmented assignment, otherwise the extra iterations would alter @@ -329,7 +326,7 @@ def _align_data(self, buffer, p_dim): # Condition E: extra iterations induced by bounds and offset rounding # must not alter the computation for s in symbols: - decl = self.decls[s.symbol] + decl = decls[s.symbol] index = s.rank.index(l.dim) stride = s.strides[index] extra = list(range(stride + l.size, stride + vect_roundup(l.size))) @@ -370,7 +367,7 @@ def _align_data(self, buffer, p_dim): l.pragma.add(system.compiler["align_forloop"]) l.pragma.add(system.compiler['force_simdization']) - def _transpose_layout(self): + def _transpose_layout(self, decls): dim = self.loop.dim retval = FindInstances.default_retval() symbols = FindInstances(Symbol).visit(self.loop, ret=retval)[Symbol] @@ -382,7 +379,7 @@ def _transpose_layout(self): mapper = OrderedDict() for s in symbols: - mapper.setdefault(self.decls[s.symbol], list()).append(s) + mapper.setdefault(decls[s.symbol], list()).append(s) for decl, syms in mapper.items(): # Adjust the declaration From bb360c2aff7b850a39d5bd515876b4c6d31c7603 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 6 Dec 2016 12:33:45 +0100 Subject: [PATCH 2/5] minor refactoring --- coffee/expander.py | 29 +++++++++++++---------------- coffee/optimizer.py | 16 ++++++++-------- coffee/plan.py | 5 ++--- coffee/rewriter.py | 16 ++++++++-------- coffee/scheduler.py | 6 ++++-- coffee/utils.py | 25 ++++--------------------- 6 files changed, 39 insertions(+), 58 deletions(-) diff --git a/coffee/expander.py b/coffee/expander.py index 6e1962cc..6d2e312a 100644 --- a/coffee/expander.py +++ b/coffee/expander.py @@ -50,22 +50,20 @@ class Expander(object): GROUP = 0 # Expression /will/ not trigger expansion EXPAND = 1 # Expression /could/ be expanded - def __init__(self, stmt, expr_info=None, hoisted=None): + def __init__(self, stmt): self.stmt = stmt - self.expr_info = expr_info - self.hoisted = hoisted - def _build(self, exp, grp): + def _build(self, exp, grp, expansions): """Create a node for the expansion and keep track of it.""" expansion = Prod(exp, dcopy(grp)) # Track the new expansion - self.expansions.append(expansion) + expansions.append(expansion) # Untrack any expansions occured in children nodes - if grp in self.expansions: - self.expansions.remove(grp) + if grp in expansions: + expansions.remove(grp) return expansion - def _expand(self, node, parent): + def _expand(self, node, parent, expansions): if isinstance(node, Symbol): return ([node], self.EXPAND) if self.should_expand(node) \ else ([node], self.GROUP) @@ -74,12 +72,12 @@ def _expand(self, node, parent): # Try to expand /within/ the children, but then return saying "I'm not # expandable any further" for n in node.children: - self._expand(n, node) + self._expand(n, node, expansions) return ([node], self.GROUP) elif isinstance(node, Prod): - l_exps, l_type = self._expand(node.left, node) - r_exps, r_type = self._expand(node.right, node) + l_exps, l_type = self._expand(node.left, node, expansions) + r_exps, r_type = self._expand(node.right, node, expansions) if l_type == self.GROUP and r_type == self.GROUP: return ([node], self.GROUP) # At least one child is expandable (marked as EXPAND), whereas the @@ -89,7 +87,7 @@ def _expand(self, node, parent): expandable = r_exps if l_type == self.GROUP else l_exps to_replace = OrderedDict() for exp, grp in itertools.product(expandable, groupable): - expansion = self._build(exp, grp) + expansion = self._build(exp, grp, expansions) to_replace.setdefault(exp, []).append(expansion) ast_replace(node, {k: ast_make_expr(Sum, v) for k, v in to_replace.items()}, copy=False, mode='symbol') @@ -99,8 +97,8 @@ def _expand(self, node, parent): return (list(flatten(to_replace.values())) or [expanded], self.EXPAND) elif isinstance(node, (Sum, Sub)): - l_exps, l_type = self._expand(node.left, node) - r_exps, r_type = self._expand(node.right, node) + l_exps, l_type = self._expand(node.left, node, expansions) + r_exps, r_type = self._expand(node.right, node, expansions) if l_type == self.EXPAND and r_type == self.EXPAND and isinstance(node, Sum): return (l_exps + r_exps, self.EXPAND) elif l_type == self.EXPAND and r_type == self.EXPAND and isinstance(node, Sub): @@ -117,5 +115,4 @@ def expand(self, should_expand, **kwargs): self.should_expand = should_expand for node, parent in expressions: - self.expansions = [] - self._expand(node, parent) + self._expand(node, parent, []) diff --git a/coffee/optimizer.py b/coffee/optimizer.py index 16c312d2..cabd105f 100644 --- a/coffee/optimizer.py +++ b/coffee/optimizer.py @@ -177,28 +177,28 @@ def _min_temporaries(self): * it is written once, AND * it is read once OR it is read n times, but it hosts only a Symbol """ - occs = count(self.header, mode='symbol_id', read_only=True) + + occurrences = count(self.header, mode='symbol_id', read_only=True) for l in self.hoisted.all_loops: - info = visit(l) - l_occs = count(l, read_only=True) + info = visit(l, info_items=['symbol_refs', 'symbols_mode']) to_replace, to_remove = {}, [] - for (temporary, _, _), temporary_occs in l_occs.items(): + for (temporary, _, _), c in count(l, read_only=True).items(): if temporary not in self.hoisted: continue if self.hoisted[temporary].loop is not l: continue - if occs.get(temporary) != temporary_occs: + if occurrences.get(temporary) != c: continue decl = self.hoisted[temporary].decl place = self.hoisted[temporary].place expr = self.hoisted[temporary].stmt.rvalue - if temporary_occs > 1 and explore_operator(expr): + if c > 1 and explore_operator(expr): continue - temporary_refs = info['symbol_refs'][temporary] + references = info['symbol_refs'][temporary] syms_mode = info['symbols_mode'] # Note: only one write is possible at this point - write = [(s, p) for s, p in temporary_refs if syms_mode[s][0] == WRITE][0] + write = [(s, p) for s, p in references if syms_mode[s][0] == WRITE][0] to_replace[write[0]] = expr to_remove.append(write[1]) place.children.remove(decl) diff --git a/coffee/plan.py b/coffee/plan.py index 8c194ff3..51fb999a 100644 --- a/coffee/plan.py +++ b/coffee/plan.py @@ -70,8 +70,7 @@ def plan_cpu(self, opts): start_time = time.time() - finder = FindInstances(FunDecl, stop_when_found=True) - kernels = finder.visit(self.ast, ret=FindInstances.default_retval())[FunDecl] + kernels = FindInstances(FunDecl, stop_when_found=True).visit(self.ast)[FunDecl] if opts is None: opts = coffee.OptimizationLevel.retrieve(coffee.options['optimizations']) @@ -189,7 +188,7 @@ def plan_gpu(self): continue metaexpr = MetaExpr(check_type(stmt, info['decls']), parent, nest) nests[nest[0]].update({stmt: metaexpr}) - loop_opts = [CPULoopOptimizer(loop, header, exprs) + loop_opts = [GPULoopOptimizer(loop, header, exprs) for (loop, header), exprs in nests.items()] for loop_opt in loop_opts: diff --git a/coffee/rewriter.py b/coffee/rewriter.py index 4d65d147..1f68ff27 100644 --- a/coffee/rewriter.py +++ b/coffee/rewriter.py @@ -69,9 +69,9 @@ def __init__(self, stmt, expr_info, header=None, hoisted=None): self.header = header or Root() self.hoisted = hoisted if hoisted is not None else StmtTracker() - self.expr_hoister = Hoister(self.stmt, self.expr_info, self.header, self.hoisted) - self.expr_expander = Expander(self.stmt, self.expr_info, self.hoisted) - self.expr_factorizer = Factorizer(self.stmt) + self.codemotion = Hoister(self.stmt, self.expr_info, self.header, self.hoisted) + self.expander = Expander(self.stmt) + self.factorizer = Factorizer(self.stmt) def licm(self, mode='normal', **kwargs): """Perform generalized loop-invariant code motion, a transformation @@ -161,9 +161,9 @@ def licm(self, mode='normal', **kwargs): out_linear_dims = set(self.expr_info.out_linear_dims) if kwargs.get('look_ahead'): - hoist = self.expr_hoister.extract + hoist = self.codemotion.extract else: - hoist = self.expr_hoister.licm + hoist = self.codemotion.licm if mode == 'normal': should_extract = lambda d: d != dims @@ -183,7 +183,7 @@ def licm(self, mode='normal', **kwargs): non_candidates = {l.dim for l in candidates[:-1]} self.reassociate(lambda i: not lda[i].intersection(non_candidates)) hoist(should_extract, with_promotion=True, lda=lda) - self.expr_hoister.trim(candidate) + self.codemotion.trim(candidate) elif mode == 'incremental': lda = kwargs.get('lda') or loops_analysis(self.header, value='dim') should_extract = lambda d: not (d and d.issubset(dims)) @@ -292,7 +292,7 @@ def expand(self, mode='standard', **kwargs): warn('Skipping unknown expansion strategy.') return - self.expr_expander.expand(should_expand, **kwargs) + self.expander.expand(should_expand, **kwargs) return self def factorize(self, mode='standard', **kwargs): @@ -384,7 +384,7 @@ def factorize(self, mode='standard', **kwargs): return # Perform the factorization - self.expr_factorizer.factorize(should_factorize, **kwargs) + self.factorizer.factorize(should_factorize, **kwargs) return self def reassociate(self, reorder=None): diff --git a/coffee/scheduler.py b/coffee/scheduler.py index fe45d2eb..8dc5a802 100644 --- a/coffee/scheduler.py +++ b/coffee/scheduler.py @@ -824,8 +824,10 @@ def reschedule(self, root): # read-after-write dependencies) linear_expr_loops = [(l for l in ei.linear_loops) for ei in self.exprs.values()] linear_expr_loops = set(flatten(linear_expr_loops)) - candidates = [l for l in inner_loops(root) if not l.is_linear or l in linear_expr_loops] - candidates = [l for l in candidates if not ExpressionGraph(l.body).has_dependency()] + candidates = [l for l in inner_loops(root) + if not l.is_linear or l in linear_expr_loops] + candidates = [l for l in candidates + if not ExpressionGraph(l.body).has_dependency()] if not candidates: return {} diff --git a/coffee/utils.py b/coffee/utils.py index e08a47ee..7fb92d36 100644 --- a/coffee/utils.py +++ b/coffee/utils.py @@ -141,18 +141,6 @@ def ast_update_rank(node, mapper): s.rank = tuple([r if r not in mapper else mapper[r] for r in s.rank]) -def ast_update_id(symbol, name, id): - """Search for string ``name`` in Symbol ``symbol`` and replaces all of the - occurrences of ``name`` with ``name_id``.""" - if not isinstance(symbol, Symbol): - return - new_name = "%s_%s" % (name, str(id)) - if name == symbol.symbol: - symbol.symbol = new_name - new_rank = [new_name if name == r else r for r in symbol.rank] - symbol.rank = tuple(new_rank) - - ############################################### # Functions to simplify creation of AST nodes # ############################################### @@ -311,17 +299,12 @@ def loops_analysis(node, key='default', value='default'): return lda -def reachability_analysis(node, decls=None): - """Perform reachability analysis in the AST rooted in ``node``. Return +def reachability_analysis(node): + """ + Perform reachability analysis in the AST rooted in ``node``. Return a dictionary mapping symbols to scopes in which they are visible. - - :param decls: an iterator of :class:`Decl`s which are known to be visible - within ``node`` """ - symbols_vis, scopes = SymbolVisibility().visit(node) - for d in decls: - symbols_vis[d].extend(scopes) - return symbols_vis + return SymbolVisibility().visit(node)[0] def explore_operator(node): From 2f7d7c6ce07dcad508ace7219b32f7e32d05e59b Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 6 Dec 2016 14:58:11 +0100 Subject: [PATCH 3/5] Better clean up after CSE --- coffee/cse.py | 6 +----- coffee/utils.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/coffee/cse.py b/coffee/cse.py index 480071dc..9c2b7412 100644 --- a/coffee/cse.py +++ b/coffee/cse.py @@ -505,8 +505,4 @@ def unpick(self): self._push_temporaries(levels[i-1], trace, global_trace, ra, decls) self._transform_temporaries(levels[i], decls) - # Clean up - for transformed_loop, nest in reversed(nests.items()): - for loop, parent in nest: - if loop == transformed_loop and not loop.body: - parent.children.remove(loop) + cleanup(self.header) diff --git a/coffee/utils.py b/coffee/utils.py index 7fb92d36..fb7f3135 100644 --- a/coffee/utils.py +++ b/coffee/utils.py @@ -841,6 +841,25 @@ def remove_empty_loops(node): parent.children.remove(loop) +def remove_unused_decls(node): + """Remove all unused decls within node, which must be of type :class:`Block`.""" + + assert isinstance(node, Block) + + decls = FindInstances(Decl, with_parent=True).visit(node)[Decl] + references = visit(node, info_items=['symbol_refs'])['symbol_refs'] + for d, p in decls: + if len(references[d.sym.symbol]) == 1: + p.children.remove(d) + + +def cleanup(node): + """Remove useless nodes in the AST rooted in node.""" + + remove_empty_loops(node) + remove_unused_decls(node) + + def postprocess(node): """Rearrange the Nodes in the AST rooted in ``node`` to improve the code quality when unparsing the tree.""" From a2c125872087122899f6edfb3982f460f39dc65f Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 6 Dec 2016 15:19:22 +0100 Subject: [PATCH 4/5] FindInstances ---> Find --- coffee/base.py | 4 ++-- coffee/cse.py | 5 ++--- coffee/expression.py | 2 +- coffee/factorizer.py | 3 +-- coffee/hoister.py | 18 ++++++++---------- coffee/optimizer.py | 8 ++++---- coffee/plan.py | 6 +++--- coffee/rewriter.py | 23 ++++++++--------------- coffee/utils.py | 25 ++++++++++--------------- coffee/vectorizer.py | 9 ++++----- coffee/visitors/inspectors.py | 6 +++--- 11 files changed, 46 insertions(+), 63 deletions(-) diff --git a/coffee/base.py b/coffee/base.py index bc519546..c8f4d030 100644 --- a/coffee/base.py +++ b/coffee/base.py @@ -1232,7 +1232,7 @@ def gencode(self, not_scope=False): class Rank(tuple): def __contains__(self, val): - from coffee.visitors import FindInstances + from coffee.visitors import Find if isinstance(val, Node): val, search = str(val), type(Node) elif isinstance(val, str): @@ -1241,7 +1241,7 @@ def __contains__(self, val): return False for i in self: if isinstance(i, Node): - items = FindInstances(search).visit(i) + items = Find(search).visit(i) if any(val == str(i) for i in items[search]): return True elif isinstance(i, str) and val == i: diff --git a/coffee/cse.py b/coffee/cse.py index 9c2b7412..5d846c3c 100644 --- a/coffee/cse.py +++ b/coffee/cse.py @@ -95,7 +95,7 @@ def urepr(self): @property def reads(self): - return FindInstances(Symbol).visit(self.expr)[Symbol] if self.expr else [] + return Find(Symbol).visit(self.expr)[Symbol] if self.expr else [] @property def linear_reads(self): @@ -306,8 +306,7 @@ def _transform_temporaries(self, temporaries, decls): decls.update(OrderedDict([(k, v.decl) for k, v in self.hoisted.items()])) def _analyze_expr(self, expr, loop, lda, decls): - finder = FindInstances(Symbol) - reads = finder.visit(expr, ret=FindInstances.default_retval())[Symbol] + reads = Find(Symbol).visit(expr)[Symbol] reads = [s for s in reads if s.symbol in decls] syms = [s for s in reads if any(d in loop.dim for d in lda[s])] diff --git a/coffee/expression.py b/coffee/expression.py index 96d8e784..5b09c685 100644 --- a/coffee/expression.py +++ b/coffee/expression.py @@ -118,7 +118,7 @@ def out_linear_loops_info(self): @property def reduction_loops(self): - stmts = FindInstances((Writer, Incr)).visit(self.parent) + stmts = Find((Writer, Incr)).visit(self.parent) if stmts[Incr]: writers = flatten(stmts.values()) return tuple(l for l in self.loops diff --git a/coffee/factorizer.py b/coffee/factorizer.py index 53a97d82..48c2d7a4 100644 --- a/coffee/factorizer.py +++ b/coffee/factorizer.py @@ -163,8 +163,7 @@ def _filter(self, factorizable_term): if not grp: return False for f in factorizable_term.factors: - retval = FindInstances.default_retval() - symbols = FindInstances(Symbol).visit(f, ret=retval)[Symbol] + symbols = Find(Symbol).visit(f)[Symbol] if any(s.urepr in grp for s in symbols): return False return True diff --git a/coffee/hoister.py b/coffee/hoister.py index 90303727..3c15dfe3 100644 --- a/coffee/hoister.py +++ b/coffee/hoister.py @@ -53,11 +53,10 @@ def _apply_cse(self): # operations (i.e., a terminal has two Symbols as children). This may # induce more sweeps of extraction to find all common sub-expressions, # but at least it keeps the algorithm simple and probably more effective - finder = FindInstances(Symbol, with_parent=True) + finder = Find(Symbol, with_parent=True) for dep, subexprs in self.extracted.items(): cs = OrderedDict() - retval = FindInstances.default_retval() - values = [finder.visit(e, retval=retval)[Symbol] for e in subexprs] + values = [finder.visit(e)[Symbol] for e in subexprs] binexprs = list(zip(*flatten(values)))[1] binexprs = [b for b in binexprs if binexprs.count(b) > 1] for b in binexprs: @@ -141,10 +140,9 @@ def _filter(self, dep, subexprs, make_unique=True, sharing=None): if dep == self.expr_info.dims: return [] sharing = [str(s) for s in sharing] - finder = FindInstances(Symbol) partitions = defaultdict(list) for e in subexprs: - symbols = tuple(set(str(s) for s in finder.visit(e)[Symbol] + symbols = tuple(set(str(s) for s in Find(Symbol).visit(e)[Symbol] if str(s) in sharing)) partitions[symbols].append(e) for shared, partition in partitions.items(): @@ -157,9 +155,9 @@ def _is_hoistable(self, subexprs, loop): """Return True if the sub-expressions provided in ``subexprs`` are hoistable outside of ``loop``, False otherwise.""" written = in_written(loop, 'symbol') - finder, reads = FindInstances(Symbol), FindInstances.default_retval() + reads = Find.default_retval() for e in subexprs: - finder.visit(e, ret=reads) + Find(Symbol).visit(e, ret=reads) reads = [s.symbol for s in reads[Symbol]] return set.isdisjoint(set(reads), set(written)) @@ -284,7 +282,7 @@ def trim(self, candidate, **kwargs): lda = kwargs.get('lda') or loops_analysis(self.header) reducible, other = [], [] for i in summands(self.stmt.rvalue): - symbols = FindInstances(Symbol).visit(i)[Symbol] + symbols = Find(Symbol).visit(i)[Symbol] unavoidable = set.intersection(*[set(lda[s]) for s in symbols]) if candidate in unavoidable: return @@ -293,7 +291,7 @@ def trim(self, candidate, **kwargs): # Make sure we do not break data dependencies make_reduce = [] - writes = FindInstances(Writer).visit(candidate) + writes = Find(Writer).visit(candidate) for w in flatten(writes.values()): if isinstance(w.rvalue, EmptyStatement): continue @@ -343,7 +341,7 @@ def trim(self, candidate, **kwargs): # Clean up removing any now unnecessary symbols reads = in_read(candidate, key='symbol') - declarations = FindInstances(Decl, with_parent=True).visit(self.header)[Decl] + declarations = Find(Decl, with_parent=True).visit(self.header)[Decl] declarations = dict(declarations) for w, p in make_reduce: if w.lvalue.symbol not in reads: diff --git a/coffee/optimizer.py b/coffee/optimizer.py index cabd105f..fb4cd231 100644 --- a/coffee/optimizer.py +++ b/coffee/optimizer.py @@ -47,7 +47,7 @@ from .rewriter import ExpressionRewriter from .cse import CSEUnpicker from .logger import warn -from coffee.visitors import FindInstances, ProjectExpansion +from coffee.visitors import Find, ProjectExpansion from functools import reduce @@ -267,7 +267,7 @@ def _dissect(self, heuristics): to_unroll = [(l, p) for l, p in nest if l not in expr_info.loops] unroll_cost = reduce(operator.mul, (l.size for l, p in to_unroll)) - nest_writers = FindInstances(Writer).visit(to_unroll[0][0]) + nest_writers = Find(Writer).visit(to_unroll[0][0]) for op, i_stmts in nest_writers.items(): # Check safety of unrolling if op in [Assign, IMul, IDiv]: @@ -286,7 +286,7 @@ def _dissect(self, heuristics): for l, p in reversed(to_unroll): i_expr = [dcopy(i_expr) for i in range(l.size)] for i, e in enumerate(i_expr): - e_syms = FindInstances(Symbol).visit(e)[Symbol] + e_syms = Find(Symbol).visit(e)[Symbol] for s in e_syms: s.rank = tuple([r if r != l.dim else i for r in s.rank]) i_expr = ast_make_expr(Sum, i_expr) @@ -308,7 +308,7 @@ def find_save(target_expr, expr_info): # that will /not/ be pre-evaluated. To obtain this number, we # can exploit the linearity of the expression in the terms # depending on the linear loops. - syms = FindInstances(Symbol).visit(target_expr)[Symbol] + syms = Find(Symbol).visit(target_expr)[Symbol] inner = lambda s: any(r == expr_info.linear_dims[-1] for r in s.rank) nterms = len(set(s.symbol for s in syms if inner(s))) save = nterms * save_factor diff --git a/coffee/plan.py b/coffee/plan.py index 51fb999a..9504a42b 100644 --- a/coffee/plan.py +++ b/coffee/plan.py @@ -41,7 +41,7 @@ from .vectorizer import LoopVectorizer, VectStrategy from .expression import MetaExpr from .logger import log, warn, PERF_OK, PERF_WARN -from coffee.visitors import FindInstances, EstimateFlops +from coffee.visitors import Find, EstimateFlops from collections import defaultdict, OrderedDict import time @@ -70,7 +70,7 @@ def plan_cpu(self, opts): start_time = time.time() - kernels = FindInstances(FunDecl, stop_when_found=True).visit(self.ast)[FunDecl] + kernels = Find(FunDecl, stop_when_found=True).visit(self.ast)[FunDecl] if opts is None: opts = coffee.OptimizationLevel.retrieve(coffee.options['optimizations']) @@ -176,7 +176,7 @@ def plan_gpu(self): # The optimization passes are performed individually (i.e., "locally") for # each function (or "kernel") found in the provided AST - kernels = FindInstances(FunDecl, stop_when_found=True).visit(self.ast)[FunDecl] + kernels = Find(FunDecl, stop_when_found=True).visit(self.ast)[FunDecl] for kernel in kernels: info = visit(kernel, info_items=['decls', 'exprs']) diff --git a/coffee/rewriter.py b/coffee/rewriter.py index 1f68ff27..de9710ee 100644 --- a/coffee/rewriter.py +++ b/coffee/rewriter.py @@ -259,8 +259,7 @@ def expand(self, mode='standard', **kwargs): """ if mode == 'standard': - retval = FindInstances.default_retval() - symbols = FindInstances(Symbol).visit(self.stmt.rvalue, ret=retval)[Symbol] + symbols = Find(Symbol).visit(self.stmt.rvalue)[Symbol] # The heuristics privileges linear dimensions dims = self.expr_info.out_linear_dims if not dims or self.expr_info.dimension >= 2: @@ -340,8 +339,7 @@ def factorize(self, mode='standard', **kwargs): """ if mode == 'standard': - retval = FindInstances.default_retval() - symbols = FindInstances(Symbol).visit(self.stmt.rvalue, ret=retval)[Symbol] + symbols = Find(Symbol).visit(self.stmt.rvalue)[Symbol] # The heuristics privileges linear dimensions dims = self.expr_info.out_linear_dims if not dims or self.expr_info.dimension >= 2: @@ -432,8 +430,7 @@ def _reassociate(node, parent): def replacediv(self): """Replace divisions by a constant with multiplications.""" - retval = FindInstances.default_retval() - divisions = FindInstances(Div).visit(self.stmt.rvalue, ret=retval)[Div] + divisions = Find(Div).visit(self.stmt.rvalue)[Div] to_replace = {} for i in divisions: if isinstance(i.right, Symbol): @@ -457,8 +454,7 @@ def preevaluate(self): if not isinstance(stmt, (Incr, Decr, IMul, IDiv)): # Not a reduction expression, give up return - retval = FindInstances.default_retval() - expr_syms = FindInstances(Symbol).visit(stmt.rvalue, ret=retval)[Symbol] + expr_syms = Find(Symbol).visit(stmt.rvalue)[Symbol] reduction_loops = expr_info.out_linear_loops_info if any([not is_perfect_loop(l) for l, p in reduction_loops]): # Unsafe if not a perfect loop nest @@ -466,7 +462,7 @@ def preevaluate(self): # The following check is because it is unsafe to simplify if non-loop or # non-constant dimensions are present hoisted_stmts = self.hoisted.all_stmts - hoisted_syms = [FindInstances(Symbol).visit(h)[Symbol] for h in hoisted_stmts] + hoisted_syms = [Find(Symbol).visit(h)[Symbol] for h in hoisted_stmts] hoisted_dims = [s.rank for s in flatten(hoisted_syms)] hoisted_dims = set([r for r in flatten(hoisted_dims) if not is_const_dim(r)]) if any(d not in expr_info.dims for d in hoisted_dims): @@ -474,9 +470,7 @@ def preevaluate(self): # not being a loop iteration variable return for i, (l, p) in enumerate(reduction_loops): - retval = SymbolDependencies.default_retval() - syms_dep = SymbolDependencies().visit(l, ret=retval, - **SymbolDependencies.default_args) + syms_dep = SymbolDependencies().visit(l, **SymbolDependencies.default_args) if not all([tuple(syms_dep[s]) == expr_info.loops and s.dim == len(expr_info.loops) for s in expr_syms if syms_dep[s]]): # A sufficient (although not necessary) condition for loop reduction to @@ -490,10 +484,9 @@ def preevaluate(self): if not all([s.symbol in self.hoisted for s in reducible_syms]): return # Replace hoisted assignments with reductions - finder = FindInstances(Assign, stop_when_found=True, with_parent=True) + finder = Find(Assign, stop_when_found=True, with_parent=True) for hoisted_loop in self.hoisted.all_loops: - retval = FindInstances.default_retval() - for assign, parent in finder.visit(hoisted_loop, ret=retval)[Assign]: + for assign, parent in finder.visit(hoisted_loop)[Assign]: sym, expr = assign.children decl = self.hoisted[sym.symbol].decl if sym.symbol in [s.symbol for s in reducible_syms]: diff --git a/coffee/utils.py b/coffee/utils.py index fb7f3135..dacf7c88 100644 --- a/coffee/utils.py +++ b/coffee/utils.py @@ -107,7 +107,7 @@ def ast_update_ofs(node, ofs, **kwargs): """ increase = kwargs.get('increase', False) - symbols = FindInstances(Symbol).visit(node, ret=FindInstances.default_retval())[Symbol] + symbols = Find(Symbol).visit(node)[Symbol] for s in symbols: new_offset = [] for r, o in zip(s.rank, s.offset): @@ -135,9 +135,7 @@ def ast_update_rank(node, mapper): transformed into 'A[j] = B[j]' """ - retval = FindInstances.default_retval() - FindInstances(Symbol).visit(node, ret=retval) - for s in retval[Symbol]: + for s in Find(Symbol).visit(node)[Symbol]: s.rank = tuple([r if r not in mapper else mapper[r] for r in s.rank]) @@ -353,7 +351,7 @@ def in_written(node, key='default'): raise RuntimeError("Illegal key=%s for in_written" % key) found = [] - writers = FindInstances(Writer).visit(node) + writers = Find(Writer).visit(node) for type, stmts in writers.items(): for stmt in stmts: found.append(gen_key(stmt.lvalue)) @@ -380,10 +378,10 @@ def in_read(node, key='default'): raise RuntimeError("Illegal key=%s for in_read" % key) found = [] - writers = FindInstances(Writer).visit(node) + writers = Find(Writer).visit(node) for type, stmts in writers.items(): for stmt in stmts: - reads = FindInstances(Symbol).visit(stmt.rvalue)[Symbol] + reads = Find(Symbol).visit(stmt.rvalue)[Symbol] found.extend([gen_key(s) for s in reads]) return found @@ -709,7 +707,7 @@ def __init__(self, node): :param node: root of the AST visited to initialize the ExpressionGraph. """ self.deps = nx.DiGraph() - writes = FindInstances(Writer).visit(node, ret=FindInstances.default_retval()) + writes = Find(Writer).visit(node) for type, nodes in writes.items(): for n in nodes: if isinstance(n.rvalue, EmptyStatement): @@ -718,8 +716,7 @@ def __init__(self, node): def add_dependency(self, sym, expr): """Add dependency between ``sym`` and symbols appearing in ``expr``.""" - retval = FindInstances.default_retval() - expr_symbols = FindInstances(Symbol).visit(expr, ret=retval)[Symbol] + expr_symbols = Find(Symbol).visit(expr)[Symbol] for es in expr_symbols: self.deps.add_edge(sym.symbol, es.symbol) @@ -736,8 +733,7 @@ def is_read(self, expr, target_sym=None): """Return True if any symbols in ``expr`` is read by ``target_sym``, False otherwise. If ``target_sym`` is None, Return True if any symbols in ``expr`` are read by at least one symbol, False otherwise.""" - retval = FindInstances.default_retval() - input_syms = FindInstances(Symbol).visit(expr, ret=retval)[Symbol] + input_syms = Find(Symbol).visit(expr)[Symbol] for s in input_syms: if s.symbol not in self.deps: continue @@ -752,8 +748,7 @@ def is_written(self, expr, target_sym=None): """Return True if any symbols in ``expr`` is written by ``target_sym``, False otherwise. If ``target_sym`` is None, Return True if any symbols in ``expr`` are written by at least one symbol, False otherwise.""" - retval = FindInstances.default_retval() - input_syms = FindInstances(Symbol).visit(expr, ret=retval)[Symbol] + input_syms = Find(Symbol).visit(expr)[Symbol] for s in input_syms: if s.symbol not in self.deps: continue @@ -846,7 +841,7 @@ def remove_unused_decls(node): assert isinstance(node, Block) - decls = FindInstances(Decl, with_parent=True).visit(node)[Decl] + decls = Find(Decl, with_parent=True).visit(node)[Decl] references = visit(node, info_items=['symbol_refs'])['symbol_refs'] for d, p in decls: if len(references[d.sym.symbol]) == 1: diff --git a/coffee/vectorizer.py b/coffee/vectorizer.py index 95650c29..d22e6444 100644 --- a/coffee/vectorizer.py +++ b/coffee/vectorizer.py @@ -43,7 +43,7 @@ from .utils import * from . import system from .logger import warn -from coffee.visitors import FindInstances +from coffee.visitors import Find class VectStrategy(object): @@ -149,7 +149,7 @@ def autovectorize(self, p_dim=-1): def _pad(self, p_dim, decls, fors, symbols_dep, symbols_mode, symbol_refs): """Apply padding.""" - to_invert = FindInstances(Invert).visit(self.header)[Invert] + to_invert = Find(Invert).visit(self.header)[Invert] # Loop increments different than 1 are unsupported if any([l.increment != 1 for l, _ in flatten(fors)]): @@ -314,7 +314,7 @@ def _align_data(self, p_dim, decls): should_round = False break - symbols = [sym] + FindInstances(Symbol).visit(expr)[Symbol] + symbols = [sym] + Find(Symbol).visit(expr)[Symbol] symbols = [s for s in symbols if s.rank and any(r == l.dim for r in s.rank)] # Condition D: the access pattern must be accessible @@ -369,8 +369,7 @@ def _align_data(self, p_dim, decls): def _transpose_layout(self, decls): dim = self.loop.dim - retval = FindInstances.default_retval() - symbols = FindInstances(Symbol).visit(self.loop, ret=retval)[Symbol] + symbols = Find(Symbol).visit(self.loop)[Symbol] symbols = [s for s in symbols if any(r == dim for r in s.rank) and s.dim > 1] # Cannot handle arrays with more than 2 dimensions diff --git a/coffee/visitors/inspectors.py b/coffee/visitors/inspectors.py index 90948295..82ff83bf 100644 --- a/coffee/visitors/inspectors.py +++ b/coffee/visitors/inspectors.py @@ -7,7 +7,7 @@ __all__ = ["FindInnerLoops", "CheckPerfectLoop", "CountOccurences", "FindLoopNests", "FindCoffeeExpressions", "SymbolReferences", "SymbolDependencies", "SymbolModes", "SymbolDeclarations", - "SymbolVisibility", "FindInstances", "FindExpression"] + "SymbolVisibility", "Find", "FindExpression"] class FindInnerLoops(Visitor): @@ -542,7 +542,7 @@ def visit_Node(self, o, ret=None, *args, **kwargs): return ret -class FindInstances(Visitor): +class Find(Visitor): @classmethod def default_retval(cls): @@ -561,7 +561,7 @@ def __init__(self, types, stop_when_found=False, with_parent=False): self.types = types self.stop_when_found = stop_when_found self.with_parent = with_parent - super(FindInstances, self).__init__() + super(Find, self).__init__() def useless_traversal(self, o): """ From 963244441153ce8495aa22f7762ac95fd6c25c77 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 6 Dec 2016 16:00:09 +0100 Subject: [PATCH 5/5] Search ArrayInit in symbol references visitor --- coffee/visitors/inspectors.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/coffee/visitors/inspectors.py b/coffee/visitors/inspectors.py index 82ff83bf..d54e49cb 100644 --- a/coffee/visitors/inspectors.py +++ b/coffee/visitors/inspectors.py @@ -278,6 +278,11 @@ def visit_Symbol(self, o, ret=None, parent=None): ret[o.symbol].append((o, parent)) return ret + def visit_ArrayInit(self, o, ret=None, *args, **kwargs): + for entry in o.values: + ret = self.visit(entry, ret=ret, *args, **kwargs) + return ret + def visit_object(self, o, ret=None, *args, **kwargs): # Identity return ret