diff --git a/CHANGELOG.md b/CHANGELOG.md index ed207623..c4a06fa8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,77 @@ Copyright 2026 Firefly Software Solutions Inc. Licensed under the Apache License Supports `--port`, `--host`, `--no-browser`, and `--dev` flags. Defaults to the `studio` subcommand when run without arguments. +### Changed + +- **Middleware Protocol** -- Renamed `before`/`after` to `before_run`/`after_run` + on `PromptCacheMiddleware` and `CircuitBreakerMiddleware` to conform to the + `AgentMiddleware` protocol contract. +- **Exception Hierarchy** -- Renamed `MemoryError` to `FireflyMemoryError` to + avoid shadowing the Python built-in. A deprecated alias is kept for backwards + compatibility. +- **Quota Defaults** -- `quota_enabled` now defaults to `False` to avoid + unexpected enforcement on first install. +- **Cost Calculator Type** -- `cost_calculator` config field is now + `Literal["auto", "genai_prices", "static"]`. + +### Security + +- **ShellTool** -- Replaced `create_subprocess_shell` with + `create_subprocess_exec` to prevent command-injection via shell metacharacters. +- **FileSystemTool** -- Replaced `str.startswith` path check with + `Path.is_relative_to` to prevent symlink-based path traversal. +- **RBAC Decorator** -- Fixed `require_permission` to use `inspect.signature` + for positional argument binding and replaced `nonlocal` mutation with local + `manager` variable. +- **Encryption** -- Each `AESEncryptionProvider.encrypt()` call now generates a + random 16-byte salt for PBKDF2 key derivation, stored as + `salt[16]+nonce[12]+ciphertext+tag`. +- **REST Middleware** -- `allow_credentials` is now automatically set to `False` + when `allow_origins=["*"]`. API key comparison uses `hmac.compare_digest`. +- **REST Router** -- Exception details are no longer exposed to clients; errors + are logged server-side and a generic message is returned. +- **Database Store** -- Schema name is validated against `^[a-zA-Z_][a-zA-Z0-9_]*$` + to prevent SQL injection. +- **FileStore** -- Added `Path.is_relative_to` check in `_path()` to prevent + namespace-based path traversal. + +### Fixed + +- **Thread Safety** -- Added `threading.Lock` to `InMemoryStore`, `CachedTool`, + `RateLimitGuard`, `ConversationMemory.get_turns/get_total_tokens/clear/ + clear_all/new_conversation/conversation_ids`. +- **Pipeline Engine** -- `_gather_inputs` now correctly extracts `output_key` + from dict and object results. `started_at` is initialised before the retry + loop. +- **asyncio.run Crash** -- `database_store.py` and `manager.py` sync wrappers + now detect a running event loop and offload to a `ThreadPoolExecutor` instead + of crashing. +- **TextTool ReDoS** -- Regex operations in `_extract`, `_replace`, `_split` now + run via `asyncio.to_thread` with a 5-second timeout. +- **SandboxGuard ReDoS** -- User-supplied patterns are compiled with a safe + `_safe_compile` helper. +- **Observability Decorators** -- `@metered` now records latency in a `finally` + block so it is captured even on exceptions. +- **Logging** -- `ColoredFormatter.format` now operates on a `copy.copy(record)` + to avoid mutating shared log records. +- **SlidingWindowManager** -- Uses `collections.deque` and `_running_tokens` + counter instead of re-estimating the entire window on every eviction. +- **PromptTemplate** -- Added `_UNSET` sentinel for `PromptVariable.default` so + that `default=None` is correctly propagated. +- **Queue Consumers** -- Kafka, RabbitMQ, and Redis consumers now wrap + `_process_message` in try/except to prevent one bad message from killing the + consumer loop. +- **Goal Decomposition** -- `_execute_task` now passes `memory=memory` to the + delegated `_task_pattern.execute()`. +- **ConversationMemory** -- `clear()` and `clear_all()` now also clear + `_summaries` to prevent stale summary leaks. +- **Reasoning Registry** -- Six built-in patterns are auto-registered at import + time. +- **Observability Exports** -- `extract_trace_context`, `inject_trace_context`, + and `trace_context_scope` are now re-exported from `observability/__init__.py`. +- **UsageTracker** -- `_check_budget` exception handler now logs at DEBUG instead + of silently passing. + ## [26.02.07] - 2026-02-17 ### Added diff --git a/docs/security.md b/docs/security.md index d244a6ac..501a6b4c 100644 --- a/docs/security.md +++ b/docs/security.md @@ -422,6 +422,11 @@ memory = MemoryManager(store=encrypted_store) All data is encrypted before writing and decrypted after reading, with no changes to application code. +Each call to `encrypt()` generates a random 16-byte salt for PBKDF2 key +derivation and a random 12-byte nonce for AES-GCM. The ciphertext is stored as +`salt[16] + nonce[12] + ciphertext + tag`, ensuring that identical plaintexts +produce different ciphertexts. + ### Environment Configuration ```bash diff --git a/docs/tools.md b/docs/tools.md index 8f86a89e..e5c2d147 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -162,7 +162,7 @@ The framework ships with nine ready-to-use tools in `tools/builtins/`. - **TextTool** -- Text utilities: count (words/chars/sentences/lines), extract (regex), truncate, replace, and split. - **HttpTool** -- Make HTTP requests (GET, POST, PUT, DELETE). Uses `asyncio.to_thread` to keep the event loop non-blocking. - **FileSystemTool** -- Read, write, and list files within a sandboxed base directory. Path-traversal attacks are rejected. -- **ShellTool** -- Execute shell commands restricted to an explicit allow-list. Empty allow-list rejects all commands (safe default). +- **ShellTool** -- Execute shell commands restricted to an explicit allow-list using `create_subprocess_exec` (no shell metacharacter injection). Empty allow-list rejects all commands (safe default). ### Abstract tools (subclass to use) diff --git a/src/fireflyframework_genai/__init__.py b/src/fireflyframework_genai/__init__.py index 76bb9eea..65755823 100644 --- a/src/fireflyframework_genai/__init__.py +++ b/src/fireflyframework_genai/__init__.py @@ -32,14 +32,18 @@ from fireflyframework_genai.exceptions import ( AgentError, AgentNotFoundError, + BudgetExceededError, ChunkingError, CompressionError, ConfigurationError, + DatabaseConnectionError, + DatabaseStoreError, DelegationError, ExperimentError, ExplainabilityError, ExposureError, FireflyGenAIError, + FireflyMemoryError, MemoryError, ObservabilityError, OutputReviewError, @@ -50,6 +54,8 @@ PromptValidationError, QoSError, QueueConnectionError, + QuotaError, + RateLimitError, ReasoningError, ReasoningPatternNotFoundError, ReasoningStepLimitError, @@ -85,6 +91,7 @@ "get_config", "reset_config", "FireflyGenAIError", + "FireflyMemoryError", "ConfigurationError", "AgentError", "AgentNotFoundError", @@ -110,6 +117,11 @@ "OutputValidationError", "PipelineError", "QoSError", + "QuotaError", + "BudgetExceededError", + "RateLimitError", + "DatabaseStoreError", + "DatabaseConnectionError", "MemoryError", "AgentDepsT", "AgentLike", diff --git a/src/fireflyframework_genai/agents/__init__.py b/src/fireflyframework_genai/agents/__init__.py index 9fb724df..d53c2c8d 100644 --- a/src/fireflyframework_genai/agents/__init__.py +++ b/src/fireflyframework_genai/agents/__init__.py @@ -26,8 +26,10 @@ OutputGuardMiddleware, PromptGuardError, PromptGuardMiddleware, + RetryMiddleware, ValidationMiddleware, ) +from fireflyframework_genai.agents.prompt_cache import CacheStatistics, PromptCacheMiddleware from fireflyframework_genai.agents.cache import ResultCache from fireflyframework_genai.agents.context import AgentContext from fireflyframework_genai.agents.decorators import firefly_agent @@ -64,6 +66,7 @@ "AgentRegistry", "BudgetExceededError", "CacheMiddleware", + "CacheStatistics", "CapabilityStrategy", "ContentBasedStrategy", "CostAwareStrategy", @@ -80,8 +83,10 @@ "OutputGuardError", "OutputGuardMiddleware", "PromptGuardError", + "PromptCacheMiddleware", "PromptGuardMiddleware", "ResultCache", + "RetryMiddleware", "RoundRobinStrategy", "ValidationMiddleware", "agent_registry", diff --git a/src/fireflyframework_genai/agents/delegation.py b/src/fireflyframework_genai/agents/delegation.py index c19b4246..6c192329 100644 --- a/src/fireflyframework_genai/agents/delegation.py +++ b/src/fireflyframework_genai/agents/delegation.py @@ -61,6 +61,7 @@ class RoundRobinStrategy: def __init__(self) -> None: # Lazily initialised so the cycle resets if the agent pool changes. self._cycle: itertools.cycle[FireflyAgent[Any, Any]] | None = None + self._last_agents: list[FireflyAgent[Any, Any]] = [] async def select( self, @@ -70,8 +71,9 @@ async def select( ) -> FireflyAgent[Any, Any]: if not agents: raise DelegationError("No agents available for delegation") - if self._cycle is None: + if self._cycle is None or self._last_agents != list(agents): self._cycle = itertools.cycle(agents) + self._last_agents = list(agents) return next(self._cycle) @@ -214,7 +216,7 @@ async def select( best_cost = float("inf") for agent in agents: - model_name = getattr(agent, "model_name", "") + model_name = getattr(agent, "model_name", "") or getattr(agent, "_model_identifier", "") cost = self._cost_tier(model_name) if model_name else 3 if cost < best_cost: best_cost = cost diff --git a/src/fireflyframework_genai/agents/prompt_cache.py b/src/fireflyframework_genai/agents/prompt_cache.py index 8ae88139..4ff5adf1 100644 --- a/src/fireflyframework_genai/agents/prompt_cache.py +++ b/src/fireflyframework_genai/agents/prompt_cache.py @@ -130,7 +130,7 @@ def __init__( self._cache_ttl_seconds = cache_ttl_seconds self._enabled = enabled - async def before(self, context: Any) -> None: + async def before_run(self, context: Any) -> None: """Configure prompt caching before agent execution. This method modifies the agent run parameters to enable provider-specific @@ -163,7 +163,7 @@ async def before(self, context: Any) -> None: family, ) - async def after(self, context: Any, result: Any) -> Any: + async def after_run(self, context: Any, result: Any) -> Any: """Record cache usage metrics after agent execution. Args: diff --git a/src/fireflyframework_genai/config.py b/src/fireflyframework_genai/config.py index e2dbf774..87f6921e 100644 --- a/src/fireflyframework_genai/config.py +++ b/src/fireflyframework_genai/config.py @@ -22,6 +22,7 @@ from __future__ import annotations import threading +from typing import Literal from pydantic import model_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -109,7 +110,7 @@ class FireflyGenAIConfig(BaseSettings): budget_alert_threshold_usd: float | None = None """Soft alert threshold in USD. A warning is logged when reached.""" - cost_calculator: str = "auto" + cost_calculator: Literal["auto", "genai_prices", "static"] = "auto" """Cost calculator preference: ``"auto"``, ``"genai_prices"``, or ``"static"``.""" # -- Memory ------------------------------------------------------------- @@ -165,7 +166,7 @@ class FireflyGenAIConfig(BaseSettings): are evicted when this limit is reached (FIFO).""" # -- Quota & Rate Limiting ----------------------------------------------- - quota_enabled: bool = True + quota_enabled: bool = False """Whether API quota management and rate limiting is active.""" quota_budget_daily_usd: float | None = None diff --git a/src/fireflyframework_genai/content/chunking.py b/src/fireflyframework_genai/content/chunking.py index a4140468..e5f36ff8 100644 --- a/src/fireflyframework_genai/content/chunking.py +++ b/src/fireflyframework_genai/content/chunking.py @@ -131,12 +131,14 @@ def _chunk_by_token(self, content: str) -> list[Chunk]: step = max(1, words_per_chunk - overlap_words) chunks: list[Chunk] = [] + search_offset = 0 for start_idx in range(0, len(words), step): end_idx = min(start_idx + words_per_chunk, len(words)) chunk_words = words[start_idx:end_idx] text = " ".join(chunk_words) - char_start = content.index(chunk_words[0]) if chunk_words else 0 + char_start = content.index(chunk_words[0], search_offset) if chunk_words else search_offset char_end = char_start + len(text) + search_offset = char_start + 1 chunks.append( Chunk( content=text, @@ -238,7 +240,9 @@ def split(self, content: str) -> list[Chunk]: metadata={"type": "document_segment"}, ) ) - offset += len(part) + offset = start + len(stripped) + else: + offset += len(part) for c in chunks: c.total_chunks = len(chunks) diff --git a/src/fireflyframework_genai/content/compression.py b/src/fireflyframework_genai/content/compression.py index 3e709a11..753527d2 100644 --- a/src/fireflyframework_genai/content/compression.py +++ b/src/fireflyframework_genai/content/compression.py @@ -21,6 +21,7 @@ from __future__ import annotations import logging +from collections import deque from typing import Any, Protocol, runtime_checkable from fireflyframework_genai.content.chunking import TextChunker @@ -241,11 +242,14 @@ def __init__( ) -> None: self._max_tokens = max_tokens self._estimator = estimator or TokenEstimator() - self._segments: list[str] = [] + self._segments: deque[str] = deque() + self._running_tokens = 0 def add(self, segment: str) -> None: """Append a new segment to the window, evicting oldest if needed.""" + seg_tokens = self._estimator.estimate(segment) self._segments.append(segment) + self._running_tokens += seg_tokens self._evict() def get_context(self) -> str: @@ -258,12 +262,14 @@ def segment_count(self) -> int: @property def estimated_tokens(self) -> int: - return self._estimator.estimate(self.get_context()) if self._segments else 0 + return self._running_tokens if self._segments else 0 def clear(self) -> None: self._segments.clear() + self._running_tokens = 0 def _evict(self) -> None: """Remove oldest segments until the window fits.""" - while len(self._segments) > 1 and self._estimator.estimate(self.get_context()) > self._max_tokens: - self._segments.pop(0) + while len(self._segments) > 1 and self._running_tokens > self._max_tokens: + removed = self._segments.popleft() + self._running_tokens -= self._estimator.estimate(removed) diff --git a/src/fireflyframework_genai/exceptions.py b/src/fireflyframework_genai/exceptions.py index 2a1dbefa..9e8393a9 100644 --- a/src/fireflyframework_genai/exceptions.py +++ b/src/fireflyframework_genai/exceptions.py @@ -165,11 +165,15 @@ class PipelineError(FireflyGenAIError): # -- Memory ------------------------------------------------------------------ -class MemoryError(FireflyGenAIError): +class FireflyMemoryError(FireflyGenAIError): """Raised for errors during memory storage, retrieval, or management.""" -class DatabaseStoreError(MemoryError): +# Deprecated alias for backwards compatibility +MemoryError = FireflyMemoryError + + +class DatabaseStoreError(FireflyMemoryError): """Raised for errors in database-backed memory store operations.""" diff --git a/src/fireflyframework_genai/exposure/queues/kafka.py b/src/fireflyframework_genai/exposure/queues/kafka.py index 2c9f0fd9..a843e64e 100644 --- a/src/fireflyframework_genai/exposure/queues/kafka.py +++ b/src/fireflyframework_genai/exposure/queues/kafka.py @@ -82,7 +82,11 @@ async def start(self) -> None: # Process message within trace context scope with trace_context_scope(span_context): - await self._process_message(message) + try: + await self._process_message(message) + except Exception: + logger.exception("Failed to process Kafka message on topic '%s'", self._topic) + continue finally: await self.stop() diff --git a/src/fireflyframework_genai/exposure/queues/rabbitmq.py b/src/fireflyframework_genai/exposure/queues/rabbitmq.py index 147e847c..a1bfd70d 100644 --- a/src/fireflyframework_genai/exposure/queues/rabbitmq.py +++ b/src/fireflyframework_genai/exposure/queues/rabbitmq.py @@ -80,7 +80,11 @@ async def start(self) -> None: # Process message within trace context scope with trace_context_scope(span_context): - await self._process_message(message) + try: + await self._process_message(message) + except Exception: + logger.exception("Failed to process RabbitMQ message on queue '%s'", self._queue_name) + continue async def stop(self) -> None: """Stop the RabbitMQ consumer.""" diff --git a/src/fireflyframework_genai/exposure/queues/redis.py b/src/fireflyframework_genai/exposure/queues/redis.py index 13670847..3cbfabfc 100644 --- a/src/fireflyframework_genai/exposure/queues/redis.py +++ b/src/fireflyframework_genai/exposure/queues/redis.py @@ -93,7 +93,11 @@ async def start(self) -> None: # Process message within trace context scope with trace_context_scope(span_context): - await self._process_message(message) + try: + await self._process_message(message) + except Exception: + logger.exception("Failed to process Redis message on channel '%s'", self._channel) + continue finally: await self.stop() diff --git a/src/fireflyframework_genai/exposure/rest/middleware.py b/src/fireflyframework_genai/exposure/rest/middleware.py index 11da601a..d295f0d9 100644 --- a/src/fireflyframework_genai/exposure/rest/middleware.py +++ b/src/fireflyframework_genai/exposure/rest/middleware.py @@ -16,6 +16,7 @@ from __future__ import annotations +import hmac import logging import time import uuid @@ -83,7 +84,7 @@ def add_cors_middleware( app.add_middleware( CORSMiddleware, allow_origins=allow_origins, - allow_credentials=True, + allow_credentials="*" not in allow_origins, allow_methods=allow_methods or ["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["*"], ) @@ -105,6 +106,11 @@ def __init__(self, max_requests: int = 60, window_seconds: float = 60.0) -> None def is_allowed(self, key: str) -> bool: """Return *True* if the request is within the rate limit.""" now = time.monotonic() + # Cleanup stale entries to prevent unbounded memory growth + if len(self._timestamps) > 10000: + stale_keys = [k for k, v in self._timestamps.items() if not v or now - v[-1] > self._window] + for k in stale_keys: + del self._timestamps[k] ts = self._timestamps.setdefault(key, []) ts[:] = [t for t in ts if now - t < self._window] if len(ts) >= self._max: @@ -156,7 +162,7 @@ async def dispatch(self, request: Request, call_next: Any) -> Response: # Try API key if _api_keys: key = request.headers.get(api_key_header, "") - if key in _api_keys: + if any(hmac.compare_digest(key, k) for k in _api_keys): return await call_next(request) # Try bearer token @@ -164,7 +170,7 @@ async def dispatch(self, request: Request, call_next: Any) -> Response: auth_value = request.headers.get(auth_header, "") if auth_value.startswith("Bearer "): token = auth_value[7:] - if token in _bearer_tokens: + if any(hmac.compare_digest(token, t) for t in _bearer_tokens): return await call_next(request) # If no auth methods configured, allow all requests diff --git a/src/fireflyframework_genai/exposure/rest/router.py b/src/fireflyframework_genai/exposure/rest/router.py index 7d65db81..dbbe8e00 100644 --- a/src/fireflyframework_genai/exposure/rest/router.py +++ b/src/fireflyframework_genai/exposure/rest/router.py @@ -20,11 +20,14 @@ from __future__ import annotations +import logging from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from fastapi import APIRouter # type: ignore[import-not-found] +logger = logging.getLogger(__name__) + from fireflyframework_genai.agents.registry import agent_registry from fireflyframework_genai.exposure.rest.schemas import AgentRequest, AgentResponse from fireflyframework_genai.exposure.rest.streaming import sse_stream, sse_stream_incremental @@ -76,15 +79,13 @@ async def run_agent(name: str, request: AgentRequest) -> AgentResponse: agent = agent_registry.get(name) try: prompt = _resolve_prompt(request) - # Wire memory for conversational requests conv_id = request.conversation_id - if conv_id is not None and agent.memory is None: - agent.memory = _rest_memory result = await agent.run(prompt, deps=request.deps, conversation_id=conv_id) output = result.output if hasattr(result, "output") else str(result) return AgentResponse(agent_name=name, output=output) except Exception as exc: - return AgentResponse(agent_name=name, output=None, success=False, error=str(exc)) + logger.exception("Agent '%s' run failed", name) + return AgentResponse(agent_name=name, output=None, success=False, error="Internal server error") @router.post("/{name}/stream") async def stream_agent(name: str, request: AgentRequest) -> Any: @@ -100,8 +101,6 @@ async def stream_agent(name: str, request: AgentRequest) -> Any: agent = agent_registry.get(name) prompt = _resolve_prompt(request) conv_id = request.conversation_id - if conv_id is not None and agent.memory is None: - agent.memory = _rest_memory return StreamingResponse( sse_stream(agent, prompt, deps=request.deps, conversation_id=conv_id), media_type="text/event-stream", @@ -131,8 +130,6 @@ async def stream_agent_incremental( agent = agent_registry.get(name) prompt = _resolve_prompt(request) conv_id = request.conversation_id - if conv_id is not None and agent.memory is None: - agent.memory = _rest_memory return StreamingResponse( sse_stream_incremental( agent, diff --git a/src/fireflyframework_genai/exposure/rest/websocket.py b/src/fireflyframework_genai/exposure/rest/websocket.py index a2182d77..45479b36 100644 --- a/src/fireflyframework_genai/exposure/rest/websocket.py +++ b/src/fireflyframework_genai/exposure/rest/websocket.py @@ -97,11 +97,6 @@ async def agent_ws(websocket: WebSocket, name: str) -> None: {"type": "conversation_id", "data": conversation_id}, ) - # Set per-connection memory only if the agent doesn't already - # have one configured by the user. - if agent.memory is None: - agent.memory = conn_memory - deps = msg.get("deps") # Attempt streaming; if it fails, report the error rather than @@ -123,8 +118,9 @@ async def agent_ws(websocket: WebSocket, name: str) -> None: {"type": "token", "data": token}, ) final = "".join(full_output) - except Exception: + except Exception as exc: # Streaming not supported or failed — fall back + logger.debug("Streaming failed for '%s': %s", name, exc) final = None if final is None: diff --git a/src/fireflyframework_genai/logging.py b/src/fireflyframework_genai/logging.py index 22d84692..1cb8decc 100644 --- a/src/fireflyframework_genai/logging.py +++ b/src/fireflyframework_genai/logging.py @@ -35,6 +35,7 @@ from __future__ import annotations +import copy import json import logging import sys @@ -114,6 +115,7 @@ def __init__( super().__init__(fmt or _DEFAULT_FORMAT, datefmt or _DEFAULT_DATEFMT) def format(self, record: logging.LogRecord) -> str: # noqa: C901 + record = copy.copy(record) # Coloured level badge lvl = record.levelname color = _LEVEL_COLORS.get(lvl, "") diff --git a/src/fireflyframework_genai/memory/conversation.py b/src/fireflyframework_genai/memory/conversation.py index 409c846a..e930dc34 100644 --- a/src/fireflyframework_genai/memory/conversation.py +++ b/src/fireflyframework_genai/memory/conversation.py @@ -161,30 +161,38 @@ def _evict_oldest( def get_turns(self, conversation_id: str) -> list[ConversationTurn]: """Return all turns for a conversation (unfiltered).""" - return list(self._conversations.get(conversation_id, [])) + with self._lock: + return list(self._conversations.get(conversation_id, [])) def get_total_tokens(self, conversation_id: str) -> int: """Return the total estimated token count for a conversation.""" - return sum(t.token_estimate for t in self._conversations.get(conversation_id, [])) + with self._lock: + return sum(t.token_estimate for t in self._conversations.get(conversation_id, [])) def clear(self, conversation_id: str) -> None: """Remove all turns for a conversation.""" - self._conversations.pop(conversation_id, None) + with self._lock: + self._conversations.pop(conversation_id, None) + self._summaries.pop(conversation_id, None) def clear_all(self) -> None: """Remove all conversations.""" - self._conversations.clear() + with self._lock: + self._conversations.clear() + self._summaries.clear() def new_conversation(self) -> str: """Create a new conversation and return its ID.""" - cid = uuid.uuid4().hex - self._conversations[cid] = [] - return cid + with self._lock: + cid = uuid.uuid4().hex + self._conversations[cid] = [] + return cid @property def conversation_ids(self) -> list[str]: """Return all active conversation IDs.""" - return list(self._conversations.keys()) + with self._lock: + return list(self._conversations.keys()) @property def max_tokens(self) -> int: diff --git a/src/fireflyframework_genai/memory/database_store.py b/src/fireflyframework_genai/memory/database_store.py index 583e7ed2..b68fba62 100644 --- a/src/fireflyframework_genai/memory/database_store.py +++ b/src/fireflyframework_genai/memory/database_store.py @@ -48,14 +48,37 @@ import asyncio import logging +import re +from concurrent.futures import ThreadPoolExecutor from datetime import UTC, datetime from typing import Any +_SAFE_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") + +_sync_pool = ThreadPoolExecutor(max_workers=4) + from fireflyframework_genai.exceptions import DatabaseConnectionError, DatabaseStoreError from fireflyframework_genai.memory.types import MemoryEntry logger = logging.getLogger(__name__) + +def _run_sync(coro: Any) -> Any: + """Run *coro* synchronously, safe even when an event loop is already running.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already inside an event loop -- offload to a background thread. + import concurrent.futures + + future = _sync_pool.submit(asyncio.run, coro) + return future.result() + return asyncio.run(coro) + + # -- PostgreSQL Store ------------------------------------------------------- @@ -87,6 +110,8 @@ def __init__( timeout: float = 30.0, schema_name: str = "firefly_memory", ) -> None: + if not _SAFE_IDENTIFIER.match(schema_name): + raise ValueError(f"Invalid schema_name: {schema_name!r}. Must be a valid SQL identifier.") self._url = url self._pool_size = pool_size self._pool_min_size = pool_min_size @@ -181,7 +206,7 @@ def save(self, namespace: str, entry: MemoryEntry) -> None: This is the synchronous wrapper that runs the async version in a thread. """ - asyncio.run(self.async_save(namespace, entry)) + _run_sync(self.async_save(namespace, entry)) async def async_save(self, namespace: str, entry: MemoryEntry) -> None: """Async version of :meth:`save`.""" @@ -218,7 +243,7 @@ async def async_save(self, namespace: str, entry: MemoryEntry) -> None: def load(self, namespace: str) -> list[MemoryEntry]: """Return all non-expired entries stored under *namespace*.""" - return asyncio.run(self.async_load(namespace)) + return _run_sync(self.async_load(namespace)) async def async_load(self, namespace: str) -> list[MemoryEntry]: """Async version of :meth:`load`.""" @@ -244,7 +269,7 @@ async def async_load(self, namespace: str) -> list[MemoryEntry]: def load_by_key(self, namespace: str, key: str) -> MemoryEntry | None: """Return the entry matching *key*, or *None*.""" - return asyncio.run(self.async_load_by_key(namespace, key)) + return _run_sync(self.async_load_by_key(namespace, key)) async def async_load_by_key(self, namespace: str, key: str) -> MemoryEntry | None: """Async version of :meth:`load_by_key`.""" @@ -276,7 +301,7 @@ async def async_load_by_key(self, namespace: str, key: str) -> MemoryEntry | Non def delete(self, namespace: str, entry_id: str) -> None: """Remove a single entry by ID.""" - asyncio.run(self.async_delete(namespace, entry_id)) + _run_sync(self.async_delete(namespace, entry_id)) async def async_delete(self, namespace: str, entry_id: str) -> None: """Async version of :meth:`delete`.""" @@ -298,7 +323,7 @@ async def async_delete(self, namespace: str, entry_id: str) -> None: def clear(self, namespace: str) -> None: """Remove all entries in *namespace*.""" - asyncio.run(self.async_clear(namespace)) + _run_sync(self.async_clear(namespace)) async def async_clear(self, namespace: str) -> None: """Async version of :meth:`clear`.""" @@ -455,7 +480,7 @@ async def _create_indexes(self) -> None: def save(self, namespace: str, entry: MemoryEntry) -> None: """Persist a single :class:`MemoryEntry` under *namespace*.""" - asyncio.run(self.async_save(namespace, entry)) + _run_sync(self.async_save(namespace, entry)) async def async_save(self, namespace: str, entry: MemoryEntry) -> None: """Async version of :meth:`save`.""" @@ -477,7 +502,7 @@ async def async_save(self, namespace: str, entry: MemoryEntry) -> None: def load(self, namespace: str) -> list[MemoryEntry]: """Return all non-expired entries stored under *namespace*.""" - return asyncio.run(self.async_load(namespace)) + return _run_sync(self.async_load(namespace)) async def async_load(self, namespace: str) -> list[MemoryEntry]: """Async version of :meth:`load`.""" @@ -503,7 +528,7 @@ async def async_load(self, namespace: str) -> list[MemoryEntry]: def load_by_key(self, namespace: str, key: str) -> MemoryEntry | None: """Return the entry matching *key*, or *None*.""" - return asyncio.run(self.async_load_by_key(namespace, key)) + return _run_sync(self.async_load_by_key(namespace, key)) async def async_load_by_key(self, namespace: str, key: str) -> MemoryEntry | None: """Async version of :meth:`load_by_key`.""" @@ -533,7 +558,7 @@ async def async_load_by_key(self, namespace: str, key: str) -> MemoryEntry | Non def delete(self, namespace: str, entry_id: str) -> None: """Remove a single entry by ID.""" - asyncio.run(self.async_delete(namespace, entry_id)) + _run_sync(self.async_delete(namespace, entry_id)) async def async_delete(self, namespace: str, entry_id: str) -> None: """Async version of :meth:`delete`.""" @@ -547,7 +572,7 @@ async def async_delete(self, namespace: str, entry_id: str) -> None: def clear(self, namespace: str) -> None: """Remove all entries in *namespace*.""" - asyncio.run(self.async_clear(namespace)) + _run_sync(self.async_clear(namespace)) async def async_clear(self, namespace: str) -> None: """Async version of :meth:`clear`.""" diff --git a/src/fireflyframework_genai/memory/manager.py b/src/fireflyframework_genai/memory/manager.py index 474355ac..6b0a371e 100644 --- a/src/fireflyframework_genai/memory/manager.py +++ b/src/fireflyframework_genai/memory/manager.py @@ -87,9 +87,24 @@ def from_config(cls) -> MemoryManager: ``pip install fireflyframework-genai[mongodb]`` """ import asyncio + from concurrent.futures import ThreadPoolExecutor from fireflyframework_genai.config import get_config + def _run_sync(coro: object) -> object: + """Run *coro* synchronously, safe even when an event loop is already running.""" + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop is not None: + pool = ThreadPoolExecutor(max_workers=1) + try: + return pool.submit(asyncio.run, coro).result() + finally: + pool.shutdown(wait=False) + return asyncio.run(coro) + cfg = get_config() store: MemoryStore @@ -114,7 +129,7 @@ def from_config(cls) -> MemoryManager: schema_name=cfg.memory_postgres_schema, ) # Initialize the database connection pool - asyncio.run(store.initialize()) + _run_sync(store.initialize()) logger.info("PostgreSQL memory backend initialized") elif cfg.memory_backend == "mongodb": @@ -133,7 +148,7 @@ def from_config(cls) -> MemoryManager: pool_size=cfg.memory_mongodb_pool_size, ) # Initialize the database connection pool - asyncio.run(store.initialize()) + _run_sync(store.initialize()) logger.info("MongoDB memory backend initialized") else: diff --git a/src/fireflyframework_genai/memory/store.py b/src/fireflyframework_genai/memory/store.py index cbc21722..3e7963e6 100644 --- a/src/fireflyframework_genai/memory/store.py +++ b/src/fireflyframework_genai/memory/store.py @@ -24,6 +24,7 @@ import asyncio import json import logging +import threading from collections import defaultdict from pathlib import Path from typing import Any, Protocol, runtime_checkable @@ -71,25 +72,31 @@ class InMemoryStore: def __init__(self) -> None: self._data: dict[str, dict[str, MemoryEntry]] = defaultdict(dict) + self._lock = threading.Lock() def save(self, namespace: str, entry: MemoryEntry) -> None: - self._data[namespace][entry.entry_id] = entry + with self._lock: + self._data[namespace][entry.entry_id] = entry def load(self, namespace: str) -> list[MemoryEntry]: - entries = list(self._data.get(namespace, {}).values()) + with self._lock: + entries = list(self._data.get(namespace, {}).values()) return [e for e in entries if not e.is_expired] def load_by_key(self, namespace: str, key: str) -> MemoryEntry | None: - for entry in self._data.get(namespace, {}).values(): - if entry.key == key and not entry.is_expired: - return entry + with self._lock: + for entry in self._data.get(namespace, {}).values(): + if entry.key == key and not entry.is_expired: + return entry return None def delete(self, namespace: str, entry_id: str) -> None: - self._data.get(namespace, {}).pop(entry_id, None) + with self._lock: + self._data.get(namespace, {}).pop(entry_id, None) def clear(self, namespace: str) -> None: - self._data.pop(namespace, None) + with self._lock: + self._data.pop(namespace, None) @property def namespaces(self) -> list[str]: @@ -112,7 +119,10 @@ def __init__(self, base_dir: str | Path = ".firefly_memory") -> None: def _path(self, namespace: str) -> Path: safe_name = namespace.replace("/", "_").replace("\\", "_") - return self._base_dir / f"{safe_name}.json" + resolved = (self._base_dir / f"{safe_name}.json").resolve() + if not resolved.is_relative_to(self._base_dir.resolve()): + raise ValueError(f"Path traversal detected in namespace: {namespace!r}") + return resolved def _read(self, namespace: str) -> dict[str, Any]: path = self._path(namespace) diff --git a/src/fireflyframework_genai/observability/__init__.py b/src/fireflyframework_genai/observability/__init__.py index d8926fbf..1045f7aa 100644 --- a/src/fireflyframework_genai/observability/__init__.py +++ b/src/fireflyframework_genai/observability/__init__.py @@ -24,7 +24,13 @@ from fireflyframework_genai.observability.events import FireflyEvent, FireflyEvents, default_events from fireflyframework_genai.observability.exporters import configure_exporters from fireflyframework_genai.observability.metrics import FireflyMetrics, default_metrics -from fireflyframework_genai.observability.tracer import FireflyTracer, default_tracer +from fireflyframework_genai.observability.tracer import ( + FireflyTracer, + default_tracer, + extract_trace_context, + inject_trace_context, + trace_context_scope, +) from fireflyframework_genai.observability.usage import ( UsageRecord, UsageSummary, @@ -48,7 +54,10 @@ "default_metrics", "default_tracer", "default_usage_tracker", + "extract_trace_context", "get_cost_calculator", + "inject_trace_context", "metered", + "trace_context_scope", "traced", ] diff --git a/src/fireflyframework_genai/observability/decorators.py b/src/fireflyframework_genai/observability/decorators.py index 3bdcc70b..452220e3 100644 --- a/src/fireflyframework_genai/observability/decorators.py +++ b/src/fireflyframework_genai/observability/decorators.py @@ -87,24 +87,26 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: start = time.perf_counter() try: result = await func(*args, **kwargs) - elapsed = (time.perf_counter() - start) * 1000 - default_metrics.record_latency(elapsed, operation=op_name) return result except Exception: default_metrics.record_error(operation=op_name) raise + finally: + elapsed = (time.perf_counter() - start) * 1000 + default_metrics.record_latency(elapsed, operation=op_name) @functools.wraps(func) def sync_wrapper(*args: Any, **kwargs: Any) -> Any: start = time.perf_counter() try: result = func(*args, **kwargs) - elapsed = (time.perf_counter() - start) * 1000 - default_metrics.record_latency(elapsed, operation=op_name) return result except Exception: default_metrics.record_error(operation=op_name) raise + finally: + elapsed = (time.perf_counter() - start) * 1000 + default_metrics.record_latency(elapsed, operation=op_name) import asyncio diff --git a/src/fireflyframework_genai/observability/usage.py b/src/fireflyframework_genai/observability/usage.py index 17e99259..21ace612 100644 --- a/src/fireflyframework_genai/observability/usage.py +++ b/src/fireflyframework_genai/observability/usage.py @@ -282,7 +282,7 @@ def _check_budget(self, usage: UsageRecord) -> None: usage.model, ) except Exception: # noqa: BLE001 - pass + logger.debug("Failed to check budget", exc_info=True) def _create_default_tracker() -> UsageTracker: diff --git a/src/fireflyframework_genai/pipeline/__init__.py b/src/fireflyframework_genai/pipeline/__init__.py index 6c33c97c..1dfdc189 100644 --- a/src/fireflyframework_genai/pipeline/__init__.py +++ b/src/fireflyframework_genai/pipeline/__init__.py @@ -26,6 +26,7 @@ from fireflyframework_genai.pipeline.result import ExecutionTraceEntry, NodeResult, PipelineResult from fireflyframework_genai.pipeline.steps import ( AgentStep, + BatchLLMStep, BranchStep, CallableStep, FanInStep, @@ -36,6 +37,7 @@ __all__ = [ "AgentStep", + "BatchLLMStep", "BranchStep", "CallableStep", "DAG", diff --git a/src/fireflyframework_genai/pipeline/builder.py b/src/fireflyframework_genai/pipeline/builder.py index 57339c23..242fbc36 100644 --- a/src/fireflyframework_genai/pipeline/builder.py +++ b/src/fireflyframework_genai/pipeline/builder.py @@ -147,4 +147,4 @@ def _resolve_step(step: Any) -> Any: if asyncio.iscoroutinefunction(step): return CallableStep(step) - return step + raise TypeError(f"Cannot resolve {type(step).__name__} as a pipeline step. Must be StepExecutor, agent-like, or async callable.") diff --git a/src/fireflyframework_genai/pipeline/engine.py b/src/fireflyframework_genai/pipeline/engine.py index 1d96d2b9..f378bc82 100644 --- a/src/fireflyframework_genai/pipeline/engine.py +++ b/src/fireflyframework_genai/pipeline/engine.py @@ -269,6 +269,7 @@ async def _execute_node( backoff_factor = node.backoff_factor retries = 0 last_error: str | None = None + started_at = datetime.now(UTC) while retries <= max_retries: started_at = datetime.now(UTC) @@ -323,7 +324,7 @@ async def _execute_node( trace_entries.append( ExecutionTraceEntry( node_id=node_id, - started_at=started_at, # type: ignore[possibly-undefined] + started_at=started_at, completed_at=completed_at, status="failed", ) @@ -406,6 +407,13 @@ def _gather_inputs(self, node_id: str, context: PipelineContext) -> dict[str, An for edge in edges: upstream_result = context.get_node_result(edge.source) if upstream_result is not None: - value = upstream_result.output if hasattr(upstream_result, "output") else upstream_result + raw = upstream_result.output if hasattr(upstream_result, "output") else upstream_result + if edge.output_key and edge.output_key != "output": + if isinstance(raw, dict): + value = raw.get(edge.output_key, raw) + else: + value = getattr(raw, edge.output_key, raw) + else: + value = raw inputs[edge.input_key] = value return inputs diff --git a/src/fireflyframework_genai/prompts/template.py b/src/fireflyframework_genai/prompts/template.py index 57c7ebcf..a1783291 100644 --- a/src/fireflyframework_genai/prompts/template.py +++ b/src/fireflyframework_genai/prompts/template.py @@ -33,6 +33,9 @@ # Shared Jinja2 environment with safe defaults _jinja_env = Environment(loader=BaseLoader(), autoescape=False, keep_trailing_newline=True) +# Sentinel value to distinguish "no default" from ``default=None``. +_UNSET = object() + class PromptVariable(BaseModel): """Describes a single variable expected by a prompt template.""" @@ -40,7 +43,7 @@ class PromptVariable(BaseModel): name: str description: str = "" required: bool = True - default: Any = None + default: Any = _UNSET class PromptInfo(BaseModel): @@ -129,7 +132,7 @@ def render(self, **kwargs: Any) -> str: for var in self._variables: if var.name in kwargs: merged[var.name] = kwargs[var.name] - elif var.default is not None: + elif var.default is not _UNSET: merged[var.name] = var.default # Include any extra kwargs not declared as variables for k, v in kwargs.items(): diff --git a/src/fireflyframework_genai/reasoning/goal_decomposition.py b/src/fireflyframework_genai/reasoning/goal_decomposition.py index 8ab2e831..cfe5549f 100644 --- a/src/fireflyframework_genai/reasoning/goal_decomposition.py +++ b/src/fireflyframework_genai/reasoning/goal_decomposition.py @@ -200,7 +200,7 @@ async def _execute_task( ) -> str: """Execute a single task, optionally delegating to a sub-pattern.""" if self._task_pattern is not None: - sub_result: ReasoningResult = await self._task_pattern.execute(agent, task, **kwargs) + sub_result: ReasoningResult = await self._task_pattern.execute(agent, task, memory=memory, **kwargs) return str(sub_result.output) template = self._get_prompt("execute_task", GOAL_TASK_EXECUTION_PROMPT) prompt = template.render(goal=goal, task=task) diff --git a/src/fireflyframework_genai/reasoning/registry.py b/src/fireflyframework_genai/reasoning/registry.py index 534b804f..cd5df350 100644 --- a/src/fireflyframework_genai/reasoning/registry.py +++ b/src/fireflyframework_genai/reasoning/registry.py @@ -69,3 +69,30 @@ def __len__(self) -> int: # Module-level singleton reasoning_registry = ReasoningPatternRegistry() + + +def _auto_register_builtins() -> None: + """Lazily register the six built-in reasoning patterns.""" + try: + from fireflyframework_genai.reasoning.chain_of_thought import ChainOfThoughtPattern + from fireflyframework_genai.reasoning.goal_decomposition import GoalDecompositionPattern + from fireflyframework_genai.reasoning.plan_and_execute import PlanAndExecutePattern + from fireflyframework_genai.reasoning.react import ReActPattern + from fireflyframework_genai.reasoning.reflexion import ReflexionPattern + from fireflyframework_genai.reasoning.tree_of_thoughts import TreeOfThoughtsPattern + + for name, cls in [ + ("react", ReActPattern), + ("chain_of_thought", ChainOfThoughtPattern), + ("plan_and_execute", PlanAndExecutePattern), + ("reflexion", ReflexionPattern), + ("tree_of_thoughts", TreeOfThoughtsPattern), + ("goal_decomposition", GoalDecompositionPattern), + ]: + if not reasoning_registry.has(name): + reasoning_registry.register(name, cls) + except Exception: # noqa: BLE001 + logger.debug("Could not auto-register built-in reasoning patterns", exc_info=True) + + +_auto_register_builtins() diff --git a/src/fireflyframework_genai/resilience/circuit_breaker.py b/src/fireflyframework_genai/resilience/circuit_breaker.py index 25930ecd..c9109b1e 100644 --- a/src/fireflyframework_genai/resilience/circuit_breaker.py +++ b/src/fireflyframework_genai/resilience/circuit_breaker.py @@ -352,7 +352,7 @@ def __init__( else: self._breaker = None - async def before(self, context: Any) -> None: + async def before_run(self, context: Any) -> None: """Check circuit breaker before agent execution. Args: @@ -372,7 +372,7 @@ async def before(self, context: Any) -> None: # Enter circuit breaker (will raise if open) await self._breaker.__aenter__() - async def after(self, context: Any, result: Any) -> Any: + async def after_run(self, context: Any, result: Any) -> Any: """Update circuit breaker after agent execution. Args: @@ -390,6 +390,19 @@ async def after(self, context: Any, result: Any) -> Any: return result + async def on_error(self, context: Any, error: Exception) -> None: + """Update circuit breaker on agent execution error. + + Args: + context: Middleware context. + error: The exception that caused the failure. + """ + if not self._enabled or self._breaker is None: + return + + # Record failure in the circuit breaker + await self._breaker.__aexit__(type(error), error, None) + def get_metrics(self) -> dict[str, Any]: """Get circuit breaker metrics. diff --git a/src/fireflyframework_genai/security/__init__.py b/src/fireflyframework_genai/security/__init__.py index c8a97dc1..e5118c51 100644 --- a/src/fireflyframework_genai/security/__init__.py +++ b/src/fireflyframework_genai/security/__init__.py @@ -21,12 +21,18 @@ """ from fireflyframework_genai.security.encryption import AESEncryptionProvider, EncryptedMemoryStore, EncryptionProvider +from fireflyframework_genai.security.output_guard import OutputGuard, default_output_guard +from fireflyframework_genai.security.prompt_guard import PromptGuard, default_prompt_guard from fireflyframework_genai.security.rbac import RBACManager, require_permission __all__ = [ - "RBACManager", - "require_permission", - "EncryptionProvider", "AESEncryptionProvider", "EncryptedMemoryStore", + "EncryptionProvider", + "OutputGuard", + "PromptGuard", + "RBACManager", + "default_output_guard", + "default_prompt_guard", + "require_permission", ] diff --git a/src/fireflyframework_genai/security/encryption.py b/src/fireflyframework_genai/security/encryption.py index 734f574e..0050d850 100644 --- a/src/fireflyframework_genai/security/encryption.py +++ b/src/fireflyframework_genai/security/encryption.py @@ -54,6 +54,7 @@ import base64 import logging +import os from typing import Any, Protocol, runtime_checkable from fireflyframework_genai.memory.types import MemoryEntry @@ -132,21 +133,29 @@ def __init__(self, key: str | bytes) -> None: "Encryption support requires 'cryptography'. Install with: pip install fireflyframework-genai[security]" ) from exc - # Derive 32-byte key if needed - key_bytes = key.encode("utf-8") if isinstance(key, str) else key - - if len(key_bytes) != 32: - # Use PBKDF2 to derive a 32-byte key from the password - logger.debug("Deriving 32-byte key from provided key using PBKDF2") - kdf = PBKDF2HMAC( - algorithm=hashes.SHA256(), - length=32, - salt=b"firefly_genai_salt", # Fixed salt (not ideal for high security) - iterations=100_000, - ) - key_bytes = kdf.derive(key_bytes) - - self._cipher = AESGCM(key_bytes) + # Store raw key bytes for per-call salt derivation + self._raw_key = key.encode("utf-8") if isinstance(key, str) else key + self._AESGCM = AESGCM + self._PBKDF2HMAC = PBKDF2HMAC + self._hashes = hashes + + # If the key is exactly 32 bytes, use it directly (pre-derived) + if len(self._raw_key) == 32: + self._direct_key = self._raw_key + else: + self._direct_key = None + + def _derive_key(self, salt: bytes) -> bytes: + """Derive a 32-byte key from the raw key using PBKDF2 with the given salt.""" + if self._direct_key is not None: + return self._direct_key + kdf = self._PBKDF2HMAC( + algorithm=self._hashes.SHA256(), + length=32, + salt=salt, + iterations=100_000, + ) + return kdf.derive(self._raw_key) def encrypt(self, plaintext: str) -> str: """Encrypt plaintext using AES-256-GCM. @@ -155,19 +164,22 @@ def encrypt(self, plaintext: str) -> str: plaintext: String to encrypt. Returns: - Base64-encoded string containing: nonce + ciphertext + tag + Base64-encoded string containing: salt[16] + nonce[12] + ciphertext + tag """ - import os + # Generate random 16-byte salt for PBKDF2 key derivation + salt = os.urandom(16) + key_bytes = self._derive_key(salt) + cipher = self._AESGCM(key_bytes) # Generate random 12-byte nonce (recommended for GCM) nonce = os.urandom(12) # Encrypt (returns ciphertext + 16-byte authentication tag) plaintext_bytes = plaintext.encode("utf-8") - ciphertext = self._cipher.encrypt(nonce, plaintext_bytes, None) + ciphertext = cipher.encrypt(nonce, plaintext_bytes, None) - # Combine nonce + ciphertext for storage - encrypted_data = nonce + ciphertext + # Combine salt + nonce + ciphertext for storage + encrypted_data = salt + nonce + ciphertext # Base64 encode for safe storage return base64.b64encode(encrypted_data).decode("ascii") @@ -188,12 +200,17 @@ def decrypt(self, ciphertext: str) -> str: # Base64 decode encrypted_data = base64.b64decode(ciphertext.encode("ascii")) - # Split nonce and ciphertext - nonce = encrypted_data[:12] - ciphertext_bytes = encrypted_data[12:] + # Split salt, nonce, and ciphertext + salt = encrypted_data[:16] + nonce = encrypted_data[16:28] + ciphertext_bytes = encrypted_data[28:] + + # Derive key from salt + key_bytes = self._derive_key(salt) + cipher = self._AESGCM(key_bytes) # Decrypt (automatically verifies authentication tag) - plaintext_bytes = self._cipher.decrypt(nonce, ciphertext_bytes, None) + plaintext_bytes = cipher.decrypt(nonce, ciphertext_bytes, None) return plaintext_bytes.decode("utf-8") except Exception as exc: @@ -326,4 +343,7 @@ def create_encryption_provider_from_config() -> EncryptionProvider | None: # Module-level default instance -default_encryption_provider: EncryptionProvider | None = create_encryption_provider_from_config() +try: + default_encryption_provider: EncryptionProvider | None = create_encryption_provider_from_config() +except Exception: # noqa: BLE001 + default_encryption_provider = None diff --git a/src/fireflyframework_genai/security/rbac.py b/src/fireflyframework_genai/security/rbac.py index dc0df645..ec5d99f7 100644 --- a/src/fireflyframework_genai/security/rbac.py +++ b/src/fireflyframework_genai/security/rbac.py @@ -326,33 +326,38 @@ async def run_agent(token: str, agent_name: str, prompt: str): def decorator(func: Callable) -> Callable: @functools.wraps(func) async def async_wrapper(*args: Any, **kwargs: Any) -> Any: - nonlocal rbac - # Get RBAC manager - if rbac is None: - rbac = _get_default_rbac() + manager = rbac or _get_default_rbac() - if rbac is None: + if manager is None: raise ValueError( "No RBAC manager configured. Set FIREFLY_GENAI_RBAC_ENABLED=true " "and FIREFLY_GENAI_RBAC_JWT_SECRET in environment." ) - # Extract token from kwargs - token = kwargs.get(token_param) + # Extract token from args/kwargs using signature binding + import inspect as _inspect + + try: + sig = _inspect.signature(func) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + token = bound.arguments.get(token_param) + except TypeError as exc: + raise ValueError(f"Missing required parameter: {token_param}") from exc if not token: raise ValueError(f"Missing required parameter: {token_param}") # Validate token and check permission try: - claims = rbac.validate_token(token) + claims = manager.validate_token(token) except ValueError as exc: logger.warning("Token validation failed: %s", exc) raise - if not rbac.has_permission(claims, permission): - user_id = rbac.get_user_id(claims) - roles = rbac.get_roles(claims) + if not manager.has_permission(claims, permission): + user_id = manager.get_user_id(claims) + roles = manager.get_roles(claims) logger.warning( "Permission denied: user=%s, roles=%s, required=%s", user_id, @@ -366,33 +371,38 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: @functools.wraps(func) def sync_wrapper(*args: Any, **kwargs: Any) -> Any: - nonlocal rbac - # Get RBAC manager - if rbac is None: - rbac = _get_default_rbac() + manager = rbac or _get_default_rbac() - if rbac is None: + if manager is None: raise ValueError( "No RBAC manager configured. Set FIREFLY_GENAI_RBAC_ENABLED=true " "and FIREFLY_GENAI_RBAC_JWT_SECRET in environment." ) - # Extract token from kwargs - token = kwargs.get(token_param) + # Extract token from args/kwargs + import inspect as _inspect + + try: + sig = _inspect.signature(func) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + token = bound.arguments.get(token_param) + except TypeError as exc: + raise ValueError(f"Missing required parameter: {token_param}") from exc if not token: raise ValueError(f"Missing required parameter: {token_param}") # Validate token and check permission try: - claims = rbac.validate_token(token) + claims = manager.validate_token(token) except ValueError as exc: logger.warning("Token validation failed: %s", exc) raise - if not rbac.has_permission(claims, permission): - user_id = rbac.get_user_id(claims) - roles = rbac.get_roles(claims) + if not manager.has_permission(claims, permission): + user_id = manager.get_user_id(claims) + roles = manager.get_roles(claims) logger.warning( "Permission denied: user=%s, roles=%s, required=%s", user_id, diff --git a/src/fireflyframework_genai/tools/builtins/filesystem.py b/src/fireflyframework_genai/tools/builtins/filesystem.py index f4a8a010..9dcf2234 100644 --- a/src/fireflyframework_genai/tools/builtins/filesystem.py +++ b/src/fireflyframework_genai/tools/builtins/filesystem.py @@ -69,7 +69,7 @@ def _resolve(self, relative: str) -> Path: """Resolve *relative* under base_dir and ensure it does not escape.""" resolved = (self._base_dir / relative).resolve() base_resolved = self._base_dir.resolve() - if not str(resolved).startswith(str(base_resolved)): + if not resolved.is_relative_to(base_resolved): raise PermissionError(f"Path '{relative}' escapes the sandbox directory") return resolved diff --git a/src/fireflyframework_genai/tools/builtins/http.py b/src/fireflyframework_genai/tools/builtins/http.py index e8bf630f..4c66c996 100644 --- a/src/fireflyframework_genai/tools/builtins/http.py +++ b/src/fireflyframework_genai/tools/builtins/http.py @@ -68,6 +68,7 @@ def __init__( description="Send HTTP requests (GET, POST, PUT, DELETE)", tags=["http", "web"], guards=guards, + timeout=timeout, parameters=[ ParameterSpec(name="url", type_annotation="str", description="Request URL", required=True), ParameterSpec( @@ -112,6 +113,17 @@ def __init__( if use_pool and not HTTPX_AVAILABLE: logger.warning("Connection pooling requested but httpx not available. Install with: pip install httpx") + def __del__(self) -> None: + """Warn if the client was not explicitly closed.""" + if self._client is not None: + import warnings + + warnings.warn( + f"Unclosed {self.__class__.__name__!r}. Call 'await tool.close()' to release connections.", + ResourceWarning, + stacklevel=2, + ) + async def close(self) -> None: """Close the HTTP client and release connections.""" if self._client is not None: diff --git a/src/fireflyframework_genai/tools/builtins/shell.py b/src/fireflyframework_genai/tools/builtins/shell.py index 6f318e67..b72303c7 100644 --- a/src/fireflyframework_genai/tools/builtins/shell.py +++ b/src/fireflyframework_genai/tools/builtins/shell.py @@ -76,8 +76,8 @@ async def _execute(self, **kwargs: Any) -> dict[str, Any]: if executable not in self._allowed: raise PermissionError(f"Command '{executable}' is not in the allowed list: {sorted(self._allowed)}") - proc = await asyncio.create_subprocess_shell( - command, + proc = await asyncio.create_subprocess_exec( + *parts, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, cwd=self._working_dir, diff --git a/src/fireflyframework_genai/tools/builtins/text_tool.py b/src/fireflyframework_genai/tools/builtins/text_tool.py index 811ef3ea..e0885fde 100644 --- a/src/fireflyframework_genai/tools/builtins/text_tool.py +++ b/src/fireflyframework_genai/tools/builtins/text_tool.py @@ -20,6 +20,7 @@ from __future__ import annotations +import asyncio import re from collections.abc import Sequence from typing import Any @@ -106,7 +107,7 @@ async def _execute(self, **kwargs: Any) -> Any: pattern = kwargs.get("pattern") if not pattern: raise ValueError("'pattern' is required for the extract action") - return self._extract(text, pattern) + return await self._extract(text, pattern) if action == "truncate": max_words = kwargs.get("max_words") if max_words is None: @@ -117,10 +118,10 @@ async def _execute(self, **kwargs: Any) -> Any: replacement = kwargs.get("replacement", "") if not pattern: raise ValueError("'pattern' is required for the replace action") - return self._replace(text, pattern, replacement or "") + return await self._replace(text, pattern, replacement or "") if action == "split": pattern = kwargs.get("pattern", r"\n") - return self._split(text, pattern or r"\n") + return await self._split(text, pattern or r"\n") raise ValueError(f"Unknown action '{action}'; expected count, extract, truncate, replace, or split") @@ -139,9 +140,9 @@ def _count(text: str, unit: str) -> dict[str, int]: raise ValueError(f"Unknown unit '{unit}'; expected words, chars, sentences, or lines") @staticmethod - def _extract(text: str, pattern: str) -> list[str]: - """Extract all regex matches.""" - return re.findall(pattern, text) + async def _extract(text: str, pattern: str) -> list[str]: + """Extract all regex matches with timeout protection.""" + return await asyncio.wait_for(asyncio.to_thread(re.findall, pattern, text), timeout=5.0) @staticmethod def _truncate(text: str, max_words: int) -> str: @@ -152,11 +153,11 @@ def _truncate(text: str, max_words: int) -> str: return " ".join(words[:max_words]) + "..." @staticmethod - def _replace(text: str, pattern: str, replacement: str) -> str: - """Regex-based find-and-replace.""" - return re.sub(pattern, replacement, text) + async def _replace(text: str, pattern: str, replacement: str) -> str: + """Regex-based find-and-replace with timeout protection.""" + return await asyncio.wait_for(asyncio.to_thread(re.sub, pattern, replacement, text), timeout=5.0) @staticmethod - def _split(text: str, pattern: str) -> list[str]: - """Split text by a regex pattern.""" - return re.split(pattern, text) + async def _split(text: str, pattern: str) -> list[str]: + """Split text by a regex pattern with timeout protection.""" + return await asyncio.wait_for(asyncio.to_thread(re.split, pattern, text), timeout=5.0) diff --git a/src/fireflyframework_genai/tools/cached.py b/src/fireflyframework_genai/tools/cached.py index b15cd59c..c0afcbf2 100644 --- a/src/fireflyframework_genai/tools/cached.py +++ b/src/fireflyframework_genai/tools/cached.py @@ -32,6 +32,7 @@ import hashlib import json import logging +import threading import time from typing import Any @@ -72,6 +73,7 @@ def __init__( self._ttl = ttl_seconds self._max_entries = max_entries self._cache: dict[str, _CacheEntry] = {} + self._lock = threading.Lock() # -- ToolProtocol conformance -------------------------------------------- @@ -89,15 +91,17 @@ async def execute(self, **kwargs: Any) -> Any: return await self._tool.execute(**kwargs) key = self._make_key(kwargs) - entry = self._cache.get(key) now = time.monotonic() - if entry is not None and entry.expires_at > now: - logger.debug("Cache hit for tool '%s' (key=%s)", self.name, key[:12]) - return entry.value + with self._lock: + entry = self._cache.get(key) + if entry is not None and entry.expires_at > now: + logger.debug("Cache hit for tool '%s' (key=%s)", self.name, key[:12]) + return entry.value result = await self._tool.execute(**kwargs) - self._put(key, result, now) + with self._lock: + self._put(key, result, time.monotonic()) return result # -- Cache management ---------------------------------------------------- @@ -105,18 +109,21 @@ async def execute(self, **kwargs: Any) -> Any: def invalidate(self, **kwargs: Any) -> bool: """Remove a specific entry from the cache. Returns *True* if found.""" key = self._make_key(kwargs) - return self._cache.pop(key, None) is not None + with self._lock: + return self._cache.pop(key, None) is not None def clear(self) -> int: """Drop all cached entries. Returns the number evicted.""" - count = len(self._cache) - self._cache.clear() - return count + with self._lock: + count = len(self._cache) + self._cache.clear() + return count @property def cache_size(self) -> int: """Number of entries currently in the cache (including expired).""" - return len(self._cache) + with self._lock: + return len(self._cache) # -- Internals ----------------------------------------------------------- diff --git a/src/fireflyframework_genai/tools/guards.py b/src/fireflyframework_genai/tools/guards.py index 09fdf58c..e265d581 100644 --- a/src/fireflyframework_genai/tools/guards.py +++ b/src/fireflyframework_genai/tools/guards.py @@ -23,6 +23,7 @@ from __future__ import annotations import re +import threading import time from collections.abc import Awaitable, Callable, Sequence from typing import Any @@ -65,19 +66,21 @@ def __init__(self, max_calls: int, period_seconds: float = 60.0) -> None: self._max_calls = max_calls self._period = period_seconds self._timestamps: list[float] = [] + self._lock = threading.Lock() async def check(self, tool_name: str, kwargs: dict[str, Any]) -> GuardResult: now = time.monotonic() - # Sliding-window rate limiter: discard timestamps that have aged out - # of the current window, then check if capacity remains. - self._timestamps = [t for t in self._timestamps if now - t < self._period] - if len(self._timestamps) >= self._max_calls: - return GuardResult( - passed=False, - reason=f"Rate limit exceeded: {self._max_calls} calls per {self._period}s", - ) - # Record this invocation timestamp to count against the window. - self._timestamps.append(now) + with self._lock: + # Sliding-window rate limiter: discard timestamps that have aged out + # of the current window, then check if capacity remains. + self._timestamps = [t for t in self._timestamps if now - t < self._period] + if len(self._timestamps) >= self._max_calls: + return GuardResult( + passed=False, + reason=f"Rate limit exceeded: {self._max_calls} calls per {self._period}s", + ) + # Record this invocation timestamp to count against the window. + self._timestamps.append(now) return GuardResult(passed=True) @@ -119,8 +122,16 @@ def __init__( allowed_patterns: Sequence[str] = (), denied_patterns: Sequence[str] = (), ) -> None: - self._allowed = [re.compile(p) for p in allowed_patterns] - self._denied = [re.compile(p) for p in denied_patterns] + self._allowed = [self._safe_compile(p) for p in allowed_patterns] + self._denied = [self._safe_compile(p) for p in denied_patterns] + + @staticmethod + def _safe_compile(pattern: str) -> re.Pattern[str]: + """Compile a regex pattern, raising ValueError on invalid patterns.""" + try: + return re.compile(pattern) + except re.error as exc: + raise ValueError(f"Invalid regex pattern {pattern!r}: {exc}") from exc async def check(self, tool_name: str, kwargs: dict[str, Any]) -> GuardResult: # Check every kwarg value against deny patterns. An allowed pattern diff --git a/tests/agents/test_prompt_cache.py b/tests/agents/test_prompt_cache.py index 06cba182..fdc73f6d 100644 --- a/tests/agents/test_prompt_cache.py +++ b/tests/agents/test_prompt_cache.py @@ -58,7 +58,7 @@ async def test_before_hook_with_disabled_middleware(self): context.model = "anthropic:claude-3-5-sonnet-20241022" # Should not raise or modify context - await middleware.before(context) + await middleware.before_run(context) async def test_before_hook_anthropic_caching(self): """Test Anthropic-specific caching configuration.""" @@ -71,7 +71,7 @@ async def test_before_hook_anthropic_caching(self): context.model = "anthropic:claude-3-5-sonnet-20241022" context.metadata = {} - await middleware.before(context) + await middleware.before_run(context) # Should configure caching metadata assert context.metadata["_prompt_cache_enabled"] is True @@ -85,7 +85,7 @@ async def test_before_hook_openai_caching(self): context.model = "openai:gpt-4o" # Should not raise (OpenAI caching is automatic) - await middleware.before(context) + await middleware.before_run(context) async def test_before_hook_gemini_caching(self): """Test Gemini-specific caching configuration.""" @@ -95,7 +95,7 @@ async def test_before_hook_gemini_caching(self): context.model = "gemini:gemini-1.5-pro" context.metadata = {} - await middleware.before(context) + await middleware.before_run(context) # Should configure Gemini caching assert context.metadata["_gemini_cache_enabled"] is True @@ -112,7 +112,7 @@ async def test_before_hook_bedrock_anthropic_routes_to_anthropic_caching(self): context.model = "bedrock:anthropic.claude-3-5-sonnet-latest" context.metadata = {} - await middleware.before(context) + await middleware.before_run(context) assert context.metadata["_prompt_cache_enabled"] is True assert context.metadata["_cache_min_tokens"] == 2048 @@ -125,7 +125,7 @@ async def test_before_hook_azure_openai_routes_to_openai_caching(self): context.model = "azure:gpt-4o" # Should not raise (OpenAI caching is automatic) - await middleware.before(context) + await middleware.before_run(context) async def test_before_hook_unsupported_provider(self): """Test behavior with unsupported provider.""" @@ -135,7 +135,7 @@ async def test_before_hook_unsupported_provider(self): context.model = "unknown:model" # Should not raise, just log debug message - await middleware.before(context) + await middleware.before_run(context) async def test_before_hook_no_model(self): """Test behavior when model is not set.""" @@ -145,7 +145,7 @@ async def test_before_hook_no_model(self): context.model = "" # Should not raise - await middleware.before(context) + await middleware.before_run(context) async def test_after_hook_with_cache_usage(self): """Test after hook records cache usage metrics.""" @@ -160,7 +160,7 @@ async def test_after_hook_with_cache_usage(self): usage.cache_read_tokens = 0 result.usage = Mock(return_value=usage) - returned_result = await middleware.after(context, result) + returned_result = await middleware.after_run(context, result) # Should return unchanged result assert returned_result == result @@ -178,7 +178,7 @@ async def test_after_hook_with_cache_hits(self): usage.cache_read_tokens = 5000 result.usage = Mock(return_value=usage) - returned_result = await middleware.after(context, result) + returned_result = await middleware.after_run(context, result) assert returned_result == result @@ -190,7 +190,7 @@ async def test_after_hook_no_usage(self): result = Mock(spec=[]) # No usage method # Should not raise - returned_result = await middleware.after(context, result) + returned_result = await middleware.after_run(context, result) assert returned_result == result async def test_after_hook_disabled(self): @@ -200,7 +200,7 @@ async def test_after_hook_disabled(self): context = Mock() result = Mock() - returned_result = await middleware.after(context, result) + returned_result = await middleware.after_run(context, result) # Should return result unchanged assert returned_result == result @@ -213,7 +213,7 @@ async def test_cache_system_prompt_disabled(self): context.model = "anthropic:claude-3-5-sonnet-20241022" # Should not configure caching - await middleware.before(context) + await middleware.before_run(context) class TestCacheStatistics: