Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions gigl/analytics/SPEC.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# SPEC: `gigl.analytics`

Per-symbol contracts for **agent-owned** public symbols in `gigl.analytics`. Each section below is the contract for one
agent-owned symbol; the matching tests under `tests/unit/analytics/` pin its behavior, and the implementation in
`_*_impl_ai.py` is regenerable from this spec.

Other public symbols in `gigl.analytics` are human-owned and not spec'd here — their docstrings in source are
authoritative.

## Table of Contents

Agent-owned (spec'd below):

- [`summary`](#summary) — Per-seed fanout summary for a sampled HeteroData batch.
- `HeteroDataSummary` — Frozen dataclass returned by `summary`; see the `summary` section.

Human-owned (see docstrings in source):

- `log_startup_diagnostics` (`inspect.py`) — Logs sampler-RAM estimate + local partition counts at training startup.

______________________________________________________________________

## `summary(data: HeteroData) -> HeteroDataSummary`

### Purpose

Per-seed fanout summary for sampled `HeteroData` batches from a GiGL's neighbor loader implementations. Use it in
progress logs to spot empty hops, degenerate fanouts, or seed-side mis-attribution at a glance.

```python
from gigl.analytics.inspect import summary

logger.info(f"fanout: {summary(batch)}")
# → "seeds=128 hop1(min=3 med=10 avg=12.5 max=25) hop2(min=12 med=80 avg=91.2 max=240)"
```

Use with batches produced by `DistNeighborLoader` and `DistABLPLoader` (`gigl/distributed/`); both attach the required
sampler metadata.

### Interface

```python
summary(data: HeteroData) -> HeteroDataSummary
```

`HeteroDataSummary` — `@dataclass(frozen=True)`:

- `seeds: int`
- `per_hop: list[HeteroDataSummary.HopStats]` — one per hop, ordered 1..K.
- `__str__` renders `"seeds=N hop1(min=X med=Y avg=Z max=W) hop2(...) ..."` with `avg` formatted to one decimal place.

`HeteroDataSummary.HopStats` — nested `@dataclass(frozen=True)` with `min: int`, `med: int`, `avg: float`, `max: int`:
per-seed neighbor count distribution at that hop.

### Input requirements

The batch must carry sampler-set metadata, or `summary` raises `ValueError`:

- `data[seed_type].batch_size > 0` on exactly one node type.
- `data.num_sampled_nodes: dict[NodeType, Tensor]` at the HeteroData root — entry for the seed type, length `K + 1`.
- `data.num_sampled_edges: dict[EdgeType, Tensor]` at the HeteroData root — entry for every edge type in
`data.edge_types`, length `K`.

### Error contract

`summary` raises `ValueError` (no other exception types, no fallback) when:

- Zero or multiple node types have `batch_size > 0`.
- `data.num_sampled_nodes` is missing, not a dict, or lacks the seed type.
- `data.num_sampled_edges` is missing, not a dict, or lacks any edge type in `data.edge_types`.
- `num_sampled_nodes[seed_type]` implies fewer than 1 hop.

### Non-goals

- Homogeneous `Data` objects — users can call `summary` by converting to HeteroData first.

### Verification

`tests/unit/analytics/inspect_test.py` pins the behavior with hand-rolled HeteroData fixtures (per-seed walk tables in
the docstrings), `ValueError` guardrails for each contract violation, canonical `__str__` assertions, and end-to-end
runs against `DistNeighborLoader` and `DistABLPLoader`. Implementations must pass that suite.
248 changes: 248 additions & 0 deletions gigl/analytics/_inspect_impl_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
# AI-OWNED FILE
# spec: ./SPEC.md
# last-generated: 2026-05-22T00:00:00Z
# ---
"""Implementation of the HeteroData batch inspector.

Contract is defined in ``gigl/analytics/SPEC.md``. The public surface lives in
``gigl.analytics.inspect``; this module is regenerable and not intended for
direct import by application code.
"""

from dataclasses import dataclass

import torch
from torch_geometric.data import HeteroData
from torch_geometric.typing import EdgeType, NodeType


@dataclass(frozen=True)
class HeteroDataSummary:
"""Diagnostic summary of a sampled HeteroData batch.

Attributes:
seeds: Number of seed nodes in the batch.
per_hop: ``HopStats`` per hop, in order from hop 1 to hop K.
"""

@dataclass(frozen=True)
class HopStats:
"""Per-seed neighbor count distribution at one hop."""

min: int
med: int
avg: float
max: int

seeds: int
per_hop: list[HopStats]

def __str__(self) -> str:
parts = [f"seeds={self.seeds}"]
for k, s in enumerate(self.per_hop, 1):
parts.append(f"hop{k}(min={s.min} med={s.med} avg={s.avg:.1f} max={s.max})")
return " ".join(parts)


def _summary_impl(data: HeteroData) -> HeteroDataSummary:
"""Implementation. See ``gigl.analytics.inspect.summary`` for the public contract."""
seed_type = _detect_seed_type(data)
batch_size = int(data[seed_type].batch_size)
num_hops = _detect_num_hops(data, seed_type)
edge_hop_bounds = _hop_boundaries(data)

device = _pick_device(data)
seeds = torch.arange(batch_size, dtype=torch.long, device=device)
frontier: dict[NodeType, tuple[torch.Tensor, torch.Tensor]] = {
seed_type: (seeds, seeds)
}

per_hop: list[HeteroDataSummary.HopStats] = []
for hop in range(1, num_hops + 1):
hop_counts = torch.zeros(batch_size, dtype=torch.long, device=device)
new_parts: dict[NodeType, list[tuple[torch.Tensor, torch.Tensor]]] = {}

for edge_type in data.edge_types:
src_type, dst_type = edge_type[0], edge_type[2]
# Walk from whichever side of the edge a frontier already covers.
# Edge_dir="out" loaders reverse the stored edges, so the seed
# often sits on the dst side at hop 1.
if src_type in frontier:
f_nodes, f_seeds = frontier[src_type]
walk_ei_src = 0
other_type = dst_type
elif dst_type in frontier:
f_nodes, f_seeds = frontier[dst_type]
walk_ei_src = 1
other_type = src_type
else:
continue
if f_nodes.numel() == 0:
continue

start, end = edge_hop_bounds[edge_type][hop - 1 : hop + 1]
if start == end:
continue
hop_ei = data[edge_type].edge_index[:, start:end]
walk_src = hop_ei[walk_ei_src]
walk_dst = hop_ei[1 - walk_ei_src]

new_nodes, new_seeds = _expand_frontier(
walk_src, walk_dst, f_nodes, f_seeds
)
if new_seeds.numel() == 0:
continue

hop_counts.scatter_add_(
0,
new_seeds,
torch.ones(new_seeds.numel(), dtype=torch.long, device=device),
)
if hop < num_hops:
new_parts.setdefault(other_type, []).append((new_nodes, new_seeds))

per_hop.append(_hop_stats(hop_counts))

if hop < num_hops:
frontier = {
nt: (
torch.cat([n for n, _ in parts]),
torch.cat([s for _, s in parts]),
)
for nt, parts in new_parts.items()
}

return HeteroDataSummary(seeds=batch_size, per_hop=per_hop)


def _detect_seed_type(data: HeteroData) -> NodeType:
seed_types = [
nt
for nt in data.node_types
if getattr(data[nt], "batch_size", None) is not None and data[nt].batch_size > 0
]
if len(seed_types) != 1:
raise ValueError(
f"Expected exactly one node type with batch_size > 0; found {seed_types}"
)
return seed_types[0]


def _detect_num_hops(data: HeteroData, seed_type: NodeType) -> int:
num_sampled = _get_root_dict(data, "num_sampled_nodes")
if seed_type not in num_sampled:
raise ValueError(
f"data.num_sampled_nodes[{seed_type}] is missing — required to "
"infer the number of hops. Was this batch produced by a GiGL "
"sampler?"
)
series = num_sampled[seed_type]
num_hops = len(series) - 1
if num_hops < 1:
raise ValueError(
f"data.num_sampled_nodes[{seed_type}] implies {num_hops} hops; "
"expected at least 1."
)
return num_hops


def _hop_boundaries(data: HeteroData) -> dict[EdgeType, list[int]]:
"""Return prefix sums of ``data.num_sampled_edges[edge_type]`` per edge type.

For edge type E with sampled-edge series ``[n1, n2, ...]``, the result
``[0, n1, n1+n2, ...]`` lets us slice ``edge_index[:, bounds[K-1]:bounds[K]]``
to get hop-K edges.
"""
num_sampled_edges = _get_root_dict(data, "num_sampled_edges")
result: dict[EdgeType, list[int]] = {}
for edge_type in data.edge_types:
if edge_type not in num_sampled_edges:
raise ValueError(
f"data.num_sampled_edges[{edge_type}] is missing — required to "
"slice edges per hop. Was this batch produced by a GiGL sampler?"
)
series = num_sampled_edges[edge_type]
prefix = [0]
running = 0
for n in series.tolist() if torch.is_tensor(series) else series:
running += int(n)
prefix.append(running)
result[edge_type] = prefix
return result


def _get_root_dict(data: HeteroData, attr: str) -> dict:
"""Read a sampler-set dict (``num_sampled_nodes`` / ``num_sampled_edges``)
from the HeteroData root; raise ``ValueError`` if absent."""
try:
value = getattr(data, attr)
except (AttributeError, KeyError) as e:
raise ValueError(
f"data.{attr} is missing — required for hop accounting. "
"Was this batch produced by a GiGL sampler?"
) from e
if not isinstance(value, dict):
raise ValueError(
f"Expected data.{attr} to be a dict keyed by type; got "
f"{type(value).__name__}."
)
return value


def _expand_frontier(
walk_src: torch.Tensor,
walk_dst: torch.Tensor,
frontier_nodes: torch.Tensor,
frontier_seeds: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Expand the ``(node, seed)`` frontier by one hop.

For each ``(node, seed)`` pair, enumerates every edge where ``walk_src ==
node`` and emits ``(walk_dst, seed)`` pairs. ``walk_src`` / ``walk_dst``
let the caller choose which side of the edge corresponds to the frontier
(so the inspector works under either edge orientation).
"""
device = walk_src.device
sort_idx = torch.argsort(walk_src, stable=True)
sorted_src = walk_src[sort_idx]
sorted_dst = walk_dst[sort_idx]

range_start = torch.searchsorted(sorted_src, frontier_nodes, right=False)
range_end = torch.searchsorted(sorted_src, frontier_nodes, right=True)
lengths = range_end - range_start

total = int(lengths.sum().item())
if total == 0:
return frontier_nodes[:0], frontier_seeds[:0]

new_seeds = frontier_seeds.repeat_interleave(lengths)
cumlen = torch.cat(
[torch.zeros(1, dtype=lengths.dtype, device=device), lengths.cumsum(0)]
)
positions = torch.arange(total, device=device)
entry_idx = torch.searchsorted(cumlen[1:], positions, right=True)
within = positions - cumlen[entry_idx]
return sorted_dst[range_start[entry_idx] + within], new_seeds


def _pick_device(data: HeteroData) -> torch.device:
"""Pick a device from any populated tensor in ``data`` (falls back to CPU)."""
for nt in data.node_types:
x = getattr(data[nt], "x", None)
if x is not None:
return x.device
for et in data.edge_types:
ei = data[et].edge_index
if ei.numel() > 0:
return ei.device
return torch.device("cpu")


def _hop_stats(counts: torch.Tensor) -> HeteroDataSummary.HopStats:
min_val, max_val = torch.aminmax(counts)
return HeteroDataSummary.HopStats(
min=int(min_val.item()),
med=int(counts.median().item()),
avg=float(counts.float().mean().item()),
max=int(max_val.item()),
)
Loading