diff --git a/devtools/visualization/TARGETS b/devtools/visualization/TARGETS
index 88a5ba77107..c6b485f1461 100644
--- a/devtools/visualization/TARGETS
+++ b/devtools/visualization/TARGETS
@@ -9,12 +9,15 @@ runtime.python_library(
srcs = [
"__init__.py",
"visualization_utils.py",
+ "html_visualization.py",
+ "trace_passes.py",
],
visibility = ["PUBLIC"],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir/_serialize:lib",
+ "//executorch/devtools/etrecord:lib",
],
)
diff --git a/devtools/visualization/__init__.py b/devtools/visualization/__init__.py
index 8e91d7ffdb2..be0f2f8959d 100644
--- a/devtools/visualization/__init__.py
+++ b/devtools/visualization/__init__.py
@@ -3,11 +3,25 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
+# model_explorer-based visualization (requires model_explorer pip package)
+try:
+ from executorch.devtools.visualization.visualization_utils import ( # noqa: F401
+ ModelExplorerServer,
+ SingletonModelExplorerServer,
+ visualize,
+ visualize_graph,
+ visualize_model_explorer,
+ )
+except ImportError:
+ pass
-from executorch.devtools.visualization.visualization_utils import ( # noqa: F401
- ModelExplorerServer,
- SingletonModelExplorerServer,
- visualize,
- visualize_graph,
- visualize_model_explorer,
+# Self-contained HTML visualization (no external dependencies)
+from executorch.devtools.visualization.html_visualization import ( # noqa: F401
+ extract_from_exported_program,
+ extract_from_etrecord,
+ extract_from_pt2,
+ extract_from_pte,
+ generate_html,
+ generate_multi_pass_html,
+ visualize_edge_manager,
)
diff --git a/devtools/visualization/html_visualization.py b/devtools/visualization/html_visualization.py
new file mode 100644
index 00000000000..adea83e637a
--- /dev/null
+++ b/devtools/visualization/html_visualization.py
@@ -0,0 +1,1030 @@
+#
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+#
+"""Generate an interactive HTML visualization of an ExecuTorch graph.
+
+Supports pre-serialization (.pt2), post-serialization (.pte), ETRecord
+(.etrecord), and multi-pass trace (.json) files. Produces a self-contained
+HTML file using Cytoscape.js with dagre layout.
+
+Usage:
+ python3 -m executorch.devtools.visualization.html_visualization model.pt2
+ python3 -m executorch.devtools.visualization.html_visualization model.pte -o graph.html
+ python3 -m executorch.devtools.visualization.html_visualization trace.json -o passes.html
+
+Authored with Claude.
+"""
+
+import argparse
+import json
+import os
+import re
+import sys
+from typing import Any, Dict, List, Optional
+
+
+CATEGORY_COLORS = {
+ "backend": "#4caf50",
+ "aten_compute": "#2196f3",
+ "quantize": "#ff9800",
+ "memory": "#9e9e9e",
+ "placeholder": "#03a9f4",
+ "param": "#78909c",
+ "delegate": "#ab47bc",
+}
+
+# Backend custom-op prefixes. Ops containing these (case-insensitive) are
+# categorized as "backend" rather than generic compute. Extend this tuple
+# when new backends register custom op namespaces.
+_BACKEND_OP_PREFIXES = (
+ "cortex_m",
+ "cadence",
+ "qaisw",
+)
+
+
+def categorize_node(op_name: str) -> str:
+ name = op_name.lower()
+ if any(prefix in name for prefix in _BACKEND_OP_PREFIXES):
+ return "backend"
+ if any(
+ k in name
+ for k in (
+ "quantize_per_tensor",
+ "dequantize_per_",
+ "quantize_per_channel",
+ "dequantize_per_channel",
+ )
+ ):
+ return "quantize"
+ if any(
+ k in name
+ for k in (
+ "view",
+ "clone",
+ "permute",
+ "slice",
+ "copy",
+ "expand",
+ "reshape",
+ "t_copy",
+ "unsqueeze",
+ "squeeze",
+ )
+ ):
+ return "memory"
+ if any(k in name for k in ("placeholder", "output", "getitem", "get_attr")):
+ return "placeholder"
+ if "delegate" in name:
+ return "delegate"
+ return "aten_compute"
+
+
+def _make_label(op_name: str) -> str:
+ name = op_name.split("::")[-1] if "::" in op_name else op_name
+ if "." in name:
+ name = name.rsplit(".", 1)[0]
+ if len(name) > 30:
+ name = name[:27] + "..."
+ return name
+
+
+def extract_from_exported_program(ep, model_name: str) -> dict:
+ """Walk an in-memory ExportedProgram's graph and extract visualization data."""
+ graph = ep.graph
+
+ nodes = []
+ edges = []
+ node_map = {}
+
+ for node in graph.nodes:
+ node_id = node.name
+ op_name = node.op
+ if node.op == "call_function":
+ op_name = str(node.target)
+ elif node.op == "call_method":
+ op_name = node.target
+
+ details = {"op": node.op, "target": str(getattr(node, "target", ""))}
+ meta_val = node.meta.get("val")
+ if meta_val is not None:
+ if hasattr(meta_val, "shape"):
+ details["shape"] = str(list(meta_val.shape))
+ details["dtype"] = str(meta_val.dtype)
+ elif isinstance(meta_val, (list, tuple)):
+ shapes = []
+ for v in meta_val:
+ if hasattr(v, "shape"):
+ shapes.append(f"{list(v.shape)} {v.dtype}")
+ if shapes:
+ details["shapes"] = shapes
+
+ category = categorize_node(op_name)
+ label = _make_label(op_name)
+
+ if node.op == "placeholder":
+ target = str(getattr(node, "target", ""))
+ if target.startswith("p_") and target.endswith("_weight"):
+ label = "weight"
+ category = "param"
+ elif target.startswith("p_") and target.endswith("_bias"):
+ label = "bias"
+ category = "param"
+ elif target.startswith("b_") and "running_mean" in target:
+ label = "bn_mean"
+ category = "param"
+ elif target.startswith("b_") and "running_var" in target:
+ label = "bn_var"
+ category = "param"
+ elif target == "x" or not target.startswith(("p_", "b_")):
+ label = "input"
+ elif node.op == "output":
+ label = "output"
+
+ node_map[node_id] = len(nodes)
+ nodes.append(
+ {
+ "id": node_id,
+ "label": label,
+ "w": max(len(label) * 8 + 16, 60),
+ "category": category,
+ "op_name": op_name,
+ "details": details,
+ }
+ )
+
+ for arg in node.args:
+ if hasattr(arg, "name") and arg.name in node_map:
+ edges.append({"source": arg.name, "target": node_id})
+ elif isinstance(arg, (list, tuple)):
+ for a in arg:
+ if hasattr(a, "name") and a.name in node_map:
+ edges.append({"source": a.name, "target": node_id})
+
+ category_counts = {}
+ for n in nodes:
+ cat = n["category"]
+ category_counts[cat] = category_counts.get(cat, 0) + 1
+
+ return {
+ "metadata": {
+ "model_name": model_name,
+ "source_type": "pt2",
+ "total_nodes": len(nodes),
+ "category_counts": category_counts,
+ },
+ "nodes": nodes,
+ "edges": edges,
+ }
+
+
+def extract_from_pt2(path: str) -> dict:
+ import torch
+
+ ep = torch.export.load(path)
+ return extract_from_exported_program(ep, os.path.basename(path))
+
+
+def _extract_delegate_blob_info(
+ delegate_data: bytes,
+) -> Optional[Dict[str, Any]]:
+ """Extract metadata from a delegate blob. Backend-agnostic."""
+ if len(delegate_data) < 8:
+ return None
+ info: Dict[str, Any] = {"blob_size_bytes": len(delegate_data)}
+ op_patterns = re.findall(rb"(?:tosa|aten|xnnpack|qnn)\.\w+", delegate_data)
+ if op_patterns:
+ info["detected_ops"] = sorted(set(op.decode() for op in op_patterns[:20]))
+ return info
+
+
+def extract_from_pte(path: str) -> dict:
+ from executorch.exir._serialize._program import deserialize_pte_binary
+ from executorch.exir.schema import DelegateCall, KernelCall, Tensor
+
+ with open(path, "rb") as f:
+ data = f.read()
+
+ pte_file = deserialize_pte_binary(data)
+ plan = pte_file.program.execution_plan[0]
+
+ nodes: List[dict] = []
+ edges: List[dict] = []
+ value_producers: Dict[int, str] = {}
+
+ # Extract delegate blobs for analysis
+ delegate_info_map: Dict[int, dict] = {}
+ for idx, delegate in enumerate(plan.delegates):
+ if hasattr(delegate, "processed") and delegate.processed:
+ blob = getattr(delegate.processed, "data", None)
+ if blob:
+ info = _extract_delegate_blob_info(bytes(blob))
+ if info:
+ delegate_info_map[idx] = info
+
+ for chain_idx, chain in enumerate(plan.chains):
+ for instr_idx, instr in enumerate(chain.instructions):
+ args = instr.instr_args
+
+ if isinstance(args, KernelCall):
+ op = plan.operators[args.op_index]
+ op_name = f"{op.name}.{op.overload}" if op.overload else op.name
+ node_id = f"k_{chain_idx}_{instr_idx}"
+
+ details: Dict[str, Any] = {"op_name": op_name}
+ input_tensors = []
+ output_tensors = []
+ for val_idx in args.args:
+ if val_idx < len(plan.values):
+ val = plan.values[val_idx].val
+ if isinstance(val, Tensor):
+ shape_str = (
+ f"[{','.join(str(s) for s in val.sizes)}]"
+ )
+ dtype_str = (
+ val.scalar_type.name
+ if hasattr(val.scalar_type, "name")
+ else str(val.scalar_type)
+ )
+ info_str = f"{shape_str} {dtype_str}"
+ if val_idx in value_producers:
+ input_tensors.append(info_str)
+ else:
+ output_tensors.append(info_str)
+
+ if input_tensors:
+ details["inputs"] = input_tensors
+ if output_tensors:
+ details["outputs"] = output_tensors
+
+ category = categorize_node(op_name)
+ label = _make_label(op_name)
+
+ nodes.append(
+ {
+ "id": node_id,
+ "label": label,
+ "w": max(len(label) * 8 + 16, 60),
+ "category": category,
+ "op_name": op_name,
+ "details": details,
+ }
+ )
+
+ for val_idx in args.args:
+ if val_idx in value_producers:
+ edges.append(
+ {"source": value_producers[val_idx], "target": node_id}
+ )
+
+ for val_idx in args.args:
+ value_producers[val_idx] = node_id
+
+ elif isinstance(args, DelegateCall):
+ node_id = f"d_{chain_idx}_{instr_idx}"
+ delegate = plan.delegates[args.delegate_index]
+
+ details = {
+ "delegate_id": delegate.id,
+ "delegate_index": args.delegate_index,
+ }
+ if args.delegate_index in delegate_info_map:
+ details.update(delegate_info_map[args.delegate_index])
+
+ label = delegate.id
+ if len(label) > 25:
+ label = label[:22] + "..."
+
+ nodes.append(
+ {
+ "id": node_id,
+ "label": label,
+ "w": max(len(label) * 8 + 20, 100),
+ "category": "delegate",
+ "op_name": f"delegate:{delegate.id}",
+ "details": details,
+ }
+ )
+
+ for val_idx in args.args:
+ if val_idx in value_producers:
+ edges.append(
+ {"source": value_producers[val_idx], "target": node_id}
+ )
+
+ for val_idx in args.args:
+ value_producers[val_idx] = node_id
+
+ for i, idx in enumerate(plan.inputs):
+ node_id = f"input_{i}"
+ val = plan.values[idx].val
+ details = {"value_index": idx}
+ if isinstance(val, Tensor):
+ details["shape"] = list(val.sizes)
+ details["dtype"] = (
+ val.scalar_type.name
+ if hasattr(val.scalar_type, "name")
+ else str(val.scalar_type)
+ )
+
+ nodes.insert(
+ 0,
+ {
+ "id": node_id,
+ "label": f"input_{i}",
+ "w": 70,
+ "category": "placeholder",
+ "op_name": "input",
+ "details": details,
+ },
+ )
+ value_producers[idx] = node_id
+
+ for i, idx in enumerate(plan.outputs):
+ node_id = f"output_{i}"
+ val = plan.values[idx].val
+ details = {"value_index": idx}
+ if isinstance(val, Tensor):
+ details["shape"] = list(val.sizes)
+ details["dtype"] = (
+ val.scalar_type.name
+ if hasattr(val.scalar_type, "name")
+ else str(val.scalar_type)
+ )
+ nodes.append(
+ {
+ "id": node_id,
+ "label": f"output_{i}",
+ "w": 80,
+ "category": "placeholder",
+ "op_name": "output",
+ "details": details,
+ }
+ )
+ if idx in value_producers:
+ edges.append({"source": value_producers[idx], "target": node_id})
+
+ # Filter edges to only reference existing nodes
+ node_ids = {n["id"] for n in nodes}
+ edges = [e for e in edges if e["source"] in node_ids and e["target"] in node_ids]
+
+ category_counts = {}
+ for n in nodes:
+ cat = n["category"]
+ category_counts[cat] = category_counts.get(cat, 0) + 1
+
+ return {
+ "metadata": {
+ "model_name": os.path.basename(path),
+ "source_type": "pte",
+ "total_nodes": len(nodes),
+ "category_counts": category_counts,
+ },
+ "nodes": nodes,
+ "edges": edges,
+ }
+
+
+def extract_from_trace_json(path: str) -> dict:
+ """Load a multi-pass trace JSON."""
+ with open(path) as f:
+ data = json.load(f)
+ if "passes" not in data:
+ raise ValueError(f"{path} does not contain a 'passes' key")
+ return data
+
+
+def extract_from_etrecord(path: str) -> dict:
+ """Extract visualization data from an ETRecord file."""
+ from executorch.devtools.etrecord import parse_etrecord
+
+ etrecord = parse_etrecord(path)
+ passes = []
+
+ # edge_dialect_program can be a single ExportedProgram or a dict of them
+ edp = etrecord.edge_dialect_program
+ if edp is not None:
+ if isinstance(edp, dict):
+ for method_name, ep in edp.items():
+ passes.append(
+ extract_from_exported_program(ep, f"Edge Dialect: {method_name}")
+ )
+ else:
+ passes.append(
+ extract_from_exported_program(edp, "Edge Dialect (pre-delegation)")
+ )
+
+ # graph_map contains additional stages
+ if etrecord.graph_map:
+ for name, ep in etrecord.graph_map.items():
+ passes.append(extract_from_exported_program(ep, name))
+
+ # Fallback to exported_program if nothing else is available
+ if etrecord.exported_program is not None and not passes:
+ passes.append(
+ extract_from_exported_program(etrecord.exported_program, "Exported Program")
+ )
+
+ if len(passes) == 0:
+ raise ValueError(f"No graph data found in {path}")
+
+ if len(passes) == 1:
+ return passes[0]
+
+ return {
+ "model_name": os.path.basename(path),
+ "passes": passes,
+ }
+
+
+# ---------------------------------------------------------------------------
+# Single-pass HTML template
+# ---------------------------------------------------------------------------
+
+HTML_TEMPLATE = """
+
+
+
+ExecuTorch Graph: $$MODEL_NAME$$
+
+
+
+
+
+
+
+
+
+
+
+"""
+
+
+# ---------------------------------------------------------------------------
+# Multi-pass HTML template
+# ---------------------------------------------------------------------------
+
+MULTI_PASS_HTML_TEMPLATE = """
+
+
+
+ExecuTorch Pass Trace: $$MODEL_NAME$$
+
+
+
+
+
+
+
+
+ Nodes: 0
+
+
+
+
+
+
+
+
+
+"""
+
+
+def generate_html(graph_data: dict, output_path: str) -> None:
+ html = HTML_TEMPLATE
+ html = html.replace("$$MODEL_NAME$$", graph_data["metadata"]["model_name"])
+ html = html.replace("$$GRAPH_JSON$$", json.dumps(graph_data))
+ html = html.replace("$$COLORS_JSON$$", json.dumps(CATEGORY_COLORS))
+ with open(output_path, "w") as f:
+ f.write(html)
+ print(
+ f"Wrote {output_path} "
+ f"({graph_data['metadata']['total_nodes']} nodes, "
+ f"{len(graph_data['edges'])} edges)"
+ )
+
+
+def generate_multi_pass_html(trace_data: dict, output_path: str) -> None:
+ model_name = trace_data.get("model_name", "unknown")
+ passes = trace_data["passes"]
+
+ html = MULTI_PASS_HTML_TEMPLATE
+ html = html.replace("$$MODEL_NAME$$", model_name)
+ html = html.replace("$$PASSES_JSON$$", json.dumps(passes))
+ html = html.replace("$$COLORS_JSON$$", json.dumps(CATEGORY_COLORS))
+ with open(output_path, "w") as f:
+ f.write(html)
+
+ total_nodes = sum(p["metadata"]["total_nodes"] for p in passes)
+ print(
+ f"Wrote {output_path} ({len(passes)} passes, "
+ f"{total_nodes} total nodes across all snapshots)"
+ )
+
+
+def visualize_edge_manager(edge_manager, output_path: str = "graph.html") -> str:
+ """Visualize an EdgeProgramManager as HTML before to_executorch().
+
+ Usage in your export script:
+ from executorch.devtools.visualization.html_visualization import (
+ visualize_edge_manager,
+ )
+
+ edge_manager = to_edge_transform_and_lower(...)
+ visualize_edge_manager(edge_manager, "my_model_graph.html")
+ et_program = edge_manager.to_executorch()
+ """
+ ep = edge_manager.exported_program()
+ graph_data = extract_from_exported_program(ep, "Edge Manager Graph")
+ generate_html(graph_data, output_path)
+ return output_path
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Visualize ExecuTorch graph as interactive HTML"
+ )
+ parser.add_argument("input", help="Path to .pt2, .pte, .etrecord, or .json file")
+ parser.add_argument("-o", "--output", default=None)
+ args = parser.parse_args()
+
+ output = args.output or os.path.splitext(args.input)[0] + ".html"
+ ext = os.path.splitext(args.input)[1].lower()
+
+ if ext == ".json":
+ trace_data = extract_from_trace_json(args.input)
+ generate_multi_pass_html(trace_data, output)
+ elif ext == ".pt2":
+ graph_data = extract_from_pt2(args.input)
+ generate_html(graph_data, output)
+ elif ext == ".pte":
+ graph_data = extract_from_pte(args.input)
+ generate_html(graph_data, output)
+ elif ext in (".etrecord", ".bin"):
+ data = extract_from_etrecord(args.input)
+ if "passes" in data:
+ generate_multi_pass_html(data, output)
+ else:
+ generate_html(data, output)
+ else:
+ print(f"Error: unsupported '{ext}'", file=sys.stderr)
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/devtools/visualization/trace_passes.py b/devtools/visualization/trace_passes.py
new file mode 100644
index 00000000000..1ee188cd4d2
--- /dev/null
+++ b/devtools/visualization/trace_passes.py
@@ -0,0 +1,838 @@
+#
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+#
+"""Trace any ExecuTorch backend's compilation pipeline pass-by-pass.
+
+Runs quantization, export, to_edge, then each backend pass individually,
+saving a JSON file with per-pass graph snapshots for visualization with
+html_visualization.py.
+
+Usage:
+ # Cortex-M backend:
+ python -m executorch.devtools.visualization.trace_passes \\
+ --backend cortex_m --model mobilenet_v2 -o trace.json
+
+ # XNNPACK backend:
+ python -m executorch.devtools.visualization.trace_passes \\
+ --backend xnnpack --model mobilenet_v2 -o trace.json
+
+ # Cadence backend:
+ python -m executorch.devtools.visualization.trace_passes \\
+ --backend cadence --model mobilenet_v2 -o trace.json
+
+ # Skip quantization (trace passes on a float model):
+ python -m executorch.devtools.visualization.trace_passes \\
+ --backend xnnpack --model mobilenet_v2 --no-quantize -o trace.json
+
+ # Then visualize:
+ python -m executorch.devtools.visualization.html_visualization trace.json -o trace.html
+
+Authored with Claude.
+"""
+
+import argparse
+import inspect
+import json
+import os
+import sys
+import traceback
+from typing import Any, Dict, List, Optional, Tuple, Type
+
+import torch
+from executorch.exir import EdgeCompileConfig, to_edge
+from executorch.exir.pass_base import ExportPass
+from executorch.exir.program._program import _transform
+from torch.export import export
+
+
+# ---------------------------------------------------------------------------
+# Node categorization (shared with html_visualization.py)
+# ---------------------------------------------------------------------------
+
+CATEGORY_COLORS = {
+ "backend": "#4caf50",
+ "aten_compute": "#2196f3",
+ "quantize": "#ff9800",
+ "memory": "#9e9e9e",
+ "placeholder": "#03a9f4",
+ "param": "#78909c",
+ "delegate": "#ab47bc",
+}
+
+_BACKEND_OP_PREFIXES = (
+ "cortex_m",
+ "cadence",
+ "qaisw",
+)
+
+
+def categorize_node(op_name: str) -> str:
+ name = op_name.lower()
+ if any(prefix in name for prefix in _BACKEND_OP_PREFIXES):
+ return "backend"
+ if any(
+ k in name
+ for k in (
+ "quantize_per_tensor",
+ "dequantize_per_",
+ "quantize_per_channel",
+ "dequantize_per_channel",
+ )
+ ):
+ return "quantize"
+ if any(
+ k in name
+ for k in (
+ "view",
+ "clone",
+ "permute",
+ "slice",
+ "copy",
+ "expand",
+ "reshape",
+ "t_copy",
+ "unsqueeze",
+ "squeeze",
+ )
+ ):
+ return "memory"
+ if any(k in name for k in ("placeholder", "output", "getitem", "get_attr")):
+ return "placeholder"
+ if "delegate" in name:
+ return "delegate"
+ return "aten_compute"
+
+
+def _make_label(op_name: str) -> str:
+ name = op_name.split("::")[-1] if "::" in op_name else op_name
+ if "." in name:
+ name = name.rsplit(".", 1)[0]
+ if len(name) > 30:
+ name = name[:27] + "..."
+ return name
+
+
+# ---------------------------------------------------------------------------
+# QDQ group detection
+# ---------------------------------------------------------------------------
+
+
+def _get_qdq_ops():
+ """Get Q/DQ op sets. Works with or without Arm constants."""
+ q_ops = set()
+ dq_ops = set()
+ try:
+ from executorch.backends.arm.constants import DQ_OPS, Q_OPS
+
+ q_ops.update(Q_OPS)
+ dq_ops.update(DQ_OPS)
+ except ImportError:
+ pass
+ # Fallback: common quantize/dequantize ops
+ for op_name in [
+ "quantize_per_tensor",
+ "quantize_per_channel",
+ "quantize_per_tensor_tensor",
+ ]:
+ op = getattr(torch.ops.quantized_decomposed, op_name, None)
+ if op is not None:
+ q_ops.add(op.default)
+ for op_name in [
+ "dequantize_per_tensor",
+ "dequantize_per_channel",
+ "dequantize_per_tensor_tensor",
+ ]:
+ op = getattr(torch.ops.quantized_decomposed, op_name, None)
+ if op is not None:
+ dq_ops.add(op.default)
+ return q_ops, dq_ops
+
+
+def detect_qdq_groups(graph) -> dict:
+ """Detect DQ -> compute_op -> Q chains and assign group IDs."""
+ q_ops, dq_ops = _get_qdq_ops()
+ if not q_ops or not dq_ops:
+ return {}
+
+ groups = {}
+ group_id = 0
+
+ for node in graph.nodes:
+ if node.op != "call_function" or node.target not in q_ops:
+ continue
+ if not node.args:
+ continue
+ compute_node = node.args[0]
+ if not hasattr(compute_node, "name") or not hasattr(compute_node, "target"):
+ continue
+ if compute_node.target in q_ops or compute_node.target in dq_ops:
+ continue
+
+ dq_nodes = []
+ for arg in compute_node.args:
+ if hasattr(arg, "target") and arg.target in dq_ops:
+ dq_nodes.append(arg)
+ elif isinstance(arg, (list, tuple)):
+ for a in arg:
+ if hasattr(a, "target") and a.target in dq_ops:
+ dq_nodes.append(a)
+
+ if not dq_nodes:
+ continue
+
+ members = [n.name for n in dq_nodes] + [compute_node.name, node.name]
+ for m in members:
+ if m not in groups:
+ groups[m] = group_id
+ group_id += 1
+
+ return groups
+
+
+# ---------------------------------------------------------------------------
+# Graph snapshot extraction
+# ---------------------------------------------------------------------------
+
+
+def _serialize_qparams(qparams_dict):
+ if not qparams_dict:
+ return None
+ result = {}
+ for idx, qa in qparams_dict.items():
+ entry = {}
+ for attr in ("scale", "zp", "qmin", "qmax", "axis", "per_channel"):
+ val = getattr(qa, attr, None)
+ if val is not None:
+ entry[attr] = val
+ dtype = getattr(qa, "dtype", None)
+ if dtype is not None:
+ entry["dtype"] = str(dtype)
+ result[str(idx)] = entry
+ return result
+
+
+def extract_from_exported_program(ep, stage_name, qdq_groups=None):
+ """Walk an ExportedProgram's graph and extract visualization data."""
+ graph = ep.graph
+ nodes = []
+ edges = []
+ node_map = {}
+
+ for node in graph.nodes:
+ node_id = node.name
+ op_name = node.op
+ if node.op == "call_function":
+ op_name = str(node.target)
+ elif node.op == "call_method":
+ op_name = node.target
+
+ details = {"op": node.op, "target": str(getattr(node, "target", ""))}
+
+ meta_val = node.meta.get("val")
+ if meta_val is not None:
+ if hasattr(meta_val, "shape"):
+ details["shape"] = str(list(meta_val.shape))
+ details["dtype"] = str(meta_val.dtype)
+ elif isinstance(meta_val, (list, tuple)):
+ shapes = []
+ for v in meta_val:
+ if hasattr(v, "shape"):
+ shapes.append(f"{list(v.shape)} {v.dtype}")
+ if shapes:
+ details["shapes"] = shapes
+
+ stack_trace = node.meta.get("stack_trace")
+ if stack_trace:
+ if len(stack_trace) > 500:
+ stack_trace = "..." + stack_trace[-500:]
+ details["stack_trace"] = stack_trace
+
+ input_qparams = node.meta.get("input_qparams")
+ if input_qparams:
+ serialized = _serialize_qparams(input_qparams)
+ if serialized:
+ details["input_qparams"] = serialized
+
+ output_qparams = node.meta.get("output_qparams")
+ if output_qparams:
+ serialized = _serialize_qparams(output_qparams)
+ if serialized:
+ details["output_qparams"] = serialized
+
+ category = categorize_node(op_name)
+ label = _make_label(op_name)
+
+ if node.op == "placeholder":
+ target = str(getattr(node, "target", ""))
+ if target.startswith("p_") and target.endswith("_weight"):
+ label = "weight"
+ category = "param"
+ elif target.startswith("p_") and target.endswith("_bias"):
+ label = "bias"
+ category = "param"
+ elif target.startswith("b_") and "running_mean" in target:
+ label = "bn_mean"
+ category = "param"
+ elif target.startswith("b_") and "running_var" in target:
+ label = "bn_var"
+ category = "param"
+ elif target == "x" or not target.startswith(("p_", "b_")):
+ label = "input"
+ elif node.op == "output":
+ label = "output"
+
+ node_data = {
+ "id": node_id,
+ "label": label,
+ "w": max(len(label) * 8 + 16, 60),
+ "category": category,
+ "op_name": op_name,
+ "details": details,
+ }
+
+ if qdq_groups and node_id in qdq_groups:
+ node_data["qdq_group_id"] = qdq_groups[node_id]
+
+ node_map[node_id] = len(nodes)
+ nodes.append(node_data)
+
+ for arg in node.args:
+ if hasattr(arg, "name") and arg.name in node_map:
+ edges.append({"source": arg.name, "target": node_id})
+ elif isinstance(arg, (list, tuple)):
+ for a in arg:
+ if hasattr(a, "name") and a.name in node_map:
+ edges.append({"source": a.name, "target": node_id})
+
+ category_counts = {}
+ for n in nodes:
+ cat = n["category"]
+ category_counts[cat] = category_counts.get(cat, 0) + 1
+
+ return {
+ "metadata": {
+ "model_name": stage_name,
+ "source_type": "trace",
+ "total_nodes": len(nodes),
+ "category_counts": category_counts,
+ "error": None,
+ },
+ "nodes": nodes,
+ "edges": edges,
+ }
+
+
+# ---------------------------------------------------------------------------
+# Backend registry
+# ---------------------------------------------------------------------------
+
+
+class BackendConfig:
+ """Configuration for tracing a backend's compilation pipeline."""
+
+ def __init__(
+ self,
+ name: str,
+ quantizer_cls: Optional[str] = None,
+ pass_list_source: Optional[str] = None,
+ edge_compile_config: Optional[EdgeCompileConfig] = None,
+ # Some passes should skip QDQ detection after them
+ skip_qdq_after: Optional[List[str]] = None,
+ ):
+ self.name = name
+ self._quantizer_cls_path = quantizer_cls
+ self._pass_list_source = pass_list_source
+ self.edge_compile_config = edge_compile_config or EdgeCompileConfig(
+ _check_ir_validity=False,
+ )
+ self.skip_qdq_after = set(skip_qdq_after or [])
+
+ def get_quantizer(self):
+ if not self._quantizer_cls_path:
+ return None
+ module_path, cls_name = self._quantizer_cls_path.rsplit(".", 1)
+ import importlib
+
+ mod = importlib.import_module(module_path)
+ cls = getattr(mod, cls_name)
+ return cls()
+
+ def get_pass_list(self) -> List[Type]:
+ """Return a list of pass classes to iterate through."""
+ if not self._pass_list_source:
+ return []
+
+ import importlib
+
+ if ":" in self._pass_list_source:
+ # "module.path:function_or_attribute" format
+ module_path, attr = self._pass_list_source.split(":", 1)
+ mod = importlib.import_module(module_path)
+ obj = getattr(mod, attr)
+ if callable(obj) and not isinstance(obj, type):
+ return obj()
+ return list(obj)
+ else:
+ # "module.path.Class.pass_list" — dotted attribute access
+ parts = self._pass_list_source.rsplit(".", 1)
+ module_path, attr = parts
+ # Try importing as module.Class.attr
+ try:
+ mod = importlib.import_module(module_path)
+ return list(getattr(mod, attr))
+ except (ImportError, AttributeError):
+ # Try one level up: module.Class has attr
+ mod_path, cls_name = module_path.rsplit(".", 1)
+ mod = importlib.import_module(mod_path)
+ cls = getattr(mod, cls_name)
+ return list(getattr(cls, attr))
+
+
+# Backend configurations — add new backends here
+_BACKEND_REGISTRY: Dict[str, BackendConfig] = {}
+
+
+def register_backend(config: BackendConfig):
+ _BACKEND_REGISTRY[config.name] = config
+
+
+def _register_builtin_backends():
+ register_backend(
+ BackendConfig(
+ name="cortex_m",
+ quantizer_cls=(
+ "executorch.backends.cortex_m.quantizer.quantizer"
+ ".CortexMQuantizer"
+ ),
+ pass_list_source=(
+ "executorch.backends.cortex_m.passes"
+ ".cortex_m_pass_manager.CortexMPassManager.pass_list"
+ ),
+ edge_compile_config=EdgeCompileConfig(
+ preserve_ops=[
+ torch.ops.aten.linear.default,
+ torch.ops.aten.hardsigmoid.default,
+ torch.ops.aten.hardsigmoid_.default,
+ torch.ops.aten.hardswish.default,
+ torch.ops.aten.hardswish_.default,
+ ],
+ _check_ir_validity=False,
+ _core_aten_ops_exception_list=[
+ torch.ops.aten.max_pool2d.default,
+ ],
+ ),
+ skip_qdq_after=["FoldAndAnnotateQParamsPass"],
+ )
+ )
+
+ register_backend(
+ BackendConfig(
+ name="xnnpack",
+ quantizer_cls=(
+ "executorch.backends.xnnpack.quantizer"
+ ".xnnpack_quantizer.XNNPACKQuantizer"
+ ),
+ pass_list_source=(
+ "executorch.devtools.visualization.trace_passes"
+ ":_default_pass_list"
+ ),
+ )
+ )
+
+ register_backend(
+ BackendConfig(
+ name="cadence",
+ quantizer_cls=(
+ "executorch.backends.cadence.aot.quantizer"
+ ".quantizer.CadenceDefaultQuantizer"
+ ),
+ pass_list_source=(
+ "executorch.backends.cadence.aot.passes"
+ ":get_passes_in_default_order"
+ ),
+ )
+ )
+
+ register_backend(
+ BackendConfig(
+ name="vulkan",
+ quantizer_cls=(
+ "executorch.backends.vulkan.quantizer"
+ ".vulkan_quantizer.VulkanQuantizer"
+ ),
+ # Vulkan has no static pass list — passes are inline in preprocess()
+ pass_list_source=None,
+ )
+ )
+
+ register_backend(
+ BackendConfig(
+ name="qnn",
+ quantizer_cls=(
+ "executorch.backends.qualcomm.quantizer"
+ ".quantizer.QnnQuantizer"
+ ),
+ # QNN uses dynamic pipeline methods, no static pass list
+ pass_list_source=None,
+ )
+ )
+
+
+# ---------------------------------------------------------------------------
+# Pass instantiation
+# ---------------------------------------------------------------------------
+
+
+def _instantiate_pass(pass_cls, exported_program):
+ """Instantiate a pass class, passing exported_program if needed."""
+ sig = inspect.signature(pass_cls.__init__)
+ params = list(sig.parameters.keys())
+ # Skip 'self'
+ if "exported_program" in params:
+ return pass_cls(exported_program)
+ elif len(params) > 1 and params[1] == "exported_program":
+ return pass_cls(exported_program)
+ else:
+ try:
+ return pass_cls()
+ except TypeError:
+ # Some passes require arguments we can't infer
+ return None
+
+
+# ---------------------------------------------------------------------------
+# Models
+# ---------------------------------------------------------------------------
+
+
+def get_model(model_name: str) -> Tuple[torch.nn.Module, tuple]:
+ """Load a model by name, returning (module, example_inputs)."""
+ if model_name == "mobilenet_v2":
+ from torchvision.models import mobilenet_v2
+
+ model = mobilenet_v2(weights=None)
+ model.eval()
+ return model, (torch.randn(1, 3, 224, 224),)
+ elif model_name == "mobilenet_v3_small":
+ from torchvision.models import mobilenet_v3_small
+
+ model = mobilenet_v3_small(weights=None)
+ model.eval()
+ return model, (torch.randn(1, 3, 224, 224),)
+ elif model_name == "resnet18":
+ from torchvision.models import resnet18
+
+ model = resnet18(weights=None)
+ model.eval()
+ return model, (torch.randn(1, 3, 224, 224),)
+ elif model_name == "resnet50":
+ from torchvision.models import resnet50
+
+ model = resnet50(weights=None)
+ model.eval()
+ return model, (torch.randn(1, 3, 224, 224),)
+ elif model_name == "lstm":
+ from torch.nn.quantizable.modules import rnn
+
+ model = rnn.LSTM(10, 20, 2)
+ model.eval()
+ return model, (
+ torch.randn(5, 3, 10),
+ (torch.randn(2, 3, 20), torch.randn(2, 3, 20)),
+ )
+ else:
+ raise ValueError(
+ f"Unknown model: {model_name}. "
+ f"Available: mobilenet_v2, mobilenet_v3_small, resnet18, resnet50, lstm"
+ )
+
+
+# ---------------------------------------------------------------------------
+# Pipeline
+# ---------------------------------------------------------------------------
+
+
+def _randn_like(x):
+ if isinstance(x, torch.Tensor):
+ return torch.randn_like(x)
+ elif isinstance(x, tuple):
+ return tuple(_randn_like(t) for t in x)
+ return x
+
+
+def run_pipeline(
+ backend_config: BackendConfig,
+ model_name: str,
+ quantize: bool = True,
+) -> dict:
+ """Run a backend's compilation pipeline, capturing snapshots."""
+ model, example_inputs = get_model(model_name)
+ snapshots = []
+ stage_num = 0
+
+ # --- Stage: Quantize ---
+ if quantize:
+ quantizer = backend_config.get_quantizer()
+ if quantizer is None:
+ print(
+ f"Warning: no quantizer configured for {backend_config.name}, "
+ f"skipping quantization"
+ )
+ else:
+ stage_num += 1
+ print("Quantizing...")
+ from torchao.quantization.pt2e.quantize_pt2e import (
+ convert_pt2e,
+ prepare_pt2e,
+ )
+
+ model = torch.export.export_for_training(
+ model, example_inputs
+ ).module()
+ prepared = prepare_pt2e(model, quantizer)
+ with torch.no_grad():
+ for _ in range(5):
+ prepared(*[_randn_like(t) for t in example_inputs])
+ model = convert_pt2e(prepared)
+
+ # --- Stage: Export ---
+ stage_num += 1
+ print("Exporting...")
+ ep = export(model, example_inputs, strict=True)
+ qdq_groups = detect_qdq_groups(ep.graph)
+ snapshots.append(
+ extract_from_exported_program(
+ ep, f"{stage_num}_post_export", qdq_groups
+ )
+ )
+ print(f" {stage_num}_post_export: {snapshots[-1]['metadata']['total_nodes']} nodes")
+
+ # --- Stage: to_edge ---
+ stage_num += 1
+ print("Converting to edge...")
+ edge_program = to_edge(ep, compile_config=backend_config.edge_compile_config)
+ edge_ep = edge_program.exported_program()
+ qdq_groups = detect_qdq_groups(edge_ep.graph)
+ snapshots.append(
+ extract_from_exported_program(
+ edge_ep, f"{stage_num}_post_to_edge", qdq_groups
+ )
+ )
+ print(
+ f" {stage_num}_post_to_edge: "
+ f"{snapshots[-1]['metadata']['total_nodes']} nodes"
+ )
+
+ # --- Stages: Individual passes ---
+ pass_list = backend_config.get_pass_list()
+ if not pass_list:
+ print(
+ f" No static pass list for {backend_config.name} — "
+ f"showing export/edge stages only"
+ )
+ else:
+ ep = edge_ep
+ for i, pass_cls in enumerate(pass_list):
+ stage_num += 1
+ pass_name = f"{stage_num}_{pass_cls.__name__}"
+ print(f"Running {pass_name}...")
+
+ try:
+ transform_pass = _instantiate_pass(pass_cls, ep)
+ if transform_pass is None:
+ print(f" Skipping {pass_name} (cannot instantiate)")
+ stage_num -= 1
+ continue
+
+ ep = _transform(ep, transform_pass)
+
+ qdq_groups_for_pass = None
+ if pass_cls.__name__ not in backend_config.skip_qdq_after:
+ try:
+ qdq_groups_for_pass = detect_qdq_groups(ep.graph)
+ except Exception:
+ pass
+
+ snapshot = extract_from_exported_program(
+ ep, pass_name, qdq_groups_for_pass
+ )
+ snapshots.append(snapshot)
+ print(
+ f" {pass_name}: "
+ f"{snapshot['metadata']['total_nodes']} nodes"
+ )
+
+ except Exception as exc:
+ print(f" ERROR in {pass_name}: {exc}", file=sys.stderr)
+ error_snapshot = extract_from_exported_program(
+ ep, f"{pass_name}_ERROR"
+ )
+ error_snapshot["metadata"]["error"] = {
+ "pass_name": pass_name,
+ "message": str(exc),
+ "traceback": traceback.format_exc(),
+ }
+ snapshots.append(error_snapshot)
+ break
+
+ return {
+ "model_name": f"{model_name} ({backend_config.name})",
+ "passes": snapshots,
+ }
+
+
+# ---------------------------------------------------------------------------
+# XNNPACK helper — extract default pass list from XNNPACKPassManager
+# ---------------------------------------------------------------------------
+
+
+def _default_pass_list():
+ """Extract the default pass list from XNNPACKPassManager.
+
+ XNNPACKPassManager stores passes as an instance attribute, so we need
+ to peek at the default set in __init__.
+ """
+ try:
+ from executorch.backends.xnnpack._passes import XNNPACKPassManager
+
+ # XNNPACKPassManager.__init__ takes exported_program, but we just
+ # need the default pass class list. Use the source to get it.
+ src = inspect.getsource(XNNPACKPassManager.__init__)
+ # The pass list is assigned as self.passes = [...] in __init__
+ # Rather than parsing source, just grab from a dummy instance
+ # Actually, we can't instantiate without an exported_program.
+ # Fall back to the known default list.
+ from executorch.backends.xnnpack._passes import (
+ ChannelsLastTaggedReshapePass,
+ ConstPropPass,
+ Conv1dUnsqueezePass,
+ ConvertToLinearPass,
+ ConvertToSDPAPass,
+ ConvertToUpsampleBilinear2d,
+ DecomposeBatchNorm,
+ DecomposeConcatenate,
+ DimOrderOpsRevertPass,
+ FuseActivationPass,
+ FuseBatchNormPass,
+ PReLUReshapePass,
+ PropagateCustomMetaPass,
+ RemoveGetItemPass,
+ RemoveRedundantCopyPass,
+ XNNPACKRemoveCloneOpsTransform,
+ )
+
+ return [
+ XNNPACKRemoveCloneOpsTransform,
+ DimOrderOpsRevertPass,
+ ConvertToUpsampleBilinear2d,
+ ConvertToLinearPass,
+ PropagateCustomMetaPass,
+ ConvertToSDPAPass,
+ ConstPropPass,
+ FuseBatchNormPass,
+ DecomposeBatchNorm,
+ FuseActivationPass,
+ DecomposeConcatenate,
+ RemoveGetItemPass,
+ Conv1dUnsqueezePass,
+ PReLUReshapePass,
+ ChannelsLastTaggedReshapePass,
+ RemoveRedundantCopyPass,
+ ]
+ except ImportError as e:
+ print(f"Warning: Could not import XNNPACK passes: {e}", file=sys.stderr)
+ return []
+
+
+# ---------------------------------------------------------------------------
+# CLI
+# ---------------------------------------------------------------------------
+
+
+def main():
+ _register_builtin_backends()
+
+ parser = argparse.ArgumentParser(
+ description="Trace ExecuTorch backend compilation passes",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog=(
+ "Available backends: "
+ + ", ".join(sorted(_BACKEND_REGISTRY.keys()))
+ + "\n\nAvailable models: mobilenet_v2, mobilenet_v3_small, "
+ "resnet18, resnet50, lstm"
+ ),
+ )
+ parser.add_argument(
+ "--backend",
+ default=None,
+ help="Backend name (e.g., cortex_m, xnnpack, cadence)",
+ )
+ parser.add_argument(
+ "--model",
+ default="mobilenet_v2",
+ help="Model name (default: mobilenet_v2)",
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ default=None,
+ help="Output JSON path (default: __trace.json)",
+ )
+ parser.add_argument(
+ "--no-quantize",
+ action="store_true",
+ help="Skip quantization (trace on float model)",
+ )
+ parser.add_argument(
+ "--list-backends",
+ action="store_true",
+ help="List available backends and exit",
+ )
+ args = parser.parse_args()
+
+ if args.list_backends:
+ print("Available backends:")
+ for name, cfg in sorted(_BACKEND_REGISTRY.items()):
+ has_passes = "with passes" if cfg._pass_list_source else "export/edge only"
+ has_quant = "quantized" if cfg._quantizer_cls_path else "float"
+ print(f" {name:15s} ({has_quant}, {has_passes})")
+ return
+
+ if not args.backend:
+ parser.error("--backend is required (use --list-backends to see options)")
+
+ if args.backend not in _BACKEND_REGISTRY:
+ print(
+ f"Error: unknown backend '{args.backend}'. "
+ f"Available: {', '.join(sorted(_BACKEND_REGISTRY.keys()))}",
+ file=sys.stderr,
+ )
+ sys.exit(1)
+
+ backend_config = _BACKEND_REGISTRY[args.backend]
+ output = args.output or f"{args.model}_{args.backend}_trace.json"
+
+ result = run_pipeline(
+ backend_config,
+ args.model,
+ quantize=not args.no_quantize,
+ )
+
+ with open(output, "w") as f:
+ json.dump(result, f)
+ print(
+ f"\nWrote {output} ({len(result['passes'])} passes, "
+ f"{os.path.getsize(output) / 1024:.0f} KB)"
+ )
+ print(
+ f"Visualize: python -m executorch.devtools.visualization"
+ f".html_visualization {output} -o {output.replace('.json', '.html')}"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py
index 61fae8ff8fc..551337ec32f 100644
--- a/examples/arm/aot_arm_compiler.py
+++ b/examples/arm/aot_arm_compiler.py
@@ -47,6 +47,8 @@
from executorch.devtools.backend_debug import get_delegation_info
from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
+from executorch.exir.passes.quantize_io_pass import QuantizeInputs, QuantizeOutputs
+
from executorch.exir import (
EdgeCompileConfig,
ExecutorchBackendConfig,
@@ -860,6 +862,17 @@ def _to_channels_last(x):
),
)
+ # Strip the float I/O wrapper from the quantized model to produce
+ # fully int8 inputs and outputs. This must run before CortexMPassManager
+ # which renames quantized_decomposed ops to cortex_m variants.
+ if args.quantize:
+ print("Applying passes to create a fully int8 quantized model...")
+
+ edge = edge.transform([
+ QuantizeInputs(edge, [0]),
+ QuantizeOutputs(edge, [0]),
+ ])
+
pass_manager = CortexMPassManager(edge.exported_program())
edge._edge_programs["forward"] = pass_manager.transform()
diff --git a/trace_cortex_m_passes.py b/trace_cortex_m_passes.py
new file mode 100644
index 00000000000..4f4a6a30d92
--- /dev/null
+++ b/trace_cortex_m_passes.py
@@ -0,0 +1,445 @@
+"""Trace the Cortex-M compilation pipeline pass-by-pass, capturing graph snapshots.
+
+Runs quantization, export, to_edge, then each CortexMPassManager pass individually,
+saving a JSON file with per-pass graph snapshots for use with visualize_graph.py.
+
+Usage:
+ python3 trace_cortex_m_passes.py --model mobilenet_v2 -o mv2_trace.json
+
+Authored with Claude.
+"""
+
+import argparse
+import inspect
+import json
+import os
+import sys
+import traceback
+
+import torch
+from executorch.backends.arm.constants import DQ_OPS, Q_OPS
+from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager
+from executorch.backends.cortex_m.quantizer.quantizer import CortexMQuantizer
+from executorch.exir import EdgeCompileConfig, to_edge
+from executorch.exir.pass_base import ExportPass
+from executorch.exir.program._program import _transform
+from torch.export import export
+from torchao.quantization.pt2e.export_utils import model_is_exported
+from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
+
+
+CATEGORY_COLORS = {
+ "backend": "#4caf50",
+ "aten_compute": "#2196f3",
+ "quantize": "#ff9800",
+ "memory": "#9e9e9e",
+ "placeholder": "#03a9f4",
+ "param": "#78909c",
+ "delegate": "#ab47bc",
+}
+
+
+def categorize_node(op_name: str) -> str:
+ name = op_name.lower()
+ if "cortex_m" in name:
+ return "backend"
+ if any(
+ k in name
+ for k in (
+ "quantize_per_tensor",
+ "dequantize_per_",
+ "quantize_per_channel",
+ "dequantize_per_channel",
+ )
+ ):
+ return "quantize"
+ if any(
+ k in name
+ for k in (
+ "view",
+ "clone",
+ "permute",
+ "slice",
+ "copy",
+ "expand",
+ "reshape",
+ "t_copy",
+ "unsqueeze",
+ "squeeze",
+ )
+ ):
+ return "memory"
+ if any(k in name for k in ("placeholder", "output", "getitem", "get_attr")):
+ return "placeholder"
+ if "delegate" in name:
+ return "delegate"
+ return "aten_compute"
+
+
+def _make_label(op_name: str) -> str:
+ name = op_name.split("::")[-1] if "::" in op_name else op_name
+ if "." in name:
+ name = name.rsplit(".", 1)[0]
+ if len(name) > 30:
+ name = name[:27] + "..."
+ return name
+
+
+def _serialize_qparams(qparams_dict):
+ """Serialize a dict[int, QuantArgs] to JSON-safe form."""
+ if not qparams_dict:
+ return None
+ result = {}
+ for idx, qa in qparams_dict.items():
+ result[str(idx)] = {
+ "scale": qa.scale,
+ "zp": qa.zp,
+ "qmin": qa.qmin,
+ "qmax": qa.qmax,
+ "dtype": str(qa.dtype),
+ "axis": qa.axis,
+ "per_channel": qa.per_channel,
+ }
+ return result
+
+
+def detect_qdq_groups(graph) -> dict:
+ """Detect DQ -> compute_op -> Q chains and assign group IDs.
+
+ Returns {node_name: group_id} for nodes in DQ->op->Q chains.
+ """
+ groups = {}
+ group_id = 0
+
+ q_op_set = set()
+ dq_op_set = set()
+ for node in graph.nodes:
+ if node.op == "call_function":
+ target = node.target
+ if target in Q_OPS:
+ q_op_set.add(target)
+ if target in DQ_OPS:
+ dq_op_set.add(target)
+
+ for node in graph.nodes:
+ if node.op != "call_function" or node.target not in Q_OPS:
+ continue
+ # This is a Q node. Find its compute input.
+ if not node.args:
+ continue
+ compute_node = node.args[0]
+ if not hasattr(compute_node, "name"):
+ continue
+ if not hasattr(compute_node, "target"):
+ continue
+ # Skip if compute_node is itself a Q/DQ op
+ if compute_node.target in Q_OPS or compute_node.target in DQ_OPS:
+ continue
+
+ # Find DQ nodes feeding the compute node
+ dq_nodes = []
+ for arg in compute_node.args:
+ if hasattr(arg, "target") and arg.target in DQ_OPS:
+ dq_nodes.append(arg)
+ elif isinstance(arg, (list, tuple)):
+ for a in arg:
+ if hasattr(a, "target") and a.target in DQ_OPS:
+ dq_nodes.append(a)
+
+ if not dq_nodes:
+ continue
+
+ # Assign all members the same group_id
+ members = [n.name for n in dq_nodes] + [compute_node.name, node.name]
+ for m in members:
+ if m not in groups:
+ groups[m] = group_id
+ group_id += 1
+
+ return groups
+
+
+def extract_from_exported_program(ep, stage_name, qdq_groups=None):
+ """Walk an ExportedProgram's graph and extract visualization data."""
+ graph = ep.graph
+
+ nodes = []
+ edges = []
+ node_map = {}
+
+ for node in graph.nodes:
+ node_id = node.name
+ op_name = node.op
+ if node.op == "call_function":
+ op_name = str(node.target)
+ elif node.op == "call_method":
+ op_name = node.target
+
+ details = {"op": node.op, "target": str(getattr(node, "target", ""))}
+
+ # Shape/dtype from meta
+ meta_val = node.meta.get("val")
+ if meta_val is not None:
+ if hasattr(meta_val, "shape"):
+ details["shape"] = str(list(meta_val.shape))
+ details["dtype"] = str(meta_val.dtype)
+ elif isinstance(meta_val, (list, tuple)):
+ shapes = []
+ for v in meta_val:
+ if hasattr(v, "shape"):
+ shapes.append(f"{list(v.shape)} {v.dtype}")
+ if shapes:
+ details["shapes"] = shapes
+
+ # Stack trace
+ stack_trace = node.meta.get("stack_trace")
+ if stack_trace:
+ # Truncate long traces to last 500 chars
+ if len(stack_trace) > 500:
+ stack_trace = "..." + stack_trace[-500:]
+ details["stack_trace"] = stack_trace
+
+ # Quantization params (post-fold stages)
+ input_qparams = node.meta.get("input_qparams")
+ if input_qparams:
+ serialized = _serialize_qparams(input_qparams)
+ if serialized:
+ details["input_qparams"] = serialized
+
+ output_qparams = node.meta.get("output_qparams")
+ if output_qparams:
+ serialized = _serialize_qparams(output_qparams)
+ if serialized:
+ details["output_qparams"] = serialized
+
+ category = categorize_node(op_name)
+ label = _make_label(op_name)
+
+ # Meaningful labels for placeholders
+ if node.op == "placeholder":
+ target = str(getattr(node, "target", ""))
+ if target.startswith("p_") and target.endswith("_weight"):
+ label = "weight"
+ category = "param"
+ elif target.startswith("p_") and target.endswith("_bias"):
+ label = "bias"
+ category = "param"
+ elif target.startswith("b_") and "running_mean" in target:
+ label = "bn_mean"
+ category = "param"
+ elif target.startswith("b_") and "running_var" in target:
+ label = "bn_var"
+ category = "param"
+ elif target == "x" or not target.startswith(("p_", "b_")):
+ label = "input"
+ elif node.op == "output":
+ label = "output"
+
+ node_data = {
+ "id": node_id,
+ "label": label,
+ "w": max(len(label) * 8 + 16, 60),
+ "category": category,
+ "op_name": op_name,
+ "details": details,
+ }
+
+ if qdq_groups and node_id in qdq_groups:
+ node_data["qdq_group_id"] = qdq_groups[node_id]
+
+ node_map[node_id] = len(nodes)
+ nodes.append(node_data)
+
+ for arg in node.args:
+ if hasattr(arg, "name") and arg.name in node_map:
+ edges.append({"source": arg.name, "target": node_id})
+ elif isinstance(arg, (list, tuple)):
+ for a in arg:
+ if hasattr(a, "name") and a.name in node_map:
+ edges.append({"source": a.name, "target": node_id})
+
+ category_counts = {}
+ for n in nodes:
+ cat = n["category"]
+ category_counts[cat] = category_counts.get(cat, 0) + 1
+
+ return {
+ "metadata": {
+ "model_name": stage_name,
+ "source_type": "trace",
+ "total_nodes": len(nodes),
+ "category_counts": category_counts,
+ "error": None,
+ },
+ "nodes": nodes,
+ "edges": edges,
+ }
+
+
+def _to_channels_last(x):
+ if isinstance(x, torch.Tensor):
+ return x.to(memory_format=torch.channels_last) if x.dim() == 4 else x
+ elif isinstance(x, tuple):
+ return tuple(_to_channels_last(t) for t in x)
+ return x
+
+
+def get_model(model_name: str):
+ """Load a model by name, returning (module, example_inputs)."""
+ if model_name == "mobilenet_v2":
+ from torchvision.models import mobilenet_v2
+
+ model = mobilenet_v2(weights=None)
+ model.eval()
+ return model, (torch.randn(1, 3, 224, 224),)
+ elif model_name == "lstm":
+ from torch.nn.quantizable.modules import rnn
+
+ model = rnn.LSTM(10, 20, 2)
+ model.eval()
+ example_inputs = (
+ torch.randn(5, 3, 10),
+ (torch.randn(2, 3, 20), torch.randn(2, 3, 20)),
+ )
+ return model, example_inputs
+ else:
+ raise ValueError(f"Unknown model: {model_name}")
+
+
+def run_pipeline(model_name: str) -> dict:
+ """Run the full Cortex-M pipeline, capturing snapshots after each stage."""
+ model, example_inputs = get_model(model_name)
+ snapshots = []
+
+ def _randn_like(x):
+ if isinstance(x, torch.Tensor):
+ return torch.randn_like(x)
+ elif isinstance(x, tuple):
+ return tuple(_randn_like(t) for t in x)
+ return x
+
+ # --- Stage 1: Quantize ---
+ print("Quantizing...")
+ quantizer = CortexMQuantizer()
+ model = torch.export.export_for_training(model, example_inputs).module()
+ prepared = prepare_pt2e(model, quantizer)
+ # Calibrate with random data
+ with torch.no_grad():
+ for _ in range(5):
+ prepared(*[_randn_like(t) for t in example_inputs])
+ quantized = convert_pt2e(prepared)
+
+ # --- Stage 2: Export ---
+ print("Exporting...")
+ ep = export(quantized, example_inputs, strict=True)
+ qdq_groups = detect_qdq_groups(ep.graph)
+ snapshots.append(
+ extract_from_exported_program(ep, "1_post_export", qdq_groups)
+ )
+ print(f" 1_post_export: {snapshots[-1]['metadata']['total_nodes']} nodes")
+
+ # --- Stage 3: to_edge ---
+ print("Converting to edge...")
+ edge_config = EdgeCompileConfig(
+ preserve_ops=[
+ torch.ops.aten.linear.default,
+ torch.ops.aten.hardsigmoid.default,
+ torch.ops.aten.hardsigmoid_.default,
+ torch.ops.aten.hardswish.default,
+ torch.ops.aten.hardswish_.default,
+ ],
+ _check_ir_validity=False,
+ _core_aten_ops_exception_list=[torch.ops.aten.max_pool2d.default],
+ )
+ edge_program = to_edge(ep, compile_config=edge_config)
+ edge_ep = edge_program.exported_program()
+ qdq_groups = detect_qdq_groups(edge_ep.graph)
+ snapshots.append(
+ extract_from_exported_program(edge_ep, "2_post_to_edge", qdq_groups)
+ )
+ print(f" 2_post_to_edge: {snapshots[-1]['metadata']['total_nodes']} nodes")
+
+ # --- Stages 4-10+: Individual passes ---
+ pass_list = CortexMPassManager.pass_list
+ ep = edge_ep
+
+ for i, pass_cls in enumerate(pass_list):
+ pass_name = f"{i + 3}_{pass_cls.__name__}"
+ print(f"Running {pass_name}...")
+
+ try:
+ signature = inspect.signature(pass_cls.__init__)
+ if "exported_program" in signature.parameters:
+ transform_pass = pass_cls(ep)
+ elif issubclass(pass_cls, ExportPass):
+ transform_pass = pass_cls()
+ else:
+ raise RuntimeError(
+ f"Unexpected pass type: {pass_cls} ({type(pass_cls)})"
+ )
+ ep = _transform(ep, transform_pass)
+
+ # Detect QDQ groups for pre-fold stages
+ qdq_groups_for_pass = None
+ if pass_cls != FoldAndAnnotateQParamsPass:
+ try:
+ qdq_groups_for_pass = detect_qdq_groups(ep.graph)
+ except Exception:
+ pass
+
+ snapshot = extract_from_exported_program(
+ ep, pass_name, qdq_groups_for_pass
+ )
+ snapshots.append(snapshot)
+ print(f" {pass_name}: {snapshot['metadata']['total_nodes']} nodes")
+
+ except Exception as exc:
+ print(f" ERROR in {pass_name}: {exc}", file=sys.stderr)
+ error_snapshot = extract_from_exported_program(
+ ep, f"{pass_name}_ERROR"
+ )
+ error_snapshot["metadata"]["error"] = {
+ "pass_name": pass_name,
+ "message": str(exc),
+ "traceback": traceback.format_exc(),
+ }
+ snapshots.append(error_snapshot)
+ break
+
+ return {"model_name": model_name, "passes": snapshots}
+
+
+# Import here so it's available for the isinstance check in the pass loop
+from executorch.backends.arm._passes import FoldAndAnnotateQParamsPass # noqa: E402
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Trace Cortex-M compilation passes and output JSON snapshots"
+ )
+ parser.add_argument(
+ "--model",
+ default="mobilenet_v2",
+ help="Model name (default: mobilenet_v2)",
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ default=None,
+ help="Output JSON path (default: _trace.json)",
+ )
+ args = parser.parse_args()
+
+ output = args.output or f"{args.model}_trace.json"
+ result = run_pipeline(args.model)
+
+ with open(output, "w") as f:
+ json.dump(result, f)
+ print(
+ f"Wrote {output} ({len(result['passes'])} passes, "
+ f"{os.path.getsize(output) / 1024:.0f} KB)"
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/visualize_graph.py b/visualize_graph.py
new file mode 100644
index 00000000000..b7b018c91ea
--- /dev/null
+++ b/visualize_graph.py
@@ -0,0 +1,42 @@
+#
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+#
+"""DEPRECATED: Moved to executorch.devtools.visualization.html_visualization.
+
+Update your imports:
+ from executorch.devtools.visualization.html_visualization import (
+ visualize_edge_manager,
+ generate_html,
+ )
+"""
+
+import warnings
+
+warnings.warn(
+ "visualize_graph has moved to "
+ "executorch.devtools.visualization.html_visualization. "
+ "Update your imports.",
+ DeprecationWarning,
+ stacklevel=2,
+)
+
+from executorch.devtools.visualization.html_visualization import ( # noqa: F401
+ CATEGORY_COLORS,
+ categorize_node,
+ extract_from_exported_program,
+ extract_from_etrecord,
+ extract_from_pt2,
+ extract_from_pte,
+ extract_from_trace_json,
+ generate_html,
+ generate_multi_pass_html,
+ main,
+ visualize_edge_manager,
+)
+
+if __name__ == "__main__":
+ main()