Skip to content

Commit

Permalink
fix einsum subscript parsing for specs with spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd authored and inducer committed Feb 15, 2023
1 parent 5cb9dc1 commit 9fc33ae
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
3 changes: 3 additions & 0 deletions loopy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,9 @@ def make_einsum(spec, arg_names, **knl_creation_kwargs):
arg_spec, out_spec = spec.split("->")
arg_specs = arg_spec.split(",")

out_spec = out_spec.strip()
arg_specs = [arg_spec.strip() for arg_spec in arg_specs]

if len(arg_names) != len(arg_specs):
raise ValueError(
f"Number of arg names ({arg_names}) should match the number "
Expand Down
10 changes: 10 additions & 0 deletions test/test_loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3473,6 +3473,16 @@ def test_type_inference_of_clbls_in_substitutions(ctx_factory):
np.testing.assert_allclose(out.get(), np.abs(10.0*(np.arange(10)-5)))


def test_einsum_parsing(ctx_factory):
ctx = ctx_factory()

# See <https://github.com/inducer/loopy/issues/753>
knl = lp.make_einsum("ik, kj -> ij", ["A", "B"])
knl = lp.add_dtypes(knl, {"A": np.float32, "B": np.float32})
lp.auto_test_vs_ref(knl, ctx, knl,
parameters={"Ni": 10, "Nj": 10, "Nk": 10})


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit 9fc33ae

Please sign in to comment.