From ef65f2cd7cc62c04cf5c9031c4e6c9ec3118b245 Mon Sep 17 00:00:00 2001 From: svij Date: Tue, 26 May 2026 05:00:35 +0000 Subject: [PATCH] insoect init --- gigl/analytics/SPEC.md | 81 +++++ gigl/analytics/_inspect_impl_ai.py | 248 ++++++++++++++++ gigl/analytics/inspect.py | 129 ++++++++ tests/unit/analytics/__init__.py | 0 tests/unit/analytics/inspect_test.py | 428 +++++++++++++++++++++++++++ 5 files changed, 886 insertions(+) create mode 100644 gigl/analytics/SPEC.md create mode 100644 gigl/analytics/_inspect_impl_ai.py create mode 100644 gigl/analytics/inspect.py create mode 100644 tests/unit/analytics/__init__.py create mode 100644 tests/unit/analytics/inspect_test.py diff --git a/gigl/analytics/SPEC.md b/gigl/analytics/SPEC.md new file mode 100644 index 000000000..6b883705e --- /dev/null +++ b/gigl/analytics/SPEC.md @@ -0,0 +1,81 @@ +# SPEC: `gigl.analytics` + +Per-symbol contracts for **agent-owned** public symbols in `gigl.analytics`. Each section below is the contract for one +agent-owned symbol; the matching tests under `tests/unit/analytics/` pin its behavior, and the implementation in +`_*_impl_ai.py` is regenerable from this spec. + +Other public symbols in `gigl.analytics` are human-owned and not spec'd here — their docstrings in source are +authoritative. + +## Table of Contents + +Agent-owned (spec'd below): + +- [`summary`](#summary) — Per-seed fanout summary for a sampled HeteroData batch. +- `HeteroDataSummary` — Frozen dataclass returned by `summary`; see the `summary` section. + +Human-owned (see docstrings in source): + +- `log_startup_diagnostics` (`inspect.py`) — Logs sampler-RAM estimate + local partition counts at training startup. + +______________________________________________________________________ + +## `summary(data: HeteroData) -> HeteroDataSummary` + +### Purpose + +Per-seed fanout summary for sampled `HeteroData` batches from a GiGL's neighbor loader implementations. Use it in +progress logs to spot empty hops, degenerate fanouts, or seed-side mis-attribution at a glance. + +```python +from gigl.analytics.inspect import summary + +logger.info(f"fanout: {summary(batch)}") +# → "seeds=128 hop1(min=3 med=10 avg=12.5 max=25) hop2(min=12 med=80 avg=91.2 max=240)" +``` + +Use with batches produced by `DistNeighborLoader` and `DistABLPLoader` (`gigl/distributed/`); both attach the required +sampler metadata. + +### Interface + +```python +summary(data: HeteroData) -> HeteroDataSummary +``` + +`HeteroDataSummary` — `@dataclass(frozen=True)`: + +- `seeds: int` +- `per_hop: list[HeteroDataSummary.HopStats]` — one per hop, ordered 1..K. +- `__str__` renders `"seeds=N hop1(min=X med=Y avg=Z max=W) hop2(...) ..."` with `avg` formatted to one decimal place. + +`HeteroDataSummary.HopStats` — nested `@dataclass(frozen=True)` with `min: int`, `med: int`, `avg: float`, `max: int`: +per-seed neighbor count distribution at that hop. + +### Input requirements + +The batch must carry sampler-set metadata, or `summary` raises `ValueError`: + +- `data[seed_type].batch_size > 0` on exactly one node type. +- `data.num_sampled_nodes: dict[NodeType, Tensor]` at the HeteroData root — entry for the seed type, length `K + 1`. +- `data.num_sampled_edges: dict[EdgeType, Tensor]` at the HeteroData root — entry for every edge type in + `data.edge_types`, length `K`. + +### Error contract + +`summary` raises `ValueError` (no other exception types, no fallback) when: + +- Zero or multiple node types have `batch_size > 0`. +- `data.num_sampled_nodes` is missing, not a dict, or lacks the seed type. +- `data.num_sampled_edges` is missing, not a dict, or lacks any edge type in `data.edge_types`. +- `num_sampled_nodes[seed_type]` implies fewer than 1 hop. + +### Non-goals + +- Homogeneous `Data` objects — users can call `summary` by converting to HeteroData first. + +### Verification + +`tests/unit/analytics/inspect_test.py` pins the behavior with hand-rolled HeteroData fixtures (per-seed walk tables in +the docstrings), `ValueError` guardrails for each contract violation, canonical `__str__` assertions, and end-to-end +runs against `DistNeighborLoader` and `DistABLPLoader`. Implementations must pass that suite. diff --git a/gigl/analytics/_inspect_impl_ai.py b/gigl/analytics/_inspect_impl_ai.py new file mode 100644 index 000000000..c051033cd --- /dev/null +++ b/gigl/analytics/_inspect_impl_ai.py @@ -0,0 +1,248 @@ +# AI-OWNED FILE +# spec: ./SPEC.md +# last-generated: 2026-05-22T00:00:00Z +# --- +"""Implementation of the HeteroData batch inspector. + +Contract is defined in ``gigl/analytics/SPEC.md``. The public surface lives in +``gigl.analytics.inspect``; this module is regenerable and not intended for +direct import by application code. +""" + +from dataclasses import dataclass + +import torch +from torch_geometric.data import HeteroData +from torch_geometric.typing import EdgeType, NodeType + + +@dataclass(frozen=True) +class HeteroDataSummary: + """Diagnostic summary of a sampled HeteroData batch. + + Attributes: + seeds: Number of seed nodes in the batch. + per_hop: ``HopStats`` per hop, in order from hop 1 to hop K. + """ + + @dataclass(frozen=True) + class HopStats: + """Per-seed neighbor count distribution at one hop.""" + + min: int + med: int + avg: float + max: int + + seeds: int + per_hop: list[HopStats] + + def __str__(self) -> str: + parts = [f"seeds={self.seeds}"] + for k, s in enumerate(self.per_hop, 1): + parts.append(f"hop{k}(min={s.min} med={s.med} avg={s.avg:.1f} max={s.max})") + return " ".join(parts) + + +def _summary_impl(data: HeteroData) -> HeteroDataSummary: + """Implementation. See ``gigl.analytics.inspect.summary`` for the public contract.""" + seed_type = _detect_seed_type(data) + batch_size = int(data[seed_type].batch_size) + num_hops = _detect_num_hops(data, seed_type) + edge_hop_bounds = _hop_boundaries(data) + + device = _pick_device(data) + seeds = torch.arange(batch_size, dtype=torch.long, device=device) + frontier: dict[NodeType, tuple[torch.Tensor, torch.Tensor]] = { + seed_type: (seeds, seeds) + } + + per_hop: list[HeteroDataSummary.HopStats] = [] + for hop in range(1, num_hops + 1): + hop_counts = torch.zeros(batch_size, dtype=torch.long, device=device) + new_parts: dict[NodeType, list[tuple[torch.Tensor, torch.Tensor]]] = {} + + for edge_type in data.edge_types: + src_type, dst_type = edge_type[0], edge_type[2] + # Walk from whichever side of the edge a frontier already covers. + # Edge_dir="out" loaders reverse the stored edges, so the seed + # often sits on the dst side at hop 1. + if src_type in frontier: + f_nodes, f_seeds = frontier[src_type] + walk_ei_src = 0 + other_type = dst_type + elif dst_type in frontier: + f_nodes, f_seeds = frontier[dst_type] + walk_ei_src = 1 + other_type = src_type + else: + continue + if f_nodes.numel() == 0: + continue + + start, end = edge_hop_bounds[edge_type][hop - 1 : hop + 1] + if start == end: + continue + hop_ei = data[edge_type].edge_index[:, start:end] + walk_src = hop_ei[walk_ei_src] + walk_dst = hop_ei[1 - walk_ei_src] + + new_nodes, new_seeds = _expand_frontier( + walk_src, walk_dst, f_nodes, f_seeds + ) + if new_seeds.numel() == 0: + continue + + hop_counts.scatter_add_( + 0, + new_seeds, + torch.ones(new_seeds.numel(), dtype=torch.long, device=device), + ) + if hop < num_hops: + new_parts.setdefault(other_type, []).append((new_nodes, new_seeds)) + + per_hop.append(_hop_stats(hop_counts)) + + if hop < num_hops: + frontier = { + nt: ( + torch.cat([n for n, _ in parts]), + torch.cat([s for _, s in parts]), + ) + for nt, parts in new_parts.items() + } + + return HeteroDataSummary(seeds=batch_size, per_hop=per_hop) + + +def _detect_seed_type(data: HeteroData) -> NodeType: + seed_types = [ + nt + for nt in data.node_types + if getattr(data[nt], "batch_size", None) is not None and data[nt].batch_size > 0 + ] + if len(seed_types) != 1: + raise ValueError( + f"Expected exactly one node type with batch_size > 0; found {seed_types}" + ) + return seed_types[0] + + +def _detect_num_hops(data: HeteroData, seed_type: NodeType) -> int: + num_sampled = _get_root_dict(data, "num_sampled_nodes") + if seed_type not in num_sampled: + raise ValueError( + f"data.num_sampled_nodes[{seed_type}] is missing — required to " + "infer the number of hops. Was this batch produced by a GiGL " + "sampler?" + ) + series = num_sampled[seed_type] + num_hops = len(series) - 1 + if num_hops < 1: + raise ValueError( + f"data.num_sampled_nodes[{seed_type}] implies {num_hops} hops; " + "expected at least 1." + ) + return num_hops + + +def _hop_boundaries(data: HeteroData) -> dict[EdgeType, list[int]]: + """Return prefix sums of ``data.num_sampled_edges[edge_type]`` per edge type. + + For edge type E with sampled-edge series ``[n1, n2, ...]``, the result + ``[0, n1, n1+n2, ...]`` lets us slice ``edge_index[:, bounds[K-1]:bounds[K]]`` + to get hop-K edges. + """ + num_sampled_edges = _get_root_dict(data, "num_sampled_edges") + result: dict[EdgeType, list[int]] = {} + for edge_type in data.edge_types: + if edge_type not in num_sampled_edges: + raise ValueError( + f"data.num_sampled_edges[{edge_type}] is missing — required to " + "slice edges per hop. Was this batch produced by a GiGL sampler?" + ) + series = num_sampled_edges[edge_type] + prefix = [0] + running = 0 + for n in series.tolist() if torch.is_tensor(series) else series: + running += int(n) + prefix.append(running) + result[edge_type] = prefix + return result + + +def _get_root_dict(data: HeteroData, attr: str) -> dict: + """Read a sampler-set dict (``num_sampled_nodes`` / ``num_sampled_edges``) + from the HeteroData root; raise ``ValueError`` if absent.""" + try: + value = getattr(data, attr) + except (AttributeError, KeyError) as e: + raise ValueError( + f"data.{attr} is missing — required for hop accounting. " + "Was this batch produced by a GiGL sampler?" + ) from e + if not isinstance(value, dict): + raise ValueError( + f"Expected data.{attr} to be a dict keyed by type; got " + f"{type(value).__name__}." + ) + return value + + +def _expand_frontier( + walk_src: torch.Tensor, + walk_dst: torch.Tensor, + frontier_nodes: torch.Tensor, + frontier_seeds: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Expand the ``(node, seed)`` frontier by one hop. + + For each ``(node, seed)`` pair, enumerates every edge where ``walk_src == + node`` and emits ``(walk_dst, seed)`` pairs. ``walk_src`` / ``walk_dst`` + let the caller choose which side of the edge corresponds to the frontier + (so the inspector works under either edge orientation). + """ + device = walk_src.device + sort_idx = torch.argsort(walk_src, stable=True) + sorted_src = walk_src[sort_idx] + sorted_dst = walk_dst[sort_idx] + + range_start = torch.searchsorted(sorted_src, frontier_nodes, right=False) + range_end = torch.searchsorted(sorted_src, frontier_nodes, right=True) + lengths = range_end - range_start + + total = int(lengths.sum().item()) + if total == 0: + return frontier_nodes[:0], frontier_seeds[:0] + + new_seeds = frontier_seeds.repeat_interleave(lengths) + cumlen = torch.cat( + [torch.zeros(1, dtype=lengths.dtype, device=device), lengths.cumsum(0)] + ) + positions = torch.arange(total, device=device) + entry_idx = torch.searchsorted(cumlen[1:], positions, right=True) + within = positions - cumlen[entry_idx] + return sorted_dst[range_start[entry_idx] + within], new_seeds + + +def _pick_device(data: HeteroData) -> torch.device: + """Pick a device from any populated tensor in ``data`` (falls back to CPU).""" + for nt in data.node_types: + x = getattr(data[nt], "x", None) + if x is not None: + return x.device + for et in data.edge_types: + ei = data[et].edge_index + if ei.numel() > 0: + return ei.device + return torch.device("cpu") + + +def _hop_stats(counts: torch.Tensor) -> HeteroDataSummary.HopStats: + min_val, max_val = torch.aminmax(counts) + return HeteroDataSummary.HopStats( + min=int(min_val.item()), + med=int(counts.median().item()), + avg=float(counts.float().mean().item()), + max=int(max_val.item()), + ) diff --git a/gigl/analytics/inspect.py b/gigl/analytics/inspect.py new file mode 100644 index 000000000..f916ac512 --- /dev/null +++ b/gigl/analytics/inspect.py @@ -0,0 +1,129 @@ +"""Inspection utilities for GiGL distributed training. + +Includes per-batch fanout summaries (``summary``) for in-loop sanity and a +startup-time partition + sampler RAM logger (``log_startup_diagnostics``) +for catching silent failure modes at process boot. +""" + +import torch +from torch_geometric.data import HeteroData + +from gigl.analytics._inspect_impl_ai import HeteroDataSummary, _summary_impl +from gigl.common.logger import Logger +from gigl.distributed import DistDataset + +logger = Logger() + +# GLT allocates ~64 MB per sampling worker by default; used for the +# back-of-envelope RAM estimate logged at startup. +# TODO: (svij) - this should be auto configured i.e. not needed. +_GLT_DEFAULT_RAM_MB_PER_SAMPLING_WORKER = 64 + + +def summary(data: HeteroData) -> HeteroDataSummary: + """Per-seed fanout summary for a sampled HeteroData batch. + + Auto-detects the seed node type (the unique node type with ``batch_size > 0``) + and the number of hops (from ``data.num_sampled_nodes[seed_type]``). At each + hop, walks every edge type whose ``num_sampled_edges`` slice contains hop-K + edges. Edges are followed from whichever end is in the current frontier + (so the inspector works under both ``edge_dir="in"`` and ``edge_dir="out"`` + — under ``"out"`` the loader stores edges reversed, putting the seed on the + destination side). + + ``str(summary(data))`` produces: + ``"seeds=N hop1(min=X med=Y avg=Z max=W) hop2(...) ..."``. + + Example: + >>> from gigl.analytics.inspect import summary + >>> result = summary(batch) + >>> print(result) + 'seeds=128 hop1(min=3 med=10 avg=12.5 max=25) hop2(min=12 med=80 avg=91.2 max=240)' + + Args: + data: HeteroData batch produced by a GiGL neighbor loader. Must carry + sampler metadata: ``batch_size`` on exactly one node type and the + root-level dicts ``data.num_sampled_nodes`` (keyed by node type) + and ``data.num_sampled_edges`` (keyed by edge type). + + Returns: + ``HeteroDataSummary`` with the seed count and per-hop + ``HeteroDataSummary.HopStats``. + + Raises: + ValueError: zero or multiple node types have ``batch_size > 0``, or + sampler metadata is missing for the seed type / any edge type. + """ + return _summary_impl(data) + + +def log_startup_diagnostics( + rank: int, + world_size: int, + dataset: DistDataset, + sampling_workers_per_process: int, + sampling_worker_shared_channel_size: str, +) -> None: + """Log sampler-RAM estimate and local partition counts at startup. + + Surfaces two silent failure modes in GLT-based distributed training: + sampler-worker RAM blowup (silent OOM) and partition misload (a rank + receives an empty or wrong shard and silently overfits to it). + + Call once per rank from the training process bootstrap, after the + dataset has been built and the distributed process group is initialized. + Emits one INFO line for the sampler RAM accounting, one INFO line for + the local node counts, and a WARNING for every node type with zero + local nodes. + + Example: + >>> from gigl.analytics.inspect import log_startup_diagnostics + >>> log_startup_diagnostics( + ... rank=0, + ... world_size=8, + ... dataset=dataset, + ... sampling_workers_per_process=4, + ... sampling_worker_shared_channel_size="4GB", + ... ) + + Args: + rank: Global rank of the calling process. + world_size: Total number of ranks in the distributed group. + dataset: Built ``DistDataset`` with ``node_ids`` populated. + sampling_workers_per_process: Number of GLT sampling workers per + training process; used for the RAM estimate. + sampling_worker_shared_channel_size: Shared-channel size string + (e.g. ``"4GB"``) passed to GLT; logged for visibility. + + Raises: + ValueError: ``dataset.node_ids`` is ``None`` (dataset not built). + """ + ram_mb_per_rank = ( + sampling_workers_per_process * _GLT_DEFAULT_RAM_MB_PER_SAMPLING_WORKER + ) + logger.info( + f"rank={rank} sampler RAM/rank: " + f"workers={sampling_workers_per_process} " + f"channel={sampling_worker_shared_channel_size} " + f"≈ {ram_mb_per_rank} MB/rank × world_size={world_size}" + ) + + node_ids = dataset.node_ids + if node_ids is None: + raise ValueError("dataset.node_ids is None — dataset not built") + + if isinstance(node_ids, torch.Tensor): + count = node_ids.numel() + logger.info(f"rank={rank} local node count: {count}") + if count == 0: + logger.warning(f"rank={rank} has 0 nodes — partition misload?") + return + + node_counts = {nt: node_ids[nt].numel() for nt in node_ids} + logger.info(f"rank={rank} local node counts per type: {node_counts}") + for nt, count in node_counts.items(): + if count == 0: + logger.warning(f"rank={rank} has 0 {nt} nodes — partition misload?") + + +__all__ = ["HeteroDataSummary", "log_startup_diagnostics", "summary"] diff --git a/tests/unit/analytics/__init__.py b/tests/unit/analytics/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/analytics/inspect_test.py b/tests/unit/analytics/inspect_test.py new file mode 100644 index 000000000..4419f8110 --- /dev/null +++ b/tests/unit/analytics/inspect_test.py @@ -0,0 +1,428 @@ +"""Unit tests for gigl.analytics.inspect.""" + +import torch +import torch.multiprocessing as mp +from absl.testing import absltest +from graphlearn_torch.distributed import shutdown_rpc +from torch_geometric.data import HeteroData + +from gigl.analytics.inspect import HeteroDataSummary, summary +from gigl.distributed.dist_ablp_neighborloader import DistABLPLoader +from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.distributed_neighborloader import DistNeighborLoader +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation +from gigl.types.graph import ( + FeaturePartitionData, + GraphPartitionData, + PartitionOutput, + message_passing_to_negative_label, + message_passing_to_positive_label, +) +from tests.test_assets.distributed.utils import create_test_process_group +from tests.test_assets.test_case import TestCase + +_USER = NodeType("user") +_STORY = NodeType("story") +_USER_TO_STORY = EdgeType(_USER, Relation("to"), _STORY) +_STORY_TO_USER = EdgeType(_STORY, Relation("to"), _USER) + + +def _build_synthetic_hetero_dataset() -> DistDataset: + """Build a tiny in-memory heterogeneous DistDataset for inspector tests. + + Graph: 5 users, 5 stories. + + user → story story → user + (1 per user) (uneven fanout) + ───────────── ────────────── + u0 ──→ s0 s0 ──┬──→ u0 + u1 ──→ s1 └──→ u1 + u2 ──→ s2 s1 ──────→ u2 + u3 ──→ s3 s2 ──┬──→ u3 + u4 ──→ s4 └──→ u4 + s3 (no outgoing) + s4 (no outgoing) + """ + partition_output = PartitionOutput( + node_partition_book={ + _USER: torch.zeros(5), + _STORY: torch.zeros(5), + }, + edge_partition_book={ + _USER_TO_STORY: torch.zeros(5), + _STORY_TO_USER: torch.zeros(5), + }, + partitioned_edge_index={ + _USER_TO_STORY: GraphPartitionData( + edge_index=torch.tensor( + [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]], dtype=torch.long + ), + edge_ids=None, + ), + _STORY_TO_USER: GraphPartitionData( + edge_index=torch.tensor( + [[0, 0, 1, 2, 2], [0, 1, 2, 3, 4]], dtype=torch.long + ), + edge_ids=None, + ), + }, + partitioned_node_features={ + _USER: FeaturePartitionData(feats=torch.zeros(5, 2), ids=torch.arange(5)), + _STORY: FeaturePartitionData(feats=torch.zeros(5, 2), ids=torch.arange(5)), + }, + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + return dataset + + +def _run_inspect_summary_hetero(_, dataset: DistDataset) -> None: + """Subprocess body: build a loader, take one batch, validate summary(). + + All 5 users are seeds. Tracing each seed's 2-hop walk through the + synthetic graph (see ``_build_synthetic_hetero_dataset``): + + seed │ hop 1 │ hop 2 │ hop1 count │ hop2 count + ─────┼───────────┼────────────────┼────────────┼─────────── + u0 │ → {s0} │ → {u0, u1} │ 1 │ 2 + u1 │ → {s1} │ → {u2} │ 1 │ 1 + u2 │ → {s2} │ → {u3, u4} │ 1 │ 2 + u3 │ → {s3} │ → ∅ │ 1 │ 0 + u4 │ → {s4} │ → ∅ │ 1 │ 0 + + hop 1 counts = [1, 1, 1, 1, 1] → min=1 med=1 avg=1.0 max=1 + hop 2 counts = [2, 1, 2, 0, 0] → min=0 med=1 avg=1.0 max=2 + + GiGL's patch_fanout_for_sampling rejects -1 despite the loader docstring + claiming "all neighbors" is supported, so we pass a fanout that exceeds + every per-node out-degree in the synthetic graph (saturates to all). + """ + create_test_process_group() + assert isinstance(dataset.node_ids, dict) + loader = DistNeighborLoader( + dataset=dataset, + input_nodes=(_USER, dataset.node_ids[_USER]), # ty: ignore[invalid-argument-type] TODO(ty-torch-keyed-access): fix ty false positives for torch-backed keyed container access. + num_neighbors=[100, 100], + batch_size=5, + pin_memory_device=torch.device("cpu"), + ) + + batch = next(iter(loader)) + assert isinstance(batch, HeteroData) + + result = summary(batch) + + assert isinstance(result, HeteroDataSummary) + assert result.seeds == 5 + assert len(result.per_hop) == 2 + + hop1, hop2 = result.per_hop + assert hop1 == HeteroDataSummary.HopStats(min=1, med=1, avg=1.0, max=1) + assert hop2 == HeteroDataSummary.HopStats(min=0, med=1, avg=1.0, max=2) + + assert str(result) == ( + "seeds=5 hop1(min=1 med=1 avg=1.0 max=1) hop2(min=0 med=1 avg=1.0 max=2)" + ) + + shutdown_rpc() + + +def _build_synthetic_ablp_dataset() -> DistDataset: + """Build a tiny in-memory ABLP DistDataset with positive/negative labels. + + Anchors: u0, u1 (2 users). Supervision edge type: user → story. + + message-passing edges label edges (stripped from batch.edge_types, + ───────────────────── surfaced as batch.y_positive / y_negative) + user → story ───────────────────────────────────────── + u0 ──→ s0 positive labels (user → story): + u1 ──→ s1 u0 ──→ s2 + story → user u1 ──→ s3 + s0 ──→ u0 negative labels (user → story): + s1 ──┬──→ u0 u0 ──→ s4 + └──→ u1 u1 ──→ s5 + """ + positive_et = message_passing_to_positive_label(_USER_TO_STORY) + negative_et = message_passing_to_negative_label(_USER_TO_STORY) + + edge_index = { + _USER_TO_STORY: torch.tensor([[0, 1], [0, 1]], dtype=torch.long), + _STORY_TO_USER: torch.tensor([[0, 1, 1], [0, 0, 1]], dtype=torch.long), + positive_et: torch.tensor([[0, 1], [2, 3]], dtype=torch.long), + negative_et: torch.tensor([[0, 1], [4, 5]], dtype=torch.long), + } + + partition_output = PartitionOutput( + node_partition_book={ + _USER: torch.zeros(2), + _STORY: torch.zeros(6), + }, + edge_partition_book={ + et: torch.zeros(int(ei.max().item()) + 1) for et, ei in edge_index.items() + }, + partitioned_edge_index={ + et: GraphPartitionData(edge_index=ei, edge_ids=torch.arange(ei.size(1))) + for et, ei in edge_index.items() + }, + partitioned_node_features={ + _USER: FeaturePartitionData(feats=torch.zeros(2, 2), ids=torch.arange(2)), + _STORY: FeaturePartitionData(feats=torch.zeros(6, 2), ids=torch.arange(6)), + }, + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + dataset = DistDataset(rank=0, world_size=1, edge_dir="out") + dataset.build(partition_output=partition_output) + return dataset + + +def _run_inspect_summary_ablp(_, dataset: DistDataset) -> None: + """Subprocess body: build an ABLP loader, take one batch, validate summary(). + + Anchors: u0, u1 — both seeded. + + seed │ hop 1 │ hop 2 │ hop1 count │ hop2 count + ─────┼───────────┼────────────────┼────────────┼─────────── + u0 │ → {s0} │ → {u0} │ 1 │ 1 + u1 │ → {s1} │ → {u0, u1} │ 1 │ 2 + + hop 1 counts = [1, 1] → min=1 med=1 avg=1.0 max=1 + hop 2 counts = [1, 2] → min=1 med=1 avg=1.5 max=2 + + Label stories (s2, s3, s4, s5) are added to the sampling frontier + internally — they have no outgoing message-passing edges in this graph, + so they do not contribute to the per-anchor counts. The batch carries + ``y_positive`` and ``y_negative`` dicts; label edge types are stripped + from ``batch.edge_types`` by the loader. + """ + create_test_process_group() + assert isinstance(dataset.node_ids, dict) + loader = DistABLPLoader( + dataset=dataset, + num_neighbors=[100, 100], + input_nodes=(_USER, torch.tensor([0, 1])), + supervision_edge_type=_USER_TO_STORY, + batch_size=2, + pin_memory_device=torch.device("cpu"), + ) + + batch = next(iter(loader)) + assert isinstance(batch, HeteroData) + + # Verify we are exercising the ABLP path: positive/negative label dicts + # must be attached, and label edge types must NOT appear in edge_types + # (they get stripped by the loader). + assert hasattr(batch, "y_positive") + assert hasattr(batch, "y_negative") + assert isinstance(batch.y_positive, dict) and len(batch.y_positive) > 0 + assert isinstance(batch.y_negative, dict) and len(batch.y_negative) > 0 + positive_et = message_passing_to_positive_label(_USER_TO_STORY) + negative_et = message_passing_to_negative_label(_USER_TO_STORY) + assert positive_et not in batch.edge_types + assert negative_et not in batch.edge_types + + result = summary(batch) + + assert isinstance(result, HeteroDataSummary) + assert result.seeds == 2 + hop1, hop2 = result.per_hop + assert hop1 == HeteroDataSummary.HopStats(min=1, med=1, avg=1.0, max=1) + assert hop2 == HeteroDataSummary.HopStats(min=1, med=1, avg=1.5, max=2) + assert str(result) == ( + "seeds=2 hop1(min=1 med=1 avg=1.0 max=1) hop2(min=1 med=1 avg=1.5 max=2)" + ) + + shutdown_rpc() + + +class SummaryIntegrationTest(TestCase): + """End-to-end tests against real loader output (DistNeighborLoader + DistABLPLoader).""" + + def test_summary_with_dist_neighbor_loader(self): + dataset = _build_synthetic_hetero_dataset() + mp.spawn(fn=_run_inspect_summary_hetero, args=(dataset,)) + + def test_summary_with_dist_ablp_loader(self): + dataset = _build_synthetic_ablp_dataset() + mp.spawn(fn=_run_inspect_summary_ablp, args=(dataset,)) + + +class SummaryValidationTest(TestCase): + """Hand-rolled HeteroData inputs that exercise both strict-contract + guardrails and per-seed stat correctness. + + Hand-rolling lets us pre-compute the exact expected fanout from the + edge_index without going through a sampler. + """ + + def test_no_batch_size_raises(self): + """No node type has batch_size > 0 → can't pick a seed type.""" + data = HeteroData( + { + _USER: {"x": torch.zeros((3, 1))}, + _STORY: {"x": torch.zeros((2, 1))}, + } + ) + with self.assertRaises(ValueError): + summary(data) + + def test_multiple_batch_sizes_raises(self): + """Two node types both have batch_size > 0 → ambiguous seed type.""" + data = HeteroData( + { + _USER: {"x": torch.zeros((3, 1)), "batch_size": 2}, + _STORY: {"x": torch.zeros((4, 1)), "batch_size": 3}, + } + ) + with self.assertRaises(ValueError): + summary(data) + + def test_missing_num_sampled_nodes_raises(self): + """data.num_sampled_nodes absent → can't infer hop count.""" + data = HeteroData( + { + _USER: {"x": torch.zeros((3, 1)), "batch_size": 3}, + _STORY: {"x": torch.zeros((2, 1))}, + _USER_TO_STORY: { + "edge_index": torch.tensor([[0, 1], [0, 1]], dtype=torch.long), + }, + } + ) + data.num_sampled_edges = {_USER_TO_STORY: torch.tensor([2])} + with self.assertRaises(ValueError): + summary(data) + + def test_missing_num_sampled_edges_raises(self): + """data.num_sampled_edges absent → can't slice edges per hop.""" + data = HeteroData( + { + _USER: {"x": torch.zeros((3, 1)), "batch_size": 3}, + _STORY: {"x": torch.zeros((2, 1))}, + _USER_TO_STORY: { + "edge_index": torch.tensor([[0, 1], [0, 1]], dtype=torch.long), + }, + } + ) + data.num_sampled_nodes = {_USER: torch.tensor([3, 2])} + with self.assertRaises(ValueError): + summary(data) + + def test_two_hop_uneven_fanout(self): + """Seeds: u0, u1, u2. + + user → story story → user + u0 ──→ s0 s0 ──┬──→ u3 + u1 ──┬──→ s0 └──→ u4 + └──→ s1 s1 ──────→ u0 + u2 ──→ s2 s2 ──────→ u3 + + seed │ hop 1 │ hop 2 │ counts + ─────┼──────────────┼────────────────────┼───────── + u0 │ → {s0} │ → {u3, u4} │ 1, 2 + u1 │ → {s0, s1} │ → {u3, u4, u0} │ 2, 3 + u2 │ → {s2} │ → {u3} │ 1, 1 + + hop1 counts = [1, 2, 1] → min=1 med=1 avg=1.3 max=2 + hop2 counts = [2, 3, 1] → min=1 med=2 avg=2.0 max=3 + """ + data = HeteroData( + { + _USER: {"x": torch.zeros((5, 1)), "batch_size": 3}, + _STORY: {"x": torch.zeros((3, 1))}, + _USER_TO_STORY: { + "edge_index": torch.tensor( + [[0, 1, 1, 2], [0, 0, 1, 2]], dtype=torch.long + ), + }, + _STORY_TO_USER: { + "edge_index": torch.tensor( + [[0, 0, 1, 2], [3, 4, 0, 3]], dtype=torch.long + ), + }, + } + ) + data.num_sampled_nodes = { + _USER: torch.tensor([3, 0, 2]), + _STORY: torch.tensor([0, 3, 0]), + } + data.num_sampled_edges = { + _USER_TO_STORY: torch.tensor([4, 0]), + _STORY_TO_USER: torch.tensor([0, 4]), + } + + result = summary(data) + + self.assertEqual(result.seeds, 3) + self.assertEqual(len(result.per_hop), 2) + + hop1, hop2 = result.per_hop + self.assertEqual(hop1.min, 1) + self.assertEqual(hop1.med, 1) + self.assertEqual(hop1.max, 2) + self.assertAlmostEqual(hop1.avg, 4 / 3, places=5) + + self.assertEqual(hop2.min, 1) + self.assertEqual(hop2.med, 2) + self.assertEqual(hop2.max, 3) + self.assertAlmostEqual(hop2.avg, 2.0, places=5) + + self.assertEqual( + str(result), + "seeds=3 hop1(min=1 med=1 avg=1.3 max=2) hop2(min=1 med=2 avg=2.0 max=3)", + ) + + def test_zero_edges_all_hops(self): + """Seeds: u0, u1, u2 — no edges anywhere. + + user → story story → user + (empty) (empty) + + seed │ hop 1 │ hop 2 │ counts + ─────┼───────┼───────┼───────── + u0 │ → ∅ │ → ∅ │ 0, 0 + u1 │ → ∅ │ → ∅ │ 0, 0 + u2 │ → ∅ │ → ∅ │ 0, 0 + + hop1 counts = [0, 0, 0] → min=0 med=0 avg=0.0 max=0 + hop2 counts = [0, 0, 0] → min=0 med=0 avg=0.0 max=0 + """ + data = HeteroData( + { + _USER: {"x": torch.zeros((3, 1)), "batch_size": 3}, + _STORY: {"x": torch.zeros((1, 1))}, + _USER_TO_STORY: { + "edge_index": torch.empty((2, 0), dtype=torch.long), + }, + _STORY_TO_USER: { + "edge_index": torch.empty((2, 0), dtype=torch.long), + }, + } + ) + data.num_sampled_nodes = { + _USER: torch.tensor([3, 0, 0]), + _STORY: torch.tensor([0, 0, 0]), + } + data.num_sampled_edges = { + _USER_TO_STORY: torch.tensor([0, 0]), + _STORY_TO_USER: torch.tensor([0, 0]), + } + + result = summary(data) + + self.assertEqual(result.seeds, 3) + self.assertEqual(len(result.per_hop), 2) + for hop_stats in result.per_hop: + self.assertEqual(hop_stats.min, 0) + self.assertEqual(hop_stats.med, 0) + self.assertEqual(hop_stats.max, 0) + self.assertEqual(hop_stats.avg, 0.0) + self.assertEqual( + str(result), + "seeds=3 hop1(min=0 med=0 avg=0.0 max=0) hop2(min=0 med=0 avg=0.0 max=0)", + )