diff --git a/.dockerignore b/.dockerignore index 66349c8a4..09e319073 100644 --- a/.dockerignore +++ b/.dockerignore @@ -106,6 +106,7 @@ venv.bak/ # Visual Code .vscode/ +*.code-workspace # terraform .terraform/ diff --git a/.gitignore b/.gitignore index 044d3b64c..e39b39890 100644 --- a/.gitignore +++ b/.gitignore @@ -115,6 +115,7 @@ venv.bak/ # local dev stuff +*.code-workspace .claude/ .devcontainer/ *.ipynb diff --git a/Dockerfile b/Dockerfile index 968f93043..2c394a889 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,11 @@ # syntax=docker/dockerfile:1 ARG PYTHON_VERSION=3.12 -ARG BASE_IMAGE=tiangolo/uwsgi-nginx-flask:python${PYTHON_VERSION} +# Pin by digest. Without it, upstream rebuilds of the +# `python3.12` tag invalidate the Stage-1 cache and pull in newer +# transitive Python packages (e.g. importlib_metadata) that conflict +# with our requirements.txt pins. Bump the digest manually when you +# want to pull a fresher base. +ARG BASE_IMAGE=tiangolo/uwsgi-nginx-flask:python${PYTHON_VERSION}@sha256:329d84f4cc50ccd14d60eb02384713b4ae8723eddefda9fda342c7c3f17cdcb1 ###################################################### diff --git a/cloudbuild.yaml b/cloudbuild.yaml index 5f28ab80b..434248a9b 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -5,19 +5,26 @@ steps: args: ["-c", "docker login --username=$$USERNAME --password=$$PASSWORD"] secretEnv: ["USERNAME", "PASSWORD"] + # Build + push in one BuildKit invocation using a docker-container + # builder (required for registry-type cache export). The builder + # `--use` setting is client-side and doesn't persist across cloudbuild + # steps, so create + use + build must happen in a single step. + # Registry cache at the fixed :buildcache tag lets unchanged stages + # (conda env, bigtable emulator, pip install) reuse the previous + # build's exact layer artifacts, so already-warm nodes only download + # what actually changed on pull. - name: "gcr.io/cloud-builders/docker" entrypoint: "bash" args: - "-c" - | - DOCKER_BUILDKIT=1 docker build -t $$USERNAME/pychunkedgraph:$TAG_NAME . - timeout: 600s - secretEnv: ["USERNAME"] - - # Push the final image to Dockerhub - - name: "gcr.io/cloud-builders/docker" - entrypoint: "bash" - args: ["-c", "docker push $$USERNAME/pychunkedgraph:$TAG_NAME"] + docker buildx create --use --name pcg-builder --driver docker-container + docker buildx build \ + --cache-from type=registry,ref=$$USERNAME/pychunkedgraph:buildcache \ + --cache-to type=registry,ref=$$USERNAME/pychunkedgraph:buildcache,mode=max \ + --push \ + -t $$USERNAME/pychunkedgraph:$TAG_NAME . + timeout: 1800s secretEnv: ["USERNAME"] availableSecrets: diff --git a/docs/precomputed_ocdbt_hybrid.md b/docs/precomputed_ocdbt_hybrid.md new file mode 100644 index 000000000..d1caafd5f --- /dev/null +++ b/docs/precomputed_ocdbt_hybrid.md @@ -0,0 +1,101 @@ +# Hybrid base: precomputed + OCDBT fork (proposal) + +Status: proposal, not implemented. Open question is whether storage and ingest-compute savings justify the read-path complexity. + +## Problem + +PCG ingest copies the entire watershed segmentation into `/ocdbt/base/` in OCDBT format before any CG edit can happen. Per-CG forks at `/ocdbt//` store only the deltas from SV splits. Two costs follow: + +- **Storage**: roughly 2× the segmentation footprint per dataset — original precomputed plus full OCDBT copy. +- **Ingest compute**: a per-chunk pass that reads the precomputed and writes it through the OCDBT driver. Hours of cluster time on TB-scale datasets. + +Both costs are paid up-front, before any user has done a single edit. The proposal here: skip the base copy and serve unedited chunks directly from the raw precomputed directory. Per-CG OCDBT forks remain as the delta store. + +## Why the current architecture has the base copy + +Today's per-CG read spec is: + +``` +neuroglancer_precomputed + └─ kvstore: ocdbt + ├─ base: kvstack [base_layer, fork_manifest, fork_data] + └─ config: { compression, max_inline_value_bytes, ... } +``` + +When a reader asks for chunk key `8_8_40/1024-..._0-128`: + +1. The `neuroglancer_precomputed` driver passes the chunk key to its kvstore (the OCDBT driver). +2. OCDBT looks up the key **in its B+tree**. The B+tree's leaves map chunk keys to values. +3. If the key isn't in the B+tree, OCDBT returns not-found. It does not consult the kvstack any further. + +The three kvstack layers serve OCDBT's *internal* storage (B+tree manifest + node blobs + leaf blobs) — they have no visibility into chunk-key lookups. So the OCDBT B+tree must contain every chunk key the reader will ever ask for, and that's why ingest copies the whole watershed: to populate the B+tree. + +## What tensorstore primitives provide + +Confirmed against tensorstore docs: + +- **`kvstack` routes by exact / prefix match, with no fallthrough on miss.** A layer that claims a key range absorbs misses — they return `state='missing'` and do not cascade to the next layer. So we can't put raw precomputed below an OCDBT layer in a kvstack and expect kvstack to fall through when OCDBT doesn't have a key. +- **No native overlay/fallback kvstore driver.** `kvstack` is the only composition primitive at the kvstore level; it's precedence-based, not fallthrough. +- **OCDBT has no external-blob references.** B+tree leaves either inline the value or point to a data file under the OCDBT directory. There's no way to make a leaf reference a raw GCS precomputed file. +- **Array-level `stack` / `ts.overlay`** layers arrays by spatial domain. In overlapping regions, the later layer takes absolute precedence — missing-in-later does not fall back to earlier. + +No single tensorstore primitive provides "try OCDBT delta first, fall through to raw precomputed on miss." + +## Architectural options + +### A — Two-stage read at the pcg layer + +PCG reads open two handles: the OCDBT fork for the delta, and a raw `neuroglancer_precomputed` reader for the watershed base. For any voxel region, issue both reads and merge with "delta wins where present, base fills the rest." + +- **Pros**: works inside pcg (`lookup_svs_from_seg`, sanity checks, debug tools) without any tensorstore changes. +- **Cons**: every pcg caller that uses `meta.ws_ocdbt` needs to route through a new merging reader. Neuroglancer doesn't benefit — it still gets a single kvstore spec from `dataset_info`. Either NG runs two layers itself (Option B) or we stand up a server-side proxy that does the merge before serving. + +### B — NG-side layer stack + +`dataset_info` publishes two precomputed layers: the raw watershed (read-only base) and the per-CG OCDBT fork (delta). NG composites them — visible segmentation is whichever has data at a given chunk. + +- **Pros**: no change to pcg's read path. Pushes the architecture complexity into the viewer. +- **Cons**: requires NG to treat "missing chunk in delta" as "fall through to base," not "render as background." Default NG behavior is the latter, so a viewer-side or proxy-side shim is likely needed. + +### C — Custom tensorstore kvstore driver + +A new "fallthrough" kvstore driver: read tries layer N, falls through on miss to layer N−1. Implement upstream in tensorstore or fork-and-maintain. + +- **Pros**: cleanest consumer-facing story — pcg and NG both keep using a single kvstore spec. +- **Cons**: tensorstore kvstore drivers are C++. Non-trivial maintenance surface; review/merge timeline if upstreaming. + +### D — Lazy base population (not a win on its own) + +Skip the ingest copy; copy a chunk from precomputed to OCDBT on first edit. Saves ingest compute. Does **not** save storage for reads — unedited chunks still 404 in OCDBT for a reader that doesn't have a fallback. Only useful in combination with A/B/C. + +## Recommendation + +Measure first. Confirm the actual storage and ingest-compute savings on a real dataset and weigh against the engineering cost of A/B/C. + +If the savings justify the work, **A + B together** is the most pragmatic path: +- A gives pcg a single merged-read API. Edits, sanity checks, debug tooling keep working. +- B avoids standing up a proxy service for the viewer by letting NG handle the overlay. + +Both require upstream verification: +- **For A**: confirm that `(x0:x1, y0:y1, z0:z1)` reads on an OCDBT with sparse keys surface missing-ness *per chunk* at the `neuroglancer_precomputed` array layer (not per-region, not silently fill-valued). +- **For B**: confirm NG's segmentation loader can be configured to fall through gaps in one layer to another. If it can't, build a small server-side merging shim — at which point Option A's reader becomes that shim and B reduces to "publish two specs." + +C is the cleanest design but carries the highest cost. Pursue only if A/B turn out to have unworkable semantics. + +## Open questions before any implementation + +1. Does OCDBT's `read_result.state == 'missing'` surface per-chunk at the `neuroglancer_precomputed` array layer, or does the array silently fill missing chunks with fill-value? Verifiable by opening an OCDBT with sparse keys and reading a region that spans present + missing chunks. +2. Does NG distinguish "chunk returned as missing" from "chunk is all fill-value"? If not, a viewer-side overlay needs a shim regardless. +3. What's the actual delta volume per CG over its lifetime? If SV splits eventually touch a significant fraction of chunks, the storage win shrinks toward zero — at which point the simpler architecture (today's full base copy) wins on engineering cost. + +## Files to start from when implementing + +- `pychunkedgraph/graph/ocdbt.py` — spec construction (`build_cg_ocdbt_spec`), base population (`create_base_ocdbt`), fork setup (`fork_base_manifest`). +- `pychunkedgraph/ingest/cli.py`, `pychunkedgraph/ingest/cluster.py` — current base-copy flow. +- `pychunkedgraph/graph/utils/generic.py::get_local_segmentation` — single pcg read entry point that would need the two-stage merge in Option A. + +## Verification (per chosen option) + +- **A**: unit test that simulates a partial-delta OCDBT + raw precomputed and confirms the pcg reader returns the correct labels for spans crossing both. +- **B**: configure an NG link with both layers against a test dataset; compare the rendered segmentation to a known-good reference at edited and unedited regions. +- **C**: a tensorstore build with the new driver passes a fallthrough test (missing key in upper layer resolves from lower layer). diff --git a/pychunkedgraph/app/app_utils.py b/pychunkedgraph/app/app_utils.py index 9d69c3650..888b39ae4 100644 --- a/pychunkedgraph/app/app_utils.py +++ b/pychunkedgraph/app/app_utils.py @@ -229,20 +229,28 @@ def ccs(coordinates_nm_): return ccs coordinates = np.array(coordinates, dtype=int) - coordinates_nm = coordinates * cg.meta.resolution - max_dist_steps = np.array([4, 8, 14, 28], dtype=float) * np.mean(cg.meta.resolution) - - node_ids = np.array(node_ids, dtype=np.uint64) if len(coordinates.shape) != 2: raise cg_exceptions.BadRequest( f"Could not determine supervoxel ID for coordinates " f"{coordinates} - Validation stage." ) - # Fast path: all node_ids are L1 and OCDBT — single seg read for all coords - if cg.meta.ocdbt_seg and np.all(cg.get_chunk_layers(np.unique(node_ids)) == 1): + # OCDBT: always read the current segmentation at the click coords, + # regardless of node_ids layer. + # - 2D slice click: NG sends `node_id` = L1 SV from the slice view. + # That slice can be stale after an SV split; the seg read returns + # the current SV at that voxel (which may have a different root). + # - 3D mesh click: NG sends `node_id` = root; no L1 SV is attached, + # so we have to look it up against current seg anyway. + # `node_ids` are not used as a constraint here. Stale UI surfaces + # downstream as "different roots" with the sv_id->root diagnostic + # mapping added in operation.py / cutting.py. + if cg.meta.ocdbt_seg: return lookup_svs_from_seg(cg.meta, coordinates) + coordinates_nm = coordinates * cg.meta.resolution + max_dist_steps = np.array([4, 8, 14, 28], dtype=float) * np.mean(cg.meta.resolution) + node_ids = np.array(node_ids, dtype=np.uint64) atomic_ids = np.zeros(len(coordinates), dtype=np.uint64) for node_id in np.unique(node_ids): node_id_m = node_ids == node_id diff --git a/pychunkedgraph/app/meshing/common.py b/pychunkedgraph/app/meshing/common.py index 10306543a..ca5277762 100644 --- a/pychunkedgraph/app/meshing/common.py +++ b/pychunkedgraph/app/meshing/common.py @@ -15,7 +15,6 @@ from pychunkedgraph.meshing.manifest import get_children_before_start_layer from pychunkedgraph.meshing.manifest import ManifestCache - __meshing_url_prefix__ = os.environ.get("MESHING_URL_PREFIX", "meshing") @@ -180,3 +179,12 @@ def _remeshing(serialized_cg_info, lvl2_nodes): def clear_manifest_cache(cg, node_id): node_ids = get_children_before_start_layer(cg, node_id, start_layer=2) ManifestCache(cg.graph_id).clear_fragments(node_ids) + + +def clear_manifest_cache_all(cg) -> int: + """Delete every cached manifest fragment for this graph. + + Returns the number of redis keys deleted across both initial and + dynamic caches (they share the ``:`` namespace). + """ + return ManifestCache(cg.graph_id).clear_namespace() diff --git a/pychunkedgraph/app/meshing/v1/routes.py b/pychunkedgraph/app/meshing/v1/routes.py index dda067e90..8ee249942 100644 --- a/pychunkedgraph/app/meshing/v1/routes.py +++ b/pychunkedgraph/app/meshing/v1/routes.py @@ -9,7 +9,6 @@ from pychunkedgraph.app.app_utils import get_cg from pychunkedgraph.app.app_utils import remap_public - bp = Blueprint( "pcg_meshing_v1", __name__, url_prefix=f"/{common.__meshing_url_prefix__}/api/v1" ) @@ -98,3 +97,12 @@ def handle_remesh(table_id): def handle_clear_manifest_cache(table_id, node_id): cg = get_cg(table_id) common.clear_manifest_cache(cg, node_id) + + +@bp.route("/table//clear_manifest_cache", methods=["POST"]) +@auth_requires_permission("admin") +def handle_clear_manifest_cache_all(table_id): + """Drop every cached manifest fragment for this graph.""" + cg = get_cg(table_id) + deleted = common.clear_manifest_cache_all(cg) + return {"deleted": deleted} diff --git a/pychunkedgraph/app/segmentation/common.py b/pychunkedgraph/app/segmentation/common.py index 0ff758c2d..6ca817420 100644 --- a/pychunkedgraph/app/segmentation/common.py +++ b/pychunkedgraph/app/segmentation/common.py @@ -2,6 +2,7 @@ import json import os +import pickle import time from datetime import datetime, timezone from functools import reduce @@ -9,8 +10,8 @@ import numpy as np import pandas as pd -import fastremap from flask import current_app, g, jsonify, make_response, request +from messagingclient import MessagingClient from pytz import UTC from pychunkedgraph import __version__, get_logger @@ -25,9 +26,8 @@ exceptions as cg_exceptions, ) from pychunkedgraph.graph.analysis import pathing -from pychunkedgraph.graph.edits_sv import split_supervoxel from pychunkedgraph.graph.misc import get_contact_sites -from pychunkedgraph.debug.sv_split import check_unsplit_sv_bridges +from pychunkedgraph.graph.sv_split.debug import check_unsplit_sv_bridges from pychunkedgraph.graph.operation import GraphEditOperation from pychunkedgraph.graph import basetypes from pychunkedgraph.meshing import mesh_analysis @@ -106,11 +106,18 @@ def handle_info(table_id): combined_info["verify_mesh"] = cg.meta.custom_data.get("mesh", {}).get( "verify", False ) - mesh_dir = cg.meta.custom_data.get("mesh", {}).get("dir", None) + mesh_meta = cg.meta.custom_data.get("mesh", {}) + mesh_dir = mesh_meta.get("dir", None) if mesh_dir is not None: combined_info["mesh_dir"] = mesh_dir elif combined_info.get("mesh_dir", None) is not None: combined_info["mesh_dir"] = "graphene_meshes" + # `dynamic_mesh_dir` lets a dataset name the unsharded dynamic-mesh + # subdir explicitly. Default `"dynamic"` matches mesh_worker.py's + # fallback and NG's current hardcoded subdir name in graphene + # backend.ts (`${fragmentUrl}dynamic/`). NG must read + # this info field before non-default values route correctly. + combined_info["dynamic_mesh_dir"] = mesh_meta.get("dynamic_mesh_dir", "dynamic") return jsonify(combined_info) @@ -322,15 +329,13 @@ def publish_edit( is_priority=True, remesh: bool = True, ): - import pickle - - from messagingclient import MessagingClient - + downsample = bool(result.seg_bbox) attributes = { "table_id": table_id, "user_id": user_id, "remesh_priority": "true" if is_priority else "false", "remesh": "true" if remesh else "false", + "downsample": "true" if downsample else "false", } payload = { "operation_id": int(result.operation_id), @@ -338,6 +343,13 @@ def publish_edit( "new_root_ids": result.new_root_ids.tolist(), "old_root_ids": result.old_root_ids.tolist(), } + if downsample: + # Each entry is the base-resolution bbox of one supervoxel split's + # writes. Kept as a list (not merged) so the worker only rewrites + # tiles whose base footprint actually changed. + payload["seg_bboxes"] = [ + [bbs.tolist(), bbe.tolist()] for bbs, bbe in result.seg_bbox + ] exchange = os.getenv("PYCHUNKEDGRAPH_EDITS_EXCHANGE", "pychunkedgraph") c = MessagingClient() @@ -434,84 +446,6 @@ def _get_sources_and_sinks(cg: ChunkedGraph, data): return (source_ids, sink_ids, source_coords, sink_coords) -def split_with_sv_splits(cg, data, user_id="test", mincut=True): - """Remove edges with automatic supervoxel splitting when needed. - - Attempts remove_edges. If source/sink SVs share a cross-chunk representative, - splits the overlapping SVs in the segmentation and retries. - """ - sources, sinks, source_coords, sink_coords = _get_sources_and_sinks(cg, data) - logger.note(f"pre-split: sources={sources}, sinks={sinks}") - t0 = time.time() - try: - ret = cg.remove_edges( - user_id=user_id, - source_ids=sources, - sink_ids=sinks, - source_coords=source_coords, - sink_coords=sink_coords, - mincut=mincut, - ) - logger.note(f"remove_edges ({time.time() - t0:.2f}s)") - except cg_exceptions.SupervoxelSplitRequiredError as e: - logger.note(f"sv split required ({time.time() - t0:.2f}s): {e}") - sources_remapped = fastremap.remap( - sources, - e.sv_remapping, - preserve_missing_labels=True, - in_place=False, - ) - sinks_remapped = fastremap.remap( - sinks, - e.sv_remapping, - preserve_missing_labels=True, - in_place=False, - ) - logger.note(f"remapped sources={sources_remapped}, sinks={sinks_remapped}") - overlap_mask = np.isin(sources_remapped, sinks_remapped) - logger.note(f"overlapping reps: {np.unique(sources_remapped[overlap_mask])}") - t1 = time.time() - for rep in np.unique(sources_remapped[overlap_mask]): - _mask0 = sources_remapped == rep - _mask1 = sinks_remapped == rep - split_supervoxel( - cg, - sources[_mask0][0], - source_coords[_mask0], - sink_coords[_mask1], - e.operation_id, - sv_remapping=e.sv_remapping, - ) - logger.note(f"sv splits done ({time.time() - t1:.2f}s)") - - sources, sinks, source_coords, sink_coords = _get_sources_and_sinks(cg, data) - logger.note(f"post-split: sources={sources}, sinks={sinks}") - t1 = time.time() - try: - ret = cg.remove_edges( - user_id=user_id, - source_ids=sources, - sink_ids=sinks, - source_coords=source_coords, - sink_coords=sink_coords, - mincut=mincut, - ) - except cg_exceptions.SupervoxelSplitRequiredError as e2: - # The cross-chunk representative group extends beyond the split - # bbox. Unsplit SVs inside the bbox still have inf edges to SVs - # outside, bridging source and sink through the broader component. - - logger.note(f"retry still requires sv split") - # check_unsplit_sv_bridges(cg, e2.sv_remapping, sources, sinks) - raise cg_exceptions.PreconditionError( - "Supervoxel split succeeded but the split region is too small " - "to fully separate source and sink. " - "Try placing source and sink points farther apart." - ) from e2 - logger.note(f"remove_edges after sv split ({time.time() - t1:.2f}s)") - return ret - - def handle_split(table_id): current_app.table_id = table_id user_id = str(g.auth_user.get("id", current_app.user_id)) @@ -523,8 +457,17 @@ def handle_split(table_id): cg = app_utils.get_cg(table_id, skip_cache=True) current_app.logger.debug(data) + sources, sinks, source_coords, sink_coords = _get_sources_and_sinks(cg, data) + logger.note(f"split inputs: sources={sources}, sinks={sinks}") try: - ret = split_with_sv_splits(cg, data, user_id, mincut) + ret = cg.remove_edges( + user_id=user_id, + source_ids=sources, + sink_ids=sinks, + source_coords=source_coords, + sink_coords=sink_coords, + mincut=mincut, + ) except cg_exceptions.LockingError as e: raise cg_exceptions.InternalServerError(e) except cg_exceptions.PreconditionError as e: diff --git a/pychunkedgraph/debug/profiler.py b/pychunkedgraph/debug/profiler.py deleted file mode 100644 index b74ddac76..000000000 --- a/pychunkedgraph/debug/profiler.py +++ /dev/null @@ -1,121 +0,0 @@ -from typing import Dict -from typing import List -from typing import Tuple - -import os -import time -from collections import defaultdict -from contextlib import contextmanager - - -class HierarchicalProfiler: - """ - Hierarchical profiler for detailed timing breakdowns. - Tracks timing at multiple levels and prints a breakdown at the end. - """ - - def __init__(self, enabled: bool = True): - self.enabled = enabled - self.timings: Dict[str, List[float]] = defaultdict(list) - self.call_counts: Dict[str, int] = defaultdict(int) - self.stack: List[Tuple[str, float]] = [] - self.current_path: List[str] = [] - - @contextmanager - def profile(self, name: str): - """Context manager for profiling a code block.""" - if not self.enabled: - yield - return - - full_path = ".".join(self.current_path + [name]) - self.current_path.append(name) - start_time = time.perf_counter() - - try: - yield - finally: - elapsed = time.perf_counter() - start_time - self.timings[full_path].append(elapsed) - self.call_counts[full_path] += 1 - self.current_path.pop() - - def print_report(self, operation_id=None): - """Print a detailed timing breakdown.""" - if not self.enabled or not self.timings: - return - - print("\n" + "=" * 80) - print( - f"PROFILER REPORT{f' (operation_id={operation_id})' if operation_id else ''}" - ) - print("=" * 80) - - # Group by depth level - by_depth: Dict[int, List[Tuple[str, float, int]]] = defaultdict(list) - for path, times in self.timings.items(): - depth = path.count(".") - total_time = sum(times) - count = self.call_counts[path] - by_depth[depth].append((path, total_time, count)) - - # Sort each level by total time - for depth in sorted(by_depth.keys()): - items = sorted(by_depth[depth], key=lambda x: -x[1]) - for path, total_time, count in items: - indent = " " * depth - avg_time = total_time / count if count > 0 else 0 - if count > 1: - print( - f"{indent}{path}: {total_time*1000:.2f}ms total " - f"({count} calls, {avg_time*1000:.2f}ms avg)" - ) - else: - print(f"{indent}{path}: {total_time*1000:.2f}ms") - - # Print summary - print("-" * 80) - top_level_total = sum( - sum(times) for path, times in self.timings.items() if "." not in path - ) - print(f"Total top-level time: {top_level_total*1000:.2f}ms") - - # Print top 10 slowest operations - print("\nTop 10 slowest operations:") - all_ops = [ - (path, sum(times), self.call_counts[path]) - for path, times in self.timings.items() - ] - all_ops.sort(key=lambda x: -x[1]) - for i, (path, total_time, count) in enumerate(all_ops[:10]): - pct = (total_time / top_level_total * 100) if top_level_total > 0 else 0 - print(f" {i+1}. {path}: {total_time*1000:.2f}ms ({pct:.1f}%)") - - print("=" * 80 + "\n") - - def reset(self): - """Reset all timing data.""" - self.timings.clear() - self.call_counts.clear() - self.stack.clear() - self.current_path.clear() - - -# Global profiler instance - enable via environment variable -PROFILER_ENABLED = os.environ.get("PCG_PROFILER_ENABLED", "0") == "1" -_profiler: HierarchicalProfiler = None - - -def get_profiler() -> HierarchicalProfiler: - """Get or create the global profiler instance.""" - global _profiler - if _profiler is None: - _profiler = HierarchicalProfiler(enabled=PROFILER_ENABLED) - return _profiler - - -def reset_profiler(): - """Reset the global profiler.""" - global _profiler - if _profiler is not None: - _profiler.reset() diff --git a/pychunkedgraph/graph/chunkedgraph.py b/pychunkedgraph/graph/chunkedgraph.py index 63791a932..1b1405217 100644 --- a/pychunkedgraph/graph/chunkedgraph.py +++ b/pychunkedgraph/graph/chunkedgraph.py @@ -3,6 +3,7 @@ import time import typing import datetime +from copy import deepcopy from itertools import chain from functools import reduce @@ -71,13 +72,20 @@ def __init__( self._cache_service = None self.mock_edges = None # hack for unit tests - # shim to update graph_id in meta for copied graphs + # Shim for copied bigtables: rewrite graph_id-bearing fields in one + # update_meta call. Bigtable row-copies preserve the source table's + # values for `graph_config.ID` and `custom_data["mesh"]["dynamic_mesh_dir"]` + # — left as-is, x0's edited meshes would alias clean's at the same + # fragment-id keys. `mesh.dir` (initial sharded meshes) is dataset- + # scoped and intentionally shared, so it's not rewritten here. if graph_id != self.graph_id: gc = self.meta.graph_config._asdict() gc["ID"] = graph_id - new_meta = ChunkedGraphMeta( - GraphConfig(**gc), self.meta.data_source, self.meta.custom_data - ) + cd = deepcopy(self.meta.custom_data) + mesh = cd.get("mesh") + if mesh is not None and "dynamic_mesh_dir" in mesh: + mesh["dynamic_mesh_dir"] = f"dynamic_{graph_id}" + new_meta = ChunkedGraphMeta(GraphConfig(**gc), self.meta.data_source, cd) self.update_meta(new_meta, overwrite=True) self._meta = new_meta @@ -1046,6 +1054,21 @@ def get_chunk_coordinates_multiple(self, node_or_chunk_ids: typing.Sequence): assert len(layers) == 0 or np.all(layers == layers[0]), "must be same layer." return chunk_utils.get_chunk_coordinates_multiple(self.meta, node_or_chunk_ids) + def get_chunk_center_voxel(self, node_or_chunk_id: basetypes.NODE_ID) -> np.ndarray: + """Approximate base-resolution voxel coord at the chunk's center. + + Useful for debugging: feed the returned ``[x, y, z]`` to NGL's + position bar to navigate to where a chunk lives in the volume. + Layer L chunk side = ``CHUNK_SIZE * 2 ** (L - 2)`` base voxels. + """ + layer = int(self.get_chunk_layer(node_or_chunk_id)) + cx, cy, cz = self.get_chunk_coordinates(node_or_chunk_id) + chunk_size = np.asarray(self.meta.graph_config.CHUNK_SIZE, dtype=int) * ( + 2 ** (layer - 2) + ) + origin = self.meta.voxel_bounds[:, 0] + np.array([cx, cy, cz]) * chunk_size + return (origin + chunk_size // 2).astype(int) + def get_chunk_id( self, node_id: basetypes.NODE_ID = None, diff --git a/pychunkedgraph/graph/chunks/utils.py b/pychunkedgraph/graph/chunks/utils.py index 0e39fbf9f..cd3b96ccc 100644 --- a/pychunkedgraph/graph/chunks/utils.py +++ b/pychunkedgraph/graph/chunks/utils.py @@ -170,11 +170,9 @@ def _compute_chunk_id( ) -> np.uint64: s_bits_per_dim = meta.bitmasks[layer] if not (x < 2**s_bits_per_dim and y < 2**s_bits_per_dim and z < 2**s_bits_per_dim): - raise ValueError( - f"Coordinate is out of range \ + raise ValueError(f"Coordinate is out of range \ layer: {layer} bits/dim {s_bits_per_dim}. \ - [{x}, {y}, {z}]; max = {2 ** s_bits_per_dim}." - ) + [{x}, {y}, {z}]; max = {2 ** s_bits_per_dim}.") layer_offset = 64 - meta.graph_config.LAYER_ID_BITS x_offset = layer_offset - s_bits_per_dim y_offset = x_offset - s_bits_per_dim @@ -241,7 +239,9 @@ def get_bounding_children_chunks( @lru_cache() -def get_l2chunkids_along_boundary(cg_meta, mlayer: int, coord_a, coord_b, padding: int = 0): +def get_l2chunkids_along_boundary( + cg_meta, mlayer: int, coord_a, coord_b, padding: int = 0 +): """ Gets L2 Chunk IDs along opposing faces for larger chunks. If padding is enabled, more faces of L2 chunks are padded on both sides. @@ -284,17 +284,23 @@ def get_l2chunkids_along_boundary(cg_meta, mlayer: int, coord_a, coord_b, paddin return l2chunk_ids_a, l2chunk_ids_b -def chunks_overlapping_bbox(bbox_min, bbox_max, chunk_size) -> dict: +def chunks_overlapping_bbox(bbox_min, bbox_max, chunk_size, origin=0) -> dict: """ Find octree chunks overlapping with a bounding box in 3D and return a dictionary mapping chunk indices to clipped bounding boxes. + + `origin` is the voxel coordinate of chunk index (0, 0, 0). Pass + `meta.voxel_bounds[:, 0]` so the lattice aligns to the dataset's + chunks; the default 0 leaves the lattice anchored at the volume + origin. """ bbox_min = np.asarray(bbox_min, dtype=int) bbox_max = np.asarray(bbox_max, dtype=int) chunk_size = np.asarray(chunk_size, dtype=int) + origin = np.asarray(origin, dtype=int) - start_idx = np.floor_divide(bbox_min, chunk_size).astype(int) - end_idx = np.floor_divide(bbox_max, chunk_size).astype(int) + start_idx = np.floor_divide(bbox_min - origin, chunk_size).astype(int) + end_idx = np.floor_divide(bbox_max - origin, chunk_size).astype(int) ix = np.arange(start_idx[0], end_idx[0] + 1) iy = np.arange(start_idx[1], end_idx[1] + 1) @@ -302,7 +308,7 @@ def chunks_overlapping_bbox(bbox_min, bbox_max, chunk_size) -> dict: grid = np.stack(np.meshgrid(ix, iy, iz, indexing="ij"), axis=-1, dtype=int) grid = grid.reshape(-1, 3) - chunk_min = grid * chunk_size + chunk_min = grid * chunk_size + origin chunk_max = chunk_min + chunk_size clipped_min = np.maximum(chunk_min, bbox_min) clipped_max = np.minimum(chunk_max, bbox_max) diff --git a/pychunkedgraph/graph/cutting.py b/pychunkedgraph/graph/cutting.py index e49cc9ded..95031c6d9 100644 --- a/pychunkedgraph/graph/cutting.py +++ b/pychunkedgraph/graph/cutting.py @@ -5,7 +5,8 @@ import graph_tool import graph_tool.flow -from typing import Tuple +from dataclasses import dataclass +from typing import Tuple, Union from typing import Sequence from typing import Iterable @@ -19,6 +20,40 @@ DEBUG_MODE = False +@dataclass +class Cut: + """Multicut produced a clean partition — these SV-pair edges are to be cut.""" + + atomic_edges: np.ndarray # shape (N, 2) + + +@dataclass +class PreviewCut: + """Multicut in preview mode — connected components after the proposed cut. + + `illegal_split` flags cases where the cut isolates source or sink. + """ + + supervoxel_ccs: list + illegal_split: bool + + +@dataclass +class SvSplitRequired: + """Multicut could not partition without first splitting a supervoxel. + + Carries the cross-chunk-representative remapping the caller needs to + run the actual SV split. Returned (not raised) from run_multicut; the + SupervoxelSplitRequiredError that surfaces this condition is caught + inside run_multicut and never escapes as control flow. + """ + + sv_remapping: dict # old_sv_id -> rep_sv_id + + +MulticutResult = Union[Cut, PreviewCut, SvSplitRequired] + + class IsolatingCutException(Exception): """Raised when mincut would split off one of the labeled supervoxel exactly. This is used to trigger a PostconditionError with a custom message. @@ -668,21 +703,38 @@ def run_multicut( path_augment: bool = True, disallow_isolating_cut: bool = True, sv_split_supported: bool = False, -): - local_mincut_graph = LocalMincutGraph( - edges.get_pairs(), - edges.affinities, - source_ids, - sink_ids, - split_preview, - path_augment, - disallow_isolating_cut=disallow_isolating_cut, - sv_split_supported=sv_split_supported, - ) - atomic_edges = local_mincut_graph.compute_mincut() - if len(atomic_edges) == 0: +) -> MulticutResult: + """Run the multicut and return either the cut edges or an SV-split request. + + When `sv_split_supported=True`, the "source and sink share a cross-chunk + rep" condition is returned as `SvSplitRequired` rather than raised — + `SupervoxelSplitRequiredError` is an implementation detail of + `LocalMincutGraph` unwinding, caught at this boundary so it never + drives control flow in callers. + """ + try: + local_mincut_graph = LocalMincutGraph( + edges.get_pairs(), + edges.affinities, + source_ids, + sink_ids, + split_preview, + path_augment, + disallow_isolating_cut=disallow_isolating_cut, + sv_split_supported=sv_split_supported, + ) + mincut_output = local_mincut_graph.compute_mincut() + except SupervoxelSplitRequiredError as err: + return SvSplitRequired(err.sv_remapping) + + if split_preview: + # compute_mincut returns (ccs, illegal_split) in preview mode. + supervoxel_ccs, illegal_split = mincut_output + return PreviewCut(supervoxel_ccs, illegal_split) + + if len(mincut_output) == 0: raise PostconditionError(f"Mincut failed. Try with a different set of points.") - return atomic_edges + return Cut(mincut_output) def run_split_preview( @@ -695,11 +747,15 @@ def run_split_preview( path_augment: bool = True, disallow_isolating_cut: bool = True, ): - root_ids = set( - cg.get_roots(np.concatenate([source_ids, sink_ids]), assert_roots=True) - ) + sink_and_source_ids = np.concatenate([source_ids, sink_ids]) + roots = cg.get_roots(sink_and_source_ids, assert_roots=True) + root_ids = set(roots) if len(root_ids) > 1: - raise PreconditionError("Supervoxels must belong to the same object.") + raise PreconditionError( + f"Supervoxels must belong to the same object. " + f"sources={list(source_ids)} sinks={list(sink_ids)} " + f"sv_id->root: {dict(zip(sink_and_source_ids.tolist(), roots.tolist()))}" + ) bbox = get_bounding_box(source_coords, sink_coords, bb_offset) l2id_agglomeration_d, edges = cg.get_subgraph( @@ -713,7 +769,7 @@ def run_split_preview( mask0 = np.isin(edges.node_ids1, supervoxels) mask1 = np.isin(edges.node_ids2, supervoxels) edges = edges[mask0 & mask1] - edges_to_remove, illegal_split = run_multicut( + result = run_multicut( edges, source_ids, sink_ids, @@ -722,8 +778,14 @@ def run_split_preview( disallow_isolating_cut=disallow_isolating_cut, sv_split_supported=cg.meta.ocdbt_seg, ) + if isinstance(result, SvSplitRequired): + # Preview callers can't perform an SV split; surface as a precondition. + raise PreconditionError( + "Supervoxel split required to cut these source/sink points; " + "preview is not available until an edit is applied." + ) - if len(edges_to_remove) == 0: + assert isinstance(result, PreviewCut), f"unexpected preview result type: {result!r}" + if len(result.supervoxel_ccs) == 0: raise PostconditionError("Mincut could not find any edges to remove.") - - return edges_to_remove, illegal_split + return result.supervoxel_ccs, result.illegal_split diff --git a/pychunkedgraph/graph/downsample.py b/pychunkedgraph/graph/downsample.py new file mode 100644 index 000000000..7a28305c4 --- /dev/null +++ b/pychunkedgraph/graph/downsample.py @@ -0,0 +1,342 @@ +"""Async mip-pyramid downsample worker support. + +An SV split writes at base resolution only; coarser mips are produced +afterwards by a pubsub worker that consumes this module's primitives. + +Work is organized into `pyramid_block`s. A block is a cubic physical +region sized so that at the coarsest scale in the pyramid it equals +exactly one storage chunk. Because every finer scale's chunk grid is a +power-of-2 refinement of the coarsest, a block aligned at the coarsest +scale is automatically aligned at every finer scale — so two different +blocks never share a storage chunk at any mip. That is what makes a +single lock per block safe. + +Within a block we pick one of two code paths: + 1. Fast in-memory path: read the affected base region once, call + tinybrain with `num_mips=K` (all mips at once), write each mip's + output. Used when the base read fits a memory budget — the typical + case because the SV-split bbox is bounded by the /split endpoint + (source+sink coords + small padding). + 2. Per-mip fallback: read the previous mip, tinybrain one step, write. + K storage round-trips instead of 1. Kept for pathological inputs + whose base read would exceed the memory budget. + +Uniform downsample factor (e.g. 2x2x2) across all non-base scales is +assumed and asserted. +""" + +import numpy as np +import tinybrain + +from pychunkedgraph import get_logger + +logger = get_logger(__name__) + +# Default memory budget for the in-memory path's base read. +# uint64 segmentation is 8 bytes/voxel; 1 GiB ≈ 512^3 voxels. Edits +# produced by the /split endpoint are bounded far below this. +DEFAULT_MEMORY_BUDGET_BYTES = 1 << 30 + + +def num_output_mips(meta) -> int: + """Count of non-base scales — what the worker actually writes.""" + return len(meta.ws_ocdbt_scales) - 1 + + +def uniform_factor(meta) -> tuple: + """Per-axis downsample factor between consecutive scales. + + tinybrain takes one factor tuple per call, so the factor must be + constant across the pyramid. Asserts rather than silently producing + wrong mips for a dataset with mixed factors. + """ + resolutions = [np.array(r, dtype=float) for r in meta.ws_ocdbt_resolutions] + factors = [ + tuple((resolutions[i] / resolutions[i - 1]).astype(int)) + for i in range(1, len(resolutions)) + ] + assert all( + f == factors[0] for f in factors + ), f"non-uniform downsample factors {factors}" + return factors[0] + + +def _chunk_size_at_scale(meta, scale_idx: int) -> np.ndarray: + """Storage chunk size at a given scale (excluding the channel dim).""" + return np.array( + meta.ws_ocdbt_scales[scale_idx].chunk_layout.read_chunk.shape[:3], dtype=int + ) + + +def block_shape(meta) -> np.ndarray: + """pyramid_block size in base-resolution voxels. + + Chosen so that at the coarsest scale K the block equals exactly one + storage chunk — which transitively aligns it to every finer scale's + chunk grid. + """ + K = num_output_mips(meta) + coarsest_chunk = _chunk_size_at_scale(meta, K) + factor = np.array(uniform_factor(meta), dtype=int) + return coarsest_chunk * factor**K + + +def blocks_for_bbox(meta, bbs, bbe) -> list: + """Block coords intersected by a base-resolution bbox. + + Bbox is rounded outward to the block grid — a tiny bbox inside one + block still yields that one block coord. Returns sorted list of + `(bx, by, bz)` ints for deadlock-free lock acquisition. + """ + shape = block_shape(meta) + lo = np.asarray(bbs, dtype=int) // shape + hi = -(-np.asarray(bbe, dtype=int) // shape) + coords = [ + (int(bx), int(by), int(bz)) + for bx in range(lo[0], hi[0]) + for by in range(lo[1], hi[1]) + for bz in range(lo[2], hi[2]) + ] + return sorted(coords) + + +def block_base_bbox(meta, block_coord) -> tuple: + """Inverse of `blocks_for_bbox` for a single coord — base-voxel bbox.""" + shape = block_shape(meta) + lo = np.asarray(block_coord, dtype=int) * shape + hi = lo + shape + return lo, hi + + +def _seg_bboxes_to_np(seg_bboxes): + return [ + (np.asarray(bbs, dtype=int), np.asarray(bbe, dtype=int)) + for bbs, bbe in seg_bboxes + ] + + +def _affected_region_base(meta, block_coord, seg_bboxes_np): + """Base-voxel region covering all tiles this block will write, at any mip. + + Starts from the union of (seg bbox ∩ block ∩ volume) then aligns + outward to the coarsest mip's base-voxel grid (= factor**K per axis). + That alignment both makes the region tinybrain-valid for num_mips=K + and guarantees clean chunk-aligned writes at every mip (coarsest + alignment refines down to every finer scale). + + Returns `(base_lo, base_hi)` or `None` if no overlap. + """ + K = num_output_mips(meta) + factor = np.array(uniform_factor(meta), dtype=int) + align = factor**K + + block_lo, block_hi = block_base_bbox(meta, block_coord) + vol_lo = meta.voxel_bounds[:, 0] + vol_hi = meta.voxel_bounds[:, 1] + clipped_lo = np.maximum(block_lo, vol_lo) + clipped_hi = np.minimum(block_hi, vol_hi) + if np.any(clipped_hi <= clipped_lo): + return None + + union_lo, union_hi = None, None + for sb, eb in seg_bboxes_np: + ilo = np.maximum(sb, clipped_lo) + ihi = np.minimum(eb, clipped_hi) + if np.any(ihi <= ilo): + continue + union_lo = ilo if union_lo is None else np.minimum(union_lo, ilo) + union_hi = ihi if union_hi is None else np.maximum(union_hi, ihi) + if union_lo is None: + return None + + base_lo = (union_lo // align) * align + base_hi = -(-union_hi // align) * align + # Keep within the clipped block. Block corners are factor**K-aligned + # (block_shape is a multiple of factor**K), so this clip preserves + # alignment. + base_lo = np.maximum(base_lo, clipped_lo) + base_hi = np.minimum(base_hi, clipped_hi) + if np.any(base_hi <= base_lo): + return None + return base_lo, base_hi + + +def _process_block_in_memory(meta, base_region, K, factor): + """Read base once, tinybrain all mips, write each output. + + Assumes the base region is factor**K-aligned in size (which is what + `_affected_region_base` returns) so tinybrain with num_mips=K emits + clean integer voxel counts at every mip. + """ + base_lo, base_hi = base_region + base = meta.ws_ocdbt_scales[0] + arr = ( + base[ + base_lo[0] : base_hi[0], + base_lo[1] : base_hi[1], + base_lo[2] : base_hi[2], + :, + ] + .read() + .result() + ) + mips = tinybrain.downsample_segmentation( + arr, factor=tuple(int(f) for f in factor), num_mips=K, sparse=False + ) + for m, out in enumerate(mips, start=1): + scale = factor**m + mip_lo = base_lo // scale + mip_hi = base_hi // scale + dst = meta.ws_ocdbt_scales[m] + dst[ + mip_lo[0] : mip_hi[0], + mip_lo[1] : mip_hi[1], + mip_lo[2] : mip_hi[2], + :, + ].write(out).result() + + +def _affected_region_at_mip( + block_lo_base, + block_hi_base, + vol_lo, + vol_hi, + seg_bboxes_base, + mip: int, + factor: np.ndarray, + mip_chunk: np.ndarray, +): + """Write region at this mip in mip-local voxel coords. + + Union of seg bboxes ∩ block ∩ volume, aligned outward to this mip's + storage-chunk grid. Returns `(mip_lo, mip_hi)` or None. + """ + scale = factor**mip + clipped_lo = np.maximum(block_lo_base, vol_lo) + clipped_hi = np.minimum(block_hi_base, vol_hi) + if np.any(clipped_hi <= clipped_lo): + return None + + union_lo, union_hi = None, None + for sb, eb in seg_bboxes_base: + ilo = np.maximum(sb, clipped_lo) + ihi = np.minimum(eb, clipped_hi) + if np.any(ihi <= ilo): + continue + union_lo = ilo if union_lo is None else np.minimum(union_lo, ilo) + union_hi = ihi if union_hi is None else np.maximum(union_hi, ihi) + if union_lo is None: + return None + + mip_lo = union_lo // scale + mip_hi = -(-union_hi // scale) + mip_lo = (mip_lo // mip_chunk) * mip_chunk + mip_hi = -(-mip_hi // mip_chunk) * mip_chunk + + vol_lo_mip = vol_lo // scale + vol_hi_mip = -(-vol_hi // scale) + mip_lo = np.maximum(mip_lo, vol_lo_mip) + mip_hi = np.minimum(mip_hi, vol_hi_mip) + if np.any(mip_hi <= mip_lo): + return None + return mip_lo, mip_hi + + +def _process_block_per_mip(meta, block_coord, seg_bboxes_np, K, factor): + """Fallback path: process one mip at a time. + + Used when the full in-memory base read would exceed the memory + budget. Each mip reads the prior mip from storage, does one + tinybrain step, writes. + + Safe across mip boundaries only because the caller holds the block + lock — no other task can write the storage chunks this block owns, + so reading mip N here always sees what we wrote at mip N in the + previous iteration. + """ + vol_lo = meta.voxel_bounds[:, 0] + vol_hi = meta.voxel_bounds[:, 1] + block_lo_base, block_hi_base = block_base_bbox(meta, block_coord) + + for mip in range(1, K + 1): + mip_chunk = _chunk_size_at_scale(meta, mip) + region = _affected_region_at_mip( + block_lo_base, + block_hi_base, + vol_lo, + vol_hi, + seg_bboxes_np, + mip, + factor, + mip_chunk, + ) + if region is None: + continue + mip_lo, mip_hi = region + src = meta.ws_ocdbt_scales[mip - 1] + src_lo = mip_lo * factor + src_hi = mip_hi * factor + arr = ( + src[ + src_lo[0] : src_hi[0], + src_lo[1] : src_hi[1], + src_lo[2] : src_hi[2], + :, + ] + .read() + .result() + ) + out = tinybrain.downsample_segmentation( + arr, factor=tuple(int(f) for f in factor), num_mips=1, sparse=False + )[0] + dst = meta.ws_ocdbt_scales[mip] + dst[ + mip_lo[0] : mip_hi[0], + mip_lo[1] : mip_hi[1], + mip_lo[2] : mip_hi[2], + :, + ].write(out).result() + + +def process_block( + meta, + block_coord, + seg_bboxes, + memory_budget_bytes: int = DEFAULT_MEMORY_BUDGET_BYTES, +): + """Downsample one pyramid_block through every non-base mip. + + Atomic within the block: caller must hold the block lock. Picks the + in-memory path when the base read fits the memory budget, falls + back to the per-mip path otherwise. + + Reads and writes only the aligned region covering `seg_bboxes` + inside the block; the rest of the block is untouched. Region + alignment rounds outward to the coarsest mip's grid so the aligned + region is always tinybrain-valid and chunk-aligned at every mip. + + Args: + meta: ChunkedGraphMeta with `ws_ocdbt_scales` / `ws_ocdbt_resolutions`. + block_coord: (bx, by, bz) block grid coord. + seg_bboxes: iterable of `(bbs, bbe)` base-voxel bbox pairs from + the SV splits that triggered this job. + """ + K = num_output_mips(meta) + factor = np.array(uniform_factor(meta), dtype=int) + seg_bboxes_np = _seg_bboxes_to_np(seg_bboxes) + + region = _affected_region_base(meta, block_coord, seg_bboxes_np) + if region is None: + return + base_lo, base_hi = region + + bytes_per_voxel = meta.ws_ocdbt_scales[0].dtype.numpy_dtype.itemsize + base_bytes = int(np.prod(base_hi - base_lo)) * bytes_per_voxel + if base_bytes <= memory_budget_bytes: + _process_block_in_memory(meta, region, K, factor) + else: + logger.info( + f"block {block_coord} base read {base_bytes / 1e9:.2f} GB exceeds " + f"budget {memory_budget_bytes / 1e9:.2f} GB; using per-mip path" + ) + _process_block_per_mip(meta, block_coord, seg_bboxes_np, K, factor) diff --git a/pychunkedgraph/graph/dry_run.py b/pychunkedgraph/graph/dry_run.py new file mode 100644 index 000000000..985d3e3ff --- /dev/null +++ b/pychunkedgraph/graph/dry_run.py @@ -0,0 +1,38 @@ +import os +from contextlib import contextmanager + +DRY_RUN_ENV = "PCG_DRY_RUN" + + +def is_dry_run() -> bool: + """True iff ``PCG_DRY_RUN=1`` in the environment. + + When true, every write function in the edit flow + (``operation._write``, the operation log writes, ``write_seg_chunks``, + and the lock acquire/release paths) returns early without + persisting. Used by debug tooling to re-run edits against + production BT/OCDBT state without mutating it. + + Strict ``"1"`` match so unset / empty / ``"true"`` / typos do not + accidentally trigger in production. + """ + return os.environ.get(DRY_RUN_ENV) == "1" + + +@contextmanager +def dry_run_scope(): + """Set ``PCG_DRY_RUN=1`` for the duration of the block; restore on exit. + + Single point for set/restore of the env var. The caller's + pre-existing value (including absence) is restored even if the + block raises. + """ + prev = os.environ.get(DRY_RUN_ENV) + os.environ[DRY_RUN_ENV] = "1" + try: + yield + finally: + if prev is None: + os.environ.pop(DRY_RUN_ENV, None) + else: + os.environ[DRY_RUN_ENV] = prev diff --git a/pychunkedgraph/graph/edits.py b/pychunkedgraph/graph/edits.py index 779743740..dc04c86a7 100644 --- a/pychunkedgraph/graph/edits.py +++ b/pychunkedgraph/graph/edits.py @@ -12,7 +12,7 @@ import numpy as np from pychunkedgraph import get_logger -from pychunkedgraph.debug.profiler import HierarchicalProfiler, get_profiler +from pychunkedgraph.profiler import HierarchicalProfiler, get_profiler from . import types from pychunkedgraph.graph import attributes diff --git a/pychunkedgraph/graph/edits_sv.py b/pychunkedgraph/graph/edits_sv.py deleted file mode 100644 index a0ce5b98b..000000000 --- a/pychunkedgraph/graph/edits_sv.py +++ /dev/null @@ -1,314 +0,0 @@ -""" -Manage new supervoxels after a supervoxel split. -""" - -import time -from datetime import datetime -from collections import defaultdict, deque - -import fastremap -import numpy as np - -from pychunkedgraph import get_logger -from pychunkedgraph.graph import ( - attributes, - ChunkedGraph, - cache as cache_utils, - basetypes, - serializers, -) -from pychunkedgraph.graph.chunks.utils import chunks_overlapping_bbox -from pychunkedgraph.graph.cutting_sv import split_supervoxel_helper -from pychunkedgraph.graph.edges_sv import update_edges, add_new_edges -from pychunkedgraph.graph.ocdbt import write_seg -from pychunkedgraph.graph.utils import get_local_segmentation -from pychunkedgraph.io.edges import get_chunk_edges - -logger = get_logger(__name__) - - -def _get_whole_sv( - cg: ChunkedGraph, node: basetypes.NODE_ID, min_coord, max_coord -) -> set: - all_chunks = [ - (x, y, z) - for x in range(min_coord[0], max_coord[0]) - for y in range(min_coord[1], max_coord[1]) - for z in range(min_coord[2], max_coord[2]) - ] - edges = get_chunk_edges(cg.meta.data_source.EDGES, all_chunks) - cx_edges = edges["cross"].get_pairs() - if len(cx_edges) == 0: - return {node} - - explored_nodes = set([node]) - queue = deque([node]) - while queue: - vertex = queue.popleft() - mask = cx_edges[:, 0] == vertex - neighbors = cx_edges[mask][:, 1] - - if len(neighbors) > 0: - neighbor_coords = cg.get_chunk_coordinates_multiple(neighbors) - min_mask = (neighbor_coords >= min_coord).all(axis=1) - max_mask = (neighbor_coords < max_coord).all(axis=1) - neighbors = neighbors[min_mask & max_mask] - - for neighbor in neighbors: - if neighbor not in explored_nodes: - explored_nodes.add(neighbor) - queue.append(neighbor) - return explored_nodes - - -def _update_chunks(cg, chunks_bbox_map, seg, result_seg, bb_start): - """Process all chunks in a single pass: assign new SV IDs to split fragments. - - For each chunk overlapping the split bbox, finds split labels and - batch-allocates new IDs. No multiprocessing needed. - """ - results = [] - for chunk_coord, chunk_bbox in chunks_bbox_map.items(): - x, y, z = chunk_coord - chunk_id = cg.get_chunk_id(layer=1, x=x, y=y, z=z) - - _s, _e = chunk_bbox - bb_start - og_chunk_seg = seg[_s[0] : _e[0], _s[1] : _e[1], _s[2] : _e[2]] - chunk_seg = result_seg[_s[0] : _e[0], _s[1] : _e[1], _s[2] : _e[2]] - - labels = fastremap.unique(chunk_seg[chunk_seg != 0]) - if labels.size < 2: - continue - - new_ids = cg.id_client.create_node_ids(chunk_id, size=len(labels)) - _indices = [] - _old_values = [] - _new_values = [] - _label_id_map = {} - for _id, new_id in zip(labels, new_ids): - _mask = chunk_seg == _id - voxel_locs = np.where(_mask) - _og_value = og_chunk_seg[ - voxel_locs[0][0], voxel_locs[1][0], voxel_locs[2][0] - ] - _index = np.column_stack(voxel_locs) - n = len(_index) - _indices.append(_index) - _old_values.append(np.full(n, _og_value, dtype=basetypes.NODE_ID)) - _new_values.append(np.full(n, new_id, dtype=basetypes.NODE_ID)) - _label_id_map[int(_id)] = new_id - - _indices = np.concatenate(_indices) + (chunk_bbox[0] - bb_start) - _old_values = np.concatenate(_old_values) - _new_values = np.concatenate(_new_values) - results.append((_indices, _old_values, _new_values, _label_id_map)) - return results - - -def _voxel_crop(bbs, bbe, bbs_, bbe_): - xS, yS, zS = bbs - bbs_ - xE, yE, zE = (None if i == 0 else -1 for i in bbe_ - bbe) - voxel_overlap_crop = np.s_[xS:xE, yS:yE, zS:zE] - return voxel_overlap_crop - - -def _parse_results(results, seg, bbs, bbe): - """Merge per-chunk split results into a single segmentation volume. - - Applies new SV IDs from each chunk's split result to `seg` (in-place) - and builds the old→new mapping + label→new-id mapping. - - Returns (seg, old_new_map, new_id_label_map). - """ - old_new_map = defaultdict(set) - new_id_label_map = {} - for result in results: - if result: - indexer, old_values, new_values, label_id_map = result - seg[tuple(indexer.T)] = new_values - for old_sv, new_sv in zip(old_values, new_values): - old_new_map[old_sv].add(new_sv) - for label, new_id in label_id_map.items(): - new_id_label_map[new_id] = label - - assert np.all(seg.shape == bbe - bbs), f"{seg.shape} != {bbe - bbs}" - return seg, old_new_map, new_id_label_map - - -def split_supervoxel( - cg: ChunkedGraph, - sv_id: basetypes.NODE_ID, - source_coords: np.ndarray, - sink_coords: np.ndarray, - operation_id: int, - sv_remapping: dict, - verbose: bool = False, - time_stamp: datetime = None, -) -> dict[int, set]: - """ - Lookups coordinates of given supervoxel in segmentation. - Finds its counterparts split by chunk boundaries and splits them as a whole. - Updates the segmentation with new IDs. - """ - vol_start = cg.meta.voxel_bounds[:, 0] - vol_end = cg.meta.voxel_bounds[:, 1] - chunk_size = cg.meta.graph_config.CHUNK_SIZE - _coords = np.concatenate([source_coords, sink_coords]) - _padding = np.array([cg.meta.resolution[-1] * 2] * 3) / cg.meta.resolution - - bbs = np.clip((np.min(_coords, 0) - _padding).astype(int), vol_start, vol_end) - bbe = np.clip((np.max(_coords, 0) + _padding).astype(int), vol_start, vol_end) - chunk_min, chunk_max = bbs // chunk_size, np.ceil(bbe / chunk_size).astype(int) - bbs, bbe = chunk_min * chunk_size, chunk_max * chunk_size - logger.note(f"cg.meta.ws_ocdbt: {cg.meta.ws_ocdbt.shape}; res {cg.meta.resolution}") - logger.note(f"chunk and padding {chunk_size}; {_padding}") - logger.note(f"bbox and chunk min max {(bbs, bbe)}; {(chunk_min, chunk_max)}") - - t0 = time.time() - rep = sv_remapping.get(sv_id, sv_id) - all_svs = np.array( - [sv for sv, r in sv_remapping.items() if r == rep], - dtype=basetypes.NODE_ID, - ) - coords = cg.get_chunk_coordinates_multiple(all_svs) - in_bbox = (coords >= chunk_min).all(axis=1) & (coords < chunk_max).all(axis=1) - cut_supervoxels = set(all_svs[in_bbox].tolist()) - supervoxel_ids = np.array(list(cut_supervoxels), dtype=basetypes.NODE_ID) - logger.note( - f"whole sv {sv_id} -> {supervoxel_ids.tolist()} ({time.time() - t0:.2f}s)" - ) - - # one voxel overlap for neighbors - bbs_ = np.clip(bbs - 1, vol_start, vol_end) - bbe_ = np.clip(bbe + 1, vol_start, vol_end) - t0 = time.time() - seg = get_local_segmentation(cg.meta, bbs_, bbe_).squeeze() - logger.note(f"segmentation read {seg.shape} ({time.time() - t0:.2f}s)") - - binary_seg = np.isin(seg, supervoxel_ids) - voxel_overlap_crop = _voxel_crop(bbs, bbe, bbs_, bbe_) - t0 = time.time() - split_result = split_supervoxel_helper( - binary_seg[voxel_overlap_crop], - source_coords - bbs, - sink_coords - bbs, - cg.meta.resolution, - verbose=verbose, - ) - logger.note(f"split computation {split_result.shape} ({time.time() - t0:.2f}s)") - - chunks_bbox_map = chunks_overlapping_bbox(bbs, bbe, cg.meta.graph_config.CHUNK_SIZE) - t0 = time.time() - results = _update_chunks( - cg, chunks_bbox_map, seg[voxel_overlap_crop], split_result, bbs - ) - logger.note( - f"chunk updates {len(chunks_bbox_map)} chunks, {len(results)} with splits ({time.time() - t0:.2f}s)" - ) - - seg_cropped = seg[voxel_overlap_crop].copy() - new_seg, old_new_map, new_id_label_map = _parse_results( - results, seg_cropped, bbs, bbe - ) - logger.note( - f"old_new_map: {len(old_new_map)} SVs split, whole_sv: {len(cut_supervoxels)} SVs" - ) - unsplit = cut_supervoxels - set(old_new_map.keys()) - if unsplit: - logger.note(f"unsplit SVs (kept IDs): {unsplit}") - - sv_ids = fastremap.unique(seg) - roots = cg.get_roots(sv_ids) - sv_root_map = dict(zip(sv_ids, roots)) - root = sv_root_map[sv_id] - logger.note(f"{sv_id} -> {root}") - - root_mask = fastremap.remap(seg, sv_root_map, in_place=False) == root - seg[~root_mask] = 0 - sv_ids = fastremap.unique(seg) - seg[voxel_overlap_crop] = new_seg - t0 = time.time() - edges_tuple = update_edges( - cg, - root, - np.array([bbs, bbe]), - seg, - old_new_map, - new_id_label_map, - ) - logger.note(f"edge update ({time.time() - t0:.2f}s)") - - rows0 = copy_parents_and_add_lineage(cg, operation_id, old_new_map) - rows1 = add_new_edges(cg, edges_tuple, old_new_map, time_stamp=time_stamp) - rows = rows0 + rows1 - - t0 = time.time() - write_seg(cg.meta, bbs, bbe, new_seg) - cg.client.write(rows) - logger.note(f"write seg + {len(rows)} rows ({time.time() - t0:.2f}s)") - return old_new_map, edges_tuple - - -def copy_parents_and_add_lineage( - cg: ChunkedGraph, - operation_id: int, - old_new_map: dict, -) -> list: - """ - Copy parents column from `old_id` to each of `new_ids`. - This makes it easy to get old hierarchy with `new_ids` using an older timestamp. - Link `old_id` and `new_ids` to create a lineage at supervoxel layer. - Returns a list of mutations to be persisted. - """ - result = [] - parents = set() - old_new_map = {k: list(v) for k, v in old_new_map.items()} - parent_cells_map = cg.client.read_nodes( - node_ids=list(old_new_map.keys()), properties=attributes.Hierarchy.Parent - ) - for old_id, new_ids in old_new_map.items(): - for new_id in new_ids: - val_dict = { - attributes.Hierarchy.FormerIdentity: np.array( - [old_id], dtype=basetypes.NODE_ID - ), - attributes.OperationLogs.OperationID: operation_id, - } - result.append( - cg.client.mutate_row(serializers.serialize_uint64(new_id), val_dict) - ) - for cell in parent_cells_map[old_id]: - cache_utils.update(cg.cache.parents_cache, [new_id], cell.value) - parents.add(cell.value) - result.append( - cg.client.mutate_row( - serializers.serialize_uint64(new_id), - {attributes.Hierarchy.Parent: cell.value}, - time_stamp=cell.timestamp, - ) - ) - val_dict = { - attributes.Hierarchy.NewIdentity: np.array(new_ids, dtype=basetypes.NODE_ID) - } - result.append( - cg.client.mutate_row(serializers.serialize_uint64(old_id), val_dict) - ) - - children_cells_map = cg.client.read_nodes( - node_ids=list(parents), properties=attributes.Hierarchy.Child - ) - for parent, children_cells in children_cells_map.items(): - assert len(children_cells) == 1, children_cells - for cell in children_cells: - mask = np.isin(cell.value, list(old_new_map.keys())) - replace = np.concatenate([old_new_map[x] for x in cell.value[mask]]) - children = np.concatenate([cell.value[~mask], replace]) - cg.cache.children_cache[parent] = children - result.append( - cg.client.mutate_row( - serializers.serialize_uint64(parent), - {attributes.Hierarchy.Child: children}, - time_stamp=cell.timestamp, - ) - ) - return result diff --git a/pychunkedgraph/graph/locks.py b/pychunkedgraph/graph/locks.py index 47a63dacf..7f8304c27 100644 --- a/pychunkedgraph/graph/locks.py +++ b/pychunkedgraph/graph/locks.py @@ -1,6 +1,7 @@ +import hashlib +import time from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Union -from typing import Sequence +from typing import Sequence, Union from collections import defaultdict import networkx as nx @@ -8,9 +9,10 @@ from pychunkedgraph import get_logger -from . import exceptions +from . import attributes, exceptions, serializers from .types import empty_1d from .lineage import lineage_graph +from .dry_run import is_dry_run logger = get_logger(__name__) @@ -57,6 +59,9 @@ def __enter__(self): if not self.operation_id: self.operation_id = self.cg.id_client.create_operation_id() + if is_dry_run(): + return self + if self.privileged_mode: return self @@ -82,6 +87,8 @@ def __enter__(self): return self def __exit__(self, exception_type, exception_value, traceback): + if is_dry_run(): + return if self.lock_acquired: max_workers = min(8, max(1, len(self.locked_root_ids))) with ThreadPoolExecutor(max_workers=max_workers) as executor: @@ -138,6 +145,8 @@ def __init__( self.future_root_ids_d = future_root_ids_d def __enter__(self): + if is_dry_run(): + return self if self.privileged_mode: return self if not self.cg.client.renew_locks(self.root_ids, self.operation_id): @@ -165,6 +174,16 @@ def __enter__(self): return self def __exit__(self, exception_type, exception_value, traceback): + if is_dry_run(): + return + if exception_type is not None: + # Partial bigtable hierarchy writes may have landed before + # the exception propagated. Keep the indefinite cells held + # so subsequent ops on these roots refuse to acquire — + # forces operator recovery (`repair_operation(..., unlock= + # True)`) rather than letting a silent corruption slip into + # further edits. + return if self.acquired: max_workers = min(8, max(1, len(self.root_ids))) with ThreadPoolExecutor(max_workers=max_workers) as executor: @@ -181,3 +200,390 @@ def __exit__(self, exception_type, exception_value, traceback): future.result() except Exception as e: logger.warning(f"Failed to unlock root: {e}") + + +def _downsample_block_lock_row_key(block_coord) -> bytes: + """Row key for one pyramid_block's downsample lock cell. + + Hash-prefixed so spatially-clustered block coords — common when a + team edits the same region — scatter across bigtable tablets instead + of piling up in one lexicographic range, which would hot-spot a + single tablet under concurrent load. + + 26 bytes total: + - 2-byte blake2b hash of the packed coord (tablet distribution). + - 24 bytes of packed coord (big-endian uint64 per axis). + uint64 per axis tracks the existing node-id width and puts no cap on + the block grid. The full coord in the key guarantees uniqueness even + if two coords share the 2-byte hash prefix. + """ + bx, by, bz = (int(c) for c in block_coord) + packed = ( + bx.to_bytes(8, "big", signed=False) + + by.to_bytes(8, "big", signed=False) + + bz.to_bytes(8, "big", signed=False) + ) + return hashlib.blake2b(packed, digest_size=2).digest() + packed + + +class DownsampleBlockLock: + """Lock a set of pyramid_blocks for the lifetime of a downsample task. + + The downsample worker holds one across read → tinybrain → write for + every block it touches. All-or-nothing: on partial acquisition we + release what we got and retry with backoff; on repeated failure we + raise so the pubsub message ends up un-acked and redelivered. + + Uses `cg.client.lock_by_row_key` with hash-prefixed row keys — the + generic row-key lock primitive in kvdbclient — so these rows never + collide with node-id-keyed root locks even though both use the same + `Concurrency.Lock` column. + """ + + __slots__ = ["cg", "block_coords", "operation_id", "acquired_keys"] + + # Retry budget for partial-acquire failures. Each attempt releases + # anything it got in the previous pass, then re-acquires from scratch. + _MAX_ACQUIRE_ATTEMPTS = 7 + _ACQUIRE_BACKOFF_BASE_SEC = 0.5 + + def __init__( + self, + cg, + block_coords: Sequence, + operation_id: np.uint64, + ) -> None: + self.cg = cg + # Sort so every `__enter__` uses a consistent acquisition order + # across workers — reduces contention between workers whose block + # sets overlap. Sort is on the coord tuple (not the hashed row + # key) so the order is stable and debuggable. + self.block_coords = sorted( + (int(bx), int(by), int(bz)) for bx, by, bz in block_coords + ) + self.operation_id = np.uint64(operation_id) + self.acquired_keys: list = [] + + def __enter__(self): + for attempt in range(self._MAX_ACQUIRE_ATTEMPTS): + self.acquired_keys = [] + all_ok = True + for coord in self.block_coords: + row_key = _downsample_block_lock_row_key(coord) + if self.cg.client.lock_by_row_key(row_key, self.operation_id): + self.acquired_keys.append(row_key) + else: + all_ok = False + break + if all_ok: + return self + self._release_acquired() + time.sleep(self._ACQUIRE_BACKOFF_BASE_SEC * (2**attempt)) + raise exceptions.LockingError( + f"Could not acquire downsample block locks for coords " + f"{self.block_coords} after {self._MAX_ACQUIRE_ATTEMPTS} attempts" + ) + + def __exit__(self, exception_type, exception_value, traceback): + self._release_acquired() + + def _release_acquired(self): + if not self.acquired_keys: + return + max_workers = min(8, max(1, len(self.acquired_keys))) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit( + self.cg.client.unlock_by_row_key, key, self.operation_id + ) + for key in self.acquired_keys + ] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + logger.warning(f"Failed to unlock downsample block: {e}") + self.acquired_keys = [] + + def renew(self) -> bool: + """Extend expiry on every held lock. Returns False if any failed.""" + ok = True + for key in self.acquired_keys: + if not self.cg.client.renew_lock_by_row_key(key, self.operation_id): + logger.warning(f"Failed to renew downsample block lock {key!r}") + ok = False + return ok + + +def _l2_chunk_lock_row_key(chunk_id) -> bytes: + """Row key for one L2 chunk's spatial lock cell. + + Hash-prefixed so spatially-clustered chunk IDs scatter across + bigtable tablets instead of piling up in one lexicographic range, + which would hot-spot a single tablet under concurrent load. + + 10 bytes total: + - 2-byte blake2b hash of the chunk_id (tablet distribution). + - 8 bytes of big-endian uint64 chunk_id. + chunk_id already encodes layer+xyz in its bits, so the full key is + unique per L2 chunk. + """ + packed = int(chunk_id).to_bytes(8, "big", signed=False) + return hashlib.blake2b(packed, digest_size=2).digest() + packed + + +class L2ChunkLock: + """Lock a set of L2 chunks to serialize SV splits that touch them. + + Closes the cross-root spatial race: two SV splits on overlapping L2 + chunks but distinct roots acquire disjoint root-lock sets and would + otherwise race on seg state. This lock is held across the + `split_supervoxel` loop (seg write + SV-level hierarchy row write) + so the pair commits atomically. + + All-or-nothing: on partial acquisition we release what we got and + retry with backoff; on repeated failure we raise `LockingError`. + + Uses `cg.client.lock_by_row_key` — the generic row-key lock in + kvdbclient — with a row-key namespace distinct from root and + downsample block locks (all three share `attributes.Concurrency.Lock` + under the hood; the row key disambiguates). + """ + + __slots__ = [ + "cg", + "chunk_ids", + "operation_id", + "privileged_mode", + "acquired_keys", + ] + + # Retry budget for partial-acquire failures. Each attempt releases + # anything it got in the previous pass, then re-acquires from scratch. + _MAX_ACQUIRE_ATTEMPTS = 7 + _ACQUIRE_BACKOFF_BASE_SEC = 0.5 + + def __init__( + self, + cg, + chunk_ids: Sequence[int], + operation_id: np.uint64, + *, + privileged_mode: bool = False, + ) -> None: + self.cg = cg + # Sort so every `__enter__` uses a consistent acquisition order + # across workers — reduces contention when overlapping lock sets + # would otherwise race AB/BA. + self.chunk_ids = sorted(int(c) for c in chunk_ids) + self.operation_id = np.uint64(operation_id) + self.privileged_mode = privileged_mode + self.acquired_keys: list = [] + + def __enter__(self): + if is_dry_run(): + return self + if self.privileged_mode: + # Replay path: the crashed op's `IndefiniteL2ChunkLock` cells + # are still set on these chunks (that's what's blocking new + # ops), and `lock_by_row_key_with_indefinite` would refuse. + # Mirror `RootLock`/`IndefiniteRootLock`'s privileged escape + # hatch — skip temporal acquire, the indefinite cells are + # our de-facto lock and they'll be released by the inner + # `IndefiniteL2ChunkLock(privileged_mode=True)` on exit. + return self + for attempt in range(self._MAX_ACQUIRE_ATTEMPTS): + self.acquired_keys = [] + all_ok = True + for chunk_id in self.chunk_ids: + row_key = _l2_chunk_lock_row_key(chunk_id) + # `_with_indefinite`: the temporal acquire must also + # refuse if the indefinite column is set. Closes the + # crash-recovery race — a worker that died holding + # `IndefiniteL2ChunkLock` leaves the indefinite cell + # set, and the next op must see it rather than silently + # racing into partial state. + if self.cg.client.lock_by_row_key_with_indefinite( + row_key, self.operation_id + ): + self.acquired_keys.append(row_key) + else: + all_ok = False + break + if all_ok: + return self + self._release_acquired() + time.sleep(self._ACQUIRE_BACKOFF_BASE_SEC * (2**attempt)) + raise exceptions.LockingError( + f"Could not acquire L2 chunk locks for chunks {self.chunk_ids} " + f"after {self._MAX_ACQUIRE_ATTEMPTS} attempts" + ) + + def __exit__(self, exception_type, exception_value, traceback): + if is_dry_run(): + return + self._release_acquired() + + def _release_acquired(self): + if not self.acquired_keys: + return + max_workers = min(8, max(1, len(self.acquired_keys))) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit( + self.cg.client.unlock_by_row_key, key, self.operation_id + ) + for key in self.acquired_keys + ] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + logger.warning(f"Failed to unlock L2 chunk: {e}") + self.acquired_keys = [] + + def renew(self) -> bool: + """Extend expiry on every held lock. Returns False if any failed.""" + ok = True + for key in self.acquired_keys: + if not self.cg.client.renew_lock_by_row_key(key, self.operation_id): + logger.warning(f"Failed to renew L2 chunk lock {key!r}") + ok = False + return ok + + +class IndefiniteL2ChunkLock: + """Upgrade held-temporal L2 chunk locks to indefinite. + + Structurally mirrors `IndefiniteRootLock`: acquired inside the + temporal lock (`L2ChunkLock`) context after preconditions are + established, and held across the write phase. Doesn't expire — the + cell persists on bigtable until explicitly released (or operator + recovery clears it), so a worker that dies with writes in flight + leaves the chunks marked indefinitely-held. + + The temporal `L2ChunkLock` must already be held by the same + `operation_id`; the acquire filter for temporal now rejects on + indefinite cells, so future temporal acquires on these chunks + refuse until this lock is released. + + Durable scope: `__enter__` writes `chunk_ids` to the op-log row's + `OperationLogs.L2ChunkLockScope` column. This persists through a + worker crash, giving `stuck_ops replay` the exact chunk set to + clean up without a bigtable-wide lock-row scan. + + `privileged_mode=True` is the operator recovery escape hatch: + skips the acquire step (the cells already exist, held by this same + op_id from the crashed attempt), pre-populates `acquired_keys` from + `chunk_ids` so `__exit__` still value-matches-releases those cells, + and does not re-write the op-log scope column. + """ + + __slots__ = ["cg", "chunk_ids", "operation_id", "privileged_mode", "acquired_keys"] + + def __init__( + self, + cg, + chunk_ids: Sequence[int], + operation_id: np.uint64, + *, + privileged_mode: bool = False, + ) -> None: + self.cg = cg + self.chunk_ids = sorted(int(c) for c in chunk_ids) + self.operation_id = np.uint64(operation_id) + self.privileged_mode = privileged_mode + self.acquired_keys: list = [] + + def __enter__(self): + if is_dry_run(): + return self + if self.privileged_mode: + # Recovery path: crashed op's indefinite cells already exist + # under this op_id. Populate acquired_keys so __exit__'s + # value-matched release deletes them after the replay writes + # succeed. + self.acquired_keys = [_l2_chunk_lock_row_key(c) for c in self.chunk_ids] + return self + for chunk_id in self.chunk_ids: + row_key = _l2_chunk_lock_row_key(chunk_id) + if not self.cg.client.lock_by_row_key_indefinitely( + row_key, self.operation_id + ): + # Partial acquire: release what we got and fail. No + # retry — an indefinite cell belongs to a currently- + # running or crashed op and won't clear on its own. + self._release_acquired() + raise exceptions.LockingError( + f"Could not upgrade L2 chunk {chunk_id} to indefinite lock " + f"(another op holds it)" + ) + self.acquired_keys.append(row_key) + self._write_scope_to_op_log() + return self + + def __exit__(self, exception_type, exception_value, traceback): + if is_dry_run(): + return + if exception_type is not None: + # Partial OCDBT seg / bigtable SV-hierarchy writes may have + # landed before the exception propagated. Leave the + # indefinite cells held and the op-log scope intact so + # subsequent ops refuse at `L2ChunkLock` acquire — forces + # operator recovery (`stuck_ops replay`) rather than + # leaking orphan SV IDs into downstream reads. + return + self._release_acquired() + self._clear_scope_on_op_log() + + def _write_scope_to_op_log(self): + """Record the chunk scope on the op-log row before seg/bigtable + writes begin. A worker crash after this point leaves both the + per-chunk indefinite cells AND this field set, so recovery can + locate the partial-write region without a bigtable scan. + """ + row_key = serializers.serialize_uint64(self.operation_id) + scope = np.asarray(self.chunk_ids, dtype=np.uint64) + entry = self.cg.client.mutate_row( + row_key, + {attributes.OperationLogs.L2ChunkLockScope: scope}, + ) + self.cg.client.write([entry]) + + def _clear_scope_on_op_log(self): + """Clear the scope record on normal exit — op completed or was + cleanly rolled back, so no partial state needs recovery. Overwrites + with an empty array; a subsequent `read_log_entries` returns an + empty scope (recovery skips). Best-effort; failures here are + logged but don't propagate. + """ + try: + row_key = serializers.serialize_uint64(self.operation_id) + empty = np.array([], dtype=np.uint64) + entry = self.cg.client.mutate_row( + row_key, + {attributes.OperationLogs.L2ChunkLockScope: empty}, + ) + self.cg.client.write([entry]) + except Exception as e: + logger.warning(f"Failed to clear L2ChunkLockScope on op-log row: {e}") + + def _release_acquired(self): + if not self.acquired_keys: + return + max_workers = min(8, max(1, len(self.acquired_keys))) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [ + executor.submit( + self.cg.client.unlock_indefinitely_locked_by_row_key, + key, + self.operation_id, + ) + for key in self.acquired_keys + ] + for future in as_completed(futures): + try: + future.result() + except Exception as e: + logger.warning(f"Failed to unlock indefinite L2 chunk: {e}") + self.acquired_keys = [] diff --git a/pychunkedgraph/graph/meta.py b/pychunkedgraph/graph/meta.py index 2d2d1d289..2c4525cf4 100644 --- a/pychunkedgraph/graph/meta.py +++ b/pychunkedgraph/graph/meta.py @@ -8,7 +8,15 @@ import numpy as np from cloudvolume import CloudVolume -from pychunkedgraph.graph.ocdbt import get_seg_source_and_destination_ocdbt +from pychunkedgraph.graph.ocdbt import ( + OcdbtConfig, + build_cg_ocdbt_spec, + ensure_fork_synced, + fork_base_manifest, + fork_exists, + get_seg_source_and_destination_ocdbt, + read_populate_meta, +) from .utils.generic import compute_bitmasks from .chunks.utils import get_chunks_boundary @@ -50,6 +58,29 @@ ) +def _redis_cached_json(key: str, loader): + """Return JSON-decoded value at ``key`` in Redis, or call ``loader()`` and + write the result through. Spares distributed workers from re-fetching the + same GCS object on every CG instantiation. Silently bypasses Redis if it + is unreachable; returns ``loader()`` directly in that case. + """ + redis = None + try: + redis = get_redis_connection() + cached = redis.get(key) + if cached is not None: + return json.loads(cached) + except Exception: + redis = None + value = loader() + if value is not None and redis is not None: + try: + redis.set(key, json.dumps(value)) + except Exception: + ... + return value + + class ChunkedGraphMeta: def __init__( self, graph_config: GraphConfig, data_source: DataSource, custom_data: Dict = {} @@ -71,6 +102,7 @@ def __init__( self._layer_count = None self._bitmasks = None self._ocdbt_seg = None + self._ocdbt_config_cached = None @property def graph_id(self): @@ -93,28 +125,57 @@ def custom_data(self): def ws_cv(self): if self._ws_cv: return self._ws_cv + ws = self._data_source.WATERSHED + info = _redis_cached_json( + f"ws_cv_info_cached:{ws}", + lambda: CloudVolume(ws, progress=False).info, + ) + self._ws_cv = CloudVolume(ws, info=info, progress=False) + return self._ws_cv - cache_key = f"{self.graph_config.ID}:ws_cv_info_cached" - try: - # try reading a cached info file for distributed workers - # useful to avoid md5 errors on high gcs load - redis = get_redis_connection() - cached_info = json.loads(redis.get(cache_key)) - self._ws_cv = CloudVolume( - self._data_source.WATERSHED, info=cached_info, progress=False + @property + def ocdbt_config(self) -> OcdbtConfig: + """Per-CG OCDBT settings with precedence info-file > custom_data > defaults. + + The watershed's ``/ocdbt/.populated/meta.json`` is the authoritative + on-disk source for fields that affect the OCDBT format (compression, + max_inline_value_bytes, populate_layer). custom_data fills per-CG + fields (enabled, sv_split_threshold) and anything the info file + doesn't pin. Both layers fall through to dataclass defaults. + + The info-file fetch goes through a Redis cache (same pattern as + ``ws_cv``) so distributed workers don't re-read the same GCS + object on every CG instantiation. Result is also cached in + instance state after first access. Legacy ``custom_data["seg"]`` + shape is read when ``"ocdbt_config"`` is absent so pre-refactor + CGs still open. + """ + if self._ocdbt_config_cached is not None: + return self._ocdbt_config_cached + + meta_d = self._custom_data.get("ocdbt_config") + if meta_d is None: + seg = self._custom_data.get("seg", {}) + meta_d = { + "enabled": bool(seg.get("ocdbt", False)), + "sv_split_threshold": int(seg.get("sv_split_threshold", 10)), + } + + info_d = None + ws = self._data_source.WATERSHED + if ws: + info_d = _redis_cached_json( + f"ocdbt_info_cached:{ws}", + lambda: read_populate_meta(ws), ) - except Exception: - self._ws_cv = CloudVolume(self._data_source.WATERSHED, progress=False) - try: - redis.set(cache_key, json.dumps(self._ws_cv.info)) - except Exception: - ... - return self._ws_cv + + self._ocdbt_config_cached = OcdbtConfig.resolve(meta_d, info_d) + return self._ocdbt_config_cached @property def ocdbt_seg(self) -> bool: if self._ocdbt_seg is None: - self._ocdbt_seg = self._custom_data.get("seg", {}).get("ocdbt", False) + self._ocdbt_seg = self.ocdbt_config.enabled return self._ocdbt_seg @property @@ -131,9 +192,22 @@ def ws_ocdbt_scales(self): """ assert self.ocdbt_seg, "make sure this pcg has segmentation in ocdbt format" if self._ws_ocdbt_scales is None: + ws = self.data_source.WATERSHED + # Auto-create the fork on first open if missing — e.g. after a + # bigtable copy that gave us a new graph_id. Idempotent and + # race-safe: concurrent opens write identical base-manifest + # bytes to the same path. Can't race with an edit because an + # edit pre-supposes the fork exists. + if not fork_exists(ws, self.graph_id): + fork_base_manifest(ws, self.graph_id) + # Refresh the fork manifest from base if it's stale and edit-free. + # See ensure_fork_synced docstring; without this, post-fork-creation + # populate writes to base are invisible through the kvstack view + # and reads return zeros. + ensure_fork_synced(ws, self.graph_id) _, self._ws_ocdbt_scales, self._ws_ocdbt_resolutions = ( get_seg_source_and_destination_ocdbt( - self.data_source.WATERSHED, self.graph_id + ws, self.graph_id, self.ocdbt_config ) ) return self._ws_ocdbt_scales @@ -272,7 +346,7 @@ def READ_ONLY(self): @property def sv_split_threshold(self) -> int: - return self._custom_data.get("seg", {}).get("sv_split_threshold", 10) + return self.ocdbt_config.sv_split_threshold @property def split_bounding_offset(self): @@ -296,12 +370,22 @@ def dataset_info(self) -> Dict: "n_layers": self.layer_count, "spatial_bit_masks": self.bitmasks, "ocdbt_seg": self.ocdbt_seg, - # Per-CG delta OCDBT path. Neuroglancer must open this - # via the kvstack spec from build_cg_ocdbt_spec() to see - # both base + delta data. Opening it as plain OCDBT only - # sees the delta. - "ocdbt_path": ( - f"ocdbt/{self.graph_id}" if self._graph_config.ID else None + # Full kvstore spec a reader hands to tensorstore's + # `neuroglancer_precomputed` driver. Server owns the + # contract — paths, data prefixes, and OCDBT config + # (e.g. `max_inline_value_bytes`) are all resolved + # here, so readers don't duplicate configuration and + # future schema changes are picked up on re-fetch. + # Readers pass this verbatim as `kvstore`; add a + # `version` field for time-travel reads. + "ocdbt_kvstore_spec": ( + build_cg_ocdbt_spec( + self._data_source.WATERSHED, + self.graph_id, + self.ocdbt_config, + ) + if self.ocdbt_seg and self._graph_config.ID + else None ), }, } diff --git a/pychunkedgraph/graph/ocdbt.py b/pychunkedgraph/graph/ocdbt.py deleted file mode 100644 index fb12cb7d0..000000000 --- a/pychunkedgraph/graph/ocdbt.py +++ /dev/null @@ -1,429 +0,0 @@ -"""OCDBT-backed neuroglancer_precomputed segmentation store. - -Architecture: one immutable base OCDBT per watershed + one delta OCDBT per -ChunkedGraph. Reads merge base + delta via tensorstore's kvstack driver. -Writes land in the delta via OCDBT's *_data_prefix options. - -Multi-scale (MIP pyramid) is supported: the source watershed's info JSON -drives the scale layout. All scales share one OCDBT kvstore; the precomputed -driver prefixes keys by scale key automatically. -""" - -import json - -import numpy as np -import tensorstore as ts - -from pychunkedgraph import get_logger - -logger = get_logger(__name__) - -OCDBT_SEG_COMPRESSION_LEVEL = 12 - -OCDBT_CONFIG = { - "compression": {"id": "zstd", "level": OCDBT_SEG_COMPRESSION_LEVEL}, - # Inline chunk values into B+tree leaves so they share the leaf's zstd - # compression context. Default (100 bytes) puts every chunk in its own - # out-of-line blob with independent zstd framing → ~7x bloat on GCS. - # 512 KiB captures every compressed_segmentation chunk we've measured. - "max_inline_value_bytes": 524288, -} - - -def _read_source_scales(ws_path): - """Read the source precomputed `info` JSON to get scale count and resolutions. - - The leading '/' in '/info' is required for GCS — without it the read - returns empty. - """ - kvs = ts.KvStore.open(ws_path).result() - info = json.loads(kvs.read("/info").result().value) - return info["scales"] - - -def _open_precomputed_scale(kvstore, scale_index, create=False, **schema_kw): - """Open one neuroglancer_precomputed scale on top of a kvstore spec.""" - spec = { - "driver": "neuroglancer_precomputed", - "kvstore": kvstore, - "scale_index": scale_index, - } - return ts.open(spec, create=create, **schema_kw).result() - - -def _schema_from_src(src_handle): - """Extract the schema kwargs needed to open a matching destination.""" - s = src_handle.schema - return dict( - rank=s.rank, - dtype=s.dtype, - codec=s.codec, - domain=s.domain, - shape=s.shape, - chunk_layout=s.chunk_layout, - dimension_units=s.dimension_units, - ) - - -# --------------------------------------------------------------------------- -# Base OCDBT (shared, immutable after ingest) -# --------------------------------------------------------------------------- - - -def _ensure_trailing_slash(path): - """Ensure kvstore paths end with / so they're treated as directories.""" - return path if path.endswith("/") else path + "/" - - -def _base_ocdbt_path(ws_path): - return _ensure_trailing_slash(f"{ws_path.rstrip('/')}/ocdbt/base") - - -def base_exists(ws_path: str) -> bool: - """Check if the base OCDBT has already been created for this watershed.""" - base = _base_ocdbt_path(ws_path) - kvs = ts.KvStore.open(base).result() - result = kvs.read("manifest.ocdbt").result() - return result.value is not None and len(result.value) > 0 - - -def create_base_ocdbt(ws_path: str): - """One-time bootstrap: create the shared base OCDBT at /ocdbt/base/. - - Wipes any existing base first, then opens each scale with create=True - so the info JSON is built from the source. Populating the base with - actual chunk data happens separately via copy_ws_chunk_multiscale - during the per-chunk ingest tasks. - - Returns (src_list, dst_list, resolutions) for the caller to use with - copy_ws_chunk_multiscale. - """ - base = _base_ocdbt_path(ws_path) - # Wipe existing base for a clean slate. - try: - kvs = ts.KvStore.open({"driver": "ocdbt", "base": base}).result() - kvs.delete_range(ts.KvStore.KeyRange()).result() - except Exception: - pass - - scales = _read_source_scales(ws_path) - resolutions = [s["resolution"] for s in scales] - base_kvstore = {"driver": "ocdbt", "base": base, "config": dict(OCDBT_CONFIG)} - - src_list, dst_list = [], [] - for i in range(len(scales)): - src_i = ts.open( - {"driver": "neuroglancer_precomputed", "kvstore": ws_path, "scale_index": i} - ).result() - dst_i = _open_precomputed_scale( - base_kvstore, i, create=True, **_schema_from_src(src_i) - ) - src_list.append(src_i) - dst_list.append(dst_i) - return src_list, dst_list, resolutions - - -def wipe_base_ocdbt(ws_path: str): - """Wipe the base OCDBT entirely (for --reset-ocdbt).""" - base = _base_ocdbt_path(ws_path) - try: - kvs = ts.KvStore.open({"driver": "ocdbt", "base": base}).result() - kvs.delete_range(ts.KvStore.KeyRange()).result() - except Exception: - pass - - -def open_base_ocdbt(ws_path: str): - """Open the existing base OCDBT (read/write) for populating during ingest. - - Used by per-chunk ingest tasks that copy precomputed data into the shared - base. NOT used at runtime — runtime always goes through the per-CG fork - spec via get_seg_source_and_destination_ocdbt. - - Returns (src_list, dst_list, resolutions). - """ - base = _base_ocdbt_path(ws_path) - scales = _read_source_scales(ws_path) - resolutions = [s["resolution"] for s in scales] - base_kvstore = {"driver": "ocdbt", "base": base, "config": dict(OCDBT_CONFIG)} - - src_list, dst_list = [], [] - for i in range(len(scales)): - src_i = ts.open( - {"driver": "neuroglancer_precomputed", "kvstore": ws_path, "scale_index": i} - ).result() - dst_i = _open_precomputed_scale(base_kvstore, i, **_schema_from_src(src_i)) - src_list.append(src_i) - dst_list.append(dst_i) - return src_list, dst_list, resolutions - - -# --------------------------------------------------------------------------- -# Per-CG delta (fork of the base) -# --------------------------------------------------------------------------- - - -def build_cg_ocdbt_spec(ws_path: str, graph_id: str) -> dict: - """Open-time kvstore spec for a CG's OCDBT, backed by a shared immutable base. - - The fork directory and its manifest are created automatically by - `fork_base_manifest` as part of CG creation — no manual setup. - - All three kvstack layers below AND all three `*_data_prefix` options - are load-bearing; removing any of them causes fork writes to leak - into the immutable base (verified empirically). - """ - base = _base_ocdbt_path(ws_path) - fork_dir = _ensure_trailing_slash(f"{ws_path.rstrip('/')}/ocdbt/{graph_id}") - data_prefix = f"{graph_id}_d/" - - # Catch-all. Lets the fork READ base's B+tree (manifest + d/ - # data files) via fall-through. Must be first so later layers can - # override sub-ranges. - base_layer = {"base": base} - - # Single-key override. Routes the fork's manifest file so new - # commits by this CG are visible only to this CG. Without this layer - # manifest writes silently clobber base's manifest. - fork_manifest_layer = { - "exact": "manifest.ocdbt", - "base": fork_dir + "manifest.ocdbt", - } - - # Catches OCDBT's new data-file writes for the fork. Pairs with the - # *_data_prefix options: OCDBT would otherwise write under the - # default `d/` prefix — no later layer claims `d/`, so kvstack - # falls through to the base catch-all and the writes corrupt base. - fork_data_layer = { - "prefix": data_prefix, - "base": _ensure_trailing_slash(fork_dir + data_prefix), - } - - return { - "driver": "ocdbt", - "base": { - "driver": "kvstack", - "layers": [base_layer, fork_manifest_layer, fork_data_layer], - }, - "config": dict(OCDBT_CONFIG), - # Steer every kind of OCDBT write under `_d/` so the - # fork_data_layer catches them. - "value_data_prefix": data_prefix, - "btree_node_data_prefix": data_prefix, - "version_tree_node_data_prefix": data_prefix, - } - - -def fork_base_manifest(ws_path: str, graph_id: str, wipe_existing: bool = False): - """Initialize a CG's delta directory by copying the base manifest. - - If wipe_existing=True, deletes the existing fork directory first (for - --retry when a prior ingest failed and left partial delta state). - """ - assert base_exists(ws_path), "base OCDBT must exist before forking" - base = _base_ocdbt_path(ws_path) - fork_dir = _ensure_trailing_slash(f"{ws_path.rstrip('/')}/ocdbt/{graph_id}") - - if wipe_existing: - try: - kvs = ts.KvStore.open(fork_dir).result() - kvs.delete_range(ts.KvStore.KeyRange()).result() - except Exception: - pass - - base_kvs = ts.KvStore.open(base).result() - fork_kvs = ts.KvStore.open(fork_dir).result() - manifest = base_kvs.read("manifest.ocdbt").result().value - fork_kvs.write("manifest.ocdbt", manifest).result() - - -def get_seg_source_and_destination_ocdbt(ws_path: str, graph_id: str) -> tuple: - """Open source watershed + CG's delta OCDBT destination (all scales). - - Always uses the fork-based kvstack spec. Requires the base to exist and - the fork's manifest to be present (set up at ingest time). - - Returns: - (src_list, dst_list, resolutions): per-scale TensorStore handles - and [x,y,z] resolutions. - """ - scales = _read_source_scales(ws_path) - resolutions = [s["resolution"] for s in scales] - cg_kvstore = build_cg_ocdbt_spec(ws_path, graph_id) - - src_list, dst_list = [], [] - for i in range(len(scales)): - src_i = ts.open( - {"driver": "neuroglancer_precomputed", "kvstore": ws_path, "scale_index": i} - ).result() - dst_i = _open_precomputed_scale(cg_kvstore, i, **_schema_from_src(src_i)) - src_list.append(src_i) - dst_list.append(dst_i) - return src_list, dst_list, resolutions - - -def copy_ws_chunk( - source, - destination, - chunk_size: tuple, - coords: list, - voxel_bounds: np.ndarray, -): - """Copy one chunk from source watershed to OCDBT destination at the same scale. - - Coordinates are interpreted at the source/destination's native scale — - callers must pre-scale them when copying coarser MIP levels. - """ - coords = np.array(coords, dtype=int) - chunk_size = np.array(chunk_size, dtype=int) - vx_start = coords * chunk_size + voxel_bounds[:, 0] - vx_end = vx_start + chunk_size - xE, yE, zE = voxel_bounds[:, 1] - - x0, y0, z0 = vx_start - x1, y1, z1 = vx_end - x1 = min(x1, xE) - y1 = min(y1, yE) - z1 = min(z1, zE) - - data = source[x0:x1, y0:y1, z0:z1].read().result() - destination[x0:x1, y0:y1, z0:z1].write(data).result() - - -def copy_ws_chunk_multiscale( - src_list, - dst_list, - resolutions, - chunk_size: tuple, - coords: list, - voxel_bounds: np.ndarray, -): - """Copy a base-resolution chunk's physical region across all MIP scales. - - The graph's chunk grid is defined at base resolution. For each coarser - scale we copy the SAME physical region — voxel coordinates are divided - by the cumulative downsample factor (derived from resolution ratios). - Source already has correct data at every scale, so this is a pure copy - with no recomputation. - """ - assert len(src_list) == len(dst_list) == len(resolutions) - coords = np.array(coords, dtype=int) - chunk_size_arr = np.array(chunk_size, dtype=int) - base_res = np.array(resolutions[0]) - - # Physical region at base resolution. - vx_start_base = coords * chunk_size_arr + voxel_bounds[:, 0] - vx_end_base = np.minimum(vx_start_base + chunk_size_arr, voxel_bounds[:, 1]) - - for i, (src, dst) in enumerate(zip(src_list, dst_list)): - # Cumulative factor from base to this scale (e.g. [2,2,1] per level). - factor = (np.array(resolutions[i]) / base_res).astype(int) - x0, y0, z0 = vx_start_base // factor - x1, y1, z1 = vx_end_base // factor - if x1 <= x0 or y1 <= y0 or z1 <= z0: - logger.debug(f"skipping empty region at scale {i}") - continue - data = src[x0:x1, y0:y1, z0:z1].read().result() - dst[x0:x1, y0:y1, z0:z1].write(data).result() - - -def _mode_downsample(data: np.ndarray, factors: tuple) -> np.ndarray: - """Mode downsample 4D segmentation array [X,Y,Z,C] by per-axis factors. - - Mode (most-frequent label) is the correct downsampling for segmentation: - it preserves exact label IDs (no interpolation) and biases toward the - dominant label in each block. - - Fast path for 2x2x1: uses a vectorized 4-element pairwise comparison. - Among 4 voxels {a,b,c,d}, if any value appears at least twice it is the - mode. Order of comparisons biases ties toward the top-left corner, which - is the standard convention for segmentation downsampling. - """ - fx, fy, fz = factors - X, Y, Z, C = data.shape - - # Pad with edge values so dimensions are divisible by the factor. - # Using 'edge' (not zeros) avoids introducing a phantom background label. - pad = [(0, (-X % fx) % fx), (0, (-Y % fy) % fy), (0, (-Z % fz) % fz), (0, 0)] - if any(p[1] > 0 for p in pad): - data = np.pad(data, pad, mode="edge") - X, Y, Z, C = data.shape - - if fx == 2 and fy == 2 and fz == 1: - # Fast vectorized path for the common 2x2x1 case. - reshaped = data.reshape(X // 2, 2, Y // 2, 2, Z, C) - a = reshaped[:, 0, :, 0] - b = reshaped[:, 0, :, 1] - c = reshaped[:, 1, :, 0] - d = reshaped[:, 1, :, 1] - return np.where( - (a == b) | (a == c) | (a == d), - a, - np.where((b == c) | (b == d), b, np.where(c == d, c, a)), - ) - - if fx == 2 and fy == 2 and fz == 2: - # 2x2x2 (8-element mode) — strided subsample is fast and label-safe - # for typical segmentation where adjacent voxels share labels. - return data[::2, ::2, ::2] - - # Generic factor: reshape into blocks, take strided first element. - # This is label-safe but loses the mode property; downsampling factor - # ratios in production are 2x2x1 or 2x2x2 so the fast paths cover them. - reshaped = data.reshape(X // fx, fx, Y // fy, fy, Z // fz, fz, C) - return reshaped[:, 0, :, 0, :, 0] - - -def propagate_to_coarser_scales(dst_scales, resolutions, base_slices): - """Cascade-downsample data from base scale through all coarser scales. - - Called after writing to the base scale (e.g. after an SV split). Each - coarser scale reads from the level below it (not from base directly), - so total downsample cost shrinks geometrically — each level processes - 1/N the data of the previous one. - - Args: - dst_scales: TensorStore handles, one per MIP level. - resolutions: [x,y,z] resolution arrays per scale, used to derive - per-axis downsample factors from consecutive resolution ratios. - base_slices: tuple of 3 slices (x, y, z) covering the region written - at base resolution. - """ - prev_slices = base_slices - for i in range(1, len(dst_scales)): - # Per-axis downsample factor from actual resolution ratio. - # Never hardcoded — different datasets may have different ratios. - factor = (np.array(resolutions[i]) / np.array(resolutions[i - 1])).astype(int) - - # Map prev-level slices to this level's coordinates. - # Ceil division on stop ensures we cover any partial block. - target_slices = tuple( - slice(s.start // f, -(-s.stop // f)) for s, f in zip(prev_slices, factor) - ) - - data = dst_scales[i - 1][prev_slices + (slice(None),)].read().result() - downsampled = _mode_downsample(data, tuple(int(f) for f in factor)) - dst_scales[i][target_slices + (slice(None),)].write(downsampled).result() - - prev_slices = target_slices - - -def write_seg(meta, bbs, bbe, data): - """Write segmentation at base scale and propagate to coarser scales. - - Single entry point for all SV-split-time segmentation writes. Builds - the tensorstore slices from the bounding box and adds the channel - dimension, so callers just pass the 3D bbox + 3D data. - - Args: - meta: ChunkedGraphMeta with ws_ocdbt_scales and ws_ocdbt_resolutions. - bbs: (3,) array — start of the region in base-resolution voxels. - bbe: (3,) array — end of the region in base-resolution voxels. - data: 3D numpy array of new segmentation IDs. - """ - slices = tuple(slice(int(s), int(e)) for s, e in zip(bbs, bbe)) - meta.ws_ocdbt[slices + (slice(None),)] = data[..., np.newaxis] - if len(meta.ws_ocdbt_scales) > 1: - propagate_to_coarser_scales( - meta.ws_ocdbt_scales, meta.ws_ocdbt_resolutions, slices - ) diff --git a/pychunkedgraph/graph/ocdbt/TENSORSTORE_REFERENCE.md b/pychunkedgraph/graph/ocdbt/TENSORSTORE_REFERENCE.md new file mode 100644 index 000000000..7faf4f140 --- /dev/null +++ b/pychunkedgraph/graph/ocdbt/TENSORSTORE_REFERENCE.md @@ -0,0 +1,134 @@ +# tensorstore OCDBT reference + +Every entry below was verified by probing tensorstore directly (intentional-bad-value + spec round-trip) against the binary in this workspace's venv. Re-verify if the tensorstore version changes. + +## OCDBT kvstore spec — top-level fields + +Sibling of `driver: "ocdbt"`: + +| Field | Type | Default | Notes | +|---|---|---|---| +| `base` | kvstore spec or URL | — | underlying kvstore (gcs/file/s3/…) | +| `manifest` | kvstore spec or URL | (under `base`) | the manifest *can* live in a separate kvstore from data | +| `config` | object | `{}` | see Config sub-fields below | +| `assume_config` | bool | `false` | skip reading the existing config from the manifest (use with care) | +| `coordinator` | ocdbt_coordinator resource | named ref `"ocdbt_coordinator"` | enables distributed mode when set | +| `cache_pool` | cache_pool resource | named ref `"cache_pool"` | | +| `data_copy_concurrency` | data_copy_concurrency resource | named ref | | +| `target_data_file_size` | uint64 | driver default | when a single commit's d/ writes exceed this, the writer rolls a new d/ file | +| `experimental_read_coalescing_threshold_bytes` | uint64 | — | | +| `experimental_read_coalescing_merged_bytes` | uint64 | — | | +| `experimental_read_coalescing_interval` | uint64 | — | | +| `btree_node_data_prefix` | string | `"d/"` | path prefix for btree-node files | +| `value_data_prefix` | string | `"d/"` | path prefix for value files | +| `version_tree_node_data_prefix` | string | `"d/"` | path prefix for version-tree files | +| `path` | string | `""` | sub-prefix in the kvstore | + +**Not fields**: `data_file_prefixes`, `version_spec`, `recheck_cached*`, `transaction`, `btree_writer_concurrency`, `manifest_kind` (lives under `config`). + +## OCDBT `config` sub-fields + +| Field | Type | tensorstore default | Notes | +|---|---|---|---| +| `compression` | object | `{}` (none) | `{"id": "zstd", "level": N}` — zstd level 1–22 | +| `max_inline_value_bytes` | uint64 | `100` | values ≤ this size live inline in the btree leaf bytes; larger values get written to a d/ file and the mutation carries only an `IndirectDataReference`. In distributed mode this **directly bounds cooperator-forwarded RPC size**: inline values are carried inside the `WriteRequest.mutations` field, so a leaf's batch blows past the 4 MiB gRPC max-receive whenever multiple inline values pile up on one node. Source: `distributed/btree_writer.cc` `StagePending`. Setting low (≤ a few KB) pushes chunk values out-of-line → small mutations → small RPCs. | +| `max_decoded_node_bytes` | uint64 | `8388608` (8 MiB) | btree node split threshold. Larger nodes → shallower tree → fewer per-commit node touches. Setting this *smaller* than the default INCREASES per-commit forwarded bytes — empirically went from ~8 MiB to ~23 MiB RPCs when set to 1 MiB. | +| `version_tree_arity_log2` | int | — | controls version tree branching; rarely tuned | +| `manifest_kind` | enum | `"single"` | `"single"` or `"numbered"` (manifest history retained — needed for time-travel reads) | +| `uuid` | string | (auto) | 32-hex per-base UUID assigned at create time | + +**Not fields**: `data_file_prefixes`, `data_file_prefix`, `btree_node_arity_log2`, `version_tree_node_arity`. + +## `ocdbt_coordinator` context resource + +| Field | Type | Default | Notes | +|---|---|---|---| +| `address` | string | — | `"host:port"` of the DistributedCoordinatorServer | +| `lease_duration` | duration string (`"1s"`, `"500ms"`, etc.) | — | how long a lease holder owns a btree node | +| `security` | object | `{method: "insecure"}` | requires `method` key. This build has **no** security methods registered (build flag) — all calls cleartext. | + +## `DistributedCoordinatorServer({...})` + +| Field | Type | Default | Notes | +|---|---|---|---| +| `bind_addresses` | list[string] | one ephemeral port | gRPC server bind address(es). `.port` after construction gives the ephemeral port. | +| `security` | object | insecure | same shape as the resource's security | + +**There is NO Python knob for the gRPC server's max-receive message size.** The 4 MiB default is set inside tensorstore's gRPC server builder. Confirmed by strings on the binary: no `TENSORSTORE_*` env var, no spec/resource field, no Context resource that maps to `grpc.max_receive_message_length`. + +## Distributed vs non-distributed write paths + +The OCDBT driver picks one of two compiled implementations at open time: + +- **non-distributed** (`btree_writer.cc`): coordinator absent. Each commit writes the manifest itself. Concurrent writers race the manifest CAS; losers retry; their pre-commit d/ writes become orphans. +- **distributed** (`distributed/btree_writer.cc`, `cooperator_*.cc`): coordinator present. One lease holder per btree node serializes commits. Other cooperators **forward their mutations over gRPC** to the lease holder. + +### Constraints unique to distributed mode + +1. **`ts.Transaction(atomic=True)` is incompatible.** "Cannot read/write … as single atomic transaction" — verified on (info + chunk) and on (cross-key). A plain `ts.Transaction()` still batches all writes into one OCDBT commit; only the *atomicity* across keys is lost. +2. **Cooperator-forwarded RPC ≤ ~4 MiB.** Carries (btree node delta) + (value bytes for keys committed into that node). +3. **Disjoint user-key writes still trigger forwarding.** Leases are per btree node, not per user-key range. Two workers writing distinct keys into the same node → one forwards to the other. + +## Cooperator batching + +`cooperator_submit_mutation_batch.cc` `SendToPeer` is the gRPC sender. The `WriteRequest` proto has `repeated bytes mutations` — each entry is one encoded `BtreeNodeWriteMutation` destined for the same leaf. The encoded mutation embeds the value_reference inline if it's an `absl::Cord`, or carries just an `IndirectDataReference` (small struct) otherwise. So **what's actually on the wire per RPC = (small request header) + Σ encoded mutations**, and each encoded mutation's size is dominated by its value bytes IF the value is inline. + +Threshold for inline-vs-ref is `max_inline_value_bytes` (see config table). That's the real lever for RPC size. + +What changes RPC size (verified by production dumps): +- `max_inline_value_bytes=1 MiB`, default node bytes → RPCs 5–8 MiB (inline chunks pile up in the batch) +- `max_inline_value_bytes=1 MiB` + `max_decoded_node_bytes=1 MiB` → RPCs up to 23 MiB (smaller nodes ≠ smaller RPCs) +- `max_inline_value_bytes=1 MiB` + dst `chunk_size` halved → RPCs grew to 12 MiB (more mutations per node → bigger batches) +- `max_inline_value_bytes=4 KiB` (chunks go out-of-line) → mutations carry only refs; RPC = small header + N×(key + ref + generation) → fits 4 MiB regardless of value sizes (this is the path our code takes) + +## Defaults visible from spec round-trip + +```json +{ + "assume_config": false, + "btree_node_data_prefix": "d/", + "config": {}, + "coordinator": "ocdbt_coordinator", + "cache_pool": "cache_pool", + "data_copy_concurrency": "data_copy_concurrency", + "value_data_prefix": "d/", + "version_tree_node_data_prefix": "d/" +} +``` + +## Env vars + +- `OCDBT_COORDINATOR_HOST`, `OCDBT_COORDINATOR_PORT`: **NO EFFECT**. Not referenced anywhere in the binary. Address must go in spec's `coordinator.address`. +- `TENSORSTORE_VERBOSE_LOGGING`: comma-separated tag list to stderr. Tags include `ocdbt`, `coordinator`. + +Other `TENSORSTORE_*` vars exist (CA paths, S3/GCS concurrency, etc.) — grep the binary. + +## On-disk layout + +- `manifest.ocdbt` at the base — root btree node + current data file refs. +- `d/` — directory of "data files" each holding concatenated values + (optionally) btree node bytes + version-tree node bytes. +- Each commit creates **at least one** d/ file holding all values + nodes for that commit, then a CAS-update of `manifest.ocdbt`. +- `target_data_file_size` controls when a single commit splits its d/ writes across files. + +## How this maps onto pychunkedgraph + +- `OcdbtConfig` (`pychunkedgraph/graph/ocdbt/meta.py`) → `compression: zstd 12`, `max_inline_value_bytes = 4 KiB`. The 4 KiB threshold keeps small metadata (info JSON, populate markers) inline while forcing every chunk value out-of-line into d/ files — this is what keeps cooperator RPCs under the 4 MiB gRPC ceiling. +- `create_base_ocdbt` is the **only** path that embeds `config.ts_config()` in its kvstore spec — that write persists the values into `manifest.ocdbt`. Every open path (`open_base_ocdbt`, `build_cg_ocdbt_spec`) omits the `config` block: tensorstore would otherwise assert our in-code defaults against the on-disk manifest and raise `FAILED_PRECONDITION` on any drift, bricking every existing base. On-disk wins. +- `populate_chunk` (`pychunkedgraph/ingest/ocdbt.py`) opens the base with `coordinator_address` (distributed mode). +- `copy_ws_bbox_multiscale` uses **non-atomic** `ts.Transaction()` because of the distributed-mode constraint above. +- `_dump_failure_to_gcs` writes JSON failure forensics when `ERROR_DUMP` env is set. + +## Empirically tried and ruled out + +- `OCDBT_COORDINATOR_HOST/PORT` env vars — no effect. +- Bumping gRPC max-receive via env / channel arg / spec field — no such knob. +- Smaller `dst chunk_size` alone — RPC size grew (more mutations per node). +- Smaller `max_decoded_node_bytes` alone — RPC size grew (more per-commit node touches). +- `--ocdbt-edges` legacy path — decommissioned, removed. +- `ts.Transaction(atomic=True)` with distributed coordinator — incompatible. + +## Open observations (not verified at production scale) + +- `lease_duration` may reduce cross-cooperator forwarding if held long enough that a worker's whole task lands on its own nodes. +- `target_data_file_size` may affect manifest growth but not RPC size. +- Switching dst encoding from `compressed_segmentation` to `raw` would make per-value size predictable (`chunk_volume × bytes_per_voxel`), bypassing the dense-region pathological CS encoding (one observed key encoded to 23 MiB at 256×256×64). diff --git a/pychunkedgraph/graph/ocdbt/__init__.py b/pychunkedgraph/graph/ocdbt/__init__.py new file mode 100644 index 000000000..633964e1c --- /dev/null +++ b/pychunkedgraph/graph/ocdbt/__init__.py @@ -0,0 +1,57 @@ +"""Public API for the OCDBT-backed segmentation store. + +See ``main.py`` for the architectural notes. This module just re-exports +the names that external callers (ingest, edits, runtime, tests) reach for. +""" + +from .meta import OcdbtConfig +from .utils import ( + _layer_bbox, + _read_source_scales, + base_exists, + fork_exists, + is_chunk_populated, + mark_chunk_populated, + read_populate_meta, + write_populate_meta, +) +from .main import ( + _mode_downsample, + build_cg_ocdbt_spec, + copy_ws_bbox_multiscale, + copy_ws_chunk, + copy_ws_chunk_multiscale, + create_base_ocdbt, + ensure_fork_synced, + fork_base_manifest, + get_seg_source_and_destination_ocdbt, + open_base_ocdbt, + propagate_to_coarser_scales, + wipe_base_ocdbt, + write_seg_chunks, +) + +__all__ = [ + "OcdbtConfig", + "_layer_bbox", + "_mode_downsample", + "_read_source_scales", + "base_exists", + "build_cg_ocdbt_spec", + "copy_ws_bbox_multiscale", + "copy_ws_chunk", + "copy_ws_chunk_multiscale", + "create_base_ocdbt", + "ensure_fork_synced", + "fork_base_manifest", + "fork_exists", + "get_seg_source_and_destination_ocdbt", + "is_chunk_populated", + "mark_chunk_populated", + "open_base_ocdbt", + "propagate_to_coarser_scales", + "read_populate_meta", + "wipe_base_ocdbt", + "write_populate_meta", + "write_seg_chunks", +] diff --git a/pychunkedgraph/graph/ocdbt/debug.py b/pychunkedgraph/graph/ocdbt/debug.py new file mode 100644 index 000000000..b64376695 --- /dev/null +++ b/pychunkedgraph/graph/ocdbt/debug.py @@ -0,0 +1,143 @@ +"""Diagnostic plumbing for OCDBT failures. + +Humanize-count for log lines, generic failure envelope (host/pod/versions/ +traceback/timestamp), bbox-failure payload builder, and a GCS dump helper +that writes per-task forensic JSON under ``$ERROR_DUMP/__.json``. +Kept out of ``main.py`` and ``utils.py`` so the core OCDBT code stays +free of import bloat that's only used on failure paths. +""" + +import json +import logging +import os +import socket +import sys +import traceback +from datetime import datetime, timezone +from os import environ +from typing import Optional + +import tensorstore as ts + +_logger = logging.getLogger(__name__) + + +def humanize_count(n: int) -> str: + """Compact count for log lines: 1234567 → '1.2M', 950 → '950'.""" + for unit, scale in (("G", 1_000_000_000), ("M", 1_000_000), ("K", 1_000)): + if n >= scale: + return f"{n / scale:.1f}{unit}" + return str(n) + + +def failure_envelope(exc: BaseException, dump_tag: Optional[str]) -> dict: + """Generic metadata for any failure dump — host, pod, versions, + timestamp, traceback, coordinator env. Caller merges with the + failure-specific fields to build the final payload. + """ + return { + "timestamp_utc": datetime.now(timezone.utc).isoformat(), + "dump_tag": dump_tag, + "host": { + "hostname": socket.gethostname(), + "pid": os.getpid(), + "pod_name": environ.get("MY_POD_NAME"), + "pod_ip": environ.get("MY_POD_IP"), + "node_name": environ.get("MY_NODE_NAME"), + }, + "versions": { + "tensorstore": getattr(ts, "__version__", None), + "python": sys.version, + }, + "ocdbt_coordinator_env": { + "OCDBT_COORDINATOR_HOST": environ.get("OCDBT_COORDINATOR_HOST"), + "OCDBT_COORDINATOR_PORT": environ.get("OCDBT_COORDINATOR_PORT"), + }, + "exception": { + "type": type(exc).__name__, + "module": type(exc).__module__, + "message": str(exc), + "traceback": traceback.format_exc(), + }, + } + + +def bbox_failure_payload( + exc: BaseException, + dump_tag: Optional[str], + bbox_lo, + bbox_hi, + resolutions, + per_scale, + dst_handle, + src_handle, +) -> dict: + """Build the full diagnostic dict for a ``copy_ws_bbox_multiscale`` + commit failure. + + Merges generic ``failure_envelope`` metadata with bbox-specific + fields (per-scale shape / chunk / key-count / raw-bytes, src+dst + kvstore specs). Spec dumps are wrapped in try/except so a malformed + handle doesn't shadow the original exception. + """ + try: + dst_spec = dst_handle.kvstore.spec().to_json() + except Exception as e: + dst_spec = f"" + try: + src_spec = src_handle.kvstore.spec().to_json() + except Exception as e: + src_spec = f"" + total_voxels = sum(p[2] for p in per_scale) + total_raw = sum(p[3] for p in per_scale) + total_keys = sum(p[5] for p in per_scale) + return { + **failure_envelope(exc, dump_tag), + "bbox_lo": [int(c) for c in bbox_lo], + "bbox_hi": [int(c) for c in bbox_hi], + "resolutions": [list(map(int, r)) for r in resolutions], + "n_scales": len(per_scale), + "total_voxels": total_voxels, + "total_raw_bytes": total_raw, + "total_keys": total_keys, + "per_scale": [ + { + "scale_index": i, + "dims": list(dims), + "voxels": nvox, + "raw_bytes": raw_bytes, + "chunk_shape": list(chunk_shape), + "n_keys": n_keys, + "max_raw_per_key_bytes": max_per_key, + } + for i, dims, nvox, raw_bytes, chunk_shape, n_keys, max_per_key in per_scale + ], + "dst_kvstore_spec": dst_spec, + "src_kvstore_spec": src_spec, + } + + +def dump_failure_to_gcs(payload: dict, dump_tag: str) -> Optional[str]: + """Write a per-task failure report to ``$ERROR_DUMP/__.json``. + + Returns the full path or None (env unset, dump_tag empty, or write + error). ``dump_tag`` carries the calling-context identifier (graph + id, layer, coords, …) so multiple experiments can share one + ``ERROR_DUMP`` bucket without collisions. + """ + root = environ.get("ERROR_DUMP", "").strip() + if not root or not dump_tag: + return None + if not root.endswith("/"): + root += "/" + utc = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S.%fZ") + rel = f"{dump_tag}__{utc}.json" + full = root + rel + try: + ts.KvStore.open(root).result().write( + rel, json.dumps(payload, indent=2).encode("utf-8") + ).result() + return full + except Exception as e: + _logger.warning("failed to write ERROR_DUMP at %s: %r", full, e) + return None diff --git a/pychunkedgraph/graph/ocdbt/main.py b/pychunkedgraph/graph/ocdbt/main.py new file mode 100644 index 000000000..f24ad6a74 --- /dev/null +++ b/pychunkedgraph/graph/ocdbt/main.py @@ -0,0 +1,583 @@ +"""OCDBT-backed neuroglancer_precomputed segmentation store — public API. + +Architecture: one immutable base OCDBT per watershed + one delta OCDBT per +ChunkedGraph. Reads merge base + delta via tensorstore's kvstack driver. +Writes land in the delta via OCDBT's ``*_data_prefix`` options. + +Multi-scale (MIP pyramid) is supported: the source watershed's info JSON +drives the scale layout. All scales share one OCDBT kvstore; the precomputed +driver prefixes keys by scale key automatically. + +Versioned reads +--------------- +Every OCDBT commit gets a monotonically-increasing ``generation_number`` and +an ``absl::Now()``-stamped ``commit_time`` (nanoseconds since epoch). The +tensorstore OCDBT driver lets callers pin a read-only open to a prior version +via the ``version`` spec field; accepts either an integer generation number +or an ISO-8601 UTC timestamp string. The timestamp form requires a ``Z`` +suffix (not ``+00:00``) and is interpreted as ``commit_time <= T`` — the open +returns the latest version at or before the pinned time. + +The commit_time itself cannot be overridden by the caller: OCDBT stamps each +commit from the writer's local clock (``absl::Now()`` in +``btree_writer_commit_operation.cc``). This means we can't make OCDBT commits +align exactly with a caller-provided operation timestamp. What the L2 chunk +lock guarantees instead: no other writer can commit to our chunks while we +hold the lock, so any timestamp captured under the lock before our first +commit is a valid pin for "pre-op state of our chunks." + +Retention: the OCDBT spec exposes no pruning fields. All versions are +retained by default. +""" + +from os import environ + +import numpy as np +import tensorstore as ts + +from pychunkedgraph import get_logger + +from .debug import bbox_failure_payload, dump_failure_to_gcs +from .meta import OcdbtConfig +from ..dry_run import is_dry_run +from .utils import ( + _base_ocdbt_path, + _ensure_trailing_slash, + _open_precomputed_scale, + _read_source_scales, + _schema_from_src, + base_exists, + fork_exists, +) + +logger = get_logger(__name__) + + +def create_base_ocdbt(ws_path: str, config: OcdbtConfig): + """One-time bootstrap: create the shared base OCDBT at ``/ocdbt/base/``. + + Wipes any existing base first, then opens each scale with create=True + so the info JSON is built from the source. Populating the base with + actual chunk data happens separately via ``copy_ws_chunk_multiscale`` + or ``copy_ws_bbox_multiscale`` during the per-chunk ingest tasks. + + Returns (src_list, dst_list, resolutions) for the caller to use with + the copy helpers. + """ + base = _base_ocdbt_path(ws_path) + # Wipe via the underlying GCS/file driver, NOT through the ocdbt + # driver. Opening as ocdbt on an empty dir creates a default-config + # `manifest.ocdbt` stub (max_inline_value_bytes=100); on a dir with + # an existing manifest it only clears the B+tree, leaving the + # manifest's config in place. Either way the subsequent open with a + # different config mismatches. + try: + kvs = ts.KvStore.open(base).result() + kvs.delete_range(ts.KvStore.KeyRange()).result() + except Exception: + pass + + scales = _read_source_scales(ws_path) + resolutions = [s["resolution"] for s in scales] + base_kvstore = {"driver": "ocdbt", "base": base, "config": config.ts_config()} + + src_list, dst_list = [], [] + for i in range(len(scales)): + src_i = ts.open( + {"driver": "neuroglancer_precomputed", "kvstore": ws_path, "scale_index": i} + ).result() + dst_i = _open_precomputed_scale( + base_kvstore, i, create=True, **_schema_from_src(src_i) + ) + src_list.append(src_i) + dst_list.append(dst_i) + return src_list, dst_list, resolutions + + +def wipe_base_ocdbt(ws_path: str): + """Wipe the base OCDBT entirely (for --reset-ocdbt).""" + base = _base_ocdbt_path(ws_path) + # Wipe via the underlying GCS/file driver so the manifest file is + # deleted too. Opening as ocdbt only clears the B+tree. + try: + kvs = ts.KvStore.open(base).result() + kvs.delete_range(ts.KvStore.KeyRange()).result() + except Exception: + pass + + +def open_base_ocdbt( + ws_path: str, config: OcdbtConfig, coordinator_address: str | None = None +): + """Open the existing base OCDBT (read/write) for populating during ingest. + + Used by per-chunk ingest tasks that copy precomputed data into the shared + base. NOT used at runtime — runtime always goes through the per-CG fork + spec via ``get_seg_source_and_destination_ocdbt``. + + ``coordinator_address`` (``"host:port"``) routes every OCDBT commit + through a ``DistributedCoordinatorServer`` so parallel workers don't + race the shared manifest's CAS — the only thing that prevents the + orphan ``d/`` file explosion. Required for any concurrent writer; the + arg is optional so single-process callers (e.g. tests, notebooks) can + skip it. + + Returns (src_list, dst_list, resolutions). + """ + base = _base_ocdbt_path(ws_path) + scales = _read_source_scales(ws_path) + resolutions = [s["resolution"] for s in scales] + # No `config` block: the base already exists (created by + # `create_base_ocdbt`), so its on-disk manifest is authoritative. + # Embedding `config.ts_config()` here would assert our in-code + # defaults against whatever is persisted and raise + # FAILED_PRECONDITION on any drift. + base_kvstore = {"driver": "ocdbt", "base": base} + if coordinator_address: + base_kvstore["coordinator"] = {"address": coordinator_address} + + src_list, dst_list = [], [] + for i in range(len(scales)): + src_i = ts.open( + {"driver": "neuroglancer_precomputed", "kvstore": ws_path, "scale_index": i} + ).result() + dst_i = _open_precomputed_scale(base_kvstore, i, **_schema_from_src(src_i)) + src_list.append(src_i) + dst_list.append(dst_i) + return src_list, dst_list, resolutions + + +def build_cg_ocdbt_spec( + ws_path: str, + graph_id: str, + config: OcdbtConfig, + *, + pinned_at: "int | str | None" = None, +) -> dict: + """Open-time kvstore spec for a CG's OCDBT, backed by a shared immutable base. + + This function is a pure spec-constructor — it doesn't materialize + the fork. The fork's ``manifest.ocdbt`` must exist before ``ts.open`` + on this spec will succeed; it's created by ``fork_base_manifest`` + (invoked from the ingest CLI's OCDBT path or the ``seg_ocdbt`` + notebook). ``ChunkedGraphMeta.ws_ocdbt_scales`` asserts presence via + ``fork_exists`` so callers get a clear error instead of a tensorstore + internal failure. + + All three kvstack layers below AND all three ``*_data_prefix`` options + are load-bearing; removing any of them causes fork writes to leak + into the immutable base (verified empirically). + + When ``pinned_at`` is set, the opened kvstore is read-only and returns + state as of the specified version. Accepts an integer generation + number (exact) or an ISO-8601 UTC timestamp string with ``Z`` suffix + (interpreted as ``commit_time <= T``). + """ + base = _base_ocdbt_path(ws_path) + fork_dir = _ensure_trailing_slash(f"{ws_path.rstrip('/')}/ocdbt/{graph_id}") + data_prefix = f"{graph_id}_d/" + + # Catch-all. Lets the fork READ base's B+tree (manifest + d/ + # data files) via fall-through. Must be first so later layers can + # override sub-ranges. + base_layer = {"base": base} + + # Single-key override. Routes the fork's manifest file so new + # commits by this CG are visible only to this CG. Without this layer + # manifest writes silently clobber base's manifest. + fork_manifest_layer = { + "exact": "manifest.ocdbt", + "base": fork_dir + "manifest.ocdbt", + } + + # Catches OCDBT's new data-file writes for the fork. Pairs with the + # *_data_prefix options: OCDBT would otherwise write under the + # default `d/` prefix — no later layer claims `d/`, so kvstack + # falls through to the base catch-all and the writes corrupt base. + fork_data_layer = { + "prefix": data_prefix, + "base": _ensure_trailing_slash(fork_dir + data_prefix), + } + + # No `config` block: this spec opens an existing OCDBT (the shared + # base + this fork's manifest+data layers). Tensorstore validates + # every field of `config` against the on-disk manifest and raises + # FAILED_PRECONDITION on mismatch, so embedding our in-code defaults + # here would break any base that was created with different values + # (e.g. an older default for `max_inline_value_bytes`). On-disk wins. + spec = { + "driver": "ocdbt", + "base": { + "driver": "kvstack", + "layers": [base_layer, fork_manifest_layer, fork_data_layer], + }, + # Steer every kind of OCDBT write under `_d/` so the + # fork_data_layer catches them. + "value_data_prefix": data_prefix, + "btree_node_data_prefix": data_prefix, + "version_tree_node_data_prefix": data_prefix, + } + if pinned_at is not None: + spec["version"] = pinned_at + return spec + + +def fork_base_manifest(ws_path: str, graph_id: str, wipe_existing: bool = False): + """Initialize a CG's delta directory by copying the base manifest. + + If wipe_existing=True, deletes the existing fork directory first (for + --retry when a prior ingest failed and left partial delta state). + """ + assert base_exists(ws_path), "base OCDBT must exist before forking" + base = _base_ocdbt_path(ws_path) + fork_dir = _ensure_trailing_slash(f"{ws_path.rstrip('/')}/ocdbt/{graph_id}") + + if wipe_existing: + try: + kvs = ts.KvStore.open(fork_dir).result() + kvs.delete_range(ts.KvStore.KeyRange()).result() + except Exception: + pass + + base_kvs = ts.KvStore.open(base).result() + fork_kvs = ts.KvStore.open(fork_dir).result() + manifest = base_kvs.read("manifest.ocdbt").result().value + fork_kvs.write("manifest.ocdbt", manifest).result() + + +def ensure_fork_synced(ws_path: str, graph_id: str) -> bool: + """Sync fork manifest to base — but only before the fork's first edit. + + Invariant we enforce: a fresh, edit-free fork must reflect base's + *current* manifest at open time. ``setup_base`` calls + ``fork_base_manifest`` once at graph creation, possibly before + populate has committed most of its writes; any subsequent populate + commit to base would otherwise be invisible through the fork + (symptom: meshing reads return zeros). We close that window by + re-snapshotting on the first runtime open before any edit lands. + + Once the fork has any edit (anything under ``_d/``), the + function has no work to do: base is immutable post-setup, so the + fork manifest cannot fall behind in any way that matters — its + divergence from base is just the fork's own forward progress. + Edit files are stable, so listing the prefix is a sufficient + short-circuit and skips reading both manifests on every runtime + open. + + Returns True iff the fork manifest was refreshed. + """ + if is_dry_run(): + return False + if not fork_exists(ws_path, graph_id): + return False + fork_dir = _ensure_trailing_slash(f"{ws_path.rstrip('/')}/ocdbt/{graph_id}") + fork_kvs = ts.KvStore.open(fork_dir).result() + data_prefix = f"{graph_id}_d/" + edit_files = fork_kvs.list( + ts.KvStore.KeyRange(data_prefix, data_prefix[:-1] + chr(ord("/") + 1)) + ).result() + if len(edit_files) > 0: + # Steady state — fork has progressed forward by design. + return False + base = _base_ocdbt_path(ws_path) + base_kvs = ts.KvStore.open(base).result() + base_manifest = base_kvs.read("manifest.ocdbt").result().value + fork_manifest = fork_kvs.read("manifest.ocdbt").result().value + if base_manifest == fork_manifest: + return False + fork_kvs.write("manifest.ocdbt", base_manifest).result() + logger.note(f"refreshed fork manifest at {fork_dir} from base (no edits)") + return True + + +def get_seg_source_and_destination_ocdbt( + ws_path: str, + graph_id: str, + config: OcdbtConfig, + *, + pinned_at: "int | str | None" = None, +) -> tuple: + """Open source watershed + CG's delta OCDBT destination (all scales). + + Always uses the fork-based kvstack spec. Requires the base to exist and + the fork's manifest to be present (set up at ingest time). + + When ``pinned_at`` is set, the destination OCDBT handles are opened + read-only at that version — used by the recovery path to read + pre-op seg values via ``ChunkedGraphMeta.pinned_seg_reads``. + + Returns: + (src_list, dst_list, resolutions): per-scale TensorStore handles + and [x,y,z] resolutions. + """ + scales = _read_source_scales(ws_path) + resolutions = [s["resolution"] for s in scales] + cg_kvstore = build_cg_ocdbt_spec(ws_path, graph_id, config, pinned_at=pinned_at) + + src_list, dst_list = [], [] + for i in range(len(scales)): + src_i = ts.open( + {"driver": "neuroglancer_precomputed", "kvstore": ws_path, "scale_index": i} + ).result() + dst_i = _open_precomputed_scale(cg_kvstore, i, **_schema_from_src(src_i)) + src_list.append(src_i) + dst_list.append(dst_i) + return src_list, dst_list, resolutions + + +def copy_ws_chunk( + source, + destination, + chunk_size: tuple, + coords: list, + voxel_bounds: np.ndarray, +): + """Copy one chunk from source watershed to OCDBT destination at the same scale. + + Coordinates are interpreted at the source/destination's native scale — + callers must pre-scale them when copying coarser MIP levels. + """ + coords = np.array(coords, dtype=int) + chunk_size = np.array(chunk_size, dtype=int) + vx_start = coords * chunk_size + voxel_bounds[:, 0] + vx_end = vx_start + chunk_size + xE, yE, zE = voxel_bounds[:, 1] + + x0, y0, z0 = vx_start + x1, y1, z1 = vx_end + x1 = min(x1, xE) + y1 = min(y1, yE) + z1 = min(z1, zE) + + data = source[x0:x1, y0:y1, z0:z1].read().result() + destination[x0:x1, y0:y1, z0:z1].write(data).result() + + +def copy_ws_chunk_multiscale( + src_list, + dst_list, + resolutions, + chunk_size: tuple, + coords: list, + voxel_bounds: np.ndarray, +): + """Copy a base-resolution chunk's physical region across all MIP scales. + + The graph's chunk grid is defined at base resolution. For each coarser + scale we copy the SAME physical region — voxel coordinates are divided + by the cumulative downsample factor (derived from resolution ratios). + Source already has correct data at every scale, so this is a pure copy + with no recomputation. + """ + assert len(src_list) == len(dst_list) == len(resolutions) + coords = np.array(coords, dtype=int) + chunk_size_arr = np.array(chunk_size, dtype=int) + base_res = np.array(resolutions[0]) + + # Physical region at base resolution. + vx_start_base = coords * chunk_size_arr + voxel_bounds[:, 0] + vx_end_base = np.minimum(vx_start_base + chunk_size_arr, voxel_bounds[:, 1]) + + for i, (src, dst) in enumerate(zip(src_list, dst_list)): + # Cumulative factor from base to this scale (e.g. [2,2,1] per level). + factor = (np.array(resolutions[i]) / base_res).astype(int) + x0, y0, z0 = vx_start_base // factor + x1, y1, z1 = vx_end_base // factor + if x1 <= x0 or y1 <= y0 or z1 <= z0: + logger.debug(f"skipping empty region at scale {i}") + continue + data = src[x0:x1, y0:y1, z0:z1].read().result() + dst[x0:x1, y0:y1, z0:z1].write(data).result() + + +def copy_ws_bbox_multiscale( + src_list, + dst_list, + resolutions, + bbox_lo: np.ndarray, + bbox_hi: np.ndarray, + dump_tag: str | None = None, +): + """Copy a base-resolution voxel bbox across all MIP scales under one + transaction so the whole multi-scale write lands as a single OCDBT commit. + + The transaction (not ``atomic=True``) is what's load-bearing: it batches + every per-chunk underlying-kvstore write across every scale into one + commit, so the d/ file count for one call is constant in bbox size and + grows only with scale count. ``atomic=True`` would add cross-key + isolation but is rejected by tensorstore's distributed-OCDBT path — + when the kvstore is opened with a ``coordinator``, atomic transactions + cannot span multiple keys (verified empirically). Non-atomic still + batches; the coordinator handles concurrency by serializing the commit + on the wire. + + Passing the source TensorStore directly into ``write(...)`` lets + tensorstore stream the copy without materializing an intermediate + numpy array in Python — peak RSS drops by roughly one scale's + worth versus the read-into-numpy-then-write pattern. + """ + assert len(src_list) == len(dst_list) == len(resolutions) + dump_enabled = bool(environ.get("ERROR_DUMP")) + base_res = np.array(resolutions[0]) + txn = ts.Transaction() + # per_scale rows are only populated when dump_enabled, so the failure + # path has enough context for the structured GCS dump without paying any + # bookkeeping cost on the happy path. + per_scale: list = [] + for i, (src, dst) in enumerate(zip(src_list, dst_list)): + factor = (np.array(resolutions[i]) / base_res).astype(int) + x0, y0, z0 = bbox_lo // factor + x1, y1, z1 = bbox_hi // factor + if x1 <= x0 or y1 <= y0 or z1 <= z0: + continue + if dump_enabled: + dims = (int(x1 - x0), int(y1 - y0), int(z1 - z0)) + nvox = dims[0] * dims[1] * dims[2] + bpv = int(np.dtype(dst.dtype.numpy_dtype).itemsize) + # The precomputed driver's read_chunk shape includes a channel + # axis; the spatial chunk shape is the first three dims. + chunk_shape = tuple(int(s) for s in dst.chunk_layout.read_chunk.shape[:3]) + n_keys = int( + np.prod( + [int(np.ceil(d / c)) if c else 0 for d, c in zip(dims, chunk_shape)] + ) + ) + max_raw_per_key = int(np.prod(chunk_shape)) * bpv + per_scale.append( + (i, dims, nvox, nvox * bpv, chunk_shape, n_keys, max_raw_per_key) + ) + dst.with_transaction(txn)[x0:x1, y0:y1, z0:z1].write( + src[x0:x1, y0:y1, z0:z1] + ).result() + try: + txn.commit_async().result() + except Exception as exc: + if dump_enabled: + payload = bbox_failure_payload( + exc, + dump_tag, + bbox_lo, + bbox_hi, + resolutions, + per_scale, + dst_list[0], + src_list[0], + ) + path = dump_failure_to_gcs(payload, dump_tag) + if path: + logger.note(f"OCDBT commit failure dump → {path}") + raise + + +def _mode_downsample(data: np.ndarray, factors: tuple) -> np.ndarray: + """Mode downsample 4D segmentation array [X,Y,Z,C] by per-axis factors. + + Mode (most-frequent label) is the correct downsampling for segmentation: + it preserves exact label IDs (no interpolation) and biases toward the + dominant label in each block. + + Fast path for 2x2x1: uses a vectorized 4-element pairwise comparison. + Among 4 voxels {a,b,c,d}, if any value appears at least twice it is the + mode. Order of comparisons biases ties toward the top-left corner, which + is the standard convention for segmentation downsampling. + """ + fx, fy, fz = factors + X, Y, Z, C = data.shape + + # Pad with edge values so dimensions are divisible by the factor. + # Using 'edge' (not zeros) avoids introducing a phantom background label. + pad = [(0, (-X % fx) % fx), (0, (-Y % fy) % fy), (0, (-Z % fz) % fz), (0, 0)] + if any(p[1] > 0 for p in pad): + data = np.pad(data, pad, mode="edge") + X, Y, Z, C = data.shape + + if fx == 2 and fy == 2 and fz == 1: + # Fast vectorized path for the common 2x2x1 case. + reshaped = data.reshape(X // 2, 2, Y // 2, 2, Z, C) + a = reshaped[:, 0, :, 0] + b = reshaped[:, 0, :, 1] + c = reshaped[:, 1, :, 0] + d = reshaped[:, 1, :, 1] + return np.where( + (a == b) | (a == c) | (a == d), + a, + np.where((b == c) | (b == d), b, np.where(c == d, c, a)), + ) + + if fx == 2 and fy == 2 and fz == 2: + # 2x2x2 (8-element mode) — strided subsample is fast and label-safe + # for typical segmentation where adjacent voxels share labels. + return data[::2, ::2, ::2] + + # Generic factor: reshape into blocks, take strided first element. + # This is label-safe but loses the mode property; downsampling factor + # ratios in production are 2x2x1 or 2x2x2 so the fast paths cover them. + reshaped = data.reshape(X // fx, fx, Y // fy, fy, Z // fz, fz, C) + return reshaped[:, 0, :, 0, :, 0] + + +def propagate_to_coarser_scales(dst_scales, resolutions, base_slices): + """Cascade-downsample data from base scale through all coarser scales. + + Called after writing to the base scale (e.g. after an SV split). Each + coarser scale reads from the level below it (not from base directly), + so total downsample cost shrinks geometrically — each level processes + 1/N the data of the previous one. + + Args: + dst_scales: TensorStore handles, one per MIP level. + resolutions: [x,y,z] resolution arrays per scale, used to derive + per-axis downsample factors from consecutive resolution ratios. + base_slices: tuple of 3 slices (x, y, z) covering the region written + at base resolution. + """ + prev_slices = base_slices + for i in range(1, len(dst_scales)): + # Per-axis downsample factor from actual resolution ratio. + # Never hardcoded — different datasets may have different ratios. + factor = (np.array(resolutions[i]) / np.array(resolutions[i - 1])).astype(int) + + # Map prev-level slices to this level's coordinates. + # Ceil division on stop ensures we cover any partial block. + target_slices = tuple( + slice(s.start // f, -(-s.stop // f)) for s, f in zip(prev_slices, factor) + ) + + data = dst_scales[i - 1][prev_slices + (slice(None),)].read().result() + downsampled = _mode_downsample(data, tuple(int(f) for f in factor)) + dst_scales[i][target_slices + (slice(None),)].write(downsampled).result() + + prev_slices = target_slices + + +def write_seg_chunks(meta, seg_writes): + """Write a flat batch of pre-sliced L2 chunks to OCDBT in parallel. + + ``seg_writes`` is the aggregated output of ``sv_split.edits.split_supervoxels`` + across every rep in an operation — each pair is one L2 chunk's worth + of ``(voxel_slices, data)``. Flattening across reps matters: one + ``write_seg_chunks`` call fires every chunk write in one parallel + tensorstore batch instead of serializing rep-by-rep. + + Only chunks that actually received new SV IDs appear here; gap + chunks between cross-chunk-connected pieces and neighbor chunks the + overlap read touched are skipped by the split planner. + + Coarser MIP levels stay the downsample worker's job — it picks up + the pubsub message ``publish_edit`` sends after this returns. + + Args: + meta: ChunkedGraphMeta with ``ws_ocdbt`` (base-scale handle). + seg_writes: iterable of ``(voxel_slices, data)`` pairs, where + ``voxel_slices`` is a 3-tuple of ``slice`` objects covering one + L2 chunk's x/y/z extent and ``data`` is the 3D label block + (shape matches the slice extents). + """ + if is_dry_run(): + return + futures = [ + meta.ws_ocdbt[voxel_slices + (slice(None),)].write(data[..., np.newaxis]) + for voxel_slices, data in seg_writes + ] + for f in futures: + f.result() diff --git a/pychunkedgraph/graph/ocdbt/meta.py b/pychunkedgraph/graph/ocdbt/meta.py new file mode 100644 index 000000000..df18b4dca --- /dev/null +++ b/pychunkedgraph/graph/ocdbt/meta.py @@ -0,0 +1,78 @@ +"""OcdbtConfig dataclass — single source of truth for per-CG OCDBT settings.""" + +from dataclasses import asdict, dataclass, field +from typing import Dict, Optional + + +@dataclass +class OcdbtConfig: + """Per-CG OCDBT settings, persisted in ``ChunkedGraphMeta.custom_data["ocdbt_config"]``. + + Carries both ingest-time choices (populate base? at which layer?) and + tensorstore kvstore options (compression, inline byte cap) that must + stay consistent for the lifetime of the OCDBT base. Built once from + the dataset yaml's ``ocdbt_config:`` section and stored alongside the + CG so every code path that opens an OCDBT store reads back the same + values. + """ + + enabled: bool = False + populate_base: bool = False + populate_layer: int = 3 + sv_split_threshold: int = 10 + compression: Dict = field(default_factory=lambda: {"id": "zstd", "level": 12}) + # Inline-vs-out-of-line threshold. Values ≤ this size live in the btree + # leaf bytes; larger values get written to a d/ file and the mutation + # carries only an IndirectDataReference. This directly determines + # cooperator-forwarded RPC size in distributed mode: inline values are + # carried inside the gRPC WriteRequest's `mutations` field, so a leaf's + # batch can blow past tensorstore's hardcoded 4 MiB gRPC max-receive + # whenever multiple inline values pile up on the same node. Verified + # by reading btree_writer.cc StagePending in v0.1.81. + # + # 4 KiB keeps small metadata (info JSON ~1.5 KB, populate-marker files) + # inline while forcing every segmentation chunk value out-of-line — + # chunks compress to 100s of KB even for the smallest scales. With + # chunk bytes out-of-line the WriteRequest stays tiny regardless of + # how many keys a worker commits at once. Tradeoff vs the previous + # 1 MiB cap: each chunk now has its own zstd-framed d/ blob instead of + # sharing a leaf's compression context, which can cost a few percent + # of compression ratio (much less than the originally-feared "7× + # bloat", which only applied at the 100-byte default). + max_inline_value_bytes: int = 4096 + + @classmethod + def from_dict(cls, d: Optional[Dict]) -> "OcdbtConfig": + """Build from a dict. Unknown keys are ignored so older configs don't + break newer code, and newer fields default in when older configs are + loaded. + """ + if not d: + return cls() + known = {f.name for f in cls.__dataclass_fields__.values()} + return cls(**{k: v for k, v in d.items() if k in known}) + + @classmethod + def resolve(cls, *dicts: Optional[Dict]) -> "OcdbtConfig": + """Layered merge: later dicts override earlier ones, all over defaults. + + Use to express precedence — e.g. ``resolve(yaml_dict, info_file_dict)`` + gives info-file values priority over yaml-supplied ones, with + dataclass defaults filling anything neither side specifies. + ``None`` and empty dicts are no-ops. + """ + merged: Dict = {} + for d in dicts: + if d: + merged.update(d) + return cls.from_dict(merged) + + def to_dict(self) -> Dict: + return asdict(self) + + def ts_config(self) -> Dict: + """The subset that belongs inside a tensorstore OCDBT kvstore ``config``.""" + return { + "compression": dict(self.compression), + "max_inline_value_bytes": self.max_inline_value_bytes, + } diff --git a/pychunkedgraph/graph/ocdbt/utils.py b/pychunkedgraph/graph/ocdbt/utils.py new file mode 100644 index 000000000..4c04e5799 --- /dev/null +++ b/pychunkedgraph/graph/ocdbt/utils.py @@ -0,0 +1,156 @@ +"""Internal helpers for the OCDBT package. + +Path builders, schema extraction, populate-marker IO, layer-bbox math. +Not part of the public API except for the marker IO and ``_layer_bbox`` +which the ingest worker uses across the package boundary. +""" + +import json +from typing import Optional + +import numpy as np +import tensorstore as ts +from tenacity import ( + retry, + retry_if_exception_message, + stop_after_attempt, + wait_exponential, +) + +# tensorstore raises ValueError with an absl/grpc status-code prefix. Retry +# only the transient classes — DNS hiccups, deadline-blown reads, server +# 5xx — so a single flaky GCS call doesn't kill the populate task. Persistent +# errors (NOT_FOUND, INVALID_ARGUMENT, RESOURCE_EXHAUSTED, …) propagate. +_transient = retry( + retry=retry_if_exception_message( + match=r"^(UNAVAILABLE|DEADLINE_EXCEEDED|ABORTED|INTERNAL):" + ), + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=0.5, min=0.5, max=8), + reraise=True, +) + + +def _ensure_trailing_slash(path: str) -> str: + """Ensure kvstore paths end with / so they're treated as directories.""" + return path if path.endswith("/") else path + "/" + + +def _base_ocdbt_path(ws_path: str) -> str: + return _ensure_trailing_slash(f"{ws_path.rstrip('/')}/ocdbt/base") + + +def _populate_markers_path(ws_path: str) -> str: + return _ensure_trailing_slash(f"{ws_path.rstrip('/')}/ocdbt/.populated") + + +def _marker_key(layer: int, coords) -> str: + return f"l{int(layer)}_{int(coords[0])}_{int(coords[1])}_{int(coords[2])}" + + +def _read_source_scales(ws_path: str): + """Read the source precomputed ``info`` JSON to get scale count and resolutions. + + The leading '/' in '/info' is required for GCS — without it the read + returns empty. + """ + kvs = ts.KvStore.open(ws_path).result() + info = json.loads(kvs.read("/info").result().value) + return info["scales"] + + +def _open_precomputed_scale( + kvstore, scale_index: int, create: bool = False, **schema_kw +): + """Open one neuroglancer_precomputed scale on top of a kvstore spec.""" + spec = { + "driver": "neuroglancer_precomputed", + "kvstore": kvstore, + "scale_index": scale_index, + } + return ts.open(spec, create=create, **schema_kw).result() + + +def _schema_from_src(src_handle) -> dict: + """Extract the schema kwargs needed to open a matching destination. + + ``domain`` already carries both extent and origin (voxel_offset). Passing + ``shape`` alongside conflicts with non-zero-origin sources because shape + implies origin=0 — tensorstore refuses to merge ``[0, N)`` with + ``[offset, offset+N)``. + """ + s = src_handle.schema + return dict( + rank=s.rank, + dtype=s.dtype, + codec=s.codec, + domain=s.domain, + chunk_layout=s.chunk_layout, + dimension_units=s.dimension_units, + ) + + +@_transient +def is_chunk_populated(ws_path: str, layer: int, coords) -> bool: + """Check whether this chunk's precomputed→OCDBT copy has already completed. + + Markers live outside the OCDBT keyspace at + ``/ocdbt/.populated/l___`` so retried ingest tasks + don't re-copy chunks and bloat the database with redundant versioned + writes. + """ + kvs = ts.KvStore.open(_populate_markers_path(ws_path)).result() + result = kvs.read(_marker_key(layer, coords)).result() + return result.value is not None and len(result.value) > 0 + + +@_transient +def mark_chunk_populated(ws_path: str, layer: int, coords) -> None: + """Record that this chunk's precomputed→OCDBT copy completed.""" + kvs = ts.KvStore.open(_populate_markers_path(ws_path)).result() + kvs.write(_marker_key(layer, coords), b"1").result() + + +@_transient +def read_populate_meta(ws_path: str) -> Optional[dict]: + """Return the per-base populate config dict, or None if not yet written.""" + kvs = ts.KvStore.open(_populate_markers_path(ws_path)).result() + r = kvs.read("meta.json").result() + if r.value is None or len(r.value) == 0: + return None + return json.loads(r.value) + + +@_transient +def write_populate_meta(ws_path: str, meta: dict) -> None: + """Persist the per-base populate config (layer, etc.) alongside markers.""" + kvs = ts.KvStore.open(_populate_markers_path(ws_path)).result() + kvs.write("meta.json", json.dumps(meta).encode()).result() + + +def base_exists(ws_path: str) -> bool: + """Check if the base OCDBT has already been created for this watershed.""" + base = _base_ocdbt_path(ws_path) + kvs = ts.KvStore.open(base).result() + result = kvs.read("manifest.ocdbt").result() + return result.value is not None and len(result.value) > 0 + + +def fork_exists(ws_path: str, graph_id: str) -> bool: + """Check if this ChunkedGraph's fork has been initialized.""" + fork_dir = _ensure_trailing_slash(f"{ws_path.rstrip('/')}/ocdbt/{graph_id}") + kvs = ts.KvStore.open(fork_dir).result() + result = kvs.read("manifest.ocdbt").result() + return result.value is not None and len(result.value) > 0 + + +def _layer_bbox(meta, layer: int, coords) -> tuple: + """Base-resolution voxel bbox of a chunk at this layer.""" + chunk_size = np.array(meta.graph_config.CHUNK_SIZE, dtype=int) + layer_chunk_size = chunk_size * (1 << (layer - 2)) + coords = np.array(coords, dtype=int) + vol_start = meta.voxel_bounds[:, 0] + vol_end = meta.voxel_bounds[:, 1] + lo = coords * layer_chunk_size + vol_start + hi = np.minimum(lo + layer_chunk_size, vol_end) + return lo, hi diff --git a/pychunkedgraph/graph/operation.py b/pychunkedgraph/graph/operation.py index 73ad898d8..7b5d72777 100644 --- a/pychunkedgraph/graph/operation.py +++ b/pychunkedgraph/graph/operation.py @@ -21,15 +21,18 @@ from . import locks from . import edits +from . import sv_split from . import types +from .ocdbt import write_seg_chunks +from .dry_run import is_dry_run from pychunkedgraph.graph import attributes from .edges import Edges from .edges.utils import get_edges_status from pychunkedgraph.graph import basetypes from pychunkedgraph.graph import serializers from .cache import CacheService -from .cutting import run_multicut -from .exceptions import PreconditionError, SupervoxelSplitRequiredError +from .cutting import Cut, SvSplitRequired, run_multicut +from .exceptions import PreconditionError from .exceptions import PostconditionError from .utils.generic import get_bounding_box as get_bbox from pychunkedgraph.graph import get_valid_timestamp @@ -50,7 +53,9 @@ class GraphEditOperation(ABC): "do_sanity_check", ] Result = namedtuple( - "Result", ["operation_id", "new_root_ids", "new_lvl2_ids", "old_root_ids"] + "Result", + ["operation_id", "new_root_ids", "new_lvl2_ids", "old_root_ids", "seg_bbox"], + defaults=(None,), ) def __init__( @@ -444,7 +449,7 @@ def execute( operation_ts=override_ts if override_ts else timestamp, status=attributes.OperationLogs.StatusCodes.CREATED.value, ) - self.cg.client.write([log_record_before_edit]) + self._persist_rows([log_record_before_edit]) try: with TimeIt(f"{op_type}.apply", self.cg.graph_id, lock.operation_id): @@ -452,7 +457,7 @@ def execute( operation_id=lock.operation_id, timestamp=override_ts if override_ts else timestamp, ) - if self.cg.meta.READ_ONLY: + if is_dry_run(): # return without persisting changes return GraphEditOperation.Result( operation_id=lock.operation_id, @@ -460,11 +465,6 @@ def execute( new_lvl2_ids=new_lvl2_ids, old_root_ids=root_ids, ) - except SupervoxelSplitRequiredError as err: - # no need for self.cg.cache = None, the cache must be retained after sv split - raise SupervoxelSplitRequiredError( - str(err), err.sv_remapping, operation_id=lock.operation_id - ) from err except PreconditionError as err: self.cg.cache = None raise PreconditionError(err) from err @@ -485,7 +485,7 @@ def execute( status=attributes.OperationLogs.StatusCodes.EXCEPTION.value, exception=repr(err), ) - self.cg.client.write([log_record_error]) + self._persist_rows([log_record_error]) raise Exception(err) from err with TimeIt(f"{op_type}.write", self.cg.graph_id, lock.operation_id): @@ -550,8 +550,18 @@ def _write( new_root_ids=new_root_ids, new_lvl2_ids=new_lvl2_ids, old_root_ids=old_root_ids, + # Only set when the operation actually ran SV splits (MulticutOperation + # populates this; other operations leave the attr absent and it defaults + # to None via the Result namedtuple's default). + seg_bbox=getattr(self, "seg_bboxes", None) or None, ) + def _persist_rows(self, rows): + """Persist BT mutation rows; no-op under ``PCG_DRY_RUN=1``.""" + if is_dry_run(): + return + self.cg.client.write(rows) + class MergeOperation(GraphEditOperation): """Merge Operation: Connect *known* pairs of supervoxels by adding a (weighted) edge. @@ -630,13 +640,14 @@ def _update_root_ids(self) -> np.ndarray: def _apply( self, *, operation_id, timestamp ) -> Tuple[np.ndarray, np.ndarray, List[Any]]: - root_ids = set( - self.cg.get_roots( - self.added_edges.ravel(), assert_roots=True, time_stamp=self.parent_ts - ) - ) + sv_ids = self.added_edges.ravel() + roots = self.cg.get_roots(sv_ids, assert_roots=True, time_stamp=self.parent_ts) + root_ids = set(roots) if len(root_ids) < 2 and not self.allow_same_segment_merge: - raise PreconditionError("Supervoxels must belong to different objects.") + raise PreconditionError( + f"Supervoxels must belong to different objects. " + f"sv_id->root: {dict(zip(sv_ids.tolist(), roots.tolist()))}" + ) atomic_edges = self.added_edges fake_edge_rows = [] @@ -761,33 +772,26 @@ def __init__( assert np.sum(layers) == layers.size, "IDs must be supervoxels." def _update_root_ids(self) -> np.ndarray: - root_ids = np.unique( - self.cg.get_roots( - self.removed_edges.ravel(), - assert_roots=True, - time_stamp=self.parent_ts, - ) - ) + sv_ids = self.removed_edges.ravel() + roots = self.cg.get_roots(sv_ids, assert_roots=True, time_stamp=self.parent_ts) + root_ids = np.unique(roots) if len(root_ids) > 1: - raise PreconditionError("Supervoxels must belong to the same object.") + raise PreconditionError( + f"Supervoxels must belong to the same object. " + f"sv_id->root: {dict(zip(sv_ids.tolist(), roots.tolist()))}" + ) return root_ids def _apply( self, *, operation_id, timestamp ) -> Tuple[np.ndarray, np.ndarray, List[Any]]: - if ( - len( - set( - self.cg.get_roots( - self.removed_edges.ravel(), - assert_roots=True, - time_stamp=self.parent_ts, - ) - ) + sv_ids = self.removed_edges.ravel() + roots = self.cg.get_roots(sv_ids, assert_roots=True, time_stamp=self.parent_ts) + if len(set(roots)) > 1: + raise PreconditionError( + f"Supervoxels must belong to the same object. " + f"sv_id->root: {dict(zip(sv_ids.tolist(), roots.tolist()))}" ) - > 1 - ): - raise PreconditionError("Supervoxels must belong to the same object.") with TimeIt("remove_edges", self.cg.graph_id, operation_id): return edits.remove_edges( @@ -866,6 +870,11 @@ class MulticutOperation(GraphEditOperation): "path_augment", "disallow_isolating_cut", "do_sanity_check", + # Base-resolution bboxes of SV splits done as part of this op, one + # per rep. Populated only when the multicut hit SvSplitRequired and + # split_supervoxels actually ran. Surfaced on the Result so the + # downsample worker knows which regions to re-mip. + "seg_bboxes", ] def __init__( @@ -893,6 +902,7 @@ def __init__( self.path_augment = path_augment self.disallow_isolating_cut = disallow_isolating_cut self.do_sanity_check = do_sanity_check + self.seg_bboxes = [] ids = np.concatenate([self.source_ids, self.sink_ids]).astype(basetypes.NODE_ID) layers = self.cg.get_chunk_layers(ids) @@ -902,30 +912,114 @@ def _update_root_ids(self) -> np.ndarray: sink_and_source_ids = np.concatenate((self.source_ids, self.sink_ids)).astype( basetypes.NODE_ID ) - root_ids = np.unique( - self.cg.get_roots( - sink_and_source_ids, assert_roots=True, time_stamp=self.parent_ts - ) + roots = self.cg.get_roots( + sink_and_source_ids, assert_roots=True, time_stamp=self.parent_ts ) + root_ids = np.unique(roots) if len(root_ids) > 1: - raise PreconditionError("Supervoxels must belong to the same segment.") + raise PreconditionError( + f"Supervoxels must belong to the same segment. " + f"sources={self.source_ids.tolist()} sinks={self.sink_ids.tolist()} " + f"sv_id->root: {dict(zip(sink_and_source_ids.tolist(), roots.tolist()))}" + ) return root_ids def _apply( self, *, operation_id, timestamp ) -> Tuple[np.ndarray, np.ndarray, List[Any]]: - # Verify that sink and source are from the same root object - root_ids = set( - self.cg.get_roots( - np.concatenate([self.source_ids, self.sink_ids]).astype( - basetypes.NODE_ID - ), - assert_roots=True, - time_stamp=self.parent_ts, + result = self._run_multicut(operation_id) + if isinstance(result, SvSplitRequired): + # Running under GraphEditOperation.execute's RootLock — no same-root + # edit can interleave between the SV split and the retry multicut. + # `plan_sv_splits` returns the chunk scope for both locks below, + # `split_supervoxels` is a pure planner that computes the full + # payload. Writes happen here inside nested L2 chunk locks: + # - `L2ChunkLock` (temporal) spans the seg reads (inside + # `split_supervoxels`) and the writes, so no concurrent + # op can mutate our chunks mid-compute. + # - `IndefiniteL2ChunkLock` is scoped tightly to the writes + # only. A worker death inside it leaves the indefinite + # cell set on every chunk row in scope, blocking future + # ops until operator replay clears them. + tasks, chunk_ids = sv_split.edits.plan_sv_splits( + self.cg, + sv_remapping=result.sv_remapping, + source_ids=self.source_ids, + sink_ids=self.sink_ids, + source_coords=self.source_coords, + sink_coords=self.sink_coords, ) + with locks.L2ChunkLock( + self.cg, + chunk_ids, + operation_id, + privileged_mode=self.privileged_mode, + ): + sv_result = sv_split.edits.split_supervoxels( + self.cg, + tasks=tasks, + sv_remapping=result.sv_remapping, + source_ids=self.source_ids, + sink_ids=self.sink_ids, + operation_id=operation_id, + timestamp=timestamp, + parent_ts=self.parent_ts, + ) + with locks.IndefiniteL2ChunkLock( + self.cg, + chunk_ids, + operation_id, + privileged_mode=self.privileged_mode, + ): + write_seg_chunks(self.cg.meta, sv_result.seg_writes) + self._persist_rows(sv_result.bigtable_rows) + self.seg_bboxes = sv_result.seg_bboxes + self.source_ids = sv_result.source_ids_fresh + self.sink_ids = sv_result.sink_ids_fresh + result = self._run_multicut(operation_id) + if isinstance(result, SvSplitRequired): + raise PreconditionError( + "Supervoxel split succeeded but source and sink remain " + "connected; place source and sink farther apart." + ) + + assert isinstance(result, Cut), f"unexpected multicut result: {result!r}" + self.removed_edges = result.atomic_edges + if not self.removed_edges.size: + raise PostconditionError("Mincut could not find any edges to remove.") + + with TimeIt("remove_edges", self.cg.graph_id, operation_id): + return edits.remove_edges( + self.cg, + operation_id=operation_id, + atomic_edges=self.removed_edges, + time_stamp=timestamp, + parent_ts=self.parent_ts, + do_sanity_check=self.do_sanity_check, + ) + + def _run_multicut(self, operation_id): + """Build the local subgraph and run multicut; returns the tagged result. + + Factored so `_apply` can call it twice — once for initial detection + and again after an SV split to get fresh atomic_edges against the + post-split graph topology. + """ + sink_and_source_ids = np.concatenate([self.source_ids, self.sink_ids]).astype( + basetypes.NODE_ID ) + roots = self.cg.get_roots( + sink_and_source_ids, + assert_roots=True, + time_stamp=self.parent_ts, + ) + root_ids = set(roots) if len(root_ids) > 1: - raise PreconditionError("Supervoxels must belong to the same object.") + raise PreconditionError( + f"Supervoxels must belong to the same object. " + f"sources={self.source_ids.tolist()} sinks={self.sink_ids.tolist()} " + f"sv_id->root: {dict(zip(sink_and_source_ids.tolist(), roots.tolist()))}" + ) bbox = get_bbox( self.source_coords, @@ -936,7 +1030,6 @@ def _apply( l2id_agglomeration_d, edges_tuple = self.cg.get_subgraph( root_ids.pop(), bbox=bbox, bbox_is_coordinate=True ) - edges = reduce(lambda x, y: x + y, edges_tuple, Edges([], [])) supervoxels = np.concatenate( [agg.supervoxels for agg in l2id_agglomeration_d.values()] @@ -948,7 +1041,7 @@ def _apply( raise PreconditionError("No local edges found.") with TimeIt("multicut", self.cg.graph_id, operation_id): - self.removed_edges = run_multicut( + return run_multicut( edges, self.source_ids, self.sink_ids, @@ -956,18 +1049,6 @@ def _apply( disallow_isolating_cut=self.disallow_isolating_cut, sv_split_supported=self.cg.meta.ocdbt_seg, ) - if not self.removed_edges.size: - raise PostconditionError("Mincut could not find any edges to remove.") - - with TimeIt("remove_edges", self.cg.graph_id, operation_id): - return edits.remove_edges( - self.cg, - operation_id=operation_id, - atomic_edges=self.removed_edges, - time_stamp=timestamp, - parent_ts=self.parent_ts, - do_sanity_check=self.do_sanity_check, - ) def _create_log_record( self, diff --git a/pychunkedgraph/graph/sv_split/README.md b/pychunkedgraph/graph/sv_split/README.md new file mode 100644 index 000000000..8cdede87b --- /dev/null +++ b/pychunkedgraph/graph/sv_split/README.md @@ -0,0 +1,147 @@ +# Supervoxel splitting + +## What it is + +A *supervoxel split* bisects one physical supervoxel — a connected region in the raw segmentation — along a user-seeded cut. The user supplies a source coordinate and a sink coordinate inside one supervoxel; the system finds a cut surface separating them and assigns new supervoxel IDs to each half, writing the updated segmentation and the corresponding graph hierarchy. + +This only runs on segmentations stored in OCDBT (a writable, append-only segmentation backend). With a read-only segmentation backend the split path is never entered; the multicut instead surfaces a precondition error asking the user to pick different source/sink points. + +## Why it's needed + +The graph is stored in **chunks**: the segmentation volume is partitioned into a regular 3D grid, and each chunk owns its own set of supervoxel IDs. When a physical supervoxel spans a chunk boundary, it is artificially cut into multiple graph-level supervoxel IDs — one per chunk — with infinite-affinity *cross-chunk edges* connecting the pieces so the graph still represents one physical object. + +The multicut algorithm runs on a local graph around source and sink. If it finds that source and sink sit inside the same cross-chunk-connected component — i.e., in the same physical supervoxel — a clean graph cut cannot separate them without first **splitting that physical supervoxel at the voxel level** and giving the resulting halves fresh IDs. That voxel-level cut is what the split flow does. The multicut runs again against the refreshed graph and produces the graph-level edges to remove. + +## End-to-end flow + +``` +Split request (source coord, sink coord) + │ + ▼ +Resolve coords → current supervoxel IDs at those pixels + │ + ▼ +┌───────────────────────────────────────────────────────────────────────┐ +│ ROOT LOCK (held across the whole operation) │ +│ │ +│ multicut: │ +│ build local subgraph around source/sink │ +│ stitch cross-chunk-connected SVs via inf-affinity edges │ +│ run mincut between source and sink │ +│ result ─► one of: │ +│ ● clean cut → edges to remove │ +│ ● SV split needed → cross-chunk-representative mapping │ +│ ● precondition → surface to user, abort │ +│ │ +│ if SV split needed: │ +│ ┌───────────────────────────────────────────────────────────────┐ │ +│ │ L2 CHUNK LOCK (spatial; sparse set + 1-chunk margin) │ │ +│ │ │ │ +│ │ for each cross-chunk rep linking source to sink: │ │ +│ │ bbs/bbe ◄ envelope of src+sink seeds + 1-chunk margin │ │ +│ │ read seg in [bbs-1, bbe+1] │ │ +│ │ (1-voxel shell → anchor voxels for edge routing) │ │ +│ │ compute voxel-level cut between seeds │ │ +│ │ allocate fresh SV IDs per chunk to each half │ │ +│ │ route existing cross-chunk edges onto the new fragments │ │ +│ │ write seg — only chunks that actually received new IDs │ │ +│ │ write hierarchy rows (lineage + new cross-chunk edges) │ │ +│ └───────────────────────────────────────────────────────────────┘ │ +│ │ +│ refresh source/sink IDs: │ +│ look up the new IDs in the in-memory split output │ +│ (bit-identical to what just landed on storage; no extra read) │ +│ │ +│ multicut (retry against post-split graph): │ +│ result ─► clean cut → edges to remove │ +│ │ still-split-needed → surface precondition error │ +│ │ +│ commit the cut: │ +│ remove graph-level edges │ +│ produce new roots │ +│ write hierarchy rows + operation log │ +└───────────────────────────────────────────────────────────────────────┘ + │ + ▼ +Release root lock — edit is durable + │ + ▼ +Publish pubsub message; when an SV split ran it carries the list of +base-resolution bounding boxes that were rewritten + │ + ▼ +┌───────────────────────────────────────────────────────────────┐ +│ Async downsample worker │ +│ │ +│ partition each published bbox into pyramid blocks │ +│ (cube regions aligned to the coarsest MIP's chunk grid; │ +│ two distinct blocks never share a storage chunk at │ +│ any MIP level) │ +│ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ PYRAMID BLOCK LOCK (separate lock family from L2) │ │ +│ │ │ │ +│ │ for each pyramid block: │ │ +│ │ read base resolution │ │ +│ │ downsample through every coarser MIP │ │ +│ │ write only tiles whose footprint intersects │ │ +│ │ a published bbox │ │ +│ └──────────────────────────────────────────────────────┘ │ +└───────────────────────────────────────────────────────────────┘ +``` + +### Notes on the flow + +- **"SV split required" is a return value, not an exception.** The multicut returns one of several tagged outcomes so the caller dispatches with a straight branch. Nothing uses raise/catch for control flow, which is what allows the root lock to stay held across the detect-then-split-then-commit sequence without the exception unwinding the lock. + +- **The cross-chunk-representative mapping** comes out of the multicut for free: as part of building its local graph it stitches every cross-chunk-connected group of graph-level supervoxels into one node and records the mapping. That map tells the split step which supervoxels are artificially-cut pieces of one physical SV, and which of them sit on a source→sink bridge. + +- **The split is per-representative.** If two unrelated physical supervoxels both need splitting in one edit (rare but possible), each is handled in its own pass under the same L2 chunk lock. + +## Concurrency design + +Two races exist at the segmentation layer even with root locks in place: + +- **Same-root race.** Without care, the root lock could drop between "detect split needed" and "perform split", letting another edit on the same root slip in and race for the same supervoxel pieces. +- **Cross-root spatial race.** Two edits on entirely distinct roots can target supervoxels whose pieces live in overlapping chunks. Root locks don't serialize them; segmentation writes would clobber each other. + +The split flow closes both: + +- **Root lock scope covers the full operation.** Detection, supervoxel-level split, retry detection, commit — all under one root lock. Same-root interleaving is impossible; any other edit on the root waits for this one to finish. + +- **L2 chunk lock covers the supervoxel-level split only.** Inside the root lock, the split step additionally acquires a spatial lock on every L2 chunk it will read or write. Keyed by chunk, so edits on different roots but overlapping chunks serialize here. Released as soon as the split writes land; the graph-level commit afterwards runs under the root lock alone. + +### How the spatial lock set is computed + +For each cross-chunk representative being split, the read/cut region is the base-voxel envelope of that rep's source and sink seed coordinates, padded by one CG chunk on each side. The cut surface lives between the seeds, so pieces of the rep far from both seeds never participate — the seed envelope is the region that gets read and rewritten, not the rep's full piece-set envelope, which for an SV cut across many chunks can be orders of magnitude larger. To derive the lock set, expand that envelope by one voxel (because the edge-routing step reads a 1-voxel shell outside the rewritten region to see neighboring supervoxels' labels), map it to the overlapping L2 chunks, union the per-representative chunk sets, and sort deterministically so workers with overlapping sets never acquire in opposite orders. + +The chunks locked are exactly the chunks the split will touch, plus the 1-chunk margin the shell read requires. + +### How the write scope is kept minimal + +Only chunks that actually receive new supervoxel IDs get written to storage. Gap chunks that happen to sit inside an envelope but contain no cross-chunk-connected pieces, and neighbor chunks read only for the edge-routing shell, are never written. The segmentation backend is append-only, so writing unchanged bytes would inflate the on-disk delta for no real change. + +### Why the post-split ID refresh is safe without an extra read + +After the split lands, the caller-supplied source and sink supervoxel IDs reference now-superseded supervoxels. The retry multicut needs the *current* IDs at the source and sink pixels — the subgraph fetch returns only live supervoxels, so a mincut asking about superseded ones would fail to find its endpoints. + +The in-memory segmentation block produced by the split is bitwise identical to what was just written to storage, and the storage write is synchronous (we wait for it) and happens under the L2 chunk lock (so nothing else can have mutated those voxels). Looking up source/sink coords in that block returns the same IDs a storage re-read would — no extra round-trip needed. + +### Worker crash mid-write + +A worker that dies — or raises from the persist block — inside the indefinite L2 chunk lock's scope leaves the lock cells set and the op-log row's `L2ChunkLockScope` populated with the exact chunks being written. Future ops on any of those chunks refuse to start — the crashed state is isolated, not amplified. An operator runs the recovery flow described in [recovery.md](recovery.md) to revert the partial writes and replay the op. + +## Invariants + +- A supervoxel split and its graph-level commit are one atomic operation. Either both land or neither does, under a single root lock. +- Within the supervoxel-split step, concurrent splits on overlapping L2 chunks serialize. No two operations write segmentation to the same chunk at the same time. +- Supervoxel-level writes touch only chunks whose voxels actually changed. Gap chunks between cross-chunk-connected pieces and neighbor chunks read for edge routing are untouched. +- After the commit, readers at the operation's timestamp see new supervoxel IDs in the cut region and new roots reflecting the cut. +- Coarser MIP levels are eventually consistent with the base scale, lagging at most until the async downsample worker processes the operation's pubsub message. + +## Related docs + +- [Algorithm](algorithm.md) — the voxel-level geodesic cut. +- [Design](design.md) — rationale behind the cut. +- [Edges](edges.md) — edge re-routing after the cut. +- [Recovery](recovery.md) — replay / recovery of interrupted splits. diff --git a/pychunkedgraph/graph/sv_split/__init__.py b/pychunkedgraph/graph/sv_split/__init__.py new file mode 100644 index 000000000..7067100af --- /dev/null +++ b/pychunkedgraph/graph/sv_split/__init__.py @@ -0,0 +1,5 @@ +""" +Supervoxel splitting. +""" + +from . import cutting, edges, edits diff --git a/pychunkedgraph/graph/sv_split/algorithm.md b/pychunkedgraph/graph/sv_split/algorithm.md new file mode 100644 index 000000000..17e82e352 --- /dev/null +++ b/pychunkedgraph/graph/sv_split/algorithm.md @@ -0,0 +1,89 @@ +# SV splitting — voxel-level cut algorithm + +This is the algorithmic core of SV splitting: given the voxels of one +supervoxel and two cut seeds, partition every voxel into a *source* side and a +*sink* side. It lives in `graph/sv_split/cutting.py` (`split_supervoxel`, +`enforce_cc`) plus the coord extraction helper `build_coords_by_label`. + +The cut is a **geodesic region grow**, not a voxel graph-mincut: each voxel is +assigned to whichever seed is nearer in geodesic (anisotropy-aware) travel +cost across the supervoxel's interior. + +## In plain terms + +The cut works on a boolean mask of the one supervoxel being split: a voxel is +*true* if it belongs to that supervoxel and *false* otherwise. Call the true +voxels "the supervoxel's voxels". + +1. **Snap the seeds.** The user clicks a source point and a sink point. Those + clicks need not land exactly on one of the supervoxel's voxels, so each seed + is moved to the nearest true voxel. This keeps the next step starting inside + the object. +2. **Flood from each seed.** From the source seed, compute the cheapest path to + every true voxel, where travelling through the thick interior is cheap and + skimming the thin surface is expensive; do the same from the sink seed. +3. **Assign each voxel to its nearer seed.** A voxel goes to the source side if + the source seed reaches it more cheaply, otherwise the sink side. The cut + surface is simply where the two floods meet — it naturally lands at the + object's thinnest neck. +4. **Clean up.** Make sure each side is a single connected piece, reassigning + any stray fragments to the side they border most. + +The rest of this document is the precise version of those four steps. + +## Inputs + +- `coords` — `(N, 3)` global voxel coordinates of the supervoxel. +- `seed_source`, `seed_sink` — global-voxel coordinates of the two cut seeds. +- `resolution` — voxel size in nm `(x, y, z)`; supplies anisotropy. +- `voxel_offset` — volume origin (chunk corner). + +## Steps + +1. **ROI restriction.** Compute the foreground bounding box from `coords` and + work inside it only. A dense occupancy volume is allocated over the ROI, not + over the whole chunk, so cost scales with foreground extent. +2. **Seed snapping.** Build a KDTree over the foreground voxel coordinates and + snap each seed to its nearest occupied voxel. Seeds need not lie exactly on + a foreground voxel. +3. **EDT → speed / travel-cost map.** A Euclidean distance transform (sampled + by `resolution`) gives each voxel its distance to the boundary. Speed favors + the interior; geodesic travel cost is `1 / speed`, so paths hug the medial + region rather than skimming the surface. +4. **Geodesic arrival.** `MCP_Geometric` (with `sampling=resolution`, so the + metric is anisotropy-aware) computes arrival cost from each seed to every + voxel. Each voxel is assigned to the side whose seed it reaches more cheaply. +5. **Narrow-band proximity boost.** Voxels within `narrow_band_width` of the + opposing side receive a cost boost (`proximity_boost`) so the boundary + tracks the geometric midline more closely. +6. **Optional downsampled grid.** When `downsample` is set, the geodesic grow + runs on a downsampled grid and is upsampled back. Upsampling happens *before* + single-CC enforcement, because upsampling under the foreground mask can + fragment a label into disconnected pieces. +7. **Single-CC enforcement** (`enforce_cc`, when `enforce_single_cc=True`). + For each side, keep the largest / seeded 26-connected component; relabel + stray components to a transient label `3`. Each label-`3` component is then + dilated **within its own bounding box** and reassigned to side 1 or 2 by + which side it borders more, with EDT distance as the tie-break. Confining + the dilation to the component's bbox is exact: per-component border counts + are identical to a full-volume dilation. + +## The single-CC invariant + +`_update_chunks` (`graph/sv_split/edits.py`) assigns **exactly one new +supervoxel id per distinct label value per chunk** and performs **no connected +-component analysis of its own**. It relies entirely on the cut having already +produced single-CC labels. Therefore every output label out of `enforce_cc` +**must already be a single connected component** — if a side were left in two +pieces, both pieces would collapse into one new supervoxel id and the graph +would gain a spuriously-connected supervoxel. + +This is why single-CC is enforced at full resolution and why label-`3` strays +are resolved rather than dropped. + +## Related docs + +- [Overview](README.md) — where this fits in the edit pipeline. +- [Design](design.md) — rationale for the choices above. +- [Edges](edges.md) — edge re-routing after the cut. +- [Recovery](recovery.md) — replay-safety of splits. diff --git a/pychunkedgraph/graph/cutting_sv.py b/pychunkedgraph/graph/sv_split/cutting.py similarity index 73% rename from pychunkedgraph/graph/cutting_sv.py rename to pychunkedgraph/graph/sv_split/cutting.py index d1869951c..3ebdf85d5 100644 --- a/pychunkedgraph/graph/cutting_sv.py +++ b/pychunkedgraph/graph/sv_split/cutting.py @@ -1,7 +1,8 @@ from time import perf_counter +import fastremap import numpy as np -from typing import Dict, Tuple, Optional, Sequence +from typing import Dict, Iterable, Tuple, Optional, Sequence from scipy.spatial import cKDTree # EDT backends: prefer Seung-Lab edt, fallback to scipy.ndimage @@ -19,6 +20,11 @@ ball, ) # keep only ball; use ndi.binary_dilation everywhere +from pychunkedgraph import get_logger +from pychunkedgraph.profiler import get_profiler + +logger = get_logger(__name__) + # ---------- Fast CC wrappers ---------- try: import cc3d @@ -194,10 +200,11 @@ def snap_seeds_to_segment( downsample_mode="stride", # 'stride' or 'random' downsample_stride=2, # used if mode='stride' downsample_target=None, # used if mode='random' + use_bbox=False, + bbox_pad_phys=None, # physical pad per side; None -> derived from voxel_size rng=None, return_index=False, leafsize=16, - log=lambda x: None, tag="snap", method="kdtree", # accepted for compatibility; only 'kdtree' currently ): @@ -220,6 +227,14 @@ def snap_seeds_to_segment( downsample_mode : 'stride' or 'random' for boundary sampling. downsample_stride : If stride mode, use every Nth boundary voxel. downsample_target : If random mode, target number of boundary points to keep. + use_bbox : If True, restrict the candidate scan to a bounding box around + the seeds instead of the whole mask. The nearest true voxel is + always near a seed, so this returns the identical snapped voxel + while building a far smaller KDTree. The box grows until it + contains a true voxel, so it never changes the result. + bbox_pad_phys : Physical pad (same units as voxel_size) added on each side of + the seed bbox. None -> derived from voxel_size. Only a starting + size; the grow loop makes it correctness-independent. rng : Optional np.random.Generator for reproducible random sampling. return_index : If True, also return indices of nearest boundary points. leafsize : cKDTree leafsize parameter. @@ -238,7 +253,9 @@ def snap_seeds_to_segment( """ t0 = perf_counter() if method != "kdtree": - log(f"[{tag}] Warning: 'method={method}' not supported; using 'kdtree'.") + logger.debug( + f"[{tag}] Warning: 'method={method}' not supported; using 'kdtree'." + ) # Validate mask if mask.ndim != 3: @@ -249,33 +266,74 @@ def snap_seeds_to_segment( if mask_order not in ("zyx", "xyz"): raise ValueError("mask_order must be 'zyx' or 'xyz'") - # Optional boundary extraction for speed - tb = perf_counter() - if use_boundary: - candidate_mask = _extract_mask_boundary(mask, erosion_iters=erosion_iters) - # Fallback to full mask if boundary is empty - if not candidate_mask.any(): - candidate_mask = mask - log(f"[{tag}] boundary empty → fallback to full mask") - else: - candidate_mask = mask - log(f"[{tag}] candidate extraction | {perf_counter()-tb:.3f}s") + # Prepare seeds array (needed up-front for the bbox window below) + seeds_xyz = np.asarray(seeds_xyz, dtype=np.float64) + if seeds_xyz.ndim == 1: + seeds_xyz = seeds_xyz[None, :] + if seeds_xyz.shape[1] != 3: + raise ValueError("seeds_xyz must be shape (N, 3)") - # Obtain candidate voxel coordinates in XYZ order - tc = perf_counter() + # Full-mask voxel bounds (XYZ), used for the final clip regardless of windowing. if mask_order == "zyx": - # mask shape is (Z, Y, X), np.where -> (z, y, x) - zc, yc, xc = np.where(candidate_mask) - points_xyz = np.stack([xc, yc, zc], axis=1) max_x, max_y, max_z = mask.shape[2] - 1, mask.shape[1] - 1, mask.shape[0] - 1 else: - # mask shape is (X, Y, Z), np.where -> (x, y, z) - xc, yc, zc = np.where(candidate_mask) - points_xyz = np.stack([xc, yc, zc], axis=1) max_x, max_y, max_z = mask.shape[0] - 1, mask.shape[1] - 1, mask.shape[2] - 1 - log( - f"[{tag}] candidate coordinates | {perf_counter()-tc:.3f}s (n={len(points_xyz)})" - ) + + # Axis order of `mask` as XYZ indices: voxel (x, y, z) lives at mask[idx]. + ax_xyz = (2, 1, 0) if mask_order == "zyx" else (0, 1, 2) + + def _candidates_xyz(window_mask, origin_xyz): + """np.where over a (cropped) mask → candidate voxel coords in XYZ. + + `origin_xyz` is the XYZ coordinate of the crop's [0,0,0] corner, + added back so coords are global. Honors use_boundary exactly as + the full-mask path does. + """ + if use_boundary: + cand = _extract_mask_boundary(window_mask, erosion_iters=erosion_iters) + if not cand.any(): + cand = window_mask + logger.debug(f"[{tag}] boundary empty → fallback to full mask") + else: + cand = window_mask + wc = np.where(cand) # tuple in mask-axis order + pts = np.stack([wc[ax_xyz[0]], wc[ax_xyz[1]], wc[ax_xyz[2]]], axis=1) + return pts + np.asarray(origin_xyz, dtype=pts.dtype) + + tb = perf_counter() + if use_bbox and seeds_xyz.shape[0] > 0: + # Per-axis voxel pad derived from physical pad / voxel_size (anisotropy + # correct, no hardcoded axis). The window grows until it contains a true + # voxel, so the pad is only a starting size — never a correctness bound. + vsize_xyz = np.asarray(voxel_size, dtype=np.float64) + if bbox_pad_phys is None: + bbox_pad_phys = float(vsize_xyz.max()) + seed_min_xyz = np.floor(seeds_xyz.min(axis=0)).astype(np.int64) + seed_max_xyz = np.ceil(seeds_xyz.max(axis=0)).astype(np.int64) + full_max_xyz = np.array([max_x, max_y, max_z], dtype=np.int64) + pad_phys = float(bbox_pad_phys) + points_xyz = np.empty((0, 3), dtype=np.int64) + while True: + pad_vox = np.ceil(pad_phys / vsize_xyz).astype(np.int64) + lo_xyz = np.maximum(seed_min_xyz - pad_vox, 0) + hi_xyz = np.minimum(seed_max_xyz + pad_vox, full_max_xyz) + # Slice the mask in its own axis order. + lo_mask = lo_xyz[list(ax_xyz)] + hi_mask = hi_xyz[list(ax_xyz)] + sl = tuple(slice(int(lo_mask[a]), int(hi_mask[a]) + 1) for a in range(3)) + points_xyz = _candidates_xyz(mask[sl], lo_xyz) + if points_xyz.shape[0] > 0: + break + if np.all(lo_xyz == 0) and np.all(hi_xyz == full_max_xyz): + # Window already spans the full mask and it is still empty. + break + pad_phys *= 2.0 + logger.debug(f"[{tag}] bbox candidate scan | {perf_counter()-tb:.3f}s") + else: + points_xyz = _candidates_xyz(mask, (0, 0, 0)) + logger.debug(f"[{tag}] candidate scan | {perf_counter()-tb:.3f}s") + + logger.debug(f"[{tag}] candidate coordinates (n={len(points_xyz)})") if points_xyz.shape[0] == 0: raise ValueError( @@ -294,14 +352,9 @@ def snap_seeds_to_segment( rng=rng, ) after = len(points_xyz) - log(f"[{tag}] downsample points {before} → {after} | {perf_counter()-td:.3f}s") - - # Prepare seeds array - seeds_xyz = np.asarray(seeds_xyz, dtype=np.float64) - if seeds_xyz.ndim == 1: - seeds_xyz = seeds_xyz[None, :] - if seeds_xyz.shape[1] != 3: - raise ValueError("seeds_xyz must be shape (N, 3)") + logger.debug( + f"[{tag}] downsample points {before} → {after} | {perf_counter()-td:.3f}s" + ) # Scale coordinates to physical space to respect anisotropy vx, vy, vz = voxel_size @@ -314,7 +367,7 @@ def snap_seeds_to_segment( te = perf_counter() tree = cKDTree(points_scaled, leafsize=leafsize) _, nn_indices = tree.query(seeds_scaled, k=1, workers=-1) - log(f"[{tag}] KDTree build+query | {perf_counter()-te:.3f}s") + logger.debug(f"[{tag}] KDTree build+query | {perf_counter()-te:.3f}s") # Map back to integer voxel coords (XYZ) snapped_xyz = points_xyz[nn_indices].astype(np.int64) @@ -324,7 +377,9 @@ def snap_seeds_to_segment( snapped_xyz[:, 1] = np.clip(snapped_xyz[:, 1], 0, max_y) snapped_xyz[:, 2] = np.clip(snapped_xyz[:, 2], 0, max_z) - log(f"[{tag}] snapped {len(seeds_xyz)} seeds | total {perf_counter()-t0:.3f}s") + logger.debug( + f"[{tag}] snapped {len(seeds_xyz)} seeds | total {perf_counter()-t0:.3f}s" + ) if return_index: return snapped_xyz, nn_indices else: @@ -334,7 +389,7 @@ def snap_seeds_to_segment( # ============================================================ # EDT wrapper (Seung-Lab edt preferred, fallback to scipy) # ============================================================ -def _compute_edt(mask: np.ndarray, sampling_zyx, log=lambda x: None, tag="edt"): +def _compute_edt(mask: np.ndarray, sampling_zyx, tag="edt"): """ Compute Euclidean distance transform using Seung-Lab edt if available, otherwise fallback to scipy.ndimage.distance_transform_edt. @@ -345,11 +400,11 @@ def _compute_edt(mask: np.ndarray, sampling_zyx, log=lambda x: None, tag="edt"): t0 = perf_counter() if _HAVE_EDT_FAST: dist = _edt_fast(mask.astype(np.uint8, copy=False), anisotropy=sampling_zyx) - log(f"[{tag}] Seung-Lab edt | {perf_counter()-t0:.3f}s") + logger.debug(f"[{tag}] Seung-Lab edt | {perf_counter()-t0:.3f}s") return dist else: dist = ndi.distance_transform_edt(mask, sampling=sampling_zyx) - log(f"[{tag}] SciPy EDT | {perf_counter()-t0:.3f}s") + logger.debug(f"[{tag}] SciPy EDT | {perf_counter()-t0:.3f}s") return dist @@ -384,11 +439,8 @@ def connect_both_seeds_via_ridge( refine_fullres_when_fail: bool = True, snap_method: str = "kdtree", snap_kwargs: dict | None = None, - verbose: bool = False, ): - def log(msg: str): - if verbose: - print(msg, flush=True) + _prof = get_profiler() def _bbox_pad_zyx(points_zyx, shape, pad=(24, 48, 48)): pts = np.asarray(points_zyx, int) @@ -434,10 +486,10 @@ def _mst_edges_phys(pts_zyx, sampling): return edges t0 = perf_counter() - log( + logger.debug( f"[connect] vol_order={vol_order}, vox_order={vox_order}, seed_order={seed_order}" ) - log( + logger.debug( f"[connect] mask shape: {binary_sv.shape}, ridge_power={ridge_power}, ds={downsample}" ) @@ -475,18 +527,20 @@ def _snap(pts_zyx, name): sampling[1], sampling[0], ), # convert ZYX->XYZ spacing - log=log, tag=f"{name}@snap", **snap_cfg, ) # Back to ZYX return snapped_xyz[:, [2, 1, 0]] - A_zyx = _snap(A_in_zyx, "A") - B_zyx = _snap(B_in_zyx, "B") + with _prof.profile("snap_seeds"): + A_zyx = _snap(A_in_zyx, "A") + B_zyx = _snap(B_in_zyx, "B") if len(A_zyx) == 0 or len(B_zyx) == 0: - log("[connect] after snapping, one side has no seeds; skipping connection") + logger.debug( + "[connect] after snapping, one side has no seeds; skipping connection" + ) return ( _seeds_from_zyx(A_zyx, seed_order), _seeds_from_zyx(B_zyx, seed_order), @@ -499,7 +553,9 @@ def _snap(pts_zyx, name): np.vstack([A_zyx, B_zyx]), sv_zyx.shape, pad=roi_pad_zyx ) roi = sv_zyx[z0:z1, y0:y1, x0:x1] - log(f"[connect] ROI: z[{z0}:{z1}] y[{y0}:{y1}] x[{x0}:{x1}] → shape {roi.shape}") + logger.debug( + f"[connect] ROI: z[{z0}:{z1}] y[{y0}:{y1}] x[{x0}:{x1}] → shape {roi.shape}" + ) # Downsample ROI sz, sy, sx = map(int, downsample) @@ -509,7 +565,7 @@ def _snap(pts_zyx, name): else: roi_ds = roi sampling_ds = (sampling[0] * sz, sampling[1] * sy, sampling[2] * sx) - log( + logger.debug( f"[connect] ROI downsampled {roi.shape} -> {roi_ds.shape} | {perf_counter()-ti_ds:.3f}s" ) @@ -532,7 +588,6 @@ def _to_roi_ds_snapped(pts_zyx, name="seedDS"): mask=roi_ds, mask_order="zyx", voxel_size=(sampling_ds[2], sampling_ds[1], sampling_ds[0]), - log=log, tag=f"{name}@roi_ds", use_boundary=False, downsample=False, @@ -542,7 +597,7 @@ def _to_roi_ds_snapped(pts_zyx, name="seedDS"): return snapped_ds_zyx.astype(int) except ValueError as e: # If roi_ds is empty or degenerate, bail out gracefully: - log( + logger.debug( f"[{name}@roi_ds] snapping failed ({e}); falling back to nearest-int grid & mask check." ) approx = np.floor(seeds_ds + 0.5).astype(int) @@ -560,7 +615,7 @@ def _to_roi_ds_snapped(pts_zyx, name="seedDS"): okA = len(A_ds) >= 1 okB = len(B_ds) >= 1 if not (okA and okB): - log( + logger.debug( "[connect] seeds disappeared or failed to map on DS grid; consider smaller ds or use_boundary=False/downsample=False in snapping." ) return ( @@ -572,9 +627,9 @@ def _to_roi_ds_snapped(pts_zyx, name="seedDS"): # EDT and cost on DS ROI (Seung-Lab edt if available) t1 = perf_counter() - dist = _compute_edt(roi_ds, sampling_ds, log=log, tag="connect:EDT") + dist = _compute_edt(roi_ds, sampling_ds, tag="connect:EDT") if dist.max() <= 0: - log("[connect] empty EDT in ROI; skipping connection") + logger.debug("[connect] empty EDT in ROI; skipping connection") return ( _seeds_from_zyx(A_zyx, seed_order), _seeds_from_zyx(B_zyx, seed_order), @@ -585,7 +640,7 @@ def _to_roi_ds_snapped(pts_zyx, name="seedDS"): eps = 1e-6 cost = np.full_like(dn, 1e12, dtype=float) cost[roi_ds] = 1.0 / (eps + np.clip(dn[roi_ds], 0, 1) ** max(0.0, ridge_power)) - log(f"[connect] EDT/cost ready on DS-ROI | {perf_counter()-t1:.3f}s") + logger.debug(f"[connect] EDT/cost ready on DS-ROI | {perf_counter()-t1:.3f}s") # Shortest paths via MST def _path_mask_ds(start, end): @@ -595,14 +650,14 @@ def _path_mask_ds(start, end): mid = perf_counter() v = costs[tuple(end)] if not np.isfinite(v): - log( + logger.debug( f"[MCP] start={tuple(start)} -> end={tuple(end)} FAILED | setup+run={mid-tmcp:.3f}s" ) return None path = np.asarray(mcp.traceback(tuple(end)), int) m = np.zeros_like(roi_ds, bool) m[tuple(path.T)] = True - log( + logger.debug( f"[MCP] start={tuple(start)} -> end={tuple(end)} OK | total={perf_counter()-tmcp:.3f}s" ) return m @@ -616,14 +671,12 @@ def _augment_team_ds(team_name, pts_ds): for i, j in edges: m = _path_mask_ds(pts_ds[i], pts_ds[j]) if m is None: - log(f"[connect:{team_name}] DS path FAILED for edge {i}-{j}") + logger.debug(f"[connect:{team_name}] DS path FAILED for edge {i}-{j}") ok = False if refine_fullres_when_fail: # fallback full-res EDT and path tfr = perf_counter() - dist_fr = _compute_edt( - roi, sampling, log=log, tag="connect:EDT(fullres)" - ) + dist_fr = _compute_edt(roi, sampling, tag="connect:EDT(fullres)") dnm = dist_fr / (dist_fr.max() if dist_fr.max() > 0 else 1.0) cost_fr = np.full_like(dist_fr, 1e12, dtype=float) cost_fr[roi] = 1.0 / ( @@ -639,11 +692,11 @@ def _augment_team_ds(team_name, pts_ds): m_fr[tuple(path_fr.T)] = True m = m_fr[::sz, ::sy, ::sx] ok = True - log( + logger.debug( f"[connect:{team_name}] fallback full-res path OK | {perf_counter()-tfr:.3f}s" ) else: - log( + logger.debug( f"[connect:{team_name}] Full-res ROI path also FAILED for edge {i}-{j}" ) m = None @@ -656,10 +709,10 @@ def _augment_team_ds(team_name, pts_ds): pB_ds, okB2 = _augment_team_ds("B", B_ds) okA &= okA2 okB &= okB2 - log(f"[connect] MST+paths built | {perf_counter()-t_aug:.3f}s") + logger.debug(f"[connect] MST+paths built | {perf_counter()-t_aug:.3f}s") if not (okA and okB): - log( + logger.debug( "[connect] connection failed for at least one team — consider smaller downsample or refine_fullres_when_fail." ) return ( @@ -676,7 +729,7 @@ def _augment_team_ds(team_name, pts_ds): tpost = perf_counter() pA = ndi.binary_dilation(pA, structure=struc) & roi pB = ndi.binary_dilation(pB, structure=struc) & roi - log(f"[connect] postproc dilation on paths | {perf_counter()-tpost:.3f}s") + logger.debug(f"[connect] postproc dilation on paths | {perf_counter()-tpost:.3f}s") A_aug = set(map(tuple, A_zyx)) B_aug = set(map(tuple, B_zyx)) @@ -689,7 +742,7 @@ def _augment_team_ds(team_name, pts_ds): A_aug = _seeds_from_zyx(np.array(sorted(list(A_aug)), int), seed_order) B_aug = _seeds_from_zyx(np.array(sorted(list(B_aug)), int), seed_order) - log( + logger.debug( f"[connect] done; +{len(A_aug)-len(seeds_a)} vox for A, +{len(B_aug)-len(seeds_b)} for B | total {perf_counter()-t0:.3f}s" ) return A_aug, B_aug, True, True @@ -724,12 +777,8 @@ def split_supervoxel_growing( # snapping control (NEW) snap_method: str = "kdtree", snap_kwargs: dict | None = None, - # logging - verbose: bool = False, ): - def log(msg: str): - if verbose: - print(msg, flush=True) + _prof = get_profiler() # Helpers reused from the module: _cc_label_26, _largest_component_id, _to_internal_zyx_volume, _from_internal_zyx_volume # _seeds_to_zyx, _compute_edt, etc. are assumed available. @@ -742,7 +791,7 @@ def _enforce_single_component(out_labels, lab, seed_pts_global, allow3=True): return 0, 0 comp, ncomp = _cc_label_26(mask) if ncomp <= 1: - log(f"[single-cc:{lab}] ncomp=1 | {perf_counter()-t:.3f}s") + logger.debug(f"[single-cc:{lab}] ncomp=1 | {perf_counter()-t:.3f}s") return 1, 0 keep_ids = set() @@ -766,7 +815,7 @@ def _enforce_single_component(out_labels, lab, seed_pts_global, allow3=True): moved = int(bad_mask.sum()) if allow3 and moved: out_labels[bad_mask] = 3 - log( + logger.debug( f"[single-cc:{lab}] kept={len(keep_ids)}, moved_to_3={moved} | {perf_counter()-t:.3f}s" ) return len(keep_ids), moved @@ -777,24 +826,38 @@ def _resolve_label3_touching_vectorized( t0 = perf_counter() comp3, n3 = _cc_label_26(out_labels == 3) n3_vox = int((out_labels == 3).sum()) - log(f"[touching] n3 comps={n3}, vox={n3_vox}") + logger.debug(f"[touching] n3 comps={n3}, vox={n3_vox}") if n3 == 0: - log(f"[touching] no label-3 components | {perf_counter()-t0:.3f}s") + logger.debug( + f"[touching] no label-3 components | {perf_counter()-t0:.3f}s" + ) return 0, 0 + # Per label-3 component, count how many of its voxels border each + # side (26-conn), and assign it to the side it borders more. This + # is the dilate-side-mask-and-intersect-with-comp3 test, but run + # only inside the label-3 bounding box (+1 halo) instead of over + # the whole volume — the +1 halo captures the side voxels just + # outside the label-3 region that the dilation reaches, so the + # counts are identical to a full-volume dilation. t1 = perf_counter() struc = np.ones((3, 3, 3), bool) - N1 = ndi.binary_dilation(out_labels == 1, structure=struc) & (comp3 > 0) - N2 = ndi.binary_dilation(out_labels == 2, structure=struc) & (comp3 > 0) - - cnt1 = np.bincount(comp3[N1], minlength=n3 + 1) - cnt2 = np.bincount(comp3[N2], minlength=n3 + 1) + nz = np.argwhere(comp3 > 0) + lo = np.maximum(nz.min(0) - 1, 0) + hi = np.minimum(nz.max(0) + 2, comp3.shape) + sl = tuple(slice(int(a), int(b)) for a, b in zip(lo, hi)) + comp3_sl = comp3[sl] + m3_sl = comp3_sl > 0 + N1 = ndi.binary_dilation(out_labels[sl] == 1, structure=struc) & m3_sl + N2 = ndi.binary_dilation(out_labels[sl] == 2, structure=struc) & m3_sl + cnt1 = np.bincount(comp3_sl[N1], minlength=n3 + 1) + cnt2 = np.bincount(comp3_sl[N2], minlength=n3 + 1) assign = np.zeros(n3 + 1, dtype=np.int16) # 0=undecided, 1 or 2 otherwise assign[cnt1 > cnt2] = 1 assign[cnt2 > cnt1] = 2 undec = np.where(assign[1:] == 0)[0] + 1 - log( + logger.debug( f"[touching] maj→1={int((assign==1).sum())}, maj→2={int((assign==2).sum())}, ties={len(undec)} | {perf_counter()-t1:.3f}s" ) @@ -810,8 +873,8 @@ def _resolve_label3_touching_vectorized( sA[tuple(np.array(seedsA).T)] = True sB = np.zeros_like(out_labels, bool) sB[tuple(np.array(seedsB).T)] = True - dA = _compute_edt(~sA, sampling, log=log, tag="split:EDT(dA)") - dB = _compute_edt(~sB, sampling, log=log, tag="split:EDT(dB)") + dA = _compute_edt(~sA, sampling, tag="split:EDT(dA)") + dB = _compute_edt(~sB, sampling, tag="split:EDT(dB)") closer2 = (dB < dA) & (comp3 > 0) pref2 = np.bincount(comp3[closer2], minlength=n3 + 1) @@ -821,7 +884,7 @@ def _resolve_label3_touching_vectorized( choose2 = pref2[tie_ids] > (total[tie_ids] - pref2[tie_ids]) assign[tie_ids[choose2]] = 2 assign[tie_ids[~choose2]] = 1 - log( + logger.debug( f"[touching] tie-break EDT done: to2={int(choose2.sum())}, to1={int((~choose2).sum())} | {perf_counter()-t2:.3f}s" ) @@ -835,26 +898,28 @@ def _resolve_label3_touching_vectorized( moved2 = int(mask2.sum()) out_labels[mask2] = 2 - log( + logger.debug( f"[touching] reassigned 3→1: {moved1}, 3→2: {moved2} | total {perf_counter()-t0:.3f}s" ) return moved1, moved2 # ---------- begin ---------- t0 = perf_counter() - log(f"[init] vol_order={vol_order}, vox_order={vox_order}, seed_order={seed_order}") - log(f"[init] input volume shape: {binary_sv.shape}") + logger.debug( + f"[init] vol_order={vol_order}, vox_order={vox_order}, seed_order={seed_order}" + ) + logger.debug(f"[init] input volume shape: {binary_sv.shape}") # Convert input volumes and sampling into internal ZYX sv_zyx, _ = _to_internal_zyx_volume(binary_sv, vol_order) sampling = _to_zyx_sampling(voxel_size, vox_order) - log(f"[init] internal shape (z,y,x): {sv_zyx.shape}") - log(f"[init] sampling (z,y,x): {sampling}") + logger.debug(f"[init] internal shape (z,y,x): {sv_zyx.shape}") + logger.debug(f"[init] sampling (z,y,x): {sampling}") # SNAP seeds to mask using the same KDTree-based method A_all = _seeds_to_zyx(seeds_a, seed_order) B_all = _seeds_to_zyx(seeds_b, seed_order) - log("[snap] snapping seeds to segment mask...") + logger.debug("[snap] snapping seeds to segment mask...") snap_cfg = dict( use_boundary=True, @@ -881,50 +946,53 @@ def _snap_ZYX(pts_zyx, tagname): sampling[1], sampling[0], ), # convert ZYX→XYZ spacing - log=log, tag=tagname, **snap_cfg, ) return snapped_xyz[:, [2, 1, 0]] - A = _snap_ZYX(A_all, "A@snap") - B = _snap_ZYX(B_all, "B@snap") - log(f"[seeds] A={len(A)}, B={len(B)}") + with _prof.profile("snap_seeds"): + A = _snap_ZYX(A_all, "A@snap") + B = _snap_ZYX(B_all, "B@snap") + logger.debug(f"[seeds] A={len(A)}, B={len(B)}") out_zyx = np.zeros_like(sv_zyx, dtype=np.int16) if A.size == 0 or B.size == 0 or not np.any(sv_zyx): - log("[seeds] missing seeds or empty SV; returning label=1 for entire SV") + logger.debug( + "[seeds] missing seeds or empty SV; returning label=1 for entire SV" + ) out_zyx[sv_zyx] = 1 return _from_internal_zyx_volume(out_zyx, vol_order) # Tight bbox ROI around mask with halo t_bbox = perf_counter() - Z, Y, X = sv_zyx.shape - coords = np.argwhere(sv_zyx) - z0, y0, x0 = coords.min(0) - z1, y1, x1 = coords.max(0) + 1 - z0h = max(z0 - halo, 0) - y0h = max(y0 - halo, 0) - x0h = max(x0 - halo, 0) - z1h = min(z1 + halo, Z) - y1h = min(y1 + halo, Y) - x1h = min(x1 + halo, X) - sv = sv_zyx[z0h:z1h, y0h:y1h, x0h:x1h] - A_roi = A - np.array([z0h, y0h, x0h]) - B_roi = B - np.array([z0h, y0h, x0h]) - log( + with _prof.profile("roi_crop"): + Z, Y, X = sv_zyx.shape + coords = np.argwhere(sv_zyx) + z0, y0, x0 = coords.min(0) + z1, y1, x1 = coords.max(0) + 1 + z0h = max(z0 - halo, 0) + y0h = max(y0 - halo, 0) + x0h = max(x0 - halo, 0) + z1h = min(z1 + halo, Z) + y1h = min(y1 + halo, Y) + x1h = min(x1 + halo, X) + sv = sv_zyx[z0h:z1h, y0h:y1h, x0h:x1h] + A_roi = A - np.array([z0h, y0h, x0h]) + B_roi = B - np.array([z0h, y0h, x0h]) + logger.debug( f"[crop] ROI shape (internal): {sv.shape} (halo {halo}) | {perf_counter()-t_bbox:.3f}s" ) # Build travel cost via EDT (Seung-Lab edt if available) t1 = perf_counter() - dist = _compute_edt(sv, sampling, log=log, tag="split:EDT(mask)") + dist = _compute_edt(sv, sampling, tag="split:EDT(mask)") distn = dist / dist.max() if dist.max() > 0 else dist eps = 1e-6 speed = np.clip(distn ** max(gamma_neck, 0.0), eps, 1.0) travel_cost = np.full_like(speed, 1e12, dtype=float) travel_cost[sv] = 1.0 / speed[sv] - log( + logger.debug( f"[speed] EDT + speed map | {perf_counter()-t1:.3f}s (total {perf_counter()-t0:.3f}s)" ) @@ -932,7 +1000,7 @@ def _snap_ZYX(pts_zyx, tagname): use_ds = downsample_geodesic is not None if use_ds: dz, dy, dx = map(int, downsample_geodesic) - log(f"[geodesic] downsample grid: {downsample_geodesic}") + logger.debug(f"[geodesic] downsample grid: {downsample_geodesic}") cost_ds = travel_cost[::dz, ::dy, ::dx] mask_ds = sv[::dz, ::dy, ::dx] sampling_ds = (sampling[0] * dz, sampling[1] * dy, sampling[2] * dx) @@ -948,9 +1016,9 @@ def _to_ds(pts): A_sub = _to_ds(A_roi) B_sub = _to_ds(B_roi) - log(f"[geodesic] seeds on DS grid: A={len(A_sub)}, B={len(B_sub)}") + logger.debug(f"[geodesic] seeds on DS grid: A={len(A_sub)}, B={len(B_sub)}") if len(A_sub) == 0 or len(B_sub) == 0: - log("[geodesic] DS removed all seeds; falling back to full-res") + logger.debug("[geodesic] DS removed all seeds; falling back to full-res") use_ds = False if not use_ds: cost_ds = travel_cost @@ -967,7 +1035,7 @@ def _to_ds(pts): TB, _ = mcpB.find_costs(B_sub, find_all_ends=False) TA = np.where(mask_ds, TA, np.inf) TB = np.where(mask_ds, TB, np.inf) - log( + logger.debug( f"[geodesic] TA/TB computed | {perf_counter()-t2:.3f}s (total {perf_counter()-t0:.3f}s)" ) @@ -982,8 +1050,8 @@ def _to_ds(pts): band = ndi.binary_dilation(band, structure=ball(nb_dilate)) & mask_ds if band.sum() < 64: band = mask_ds.copy() - log("[band] tiny band -> using full ROI on current grid") - log( + logger.debug("[band] tiny band -> using full ROI on current grid") + logger.debug( f"[band] voxels: {int(band.sum())} | {perf_counter()-t3:.3f}s (total {perf_counter()-t0:.3f}s)" ) @@ -1003,7 +1071,7 @@ def _to_ds(pts): sub_labels_ds[z, y, x] = 1 for z, y, x in B_sub: sub_labels_ds[z, y, x] = 2 - log( + logger.debug( f"[label] DS labeling done | {perf_counter()-t4:.3f}s (total {perf_counter()-t0:.3f}s)" ) @@ -1015,7 +1083,7 @@ def _to_ds(pts): sub_labels[z, y, x] = 1 for z, y, x in B_roi: sub_labels[z, y, x] = 2 - log(f"[label] upsampled DS→full ROI") + logger.debug(f"[label] upsampled DS→full ROI") else: sub_labels = sub_labels_ds @@ -1023,34 +1091,50 @@ def _to_ds(pts): out_zyx[sv_zyx] = 1 out_zyx[z0h:z1h, y0h:y1h, x0h:x1h][sub_labels == 1] = 1 out_zyx[z0h:z1h, y0h:y1h, x0h:x1h][sub_labels == 2] = 2 - log("[writeback] labels written to full volume") - - # Enforce single CC per label - if enforce_single_cc: - keptA, movedA = _enforce_single_component( - out_zyx, 1, A, allow3=allow_third_label - ) - keptB, movedB = _enforce_single_component( - out_zyx, 2, B, allow3=allow_third_label - ) - log( - f"[single-cc] label1 kept {keptA}, moved {movedA} -> 3; label2 kept {keptB}, moved {movedB} -> 3" - ) + logger.debug("[writeback] labels written to full volume") - # Resolve 3-touching - moved1, moved2 = _resolve_label3_touching_vectorized(out_zyx, A, B, sampling) - if moved1 or moved2: + # Enforce single CC per label (full res). The upsampled labeling can + # fragment under the foreground mask, so enforcement must run here, + # not on the DS grid. The label-3 resolution inside uses + # cc3d.contacts for adjacency instead of two full-volume dilations. + with _prof.profile("enforce_cc"): if enforce_single_cc: - keptA, movedA = _enforce_single_component( - out_zyx, 1, A, allow3=allow_third_label - ) - keptB, movedB = _enforce_single_component( - out_zyx, 2, B, allow3=allow_third_label - ) - log( - f"[single-cc 2nd] label1 kept {keptA}, moved {movedA}; label2 kept {keptB}, moved {movedB}" + with _prof.profile("label1"): + keptA, movedA = _enforce_single_component( + out_zyx, 1, A, allow3=allow_third_label + ) + with _prof.profile("label2"): + keptB, movedB = _enforce_single_component( + out_zyx, 2, B, allow3=allow_third_label + ) + logger.debug( + f"[single-cc] label1 kept {keptA}, moved {movedA} -> 3; label2 kept {keptB}, moved {movedB} -> 3" ) + with _prof.profile("resolve3"): + # Only reassign strays if any exist. Label 3 is created solely + # by _enforce_single_component above; with none present, + # _resolve_label3_touching_vectorized would early-return after a + # full-volume CC scan, so the np.any check skips that scan. + moved1 = moved2 = 0 + n3 = int(np.count_nonzero(out_zyx == 3)) + logger.note(f"resolve3: label-3 stray voxels: {n3}") + if n3: + moved1, moved2 = _resolve_label3_touching_vectorized( + out_zyx, A, B, sampling + ) + if moved1 or moved2: + if enforce_single_cc: + keptA, movedA = _enforce_single_component( + out_zyx, 1, A, allow3=allow_third_label + ) + keptB, movedB = _enforce_single_component( + out_zyx, 2, B, allow3=allow_third_label + ) + logger.debug( + f"[single-cc 2nd] label1 kept {keptA}, moved {movedA}; label2 kept {keptB}, moved {movedB}" + ) + # Final check for lab in (1, 2): _, ncomp = _cc_label_26(out_zyx == lab) @@ -1059,9 +1143,9 @@ def _to_ds(pts): if raise_if_multi_cc: raise ValueError(msg) else: - log(msg) + logger.debug(msg) - log(f"[done] total elapsed {perf_counter()-t0:.3f}s") + logger.debug(f"[done] total elapsed {perf_counter()-t0:.3f}s") return _from_internal_zyx_volume(out_zyx, vol_order) @@ -1162,38 +1246,43 @@ def build_kdtrees_by_label( def build_coords_by_label( vol: np.ndarray, *, + labels: Optional[Iterable[int]] = None, background: int = 0, min_points: int = 1, dtype: np.dtype = np.float32, ) -> Dict[int, np.ndarray]: - """Group voxel coordinates by label without building kdtrees. - - Returns mapping label -> (N, 3) coordinate array in (z, y, x) order. + """Group voxel coords by label via ``fastremap.point_cloud``. + + Returns ``{label: (M_label, 3) coords in (z, y, x)}`` cast to + ``dtype``. ``fastremap.point_cloud`` is a C++ single-pass + implementation that emits ``uint16`` coords grouped by label, + treating ``0`` as background. + + ``labels`` restricts the output dict; the underlying C++ scan + visits every voxel regardless (faster and lighter than a + label-filtered Python scan). ``min_points`` drops labels with + fewer than that many voxels. ``background != 0`` removes that + label from the result after the call. """ if vol.ndim != 3: raise ValueError("`vol` must be a 3D array.") - Z, Y, X = vol.shape - - flat = vol.ravel() - nz = np.flatnonzero(flat) if background == 0 else np.flatnonzero(flat != background) - if nz.size == 0: - return {} - - labels = flat[nz] - z, y, x = np.unravel_index(nz, (Z, Y, X)) - coords = np.column_stack((z, y, x)).astype(dtype, copy=False) - - order = np.argsort(labels, kind="mergesort") - labels_sorted = labels[order] - starts = np.flatnonzero(np.r_[True, labels_sorted[1:] != labels_sorted[:-1]]) - ends = np.r_[starts[1:], labels_sorted.size] + raw = fastremap.point_cloud(vol) + if background != 0: + raw.pop(background, None) + + if labels is not None: + wanted = {int(x) for x in labels} + if not wanted: + return {} + items = ((k, v) for k, v in raw.items() if int(k) in wanted) + else: + items = raw.items() result: Dict[int, np.ndarray] = {} - for s, e in zip(starts, ends): - n = e - s - if n < min_points: + for k, coords in items: + if coords.shape[0] < min_points: continue - result[int(labels_sorted[s])] = coords[order[s:e]] + result[int(k)] = coords.astype(dtype, copy=False) return result @@ -1266,56 +1355,58 @@ def split_supervoxel_helper( source_coords: np.ndarray, sink_coords: np.ndarray, voxel_size: tuple, - verbose: bool = False, ): voxel_size = np.array(voxel_size) downsample = voxel_size.max() // voxel_size + _prof = get_profiler() # 1) Connect seed teams first - A_aug, B_aug, okA, okB = connect_both_seeds_via_ridge( - binary_seg, - source_coords, - sink_coords, - voxel_size=voxel_size, - downsample=downsample, - vol_order="xyz", - vox_order="xyz", - seed_order="xyz", - snap_method="kdtree", - snap_kwargs=dict( - use_boundary=False, # disables boundary-only snapping for maximum safety - downsample=False, # avoids losing candidates - method="kdtree", - ), - verbose=verbose, - ) + with _prof.profile("connect_seeds"): + A_aug, B_aug, okA, okB = connect_both_seeds_via_ridge( + binary_seg, + source_coords, + sink_coords, + voxel_size=voxel_size, + downsample=downsample, + vol_order="xyz", + vox_order="xyz", + seed_order="xyz", + snap_method="kdtree", + snap_kwargs=dict( + use_boundary=False, # disables boundary-only snapping for maximum safety + downsample=False, # avoids losing candidates + use_bbox=True, # window candidates to a box around the seeds; same result + method="kdtree", + ), + ) if not (okA and okB): raise RuntimeError( "In-mask connection failed for at least one team; skipping split." ) # 2) Run the corridor-free splitter with same snapping settings - return split_supervoxel_growing( - binary_seg, - A_aug, - B_aug, - voxel_size=voxel_size, - vol_order="xyz", - vox_order="xyz", - seed_order="xyz", - halo=1, - gamma_neck=1.6, - narrow_band_rel=0.08, - nb_dilate=1, - downsample_geodesic=(1, 2, 2), - enforce_single_cc=True, - raise_if_seed_split=True, - raise_if_multi_cc=True, - verbose=verbose, - snap_method="kdtree", - snap_kwargs=dict( - use_boundary=False, # match the connector for consistency - downsample=False, - method="kdtree", - ), - ) + with _prof.profile("split_growing"): + return split_supervoxel_growing( + binary_seg, + A_aug, + B_aug, + voxel_size=voxel_size, + vol_order="xyz", + vox_order="xyz", + seed_order="xyz", + halo=1, + gamma_neck=1.6, + narrow_band_rel=0.08, + nb_dilate=1, + downsample_geodesic=(1, 2, 2), + enforce_single_cc=True, + raise_if_seed_split=True, + raise_if_multi_cc=True, + snap_method="kdtree", + snap_kwargs=dict( + use_boundary=False, # match the connector for consistency + downsample=False, + use_bbox=True, # window candidates to a box around the seeds; same result + method="kdtree", + ), + ) diff --git a/pychunkedgraph/debug/sv_split.py b/pychunkedgraph/graph/sv_split/debug.py similarity index 98% rename from pychunkedgraph/debug/sv_split.py rename to pychunkedgraph/graph/sv_split/debug.py index 8caca0fc2..ba6fe51a5 100644 --- a/pychunkedgraph/debug/sv_split.py +++ b/pychunkedgraph/graph/sv_split/debug.py @@ -5,10 +5,10 @@ import numpy as np import fastremap -from ..app.app_utils import handle_supervoxel_id_lookup -from ..graph import attributes -from ..graph.chunkedgraph import ChunkedGraph -from ..graph.edges import Edges +from pychunkedgraph.app.app_utils import handle_supervoxel_id_lookup +from pychunkedgraph.graph import attributes +from pychunkedgraph.graph.chunkedgraph import ChunkedGraph +from pychunkedgraph.graph.edges import Edges def get_subgraph_edges(cg: ChunkedGraph, root_id, bbox): @@ -339,8 +339,6 @@ def trace_stale_sv(cg: ChunkedGraph, sv_id, bbox=None, root_id=None): )[0] chunk_ids = np.unique(cg.get_chunk_ids_from_node_ids(l2ids)) - from ..io.edges import get_chunk_edges - chunk_edges_d = cg.read_chunk_edges(chunk_ids) chunk_edges_all = reduce( lambda x, y: x + y, chunk_edges_d.values(), Edges([], []) diff --git a/pychunkedgraph/graph/sv_split/design.md b/pychunkedgraph/graph/sv_split/design.md new file mode 100644 index 000000000..c44273f63 --- /dev/null +++ b/pychunkedgraph/graph/sv_split/design.md @@ -0,0 +1,64 @@ +# SV splitting — design rationale + +Why the [cut](algorithm.md) is shaped the way it is. + +## Geodesic region-grow, not a voxel graph-mincut + +The split assigns each voxel to the nearer seed by geodesic travel cost rather +than solving a min-cut on a voxel adjacency graph. A region grow over an +anisotropy-aware geodesic metric follows the object's medial geometry and +produces a smooth boundary near the midline, without building and cutting a +large per-voxel graph for every split. + +## Snap seeds to the boundary + +Cut seeds come from operator clicks / mincut output and need not land exactly +on a foreground voxel. Snapping each seed (via KDTree) to the nearest occupied +voxel makes the grow well-defined regardless of where the seed falls, and keeps +the two arrival fields rooted inside the object. + +## Single-CC enforced at full resolution + +Each output side must be a single connected component because the downstream +writer assigns one new supervoxel id per label and does no CC of its own (see +the [single-CC invariant](algorithm.md#the-single-cc-invariant)). +Enforcement runs at full resolution: if the geodesic grow is done on a +downsampled grid, the result is upsampled *first*, because upsampling under the +foreground mask can fragment a label into disconnected pieces — enforcing CC +before that would let the fragments through. + +## Resolve strays, don't drop them + +Small components cut off from a seeded body are relabeled to a transient label +and reassigned to whichever side they border more (EDT tie-break) rather than +discarded. Dropping voxels would lose mass; merging blindly could bridge the +two sides. Border-count reassignment keeps every voxel while respecting the +cut. + +## Label-3 dilation confined to its bounding box + +Reassigning a stray component only needs its local border with each side, so +the dilation is run inside the component's own bounding box rather than over +the whole volume. The per-component border counts are identical either way, so +this is an exact optimization, not an approximation. + +## Single-pass coord extraction, in-place masking + +`build_coords_by_label` groups voxels by label in a single `np.unique` pass and +masks in place, instead of allocating a volume-sized boolean array per label. +The cut likewise works inside the foreground bounding box. Both choices bound +work to the foreground extent rather than the chunk volume. + +## Reads pinned to `parent_ts` + +All graph reads during a split are pinned to the operation's `parent_ts` so a +replay sees the same graph state and allocates the same new supervoxel ids. +This is what makes interrupted splits safe to re-run (see +[Recovery](recovery.md)). + +## Related docs + +- [Overview](README.md) +- [Algorithm](algorithm.md) +- [Edges](edges.md) +- [Recovery](recovery.md) diff --git a/pychunkedgraph/graph/sv_split/edges.md b/pychunkedgraph/graph/sv_split/edges.md new file mode 100644 index 000000000..1552bd3a6 --- /dev/null +++ b/pychunkedgraph/graph/sv_split/edges.md @@ -0,0 +1,139 @@ +# Edge updates after a supervoxel split + +## Context + +A supervoxel split rewrites voxels inside a bbox: a single old SV is replaced by N new fragments (one per chunk × per side of the cut). Every atomic edge that referenced the old SV — to neighbors inside the same root, to neighbors in a different root, and to other pieces of the same physical supervoxel — must now reference an appropriate fragment instead, or the graph hierarchy diverges from the new segmentation. + +Edge update is the second half of `split_supervoxel`. The first half produced a labeled bbox, an `old_new_map` (`old_sv_id → set[new_sv_ids]`), and a `new_id_label_map` (`new_sv_id → cut-side label`). This document covers what happens from there. + +## Algorithm overview + +``` +inputs from voxel-level split + ├─ new_seg bbox volume with new SV IDs in place of old + ├─ old_new_map which old SVs got split, and into which new IDs + └─ new_id_label_map for each new ID, which side of the cut it's on + +update_edges (sv_split/edges.py): + 1. fetch atomic subgraph inside bbox, rooted at the rep's root + 2. dedupe edges, drop self-loops + 3. group by partner-root vs split-root → active / inactive + 4. for each old SV: + inactive partners → broadcast edge to every fragment + active partners → expand split partners, match by label/proximity + intra-fragment → low-affinity edges between every fragment pair + 5. validate (no cross-label inf bridges, no self-loops, completeness) + 6. return new (edges, affinities, areas) + +add_new_edges (sv_split/edges.py): + 1. duplicate bidirectional, group by L2 parent chunk + 2. per chunk: append to SplitEdges (history) and rewrite + CompactedSplitEdges (snapshot, with stale rows filtered) +``` + +## Inputs to `update_edges` + +- `cg, root_id, bbox` — the rep's root and the bbox the voxel-level cut acted on. +- `new_seg` — segmentation in the read window (bbox + 1-voxel shell). The shell is what makes anchor lookups work for unsplit pieces of the rep on the other side of a chunk boundary; without it, cross-chunk edges from those pieces would route to whatever happens to lie at the boundary face, not to the actual fragment the cut produced. +- `old_new_map` — drives which edges need re-routing. +- `new_id_label_map` — used to pair fragments with the same cut-side label across cross-chunk edges. + +`update_edges` calls `cg.get_subgraph(root_id, bbox, bbox_is_coordinate=True)`. This returns every atomic edge whose endpoint sits in the bbox under the rep's root. That set already includes both intra-cut edges (between split SVs) and the cross-chunk-shell edges to neighbors outside the rewritten region. + +After fetch, edges are sorted within each pair, deduped, and self-loops filtered. The remaining set is the input to classification. + +## Classification + +For each edge, the partner's root determines the routing path. `sv_root_map` is built from one batched `cg.get_roots(...)` over all unique partners. + +### Inactive partner (`partner_root != root_id`) + +The partner sits in a different agglomerated object. The split's cut-side has no semantic relationship to that neighbor — *any* fragment of the old SV that touched the neighbor's voxels still touches them after the split. **Broadcast**: for each old SV split into N fragments, copy the edge to every fragment, preserving affinity and area. + +This intentionally over-creates edges. They cost nothing if both endpoints stay in different roots forever; they collapse harmlessly into a single root-level edge if the two roots later merge. + +### Active partner (`partner_root == root_id`) + +The partner is inside the same agglomerated object as the rep — the partner is either: + +- another piece of the same physical SV (cross-chunk-connected), +- a different SV in the same root reachable via L2 hierarchy. + +For active partners, edges are routed based on affinity type: + +#### Inf-affinity, partner also split + +The partner SV is itself in `old_new_map` (e.g. it's another piece of the rep that was rewritten). We need each new fragment of the old SV to connect to the *matching-label* new fragment of the partner — the one on the same cut-side. `_match_by_label` does this lookup via `new_id_label_map`. If no fragment of the partner shares the source SV's label (rare, indicates a partial split), fallback to the closest fragment by distance. + +#### Inf-affinity, partner unsplit + +This is the cross-chunk edge to a piece of the rep that the bbox didn't include — by construction with the seed-driven bbox, these are the rep's far-away pieces that keep their old IDs. The unsplit partner has no `new_id_label_map` entry. + +**Critical**: do *not* broadcast this edge to all fragments. An unsplit partner connected via inf-affinity to fragments on both sides of the cut would form an uncuttable bridge — a future mincut on this object would route through `frag_a → unsplit_partner → frag_b` with infinite affinity and never separate them. So `_match_inf_unsplit` assigns the edge to exactly one fragment: the one closest to the partner. + +`validate_split_edges` enforces this with check (A): no inf-affinity edge from an unsplit partner to fragments with different cut-side labels. + +#### Finite-affinity (regular) + +Real adjacency edges between SVs based on per-pair affinity. `_match_by_proximity` assigns the edge to *every* fragment within `cg.meta.sv_split_threshold` voxels of the partner, fallback to closest if none qualify. Multiple fragments may legitimately neighbor the partner; the threshold preserves the original adjacency where it actually exists. + +### Intra-fragment edges + +For each old SV split into multiple new fragments, add a low-affinity (0.001) edge between every pair of fragments. These are cuttable by future mincut operations — they record that the fragments share a graph-level neighborhood (they came from the same SV) without forcing them to stay agglomerated. Without these edges, an entirely-disconnected fragment of an old SV would have no link to the rest of the object; with them, the standard mincut machinery handles the relationship correctly. + +## Distance computation + +Distances drive both proximity matching and the closest-fragment fallback. Each fragment gets a `cKDTree` over its voxels (built from `build_coords_by_label(new_seg)`). + +- **Partner inside bbox**: build a kdtree on the partner's voxels too. For each fragment, the smaller-tree-queries-larger heuristic minimizes work; result is the minimum voxel distance. +- **Active partner outside bbox**: the partner's voxels aren't in `new_seg`. `_compute_boundary_distances` uses the partner's chunk coordinate to determine which face of the source chunk the edge crosses, then measures each fragment's distance to that boundary plane. This is an over-estimate for non-boundary-aligned partners but it's the only signal available without extra reads. + +## Validation + +`validate_split_edges` checks four invariants and raises `PostconditionError` on any violation. Failures abort the operation cleanly under the indefinite L2 chunk lock; the recovery flow then handles cleanup. + +| Check | Why | +|-------|-----| +| (A) No inf-affinity bridges between cut-sides via an unsplit partner | Would be uncuttable by future mincuts | +| (B) No self-loops | Indicates a routing bug; would skew degree counts and break some traversal assumptions | +| (C) Every old SV has at least one replacement edge from its fragments | Catches old SVs that vanished from the edge set entirely (would orphan them in the hierarchy) | +| (D) All fragment pairs of each old SV are connected | Confirms the intra-fragment low-affinity edges were emitted | + +These run before any bigtable write, so the validation is the last line of defense before the writes commit under the lock. + +## Persisting: `add_new_edges` + +The new edges are batched into bigtable per L2 chunk. Two columns get written per chunk per op: + +### `SplitEdges` (history) + +An append-only column. Each split op writes its new edges as a fresh cell with the op's logical timestamp. Time-travel reads at any timestamp T walk all cells with `ts ≤ T`, then apply the stale-edge resolution path to filter out edges whose endpoints have been superseded by later ops. This is the authoritative store for historical reads. + +### `CompactedSplitEdges` (snapshot) + +A latest-only column for fast current-time reads. On each op: + +1. Read the previous compacted cell (if any) plus its matching `CompactedAffinity` and `CompactedArea`. +2. Filter out rows whose endpoints reference any old SV in `old_new_map.keys()` (these are the SVs that just got split — their edges are stale). +3. Concatenate the new rows. +4. Write the whole thing as one fresh cell. + +Current-time readers can take this single cell directly without history walks or stale-edge resolution. + +The chunk grouping uses each edge's first endpoint's L2 parent chunk: `cg.get_chunk_ids_from_node_ids(cg.get_parents(nodes))`. Parent chunks (not the SV's own L1 chunk) is the correct routing — the edge belongs to the chunk where its endpoint lives in the L2 hierarchy. Bidirectional duplication ensures every edge is owned by both endpoints' chunks; readers picking up either side find it. + +Both writes use `time_stamp=task.operation_ts`, so all rows from one op land at the same logical time. Concurrent SV-splits on disjoint chunks don't interfere because they write disjoint chunk rows. + +## Invariants + +- For every old SV in `old_new_map`, every atomic edge that referenced it in the pre-split graph has at least one corresponding edge among its fragments after the split. +- No inf-affinity edge crosses cut-sides through an unsplit partner. +- Every cross-chunk piece of the rep that the bbox didn't include keeps its old ID and its existing edges resolve unchanged (because no edge in those rows references the now-split SVs at endpoints — the routing only touches edges whose endpoints are in the bbox or its 1-voxel shell). +- `SplitEdges` and `CompactedSplitEdges` agree at the latest timestamp: the compacted snapshot is the result of replaying the history through the stale-edge filter. + +## Related docs + +- [Overview](README.md) +- [Algorithm](algorithm.md) — the cut that produces the labeling this step consumes. +- [Design](design.md) +- [Recovery](recovery.md) diff --git a/pychunkedgraph/graph/edges_sv.py b/pychunkedgraph/graph/sv_split/edges.py similarity index 93% rename from pychunkedgraph/graph/edges_sv.py rename to pychunkedgraph/graph/sv_split/edges.py index ea9354990..2d49af158 100644 --- a/pychunkedgraph/graph/edges_sv.py +++ b/pychunkedgraph/graph/sv_split/edges.py @@ -23,8 +23,8 @@ Distance computation: For partners within the segmentation bbox, distances are precomputed via kdtree pairwise distances. For active partners outside the bbox (e.g. - cross-chunk fragments excluded by _get_whole_sv's bbox clipping), distances - are computed from each new fragment's kdtree to the partner's chunk boundary. + cross-chunk fragments not in the rep's CC member set), distances are + computed from each new fragment's kdtree to the partner's chunk boundary. """ from __future__ import annotations @@ -38,10 +38,11 @@ import numpy as np from pychunkedgraph import get_logger +from pychunkedgraph.profiler import get_profiler from pychunkedgraph.graph import attributes, basetypes, serializers from pychunkedgraph.graph.exceptions import PostconditionError from scipy.spatial import cKDTree -from pychunkedgraph.graph.cutting_sv import build_coords_by_label +from .cutting import build_coords_by_label from pychunkedgraph.graph.edges import Edges if TYPE_CHECKING: @@ -349,15 +350,11 @@ def update_edges( new_seg: np.ndarray, old_new_map: dict, new_id_label_map: dict = None, + parent_ts: datetime = None, ): old_new_map = dict(old_new_map) - t0 = time.time() - coords_by_label = build_coords_by_label(new_seg) + _prof = get_profiler() new_ids = np.array(list(set.union(*old_new_map.values())), dtype=basetypes.NODE_ID) - new_kdtrees = [cKDTree(coords_by_label[int(k)]) for k in new_ids] - logger.note( - f"build_coords {len(coords_by_label)} labels, {len(new_ids)} fragment trees ({time.time() - t0:.2f}s)" - ) t0 = time.time() _, edges_tuple = cg.get_subgraph(root_id, bbox, bbox_is_coordinate=True) @@ -380,10 +377,31 @@ def update_edges( t0 = time.time() all_edge_svs = np.unique(edges) - all_roots = cg.get_roots(all_edge_svs) + all_roots = cg.get_roots(all_edge_svs, time_stamp=parent_ts) sv_root_map = dict(zip(all_edge_svs, all_roots)) logger.note(f"get_roots {len(all_edge_svs)} svs ({time.time() - t0:.2f}s)") + # Coords are only ever read for new fragment ids (kdtrees) and for + # partners queried via coords_by_label.get(...) in _get_new_edges. + # Partners can only come from subgraph-edge endpoints, so this + # union is a tight superset of every key that gets looked up. + t0 = time.time() + with _prof.profile("build_coords"): + # Zero out every label whose coords nothing downstream will + # query, in place. fastremap.point_cloud (called by + # build_coords_by_label) groups by every nonzero label in the + # vol, so trimming the input is the only way to shrink its + # C++ scan; the labels= kwarg only filters the result dict + # after the scan. Caller does not read seg after update_edges + # returns, so the in-place mutation is safe. + wanted_labels = np.union1d(new_ids, all_edge_svs) + fastremap.mask_except(new_seg, list(wanted_labels), in_place=True) + coords_by_label = build_coords_by_label(new_seg) + new_kdtrees = [cKDTree(coords_by_label[int(k)]) for k in new_ids] + logger.note( + f"build_coords {len(coords_by_label)} labels, {len(new_ids)} fragment trees ({time.time() - t0:.2f}s)" + ) + t0 = time.time() result = _get_new_edges( (edges, affinities, areas), diff --git a/pychunkedgraph/graph/sv_split/edits.py b/pychunkedgraph/graph/sv_split/edits.py new file mode 100644 index 000000000..13b9b7878 --- /dev/null +++ b/pychunkedgraph/graph/sv_split/edits.py @@ -0,0 +1,775 @@ +""" +Manage new supervoxels after a supervoxel split. +""" + +import time +from dataclasses import dataclass +from datetime import datetime +from collections import defaultdict +from typing import TYPE_CHECKING, List, Tuple + +import fastremap +import numpy as np + +from pychunkedgraph import get_logger +from pychunkedgraph.profiler import get_profiler +from pychunkedgraph.graph import ( + attributes, + cache as cache_utils, + basetypes, + serializers, +) +from pychunkedgraph.graph.chunks.utils import chunks_overlapping_bbox +from .cutting import split_supervoxel_helper +from .edges import update_edges, add_new_edges +from pychunkedgraph.graph.utils import get_local_segmentation + +if TYPE_CHECKING: + from pychunkedgraph.graph.chunkedgraph import ChunkedGraph + +logger = get_logger(__name__) + + +@dataclass +class SvSplitTask: + """One SV-split task per cross-chunk rep. + + Produced by `plan_sv_splits` (pure, no IO), consumed by + `split_supervoxel`. `src_mask`/`sink_mask` are positional masks + back into the caller's `source_ids`/`sink_ids` arrays so the + aggregator can splice the per-task fresh IDs in at the right + positions. + """ + + sv_id: int + src_coords: np.ndarray + sink_coords: np.ndarray + src_mask: np.ndarray + sink_mask: np.ndarray + bbs: np.ndarray + bbe: np.ndarray + + +@dataclass +class _SplitCtx: + """Per-task context shared across the split stage helpers. + + Holds the inputs that every stage threads through unchanged. `seg` is + mutated in place across stages (fresh IDs written, then root mask); + the reference is stable, so storing it here is sound. + """ + + cg: "ChunkedGraph" + seg: np.ndarray + bbs: np.ndarray + bbe: np.ndarray + bbs_: np.ndarray + bbe_: np.ndarray + sv_id: int + sv_ids: np.ndarray + source_coords: np.ndarray + sink_coords: np.ndarray + operation_id: int + time_stamp: datetime + parent_ts: datetime + + +@dataclass +class _ApplyResult: + """Outputs of `_apply_and_capture` consumed by the orchestrator.""" + + old_new_map: dict + new_id_label_map: dict + seg_write_pairs: List[Tuple[Tuple[slice, slice, slice], np.ndarray]] + src_new_ids: np.ndarray + sink_new_ids: np.ndarray + + +@dataclass +class SvSplitOutcome: + """Output of `split_supervoxel` for one task. Aggregated into + `SplitResult` by `split_supervoxels`.""" + + seg_bbox: Tuple[np.ndarray, np.ndarray] + src_new_ids: np.ndarray + sink_new_ids: np.ndarray + # Per-chunk OCDBT write payloads for this task. + seg_write_pairs: List[Tuple[Tuple[slice, slice, slice], np.ndarray]] + bigtable_rows: list + + +@dataclass +class SplitResult: + """Pure planner output of `split_supervoxels`. + + The caller (`MulticutOperation._apply`) performs the actual writes + under the L2 chunk locks: + - `seg_writes` is fed to `write_seg_chunks` as one flat parallel batch. + - `bigtable_rows` is written via `cg.client.write` in one batch. + """ + + seg_bboxes: List[Tuple[np.ndarray, np.ndarray]] + source_ids_fresh: np.ndarray + sink_ids_fresh: np.ndarray + # Flat list across all tasks: (voxel_slices, data_block) per OCDBT + # chunk write. `voxel_slices` is a 3-tuple of `slice` objects; the + # caller appends the channel slice and writes to `meta.ws_ocdbt`. + seg_writes: List[Tuple[Tuple[slice, slice, slice], np.ndarray]] + bigtable_rows: list + + +def _coords_bbox( + cg: "ChunkedGraph", + src_coords_rep: np.ndarray, + sink_coords_rep: np.ndarray, +) -> tuple: + """Base-voxel bbox covering the user's source/sink seeds plus a margin. + + The cut surface lives between the user-placed source and sink + voxels; voxels of the rep that are far from those seeds never + contribute to the cut. So the read region is the seeds' envelope, + not the rep's full chunk envelope — for a physical SV cut into many + pieces across chunks, this can be orders of magnitude smaller. + + The margin is one CG chunk on each side. It matches the existing + L2 chunk lock margin and the 1-voxel shell read in + `split_supervoxel`, and gives `split_supervoxel_helper` headroom + around the seeds for the cut surface to travel along the SV. + + Pieces of the rep that fall outside the bbox keep their existing + IDs — they aren't read here and aren't rewritten. Cross-chunk-edge + routing for boundary-adjacent pieces is handled by the 1-voxel + shell at read time; cross-chunk edges entirely between unsplit + pieces don't change because their IDs don't change. + """ + coords = np.concatenate([src_coords_rep, sink_coords_rep], axis=0) + margin = np.array(cg.meta.graph_config.CHUNK_SIZE, dtype=int) + vol_start = cg.meta.voxel_bounds[:, 0] + vol_end = cg.meta.voxel_bounds[:, 1] + bbs = np.clip(coords.min(axis=0) - margin, vol_start, vol_end) + bbe = np.clip(coords.max(axis=0) + margin, vol_start, vol_end) + return bbs, bbe + + +def _l2_chunks_for_splits(cg: "ChunkedGraph", per_rep_bboxes: list) -> list[int]: + """Layer-2 chunk IDs every rep's split will read or write. + + Reads extend 1 voxel past `[bbs, bbe]` so `update_edges` has anchor + voxels for cross-chunk neighbors; the lock must cover those neighbor + chunks too, hence the `bbs - 1` / `bbe + 1` expansion. Clipped to + volume bounds so a bbox on the volume edge doesn't enumerate phantom + negative-index chunks. Sorted for deterministic lock-acquire order + (L2ChunkLock relies on sorted input for deadlock avoidance). + """ + vol_start = cg.meta.voxel_bounds[:, 0] + vol_end = cg.meta.voxel_bounds[:, 1] + chunk_size = cg.meta.graph_config.CHUNK_SIZE + chunk_coords = set() + for bbs, bbe in per_rep_bboxes: + read_lo = np.clip(bbs - 1, vol_start, vol_end) + read_hi = np.clip(bbe + 1, vol_start, vol_end) + chunk_coords.update( + chunks_overlapping_bbox( + read_lo, read_hi, chunk_size, origin=vol_start + ).keys() + ) + return sorted( + int(cg.get_chunk_id(layer=2, x=x, y=y, z=z)) for (x, y, z) in chunk_coords + ) + + +def _overlapping_reps( + *, + sv_remapping: dict, + source_ids: np.ndarray, + sink_ids: np.ndarray, + source_coords: np.ndarray, + sink_coords: np.ndarray, +): + """Yield per-rep data for every rep that links source and sink. + + A rep is a cross-chunk-representative SV shared by at least one + source and one sink in `sv_remapping`. These are the SVs that must + be split before the multicut can partition source from sink. + + Yields `(sv_id, src_coords_rep, sink_coords_rep, src_mask, sink_mask)`: + sv_id — one of the rep's source SV IDs, used as the + seed for `split_supervoxel`. + src_coords_rep — slice of source_coords whose SV maps to this rep. + sink_coords_rep — slice of sink_coords whose SV maps to this rep. + src_mask — positional boolean mask over source_ids; the + caller uses it to splice per-rep results back + into the full source arrays. + sink_mask — same, for sink_ids. + + Keyword-only signature — positional source/sink args of the same + shape are easy to swap without noticing. + """ + sources_remapped = fastremap.remap( + source_ids, sv_remapping, preserve_missing_labels=True, in_place=False + ) + sinks_remapped = fastremap.remap( + sink_ids, sv_remapping, preserve_missing_labels=True, in_place=False + ) + overlap_mask = np.isin(sources_remapped, sinks_remapped) + for rep in np.unique(sources_remapped[overlap_mask]): + src_mask = sources_remapped == rep + sink_mask = sinks_remapped == rep + yield ( + source_ids[src_mask][0], + source_coords[src_mask], + sink_coords[sink_mask], + src_mask, + sink_mask, + ) + + +def plan_sv_splits( + cg: "ChunkedGraph", + *, + sv_remapping: dict, + source_ids: np.ndarray, + sink_ids: np.ndarray, + source_coords: np.ndarray, + sink_coords: np.ndarray, +) -> Tuple[List[SvSplitTask], list]: + """Compute one `SvSplitTask` per rep and the L2 chunk set the splits + will touch. + + Pure function — no bigtable/OCDBT IO, no locks. Lets the caller + acquire the L2 chunk locks (both temporal and indefinite) around + `split_supervoxels` without recomputing the plan inside. + + Returns `(tasks, chunk_ids)` — `tasks` feeds `split_supervoxels`, + `chunk_ids` is the sorted union of read-expanded L2 chunks the full + operation touches. + """ + tasks: List[SvSplitTask] = [] + for ( + sv_id, + src_coords_rep, + sink_coords_rep, + src_mask, + sink_mask, + ) in _overlapping_reps( + sv_remapping=sv_remapping, + source_ids=source_ids, + sink_ids=sink_ids, + source_coords=source_coords, + sink_coords=sink_coords, + ): + bbs, bbe = _coords_bbox(cg, src_coords_rep, sink_coords_rep) + tasks.append( + SvSplitTask( + sv_id=sv_id, + src_coords=src_coords_rep, + sink_coords=sink_coords_rep, + src_mask=src_mask, + sink_mask=sink_mask, + bbs=bbs, + bbe=bbe, + ) + ) + chunk_ids = _l2_chunks_for_splits(cg, [(t.bbs, t.bbe) for t in tasks]) + return tasks, chunk_ids + + +def split_supervoxels( + cg: "ChunkedGraph", + *, + tasks: List[SvSplitTask], + sv_remapping: dict, + source_ids: np.ndarray, + sink_ids: np.ndarray, + operation_id: int, + timestamp: datetime = None, + parent_ts: datetime = None, +) -> SplitResult: + """Pure planner for the SV-split step. Returns a `SplitResult` with + all the data the caller needs to persist under locks. + + Does **not** write — the caller (`MulticutOperation._apply`) owns + the L2 chunk lock lifecycle and fires the OCDBT + bigtable writes + inside `IndefiniteL2ChunkLock`. + + Must be called inside the caller's `L2ChunkLock` for the + `plan.chunk_ids` set — the seg reads inside `split_supervoxel` need + to be consistent with concurrent writers. + + `timestamp` is the op's logical write time; threaded down to every + `mutate_row` in the persist block so all new-SV cells land at the + same logical time (atomic visibility for `parent_ts`-filtered + readers, and deterministic replay via `override_ts`). + + Fields on the returned `SplitResult`: + seg_bboxes: per-task base-resolution `(bbs, bbe)` — downsample + worker input. + source_ids_fresh / sink_ids_fresh: input `source_ids`/`sink_ids` + with positions touched by an overlap task replaced by the + new SV ID that now lives at that coord. Untouched positions + stay unchanged. Feeds the retry multicut. + seg_writes: flat list of `(voxel_slices, data)` pairs across all + tasks — one tensorstore write per pair, fired in parallel. + bigtable_rows: flattened rows from `copy_parents_and_add_lineage` + + `add_new_edges` across all tasks. + """ + source_ids_fresh = np.asarray(source_ids, dtype=basetypes.NODE_ID).copy() + sink_ids_fresh = np.asarray(sink_ids, dtype=basetypes.NODE_ID).copy() + + seg_bboxes = [] + seg_writes: List[Tuple[Tuple[slice, slice, slice], np.ndarray]] = [] + bigtable_rows: list = [] + for task in tasks: + out = split_supervoxel( + cg, + task, + operation_id, + sv_remapping=sv_remapping, + time_stamp=timestamp, + parent_ts=parent_ts, + ) + seg_bboxes.append(out.seg_bbox) + source_ids_fresh[task.src_mask] = out.src_new_ids + sink_ids_fresh[task.sink_mask] = out.sink_new_ids + seg_writes.extend(out.seg_write_pairs) + bigtable_rows.extend(out.bigtable_rows) + return SplitResult( + seg_bboxes=seg_bboxes, + source_ids_fresh=source_ids_fresh, + sink_ids_fresh=sink_ids_fresh, + seg_writes=seg_writes, + bigtable_rows=bigtable_rows, + ) + + +def _update_chunks(cg: "ChunkedGraph", chunks_bbox_map, seg, result_seg, bb_start): + """Process all chunks in a single pass: assign new SV IDs to split fragments. + + Returns `(results, change_chunks)`: + results: per-chunk (indices, old_values, new_values, label_id_map) + tuples; consumed by `_parse_results`. + change_chunks: `(chunk_coord, chunk_bbox)` for the chunks whose + voxels received new SV IDs. `write_seg_chunks` uses this to + rewrite only those chunks (skipping gap chunks that had no + split activity keeps the OCDBT delta proportional to actual + label changes). + """ + results = [] + change_chunks = [] + for chunk_coord, chunk_bbox in chunks_bbox_map.items(): + x, y, z = chunk_coord + chunk_id = cg.get_chunk_id(layer=1, x=x, y=y, z=z) + + _s, _e = chunk_bbox - bb_start + og_chunk_seg = seg[_s[0] : _e[0], _s[1] : _e[1], _s[2] : _e[2]] + chunk_seg = result_seg[_s[0] : _e[0], _s[1] : _e[1], _s[2] : _e[2]] + + labels = fastremap.unique(chunk_seg[chunk_seg != 0]) + if labels.size < 2: + continue + + new_ids = cg.id_client.create_node_ids(chunk_id, size=len(labels)) + _indices = [] + _old_values = [] + _new_values = [] + _label_id_map = {} + for _id, new_id in zip(labels, new_ids): + _mask = chunk_seg == _id + voxel_locs = np.where(_mask) + _og_value = og_chunk_seg[ + voxel_locs[0][0], voxel_locs[1][0], voxel_locs[2][0] + ] + _index = np.column_stack(voxel_locs) + n = len(_index) + _indices.append(_index) + _old_values.append(np.full(n, _og_value, dtype=basetypes.NODE_ID)) + _new_values.append(np.full(n, new_id, dtype=basetypes.NODE_ID)) + _label_id_map[int(_id)] = new_id + + _indices = np.concatenate(_indices) + (chunk_bbox[0] - bb_start) + _old_values = np.concatenate(_old_values) + _new_values = np.concatenate(_new_values) + results.append((_indices, _old_values, _new_values, _label_id_map)) + change_chunks.append((chunk_coord, chunk_bbox)) + return results, change_chunks + + +def _voxel_crop(bbs, bbe, bbs_, bbe_): + xS, yS, zS = bbs - bbs_ + xE, yE, zE = (None if i == 0 else -1 for i in bbe_ - bbe) + voxel_overlap_crop = np.s_[xS:xE, yS:yE, zS:zE] + return voxel_overlap_crop + + +def _assert_same_chunk(cg: "ChunkedGraph", old_new_map: dict) -> None: + """Every new SV must live in the same chunk as the SV it split from. + + PCG segment IDs are unique only within a chunk; a split fragment that + landed in a different chunk than its parent would break the hierarchy. + """ + olds = np.fromiter(old_new_map.keys(), dtype=basetypes.NODE_ID) + news = np.fromiter( + (n for ns in old_new_map.values() for n in ns), dtype=basetypes.NODE_ID + ) + expected = np.repeat( + cg.get_chunk_ids_from_node_ids(olds), [len(ns) for ns in old_new_map.values()] + ) + assert np.array_equal( + cg.get_chunk_ids_from_node_ids(news), expected + ), "new supervoxel landed in a different chunk than the SV it split from" + + +def _parse_results(results, seg, bbs, bbe): + """Merge per-chunk split results into a single segmentation volume. + + Applies new SV IDs from each chunk's split result to `seg` (in-place) + and builds the old→new mapping + label→new-id mapping. + + Returns (seg, old_new_map, new_id_label_map). + """ + old_new_map = defaultdict(set) + new_id_label_map = {} + for result in results: + if result: + indexer, old_values, new_values, label_id_map = result + seg[tuple(indexer.T)] = new_values + for old_sv, new_sv in zip(old_values, new_values): + old_new_map[old_sv].add(new_sv) + for label, new_id in label_id_map.items(): + new_id_label_map[new_id] = label + + assert np.all(seg.shape == bbe - bbs), f"{seg.shape} != {bbe - bbs}" + return seg, old_new_map, new_id_label_map + + +def _read_seg_and_ids(cg: "ChunkedGraph", bbs, bbe): + """Read seg over [bbs-1, bbe+1] and return its distinct SV IDs. + + The 1-voxel shell gives update_edges anchor voxels from neighbouring + SVs. Returns (seg, sv_ids, bbs_, bbe_). + """ + vol_start = cg.meta.voxel_bounds[:, 0] + vol_end = cg.meta.voxel_bounds[:, 1] + bbs_ = np.clip(bbs - 1, vol_start, vol_end) + bbe_ = np.clip(bbe + 1, vol_start, vol_end) + _prof = get_profiler() + t0 = time.time() + with _prof.profile("seg_read"): + seg = get_local_segmentation(cg.meta, bbs_, bbe_).squeeze() + logger.note(f"segmentation read {seg.shape} ({time.time() - t0:.2f}s)") + + with _prof.profile("seg_unique"): + # Unique per chunk on the segment-id field only. Segment IDs are + # injective within a chunk and narrower than uint64, so the per- + # block sort is cheap; the chunk bits are OR'd back before the + # final union. The lattice is anchored at voxel_bounds[:, 0] so + # each block is exactly one chunk. Background 0 carries no chunk + # and is restored once at the end. + chunk_map = chunks_overlapping_bbox( + bbs_, bbe_, cg.meta.graph_config.CHUNK_SIZE, origin=vol_start + ) + parts = [] + has_zero = False + for (cx, cy, cz), cbbox in chunk_map.items(): + s, e = cbbox[0] - bbs_, cbbox[1] - bbs_ + sub = seg[s[0] : e[0], s[1] : e[1], s[2] : e[2]] + chunk_id = np.uint64(cg.get_chunk_id(layer=1, x=cx, y=cy, z=cz)) + limit = np.uint64(cg.get_segment_id_limit(chunk_id)) + narrow = np.min_scalar_type(int(limit)) + u = fastremap.unique((sub & limit).astype(narrow, copy=False)) + if u.size and u[0] == 0: + has_zero = True + u = u[1:] + parts.append(u.astype(np.uint64) | chunk_id) + sv_ids = np.unique(np.concatenate(parts)) if parts else np.array([], np.uint64) + if has_zero: + sv_ids = np.concatenate([[np.uint64(0)], sv_ids]) + return seg, sv_ids, bbs_, bbe_ + + +def _select_cut_supervoxels(sv_id, sv_ids, rep_pieces): + """Narrow the rep to the pieces actually present in the bbox seg. + + Rep pieces whose voxels lie outside the seed-driven bbox don't appear + in seg and contribute nothing to the cut. Returns (cut_supervoxels, + supervoxel_ids). + """ + seg_ids = {int(x) for x in sv_ids if x != 0} + cut_supervoxels = rep_pieces & seg_ids + supervoxel_ids = np.array(list(cut_supervoxels), dtype=basetypes.NODE_ID) + logger.note( + f"whole sv {sv_id} -> {supervoxel_ids.tolist()} " + f"({len(rep_pieces) - len(cut_supervoxels)} rep pieces outside bbox)" + ) + return cut_supervoxels, supervoxel_ids + + +def _compute_split(ctx: _SplitCtx, supervoxel_ids): + """Build the binary mask over the overlap crop and run the cut. + + Returns (split_result, voxel_overlap_crop). + """ + _prof = get_profiler() + with _prof.profile("binary_seg"): + # Per-SV OR over the overlap crop: each `== sv` is one C pass with + # no seg-size auxiliaries (unlike np.isin's sort+search), and the + # crop is the only region split_supervoxel_helper consumes. + voxel_overlap_crop = _voxel_crop(ctx.bbs, ctx.bbe, ctx.bbs_, ctx.bbe_) + seg_overlap = ctx.seg[voxel_overlap_crop] + binary_seg = np.zeros(seg_overlap.shape, dtype=bool) + for sv in supervoxel_ids: + binary_seg |= seg_overlap == sv + t0 = time.time() + with _prof.profile("geodesic_split"): + split_result = split_supervoxel_helper( + binary_seg, + ctx.source_coords - ctx.bbs, + ctx.sink_coords - ctx.bbs, + ctx.cg.meta.resolution, + ) + logger.note(f"split computation {split_result.shape} ({time.time() - t0:.2f}s)") + return split_result, voxel_overlap_crop + + +def _apply_and_capture( + ctx: _SplitCtx, voxel_overlap_crop, split_result, cut_supervoxels +): + """Apply fresh IDs to seg's crop and capture the write/lookup outputs. + + Writes fresh SV IDs into seg's overlap crop in place (a view, no + full-crop copy; _parse_results only writes, never reads crop values), + then captures the OCDBT write payloads and the src/sink id lookups + while the crop still holds unmasked neighbour IDs. Everything here + runs before the root mask, which would otherwise zero the neighbour + IDs the write must preserve. Returns an `_ApplyResult`. + """ + cg, seg, bbs, bbe = ctx.cg, ctx.seg, ctx.bbs, ctx.bbe + _prof = get_profiler() + chunks_bbox_map = chunks_overlapping_bbox( + bbs, bbe, cg.meta.graph_config.CHUNK_SIZE, origin=cg.meta.voxel_bounds[:, 0] + ) + t0 = time.time() + results, change_chunks = _update_chunks( + cg, chunks_bbox_map, seg[voxel_overlap_crop], split_result, bbs + ) + logger.note( + f"chunk updates {len(chunks_bbox_map)} chunks, " + f"{len(change_chunks)} with splits ({time.time() - t0:.2f}s)" + ) + + with _prof.profile("parse_results"): + new_seg = seg[voxel_overlap_crop] + new_seg, old_new_map, new_id_label_map = _parse_results( + results, new_seg, bbs, bbe + ) + _assert_same_chunk(cg, old_new_map) + logger.note( + f"old_new_map: {len(old_new_map)} SVs split, whole_sv: {len(cut_supervoxels)} SVs" + ) + unsplit = cut_supervoxels - set(old_new_map.keys()) + if unsplit: + logger.note(f"unsplit SVs (kept IDs): {unsplit}") + + # .copy() per changed chunk detaches each payload from seg before the + # mask / update_edges mutate it; changed chunks only, so the copies + # stay proportional to the edit. The caller batches them into one + # parallel tensorstore write. + seg_write_pairs: List[Tuple[Tuple[slice, slice, slice], np.ndarray]] = [] + for _, chunk_bbox in change_chunks: + lo, hi = chunk_bbox[0], chunk_bbox[1] + local_lo = lo - bbs + local_hi = hi - bbs + data = new_seg[ + local_lo[0] : local_hi[0], + local_lo[1] : local_hi[1], + local_lo[2] : local_hi[2], + ].copy() + voxel_slices = tuple(slice(int(s), int(e)) for s, e in zip(lo, hi)) + seg_write_pairs.append((voxel_slices, data)) + + local_src = (np.asarray(ctx.source_coords, dtype=int) - bbs).astype(int) + local_sink = (np.asarray(ctx.sink_coords, dtype=int) - bbs).astype(int) + src_new_ids = new_seg[tuple(local_src.T)].copy() + sink_new_ids = new_seg[tuple(local_sink.T)].copy() + return _ApplyResult( + old_new_map=old_new_map, + new_id_label_map=new_id_label_map, + seg_write_pairs=seg_write_pairs, + src_new_ids=src_new_ids, + sink_new_ids=sink_new_ids, + ) + + +def _route_edges_and_rows(ctx: _SplitCtx, old_new_map, new_id_label_map): + """Resolve the split's root, route edges, build bigtable rows. + + Returns the flat list of bigtable rows. + """ + cg, seg = ctx.cg, ctx.seg + _prof = get_profiler() + with _prof.profile("get_roots"): + roots = cg.get_roots(ctx.sv_ids, time_stamp=ctx.parent_ts) + sv_root_map = dict(zip(ctx.sv_ids, roots)) + root = sv_root_map[ctx.sv_id] + logger.note(f"{ctx.sv_id} -> {root}") + + t0 = time.time() + with _prof.profile("update_edges"): + edges_tuple = update_edges( + cg, + root, + np.array([ctx.bbs, ctx.bbe]), + seg, + old_new_map, + new_id_label_map, + parent_ts=ctx.parent_ts, + ) + logger.note(f"edge update ({time.time() - t0:.2f}s)") + + rows0 = copy_parents_and_add_lineage( + cg, ctx.operation_id, old_new_map, time_stamp=ctx.time_stamp + ) + rows1 = add_new_edges(cg, edges_tuple, old_new_map, time_stamp=ctx.time_stamp) + return rows0 + rows1 + + +def split_supervoxel( + cg: "ChunkedGraph", + task: SvSplitTask, + operation_id: int, + *, + sv_remapping: dict, + time_stamp: datetime = None, + parent_ts: datetime = None, +) -> SvSplitOutcome: + """Split one cross-chunk-connected SV into connected components. + + `task.bbs` / `task.bbe` are the base-voxel bbox covering the user's + source and sink seeds plus a one-chunk margin — `plan_sv_splits` + pre-computed this via `_coords_bbox`. The bbox is driven by where + the user wants the cut, not by the rep's full chunk envelope; rep + pieces outside the bbox aren't read and keep their existing IDs. + + `time_stamp` is the op's logical write time; threaded through to + `copy_parents_and_add_lineage` + `add_new_edges` so every new-SV + mutation lands at the same timestamp. + """ + sv_id = task.sv_id + bbs = task.bbs + bbe = task.bbe + + logger.note(f"cg.meta.ws_ocdbt: {cg.meta.ws_ocdbt.shape}; res {cg.meta.resolution}") + logger.note(f"bbox: {(bbs, bbe)}") + + rep = sv_remapping.get(sv_id, sv_id) + rep_pieces = {int(sv) for sv, r in sv_remapping.items() if r == rep} + + seg, sv_ids, bbs_, bbe_ = _read_seg_and_ids(cg, bbs, bbe) + ctx = _SplitCtx( + cg=cg, + seg=seg, + bbs=bbs, + bbe=bbe, + bbs_=bbs_, + bbe_=bbe_, + sv_id=sv_id, + sv_ids=sv_ids, + source_coords=task.src_coords, + sink_coords=task.sink_coords, + operation_id=operation_id, + time_stamp=time_stamp, + parent_ts=parent_ts, + ) + cut_supervoxels, supervoxel_ids = _select_cut_supervoxels(sv_id, sv_ids, rep_pieces) + split_result, voxel_overlap_crop = _compute_split(ctx, supervoxel_ids) + applied = _apply_and_capture(ctx, voxel_overlap_crop, split_result, cut_supervoxels) + rows = _route_edges_and_rows(ctx, applied.old_new_map, applied.new_id_label_map) + + return SvSplitOutcome( + seg_bbox=(bbs, bbe), + src_new_ids=applied.src_new_ids, + sink_new_ids=applied.sink_new_ids, + seg_write_pairs=applied.seg_write_pairs, + bigtable_rows=rows, + ) + + +def copy_parents_and_add_lineage( + cg: "ChunkedGraph", + operation_id: int, + old_new_map: dict, + *, + time_stamp: datetime = None, +) -> list: + """Copy parent pointers from old SVs onto their new-ID fragments + and write the lineage (FormerIdentity / NewIdentity) + L2 Child + list updates. + + `time_stamp` is the op's logical write time — used for every new-SV + cell this function writes so a `parent_ts`-filtered reader sees the + op atomically. The Parent-copy and Child-list writes deliberately + preserve the old cell's timestamp (so pre-op readers still see the + old hierarchy via the old timestamp). + + Returns a list of mutations to be persisted. + """ + result = [] + parents = set() + old_new_map = {k: list(v) for k, v in old_new_map.items()} + parent_cells_map = cg.client.read_nodes( + node_ids=list(old_new_map.keys()), properties=attributes.Hierarchy.Parent + ) + for old_id, new_ids in old_new_map.items(): + for new_id in new_ids: + val_dict = { + attributes.Hierarchy.FormerIdentity: np.array( + [old_id], dtype=basetypes.NODE_ID + ), + attributes.OperationLogs.OperationID: operation_id, + } + result.append( + cg.client.mutate_row( + serializers.serialize_uint64(new_id), + val_dict, + time_stamp=time_stamp, + ) + ) + for cell in parent_cells_map[old_id]: + cache_utils.update(cg.cache.parents_cache, [new_id], cell.value) + parents.add(cell.value) + result.append( + cg.client.mutate_row( + serializers.serialize_uint64(new_id), + {attributes.Hierarchy.Parent: cell.value}, + time_stamp=cell.timestamp, + ) + ) + val_dict = { + attributes.Hierarchy.NewIdentity: np.array(new_ids, dtype=basetypes.NODE_ID) + } + result.append( + cg.client.mutate_row( + serializers.serialize_uint64(old_id), + val_dict, + time_stamp=time_stamp, + ) + ) + + children_cells_map = cg.client.read_nodes( + node_ids=list(parents), properties=attributes.Hierarchy.Child + ) + for parent, children_cells in children_cells_map.items(): + assert len(children_cells) == 1, children_cells + for cell in children_cells: + mask = np.isin(cell.value, list(old_new_map.keys())) + replace = np.concatenate([old_new_map[x] for x in cell.value[mask]]) + children = np.concatenate([cell.value[~mask], replace]) + cg.cache.children_cache[parent] = children + result.append( + cg.client.mutate_row( + serializers.serialize_uint64(parent), + {attributes.Hierarchy.Child: children}, + time_stamp=cell.timestamp, + ) + ) + return result diff --git a/pychunkedgraph/graph/sv_split/profile.py b/pychunkedgraph/graph/sv_split/profile.py new file mode 100644 index 000000000..1c6e135b8 --- /dev/null +++ b/pychunkedgraph/graph/sv_split/profile.py @@ -0,0 +1,345 @@ +"""Re-runnable dry-run profile harness for SV splits. + +Drives an SV-split operation end-to-end under ``PCG_DRY_RUN=1`` so no +BT or OCDBT state is mutated, captures per-stage timing + memory + IO +metrics into a ``HierarchicalProfiler`` (one ``BlockMetrics`` row per +stage), and snapshots each stage's intermediate result into a +``SplitInputs`` dataclass that's persisted alongside the profiler. + +The persisted run lets the user iterate on a single heavy stage in +isolation (e.g. profile just ``split_supervoxels`` after editing it) +without re-running the prior stages. +""" + +import hashlib +import json +import pickle +import shutil +import sys +import tempfile +import traceback +from contextlib import contextmanager, redirect_stdout +from dataclasses import dataclass +from io import StringIO +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from pychunkedgraph.app.segmentation.common import _get_sources_and_sinks +from pychunkedgraph.profiler import HierarchicalProfiler, get_profiler +from . import edits +from pychunkedgraph.graph import utils as _utils_pkg +from pychunkedgraph.graph.dry_run import dry_run_scope +from pychunkedgraph.graph.operation import Cut, MulticutOperation, SvSplitRequired +from pychunkedgraph.graph.utils import generic as _utils_generic +from pychunkedgraph.graph.utils import id_helpers as _utils_id_helpers + +_CACHE_ROOT = Path(tempfile.gettempdir()) / "pcg_split_profile" + + +@dataclass +class SplitInputs: + """Per-stage inputs/outputs captured during a ``run_split_profile`` run. + + Persisted to disk alongside the profiler so single-stage replays + can reuse the inputs without re-running prior stages. Every field + matches the exact value at the corresponding call site in + ``MulticutOperation._apply``. + """ + + operation_id: Optional[int] = None + timestamp: Any = None + source_ids_pre: Optional[np.ndarray] = None + sink_ids_pre: Optional[np.ndarray] = None + source_coords: Optional[np.ndarray] = None + sink_coords: Optional[np.ndarray] = None + sv_remapping: Optional[dict] = None + plan_tasks: Any = None + plan_chunk_ids: Any = None + sv_result: Any = None + cut: Any = None + + +def _payload_canonical(payload: dict) -> str: + """Canonical JSON encoding used for hashing and collision detection.""" + return json.dumps(payload, sort_keys=True) + + +def _payload_sha(payload: dict) -> str: + """First 8 hex chars of sha256 over the canonical-JSON payload.""" + return hashlib.sha256(_payload_canonical(payload).encode()).hexdigest()[:8] + + +def run_dir(cg, payload: dict) -> Path: + """Cache directory for ``(cg.graph_id, payload)`` under system tmp.""" + return _CACHE_ROOT / cg.graph_id / _payload_sha(payload) + + +@contextmanager +def count_io(cg): + """Count BT row reads + OCDBT bytes read for the wrapped block. + + Wraps ``cg.client.read_nodes`` and ``cg.client.read_log_entries`` + (BT reads), plus ``get_local_segmentation`` at every top-level + binding site reached from the SV-split flow (OCDBT reads). All + originals are restored on exit. + """ + counters: Dict[str, int] = { + "bt_row_reads": 0, + "bt_log_reads": 0, + "ocdbt_reads": 0, + "ocdbt_bytes": 0, + } + + orig_read_nodes = cg.client.read_nodes + + def wrap_read_nodes(*a, **k): + result = orig_read_nodes(*a, **k) + counters["bt_row_reads"] += len(result) if result is not None else 0 + return result + + orig_read_log = cg.client.read_log_entries + + def wrap_read_log(*a, **k): + result = orig_read_log(*a, **k) + counters["bt_log_reads"] += len(result) if result is not None else 0 + return result + + cg.client.read_nodes = wrap_read_nodes + cg.client.read_log_entries = wrap_read_log + + # Patch every binding of get_local_segmentation reached from the + # SV-split flow. The source module is _utils_generic; the others + # imported it by name at module load time, so they hold separate + # references that need their own swap. + seg_modules = [_utils_generic, _utils_pkg, edits, _utils_id_helpers] + orig_seg_fns = {m: m.get_local_segmentation for m in seg_modules} + + def wrap_get_local_seg(meta, bbox_start, bbox_end, mip=0): + # Always call the source function so we don't double-count if + # one wrapped binding calls another. + arr = orig_seg_fns[_utils_generic](meta, bbox_start, bbox_end, mip) + counters["ocdbt_bytes"] += int(arr.nbytes) + counters["ocdbt_reads"] += 1 + return arr + + for m in seg_modules: + m.get_local_segmentation = wrap_get_local_seg + + try: + yield counters + finally: + cg.client.read_nodes = orig_read_nodes + cg.client.read_log_entries = orig_read_log + for m, fn in orig_seg_fns.items(): + m.get_local_segmentation = fn + + +def profile_call(cg, name, fn, *args, **kwargs): + """Profile a single callable under dry-run with IO counters. + + Standalone replay helper for per-stage profiling (e.g. after + editing a single function's source). Opens ``dry_run_scope`` + + ``count_io``, runs ``profiler.profile(name, with_memory=True, + with_rss=True, counters=counters)`` around ``fn(*args, **kwargs)``, + returns ``(profiler, result)``. The profiler has exactly one block. + """ + profiler = HierarchicalProfiler(enabled=True) + with dry_run_scope(), count_io(cg) as counters: + with profiler.profile(name, counters=counters): + result = fn(*args, **kwargs) + return profiler, result + + +def build_op( + cg, + payload: dict, + *, + user_id: str = "dry_run_profile", + bbox_offset: Tuple[int, int, int] = (240, 240, 24), +) -> MulticutOperation: + """Decode a /split JSON payload and instantiate ``MulticutOperation``. + + Mirrors ``ChunkedGraph.remove_edges`` direct instantiation pattern. + The caller drives ``op.execute()``. + """ + source_ids, sink_ids, source_coords, sink_coords = _get_sources_and_sinks( + cg, payload + ) + op = MulticutOperation( + cg, + user_id=user_id, + source_ids=source_ids, + sink_ids=sink_ids, + source_coords=source_coords, + sink_coords=sink_coords, + bbox_offset=bbox_offset, + ) + return op + + +def annotate_chunks(cg, chunk_ids) -> List[str]: + """Annotate each chunk id with its NGL-navigable center voxel.""" + out: List[str] = [] + for cid in chunk_ids: + coord = cg.get_chunk_center_voxel(int(cid)).tolist() + out.append(f"{int(cid):#x} -> voxel {coord}") + return out + + +def _save_run( + cg, + payload: dict, + profiler: HierarchicalProfiler, + inputs: SplitInputs, +) -> Path: + """Write run artifacts under ``run_dir``; raise on payload-hash collision.""" + target = run_dir(cg, payload) + payload_path = target / "payload.json" + incoming = _payload_canonical(payload) + if payload_path.exists(): + existing = payload_path.read_text() + if existing != incoming: + raise RuntimeError( + f"payload-hash collision at {target}: " + "existing payload != incoming payload" + ) + target.mkdir(parents=True, exist_ok=True) + with open(target / "inputs.pkl", "wb") as f: + pickle.dump(inputs, f) + with open(target / "profiler.pkl", "wb") as f: + pickle.dump(profiler, f) + buf = StringIO() + with redirect_stdout(buf): + profiler.metrics_report() + (target / "metrics.txt").write_text(buf.getvalue()) + payload_path.write_text(incoming) + return target + + +def load_run(cg, payload: dict) -> Tuple[HierarchicalProfiler, SplitInputs]: + """Restore a prior ``run_split_profile`` result from disk.""" + target = run_dir(cg, payload) + with open(target / "profiler.pkl", "rb") as f: + profiler = pickle.load(f) + with open(target / "inputs.pkl", "rb") as f: + inputs = pickle.load(f) + return profiler, inputs + + +def load_traceback(cg, payload: dict) -> Optional[str]: + """Return the saved traceback for a run, or ``None`` if it succeeded. + + ``run_split_profile`` writes ``traceback.txt`` only when + ``op.execute()`` raised; its absence means the run completed. + """ + tb_path = run_dir(cg, payload) / "traceback.txt" + return tb_path.read_text() if tb_path.exists() else None + + +def run_split_profile( + cg, payload: dict, *, overwrite: bool = False +) -> Tuple[HierarchicalProfiler, SplitInputs]: + """Drive an SV split under dry-run with per-stage metrics captured. + + Returns ``(profiler, inputs)``. Uses the global profiler so inline + ``get_profiler().profile()`` blocks inside the SV-split call path + are captured automatically. ``inputs`` holds each stage's + intermediate values for standalone replay. + + ``overwrite=True`` wipes any existing cached run for this payload + before starting. + + Always writes a cache (profiler + inputs + metrics.txt) to + ``run_dir(cg, payload)`` on completion — even when ``op.execute()`` + raises — and prints the cache path. + """ + target_dir = run_dir(cg, payload) + if target_dir.exists(): + if overwrite: + shutil.rmtree(target_dir) + else: + raise FileExistsError( + f"cached run already exists at {target_dir}; " + "pass overwrite=True to wipe and re-run, or " + "load_run(cg, payload) to read it" + ) + + profiler = get_profiler() + profiler.reset() + profiler.enabled = True + inputs = SplitInputs() + + op = build_op(cg, payload) + inputs.source_ids_pre = op.source_ids.copy() + inputs.sink_ids_pre = op.sink_ids.copy() + inputs.source_coords = op.source_coords + inputs.sink_coords = op.sink_coords + + with dry_run_scope(), count_io(cg) as counters: + profiler.default_counters = counters + + # Capture-only wrappers for SplitInputs replay — no profile() + # blocks. The real per-step metrics come from inline profile() + # blocks inside the called functions. + orig_run_multicut = MulticutOperation._run_multicut + orig_plan_sv_splits = edits.plan_sv_splits + orig_split_supervoxels = edits.split_supervoxels + + mincut_call_count = [0] + + def wrap_run_multicut(self_op, operation_id): + result = orig_run_multicut(self_op, operation_id) + mincut_call_count[0] += 1 + if mincut_call_count[0] == 1 and isinstance(result, SvSplitRequired): + inputs.sv_remapping = result.sv_remapping + elif isinstance(result, Cut): + inputs.cut = result + return result + + def wrap_plan_sv_splits(*a, **k): + result = orig_plan_sv_splits(*a, **k) + inputs.plan_tasks, inputs.plan_chunk_ids = result + return result + + def wrap_split_supervoxels(*a, **k): + if "operation_id" in k: + inputs.operation_id = k["operation_id"] + if "timestamp" in k: + inputs.timestamp = k["timestamp"] + result = orig_split_supervoxels(*a, **k) + inputs.sv_result = result + return result + + MulticutOperation._run_multicut = wrap_run_multicut + edits.plan_sv_splits = wrap_plan_sv_splits + edits.split_supervoxels = wrap_split_supervoxels + + tb_text = None + try: + op.execute() + except Exception as e: + print( + f"[split_profile] op.execute() raised " f"{type(e).__name__}: {e}", + file=sys.stderr, + ) + tb_text = traceback.format_exc() + finally: + MulticutOperation._run_multicut = orig_run_multicut + edits.plan_sv_splits = orig_plan_sv_splits + edits.split_supervoxels = orig_split_supervoxels + profiler.default_counters = None + + try: + target = _save_run(cg, payload, profiler, inputs) + if tb_text is not None: + (target / "traceback.txt").write_text(tb_text) + print(f"[split_profile] run cached at {target}") + except Exception as save_err: + print(f"[split_profile] cache save failed: {save_err}", file=sys.stderr) + + # Disable so the global profiler is a no-op for callers outside + # this harness (production code paths included). + profiler.enabled = False + return profiler, inputs diff --git a/pychunkedgraph/graph/sv_split/recovery.md b/pychunkedgraph/graph/sv_split/recovery.md new file mode 100644 index 000000000..1c22e973e --- /dev/null +++ b/pychunkedgraph/graph/sv_split/recovery.md @@ -0,0 +1,73 @@ +# Supervoxel split recovery + +## What it is + +A recovery path for supervoxel-split operations that crash mid-write, leaving partial state in segmentation and per-chunk locks held indefinitely. An operator runs a one-shot command that reverts the crashed op's partial segmentation writes and re-runs the op from scratch, producing a clean successful edit and freeing the affected chunks for future work. + +## When it applies + +Every supervoxel split writes two things atomically from the point of view of the operation: +- Segmentation chunks — the voxel-level split, where each L2 chunk touched by the split gets fresh supervoxel IDs at the voxels that moved to a new fragment. +- Graph hierarchy rows — lineage from the old supervoxels to the new ones plus the cross-chunk edges linking the new fragments. + +Both writes happen under an indefinite L2 chunk lock covering the exact chunks being rewritten. If the worker running the op dies before the lock's context exits — process kill, pod eviction, hardware failure, OOM — or raises a caught exception inside the persist block, the lock stays held and the op-log row's `L2ChunkLockScope` stays populated with the affected chunk IDs. From that moment on, any new op whose chunk set overlaps the stuck op's chunks refuses to start, blocking further corruption. + +The authoritative signal that an op is stuck is `L2ChunkLockScope` being non-empty — the clean `__exit__` path clears it on success. Either a crash (`Status=CREATED`, exit never ran) or a caught exception (`Status=FAILED`, held-cells-on-exception path) keeps the scope set. + +The operator runs recovery when the lock has been held long enough that the worker is definitively gone, not merely slow. A minimum-age threshold (10 minutes is a reasonable default) distinguishes stuck ops from ops still in flight. + +## Concurrent edits on other regions keep working + +The indefinite lock is per L2 chunk. While op X is stuck on chunks `{C1, C2, C3}`, another op Y on chunks `{C4, C5}` sees no indefinite cell on its chunks and proceeds normally. Its writes advance the latest OCDBT manifest. By the time the operator gets to recovery, the manifest has moved past the stuck op's `OperationTimeStamp` and other regions of segmentation reflect Y's (and any subsequent ops') work. + +This is important: the recovery must not undo Y's changes. It also cannot rely on a single "read pre-op segmentation" pin because that would return pre-Y state outside the stuck op's own chunks, and the replay would overwrite neighbor state with stale values. + +## Why pre-op pinning is not enough on its own + +A supervoxel split reads more than just its own chunks. To route existing cross-chunk edges onto the new fragments, the split reads a one-voxel shell around its chunk envelope — supervoxel IDs from neighboring chunks serve as anchors for the re-routed edges. + +If the replay opened segmentation with a pin at the stuck op's `OperationTimeStamp` and then read that shell, the neighbor voxels would show their pre-op state, not their current state. If Y had split a supervoxel in one of those neighbor chunks in the interim, the pinned read would see the old neighbor IDs, and the replay would route its cross-chunk edges to supervoxel IDs that no longer exist. Graph corruption. + +So the replay cannot read the world through a single pinned view. It must see latest state for the neighbor shell and clean pre-op state for its own chunks. + +## Cleanup-then-replay + +Recovery proceeds in two steps. + +**Cleanup.** For each chunk in the stuck op's durably-recorded scope, the operator reads the chunk's pre-op voxel values from a segmentation handle pinned at the op's `OperationTimeStamp`, and writes those values back to the latest (unpinned) handle. The result: those chunks, at the latest manifest, now show pre-op segmentation — as if the crashed op had never started. Neighbor chunks and every chunk outside the stuck op's scope are untouched, so any concurrent op's work is preserved. + +**Replay.** The operator then re-runs the op under the privileged-repair path. The run reads latest state, which is now consistent — clean pre-op values on the stuck op's own chunks, current state everywhere else — and goes through the normal edit flow. It allocates fresh supervoxel IDs, re-computes the split, writes new segmentation and hierarchy, and lands the op-log row at `SUCCESS`. + +When the replay's indefinite lock context exits, it issues value-matched releases on every chunk in the scope. Because the replay re-uses the crashed op's operation ID, the value match succeeds and the pre-existing indefinite cells are deleted. The chunks are free for new ops again. + +## Orphans in segmentation history + +OCDBT is append-only. The crashed op's partial segmentation writes still exist in the store's commit history — they are not deleted, only overshadowed. At the latest manifest, the cleanup step has overwritten them with pre-op values, so a normal (unpinned) read returns the pre-op state and the replay's fresh writes take effect on top of that. Readers that explicitly pin a historical version between the crash and the replay will still see the partial writes as a snapshot, but readers at latest never observe them. + +The orphan supervoxel IDs allocated by the crashed op are never referenced by any hierarchy row — the crashed op never wrote its hierarchy rows to completion, and the replay allocated a new set of IDs. From the graph's perspective those orphan IDs do not exist. + +## Operator workflow + +1. **List stuck ops.** The operator runs the list command with a minimum-age threshold. It returns op-log rows whose `L2ChunkLockScope` is still populated past that age (excluding any that have reached `SUCCESS`), along with each op's user ID, timestamp, age, status, and the number of chunks in its recorded scope. Ops too young to classify are skipped. + +2. **Inspect.** For each candidate, confirm from logs or monitoring that the worker that submitted the op is definitively dead — not, for example, paused on a long-running multicut. The minimum-age threshold exists to reduce false positives but the operator retains final judgment. + +3. **Replay.** The operator runs the replay command with the op ID. Before any destructive step the replay cross-checks the recorded scope against live lock state: for every chunk in `L2ChunkLockScope` it reads back the `Concurrency.IndefiniteLock` cell and verifies it's held by the op being replayed. Any discrepancy (cell missing, or held by a different op) aborts the replay loudly — a stale scope could otherwise have cleanup revert chunks another op legitimately owns. On clean verification, cleanup reverts the op's partial writes, then the privileged-repair path reruns the op. On success, the op-log row shows `SUCCESS` and the previously-held indefinite lock cells are released. + +4. **Verify.** A second list invocation should no longer include the op. Any new ops that were waiting on the affected chunks proceed. + +If the replay itself fails — for example, the operator's judgment about the worker's status was wrong and the original worker comes back — the replay surfaces the error and leaves the op-log row and lock state as it found them. The operator investigates, potentially clears the lock manually via direct bigtable tools, and tries again. + +## Invariants + +- A stuck op's durable scope record (written before any segmentation or hierarchy write begins) lets recovery locate every chunk that might have received a partial write, without a bigtable-wide scan. +- Cleanup only touches chunks in the stuck op's scope. Neighbor state and any concurrent ops' changes are preserved byte-for-byte. +- The replay sees a consistent world: pre-op values on the stuck op's own chunks (from the cleanup), current state on every other chunk (from the latest manifest). +- After successful replay, the op-log row is at `SUCCESS`, all indefinite cells previously held by that op are released, and the affected chunks are available to new ops. The op's original intent — the edit the user asked for — is realized with a fresh set of supervoxel IDs. + +## Related docs + +- [Overview](README.md) +- [Algorithm](algorithm.md) +- [Design](design.md) — why reads are pinned to `parent_ts`. +- [Edges](edges.md) diff --git a/pychunkedgraph/graph/utils/generic.py b/pychunkedgraph/graph/utils/generic.py index 4ebe0533e..64d37a5bc 100644 --- a/pychunkedgraph/graph/utils/generic.py +++ b/pychunkedgraph/graph/utils/generic.py @@ -173,8 +173,9 @@ def get_local_segmentation(meta, bbox_start, bbox_end, mip: int = 0) -> np.ndarr def lookup_svs_from_seg(meta, coordinates): """Read SV IDs directly from OCDBT segmentation at given coordinates.""" - bbox_start = np.min(coordinates, axis=0) - bbox_end = np.max(coordinates, axis=0) + 1 + coordinates = np.asarray(coordinates, dtype=int) + bbox_start = coordinates.min(axis=0) + bbox_end = coordinates.max(axis=0) + 1 seg = get_local_segmentation(meta, bbox_start, bbox_end)[..., 0] - local_coords = coordinates - bbox_start - return np.array([seg[tuple(c)] for c in local_coords], dtype=np.uint64) + local = coordinates - bbox_start + return seg[local[:, 0], local[:, 1], local[:, 2]].astype(np.uint64) diff --git a/pychunkedgraph/graph/utils/id_helpers.py b/pychunkedgraph/graph/utils/id_helpers.py index 7f7d8f927..60b1d0799 100644 --- a/pychunkedgraph/graph/utils/id_helpers.py +++ b/pychunkedgraph/graph/utils/id_helpers.py @@ -128,10 +128,14 @@ def get_atomic_ids_from_coords( """ import fastremap - if parent_id_layer == 1 and meta.ocdbt_seg: + if meta.ocdbt_seg: + # Unified path: any OCDBT lookup reads the current seg at the + # coords, ignoring the user-supplied parent (the parent may be + # stale after an SV split, or — for 3D mesh clicks — not an L1 + # SV at all). See handle_supervoxel_id_lookup for the rationale. return lookup_svs_from_seg(meta, coordinates) - if parent_id_layer == 1 and not meta.ocdbt_seg: + if parent_id_layer == 1: return np.array([parent_id] * len(coordinates), dtype=np.uint64) coordinates_nm = coordinates * np.array(meta.resolution) diff --git a/pychunkedgraph/ingest/cli.py b/pychunkedgraph/ingest/cli.py index bd2f07626..47659d1ca 100644 --- a/pychunkedgraph/ingest/cli.py +++ b/pychunkedgraph/ingest/cli.py @@ -5,6 +5,8 @@ """ import os +from functools import partial +from time import sleep from pychunkedgraph import configure_logging, DEBUG @@ -14,23 +16,22 @@ from .cluster import create_atomic_chunk, create_parent_chunk, enqueue_l2_tasks from .manager import IngestionManager +from .ocdbt import coordinator, setup_base from .utils import ( bootstrap, - chunk_id_str, + job_type_guard, print_completion_rate, print_status, + purge_layer_state, queue_layer_helper, - job_type_guard, + requeue_chunk, ) from .simple_tests import run_all from .create.parent_layer import add_parent_chunk from ..graph.chunkedgraph import ChunkedGraph -from ..graph.ocdbt import ( - base_exists, - create_base_ocdbt, - fork_base_manifest, - wipe_base_ocdbt, -) +from ..graph.ocdbt import OcdbtConfig +from ..meshing.meta import MeshConfig +from ..meshing.setup import setup_mesh_meta from ..utils.redis import get_redis_connection, keys as r_keys group_name = "ingest" @@ -52,78 +53,102 @@ def flush_redis(): @ingest_cli.command("graph") @click.argument("graph_id", type=str) -@click.argument("dataset", type=click.Path(exists=True)) -@click.option("--ocdbt", is_flag=True, help="Precomputed supervoxel seg into ocdbt.") +@click.argument("dataset", type=click.Path(exists=True), required=False) +@click.option("--raw", is_flag=True, help="Read edges from agglomeration output.") @click.option( - "--sv-split-threshold", - type=int, - default=10, - help="Distance threshold for SV split edge matching.", + "--retry", + "-r", + is_flag=True, + help="Re-run setup against the existing table (no cg.create()).", ) -@click.option("--raw", is_flag=True, help="Read edges from agglomeration output.") -@click.option("--retry", is_flag=True, help="Rerun without creating a new table.") @click.option( - "--reset-ocdbt", + "--skip-queue", + "-s", is_flag=True, - help="Wipe base AND this CG's delta OCDBT, then recreate from scratch.", + help="Set up everything but don't enqueue L2 tasks.", +) +@click.option( + "--test", + "-t", + is_flag=True, + help="Test 8 chunks at the center of dataset.", ) -@click.option("--test", is_flag=True, help="Test 8 chunks at the center of dataset.") @job_type_guard(group_name) def ingest_graph( graph_id: str, dataset: click.Path, - ocdbt: bool, - sv_split_threshold: int, raw: bool, retry: bool, - reset_ocdbt: bool, + skip_queue: bool, test: bool, ): - """ - Main ingest command. - Takes ingest config from a yaml file and queues atomic tasks. + """Main ingest command. Takes config from yaml, queues atomic tasks. + + Purely about the bigtable graph: creates the table and enqueues L2 + tasks. OCDBT base + fork creation happens in ``ingest layer N`` when + N matches ``ocdbt_populate_layer``; that's the single owner of the + OCDBT lifecycle. + + ``--retry`` reuses the existing IngestionManager from redis and skips + ``cg.create()``. Pair with ``--skip-queue`` to skip L2 enqueue too. """ redis = get_redis_connection() - redis.set(r_keys.JOB_TYPE, group_name) - with open(dataset, "r") as stream: - config = yaml.safe_load(stream) - if test: configure_logging(level=DEBUG) - meta, ingest_config, client_info = bootstrap(graph_id, config, raw, test) - cg = ChunkedGraph(meta=meta, client_info=client_info) - if not retry: + if retry: + imanager_pickle = redis.get(r_keys.INGESTION_MANAGER) + if imanager_pickle is None: + raise click.ClickException( + f"--retry requires an existing `{group_name}` job in redis. " + f"Run without --retry to start a new job." + ) + imanager = IngestionManager.from_pickle(imanager_pickle) + else: + if dataset is None: + raise click.ClickException("dataset is required unless --retry is passed.") + redis.set(r_keys.JOB_TYPE, group_name) + with open(dataset, "r") as stream: + config = yaml.safe_load(stream) + meta, ingest_config, client_info, ocdbt_config_dict = bootstrap( + graph_id, config, raw, test + ) + cg = ChunkedGraph(meta=meta, client_info=client_info) cg.create() - - needs_base = False - if ocdbt: - ws = cg.meta.data_source.WATERSHED - cg.meta.custom_data["seg"] = { - "ocdbt": True, - "sv_split_threshold": sv_split_threshold, - } - cg.update_meta(cg.meta, overwrite=True) - - if reset_ocdbt: - wipe_base_ocdbt(ws) - - needs_base = not base_exists(ws) - if needs_base: - create_base_ocdbt(ws) - - fork_base_manifest(ws, graph_id, wipe_existing=retry or reset_ocdbt) - - imanager = IngestionManager( - ingest_config, - meta, - ocdbt_seg=ocdbt, - ocdbt_populate_base=needs_base, - ) - enqueue_l2_tasks(imanager, create_atomic_chunk) + imanager = IngestionManager( + ingest_config, + meta, + ocdbt_config=ocdbt_config_dict, + ) + + if not skip_queue: + enqueue_l2_tasks(imanager, create_atomic_chunk) os._exit(0) +@ingest_cli.command("mesh_meta") +@click.argument("graph_id", type=str) +@click.argument("dataset", type=click.Path(exists=True)) +@job_type_guard(group_name) +def mesh_meta(graph_id: str, dataset: click.Path): + """Set up every mesh.* metadata field for GRAPH_ID from DATASET yaml. + + Reads ``mesh_config:`` from the yaml, applies it to the graph. Run + once per new/copied graph, after the operator has verified initial + ingest (including the root layer) is complete — no automatic gate. + """ + with open(dataset, "r") as stream: + config = yaml.safe_load(stream) + if "mesh_config" not in config: + raise click.ClickException( + f"{dataset} has no `mesh_config:` block — required for mesh_meta." + ) + mesh_cfg = MeshConfig.from_dict(config["mesh_config"]) + cg = ChunkedGraph(graph_id=graph_id) + result = setup_mesh_meta(cg, mesh_cfg) + click.echo(f"mesh meta written for {graph_id}: {result}") + + @ingest_cli.command("imanager") @click.argument("graph_id", type=str) @click.argument("dataset", type=click.Path(exists=True)) @@ -140,33 +165,104 @@ def pickle_imanager(graph_id: str, dataset: click.Path, raw: bool): except yaml.YAMLError as exc: print(exc) - meta, ingest_config, _ = bootstrap(graph_id, config=config, raw=raw) - imanager = IngestionManager(ingest_config, meta) + meta, ingest_config, _, ocdbt_config_dict = bootstrap( + graph_id, config=config, raw=raw + ) + imanager = IngestionManager(ingest_config, meta, ocdbt_config=ocdbt_config_dict) imanager.redis.set(r_keys.JOB_TYPE, group_name) @ingest_cli.command("layer") @click.argument("parent_layer", type=int) +@click.option( + "--queue-only", + "-q", + is_flag=True, + help="Only enqueue tasks; do not start the OCDBT coordinator. " + "Use when a coordinator is already running in another process.", +) +@click.option( + "--ocdbt-only", + "-o", + is_flag=True, + help="Workers run only OCDBT populate (skip add_parent_chunk). " + "Requires the OCDBT populate layer.", +) +@click.option( + "--ingest-only", + "-i", + is_flag=True, + help="Workers run only add_parent_chunk (skip OCDBT populate). " + "Use when the OCDBT base is already populated for this layer.", +) @job_type_guard(group_name) -def queue_layer(parent_layer): +def queue_layer(parent_layer, queue_only, ocdbt_only, ingest_only): """ Queue all chunk tasks at a given layer. Must be used when all the chunks at `parent_layer - 1` have completed. + + When this layer is the OCDBT populate layer, this command also owns the + OCDBT lifecycle: idempotently creates the base + fork via ``setup_base`` + and starts a ``DistributedCoordinatorServer`` so every worker's commit + routes through one process (eliminates manifest-CAS races and orphan + ``d/`` files). Stays in the foreground until killed. + + Flags: + ``--queue-only`` skips the coordinator (one is assumed running elsewhere). + ``--ocdbt-only`` task body = OCDBT populate only. + ``--ingest-only`` task body = add_parent_chunk only. """ assert parent_layer > 2, "This command is for layers 3 and above." + if ocdbt_only and ingest_only: + raise click.ClickException( + "--ocdbt-only and --ingest-only are mutually exclusive." + ) redis = get_redis_connection() imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) - queue_layer_helper(parent_layer, imanager, create_parent_chunk) + + is_populate_layer = imanager.is_ocdbt_populate_layer(parent_layer) + if ocdbt_only and not is_populate_layer: + raise click.ClickException( + "--ocdbt-only requires running at the OCDBT populate layer." + ) + + if is_populate_layer: + # Single owner of the OCDBT lifecycle: create base + fork if + # missing, reconcile config with on-disk meta, then re-pickle + # imanager so queued workers read the resolved config. + resolved = setup_base(imanager.cg, OcdbtConfig.from_dict(imanager.ocdbt_config)) + imanager.ocdbt_config = resolved.to_dict() + imanager.redis.set(r_keys.INGESTION_MANAGER, imanager.serialized(pickled=True)) + + mode = "ocdbt" if ocdbt_only else ("ingest" if ingest_only else "full") + task_fn = ( + partial(create_parent_chunk, mode=mode) + if mode != "full" + else create_parent_chunk + ) + + # Coordinator only matters when OCDBT populate will actually run. + needs_coordinator = ( + is_populate_layer and mode in ("full", "ocdbt") and not queue_only + ) + if needs_coordinator: + with coordinator(imanager.redis): + queue_layer_helper(parent_layer, imanager, task_fn) + while True: + sleep(60) + else: + queue_layer_helper(parent_layer, imanager, task_fn) @ingest_cli.command("status") +@click.option("--refresh", type=int, default=5, help="Seconds between redis polls.") @job_type_guard(group_name) -def ingest_status(): +def ingest_status(refresh: int): """Print ingest status to console by layer.""" redis = get_redis_connection() try: imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) - print_status(imanager, redis) + print_status(imanager, redis, refresh_seconds=refresh) except TypeError as err: print(f"\nNo current `{group_name}` job found in redis: {err}") @@ -177,23 +273,7 @@ def ingest_status(): @job_type_guard(group_name) def ingest_chunk(queue: str, chunk_info): """Manually queue chunk when a job is stuck for whatever reason.""" - redis = get_redis_connection() - imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) - layer, coords = chunk_info[0], chunk_info[1:] - - func = create_parent_chunk - args = (layer, coords) - if layer == 2: - func = create_atomic_chunk - args = (coords,) - queue = imanager.get_task_queue(queue) - queue.enqueue( - func, - job_id=chunk_id_str(layer, coords), - job_timeout=f"{int(layer * layer)}m", - result_ttl=0, - args=args, - ) + requeue_chunk(queue, chunk_info, create_atomic_chunk, create_parent_chunk) @ingest_cli.command("chunk_local") @@ -228,3 +308,14 @@ def rate(layer: int, span: int): @job_type_guard(group_name) def run_tests(graph_id): run_all(ChunkedGraph(graph_id=graph_id)) + + +@ingest_cli.command("purge_layer") +@click.argument("layer", type=int) +@click.confirmation_option(prompt="Purge ALL redis state for this layer?") +@job_type_guard(group_name) +def purge_layer(layer: int): + """Drop the per-layer RQ queue + registries + completion set so the + layer can be re-run from a previous layer's backup.""" + purge_layer_state(get_redis_connection(), layer) + click.echo(f"purged redis state for layer {layer}") diff --git a/pychunkedgraph/ingest/cli_upgrade.py b/pychunkedgraph/ingest/cli_upgrade.py index 3a3ccb2e1..83e9e53c8 100644 --- a/pychunkedgraph/ingest/cli_upgrade.py +++ b/pychunkedgraph/ingest/cli_upgrade.py @@ -4,44 +4,30 @@ cli for running upgrade """ -from time import sleep - -from pychunkedgraph import get_logger - -logger = get_logger(__name__) - import click -import tensorstore as ts from flask.cli import AppGroup -from pychunkedgraph import __version__ + +from pychunkedgraph import __version__, get_logger from pychunkedgraph.graph.meta import GraphConfig from . import IngestConfig -from .cluster import ( - convert_edges_to_ocdbt, - enqueue_l2_tasks, - upgrade_atomic_chunk, - upgrade_parent_chunk, -) +from .cluster import enqueue_l2_tasks, upgrade_atomic_chunk, upgrade_parent_chunk from .manager import IngestionManager +from .ocdbt import setup_base from .utils import ( - chunk_id_str, + job_type_guard, print_completion_rate, print_status, queue_layer_helper, - start_ocdbt_server, - job_type_guard, + requeue_chunk, ) from ..graph.chunkedgraph import ChunkedGraph, ChunkedGraphMeta -from ..graph.ocdbt import ( - base_exists, - create_base_ocdbt, - fork_base_manifest, - wipe_base_ocdbt, -) +from ..graph.ocdbt import OcdbtConfig from ..utils.redis import get_redis_connection from ..utils.redis import keys as r_keys +logger = get_logger(__name__) + group_name = "upgrade" upgrade_cli = AppGroup(group_name) @@ -63,26 +49,18 @@ def flush_redis(): @click.argument("graph_id", type=str) @click.option("--test", is_flag=True, help="Test 8 chunks at the center of dataset.") @click.option("--ocdbt", is_flag=True, help="Enable ocdbt seg (SV splitting support).") -@click.option("--ocdbt-edges", is_flag=True, help="Convert edges to ocdbt kv store.") @click.option( "--sv-split-threshold", type=int, default=10, help="Distance threshold for SV split edge matching.", ) -@click.option( - "--reset-ocdbt", - is_flag=True, - help="Wipe base AND this CG's delta OCDBT, then recreate from scratch.", -) @job_type_guard(group_name) def upgrade_graph( graph_id: str, test: bool, ocdbt: bool, - ocdbt_edges: bool, sv_split_threshold: int, - reset_ocdbt: bool, ): """ Main upgrade command. Queues atomic tasks. @@ -103,38 +81,18 @@ def upgrade_graph( cg = ChunkedGraph(graph_id=graph_id) if ocdbt: - ws = cg.meta.data_source.WATERSHED - cg.meta.custom_data["seg"] = { - "ocdbt": True, - "sv_split_threshold": sv_split_threshold, - } - cg.update_meta(cg.meta, overwrite=True) + ocdbt_cfg = OcdbtConfig.from_dict(cg.meta.custom_data.get("ocdbt_config")) + ocdbt_cfg.enabled = True + ocdbt_cfg.sv_split_threshold = sv_split_threshold + setup_base(cg, ocdbt_cfg) logger.note(f"enabled ocdbt seg with sv_split_threshold={sv_split_threshold}") - - if reset_ocdbt: - wipe_base_ocdbt(ws) - - if not base_exists(ws): - create_base_ocdbt(ws) - - fork_base_manifest(ws, graph_id, wipe_existing=reset_ocdbt) try: cg.client.create_column_family("4") except Exception: ... imanager = IngestionManager(ingest_config, cg.meta) - if ocdbt_edges: - server = ts.ocdbt.DistributedCoordinatorServer() - start_ocdbt_server(imanager, server) - - fn = convert_edges_to_ocdbt if ocdbt_edges else upgrade_atomic_chunk - enqueue_l2_tasks(imanager, fn) - - if ocdbt_edges: - logger.note("All tasks queued. Keep this alive for ocdbt coordinator server.") - while True: - sleep(60) + enqueue_l2_tasks(imanager, upgrade_atomic_chunk) @upgrade_cli.command("layer") @@ -153,13 +111,14 @@ def queue_layer(parent_layer: int, splits: int = 0): @upgrade_cli.command("status") +@click.option("--refresh", type=int, default=5, help="Seconds between redis polls.") @job_type_guard(group_name) -def upgrade_status(): +def upgrade_status(refresh: int): """Print upgrade status to console.""" redis = get_redis_connection() try: imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) - print_status(imanager, redis, upgrade=True) + print_status(imanager, redis, upgrade=True, refresh_seconds=refresh) except TypeError as err: print(f"\nNo current `{group_name}` job found in redis: {err}") @@ -170,23 +129,7 @@ def upgrade_status(): @job_type_guard(group_name) def upgrade_chunk(queue: str, chunk_info): """Manually queue chunk when a job is stuck for whatever reason.""" - redis = get_redis_connection() - imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) - layer, coords = chunk_info[0], chunk_info[1:] - - func = upgrade_parent_chunk - args = (layer, coords) - if layer == 2: - func = upgrade_atomic_chunk - args = (coords,) - queue = imanager.get_task_queue(queue) - queue.enqueue( - func, - job_id=chunk_id_str(layer, coords), - job_timeout=f"{int(layer * layer)}m", - result_ttl=0, - args=args, - ) + requeue_chunk(queue, chunk_info, upgrade_atomic_chunk, upgrade_parent_chunk) @upgrade_cli.command("rate") diff --git a/pychunkedgraph/ingest/cluster.py b/pychunkedgraph/ingest/cluster.py index 4e8149ead..36d111f1a 100644 --- a/pychunkedgraph/ingest/cluster.py +++ b/pychunkedgraph/ingest/cluster.py @@ -5,19 +5,20 @@ """ from os import environ - -from pychunkedgraph import get_logger - -logger = get_logger(__name__) from time import sleep from typing import Callable, Dict, Iterable, Tuple, Sequence import numpy as np from rq import Queue as RQueue, Retry +from pychunkedgraph import get_logger + +logger = get_logger(__name__) + from .utils import chunk_id_str, get_chunks_not_done, randomize_grid_points from .manager import IngestionManager +from .ocdbt import get_coordinator_address, populate_chunk from .ran_agglomeration import ( get_active_edges, read_raw_edge_data, @@ -27,11 +28,10 @@ from .create.parent_layer import add_parent_chunk from .upgrade.atomic_layer import update_chunk as update_atomic_chunk from .upgrade.parent_layer import update_chunk as update_parent_chunk -from ..graph.edges import EDGE_TYPES, Edges, put_edges +from ..graph.edges import EDGE_TYPES from ..graph import ChunkedGraph, ChunkedGraphMeta -from ..graph.ocdbt import copy_ws_chunk_multiscale, open_base_ocdbt +from ..graph.ocdbt import is_chunk_populated from ..graph.chunks.hierarchy import get_children_chunk_coords -from ..graph.basetypes import NODE_ID from ..io.edges import get_chunk_edges from ..io.components import get_chunk_components from ..utils.redis import keys as r_keys, get_redis_connection @@ -64,18 +64,47 @@ def _post_task_completion( def create_parent_chunk( parent_layer: int, parent_coords: Sequence[int], + mode: str = "full", ) -> None: + """One parent-chunk task. ``mode`` (bound at queue time via partial) + selects which halves run: + ``full`` : OCDBT populate (if eligible) + add_parent_chunk + ``ocdbt`` : only OCDBT populate (skip add_parent_chunk) + ``ingest`` : only add_parent_chunk (skip OCDBT populate) + + ``_post_task_completion`` always runs so the layer's progress tracking + in redis stays consistent. + + OCDBT populate runs FIRST so any failure aborts the task BEFORE graph + mutation; otherwise a half-built graph would force corrupt-state retries. + """ imanager = _get_imanager() - add_parent_chunk( - imanager.cg, - parent_layer, - parent_coords, - get_children_chunk_coords( - imanager.cg_meta, + + do_ocdbt = mode in ("full", "ocdbt") and imanager.is_ocdbt_populate_layer( + parent_layer + ) + do_ingest = mode in ("full", "ingest") + + if do_ocdbt: + ws = imanager.cg.meta.data_source.WATERSHED + if not is_chunk_populated(ws, parent_layer, parent_coords): + address = get_coordinator_address(imanager.redis) + populate_chunk( + imanager, ws, parent_layer, parent_coords, coordinator_address=address + ) + + if do_ingest: + add_parent_chunk( + imanager.cg, parent_layer, parent_coords, - ), - ) + get_children_chunk_coords( + imanager.cg_meta, + parent_layer, + parent_coords, + ), + ) + _post_task_completion(imanager, parent_layer, parent_coords) @@ -146,21 +175,6 @@ def create_atomic_chunk(coords: Sequence[int]): for k, v in chunk_edges_active.items(): logger.debug(f"active_{k}: {len(v)}") - if imanager.ocdbt_seg and imanager.ocdbt_populate_base: - # Populate the shared base OCDBT with precomputed chunks (one-time - # per watershed). Uses the raw base handles, NOT the per-CG fork - # spec — the fork only stores SV-split deltas. - src_list, dst_list, resolutions = open_base_ocdbt( - imanager.cg.meta.data_source.WATERSHED - ) - copy_ws_chunk_multiscale( - src_list, - dst_list, - resolutions, - imanager.cg.meta.graph_config.CHUNK_SIZE, - coords, - imanager.cg.meta.voxel_bounds, - ) _post_task_completion(imanager, 2, coords) @@ -172,47 +186,6 @@ def upgrade_atomic_chunk(coords: Sequence[int]): _post_task_completion(imanager, 2, coords) -def convert_edges_to_ocdbt(coords: Sequence[int]): - """ - Convert edges stored per chunk to ajacency list in the tensorstore ocdbt kv store. - """ - imanager = _get_imanager() - coords = np.array(list(coords), dtype=int) - chunk_edges_all, mapping = _get_atomic_chunk_data(imanager, coords) - - node_ids1 = [] - node_ids2 = [] - affinities = [] - areas = [] - for edges in chunk_edges_all.values(): - node_ids1.extend(edges.node_ids1) - node_ids2.extend(edges.node_ids2) - affinities.extend(edges.affinities) - areas.extend(edges.areas) - - edges = Edges(node_ids1, node_ids2, affinities=affinities, areas=areas) - nodes = np.concatenate( - [edges.node_ids1, edges.node_ids2, np.fromiter(mapping.keys(), dtype=NODE_ID)] - ) - nodes = np.unique(nodes) - - chunk_id = imanager.cg.get_chunk_id(layer=1, x=coords[0], y=coords[1], z=coords[2]) - chunk_ids = imanager.cg.get_chunk_ids_from_node_ids(nodes) - - host = imanager.redis.get("OCDBT_COORDINATOR_HOST").decode() - port = imanager.redis.get("OCDBT_COORDINATOR_PORT").decode() - environ["OCDBT_COORDINATOR_HOST"] = host - environ["OCDBT_COORDINATOR_PORT"] = port - logger.note(f"OCDBT Coordinator address {host}:{port}") - - put_edges( - f"{imanager.cg.meta.data_source.EDGES}/ocdbt", - nodes[chunk_ids == chunk_id], - edges, - ) - _post_task_completion(imanager, 2, coords) - - def _get_test_chunks(meta: ChunkedGraphMeta): """Chunks at the center most likely not to be empty""" parent_coords = np.array(meta.layer_chunk_bounds[3]) // 2 diff --git a/pychunkedgraph/ingest/create/atomic_layer.py b/pychunkedgraph/ingest/create/atomic_layer.py index 30043710d..69ea8a709 100644 --- a/pychunkedgraph/ingest/create/atomic_layer.py +++ b/pychunkedgraph/ingest/create/atomic_layer.py @@ -11,6 +11,8 @@ import numpy as np +from pychunkedgraph import get_logger + from ...graph import attributes, basetypes, serializers, get_valid_timestamp from ...graph.chunkedgraph import ChunkedGraph from ...graph.edges import Edges @@ -19,6 +21,8 @@ from ...graph.utils.flatgraph import build_gt_graph from ...graph.utils.flatgraph import connected_components +logger = get_logger(__name__) + def add_atomic_chunk( cg: ChunkedGraph, @@ -28,6 +32,10 @@ def add_atomic_chunk( time_stamp: Optional[datetime.datetime] = None, ): chunk_node_ids, chunk_edge_ids = _get_chunk_nodes_and_edges(chunk_edges_d, isolated) + logger.note( + f"L2 chunk {tuple(coords)}: nodes={len(chunk_node_ids):,} " + f"edges={len(chunk_edge_ids):,}" + ) if not chunk_node_ids.size: return diff --git a/pychunkedgraph/ingest/create/parent_layer.py b/pychunkedgraph/ingest/create/parent_layer.py index a12d2b858..1d36f9edd 100644 --- a/pychunkedgraph/ingest/create/parent_layer.py +++ b/pychunkedgraph/ingest/create/parent_layer.py @@ -12,6 +12,9 @@ import fastremap import numpy as np + +from pychunkedgraph import get_logger + from ...graph import types, attributes, basetypes, serializers, get_valid_timestamp from ...utils.general import chunked from ...graph.utils import flatgraph @@ -22,6 +25,8 @@ from .cross_edges import get_children_chunk_cross_edges from .cross_edges import get_chunk_nodes_cross_edge_layer +logger = get_logger(__name__) + def add_parent_chunk( cg: ChunkedGraph, @@ -50,6 +55,11 @@ def add_parent_chunk( raw_ccs = flatgraph.connected_components(graph) # connected components with indices connected_components = [graph_ids[cc] for cc in raw_ccs] + logger.note( + f"L{layer_id} chunk {tuple(coords)}: nodes={len(connected_components):,} " + f"cx_edges={len(cx_edges):,}" + ) + _write_connected_components( cg, layer_id, diff --git a/pychunkedgraph/ingest/manager.py b/pychunkedgraph/ingest/manager.py index 915538320..566558e05 100644 --- a/pychunkedgraph/ingest/manager.py +++ b/pychunkedgraph/ingest/manager.py @@ -15,8 +15,7 @@ def __init__( self, config: IngestConfig, chunkedgraph_meta: ChunkedGraphMeta, - ocdbt_seg: bool = False, - ocdbt_populate_base: bool = False, + ocdbt_config: dict = None, _from_pickle: bool = False, ): self._config = config @@ -25,8 +24,7 @@ def __init__( self._redis = None self._task_queues = {} self._from_pickle = _from_pickle - self.ocdbt_seg = ocdbt_seg - self.ocdbt_populate_base = ocdbt_populate_base + self.ocdbt_config = ocdbt_config or {} if not _from_pickle: # initiate redis and store serialized state @@ -55,12 +53,34 @@ def redis(self): self._redis.set(r_keys.INGESTION_MANAGER, self.serialized(pickled=True)) return self._redis + @property + def ocdbt_seg(self) -> bool: + return bool(self.ocdbt_config.get("enabled")) + + @property + def ocdbt_populate_base(self) -> bool: + return bool(self.ocdbt_config.get("populate_base")) + + @property + def ocdbt_populate_layer(self) -> int: + return int(self.ocdbt_config.get("populate_layer", 3)) + + def is_ocdbt_populate_layer(self, layer: int) -> bool: + """True iff OCDBT is enabled, base-populate is on, AND the given + layer matches the configured populate layer. Single guard for any + code that branches on 'should this layer touch OCDBT?'. + """ + return ( + self.ocdbt_seg + and self.ocdbt_populate_base + and layer == self.ocdbt_populate_layer + ) + def serialized(self, pickled=False): params = { "config": self._config, "chunkedgraph_meta": self._chunkedgraph_meta, - "ocdbt_seg": self.ocdbt_seg, - "ocdbt_populate_base": self.ocdbt_populate_base, + "ocdbt_config": self.ocdbt_config, } if pickled: return pickle.dumps(params) diff --git a/pychunkedgraph/ingest/ocdbt.py b/pychunkedgraph/ingest/ocdbt.py new file mode 100644 index 000000000..a8b10666f --- /dev/null +++ b/pychunkedgraph/ingest/ocdbt.py @@ -0,0 +1,128 @@ +"""OCDBT-specific ingest helpers. + +Single home for everything OCDBT-related at the ingest layer: + * coordinator-server lifecycle (`coordinator`) + * per-chunk populate task (`populate_chunk`), used from `create_parent_chunk` + * shared base setup (`setup_base`), used by both ingest and upgrade CLIs +""" + +from contextlib import contextmanager +from os import environ + +import tensorstore as ts + +from pychunkedgraph import get_logger + +from ..graph.ocdbt import ( + OcdbtConfig, + _layer_bbox, + base_exists, + copy_ws_bbox_multiscale, + create_base_ocdbt, + fork_base_manifest, + mark_chunk_populated, + open_base_ocdbt, + read_populate_meta, + write_populate_meta, +) + +logger = get_logger(__name__) + +_COORD_HOST_KEY = "OCDBT_COORDINATOR_HOST" +_COORD_PORT_KEY = "OCDBT_COORDINATOR_PORT" + + +@contextmanager +def coordinator(redis): + """Start a ``DistributedCoordinatorServer`` and advertise its address in + Redis so parallel populate workers route every OCDBT commit through this + one server — no manifest-CAS races, no orphan ``d/`` files. + + The server lives as long as the ``with`` block does; on exit the Redis + advertisement is cleared so a stale address can't outlive the server. + Caller blocks inside the ``with`` body (e.g. ``while True: sleep(60)``) + to keep the server reference alive across the populate phase. + """ + server = ts.ocdbt.DistributedCoordinatorServer() + host = environ.get("MY_POD_IP", "localhost") + redis.set(_COORD_HOST_KEY, host) + redis.set(_COORD_PORT_KEY, str(server.port)) + logger.note(f"OCDBT Coordinator listening at {host}:{server.port}") + try: + yield server + finally: + redis.delete(_COORD_HOST_KEY, _COORD_PORT_KEY) + logger.note("OCDBT Coordinator advertisement cleared.") + + +def get_coordinator_address(redis) -> str: + """Return the advertised ``"host:port"`` for the OCDBT coordinator. + + The address goes into the OCDBT kvstore spec's ``coordinator`` field — + the only routing knob tensorstore actually honors (verified against + the tensorstore binary; ``OCDBT_COORDINATOR_HOST/PORT`` env vars are + not consulted). + + Distributed callers MUST go through this getter so the populate fails + loudly when the coordinator isn't advertised — uncoordinated parallel + commits race the shared manifest and leak orphan ``d/`` files, the + exact bug this code exists to prevent. + """ + host = redis.get(_COORD_HOST_KEY) + port = redis.get(_COORD_PORT_KEY) + if not host or not port: + raise RuntimeError( + "OCDBT coordinator address not advertised in Redis " + f"({_COORD_HOST_KEY}/{_COORD_PORT_KEY} unset). " + "Run `flask ingest layer N` (with N == ocdbt_populate_layer) to " + "start the coordinator before queuing populate workers." + ) + return f"{host.decode()}:{port.decode()}" + + +def populate_chunk( + imanager, ws: str, layer: int, coords, coordinator_address: str | None = None +) -> None: + """One LN parent-layer task's OCDBT populate. + + When ``coordinator_address`` is set, every commit routes through that + server (mandatory for distributed workers — see ``get_coordinator_address``). + Single-process callers (notebooks, local one-off runs) can omit it and + write directly; safe as long as no other writer is committing concurrently. + + Copies the base-resolution bbox at every scale under one atomic + transaction and records the per-chunk completion marker. + """ + cfg = OcdbtConfig.from_dict(imanager.ocdbt_config) + src_list, dst_list, resolutions = open_base_ocdbt( + ws, cfg, coordinator_address=coordinator_address + ) + lo, hi = _layer_bbox(imanager.cg.meta, layer, coords) + coord_str = "_".join(str(int(c)) for c in coords) + dump_tag = f"{imanager.cg.meta.graph_id}/L{layer}/{coord_str}" + logger.note(f"L{layer} OCDBT populate {tuple(int(c) for c in coords)}") + copy_ws_bbox_multiscale(src_list, dst_list, resolutions, lo, hi, dump_tag=dump_tag) + mark_chunk_populated(ws, layer, coords) + + +def setup_base(cg, ocdbt_cfg: OcdbtConfig) -> OcdbtConfig: + """Idempotent OCDBT base + fork setup, shared by ingest and upgrade. + + Creates the base if missing; reconciles the yaml/CLI-supplied config + with the on-disk populate_meta (info-file wins per + ``OcdbtConfig.resolve``); persists the resolved config to + ``cg.meta.custom_data["ocdbt_config"]``; forks the manifest for this + CG. Returns the resolved OcdbtConfig. To wipe and start over, use + ``gcloud storage rm -r gs:///ocdbt/`` before invoking. + """ + ws = cg.meta.data_source.WATERSHED + if not base_exists(ws): + create_base_ocdbt(ws, ocdbt_cfg) + info = read_populate_meta(ws) + resolved = OcdbtConfig.resolve(ocdbt_cfg.to_dict(), info) + if resolved.populate_base: + write_populate_meta(ws, resolved.to_dict()) + cg.meta.custom_data["ocdbt_config"] = resolved.to_dict() + cg.update_meta(cg.meta, overwrite=True) + fork_base_manifest(ws, cg.meta.graph_id) + return resolved diff --git a/pychunkedgraph/ingest/utils.py b/pychunkedgraph/ingest/utils.py index d69756104..fbd55dedc 100644 --- a/pychunkedgraph/ingest/utils.py +++ b/pychunkedgraph/ingest/utils.py @@ -1,29 +1,45 @@ # pylint: disable=invalid-name, missing-docstring import functools - -from pychunkedgraph import get_logger - -logger = get_logger(__name__) -import math, random, sys +import math +import sys from os import environ from time import sleep -from typing import Any, Generator, Tuple +from typing import Dict, Generator, Tuple import numpy as np -import tensorstore as ts -from rq import Queue, Retry, Worker -from rq.worker import WorkerStatus +from kvdbclient import BigTableConfig, HBaseConfig +from rich import box +from rich.console import Group +from rich.live import Live +from rich.panel import Panel +from rich.rule import Rule +from rich.table import Table +from rich.text import Text +from rq import Queue, Retry +from rq.registry import ( + CanceledJobRegistry, + DeferredJobRegistry, + FailedJobRegistry, + FinishedJobRegistry, + ScheduledJobRegistry, + StartedJobRegistry, +) +from rq.worker_registration import WORKERS_BY_QUEUE_KEY + +from pychunkedgraph import get_logger from . import IngestConfig from .manager import IngestionManager -from ..graph.meta import ChunkedGraphMeta, DataSource, GraphConfig from ..graph import BackendClientInfo -from kvdbclient import BigTableConfig, HBaseConfig +from ..graph.meta import ChunkedGraphMeta, DataSource, GraphConfig +from ..graph.ocdbt import OcdbtConfig from ..utils.general import chunked from ..utils.redis import get_redis_connection from ..utils.redis import keys as r_keys +logger = get_logger(__name__) + chunk_id_str = lambda layer, coords: f"{layer}_{'_'.join(map(str, coords))}" @@ -32,8 +48,13 @@ def bootstrap( config: dict, raw: bool = False, test_run: bool = False, -) -> Tuple[ChunkedGraphMeta, IngestConfig, BackendClientInfo]: - """Parse config loaded from a yaml file.""" +) -> Tuple[ChunkedGraphMeta, IngestConfig, BackendClientInfo, Dict]: + """Parse config loaded from a yaml file. + + Returns ``(meta, ingest_config, client_info, ocdbt_config_dict)`` where the + ocdbt config dict is sanitized through ``OcdbtConfig.from_dict(...).to_dict()`` + so unknown yaml keys are dropped and missing fields take dataclass defaults. + """ ingest_config = IngestConfig( **config.get("ingest_config", {}), USE_RAW_EDGES=raw, @@ -55,7 +76,8 @@ def bootstrap( data_source = DataSource(**config["data_source"]) meta = ChunkedGraphMeta(graph_config, data_source) - return (meta, ingest_config, client_info) + ocdbt_config_dict = OcdbtConfig.from_dict(config.get("ocdbt_config")).to_dict() + return (meta, ingest_config, client_info, ocdbt_config_dict) def move_up(lines: int = 1): @@ -95,16 +117,6 @@ def postprocess_edge_data(im, edge_dict): raise ValueError(f"Unknown data_version: {data_version}") -def start_ocdbt_server(imanager: IngestionManager, server: Any): - spec = {"driver": "ocdbt", "base": f"{imanager.cg.meta.data_source.EDGES}/ocdbt"} - spec["coordinator"] = {"address": f"localhost:{server.port}"} - ts.KvStore.open(spec).result() - imanager.redis.set("OCDBT_COORDINATOR_PORT", str(server.port)) - ocdbt_host = environ.get("MY_POD_IP", "localhost") - imanager.redis.set("OCDBT_COORDINATOR_HOST", ocdbt_host) - logger.note(f"OCDBT Coordinator address {ocdbt_host}:{server.port}") - - def randomize_grid_points(X: int, Y: int, Z: int) -> Generator[int, int, int]: indices = np.arange(X * Y * Z) np.random.shuffle(indices) @@ -148,64 +160,237 @@ def print_completion_rate(imanager: IngestionManager, layer: int, span: int = 30 move_up() -def print_status(imanager: IngestionManager, redis, upgrade: bool = False): +def _workers_busy_per_queue(redis, worker_keys_per_layer): + """For each layer's set of worker keys, return parallel (workers, busy) + string lists — "-" / "-" when no workers are registered for that layer. + + Two-round-trip approach: caller already fetched the SMEMBERS sets; this + function pipelines HGET state for every worker key and counts busy. + """ + state_pipe = redis.pipeline() + for keys in worker_keys_per_layer: + for wk in keys: + state_pipe.hget(wk, "state") + states = state_pipe.execute() if any(worker_keys_per_layer) else [] + + workers, busy = [], [] + idx = 0 + for keys in worker_keys_per_layer: + total = len(keys) + b = 0 + for _ in keys: + if states[idx] == b"busy": + b += 1 + idx += 1 + workers.append(f"{total}" if total else "-") + busy.append(f"{b}" if total else "-") + return workers, busy + + +def _layer_keys(layers) -> list: + """Stable per-layer redis keys (completed-set, queue list, failed zset, workers set). + + Returned once before the refresh loop so each refresh skips Queue / + FailedJobRegistry construction and the lazy rq.registry import. """ - Helper to print status to console. + return [ + ( + f"{layer}c", + f"rq:queue:l{layer}", + f"rq:failed:l{layer}", + WORKERS_BY_QUEUE_KEY % f"l{layer}", + ) + for layer in layers + ] + + +def _layer_status(redis, layer_keys): + """Pipelined fetch of job_type + per-layer counts + busy-worker ratios.""" + pipeline = redis.pipeline() + pipeline.get(r_keys.JOB_TYPE) + for completed_key, queue_key, failed_key, workers_key in layer_keys: + pipeline.scard(completed_key) + pipeline.llen(queue_key) + pipeline.zcard(failed_key) + pipeline.smembers(workers_key) + results = pipeline.execute() + + job_type = results[0].decode() if results[0] else "not_available" + completed, queued, failed, worker_keys_per_layer = [], [], [], [] + for i in range(1, len(results), 4): + completed.append(results[i]) + queued.append(results[i + 1]) + failed.append(results[i + 2]) + worker_keys_per_layer.append(results[i + 3]) + + workers, busy = _workers_busy_per_queue(redis, worker_keys_per_layer) + return job_type, completed, queued, failed, workers, busy + + +def _sized_table(columns: list, rows: list, **table_kwargs) -> Table: + """Build a Rich Table whose column widths are sized to the actual data. + + `columns` is a list of (name, justify) tuples. + `rows` is a list of tuples of cell strings (one per column). + Each column gets width = max(len(name), max(len(cell)) over rows) so Rich + never wraps or crops because no column is implicitly squeezed. + """ + table = Table( + box=None, + pad_edge=False, + padding=(0, 2), + show_header=True, + header_style="bold", + **table_kwargs, + ) + for col_idx, (name, justify) in enumerate(columns): + width = max(len(name), max((len(row[col_idx]) for row in rows), default=0)) + # Header wrapped in Text so any brackets in `name` render literally + # rather than being parsed as Rich markup tags. + table.add_column( + Text(name, style="bold"), justify=justify, width=width, no_wrap=True + ) + for row in rows: + table.add_row(*row) + return table + + +def _aligned_kv_table(pairs: list, widths: list) -> Table: + """One-data-row mini-table with externally-provided per-column widths.""" + table = Table( + box=None, pad_edge=False, padding=(0, 1), show_header=True, header_style="bold" + ) + for (name, _), w in zip(pairs, widths): + table.add_column(name, justify="left", width=w, no_wrap=True) + table.add_row(*(v for _, v in pairs)) + return table + + +def _header_renderables(imanager: IngestionManager) -> list: + """Graph and ocdbt rows as mini-tables sharing column widths so columns line up.""" + graph_pairs = [ + ("version", str(imanager.cg.version)), + ("graph_id", imanager.cg.graph_id), + ("chunk_size", str(imanager.cg.meta.graph_config.CHUNK_SIZE)), + ] + ocdbt_pairs = [] + if imanager.ocdbt_seg: + ocdbt_pairs = [ + ("ocdbt", str(imanager.ocdbt_seg)), + ("populate_base", str(imanager.ocdbt_populate_base)), + ("populate_layer", str(imanager.ocdbt_populate_layer)), + ] + + # Per-column width = max length seen in EITHER row's header or value at that index. + n = max(len(graph_pairs), len(ocdbt_pairs)) + widths = [] + for i in range(n): + sizes = [] + if i < len(graph_pairs): + sizes.append(len(graph_pairs[i][0])) + sizes.append(len(graph_pairs[i][1])) + if i < len(ocdbt_pairs): + sizes.append(len(ocdbt_pairs[i][0])) + sizes.append(len(ocdbt_pairs[i][1])) + widths.append(max(sizes)) + + out = [_aligned_kv_table(graph_pairs, widths)] + if ocdbt_pairs: + out.append(Rule(style="dim")) + out.append(_aligned_kv_table(ocdbt_pairs, widths)) + return out + + +def _status_table( + layers, layer_counts, completed, queued, failed, workers, busy +) -> Table: + """One row per layer with progress, queue, and worker stats.""" + columns = [ + ("layer", "center"), + ("queued", "right"), + ("completed", "right"), + ("total", "right"), + ("progress", "right"), + ("failed", "right"), + ("workers", "right"), + ("busy", "right"), + ] + rows = [] + for layer, done, count, q, f, w, b in zip( + layers, completed, layer_counts, queued, failed, workers, busy + ): + pct = math.floor((done / count) * 100) if count else 0 + rows.append( + ( + str(layer), + f"{q:,}", + f"{done:,}", + f"{count:,}", + f"{pct}%", + f"{f:,}", + str(w), + str(b), + ) + ) + return _sized_table(columns, rows) + + +def _status_renderable( + imanager, + layers, + layer_counts, + job_type, + completed, + queued, + failed, + workers, + busy, +): + """Combine header rows + per-layer table inside one Panel; job_type goes in the title.""" + body = Group( + *_header_renderables(imanager), + Rule(style="dim"), + _status_table(layers, layer_counts, completed, queued, failed, workers, busy), + ) + return Panel( + body, + title=job_type, + title_align="left", + box=box.ROUNDED, + padding=(0, 1), + expand=False, + ) + + +def print_status( + imanager: IngestionManager, + redis, + upgrade: bool = False, + refresh_seconds: int = 5, +): + """ + Print status to console. If `upgrade=True`, status does not include the root layer, since there is no need to update cross edges for root ids. + `refresh_seconds` is how often redis is re-polled between redraws. """ layers = range(2, imanager.cg_meta.layer_count + 1) if upgrade: layers = range(2, imanager.cg_meta.layer_count) - - def _refresh_status(): - pipeline = redis.pipeline() - pipeline.get(r_keys.JOB_TYPE) - worker_busy = ["-"] * len(layers) - for layer in layers: - pipeline.scard(f"{layer}c") - queue = Queue(f"l{layer}", connection=redis) - pipeline.llen(queue.key) - pipeline.zcard(queue.failed_job_registry.key) - - results = pipeline.execute() - job_type = "not_available" - if results[0] is not None: - job_type = results[0].decode() - completed = [] - queued = [] - failed = [] - for i in range(1, len(results), 3): - result = results[i : i + 3] - completed.append(result[0]) - queued.append(result[1]) - failed.append(result[2]) - return job_type, completed, queued, failed, worker_busy - - job_type, completed, queued, failed, worker_busy = _refresh_status() - layer_counts = imanager.cg_meta.layer_chunk_counts - header = ( - f"\njob_type: \t{job_type}" - f"\nversion: \t{imanager.cg.version}" - f"\ngraph_id: \t{imanager.cg.graph_id}" - f"\nchunk_size: \t{imanager.cg.meta.graph_config.CHUNK_SIZE}" - "\n\nlayer status:" - ) - print(header) - while True: - for layer, done, count in zip(layers, completed, layer_counts): - print( - f"{layer}\t| {done:9} / {count} \t| {math.floor((done/count)*100):6}%" - ) + layer_keys = _layer_keys(layers) - print("\n\nqueue status:") - for layer, q, f, wb in zip(layers, queued, failed, worker_busy): - print(f"l{layer}\t| queued: {q:<10} failed: {f:<10} busy: {wb}") + def render(): + return _status_renderable( + imanager, layers, layer_counts, *_layer_status(redis, layer_keys) + ) - sleep(1) - _, completed, queued, failed, worker_busy = _refresh_status() - move_up(lines=2 * len(layers) + 3) + # Start Live with a placeholder so the panel paints instantly; the first + # real fetch (which includes redis connection setup) replaces it. + with Live(Text("loading…"), screen=False) as live: + while True: + live.update(render()) + sleep(refresh_seconds) def queue_layer_helper( @@ -267,6 +452,54 @@ def queue_layer_helper( logger.note(f"Queued {len(job_datas)} chunks.") +_RQ_REGISTRY_CLASSES = ( + FailedJobRegistry, + StartedJobRegistry, + DeferredJobRegistry, + ScheduledJobRegistry, + FinishedJobRegistry, + CanceledJobRegistry, +) + + +def purge_layer_state(redis, layer: int) -> None: + """Reset per-layer state so a layer can be re-run from a previous + layer's backup: drop the RQ queue (deletes jobs too), wipe each RQ + registry by its own ``.key`` attribute (so we don't hardcode RQ's + internal key naming), and clear the pychunkedgraph completion set + ``f"{layer}c"``. + """ + name = f"l{layer}" + Queue(name=name, connection=redis).delete(delete_jobs=True) + for cls in _RQ_REGISTRY_CLASSES: + redis.delete(cls(name=name, connection=redis).key) + redis.delete(f"{layer}c") + + +def requeue_chunk(queue_name: str, chunk_info, atomic_fn, parent_fn): + """Body of the ``chunk`` CLI command (shared by ingest and upgrade). + + Loads the manager from Redis, dispatches ``atomic_fn`` for L2 or + ``parent_fn`` for L3+, and enqueues a single task with the standard + job_id / timeout convention. + """ + redis = get_redis_connection() + imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER)) + layer, coords = chunk_info[0], chunk_info[1:] + if layer == 2: + fn, args = atomic_fn, (coords,) + else: + fn, args = parent_fn, (layer, coords) + queue = imanager.get_task_queue(queue_name) + queue.enqueue( + fn, + job_id=chunk_id_str(layer, coords), + job_timeout=f"{int(layer * layer)}m", + result_ttl=0, + args=args, + ) + + def job_type_guard(job_type: str): def decorator_job_type_guard(func): @functools.wraps(func) diff --git a/pychunkedgraph/meshing/manifest/cache.py b/pychunkedgraph/meshing/manifest/cache.py index f38a830c2..0decc3a65 100644 --- a/pychunkedgraph/meshing/manifest/cache.py +++ b/pychunkedgraph/meshing/manifest/cache.py @@ -12,6 +12,8 @@ DOES_NOT_EXIST = "X" INITIAL_PATH_PREFIX = "initial_path_prefix" +MANIFEST_TTL_SECONDS = 3 * 24 * 3600 + REDIS_HOST = os.environ.get("MANIFEST_CACHE_REDIS_HOST", "localhost") REDIS_PORT = os.environ.get("MANIFEST_CACHE_REDIS_PORT", "6379") REDIS_PASSWORD = os.environ.get("MANIFEST_CACHE_REDIS_PASSWORD", "") @@ -74,6 +76,27 @@ def clear_fragments(self, node_ids) -> None: keys = [f"{self.namespace}:{n}" for n in node_ids] REDIS.delete(*keys) + def clear_namespace(self, batch_size: int = 1000) -> int: + """Delete every key under this graph_id's namespace. + + SCAN-based pattern delete (non-blocking on the redis side). + Returns the number of keys deleted. + """ + if REDIS is None: + return 0 + + pattern = f"{self.namespace}:*" + deleted = 0 + batch = [] + for key in REDIS.scan_iter(match=pattern, count=batch_size): + batch.append(key) + if len(batch) >= batch_size: + deleted += REDIS.delete(*batch) + batch.clear() + if batch: + deleted += REDIS.delete(*batch) + return deleted + def _get_cached_initial_fragments(self, node_ids: List[np.uint64]): if REDIS is None: return {}, node_ids, [] @@ -140,10 +163,16 @@ def _set_cached_initial_fragments( for node_id, fragment_info in fragments_d.items(): path, offset, size = fragment_info key = f"{self.namespace}:{node_id}" - pipeline.set(key, f"{path[prefix_idx:]}:{offset}:{size}") + pipeline.set( + key, f"{path[prefix_idx:]}:{offset}:{size}", ex=MANIFEST_TTL_SECONDS + ) for node_id in not_existing: - pipeline.set(f"{self.namespace}:{node_id}", DOES_NOT_EXIST) + pipeline.set( + f"{self.namespace}:{node_id}", + DOES_NOT_EXIST, + ex=MANIFEST_TTL_SECONDS, + ) pipeline.execute() @@ -155,9 +184,15 @@ def _set_cached_dynamic_fragments( pipeline = REDIS.pipeline() for node_id, fragment in fragments_d.items(): - pipeline.set(f"{self.namespace}:{node_id}", fragment) + pipeline.set( + f"{self.namespace}:{node_id}", fragment, ex=MANIFEST_TTL_SECONDS + ) for node_id in not_existing: - pipeline.set(f"{self.namespace}:{node_id}", DOES_NOT_EXIST) + pipeline.set( + f"{self.namespace}:{node_id}", + DOES_NOT_EXIST, + ex=MANIFEST_TTL_SECONDS, + ) pipeline.execute() diff --git a/pychunkedgraph/meshing/meshgen_utils.py b/pychunkedgraph/meshing/meshgen_utils.py index 1a35f5f73..db8f6d092 100644 --- a/pychunkedgraph/meshing/meshgen_utils.py +++ b/pychunkedgraph/meshing/meshgen_utils.py @@ -145,7 +145,15 @@ def get_json_info(cg): dataset_info = cg.meta.dataset_info dummy_app_info = {"app": {"supported_api_versions": [0, 1]}} info = {**dataset_info, **dummy_app_info} - info["mesh"] = cg.meta.custom_data.get("mesh", {}).get("dir", "graphene_meshes") + mesh_meta = cg.meta.custom_data.get("mesh", {}) + info["mesh"] = mesh_meta.get("dir", "graphene_meshes") + # `dynamic_mesh_dir` lets a dataset name the unsharded dynamic-mesh + # subdir explicitly. Default `"dynamic"` matches the mesh worker's + # fallback and NG's current hardcoded subdir name — see the + # spelunker-ocdbt graphene backend (looks up + # `dynamic/`). NG must be patched to read + # this info field before non-default values route correctly. + info["dynamic_mesh_dir"] = mesh_meta.get("dynamic_mesh_dir", "dynamic") info_str = dumps(info) return loads(info_str) diff --git a/pychunkedgraph/meshing/meta.py b/pychunkedgraph/meshing/meta.py new file mode 100644 index 000000000..2969ea19a --- /dev/null +++ b/pychunkedgraph/meshing/meta.py @@ -0,0 +1,68 @@ +"""MeshConfig dataclass — single source of truth for per-CG mesh setup values. + +Read from the dataset yaml under a ``mesh_config:`` block, exactly like +``OcdbtConfig`` is read from ``ocdbt_config:``. Every static field is +required — the helper that applies it (``setup_mesh_meta``) does not +substitute defaults for missing yaml entries. The only optional field +is :attr:`dynamic_mesh_dir`, which is graph-id-derived and filled in by +:meth:`with_graph_id` when omitted from the yaml. + +Example yaml block:: + + mesh_config: + dir: graphene_meshes + mip: 0 + max_layer: 6 + max_error: 40 + chunk_size: [512, 512, 256] + minishard_bits: {2: 1, 3: 3, 4: 6, 5: 9, 6: 12} + # dynamic_mesh_dir: my_custom_dir # optional; default "dynamic_" +""" + +from dataclasses import asdict, dataclass, replace +from typing import Dict, List, Optional + + +@dataclass +class MeshConfig: + """Per-CG mesh setup config. See module docstring for yaml schema.""" + + dir: str + mip: int + max_layer: int + max_error: int + chunk_size: List[int] + minishard_bits: Dict[int, int] + dynamic_mesh_dir: Optional[str] = None + + @classmethod + def from_dict(cls, d: Dict) -> "MeshConfig": + """Build from a yaml-parsed dict. + + Unknown keys are dropped (so older yamls don't break newer code). + ``minishard_bits`` keys are coerced to ``int`` so the yaml is + tolerant of bare-int vs quoted-string keys. + """ + if not d: + raise ValueError( + "MeshConfig.from_dict: empty config — yaml `mesh_config:` " + "block is required" + ) + known = {f for f in cls.__dataclass_fields__} + kwargs = {k: v for k, v in d.items() if k in known} + if "minishard_bits" in kwargs: + kwargs["minishard_bits"] = { + int(k): int(v) for k, v in kwargs["minishard_bits"].items() + } + if "chunk_size" in kwargs: + kwargs["chunk_size"] = [int(x) for x in kwargs["chunk_size"]] + return cls(**kwargs) + + def with_graph_id(self, graph_id: str) -> "MeshConfig": + """Return a copy with ``dynamic_mesh_dir`` filled in if unset.""" + if self.dynamic_mesh_dir is not None: + return self + return replace(self, dynamic_mesh_dir=f"dynamic_{graph_id}") + + def to_dict(self) -> Dict: + return asdict(self) diff --git a/pychunkedgraph/meshing/setup.py b/pychunkedgraph/meshing/setup.py new file mode 100644 index 000000000..41f2d3b74 --- /dev/null +++ b/pychunkedgraph/meshing/setup.py @@ -0,0 +1,147 @@ +"""One-shot mesh metadata setup for a CG. + +Writes every mesh-related field a new or freshly-copied graph needs +before any mesh fragment can be served: + + 1. ``cg.meta.ws_cv.info["mesh"]`` (mesh dir in the watershed cv info.json) + 2. ``cg.meta.ws_cv.mesh.meta.info`` (per-layer sharded mesh spec) + 3. ``cg.meta.ws_cv.info["mesh_metadata"]`` (uniform draco grid + dynamic dir) + 4. ``cg.meta.custom_data["mesh"]`` (CG bigtable meta block) + +Idempotent: re-running overwrites the same fields with the current +inputs, with one exception — ``initial_ts`` is preserved if already +set, because changing it after the fact would reclassify every node +id and silently break served manifests. +""" + +import logging + +import numpy as np + +from ..graph.chunkedgraph import ChunkedGraph +from .meshgen import get_draco_encoding_settings_for_chunk +from .meta import MeshConfig + +logger = logging.getLogger(__name__) + + +def derive_initial_ts(cg: ChunkedGraph) -> int: + """Unix-seconds timestamp of a root id sampled from the dataset center. + + ``mesh.initial_ts`` is the threshold ``segregate_node_ids`` (see + ``meshing/manifest/utils.py``) uses to classify root ids as initial + vs post-ingest. It must sit above the last initial-ingest commit + and below any post-ingest commit. Picking a root id near the + volume center and using its commit timestamp satisfies both bounds + for any graph that completed initial ingest. + + Walks shells outward from the center of the L2 chunk grid (L1 + shares L2's coordinate grid) and returns the timestamp of the root + of the first SV found. + """ + hi = np.asarray(cg.meta.layer_chunk_bounds[2]) + center = hi // 2 + for r in range(int(hi.max()) + 1): + box = ( + np.array( + np.meshgrid( + np.arange(-r, r + 1), + np.arange(-r, r + 1), + np.arange(-r, r + 1), + indexing="ij", + ) + ) + .reshape(3, -1) + .T + ) + shell = box[np.max(np.abs(box), axis=1) == r] + coords = np.unique(np.clip(center + shell, 0, hi - 1), axis=0) + for c in coords: + chunk_id = cg.get_chunk_id(layer=1, x=int(c[0]), y=int(c[1]), z=int(c[2])) + svs = list(cg.range_read_chunk(chunk_id)) + if svs: + sv = svs[len(svs) // 2] + root = cg.get_root(sv) + ts = cg.get_node_timestamps(np.array([root]), return_numpy=False)[0] + return int(ts.timestamp()) + raise RuntimeError("derive_initial_ts: no SVs found anywhere in the volume") + + +def setup_mesh_meta( + cg: ChunkedGraph, + mesh_config: MeshConfig, +) -> dict: + """Write every mesh.* metadata field this graph needs to serve meshes. + + Writes go to two places: the watershed CloudVolume (steps 1-3, via + ``info.json`` / ``mesh/info`` on GCS) and the CG's bigtable meta + block (step 4). + + ``initial_ts`` is set once and never overwritten — if the existing + bigtable mesh meta already has one, it is reused as-is. Otherwise + it is derived via :func:`derive_initial_ts` and persisted. + + Returns the mesh meta dict persisted into bigtable. + """ + cfg = mesh_config.with_graph_id(cg.graph_id) + existing_mesh = cg.meta.custom_data.get("mesh", {}) + existing_ts = existing_mesh.get("initial_ts") + initial_ts = int(existing_ts) if existing_ts is not None else derive_initial_ts(cg) + if existing_ts is not None: + logger.info("preserving existing initial_ts=%d", initial_ts) + + # 1. watershed CV info — mesh dir. + cg.meta.ws_cv.info["mesh"] = cfg.dir + cg.meta.ws_cv.commit_info() + logger.info("wrote ws_cv.info['mesh']=%r", cfg.dir) + + # 2. sharded mesh spec — same template per layer, layer-specific bits. + layer_shard_spec = { + "@type": "neuroglancer_uint64_sharded_v1", + "preshift_bits": 0, + "hash": "murmurhash3_x86_128", + "shard_bits": 0, + "minishard_index_encoding": "gzip", + "data_encoding": "raw", + } + sharding = { + str(layer): {**layer_shard_spec, "minishard_bits": int(bits)} + for layer, bits in cfg.minishard_bits.items() + if layer <= cfg.max_layer + } + mesh_spec = { + "@type": "neuroglancer_legacy_mesh", + "spatial_index": None, + "mip": int(cfg.mip), + "chunk_size": list(cfg.chunk_size), + "sharding": sharding, + } + cg.meta.ws_cv.mesh.meta.info = mesh_spec + cg.meta.ws_cv.mesh.meta.commit_info() + logger.info("wrote sharded mesh spec for layers %s", sorted(sharding.keys())) + + # 3. uniform draco grid size, derived from layer-2 draco settings. + draco = get_draco_encoding_settings_for_chunk( + cg, cg.get_chunk_id(layer=2, x=0, y=0, z=0), mip=cfg.mip + ) + grid_size = draco["quantization_range"] / (2 ** draco["quantization_bits"] - 1) + cg.meta.ws_cv.info["mesh_metadata"] = { + "uniform_draco_grid_size": grid_size, + "unsharded_mesh_dir": "dynamic", + } + cg.meta.ws_cv.commit_info() + logger.info("wrote mesh_metadata uniform_draco_grid_size=%s", grid_size) + + # 4. CG-side bigtable meta block. + mesh_meta = { + "max_layer": int(cfg.max_layer), + "dynamic_mesh_dir": cfg.dynamic_mesh_dir, + "mip": int(cfg.mip), + "max_error": int(cfg.max_error), + "dir": cfg.dir, + "initial_ts": int(initial_ts), + } + cg.meta.custom_data["mesh"] = mesh_meta + cg.update_meta(cg.meta, overwrite=True) + logger.info("wrote cg.meta.custom_data['mesh']=%r", mesh_meta) + return mesh_meta diff --git a/pychunkedgraph/profiler/__init__.py b/pychunkedgraph/profiler/__init__.py new file mode 100644 index 000000000..7db7575f3 --- /dev/null +++ b/pychunkedgraph/profiler/__init__.py @@ -0,0 +1,6 @@ +""" +Hierarchical profiling. +""" + +from .hierarchical import BlockMetrics, HierarchicalProfiler +from .main import PROFILER_ENABLED, get_profiler, reset_profiler diff --git a/pychunkedgraph/profiler/hierarchical.py b/pychunkedgraph/profiler/hierarchical.py new file mode 100644 index 000000000..030c23552 --- /dev/null +++ b/pychunkedgraph/profiler/hierarchical.py @@ -0,0 +1,384 @@ +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import threading +import time +import tracemalloc +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass, field + +import psutil + +from .utils import _fmt_time, _fmt_bytes, _fmt_count + + +@dataclass +class BlockMetrics: + """Per-block metrics captured by HierarchicalProfiler.profile().""" + + path: str + elapsed_s: float + py_heap_peak_bytes: int = 0 + rss_start_bytes: int = 0 + rss_peak_bytes: int = 0 + counter_deltas: Dict[str, int] = field(default_factory=dict) + # Wall-clock time when this block finished, measured relative to + # the first profile() entry since the profiler was reset. + wall_end_s: float = 0.0 + + +class _RSSSampler: + """Daemon thread that samples process RSS and tracks the max. + + Internal to this module; not part of the profiler's public API. + """ + + def __init__(self, interval_s: float = 0.050): + self._interval = interval_s + self._stop = threading.Event() + self._thread: Optional[threading.Thread] = None + self._proc = psutil.Process() + self.start_rss = 0 + self.peak_rss = 0 + + def start(self) -> None: + self.start_rss = self._proc.memory_info().rss + self.peak_rss = self.start_rss + self._stop.clear() + self._thread = threading.Thread(target=self._loop, daemon=True) + self._thread.start() + + def stop(self) -> None: + self._stop.set() + if self._thread is not None: + self._thread.join(timeout=1.0) + self._thread = None + + def _loop(self) -> None: + while not self._stop.is_set(): + try: + rss = self._proc.memory_info().rss + if rss > self.peak_rss: + self.peak_rss = rss + except Exception: + pass + self._stop.wait(self._interval) + + +class HierarchicalProfiler: + """ + Hierarchical profiler for detailed timing breakdowns. + Tracks timing at multiple levels and prints a breakdown at the end. + + Optional per-block memory + IO collection is opt-in via kwargs on + profile(); see BlockMetrics for the captured fields. + """ + + def __init__( + self, + enabled: bool = True, + *, + with_memory: bool = True, + with_rss: bool = True, + ): + self.enabled = enabled + self.timings: Dict[str, List[float]] = defaultdict(list) + self.call_counts: Dict[str, int] = defaultdict(int) + self.stack: List[Tuple[str, float]] = [] + self.current_path: List[str] = [] + self.blocks: List[BlockMetrics] = [] + # perf_counter at the first profile() entry since reset. + # Used to stamp each block's wall_end_s for cumulative-wall view. + self._base_perf: Optional[float] = None + # Per-instance defaults so inline profile() call sites stay short + # (no per-block kwargs). Callers can override per-call. + self.with_memory_default = with_memory + self.with_rss_default = with_rss + # Optional caller-set default counters dict used when profile() + # is called without an explicit `counters=` kwarg. Lets inline + # profile() blocks in production code pick up an outer harness's + # IO counters without needing to thread them through. + self.default_counters: Optional[Dict[str, int]] = None + + @contextmanager + def profile( + self, + name: str, + *, + with_memory: Optional[bool] = None, + with_rss: Optional[bool] = None, + counters: Optional[Dict[str, int]] = None, + ): + """Context manager for profiling a code block. + + Default behavior (no kwargs) records only timing into + `self.timings` / `self.call_counts`, matching the original + implementation. + + Optional kwargs collect extra metrics into `self.blocks`: + - with_memory: tracemalloc Python heap peak per block. + - with_rss: psutil RSS peak via a 50 ms sampler thread. + - counters: caller-supplied dict; per-key deltas recorded + (after - before for keys present at exit). + """ + if not self.enabled: + yield + return + + if with_memory is None: + with_memory = self.with_memory_default + if with_rss is None: + with_rss = self.with_rss_default + if counters is None: + counters = self.default_counters + + full_path = ".".join(self.current_path + [name]) + self.current_path.append(name) + + started_tracemalloc = False + if with_memory: + if not tracemalloc.is_tracing(): + tracemalloc.start() + started_tracemalloc = True + tracemalloc.reset_peak() + + sampler: Optional[_RSSSampler] = None + if with_rss: + sampler = _RSSSampler() + sampler.start() + + counters_before = dict(counters) if counters is not None else None + + start_time = time.perf_counter() + if self._base_perf is None: + self._base_perf = start_time + try: + yield + finally: + end_time = time.perf_counter() + elapsed = end_time - start_time + self.timings[full_path].append(elapsed) + self.call_counts[full_path] += 1 + self.current_path.pop() + + py_peak = 0 + if with_memory: + _curr, py_peak = tracemalloc.get_traced_memory() + if started_tracemalloc: + tracemalloc.stop() + + rss_start = 0 + rss_peak = 0 + if sampler is not None: + sampler.stop() + rss_start = sampler.start_rss + rss_peak = sampler.peak_rss + + counter_deltas: Dict[str, int] = {} + if counters is not None and counters_before is not None: + for k, v_after in counters.items(): + counter_deltas[k] = v_after - counters_before.get(k, 0) + + self.blocks.append( + BlockMetrics( + path=full_path, + elapsed_s=elapsed, + py_heap_peak_bytes=int(py_peak), + rss_start_bytes=int(rss_start), + rss_peak_bytes=int(rss_peak), + counter_deltas=counter_deltas, + wall_end_s=end_time - self._base_perf, + ) + ) + + def print_report(self, operation_id=None): + """Print a detailed timing breakdown.""" + if not self.timings: + return + + print("\n" + "=" * 80) + print( + f"PROFILER REPORT{f' (operation_id={operation_id})' if operation_id else ''}" + ) + print("=" * 80) + + # Group by depth level + by_depth: Dict[int, List[Tuple[str, float, int]]] = defaultdict(list) + for path, times in self.timings.items(): + depth = path.count(".") + total_time = sum(times) + count = self.call_counts[path] + by_depth[depth].append((path, total_time, count)) + + # Sort each level by total time + for depth in sorted(by_depth.keys()): + items = sorted(by_depth[depth], key=lambda x: -x[1]) + for path, total_time, count in items: + indent = " " * depth + avg_time = total_time / count if count > 0 else 0 + if count > 1: + print( + f"{indent}{path}: {total_time*1000:.2f}ms total " + f"({count} calls, {avg_time*1000:.2f}ms avg)" + ) + else: + print(f"{indent}{path}: {total_time*1000:.2f}ms") + + # Print summary + print("-" * 80) + top_level_total = sum( + sum(times) for path, times in self.timings.items() if "." not in path + ) + print(f"Total top-level time: {top_level_total*1000:.2f}ms") + + # Print top 10 slowest operations + print("\nTop 10 slowest operations:") + all_ops = [ + (path, sum(times), self.call_counts[path]) + for path, times in self.timings.items() + ] + all_ops.sort(key=lambda x: -x[1]) + for i, (path, total_time, count) in enumerate(all_ops[:10]): + pct = (total_time / top_level_total * 100) if top_level_total > 0 else 0 + print(f" {i+1}. {path}: {total_time*1000:.2f}ms ({pct:.1f}%)") + + print("=" * 80 + "\n") + + # Counter keys that are uninformative for the SV-split flow and + # only add visual noise to the report. + _SKIP_COUNTERS = ("ocdbt_reads",) + + def metrics_report(self, operation_id=None) -> None: + """Print a compact, human-readable table over self.blocks. + + Columns: stage, wall, cum_wall, py_peak, rss_start, rss_peak, + rss_Δ (signed), plus one column per counter key that has a + non-zero value in at least one block. ``cum_wall`` is wall + time elapsed from the first ``profile()`` block since reset. + rss_start / rss_peak are absolute process RSS; rss_Δ is the + new-allocation delta inside the block. + + Rows are laid out as a tree: parent-first pre-order so each + group reads top-down (rollup, then per-step breakdown). Each + nesting level gets its own column (L0, L1, …); a block's name + sits in the column matching its depth, deeper columns blank. + Top-level groups keep execution order. No blank separator + rows. + """ + if not self.blocks: + return + + # Collect counter keys in first-seen order; drop ones that are + # zero in every block (e.g., bt_log_reads is usually 0 for SV + # splits and only adds noise) or in the static skip list. + counter_keys: List[str] = [] + seen_keys: set = set() + for b in self.blocks: + for k in b.counter_deltas: + if k not in seen_keys: + seen_keys.add(k) + counter_keys.append(k) + counter_keys = [ + k + for k in counter_keys + if k not in self._SKIP_COUNTERS + and any(b.counter_deltas.get(k, 0) for b in self.blocks) + ] + + def fmt_counter(key: str, val: int) -> str: + if key.endswith("_bytes"): + return _fmt_bytes(val) + return _fmt_count(val) + + # Build a parent-first pre-order over the dotted paths so each + # group reads top-down. self.blocks is in completion order + # (children before parents); order_idx preserves that as the + # tie-break for sibling ordering and top-level group order. + by_path: Dict[str, BlockMetrics] = {} + order_idx: Dict[str, int] = {} + for i, b in enumerate(self.blocks): + by_path[b.path] = b + order_idx[b.path] = i + + children: Dict[str, List[str]] = defaultdict(list) + roots: List[str] = [] + for b in self.blocks: + if "." in b.path: + children[b.path.rsplit(".", 1)[0]].append(b.path) + else: + roots.append(b.path) + roots.sort(key=lambda p: order_idx[p]) + for kids in children.values(): + kids.sort(key=lambda p: order_idx[p]) + + ordered: List[str] = [] + + def _visit(path: str) -> None: + ordered.append(path) + for child in children.get(path, []): + _visit(child) + + for r in roots: + _visit(r) + + # One name column per nesting level; a block's leaf name sits + # in the column matching its depth. + max_depth = max(p.count(".") for p in ordered) + level_cols = [f"L{i}" for i in range(max_depth + 1)] + metric_cols = [ + "wall", + "cum_wall", + "py_peak", + "rss_start", + "rss_peak", + "rss_Δ", + ] + counter_keys + cols = level_cols + metric_cols + + rows: List[List[str]] = [] + for path in ordered: + b = by_path[path] + depth = path.count(".") + level_cells = [""] * len(level_cols) + level_cells[depth] = path.rsplit(".", 1)[-1] + row = level_cells + [ + _fmt_time(b.elapsed_s), + _fmt_time(getattr(b, "wall_end_s", 0.0)), + _fmt_bytes(b.py_heap_peak_bytes), + _fmt_bytes(b.rss_start_bytes), + _fmt_bytes(b.rss_peak_bytes), + _fmt_bytes(b.rss_peak_bytes - b.rss_start_bytes, signed=True), + ] + for k in counter_keys: + row.append(fmt_counter(k, b.counter_deltas.get(k, 0))) + rows.append(row) + + widths = [len(c) for c in cols] + for row in rows: + for i, v in enumerate(row): + if len(v) > widths[i]: + widths[i] = len(v) + + def line(values: List[str]) -> str: + return " ".join(v.ljust(widths[i]) for i, v in enumerate(values)) + + title = "metrics report" + if operation_id is not None: + title = f"{title} (operation_id={operation_id})" + print(title) + print(line(cols)) + print(line(["-" * w for w in widths])) + for row in rows: + print(line(row)) + + def reset(self): + """Reset all timing data.""" + self.timings.clear() + self.call_counts.clear() + self.stack.clear() + self.current_path.clear() + self.blocks.clear() + self._base_perf = None diff --git a/pychunkedgraph/profiler/main.py b/pychunkedgraph/profiler/main.py new file mode 100644 index 000000000..63661e234 --- /dev/null +++ b/pychunkedgraph/profiler/main.py @@ -0,0 +1,21 @@ +import os + +from .hierarchical import HierarchicalProfiler + +PROFILER_ENABLED = os.environ.get("PCG_PROFILER_ENABLED", "0") == "1" +_profiler: HierarchicalProfiler = None + + +def get_profiler() -> HierarchicalProfiler: + """Get or create the global profiler instance.""" + global _profiler + if _profiler is None: + _profiler = HierarchicalProfiler(enabled=PROFILER_ENABLED) + return _profiler + + +def reset_profiler(): + """Reset the global profiler.""" + global _profiler + if _profiler is not None: + _profiler.reset() diff --git a/pychunkedgraph/profiler/utils.py b/pychunkedgraph/profiler/utils.py new file mode 100644 index 000000000..d2925730c --- /dev/null +++ b/pychunkedgraph/profiler/utils.py @@ -0,0 +1,29 @@ +def _fmt_time(s: float) -> str: + """Auto-scale seconds → ``12.3 ms`` / ``1.23 s``.""" + if s < 1.0: + return f"{s * 1000:.1f} ms" + return f"{s:.2f} s" + + +def _fmt_bytes(n: int, *, signed: bool = False) -> str: + """Auto-scale bytes (binary) → ``512 B`` / ``1.5 KB`` / ``45.6 MB`` / ``1.23 GB``. + + With ``signed=True``, positive values get a ``+`` prefix (for delta columns). + """ + if signed: + sign = "+" if n > 0 else "-" if n < 0 else "" + else: + sign = "-" if n < 0 else "" + n = abs(int(n)) + if n < 1024: + return f"{sign}{n} B" + if n < 1024**2: + return f"{sign}{n / 1024:.1f} KB" + if n < 1024**3: + return f"{sign}{n / 1024**2:.1f} MB" + return f"{sign}{n / 1024**3:.2f} GB" + + +def _fmt_count(n: int) -> str: + """Thousands-separator integer: ``1234567`` → ``1,234,567``.""" + return f"{int(n):,}" diff --git a/pychunkedgraph/repair/stuck_ops.py b/pychunkedgraph/repair/stuck_ops.py new file mode 100644 index 000000000..2b9064dca --- /dev/null +++ b/pychunkedgraph/repair/stuck_ops.py @@ -0,0 +1,290 @@ +"""Operator recovery for SV-split ops that crashed mid-write. + +A crash inside `IndefiniteL2ChunkLock`'s critical section leaves the +per-chunk `Concurrency.IndefiniteLock` cells set *and* records the +chunk scope on the op-log row's `OperationLogs.L2ChunkLockScope` field. +Ops on other (non-overlapping) chunks continue to succeed and advance +the OCDBT manifest while the stuck op sits there blocking its own +chunks. + +Recovery = cleanup + replay. The cleanup step reverts the stuck op's +partial OCDBT writes by copying pre-op voxel values (read from a +version-pinned OCDBT handle at the op's `OperationTimeStamp`) back to +the latest manifest. The replay then runs the op normally via the +existing `repair.edits.repair_operation` path — reads latest (clean on +the stuck op's chunks, current on everyone else's), writes fresh SV +IDs, and `IndefiniteL2ChunkLock`'s privileged-mode exit deletes the +crashed op's pre-existing cells. + +See `pychunkedgraph/graph/sv_split/recovery.md` for the full +architecture and correctness argument. +""" + +import argparse +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime, timedelta, timezone + +import numpy as np + +from pychunkedgraph import get_logger +from pychunkedgraph.graph import ChunkedGraph, attributes +from pychunkedgraph.graph.chunks.utils import get_chunk_coordinates +from pychunkedgraph.graph.locks import _l2_chunk_lock_row_key +from pychunkedgraph.graph.ocdbt import get_seg_source_and_destination_ocdbt +from pychunkedgraph.repair.edits import repair_operation + +logger = get_logger(__name__) + + +def _operation_ts_to_pin(operation_ts: datetime) -> str: + """Convert an op-log `OperationTimeStamp` to the OCDBT `version` + string format — ISO-8601 UTC with `Z` suffix, microsecond + precision. OCDBT's binder rejects `+00:00`. + """ + if operation_ts.tzinfo is None: + operation_ts = operation_ts.replace(tzinfo=timezone.utc) + else: + operation_ts = operation_ts.astimezone(timezone.utc) + return operation_ts.isoformat().replace("+00:00", "Z") + + +def _chunk_voxel_slices(cg: ChunkedGraph, chunk_id: int) -> tuple: + """Voxel-space slice tuple covering one L2 chunk, clipped to volume bounds.""" + coords = get_chunk_coordinates(cg.meta, np.uint64(chunk_id)) + chunk_size = np.array(cg.meta.graph_config.CHUNK_SIZE, dtype=int) + voxel_bounds = cg.meta.voxel_bounds + lo = coords * chunk_size + voxel_bounds[:, 0] + hi = np.minimum(lo + chunk_size, voxel_bounds[:, 1]) + return tuple(slice(int(s), int(e)) for s, e in zip(lo, hi)) + + +def list_stuck(cg: ChunkedGraph, min_age: timedelta = timedelta(minutes=10)) -> list: + """Return op-log entries whose `L2ChunkLockScope` is set past `min_age`, + excluding successfully-completed ops. + + The authoritative signal for a stuck op is "scope recorded" — + `IndefiniteL2ChunkLock.__enter__` writes it before any seg/bigtable + write and its clean `__exit__` clears it. An op whose scope is + still populated is either a worker crash (Status=CREATED, Fix 1's + `__exit__` short-circuit never ran) or an exception during the + persist block (Status=EXCEPTION, Fix 1 held the cells on the way + out). Either way it's still holding `Concurrency.IndefiniteLock` + cells on the listed chunks and blocking any new op that overlaps. + + Ops that reach `SUCCESS` normally have scope cleared — we defensively + filter them out in case `_clear_scope_on_op_log`'s best-effort write + failed and logged. Failed ops that never touched the persist block + (e.g. a PreconditionError from multicut) have no scope and don't + show up here; they're not blocking anything. + """ + now = datetime.now(timezone.utc) + cutoff = now - min_age + entries = cg.client.read_log_entries() + stuck = [] + success_code = attributes.OperationLogs.StatusCodes.SUCCESS.value + for op_id, entry in entries.items(): + scope = entry.get(attributes.OperationLogs.L2ChunkLockScope) + if scope is None or len(scope) == 0: + continue + if entry.get(attributes.OperationLogs.Status) == success_code: + continue + op_ts = entry.get(attributes.OperationLogs.OperationTimeStamp) + if op_ts is None: + continue + if op_ts.tzinfo is None: + op_ts = op_ts.replace(tzinfo=timezone.utc) + if op_ts > cutoff: + continue + stuck.append( + { + "op_id": int(op_id), + "operation_ts": op_ts, + "age": now - op_ts, + "user_id": entry.get(attributes.OperationLogs.UserID), + "l2_chunk_scope": scope, + "status": entry.get(attributes.OperationLogs.Status), + } + ) + stuck.sort(key=lambda r: r["op_id"]) + return stuck + + +def cleanup_partial_writes(cg: ChunkedGraph, op_id: int) -> int: + """Revert a stuck op's partial OCDBT writes to pre-op voxel values. + + Reads each chunk in the op's `L2ChunkLockScope` through an OCDBT + handle pinned at the op's `OperationTimeStamp` (which pre-dates any + of its commits), then writes those pre-op values back to the latest + manifest. Overwrites the crashed op's partial seg writes at the + same chunk keys; neighbor chunks are untouched, preserving any + concurrent ops' updates. + + Returns the number of chunks rewritten. + """ + log_entries = cg.client.read_log_entries(operation_ids=[np.uint64(op_id)]) + if not log_entries: + raise ValueError(f"No op-log row for op_id={op_id}") + entry = log_entries[np.uint64(op_id)] + + scope = entry.get(attributes.OperationLogs.L2ChunkLockScope) + if scope is None or len(scope) == 0: + logger.info(f"op {op_id} has no L2ChunkLockScope — nothing to clean up") + return 0 + + operation_ts = entry.get(attributes.OperationLogs.OperationTimeStamp) + if operation_ts is None: + raise ValueError(f"op {op_id} has no OperationTimeStamp") + pin_str = _operation_ts_to_pin(operation_ts) + + # Pinned read handle (read-only at pre-op version) vs. unpinned + # write handle (latest). Tensorstore refuses writes on version-pinned + # kvstores, so the two paths use separate handles. + _, pinned_scales, _ = get_seg_source_and_destination_ocdbt( + cg.meta.data_source.WATERSHED, + cg.meta.graph_id, + cg.meta.ocdbt_config, + pinned_at=pin_str, + ) + pinned_ws = pinned_scales[0] + latest_ws = cg.meta.ws_ocdbt + + def _revert_chunk(chunk_id: int) -> None: + voxel_slices = _chunk_voxel_slices(cg, int(chunk_id)) + pre_op = pinned_ws[voxel_slices + (slice(None),)].read().result() + latest_ws[voxel_slices + (slice(None),)].write(pre_op).result() + + # Parallel read-then-write per chunk. Bounded pool so large scopes + # don't saturate tensorstore's internal concurrency. + max_workers = min(16, max(1, len(scope))) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(_revert_chunk, int(c)) for c in scope] + for future in as_completed(futures): + future.result() + + logger.info(f"op {op_id}: reverted {len(scope)} partial chunk writes") + return len(scope) + + +def _verify_indefinite_cells(cg: ChunkedGraph, op_id: int, scope) -> list: + """Check each chunk in `scope` actually has `Concurrency.IndefiniteLock` + held by `op_id`. Returns the list of chunk IDs whose cell is missing + or held by a different op_id — an empty list means everything is + consistent. + + Guards `replay` against acting on a stale scope: if cells aren't + actually held (operator already ran replay, manual intervention, + any bug that released cells without clearing scope), `cleanup_ + partial_writes` would revert chunks that another op may have + legitimately written to in the meantime. Refusing loudly is safer + than assuming. + """ + lock_column = attributes.Concurrency.IndefiniteLock + expected = np.uint64(op_id) + discrepancies = [] + for chunk_id in scope: + row_key = _l2_chunk_lock_row_key(int(chunk_id)) + cells = cg.client._read_byte_row(row_key, columns=lock_column) + if not cells: + discrepancies.append(int(chunk_id)) + continue + held_by = cells[0].value if hasattr(cells[0], "value") else None + if held_by != expected: + discrepancies.append(int(chunk_id)) + return discrepancies + + +def replay(cg: ChunkedGraph, op_id: int): + """Recovery: verify locks, clean up partial OCDBT writes, then run + the op normally. + + Before any destructive step, read back the per-chunk + `Concurrency.IndefiniteLock` cells listed in the op's + `L2ChunkLockScope` and confirm they're still held by `op_id`. If + any are missing or held by another op, raise and do nothing — + proceeding would have `cleanup_partial_writes` revert chunks we + don't actually own. + + On clean verification, `cleanup_partial_writes` reverts the op's + partial OCDBT writes, then `repair.edits.repair_operation` reruns + `operation.execute(..., privileged_mode=True, parent_ts=)`. `IndefiniteL2ChunkLock.__enter__` in privileged mode + populates `acquired_keys` from the scope so `__exit__` releases the + crashed op's pre-existing indefinite cells after the replay writes + land. + """ + log_entries = cg.client.read_log_entries(operation_ids=[np.uint64(op_id)]) + if not log_entries: + raise ValueError(f"No op-log row for op_id={op_id}") + entry = log_entries[np.uint64(op_id)] + scope = entry.get(attributes.OperationLogs.L2ChunkLockScope) + if scope is None or len(scope) == 0: + raise RuntimeError( + f"op {op_id} has no L2ChunkLockScope — not a stuck SV-split op. " + "If the op failed cleanly, the client should re-submit under a " + "fresh op_id rather than replay." + ) + + mismatched = _verify_indefinite_cells(cg, op_id, scope) + if mismatched: + raise RuntimeError( + f"op {op_id}: L2ChunkLockScope lists chunks {[int(c) for c in scope]}, " + f"but the following chunks do not have Concurrency.IndefiniteLock " + f"held by op_id={op_id}: {mismatched}. Refusing to replay — the " + "recorded scope disagrees with live lock state. Possible causes: " + "replay already ran, cells were manually cleared, or a different " + "op acquired these chunks. Investigate before retrying." + ) + + cleanup_partial_writes(cg, op_id) + return repair_operation(cg, op_id, unlock=True) + + +def _main(): + parser = argparse.ArgumentParser( + description="Recover stuck SV-split operations via cleanup + replay." + ) + sub = parser.add_subparsers(dest="cmd", required=True) + + p_list = sub.add_parser( + "list", + help="List stuck ops (L2ChunkLockScope still populated past min-age).", + ) + p_list.add_argument("--graph", required=True, help="Graph ID.") + p_list.add_argument( + "--min-age", + type=int, + default=10, + help="Minimum age in minutes before an op is considered stuck (default: 10).", + ) + + p_replay = sub.add_parser( + "replay", help="Clean up partial writes and replay a stuck op." + ) + p_replay.add_argument("--graph", required=True, help="Graph ID.") + p_replay.add_argument("--op-id", type=int, required=True, help="Op ID to replay.") + + args = parser.parse_args() + cg = ChunkedGraph(graph_id=args.graph) + + if args.cmd == "list": + stuck = list_stuck(cg, min_age=timedelta(minutes=args.min_age)) + if not stuck: + print("No stuck ops.") + return + for row in stuck: + scope_size = ( + len(row["l2_chunk_scope"]) if row["l2_chunk_scope"] is not None else 0 + ) + print( + f"op {row['op_id']}: user={row['user_id']} " + f"ts={row['operation_ts'].isoformat()} " + f"age={row['age']} " + f"l2_chunks={scope_size}" + ) + elif args.cmd == "replay": + result = replay(cg, args.op_id) + print(f"replay complete: {result}") + + +if __name__ == "__main__": + _main() diff --git a/pychunkedgraph/tests/conftest.py b/pychunkedgraph/tests/conftest.py index 0e737f7f5..fa4683253 100644 --- a/pychunkedgraph/tests/conftest.py +++ b/pychunkedgraph/tests/conftest.py @@ -180,7 +180,7 @@ def _cgraph(request, n_layers=10, atomic_chunk_bounds: np.ndarray = np.array([]) "ingest_config": {}, } - meta, _, client_info = bootstrap("test", config=config) + meta, _, client_info, _ = bootstrap("test", config=config) graph = ChunkedGraph(graph_id="test", meta=meta, client_info=client_info) graph.mock_edges = Edges([], []) graph.meta._ws_cv = CloudVolumeMock() @@ -247,7 +247,7 @@ def _cgraph(request, n_layers=10, atomic_chunk_bounds: np.ndarray = np.array([]) "ingest_config": {}, } - meta, _, client_info = bootstrap("test", config=config) + meta, _, client_info, _ = bootstrap("test", config=config) graph = ChunkedGraph(graph_id="test", meta=meta, client_info=client_info) # No mock_edges - use real I/O via file:// protocol graph.meta._ws_cv = CloudVolumeMock() diff --git a/pychunkedgraph/tests/graph/test_cutting.py b/pychunkedgraph/tests/graph/test_cutting.py index 89cf4969d..4411d876c 100644 --- a/pychunkedgraph/tests/graph/test_cutting.py +++ b/pychunkedgraph/tests/graph/test_cutting.py @@ -4,8 +4,10 @@ import pytest from pychunkedgraph.graph.cutting import ( + Cut, IsolatingCutException, LocalMincutGraph, + PreviewCut, merge_cross_chunk_edges_graph_tool, run_multicut, ) @@ -336,8 +338,9 @@ def test_basic_split(self): path_augment=True, disallow_isolating_cut=False, ) - assert len(result) > 0 - result_set = set(map(tuple, result)) + assert isinstance(result, Cut) + assert len(result.atomic_edges) > 0 + result_set = set(map(tuple, result.atomic_edges)) assert (2, 3) in result_set or (3, 2) in result_set def test_basic_split_direct(self): @@ -354,8 +357,9 @@ def test_basic_split_direct(self): path_augment=False, disallow_isolating_cut=False, ) - assert len(result) > 0 - result_set = set(map(tuple, result)) + assert isinstance(result, Cut) + assert len(result.atomic_edges) > 0 + result_set = set(map(tuple, result.atomic_edges)) assert (2, 3) in result_set or (3, 2) in result_set def test_no_edges_raises(self): @@ -377,7 +381,7 @@ def test_no_edges_raises(self): ) def test_split_preview_mode(self): - """run_multicut with split_preview=True returns (ccs, illegal_split).""" + """run_multicut with split_preview=True returns a PreviewCut.""" node_ids1 = np.array([1, 2, 3], dtype=np.uint64) node_ids2 = np.array([2, 3, 4], dtype=np.uint64) affinities = np.array([0.9, 0.05, 0.9], dtype=np.float32) @@ -391,10 +395,10 @@ def test_split_preview_mode(self): path_augment=False, disallow_isolating_cut=False, ) - supervoxel_ccs, illegal_split = result - assert isinstance(supervoxel_ccs, list) - assert len(supervoxel_ccs) >= 2 - assert isinstance(illegal_split, bool) + assert isinstance(result, PreviewCut) + assert isinstance(result.supervoxel_ccs, list) + assert len(result.supervoxel_ccs) >= 2 + assert isinstance(result.illegal_split, bool) class TestMergeCrossChunkEdgesOverlap: @@ -641,7 +645,7 @@ class TestRunMulticutSplitPreview: """Test run_multicut in split_preview mode returns correct structure.""" def test_split_preview_returns_ccs_and_flag(self): - """run_multicut with split_preview=True should return (ccs, illegal_split).""" + """run_multicut with split_preview=True should return a PreviewCut.""" node_ids1 = np.array([1, 2, 3], dtype=np.uint64) node_ids2 = np.array([2, 3, 4], dtype=np.uint64) affinities = np.array([0.9, 0.01, 0.9], dtype=np.float32) @@ -656,15 +660,15 @@ def test_split_preview_returns_ccs_and_flag(self): disallow_isolating_cut=False, ) - supervoxel_ccs, illegal_split = result - assert isinstance(supervoxel_ccs, list) - assert len(supervoxel_ccs) >= 2 - assert isinstance(illegal_split, bool) + assert isinstance(result, PreviewCut) + assert isinstance(result.supervoxel_ccs, list) + assert len(result.supervoxel_ccs) >= 2 + assert isinstance(result.illegal_split, bool) # Source side CC - assert 1 in supervoxel_ccs[0] + assert 1 in result.supervoxel_ccs[0] # Sink side CC - assert 4 in supervoxel_ccs[1] + assert 4 in result.supervoxel_ccs[1] def test_split_preview_with_path_augment(self): """run_multicut with split_preview=True and path_augment=True.""" @@ -682,12 +686,12 @@ def test_split_preview_with_path_augment(self): disallow_isolating_cut=False, ) - supervoxel_ccs, illegal_split = result - assert len(supervoxel_ccs) >= 2 + assert isinstance(result, PreviewCut) + assert len(result.supervoxel_ccs) >= 2 # Source side - assert 1 in supervoxel_ccs[0] + assert 1 in result.supervoxel_ccs[0] # Sink side - assert 5 in supervoxel_ccs[1] + assert 5 in result.supervoxel_ccs[1] def test_split_preview_larger_graph(self): """split_preview on a larger graph with a clear cut point.""" @@ -709,14 +713,14 @@ def test_split_preview_larger_graph(self): disallow_isolating_cut=False, ) - supervoxel_ccs, illegal_split = result - source_cc = set(supervoxel_ccs[0]) - sink_cc = set(supervoxel_ccs[1]) + assert isinstance(result, PreviewCut) + source_cc = set(result.supervoxel_ccs[0]) + sink_cc = set(result.supervoxel_ccs[1]) # Source cluster assert {1, 2, 3}.issubset(source_cc) # Sink cluster assert {4, 5, 6}.issubset(sink_cc) - assert not illegal_split + assert not result.illegal_split class TestLocalMincutGraphWithLogger: @@ -1040,7 +1044,7 @@ class TestRunSplitPreview: """ def test_basic_split_preview(self): - """run_multicut with split_preview should return CCs and a flag.""" + """run_multicut with split_preview should return a PreviewCut.""" edges_sv = Edges( np.array([1, 2, 3, 4], dtype=np.uint64), np.array([2, 3, 4, 5], dtype=np.uint64), @@ -1049,16 +1053,17 @@ def test_basic_split_preview(self): ) sources = np.array([1], dtype=np.uint64) sinks = np.array([5], dtype=np.uint64) - ccs, illegal_split = run_multicut( + result = run_multicut( edges_sv, sources, sinks, split_preview=True, disallow_isolating_cut=False, ) - assert isinstance(ccs, list) - assert isinstance(illegal_split, bool) - assert len(ccs) >= 2 + assert isinstance(result, PreviewCut) + assert isinstance(result.supervoxel_ccs, list) + assert isinstance(result.illegal_split, bool) + assert len(result.supervoxel_ccs) >= 2 def test_split_preview_with_areas(self): """Split preview with areas provided.""" @@ -1070,7 +1075,7 @@ def test_split_preview_with_areas(self): ) sources = np.array([10], dtype=np.uint64) sinks = np.array([40], dtype=np.uint64) - ccs, illegal_split = run_multicut( + result = run_multicut( edges_sv, sources, sinks, @@ -1078,12 +1083,10 @@ def test_split_preview_with_areas(self): path_augment=False, disallow_isolating_cut=False, ) - assert isinstance(ccs, list) - assert len(ccs) >= 2 - # Source side should contain 10 - assert 10 in ccs[0] - # Sink side should contain 40 - assert 40 in ccs[1] + assert isinstance(result, PreviewCut) + assert len(result.supervoxel_ccs) >= 2 + assert 10 in result.supervoxel_ccs[0] + assert 40 in result.supervoxel_ccs[1] def test_split_preview_path_augment(self): """Split preview with path_augment=True.""" @@ -1094,7 +1097,7 @@ def test_split_preview_path_augment(self): ) sources = np.array([1], dtype=np.uint64) sinks = np.array([6], dtype=np.uint64) - ccs, illegal_split = run_multicut( + result = run_multicut( edges_sv, sources, sinks, @@ -1102,11 +1105,11 @@ def test_split_preview_path_augment(self): path_augment=True, disallow_isolating_cut=False, ) - assert isinstance(ccs, list) - assert len(ccs) >= 2 - assert 1 in ccs[0] - assert 6 in ccs[1] - assert not illegal_split + assert isinstance(result, PreviewCut) + assert len(result.supervoxel_ccs) >= 2 + assert 1 in result.supervoxel_ccs[0] + assert 6 in result.supervoxel_ccs[1] + assert not result.illegal_split class TestFilterGraphCCsWithLogger: diff --git a/pychunkedgraph/tests/graph/test_cutting_sv.py b/pychunkedgraph/tests/graph/test_cutting_sv.py index a2b29ac74..b27142e44 100644 --- a/pychunkedgraph/tests/graph/test_cutting_sv.py +++ b/pychunkedgraph/tests/graph/test_cutting_sv.py @@ -1,10 +1,10 @@ -"""Tests for pychunkedgraph.graph.cutting_sv""" +"""Tests for pychunkedgraph.graph.sv_split.cutting""" import numpy as np import pytest from scipy.spatial import cKDTree -from pychunkedgraph.graph.cutting_sv import ( +from pychunkedgraph.graph.sv_split.cutting import ( _cc_label_26, _largest_component_id, _to_zyx_sampling, @@ -19,6 +19,7 @@ _upsample_bool, _upsample_labels, build_kdtrees_by_label, + build_coords_by_label, pairwise_min_distance_two_sets, split_supervoxel_growing, connect_both_seeds_via_ridge, @@ -391,6 +392,108 @@ def test_invalid_mask_order(self): ) +class TestSnapSeedsBboxEquivalence: + """use_bbox=True must return the identical snapped voxel as the + full-mask scan. The bbox restriction is purely a candidate-set + optimization; a different nearest voxel would be a correctness + regression, which is exactly what these cases probe.""" + + def _both(self, seeds, mask, *, mask_order, voxel_size, bbox_pad_phys): + full = snap_seeds_to_segment( + seeds, + mask, + mask_order=mask_order, + voxel_size=voxel_size, + use_boundary=False, + downsample=False, + use_bbox=False, + ) + boxed = snap_seeds_to_segment( + seeds, + mask, + mask_order=mask_order, + voxel_size=voxel_size, + use_boundary=False, + downsample=False, + use_bbox=True, + bbox_pad_phys=bbox_pad_phys, + ) + return full, boxed + + def test_seed_outside_isotropic(self): + mask = np.zeros((30, 30, 30), dtype=bool) + mask[10:20, 10:20, 10:20] = True + seeds = np.array([[0.0, 0.0, 0.0]]) # outside; nearest is a corner + full, boxed = self._both( + seeds, + mask, + mask_order="zyx", + voxel_size=(1.0, 1.0, 1.0), + bbox_pad_phys=2.0, # smaller than the seed->mask gap, forces grow + ) + np.testing.assert_array_equal(full, boxed) + + def test_anisotropic_voxel_size(self): + # Anisotropy means the nearest voxel by physical distance need not + # be the nearest by index distance; the bbox pad is in voxels per + # axis, so this is the case most likely to expose a wrong window. + mask = np.zeros((40, 40, 10), dtype=bool) + mask[15:25, 15:25, 3:7] = True + seeds = np.array([[0.0, 0.0, 0.0]]) + full, boxed = self._both( + seeds, + mask, + mask_order="zyx", + voxel_size=(8.0, 8.0, 40.0), + bbox_pad_phys=8.0, + ) + np.testing.assert_array_equal(full, boxed) + + def test_multiple_far_apart_seeds(self): + mask = np.zeros((30, 30, 30), dtype=bool) + mask[2:6, 2:6, 2:6] = True + mask[24:28, 24:28, 24:28] = True + seeds = np.array([[0.0, 0.0, 0.0], [29.0, 29.0, 29.0]]) + full, boxed = self._both( + seeds, + mask, + mask_order="zyx", + voxel_size=(1.0, 1.0, 1.0), + bbox_pad_phys=1.0, + ) + np.testing.assert_array_equal(full, boxed) + + def test_xyz_mask_order(self): + mask_xyz = np.zeros((20, 24, 16), dtype=bool) + mask_xyz[6:14, 6:18, 4:12] = True + seeds = np.array([[0.0, 0.0, 0.0]]) + full, boxed = self._both( + seeds, + mask_xyz, + mask_order="xyz", + voxel_size=(1.0, 1.0, 1.0), + bbox_pad_phys=2.0, + ) + np.testing.assert_array_equal(full, boxed) + + def test_irregular_shape_grow_loop(self): + # Dumbbell: two blobs joined by a thin neck. A seed near one blob + # with a tiny pad starts with an empty window and must grow. + mask = np.zeros((40, 20, 20), dtype=bool) + mask[2:10, 6:14, 6:14] = True + mask[30:38, 6:14, 6:14] = True + mask[10:30, 9:11, 9:11] = True # neck + seeds = np.array([[10.0, 10.0, 1.0]]) # outside, off the z=0 end + full, boxed = self._both( + seeds, + mask, + mask_order="zyx", + voxel_size=(1.0, 1.0, 1.0), + bbox_pad_phys=1.0, + ) + np.testing.assert_array_equal(full, boxed) + + # ============================================================ # Tests: EDT # ============================================================ @@ -486,6 +589,111 @@ def test_uint64_labels(self): assert int(2**60) in trees +# ============================================================ +# Tests: select_voxel_indices + build_coords_by_label +# ============================================================ +def _reference_coords_by_label(vol, *, labels=None, background=0, min_points=1): + """Reference: brute-force np.argwhere(vol == k), restricted to labels. + + Used to verify build_coords_by_label against the algorithmic + definition (independent of its implementation). + """ + present = set(int(x) for x in np.unique(vol) if int(x) != background) + keys = present if labels is None else present & {int(x) for x in labels} + result = {} + for k in keys: + pts = np.argwhere(vol == k) + if pts.shape[0] < min_points: + continue + result[k] = pts.astype(np.float32) + return result + + +def _assert_dict_equal(actual, expected): + assert set(actual.keys()) == set(expected.keys()) + for k in expected: + a = actual[k][np.lexsort(actual[k].T[::-1])] + e = expected[k][np.lexsort(expected[k].T[::-1])] + assert a.shape == e.shape + np.testing.assert_array_equal(a, e) + + +class TestBuildCoordsByLabel: + def test_matches_reference_full_scan(self): + rng = np.random.default_rng(0) + vol = rng.integers(0, 4, size=(6, 5, 7), dtype=np.uint64) + actual = build_coords_by_label(vol) + expected = _reference_coords_by_label(vol) + _assert_dict_equal(actual, expected) + + def test_labels_filter_restricts_keys(self): + vol = np.zeros((5, 5, 5), dtype=np.uint64) + vol[1, 1, 1] = 7 + vol[2, 3, 0] = 8 + vol[0, 0, 3] = 9 + out = build_coords_by_label(vol, labels=[7, 9]) + assert set(out.keys()) == {7, 9} + assert 8 not in out + + def test_labels_filter_matches_reference(self): + rng = np.random.default_rng(1) + vol = rng.integers(0, 5, size=(5, 6, 4), dtype=np.uint64) + wanted = [1, 3] + actual = build_coords_by_label(vol, labels=wanted) + expected = _reference_coords_by_label(vol, labels=wanted) + _assert_dict_equal(actual, expected) + + def test_labels_filter_keeps_label_with_no_voxels_absent(self): + vol = np.zeros((3, 3, 3), dtype=np.uint64) + vol[1, 1, 1] = 5 + out = build_coords_by_label(vol, labels=[5, 42]) + assert set(out.keys()) == {5} + + def test_empty_labels(self): + vol = np.ones((3, 3, 3), dtype=np.uint64) + out = build_coords_by_label(vol, labels=[]) + assert out == {} + + def test_min_points_filter(self): + vol = np.zeros((4, 4, 4), dtype=np.uint64) + vol[0, 0, 0] = 1 + vol[1, 1, 1] = 2 + vol[1, 1, 2] = 2 + vol[1, 2, 1] = 2 + out = build_coords_by_label(vol, min_points=2) + assert 1 not in out + assert 2 in out + assert out[2].shape == (3, 3) + + def test_empty_volume(self): + vol = np.zeros((3, 3, 3), dtype=np.uint64) + assert build_coords_by_label(vol) == {} + + def test_background_nonzero(self): + vol = np.full((4, 4, 4), 99, dtype=np.uint64) + vol[2, 2, 2] = 5 + out = build_coords_by_label(vol, background=99) + assert set(out.keys()) == {5} + np.testing.assert_array_equal(out[5], np.array([[2, 2, 2]], dtype=np.float32)) + + def test_uint64_labels(self): + vol = np.zeros((4, 4, 4), dtype=np.uint64) + big = np.uint64(2**60) + vol[1, 1, 1] = big + out = build_coords_by_label(vol) + assert int(big) in out + + def test_dtype_float32_default(self): + vol = np.zeros((3, 3, 3), dtype=np.uint64) + vol[0, 0, 0] = 1 + out = build_coords_by_label(vol) + assert out[1].dtype == np.float32 + + def test_non_3d_raises(self): + with pytest.raises(ValueError, match="3D"): + build_coords_by_label(np.zeros((5, 5), dtype=np.uint64)) + + # ============================================================ # Tests: pairwise_min_distance_two_sets # ============================================================ @@ -565,7 +773,6 @@ def test_basic_split_xyz(self): vol_order="xyz", vox_order="xyz", seed_order="xyz", - verbose=False, snap_kwargs=dict(use_boundary=False, downsample=False), enforce_single_cc=True, raise_if_multi_cc=False, @@ -588,7 +795,6 @@ def test_basic_split_zyx(self): vol_order="zyx", vox_order="zyx", seed_order="zyx", - verbose=False, snap_kwargs=dict(use_boundary=False, downsample=False), enforce_single_cc=True, raise_if_multi_cc=False, @@ -610,7 +816,6 @@ def test_empty_seeds_returns_label1(self): vol_order="zyx", vox_order="zyx", seed_order="zyx", - verbose=False, snap_kwargs=dict(use_boundary=False, downsample=False), ) assert np.all(result[mask] == 1) @@ -627,7 +832,6 @@ def test_with_downsample_geodesic(self): vox_order="zyx", seed_order="zyx", downsample_geodesic=(1, 2, 2), - verbose=False, snap_kwargs=dict(use_boundary=False, downsample=False), enforce_single_cc=True, raise_if_multi_cc=False, @@ -636,6 +840,30 @@ def test_with_downsample_geodesic(self): assert np.any(result == 1) assert np.any(result == 2) + def test_single_cc_guarantee_with_downsample(self): + """Each output label must remain a single connected component + after the DS-grid enforce_cc rewrite. raise_if_multi_cc=True so + the function's own guard trips if the guarantee is broken.""" + mask, seeds_a, seeds_b = _make_dumbbell_mask(shape=(20, 30, 30)) + result = split_supervoxel_growing( + mask, + seeds_a, + seeds_b, + voxel_size=(1.0, 1.0, 1.0), + vol_order="zyx", + vox_order="zyx", + seed_order="zyx", + downsample_geodesic=(1, 2, 2), + enforce_single_cc=True, + raise_if_multi_cc=True, + snap_kwargs=dict(use_boundary=False, downsample=False), + ) + assert np.any(result == 1) + assert np.any(result == 2) + for lab in (1, 2): + _, ncomp = _cc_label_26(result == lab) + assert ncomp == 1, f"label {lab} split into {ncomp} components" + # ============================================================ # Tests: connect_both_seeds_via_ridge @@ -656,7 +884,6 @@ def test_basic_connection(self): vox_order="xyz", seed_order="xyz", downsample=(1, 1, 1), - verbose=False, snap_kwargs=dict(use_boundary=False, downsample=False), ) assert okA @@ -680,7 +907,6 @@ def test_single_seed_per_team(self): vol_order="xyz", seed_order="xyz", downsample=(1, 1, 1), - verbose=False, snap_kwargs=dict(use_boundary=False, downsample=False), ) assert okA @@ -701,7 +927,6 @@ def test_empty_seeds(self): vol_order="xyz", seed_order="xyz", downsample=(1, 1, 1), - verbose=False, snap_kwargs=dict(use_boundary=False, downsample=False), ) assert not okA @@ -722,7 +947,6 @@ def test_basic_split(self): seeds_a_xyz, seeds_b_xyz, voxel_size=(1.0, 1.0, 1.0), - verbose=False, ) assert result.shape == mask_xyz.shape assert np.any(result == 1) diff --git a/pychunkedgraph/tests/graph/test_downsample.py b/pychunkedgraph/tests/graph/test_downsample.py new file mode 100644 index 000000000..2eb799334 --- /dev/null +++ b/pychunkedgraph/tests/graph/test_downsample.py @@ -0,0 +1,309 @@ +"""Tests for pychunkedgraph.graph.downsample.""" + +import shutil +import tempfile +import threading +import time +from types import SimpleNamespace + +import numpy as np +import pytest +import tensorstore as ts + +from pychunkedgraph.graph import downsample as ds +from pychunkedgraph.graph.locks import ( + DownsampleBlockLock, + _downsample_block_lock_row_key, +) +from pychunkedgraph.graph import exceptions +from pychunkedgraph.tests.helpers import ( + RowKeyLockRegistry, + make_cg_with_row_key_lock_registry, +) + + +@pytest.fixture +def local_ocdbt(): + """3-scale file-backed OCDBT store with factor (2,2,1) between scales. + + Matches the fixture in test_ocdbt.py so downsample behaviour can be + exercised end-to-end against real tensorstore handles. + """ + tmpdir = tempfile.mkdtemp() + base = f"file://{tmpdir}/ocdbt/base" + mm = {"type": "segmentation", "data_type": "uint64", "num_channels": 1} + + def mk(size, resolution, extra_mm=None): + spec = { + "driver": "neuroglancer_precomputed", + "kvstore": {"driver": "ocdbt", "base": base}, + "scale_metadata": { + "size": size, + "resolution": resolution, + "encoding": "compressed_segmentation", + "compressed_segmentation_block_size": [8, 8, 8], + "chunk_size": [32, 32, 32], + }, + } + if extra_mm: + spec["multiscale_metadata"] = extra_mm + return ts.open(spec, create=True).result() + + scales = [ + mk([64, 64, 32], [4, 4, 40], extra_mm=mm), + mk([32, 32, 32], [8, 8, 40]), + mk([16, 16, 32], [16, 16, 40]), + ] + resolutions = [[4, 4, 40], [8, 8, 40], [16, 16, 40]] + + yield {"scales": scales, "resolutions": resolutions} + shutil.rmtree(tmpdir) + + +def _make_meta(local_ocdbt_, voxel_bounds=None): + """Minimal ChunkedGraphMeta stand-in with only the attributes downsample reads.""" + scales = local_ocdbt_["scales"] + if voxel_bounds is None: + # Full volume from scale 0. + dom = scales[0].domain + voxel_bounds = np.array( + [ + [dom[0].inclusive_min, dom[0].exclusive_max], + [dom[1].inclusive_min, dom[1].exclusive_max], + [dom[2].inclusive_min, dom[2].exclusive_max], + ], + dtype=int, + ) + return SimpleNamespace( + ws_ocdbt_scales=scales, + ws_ocdbt_resolutions=local_ocdbt_["resolutions"], + voxel_bounds=voxel_bounds, + ) + + +class TestBlockGeometry: + def test_num_output_mips(self, local_ocdbt): + meta = _make_meta(local_ocdbt) + assert ds.num_output_mips(meta) == 2 + + def test_uniform_factor(self, local_ocdbt): + meta = _make_meta(local_ocdbt) + assert ds.uniform_factor(meta) == (2, 2, 1) + + def test_non_uniform_factor_asserts(self, local_ocdbt): + meta = _make_meta(local_ocdbt) + meta.ws_ocdbt_resolutions = [[4, 4, 40], [8, 8, 40], [8, 16, 40]] + with pytest.raises(AssertionError): + ds.uniform_factor(meta) + + def test_block_shape_covers_one_coarsest_chunk(self, local_ocdbt): + # coarsest chunk = 32 mip-2 voxels per axis; factor^2 = (4,4,1). + # Block = 32 * (4,4,1) = (128, 128, 32) base voxels. + meta = _make_meta(local_ocdbt) + assert tuple(ds.block_shape(meta).tolist()) == (128, 128, 32) + + def test_blocks_for_bbox_single(self, local_ocdbt): + meta = _make_meta(local_ocdbt) + # Tiny bbox entirely inside block (0,0,0). + blocks = ds.blocks_for_bbox(meta, [10, 10, 5], [20, 20, 10]) + assert blocks == [(0, 0, 0)] + + def test_blocks_for_bbox_spans_block_boundary(self, local_ocdbt): + meta = _make_meta(local_ocdbt) + # Block shape = (128,128,32). Bbox from (120,0,0) to (200,50,10) + # crosses the x-axis boundary at 128. + blocks = ds.blocks_for_bbox(meta, [120, 0, 0], [200, 50, 10]) + assert blocks == sorted([(0, 0, 0), (1, 0, 0)]) + + def test_block_base_bbox_roundtrip(self, local_ocdbt): + meta = _make_meta(local_ocdbt) + lo, hi = ds.block_base_bbox(meta, (0, 0, 0)) + assert tuple(lo.tolist()) == (0, 0, 0) + assert tuple(hi.tolist()) == (128, 128, 32) + + lo, hi = ds.block_base_bbox(meta, (2, 1, 0)) + assert tuple(lo.tolist()) == (256, 128, 0) + assert tuple(hi.tolist()) == (384, 256, 32) + + +class TestProcessBlockInMemory: + def test_writes_to_every_non_base_scale(self, local_ocdbt): + """Base region intersected by bbox propagates to mip 1 and mip 2.""" + scales = local_ocdbt["scales"] + # Seed base with a constant label. + data = np.full((32, 32, 32), 7, dtype=np.uint64) + scales[0][0:32, 0:32, 0:32, :].write(data[..., np.newaxis]).result() + + meta = _make_meta(local_ocdbt) + # Block (0,0,0) has shape (128,128,32); only its (0..32, 0..32, 0..32) + # subregion has real data — the rest is zeros. + ds.process_block( + meta, (0, 0, 0), [(np.array([0, 0, 0]), np.array([32, 32, 32]))] + ) + + mip1 = scales[1][0:16, 0:16, 0:32, :].read().result() + mip2 = scales[2][0:8, 0:8, 0:32, :].read().result() + assert (mip1 == 7).all() + assert (mip2 == 7).all() + + def test_region_outside_bbox_stays_zero(self, local_ocdbt): + """Mip tiles whose base footprint misses the bbox are not written.""" + scales = local_ocdbt["scales"] + # Seed base with 3 inside the edit bbox only. + edit_data = np.full((16, 16, 16), 3, dtype=np.uint64) + scales[0][0:16, 0:16, 0:16, :].write(edit_data[..., np.newaxis]).result() + + meta = _make_meta(local_ocdbt) + ds.process_block( + meta, (0, 0, 0), [(np.array([0, 0, 0]), np.array([16, 16, 16]))] + ) + + # Tile inside edit: written with label 3. + mip1_inside = scales[1][0:8, 0:8, 0:16, :].read().result() + assert (mip1_inside == 3).all() + # Tile outside edit (far corner of block): still zero. + mip1_outside = scales[1][12:16, 12:16, 16:32, :].read().result() + assert (mip1_outside == 0).all() + + +class TestProcessBlockDispatcher: + def test_selects_in_memory_when_under_budget(self, local_ocdbt, monkeypatch): + """Typical small affected region → in-memory path.""" + calls = {"in_memory": 0, "per_mip": 0} + monkeypatch.setattr( + ds, + "_process_block_in_memory", + lambda *a, **kw: calls.__setitem__("in_memory", calls["in_memory"] + 1), + ) + monkeypatch.setattr( + ds, + "_process_block_per_mip", + lambda *a, **kw: calls.__setitem__("per_mip", calls["per_mip"] + 1), + ) + meta = _make_meta(local_ocdbt) + ds.process_block( + meta, (0, 0, 0), [(np.array([0, 0, 0]), np.array([16, 16, 16]))] + ) + assert calls == {"in_memory": 1, "per_mip": 0} + + def test_selects_per_mip_when_over_budget(self, local_ocdbt, monkeypatch): + """When the base read would exceed budget, the per-mip path runs.""" + calls = {"in_memory": 0, "per_mip": 0} + monkeypatch.setattr( + ds, + "_process_block_in_memory", + lambda *a, **kw: calls.__setitem__("in_memory", calls["in_memory"] + 1), + ) + monkeypatch.setattr( + ds, + "_process_block_per_mip", + lambda *a, **kw: calls.__setitem__("per_mip", calls["per_mip"] + 1), + ) + meta = _make_meta(local_ocdbt) + ds.process_block( + meta, + (0, 0, 0), + [(np.array([0, 0, 0]), np.array([128, 128, 32]))], + memory_budget_bytes=1, # force the fallback + ) + assert calls == {"in_memory": 0, "per_mip": 1} + + +class TestDownsampleBlockRowKey: + def test_length(self): + assert len(_downsample_block_lock_row_key((0, 0, 0))) == 26 + + def test_deterministic(self): + assert _downsample_block_lock_row_key( + (7, 8, 9) + ) == _downsample_block_lock_row_key((7, 8, 9)) + + def test_distinct_coords_distinct_keys(self): + a = _downsample_block_lock_row_key((1, 0, 0)) + b = _downsample_block_lock_row_key((0, 1, 0)) + assert a != b + + def test_hash_prefix_scatters(self): + """Adjacent block coords should not produce adjacent row keys (the whole + point of the hash prefix).""" + # Gather hash prefixes for a line of adjacent coords; they should span + # many distinct first-bytes, not cluster in one byte. + prefixes = {_downsample_block_lock_row_key((i, 0, 0))[0] for i in range(128)} + assert len(prefixes) > 32 + + +class TestDownsampleBlockLock: + def test_acquire_and_release(self): + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + with DownsampleBlockLock(cg, [(0, 0, 0), (1, 0, 0)], np.uint64(42)): + assert len(registry._held) == 2 + assert registry._held == {} + + def test_non_overlapping_concurrent(self): + """Two locks on disjoint block sets can coexist.""" + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + l1 = DownsampleBlockLock(cg, [(0, 0, 0)], np.uint64(1)) + l2 = DownsampleBlockLock(cg, [(5, 5, 5)], np.uint64(2)) + l1.__enter__() + l2.__enter__() + assert len(registry._held) == 2 + l1.__exit__(None, None, None) + l2.__exit__(None, None, None) + assert registry._held == {} + + def test_overlapping_contends(self, monkeypatch): + """Two overlapping acquisitions serialize: second blocks until first releases.""" + # Short backoff so the waiting thread retries quickly after release. + monkeypatch.setattr(DownsampleBlockLock, "_ACQUIRE_BACKOFF_BASE_SEC", 0.05) + + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + + l1 = DownsampleBlockLock(cg, [(0, 0, 0)], np.uint64(1)) + l1.__enter__() + + second_entered = threading.Event() + second_failed = threading.Event() + + def second(): + lock = DownsampleBlockLock(cg, [(0, 0, 0)], np.uint64(2)) + try: + lock.__enter__() + second_entered.set() + lock.__exit__(None, None, None) + except exceptions.LockingError: + second_failed.set() + + t = threading.Thread(target=second) + t.start() + time.sleep(0.2) + # l1 is still holding; second should not have entered. + assert not second_entered.is_set() + # Now release; second should succeed on its next retry. + l1.__exit__(None, None, None) + t.join(timeout=2.0) + assert second_entered.is_set() + assert not second_failed.is_set() + assert registry._held == {} + + def test_partial_acquire_released_on_failure(self, monkeypatch): + """If any coord in the set fails to lock, prior ones are released.""" + monkeypatch.setattr(DownsampleBlockLock, "_MAX_ACQUIRE_ATTEMPTS", 2) + monkeypatch.setattr(DownsampleBlockLock, "_ACQUIRE_BACKOFF_BASE_SEC", 0.01) + + registry = RowKeyLockRegistry() + # Pre-hold (1,0,0) so the second coord always fails. + registry.lock_by_row_key( + _downsample_block_lock_row_key((1, 0, 0)), np.uint64(99) + ) + + cg = make_cg_with_row_key_lock_registry(registry) + lock = DownsampleBlockLock(cg, [(0, 0, 0), (1, 0, 0)], np.uint64(1)) + with pytest.raises(exceptions.LockingError): + lock.__enter__() + # Only (1,0,0) should remain held, by the pre-existing holder. + assert len(registry._held) == 1 + only_key = next(iter(registry._held)) + assert only_key == _downsample_block_lock_row_key((1, 0, 0)) diff --git a/pychunkedgraph/tests/graph/test_edges_sv.py b/pychunkedgraph/tests/graph/test_edges_sv.py index d8ad0f2ba..9265cf85c 100644 --- a/pychunkedgraph/tests/graph/test_edges_sv.py +++ b/pychunkedgraph/tests/graph/test_edges_sv.py @@ -1,4 +1,4 @@ -"""Comprehensive tests for pychunkedgraph.graph.edges_sv — edge routing after SV split.""" +"""Comprehensive tests for pychunkedgraph.graph.sv_split.edges — edge routing after SV split.""" import numpy as np import pytest @@ -6,7 +6,7 @@ from pychunkedgraph.graph import basetypes from pychunkedgraph.graph.exceptions import PostconditionError -from pychunkedgraph.graph.edges_sv import ( +from pychunkedgraph.graph.sv_split.edges import ( _get_new_edges, _match_by_label, _match_by_proximity, diff --git a/pychunkedgraph/tests/graph/test_edits_sv.py b/pychunkedgraph/tests/graph/test_edits_sv.py index bced0a070..1d852b433 100644 --- a/pychunkedgraph/tests/graph/test_edits_sv.py +++ b/pychunkedgraph/tests/graph/test_edits_sv.py @@ -1,14 +1,16 @@ -"""Tests for pychunkedgraph.graph.edits_sv""" +"""Tests for pychunkedgraph.graph.sv_split.edits""" import numpy as np import pytest from collections import defaultdict from unittest.mock import MagicMock, patch -from pychunkedgraph.graph.edits_sv import ( +from pychunkedgraph.graph.sv_split.edits import ( + _coords_bbox, _voxel_crop, _parse_results, copy_parents_and_add_lineage, + plan_sv_splits, ) from pychunkedgraph.graph import attributes, basetypes @@ -237,3 +239,130 @@ def test_operation_id_stored(self): assert val_dict[attributes.OperationLogs.OperationID] == 99 op_id_found = True assert op_id_found + + def test_time_stamp_threaded_to_new_sv_writes(self): + """New-SV writes (FormerIdentity/OperationID on new, NewIdentity + on old) land at `time_stamp`. Parent-copy and Child-list writes + preserve the old cell's timestamp so pre-op readers still see + the old hierarchy. + """ + from datetime import datetime, timezone + + old = np.uint64(10) + new1 = np.uint64(101) + parent = np.uint64(1000) + + old_cell_ts = 42 # old cell's timestamp, preserved on Parent/Child copies + op_ts = datetime(2026, 4, 23, tzinfo=timezone.utc) # op's logical write time + + parent_cells_map = {old: [_FakeCell(parent, timestamp=old_cell_ts)]} + children_cells_map = { + parent: [ + _FakeCell( + np.array([old], dtype=basetypes.NODE_ID), timestamp=old_cell_ts + ) + ] + } + cg = self._make_cg(parent_cells_map, children_cells_map) + + copy_parents_and_add_lineage( + cg, operation_id=7, old_new_map={old: {new1}}, time_stamp=op_ts + ) + + # Classify each mutate_row call by which column it writes. + for call in cg.client.mutate_row.call_args_list: + val_dict = call[0][1] + kw = call[1] + ts = kw.get("time_stamp") + cols = set(val_dict.keys()) + + if attributes.Hierarchy.FormerIdentity in cols: + # New-SV lineage write — should use op's time_stamp. + assert ts == op_ts, f"FormerIdentity write ts={ts}, expected {op_ts}" + elif attributes.Hierarchy.NewIdentity in cols: + # Old-SV NewIdentity write — should use op's time_stamp. + assert ts == op_ts, f"NewIdentity write ts={ts}, expected {op_ts}" + elif attributes.Hierarchy.Parent in cols: + # Copied-parent write — preserves old cell's timestamp. + assert ( + ts == old_cell_ts + ), f"Parent-copy write ts={ts}, expected {old_cell_ts}" + elif attributes.Hierarchy.Child in cols: + # Updated-children write on L2 parent — preserves old timestamp. + assert ( + ts == old_cell_ts + ), f"Child-list write ts={ts}, expected {old_cell_ts}" + + +# ============================================================ +# Tests: _coords_bbox / plan_sv_splits bbox is seed-driven, not rep-driven +# ============================================================ +class TestCoordsBbox: + def _make_cg(self, chunk_size=(64, 64, 64), volume=(1024, 1024, 1024)): + cg = MagicMock() + cg.meta.graph_config.CHUNK_SIZE = list(chunk_size) + cg.meta.voxel_bounds = np.array( + [[0, volume[0]], [0, volume[1]], [0, volume[2]]] + ) + cg.get_chunk_id.side_effect = lambda layer, x, y, z: ( + (layer << 60) | (x << 40) | (y << 20) | z + ) + return cg + + def test_envelope_around_seeds_with_one_chunk_margin(self): + cg = self._make_cg(chunk_size=(64, 64, 64)) + src = np.array([[100, 200, 300]]) + sink = np.array([[150, 250, 350]]) + bbs, bbe = _coords_bbox(cg, src, sink) + # min - chunk_size, max + chunk_size, clipped to volume bounds. + np.testing.assert_array_equal(bbs, np.array([100 - 64, 200 - 64, 300 - 64])) + np.testing.assert_array_equal(bbe, np.array([150 + 64, 250 + 64, 350 + 64])) + + def test_clipped_to_volume_bounds(self): + cg = self._make_cg(chunk_size=(64, 64, 64), volume=(256, 256, 256)) + src = np.array([[10, 10, 10]]) + sink = np.array([[250, 250, 250]]) + bbs, bbe = _coords_bbox(cg, src, sink) + # Lower seed - 64 = -54 → clipped to 0; upper seed + 64 = 314 → clipped to 256. + np.testing.assert_array_equal(bbs, np.array([0, 0, 0])) + np.testing.assert_array_equal(bbe, np.array([256, 256, 256])) + + def test_plan_sv_splits_bbox_independent_of_rep_extent(self): + """The returned per-task bbox follows the seeds, not the rep's + cross-chunk pieces. A rep whose pieces span the whole volume + produces the same tight bbox as a rep with one piece, given the + same src/sink coords. + """ + cg = self._make_cg(chunk_size=(64, 64, 64), volume=(1024, 1024, 1024)) + + # Two source/sink IDs that map to the same rep — the SV-split + # trigger condition. The rep's other pieces (b..z) sit far from + # the seeds. They would have ballooned the old `_rep_bbox`; the + # new `_coords_bbox` ignores them. + rep = np.uint64(1) + sv_remapping = { + np.uint64(10): rep, # src + np.uint64(20): rep, # sink + **{np.uint64(100 + i): rep for i in range(28)}, # 28 distant pieces + } + + source_ids = np.array([10], dtype=basetypes.NODE_ID) + sink_ids = np.array([20], dtype=basetypes.NODE_ID) + source_coords = np.array([[100, 200, 300]]) + sink_coords = np.array([[150, 250, 350]]) + + tasks, _ = plan_sv_splits( + cg, + sv_remapping=sv_remapping, + source_ids=source_ids, + sink_ids=sink_ids, + source_coords=source_coords, + sink_coords=sink_coords, + ) + assert len(tasks) == 1 + np.testing.assert_array_equal( + tasks[0].bbs, np.array([100 - 64, 200 - 64, 300 - 64]) + ) + np.testing.assert_array_equal( + tasks[0].bbe, np.array([150 + 64, 250 + 64, 350 + 64]) + ) diff --git a/pychunkedgraph/tests/graph/test_locks.py b/pychunkedgraph/tests/graph/test_locks.py index 97da9334c..0d82262a9 100644 --- a/pychunkedgraph/tests/graph/test_locks.py +++ b/pychunkedgraph/tests/graph/test_locks.py @@ -1,10 +1,23 @@ +import threading +import time from time import sleep from datetime import datetime, timedelta, UTC import numpy as np import pytest -from ..helpers import create_chunk, to_label +from ..helpers import ( + RowKeyLockRegistry, + create_chunk, + make_cg_with_row_key_lock_registry, + to_label, +) +from ...graph import attributes, exceptions +from ...graph.locks import ( + IndefiniteL2ChunkLock, + L2ChunkLock, + _l2_chunk_lock_row_key, +) from ...graph.lineage import get_future_root_ids from ...ingest.create.parent_layer import add_parent_chunk @@ -702,6 +715,30 @@ def test_indefiniterootlock_exit_handles_exception(self): # Should not raise lock.__exit__(None, None, None) + def test_indefiniterootlock_exit_holds_on_exception_path(self): + """When `__exit__` is called with a propagating exception, cells + stay held — partial bigtable hierarchy writes may have landed + and further ops must refuse until operator recovery runs. + """ + cg = _make_mock_cg() + root_ids = np.array([np.uint64(100), np.uint64(101)]) + cg.client.lock_roots_indefinitely.return_value = ( + True, + list(root_ids), + [], + ) + + lock = IndefiniteRootLock( + cg, + np.uint64(10), + root_ids, + future_root_ids_d=defaultdict(list), + ) + lock.__enter__() + lock.__exit__(ValueError, ValueError("boom"), None) + + cg.client.unlock_indefinitely_locked_root.assert_not_called() + class TestIndefiniteRootLockComputesFutureRootIds: def test_indefiniterootlock_computes_future_root_ids(self): @@ -750,3 +787,266 @@ def test_rootlock_as_context_manager(self): assert lock.lock_acquired is True cg.client.unlock_root.assert_called_once() + + +class TestL2ChunkLockRowKey: + def test_length(self): + assert len(_l2_chunk_lock_row_key(0)) == 10 + + def test_deterministic(self): + assert _l2_chunk_lock_row_key(0xDEADBEEF) == _l2_chunk_lock_row_key(0xDEADBEEF) + + def test_distinct_chunks_distinct_keys(self): + assert _l2_chunk_lock_row_key(42) != _l2_chunk_lock_row_key(43) + + def test_hash_prefix_scatters(self): + """Adjacent chunk IDs should not cluster in one first-byte prefix — + that's the whole point of the hash prefix.""" + prefixes = {_l2_chunk_lock_row_key(i)[0] for i in range(256)} + # blake2b over 8 bytes of changing input distributes uniformly. + assert len(prefixes) > 128 + + +class TestL2ChunkLock: + def test_acquire_and_release(self): + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + with L2ChunkLock(cg, [np.uint64(1), np.uint64(2)], np.uint64(42)): + assert len(registry._held) == 2 + assert registry._held == {} + + def test_non_overlapping_concurrent(self): + """Disjoint chunk sets can coexist — no shared row keys.""" + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + l1 = L2ChunkLock(cg, [np.uint64(1)], np.uint64(1)) + l2 = L2ChunkLock(cg, [np.uint64(5)], np.uint64(2)) + l1.__enter__() + l2.__enter__() + assert len(registry._held) == 2 + l1.__exit__(None, None, None) + l2.__exit__(None, None, None) + assert registry._held == {} + + def test_overlapping_contends(self, monkeypatch): + """Two overlapping acquisitions serialize: second blocks until first releases.""" + monkeypatch.setattr(L2ChunkLock, "_ACQUIRE_BACKOFF_BASE_SEC", 0.05) + + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + + l1 = L2ChunkLock(cg, [np.uint64(7)], np.uint64(1)) + l1.__enter__() + + second_entered = threading.Event() + second_failed = threading.Event() + + def second(): + lock = L2ChunkLock(cg, [np.uint64(7)], np.uint64(2)) + try: + lock.__enter__() + second_entered.set() + lock.__exit__(None, None, None) + except exceptions.LockingError: + second_failed.set() + + t = threading.Thread(target=second) + t.start() + time.sleep(0.2) + assert not second_entered.is_set() + l1.__exit__(None, None, None) + t.join(timeout=2.0) + assert second_entered.is_set() + assert not second_failed.is_set() + assert registry._held == {} + + def test_partial_acquire_released_on_failure(self, monkeypatch): + """If any chunk in the set fails to lock, prior ones are released.""" + monkeypatch.setattr(L2ChunkLock, "_MAX_ACQUIRE_ATTEMPTS", 2) + monkeypatch.setattr(L2ChunkLock, "_ACQUIRE_BACKOFF_BASE_SEC", 0.01) + + registry = RowKeyLockRegistry() + registry.lock_by_row_key(_l2_chunk_lock_row_key(np.uint64(2)), np.uint64(99)) + + cg = make_cg_with_row_key_lock_registry(registry) + lock = L2ChunkLock(cg, [np.uint64(1), np.uint64(2)], np.uint64(1)) + with pytest.raises(exceptions.LockingError): + lock.__enter__() + # Only chunk 2 remains held, by the pre-existing holder. + assert len(registry._held) == 1 + assert next(iter(registry._held)) == _l2_chunk_lock_row_key(np.uint64(2)) + + def test_privileged_mode_skips_acquire(self): + """Replay path: indefinite cells from the crashed op are still + set, so a normal temporal acquire would refuse. Privileged mode + bypasses the acquire entirely — the indefinite cells are the + de-facto lock and the inner `IndefiniteL2ChunkLock(privileged=True)` + releases them on exit. + """ + registry = RowKeyLockRegistry() + # Crashed op's indefinite cells block a normal temporal acquire. + crashed_op = np.uint64(42) + for c in (np.uint64(1), np.uint64(2)): + registry.lock_by_row_key_indefinitely(_l2_chunk_lock_row_key(c), crashed_op) + + cg = make_cg_with_row_key_lock_registry(registry) + + # Normal acquire refuses because indefinite is held. + normal = L2ChunkLock(cg, [np.uint64(1), np.uint64(2)], np.uint64(99)) + with pytest.raises(exceptions.LockingError): + normal.__enter__() + + # Privileged acquire — called from replay with the same op_id as + # the crashed op — skips the acquire and returns cleanly. + priv = L2ChunkLock( + cg, [np.uint64(1), np.uint64(2)], crashed_op, privileged_mode=True + ) + priv.__enter__() + priv.__exit__(None, None, None) + # Indefinite cells still held (privileged-L2ChunkLock doesn't + # touch them — that's IndefiniteL2ChunkLock(privileged=True)'s job). + assert len(registry._held_indefinite) == 2 + + +class TestIndefiniteL2ChunkLock: + """`IndefiniteL2ChunkLock` lifecycle: acquire + scope write on enter, + release + scope clear on exit; privileged mode releases pre-existing + cells left by a crashed op. + """ + + def _scope_mutate_calls(self, cg): + """Extract (row_key, scope_value) from cg.client.mutate_row calls + that set `L2ChunkLockScope`. Lets tests assert on what was written.""" + calls = [] + for call in cg.client.mutate_row.call_args_list: + row_key, val_dict = call[0][:2] + if attributes.OperationLogs.L2ChunkLockScope in val_dict: + calls.append( + (row_key, val_dict[attributes.OperationLogs.L2ChunkLockScope]) + ) + return calls + + def test_enter_writes_scope_and_acquires_cells(self): + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + chunks = [np.uint64(3), np.uint64(1), np.uint64(2)] + op_id = np.uint64(42) + + lock = IndefiniteL2ChunkLock(cg, chunks, op_id) + lock.__enter__() + try: + # Every chunk now has an indefinite cell. + assert len(registry._held_indefinite) == 3 + # Scope written to op-log row; value is the sorted chunk list. + scope_calls = self._scope_mutate_calls(cg) + non_empty = [c for c in scope_calls if len(c[1]) > 0] + assert len(non_empty) == 1 + assert list(non_empty[0][1]) == [1, 2, 3] + finally: + lock.__exit__(None, None, None) + + def test_exit_releases_cells_and_clears_scope(self): + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + chunks = [np.uint64(1), np.uint64(2)] + with IndefiniteL2ChunkLock(cg, chunks, np.uint64(42)): + pass + # Cells released. + assert registry._held_indefinite == {} + # Scope cleared: one write of an empty array to L2ChunkLockScope. + empty_calls = [c for c in self._scope_mutate_calls(cg) if len(c[1]) == 0] + assert len(empty_calls) == 1 + + def test_privileged_mode_releases_preexisting(self): + """Crashed op left indefinite cells under its op_id; the replay + re-enters with privileged_mode=True and the `__exit__` is expected + to delete those pre-existing cells (value-matched by op_id). + """ + registry = RowKeyLockRegistry() + op_id = np.uint64(42) + chunks = [np.uint64(10), np.uint64(20)] + for c in chunks: + assert registry.lock_by_row_key_indefinitely( + _l2_chunk_lock_row_key(c), op_id + ) + assert len(registry._held_indefinite) == 2 + + cg = make_cg_with_row_key_lock_registry(registry) + with IndefiniteL2ChunkLock(cg, chunks, op_id, privileged_mode=True): + # Privileged enter skips acquire, so pre-existing cells persist. + assert len(registry._held_indefinite) == 2 + # Privileged mode does not re-write the scope either; only the + # clear-on-exit writes `L2ChunkLockScope`. + assert self._scope_mutate_calls(cg) == [] + # Exit released the pre-existing cells. + assert registry._held_indefinite == {} + + def test_double_acquire_fails(self): + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + op_a = np.uint64(1) + op_b = np.uint64(2) + with IndefiniteL2ChunkLock(cg, [np.uint64(5)], op_a): + lock_b = IndefiniteL2ChunkLock(cg, [np.uint64(5)], op_b) + with pytest.raises(exceptions.LockingError): + lock_b.__enter__() + # Op A's cell still held. + assert len(registry._held_indefinite) == 1 + + def test_replay_nested_privileged_clears_crashed_cells(self): + """Replay lock-dance against a crashed op's pre-existing cells. + + Simulates what `MulticutOperation._apply` does during replay: + `with L2ChunkLock(privileged=True): with IndefiniteL2ChunkLock( + privileged=True): ...`. Both locks must succeed despite indefinite + cells being pre-held, and the inner `__exit__` must release them. + + This regresses the bug where `L2ChunkLock` lacked a privileged + escape hatch — the temporal acquire would refuse because + `lock_by_row_key_with_indefinite` sees the crashed op's + indefinite cell. + """ + registry = RowKeyLockRegistry() + crashed_op = np.uint64(42) + chunks = [np.uint64(1), np.uint64(2), np.uint64(3)] + # Seed crashed op's indefinite cells. + for c in chunks: + registry.lock_by_row_key_indefinitely(_l2_chunk_lock_row_key(c), crashed_op) + assert len(registry._held_indefinite) == 3 + + cg = make_cg_with_row_key_lock_registry(registry) + + # Replay's exact lock-dance from operation.py _apply. + with L2ChunkLock(cg, chunks, crashed_op, privileged_mode=True): + with IndefiniteL2ChunkLock(cg, chunks, crashed_op, privileged_mode=True): + # Simulated replay writes would happen here; we just + # assert the locks entered without raising. + pass + # Crashed op's cells released. + assert registry._held_indefinite == {} + + def test_exit_holds_on_exception_path(self): + """When `__exit__` is called with a propagating exception, cells + stay held and the op-log scope is NOT cleared — partial OCDBT / + bigtable writes may exist and subsequent ops must refuse until + operator recovery runs. + """ + registry = RowKeyLockRegistry() + cg = make_cg_with_row_key_lock_registry(registry) + chunks = [np.uint64(1), np.uint64(2)] + op_id = np.uint64(42) + + lock = IndefiniteL2ChunkLock(cg, chunks, op_id) + lock.__enter__() + # Enter wrote scope + held cells. + assert len(registry._held_indefinite) == 2 + scope_writes = self._scope_mutate_calls(cg) + assert any(len(v) > 0 for _, v in scope_writes) + + # Simulate an exception propagating through the `with` block. + lock.__exit__(ValueError, ValueError("boom"), None) + + # Cells still held, scope not cleared (no empty-array mutate). + assert len(registry._held_indefinite) == 2 + empty_writes = [(k, v) for k, v in self._scope_mutate_calls(cg) if len(v) == 0] + assert empty_writes == [] diff --git a/pychunkedgraph/tests/graph/test_meta.py b/pychunkedgraph/tests/graph/test_meta.py index fc24f9917..7ee2896a4 100644 --- a/pychunkedgraph/tests/graph/test_meta.py +++ b/pychunkedgraph/tests/graph/test_meta.py @@ -425,7 +425,10 @@ def test_ws_cv_redis_cached(self, mock_get_redis, mock_cv_cls): @patch("pychunkedgraph.graph.meta.CloudVolume") @patch("pychunkedgraph.graph.meta.get_redis_connection") def test_ws_cv_redis_failure_fallback(self, mock_get_redis, mock_cv_cls): - """When redis raises, ws_cv falls back to direct CloudVolume.""" + """When redis raises, ws_cv still fetches `.info` (via the loader) and + then constructs the cached CloudVolume with that info — it just skips + writing the info back to redis. + """ gc = GraphConfig(ID="test_graph", CHUNK_SIZE=[64, 64, 64]) ds = DataSource(WATERSHED="gs://bucket/ws", DATA_VERSION=4) meta = ChunkedGraphMeta(gc, ds) @@ -439,8 +442,12 @@ def test_ws_cv_redis_failure_fallback(self, mock_get_redis, mock_cv_cls): result = meta.ws_cv assert result is mock_cv_instance - # Should have been called without info kwarg (fallback) - mock_cv_cls.assert_called_with("gs://bucket/ws", progress=False) + # Loader path runs even when redis is dead — fetch .info first… + mock_cv_cls.assert_any_call("gs://bucket/ws", progress=False) + # …then construct the cached handle with explicit info=. + mock_cv_cls.assert_any_call( + "gs://bucket/ws", info={"scales": []}, progress=False + ) @patch("pychunkedgraph.graph.meta.CloudVolume") @patch("pychunkedgraph.graph.meta.get_redis_connection") @@ -462,9 +469,14 @@ def test_ws_cv_caches_to_redis(self, mock_get_redis, mock_cv_cls): result = meta.ws_cv assert result is mock_cv_instance - # The fallback CloudVolume call (no info= kwarg) - mock_cv_cls.assert_called_with("gs://bucket/ws", progress=False) - # Should try to cache in redis + # Cache-miss path: loader fetches .info, then handle is constructed + # with explicit info=, and the info gets written to redis. + mock_cv_cls.assert_any_call("gs://bucket/ws", progress=False) + mock_cv_cls.assert_any_call( + "gs://bucket/ws", + info={"scales": [{"resolution": [8, 8, 40]}]}, + progress=False, + ) mock_redis.set.assert_called_once() @patch("pychunkedgraph.graph.meta.CloudVolume") @@ -582,8 +594,17 @@ def test_ws_ocdbt_asserts_when_not_ocdbt(self): with pytest.raises(AssertionError, match="ocdbt"): _ = meta.ws_ocdbt + @patch("pychunkedgraph.graph.meta.read_populate_meta", return_value=None) + @patch("pychunkedgraph.graph.meta.ensure_fork_synced") + @patch("pychunkedgraph.graph.meta.fork_exists", return_value=True) @patch("pychunkedgraph.graph.meta.get_seg_source_and_destination_ocdbt") - def test_ws_ocdbt_returns_base_scale(self, mock_get_ocdbt): + def test_ws_ocdbt_returns_base_scale( + self, + mock_get_ocdbt, + _mock_fork_exists, + _mock_ensure_synced, + _mock_read_populate_meta, + ): gc = GraphConfig(ID="test_graph", CHUNK_SIZE=[64, 64, 64]) ds = DataSource(WATERSHED="gs://bucket/ws", DATA_VERSION=4) meta = ChunkedGraphMeta(gc, ds, custom_data={"seg": {"ocdbt": True}}) @@ -600,10 +621,23 @@ def test_ws_ocdbt_returns_base_scale(self, mock_get_ocdbt): assert meta.ws_ocdbt is mock_dst_base assert meta.ws_ocdbt_scales == [mock_dst_base, mock_dst_mip1] assert meta.ws_ocdbt_resolutions == [[4, 4, 40], [8, 8, 40]] - mock_get_ocdbt.assert_called_once_with("gs://bucket/ws", "test_graph") + # `get_seg_source_and_destination_ocdbt` is called with a third positional + # arg (the resolved OcdbtConfig) — verify only the (ws, graph_id) part; + # the OcdbtConfig.resolve contract has its own unit tests. + mock_get_ocdbt.assert_called_once() + assert mock_get_ocdbt.call_args.args[:2] == ("gs://bucket/ws", "test_graph") + @patch("pychunkedgraph.graph.meta.read_populate_meta", return_value=None) + @patch("pychunkedgraph.graph.meta.ensure_fork_synced") + @patch("pychunkedgraph.graph.meta.fork_exists", return_value=True) @patch("pychunkedgraph.graph.meta.get_seg_source_and_destination_ocdbt") - def test_ws_ocdbt_cached(self, mock_get_ocdbt): + def test_ws_ocdbt_cached( + self, + mock_get_ocdbt, + _mock_fork_exists, + _mock_ensure_synced, + _mock_read_populate_meta, + ): gc = GraphConfig(ID="test_graph", CHUNK_SIZE=[64, 64, 64]) ds = DataSource(WATERSHED="gs://bucket/ws", DATA_VERSION=4) meta = ChunkedGraphMeta(gc, ds, custom_data={"seg": {"ocdbt": True}}) diff --git a/pychunkedgraph/tests/graph/test_multicut.py b/pychunkedgraph/tests/graph/test_multicut.py index 590476ffd..4edd5962f 100644 --- a/pychunkedgraph/tests/graph/test_multicut.py +++ b/pychunkedgraph/tests/graph/test_multicut.py @@ -2,8 +2,7 @@ import pytest from ...graph.edges import Edges -from ...graph import exceptions -from ...graph.cutting import run_multicut +from ...graph.cutting import Cut, SvSplitRequired, run_multicut class TestGraphMultiCut: @@ -25,13 +24,15 @@ def test_cut_multi_tree(self, gen_graph): source_ids = np.array([1, 2], dtype=np.uint64) sink_ids = np.array([5, 6], dtype=np.uint64) - cut_edges = run_multicut( + result = run_multicut( edges, source_ids, sink_ids, path_augment=False, disallow_isolating_cut=False, ) + assert isinstance(result, Cut) + cut_edges = result.atomic_edges assert cut_edges.shape[0] > 0 # Verify the cut actually separates sources from sinks @@ -64,14 +65,19 @@ def test_path_augmented_multicut(self, sv_data): edges = Edges( sv_edges[:, 0], sv_edges[:, 1], affinities=sv_affinity, areas=sv_area ) - cut_edges_aug = run_multicut(edges, sv_sources, sv_sinks, path_augment=True) - assert cut_edges_aug.shape[0] == 350 + result = run_multicut(edges, sv_sources, sv_sinks, path_augment=True) + assert isinstance(result, Cut) + assert result.atomic_edges.shape[0] == 350 - with pytest.raises(exceptions.SupervoxelSplitRequiredError): - run_multicut( - edges, - sv_sources, - sv_sinks, - path_augment=False, - sv_split_supported=True, - ) + # Without path augmentation on this fixture, source/sink share a + # cross-chunk representative — returned as SvSplitRequired when + # sv_split_supported=True (no exception escapes run_multicut). + sv_result = run_multicut( + edges, + sv_sources, + sv_sinks, + path_augment=False, + sv_split_supported=True, + ) + assert isinstance(sv_result, SvSplitRequired) + assert sv_result.sv_remapping # non-empty mapping diff --git a/pychunkedgraph/tests/graph/test_ocdbt.py b/pychunkedgraph/tests/graph/test_ocdbt.py index efa6ed048..e1694f004 100644 --- a/pychunkedgraph/tests/graph/test_ocdbt.py +++ b/pychunkedgraph/tests/graph/test_ocdbt.py @@ -4,6 +4,8 @@ import os import shutil import tempfile +import time +from datetime import datetime, timezone import numpy as np import pytest @@ -13,6 +15,13 @@ from pychunkedgraph.graph import ocdbt as ocdbt_mod from pychunkedgraph.graph.meta import ChunkedGraphMeta, GraphConfig, DataSource +SCALE_META_BASE = { + "encoding": "compressed_segmentation", + "compressed_segmentation_block_size": [8, 8, 8], + "chunk_size": [32, 32, 32], +} +MULTISCALE_META = {"type": "segmentation", "data_type": "uint64", "num_channels": 1} + def _make_mock_src(num_scales=2): """Build a mock TensorStore source handle with a copyable schema.""" @@ -62,7 +71,9 @@ def _setup_ts_mock(mock_ts, num_scales=2): class TestBuildCgOcdbtSpec: def test_spec_structure(self): """build_cg_ocdbt_spec returns the expected kvstack-layered spec.""" - spec = ocdbt_mod.build_cg_ocdbt_spec("gs://bucket/ws", "my_graph") + spec = ocdbt_mod.build_cg_ocdbt_spec( + "gs://bucket/ws", "my_graph", ocdbt_mod.OcdbtConfig() + ) assert spec["driver"] == "ocdbt" layers = spec["base"]["layers"] assert len(layers) == 3 @@ -80,41 +91,33 @@ def test_spec_structure(self): class TestForkBaseManifest: - def test_copies_manifest(self): + """Byte-level behavior of `fork_base_manifest` — manifest copy + wipe.""" + + def test_copies_manifest(self, local_ocdbt): """fork_base_manifest copies the base manifest via tensorstore kvstore.""" - tmpdir = tempfile.mkdtemp() - ws = f"file://{tmpdir}" - try: - # Create a real base OCDBT with a manifest. - base_kvs = ts.KvStore.open(f"{ws}/ocdbt/base/").result() - base_kvs.write("manifest.ocdbt", b"fake_manifest_bytes").result() + ws = local_ocdbt["ws"] + base_kvs = ts.KvStore.open(f"{ws}/ocdbt/base/").result() + base_kvs.write("manifest.ocdbt", b"fake_manifest_bytes").result() - ocdbt_mod.fork_base_manifest(ws, "my_graph") + ocdbt_mod.fork_base_manifest(ws, "my_graph") - fork_kvs = ts.KvStore.open(f"{ws}/ocdbt/my_graph/").result() - result = fork_kvs.read("manifest.ocdbt").result() - assert result.value == b"fake_manifest_bytes" - finally: - shutil.rmtree(tmpdir) + fork_kvs = ts.KvStore.open(f"{ws}/ocdbt/my_graph/").result() + assert fork_kvs.read("manifest.ocdbt").result().value == b"fake_manifest_bytes" - def test_wipe_existing_cleans_fork_dir(self): + def test_wipe_existing_cleans_fork_dir(self, local_ocdbt): """wipe_existing=True removes the fork directory before copying.""" - tmpdir = tempfile.mkdtemp() - ws = f"file://{tmpdir}" - try: - base_kvs = ts.KvStore.open(f"{ws}/ocdbt/base/").result() - base_kvs.write("manifest.ocdbt", b"manifest_v1").result() + ws = local_ocdbt["ws"] + base_kvs = ts.KvStore.open(f"{ws}/ocdbt/base/").result() + base_kvs.write("manifest.ocdbt", b"manifest_v1").result() - fork_kvs = ts.KvStore.open(f"{ws}/ocdbt/my_graph/").result() - fork_kvs.write("stale_file", b"stale").result() + fork_kvs = ts.KvStore.open(f"{ws}/ocdbt/my_graph/").result() + fork_kvs.write("stale_file", b"stale").result() - ocdbt_mod.fork_base_manifest(ws, "my_graph", wipe_existing=True) + ocdbt_mod.fork_base_manifest(ws, "my_graph", wipe_existing=True) - fork_kvs2 = ts.KvStore.open(f"{ws}/ocdbt/my_graph/").result() - assert fork_kvs2.read("manifest.ocdbt").result().value == b"manifest_v1" - assert len(fork_kvs2.read("stale_file").result().value) == 0 - finally: - shutil.rmtree(tmpdir) + fork_kvs2 = ts.KvStore.open(f"{ws}/ocdbt/my_graph/").result() + assert fork_kvs2.read("manifest.ocdbt").result().value == b"manifest_v1" + assert len(fork_kvs2.read("stale_file").result().value) == 0 class TestModeDownsample: @@ -216,45 +219,80 @@ def test_boundary_clipping(self): @pytest.fixture def local_ocdbt(): - """Create a local precomputed multi-scale OCDBT store. - - Builds 3 scales (factors 2,2,1 between each) with known segmentation IDs - so downsampling behaviour and propagation can be asserted against exact - values. Returns paths + handles for tests to work against directly. + """Shared OCDBT test environment. + + Creates a local 3-scale precomputed base OCDBT (factors 2,2,1 per + level) and exposes helpers for fork-based tests. Every OCDBT test + that needs real storage uses this fixture — no duplicated tmpdir + scaffolding. + + Yields: + tmpdir: on-disk workspace (cleaned up on teardown). + ws: `file://{tmpdir}` URL — what `build_cg_ocdbt_spec` expects. + base: base OCDBT kvstore URL. + scales: 3 precomputed handles on the base (multi-scale tests). + resolutions: per-scale [x,y,z] resolution arrays. + make_fork(graph_id, *, scale_index=0, pinned_at=None): opens a + precomputed handle through a fork of the base. Creates the + fork on first call per `graph_id` and reuses it thereafter; + repeated calls with the same id never re-copy the manifest + (which would clobber fork writes). """ tmpdir = tempfile.mkdtemp() - base = f"file://{tmpdir}/ocdbt/base" - - mm = {"type": "segmentation", "data_type": "uint64", "num_channels": 1} + ws = f"file://{tmpdir}" + base = f"{ws}/ocdbt/base" - def mk(scale_idx, size, resolution, extra_mm=None): + def _mk_scale(size, resolution, *, include_mm): + # Match OcdbtConfig defaults so forks (which always use them) don't + # trip the "Configuration mismatch on max_inline_value_bytes" check. spec = { "driver": "neuroglancer_precomputed", - "kvstore": {"driver": "ocdbt", "base": base}, + "kvstore": { + "driver": "ocdbt", + "base": base, + "config": ocdbt_mod.OcdbtConfig().ts_config(), + }, "scale_metadata": { "size": size, "resolution": resolution, - "encoding": "compressed_segmentation", - "compressed_segmentation_block_size": [8, 8, 8], - "chunk_size": [32, 32, 32], + **SCALE_META_BASE, }, } - if extra_mm: - spec["multiscale_metadata"] = extra_mm + if include_mm: + spec["multiscale_metadata"] = MULTISCALE_META return ts.open(spec, create=True).result() scales = [ - mk(0, [64, 64, 32], [4, 4, 40], extra_mm=mm), - mk(1, [32, 32, 32], [8, 8, 40]), - mk(2, [16, 16, 32], [16, 16, 40]), + _mk_scale([64, 64, 32], [4, 4, 40], include_mm=True), + _mk_scale([32, 32, 32], [8, 8, 40], include_mm=False), + _mk_scale([16, 16, 32], [16, 16, 40], include_mm=False), ] resolutions = [[4, 4, 40], [8, 8, 40], [16, 16, 40]] + _created_forks = set() + + def make_fork(graph_id, *, scale_index=0, pinned_at=None): + if graph_id not in _created_forks: + ocdbt_mod.fork_base_manifest(ws, graph_id) + _created_forks.add(graph_id) + spec = ocdbt_mod.build_cg_ocdbt_spec( + ws, graph_id, ocdbt_mod.OcdbtConfig(), pinned_at=pinned_at + ) + return ts.open( + { + "driver": "neuroglancer_precomputed", + "kvstore": spec, + "scale_index": scale_index, + } + ).result() + yield { "tmpdir": tmpdir, + "ws": ws, "base": base, "scales": scales, "resolutions": resolutions, + "make_fork": make_fork, } shutil.rmtree(tmpdir) @@ -398,220 +436,235 @@ def test_repeated_update_reflects_latest_base(self, local_ocdbt): assert (scales[2][0:4, 0:4, 0:16, :].read().result() == 2).all() -class TestWriteSeg: - def test_writes_base_and_propagates(self, local_ocdbt): - """`write_seg` writes to base scale AND propagates to all coarser scales.""" +class TestWriteSegChunks: + """`write_seg_chunks` now takes a flat list of (slices, data) pairs. + + `sv_split.edits.split_supervoxels` is responsible for producing this list + across all reps so the outer rep loop is a pure data gather — + tensorstore writes fire in one parallel batch. + """ + + def test_writes_only_supplied_chunks(self, local_ocdbt): + """Chunks absent from `seg_writes` stay untouched (OCDBT delta + stays proportional to the actual SV change).""" scales = local_ocdbt["scales"] - res = local_ocdbt["resolutions"] meta = MagicMock() meta.ws_ocdbt = scales[0] - meta.ws_ocdbt_scales = scales - meta.ws_ocdbt_resolutions = res - data = np.full((16, 16, 16), 55, dtype=np.uint64) - ocdbt_mod.write_seg(meta, [0, 0, 0], [16, 16, 16], data) + # One chunk at [0..32] with label 55. The adjacent chunk at + # [32..64] is NOT in the write list, so it should stay zero. + chunk_data = np.full((32, 32, 32), 55, dtype=np.uint64) + seg_writes = [ + ( + (slice(0, 32), slice(0, 32), slice(0, 32)), + chunk_data, + ) + ] + ocdbt_mod.write_seg_chunks(meta, seg_writes) - # Base scale: written region has label 55. - assert (scales[0][0:16, 0:16, 0:16, :].read().result() == 55).all() - # Coarser scales: propagated. - assert (scales[1][0:8, 0:8, 0:16, :].read().result() == 55).all() - assert (scales[2][0:4, 0:4, 0:16, :].read().result() == 55).all() + assert (scales[0][0:32, 0:32, 0:32, :].read().result() == 55).all() + assert (scales[0][32:64, 0:32, 0:32, :].read().result() == 0).all() + # Coarser scales untouched — downsample worker's job. + assert (scales[1][0:16, 0:16, 0:32, :].read().result() == 0).all() + assert (scales[2][0:8, 0:8, 0:32, :].read().result() == 0).all() - def test_single_scale_skips_propagation(self, local_ocdbt): - """With only one scale in the list, propagation is a no-op (no IndexError).""" + def test_multiple_chunks_in_one_batch(self, local_ocdbt): + """Multiple chunks (e.g. from different reps) fire in one call.""" + scales = local_ocdbt["scales"] meta = MagicMock() - meta.ws_ocdbt = local_ocdbt["scales"][0] - meta.ws_ocdbt_scales = [local_ocdbt["scales"][0]] - meta.ws_ocdbt_resolutions = [local_ocdbt["resolutions"][0]] + meta.ws_ocdbt = scales[0] - data = np.full((8, 8, 8), 99, dtype=np.uint64) - ocdbt_mod.write_seg(meta, [0, 0, 0], [8, 8, 8], data) - assert (meta.ws_ocdbt[0:8, 0:8, 0:8, :].read().result() == 99).all() + seg_writes = [ + ( + (slice(0, 32), slice(0, 32), slice(0, 32)), + np.full((32, 32, 32), 11, dtype=np.uint64), + ), + ( + (slice(32, 64), slice(0, 32), slice(0, 32)), + np.full((32, 32, 32), 22, dtype=np.uint64), + ), + ] + ocdbt_mod.write_seg_chunks(meta, seg_writes) + assert (scales[0][0:32, 0:32, 0:32, :].read().result() == 11).all() + assert (scales[0][32:64, 0:32, 0:32, :].read().result() == 22).all() -class TestMetaToForkEndToEnd: - """Full path: ChunkedGraphMeta.ws_ocdbt_scales → real kvstack fork → read/write.""" + def test_offset_region(self, local_ocdbt): + """Writes at a non-origin offset land in the right chunk.""" + scales = local_ocdbt["scales"] + meta = MagicMock() + meta.ws_ocdbt = scales[0] - def test_meta_opens_fork_and_merges_base(self): - """meta.ws_ocdbt_scales opens a real kvstack-backed OCDBT and reads - merge base + fork correctly. + seg_writes = [ + ( + (slice(32, 64), slice(0, 32), slice(0, 32)), + np.full((32, 32, 32), 99, dtype=np.uint64), + ) + ] + ocdbt_mod.write_seg_chunks(meta, seg_writes) - Only `_read_source_scales` is mocked (it reads `/info` which is a - GCS-only key). The full meta → build_cg_ocdbt_spec → kvstack → - OCDBT → read/write path is exercised for real. - """ - tmpdir = tempfile.mkdtemp() - ws = f"file://{tmpdir}" - try: - MM = {"type": "segmentation", "data_type": "uint64", "num_channels": 1} - SCALE = { - "size": [64, 64, 32], - "resolution": [4, 4, 40], - "encoding": "compressed_segmentation", - "compressed_segmentation_block_size": [8, 8, 8], - "chunk_size": [32, 32, 32], - } - FAKE_SCALES = [ - { - "resolution": [4, 4, 40], - "size": [64, 64, 32], - "chunk_sizes": [[32, 32, 32]], - "encoding": "compressed_segmentation", - "compressed_segmentation_block_size": [8, 8, 8], - } - ] - - # Source precomputed — needed by get_seg_source_and_destination_ocdbt - # to open the source handle and copy its schema. - ts.open( - { - "driver": "neuroglancer_precomputed", - "kvstore": f"{ws}/", - "multiscale_metadata": MM, - "scale_metadata": SCALE, - }, - create=True, - ).result() + assert (scales[0][32:64, 0:32, 0:32, :].read().result() == 99).all() + assert (scales[0][0:32, 0:32, 0:32, :].read().result() == 0).all() - # Create base OCDBT with known data. - base_kvstore = { - "driver": "ocdbt", - "base": f"{ws}/ocdbt/base/", - "config": dict(ocdbt_mod.OCDBT_CONFIG), - } - base_store = ts.open( - { - "driver": "neuroglancer_precomputed", - "kvstore": base_kvstore, - "multiscale_metadata": MM, - "scale_metadata": SCALE, + +class TestWsOcdbtScalesProperty: + """`ChunkedGraphMeta.ws_ocdbt_scales` opens a fork over the shared base. + + Full path exercised: property → build_cg_ocdbt_spec → kvstack → OCDBT + read/write. Only `_read_source_scales` is mocked (it reads `/info` + which lives on the source watershed, not the OCDBT fork). + """ + + def test_opens_fork_and_merges_base(self, local_ocdbt): + ws = local_ocdbt["ws"] + + # Source precomputed at ws root — needed by + # get_seg_source_and_destination_ocdbt to copy the schema. + ts.open( + { + "driver": "neuroglancer_precomputed", + "kvstore": f"{ws}/", + "multiscale_metadata": MULTISCALE_META, + "scale_metadata": { + "size": [64, 64, 32], + "resolution": [4, 4, 40], + **SCALE_META_BASE, }, - create=True, - ).result() - base_store[...] = np.full((64, 64, 32, 1), 50, dtype=np.uint64) - - # Fork for graph "test_cg". - ocdbt_mod.fork_base_manifest(f"{ws}/", "test_cg") - - gc = GraphConfig(ID="test_cg", CHUNK_SIZE=[32, 32, 32]) - ds = DataSource(WATERSHED=f"{ws}/", DATA_VERSION=4) - meta = ChunkedGraphMeta(gc, ds, custom_data={"seg": {"ocdbt": True}}) - - # Mock only _read_source_scales ('/info' is GCS-only). - with patch.object( - ocdbt_mod, "_read_source_scales", return_value=FAKE_SCALES - ): - scales = meta.ws_ocdbt_scales - assert len(scales) == 1 - - # Read: should see base data. - r = scales[0][0:16, 0:16, 0:16, :].read().result() - assert (r == 50).all(), f"fork should see base, got {np.unique(r)}" - - # Write via the fork handle. - scales[0][0:16, 0:16, 0:16, :] = np.full( - (16, 16, 16, 1), 7, dtype=np.uint64 - ) + }, + create=True, + ).result() - # Read back: edited = 7, untouched = 50. - assert (scales[0][0:16, 0:16, 0:16, :].read().result() == 7).all() - assert (scales[0][32:48, 0:16, 0:16, :].read().result() == 50).all() - - # Base unchanged. - base_ro = ts.open( - { - "driver": "neuroglancer_precomputed", - "kvstore": base_kvstore, - } - ).result() - assert (base_ro[0:16, 0:16, 0:16, :].read().result() == 50).all() - finally: - shutil.rmtree(tmpdir) + # Seed base scale 0 with a known value via the fixture's handle. + local_ocdbt["scales"][0][...] = np.full((64, 64, 32, 1), 50, dtype=np.uint64) + gc = GraphConfig(ID="ws_scales_cg", CHUNK_SIZE=[32, 32, 32]) + ds = DataSource(WATERSHED=f"{ws}/", DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds, custom_data={"seg": {"ocdbt": True}}) -class TestForkIsolation: - """End-to-end: two forks on the same base, writes isolated, base immutable.""" + # Trigger fork creation through the same helper the property will use. + local_ocdbt["make_fork"]("ws_scales_cg") - def test_two_forks_isolated(self): - tmpdir = tempfile.mkdtemp() - ws = f"file://{tmpdir}" - try: - # Build a base OCDBT with known data. - MM = {"type": "segmentation", "data_type": "uint64", "num_channels": 1} - SCALE = { - "size": [64, 64, 32], + fake_scales = [ + { "resolution": [4, 4, 40], + "size": [64, 64, 32], + "chunk_sizes": [[32, 32, 32]], "encoding": "compressed_segmentation", "compressed_segmentation_block_size": [8, 8, 8], - "chunk_size": [32, 32, 32], } - base_kvstore = { - "driver": "ocdbt", - "base": f"{ws}/ocdbt/base/", - "config": dict(ocdbt_mod.OCDBT_CONFIG), - } - base_store = ts.open( - { - "driver": "neuroglancer_precomputed", - "kvstore": base_kvstore, - "multiscale_metadata": MM, - "scale_metadata": SCALE, - }, - create=True, - ).result() - base_store[...] = np.full((64, 64, 32, 1), 50, dtype=np.uint64) - - base_path = f"{tmpdir}/ocdbt/base" - base_files_before = set( - os.path.relpath(os.path.join(r, f), base_path) - for r, _, fs in os.walk(base_path) - for f in fs + ] + with patch.object( + ocdbt_mod.main, "_read_source_scales", return_value=fake_scales + ): + scales = meta.ws_ocdbt_scales + assert len(scales) == 1 + + # Fork sees base data. + assert (scales[0][0:16, 0:16, 0:16, :].read().result() == 50).all() + + # Write to the fork and confirm isolation. + scales[0][0:16, 0:16, 0:16, :] = np.full( + (16, 16, 16, 1), 7, dtype=np.uint64 ) + assert (scales[0][0:16, 0:16, 0:16, :].read().result() == 7).all() + assert (scales[0][32:48, 0:16, 0:16, :].read().result() == 50).all() - # Fork A and B via fork_base_manifest. - ocdbt_mod.fork_base_manifest(ws, "fork_a") - ocdbt_mod.fork_base_manifest(ws, "fork_b") - - def open_fork(gid): - spec = ocdbt_mod.build_cg_ocdbt_spec(ws, gid) - return ts.open( - {"driver": "neuroglancer_precomputed", "kvstore": spec}, - ).result() - - fork_a = open_fork("fork_a") - fork_b = open_fork("fork_b") - - # Both see base data. - assert (fork_a[0:16, 0:16, 0:16, :].read().result() == 50).all() - assert (fork_b[0:16, 0:16, 0:16, :].read().result() == 50).all() - - # Write different values to each fork. - fork_a[0:16, 0:16, 0:16, :] = np.full((16, 16, 16, 1), 1, dtype=np.uint64) - fork_b[32:48, 0:16, 0:16, :] = np.full((16, 16, 16, 1), 2, dtype=np.uint64) - - # Each fork sees ONLY its own edit + base for the rest. - assert (fork_a[0:16, 0:16, 0:16, :].read().result() == 1).all() - assert (fork_a[32:48, 0:16, 0:16, :].read().result() == 50).all() - assert (fork_b[32:48, 0:16, 0:16, :].read().result() == 2).all() - assert (fork_b[0:16, 0:16, 0:16, :].read().result() == 50).all() - - # Base is unchanged. - base_files_after = set( - os.path.relpath(os.path.join(r, f), base_path) - for r, _, fs in os.walk(base_path) - for f in fs - ) - assert ( - base_files_before == base_files_after - ), f"base was mutated: new={base_files_after - base_files_before}" - - # Fork writes went to their own directories. - fork_a_files = os.listdir(f"{tmpdir}/ocdbt/fork_a") - fork_b_files = os.listdir(f"{tmpdir}/ocdbt/fork_b") - assert any("fork_a_d" in f for f in fork_a_files) - assert any("fork_b_d" in f for f in fork_b_files) - finally: - shutil.rmtree(tmpdir) + # Base still reports the original value (fork write didn't leak). + assert ( + local_ocdbt["scales"][0][0:16, 0:16, 0:16, :].read().result() == 50 + ).all() + + +class TestForkIsolation: + """Two forks on the same base: writes isolated, base immutable.""" + + def test_two_forks_isolated(self, local_ocdbt): + tmpdir = local_ocdbt["tmpdir"] + # Seed base scale 0 with a known value. + local_ocdbt["scales"][0][...] = np.full((64, 64, 32, 1), 50, dtype=np.uint64) + + base_path = f"{tmpdir}/ocdbt/base" + base_files_before = { + os.path.relpath(os.path.join(r, f), base_path) + for r, _, fs in os.walk(base_path) + for f in fs + } + + fork_a = local_ocdbt["make_fork"]("fork_a") + fork_b = local_ocdbt["make_fork"]("fork_b") + + # Both see base data. + assert (fork_a[0:16, 0:16, 0:16, :].read().result() == 50).all() + assert (fork_b[0:16, 0:16, 0:16, :].read().result() == 50).all() + + # Write different values to each fork. + fork_a[0:16, 0:16, 0:16, :] = np.full((16, 16, 16, 1), 1, dtype=np.uint64) + fork_b[32:48, 0:16, 0:16, :] = np.full((16, 16, 16, 1), 2, dtype=np.uint64) + + # Each fork sees ONLY its own edit + base for the rest. + assert (fork_a[0:16, 0:16, 0:16, :].read().result() == 1).all() + assert (fork_a[32:48, 0:16, 0:16, :].read().result() == 50).all() + assert (fork_b[32:48, 0:16, 0:16, :].read().result() == 2).all() + assert (fork_b[0:16, 0:16, 0:16, :].read().result() == 50).all() + + # Base files unchanged (no new bytes written under ocdbt/base/). + base_files_after = { + os.path.relpath(os.path.join(r, f), base_path) + for r, _, fs in os.walk(base_path) + for f in fs + } + assert ( + base_files_before == base_files_after + ), f"base was mutated: new={base_files_after - base_files_before}" + + # Fork writes went to their own directories. + assert any("fork_a_d" in f for f in os.listdir(f"{tmpdir}/ocdbt/fork_a")) + assert any("fork_b_d" in f for f in os.listdir(f"{tmpdir}/ocdbt/fork_b")) + + +class TestPinnedAt: + """Versioned reads: pinning a fork to a prior generation/timestamp + returns pre-write state; default (unpinned) returns latest. + + Documents both pin forms OCDBT accepts — integer generation (exact) + and ISO-8601 UTC timestamp with `Z` suffix (commit_time upper bound). + """ + + def test_pin_by_generation_and_by_timestamp(self, local_ocdbt): + # Seed base so fork reads see data even before the first fork write. + local_ocdbt["scales"][0][...] = np.full((64, 64, 32, 1), 50, dtype=np.uint64) + + fork = local_ocdbt["make_fork"]("pin_cg") + + # Write v1 then v2 at the same voxels. Capture pin markers between + # the two writes so pre-v2 state is what each pin should return. + fork[0:16, 0:16, 0:16, :] = np.full((16, 16, 16, 1), 1, dtype=np.uint64) + + fork_manifest_kvs = ts.KvStore.open( + f"{local_ocdbt['ws']}/ocdbt/pin_cg/" + ).result() + pin_gen = ts.ocdbt.dump(fork_manifest_kvs).result()["versions"][-1][ + "generation_number" + ] + + time.sleep(0.01) + pin_ts = datetime.now(tz=timezone.utc).isoformat().replace("+00:00", "Z") + time.sleep(0.01) + + fork[0:16, 0:16, 0:16, :] = np.full((16, 16, 16, 1), 2, dtype=np.uint64) + + fork_latest = local_ocdbt["make_fork"]("pin_cg") + assert (fork_latest[0:16, 0:16, 0:16, :].read().result() == 2).all() + + fork_gen = local_ocdbt["make_fork"]("pin_cg", pinned_at=pin_gen) + assert (fork_gen[0:16, 0:16, 0:16, :].read().result() == 1).all() + + fork_ts = local_ocdbt["make_fork"]("pin_cg", pinned_at=pin_ts) + assert (fork_ts[0:16, 0:16, 0:16, :].read().result() == 1).all() + + # Untouched region still shows base data under every pin. + for handle in (fork_latest, fork_gen, fork_ts): + assert (handle[32:48, 0:16, 0:16, :].read().result() == 50).all() class TestCopyWsChunkMultiscale: diff --git a/pychunkedgraph/tests/graph/test_stuck_ops.py b/pychunkedgraph/tests/graph/test_stuck_ops.py new file mode 100644 index 000000000..d3e8c3b8d --- /dev/null +++ b/pychunkedgraph/tests/graph/test_stuck_ops.py @@ -0,0 +1,385 @@ +"""Tests for pychunkedgraph.repair.stuck_ops — cleanup + replay path +for SV-split ops that crashed mid-write. + +The heavy test (`test_cleanup_reverts_partial_writes_to_pre_op`) +exercises the full cleanup flow against a real local OCDBT store — it +writes a known pre-op state, snapshots the manifest, writes simulated +"partial" data, constructs an op-log row with `L2ChunkLockScope` and +`OperationTimeStamp`, and asserts that cleanup reverts the scoped +chunks to pre-op values while leaving neighbor chunks alone. +""" + +from datetime import datetime, timedelta, timezone +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import tensorstore as ts + +from pychunkedgraph.graph import attributes, ocdbt as ocdbt_mod +from pychunkedgraph.graph.chunks.utils import get_chunk_coordinates +from pychunkedgraph.graph.locks import _l2_chunk_lock_row_key +from pychunkedgraph.graph.meta import ChunkedGraphMeta, DataSource, GraphConfig +from pychunkedgraph.repair import stuck_ops + +# Pick up the shared `local_ocdbt` fixture from test_ocdbt. +from .test_ocdbt import local_ocdbt # noqa: F401 + + +class TestListStuck: + """`list_stuck` surfaces ops with non-empty `L2ChunkLockScope` past + `min_age` whose Status isn't SUCCESS — i.e. still holding + `Concurrency.IndefiniteLock` cells somewhere.""" + + def _entry(self, status, age_seconds, user="u", scope=None): + now = datetime.now(timezone.utc) + entry = { + attributes.OperationLogs.Status: status, + attributes.OperationLogs.OperationTimeStamp: now + - timedelta(seconds=age_seconds), + attributes.OperationLogs.UserID: user, + } + if scope is not None: + entry[attributes.OperationLogs.L2ChunkLockScope] = np.asarray( + scope, dtype=np.uint64 + ) + return entry + + def _cg(self, entries): + cg = MagicMock() + cg.client.read_log_entries.return_value = entries + return cg + + def test_filters_out_success_with_scope(self): + """Defensive: a SUCCESS op with stale scope (if + `_clear_scope_on_op_log` ever failed silently) must not be + listed as stuck.""" + success = attributes.OperationLogs.StatusCodes.SUCCESS.value + created = attributes.OperationLogs.StatusCodes.CREATED.value + cg = self._cg( + { + np.uint64(1): self._entry(success, 900, scope=[10, 20]), + np.uint64(2): self._entry(created, 900, scope=[10, 20]), + } + ) + stuck = stuck_ops.list_stuck(cg, min_age=timedelta(minutes=1)) + assert [r["op_id"] for r in stuck] == [2] + + def test_filters_out_empty_scope(self): + """Ops that never touched the persist block (no scope) are not + stuck via L2 locks — they're outside `stuck_ops`' concern.""" + created = attributes.OperationLogs.StatusCodes.CREATED.value + exception = attributes.OperationLogs.StatusCodes.EXCEPTION.value + cg = self._cg( + { + np.uint64(1): self._entry(created, 900), # no scope + np.uint64(2): self._entry(exception, 900), # no scope + np.uint64(3): self._entry(created, 900, scope=[42]), + } + ) + stuck = stuck_ops.list_stuck(cg, min_age=timedelta(minutes=1)) + assert [r["op_id"] for r in stuck] == [3] + + def test_surfaces_exception_path_with_scope(self): + """After Fix 1, a Python exception during the persist block + leaves cells held + scope set but Status=EXCEPTION. The op must + be listed so the operator can recover it.""" + exception = attributes.OperationLogs.StatusCodes.EXCEPTION.value + cg = self._cg( + { + np.uint64(42): self._entry( + exception, 900, user="alice", scope=[100, 200] + ), + } + ) + stuck = stuck_ops.list_stuck(cg, min_age=timedelta(minutes=1)) + assert len(stuck) == 1 + row = stuck[0] + assert row["op_id"] == 42 + assert row["status"] == exception + assert list(row["l2_chunk_scope"]) == [100, 200] + + def test_filters_out_young_ops(self): + created = attributes.OperationLogs.StatusCodes.CREATED.value + cg = self._cg( + { + np.uint64(1): self._entry(created, 10, scope=[1]), # too young + np.uint64(2): self._entry(created, 3600, scope=[2]), # an hour old + } + ) + stuck = stuck_ops.list_stuck(cg, min_age=timedelta(minutes=10)) + assert [r["op_id"] for r in stuck] == [2] + + def test_returns_scope_and_user(self): + created = attributes.OperationLogs.StatusCodes.CREATED.value + cg = self._cg( + { + np.uint64(7): self._entry(created, 1800, user="op", scope=[100, 200]), + } + ) + stuck = stuck_ops.list_stuck(cg, min_age=timedelta(minutes=10)) + assert len(stuck) == 1 + row = stuck[0] + assert row["op_id"] == 7 + assert row["user_id"] == "op" + assert list(row["l2_chunk_scope"]) == [100, 200] + assert row["age"] > timedelta(minutes=10) + + +class TestVerifyIndefiniteCells: + """`_verify_indefinite_cells` reads each chunk's indefinite-lock cell + and reports any that don't match the expected op_id.""" + + class _Cell: + def __init__(self, value): + self.value = value + + def _cg(self, cells_by_row_key): + cg = MagicMock() + + def read(row_key, columns=None): + return cells_by_row_key.get(row_key, []) + + cg.client._read_byte_row.side_effect = read + return cg + + def test_all_held_by_same_op(self): + op_id = 42 + scope = [np.uint64(1), np.uint64(2)] + cells = { + stuck_ops._l2_chunk_lock_row_key(1): [self._Cell(np.uint64(op_id))], + stuck_ops._l2_chunk_lock_row_key(2): [self._Cell(np.uint64(op_id))], + } + cg = self._cg(cells) + assert stuck_ops._verify_indefinite_cells(cg, op_id, scope) == [] + + def test_cell_missing_flagged(self): + op_id = 42 + scope = [np.uint64(1), np.uint64(2)] + cells = { + stuck_ops._l2_chunk_lock_row_key(1): [self._Cell(np.uint64(op_id))], + # chunk 2 has no cell + } + cg = self._cg(cells) + discrepancies = stuck_ops._verify_indefinite_cells(cg, op_id, scope) + assert discrepancies == [2] + + def test_cell_held_by_different_op_flagged(self): + op_id = 42 + other_op = np.uint64(99) + scope = [np.uint64(1), np.uint64(2)] + cells = { + stuck_ops._l2_chunk_lock_row_key(1): [self._Cell(other_op)], + stuck_ops._l2_chunk_lock_row_key(2): [self._Cell(np.uint64(op_id))], + } + cg = self._cg(cells) + discrepancies = stuck_ops._verify_indefinite_cells(cg, op_id, scope) + assert discrepancies == [1] + + +class TestReplayVerifies: + """`replay` refuses to call cleanup_partial_writes or repair_operation + when the recorded scope disagrees with live indefinite-lock state.""" + + def test_replay_refuses_when_cells_missing(self, monkeypatch): + op_id = 77 + scope = np.asarray([1, 2], dtype=np.uint64) + + cg = MagicMock() + cg.client.read_log_entries.return_value = { + np.uint64(op_id): { + attributes.OperationLogs.L2ChunkLockScope: scope, + attributes.OperationLogs.OperationTimeStamp: datetime.now(timezone.utc), + } + } + # No cells held on either chunk. + cg.client._read_byte_row.return_value = [] + + # Spy on the destructive steps — neither should be called. + cleanup_called = {"v": False} + repair_called = {"v": False} + monkeypatch.setattr( + stuck_ops, + "cleanup_partial_writes", + lambda *a, **k: cleanup_called.__setitem__("v", True), + ) + monkeypatch.setattr( + stuck_ops, + "repair_operation", + lambda *a, **k: repair_called.__setitem__("v", True), + ) + + with pytest.raises(RuntimeError, match="Refusing to replay"): + stuck_ops.replay(cg, op_id) + assert not cleanup_called["v"] + assert not repair_called["v"] + + def test_replay_refuses_when_empty_scope(self, monkeypatch): + op_id = 77 + cg = MagicMock() + cg.client.read_log_entries.return_value = { + np.uint64(op_id): { + attributes.OperationLogs.OperationTimeStamp: datetime.now(timezone.utc), + } + } + cleanup_called = {"v": False} + monkeypatch.setattr( + stuck_ops, + "cleanup_partial_writes", + lambda *a, **k: cleanup_called.__setitem__("v", True), + ) + + with pytest.raises(RuntimeError, match="not a stuck SV-split op"): + stuck_ops.replay(cg, op_id) + assert not cleanup_called["v"] + + +class TestCleanupPartialWrites: + """Cleanup reverts partial OCDBT writes using pinned reads of pre-op state.""" + + def _meta_with_fork(self, local_ocdbt_fixture, graph_id): + """Build a real ChunkedGraphMeta pointing at the fixture's fork so + `ws_ocdbt` reads/writes go through the same kvstack as production. + + Creates a matching source precomputed at the watershed root so + `get_seg_source_and_destination_ocdbt` and `ws_cv` both work. + Sets `layer_count` explicitly to bypass `ws_cv.bounds` inference. + """ + ws = local_ocdbt_fixture["ws"] + mm = {"type": "segmentation", "data_type": "uint64", "num_channels": 1} + scale_metadata = { + "size": [64, 64, 32], + "resolution": [4, 4, 40], + "encoding": "compressed_segmentation", + "compressed_segmentation_block_size": [8, 8, 8], + "chunk_size": [32, 32, 32], + } + ts.open( + { + "driver": "neuroglancer_precomputed", + "kvstore": f"{ws}/", + "multiscale_metadata": mm, + "scale_metadata": scale_metadata, + }, + create=True, + ).result() + + local_ocdbt_fixture["make_fork"](graph_id) + + gc = GraphConfig( + ID=graph_id, + CHUNK_SIZE=np.array([32, 32, 32], dtype=int), + ) + ds = DataSource(WATERSHED=f"{ws}/", DATA_VERSION=4) + meta = ChunkedGraphMeta(gc, ds, custom_data={"seg": {"ocdbt": True}}) + meta.layer_count = 3 # avoids lazy cloudvolume layer inference + return meta + + def _capture_fork_pin(self, local_ocdbt_fixture, graph_id): + """Return an ISO-8601 `Z`-suffix pin string for the fork's current + manifest commit — the pre-op timestamp for cleanup to pin on. + """ + fork_manifest_kvs = ts.KvStore.open( + f"{local_ocdbt_fixture['ws']}/ocdbt/{graph_id}/" + ).result() + manifest = ts.ocdbt.dump(fork_manifest_kvs).result() + # commit_time is recorded as int ns since epoch; use a timestamp + # just past the last commit as the pin so the upper-bound filter + # picks up everything written so far. + last_ns = manifest["versions"][-1]["commit_time"] + return datetime.fromtimestamp(last_ns / 1e9 + 0.001, tz=timezone.utc) + + def test_cleanup_reverts_partial_writes_to_pre_op(self, local_ocdbt): + """Write known pre-op state, snapshot time, write partial state to + one chunk, simulate a stuck op with that chunk in scope, and + confirm cleanup reverts the chunk while leaving a non-scoped + neighbor chunk untouched. + """ + fixture = local_ocdbt + + meta = self._meta_with_fork(fixture, "stuck_cg") + fork_scale0 = fixture["make_fork"]("stuck_cg") + + # Pre-op state: chunk 0 region filled with 111, chunk 1 with 222. + # Chunk grid is at base resolution with 32^3 voxels per chunk. + fork_scale0[0:32, 0:32, 0:32, :] = np.full( + (32, 32, 32, 1), 111, dtype=np.uint64 + ) + fork_scale0[32:64, 0:32, 0:32, :] = np.full( + (32, 32, 32, 1), 222, dtype=np.uint64 + ) + + # Snapshot pin timestamp just after the pre-op writes. + pre_op_pin_dt = self._capture_fork_pin(fixture, "stuck_cg") + + # Partial "crash" writes: overwrite chunk 0 with garbage, touch + # chunk 1 too to prove scope-boundedness (scope will only list + # chunk 0, so chunk 1's garbage must persist after cleanup). + fork_scale0[0:32, 0:32, 0:32, :] = np.full( + (32, 32, 32, 1), 999, dtype=np.uint64 + ) + fork_scale0[32:64, 0:32, 0:32, :] = np.full( + (32, 32, 32, 1), 888, dtype=np.uint64 + ) + + # Chunk IDs for chunk-coord (0,0,0) and (1,0,0) at layer 2. + chunk_id_0 = _chunk_id_from_coord(meta, layer=2, coord=(0, 0, 0)) + chunk_id_1 = _chunk_id_from_coord(meta, layer=2, coord=(1, 0, 0)) + + # Sanity: scope chunk decodes back to the right coord. + assert tuple(get_chunk_coordinates(meta, chunk_id_0)) == (0, 0, 0) + + # Synthetic op-log row with scope=[chunk_id_0] and OperationTimeStamp=pre_op_pin. + op_id = 777 + op_log_row = { + attributes.OperationLogs.L2ChunkLockScope: np.asarray( + [chunk_id_0], dtype=np.uint64 + ), + attributes.OperationLogs.OperationTimeStamp: pre_op_pin_dt, + } + + cg = MagicMock() + cg.meta = meta + cg.client.read_log_entries.return_value = {np.uint64(op_id): op_log_row} + + # `_read_source_scales` reads `/info` from the watershed via + # tensorstore's kvstore interface — fine on GCS, not on file://. + # Bypass with a fake scale list matching the test's scale 0. + fake_scales = [ + { + "resolution": [4, 4, 40], + "size": [64, 64, 32], + "chunk_sizes": [[32, 32, 32]], + "encoding": "compressed_segmentation", + "compressed_segmentation_block_size": [8, 8, 8], + } + ] + # Patch the binding in `ocdbt.main` (where it's actually called). + # The package re-export in `ocdbt/__init__.py` is a separate name + # binding and patching it has no effect on main.py's local one. + with patch.object( + ocdbt_mod.main, "_read_source_scales", return_value=fake_scales + ): + reverted = stuck_ops.cleanup_partial_writes(cg, op_id) + assert reverted == 1 + + # Scoped chunk reverted to pre-op. + scoped = fork_scale0[0:32, 0:32, 0:32, :].read().result() + assert ( + scoped == 111 + ).all(), f"scoped chunk not reverted: unique={np.unique(scoped)}" + # Non-scoped neighbor still has its post-crash "garbage" (888) — + # cleanup does not touch it. + neighbor = fork_scale0[32:64, 0:32, 0:32, :].read().result() + assert ( + neighbor == 888 + ).all(), f"neighbor chunk erroneously reverted: unique={np.unique(neighbor)}" + + +def _chunk_id_from_coord(meta, layer, coord): + """Encode (layer, x, y, z) into a chunk ID using the graph's bitmasks.""" + from pychunkedgraph.graph.chunks.utils import get_chunk_id + + return get_chunk_id( + meta, layer=layer, x=int(coord[0]), y=int(coord[1]), z=int(coord[2]) + ) diff --git a/pychunkedgraph/tests/helpers.py b/pychunkedgraph/tests/helpers.py index c41d629f6..009fec730 100644 --- a/pychunkedgraph/tests/helpers.py +++ b/pychunkedgraph/tests/helpers.py @@ -1,4 +1,6 @@ +import threading from functools import reduce +from unittest.mock import MagicMock import numpy as np @@ -109,3 +111,76 @@ def get_layer_chunk_bounds( layer_bounds = atomic_chunk_bounds / (2 ** (layer - 2)) layer_bounds_d[layer] = np.ceil(layer_bounds).astype(int) return layer_bounds_d + + +class RowKeyLockRegistry: + """Thread-safe in-memory stand-in for kvdbclient's row-key lock API. + + Matches the full `cg.client.lock_by_row_key*` / `unlock_by_row_key*` + / `renew_lock_by_row_key` surface — including the indefinite-column + variants — so row-key-based lock primitives (DownsampleBlockLock, + L2ChunkLock, IndefiniteL2ChunkLock, …) can be exercised without a + bigtable emulator. + + Two separate maps, one per column. The "with_indefinite" temporal + acquire refuses if either map holds the row, mirroring the filter + union that `lock_by_row_key_with_indefinite` uses on bigtable. + """ + + def __init__(self): + self._lock = threading.Lock() + self._held = {} + self._held_indefinite = {} + + def lock_by_row_key(self, row_key, operation_id): + with self._lock: + if row_key in self._held: + return False + self._held[row_key] = operation_id + return True + + def lock_by_row_key_with_indefinite(self, row_key, operation_id): + with self._lock: + if row_key in self._held or row_key in self._held_indefinite: + return False + self._held[row_key] = operation_id + return True + + def lock_by_row_key_indefinitely(self, row_key, operation_id): + with self._lock: + if row_key in self._held_indefinite: + return False + self._held_indefinite[row_key] = operation_id + return True + + def unlock_by_row_key(self, row_key, operation_id): + with self._lock: + if self._held.get(row_key) == operation_id: + del self._held[row_key] + return True + return False + + def unlock_indefinitely_locked_by_row_key(self, row_key, operation_id): + with self._lock: + if self._held_indefinite.get(row_key) == operation_id: + del self._held_indefinite[row_key] + return True + return False + + def renew_lock_by_row_key(self, row_key, operation_id): + with self._lock: + return self._held.get(row_key) == operation_id + + +def make_cg_with_row_key_lock_registry(registry: RowKeyLockRegistry): + """Attach a `RowKeyLockRegistry` to a `MagicMock` cg.client.""" + cg = MagicMock() + cg.client.lock_by_row_key = registry.lock_by_row_key + cg.client.lock_by_row_key_with_indefinite = registry.lock_by_row_key_with_indefinite + cg.client.lock_by_row_key_indefinitely = registry.lock_by_row_key_indefinitely + cg.client.unlock_by_row_key = registry.unlock_by_row_key + cg.client.unlock_indefinitely_locked_by_row_key = ( + registry.unlock_indefinitely_locked_by_row_key + ) + cg.client.renew_lock_by_row_key = registry.renew_lock_by_row_key + return cg diff --git a/pychunkedgraph/tests/ingest/test_ingest_utils.py b/pychunkedgraph/tests/ingest/test_ingest_utils.py index 4c5bdf0af..400ce3a0c 100644 --- a/pychunkedgraph/tests/ingest/test_ingest_utils.py +++ b/pychunkedgraph/tests/ingest/test_ingest_utils.py @@ -44,10 +44,13 @@ def test_from_config(self): }, "ingest_config": {}, } - meta, ingest_config, client_info = bootstrap("test_graph", config=config) + meta, ingest_config, client_info, ocdbt_config_dict = bootstrap( + "test_graph", config=config + ) assert meta.graph_config.ID == "test_graph" assert meta.graph_config.FANOUT == 2 assert ingest_config.USE_RAW_EDGES is False + assert isinstance(ocdbt_config_dict, dict) class TestPostprocessEdgeData: @@ -329,7 +332,6 @@ def my_func(): # ===================================================================== # Additional pure unit tests # ===================================================================== -from pychunkedgraph.ingest.utils import start_ocdbt_server class TestGetChunksNotDoneWithSplits: @@ -390,57 +392,6 @@ def test_get_chunks_not_done_splits_coord_str_format(self): assert call_args[0][1] == ["2_3_4_0"] -class TestStartOcdbtServer: - """Test start_ocdbt_server function.""" - - @patch("pychunkedgraph.ingest.utils.ts") - @patch.dict("os.environ", {"MY_POD_IP": "10.0.0.1"}) - def test_start_ocdbt_server(self, mock_ts): - """start_ocdbt_server should open a KvStore and set redis keys.""" - imanager = MagicMock() - imanager.cg.meta.data_source.EDGES = "gs://bucket/edges" - mock_redis = MagicMock() - imanager.redis = mock_redis - - server = MagicMock() - server.port = 12345 - - mock_kv_future = MagicMock() - mock_ts.KvStore.open.return_value = mock_kv_future - - start_ocdbt_server(imanager, server) - - # Verify tensorstore was called with the right spec - call_args = mock_ts.KvStore.open.call_args[0][0] - assert call_args["driver"] == "ocdbt" - assert "gs://bucket/edges/ocdbt" in call_args["base"] - assert call_args["coordinator"]["address"] == "localhost:12345" - mock_kv_future.result.assert_called_once() - - # Verify redis keys were set - mock_redis.set.assert_any_call("OCDBT_COORDINATOR_PORT", "12345") - mock_redis.set.assert_any_call("OCDBT_COORDINATOR_HOST", "10.0.0.1") - - @patch("pychunkedgraph.ingest.utils.ts") - @patch.dict("os.environ", {}, clear=True) - def test_start_ocdbt_server_default_host(self, mock_ts): - """When MY_POD_IP is not set, should default to localhost.""" - imanager = MagicMock() - imanager.cg.meta.data_source.EDGES = "gs://bucket/edges" - mock_redis = MagicMock() - imanager.redis = mock_redis - - server = MagicMock() - server.port = 9999 - - mock_kv_future = MagicMock() - mock_ts.KvStore.open.return_value = mock_kv_future - - start_ocdbt_server(imanager, server) - - mock_redis.set.assert_any_call("OCDBT_COORDINATOR_HOST", "localhost") - - class TestPostprocessEdgeDataNoneValues: """Test postprocess_edge_data when edge_dict values are None.""" diff --git a/requirements.in b/requirements.in index 143d90399..42f6cd592 100644 --- a/requirements.in +++ b/requirements.in @@ -14,6 +14,7 @@ pyyaml cachetools werkzeug tensorstore +rich edt connected-components-3d scikit-image @@ -28,8 +29,9 @@ task-queue>=2.14.0 messagingclient>0.3.0 dracopy>=1.5.0 datastoreflex>=0.5.0 -kvdbclient>0.5.0 +kvdbclient>=0.7.0 zstandard>=0.23.0 +tinybrain>=1.7.0 # Conda only - use requirements.yml (or install manually): # graph-tool \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index af29a75bd..8c4a4d032 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,18 +4,18 @@ # # pip-compile --output-file=requirements.txt requirements.in # -attrs==25.4.0 +attrs==26.1.0 # via # jsonschema # referencing blinker==1.9.0 # via flask -boto3==1.42.53 +boto3==1.43.14 # via # cloud-files # cloud-volume # task-queue -botocore==1.42.53 +botocore==1.43.14 # via # boto3 # s3transfer @@ -23,21 +23,21 @@ brotli==1.2.0 # via # cloud-files # urllib3 -cachetools==7.0.1 +cachetools==7.1.4 # via # -r requirements.in # middle-auth-client -certifi==2026.1.4 +certifi==2026.5.20 # via requests cffi==2.0.0 # via cryptography -chardet==5.2.0 +chardet==7.4.3 # via # cloud-files # cloud-volume -charset-normalizer==3.4.4 +charset-normalizer==3.4.7 # via requests -click==8.3.1 +click==8.4.1 # via # -r requirements.in # cloud-files @@ -46,22 +46,22 @@ click==8.3.1 # microviewer # rq # task-queue -cloud-files==6.2.1 +cloud-files==6.3.0 # via # -r requirements.in # cloud-volume # datastoreflex -cloud-volume==12.10.0 +cloud-volume==12.13.1 # via -r requirements.in -compressed-segmentation==2.3.2 +compressed-segmentation==2.3.3 # via cloud-volume -connected-components-3d==3.26.1 +connected-components-3d==3.28.0 # via -r requirements.in crc32c==2.8 # via cloud-files -croniter==6.0.0 +croniter==6.2.2 # via rq -cryptography==46.0.5 +cryptography==48.0.0 # via google-auth datastoreflex==0.5.0 # via -r requirements.in @@ -79,7 +79,7 @@ edt==3.1.1 # via -r requirements.in fasteners==0.20 # via cloud-files -fastremap==1.17.7 +fastremap==1.19.0 # via # -r requirements.in # cloud-volume @@ -93,19 +93,19 @@ flask-cors==6.0.2 # via -r requirements.in furl==2.1.4 # via middle-auth-client -gevent==25.9.1 +gevent==26.5.0 # via # cloud-files # cloud-volume # task-queue -google-api-core[grpc]==2.30.0 +google-api-core[grpc]==2.30.3 # via # google-cloud-bigtable # google-cloud-core # google-cloud-datastore # google-cloud-pubsub # google-cloud-storage -google-auth==2.48.0 +google-auth==2.53.0 # via # cloud-files # cloud-volume @@ -116,9 +116,9 @@ google-auth==2.48.0 # google-cloud-pubsub # google-cloud-storage # task-queue -google-cloud-bigtable==2.35.0 +google-cloud-bigtable==2.38.0 # via kvdbclient -google-cloud-core==2.5.0 +google-cloud-core==2.6.0 # via # cloud-files # cloud-volume @@ -126,13 +126,13 @@ google-cloud-core==2.5.0 # google-cloud-datastore # google-cloud-storage # task-queue -google-cloud-datastore==2.23.0 +google-cloud-datastore==2.24.0 # via # -r requirements.in # datastoreflex -google-cloud-pubsub==2.35.0 +google-cloud-pubsub==2.38.0 # via messagingclient -google-cloud-storage==3.9.0 +google-cloud-storage==3.10.1 # via # cloud-files # cloud-volume @@ -142,37 +142,36 @@ google-crc32c==1.8.0 # google-cloud-bigtable # google-cloud-storage # google-resumable-media -google-resumable-media==2.8.0 +google-resumable-media==2.9.0 # via google-cloud-storage -googleapis-common-protos[grpc]==1.72.0 +googleapis-common-protos[grpc]==1.75.0 # via # google-api-core # grpc-google-iam-v1 # grpcio-status -greenlet==3.3.1 +greenlet==3.5.1 # via gevent -grpc-google-iam-v1==0.14.3 +grpc-google-iam-v1==0.14.4 # via # google-cloud-bigtable # google-cloud-pubsub -grpcio==1.78.0 +grpcio==1.80.0 # via # google-api-core + # google-cloud-bigtable # google-cloud-datastore # google-cloud-pubsub # googleapis-common-protos # grpc-google-iam-v1 # grpcio-status -grpcio-status==1.78.0 +grpcio-status==1.80.0 # via # google-api-core # google-cloud-pubsub -idna==3.11 +idna==3.16 # via requests -imageio==2.37.2 +imageio==2.37.3 # via scikit-image -importlib-metadata==8.7.1 - # via opentelemetry-api inflection==0.5.1 # via python-jsonschema-objects iniconfig==2.3.0 @@ -187,7 +186,7 @@ jmespath==1.1.0 # via # boto3 # botocore -json5==0.13.0 +json5==0.14.0 # via cloud-volume jsonschema==4.26.0 # via @@ -195,20 +194,24 @@ jsonschema==4.26.0 # python-jsonschema-objects jsonschema-specifications==2025.9.1 # via jsonschema -kvdbclient==0.6.0 +kvdbclient==0.7.0 # via -r requirements.in -lazy-loader==0.4 +lazy-loader==0.5 # via scikit-image markdown==3.10.2 # via python-jsonschema-objects +markdown-it-py==4.2.0 + # via rich markupsafe==3.0.3 # via # flask # jinja2 # werkzeug +mdurl==0.1.2 + # via markdown-it-py messagingclient==0.4.0 # via -r requirements.in -microviewer==1.20.0 +microviewer==1.21.0 # via cloud-volume middle-auth-client==3.19.2 # via -r requirements.in @@ -222,7 +225,7 @@ networkx==3.6.1 # cloud-volume # osteoid # scikit-image -numpy==2.4.2 +numpy==2.4.6 # via # -r requirements.in # cloud-files @@ -244,30 +247,31 @@ numpy==2.4.2 # task-queue # tensorstore # tifffile + # tinybrain # zmesh -opentelemetry-api==1.39.1 +opentelemetry-api==1.42.1 # via # google-cloud-pubsub # opentelemetry-sdk # opentelemetry-semantic-conventions -opentelemetry-sdk==1.39.1 +opentelemetry-sdk==1.42.1 # via google-cloud-pubsub -opentelemetry-semantic-conventions==0.60b1 +opentelemetry-semantic-conventions==0.63b1 # via opentelemetry-sdk orderedmultidict==1.0.2 # via furl -orjson==3.11.7 +orjson==3.11.9 # via # cloud-files # task-queue osteoid==0.6.0 # via cloud-volume -packaging==26.0 +packaging==26.2 # via # lazy-loader # pytest # scikit-image -pandas==3.0.1 +pandas==3.0.3 # via -r requirements.in pathos==0.3.5 # via @@ -276,7 +280,7 @@ pathos==0.3.5 # task-queue pbr==7.0.3 # via task-queue -pillow==12.1.1 +pillow==12.2.0 # via # imageio # scikit-image @@ -288,13 +292,13 @@ pox==0.3.7 # via pathos ppft==1.7.8 # via pathos -proto-plus==1.27.1 +proto-plus==1.28.0 # via # google-api-core # google-cloud-bigtable # google-cloud-datastore # google-cloud-pubsub -protobuf==6.33.5 +protobuf==6.33.6 # via # -r requirements.in # cloud-files @@ -309,21 +313,23 @@ protobuf==6.33.5 # proto-plus psutil==7.2.2 # via cloud-volume -pyasn1==0.6.2 +pyasn1==0.6.3 # via # pyasn1-modules # rsa pyasn1-modules==0.4.2 # via google-auth -pybind11==3.0.2 +pybind11==3.0.4 # via osteoid pycparser==3.0 # via cffi -pygments==2.19.2 - # via pytest +pygments==2.20.0 + # via + # pytest + # rich pysimdjson==7.0.2 # via cloud-volume -pytest==9.0.2 +pytest==9.0.3 # via compressed-segmentation python-dateutil==2.9.0.post0 # via @@ -331,17 +337,15 @@ python-dateutil==2.9.0.post0 # cloud-volume # croniter # pandas -python-json-logger==4.0.0 +python-json-logger==4.1.0 # via -r requirements.in python-jsonschema-objects==0.5.7 # via cloud-volume -pytz==2025.2 - # via - # croniter - # kvdbclient +pytz==2026.2 + # via kvdbclient pyyaml==6.0.3 # via -r requirements.in -redis==7.2.0 +redis==7.4.0 # via # -r requirements.in # rq @@ -349,7 +353,7 @@ referencing==0.37.0 # via # jsonschema # jsonschema-specifications -requests==2.32.5 +requests==2.34.2 # via # -r requirements.in # cloud-files @@ -359,17 +363,17 @@ requests==2.32.5 # kvdbclient # middle-auth-client # task-queue +rich==15.0.0 + # via -r requirements.in rpds-py==0.30.0 # via # jsonschema # referencing -rq==2.6.1 +rq==2.9.0 # via -r requirements.in rsa==4.9.1 - # via - # cloud-files - # google-auth -s3transfer==0.16.0 + # via cloud-files +s3transfer==0.17.0 # via boto3 scikit-image==0.26.0 # via -r requirements.in @@ -393,10 +397,12 @@ tenacity==9.1.4 # cloud-volume # kvdbclient # task-queue -tensorstore==0.1.81 +tensorstore==0.1.84 # via -r requirements.in -tifffile==2026.3.3 +tifffile==2026.5.15 # via scikit-image +tinybrain==1.7.0 + # via -r requirements.in tqdm==4.67.3 # via # cloud-files @@ -409,24 +415,22 @@ typing-extensions==4.15.0 # opentelemetry-sdk # opentelemetry-semantic-conventions # referencing -urllib3[brotli]==2.6.3 +urllib3[brotli]==2.7.0 # via # botocore # cloud-files # cloud-volume # requests -werkzeug==3.1.6 +werkzeug==3.1.8 # via # -r requirements.in # flask # flask-cors -zipp==3.23.0 - # via importlib-metadata -zmesh==1.10.0 +zmesh==1.13.1 # via -r requirements.in -zope-event==6.1 +zope-event==6.2 # via gevent -zope-interface==8.2 +zope-interface==8.4 # via gevent zstandard==0.25.0 # via diff --git a/workers/downsample_worker.py b/workers/downsample_worker.py new file mode 100644 index 000000000..ae436a9aa --- /dev/null +++ b/workers/downsample_worker.py @@ -0,0 +1,86 @@ +# pylint: disable=invalid-name, missing-docstring, logging-fstring-interpolation + +"""Pubsub worker that updates coarser segmentation mips after an SV split. + +Consumes the same edits exchange the mesh worker uses, but binds its own +queue and filters on the `downsample="true"` attribute set by +`publish_edit` when `result.seg_bbox` is populated. For each block the +SV-split touched, acquires the block's lock, runs the in-memory / +per-mip pyramid writer, releases. +""" + +import gc +import logging +import pickle +from os import getenv + +from messagingclient import MessagingClient + +from pychunkedgraph.graph import ChunkedGraph +from pychunkedgraph.graph.downsample import blocks_for_bbox, process_block +from pychunkedgraph.graph.locks import DownsampleBlockLock + +PCG_CACHE = {} + +INFO_HIGH = 25 +logging.basicConfig( + level=INFO_HIGH, + format="%(asctime)s %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", +) + + +def callback(payload): + # Filter by attribute rather than queue binding so all edit-triggered + # workers can share the same exchange. Split edits set + # `downsample=true`; merges/undos/redos/rollbacks don't. + if payload.attributes.get("downsample") != "true": + return + + data = pickle.loads(payload.data) + op_id = int(data["operation_id"]) + table_id = payload.attributes["table_id"] + seg_bboxes = data.get("seg_bboxes") + if not seg_bboxes: + return + + try: + cg = PCG_CACHE[table_id] + except KeyError: + cg = ChunkedGraph(graph_id=table_id) + PCG_CACHE[table_id] = cg + + # Defensive: non-OCDBT graphs have no coarser scales to write to. + seg_cfg = cg.meta.custom_data.get("seg", {}) + if not seg_cfg.get("ocdbt"): + logging.log( + INFO_HIGH, + f"graph {table_id} not OCDBT-backed; skipping downsample op {op_id}", + ) + return + + # Each published bbox is one SV split's write region. Collapse the + # list into the union of blocks touched so we lock/process each + # block exactly once even if two bboxes share blocks. + unique_blocks = set() + for bbs, bbe in seg_bboxes: + unique_blocks.update(blocks_for_bbox(cg.meta, bbs, bbe)) + block_list = sorted(unique_blocks) + + logging.log( + INFO_HIGH, + f"downsampling {len(block_list)} block(s) for op {op_id} graph {table_id}", + ) + with DownsampleBlockLock(cg, block_list, op_id): + for block in block_list: + process_block(cg.meta, block, seg_bboxes) + logging.log(INFO_HIGH, f"downsample complete op {op_id} graph {table_id}") + gc.collect() + + +c = MessagingClient() +downsample_queue = getenv("PYCHUNKEDGRAPH_DOWNSAMPLE_QUEUE") +assert ( + downsample_queue is not None +), "env PYCHUNKEDGRAPH_DOWNSAMPLE_QUEUE not specified." +c.consume(downsample_queue, callback)