Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions fireflyframework_agentic/pipeline/_psycopg_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2026 Firefly Software Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Shared scaffolding for Postgres-backed pipeline backends.

The checkpointer and audit-log backends both need the same boilerplate:
optional-dep guard on ``psycopg``, ``dsn`` xor ``connection`` constructor
check, table-name validation, and lazy idempotent DDL on first write. This
module centralizes it so each backend only has to declare its DDL and
default table name.
"""

from __future__ import annotations

from typing import Any

try:
import psycopg as _psycopg # type: ignore[import-not-found]
except ImportError: # pragma: no cover - optional dep
_psycopg = None # type: ignore[assignment]


class PsycopgBackend:
"""Base class for backends that persist into a single Postgres table.

Subclasses set the class attribute ``_DDL`` to a format string with a
single ``{table}`` placeholder, and pass their human-readable name and
default table to ``__init__``. The base class handles the rest:

* Raises ``ImportError`` if the ``postgres`` extra is not installed.
* Enforces ``dsn`` xor ``connection``.
* Validates ``table_name`` against SQL injection (interpolated into DDL).
* Opens the connection (with ``autocommit=True``) when only ``dsn`` is given.
* Applies the DDL lazily and idempotently on first ``_ensure_table()`` call.
"""

_DDL: str = ""

def __init__(
self,
*,
kind: str,
dsn: str | None,
connection: Any,
table_name: str,
) -> None:
if _psycopg is None:
raise ImportError(
f"{kind} requires the 'postgres' extra. Install with: pip install fireflyframework-agentic[postgres]"
)
if (dsn is None) == (connection is None):
raise ValueError(f"{kind} needs exactly one of `dsn` or `connection`.")
# Table name is interpolated into DDL — validate strictly to avoid SQL injection.
if not table_name.replace("_", "").isalnum():
raise ValueError(f"Invalid table_name {table_name!r}: must be alphanumeric/underscore only.")
self._conn = connection if connection is not None else _psycopg.connect(dsn, autocommit=True)
self._table = table_name
self._ddl_applied = False

def _ensure_table(self) -> None:
if self._ddl_applied:
return
with self._conn.cursor() as cur:
cur.execute(self._DDL.format(table=self._table))
self._ddl_applied = True
27 changes: 3 additions & 24 deletions fireflyframework_agentic/pipeline/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,7 @@

from pydantic import BaseModel

try:
import psycopg as _psycopg # type: ignore[import-not-found]
except ImportError: # pragma: no cover - optional dep
_psycopg = None # type: ignore[assignment]
from fireflyframework_agentic.pipeline._psycopg_backend import PsycopgBackend

try:
from opentelemetry._logs import LogRecord as _OtelLogRecord # type: ignore[import-not-found]
Expand Down Expand Up @@ -145,7 +142,7 @@ def list_entries(self, pipeline_name: str, run_id: str) -> list[AuditEntry]:
return entries


class PostgresAuditLog:
class PostgresAuditLog(PsycopgBackend):
"""Postgres-backed audit log. Single table created on first ``record`` call.

Reuses ``psycopg`` from the ``postgres`` optional extra. ``dsn`` or a
Expand Down Expand Up @@ -181,25 +178,7 @@ def __init__(
connection: Any = None,
table_name: str = "firefly_audit",
) -> None:
if _psycopg is None:
raise ImportError(
"PostgresAuditLog requires the 'postgres' extra. "
"Install with: pip install fireflyframework-agentic[postgres]"
)
if (dsn is None) == (connection is None):
raise ValueError("PostgresAuditLog needs exactly one of `dsn` or `connection`.")
if not table_name.replace("_", "").isalnum():
raise ValueError(f"Invalid table_name {table_name!r}: must be alphanumeric/underscore only.")
self._conn = connection if connection is not None else _psycopg.connect(dsn, autocommit=True)
self._table = table_name
self._ddl_applied = False

def _ensure_table(self) -> None:
if self._ddl_applied:
return
with self._conn.cursor() as cur:
cur.execute(self._DDL.format(table=self._table))
self._ddl_applied = True
super().__init__(kind="PostgresAuditLog", dsn=dsn, connection=connection, table_name=table_name)

def record(self, entry: AuditEntry) -> None:
self._ensure_table()
Expand Down
30 changes: 4 additions & 26 deletions fireflyframework_agentic/pipeline/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,13 @@

from pydantic import BaseModel

from fireflyframework_agentic.pipeline._psycopg_backend import PsycopgBackend

try:
import redis as _redis # type: ignore[import-not-found]
except ImportError: # pragma: no cover - optional dep
_redis = None # type: ignore[assignment]

try:
import psycopg as _psycopg # type: ignore[import-not-found]
except ImportError: # pragma: no cover - optional dep
_psycopg = None # type: ignore[assignment]


class CheckpointRecord(BaseModel):
"""One saved checkpoint.
Expand Down Expand Up @@ -194,7 +191,7 @@ def list_runs(self, pipeline_name: str) -> list[str]:
return list(self._client.zrange(self._runs_index_key(pipeline_name), 0, -1))


class PostgresCheckpointer:
class PostgresCheckpointer(PsycopgBackend):
"""Postgres-backed checkpointer.

Uses a single table created on first ``save`` call. The DDL is idempotent
Expand Down Expand Up @@ -232,26 +229,7 @@ def __init__(
connection: Any = None,
table_name: str = "firefly_checkpoints",
) -> None:
if _psycopg is None:
raise ImportError(
"PostgresCheckpointer requires the 'postgres' extra. "
"Install with: pip install fireflyframework-agentic[postgres]"
)
if (dsn is None) == (connection is None):
raise ValueError("PostgresCheckpointer needs exactly one of `dsn` or `connection`.")
# Table name is interpolated into DDL — validate it strictly to avoid SQL injection.
if not table_name.replace("_", "").isalnum():
raise ValueError(f"Invalid table_name {table_name!r}: must be alphanumeric/underscore only.")
self._conn = connection if connection is not None else _psycopg.connect(dsn, autocommit=True)
self._table = table_name
self._ddl_applied = False

def _ensure_table(self) -> None:
if self._ddl_applied:
return
with self._conn.cursor() as cur:
cur.execute(self._DDL.format(table=self._table))
self._ddl_applied = True
super().__init__(kind="PostgresCheckpointer", dsn=dsn, connection=connection, table_name=table_name)

def save(self, record: CheckpointRecord) -> None:
self._ensure_table()
Expand Down
7 changes: 4 additions & 3 deletions tests/unit/pipeline/test_audit_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytest
from pydantic import BaseModel

import fireflyframework_agentic.pipeline._psycopg_backend as psycopg_backend_module
import fireflyframework_agentic.pipeline.audit as audit_module
from fireflyframework_agentic.pipeline import (
AuditEntry,
Expand Down Expand Up @@ -85,8 +86,8 @@ def test_file_audit_log_unknown_run_returns_empty(tmp_path: Path) -> None:
@pytest.fixture(autouse=True)
def _stub_optional_deps(monkeypatch: pytest.MonkeyPatch) -> None:
"""Stub _psycopg and OTel symbols so backends can be constructed with mocks."""
if audit_module._psycopg is None:
monkeypatch.setattr(audit_module, "_psycopg", MagicMock(name="psycopg_stub"))
if psycopg_backend_module._psycopg is None:
monkeypatch.setattr(psycopg_backend_module, "_psycopg", MagicMock(name="psycopg_stub"))
if audit_module._otel_get_logger is None:
monkeypatch.setattr(audit_module, "_otel_get_logger", MagicMock(name="otel_logger_factory"))
monkeypatch.setattr(audit_module, "_OtelLogRecord", MagicMock(name="LogRecord"))
Expand All @@ -99,7 +100,7 @@ def _stub_optional_deps(monkeypatch: pytest.MonkeyPatch) -> None:


def test_postgres_audit_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(audit_module, "_psycopg", None)
monkeypatch.setattr(psycopg_backend_module, "_psycopg", None)
with pytest.raises(ImportError, match=r"\[postgres\]"):
PostgresAuditLog(dsn="postgresql://x")

Expand Down
7 changes: 4 additions & 3 deletions tests/unit/pipeline/test_checkpoint_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest
from pydantic import BaseModel

import fireflyframework_agentic.pipeline._psycopg_backend as psycopg_backend_module
import fireflyframework_agentic.pipeline.checkpoint as checkpoint_module
from fireflyframework_agentic.pipeline import (
CheckpointRecord,
Expand All @@ -39,8 +40,8 @@ def _stub_optional_deps(monkeypatch: pytest.MonkeyPatch) -> None:
"""
if checkpoint_module._redis is None:
monkeypatch.setattr(checkpoint_module, "_redis", MagicMock(name="redis_stub"))
if checkpoint_module._psycopg is None:
monkeypatch.setattr(checkpoint_module, "_psycopg", MagicMock(name="psycopg_stub"))
if psycopg_backend_module._psycopg is None:
monkeypatch.setattr(psycopg_backend_module, "_psycopg", MagicMock(name="psycopg_stub"))


# =============================================================================
Expand Down Expand Up @@ -244,7 +245,7 @@ def fake_execute(sql: str, params: tuple | None = None) -> None:


def test_postgres_checkpointer_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(checkpoint_module, "_psycopg", None)
monkeypatch.setattr(psycopg_backend_module, "_psycopg", None)
with pytest.raises(ImportError, match=r"\[postgres\]"):
PostgresCheckpointer(dsn="postgresql://x")

Expand Down