From 06407d642f97cb8c3d946768c8eae0fc8882007d Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Wed, 25 Feb 2026 14:53:34 +0000 Subject: [PATCH 1/2] Arm backend: Fix to ensure nested control flow graphs are delegated - Added function to ensure True/False come in the right order. - Update to quantizer to ensure submodules are converted first. - Update to scalar pass for nested conditions. Signed-off-by: Saoirse Stewart --- backends/arm/_passes/arm_pass_utils.py | 49 ++++++++++++- .../arm/_passes/control_flow_const_inline.py | 35 +++++++-- .../arm/_passes/scalars_to_attribute_pass.py | 9 ++- .../operator_support/control_flow_support.py | 73 +++++++++++-------- backends/arm/quantizer/arm_quantizer.py | 61 +++++++++------- backends/arm/test/ops/test_cond.py | 37 ++++------ backends/arm/tosa/backend.py | 7 +- backends/arm/tosa/partitioner.py | 8 +- 8 files changed, 183 insertions(+), 96 deletions(-) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index a8260d07620..2a047f18778 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -8,7 +8,7 @@ import traceback from inspect import isclass -from typing import Optional, Sequence +from typing import List, Optional, Sequence, Tuple import torch import torch.fx @@ -17,6 +17,10 @@ from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.graph_module import ( + _get_control_flow_submodules, + get_control_flow_submodules, +) from torch._export.utils import ( get_buffer, @@ -29,6 +33,7 @@ from torch._ops import OpOverload from torch._subclasses.fake_tensor import FakeTensor from torch.export.graph_signature import InputKind +from torch.fx import GraphModule, Node def is_submodule_node(node: torch.fx.Node): @@ -284,3 +289,45 @@ def set_node_arg(node: torch.fx.Node, i: int | str, value): def get_output_dim_orders(graph_module): output_node = graph_module.graph.output_node() return [get_first_fake_tensor(node).dim_order() for node in output_node.args[0]] + + +def is_nested_control_flow_graph(graph_module: GraphModule) -> bool: + """Returns True if graph_module is a nested control-flow graph.""" + + # Find all top-level control-flow submodules + top_cf = get_control_flow_submodules(graph_module) + # For each submodule, see if it itself has control-flow inside + for _, submod, _ in top_cf: + if get_control_flow_submodules(submod): + return True + return False + + +def get_cond_while_submodules_nested( + graph_module: GraphModule, + apply_quantization: bool = False, +) -> List[Tuple[str, GraphModule, Node]]: + """Recursively find cond/while_loop submodules in an GraphModule. + + In nested control flow graphs, FX records the submodule functions + (true/false or cond/body) in reverse order compared to top-level graphs. We + must swap the indices when nested so that cond (first) and body/true_fn + (second) are consistently identified across all nesting levels. + + """ + + # Determine arg indices based on nesting and whether only cond branch is needed + nested = is_nested_control_flow_graph(graph_module) + # cond: [true_fn, false_fn] or swapped if nested + cond_indices = [2, 1] if nested else [1, 2] + # while_loop: [cond_fn, body_fn] or swapped if nested + while_indices = [1, 0] if nested else [0, 1] + if apply_quantization: + # only keep the cond_fn for while_loop (first index) when quantizing. + while_indices = [while_indices[0]] + mapping = { + torch.ops.higher_order.cond: cond_indices, + torch.ops.higher_order.while_loop: while_indices, + } + # collect cond/while submodules (using mapping indices) + return _get_control_flow_submodules(graph_module, mapping) diff --git a/backends/arm/_passes/control_flow_const_inline.py b/backends/arm/_passes/control_flow_const_inline.py index 3fdabb42511..cc76e5d9957 100644 --- a/backends/arm/_passes/control_flow_const_inline.py +++ b/backends/arm/_passes/control_flow_const_inline.py @@ -5,11 +5,14 @@ from typing import Set, Type +import torch from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import ( + get_cond_while_submodules_nested, + is_submodule_node, +) from executorch.backends.transforms.utils import is_get_attr_node from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.graph_module import get_cond_while_submodules - from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule @@ -27,15 +30,23 @@ class ControlFlowConstInlinePass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + _targeted_ops = { + torch.ops.higher_order.cond, + torch.ops.higher_order.while_loop, + } - def call(self, graph_module: GraphModule) -> PassResult: + def _convert_getattr(self, graph_module): modified = False - - for _, submodule, _ in get_cond_while_submodules(graph_module): + for _, submodule, _ in get_cond_while_submodules_nested(graph_module): for submodule_node in submodule.graph.nodes: - if is_get_attr_node(submodule_node): + if submodule_node.target in self._targeted_ops: + self._convert_getattr(submodule) + + # For nested control flow, a "node" may be may actually be GraphModule. + # Enure we are only checking for nodes here. + if is_get_attr_node(submodule_node) and not is_submodule_node( + submodule_node + ): val = getattr( submodule_node.graph.owning_module, submodule_node.target ) @@ -53,6 +64,14 @@ def call(self, graph_module: GraphModule) -> PassResult: submodule_node.replace_all_uses_with(const_node) submodule.graph.erase_node(submodule_node) modified = True + return modified + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def call(self, graph_module: GraphModule) -> PassResult: + + modified = self._convert_getattr(graph_module) if modified: graph_module.recompile() diff --git a/backends/arm/_passes/scalars_to_attribute_pass.py b/backends/arm/_passes/scalars_to_attribute_pass.py index 0e889f4e59b..731ea9b4a68 100644 --- a/backends/arm/_passes/scalars_to_attribute_pass.py +++ b/backends/arm/_passes/scalars_to_attribute_pass.py @@ -8,10 +8,11 @@ import torch from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.arm_pass_utils import ( + get_cond_while_submodules_nested, + get_first_fake_tensor, +) from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass -from executorch.exir.graph_module import get_cond_while_submodules - from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule, Node from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix @@ -99,7 +100,7 @@ def handle_control_nodes(self, node: Node, graph_module: GraphModule) -> None: """Apply scalar argument conversion on subgraphs of control-flow nodes. """ - for _, submodule, _ in get_cond_while_submodules(graph_module): + for _, submodule, _ in get_cond_while_submodules_nested(graph_module): for submodule_node in submodule.graph.nodes: # use aten.full.default for scalar constants in control subgraphs self._convert_scalar_args(submodule, submodule_node) diff --git a/backends/arm/operator_support/control_flow_support.py b/backends/arm/operator_support/control_flow_support.py index a474466fe16..b34ebeaece0 100644 --- a/backends/arm/operator_support/control_flow_support.py +++ b/backends/arm/operator_support/control_flow_support.py @@ -20,32 +20,43 @@ def _fully_partitioned(submodule: fx.GraphModule) -> bool: + """Check that all nested control-flow ops within this submodule are also + fully partitioned. + """ partition_tag = None + for submodule_node in submodule.graph.nodes: - if submodule_node.op == "call_function": - # Input Q ops and output DQ ops will be de-tagged even if the submodule is fully supported. - if ( - submodule_node.target in Q_OPS - and list(submodule_node.all_input_nodes)[0].op == "placeholder" - ): - continue - if ( - submodule_node.target in DQ_OPS - and list(submodule_node.users)[0].op == "output" - ): - continue - if "delegation_tag" not in submodule_node.meta: - return False - if partition_tag is None: - partition_tag = submodule_node.meta["delegation_tag"] - elif submodule_node.meta["delegation_tag"] != partition_tag: - return False + if submodule_node.target in ControlFlowOpSupported._targeted_ops: + if _submodules_fully_partitioned(submodule_node, submodule): + return True + + if submodule_node.op != "call_function": + continue + # skip no-op quantize/dequantize boundary + if ( + submodule_node.target in Q_OPS + and list(submodule_node.all_input_nodes)[0].op == "placeholder" + ): + continue + if ( + submodule_node.target in DQ_OPS + and list(submodule_node.users)[0].op == "output" + ): + continue + + if "delegation_tag" not in submodule_node.meta: + return False + + if partition_tag is None: + partition_tag = submodule_node.meta["delegation_tag"] + + elif submodule_node.meta["delegation_tag"] != partition_tag: + return False + return True -def _submodules_fully_partitioned( - node: fx.Node, exported_program: ExportedProgram -) -> bool: +def _submodules_fully_partitioned(node: fx.Node, graph_module: fx.GraphModule) -> bool: """Returns whether the submodule arguments to a cond node were fully partitioned. @@ -61,9 +72,7 @@ def _submodules_fully_partitioned( raise ValueError(f"Unexpected target: {node.target}") cond_submodules = ( ( - exported_program.graph_module.get_submodule( - str(cast(torch.fx.Node, submodule_node).target) - ), + graph_module.get_submodule(str(cast(torch.fx.Node, submodule_node).target)), cast(torch.fx.Node, submodule_node), ) for submodule_node in submodule_args @@ -71,10 +80,11 @@ def _submodules_fully_partitioned( for submodule, submodule_node in cond_submodules: submodule = cast(torch.fx.GraphModule, submodule) - if _fully_partitioned(submodule): - submodule_node.meta["val"] = submodule.graph.output_node().meta["val"] - else: + if not _fully_partitioned(submodule): return False + else: + submodule_node.meta["val"] = submodule.graph.output_node().meta["val"] + return True @@ -105,6 +115,7 @@ def __init__( def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: + if is_submodule_node(node): if not _tosa_spec_supports_cf(self.tosa_spec): self.reporter.report_reject( @@ -118,7 +129,9 @@ def is_node_supported( node, f"Submodule had unsupported user {user}" ) return False - if not _submodules_fully_partitioned(user, self.exported_program): + if not _submodules_fully_partitioned( + user, self.exported_program.graph_module + ): self.reporter.report_reject( node, "One submodule was not fully partitioned" ) @@ -161,7 +174,9 @@ def is_node_supported( ) return False - if not _submodules_fully_partitioned(node, self.exported_program): + if not _submodules_fully_partitioned( + node, self.exported_program.graph_module + ): self.reporter.report_reject( node, "Submodule was not fully partitioned." ) diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index c53070007dc..62a4ff9fa18 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -24,8 +24,11 @@ from executorch.backends.arm.common.arm_compile_spec import ( ArmCompileSpec, ) # isort: skip +from executorch.backends.arm._passes.arm_pass_utils import ( + get_cond_while_submodules_nested, + is_submodule_node, +) from executorch.backends.arm.vgf import VgfCompileSpec -from executorch.exir.graph_module import _get_control_flow_submodules from torch.fx import GraphModule, Node from torchao.quantization.pt2e import ( @@ -639,27 +642,6 @@ def validate(self, model: GraphModule) -> None: f"Quantizer detected operator {node.name} with different device inputs: {devices}." ) - @staticmethod - def _get_submodules_not_handled_by_torchao( - graph_module: GraphModule, - ): - """Returns control flow submodules that torchao's - prepare_pt2e/convert_pt2e do not handle natively. torchao now - recursively handles while_loop body_fn. - - (arg 1), so we only need to manually handle: - - cond true/false branches (args 1, 2) - - while_loop cond_fn (arg 0) - - """ - return _get_control_flow_submodules( - graph_module, - { - torch.ops.higher_order.cond: [1, 2], - torch.ops.higher_order.while_loop: [0], - }, - ) - def quantize_with_submodules( self, model: GraphModule, @@ -688,15 +670,40 @@ def quantize_with_submodules( prepare_fn = prepare_qat_pt2e if is_qat else prepare_pt2e prepared = prepare_fn(model, self) - for name, submodule, _ in self._get_submodules_not_handled_by_torchao(prepared): + # Prepare conditional submodules (e.g., if/while bodies) + # prepare only cond branches and while_loop cond_fn + for name, submodule, _ in get_cond_while_submodules_nested( + prepared, apply_quantization=True + ): prepared.set_submodule(name, prepare_fn(submodule, self), strict=True) + for submodule_node in submodule.graph.nodes: + if is_submodule_node(submodule_node): + for nested_name, nested_sub, _ in get_cond_while_submodules_nested( + submodule, apply_quantization=True + ): + prepared.set_submodule( + nested_name, prepare_fn(nested_sub, self), strict=True + ) + for inp in calibration_samples: prepared(*inp) - for name, submodule, _ in self._get_submodules_not_handled_by_torchao(prepared): - prepared.set_submodule(name, convert_pt2e(submodule), strict=True) - converted = convert_pt2e(prepared) - return converted + # Prepare conditional submodules (e.g., if/while bodies) + # convert only cond branches and while_loop cond_fn + for _, submodule, _ in get_cond_while_submodules_nested( + prepared, apply_quantization=True + ): + converted = convert_pt2e(submodule) + for submodule_node in submodule.graph.nodes: + if is_submodule_node(submodule_node): + for nested_name, nested_sub, _ in get_cond_while_submodules_nested( + submodule, apply_quantization=True + ): + converted.set_submodule( + nested_name, convert_pt2e(nested_sub), strict=True + ) + + return convert_pt2e(prepared) class EthosUQuantizer(TOSAQuantizer): diff --git a/backends/arm/test/ops/test_cond.py b/backends/arm/test/ops/test_cond.py index 2ccbb6886f1..ce7d8c5bad8 100644 --- a/backends/arm/test/ops/test_cond.py +++ b/backends/arm/test/ops/test_cond.py @@ -225,13 +225,7 @@ def _set_branch_calibration_samples( quant_stage.calibration_samples = calibration_samples -@common.parametrize( - "case", - test_cases, - xfails={ - "nested_one_arg_one_output": "Not fully delegated.", - }, -) +@common.parametrize("case", test_cases) def test_cond_tosa_FP(case: Callable[[], tuple[torch.nn.Module, tuple]]): module, example_inputs = case() pipeline = TosaPipelineFP[tuple]( @@ -248,20 +242,18 @@ def test_cond_tosa_FP(case: Callable[[], tuple[torch.nn.Module, tuple]]): pipeline.run() -@common.parametrize( - "case", - test_cases, - xfails={ - "nested_one_arg_one_output": "Node submodule_0 target submodule_0 references nonexistent attribute submodule_0", - }, -) +@common.parametrize("case", test_cases) def test_cond_tosa_INT(case: Callable[[], tuple[torch.nn.Module, tuple]]): module, example_inputs = case() pipeline = TosaPipelineINT[tuple]( - module, example_inputs, aten_op, tosa_extensions=["cf"] + module, + example_inputs, + aten_op, + tosa_extensions=["cf"], + frobenius_threshold=0.8, + cosine_threshold=0.8, # MLETORCH-1808 ) _set_branch_calibration_samples(pipeline, module, example_inputs) - # Make sure no cond ops are left after partitioning. pipeline.add_stage_after( "to_edge_transform_and_lower", @@ -272,10 +264,7 @@ def test_cond_tosa_INT(case: Callable[[], tuple[torch.nn.Module, tuple]]): pipeline.run() -@common.parametrize( - "case", - test_cases, -) +@common.parametrize("case", test_cases) def test_cond_u55_INT(case: Callable[[], tuple[torch.nn.Module, tuple]]): module, example_inputs = case() pipeline = OpNotSupportedPipeline[tuple](module, example_inputs, {aten_op: 1}) @@ -286,8 +275,12 @@ def test_cond_u55_INT(case: Callable[[], tuple[torch.nn.Module, tuple]]): @common.parametrize( "case", test_cases, - xfails={ - "nested_one_arg_one_output": "Node submodule_0 target submodule_0 references nonexistent attribute submodule_0", + skips={ + "one_arg_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.", + "one_arg_const_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.", + "multiple_one_arg_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.", + "one_arg_and_scalar_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.", + "nested_one_arg_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.", }, ) @common.XfailIfNoCorstone320.with_args(raises=None) diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index b7b81e36ced..44d11444b59 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -22,6 +22,10 @@ import torch import tosa_serializer as ts + +from executorch.backends.arm._passes.arm_pass_utils import ( + get_cond_while_submodules_nested, +) from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec from executorch.backends.arm.common.debug import debug_fail, debug_tosa_dump from executorch.backends.arm.debug.schema import DebugHook @@ -35,7 +39,6 @@ from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.dim_order_utils import get_memory_format -from executorch.exir.graph_module import get_cond_while_submodules from torch.export.exported_program import ExportedProgram from torch.fx import Graph, GraphModule, Node @@ -398,7 +401,7 @@ def _preprocess_module( # noqa: C901 raise # Recursively preprocess controlflow submodules. - for name, submodule, control_flow_node in get_cond_while_submodules( + for name, submodule, control_flow_node in get_cond_while_submodules_nested( graph_module ): TOSABackend._regularize_submodule(submodule, control_flow_node) diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 4cacfd4975c..27f24e5958d 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -18,7 +18,10 @@ from typing import Callable, List, Optional, Sequence, Tuple import torch -from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.arm_pass_utils import ( + get_cond_while_submodules_nested, + get_first_fake_tensor, +) from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( calculate_multiples, ) @@ -37,7 +40,6 @@ ) from executorch.exir.backend.utils import tag_constant_data, WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.graph_module import get_cond_while_submodules from torch.export.exported_program import ExportedProgram from torch.fx import GraphModule from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition @@ -219,7 +221,7 @@ def _tag_module( # noqa tags: set[str] = set() if tag_iterator is None: tag_iterator = count(0) - for _, submodule, _ in get_cond_while_submodules(module): + for _, submodule, _ in get_cond_while_submodules_nested(module): submodule_tags = self._tag_module( submodule, containing_program, reporter, tag_iterator ) From 0fb0a34d6bb1a4be858371358c813452d2a0012e Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Tue, 10 Mar 2026 16:14:57 +0000 Subject: [PATCH 2/2] Arm backend: Remove skips from cond testcases --- backends/arm/test/ops/test_cond.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/backends/arm/test/ops/test_cond.py b/backends/arm/test/ops/test_cond.py index ce7d8c5bad8..d4f856ec761 100644 --- a/backends/arm/test/ops/test_cond.py +++ b/backends/arm/test/ops/test_cond.py @@ -272,17 +272,7 @@ def test_cond_u55_INT(case: Callable[[], tuple[torch.nn.Module, tuple]]): pipeline.run() -@common.parametrize( - "case", - test_cases, - skips={ - "one_arg_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.", - "one_arg_const_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.", - "multiple_one_arg_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.", - "one_arg_and_scalar_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.", - "nested_one_arg_one_output": "Segfault when transpose goes into cond. MLBEDSW-11416.", - }, -) +@common.parametrize("case", test_cases) @common.XfailIfNoCorstone320.with_args(raises=None) def test_cond_u85_INT(case: Callable[[], tuple[torch.nn.Module, tuple]]): module, example_inputs = case()