From 798c065b359f548d282ed3c5faba0fb4793c5d4f Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 22 May 2026 14:45:02 -0500 Subject: [PATCH 1/2] add PytatoParallelPyOpenCLArrayContext --- arraycontext/__init__.py | 7 +- arraycontext/impl/pytato/__init__.py | 112 ++++- arraycontext/impl/pytato/parallelize.py | 587 ++++++++++++++++++++++++ arraycontext/impl/pytato/utils.py | 77 ++++ arraycontext/pytest.py | 9 + test/test_arraycontext.py | 2 + test/test_pytato_arraycontext.py | 321 +++++++++++++ 7 files changed, 1112 insertions(+), 3 deletions(-) create mode 100644 arraycontext/impl/pytato/parallelize.py diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 82028207..2140a7b2 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -80,7 +80,11 @@ from .impl.jax import EagerJAXArrayContext from .impl.numpy import NumpyArrayContext from .impl.pyopencl import PyOpenCLArrayContext -from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext +from .impl.pytato import ( + PytatoJAXArrayContext, + PytatoParallelPyOpenCLArrayContext, + PytatoPyOpenCLArrayContext, +) from .loopy import make_loopy_program from .pytest import ( PytestArrayContextFactory, @@ -140,6 +144,7 @@ "NumpyArrayContext", "PyOpenCLArrayContext", "PytatoJAXArrayContext", + "PytatoParallelPyOpenCLArrayContext", "PytatoPyOpenCLArrayContext", "PytestArrayContextFactory", "PytestPyOpenCLArrayContextFactory", diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 7c667a79..508cdc7a 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -13,6 +13,7 @@ The following :mod:`pytato`-based array contexts are provided: .. autoclass:: PytatoPyOpenCLArrayContext +.. autoclass:: PytatoParallelPyOpenCLArrayContext .. autoclass:: PytatoJAXArrayContext @@ -28,7 +29,8 @@ .. automodule:: arraycontext.impl.pytato.utils """ __copyright__ = """ -Copyright (C) 2020-1 University of Illinois Board of Trustees +Copyright (C) 2020-6 University of Illinois Board of Trustees +Copyright (C) 2022-3 Kaushik Kulkarni """ __license__ = """ @@ -827,9 +829,15 @@ def compile(self, def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays ) -> pytato.AbstractResultWithNamedArrays: import pytato as pt + + dag = pt.transform.deduplicate_data_wrappers(dag) + dag = pt.tag_all_calls_to_be_inlined(dag) dag = pt.inline_calls(dag) - return pt.transform.materialize_with_mpms(dag) + + dag = pt.transform.materialize_with_mpms(dag) + + return dag @override def einsum(self, spec, *args, arg_names=None, tagged=()): @@ -909,6 +917,106 @@ def clone(self): # }}} +# {{{ PytatoParallelPyOpenCLArrayContext + +class PytatoParallelPyOpenCLArrayContext(PytatoPyOpenCLArrayContext): + """ + Same as :class:`PytatoPyOpenCLArrayContext`, but parallelizes across the device. + + .. note:: + + Refer to :meth:`transform_dag` and :meth:`transform_loopy_program` for + details on the transformation algorithm provided by this array context. + """ + # FIXME: Is this something that the base PytatoParallelPyOpenCLArrayContext + # should be calling, or should it be left for more-concrete derived array + # contexts? If the latter, where should it live? + def _materialize_einsum_inputs_and_outputs( + self, dag: pytato.AbstractResultWithNamedArrays + ) -> pytato.AbstractResultWithNamedArrays: + import pytato as pt + + from .utils import ( + get_inputs_and_outputs_of_einsum, + get_inputs_and_outputs_of_reduction_nodes, + ) + + einsum_inputs, einsum_outputs = get_inputs_and_outputs_of_einsum(dag) + redn_inputs, redn_outputs = get_inputs_and_outputs_of_reduction_nodes(dag) + reduction_inputs_outputs = ( + einsum_inputs | einsum_outputs | redn_inputs | redn_outputs) + + def materialize( + expr: pt.transform.ArrayOrNames) -> pt.transform.ArrayOrNames: + if expr in reduction_inputs_outputs: + if isinstance(expr, pt.InputArgumentBase): + return expr + else: + return expr.tagged(pt.tags.ImplStored()) + else: + return expr + + return pt.transform.map_and_copy(dag, materialize) + + @override + def transform_dag( + self, dag: pytato.AbstractResultWithNamedArrays + ) -> pytato.AbstractResultWithNamedArrays: + r""" + Returns a transformed version of *dag*, where the applied transform is: + + #. Materialize as per MPMS materialization heuristic. + #. materialize every :class:`pytato.array.Einsum`\ 's inputs and outputs. + """ + import pytato as pt + + dag = pt.transform.deduplicate_data_wrappers(dag) + + dag = pt.tag_all_calls_to_be_inlined(dag) + dag = pt.inline_calls(dag) + + dag = pt.transform.materialize_with_mpms(dag) + dag = self._materialize_einsum_inputs_and_outputs(dag) + + return dag + + def _parallelize_across_device( + self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit: + from .parallelize import ( + add_gbarrier_between_disjoint_loop_sets, + alias_global_temporaries, + split_iteration_domain_across_work_items, + ) + + t_unit = split_iteration_domain_across_work_items( + t_unit, self.queue.device.max_compute_units) + + t_unit = add_gbarrier_between_disjoint_loop_sets(t_unit) + + # FIXME: Is this something that the base PytatoParallelPyOpenCLArrayContext + # should be calling, or should it be left for more-concrete derived array + # contexts? If the latter, where should it live? + t_unit = alias_global_temporaries(t_unit) + + return t_unit + + def transform_loopy_program( + self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit: + r""" + Returns a transformed version of *t_unit*, where the applied transform is: + + #. An execution grid size :math:`G` is selected based on *self*'s + OpenCL-device. + #. The iteration domain for each statement in the *t_unit* is divided to + equally among the work-items in :math:`G`. + #. Kernel boundaries are drawn between every set of disjoint loops. + #. Once the kernel boundaries are inferred, :func:`alias_global_temporaries` + is invoked to reduce the memory peak memory used by the transformed + program. + """ + return self._parallelize_across_device(t_unit) + + # {{{ PytatoJAXArrayContext class PytatoJAXArrayContext(_BasePytatoArrayContext): diff --git a/arraycontext/impl/pytato/parallelize.py b/arraycontext/impl/pytato/parallelize.py new file mode 100644 index 00000000..532f608a --- /dev/null +++ b/arraycontext/impl/pytato/parallelize.py @@ -0,0 +1,587 @@ +# pyright: reportAny=warning + +from __future__ import annotations + + +__copyright__ = """ +Copyright (C) 2022-23 Kaushik Kulkarni +Copyright (C) 2022-26 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from typing_extensions import override + +import loopy as lp +from loopy.match import Matchable, MatchExpressionBase +from loopy.symbolic import WalkMapper +from loopy.translation_unit import CallablesTable, for_each_kernel + + +if TYPE_CHECKING: + from collections.abc import Mapping + +import logging + + +logger = logging.getLogger(__name__) + + +__doc__ = """ +.. autofunction:: split_iteration_domain_across_work_items +.. autofunction:: add_gbarrier_between_disjoint_loop_sets +""" + + +# {{{ disjoint loop sets + +@dataclass(frozen=True, eq=True) +class _LoopSet: + inames: frozenset[str] + insns_in_loop_set: frozenset[str] + + +def _get_disjoint_loop_sets(kernel: lp.LoopKernel) -> frozenset[_LoopSet]: + """ + Returns information about the disjoint loop sets in *kernel*. + """ + disjoint_inames_and_insns: list[tuple[set[str], set[str]]] = [] + iname_to_associated_inames_and_insns: dict[str, tuple[set[str], set[str]]] = {} + for insn in kernel.instructions: + inames = insn.within_inames | insn.reduction_inames() + associated_inames_and_insns: tuple[set[str], set[str]] | None = None + for iname in inames: + try: + associated_inames_and_insns = \ + iname_to_associated_inames_and_insns[iname] + except KeyError: + pass + if associated_inames_and_insns is not None: + associated_inames, associated_insns = associated_inames_and_insns + associated_inames.update(inames) + associated_insns.add(insn.id) + else: + associated_inames_and_insns = (set(inames), {insn.id}) + disjoint_inames_and_insns.append(associated_inames_and_insns) + for iname in inames: + iname_to_associated_inames_and_insns[iname] = associated_inames_and_insns + + return frozenset({ + _LoopSet( + frozenset(associated_inames), + frozenset(associated_insns)) + for associated_inames, associated_insns in disjoint_inames_and_insns}) + +# }}} + + +# {{{ split_iteration_domain_across_work_items + +def get_iname_approx_length(kernel: lp.LoopKernel, iname: str) -> float | int: + from loopy.isl_helpers import static_max_of_pw_aff + max_domain_size = static_max_of_pw_aff( + kernel.get_iname_bounds(iname).size, + constants_only=False).to_pw_aff().max_val() + if max_domain_size.is_infty(): + import math + return math.inf + else: + return max_domain_size.to_python() + + +class OuterReductionNestCollector(WalkMapper[[]]): + def __init__(self, outer_inames: frozenset[str]) -> None: + super().__init__() + self.outer_inames: frozenset[str] = outer_inames + # Since we're only looking for the reductions that are on the outside, we can + # use a list instead of a full graph + self.outer_redn_nest: list[frozenset[str]] = [] + + @override + def map_reduction(self, expr: lp.Reduction) -> None: + if not self.visit(expr): + return + + outer_redn_inames = frozenset(expr.inames) & self.outer_inames + + if outer_redn_inames: + self.outer_redn_nest.append(outer_redn_inames) + + self.rec(expr.expr) + + +def _get_outer_iname_pos_from_loop_set( + kernel: lp.LoopKernel, loop_set: _LoopSet, outer_inames: frozenset[str] + ) -> Mapping[str, int]: + if not outer_inames: + return {} + + import pymbolic.primitives as prim + + iname_orders: set[tuple[frozenset[str], ...]] = set() + + for insn_id in loop_set.insns_in_loop_set: + insn = kernel.id_to_insn[insn_id] + if isinstance(insn, lp.Assignment): + insn_iname_order: list[frozenset[str]] = [] + if isinstance(insn.assignee, prim.Subscript): + insn_iname_order.extend( + frozenset({idx.name}) + for idx in insn.assignee.index_tuple + if ( + isinstance(idx, prim.Variable) + and idx.name in outer_inames)) + ornc = OuterReductionNestCollector(outer_inames) + ornc(insn.expression) + insn_iname_order.extend(ornc.outer_redn_nest) + if insn_iname_order: + iname_orders.add(tuple(insn_iname_order)) + elif isinstance(insn, lp.CallInstruction): + # must be a callable kernel, don't touch. + pass + elif isinstance(insn, (lp.BarrierInstruction, lp.NoOpInstruction)): + pass + else: + raise NotImplementedError(type(insn)) + + iname_order = None + + if iname_orders: + # Merge the per-assignee partial orders into a single total order + from pytools.graph import CycleError, compute_topological_order + + successors: dict[str, set[str]] = {iname: set() for iname in outer_inames} + for order in iname_orders: + for earlier, later in zip(order[:-1], order[1:], strict=True): + for earlier_iname in earlier: + for later_iname in later: + successors[earlier_iname].add(later_iname) + + try: + # key= for determinism + iname_order = compute_topological_order(successors, key=lambda x: x) + except CycleError: + pass + + if not iname_order: + # No consistent merge of the per-assignee orderings exists; fall + # back to a deterministic order based on iname names + iname_order = sorted(outer_inames) + + return {iname: i + for i, iname in enumerate(iname_order)} + + +def _split_loop_set_across_work_items( + kernel: lp.LoopKernel, + callables: CallablesTable, + loop_set: _LoopSet, + iname_to_approx_length: Mapping[str, float | int], + max_device_compute_units: int, +) -> lp.LoopKernel: + + # Could possibly do something fancier that also includes the individual inner + # loops in the loop set, but for now just looking at the inames shared between + # all instructions in the set + + outer_non_redn_inames = loop_set.inames + for insn_id in loop_set.insns_in_loop_set: + outer_non_redn_inames &= kernel.id_to_insn[insn_id].within_inames + + outer_redn_inames = loop_set.inames + for insn_id in loop_set.insns_in_loop_set: + outer_redn_inames &= kernel.id_to_insn[insn_id].reduction_inames() + + outer_iname_pos: Mapping[str, int] + all_outer_inames = outer_non_redn_inames | outer_redn_inames + if all_outer_inames: + outer_iname_pos = _get_outer_iname_pos_from_loop_set( + kernel, loop_set, all_outer_inames) + else: + outer_iname_pos = {} + + # Prioritize the non-reduction loop with largest loop count. In case of ties, + # look at the iname position in the assignee and pick the iname indexing over + # leading axis for the work-group hardware iname + inames_to_parallelize = sorted( + outer_non_redn_inames, + key=lambda iname: ( + iname_to_approx_length[iname], + -outer_iname_pos[iname])) + + # Add the largest reduction loop if we don't already have 2 non-reduction loops + # to parallelize over + if len(inames_to_parallelize) < 2 and outer_redn_inames: + inames_to_parallelize.insert(0, + max( + outer_redn_inames, + key=lambda iname: ( + iname_to_approx_length[iname], + -outer_iname_pos[iname]))) + + vng = kernel.get_var_name_generator() + + if len(inames_to_parallelize) == 0: + pass + elif len(inames_to_parallelize) == 1: + iname, = inames_to_parallelize + if iname in outer_non_redn_inames: + ngroups = max_device_compute_units * 4 # '4' to overfill the device + l_one_size = 4 + l_zero_size = 16 + + kernel = lp.split_iname( + kernel, iname, ngroups * l_zero_size * l_one_size) + kernel = lp.split_iname( + kernel, f"{iname}_inner", l_zero_size, inner_tag="l.0") + kernel = lp.split_iname( + kernel, f"{iname}_inner_outer", l_one_size, inner_tag="l.1", + outer_tag="g.0") + else: + from loopy.match import Id + from loopy.transform.data import reduction_arg_to_subst_rule + from loopy.transform.precompute import precompute_for_single_kernel + + ngroups = max_device_compute_units + wg_size = 32 + + iredn_chunk = vng(f"{iname}_chunk") + iredn_inner = vng(f"{iname}_inner") + kernel = lp.split_iname( + kernel, iname, ngroups * wg_size, + inner_iname=iredn_inner, outer_iname=iredn_chunk) + + iredn_group = vng(f"{iname}_group") + iredn_thread = vng(f"{iname}_thread") + kernel = lp.split_iname( + kernel, iredn_inner, wg_size, + outer_iname=iredn_group, inner_iname=iredn_thread, + inner_tag="l.0") + kernel = lp.split_reduction_outward(kernel, iredn_group) + kernel = lp.split_reduction_outward(kernel, iredn_thread) + + insn_ids = sorted(loop_set.insns_in_loop_set) + + iprcmpt_redn_group = vng(f"iprcmpt_{iredn_group}") + + compute_insns: list[str] = [] + for insn_id in insn_ids: + subst_rule_name = vng(f"redn_subst_{iname}_{insn_id}") + kernel = reduction_arg_to_subst_rule( + kernel, iredn_group, + subst_rule_name=subst_rule_name, + insn_match=Id(insn_id)) + + temp_name = vng(f"redn_temp_{iname}_{insn_id}") + compute_insn_id = vng(f"redn_compute_{iname}_{insn_id}") + kernel = precompute_for_single_kernel( + kernel, callables, subst_rule_name, iredn_group, + temporary_name=temp_name, + temporary_address_space=lp.AddressSpace.GLOBAL, + precompute_inames=[iprcmpt_redn_group], + default_tag="g.0", + # Don't want a separate barrier to be added for each temporary; + # instead we will add one below (this is safe because the + # instructions inside a reduction-only outer loop can't depend + # on each other) + add_barrier_for_global_temporary=False, + compute_insn_id=compute_insn_id) + + compute_insns.append(compute_insn_id) + + barrier_id = vng(f"redn_barrier_{iname}") + kernel = lp.add_barrier( + kernel, + insn_before=InsnIds(frozenset(compute_insns)), + insn_after=InsnIds(frozenset(insn_ids)), + id_based_on=barrier_id, + synchronization_kind="global", + mem_kind="global", + within_inames=frozenset()) + + else: + bigger_loop = inames_to_parallelize[-1] + smaller_loop = inames_to_parallelize[-2] + + ngroups = max_device_compute_units * 4 # '4' to overfill the device + l_one_size = 4 + l_zero_size = 16 + + kernel = lp.split_iname( + kernel, f"{bigger_loop}", l_one_size * ngroups) + kernel = lp.split_iname( + kernel, f"{bigger_loop}_inner", l_one_size, inner_tag="l.1", + outer_tag="g.0") + if smaller_loop in outer_non_redn_inames: + kernel = lp.split_iname( + kernel, smaller_loop, l_zero_size, inner_tag="l.0") + else: + smaller_inner_loop = vng(f"{smaller_loop}_inner") + kernel = lp.split_iname( + kernel, smaller_loop, l_zero_size, inner_iname=smaller_inner_loop, + inner_tag="l.0") + kernel = lp.split_reduction_outward(kernel, smaller_inner_loop) + + return kernel + + +@for_each_kernel +def _split_iteration_domain_across_work_items_for_single_kernel( + kernel: lp.LoopKernel, + callables: CallablesTable, + max_device_compute_units: int, +) -> lp.LoopKernel: + + iname_to_approx_length = { + iname: get_iname_approx_length(kernel, iname) + for iname in kernel.all_inames()} + + loop_sets = _get_disjoint_loop_sets(kernel) + + for loop_set in loop_sets: + kernel = _split_loop_set_across_work_items(kernel, + callables, + loop_set, + iname_to_approx_length, + max_device_compute_units) + + return kernel + + +def split_iteration_domain_across_work_items( + t_unit: lp.TranslationUnit, + max_device_compute_units: int, +) -> lp.TranslationUnit: + # Need to pass callables table down into per-kernel function due to + # precompute_for_single_kernel call + return _split_iteration_domain_across_work_items_for_single_kernel( + t_unit, t_unit.callables_table, max_device_compute_units) + +# }}} + + +# {{{ get_call_kernel_insn_ids + +def get_call_kernel_insn_ids(kernel: lp.LoopKernel) -> tuple[frozenset[str], ...]: + """ + Returns a sequence of collection of instruction ids where each entry in the + sequence corresponds to the instructions in a call-kernel to launch. + + In this heuristic we simply draw kernel boundaries such that instruction + belonging to disjoint loop set pairs are executed in different call kernels. + """ + loop_sets = _get_disjoint_loop_sets(kernel) + + insn_id_to_loop_set = { + insn_id: loop_set + for loop_set in loop_sets + for insn_id in loop_set.insns_in_loop_set} + + from pytools.graph import compute_topological_order + + loop_set_dep_graph: dict[_LoopSet, set[_LoopSet]] = { + insn_id_to_loop_set[insn.id]: set() + for insn in kernel.instructions + } + + for insn in kernel.instructions: + insn_loop_set = insn_id_to_loop_set[insn.id] + for dep_id in insn.depends_on: + dep_loop_set = insn_id_to_loop_set[dep_id] + if insn_loop_set != dep_loop_set: + loop_set_dep_graph[dep_loop_set].add(insn_loop_set) + + # Break ties between ready loop sets using the lexicographically smallest + # instruction ID in each set. Loop sets are disjoint by construction, so these + # mins are unique across sets + toposorted_loop_sets: list[_LoopSet] = compute_topological_order( + loop_set_dep_graph, + key=lambda ls: min(ls.insns_in_loop_set)) + + return tuple(loop_set.insns_in_loop_set for loop_set in toposorted_loop_sets) + +# }}} + + +# {{{ add_gbarrier_between_disjoint_loop_sets + +@dataclass(frozen=True) +class InsnIds(MatchExpressionBase): + insn_ids_to_match: frozenset[str] + + @override + def __call__(self, kernel: lp.LoopKernel, matchable: Matchable): + return matchable.id in self.insn_ids_to_match + + +def add_gbarrier_between_disjoint_loop_sets( + t_unit: lp.TranslationUnit) -> lp.TranslationUnit: + kernel = t_unit.default_entrypoint + ing = kernel.get_instruction_id_generator() + + call_kernel_insn_ids = get_call_kernel_insn_ids(kernel) + gbarrier_ids: list[str] = [] + + for ibarrier, (insns_before, insns_after) in enumerate( + zip(call_kernel_insn_ids[:-1], call_kernel_insn_ids[1:], strict=True)): + id_based_on = ing(f"_actx_gbarrier_{ibarrier}") + kernel = lp.add_barrier( + kernel, + insn_before=InsnIds(insns_before), + insn_after=InsnIds(insns_after), + id_based_on=id_based_on, + within_inames=frozenset()) + assert id_based_on in kernel.id_to_insn + gbarrier_ids.append(id_based_on) + + from loopy.match import Id + for pred_gbarrier, succ_gbarrier in zip( + gbarrier_ids[:-1], gbarrier_ids[1:], strict=True): + kernel = lp.add_dependency(kernel, Id(succ_gbarrier), pred_gbarrier) + + return t_unit.with_kernel(kernel) + +# }}} + + +# {{{ global temp var aliasing for disjoint live intervals + +def alias_global_temporaries(t_unit: lp.TranslationUnit) -> lp.TranslationUnit: + """ + Returns a copy of *t_unit* with temporaries of that have disjoint live + intervals using the same :attr:`loopy.TemporaryVariable.base_storage`. + + .. warning:: + + This routine **assumes** that the entrypoint in *t_unit* global + barriers inserted as per :func:`get_call_kernel_insn_ids`. + """ + from collections import defaultdict + + import loopy as lp + from loopy.kernel.data import AddressSpace + from pytools import UniqueNameGenerator + + t_unit = lp.infer_unknown_types(t_unit) + + # all loopy programs from pytato DAGs have exactly one entrypoint. + kernel = t_unit.default_entrypoint + + temp_vars = frozenset(tv.name + for tv in kernel.temporary_variables.values() + if tv.address_space == AddressSpace.GLOBAL) + + call_kernel_insn_ids = get_call_kernel_insn_ids(kernel) + expanded_kernel = lp.expand_subst(kernel) + temp_to_live_interval_start: dict[str, int] = {} + temp_to_live_interval_end: dict[str, int] = {} + + for icall_kernel, insn_ids in enumerate(call_kernel_insn_ids): + for insn_id in insn_ids: + for var in (expanded_kernel.id_to_insn[insn_id].dependency_names() + & temp_vars): + if var not in temp_to_live_interval_start: + assert var not in temp_to_live_interval_end + temp_to_live_interval_start[var] = icall_kernel + assert var in temp_to_live_interval_start + temp_to_live_interval_end[var] = icall_kernel + + vng = UniqueNameGenerator() + + # {{{ get mappings from icall_kernel to temps that are just alive or dead + + icall_kernel_to_just_live_temp_vars: list[set[str]] = [ + set() for _ in call_kernel_insn_ids] + icall_kernel_to_just_dead_temp_vars: list[set[str]] = [ + set() for _ in call_kernel_insn_ids] + + for tv_name, just_alive_idx in temp_to_live_interval_start.items(): + icall_kernel_to_just_live_temp_vars[just_alive_idx].add(tv_name) + + for tv_name, just_dead_idx in temp_to_live_interval_end.items(): + if just_dead_idx != (len(call_kernel_insn_ids) - 1): + # we ignore the temporaries that died at the last kernel since we cannot + # reclaim their memory + icall_kernel_to_just_dead_temp_vars[just_dead_idx+1].add(tv_name) + + # }}} + + new_tvs: dict[str, lp.TemporaryVariable] = {} + # a mapping from shape to the available base storages from temp variables + # that were dead. + shape_to_available_base_storage: dict[int, set[str]] = defaultdict(set) + + for icall_kernel, _ in enumerate(call_kernel_insn_ids): + just_dead_temps = icall_kernel_to_just_dead_temp_vars[icall_kernel] + to_be_allocated_temps = icall_kernel_to_just_live_temp_vars[icall_kernel] + + # reclaim base storage from the dead temporaries + for tv_name in sorted(just_dead_temps): + tv = new_tvs[tv_name] + assert tv.base_storage is not None + assert isinstance(tv.nbytes, int) + assert tv.base_storage not in shape_to_available_base_storage[tv.nbytes] + shape_to_available_base_storage[tv.nbytes].add(tv.base_storage) + + # assign base storages to 'to_be_allocated_temps' + for tv_name in sorted(to_be_allocated_temps): + tv = kernel.temporary_variables[tv_name] + assert tv.name not in new_tvs + assert tv.base_storage is None + assert isinstance(tv.nbytes, int) + if shape_to_available_base_storage[tv.nbytes]: + base_storage = sorted(shape_to_available_base_storage[tv.nbytes])[0] + shape_to_available_base_storage[tv.nbytes].remove(base_storage) + else: + base_storage = vng("_actx_tmp_base") + + new_tvs[tv.name] = tv.copy(base_storage=base_storage) + + for name, tv in kernel.temporary_variables.items(): + if tv.address_space != AddressSpace.GLOBAL: + new_tvs[name] = tv + + kernel = kernel.copy(temporary_variables=new_tvs) + kernel = lp.allocate_temporaries_for_base_storage(kernel) + + def verify_is_int(x: object) -> int: + assert isinstance(x, int) + return x + + old_tmp_mem_requirement = sum( + verify_is_int(tv.nbytes) + for tv in kernel.temporary_variables.values()) + new_tmp_mem_requirement = sum( + {tv.base_storage: verify_is_int(tv.nbytes) + for tv in kernel.temporary_variables.values()}.values()) + logger.info( + "[alias_global_temporaries]: Reduced memory requirement from " + "%.1fMB to %.1fMB.", + old_tmp_mem_requirement*1e-6, new_tmp_mem_requirement*1e-6) + + return t_unit.with_kernel(kernel) + +# }}} + + +# vim: foldmethod=marker diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index 20b265e5..b9cedd7f 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -23,6 +23,7 @@ __copyright__ = """ Copyright (C) 2021 University of Illinois Board of Trustees +Copyright (C) 2022-3 Kaushik Kulkarni """ __license__ = """ @@ -51,12 +52,15 @@ from typing_extensions import override import pytools +from pymbolic.mapper.optimize import optimize_mapper from pytato.analysis import get_num_call_sites from pytato.array import ( Array, Axis as PtAxis, DataInterface, DataWrapper, + Einsum, + IndexLambda, Placeholder, SizeParam, make_placeholder, @@ -65,6 +69,7 @@ from pytato.transform import ( ArrayOrNames, ArrayOrNamesTc, + CachedWalkMapper, CopyMapper, TransformMapperCache, deduplicate, @@ -339,4 +344,76 @@ def _rec_str(key: object) -> str: # }}} + +# {{{ EinsumInputOutputCollector + +@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) +class EinsumInputOutputCollector(CachedWalkMapper[[]]): + """ + .. note:: + + We deliberately avoid using :class:`pytato.transform.CombineMapper` since + the mapper's caching structure would still lead to recomputing + the union of sets for the results of a revisited node. + """ + def __init__(self) -> None: + self.collected_outputs: set[Array] = set() + self.collected_inputs: set[Array] = set() + super().__init__() + + @override + def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: + return expr + + @override + def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + if isinstance(expr, Einsum): + self.collected_outputs.add(expr) + self.collected_inputs.update(expr.args) + + +def get_inputs_and_outputs_of_einsum( + expr: AbstractResultWithNamedArrays + ) -> tuple[frozenset[Array], frozenset[Array]]: + mapper = EinsumInputOutputCollector() + mapper(expr) + return frozenset(mapper.collected_inputs), frozenset(mapper.collected_outputs) + +# }}} + + +# {{{ ReductionInputOutputCollector + +class ReductionInputOutputCollector(CachedWalkMapper[[]]): + """ + .. note:: + We deliberately avoid using :class:`pytato.transform.CombineMapper` since + the mapper's caching structure would still lead to recomputing + the union of sets for the results of a revisited node. + """ + def __init__(self) -> None: + self.collected_outputs: set[Array] = set() + self.collected_inputs: set[Array] = set() + super().__init__() + + @override + def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: + return expr + + def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + if isinstance(expr, IndexLambda) and expr.var_to_reduction_descr: + self.collected_outputs.add(expr) + self.collected_inputs.update(expr.bindings.values()) + + +def get_inputs_and_outputs_of_reduction_nodes( + expr: AbstractResultWithNamedArrays + ) -> tuple[frozenset[Array], frozenset[Array]]: + mapper = ReductionInputOutputCollector() + mapper(expr) + return frozenset(mapper.collected_inputs), frozenset(mapper.collected_outputs) + +# }}} + + # vim: foldmethod=marker diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 1a160ad4..50941405 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -196,6 +196,14 @@ def __str__(self) -> str: f"on '{self.device.platform.name.strip()}'>>") +class _PytestPytatoParallelPyOpenCLArrayContextFactory( + _PytestPytatoPyOpenCLArrayContextFactory): + @property + def actx_class(self): + from arraycontext.impl.pytato import PytatoParallelPyOpenCLArrayContext + return PytatoParallelPyOpenCLArrayContext + + class _PytestEagerJaxArrayContextFactory(PytestArrayContextFactory): def __init__(self, *args, **kwargs) -> None: pass @@ -274,6 +282,7 @@ def __str__(self) -> str: "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, "pytato:jax": _PytestPytatoJaxArrayContextFactory, + "pytato:parallel_pyopencl": _PytestPytatoParallelPyOpenCLArrayContextFactory, "eagerjax": _PytestEagerJaxArrayContextFactory, "numpy": _PytestNumpyArrayContextFactory, } diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 9948d71f..e4222b0e 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -55,6 +55,7 @@ _PytestNumpyArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoJaxArrayContextFactory, + _PytestPytatoParallelPyOpenCLArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory, ) from testlib import DOFArray, MyContainer, MyContainerDOFBcast, Velocity2D @@ -111,6 +112,7 @@ class _PytatoPyOpenCLArrayContextForTestsFactory( _PytatoPyOpenCLArrayContextForTestsFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, + _PytestPytatoParallelPyOpenCLArrayContextFactory, _PytestNumpyArrayContextFactory, ]) diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py index a7f48fae..4635fe31 100644 --- a/test/test_pytato_arraycontext.py +++ b/test/test_pytato_arraycontext.py @@ -398,6 +398,327 @@ def twice(x): actx2._enable_profiling(True) +def test_split_iteration_domain_across_work_items_scalar(): + import loopy as lp + from loopy.kernel.data import GroupInameTag, LocalInameTag + + from arraycontext.impl.pytato.parallelize import ( + split_iteration_domain_across_work_items, + ) + + # Scalars only, nothing to parallelize + t_unit = lp.make_kernel( + "{:}", + "out = a + 1", + [ + lp.GlobalArg("a,out", np.float32, shape=()), + ..., + ]) + + t_unit = split_iteration_domain_across_work_items( + t_unit, max_device_compute_units=4) + + knl = t_unit.default_entrypoint + all_tags = {tag for iname in knl.all_inames() + for tag in knl.iname_tags(iname)} + assert not any(isinstance(t, (GroupInameTag, LocalInameTag)) for t in all_tags) + + +def test_split_iteration_domain_across_work_items_no_outer_inames(): + import loopy as lp + from loopy.kernel.data import GroupInameTag, LocalInameTag + + from arraycontext.impl.pytato.parallelize import ( + split_iteration_domain_across_work_items, + ) + + # No outer inames, nothing to parallelize + t_unit = lp.make_kernel( + "{[i, j]: 0<=i,j 1: From 66519a6e46bc41d2817aac138bd5139939deaaef Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 22 May 2026 17:24:51 -0500 Subject: [PATCH 2/2] change loopy branch --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a4cb4025..6de7e639 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,5 +5,5 @@ git+https://github.com/inducer/pymbolic.git#egg=pymbolic git+https://github.com/inducer/pyopencl.git#egg=pyopencl git+https://github.com/inducer/islpy.git#egg=islpy -git+https://github.com/inducer/loopy.git#egg=loopy +git+https://github.com/majosm/loopy.git@pytato-parallel#egg=loopy git+https://github.com/inducer/pytato.git#egg=pytato