Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions devtools/inspector/_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,3 +1649,129 @@ def get_stacktraces_for_row(aot_ops: List[str]) -> Dict[str, Optional[str]]:
df["stacktraces"] = df["aot_ops"].apply(get_stacktraces_for_row)

return df

def calculate_numeric_gap_from_taps(
self,
flat_runtime_outputs: Sequence,
tap_specs: Sequence,
distance: Union[str, NumericalComparatorBase] = "MSE",
reference_graph: Optional[str] = None,
disable_debug_handle_valdiation: bool = False,
) -> pd.DataFrame:
"""
Compares AOT intermediate outputs (from the ETRecord, captured by
IntermediateOutputCapturer) with runtime tap values exposed as
USER_OUTPUTs by the `intermediate_output_tap` package.

Unlike `calculate_numeric_gap`, this method works through delegates
with no backend-side support: the runtime values come from extra
outputs the AOT pass added to the ExportedProgram before lowering.

IMPORTANT: ETRecord serialization regenerates `debug_handle`s during
roundtrip, so the handles in `tap_specs` (set at AOT-pass time) are
stale. Each spec's `reducer_node_name` (set by
`strip_taps_(edge, tap_specs=specs)`) is used to look up the
post-roundtrip `debug_handle` in the AOT reference graph.

Args:
flat_runtime_outputs: The flat output tuple/list returned by
running the lowered program (e.g. `Method.execute(inputs)`).
tap_specs: The list of TapSpec returned by
`strip_taps_(edge, tap_specs=specs)` — these carry
`reducer_node_name` for alignment.
distance: "MSE", "L1", "SNR", or a `NumericalComparatorBase`.
reference_graph: AOT graph to use as the golden. See
`calculate_numeric_gap` for valid values.
disable_debug_handle_valdiation: Bypass debug handle validation.

Returns:
DataFrame with one row per (aot_handle, runtime_handle) pair.
Same shape produced by `calculate_numeric_gap`.
"""
reference_graph_module, _resolved_graph_name = self._resolve_reference_graph(
reference_graph,
disable_debug_handle_valdiation,
)
aot_intermediate_outputs, aot_debug_handle_to_op_names = (
self._get_aot_intermediate_outputs_and_op_names(reference_graph_module)
)
if len(aot_intermediate_outputs) == 0:
raise ValueError(
"No AOT intermediate outputs were captured. ETRecord must be "
"provided with representative_inputs for tap-based comparison."
)

spec_handles = _lookup_handles_by_name(reference_graph_module, tap_specs)

runtime_intermediate_outputs: Dict[DebugHandle, Tuple[Any, int]] = {}
runtime_debug_handle_to_op_names: Dict[DebugHandle, List[str]] = {}
for spec, dh in zip(tap_specs, spec_handles):
if dh is None:
continue
key: DebugHandle = (int(dh),)
runtime_intermediate_outputs[key] = (
flat_runtime_outputs[spec.output_index],
1,
)
runtime_debug_handle_to_op_names[key] = [spec.op_target]

if not runtime_intermediate_outputs:
raise ValueError(
"Could not recover any post-roundtrip handles for tap_specs. "
"Verify that strip_taps_(edge, tap_specs=specs) was called and "
"the returned specs were passed to this method, and that "
"generate_etrecord ran AFTER strip_taps_."
)

mapping = map_runtime_aot_intermediate_outputs(
aot_intermediate_outputs, runtime_intermediate_outputs
)

if isinstance(distance, NumericalComparatorBase):
comparator = distance
if comparator.inspector is None:
comparator.inspector = self
else:
metric = distance.strip().upper()
if metric == "MSE":
comparator = MSEComparator(inspector=self)
elif metric == "L1":
comparator = L1Comparator(inspector=self)
elif metric == "SNR":
comparator = SNRComparator(inspector=self)
else:
raise ValueError(f"Unsupported distance metric {distance!r}")

return comparator.compare(
mapping,
aot_debug_handle_to_op_names,
runtime_debug_handle_to_op_names,
)


def _lookup_handles_by_name(
reference_graph_module,
tap_specs: Sequence,
) -> List[Optional[int]]:
"""
For each TapSpec, return the post-roundtrip `debug_handle` of the FX node
whose name equals `spec.reducer_node_name`. Returns None for any spec
without a `reducer_node_name` set or whose name is not found.

`reducer_node_name` is set by `strip_taps_(edge, tap_specs=specs)`. FX
node names survive ETRecord serialization roundtrip, so this lookup is
stable.
"""
name_to_handle: Dict[str, Optional[int]] = {}
for n in reference_graph_module.graph.nodes:
h = n.meta.get("debug_handle")
name_to_handle[n.name] = int(h) if isinstance(h, int) else None

out: List[Optional[int]] = []
for spec in tap_specs:
rn = getattr(spec, "reducer_node_name", None)
if rn is None:
out.append(None)
continue
out.append(name_to_handle.get(rn))
return out
85 changes: 85 additions & 0 deletions devtools/intermediate_output_tap/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

runtime.python_library(
name = "spec",
srcs = ["_spec.py"],
)

runtime.python_library(
name = "custom_ops_lib",
srcs = ["custom_ops_lib.py"],
deps = [
"//caffe2:torch",
],
)

runtime.python_library(
name = "selectors",
srcs = ["_selectors.py"],
deps = [
"//caffe2:torch",
],
)

runtime.python_library(
name = "reducers",
srcs = ["_reducers.py"],
deps = [
"//caffe2:torch",
"//executorch/exir/dialects:lib",
],
)

runtime.python_library(
name = "tap_pass",
srcs = ["_tap_pass.py"],
deps = [
"//caffe2:torch",
":custom_ops_lib",
":reducers",
":selectors",
":spec",
],
)

runtime.python_library(
name = "strip_pass",
srcs = ["_strip_pass.py"],
deps = [
"//caffe2:torch",
":reducers",
":tap_pass",
],
)

runtime.python_library(
name = "convenience",
srcs = ["_convenience.py"],
deps = [
"fbsource//third-party/pypi/pandas:pandas",
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/runtime:runtime",
":reducers",
":selectors",
":spec",
":strip_pass",
":tap_pass",
],
)

runtime.python_library(
name = "lib",
srcs = ["__init__.py"],
deps = [
":convenience",
":custom_ops_lib",
":reducers",
":selectors",
":spec",
":strip_pass",
":tap_pass",
],
)
101 changes: 101 additions & 0 deletions devtools/intermediate_output_tap/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

"""
Public API for the ExecuTorch numerical debugger.

Backend-agnostic intermediate-value tap that complements the existing
Inspector framework:

- AOT side : `IntermediateOutputCapturer` (existing)
- Runtime side : ETDump intermediate output events (existing, opaque inside delegates)
- Runtime side : USER_OUTPUT taps (this module — works through delegates without
any backend-side changes)

Typical usage:

from executorch.devtools.intermediate_output_tap import (
tap_intermediate_outputs, strip_taps_, DEFAULT_STATS,
)

ep = export(model, example_inputs)
ep_tapped, specs = tap_intermediate_outputs(ep, reducer=DEFAULT_STATS)
edge = to_edge_transform_and_lower(ep_tapped, partitioner=[XnnpackPartitioner()])
strip_taps_(edge)
et_program = edge.to_executorch()

flat_outputs = runtime.forward(*example_inputs)
df = inspector.calculate_numeric_gap_from_taps(flat_outputs, specs)
"""

from executorch.devtools.intermediate_output_tap import (

Check warning on line 36 in devtools/intermediate_output_tap/__init__.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 F401

'executorch.devtools.intermediate_output_tap.custom_ops_lib' imported but unused See https://www.flake8rules.com/rules/F401.html.
custom_ops_lib, # noqa: F401 ensures torch.ops.executorch_devtools.tap is registered
)
from executorch.devtools.intermediate_output_tap._convenience import (
compare_aot_runtime_dataframe,
format_tap_dataframe,
specs_to_dataframe,
tap_all_and_run,
)
from executorch.devtools.intermediate_output_tap._reducers import (
ABS_MAX_ONLY,
DEFAULT_STATS,
FULL_TENSOR,
get_reducer,
MIN_MAX_MEAN,
StatReducer,
)
from executorch.devtools.intermediate_output_tap._selectors import (
NodeSelector,
select_all,
select_all_call_function,
select_any,
select_by_meta_tag,
select_by_module_path,
select_by_op_type,
select_not,
)
from executorch.devtools.intermediate_output_tap._spec import TapSpec
from executorch.devtools.intermediate_output_tap._strip_pass import strip_taps_
from executorch.devtools.intermediate_output_tap._tap_pass import (
find_tap_nodes,
is_tap_node,
tap_intermediate_outputs,
)


__all__ = [
# Core API
"tap_intermediate_outputs",
"strip_taps_",
"TapSpec",
# Convenience
"tap_all_and_run",
"specs_to_dataframe",
"format_tap_dataframe",
"compare_aot_runtime_dataframe",
# Reducers
"StatReducer",
"FULL_TENSOR",
"ABS_MAX_ONLY",
"MIN_MAX_MEAN",
"DEFAULT_STATS",
"get_reducer",
# Selectors
"NodeSelector",
"select_all_call_function",
"select_by_op_type",
"select_by_module_path",
"select_by_meta_tag",
"select_any",
"select_all",
"select_not",
# Helpers
"find_tap_nodes",
"is_tap_node",
]
Loading
Loading