Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@
from .impl.jax import EagerJAXArrayContext
from .impl.numpy import NumpyArrayContext
from .impl.pyopencl import PyOpenCLArrayContext
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
from .impl.pytato import (
PytatoJAXArrayContext,
PytatoParallelPyOpenCLArrayContext,
PytatoPyOpenCLArrayContext,
)
from .loopy import make_loopy_program
from .pytest import (
PytestArrayContextFactory,
Expand Down Expand Up @@ -140,6 +144,7 @@
"NumpyArrayContext",
"PyOpenCLArrayContext",
"PytatoJAXArrayContext",
"PytatoParallelPyOpenCLArrayContext",
"PytatoPyOpenCLArrayContext",
"PytestArrayContextFactory",
"PytestPyOpenCLArrayContextFactory",
Expand Down
112 changes: 110 additions & 2 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
The following :mod:`pytato`-based array contexts are provided:

.. autoclass:: PytatoPyOpenCLArrayContext
.. autoclass:: PytatoParallelPyOpenCLArrayContext
.. autoclass:: PytatoJAXArrayContext


Expand All @@ -28,7 +29,8 @@
.. automodule:: arraycontext.impl.pytato.utils
"""
__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
Copyright (C) 2020-6 University of Illinois Board of Trustees
Copyright (C) 2022-3 Kaushik Kulkarni
"""

__license__ = """
Expand Down Expand Up @@ -827,9 +829,15 @@ def compile(self,
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
) -> pytato.AbstractResultWithNamedArrays:
import pytato as pt

dag = pt.transform.deduplicate_data_wrappers(dag)

dag = pt.tag_all_calls_to_be_inlined(dag)
dag = pt.inline_calls(dag)
return pt.transform.materialize_with_mpms(dag)

dag = pt.transform.materialize_with_mpms(dag)

return dag

@override
def einsum(self, spec, *args, arg_names=None, tagged=()):
Expand Down Expand Up @@ -909,6 +917,106 @@ def clone(self):
# }}}


# {{{ PytatoParallelPyOpenCLArrayContext

class PytatoParallelPyOpenCLArrayContext(PytatoPyOpenCLArrayContext):
"""
Same as :class:`PytatoPyOpenCLArrayContext`, but parallelizes across the device.

.. note::

Refer to :meth:`transform_dag` and :meth:`transform_loopy_program` for
details on the transformation algorithm provided by this array context.
"""
# FIXME: Is this something that the base PytatoParallelPyOpenCLArrayContext
# should be calling, or should it be left for more-concrete derived array
# contexts? If the latter, where should it live?
def _materialize_einsum_inputs_and_outputs(
self, dag: pytato.AbstractResultWithNamedArrays
) -> pytato.AbstractResultWithNamedArrays:
import pytato as pt

from .utils import (
get_inputs_and_outputs_of_einsum,
get_inputs_and_outputs_of_reduction_nodes,
)

einsum_inputs, einsum_outputs = get_inputs_and_outputs_of_einsum(dag)
redn_inputs, redn_outputs = get_inputs_and_outputs_of_reduction_nodes(dag)
reduction_inputs_outputs = (
einsum_inputs | einsum_outputs | redn_inputs | redn_outputs)

def materialize(
expr: pt.transform.ArrayOrNames) -> pt.transform.ArrayOrNames:
if expr in reduction_inputs_outputs:
if isinstance(expr, pt.InputArgumentBase):
return expr
else:
return expr.tagged(pt.tags.ImplStored())
else:
return expr

return pt.transform.map_and_copy(dag, materialize)

@override
def transform_dag(
self, dag: pytato.AbstractResultWithNamedArrays
) -> pytato.AbstractResultWithNamedArrays:
r"""
Returns a transformed version of *dag*, where the applied transform is:

#. Materialize as per MPMS materialization heuristic.
#. materialize every :class:`pytato.array.Einsum`\ 's inputs and outputs.
"""
import pytato as pt

dag = pt.transform.deduplicate_data_wrappers(dag)

dag = pt.tag_all_calls_to_be_inlined(dag)
dag = pt.inline_calls(dag)

dag = pt.transform.materialize_with_mpms(dag)
dag = self._materialize_einsum_inputs_and_outputs(dag)

return dag

def _parallelize_across_device(
self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
from .parallelize import (
add_gbarrier_between_disjoint_loop_sets,
alias_global_temporaries,
split_iteration_domain_across_work_items,
)

t_unit = split_iteration_domain_across_work_items(
t_unit, self.queue.device.max_compute_units)

t_unit = add_gbarrier_between_disjoint_loop_sets(t_unit)

# FIXME: Is this something that the base PytatoParallelPyOpenCLArrayContext
# should be calling, or should it be left for more-concrete derived array
# contexts? If the latter, where should it live?
t_unit = alias_global_temporaries(t_unit)

return t_unit

def transform_loopy_program(
self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
r"""
Returns a transformed version of *t_unit*, where the applied transform is:

#. An execution grid size :math:`G` is selected based on *self*'s
OpenCL-device.
#. The iteration domain for each statement in the *t_unit* is divided to
equally among the work-items in :math:`G`.
#. Kernel boundaries are drawn between every set of disjoint loops.
#. Once the kernel boundaries are inferred, :func:`alias_global_temporaries`
is invoked to reduce the memory peak memory used by the transformed
program.
"""
return self._parallelize_across_device(t_unit)


# {{{ PytatoJAXArrayContext

class PytatoJAXArrayContext(_BasePytatoArrayContext):
Expand Down
Loading
Loading