Skip to content

Commit

Permalink
Enable parameter studies in meshmode.
Browse files Browse the repository at this point in the history
  • Loading branch information
nkoskelo committed Jan 28, 2025
1 parent 7fd9109 commit d563426
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 13 deletions.
182 changes: 169 additions & 13 deletions meshmode/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@
PyOpenCLArrayContext as PyOpenCLArrayContextBase,
PytatoPyOpenCLArrayContext as PytatoPyOpenCLArrayContextBase,
)

from arraycontext.parameter_study import (
ParamStudyPytatoPyOpenCLArrayContext as ParameterStudyPytatoPyOpenCLArrayContextBase,
ParameterStudyAxisTag,
)

from arraycontext.pytest import (
_PytestPyOpenCLArrayContextFactoryWithClass,
_PytestPytatoPyOpenCLArrayContextFactory,
Expand All @@ -61,7 +67,8 @@
DiscretizationTopologicalDimAxisTag,
DiscretizationAmbientDimAxisTag,
DiscretizationFlattenedDOFAxisTag,
DiscretizationEntityAxisTag)
DiscretizationEntityAxisTag,
DiscretizationNICKSAxisTag,)
from dataclasses import dataclass

if TYPE_CHECKING:
Expand Down Expand Up @@ -282,7 +289,7 @@ def transform_loopy_program(self, t_unit):

# {{{ pytato pyopencl array context subclass

class PytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContextBase):
class PytatoPyOpenCLArrayContext(ParameterStudyPytatoPyOpenCLArrayContextBase):
def transform_dag(self, dag):
dag = super().transform_dag(dag)

Expand Down Expand Up @@ -605,7 +612,7 @@ def cached_data_wrapper_if_present(ary):
return dag


class SingleGridWorkBalancingPytatoArrayContext(PytatoPyOpenCLArrayContextBase):
class SingleGridWorkBalancingPytatoArrayContext(ParameterStudyPytatoPyOpenCLArrayContextBase):
"""
A :class:`PytatoPyOpenCLArrayContext` that parallelizes work in an OpenCL
kernel so that the work
Expand Down Expand Up @@ -933,6 +940,12 @@ def fuse_same_discretization_entity_loops(knl):
True,
orig_knl)

from meshmode.transform_metadata import (DiscretizationNICKSAxisTag)
knl = _fuse_loops_over_a_discr_entity(knl, DiscretizationNICKSAxisTag,
"inick",
True,
orig_knl)

return knl


Expand Down Expand Up @@ -1043,6 +1056,13 @@ def _get_iel_to_idofs(kernel):
.tags_of_type(DiscretizationDimAxisTag))
}

inick_inames = {iname
for iname in kernel.all_inames()
if (kernel
.inames[iname]
.tags_of_type(DiscretizationNICKSAxisTag))
}

iel_to_idofs = {iel: set() for iel in iel_inames}

for insn in kernel.instructions:
Expand All @@ -1058,6 +1078,19 @@ def _get_iel_to_idofs(kernel):
raise NotImplementedError(f"The <iel> loop {insn.within_inames}"
" does not appear as a singly nested"
" loop.")
elif (len(insn.within_inames) == 1
and (insn.within_inames) <= inick_inames):
inick, = insn.within_inames
if all(kernel.id_to_insn[el_insn].within_inames == insn.within_inames
for el_insn in kernel.iname_to_insns()[inick]):
# the iel here doesn't interfere with any idof i.e. we
# support parallelizing such loops.
raise NotImplementedError("Just NICKS LOOP")
else:
raise NotImplementedError(f"The <iel> loop {insn.within_inames}"
" does not appear as a singly nested"
" loop.")

elif ((len(insn.within_inames) == 2)
and (len(insn.within_inames & iel_inames) == 1)
and (len(insn.within_inames & idof_inames) == 1)):
Expand All @@ -1074,8 +1107,8 @@ def _get_iel_to_idofs(kernel):
elif ((len(insn.within_inames) > 2)
and (len(insn.within_inames & iel_inames) == 1)
and (len(insn.within_inames & idof_inames) == 1)
and (len(insn.within_inames & (idim_inames | iface_inames))
== (len(insn.within_inames) - 2))):
and (len(insn.within_inames & (idim_inames | iface_inames | inick_inames))
== (len(insn.within_inames) - 2))):
iel, = insn.within_inames & iel_inames
idof, = insn.within_inames & idof_inames
iel_to_idofs[iel].add(idof)
Expand All @@ -1086,6 +1119,8 @@ def _get_iel_to_idofs(kernel):
else:
raise NotImplementedError("Could not fit into <iel,idof,iface>"
" loop nest pattern.")
#elif ((len(insn.within_inames) == 0)):
# pass
else:
raise NotImplementedError(f"Cannot fit loop nest '{insn.within_inames}'"
" into known set of loop-nest patterns.")
Expand Down Expand Up @@ -1132,7 +1167,8 @@ def _prepare_kernel_for_parallelization(kernel):
DiscretizationAmbientDimAxisTag: "idim",
DiscretizationTopologicalDimAxisTag: "idim",
DiscretizationFlattenedDOFAxisTag: "imsh_nodes",
DiscretizationFaceAxisTag: "iface"}
DiscretizationFaceAxisTag: "iface",
DiscretizationNICKSAxisTag: "inick",} # treat as new..
import loopy as lp
from loopy.match import ObjTagged

Expand All @@ -1144,7 +1180,6 @@ def _prepare_kernel_for_parallelization(kernel):
for insn in kernel.instructions:
inames = insn.within_inames | insn.reduction_inames()
ensm_buckets.setdefault(tuple(sorted(inames)), set()).add(insn.id)

# FIXME: Dependency violation is a big concern here
# Waiting on the loopy feature: https://github.com/inducer/loopy/issues/550

Expand Down Expand Up @@ -1261,8 +1296,8 @@ def __init__(
self, queue: "cl.CommandQueue", allocator=None, *,
use_memory_pool: Optional[bool] = None,
compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None,
use_axis_tag_inference_fallback: bool = False,
use_einsum_inference_fallback: bool = False,
use_axis_tag_inference_fallback: bool = False, # Get the inference
use_einsum_inference_fallback: bool = False, # Get the inference

# do not use: only for testing
_force_svm_arg_limit: Optional[int] = None,
Expand All @@ -1282,6 +1317,7 @@ def __init__(
def transform_dag(self, dag):
import pytato as pt


# {{{ Remove FEMEinsumTags that might have been propagated

# TODO: Is this too hacky?
Expand Down Expand Up @@ -1419,13 +1455,11 @@ def materialize_inverse_mass_inputs(expr):
# {{{ materialize all einsums

def materialize_all_einsums_or_reduces(expr):
from pytato.raising import (index_lambda_to_high_level_op,
ReduceOp)

if isinstance(expr, pt.Einsum):
return expr.tagged(pt.tags.ImplStored())
elif (isinstance(expr, pt.IndexLambda)
and isinstance(index_lambda_to_high_level_op(expr), ReduceOp)):
and (len(expr.var_to_reduction_descr)> 0)):
return expr.tagged(pt.tags.ImplStored())
else:
return expr
Expand All @@ -1436,6 +1470,23 @@ def materialize_all_einsums_or_reduces(expr):

# }}}

# {{{ attach Nick's Discretization Tag to the operators if they are uncertain.

from arraycontext.parameter_study import ParameterStudyAxisTag
from meshmode.transform_metadata import (DiscretizationNICKSAxisTag)

def add_nicks_disc_tag(expr):
if isinstance(expr, pt.Array):
new_expr = expr
for iaxis, axis in enumerate(expr.axes):
if axis.tags_of_type(ParameterStudyAxisTag):
new_expr = new_expr.with_tagged_axis(iaxis, [DiscretizationNICKSAxisTag()])
return new_expr
return expr

dag = pt.transform.map_and_copy(dag, add_nicks_disc_tag)
# }}}

# {{{ infer axis types

from meshmode.pytato_utils import unify_discretization_entity_tags
Expand Down Expand Up @@ -1532,6 +1583,7 @@ def _get_rid_of_broadcasts_from_einsum(expr):

# }}}


# {{{ remove any PartID tags

# FIXME: Remove after https://github.com/inducer/pytato/pull/393 goes in
Expand All @@ -1552,6 +1604,8 @@ def remove_part_id_tags(expr):

# }}}



# {{{ attach FEMEinsumTag tags

dag_outputs = frozenset(dag._data.values())
Expand Down Expand Up @@ -1666,9 +1720,18 @@ def transform_loopy_program(self, t_unit):
# {{{ fallback: if the inames are not inferred which mesh entity they
# iterate over.

HAS_PARAM_STUDY = True
for iname in knl.all_inames():

if knl.iname_tags_of_type(iname, ParameterStudyAxisTag):
HAS_PARAM_STUDY = True
knl.inames[iname] = knl.inames[iname].tagged([DiscretizationNICKSAxisTag()]) # Tag it with this entity axis tag to advance the program.

if not knl.iname_tags_of_type(iname, DiscretizationEntityAxisTag):
if HAS_PARAM_STUDY:
return super().transform_loopy_program(original_t_unit)
if not self.use_axis_tag_inference_fallback:
breakpoint()
raise AxisTagInferenceError("Unable to infer axis tags.")
else:
warn(f"[{knl.name}]: Falling back to a slower transformation"
Expand All @@ -1677,6 +1740,12 @@ def transform_loopy_program(self, t_unit):
stacklevel=2)
return super().transform_loopy_program(original_t_unit)

if __debug__:
breakpoint()
for iname in knl.all_inames():
if knl.iname_tags_of_type(iname, ParameterStudyAxisTag):
assert knl.iname_tags_of_type(iname, DiscretizationNICKSAxisTag)

for insn in knl.instructions:
for assignee in insn.assignee_var_names():
var = knl.get_var_descriptor(assignee)
Expand Down Expand Up @@ -1705,11 +1774,31 @@ def transform_loopy_program(self, t_unit):
with ProcessLogger(logger, "Loop Fusion"):
knl = fuse_same_discretization_entity_loops(knl)

# }}}
# {{{ ENSURE NICKS INAMES ARE DOUBLE TAGGED.

for iname in knl.all_inames():
if knl.iname_tags_of_type(iname, ParameterStudyAxisTag):
assert knl.iname_tags_of_type(iname, DiscretizationNICKSAxisTag)

# }}}

# {{{ align kernels for fused einsums

knl = _prepare_kernel_for_parallelization(knl)


# {{{ ENSURE NICKS INAMES ARE DOUBLE TAGGED.
for iname in knl.all_inames():
if knl.iname_tags_of_type(iname, ParameterStudyAxisTag):
knl.inames[iname] = knl.inames[iname].tagged([DiscretizationNICKSAxisTag()]) # Tag it with this entity axis tag to advance the program.

for iname in knl.all_inames():
if knl.iname_tags_of_type(iname, ParameterStudyAxisTag):
assert knl.iname_tags_of_type(iname, DiscretizationNICKSAxisTag)

# }}}

knl = _combine_einsum_domains(knl)

# }}}
Expand Down Expand Up @@ -1792,6 +1881,9 @@ def transform_loopy_program(self, t_unit):
except NotImplementedError as err:
if knl.tags_of_type(FromArrayContextCompile):
raise err
elif err == "Just NICKS LOOP":
return super().transform_loopy_program(original_t_unit)
return original_t_unit # We are going to do nothing in this case.
else:
warn(f"[{knl.name}]: FusionContractorArrayContext."
"transform_loopy_program not broad enough (yet)."
Expand All @@ -1818,6 +1910,7 @@ def transform_loopy_program(self, t_unit):
# {{{ Parallelization strategy: Use feinsum

t_unit = t_unit.with_kernel(knl)

del knl

if False and t_unit.default_entrypoint.tags_of_type(FromArrayContextCompile):
Expand Down Expand Up @@ -1877,12 +1970,75 @@ def transform_loopy_program(self, t_unit):
else:
knl = lp.split_iname(knl, iel, 32,
outer_tag="g.0", inner_tag="l.0")
inick_inames = {iname
for iname in knl.all_inames()
if (knl
.inames[iname]
.tags_of_type(DiscretizationNICKSAxisTag))
}
if inick_inames:
for iname in inick_inames:
tags = knl.inames[iname].tags_of_type(ParameterStudyAxisTag)
if tags:
#knl = lp.untag_inames(knl, iname, ParameterStudyAxisTag)
my_tag = next(iter(tags))
size = my_tag.size
if size > 8 and 0:
# We are going to split this.
knl = lp.split_iname(knl, iname, 8, outer_tag="l.2",
inner_tag="ord")
else:
knl = lp.tag_inames(knl, {iname: "ord" for iname in inick_inames})
breakpoint()

t_unit = t_unit.with_kernel(knl)

# }}}

self.transform_loopy_cache.store_if_not_present(original_t_unit, t_unit)
# {{{ Stats Collection (Disabled)

if inick_inames and 0:
with ProcessLogger(logger, "Counting Kernel Ops"):
mem_map = lp.get_mem_access_map(t_unit, subgroup_size=32)
grouped_map = mem_map.group_by("mtype", "dtype", "direction")

"""
f32_global_id = grouped_map[lp.MemAccess(mtype="global",
dtype=np.float32,
direction="load")
]
f32_global_st = grouped_map[lp.MemAccess(mtype="global",
dtype=np.float32,
direction="store")
]
"""

f64_local_id = grouped_map[lp.MemAccess(mtype="local",
dtype=np.float64,
direction="load")
]


f64_local_st = grouped_map[lp.MemAccess(mtype="local",
dtype=np.float64,
direction="store")
]

f64_global_id = grouped_map[lp.MemAccess(mtype="global",
dtype=np.float64,
direction="load")
]


f64_global_st = grouped_map[lp.MemAccess(mtype="global",
dtype=np.float64,
direction="store")
]


# }}}
#self.transform_loopy_cache.store_if_not_present(original_t_unit, t_unit)

return t_unit

Expand Down
7 changes: 7 additions & 0 deletions meshmode/transform_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ class DiscretizationEntityAxisTag(UniqueTag):
the axis indexes over.
"""

@tag_dataclass
class DiscretizationNICKSAxisTag(DiscretizationEntityAxisTag):
"""
A tag used for Nick's Parameter Study Axis Tag so that the discretization can
use the compiler written by Kausik.
"""


@tag_dataclass
class DiscretizationElementAxisTag(DiscretizationEntityAxisTag):
Expand Down

0 comments on commit d563426

Please sign in to comment.