-
Notifications
You must be signed in to change notification settings - Fork 255
api: fix handling of multiple conditions for buffering #2850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5c778ca
8cbdc7f
2e2cc4c
fa27859
9cc52bb
f6e2ee1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,9 @@ | |
| ) | ||
| from devito.symbolics import IntDiv, limits_mapper, uxreplace | ||
| from devito.tools import Pickable, Tag, frozendict | ||
| from devito.types import Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min | ||
| from devito.types import ( | ||
| Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min, relational_shift | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| 'ClusterizedEq', | ||
|
|
@@ -222,7 +224,7 @@ def __new__(cls, *args, **kwargs): | |
| relations=ordering.relations, mode='partial') | ||
| ispace = IterationSpace(intervals, iterators) | ||
|
|
||
| # Construct the conditionals and replace the ConditionalDimensions in `expr` | ||
| # Construct the conditionals | ||
| conditionals = {} | ||
| for d in ordering: | ||
| if not d.is_Conditional: | ||
|
|
@@ -234,13 +236,31 @@ def __new__(cls, *args, **kwargs): | |
| if d._factor is not None: | ||
| cond = d.relation(cond, GuardFactor(d)) | ||
| conditionals[d] = cond | ||
|
|
||
| # Merge conditionals when possible. E.g if we have an implicit_dim | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw this block imho deserves its own function |
||
| # and there is a dimension with the same parent, we ca merged | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dimension "ca merged" "their conditions" you could also make the example a bit more practical |
||
| # its condition | ||
| for d in input_expr.implicit_dims: | ||
| if d not in conditionals: | ||
| continue | ||
| for cd in dict(conditionals): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. list(...) is fine |
||
| if cd.parent == d.parent and cd != d: | ||
| cond = conditionals.pop(d) | ||
| if d.relation == 'strict': | ||
| conditionals[cd] = conditionals[d] = cond | ||
| else: | ||
| mode = cd.relation and d.relation | ||
| conditionals[cd] = mode(cond, conditionals[cd]) | ||
| break | ||
|
|
||
| # Replace the ConditionalDimensions in `expr` | ||
| for d, cond in conditionals.items(): | ||
| # Replace dimension with index | ||
| index = d.index | ||
| if d.condition is not None and d in expr.free_symbols: | ||
| index = index - relational_min(d.condition, d.parent) | ||
| expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor)}) | ||
|
|
||
| conditionals = frozendict(conditionals) | ||
| index = index - relational_min(cond, d.parent) | ||
| shift = relational_shift(cond, d.parent) | ||
| expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor) + shift}) | ||
|
|
||
| # Lower all Differentiable operations into SymPy operations | ||
| rhs = diff2sympy(expr.rhs) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| from collections import defaultdict | ||
| from functools import singledispatch | ||
|
|
||
| from sympy import true | ||
| from sympy import Expr, Mod, true | ||
|
|
||
| from devito.ir import ( | ||
| Backward, Forward, GuardBoundNext, PrefetchUpdate, Queue, ReleaseLock, SyncArray, | ||
|
|
@@ -9,11 +10,65 @@ | |
| from devito.passes.clusters.utils import in_critical_region, is_memcpy | ||
| from devito.symbolics import IntDiv, uxreplace | ||
| from devito.tools import OrderedSet, is_integer, timed_pass | ||
| from devito.types import CustomDimension, Lock | ||
| from devito.types import CustomDimension, Lock, VirtualDimension | ||
|
|
||
| __all__ = ['memcpy_prefetch', 'tasking'] | ||
|
|
||
|
|
||
| @singledispatch | ||
| def next_index(expr, dim, dir): | ||
| return expr._subs(dim, dim + dir) | ||
|
|
||
|
|
||
| @next_index.register(Expr) | ||
| def _(expr, dim, dir): | ||
| if not expr.args: | ||
| return expr._subs(dim, dim + dir) | ||
| return expr.func(*[next_index(a, dim, dir) for a in expr.args]) | ||
|
|
||
|
|
||
| @next_index.register(IntDiv) | ||
| def _(expr, dim, dir): | ||
| """ | ||
| Handle forward and backward fetches separately to handle non-canonical index | ||
| expressions of the form: | ||
|
|
||
| t//factor + cond(t) | ||
|
|
||
| where ``cond(t)`` is a piecewise correction term. | ||
|
|
||
| The forward fetch advances to the next coarse-grained slot while evaluating | ||
| the correction at the next time point: | ||
|
|
||
| t//factor + cond(t) | ||
| -> (t//factor + 1) + cond(t + 1) | ||
|
|
||
| The backward fetch is not, in general, the inverse transformation obtained by | ||
| replacing ``+1`` with ``-1``. The correction may already be applied at the | ||
| current time point, causing the forward and backward fetches to be asymmetric. | ||
|
|
||
| For example, with ``factor=2`` and ``cond(t) := (t == a)``, the index at | ||
| ``t=a=3`` is: | ||
|
|
||
| 3//2 + 1 = 2 | ||
|
|
||
| while the previous index is: | ||
|
|
||
| 2//2 + 0 = 1 | ||
|
|
||
| A symmetric backward transformation would instead yield: | ||
|
|
||
| 3//2 - 1 + 0 = 0 | ||
| """ | ||
| if expr.lhs._defines & dim._defines: | ||
| if dir == 1: | ||
| return expr + dir | ||
| else: | ||
| return expr._subs(dim, dim + dir) | ||
| else: | ||
| return expr | ||
|
|
||
|
|
||
| def async_trigger(c, dims): | ||
| """ | ||
| Return the Dimension in `c`'s IterationSpace that triggers the | ||
|
|
@@ -78,7 +133,8 @@ def callback(self, clusters, prefix): | |
| d = self.key0(c0) | ||
| if d is not dim: | ||
| continue | ||
|
|
||
| if d in c0.guards and not c0.guards[d].has(Mod): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. searching for Mod is a bit meh, I'd rather add a special guard to ir/support/guards.py and look for that instead (there's quite a few already in there!) |
||
| continue | ||
| protected = self._schedule_waitlocks(c0, d, clusters, locks, syncs) | ||
| self._schedule_withlocks(c0, d, protected, locks, syncs) | ||
|
|
||
|
|
@@ -193,7 +249,7 @@ def memcpy_prefetch(clusters, key0, sregistry): | |
| if c.properties.is_prefetchable(d._defines): | ||
| _actions_from_update_memcpy(c, d, clusters, actions, sregistry) | ||
| elif d.is_Custom and is_integer(c.ispace[d].size): | ||
| _actions_from_init(c, d, actions) | ||
| _actions_from_init(c, d, clusters, actions) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. leftover, I guess |
||
|
|
||
| # Attach the computed Actions | ||
| processed = [] | ||
|
|
@@ -214,7 +270,7 @@ def memcpy_prefetch(clusters, key0, sregistry): | |
| return processed | ||
|
|
||
|
|
||
| def _actions_from_init(c, d, actions): | ||
| def _actions_from_init(c, d, clusters, actions): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. leftover, I guess |
||
| e = c.exprs[0] | ||
| function = e.rhs.function | ||
| target = e.lhs.function | ||
|
|
@@ -240,7 +296,7 @@ def _actions_from_update_memcpy(c, d, clusters, actions, sregistry): | |
|
|
||
| fetch = e.rhs.indices[d] | ||
| fshift = {Forward: 1, Backward: -1}.get(direction, 0) | ||
| findex = fetch + fshift if fetch.find(IntDiv) else fetch._subs(pd, pd + fshift) | ||
| findex = next_index(fetch, pd, fshift) | ||
|
|
||
| # If fetching into e.g. `ub[t1]` we might need to prefetch into e.g. `ub[t0]` | ||
| tindex0 = e.lhs.indices[d] | ||
|
|
@@ -271,8 +327,16 @@ def _actions_from_update_memcpy(c, d, clusters, actions, sregistry): | |
| ispace = c.ispace.augment({pd: tindex}) if tindex is not tindex0 else c.ispace | ||
|
|
||
| guard0 = c.guards.get(d, true)._subs(fetch, findex) | ||
| guard1 = GuardBoundNext(function.indices[d], direction) | ||
| guards = c.guards.impose(d, guard0 & guard1) | ||
| guard1 = GuardBoundNext(function.indices[d], e.rhs.indices[d], direction) | ||
|
|
||
| # First guard1 then if guard1 is valid we can safely evaluate guard0 | ||
| # that will have valid indices into f | ||
| vdnext = VirtualDimension(name=f'vdnext_{d.name}', parent=pd) | ||
| ispace = ispace.insert(pd, vdnext) | ||
| # Check valid tindex first | ||
| guards = c.guards.impose(d, guard1) | ||
| # THen check valid access | ||
| guards = guards.impose(vdnext, guard0) | ||
|
|
||
| syncs = {d: [ | ||
| ReleaseLock(handle, target), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should place this whole block of code, which constructs/lowers the conditionals, into its own separate functions, and a docstring with some examples