Skip to content

Commit

Permalink
simplify inlining array accesses (#735)
Browse files Browse the repository at this point in the history
* simplify inlining array accesses

* add a test

* make the code easier to read

* Add another test
  • Loading branch information
isuruf authored Jan 25, 2023
1 parent 95a9dc4 commit 5cb9dc1
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 35 deletions.
40 changes: 6 additions & 34 deletions loopy/transform/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from loopy.kernel.instruction import (CallInstruction, MultiAssignmentBase,
Assignment, CInstruction, _DataObliviousInstruction)
from loopy.symbolic import (
simplify_using_aff,
RuleAwareIdentityMapper,
RuleAwareSubstitutionMapper, SubstitutionRuleMappingContext)
from loopy.kernel.function_interface import (
Expand Down Expand Up @@ -130,43 +129,16 @@ def __init__(self, rule_mapping_context, caller_knl,

def map_subscript(self, expr, expn_state):
if expr.aggregate.name in self.callee_knl.arg_dict:
from loopy.symbolic import get_start_subscript_from_sar
from loopy.symbolic import simplify_via_aff
from pymbolic.primitives import Subscript, Variable
from pymbolic import substitute

sar = self.callee_arg_to_call_param[expr.aggregate.name] # SubArrayRef

callee_arg = self.callee_knl.arg_dict[expr.aggregate.name]
if sar.subscript.aggregate.name in self.caller_knl.arg_dict:
caller_arg = self.caller_knl.arg_dict[sar.subscript.aggregate.name]
else:
caller_arg = self.caller_knl.temporary_variables[
sar.subscript.aggregate.name]

flatten_index = 0
for i, idx in enumerate(get_start_subscript_from_sar(sar,
self.caller_knl).index_tuple):
flatten_index += idx*caller_arg.dim_tags[i].stride

flatten_index += sum(
idx * self.rec(tag.stride, expn_state)
for idx, tag in zip(self.rec(expr.index_tuple, expn_state),
callee_arg.dim_tags))

flatten_index = simplify_via_aff(flatten_index)

new_indices = []
for dim_tag in caller_arg.dim_tags:
if dim_tag.stride != 0:
ind = flatten_index // dim_tag.stride
else:
# argument has 0-stride i.e. doesn't matter how we index into it.
ind = 0
flatten_index -= (dim_tag.stride * ind)
new_indices.append(ind)

new_indices = tuple(simplify_using_aff(
self.callee_knl, i) for i in new_indices)
index_tuple = self.rec(expr.index_tuple, expn_state)
subs_map = {iname: idx for idx, iname in
zip(index_tuple, sar.swept_inames)}
new_indices = tuple(substitute(idx, subs_map) for idx in
sar.subscript.index_tuple)

return Subscript(Variable(sar.subscript.aggregate.name), new_indices)
else:
Expand Down
58 changes: 57 additions & 1 deletion test/test_callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,7 @@ def test_inline_stride():
dtype=np.float64,
shape=("n", "n")),
...],
assumptions="n>=1",
assumptions="n>=2",
)
knl = lp.merge([parent_knl, child_knl])
knl = lp.inline_callable_kernel(knl, "linear_combo")
Expand Down Expand Up @@ -1458,6 +1458,62 @@ def test_inline_predicate():
assert code.count("if (a)") == 1


def test_subarray_ref_with_repeated_indices(ctx_factory):
# https://github.com/inducer/loopy/pull/735#discussion_r1071690388

ctx = ctx_factory()
cq = cl.CommandQueue(ctx)
child_knl = lp.make_function(
["{[i]: 0<=i<10}"],
"""
g[i] = 1
""", name="ones")

parent_knl = lp.make_kernel(
["{[i]:0<=i<10}", "{[j]: 0<=j<10}"],
"""
z[i, j] = 0 {id = a}
[i]: z[i, i] = ones() {dep=a,dup=i}
""",
kernel_data=[
lp.GlobalArg(
name="z",
dtype=np.float64,
is_input=False,
shape=(10, 10)),
...],
)
knl = lp.merge([parent_knl, child_knl])
knl = lp.inline_callable_kernel(knl, "ones")
evt, (z_dev,) = knl(cq)
assert np.allclose(z_dev.get(), np.eye(10))


def test_inline_constant_access():
child_knl = lp.make_function(
[],
"""
g[0] = 2*e[0] + 3*f[0] {id=a}
g[1] = 2*e[1] + 3*f[1] {dep=a}
""", name="linear_combo")
parent_knl = lp.make_kernel(
["{[j]:0<=j<n}", "{[i]:0<=i<n}"],
"""
[i]: z[i, j] = linear_combo([i]: x[i, j], [i]: y[i,j])
""",
kernel_data=[
lp.GlobalArg(
name="x, y, z",
dtype=np.float64,
shape=(3, "n")),
...],
)
knl = lp.merge([parent_knl, child_knl])
knl = lp.inline_callable_kernel(knl, "linear_combo")
knl = lp.tag_array_axes(knl, ["x", "y", "z"], "sep,C")
lp.generate_code_v2(knl).device_code()


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit 5cb9dc1

Please sign in to comment.