From 01801e2f143dd10c60d9e072ac3879b8e8d2d5a6 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 9 Jan 2023 14:51:32 +0530 Subject: [PATCH] introduce lp.decouple_domain --- loopy/__init__.py | 3 ++ loopy/transform/domain.py | 90 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) create mode 100644 loopy/transform/domain.py diff --git a/loopy/__init__.py b/loopy/__init__.py index 7491de6cc..3102174e5 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -122,6 +122,7 @@ from loopy.transform.parameter import assume, fix_parameters from loopy.transform.save import save_and_reload_temporaries from loopy.transform.add_barrier import add_barrier +from loopy.transform.domain import decouple_domain from loopy.transform.callable import (register_callable, merge, inline_callable_kernel, rename_callable) from loopy.transform.pack_and_unpack_args import pack_and_unpack_args_for_call @@ -251,6 +252,8 @@ "add_barrier", + "decouple_domain", + "register_callable", "merge", diff --git a/loopy/transform/domain.py b/loopy/transform/domain.py new file mode 100644 index 000000000..3cbdceb14 --- /dev/null +++ b/loopy/transform/domain.py @@ -0,0 +1,90 @@ +__copyright__ = "Copyright (C) 2023 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +__doc__ = """ +.. currentmodule:: loopy + +.. autofunction:: decouple_domain +""" + +import islpy as isl + +from loopy.translation_unit import for_each_kernel +from loopy.kernel import LoopKernel +from loopy.diagnostic import LoopyError +from collections.abc import Collection + + +@for_each_kernel +def decouple_domain(kernel: LoopKernel, + inames: Collection[str], + parent_inames: Collection[str]) -> LoopKernel: + r""" + Returns a copy of *kernel* with altered domains. The home domain of + *inames* i.e. :math:`\mathcal{D}^{\text{home}}({\text{inames}})` is + replaced with two domains :math:`\mathcal{D}_1` and :math:`\mathcal{D}_2`. + :math:`\mathcal{D}_1` is the domain with dimensions corresponding to *inames* + projected out and :math:`\mathcal{D}_2` is the domain with all the dimensions + other than the ones corresponding to *inames* projected out. + + .. note:: + + An error is raised if all the *inames* do not correspond to the same home + domain of *kernel*. + """ + + if not inames: + raise LoopyError("No inames were provided to decouple into" + " a different domain.") + + hdi = kernel.get_home_domain_index(next(iter(inames))) + for iname in inames: + if kernel.get_home_domain_index(iname) != hdi: + raise LoopyError("inames are not a part of the same home domain.") + + for parent_iname in parent_inames: + if parent_iname not in set(kernel.domains[hdi].get_var_dict()): + raise LoopyError(f"Parent iname '{parent_iname}' not a part of the" + f" corresponding home domain '{kernel.domains[hdi]}'.") + + all_dims = frozenset(kernel.domains[hdi].get_var_dict()) + D1 = kernel.domains[hdi] + D2 = kernel.domains[hdi] + + for iname in sorted(all_dims): + if iname in inames: + dt, pos = D1.get_var_dict()[iname] + D1 = D1.project_out(dt, pos, 1) + elif iname in parent_inames: + dt, pos = D2.get_var_dict()[iname] + if dt != isl.dim_type.param: + n_params = D2.dim(isl.dim_type.param) + D2 = D2.move_dims(isl.dim_type.param, n_params, dt, pos, 1) + else: + dt, pos = D2.get_var_dict()[iname] + D2 = D2.project_out(dt, pos, 1) + + new_domains = kernel.domains[:] + new_domains[hdi] = D1 + new_domains.append(D2) + kernel = kernel.copy(domains=new_domains) + return kernel