diff --git a/fireflyframework_agentic/pipeline/_psycopg_backend.py b/fireflyframework_agentic/pipeline/_psycopg_backend.py new file mode 100644 index 00000000..b6784c9d --- /dev/null +++ b/fireflyframework_agentic/pipeline/_psycopg_backend.py @@ -0,0 +1,76 @@ +# Copyright 2026 Firefly Software Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared scaffolding for Postgres-backed pipeline backends. + +The checkpointer and audit-log backends both need the same boilerplate: +optional-dep guard on ``psycopg``, ``dsn`` xor ``connection`` constructor +check, table-name validation, and lazy idempotent DDL on first write. This +module centralizes it so each backend only has to declare its DDL and +default table name. +""" + +from __future__ import annotations + +from typing import Any + +try: + import psycopg as _psycopg # type: ignore[import-not-found] +except ImportError: # pragma: no cover - optional dep + _psycopg = None # type: ignore[assignment] + + +class PsycopgBackend: + """Base class for backends that persist into a single Postgres table. + + Subclasses set the class attribute ``_DDL`` to a format string with a + single ``{table}`` placeholder, and pass their human-readable name and + default table to ``__init__``. The base class handles the rest: + + * Raises ``ImportError`` if the ``postgres`` extra is not installed. + * Enforces ``dsn`` xor ``connection``. + * Validates ``table_name`` against SQL injection (interpolated into DDL). + * Opens the connection (with ``autocommit=True``) when only ``dsn`` is given. + * Applies the DDL lazily and idempotently on first ``_ensure_table()`` call. + """ + + _DDL: str = "" + + def __init__( + self, + *, + kind: str, + dsn: str | None, + connection: Any, + table_name: str, + ) -> None: + if _psycopg is None: + raise ImportError( + f"{kind} requires the 'postgres' extra. Install with: pip install fireflyframework-agentic[postgres]" + ) + if (dsn is None) == (connection is None): + raise ValueError(f"{kind} needs exactly one of `dsn` or `connection`.") + # Table name is interpolated into DDL — validate strictly to avoid SQL injection. + if not table_name.replace("_", "").isalnum(): + raise ValueError(f"Invalid table_name {table_name!r}: must be alphanumeric/underscore only.") + self._conn = connection if connection is not None else _psycopg.connect(dsn, autocommit=True) + self._table = table_name + self._ddl_applied = False + + def _ensure_table(self) -> None: + if self._ddl_applied: + return + with self._conn.cursor() as cur: + cur.execute(self._DDL.format(table=self._table)) + self._ddl_applied = True diff --git a/fireflyframework_agentic/pipeline/audit.py b/fireflyframework_agentic/pipeline/audit.py index 481ea08d..d9691775 100644 --- a/fireflyframework_agentic/pipeline/audit.py +++ b/fireflyframework_agentic/pipeline/audit.py @@ -46,10 +46,7 @@ from pydantic import BaseModel -try: - import psycopg as _psycopg # type: ignore[import-not-found] -except ImportError: # pragma: no cover - optional dep - _psycopg = None # type: ignore[assignment] +from fireflyframework_agentic.pipeline._psycopg_backend import PsycopgBackend try: from opentelemetry._logs import LogRecord as _OtelLogRecord # type: ignore[import-not-found] @@ -145,7 +142,7 @@ def list_entries(self, pipeline_name: str, run_id: str) -> list[AuditEntry]: return entries -class PostgresAuditLog: +class PostgresAuditLog(PsycopgBackend): """Postgres-backed audit log. Single table created on first ``record`` call. Reuses ``psycopg`` from the ``postgres`` optional extra. ``dsn`` or a @@ -181,25 +178,7 @@ def __init__( connection: Any = None, table_name: str = "firefly_audit", ) -> None: - if _psycopg is None: - raise ImportError( - "PostgresAuditLog requires the 'postgres' extra. " - "Install with: pip install fireflyframework-agentic[postgres]" - ) - if (dsn is None) == (connection is None): - raise ValueError("PostgresAuditLog needs exactly one of `dsn` or `connection`.") - if not table_name.replace("_", "").isalnum(): - raise ValueError(f"Invalid table_name {table_name!r}: must be alphanumeric/underscore only.") - self._conn = connection if connection is not None else _psycopg.connect(dsn, autocommit=True) - self._table = table_name - self._ddl_applied = False - - def _ensure_table(self) -> None: - if self._ddl_applied: - return - with self._conn.cursor() as cur: - cur.execute(self._DDL.format(table=self._table)) - self._ddl_applied = True + super().__init__(kind="PostgresAuditLog", dsn=dsn, connection=connection, table_name=table_name) def record(self, entry: AuditEntry) -> None: self._ensure_table() diff --git a/fireflyframework_agentic/pipeline/checkpoint.py b/fireflyframework_agentic/pipeline/checkpoint.py index ecedb16c..5cb2b5cf 100644 --- a/fireflyframework_agentic/pipeline/checkpoint.py +++ b/fireflyframework_agentic/pipeline/checkpoint.py @@ -39,16 +39,13 @@ from pydantic import BaseModel +from fireflyframework_agentic.pipeline._psycopg_backend import PsycopgBackend + try: import redis as _redis # type: ignore[import-not-found] except ImportError: # pragma: no cover - optional dep _redis = None # type: ignore[assignment] -try: - import psycopg as _psycopg # type: ignore[import-not-found] -except ImportError: # pragma: no cover - optional dep - _psycopg = None # type: ignore[assignment] - class CheckpointRecord(BaseModel): """One saved checkpoint. @@ -194,7 +191,7 @@ def list_runs(self, pipeline_name: str) -> list[str]: return list(self._client.zrange(self._runs_index_key(pipeline_name), 0, -1)) -class PostgresCheckpointer: +class PostgresCheckpointer(PsycopgBackend): """Postgres-backed checkpointer. Uses a single table created on first ``save`` call. The DDL is idempotent @@ -232,26 +229,7 @@ def __init__( connection: Any = None, table_name: str = "firefly_checkpoints", ) -> None: - if _psycopg is None: - raise ImportError( - "PostgresCheckpointer requires the 'postgres' extra. " - "Install with: pip install fireflyframework-agentic[postgres]" - ) - if (dsn is None) == (connection is None): - raise ValueError("PostgresCheckpointer needs exactly one of `dsn` or `connection`.") - # Table name is interpolated into DDL — validate it strictly to avoid SQL injection. - if not table_name.replace("_", "").isalnum(): - raise ValueError(f"Invalid table_name {table_name!r}: must be alphanumeric/underscore only.") - self._conn = connection if connection is not None else _psycopg.connect(dsn, autocommit=True) - self._table = table_name - self._ddl_applied = False - - def _ensure_table(self) -> None: - if self._ddl_applied: - return - with self._conn.cursor() as cur: - cur.execute(self._DDL.format(table=self._table)) - self._ddl_applied = True + super().__init__(kind="PostgresCheckpointer", dsn=dsn, connection=connection, table_name=table_name) def save(self, record: CheckpointRecord) -> None: self._ensure_table() diff --git a/tests/unit/pipeline/test_audit_log.py b/tests/unit/pipeline/test_audit_log.py index e6b426b2..369cf6a5 100644 --- a/tests/unit/pipeline/test_audit_log.py +++ b/tests/unit/pipeline/test_audit_log.py @@ -17,6 +17,7 @@ import pytest from pydantic import BaseModel +import fireflyframework_agentic.pipeline._psycopg_backend as psycopg_backend_module import fireflyframework_agentic.pipeline.audit as audit_module from fireflyframework_agentic.pipeline import ( AuditEntry, @@ -85,8 +86,8 @@ def test_file_audit_log_unknown_run_returns_empty(tmp_path: Path) -> None: @pytest.fixture(autouse=True) def _stub_optional_deps(monkeypatch: pytest.MonkeyPatch) -> None: """Stub _psycopg and OTel symbols so backends can be constructed with mocks.""" - if audit_module._psycopg is None: - monkeypatch.setattr(audit_module, "_psycopg", MagicMock(name="psycopg_stub")) + if psycopg_backend_module._psycopg is None: + monkeypatch.setattr(psycopg_backend_module, "_psycopg", MagicMock(name="psycopg_stub")) if audit_module._otel_get_logger is None: monkeypatch.setattr(audit_module, "_otel_get_logger", MagicMock(name="otel_logger_factory")) monkeypatch.setattr(audit_module, "_OtelLogRecord", MagicMock(name="LogRecord")) @@ -99,7 +100,7 @@ def _stub_optional_deps(monkeypatch: pytest.MonkeyPatch) -> None: def test_postgres_audit_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(audit_module, "_psycopg", None) + monkeypatch.setattr(psycopg_backend_module, "_psycopg", None) with pytest.raises(ImportError, match=r"\[postgres\]"): PostgresAuditLog(dsn="postgresql://x") diff --git a/tests/unit/pipeline/test_checkpoint_backends.py b/tests/unit/pipeline/test_checkpoint_backends.py index 091f6cf3..47a6674c 100644 --- a/tests/unit/pipeline/test_checkpoint_backends.py +++ b/tests/unit/pipeline/test_checkpoint_backends.py @@ -18,6 +18,7 @@ import pytest from pydantic import BaseModel +import fireflyframework_agentic.pipeline._psycopg_backend as psycopg_backend_module import fireflyframework_agentic.pipeline.checkpoint as checkpoint_module from fireflyframework_agentic.pipeline import ( CheckpointRecord, @@ -39,8 +40,8 @@ def _stub_optional_deps(monkeypatch: pytest.MonkeyPatch) -> None: """ if checkpoint_module._redis is None: monkeypatch.setattr(checkpoint_module, "_redis", MagicMock(name="redis_stub")) - if checkpoint_module._psycopg is None: - monkeypatch.setattr(checkpoint_module, "_psycopg", MagicMock(name="psycopg_stub")) + if psycopg_backend_module._psycopg is None: + monkeypatch.setattr(psycopg_backend_module, "_psycopg", MagicMock(name="psycopg_stub")) # ============================================================================= @@ -244,7 +245,7 @@ def fake_execute(sql: str, params: tuple | None = None) -> None: def test_postgres_checkpointer_missing_dep_raises(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setattr(checkpoint_module, "_psycopg", None) + monkeypatch.setattr(psycopg_backend_module, "_psycopg", None) with pytest.raises(ImportError, match=r"\[postgres\]"): PostgresCheckpointer(dsn="postgresql://x")