diff --git a/docs/pipeline.md b/docs/pipeline.md index 2ece529e..a6918b30 100644 --- a/docs/pipeline.md +++ b/docs/pipeline.md @@ -180,12 +180,18 @@ Custom reducers are any callable `(current, update) -> merged`. ### Checkpoint + Resume -Pass a `Checkpointer` to persist state after each successful node. The -filesystem implementation ships in the package; Redis/Postgres backends are -straightforward to plug in via the `Checkpointer` Protocol. +Pass a `Checkpointer` to persist state after each successful node. Three +backends ship out of the box, all conforming to the same `Checkpointer` +Protocol so they're swappable without code changes. + +| Backend | Use when | Trade-off | Install | +|---|---|---|---| +| `FileCheckpointer` | Dev, single-host, ephemeral | No cross-process / cross-host sharing | (default — no extra) | +| `RedisCheckpointer` | Multi-worker, sub-day-scale runs | TTL eviction; not durable forever | `pip install fireflyframework-agentic[redis]` | +| `PostgresCheckpointer` | Long-lived runs, compliance, audit-friendly | Operational overhead of a DB | `pip install fireflyframework-agentic[postgres]` | ```python -from fireflyframework_agentic.pipeline import FileCheckpointer +from fireflyframework_agentic.pipeline import FileCheckpointer # or Redis / Postgres pipeline = ( PipelineBuilder("software-factory", state=BuildState, @@ -208,6 +214,21 @@ result = await pipeline.invoke(run_id=result.run_id) result = await pipeline.invoke(state=loaded_state, start_at=deployer) ``` +Swapping backends is a one-line change. Redis uses a TTL on each checkpoint +key (default 30 days) plus a sorted-set index of run IDs; Postgres uses a +single `firefly_checkpoints` table created idempotently on first save: + +```python +from fireflyframework_agentic.pipeline import RedisCheckpointer, PostgresCheckpointer + +# Either a URL/DSN (backend constructs its own client) or a pre-built client +# (lets you share a connection pool across many pipelines). +checkpointer = RedisCheckpointer(url="redis://localhost:6379/0", ttl_seconds=86400 * 30) +checkpointer = RedisCheckpointer(client=my_existing_redis) +checkpointer = PostgresCheckpointer(dsn="postgresql://user:pw@host/db") +checkpointer = PostgresCheckpointer(connection=my_existing_psycopg_connection) +``` + ### Cycles and `recursion_limit` State pipelines permit cycles for ReAct loops and retry-with-critique patterns. diff --git a/docs/superpowers/specs/2026-05-27-pipeline-phase-3a-design.md b/docs/superpowers/specs/2026-05-27-pipeline-phase-3a-design.md new file mode 100644 index 00000000..7eb5f684 --- /dev/null +++ b/docs/superpowers/specs/2026-05-27-pipeline-phase-3a-design.md @@ -0,0 +1,209 @@ +# Phase 3a — Durable Checkpointer Backends (Redis + Postgres) + +Issue [#147](https://github.com/fireflyframework/fireflyframework-agentic/issues/147), phase 3a. Stacked on top of Phase 1+2 (`issue-147-pipeline-evolution`). + +## Problem + +`StatePipeline` checkpointing today has one backend: `FileCheckpointer`, which writes JSON files to a local filesystem. That blocks the use cases customers actually run agentic pipelines for: + +- **Multi-worker fail-over.** A pipeline that crashes on worker A cannot be resumed on worker B because the checkpoint files are on A's disk. +- **Containerized deploys.** Ephemeral container filesystems lose checkpoint state on restart. +- **Horizontal scaling.** No shared state means runs are pinned to a single host. +- **Long-lived runs.** Filesystem checkpoints accumulate without TTL; cleanup is manual. + +Phase 3a ships two durable backends — Redis and Postgres — pluggable through the existing `Checkpointer` Protocol. No API changes to `StatePipeline`; no breaking changes for existing `FileCheckpointer` users. + +## Goal + +`PipelineBuilder("agent", state=..., checkpointer=...)` accepts a `RedisCheckpointer` or `PostgresCheckpointer` interchangeably with the existing `FileCheckpointer`. The software-factory scenario (architect → python_dev → deployer fails → resume on a different worker → evaluator) works end-to-end against any of the three. + +## Non-goals + +- Backend-pluggable observability / OTel spans (phase 3b). +- HITL pause primitive and audit log (phase 3c). +- Real-Redis / real-Postgres tests in this PR — verified out-of-band by the user. +- Schema migration tooling for the JSONB `state` column — the application owns state-schema evolution. +- Connection-pool tuning knobs beyond "pass me a client." +- Auth / RBAC on resume APIs. +- Streaming partial state. + +## Architecture + +Single file: `fireflyframework_agentic/pipeline/checkpoint.py`. Existing contents (`Checkpointer` Protocol, `CheckpointRecord`, `FileCheckpointer`) are preserved. Two new classes added to the same module: + +``` +pipeline/checkpoint.py +├── Checkpointer (Protocol, existing) +├── CheckpointRecord (existing) +├── FileCheckpointer (existing) +├── RedisCheckpointer (new) +└── PostgresCheckpointer (new) +``` + +### Guarded optional imports + +Top-of-file guarded imports mirror the established pattern in `engine.py` for OTel: + +```python +try: + import redis as _redis +except ImportError: + _redis = None + +try: + import psycopg as _psycopg +except ImportError: + _psycopg = None +``` + +Each backend's `__init__` raises a clear install-instruction error if its dependency is absent. No inline imports — project rule respected. + +### Connection lifecycle + +Both backends accept either a connection string OR a pre-built sync client so callers can share a connection pool across many pipelines: + +```python +RedisCheckpointer(url="redis://localhost:6379/0", ttl_seconds=86400 * 30) +RedisCheckpointer(client=existing_redis_client) + +PostgresCheckpointer(dsn="postgresql://user:pw@host/db") +PostgresCheckpointer(connection=existing_psycopg_connection) +``` + +The `Checkpointer` Protocol's methods are sync (called synchronously inside `StatePipeline._save_checkpoint`). Both backends use the sync clients (`redis-py`, `psycopg[binary]`) directly — no `asyncio.run` indirection. + +## Storage schemas + +### Redis + +One key per checkpoint, plus a sorted-set index per pipeline for `list_runs`: + +| Key | Type | Value | +|---|---|---| +| `firefly:ckpt:{pipeline}:{run_id}:{seq:06d}_{node_id}` | string | JSON-serialized `CheckpointRecord` | +| `firefly:ckpt:{pipeline}:runs` | sorted set | members = `run_id`s, scores = last-update unix timestamps | + +- TTL on the per-checkpoint keys is configurable, default 30 days. +- `save` issues `SET key value EX ttl` then `ZADD index ts run_id` (idempotent — score updated each call). +- `load_latest` issues `KEYS firefly:ckpt:{pipeline}:{run_id}:*`, picks the lexicographically last key (the zero-padded `seq` makes lex order match numeric order), then `GET` it. +- `list_runs` issues `ZRANGE index 0 -1` and returns all run IDs ordered by last-update. + +`KEYS` is acceptable here because the cardinality per run is bounded by node count × visit count (small for agentic workflows). A `SCAN`-based variant is trivial to add if scale demands it; not in scope for 3a. + +### Postgres + +Single table, created idempotently on first connect: + +```sql +CREATE TABLE IF NOT EXISTS firefly_checkpoints ( + pipeline_name TEXT NOT NULL, + run_id TEXT NOT NULL, + sequence INT NOT NULL, + node_id TEXT NOT NULL, + state JSONB NOT NULL, + completed_nodes JSONB NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (pipeline_name, run_id, sequence) +); +CREATE INDEX IF NOT EXISTS firefly_checkpoints_run + ON firefly_checkpoints (pipeline_name, run_id); +``` + +- DDL runs once per `PostgresCheckpointer` instance via a `_ddl_applied` flag set on first `save`. +- `save` issues `INSERT … ON CONFLICT (pipeline_name, run_id, sequence) DO UPDATE SET node_id = EXCLUDED.node_id, state = EXCLUDED.state, completed_nodes = EXCLUDED.completed_nodes`. +- `load_latest` issues `SELECT pipeline_name, run_id, sequence, node_id, state, completed_nodes FROM firefly_checkpoints WHERE pipeline_name = %s AND run_id = %s ORDER BY sequence DESC LIMIT 1`. +- `list_runs` issues `SELECT DISTINCT run_id FROM firefly_checkpoints WHERE pipeline_name = %s ORDER BY run_id`. + +## Optional dependencies + +`pyproject.toml` gains two new extras, no new required dependencies and no new dev-dependencies: + +```toml +[project.optional-dependencies] +redis = ["redis>=5,<6"] +postgres = ["psycopg[binary]>=3,<4"] +``` + +Install paths: + +```bash +pip install fireflyframework-agentic[redis] +pip install fireflyframework-agentic[postgres] +pip install fireflyframework-agentic[redis,postgres] +``` + +## Testing strategy + +`unittest.mock` only. No `fakeredis`, no `pytest-postgresql`, no test containers. Real-service verification is out-of-band by the user against real Redis and real Postgres, not in this PR. + +### `RedisCheckpointer` tests + +Mock the `redis.Redis` client (`mock.create_autospec(redis.Redis)`). Assertions: + +- `save` issues `client.set(key=..., value=, ex=)` with the expected key format and TTL. +- `save` issues `client.zadd("firefly:ckpt::runs", {: })`. +- `load_latest` issues `client.keys("firefly:ckpt:::*")`, picks the lex-last key, then `client.get()`, returns the parsed `CheckpointRecord`. +- `load_latest` returns `None` when `keys` returns an empty list. +- `list_runs` issues `client.zrange("firefly:ckpt::runs", 0, -1)` and returns its result. +- Constructing `RedisCheckpointer(url=...)` when `_redis is None` raises `ImportError` with a message that names the extra (`pip install fireflyframework-agentic[redis]`). + +### `PostgresCheckpointer` tests + +Mock the `psycopg.Connection` and its `cursor()` context manager. Assertions: + +- First `save` issues the `CREATE TABLE IF NOT EXISTS` DDL exactly once. The DDL flag prevents subsequent `save` calls from re-issuing it. +- `save` issues `INSERT … ON CONFLICT (pipeline_name, run_id, sequence) DO UPDATE SET …` with bound parameters matching the `CheckpointRecord` fields. +- `load_latest` issues `SELECT … ORDER BY sequence DESC LIMIT 1` and returns the row hydrated into a `CheckpointRecord`. Returns `None` for the no-rows case. +- `list_runs` issues `SELECT DISTINCT run_id … ORDER BY run_id` and returns the list. +- Constructing `PostgresCheckpointer(dsn=...)` when `_psycopg is None` raises `ImportError` naming the extra. + +### Protocol-conformance test + +A single test module `tests/unit/pipeline/test_checkpoint_backends.py` parametrizes the existing software-factory scenario (architect → python_dev → deployer fails → resume → evaluator) across all three backends: + +- `FileCheckpointer` — real, uses `tmp_path`. +- `RedisCheckpointer` — wrapping a `MagicMock` Redis client that records `set`/`get`/`keys`/`zadd`/`zrange` calls in-memory. +- `PostgresCheckpointer` — wrapping a `MagicMock` connection whose cursor returns canned results from an in-memory dict mimicking the schema. + +The mocks keep enough state to make the full resume flow work (save then load returns what was saved). Protocol drift between backends is caught as a test failure. + +## Documentation + +`docs/pipeline.md` — "Checkpoint + Resume" subsection gains a backend-comparison table: + +| Backend | Use when | Trade-off | +|---|---|---| +| `FileCheckpointer` | Dev, single-host, ephemeral | No cross-process / cross-host sharing | +| `RedisCheckpointer` | Multi-worker, sub-day-scale runs | TTL eviction; not durable forever | +| `PostgresCheckpointer` | Long-lived runs, compliance, audit-friendly | Operational overhead of a DB | + +`examples/pipeline_state.py` — append a fourth scenario gated behind `if os.environ.get("PG_DSN"):` showing how to swap `FileCheckpointer` for `PostgresCheckpointer`. No-op when the env var is unset, so the example still runs anywhere. + +`fireflyframework_agentic/pipeline/__init__.py` — re-export `RedisCheckpointer` and `PostgresCheckpointer` via a plain `from … import …`. Both classes always exist as importable names; the optional-dep gate is in their `__init__` methods. Rationale: simpler than `__getattr__` indirection on the package, and the canonical pattern in this codebase for OTel and other soft dependencies. + +## Scope + +- `pipeline/checkpoint.py`: ~250 LOC added +- `pyproject.toml`: two new `[project.optional-dependencies]` entries; **no new required deps, no new dev deps** +- `tests/unit/pipeline/test_checkpoint_backends.py`: ~180 LOC +- `docs/pipeline.md`: +20 LOC (table + a paragraph) +- `examples/pipeline_state.py`: +30 LOC (optional fourth scenario) +- `pipeline/__init__.py`: +2 exports + +Total ~450 LOC + tests, independently shippable. + +## Verification + +After implementation: + +1. `pytest tests/unit/pipeline/ -v` — pipeline suite green, all three backends pass the parametrized conformance test. +2. `pytest tests/unit/` — full unit suite green. +3. `ruff check` + `ruff format --check` clean. +4. `pyright` clean on touched modules. +5. `python -c "from fireflyframework_agentic.pipeline import RedisCheckpointer, PostgresCheckpointer"` works in a fresh venv with neither extra installed (the imports succeed; constructing the classes is what raises). +6. `python -c "from fireflyframework_agentic.pipeline import RedisCheckpointer; RedisCheckpointer(url='redis://x')"` in the no-extras venv raises a clear `ImportError` naming the extra. + +## What lands next + +- **Phase 3b** — `StatePipelineEventHandler` Protocol + OTel spans per state-pipeline node. +- **Phase 3c** — `Pause(reason)` sentinel for HITL approval gates + `AuditLog` Protocol with Postgres impl reusing the 3a Postgres connection. diff --git a/examples/pipeline_state.py b/examples/pipeline_state.py index 9f0fc463..6a3536f8 100644 --- a/examples/pipeline_state.py +++ b/examples/pipeline_state.py @@ -43,6 +43,7 @@ import asyncio import logging +import os import tempfile from pathlib import Path from typing import Annotated @@ -52,6 +53,7 @@ from fireflyframework_agentic.pipeline import ( FileCheckpointer, PipelineBuilder, + PostgresCheckpointer, Send, extend, ) @@ -232,10 +234,45 @@ async def run_map_reduce() -> None: # ============================================================================= +async def run_software_factory_postgres() -> None: + """Optional: the same software-factory scenario backed by Postgres. + + Runs only when the ``PG_DSN`` env var is set (e.g. + ``PG_DSN=postgresql://user:pw@localhost/firefly``). Requires the + ``postgres`` extra: ``pip install fireflyframework-agentic[postgres]``. + """ + dsn = os.environ.get("PG_DSN") + if not dsn: + return + + print("=== 4. Software factory with PostgresCheckpointer ===\n") + + # Reset the deployer flag so this scenario starts clean. + _deployer_failed_once["flag"] = False + + checkpointer = PostgresCheckpointer(dsn=dsn) + pipeline = ( + PipelineBuilder("software-factory-pg", state=BuildState, checkpointer=checkpointer) + .add_node(architect) + .add_node(python_dev) + .add_node(deployer) + .add_node(evaluator) + .chain(architect, python_dev, deployer, evaluator) + .build() + ) + first = await pipeline.invoke(BuildState(requirements="postgres-backed deploy")) + print(f" first run: success={first.success}, failed_node={first.failed_node}") + print(f" run_id: {first.run_id}\n") + second = await pipeline.invoke(run_id=first.run_id) + print(f" resumed: success={second.success}") + print(f" eval: {second.state.evaluation}\n") + + async def main() -> None: await run_branching() await run_software_factory() await run_map_reduce() + await run_software_factory_postgres() if __name__ == "__main__": diff --git a/fireflyframework_agentic/pipeline/__init__.py b/fireflyframework_agentic/pipeline/__init__.py index ed844dc0..3925952c 100644 --- a/fireflyframework_agentic/pipeline/__init__.py +++ b/fireflyframework_agentic/pipeline/__init__.py @@ -33,6 +33,8 @@ Checkpointer, CheckpointRecord, FileCheckpointer, + PostgresCheckpointer, + RedisCheckpointer, ) from fireflyframework_agentic.pipeline.context import PipelineContext from fireflyframework_agentic.pipeline.dag import DAG, DAGEdge, DAGNode, FailureStrategy @@ -75,6 +77,7 @@ "FanOutStep", "FileCheckpointer", "NodeResult", + "PostgresCheckpointer", "PipelineBuilder", "PipelineContext", "PipelineEngine", @@ -82,8 +85,9 @@ "PipelineResult", "ReasoningStep", "RecursionLimitError", - "Send", + "RedisCheckpointer", "RetrievalStep", + "Send", "StatePipeline", "StatePipelineResult", "StepExecutor", diff --git a/fireflyframework_agentic/pipeline/checkpoint.py b/fireflyframework_agentic/pipeline/checkpoint.py index d4ebeaae..3a63d8c1 100644 --- a/fireflyframework_agentic/pipeline/checkpoint.py +++ b/fireflyframework_agentic/pipeline/checkpoint.py @@ -18,19 +18,37 @@ ``(pipeline_name, run_id, node_id)``. On resume the engine loads the latest checkpoint and skips nodes that already completed in that run. -:class:`FileCheckpointer` is a filesystem-backed JSON implementation suitable -for single-process workflows. Redis / Postgres checkpointers can implement -the same Protocol and be swapped in without API changes. +Three backends ship: + +* :class:`FileCheckpointer` — filesystem JSON. Best for dev / single-host. +* :class:`RedisCheckpointer` — Redis with TTL. Best for multi-worker, + sub-day-scale runs. Requires ``pip install fireflyframework-agentic[redis]``. +* :class:`PostgresCheckpointer` — Postgres with a single ``firefly_checkpoints`` + table. Best for long-lived runs / compliance. Requires + ``pip install fireflyframework-agentic[postgres]``. + +Any backend conforms to the :class:`Checkpointer` Protocol and is interchangeable. """ from __future__ import annotations import json +import time from pathlib import Path from typing import Any, Protocol, runtime_checkable from pydantic import BaseModel +try: + import redis as _redis # type: ignore[import-not-found] +except ImportError: # pragma: no cover - optional dep + _redis = None # type: ignore[assignment] + +try: + import psycopg as _psycopg # type: ignore[import-not-found] +except ImportError: # pragma: no cover - optional dep + _psycopg = None # type: ignore[assignment] + class CheckpointRecord(BaseModel): """One saved checkpoint.""" @@ -97,3 +115,187 @@ def list_runs(self, pipeline_name: str) -> list[str]: if not pipeline_dir.exists(): return [] return sorted(d.name for d in pipeline_dir.iterdir() if d.is_dir()) + + +class RedisCheckpointer: + """Redis-backed checkpointer. + + Key layout:: + + firefly:ckpt:::_ -> JSON record (with TTL) + firefly:ckpt::runs -> ZSET of run_ids (score = last update ts) + + Parameters: + url: Redis connection string (e.g. ``redis://host:6379/0``). Mutually + exclusive with ``client``. + client: Pre-built ``redis.Redis`` instance. Use this to share a + connection pool across many pipelines. + ttl_seconds: TTL applied to each checkpoint key. Default 30 days. + The runs-index ZSET does not expire (it's tiny). + key_prefix: Override the default ``firefly:ckpt`` key prefix. + + Raises: + ImportError: When the ``redis`` extra is not installed. + """ + + def __init__( + self, + *, + url: str | None = None, + client: Any = None, + ttl_seconds: int = 60 * 60 * 24 * 30, + key_prefix: str = "firefly:ckpt", + ) -> None: + if _redis is None: + raise ImportError( + "RedisCheckpointer requires the 'redis' extra. " + "Install with: pip install fireflyframework-agentic[redis]" + ) + if (url is None) == (client is None): + raise ValueError("RedisCheckpointer needs exactly one of `url` or `client`.") + self._client = client if client is not None else _redis.Redis.from_url(url, decode_responses=True) + self._ttl = ttl_seconds + self._prefix = key_prefix + + def _ckpt_key(self, pipeline: str, run_id: str, sequence: int, node_id: str) -> str: + return f"{self._prefix}:{pipeline}:{run_id}:{sequence:06d}_{node_id}" + + def _runs_index_key(self, pipeline: str) -> str: + return f"{self._prefix}:{pipeline}:runs" + + def _run_pattern(self, pipeline: str, run_id: str) -> str: + return f"{self._prefix}:{pipeline}:{run_id}:*" + + def save(self, record: CheckpointRecord) -> None: + key = self._ckpt_key(record.pipeline_name, record.run_id, record.sequence, record.node_id) + self._client.set(key, record.model_dump_json(), ex=self._ttl) + self._client.zadd(self._runs_index_key(record.pipeline_name), {record.run_id: time.time()}) + + def load_latest(self, pipeline_name: str, run_id: str) -> CheckpointRecord | None: + keys = self._client.keys(self._run_pattern(pipeline_name, run_id)) + if not keys: + return None + # Keys are zero-padded by sequence; lex-sorted last = numerically-latest. + latest_key = sorted(keys)[-1] + payload = self._client.get(latest_key) + if payload is None: + return None + return CheckpointRecord.model_validate(json.loads(payload)) + + def list_runs(self, pipeline_name: str) -> list[str]: + return list(self._client.zrange(self._runs_index_key(pipeline_name), 0, -1)) + + +class PostgresCheckpointer: + """Postgres-backed checkpointer. + + Uses a single table created on first ``save`` call. The DDL is idempotent + so multiple processes pointing at the same database are safe. + + Parameters: + dsn: Postgres connection string. Mutually exclusive with ``connection``. + connection: Pre-built ``psycopg.Connection``. Use this to share a + connection across many pipelines. + table_name: Override the default ``firefly_checkpoints`` table name. + + Raises: + ImportError: When the ``postgres`` extra is not installed. + """ + + _DDL = """ + CREATE TABLE IF NOT EXISTS {table} ( + pipeline_name TEXT NOT NULL, + run_id TEXT NOT NULL, + sequence INT NOT NULL, + node_id TEXT NOT NULL, + state JSONB NOT NULL, + completed_nodes JSONB NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY (pipeline_name, run_id, sequence) + ); + CREATE INDEX IF NOT EXISTS {table}_run_idx + ON {table} (pipeline_name, run_id); + """ + + def __init__( + self, + *, + dsn: str | None = None, + connection: Any = None, + table_name: str = "firefly_checkpoints", + ) -> None: + if _psycopg is None: + raise ImportError( + "PostgresCheckpointer requires the 'postgres' extra. " + "Install with: pip install fireflyframework-agentic[postgres]" + ) + if (dsn is None) == (connection is None): + raise ValueError("PostgresCheckpointer needs exactly one of `dsn` or `connection`.") + # Table name is interpolated into DDL — validate it strictly to avoid SQL injection. + if not table_name.replace("_", "").isalnum(): + raise ValueError(f"Invalid table_name {table_name!r}: must be alphanumeric/underscore only.") + self._conn = connection if connection is not None else _psycopg.connect(dsn, autocommit=True) + self._table = table_name + self._ddl_applied = False + + def _ensure_table(self) -> None: + if self._ddl_applied: + return + with self._conn.cursor() as cur: + cur.execute(self._DDL.format(table=self._table)) + self._ddl_applied = True + + def save(self, record: CheckpointRecord) -> None: + self._ensure_table() + sql = ( + f"INSERT INTO {self._table} " + "(pipeline_name, run_id, sequence, node_id, state, completed_nodes) " + "VALUES (%s, %s, %s, %s, %s, %s) " + "ON CONFLICT (pipeline_name, run_id, sequence) DO UPDATE SET " + "node_id = EXCLUDED.node_id, " + "state = EXCLUDED.state, " + "completed_nodes = EXCLUDED.completed_nodes" + ) + with self._conn.cursor() as cur: + cur.execute( + sql, + ( + record.pipeline_name, + record.run_id, + record.sequence, + record.node_id, + json.dumps(record.state), + json.dumps(record.completed_nodes), + ), + ) + + def load_latest(self, pipeline_name: str, run_id: str) -> CheckpointRecord | None: + self._ensure_table() + sql = ( + f"SELECT pipeline_name, run_id, sequence, node_id, state, completed_nodes " + f"FROM {self._table} " + f"WHERE pipeline_name = %s AND run_id = %s " + f"ORDER BY sequence DESC LIMIT 1" + ) + with self._conn.cursor() as cur: + cur.execute(sql, (pipeline_name, run_id)) + row = cur.fetchone() + if row is None: + return None + pipeline, rid, seq, node_id, state, completed = row + # psycopg returns JSONB as parsed Python objects; tolerate raw strings too. + return CheckpointRecord( + pipeline_name=pipeline, + run_id=rid, + sequence=seq, + node_id=node_id, + state=json.loads(state) if isinstance(state, str) else state, + completed_nodes=json.loads(completed) if isinstance(completed, str) else completed, + ) + + def list_runs(self, pipeline_name: str) -> list[str]: + self._ensure_table() + sql = f"SELECT DISTINCT run_id FROM {self._table} WHERE pipeline_name = %s ORDER BY run_id" + with self._conn.cursor() as cur: + cur.execute(sql, (pipeline_name,)) + return [r[0] for r in cur.fetchall()] diff --git a/pyproject.toml b/pyproject.toml index 416a109d..ed77b3c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,9 @@ queues = [ postgres = [ "asyncpg>=0.30.0", "sqlalchemy>=2.0.0", + # psycopg drives the sync-Protocol PostgresCheckpointer; asyncpg above + # drives the async memory store. Both ship under the same extra. + "psycopg[binary]>=3.2.0,<4", ] mongodb = [ "motor>=3.6.0", diff --git a/tests/unit/pipeline/test_checkpoint_backends.py b/tests/unit/pipeline/test_checkpoint_backends.py new file mode 100644 index 00000000..091f6cf3 --- /dev/null +++ b/tests/unit/pipeline/test_checkpoint_backends.py @@ -0,0 +1,403 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. + +"""Tests for the three Checkpointer backends (File, Redis, Postgres). + +Mocks only — no real Redis or Postgres needed. Real-service verification is +out-of-band against actual servers. +""" + +from __future__ import annotations + +import json +from typing import Any +from unittest.mock import MagicMock, create_autospec + +import pytest +from pydantic import BaseModel + +import fireflyframework_agentic.pipeline.checkpoint as checkpoint_module +from fireflyframework_agentic.pipeline import ( + CheckpointRecord, + FileCheckpointer, + PipelineBuilder, + PostgresCheckpointer, + RedisCheckpointer, + StatePipeline, +) + + +@pytest.fixture(autouse=True) +def _stub_optional_deps(monkeypatch: pytest.MonkeyPatch) -> None: + """Make ``_redis`` / ``_psycopg`` truthy so we can construct backends with mock clients. + + Tests that explicitly want the missing-dep code path (the two + ``..._missing_dep_raises`` tests) override this by setting the symbol back + to None via their own monkeypatch — the per-test patch wins. + """ + if checkpoint_module._redis is None: + monkeypatch.setattr(checkpoint_module, "_redis", MagicMock(name="redis_stub")) + if checkpoint_module._psycopg is None: + monkeypatch.setattr(checkpoint_module, "_psycopg", MagicMock(name="psycopg_stub")) + + +# ============================================================================= +# RedisCheckpointer +# ============================================================================= + + +def _redis_client_mock() -> MagicMock: + """Build a MagicMock that tracks the calls a RedisCheckpointer issues, with + just enough state (an in-memory dict) to make round-trip save→load work. + """ + store: dict[str, str] = {} + runs_index: dict[str, dict[str, float]] = {} + + client = MagicMock(name="redis.Redis") + + def fake_set(key: str, value: str, ex: int | None = None) -> bool: + store[key] = value + return True + + def fake_get(key: str) -> str | None: + return store.get(key) + + def fake_keys(pattern: str) -> list[str]: + # Trivial glob: only handles "*" patterns (which is what we use). + if pattern.endswith("*"): + prefix = pattern[:-1] + return [k for k in store if k.startswith(prefix)] + return [k for k in store if k == pattern] + + def fake_zadd(key: str, mapping: dict[str, float]) -> int: + runs_index.setdefault(key, {}).update(mapping) + return len(mapping) + + def fake_zrange(key: str, start: int, end: int) -> list[str]: + members = runs_index.get(key, {}) + ordered = sorted(members, key=lambda m: members[m]) + return ordered[start : (end + 1 if end != -1 else None)] + + client.set.side_effect = fake_set + client.get.side_effect = fake_get + client.keys.side_effect = fake_keys + client.zadd.side_effect = fake_zadd + client.zrange.side_effect = fake_zrange + return client + + +def test_redis_checkpointer_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(checkpoint_module, "_redis", None) + with pytest.raises(ImportError, match=r"\[redis\]"): + RedisCheckpointer(url="redis://x") + + +def test_redis_checkpointer_rejects_both_url_and_client() -> None: + with pytest.raises(ValueError, match="exactly one"): + RedisCheckpointer(url="redis://x", client=MagicMock()) + + +def test_redis_checkpointer_rejects_neither_url_nor_client() -> None: + with pytest.raises(ValueError, match="exactly one"): + RedisCheckpointer() + + +def test_redis_save_issues_set_and_zadd() -> None: + client = _redis_client_mock() + ckpt = RedisCheckpointer(client=client, ttl_seconds=42) + record = CheckpointRecord( + pipeline_name="p", + run_id="r", + sequence=1, + node_id="n", + state={"x": 1}, + completed_nodes=["n"], + ) + ckpt.save(record) + + set_call = client.set.call_args + assert set_call.args[0] == "firefly:ckpt:p:r:000001_n" + assert json.loads(set_call.args[1])["node_id"] == "n" + assert set_call.kwargs["ex"] == 42 + + zadd_call = client.zadd.call_args + assert zadd_call.args[0] == "firefly:ckpt:p:runs" + assert "r" in zadd_call.args[1] + + +def test_redis_load_latest_picks_highest_sequence_key() -> None: + client = _redis_client_mock() + ckpt = RedisCheckpointer(client=client) + for seq in (1, 5, 3): + ckpt.save( + CheckpointRecord( + pipeline_name="p", + run_id="r", + sequence=seq, + node_id=f"node{seq}", + state={"seq": seq}, + completed_nodes=[], + ) + ) + latest = ckpt.load_latest("p", "r") + assert latest is not None + assert latest.sequence == 5 + assert latest.node_id == "node5" + + +def test_redis_load_latest_returns_none_when_run_unknown() -> None: + client = _redis_client_mock() + ckpt = RedisCheckpointer(client=client) + assert ckpt.load_latest("p", "missing") is None + + +def test_redis_list_runs_returns_zrange_result() -> None: + client = _redis_client_mock() + ckpt = RedisCheckpointer(client=client) + for run_id in ("r1", "r2", "r3"): + ckpt.save( + CheckpointRecord( + pipeline_name="p", + run_id=run_id, + sequence=1, + node_id="n", + state={}, + completed_nodes=[], + ) + ) + runs = ckpt.list_runs("p") + assert set(runs) == {"r1", "r2", "r3"} + + +# ============================================================================= +# PostgresCheckpointer +# ============================================================================= + + +def _postgres_connection_mock() -> tuple[MagicMock, dict[tuple, dict[str, Any]]]: + """MagicMock Connection backed by an in-memory dict shaped like the + firefly_checkpoints table. Returns ``(conn, store)`` so tests can poke the store. + """ + store: dict[tuple[str, str, int], dict[str, Any]] = {} + ddl_calls: list[str] = [] + + conn = MagicMock(name="psycopg.Connection") + + def make_cursor() -> MagicMock: + cur = MagicMock(name="psycopg.Cursor") + cur.__enter__ = MagicMock(return_value=cur) + cur.__exit__ = MagicMock(return_value=None) + cur._last_fetchone = None + cur._last_fetchall = [] + + def fake_execute(sql: str, params: tuple | None = None) -> None: + sql_lower = sql.strip().lower() + if sql_lower.startswith("create table"): + ddl_calls.append(sql) + return + if sql_lower.startswith("insert into"): + assert params is not None + key = (params[0], params[1], params[2]) + store[key] = { + "pipeline_name": params[0], + "run_id": params[1], + "sequence": params[2], + "node_id": params[3], + "state": json.loads(params[4]) if isinstance(params[4], str) else params[4], + "completed_nodes": (json.loads(params[5]) if isinstance(params[5], str) else params[5]), + } + return + if sql_lower.startswith("select pipeline_name"): + assert params is not None + matches = [v for k, v in store.items() if k[0] == params[0] and k[1] == params[1]] + matches.sort(key=lambda r: r["sequence"], reverse=True) + cur._last_fetchone = ( + ( + matches[0]["pipeline_name"], + matches[0]["run_id"], + matches[0]["sequence"], + matches[0]["node_id"], + matches[0]["state"], + matches[0]["completed_nodes"], + ) + if matches + else None + ) + return + if sql_lower.startswith("select distinct run_id"): + assert params is not None + runs = sorted({k[1] for k in store if k[0] == params[0]}) + cur._last_fetchall = [(r,) for r in runs] + return + raise AssertionError(f"unexpected SQL: {sql}") + + cur.execute.side_effect = fake_execute + cur.fetchone.side_effect = lambda: cur._last_fetchone + cur.fetchall.side_effect = lambda: cur._last_fetchall + return cur + + conn.cursor.side_effect = make_cursor + conn._ddl_calls = ddl_calls + return conn, store + + +def test_postgres_checkpointer_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(checkpoint_module, "_psycopg", None) + with pytest.raises(ImportError, match=r"\[postgres\]"): + PostgresCheckpointer(dsn="postgresql://x") + + +def test_postgres_checkpointer_rejects_both_dsn_and_connection() -> None: + with pytest.raises(ValueError, match="exactly one"): + PostgresCheckpointer(dsn="postgresql://x", connection=MagicMock()) + + +def test_postgres_checkpointer_rejects_bad_table_name() -> None: + with pytest.raises(ValueError, match="Invalid table_name"): + PostgresCheckpointer(connection=MagicMock(), table_name="bad; DROP TABLE users") + + +def test_postgres_ddl_runs_once_across_many_saves() -> None: + conn, _store = _postgres_connection_mock() + ckpt = PostgresCheckpointer(connection=conn) + for seq in range(3): + ckpt.save( + CheckpointRecord( + pipeline_name="p", + run_id="r", + sequence=seq, + node_id=f"n{seq}", + state={}, + completed_nodes=[], + ) + ) + assert len(conn._ddl_calls) == 1, "DDL should run exactly once per instance" + + +def test_postgres_save_then_load_latest_round_trips() -> None: + conn, _store = _postgres_connection_mock() + ckpt = PostgresCheckpointer(connection=conn) + for seq, node_id in [(1, "a"), (5, "e"), (3, "c")]: + ckpt.save( + CheckpointRecord( + pipeline_name="p", + run_id="r", + sequence=seq, + node_id=node_id, + state={"seq": seq}, + completed_nodes=[node_id], + ) + ) + latest = ckpt.load_latest("p", "r") + assert latest is not None + assert latest.sequence == 5 + assert latest.node_id == "e" + + +def test_postgres_load_latest_returns_none_when_empty() -> None: + conn, _store = _postgres_connection_mock() + ckpt = PostgresCheckpointer(connection=conn) + assert ckpt.load_latest("p", "missing") is None + + +def test_postgres_list_runs_returns_distinct_run_ids() -> None: + conn, _store = _postgres_connection_mock() + ckpt = PostgresCheckpointer(connection=conn) + for run_id in ("rA", "rB", "rA"): # rA twice, rB once + ckpt.save( + CheckpointRecord( + pipeline_name="p", + run_id=run_id, + sequence=1, + node_id="n", + state={}, + completed_nodes=[], + ) + ) + assert ckpt.list_runs("p") == ["rA", "rB"] + + +# ============================================================================= +# Protocol conformance — software-factory scenario across all three backends +# ============================================================================= + + +class FactoryState(BaseModel): + requirements: str + spec: str | None = None + code: str | None = None + deploy_url: str | None = None + evaluation: str | None = None + + +def _build_factory(checkpointer: Any) -> StatePipeline: + """Construct the canonical 4-step agent pipeline that fails on first deploy.""" + state_flag = {"failed_once": False} + + async def architect(state: FactoryState) -> dict: + return {"spec": f"spec for {state.requirements}"} + + async def python_dev(state: FactoryState) -> dict: + return {"code": f"# code for {state.spec}"} + + async def deployer(state: FactoryState) -> dict: + if not state_flag["failed_once"]: + state_flag["failed_once"] = True + raise RuntimeError("blip") + return {"deploy_url": "https://app"} + + async def evaluator(state: FactoryState) -> dict: + return {"evaluation": f"PASS {state.deploy_url}"} + + pipeline = ( + PipelineBuilder("factory", state=FactoryState, checkpointer=checkpointer) + .add_node(architect) + .add_node(python_dev) + .add_node(deployer) + .add_node(evaluator) + .chain(architect, python_dev, deployer, evaluator) + .build() + ) + assert isinstance(pipeline, StatePipeline) + return pipeline + + +@pytest.fixture +def file_backend(tmp_path): + return FileCheckpointer(tmp_path / "ckpt") + + +@pytest.fixture +def redis_backend(): + return RedisCheckpointer(client=_redis_client_mock()) + + +@pytest.fixture +def postgres_backend(): + conn, _store = _postgres_connection_mock() + return PostgresCheckpointer(connection=conn) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("backend_fixture", ["file_backend", "redis_backend", "postgres_backend"]) +async def test_backend_supports_fail_and_resume(backend_fixture, request) -> None: + """Same scenario across all three backends: deployer fails, resume completes.""" + backend = request.getfixturevalue(backend_fixture) + pipeline = _build_factory(backend) + + first = await pipeline.invoke(FactoryState(requirements="users service")) + assert not first.success + assert first.failed_node == "deployer" + assert first.completed_nodes == ["architect", "python_dev"] + + second = await pipeline.invoke(run_id=first.run_id) + assert second.success + assert second.completed_nodes == ["architect", "python_dev", "deployer", "evaluator"] + assert second.state.evaluation == "PASS https://app" + + +# Silence the unused-import warning for the autospec we don't actually use here +# but keep available for follow-up tests. +_ = create_autospec