From b2cb71787250e15e221569a5b2dbfdcb9faa6491 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 3 Jun 2026 16:06:02 +0100 Subject: [PATCH] Tweak scheduling of scalar aliases in the presence of guards --- devito/passes/clusters/aliases.py | 55 ++++++++++++++++++++++++------- tests/test_dimension.py | 22 ++++++++----- tests/test_dse.py | 38 ++++++++++----------- 3 files changed, 74 insertions(+), 41 deletions(-) diff --git a/devito/passes/clusters/aliases.py b/devito/passes/clusters/aliases.py index a51de1f25c..58078847a2 100644 --- a/devito/passes/clusters/aliases.py +++ b/devito/passes/clusters/aliases.py @@ -125,13 +125,13 @@ def _aliases_from_clusters(self, cgroup, exclude, meta): variants = [] for mapper in self._generate(cgroup, exclude): # Clusters -> AliasList - found = collect(mapper.extracted, meta.ispace, self.opt_minstorage) + found = collect(mapper.extracted, meta, self.opt_minstorage) exprs, aliases = self._choose(found, cgroup, mapper) # AliasList -> Schedule schedule = lower_aliases(aliases, meta, self.opt_maxpar) - variants.append(Variant(schedule, exprs)) + variants.append(make_variant(schedule, exprs, mapper)) if not variants: return [] @@ -282,8 +282,6 @@ def _do_generate(self, exprs, exclude, cbk_search, cbk_compose=None): class CireInvariants(CireTransformerLegacy, Queue): - _q_guards_in_key = True - def __init__(self, sregistry, options, platform): super().__init__(sregistry, options, platform) @@ -511,7 +509,7 @@ def _cbk_search2(self, expr, rank): } -def collect(extracted, ispace, minstorage): +def collect(extracted, meta, minstorage): """ Find groups of aliasing expressions. @@ -575,11 +573,11 @@ def collect(extracted, ispace, minstorage): group.append(u) unseen.remove(u) - group = Group(group, ispace=ispace) + group = Group(group, ispace=meta.ispace) k = group.dimensions_translated if minstorage else group.dimensions - k = frozenset(d for d in k if not d.is_NonlinearDerived) + mapper.setdefault(k, []).append(group) aliases = AliasList() @@ -657,8 +655,9 @@ def collect(extracted, ispace, minstorage): # Compute the alias score na = g.naliases - nr = nredundants(ispace, pivot) + nr = nredundants(meta.ispace, pivot) score = estimate_cost(pivot, True)*((na - 1) + nr) + aliases.add(pivot, aliaseds, list(mapper), distances, score) return aliases @@ -728,8 +727,9 @@ def lower_aliases(aliases, meta, maxpar): m = i.dim.symbolic_min - i.dim.parent.symbolic_min else: m = 0 - d = dmapper[i.dim] = IncrDimension(f"{i.dim.name}s", i.dim, m, - dd.symbolic_size, 1, dd.step) + d = dmapper[i.dim] = IncrDimension( + f"{i.dim.name}s", i.dim, m, dd.symbolic_size, 1, dd.step + ) sub_iterators[i.dim] = d else: d = i.dim @@ -745,6 +745,11 @@ def lower_aliases(aliases, meta, maxpar): # The alias write-to space writeto = IterationSpace(IntervalGroup(writeto), sub_iterators) + # Avoid scalar aliases in the presence of guards, since hoisting them + # might cause scope issues (see `test_dse.py::TestAliases::test_split_cond`) + if not writeto and meta.guards: + continue + # The alias iteration space ispace = IterationSpace(IntervalGroup(intervals, meta.ispace.relations), meta.ispace.sub_iterators, @@ -764,6 +769,34 @@ def lower_aliases(aliases, meta, maxpar): return Schedule(*processed, dmapper=dmapper, is_frame=aliases.is_frame) +def make_variant(schedule, exprs, mapper): + """ + Create a Variant from a Schedule and the corresponding expressions. + """ + # Some aliases may have been discarded along the way, and for + # them we reinstate the original sub-expressions + retained = flatten(sa.aliaseds for sa in schedule) + + subs = {} + for k, v in mapper.items(): + if v in retained: + continue + elif isinstance(v, dict): + # E.g., `mapper = {u[t0, x+3, y+3] + u[t0, x+3, y+4]: + # {u[t0, x+3, y+4]: None, u[t0, x+3, y+3]: dummy0}}` + try: + v1, = [i for i in v.values() if i not in retained] + except ValueError: + continue + subs[v1] = k + else: + subs[v] = k + + exprs = [uxreplace(e, subs) for e in exprs] + + return Variant(schedule, exprs) + + def optimize_schedule_rotations(schedule, sregistry): """ Transform the schedule such that the tensor temporaries "rotate" along @@ -1493,7 +1526,7 @@ def nredundants(ispace, expr): redundant if it defines an iteration space for `expr` while not appearing among its free symbols. Note that the converse isn't generally true: there could be a Dimension that does not appear in the free symbols while defining - a non-redundant iteration space (e.g., a BlockDimension). + a non-redundant iteration space (e.g., a BlockDimension or a reduction). """ iterated = {i.dim for i in ispace} used = {i for i in expr.free_symbols if i.is_Dimension} diff --git a/tests/test_dimension.py b/tests/test_dimension.py index 19d807466b..713378fb4f 100644 --- a/tests/test_dimension.py +++ b/tests/test_dimension.py @@ -22,6 +22,7 @@ from devito.types import Array, StencilDimension, Symbol from devito.types.basic import Scalar from devito.types.dimension import AffineIndexAccessFunction, Thickness +from devito.types.misc import Temp class TestIndexAccessFunction: @@ -2130,9 +2131,10 @@ def test_topofusion_w_subdims_conddims(self): assert exprs[0].write is h exprs = FindNodes(Expression).visit(bns['x2_blk0']) - assert len(exprs) == 2 - assert exprs[0].write is fsave - assert exprs[1].write is gsave + assert len(exprs) == 3 + assert isinstance(exprs[0].expr.lhs, Temp) + assert exprs[1].write is fsave + assert exprs[2].write is gsave def test_topofusion_w_subdims_conddims_v2(self): """ @@ -2163,9 +2165,10 @@ def test_topofusion_w_subdims_conddims_v2(self): bns, _ = assert_blocking(op, {'x0_blk0', 'x1_blk0'}) assert len(FindNodes(Expression).visit(bns['x0_blk0'])) == 3 exprs = FindNodes(Expression).visit(bns['x1_blk0']) - assert len(exprs) == 2 - assert exprs[0].write is fsave - assert exprs[1].write is gsave + assert len(exprs) == 3 + assert isinstance(exprs[0].expr.lhs, Temp) + assert exprs[1].write is fsave + assert exprs[2].write is gsave def test_topofusion_w_subdims_conddims_v3(self): """ @@ -2200,9 +2203,10 @@ def test_topofusion_w_subdims_conddims_v3(self): assert exprs[1].write is g exprs = FindNodes(Expression).visit(bns['x2_blk0']) - assert len(exprs) == 2 - assert exprs[0].write is fsave - assert exprs[1].write is gsave + assert len(exprs) == 3 + assert isinstance(exprs[0].expr.lhs, Temp) + assert exprs[1].write is fsave + assert exprs[2].write is gsave # Additional nest due to anti-dependence exprs = FindNodes(Expression).visit(bns['x1_blk0']) diff --git a/tests/test_dse.py b/tests/test_dse.py index 8399261563..3dcbd40336 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -19,7 +19,7 @@ Conditional, DummyEq, Expression, FindNodes, FindSymbols, Iteration, ParallelIteration, retrieve_iteration_tree ) -from devito.passes.clusters.aliases import collect +from devito.passes.clusters.aliases import AliasKey, collect from devito.passes.clusters.factorization import collect_nested from devito.passes.iet.parpragma import VExpanded from devito.symbolics import ( # noqa @@ -423,8 +423,9 @@ def test_collection(self, exprs, expected): extracted = {i.rhs: i.lhs for i in exprs} ispace = exprs[0].ispace + meta = AliasKey(ispace, None, None, None, None) - aliases = collect(extracted, ispace, False) + aliases = collect(extracted, meta, False) aliases.filter(lambda a: a.score > 0) assert len(aliases) == len(expected) @@ -2553,7 +2554,7 @@ def test_invariants_with_conditional(self): op = Operator(eqn, opt='advanced') - assert_structure(op, ['t', 't,fd', 't,fd,x,y'], 't,fd,x,y') + assert_structure(op, ['t', 't,fd,x,y'], 't,fd,x,y') # Make sure it compiles _ = op.cfunction @@ -2561,7 +2562,7 @@ def test_invariants_with_conditional(self): eqn = Eq(u, u - (cos(time_sub * factor * f) * sin(g) * uf)) op = Operator(eqn, opt='advanced') - assert_structure(op, ['x,y', 't', 't,fd', 't,fd,x,y'], 'x,y,t,fd,x,y') + assert_structure(op, ['x,y', 't', 't,fd,x,y'], 'x,y,t,fd,x,y') # Make sure it compiles _ = op.cfunction @@ -2705,10 +2706,9 @@ def test_split_cond(self): cond = FindNodes(Conditional).visit(op) assert len(cond) == 3 - # Each guard should have its own alias for cos(time) - assert 'float r0 = cos(time);' in str(body0(op)) + # No aliases in this case due to guards scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)] - assert len(scalars) == 2 + assert len(scalars) == 0 def test_split_cond_multi_alias(self): grid = Grid((11, 11)) @@ -2728,11 +2728,9 @@ def test_split_cond_multi_alias(self): cond = FindNodes(Conditional).visit(op) assert len(cond) == 3 - # Each guard should have its own aliases for cos(time) and sin(time) - assert 'const float r0 = sin(time) + cos(time)' in str(body0(op)) - assert 'const float r1 = cos(time);' in str(body0(op)) + # No aliases in this case due to guards scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)] - assert len(scalars) == 3 + assert len(scalars) == 0 def test_multi_cond_no_split(self): grid = Grid((11, 11)) @@ -2758,7 +2756,7 @@ def test_multi_cond_no_split(self): ) scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)] - assert len(scalars) == 3 + assert len(scalars) == 0 def test_alias_with_conditional(self): grid = Grid((11, 11)) @@ -2779,9 +2777,9 @@ def test_alias_with_conditional(self): cond = FindNodes(Conditional).visit(op) assert len(cond) == 3 - # Each guard should have its own alias for cos(time/ctf) + # No aliases in this case due to guards scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)] - assert len(scalars) == 2 + assert len(scalars) == 0 def test_scalar_alias_interp(self): grid = Grid(shape=(11, 11)) @@ -2825,9 +2823,9 @@ def test_scalar_with_cond_access(self): cond = FindNodes(Conditional).visit(op) assert len(cond) == 3 - # # Each guard should have its own alias for cos/sin(f1[time-2]) + # The guards prevent some aliases from being hoisted out scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)] - assert len(scalars) == 3 + assert len(scalars) == 0 assert_structure( op, @@ -2855,9 +2853,9 @@ def test_scalar_with_cond_tinvariant(self): cond = FindNodes(Conditional).visit(op) assert len(cond) == 1 - # One for each 1/dt 1/dt**2 + # One for 1/dt, while 1/dt**2 ain't hoisted out due to the guard scalars = [i for i in FindSymbols().visit(op) if isinstance(i, Temp)] - assert len(scalars) == 2 + assert len(scalars) == 1 assert_structure( op, @@ -2865,11 +2863,9 @@ def test_scalar_with_cond_tinvariant(self): 'txyxy' ) - # Both aliases should be hoisted outside the time loop + # The 1/dt alias should be hoisted outside the time loop assert str(body0(op).body[0]) == 'const float r0 = 1.0F/dt;' assert not body0(op).body[0].ispace - assert str(body0(op).body[1]) == 'const float r1 = 1.0F/(dt*dt);' - assert not body0(op).body[1].ispace class TestIsoAcoustic: