Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,17 @@ def guard(clusters):
# Separate out the indirect ConditionalDimensions, which only serve
# the purpose of protecting from OOB accesses
cds = [d for d in cds if not d.indirect]
modes = [cd.relation for cd in cds]
if modes.count('strict') > 1:
raise CompilationError("Only one `strict` condition"
"can be used in an equation")
elif 'strict' in modes:
mode = 'strict'
else:
mode = sympy.And if sympy.And in modes else sympy.Or

# Chain together all `cds` conditions from all expressions in `c`
guards = {}
mode = sympy.Or
for cd in cds:
# `BOTTOM` parent implies a guard that lives outside of
# any iteration space, which corresponds to the placeholder None
Expand All @@ -279,7 +286,6 @@ def guard(clusters):

# Pull `cd` from any expr
condition = guards.setdefault(k, [])
mode = mode and cd.relation
for e in exprs:
try:
condition.append(e.conditionals[cd])
Expand All @@ -296,7 +302,10 @@ def guard(clusters):

# Combination `mode` is And by default.
# If all conditions are Or then Or combination `mode` is used.
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}
if mode == 'strict':
guards = {d: v[0] for d, v in guards.items()}
else:
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}

# Construct a guarded Cluster
processed.append(c.rebuild(exprs=exprs, guards=guards))
Expand Down
32 changes: 26 additions & 6 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
)
from devito.symbolics import IntDiv, limits_mapper, uxreplace
from devito.tools import Pickable, Tag, frozendict
from devito.types import Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min
from devito.types import (
Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min, relational_shift
)

__all__ = [
'ClusterizedEq',
Expand Down Expand Up @@ -222,7 +224,7 @@ def __new__(cls, *args, **kwargs):
relations=ordering.relations, mode='partial')
ispace = IterationSpace(intervals, iterators)

# Construct the conditionals and replace the ConditionalDimensions in `expr`
# Construct the conditionals
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should place this whole block of code, which constructs/lowers the conditionals, into its own separate functions, and a docstring with some examples

conditionals = {}
for d in ordering:
if not d.is_Conditional:
Expand All @@ -234,13 +236,31 @@ def __new__(cls, *args, **kwargs):
if d._factor is not None:
cond = d.relation(cond, GuardFactor(d))
conditionals[d] = cond

# Merge conditionals when possible. E.g if we have an implicit_dim
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw this block imho deserves its own function

# and there is a dimension with the same parent, we ca merged
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dimension

"ca merged"

"their conditions"

you could also make the example a bit more practical

# its condition
for d in input_expr.implicit_dims:
if d not in conditionals:
continue
for cd in dict(conditionals):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list(...) is fine

if cd.parent == d.parent and cd != d:
cond = conditionals.pop(d)
if d.relation == 'strict':
conditionals[cd] = conditionals[d] = cond
else:
mode = cd.relation and d.relation
conditionals[cd] = mode(cond, conditionals[cd])
break

# Replace the ConditionalDimensions in `expr`
for d, cond in conditionals.items():
# Replace dimension with index
index = d.index
if d.condition is not None and d in expr.free_symbols:
index = index - relational_min(d.condition, d.parent)
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor)})

conditionals = frozendict(conditionals)
index = index - relational_min(cond, d.parent)
shift = relational_shift(cond, d.parent)
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift})

# Lower all Differentiable operations into SymPy operations
rhs = diff2sympy(expr.rhs)
Expand Down
67 changes: 43 additions & 24 deletions devito/ir/support/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from sympy.logic.boolalg import BooleanFunction

from devito.ir.support.space import Forward, IterationDirection
from devito.symbolics import CondEq, CondNe, search
from devito.symbolics import CondEq, CondNe, IntDiv, search
from devito.tools import Pickable, as_tuple, frozendict, split
from devito.types import Dimension, LocalObject

Expand All @@ -31,6 +31,34 @@
]


@singledispatch
def bound_index(expr, dim, dir):
if dir == Forward:
return expr._subs(dim, dim + 1)
else:
return expr._subs(dim, dim - 1)


@bound_index.register(Expr)
def _(expr, dim, dir):
if not expr.args:
if dir == Forward:
return expr._subs(dim, dim + 1)
else:
return expr._subs(dim, dim - 1)
return expr.func(*[bound_index(a, dim, dir) for a in expr.args])


@bound_index.register(IntDiv)
def _(expr, dim, dir):
v = dim.symbolic_factor
p0 = dim.root
if dir == Forward:
return Mul((((p0 + 1) + v - 1) / v), v, evaluate=False)
else:
return (p0 - 1) - abs(p0 - 1) % v


class AbstractGuard:
pass

Expand Down Expand Up @@ -138,37 +166,27 @@ class BaseGuardBoundNext(Guard, Pickable):
given `direction`.
"""

__rargs__ = ('d', 'direction')
__rargs__ = ('d', 'index', 'direction')

def __new__(cls, d, direction, **kwargs):
def __new__(cls, d, index, direction, **kwargs):
assert isinstance(d, Dimension)
assert isinstance(direction, IterationDirection)

if direction == Forward:
p0 = d.root
p1 = d.root.symbolic_max
# Always take the next index in the iteration direction
next_index = bound_index(index, d, direction)

if d.is_Conditional:
v = d.symbolic_factor
# Round `p0 + 1` up to the nearest multiple of `v`
p0 = Mul((((p0 + 1) + v - 1) / v), v, evaluate=False)
else:
p0 = p0 + 1
# The direction might be forward but accessing c - d
# making the access backward w.r.t
# Update direction according to access direction for valid guard
if index.has(-d):
direction = -direction

if direction == Forward:
p0 = next_index
p1 = d.root.symbolic_max
else:
p0 = d.root.symbolic_min
p1 = d.root

if d.is_Conditional:
v = d.symbolic_factor
# Round `p1 - 1` down to the nearest sub-multiple of `v`
# NOTE: we use ABS to make sure we handle negative values properly.
# Once `p1 - 1` is negative (e.g. `iteration=time - 1` and `time=0`),
# as long as we get a negative number, rather than 0 and even if it's
# not `-v`, we're good
p1 = (p1 - 1) - abs(p1 - 1) % v
else:
p1 = p1 - 1
p1 = next_index

try:
if cls.__base__._eval_relation(p0, p1) is true:
Expand All @@ -180,6 +198,7 @@ def __new__(cls, d, direction, **kwargs):

obj.d = d
obj.direction = direction
obj.index = index

return obj

Expand Down
8 changes: 8 additions & 0 deletions devito/ir/support/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,14 @@ def __repr__(self):
def __hash__(self):
return hash(self._name)

def __neg__(self):
if self._name == '++':
return Backward
elif self._name == '--':
return Forward
else:
return Any


Forward = IterationDirection('++')
"""Forward iteration direction ('++')."""
Expand Down
3 changes: 3 additions & 0 deletions devito/ir/support/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __lt__(self, other):
return True
elif q_positive(i):
return False

raise TypeError("Non-comparable index functions") from e

return False
Expand Down Expand Up @@ -164,6 +165,7 @@ def __gt__(self, other):
return True
elif q_negative(i):
return False

raise TypeError("Non-comparable index functions") from e

return False
Expand Down Expand Up @@ -203,6 +205,7 @@ def __le__(self, other):
return True
elif q_positive(i):
return False

raise TypeError("Non-comparable index functions") from e

# Note: unlike `__lt__`, if we end up here, then *it is* <=. For example,
Expand Down
80 changes: 72 additions & 8 deletions devito/passes/clusters/asynchrony.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import defaultdict
from functools import singledispatch

from sympy import true
from sympy import Expr, Mod, true

from devito.ir import (
Backward, Forward, GuardBoundNext, PrefetchUpdate, Queue, ReleaseLock, SyncArray,
Expand All @@ -9,11 +10,65 @@
from devito.passes.clusters.utils import in_critical_region, is_memcpy
from devito.symbolics import IntDiv, uxreplace
from devito.tools import OrderedSet, is_integer, timed_pass
from devito.types import CustomDimension, Lock
from devito.types import CustomDimension, Lock, VirtualDimension

__all__ = ['memcpy_prefetch', 'tasking']


@singledispatch
def next_index(expr, dim, dir):
return expr._subs(dim, dim + dir)


@next_index.register(Expr)
def _(expr, dim, dir):
if not expr.args:
return expr._subs(dim, dim + dir)
return expr.func(*[next_index(a, dim, dir) for a in expr.args])


@next_index.register(IntDiv)
def _(expr, dim, dir):
"""
Handle forward and backward fetches separately to handle non-canonical index
expressions of the form:

t//factor + cond(t)

where ``cond(t)`` is a piecewise correction term.

The forward fetch advances to the next coarse-grained slot while evaluating
the correction at the next time point:

t//factor + cond(t)
-> (t//factor + 1) + cond(t + 1)

The backward fetch is not, in general, the inverse transformation obtained by
replacing ``+1`` with ``-1``. The correction may already be applied at the
current time point, causing the forward and backward fetches to be asymmetric.

For example, with ``factor=2`` and ``cond(t) := (t == a)``, the index at
``t=a=3`` is:

3//2 + 1 = 2

while the previous index is:

2//2 + 0 = 1

A symmetric backward transformation would instead yield:

3//2 - 1 + 0 = 0
"""
if expr.lhs._defines & dim._defines:
if dir == 1:
return expr + dir
else:
return expr._subs(dim, dim + dir)
else:
return expr


def async_trigger(c, dims):
"""
Return the Dimension in `c`'s IterationSpace that triggers the
Expand Down Expand Up @@ -78,7 +133,8 @@ def callback(self, clusters, prefix):
d = self.key0(c0)
if d is not dim:
continue

if d in c0.guards and not c0.guards[d].has(Mod):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

searching for Mod is a bit meh, I'd rather add a special guard to ir/support/guards.py and look for that instead (there's quite a few already in there!)

continue
protected = self._schedule_waitlocks(c0, d, clusters, locks, syncs)
self._schedule_withlocks(c0, d, protected, locks, syncs)

Expand Down Expand Up @@ -193,7 +249,7 @@ def memcpy_prefetch(clusters, key0, sregistry):
if c.properties.is_prefetchable(d._defines):
_actions_from_update_memcpy(c, d, clusters, actions, sregistry)
elif d.is_Custom and is_integer(c.ispace[d].size):
_actions_from_init(c, d, actions)
_actions_from_init(c, d, clusters, actions)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover, I guess


# Attach the computed Actions
processed = []
Expand All @@ -214,7 +270,7 @@ def memcpy_prefetch(clusters, key0, sregistry):
return processed


def _actions_from_init(c, d, actions):
def _actions_from_init(c, d, clusters, actions):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover, I guess

e = c.exprs[0]
function = e.rhs.function
target = e.lhs.function
Expand All @@ -240,7 +296,7 @@ def _actions_from_update_memcpy(c, d, clusters, actions, sregistry):

fetch = e.rhs.indices[d]
fshift = {Forward: 1, Backward: -1}.get(direction, 0)
findex = fetch + fshift if fetch.find(IntDiv) else fetch._subs(pd, pd + fshift)
findex = next_index(fetch, pd, fshift)

# If fetching into e.g. `ub[t1]` we might need to prefetch into e.g. `ub[t0]`
tindex0 = e.lhs.indices[d]
Expand Down Expand Up @@ -271,8 +327,16 @@ def _actions_from_update_memcpy(c, d, clusters, actions, sregistry):
ispace = c.ispace.augment({pd: tindex}) if tindex is not tindex0 else c.ispace

guard0 = c.guards.get(d, true)._subs(fetch, findex)
guard1 = GuardBoundNext(function.indices[d], direction)
guards = c.guards.impose(d, guard0 & guard1)
guard1 = GuardBoundNext(function.indices[d], e.rhs.indices[d], direction)

# First guard1 then if guard1 is valid we can safely evaluate guard0
# that will have valid indices into f
vdnext = VirtualDimension(name=f'vdnext_{d.name}', parent=pd)
ispace = ispace.insert(pd, vdnext)
# Check valid tindex first
guards = c.guards.impose(d, guard1)
# THen check valid access
guards = guards.impose(vdnext, guard0)

syncs = {d: [
ReleaseLock(handle, target),
Expand Down
Loading
Loading