diff --git a/backends/qualcomm/debugger/README.md b/backends/qualcomm/debugger/README.md index fb8f9a1c662..7456d7ca91a 100644 --- a/backends/qualcomm/debugger/README.md +++ b/backends/qualcomm/debugger/README.md @@ -90,6 +90,38 @@ Note: Files ending with `.bin ` do not support graph visualization in qairt_visu For more details, visit the [QAIRT Visualizer](https://pypi.org/project/qairt-visualizer/). +# Observatory + +A new, review-focused Observatory implementation is available under: + +`backends/qualcomm/debugger/observatory` + +Use the Observatory CLI to wrap any Qualcomm AOT export script. Use `--lens_recipe=accuracy` +to enable accuracy lenses. + +```bash +python -m executorch.backends.qualcomm.debugger.observatory \ + --output-html obs_report.html \ + --lens_recipe=accuracy \ + {original script and args} +``` + +For example: + +```bash +python -m executorch.backends.qualcomm.debugger.observatory \ + --output-html obs_report.html \ + --lens_recipe=accuracy \ + examples/qualcomm/oss_scripts/mobilevit_v2.py \ + --backend htp --model SM8650 -d ./imagenet-mini-val/ -b build-android/ --compile_only +``` + +> **Note**: Qualcomm example scripts (e.g. `oss_scripts/roberta.py`) use only absolute imports +> and are run as plain scripts. The Observatory CLI auto-selects `runpy.run_path` for these since +> their directories do not contain `__init__.py`. + +See `backends/qualcomm/debugger/observatory/README.md` for full documentation. + # ExecuTorch QNN Intermediate Output Debugger ExecuTorch QNN Intermediate Output Debugger is a tool that helps users debug intermediate output accuracy by comparing CPU outputs with QNN outputs. This tool offers a variety of output formats and flexibility for users to define their own metrics when debugging. diff --git a/backends/qualcomm/debugger/observatory/README.md b/backends/qualcomm/debugger/observatory/README.md new file mode 100644 index 00000000000..af785af5638 --- /dev/null +++ b/backends/qualcomm/debugger/observatory/README.md @@ -0,0 +1,146 @@ +# Qualcomm Observatory CLI + +Qualcomm-specific Observatory CLI that wraps `devtools/observatory` with QNN backend patches and +accuracy lenses. Requires a QNN SDK environment (source `$QNN_SDK_ROOT/bin/envsetup.sh` before +running on-device jobs). + +## Usage + +### Collection mode (default) + +```bash +python -m executorch.backends.qualcomm.debugger.observatory \ + [--output-html PATH] [--output-json PATH] SCRIPT [SCRIPT_ARGS...] +``` + +### With accuracy debugging + +```bash +python -m executorch.backends.qualcomm.debugger.observatory \ + --lens_recipe=accuracy \ + [--output-html PATH] [--output-json PATH] \ + SCRIPT [SCRIPT_ARGS...] +``` + +### Visualize mode (JSON → HTML, no re-execution) + +```bash +python -m executorch.backends.qualcomm.debugger.observatory visualize \ + --input-json report.json --output-html report.html +``` + +## Qualcomm examples + +Qualcomm example scripts use only absolute imports and live in directories without `__init__.py`, +so the Observatory CLI runs them as plain scripts via `runpy.run_path` (no special invocation +needed). + +### Vision model (ImageNet) + +```bash +source $QNN_SDK_ROOT/bin/envsetup.sh + +python -m executorch.backends.qualcomm.debugger.observatory \ + --output-html /tmp/obs_vit/report.html \ + --output-json /tmp/obs_vit/report.json \ + --lens_recipe=accuracy \ + examples/qualcomm/scripts/torchvision_vit.py \ + -m SM8650 -b ./build-android \ + --dataset imagenet-mini-val/ \ + -H mlgtw-linux -s \ + -a /tmp/obs_vit --seed 1126 --compile_only +``` + +### NLP model (Wikipedia sentences) + +```bash +python -m executorch.backends.qualcomm.debugger.observatory \ + --output-html /tmp/obs_roberta/report.html \ + --lens_recipe=accuracy \ + examples/qualcomm/oss_scripts/roberta.py \ + -m SM8650 -b ./build-android \ + -H mlgtw-linux -s \ + -a /tmp/obs_roberta --compile_only +``` + +### Compile-only (no device required) + +Add `--compile_only` to any Qualcomm script to export and lower without pushing to device. +This is useful for inspecting the compilation pipeline in CI or on a dev machine. + +## Available example scripts + +### `examples/qualcomm/scripts/` — vision models + +| Script | Model | +|---|---| +| `torchvision_vit.py` | Vision Transformer | +| `mobilenet_v2.py` | MobileNetV2 | +| `mobilenet_v3.py` | MobileNetV3 | +| `inception_v3.py` | InceptionV3 | +| `inception_v4.py` | InceptionV4 | + +Dataset: ImageNet (pass with `--dataset ` or `-d `). + +### `examples/qualcomm/oss_scripts/` — NLP/open-source models + +| Script | Model | +|---|---| +| `roberta.py` | RoBERTa | +| `bert.py` | BERT | +| `albert.py` | ALBERT | +| `distilbert.py` | DistilBERT | +| `eurobert.py` | EuroBERT | + +Dataset: Wikipedia sentences (`wikisent2.txt`). Pass with `-d `. + +Common flags: `-m ` (e.g. `SM8650`), `-b `, `-H `, +`-s `, `-a `, `--compile_only`. + +## Accuracy lenses (`--lens_recipe=accuracy`) + +Registers `AccuracyLens` and `PerLayerAccuracyLens` (with QNN dataset patches) on top of the +default `PipelineGraphCollectorLens`. These produce: + +- Per-stage accuracy metrics (PSNR, cosine similarity, MSE, top-k) +- Per-layer accuracy heat-map overlaid on the graph +- Cross-stage diff labels in the left panel of the HTML report + +QNN dataset patches (`lenses/qnn_dataset_patches.py`) wire the on-device inference output back +into the accuracy lens so metrics reflect true QNN outputs, not emulated CPU results. + +## Two-step workflow + +Collect on-device in CI, visualize locally without re-running: + +```bash +# Step 1 — collect (e.g., in CI with device attached) +python -m executorch.backends.qualcomm.debugger.observatory \ + --output-html /tmp/obs/report.html \ + --output-json /tmp/obs/report.json \ + examples/qualcomm/scripts/torchvision_vit.py \ + -m SM8650 -b ./build-android -d imagenet-mini-val/ \ + -H mlgtw-linux -s -a /tmp/obs + +# Step 2 — re-generate HTML from JSON (e.g., locally after lens update) +python -m executorch.backends.qualcomm.debugger.observatory visualize \ + --input-json /tmp/obs/report.json \ + --output-html /tmp/obs/report_v2.html +``` + +## Backend patches + +`lenses/qnn_patches.py` installs a monkey-patch on `ptq_calibrate` so the +`PipelineGraphCollectorLens` can intercept the QNN quantization calibration stage and capture +the graph at that point. The patch is active only while the Observatory context is open. + +`lenses/qnn_dataset_patches.py` wires on-device inference results into `AccuracyLens` so that +accuracy metrics use real QNN outputs. + +## See also + +- `backends/qualcomm/debugger/README.md` — broader Qualcomm debugger overview (QAIRT visualizer, + intermediate output debugger) +- `devtools/observatory/README.md` — framework overview, Python API, custom lens guide +- `devtools/observatory/USAGE.md` — full CLI reference +- `devtools/observatory/lenses/LENSES.md` — built-in lens details diff --git a/backends/qualcomm/debugger/observatory/__init__.py b/backends/qualcomm/debugger/observatory/__init__.py new file mode 100644 index 00000000000..b5f86874fd4 --- /dev/null +++ b/backends/qualcomm/debugger/observatory/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/qualcomm/debugger/observatory/__main__.py b/backends/qualcomm/debugger/observatory/__main__.py new file mode 100644 index 00000000000..e69e40ec80e --- /dev/null +++ b/backends/qualcomm/debugger/observatory/__main__.py @@ -0,0 +1,9 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .cli import main + +main() diff --git a/backends/qualcomm/debugger/observatory/cli.py b/backends/qualcomm/debugger/observatory/cli.py new file mode 100644 index 00000000000..c05485a0430 --- /dev/null +++ b/backends/qualcomm/debugger/observatory/cli.py @@ -0,0 +1,81 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Qualcomm Observatory CLI -- QNN-specific lens configuration. + +Collection mode (default): + python -m executorch.backends.qualcomm.debugger.observatory \\ + [--output-html PATH] [--output-json PATH] SCRIPT [SCRIPT_ARGS...] + +With accuracy debugging: + python -m executorch.backends.qualcomm.debugger.observatory \\ + --lens_recipe=accuracy SCRIPT [SCRIPT_ARGS...] + +Visualize mode (JSON -> HTML): + python -m executorch.backends.qualcomm.debugger.observatory visualize \\ + --input-json report.json --output-html report.html +""" + +from __future__ import annotations + +import sys + + +def main(): + from executorch.devtools.observatory.cli import ( + make_collect_parser, + make_visualize_parser, + run_observatory, + run_visualize, + ) + + if len(sys.argv) > 1 and sys.argv[1] == "visualize": + parser = make_visualize_parser() + args = parser.parse_args(sys.argv[2:]) + run_visualize(args.input_json, args.output_html) + return + + parser = make_collect_parser( + prog="python -m executorch.backends.qualcomm.debugger.observatory" + ) + parser.add_argument( + "--lens_recipe", + choices=["accuracy"], + default=None, + help="Lens recipe to enable (e.g. accuracy)", + ) + args = parser.parse_args(sys.argv[1:]) + + from executorch.devtools.observatory.observatory import Observatory + from executorch.devtools.observatory.lenses.pipeline_graph_collector import ( + PipelineGraphCollectorLens, + ) + from .lenses.qnn_patches import install_qnn_patches + + Observatory.clear() + PipelineGraphCollectorLens.register_backend_patches(install_qnn_patches) + Observatory.register_lens(PipelineGraphCollectorLens) + + if args.lens_recipe == "accuracy": + from executorch.devtools.observatory.lenses.accuracy import AccuracyLens + from .lenses.qnn_dataset_patches import install_qnn_dataset_patches + + AccuracyLens.register_dataset_patches(install_qnn_dataset_patches) + Observatory.register_lens(AccuracyLens) + + from executorch.devtools.observatory.lenses.per_layer_accuracy import ( + PerLayerAccuracyLens, + ) + + Observatory.register_lens(PerLayerAccuracyLens) + + run_observatory( + args.script, args.script_args, Observatory, args.output_html, args.output_json + ) + + +if __name__ == "__main__": + main() diff --git a/backends/qualcomm/debugger/observatory/lenses/__init__.py b/backends/qualcomm/debugger/observatory/lenses/__init__.py new file mode 100644 index 00000000000..b5f86874fd4 --- /dev/null +++ b/backends/qualcomm/debugger/observatory/lenses/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/qualcomm/debugger/observatory/lenses/qnn_dataset_patches.py b/backends/qualcomm/debugger/observatory/lenses/qnn_dataset_patches.py new file mode 100644 index 00000000000..aa67427581d --- /dev/null +++ b/backends/qualcomm/debugger/observatory/lenses/qnn_dataset_patches.py @@ -0,0 +1,75 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""QNN dataset patches for AccuracyLens. + +Installs monkey-patches on executorch.examples.qualcomm.utils dataset functions +to capture targets and task type for accuracy evaluation. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from executorch.devtools.observatory.lenses.accuracy import AccuracyLens + + +def install_qnn_dataset_patches(cls: type[AccuracyLens]) -> None: + """Install QNN dataset capture patches on AccuracyLens.""" + try: + import executorch.examples.qualcomm.utils as utils_module + + if hasattr(utils_module, "get_imagenet_dataset"): + original = utils_module.get_imagenet_dataset + cls._originals["get_imagenet_dataset"] = original + + def patched_imagenet(*args, **kwargs): + inputs, targets = original(*args, **kwargs) + cls._captured_targets = targets + cls._task_type = "classification" + logging.info( + "[AccuracyLens] Captured ImageNet targets (%d samples)", + len(targets), + ) + return inputs, targets + + utils_module.get_imagenet_dataset = patched_imagenet + logging.info("[AccuracyLens] Installed patch: get_imagenet_dataset") + + if hasattr(utils_module, "get_masked_language_model_dataset"): + original_mlm = utils_module.get_masked_language_model_dataset + cls._originals["get_masked_language_model_dataset"] = original_mlm + + def patched_mlm(*args, **kwargs): + inputs, targets = original_mlm(*args, **kwargs) + cls._captured_targets = targets + cls._task_type = "mlm" + logging.info( + "[AccuracyLens] Captured MLM targets (%d samples)", + len(targets), + ) + return inputs, targets + + utils_module.get_masked_language_model_dataset = patched_mlm + logging.info( + "[AccuracyLens] Installed patch: get_masked_language_model_dataset" + ) + + def _uninstall(): + try: + for key, orig in cls._originals.items(): + if hasattr(utils_module, key): + setattr(utils_module, key, orig) + except Exception: + pass + + cls._dataset_uninstallers.append(_uninstall) + except ImportError: + logging.debug( + "[AccuracyLens] qualcomm utils not available, skipping dataset patches" + ) diff --git a/backends/qualcomm/debugger/observatory/lenses/qnn_patches.py b/backends/qualcomm/debugger/observatory/lenses/qnn_patches.py new file mode 100644 index 00000000000..3de4a56ae73 --- /dev/null +++ b/backends/qualcomm/debugger/observatory/lenses/qnn_patches.py @@ -0,0 +1,76 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""QNN backend patches for PipelineGraphCollectorLens. + +Installs a monkey-patch on executorch.examples.qualcomm.utils.ptq_calibrate +to capture the float ExportedProgram with from_node metadata populated. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from executorch.devtools.observatory.lenses.pipeline_graph_collector import ( + PipelineGraphCollectorLens, + ) + + +def install_qnn_patches(cls: type[PipelineGraphCollectorLens]) -> None: + """Install QNN ptq_calibrate patch on the PipelineGraphCollectorLens.""" + try: + import executorch.backends.qualcomm.export_utils as qnn_utils_module + + original = qnn_utils_module._ptq_calibrate + cls._originals["qnn._ptq_calibrate"] = original + + def patched_ptq_calibrate(captured_model, quantizer, dataset): + cls._set_accuracy_fallback_dataset( + dataset, source="qnn.ptq_calibrate" + ) + + collect_target = captured_model + try: + sample = ( + cls._last_calibration_dataset[0] + if cls._last_calibration_dataset + else None + ) + if sample is not None: + import torch + + ep = torch.export.export(captured_model, sample, strict=False) + collect_target = ep.run_decompositions({}) + except Exception as exc: + logging.debug( + "[PipelineGraphCollector] from_node re-export skipped: %s", exc + ) + + try: + cls._collect_fn("Exported Float", collect_target) + except Exception as exc: + logging.debug( + "[PipelineGraphCollector] collect skipped (Exported Float): %s", + exc, + ) + return original(captured_model, quantizer, dataset) + + qnn_utils_module._ptq_calibrate = patched_ptq_calibrate + logging.info("[PipelineGraphCollector] Installed QNN patch: _ptq_calibrate") + + def _uninstall(): + try: + qnn_utils_module._ptq_calibrate = original + except Exception: + pass + + cls._backend_uninstallers.append(_uninstall) + except Exception as exc: + logging.warning( + "[PipelineGraphCollector] Failed to patch QNN ptq_calibrate: %s", exc + ) diff --git a/backends/xnnpack/debugger/__init__.py b/backends/xnnpack/debugger/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/xnnpack/debugger/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/xnnpack/debugger/observatory/README.md b/backends/xnnpack/debugger/observatory/README.md new file mode 100644 index 00000000000..cd5a203656a --- /dev/null +++ b/backends/xnnpack/debugger/observatory/README.md @@ -0,0 +1,119 @@ +# XNNPack Observatory CLI + +XNNPack-specific Observatory CLI that wraps `devtools/observatory` with XNNPack backend patches and accuracy lenses. + +## Usage + +### Collection mode (default) + +```bash +python -m executorch.backends.xnnpack.debugger.observatory \ + [--output-html PATH] [--output-json PATH] SCRIPT [SCRIPT_ARGS...] +``` + +### With accuracy debugging + +```bash +python -m executorch.backends.xnnpack.debugger.observatory \ + --lense_recipe=accuracy \ + [--output-html PATH] [--output-json PATH] \ + SCRIPT [SCRIPT_ARGS...] +``` + +### Visualize mode (JSON → HTML, no re-execution) + +```bash +python -m executorch.backends.xnnpack.debugger.observatory visualize \ + --input-json report.json --output-html report.html +``` + +## XNNPack examples + +`examples/xnnpack/aot_compiler.py` uses relative imports (`from . import MODEL_NAME_TO_OPTIONS`, +`from ..models import ...`) and must be executed as a Python module. The CLI handles this +automatically: when the supplied path ends in `.py` and its directory contains `__init__.py`, it +uses `runpy.run_module` instead of `runpy.run_path`. + +### File path (auto-detected as module) + +```bash +python -m executorch.backends.xnnpack.debugger.observatory \ + --output-html /tmp/mv2/report.html \ + --lense_recipe=accuracy \ + examples/xnnpack/aot_compiler.py \ + --model_name=mv2 --delegate --quantize --output_dir /tmp/mv2 +``` + +### Dotted module name (explicit) + +```bash +python -m executorch.backends.xnnpack.debugger.observatory \ + --output-html /tmp/mv2/report.html \ + --lense_recipe=accuracy \ + examples.xnnpack.aot_compiler \ + --model_name=mv2 --delegate --quantize --output_dir /tmp/mv2 +``` + +### Available model names + +Pass `--model_name` with any of the models defined in `examples/xnnpack/__init__.py`: + +| Model | Notes | +|---|---| +| `mv2` | MobileNetV2 — fast, quantizable | +| `mv3` | MobileNetV3 | +| `resnet18` | ResNet-18 | +| `resnet50` | ResNet-50 | +| `vit` | Vision Transformer | +| `ic3` | InceptionV3 | +| `ic4` | InceptionV4 | +| `dl3` | DeepLabV3 | +| `edsr` | Super-resolution | +| `mobilebert` | MobileBERT | +| `w2l` | Wav2Letter | +| `linear` | Linear baseline | +| `add` / `add_mul` | Arithmetic baselines | +| `llama2` | Llama 2 (requires HuggingFace token) | +| `emformer_join` / `emformer_transcribe` | Speech | + +Common flags: `--delegate` (XNNPACK delegation, on by default), `--quantize` (8-bit PTQ), +`--output_dir` (where the `.pte` is written). + +## Accuracy lenses (`--lense_recipe=accuracy`) + +Registers `AccuracyLens` and `PerLayerAccuracyLens` on top of the default +`PipelineGraphCollectorLens`. These produce: + +- Per-stage accuracy metrics (PSNR, cosine similarity, MSE, top-k) +- Per-layer accuracy heat-map overlaid on the graph +- Cross-stage diff labels in the left panel of the HTML report + +## Two-step workflow + +Collect in one environment, visualize in another: + +```bash +# Step 1 — collect +python -m executorch.backends.xnnpack.debugger.observatory \ + --output-html /tmp/mv2/report.html \ + --output-json /tmp/mv2/report.json \ + examples/xnnpack/aot_compiler.py \ + --model_name=mv2 --delegate --quantize --output_dir /tmp/mv2 + +# Step 2 — re-generate HTML from JSON (e.g., after lens code update) +python -m executorch.backends.xnnpack.debugger.observatory visualize \ + --input-json /tmp/mv2/report.json \ + --output-html /tmp/mv2/report_v2.html +``` + +## Backend patches + +`lenses/xnnpack_patches.py` installs XNNPack-specific monkey-patches so the +`PipelineGraphCollectorLens` can intercept XNNPack-specific lowering steps. These patches are +active only while the Observatory context is open and are removed when it closes. + +## See also + +- `devtools/observatory/README.md` — framework overview, Python API, custom lens guide +- `devtools/observatory/USAGE.md` — full CLI reference +- `devtools/observatory/lenses/LENSES.md` — built-in lens details diff --git a/backends/xnnpack/debugger/observatory/__init__.py b/backends/xnnpack/debugger/observatory/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/xnnpack/debugger/observatory/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/xnnpack/debugger/observatory/__main__.py b/backends/xnnpack/debugger/observatory/__main__.py new file mode 100644 index 00000000000..487ded86e42 --- /dev/null +++ b/backends/xnnpack/debugger/observatory/__main__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .cli import main + +main() diff --git a/backends/xnnpack/debugger/observatory/cli.py b/backends/xnnpack/debugger/observatory/cli.py new file mode 100644 index 00000000000..483e96f55dd --- /dev/null +++ b/backends/xnnpack/debugger/observatory/cli.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""XNNPACK Observatory CLI -- XNNPACK-specific lens configuration. + +Collection mode (default): + python -m executorch.backends.xnnpack.debugger.observatory \\ + [--output-html PATH] [--output-json PATH] SCRIPT [SCRIPT_ARGS...] + +With accuracy debugging: + python -m executorch.backends.xnnpack.debugger.observatory \\ + --lens_recipe=accuracy SCRIPT [SCRIPT_ARGS...] + +Visualize mode (JSON -> HTML): + python -m executorch.backends.xnnpack.debugger.observatory visualize \\ + --input-json report.json --output-html report.html +""" + +from __future__ import annotations + +import sys + + +def main(): + from executorch.devtools.observatory.cli import ( + make_collect_parser, + make_visualize_parser, + run_observatory, + run_visualize, + ) + + if len(sys.argv) > 1 and sys.argv[1] == "visualize": + parser = make_visualize_parser() + args = parser.parse_args(sys.argv[2:]) + run_visualize(args.input_json, args.output_html) + return + + parser = make_collect_parser( + prog="python -m executorch.backends.xnnpack.debugger.observatory" + ) + parser.add_argument( + "--lens_recipe", + choices=["accuracy"], + default=None, + help="Lens recipe to enable (e.g. accuracy)", + ) + args = parser.parse_args(sys.argv[1:]) + + from executorch.devtools.observatory.observatory import Observatory + from executorch.devtools.observatory.lenses.pipeline_graph_collector import ( + PipelineGraphCollectorLens, + ) + from .lenses.xnnpack_patches import install_xnnpack_patches + + Observatory.clear() + PipelineGraphCollectorLens.register_backend_patches(install_xnnpack_patches) + Observatory.register_lens(PipelineGraphCollectorLens) + + if args.lens_recipe == "accuracy": + from executorch.devtools.observatory.lenses.accuracy import AccuracyLens + + Observatory.register_lens(AccuracyLens) + + from executorch.devtools.observatory.lenses.per_layer_accuracy import ( + PerLayerAccuracyLens, + ) + + Observatory.register_lens(PerLayerAccuracyLens) + + run_observatory( + args.script, args.script_args, Observatory, args.output_html, args.output_json + ) + + +if __name__ == "__main__": + main() diff --git a/backends/xnnpack/debugger/observatory/lenses/__init__.py b/backends/xnnpack/debugger/observatory/lenses/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/xnnpack/debugger/observatory/lenses/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/xnnpack/debugger/observatory/lenses/xnnpack_patches.py b/backends/xnnpack/debugger/observatory/lenses/xnnpack_patches.py new file mode 100644 index 00000000000..e3ea2554338 --- /dev/null +++ b/backends/xnnpack/debugger/observatory/lenses/xnnpack_patches.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""XNNPACK backend patches for PipelineGraphCollectorLens. + +Installs a monkey-patch on XNNPACK quantization helpers (both +``examples`` and ``executorch.examples`` import paths) to capture the +float ExportedProgram with from_node metadata populated. +""" + +from __future__ import annotations + +import importlib +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from executorch.devtools.observatory.lenses.pipeline_graph_collector import ( + PipelineGraphCollectorLens, + ) + + +MODULE_CANDIDATES = ( + "examples.xnnpack.quantization.utils", + "executorch.examples.xnnpack.quantization.utils", +) + + +def _install_patch_for_module( + cls: type[PipelineGraphCollectorLens], module, alias: str +) -> bool: + try: + original = module.quantize + except AttributeError: + logging.debug( + "[PipelineGraphCollector] XNNPACK patch skipped; no quantize in %s", + alias, + ) + return False + + key = f"xnnpack.quantize[{alias}]" + if key in cls._originals: + return True + + cls._originals[key] = original + + def patched_quantize(model, example_inputs, quant_type=None): + sample = None + try: + if isinstance(example_inputs, (tuple, list)): + sample = tuple(example_inputs) + else: + sample = (example_inputs,) + cls._set_accuracy_fallback_dataset([sample], source=key) + except Exception: + pass + + collect_target = model + try: + import torch + + if sample is not None: + ep = torch.export.export(model, sample, strict=False) + collect_target = ep.run_decompositions({}) + except Exception as exc: + logging.debug( + "[PipelineGraphCollector] XNNPACK from_node re-export skipped: %s", + exc, + ) + + try: + cls._collect_fn("Exported Float", collect_target) + except Exception as exc: + logging.debug( + "[PipelineGraphCollector] collect skipped (Exported Float): %s", + exc, + ) + + if quant_type is None: + return original(model, example_inputs) + return original(model, example_inputs, quant_type) + + module.quantize = patched_quantize + logging.info( + "[PipelineGraphCollector] Installed XNNPACK patch: quantize (%s)", alias + ) + + def _uninstall(): + try: + module.quantize = original + except Exception: + pass + + cls._backend_uninstallers.append(_uninstall) + return True + + +def install_xnnpack_patches(cls: type[PipelineGraphCollectorLens]) -> None: + """Install XNNPACK quantize patch on the PipelineGraphCollectorLens.""" + + patched = False + seen_modules: set[int] = set() + + for alias in MODULE_CANDIDATES: + try: + module = importlib.import_module(alias) + except ImportError: + continue + + module_id = id(module) + if module_id in seen_modules: + continue + seen_modules.add(module_id) + + try: + patched |= _install_patch_for_module(cls, module, alias) + except Exception as exc: + logging.warning( + "[PipelineGraphCollector] Failed to patch XNNPACK quantize (%s): %s", + alias, + exc, + ) + + if not patched: + logging.warning( + "[PipelineGraphCollector] Failed to patch XNNPACK quantize: no candidate module found" + ) diff --git a/devtools/fx_viewer/API_IMPLEMENTATION_STATUS.md b/devtools/fx_viewer/API_IMPLEMENTATION_STATUS.md new file mode 100644 index 00000000000..28b8e56e8f1 --- /dev/null +++ b/devtools/fx_viewer/API_IMPLEMENTATION_STATUS.md @@ -0,0 +1,60 @@ +# RFC API Implementation Status + +Date: 2026-03-13 +Scope: `backends/qualcomm/utils/fx_viewer/templates/*` +Reference: `backends/qualcomm/utils/fx_viewer/RFC_FX_VIEWER_API_INTERFACE.md` + +## Summary + +Most RFC APIs are implemented in the current JS runtime. +Remaining gaps are minor and documented below. + +## Implemented + +1. Construction and presets +- `FXGraphViewer.create(config)` +- Presets: `split`, `compact`, `headless`, `custom` +- Slot precedence and layout merge behavior + +2. Canonical state API +- `getState` +- `setState` +- `replaceState` (state replacement with camera/search handling) +- `batch` + +3. Convenience APIs +- `setTheme`, `setLayers`, `setColorBy` +- `selectNode`, `clearSelection`, `search`, `zoomToFit`, `panToNode`, `animateToNode` +- `setUIVisibility`, `setLayout` +- `enterFullscreen`, `exitFullscreen`, `destroy` + +4. Runtime layer mutation APIs +- `upsertLayer`, `removeLayer`, `patchLayerNodes`, `setLayerLabel`, `setColorRule` + +5. Events +- `statechange`, `selectionchange`, `themechange`, `layoutchange`, `error` +- `on`/`off` subscription model + +6. Compare API +- `FXGraphCompare.create` +- `setColumns` (applies to optional compare container) +- `setCompact`, `setSync`, `destroy` + +7. UI synchronization contract +- External state updates reflected in theme/layers/colorBy controls +- `syncControlsFromState()` in `UIManager` + +8. Fullscreen taskbar support +- Optional taskbar fullscreen button via `layout.fullscreen.button` / `ui.controls.fullscreenButton` + +## Partial / Follow-up + +1. Strict state schema validation +- RFC describes validation-rich state store; current implementation uses pragmatic checks and coercions. + +2. Theme registration depth +- `registerTheme` works; deeper token validation and compatibility checks are not yet strict. + +3. Compare camera/theme/layer sync +- Compare selection sync is implemented. +- Other sync dimensions are modeled in config but not fully propagated yet. diff --git a/devtools/fx_viewer/README.md b/devtools/fx_viewer/README.md new file mode 100644 index 00000000000..22090c9a6bb --- /dev/null +++ b/devtools/fx_viewer/README.md @@ -0,0 +1,307 @@ +# fx_viewer + +`fx_viewer` exports FX graphs to interactive HTML and provides an embeddable JavaScript runtime. + +## What It Provides + +Python side: +1. Extract FX graph (`torch.export` / `torch.fx`). +2. Compute layout (fast-sugiyama — Rust-backed Sugiyama). +3. Build payload (`base` + `extensions`). +4. Export JSON / JS snippet / standalone HTML. + +JS side: +1. Canvas graph + minimap + info panel + search. +2. Layer toggles and color-by controls. +3. State-driven API for embedding, compare mode, fullscreen, and runtime layer mutation. + +## Dependencies + +Layout computation is delegated to the external +[`fast-sugiyama`](https://github.com/austinorr/fast-sugiyama) package +(Rust-backed, drop-in replacement for the previously vendored `grandalf` +fork). Install with the `[all]` extra so that `rectangle-packer` is pulled +in — it is required to pack disconnected graph components into a single +non-overlapping layout: + +```bash +pip install 'fast-sugiyama[all]' +``` + +- Python ≥ 3.11 is required by `fast-sugiyama`. Users still on Python 3.10 + cannot use `FXGraphExporter`'s HTML/JSON output until they upgrade. +- The package is **not** declared in executorch's `pyproject.toml`: it is + imported lazily on the first layout call. If it is missing, the exporter + raises an `ImportError` with install instructions. + +## Quick Start + +From repo root: + +```bash +source .venv/bin/activate +python backends/qualcomm/utils/fx_viewer/examples/demo_fx_viewer_extensions.py --model both +``` + +Outputs: +1. `swin_graph_v3_extensions.html` +2. `llama_graph_v3_extensions.html` + +## Python API + +```python +from executorch.backends.qualcomm.utils.fx_viewer import ( + FXGraphExporter, + GraphExtension, + CategoricalColorRule, +) + +# Step 1 instantiate exporter +f = FXGraphExporter(graph_module) + +# Optional (define graph extension Layer) +ext = GraphExtension(id="backend", name="Backend Assignment") +ext.add_node_data("node_0", {"backend": "cpu"}) +ext.set_color_rule(CategoricalColorRule(attribute="backend")) +f.add_extension(ext) + +# Step 2 save standalone html +f.export_html("graph.html") +``` + +Main exporter methods: +1. [ ] `generate_json_payload()` +2. `export_json(path)` +3. `export_js(container_id)` +4. `export_html(path)` + +Python tutorial: +1. `backends/qualcomm/utils/fx_viewer/examples/PYTHON_API_TUTORIAL.md` + +## JS API (Runtime) + +Construction: +1. `FXGraphViewer.create(config)` +2. Compatibility constructor: `new FXGraphViewer(containerId, payload)` + +State/events: +1. `getState`, `setState`, `replaceState`, `batch` +2. `on`, `off` + +Viewer actions: +1. `setTheme`, `setLayers`, `setColorBy` +2. `selectNode`, `clearSelection`, `search` +3. `zoomToFit`, `panToNode`, `animateToNode` +4. `setUIVisibility`, `setLayout` +5. `enterFullscreen`, `exitFullscreen`, `destroy` + +Runtime layer mutation: +1. `upsertLayer`, `removeLayer`, `patchLayerNodes`, `setLayerLabel`, `setColorRule` + +Compare: +1. `FXGraphCompare.create({ viewers, layout, sync })` — `viewers` accepts `FXGraphViewer[]` or `Map` +2. `setSync(patch)`, `destroy()` +3. `setTiled()` / `setCompact()` — no-ops (tiled layout is always used in compare mode) +4. `sync.mode`: `'auto'` (default) | `'id'` | `'layer'` | `'none'` + - `'auto'`: tries `debug_handle` set-intersection first, falls back to node-ID match + - `'id'`: matches by node id only + - `'layer'`: matches by `extensions[layer].nodes[nodeId].info[field]`; set `sync.layer` and `sync.field` + - `'none'`: no sync +5. `layout.container`: CSS selector or `HTMLElement` — required + +Highlight groups (programmatic overlay, independent of selection): +1. `viewer.addHighlightGroup(groupId, nodeIds, color)` — add/replace a named group +2. `viewer.removeHighlightGroup(groupId)` — remove one group +3. `viewer.clearAllHighlightGroups()` — remove all groups +4. `viewer.getHighlightGroups()` — returns `Map` + +## Compare View Architecture + +`FXGraphCompare` owns the compare layout DOM entirely. It builds a structured shell inside `layout.container` and moves canvas/minimap elements out of each viewer's own wrapper into that shell. + +### DOM Structure + +``` +layout.container (user-supplied div) + .fx-compare-root (flex column, fills container — created by FXGraphCompare) + .fx-compare-grid (CSS grid: 160px sidebar + N×1fr graph columns, 3 rows) + .fx-compare-sidebar-cell (col 1, rows 1-2 — shared controls) + .fx-compare-minimap-cell (col i+2, row 1 — one per viewer) + viewer.minimapRenderer.container (moved here from viewer.sidebar) + .fx-compare-graph-name (graph title label, absolute overlay) + .fx-compare-canvas-cell (col i+2, row 2 — one per viewer) + viewer.mainArea (moved here from viewer.wrapper) + .fx-compare-info-row (col 1..-1, row 3 — CSS subgrid, merged info panel) + .fx-compare-sidebar-info-cell (col 1 — empty spacer) + .fx-compare-info-hdr (col i+2 — graph name header, one per visible viewer) + .fx-compare-info-prop (col 1 — property name, sticky left) + .fx-compare-info-val (col i+2 — property value, one per visible viewer) +``` + +Each viewer's own `.fx-viewer-wrapper` (sidebar, resizer, etc.) is hidden (`display: none`) while compare is active. The viewer's public API (`setTheme`, `selectNode`, `renderAll`, etc.) continues to work normally because it operates on `mainArea` and `minimapRenderer.container` regardless of where they are in the DOM. + +### Uniform Row Heights + +All minimap cells are the same fixed height (CSS `minmax(100px, 200px)`). All canvas cells share `minmax(50vh, 100vh)` in the same grid row, so they expand to fill identical space. Vertical boundaries are aligned across graphs because the cells are siblings in the same CSS grid — no per-column height negotiation needed. + +### Ownership and Lifecycle + +1. **`FXGraphCompare` owns the compare DOM.** It creates `.fx-compare-root`, `.fx-compare-grid`, `.fx-compare-sidebar-cell`, `.fx-compare-minimap-cell`, `.fx-compare-canvas-cell`, and `.fx-compare-info-row` elements and appends them to `layout.container`. +2. **Viewers own their renderers.** `FXGraphCompare` only moves `viewer.mainArea` and `viewer.minimapRenderer.container` — it does not touch canvas contexts, event listeners, or state machines. +3. **DOM snapshots for teardown.** Before moving any element, `FXGraphCompare` records its original parent and next sibling in a `WeakMap`. `destroy()` calls `_teardownCompareDOM()` which restores every element to its original position and un-hides each viewer wrapper. +4. **Canvas resize.** A `ResizeObserver` is attached to each `.fx-compare-canvas-cell`. When the cell resizes (window resize, column visibility change), it calls `viewer.canvasRenderer.resize()` + `viewer.renderAll()`. An initial `requestAnimationFrame` resize fires after `_buildCompareDOM()` to handle the first layout pass. + +### Interaction Control + +| Action | Owner | Mechanism | +|--------|-------|-----------| +| Node selection sync | `FXGraphCompare._wireSelectionSync()` | Listens to `viewer.on('selectionchange')`; propagates via `viewer.selectNode()` with source guard to prevent loops | +| Theme sync (state change) | `FXGraphCompare._wireStateSync()` | Listens to `viewer.on('statechange')`; propagates theme changes to other viewers; calls `_applyCompareTheme()` | +| Theme (compare shell) | `FXGraphCompare._applyCompareTheme()` | Sets CSS custom properties on `.fx-compare-root`; styles sidebar controls inline | +| Layers / ColorBy | Sidebar Layers button | Builds union of all extension ids; calls `viewer.setLayers()` / `viewer.setColorBy()` per viewer | +| Zoom to Fit | Sidebar Fit button | Calls `viewer.controller.zoomToFit()` on all viewers | +| Fullscreen | Sidebar Full button | Calls `requestFullscreen()` on `.fx-compare-root`; `fullscreenchange` listener updates button icon | +| Sync mode | Sidebar sync selector → `FXGraphCompare.setSync()` | Updates `this.sync`; next selection event uses new mode | +| Merged info panel | `FXGraphCompare._updateMergedInfo()` | Called after selection sync; renders a diff table into `.fx-compare-info-row` | + +### Selection Sync Modes + +| Mode | Sidebar label | Behavior | +|------|--------------|----------| +| `'auto'` (default) | Auto (handle→id) | `debug_handle` set-intersection first; falls back to node-ID match | +| `'id'` | ID only | Matches by node id; no-op if absent | +| `'layer'` | Ext: \.\ | Matches by extension field value; picks last in topo order on multiple matches | +| `'none'` | Don't sync | No propagation | + +`debug_handle` normalization: `int` → `{int}`, `int[]` → `Set(int[])`, `null/0/[]` → empty set. Two nodes match if their sets have a non-empty intersection. The **last in topological order** is selected on multiple matches. + +Three mapping patterns: +- **1-to-1**: same handle on both sides. +- **1-to-many** (decomposed ops): `linear` → `t + mm + add`, all share the same handle; last decomposed op is selected. +- **many-to-1** (fused ops): fused node carries union tuple handle `(h1, h2)`; any source node whose handle intersects `{h1, h2}` matches. + +### Merged Info Panel + +When a node is selected (and sync propagates), `_updateMergedInfo(nodeIdMap)` renders a comparison table into `.fx-compare-info-row`: +- Header row: "Property" | Graph 1 | Graph 2 | ... +- One row per property (union of all `node.info` keys across all selected nodes) +- Rows where values differ across graphs are highlighted amber (`.fx-diff`) +- Missing values shown as `—` + +## Canonical Data Contract + +Top-level payload: +1. `base`: `{ legend, nodes, edges }` +2. `extensions`: map keyed by extension id + +`base.nodes[]` fields: +1. `id`, `label`, `x`, `y`, `width`, `height` +2. `info`: metadata used by search/info panel +3. `tooltip`: base tooltip lines +4. `fill_color` (optional) + +`base.edges[]` fields: +1. `v`, `w` +2. `points` (optional routed polyline) + +## Extension Authoring Guide + +Key contract: +1. Add extension data explicitly with `add_node_data(node_id, data)`. +2. Formatter input is exactly that stored `data` dictionary. +3. Formatters must return `list[str]`. + +What formatters do not receive implicitly: +1. Full FX node object. +2. Base graph `info` fields. +3. Global graph context. + +If you need base attributes (for example `target`, `op`) in extension label/tooltip, +copy them into extension data before formatter use. + +### Sync key registration + +To expose an extension field as an explicit sync option in the compare sidebar: + +```python +ext.set_sync_key("debug_handle") +``` + +This makes `Ext: .debug_handle` appear as a selectable option in the compare sidebar. Selecting it activates `mode: 'layer'` with that extension and field. + +The `per_layer_accuracy` extension automatically registers `debug_handle` as a sync key when built via `_add_accuracy_extension`. + +## Color Rules + +Available rules: +1. `CategoricalColorRule(attribute, color_map=None)` +2. `NumericColorRule(attribute, cmap="viridis", handle_outliers=True)` + +Rule selection: +1. Use categorical for discrete semantic labels. +2. Use numeric for continuous measured metrics. +3. Keep `handle_outliers=True` for noisy distributions. +4. For rank/index-like metrics, set `handle_outliers=False`. + +## 3-Graph Compare Demo + +Standalone demo showing all three `debug_handle` mapping patterns in one compare view: + +```bash +python backends/qualcomm/utils/fx_viewer/examples/demo_3graph_compare.py +``` + +Output: `demo_3graph_compare.html` + +Three graphs: +1. **Reference (float)**: unique int handle per node. +2. **Decomposed (1→many)**: each `linear` → `t + mm + add`, all three share the same handle. +3. **Fused (many→1)**: `relu` nodes that follow a `linear` carry a union tuple handle `(linear_h, relu_h)`. + +Expected sync behavior (mode `auto`, set intersection): +- Click `linear` (handle `{6}`) in Graph 1 → Graph 2: `add_tensor` (last of `{t,mm,add}`). Graph 3: `relu` (handle `{6,7}`). +- Click `relu` (handle `{7}`) in Graph 1 → Graph 2: `relu_default`. Graph 3: `relu` (handle `{6,7}`). +- Click `relu` (handle `{6,7}`) in Graph 3 → Graph 1: `relu` (last among `{linear,relu}`). Graph 2: `relu_default`. + +## Unified API Harness + +Files: +1. Generator: `backends/qualcomm/utils/fx_viewer/examples/generate_api_test_harness.py` +2. Template: `backends/qualcomm/utils/fx_viewer/examples/harness_template.html` +3. Testcases: `backends/qualcomm/utils/fx_viewer/examples/harness_testcases.py` +4. Tutorial testcase guide: `backends/qualcomm/utils/fx_viewer/examples/FX_VIEWER_API_TESTCASES.md` + +Generate harnesses: + +```bash +source .venv/bin/activate +export PYTHONPATH=~/:$PYTHONPATH +python backends/qualcomm/utils/fx_viewer/examples/generate_api_test_harness.py +``` + +Generated outputs: +1. `fx_viewer_api_test_harness_portable.html` +2. `fx_viewer_api_test_harness_qualcomm.html` + +Suggested learning order: +1. JS beginner ladder (`js_01` ... `js_08` in testcase guide). +2. Advanced combos (`adv_01` ... `adv_04`). +3. Final mixed demo (`js_99_combo_mixed`). + +## Testing + +Contract tests: +1. `tests/test_exporter_contract.py` + +Run: + +```bash +source .venv/bin/activate +pytest -q tests/test_exporter_contract.py +``` + +## References + +1. API RFC: `backends/qualcomm/utils/fx_viewer/RFC_FX_VIEWER_API_INTERFACE.md` +2. Implementation status: `backends/qualcomm/utils/fx_viewer/RFC_API_IMPLEMENTATION_STATUS.md` +3. JS runtime internals: `backends/qualcomm/utils/fx_viewer/templates/README.md` diff --git a/devtools/fx_viewer/__init__.py b/devtools/fx_viewer/__init__.py new file mode 100644 index 00000000000..15b2f133e67 --- /dev/null +++ b/devtools/fx_viewer/__init__.py @@ -0,0 +1,25 @@ +from .color_rules import ColorRule, CategoricalColorRule, NumericColorRule +from .exporter import FXGraphExporter +from .extension import GraphExtension +from .models import ( + BaseGraphPayload, + GraphEdge, + GraphExtensionNodePayload, + GraphExtensionPayload, + GraphNode, + GraphPayload, +) + +__all__ = [ + "FXGraphExporter", + "GraphExtension", + "ColorRule", + "CategoricalColorRule", + "NumericColorRule", + "GraphNode", + "GraphEdge", + "BaseGraphPayload", + "GraphExtensionNodePayload", + "GraphExtensionPayload", + "GraphPayload", +] diff --git a/devtools/fx_viewer/color_rules.py b/devtools/fx_viewer/color_rules.py new file mode 100644 index 00000000000..5f79197dd64 --- /dev/null +++ b/devtools/fx_viewer/color_rules.py @@ -0,0 +1,152 @@ +"""Color rules for mapping node attributes to display colors.""" + +import hashlib +import colorsys + +class ColorRule: + """Base class for node->color mapping.""" + def __init__(self, attribute: str): + self.attribute = attribute + + def apply(self, nodes_data: dict) -> tuple[dict, list]: + """ + Takes a dictionary mapping node_id -> node_info_dict. + Returns: + - node_colors: Dict[str, str] mapping node_id -> hex color. + - legend: List[Dict[str, str]] containing legend items {"label": ..., "color": ...}. + """ + raise NotImplementedError + +class CategoricalColorRule(ColorRule): + """Assign deterministic colors to string/categorical values.""" + def __init__(self, attribute: str, color_map=None): + super().__init__(attribute) + self.color_map = color_map or {} + + def apply(self, nodes_data: dict) -> tuple[dict, list]: + node_colors = {} + unique_values = set() + + for node_id, data in nodes_data.items(): + if self.attribute not in data: + continue + + val = data[self.attribute] + if val is None: + continue + + val_str = str(val) + unique_values.add(val_str) + + if val_str in self.color_map: + node_colors[node_id] = self.color_map[val_str] + else: + # Consistent hashing to a hue value in HSV space + hash_val = int(hashlib.md5(val_str.encode('utf-8')).hexdigest(), 16) + hue = (hash_val % 360) / 360.0 + saturation = 0.65 + value_hsv = 0.85 + + r, g, b = colorsys.hsv_to_rgb(hue, saturation, value_hsv) + r, g, b = int(r * 255), int(g * 255), int(b * 255) + node_colors[node_id] = f"#{r:02x}{g:02x}{b:02x}" + + # Generate Legend + legend = [] + # First add explicit map entries + for k, v in self.color_map.items(): + if k in unique_values: + legend.append({"label": str(k), "color": v}) + unique_values.remove(k) + + # Then add hashed ones + for val_str in sorted(unique_values): + # Recalculate hash for the legend to avoid storing it twice + hash_val = int(hashlib.md5(val_str.encode('utf-8')).hexdigest(), 16) + hue = (hash_val % 360) / 360.0 + r, g, b = colorsys.hsv_to_rgb(hue, 0.65, 0.85) + hex_color = f"#{int(r * 255):02x}{int(g * 255):02x}{int(b * 255):02x}" + legend.append({"label": str(val_str), "color": hex_color}) + + return node_colors, legend + +class NumericColorRule(ColorRule): + """Assign gradient colors to numeric values.""" + def __init__(self, attribute: str, cmap="viridis", handle_outliers=True): + super().__init__(attribute) + self.cmap = cmap + self.handle_outliers = handle_outliers + + def _interpolate_color(self, ratio): + ratio = max(0.0, min(1.0, ratio)) + + if self.cmap.lower() == 'reds': + r = 255 + g = b = int(127 * (1 - ratio) + 128) + elif self.cmap.lower() == 'blues': + r = g = int(127 * (1 - ratio) + 128) + b = 255 + elif self.cmap.lower() == 'greens': + r = b = int(127 * (1 - ratio) + 128) + g = 255 + else: # viridis-like fallback + if ratio < 0.5: + r, g, b = 68 + 50, int(1 + ratio * 2 * 120)+ 50, int(34 + ratio * 2 * 50) + 50 + else: + r = int(68 + (ratio - 0.5) * 2 * 187) + g = int(171 + (ratio - 0.5) * 2 * 84) + b = int(134 - (ratio - 0.5) * 2 * 134) + + return f"#{r:02x}{g:02x}{b:02x}" + + def apply(self, nodes_data: dict) -> tuple[dict, list]: + # Pass 1: Collect valid values for fitting + valid_values = [] + for data in nodes_data.values(): + if self.attribute in data: + val = data[self.attribute] + if isinstance(val, (int, float)): + valid_values.append(val) + + if not valid_values: + return {}, [] + + # Fit bounds + if self.handle_outliers and len(valid_values) > 10: + valid_values.sort() + p5_idx = max(0, int(len(valid_values) * 0.05)) + p95_idx = min(len(valid_values) - 1, int(len(valid_values) * 0.95)) + _min = valid_values[p5_idx] + _max = valid_values[p95_idx] + else: + _min = min(valid_values) + _max = max(valid_values) + + if _min == _max: + _max = _min + 1e-9 + + # Pass 2: Calculate colors + node_colors = {} + for node_id, data in nodes_data.items(): + if self.attribute in data: + val = data[self.attribute] + if isinstance(val, (int, float)): + ratio = (val - _min) / (_max - _min) + node_colors[node_id] = self._interpolate_color(ratio) + + # Generate Legend + legend = [] + for i in range(5): + ratio = i / 4.0 + val = _min + ratio * (_max - _min) + color = self._interpolate_color(ratio) + + if abs(val) >= 1000 or (abs(val) < 0.01 and val != 0): + label_str = f"{val:.2e}" + elif isinstance(val, int) or float(val).is_integer(): + label_str = f"{int(val)}" + else: + label_str = f"{val:.2f}" + legend.append({"label": label_str, "color": color}) + + return node_colors, legend diff --git a/devtools/fx_viewer/examples/FX_VIEWER_API_TESTCASES.md b/devtools/fx_viewer/examples/FX_VIEWER_API_TESTCASES.md new file mode 100644 index 00000000000..ee74b0c8bc4 --- /dev/null +++ b/devtools/fx_viewer/examples/FX_VIEWER_API_TESTCASES.md @@ -0,0 +1,159 @@ +# FX Viewer API Harness: Tutorial Testcases + +This document is a learning guide for the unified harness. +Read top-to-bottom and run cases in order. + +## Harness Outputs + +1. `fx_viewer_api_test_harness_portable.html` +2. `fx_viewer_api_test_harness_qualcomm.html` + +Portable harness requires no Qualcomm SDK. +Qualcomm harness requires QAIRT/QNN environment. + +## Learning Order + +### Level 1: API Fundamentals + +1. `js_01_create_init_destroy` +- Purpose: learn viewer lifecycle. +- Target APIs: `FXGraphViewer.create`, `init`, `destroy`, `getState`. +- What to try: +1. Click `Create + Init`. +2. Click `Destroy`. +3. Repeat and observe state panel. + +2. `js_02_state_theme` +- Purpose: understand state-driven controls. +- Target APIs: `getState`, `setState`, `setTheme`. +- What to try: +1. Switch light/dark theme. +2. Toggle highlight mode. +3. Watch `getState()` snapshot update. + +3. `js_03_selection_camera` +- Purpose: learn navigation semantics. +- Target APIs: `selectNode`, `animateToNode`, `panToNode`, `zoomToFit`, `clearSelection`. +- What to try: +1. Use each button and compare motion behavior. +2. Confirm clear-selection resets visual focus. + +4. `js_04_layers_colorby` +- Purpose: separate "active layers" from "color source". +- Target APIs: `setLayers`, `setColorBy`. +- What to try: +1. Disable one layer and keep colorBy on the other. +2. Set colorBy to `base` and compare legend. + +5. `js_05_runtime_mutation` +- Purpose: mutate overlays at runtime. +- Target APIs: `upsertLayer`, `patchLayerNodes`, `setLayerLabel`, `setColorRule`, `removeLayer`. +- What to try: +1. Move threshold slider. +2. Rename layer. +3. Apply color rule function. +4. Remove layer and revert to base. + +6. `js_06_layout_slots` +- Purpose: embed UI components in external host divs. +- Target APIs: `mount.slots`, `setLayout`, `setUIVisibility`. +- What to try: +1. Toggle info/minimap visibility. +2. Hide/show toolbar chrome. +3. Inspect slot ownership behavior. + +7. `js_07_events` +- Purpose: subscribe to viewer events from host code. +- Target APIs: `on`, `off` with `statechange`, `selectionchange`, `themechange`, `layoutchange`. +- What to try: +1. Trigger events with buttons. +2. Click `Unsubscribe Events`. +3. Confirm log no longer updates. + +8. `js_08_compare_basics` +- Purpose: 3-graph compare with automatic `debug_handle` sync. +- Target APIs: `FXGraphCompare.create` (Map API), `setSync`, default `mode: 'auto'`. +- What to try: +1. Click a node in any graph — observe that the other two graphs sync to the matching node via `debug_handle` set intersection. +2. Open the sidebar sync selector — confirm it shows `Auto (handle→id)` as the default. +3. Switch to `ID only` and click a node — confirm sync still works (same graph, same node ids). +4. Switch to `Don't sync` — confirm no propagation. +5. If `per_layer_accuracy.debug_handle` appears in the selector, try it to see extension-field sync. + +### Level 2: Interesting Combinations + +9. `adv_01_accuracy_dynamic` +- Purpose: real per-layer accuracy workflow with host controls. +- Target APIs: `setTheme`, `patchLayerNodes`, `setColorBy`, `selectNode`. +- What to try: +1. Adjust severity percentile. +2. Focus highest-severity node. + +10. `adv_02_headless_slots_slider` +- Purpose: host-owned layout + slot embedding + dynamic recolor. +- Target APIs: `mount.slots`, `patchLayerNodes`, `setColorBy`. +- What to try: +1. Drag threshold slider. +2. Check info/minimap/legend in external panes. + +11. `adv_03_fullscreen_toolbar` +- Purpose: fullscreen via both toolbar and direct API. +- Target APIs: `layout.fullscreen.button`, `enterFullscreen`, `exitFullscreen`. +- What to try: +1. Enter/exit fullscreen from side buttons. +2. Use taskbar fullscreen toggle too. + +11b. `demo_3graph_compare` (standalone HTML, not in harness) +- Purpose: see all three `debug_handle` mapping patterns in one view. +- Run: `python backends/qualcomm/utils/fx_viewer/examples/demo_3graph_compare.py` +- What to try: +1. Click `linear` in Graph 1 — Graph 2 selects `add_tensor` (last of decomposed ops), Graph 3 selects `relu` (fused handle `{6,7}`). +2. Click `relu` in Graph 3 — Graph 1 selects `relu` (last among `{linear,relu}` intersecting `{6,7}`), Graph 2 selects `relu_default`. +3. Click "Highlight Demo" button — all three graphs show orange borders on linear-family nodes. +4. Click "Clear Highlights" — borders disappear. +5. In browser console: `fxRef.addHighlightGroup('test', ['linear'], '#00aaff')` — blue border on `linear`. + +12. `adv_04_tiled_compare` +- Purpose: 3-graph compare starting with explicit extension-field sync (`per_layer_accuracy.debug_handle`). +- Target APIs: `FXGraphCompare.create` (Map API), `sync.mode: 'layer'`, `sync.layer`, `sync.field`, `setSync`. +- What to try: +1. Click a node in any graph — observe sync via `per_layer_accuracy.debug_handle` field value matching. +2. Open the sidebar sync selector — confirm `Ext: per_layer_accuracy.debug_handle` is selected. +3. Switch to `Auto (handle→id)` — confirm sync still works via `debug_handle` set intersection. +4. Switch to `ID only` — confirm sync works by node name. +5. Switch to `Don't sync` — confirm no propagation. +6. Inspect the merged info panel — compare `debug_handle` values across all three graphs. + +### Level 3: Current Mixed Demo + +13. `js_99_combo_mixed` +- Purpose: demonstrate a realistic mixed usage pattern. +- Target APIs: compare sync, runtime mutation, event subscriptions, theme control, camera APIs. +- What to try: +1. Toggle compare sync flags. +2. Move threshold slider. +3. Focus worst node. +4. Run scripted sequence. + +### Qualcomm-only + +13. `qualcomm_metadata` +- Purpose: inspect Qualcomm PTQ metadata beside rendered graph. +- Target APIs: `create` plus host metadata composition. + +## How to Learn Effectively + +1. Run one testcase at a time. +2. Edit JS pane in small changes and rerun. +3. Keep the "Target APIs" list in view while editing. +4. Move to the next level only after you can explain current behavior. + +## Common Mistakes + +1. Using `setColorBy(layer)` when that layer is not active. +2. Forgetting to call `init()` after `create()`. +3. Patching a layer that has not been added via `upsertLayer`. +4. Assuming compare sync covers all dimensions when only selection is enabled. +5. Expecting `mode: 'id'` to work across decomposed/fused graphs — use `mode: 'auto'` instead. +6. Forgetting to call `ext.set_sync_key(field)` when you want an extension field to appear in the compare sidebar. +7. Calling `addHighlightGroup` with node IDs that don't exist in the graph — they are silently skipped. diff --git a/devtools/fx_viewer/examples/PYTHON_API_TUTORIAL.md b/devtools/fx_viewer/examples/PYTHON_API_TUTORIAL.md new file mode 100644 index 00000000000..345a49e86b6 --- /dev/null +++ b/devtools/fx_viewer/examples/PYTHON_API_TUTORIAL.md @@ -0,0 +1,140 @@ +# fx_viewer Python API Tutorial + +This tutorial is intentionally practical and maps directly to harness usage. + +## 1) Minimal Export + +```python +import torch +from executorch.backends.qualcomm.utils.fx_viewer import FXGraphExporter + +model = torch.nn.Sequential(torch.nn.Linear(16, 16), torch.nn.ReLU()).eval() +sample = (torch.randn(1, 16),) + +ep = torch.export.export(model, sample, strict=False) +exporter = FXGraphExporter(ep.graph_module) +exporter.export_html("minimal_graph.html") +``` + +What you learn: +1. Create exporter from `graph_module`. +2. Generate standalone HTML quickly. + +## 2) Add One Extension Layer + +```python +from executorch.backends.qualcomm.utils.fx_viewer import GraphExtension + +ext = GraphExtension(id="backend", name="Backend") + +payload = exporter.generate_json_payload() +for node in payload["base"]["nodes"]: + # Example: fake backend assignment + ext.add_node_data(node["id"], {"backend": "cpu"}) + +ext.set_label_formatter(lambda d: [f"backend={d.get('backend', 'unknown')}"]) +exporter.add_extension(ext) +exporter.export_html("graph_with_backend_layer.html") +``` + +What you learn: +1. `add_node_data(node_id, data)` is the core extension contract. +2. Formatter input is exactly stored extension data. + +## 3) Add Color Rules + +### Categorical + +```python +from executorch.backends.qualcomm.utils.fx_viewer import CategoricalColorRule + +ext.set_color_rule(CategoricalColorRule(attribute="backend")) +``` + +Use when values are discrete labels. + +### Numeric + +```python +from executorch.backends.qualcomm.utils.fx_viewer import NumericColorRule + +metric_ext = GraphExtension(id="latency", name="Latency") +metric_ext.add_node_data("node_a", {"latency_ms": 1.2}) +metric_ext.set_color_rule(NumericColorRule(attribute="latency_ms", cmap="viridis")) +``` + +Use when values are continuous metrics. + +## 4) Export Modes + +```python +payload = exporter.generate_json_payload() # in-memory dict +exporter.export_json("graph_payload.json") +js_snippet = exporter.export_js("graph-host") +exporter.export_html("graph_standalone.html") +``` + +Use cases: +1. `export_html`: easiest for local inspection. +2. `export_json` + JS runtime: best for custom host applications. +3. `export_js`: quick embed in existing HTML. + +## 5) debug_handle Extraction and Compare Sync + +`debug_handle` is a per-node integer assigned by `generate_missing_debug_handles`. The exporter +extracts it explicitly so it is always present in `node.info` regardless of type: + +```python +from executorch.exir.passes.debug_handle_generator_pass import generate_missing_debug_handles + +ep = torch.export.export(model, sample, strict=False) +generate_missing_debug_handles(ep) +gm = ep.module() + +exporter = FXGraphExporter(gm) +payload = exporter.generate_json_payload() +# payload["base"]["nodes"][i]["info"]["debug_handle"] is now int or list[int] +``` + +Fused nodes may carry a tuple handle `(h1, h2)`. The exporter normalizes: +- `int` → stored as `int` +- `tuple/list` with one element → stored as `int` +- `tuple/list` with multiple elements → stored as `list[int]` + +### Registering a sync key for compare mode + +To expose an extension field as an explicit sync option in the compare sidebar: + +```python +ext = GraphExtension(id="my_ext", name="My Extension") +ext.add_node_data(node_id, {"debug_handle": 42, "latency_ms": 1.5}) +ext.set_sync_key("debug_handle") # appears as "Ext: my_ext.debug_handle" in sidebar +``` + +The `per_layer_accuracy` extension (built by `_add_accuracy_extension`) automatically registers +`debug_handle` as a sync key. This enables the compare sidebar to offer +`Ext: per_layer_accuracy.debug_handle` as an explicit sync option alongside the default +`Auto (handle→id)` mode. + +## 6) Connect Python Output to JS Harness Thinking + +If Python emits: +1. `extensions["per_layer_accuracy"]` (with `set_sync_key("debug_handle")`) +2. `extensions["topological_order"]` + +Then JS harness can immediately use: +1. `viewer.setLayers(["per_layer_accuracy", "topological_order"])` +2. `viewer.setColorBy("per_layer_accuracy")` +3. `viewer.patchLayerNodes("per_layer_accuracy", patchByNodeId)` +4. `FXGraphCompare.create({ viewers, layout, sync: { mode: 'auto' } })` — auto sync via `debug_handle` + +This is the core Python/JS contract boundary. + +## 7) Recommended Practice Path + +1. Start with `minimal_graph.html`. +2. Add one extension with one field. +3. Add categorical color. +4. Add numeric metric layer. +5. Add `set_sync_key` and test in compare mode (`js_08`, `adv_04`). +6. Run `demo_3graph_compare.py` to see all three `debug_handle` mapping patterns. diff --git a/devtools/fx_viewer/examples/demo_3graph_compare.py b/devtools/fx_viewer/examples/demo_3graph_compare.py new file mode 100644 index 00000000000..98d5891e5dc --- /dev/null +++ b/devtools/fx_viewer/examples/demo_3graph_compare.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python3 +"""Standalone 3-graph compare demo: Reference → Decomposed (1-to-many) → Fused (many-to-1). + +Demonstrates from_node_root as the primary sync mode: +- Graph 1 (Reference): torch.export float model; run_decompositions populates from_node. +- Graph 2 (Decomposed): linear → t + mm + add (3 nodes share from_node_root="linear"). +- Graph 3 (Fused): deepcopy of ref_gm; relu nodes get union debug_handle for many-to-1. + +Sync mode 'auto' (from_node_root → debug_handle → id) connects all three graphs: +- Click linear in Graph 1 → Graph 2: add_tensor (last of {t,mm,add} with from_node_root=linear). +- Click t_default/mm_default/add_tensor in Graph 2 → Graph 1: linear (from_node_root match). +- Click relu in Graph 1 → Graph 2: relu_default (from_node_root=relu). +- Click relu in Graph 3 (union handle) → Graph 1: relu. Graph 2: relu_default. + +Run from repo root: + python backends/qualcomm/utils/fx_viewer/examples/demo_3graph_compare.py +""" + +from __future__ import annotations + +import copy +import json +from pathlib import Path +from typing import Any, Dict + +import torch +import torch.fx + +from executorch.devtools.fx_viewer import ( + FXGraphExporter, + GraphExtension, + NumericColorRule, +) + +THIS_DIR = Path(__file__).resolve().parent + + +# --------------------------------------------------------------------------- +# Model +# --------------------------------------------------------------------------- + +class _ToyModel(torch.nn.Module): + """Small MLP — Linear+ReLU blocks are decomposed/fused to demo handle mapping.""" + + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(32, 64) + self.fc2 = torch.nn.Linear(64, 32) + self.fc3 = torch.nn.Linear(32, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.relu(self.fc1(x)) + x = torch.relu(self.fc2(x)) + return self.fc3(x) + + +# --------------------------------------------------------------------------- +# 1-to-many: DecomposeLinear transformer +# --------------------------------------------------------------------------- + +class _DecomposeLinearPass(torch.fx.Transformer): + """1-to-many: aten.linear → aten.t + aten.mm + aten.add (3 nodes, same handle).""" + + def call_function(self, target, args, kwargs): + if target is torch.ops.aten.linear.default: + inp, weight = args[0], args[1] + bias = args[2] if len(args) > 2 else kwargs.get("bias") + t = super().call_function(torch.ops.aten.t.default, (weight,), {}) + mm = super().call_function(torch.ops.aten.mm.default, (inp, t), {}) + if bias is not None: + return super().call_function(torch.ops.aten.add.Tensor, (mm, bias), {}) + return mm + return super().call_function(target, args, kwargs) + + + + +# --------------------------------------------------------------------------- +# many-to-1: fused graph (no Transformer needed) +# --------------------------------------------------------------------------- + +def _build_fused_graph(ref_gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Simulate fusion: relu nodes that follow a linear get a union handle (linear_h, relu_h). + + from_node is preserved from ref_gm via deepcopy, so from_node_root sync still works. + The union debug_handle demonstrates many-to-1 matching as a secondary sync mechanism. + """ + fused_gm = copy.deepcopy(ref_gm) + nodes = list(fused_gm.graph.nodes) + for i, node in enumerate(nodes): + if node.op != "call_function" or "relu" not in node.name: + continue + prev = nodes[i - 1] if i > 0 else None + if prev and prev.op == "call_function" and "linear" in prev.name: + lh = prev.meta.get("debug_handle") + rh = node.meta.get("debug_handle") + if lh and rh: + node.meta["debug_handle"] = (lh, rh) # union tuple → many-to-1 + return fused_gm + + +# --------------------------------------------------------------------------- +# Extension builder +# --------------------------------------------------------------------------- + +def _build_debug_handle_extension(graph_module: torch.fx.GraphModule) -> GraphExtension: + ext = GraphExtension(id="debug_handle_sync", name="Debug Handle") + for node in graph_module.graph.nodes: + raw = node.meta.get("debug_handle") + if not raw or raw == 0 or raw == () or raw == []: + continue + if isinstance(raw, int): + dh_val: Any = raw + elif isinstance(raw, (tuple, list)): + ints = [int(x) for x in raw if isinstance(x, int) and x != 0] + if not ints: + continue + dh_val = ints[0] if len(ints) == 1 else ints + else: + continue + ext.add_node_data(node.name, {"debug_handle": dh_val}) + ext.set_sync_key("debug_handle") + ext.set_label_formatter(lambda d: [f"dh={d.get('debug_handle', '?')}"]) + ext.set_color_rule(NumericColorRule(attribute="debug_handle", cmap="viridis", handle_outliers=False)) + return ext + + +# --------------------------------------------------------------------------- +# HTML output +# --------------------------------------------------------------------------- + +def _write_compare_html( + output_path: Path, + ref_payload: Dict[str, Any], + decomp_payload: Dict[str, Any], + fused_payload: Dict[str, Any], +) -> None: + js_bundle = FXGraphExporter._load_viewer_js_bundle() + + payloads_json = json.dumps({ + "ref": ref_payload, + "decomp": decomp_payload, + "fused": fused_payload, + }) + + html = f""" + + + + 3-Graph Compare: from_node_root Sync Demo + + + +
+
3-Graph Compare: from_node_root Sync Demo
+ + +
+ About this demo +

+ Graph 1 (Reference): float model exported; run_decompositions populates from_node on all nodes.
+ Graph 2 (Decomposed, 1→many): each linear → t + mm + add; all 3 share from_node_root="linear".
+ Graph 3 (Fused, many→1): deepcopy of ref; relu nodes that follow a linear get a union debug_handle.
+ Sync mode 'Auto (from_node→handle→id)': from_node_root is tried first, then debug_handle set intersection. +

+
+
+
+
+
+ + + + + + +""" + output_path.write_text(html) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main() -> None: + model = _ToyModel().eval() + sample = (torch.randn(1, 32),) + + print("Exporting reference graph...") + ref_ep = torch.export.export(model, sample, strict=False) + # run_decompositions populates from_node on all nodes + ref_ep_decomp = ref_ep.run_decompositions({}) + ref_gm = ref_ep_decomp.module() + + print("Building decomposed graph (1-to-many via Transformer)...") + # _DecomposeLinearPass is a torch.fx.Transformer — it auto-propagates from_node + decomp_gm = _DecomposeLinearPass(ref_gm).transform() + # from_node_root is now set on t_default, mm_default, add_tensor → "linear" + + print("Building fused graph (many-to-1)...") + fused_gm = _build_fused_graph(ref_gm) + + print("Exporting payloads...") + ref_exp = FXGraphExporter(ref_gm) + decomp_exp = FXGraphExporter(decomp_gm) + fused_exp = FXGraphExporter(fused_gm) + for exp, gm in [(ref_exp, ref_gm), (decomp_exp, decomp_gm), (fused_exp, fused_gm)]: + exp.add_extension(_build_debug_handle_extension(gm)) + + ref_payload = ref_exp.generate_json_payload() + decomp_payload = decomp_exp.generate_json_payload() + fused_payload = fused_exp.generate_json_payload() + + output = Path("demo_3graph_compare.html") + _write_compare_html(output, ref_payload, decomp_payload, fused_payload) + print(f"Wrote: {output}") + print() + print("Expected sync behavior (mode: auto, from_node_root → debug_handle → id):") + print(" Click linear in Graph 1 → Graph 2: add_tensor (last of {t,mm,add} with from_node_root=linear).") + print(" Click t_default/mm_default/add_tensor in Graph 2 → Graph 1: linear (from_node_root match).") + print(" Click relu in Graph 1 → Graph 2: relu_default (from_node_root=relu).") + print(" Click relu in Graph 3 (union handle) → Graph 1: relu. Graph 2: relu_default.") + + +if __name__ == "__main__": + main() diff --git a/devtools/fx_viewer/examples/demo_fx_viewer_extensions.py b/devtools/fx_viewer/examples/demo_fx_viewer_extensions.py new file mode 100644 index 00000000000..af387fe2760 --- /dev/null +++ b/devtools/fx_viewer/examples/demo_fx_viewer_extensions.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +"""Demo for fx_viewer V3 extensions with Swin and Llama models. + +This script exports standalone HTML files using the new fx_viewer module and +adds two extension layers: +1) Target/op-type categorical coloring. +2) Topological-order numeric coloring. + +Run (from repo root): + source ~/executorch/.venv/bin/activate + python examples/demo_fx_viewer_extensions.py --model both +""" + +from __future__ import annotations + +import argparse +from collections import deque +from pathlib import Path +from typing import Any + +import torch + + +from executorch.devtools.fx_viewer import ( + CategoricalColorRule, + FXGraphExporter, + GraphExtension, + GraphNode, + NumericColorRule, +) + + +def _base_label(node: GraphNode) -> str: + target = str(node.info.get("target") if node.info.get("op") == "call_function" else node.info.get("op")) + return target.replace("aten.", "").replace(".default", "") + +def _compute_topological_index(nodes: list[dict[str, Any]], edges: list[dict[str, Any]]) -> dict[str, int]: + node_ids = [n["id"] for n in nodes] + indeg = {nid: 0 for nid in node_ids} + adj: dict[str, list[str]] = {nid: [] for nid in node_ids} + + for e in edges: + src, dst = e["v"], e["w"] + if src in adj and dst in indeg: + adj[src].append(dst) + indeg[dst] += 1 + + q = deque([nid for nid in node_ids if indeg[nid] == 0]) + topo_index: dict[str, int] = {} + idx = 0 + + while q: + cur = q.popleft() + topo_index[cur] = idx + idx += 1 + for nxt in adj[cur]: + indeg[nxt] -= 1 + if indeg[nxt] == 0: + q.append(nxt) + + # Fallback for unexpected cycles: keep deterministic order. + for nid in node_ids: + if nid not in topo_index: + topo_index[nid] = idx + idx += 1 + + return topo_index + + +def _build_color_by_type_extension(nodes: list[dict[str, Any]]) -> GraphExtension: + ext = GraphExtension(id="color_by_type", name="Color By Type") + + for n in nodes: + info = n.get("info", {}) + ext.add_node_data( + n["id"], + { + "target": str(info.get("target", "unknown")), + "op": str(info.get("op", "unknown")), + "color_data": str(info.get("target") if info.get("op") == "call_function" else info.get("op")) + + }, + ) + + ext.set_label_formatter(lambda d: [f"color_data: {d.get('color_data', 'unknown')}"]) + ext.set_color_rule(CategoricalColorRule(attribute="color_data")) + return ext + + +def _build_topology_extension( + nodes: list[dict[str, Any]], + edges: list[dict[str, Any]], +) -> GraphExtension: + topo_idx = _compute_topological_index(nodes, edges) + + ext = GraphExtension(id="topological_order", name="Topological Order") + for n in nodes: + idx = topo_idx[n["id"]] + ext.add_node_data(n["id"], {"topo_index": idx}) + + ext.set_label_formatter(lambda d: [f"topo: {d.get('topo_index', -1)}"]) + ext.set_tooltip_formatter( + lambda d: [ + f"Topological index: {d.get('topo_index', -1)}", + ] + ) + ext.set_color_rule(NumericColorRule(attribute="topo_index", cmap="viridis", handle_outliers=False)) + return ext + + +def _export_with_extensions(model: torch.nn.Module, inputs: tuple[Any, ...], output_html: Path) -> None: + try: + ep_model = torch.export.export(model, inputs, strict=False) + ep_model = ep_model.run_decompositions() + graph_module = ep_model.graph_module + except Exception: + graph_module = torch.fx.symbolic_trace(model) + + exporter = FXGraphExporter(graph_module) + + # override base behavior + exporter.set_base_label_formatter(_base_label) + + base_payload = exporter.generate_json_payload() + base_nodes = base_payload["base"]["nodes"] + base_edges = base_payload["base"]["edges"] + + exporter.add_extension(_build_color_by_type_extension(base_nodes)) + exporter.add_extension(_build_topology_extension(base_nodes, base_edges)) + + exporter.export_html(str(output_html)) + + +def _build_swin_model() -> tuple[torch.nn.Module, tuple[Any, ...]]: + from transformers import SwinConfig, SwinForImageClassification + + config = SwinConfig( + image_size=224, + patch_size=4, + num_channels=3, + embed_dim=96, + depths=[2, 2, 2, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + num_labels=10, + ) + model = SwinForImageClassification(config).eval().to("cpu") + inputs = (torch.rand(1, 3, 224, 224),) + return model, inputs + + +def _build_llama_model() -> tuple[torch.nn.Module, tuple[Any, ...]]: + from transformers import LlamaConfig, LlamaForCausalLM + + config = LlamaConfig( + vocab_size=128, + hidden_size=128, + intermediate_size=256, + num_hidden_layers=6, + num_attention_heads=8, + num_key_value_heads=8, + max_position_embeddings=256, + ) + model = LlamaForCausalLM(config).eval().to("cpu") + input_ids = torch.randint(0, config.vocab_size, (1, 32)) + inputs = (input_ids,) + return model, inputs + + +def main() -> None: + parser = argparse.ArgumentParser(description="FX Viewer V3 extension demo") + parser.add_argument( + "--model", + choices=["swin", "llama", "both"], + default="both", + help="Which model demo to export", + ) + parser.add_argument( + "--out-dir", + default=".", + help="Output directory for generated HTML files", + ) + args = parser.parse_args() + + out_dir = Path(args.out_dir).resolve() + out_dir.mkdir(parents=True, exist_ok=True) + + if args.model in ("swin", "both"): + print("Building Swin demo model...") + model, inputs = _build_swin_model() + out_file = out_dir / "swin_graph_v3_extensions.html" + _export_with_extensions(model, inputs, out_file) + print(f"Exported: {out_file}") + + if args.model in ("llama", "both"): + print("Building Llama demo model...") + model, inputs = _build_llama_model() + out_file = out_dir / "llama_graph_v3_extensions.html" + _export_with_extensions(model, inputs, out_file) + print(f"Exported: {out_file}") + + +if __name__ == "__main__": + main() diff --git a/devtools/fx_viewer/examples/demo_per_layer_accuracy_fx.py b/devtools/fx_viewer/examples/demo_per_layer_accuracy_fx.py new file mode 100644 index 00000000000..9bf179e1a59 --- /dev/null +++ b/devtools/fx_viewer/examples/demo_per_layer_accuracy_fx.py @@ -0,0 +1,956 @@ +#!/usr/bin/env python3 +"""Standalone fx_viewer per-layer accuracy demo (no Observatory UI). + +This demo compares two FX graphs and visualizes per-layer accuracy deltas using +an fx_viewer extension. It supports two pipelines: +- fake_quant: backend-agnostic simulated quantization (weight rounding only) +- qualcomm_ptq: Qualcomm PTQ path using QnnQuantizer + prepare/convert PT2E + +It also follows the debug workflow: +1) Run end-to-end on multiple input samples. +2) Pick the worst sample by output drop score. +3) Capture per-layer outputs only on that worst sample. + +Run from repo root: + source .venv/bin/activate + export PYTHONPATH=~/ + + # Backend-agnostic demo: + python backends/qualcomm/utils/fx_viewer/examples/demo_per_layer_accuracy_fx.py \ + --pipeline fake_quant --model swin + + # Qualcomm PTQ demo (requires QNN/QAIRT env): + source ~/executorch/qairt/2.37.0.250724/bin/envsetup.sh + python backends/qualcomm/utils/fx_viewer/examples/demo_per_layer_accuracy_fx.py \ + --pipeline qualcomm_ptq --model swin --soc-model SM8650 +""" + +from __future__ import annotations + +import argparse +import copy +import json +import math +import os +import random +import sys +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, Mapping, Sequence + +import torch + + +from executorch.devtools.inspector._inspector_utils import ( # noqa: E402 + DebugHandle, + get_aot_debug_handle_to_op_name_mapping, +) +from executorch.devtools.inspector._intermediate_output_capturer import ( # noqa: E402 + IntermediateOutputCapturer, +) +from executorch.exir.passes.debug_handle_generator_pass import ( # noqa: E402 + generate_missing_debug_handles, +) + +from executorch.devtools.fx_viewer import ( # noqa: E402 + FXGraphExporter, + GraphExtension, + NumericColorRule, +) + + +@dataclass +class MatchRecord: + candidate_node: str + reference_node: str + candidate_debug_handle: DebugHandle + reference_debug_handle: DebugHandle + matched_by: str + + +@dataclass +class LayerMetric: + candidate_node: str + reference_node: str + candidate_debug_handle: DebugHandle + reference_debug_handle: DebugHandle + matched_by: str + numel_compared: int + candidate_shape: str + reference_shape: str + max_abs_err: float + mean_abs_err: float + mse: float + cosine_similarity: float + severity_score: float + + +@dataclass +class SampleScore: + sample_index: int + mse: float + max_abs_err: float + cosine_similarity: float + drop_score: float + + +@dataclass +class GraphPair: + pipeline: str + reference_name: str + candidate_name: str + reference_graph: torch.fx.GraphModule + candidate_graph: torch.fx.GraphModule + metadata: dict[str, Any] + + +def _set_seed(seed: int) -> None: + random.seed(seed) + torch.manual_seed(seed) + + +def _patch_swin_window_ops() -> None: + # Mirrors examples/qualcomm/oss_scripts/swin_transformer.py adjustments. + from transformers.models.swin import modeling_swin + + def window_partition(input_feature: torch.Tensor, window_size: int) -> torch.Tensor: + batch_size, height, width, num_channels = input_feature.shape + input_feature = input_feature.view( + batch_size, + height // window_size, + window_size, + width // window_size, + window_size * num_channels, + ) + windows = input_feature.permute(0, 1, 3, 2, 4).contiguous() + return windows.view(-1, window_size, window_size, num_channels) + + def window_reverse( + windows: torch.Tensor, window_size: int, height: int, width: int + ) -> torch.Tensor: + num_channels = windows.shape[-1] + windows = windows.view( + -1, + height // window_size, + width // window_size, + window_size, + window_size * num_channels, + ) + windows = windows.permute(0, 1, 3, 2, 4).contiguous() + return windows.view(-1, height, width, num_channels) + + modeling_swin.window_partition = window_partition + modeling_swin.window_reverse = window_reverse + + +def _build_swin_model() -> tuple[torch.nn.Module, tuple[int, ...]]: + from transformers import SwinConfig, SwinForImageClassification + + _patch_swin_window_ops() + config = SwinConfig( + image_size=224, + patch_size=4, + num_channels=3, + embed_dim=64, + depths=[1, 1, 1, 1], + num_heads=[2, 4, 8, 16], + window_size=7, + num_labels=10, + ) + model = SwinForImageClassification(config).eval().to("cpu") + return model, (1, 3, 224, 224) + + +class _ToyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.features = torch.nn.Sequential( + torch.nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), + torch.nn.GELU(), + torch.nn.AdaptiveAvgPool2d((1, 1)), + ) + self.classifier = torch.nn.Linear(32, 10) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.features(x) + x = torch.flatten(x, 1) + return self.classifier(x) + + +def _build_toy_model() -> tuple[torch.nn.Module, tuple[int, ...]]: + return _ToyModel().eval().to("cpu"), (1, 3, 128, 128) + + +def _make_random_samples(input_shape: tuple[int, ...], num_samples: int) -> list[tuple[torch.Tensor, ...]]: + samples: list[tuple[torch.Tensor, ...]] = [] + for _ in range(num_samples): + samples.append((torch.rand(*input_shape),)) + return samples + + +def _fake_quantize_tensor(tensor: torch.Tensor, num_bits: int = 8) -> torch.Tensor: + if tensor.numel() == 0: + return tensor + qmax = (1 << (num_bits - 1)) - 1 + max_abs = tensor.detach().abs().max() + if float(max_abs) == 0.0: + return tensor + scale = max_abs / float(qmax) + q = (tensor / scale).round().clamp(-qmax, qmax) + return q * scale + + +def _make_fake_quantized_copy(model: torch.nn.Module) -> torch.nn.Module: + quantized = copy.deepcopy(model) + with torch.no_grad(): + for parameter in quantized.parameters(): + parameter.copy_(_fake_quantize_tensor(parameter)) + return quantized.eval().to("cpu") + + +def _export_with_debug_handles( + model: torch.nn.Module, sample_inputs: tuple[torch.Tensor, ...] +) -> torch.export.ExportedProgram: + ep = torch.export.export(model, sample_inputs, strict=False) + generate_missing_debug_handles(ep) + return ep + + +def _capture_outputs( + graph_module: torch.fx.GraphModule, sample_inputs: tuple[torch.Tensor, ...] +) -> Dict[DebugHandle, Any]: + capturer = IntermediateOutputCapturer(graph_module) + return capturer.run_and_capture(*sample_inputs) + + +def _node_to_handle( + handle_to_nodes: Mapping[DebugHandle, Sequence[str]], +) -> Dict[str, DebugHandle]: + result: Dict[str, DebugHandle] = {} + for handle, names in handle_to_nodes.items(): + for name in names: + result[name] = handle + return result + + +def _ensure_graph_module_debug_handles(graph_module: torch.fx.GraphModule) -> None: + max_handle = 0 + for node in graph_module.graph.nodes: + handle = node.meta.get("debug_handle") + if isinstance(handle, int): + max_handle = max(max_handle, handle) + elif isinstance(handle, (tuple, list)): + numeric = [int(x) for x in handle if isinstance(x, int)] + if numeric: + max_handle = max(max_handle, max(numeric)) + + next_handle = max_handle + 1 + for node in graph_module.graph.nodes: + if node.op in ("placeholder", "output"): + continue + handle = node.meta.get("debug_handle") + missing = handle is None or handle == 0 or handle == () or handle == [] + if missing: + node.meta["debug_handle"] = next_handle + next_handle += 1 + + +def _match_nodes( + reference_map: Mapping[DebugHandle, Sequence[str]], + candidate_map: Mapping[DebugHandle, Sequence[str]], +) -> tuple[list[MatchRecord], dict[str, int]]: + matches: list[MatchRecord] = [] + + ref_node_to_handle = _node_to_handle(reference_map) + cand_node_to_handle = _node_to_handle(candidate_map) + + matched_candidate_nodes: set[str] = set() + + # Phase 1: exact debug-handle matching. + for handle in sorted(set(reference_map.keys()) & set(candidate_map.keys())): + reference_node = reference_map[handle][0] + for candidate_node in candidate_map[handle]: + matches.append( + MatchRecord( + candidate_node=candidate_node, + reference_node=reference_node, + candidate_debug_handle=handle, + reference_debug_handle=handle, + matched_by="debug_handle", + ) + ) + matched_candidate_nodes.add(candidate_node) + + # Phase 2: node-name fallback. + for candidate_node, candidate_handle in cand_node_to_handle.items(): + if candidate_node in matched_candidate_nodes: + continue + if candidate_node not in ref_node_to_handle: + continue + reference_handle = ref_node_to_handle[candidate_node] + matches.append( + MatchRecord( + candidate_node=candidate_node, + reference_node=candidate_node, + candidate_debug_handle=candidate_handle, + reference_debug_handle=reference_handle, + matched_by="node_name", + ) + ) + matched_candidate_nodes.add(candidate_node) + + stats = { + "reference_handles": len(reference_map), + "candidate_handles": len(candidate_map), + "handle_intersection": len(set(reference_map.keys()) & set(candidate_map.keys())), + "matched_nodes": len(matches), + "matched_by_debug_handle": sum(1 for m in matches if m.matched_by == "debug_handle"), + "matched_by_node_name": sum(1 for m in matches if m.matched_by == "node_name"), + "candidate_nodes_unmatched": max(0, len(cand_node_to_handle) - len(matched_candidate_nodes)), + } + return matches, stats + + +def _flatten_for_metric(value: Any) -> tuple[torch.Tensor | None, str]: + if isinstance(value, torch.Tensor): + return value.detach().cpu().to(torch.float64).reshape(-1), str(tuple(value.shape)) + + if isinstance(value, (tuple, list)): + tensor_parts = [ + v.detach().cpu().to(torch.float64).reshape(-1) + for v in value + if isinstance(v, torch.Tensor) + ] + if tensor_parts: + shape = "[" + ", ".join(str(tuple(v.shape)) for v in value if isinstance(v, torch.Tensor)) + "]" + return torch.cat(tensor_parts), shape + scalar_parts = [float(v) for v in value if isinstance(v, (int, float))] + if scalar_parts: + return torch.tensor(scalar_parts, dtype=torch.float64), f"list(len={len(scalar_parts)})" + return None, "unsupported_sequence" + + if isinstance(value, (int, float, bool)): + return torch.tensor([float(value)], dtype=torch.float64), "scalar" + + return None, f"unsupported:{type(value).__name__}" + + +def _compute_metric_for_pair( + reference_value: Any, + candidate_value: Any, +) -> tuple[int, str, str, float, float, float, float] | None: + ref_flat, ref_shape = _flatten_for_metric(reference_value) + cand_flat, cand_shape = _flatten_for_metric(candidate_value) + + if ref_flat is None or cand_flat is None: + return None + + compared = min(ref_flat.numel(), cand_flat.numel()) + if compared == 0: + return None + + ref = torch.nan_to_num(ref_flat[:compared], nan=0.0, posinf=0.0, neginf=0.0) + cand = torch.nan_to_num(cand_flat[:compared], nan=0.0, posinf=0.0, neginf=0.0) + diff = cand - ref + abs_diff = diff.abs() + + max_abs = float(abs_diff.max().item()) + mean_abs = float(abs_diff.mean().item()) + mse = float((diff * diff).mean().item()) + + ref_norm = float(ref.norm().item()) + cand_norm = float(cand.norm().item()) + if ref_norm == 0.0 or cand_norm == 0.0: + cosine = 1.0 if ref_norm == cand_norm else 0.0 + else: + cosine = float(torch.nn.functional.cosine_similarity(ref, cand, dim=0).item()) + if math.isnan(cosine): + cosine = 0.0 + + return compared, ref_shape, cand_shape, max_abs, mean_abs, mse, cosine + + +def _compute_layer_metrics( + matches: Iterable[MatchRecord], + reference_outputs: Mapping[DebugHandle, Any], + candidate_outputs: Mapping[DebugHandle, Any], +) -> list[LayerMetric]: + metrics: list[LayerMetric] = [] + for match in matches: + if match.reference_debug_handle not in reference_outputs: + continue + if match.candidate_debug_handle not in candidate_outputs: + continue + computed = _compute_metric_for_pair( + reference_outputs[match.reference_debug_handle], + candidate_outputs[match.candidate_debug_handle], + ) + if computed is None: + continue + ( + compared, + reference_shape, + candidate_shape, + max_abs, + mean_abs, + mse, + cosine, + ) = computed + # Severity of performance drop: larger is worse. + severity = max_abs + max(0.0, 1.0 - cosine) + metrics.append( + LayerMetric( + candidate_node=match.candidate_node, + reference_node=match.reference_node, + candidate_debug_handle=match.candidate_debug_handle, + reference_debug_handle=match.reference_debug_handle, + matched_by=match.matched_by, + numel_compared=compared, + candidate_shape=candidate_shape, + reference_shape=reference_shape, + max_abs_err=max_abs, + mean_abs_err=mean_abs, + mse=mse, + cosine_similarity=cosine, + severity_score=severity, + ) + ) + return metrics + + +def _add_accuracy_extension(exporter: FXGraphExporter, metrics: Iterable[LayerMetric]) -> None: + ext = GraphExtension(id="per_layer_accuracy", name="Per-layer Accuracy (Worst Sample)") + for metric in metrics: + dh = metric.candidate_debug_handle + if isinstance(dh, (tuple, list)) and len(dh) > 0: + dh_scalar = int(dh[0]) + elif isinstance(dh, int) and dh != 0: + dh_scalar = dh + else: + dh_scalar = None + node_data: dict[str, Any] = { + "reference_node": metric.reference_node, + "candidate_debug_handle": list(metric.candidate_debug_handle), + "reference_debug_handle": list(metric.reference_debug_handle), + "matched_by": metric.matched_by, + "numel_compared": metric.numel_compared, + "candidate_shape": metric.candidate_shape, + "reference_shape": metric.reference_shape, + "max_abs_err": metric.max_abs_err, + "mean_abs_err": metric.mean_abs_err, + "mse": metric.mse, + "cosine_similarity": metric.cosine_similarity, + "severity_score": metric.severity_score, + } + if dh_scalar is not None: + node_data["debug_handle"] = dh_scalar + ext.add_node_data(metric.candidate_node, node_data) + + ext.set_sync_key("debug_handle") + ext.set_label_formatter( + lambda d: [ + f"severity={d.get('severity_score', 0.0):.2e}", + f"max_abs={d.get('max_abs_err', 0.0):.2e}", + ] + ) + ext.set_tooltip_formatter( + lambda d: [ + f"match={d.get('matched_by', 'n/a')}", + f"ref_node={d.get('reference_node', 'n/a')}", + f"ref_debug_handle={d.get('reference_debug_handle', [])}", + f"cand_debug_handle={d.get('candidate_debug_handle', [])}", + f"shape(ref)={d.get('reference_shape', 'n/a')}", + f"shape(cand)={d.get('candidate_shape', 'n/a')}", + f"numel={d.get('numel_compared', 0)}", + f"severity={d.get('severity_score', 0.0):.6e}", + f"max_abs={d.get('max_abs_err', 0.0):.6e}", + f"mean_abs={d.get('mean_abs_err', 0.0):.6e}", + f"mse={d.get('mse', 0.0):.6e}", + f"cos={d.get('cosine_similarity', 0.0):.6f}", + ] + ) + # Red severity map: higher severity => stronger red. + ext.set_color_rule( + NumericColorRule(attribute="severity_score", cmap="reds", handle_outliers=True) + ) + exporter.add_extension(ext) + + +def _to_primary_tensor(value: Any) -> torch.Tensor | None: + if isinstance(value, torch.Tensor): + return value + if hasattr(value, "logits") and isinstance(value.logits, torch.Tensor): + return value.logits + if isinstance(value, (tuple, list)): + for item in value: + t = _to_primary_tensor(item) + if t is not None: + return t + if isinstance(value, dict): + for item in value.values(): + t = _to_primary_tensor(item) + if t is not None: + return t + return None + + +def _score_samples_by_e2e_drop( + reference_graph: torch.fx.GraphModule, + candidate_graph: torch.fx.GraphModule, + samples: Sequence[tuple[torch.Tensor, ...]], +) -> tuple[list[SampleScore], int]: + scores: list[SampleScore] = [] + with torch.no_grad(): + for idx, sample in enumerate(samples): + ref_out = reference_graph(*sample) + cand_out = candidate_graph(*sample) + ref_t = _to_primary_tensor(ref_out) + cand_t = _to_primary_tensor(cand_out) + if ref_t is None or cand_t is None: + scores.append( + SampleScore( + sample_index=idx, + mse=float("inf"), + max_abs_err=float("inf"), + cosine_similarity=0.0, + drop_score=float("inf"), + ) + ) + continue + + ref = ref_t.detach().cpu().to(torch.float64).reshape(-1) + cand = cand_t.detach().cpu().to(torch.float64).reshape(-1) + compared = min(ref.numel(), cand.numel()) + if compared == 0: + scores.append( + SampleScore( + sample_index=idx, + mse=float("inf"), + max_abs_err=float("inf"), + cosine_similarity=0.0, + drop_score=float("inf"), + ) + ) + continue + + ref = torch.nan_to_num(ref[:compared], nan=0.0, posinf=0.0, neginf=0.0) + cand = torch.nan_to_num(cand[:compared], nan=0.0, posinf=0.0, neginf=0.0) + diff = cand - ref + mse = float((diff * diff).mean().item()) + max_abs = float(diff.abs().max().item()) + + ref_norm = float(ref.norm().item()) + cand_norm = float(cand.norm().item()) + if ref_norm == 0.0 or cand_norm == 0.0: + cosine = 1.0 if ref_norm == cand_norm else 0.0 + else: + cosine = float(torch.nn.functional.cosine_similarity(ref, cand, dim=0).item()) + if math.isnan(cosine): + cosine = 0.0 + + # Composite score for selecting the worst E2E sample. + drop_score = max_abs + mse + 5.0 * max(0.0, 1.0 - cosine) + scores.append( + SampleScore( + sample_index=idx, + mse=mse, + max_abs_err=max_abs, + cosine_similarity=cosine, + drop_score=drop_score, + ) + ) + + worst = max(scores, key=lambda s: s.drop_score) + return scores, worst.sample_index + + +def _write_compare_html( + output_path: Path, + panels: Sequence[dict[str, Any]], + default_columns: int, +) -> None: + js_bundle = FXGraphExporter._load_viewer_js_bundle() + + panel_html = [] + for idx, panel in enumerate(panels): + panel_html.append( + f""" +
+
{panel['title']}
+
+
+""" + ) + + html = f""" + + + + FX Viewer Accuracy Compare + + + +
+
FX Graph Multi-Compare
+ + + +
+
{''.join(panel_html)} +
+ + + + + +""" + output_path.write_text(html) + + +def _write_metrics_json( + output_path: Path, + metrics: Sequence[LayerMetric], + match_stats: Mapping[str, int], + metadata: Mapping[str, Any], + sample_scores: Sequence[SampleScore], + worst_sample_index: int, +) -> None: + payload = { + "metadata": dict(metadata), + "match_stats": dict(match_stats), + "worst_sample_index": worst_sample_index, + "sample_scores": [asdict(s) for s in sample_scores], + "summary": { + "layers_with_metrics": len(metrics), + "severity_max": max((m.severity_score for m in metrics), default=0.0), + "severity_mean": ( + sum(m.severity_score for m in metrics) / len(metrics) if metrics else 0.0 + ), + "max_abs_err_max": max((m.max_abs_err for m in metrics), default=0.0), + "max_abs_err_mean": ( + sum(m.max_abs_err for m in metrics) / len(metrics) if metrics else 0.0 + ), + "cosine_similarity_mean": ( + sum(m.cosine_similarity for m in metrics) / len(metrics) if metrics else 0.0 + ), + }, + "layers": [asdict(metric) for metric in metrics], + "top10_severity": [ + asdict(metric) + for metric in sorted(metrics, key=lambda m: m.severity_score, reverse=True)[:10] + ], + } + output_path.write_text(json.dumps(payload, indent=2)) + + +def _build_graph_pair_fake_quant( + model: torch.nn.Module, + export_sample: tuple[torch.Tensor, ...], +) -> GraphPair: + reference_ep = _export_with_debug_handles(model, export_sample) + candidate_model = _make_fake_quantized_copy(model) + candidate_ep = _export_with_debug_handles(candidate_model, export_sample) + return GraphPair( + pipeline="fake_quant", + reference_name="Reference Float", + candidate_name="Candidate Fake-Quantized", + reference_graph=reference_ep.module(), + candidate_graph=candidate_ep.module(), + metadata={ + "method": "deepcopy + weight rounding to int8 grid", + "qnn_sdk_root": os.getenv("QNN_SDK_ROOT", ""), + }, + ) + + +def _build_graph_pair_qualcomm_ptq( + model: torch.nn.Module, + export_sample: tuple[torch.Tensor, ...], + calibration_samples: Sequence[tuple[torch.Tensor, ...]], + soc_model: str, + backend_name: str, +) -> GraphPair: + from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype + from executorch.backends.qualcomm.serialization.qc_schema import ( + QnnExecuTorchBackendType, + ) + try: + from executorch.examples.qualcomm.utils import make_quantizer + except ModuleNotFoundError: + from examples.qualcomm.utils import make_quantizer + from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + + if os.getenv("QNN_SDK_ROOT") is None: + raise RuntimeError( + "QNN_SDK_ROOT is not set. Run: source ~/executorch/qairt/2.37.0.250724/bin/envsetup.sh" + ) + + backend = getattr(QnnExecuTorchBackendType, f"k{backend_name.title()}Backend") + + reference_ep = _export_with_debug_handles(model, export_sample) + quant_input_graph = reference_ep.module() + reference_graph = copy.deepcopy(quant_input_graph) + + quantizer = make_quantizer( + quant_dtype=QuantDtype.use_8a8w, + backend=backend, + soc_model=soc_model, + ) + annotated_model = prepare_pt2e(quant_input_graph, quantizer) + + with torch.no_grad(): + for sample in calibration_samples: + annotated_model(*sample) + + candidate_graph = convert_pt2e(annotated_model) + _ensure_graph_module_debug_handles(candidate_graph) + + return GraphPair( + pipeline="qualcomm_ptq", + reference_name="Reference Float (Exported)", + candidate_name=f"Candidate Qualcomm PTQ ({soc_model}, {backend_name.upper()})", + reference_graph=reference_graph, + candidate_graph=candidate_graph, + metadata={ + "method": "QnnQuantizer + prepare_pt2e/convert_pt2e", + "soc_model": soc_model, + "backend": backend_name, + "calibration_samples": len(calibration_samples), + "qnn_sdk_root": os.getenv("QNN_SDK_ROOT", ""), + }, + ) + + +def _run_single_pipeline( + graph_pair: GraphPair, + samples: Sequence[tuple[torch.Tensor, ...]], + pipeline_output_dir: Path, + seed: int, + model_name: str, + default_compare_columns: int, +) -> None: + pipeline_output_dir.mkdir(parents=True, exist_ok=True) + + print(f"[{graph_pair.pipeline}] Scoring end-to-end drop over {len(samples)} samples...") + sample_scores, worst_sample_idx = _score_samples_by_e2e_drop( + graph_pair.reference_graph, + graph_pair.candidate_graph, + samples, + ) + worst_sample = samples[worst_sample_idx] + worst_score = next(s for s in sample_scores if s.sample_index == worst_sample_idx) + print( + f"[{graph_pair.pipeline}] Worst sample idx={worst_sample_idx}, " + f"drop={worst_score.drop_score:.6e}, max_abs={worst_score.max_abs_err:.6e}, " + f"mse={worst_score.mse:.6e}, cos={worst_score.cosine_similarity:.6f}" + ) + + print(f"[{graph_pair.pipeline}] Capturing per-layer outputs on worst sample...") + reference_outputs = _capture_outputs(graph_pair.reference_graph, worst_sample) + candidate_outputs = _capture_outputs(graph_pair.candidate_graph, worst_sample) + + print(f"[{graph_pair.pipeline}] Building debug-handle mappings...") + reference_map = get_aot_debug_handle_to_op_name_mapping(graph_pair.reference_graph) + candidate_map = get_aot_debug_handle_to_op_name_mapping(graph_pair.candidate_graph) + matches, match_stats = _match_nodes(reference_map, candidate_map) + + print(f"[{graph_pair.pipeline}] Computing per-layer metrics...") + metrics = _compute_layer_metrics(matches, reference_outputs, candidate_outputs) + + print(f"[{graph_pair.pipeline}] Exporting fx_viewer HTML...") + reference_exporter = FXGraphExporter(graph_pair.reference_graph) + candidate_exporter = FXGraphExporter(graph_pair.candidate_graph) + _add_accuracy_extension(candidate_exporter, metrics) + + reference_html = pipeline_output_dir / "reference_fx_graph.html" + candidate_html = pipeline_output_dir / "candidate_fx_graph_per_layer_accuracy.html" + compare_html = pipeline_output_dir / "compare_side_by_side.html" + metrics_json = pipeline_output_dir / "per_layer_accuracy_metrics.json" + + reference_payload = reference_exporter.generate_json_payload() + candidate_payload = candidate_exporter.generate_json_payload() + + reference_exporter.export_html(str(reference_html)) + candidate_exporter.export_html(str(candidate_html)) + + panels = [ + { + "title": graph_pair.reference_name, + "payload": reference_payload, + }, + { + "title": f"{graph_pair.candidate_name} [worst sample={worst_sample_idx}]", + "payload": candidate_payload, + }, + ] + _write_compare_html(compare_html, panels=panels, default_columns=default_compare_columns) + + _write_metrics_json( + metrics_json, + metrics, + match_stats=match_stats, + metadata={ + "pipeline": graph_pair.pipeline, + "model": model_name, + "seed": seed, + "reference_node_count": len(list(graph_pair.reference_graph.graph.nodes)), + "candidate_node_count": len(list(graph_pair.candidate_graph.graph.nodes)), + "reference_captured_outputs": len(reference_outputs), + "candidate_captured_outputs": len(candidate_outputs), + **graph_pair.metadata, + }, + sample_scores=sample_scores, + worst_sample_index=worst_sample_idx, + ) + + top5 = sorted(metrics, key=lambda m: m.severity_score, reverse=True)[:5] + print(f"[{graph_pair.pipeline}] Done. Output: {pipeline_output_dir}") + print(f"[{graph_pair.pipeline}] Top-5 severity layers (red = worse):") + for item in top5: + print( + " " + f"{item.candidate_node}: severity={item.severity_score:.6e}, " + f"max_abs={item.max_abs_err:.6e}, cos={item.cosine_similarity:.6f}, " + f"match={item.matched_by}" + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Standalone per-layer accuracy demo for fx_viewer.") + parser.add_argument("--model", choices=["toy", "swin"], default="swin") + parser.add_argument("--output-dir", default="fx_viewer_accuracy_demo") + parser.add_argument("--seed", type=int, default=1126) + parser.add_argument("--num-samples", type=int, default=10) + parser.add_argument( + "--pipeline", + choices=["fake_quant", "qualcomm_ptq", "both"], + default="both", + help="Which comparison pipeline(s) to run.", + ) + parser.add_argument("--soc-model", default="SM8650") + parser.add_argument("--backend", choices=["htp", "gpu"], default="htp") + parser.add_argument("--calibration-steps", type=int, default=4) + parser.add_argument("--compare-columns", type=int, default=2) + args = parser.parse_args() + + _set_seed(args.seed) + output_dir = Path(args.output_dir).resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + if args.model == "swin": + reference_model, input_shape = _build_swin_model() + else: + reference_model, input_shape = _build_toy_model() + + samples = _make_random_samples(input_shape, max(1, args.num_samples)) + export_sample = samples[0] + + requested_pipelines = ( + ["fake_quant", "qualcomm_ptq"] if args.pipeline == "both" else [args.pipeline] + ) + + for pipeline in requested_pipelines: + if pipeline == "fake_quant": + pair = _build_graph_pair_fake_quant(reference_model, export_sample) + elif pipeline == "qualcomm_ptq": + calib_samples = samples[: max(1, args.calibration_steps)] + pair = _build_graph_pair_qualcomm_ptq( + reference_model, + export_sample, + calibration_samples=calib_samples, + soc_model=args.soc_model, + backend_name=args.backend, + ) + else: + raise AssertionError(f"Unsupported pipeline: {pipeline}") + + _run_single_pipeline( + pair, + samples=samples, + pipeline_output_dir=output_dir / pipeline, + seed=args.seed, + model_name=args.model, + default_compare_columns=max(1, min(4, args.compare_columns)), + ) + + print("\nDemo complete.") + print(f"Output root: {output_dir}") + for pipeline in requested_pipelines: + print(f" - {pipeline}/reference_fx_graph.html") + print(f" - {pipeline}/candidate_fx_graph_per_layer_accuracy.html") + print(f" - {pipeline}/compare_side_by_side.html") + print(f" - {pipeline}/per_layer_accuracy_metrics.json") + + +if __name__ == "__main__": + main() diff --git a/devtools/fx_viewer/examples/generate_api_test_harness.py b/devtools/fx_viewer/examples/generate_api_test_harness.py new file mode 100644 index 00000000000..0bc852cb941 --- /dev/null +++ b/devtools/fx_viewer/examples/generate_api_test_harness.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +"""Generate unified fx_viewer API harness HTML files. + +This generator builds two educational harness outputs: +1) Portable harness (no Qualcomm SDK required): + - Swin graph + - real per-layer accuracy extension from fake-quant comparison + - topology + color-by-type structural extensions +2) Qualcomm harness (requires QNN/QAIRT env): + - Swin graph + - real per-layer accuracy extension from Qualcomm PTQ comparison + - same structural extensions + Qualcomm metadata testcase + +Design goals: +1) Few CLI options. +2) One shared HTML template. +3) One shared testcase catalog. +4) Payload/testcase composition done in one place for easy extension. +""" + +from __future__ import annotations + +import argparse +import importlib.util +import json +import os +import random +import sys +from collections import deque +from pathlib import Path +from typing import Any + +import torch + +from executorch.devtools.fx_viewer import ( + CategoricalColorRule, + FXGraphExporter, + GraphExtension, + NumericColorRule, +) + +THIS_DIR = Path(__file__).resolve().parent +if str(THIS_DIR) not in sys.path: + sys.path.insert(0, str(THIS_DIR)) +from harness_testcases import build_testcases + + +def _load_local_accuracy_demo_module(): + module_path = THIS_DIR / "demo_per_layer_accuracy_fx.py" + spec = importlib.util.spec_from_file_location("fx_acc_demo_local", module_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to load module from {module_path}") + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +acc_demo = _load_local_accuracy_demo_module() + + +def _set_seed(seed: int) -> None: + random.seed(seed) + torch.manual_seed(seed) + + +def _load_viewer_js_bundle_local() -> str: + """Load viewer JS from workspace templates. + + We intentionally load from the local repo tree so the harness always uses + the current in-workspace JS runtime under development. + """ + template_dir = THIS_DIR.parent / "templates" + ordered_files = [ + "runtime.js", + "graph_data_store.js", + "search_engine.js", + "view_controller.js", + "canvas_renderer.js", + "minimap_renderer.js", + "ui_manager.js", + "fx_graph_viewer.js", + "compare.js", + ] + chunks: list[str] = [] + for filename in ordered_files: + path = template_dir / filename + chunks.append(f"\n// ---- {filename} ----\n") + chunks.append(path.read_text()) + return "\n".join(chunks) + + +def _compute_topological_index( + nodes: list[dict[str, Any]], edges: list[dict[str, Any]] +) -> dict[str, int]: + node_ids = [n["id"] for n in nodes] + indeg = {nid: 0 for nid in node_ids} + adj: dict[str, list[str]] = {nid: [] for nid in node_ids} + + for e in edges: + src, dst = e["v"], e["w"] + if src in adj and dst in indeg: + adj[src].append(dst) + indeg[dst] += 1 + + q = deque([nid for nid in node_ids if indeg[nid] == 0]) + topo: dict[str, int] = {} + idx = 0 + while q: + cur = q.popleft() + topo[cur] = idx + idx += 1 + for nxt in adj[cur]: + indeg[nxt] -= 1 + if indeg[nxt] == 0: + q.append(nxt) + + for nid in node_ids: + if nid not in topo: + topo[nid] = idx + idx += 1 + + return topo + + +def _build_color_by_type_extension(nodes: list[dict[str, Any]]) -> GraphExtension: + ext = GraphExtension(id="color_by_type", name="Color By Type") + for n in nodes: + info = n.get("info", {}) + color_data = str( + info.get("target") if info.get("op") == "call_function" else info.get("op", "unknown") + ) + ext.add_node_data( + n["id"], + { + "target": str(info.get("target", "unknown")), + "op": str(info.get("op", "unknown")), + "color_data": color_data, + }, + ) + + ext.set_label_formatter(lambda d: [f"color_data: {d.get('color_data', 'unknown')}"]) + ext.set_color_rule(CategoricalColorRule(attribute="color_data")) + return ext + + +def _build_topology_extension( + nodes: list[dict[str, Any]], edges: list[dict[str, Any]] +) -> GraphExtension: + topo_idx = _compute_topological_index(nodes, edges) + ext = GraphExtension(id="topological_order", name="Topological Order") + for n in nodes: + ext.add_node_data(n["id"], {"topo_index": topo_idx[n["id"]]}) + + ext.set_label_formatter(lambda d: [f"topo: {d.get('topo_index', -1)}"]) + ext.set_tooltip_formatter(lambda d: [f"Topological index: {d.get('topo_index', -1)}"]) + ext.set_color_rule( + NumericColorRule(attribute="topo_index", cmap="viridis", handle_outliers=False) + ) + return ext + + +def _add_structural_extensions(exporter: FXGraphExporter) -> None: + """Attach structural extensions (type + topology) to any exporter.""" + base = exporter.generate_json_payload() + nodes = base["base"]["nodes"] + edges = base["base"]["edges"] + exporter.add_extension(_build_color_by_type_extension(nodes)) + exporter.add_extension(_build_topology_extension(nodes, edges)) + + +def _compute_accuracy_metrics_for_pair( + graph_pair: Any, + samples: list[tuple[torch.Tensor, ...]], +) -> tuple[list[Any], int]: + sample_scores, worst_sample_idx = acc_demo._score_samples_by_e2e_drop( + graph_pair.reference_graph, + graph_pair.candidate_graph, + samples, + ) + worst_sample = samples[worst_sample_idx] + + reference_outputs = acc_demo._capture_outputs(graph_pair.reference_graph, worst_sample) + candidate_outputs = acc_demo._capture_outputs(graph_pair.candidate_graph, worst_sample) + + reference_map = acc_demo.get_aot_debug_handle_to_op_name_mapping(graph_pair.reference_graph) + candidate_map = acc_demo.get_aot_debug_handle_to_op_name_mapping(graph_pair.candidate_graph) + matches, _ = acc_demo._match_nodes(reference_map, candidate_map) + metrics = acc_demo._compute_layer_metrics(matches, reference_outputs, candidate_outputs) + + _ = sample_scores # kept for future extensions/reporting + return metrics, worst_sample_idx + + +def _build_swin_samples(num_samples: int) -> tuple[Any, list[tuple[torch.Tensor, ...]], tuple[torch.Tensor, ...]]: + model, input_shape = acc_demo._build_swin_model() + samples = acc_demo._make_random_samples(input_shape, num_samples=num_samples) + export_sample = samples[0] + return model, samples, export_sample + + +def _build_portable_payloads(num_samples: int) -> dict[str, Any]: + model, samples, export_sample = _build_swin_samples(num_samples) + graph_pair = acc_demo._build_graph_pair_fake_quant(model, export_sample) + metrics, worst_sample_idx = _compute_accuracy_metrics_for_pair(graph_pair, samples) + + reference_exporter = FXGraphExporter(graph_pair.reference_graph) + candidate_exporter = FXGraphExporter(graph_pair.candidate_graph) + _add_structural_extensions(reference_exporter) + _add_structural_extensions(candidate_exporter) + acc_demo._add_accuracy_extension(candidate_exporter, metrics) + + reference_payload = reference_exporter.generate_json_payload() + candidate_payload = candidate_exporter.generate_json_payload() + + # Second candidate: different fake-quant seed for 3-graph harness demo + torch.manual_seed(42) + candidate_model_2 = acc_demo._make_fake_quantized_copy(model) + candidate_ep_2 = acc_demo._export_with_debug_handles(candidate_model_2, export_sample) + graph_pair_2 = acc_demo.GraphPair( + pipeline="fake_quant_2", + reference_name="Reference Float", + candidate_name="Candidate Fake-Quantized (seed 42)", + reference_graph=graph_pair.reference_graph, + candidate_graph=candidate_ep_2.module(), + metadata={}, + ) + metrics_2, _ = _compute_accuracy_metrics_for_pair(graph_pair_2, samples) + candidate_exporter_2 = FXGraphExporter(graph_pair_2.candidate_graph) + _add_structural_extensions(candidate_exporter_2) + acc_demo._add_accuracy_extension(candidate_exporter_2, metrics_2) + candidate_payload_2 = candidate_exporter_2.generate_json_payload() + + return { + "profile": "portable", + "model": "swin", + "method": "fake_quant + intermediate_output_capturer", + "worst_sample_index": worst_sample_idx, + "structural": reference_payload, + "accuracy_reference": reference_payload, + "accuracy_candidate": candidate_payload, + "accuracy_candidate_2": candidate_payload_2, + } + + +def _build_qualcomm_payloads( + num_samples: int, + calibration_steps: int, + soc_model: str, + backend: str, +) -> dict[str, Any]: + model, samples, export_sample = _build_swin_samples(num_samples) + calibration_samples = samples[: max(1, calibration_steps)] + + graph_pair = acc_demo._build_graph_pair_qualcomm_ptq( + model=model, + export_sample=export_sample, + calibration_samples=calibration_samples, + soc_model=soc_model, + backend_name=backend, + ) + metrics, worst_sample_idx = _compute_accuracy_metrics_for_pair(graph_pair, samples) + + reference_exporter = FXGraphExporter(graph_pair.reference_graph) + candidate_exporter = FXGraphExporter(graph_pair.candidate_graph) + _add_structural_extensions(reference_exporter) + _add_structural_extensions(candidate_exporter) + acc_demo._add_accuracy_extension(candidate_exporter, metrics) + + reference_payload = reference_exporter.generate_json_payload() + candidate_payload = candidate_exporter.generate_json_payload() + + return { + "profile": "qualcomm", + "model": "swin", + "method": "QnnQuantizer + prepare_pt2e/convert_pt2e + intermediate_output_capturer", + "soc_model": soc_model, + "backend": backend, + "qnn_sdk_root": os.getenv("QNN_SDK_ROOT", ""), + "worst_sample_index": worst_sample_idx, + "structural": reference_payload, + "accuracy_reference": reference_payload, + "accuracy_candidate": candidate_payload, + } + + +def _load_template() -> str: + return (THIS_DIR / "harness_template.html").read_text() + + +def _render_html(payloads: dict[str, Any], testcases: list[dict[str, Any]]) -> str: + template = _load_template() + js_bundle = _load_viewer_js_bundle_local() + payload_json = json.dumps( + { + "meta": { + "profile": payloads.get("profile", "unknown"), + "model": payloads.get("model", "unknown"), + "method": payloads.get("method", "unknown"), + "soc_model": payloads.get("soc_model"), + "backend": payloads.get("backend"), + "qnn_sdk_root": payloads.get("qnn_sdk_root"), + "worst_sample_index": payloads.get("worst_sample_index"), + }, + "structural": payloads["structural"], + "accuracy_reference": payloads["accuracy_reference"], + "accuracy_candidate": payloads["accuracy_candidate"], + "accuracy_candidate_2": payloads.get("accuracy_candidate_2"), + } + ) + testcases_json = json.dumps(testcases) + + out = template.replace("__PAYLOADS_JSON__", payload_json) + out = out.replace("__TEST_CASES_JSON__", testcases_json) + out = out.replace("__VIEWER_JS_BUNDLE__", js_bundle) + return out + + +def _write_harness(output_path: Path, payloads: dict[str, Any], include_qualcomm: bool) -> None: + testcases = build_testcases(include_qualcomm=include_qualcomm) + output_path.write_text(_render_html(payloads, testcases)) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate unified fx_viewer API test harnesses") + parser.add_argument( + "--output-dir", + default=str(THIS_DIR), + help="Directory to write generated harness HTML files", + ) + parser.add_argument("--seed", type=int, default=1126) + parser.add_argument("--num-samples", type=int, default=6) + parser.add_argument("--calibration-steps", type=int, default=3) + parser.add_argument("--soc-model", default="SM8650") + parser.add_argument("--backend", choices=["htp", "gpu"], default="htp") + args = parser.parse_args() + + _set_seed(args.seed) + + output_dir = Path(args.output_dir).resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + print("[1/4] Building portable payloads (Swin + fake-quant accuracy)...") + portable = _build_portable_payloads(num_samples=args.num_samples) + + portable_html = output_dir / "fx_viewer_api_test_harness_portable.html" + print("[2/4] Writing portable harness...") + _write_harness(portable_html, portable, include_qualcomm=False) + print(f"Wrote: {portable_html}") + + qualcomm_html = output_dir / "fx_viewer_api_test_harness_qualcomm.html" + print("[3/4] Building Qualcomm payloads (if environment is ready)...") + try: + qualcomm = _build_qualcomm_payloads( + num_samples=args.num_samples, + calibration_steps=args.calibration_steps, + soc_model=args.soc_model, + backend=args.backend, + ) + print("[4/4] Writing Qualcomm harness...") + _write_harness(qualcomm_html, qualcomm, include_qualcomm=True) + print(f"Wrote: {qualcomm_html}") + except Exception as exc: + message = ( + "Qualcomm harness was not generated because environment setup is incomplete " + f"or PTQ build failed:\n{exc}\n\n" + "Portable harness is available and fully functional." + ) + qualcomm_html.write_text( + "
"
+            + message.replace("&", "&").replace("<", "<").replace(">", ">")
+            + "
" + ) + print(f"Wrote fallback note: {qualcomm_html}") + + +if __name__ == "__main__": + main() diff --git a/devtools/fx_viewer/examples/harness_template.html b/devtools/fx_viewer/examples/harness_template.html new file mode 100644 index 00000000000..aaf7cbd6634 --- /dev/null +++ b/devtools/fx_viewer/examples/harness_template.html @@ -0,0 +1,174 @@ + + + + + fx_viewer Unified API Harness + + + +
+
+ fx_viewer Unified API Harness + + + + +
+
+ +
+
+
+
HTML Input (Editable)
+
JS API Input (Editable)
+
Run Log
+
+ +
+
Outcome (Resizable Host)
+
+
+
+
+ + + + + + diff --git a/devtools/fx_viewer/examples/harness_testcases.py b/devtools/fx_viewer/examples/harness_testcases.py new file mode 100644 index 00000000000..334eb1c09ec --- /dev/null +++ b/devtools/fx_viewer/examples/harness_testcases.py @@ -0,0 +1,978 @@ +"""Testcase catalog for the unified fx_viewer API harness. + +This file intentionally orders cases from simple to advanced so the harness doubles +as a tutorial. +""" + +from __future__ import annotations + +from typing import Any + + +def build_testcases(*, include_qualcomm: bool) -> list[dict[str, Any]]: + cases: list[dict[str, Any]] = [ + { + "id": "js_01_create_init_destroy", + "title": "JS 01: Create / Init / Destroy", + "description": "Smallest viewer lifecycle example.", + "html": """ +
+
+ + + Target APIs: create, init, destroy, getState +
+
+
+

+  
+
+""".strip(), + "js": """ +let viewer = null; + +function renderState() { + const stateEl = document.getElementById('c1_state'); + if (!viewer) { + stateEl.textContent = 'viewer = null'; + return; + } + const s = viewer.getState(); + stateEl.textContent = JSON.stringify({ + theme: s.theme, + colorBy: s.colorBy, + selectedNodeId: s.selectedNodeId, + activeExtensions: s.activeExtensions, + }, null, 2); +} + +function createViewer() { + if (viewer) viewer.destroy(); + viewer = FXGraphViewer.create({ + payload: api.payloads.structural, + mount: { root: '#c1_view' }, + layout: { preset: 'split' }, + }); + viewer.init(); + renderState(); + api.log('Created + initialized viewer'); +} + +function destroyViewer() { + if (!viewer) return; + viewer.destroy(); + viewer = null; + renderState(); + api.log('Destroyed viewer'); +} + +document.getElementById('c1_create').addEventListener('click', createViewer); +document.getElementById('c1_destroy').addEventListener('click', destroyViewer); + +createViewer(); +api.setCleanup(() => destroyViewer()); +""".strip(), + }, + { + "id": "js_02_state_theme", + "title": "JS 02: State + Theme", + "description": "Learn getState/setState/setTheme with visible state snapshot.", + "html": """ +
+
+ + + Target APIs: getState, setState, setTheme +
+
+
+

+  
+
+""".strip(), + "js": """ +const viewer = FXGraphViewer.create({ + payload: api.payloads.structural, + mount: { root: '#c2_view' }, + layout: { preset: 'split' }, + state: { theme: 'light' }, +}); +viewer.init(); +api.registerViewer(viewer); + +const stateEl = document.getElementById('c2_state'); +const themeSel = document.getElementById('c2_theme'); + +themeSel.addEventListener('change', () => { + viewer.setTheme(themeSel.value); + renderState(); +}); + +document.getElementById('c2_toggle_highlight').addEventListener('click', () => { + const s = viewer.getState(); + viewer.setState({ highlightAncestors: !s.highlightAncestors }); + renderState(); +}); + +function renderState() { + const s = viewer.getState(); + stateEl.textContent = JSON.stringify({ + theme: s.theme, + highlightAncestors: s.highlightAncestors, + colorBy: s.colorBy, + activeExtensions: s.activeExtensions, + camera: s.camera, + }, null, 2); +} + +renderState(); +api.log('Use theme dropdown and highlight toggle, then inspect getState output.'); +""".strip(), + }, + { + "id": "js_03_selection_camera", + "title": "JS 03: Selection + Camera", + "description": "Control navigation APIs explicitly from custom host buttons.", + "html": """ +
+
+ + + + + + Target APIs: selectNode, animateToNode, panToNode, zoomToFit, clearSelection +
+
+
+""".strip(), + "js": """ +const viewer = FXGraphViewer.create({ + payload: api.payloads.structural, + mount: { root: '#c3_view' }, + layout: { preset: 'split' }, +}); +viewer.init(); +api.registerViewer(viewer); + +const ids = viewer.store.baseData.nodes.map((n) => n.id); +const firstId = ids[0]; +const midId = ids[Math.floor(ids.length / 2)]; +const lastId = ids[ids.length - 1]; + +document.getElementById('c3_select_first').addEventListener('click', () => { + viewer.selectNode(firstId, { center: true }); +}); + +document.getElementById('c3_select_mid').addEventListener('click', () => { + viewer.selectNode(midId, { animate: true, center: true }); +}); + +document.getElementById('c3_pan_last').addEventListener('click', () => { + viewer.panToNode(lastId); +}); + +document.getElementById('c3_zoom_fit').addEventListener('click', () => viewer.zoomToFit()); +document.getElementById('c3_clear').addEventListener('click', () => viewer.clearSelection()); + +api.log('Use buttons to see how camera + selection APIs differ.'); +""".strip(), + }, + { + "id": "js_04_layers_colorby", + "title": "JS 04: Layers + ColorBy", + "description": "Learn extension activation and color source switching.", + "html": """ +
+
+
Layer Controls
+ + +
+ +
+
+
+""".strip(), + "js": """ +const viewer = FXGraphViewer.create({ + payload: api.payloads.structural, + mount: { root: '#c4_view' }, + layout: { preset: 'split' }, + state: { activeExtensions: ['color_by_type', 'topological_order'], colorBy: 'topological_order' }, +}); +viewer.init(); +api.registerViewer(viewer); + +const layerType = document.getElementById('c4_layer_type'); +const layerTopo = document.getElementById('c4_layer_topo'); +const colorBy = document.getElementById('c4_colorby'); + +function applyLayers() { + const layers = []; + if (layerType.checked) layers.push('color_by_type'); + if (layerTopo.checked) layers.push('topological_order'); + viewer.setLayers(layers); +} + +layerType.addEventListener('change', applyLayers); +layerTopo.addEventListener('change', applyLayers); +colorBy.addEventListener('change', () => viewer.setColorBy(colorBy.value)); + +api.log('Toggle layers and colorBy to observe legend/canvas updates.'); +""".strip(), + }, + { + "id": "js_05_runtime_mutation", + "title": "JS 05: Runtime Layer Mutation", + "description": "Create, patch, recolor, relabel, and remove a dynamic layer at runtime.", + "html": """ +
+
+
Runtime Mutation Controls
+ +
+
+ + + +
+
+
+""".strip(), + "js": """ +const layerId = 'runtime_score'; +const viewer = FXGraphViewer.create({ + payload: api.payloads.structural, + mount: { root: '#c5_view' }, + layout: { preset: 'split' }, + state: { activeExtensions: ['color_by_type', 'topological_order'], colorBy: 'base' }, +}); +viewer.init(); +api.registerViewer(viewer); + +const allNodes = viewer.store.baseData.nodes.slice(0, 140); +const runtimeNodes = {}; +allNodes.forEach((n, idx) => { + runtimeNodes[n.id] = { + info: { runtime_score: idx / Math.max(1, allNodes.length - 1) }, + label_append: [`r=${(idx / Math.max(1, allNodes.length - 1)).toFixed(2)}`], + fill_color: '#93c5fd', + }; +}); + +viewer.upsertLayer(layerId, { + name: 'Runtime Score', + legend: [ + { label: 'low', color: '#93c5fd' }, + { label: 'high', color: '#b91c1c' }, + ], + nodes: runtimeNodes, +}); +viewer.setLayers(['color_by_type', 'topological_order', layerId]); +viewer.setColorBy(layerId); + +const slider = document.getElementById('c5_threshold'); +const valueEl = document.getElementById('c5_threshold_value'); + +function applyThreshold() { + const t = Number(slider.value) / 100; + valueEl.textContent = `threshold=${t.toFixed(2)}`; + const patch = {}; + Object.entries(runtimeNodes).forEach(([nodeId, nodeData]) => { + const score = Number((nodeData.info && nodeData.info.runtime_score) || 0); + patch[nodeId] = { + fill_color: score >= t ? '#b91c1c' : '#93c5fd', + label_append: [`r=${score.toFixed(2)}`], + }; + }); + viewer.patchLayerNodes(layerId, patch); + viewer.setColorBy(layerId); +} + +slider.addEventListener('input', applyThreshold); +applyThreshold(); + +document.getElementById('c5_rename').addEventListener('click', () => { + viewer.setLayerLabel(layerId, 'Runtime Score (renamed)'); +}); + +document.getElementById('c5_rule').addEventListener('click', () => { + viewer.setColorRule(layerId, (nodeData) => { + const s = Number((nodeData.info && nodeData.info.runtime_score) || 0); + return s > 0.7 ? '#14532d' : '#fef08a'; + }); + viewer.setColorBy(layerId); +}); + +document.getElementById('c5_remove').addEventListener('click', () => { + viewer.removeLayer(layerId); + viewer.setColorBy('base'); +}); + +api.log('This case targets runtime mutation APIs end-to-end.'); +""".strip(), + }, + { + "id": "js_06_layout_slots", + "title": "JS 06: Layout + External Slots", + "description": "Mount viewer pieces into external host divs and control layout/UI visibility.", + "html": """ +
+
+ + + + Target APIs: mount.slots, setLayout, setUIVisibility +
+ +
+
+
+
+
+
+
+
+
+
+
+
+""".strip(), + "js": """ +let showInfo = true; +let showMinimap = true; +let showChrome = true; + +const viewer = FXGraphViewer.create({ + payload: api.payloads.structural, + mount: { + root: '#c6_hidden_root', + slots: { + canvas: '#c6_canvas', + toolbar: '#c6_toolbar', + info: '#c6_info', + minimap: '#c6_minimap', + legend: '#c6_legend', + }, + }, + layout: { + preset: 'headless', + panels: { + info: { visible: true }, + minimap: { visible: true, height: 220, resizable: false }, + legend: { visible: true }, + }, + }, + ui: { + controls: { + toolbar: true, + search: true, + layers: true, + colorBy: true, + theme: true, + legend: true, + zoomButtons: true, + highlightButton: true, + fullscreenButton: false, + }, + }, + state: { activeExtensions: ['topological_order'], colorBy: 'topological_order' }, +}); +viewer.init(); +api.registerViewer(viewer); + +document.getElementById('c6_toggle_info').addEventListener('click', () => { + showInfo = !showInfo; + viewer.setLayout({ panels: { info: { visible: showInfo } } }); +}); + +document.getElementById('c6_toggle_minimap').addEventListener('click', () => { + showMinimap = !showMinimap; + viewer.setLayout({ panels: { minimap: { visible: showMinimap } } }); +}); + +document.getElementById('c6_toggle_chrome').addEventListener('click', () => { + showChrome = !showChrome; + viewer.setUIVisibility({ + toolbar: showChrome, + search: showChrome, + layers: showChrome, + theme: showChrome, + }); +}); + +api.log('This case demonstrates host-owned slots and runtime layout toggles.'); +""".strip(), + }, + { + "id": "js_07_events", + "title": "JS 07: Events and Subscriptions", + "description": "Observe and unsubscribe viewer events from host code.", + "html": """ +
+
+ + + + + +
+
+
+

+  
+
+""".strip(), + "js": """ +const viewer = FXGraphViewer.create({ + payload: api.payloads.structural, + mount: { root: '#c7_view' }, + layout: { preset: 'split' }, +}); +viewer.init(); +api.registerViewer(viewer); + +const logEl = document.getElementById('c7_log'); +const firstNode = viewer.store.baseData.nodes[0].id; +let minimapVisible = true; + +function log(line) { + logEl.textContent = `${line}\n${logEl.textContent}`.slice(0, 3000); +} + +const offState = viewer.on('statechange', (evt) => log(`statechange source=${evt.source}`)); +const offSel = viewer.on('selectionchange', (evt) => log(`selection ${evt.prevSelection} -> ${evt.nextSelection}`)); +const offTheme = viewer.on('themechange', (evt) => log(`theme ${evt.prevTheme} -> ${evt.nextTheme}`)); +const offLayout = viewer.on('layoutchange', () => log('layoutchange')); + +document.getElementById('c7_theme').addEventListener('click', () => { + viewer.setTheme(viewer.getState().theme === 'light' ? 'dark' : 'light'); +}); +document.getElementById('c7_select').addEventListener('click', () => viewer.selectNode(firstNode, { animate: true, center: true })); +document.getElementById('c7_clear').addEventListener('click', () => viewer.clearSelection()); +document.getElementById('c7_layout').addEventListener('click', () => { + minimapVisible = !minimapVisible; + viewer.setLayout({ panels: { minimap: { visible: minimapVisible } } }); +}); +document.getElementById('c7_unsub').addEventListener('click', () => { + offState(); offSel(); offTheme(); offLayout(); + log('All event listeners unsubscribed'); +}); + +api.log('Trigger controls and inspect event stream in the right log panel.'); +""".strip(), + }, + { + "id": "js_08_compare_basics", + "title": "JS 08: 3-Graph Compare + Auto debug_handle Sync", + "description": "Three-view compare (Reference, Candidate A, Candidate B) using Map API. Default sync is 'Auto (handle→id)' — tries debug_handle first, falls back to node name.", + "html": """ +
+
+ 3-graph compare with auto debug_handle sync. Sidebar: Auto (handle→id) | ID only | Ext: per_layer_accuracy.debug_handle | Don't sync +
+
+ + + +
+
+""".strip(), + "js": """ +let compare = null; + +function buildCompare() { + if (compare) { compare.destroy(); compare = null; } + + const ref = FXGraphViewer.create({ + payload: api.payloads.accuracy_reference, + mount: { root: '#c8_ref_mount' }, + layout: { preset: 'split' }, + state: { activeExtensions: ['per_layer_accuracy', 'color_by_type'], colorBy: 'color_by_type' }, + }); + ref.init(); + api.registerViewer(ref); + + const cand1 = FXGraphViewer.create({ + payload: api.payloads.accuracy_candidate, + mount: { root: '#c8_cand1_mount' }, + layout: { preset: 'split' }, + state: { activeExtensions: ['per_layer_accuracy'], colorBy: 'per_layer_accuracy' }, + }); + cand1.init(); + api.registerViewer(cand1); + + const cand2Payload = api.payloads.accuracy_candidate_2 || api.payloads.accuracy_candidate; + const cand2 = FXGraphViewer.create({ + payload: cand2Payload, + mount: { root: '#c8_cand2_mount' }, + layout: { preset: 'split' }, + state: { activeExtensions: ['per_layer_accuracy'], colorBy: 'per_layer_accuracy' }, + }); + cand2.init(); + api.registerViewer(cand2); + + compare = FXGraphCompare.create({ + viewers: new Map([ + ['Reference', ref], + ['Candidate A', cand1], + ['Candidate B', cand2], + ]), + layout: { container: '#c8_grid' }, + // sync defaults to { mode: 'auto' } + }); + api.registerCompare(compare); + + api.log('3-graph compare ready. Sidebar shows Auto (handle→id) sync by default.'); +} + +buildCompare(); +api.setCleanup(() => { if (compare) compare.destroy(); }); +""".strip(), + }, + { + "id": "adv_01_accuracy_dynamic", + "title": "ADV 01: Per-layer Accuracy Controls", + "description": "Interesting combo: real per-layer metrics + dynamic threshold + theme + focus.", + "html": """ +
+
+
Accuracy Controls
+ +
+ +
+ +
+
+
+""".strip(), + "js": """ +const viewer = FXGraphViewer.create({ + payload: api.payloads.accuracy_candidate, + mount: { root: '#acc_view' }, + layout: { preset: 'split', panels: { sidebar: { width: 420 } } }, + state: { + activeExtensions: ['per_layer_accuracy', 'topological_order', 'color_by_type'], + colorBy: 'per_layer_accuracy', + theme: 'light', + }, +}); +viewer.init(); +api.registerViewer(viewer); + +const extId = 'per_layer_accuracy'; +const nodes = viewer.store.extensions[extId].nodes; +const severities = Object.values(nodes) + .map((n) => Number((n.info && n.info.severity_score) || 0)) + .filter(Number.isFinite) + .sort((a, b) => a - b); + +const slider = document.getElementById('acc_threshold'); +const label = document.getElementById('acc_threshold_value'); +const themeSel = document.getElementById('acc_theme'); +const focusBtn = document.getElementById('acc_focus_worst'); + +function quantile(q) { + if (severities.length === 0) return 0; + const i = Math.min(severities.length - 1, Math.max(0, Math.floor(q * (severities.length - 1)))); + return severities[i]; +} + +function applyThreshold() { + const p = Number(slider.value) / 100; + const threshold = quantile(p); + label.textContent = `percentile=${slider.value}, threshold=${threshold.toExponential(3)}`; + const patch = {}; + Object.entries(nodes).forEach(([nodeId, nodeData]) => { + const s = Number((nodeData.info && nodeData.info.severity_score) || 0); + patch[nodeId] = { + fill_color: s >= threshold ? '#991b1b' : '#fecaca', + label_append: [`sev=${s.toExponential(2)}`], + }; + }); + viewer.patchLayerNodes(extId, patch); + viewer.setColorBy(extId); +} + +slider.addEventListener('input', applyThreshold); +themeSel.addEventListener('change', () => viewer.setTheme(themeSel.value)); +focusBtn.addEventListener('click', () => { + let worst = null; + let worstScore = -Infinity; + Object.entries(nodes).forEach(([nodeId, nodeData]) => { + const s = Number((nodeData.info && nodeData.info.severity_score) || 0); + if (s > worstScore) { + worstScore = s; + worst = nodeId; + } + }); + if (worst) viewer.selectNode(worst, { animate: true, center: true }); +}); + +applyThreshold(); +api.log(`Loaded real accuracy payload. worst_sample_index=${api.payloads.meta.worst_sample_index}`); +""".strip(), + }, + { + "id": "adv_02_headless_slots_slider", + "title": "ADV 02: Headless Slots + Slider", + "description": "Interesting combo: custom host layout + external slots + dynamic recoloring.", + "html": """ +
+ +
+
+
Custom Controls
+ +
+
+
+
+
+
+
+
+
+
+
+
+""".strip(), + "js": """ +const viewer = FXGraphViewer.create({ + payload: api.payloads.structural, + mount: { + root: '#adv2_headless_mount', + slots: { + canvas: '#adv2_slot_canvas', + info: '#adv2_slot_info', + minimap: '#adv2_slot_minimap', + legend: '#adv2_slot_legend', + }, + }, + layout: { + preset: 'headless', + panels: { + minimap: { visible: true, height: 220, resizable: false }, + info: { visible: true }, + legend: { visible: true }, + }, + }, + ui: { + controls: { + toolbar: false, + search: false, + layers: false, + colorBy: false, + theme: false, + legend: true, + zoomButtons: false, + highlightButton: false, + fullscreenButton: false, + }, + }, + state: { activeExtensions: ['topological_order'], colorBy: 'topological_order' }, +}); +viewer.init(); +api.registerViewer(viewer); + +const slider = document.getElementById('adv2_topo_threshold'); +const valueEl = document.getElementById('adv2_topo_threshold_value'); +const nodes = viewer.store.extensions['topological_order'].nodes; +const maxTopo = Math.max(...Object.values(nodes).map((n) => Number(n.info.topo_index || 0))); +slider.max = String(maxTopo); + +function renderThreshold() { + const threshold = Number(slider.value); + valueEl.textContent = `threshold=${threshold} / ${maxTopo}`; + const patch = {}; + Object.entries(nodes).forEach(([nodeId, nodeData]) => { + const idx = Number(nodeData.info.topo_index || 0); + patch[nodeId] = { fill_color: idx >= threshold ? '#b91c1c' : '#93c5fd' }; + }); + viewer.patchLayerNodes('topological_order', patch); + viewer.setColorBy('topological_order'); +} + +slider.addEventListener('input', renderThreshold); +renderThreshold(); +api.log('Headless slot composition active. Move slider and inspect recoloring.'); +""".strip(), + }, + { + "id": "adv_03_fullscreen_toolbar", + "title": "ADV 03: Fullscreen + Toolbar API", + "description": "Interesting combo: fullscreen button in toolbar + direct fullscreen APIs.", + "html": """ +
+
+
Fullscreen Controls
+ + +

Taskbar also has a fullscreen toggle button in this case.

+
+
+
+""".strip(), + "js": """ +const viewer = FXGraphViewer.create({ + payload: api.payloads.structural, + mount: { root: '#adv3_view' }, + layout: { preset: 'split', fullscreen: { enabled: true, button: true } }, + state: { activeExtensions: ['color_by_type'], colorBy: 'color_by_type' }, +}); +viewer.init(); +api.registerViewer(viewer); + +document.getElementById('adv3_enter_fs').addEventListener('click', () => viewer.enterFullscreen()); +document.getElementById('adv3_exit_fs').addEventListener('click', () => viewer.exitFullscreen()); + +api.log('Use taskbar fullscreen button or side controls to validate API + UI integration.'); +""".strip(), + }, + { + "id": "adv_04_tiled_compare", + "title": "ADV 04: 3-Graph Compare + Extension Sync Key", + "description": "Three-graph compare demonstrating set_sync_key('debug_handle') on per_layer_accuracy. Sidebar shows 'Ext: per_layer_accuracy.debug_handle' as an explicit sync option.", + "html": """ +
+ + + +
+""".strip(), + "js": """ +const ref = FXGraphViewer.create({ + payload: api.payloads.accuracy_reference, + mount: { root: '#adv4_ref_mount' }, + layout: { preset: 'split' }, + state: { activeExtensions: ['per_layer_accuracy', 'topological_order'], colorBy: 'topological_order' }, +}); +ref.init(); +api.registerViewer(ref); + +const left = FXGraphViewer.create({ + payload: api.payloads.accuracy_candidate, + mount: { root: '#adv4_left_mount' }, + layout: { preset: 'split' }, + state: { + activeExtensions: ['per_layer_accuracy', 'topological_order'], + colorBy: 'per_layer_accuracy', + }, +}); +left.init(); +api.registerViewer(left); + +const cand2Payload = api.payloads.accuracy_candidate_2 || api.payloads.accuracy_candidate; +const right = FXGraphViewer.create({ + payload: cand2Payload, + mount: { root: '#adv4_right_mount' }, + layout: { preset: 'split' }, + state: { + activeExtensions: ['per_layer_accuracy', 'topological_order'], + colorBy: 'per_layer_accuracy', + }, +}); +right.init(); +api.registerViewer(right); + +const compare = FXGraphCompare.create({ + viewers: new Map([ + ['Reference', ref], + ['Candidate A', left], + ['Candidate B', right], + ]), + layout: { container: '#adv4_grid' }, + sync: { mode: 'layer', layer: 'per_layer_accuracy', field: 'debug_handle' }, +}); +api.registerCompare(compare); + +api.log('ADV04: 3-graph compare with extension sync key per_layer_accuracy.debug_handle active.'); +""".strip(), + }, + { + "id": "js_99_combo_mixed", + "title": "JS 99: Mixed Combo Demo", + "description": "Current mixed demo: compare + sync + runtime mutation + events + themed controls.", + "html": """ +
+
+
Combo Controls
+ +
+ +
+
+
+
+ + +
+

+  
+
+ + +
+
+""".strip(), + "js": """ +const left = FXGraphViewer.create({ + payload: api.payloads.accuracy_reference, + mount: { root: '#c99_left_mount' }, + layout: { preset: 'split' }, + state: { activeExtensions: ['color_by_type'], colorBy: 'color_by_type', theme: 'light' }, +}); +left.init(); +api.registerViewer(left); + +const right = FXGraphViewer.create({ + payload: api.payloads.accuracy_candidate, + mount: { root: '#c99_right_mount' }, + layout: { preset: 'split' }, + state: { + activeExtensions: ['per_layer_accuracy', 'topological_order', 'color_by_type'], + colorBy: 'per_layer_accuracy', + theme: 'light', + }, +}); +right.init(); +api.registerViewer(right); + +const compare = FXGraphCompare.create({ + viewers: [left, right], + layout: { columns: 2, container: '#c99_grid' }, + sync: { mode: 'id' }, +}); +api.registerCompare(compare); + +const logEl = document.getElementById('c99_log'); +function log(msg) { + logEl.textContent = `${new Date().toLocaleTimeString()} ${msg}\n${logEl.textContent}`.slice(0, 4000); +} + +const offSel = right.on('selectionchange', (evt) => log(`selection ${evt.prevSelection} -> ${evt.nextSelection}`)); +const offTheme = right.on('themechange', (evt) => log(`theme ${evt.prevTheme} -> ${evt.nextTheme}`)); + +api.setCleanup(() => { + offSel(); + offTheme(); +}); + +const extId = 'per_layer_accuracy'; +const nodes = right.store.extensions[extId].nodes; +const severities = Object.values(nodes) + .map((n) => Number((n.info && n.info.severity_score) || 0)) + .filter(Number.isFinite) + .sort((a, b) => a - b); + +function quantile(q) { + if (severities.length === 0) return 0; + const i = Math.min(severities.length - 1, Math.max(0, Math.floor(q * (severities.length - 1)))); + return severities[i]; +} + +function applyThreshold() { + const slider = document.getElementById('c99_threshold'); + const threshold = quantile(Number(slider.value) / 100); + document.getElementById('c99_threshold_text').textContent = + `percentile=${slider.value} threshold=${threshold.toExponential(3)}`; + const patch = {}; + Object.entries(nodes).forEach(([nodeId, nodeData]) => { + const s = Number((nodeData.info && nodeData.info.severity_score) || 0); + patch[nodeId] = { + fill_color: s >= threshold ? '#991b1b' : '#fecaca', + label_append: [`sev=${s.toExponential(2)}`], + }; + }); + right.patchLayerNodes(extId, patch); + right.setColorBy(extId); +} + +function focusWorst() { + let worst = null; + let worstScore = -Infinity; + Object.entries(nodes).forEach(([nodeId, nodeData]) => { + const s = Number((nodeData.info && nodeData.info.severity_score) || 0); + if (s > worstScore) { + worstScore = s; + worst = nodeId; + } + }); + if (worst) { + right.selectNode(worst, { animate: true, center: true }); + log(`focus worst node=${worst} score=${worstScore.toExponential(3)}`); + } +} + +document.getElementById('c99_theme').addEventListener('change', (e) => { + left.setTheme(e.target.value); + right.setTheme(e.target.value); +}); +document.getElementById('c99_threshold').addEventListener('input', applyThreshold); +document.getElementById('c99_sync_sel').addEventListener('change', (e) => compare.setSync({ mode: e.target.checked ? 'id' : 'none' })); +document.getElementById('c99_focus_worst').addEventListener('click', focusWorst); +document.getElementById('c99_sequence').addEventListener('click', () => { + log('scripted sequence start'); + left.setTheme('dark'); + applyThreshold(); + focusWorst(); + left.zoomToFit(); + right.zoomToFit(); + log('scripted sequence done'); +}); + +applyThreshold(); +log('Mixed combo demo ready.'); +api.log('Final mixed demo: compare + sync + mutation + events + controls.'); +""".strip(), + }, + ] + + if include_qualcomm: + cases.append( + { + "id": "qualcomm_metadata", + "title": "QUALCOMM: PTQ Metadata", + "description": "Qualcomm-specific payload metadata from real QNN PTQ path.", + "html": """ +
+
+

Qualcomm Metadata

+

+  
+
+
+""".strip(), + "js": """ +const viewer = FXGraphViewer.create({ + payload: api.payloads.accuracy_candidate, + mount: { root: '#qnn_view' }, + layout: { preset: 'split' }, + state: { activeExtensions: ['per_layer_accuracy'], colorBy: 'per_layer_accuracy' }, +}); +viewer.init(); +api.registerViewer(viewer); + +document.getElementById('qnn_meta').textContent = JSON.stringify(api.payloads.meta, null, 2); +api.log('Rendered Qualcomm PTQ payload + metadata snapshot.'); +""".strip(), + } + ) + + return cases diff --git a/devtools/fx_viewer/exporter.py b/devtools/fx_viewer/exporter.py new file mode 100644 index 00000000000..11f0158932e --- /dev/null +++ b/devtools/fx_viewer/exporter.py @@ -0,0 +1,888 @@ +from __future__ import annotations + +import copy +import json +import os +import warnings +from dataclasses import asdict +from typing import Callable, List, Optional, Any, Dict, Sequence + +import torch +import torch.fx + +from .color_rules import ColorRule +from .extension import GraphExtension + +try: + from executorch.exir.dialects.edge._ops import EdgeOpOverload +except ImportError: + EdgeOpOverload = None +from .models import ( + BaseGraphPayload, + GraphEdge, + GraphExtensionPayload, + GraphNode, + GraphPayload, +) + + +class FXGraphExporter: + """Export PyTorch FX graphs to JSON/JS/HTML payloads for the viewer. + + The exporter extracts node metadata from ``fx_node.meta`` into ``node.info``. + Scalar meta values (str, int, float, bool) are included automatically. + ``debug_handle`` is handled explicitly to support both scalar (int) and + fused/tuple forms (tuple[int, ...] / list[int]): + - int → stored as int + - tuple/list with one non-zero element → stored as int + - tuple/list with multiple non-zero elements → stored as list[int] + This ensures ``node.info.debug_handle`` is always present and usable by the + JS compare sync engine (``mode: 'auto'`` set-intersection matching). + """ + + def __init__(self, graph_module: torch.fx.GraphModule): + self.graph_module = graph_module + self.extensions: List[GraphExtension] = [] + + self.base_label_formatter: Callable[[GraphNode], str] = self._default_base_label + self.base_tooltip_formatter: Callable[[GraphNode], List[str]] = self._default_base_tooltip + self.base_color_rule: Optional[ColorRule] = None + + _NODE_CHAR_WIDTH = 7 + _NODE_MIN_WIDTH = 100 + _NODE_X_PADDING = 20 + _NODE_LINE_HEIGHT = 16 + _NODE_Y_PADDING = 20 + _LAYOUT_XSPACE = 50 + _LAYOUT_YSPACE = 30 + _DUMMY_SIZE_X = 100 # dummy nodes (from fast-sugiyama) occupy no real width/height + _DUMMY_SIZE_Y = 30 # dummy nodes (from fast-sugiyama) occupy no real width/height + _SPINE_COHESION_ITER = 20 + + def _default_base_label(self, node: GraphNode) -> str: + target = str(node.info.get("target") or node.info.get("op") or "") + return target.replace("aten.", "").replace(".default", "") + + def _default_base_tooltip(self, node: GraphNode) -> List[str]: + lines = [ + f"Name: {node.info.get('name', 'n/a')}", + f"Op: {node.info.get('op', 'n/a')}", + f"Target: {node.info.get('target', 'n/a')}", + ] + return lines + + def set_base_label_formatter(self, formatter: Callable[[GraphNode], str]): + self.base_label_formatter = formatter + + def set_base_tooltip_formatter(self, formatter: Callable[[GraphNode], List[str]]): + self.base_tooltip_formatter = formatter + + def set_base_color_rule(self, rule: ColorRule): + self.base_color_rule = rule + + def add_extension(self, extension: GraphExtension): + if not isinstance(extension, GraphExtension): + raise TypeError("extension must be a GraphExtension") + if any(ext.id == extension.id for ext in self.extensions): + raise ValueError(f"duplicate extension id: '{extension.id}'") + self.extensions.append(extension) + + @staticmethod + def _get_from_node_root_name(from_node_list): + """Walk from_node chain to root, return root node name.""" + if not from_node_list: + return None + ns = from_node_list[-1] + while getattr(ns, "from_node", None): + ns = ns.from_node[-1] + return getattr(ns, "name", None) + + @staticmethod + def _format_arg(arg): + if isinstance(arg, torch.fx.Node): + return arg.name + if isinstance(arg, (list, tuple)): + return type(arg)(FXGraphExporter._format_arg(a) for a in arg) + if isinstance(arg, dict): + return {k: FXGraphExporter._format_arg(v) for k, v in arg.items()} + return str(arg) + + def _extract_graph(self) -> tuple[dict[str, GraphNode], list[GraphEdge]]: + # torch.fx.Graph.nodes iterates in topological order (documented guarantee). + print("Building graph payload model...") + nodes: dict[str, GraphNode] = {} + edges: list[GraphEdge] = [] + + for idx, fx_node in enumerate(self.graph_module.graph.nodes): + target = fx_node.target + schema = None + if EdgeOpOverload is not None and isinstance(target, EdgeOpOverload): + target_str = target.__name__ + schema = target._schema.schema + elif isinstance(target, torch._ops.OpOverload): + target_str = str(target) + schema = getattr(target, "_schema", None) + else: + target_str = str(target) + + info: dict[str, Any] = { + "op": fx_node.op, + "name": fx_node.name, + "target": target_str, + "args": self._format_arg(fx_node.args), + "kwargs": self._format_arg(fx_node.kwargs), + } + + if schema is not None: + info["schema"] = str(schema) + pos_schema_args = [ + a for a in schema.arguments if not a.kwarg_only + ] + formatted_args = self._format_arg(fx_node.args) + if not isinstance(formatted_args, (list, tuple)): + formatted_args = (formatted_args,) if formatted_args else () + named = {} + for i, val in enumerate(formatted_args): + name = ( + pos_schema_args[i].name + if i < len(pos_schema_args) + else f"arg_{i}" + ) + named[name] = val + formatted_kwargs = self._format_arg(fx_node.kwargs) + if isinstance(formatted_kwargs, dict): + named.update(formatted_kwargs) + info["named_args"] = named + + if "tensor_meta" in fx_node.meta: + tm = fx_node.meta["tensor_meta"] + if isinstance(tm, list): + info["tensor_shape"] = [tuple(t.shape) if hasattr(t, "shape") else None for t in tm] + info["dtype"] = [str(t.dtype) if hasattr(t, "dtype") else None for t in tm] + elif hasattr(tm, "shape"): + info["tensor_shape"] = tuple(tm.shape) + info["dtype"] = str(tm.dtype) if hasattr(tm, "dtype") else None + + for key, value in fx_node.meta.items(): + if key != "tensor_meta" and isinstance(value, (str, int, float, bool)): + info[key] = value + + # Explicitly handle debug_handle (may be int or tuple — not caught by scalar loop) + raw_dh = fx_node.meta.get("debug_handle") + if raw_dh is not None and raw_dh != () and raw_dh != []: + if isinstance(raw_dh, int): + info["debug_handle"] = raw_dh + elif isinstance(raw_dh, (tuple, list)): + ints = [int(x) for x in raw_dh if isinstance(x, int) and x != 0] + if ints: + info["debug_handle"] = ints[0] if len(ints) == 1 else ints + + raw_fn = fx_node.meta.get("from_node") + if raw_fn and isinstance(raw_fn, list) and len(raw_fn) > 0: + root_name = self._get_from_node_root_name(raw_fn) + if root_name: + info["from_node_root"] = root_name + + nodes[fx_node.name] = GraphNode(id=fx_node.name, topo_index=idx, info=info) + + for input_node in fx_node.all_input_nodes: + edges.append(GraphEdge(v=input_node.name, w=fx_node.name)) + + return nodes, edges + + @staticmethod + def _validate_str_list(value: Any, *, context: str) -> list[str]: + if not isinstance(value, list) or any(not isinstance(x, str) for x in value): + warnings.warn(f"{context} must return list[str]", RuntimeWarning, stacklevel=2) + return [] + return value + + def _safe_base_label(self, node: GraphNode) -> str: + label = self.base_label_formatter(node) + if not isinstance(label, str): + warnings.warn( + f"base_label_formatter returned non-str for node '{node.id}', coercing to str", + RuntimeWarning, + stacklevel=2, + ) + return str(label) + return label + + def _safe_base_tooltip(self, node: GraphNode) -> list[str]: + try: + value = self.base_tooltip_formatter(node) + except Exception as exc: + warnings.warn( + f"base_tooltip_formatter failed for node '{node.id}': {exc}", + RuntimeWarning, + stacklevel=2, + ) + return [] + return self._validate_str_list(value, context=f"base_tooltip_formatter(node='{node.id}')") + + def _ext_label_lines_for_layout(self, extension: GraphExtension, node_id: str) -> list[str]: + if not extension.label_formatter or node_id not in extension.nodes_data: + return [] + try: + result = extension.label_formatter(extension.nodes_data[node_id]) + except Exception as exc: + warnings.warn( + f"Extension '{extension.id}' label formatter failed for node '{node_id}' during layout: {exc}", + RuntimeWarning, + stacklevel=2, + ) + return [] + return self._validate_str_list( + result, + context=f"extension '{extension.id}' label formatter(node='{node_id}')", + ) + + @classmethod + def _compute_node_box_size( + cls, + base_label: str, + extension_lines: Sequence[str] | None = None, + ) -> tuple[int, int]: + max_char_width = len(base_label or "") + total_lines = 1 + + if extension_lines: + for line in extension_lines: + max_char_width = max(max_char_width, len(line)) + total_lines += 1 + + width = max(max_char_width * cls._NODE_CHAR_WIDTH + cls._NODE_X_PADDING, cls._NODE_MIN_WIDTH) + height = total_lines * cls._NODE_LINE_HEIGHT + cls._NODE_Y_PADDING + return width, height + + @classmethod + def _compute_layout_with_ext_lines( + cls, + nodes: dict[str, GraphNode], + edges: list[GraphEdge], + ext_label_lines_by_node: dict[str, list[str]], + base_label_getter: Callable[[GraphNode], str], + ) -> None: + for node_id, node in nodes.items(): + base_label = base_label_getter(node) + ext_lines = ext_label_lines_by_node.get(node.id, []) + node.width, node.height = cls._compute_node_box_size(base_label, ext_lines) + + if not nodes: + return + + try: + from fast_sugiyama import from_edges + except ImportError as exc: + raise ImportError( + "fx_viewer layout requires 'fast-sugiyama' (and rectangle-packer " + "for multi-component packing). Install with: " + "pip install 'fast-sugiyama[all]' (requires Python >= 3.11)" + ) from exc + + print("Running fast-sugiyama layout...") + edge_list = [(e.v, e.w) for e in edges if e.v in nodes and e.w in nodes] + widths = [n.width for n in nodes.values()] + # Median width as vertex_spacing keeps the layout tight for typical + # nodes; wide outliers are handled by the per-layer compaction below. + from statistics import median + baseline_w = median(widths) + max_w = max(widths) + expected_gap = baseline_w + cls._LAYOUT_XSPACE + + raw_layouts = from_edges( + edge_list, + vertex_spacing=expected_gap, + minimum_length=1, + dummy_vertices=True, + crossing_minimization="median", + ) + + # Adaptive per-layer x-spacing + y-row compaction using actual node + # sizes. Preserves fast-sugiyama's within-layer ordering (so crossings + # are preserved) while tightening gaps for narrow nodes and widening + # them for long labels. + adjusted = cls._compact_components(raw_layouts, nodes, expected_gap) + + # Isolated nodes (zero edges in the edge_list) are invisible to + # from_edges; synthesize a one-node component for each so rect_pack + # tiles them into the layout instead of leaving them stacked at 0,0. + referenced: set = set() + for a, b in edge_list: + referenced.add(a) + referenced.add(b) + for nid, node in nodes.items(): + if nid not in referenced: + adjusted.append( + ([(nid, (0.0, 0.0))], float(node.width), float(node.height), []) + ) + + # Pack components. fast-sugiyama's bbox is center-to-center only, so + # pad spacing by max_w + xspace to prevent edge-level overlap between + # adjacent components. + from fast_sugiyama.layout import Layouts + pack_spacing = max_w + cls._LAYOUT_XSPACE + widest_component = max((w for _pos, w, _h, _e in adjusted), default=0.0) + pack_width = int(max(widest_component*3 + pack_spacing, 2000.0)) + layouts = Layouts(adjusted).rect_pack_layouts( + max_width=pack_width, + spacing=pack_spacing, + ) + + positions: dict[Any, tuple[float, float]] = {} + expanded_edges: list[tuple[Any, Any]] = [] + for positions_list, _w, _h, edges_with_dummies in layouts: + positions.update(dict(positions_list)) + if edges_with_dummies: + expanded_edges.extend(edges_with_dummies) + + for node_id, node in nodes.items(): + if node_id in positions: + x, y = positions[node_id] + node.x = float(x) + node.y = float(y) + + edge_map = {(e.v, e.w): e for e in edges} + for (u, v), pts in cls._polylines_from_dummy_chain( + expanded_edges, nodes, positions + ).items(): + if (u, v) not in edge_map: + continue + clipped = cls._clip_edge_polyline(pts, nodes[u], nodes[v]) + edge_map[(u, v)].points = [ + {"x": float(x), "y": float(y)} for (x, y) in clipped + ] + + @classmethod + def _compact_components(cls, layouts, nodes, expected_gap): + from collections import defaultdict + + def _w(nid): + n = nodes.get(nid) + return n.width if n is not None else cls._DUMMY_SIZE_X + + def _h(nid): + n = nodes.get(nid) + return n.height if n is not None else cls._DUMMY_SIZE_Y + + xspace = cls._LAYOUT_XSPACE + yspace = cls._LAYOUT_YSPACE + expected = max(float(expected_gap), 1.0) + + new_layouts = [] + for pos, w, h, el in layouts: + x_orig: dict = {nid: px for nid, (px, _) in pos} + x: dict = dict(x_orig) + + by_y: dict = defaultdict(list) + for nid, (_, py) in pos: + by_y[py].append(nid) + + def _sweep_min_gap(nids): + nids.sort(key=lambda n: x[n]) + for i in range(1, len(nids)): + prev, cur = nids[i - 1], nids[i] + min_gap = (_w(prev) + _w(cur)) / 2 + xspace + if x[cur] - x[prev] < min_gap: + x[cur] = x[prev] + min_gap + + # Phase 1: chain detection (real + dummy members) + chains = cls._detect_chains(el or [], nodes) + + FAIR_RUNS=5 + # Phase 2: iterative spine cohesion + pure-A overlap fix + for i in range(cls._SPINE_COHESION_ITER + FAIR_RUNS): + + # delta for each node + # We record the relative weight (chain length) and their base delta + node_delta: dict = defaultdict(list) + for ch in chains: + if not ch: + continue + mean_x = sum(x[v] for v in ch) / len(ch) + # emphasize end point (disable for last FAIR_RUNS iters) + if i < cls._SPINE_COHESION_ITER: + # we use common start and end node of chain to attract chain close together + mean_x = (mean_x + x[ch[0]] + x[ch[-1]]) / 3.0 + for v in ch: + # weight: len(ch) + # delta: mean_x - x[v] + node_delta[v].append((len(ch), (mean_x - x[v]))) + # move the node x + for n, deltas in node_delta.items(): + total_weight = sum(w for w, _ in deltas) + x[n] += sum(w / total_weight * d for w, d in deltas) + # adjust node overlapping + for nids in by_y.values(): + _sweep_min_gap(nids) + + + # Phase 3: vertical compaction with flipped y so inputs land + # at the top of the canvas and outputs at the bottom. The + # iteration runs largest-original-y first so that rank + # receives new_y = 0 (top of canvas) and deeper ranks get + # monotonically larger new_y values. + distinct_ys = sorted(by_y.keys(), reverse=True) + layer_h = { + y: max((_h(n) for n in by_y[y]), default=0.0) + for y in distinct_ys + } + new_y: dict = {} + cursor = 0.0 + for i, y in enumerate(distinct_ys): + if i == 0: + new_y[y] = cursor + else: + cursor += ( + layer_h[distinct_ys[i - 1]] + layer_h[y] + ) / 2 + yspace + new_y[y] = cursor + + new_positions = [(nid, (x[nid], new_y[py])) for nid, (_, py) in pos] + xs = [xy[0] for _, xy in new_positions] + ys = [xy[1] for _, xy in new_positions] + new_w = (max(xs) - min(xs)) if xs else w + new_h = (max(ys) - min(ys)) if ys else h + new_layouts.append((new_positions, new_w, new_h, el)) + return new_layouts + + @classmethod + def _detect_chains(cls, edge_list, nodes): + # Here we break the graph into chains (connected node list) + # The longer the chain the better (aligned visual vertical axis) + # We achieve long chain by calculating best_prev and best_succ with rank + # We must let the chain start and end node to be shared + # The shared end points (common nodes) will be used to pull chains near in later iterative loop + from collections import defaultdict + + if not edge_list: + return [] + + succ: dict = defaultdict(set) + prev: dict = defaultdict(set) + for u, v in edge_list: + succ[u].add(v) + prev[v].add(u) + + all_nodes = set(succ) | set(prev) + + # max depth a node's output can reach + node_out_rank: dict = defaultdict(int) + # the succer node that have maximal node_out_rank + best_succ: dict = {} + graph_output_nodes = [n for n in all_nodes if len(succ[n]) == 0] + stack = graph_output_nodes + while stack: + n = stack.pop() + for pn in prev[n]: + score = 2 if pn in nodes else 1 + if node_out_rank[pn] < node_out_rank[n] + score: + node_out_rank[pn] = node_out_rank[n] + score + stack.append(pn) + best_succ[pn] = n + + # max depth a node's input can reach + node_in_rank: dict = defaultdict(int) + # the prev node that have maximal node_in_rank + best_prev: dict = {} + graph_input_nodes = [n for n in all_nodes if len(prev[n]) == 0] + stack = graph_input_nodes + while stack: + n = stack.pop() + for nn in succ[n]: + score = 2 if nn in nodes else 1 + if node_in_rank[nn] < node_in_rank[n] + score: + node_in_rank[nn] = node_in_rank[n] + score + stack.append(nn) + best_prev[nn] = n + + + visited: set = set() + chains: list = [] + for start in sorted(all_nodes, key=lambda n:node_out_rank[n], reverse=True): + if start in visited: + continue + cur = start + walk = [start] + if start in best_prev: + walk.insert(0, best_prev[start]) + while cur not in visited: + visited.add(cur) + if cur not in best_succ: + break + nxt = best_succ[cur] + walk.append(nxt) + cur = nxt + if len(walk) >= 2: # always true + chains.append(walk) + return chains + + @staticmethod + def _polylines_from_dummy_chain( + expanded_edges: list[tuple[Any, Any]], + nodes: dict[str, GraphNode], + positions: dict[Any, tuple[float, float]], + ) -> dict[tuple[Any, Any], list[tuple[float, float]]]: + forward: dict[Any, Any] = {} + for u, v in expanded_edges: + forward[u] = v + + polylines: dict[tuple[Any, Any], list[tuple[float, float]]] = {} + for u, v in expanded_edges: + if u not in nodes: + continue + chain: list[tuple[float, float]] = [] + cur = v + while cur not in nodes: + if cur not in positions: + break + chain.append(positions[cur]) + if cur not in forward: + break + cur = forward[cur] + if cur in nodes and u in positions and cur in positions: + polylines[(u, cur)] = [positions[u], *chain, positions[cur]] + return polylines + + @staticmethod + def _clip_point_to_aabb( + center: tuple[float, float], + half: tuple[float, float], + toward: tuple[float, float], + ) -> tuple[float, float]: + cx, cy = center + hw, hh = half + dx = toward[0] - cx + dy = toward[1] - cy + if dx == 0.0 and dy == 0.0: + return center + tx = hw / abs(dx) if dx != 0.0 else float("inf") + ty = hh / abs(dy) if dy != 0.0 else float("inf") + t = min(tx, ty, 1.0) + return (cx + t * dx, cy + t * dy) + + @classmethod + def _clip_edge_polyline( + cls, + points: list[tuple[float, float]], + src_node: GraphNode, + tgt_node: GraphNode, + ) -> list[tuple[float, float]]: + if len(points) < 2: + return points + clipped = list(points) + # Edges exit the source from its bottom-midpoint and enter the + # target at its top-midpoint. Predictable anchors keep parallel + # edges docked together and avoid endpoint/box intersections. + clipped[0] = ( + float(src_node.x), + float(src_node.y) + float(src_node.height) / 2.0, + ) + clipped[-1] = ( + float(tgt_node.x), + float(tgt_node.y) - float(tgt_node.height) / 2.0, + ) + return clipped + + @staticmethod + def _segment_crosses_aabb( + p1: tuple[float, float], + p2: tuple[float, float], + aabb_center: tuple[float, float], + aabb_half: tuple[float, float], + ) -> bool: + # Liang-Barsky style parametric slab test in the AABB's local frame. + cx, cy = aabb_center + hx, hy = aabb_half + x1, y1 = p1[0] - cx, p1[1] - cy + x2, y2 = p2[0] - cx, p2[1] - cy + dx, dy = x2 - x1, y2 - y1 + t_enter, t_exit = 0.0, 1.0 + for p, q in ((-dx, x1 + hx), (dx, hx - x1), (-dy, y1 + hy), (dy, hy - y1)): + if p == 0.0: + if q < 0.0: + return False + continue + t = q / p + if p < 0.0: + if t > t_exit: + return False + if t > t_enter: + t_enter = t + else: + if t < t_enter: + return False + if t < t_exit: + t_exit = t + return t_enter < t_exit + + def _compute_layout(self, nodes: dict[str, GraphNode], edges: list[GraphEdge]) -> None: + ext_label_lines_by_node: dict[str, list[str]] = {} + for node_id in nodes: + ext_lines: list[str] = [] + for ext in self.extensions: + ext_lines.extend(self._ext_label_lines_for_layout(ext, node_id)) + ext_label_lines_by_node[node_id] = ext_lines + + self._compute_layout_with_ext_lines( + nodes, + edges, + ext_label_lines_by_node=ext_label_lines_by_node, + base_label_getter=self._safe_base_label, + ) + + @staticmethod + def _coerce_str_lines(value: Any) -> list[str]: + if not isinstance(value, list): + return [] + return [x for x in value if isinstance(x, str)] + + @classmethod + def relayout_payload_base( + cls, + base_payload: Dict[str, Any], + extensions_payload: Optional[Dict[str, Any]] = None, + include_layers: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """Recompute base graph node/edge layout using extension label lines. + + This API operates on payload dictionaries and does not require + a ``torch.fx.GraphModule`` instance. + """ + if not isinstance(base_payload, dict): + raise TypeError("base_payload must be a dict") + + relaid = copy.deepcopy(base_payload) + + raw_nodes = relaid.get("nodes") + raw_edges = relaid.get("edges") + if not isinstance(raw_nodes, list) or not isinstance(raw_edges, list): + raise ValueError("base_payload must contain list fields: 'nodes' and 'edges'") + + nodes: dict[str, GraphNode] = {} + for idx, node_data in enumerate(raw_nodes): + if not isinstance(node_data, dict): + continue + node_id = str(node_data.get("id", "")).strip() + if not node_id: + continue + + info_value = node_data.get("info", {}) + tooltip_value = node_data.get("tooltip", []) + node = GraphNode( + id=node_id, + label=str(node_data.get("label", "")), + topo_index=int(node_data.get("topo_index", idx)), + info=info_value if isinstance(info_value, dict) else {}, + tooltip=tooltip_value if isinstance(tooltip_value, list) else [], + fill_color=node_data.get("fill_color"), + ) + nodes[node_id] = node + + edges: list[GraphEdge] = [] + for edge_data in raw_edges: + if not isinstance(edge_data, dict): + continue + v = str(edge_data.get("v", "")).strip() + w = str(edge_data.get("w", "")).strip() + if not v or not w or v not in nodes or w not in nodes: + continue + edges.append(GraphEdge(v=v, w=w, points=[])) + + ext_payloads = extensions_payload if isinstance(extensions_payload, dict) else {} + if include_layers is None: + active_layer_ids = list(ext_payloads.keys()) + else: + active_layer_ids = [layer_id for layer_id in include_layers if layer_id in ext_payloads] + + ext_label_lines_by_node: dict[str, list[str]] = {node_id: [] for node_id in nodes} + for layer_id in active_layer_ids: + layer_payload = ext_payloads.get(layer_id) + if not isinstance(layer_payload, dict): + continue + layer_nodes = layer_payload.get("nodes") + if not isinstance(layer_nodes, dict): + continue + for node_id, node_payload in layer_nodes.items(): + if node_id not in ext_label_lines_by_node or not isinstance(node_payload, dict): + continue + ext_label_lines_by_node[node_id].extend( + cls._coerce_str_lines(node_payload.get("label_append")) + ) + + cls._compute_layout_with_ext_lines( + nodes, + edges, + ext_label_lines_by_node=ext_label_lines_by_node, + base_label_getter=lambda node: str(node.label or ""), + ) + + node_by_id = {node.get("id"): node for node in raw_nodes if isinstance(node, dict)} + for node_id, node in nodes.items(): + node_dict = node_by_id.get(node_id) + if not isinstance(node_dict, dict): + continue + node_dict["x"] = node.x + node_dict["y"] = node.y + node_dict["width"] = node.width + node_dict["height"] = node.height + + edge_points_by_key = {(edge.v, edge.w): edge.points for edge in edges} + for edge_data in raw_edges: + if not isinstance(edge_data, dict): + continue + key = ( + str(edge_data.get("v", "")).strip(), + str(edge_data.get("w", "")).strip(), + ) + edge_data["points"] = edge_points_by_key.get(key, []) + + return relaid + + def _build_base_payload(self, nodes: dict[str, GraphNode], edges: list[GraphEdge]) -> BaseGraphPayload: + print("[FX Graph Viewer] Compiling base graph payload...") + + base_color_input = {node_id: node.info for node_id, node in nodes.items()} + base_colors: dict[str, str] = {} + base_legend: list[dict[str, str]] = [] + if self.base_color_rule: + base_colors, base_legend = self.base_color_rule.apply(base_color_input) + + for node in nodes.values(): + node.label = self._safe_base_label(node) + node.tooltip = self._safe_base_tooltip(node) + if node.id in base_colors: + node.fill_color = base_colors[node.id] + + return BaseGraphPayload( + legend=base_legend, + nodes=list(nodes.values()), + edges=edges, + ) + + def _build_extensions_payload(self) -> dict[str, GraphExtensionPayload]: + print("[ FX Graph Viewer ] Compiling extension payloads...") + return {ext.id: ext.build_payload() for ext in self.extensions} + + def generate_json_payload(self) -> Dict[str, Any]: + nodes, edges = self._extract_graph() + self._compute_layout(nodes, edges) + base_payload = self._build_base_payload(nodes, edges) + extensions_payload = self._build_extensions_payload() + payload = GraphPayload(base=base_payload, extensions=extensions_payload) + return asdict(payload) + + def export_json(self, output_path: str): + data = self.generate_json_payload() + with open(output_path, "w") as f: + json.dump(data, f, indent=2) + print(f"Success! Exported JSON payload to {output_path}") + + @staticmethod + def _load_viewer_js_bundle() -> str: + template_dir = os.path.join(os.path.dirname(__file__), "templates") + ordered_files = [ + "runtime.js", + "graph_data_store.js", + "search_engine.js", + "view_controller.js", + "canvas_renderer.js", + "minimap_renderer.js", + "ui_manager.js", + "fx_graph_viewer.js", + "compare.js", + ] + chunks = [] + for filename in ordered_files: + path = os.path.join(template_dir, filename) + with open(path, "r") as f: + chunks.append(f"\n// ---- {filename} ----\n") + chunks.append(f.read()) + + return "\n".join(chunks) + + def export_js(self, container_id: str) -> str: + data = self.generate_json_payload() + json_str = json.dumps(data) + js_content = self._load_viewer_js_bundle() + + return f""" + const graphPayload = {json_str}; + + {js_content} + + (function() {{ + try {{ + const viewer = FXGraphViewer.create({{ + payload: graphPayload, + mount: {{ root: '#{container_id}' }}, + }}); + viewer.init(); + window.fxViewer = viewer; + }} catch (e) {{ + console.error("Failed to mount FXGraphViewer:", e); + const container = document.getElementById('{container_id}'); + if (container) {{ + container.innerHTML = "
Error mounting graph: " + e.message + "
"; + }} + }} + }})(); + """ + + def export_html(self, output_html: str = "model_graph.html"): + data = self.generate_json_payload() + json_str = json.dumps(data) + js_content = self._load_viewer_js_bundle() + + html_content = f""" + + + + PyTorch FX Graph Viewer V3 + + + +
+
+
Loading Graph Viewer...
+
+
+ + + + + + +""" + + print(f"Writing to {output_html}...") + with open(output_html, "w") as f: + f.write(html_content) + + print(f"Success! Exported extensible graph to {output_html}") diff --git a/devtools/fx_viewer/extension.py b/devtools/fx_viewer/extension.py new file mode 100644 index 00000000000..c0472932ce5 --- /dev/null +++ b/devtools/fx_viewer/extension.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from dataclasses import asdict +from typing import Dict, Any, Callable, Optional +import warnings + +from .color_rules import ColorRule +from .models import GraphExtensionNodePayload, GraphExtensionPayload + + +class GraphExtension: + """Optional annotation layer attached to the base FX graph.""" + + def __init__(self, id: str, name: str): + clean_id = id.strip() + if not clean_id: + raise ValueError("GraphExtension id must be non-empty") + if not name.strip(): + raise ValueError("GraphExtension name must be non-empty") + + self.id = clean_id + self.name = name + self.nodes_data: Dict[str, Dict[str, Any]] = {} + self.sync_keys: list[str] = [] + + self.color_rule: Optional[ColorRule] = None + self.label_formatter: Optional[Callable[[Dict[str, Any]], list[str]]] = None + self.tooltip_formatter: Optional[Callable[[Dict[str, Any]], list[str]]] = None + + def add_node_data(self, node_id: str, data: Dict[str, Any]): + if node_id not in self.nodes_data: + self.nodes_data[node_id] = {} + self.nodes_data[node_id].update(data) + + def set_color_rule(self, rule: ColorRule): + self.color_rule = rule + + def set_sync_key(self, field: str): + """Mark a data field as a compare-mode sync key. + + When registered, the field appears as an explicit option in the compare + sidebar under "Ext: .". Selecting it activates + ``mode: 'layer'`` sync with this extension and field. + + Example:: + + ext.set_sync_key("debug_handle") + # sidebar shows: "Ext: my_ext.debug_handle" + + Multiple sync keys can be registered on the same extension. + """ + if field not in self.sync_keys: + self.sync_keys.append(field) + + def set_label_formatter(self, formatter: Callable[[Dict[str, Any]], list[str]]): + self.label_formatter = formatter + + def set_tooltip_formatter(self, formatter: Callable[[Dict[str, Any]], list[str]]): + self.tooltip_formatter = formatter + + def _format_lines( + self, + *, + formatter: Callable[[Dict[str, Any]], list[str]], + data: Dict[str, Any], + node_id: str, + kind: str, + ) -> list[str]: + try: + result = formatter(data) + except Exception as exc: + warnings.warn( + f"Extension '{self.id}' {kind} formatter failed for node '{node_id}': {exc}", + RuntimeWarning, + stacklevel=2, + ) + return [] + + if not isinstance(result, list) or any(not isinstance(x, str) for x in result): + warnings.warn( + f"Extension '{self.id}' {kind} formatter must return list[str] for node '{node_id}'", + RuntimeWarning, + stacklevel=2, + ) + return [] + + return result + + def build_payload(self) -> GraphExtensionPayload: + node_colors = {} + legend = [] + + if self.color_rule: + node_colors, legend = self.color_rule.apply(self.nodes_data) + + compiled_nodes: Dict[str, GraphExtensionNodePayload] = {} + + for node_id, data in self.nodes_data.items(): + compiled = GraphExtensionNodePayload(info=data) + + if self.label_formatter: + lines = self._format_lines( + formatter=self.label_formatter, + data=data, + node_id=node_id, + kind="label", + ) + if lines: + compiled.label_append = lines + + if self.tooltip_formatter: + lines = self._format_lines( + formatter=self.tooltip_formatter, + data=data, + node_id=node_id, + kind="tooltip", + ) + if lines: + compiled.tooltip = lines + + if node_id in node_colors: + compiled.fill_color = node_colors[node_id] + + compiled_nodes[node_id] = compiled + + return GraphExtensionPayload( + id=self.id, + name=self.name, + legend=legend, + nodes=compiled_nodes, + sync_keys=list(self.sync_keys), + ) + + def build(self) -> Dict[str, Any]: + """Backward-compatible dict payload export.""" + return asdict(self.build_payload()) diff --git a/devtools/fx_viewer/models.py b/devtools/fx_viewer/models.py new file mode 100644 index 00000000000..28c0a8f8f70 --- /dev/null +++ b/devtools/fx_viewer/models.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class GraphNode: + """Wire-format node schema used by exporter and JSON payload.""" + + id: str + label: str = "" + x: float = 0.0 + y: float = 0.0 + width: float = 100.0 + height: float = 40.0 + topo_index: int = 0 + info: dict[str, Any] = field(default_factory=dict) + tooltip: list[str] = field(default_factory=list) + fill_color: str | None = None + + +@dataclass +class GraphEdge: + """Wire-format edge schema used by exporter and JSON payload.""" + + v: str + w: str + points: list[dict[str, float]] = field(default_factory=list) + + +@dataclass +class BaseGraphPayload: + """Nodes are in topological order (guaranteed by torch.fx.Graph.nodes iteration).""" + + legend: list[dict[str, str]] + nodes: list[GraphNode] + edges: list[GraphEdge] + + +@dataclass +class GraphExtensionNodePayload: + """Wire-format extension node schema.""" + + info: dict[str, Any] = field(default_factory=dict) + tooltip: list[str] = field(default_factory=list) + label_append: list[str] = field(default_factory=list) + fill_color: str | None = None + + +@dataclass +class GraphExtensionPayload: + """Wire-format extension layer schema.""" + + id: str + name: str + legend: list[dict[str, str]] = field(default_factory=list) + nodes: dict[str, GraphExtensionNodePayload] = field(default_factory=dict) + sync_keys: list[str] = field(default_factory=list) + + +@dataclass +class GraphPayload: + base: BaseGraphPayload + extensions: dict[str, GraphExtensionPayload] diff --git a/devtools/fx_viewer/templates/README.md b/devtools/fx_viewer/templates/README.md new file mode 100644 index 00000000000..0788f105287 --- /dev/null +++ b/devtools/fx_viewer/templates/README.md @@ -0,0 +1,537 @@ +# fx_viewer JS Runtime + +This folder contains the browser runtime used by `FXGraphExporter`. + +## Runtime Model + +1. `FXGraphViewer` is the facade and public API. +2. `ViewerController` is the interaction/state controller. +3. `GraphDataStore` owns base + extension data and composes active virtual nodes. +4. Renderers (`CanvasRenderer`, `MinimapRenderer`) paint from controller/store state. +5. `UIManager` is a state adapter for taskbar/search/layers/info/legend controls. +6. `FXGraphCompare` orchestrates multi-view compare: builds compare DOM, wires sync, owns lifecycle. +7. `FXCompareTaskbar` renders the optional shared taskbar above the compare grid. + +## Files and Responsibilities + +1. `runtime.js`: shared theme tokens (`THEMES`), event utilities (`fxOn`, `fxOffAll`, `fxEsc`). +2. `graph_data_store.js`: payload normalization, topology cache, virtual-node composition. +3. `search_engine.js`: fuzzy search over active nodes. +4. `view_controller.js`: state machine and interaction orchestration. +5. `canvas_renderer.js`: primary graph rendering + canvas interactions. +6. `minimap_renderer.js`: minimap rendering + minimap navigation. +7. `ui_manager.js`: taskbar/search/layers/info panel/legend DOM. +8. `fx_graph_viewer.js`: `FXGraphViewer` facade (CSS in `_injectStyles()`). +9. `compare.js`: `FXCompareTaskbar` + `FXGraphCompare` orchestrator. + +## Public API + +The canonical API reference is maintained in: +1. `backends/qualcomm/utils/fx_viewer/README.md` (`JS API (Runtime)` section) + +This file focuses on runtime internals, file responsibilities, and script load order. + +## Config Precedence + +1. `mount.slots.*` (placement) has highest precedence. +2. Explicit `layout.*` overrides preset defaults. +3. `layout.preset` fills missing values. +4. Built-in defaults are last fallback. + +## Compare View DOM and Ownership + +`FXGraphCompare` builds its own DOM shell inside `layout.container` and moves viewer sub-elements into it. The viewer's own wrapper is hidden but not destroyed. + +### DOM tree + +``` +layout.container + .fx-compare-root flex column; created by FXGraphCompare + .fx-compare-taskbar optional; created by FXCompareTaskbar when sharedTaskbar.enabled + .fx-compare-grid CSS grid, repeat(N, 1fr) columns + .fx-compare-col one per viewer; flex column + .fx-compare-col-header + .fx-compare-minimap-row fixed height; viewer.minimapRenderer.container moved here + .fx-compare-canvas-row flex:1; viewer.mainArea moved here + .fx-compare-info-bar single shared merged info panel; full width +``` + +### Ownership rules + +- `FXGraphCompare` creates `.fx-compare-root`, `.fx-compare-grid`, `.fx-compare-col`, `.fx-compare-info-bar`. +- `FXCompareTaskbar` creates `.fx-compare-taskbar` and prepends it to `.fx-compare-root`. +- Each viewer's `.fx-viewer-wrapper` is hidden (`display:none`) while compare is active. +- `viewer.mainArea` and `viewer.minimapRenderer.container` are moved into compare columns; all other viewer internals stay in the hidden wrapper. +- DOM snapshots (parent + nextSibling) are recorded before any move. `_teardownCompareDOM()` restores them on `destroy()`. + +### Resize handling + +- A `ResizeObserver` on each `.fx-compare-canvas-row` calls `viewer.canvasRenderer.resize()` + `renderAll()`. +- An initial `requestAnimationFrame` fires after `_buildCompareDOM()` for the first layout pass. +- `MinimapRenderer` has its own `ResizeObserver` on its container for deferred-visibility cases (observatory, collapsed sections). + +### Interaction ownership + +| Concern | Owner | +|---------|-------| +| Selection sync | `FXGraphCompare._wireSelectionSync()` — source-guarded, propagates via `viewer.selectNode()` | +| Theme sync | `FXCompareTaskbar` (shared taskbar) + `FXGraphCompare._wireStateSync()` (state events) | +| Layers / ColorBy | `FXCompareTaskbar` — union of all extension ids, per-viewer `setLayers()` | +| Zoom / Fullscreen | `FXCompareTaskbar` — calls viewer public API or `requestFullscreen()` on root | +| Column count | `FXGraphCompare.setColumns()` — updates grid CSS, triggers resize RAF | +| Merged info panel | `FXGraphCompare._updateMergedInfo()` — called after selection sync, renders diff table | +| Viewer visibility | `FXGraphCompare.setViewerVisible(name, visible)` — show/hide graph columns | +| Follow selection | Sidebar toggle — controls auto-pan on selection change | + +### Selection Sync Modes + +`FXGraphCompare` propagates the primary selection from the source viewer to all other viewers using `_findSyncTarget`. The sidebar selector controls the active mode. + +| Mode | Sidebar label | Behavior | +|------|--------------|----------| +| `'auto'` (default) | Auto (handle→id) | Tries `debug_handle` set-intersection first; falls back to node-ID match | +| `'id'` | ID only | Selects the node with the same id in each target viewer; no-op if absent | +| `'layer'` | Ext: \.\ | Matches by `extensions[layer].nodes[nodeId].info[field]` value equality; picks last in topo order on multiple matches | +| `'none'` | Don't sync | No cross-viewer selection propagation | + +The initial sync mode can be set programmatically via `config.sync` at construction time. Observatory's `GraphCompareSpec.default_sync` maps directly to this field, allowing lenses to declare their preferred default sync strategy. + +#### `debug_handle` normalization (`mode: 'auto'`) + +`debug_handle` in `node.info` is `int` (scalar) or `int[]` (list, for fused nodes). The sync engine normalizes both to a `Set` and uses **set intersection** to find matches: + +- `int dh` → `{dh}` (if non-zero) +- `int[] dh` → `Set(dh.filter(x => x !== 0))` +- `null / 0 / []` → empty set (no match) + +Two nodes match if their normalized sets have a non-empty intersection. When multiple target nodes match, the **last in topological order** is selected (highest `topo_index`). + +This enables three mapping patterns: +- **1-to-1**: same handle on both sides → direct match. +- **1-to-many** (decomposed ops): one source node → multiple target nodes sharing the same handle (e.g. `linear` → `t + mm + add`). The last decomposed op is selected. +- **many-to-1** (fused ops): a fused node carries a union tuple handle `(h1, h2)`. Any source node whose handle intersects `{h1, h2}` will match the fused node. + +#### Registering a sync key on an extension + +To expose an extension field as an explicit sync option in the sidebar, call `set_sync_key` on the Python `GraphExtension`: + +```python +ext = GraphExtension(id="my_ext", name="My Extension") +ext.add_node_data(node_id, {"debug_handle": 42, ...}) +ext.set_sync_key("debug_handle") # appears as "Ext: my_ext.debug_handle" in sidebar +``` + +The sidebar will show `Ext: my_ext.debug_handle` as a selectable option. Selecting it activates `mode: 'layer'` with `layer='my_ext'` and `field='debug_handle'`. + +### `compare.setViewerVisible(name, visible)` + +Show or hide a graph column by name. Equivalent to checking/unchecking in the Graphs menu. + +- `name` {string} — viewer name as passed to `FXGraphCompare.create()` +- `visible` {boolean} — `true` to show, `false` to hide + +When showing a viewer that has a corresponding node for the current selection in another viewer, the newly visible viewer will pan to that node automatically. + +### Follow Selection toggle + +The sidebar includes a "Follow" toggle button. When enabled (\u2299), every selection change auto-pans all viewers to the corresponding node. When disabled (\u25cb), selections still sync but viewers don't auto-pan (user can navigate freely). Newly re-enabled viewers always pan to the active selection regardless of this toggle. + +## Payload Contract + +Runtime input payload: + +```js +{ + base: { + legend: [{ label, color }], + nodes: [{ id, label, x, y, width, height, info, tooltip, fill_color? }], + // info.debug_handle: int | int[] — present when generate_missing_debug_handles() was called + edges: [{ v, w, points? }] + }, + extensions: { + [extId]: { + name: string, + legend: [{ label, color }], + sync_keys: string[], // fields registered via ext.set_sync_key(); drives sidebar options + nodes: { + [nodeId]: { + info?: object, // arbitrary key/value; info.debug_handle used by 'auto' sync + tooltip?: string[], + label_append?: string[], + fill_color?: string + } + } + } + } +} +``` + +## Script Load Order + +1. `runtime.js` +2. `graph_data_store.js` +3. `search_engine.js` +4. `view_controller.js` +5. `canvas_renderer.js` +6. `minimap_renderer.js` +7. `ui_manager.js` +8. `fx_graph_viewer.js` +9. `compare.js` + +## Maintenance Notes + +1. Keep module boundaries strict; route orchestration through controller/facade. +2. Preserve payload compatibility when adding UI/runtime features. +3. If state shape changes, update docs and relevant contracts in this folder. +4. `FXGraphCompare` must always restore viewer DOM on `destroy()` — never leave viewer.mainArea detached. + +--- + +## Runtime Internals + +### GraphDataStore + +Manages the raw JSON graph payload and constructs the "Virtual Node" topology. + +**State:** `baseData` (structural nodes/edges), `extensions` (annotation overlays), `activeNodes` (pre-computed flat array), `activeNodeMap` (O(1) id→node), `adjList`/`revAdjList` (edge traversal), `graphBounds` (camera zoom target). + +**Topology Init (`_initTopology`):** Loops over `baseData.nodes` once to calculate global bounds. Normalizes coordinates so the top-left node starts at (50, 50). Builds adjacency lists. + +**Virtual Node Composition (`computeActiveGraph`):** For each base node, creates a flat `info` dict. Iterates active extension ids; if an extension has data for the node, prefixes keys (e.g. `Profiler.latency: 15`) and merges them. Concatenates `label_append` and `tooltip` arrays. Resolves `fill_color` from `colorById`. Pre-computing `activeNodes` on checkbox toggle avoids GC pressure during the 60FPS render loop. + +**Traversal:** `getAncestors(id)` / `getDescendants(id)` — BFS over `revAdjList`/`adjList` for canvas selection highlighting. + +--- + +### SearchEngine + +Fuzzy search over active graph nodes with token scoring and context highlighting. + +**Algorithm:** +1. Tokenize query by spaces (e.g. `"conv 15ms"` → `["conv", "15ms"]`). +2. Iterate `activeNodes`. Because `node.info` is a flattened dict with extension prefixes, the engine searches extension data natively. +3. Scoring: `node.id` match = +10, `op` = +5, `target` = +3, other key/value = +1. +4. Context highlighting: wraps matched substring in `` so the dropdown shows exactly why a node matched. +5. Filter to nodes matching the maximum number of tokens (fuzzy AND). + +--- + +### ViewerController + +Centralized state machine managing interactions, camera transforms, selections, and extension visibility. + +**State fields:** `hoveredNodeId`, `hoveredEdge`, `selectedNodeId`, `selectedEdge`, `previewNodeId`, `ancestors`/`descendants` (Sets), `searchCandidates`, `searchSelectedIndex`, `highlightAncestors`, `themeName`, `activeExtensions` (Set), `colorBy`, `highlightGroups` (Map, color: string}>). + +**`setState(patch)`:** Merges patch. If `activeExtensions` or `colorBy` changed, calls `store.computeActiveGraph()`, regenerates minimap thumbnail, updates legend and info panel. If `themeName` changed, calls `ui.applyThemeToDOM()`. Always calls `viewer.renderAll()` and emits `statechange`. + +**`animateToTransform(x, y, k)`:** Uses `requestAnimationFrame` with easeOutCubic over 300ms to interpolate camera position. + +**`zoomToFit()`:** If a node is selected, collects its 2-hop neighborhood; if an edge is selected, collects 1-hop neighbors of both endpoints. Computes bounding box and animates camera to fit. Falls back to full graph bounds if nothing is selected. + +**Selection:** `selectNode` / `selectEdge` run BFS ancestors/descendants via `store`, update state, and call `ui.updateInfoPanel` / `ui.updateEdgeInfoPanel`. `clearSelection` nullifies all selection state and hides the info panel. + +**Search flow:** `handleSearch` → `SearchEngine.search` → `setState({searchCandidates})` → `ui.updateSearchResults`. Arrow keys call `handleSearchNavigate` (pan preview). Enter/click calls `handleSearchSelect` (full select + close menu). + +--- + +### CanvasRenderer + +High-performance 2D canvas rendering of the main graph with pan/zoom and hover interactions. + +**Coordinate spaces:** DOM mouse events are in Screen Space. Nodes/edges live in Graph Space. Conversion: `graphX = (screenX - transform.x) / transform.k`. Device pixel ratio (`dpr`) is applied via `ctx.scale(dpr, dpr)` to prevent blurring on retina displays. + +**Pan/zoom:** `mousedown`/`mousemove` delta → `transform.x/y`. `wheel` → exponential `zoomFactor`, pivot at cursor: `transform.x = mouseX - graphX * newK`. + +**`render()` loop:** +1. Clear canvas, fill theme background. +2. Apply `ctx.scale(dpr, dpr)`, `ctx.translate(transform.x, y)`, `ctx.scale(k, k)`. +3. Compute `opacity = 0.15` for nodes/edges outside the active selection ancestry. +4. Draw edges (with midpoint tensor-shape labels), then node rectangles with multi-line labels. +5. **Highlight group overlay pass** (after node rendering): iterates `state.highlightGroups`; for each group draws a thick 6px solid border outside the node rect (offset by `borderWidth/2` so it does not clip node fill or text). Multiple groups coexist; last group in Map iteration order wins visually on overlapping nodes. + +**`drawSmartTooltip()`:** Tests 4 candidate positions (up/down/left/right) against viewport bounds. Prefers "right" if it fits. Draws a dashed connector line scaled by `1/transform.k`. + +**Dynamic color:** `shadeColor()` lightens/darkens custom extension fill colors for hover/selected/ancestor states, preserving analytical heatmap context. + +--- + +### MinimapRenderer + +Minimap overview rendering with viewport tracking and click/drag navigation. + +**Coordinate transforms:** +- `minimapScale = min(mw/graphW, mh/graphH) * 0.9` — shrinks Graph Space to Minimap Space. +- `thumbnailOffset` — centers the scaled graph in the minimap container. +- Click/drag: `graphX = (screenX - thumbnailOffset.x) / minimapScale` → update `transform.x/y` to center main canvas on that point. +- Viewport rect: `viewX = -transform.x / transform.k`, projected to minimap: `mx = viewX * minimapScale + thumbnailOffset.x`. + +**`generateThumbnail()`:** Draws the full graph to an off-screen canvas buffer. Called only when graph data or theme changes. + +**`render()`:** Blits thumbnail (O(1)), then overlays search candidate dots, ancestor/descendant highlights, and the red viewport rectangle. + +--- + +### UIManager + +Manages all non-canvas DOM elements: taskbar, search, layers dropdown, legend overlay, and info panel. + +**`buildUI()`:** Constructs HTML overlay components programmatically. Attaches `input`/`keydown` listeners on search input, relaying to `ViewerController`. + +**Search rendering:** `updateSearchResults` renders 50 items initially. `onscroll` listener appends 20 more when near the bottom (chunked rendering to prevent DOM freeze). + +**Layers menu:** Reads `viewer.store.extensions` to build checkboxes (`activeExtensions`) and radio buttons (`colorBy`). `onchange` calls `controller.setState`. + +**Info panel (`updateInfoPanel`):** +1. Renders core PyTorch keys (`op`, `name`, `target`, `args`, `kwargs`, `shape`, `dtype`) at top. +2. Renders Inputs/Outputs as clickable `fx-link` elements that animate camera to related nodes. +3. Groups remaining prefixed keys (e.g. `Profiler.latency`) by prefix, rendering section headers. + +**Legend (`renderLegend`):** Reads `colorBy` state, fetches legend array from `store`, renders color swatches with `shadeColor` adjustment for dark mode. + +--- + +## FXGraphViewer API Reference + +### `FXGraphViewer.create(config)` + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `config.payload` | object | required | Graph payload (base + extensions) | +| `config.mount.root` | string\|HTMLElement | required | CSS selector or element for root container | +| `config.mount.slots.canvas` | string\|HTMLElement | — | External canvas mount | +| `config.mount.slots.toolbar` | string\|HTMLElement | — | External toolbar mount | +| `config.mount.slots.info` | string\|HTMLElement | — | External info panel mount | +| `config.mount.slots.minimap` | string\|HTMLElement | — | External minimap mount | +| `config.mount.slots.legend` | string\|HTMLElement | — | External legend mount | +| `config.layout.preset` | `'split'`\|`'compact'`\|`'headless'`\|`'custom'` | `'split'` | Layout preset | +| `config.layout.panels.sidebar.visible` | boolean | `true` | Show sidebar | +| `config.layout.panels.sidebar.width` | number | `500` | Sidebar width in px | +| `config.layout.panels.sidebar.resizable` | boolean | `true` | Allow sidebar resize | +| `config.layout.panels.sidebar.collapsible` | boolean | `true` | Double-click resizer to collapse | +| `config.layout.panels.info.visible` | boolean | `true` | Show info panel | +| `config.layout.panels.minimap.visible` | boolean | `true` | Show minimap | +| `config.layout.panels.minimap.height` | number | `240` | Minimap height in px | +| `config.layout.panels.minimap.resizable` | boolean | `true` | Allow minimap resize | +| `config.layout.panels.legend.visible` | boolean | `true` | Show legend overlay | +| `config.layout.fullscreen.enabled` | boolean | `true` | Allow fullscreen | +| `config.layout.fullscreen.button` | boolean | `false` | Show fullscreen button in taskbar | +| `config.ui.controls.toolbar` | boolean | `true` | Show/hide entire taskbar | +| `config.ui.controls.search` | boolean | `true` | Show/hide search input | +| `config.ui.controls.layers` | boolean | `true` | Show/hide layers button | +| `config.ui.controls.theme` | boolean | `true` | Show/hide theme selector | +| `config.ui.controls.legend` | boolean | `true` | Show/hide legend overlay | +| `config.ui.controls.zoomButtons` | boolean | `true` | Show/hide zoom-to-fit button | +| `config.ui.controls.fullscreenButton` | boolean | `false` | Show/hide fullscreen button | +| `config.ui.controls.highlightButton` | boolean | `true` | Show/hide ancestor/descendant highlight toggle | +| `config.state.theme` | string | `'light'` | Initial theme (`'light'`, `'dark'`, or custom) | +| `config.state.colorBy` | string | `'base'` | Initial color-by extension id or `'base'` | +| `config.state.activeExtensions` | string[] | `[]` | Initially active extension ids | +| `config.state.highlightAncestors` | boolean | `true` | Highlight ancestors/descendants on select | + +### Preset defaults + +| Preset | sidebar | minimap | info | toolbar | search | layers | theme | +|--------|---------|---------|------|---------|--------|--------|-------| +| `split` (default) | visible | visible | visible | on | on | on | on | +| `compact` | hidden | hidden | hidden | on | on | on | on | +| `headless` | hidden | hidden | hidden | off | off | off | off | +| `custom` | hidden | hidden | hidden | on | on | on | on | + +### Public methods + +``` +viewer.init() Initial zoom-to-fit / position +viewer.renderAll() Re-render canvas + minimap +viewer.getState() Snapshot of current state +viewer.setState(patch) Update controller state +viewer.setTheme(name) Switch theme ('light'|'dark'|custom) +viewer.setLayers(layerIds[]) Set active extensions +viewer.setColorBy(layerId) Set active color-by extension +viewer.selectNode(nodeId, opts?) Select node; opts: { animate, center, k } +viewer.panToNode(nodeId) Pan to node without selecting +viewer.animateToNode(nodeId, opts?) Animated pan; opts: { k } +viewer.setUIVisibility(flags) Show/hide individual controls at runtime +viewer.setLayout(layoutPatch) Apply layout changes at runtime +viewer.upsertLayer(id, payload) Add or update an extension layer +viewer.removeLayer(id) Remove an extension layer +viewer.patchLayerNodes(id, nodePatch) Update node data in an extension +viewer.enterFullscreen() Enter fullscreen (returns Promise) +viewer.exitFullscreen() Exit fullscreen (returns Promise) +viewer.addHighlightGroup(groupId, nodeIds, color) Add/replace a named highlight group overlay +viewer.removeHighlightGroup(groupId) Remove a named highlight group +viewer.clearAllHighlightGroups() Remove all highlight groups +viewer.getHighlightGroups() Returns Map +viewer.destroy() Teardown all DOM, listeners, renderers +viewer.on(event, fn) Subscribe to event; returns unsubscribe fn +viewer.off(event, fn) Unsubscribe from event +FXGraphViewer.registerTheme(name, tokens) Register a custom theme globally +``` + +### Events + +``` +viewer.on('selectionchange', (e) => { e.nodeId, e.prevNodeId, e.source }) +viewer.on('statechange', (e) => { e.prevState, e.nextState, e.source }) +viewer.on('layoutchange', (e) => { e.prevState, e.nextState, e.source }) +viewer.on('hover', (e) => { e.nodeId }) +``` + +### Examples + +**1. Minimal split viewer** +```js +const viewer = FXGraphViewer.create({ + payload: myPayload, + mount: { root: '#my-container' }, +}); +viewer.init(); +``` + +**2. Compact viewer** (no sidebar, no minimap) +```js +const viewer = FXGraphViewer.create({ + payload: myPayload, + mount: { root: '#my-container' }, + layout: { preset: 'compact' }, +}); +viewer.init(); +``` + +**3. Headless with external slots** +```js +const viewer = FXGraphViewer.create({ + payload: myPayload, + mount: { + root: '#root', + slots: { canvas: '#canvas-div', minimap: '#minimap-div', info: '#info-div', legend: '#legend-div' }, + }, + layout: { preset: 'headless' }, +}); +viewer.init(); +``` + +**4. Runtime layer mutation** +```js +viewer.upsertLayer('quant', quantPayload); +viewer.patchLayerNodes('quant', { node_0: { fill_color: '#f00' } }); +viewer.setColorBy('quant'); +``` + +--- + +## Highlight Groups + +Highlight groups are a **read-only programmatic overlay** separate from the single-node selection system. They are set via the JS API and render as thick colored borders around specified nodes. + +### Key properties + +- Multiple named groups coexist simultaneously on the same viewer. +- Each group has a `groupId` (string), a set of `nodeIds`, and a CSS `color`. +- Rendering: 6px solid border drawn **outside** the node rect (offset by `borderWidth/2`) so it does not clip node fill or text. +- Groups survive `clearSelection()` — they are independent of selection state. +- Groups are drawn as a separate pass **after** all node rendering, so they always appear on top. + +### API + +```js +// Create or replace a named group +viewer.addHighlightGroup(groupId, nodeIds, color) +// groupId: string — unique name +// nodeIds: string[] — node IDs to highlight +// color: string — CSS color, e.g. '#ff6600' or 'rgba(255,100,0,0.8)' + +// Remove one group +viewer.removeHighlightGroup(groupId) + +// Remove all groups +viewer.clearAllHighlightGroups() + +// Read current groups (returns a shallow copy) +viewer.getHighlightGroups() // → Map, color: string}> +``` + +### Example + +```js +// Highlight two groups with different colors +viewer.addHighlightGroup('critical', ['conv_0', 'conv_1'], '#ff0000'); +viewer.addHighlightGroup('attention', ['attn_q', 'attn_k', 'attn_v'], '#0066ff'); + +// Remove one group +viewer.removeHighlightGroup('critical'); + +// Clear all +viewer.clearAllHighlightGroups(); +``` + +### Compare mode + +`FXGraphCompare` does **not** automatically propagate highlight groups across viewers. Groups are set explicitly per-viewer via the JS API. The compare sync only propagates the primary selection (unchanged). + +### State storage + +`highlightGroups` is stored on `ViewerController.state.highlightGroups` as a `Map, color: string}>`. It is included in `snapshotState()` (shallow copy). `setState({ highlightGroups: newMap })` replaces the reference atomically and triggers a re-render. + +--- + +## FXGraphCompare API Reference + +### `FXGraphCompare.create(config)` + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `config.viewers` | FXGraphViewer[] | required | Viewers to compare | +| `config.layout.container` | string\|HTMLElement | required | CSS selector or element; compare DOM built inside | +| `config.layout.columns` | number | `2` | Side-by-side column count | +| `config.layout.minimapHeight` | number | `180` | Minimap row height in px (same for all columns) | +| `config.layout.infoHeight` | number | `200` | Merged info bar max-height in px | +| `config.layout.canvasHeightRatio` | number | `0.7` | Reserved; canvas fills remaining space after minimap | +| `config.sharedTaskbar.enabled` | boolean | `false` | Opt-in shared taskbar above the grid | +| `config.sharedTaskbar.controls.theme` | boolean | `true` | Theme selector in shared taskbar | +| `config.sharedTaskbar.controls.layers` | boolean | `true` | Layers+colorBy dropdown in shared taskbar | +| `config.sharedTaskbar.controls.zoomFit` | boolean | `true` | Zoom-to-fit button in shared taskbar | +| `config.sharedTaskbar.controls.syncMode` | boolean | `true` | Sync mode selector in shared taskbar | +| `config.sharedTaskbar.controls.fullscreen` | boolean | `true` | Fullscreen button in shared taskbar | +| `config.sync.mode` | `'none'`\|`'id'`\|`'auto'`\|`'layer'` | `'auto'` | Selection sync mode | +| `config.sync.layer` | string | `''` | Extension id (when `mode='layer'`) | +| `config.sync.field` | string | `''` | Info key to match on (when `mode='layer'`) | + +### Container height requirement + +The compare root uses `height: 100%` to fill its container. If the container has no explicit height (e.g. only `min-height`), the flex chain collapses. A `ResizeObserver` on the container automatically sets `height = 90vh` as a fallback when `offsetHeight < 100`, and switches back to `100%` if the container later gains an explicit height. + +### Methods + +``` +compare.setColumns(n) Update column count; triggers resize +compare.setSync(patch) Update sync config; e.g. { mode: 'none' } +compare.destroy() Restore all viewer DOM, disconnect observers, remove compare root +``` + +### Examples + +**1. Minimal two-viewer compare** (no shared taskbar, sync by id) +```js +const compare = FXGraphCompare.create({ + viewers: [viewerA, viewerB], + layout: { container: '#compare-host' }, + sync: { mode: 'id' }, +}); +``` + +**2. With shared taskbar** +```js +const compare = FXGraphCompare.create({ + viewers: [viewerA, viewerB], + layout: { container: '#compare-host' }, + sharedTaskbar: { enabled: true }, +}); +``` + +**3. Layer-field sync** +```js +const compare = FXGraphCompare.create({ + viewers: [viewerA, viewerB], + layout: { container: '#compare-host' }, + sync: { mode: 'layer', layer: 'quant', field: 'node_name' }, +}); diff --git a/devtools/fx_viewer/templates/canvas_renderer.js b/devtools/fx_viewer/templates/canvas_renderer.js new file mode 100644 index 00000000000..a111fbdb55f --- /dev/null +++ b/devtools/fx_viewer/templates/canvas_renderer.js @@ -0,0 +1,630 @@ +// High-performance 2D canvas rendering of the main graph with pan/zoom and hover interactions. +// +// render() pass order: +// 1. Clear canvas, fill theme background. +// 2. Apply DPR scale + camera transform (translate + scale). +// 3. Draw edges (with midpoint tensor-shape labels). +// 4. Draw node rectangles with multi-line labels and interaction-state coloring. +// 5. Highlight group overlay pass — iterates state.highlightGroups; draws a thick solid border +// outside each node rect with an explicit outer gap from the single-selection border. +// Multiple groups coexist; last group in Map iteration order wins on overlapping nodes. +// 6. Draw smart tooltip for hovered node or edge. +class CanvasRenderer { + constructor(container, viewer) { + this.container = container; + this.viewer = viewer; + this._teardownFns = []; + + this.canvasContainer = document.createElement('div'); + this.canvasContainer.style.width = '100%'; + this.canvasContainer.style.height = '100%'; + this.container.appendChild(this.canvasContainer); + + this.canvas = document.createElement('canvas'); + this.canvas.className = 'fx-canvas'; + this.canvasContainer.appendChild(this.canvas); + + this.ctx = this.canvas.getContext('2d'); + + this.isDragging = false; + this.lastMousePos = { x: 0, y: 0 }; + + this.resize(); + this._onWindowResize = () => this.resize(); + fxOn(this._teardownFns, window, 'resize', this._onWindowResize); + this._resizeObserver = null; + if (typeof ResizeObserver !== 'undefined') { + this._resizeObserver = new ResizeObserver(() => this.resize()); + this._resizeObserver.observe(this.canvasContainer); + } + this.setupEvents(); + } + + resize() { + const dpr = window.devicePixelRatio || 1; + const rect = this.canvasContainer.getBoundingClientRect(); + if (rect.width <= 0 || rect.height <= 0) return; + this.canvas.width = rect.width * dpr; + this.canvas.height = rect.height * dpr; + this.viewer.renderAll(); + } + + resetInteractionState() { + this.isDragging = false; + this.dragMoved = false; + this.lastMousePos = { x: 0, y: 0 }; + this.viewer.controller.handleHover(null, null); + } + + destroy() { + if (this._resizeObserver) { + this._resizeObserver.disconnect(); + this._resizeObserver = null; + } + fxOffAll(this._teardownFns); + } + + setupEvents() { + const onMouseDown = (e) => { + this.isDragging = true; + this.dragMoved = false; + this.lastMousePos = { x: e.clientX, y: e.clientY }; + }; + fxOn(this._teardownFns, this.canvas, 'mousedown', onMouseDown); + + const onMouseMove = (e) => { + if (this.isDragging) { + const dx = e.clientX - this.lastMousePos.x; + const dy = e.clientY - this.lastMousePos.y; + if (Math.abs(dx) > 2 || Math.abs(dy) > 2) { + this.dragMoved = true; + } + this.viewer.controller.transform.x += dx; + this.viewer.controller.transform.y += dy; + this.lastMousePos = { x: e.clientX, y: e.clientY }; + this.viewer.renderAll(); + } else { + const rect = this.canvas.getBoundingClientRect(); + const mouseX = e.clientX - rect.left; + const mouseY = e.clientY - rect.top; + + const transform = this.viewer.controller.transform; + const graphX = (mouseX - transform.x) / transform.k; + const graphY = (mouseY - transform.y) / transform.k; + + this.detectHover(graphX, graphY); + } + }; + fxOn(this._teardownFns, window, 'mousemove', onMouseMove); + + const onMouseUp = () => { + this.isDragging = false; + }; + fxOn(this._teardownFns, window, 'mouseup', onMouseUp); + + const onWheel = (e) => { + e.preventDefault(); + const zoomIntensity = 0.1; + const wheel = e.deltaY < 0 ? 1 : -1; + const zoomFactor = Math.exp(wheel * zoomIntensity); + + const rect = this.canvas.getBoundingClientRect(); + const mouseX = e.clientX - rect.left; + const mouseY = e.clientY - rect.top; + + const transform = this.viewer.controller.transform; + const graphX = (mouseX - transform.x) / transform.k; + const graphY = (mouseY - transform.y) / transform.k; + + transform.k *= zoomFactor; + transform.x = mouseX - graphX * transform.k; + transform.y = mouseY - graphY * transform.k; + + this.viewer.renderAll(); + }; + fxOn(this._teardownFns, this.canvas, 'wheel', onWheel, { passive: false }); + + const onClick = (e) => { + if (this.dragMoved) return; + const state = this.viewer.controller.state; + this.viewer.controller.handleClick(state.hoveredNodeId, state.hoveredEdge); + }; + fxOn(this._teardownFns, this.canvas, 'click', onClick); + } + + detectHover(graphX, graphY) { + let nearestNode = null; + let nearestEdge = null; + + for (let i = 0; i < this.viewer.store.baseData.nodes.length; i++) { + const node = this.viewer.store.baseData.nodes[i]; + const w = node.width; + const h = node.height; + if (graphX >= node.x - w/2 && graphX <= node.x + w/2 && + graphY >= node.y - h/2 && graphY <= node.y + h/2) { + nearestNode = node.id; + break; + } + } + + if (!nearestNode) { + const transform = this.viewer.controller.transform; + const hoverDist = 5 / transform.k; + for (let i = 0; i < this.viewer.store.baseData.edges.length; i++) { + const edge = this.viewer.store.baseData.edges[i]; + if (!edge.bounds) continue; + if (graphX < edge.bounds.minX - hoverDist || graphX > edge.bounds.maxX + hoverDist || + graphY < edge.bounds.minY - hoverDist || graphY > edge.bounds.maxY + hoverDist) continue; + + const v = this.viewer.store.activeNodeMap.get(edge.v); + const w = this.viewer.store.activeNodeMap.get(edge.w); + let min_d = Infinity; + if (edge.points && edge.points.length > 0) { + for (let j = 0; j < edge.points.length - 1; j++) { + const d = this.distToSegment({x: graphX, y: graphY}, edge.points[j], edge.points[j+1]); + min_d = Math.min(min_d, d); + } + } else if (v && w) { + min_d = this.distToSegment({x: graphX, y: graphY}, v, w); + } + if (min_d <= hoverDist) { + nearestEdge = edge; + break; + } + } + } + + this.viewer.controller.handleHover(nearestNode, nearestEdge); + } + + distToSegment(p, v, w) { + const l2 = (v.x - w.x)**2 + (v.y - w.y)**2; + if (l2 === 0) return Math.hypot(p.x - v.x, p.y - v.y); + let t = ((p.x - v.x) * (w.x - v.x) + (p.y - v.y) * (w.y - v.y)) / l2; + t = Math.max(0, Math.min(1, t)); + return Math.hypot(p.x - (v.x + t * (w.x - v.x)), p.y - (v.y + t * (w.y - v.y))); + } + + render() { + const dpr = window.devicePixelRatio || 1; + const ctx = this.ctx; + const transform = this.viewer.controller.transform; + const state = this.viewer.controller.state; + const theme = THEMES[state.themeName]; + const minEdgeWidth = 1 / Math.max(transform.k, 1e-6); + const edgeLineWidths = { + normal: 2, + input: 3, + output: 3, + hover: 3, + }; + const nodeBorderWidths = { + selected: 4, + preview: 4, + hover: 4, + }; + const hoverDashPattern = [8, 6]; + const groupBorderWidth = 3; + const groupBorderGap = 0; + + ctx.setTransform(1, 0, 0, 1, 0, 0); + ctx.fillStyle = theme.bg; + ctx.fillRect(0, 0, this.canvas.width, this.canvas.height); + + ctx.save(); + ctx.scale(dpr, dpr); + ctx.translate(transform.x, transform.y); + ctx.scale(transform.k, transform.k); + + const inNodes = new Set(); + const outNodes = new Set(); + + // nodes under selection or hovering + const activeNodes = [state.previewNodeId, state.selectedNodeId, state.hoveredNodeId] + activeNodes.forEach( + (activeNode) => { + (this.viewer.store.revAdjList.get(activeNode) || []).forEach(e => inNodes.add(e.v)); + (this.viewer.store.adjList.get(activeNode) || []).forEach(e => outNodes.add(e.w)); + } + ) + + const isSelectionMode = !!state.selectedNodeId || !!state.previewNodeId || !!state.selectedEdge; + + this.viewer.store.baseData.edges.forEach(edge => { + const v = this.viewer.store.activeNodeMap.get(edge.v); + const w = this.viewer.store.activeNodeMap.get(edge.w); + if (!v || !w) return; + + let opacity = 1.0; + if (isSelectionMode) { + const targetNode = state.previewNodeId || state.selectedNodeId; + const isSelectedNodeEdge = (targetNode && (edge.v === targetNode || edge.w === targetNode)) || state.selectedEdge === edge; + if (state.highlightAncestors) { + const inAncestors = state.ancestors.has(edge.v) && state.ancestors.has(edge.w); + const inDescendants = state.descendants.has(edge.v) && state.descendants.has(edge.w); + if (!inAncestors && !inDescendants && !isSelectedNodeEdge) { + opacity = 0.15; + } + } + // If highlightAncestors is false, opacity remains 1.0 for all edges + } + + const isHovered = state.hoveredEdge === edge || state.selectedEdge === edge; + const isInputEdge = activeNodes.includes(edge.w); + const isOutputEdge = activeNodes.includes(edge.v); + + if (isHovered) { + ctx.strokeStyle = theme.edgeHover; + ctx.globalAlpha = opacity; + ctx.lineWidth = Math.max(edgeLineWidths.hover, minEdgeWidth); + } else if (isInputEdge) { + ctx.strokeStyle = theme.edgeInput; + ctx.globalAlpha = opacity; + ctx.lineWidth = Math.max(edgeLineWidths.input, minEdgeWidth); + } else if (isOutputEdge) { + ctx.strokeStyle = theme.edgeOutput; + ctx.globalAlpha = opacity; + ctx.lineWidth = Math.max(edgeLineWidths.output, minEdgeWidth); + } else { + ctx.strokeStyle = theme.edgeNormal; + ctx.globalAlpha = opacity; + ctx.lineWidth = Math.max(edgeLineWidths.normal, minEdgeWidth); + } + + ctx.beginPath(); + let midX = 0, midY = 0; + if (edge.points && edge.points.length > 0) { + ctx.moveTo(edge.points[0].x, edge.points[0].y); + for (let i = 1; i < edge.points.length; i++) { + ctx.lineTo(edge.points[i].x, edge.points[i].y); + } + const midIdx = Math.floor(edge.points.length / 2); + midX = edge.points[midIdx].x; + midY = edge.points[midIdx].y; + if (edge.points.length % 2 === 0 && midIdx > 0) { + midX = (edge.points[midIdx].x + edge.points[midIdx-1].x) / 2; + midY = (edge.points[midIdx].y + edge.points[midIdx-1].y) / 2; + } + } else { + ctx.moveTo(v.x, v.y); + ctx.lineTo(w.x, w.y); + midX = (v.x + w.x) / 2; + midY = (v.y + w.y) / 2; + } + ctx.stroke(); + ctx.globalAlpha = 1.0; + + const srcNode = v; + if (srcNode && srcNode.info && srcNode.info.tensor_shape) { + let shapeStr = JSON.stringify(srcNode.info.tensor_shape).replace(/"/g, ''); + let dtypeStr = typeof srcNode.info.dtype === 'string' ? ` [${srcNode.info.dtype.replace('torch.', '')}]` : ''; + let label = `${shapeStr}${dtypeStr}`; + + ctx.font = '10px sans-serif'; + const tw = ctx.measureText(label).width; + const th = 12; + + ctx.globalAlpha = Math.max(opacity, 0.8); + ctx.fillStyle = theme.bg; + ctx.fillRect(midX - tw/2 - 2, midY - th/2 - 2, tw + 4, th + 4); + + ctx.fillStyle = isHovered ? theme.edgeHover : theme.textMuted; + ctx.textAlign = 'center'; + ctx.textBaseline = 'middle'; + ctx.fillText(label, midX, midY); + ctx.globalAlpha = 1.0; + } + }); + + ctx.textAlign = 'center'; + ctx.textBaseline = 'middle'; + + // Helper to lighten/darken hex colors dynamically based on active theme + const shadeColor = (color, percent) => { + if (!color || !color.startsWith('#')) return color; + let R = parseInt(color.substring(1,3), 16); + let G = parseInt(color.substring(3,5), 16); + let B = parseInt(color.substring(5,7), 16); + R = parseInt(R * (100 + percent) / 100); + G = parseInt(G * (100 + percent) / 100); + B = parseInt(B * (100 + percent) / 100); + R = (R<255)?R:255; G = (G<255)?G:255; B = (B<255)?B:255; + R = (R>0)?R:0; G = (G>0)?G:0; B = (B>0)?B:0; + const RR = ((R.toString(16).length==1)?"0"+R.toString(16):R.toString(16)); + const GG = ((G.toString(16).length==1)?"0"+G.toString(16):G.toString(16)); + const BB = ((B.toString(16).length==1)?"0"+B.toString(16):B.toString(16)); + return "#"+RR+GG+BB; + }; + + this.viewer.store.activeNodes.forEach(node => { + const isHovered = node.id === state.hoveredNodeId; + const isSelected = node.id === state.selectedNodeId; + const isPreview = node.id === state.previewNodeId; + const isInput = inNodes.has(node.id); + const isOutput = outNodes.has(node.id); + + let opacity = 1.0; + let isEdgeEndpoint = state.selectedEdge && (state.selectedEdge.v === node.id || state.selectedEdge.w === node.id); + + if (isSelectionMode) { + const targetNode = state.previewNodeId || state.selectedNodeId; + if (state.highlightAncestors) { + const isAncestors = state.ancestors.has(node.id); + const isDescendants = state.descendants.has(node.id); + if (!isAncestors && !isDescendants && node.id !== targetNode && !isEdgeEndpoint) { + opacity = 0.15; + } + } + } + + ctx.globalAlpha = opacity; + + // Determine base fill color (either custom extension color or theme default) + let baseColor = node.fill_color ? node.fill_color : theme.nodeFill; + + // Adjust lightness for Dark Mode to ensure custom colors aren't too bright + if (state.themeName === 'dark' && node.fill_color) { + baseColor = shadeColor(baseColor, -20); + } + + // Apply interaction state coloring dynamically instead of overriding with theme defaults + let renderedFill; + if (isSelected || isPreview || isEdgeEndpoint) { + renderedFill = shadeColor(baseColor, state.themeName === 'dark' ? 30 : 20); + ctx.fillStyle = renderedFill; + ctx.globalAlpha = Math.max(opacity, 0.8); + } else if (isHovered) { + renderedFill = shadeColor(baseColor, state.themeName === 'dark' ? 20 : 20); + ctx.fillStyle = renderedFill; + } else if (isInput) { + renderedFill = shadeColor(baseColor, state.themeName === 'dark' ? 10 : 10); + ctx.fillStyle = renderedFill; + } else if (isOutput) { + renderedFill = shadeColor(baseColor, state.themeName === 'dark' ? 10 : 10); + ctx.fillStyle = renderedFill; + } else { + renderedFill = baseColor; + ctx.fillStyle = renderedFill; + } + + ctx.fillRect(node.x - node.width/2, node.y - node.height/2, node.width, node.height); + + if (isSelected || isPreview || isHovered) { + ctx.strokeStyle = theme.edgeHover; + if (isSelected || isPreview) { + ctx.lineWidth = isSelected ? nodeBorderWidths.selected : nodeBorderWidths.preview; + } else { + ctx.lineWidth = nodeBorderWidths.hover; + } + if (isHovered && !isSelected && !isPreview) { + ctx.setLineDash(hoverDashPattern); + } else { + ctx.setLineDash([]); + } + ctx.strokeRect(node.x - node.width/2, node.y - node.height/2, node.width, node.height); + ctx.setLineDash([]); + } + + ctx.fillStyle = node.fill_color + ? (fxReadableTextColor(renderedFill) || theme.text) + : theme.text; + let allLines = [node.label || node.id]; + if (node.label_append && node.label_append.length > 0) { + allLines = allLines.concat(node.label_append); + } + + const lineHeight = 16; + const startY = node.y - ((allLines.length - 1) * lineHeight) / 2; + + for (let i = 0; i < allLines.length; i++) { + if (i === 0) ctx.font = 'bold 14px sans-serif'; + else ctx.font = '12px sans-serif'; + ctx.fillText(allLines[i], node.x, startY + (i * lineHeight)); + } + + ctx.globalAlpha = 1.0; + }); + + // Highlight group overlay pass — drawn after all nodes + const singleBorderHalf = Math.max( + nodeBorderWidths.selected, + nodeBorderWidths.preview, + nodeBorderWidths.hover + ) / 2; + const groupHalf = groupBorderWidth / 2; + const outerOffset = singleBorderHalf + groupBorderGap + groupHalf; + if (state.highlightGroups && state.highlightGroups.size > 0) { + state.highlightGroups.forEach(({ nodeIds, color }) => { + ctx.strokeStyle = color; + ctx.lineWidth = groupBorderWidth; + ctx.setLineDash([]); + nodeIds.forEach((id) => { + const node = this.viewer.store.activeNodeMap.get(id); + if (!node) return; + ctx.strokeRect( + node.x - node.width / 2 - outerOffset, + node.y - node.height / 2 - outerOffset, + node.width + outerOffset * 2, + node.height + outerOffset * 2 + ); + }); + }); + } + + if (state.hoveredNodeId || state.hoveredEdge) { + this.drawSmartTooltip(ctx, state.hoveredNodeId, state.hoveredEdge); + } + + ctx.restore(); + } + + drawSmartTooltip(ctx, hoveredNodeId, hoveredEdge) { + const theme = THEMES[this.viewer.controller.state.themeName]; + let tooltipLines = []; + let groupNodes = []; + let targetX = 0; + let targetY = 0; + + if (hoveredNodeId) { + const node = this.viewer.store.activeNodeMap.get(hoveredNodeId); + if (!node) return; + targetX = node.x; + targetY = node.y; + groupNodes.push(node); + + (this.viewer.store.revAdjList.get(hoveredNodeId) || []).forEach(e => { + const n = this.viewer.store.activeNodeMap.get(e.v); + if (n) groupNodes.push(n); + }); + (this.viewer.store.adjList.get(hoveredNodeId) || []).forEach(e => { + const n = this.viewer.store.activeNodeMap.get(e.w); + if (n) groupNodes.push(n); + }); + + if (node.tooltip && node.tooltip.length > 0) { + tooltipLines.push(...node.tooltip); + } + } else if (hoveredEdge) { + const srcNode = this.viewer.store.activeNodeMap.get(hoveredEdge.v); + const dstNode = this.viewer.store.activeNodeMap.get(hoveredEdge.w); + if (!srcNode || !dstNode) return; + groupNodes.push(srcNode, dstNode); + + if (hoveredEdge.points && hoveredEdge.points.length > 0) { + const midIdx = Math.floor(hoveredEdge.points.length / 2); + if (hoveredEdge.points.length % 2 === 0 && midIdx > 0) { + targetX = (hoveredEdge.points[midIdx].x + hoveredEdge.points[midIdx-1].x) / 2; + targetY = (hoveredEdge.points[midIdx].y + hoveredEdge.points[midIdx-1].y) / 2; + } else { + targetX = hoveredEdge.points[midIdx].x; + targetY = hoveredEdge.points[midIdx].y; + } + } else { + targetX = (srcNode.x + dstNode.x) / 2; + targetY = (srcNode.y + dstNode.y) / 2; + } + + if (srcNode.info && srcNode.info.tensor_shape) { + tooltipLines.push(`Shape: ${JSON.stringify(srcNode.info.tensor_shape).replace(/"/g, '')}`); + } + if (srcNode.info && srcNode.info.dtype && typeof srcNode.info.dtype === "string") { + tooltipLines.push(`Dtype: ${srcNode.info.dtype.replace('torch.', '')}`); + } + } + + if (tooltipLines.length === 0 || groupNodes.length === 0) return; + + let minX = Infinity, maxX = -Infinity, minY = Infinity, maxY = -Infinity; + groupNodes.forEach(n => { + minX = Math.min(minX, n.x - n.width/2); + maxX = Math.max(maxX, n.x + n.width/2); + minY = Math.min(minY, n.y - n.height/2); + maxY = Math.max(maxY, n.y + n.height/2); + }); + + const transform = this.viewer.controller.transform; + const dpr = window.devicePixelRatio || 1; + + const fontSize = 12 / transform.k; + ctx.font = `bold ${fontSize}px sans-serif`; + let maxW = 0; + tooltipLines.forEach(line => { + maxW = Math.max(maxW, ctx.measureText(line).width); + }); + + const padding = 8 / transform.k; + const tw = maxW + padding * 2; + const lineHeight = 16 / transform.k; + const th = (tooltipLines.length * lineHeight) + padding * 2; + + const viewLeft = -transform.x / transform.k; + const viewTop = -transform.y / transform.k; + const viewRight = viewLeft + (this.canvas.width / dpr) / transform.k; + const viewBottom = viewTop + (this.canvas.height / dpr) / transform.k; + + const margin = 20 / transform.k; + + const candidates = [ + { id: 'up', x: targetX - tw/2, y: minY - margin - th }, + { id: 'left', x: minX - margin - tw, y: targetY - th/2 }, + { id: 'right', x: maxX + margin, y: targetY - th/2 }, + { id: 'down', x: targetX - tw/2, y: maxY + margin } + ]; + + let bestCand = null; + let minD = Infinity; + let rightCand = null; + + let validCandidates = candidates.filter(c => + c.x >= viewLeft && (c.x + tw) <= viewRight && + c.y >= viewTop && (c.y + th) <= viewBottom + ); + + if (validCandidates.length === 0) validCandidates = candidates; + + validCandidates.forEach(c => { + const cx = c.x + tw/2; + const cy = c.y + th/2; + const d = Math.hypot(cx - targetX, cy - targetY); + c.distance = d; + + if (c.id === 'right') { + rightCand = c; + } + + if (d < minD) { + minD = d; + bestCand = c; + } + }); + + if (rightCand && rightCand.distance <= minD * 10) { + bestCand = rightCand; + } + + let tooltipX = bestCand.x; + let tooltipY = bestCand.y; + + let lineStartX = targetX; + let lineStartY = targetY; + if (hoveredNodeId) { + const node = this.viewer.store.activeNodeMap.get(hoveredNodeId); + if (node) { + if (bestCand.id === 'right') lineStartX = node.x + node.width / 2; + else if (bestCand.id === 'left') lineStartX = node.x - node.width / 2; + else if (bestCand.id === 'up') lineStartY = node.y - node.height / 2; + else if (bestCand.id === 'down') lineStartY = node.y + node.height / 2; + } + } + + ctx.strokeStyle = theme.edgeHover; + ctx.lineWidth = 2 / transform.k; + ctx.setLineDash([5 / transform.k, 5 / transform.k]); + ctx.beginPath(); + ctx.moveTo(lineStartX, lineStartY); + if (bestCand.id === 'up') { + ctx.lineTo(tooltipX + tw/2, tooltipY + th); + } else if (bestCand.id === 'down') { + ctx.lineTo(tooltipX + tw/2, tooltipY); + } else if (bestCand.id === 'left') { + ctx.lineTo(tooltipX + tw, tooltipY + th/2); + } else { // right + ctx.lineTo(tooltipX, tooltipY + th/2); + } + ctx.stroke(); + ctx.setLineDash([]); + + ctx.fillStyle = theme.uiBg; + ctx.fillRect(tooltipX, tooltipY, tw, th); + ctx.strokeStyle = theme.uiBorder; + ctx.lineWidth = 1 / transform.k; + ctx.strokeRect(tooltipX, tooltipY, tw, th); + + ctx.fillStyle = theme.text; + ctx.textAlign = 'left'; + ctx.textBaseline = 'top'; + tooltipLines.forEach((line, idx) => { + ctx.fillText(line, tooltipX + padding, tooltipY + padding + idx * lineHeight); + }); + } +} diff --git a/devtools/fx_viewer/templates/compare.js b/devtools/fx_viewer/templates/compare.js new file mode 100644 index 00000000000..501d34b1721 --- /dev/null +++ b/devtools/fx_viewer/templates/compare.js @@ -0,0 +1,1107 @@ +// FXGraphCompare: orchestrates multi-view compare DOM, selection sync, and lifecycle. +// Layout: 3-row × (N+1)-col CSS grid. +// Col 0 (sidebar): shared controls spanning rows 0-1. +// Cols 1..N: per-graph minimap (row 0) and canvas (row 1). +// Row 2: info row using CSS subgrid for aligned property columns. +// +// Selection sync modes (config.sync.mode): +// 'auto' (default) — tries from_node_root first, then debug_handle set-intersection, falls back to node-ID match. +// 'id' — matches by node id only. +// 'layer' — matches by extensions[layer].nodes[nodeId].info[field] value equality; +// picks last in topological order on multiple matches. +// 'none' — no cross-viewer selection propagation. +// +// debug_handle normalization (_normalizeHandle): +// int dh → Set{dh} (if non-zero) +// int[] dh → Set(dh.filter(x => x !== 0)) +// null/0/[] → empty Set (no match) +// Two nodes match if their normalized sets have a non-empty intersection. +// +// Sidebar sync selector options (rebuilt by _rebuildSyncPanel): +// "Auto (from_node→handle→id)" → mode: 'auto' +// "ID only" → mode: 'id' +// "Don't sync" → mode: 'none' +// "Ext: ." → mode: 'layer' (one per registered sync_key) + +class FXGraphCompare { + static create(config) { + return new FXGraphCompare(config); + } + + constructor(config = {}) { + // Accept Map or Array (backward compat) + if (config.viewers instanceof Map) { + this._viewerMap = config.viewers; + } else if (Array.isArray(config.viewers)) { + this._viewerMap = new Map(config.viewers.map((v, i) => + [(v.config && v.config.title) || `Graph ${i + 1}`, v] + )); + } else { + this._viewerMap = new Map(); + } + this.viewers = [...this._viewerMap.values()]; + this._viewerNames = [...this._viewerMap.keys()]; + this._visibleViewers = new Set(this._viewerNames); + + this.sync = { + mode: 'auto', + layer: '', + field: '', + ...(config.sync || {}), + }; + + this.container = null; + if (config.layout && config.layout.container) { + if (typeof config.layout.container === 'string') { + this.container = document.querySelector(config.layout.container); + } else if (config.layout.container instanceof HTMLElement) { + this.container = config.layout.container; + } + } + + this._guards = new WeakSet(); + this._offs = []; + this._root = null; + this._grid = null; + this._infoRow = null; + this._minimapCells = []; + this._nameCells = []; + this._canvasCells = []; + this._colResizeObservers = []; + this._domSnapshots = new WeakMap(); + this._followSelection = true; + this._openPortalMenus = []; + this._currentTheme = this.viewers[0]?.controller?.state?.themeName || 'light'; + this._layoutRefreshQueued = false; + this._layoutRefreshAgain = false; + this._pendingRefreshAgainResetView = false; + this._needsResetOnNextVisibleLayout = false; + + if (this.container) { + this._buildCompareDOM(); + } + + this._wireSelectionSync(); + this._wireStateSync(); + + // Suppress per-viewer UI in compare mode (keep taskbar + search only) + this.viewers.forEach((v) => { + v.setUIVisibility({ + toolbar: true, + search: true, + layers: false, + theme: false, + zoomButtons: false, + fullscreenButton: false, + highlightButton: false, + }); + }); + } + + _buildCompareDOM() { + this._root = document.createElement('div'); + this._root.className = 'fx-compare-root'; + // Fallback height if container has no defined height + if (this.container.offsetHeight < 100) { + this._root.style.height = Math.round(window.innerHeight * 0.85) + 'px'; + } + this.container.appendChild(this._root); + + const N = this.viewers.length; + + // Main grid: col 0 = 160px sidebar, cols 1..N = 1fr each + this._grid = document.createElement('div'); + this._grid.className = 'fx-compare-grid'; + this._grid.style.gridTemplateColumns = `160px repeat(${N}, 1fr)`; + this._root.appendChild(this._grid); + + // Sidebar cell (col 1, rows 1-3) + const sidebar = document.createElement('div'); + sidebar.className = 'fx-compare-sidebar-cell'; + this._grid.appendChild(sidebar); + this._buildSidebar(sidebar); + this._sidebarEl = sidebar; + + // Per-viewer cells + this._minimapCells = []; + this._nameCells = []; + this._canvasCells = []; + + this.viewers.forEach((viewer, i) => { + // Snapshot original DOM positions for teardown + this._domSnapshots.set(viewer, { + mainAreaParent: viewer.mainArea.parentNode, + mainAreaNextSibling: viewer.mainArea.nextSibling, + minimapParent: viewer.minimapRenderer ? viewer.minimapRenderer.container.parentNode : null, + minimapNextSibling: viewer.minimapRenderer ? viewer.minimapRenderer.container.nextSibling : null, + wrapperDisplay: viewer.wrapper.style.display, + }); + + // Minimap cell: col i+2, row 1 + const minimapCell = document.createElement('div'); + minimapCell.className = 'fx-compare-minimap-cell'; + minimapCell.style.gridColumn = String(i + 2); + minimapCell.style.gridRow = '1'; + if (viewer.minimapRenderer) { + minimapCell.appendChild(viewer.minimapRenderer.container); + } + this._grid.appendChild(minimapCell); + this._minimapCells.push(minimapCell); + + // Name cell: col i+2, row 2 + const nameCell = document.createElement('div'); + nameCell.className = 'fx-compare-name-cell'; + nameCell.style.gridColumn = String(i + 2); + nameCell.textContent = this._viewerNames[i]; + this._grid.appendChild(nameCell); + this._nameCells.push(nameCell); + + // Canvas cell: col i+2, row 3 + const canvasCell = document.createElement('div'); + canvasCell.className = 'fx-compare-canvas-cell'; + canvasCell.style.gridColumn = String(i + 2); + canvasCell.style.gridRow = '3'; + canvasCell.appendChild(viewer.mainArea); + this._grid.appendChild(canvasCell); + this._canvasCells.push(canvasCell); + + // Hide viewer's own wrapper shell + viewer.wrapper.style.display = 'none'; + + // ResizeObserver on canvas cell + if (typeof ResizeObserver !== 'undefined') { + const ro = new ResizeObserver(() => this._scheduleLayoutRefresh()); + ro.observe(canvasCell); + if (viewer.minimapRenderer) { + ro.observe(minimapCell); + } + this._colResizeObservers.push(ro); + } + }); + + // Info row (col 1..-1, row 4) — uses CSS subgrid + this._infoRow = document.createElement('div'); + this._infoRow.className = 'fx-compare-info-row'; + this._grid.appendChild(this._infoRow); + this._updateMergedInfo(null); + + this._applyCompareTheme(this.viewers[0]?.controller?.state?.themeName || 'light'); + + this._scheduleLayoutRefresh({ resetView: true }); + } + + _scheduleLayoutRefresh(options = {}) { + if (!this._root) return; + if (this._layoutRefreshQueued) { + this._layoutRefreshAgain = true; + this._pendingRefreshAgainResetView = this._pendingRefreshAgainResetView || !!options.resetView; + return; + } + this._pendingRefreshResetView = this._pendingRefreshResetView || !!options.resetView; + this._layoutRefreshQueued = true; + requestAnimationFrame(() => { + requestAnimationFrame(() => { + this._layoutRefreshQueued = false; + if (!this._root) return; + const resetView = !!this._pendingRefreshResetView || this._needsResetOnNextVisibleLayout; + this._pendingRefreshResetView = false; + this._refreshViewerLayout({ resetView }); + if (this._layoutRefreshAgain) { + const resetAgain = this._pendingRefreshAgainResetView; + this._layoutRefreshAgain = false; + this._pendingRefreshAgainResetView = false; + this._scheduleLayoutRefresh({ resetView: resetAgain }); + } + }); + }); + } + + _refreshViewerLayout(options = {}) { + const resetView = !!options.resetView; + let sawInvalidLayout = false; + this.viewers.forEach((viewer, i) => { + if (!this._visibleViewers.has(this._viewerNames[i])) return; + if (viewer.canvasRenderer && typeof viewer.canvasRenderer.resetInteractionState === 'function') { + viewer.canvasRenderer.resetInteractionState(); + } + if (viewer.minimapRenderer && typeof viewer.minimapRenderer.resetInteractionState === 'function') { + viewer.minimapRenderer.resetInteractionState(); + } + + const canvasRect = viewer.canvasContainer && viewer.canvasContainer.getBoundingClientRect(); + const minimapRect = viewer.minimapRenderer && viewer.minimapRenderer.container.getBoundingClientRect(); + const hasCanvasLayout = canvasRect && canvasRect.width > 0 && canvasRect.height > 0; + const hasMinimapLayout = !viewer.minimapRenderer || (minimapRect && minimapRect.width > 0 && minimapRect.height > 0); + if (!hasCanvasLayout || !hasMinimapLayout) { + sawInvalidLayout = true; + return; + } + + if (viewer.canvasRenderer) viewer.canvasRenderer.resize(); + if (viewer.minimapRenderer) { + viewer.minimapRenderer.resize(); + viewer.minimapRenderer.generateThumbnail(); + } + if (resetView) viewer.init(); + else viewer.renderAll(); + }); + this._needsResetOnNextVisibleLayout = sawInvalidLayout; + } + + _buildSidebar(sidebar) { + // Layers button + menu (portal pattern — menu appended to document.body when open) + const layersWrap = document.createElement('div'); + layersWrap.style.position = 'relative'; + const layersBtn = document.createElement('button'); + layersBtn.className = 'fx-button'; + layersBtn.title = 'Layers / Color By'; + layersBtn.textContent = 'Layers'; + layersBtn.style.marginLeft = '0'; + layersBtn.style.boxSizing = 'border-box'; + layersBtn.style.width = '100%'; + const layersMenu = document.createElement('div'); + layersMenu.className = 'fx-compare-portal-menu'; + layersBtn.onclick = () => { + if (layersMenu.parentNode) { + this._closePortalMenu(layersMenu); + } else { + this._rebuildLayersMenu(layersMenu); + this._openPortalMenu(layersBtn, layersMenu); + } + }; + fxOn(this._offs, document, 'click', (e) => { + if (!layersWrap.contains(e.target) && !layersMenu.contains(e.target)) { + this._closePortalMenu(layersMenu); + } + }); + layersWrap.appendChild(layersBtn); + sidebar.appendChild(layersWrap); + + // Theme selector + const themeSel = document.createElement('select'); + themeSel.className = 'fx-select'; + themeSel.style.marginLeft = '0'; + themeSel.style.boxSizing = 'border-box'; + themeSel.style.width = '100%'; + themeSel.innerHTML = ``; + themeSel.onchange = (e) => this.viewers.forEach((v) => v.setTheme(e.target.value)); + sidebar.appendChild(themeSel); + this._themeSelect = themeSel; + + // Zoom-fit button + const zoomBtn = document.createElement('button'); + zoomBtn.className = 'fx-button'; + zoomBtn.style.marginLeft = '0'; + zoomBtn.innerHTML = '⤢ Fit'; + zoomBtn.title = 'Zoom to Fit All'; + zoomBtn.style.boxSizing = 'border-box'; + zoomBtn.style.width = '100%'; + zoomBtn.onclick = () => this.viewers.forEach((v) => v.controller.zoomToFit()); + sidebar.appendChild(zoomBtn); + + // Fullscreen button + const fsBtn = document.createElement('button'); + fsBtn.className = 'fx-button'; + fsBtn.style.marginLeft = '0'; + fsBtn.innerHTML = '⛶ Full'; + fsBtn.title = 'Fullscreen'; + fsBtn.style.boxSizing = 'border-box'; + fsBtn.style.width = '100%'; + fsBtn.onclick = () => { + if (document.fullscreenElement) { + document.exitFullscreen(); + } else { + this._root.requestFullscreen && this._root.requestFullscreen(); + } + }; + fxOn(this._offs, document, 'fullscreenchange', () => { + fsBtn.innerHTML = document.fullscreenElement ? '✕ Exit' : '⛶ Full'; + fsBtn.title = document.fullscreenElement ? 'Exit Fullscreen' : 'Fullscreen'; + }); + sidebar.appendChild(fsBtn); + + // Sync mode selector (only registered sync_keys) + const syncSel = document.createElement('select'); + syncSel.className = 'fx-select'; + syncSel.style.marginLeft = '0'; + syncSel.style.boxSizing = 'border-box'; + syncSel.style.width = '100%'; + this._syncSelect = syncSel; + this._rebuildSyncPanel(); + syncSel.onchange = (e) => { + const val = e.target.value; + if (val === 'none') { this.setSync({ mode: 'none' }); } + else if (val === 'id') { this.setSync({ mode: 'id' }); } + else if (val === 'auto') { this.setSync({ mode: 'auto' }); } + else { const [layer, field] = val.split('::'); this.setSync({ mode: 'layer', layer, field }); } + }; + sidebar.appendChild(syncSel); + + // Visible graphs toggle (portal pattern) + const visWrap = document.createElement('div'); + visWrap.style.position = 'relative'; + const visBtn = document.createElement('button'); + visBtn.className = 'fx-button'; + visBtn.style.marginLeft = '0'; + visBtn.style.boxSizing = 'border-box'; + visBtn.style.width = '100%'; + visBtn.title = 'Toggle visible graphs'; + visBtn.textContent = 'Graphs'; + const visMenu = document.createElement('div'); + visMenu.className = 'fx-compare-portal-menu'; + visBtn.onclick = () => { + if (visMenu.parentNode) { + this._closePortalMenu(visMenu); + } else { + this._rebuildVisMenu(visMenu); + this._openPortalMenu(visBtn, visMenu); + } + }; + fxOn(this._offs, document, 'click', (e) => { + if (!visWrap.contains(e.target) && !visMenu.contains(e.target)) { + this._closePortalMenu(visMenu); + } + }); + visWrap.appendChild(visBtn); + sidebar.appendChild(visWrap); + + // Follow selection toggle + const followBtn = document.createElement('button'); + followBtn.className = 'fx-button'; + followBtn.style.marginLeft = '0'; + followBtn.style.boxSizing = 'border-box'; + followBtn.style.width = '100%'; + followBtn.title = 'Toggle: zoom-to-fit on selection sync'; + const updateFollowBtn = () => { + followBtn.textContent = this._followSelection ? '\u2299 Zoom-Fit' : '\u25cb Zoom-Fit'; + followBtn.style.opacity = this._followSelection ? '1' : '0.6'; + }; + updateFollowBtn(); + followBtn.onclick = () => { + this._followSelection = !this._followSelection; + updateFollowBtn(); + }; + sidebar.appendChild(followBtn); + + // Highlight ancestors toggle + const hlBtn = document.createElement('button'); + hlBtn.className = 'fx-button'; + hlBtn.style.marginLeft = '0'; + hlBtn.style.boxSizing = 'border-box'; + hlBtn.style.width = '100%'; + hlBtn.innerHTML = '🔗'; + hlBtn.title = 'Toggle Highlight Ancestors/Descendants'; + const updateHlBtn = () => { + const on = this.viewers[0]?.controller?.state?.highlightAncestors !== false; + hlBtn.style.opacity = on ? '1' : '0.6'; + }; + updateHlBtn(); + hlBtn.onclick = () => { + const on = this.viewers[0]?.controller?.state?.highlightAncestors !== false; + this.viewers.forEach((v) => { + v.controller.state.highlightAncestors = !on; + v.controller.setState({}); + }); + updateHlBtn(); + }; + sidebar.appendChild(hlBtn); + this._hlBtn = hlBtn; + } + + _rebuildLayersMenu(menu) { + const allLayers = new Map(); + this.viewers.forEach((v) => { + Object.entries(v.store.extensions || {}).forEach(([id, ext]) => { + if (!allLayers.has(id)) allLayers.set(id, ext.name || id); + }); + }); + let html = '
Extensions
'; + allLayers.forEach((name, id) => { + html += ``; + }); + html += '
Color By
'; + html += ``; + allLayers.forEach((name, id) => { + html += ``; + }); + menu.innerHTML = html; + menu.querySelectorAll('input[type="checkbox"]').forEach((cb) => { + cb.onchange = (e) => { + this.viewers.forEach((v) => { + const active = new Set(v.controller.state.activeExtensions); + if (e.target.checked) active.add(e.target.value); + else active.delete(e.target.value); + v.setLayers([...active]); + }); + }; + }); + const currentColorBy = this.viewers[0]?.controller?.state?.colorBy || 'base'; + const matchingRadio = menu.querySelector(`input[type="radio"][name="cmp_colorby"][value="${currentColorBy}"]`); + if (matchingRadio) matchingRadio.checked = true; + menu.querySelectorAll('input[type="radio"][name="cmp_colorby"]').forEach((rb) => { + rb.onchange = (e) => { + if (e.target.checked) this.viewers.forEach((v) => v.setColorBy(e.target.value)); + }; + }); + } + + _rebuildSyncPanel() { + if (!this._syncSelect) return; + let html = '' + + '' + + ''; + const seen = new Set(['auto', 'id', 'none']); + this.viewers.forEach((v) => { + Object.entries(v.store.extensions || {}).forEach(([extId, ext]) => { + (ext.sync_keys || []).forEach((field) => { + const key = `${extId}::${field}`; + if (!seen.has(key)) { + seen.add(key); + html += ``; + } + }); + }); + }); + this._syncSelect.innerHTML = html; + const sync = this.sync; + if (sync.mode === 'none') this._syncSelect.value = 'none'; + else if (sync.mode === 'id') this._syncSelect.value = 'id'; + else if (sync.mode === 'layer' && sync.layer && sync.field) this._syncSelect.value = `${sync.layer}::${sync.field}`; + else this._syncSelect.value = 'auto'; + } + + _rebuildVisMenu(menu) { + menu.innerHTML = ''; + this._viewerNames.forEach((name) => { + const label = document.createElement('label'); + label.style.cssText = 'display:block;padding:5px;cursor:pointer;'; + const cb = document.createElement('input'); + cb.type = 'checkbox'; + cb.checked = this._visibleViewers.has(name); + cb.onchange = (e) => this._setViewerVisible(name, e.target.checked); + label.appendChild(cb); + label.appendChild(document.createTextNode(' ' + name)); + menu.appendChild(label); + }); + } + + _openPortalMenu(btn, menu) { + if (menu.parentNode) menu.parentNode.removeChild(menu); + const btnRect = btn.getBoundingClientRect(); + const rootRect = this._root.getBoundingClientRect(); + const scrollTop = this._root.scrollTop; + const scrollLeft = this._root.scrollLeft; + const top = btnRect.top - rootRect.top + scrollTop; + const left = btnRect.right - rootRect.left + scrollLeft + 4; + menu.style.top = top + 'px'; + menu.style.left = left + 'px'; + menu.style.maxHeight = Math.min(window.innerHeight - btnRect.top - 8, window.innerHeight * 0.6) + 'px'; + this._root.appendChild(menu); + this._openPortalMenus.push(menu); + const theme = (typeof THEMES !== 'undefined' && THEMES[this._currentTheme]) || THEMES?.light; + if (theme) { + menu.style.backgroundColor = theme.uiBg; + menu.style.color = theme.text; + menu.style.borderColor = theme.uiBorder; + menu.querySelectorAll('label, .fx-button, .fx-select').forEach((el) => { + el.style.color = theme.text; + }); + } + } + + _closePortalMenu(menu) { + if (menu.parentNode) menu.parentNode.removeChild(menu); + const idx = this._openPortalMenus.indexOf(menu); + if (idx !== -1) this._openPortalMenus.splice(idx, 1); + } + + _setViewerVisible(name, visible) { + if (visible) this._visibleViewers.add(name); + else this._visibleViewers.delete(name); + + const visCount = [...this._viewerNames].filter((n) => this._visibleViewers.has(n)).length; + this._grid.style.gridTemplateColumns = `160px repeat(${Math.max(1, visCount)}, 1fr)`; + + let colIdx = 2; + this.viewers.forEach((v, i) => { + const isVis = this._visibleViewers.has(this._viewerNames[i]); + if (this._minimapCells[i]) { + this._minimapCells[i].style.display = isVis ? '' : 'none'; + if (isVis) this._minimapCells[i].style.gridColumn = String(colIdx); + } + if (this._nameCells[i]) { + this._nameCells[i].style.display = isVis ? '' : 'none'; + if (isVis) this._nameCells[i].style.gridColumn = String(colIdx); + } + if (this._canvasCells[i]) { + this._canvasCells[i].style.display = isVis ? '' : 'none'; + if (isVis) this._canvasCells[i].style.gridColumn = String(colIdx); + } + if (isVis) { + colIdx++; + } + }); + + this._scheduleLayoutRefresh(); + + if (visible) { + const newViewer = this.viewers[this._viewerNames.indexOf(name)]; + requestAnimationFrame(() => { + const canvasRect = newViewer.canvasContainer && newViewer.canvasContainer.getBoundingClientRect(); + if (!canvasRect || canvasRect.width <= 0 || canvasRect.height <= 0) { + this._needsResetOnNextVisibleLayout = true; + return; + } + if (newViewer.canvasRenderer) newViewer.canvasRenderer.resize(); + if (newViewer.minimapRenderer) { + newViewer.minimapRenderer.resize(); + newViewer.minimapRenderer.generateThumbnail(); + } + let srcViewer = null, srcNodeId = null; + this.viewers.forEach((v, i) => { + if (v === newViewer || !this._visibleViewers.has(this._viewerNames[i])) return; + const sel = v.controller?.state?.selectedNodeId; + if (sel && !srcNodeId) { srcViewer = v; srcNodeId = sel; } + }); + if (srcViewer && srcNodeId) { + const targetId = this._findSyncTarget(srcViewer, srcNodeId, newViewer); + if (targetId) { + newViewer.selectNode(targetId, { center: false }); + newViewer.controller.zoomToFit(); + } else { + newViewer.controller.zoomToFit(); + } + } else { + newViewer.controller.zoomToFit(); + } + }); + } + + this._updateMergedInfo(null); + } + + _teardownCompareDOM() { + // Close any open portal menus + this._openPortalMenus.slice().forEach((m) => this._closePortalMenu(m)); + + this._colResizeObservers.forEach((ro) => ro.disconnect()); + this._colResizeObservers = []; + + this.viewers.forEach((viewer) => { + const snap = this._domSnapshots.get(viewer); + if (!snap) return; + + if (snap.mainAreaParent) { + if (snap.mainAreaNextSibling) { + snap.mainAreaParent.insertBefore(viewer.mainArea, snap.mainAreaNextSibling); + } else { + snap.mainAreaParent.appendChild(viewer.mainArea); + } + } + + if (viewer.minimapRenderer && snap.minimapParent) { + if (snap.minimapNextSibling) { + snap.minimapParent.insertBefore(viewer.minimapRenderer.container, snap.minimapNextSibling); + } else { + snap.minimapParent.appendChild(viewer.minimapRenderer.container); + } + } + + viewer.wrapper.style.display = snap.wrapperDisplay; + this._domSnapshots.delete(viewer); + }); + + if (this._root && this._root.parentNode) { + this._root.parentNode.removeChild(this._root); + } + this._root = null; + this._grid = null; + this._infoRow = null; + this._minimapCells = []; + this._nameCells = []; + this._canvasCells = []; + } + + _wireSelectionSync() { + this.viewers.forEach((viewer) => { + const off = viewer.on('selectionchange', (evt) => { + if (evt.nextSelection){ + if (this._followSelection) viewer.controller.zoomToFit(); + } + + if (this._guards.has(viewer)) return; + if (!evt.nextSelection) { + this.viewers.forEach((other) => { + if (other === viewer) return; + this._applyGuarded(other, () => other.clearSelection()); + }); + this._updateMergedInfo(null); + this.viewers.forEach((v) => v.removeHighlightGroup('_sync_candidates')); + return; + } + + const nodeIdMap = this._buildSyncedNodeMap(viewer, evt.nextSelection); + + this.viewers.forEach((other) => { + if (other === viewer) return; + const targetId = nodeIdMap.get(other); + if (targetId) { + this._applyGuarded(other, () => { + other.selectNode(targetId, { center: false }); + if (this._followSelection) other.controller.panToNode(targetId, {}); + }); + } else { + this._applyGuarded(other, () => other.clearSelection()); + } + }); + + this._updateMergedInfo(nodeIdMap); + this._applyAutoCandidateHighlights(viewer, evt.nextSelection); + }); + this._offs.push(off); + }); + } + + _buildSyncedNodeMap(sourceViewer, nodeId) { + const map = new Map(); + if (!nodeId) return map; + map.set(sourceViewer, nodeId); + this.viewers.forEach((viewer) => { + if (viewer === sourceViewer) return; + const targetId = this._findSyncTarget(sourceViewer, nodeId, viewer); + if (targetId) { + map.set(viewer, targetId); + } + }); + return map; + } + + _collectCurrentSelectionMap() { + const map = new Map(); + this.viewers.forEach((viewer) => { + const selectedNodeId = viewer.controller?.state?.selectedNodeId; + if (selectedNodeId) { + map.set(viewer, selectedNodeId); + } + }); + return map.size > 0 ? map : null; + } + + _buildAutoCandidates(sourceViewer, nodeId, targetViewer) { + const rootCandidates = this._getAllFromNodeRootCandidates(sourceViewer, nodeId, targetViewer); + const handleCandidates = this._getAllDebugHandleCandidates(sourceViewer, nodeId, targetViewer); + const targetId = this._findSyncTarget(sourceViewer, nodeId, targetViewer); + const candidates = [...rootCandidates, ...handleCandidates]; + if (targetId) candidates.push(targetId); + if (targetViewer === sourceViewer) candidates.push(nodeId); + return [...new Set(candidates)]; + } + + _applyAutoCandidateHighlights(sourceViewer, nodeId) { + if (this.sync.mode !== 'auto' || !nodeId) { + this.viewers.forEach((v) => v.removeHighlightGroup('_sync_candidates')); + return; + } + + this.viewers.forEach((viewer) => { + const allCandidates = this._buildAutoCandidates(sourceViewer, nodeId, viewer); + if (allCandidates.length > 0) { + viewer.addHighlightGroup('_sync_candidates', allCandidates, '#ffaa00'); + } else { + viewer.removeHighlightGroup('_sync_candidates'); + } + }); + } + + _syncPreviewAcrossViewers(sourceViewer, previewNodeId) { + if (!previewNodeId) { + this.viewers.forEach((other) => { + if (other === sourceViewer) return; + this._applyGuarded(other, () => { + const selectedNodeId = other.controller.state.selectedNodeId; + const selectedEdge = other.controller.state.selectedEdge; + let ancestors = new Set(); + let descendants = new Set(); + if (selectedNodeId) { + ancestors = other.store.getAncestors(selectedNodeId); + descendants = other.store.getDescendants(selectedNodeId); + } else if (selectedEdge) { + ancestors = other.store.getAncestors(selectedEdge.v); + descendants = other.store.getDescendants(selectedEdge.w); + } + other.controller.setState({ + previewNodeId: null, + ancestors, + descendants, + }, { source: 'compare-preview-sync' }); + }); + }); + this._updateMergedInfo(this._collectCurrentSelectionMap()); + this._applyAutoCandidateHighlights(sourceViewer, sourceViewer.controller?.state?.selectedNodeId || null); + return; + } + + const nodeIdMap = this._buildSyncedNodeMap(sourceViewer, previewNodeId); + this.viewers.forEach((other) => { + if (other === sourceViewer) return; + const targetPreviewId = nodeIdMap.get(other) || null; + this._applyGuarded(other, () => { + if (!targetPreviewId) { + const selectedNodeId = other.controller.state.selectedNodeId; + const selectedEdge = other.controller.state.selectedEdge; + let ancestors = new Set(); + let descendants = new Set(); + if (selectedNodeId) { + ancestors = other.store.getAncestors(selectedNodeId); + descendants = other.store.getDescendants(selectedNodeId); + } else if (selectedEdge) { + ancestors = other.store.getAncestors(selectedEdge.v); + descendants = other.store.getDescendants(selectedEdge.w); + } + other.controller.setState({ + previewNodeId: null, + ancestors, + descendants, + }, { source: 'compare-preview-sync' }); + return; + } + + other.controller.setState({ + previewNodeId: targetPreviewId, + ancestors: other.store.getAncestors(targetPreviewId), + descendants: other.store.getDescendants(targetPreviewId), + }, { source: 'compare-preview-sync' }); + if (this._followSelection) { + other.controller.panToNode(targetPreviewId); + } + }); + }); + this._updateMergedInfo(nodeIdMap); + this._applyAutoCandidateHighlights(sourceViewer, previewNodeId); + } + + _findSyncTarget(sourceViewer, nodeId, targetViewer) { + const mode = this.sync.mode; + if (mode === 'none') return null; + if (mode === 'auto') { + const byRoot = this._syncByFromNodeRoot(sourceViewer, nodeId, targetViewer); + if (byRoot) return byRoot; + const byHandle = this._syncByDebugHandle(sourceViewer, nodeId, targetViewer); + if (byHandle) return byHandle; + return targetViewer.store.activeNodeMap.has(nodeId) ? nodeId : null; + } + if (mode === 'id') { + return targetViewer.store.activeNodeMap.has(nodeId) ? nodeId : null; + } + if (mode === 'layer') { + const { layer, field } = this.sync; + const srcVal = sourceViewer.store.extensions[layer]?.nodes[nodeId]?.info[field]; + if (srcVal === undefined) return null; + const sourceNode = sourceViewer.store.activeNodeMap.get(nodeId); + const candidates = targetViewer.store.activeNodes.filter( + (n) => targetViewer.store.extensions[layer]?.nodes[n.id]?.info[field] === srcVal + ); + const picked = this._pickCandidateByTargetMode(sourceNode, candidates); + return picked ? picked.id : null; + } + return null; + } + + _getTargetMode(node) { + const target = String(node?.info?.target || '').toLowerCase(); + if (target.includes('dequantize')) return 'dequantize'; + if (target.includes('quantize') || target.includes('activation_post_process')) return 'quantize'; + return 'none'; + } + + _pickCandidateByTargetMode(sourceNode, candidates) { + if (!Array.isArray(candidates) || candidates.length === 0) return null; + const sourceMode = this._getTargetMode(sourceNode); + for (let i = candidates.length - 1; i >= 0; i--) { + if (this._getTargetMode(candidates[i]) === sourceMode) return candidates[i]; + } + return candidates[candidates.length - 1]; + } + + // Normalize debug_handle (int | int[] | null) → Set + _normalizeHandle(dh) { + if (!dh && dh !== 0) return new Set(); + if (typeof dh === 'number') return dh !== 0 ? new Set([dh]) : new Set(); + if (Array.isArray(dh)) return new Set(dh.filter((x) => typeof x === 'number' && x !== 0)); + return new Set(); + } + + _syncByFromNodeRoot(sourceViewer, nodeId, targetViewer) { + const srcNode = sourceViewer.store.activeNodeMap.get(nodeId); + const srcRoot = srcNode?.info?.from_node_root; + if (!srcRoot) return null; + const candidates = targetViewer.store.activeNodes.filter((n) => { + const tgtNode = targetViewer.store.activeNodeMap.get(n.id); + return tgtNode?.info?.from_node_root === srcRoot; + }); + const picked = this._pickCandidateByTargetMode(srcNode, candidates); + return picked ? picked.id : null; + } + + _getAllFromNodeRootCandidates(sourceViewer, nodeId, targetViewer) { + const srcNode = sourceViewer.store.activeNodeMap.get(nodeId); + const srcRoot = srcNode?.info?.from_node_root; + if (!srcRoot) return []; + return targetViewer.store.activeNodes + .filter((n) => { + const tgtNode = targetViewer.store.activeNodeMap.get(n.id); + return tgtNode?.info?.from_node_root === srcRoot; + }) + .map((n) => n.id); + } + + _syncByDebugHandle(sourceViewer, nodeId, targetViewer) { + const srcNode = sourceViewer.store.activeNodeMap.get(nodeId); + const srcSet = this._normalizeHandle(srcNode?.info?.debug_handle); + if (srcSet.size === 0) return null; + const candidates = targetViewer.store.activeNodes.filter((n) => { + const tgtSet = this._normalizeHandle( + targetViewer.store.activeNodeMap.get(n.id)?.info?.debug_handle + ); + for (const v of srcSet) { if (tgtSet.has(v)) return true; } + return false; + }); + const picked = this._pickCandidateByTargetMode(srcNode, candidates); + return picked ? picked.id : null; + } + + _getAllDebugHandleCandidates(sourceViewer, nodeId, targetViewer) { + const srcNode = sourceViewer.store.activeNodeMap.get(nodeId); + const srcSet = this._normalizeHandle(srcNode?.info?.debug_handle); + if (srcSet.size === 0) return []; + return targetViewer.store.activeNodes + .filter((n) => { + const tgtSet = this._normalizeHandle( + targetViewer.store.activeNodeMap.get(n.id)?.info?.debug_handle + ); + for (const v of srcSet) { if (tgtSet.has(v)) return true; } + return false; + }) + .map((n) => n.id); + } + + _updateMergedInfo(nodeIdMap) { + if (!this._infoRow) return; + while (this._infoRow.firstChild) this._infoRow.removeChild(this._infoRow.firstChild); + + if (!nodeIdMap) { + const ph = document.createElement('div'); + ph.className = 'fx-compare-info-placeholder'; + ph.textContent = 'No node selected — click a node to compare'; + this._infoRow.appendChild(ph); + return; + } + + const makeCell = (cls, text) => { + const el = document.createElement('div'); + el.className = cls; + el.textContent = text; + return el; + }; + + const visViewerPairs = this.viewers + .map((v, i) => ({ v, name: this._viewerNames[i] })) + .filter(({ name }) => this._visibleViewers.has(name)); + const sections = [{ key: 'base', title: null, props: [] }]; + const sectionByKey = new Map([['base', sections[0]]]); + const rowData = new Map(); + const addProp = (section, prop) => { + if (!section.props.includes(prop)) section.props.push(prop); + }; + const addSection = (key, title) => { + if (!sectionByKey.has(key)) { + const section = { key, title, props: [] }; + sectionByKey.set(key, section); + sections.push(section); + } + return sectionByKey.get(key); + }; + + nodeIdMap.forEach((nid, v) => { + const node = v.store.activeNodeMap.get(nid); + if (!node) return; + const baseNode = (v.store.baseData.nodes || []).find((n) => n.id === nid) || node; + const bySection = new Map(); + const baseProps = { id: nid, op: baseNode.op || node.op || '', ...(baseNode.info || {}) }; + Object.keys(baseProps).forEach((prop) => addProp(sections[0], prop)); + bySection.set('base', baseProps); + + const activeExtensions = Array.from(v.controller?.state?.activeExtensions || []); + activeExtensions.forEach((extId) => { + const ext = v.store.extensions && v.store.extensions[extId]; + const extNode = ext && ext.nodes && ext.nodes[nid]; + if (!extNode || !extNode.info) return; + const title = ext.name || extId; + const section = addSection(`ext:${extId}`, title); + const props = { ...extNode.info }; + Object.keys(props).forEach((prop) => addProp(section, prop)); + bySection.set(section.key, props); + }); + + rowData.set(v, bySection); + }); + + // Header row + this._infoRow.appendChild(makeCell('fx-compare-info-hdr fx-compare-info-prop', 'Property')); + visViewerPairs.forEach(({ name }) => { + this._infoRow.appendChild(makeCell('fx-compare-info-hdr', name)); + }); + + let rowIdx = 0; + sections.forEach((section) => { + if (section.props.length === 0) return; + if (section.title) { + this._infoRow.appendChild(makeCell('fx-compare-info-section', section.title)); + } + section.props.forEach((prop) => { + const rowCls = rowIdx % 2 === 1 ? ' fx-compare-info-row-alt' : ''; + rowIdx++; + const vals = visViewerPairs.map(({ v }) => { + const bySection = rowData.get(v); + const d = bySection && bySection.get(section.key); + if (!d || d[prop] === undefined) return ' -- '; + const raw = d[prop]; + return (raw !== null && typeof raw === 'object') ? JSON.stringify(raw, null, 2) : String(raw); + }); + const allSame = vals.every((v) => v === vals[0]); + this._infoRow.appendChild(makeCell('fx-compare-info-prop' + rowCls, prop)); + vals.forEach((val) => { + this._infoRow.appendChild(makeCell('fx-compare-info-val' + rowCls + (allSame ? '' : ' fx-compare-info-diff'), val)); + }); + }); + }); + } + + _wireStateSync() { + this.viewers.forEach((viewer) => { + const off = viewer.on('statechange', (evt) => { + if (this._guards.has(viewer)) return; + const prev = evt.prevState || {}; + const next = evt.nextState || {}; + const themeChanged = prev.theme !== next.theme && typeof next.theme === 'string'; + if (themeChanged) { + this.viewers.forEach((other) => { + if (other === viewer) return; + this._applyGuarded(other, () => { + other.setTheme(next.theme); + }); + }); + } + + const themeNameChanged = prev.themeName !== next.themeName && typeof next.themeName === 'string'; + if (themeNameChanged && this._themeSelect) { + this._themeSelect.value = next.themeName; + } + if (themeNameChanged) { + this._applyCompareTheme(next.themeName); + } + + if (prev.previewNodeId !== next.previewNodeId) { + this._syncPreviewAcrossViewers(viewer, next.previewNodeId || null); + } + }); + this._offs.push(off); + }); + } + + setSync(syncPatch = {}) { + this.sync = { ...this.sync, ...syncPatch }; + if (this.sync.mode !== 'auto') { + this.viewers.forEach((v) => v.removeHighlightGroup('_sync_candidates')); + } + this._rebuildSyncPanel(); + } + + /** @deprecated No-op — tiled layout is always used in compare mode */ + setTiled() {} + + /** @deprecated No-op — sidebar replaces sharedTaskbar */ + setCompact() {} + + _applyCompareTheme(themeName) { + this._currentTheme = themeName; + const isDark = themeName === 'dark'; + const r = this._root.style; + if (isDark) { + r.setProperty('--cmp-bg', '#1e1e1e'); + r.setProperty('--cmp-text', '#ffffff'); + r.setProperty('--cmp-border', '#444444'); + r.setProperty('--cmp-border-strong', '#555555'); + r.setProperty('--cmp-sidebar-bg', 'rgba(255,255,255,0.05)'); + r.setProperty('--cmp-info-bg', '#1e1e1e'); + r.setProperty('--cmp-prop-bg', 'rgba(255,255,255,0.05)'); + r.setProperty('--cmp-hdr-bg', 'rgba(255,255,255,0.08)'); + r.setProperty('--cmp-diff-bg', 'rgba(255,160,40,0.10)'); + r.setProperty('--cmp-diff-accent', '#c87830'); + r.setProperty('--cmp-name-bg', 'rgba(255,255,255,0.1)'); + r.setProperty('--cmp-ui-bg', 'rgba(30,30,30,0.95)'); + r.setProperty('--cmp-ui-hover', '#333333'); + } else { + r.setProperty('--cmp-bg', '#ffffff'); + r.setProperty('--cmp-text', '#000000'); + r.setProperty('--cmp-border', '#e5e7eb'); + r.setProperty('--cmp-border-strong', '#cccccc'); + r.setProperty('--cmp-sidebar-bg', 'rgba(0,0,0,0.02)'); + r.setProperty('--cmp-info-bg', '#ffffff'); + r.setProperty('--cmp-prop-bg', 'rgba(0,0,0,0.02)'); + r.setProperty('--cmp-hdr-bg', 'rgba(0,0,0,0.04)'); + r.setProperty('--cmp-diff-bg', 'rgba(255,140,0,0.06)'); + r.setProperty('--cmp-diff-accent', '#e08a3c'); + r.setProperty('--cmp-name-bg', 'rgba(0,0,0,0.06)'); + r.setProperty('--cmp-ui-bg', 'rgba(255,255,255,0.95)'); + r.setProperty('--cmp-ui-hover', '#f0f8ff'); + } + const theme = (typeof THEMES !== 'undefined' && THEMES[themeName]) || THEMES?.light; + if (theme && this._sidebarEl) { + this._sidebarEl.querySelectorAll('.fx-button, .fx-select').forEach((el) => { + el.style.backgroundColor = theme.uiBg; + el.style.color = theme.text; + el.style.borderColor = theme.uiBorder; + }); + } + if (theme) { + for (const menu of this._openPortalMenus) { + menu.style.backgroundColor = theme.uiBg; + menu.style.color = theme.text; + menu.style.borderColor = theme.uiBorder; + menu.querySelectorAll('label, .fx-button, .fx-select').forEach((el) => { + el.style.color = theme.text; + }); + } + } + } + + _applyGuarded(viewer, fn) { + this._guards.add(viewer); + try { + fn(); + } finally { + setTimeout(() => this._guards.delete(viewer), 0); + } + } + + setViewerVisible(name, visible) { + this._setViewerVisible(name, visible); + } + + refreshLayout(options = {}) { + this._scheduleLayoutRefresh(options); + } + + destroy() { + this._teardownCompareDOM(); + this._offs.forEach((off) => { + try { off(); } catch (_) {} + }); + this._offs = []; + } +} + +if (typeof globalThis !== 'undefined') { + globalThis.FXGraphCompare = FXGraphCompare; +} diff --git a/devtools/fx_viewer/templates/fx_graph_viewer.js b/devtools/fx_viewer/templates/fx_graph_viewer.js new file mode 100644 index 00000000000..61279b6e474 --- /dev/null +++ b/devtools/fx_viewer/templates/fx_graph_viewer.js @@ -0,0 +1,841 @@ +// FXGraphViewer: public facade for the graph viewer runtime. +// +// Construction: +// FXGraphViewer.create(config) — preferred factory +// new FXGraphViewer(containerId, payload) — legacy compat +// +// Key public methods: +// init() — initial zoom-to-fit / position +// getState() / setState(patch) — state snapshot and update +// setTheme(name) — 'light' | 'dark' | custom +// setLayers(ids[]) / setColorBy(id) +// selectNode(id, opts?) / clearSelection() +// panToNode(id) / animateToNode(id, opts?) / zoomToFit() +// upsertLayer / removeLayer / patchLayerNodes / setLayerLabel / setColorRule +// enterFullscreen() / exitFullscreen() +// on(event, fn) / off(event, fn) +// destroy() +// +// Highlight group API (programmatic overlay, independent of selection): +// addHighlightGroup(groupId, nodeIds, color) — add/replace a named group +// removeHighlightGroup(groupId) — remove one group +// clearAllHighlightGroups() — remove all groups +// getHighlightGroups() — Map +// +// Events: 'selectionchange', 'statechange', 'layoutchange', 'themechange', 'error' +class FXGraphViewer { + static create(config) { + return new FXGraphViewer(config); + } + + static registerTheme(name, themeTokens) { + if (!name || !themeTokens || typeof themeTokens !== 'object') { + throw new Error('registerTheme(name, themeTokens) requires valid inputs'); + } + THEMES[name] = themeTokens; + } + + constructor(arg1, arg2) { + this._listeners = new Map(); + this._layoutState = {}; + this._teardownFns = []; + + this.config = this._normalizeConfig(arg1, arg2); + this.containerId = this.config._resolved.root.id || 'fx-viewer-root'; + this.rootContainer = this.config._resolved.root; + this._injectStyles(); + this._buildShell(); + this.config._resolved.slots = this._resolveSlots((this.config.mount && this.config.mount.slots) || {}); + + this.store = new GraphDataStore(this.config.payload); + this.searchEngine = new SearchEngine(this.store); + this.controller = new ViewerController(this, this.config.state || {}); + + const slots = this.config._resolved.slots; + const canvasMount = slots.canvas || this.mainArea; + this.canvasRenderer = new CanvasRenderer(canvasMount, this); + this.canvasContainer = this.canvasRenderer.canvasContainer; + + if (this._shouldUseInternalSidebar()) { + this.resizerH = document.createElement('div'); + this.resizerH.className = 'fx-resizer-h'; + this.resizerH.title = 'Drag to resize minimap height.'; + } else { + this.resizerH = null; + } + + this.ui = new UIManager(this.mainArea, this, { + controls: this.config.ui.controls, + mounts: { + toolbarContainer: slots.toolbar || this.mainArea, + legendContainer: slots.legend || this.mainArea, + infoContainer: slots.info || (this._shouldUseInternalSidebar() ? this.sidebar : this.mainArea), + }, + }); + + if (this._isMinimapVisible()) { + const minimapMount = slots.minimap || (this._shouldUseInternalSidebar() ? this.sidebar : this.mainArea); + this.minimapRenderer = new MinimapRenderer(minimapMount, this); + const minimapHeight = this.config.layout?.panels?.minimap?.height; + if (typeof minimapHeight === 'number' && minimapHeight > 0) { + this.minimapRenderer.container.style.height = `${minimapHeight}px`; + } + if (this.resizerH && minimapMount === this.sidebar) { + this.sidebar.insertBefore(this.resizerH, this.minimapRenderer.container); + } + } else { + this.minimapRenderer = null; + } + + this.setupResizer(); + this.applyLayout(this.config.layout); + + if (slots.info) { + slots.info.style.overflow = 'hidden'; + slots.info.style.minHeight = '0'; + if (this.ui && this.ui.infoPanel) { + this.ui.infoPanel.style.height = '100%'; + this.ui.infoPanel.style.overflowY = 'auto'; + this.ui.infoPanel.style.boxSizing = 'border-box'; + } + } + } + + _normalizeConfig(arg1, arg2) { + const isNewConfig = arg1 && typeof arg1 === 'object' && 'payload' in arg1; + const config = isNewConfig + ? { ...arg1 } + : { + payload: arg2, + mount: { root: typeof arg1 === 'string' ? `#${arg1}` : arg1 }, + }; + + if (!config.payload) { + throw new Error('FXGraphViewer requires a payload'); + } + + const root = this._resolveElement(config.mount && config.mount.root); + if (!root) { + throw new Error(`FXGraphViewer root mount not found: ${String(config.mount && config.mount.root)}`); + } + + const preset = ((config.layout && config.layout.preset) || 'split').toLowerCase(); + const presetDefaults = this._presetDefaults(preset); + const mergedLayout = this._deepMerge(presetDefaults.layout, config.layout || {}); + + const mergedUI = this._deepMerge(presetDefaults.ui, config.ui || {}); + if ( + mergedLayout && + mergedLayout.fullscreen && + Object.prototype.hasOwnProperty.call(mergedLayout.fullscreen, 'button') + ) { + mergedUI.controls.fullscreenButton = !!mergedLayout.fullscreen.button; + } + const mergedState = this._deepMerge(presetDefaults.state, config.state || {}); + + const slots = (config.mount && config.mount.slots) || {}; + const resolvedSlots = this._resolveSlots(slots); + + if (Array.isArray(mergedState.activeExtensions)) { + mergedState.activeExtensions = mergedState.activeExtensions.slice(); + } + + return { + payload: config.payload, + mount: config.mount || { root }, + layout: mergedLayout, + ui: mergedUI, + state: mergedState, + _resolved: { + root, + slots: resolvedSlots, + preset, + }, + }; + } + + _resolveElement(ref) { + if (!ref) return null; + if (typeof ref === 'string') { + if (ref.startsWith('#') || ref.startsWith('.')) { + return document.querySelector(ref); + } + return document.getElementById(ref) || document.querySelector(ref); + } + if (ref instanceof HTMLElement) return ref; + return null; + } + + _resolveSlots(slots) { + return { + canvas: this._resolveElement(slots.canvas), + toolbar: this._resolveElement(slots.toolbar), + info: this._resolveElement(slots.info), + minimap: this._resolveElement(slots.minimap), + legend: this._resolveElement(slots.legend), + }; + } + + _presetDefaults(preset) { + const split = { + layout: { + preset, + panels: { + sidebar: { visible: true, width: 500, resizable: true, collapsible: true }, + info: { visible: true, dock: 'sidebar' }, + minimap: { visible: true, dock: 'sidebar', height: 500, resizable: true }, + legend: { visible: true, dock: 'canvas' }, + }, + fullscreen: { enabled: true, button: false }, + }, + ui: { + controls: { + toolbar: true, + search: true, + layers: true, + colorBy: true, + theme: true, + legend: true, + zoomButtons: true, + fullscreenButton: false, + highlightButton: true, + }, + }, + state: { + theme: 'light', + colorBy: 'base', + highlightAncestors: true, + }, + }; + + if (preset === 'compact') { + split.layout.panels.sidebar.visible = false; + split.layout.panels.minimap.visible = false; + split.layout.panels.info.visible = false; + } + if (preset === 'headless') { + split.layout.panels.sidebar.visible = false; + split.layout.panels.minimap.visible = false; + split.layout.panels.info.visible = false; + split.ui.controls.toolbar = false; + split.ui.controls.search = false; + split.ui.controls.layers = false; + split.ui.controls.colorBy = false; + split.ui.controls.theme = false; + split.ui.controls.legend = false; + } + if (preset === 'custom') { + split.layout.panels.sidebar.visible = false; + split.layout.panels.minimap.visible = false; + split.layout.panels.info.visible = false; + } + return split; + } + + _deepMerge(base, patch) { + if (!patch || typeof patch !== 'object') { + return Array.isArray(base) ? base.slice() : { ...base }; + } + const out = Array.isArray(base) ? base.slice() : { ...base }; + Object.keys(patch).forEach((key) => { + const patchVal = patch[key]; + const baseVal = out[key]; + if ( + patchVal && + typeof patchVal === 'object' && + !Array.isArray(patchVal) && + baseVal && + typeof baseVal === 'object' && + !Array.isArray(baseVal) + ) { + out[key] = this._deepMerge(baseVal, patchVal); + } else { + out[key] = patchVal; + } + }); + return out; + } + + _injectStyles() { + if (document.getElementById('fx-viewer-styles')) return; + const style = document.createElement('style'); + style.id = 'fx-viewer-styles'; + style.innerHTML = ` + .fx-viewer-wrapper { display: flex; flex-direction: row; width: 100%; height: 100%; overflow: hidden; font-family: sans-serif; } + .fx-main-area { flex: 1; position: relative; overflow: hidden; min-width: 60%; } + .fx-resizer { width: 6px; background: #ccc; cursor: col-resize; z-index: 20; transition: background 0.2s; } + .fx-resizer:hover, .fx-resizer.dragging { background: #999; } + .fx-sidebar { width: 500px; display: flex; flex-direction: column; background: #fff; border-left: 1px solid #ccc; z-index: 10; } + .fx-sidebar.collapsed { display: none; } + .fx-canvas { display: block; width: 100%; height: 100%; } + .fx-taskbar { position: absolute; top: 10px; left: 10px; right: 10px; min-height: 40px; border-radius: 4px; display: flex; align-items: center; padding: 0 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1); z-index: 10; border: 1px solid transparent; overflow: visible; flex-wrap: wrap; gap: 6px; } + .fx-search-container { position: relative; flex: 1; max-width: 400px; } + .fx-search-input { width: 100%; padding: 6px; box-sizing: border-box; } + .fx-layers-menu { position: absolute; top: 100%; right: 0; min-width: 260px; max-width: 420px; max-height: 60vh; overflow-y: auto; display: none; box-shadow: 0 4px 6px rgba(0,0,0,0.1); border: 1px solid transparent; z-index: 200; } + .fx-search-item { padding: 8px; cursor: pointer; border-bottom: 1px solid transparent; } + .fx-search-item:hover, .fx-search-item.active { background: var(--fx-ui-hover, #f0f8ff); } + .fx-minimap-container { width: 100%; height: 500px; border-top: 1px solid transparent; flex-shrink: 0; } + .fx-minimap { width: 100%; height: 100%; display: block; cursor: crosshair; } + .fx-info-panel { flex: 1; overflow-y: auto; padding: 15px; font-size: 13px; display: block; } + .fx-info-panel h3 { margin-top: 0; margin-bottom: 15px; font-size: 15px; word-break: break-all; } + .fx-info-table { width: 100%; border-collapse: collapse; border: 1px solid transparent; } + .fx-info-table th, .fx-info-table td { border: 1px solid transparent; padding: 6px; text-align: left; vertical-align: top; } + .fx-info-table th { width: 60px; font-weight: bold; } + .fx-ext-header { margin-top: 15px; padding: 4px 6px; font-weight: bold; font-size: 12px; letter-spacing: 0.5px; background: rgba(0,0,0,0.03); } + .fx-legend-overlay { position: absolute; right: 10px; bottom: 10px; padding: 8px 10px; border: 1px solid transparent; border-radius: 4px; font-size: 12px; max-width: 260px; max-height: 40vh; overflow-y: auto; box-shadow: 0 2px 6px rgba(0,0,0,0.1); z-index: 15; } + .fx-link { color: #0366d6; cursor: pointer; text-decoration: none; font-family: monospace; display: inline-block; margin-bottom: 4px; word-break: break-all; } + .fx-link:hover { text-decoration: underline; } + .fx-button { margin-left: 10px; padding: 6px 12px; cursor: pointer; background: transparent; border: 1px solid transparent; border-radius: 4px; font-size: 16px; display: flex; align-items: center; justify-content: center; transition: background 0.2s; } + .fx-select { margin-left: 10px; padding: 4px; border-radius: 4px; font-size: 14px; text-align-last: center; } + .fx-resizer-h { height: 6px; background: #ccc; cursor: row-resize; z-index: 20; transition: background 0.2s; flex-shrink: 0; } + .fx-resizer-h:hover, .fx-resizer-h.dragging { background: #999; } + .fx-hidden { display: none !important; } + .fx-compare-root { --cmp-bg: #fff; --cmp-text: #000; --cmp-border: #e5e7eb; --cmp-border-strong: #ccc; --cmp-sidebar-bg: rgba(0,0,0,0.02); --cmp-info-bg: #fff; --cmp-prop-bg: rgba(0,0,0,0.02); --cmp-hdr-bg: rgba(0,0,0,0.04); --cmp-diff-bg: rgba(255,140,0,0.06); --cmp-diff-accent: #e08a3c; --cmp-name-bg: rgba(0,0,0,0.06); --cmp-ui-bg: rgba(255,255,255,0.95); --cmp-ui-hover: #f0f8ff; display: flex; flex-direction: column; width: 100%; height: 100%; overflow-y: auto; overflow-x: hidden; background: var(--cmp-bg); color: var(--cmp-text); padding-bottom: 24px; box-sizing: border-box; position: relative; } + .fx-compare-grid { display: grid; grid-template-rows: minmax(300px, 600px) auto minmax(400px, 80vh) auto; overflow: visible; border: 6px solid var(--cmp-border-strong); margin: 0 8px 8px 0; } + .fx-compare-sidebar-cell { grid-column: 1; grid-row: 1 / 4; display: flex; flex-direction: column; border-right: 1px solid var(--cmp-border); background: var(--cmp-sidebar-bg); overflow: hidden; position: sticky; top: 0; align-self: start; max-height: 100vh; z-index: 10; padding: 6px 4px; gap: 4px; } + .fx-compare-sidebar-info-cell { grid-column: 1; grid-row: 4; border-right: 1px solid var(--cmp-border); border-top: 1px solid var(--cmp-border-strong); background: var(--cmp-sidebar-bg); overflow: hidden; } + .fx-compare-minimap-cell { overflow: hidden; border-bottom: 6px solid var(--cmp-border-strong); border-left: 6px solid var(--cmp-border-strong); } + .fx-compare-minimap-cell .fx-minimap-container { width: 100%; height: 100% !important; border-top: none; } + .fx-compare-name-cell { grid-row: 2; display: flex; align-items: center; justify-content: center; height: 24px; padding: 0 6px; font-size: 11px; font-weight: 600; background: var(--cmp-name-bg); border-bottom: 6px solid var(--cmp-border-strong); border-left: 6px solid var(--cmp-border-strong); overflow: hidden; text-overflow: ellipsis; white-space: nowrap; box-sizing: border-box; } + .fx-compare-canvas-cell { overflow: hidden; border-left: 6px solid var(--cmp-border-strong); position: relative; } + .fx-compare-canvas-cell .fx-main-area { min-width: 0; width: 100%; height: 100%; } + .fx-compare-info-row { grid-column: 1 / -1; grid-row: 4; display: grid; grid-template-columns: subgrid; align-content: start; border-top: 3px solid var(--cmp-border-strong); overflow: visible; font-size: 13px; background: var(--cmp-info-bg); } + .fx-compare-info-prop { font-weight: 600; padding: 6px 10px; border-bottom: 1px solid var(--cmp-border); background: var(--cmp-prop-bg); border-right: 2px solid var(--cmp-border-strong); min-width: 0; overflow-x: auto; overflow-y: hidden; white-space: nowrap; position: sticky; left: 0; z-index: 2; letter-spacing: 0.2px; } + .fx-compare-info-val { font-family: monospace; padding: 6px 12px; border-bottom: 1px solid var(--cmp-border); border-left: 2px solid var(--cmp-border-strong); min-width: 0; overflow: hidden; white-space: normal; word-break: break-word; } + .fx-compare-info-hdr { font-weight: 700; padding: 7px 10px; border-bottom: 3px solid var(--cmp-border-strong); border-left: 6px solid var(--cmp-border-strong); background: var(--cmp-hdr-bg); position: sticky; top: 0; min-width: 0; overflow: hidden; white-space: normal; word-break: break-word; z-index: 1; letter-spacing: 0.4px; font-size: 12px; text-transform: uppercase; } + .fx-compare-info-section { grid-column: 1 / -1; padding: 7px 10px; font-weight: 700; font-size: 12px; letter-spacing: 0.4px; text-transform: uppercase; border-top: 2px solid var(--cmp-border-strong); border-bottom: 1px solid var(--cmp-border); background: var(--cmp-hdr-bg); position: sticky; left: 0; z-index: 2; } + .fx-compare-info-diff { background: var(--cmp-diff-bg); border-left: 3px solid var(--cmp-diff-accent); } + .fx-compare-info-row-alt { background: rgba(0,0,0,0.018); } + .fx-compare-root:fullscreen, .fx-compare-root:-webkit-full-screen, .fx-compare-root:-moz-full-screen, .fx-compare-root:-ms-fullscreen { background: var(--cmp-bg); color: var(--cmp-text); width: 100vw; height: 100vh; overflow-y: auto; overflow-x: hidden; } + .fx-compare-root:fullscreen .fx-compare-sidebar-cell, .fx-compare-root:-webkit-full-screen .fx-compare-sidebar-cell, .fx-compare-root:-moz-full-screen .fx-compare-sidebar-cell, .fx-compare-root:-ms-fullscreen .fx-compare-sidebar-cell { background: var(--cmp-sidebar-bg); border-right-color: var(--cmp-border); } + .fx-compare-info-placeholder { color: #888; text-align: center; padding: 12px; grid-column: 1 / -1; } + .fx-compare-canvas-cell .fx-legend-overlay { max-width: 40%; max-height: 20%; overflow: auto; font-size: 11px; } + .fx-compare-portal-menu { position: absolute; min-width: 220px; max-height: 60vh; overflow-y: auto; background: rgba(255,255,255,0.95); backdrop-filter: blur(4px); -webkit-backdrop-filter: blur(4px); border: 1px solid #ccc; box-shadow: 0 4px 12px rgba(0,0,0,0.15); border-radius: 4px; z-index: 9999; padding: 4px 0; font-size: 13px; } + .fx-search-menu { position: absolute; top: 100%; left: 0; right: 0; max-height: 300px; overflow-y: auto; display: none; box-shadow: 0 4px 6px rgba(0,0,0,0.1); border: 1px solid transparent; border-top: none; z-index: 100; background: rgba(255,255,255,0.88); backdrop-filter: blur(2px); } + `; + document.head.appendChild(style); + } + + _buildShell() { + const root = this.rootContainer; + const oldWrappers = root.querySelectorAll(':scope > .fx-viewer-wrapper[data-fx-viewer-owned="true"]'); + oldWrappers.forEach((node) => node.remove()); + + this.wrapper = document.createElement('div'); + this.wrapper.className = 'fx-viewer-wrapper'; + this.wrapper.dataset.fxViewerOwned = 'true'; + root.appendChild(this.wrapper); + + this.mainArea = document.createElement('div'); + this.mainArea.className = 'fx-main-area'; + this.wrapper.appendChild(this.mainArea); + + this.resizer = document.createElement('div'); + this.resizer.className = 'fx-resizer'; + this.resizer.title = 'Drag to resize sidebar. Double click to toggle.'; + this.wrapper.appendChild(this.resizer); + + this.sidebar = document.createElement('div'); + this.sidebar.className = 'fx-sidebar'; + this.wrapper.appendChild(this.sidebar); + } + + _shouldUseInternalSidebar() { + const slots = this.config._resolved.slots; + const panels = this.config.layout?.panels || {}; + const infoInternal = panels.info?.visible !== false && !slots.info; + const minimapInternal = panels.minimap?.visible !== false && !slots.minimap; + return !!(infoInternal || minimapInternal); + } + + _isSidebarVisible() { + const panels = this.config.layout?.panels || {}; + if (panels.sidebar && panels.sidebar.visible === false) { + return false; + } + return this._shouldUseInternalSidebar(); + } + + _isMinimapVisible() { + const panels = this.config.layout?.panels || {}; + return !(panels.minimap && panels.minimap.visible === false); + } + + setupResizer() { + let isResizing = false; + let isResizingH = false; + + if (!this.resizer) return; + + const onResizerMouseDown = (e) => { + if (!this._isSidebarVisible() || this.config.layout?.panels?.sidebar?.resizable === false) return; + isResizing = true; + this.resizer.classList.add('dragging'); + document.body.style.cursor = 'col-resize'; + e.preventDefault(); + }; + fxOn(this._teardownFns, this.resizer, 'mousedown', onResizerMouseDown); + + if (this.resizerH) { + const onResizerHMouseDown = (e) => { + if (this.config.layout?.panels?.minimap?.resizable === false) return; + isResizingH = true; + this.resizerH.classList.add('dragging'); + document.body.style.cursor = 'row-resize'; + e.preventDefault(); + }; + fxOn(this._teardownFns, this.resizerH, 'mousedown', onResizerHMouseDown); + } + + const onWindowMouseMove = (e) => { + if (isResizing) { + const containerRect = this.wrapper.getBoundingClientRect(); + let newWidth = containerRect.right - e.clientX; + newWidth = Math.max(150, Math.min(newWidth, containerRect.width * 0.4)); + this.sidebar.style.width = `${newWidth}px`; + this._layoutState.sidebarWidth = newWidth; + + this.canvasRenderer.resize(); + if (this.minimapRenderer) { + this.minimapRenderer.resize(); + this.minimapRenderer.generateThumbnail(); + } + this.renderAll(); + } else if (isResizingH && this.minimapRenderer) { + const containerRect = this.wrapper.getBoundingClientRect(); + let newHeight = containerRect.bottom - e.clientY; + newHeight = Math.max(100, Math.min(newHeight, containerRect.height - 100)); + this.minimapRenderer.container.style.height = `${newHeight}px`; + this._layoutState.minimapHeight = newHeight; + + this.minimapRenderer.resize(); + this.minimapRenderer.generateThumbnail(); + this.renderAll(); + } + }; + fxOn(this._teardownFns, window, 'mousemove', onWindowMouseMove); + + const onWindowMouseUp = () => { + if (isResizing) { + isResizing = false; + this.resizer.classList.remove('dragging'); + document.body.style.cursor = ''; + } + if (isResizingH && this.resizerH) { + isResizingH = false; + this.resizerH.classList.remove('dragging'); + document.body.style.cursor = ''; + } + }; + fxOn(this._teardownFns, window, 'mouseup', onWindowMouseUp); + + const onResizerDblClick = () => { + if (this.config.layout?.panels?.sidebar?.collapsible === false) return; + this.sidebar.classList.toggle('collapsed'); + requestAnimationFrame(() => { + this.canvasRenderer.resize(); + this.renderAll(); + }); + }; + fxOn(this._teardownFns, this.resizer, 'dblclick', onResizerDblClick); + } + + applyLayout(layoutPatch) { + if (layoutPatch) { + this.config.layout = this._deepMerge(this.config.layout, layoutPatch); + } + const panels = this.config.layout?.panels || {}; + + const sidebarVisible = this._isSidebarVisible(); + this.sidebar.style.display = sidebarVisible ? '' : 'none'; + this.resizer.style.display = sidebarVisible && panels.sidebar?.resizable !== false ? '' : 'none'; + + if (panels.sidebar && typeof panels.sidebar.width === 'number') { + this.sidebar.style.width = `${panels.sidebar.width}px`; + this._layoutState.sidebarWidth = panels.sidebar.width; + } + + if (this.ui && this.ui.infoPanel) { + this.ui.infoPanel.style.display = panels.info?.visible === false ? 'none' : ''; + } + + if (this.minimapRenderer) { + this.minimapRenderer.container.style.display = panels.minimap?.visible === false ? 'none' : ''; + if (typeof panels.minimap?.height === 'number') { + this.minimapRenderer.container.style.height = `${panels.minimap.height}px`; + this._layoutState.minimapHeight = panels.minimap.height; + } + } + + if (this.resizerH) { + const showResizerH = panels.minimap?.visible !== false && panels.minimap?.resizable !== false; + this.resizerH.style.display = showResizerH ? '' : 'none'; + } + + if (this.ui) { + this.ui.setControlVisibility({ + toolbar: this.config.ui.controls.toolbar, + search: this.config.ui.controls.search, + layers: this.config.ui.controls.layers || this.config.ui.controls.colorBy, + theme: this.config.ui.controls.theme, + legend: this.config.ui.controls.legend && panels.legend?.visible !== false, + fullscreenButton: !!this.config.ui.controls.fullscreenButton, + }); + } + + this.canvasRenderer.resize(); + if (this.minimapRenderer) { + this.minimapRenderer.resize(); + this.minimapRenderer.generateThumbnail(); + } + this.renderAll(); + } + + init() { + if (this.minimapRenderer) { + this.minimapRenderer.generateThumbnail(); + } + + if (this.store.baseData.nodes.length > 10) { + const firstNode = this.store.baseData.nodes[0]; + const k = 0.5; + const rect = this.canvasContainer.getBoundingClientRect(); + this.controller.transform.k = k; + this.controller.transform.x = rect.width / 2 - firstNode.x * k; + this.controller.transform.y = rect.height / 2 - firstNode.y * k; + this.renderAll(); + } else { + this.controller.zoomToFit(); + } + } + + renderAll() { + if (this.canvasRenderer) this.canvasRenderer.render(); + if (this.minimapRenderer) this.minimapRenderer.render(); + } + + on(eventName, listener) { + if (!this._listeners.has(eventName)) { + this._listeners.set(eventName, new Set()); + } + this._listeners.get(eventName).add(listener); + return () => this.off(eventName, listener); + } + + off(eventName, listener) { + if (!this._listeners.has(eventName)) return; + this._listeners.get(eventName).delete(listener); + } + + _emit(eventName, payload) { + const listeners = this._listeners.get(eventName); + if (!listeners) return; + const event = { + type: eventName, + timestamp: Date.now(), + ...payload, + }; + listeners.forEach((cb) => { + try { + cb(event); + } catch (err) { + console.error(`FXGraphViewer listener error on '${eventName}':`, err); + } + }); + } + + getState() { + const state = this.controller.snapshotState(); + return { + ...state, + layoutState: { + ...this._layoutState, + sidebarCollapsed: this.sidebar && this.sidebar.classList.contains('collapsed'), + }, + uiVisibility: { + ...(state.uiVisibility || {}), + toolbar: this.ui && this.ui.taskbar ? this.ui.taskbar.style.display !== 'none' : false, + search: this.ui && this.ui.searchContainer ? this.ui.searchContainer.style.display !== 'none' : false, + layers: this.ui && this.ui.layersContainer ? this.ui.layersContainer.style.display !== 'none' : false, + theme: this.ui && this.ui.themeSelect ? this.ui.themeSelect.style.display !== 'none' : false, + legend: this.ui && this.ui.legendOverlay ? this.ui.legendOverlay.style.display !== 'none' : false, + fullscreenButton: this.ui && this.ui.btnFullscreen ? this.ui.btnFullscreen.style.display !== 'none' : false, + }, + }; + } + + setState(patch, opts = {}) { + const source = opts.source || 'api'; + const nextPatch = { ...(patch || {}) }; + if (nextPatch.camera && typeof nextPatch.camera === 'object') { + if (Number.isFinite(nextPatch.camera.x)) this.controller.transform.x = nextPatch.camera.x; + if (Number.isFinite(nextPatch.camera.y)) this.controller.transform.y = nextPatch.camera.y; + if (Number.isFinite(nextPatch.camera.k)) this.controller.transform.k = nextPatch.camera.k; + delete nextPatch.camera; + } + const hasSearchQuery = typeof nextPatch.searchQuery === 'string'; + const searchQuery = hasSearchQuery ? nextPatch.searchQuery : null; + if (hasSearchQuery) delete nextPatch.searchQuery; + + this.controller.setState(nextPatch, { source }); + if (hasSearchQuery) { + if (this.ui && this.ui.searchInput) { + this.ui.searchInput.value = searchQuery; + } + this.controller.handleSearch(searchQuery); + } + this.renderAll(); + } + + replaceState(nextState, opts = {}) { + const source = opts.source || 'api'; + const ns = nextState || {}; + const replacePatch = { + hoveredNodeId: null, + hoveredEdge: null, + selectedNodeId: ns.selectedNodeId || null, + selectedEdge: ns.selectedEdge || null, + previewNodeId: null, + ancestors: new Set(), + descendants: new Set(), + searchCandidates: [], + searchSelectedIndex: -1, + highlightAncestors: ns.highlightAncestors !== false, + themeName: ns.themeName || ns.theme || 'light', + activeExtensions: new Set(ns.activeExtensions || []), + colorBy: ns.colorBy || 'base', + uiVisibility: { ...(ns.uiVisibility || {}) }, + }; + this.controller.setState(replacePatch, { source }); + + if (ns.camera && typeof ns.camera === 'object') { + if (Number.isFinite(ns.camera.x)) this.controller.transform.x = ns.camera.x; + if (Number.isFinite(ns.camera.y)) this.controller.transform.y = ns.camera.y; + if (Number.isFinite(ns.camera.k)) this.controller.transform.k = ns.camera.k; + this.renderAll(); + } + if (typeof ns.searchQuery === 'string') { + if (this.ui && this.ui.searchInput) { + this.ui.searchInput.value = ns.searchQuery; + } + this.controller.handleSearch(ns.searchQuery); + } + } + + batch(fn) { + if (typeof fn === 'function') { + fn(); + } + } + + setTheme(themeName) { + if (!(themeName in THEMES)) { + const err = new Error(`Unknown theme '${themeName}'`); + this._emit('error', { error: err, source: 'api' }); + throw err; + } + this.setState({ themeName }, { source: 'api' }); + } + + setLayers(layerIds) { + this.setState({ activeExtensions: new Set(layerIds || []) }, { source: 'api' }); + } + + setColorBy(layerId) { + this.setState({ colorBy: layerId || 'base' }, { source: 'api' }); + } + + selectNode(nodeId, opts = {}) { + this.controller.selectNode(nodeId); + if (opts.animate) { + this.controller.animateToNode(nodeId, opts.k || null); + } else if (opts.center !== false) { + this.controller.panToNode(nodeId); + } + } + + clearSelection() { + this.controller.clearSelection(); + } + + search(query) { + if (this.ui && this.ui.searchInput) { + this.ui.searchInput.value = query; + } + this.controller.handleSearch(query); + } + + zoomToFit() { + this.controller.zoomToFit(); + } + + panToNode(nodeId) { + this.controller.panToNode(nodeId); + } + + animateToNode(nodeId, options = {}) { + const targetK = Object.prototype.hasOwnProperty.call(options, 'k') ? options.k : null; + this.controller.animateToNode(nodeId, targetK); + } + + setUIVisibility(flags = {}) { + if (!this.ui) return; + const prev = this.getState(); + this.ui.setControlVisibility(flags); + this.controller.state.uiVisibility = { + ...(this.controller.state.uiVisibility || {}), + ...flags, + }; + this._emit('statechange', { + prevState: prev, + nextState: this.getState(), + source: 'api', + }); + } + + setLayout(layoutPatch = {}) { + const prev = this.getState(); + this.applyLayout(layoutPatch); + this._emit('layoutchange', { + prevState: prev, + nextState: this.getState(), + source: 'api', + }); + } + + _refreshAfterLayerMutation({ rebuildMenu = false } = {}) { + this._refreshLayerControls({ rebuildMenu }); + this.controller.setState({}, { source: 'api' }); + } + + _refreshLayerControls({ rebuildMenu = false } = {}) { + if (!this.ui) return; + if (rebuildMenu) this.ui.rebuildLayersMenu(); + this.ui.syncControlsFromState(); + this.ui.renderLegend(); + } + + upsertLayer(layerId, layerPayload) { + this.store.upsertExtension(layerId, layerPayload); + this._refreshAfterLayerMutation({ rebuildMenu: true }); + } + + removeLayer(layerId) { + this.store.removeExtension(layerId); + const active = new Set(this.controller.state.activeExtensions); + active.delete(layerId); + const nextColorBy = this.controller.state.colorBy === layerId ? 'base' : this.controller.state.colorBy; + this.controller.setState({ activeExtensions: active, colorBy: nextColorBy }, { source: 'api' }); + this._refreshLayerControls({ rebuildMenu: true }); + } + + patchLayerNodes(layerId, patchByNodeId) { + this.store.patchExtensionNodes(layerId, patchByNodeId); + this._refreshAfterLayerMutation(); + } + + setLayerLabel(layerId, label) { + this.store.setExtensionLabel(layerId, label); + this._refreshAfterLayerMutation({ rebuildMenu: true }); + } + + setColorRule(layerId, colorRule) { + const ext = this.store.extensions[layerId]; + if (!ext || !ext.nodes) return; + + if (typeof colorRule === 'function') { + Object.entries(ext.nodes).forEach(([nodeId, nodeData]) => { + const nextColor = colorRule(nodeData, nodeId); + if (typeof nextColor === 'string') { + nodeData.fill_color = nextColor; + } + }); + this._refreshAfterLayerMutation(); + return; + } + + if (colorRule && colorRule.type === 'threshold' && colorRule.field && Array.isArray(colorRule.thresholds)) { + const thresholds = colorRule.thresholds + .filter((x) => x && typeof x.value === 'number' && typeof x.color === 'string') + .sort((a, b) => a.value - b.value); + Object.values(ext.nodes).forEach((nodeData) => { + const v = Number(nodeData.info && nodeData.info[colorRule.field]); + if (!Number.isFinite(v)) return; + let chosen = thresholds.length > 0 ? thresholds[0].color : null; + thresholds.forEach((t) => { + if (v >= t.value) chosen = t.color; + }); + if (chosen) nodeData.fill_color = chosen; + }); + this._refreshAfterLayerMutation(); + } + } + + enterFullscreen() { + let target = this.rootContainer; + if (!target || !target.isConnected) { + const fallback = this.wrapper && this.wrapper.parentElement ? this.wrapper.parentElement : null; + if (fallback && fallback.isConnected) { + target = fallback; + this.rootContainer = fallback; + if (this.config && this.config.mount) this.config.mount.root = fallback; + if (this.config && this.config._resolved) this.config._resolved.root = fallback; + } + } + if (!target || !target.isConnected) { + return Promise.resolve(); + } + if (target.requestFullscreen) { + return target.requestFullscreen(); + } + return Promise.resolve(); + } + + exitFullscreen() { + if (document.fullscreenElement && document.exitFullscreen) { + return document.exitFullscreen(); + } + return Promise.resolve(); + } + + addHighlightGroup(groupId, nodeIds, color) { + const groups = new Map(this.controller.state.highlightGroups); + groups.set(groupId, { nodeIds: new Set(nodeIds), color: color || '#ff6600' }); + this.controller.setState({ highlightGroups: groups }); + } + + removeHighlightGroup(groupId) { + const groups = new Map(this.controller.state.highlightGroups); + groups.delete(groupId); + this.controller.setState({ highlightGroups: groups }); + } + + clearAllHighlightGroups() { + this.controller.setState({ highlightGroups: new Map() }); + } + + getHighlightGroups() { + return new Map(this.controller.state.highlightGroups); + } + + destroy() { + this._listeners.clear(); + if (this.canvasRenderer && this.canvasRenderer.destroy) { + this.canvasRenderer.destroy(); + } + if (this.minimapRenderer && this.minimapRenderer.destroy) { + this.minimapRenderer.destroy(); + } + if (this.ui && this.ui.destroy) { + this.ui.destroy(); + } + fxOffAll(this._teardownFns); + if (this.wrapper && this.wrapper.parentNode) { + this.wrapper.parentNode.removeChild(this.wrapper); + } + } +} + +if (typeof globalThis !== 'undefined') { + globalThis.FXGraphViewer = FXGraphViewer; +} diff --git a/devtools/fx_viewer/templates/graph_data_store.js b/devtools/fx_viewer/templates/graph_data_store.js new file mode 100644 index 00000000000..aa6a93e21a5 --- /dev/null +++ b/devtools/fx_viewer/templates/graph_data_store.js @@ -0,0 +1,248 @@ +// Manages base graph + extension data and composes virtual nodes. +class GraphDataStore { + constructor(payload) { + this.baseData = payload.base; + this.extensions = payload.extensions || {}; + + // The pre-computed array/map of active Virtual Nodes + this.activeNodes = []; + this.activeNodeMap = new Map(); + + // Structural Topology (never changes when toggling extensions) + this.adjList = new Map(); + this.revAdjList = new Map(); + this.graphBounds = { minX: Infinity, maxX: -Infinity, minY: Infinity, maxY: -Infinity, width: 0, height: 0 }; + + this._initTopology(); + } + + _initTopology() { + let minX = Infinity, maxX = -Infinity, minY = Infinity, maxY = -Infinity; + + this.baseData.nodes.forEach(node => { + if (!node.width) node.width = 100; // Fallback + if (!node.height) node.height = 40; + + minX = Math.min(minX, node.x - node.width/2); + maxX = Math.max(maxX, node.x + node.width/2); + minY = Math.min(minY, node.y - node.height/2); + maxY = Math.max(maxY, node.y + node.height/2); + + this.adjList.set(node.id, []); + this.revAdjList.set(node.id, []); + }); + + const offsetX = 50 - minX; + const offsetY = 50 - minY; + + this.baseData.nodes.forEach(node => { + node.x += offsetX; + node.y += offsetY; + }); + + this.baseData.edges.forEach(edge => { + if (edge.points) { + edge.points.forEach(p => { p.x += offsetX; p.y += offsetY; }); + } + }); + + minX += offsetX; maxX += offsetX; minY += offsetY; maxY += offsetY; + + this.graphBounds = { + minX, maxX, minY, maxY, + width: maxX - minX + 100, + height: maxY - minY + 100 + }; + + this.baseData.edges.forEach(edge => { + let eMinX = Infinity, eMaxX = -Infinity, eMinY = Infinity, eMaxY = -Infinity; + const v = this.baseData.nodes.find(n => n.id === edge.v); + const w = this.baseData.nodes.find(n => n.id === edge.w); + + if (edge.points && edge.points.length > 0) { + edge.points.forEach(p => { + eMinX = Math.min(eMinX, p.x); eMaxX = Math.max(eMaxX, p.x); + eMinY = Math.min(eMinY, p.y); eMaxY = Math.max(eMaxY, p.y); + }); + } else if (v && w) { + eMinX = Math.min(v.x, w.x); eMaxX = Math.max(v.x, w.x); + eMinY = Math.min(v.y, w.y); eMaxY = Math.max(v.y, w.y); + } + edge.bounds = { minX: eMinX, maxX: eMaxX, minY: eMinY, maxY: eMaxY }; + + if (this.adjList.has(edge.v)) this.adjList.get(edge.v).push(edge); + if (this.revAdjList.has(edge.w)) this.revAdjList.get(edge.w).push(edge); + }); + } + + /** + * Called whenever the user toggles Extension checkboxes or Color Radio buttons. + * Rebuilds `activeNodes` by flattening the enabled JSON hierarchies into single Virtual Nodes. + */ + computeActiveGraph(activeExtensionIds, colorById) { + this.activeNodes = []; + this.activeNodeMap.clear(); + + this.baseData.nodes.forEach(baseNode => { + // 1. Initialize with Base Info + let flatInfo = { ...baseNode.info }; + let label_append = []; + let tooltip = [...(baseNode.tooltip || [])]; + let fill_color = baseNode.fill_color; + + // 2. Iterate through visible extensions + activeExtensionIds.forEach(extId => { + const ext = this.extensions[extId]; + if (!ext) return; + const extNode = ext.nodes[baseNode.id]; + if (!extNode) return; + + // Merge Info with Prefixes (e.g. "Profiler.latency" = 15) + if (extNode.info) { + for (const [k, v] of Object.entries(extNode.info)) { + flatInfo[`${ext.name}.${k}`] = v; + } + } + + if (extNode.label_append) { + label_append.push(...extNode.label_append); + } + + if (extNode.tooltip) { + tooltip.push(`[${ext.name}]`); + tooltip.push(...extNode.tooltip); + } + }); + + // 3. Resolve Node Fill Color + if (colorById !== 'base' && this.extensions[colorById]) { + const colorNode = this.extensions[colorById].nodes[baseNode.id]; + if (colorNode && colorNode.fill_color) { + fill_color = colorNode.fill_color; + } + } + + // 4. Cache Virtual Node + const virtualNode = { + ...baseNode, + info: flatInfo, + label_append: label_append, + tooltip: tooltip, + fill_color: fill_color + }; + + this.activeNodes.push(virtualNode); + this.activeNodeMap.set(virtualNode.id, virtualNode); + }); + } + + upsertExtension(extensionId, extensionPayload) { + if (!extensionId) { + throw new Error("upsertExtension requires a non-empty extensionId"); + } + if (!extensionPayload || typeof extensionPayload !== 'object') { + throw new Error(`upsertExtension('${extensionId}') requires an object payload`); + } + + const previous = this.extensions[extensionId] || {}; + this.extensions[extensionId] = { + name: extensionPayload.name || previous.name || extensionId, + legend: Array.isArray(extensionPayload.legend) ? extensionPayload.legend : (previous.legend || []), + nodes: extensionPayload.nodes && typeof extensionPayload.nodes === 'object' + ? extensionPayload.nodes + : (previous.nodes || {}), + }; + } + + removeExtension(extensionId) { + if (!extensionId) return; + delete this.extensions[extensionId]; + } + + setExtensionLabel(extensionId, label) { + const ext = this.extensions[extensionId]; + if (!ext) return; + ext.name = label || ext.name; + } + + patchExtensionNodes(extensionId, patchByNodeId) { + if (!extensionId || !patchByNodeId || typeof patchByNodeId !== 'object') return; + const ext = this.extensions[extensionId]; + if (!ext) { + this.upsertExtension(extensionId, { name: extensionId, nodes: {} }); + } + + const target = this.extensions[extensionId]; + if (!target.nodes) target.nodes = {}; + + Object.entries(patchByNodeId).forEach(([nodeId, patch]) => { + const prev = target.nodes[nodeId] || {}; + const next = { ...prev, ...patch }; + if (patch && patch.info && typeof patch.info === 'object') { + next.info = { ...(prev.info || {}), ...patch.info }; + } + if (patch && patch.tooltip && Array.isArray(patch.tooltip)) { + next.tooltip = patch.tooltip.slice(); + } + if (patch && patch.label_append && Array.isArray(patch.label_append)) { + next.label_append = patch.label_append.slice(); + } + target.nodes[nodeId] = next; + }); + } + + setExtensionLegend(extensionId, legend) { + const ext = this.extensions[extensionId]; + if (!ext) return; + ext.legend = Array.isArray(legend) ? legend : []; + } + + getAncestors(nodeId) { + const visited = new Set(); + const queue = [nodeId]; + visited.add(nodeId); + while (queue.length > 0) { + const curr = queue.shift(); + const inEdges = this.revAdjList.get(curr) || []; + inEdges.forEach(e => { + if (!visited.has(e.v)) { + visited.add(e.v); + queue.push(e.v); + } + }); + } + return visited; + } + + computeBoundsForNodes(nodeIds) { + let minX = Infinity, maxX = -Infinity, minY = Infinity, maxY = -Infinity; + nodeIds.forEach(nid => { + const node = this.activeNodeMap.get(nid); + if (node) { + minX = Math.min(minX, node.x - node.width / 2); + maxX = Math.max(maxX, node.x + node.width / 2); + minY = Math.min(minY, node.y - node.height / 2); + maxY = Math.max(maxY, node.y + node.height / 2); + } + }); + if (minX === Infinity) return null; + return { minX, maxX, minY, maxY, width: maxX - minX, height: maxY - minY }; + } + + getDescendants(nodeId) { + const visited = new Set(); + const queue = [nodeId]; + visited.add(nodeId); + while (queue.length > 0) { + const curr = queue.shift(); + const outEdges = this.adjList.get(curr) || []; + outEdges.forEach(e => { + if (!visited.has(e.w)) { + visited.add(e.w); + queue.push(e.w); + } + }); + } + return visited; + } +} diff --git a/devtools/fx_viewer/templates/minimap_renderer.js b/devtools/fx_viewer/templates/minimap_renderer.js new file mode 100644 index 00000000000..1e5961c712b --- /dev/null +++ b/devtools/fx_viewer/templates/minimap_renderer.js @@ -0,0 +1,304 @@ +// Minimap overview rendering with viewport tracking and click/drag navigation. +class MinimapRenderer { + constructor(container, viewer) { + this.viewer = viewer; + this._teardownFns = []; + const mountPoint = container || this.viewer.sidebar || this.viewer.mainArea; + this.container = document.createElement('div'); + this.container.className = 'fx-minimap-container'; + mountPoint.appendChild(this.container); + + this.canvas = document.createElement('canvas'); + this.canvas.className = 'fx-minimap'; + this.container.appendChild(this.canvas); + this.ctx = this.canvas.getContext('2d'); + + this.thumbnailCanvas = document.createElement('canvas'); + this.thumbnailCtx = this.thumbnailCanvas.getContext('2d'); + + this.minimapScale = 1; + this.thumbnailOffset = { x: 0, y: 0 }; + this.isDragging = false; + + this.resize(); + this._onWindowResize = () => { + this.resize(); + this.generateThumbnail(); + this.render(); + }; + fxOn(this._teardownFns, window, 'resize', this._onWindowResize); + + this.setupEvents(); + + if (typeof ResizeObserver !== 'undefined') { + this._resizeObserver = new ResizeObserver(() => { + const rect = this.container.getBoundingClientRect(); + if (rect.width > 0 && rect.height > 0) { + this.resize(); + this.generateThumbnail(); + this.render(); + } + }); + this._resizeObserver.observe(this.container); + this._teardownFns.push(() => this._resizeObserver.disconnect()); + } + } + + resize() { + const dpr = window.devicePixelRatio || 1; + const rect = this.container.getBoundingClientRect(); + this.canvas.width = rect.width * dpr; + this.canvas.height = rect.height * dpr; + } + + generateThumbnail() { + const mw = this.canvas.width; + const mh = this.canvas.height; + if (mw === 0 || mh === 0) return; + + this.thumbnailCanvas.width = mw; + this.thumbnailCanvas.height = mh; + + const bounds = this.viewer.store.graphBounds; + if (bounds.width === 0) return; + + const scaleX = mw / bounds.width; + const scaleY = mh / bounds.height; + this.minimapScale = Math.min(scaleX, scaleY) * 0.9; + + this.thumbnailOffset = { + x: (mw - bounds.width * this.minimapScale) / 2, + y: (mh - bounds.height * this.minimapScale) / 2 + }; + + const ctx = this.thumbnailCtx; + const theme = THEMES[this.viewer.controller.state.themeName]; + ctx.fillStyle = theme.bg; + ctx.fillRect(0, 0, mw, mh); + ctx.save(); + ctx.translate(this.thumbnailOffset.x, this.thumbnailOffset.y); + ctx.scale(this.minimapScale, this.minimapScale); + + ctx.strokeStyle = theme.edgeNormal; + const dpr = window.devicePixelRatio || 1; + ctx.lineWidth = Math.max(1 / this.minimapScale, dpr / this.minimapScale); + ctx.beginPath(); + this.viewer.store.baseData.edges.forEach(edge => { + const v = this.viewer.store.activeNodeMap.get(edge.v); + const w = this.viewer.store.activeNodeMap.get(edge.w); + if (v && w) { + if (edge.points && edge.points.length > 0) { + ctx.moveTo(edge.points[0].x, edge.points[0].y); + for (let i = 1; i < edge.points.length; i++) { + ctx.lineTo(edge.points[i].x, edge.points[i].y); + } + } else { + ctx.moveTo(v.x, v.y); + ctx.lineTo(w.x, w.y); + } + } + }); + ctx.stroke(); + + this.viewer.store.activeNodes.forEach(node => { + ctx.fillStyle = node.fill_color ? node.fill_color : theme.nodeFill; + const minSize = 2 / this.minimapScale; + const w = Math.max(node.width, minSize); + const h = Math.max(node.height, minSize); + ctx.fillRect(node.x - w/2, node.y - h/2, w, h); + }); + ctx.restore(); + } + + setupEvents() { + const onMouseDown = (e) => { + this.isDragging = true; + this.handleDrag(e); + }; + fxOn(this._teardownFns, this.canvas, 'mousedown', onMouseDown); + + const onMouseMove = (e) => { + if (this.isDragging) this.handleDrag(e); + }; + fxOn(this._teardownFns, window, 'mousemove', onMouseMove); + + const onMouseUp = () => { + this.isDragging = false; + }; + fxOn(this._teardownFns, window, 'mouseup', onMouseUp); + + const onWheel = (e) => { + e.preventDefault(); + const zoomIntensity = 0.1; + const wheel = e.deltaY < 0 ? 1 : -1; + const zoomFactor = Math.exp(wheel * zoomIntensity); + + const mainCanvasRect = this.viewer.canvasRenderer.canvas.getBoundingClientRect(); + const mouseX = mainCanvasRect.width / 2; + const mouseY = mainCanvasRect.height / 2; + + const transform = this.viewer.controller.transform; + const graphX = (mouseX - transform.x) / transform.k; + const graphY = (mouseY - transform.y) / transform.k; + + transform.k *= zoomFactor; + transform.x = mouseX - graphX * transform.k; + transform.y = mouseY - graphY * transform.k; + + this.viewer.renderAll(); + }; + fxOn(this._teardownFns, this.canvas, 'wheel', onWheel, { passive: false }); + } + + handleDrag(e) { + const dpr = window.devicePixelRatio || 1; + const rect = this.canvas.getBoundingClientRect(); + const mx = (e.clientX - rect.left) * dpr; + const my = (e.clientY - rect.top) * dpr; + + const graphX = (mx - this.thumbnailOffset.x) / this.minimapScale; + const graphY = (my - this.thumbnailOffset.y) / this.minimapScale; + + const canvasRect = this.viewer.canvasContainer.getBoundingClientRect(); + const transform = this.viewer.controller.transform; + + transform.x = (canvasRect.width / 2) - (graphX * transform.k); + transform.y = (canvasRect.height / 2) - (graphY * transform.k); + + this.viewer.renderAll(); + } + + resetInteractionState() { + this.isDragging = false; + } + + render() { + const dpr = window.devicePixelRatio || 1; + if (this.canvas.width === 0 || this.canvas.height === 0) return; + + const state = this.viewer.controller.state; + const theme = THEMES[state.themeName]; + const minEdgeWidth = dpr / Math.max(this.minimapScale, 1e-6); + + this.ctx.setTransform(1, 0, 0, 1, 0, 0); + this.ctx.fillStyle = theme.bg; + this.ctx.fillRect(0, 0, this.canvas.width, this.canvas.height); + + const isSelectionMode = !!state.selectedNodeId || !!state.previewNodeId || !!state.selectedEdge; + + if (isSelectionMode && state.highlightAncestors) { + this.ctx.globalAlpha = 0.2; + } + this.ctx.drawImage(this.thumbnailCanvas, 0, 0); + this.ctx.globalAlpha = 1.0; + + this.ctx.save(); + this.ctx.translate(this.thumbnailOffset.x, this.thumbnailOffset.y); + this.ctx.scale(this.minimapScale, this.minimapScale); + + const drawNodes = (nodes, padding = 0, fillColor = null) => { + nodes.forEach(nid => { + const node = this.viewer.store.activeNodeMap.get(nid); + if (node) { + this.ctx.fillStyle = fillColor || (node.fill_color ? node.fill_color : theme.nodeFill); + const minSize = 3 / this.minimapScale; + const w = Math.max(node.width, minSize) + padding; + const h = Math.max(node.height, minSize) + padding; + this.ctx.fillRect(node.x - w/2, node.y - h/2, w, h); + } + }); + }; + + const drawEdgePath = (edge, v, w) => { + if (edge.points && edge.points.length > 0) { + this.ctx.moveTo(edge.points[0].x, edge.points[0].y); + for (let i = 1; i < edge.points.length; i++) { + this.ctx.lineTo(edge.points[i].x, edge.points[i].y); + } + } else { + this.ctx.moveTo(v.x, v.y); + this.ctx.lineTo(w.x, w.y); + } + }; + + const drawEdges = (edges, color, width) => { + const edgeList = Array.from(edges || []); + if (edgeList.length === 0) return; + this.ctx.strokeStyle = color; + this.ctx.lineWidth = Math.max(width, minEdgeWidth); + this.ctx.beginPath(); + edgeList.forEach((edge) => { + const v = this.viewer.store.activeNodeMap.get(edge.v); + const w = this.viewer.store.activeNodeMap.get(edge.w); + if (!v || !w) return; + drawEdgePath(edge, v, w); + }); + this.ctx.stroke(); + }; + + const selectionEdges = new Set(); + const target = state.previewNodeId || state.selectedNodeId; + + if (target) { + (this.viewer.store.revAdjList.get(target) || []).forEach((edge) => selectionEdges.add(edge)); + (this.viewer.store.adjList.get(target) || []).forEach((edge) => selectionEdges.add(edge)); + } + if (state.selectedEdge) selectionEdges.add(state.selectedEdge); + drawEdges(selectionEdges, theme.edgeInput, 2 / this.minimapScale); + + if (target) { + if (state.highlightAncestors) { + drawNodes(Array.from(state.ancestors)); + drawNodes(Array.from(state.descendants)); + } + } + + if (state.highlightGroups && state.highlightGroups.size > 0) { + state.highlightGroups.forEach(({ nodeIds, color }) => { + const nodeSet = new Set(nodeIds || []); + const edges = this.viewer.store.baseData.edges.filter( + (edge) => nodeSet.has(edge.v) && nodeSet.has(edge.w) + ); + drawEdges(edges, color, 2 / this.minimapScale); + drawNodes(Array.from(nodeSet), 2/this.minimapScale, color); + }); + } + + if (state.selectedEdge) { + drawNodes([state.selectedEdge.v, state.selectedEdge.w], 2/this.minimapScale, theme.nodeSelected); + } + + if (state.searchCandidates.length > 0) { + drawNodes(state.searchCandidates.map(c => c.node.id), 2/this.minimapScale, theme.nodeSelected); + } + + if (target) { + drawNodes([target], 3/ this.minimapScale, theme.nodeSelected); + } + + this.ctx.restore(); + + const transform = this.viewer.controller.transform; + const canvasRect = this.viewer.canvasContainer.getBoundingClientRect(); + + const vx = -transform.x / transform.k; + const vy = -transform.y / transform.k; + const vw = canvasRect.width / transform.k; + const vh = canvasRect.height / transform.k; + + const mx = vx * this.minimapScale + this.thumbnailOffset.x; + const my = vy * this.minimapScale + this.thumbnailOffset.y; + const mw = vw * this.minimapScale; + const mh = vh * this.minimapScale; + + this.ctx.strokeStyle = theme.minimapBorder; + this.ctx.lineWidth = 2 * dpr; + this.ctx.strokeRect(mx, my, mw, mh); + this.ctx.fillStyle = theme.minimapBox; + this.ctx.fillRect(mx, my, mw, mh); + } + + destroy() { + fxOffAll(this._teardownFns); + } +} diff --git a/devtools/fx_viewer/templates/runtime.js b/devtools/fx_viewer/templates/runtime.js new file mode 100644 index 00000000000..c92e9724340 --- /dev/null +++ b/devtools/fx_viewer/templates/runtime.js @@ -0,0 +1,75 @@ +function fxOn(teardownFns, target, eventName, handler, options) { + if (!target || !target.addEventListener || !target.removeEventListener) return; + target.addEventListener(eventName, handler, options); + teardownFns.push(() => target.removeEventListener(eventName, handler, options)); +} + +function fxOffAll(teardownFns) { + while (teardownFns.length > 0) { + const off = teardownFns.pop(); + try { + off(); + } catch (_) {} + } +} + +function fxEsc(s) { + if (typeof s !== 'string') s = String(s); + return s.replace(/&/g, '&').replace(//g, '>').replace(/"/g, '"').replace(/'/g, '''); +} + +// Pick a readable ink for a given background hex. Uses WCAG 2.x relative +// luminance and returns one of two standard inks — dark #111111 or light +// #f8f8f8 — whichever has higher contrast against the background. Returns +// null for malformed input so callers can fall back to theme defaults. +function fxReadableTextColor(hex) { + if (typeof hex !== 'string' || hex.charAt(0) !== '#' || hex.length !== 7) return null; + const r = parseInt(hex.substring(1, 3), 16) / 255; + const g = parseInt(hex.substring(3, 5), 16) / 255; + const b = parseInt(hex.substring(5, 7), 16) / 255; + if (!isFinite(r) || !isFinite(g) || !isFinite(b)) return null; + const lin = (c) => (c <= 0.03928 ? c / 12.92 : Math.pow((c + 0.055) / 1.055, 2.4)); + const L = 0.2126 * lin(r) + 0.7152 * lin(g) + 0.0722 * lin(b); + return L > 0.179 ? '#111111' : '#f8f8f8'; +} + +const THEMES = { + 'light': { + bg: '#ffffff', + text: '#000000', + textMuted: '#666666', + nodeFill: '#66ccee', + nodeInput: '#75dcfe', + nodeOutput: '#75dcfe', + nodeSelected: '#fbc02d ', + edgeNormal: '#333333', + edgeInput: '#ff9800', + edgeOutput: '#ff9800', + edgeHover: '#e91e63', + minimapBox: 'rgba(255, 0, 0, 0.1)', + minimapBorder: 'red', + uiBg: 'rgba(255, 255, 255, 0.95)', + uiBorder: '#cccccc', + uiHover: '#f0f8ff', + legendBg: 'rgba(255, 255, 255, 0.8)' + }, + 'dark': { + bg: '#1e1e1e', + text: '#ffffff', + textMuted: '#cccccc', + nodeFill: '#0277a1', + nodeInput: '#1287b1', + nodeOutput: '#1287b1', + nodeSelected: '#ffeb3b', + edgeNormal: '#cccccc', + edgeInput: '#ffb74d', + edgeOutput: '#ffb74d', + edgeHover: '#ff4081', + minimapBox: 'rgba(255, 100, 100, 0.3)', + minimapBorder: '#ff5555', + uiBg: 'rgba(30, 30, 30, 0.95)', + uiBorder: '#777777', + uiHover: '#333333', + legendBg: 'rgba(30, 30, 30, 0.8)' + } +}; diff --git a/devtools/fx_viewer/templates/search_engine.js b/devtools/fx_viewer/templates/search_engine.js new file mode 100644 index 00000000000..f4cd99d0939 --- /dev/null +++ b/devtools/fx_viewer/templates/search_engine.js @@ -0,0 +1,244 @@ +// Fuzzy search over active graph nodes with token scoring and context highlighting. +class SearchEngine { + constructor(dataStore) { + this.dataStore = dataStore; + this._kvScores = { + exactFieldExactValue: 40, + exactFieldFuzzyValue: 34, + fuzzyFieldFuzzyValue: 28, + }; + } + + search(query) { + if (!query || query.trim() === '') return []; + const queryTokens = this._tokenizeQuery(query); + if (queryTokens.length === 0) return []; + + const results = []; + const hOpen = ``; + const hClose = ``; + + let maxMatchedTokensCount = 0; + + this.dataStore.activeNodes.forEach(node => { + let totalScore = 0; + let matchedTokensCount = 0; + let bestMatchField = null; + let bestMatchString = null; + let highestTokenScore = 0; + + queryTokens.forEach(token => { + const tokenResult = this._scoreToken(node, token, hOpen, hClose); + const tokenScore = tokenResult.score; + const tokenMatchField = tokenResult.matchField; + const tokenMatchString = tokenResult.matchString; + + if (tokenScore > 0) { + totalScore += tokenScore; + matchedTokensCount += 1; + if (tokenScore > highestTokenScore) { + highestTokenScore = tokenScore; + bestMatchField = tokenMatchField; + bestMatchString = tokenMatchString; + } + } + }); + + if (totalScore > 0) { + if (matchedTokensCount > maxMatchedTokensCount) { + maxMatchedTokensCount = matchedTokensCount; + } + + let highlightedId = node.id; + let stop = false; + queryTokens.forEach((token) => { + const plain = this._stripQuotes(token).toLowerCase(); + if (!plain) return; + if (!stop) highlightedId = this._highlightText(highlightedId, plain, hOpen, hClose); + if (highlightedId.includes(hClose)) stop=true; + }); + + results.push({ + node, + score: totalScore, + matchedTokensCount: matchedTokensCount, + matchField: bestMatchField, + matchString: bestMatchString, + highlightedId: highlightedId + }); + } + }); + + const filteredResults = results.filter(r => r.matchedTokensCount === maxMatchedTokensCount); + filteredResults.sort((a, b) => b.score - a.score); + return filteredResults; + } + + _tokenizeQuery(query) { + const tokens = []; + let current = ''; + let currentQuote = null; + + for (let i = 0; i < query.length; i++) { + const ch = query[i]; + if (ch === '"' || ch === "'") { + if(!currentQuote){ + currentQuote = ch; + }else if(currentQuote === ch){ + currentQuote = null; + }else{ + current += ch; + } + continue; + } + if (!currentQuote && /\s/.test(ch)) { + if (current.trim().length > 0) tokens.push(current.trim()); + current = ''; + continue; + } + current += ch; + } + + if (current.trim().length > 0) tokens.push(current.trim()); + return tokens; + } + + _stripQuotes(s) { + return s.replace(/"|'/g, ""); + } + + _escapeRegExp(s) { + return String(s || '').replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); + } + + _escapeHTML(str) { + return str.replace(/[&<>"']/g, function(m) { + return { + '&': '&', + '<': '<', + '>': '>', + '"': '"', + "'": ''' + }[m]; + }); + } + + _highlightText(text, pattern, hOpen, hClose) { + let safeText = this._escapeHTML(text); + if (!pattern) return safeText; + const regex = new RegExp(`(${this._escapeRegExp(pattern)})`, 'gi'); + return String(safeText).replace(regex, `${hOpen}$1${hClose}`); + } + + _valueToString(value) { + const valueStr = typeof value === 'object' ? JSON.stringify(value) : String(value); + return valueStr.split(/\r?\n/).join(""); + } + + _formatValueSnippet(actualValStr, pattern, hOpen, hClose) { + const valStrLower = actualValStr.toLowerCase(); + const idx = valStrLower.indexOf(pattern); + if (idx === -1) { + if (actualValStr.length > 30) return `${actualValStr.substring(0, 30)}...`; + return actualValStr; + } + const start = Math.max(0, idx - 15); + const end = Math.min(actualValStr.length, idx + pattern.length + 15); + let snippet = actualValStr.substring(start, end); + snippet = this._highlightText(snippet, pattern, hOpen, hClose); + return `...${snippet}...`; + } + + _scoreToken(node, rawToken, hOpen, hClose) { + const normalizedToken = this._stripQuotes(rawToken).toLowerCase(); + if (!normalizedToken) return { score: 0, matchField: null, matchString: null }; + + const firstEq = rawToken.indexOf('='); + if (firstEq !== -1) { + const fieldPattern = this._stripQuotes(rawToken.slice(0, firstEq)).toLowerCase(); + const valuePattern = this._stripQuotes(rawToken.slice(firstEq + 1)).toLowerCase(); + if (fieldPattern && valuePattern) { + const kvResult = this._scoreFieldValueToken(node, fieldPattern, valuePattern, hOpen, hClose); + if (kvResult.fieldMatched) return kvResult; + } + return this._scorePlainToken(node, normalizedToken, hOpen, hClose); + } + + return this._scorePlainToken(node, normalizedToken, hOpen, hClose); + } + + _unifyValStr(valStr){ + const valLower = valStr.toLowerCase() + return this._stripQuotes(valLower.replace(/\s/g, '')); + } + + _scoreFieldValueToken(node, fieldPattern, valuePattern, hOpen, hClose) { + const entries = [['id', node.id], ...Object.entries(node.info || {})]; + let fieldMatched = false; + let best = { score: 0, matchField: null, matchString: null, fieldMatched: true }; + + for (const [key, value] of entries) { + const keyStr = String(key); + const keyLower = keyStr.toLowerCase(); + if (!keyLower.includes(fieldPattern)) continue; + fieldMatched = true; + + const actualValStr = this._valueToString(value); + const unifiedValStr = this._unifyValStr(actualValStr); + const unifiedValPattern = this._unifyValStr(valuePattern); + if (!unifiedValStr.includes(unifiedValPattern)) continue; + + const keyExact = keyLower === fieldPattern; + const valueExact = unifiedValStr === unifiedValPattern; + let score = this._kvScores.fuzzyFieldFuzzyValue; + if (keyExact && valueExact) score = this._kvScores.exactFieldExactValue; + else if (keyExact) score = this._kvScores.exactFieldFuzzyValue; + + const matchField = this._highlightText(keyStr, fieldPattern, hOpen, hClose); + const matchString = this._formatValueSnippet(actualValStr, valuePattern, hOpen, hClose); + if (score > best.score) { + best = { score, matchField, matchString, fieldMatched: true }; + } + } + + if (!fieldMatched) return { score: 0, matchField: null, matchString: null, fieldMatched: false }; + return best; + } + + _scorePlainToken(node, token, hOpen, hClose) { + if (node.id.toLowerCase().includes(token)) { + return { score: 10, matchField: 'id', matchString: node.id }; + } + + for (const [key, value] of Object.entries(node.info || {})) { + const keyStr = String(key); + const keyLower = keyStr.toLowerCase(); + const actualValStr = this._valueToString(value); + const unifiedValStr = this._unifyValStr(actualValStr); + const unifiedToken = this._unifyValStr(token); + if (!keyLower.includes(unifiedToken) && !unifiedValStr.includes(unifiedToken)) continue; + + let score = 1; + if (key === 'op') score = 5; + else if (key === 'target') score = 3; + + if (unifiedValStr.includes(unifiedToken)) { + return { + score, + matchField: keyStr, + matchString: this._formatValueSnippet(actualValStr, token, hOpen, hClose), + }; + } + + const highlightedField = this._highlightText(keyStr, token, hOpen, hClose); + const shortValue = actualValStr.length > 30 ? `${actualValStr.substring(0, 30)}...` : actualValStr; + return { + score, + matchField: highlightedField, + matchString: shortValue, + }; + } + + return { score: 0, matchField: null, matchString: null }; + } +} diff --git a/devtools/fx_viewer/templates/ui_manager.js b/devtools/fx_viewer/templates/ui_manager.js new file mode 100644 index 00000000000..c50c6b22c54 --- /dev/null +++ b/devtools/fx_viewer/templates/ui_manager.js @@ -0,0 +1,635 @@ +// Manages all non-canvas DOM elements: taskbar, search, layers dropdown, legend, and info panel. +class UIManager { + constructor(container, viewer, options = {}) { + this.container = container; + this.viewer = viewer; + this.controller = viewer.controller; + this.options = options; + this._teardownFns = []; + this.controls = { + toolbar: true, + search: true, + layers: true, + colorBy: true, + theme: true, + legend: true, + zoomButtons: true, + fullscreenButton: false, + highlightButton: true, + ...(options.controls || {}), + }; + this.mounts = options.mounts || {}; + this.layerCheckboxes = new Map(); + this.colorByRadios = new Map(); + this.buildUI(); + } + + _createTaskbarButton({ html, title, onClick, className = 'fx-button' }) { + const btn = document.createElement('button'); + btn.className = className; + btn.innerHTML = html; + if (title) btn.title = title; + if (typeof onClick === 'function') btn.onclick = onClick; + return btn; + } + + buildUI() { + this.taskbar = document.createElement('div'); + this.taskbar.className = 'fx-taskbar'; + + this.searchContainer = null; + this.layersContainer = null; + this.layersMenu = null; + this.themeSelect = null; + this.btnHighlight = null; + this.btnZoomFit = null; + this.btnFullscreen = null; + this.btnClear = null; + + if (this.controls.search) { + this.searchContainer = document.createElement('div'); + this.searchContainer.className = 'fx-search-container'; + + this.searchInput = document.createElement('input'); + this.searchInput.className = 'fx-search-input'; + this.searchInput.placeholder = 'Search nodes (fuzzy)...'; + + this.searchMenu = document.createElement('div'); + this.searchMenu.className = 'fx-search-menu'; + + this.searchContainer.appendChild(this.searchInput); + this.searchContainer.appendChild(this.searchMenu); + this.taskbar.appendChild(this.searchContainer); + } else { + this.searchInput = null; + this.searchMenu = null; + } + + if (this.controls.layers || this.controls.colorBy) { + this.layersContainer = document.createElement('div'); + this.layersContainer.style.position = 'relative'; + this.layersContainer.style.marginLeft = '10px'; + + this.btnLayers = this._createTaskbarButton({ + html: '📚 Layers', + }); + + this.layersMenu = document.createElement('div'); + this.layersMenu.className = 'fx-layers-menu'; + this.rebuildLayersMenu(); + + this.btnLayers.onclick = () => { + this.layersMenu.style.display = this.layersMenu.style.display === 'block' ? 'none' : 'block'; + }; + + this.layersContainer.appendChild(this.btnLayers); + this.layersContainer.appendChild(this.layersMenu); + this.taskbar.appendChild(this.layersContainer); + } + + if (this.controls.highlightButton) { + this.btnHighlight = this._createTaskbarButton({ + html: '🔗', + title: 'Toggle Highlight Ancestors/Descendants', + className: 'fx-button active', + onClick: () => { + this.controller.state.highlightAncestors = !this.controller.state.highlightAncestors; + this.btnHighlight.classList.toggle('active', this.controller.state.highlightAncestors); + this.controller.setState({}); + }, + }); + this.taskbar.appendChild(this.btnHighlight); + } + + if (this.controls.zoomButtons) { + this.btnZoomFit = this._createTaskbarButton({ + html: '⤢', + title: 'Zoom to Fit', + onClick: () => this.controller.zoomToFit(), + }); + this.taskbar.appendChild(this.btnZoomFit); + } + + if (this.controls.fullscreenButton) { + this.btnFullscreen = this._createTaskbarButton({ + html: '⛶', + title: 'Enter Fullscreen', + onClick: async () => { + if (document.fullscreenElement) { + await this.viewer.exitFullscreen(); + } else { + await this.viewer.enterFullscreen(); + } + this.syncFullscreenButton(); + }, + }); + this.taskbar.appendChild(this.btnFullscreen); + this._onFullscreenChange = () => this.syncFullscreenButton(); + fxOn(this._teardownFns, document, 'fullscreenchange', this._onFullscreenChange); + } + + if (this.controls.theme) { + this.themeSelect = document.createElement('select'); + this.themeSelect.className = 'fx-select'; + this.themeSelect.innerHTML = ``; + this.themeSelect.onchange = (e) => { + this.controller.setState({ themeName: e.target.value }); + }; + this.taskbar.appendChild(this.themeSelect); + } + + this.infoPanel = document.createElement('div'); + this.infoPanel.className = 'fx-info-panel'; + this.infoPanel.innerHTML = '
No node selected

Hover or click a node
'; + + this.legendOverlay = document.createElement('div'); + this.legendOverlay.className = 'fx-legend-overlay'; + + const toolbarContainer = this.mounts.toolbarContainer || this.container; + const legendContainer = this.mounts.legendContainer || this.container; + const infoContainer = this.mounts.infoContainer || this.viewer.sidebar || this.container; + if (this.controls.toolbar) { + toolbarContainer.appendChild(this.taskbar); + } + if (this.controls.legend) { + legendContainer.appendChild(this.legendOverlay); + } else { + this.legendOverlay.style.display = 'none'; + } + infoContainer.appendChild(this.infoPanel); + + this.applyThemeToDOM(); + this.renderLegend(); + + if (this.searchInput) { + fxOn(this._teardownFns, this.searchInput, 'input', (e) => { + this.controller.handleSearch(e.target.value); + if (e.target.value) { + this._showSearchMenu(); + } else { + this.searchMenu.style.display = 'none'; + } + }); + + fxOn(this._teardownFns, this.searchInput, 'keydown', (e) => { + if (e.key === 'ArrowDown') { + e.preventDefault(); + this.controller.handleSearchNavigate(1); + } else if (e.key === 'ArrowUp') { + e.preventDefault(); + this.controller.handleSearchNavigate(-1); + } else if (e.key === 'Enter') { + e.preventDefault(); + this.controller.handleSearchSelect(); + } + }); + } + + this._onDocumentClick = (e) => { + if (this.searchContainer && !this.searchContainer.contains(e.target)) { + this.closeSearchMenu(); + if (this.controller.state.searchCandidates.length > 0) { + this.controller.setState({ searchCandidates: [], searchSelectedIndex: -1, previewNodeId: null }); + } + } + if (this.layersContainer && !this.layersContainer.contains(e.target)) { + this.layersMenu.style.display = 'none'; + } + }; + fxOn(this._teardownFns, document, 'click', this._onDocumentClick); + + this.syncControlsFromState(); + this.syncFullscreenButton(); + } + + rebuildLayersMenu() { + if (!this.layersMenu) return; + this.layerCheckboxes.clear(); + this.colorByRadios.clear(); + const radioName = `fx-color-by-${this.viewer.containerId || 'viewer'}`; + + let layersHtml = ''; + if (this.controls.layers) { + layersHtml += `
Extensions
`; + for (const [extId, extData] of Object.entries(this.viewer.store.extensions)) { + layersHtml += ` + + `; + } + } + + if (this.controls.colorBy) { + layersHtml += `
Color By
`; + layersHtml += ` + + `; + for (const [extId, extData] of Object.entries(this.viewer.store.extensions)) { + if (extData.legend && extData.legend.length > 0) { + layersHtml += ` + + `; + } + } + } + + this.layersMenu.innerHTML = layersHtml; + this.layersMenu.querySelectorAll('.fx-layer-checkbox').forEach(cb => { + this.layerCheckboxes.set(cb.value, cb); + cb.onchange = (e) => { + const active = new Set(this.controller.state.activeExtensions); + if (e.target.checked) active.add(e.target.value); + else active.delete(e.target.value); + this.controller.setState({ activeExtensions: active }); + }; + }); + this.layersMenu.querySelectorAll('input[type="radio"]').forEach(radio => { + this.colorByRadios.set(radio.value, radio); + radio.onchange = (e) => { + this.controller.setState({ colorBy: e.target.value }); + }; + }); + } + + syncControlsFromState() { + const state = this.controller.state; + if (this.themeSelect) { + this.themeSelect.value = state.themeName; + } + if (this.btnHighlight) { + this.btnHighlight.classList.toggle('active', !!state.highlightAncestors); + } + if (this.layerCheckboxes.size > 0) { + this.layerCheckboxes.forEach((checkbox, extId) => { + checkbox.checked = state.activeExtensions.has(extId); + }); + } + if (this.colorByRadios.size > 0) { + this.colorByRadios.forEach((radio, extId) => { + radio.checked = extId === state.colorBy; + }); + } + } + + setControlVisibility(flags = {}) { + if ('toolbar' in flags && this.taskbar) { + this.taskbar.style.display = flags.toolbar ? '' : 'none'; + } + if ('search' in flags && this.searchContainer) { + this.searchContainer.style.display = flags.search ? '' : 'none'; + } + if ('layers' in flags && this.layersContainer) { + this.layersContainer.style.display = flags.layers ? '' : 'none'; + } + if ('theme' in flags && this.themeSelect) { + this.themeSelect.style.display = flags.theme ? '' : 'none'; + } + if ('legend' in flags && this.legendOverlay) { + this.legendOverlay.style.display = flags.legend ? '' : 'none'; + } + if ('fullscreenButton' in flags && this.btnFullscreen) { + this.btnFullscreen.style.display = flags.fullscreenButton ? '' : 'none'; + } + if ('highlightButton' in flags && this.btnHighlight) { + this.btnHighlight.style.display = flags.highlightButton ? '' : 'none'; + } + } + + syncFullscreenButton() { + if (!this.btnFullscreen) return; + const active = !!document.fullscreenElement; + this.btnFullscreen.title = active ? 'Exit Fullscreen' : 'Enter Fullscreen'; + this.btnFullscreen.innerHTML = active ? '✕' : '⛶'; + } + + renderLegend() { + if (!this.controls.legend || !this.legendOverlay) return; + const colorBy = this.controller.state.colorBy; + let legendData = []; + let title = "Legend"; + const theme = THEMES[this.controller.state.themeName]; + + if (colorBy === 'base') { + legendData = this.viewer.store.baseData.legend || []; + title = "Base Graph"; + } else if (this.viewer.store.extensions[colorBy]) { + legendData = this.viewer.store.extensions[colorBy].legend || []; + title = this.viewer.store.extensions[colorBy].name; + } + + if (legendData.length === 0) { + this.legendOverlay.style.display = 'none'; + return; + } + + this.legendOverlay.style.display = 'block'; + const shadeColor = (color, percent) => { + if (!color || !color.startsWith('#')) return color; + let R = parseInt(color.substring(1,3), 16); + let G = parseInt(color.substring(3,5), 16); + let B = parseInt(color.substring(5,7), 16); + R = parseInt(R * (100 + percent) / 100); + G = parseInt(G * (100 + percent) / 100); + B = parseInt(B * (100 + percent) / 100); + R = (R<255)?R:255; G = (G<255)?G:255; B = (B<255)?B:255; + R = (R>0)?R:0; G = (G>0)?G:0; B = (B>0)?B:0; + const RR = ((R.toString(16).length==1)?"0"+R.toString(16):R.toString(16)); + const GG = ((G.toString(16).length==1)?"0"+G.toString(16):G.toString(16)); + const BB = ((B.toString(16).length==1)?"0"+B.toString(16):B.toString(16)); + return "#"+RR+GG+BB; + }; + + let html = `${title}
`; + legendData.forEach(item => { + const swatchColor = theme && theme === THEMES.dark ? shadeColor(item.color, -20) : item.color; + html += `
+
+ ${fxEsc(item.label)} +
`; + }); + html += `
`; + this.legendOverlay.innerHTML = html; + } + + applyThemeToDOM() { + const theme = THEMES[this.controller.state.themeName]; + this.viewer.wrapper.style.setProperty('--fx-ui-hover', theme.uiHover); + this.viewer.wrapper.style.backgroundColor = theme.bg; + this.viewer.wrapper.style.color = theme.text; + if (this.viewer.sidebar) { + this.viewer.sidebar.style.backgroundColor = theme.bg; + this.viewer.sidebar.style.borderLeftColor = theme.uiBorder; + } + + if (this.taskbar) { + this.taskbar.style.backgroundColor = theme.uiBg; + this.taskbar.style.borderColor = theme.uiBorder; + } + if (this.searchMenu) { + this.searchMenu.style.backgroundColor = theme.uiBg; + this.searchMenu.style.borderColor = theme.uiBorder; + } + + if (this.legendOverlay) { + this.legendOverlay.style.backgroundColor = theme.legendBg; + this.legendOverlay.style.borderColor = theme.uiBorder; + } + + this.infoPanel.style.backgroundColor = theme.bg; + + const sel = '.fx-button, .fx-search-input, .fx-select, .fx-layers-menu'; + const wrapperControls = this.viewer.wrapper.querySelectorAll(sel); + const mainAreaControls = this.viewer.mainArea && this.viewer.mainArea !== this.viewer.wrapper + ? this.viewer.mainArea.querySelectorAll(sel) + : []; + const controls = new Set([...wrapperControls, ...mainAreaControls]); + controls.forEach(ctrl => { + ctrl.style.borderColor = theme.uiBorder; + ctrl.style.color = theme.text; + ctrl.style.backgroundColor = theme.uiBg; + }); + + document.querySelectorAll('.fx-search-item').forEach(item => { + item.style.borderBottomColor = theme.uiBorder; + }); + + this.viewer.resizer.style.backgroundColor = theme.uiBorder; + if (this.viewer.resizerH) { + this.viewer.resizerH.style.backgroundColor = theme.uiBorder; + } + + if (this.viewer.minimapRenderer && this.viewer.minimapRenderer.container) { + this.viewer.minimapRenderer.container.style.borderTopColor = theme.uiBorder; + this.viewer.minimapRenderer.generateThumbnail(); + this.viewer.minimapRenderer.render(); + } + } + + closeSearchMenu() { + if (this.searchMenu) this.searchMenu.style.display = 'none'; + } + + _showSearchMenu() { + if (!this.searchMenu) return; + const canvas = this.viewer.canvasRenderer?.canvas; + if (canvas) { + const rect = canvas.getBoundingClientRect(); + this.searchMenu.style.maxHeight = Math.min(Math.floor(rect.height * 0.7), 500) + 'px'; + this.searchMenu.style.maxWidth = Math.min(Math.floor(rect.width * 0.7), 500) + 'px'; + } + this.searchMenu.style.display = 'block'; + } + + updateSearchResults(candidates, selectedIndex) { + if (!this.searchMenu) return; + this.searchMenu.innerHTML = ''; + if (candidates.length === 0) return; + this._showSearchMenu(); + + this.visibleCandidatesCount = 50; + this.renderSearchCandidatesChunk(candidates, selectedIndex, 0, this.visibleCandidatesCount); + + this.searchMenu.onscroll = () => { + if (this.searchMenu.scrollTop + this.searchMenu.clientHeight >= this.searchMenu.scrollHeight - 10) { + if (this.visibleCandidatesCount < candidates.length) { + const start = this.visibleCandidatesCount; + this.visibleCandidatesCount += 20; + this.renderSearchCandidatesChunk(candidates, this.controller.state.searchSelectedIndex, start, this.visibleCandidatesCount); + } + } + }; + } + + updateSearchActiveItem(selectedIndex) { + if (!this.searchMenu) return; + Array.from(this.searchMenu.children).forEach((item, idx) => { + if (idx === selectedIndex) { + item.classList.add('active'); + } else { + item.classList.remove('active'); + } + }); + if (selectedIndex >= 0 && selectedIndex < this.searchMenu.children.length) { + const childNode = this.searchMenu.children[selectedIndex]; + if (childNode) childNode.scrollIntoView({ block: 'nearest' }); + } + } + + renderSearchCandidatesChunk(candidates, selectedIndex, start=0, end=50) { + if (!this.searchMenu) return; + const theme = THEMES[this.controller.state.themeName]; + for (let idx = start; idx < Math.min(end, candidates.length); idx++) { + const cand = candidates[idx]; + const item = document.createElement('div'); + item.className = 'fx-search-item' + (idx === selectedIndex ? ' active' : ''); + item.style.borderBottomColor = theme.uiBorder; + + let matchText = ''; + if (cand.matchField === 'id') { + matchText = ''; + } else if (cand.matchField) { + matchText = `${cand.matchField}: ${cand.matchString}`; + } + + // highlightedId and matchString contain intentional HTML (highlight spans from SearchEngine) + item.innerHTML = `
${cand.highlightedId}
${matchText}
`; + + item.onmouseenter = () => this.controller.handleSearchHover(idx); + item.onmousedown = (e) => { + e.preventDefault(); + this.controller.handleSearchSelect(idx); + }; + + this.searchMenu.appendChild(item); + } + } + + updateInfoPanel(nodeId) { + const node = this.viewer.store.activeNodeMap.get(nodeId); + if (!node) return; + + this.infoPanel.style.display = 'block'; + const theme = THEMES[this.controller.state.themeName]; + + let html = `

Node: ${fxEsc(node.id)}

`; + html += ``; + + const renderRow = (key, val) => { + html += ``; + }; + + // 1. Core PyTorch Properties + const coreKeys = ['op', 'name', 'target', 'schema', 'args', 'kwargs', 'named_args', 'shape', 'dtype', 'tensor_shape']; + if (node.info) { + coreKeys.forEach(k => { + if (!(k in node.info)) return; + let val = node.info[k]; + if (k === 'args' || k === 'kwargs') { + if (typeof val === 'object') val = JSON.stringify(val, null, 2); + if (val !== '()' && val !== '{}') { + renderRow(k.charAt(0).toUpperCase() + k.slice(1), `
${fxEsc(String(val))}
`); + } + } else if (k === 'named_args' || k === 'schema') { + let display = typeof val === 'object' ? JSON.stringify(val, null, 2) : String(val); + renderRow(k === 'schema' ? 'Schema' : 'Named Args', `
${fxEsc(display)}
`); + } else if (k === 'shape' || k === 'tensor_shape') { + renderRow("Shape", JSON.stringify(val).replace(/"/g, '')); + } else if (k === 'dtype') { + renderRow("Dtype", val.replace('torch.', '')); + } else { + renderRow(k.charAt(0).toUpperCase() + k.slice(1), fxEsc(String(val))); + } + }); + } + + const inEdges = this.viewer.store.revAdjList.get(nodeId) || []; + if (inEdges.length > 0) { + let links = inEdges.map(e => ``).join('
'); + renderRow("Inputs", links); + } + + const outEdges = this.viewer.store.adjList.get(nodeId) || []; + if (outEdges.length > 0) { + let links = outEdges.map(e => ``).join('
'); + renderRow("Outputs", links); + } + + // Render base custom meta that isn't from an extension + if (node.info) { + for (const [key, value] of Object.entries(node.info)) { + if (coreKeys.includes(key) || key.includes('.')) continue; // skip core and extensions + let valStr = typeof value === 'object' ? JSON.stringify(value, null, 2) : String(value); + renderRow(`Meta: ${key}`, `
${valStr}
`); + } + } + html += `
${fxEsc(key)}${val}
`; + + // 2. Extension Groups (Split by Prefix) + if (node.info) { + let extensionGroups = {}; + for (const [key, value] of Object.entries(node.info)) { + if (key.includes('.')) { + const parts = key.split('.'); + const extName = parts[0]; + const subKey = parts.slice(1).join('.'); + if (!extensionGroups[extName]) extensionGroups[extName] = {}; + extensionGroups[extName][subKey] = value; + } + } + + for (const [extName, extDict] of Object.entries(extensionGroups)) { + html += `
--- ${fxEsc(extName)} ---
`; + html += ``; + for (const [k, v] of Object.entries(extDict)) { + let valStr = typeof v === 'object' ? JSON.stringify(v, null, 2) : String(v); + html += ``; + } + html += `
${fxEsc(k)}
${fxEsc(valStr)}
`; + } + } + + this.infoPanel.innerHTML = html; + + const links = this.infoPanel.querySelectorAll('.fx-link'); + links.forEach(link => { + link.onclick = (e) => { + const targetNode = e.target.getAttribute('data-node'); + this.controller.selectNode(targetNode); + this.controller.animateToNode(targetNode); + }; + }); + } + + updateEdgeInfoPanel(edge) { + const srcNode = this.viewer.store.activeNodeMap.get(edge.v); + const dstNode = this.viewer.store.activeNodeMap.get(edge.w); + if (!srcNode || !dstNode) return; + + const theme = THEMES[this.controller.state.themeName]; + let html = `

Edge: ${fxEsc(srcNode.id)} →
${fxEsc(dstNode.id)}

`; + html += ``; + + let shapeStr = '', dtypeStr = ''; + if (srcNode.info && srcNode.info.shape) shapeStr = JSON.stringify(srcNode.info.shape).replace(/"/g, ''); + else if (srcNode.info && srcNode.info.tensor_shape) shapeStr = JSON.stringify(srcNode.info.tensor_shape).replace(/"/g, ''); + + if (srcNode.info && srcNode.info.dtype && typeof srcNode.info.dtype === "string") dtypeStr = srcNode.info.dtype.replace('torch.', ''); + + if (shapeStr) html += ``; + if (dtypeStr) html += ``; + + html += ``; + html += ``; + + html += `
Shape${fxEsc(shapeStr)}
Dtype${fxEsc(dtypeStr)}
Src Node
Dst Node
`; + this.infoPanel.innerHTML = html; + + const links = this.infoPanel.querySelectorAll('.fx-link'); + links.forEach(link => { + link.onclick = (e) => { + const targetNode = e.target.getAttribute('data-node'); + this.controller.selectNode(targetNode); + this.controller.animateToNode(targetNode); + }; + }); + } + + hideInfoPanel() { + this.infoPanel.innerHTML = '
No node selected

Hover or click a node
'; + } + + destroy() { + fxOffAll(this._teardownFns); + if (this.taskbar && this.taskbar.parentNode) this.taskbar.parentNode.removeChild(this.taskbar); + if (this.legendOverlay && this.legendOverlay.parentNode) this.legendOverlay.parentNode.removeChild(this.legendOverlay); + if (this.infoPanel && this.infoPanel.parentNode) this.infoPanel.parentNode.removeChild(this.infoPanel); + } +} diff --git a/devtools/fx_viewer/templates/view_controller.js b/devtools/fx_viewer/templates/view_controller.js new file mode 100644 index 00000000000..dbe667448ea --- /dev/null +++ b/devtools/fx_viewer/templates/view_controller.js @@ -0,0 +1,334 @@ +// Centralized state machine managing interactions, camera transforms, selections, and extension visibility. +// +// State fields: +// hoveredNodeId, hoveredEdge, selectedNodeId, selectedEdge, previewNodeId +// ancestors / descendants (Sets) — BFS results for the active selection +// searchCandidates, searchSelectedIndex +// highlightAncestors (bool) — dim non-ancestor nodes when a node is selected +// themeName (string) +// activeExtensions (Set) — which extension layers are visible +// colorBy (string) — which extension drives node fill color ('base' or extId) +// highlightGroups (Map, color: string}>) +// — programmatic overlay groups; drawn as thick colored borders after node rendering; +// independent of selection state; set via FXGraphViewer.addHighlightGroup() API. +class ViewerController { + constructor(viewer, initialState = {}) { + this.viewer = viewer; + this.store = viewer.store; + this.transform = { x: 0, y: 0, k: 1 }; + + const initialTheme = initialState.themeName || initialState.theme || 'light'; + const initialExtensions = initialState.activeExtensions + ? new Set(initialState.activeExtensions) + : new Set(Object.keys(this.store.extensions)); + const initialColorBy = initialState.colorBy || 'base'; + + this.state = { + hoveredNodeId: null, + hoveredEdge: null, + selectedNodeId: null, + selectedEdge: null, + previewNodeId: null, + ancestors: new Set(), + descendants: new Set(), + searchCandidates: [], + searchSelectedIndex: -1, + highlightAncestors: initialState.highlightAncestors !== false, + themeName: initialTheme, + uiVisibility: { ...(initialState.uiVisibility || {}) }, + + // V3 Extensibility State + activeExtensions: initialExtensions, + colorBy: initialColorBy, + highlightGroups: new Map(), + }; + + // Initial computation of the virtual graph + this.store.computeActiveGraph(this.state.activeExtensions, this.state.colorBy); + } + + snapshotState() { + return { + hoveredNodeId: this.state.hoveredNodeId, + hoveredEdge: this.state.hoveredEdge, + selectedNodeId: this.state.selectedNodeId, + selectedEdge: this.state.selectedEdge, + previewNodeId: this.state.previewNodeId, + searchCandidates: this.state.searchCandidates.slice(), + searchSelectedIndex: this.state.searchSelectedIndex, + highlightAncestors: this.state.highlightAncestors, + themeName: this.state.themeName, + theme: this.state.themeName, + activeExtensions: Array.from(this.state.activeExtensions), + colorBy: this.state.colorBy, + highlightGroups: new Map(this.state.highlightGroups), + searchQuery: this.viewer.ui && this.viewer.ui.searchInput ? this.viewer.ui.searchInput.value : "", + camera: { ...this.transform }, + uiVisibility: { ...(this.state.uiVisibility || {}) }, + }; + } + + setState(newState, options = {}) { + const prev = this.snapshotState(); + + const patch = { ...newState }; + if ('theme' in patch && !('themeName' in patch)) { + patch.themeName = patch.theme; + } + if ('activeExtensions' in patch && !(patch.activeExtensions instanceof Set)) { + patch.activeExtensions = new Set(patch.activeExtensions || []); + } + + Object.assign(this.state, patch); + + // If graph structure or color changed, we must recompute and update UI + if ('activeExtensions' in patch || 'colorBy' in patch) { + this.store.computeActiveGraph(this.state.activeExtensions, this.state.colorBy); + + if (this.viewer.minimapRenderer) { + this.viewer.minimapRenderer.generateThumbnail(); + } + if (this.viewer.ui) { + this.viewer.ui.renderLegend(); + if (this.state.selectedNodeId) { + this.viewer.ui.updateInfoPanel(this.state.selectedNodeId); + } + } + } + + if ('themeName' in patch || 'theme' in patch) { + if (this.viewer.ui) { + this.viewer.ui.applyThemeToDOM(); + } + if (this.viewer.minimapRenderer) { + this.viewer.minimapRenderer.generateThumbnail(); + } + } + + if (this.viewer.ui) { + this.viewer.ui.syncControlsFromState(); + } + + this.viewer.renderAll(); + + const next = this.snapshotState(); + this.viewer._emit('statechange', { prevState: prev, nextState: next, source: options.source || 'api' }); + if (prev.selectedNodeId !== next.selectedNodeId) { + this.viewer._emit('selectionchange', { + prevSelection: prev.selectedNodeId, + nextSelection: next.selectedNodeId, + source: options.source || 'api', + }); + } + if (prev.theme !== next.theme) { + this.viewer._emit('themechange', { prevTheme: prev.theme, nextTheme: next.theme, source: options.source || 'api' }); + } + } + + animateToTransform(targetX, targetY, targetK, duration = 300) { + const startX = this.transform.x; + const startY = this.transform.y; + const startK = this.transform.k; + const startTime = performance.now(); + + const animate = (currentTime) => { + const elapsed = currentTime - startTime; + const progress = Math.min(elapsed / duration, 1); + const ease = 1 - Math.pow(1 - progress, 3); // easeOutCubic + + this.transform.x = startX + (targetX - startX) * ease; + this.transform.y = startY + (targetY - startY) * ease; + this.transform.k = startK + (targetK - startK) * ease; + + this.viewer.renderAll(); + + if (progress < 1) { + requestAnimationFrame(animate); + } + }; + requestAnimationFrame(animate); + } + + zoomToFit() { + const rect = this.viewer.canvasContainer.getBoundingClientRect(); + const padding = Math.min(50, rect.width/5, rect.height/5); + const availableW = rect.width - padding * 2; + const availableH = rect.height - padding * 2; + + let bounds = this.store.graphBounds; + + if (this.state.selectedNodeId) { + bounds = this.store.computeBoundsForNodes( + this._collectNeighbors(this.state.selectedNodeId) + ) || bounds; + } else if (this.state.selectedEdge) { + bounds = this.store.computeBoundsForNodes( + this._collectEdgeNeighbors(this.state.selectedEdge) + ) || bounds; + } + + if (bounds.width === 0 || bounds.height === 0) return; + + const scaleW = availableW / bounds.width; + const scaleH = availableH / bounds.height; + let targetK = Math.min(scaleW, scaleH); + if (this.state.selectedNodeId || this.state.selectedEdge) { + targetK = Math.min(targetK, 1.2); + } + + const centerX = bounds.minX + bounds.width / 2; + const centerY = bounds.minY + bounds.height / 2; + + const targetX = (rect.width / 2) - centerX * targetK; + const targetY = (rect.height / 2) - centerY * targetK; + + this.animateToTransform(targetX, targetY, targetK); + } + + _collectNeighbors(nodeId) { + const nodes = new Set([nodeId]); + for (const e of this.store.revAdjList.get(nodeId) || []) { + nodes.add(e.v); + } + for (const e of this.store.adjList.get(nodeId) || []) { + nodes.add(e.w); + } + return nodes; + } + + _collectEdgeNeighbors(edge) { + const nodes = new Set([edge.v, edge.w]); + for (const e of this.store.revAdjList.get(edge.v) || []) nodes.add(e.v); + for (const e of this.store.adjList.get(edge.v) || []) nodes.add(e.w); + for (const e of this.store.revAdjList.get(edge.w) || []) nodes.add(e.v); + for (const e of this.store.adjList.get(edge.w) || []) nodes.add(e.w); + return nodes; + } + + panToNode(nodeId) { + const node = this.store.activeNodeMap.get(nodeId); + if (!node) return; + const rect = this.viewer.canvasContainer.getBoundingClientRect(); + this.transform.x = rect.width / 2 - node.x * this.transform.k; + this.transform.y = rect.height / 2 - node.y * this.transform.k; + this.viewer.renderAll(); + } + + animateToNode(nodeId, targetK = null) { + const node = this.store.activeNodeMap.get(nodeId); + if (!node) return; + const rect = this.viewer.canvasContainer.getBoundingClientRect(); + const k = targetK !== null ? targetK : this.transform.k; + const targetX = rect.width / 2 - node.x * k; + const targetY = rect.height / 2 - node.y * k; + this.animateToTransform(targetX, targetY, k); + } + + handleHover(nodeId, edge) { + if (this.state.hoveredNodeId !== nodeId || this.state.hoveredEdge !== edge) { + this.setState({ hoveredNodeId: nodeId, hoveredEdge: edge }); + } + } + + handleClick(nodeId, edge) { + if (nodeId) { + this.selectNode(nodeId); + } else if (edge) { + this.selectEdge(edge); + } else { + this.clearSelection(); + } + } + + selectNode(nodeId) { + const ancestors = this.store.getAncestors(nodeId); + const descendants = this.store.getDescendants(nodeId); + this.setState({ + selectedNodeId: nodeId, + selectedEdge: null, + ancestors, + descendants, + previewNodeId: null + }); + this.viewer.ui.updateInfoPanel(nodeId); + } + + selectEdge(edge) { + const ancestors = this.store.getAncestors(edge.v); + const descendants = this.store.getDescendants(edge.w); + this.setState({ + selectedNodeId: null, + selectedEdge: edge, + ancestors, + descendants, + previewNodeId: null + }); + this.viewer.ui.updateEdgeInfoPanel(edge); + } + + clearSelection() { + this.setState({ + selectedNodeId: null, + selectedEdge: null, + ancestors: new Set(), + descendants: new Set(), + previewNodeId: null + }); + this.viewer.ui.hideInfoPanel(); + } + + handleSearch(query) { + if (!query) { + this.setState({ searchCandidates: [], searchSelectedIndex: -1, previewNodeId: null }); + this.viewer.ui.updateSearchResults([], -1); + if (this.state.selectedNodeId) { + this.viewer.ui.updateInfoPanel(this.state.selectedNodeId); + this.panToNode(this.state.selectedNodeId); + } else { + this.viewer.ui.hideInfoPanel(); + } + return; + } + const candidates = this.viewer.searchEngine.search(query); + this.setState({ searchCandidates: candidates, searchSelectedIndex: -1, previewNodeId: null }); + this.viewer.ui.updateSearchResults(candidates, -1); + } + + handleSearchNavigate(direction) { + const { searchCandidates, searchSelectedIndex } = this.state; + if (searchCandidates.length === 0) return; + let newIndex = searchSelectedIndex + direction; + if (newIndex < 0) newIndex = searchCandidates.length - 1; + if (newIndex >= searchCandidates.length) newIndex = 0; + + const previewNode = searchCandidates[newIndex].node.id; + this.setState({ searchSelectedIndex: newIndex, previewNodeId: previewNode }); + this.viewer.ui.updateSearchActiveItem(newIndex); + this.viewer.ui.updateInfoPanel(previewNode); + this.panToNode(previewNode); + } + + handleSearchSelect(index) { + const { searchCandidates } = this.state; + const idx = index !== undefined ? index : this.state.searchSelectedIndex; + if (idx >= 0 && idx < searchCandidates.length) { + const nodeId = searchCandidates[idx].node.id; + this.selectNode(nodeId); + this.panToNode(nodeId); + this.viewer.ui.closeSearchMenu(); + this.setState({ searchCandidates: [], searchSelectedIndex: -1, previewNodeId: null }); + if (this.viewer.ui.searchInput) this.viewer.ui.searchInput.value = ''; + } + } + + handleSearchHover(index) { + const { searchCandidates } = this.state; + if (index >= 0 && index < searchCandidates.length) { + const previewNode = searchCandidates[index].node.id; + this.setState({ searchSelectedIndex: index, previewNodeId: previewNode }); + this.viewer.ui.updateSearchActiveItem(index); + this.viewer.ui.updateInfoPanel(previewNode); + this.panToNode(previewNode); + } + } +} diff --git a/devtools/observatory/README.md b/devtools/observatory/README.md new file mode 100644 index 00000000000..a3283796a2a --- /dev/null +++ b/devtools/observatory/README.md @@ -0,0 +1,244 @@ +# Observatory + +Observatory is a unified debugging framework for ExecuTorch that captures graph snapshots and analysis data across compilation stages, then exports the results as a standalone, shareable HTML report. + +Instead of collecting logs, traces, and artifacts from scattered sources, Observatory provides a single workflow: **capture, store, analyze, visualize, share**. The output is one HTML file that anyone can open in a browser to inspect graphs, accuracy metrics, per-layer analysis, and more. + +## Why it exists + +Debugging model compilation issues is often too manual. When something goes wrong, engineers typically need to collect information from multiple places, reconstruct execution context by hand, and pass partial artifacts between people. This is especially painful when the issue is hard to reproduce, the investigator is not the original developer, or the context needs to be shared across teams. + +Observatory closes this gap by providing a consistent, automated workflow from data capture to presentation. + +## The workflow + +``` +capture --> store --> analyze --> visualize --> share +``` + +1. **Capture**: Observatory wraps your export script and automatically collects graph snapshots at each compilation stage (export, quantize, lower). +2. **Store**: Raw data is persisted as structured JSON for later re-analysis. +3. **Analyze**: Each lens processes the collected data into findings, comparisons, and derived insights. +4. **Visualize**: Results are assembled into an interactive HTML report with multiple view types. +5. **Share**: The report is a single self-contained HTML file. Send it, attach it to a bug report, or host it on GitHub Pages. + +## What you get + +A standalone HTML report containing: + +- **Graph View**: Interactive fx_viewer graphs with color-coded overlays (accuracy error, op type, etc.) +- **Table View**: Key-value summaries, per-record metrics, cross-record comparisons +- **Compare View**: Side-by-side graph comparison with synchronized selection +- **Dashboard**: Session-level summary with badges and navigation + +## Quick start + +### CLI (zero-config) + +Point the CLI at any ExecuTorch export script: +```bash +python -m executorch.devtools.observatory SCRIPT [SCRIPT_ARGS...] +``` +Use `--output-html` / `--output-json` to set output paths explicitly: + +```bash +python -m executorch.devtools.observatory \ + --output-html /tmp/obs/report.html \ + --output-json /tmp/obs/report.json \ + examples/qualcomm/oss_scripts/swin_v2_t.py \ + --model SM8650 -b ./build-android -d imagenet-mini/val -a ./swin_v2_t +``` + +Use backend-specific observatory cli for additional customized lenses and hooks (for example, xnnpack backend with per-layer accuracy analysis) + +```bash +python -m executorch.backends.xnnpack.debugger.observatory \ + --output-html /tmp/obs/report.html \ + --lense_recipe=accuracy \ + examples/xnnpack/aot_compiler.py \ + --model_name=mv2 --delegate --quantize --output_dir /tmp/mv2 +``` + +> **XNNPack note**: `aot_compiler.py` uses relative imports so it must run as a Python module. +> The CLI auto-detects this from `__init__.py` presence. You can also pass the dotted module +> name directly: `examples.xnnpack.aot_compiler` + +This produces: +- `/tmp/obs/report.html` (interactive report) +- `/tmp/obs/report.json` (raw data, path auto-derived from HTML path) + +### Python API + +```python +from executorch.devtools.observatory import Observatory, observe_pass + +# Wrap passes for automatic graph collection +pass_a = observe_pass(SomePass()) + +Observatory.clear() +with Observatory.enable_context(): + # Auto: Lenses can auto-insert collection points by monkey patching when entering context + # Manual: Insert the collection point anywhere + Observatory.collect("step_0", graph_module) + + # observe_pass: auto-collects input and output graphs + result = pass_a(graph_module) + # collects "SomePass/input" and "SomePass/output" + +Observatory.export_html_report("/tmp/report.html") +``` + +## Core concepts + +### `observe_pass` decorator + +The `observe_pass` decorator wraps any pass (PassBase subclass or callable) to automatically collect graphs via Observatory. By default it captures both input and output graphs, making pass debugging a one-line change: + +```python +observed = observe_pass(SomePass()) # wrap once +result = observed(graph_module) # auto-collects input + output +``` + +Record names are derived from the class name and auto-deduplicated on repeat calls. +See [USAGE.md](USAGE.md) for the full decorator reference. + +### Lenses + +A **Lens** is a modular extension that adds domain-specific debugging logic. Each lens can participate in capture, analysis, and visualization. This is what makes Observatory a framework rather than a fixed tool. + +Built-in lenses: + +| Lens | What it does | +|------|-------------| +| `GraphLens` | Renders interactive fx_viewer graph for each collected artifact | +| `MetadataLens` | Collects artifact type, node count, environment info | +| `StackTraceLens` | Captures the call stack at each collection point | +| `PipelineGraphCollectorLens` | Auto-collects graphs at export/quantize/lower stages (patches framework functions) | +| `AccuracyLens` | Evaluates model accuracy at each stage (PSNR, cosine similarity, MSE, top-k) | +| `GraphColorLens` | Colors graph nodes by op_type or op_target | +| `PerLayerAccuracyLens` | Computes per-layer accuracy metrics with graph overlays | + +See [lenses/LENSES.md](lenses/LENSES.md) for detailed lens documentation. + +### The 2-step design + +Observatory separates **runtime collection** from **report generation**: + +1. **Step 1**: Run your script — both JSON and HTML are exported automatically +2. **Step 2**: Re-generate HTML any time later from the JSON (`cli visualize`) + +This means you can collect data in CI and re-generate reports locally, or regenerate HTML after updating lens code without re-running expensive export scripts. + +```bash +# Step 1: collect (e.g., in CI) +python -m executorch.devtools.observatory script.py ... + +# Step 2: re-visualize (e.g., locally) +python -m executorch.devtools.observatory visualize \ + --input-json observatory_report.json --output-html report.html +``` + +### Fx-Viewer + +Fx-Viewer (`devtools/utils/fx_viewer`) is the graph visualization component used inside Observatory's Graph View. Observatory owns the workflow; Fx-Viewer provides the interactive graph rendering, node inspection, and highlighting within that workflow. + +## How to use it + +See [USAGE.md](USAGE.md) for the full CLI usage guide, including: + +- Zero-config e2e workflow +- Visualize mode (JSON to HTML) +- Manual collection points in arbitrary code +- Demo script batch modes + +## Writing a custom Lens + +A lens implements the `observe -> digest -> analyze -> frontend` lifecycle: + +```python +from executorch.devtools.observatory.interfaces import ( + AnalysisResult, Frontend, TableBlock, TableRecordSpec, ViewList, +) + +class MyLens: + @classmethod + def get_name(cls): return "my_lens" + + @classmethod + def observe(cls, artifact, context): + return {"node_count": len(artifact.graph.nodes)} + + @classmethod + def digest(cls, observation, context): + return observation + + @staticmethod + def analyze(records, config): + return AnalysisResult() + + @staticmethod + def get_frontend_spec(): + class MyFrontend(Frontend): + def record(self, digest, analysis, context): + return ViewList(blocks=[ + TableBlock(id="summary", title="Summary", + record=TableRecordSpec(data=digest), order=0) + ]) + return MyFrontend() +``` + +Register it before entering the context: + +```python +Observatory.register_lens(MyLens) +``` + +See [lenses/LENSES.md](lenses/LENSES.md) for the full lens protocol and built-in lens details. + +## Adding graph overlays from a Lens + +Lenses can contribute colored graph overlays during the `analyze()` phase: + +```python +from executorch.backends.qualcomm.debugger.observatory.interfaces import ( + AnalysisResult, RecordAnalysis, +) +from executorch.devtools.fx_viewer import ( + GraphExtensionPayload, GraphExtensionNodePayload, +) + +payload = GraphExtensionPayload( + id="error", name="Accuracy Error", + legend=[{"label": "Low", "color": "#93c5fd"}], + nodes={"node_0": GraphExtensionNodePayload(fill_color="#93c5fd")}, +) + +record_analysis = RecordAnalysis(data={"max_mse": 0.1}) +record_analysis.add_graph_layer("error", payload) + +return AnalysisResult(per_record_data={"step_1": record_analysis}) +``` + +## Entry points + +| File | Purpose | +|------|---------| +| `observatory.py` | Runtime lifecycle, report assembly, export APIs | +| `interfaces.py` | Typed dataclass contracts for all blocks, lenses, and analysis | +| `graph_hub.py` | Graph asset/layer merge logic | +| `cli.py` | CLI runner (run mode + visualize mode) | +| `auto_collect.py` | ETRecord monkey-patch auto-collection | + +## Document map + +| Document | What it covers | +|----------|---------------| +| [USAGE.md](USAGE.md) | CLI usage guide, workflow examples, demo script modes | +| [lenses/LENSES.md](lenses/LENSES.md) | Built-in lens details, accuracy lens internals, custom lens patterns | +| [REFERENCE.md](REFERENCE.md) | Contract tables, API reference, JS callbacks, performance notes | + +## Tests + +```bash +pytest -q backends/devtools/observatory/tests +``` diff --git a/devtools/observatory/REFERENCE.md b/devtools/observatory/REFERENCE.md new file mode 100644 index 00000000000..122fdded26b --- /dev/null +++ b/devtools/observatory/REFERENCE.md @@ -0,0 +1,846 @@ +# Observatory Technical Reference + +This document contains the detailed contract tables, API references, and performance notes for Observatory internals. For an introduction and usage guide, see [README.md](README.md). + +## 1. Contract Tables + +The tables below define the actual contracts across lens code, report JSON, +frontend rendering, and custom JS callbacks. + +### 1.1 End-to-End Stage Matrix + +| Stage | Entrypoint / Signature | Primary input | Primary output | Output JSON path | +| --- | --- | --- | --- | --- | +| Runtime capture | `Observatory.collect(name: str, artifact: Any)` | `artifact`, `ObservationContext(config, shared_state)` | `RecordDigest(name, timestamp, data)` | `records[i].digests` | +| Session hooks | `Lens.on_session_start/end(context)` | `ObservationContext` | lens-scoped session payload | `session.start_data[lens]`, `session.end_data[lens]` | +| Analyze | `Lens.analyze(records, config) -> AnalysisResult` | `List[RecordDigest]`, merged config | `global_data`, `per_record_data[name]` | `analysis_results[lens]` | +| Graph merge | `GraphHub.add_analysis_layers(graph_ref, lens_name, record_analysis)` | `RecordAnalysis.graph_layers` | namespaced layer map | `graph_layers[graph_ref]["/"]` | +| Frontend assembly | `Frontend.dashboard(...)`, `Frontend.record(...)` | digest/session/analysis values | `ViewList(blocks=[...])` | `dashboard[lens]`, `records[i].views[lens]` | +| Browser render | `renderMain()/renderUnifiedView()` | full report payload | DOM sections / graph viewers | runtime only | +| Custom JS invoke | `fn(container, args, context, analysis)` | block spec + runtime context | user DOM changes | runtime only | + +### 1.2 Lens Lifecycle Signatures by Stage + +| Stage | Method | Signature | Return / Side Effect | +| --- | --- | --- | --- | +| Registration | `get_name` | `@classmethod get_name() -> str` | stable lens key | +| Registration | `setup` | `@classmethod setup() -> None` | one-time setup | +| Session | `on_session_start` | `@classmethod on_session_start(context: ObservationContext) -> Optional[Serializable]` | stored in `session.start_data[lens]` | +| Runtime | `observe` | `@classmethod observe(artifact: Any, context: ObservationContext) -> Any` | transient observation | +| Runtime | `digest` | `@classmethod digest(observation: Any, context: ObservationContext) -> Serializable` | persisted digest in `RecordDigest.data[lens]` | +| Session | `on_session_end` | `@classmethod on_session_end(context: ObservationContext) -> Optional[Serializable]` | stored in `session.end_data[lens]` | +| Analyze | `analyze` | `@staticmethod analyze(records: List[RecordDigest], config: Dict[str, Any]) -> AnalysisResult` | global + per-record derived data | +| Frontend | `get_frontend_spec` | `@staticmethod get_frontend_spec() -> Frontend` | returns strategy object | +| Reset | `clear` | `@classmethod clear() -> None` | clears lens internal state | + +### 1.3 Frontend Stage Signatures and Input Sources + +| Method | Signature | Python-side argument source | Serialized destination | +| --- | --- | --- | --- | +| `resources` | `resources() -> Dict[str, str]` | lens frontend implementation | `resources.js[]`, `resources.css[]` | +| `dashboard` | `dashboard(start, end, analysis, records) -> Optional[ViewList]` | `start=session.start_data[lens]`, `end=session.end_data[lens]`, `analysis=analysis_results[lens].global_data`, `records=List[RecordDigest]` | `dashboard[lens]` | +| `record` | `record(digest, analysis, context) -> Optional[ViewList]` | `digest=record.data[lens]`, `analysis={"global": global_data, "record": per_record_data[name].data}`, `context={"index", "name"}` | `records[i].views[lens]` | +| `check_badges` | `check_badges(digest, analysis) -> List[Dict[str, str]]` | current digest + `global_data` | `records[i].badges[]` | +| `check_index_diffs` | `check_index_diffs(prev_digest, curr_digest, analysis) -> Dict[str, str]` | previous/current digest + `global_data` | `records[i].diff_index` | + +### 1.4 View Block Contracts (Frontend Output) + +| Block type | Python dataclass | Required record fields | Compare modes | JS renderer path | +| --- | --- | --- | --- | --- | +| Table | `TableBlock(record=TableRecordSpec)` | `record.data: Dict[str, Serializable]` | `auto`, `disabled` | `renderTableContent` | +| HTML | `HtmlBlock(record=HtmlRecordSpec)` | `record.content: str` | `auto`, `disabled` | `content.innerHTML` | +| Custom | `CustomBlock(record=CustomRecordSpec)` | `record.js_func: str`, `record.args: dict` | `custom`, `disabled` | `resolveFunction(js_func)` then callback | +| Graph | `GraphBlock(record=GraphRecordSpec)` | `record.graph_ref: str` | `auto`, `custom`, `disabled` | `mountGraphViewer` | + +| Common block fields | Type | Notes | +| --- | --- | --- | +| `id` | `str` | must be non-empty, unique inside one `ViewList` | +| `title` | `str` | section header text | +| `order` | `int` | stable sort key for rendering | +| `collapsible` | `bool` | section open/close behavior | + +### 1.5 Information Object Map Across Boundaries + +| Python object | Produced in Python stage | JSON path in report | JS access pattern | +| --- | --- | --- | --- | +| `RecordDigest` | `collect` | `records[i]` | `state.data.records[i]` | +| `RecordDigest.data[lens]` | `digest` | `records[i].digests[lens]` | `context.record.digests[lens]` in record custom JS | +| `SessionResult.start_data[lens]` | `on_session_start` | `session.start_data[lens]` | dashboard custom context `start` | +| `SessionResult.end_data[lens]` | `on_session_end` | `session.end_data[lens]` | dashboard custom context `end` | +| `AnalysisResult.global_data` | `analyze` | `analysis_results[lens].global_data` | `analysis.global_data` in custom JS | +| `AnalysisResult.per_record_data[name].data` | `analyze` | `analysis_results[lens].per_record_data[name].data` | `analysis.per_record_data?.[recordName]?.data` | +| `AnalysisResult.per_record_data[name].graph_layers` | `analyze` | merged into `graph_layers[graph_ref]` | included via viewer `extensions` | +| `ViewList` | frontend callbacks | `dashboard[lens].blocks` / `records[i].views[lens].blocks` | `getLensBlocks(record, lensName)` | + +### 1.6 Custom JS Callback Signatures + +| Callback stage | Invocation site | Signature | `context` shape | +| --- | --- | --- | --- | +| Record block render | `renderRecordBlock(..., context={ index, record }, analysis)` | `fn(container, args, context, analysis)` | `{ index: number, record: SerializedRecord }` | +| Dashboard block render | `renderDashboard(..., context={ start, end, records }, analysis)` | `fn(container, args, context, analysis)` | `{ start: object, end: object, records: SerializedRecord[] }` | +| Compare render (`mode="custom"`) | `renderCustomCompare(..., context, analysis)` | `fn(container, args, context, analysis)` | `{ indices, names, records, blocks, lens, block_id }` | + +| Callback arg | Runtime value | Source contract | +| --- | --- | --- | +| `container` | target block DOM container | JS renderer internals | +| `args` | `block.record.args` or `block.compare.args` | `CustomRecordSpec.args` / `CustomCompareSpec.args` | +| `analysis` | `state.data.analysis_results[lensName]` | serialized `AnalysisResult` (`global_data`, `per_record_data`) | + +Example (record view): + +```javascript +function renderRecord(container, args, context, analysis) { + const lensName = "accuracy"; + const digest = context.record?.digests?.[lensName] || {}; + const perRecord = analysis?.per_record_data?.[context.record?.name]?.data || {}; + const global = analysis?.global_data || {}; + container.textContent = `${args.title}: mse_max=${perRecord.max_mse ?? "n/a"}`; +} +``` + +### 1.7 Graph Pipeline (Observatory -> Viewer) + +| Step | Python/JS API | Payload shape | Notes | +| --- | --- | --- | --- | +| Base graph capture | `GraphLens.observe` | `{ graph_ref, base, meta }` | `base` comes from `FXGraphExporter.generate_json_payload()["base"]` | +| Base graph registration | `GraphHub.register_asset(graph_ref, base, meta)` | `graph_assets[graph_ref] = { base, meta }` | one asset per record name by default | +| Layer authoring | `RecordAnalysis.add_graph_layer(key, extension)` | `RecordAnalysis.graph_layers[key]` | `extension` accepts `GraphExtensionPayload` or `GraphExtension` | +| Layer merge | `GraphHub.add_analysis_layers(graph_ref, lens_name, analysis)` | `graph_layers[graph_ref]["/"]` | namespaced IDs prevent cross-lens collisions | +| Viewer payload build | `buildViewerPayload(graphRef)` in `01_utils.js` | `{ base, extensions }` | `base <- graph_assets`, `extensions <- graph_layers` | +| Viewer mount | `FXGraphViewer.create({ payload, mount, layout, state })` | viewer instance | called from `mountGraphViewer` | +| Compare mount | `FXGraphCompare.create({ viewers, layout, sync })` | compare controller | called from `renderGraphCompare` | + +### 1.8 fx_viewer Type Bridge (Python Side) + +| Observatory usage point | fx_viewer type/API | Field-level mapping | +| --- | --- | --- | +| Base graph export | `FXGraphExporter.generate_json_payload()` | uses `.base.legend/nodes/edges` for `graph_assets[graph_ref].base` | +| Layer helper | `GraphExtension(id, name)` | accumulates `nodes_data[node_id]` then `build_payload()` | +| Stable layer payload | `GraphExtensionPayload` | `id`, `name`, `legend`, `nodes[node_id]` | +| Per-node layer payload | `GraphExtensionNodePayload` | `info`, `tooltip`, `label_append`, `fill_color` | +| Observatory conversion | `GraphLayerContribution.to_payload()` | converts `GraphExtension` to `GraphExtensionPayload` and applies overrides | + +### 1.9 fx_viewer Runtime API Used by Observatory (JS Side) + +| Runtime API | Called from observatory JS | Purpose | +| --- | --- | --- | +| `FXGraphViewer.create(config)` | `mountGraphViewer` | mount single graph view | +| `viewer.init()` | `mountGraphViewer` | initialize renderer/UI | +| `viewer.setLayout(patch)` | `mountGraphViewer` | hide sidebar in compact compare layouts | +| `viewer.setUIVisibility(flags)` | `mountGraphViewer` | hide minimap toggle in compare | +| `viewer.setLayers(layerIds)` | `ObservatoryAPI` graph handle | switch active extension layers | +| `viewer.setColorBy(layerId)` | `ObservatoryAPI` graph handle | set color source layer | +| `viewer.patchLayerNodes(layerId, patch)` | `ObservatoryAPI` graph handle | patch node style/info in a layer | +| `viewer.selectNode(nodeId, opts)` | `ObservatoryAPI` + delegated actions | focus node | +| `viewer.zoomToFit()` | `ObservatoryAPI` | reset camera to graph bounds | +| `viewer.enterFullscreen()/exitFullscreen()` | `ObservatoryAPI` | fullscreen control | +| `viewer.on('selectionchange', cb)` | `ObservatoryAPI`, `FXGraphCompare` | selection events + compare sync | +| `FXGraphCompare.create(config)` | `renderGraphCompare` | multi-view synchronization | +| `compare.setSync(patch)` | compare sync checkbox handler | toggle selection sync on/off | + +### 1.10 Graph Layer Naming and Selection Rules + +| Rule | Contract | Example | +| --- | --- | --- | +| Author key (lens analyze code) | free-form local key in `RecordAnalysis.graph_layers` | `"error"` | +| Report-level namespaced key | `"/"` | `"accuracy/error"` | +| Graph block default layers | `GraphRecordSpec.default_layers: list[str]` | `default_layers=["accuracy/error"]` | +| Graph block default color | `GraphRecordSpec.default_color_by: str | None` | `default_color_by="accuracy/error"` | +| Compare viewer options merge | `Object.assign({}, record.viewer_options, compare.viewer_options_compare)` | compare options override record defaults without mutating record options | + +### 1.11 Python-to-JS Dataflow Cheat Sheet + +| You define in Python | Appears in report | Read in JS callback | +| --- | --- | --- | +| `CustomRecordSpec.args` | `block.record.args` | callback `args` | +| `CustomCompareSpec.args` | `block.compare.args` | callback `args` (compare mode) | +| `Frontend.record(..., context={"index","name"})` | selection metadata + serialized record list | callback `context.index`, `context.record` | +| `AnalysisResult.global_data` | `analysis_results[lens].global_data` | callback `analysis.global_data` | +| `AnalysisResult.per_record_data[name].data` | `analysis_results[lens].per_record_data[name].data` | callback `analysis.per_record_data?.[context.record.name]?.data` | + + +## 2. Embedded References (Single Source) + +This README now embeds the former standalone references so contributors can +review runtime behavior and API contracts in one place. + +### 2.1 Architecture Reference + + +Observatory is a whitebox debugging runtime for ExecuTorch compilation and execution flows. +This implementation is intentionally graph-native and typed. + +#### 1. Core Principles + +1. Strict contracts over implicit dicts. +2. Runtime capture separated from offline analysis. +3. Graph assets shared by reference via `graph_ref`. +4. Graph overlays produced in `analyze()` and merged centrally. +5. UI runtime split into topic JS modules for reviewability. + +#### 2. Four-Phase Lifecycle + +##### Phase 1: Runtime Capture + +1. User enters `Observatory.enable_context(...)`. +2. Lenses run `observe()` and `digest()` during `collect(name, artifact)`. +3. Output is persisted as `RecordDigest`. + +##### Phase 2: Session Hooks + +1. Outermost context entry triggers `on_session_start`. +2. Outermost context exit triggers `on_session_end`. +3. ETRecord monkey-patch auto-collection is installed/uninstalled on outermost boundaries. + +##### Phase 3: Analysis + +1. Each lens runs `analyze(records, config)`. +2. Global results go to `AnalysisResult.global_data`. +3. Per-record results go to `AnalysisResult.per_record_data[record_name]` as `RecordAnalysis`. +4. Graph overlay contributions are attached in `RecordAnalysis.graph_layers`. + +##### Phase 4: Report Assembly + Rendering + +1. Frontend blocks are produced from typed `ViewList` contracts. +2. `GraphHub` merges base assets and analysis-time graph overlays. +3. Report payload is exported to JSON and HTML. + +#### 3. Graph-Native Runtime Model + +##### 3.1 Graph Asset Source + +1. `GraphLens` builds one canonical fx_viewer base payload per record. +2. Payload stored in report-level `graph_assets[graph_ref]`. + +##### 3.2 Graph Overlay Source + +1. Lenses attach layers in `analyze()` per record using `RecordAnalysis.graph_layers`. +2. Each layer uses typed fx_viewer payloads (`GraphExtensionPayload`) or `GraphExtension` authoring helper. +3. `GraphHub` namespaces internal layer IDs as `/`. + +##### 3.3 GraphView Consumption + +1. `GraphBlock.record.graph_ref` resolves base graph. +2. `graph_layers[graph_ref]` provides merged overlay layers. +3. Compare mode renders side-by-side viewers with optional selection sync. + +#### 4. UI Runtime Topology + +JS modules under `templates/js`: + +1. `00_state.js`: report state bootstrap. +2. `01_utils.js`: utility + viewer payload helpers. +3. `02_layout.js`: header/sidebar/index rendering. +4. `03_blocks.js`: block renderers + compare behavior. +5. `04_actions.js`: navigation/selection/theme actions. +6. `05_bootstrap_api.js`: app init + `window.ObservatoryAPI`. + +#### 5. Auto-Collection Architecture (ETRecord) + +1. `enable_context` installs ETRecord wrappers. +2. Wrapped ETRecord calls invoke `Observatory.collect(...)` transparently. +3. No manual observe points are required for ETRecord graph capture. +4. Wrappers are restored on outermost context exit. + +#### 6. Breaking API Policy + +This observatory path is intentionally breaking: + +1. Frontend methods must return typed `ViewList` block contracts. +2. Analyze-time graph layers use typed dataclasses, not raw dict hooks. +3. Legacy compatibility shims are intentionally not maintained. + +### 2.2 Python API Reference + + +#### 1. Entry Point + +Import: + +```python +from executorch.backends.qualcomm.debugger.observatory import Observatory +``` + +#### 2. Session Lifecycle + +##### `Observatory.enable_context(config: dict | None = None)` + +Context manager enabling collection. + +Behavior: + +1. Registers default lenses lazily on first use. +2. Installs ETRecord auto-collection wrappers on outermost entry. +3. Calls `on_session_start` hooks on outermost entry. +4. Calls `on_session_end` hooks on outermost exit. +5. Uninstalls ETRecord wrappers on outermost exit. + +Example: + +```python +with Observatory.enable_context(config={"profiling": {"enabled": True}}): + ... +``` + +##### Nested config behavior + +1. Configs are shallow-merged by key. +2. Nested dict values are merged per top-level lens key. +3. Inner context values override outer context values. + +#### 3. Capture APIs + +##### `Observatory.collect(name: str, artifact: Any) -> None` + +Captures one record across all registered lenses. + +Behavior: + +1. No-op if context disabled. +2. Populates `ObservationContext.shared_state["record_name"]`. +3. Executes each lens `observe -> digest` pipeline. +4. Stores `RecordDigest` keyed by `name`. + +##### `Observatory.ignore_graphs(names: list[str]) -> None` + +1. Marks matching names ignored for future collect calls. +2. Removes existing records with matching names. + +##### `Observatory.list_collected() -> list[str]` + +Returns all collected record names. + +##### `Observatory.get(name: str) -> RecordDigest | None` + +Returns one collected record by name. + +#### 4. Lens Registration and Reset + +##### `Observatory.register_lens(lens_cls)` + +Registers a custom lens and runs `lens_cls.setup()` once. + +##### `Observatory.clear() -> None` + +1. Clears all records. +2. Clears session data. +3. Uninstalls ETRecord wrappers. +4. Calls `clear()` on every registered lens. + +#### 5. Export APIs + +##### `Observatory.export_html_report(output_path, title="Observatory Report", config=None)` + +Builds analysis + frontend payload and emits interactive HTML. + +##### `Observatory.export_json(output_path)` + +Exports raw records + session payload only. + +##### `Observatory.generate_html_from_json(json_path, html_path, title="Observatory Report", config=None)` + +Reconstructs HTML report from exported raw JSON and current lens frontend/analyze logic. + +#### 6. Minimal End-to-End Example + +```python +import torch +from executorch.backends.qualcomm.debugger.observatory import Observatory + +class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(8, 8) + + def forward(self, x): + return self.fc(x) + +model = M().eval() +graph = torch.fx.symbolic_trace(model) + +Observatory.clear() +with Observatory.enable_context(): + Observatory.collect("step_0", graph) + +Observatory.export_html_report("/tmp/observatory_report.html") +``` + +#### 7. Notes + +1. This observatory path is breaking-by-design and does not support legacy dict block contracts. +2. Frontend returns must use typed `ViewList` dataclass API. +3. Graph layers must be attached via `AnalysisResult.per_record_data[record].graph_layers`. + +### 2.3 Interface Reference + + +This document defines strict dataclass contracts used by observatory lenses and UI rendering. + +Source of truth: + +- `backends/qualcomm/debugger/observatory/interfaces.py` + +#### 1. Frontend Block Contracts + +All frontend methods return: + +```python +ViewList(blocks=[...]) +``` + +where each block is one of: + +1. `TableBlock` +2. `HtmlBlock` +3. `CustomBlock` +4. `GraphBlock` + +##### 1.1 TableBlock + +Fields: + +1. base: `id`, `title`, `order`, `collapsible` +2. `record: TableRecordSpec(data: dict[str, Serializable])` +3. `compare: TableCompareSpec(mode: "auto" | "disabled")` + +##### 1.2 HtmlBlock + +Fields: + +1. base: `id`, `title`, `order`, `collapsible` +2. `record: HtmlRecordSpec(content: str)` +3. `compare: HtmlCompareSpec(mode: "auto" | "disabled")` + +##### 1.3 CustomBlock + +Fields: + +1. base: `id`, `title`, `order`, `collapsible` +2. `record: CustomRecordSpec(js_func: str, args: dict)` +3. `compare: CustomCompareSpec(mode: "custom" | "disabled", js_func: str | None, args: dict)` + +Rules: + +1. `record.js_func` must be non-empty. +2. `compare.mode == "custom"` uses `compare.js_func` if set; otherwise falls back to `record.js_func`. + +##### 1.4 GraphBlock + +Fields: + +1. base: `id`, `title`, `order`, `collapsible` +2. `record: GraphRecordSpec` +3. `compare: GraphCompareSpec` + +`GraphRecordSpec` fields: + +1. `graph_ref: str` (required) +2. `default_layers: list[str]` +3. `default_color_by: str | None` +4. `layer_scope: "all" | "lens_only" | list[str]` +5. `viewer_options: dict` +6. `controls: dict` +7. `fullscreen: dict` + +`GraphCompareSpec` fields: + +1. `mode: "auto" | "disabled" | "custom"` +2. `max_parallel: int >= 1` +3. `sync_toggle: bool` +4. `default_sync: dict` — initial `FXGraphCompare` sync config passed as `config.sync`. Keys: `mode` (`"auto"`, `"id"`, `"layer"`, `"none"`), `layer` (extension id), `field` (info key). Empty dict (default) falls back to `{ mode: "auto" }`. +5. `viewer_options_compare: dict` +6. `js_func: str | None` +7. `args: dict` + +#### 2. Validation API + +Utilities: + +1. `validate_view_block(block)` +2. `validate_view_list(view_list)` + +Validation checks: + +1. Non-empty block id/title. +2. Unique block ids in one `ViewList`. +3. CustomBlock function requirements. +4. GraphBlock `graph_ref` and `max_parallel` constraints. + +#### 3. Analysis Contracts + +##### 3.1 `AnalysisResult` + +Fields: + +1. `global_data: dict[str, Serializable]` +2. `per_record_data: dict[str, RecordAnalysis]` + +##### 3.2 `RecordAnalysis` + +Fields: + +1. `data: dict[str, Serializable]` +2. `graph_layers: dict[str, GraphLayerContribution]` + +##### 3.3 `GraphLayerContribution` + +Fields: + +1. `extension`: `GraphExtensionPayload | GraphExtension` +2. `id_override: str | None` +3. `name_override: str | None` + +Method: + +1. `to_payload() -> GraphExtensionPayload` + +Notes: + +1. `GraphExtensionPayload` is preferred for stable serialization. +2. `GraphExtension` is accepted as authoring helper and converted lazily. + +#### 4. Runtime Core Contracts + +##### 4.1 ObservationContext + +1. `config: dict` +2. `shared_state: dict` + +##### 4.2 RecordDigest + +1. `name: str` +2. `timestamp: float` +3. `data: dict[str, Serializable]` + +##### 4.3 SessionResult + +1. `start_data: dict` +2. `end_data: dict` + +#### 5. Lens Protocol Summary + +Each lens may implement: + +1. `setup()` +2. `on_session_start(context)` +3. `observe(artifact, context)` +4. `digest(observation, context)` +5. `on_session_end(context)` +6. `clear()` +7. `analyze(records, config) -> AnalysisResult` +8. `get_frontend_spec() -> Frontend` + +No separate `contribute_graph_layers` hook is used in this architecture. +Graph layers are contributed through `analyze()` via typed `RecordAnalysis`. + +### 2.4 JavaScript API Reference + + +This document describes the host/runtime JS contracts exposed by observatory reports. + +#### 1. CustomBlock JS Contract + +`CustomBlock.record.js_func` signature: + +```javascript +function renderRecord(container, args, context, analysis) { + // container: HTMLElement + // args: static serializable args from Python + // context: { index, record } + // analysis: { global_data, per_record_data } +} +``` + +`CustomBlock.compare` with `mode="custom"` signature: + +```javascript +function renderCompare(container, args, context, analysis) { + // container: HTMLElement + // args: static compare args + // context: { + // indices: number[], + // names: string[], + // records: object[], + // blocks: object[], + // lens: string, + // block_id: string, + // } + // analysis: { global_data, per_record_data } +} +``` + +Behavior: + +1. If `compare.js_func` is set, it is used. +2. If `compare.js_func` is not set, runtime falls back to `record.js_func`. + +#### 2. `window.ObservatoryAPI` + +Defined in `templates/js/05_bootstrap_api.js`. + +##### 2.1 `mountGraph(container, graphRef, options)` + +Mount graph viewer into host container. + +Arguments: + +1. `container`: selector string or `HTMLElement`. +2. `graphRef`: key into report `graph_assets`. +3. `options`: + - `default_layers` + - `default_color_by` + - `viewer_options` + +Returns `GraphHandle` with methods: + +1. `setLayers(layerIds)` +2. `setColorBy(layerId)` +3. `updateLayerNodeStyle(layerId, nodeId, patch)` +4. `selectNode(nodeId, opts)` +5. `zoomToFit()` +6. `setSyncEnabled(enabled)` +7. `enterFullscreen()` +8. `exitFullscreen()` +9. `onNodeSelected(callback)` +10. `destroy()` + +##### 2.2 Navigation helpers + +1. `selectRecord(index)` +2. `openCompare(indices)` +3. `showSingleRecord(index)` + +##### 2.3 Utility helpers + +1. `showToast(message, type)` +2. `getContext()` + +#### 3. Delegated HTML Actions + +Supported action attributes: + +1. `data-ob-action="select-record" data-ob-record="N"` +2. `data-ob-action="open-compare" data-ob-indices="A,B"` +3. `data-ob-action="graph-focus-node" data-ob-node-id="node_name"` + +#### 4. Minimal Example + +```html +
+ +``` + +```javascript +const handle = window.ObservatoryAPI.mountGraph('#graph-slot', 'record_1', { + default_layers: ['accuracy/error'], + default_color_by: 'accuracy/error', +}); + +handle.zoomToFit(); +``` + +### 2.5 Lens-to-GraphHub Guide + + +This guide explains how lenses contribute graph overlays through `analyze()`. + +#### 1. Architectural Rule + +Graph layers are derived data and must be attached in analysis results. + +Do: + +1. Build per-record graph overlays in `analyze()`. +2. Store them in `RecordAnalysis.graph_layers`. + +Do not: + +1. Use separate runtime hooks for layer contribution. +2. Return raw dict layer payloads in user-facing lens APIs. + +#### 2. Preferred Payload Types + +Use fx_viewer typed payload API: + +1. `GraphExtensionPayload` (preferred persisted form) +2. `GraphExtension` (authoring helper, converted lazily) + +Imports: + +```python +from executorch.backends.qualcomm.utils.fx_viewer import ( + GraphExtension, + GraphExtensionPayload, + GraphExtensionNodePayload, +) +``` + +#### 3. Pattern A: Build payload directly + +```python +from executorch.backends.qualcomm.debugger.observatory.interfaces import ( + AnalysisResult, + RecordAnalysis, +) +from executorch.backends.qualcomm.utils.fx_viewer import ( + GraphExtensionPayload, + GraphExtensionNodePayload, +) + +@staticmethod +def analyze(records, config): + per_record = {} + for record in records: + payload = GraphExtensionPayload( + id="error", + name="Accuracy Error", + legend=[{"label": "Low", "color": "#93c5fd"}, {"label": "High", "color": "#b91c1c"}], + nodes={ + "node_0": GraphExtensionNodePayload( + info={"mse": 0.12}, + label_append=["mse=0.12"], + fill_color="#b91c1c", + ) + }, + ) + + analysis = RecordAnalysis(data={"max_mse": 0.12}) + analysis.add_graph_layer("error", payload) + per_record[record.name] = analysis + + return AnalysisResult(per_record_data=per_record) +``` + +#### 4. Pattern B: Use GraphExtension helper + +```python +from executorch.backends.qualcomm.utils.fx_viewer import GraphExtension, NumericColorRule + +ext = GraphExtension(id="latency", name="Layer Latency") +ext.add_node_data("node_0", {"latency_ms": 1.23}) +ext.set_color_rule(NumericColorRule(attribute="latency_ms")) + +analysis = RecordAnalysis(data={"p95_ms": 1.23}) +analysis.add_graph_layer("latency", ext) +``` + +`GraphHub` converts `GraphExtension` to `GraphExtensionPayload` internally. + +#### 5. Layer ID Policy + +User-facing key in `RecordAnalysis`: + +1. `graph_layers["error"] = ...` + +Internal report layer ID by GraphHub: + +1. `/` +2. Example: `accuracy/error` + +This keeps lens APIs free from hardcoded namespacing rules. + +#### 6. How GraphHub Resolves Target Graph + +1. Graph assets are registered by `graph_ref` from graph digest. +2. During payload assembly, observatory reads each record's `RecordAnalysis`. +3. `GraphHub.add_analysis_layers(graph_ref, lens_name, record_analysis)` merges layers for that graph. + +#### 7. Frontend Binding + +`GraphBlock.record.graph_ref` selects which graph asset/layers to render. + +Example: + +```python +GraphView( + id="acc_graph", + title="Accuracy Graph", + graph_ref="Candidate FakeQuant", + default_layers=["accuracy/error"], + default_color_by="accuracy/error", +) +``` + +### 2.6 UI Testcases + +See: +1. `examples/OBSERVATORY_UI_TESTCASES.md` + +## 3. Performance Notes — Viewer Lifecycle and Caching + +### 3.1 Single-record viewers: live DOM cache with LRU eviction + +Single-record graph blocks use a **live viewer cache** keyed by +`(recordIndex, lensName, blockId)`. + +On first visit a `FXGraphViewer` instance is created and its DOM wrapper is stored in +`state.viewerCache`. On return visit the existing wrapper is re-appended to the new +container via `appendChild` (moves the DOM node, no clone). The viewer's full state — +camera pan/zoom, selected node, active extension layers, colorBy, search query — is +preserved exactly as left. No `init()`, no `computeActiveGraph()`, no re-layout. + +On navigate away `destroyGraphRuntime()` detaches the wrapper from the DOM but does +**not** call `viewer.destroy()`. The viewer stays alive in memory. + +**LRU eviction**: the cache holds at most `MAX_CACHED_VIEWERS = 10` live viewers. When +the cap is reached, the least-recently-accessed viewer is destroyed and removed. For a +typical report with 5 records × 2 graph blocks = 10 viewers, nothing is ever evicted. + +**Memory budget per cached viewer** (approximate, 1200×640 viewport, dpr=2): + +| Component | Size | +|---|---| +| Canvas pixel buffer | ~12 MB | +| Node/edge JS objects (3600 nodes) | ~2 MB | +| DOM elements | ~0.1 MB | +| **Total** | **~14 MB** | + +10 cached viewers ≈ 140 MB. For reports with more records, the LRU cap bounds memory use. + +### 3.2 Compare-mode viewers: state-snapshot cache + +Compare-mode viewers are **always freshly created** on each visit. Keeping multiple +side-by-side viewers alive simultaneously would multiply the memory cost above by the +number of panes, with limited benefit since compare views are typically visited briefly. + +Instead, a lightweight **state snapshot** is saved to `state.compareStateCache` whenever +any compare viewer's state changes. The snapshot stores: +`{ camera: {x,y,k}, selectedNodeId, activeExtensions, colorBy }` — a few bytes per block. + +Cache key: `"compare::"` — intentionally **not** keyed by pool +composition or graphRef. Consequences: + +- Adding or removing records from the compare pool restores the same camera and layers. +- Switching from single-record mode back to compare mode restores the compare state. + +**Merge-on-write for `selectedNodeId`**: all viewers in the same compare block share one +snapshot and each registers a `statechange` listener that writes to it. To prevent a +viewer with no selection (e.g. viewer B panning after viewer A selected a node) from +overwriting the saved selection with `null`, the write is a **merge**: `selectedNodeId` +is only updated when the incoming value is non-null. `camera`, `activeExtensions`, and +`colorBy` are always overwritten with the latest value. The selection is cleared from the +snapshot only when another viewer explicitly selects a different node. + +On re-entry, each new viewer is seeded from the snapshot in priority order: +1. `selectedNodeId` present and node exists in **this viewer's** graph → + `selectNode(id, { animate: true })`. +2. Node not found in this viewer's graph (different record) or no selection → + `zoomToFit()`. +3. No snapshot at all → `init()` runs normally (zoom-to-fit or first-node centering). + +Layers (`activeExtensions`, `colorBy`) are always restored from the snapshot when +available, regardless of which priority path is taken for camera/selection. + +### 3.3 Section collapse state + +Section open/close state is persisted to `localStorage` under key +`graphCollectorViewPrefs` as `"${lensName}:${blockId}" → bool`. This is independent of +the viewer cache and survives full page reloads. + +### 3.4 Trade-off summary + +| Scenario | Approach | Memory | Switch-back latency | +|---|---|---|---| +| Single-record graph block | Live DOM cache + LRU | ~14 MB/viewer, cap 10 | ~0 ms (rAF resize only) | +| Compare graph block | State-snapshot + fresh create | ~bytes/block | ~50–200 ms (create + init) | +| Section collapse | localStorage | 0 | 0 | diff --git a/devtools/observatory/USAGE.md b/devtools/observatory/USAGE.md new file mode 100644 index 00000000000..72a8cd8adda --- /dev/null +++ b/devtools/observatory/USAGE.md @@ -0,0 +1,205 @@ +# Observatory CLI Usage Guide + +The Observatory CLI wraps any ExecuTorch export script in an Observatory context, +automatically collecting graph snapshots at each compilation stage. + +## 1. Zero-Config E2E Workflow + +The simplest invocation: point the CLI at your script and pass its arguments through. + +```bash +python -m executorch.devtools.observatory \ + my_export_script.py [SCRIPT_ARGS...] +``` + +Use `--output-html` / `--output-json` to control output paths: + +```bash +python -m executorch.devtools.observatory \ + --output-html /tmp/obs/report.html \ + --output-json /tmp/obs/report.json \ + examples/qualcomm/oss_scripts/swin_v2_t.py \ + --model SM8650 -b ./build-android -d imagenet-mini/val -a ./swin_v2_t +``` + +Use backend-specific observatory cli for additional customized lenses and hooks (qualcomm for example) + +```bash +python -m executorch.backends.xnnpack.debugger.observatory \ + --output-html /tmp/obs/report.html \ + --lense_recipe=accuracy \ + examples/xnnpack/aot_compiler.py \ + --model_name=mv2 --delegate --quantize --output_dir /tmp/mv2 +``` + +> **XNNPack note**: `examples/xnnpack/aot_compiler.py` uses relative imports (`from . import ...`). +> The CLI auto-detects this: when a `.py` path is given and its directory contains `__init__.py`, +> it runs via `runpy.run_module` instead of `runpy.run_path`. You can also pass a dotted module +> name directly (e.g. `examples.xnnpack.aot_compiler`) to force module mode explicitly. + +## 2. Convert JSON to HTML (Visualize Mode) + +Use the `visualize` subcommand to convert an existing JSON file to HTML without +re-running the export script. This re-runs the analysis phase (lens `analyze()` methods) +against the persisted data, so HTML reports can be updated after lens code changes. + +```bash +python -m executorch.backends.qualcomm.debugger.observatory visualize \ + --input-json /tmp/obs/report.json \ + --output-html /tmp/obs/report.html +``` + +Options: +- `--input-json` — path to the raw JSON file (required) +- `--output-html` — path for the generated HTML file (required) + +## 3. Two-Step Workflow (CI collect, local visualize) + +**Step 1 — CI: collect and export** +```bash +python -m executorch.backends.qualcomm.debugger.observatory \ + --output-json artifacts/report.json \ + --output-html artifacts/report.html \ + my_export_script.py --output_dir artifacts/ +``` + +**Step 2 — Local: re-generate HTML from JSON** +```bash +python -m executorch.backends.qualcomm.debugger.observatory visualize \ + --input-json artifacts/report.json \ + --output-html artifacts/report_v2.html +``` + +This separates the history archive results of on-device execution (Step 1) from the interactive +visualization (Step 2), which can be re-run on demand (e.g. comparing models between 2 history commits). + +## 4. Disabling Lenses via Config + +When using the Observatory Python API directly, pass a config dict to +`enable_context()` or `export_html_report()`: + +```python +from executorch.devtools.observatory import Observatory + + +config = { + "accuracy": {"enabled": False}, + "per_layer_accuracy": {"enabled": False}, +} + +with Observatory.enable_context(config=config): + # ... your export code ... + +Observatory.export_html_report("report.html", config=config) +``` + +Config keys correspond to lens names returned by `lens.get_name()`. Each lens +checks `config.get(lens_name, {}).get("enabled", True)` during setup. + +## 5. Manual Observation Collection Points + +You can insert `Observatory.collect()` calls anywhere in your code to capture +intermediate graph states. This is useful for debugging pass transforms or +custom lowering steps. + +### Basic usage + +```python +import torch +from executorch.devtools.observatory import Observatory + +model = MyModel().eval() +graph = torch.fx.symbolic_trace(model) + +Observatory.clear() +with Observatory.enable_context(): + Observatory.collect("original", graph) + + # Apply a pass + transformed = my_pass(graph) + Observatory.collect("after_my_pass", transformed) + +Observatory.export_html_report("pass_debug.html") +Observatory.export_json("pass_debug.json") +``` + +### Pass transform debugging + +Use `observe_pass` to automatically collect graphs before and after a pass. +Wrap any `PassBase` subclass instance, callable, or use it as a class decorator: + +```python +from executorch.devtools.observatory import Observatory, observe_pass +from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass + +# Wrap pass instances — default collects both input and output graphs +pass_a = observe_pass(RemoveGraphAssertsPass()) +pass_b = observe_pass(MyCustomPass()) + +Observatory.clear() +with Observatory.enable_context(): + result_a = pass_a(graph_module) + # collects "RemoveGraphAssertsPass/input" and "RemoveGraphAssertsPass/output" + + result_b = pass_b(result_a.graph_module) + # collects "MyCustomPass/input" and "MyCustomPass/output" + + # Call again — names auto-deduplicate + result_c = pass_a(result_b.graph_module) + # collects "RemoveGraphAssertsPass/input #2" and "RemoveGraphAssertsPass/output #2" + +Observatory.export_html_report("pass_debug.html") +``` + +Control what is collected with boolean flags: + +```python +# Collect only the output graph +observed = observe_pass(SomePass(), collect_input=False) + +# Collect only the input graph +observed = observe_pass(SomePass(), collect_output=False) + +# Override the record name +observed = observe_pass(SomePass(), name="step_1") +``` + +Use as a class decorator to make all instances observable: + +```python +@observe_pass +class MyPass(PassBase): + def call(self, gm): + # ... transform logic ... + return PassResult(gm, True) + +# Or with parameters: +@observe_pass(name="Quantize", collect_input=False) +class QuantizePass(PassBase): + def call(self, gm): + ... +``` + +`observe_pass` is a no-op when no Observatory context is active. + +### Inside the CLI-wrapped script (zero-code-change) + +When running via the CLI, `Observatory.enable_context()` is already active. +You can add collection points to your script without any setup: + +```python +# In your export script (e.g., my_model.py): +from executorch.devtools.observatory import Observatory + +# This fires only when Observatory context is active (i.e., when run via CLI). +# It is a no-op otherwise. +Observatory.collect("pre_quantize", exported_program) +``` +## 6. Quick Reference + +| Scenario | Command | +|----------|---------| +| E2E single script | `cli SCRIPT [SCRIPT_ARGS]` | +| E2E with explicit paths | `cli --output-html X.html --output-json X.json SCRIPT ...` | +| JSON -> HTML | `cli visualize --input-json X.json --output-html X.html` | +| With accuracy (backend CLI) | `backend-cli --lense_recipe=accuracy SCRIPT ...` | diff --git a/devtools/observatory/__init__.py b/devtools/observatory/__init__.py new file mode 100644 index 00000000000..26b88c789b5 --- /dev/null +++ b/devtools/observatory/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .observe_pass import observe_pass +from .observatory import Observatory + +__all__ = ["Observatory", "observe_pass"] diff --git a/devtools/observatory/__main__.py b/devtools/observatory/__main__.py new file mode 100644 index 00000000000..e69e40ec80e --- /dev/null +++ b/devtools/observatory/__main__.py @@ -0,0 +1,9 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .cli import main + +main() diff --git a/devtools/observatory/cli.py b/devtools/observatory/cli.py new file mode 100644 index 00000000000..fccb8266d79 --- /dev/null +++ b/devtools/observatory/cli.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Observatory CLI -- generic runner with shared helpers for backend CLIs. + +Collection mode (default): + python -m executorch.devtools.observatory \\ + [--output-html PATH] [--output-json PATH] SCRIPT [SCRIPT_ARGS...] + +Visualize mode (JSON -> HTML, no script execution): + python -m executorch.devtools.observatory visualize \\ + --input-json report.json --output-html report.html +""" + +from __future__ import annotations + +import argparse +import logging +import os +import runpy +import sys + +logging.basicConfig(level=logging.INFO, format="%(message)s") + + +# --------------------------------------------------------------------------- +# Shared helpers (importable by backend CLIs) +# --------------------------------------------------------------------------- + + +def make_collect_parser(prog=None): + """Create the base argparse parser for collection mode. + + Backend CLIs can extend this with additional arguments before calling + parse_args, e.g.:: + + parser = make_collect_parser(prog="my-backend") + parser.add_argument("--lense_recipe", choices=["accuracy"]) + args = parser.parse_args(sys.argv[1:]) + """ + parser = argparse.ArgumentParser( + prog=prog, + description="Run a script under Observatory and export reports.", + ) + parser.add_argument("--output-html", default=None, help="HTML report output path") + parser.add_argument("--output-json", default=None, help="JSON report output path") + parser.add_argument("script", help="Script to run under Observatory") + parser.add_argument( + "script_args", + nargs=argparse.REMAINDER, + help="Arguments passed to the script", + ) + return parser + + +def make_visualize_parser(prog=None): + """Create the argparse parser for visualize mode.""" + parser = argparse.ArgumentParser( + prog=prog, + description="Generate an HTML report from an existing JSON export.", + ) + parser.add_argument("--input-json", required=True, help="Input JSON report path") + parser.add_argument( + "--output-html", required=True, help="Output HTML report path" + ) + return parser + + +def run_visualize(input_json: str, output_html: str) -> None: + if not os.path.isfile(input_json): + logging.error("[Observatory CLI] Input JSON not found: %s", input_json) + sys.exit(1) + + from .observatory import Observatory + + Observatory.clear() + Observatory.generate_html_from_json(input_json, output_html) + logging.info( + "[Observatory CLI] visualize: html=%s from json=%s", output_html, input_json + ) + + +def _resolve_run_mode(script: str) -> tuple[str, str, str | None]: + """Return (mode, target, pkg_root). + + mode='module': target is a dotted module name; pkg_root is prepended to sys.path. + mode='script': target is an absolute file path; pkg_root is None. + + Detection rules (in order): + 1. No '.py' suffix and no path separator → explicit dotted module name → 'module' + 2. File path whose directory has __init__.py → walk up to find package root → 'module' + 3. Everything else → 'script' + """ + if not script.endswith(".py") and os.sep not in script and "/" not in script: + return "module", script, None + + abs_path = os.path.abspath(script) + script_dir = os.path.dirname(abs_path) + + if not os.path.isfile(os.path.join(script_dir, "__init__.py")): + return "script", abs_path, None + + parts = [os.path.splitext(os.path.basename(abs_path))[0]] + current = script_dir + while True: + parts.insert(0, os.path.basename(current)) + parent = os.path.dirname(current) + if parent == current or not os.path.isfile( + os.path.join(parent, "__init__.py") + ): + pkg_root = parent + break + current = parent + + return "module", ".".join(parts), pkg_root + + +def run_observatory( + script_path: str, + script_argv: list[str], + Observatory, + output_html: str | None = None, + output_json: str | None = None, +) -> None: + """Shared run logic for all CLIs.""" + sys.argv = [script_path] + script_argv + + if output_html is None: + output_html = "observatory_report.html" + if output_json is None: + if output_html.endswith(".html"): + output_json = output_html[:-5] + ".json" + else: + output_json = "observatory_report.json" + + title = f"Observatory: {os.path.basename(script_path)}" + mode, target, pkg_root = _resolve_run_mode(script_path) + + try: + with Observatory.enable_context(config={}): + if mode == "module": + if pkg_root is not None and pkg_root not in sys.path: + sys.path.insert(0, pkg_root) + logging.info("[Observatory CLI] Running as module: %s", target) + runpy.run_module(target, run_name="__main__", alter_sys=True) + else: + logging.info("[Observatory CLI] Running as script: %s", target) + runpy.run_path(target, run_name="__main__") + except SystemExit: + pass + except ImportError as exc: + logging.error("[Observatory CLI] Import error in '%s': %s", script_path, exc) + if "relative import" in str(exc) or "attempted relative import" in str(exc): + logging.error( + " Hint: this script uses relative imports and must run as a Python module.\n" + " Option A — ensure the script's directory contains __init__.py so it is\n" + " auto-detected as a package member.\n" + " Option B — pass a dotted module name instead of a file path:\n" + " python -m executorch.devtools.observatory" + " .. [args...]" + ) + elif mode == "module": + logging.error( + " Hint: module '%s' could not be imported.\n" + " Ensure the package root '%s' is on PYTHONPATH or the package is installed.", + target, + pkg_root or "unknown", + ) + except Exception as exc: + logging.error( + "[Observatory CLI] '%s' raised: %s (run mode: %s, target: %s)", + os.path.basename(script_path), + exc, + mode, + target, + ) + finally: + os.makedirs(os.path.dirname(output_html) or ".", exist_ok=True) + os.makedirs(os.path.dirname(output_json) or ".", exist_ok=True) + Observatory.export_json(output_json) + Observatory.export_html_report(output_html, title=title, config={}) + collected = Observatory.list_collected() + if collected: + logging.info( + "[Observatory CLI] Reports: html=%s json=%s (%d records: %s)", + output_html, + output_json, + len(collected), + ", ".join(collected), + ) + else: + logging.warning("[Observatory CLI] No records collected") + + +# --------------------------------------------------------------------------- +# Generic CLI entry point +# --------------------------------------------------------------------------- + + +def main(): + if len(sys.argv) > 1 and sys.argv[1] == "visualize": + parser = make_visualize_parser() + args = parser.parse_args(sys.argv[2:]) + run_visualize(args.input_json, args.output_html) + return + + parser = make_collect_parser() + args = parser.parse_args(sys.argv[1:]) + + from .observatory import Observatory + from .lenses.pipeline_graph_collector import PipelineGraphCollectorLens + + Observatory.clear() + Observatory.register_lens(PipelineGraphCollectorLens) + + run_observatory( + args.script, args.script_args, Observatory, args.output_html, args.output_json + ) + + +if __name__ == "__main__": + main() diff --git a/devtools/observatory/graph_hub.py b/devtools/observatory/graph_hub.py new file mode 100644 index 00000000000..ccbeb4c51dd --- /dev/null +++ b/devtools/observatory/graph_hub.py @@ -0,0 +1,79 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import asdict +from typing import Any, Dict + +from .interfaces import RecordAnalysis + + +class GraphHub: + """Framework graph asset/layer registry for report assembly.""" + + def __init__(self) -> None: + self._graph_assets: Dict[str, Dict[str, Any]] = {} + self._graph_layers: Dict[str, Dict[str, Dict[str, Any]]] = {} + + def register_asset(self, graph_ref: str, base_payload: Dict[str, Any], meta: Dict[str, Any]) -> None: + if not graph_ref or not isinstance(base_payload, dict): + return + self._graph_assets[graph_ref] = { + "base": base_payload, + "meta": meta or {}, + } + + def add_analysis_layers( + self, + graph_ref: str, + lens_name: str, + analysis: RecordAnalysis | None, + ) -> None: + """Merge per-record analysis graph layers into hub storage. + + Layer IDs are namespaced internally as `/`. + """ + + if not graph_ref or analysis is None: + return + + slot = self._graph_layers.setdefault(graph_ref, {}) + for layer_key, contribution in analysis.graph_layers.items(): + if not layer_key.strip(): + continue + + payload = contribution.to_payload() + namespaced_id = f"{lens_name}/{layer_key}" + + slot[namespaced_id] = { + "name": payload.name, + "legend": payload.legend, + "sync_keys": payload.sync_keys, + "nodes": { + node_id: asdict(node_payload) + for node_id, node_payload in payload.nodes.items() + }, + } + + def get_asset(self, graph_ref: str) -> Dict[str, Any]: + return self._graph_assets.get(graph_ref, {}) + + def build_payload(self) -> Dict[str, Any]: + return { + "graph_assets": self._graph_assets, + "graph_layers": self._graph_layers, + } + + @staticmethod + def build_viewer_payload(graph_assets: Dict[str, Any], graph_layers: Dict[str, Any], graph_ref: str) -> Dict[str, Any]: + asset = graph_assets.get(graph_ref, {}) + if not asset: + return {"base": {"legend": [], "nodes": [], "edges": []}, "extensions": {}} + return { + "base": asset.get("base", {"legend": [], "nodes": [], "edges": []}), + "extensions": graph_layers.get(graph_ref, {}), + } diff --git a/devtools/observatory/html_template.py b/devtools/observatory/html_template.py new file mode 100644 index 00000000000..7507c10d688 --- /dev/null +++ b/devtools/observatory/html_template.py @@ -0,0 +1,96 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import html + +from .template_loader import load_css, load_js_chunks + + +def get_html_template(title: str, payload_json: str, is_compressed: bool = False) -> str: + """Generate observatory HTML shell. + + Args: + title: Report title shown in and page heading. + payload_json: Either the raw JSON string (is_compressed=False) or a + gzip+base64 encoded string (is_compressed=True). + is_compressed: When True, payload_json is a gzip+base64 blob that the + browser decompresses via DecompressionStream before parsing. + """ + + css = load_css() + js_bundle = "\n".join(load_js_chunks()) + + if is_compressed: + data_script = f'window.__OBS_RAW__ = "{payload_json}";' + decompress_block = """ + async function _obsDecompress(b64gz) { + const compressed = Uint8Array.from(atob(b64gz), c => c.charCodeAt(0)); + const ds = new DecompressionStream('gzip'); + const writer = ds.writable.getWriter(); + writer.write(compressed); + writer.close(); + const chunks = []; + const reader = ds.readable.getReader(); + while (true) { + const { done, value } = await reader.read(); + if (done) break; + chunks.push(value); + } + const total = chunks.reduce((n, c) => n + c.length, 0); + const out = new Uint8Array(total); + let off = 0; + for (const c of chunks) { out.set(c, off); off += c.length; } + return new TextDecoder().decode(out); + } + window.OBSERVATORY_DATA = JSON.parse(await _obsDecompress(window.__OBS_RAW__)); +""" + else: + data_script = f'window.OBSERVATORY_DATA = {payload_json};' + decompress_block = "" + + return f"""<!DOCTYPE html> +<html lang=\"en\"> +<head> + <meta charset=\"UTF-8\"> + <meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\"> + <title>{html.escape(title)} + + + +
+ + + + + +""" diff --git a/devtools/observatory/interfaces.py b/devtools/observatory/interfaces.py new file mode 100644 index 00000000000..b78502c33a9 --- /dev/null +++ b/devtools/observatory/interfaces.py @@ -0,0 +1,637 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Typed API contracts for the Observatory runtime. + +This module is the source-of-truth contract for: +1. Frontend view composition (`ViewList` + typed blocks/specs). +2. Runtime record/session objects (`ObservationContext`, `RecordDigest`, etc.). +3. Analyze-phase graph layer contribution via fx_viewer payload types. + +Architecture model: +1. Runtime phase (`observe`/`digest`) captures raw record data. +2. Analyze phase (`analyze`) computes global and per-record derived data. +3. Frontend phase (`Frontend.*`) maps typed data into renderable view blocks. + +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union + +if TYPE_CHECKING: + from executorch.devtools.fx_viewer.extension import GraphExtension + from executorch.devtools.fx_viewer.models import GraphExtensionPayload + + +# Type Alias for JSON-serializable leaf/object values. +Serializable = Union[Dict[str, Any], List[Any], str, int, float, bool, None] + + +# --------------------------------------------------------------------------- +# Frontend block contracts +# --------------------------------------------------------------------------- + + +@dataclass +class TableRecordSpec: + """Record payload for a table block. + + Args: + data: Key-value pairs rendered in the default table renderer. + """ + + data: Dict[str, Serializable] = field(default_factory=dict) + + +@dataclass +class TableCompareSpec: + """Compare behavior for table blocks. + + Modes: + 1. `auto`: runtime renders side-by-side table diff view. + 2. `disabled`: hide compare section for this block. + """ + + mode: Literal["auto", "disabled"] = "auto" + + +@dataclass +class HtmlRecordSpec: + """Record payload for an HTML block. + + Args: + content: Raw HTML fragment. In the report payload this field is + base64-encoded to prevent special characters from corrupting + the JSON embedding. The runtime decodes it before innerHTML assignment. + """ + + content: str = "" + + +@dataclass +class HtmlCompareSpec: + """Compare behavior for HTML blocks. + + Modes: + 1. `auto`: runtime renders selected HTML blocks side-by-side. + 2. `disabled`: hide compare section for this block. + """ + + mode: Literal["auto", "disabled"] = "auto" + + +@dataclass +class CustomRecordSpec: + """Record payload for a custom JS block. + + JS signature: + function renderRecord(container, args, context, analysis) + + JS argument mapping: + 1. `container`: host DOM container created by observatory runtime. + 2. `args`: exactly this dataclass field (`CustomRecordSpec.args`). + 3. `context`: runtime-selected context object. + - Record view path: `{ index, record }` + - `index`: record index in report payload. + - `record`: full serialized record object from report payload. + - Dashboard path: `{ start, end, records }` + - `start`/`end`: lens session payloads. + - `records`: full serialized records list. + 4. `analysis`: `report.analysis_results[lens_name]` object: + `{ global_data, per_record_data }`. + + Practical access pattern for record callbacks: + 1. Digest: `context.record.digests[lens_name]`. + 2. Per-record analysis: + `analysis.per_record_data?.[context.record.name]?.data`. + 3. Global analysis: `analysis.global_data`. + + Fields: + 1. `js_func`: global function path. + 2. `args`: static serializable args. + """ + + js_func: str = "" + args: Dict[str, Serializable] = field(default_factory=dict) + + +@dataclass +class CustomCompareSpec: + """Compare behavior for custom JS blocks. + + JS signature: + function renderCompare(container, args, context, analysis) + + JS argument mapping: + 1. `container`: compare section DOM container. + 2. `args`: exactly this dataclass field (`CustomCompareSpec.args`). + 3. `context`: + - `indices`: selected global record indices. + - `names`: selected record names. + - `records`: selected serialized record objects. + - `blocks`: selected block payloads for this block ID. + - `lens`: lens name. + - `block_id`: block ID. + 4. `analysis`: `report.analysis_results[lens_name]` object: + `{ global_data, per_record_data }`. + + Fields: + 1. `mode`: `custom` or `disabled`. + 2. `js_func`: compare function path. If omitted in `custom` mode, runtime + falls back to `record.js_func`. + 3. `args`: static compare args. + """ + + mode: Literal["custom", "disabled"] = "disabled" + js_func: Optional[str] = None + args: Dict[str, Serializable] = field(default_factory=dict) + + +GraphLayerScope = Union[Literal["all", "lens_only"], List[str]] + + +@dataclass +class GraphRecordSpec: + """Record payload for GraphView blocks. + + Core fields: + 1. `graph_ref`: key into report `graph_assets` and `graph_layers`. + 2. `default_layers`: initial extension layer IDs. + 3. `default_color_by`: initial color-by layer ID. + 4. `layer_scope`: `all`, `lens_only`, or explicit allowlist. + 5. `viewer_options`: passthrough options for embedded FX viewer. + """ + + graph_ref: str + default_layers: List[str] = field(default_factory=list) + default_color_by: Optional[str] = None + layer_scope: GraphLayerScope = "all" + viewer_options: Dict[str, Serializable] = field(default_factory=dict) + controls: Dict[str, Serializable] = field(default_factory=dict) + fullscreen: Dict[str, Serializable] = field(default_factory=dict) + + +@dataclass +class GraphCompareSpec: + """Compare behavior for graph blocks. + + Modes: + 1. `auto`: runtime mounts side-by-side graph compare with optional sync. + 2. `custom`: call user JS function for compare rendering. + 3. `disabled`: hide compare section. + """ + + mode: Literal["auto", "disabled", "custom"] = "auto" + max_parallel: int = 2 + sync_toggle: bool = True + default_sync: Dict[str, str] = field(default_factory=dict) + viewer_options_compare: Dict[str, Serializable] = field(default_factory=dict) + js_func: Optional[str] = None + args: Dict[str, Serializable] = field(default_factory=dict) + + +@dataclass +class TableBlock: + """Typed view block for key-value table rendering.""" + + id: str + title: str + record: TableRecordSpec + compare: TableCompareSpec = field(default_factory=TableCompareSpec) + order: int = 0 + collapsible: bool = True + type: Literal["table"] = "table" + + +@dataclass +class HtmlBlock: + """Typed view block for raw HTML rendering.""" + + id: str + title: str + record: HtmlRecordSpec + compare: HtmlCompareSpec = field(default_factory=HtmlCompareSpec) + order: int = 0 + collapsible: bool = True + type: Literal["html"] = "html" + + +@dataclass +class CustomBlock: + """Typed view block for custom JS rendering.""" + + id: str + title: str + record: CustomRecordSpec + compare: CustomCompareSpec = field(default_factory=CustomCompareSpec) + order: int = 0 + collapsible: bool = True + type: Literal["custom"] = "custom" + + +@dataclass +class GraphBlock: + """Typed view block for graph viewer rendering.""" + + id: str + title: str + record: GraphRecordSpec + compare: GraphCompareSpec = field(default_factory=GraphCompareSpec) + order: int = 0 + collapsible: bool = True + type: Literal["graph"] = "graph" + + +ViewBlock = Union[TableBlock, HtmlBlock, CustomBlock, GraphBlock] + + +@dataclass +class ViewList: + """Ordered block list returned by lens frontends. + + Rules: + 1. Block IDs must be unique within one ViewList. + 2. Rendering order is controlled by `block.order`. + """ + + blocks: List[ViewBlock] = field(default_factory=list) + + +@dataclass +class GraphView: + """Convenience authoring helper for one graph block. + + This helper is intended for lens authors who want the ergonomics of a + focused graph API while still returning canonical `GraphBlock`. + """ + + id: str + title: str + graph_ref: str + default_layers: List[str] = field(default_factory=list) + default_color_by: Optional[str] = None + layer_scope: GraphLayerScope = "all" + viewer_options: Dict[str, Serializable] = field(default_factory=dict) + controls: Dict[str, Serializable] = field(default_factory=dict) + fullscreen: Dict[str, Serializable] = field(default_factory=dict) + compare: GraphCompareSpec = field(default_factory=GraphCompareSpec) + order: int = 0 + collapsible: bool = True + + def as_block(self) -> GraphBlock: + """Build canonical `GraphBlock` from convenience fields.""" + + return GraphBlock( + id=self.id, + title=self.title, + record=GraphRecordSpec( + graph_ref=self.graph_ref, + default_layers=self.default_layers, + default_color_by=self.default_color_by, + layer_scope=self.layer_scope, + viewer_options=self.viewer_options, + controls=self.controls, + fullscreen=self.fullscreen, + ), + compare=self.compare, + order=self.order, + collapsible=self.collapsible, + ) + + +def _require_non_empty_text(value: str, field_name: str, block_id: str) -> None: + if not isinstance(value, str) or not value.strip(): + raise ValueError(f"ViewBlock '{block_id}' requires non-empty {field_name}") + + +def _require_str_list(value: Any, field_name: str, block_id: str) -> None: + if not isinstance(value, list) or any((not isinstance(x, str) or not x.strip()) for x in value): + raise ValueError(f"ViewBlock '{block_id}' requires {field_name} as list[str]") + + +def validate_view_block(block: ViewBlock) -> None: + """Validate one typed frontend block. + + Validation covers: + 1. Required identity fields (`id`, `title`). + 2. Per-block invariant checks. + 3. Compare-mode specific requirements (for custom/graph blocks). + """ + + if not isinstance(block, (TableBlock, HtmlBlock, CustomBlock, GraphBlock)): + raise TypeError(f"Unsupported ViewBlock type: {type(block)}") + + _require_non_empty_text(block.id, "id", "") + _require_non_empty_text(block.title, "title", block.id) + + if not isinstance(block.order, int): + raise TypeError(f"ViewBlock '{block.id}' order must be int") + + if isinstance(block, CustomBlock): + _require_non_empty_text(block.record.js_func, "record.js_func", block.id) + if block.compare.mode == "custom": + compare_js_func = (block.compare.js_func or "").strip() or block.record.js_func.strip() + if not compare_js_func: + raise ValueError(f"CustomBlock '{block.id}' compare mode custom requires js_func") + + if isinstance(block, GraphBlock): + _require_non_empty_text(block.record.graph_ref, "record.graph_ref", block.id) + + if block.record.default_layers: + _require_str_list(block.record.default_layers, "record.default_layers", block.id) + + scope = block.record.layer_scope + if isinstance(scope, str): + if scope not in {"all", "lens_only"}: + raise ValueError( + f"GraphBlock '{block.id}' layer_scope must be 'all', 'lens_only', or list[str]" + ) + elif isinstance(scope, list): + _require_str_list(scope, "record.layer_scope", block.id) + else: + raise ValueError( + f"GraphBlock '{block.id}' layer_scope must be 'all', 'lens_only', or list[str]" + ) + + if int(block.compare.max_parallel) < 1: + raise ValueError(f"GraphBlock '{block.id}' compare.max_parallel must be >= 1") + + if block.compare.mode == "custom" and not (block.compare.js_func or "").strip(): + raise ValueError(f"GraphBlock '{block.id}' compare mode custom requires js_func") + + +def validate_view_list(view_list: ViewList) -> None: + """Validate a full frontend `ViewList` contract.""" + + if not isinstance(view_list, ViewList): + raise TypeError(f"Expected ViewList, got {type(view_list)}") + + seen_ids = set() + for block in view_list.blocks: + validate_view_block(block) + if block.id in seen_ids: + raise ValueError(f"Duplicate ViewBlock id in one ViewList: {block.id}") + seen_ids.add(block.id) + + +# --------------------------------------------------------------------------- +# Analysis contracts +# --------------------------------------------------------------------------- + + +@dataclass +class GraphLayerContribution: + """Graph layer contribution attached during analyze-phase. + + `extension` accepts either: + 1. `GraphExtensionPayload` (preferred stable payload type). + 2. `GraphExtension` (authoring helper converted lazily). + """ + + extension: Union["GraphExtension", "GraphExtensionPayload"] + id_override: Optional[str] = None + name_override: Optional[str] = None + + def to_payload(self) -> "GraphExtensionPayload": + """Resolve contribution into a `GraphExtensionPayload`.""" + + from executorch.devtools.fx_viewer.extension import GraphExtension + from executorch.devtools.fx_viewer.models import GraphExtensionPayload + + payload: GraphExtensionPayload + if isinstance(self.extension, GraphExtensionPayload): + payload = self.extension + elif isinstance(self.extension, GraphExtension): + payload = self.extension.build_payload() + else: + raise TypeError( + "GraphLayerContribution.extension must be GraphExtensionPayload or GraphExtension" + ) + + if self.id_override or self.name_override: + return GraphExtensionPayload( + id=self.id_override or payload.id, + name=self.name_override or payload.name, + legend=payload.legend, + nodes=payload.nodes, + ) + + return payload + + +@dataclass +class RecordAnalysis: + """Per-record analysis output. + + Fields: + 1. `data`: record-specific derived values consumed by frontend record views. + 2. `graph_layers`: map from local layer key to typed graph contribution. + """ + + data: Dict[str, Serializable] = field(default_factory=dict) + graph_layers: Dict[str, GraphLayerContribution] = field(default_factory=dict) + + def add_graph_layer( + self, + key: str, + extension: Union["GraphExtension", "GraphExtensionPayload"], + *, + id_override: Optional[str] = None, + name_override: Optional[str] = None, + ) -> None: + """Add or replace a graph layer contribution for this record.""" + + if not key.strip(): + raise ValueError("RecordAnalysis graph layer key must be non-empty") + self.graph_layers[key] = GraphLayerContribution( + extension=extension, + id_override=id_override, + name_override=name_override, + ) + + +# --------------------------------------------------------------------------- +# Runtime core contracts +# --------------------------------------------------------------------------- + + +class Frontend: + """Visualization strategy object returned by each lens. + + Frontend methods are block-oriented: + 1. `dashboard(...) -> ViewList | None` + 2. `record(...) -> ViewList | None` + + Compare behavior is declared per block (`block.compare`) instead of a + separate lens-level `compare()` callback. + """ + + def resources(self) -> Dict[str, str]: + """Return optional shared JS/CSS resources. + + Returns: + Dict with optional keys: + 1. `js`: inline JavaScript source. + 2. `css`: inline CSS source. + """ + + return {} + + def dashboard( + self, + start: Dict[str, Any], + end: Dict[str, Any], + analysis: Dict[str, Any], + records: List[Any], + ) -> Optional[ViewList]: + """Build dashboard-level block list for one lens. + + Python-side inputs: + 1. `start` <- `SessionResult.start_data[lens_name]`. + 2. `end` <- `SessionResult.end_data[lens_name]`. + 3. `analysis` <- `AnalysisResult.global_data` (this lens). + 4. `records` <- collected `RecordDigest` list. + + Render dataflow: + 1. Return `ViewList(blocks=[...])`. + 2. Blocks are serialized into report payload. + 3. For `CustomBlock`, JS callback receives: + `fn(container, block.record.args, {start,end,records}, analysis_results[lens_name])`. + + Args: + start: Session start payload from `on_session_start`. + end: Session end payload from `on_session_end`. + analysis: Lens global analysis payload. + records: Serialized record list for context-aware summaries. + """ + return None + + def record( + self, + digest: Any, + analysis: Dict[str, Any], + context: Dict[str, Any], + ) -> Optional[ViewList]: + """Build record-level block list for one lens. + + Python-side inputs: + 1. `digest` <- current record digest for this lens. + 2. `analysis` <- `{ "global": global_data, "record": per_record_data[name].data }`. + 3. `context` <- `{ "index": int, "name": str }`. + + Render dataflow: + 1. Return `ViewList(blocks=[...])`. + 2. Runtime mounts block renderers per selected record. + 3. For `CustomBlock`, JS callback receives: + `fn(container, block.record.args, {index, record}, analysis_results[lens_name])`. + 4. JS `record` is the serialized report record object, so digest data is + available via `context.record.digests[lens_name]`. + + Args: + digest: Current record digest for this lens. + analysis: Dict with `global` and `record` derived analysis. + context: Record context metadata (`index`, `name`). + """ + return None + + def check_badges(self, digest: Any, analysis: Dict[str, Any]) -> List[Dict[str, str]]: + return [] + + def check_index_diffs( + self, + prev_digest: Any, + curr_digest: Any, + analysis: Dict[str, Any], + ) -> Dict[str, str]: + return {} + + +@dataclass +class ObservationContext: + """Context shared across runtime lens hooks. + + `shared_state` is a per-collect broker for cross-lens hints (for example, + exposing record name or artifact hints discovered by one lens). + """ + + config: Dict[str, Any] + shared_state: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class RecordDigest: + """Persistent observation item. + + This is the canonical persisted unit produced by runtime capture. + """ + + name: str + timestamp: float + data: Dict[str, Serializable] = field(default_factory=dict) + + +@dataclass +class SessionResult: + """Session start/end data from lens hooks.""" + + start_data: Dict[str, Serializable] = field(default_factory=dict) + end_data: Dict[str, Serializable] = field(default_factory=dict) + + +@dataclass +class AnalysisResult: + """Global + per-record analysis contract for a lens.""" + + global_data: Dict[str, Serializable] = field(default_factory=dict) + per_record_data: Dict[str, RecordAnalysis] = field(default_factory=dict) + + +class Lens: + """Protocol for Observatory lenses. + + Lifecycle phases: + 1. Runtime (stateful): `setup`, session hooks, `observe`, `digest`, `clear`. + 2. Analyze (pure-data): `analyze(records, config)`. + 3. Frontend strategy: `get_frontend_spec()`. + """ + + @classmethod + def get_name(cls) -> str: + raise NotImplementedError() + + @classmethod + def setup(cls) -> None: + pass + + @classmethod + def on_session_start(cls, context: ObservationContext) -> Optional[Serializable]: + return None + + @classmethod + def observe(cls, artifact: Any, context: ObservationContext) -> Any: + return None + + @classmethod + def digest(cls, observation: Any, context: ObservationContext) -> Serializable: + return None + + @classmethod + def on_session_end(cls, context: ObservationContext) -> Optional[Serializable]: + return None + + @classmethod + def clear(cls) -> None: + pass + + @staticmethod + def analyze(records: List[RecordDigest], config: Dict[str, Any]) -> AnalysisResult: + return AnalysisResult() + + @staticmethod + def get_frontend_spec() -> Frontend: + return Frontend() diff --git a/devtools/observatory/lenses/LENSES.md b/devtools/observatory/lenses/LENSES.md new file mode 100644 index 00000000000..04ef317a20d --- /dev/null +++ b/devtools/observatory/lenses/LENSES.md @@ -0,0 +1,401 @@ +# Observatory Lenses Reference + +## Overview + +Observatory lenses are plugins that observe, analyze, and render model artifacts +at each stage of the ExecuTorch compilation pipeline. Lenses install their own +monkey-patches in `on_session_start()` and remove them in `on_session_end()`. + +## Built-in Lenses + +| Lens | Purpose | Patches? | +|------|---------|----------| +| GraphLens | Renders fx_viewer graph visualization | No | +| MetadataLens | Collects artifact type, node count, environment info | No | +| StackTraceLens | Captures repo-local call stack at collection time | No | +| PipelineGraphCollectorLens | Auto-collects graphs at each pipeline stage | Yes | +| AccuracyLens | Evaluates model accuracy at each stage | Yes | +| PerLayerAccuracyLens | Sparse per-layer accuracy via `from_node_root` matching | No | + +--- + +## PipelineGraphCollectorLens + +### Purpose + +Automatically collects graph snapshots at each stage of the export -> quantize -> +lower pipeline by monkey-patching framework-level functions plus backend-specific +entrypoints. Framework-level patches cover common stages (`prepare_pt2e`, +`convert_pt2e`, `to_edge_transform_and_lower`), while backend-specific patches +are used to collect "Exported Float" with stable fallback dataset capture. + +### Observation Points + +| # | Pipeline Stage | Record Name | Patched Function | Source File | Collected Artifact | +|---|---------------|-------------|-----------------|-------------|-------------------| +| 1 | Backend-specific pre-quant (QNN) | "Exported Float" | `ptq_calibrate()` | `executorch/examples/qualcomm/utils.py` | `ExportedProgram` (`run_decompositions({})`) | +| 2 | Backend-specific pre-quant (XNNPACK) | "Exported Float" | `quantize()` | `executorch/examples/xnnpack/quantization/utils.py` | `ExportedProgram` (`run_decompositions({})`) | +| 3 | Quantizer Prepare | "Annotated Model" | `prepare_pt2e()` | `torchao/.../quantize_pt2e.py` | Output `GraphModule` with observers | +| 4 | Quantizer Convert (input) | "Calibrated Model" | `convert_pt2e()` | same | Input `GraphModule` (post-calibration) | +| 5 | Quantizer Convert (output) | "Quantized Model" | `convert_pt2e()` | same | Output `GraphModule` with Q/DQ ops | +| 6 | Edge transform input | "Pre-EdgeTransform/{method}" | `to_edge_transform_and_lower()` | `executorch/exir/program/_program.py` + `executorch/exir/__init__.py` | Input `ExportedProgram` (single or dict entry) | +| 7 | ETRecord Export | "ETRecord Exported/{method}" | `ETRecord.add_exported_program()` | `executorch/devtools/etrecord/` | Exported program | +| 8 | ETRecord Edge | "ETRecord Edge/{method}" | `ETRecord.add_edge_dialect_program()` | same | Edge dialect program | +| 9 | ETRecord Extra | "ETRecord Extra/{module}" | `ETRecord.add_extra_export_modules()` | same | Extra modules | +| 10 | Edge transform output (final) | "EdgeProgramManager EP" | `to_edge_transform_and_lower()` | `executorch/exir/program/_program.py` + `executorch/exir/__init__.py` | `EdgeProgramManager.exported_program()` | + +### Patching Strategy + +Each patch follows the same pattern: +1. Save original function in `_originals[key]` +2. Create wrapper that calls `Observatory.collect(name, artifact)` then calls original +3. Replace function in module namespace via `setattr` +4. On session end, restore all originals + +Patch install order is explicit: +1. backend-agnostic framework patches (`torchao`, `executorch.exir`, ETRecord) +2. backend-specific patches (QNN/XNNPACK) + +This ordering avoids early-import alias freezing in e2e scripts. + +The `to_edge_transform_and_lower` patch also forces `generate_etrecord=True` to +ensure ETRecord collection fires (rows 7-9). The post-call edge output record +("EdgeProgramManager EP") is collected after ETRecord hooks complete. + +### Backend Contract for AccuracyLens + +`PipelineGraphCollectorLens` owns a cross-lens fallback dataset field: + +- `_last_calibration_dataset` + +Contract: +- Any backend-specific patch that emits `"Exported Float"` must also populate + `_last_calibration_dataset`. +- This is done through `_set_accuracy_fallback_dataset(...)` in + `pipeline_graph_collector.py`. +- AccuracyLens uses this field when dataset loader patches did not provide + `_captured_dataset`. + +Current backend-specific implementations: +- QNN: `ptq_calibrate` patch +- XNNPACK: `quantize` patch + +--- + +## AccuracyLens + +### Purpose + +Evaluates model accuracy at each collected pipeline stage by running inference +with a dataset and computing metrics (TopK, PSNR, CosineSimilarity, etc.). + +### How It Works + +AccuracyLens depends on PipelineGraphCollectorLens for graph collection timing. +It configures itself lazily when it first observes the "Exported Float" record: + +``` +Observatory.collect("Exported Float", exported_program) + -> AccuracyLens.observe() triggered + -> Recognizes record name "Exported Float" + -> Extracts float model from ExportedProgram + -> Configures evaluator with captured dataset + auto-detected metrics + -> Runs evaluation on float model -> returns accuracy digest + -> Evaluator is now ready for subsequent records + +Observatory.collect("Quantized Model", quantized_model) + -> AccuracyLens.observe() triggered + -> Evaluator already configured + -> Runs evaluation on quantized model -> returns accuracy digest +``` + +### Data Sources and Fallback Strategy + +| Data | Primary Source | Fallback | When Fallback Triggers | +|------|---------------|----------|----------------------| +| Dataset (inputs) | `get_imagenet_dataset` / `get_masked_language_model_dataset` patches | `PipelineGraphCollectorLens._last_calibration_dataset` from backend-specific patch | Dataset loader patches do not fire or do not capture usable inputs | +| Targets (labels) | `get_imagenet_dataset` / `get_masked_language_model_dataset` patch | None (skip target-specific metrics) | Custom dataset, non-classification task | +| Float model | "Exported Float" artifact (ExportedProgram) | -- | Always available | +| Golden outputs | Computed from float model + dataset | -- | Always available if dataset exists | +| Post-process | Auto-detected from model output type | Identity function | Detection failure | + +**Fallback behavior when dataset patches don't fire:** +- AccuracyLens uses backend-captured fallback dataset from PipelineGraphCollectorLens +- Target-dependent metrics (TopK/MaskedTokenAccuracy) are skipped when targets are unavailable +- Golden-output metrics (PSNR, CosineSimilarity, MSE, AbsErr) still run when fallback inputs exist + +### Dataset Loader Patches + +| Patched Function | Module | Return Format | What's Captured | +|-----------------|--------|---------------|-----------------| +| `get_imagenet_dataset` | `executorch.examples.qualcomm.utils` | `(List[Tuple[Tensor]], List[Tensor])` | inputs + class targets | +| `get_masked_language_model_dataset` | `executorch.examples.qualcomm.utils` | `(List[Tuple[Tensor, Tensor]], List[Tensor])` | inputs + masked targets | + +Note: `AccuracyLens` no longer patches `build_executorch_binary`; dataset fallback is +owned by `PipelineGraphCollectorLens` backend-specific contract. + +### Auto-Detection + +**Task type** (from target format): +- Targets contain -100 values -> MLM mode -> uses `MaskedTokenAccuracy` + `MLMEvaluator` +- Otherwise -> Classification mode -> uses `TopKAccuracy` + `StandardEvaluator` + +**Post-process** (from model output type): +- `torch.Tensor` -> identity +- Has `.logits` attribute -> `lambda x: x.logits` (HuggingFace models) +- Tuple -> `lambda x: x[0]` + +### Default Metrics + +All metrics that compare against golden outputs are always included when golden +outputs are available. Target-dependent metrics are added when targets are captured. + +| Task Type | Metrics | +|-----------|---------| +| Any (with golden) | PSNR, CosineSimilarity, MSE, AbsErr | +| Classification (with targets) | + TopKAccuracy(k=1), TopKAccuracy(k=5) | +| MLM (with targets) | + MaskedTokenAccuracy | + +### Metric Design: `higher_is_better` and Worst Direction + +Every `Metric` subclass declares `higher_is_better` which controls how the +worst-case input is identified: + +| Metric | higher_is_better | Worst = | +|--------|-----------------|---------| +| PSNR | True | argmin (lowest dB = worst quality) | +| CosineSimilarity | True | argmin (lowest similarity = worst) | +| TopKAccuracy | True | argmin (0.0 = incorrect) | +| MaskedTokenAccuracy | True | argmin (lowest token accuracy) | +| MSE | False | argmax (highest error = worst) | +| AbsErr | False | argmax (highest error = worst) | + +### PSNR Cap + +PSNR is capped at `PSNR.MAX_PSNR = 100.0` dB. Raw PSNR above 100 dB (e.g., +128 dB for near-zero error) is not meaningfully different from perfect match and +creates confusing display. The cap gives a uniform ceiling: perfect match → +100.0, real degradation → actual dB value below 100.0. + +### Per-Sample Statistics + +When the dataset has more than one sample, each metric emits additional keys in +the digest alongside the primary mean value: + +``` +psnr → mean PSNR across all samples (primary display value) +psnr_min → worst PSNR sample value +psnr_max → best PSNR sample value +psnr_worst_idx → dataset index of the worst-performing sample +``` + +The same pattern applies to all metrics: `{name}_min`, `{name}_max`, +`{name}_worst_idx`. + +The frontend renders these as three separate tables: + +| Table | Block ID | Content | Shown when | +|-------|----------|---------|------------| +| Accuracy | `accuracy_table` | Mean metric values | Always | +| Per-Sample Stats | `accuracy_stats_table` | `{name}_min` / `{name}_max` per metric | >1 sample | +| Worst Input Index | `accuracy_worst_idx_table` | `{name}_worst_idx` per metric (suffix stripped) | >1 sample | + +### Cross-Lens Data Sharing via `_worst_indices` + +AccuracyLens exposes the worst-case input indices as class-level state so that +future lenses can access them during their own `observe()` call without +re-running inference: + +```python +from executorch.devtools.observatory.lenses.accuracy import AccuracyLens + +class PerLayerAccuracyLens(Lens): + @classmethod + def observe(cls, artifact, context): + # Use the worst input identified by AccuracyLens for focused analysis + worst_idx = AccuracyLens._worst_indices.get("psnr") + if worst_idx is not None: + # Run per-layer analysis on dataset[worst_idx] + ... +``` + +**Contract:** +- `AccuracyLens._worst_indices` is a `Dict[str, int]` mapping metric name to + dataset index (e.g., `{"psnr": 3, "cosine_sim": 3, "mse": 7}`) +- Updated after every `evaluate()` call, so it reflects the current record +- Cleared in `_clear_state()` (session end / Observatory.clear()) +- Only populated when dataset has >1 sample +- AccuracyLens must be registered before any lens that reads `_worst_indices` + (lenses run in registration order within each `Observatory.collect()` call) + +--- + +## PerLayerAccuracyLens + +### Purpose + +Computes sparse per-layer metrics between an anchor graph (default: +`"Exported Float"`) and each collected graph, then renders: + +1. Lens-specific graph overlays with per-metric coloring. +2. One merged per-layer metrics table (worst -> best). + +### Sparse Matching Rule + +For each graph: +1. Iterate nodes in topological order. +2. Build key per node: + - `root:` when available. + - `id:` fallback when root is missing. +3. Store key -> node using overwrite semantics. + +Effect: +- Last topological node for a key is selected (sparse map). +- Pairwise correspondence uses key intersection only. +- No group aggregation. + +### Data / Sample Selection + +Input sample source: +1. `AccuracyLens._captured_dataset` (primary) +2. `PipelineGraphCollectorLens._last_calibration_dataset` (fallback) + +Sample index selection: +1. `config["per_layer_accuracy"]["sample_index"]` if provided. +2. `AccuracyLens._worst_indices` using metric priority list. +3. Fallback index `0`. + +### Metrics and Visual Layers + +Per matched node: +- `PSNR` +- `CosineSimilarity` +- `MSE` +- `AbsErr` + +Graph layers emitted in analyze phase: +- `per_layer_accuracy/psnr` (low PSNR = severe red) +- `per_layer_accuracy/cosine_sim` (low cosine similarity = severe red) +- `per_layer_accuracy/mse` (high MSE = severe red) +- `per_layer_accuracy/abs_err` (high AbsErr = severe red) + +Each layer includes all metric values in node labels/tooltips so users can +inspect cross-metric behavior even when a different metric is selected for +coloring. + +Default lens graph section: +- `default_layers = ["per_layer_accuracy/psnr"]` (other metric layers are still available in layer controls) +- `default_color_by = "per_layer_accuracy/psnr"` + +### Frontend Sections + +Record view includes: +1. Summary table. +2. Lens-specific graph section. +3. One merged metrics table with metric-specific column coloring: + - PSNR column (low PSNR is severe) + - Cosine column (low cosine is severe) + - MSE column (high MSE is severe) + - AbsErr column (high AbsErr is severe) + - text color is auto-contrasted per cell background + +### Config + +```python +config = { + "per_layer_accuracy": { + "anchor_record_name": "Exported Float", + # optional explicit sample index + # "sample_index": 0, + # optional priority when sample_index is omitted + "worst_metric_priority": ["psnr", "cosine_sim", "mse", "abs_err"], + } +} +``` + +### Registration Note + +Register `AccuracyLens` before `PerLayerAccuracyLens` so worst-index hints are +available in the same `collect()` call. + +--- + +## Custom Usage + +### Providing a Custom Dataset and Evaluator + +For scripts with custom datasets not covered by the auto-patches, users can +provide their own evaluator via the Observatory config: + +```python +from executorch.devtools.observatory import Observatory +from executorch.devtools.observatory.lenses.accuracy import ( + StandardEvaluator, TopKAccuracy, PSNR, CosineSimilarity +) + +# Prepare your dataset and golden outputs +dataset = [...] # List of input tuples +targets = [...] # Ground truth labels +golden = [model(*inp) for inp in dataset] # Reference outputs + +evaluator = StandardEvaluator( + dataset=dataset, + metrics=[ + TopKAccuracy(targets, k=1), + PSNR(golden), + CosineSimilarity(golden), + ], + post_process=lambda x: x.logits, # optional +) + +config = {"accuracy": {"evaluator": evaluator}} + +with Observatory.enable_context(config=config): + build_executorch_binary(model, inputs, ...) +``` + +### Using MLMEvaluator for Language Models + +```python +from executorch.devtools.observatory.lenses.accuracy import ( + MLMEvaluator, MaskedTokenAccuracy, PSNR +) + +evaluator = MLMEvaluator( + dataset=inputs, + metrics=[MaskedTokenAccuracy(targets), PSNR(golden)], +) + +config = {"accuracy": {"evaluator": evaluator}} +with Observatory.enable_context(config=config): + ... +``` + +### Custom Metrics + +Implement the `Metric` protocol: + +```python +class MyMetric: + def name(self) -> str: + return "my_metric" + + def calculate(self, predictions: List[torch.Tensor]) -> float: + # Your metric logic here + return score + +evaluator = StandardEvaluator( + dataset=dataset, + metrics=[MyMetric(), PSNR(golden)], +) +``` + +### Disabling Accuracy Evaluation + +Accuracy lenses are opt-in via `--lense_recipe=accuracy` on backend CLIs. When +omitted, no accuracy evaluation runs. To disable accuracy when using the Python +API directly: + +```python +config = {"accuracy": {"enabled": False}} +``` diff --git a/devtools/observatory/lenses/__init__.py b/devtools/observatory/lenses/__init__.py new file mode 100644 index 00000000000..b5f86874fd4 --- /dev/null +++ b/devtools/observatory/lenses/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/devtools/observatory/lenses/accuracy.py b/devtools/observatory/lenses/accuracy.py new file mode 100644 index 00000000000..a38cae2999d --- /dev/null +++ b/devtools/observatory/lenses/accuracy.py @@ -0,0 +1,781 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Accuracy Evaluation Lens — auto-captures datasets and evaluates model accuracy. + +This lens patches dataset loaders to transparently capture evaluation data, then +lazily configures an evaluator when the first "Exported Float" record is observed. + +Patches installed on session start: + - get_imagenet_dataset: captures (inputs, targets) for ImageNet classification + - get_masked_language_model_dataset: captures (inputs, targets) for MLM tasks + +Lazy configuration (on first "Exported Float" observe): + - Extracts float model from ExportedProgram + - Uses captured dataset (primary) or sample input from PipelineGraphCollectorLens (fallback) + - Auto-detects task type, post_process, and metrics + - Computes golden outputs for PSNR/CosineSimilarity/MSE/AbsErr + +Per-sample statistics (when dataset has >1 sample): + - Each metric emits mean (primary), min, max, and worst_idx in the digest + - worst_idx is determined by each metric's higher_is_better direction + - AccuracyLens._worst_indices exposes {metric_name: index} as class-level state + so future lenses (e.g., per-layer accuracy) can read the worst input index + during their own observe() call without re-running inference +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +import numpy as np +import torch +import torch.nn.functional as F + +from ..interfaces import ( + AnalysisResult, + Frontend, + Lens, + ObservationContext, + RecordAnalysis, + RecordDigest, + TableBlock, + TableRecordSpec, + ViewList, +) + + +# --------------------------------------------------------------------------- +# Data model +# --------------------------------------------------------------------------- + + +@dataclass +class PrecomputedOutputs: + """Wrapper for inference results obtained externally (e.g., from device).""" + + outputs: List[torch.Tensor] + + def __post_init__(self): + if isinstance(self.outputs, list) and self.outputs: + if isinstance(self.outputs[0], np.ndarray): + self.outputs = [torch.from_numpy(o) for o in self.outputs] + + +# --------------------------------------------------------------------------- +# Metric base class +# --------------------------------------------------------------------------- + + +class Metric: + """Base class for accuracy metrics. + + Subclasses implement calculate_per_sample() which returns one scalar per + input sample. The base class derives the aggregate mean and provides + worst_index() using the metric's built-in direction knowledge. + + higher_is_better controls worst-case direction: + True → worst = argmin (PSNR, cosine_sim, TopK — lower means worse quality) + False → worst = argmax (MSE, AbsErr — higher means worse quality) + """ + + higher_is_better: bool = True + + def name(self) -> str: + raise NotImplementedError + + def calculate_per_sample(self, predictions: List[torch.Tensor]) -> List[float]: + raise NotImplementedError + + def calculate(self, predictions: List[torch.Tensor]) -> float: + values = self.calculate_per_sample(predictions) + return float(np.mean(values)) if values else 0.0 + + def worst_index(self, per_sample: List[float]) -> int: + if not per_sample: + return 0 + if self.higher_is_better: + return int(np.argmin(per_sample)) + else: + return int(np.argmax(per_sample)) + + +# --------------------------------------------------------------------------- +# Metric implementations +# --------------------------------------------------------------------------- + + +class TopKAccuracy(Metric): + """Classification accuracy: fraction of samples where true label is in top-k.""" + + higher_is_better = True + + def __init__(self, targets: List[Any], k: int = 1): + self.targets = targets + self.k = k + + def name(self) -> str: + return f"top_{self.k}" + + def calculate_per_sample(self, predictions: List[torch.Tensor]) -> List[float]: + values = [] + for pred, target in zip(predictions, self.targets): + if not isinstance(pred, torch.Tensor): + pred = torch.tensor(pred) + if not isinstance(target, torch.Tensor): + target = torch.tensor(target) + if pred.dim() == 2: + pred = pred.squeeze(0) + _, indices = pred.topk(self.k) + values.append(100.0 if target.view(-1) in indices else 0.0) + return values + + def calculate(self, predictions: List[torch.Tensor]) -> float: + values = self.calculate_per_sample(predictions) + return float(np.mean(values)) if values else 0.0 + + +class CosineSimilarity(Metric): + """Cosine similarity between predictions and golden outputs.""" + + higher_is_better = True + + def __init__(self, golden_outputs: List[torch.Tensor]): + self.golden = golden_outputs + + def name(self) -> str: + return "cosine_sim" + + def calculate_per_sample(self, predictions: List[torch.Tensor]) -> List[float]: + if not self.golden or len(predictions) != len(self.golden): + return [] + values = [] + for p, g in zip(predictions, self.golden): + p_flat = p.flatten().float() + g_flat = g.flatten().float() + values.append( + F.cosine_similarity(p_flat.unsqueeze(0), g_flat.unsqueeze(0)).item() + ) + return values + + +class PSNR(Metric): + """Peak Signal-to-Noise Ratio, capped at MAX_PSNR for UI consistency. + + Raw PSNR above MAX_PSNR (e.g., 128 dB for near-zero error) is not + meaningfully different from perfect match, so we clamp to MAX_PSNR. + This gives a uniform ceiling: perfect match → MAX_PSNR, real degradation + → actual dB value below MAX_PSNR. + """ + + higher_is_better = True + MAX_PSNR = 100.0 + + def __init__(self, golden_outputs: List[torch.Tensor]): + self.golden = golden_outputs + self.max_val = ( + max(torch.max(g).item() for g in golden_outputs) if golden_outputs else 1.0 + ) + + def name(self) -> str: + return "psnr" + + def calculate_per_sample(self, predictions: List[torch.Tensor]) -> List[float]: + if not self.golden or len(predictions) != len(self.golden): + return [] + values = [] + for p, g in zip(predictions, self.golden): + mse = F.mse_loss(p.float(), g.float()) + if mse == 0: + values.append(self.MAX_PSNR) + else: + db = ( + 20 + * torch.log10( + torch.tensor(self.max_val) / torch.sqrt(mse) + ).item() + ) + values.append(min(db, self.MAX_PSNR)) + return values + + +class MSE(Metric): + """Mean Squared Error per sample. Lower is better (higher_is_better=False).""" + + higher_is_better = False + + def __init__(self, golden_outputs: List[torch.Tensor]): + self.golden = golden_outputs + + def name(self) -> str: + return "mse" + + def calculate_per_sample(self, predictions: List[torch.Tensor]) -> List[float]: + if not self.golden or len(predictions) != len(self.golden): + return [] + return [ + F.mse_loss(p.float(), g.float()).item() + for p, g in zip(predictions, self.golden) + ] + + +class AbsErr(Metric): + """Mean Absolute Error per sample. Lower is better (higher_is_better=False).""" + + higher_is_better = False + + def __init__(self, golden_outputs: List[torch.Tensor]): + self.golden = golden_outputs + + def name(self) -> str: + return "abs_err" + + def calculate_per_sample(self, predictions: List[torch.Tensor]) -> List[float]: + if not self.golden or len(predictions) != len(self.golden): + return [] + return [ + torch.mean(torch.abs(p.float() - g.float())).item() + for p, g in zip(predictions, self.golden) + ] + + +class MaskedTokenAccuracy(Metric): + """Token-level accuracy for MLM models, filtering by ignore_index (-100).""" + + higher_is_better = True + + def __init__(self, targets: List[torch.Tensor], ignore_index: int = -100): + self.targets = targets + self.ignore_index = ignore_index + + def name(self) -> str: + return "masked_token_accuracy" + + def calculate_per_sample(self, predictions: List[torch.Tensor]) -> List[float]: + values = [] + for pred, target in zip(predictions, self.targets): + if not isinstance(target, torch.Tensor): + target = torch.tensor(target) + indices = [ + i + for i, t in enumerate(target.view(-1)) + if t.item() != self.ignore_index + ] + if not indices: + values.append(0.0) + continue + if pred.dim() >= 2: + pred_tokens = pred.view(-1, pred.shape[-1]).argmax(dim=-1) + else: + pred_tokens = pred.view(-1) + correct = sum( + 1 + for i in indices + if i < len(pred_tokens) + and pred_tokens[i].item() == target.view(-1)[i].item() + ) + values.append((correct / len(indices)) * 100.0) + return values + + def calculate(self, predictions: List[torch.Tensor]) -> float: + values = self.calculate_per_sample(predictions) + return float(np.mean(values)) if values else 0.0 + + +# --------------------------------------------------------------------------- +# Evaluators +# --------------------------------------------------------------------------- + + +class Evaluator: + def __init__( + self, + dataset: List[Any], + metrics: List[Metric], + post_process: Optional[Callable] = None, + ): + self.dataset = dataset + self.metrics = metrics + self.post_process = post_process or (lambda x: x) + + def evaluate(self, model: Any) -> Dict[str, Any]: + """Run inference and compute all metrics. + + For each metric, always emits the mean value under the metric name. + When dataset has >1 sample, also emits: + {name}_min, {name}_max — range across samples + {name}_worst_idx — index of the worst-performing sample + (argmin for higher_is_better, argmax otherwise) + """ + predictions = self.run_inference(model, self.dataset) + results: Dict[str, Any] = {"_num_samples": len(predictions)} + for metric in self.metrics: + name = metric.name() + try: + per_sample = metric.calculate_per_sample(predictions) + if not per_sample: + results[name] = 0.0 + continue + results[name] = round(float(np.mean(per_sample)), 4) + if len(per_sample) > 1: + results[f"{name}_min"] = round(float(min(per_sample)), 4) + results[f"{name}_max"] = round(float(max(per_sample)), 4) + results[f"{name}_worst_idx"] = metric.worst_index(per_sample) + except Exception as e: + logging.error("Metric %s failed: %s", name, e) + results[name] = f"error: {e}" + return results + + def run_inference(self, model: Any, dataset: List[Any]) -> List[torch.Tensor]: + raise NotImplementedError + + +class StandardEvaluator(Evaluator): + """Standard evaluator for classification and regression models.""" + + def run_inference(self, model: Any, dataset: List[Any]) -> List[torch.Tensor]: + if isinstance(model, PrecomputedOutputs): + return model.outputs + predictions = [] + is_ep = hasattr(model, "module") and callable(model.module) + executable = model.module() if is_ep else model + # torch.no_grad() matches _compute_golden_outputs — without it the + # autograd context can cause subtle numerical differences vs golden. + with torch.no_grad(): + for inputs in dataset: + args = inputs if isinstance(inputs, (tuple, list)) else (inputs,) + raw_out = executable(*args) + out = self.post_process(raw_out) + if isinstance(out, torch.Tensor): + out = out.detach().cpu() + predictions.append(out) + return predictions + + +class MLMEvaluator(Evaluator): + """Evaluator for masked language models with -100 masking. + + Uses self.post_process (from _auto_detect_post_process) to extract logits, + keeping the inference path consistent with how golden outputs are computed + in _compute_golden_outputs. An earlier version had hardcoded + ``out.logits if hasattr(out, "logits") else out`` which diverged from the + golden computation and produced wrong PSNR/cosine for HuggingFace models. + """ + + def run_inference(self, model: Any, dataset: List[Any]) -> List[torch.Tensor]: + if isinstance(model, PrecomputedOutputs): + return model.outputs + predictions = [] + is_ep = hasattr(model, "module") and callable(model.module) + executable = model.module() if is_ep else model + with torch.no_grad(): + for inputs in dataset: + args = inputs if isinstance(inputs, (tuple, list)) else (inputs,) + raw_out = executable(*args) + out = self.post_process(raw_out) + if isinstance(out, torch.Tensor): + out = out.detach().cpu() + predictions.append(out) + return predictions + + +# --------------------------------------------------------------------------- +# AccuracyLens +# --------------------------------------------------------------------------- + + +class AccuracyLens(Lens): + """Evaluates model accuracy at each collected pipeline stage. + + Configures itself lazily when it first observes the "Exported Float" record, + extracting the float model from the ExportedProgram and building an evaluator + with captured dataset + auto-detected metrics. + + Cross-lens data sharing: + _worst_indices is a class-level dict {metric_name: worst_input_index} updated + after every evaluate() call. Future lenses (e.g., per-layer accuracy) can + read it during their own observe() to focus analysis on the worst input: + + from .accuracy import AccuracyLens + worst = AccuracyLens._worst_indices.get("psnr") # int or None + + This follows the same pattern as PipelineGraphCollectorLens._last_export_inputs. + Lenses run in registration order, so AccuracyLens must be registered before + any lens that reads _worst_indices. + """ + + _installed: bool = False + _originals: Dict[str, Any] = {} + # Backend-specific dataset patch installers. + _dataset_patch_installers: List[Callable] = [] + _dataset_uninstallers: List[Callable] = [] + + _float_model: Any = None # cached GraphModule from "Exported Float" ExportedProgram + _captured_dataset: Optional[List[Any]] = None + _captured_targets: Optional[List[Any]] = None + _golden_outputs: Optional[List[torch.Tensor]] = None + _post_process: Optional[Callable] = None + _evaluator: Optional[Evaluator] = None + _task_type: Optional[str] = None # "classification", "mlm", or None + _worst_indices: Dict[str, int] = {} # {metric_name: worst_input_index} + + @classmethod + def register_dataset_patches( + cls, installer: Callable[["AccuracyLens"], None] + ) -> None: + """Register a backend-specific dataset patch installer. + + The installer receives the AccuracyLens class and should set + cls._captured_targets and cls._task_type when dataset functions + are called. It may also append to cls._dataset_uninstallers. + """ + if installer not in cls._dataset_patch_installers: + cls._dataset_patch_installers.append(installer) + + @classmethod + def get_name(cls) -> str: + return "accuracy" + + @classmethod + def on_session_start(cls, context: ObservationContext) -> None: + if cls._installed: + return + for installer in cls._dataset_patch_installers: + try: + installer(cls) + except Exception as exc: + logging.warning( + "[AccuracyLens] Dataset patch installer failed: %s", exc + ) + cls._installed = True + + @classmethod + def on_session_end(cls, context: ObservationContext) -> None: + cls._uninstall_all() + cls._clear_state() + + @classmethod + def clear(cls) -> None: + cls._uninstall_all() + cls._clear_state() + cls._dataset_patch_installers.clear() + cls._dataset_uninstallers.clear() + + @classmethod + def _clear_state(cls) -> None: + cls._float_model = None + cls._captured_dataset = None + cls._captured_targets = None + cls._golden_outputs = None + cls._post_process = None + cls._evaluator = None + cls._task_type = None + cls._worst_indices = {} + + @classmethod + def observe(cls, artifact: Any, context: ObservationContext) -> Any: + record_name = context.shared_state.get("record_name", "") + + # Lazily configure evaluator on first "Exported Float" record + if record_name == "Exported Float" and cls._evaluator is None: + cls._configure_from_float_model(artifact) + + acc_config = context.config.get("accuracy", {}) + evaluator = acc_config.get("evaluator") or cls._evaluator + if not evaluator: + return None + + if not isinstance( + artifact, + (torch.nn.Module, torch.fx.GraphModule, torch.export.ExportedProgram), + ): + return None + + eval_artifact = artifact + if record_name == "Exported Float" and cls._float_model is not None: + eval_artifact = cls._float_model + + try: + raw = evaluator.evaluate(eval_artifact) + # Update class-level worst indices for cross-lens access + cls._worst_indices = { + k[: -len("_worst_idx")]: v + for k, v in raw.items() + if k.endswith("_worst_idx") + } + return { + k: round(v, 4) if isinstance(v, float) else v + for k, v in raw.items() + } + except Exception as e: + logging.error("[AccuracyLens] Evaluation failed: %s", e) + return {"error_message": str(e)} + + @classmethod + def digest(cls, observation: Any, context: ObservationContext) -> Any: + return observation + + @staticmethod + def analyze( + records: List[RecordDigest], config: Dict[str, Any] + ) -> AnalysisResult: + result = AnalysisResult() + for i, record in enumerate(records): + digest = record.data.get("accuracy") + if digest is None: + continue + analysis = RecordAnalysis() + if i > 0: + prev = records[i - 1].data.get("accuracy", {}) + for key in digest: + if key.startswith("_"): + continue + if ( + isinstance(digest.get(key), (int, float)) + and isinstance(prev.get(key), (int, float)) + ): + analysis.data[f"{key}_diff"] = round( + digest[key] - prev[key], 4 + ) + result.per_record_data[record.name] = analysis + return result + + @staticmethod + def get_frontend_spec() -> Frontend: + return _AccuracyFrontend() + + # ------------------------------------------------------------------ + # Auto-configuration helpers + # ------------------------------------------------------------------ + + @classmethod + def _auto_detect_post_process(cls, model: Any, dataset: List[Any]) -> Callable: + try: + sample = dataset[0] + args = sample if isinstance(sample, (tuple, list)) else (sample,) + with torch.no_grad(): + out = model(*args) + if isinstance(out, torch.Tensor): + return lambda x: x + if hasattr(out, "logits"): + return lambda x: x.logits + if isinstance(out, tuple): + return lambda x: x[0] + except Exception as e: + logging.debug("[AccuracyLens] post_process auto-detect failed: %s", e) + return lambda x: x + + @classmethod + def _compute_golden_outputs( + cls, model: Any, dataset: List[Any], post_process: Callable + ) -> List[torch.Tensor]: + golden = [] + with torch.no_grad(): + for inputs in dataset: + args = inputs if isinstance(inputs, (tuple, list)) else (inputs,) + out = model(*args) + processed = post_process(out) + if isinstance(processed, torch.Tensor): + processed = processed.detach().cpu() + golden.append(processed) + return golden + + @classmethod + def _configure_from_float_model(cls, artifact: Any) -> None: + """Lazily configure evaluator from the "Exported Float" ExportedProgram. + + Caches the extracted GraphModule as _float_model so that observe() can + reuse it for the "Exported Float" evaluation instead of calling + artifact.module() a second time (which would create a different + GraphModule instance and risk numerical mismatch with golden outputs). + """ + try: + is_ep = hasattr(artifact, "module") and callable(artifact.module) + model = artifact.module() if is_ep else artifact + cls._float_model = model + + # Primary: captured dataset from dataset loader patches + # Fallback: sample input captured by backend-specific patches in + # PipelineGraphCollectorLens. + if cls._captured_dataset is None: + from .pipeline_graph_collector import PipelineGraphCollectorLens + + calibration_dataset = PipelineGraphCollectorLens._last_calibration_dataset + # cls._captured_dataset might already be captured + # from dataloader patching in _install_dataset_patches + if calibration_dataset is not None: + cls._captured_dataset = calibration_dataset + logging.info( + "[AccuracyLens] Using backend-captured fallback dataset from PipelineGraphCollectorLens" + ) + + if cls._captured_dataset is None: + logging.debug("[AccuracyLens] No dataset available, skipping auto-config") + return + + cls._post_process = cls._auto_detect_post_process(model, cls._captured_dataset) + cls._golden_outputs = cls._compute_golden_outputs( + model, cls._captured_dataset, cls._post_process + ) + cls._evaluator = cls._build_default_evaluator() + if cls._evaluator: + logging.info( + "[AccuracyLens] Auto-configured %s evaluator with %d metrics", + cls._task_type, + len(cls._evaluator.metrics), + ) + except Exception as e: + logging.warning("[AccuracyLens] Auto-config from float model failed: %s", e) + + @classmethod + def _build_default_evaluator(cls) -> Optional[Evaluator]: + if cls._captured_dataset is None: + logging.info("[AccuracyLens] Unable to auto build evaluator because no dataset is captured") + return None + + dataset = cls._captured_dataset + targets = cls._captured_targets + golden = cls._golden_outputs + task_type = cls._task_type + + metrics: List[Metric] = [] + if golden: + metrics.extend([ + PSNR(golden), + CosineSimilarity(golden), + MSE(golden), + AbsErr(golden), + ]) + + if task_type == "classification" and targets: + metrics.extend([TopKAccuracy(targets, k=1), TopKAccuracy(targets, k=5)]) + return StandardEvaluator( + dataset=dataset, metrics=metrics, post_process=cls._post_process + ) + elif task_type == "mlm" and targets: + metrics.append(MaskedTokenAccuracy(targets)) + return MLMEvaluator( + dataset=dataset, metrics=metrics, post_process=cls._post_process + ) + elif metrics: + return StandardEvaluator( + dataset=dataset, metrics=metrics, post_process=cls._post_process + ) + return None + + # ------------------------------------------------------------------ + # Patches + # ------------------------------------------------------------------ + + @classmethod + def _uninstall_all(cls) -> None: + if not cls._installed: + return + for uninstaller in cls._dataset_uninstallers: + try: + uninstaller() + except Exception: + pass + cls._dataset_uninstallers.clear() + cls._originals.clear() + cls._installed = False + logging.info("[AccuracyLens] Uninstalled all patches") + + +# --------------------------------------------------------------------------- +# Frontend +# --------------------------------------------------------------------------- + + +class _AccuracyFrontend(Frontend): + def record( + self, digest: Any, analysis: Dict[str, Any], context: Dict[str, Any] + ) -> Optional[ViewList]: + if not digest or not isinstance(digest, dict): + return None + + # Partition digest keys into three groups: + # primary — mean metric values (no suffix) + # stats — per-sample min/max (_min / _max suffix) + # worst — worst-case input indices (_worst_idx suffix) + num_samples = digest.get("_num_samples", 1) + primary_data = {} + stats_data = {} + worst_data = {} + for k, v in digest.items(): + if k.startswith("_"): + continue # internal keys (_num_samples, etc.) + if k.endswith("_worst_idx"): + worst_data[k[: -len("_worst_idx")]] = v + elif k.endswith("_min") or k.endswith("_max"): + stats_data[k] = v + else: + primary_data[k] = v + + n = f"{num_samples} sample{'s' if num_samples != 1 else ''}" + blocks = [ + TableBlock( + id="accuracy_table", + title=f"Accuracy ({n})", + record=TableRecordSpec(data=primary_data), + order=20, + ) + ] + + # Per-sample min/max table: only present when >1 sample was evaluated. + if stats_data: + blocks.append( + TableBlock( + id="accuracy_stats_table", + title=f"Per-Sample Stats ({n})", + record=TableRecordSpec(data=stats_data), + order=21, + ) + ) + + # Worst input index table: only present when >1 sample was evaluated. + if worst_data: + blocks.append( + TableBlock( + id="accuracy_worst_idx_table", + title=f"Worst Input Index ({n})", + record=TableRecordSpec(data=worst_data), + order=22, + ) + ) + + return ViewList(blocks=blocks) + + def check_index_diffs( + self, prev_digest: Any, curr_digest: Any, analysis: Dict[str, Any] + ) -> Dict[str, str]: + result = {} + if not prev_digest or not curr_digest: + return result + for key in [ + "psnr", "cosine_sim", "mse", "abs_err", + "top_1", "top_5", "masked_token_accuracy", + ]: + if key in prev_digest and key in curr_digest: + prev_val = prev_digest[key] + curr_val = curr_digest[key] + if isinstance(prev_val, (int, float)) and isinstance( + curr_val, (int, float) + ): + if abs(curr_val - prev_val) < 0.0001: + continue + result[key] = f"{curr_val - prev_val:+.4f}" + return result + + def check_badges( + self, digest: Any, analysis: Dict[str, Any] + ) -> List[Dict[str, str]]: + badges = [] + if digest and isinstance(digest, dict) and "error_message" in digest: + badges.append({"label": "ERR", "color": "#d73a49"}) + return badges diff --git a/devtools/observatory/lenses/graph.py b/devtools/observatory/lenses/graph.py new file mode 100644 index 00000000000..beecb7963c9 --- /dev/null +++ b/devtools/observatory/lenses/graph.py @@ -0,0 +1,91 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import Any, Dict, Optional + +import torch + +from executorch.devtools.fx_viewer import FXGraphExporter + +from ..interfaces import Frontend, GraphView, Lens, ObservationContext, ViewList + + +class GraphLens(Lens): + """Canonical producer of base fx_viewer graph payload per record.""" + + @classmethod + def get_name(cls) -> str: + return "graph" + + @classmethod + def _to_graph_module(cls, artifact: Any) -> Optional[torch.fx.GraphModule]: + if isinstance(artifact, torch.fx.GraphModule): + return artifact + + graph_module = getattr(artifact, "graph_module", None) + if isinstance(graph_module, torch.fx.GraphModule): + return graph_module + + try: + from torch.export import ExportedProgram + + if isinstance(artifact, ExportedProgram): + return artifact.graph_module + + exported_program = getattr(artifact, "exported_program", None) + if isinstance(exported_program, ExportedProgram): + return exported_program.graph_module + except Exception: + pass + + return None + + @classmethod + def observe(cls, artifact: Any, context: ObservationContext) -> Any: + graph_module = cls._to_graph_module(artifact) + if graph_module is None: + return None + + exporter = FXGraphExporter(graph_module) + payload = exporter.generate_json_payload() + + base = payload.get("base", {}) + record_name = str(context.shared_state.get("record_name") or "record") + + return { + "graph_ref": record_name, + "base": base, + "meta": { + "record_name": record_name, + "node_count": len(base.get("nodes", [])), + "edge_count": len(base.get("edges", [])), + }, + } + + @classmethod + def digest(cls, observation: Any, context: ObservationContext) -> Any: + return observation + + class GraphFrontend(Frontend): + def record(self, digest, analysis, context) -> Optional[ViewList]: + if not digest: + return None + + view = GraphView( + id="graph_main", + title="Graph", + graph_ref=str(digest.get("graph_ref", "")), + default_layers=["graph_color/op_type", "graph_color/op_target"], + default_color_by="graph_color/op_type", + order=10, + ) + return ViewList(blocks=[view.as_block()]) + + @staticmethod + def get_frontend_spec() -> Frontend: + return GraphLens.GraphFrontend() diff --git a/devtools/observatory/lenses/graph_color.py b/devtools/observatory/lenses/graph_color.py new file mode 100644 index 00000000000..3e225439c7d --- /dev/null +++ b/devtools/observatory/lenses/graph_color.py @@ -0,0 +1,94 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Graph Color Lens — adds Op Type and Op Target color layers to graph views. + +Derives color-by layers from the base graph payload captured by GraphLens. +No runtime observation needed; all work happens in the analyze phase. +""" + +from __future__ import annotations + +from typing import Any, Dict, List + +from executorch.devtools.fx_viewer.color_rules import ( + CategoricalColorRule, +) +from executorch.devtools.fx_viewer.extension import GraphExtension +from executorch.devtools.fx_viewer.models import ( + GraphExtensionNodePayload, + GraphExtensionPayload, +) + +from ..interfaces import ( + AnalysisResult, + Lens, + RecordAnalysis, + RecordDigest, +) + +_OP_COLOR_MAP = { + "call_function": "#4A90D9", + "placeholder": "#50B86C", + "output": "#E85D5D", + "call_module": "#9B6DC6", + "call_method": "#E8A838", + "get_attr": "#7190C9", +} + + +class GraphColorLens(Lens): + """Produces Op Type and Op Target color layers from graph structure.""" + + @classmethod + def get_name(cls) -> str: + return "graph_color" + + @staticmethod + def analyze( + records: List[RecordDigest], config: Dict[str, Any] + ) -> AnalysisResult: + result = AnalysisResult() + + for record in records: + graph_digest = record.data.get("graph") + if not graph_digest or not isinstance(graph_digest, dict): + continue + base = graph_digest.get("base") + if not base or not isinstance(base, dict): + continue + nodes = base.get("nodes") + if not nodes: + continue + + op_type_ext = GraphExtension(id="op_type", name="Op Type") + target_ext = GraphExtension(id="op_target", name="Op Target") + + for node in nodes: + node_id = node.get("id", "") + info = node.get("info", {}) + op = info.get("op", "") + target_raw = info.get("target", "") + + op_type_ext.add_node_data(node_id, {"op_type": op}) + + if op == "call_function": + cat = target_raw.replace("aten.", "").replace(".default", "") + else: + cat = op + target_ext.add_node_data(node_id, {"target_category": cat}) + + op_type_ext.set_color_rule( + CategoricalColorRule("op_type", color_map=_OP_COLOR_MAP) + ) + target_ext.set_color_rule(CategoricalColorRule("target_category")) + + analysis = RecordAnalysis() + analysis.add_graph_layer("op_type", op_type_ext) + analysis.add_graph_layer("op_target", target_ext) + result.per_record_data[record.name] = analysis + + return result diff --git a/devtools/observatory/lenses/metadata.py b/devtools/observatory/lenses/metadata.py new file mode 100644 index 00000000000..c044aff8d69 --- /dev/null +++ b/devtools/observatory/lenses/metadata.py @@ -0,0 +1,154 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import os +import platform +import sys +from datetime import datetime +from typing import Any, Dict, List, Optional + +import torch + +from ..interfaces import ( + AnalysisResult, + Frontend, + Lens, + ObservationContext, + RecordAnalysis, + RecordDigest, + TableBlock, + TableRecordSpec, + ViewList, +) + + +class MetadataLens(Lens): + """Collects basic metadata about artifacts and runtime environment.""" + + @classmethod + def get_name(cls) -> str: + return "metadata" + + @classmethod + def observe(cls, artifact: Any, context: ObservationContext) -> Any: + artifact_type = str(type(artifact).__name__) + node_count: Any = "N/A" + + try: + from torch.export import ExportedProgram + + if isinstance(artifact, torch.fx.GraphModule): + artifact_type = "GM" + node_count = len(list(artifact.graph.nodes)) + elif isinstance(artifact, ExportedProgram): + artifact_type = "EP" + node_count = len(list(artifact.graph_module.graph.nodes)) + elif isinstance(artifact, torch.nn.Module): + artifact_type = "NN" + except Exception: + pass + + context.shared_state["artifact_type"] = artifact_type + return { + "artifact_type": artifact_type, + "node_count": node_count, + } + + @classmethod + def digest(cls, observation: Any, context: ObservationContext) -> Any: + return observation + + @classmethod + def on_session_start(cls, context: ObservationContext) -> Optional[Dict[str, Any]]: + return { + "command_line": " ".join(sys.orig_argv), + "python_version": sys.version.split("\n")[0], + "platform_system": platform.system(), + "platform_release": platform.release(), + "platform_machine": platform.machine(), + "working_directory": os.getcwd(), + "start_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + } + + @staticmethod + def analyze(records: List[RecordDigest], config: Dict[str, Any]) -> AnalysisResult: + per_record: Dict[str, RecordAnalysis] = {} + + for i in range(len(records) - 1): + def _count(rec: RecordDigest) -> int: + data = rec.data.get("metadata") + if not data: + return 0 + value = data.get("node_count") + return int(value) if isinstance(value, (int, float)) else 0 + + before = _count(records[i]) + after = _count(records[i + 1]) + per_record[records[i + 1].name] = RecordAnalysis( + data={"node_diff": after - before} + ) + + return AnalysisResult(per_record_data=per_record) + + class MetadataFrontend(Frontend): + def dashboard(self, start, end, analysis, records) -> Optional[ViewList]: + return ViewList( + blocks=[ + TableBlock( + id="metadata_dashboard", + title="Session Metadata", + record=TableRecordSpec(data=start or {}), + order=0, + ) + ] + ) + + def record(self, digest, analysis, context) -> Optional[ViewList]: + data = digest.copy() if isinstance(digest, dict) else {} + record_analysis = (analysis or {}).get("record") or {} + node_diff = record_analysis.get("node_diff", 0) + if node_diff: + data["nodes_change"] = f"{node_diff:+d}" + + return ViewList( + blocks=[ + TableBlock( + id="metadata_record", + title="Metadata", + record=TableRecordSpec(data=data), + order=0, + ) + ] + ) + + def check_index_diffs(self, prev_digest, curr_digest, analysis): + try: + before = int(prev_digest.get("node_count", 0)) + after = int(curr_digest.get("node_count", 0)) + diff = after - before + if diff: + return {"nodes": f"{diff:+d}"} + except Exception: + return {} + return {} + + def check_badges(self, digest, analysis): + badges = [] + if digest and "artifact_type" in digest: + badges.append( + { + "label": str(digest["artifact_type"]), + "class": "badge", + "title": "Artifact Type", + } + ) + return badges + + @staticmethod + def get_frontend_spec() -> Frontend: + return MetadataLens.MetadataFrontend() diff --git a/devtools/observatory/lenses/per_layer_accuracy.py b/devtools/observatory/lenses/per_layer_accuracy.py new file mode 100644 index 00000000000..3e67844688c --- /dev/null +++ b/devtools/observatory/lenses/per_layer_accuracy.py @@ -0,0 +1,957 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Per-layer accuracy lens with sparse from_node-root matching. + +Design constraints: +1. Sparse correspondence only: each match key maps to one node per graph. +2. Key priority: from_node_root first, then node-id fallback when root missing. +3. Last topological node wins for duplicate keys in each graph. +4. Per-layer metrics are computed on one sample index, reusing AccuracyLens + worst-index selection when available. +""" + +from __future__ import annotations + +import html +import math +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import torch + +from executorch.devtools.fx_viewer.color_rules import ColorRule +from executorch.devtools.fx_viewer.extension import GraphExtension + +from ..interfaces import ( + AnalysisResult, + Frontend, + GraphCompareSpec, + GraphView, + HtmlBlock, + HtmlRecordSpec, + Lens, + ObservationContext, + RecordAnalysis, + RecordDigest, + TableBlock, + TableRecordSpec, + ViewList, +) +from .accuracy import AbsErr, AccuracyLens, CosineSimilarity, MSE, PSNR +from .pipeline_graph_collector import PipelineGraphCollectorLens + + +@dataclass +class _SparseNodeRef: + node_id: str + key_kind: str + from_node_root: Optional[str] + topo_index: int + + +class _NodeOutputCapturer(torch.fx.Interpreter): + """Capture intermediate outputs by node id.""" + + def __init__(self, module: torch.fx.GraphModule): + super().__init__(module) + self.outputs: Dict[str, Any] = {} + + def run_node(self, n: torch.fx.Node) -> Any: + out = super().run_node(n) + if n.op not in ("placeholder", "output"): + self.outputs[n.name] = out + return out + + +class _MetricNumericColorRule(ColorRule): + """Numeric color rule with optional inverse severity direction. + + When ``fixed_range`` is supplied, the given ``(vmin, vmax)`` is used for + normalization instead of computing it from ``nodes_data``. This lets the + caller share a single color scale across multiple records. + """ + + def __init__( + self, + attribute: str, + *, + low_rgb: Tuple[int, int, int], + high_rgb: Tuple[int, int, int], + inverse: bool = False, + fixed_range: Optional[Tuple[float, float]] = None, + ) -> None: + super().__init__(attribute) + self.low_rgb = low_rgb + self.high_rgb = high_rgb + self.inverse = inverse + self.fixed_range = fixed_range + + @staticmethod + def _interp(low: int, high: int, ratio: float) -> int: + return int(low + (high - low) * ratio) + + def _color(self, ratio: float) -> str: + ratio = max(0.0, min(1.0, ratio)) + lr, lg, lb = self.low_rgb + hr, hg, hb = self.high_rgb + r = self._interp(lr, hr, ratio) + g = self._interp(lg, hg, ratio) + b = self._interp(lb, hb, ratio) + return f"#{r:02x}{g:02x}{b:02x}" + + def _resolve_range(self, vals: List[float]) -> Optional[Tuple[float, float]]: + if self.fixed_range is not None: + rmin, rmax = self.fixed_range + if math.isfinite(rmin) and math.isfinite(rmax): + vmin, vmax = float(rmin), float(rmax) + if vmin == vmax: + vmax = vmin + 1e-12 + return vmin, vmax + if not vals: + return None + vmin, vmax = min(vals), max(vals) + if vmin == vmax: + vmax = vmin + 1e-12 + return vmin, vmax + + def apply(self, nodes_data: dict) -> tuple[dict, list]: + vals = [] + for data in nodes_data.values(): + v = data.get(self.attribute) + if isinstance(v, (int, float)): + fv = float(v) + if math.isfinite(fv): + vals.append(fv) + + resolved = self._resolve_range(vals) + if resolved is None: + return {}, [] + vmin, vmax = resolved + + node_colors = {} + for node_id, data in nodes_data.items(): + v = data.get(self.attribute) + if not isinstance(v, (int, float)): + continue + fv = float(v) + if not math.isfinite(fv): + continue + ratio = (fv - vmin) / (vmax - vmin) + if self.inverse: + ratio = 1.0 - ratio + node_colors[node_id] = self._color(ratio) + + legend = [] + for i in range(5): + t = i / 4.0 + ratio = 1.0 - t if self.inverse else t + color = self._color(ratio) + val = vmin + t * (vmax - vmin) + legend.append({"label": f"{val:.3f}", "color": color}) + return node_colors, legend + + +class PerLayerAccuracyLens(Lens): + _anchor_graph_module: Optional[torch.fx.GraphModule] = None + _anchor_record_name: Optional[str] = None + _anchor_sparse_index: Dict[str, _SparseNodeRef] = {} + _anchor_outputs_cache: Dict[int, Dict[str, Any]] = {} + + @classmethod + def get_name(cls) -> str: + return "per_layer_accuracy" + + @classmethod + def clear(cls) -> None: + cls._anchor_graph_module = None + cls._anchor_record_name = None + cls._anchor_sparse_index = {} + cls._anchor_outputs_cache = {} + + @classmethod + def on_session_end(cls, context: ObservationContext) -> None: + cls.clear() + + @classmethod + def _lens_config(cls, context: ObservationContext) -> Dict[str, Any]: + cfg = context.config.get("per_layer_accuracy", {}) + return cfg if isinstance(cfg, dict) else {} + + @classmethod + def _to_graph_module(cls, artifact: Any) -> Optional[torch.fx.GraphModule]: + try: + from torch.export import ExportedProgram + + if isinstance(artifact, ExportedProgram): + # Use executable GraphModule with bound params/buffers. + # Raw ExportedProgram.graph_module has lifted placeholders and + # cannot be executed with dataset sample inputs alone. + gm = artifact.module() + return gm if isinstance(gm, torch.fx.GraphModule) else None + + exported_program = getattr(artifact, "exported_program", None) + if isinstance(exported_program, ExportedProgram): + gm = exported_program.module() + return gm if isinstance(gm, torch.fx.GraphModule) else None + except Exception: + pass + + if isinstance(artifact, torch.fx.GraphModule): + return artifact + + graph_module = getattr(artifact, "graph_module", None) + if isinstance(graph_module, torch.fx.GraphModule): + return graph_module + + return None + + @staticmethod + def _extract_from_node_root(node: torch.fx.Node) -> Optional[str]: + root_name = node.meta.get("from_node_root") + if isinstance(root_name, str) and root_name: + return root_name + + from_node = node.meta.get("from_node") + if not isinstance(from_node, list) or not from_node: + return None + + try: + ns = from_node[-1] + while getattr(ns, "from_node", None): + parent = ns.from_node + if not isinstance(parent, list) or not parent: + break + ns = parent[-1] + name = getattr(ns, "name", None) + return str(name) if name else None + except Exception: + return None + + @classmethod + def _build_sparse_node_index( + cls, graph_module: torch.fx.GraphModule + ) -> Dict[str, _SparseNodeRef]: + """Sparse index with last-topological node selection per key.""" + sparse: Dict[str, _SparseNodeRef] = {} + for topo, node in enumerate(graph_module.graph.nodes): + if node.op in ("placeholder", "output"): + continue + root = cls._extract_from_node_root(node) + if root: + key = f"root:{root}" + kind = "root" + else: + key = f"id:{node.name}" + kind = "id_fallback" + sparse[key] = _SparseNodeRef( + node_id=node.name, + key_kind=kind, + from_node_root=root, + topo_index=topo, + ) + return sparse + + @staticmethod + def _resolve_dataset() -> Optional[List[Any]]: + if ( + isinstance(AccuracyLens._captured_dataset, list) + and AccuracyLens._captured_dataset + ): + return AccuracyLens._captured_dataset + if ( + isinstance(PipelineGraphCollectorLens._last_calibration_dataset, list) + and PipelineGraphCollectorLens._last_calibration_dataset + ): + return PipelineGraphCollectorLens._last_calibration_dataset + return None + + @staticmethod + def _pick_sample_index( + dataset: Optional[List[Any]], + config: Dict[str, Any], + ) -> Tuple[int, str]: + if not dataset: + return 0, "default(0)" + + explicit = config.get("sample_index") + if isinstance(explicit, int): + idx = min(max(explicit, 0), len(dataset) - 1) + return idx, "config.sample_index" + + priority = config.get( + "worst_metric_priority", + ["psnr", "cosine_sim", "mse", "abs_err", "top_1", "top_5"], + ) + if not isinstance(priority, list): + priority = ["psnr", "cosine_sim", "mse", "abs_err", "top_1", "top_5"] + + for metric_name in priority: + idx = AccuracyLens._worst_indices.get(str(metric_name)) + if isinstance(idx, int): + return ( + min(max(idx, 0), len(dataset) - 1), + f"accuracy.worst[{metric_name}]", + ) + + return 0, "default(0)" + + @staticmethod + def _safe_float(value: Any, default: float = 0.0) -> float: + try: + parsed = float(value) + except (TypeError, ValueError): + return default + return parsed if math.isfinite(parsed) else default + + @staticmethod + def _normalize_sample(sample: Any) -> Tuple[Any, ...]: + if isinstance(sample, tuple): + return sample + if isinstance(sample, list): + return tuple(sample) + return (sample,) + + @classmethod + def _capture_outputs( + cls, + graph_module: torch.fx.GraphModule, + sample: Tuple[Any, ...], + ) -> Dict[str, Any]: + capturer = _NodeOutputCapturer(graph_module) + with torch.no_grad(): + capturer.run(*sample) + return capturer.outputs + + @staticmethod + def _flatten_for_metric(value: Any) -> Tuple[Optional[torch.Tensor], str]: + if isinstance(value, torch.Tensor): + return value.detach().cpu().to(torch.float64).reshape(-1), str( + tuple(value.shape) + ) + + if isinstance(value, (tuple, list)): + tensors = [ + v.detach().cpu().to(torch.float64).reshape(-1) + for v in value + if isinstance(v, torch.Tensor) + ] + if tensors: + shape = ( + "[" + + ", ".join( + str(tuple(v.shape)) + for v in value + if isinstance(v, torch.Tensor) + ) + + "]" + ) + return torch.cat(tensors), shape + scalars = [float(v) for v in value if isinstance(v, (int, float, bool))] + if scalars: + return ( + torch.tensor(scalars, dtype=torch.float64), + f"list(len={len(scalars)})", + ) + return None, "unsupported_sequence" + + if isinstance(value, (int, float, bool)): + return torch.tensor([float(value)], dtype=torch.float64), "scalar" + + return None, f"unsupported:{type(value).__name__}" + + @classmethod + def _compute_pair_metrics( + cls, + anchor_value: Any, + target_value: Any, + ) -> Optional[Dict[str, Any]]: + anchor_flat, anchor_shape = cls._flatten_for_metric(anchor_value) + target_flat, target_shape = cls._flatten_for_metric(target_value) + if anchor_flat is None or target_flat is None: + return None + + compared = min(anchor_flat.numel(), target_flat.numel()) + if compared <= 0: + return None + + anchor_vec = torch.nan_to_num( + anchor_flat[:compared], nan=0.0, posinf=0.0, neginf=0.0 + ) + target_vec = torch.nan_to_num( + target_flat[:compared], nan=0.0, posinf=0.0, neginf=0.0 + ) + + predictions = [target_vec] + golden = [anchor_vec] + + psnr = cls._safe_float(PSNR(golden).calculate(predictions)) + cosine = cls._safe_float(CosineSimilarity(golden).calculate(predictions)) + mse = cls._safe_float(MSE(golden).calculate(predictions)) + abs_err = cls._safe_float(AbsErr(golden).calculate(predictions)) + + return { + "numel_compared": int(compared), + "anchor_shape": anchor_shape, + "target_shape": target_shape, + "psnr": float(psnr), + "cosine_sim": float(cosine), + "mse": float(mse), + "abs_err": float(abs_err), + } + + @staticmethod + def _summarize_rows(rows: List[Dict[str, Any]]) -> Dict[str, float]: + if not rows: + return { + "psnr_mean": 0.0, + "psnr_min": 0.0, + "psnr_max": 0.0, + "cosine_sim_mean": 0.0, + "mse_mean": 0.0, + "abs_err_mean": 0.0, + } + + def _values(k: str) -> List[float]: + return [PerLayerAccuracyLens._safe_float(r.get(k, 0.0)) for r in rows] + + psnr_vals = _values("psnr") + return { + "psnr_mean": f"{float(sum(psnr_vals) / len(psnr_vals)):.4f}", + "psnr_min": f"{float(min(psnr_vals)):.4f}", + "psnr_max": f"{float(max(psnr_vals)):.4f}", + "cosine_sim_mean": f"{float(sum(_values('cosine_sim')) / len(rows)):.4f}", + "mse_mean": f"{float(sum(_values('mse')) / len(rows)):.4f}", + "abs_err_mean": f"{float(sum(_values('abs_err')) / len(rows)):.4f}", + } + + @classmethod + def observe(cls, artifact: Any, context: ObservationContext) -> Any: + graph_module = cls._to_graph_module(artifact) + if graph_module is None: + return None + + cfg = cls._lens_config(context) + anchor_name = str(cfg.get("anchor_record_name", "Exported Float")) + record_name = str(context.shared_state.get("record_name", "no_record_name!")) + + sparse_index = cls._build_sparse_node_index(graph_module) + if record_name == anchor_name: + cls._anchor_graph_module = graph_module + cls._anchor_record_name = record_name + cls._anchor_sparse_index = sparse_index + cls._anchor_outputs_cache = {} + + if cls._anchor_graph_module is None or not cls._anchor_sparse_index: + return { + "graph_ref": record_name, + "anchor_record": anchor_name, + "sample_index": 0, + "sample_source": "no_anchor", + "rows": [], + "summary": {}, + "match_count": 0, + "anchor_sparse_count": 0, + "target_sparse_count": len(sparse_index), + } + + dataset = cls._resolve_dataset() + if not dataset: + return { + "graph_ref": record_name, + "anchor_record": cls._anchor_record_name or anchor_name, + "sample_index": 0, + "sample_source": "no_dataset", + "rows": [], + "summary": {}, + "match_count": 0, + "anchor_sparse_count": len(cls._anchor_sparse_index), + "target_sparse_count": len(sparse_index), + } + + sample_index, sample_source = cls._pick_sample_index(dataset, cfg) + sample = cls._normalize_sample(dataset[sample_index]) + + if sample_index not in cls._anchor_outputs_cache: + cls._anchor_outputs_cache[sample_index] = cls._capture_outputs( + cls._anchor_graph_module, sample + ) + anchor_outputs = cls._anchor_outputs_cache[sample_index] + + if graph_module is cls._anchor_graph_module: + target_outputs = anchor_outputs + else: + target_outputs = cls._capture_outputs(graph_module, sample) + + matched_keys = sorted( + set(cls._anchor_sparse_index.keys()) & set(sparse_index.keys()) + ) + rows: List[Dict[str, Any]] = [] + for key in matched_keys: + anchor_ref = cls._anchor_sparse_index[key] + target_ref = sparse_index[key] + if ( + anchor_ref.node_id not in anchor_outputs + or target_ref.node_id not in target_outputs + ): + continue + + metrics = cls._compute_pair_metrics( + anchor_outputs[anchor_ref.node_id], + target_outputs[target_ref.node_id], + ) + if metrics is None: + continue + + rows.append( + { + "match_key": key, + "key_kind": target_ref.key_kind, + "from_node_root": target_ref.from_node_root, + "anchor_node": anchor_ref.node_id, + "target_node": target_ref.node_id, + "anchor_topo_index": anchor_ref.topo_index, + "target_topo_index": target_ref.topo_index, + **metrics, + } + ) + + return { + "graph_ref": record_name, + "anchor_record": cls._anchor_record_name or anchor_name, + "sample_index": sample_index, + "sample_source": sample_source, + "rows": rows, + "summary": cls._summarize_rows(rows), + "match_count": len(rows), + "anchor_sparse_count": len(cls._anchor_sparse_index), + "target_sparse_count": len(sparse_index), + } + + @classmethod + def digest(cls, observation: Any, context: ObservationContext) -> Any: + return observation + + @staticmethod + def _metric_specs() -> Dict[str, Dict[str, Any]]: + return { + # Lower value is worse for PSNR/Cosine. + "psnr": { + "name": "Per-layer PSNR", + "label": "PSNR", + "inverse": True, + }, + "cosine_sim": { + "name": "Per-layer Cosine Similarity", + "label": "Cosine", + "inverse": True, + }, + # Higher value is worse for error metrics. + "mse": { + "name": "Per-layer MSE", + "label": "MSE", + "inverse": False, + }, + "abs_err": { + "name": "Per-layer AbsErr", + "label": "AbsErr", + "inverse": False, + }, + } + + @classmethod + def _build_metric_extension( + cls, + rows: List[Dict[str, Any]], + metric_name: str, + *, + fixed_range: Optional[Tuple[float, float]] = None, + ) -> GraphExtension: + spec = cls._metric_specs().get(metric_name) + if not spec: + raise ValueError(f"Unsupported per-layer metric extension: {metric_name}") + + ext = GraphExtension(id=metric_name, name=str(spec["name"])) + for row in rows: + node_id = str(row["target_node"]) + info = { + "sparse_match_key": row.get("match_key", ""), + "key_kind": row.get("key_kind", ""), + "from_node_root": row.get("from_node_root", ""), + "anchor_node": row.get("anchor_node", ""), + "target_node": row.get("target_node", ""), + "anchor_topo_index": row.get("anchor_topo_index", -1), + "target_topo_index": row.get("target_topo_index", -1), + "numel_compared": row.get("numel_compared", 0), + "anchor_shape": row.get("anchor_shape", "n/a"), + "target_shape": row.get("target_shape", "n/a"), + "psnr": cls._safe_float(row.get("psnr", 0.0)), + "cosine_sim": cls._safe_float(row.get("cosine_sim", 0.0)), + "mse": cls._safe_float(row.get("mse", 0.0)), + "abs_err": cls._safe_float(row.get("abs_err", 0.0)), + } + ext.add_node_data(node_id, info) + + ext.set_sync_key("sparse_match_key") + + def _format_metric_value(value: float, *, tooltip: bool = False) -> str: + if metric_name in ("mse", "abs_err"): + return f"{value:.6e}" if tooltip else f"{value:.3e}" + return f"{value:.6f}" if tooltip else f"{value:.4f}" + + def _label_formatter(d: Dict[str, Any]) -> List[str]: + primary = cls._safe_float(d.get(metric_name, 0.0)) + primary_label = str(spec["label"]) + return [f"{primary_label}={_format_metric_value(primary)}"] + + ext.set_label_formatter(_label_formatter) + + def _tooltip_formatter(d: Dict[str, Any]) -> List[str]: + primary = cls._safe_float(d.get(metric_name, 0.0)) + primary_label = str(spec["label"]) + return [ + f"target_node={d.get('target_node', 'n/a')}", + f"match_key={d.get('sparse_match_key', '')}", + f"{primary_label}={_format_metric_value(primary, tooltip=True)}", + ] + + ext.set_tooltip_formatter(_tooltip_formatter) + ext.set_color_rule( + _MetricNumericColorRule( + attribute=metric_name, + # Severe values map to darker red. + low_rgb=(254, 224, 210), + high_rgb=(165, 15, 21), + inverse=bool(spec["inverse"]), + fixed_range=fixed_range, + ) + ) + return ext + + @staticmethod + def _aggregate_metric_ranges( + records: List[RecordDigest], + ) -> Dict[str, List[float]]: + """Union (min, max) per metric across every record's rows. + + Returned as list-of-floats for clean JSON round-tripping through + ``AnalysisResult.global_data``. + """ + pools: Dict[str, List[float]] = { + metric: [] for metric in PerLayerAccuracyLens._metric_specs().keys() + } + for record in records: + digest = record.data.get("per_layer_accuracy") + if not isinstance(digest, dict): + continue + rows = digest.get("rows") + if not isinstance(rows, list): + continue + for row in rows: + for metric in pools: + v = row.get(metric) + if isinstance(v, (int, float)): + fv = float(v) + if math.isfinite(fv): + pools[metric].append(fv) + + ranges: Dict[str, List[float]] = {} + for metric, vals in pools.items(): + if vals: + ranges[metric] = [min(vals), max(vals)] + return ranges + + @staticmethod + def analyze(records: List[RecordDigest], config: Dict[str, Any]) -> AnalysisResult: + result = AnalysisResult() + metric_ranges = PerLayerAccuracyLens._aggregate_metric_ranges(records) + if metric_ranges: + result.global_data["metric_ranges"] = metric_ranges + + for record in records: + digest = record.data.get("per_layer_accuracy") + if not isinstance(digest, dict): + continue + rows = digest.get("rows") + if not isinstance(rows, list) or not rows: + continue + + analysis = RecordAnalysis( + data={ + "match_count": digest.get("match_count", 0), + "sample_index": digest.get("sample_index", 0), + "sample_source": digest.get("sample_source", "n/a"), + } + ) + + for metric_name in ("cosine_sim",): + # TODO other options "psnr" "mse", "abs_err" + r = metric_ranges.get(metric_name) + fixed_range = (r[0], r[1]) if r else None + metric_ext = PerLayerAccuracyLens._build_metric_extension( + rows, metric_name, fixed_range=fixed_range + ) + analysis.add_graph_layer(metric_name, metric_ext) + + result.per_record_data[record.name] = analysis + + return result + + class _PerLayerAccuracyFrontend(Frontend): + @staticmethod + def _interp_color( + ratio: float, + low_rgb: Tuple[int, int, int], + high_rgb: Tuple[int, int, int], + ) -> str: + ratio = max(0.0, min(1.0, ratio)) + r = int(low_rgb[0] + (high_rgb[0] - low_rgb[0]) * ratio) + g = int(low_rgb[1] + (high_rgb[1] - low_rgb[1]) * ratio) + b = int(low_rgb[2] + (high_rgb[2] - low_rgb[2]) * ratio) + return f"#{r:02x}{g:02x}{b:02x}" + + @staticmethod + def _hex_to_rgb(hex_color: str) -> Tuple[int, int, int]: + h = hex_color.lstrip("#") + if len(h) != 6: + return (255, 255, 255) + return (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)) + + @classmethod + def _text_color_for_bg(cls, hex_color: str) -> str: + r, g, b = cls._hex_to_rgb(hex_color) + # WCAG-ish luma heuristic. + luma = 0.299 * r + 0.587 * g + 0.114 * b + return "#111111" if luma > 150 else "#f8f8f8" + + @classmethod + def _metric_cell_style( + cls, + value: float, + vmin: float, + vmax: float, + *, + low_rgb: Tuple[int, int, int], + high_rgb: Tuple[int, int, int], + inverse: bool, + ) -> str: + value = PerLayerAccuracyLens._safe_float(value) + vmin = PerLayerAccuracyLens._safe_float(vmin) + vmax = PerLayerAccuracyLens._safe_float(vmax) + if vmax <= vmin: + ratio = 0.0 + else: + ratio = (value - vmin) / (vmax - vmin) + if not math.isfinite(ratio): + ratio = 0.0 + if inverse: + ratio = 1.0 - ratio + bg = cls._interp_color(ratio, low_rgb, high_rgb) + fg = cls._text_color_for_bg(bg) + return f"background:{bg};color:{fg};" + + @classmethod + def _merged_metrics_table_html( + cls, + rows: Iterable[Dict[str, Any]], + metric_ranges: Optional[Dict[str, List[float]]] = None, + ) -> str: + row_list = list(rows) + if not row_list: + return "
No matched nodes.
" + + # Worst -> best ranking uses PSNR primarily (lower is worse). + row_list.sort( + key=lambda r: ( + PerLayerAccuracyLens._safe_float(r.get("psnr", 0.0)), + -PerLayerAccuracyLens._safe_float(r.get("mse", 0.0)), + -PerLayerAccuracyLens._safe_float(r.get("abs_err", 0.0)), + PerLayerAccuracyLens._safe_float(r.get("cosine_sim", 0.0)), + ) + ) + + def _range(k: str) -> Tuple[float, float]: + if metric_ranges and k in metric_ranges: + r = metric_ranges[k] + if len(r) == 2: + return (float(r[0]), float(r[1])) + vals = [ + PerLayerAccuracyLens._safe_float(r.get(k, 0.0)) for r in row_list + ] + return (min(vals), max(vals)) if vals else (0.0, 0.0) + + psnr_min, psnr_max = _range("psnr") + cos_min, cos_max = _range("cosine_sim") + mse_min, mse_max = _range("mse") + abs_min, abs_max = _range("abs_err") + + parts = [ + "", + "", + "", + "", + "", + ] + + for rank, row in enumerate(row_list, start=1): + node = html.escape(str(row.get("target_node", ""))) + anchor = html.escape(str(row.get("anchor_node", ""))) + root = html.escape(str(row.get("from_node_root") or "n/a")) + psnr = PerLayerAccuracyLens._safe_float(row.get("psnr", 0.0)) + cosine = PerLayerAccuracyLens._safe_float(row.get("cosine_sim", 0.0)) + mse = PerLayerAccuracyLens._safe_float(row.get("mse", 0.0)) + abs_err = PerLayerAccuracyLens._safe_float(row.get("abs_err", 0.0)) + + psnr_style = cls._metric_cell_style( + psnr, + psnr_min, + psnr_max, + low_rgb=(254, 224, 210), + high_rgb=(165, 15, 21), + inverse=True, + ) + cos_style = cls._metric_cell_style( + cosine, + cos_min, + cos_max, + low_rgb=(254, 237, 222), + high_rgb=(217, 72, 1), + inverse=True, + ) + mse_style = cls._metric_cell_style( + mse, + mse_min, + mse_max, + low_rgb=(239, 243, 255), + high_rgb=(8, 81, 156), + inverse=False, + ) + abs_style = cls._metric_cell_style( + abs_err, + abs_min, + abs_max, + low_rgb=(242, 240, 247), + high_rgb=(84, 39, 143), + inverse=False, + ) + + parts.append( + "" + f"" + f"" + f"" + f"" + f"" + f"" + f"" + f"" + "" + ) + + parts.append("
#Target NodeAnchor NodeRootPSNRCosineMSEAbsErr
{rank}{node}{anchor}{root}{psnr:.4f}{cosine:.6f}{mse:.6e}{abs_err:.6e}
") + return "".join(parts) + + def resources(self) -> Dict[str, str]: + return { + "css": """ +.pla-metric-table { + width: 100%; + border-collapse: collapse; + font-size: 12px; +} +.pla-metric-table th, .pla-metric-table td { + border: 1px solid var(--border-color); + padding: 4px 6px; + text-align: left; +} +.pla-metric-table th { + background: var(--bg-tertiary); + color: var(--text-primary); +} +.pla-empty { + color: var(--text-secondary); + font-size: 12px; +} +""" + } + + def record( + self, digest: Any, analysis: Dict[str, Any], context: Dict[str, Any] + ) -> Optional[ViewList]: + if not isinstance(digest, dict): + return None + + rows = digest.get("rows") if isinstance(digest.get("rows"), list) else [] + summary = ( + digest.get("summary", {}) + if isinstance(digest.get("summary"), dict) + else {} + ) + graph_ref = str(digest.get("graph_ref") or context.get("name") or "") + lens_name = PerLayerAccuracyLens.get_name() + metric_ranges = None + if isinstance(analysis, dict): + raw = analysis.get("global") + if isinstance(raw, dict): + mr = raw.get("metric_ranges") + if isinstance(mr, dict): + metric_ranges = mr + + summary_data = { + "anchor_record": digest.get("anchor_record", "n/a"), + "sample_index": digest.get("sample_index", 0), + "sample_source": digest.get("sample_source", "n/a"), + "match_count": digest.get("match_count", 0), + "anchor_sparse_count": digest.get("anchor_sparse_count", 0), + "target_sparse_count": digest.get("target_sparse_count", 0), + **summary, + } + + blocks = [ + TableBlock( + id="per_layer_accuracy_summary", + title="Per-layer Accuracy Summary", + record=TableRecordSpec(data=summary_data), + order=20, + ), + GraphView( + id="per_layer_accuracy_graph", + title="Per-layer Accuracy Graph", + graph_ref=graph_ref, + default_layers=[f"{lens_name}/cosine_sim"], + default_color_by=f"{lens_name}/cosine_sim", + compare=GraphCompareSpec( + default_sync={ + "mode": "layer", + "layer": f"{lens_name}/cosine_sim", + "field": "sparse_match_key", + } + ), + order=21, + ).as_block(), + HtmlBlock( + id="per_layer_accuracy_metrics_table", + title="Per-layer Metrics (Worst → Best)", + record=HtmlRecordSpec( + content=self._merged_metrics_table_html(rows, metric_ranges) + ), + order=22, + ), + ] + return ViewList(blocks=blocks) + + def check_badges( + self, digest: Any, analysis: Dict[str, Any] + ) -> List[Dict[str, str]]: + if isinstance(digest, dict) and int(digest.get("match_count", 0)) > 0: + return [ + {"label": "PLA", "class": "badge", "title": "Per-layer accuracy"} + ] + return [] + + @staticmethod + def get_frontend_spec() -> Frontend: + return PerLayerAccuracyLens._PerLayerAccuracyFrontend() diff --git a/devtools/observatory/lenses/pipeline_graph_collector.py b/devtools/observatory/lenses/pipeline_graph_collector.py new file mode 100644 index 00000000000..1ed7a1e44ff --- /dev/null +++ b/devtools/observatory/lenses/pipeline_graph_collector.py @@ -0,0 +1,400 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Pipeline Graph Collector Lens — auto-collects graphs at compilation stages. + +This lens installs monkey-patches on framework-level functions to transparently +capture graph artifacts at each stage of the export → quantize → lower pipeline. +All patches are installed on session start and removed on session end. + +Collection points (in pipeline order): + 1. torch.export.export → "Exported Float" (ExportedProgram) + 2. prepare_pt2e → "Annotated Model" (GraphModule with observers) + 3. convert_pt2e (input) → "Calibrated Model" (post-calibration, pre-convert) + 4. convert_pt2e (output) → "Quantized Model" (GraphModule with Q/DQ ops) + 5. to_edge_transform_and_lower → "Pre-EdgeTransform/{method}" and "EdgeProgramManager EP" + 6. ETRecord.add_* → "ETRecord Exported/…", "ETRecord Edge/…", etc. + +Patching strategy: + - Framework-level patches (torchao, executorch.exir) work for ALL backends. + - Backend-specific patches are installed after framework-level patches to avoid + early module-import alias freezing. + - Contract: a backend-specific patch that emits "Exported Float" must also + populate `_last_calibration_dataset` so AccuracyLens can auto-configure. + - ETRecord patches fire when generate_etrecord=True (forced by the + to_edge_transform_and_lower patch). + - All originals are saved in _originals and restored on session end. +""" + +from __future__ import annotations + +import logging +from typing import Any, Callable, Dict, List, Optional + +from ..interfaces import AnalysisResult, Lens, ObservationContext, RecordDigest + + +class PipelineGraphCollectorLens(Lens): + """Unified graph collector — owns all graph-related monkey-patches.""" + + _installed: bool = False + _originals: Dict[str, Any] = {} + _collect_fn: Optional[Callable[[str, Any], None]] = None + # Cross-lens contract for AccuracyLens fallback dataset. + _last_calibration_dataset: Optional[list] = None + # Backend-specific patch installers registered via register_backend_patches(). + _backend_patch_installers: List[Callable] = [] + # Backend-specific uninstallers registered during patch installation. + _backend_uninstallers: List[Callable] = [] + + @classmethod + def register_backend_patches( + cls, installer: Callable[["PipelineGraphCollectorLens"], None] + ) -> None: + """Register a backend-specific patch installer. + + The installer receives the lens class and should use cls._originals, + cls._collect_fn, and cls._set_accuracy_fallback_dataset() for + standard integration. It may also append to cls._backend_uninstallers + to register cleanup logic. + """ + if installer not in cls._backend_patch_installers: + cls._backend_patch_installers.append(installer) + + @classmethod + def get_name(cls) -> str: + return "pipeline_graph_collector" + + @classmethod + def on_session_start(cls, context: ObservationContext) -> None: + if cls._installed: + return + + from ..observatory import Observatory + + cls._collect_fn = Observatory.collect + # Install backend-agnostic patches first. + cls._install_quantizer_patches() + cls._install_edge_lower_patch() + cls._install_etrecord_patches() + # Install backend-specific patches registered via register_backend_patches(). + for installer in cls._backend_patch_installers: + try: + installer(cls) + except Exception as exc: + logging.warning( + "[PipelineGraphCollector] Backend patch failed: %s", exc + ) + cls._installed = True + + @classmethod + def on_session_end(cls, context: ObservationContext) -> None: + cls._uninstall_all() + + @classmethod + def clear(cls) -> None: + cls._uninstall_all() + cls._last_calibration_dataset = None + cls._backend_patch_installers.clear() + cls._backend_uninstallers.clear() + + @classmethod + def observe(cls, artifact: Any, context: ObservationContext) -> Any: + return None + + @classmethod + def digest(cls, observation: Any, context: ObservationContext) -> Any: + return None + + @staticmethod + def analyze(records: List[RecordDigest], config: Dict[str, Any]) -> AnalysisResult: + return AnalysisResult() + + @classmethod + def _set_accuracy_fallback_dataset(cls, dataset: Any, source: str) -> None: + """Store dataset for AccuracyLens fallback. + + Backend-specific patch contract: + any patch that emits "Exported Float" should call this helper first. + """ + try: + dataset_list = list(dataset) if not isinstance(dataset, list) else dataset + if not dataset_list: + return + cls._last_calibration_dataset = dataset_list + logging.debug( + "[PipelineGraphCollector] Stored fallback dataset from %s (%d samples)", + source, + len(dataset_list), + ) + except Exception: + # Best-effort only; collection flow must not fail on dataset capture. + pass + + # ------------------------------------------------------------------ + # Patch: prepare_pt2e, convert_pt2e + # Captures annotated model (post-prepare) and quantized model + # (post-convert). Also captures the calibrated model (convert input). + # ------------------------------------------------------------------ + + @classmethod + def _install_quantizer_patches(cls) -> None: + try: + import torchao.quantization.pt2e.quantize_pt2e as qt_module + + # prepare_pt2e + original_prepare = qt_module.prepare_pt2e + cls._originals["prepare_pt2e"] = original_prepare + + def patched_prepare_pt2e(model, *args, **kwargs): + result = original_prepare(model, *args, **kwargs) + try: + cls._collect_fn("Annotated Model", result) + except Exception as exc: + logging.debug( + "[PipelineGraphCollector] collect skipped (Annotated Model): %s", + exc, + ) + return result + + qt_module.prepare_pt2e = patched_prepare_pt2e + logging.info("[PipelineGraphCollector] Installed patch: prepare_pt2e") + + # convert_pt2e — collect both input (calibrated) and output (quantized) + original_convert = qt_module.convert_pt2e + cls._originals["convert_pt2e"] = original_convert + + def patched_convert_pt2e(model, *args, **kwargs): + try: + cls._collect_fn("Calibrated Model", model) + except Exception as exc: + logging.debug( + "[PipelineGraphCollector] collect skipped (Calibrated Model): %s", + exc, + ) + result = original_convert(model, *args, **kwargs) + try: + cls._collect_fn("Quantized Model", result) + except Exception as exc: + logging.debug( + "[PipelineGraphCollector] collect skipped (Quantized Model): %s", + exc, + ) + return result + + qt_module.convert_pt2e = patched_convert_pt2e + logging.info("[PipelineGraphCollector] Installed patch: convert_pt2e") + except Exception as exc: + logging.warning( + "[PipelineGraphCollector] Failed to patch quantizer APIs: %s", exc + ) + + # ------------------------------------------------------------------ + # Patch: to_edge_transform_and_lower + # Forces generate_etrecord=True and collects: + # 1) pre-transform input programs, and + # 2) post-call EdgeProgramManager.exported_program(). + # ------------------------------------------------------------------ + + @classmethod + def _install_edge_lower_patch(cls) -> None: + try: + import executorch.exir.program._program as program_module + import executorch.exir as exir_module + + def _collect_pre_edge_transform_inputs(args, kwargs): + programs = kwargs.get("programs") + if programs is None and len(args) > 0: + programs = args[0] + if programs is None: + return + + if isinstance(programs, dict): + for method_name, program in programs.items(): + try: + cls._collect_fn(f"Pre-EdgeTransform/{method_name}", program) + except Exception as exc: + logging.debug( + "[PipelineGraphCollector] collect skipped (Pre-EdgeTransform/%s): %s", + method_name, + exc, + ) + else: + try: + cls._collect_fn("Pre-EdgeTransform/forward", programs) + except Exception as exc: + logging.debug( + "[PipelineGraphCollector] collect skipped (Pre-EdgeTransform/forward): %s", + exc, + ) + + def _make_patched_to_edge_transform_and_lower(original_fn): + def patched_to_edge_transform_and_lower(*args, **kwargs): + _collect_pre_edge_transform_inputs(args, kwargs) + kwargs["generate_etrecord"] = True + result = original_fn(*args, **kwargs) + try: + cls._collect_fn("EdgeProgramManager EP", result.exported_program()) + except Exception as exc: + logging.debug( + "[PipelineGraphCollector] collect skipped (EdgeProgramManager EP): %s", + exc, + ) + return result + + return patched_to_edge_transform_and_lower + + for i, module in enumerate([program_module, exir_module]): + original = module.to_edge_transform_and_lower + cls._originals[f"to_edge_transform_and_lower_{i}"] = original + module.to_edge_transform_and_lower = _make_patched_to_edge_transform_and_lower( + original + ) + logging.info( + "[PipelineGraphCollector] Installed patch: to_edge_transform_and_lower" + ) + except Exception as exc: + logging.warning( + "[PipelineGraphCollector] Failed to patch to_edge_transform_and_lower: %s", + exc, + ) + + # ------------------------------------------------------------------ + # Patch: ETRecord methods + # Auto-collects graph observations when ETRecord APIs are called. + # Absorbed from the former auto_collect.py module. + # ------------------------------------------------------------------ + + @classmethod + def _install_etrecord_patches(cls) -> None: + try: + from executorch.devtools.etrecord._etrecord import ETRecord + except Exception as exc: + logging.warning( + "[PipelineGraphCollector] Failed to import ETRecord: %s", exc + ) + return + + collect = cls._collect_fn + + def _safe_collect(name: str, artifact: Any) -> None: + try: + collect(name, artifact) + except Exception as exc: + logging.debug( + "[PipelineGraphCollector] ETRecord auto-collect skipped (%s): %s", + name, + exc, + ) + + def _wrap_add_exported_program(original): + def wrapped(self, exported_program): + result = original(self, exported_program) + if exported_program is None: + return result + if isinstance(exported_program, dict): + for method_name, program in exported_program.items(): + _safe_collect(f"ETRecord Exported/{method_name}", program) + else: + _safe_collect("ETRecord Exported/forward", exported_program) + return result + + return wrapped + + def _wrap_add_edge_dialect_program(original): + def wrapped(self, edge_dialect_program): + result = original(self, edge_dialect_program) + processed = getattr(self, "edge_dialect_program", None) + if isinstance(processed, dict): + for method_name, program in processed.items(): + _safe_collect(f"ETRecord Edge/{method_name}", program) + elif processed is not None: + _safe_collect("ETRecord Edge/forward", processed) + return result + + return wrapped + + def _wrap_add_extra_export_modules(original): + def wrapped(self, extra_recorded_export_modules): + result = original(self, extra_recorded_export_modules) + graph_map = getattr(self, "graph_map", {}) or {} + for module_name, program in graph_map.items(): + _safe_collect(f"ETRecord Extra/{module_name}", program) + return result + + return wrapped + + patches = { + "add_exported_program": _wrap_add_exported_program, + "add_edge_dialect_program": _wrap_add_edge_dialect_program, + "add_extra_export_modules": _wrap_add_extra_export_modules, + } + + for method_name, wrap_builder in patches.items(): + original = getattr(ETRecord, method_name, None) + if original is None: + continue + cls._originals[f"ETRecord.{method_name}"] = original + setattr(ETRecord, method_name, wrap_builder(original)) + + logging.info("[PipelineGraphCollector] Installed ETRecord patches") + + # ------------------------------------------------------------------ + # Uninstall all patches + # ------------------------------------------------------------------ + + @classmethod + def _uninstall_all(cls) -> None: + if not cls._installed: + return + + for key, original in cls._originals.items(): + try: + if key == "prepare_pt2e": + import torchao.quantization.pt2e.quantize_pt2e as qt_module + + qt_module.prepare_pt2e = original + elif key == "convert_pt2e": + import torchao.quantization.pt2e.quantize_pt2e as qt_module + + qt_module.convert_pt2e = original + elif key.startswith("to_edge_transform_and_lower"): + import executorch.exir.program._program as program_module + import executorch.exir as exir_module + for i, module in enumerate([program_module, exir_module]): + if str(i) == key[-1]: + module.to_edge_transform_and_lower = original + + elif key.startswith("ETRecord."): + try: + from executorch.devtools.etrecord._etrecord import ETRecord + + method_name = key.split(".", 1)[1] + setattr(ETRecord, method_name, original) + except Exception: + pass + else: + # Backend-specific patches store (module_attr, module) tuples + # or are handled by their own uninstall logic via _originals. + # Generic fallback: skip keys we don't recognize. + pass + except Exception as exc: + logging.warning( + "[PipelineGraphCollector] Failed to restore %s: %s", key, exc + ) + + cls._originals.clear() + cls._collect_fn = None + cls._last_calibration_dataset = None + for uninstaller in cls._backend_uninstallers: + try: + uninstaller() + except Exception as exc: + logging.warning( + "[PipelineGraphCollector] Backend uninstall failed: %s", exc + ) + cls._installed = False + logging.info("[PipelineGraphCollector] Uninstalled all patches") diff --git a/devtools/observatory/lenses/stack_trace.py b/devtools/observatory/lenses/stack_trace.py new file mode 100644 index 00000000000..81dc21aef21 --- /dev/null +++ b/devtools/observatory/lenses/stack_trace.py @@ -0,0 +1,128 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import inspect +import logging +import os +from typing import Any, Dict, List + +from ..interfaces import Frontend, HtmlBlock, HtmlRecordSpec, Lens, ObservationContext, ViewList +from ..utils import get_git_info, get_repo_root, is_in_repo + + +class StackTraceLens(Lens): + """Collects repository-local stack trace frames.""" + + @classmethod + def get_name(cls) -> str: + return "stack_trace" + + @classmethod + def _get_stack_trace(cls) -> List[Dict[str, Any]]: + repo_root = get_repo_root() + git_info = get_git_info() + + frames = [] + for frame_info in inspect.stack(): + if not is_in_repo(frame_info.filename): + continue + if "/observatory/observatory.py" in frame_info.filename.replace("\\", "/"): + continue + + github_link = None + link_root = git_info.commit_blob_url or git_info.branch_blob_url or git_info.github_link + if link_root and repo_root: + try: + rel_path = os.path.relpath(frame_info.filename, repo_root) + github_link = f"{link_root}/{rel_path}#L{frame_info.lineno}" + except Exception: + pass + + rel_path = frame_info.filename + if repo_root and frame_info.filename.startswith(repo_root): + rel_path = os.path.relpath(frame_info.filename, repo_root) + + frames.append( + { + "function": frame_info.function, + "filename": os.path.basename(rel_path), + "dir": os.path.dirname(rel_path), + "line": frame_info.lineno, + "context": frame_info.code_context[0].strip() if frame_info.code_context else None, + "link": github_link, + } + ) + + frames.reverse() + return frames + + @classmethod + def observe(cls, artifact: Any, context: ObservationContext) -> Any: + try: + return cls._get_stack_trace() + except Exception as exc: + logging.warning("[Observatory] Failed to collect stack trace: %s", exc) + return [] + + @classmethod + def digest(cls, observation: Any, context: ObservationContext) -> Any: + return observation + + class StackTraceFrontend(Frontend): + def record(self, digest, analysis, context): + if not digest: + return ViewList( + blocks=[ + HtmlBlock( + id="stack_trace_record", + title="Stack Trace", + record=HtmlRecordSpec(content="
No stack trace available
"), + order=40, + ) + ] + ) + + html = ["
"] + for frame in digest: + link_prefix = ( + f'' + if frame.get("link") + else "" + ) + link_suffix = "" if frame.get("link") else "" + snippet = "" + if frame.get("context"): + snippet = ( + "
" + f"{frame['context']}" + "
" + ) + + html.append( + "
" + f"
{frame['function']}
" + f"
{link_prefix}{frame['dir']}/{frame['filename']}:{frame['line']}{link_suffix}
" + f"{snippet}
" + ) + html.append("
") + + return ViewList( + blocks=[ + HtmlBlock( + id="stack_trace_record", + title="Stack Trace", + record=HtmlRecordSpec(content="".join(html)), + order=40, + ) + ] + ) + + @staticmethod + def get_frontend_spec() -> Frontend: + return StackTraceLens.StackTraceFrontend() diff --git a/devtools/observatory/observatory.py b/devtools/observatory/observatory.py new file mode 100644 index 00000000000..3710ed690d5 --- /dev/null +++ b/devtools/observatory/observatory.py @@ -0,0 +1,607 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Observatory runtime core. + +Lifecycle summary: +1. Runtime capture: `observe -> digest`. +2. Analysis: per-lens `analyze(records, config)`. +3. Assembly: merge frontend blocks + graph assets/layers. +4. Rendering/export: JSON and HTML reports. + +The runtime enforces strict typed interfaces from `interfaces.py`. +""" + +from __future__ import annotations + +import base64 +import copy +import gzip +import json +import logging +import math +import os +import time +import traceback +from contextlib import contextmanager +from dataclasses import asdict +from datetime import datetime +from json.encoder import _make_iterencode, encode_basestring, encode_basestring_ascii +from typing import Any, ContextManager, Dict, List, Optional, Set, Type + +from executorch.devtools.fx_viewer.exporter import FXGraphExporter + +from .graph_hub import GraphHub +from .interfaces import ( + AnalysisResult, + Frontend, + Lens, + ObservationContext, + RecordAnalysis, + RecordDigest, + SessionResult, + ViewList, + validate_view_list, +) + + +class _NonFiniteFloatAsStringJSONEncoder(json.JSONEncoder): + """JSON encoder that emits non-finite floats as strings.""" + + def iterencode(self, o: Any, _one_shot: bool = False): + markers = {} if self.check_circular else None + _encoder = encode_basestring_ascii if self.ensure_ascii else encode_basestring + + def floatstr( + value: float, + _repr=float.__repr__, + ) -> str: + if math.isnan(value): + return '"nan"' + if math.isinf(value): + return '"inf"' if value > 0 else '"-inf"' + return _repr(value) + + _iterencode = _make_iterencode( + markers, + self.default, + _encoder, + self.indent, + floatstr, + self.key_separator, + self.item_separator, + self.sort_keys, + self.skipkeys, + _one_shot, + ) + return _iterencode(o, 0) + + +class Observatory: + """Global registry for collecting and rendering observability artifacts.""" + + _records: Dict[str, RecordDigest] = {} + _ignored_graphs: Set[str] = set() + _session_result: SessionResult = SessionResult() + _lens_registry: List[Type[Lens]] = [] + _lenses_initialized: bool = False + _config_stack: List[Dict[str, Any]] = [] + + @classmethod + def register_lens(cls, lens_cls: Type[Lens], append=True) -> None: + """Register lens class and run one-time setup.""" + + if lens_cls in cls._lens_registry: + return + if append: + cls._lens_registry.append(lens_cls) + else: + cls._lens_registry.insert(0, lens_cls) + try: + lens_cls.setup() + except Exception as exc: + logging.error("[Observatory] Failed to setup lens %s: %s", lens_cls, exc) + + @classmethod + def _ensure_default_lenses(cls) -> None: + """Lazy-register built-in lenses for the minimal observatory runtime.""" + + if cls._lenses_initialized: + return + + from .lenses.graph import GraphLens + from .lenses.graph_color import GraphColorLens + from .lenses.metadata import MetadataLens + from .lenses.stack_trace import StackTraceLens + + # defaut lenses should stay in front in case other lenses depends on their data + cls.register_lens(GraphColorLens, append=False) + cls.register_lens(StackTraceLens, append=False) + cls.register_lens(GraphLens, append=False) + cls.register_lens(MetadataLens, append=False) + cls._lenses_initialized = True + + @classmethod + @contextmanager + def enable_context(cls, config: Optional[Dict[str, Any]] = None) -> ContextManager[None]: + """Enable observation context with nested config overrides. + + Session hooks run once per outermost context: + 1. On first enter, auto-collection patches are installed and + `on_session_start` hooks are called. + 2. On last exit, `on_session_end` hooks are called and patches removed. + """ + + cls._ensure_default_lenses() + + def merge_config_dict(base: Dict[str, Any], new: Dict[str, Any]) -> Dict[str, Any]: + result = copy.copy(base) + result.update({k: copy.copy(v) for k, v in base.items() if isinstance(v, dict)}) + for key, value in new.items(): + if isinstance(value, dict) and isinstance(result.get(key), dict): + result[key].update(value) + else: + result[key] = value + return result + + parent_config = cls._config_stack[-1] if cls._config_stack else {} + context_config = merge_config_dict(parent_config, config or {}) + + is_outermost_start = len(cls._config_stack) == 0 + cls._config_stack.append(context_config) + hook_ctx = ObservationContext(config=context_config) + + if is_outermost_start: + for lens in cls._lens_registry: + try: + data = lens.on_session_start(hook_ctx) + if data: + cls._session_result.start_data[lens.get_name()] = data + except Exception as exc: + logging.error("[Observatory] Lens %s failed on_session_start: %s", lens, exc) + + try: + yield + finally: + is_outermost_end = len(cls._config_stack) == 1 + + if is_outermost_end: + for lens in cls._lens_registry: + try: + data = lens.on_session_end(hook_ctx) + if data: + cls._session_result.end_data[lens.get_name()] = data + except Exception as exc: + logging.error("[Observatory] Lens %s failed on_session_end: %s", lens, exc) + + cls._config_stack.pop() + + @classmethod + def _get_current_context(cls) -> Optional[ObservationContext]: + """Return current observation context or None if disabled.""" + + if not cls._config_stack: + return None + return ObservationContext(config=cls._config_stack[-1]) + + @classmethod + def ignore_graphs(cls, names: List[str]) -> None: + """Ignore future collect calls with matching names and drop existing records.""" + + for name in names: + cls._ignored_graphs.add(name) + if name in cls._records: + del cls._records[name] + + @classmethod + def collect(cls, name: str, artifact: Any) -> None: + """Capture one record across all registered lenses. + + Notes: + 1. No-op when context is disabled. + 2. Record name is exposed via `context.shared_state['record_name']`. + 3. Duplicate names are auto-suffixed with #2, #3, … to prevent overwrites. + """ + + if any(ignored in name for ignored in cls._ignored_graphs): + return + + if not cls._config_stack: + return + + # Deduplicate: if name exists, suffix with #2, #3, ... + if name in cls._records: + n = 2 + while f"{name} #{n}" in cls._records: + n += 1 + name = f"{name} #{n}" + + active_config = cls._config_stack[-1] + ctx = ObservationContext(config=active_config) + ctx.shared_state["record_name"] = name + + record = RecordDigest(name=name, timestamp=datetime.now().timestamp()) + t_start = time.perf_counter() + + for lens in cls._lens_registry: + try: + lens_name = lens.get_name() + observation = lens.observe(artifact, ctx) + if observation is None: + continue + digest = lens.digest(observation, ctx) + if digest is not None: + record.data[lens_name] = digest + except Exception as exc: + logging.error("[Observatory] Lens %s failed collection for %s: %s", lens, name, exc) + + cls._records[name] = record + elapsed_ms = (time.perf_counter() - t_start) * 1000.0 + logging.info("[Observatory] Collected %s in %.1f ms", name, elapsed_ms) + + @classmethod + def list_collected(cls) -> List[str]: + return list(cls._records.keys()) + + @classmethod + def get(cls, name: str) -> Optional[RecordDigest]: + return cls._records.get(name) + + @classmethod + def clear(cls) -> None: + """Clear records/session state and reset lens runtime state.""" + + cls._records.clear() + cls._session_result = SessionResult() + + for lens in cls._lens_registry: + try: + lens.clear() + except Exception as exc: + logging.error("[Observatory] Lens %s failed clear: %s", lens, exc) + + @staticmethod + def _serialize_view_list(result: Any) -> Optional[Dict[str, Any]]: + """Validate and serialize frontend return values.""" + + if result is None: + return None + + if not isinstance(result, ViewList): + raise TypeError(f"Frontend must return ViewList, got {type(result)}") + + validate_view_list(result) + return {"blocks": [asdict(block) for block in result.blocks]} + + @classmethod + def _safe_frontend_call(cls, lens_name: str, method: Any, *args: Any, **kwargs: Any) -> Optional[Dict[str, Any]]: + """Run frontend method with error isolation and fallback error block.""" + + try: + result = method(*args, **kwargs) + return cls._serialize_view_list(result) + except Exception as exc: + logging.error( + "[Observatory] Frontend %s.%s failed: %s\n%s", + lens_name, + getattr(method, "__name__", ""), + exc, + traceback.format_exc(), + ) + return { + "blocks": [ + { + "id": "frontend_error", + "title": "Frontend Error", + "type": "html", + "record": { + "content": ( + '
' + f"Error: {str(exc)}
" + ) + }, + "compare": {"mode": "disabled"}, + "order": 999, + "collapsible": True, + } + ] + } + + @staticmethod + def _encode_html_blocks(serialized_records: list, dashboard: dict) -> None: + """Base64-encode HtmlBlock.content strings in-place to prevent JSON corruption.""" + + def _encode_blocks(blocks: list) -> None: + for block in blocks: + if block.get("type") == "html": + content = (block.get("record") or {}).get("content", "") + if content: + block["record"]["content"] = base64.b64encode( + content.encode("utf-8") + ).decode("ascii") + + for record in serialized_records: + for view in (record.get("views") or {}).values(): + _encode_blocks(view.get("blocks") or []) + for view in dashboard.values(): + _encode_blocks(view.get("blocks") or []) + + @staticmethod + def _compress_payload(json_data: str, threshold: int = 8192) -> tuple: + """Gzip+base64 compress JSON payload if above threshold bytes.""" + + raw = json_data.encode("utf-8") + if len(raw) >= threshold: + compressed = gzip.compress(raw, compresslevel=6) + return base64.b64encode(compressed).decode("ascii"), True + return json_data, False + + @classmethod + def _generate_report_payload( + cls, + records: List[RecordDigest], + session: SessionResult, + config: Dict[str, Any], + lens_registry: List[Type[Lens]], + ) -> Dict[str, Any]: + """Build full report payload including graph assets/layers and views.""" + + analysis_results: Dict[str, AnalysisResult] = { + lens.get_name(): lens.analyze(records, config) for lens in lens_registry + } + + resources: Dict[str, List[str]] = {"js": [], "css": []} + try: + resources["js"].append(FXGraphExporter._load_viewer_js_bundle()) + except Exception as exc: + logging.warning("[Observatory] Failed loading fx_viewer runtime bundle: %s", exc) + + for lens in lens_registry: + frontend = lens.get_frontend_spec() + res = frontend.resources() if isinstance(frontend, Frontend) else {} + if res.get("js"): + resources["js"].append(res["js"]) + if res.get("css"): + resources["css"].append(res["css"]) + + resources["js"] = [ + base64.b64encode(s.encode("utf-8")).decode("ascii") for s in resources["js"] + ] + resources["css"] = [ + base64.b64encode(s.encode("utf-8")).decode("ascii") for s in resources["css"] + ] + + graph_hub = GraphHub() + serialized_records = [] + + for i, record in enumerate(records): + serialized = { + "name": record.name, + "timestamp": datetime.fromtimestamp(record.timestamp).strftime("%Y-%m-%d %H:%M:%S"), + "views": {}, + "badges": [], + "diff_index": {}, + "digests": record.data, + } + + for lens in lens_registry: + lens_name = lens.get_name() + digest = record.data.get(lens_name) + + analysis = analysis_results.get(lens_name, AnalysisResult()) + record_analysis: RecordAnalysis | None = analysis.per_record_data.get(record.name) + graph_ref = record.name + # extract graph data from graph lens runtime digest + if lens_name == "graph" and digest is not None: + assert isinstance(digest, dict) and isinstance(digest.get("base"), dict), "[Observatory] error validating graph lense output." + assert digest["graph_ref"] == graph_ref, "[Observatory] graph ref should be consistant with record name" + graph_hub.register_asset( + graph_ref, + digest["base"], + digest.get("meta", {}), + ) + + # Merge analyze-phase graph layers even for analyze-only lenses. + graph_hub.add_analysis_layers(graph_ref, lens_name, record_analysis) + + # Everything below depends on runtime digest presence. + if digest is None: + continue + + analysis_ctx = { + "global": analysis.global_data, + "record": (record_analysis.data if record_analysis else {}), + } + + frontend = lens.get_frontend_spec() + try: + serialized["badges"].extend(frontend.check_badges(digest, analysis.global_data)) + except Exception as exc: + logging.error("[Observatory] check_badges failed for %s: %s", lens_name, exc) + + if i > 0: + prev_digest = records[i - 1].data.get(lens_name) + if prev_digest is not None: + try: + serialized["diff_index"].update( + frontend.check_index_diffs(prev_digest, digest, analysis.global_data) + ) + except Exception as exc: + logging.error("[Observatory] check_index_diffs failed for %s: %s", lens_name, exc) + + serialized_view = cls._safe_frontend_call( + lens_name, + frontend.record, + digest, + analysis_ctx, + {"index": i, "name": record.name}, + ) + if serialized_view: + serialized["views"][lens_name] = serialized_view + + serialized_records.append(serialized) + + dashboard_views = {} + for lens in lens_registry: + lens_name = lens.get_name() + frontend = lens.get_frontend_spec() + dashboard_view = cls._safe_frontend_call( + lens_name, + frontend.dashboard, + session.start_data.get(lens_name, {}), + session.end_data.get(lens_name, {}), + analysis_results.get(lens_name, AnalysisResult()).global_data, + records, + ) + if dashboard_view: + dashboard_views[lens_name] = dashboard_view + + graph_payload = graph_hub.build_payload() + graph_assets = graph_payload["graph_assets"] + graph_layers = graph_payload["graph_layers"] + for graph_ref, asset in graph_assets.items(): + if not isinstance(asset, dict): + continue + base_payload = asset.get("base") + if not isinstance(base_payload, dict): + continue + + extensions_payload = graph_layers.get(graph_ref, {}) + if not isinstance(extensions_payload, dict) or not extensions_payload: + continue + + try: + asset["base"] = FXGraphExporter.relayout_payload_base( + base_payload, + extensions_payload=extensions_payload, + ) + except Exception as exc: + logging.warning( + "[Observatory] FX relayout failed for graph_ref=%s: %s", + graph_ref, + exc, + ) + + payload = { + "resources": resources, + "records": serialized_records, + "dashboard": dashboard_views, + "analysis_results": { + key: { + "global_data": value.global_data, + "per_record_data": { + rec_name: {"data": rec_analysis.data} + for rec_name, rec_analysis in value.per_record_data.items() + }, + } + for key, value in analysis_results.items() + }, + "session": { + "start_data": session.start_data, + "end_data": session.end_data, + }, + "graph_assets": graph_assets, + "graph_layers": graph_layers, + } + + Observatory._encode_html_blocks(serialized_records, dashboard_views) + return payload + + @classmethod + def export_html_report( + cls, + output_path: str, + title: str = "Observatory Report", + config: Optional[Dict[str, Any]] = None, + ) -> None: + """Export collected records to HTML report.""" + + if not cls._records: + logging.warning("[Observatory] No records collected, skipping HTML export") + return + + cls._ensure_default_lenses() + payload = cls._generate_report_payload( + list(cls._records.values()), + cls._session_result, + config or {}, + cls._lens_registry, + ) + payload["title"] = title + payload["generated_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + from .html_template import get_html_template + + json_data = json.dumps(payload, cls=_NonFiniteFloatAsStringJSONEncoder).replace( + " None: + """Export raw records/session data as JSON.""" + + if not cls._records: + logging.warning("[Observatory] No records collected, skipping JSON export") + return + + data = { + "records": [asdict(r) for r in cls._records.values()], + "session": asdict(cls._session_result), + } + + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + with open(output_path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2, cls=_NonFiniteFloatAsStringJSONEncoder) + + logging.info("[Observatory] Exported raw data to %s", output_path) + + @staticmethod + def generate_html_from_json( + json_path: str, + html_path: str, + title: str = "Observatory Report", + config: Optional[Dict[str, Any]] = None, + ) -> None: + """Generate HTML report from previously exported raw JSON.""" + + with open(json_path, "r", encoding="utf-8") as f: + data = json.load(f) + + records = [RecordDigest(**r) for r in data["records"]] + session = SessionResult(**data["session"]) + + Observatory._ensure_default_lenses() + payload = Observatory._generate_report_payload( + records, + session, + config or {}, + Observatory._lens_registry, + ) + payload["title"] = title + payload["generated_at"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + from .html_template import get_html_template + + json_data = json.dumps(payload, cls=_NonFiniteFloatAsStringJSONEncoder).replace( + " Any: + """Wrap a pass to auto-collect its graph via Observatory. + + Works as a class decorator, instance wrapper, or function wrapper. + + Args: + target: A PassBase subclass (class decorator), a pass instance, or a + callable pass. When ``None``, returns a parameterized decorator. + name: Override the record name (default: derived from class/function name). + collect_input: Collect the input graph before the pass runs (default True). + collect_output: Collect the output graph after the pass runs (default True). + """ + collect_both = collect_input and collect_output + + def _base_name(obj: Any) -> str: + if name: + return name + if isinstance(obj, type): + return obj.__name__ + cls_name = type(obj).__name__ + if cls_name == "function": + return getattr(obj, "__name__", "pass") + return cls_name + + def _collect_artifact(record_name: str, artifact: Any) -> None: + from .observatory import Observatory + + try: + Observatory.collect(record_name, artifact) + except Exception as exc: + logging.debug("[observe_pass] collection failed for %s: %s", record_name, exc) + + def _output_graph(gm: Any, result: Any) -> Any: + if isinstance(result, tuple) and hasattr(result, "graph_module"): + return result.graph_module + if result is None: + return gm + return result + + def _wrap_callable(fn: Callable) -> Callable: + @functools.wraps(fn) + def wrapper(gm: Any, *args: Any, **kwargs: Any) -> Any: + base = _base_name(fn) + if collect_input: + _collect_artifact(f"{base}/input" if collect_both else base, gm) + result = fn(gm, *args, **kwargs) + if collect_output: + _collect_artifact( + f"{base}/output" if collect_both else base, + _output_graph(gm, result), + ) + return result + + return wrapper + + def _wrap_class(cls: Type) -> Type: + original_call = cls.__call__ + + @functools.wraps(original_call) + def patched_call(self: Any, gm: Any, *args: Any, **kwargs: Any) -> Any: + base = _base_name(self) + if collect_input: + _collect_artifact(f"{base}/input" if collect_both else base, gm) + result = original_call(self, gm, *args, **kwargs) + if collect_output: + _collect_artifact( + f"{base}/output" if collect_both else base, + _output_graph(gm, result), + ) + return result + + cls.__call__ = patched_call + return cls + + # Dispatch based on how observe_pass was called. + if target is None: + # Parameterized: @observe_pass(name="X", collect_output=False) + def decorator(t: Any) -> Any: + if isinstance(t, type): + return _wrap_class(t) + return _wrap_callable(t) + + return decorator + + if isinstance(target, type): + return _wrap_class(target) + + return _wrap_callable(target) diff --git a/devtools/observatory/template_loader.py b/devtools/observatory/template_loader.py new file mode 100644 index 00000000000..929f13a3a89 --- /dev/null +++ b/devtools/observatory/template_loader.py @@ -0,0 +1,44 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import os +from typing import List + + +_TEMPLATE_DIR = os.path.join(os.path.dirname(__file__), "templates") + + +def _read_file(path: str) -> str: + with open(path, "r", encoding="utf-8") as f: + return f.read() + + +def load_css() -> str: + """Load base observatory CSS.""" + + return _read_file(os.path.join(_TEMPLATE_DIR, "css", "main.css")) + + +def load_js_chunks() -> List[str]: + """Load ordered observatory JS runtime chunks.""" + + ordered = [ + "00_state.js", + "01_utils.js", + "02_layout.js", + "03_blocks.js", + "04_actions.js", + "05_bootstrap_api.js", + ] + + chunks: List[str] = [] + for filename in ordered: + path = os.path.join(_TEMPLATE_DIR, "js", filename) + chunks.append(f"\n// ---- {filename} ----\n") + chunks.append(_read_file(path)) + return chunks diff --git a/devtools/observatory/templates/css/main.css b/devtools/observatory/templates/css/main.css new file mode 100644 index 00000000000..727ffb1e166 --- /dev/null +++ b/devtools/observatory/templates/css/main.css @@ -0,0 +1,402 @@ + :root { + /* Light mode colors */ + --bg-primary: #f6f8fa; + --bg-secondary: #ffffff; + --bg-tertiary: #fafbfc; + --bg-code: #f4f4fa; + --text-primary: #24292e; + --text-secondary: #6a737d; + --text-inverse: #ffffff; + --border-color: #e1e4e8; + --header-bg: #24292e; + --link-color: #0366d6; + --link-hover: #0256c7; + --accent-color: #0366d6; + --success-color: #28a745; + --error-color: #d73a49; + --code-text: #d4d4d4; + --shadow: rgba(0,0,0,0.1); + --frame-border: #0366d6; + --diff-add-bg: #1e3a24; + --diff-add-color: #8cc696; + --diff-rem-bg: #4a2626; + --diff-rem-color: #e09690; + --diff-deep-add: #2e5c35; + --diff-deep-rem: #693333; + --success-bg: #dafbe1; + --error-bg: #ffebe9; + } + + [data-theme="dark"] { + /* Dark mode colors */ + --bg-primary: #0d1117; + --bg-secondary: #161b22; + --bg-tertiary: #21262d; + --bg-code: #0d1117; + --text-primary: #c9d1d9; + --text-secondary: #8b949e; + --text-inverse: #c9d1d9; + --border-color: #30363d; + --header-bg: #161b22; + --link-color: #79c0ff; + --link-hover: #a5d6ff; + --accent-color: #58a6ff; + --success-color: #3fb950; + --error-color: #f85149; + --code-text: #c9d1d9; + --shadow: rgba(0,0,0,0.3); + --frame-border: #58a6ff; + --diff-add-bg: #1e3a24; + --diff-add-color: #8cc696; + --diff-rem-bg: #4a2626; + --diff-rem-color: #e09690; + --diff-deep-add: #3a7545; + --diff-deep-rem: #8a4444; + --success-bg: #1e3a24; + --error-bg: #4a2626; + } + + * { margin: 0; padding: 0; box-sizing: border-box; } + + body { + font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif; + line-height: 1.6; + color: var(--text-primary); + background: var(--bg-primary); + height: 100vh; + overflow: hidden; + } + + /* App Layout */ + #app { + display: flex; + flex-direction: column; + height: 100vh; + } + + /* Header trigger zone — invisible strip at top */ + .header-trigger { + position: fixed; + top: 0; + left: 0; + right: 0; + height: 10px; + z-index: 201; + } + + header { + background: var(--bg-secondary); + color: var(--text-primary); + padding: 1rem 2rem; + display: flex; + justify-content: space-between; + align-items: center; + flex-shrink: 0; + border-bottom: 1px solid var(--border-color); + position: fixed; + top: 0; + left: 0; + right: 0; + transform: translateY(-100%); + transition: transform 0.25s ease, box-shadow 0.25s ease; + z-index: 200; + } + + .header-trigger:hover ~ header, + header:hover { + transform: translateY(0); + box-shadow: 0 4px 12px var(--shadow); + } + + header.hidden { + transform: translateY(-100%); + } + + [data-theme="dark"] header { + background: var(--header-bg); + color: var(--text-inverse); + border-bottom: none; + } + + .header-content h1 { font-size: 1.2rem; margin: 0; } + .header-meta { font-size: 0.85rem; opacity: 0.7; margin-top: 0.2rem; } + .header-meta span { margin-right: 1.5rem; } + .header-meta code { background: var(--bg-tertiary); padding: 0.2rem 0.4rem; border-radius: 3px; border: 1px solid var(--border-color); } + + [data-theme="dark"] .header-meta code { background: rgba(255,255,255,0.1); border: none; } + + /* Theme Toggle */ + .theme-toggle { + background: var(--bg-tertiary); + border: 1px solid var(--border-color); + color: var(--text-primary); + border-radius: 6px; + padding: 0.5rem 0.75rem; + cursor: pointer; + font-size: 1.2rem; + transition: all 0.2s; + display: flex; + align-items: center; + justify-content: center; + } + .theme-toggle:hover { background: var(--hover-color); transform: scale(1.05); } + + [data-theme="dark"] .theme-toggle { + background: rgba(255,255,255,0.1); + border: 1px solid rgba(255,255,255,0.2); + color: white; + } + [data-theme="dark"] .theme-toggle:hover { background: rgba(255,255,255,0.2); } + + .container { + display: flex; + flex: 1; + overflow: hidden; + } + + /* Sidebar trigger zone — invisible strip on left edge */ + .index-pane-trigger { + position: fixed; + left: 0; + top: 0; + bottom: 0; + width: 12px; + z-index: 200; + } + + /* Index Pane (Sidebar) */ + .index-pane { + position: fixed; + left: 0; + top: 0; + bottom: 0; + width: 300px; + background: var(--bg-secondary); + border-right: 1px solid var(--border-color); + display: flex; + flex-direction: column; + flex-shrink: 0; + z-index: 199; + transform: translateX(-100%); + transition: transform 0.25s ease, box-shadow 0.25s ease; + box-shadow: none; + } + + /* Show when trigger zone or pane itself is hovered */ + .index-pane-trigger:hover ~ .index-pane, + .index-pane:hover { + transform: translateX(0); + box-shadow: 4px 0 16px var(--shadow); + } + + .index-header { + padding: 1rem; + border-bottom: 1px solid var(--border-color); + background: var(--bg-tertiary); + } + .index-header h2 { font-size: 1rem; margin: 0; } + + .index-list { + flex: 1; + overflow-y: auto; + padding: 0.5rem; + list-style: none; + } + + .index-item { + padding: 0.6rem 0.8rem; + margin-bottom: 0.3rem; + border-radius: 6px; + cursor: pointer; + display: flex; + justify-content: space-between; + align-items: center; + transition: all 0.2s; + font-size: 0.9rem; + border-left: 3px solid transparent; + border: none; + background: var(--bg-primary); + box-shadow: 0 1px 2px rgba(0,0,0,0.05); + } + .index-item:hover { + background: var(--bg-tertiary); + box-shadow: 0 2px 4px var(--shadow); + transform: translateX(2px); + } + .index-item.active { + background: var(--accent-color); + color: white; + border-left-color: transparent; + border-color: var(--accent-color); + box-shadow: 0 2px 6px var(--shadow); + } + .index-item.diff-base { + border-left-color: var(--diff-rem-color); + background: rgba(224, 150, 144, 0.1); + border-color: var(--diff-rem-color); + } + .index-item.diff-new { + border-left-color: var(--diff-add-color); + background: rgba(140, 198, 150, 0.1); + border-color: var(--diff-add-color); + } + .index-item.selected { + background: rgba(3, 102, 214, 0.1); + border-left-color: var(--accent-color); + border-color: var(--accent-color); + } + .index-item.active .badge { background: rgba(255,255,255,0.2); border-color: rgba(255,255,255,0.3); color: white; } + + /* Badges */ + .badges { display: flex; gap: 0.3rem; flex-wrap: wrap; } + .badge { font-size: 0.7rem; padding: 0.1rem 0.3rem; border-radius: 3px; font-weight: 600; background: var(--bg-tertiary); border: 1px solid var(--border-color); color: var(--text-primary); } + .badge-error { background: var(--error-bg); border-color: var(--error-color); color: var(--error-color); } + .badge-success { background: var(--success-bg); border-color: var(--success-color); color: var(--success-color); } + + /* Diff Separator - Blends with background as spacing between records */ + .diff-separator { + padding: 0.6rem 1rem; + margin: 0.5rem 0; + background: var(--bg-secondary); + border: none; + cursor: pointer; + transition: all 0.2s; + list-style: none; + } + .diff-separator:hover { + background: var(--bg-tertiary); + border-radius: 4px; + } + .diff-content { + display: flex; + flex-direction: column; + gap: 0.25rem; + opacity: 0.7; + } + .diff-row { + display: flex; + justify-content: space-between; + align-items: center; + font-size: 0.75rem; + } + .diff-label { + font-weight: 500; + color: var(--text-secondary); + } + .diff-stats { + font-family: monospace; + display: flex; + gap: 0.4rem; + } + .stat-add { color: var(--success-color); font-weight: 600; } + .stat-rem { color: var(--error-color); font-weight: 600; } + + /* Main Content */ + .main-pane { + flex: 1; + overflow-y: auto; + padding: 2rem; + position: relative; + } + + /* Record View */ + .record-view h2 { color: var(--accent-color); border-bottom: 2px solid var(--border-color); padding-bottom: 0.5rem; margin-bottom: 1.5rem; } + + /* Toggleable Sections */ + .toggle-section { border: 1px solid var(--border-color); border-radius: 6px; overflow: hidden; margin-bottom: 1rem; } + .toggle-header { background: var(--bg-tertiary); padding: 0.75rem 1rem; font-size: 1rem; font-weight: 600; border-bottom: 1px solid var(--border-color); display: flex; justify-content: space-between; align-items: center; cursor: pointer; } + .toggle-header.collapsed { border-bottom: none; } + .toggle-header:hover { background: var(--border-color); } + .toggle-title { font-size: 1rem; } + .toggle-content { padding: 1rem; overflow-x: auto; background: var(--bg-secondary); } + .toggle-content.hidden { display: none; } + .copy-btn { + font-size: 0.75rem; + padding: 0.3rem 0.6rem; + border: none; + background: var(--bg-secondary); + color: var(--accent-fg); + border-radius: 4px; + cursor: pointer; + transition: all 0.2s; + font-weight: 500; + } + .copy-btn:hover { + background: var(--bg-primary); + transform: translateY(-3px); + box-shadow: 0 2px 4px var(--shadow); + } + .copy-btn.copied { + background: var(--success-color); + } + + /* Tables */ + table.kv-table { + width: 100%; + border-collapse: collapse; + font-size: 0.85rem; + } + .kv-table th, .kv-table td { + padding: 0.4rem 0.8rem; + text-align: left; + border-bottom: 1px solid var(--border-color); + } + .kv-table td { font-family: var(--font-mono); } + .kv-table tr:last-child th, .kv-table tr:last-child td { border-bottom: none; } + .kv-table th { color: var(--text-secondary); font-weight: normal; background: var(--bg-tertiary); } + + /* Regular KV tables - 40/60 split */ + .kv-table:not(.comparison-table) th { width: 40%; } + + /* Comparison tables - equal width columns */ + .kv-table.comparison-table { + table-layout: fixed; + } + .kv-table.comparison-table th, + .kv-table.comparison-table td { + width: auto; + } + + .split-view { display: flex; height: 100%; gap: 1rem; } + .split-pane { flex: 1; overflow: auto; border-right: 1px dashed var(--border-color); padding-right: 0.5rem; } + .split-pane:last-child { border-right: none; padding-right: 0; } + .split-pane h3 { font-size: 0.9rem; margin-bottom: 0.5rem; color: var(--accent-color); } + + /* Utility */ + .hidden { display: none; } + .btn { padding: 0.5rem 1rem; border-radius: 4px; cursor: pointer; border: 1px solid var(--border-color); background: var(--bg-tertiary); color: var(--text-primary); } + .btn:hover { background: var(--bg-secondary); } + .btn-sm { padding: 0.2rem 0.5rem; border-radius: 4px; cursor: pointer; border: 1px solid var(--border-color); background: var(--bg-tertiary); font-size: 0.8rem; margin-left: 0.5rem; color: var(--text-primary); } + .btn-sm:hover { background: var(--bg-secondary); } + .loading { padding: 2rem; text-align: center; color: var(--text-secondary); font-style: italic; } + + .clickable-name { cursor: pointer; color: var(--link-color); text-decoration: underline; } + .clickable-name:hover { color: var(--link-hover); } + + /* Toast Notifications */ + .toast { + position: fixed; + bottom: 2rem; + right: 2rem; + background: var(--bg-secondary); + color: var(--text-primary); + padding: 0.8rem 1.2rem; + border-radius: 6px; + box-shadow: 0 4px 12px var(--shadow); + border: 1px solid var(--border-color); + opacity: 0; + transform: translateY(1rem); + transition: all 0.3s ease; + z-index: 1000; + font-size: 0.9rem; + } + .toast.show { + opacity: 1; + transform: translateY(0); + } + .toast.toast-success { + border-left: 4px solid var(--success-color); + } + .toast.toast-error { + border-left: 4px solid var(--error-color); + } diff --git a/devtools/observatory/templates/js/00_state.js b/devtools/observatory/templates/js/00_state.js new file mode 100644 index 00000000000..9f3d5d407fc --- /dev/null +++ b/devtools/observatory/templates/js/00_state.js @@ -0,0 +1,18 @@ +(function() { + const OBS = (window.__observatory = window.__observatory || {}); + + OBS.state = { + data: window.OBSERVATORY_DATA || {}, + activeRecordIndex: -1, + theme: localStorage.getItem('graphCollectorTheme') || 'dark', + viewPrefs: JSON.parse(localStorage.getItem('graphCollectorViewPrefs') || '{}'), + selectionMode: false, + selectedIndices: new Set(), + mountedViewers: [], + mountedCompares: [], + viewerCache: new Map(), + graphCompareInstances: new Map(), + }; + + OBS.app = document.getElementById('app'); +})(); diff --git a/devtools/observatory/templates/js/01_utils.js b/devtools/observatory/templates/js/01_utils.js new file mode 100644 index 00000000000..42c0f9ac80a --- /dev/null +++ b/devtools/observatory/templates/js/01_utils.js @@ -0,0 +1,162 @@ +(function() { + const OBS = window.__observatory; + const state = OBS.state; + + function safeStr(val) { + if (val === null || val === undefined) return ''; + if (typeof val === 'object') return JSON.stringify(val); + return String(val); + } + + function escapeHtml(text) { + const div = document.createElement('div'); + div.textContent = text == null ? '' : String(text); + return div.innerHTML; + } + + function showToast(message, type = 'success') { + const existingToast = document.querySelector('.toast'); + if (existingToast) existingToast.remove(); + + const toast = document.createElement('div'); + toast.className = `toast toast-${type}`; + toast.textContent = message; + document.body.appendChild(toast); + + setTimeout(() => toast.classList.add('show'), 10); + setTimeout(() => { + toast.classList.remove('show'); + setTimeout(() => toast.remove(), 300); + }, 3000); + } + + function copyTable(tableEl) { + const rows = Array.from(tableEl.querySelectorAll('tr')); + const csv = rows + .map((row) => { + const cols = Array.from(row.querySelectorAll('td, th')); + return cols + .map((col) => `"${col.innerText.replace(/"/g, '""')}"`) + .join(','); + }) + .join('\n'); + + const html = `${tableEl.innerHTML}
`; + + try { + const blobHTML = new Blob([html], { type: 'text/html' }); + const blobText = new Blob([csv], { type: 'text/plain' }); + const clipboardItem = new ClipboardItem({ + 'text/html': blobHTML, + 'text/plain': blobText, + }); + navigator.clipboard + .write([clipboardItem]) + .then(() => showToast('Copied to clipboard!', 'success')) + .catch(() => showToast('Failed to copy', 'error')); + } catch (_e) { + navigator.clipboard + .writeText(csv) + .then(() => showToast('Copied to clipboard!', 'success')) + .catch(() => showToast('Failed to copy', 'error')); + } + } + + function resolveFunction(path) { + if (!path || typeof path !== 'string') return null; + const parts = path.split('.'); + let fn = window; + for (const p of parts) fn = fn && fn[p]; + return typeof fn === 'function' ? fn : null; + } + + function getLensBlocks(record, lensName) { + const lensView = (record.views || {})[lensName]; + if (!lensView || !Array.isArray(lensView.blocks)) return []; + return lensView.blocks.slice().sort((a, b) => Number(a.order || 0) - Number(b.order || 0)); + } + + function toArraySet(maybeSet) { + return Array.from(maybeSet || []).sort((a, b) => a - b); + } + + function buildViewerPayload(graphRef) { + const assets = state.data.graph_assets || {}; + const layers = state.data.graph_layers || {}; + const asset = assets[graphRef] || {}; + return { + base: asset.base || { legend: [], nodes: [], edges: [] }, + extensions: layers[graphRef] || {}, + }; + } + + const MAX_CACHED_VIEWERS = 10; + + function buildViewerCacheKey(mode, recordIndex, lensName, blockId) { + if (mode === 'single') { + return `single:${recordIndex}:${lensName}:${blockId}`; + } + return `compare:${lensName}:${blockId}`; + } + + function evictViewerCache() { + if (state.viewerCache.size < MAX_CACHED_VIEWERS) return; + let oldestKey = null, oldestTime = Infinity; + for (const [k, entry] of state.viewerCache) { + if (entry.lastAccessed < oldestTime) { + oldestTime = entry.lastAccessed; + oldestKey = k; + } + } + if (oldestKey) { + const entry = state.viewerCache.get(oldestKey); + try { entry.viewer.destroy(); } catch (_) {} + state.viewerCache.delete(oldestKey); + } + } + + function destroyGraphRuntime() { + const persistentCompares = new Set( + Array.from(state.graphCompareInstances.values()).map(inst => inst.compare) + ); + + for (const compare of state.mountedCompares) { + if (persistentCompares.has(compare)) continue; + try { + if (compare && typeof compare.destroy === 'function') compare.destroy(); + } catch (_e) {} + } + state.mountedCompares = []; + + for (const viewer of state.mountedViewers) { + if (!viewer) continue; + if (viewer.wrapper && viewer.wrapper.parentNode) { + viewer.wrapper.parentNode.removeChild(viewer.wrapper); + } + } + state.mountedViewers = []; + } + + function destroyAllGraphCompares() { + for (const [, inst] of state.graphCompareInstances) { + try { inst.compare.destroy(); } catch (_) {} + } + state.graphCompareInstances.clear(); + } + + OBS.utils = { + safeStr, + escapeHtml, + showToast, + copyTable, + resolveFunction, + getLensBlocks, + toArraySet, + buildViewerPayload, + destroyGraphRuntime, + destroyAllGraphCompares, + buildViewerCacheKey, + evictViewerCache, + MAX_CACHED_VIEWERS, + }; +})(); diff --git a/devtools/observatory/templates/js/02_layout.js b/devtools/observatory/templates/js/02_layout.js new file mode 100644 index 00000000000..39cf5cebeef --- /dev/null +++ b/devtools/observatory/templates/js/02_layout.js @@ -0,0 +1,152 @@ +(function() { + const OBS = window.__observatory; + const state = OBS.state; + const { escapeHtml } = OBS.utils; + + function renderLayout() { + const icon = state.theme === 'dark' ? '☀️' : '🌙'; + OBS.app.innerHTML = ` +
+
+
+

${escapeHtml(state.data.title || 'Observatory Report')}

+
+ Generated: ${escapeHtml(state.data.generated_at || '')} +
+
+
+ +
+
+
+
+ +
+
+ `; + updateIndexHeader(); + } + + function updateIndexHeader() { + const header = document.getElementById('index-header'); + if (!header) return; + + if (state.selectionMode) { + const total = (state.data.records || []).length; + const allSelected = state.selectedIndices.size === total; + header.innerHTML = ` +
+ Selected: ${state.selectedIndices.size} +
+ + +
+
+ `; + return; + } + + header.innerHTML = ` +
+

Collected Graphs (${(state.data.records || []).length})

+ +
+ `; + } + + function renderIndex() { + const list = document.getElementById('index-list'); + if (!list) return; + + const records = state.data.records || []; + let html = ` +
  • + 📊 Run Dashboard +
  • + `; + + records.forEach((rec, idx) => { + const isSelected = state.selectedIndices.has(idx); + const checkbox = state.selectionMode + ? `` + : ''; + + let activeClass = ''; + if (state.selectionMode) { + if (isSelected) activeClass = 'selected'; + } else if (typeof state.activeRecordIndex === 'object' && state.activeRecordIndex !== null) { + if (state.activeRecordIndex.pool && state.activeRecordIndex.pool.includes(idx)) activeClass = 'selected'; + if (state.activeRecordIndex.base === idx) activeClass = 'diff-base'; + if (state.activeRecordIndex.new === idx) activeClass = 'diff-new'; + } else if (state.activeRecordIndex === idx) { + activeClass = 'active'; + } + + const badges = (rec.badges || []) + .map((b) => { + const badgeClass = b.class || 'badge'; + const title = escapeHtml(b.title || b.label || ''); + const label = escapeHtml(b.label || ''); + return `${label}`; + }) + .join(''); + + if (rec.diff_index && Object.keys(rec.diff_index).length > 0 && idx > 0) { + const rows = Object.entries(rec.diff_index) + .map(([key, val]) => { + const text = String(val); + const plusMatch = text.match(/\+((?:\d+\.?\d*)|(?:\.\d+))/); + const minusMatch = text.match(/-((?:\d+\.?\d*)|(?:\.\d+))/); + let stats = ''; + if (plusMatch || minusMatch) { + if (plusMatch) stats += `+${plusMatch[1]}`; + if (minusMatch) stats += `-${minusMatch[1]}`; + } else { + stats = `${escapeHtml(text)}`; + } + return ` +
    + ${escapeHtml(key)} + ${stats} +
    + `; + }) + .join(''); + + html += ` +
  • +
    ${rows}
    +
  • + `; + } + + html += ` +
  • +
    + ${checkbox} +
    +
    ${escapeHtml(rec.name || '')}
    +
    +
    +
    ${badges}
    +
  • + `; + }); + + list.innerHTML = html; + updateIndexHeader(); + } + + OBS.layout = { + renderLayout, + renderIndex, + updateIndexHeader, + }; +})(); diff --git a/devtools/observatory/templates/js/03_blocks.js b/devtools/observatory/templates/js/03_blocks.js new file mode 100644 index 00000000000..a11b0b60e6e --- /dev/null +++ b/devtools/observatory/templates/js/03_blocks.js @@ -0,0 +1,683 @@ +(function() { + const OBS = window.__observatory; + const state = OBS.state; + const { + safeStr, + escapeHtml, + copyTable, + resolveFunction, + getLensBlocks, + toArraySet, + buildViewerPayload, + destroyGraphRuntime, + buildViewerCacheKey, + evictViewerCache, + } = OBS.utils; + + function refreshCompareLayouts(container, options) { + if (!container || !(state.graphCompareInstances instanceof Map)) return; + for (const [, inst] of state.graphCompareInstances) { + const compare = inst && inst.compare; + if (!compare || !compare._root || !compare._root.isConnected) continue; + if (container !== compare._root && !container.contains(compare._root)) continue; + try { + if (typeof compare.refreshLayout === 'function') { + compare.refreshLayout(options || {}); + } else { + for (const v of compare.viewers || []) { + try { v.canvasRenderer.resize(); } catch (_) {} + try { + if (v.minimapRenderer) { + v.minimapRenderer.resize(); + v.minimapRenderer.generateThumbnail(); + } + } catch (_) {} + try { v.renderAll(); } catch (_) {} + } + } + } catch (_) {} + } + } + + function createSection(title, storageKey, collapsible) { + const isCollapsible = collapsible !== false; + const isCollapsed = isCollapsible && state.viewPrefs[storageKey] === false; + + const section = document.createElement('div'); + section.className = 'toggle-section'; + + const header = document.createElement('div'); + header.className = `toggle-header ${isCollapsed ? 'collapsed' : ''}`; + + const titleSpan = document.createElement('span'); + titleSpan.className = 'toggle-title'; + titleSpan.textContent = title; + header.appendChild(titleSpan); + + const content = document.createElement('div'); + content.className = `toggle-content ${isCollapsed ? 'hidden' : ''}`; + + if (isCollapsible) { + header.onclick = () => { + content.classList.toggle('hidden'); + header.classList.toggle('collapsed'); + const isExpanded = !content.classList.contains('hidden'); + state.viewPrefs[storageKey] = isExpanded; + localStorage.setItem('graphCollectorViewPrefs', JSON.stringify(state.viewPrefs)); + if (isExpanded) { + requestAnimationFrame(() => refreshCompareLayouts(content)); + } + }; + } + + section.appendChild(header); + section.appendChild(content); + + return { section, header, content }; + } + + function renderTableContent(content, data) { + const table = document.createElement('table'); + table.className = 'kv-table'; + const tbody = document.createElement('tbody'); + + const entries = Object.entries(data || {}); + if (entries.length === 0) { + const tr = document.createElement('tr'); + const td = document.createElement('td'); + td.colSpan = 2; + td.textContent = '(empty)'; + tr.appendChild(td); + tbody.appendChild(tr); + } else { + for (const [key, val] of entries) { + const tr = document.createElement('tr'); + const th = document.createElement('th'); + th.textContent = key; + const td = document.createElement('td'); + td.textContent = safeStr(val); + tr.appendChild(th); + tr.appendChild(td); + tbody.appendChild(tr); + } + } + + table.appendChild(tbody); + content.appendChild(table); + return table; + } + + function resolveViewerCtor() { + if (typeof FXGraphViewer !== 'undefined') return FXGraphViewer; + if (window && window.FXGraphViewer) return window.FXGraphViewer; + return null; + } + + function resolveGraphRef(graphRecord, fallbackGraphRef) { + if (!graphRecord || typeof graphRecord !== 'object') return fallbackGraphRef || ''; + return ( + graphRecord.graph_ref || + graphRecord.graphRef || + graphRecord.record_name || + graphRecord.recordName || + fallbackGraphRef || + '' + ); + } + + function mountGraphViewer(root, graphRecord, viewerOptions, fallbackGraphRef, cacheKey) { + // Cache hit: reattach existing live viewer + if (cacheKey && state.viewerCache.has(cacheKey)) { + const entry = state.viewerCache.get(cacheKey); + entry.lastAccessed = Date.now(); + root.appendChild(entry.wrapper); + if (entry.viewer && entry.viewer.rootContainer !== root) { + entry.viewer.rootContainer = root; + if (entry.viewer.config && entry.viewer.config.mount) { + entry.viewer.config.mount.root = root; + } + if (entry.viewer.config && entry.viewer.config._resolved) { + entry.viewer.config._resolved.root = root; + } + } + requestAnimationFrame(() => { + try { entry.viewer.canvasRenderer.resize(); } catch (_) {} + try { entry.viewer.renderAll(); } catch (_) {} + }); + state.mountedViewers.push(entry.viewer); + return entry.viewer; + } + + // Cache miss: create new viewer + const ViewerCtor = resolveViewerCtor(); + const graphRef = resolveGraphRef(graphRecord, fallbackGraphRef); + + if (!ViewerCtor || !graphRef) { + const reason = !ViewerCtor ? 'FXGraphViewer unavailable' : 'graph_ref missing'; + root.innerHTML = `
    ${reason}.
    `; + return null; + } + + const payload = buildViewerPayload(graphRef); + const defaultLayers = Array.isArray(graphRecord.default_layers) ? graphRecord.default_layers : []; + const defaultColorBy = graphRecord.default_color_by || (defaultLayers.length > 0 ? defaultLayers[0] : 'base'); + + const layoutMode = (viewerOptions || {}).layout_mode || 'full'; + let preset = 'split'; + if (layoutMode === 'compare_compact') preset = 'compact'; + if (layoutMode === 'headless') preset = 'headless'; + + const viewer = ViewerCtor.create({ + payload, + mount: { root }, + layout: { preset, fullscreen: { button: true } }, + state: { activeExtensions: defaultLayers, colorBy: defaultColorBy, themeName: state.theme }, + }); + + // FIX: defer init() until after browser layout pass so getBoundingClientRect() is valid + requestAnimationFrame(() => viewer.init()); + + if ((viewerOptions || {}).sidebar_mode === 'hidden' && typeof viewer.setLayout === 'function') { + try { viewer.setLayout({ panels: { sidebar: { visible: false } } }); } catch (_e) {} + } + if ((viewerOptions || {}).minimap_mode === 'off' && typeof viewer.setUIVisibility === 'function') { + try { viewer.setUIVisibility({ minimapToggle: false }); } catch (_e) {} + } + + if (cacheKey) { + evictViewerCache(); + state.viewerCache.set(cacheKey, { + viewer, + wrapper: viewer.wrapper, + lastAccessed: Date.now(), + }); + } + + state.mountedViewers.push(viewer); + return viewer; + } + + function renderRecordBlock(container, lensName, block, context, analysis) { + const storageKey = `${lensName}:${block.id}`; + const title = block.title || block.id || lensName; + const { section, header, content } = createSection(title, storageKey, block.collapsible); + + if (block.type === 'table') { + const table = renderTableContent(content, block.record && block.record.data); + const copyBtn = document.createElement('button'); + copyBtn.className = 'copy-btn'; + copyBtn.innerText = 'Copy'; + copyBtn.onclick = (e) => { + e.stopPropagation(); + copyTable(table); + }; + header.appendChild(copyBtn); + } else if (block.type === 'html') { + const raw = (block.record && block.record.content) || ''; + let decoded = raw; + try { decoded = atob(raw); } catch(_) {} + content.innerHTML = decoded; + } else if (block.type === 'custom') { + const jsFunc = block.record && block.record.js_func; + const fn = resolveFunction(jsFunc); + if (!fn) { + content.innerHTML = `
    Function ${escapeHtml(jsFunc || '')} not found
    `; + } else { + try { + fn(content, (block.record && block.record.args) || {}, context, analysis); + } catch (err) { + content.innerHTML = `
    JS Error: ${escapeHtml(err.message || String(err))}
    `; + } + } + } else if (block.type === 'graph') { + const graphRoot = document.createElement('div'); + graphRoot.style.height = '1000px'; + graphRoot.style.minHeight = '800px'; + graphRoot.style.border = '1px solid var(--border-color)'; + graphRoot.style.borderRadius = '8px'; + graphRoot.style.overflow = 'hidden'; + content.appendChild(graphRoot); + const fallbackGraphRef = (context && context.record && context.record.name) || ''; + const recordIndex = (context && context.index !== undefined) ? context.index : -1; + const cacheKey = (recordIndex >= 0) + ? buildViewerCacheKey('single', recordIndex, lensName, block.id || block.type) + : null; + mountGraphViewer(graphRoot, block.record || {}, (block.record && block.record.viewer_options) || {}, fallbackGraphRef, cacheKey); + } else { + content.innerHTML = `
    Unsupported block type: ${escapeHtml(block.type || '')}
    `; + } + + container.appendChild(section); + } + + function renderDashboard(container) { + container.innerHTML = '

    Run Dashboard

    '; + + const dashboard = state.data.dashboard || {}; + let hasContent = false; + + for (const [lensName, viewList] of Object.entries(dashboard)) { + const blocks = Array.isArray(viewList && viewList.blocks) + ? viewList.blocks.slice().sort((a, b) => Number(a.order || 0) - Number(b.order || 0)) + : []; + if (blocks.length === 0) continue; + + const analysis = state.data.analysis_results && state.data.analysis_results[lensName]; + const context = { + start: (state.data.session && state.data.session.start_data && state.data.session.start_data[lensName]) || {}, + end: (state.data.session && state.data.session.end_data && state.data.session.end_data[lensName]) || {}, + records: state.data.records || [], + }; + + for (const block of blocks) { + renderRecordBlock(container, lensName, block, context, analysis); + hasContent = true; + } + } + + if (!hasContent) container.innerHTML += '

    No dashboard data available.

    '; + } + + function renderTableCompare(content, entries) { + const allKeys = new Set(); + for (const entry of entries) { + const data = entry.block && entry.block.record && entry.block.record.data; + if (!data || typeof data !== 'object') continue; + Object.keys(data).forEach((k) => allKeys.add(k)); + } + + if (allKeys.size === 0) { + content.innerHTML = '

    No table data to compare.

    '; + return null; + } + + const table = document.createElement('table'); + table.className = 'kv-table comparison-table'; + + const thead = document.createElement('thead'); + const headerRow = document.createElement('tr'); + const th0 = document.createElement('th'); + th0.textContent = 'Property'; + headerRow.appendChild(th0); + + for (const entry of entries) { + const th = document.createElement('th'); + const span = document.createElement('span'); + span.className = 'clickable-name'; + span.textContent = entry.record.name || `record_${entry.idx}`; + span.onclick = (e) => { + e.stopPropagation(); + window.selectRecord(entry.idx, true); + }; + th.appendChild(span); + headerRow.appendChild(th); + } + + thead.appendChild(headerRow); + table.appendChild(thead); + + const tbody = document.createElement('tbody'); + for (const key of Array.from(allKeys).sort()) { + const tr = document.createElement('tr'); + const th = document.createElement('th'); + th.textContent = key; + tr.appendChild(th); + + for (const entry of entries) { + const td = document.createElement('td'); + const data = entry.block && entry.block.record && entry.block.record.data; + td.textContent = data && data[key] !== undefined ? safeStr(data[key]) : '-'; + tr.appendChild(td); + } + + tbody.appendChild(tr); + } + + table.appendChild(tbody); + content.appendChild(table); + return table; + } + + function renderHtmlCompare(content, entries) { + const split = document.createElement('div'); + split.className = 'split-view'; + + for (const entry of entries) { + const pane = document.createElement('div'); + pane.className = 'split-pane'; + + const h3 = document.createElement('h3'); + const span = document.createElement('span'); + span.className = 'clickable-name'; + span.textContent = entry.record.name || `record_${entry.idx}`; + span.onclick = () => window.selectRecord(entry.idx, true); + h3.appendChild(span); + pane.appendChild(h3); + + const raw = (entry.block && entry.block.record && entry.block.record.content) || ''; + let decoded = raw; + try { decoded = atob(raw); } catch(_) {} + const div = document.createElement('div'); + div.innerHTML = decoded; + pane.appendChild(div); + split.appendChild(pane); + } + + content.appendChild(split); + } + + function renderCustomCompare(content, entries, compareSpec, sampleBlock, lensName, blockId) { + const recordJsFunc = sampleBlock && sampleBlock.record && sampleBlock.record.js_func; + const jsFunc = compareSpec.js_func || recordJsFunc || ''; + const fn = resolveFunction(jsFunc); + if (!fn) { + content.innerHTML = `
    Function ${escapeHtml(jsFunc)} not found
    `; + return; + } + + const context = { + indices: entries.map((e) => e.idx), + names: entries.map((e) => e.record.name), + records: entries.map((e) => e.record), + blocks: entries.map((e) => e.block), + lens: lensName, + block_id: blockId, + }; + + try { + fn(content, compareSpec.args || {}, context, state.data.analysis_results && state.data.analysis_results[lensName]); + } catch (err) { + content.innerHTML = `
    JS Error: ${escapeHtml(err.message || String(err))}
    `; + } + } + + function ensureCompareInstance(content, allEntries, compareSpec, lensName, blockId) { + const cacheKey = `${lensName}:${blockId}`; + const cached = state.graphCompareInstances.get(cacheKey); + + if (cached) { + content.appendChild(cached.compare._root); + requestAnimationFrame(() => { + if (typeof cached.compare.refreshLayout === 'function') { + try { cached.compare.refreshLayout(); } catch (_) {} + return; + } + for (const v of cached.compare.viewers) { + try { v.canvasRenderer.resize(); } catch (_) {} + try { + if (v.minimapRenderer) { + v.minimapRenderer.resize(); + v.minimapRenderer.generateThumbnail(); + } + } catch (_) {} + try { v.renderAll(); } catch (_) {} + } + }); + return cacheKey; + } + + const placeholder = document.createElement('div'); + placeholder.className = 'loading'; + placeholder.textContent = 'Building graph compare view\u2026'; + content.appendChild(placeholder); + + buildCompareAsync(content, placeholder, compareSpec, lensName, blockId, cacheKey); + return cacheKey; + } + + async function buildCompareAsync(content, placeholder, compareSpec, lensName, blockId, cacheKey) { + const CompareCtor = typeof FXGraphCompare !== 'undefined' ? FXGraphCompare : (window && window.FXGraphCompare); + if (!CompareCtor) { + if (placeholder) placeholder.innerHTML = 'FXGraphCompare unavailable.'; + return; + } + + const records = state.data.records || []; + const viewerMap = new Map(); + const nameToIndex = new Map(); + const isOnscreen = placeholder !== null; + + for (let idx = 0; idx < records.length; idx++) { + if (isOnscreen && !content.isConnected) return; + + const record = records[idx]; + if (!record) continue; + const blocks = getLensBlocks(record, lensName); + const block = blocks.find(b => (b.id || `${lensName}_${b.type}`) === (blockId || `${lensName}_graph`)); + if (!block || block.type !== 'graph') continue; + + const graphRoot = document.createElement('div'); + graphRoot.style.height = '520px'; + graphRoot.style.minHeight = '360px'; + graphRoot.style.overflow = 'hidden'; + + const options = Object.assign({}, (block.record && block.record.viewer_options) || {}); + const fallbackGraphRef = record.name || ''; + const viewer = mountGraphViewer(graphRoot, block.record || {}, options, fallbackGraphRef, null); + if (!viewer) continue; + + const name = record.name || `record_${idx}`; + viewerMap.set(name, viewer); + nameToIndex.set(name, idx); + + await new Promise(resolve => requestAnimationFrame(resolve)); + } + + if (isOnscreen && !content.isConnected) return; + + if (viewerMap.size === 0) { + if (placeholder) placeholder.innerHTML = 'No graph viewers could be created.'; + return; + } + + if (placeholder && placeholder.parentNode) placeholder.parentNode.removeChild(placeholder); + + const compare = CompareCtor.create({ + viewers: viewerMap, + layout: { container: content }, + sync: compareSpec.default_sync && compareSpec.default_sync.mode + ? compareSpec.default_sync + : { mode: 'auto' }, + }); + + state.graphCompareInstances.set(cacheKey, { compare, nameToIndex }); + + const compareViewerSet = new Set(viewerMap.values()); + state.mountedViewers = state.mountedViewers.filter(v => !compareViewerSet.has(v)); + + const visibleIndices = getCurrentVisibleIndices(); + if (visibleIndices) syncCompareVisibility(cacheKey, visibleIndices); + } + + function getCurrentVisibleIndices() { + if (state.selectionMode) return toArraySet(state.selectedIndices); + if (typeof state.activeRecordIndex === 'object' && state.activeRecordIndex !== null) { + return state.activeRecordIndex.pool + ? state.activeRecordIndex.pool + : [state.activeRecordIndex.base, state.activeRecordIndex.new].filter(x => Number.isInteger(x)); + } + return null; + } + + async function warmCompareInstances() { + const records = state.data.records || []; + if (records.length < 2) return; + + const graphBlockIds = new Map(); + for (const record of records) { + for (const lensName of Object.keys((record && record.views) || {})) { + const blocks = getLensBlocks(record, lensName); + for (const block of blocks) { + if (block.type !== 'graph') continue; + const id = block.id || `${lensName}_${block.type}`; + const key = `${lensName}:${id}`; + if (!graphBlockIds.has(key)) { + graphBlockIds.set(key, { lensName, blockId: id, compareSpec: block.compare || {} }); + } + } + } + } + + for (const [, spec] of graphBlockIds) { + const cacheKey = `${spec.lensName}:${spec.blockId}`; + if (state.graphCompareInstances.has(cacheKey)) continue; + const offscreen = document.createElement('div'); + await buildCompareAsync(offscreen, null, spec.compareSpec, spec.lensName, spec.blockId, cacheKey); + } + } + + function syncCompareVisibility(cacheKey, visibleIndices) { + const inst = state.graphCompareInstances.get(cacheKey); + if (!inst) return; + for (const [name, index] of inst.nameToIndex) { + inst.compare.setViewerVisible(name, visibleIndices.includes(index)); + } + } + + function defaultCompareMode(block) { + if (!block) return 'disabled'; + if (block.type === 'table' || block.type === 'html' || block.type === 'graph') return 'auto'; + return 'disabled'; + } + + function renderCompareLens(recordView, lensName, indices) { + const records = state.data.records || []; + const entriesByBlock = new Map(); + + for (const idx of indices) { + const record = records[idx]; + if (!record) continue; + const blocks = getLensBlocks(record, lensName); + for (const block of blocks) { + const id = block.id || `${lensName}_${block.type}`; + if (!entriesByBlock.has(id)) entriesByBlock.set(id, []); + entriesByBlock.get(id).push({ idx, record, block }); + } + } + + const blockEntries = Array.from(entriesByBlock.values()) + .filter((list) => list.length > 0) + .sort((a, b) => Number((a[0].block && a[0].block.order) || 0) - Number((b[0].block && b[0].block.order) || 0)); + + for (const entries of blockEntries) { + if (entries.length < 2) continue; + + const sample = entries[0].block; + const compareSpec = sample.compare || {}; + const mode = compareSpec.mode || defaultCompareMode(sample); + if (mode === 'disabled') continue; + + const blockLabel = sample.title || sample.id || sample.type; + const sectionKey = `cmp:${lensName}:${sample.id}`; + const { section, header, content } = createSection(`Comparison: ${lensName} / ${blockLabel}`, sectionKey, sample.collapsible); + + if (sample.type === 'table' && mode === 'auto') { + const table = renderTableCompare(content, entries); + if (table) { + const copyBtn = document.createElement('button'); + copyBtn.className = 'copy-btn'; + copyBtn.innerText = 'Copy'; + copyBtn.onclick = (e) => { + e.stopPropagation(); + copyTable(table); + }; + header.appendChild(copyBtn); + } + } else if (sample.type === 'html' && mode === 'auto') { + renderHtmlCompare(content, entries); + } else if (sample.type === 'graph' && mode === 'auto') { + const cacheKey = ensureCompareInstance(content, entries, compareSpec, lensName, sample.id); + syncCompareVisibility(cacheKey, indices); + } else if (mode === 'custom') { + renderCustomCompare(content, entries, compareSpec, sample, lensName, sample.id); + } else { + content.innerHTML = `

    Compare mode '${escapeHtml(mode)}' for block type '${escapeHtml(sample.type)}' is not supported in minimal runtime.

    `; + } + + recordView.appendChild(section); + } + } + + function renderUnifiedView(container, indices) { + const records = (state.data.records || []).filter((_, idx) => indices.includes(idx)); + const isSingle = indices.length === 1; + const title = isSingle + ? records[0].name + : `Comparison (${indices.map((i) => (state.data.records || [])[i].name).join(' vs ')})`; + + container.innerHTML = `

    ${escapeHtml(title)}

    `; + const recordView = container.querySelector('.record-view'); + + if (isSingle) { + const idx = indices[0]; + const record = (state.data.records || [])[idx]; + let hasContent = false; + + for (const lensName of Object.keys((record && record.views) || {})) { + const blocks = getLensBlocks(record, lensName); + const analysis = state.data.analysis_results && state.data.analysis_results[lensName]; + const context = { index: idx, record }; + for (const block of blocks) { + renderRecordBlock(recordView, lensName, block, context, analysis); + hasContent = true; + } + } + + if (!hasContent) recordView.innerHTML += '

    No views available for this record.

    '; + return; + } + + const allLenses = new Set(); + for (const idx of indices) { + const record = (state.data.records || [])[idx]; + for (const lensName of Object.keys((record && record.views) || {})) { + allLenses.add(lensName); + } + } + + for (const lensName of allLenses) { + renderCompareLens(recordView, lensName, indices); + } + } + + function renderMain() { + destroyGraphRuntime(); + + const container = document.getElementById('main-pane'); + if (!container) return; + container.innerHTML = ''; + + if (state.selectionMode) { + const indices = toArraySet(state.selectedIndices); + if (indices.length === 0) { + container.innerHTML = '
    Select items to compare...
    '; + } else { + renderUnifiedView(container, indices); + } + return; + } + + if (state.activeRecordIndex === -1) { + renderDashboard(container); + return; + } + + if (typeof state.activeRecordIndex === 'object' && state.activeRecordIndex !== null) { + const indices = state.activeRecordIndex.pool + ? state.activeRecordIndex.pool + : [state.activeRecordIndex.base, state.activeRecordIndex.new].filter((x) => Number.isInteger(x)); + renderUnifiedView(container, indices); + return; + } + + renderUnifiedView(container, [state.activeRecordIndex]); + } + + OBS.render = { + renderMain, + renderDashboard, + renderUnifiedView, + mountGraphViewer, + warmCompareInstances, + }; +})(); diff --git a/devtools/observatory/templates/js/04_actions.js b/devtools/observatory/templates/js/04_actions.js new file mode 100644 index 00000000000..cfc06c11710 --- /dev/null +++ b/devtools/observatory/templates/js/04_actions.js @@ -0,0 +1,116 @@ +(function() { + const OBS = window.__observatory; + const state = OBS.state; + const { renderIndex, updateIndexHeader } = OBS.layout; + const { renderMain } = OBS.render; + + function setTheme(theme) { + document.documentElement.setAttribute('data-theme', theme); + state.theme = theme; + localStorage.setItem('graphCollectorTheme', theme); + const icon = document.querySelector('.theme-icon'); + if (icon) icon.textContent = theme === 'dark' ? '☀️' : '🌙'; + + // Sync theme to all mounted fx_viewer instances + for (const viewer of (state.mountedViewers || [])) { + try { + if (typeof viewer.setTheme === 'function') viewer.setTheme(theme); + } catch (_) {} + } + for (const [, inst] of (state.graphCompareInstances || new Map())) { + const firstViewer = inst.compare.viewers[0]; + if (firstViewer && typeof firstViewer.setTheme === 'function') { + try { firstViewer.setTheme(theme); } catch (_) {} + } + } + } + + function toggleTheme() { + setTheme(state.theme === 'light' ? 'dark' : 'light'); + } + + function showCompareView() { + const indices = Array.from(arguments).filter((n) => Number.isInteger(n)); + state.activeRecordIndex = { pool: indices }; + renderIndex(); + renderMain(); + } + + function selectRecord(index, forceNavigate) { + if (forceNavigate && state.selectionMode) { + state.selectionMode = false; + state.selectedIndices.clear(); + } + + if (state.selectionMode && index !== -1) { + toggleSelect(index); + return; + } + + state.activeRecordIndex = index; + renderIndex(); + renderMain(); + } + + function toggleSelectionMode() { + state.selectionMode = !state.selectionMode; + if (state.selectionMode) { + state.selectedIndices.clear(); + if (typeof state.activeRecordIndex === 'number' && state.activeRecordIndex !== -1) { + state.selectedIndices.add(state.activeRecordIndex); + } else if ( + typeof state.activeRecordIndex === 'object' && + state.activeRecordIndex && + Array.isArray(state.activeRecordIndex.pool) + ) { + for (const idx of state.activeRecordIndex.pool) state.selectedIndices.add(idx); + } + } else { + state.selectedIndices.clear(); + } + + updateIndexHeader(); + renderIndex(); + renderMain(); + } + + function toggleSelectAll() { + const records = state.data.records || []; + if (state.selectedIndices.size === records.length) { + state.selectedIndices.clear(); + } else { + state.selectedIndices = new Set(records.map((_, i) => i)); + } + renderIndex(); + renderMain(); + } + + function toggleSelect(idx, event) { + if (event) event.stopPropagation(); + if (state.selectedIndices.has(idx)) { + state.selectedIndices.delete(idx); + } else { + state.selectedIndices.add(idx); + } + renderIndex(); + renderMain(); + } + + OBS.actions = { + setTheme, + toggleTheme, + showCompareView, + selectRecord, + toggleSelectionMode, + toggleSelectAll, + toggleSelect, + }; + + window.setTheme = setTheme; + window.toggleTheme = toggleTheme; + window.showCompareView = showCompareView; + window.selectRecord = selectRecord; + window.toggleSelectionMode = toggleSelectionMode; + window.toggleSelectAll = toggleSelectAll; + window.toggleSelect = toggleSelect; +})(); diff --git a/devtools/observatory/templates/js/05_bootstrap_api.js b/devtools/observatory/templates/js/05_bootstrap_api.js new file mode 100644 index 00000000000..9f4f0f83489 --- /dev/null +++ b/devtools/observatory/templates/js/05_bootstrap_api.js @@ -0,0 +1,147 @@ +(function() { + const OBS = window.__observatory; + const state = OBS.state; + const { renderLayout, renderIndex } = OBS.layout; + const { renderMain, mountGraphViewer } = OBS.render; + const { showToast } = OBS.utils; + const actions = OBS.actions; + + function wrapGraphHandle(viewer) { + return { + setLayers(layerIds) { + if (viewer && typeof viewer.setLayers === 'function') viewer.setLayers(layerIds || []); + }, + setColorBy(layerId) { + if (viewer && typeof viewer.setColorBy === 'function') viewer.setColorBy(layerId); + }, + updateLayerNodeStyle(layerId, nodeId, patch) { + if (!viewer || typeof viewer.patchLayerNodes !== 'function') return; + const payload = {}; + payload[nodeId] = patch || {}; + viewer.patchLayerNodes(layerId, payload); + }, + selectNode(nodeId, opts) { + if (viewer && typeof viewer.selectNode === 'function') viewer.selectNode(nodeId, opts || {}); + }, + zoomToFit() { + if (viewer && typeof viewer.zoomToFit === 'function') viewer.zoomToFit(); + }, + setSyncEnabled(enabled) { + if (viewer && typeof viewer.setState === 'function') { + try { + viewer.setState({ syncSelection: !!enabled }); + } catch (_e) {} + } + }, + enterFullscreen() { + if (viewer && typeof viewer.enterFullscreen === 'function') viewer.enterFullscreen(); + }, + exitFullscreen() { + if (viewer && typeof viewer.exitFullscreen === 'function') viewer.exitFullscreen(); + }, + onNodeSelected(callback) { + if (viewer && typeof viewer.on === 'function') { + viewer.on('selectionchange', callback); + } + }, + destroy() { + if (viewer && typeof viewer.destroy === 'function') viewer.destroy(); + }, + _viewer: viewer, + }; + } + + window.ObservatoryAPI = { + mountGraph(container, graphRef, options) { + let root = container; + if (typeof container === 'string') root = document.querySelector(container); + if (!root) throw new Error('mountGraph: container not found'); + + const graphRecord = { + graph_ref: graphRef, + default_layers: (options && options.default_layers) || [], + default_color_by: options && options.default_color_by, + viewer_options: (options && options.viewer_options) || {}, + }; + + const viewer = mountGraphViewer(root, graphRecord, graphRecord.viewer_options); + if (!viewer) throw new Error('mountGraph: failed to mount viewer'); + return wrapGraphHandle(viewer); + }, + + selectRecord(index) { + actions.selectRecord(index, true); + }, + + openCompare(indices) { + if (!Array.isArray(indices) || indices.length === 0) return; + state.activeRecordIndex = { pool: indices.slice() }; + renderIndex(); + renderMain(); + }, + + showSingleRecord(index) { + actions.selectRecord(index, true); + }, + + showToast(message, type) { + showToast(message, type || 'success'); + }, + + getContext() { + return { + activeRecordIndex: state.activeRecordIndex, + selectionMode: state.selectionMode, + selectedIndices: Array.from(state.selectedIndices), + records: state.data.records || [], + }; + }, + }; + + function setupDelegatedActions() { + document.body.addEventListener('click', (event) => { + const target = event.target && event.target.closest && event.target.closest('[data-ob-action]'); + if (!target) return; + + const action = target.getAttribute('data-ob-action'); + if (action === 'select-record') { + const rec = Number(target.getAttribute('data-ob-record')); + if (Number.isInteger(rec)) actions.selectRecord(rec, true); + return; + } + + if (action === 'open-compare') { + const raw = target.getAttribute('data-ob-indices') || ''; + const indices = raw + .split(',') + .map((v) => Number(v.trim())) + .filter((v) => Number.isInteger(v)); + if (indices.length > 0) window.ObservatoryAPI.openCompare(indices); + return; + } + + if (action === 'graph-focus-node') { + const nodeId = target.getAttribute('data-ob-node-id'); + if (!nodeId) return; + const firstViewer = state.mountedViewers && state.mountedViewers[0]; + if (firstViewer && typeof firstViewer.selectNode === 'function') { + firstViewer.selectNode(nodeId, { center: true, animate: true }); + } + } + }); + } + + function init() { + actions.setTheme(state.theme); + renderLayout(); + renderIndex(); + renderMain(); + setupDelegatedActions(); + + if (OBS.render.warmCompareInstances) { + requestAnimationFrame(() => OBS.render.warmCompareInstances()); + } + } + + init(); +})(); diff --git a/devtools/observatory/tests/test_graph_hub.py b/devtools/observatory/tests/test_graph_hub.py new file mode 100644 index 00000000000..94ce1cf3d84 --- /dev/null +++ b/devtools/observatory/tests/test_graph_hub.py @@ -0,0 +1,56 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.devtools.observatory.graph_hub import GraphHub +from executorch.devtools.observatory.interfaces import RecordAnalysis +from executorch.devtools.fx_viewer import ( + GraphExtensionNodePayload, + GraphExtensionPayload, +) + + +def test_graph_hub_register_and_layers() -> None: + hub = GraphHub() + hub.register_asset( + "r0", + base_payload={"legend": [], "nodes": [{"id": "n0"}], "edges": []}, + meta={"record_name": "r0"}, + ) + analysis = RecordAnalysis() + analysis.add_graph_layer( + "error", + GraphExtensionPayload( + id="error", + name="Error", + legend=[{"label": "L", "color": "#000"}], + sync_keys=["debug_handle"], + nodes={ + "n0": GraphExtensionNodePayload( + fill_color="#000", + ) + }, + ), + ) + hub.add_analysis_layers("r0", "accuracy", analysis) + + payload = hub.build_payload() + assert "r0" in payload["graph_assets"] + assert "accuracy/error" in payload["graph_layers"]["r0"] + assert payload["graph_layers"]["r0"]["accuracy/error"]["sync_keys"] == ["debug_handle"] + + +def test_build_viewer_payload() -> None: + graph_assets = { + "r1": { + "base": {"legend": [], "nodes": [{"id": "a"}], "edges": []}, + "meta": {}, + } + } + graph_layers = {"r1": {"x/y": {"name": "L", "legend": [], "nodes": {}}}} + + payload = GraphHub.build_viewer_payload(graph_assets, graph_layers, "r1") + assert payload["base"]["nodes"][0]["id"] == "a" + assert "x/y" in payload["extensions"] diff --git a/devtools/observatory/tests/test_graph_payload_relayout.py b/devtools/observatory/tests/test_graph_payload_relayout.py new file mode 100644 index 00000000000..8e748787947 --- /dev/null +++ b/devtools/observatory/tests/test_graph_payload_relayout.py @@ -0,0 +1,152 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from executorch.devtools.observatory.interfaces import ( + AnalysisResult, + Frontend, + Lens, + RecordAnalysis, + RecordDigest, + SessionResult, +) +from executorch.devtools.observatory.observatory import Observatory +from executorch.devtools.fx_viewer import ( + FXGraphExporter, + GraphExtensionNodePayload, + GraphExtensionPayload, +) + + +def _node_by_id(base_payload: dict, node_id: str) -> dict: + for node in base_payload["nodes"]: + if node.get("id") == node_id: + return node + raise KeyError(node_id) + + +def test_relayout_payload_base_uses_extension_label_lines() -> None: + base_payload = { + "legend": [], + "nodes": [ + {"id": "n0", "label": "a", "x": 0.0, "y": 0.0, "width": 100, "height": 36, "info": {}}, + {"id": "n1", "label": "b", "x": 0.0, "y": 0.0, "width": 100, "height": 36, "info": {}}, + ], + "edges": [{"v": "n0", "w": "n1", "points": []}], + } + ext_payload = { + "acc/psnr": { + "name": "PSNR", + "legend": [], + "nodes": { + "n0": {"label_append": ["psnr=12.34567890123456789"]}, + }, + } + } + + relaid = FXGraphExporter.relayout_payload_base(base_payload, ext_payload) + + assert _node_by_id(base_payload, "n0")["width"] == 100 + assert _node_by_id(base_payload, "n0")["height"] == 36 + assert base_payload["edges"][0]["points"] == [] + + n0 = _node_by_id(relaid, "n0") + assert n0["width"] > 100 + assert n0["height"] > 36 + assert relaid["edges"][0]["points"] + + +class _FakeGraphLens(Lens): + @classmethod + def get_name(cls) -> str: + return "graph" + + @staticmethod + def get_frontend_spec() -> Frontend: + return Frontend() + + +class _FakeLayerLens(Lens): + @classmethod + def get_name(cls) -> str: + return "fake_layer" + + @staticmethod + def analyze(records, config) -> AnalysisResult: + result = AnalysisResult() + for record in records: + rec_analysis = RecordAnalysis() + rec_analysis.add_graph_layer( + "test", + GraphExtensionPayload( + id="test", + name="Test Layer", + legend=[], + nodes={ + "n0": GraphExtensionNodePayload( + label_append=["a long extension label for relayout"], + ) + }, + ), + ) + result.per_record_data[record.name] = rec_analysis + return result + + @staticmethod + def get_frontend_spec() -> Frontend: + return Frontend() + + +def test_observatory_relayout_applied_during_payload_assembly() -> None: + records = [ + RecordDigest( + name="r0", + timestamp=0.0, + data={ + "graph": { + "graph_ref": "r0", + "base": { + "legend": [], + "nodes": [ + { + "id": "n0", + "label": "a", + "x": 0.0, + "y": 0.0, + "width": 100, + "height": 36, + "info": {}, + }, + { + "id": "n1", + "label": "b", + "x": 0.0, + "y": 0.0, + "width": 100, + "height": 36, + "info": {}, + }, + ], + "edges": [{"v": "n0", "w": "n1", "points": []}], + }, + "meta": {}, + } + }, + ) + ] + + payload = Observatory._generate_report_payload( + records=records, + session=SessionResult(), + config={}, + lens_registry=[_FakeGraphLens, _FakeLayerLens], + ) + + base = payload["graph_assets"]["r0"]["base"] + assert _node_by_id(base, "n0")["width"] > 100 + assert _node_by_id(base, "n0")["height"] > 36 + assert base["edges"][0]["points"] diff --git a/devtools/observatory/tests/test_observatory_smoke.py b/devtools/observatory/tests/test_observatory_smoke.py new file mode 100644 index 00000000000..76f61e2a202 --- /dev/null +++ b/devtools/observatory/tests/test_observatory_smoke.py @@ -0,0 +1,99 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import json + +from executorch.devtools.observatory import Observatory +from executorch.devtools.observatory.observatory import ( + _NonFiniteFloatAsStringJSONEncoder, +) + + +class _SmokeModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Linear(4, 4) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fc(x) + + +def _build_report_payload_for_tests(): + Observatory._ensure_default_lenses() + return Observatory._generate_report_payload( + list(Observatory._records.values()), + Observatory._session_result, + {}, + Observatory._lens_registry, + ) + + +def test_observatory_collect_and_export_html(tmp_path) -> None: + Observatory.clear() + + model = _SmokeModel().eval() + graph_module = torch.fx.symbolic_trace(model) + + with Observatory.enable_context(): + Observatory.collect("smoke", graph_module) + + out = tmp_path / "report.html" + Observatory.export_html_report(str(out), title="Smoke") + + assert out.exists() + assert out.stat().st_size > 0 + + Observatory.clear() + + +def test_observatory_merges_analyze_only_graph_layers_and_defaults() -> None: + Observatory.clear() + + +def test_nonfinite_floats_are_serialized_as_strings() -> None: + payload = { + "nan_value": float("nan"), + "pos_inf": float("inf"), + "neg_inf": float("-inf"), + "finite": 1.25, + } + + encoded = json.dumps(payload, cls=_NonFiniteFloatAsStringJSONEncoder) + decoded = json.loads(encoded) + + assert decoded["nan_value"] == "nan" + assert decoded["pos_inf"] == "inf" + assert decoded["neg_inf"] == "-inf" + assert decoded["finite"] == 1.25 + + model = _SmokeModel().eval() + graph_module = torch.fx.symbolic_trace(model) + + with Observatory.enable_context(): + Observatory.collect("smoke", graph_module) + + payload = _build_report_payload_for_tests() + assert "smoke" in payload["graph_assets"] + + graph_layers = payload["graph_layers"].get("smoke", {}) + assert "graph_color/op_type" in graph_layers + assert "graph_color/op_target" in graph_layers + + records = payload["records"] + assert len(records) == 1 + graph_record = records[0]["views"]["graph"]["blocks"][0]["record"] + + default_layers = graph_record["default_layers"] + default_color_by = graph_record["default_color_by"] + + assert default_layers == ["graph_color/op_type", "graph_color/op_target"] + assert default_color_by == "graph_color/op_type" + for layer_id in default_layers: + assert layer_id in graph_layers + assert default_color_by in graph_layers + + Observatory.clear() diff --git a/devtools/observatory/tests/test_observe_pass.py b/devtools/observatory/tests/test_observe_pass.py new file mode 100644 index 00000000000..739e7356c11 --- /dev/null +++ b/devtools/observatory/tests/test_observe_pass.py @@ -0,0 +1,218 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.fx.passes.infra.pass_base import PassBase, PassResult + +from executorch.devtools.observatory import Observatory, observe_pass + + +class _ToyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc = torch.nn.Linear(4, 4) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fc(x) + + +class _IdentityPass(PassBase): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + return PassResult(graph_module, False) + + +def _make_gm() -> torch.fx.GraphModule: + return torch.fx.symbolic_trace(_ToyModel().eval()) + + +# --- Name deduplication in Observatory.collect() --- + + +def test_collect_dedup_names() -> None: + Observatory.clear() + gm = _make_gm() + with Observatory.enable_context(): + Observatory.collect("x", gm) + Observatory.collect("x", gm) + Observatory.collect("x", gm) + names = Observatory.list_collected() + assert "x" in names + assert "x #2" in names + assert "x #3" in names + assert len(names) == 3 + Observatory.clear() + + +def test_collect_no_dedup_for_unique_names() -> None: + Observatory.clear() + gm = _make_gm() + with Observatory.enable_context(): + Observatory.collect("a", gm) + Observatory.collect("b", gm) + names = Observatory.list_collected() + assert names == ["a", "b"] + Observatory.clear() + + +# --- observe_pass: instance wrapper --- + + +def test_observe_pass_instance_both() -> None: + Observatory.clear() + gm = _make_gm() + observed = observe_pass(_IdentityPass()) + with Observatory.enable_context(): + result = observed(gm) + assert isinstance(result, PassResult) + names = Observatory.list_collected() + assert "_IdentityPass/input" in names + assert "_IdentityPass/output" in names + Observatory.clear() + + +def test_observe_pass_instance_output_only() -> None: + Observatory.clear() + gm = _make_gm() + observed = observe_pass(_IdentityPass(), collect_input=False) + with Observatory.enable_context(): + observed(gm) + names = Observatory.list_collected() + assert "_IdentityPass" in names + assert len(names) == 1 + Observatory.clear() + + +def test_observe_pass_instance_input_only() -> None: + Observatory.clear() + gm = _make_gm() + observed = observe_pass(_IdentityPass(), collect_output=False) + with Observatory.enable_context(): + observed(gm) + names = Observatory.list_collected() + assert "_IdentityPass" in names + assert len(names) == 1 + Observatory.clear() + + +# --- observe_pass: class decorator --- + + +def test_observe_pass_class_decorator() -> None: + @observe_pass + class _DecoratedPass(PassBase): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + return PassResult(graph_module, False) + + Observatory.clear() + gm = _make_gm() + p = _DecoratedPass() + with Observatory.enable_context(): + result = p(gm) + assert isinstance(result, PassResult) + names = Observatory.list_collected() + assert "_DecoratedPass/input" in names + assert "_DecoratedPass/output" in names + Observatory.clear() + + +def test_observe_pass_parameterized_class_decorator() -> None: + @observe_pass(name="Custom", collect_input=False) + class _ParamPass(PassBase): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + return PassResult(graph_module, False) + + Observatory.clear() + gm = _make_gm() + with Observatory.enable_context(): + _ParamPass()(gm) + names = Observatory.list_collected() + assert "Custom" in names + assert len(names) == 1 + Observatory.clear() + + +# --- observe_pass: function wrapper --- + + +def test_observe_pass_function() -> None: + def my_pass_fn(gm: torch.fx.GraphModule) -> PassResult: + return PassResult(gm, False) + + Observatory.clear() + gm = _make_gm() + observed = observe_pass(my_pass_fn) + with Observatory.enable_context(): + result = observed(gm) + assert isinstance(result, PassResult) + names = Observatory.list_collected() + assert "my_pass_fn/input" in names + assert "my_pass_fn/output" in names + Observatory.clear() + + +# --- Deduplication on repeated calls --- + + +def test_observe_pass_repeated_calls_dedup() -> None: + Observatory.clear() + gm = _make_gm() + observed = observe_pass(_IdentityPass()) + with Observatory.enable_context(): + observed(gm) + observed(gm) + names = Observatory.list_collected() + assert "_IdentityPass/input" in names + assert "_IdentityPass/output" in names + assert "_IdentityPass/input #2" in names + assert "_IdentityPass/output #2" in names + assert len(names) == 4 + Observatory.clear() + + +# --- No-op outside context --- + + +def test_observe_pass_noop_outside_context() -> None: + Observatory.clear() + gm = _make_gm() + observed = observe_pass(_IdentityPass()) + result = observed(gm) + assert isinstance(result, PassResult) + assert len(Observatory.list_collected()) == 0 + Observatory.clear() + + +# --- PassResult preservation --- + + +def test_observe_pass_preserves_return_value() -> None: + class _ModifyingPass(PassBase): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + return PassResult(graph_module, True) + + Observatory.clear() + gm = _make_gm() + observed = observe_pass(_ModifyingPass()) + with Observatory.enable_context(): + result = observed(gm) + assert result.modified is True + assert result.graph_module is gm + Observatory.clear() + + +# --- Custom name --- + + +def test_observe_pass_custom_name() -> None: + Observatory.clear() + gm = _make_gm() + observed = observe_pass(_IdentityPass(), name="MyStep") + with Observatory.enable_context(): + observed(gm) + names = Observatory.list_collected() + assert "MyStep/input" in names + assert "MyStep/output" in names + Observatory.clear() diff --git a/devtools/observatory/tests/test_per_layer_accuracy_lens.py b/devtools/observatory/tests/test_per_layer_accuracy_lens.py new file mode 100644 index 00000000000..a542f2b5dd3 --- /dev/null +++ b/devtools/observatory/tests/test_per_layer_accuracy_lens.py @@ -0,0 +1,238 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from types import SimpleNamespace +from typing import List + +import torch + +from executorch.devtools.observatory.interfaces import ObservationContext, RecordDigest +from executorch.devtools.observatory.lenses.accuracy import AccuracyLens +from executorch.devtools.observatory.lenses.per_layer_accuracy import ( + PerLayerAccuracyLens, +) + + +class _ToyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(4, 4) + self.fc2 = torch.nn.Linear(4, 4) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = torch.relu(x) + return self.fc2(x) + + +def _non_io_nodes(graph_module: torch.fx.GraphModule) -> List[torch.fx.Node]: + return [ + n for n in graph_module.graph.nodes if n.op not in ("placeholder", "output") + ] + + +def _attach_same_root(graph_module: torch.fx.GraphModule, root_name: str) -> None: + for node in _non_io_nodes(graph_module): + node.meta["from_node"] = [SimpleNamespace(name=root_name, from_node=[])] + + +def _attach_ordered_roots(graph_module: torch.fx.GraphModule) -> None: + for idx, node in enumerate(_non_io_nodes(graph_module)): + node.meta["from_node"] = [SimpleNamespace(name=f"root_{idx}", from_node=[])] + + +def test_sparse_index_uses_last_topological_node_per_root() -> None: + gm = torch.fx.symbolic_trace(_ToyModel().eval()) + _attach_same_root(gm, "shared_root") + + sparse = PerLayerAccuracyLens._build_sparse_node_index(gm) + key = "root:shared_root" + assert key in sparse + assert sparse[key].node_id == _non_io_nodes(gm)[-1].name + + +def test_sparse_index_fallback_uses_node_id_when_from_node_missing() -> None: + gm = torch.fx.symbolic_trace(_ToyModel().eval()) + sparse = PerLayerAccuracyLens._build_sparse_node_index(gm) + node_names = [n.name for n in _non_io_nodes(gm)] + assert all(f"id:{name}" in sparse for name in node_names) + + +def test_analyze_produces_unified_metric_ranges_across_records() -> None: + def _digest(rows): + return { + "rows": rows, + "match_count": len(rows), + "sample_index": 0, + "sample_source": "test", + } + + rec_a = RecordDigest( + name="record_a", + timestamp=0.0, + data={ + "per_layer_accuracy": _digest( + [ + { + "target_node": "n1", + "cosine_sim": 0.90, + "psnr": 20.0, + "mse": 1e-3, + "abs_err": 1e-2, + }, + { + "target_node": "n2", + "cosine_sim": 0.95, + "psnr": 25.0, + "mse": 5e-4, + "abs_err": 5e-3, + }, + ] + ) + }, + ) + rec_b = RecordDigest( + name="record_b", + timestamp=1.0, + data={ + "per_layer_accuracy": _digest( + [ + { + "target_node": "n1", + "cosine_sim": 0.90, + "psnr": 30.0, + "mse": 2e-3, + "abs_err": 1e-2, + }, + { + "target_node": "n3", + "cosine_sim": 0.70, + "psnr": 18.0, + "mse": 8e-3, + "abs_err": 4e-2, + }, + ] + ) + }, + ) + + result = PerLayerAccuracyLens.analyze([rec_a, rec_b], {}) + + # Global ranges span the union across both records. + ranges = result.global_data["metric_ranges"] + assert ranges["cosine_sim"] == [0.70, 0.95] + assert ranges["psnr"] == [18.0, 30.0] + assert ranges["mse"] == [5e-4, 8e-3] + assert ranges["abs_err"] == [5e-3, 4e-2] + + # A node with the same metric value in both records gets the same color. + payload_a = ( + result.per_record_data["record_a"].graph_layers["cosine_sim"].to_payload() + ) + payload_b = ( + result.per_record_data["record_b"].graph_layers["cosine_sim"].to_payload() + ) + assert payload_a.nodes["n1"].fill_color == payload_b.nodes["n1"].fill_color + # And the two endpoints of the unified scale map to distinct colors. + assert payload_a.nodes["n2"].fill_color != payload_b.nodes["n3"].fill_color + + +def test_per_layer_accuracy_observe_analyze_and_frontend_defaults() -> None: + PerLayerAccuracyLens.clear() + + old_dataset = AccuracyLens._captured_dataset + old_worst = dict(AccuracyLens._worst_indices) + try: + torch.manual_seed(0) + anchor_model = _ToyModel().eval() + target_model = _ToyModel().eval() + with torch.no_grad(): + for p in target_model.parameters(): + p.add_(0.01) + + anchor_gm = torch.fx.symbolic_trace(anchor_model) + target_gm = torch.fx.symbolic_trace(target_model) + _attach_ordered_roots(anchor_gm) + _attach_ordered_roots(target_gm) + + sample = (torch.randn(1, 4),) + AccuracyLens._captured_dataset = [sample] + AccuracyLens._worst_indices = {"mse": 0} + + cfg = { + "per_layer_accuracy": { + "anchor_record_name": "Exported Float", + "worst_metric_priority": ["mse"], + } + } + + anchor_ctx = ObservationContext( + config=cfg, + shared_state={"record_name": "Exported Float"}, + ) + target_ctx = ObservationContext( + config=cfg, + shared_state={"record_name": "Quantized Model"}, + ) + + anchor_digest = PerLayerAccuracyLens.observe(anchor_gm, anchor_ctx) + target_digest = PerLayerAccuracyLens.observe(target_gm, target_ctx) + + assert isinstance(anchor_digest, dict) + assert isinstance(target_digest, dict) + assert target_digest["sample_index"] == 0 + assert target_digest["match_count"] > 0 + assert len(target_digest["rows"]) > 0 + first_row = target_digest["rows"][0] + assert "psnr" in first_row + assert "psnr_drop" not in first_row + assert "cosine_sim" in first_row + assert "mse" in first_row + assert "abs_err" in first_row + + records = [ + RecordDigest( + name="Exported Float", + timestamp=0.0, + data={"per_layer_accuracy": anchor_digest}, + ), + RecordDigest( + name="Quantized Model", + timestamp=1.0, + data={"per_layer_accuracy": target_digest}, + ), + ] + analysis = PerLayerAccuracyLens.analyze(records, {}) + assert "Quantized Model" in analysis.per_record_data + rec_analysis = analysis.per_record_data["Quantized Model"] + assert "psnr" in rec_analysis.graph_layers + assert "cosine_sim" in rec_analysis.graph_layers + assert "mse" in rec_analysis.graph_layers + assert "abs_err" in rec_analysis.graph_layers + psnr_payload = rec_analysis.graph_layers["psnr"].to_payload() + assert "sparse_match_key" in psnr_payload.sync_keys + mse_payload = rec_analysis.graph_layers["mse"].to_payload() + assert "sparse_match_key" in mse_payload.sync_keys + + frontend = PerLayerAccuracyLens.get_frontend_spec() + view = frontend.record( + target_digest, {"record": {}}, {"name": "Quantized Model", "index": 1} + ) + assert view is not None + graph_blocks = [b for b in view.blocks if getattr(b, "type", "") == "graph"] + assert graph_blocks + assert graph_blocks[0].record.default_color_by == "per_layer_accuracy/psnr" + assert "per_layer_accuracy/psnr" in graph_blocks[0].record.default_layers + html_blocks = [b for b in view.blocks if getattr(b, "type", "") == "html"] + assert html_blocks + assert html_blocks[0].id == "per_layer_accuracy_metrics_table" + assert "Per-layer Metrics" in html_blocks[0].title + finally: + AccuracyLens._captured_dataset = old_dataset + AccuracyLens._worst_indices = old_worst + PerLayerAccuracyLens.clear() diff --git a/devtools/observatory/tests/test_utils_and_stacktrace.py b/devtools/observatory/tests/test_utils_and_stacktrace.py new file mode 100644 index 00000000000..69e5996053a --- /dev/null +++ b/devtools/observatory/tests/test_utils_and_stacktrace.py @@ -0,0 +1,95 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from executorch.devtools.observatory.lenses.stack_trace import StackTraceLens +from executorch.devtools.observatory import utils + + +@pytest.mark.parametrize( + "remote_url, expected", + [ + ("git@github.com:pytorch/executorch.git", "https://github.com/pytorch/executorch"), + ("git@github.qualcomm.com:MLG/executorch", "https://github.qualcomm.com/MLG/executorch"), + ( + "ssh://git@github.enterprise.local:8022/org/repo.git", + "https://github.enterprise.local:8022/org/repo", + ), + ("https://github.com/pytorch/executorch.git", "https://github.com/pytorch/executorch"), + ("file:///tmp/executorch", None), + ], +) +def test_normalize_remote_url(remote_url, expected): + assert utils._normalize_remote_url(remote_url) == expected # type: ignore[attr-defined] + + +def test_stack_trace_uses_commit_links(monkeypatch, tmp_path): + repo_root = tmp_path / "repo" + repo_root.mkdir() + source_file = repo_root / "pkg" / "module.py" + source_file.parent.mkdir(parents=True) + source_file.write_text("line = 1\n") + + frame = SimpleNamespace( + filename=str(source_file), + function="test_fn", + lineno=1, + code_context=["line = 1\n"], + ) + + git_info = SimpleNamespace( + commit_blob_url="https://github.com/org/repo/blob/abc123", + branch_blob_url="https://github.com/org/repo/blob/main", + github_link="https://github.com/org/repo/tree/main", + ) + + from executorch.devtools.observatory.lenses import stack_trace as stack_trace_mod + + monkeypatch.setattr(stack_trace_mod, "inspect", SimpleNamespace(stack=lambda: [frame])) + monkeypatch.setattr(stack_trace_mod, "get_repo_root", lambda: str(repo_root)) + monkeypatch.setattr(stack_trace_mod, "get_git_info", lambda: git_info) + monkeypatch.setattr(stack_trace_mod, "is_in_repo", lambda path: str(path).startswith(str(repo_root))) + + frames = StackTraceLens._get_stack_trace() + assert len(frames) == 1 + assert frames[0]["link"] == "https://github.com/org/repo/blob/abc123/pkg/module.py#L1" + + +def test_stack_trace_falls_back_to_branch_links(monkeypatch, tmp_path): + repo_root = tmp_path / "repo" + repo_root.mkdir() + source_file = repo_root / "pkg" / "fallback.py" + source_file.parent.mkdir(parents=True) + source_file.write_text("line = 2\n") + + frame = SimpleNamespace( + filename=str(source_file), + function="fallback_fn", + lineno=2, + code_context=["line = 2\n"], + ) + + git_info = SimpleNamespace( + commit_blob_url=None, + branch_blob_url="https://github.com/org/repo/blob/main", + github_link="https://github.com/org/repo/tree/main", + ) + + from executorch.devtools.observatory.lenses import stack_trace as stack_trace_mod + + monkeypatch.setattr(stack_trace_mod, "inspect", SimpleNamespace(stack=lambda: [frame])) + monkeypatch.setattr(stack_trace_mod, "get_repo_root", lambda: str(repo_root)) + monkeypatch.setattr(stack_trace_mod, "get_git_info", lambda: git_info) + monkeypatch.setattr(stack_trace_mod, "is_in_repo", lambda path: str(path).startswith(str(repo_root))) + + frames = StackTraceLens._get_stack_trace() + assert len(frames) == 1 + assert frames[0]["link"] == "https://github.com/org/repo/blob/main/pkg/fallback.py#L2" diff --git a/devtools/observatory/utils.py b/devtools/observatory/utils.py new file mode 100644 index 00000000000..6f563f92716 --- /dev/null +++ b/devtools/observatory/utils.py @@ -0,0 +1,191 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import logging +import os +import subprocess +from dataclasses import dataclass +from typing import Optional +from urllib.parse import urlparse + + +@dataclass +class GitInfo: + """Git repository information for source links.""" + + remote_url: Optional[str] = None + remote_https_url: Optional[str] = None + branch: Optional[str] = None + commit_hash: Optional[str] = None + is_dirty: bool = False + github_link: Optional[str] = None + branch_blob_url: Optional[str] = None + commit_blob_url: Optional[str] = None + + +_cached_git_info: Optional[GitInfo] = None +_cached_repo_root: Optional[str] = None + + +def _strip_git_suffix(value: str) -> str: + return value[:-4] if value.endswith(".git") else value + + +def _normalize_remote_url(remote_url: str) -> Optional[str]: + """Return an https URL for browser links regardless of git remote scheme.""" + + if not remote_url: + return None + + remote_url = remote_url.strip() + if not remote_url: + return None + + if remote_url.startswith("git@"): + try: + user_host, path = remote_url.split(":", 1) + except ValueError: + return None + host = user_host.split("@", 1)[1] + cleaned_path = _strip_git_suffix(path.lstrip("/").rstrip("/")) + return f"https://{host}/{cleaned_path}" if cleaned_path else f"https://{host}" + + if remote_url.startswith("ssh://"): + parsed = urlparse(remote_url) + host = parsed.hostname + if not host: + return None + port_suffix = f":{parsed.port}" if parsed.port and parsed.port not in (22,) else "" + cleaned_path = _strip_git_suffix(parsed.path.lstrip("/").rstrip("/")) + return f"https://{host}{port_suffix}/{cleaned_path}" if cleaned_path else f"https://{host}{port_suffix}" + + parsed = urlparse(remote_url) + if parsed.scheme in {"http", "https"}: + host = parsed.hostname + if not host: + return None + port_suffix = "" + if parsed.port and not (parsed.scheme == "http" and parsed.port == 80) and not (parsed.scheme == "https" and parsed.port == 443): + port_suffix = f":{parsed.port}" + cleaned_path = _strip_git_suffix(parsed.path.lstrip("/").rstrip("/")) + return f"https://{host}{port_suffix}/{cleaned_path}" if cleaned_path else f"https://{host}{port_suffix}" + + return None + + +def get_repo_root() -> Optional[str]: + """Return repository root, or None when unavailable.""" + + global _cached_repo_root + if _cached_repo_root is not None: + return _cached_repo_root + + try: + result = subprocess.run( + ["git", "rev-parse", "--show-toplevel"], + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0: + _cached_repo_root = result.stdout.strip() + return _cached_repo_root + except Exception as exc: + logging.debug("[Observatory] Failed to detect repo root: %s", exc) + + return None + + +def get_git_info() -> GitInfo: + """Return git metadata used for stack trace source links.""" + + global _cached_git_info + if _cached_git_info is not None: + return _cached_git_info + + info = GitInfo() + + try: + upstream = subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"], + capture_output=True, + text=True, + check=False, + ) + + remote_name = None + if upstream.returncode == 0: + upstream_name = upstream.stdout.strip() + if "/" in upstream_name: + remote_name, info.branch = upstream_name.split("/", 1) + + if info.branch is None: + local = subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + capture_output=True, + text=True, + check=False, + ) + if local.returncode == 0: + info.branch = local.stdout.strip() + remote_name = "origin" + + if remote_name: + remote_url = subprocess.run( + ["git", "config", "--get", f"remote.{remote_name}.url"], + capture_output=True, + text=True, + check=False, + ) + if remote_url.returncode == 0: + info.remote_url = remote_url.stdout.strip() + info.remote_https_url = _normalize_remote_url(info.remote_url) + + commit = subprocess.run( + ["git", "rev-parse", "HEAD"], + capture_output=True, + text=True, + check=False, + ) + if commit.returncode == 0: + info.commit_hash = commit.stdout.strip() + + dirty = subprocess.run( + ["git", "status", "--porcelain"], + capture_output=True, + text=True, + check=False, + ) + if dirty.returncode == 0: + info.is_dirty = bool(dirty.stdout.strip()) + + if info.remote_https_url and info.branch: + info.github_link = f"{info.remote_https_url}/tree/{info.branch}" + info.branch_blob_url = f"{info.remote_https_url}/blob/{info.branch}" + if info.remote_https_url and info.commit_hash: + info.commit_blob_url = f"{info.remote_https_url}/blob/{info.commit_hash}" + except Exception as exc: + logging.debug("[Observatory] Failed to query git info: %s", exc) + + _cached_git_info = info + return info + + +def is_in_repo(filepath: str) -> bool: + """Check whether filepath is inside the current repository root.""" + + repo_root = get_repo_root() + if repo_root is None: + return False + + try: + abs_path = os.path.abspath(filepath) + abs_root = os.path.abspath(repo_root) + return abs_path.startswith(abs_root) + except Exception: + return False