diff --git a/engine/src/agent_control_engine/core.py b/engine/src/agent_control_engine/core.py index 99c2273b..e2ae8b6e 100644 --- a/engine/src/agent_control_engine/core.py +++ b/engine/src/agent_control_engine/core.py @@ -33,6 +33,108 @@ # Max concurrent evaluations (limits task spawning overhead for large policies) MAX_CONCURRENT_EVALUATIONS = int(os.environ.get("MAX_CONCURRENT_EVALUATIONS", "3")) +SELECTED_DATA_PREVIEW_MAX_CHARS = int( + os.environ.get("AGENT_CONTROL_SELECTED_DATA_PREVIEW_MAX_CHARS", "500") +) +SELECTED_DATA_PREVIEW_MAX_ITEMS = int( + os.environ.get("AGENT_CONTROL_SELECTED_DATA_PREVIEW_MAX_ITEMS", "20") +) +SELECTED_DATA_PREVIEW_MAX_DEPTH = int( + os.environ.get("AGENT_CONTROL_SELECTED_DATA_PREVIEW_MAX_DEPTH", "3") +) +_SENSITIVE_KEY_PARTS = ( + "api_key", + "apikey", + "authorization", + "credential", + "password", + "secret", + "token", +) + + +def _env_flag(name: str, *, default: bool = False) -> bool: + """Read a boolean environment flag.""" + value = os.environ.get(name) + if value is None: + return default + return value.strip().lower() in {"1", "true", "yes", "on"} + + +def _is_sensitive_key(key: object) -> bool: + """Return whether a mapping key is likely to contain a secret.""" + normalized = str(key).lower() + return any(part in normalized for part in _SENSITIVE_KEY_PARTS) + + +def _truncate_string(value: str, max_chars: int) -> tuple[str, bool]: + """Return a bounded string preview and whether it was truncated.""" + if len(value) <= max_chars: + return value, False + if max_chars <= 3: + return value[:max_chars], True + return f"{value[: max_chars - 3]}...", True + + +def _selected_data_preview_value( + value: Any, + *, + depth: int = 0, +) -> tuple[Any, bool]: + """Build a bounded, redacted preview of selected data.""" + if depth >= SELECTED_DATA_PREVIEW_MAX_DEPTH: + return "", True + + if value is None or isinstance(value, bool | int | float): + return value, False + + if isinstance(value, str): + return _truncate_string(value, SELECTED_DATA_PREVIEW_MAX_CHARS) + + if isinstance(value, dict): + preview: dict[str, Any] = {} + truncated = len(value) > SELECTED_DATA_PREVIEW_MAX_ITEMS + for index, (key, item) in enumerate(value.items()): + if index >= SELECTED_DATA_PREVIEW_MAX_ITEMS: + break + preview_key = str(key) + if _is_sensitive_key(key): + preview[preview_key] = "" + truncated = True + continue + preview_item, item_truncated = _selected_data_preview_value( + item, + depth=depth + 1, + ) + preview[preview_key] = preview_item + truncated = truncated or item_truncated + return preview, truncated + + if isinstance(value, list | tuple): + preview_items: list[Any] = [] + truncated = len(value) > SELECTED_DATA_PREVIEW_MAX_ITEMS + for item in value[:SELECTED_DATA_PREVIEW_MAX_ITEMS]: + preview_item, item_truncated = _selected_data_preview_value( + item, + depth=depth + 1, + ) + preview_items.append(preview_item) + truncated = truncated or item_truncated + return preview_items, truncated + + text_preview, truncated = _truncate_string(str(value), SELECTED_DATA_PREVIEW_MAX_CHARS) + return text_preview, truncated + + +def _selected_data_preview(value: Any) -> dict[str, Any]: + """Return UI-safe selector output details for evaluator-level inspection.""" + preview, truncated = _selected_data_preview_value(value) + return { + "type": type(value).__name__, + "value": preview, + "truncated": truncated, + } + @functools.lru_cache(maxsize=256) def _compile_regex(pattern: str) -> Any: @@ -102,9 +204,16 @@ def __init__( self, controls: Sequence[ControlWithIdentity], context: Literal["sdk", "server"] = "server", + *, + include_raw_selected_data: bool | None = None, ): self.controls = controls self.context = context + self.include_raw_selected_data = ( + _env_flag("AGENT_CONTROL_INCLUDE_RAW_SELECTED_DATA") + if include_raw_selected_data is None + else include_raw_selected_data + ) @staticmethod def _truncated_message(message: str | None) -> str | None: @@ -224,6 +333,9 @@ async def _evaluate_leaf( "message": self._truncated_message(result.message), } metadata = dict(result.metadata or {}) + if self.include_raw_selected_data: + metadata["engine_selected_data"] = data + metadata["engine_selected_data_preview"] = _selected_data_preview(data) metadata["condition_trace"] = trace return _ConditionEvaluation( result=result.model_copy(update={"metadata": metadata}), @@ -269,7 +381,21 @@ def _composite_metadata( *, matched: bool, ) -> dict[str, Any] | None: - """Select stable child metadata to preserve on composite results.""" + """Select stable child metadata to preserve on composite results. + + The engine_selected_data_preview value in this metadata is not all + evaluator inputs. It is the bounded selected value preview from the leaf + metadata the engine preserves for the final composite result: + - or where one child matches: engine_selected_data_preview comes from the + matching child. + - and where one child fails: engine_selected_data_preview comes from the + failing child. + - and where all children match: engine_selected_data_preview comes from the + first matching child, usually the first leaf. + - or where no children match: engine_selected_data_preview comes from the + first evaluated child. + - not: engine_selected_data_preview comes from its child. + """ source_result: EvaluatorResult | None = None if matched: source_result = next( diff --git a/engine/tests/test_core.py b/engine/tests/test_core.py index 9c8da751..ed4e6e00 100644 --- a/engine/tests/test_core.py +++ b/engine/tests/test_core.py @@ -157,7 +157,7 @@ async def evaluate(self, data: Any) -> EvaluatorResult: matched=matched, confidence=0.8 if matched else 0.4, message=f"Metadata {self.config.value}", - metadata={"source": self.config.value, "selected_data": data}, + metadata={"source": self.config.value, "selected_data": f"evaluator:{data}"}, ) _execution_log.append(f"metadata:{self.config.value}:end") return result @@ -831,7 +831,13 @@ async def test_confidence_is_full_on_deny_match(self): controls = [ make_control(1, "denier", "test-deny", action="deny", config_value="d"), ] + [ - make_control(i + 2, f"blocker{i}", "test-blocker", action="observe", config_value=str(i)) + make_control( + i + 2, + f"blocker{i}", + "test-blocker", + action="observe", + config_value=str(i), + ) for i in range(9) ] engine = ControlEngine(controls) @@ -1406,6 +1412,146 @@ async def test_or_short_circuit_records_skipped_trace(self): assert trace["children"][1]["matched"] is None assert trace["children"][1]["short_circuit_reason"] == "or_matched" + @pytest.mark.asyncio + async def test_leaf_metadata_includes_selector_selected_data_preview(self): + """Leaf metadata should expose a safe preview of the selected selector.path value.""" + # Given: a leaf control selecting a nested step input value + controls = [ + MockControlWithIdentity( + id=1, + name="city_control", + control=ControlDefinition( + description="City guardrail", + enabled=True, + execution="server", + scope={"step_types": ["tool"], "stages": ["pre"]}, + condition={ + "selector": {"path": "input.city"}, + "evaluator": {"name": "test-deny", "config": {"value": "match"}}, + }, + action={"decision": "observe"}, + ), + ) + ] + engine = ControlEngine(controls) + + # When: processing a request where input.city has a concrete value + request = EvaluationRequest( + agent_name="00000000-0000-0000-0000-000000000001", + step=Step( + type="tool", + name="lookup-weather", + input={"city": "San Francisco"}, + output=None, + ), + stage="pre", + ) + result = await engine.process(request) + + # Then: UI consumers can inspect the selected value without raw data export. + assert result.matches is not None + metadata = result.matches[0].result.metadata + assert metadata is not None + assert "selected_data" not in metadata + assert metadata["engine_selected_data_preview"] == { + "type": "str", + "value": "San Francisco", + "truncated": False, + } + + @pytest.mark.asyncio + async def test_leaf_selected_data_preview_is_bounded_and_redacted(self): + """Selected data previews should cap payload size and redact secret-like keys.""" + # Given: a leaf control selecting a large object with a secret-like key + controls = [ + MockControlWithIdentity( + id=1, + name="payload_control", + control=ControlDefinition( + description="Payload guardrail", + enabled=True, + execution="server", + scope={"step_types": ["tool"], "stages": ["pre"]}, + condition={ + "selector": {"path": "input"}, + "evaluator": {"name": "test-deny", "config": {"value": "match"}}, + }, + action={"decision": "observe"}, + ), + ) + ] + engine = ControlEngine(controls) + request = EvaluationRequest( + agent_name="00000000-0000-0000-0000-000000000001", + step=Step( + type="tool", + name="send-payload", + input={ + "prompt": "x" * 600, + "api_key": "secret-value", + }, + output=None, + ), + stage="pre", + ) + + # When: processing the request + result = await engine.process(request) + + # Then: the preview is useful for UI inspection but does not expose the raw payload. + assert result.matches is not None + metadata = result.matches[0].result.metadata + assert metadata is not None + preview = metadata["engine_selected_data_preview"] + assert preview["type"] == "dict" + assert preview["truncated"] is True + assert preview["value"]["api_key"] == "" + assert preview["value"]["prompt"].endswith("...") + assert len(preview["value"]["prompt"]) == 500 + + @pytest.mark.asyncio + async def test_engine_selected_data_does_not_overwrite_evaluator_metadata(self): + """Engine-owned selector data should not collide with evaluator-owned metadata.""" + # Given: an evaluator that deliberately returns its own selected_data key + controls = [ + MockControlWithIdentity( + id=1, + name="metadata_control", + control=ControlDefinition( + description="Metadata guardrail", + enabled=True, + execution="server", + scope={"step_types": ["llm"], "stages": ["pre"]}, + condition={ + "selector": {"path": "input"}, + "evaluator": {"name": "test-metadata", "config": {"value": "match"}}, + }, + action={"decision": "observe"}, + ), + ) + ] + engine = ControlEngine(controls, include_raw_selected_data=True) + request = EvaluationRequest( + agent_name="00000000-0000-0000-0000-000000000001", + step=Step(type="llm", name="test-step", input="raw input", output=None), + stage="pre", + ) + + # When: processing the request + result = await engine.process(request) + + # Then: evaluator-owned metadata remains intact and engine-owned data is namespaced. + assert result.matches is not None + metadata = result.matches[0].result.metadata + assert metadata is not None + assert metadata["selected_data"] == "evaluator:raw input" + assert metadata["engine_selected_data"] == "raw input" + assert metadata["engine_selected_data_preview"] == { + "type": "str", + "value": "raw input", + "truncated": False, + } + @pytest.mark.asyncio async def test_composite_results_preserve_decisive_child_metadata(self): """Composite results should retain structured metadata from the decisive child.""" @@ -1463,7 +1609,12 @@ async def test_composite_results_preserve_decisive_child_metadata(self): metadata = result.matches[0].result.metadata assert metadata is not None assert metadata["source"] == "match-right" - assert metadata["selected_data"] == "chosen" + assert metadata["selected_data"] == "evaluator:chosen" + assert metadata["engine_selected_data_preview"] == { + "type": "str", + "value": "chosen", + "truncated": False, + } assert metadata["condition_trace"]["type"] == "or" assert "slow:skip-tail:start" not in _execution_log @@ -1830,10 +1981,20 @@ async def test_server_context_only_runs_server_controls(self): """ controls = [ make_control_with_execution( - 1, "local_ctrl", "test-allow", action="observe", config_value="loc", execution="sdk" + 1, + "local_ctrl", + "test-allow", + action="observe", + config_value="loc", + execution="sdk", ), make_control_with_execution( - 2, "server_ctrl", "test-allow", action="observe", config_value="srv", execution="server" + 2, + "server_ctrl", + "test-allow", + action="observe", + config_value="srv", + execution="server", ), ] engine = ControlEngine(controls, context="server") @@ -1860,10 +2021,20 @@ async def test_sdk_context_only_runs_sdk_controls(self): """ controls = [ make_control_with_execution( - 1, "local_ctrl", "test-allow", action="observe", config_value="loc", execution="sdk" + 1, + "local_ctrl", + "test-allow", + action="observe", + config_value="loc", + execution="sdk", ), make_control_with_execution( - 2, "server_ctrl", "test-allow", action="observe", config_value="srv", execution="server" + 2, + "server_ctrl", + "test-allow", + action="observe", + config_value="srv", + execution="server", ), ] engine = ControlEngine(controls, context="sdk") @@ -1890,10 +2061,20 @@ async def test_default_context_is_server(self): """ controls = [ make_control_with_execution( - 1, "local_ctrl", "test-allow", action="observe", config_value="loc", execution="sdk" + 1, + "local_ctrl", + "test-allow", + action="observe", + config_value="loc", + execution="sdk", ), make_control_with_execution( - 2, "server_ctrl", "test-allow", action="observe", config_value="srv", execution="server" + 2, + "server_ctrl", + "test-allow", + action="observe", + config_value="srv", + execution="server", ), ] engine = ControlEngine(controls) # No context param diff --git a/evaluators/builtin/src/agent_control_evaluators/__init__.py b/evaluators/builtin/src/agent_control_evaluators/__init__.py index b1dabd9e..d435d801 100644 --- a/evaluators/builtin/src/agent_control_evaluators/__init__.py +++ b/evaluators/builtin/src/agent_control_evaluators/__init__.py @@ -28,7 +28,11 @@ __version__ = "0.0.0.dev" # Core infrastructure - export from _base and _registry -from agent_control_evaluators._base import Evaluator, EvaluatorConfig, EvaluatorMetadata +from agent_control_evaluators._base import ( + Evaluator, + EvaluatorConfig, + EvaluatorMetadata, +) from agent_control_evaluators._discovery import ( discover_evaluators, ensure_evaluators_discovered, diff --git a/evaluators/builtin/src/agent_control_evaluators/_base.py b/evaluators/builtin/src/agent_control_evaluators/_base.py index f5e6fc77..bf36f8c1 100644 --- a/evaluators/builtin/src/agent_control_evaluators/_base.py +++ b/evaluators/builtin/src/agent_control_evaluators/_base.py @@ -120,6 +120,7 @@ async def evaluate(self, data: Any) -> EvaluatorResult: message="Evaluation complete" ) ``` + """ metadata: ClassVar[EvaluatorMetadata] diff --git a/evaluators/builtin/tests/list/test_list_extra.py b/evaluators/builtin/tests/list/test_list_extra.py new file mode 100644 index 00000000..ff8fe90a --- /dev/null +++ b/evaluators/builtin/tests/list/test_list_extra.py @@ -0,0 +1,63 @@ +"""Targeted tests covering match_mode branches and edge-case messages.""" + +from __future__ import annotations + +import pytest +from agent_control_evaluators.list.config import ListEvaluatorConfig +from agent_control_evaluators.list.evaluator import ListEvaluator + + +@pytest.mark.asyncio +async def test_match_mode_contains_uses_word_boundary(): + """contains mode matches whole words but rejects sub-word matches.""" + config = ListEvaluatorConfig(values=["admin"], match_mode="contains") + evaluator = ListEvaluator(config) + + matched = await evaluator.evaluate("the admin user logged in") + assert matched.matched is True + + not_matched = await evaluator.evaluate("administrator") # sub-word, no boundary + assert not_matched.matched is False + + +@pytest.mark.asyncio +async def test_match_mode_exact_is_the_default(): + """No explicit mode uses anchored exact matching.""" + config = ListEvaluatorConfig(values=["admin"]) + evaluator = ListEvaluator(config) + + exact = await evaluator.evaluate("admin") + assert exact.matched is True + + partial = await evaluator.evaluate("admin user") # not anchored end + assert partial.matched is False + + +@pytest.mark.asyncio +async def test_data_none_returns_empty_input_message(): + """None input is treated as empty and the control is ignored.""" + config = ListEvaluatorConfig(values=["x"]) + evaluator = ListEvaluator(config) + + result = await evaluator.evaluate(None) + + assert result.matched is False + assert result.message == "Empty input - control ignored" + assert result.metadata["input_count"] == 0 + + +@pytest.mark.asyncio +async def test_message_truncates_match_list_at_five(): + """More than five matches collapse into a ``(+N more)`` suffix.""" + config = ListEvaluatorConfig( + values=["a", "b", "c", "d", "e", "f", "g"], + logic="any", + ) + evaluator = ListEvaluator(config) + + result = await evaluator.evaluate(["a", "b", "c", "d", "e", "f", "g"]) + + assert result.matched is True + # First five matches appear, the rest summarized. + assert "a, b, c, d, e" in result.message + assert "(+2 more)" in result.message diff --git a/evaluators/builtin/tests/regex/__init__.py b/evaluators/builtin/tests/regex/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/evaluators/builtin/tests/regex/test_regex.py b/evaluators/builtin/tests/regex/test_regex.py new file mode 100644 index 00000000..9df69560 --- /dev/null +++ b/evaluators/builtin/tests/regex/test_regex.py @@ -0,0 +1,115 @@ +"""Tests for the regex evaluator and its config validation.""" + +from __future__ import annotations + +import pytest +from agent_control_evaluators.regex.config import RegexEvaluatorConfig +from agent_control_evaluators.regex.evaluator import RegexEvaluator + + +class TestRegexConfig: + """Pattern validation rejects invalid RE2 syntax at config time.""" + + def test_valid_pattern_accepted(self): + config = RegexEvaluatorConfig(pattern=r"\d{3}-\d{2}-\d{4}") + assert config.pattern == r"\d{3}-\d{2}-\d{4}" + + def test_empty_pattern_accepted(self): + # Empty string is technically a valid RE2 pattern (matches everything). + config = RegexEvaluatorConfig(pattern="") + assert config.pattern == "" + + def test_invalid_pattern_rejected(self): + with pytest.raises(ValueError, match="Invalid regex pattern"): + RegexEvaluatorConfig(pattern="[invalid(regex") + + def test_flags_default_to_none(self): + config = RegexEvaluatorConfig(pattern=r"\d+") + assert config.flags is None + + def test_flags_can_be_specified(self): + config = RegexEvaluatorConfig(pattern="secret", flags=["IGNORECASE"]) + assert config.flags == ["IGNORECASE"] + + +class TestRegexEvaluator: + """Pattern matching against arbitrary data.""" + + @pytest.mark.asyncio + async def test_match_returns_matched_true(self): + evaluator = RegexEvaluator.from_dict({"pattern": r"\d{3}-\d{4}"}) + + result = await evaluator.evaluate("call 555-1234 today") + + assert result.matched is True + assert result.confidence == 1.0 + assert "found" in result.message + assert result.metadata["pattern"] == r"\d{3}-\d{4}" + + @pytest.mark.asyncio + async def test_no_match_returns_matched_false(self): + evaluator = RegexEvaluator.from_dict({"pattern": r"\d{3}-\d{4}"}) + + result = await evaluator.evaluate("no numbers here") + + assert result.matched is False + assert "not found" in result.message + + @pytest.mark.asyncio + async def test_none_data_returns_no_data_message(self): + evaluator = RegexEvaluator.from_dict({"pattern": r".*"}) + + result = await evaluator.evaluate(None) + + assert result.matched is False + assert result.message == "No data to match" + + @pytest.mark.asyncio + async def test_non_string_data_is_coerced(self): + """Non-string inputs are stringified before matching.""" + evaluator = RegexEvaluator.from_dict({"pattern": r"^42$"}) + + result = await evaluator.evaluate(42) + + assert result.matched is True + + @pytest.mark.asyncio + async def test_ignorecase_flag_short_form(self): + """The ``I`` short form is treated the same as ``IGNORECASE``.""" + evaluator = RegexEvaluator.from_dict( + {"pattern": "SECRET", "flags": ["I"]}, + ) + + result = await evaluator.evaluate("the secret value") + + assert result.matched is True + + @pytest.mark.asyncio + async def test_ignorecase_flag_long_form(self): + evaluator = RegexEvaluator.from_dict( + {"pattern": "secret", "flags": ["IGNORECASE"]}, + ) + + result = await evaluator.evaluate("THE SECRET VALUE") + + assert result.matched is True + + @pytest.mark.asyncio + async def test_unknown_flag_is_ignored(self): + """RE2 supports a narrow flag set; unknown flag names must not raise.""" + evaluator = RegexEvaluator.from_dict( + {"pattern": "x", "flags": ["MULTILINE"]}, + ) + + result = await evaluator.evaluate("xyz") + + # Should still work — unknown flag is silently dropped, not an error. + assert result.matched is True + + @pytest.mark.asyncio + async def test_case_sensitive_by_default(self): + evaluator = RegexEvaluator.from_dict({"pattern": "Secret"}) + + result = await evaluator.evaluate("the secret value") + + assert result.matched is False diff --git a/evaluators/builtin/tests/sql/test_sql_config_validation.py b/evaluators/builtin/tests/sql/test_sql_config_validation.py new file mode 100644 index 00000000..8842ed4f --- /dev/null +++ b/evaluators/builtin/tests/sql/test_sql_config_validation.py @@ -0,0 +1,103 @@ +"""Targeted tests for SQLEvaluatorConfig validate_config branches.""" + +from __future__ import annotations + +import warnings + +import pytest +from agent_control_evaluators.sql.config import SQLEvaluatorConfig + + +class TestConflictingRestrictions: + """Mutually-exclusive allow/block lists must be rejected at config time.""" + + def test_blocked_and_allowed_operations_conflict(self): + with pytest.raises(ValueError, match="blocked_operations and allowed_operations"): + SQLEvaluatorConfig( + blocked_operations=["DELETE"], + allowed_operations=["SELECT"], + ) + + def test_blocked_and_allowed_tables_conflict(self): + with pytest.raises(ValueError, match="allowed_tables and blocked_tables"): + SQLEvaluatorConfig( + allowed_tables=["users"], + blocked_tables=["secrets"], + ) + + def test_blocked_and_allowed_schemas_conflict(self): + with pytest.raises(ValueError, match="allowed_schemas and blocked_schemas"): + SQLEvaluatorConfig( + allowed_schemas=["public"], + blocked_schemas=["internal"], + ) + + +class TestLimitBounds: + """Numeric controls must be positive.""" + + def test_max_limit_must_be_positive(self): + with pytest.raises(ValueError, match="max_limit must be a positive integer"): + SQLEvaluatorConfig(max_limit=0) + + def test_max_limit_negative_rejected(self): + with pytest.raises(ValueError, match="max_limit must be a positive integer"): + SQLEvaluatorConfig(max_limit=-5) + + def test_max_statements_must_be_positive(self): + with pytest.raises(ValueError, match="max_statements must be a positive integer"): + SQLEvaluatorConfig( + allow_multi_statements=True, + max_statements=0, + ) + + +class TestColumnControls: + """Column-level validators cover required_column_values shape rules.""" + + def test_column_context_without_required_columns_warns(self): + with pytest.warns(UserWarning, match="column_context is set but required_columns"): + SQLEvaluatorConfig(column_context="where") + + def test_required_column_values_rejects_empty_column_ref(self): + with pytest.raises(ValueError, match="empty column reference"): + SQLEvaluatorConfig( + required_columns=["tenant_id"], + required_column_values={" ": "tenant_id"}, + ) + + def test_required_column_values_rejects_malformed_qualified_ref(self): + with pytest.raises( + ValueError, match="'table.column' format when qualified" + ): + SQLEvaluatorConfig( + required_columns=["tenant_id"], + required_column_values={"users.": "tenant_id"}, + ) + + def test_required_column_values_rejects_blank_qualified_table_side(self): + with pytest.raises( + ValueError, match="'table.column' format when qualified" + ): + SQLEvaluatorConfig( + required_columns=["tenant_id"], + required_column_values={".tenant_id": "tenant_id"}, + ) + + def test_required_column_values_rejects_empty_context_key(self): + with pytest.raises(ValueError, match="empty context key"): + SQLEvaluatorConfig( + required_columns=["tenant_id"], + required_column_values={"users.tenant_id": " "}, + ) + + def test_valid_required_column_values_accepted(self): + """Sanity check: a valid combination passes without raising.""" + with warnings.catch_warnings(): + warnings.simplefilter("error") # promote any warning to a failure + config = SQLEvaluatorConfig( + required_columns=["tenant_id"], + column_context="where", + required_column_values={"users.tenant_id": "tenant_id"}, + ) + assert config.required_column_values == {"users.tenant_id": "tenant_id"} diff --git a/evaluators/builtin/tests/test_discovery.py b/evaluators/builtin/tests/test_discovery.py new file mode 100644 index 00000000..62876412 --- /dev/null +++ b/evaluators/builtin/tests/test_discovery.py @@ -0,0 +1,187 @@ +"""Tests for entry-point-based evaluator discovery.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from agent_control_evaluators import ( + Evaluator, + EvaluatorConfig, + EvaluatorMetadata, + clear_evaluators, + discover_evaluators, + ensure_evaluators_discovered, + get_all_evaluators, + list_evaluators, + register_evaluator, + reset_evaluator_discovery, +) +from agent_control_evaluators import _discovery as discovery_module +from agent_control_models import EvaluatorResult + + +class _DiscoveryConfig(EvaluatorConfig): + pass + + +def _make_class(*, name: str, available: bool = True) -> type[Evaluator[_DiscoveryConfig]]: + class _Dummy(Evaluator[_DiscoveryConfig]): + metadata = EvaluatorMetadata(name=name, version="1.0.0", description="") + config_model = _DiscoveryConfig + + @classmethod + def is_available(cls) -> bool: + return available + + async def evaluate(self, data: Any) -> EvaluatorResult: + return EvaluatorResult(matched=False, confidence=1.0, message="") + + _Dummy.__name__ = f"Discovery_{name.replace('-', '_')}" + return _Dummy + + +@pytest.fixture +def isolated_discovery(): + """Snapshot registry + discovery flag, restore on teardown.""" + snapshot = dict(get_all_evaluators()) + clear_evaluators() + reset_evaluator_discovery() + yield + clear_evaluators() + reset_evaluator_discovery() + for cls in snapshot.values(): + register_evaluator(cls) + + +def _make_fake_entry_point(name: str, evaluator_class: type[Any]) -> MagicMock: + """Build a MagicMock that mimics importlib.metadata.EntryPoint.""" + ep = MagicMock() + ep.name = name + ep.load.return_value = evaluator_class + return ep + + +def test_discover_evaluators_registers_available_classes(isolated_discovery): + """Discover walks the entry-point group and registers each available class.""" + cls = _make_class(name="disc-a") + fake_ep = _make_fake_entry_point("disc-a", cls) + + with patch.object(discovery_module, "entry_points", return_value=[fake_ep]): + count = discover_evaluators() + + assert count == 1 + assert get_all_evaluators().get("disc-a") is cls + + +def test_discover_evaluators_skips_unavailable_classes(isolated_discovery): + """Evaluators whose is_available() is False must NOT be registered.""" + cls = _make_class(name="disc-unavailable", available=False) + fake_ep = _make_fake_entry_point("disc-unavailable", cls) + + with patch.object(discovery_module, "entry_points", return_value=[fake_ep]): + count = discover_evaluators() + + assert count == 0 + assert "disc-unavailable" not in get_all_evaluators() + + +def test_discover_evaluators_skips_already_registered(isolated_discovery): + """Already-registered names are skipped without raising.""" + cls = _make_class(name="disc-existing") + register_evaluator(cls) + + fake_ep = _make_fake_entry_point("disc-existing", cls) + with patch.object(discovery_module, "entry_points", return_value=[fake_ep]): + count = discover_evaluators() + + assert count == 0 + + +def test_discover_evaluators_only_runs_once(isolated_discovery): + """Repeat calls short-circuit on the _DISCOVERY_COMPLETE flag.""" + cls = _make_class(name="disc-once") + fake_ep = _make_fake_entry_point("disc-once", cls) + + with patch.object( + discovery_module, "entry_points", return_value=[fake_ep] + ) as patched: + first = discover_evaluators() + second = discover_evaluators() + + # First call discovers, second returns 0 without consulting entry_points. + assert first == 1 + assert second == 0 + assert patched.call_count == 1 + + +def test_discover_evaluators_swallows_load_failures(isolated_discovery): + """A broken entry point is logged and skipped, not propagated.""" + bad_ep = MagicMock() + bad_ep.name = "broken" + bad_ep.load.side_effect = RuntimeError("boom") + + good_cls = _make_class(name="disc-good") + good_ep = _make_fake_entry_point("disc-good", good_cls) + + with patch.object(discovery_module, "entry_points", return_value=[bad_ep, good_ep]): + count = discover_evaluators() + + assert count == 1 + assert get_all_evaluators().get("disc-good") is good_cls + + +def test_discover_evaluators_handles_entry_points_failure(isolated_discovery): + """If entry_points() itself raises, discovery completes with zero results.""" + with patch.object( + discovery_module, + "entry_points", + side_effect=RuntimeError("entry-point system unavailable"), + ): + count = discover_evaluators() + + assert count == 0 + + +def test_reset_evaluator_discovery_allows_rerun(isolated_discovery): + """reset_evaluator_discovery clears the completed flag so discover runs again.""" + cls = _make_class(name="disc-reset") + fake_ep = _make_fake_entry_point("disc-reset", cls) + + with patch.object( + discovery_module, "entry_points", return_value=[fake_ep] + ) as patched: + discover_evaluators() + clear_evaluators() + reset_evaluator_discovery() + count = discover_evaluators() + + assert count == 1 + assert patched.call_count == 2 + + +def test_ensure_evaluators_discovered_runs_once(isolated_discovery): + """ensure_evaluators_discovered is the lazy-init entry point.""" + cls = _make_class(name="disc-ensure") + fake_ep = _make_fake_entry_point("disc-ensure", cls) + + with patch.object( + discovery_module, "entry_points", return_value=[fake_ep] + ) as patched: + ensure_evaluators_discovered() + ensure_evaluators_discovered() + + assert patched.call_count == 1 + assert get_all_evaluators().get("disc-ensure") is cls + + +def test_list_evaluators_triggers_discovery(isolated_discovery): + """list_evaluators is the convenience accessor; it must trigger discovery.""" + cls = _make_class(name="disc-list") + fake_ep = _make_fake_entry_point("disc-list", cls) + + with patch.object(discovery_module, "entry_points", return_value=[fake_ep]): + result = list_evaluators() + + assert result.get("disc-list") is cls diff --git a/evaluators/builtin/tests/test_factory.py b/evaluators/builtin/tests/test_factory.py new file mode 100644 index 00000000..4bba4b82 --- /dev/null +++ b/evaluators/builtin/tests/test_factory.py @@ -0,0 +1,172 @@ +"""Tests for the LRU-cached evaluator factory.""" + +from __future__ import annotations + +import importlib +from typing import Any + +import pytest +from agent_control_evaluators import ( + Evaluator, + EvaluatorConfig, + EvaluatorMetadata, + clear_evaluator_cache, + clear_evaluators, + get_all_evaluators, + get_evaluator_instance, + register_evaluator, +) +from agent_control_evaluators import _factory as factory_module +from agent_control_models import EvaluatorResult, EvaluatorSpec + + +class _FactoryConfig(EvaluatorConfig): + payload: str = "default" + + +class _FactoryEvaluator(Evaluator[_FactoryConfig]): + metadata = EvaluatorMetadata(name="factory-dummy", version="1.0.0", description="") + config_model = _FactoryConfig + + async def evaluate(self, data: Any) -> EvaluatorResult: + return EvaluatorResult(matched=False, confidence=1.0, message="") + + +@pytest.fixture +def isolated_factory(): + """Snapshot registry/cache so factory tests don't leak state.""" + snapshot = dict(get_all_evaluators()) + clear_evaluators() + clear_evaluator_cache() + register_evaluator(_FactoryEvaluator) + yield + clear_evaluator_cache() + clear_evaluators() + for cls in snapshot.values(): + register_evaluator(cls) + + +def test_get_evaluator_instance_returns_evaluator(isolated_factory): + spec = EvaluatorSpec(name="factory-dummy", config={"payload": "p1"}) + + instance = get_evaluator_instance(spec) + + assert isinstance(instance, _FactoryEvaluator) + assert instance.config.payload == "p1" + + +def test_get_evaluator_instance_caches_by_config(isolated_factory): + spec_a = EvaluatorSpec(name="factory-dummy", config={"payload": "same"}) + spec_b = EvaluatorSpec(name="factory-dummy", config={"payload": "same"}) + + first = get_evaluator_instance(spec_a) + second = get_evaluator_instance(spec_b) + + # Same config = same cached instance. + assert first is second + + +def test_get_evaluator_instance_treats_different_configs_separately(isolated_factory): + spec_a = EvaluatorSpec(name="factory-dummy", config={"payload": "a"}) + spec_b = EvaluatorSpec(name="factory-dummy", config={"payload": "b"}) + + instance_a = get_evaluator_instance(spec_a) + instance_b = get_evaluator_instance(spec_b) + + assert instance_a is not instance_b + assert instance_a.config.payload == "a" + assert instance_b.config.payload == "b" + + +def test_get_evaluator_instance_raises_for_unknown_evaluator(isolated_factory): + with pytest.raises(ValueError, match="not found"): + get_evaluator_instance(EvaluatorSpec(name="no-such-evaluator", config={})) + + +def test_clear_evaluator_cache_forces_recreation(isolated_factory): + spec = EvaluatorSpec(name="factory-dummy", config={"payload": "p"}) + + first = get_evaluator_instance(spec) + clear_evaluator_cache() + second = get_evaluator_instance(spec) + + assert first is not second + + +def test_get_evaluator_instance_evicts_oldest_when_full(isolated_factory, monkeypatch): + """LRU eviction: when cache is full, the least-recently-used entry is dropped.""" + # Force a tiny cache so we can observe eviction without overhead. + monkeypatch.setattr(factory_module, "EVALUATOR_CACHE_SIZE", 2) + + spec_a = EvaluatorSpec(name="factory-dummy", config={"payload": "a"}) + spec_b = EvaluatorSpec(name="factory-dummy", config={"payload": "b"}) + spec_c = EvaluatorSpec(name="factory-dummy", config={"payload": "c"}) + + first_a = get_evaluator_instance(spec_a) + get_evaluator_instance(spec_b) + # Insert third → "a" is the LRU and must be evicted. + get_evaluator_instance(spec_c) + + re_a = get_evaluator_instance(spec_a) + # "a" was evicted: new instance must NOT be the original. + assert re_a is not first_a + + +def test_get_evaluator_instance_moves_hit_to_most_recent( + isolated_factory, monkeypatch +): + """Cache hit must refresh LRU recency so the touched entry isn't evicted next.""" + monkeypatch.setattr(factory_module, "EVALUATOR_CACHE_SIZE", 2) + + spec_a = EvaluatorSpec(name="factory-dummy", config={"payload": "a"}) + spec_b = EvaluatorSpec(name="factory-dummy", config={"payload": "b"}) + spec_c = EvaluatorSpec(name="factory-dummy", config={"payload": "c"}) + + first_a = get_evaluator_instance(spec_a) + get_evaluator_instance(spec_b) + # Touch "a" so "b" becomes the LRU. + re_a = get_evaluator_instance(spec_a) + assert re_a is first_a + + # Inserting "c" should evict "b", not "a". + get_evaluator_instance(spec_c) + + refetched_a = get_evaluator_instance(spec_a) + assert refetched_a is first_a # still cached + + +def test_parse_cache_size_uses_default_when_unset(monkeypatch): + monkeypatch.delenv("EVALUATOR_CACHE_SIZE", raising=False) + reloaded = importlib.reload(factory_module) + try: + assert reloaded.EVALUATOR_CACHE_SIZE == factory_module.DEFAULT_CACHE_SIZE + finally: + importlib.reload(factory_module) + + +def test_parse_cache_size_falls_back_on_invalid_value(monkeypatch): + monkeypatch.setenv("EVALUATOR_CACHE_SIZE", "not-a-number") + reloaded = importlib.reload(factory_module) + try: + assert reloaded.EVALUATOR_CACHE_SIZE == reloaded.DEFAULT_CACHE_SIZE + finally: + importlib.reload(factory_module) + + +def test_parse_cache_size_clamps_to_minimum(monkeypatch): + monkeypatch.setenv("EVALUATOR_CACHE_SIZE", "0") + reloaded = importlib.reload(factory_module) + try: + # Anything below MIN_CACHE_SIZE is clamped to avoid infinite eviction loops. + assert reloaded.EVALUATOR_CACHE_SIZE >= reloaded.MIN_CACHE_SIZE + finally: + importlib.reload(factory_module) + + +def test_parse_cache_size_accepts_valid_int(monkeypatch): + monkeypatch.setenv("EVALUATOR_CACHE_SIZE", "42") + reloaded = importlib.reload(factory_module) + try: + assert reloaded.EVALUATOR_CACHE_SIZE == 42 + finally: + importlib.reload(factory_module) diff --git a/evaluators/builtin/tests/test_registry.py b/evaluators/builtin/tests/test_registry.py new file mode 100644 index 00000000..6b663129 --- /dev/null +++ b/evaluators/builtin/tests/test_registry.py @@ -0,0 +1,119 @@ +"""Tests for the in-memory evaluator registry.""" + +from __future__ import annotations + +from typing import Any + +import pytest +from agent_control_evaluators import ( + Evaluator, + EvaluatorConfig, + EvaluatorMetadata, + clear_evaluators, + get_all_evaluators, + get_evaluator, + register_evaluator, +) +from agent_control_models import EvaluatorResult + + +class _DummyConfig(EvaluatorConfig): + pass + + +def _make_class(*, name: str, available: bool = True) -> type[Evaluator[_DummyConfig]]: + """Build a fresh Evaluator subclass with the supplied metadata name.""" + + class _Dummy(Evaluator[_DummyConfig]): + metadata = EvaluatorMetadata( + name=name, + version="1.0.0", + description="", + ) + config_model = _DummyConfig + + @classmethod + def is_available(cls) -> bool: + return available + + async def evaluate(self, data: Any) -> EvaluatorResult: + return EvaluatorResult(matched=False, confidence=1.0, message="") + + _Dummy.__name__ = f"Dummy_{name.replace('-', '_')}" + return _Dummy + + +@pytest.fixture +def isolated_registry(): + """Snapshot and restore the global registry so tests don't leak state.""" + snapshot = dict(get_all_evaluators()) + clear_evaluators() + yield + clear_evaluators() + for cls in snapshot.values(): + register_evaluator(cls) + + +def test_register_and_lookup_evaluator(isolated_registry): + cls = _make_class(name="reg-a") + + register_evaluator(cls) + + assert get_evaluator("reg-a") is cls + + +def test_get_evaluator_returns_none_when_not_registered(isolated_registry): + assert get_evaluator("does-not-exist") is None + + +def test_get_all_evaluators_returns_copy(isolated_registry): + cls = _make_class(name="reg-copy") + register_evaluator(cls) + + snapshot = get_all_evaluators() + snapshot["evil"] = cls # mutate the returned dict + + # Internal registry must not reflect external mutation. + assert "evil" not in get_all_evaluators() + + +def test_register_is_idempotent_for_same_class(isolated_registry): + cls = _make_class(name="reg-idem") + + register_evaluator(cls) + # Registering the exact same class again must not raise. + assert register_evaluator(cls) is cls + + +def test_register_rejects_name_collision_with_different_class(isolated_registry): + first = _make_class(name="reg-conflict") + second = _make_class(name="reg-conflict") + register_evaluator(first) + + with pytest.raises(ValueError, match="already registered"): + register_evaluator(second) + + +def test_register_skips_unavailable_evaluators(isolated_registry): + cls = _make_class(name="reg-unavailable", available=False) + + # Should not raise and should not register. + assert register_evaluator(cls) is cls + assert get_evaluator("reg-unavailable") is None + + +def test_clear_evaluators_empties_registry(isolated_registry): + register_evaluator(_make_class(name="reg-c1")) + register_evaluator(_make_class(name="reg-c2")) + assert len(get_all_evaluators()) == 2 + + clear_evaluators() + + assert get_all_evaluators() == {} + + +def test_register_decorator_returns_class(isolated_registry): + cls = _make_class(name="reg-decorator") + # The function is documented as decorator-compatible: it must return the class. + decorated = register_evaluator(cls) + assert decorated is cls diff --git a/evaluators/contrib/galileo/pyproject.toml b/evaluators/contrib/galileo/pyproject.toml index b671128e..e9769737 100644 --- a/evaluators/contrib/galileo/pyproject.toml +++ b/evaluators/contrib/galileo/pyproject.toml @@ -23,6 +23,7 @@ dev = [ ] [project.entry-points."agent_control.evaluators"] +"galileo.luna" = "agent_control_evaluator_galileo.luna:LunaEvaluator" "galileo.luna2" = "agent_control_evaluator_galileo.luna2:Luna2Evaluator" [build-system] diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/__init__.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/__init__.py index 6389087f..d9269fe1 100644 --- a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/__init__.py +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/__init__.py @@ -3,6 +3,7 @@ This package provides Galileo evaluators for agent-control. Available evaluators: + - galileo.luna: Galileo Luna direct scorer evaluation - galileo.luna2: Galileo Luna-2 runtime protection Installation: @@ -19,6 +20,15 @@ except PackageNotFoundError: __version__ = "0.0.0.dev" +from agent_control_evaluator_galileo.luna import ( + LUNA_AVAILABLE, + GalileoLunaClient, + LunaEvaluator, + LunaEvaluatorConfig, + LunaOperator, + ScorerInvokeRequest, + ScorerInvokeResponse, +) from agent_control_evaluator_galileo.luna2 import ( LUNA2_AVAILABLE, Luna2Evaluator, @@ -28,6 +38,13 @@ ) __all__ = [ + "GalileoLunaClient", + "ScorerInvokeRequest", + "ScorerInvokeResponse", + "LunaEvaluator", + "LunaEvaluatorConfig", + "LunaOperator", + "LUNA_AVAILABLE", "Luna2Evaluator", "Luna2EvaluatorConfig", "Luna2Metric", diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/__init__.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/__init__.py new file mode 100644 index 00000000..b26feaac --- /dev/null +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/__init__.py @@ -0,0 +1,21 @@ +"""Galileo Luna direct scorer evaluator.""" + +from agent_control_evaluator_galileo.luna.client import ( + GalileoLunaClient, + ScorerInvokeInputs, + ScorerInvokeRequest, + ScorerInvokeResponse, +) +from agent_control_evaluator_galileo.luna.config import LunaEvaluatorConfig, LunaOperator +from agent_control_evaluator_galileo.luna.evaluator import LUNA_AVAILABLE, LunaEvaluator + +__all__ = [ + "GalileoLunaClient", + "ScorerInvokeInputs", + "ScorerInvokeRequest", + "ScorerInvokeResponse", + "LunaEvaluatorConfig", + "LunaOperator", + "LunaEvaluator", + "LUNA_AVAILABLE", +] diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py new file mode 100644 index 00000000..3bbc807f --- /dev/null +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/client.py @@ -0,0 +1,388 @@ +"""Direct HTTP client for Galileo Luna scorer invocation.""" + +from __future__ import annotations + +import logging +import os +from base64 import urlsafe_b64encode +from hashlib import sha256 +from hmac import new as hmac_new +from json import dumps +from time import time +from typing import Literal + +import httpx +from agent_control_models import JSONObject, JSONValue +from pydantic import BaseModel, Field, PrivateAttr, model_validator + +logger = logging.getLogger(__name__) + +DEFAULT_TIMEOUT_SECS = 10.0 +DEFAULT_INTERNAL_TOKEN_TTL_SECS = 3600 +PUBLIC_SCORER_INVOKE_PATH = "/scorers/invoke" +INTERNAL_SCORER_INVOKE_PATH = "/internal/scorers/invoke" +AuthMode = Literal["public", "internal"] + + +def _b64url(data: bytes) -> str: + return urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +def _internal_auth_token( + api_secret: str, + ttl_seconds: int = DEFAULT_INTERNAL_TOKEN_TTL_SECS, +) -> str: + """Create the internal JWT expected by Galileo API internal routes.""" + now = int(time()) + header = {"alg": "HS256", "typ": "JWT"} + payload = { + "internal": True, + "scope": "scorers.invoke", + "iat": now, + "exp": now + ttl_seconds, + } + signing_input = ".".join( + [ + _b64url(dumps(header, separators=(",", ":")).encode("utf-8")), + _b64url(dumps(payload, separators=(",", ":")).encode("utf-8")), + ] + ) + signature = hmac_new(api_secret.encode("utf-8"), signing_input.encode("ascii"), sha256).digest() + return f"{signing_input}.{_b64url(signature)}" + + +def _env_auth_mode() -> AuthMode | None: + value = os.getenv("GALILEO_LUNA_AUTH_MODE") + if value is None or value.strip() == "": + return None + normalized = value.strip().lower() + if normalized == "public": + return "public" + if normalized == "internal": + return "internal" + raise ValueError("GALILEO_LUNA_AUTH_MODE must be either 'public' or 'internal'.") + + +def _as_float_or_none(value: JSONValue) -> float | None: + if isinstance(value, bool) or value is None: + return None + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + try: + return float(value) + except ValueError: + return None + return None + + +def _has_value(value: JSONValue) -> bool: + if value is None: + return False + if isinstance(value, str): + return value.strip() != "" + if isinstance(value, (list, dict)): + return len(value) > 0 + return True + + +class ScorerInvokeInputs(BaseModel): + """Input values sent to Galileo's scorer invoke API.""" + + query: JSONValue = "" + response: JSONValue = "" + ground_truth: JSONValue = None + tools: JSONValue = None + + +class ScorerInvokeRequest(BaseModel): + """Request payload for Galileo Luna scorer invocation. + + Attributes: + inputs: Selected scorer input values. + scorer_label: Preset, registered, or fine-tuned scorer label. + scorer_id: Optional Galileo scorer identifier. + scorer_version_id: Optional Galileo scorer version identifier. + config: Optional scorer-specific configuration. + """ + + inputs: ScorerInvokeInputs + scorer_label: str | None = Field(default=None, min_length=1) + scorer_id: str | None = Field(default=None, min_length=1) + scorer_version_id: str | None = Field(default=None, min_length=1) + config: JSONObject | None = None + + @model_validator(mode="after") + def ensure_required_values(self) -> ScorerInvokeRequest: + if not (self.scorer_label or self.scorer_id or self.scorer_version_id): + raise ValueError( + "One of scorer_label, scorer_id, or scorer_version_id must be set." + ) + if not (_has_value(self.inputs.query) or _has_value(self.inputs.response)): + raise ValueError("Either inputs.query or inputs.response must be set.") + return self + + def to_dict(self) -> JSONObject: + """Convert to the Galileo scorer invoke API request shape.""" + return self.model_dump(mode="json", exclude_none=True) + + +class ScorerInvokeResponse(BaseModel): + """Response from Galileo Luna scorer invocation. + + Attributes: + scorer_label: Echoed scorer label, when returned. + score: Raw scorer value. + status: Invocation status. + execution_time: Execution time in seconds, when returned. + error_message: Error detail for non-success statuses. + """ + + scorer_label: str | None = None + score: JSONValue + status: str = "unknown" + execution_time: float | None = None + error_message: str | None = None + _raw_response: JSONObject = PrivateAttr(default_factory=dict) + + @property + def raw_response(self) -> JSONObject: + return self._raw_response + + @classmethod + def from_dict(cls, data: JSONObject) -> ScorerInvokeResponse: + """Create a response model from the API JSON object.""" + response = cls.model_validate( + data | {"execution_time": _as_float_or_none(data.get("execution_time"))} + ) + response._raw_response = data + return response + + +class GalileoLunaClient: + """Thin HTTP client for Galileo Luna direct scorer invocation. + + Environment Variables: + GALILEO_API_SECRET_KEY or GALILEO_API_SECRET: Galileo API internal JWT signing secret. + GALILEO_API_KEY: Galileo API key fallback for public scorer invocation. + GALILEO_LUNA_AUTH_MODE: Auth mode, either "public" or "internal". + GALILEO_CONSOLE_URL: Galileo Console URL (optional, defaults to production). + """ + + def __init__( + self, + api_key: str | None = None, + api_secret: str | None = None, + console_url: str | None = None, + api_url: str | None = None, + auth_mode: AuthMode | None = None, + ) -> None: + """Initialize the Galileo Luna client. + + Args: + api_key: Galileo API key. If not provided, reads from GALILEO_API_KEY. + api_secret: Galileo API secret for internal JWT auth. If not provided, + reads from GALILEO_API_SECRET_KEY or GALILEO_API_SECRET. + console_url: Galileo Console URL. If not provided, reads from + GALILEO_CONSOLE_URL or uses the production console URL. + api_url: Galileo API URL. If not provided, reads from GALILEO_API_URL + before deriving from the console URL. + auth_mode: Auth mode to use. If not provided, reads from + GALILEO_LUNA_AUTH_MODE, or infers from the single available credential. + + Raises: + ValueError: If credentials are missing, ambiguous, or incompatible with + the selected auth mode. + """ + resolved_api_secret = ( + api_secret or os.getenv("GALILEO_API_SECRET_KEY") or os.getenv("GALILEO_API_SECRET") + ) + resolved_api_key = api_key or os.getenv("GALILEO_API_KEY") + resolved_auth_mode = self._resolve_auth_mode( + auth_mode or _env_auth_mode(), + api_key=resolved_api_key, + api_secret=resolved_api_secret, + ) + + self.api_key = resolved_api_key + self.api_secret = resolved_api_secret + self.auth_mode = resolved_auth_mode + self.console_url = ( + console_url or os.getenv("GALILEO_CONSOLE_URL") or "https://console.galileo.ai" + ) + self.api_base = (api_url or os.getenv("GALILEO_API_URL") or "").rstrip( + "/" + ) or self._derive_api_url(self.console_url) + self._client: httpx.AsyncClient | None = None + logger.info("[GalileoLunaClient] Auth mode selected: %s", self.auth_mode) + + @staticmethod + def _resolve_auth_mode( + auth_mode: AuthMode | None, + *, + api_key: str | None, + api_secret: str | None, + ) -> AuthMode: + if auth_mode == "public": + if not api_key: + raise ValueError( + "GALILEO_API_KEY is required when GALILEO_LUNA_AUTH_MODE=public." + ) + return "public" + + if auth_mode == "internal": + if not api_secret: + raise ValueError( + "GALILEO_API_SECRET_KEY or GALILEO_API_SECRET is required when " + "GALILEO_LUNA_AUTH_MODE=internal." + ) + return "internal" + + if api_key and api_secret: + raise ValueError( + "Both Galileo API key and API secret are configured. Set " + "GALILEO_LUNA_AUTH_MODE to 'public' or 'internal' to choose the " + "runtime auth mode explicitly." + ) + if api_secret: + return "internal" + if api_key: + return "public" + raise ValueError( + "GALILEO_API_SECRET_KEY or GALILEO_API_KEY is required. " + "Set one as an environment variable or pass it to the constructor." + ) + + def _derive_api_url(self, console_url: str) -> str: + """Derive the API URL from a Galileo Console URL.""" + url = console_url.rstrip("/") + + if "console." in url: + return url.replace("console.", "api.") + if "console-" in url: + return url.replace("console-", "api-", 1) + + if url.startswith("https://"): + return url.replace("https://", "https://api.") + if url.startswith("http://"): + return url.replace("http://", "http://api.") + + return url + + async def _get_client(self) -> httpx.AsyncClient: + """Get or create the HTTP client.""" + if self._client is None or self._client.is_closed: + headers = {"Content-Type": "application/json"} + if self.auth_mode == "public" and self.api_key is not None: + headers["Galileo-API-Key"] = self.api_key + self._client = httpx.AsyncClient( + headers=headers, + timeout=httpx.Timeout(DEFAULT_TIMEOUT_SECS), + ) + return self._client + + def _endpoint_and_headers( + self, + headers: dict[str, str] | None, + ) -> tuple[str, dict[str, str]]: + request_headers = dict(headers or {}) + if self.auth_mode == "public": + return f"{self.api_base}{PUBLIC_SCORER_INVOKE_PATH}", request_headers + + if self.api_secret is None: + raise RuntimeError("Internal Luna auth mode is missing an API secret.") + request_headers["Authorization"] = f"Bearer {_internal_auth_token(self.api_secret)}" + return f"{self.api_base}{INTERNAL_SCORER_INVOKE_PATH}", request_headers + + async def invoke( + self, + *, + scorer_label: str | None = None, + scorer_id: str | None = None, + scorer_version_id: str | None = None, + input: JSONValue = None, + output: JSONValue = None, + config: JSONObject | None = None, + timeout: float = DEFAULT_TIMEOUT_SECS, + headers: dict[str, str] | None = None, + ) -> ScorerInvokeResponse: + """Invoke a Galileo Luna scorer. + + Args: + scorer_label: Preset, registered, or fine-tuned scorer label. + scorer_id: Optional Galileo scorer identifier. + scorer_version_id: Optional Galileo scorer version identifier. + input: Optional user/system prompt text. + output: Optional model response text. + config: Optional scorer-specific configuration. + timeout: Request timeout in seconds. + headers: Additional request headers. + + Returns: + Parsed scorer invocation response. + + Raises: + ValueError: If neither input nor output is provided. + RuntimeError: If the API response is not a JSON object. + httpx.HTTPStatusError: If the API returns an error status code. + httpx.RequestError: If the request fails before a response is received. + """ + if not (scorer_label or scorer_id or scorer_version_id): + raise ValueError("At least one scorer identifier must be provided.") + if not (_has_value(input) or _has_value(output)): + raise ValueError("At least one of input or output must be provided.") + + request_body = ScorerInvokeRequest( + scorer_label=scorer_label, + scorer_id=scorer_id, + scorer_version_id=scorer_version_id, + inputs=ScorerInvokeInputs( + query="" if input is None else input, response="" if output is None else output + ), + config=config, + ).to_dict() + endpoint, request_headers = self._endpoint_and_headers(headers) + + logger.debug("[GalileoLunaClient] POST %s", endpoint) + logger.debug("[GalileoLunaClient] Request body: %s", request_body) + + try: + client = await self._get_client() + response = await client.post( + endpoint, + json=request_body, + headers=request_headers, + timeout=timeout, + ) + response.raise_for_status() + response_data = response.json() + if not isinstance(response_data, dict): + raise RuntimeError("Invalid response payload: not a JSON object") + + parsed = ScorerInvokeResponse.from_dict(response_data) + logger.debug("[GalileoLunaClient] Response: %s", parsed.raw_response) + return parsed + except httpx.HTTPStatusError as exc: + logger.error( + "[GalileoLunaClient] API error: %s - %s", + exc.response.status_code, + exc.response.text, + ) + raise + except httpx.RequestError as exc: + logger.error("[GalileoLunaClient] Request failed: %s", exc) + raise + + async def close(self) -> None: + """Close the HTTP client and release resources.""" + if self._client is not None: + await self._client.aclose() + self._client = None + + async def __aenter__(self) -> GalileoLunaClient: + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: + """Async context manager exit.""" + await self.close() diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py new file mode 100644 index 00000000..788fa24c --- /dev/null +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/config.py @@ -0,0 +1,99 @@ +"""Configuration model for direct Galileo Luna scorer evaluation.""" + +from __future__ import annotations + +from typing import Literal + +from agent_control_evaluators import EvaluatorConfig +from agent_control_models import JSONObject, JSONValue +from pydantic import Field, model_validator + +LunaOperator = Literal["gt", "gte", "lt", "lte", "eq", "ne", "contains", "any"] +LunaPayloadField = Literal["input", "output"] + +_NUMERIC_OPERATORS = frozenset({"gt", "gte", "lt", "lte"}) + + +def coerce_number(value: JSONValue) -> float | None: + """Return a numeric value for JSON scalars that can be compared numerically.""" + if isinstance(value, bool) or value is None: + return None + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + try: + return float(value) + except ValueError: + return None + return None + + +class LunaEvaluatorConfig(EvaluatorConfig): + """Configuration for direct Luna scorer evaluation. + + Attributes: + scorer_label: Preset, registered, or fine-tuned scorer label. + scorer_id: Optional Galileo scorer identifier. + scorer_version_id: Optional Galileo scorer version identifier. + threshold: Local threshold used by the evaluator for comparison. + operator: Local comparison operator. Numeric operators use threshold as a number. + scorer_config: Optional scorer-specific config sent as ``config``. + payload_field: Explicit scorer input side for scalar selected data. + timeout_ms: Request timeout in milliseconds. + """ + + scorer_label: str | None = Field( + default=None, + min_length=1, + description="Luna scorer label to invoke.", + ) + scorer_id: str | None = Field( + default=None, + min_length=1, + description="Optional Galileo scorer identifier to invoke.", + ) + scorer_version_id: str | None = Field( + default=None, + min_length=1, + description="Optional Galileo scorer version identifier to invoke.", + ) + threshold: JSONValue = Field( + default=0.5, + description="Local threshold used to decide whether the control matches.", + ) + operator: LunaOperator = Field( + default="gte", + description="Local comparison operator applied to the raw Luna score.", + ) + scorer_config: JSONObject | None = Field( + default=None, + alias="config", + serialization_alias="config", + description="Optional scorer-specific configuration sent to Galileo.", + ) + payload_field: LunaPayloadField = Field( + default="input", + description=( + "Which scorer input side to use when selector output is a scalar value. " + "Structured selected data with input/output keys overrides this setting." + ), + ) + timeout_ms: int = Field( + default=10000, + ge=1000, + le=60000, + description="Request timeout in milliseconds (1-60 seconds)", + ) + + @model_validator(mode="after") + def validate_threshold(self) -> LunaEvaluatorConfig: + """Validate threshold compatibility with the configured operator.""" + if not (self.scorer_label or self.scorer_id or self.scorer_version_id): + raise ValueError( + "one of scorer_label, scorer_id, or scorer_version_id is required" + ) + if self.operator in _NUMERIC_OPERATORS and coerce_number(self.threshold) is None: + raise ValueError(f"operator '{self.operator}' requires a numeric threshold") + if self.operator != "any" and self.threshold is None: + raise ValueError("threshold is required unless operator is 'any'") + return self diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py new file mode 100644 index 00000000..7b48052f --- /dev/null +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/luna/evaluator.py @@ -0,0 +1,270 @@ +"""Direct Galileo Luna evaluator implementation.""" + +from __future__ import annotations + +import json +import logging +import os +from importlib.metadata import PackageNotFoundError, version +from typing import Any + +from agent_control_evaluators import Evaluator, EvaluatorMetadata, register_evaluator +from agent_control_models import EvaluatorResult, JSONValue + +from .client import GalileoLunaClient, ScorerInvokeResponse +from .config import LunaEvaluatorConfig, coerce_number + +logger = logging.getLogger(__name__) + + +def _resolve_package_version() -> str: + """Return the installed package version, or a dev fallback during local imports.""" + try: + return version("agent-control-evaluator-galileo") + except PackageNotFoundError: + return "0.0.0.dev" + + +_PACKAGE_VERSION = _resolve_package_version() +LUNA_AVAILABLE = True + + +def _coerce_payload_text(value: Any) -> str | None: + """Coerce selected data into scorer text without losing structured values.""" + if value is None: + return None + if isinstance(value, str): + return value + if isinstance(value, (int, float, bool)): + return str(value) + try: + return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str) + except TypeError: + return str(value) + + +def _has_text(value: str | None) -> bool: + return value is not None and value != "" + + +def _extract_dict_text(data: dict[str, Any], key: str) -> str | None: + if key not in data: + return None + return _coerce_payload_text(data.get(key)) + + +def _contains(score: JSONValue, threshold: JSONValue) -> bool: + if threshold is None: + return False + if isinstance(score, str): + return str(threshold) in score + if isinstance(score, list): + return threshold in score + if isinstance(score, dict): + return threshold in score.values() + return False + + +def _confidence_from_score(score: JSONValue) -> float: + if isinstance(score, bool): + return 1.0 if score else 0.0 + number = coerce_number(score) + if number is not None and 0.0 <= number <= 1.0: + return number + return 1.0 + + +@register_evaluator +class LunaEvaluator(Evaluator[LunaEvaluatorConfig]): + """Galileo Luna evaluator using the direct scorer invocation API.""" + + metadata = EvaluatorMetadata( + name="galileo.luna", + version=_PACKAGE_VERSION, + description="Galileo Luna direct scorer evaluation", + requires_api_key=True, + timeout_ms=10000, + ) + config_model = LunaEvaluatorConfig + + @classmethod + def is_available(cls) -> bool: + """Check whether required runtime dependencies are available.""" + return LUNA_AVAILABLE + + def __init__(self, config: LunaEvaluatorConfig) -> None: + """Initialize the direct Luna evaluator. + + Args: + config: Validated LunaEvaluatorConfig instance. + + Raises: + ValueError: If neither GALILEO_API_SECRET_KEY nor GALILEO_API_KEY is set. + """ + has_auth = ( + os.getenv("GALILEO_API_SECRET_KEY") + or os.getenv("GALILEO_API_SECRET") + or os.getenv("GALILEO_API_KEY") + ) + if not has_auth: + raise ValueError( + "GALILEO_API_SECRET_KEY or GALILEO_API_KEY environment variable must be set. " + "Set an API secret for internal auth or a Galileo API key before using " + "galileo.luna." + ) + + super().__init__(config) + self._client = GalileoLunaClient() + + def _get_client(self) -> GalileoLunaClient: + """Get the Galileo Luna client.""" + return self._client + + def _prepare_payload(self, data: Any) -> tuple[str | None, str | None]: + """Prepare scorer input/output fields from selected data.""" + if isinstance(data, dict): + input_text = _extract_dict_text(data, "input") + output_text = _extract_dict_text(data, "output") + if _has_text(input_text) or _has_text(output_text): + return input_text, output_text + + text = _coerce_payload_text(data) + if self.config.payload_field == "output": + return None, text + return text, None + + def _score_matches(self, score: JSONValue) -> bool: + """Apply the configured local threshold comparison to a raw Luna score.""" + operator = self.config.operator + threshold = self.config.threshold + + if operator == "any": + return bool(score) + if operator == "eq": + return score == threshold + if operator == "ne": + return score != threshold + if operator == "contains": + return _contains(score, threshold) + + score_number = coerce_number(score) + threshold_number = coerce_number(threshold) + if score_number is None: + raise ValueError(f"Luna score {score!r} is not numeric") + if threshold_number is None: + raise ValueError(f"Luna threshold {threshold!r} is not numeric") + + if operator == "gt": + return score_number > threshold_number + if operator == "gte": + return score_number >= threshold_number + if operator == "lt": + return score_number < threshold_number + if operator == "lte": + return score_number <= threshold_number + + raise ValueError(f"Unsupported Luna operator: {operator}") + + async def evaluate(self, data: Any) -> EvaluatorResult: + """Evaluate selected data with Galileo Luna direct scorer invocation. + + Args: + data: The data selected from the runtime step. + + Returns: + EvaluatorResult with local threshold decision and scorer metadata. + """ + input_text, output_text = self._prepare_payload(data) + if not (_has_text(input_text) or _has_text(output_text)): + return EvaluatorResult( + matched=False, + confidence=1.0, + message="No data to score with Luna", + metadata=self._base_metadata(), + ) + + try: + scorer_kwargs = self._scorer_kwargs() + response = await self._get_client().invoke( + **scorer_kwargs, + input=input_text if _has_text(input_text) else None, + output=output_text if _has_text(output_text) else None, + config=self.config.scorer_config, + timeout=self.get_timeout_seconds(), + ) + + if response.status.lower() != "success": + message = response.error_message or f"Luna scorer status: {response.status}" + raise RuntimeError(message) + + matched = self._score_matches(response.score) + metadata = self._metadata(response) + operator = self.config.operator + threshold = self.config.threshold + state = "triggered" if matched else "not triggered" + return EvaluatorResult( + matched=matched, + confidence=_confidence_from_score(response.score), + message=( + f"Luna score {response.score!r} {operator} threshold " + f"{threshold!r}: control {state}." + ), + metadata=metadata, + ) + except Exception as exc: + logger.error("Luna evaluation error: %s", exc, exc_info=True) + return self._handle_error(exc) + + def _base_metadata(self) -> dict[str, Any]: + metadata = { + "scorer_label": self.config.scorer_label, + "scorer_id": self.config.scorer_id, + "scorer_version_id": self.config.scorer_version_id, + } + return {key: value for key, value in metadata.items() if value is not None} + + def _scorer_kwargs(self) -> dict[str, Any]: + kwargs = { + "scorer_label": self.config.scorer_label, + "scorer_id": self.config.scorer_id, + "scorer_version_id": self.config.scorer_version_id, + } + return {key: value for key, value in kwargs.items() if value is not None} + + def _metadata( + self, + response: ScorerInvokeResponse, + ) -> dict[str, Any]: + metadata: dict[str, Any] = self._base_metadata() + metadata.update({ + "scorer_label": response.scorer_label or self.config.scorer_label, + "score": response.score, + "threshold": self.config.threshold, + "operator": self.config.operator, + "status": response.status, + "execution_time_seconds": response.execution_time, + "error_message": response.error_message, + }) + return metadata + + def _handle_error( + self, + error: Exception, + ) -> EvaluatorResult: + error_detail = str(error) + return EvaluatorResult( + matched=False, + confidence=0.0, + message=f"Luna evaluation error: {error_detail}", + metadata={ + "error_type": type(error).__name__, + "scorer_label": self.config.scorer_label, + "scorer_id": self.config.scorer_id, + "scorer_version_id": self.config.scorer_version_id, + }, + error=error_detail, + ) + + async def aclose(self) -> None: + """Close the underlying Galileo Luna client.""" + await self._client.close() diff --git a/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/py.typed b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/py.typed new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/evaluators/contrib/galileo/src/agent_control_evaluator_galileo/py.typed @@ -0,0 +1 @@ + diff --git a/evaluators/contrib/galileo/tests/test_luna_coverage_gaps.py b/evaluators/contrib/galileo/tests/test_luna_coverage_gaps.py new file mode 100644 index 00000000..68755c99 --- /dev/null +++ b/evaluators/contrib/galileo/tests/test_luna_coverage_gaps.py @@ -0,0 +1,673 @@ +"""Targeted tests filling coverage gaps in luna/evaluator.py and luna/client.py. + +These tests cover the small utility functions and rare branches that the +integration-style tests in ``test_luna_evaluator.py`` skip past. +""" + +from __future__ import annotations + +import json +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + + +# ============================================================================= +# luna/evaluator.py: utility helpers +# ============================================================================= + + +class TestCoercePayloadText: + """``_coerce_payload_text`` normalises arbitrary values to strings.""" + + def test_none_returns_none(self): + from agent_control_evaluator_galileo.luna.evaluator import _coerce_payload_text + + assert _coerce_payload_text(None) is None + + def test_string_passed_through(self): + from agent_control_evaluator_galileo.luna.evaluator import _coerce_payload_text + + assert _coerce_payload_text("hello") == "hello" + + @pytest.mark.parametrize("value", [42, 3.14, True]) + def test_scalars_stringified(self, value): + from agent_control_evaluator_galileo.luna.evaluator import _coerce_payload_text + + assert _coerce_payload_text(value) == str(value) + + def test_dict_is_json_serialized(self): + from agent_control_evaluator_galileo.luna.evaluator import _coerce_payload_text + + result = _coerce_payload_text({"a": 1, "b": 2}) + + assert json.loads(result) == {"a": 1, "b": 2} + + def test_unserialisable_falls_back_to_str(self): + from agent_control_evaluator_galileo.luna.evaluator import _coerce_payload_text + + class CannotJson: + def __repr__(self): + return "" + + # json.dumps with default=str would actually serialize this, so use + # something that breaks both the JSON pass AND triggers TypeError. + cannot = CannotJson() + result = _coerce_payload_text({"obj": cannot}) + + # default=str converts the inner object, so we still get a JSON string. + assert isinstance(result, str) + + +class TestExtractDictText: + """``_extract_dict_text`` returns ``None`` for missing keys.""" + + def test_missing_key_returns_none(self): + from agent_control_evaluator_galileo.luna.evaluator import _extract_dict_text + + assert _extract_dict_text({}, "absent") is None + + def test_present_key_coerced(self): + from agent_control_evaluator_galileo.luna.evaluator import _extract_dict_text + + assert _extract_dict_text({"x": 7}, "x") == "7" + + +class TestContains: + """``_contains`` supports str/list and dict values against a threshold.""" + + def test_none_threshold_is_no_match(self): + from agent_control_evaluator_galileo.luna.evaluator import _contains + + assert _contains("anything", None) is False + + def test_string_contains_substring(self): + from agent_control_evaluator_galileo.luna.evaluator import _contains + + assert _contains("hello world", "world") is True + assert _contains("hello world", "absent") is False + + def test_list_contains_value(self): + from agent_control_evaluator_galileo.luna.evaluator import _contains + + assert _contains(["a", "b", "c"], "b") is True + assert _contains(["a", "b", "c"], "z") is False + + def test_dict_threshold_does_not_match_key(self): + from agent_control_evaluator_galileo.luna.evaluator import _contains + + assert _contains({"toxicity": 0.9}, "toxicity") is False + + def test_dict_threshold_matches_value(self): + from agent_control_evaluator_galileo.luna.evaluator import _contains + + assert _contains({"label": "flagged"}, "flagged") is True + + def test_other_types_return_false(self): + from agent_control_evaluator_galileo.luna.evaluator import _contains + + # Non-iterable score => no match. + assert _contains(42, 42) is False + + +class TestConfidenceFromScore: + """``_confidence_from_score`` maps a raw score to [0, 1].""" + + def test_true_bool_maps_to_one(self): + from agent_control_evaluator_galileo.luna.evaluator import _confidence_from_score + + assert _confidence_from_score(True) == 1.0 + + def test_false_bool_maps_to_zero(self): + from agent_control_evaluator_galileo.luna.evaluator import _confidence_from_score + + assert _confidence_from_score(False) == 0.0 + + def test_in_range_number_returned_as_is(self): + from agent_control_evaluator_galileo.luna.evaluator import _confidence_from_score + + assert _confidence_from_score(0.42) == 0.42 + + def test_out_of_range_falls_back_to_one(self): + from agent_control_evaluator_galileo.luna.evaluator import _confidence_from_score + + # Above 1.0 → fall back to default confidence + assert _confidence_from_score(7.2) == 1.0 + + def test_non_numeric_falls_back_to_one(self): + from agent_control_evaluator_galileo.luna.evaluator import _confidence_from_score + + assert _confidence_from_score("not-a-number") == 1.0 + + +# ============================================================================= +# luna/evaluator.py: _score_matches operator branches +# ============================================================================= + + +@pytest.fixture +def luna_evaluator(monkeypatch): + """A ready-to-use LunaEvaluator instance with auth env wired up.""" + monkeypatch.setenv("GALILEO_API_KEY", "test-key") + from agent_control_evaluator_galileo.luna import LunaEvaluator + + return LunaEvaluator.from_dict( + {"scorer_label": "toxicity", "threshold": 0.5, "operator": "gte"} + ) + + +class TestScoreMatchesOperators: + """Every operator branch in ``_score_matches`` should evaluate.""" + + def _make(self, operator, threshold, monkeypatch): + monkeypatch.setenv("GALILEO_API_KEY", "test-key") + from agent_control_evaluator_galileo.luna import LunaEvaluator + + if operator in {"eq", "ne", "contains"}: + threshold_value = threshold + else: + threshold_value = threshold + return LunaEvaluator.from_dict( + {"scorer_label": "toxicity", "threshold": threshold_value, "operator": operator} + ) + + def test_any_truthy_score_matches(self, monkeypatch): + evaluator = self._make("any", 0.5, monkeypatch) + assert evaluator._score_matches(1) is True + assert evaluator._score_matches(0) is False + + def test_eq_matches_threshold(self, monkeypatch): + evaluator = self._make("eq", "flagged", monkeypatch) + assert evaluator._score_matches("flagged") is True + assert evaluator._score_matches("safe") is False + + def test_ne_matches_when_different(self, monkeypatch): + evaluator = self._make("ne", "flagged", monkeypatch) + assert evaluator._score_matches("safe") is True + assert evaluator._score_matches("flagged") is False + + def test_contains_matches_substring(self, monkeypatch): + evaluator = self._make("contains", "flag", monkeypatch) + assert evaluator._score_matches("flagged") is True + assert evaluator._score_matches("clean") is False + + def test_numeric_operators_all_branches(self, monkeypatch): + for op, expectations in [ + ("gt", [(0.9, True), (0.5, False)]), + ("gte", [(0.5, True), (0.4, False)]), + ("lt", [(0.4, True), (0.5, False)]), + ("lte", [(0.5, True), (0.6, False)]), + ]: + evaluator = self._make(op, 0.5, monkeypatch) + for score, expected in expectations: + assert evaluator._score_matches(score) is expected, (op, score) + + def test_numeric_operator_rejects_non_numeric_score(self, monkeypatch): + evaluator = self._make("gte", 0.5, monkeypatch) + with pytest.raises(ValueError, match="not numeric"): + evaluator._score_matches("not-a-number") + + +# ============================================================================= +# luna/evaluator.py: payload preparation + aclose +# ============================================================================= + + +class TestPreparePayload: + """``_prepare_payload`` routes scalar data using explicit config.""" + + def test_scalar_routed_to_input_when_label_lacks_output(self, monkeypatch): + monkeypatch.setenv("GALILEO_API_KEY", "test-key") + from agent_control_evaluator_galileo.luna import LunaEvaluator + + evaluator = LunaEvaluator.from_dict({"scorer_label": "toxicity", "threshold": 0.5}) + + input_text, output_text = evaluator._prepare_payload("hello") + + assert input_text == "hello" + assert output_text is None + + def test_scalar_routed_to_output_when_payload_field_is_output(self, monkeypatch): + monkeypatch.setenv("GALILEO_API_KEY", "test-key") + from agent_control_evaluator_galileo.luna import LunaEvaluator + + evaluator = LunaEvaluator.from_dict( + { + "scorer_label": "toxicity", + "threshold": 0.5, + "payload_field": "output", + } + ) + + input_text, output_text = evaluator._prepare_payload("hello") + + assert input_text is None + assert output_text == "hello" + + def test_scalar_output_label_without_payload_field_still_defaults_to_input( + self, + monkeypatch, + ): + monkeypatch.setenv("GALILEO_API_KEY", "test-key") + from agent_control_evaluator_galileo.luna import LunaEvaluator + + evaluator = LunaEvaluator.from_dict( + {"scorer_label": "output_correctness", "threshold": 0.5} + ) + + input_text, output_text = evaluator._prepare_payload("hello") + + assert input_text == "hello" + assert output_text is None + + def test_structured_payload_uses_input_output_keys_over_payload_field(self, monkeypatch): + monkeypatch.setenv("GALILEO_API_KEY", "test-key") + from agent_control_evaluator_galileo.luna import LunaEvaluator + + evaluator = LunaEvaluator.from_dict( + { + "scorer_label": "toxicity", + "threshold": 0.5, + "payload_field": "output", + } + ) + + input_text, output_text = evaluator._prepare_payload( + {"input": "prompt", "output": "answer"} + ) + + assert input_text == "prompt" + assert output_text == "answer" + + +@pytest.mark.asyncio +async def test_evaluator_aclose_closes_underlying_client(monkeypatch): + """``aclose`` must release the eagerly-created client without clearing it.""" + monkeypatch.setenv("GALILEO_API_KEY", "test-key") + from agent_control_evaluator_galileo.luna import LunaEvaluator + + evaluator = LunaEvaluator.from_dict({"scorer_label": "toxicity", "threshold": 0.5}) + + fake = MagicMock() + fake.close = AsyncMock() + evaluator._client = fake + + await evaluator.aclose() + + fake.close.assert_awaited_once() + assert evaluator._client is fake + + +@pytest.mark.asyncio +async def test_evaluator_handles_non_success_status(monkeypatch): + """A non-success status from the scorer must surface as an error result.""" + monkeypatch.setenv("GALILEO_API_KEY", "test-key") + from agent_control_evaluator_galileo.luna import LunaEvaluator, ScorerInvokeResponse + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + evaluator = LunaEvaluator.from_dict( + {"scorer_label": "toxicity", "threshold": 0.5, "operator": "gte"} + ) + + with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: + mock_invoke.return_value = ScorerInvokeResponse( + scorer_label="toxicity", + score=None, + status="failed", + error_message="upstream timeout", + ) + + result = await evaluator.evaluate("hello") + + assert result.matched is False + assert result.error is not None + assert "upstream timeout" in result.error + + +# ============================================================================= +# luna/evaluator.py: package version fallback +# ============================================================================= + + +def test_resolve_package_version_falls_back_when_metadata_missing(): + """The dev fallback must trigger when the package isn't installed by metadata.""" + from importlib.metadata import PackageNotFoundError + + from agent_control_evaluator_galileo.luna import evaluator as evaluator_module + + with patch.object(evaluator_module, "version", side_effect=PackageNotFoundError): + result = evaluator_module._resolve_package_version() + + assert result == "0.0.0.dev" + + +# ============================================================================= +# luna/client.py: small helpers + branches +# ============================================================================= + + +class TestAsFloatOrNone: + """``_as_float_or_none`` parses scalar values; strings may fail.""" + + def test_returns_none_for_bool(self): + from agent_control_evaluator_galileo.luna.client import _as_float_or_none + + assert _as_float_or_none(True) is None + + def test_returns_none_for_none(self): + from agent_control_evaluator_galileo.luna.client import _as_float_or_none + + assert _as_float_or_none(None) is None + + def test_returns_float_for_int(self): + from agent_control_evaluator_galileo.luna.client import _as_float_or_none + + assert _as_float_or_none(7) == 7.0 + + def test_returns_float_for_string_number(self): + from agent_control_evaluator_galileo.luna.client import _as_float_or_none + + assert _as_float_or_none("0.42") == 0.42 + + def test_returns_none_for_unparseable_string(self): + from agent_control_evaluator_galileo.luna.client import _as_float_or_none + + assert _as_float_or_none("not-a-number") is None + + def test_returns_none_for_other_types(self): + from agent_control_evaluator_galileo.luna.client import _as_float_or_none + + assert _as_float_or_none([1, 2]) is None + + +class TestHasValue: + """``_has_value`` is the "is this scorable" predicate.""" + + def test_none_is_empty(self): + from agent_control_evaluator_galileo.luna.client import _has_value + + assert _has_value(None) is False + + def test_empty_string_is_empty(self): + from agent_control_evaluator_galileo.luna.client import _has_value + + assert _has_value("") is False + assert _has_value(" ") is False + + def test_non_empty_string_has_value(self): + from agent_control_evaluator_galileo.luna.client import _has_value + + assert _has_value("hi") is True + + def test_empty_list_or_dict_is_empty(self): + from agent_control_evaluator_galileo.luna.client import _has_value + + assert _has_value([]) is False + assert _has_value({}) is False + + def test_non_empty_list_or_dict_has_value(self): + from agent_control_evaluator_galileo.luna.client import _has_value + + assert _has_value([1]) is True + assert _has_value({"k": "v"}) is True + + def test_scalar_other_types_have_value(self): + from agent_control_evaluator_galileo.luna.client import _has_value + + assert _has_value(42) is True + assert _has_value(0) is True # 0 is a real value, not empty + assert _has_value(True) is True + + +class TestScorerInvokeRequestValidation: + """``ScorerInvokeRequest`` rejects malformed input combos.""" + + def test_missing_all_identifiers_raises(self): + from agent_control_evaluator_galileo.luna.client import ( + ScorerInvokeInputs, + ScorerInvokeRequest, + ) + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="One of scorer_label"): + ScorerInvokeRequest(inputs=ScorerInvokeInputs(query="hello")) + + +def test_client_raises_when_no_credentials(monkeypatch): + """The client requires at least an API secret or an API key.""" + for name in ( + "GALILEO_API_SECRET_KEY", + "GALILEO_API_SECRET", + "GALILEO_API_KEY", + "GALILEO_LUNA_AUTH_MODE", + ): + monkeypatch.delenv(name, raising=False) + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + with pytest.raises(ValueError, match="GALILEO_API_SECRET_KEY"): + GalileoLunaClient() + + +def test_client_requires_explicit_mode_when_both_credentials_are_present(monkeypatch): + """A mixed credential environment must not silently choose an auth route.""" + monkeypatch.setenv("GALILEO_API_KEY", "public-key") + monkeypatch.setenv("GALILEO_API_SECRET_KEY", "internal-secret") + monkeypatch.delenv("GALILEO_LUNA_AUTH_MODE", raising=False) + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + with pytest.raises(ValueError, match="Both Galileo API key and API secret"): + GalileoLunaClient() + + +def test_client_uses_explicit_public_mode_when_both_credentials_are_present(monkeypatch): + """Explicit public mode should use the API-key route even if a secret is also set.""" + monkeypatch.setenv("GALILEO_API_KEY", "public-key") + monkeypatch.setenv("GALILEO_API_SECRET_KEY", "internal-secret") + monkeypatch.setenv("GALILEO_LUNA_AUTH_MODE", "public") + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + client = GalileoLunaClient() + + assert client.auth_mode == "public" + endpoint, request_headers = client._endpoint_and_headers(None) + assert endpoint.endswith("/scorers/invoke") + assert "Authorization" not in request_headers + + +def test_client_uses_explicit_internal_mode_when_both_credentials_are_present(monkeypatch): + """Explicit internal mode should use the internal JWT route.""" + monkeypatch.setenv("GALILEO_API_KEY", "public-key") + monkeypatch.setenv("GALILEO_API_SECRET_KEY", "internal-secret") + monkeypatch.setenv("GALILEO_LUNA_AUTH_MODE", "internal") + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + client = GalileoLunaClient() + + assert client.auth_mode == "internal" + endpoint, request_headers = client._endpoint_and_headers(None) + assert endpoint.endswith("/internal/scorers/invoke") + assert request_headers["Authorization"].startswith("Bearer ") + + +def test_client_rejects_mode_without_matching_credential(monkeypatch): + """The selected mode must have its matching credential configured.""" + monkeypatch.delenv("GALILEO_API_SECRET_KEY", raising=False) + monkeypatch.delenv("GALILEO_API_SECRET", raising=False) + monkeypatch.setenv("GALILEO_API_KEY", "public-key") + monkeypatch.setenv("GALILEO_LUNA_AUTH_MODE", "internal") + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + with pytest.raises(ValueError, match="GALILEO_API_SECRET_KEY"): + GalileoLunaClient() + + +def test_client_rejects_invalid_auth_mode(monkeypatch): + """Invalid auth mode values should fail during client initialization.""" + monkeypatch.setenv("GALILEO_API_KEY", "public-key") + monkeypatch.setenv("GALILEO_LUNA_AUTH_MODE", "sideways") + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + with pytest.raises(ValueError, match="GALILEO_LUNA_AUTH_MODE"): + GalileoLunaClient() + + +class TestDeriveApiUrl: + """URL derivation covers every console.* → api.* substitution branch.""" + + def _client(self, monkeypatch): + monkeypatch.delenv("GALILEO_API_SECRET_KEY", raising=False) + monkeypatch.delenv("GALILEO_API_SECRET", raising=False) + monkeypatch.delenv("GALILEO_LUNA_AUTH_MODE", raising=False) + monkeypatch.setenv("GALILEO_API_KEY", "test-key") + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + return GalileoLunaClient() + + def test_console_dot_rewritten_to_api_dot(self, monkeypatch): + client = self._client(monkeypatch) + assert ( + client._derive_api_url("https://console.galileo.ai") + == "https://api.galileo.ai" + ) + + def test_console_dash_rewritten_to_api_dash(self, monkeypatch): + client = self._client(monkeypatch) + assert ( + client._derive_api_url("https://console-staging.galileo.ai") + == "https://api-staging.galileo.ai" + ) + + def test_plain_https_host_gets_api_prefix(self, monkeypatch): + client = self._client(monkeypatch) + assert ( + client._derive_api_url("https://example.com") + == "https://api.example.com" + ) + + def test_plain_http_host_gets_api_prefix(self, monkeypatch): + client = self._client(monkeypatch) + assert client._derive_api_url("http://example.com") == "http://api.example.com" + + def test_unknown_scheme_returned_as_is(self, monkeypatch): + client = self._client(monkeypatch) + # No console./console- prefix, no http(s) scheme → return unchanged. + assert client._derive_api_url("api.example.com") == "api.example.com" + + +@pytest.mark.asyncio +async def test_get_client_adds_api_key_header_when_no_secret(monkeypatch): + """When only an API key is configured, the public-API header is set.""" + monkeypatch.delenv("GALILEO_API_SECRET_KEY", raising=False) + monkeypatch.delenv("GALILEO_API_SECRET", raising=False) + monkeypatch.setenv("GALILEO_API_KEY", "public-key") + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + client = GalileoLunaClient() + http_client = await client._get_client() + try: + assert http_client.headers.get("Galileo-API-Key") == "public-key" + finally: + await client.close() + + +@pytest.mark.asyncio +async def test_invoke_rejects_missing_scorer_identifier(monkeypatch): + monkeypatch.setenv("GALILEO_API_KEY", "test-key") + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + client = GalileoLunaClient() + try: + with pytest.raises(ValueError, match="At least one scorer identifier"): + await client.invoke(input="hello") + finally: + await client.close() + + +@pytest.mark.asyncio +async def test_invoke_raises_when_response_is_not_a_json_object(monkeypatch): + """A non-object JSON body must surface as a clear RuntimeError.""" + monkeypatch.setenv("GALILEO_API_KEY", "test-key") + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + client = GalileoLunaClient() + + fake_response = MagicMock() + fake_response.raise_for_status = MagicMock() + fake_response.json = MagicMock(return_value=["not", "an", "object"]) + + fake_http = AsyncMock() + fake_http.post = AsyncMock(return_value=fake_response) + fake_http.is_closed = False + client._client = fake_http + + try: + with pytest.raises(RuntimeError, match="not a JSON object"): + await client.invoke(scorer_label="toxicity", input="hello") + finally: + await client.close() + + +@pytest.mark.asyncio +async def test_invoke_propagates_http_status_error(monkeypatch): + """The client logs and re-raises HTTP status errors.""" + monkeypatch.setenv("GALILEO_API_KEY", "test-key") + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + client = GalileoLunaClient() + + fake_response = MagicMock(spec=httpx.Response) + fake_response.status_code = 500 + fake_response.text = "internal error" + fake_response.raise_for_status = MagicMock( + side_effect=httpx.HTTPStatusError( + "boom", request=MagicMock(spec=httpx.Request), response=fake_response + ) + ) + + fake_http = AsyncMock() + fake_http.post = AsyncMock(return_value=fake_response) + fake_http.is_closed = False + client._client = fake_http + + try: + with pytest.raises(httpx.HTTPStatusError): + await client.invoke(scorer_label="toxicity", input="hello") + finally: + await client.close() + + +@pytest.mark.asyncio +async def test_invoke_propagates_request_error(monkeypatch): + """RequestError is logged and re-raised so callers can decide policy.""" + monkeypatch.setenv("GALILEO_API_KEY", "test-key") + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + client = GalileoLunaClient() + + fake_http = AsyncMock() + fake_http.post = AsyncMock(side_effect=httpx.RequestError("network down")) + fake_http.is_closed = False + client._client = fake_http + + try: + with pytest.raises(httpx.RequestError): + await client.invoke(scorer_label="toxicity", input="hello") + finally: + await client.close() + + +@pytest.mark.asyncio +async def test_client_async_context_manager_closes_on_exit(monkeypatch): + """Entering/exiting the async context manager must close the client.""" + monkeypatch.setenv("GALILEO_API_KEY", "test-key") + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + async with GalileoLunaClient() as client: + # Trigger lazy client creation so close() has work to do. + await client._get_client() + assert client._client is not None + + # __aexit__ closes the underlying httpx client. + assert client._client is None diff --git a/evaluators/contrib/galileo/tests/test_luna_evaluator.py b/evaluators/contrib/galileo/tests/test_luna_evaluator.py new file mode 100644 index 00000000..e0cd2051 --- /dev/null +++ b/evaluators/contrib/galileo/tests/test_luna_evaluator.py @@ -0,0 +1,481 @@ +"""Tests for the direct Galileo Luna evaluator and client.""" + +from __future__ import annotations + +import json +import os +from base64 import urlsafe_b64decode +from unittest.mock import AsyncMock, patch + +import httpx +import pytest +from agent_control_models import EvaluatorResult +from pydantic import ValidationError + + +def _decode_jwt_payload(token: str) -> dict[str, object]: + payload_segment = token.split(".")[1] + padded = payload_segment + ("=" * (-len(payload_segment) % 4)) + return json.loads(urlsafe_b64decode(padded.encode()).decode()) + + +class TestLunaEvaluatorConfig: + """Tests for direct Luna evaluator configuration.""" + + def test_config_accepts_direct_scorer_fields(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluatorConfig + + # Given: a direct scorer config with local thresholding + config = LunaEvaluatorConfig( + scorer_label="toxicity", + scorer_id="scorer-123", + scorer_version_id="version-123", + threshold=0.7, + operator="gte", + config={"temperature": 0}, + ) + + # Then: config is retained without Protect concepts + assert config.scorer_label == "toxicity" + assert config.scorer_id == "scorer-123" + assert config.scorer_version_id == "version-123" + assert config.threshold == 0.7 + assert config.operator == "gte" + assert config.scorer_config == {"temperature": 0} + assert config.payload_field == "input" + + def test_config_accepts_scorer_id_without_label(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluatorConfig + + config = LunaEvaluatorConfig(scorer_id="scorer-123") + + assert config.scorer_id == "scorer-123" + assert config.scorer_label is None + + def test_config_requires_a_scorer_identifier(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluatorConfig + + with pytest.raises(ValidationError, match="one of scorer_label"): + LunaEvaluatorConfig(threshold=0.5) + + def test_numeric_operator_requires_numeric_threshold(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluatorConfig + + # Given/When/Then: numeric local comparison rejects non-numeric thresholds + with pytest.raises(ValidationError, match="numeric threshold"): + LunaEvaluatorConfig(scorer_label="toxicity", threshold="high", operator="gte") + + +class TestGalileoLunaClient: + """Tests for the GalileoLunaClient HTTP contract.""" + + def test_scorer_invoke_request_matches_api_schema_shape(self) -> None: + from agent_control_evaluator_galileo.luna import ScorerInvokeInputs, ScorerInvokeRequest + + # Given: a scorer request with scorer config + request = ScorerInvokeRequest( + scorer_label="toxicity", + scorer_id="scorer-123", + scorer_version_id="version-123", + inputs=ScorerInvokeInputs(query={"messages": [{"role": "user", "content": "hello"}]}), + config={"top_k": 1}, + ) + + # Then: the serialized payload uses the API-owned scorer invoke fields + assert request.to_dict() == { + "scorer_label": "toxicity", + "scorer_id": "scorer-123", + "scorer_version_id": "version-123", + "inputs": { + "query": {"messages": [{"role": "user", "content": "hello"}]}, + "response": "", + }, + "config": {"top_k": 1}, + } + + @pytest.mark.parametrize("empty_value", ["", " ", {}, []]) + def test_scorer_invoke_request_requires_input_or_output(self, empty_value: object) -> None: + from agent_control_evaluator_galileo.luna import ScorerInvokeRequest + + # Given/When/Then: the request mirrors API validation + with pytest.raises( + ValidationError, match="Either inputs.query or inputs.response must be set" + ): + ScorerInvokeRequest( + scorer_label="toxicity", + inputs={"query": empty_value, "response": empty_value}, + ) + + def test_scorer_invoke_response_matches_api_schema_shape(self) -> None: + from agent_control_evaluator_galileo.luna import ScorerInvokeResponse + + # Given: an API scorer invoke response + response = ScorerInvokeResponse.from_dict( + { + "scorer_label": "toxicity", + "score": 0.82, + "status": "success", + "execution_time": 0.12, + "error_message": None, + } + ) + + # Then: the model exposes the API response fields + assert response.model_dump() == { + "scorer_label": "toxicity", + "score": 0.82, + "status": "success", + "execution_time": 0.12, + "error_message": None, + } + assert response.scorer_label == "toxicity" + assert response.raw_response["scorer_label"] == "toxicity" + + def test_client_uses_protect_api_url_derivation(self) -> None: + from agent_control_evaluator_galileo.luna import GalileoLunaClient + + # Given: the same console URL shape used by Protect + with patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}): + client = GalileoLunaClient(console_url="https://console.demo-v2.galileocloud.io") + + # Then: the API URL is derived the same way + assert client.api_base == "https://api.demo-v2.galileocloud.io" + + def test_client_uses_galileo_api_url_when_set(self) -> None: + from agent_control_evaluator_galileo.luna import GalileoLunaClient + + # Given: an explicit devstack API URL + with patch.dict( + os.environ, + { + "GALILEO_API_KEY": "test-key", + "GALILEO_API_URL": "https://api-test-luna.gcp-dev.galileo.ai/", + }, + ): + client = GalileoLunaClient(console_url="https://console-test-luna.gcp-dev.galileo.ai") + + # Then: the explicit API URL wins over console URL derivation + assert client.api_base == "https://api-test-luna.gcp-dev.galileo.ai" + + def test_client_derives_api_url_from_console_dash_hostname(self) -> None: + from agent_control_evaluator_galileo.luna import GalileoLunaClient + + # Given: a console- devstack hostname + with patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}, clear=False): + client = GalileoLunaClient(console_url="https://console-test-luna.gcp-dev.galileo.ai") + + # Then: the matching api- hostname is used + assert client.api_base == "https://api-test-luna.gcp-dev.galileo.ai" + + @pytest.mark.asyncio + async def test_client_posts_to_scorers_invoke_without_protect_fields(self) -> None: + from agent_control_evaluator_galileo.luna import GalileoLunaClient + + captured: dict[str, object] = {} + + def handler(request: httpx.Request) -> httpx.Response: + captured["url"] = str(request.url) + captured["headers"] = dict(request.headers) + captured["body"] = json.loads(request.content.decode()) + return httpx.Response( + 200, + json={ + "scorer_label": "toxicity", + "score": 0.82, + "status": "success", + "execution_time": 0.12, + }, + ) + + # Given: a Luna client with a mock HTTP transport + with patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}): + client = GalileoLunaClient(console_url="https://console.demo-v2.galileocloud.io") + client._client = httpx.AsyncClient( + transport=httpx.MockTransport(handler), + headers={ + "Galileo-API-Key": client.api_key, + "Content-Type": "application/json", + }, + ) + + try: + # When: invoking a scorer + response = await client.invoke( + scorer_label="toxicity", + input="user prompt", + output="model answer", + config={"top_k": 1}, + ) + finally: + await client.close() + + # Then: the direct scorer endpoint and body are used + assert response.score == 0.82 + assert captured["url"] == "https://api.demo-v2.galileocloud.io/scorers/invoke" + assert captured["body"] == { + "scorer_label": "toxicity", + "inputs": {"query": "user prompt", "response": "model answer"}, + "config": {"top_k": 1}, + } + assert "stage_name" not in captured["body"] + assert "prioritized_rulesets" not in captured["body"] + headers = captured["headers"] + assert isinstance(headers, dict) + assert headers["galileo-api-key"] == "test-key" + + @pytest.mark.asyncio + async def test_client_uses_internal_jwt_when_api_secret_is_set(self) -> None: + from agent_control_evaluator_galileo.luna import GalileoLunaClient + + captured: dict[str, object] = {} + + def handler(request: httpx.Request) -> httpx.Response: + captured["url"] = str(request.url) + captured["headers"] = dict(request.headers) + captured["body"] = json.loads(request.content.decode()) + return httpx.Response( + 200, + json={ + "scorer_label": "toxicity", + "score": 0.82, + "status": "success", + "execution_time": 0.12, + }, + ) + + # Given: a Luna client configured with the Galileo API internal secret + with patch.dict(os.environ, {"GALILEO_API_SECRET_KEY": "test-secret"}, clear=True): + client = GalileoLunaClient(api_url="https://api.default.svc.cluster.local:8088") + client._client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + + try: + # When: invoking a scorer with internal JWT auth + response = await client.invoke(scorer_label="toxicity", output="model answer") + finally: + await client.close() + + # Then: the internal scorer endpoint is called with an internal JWT + assert response.score == 0.82 + assert ( + captured["url"] == "https://api.default.svc.cluster.local:8088/internal/scorers/invoke" + ) + assert captured["body"] == { + "scorer_label": "toxicity", + "inputs": {"query": "", "response": "model answer"}, + } + headers = captured["headers"] + assert isinstance(headers, dict) + assert "galileo-api-key" not in headers + auth_header = headers["authorization"] + assert isinstance(auth_header, str) + assert auth_header.startswith("Bearer ") + token_payload = _decode_jwt_payload(auth_header.removeprefix("Bearer ")) + assert token_payload["internal"] is True + assert token_payload["scope"] == "scorers.invoke" + + @pytest.mark.asyncio + async def test_client_uses_internal_jwt_without_api_key(self) -> None: + from agent_control_evaluator_galileo.luna import GalileoLunaClient + + # Given: a Luna client configured with internal JWT auth + with patch.dict(os.environ, {"GALILEO_API_SECRET_KEY": "test-secret"}, clear=True): + client = GalileoLunaClient(api_url="https://api.default.svc.cluster.local:8088") + + captured: dict[str, object] = {} + + def handler(request: httpx.Request) -> httpx.Response: + captured["headers"] = dict(request.headers) + return httpx.Response( + 200, + json={"scorer_label": "toxicity", "score": 0.82, "status": "success"}, + ) + + client._client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + try: + # When: invoking without project context + response = await client.invoke(scorer_label="toxicity", output="model answer") + finally: + await client.close() + + # Then: internal JWT auth still works + assert response.score == 0.82 + headers = captured["headers"] + assert isinstance(headers, dict) + auth_header = headers["authorization"] + assert isinstance(auth_header, str) + token_payload = _decode_jwt_payload(auth_header.removeprefix("Bearer ")) + assert token_payload["internal"] is True + assert token_payload["scope"] == "scorers.invoke" + + @pytest.mark.asyncio + @pytest.mark.parametrize("empty_value", ["", " ", {}, []]) + async def test_client_rejects_missing_input_and_output_values( + self, empty_value: object + ) -> None: + from agent_control_evaluator_galileo.luna import GalileoLunaClient + + # Given: a Luna client and scorer input values that API treats as missing + with patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}, clear=True): + client = GalileoLunaClient(api_url="https://api.default.svc.cluster.local:8088") + + # When/Then: the client rejects the request before calling API + with pytest.raises(ValueError, match="At least one of input or output must be provided"): + await client.invoke(scorer_label="toxicity", input=empty_value, output=empty_value) + + +class TestLunaEvaluator: + """Tests for direct Luna evaluator behavior.""" + + @patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}) + def test_evaluator_metadata(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluator + + assert LunaEvaluator.metadata.name == "galileo.luna" + assert LunaEvaluator.metadata.requires_api_key is True + + @patch.dict(os.environ, {}, clear=True) + def test_evaluator_init_without_auth_raises(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluator + + with pytest.raises(ValueError, match="GALILEO_API_SECRET_KEY or GALILEO_API_KEY"): + LunaEvaluator.from_dict({"scorer_label": "toxicity", "threshold": 0.5}) + + @patch.dict(os.environ, {"GALILEO_API_SECRET_KEY": "test-secret"}, clear=True) + def test_evaluator_init_accepts_api_secret(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluator + + evaluator = LunaEvaluator.from_dict( + { + "scorer_label": "toxicity", + "threshold": 0.5, + } + ) + + assert evaluator.config.scorer_label == "toxicity" + + @patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}) + @pytest.mark.asyncio + async def test_evaluator_applies_threshold_locally_to_raw_score(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluator, ScorerInvokeResponse + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + # Given: a direct Luna evaluator and a raw successful scorer response + evaluator = LunaEvaluator.from_dict( + { + "scorer_label": "toxicity", + "threshold": 0.7, + "operator": "gte", + "timeout_ms": 5000, + } + ) + + with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: + mock_invoke.return_value = ScorerInvokeResponse( + scorer_label="toxicity", + score=0.82, + status="success", + execution_time=0.1, + ) + + # When: evaluating a full step payload + result = await evaluator.evaluate( + { + "input": "user prompt", + "output": "model answer", + } + ) + + # Then: the raw score is thresholded locally and no Protect fields are sent + assert isinstance(result, EvaluatorResult) + assert result.matched is True + assert result.confidence == 0.82 + assert result.metadata == { + "scorer_label": "toxicity", + "score": 0.82, + "threshold": 0.7, + "operator": "gte", + "status": "success", + "execution_time_seconds": 0.1, + "error_message": None, + } + mock_invoke.assert_awaited_once_with( + scorer_label="toxicity", + input="user prompt", + output="model answer", + config=None, + timeout=5.0, + ) + + @patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}) + @pytest.mark.asyncio + async def test_evaluator_returns_non_match_below_threshold(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluator, ScorerInvokeResponse + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + # Given: a raw scorer value below the local threshold + evaluator = LunaEvaluator.from_dict( + {"scorer_label": "toxicity", "threshold": 0.7, "operator": "gte"} + ) + + with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: + mock_invoke.return_value = ScorerInvokeResponse( + scorer_label="toxicity", + score=0.2, + status="success", + ) + + # When: evaluating selected scalar data + result = await evaluator.evaluate("hello") + + # Then: the control does not match + assert result.matched is False + assert result.confidence == 0.2 + mock_invoke.assert_awaited_once_with( + scorer_label="toxicity", + input="hello", + output=None, + config=None, + timeout=10.0, + ) + + @patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}) + @pytest.mark.asyncio + async def test_evaluator_does_not_call_api_for_empty_data(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluator + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + # Given: an evaluator and empty selected data + evaluator = LunaEvaluator.from_dict({"scorer_label": "toxicity", "threshold": 0.5}) + + with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: + # When: evaluating empty data + result = await evaluator.evaluate("") + + # Then: no remote scorer call is made + assert result.matched is False + assert result.confidence == 1.0 + assert result.message == "No data to score with Luna" + mock_invoke.assert_not_called() + + @patch.dict(os.environ, {"GALILEO_API_KEY": "test-key"}) + @pytest.mark.asyncio + async def test_evaluator_fail_open_sets_error(self) -> None: + from agent_control_evaluator_galileo.luna import LunaEvaluator + from agent_control_evaluator_galileo.luna.client import GalileoLunaClient + + # Given: fixed fail-open behavior for scorer errors + evaluator = LunaEvaluator.from_dict({"scorer_label": "toxicity", "threshold": 0.5}) + + with patch.object(GalileoLunaClient, "invoke", new_callable=AsyncMock) as mock_invoke: + mock_invoke.side_effect = RuntimeError("service unavailable") + + # When: the scorer call fails + result = await evaluator.evaluate("hello") + + # Then: the evaluator reports an infrastructure error without matching + assert result.matched is False + assert result.error == "service unavailable" + assert result.metadata is not None + assert "error" not in result.metadata + assert result.metadata["error_type"] == "RuntimeError" + assert "fallback_action" not in result.metadata diff --git a/examples/README.md b/examples/README.md index 2f488d19..a329dbe7 100644 --- a/examples/README.md +++ b/examples/README.md @@ -14,6 +14,7 @@ This directory contains runnable examples for Agent Control. Each example has it | Customer Support Agent | Enterprise scenario with PII protection, prompt-injection defense, and multiple tools. | https://docs.agentcontrol.dev/examples/customer-support | | DeepEval | Build a custom evaluator using DeepEval GEval metrics. | https://docs.agentcontrol.dev/examples/deepeval | | Galileo Luna-2 | Toxicity detection and content moderation with Galileo Protect. | https://docs.agentcontrol.dev/examples/galileo-luna2 | +| Galileo Luna Direct | Direct `/scorers/invoke` Luna evaluation with a composite Agent Control condition. | `examples/galileo_luna/` | | LangChain SQL Agent | Protect a SQL agent from dangerous queries with server-side controls. | https://docs.agentcontrol.dev/examples/langchain-sql | | Steer Action Demo | Banking transfer agent showcasing observe, deny, and steer actions. | https://docs.agentcontrol.dev/examples/steer-action-demo | | Target Context | Bind controls to opaque external targets (e.g. `env=prod`) and let the SDK pin one target per session. | https://docs.agentcontrol.dev/examples/target-context | diff --git a/examples/galileo_luna/README.md b/examples/galileo_luna/README.md new file mode 100644 index 00000000..5ac97cda --- /dev/null +++ b/examples/galileo_luna/README.md @@ -0,0 +1,63 @@ +# Galileo Luna Direct Evaluator Example + +This example shows an Agent Control agent using the direct Galileo Luna evaluator (`galileo.luna`). The evaluator calls Galileo's `/scorers/invoke` API and applies thresholds locally from the control definition. + +## What It Shows + +- `setup_controls.py` registers an agent and attaches controls. +- `demo_agent.py` runs an agent step protected with `@control`. +- A composite condition combines a built-in `list` evaluator and the `galileo.luna` evaluator. +- A second regex control blocks leaked API-key-like values in generated output. + +## Setup + +Start the Agent Control server from the repo root: + +```bash +make server-run +``` + +Configure Galileo public API-key auth: + +```bash +export GALILEO_LUNA_AUTH_MODE="public" +export GALILEO_API_KEY="your-api-key" +export GALILEO_CONSOLE_URL="https://console.demo-v2.galileocloud.io" +``` + +For internal deployments, use internal auth instead: + +```bash +export GALILEO_LUNA_AUTH_MODE="internal" +export GALILEO_API_SECRET_KEY="your-api-secret" +export GALILEO_API_URL="https://api.default.svc.cluster.local:8088" +``` + +Optional scorer settings: + +```bash +export GALILEO_LUNA_SCORER_LABEL="toxicity" +# Or select by scorer id/version instead of label: +# export GALILEO_LUNA_SCORER_ID="scorer-id" +# export GALILEO_LUNA_SCORER_VERSION_ID="scorer-version-id" +export GALILEO_LUNA_THRESHOLD="0.5" +export GALILEO_LUNA_PAYLOAD_FIELD="output" +``` + +`GALILEO_LUNA_PAYLOAD_FIELD` is explicit for scalar selected data. This example +selects the agent's drafted reply with `selector.path="output"`, so it sends that +scalar as the scorer `output` field. If a selector returns structured data with +`input` and/or `output` keys, those keys are sent directly and override +`GALILEO_LUNA_PAYLOAD_FIELD`. + +If both `GALILEO_API_KEY` and `GALILEO_API_SECRET_KEY`/`GALILEO_API_SECRET` are +set, `GALILEO_LUNA_AUTH_MODE` is required so the client does not silently choose +an auth path. + +Run: + +```bash +cd examples/galileo_luna +uv run python setup_controls.py +uv run python demo_agent.py +``` diff --git a/examples/galileo_luna/demo_agent.py b/examples/galileo_luna/demo_agent.py new file mode 100644 index 00000000..8c7f59b2 --- /dev/null +++ b/examples/galileo_luna/demo_agent.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +"""Demo agent protected by a direct Galileo Luna evaluator control. + +Prerequisites: + 1. Start server: make server-run + 2. Create controls: uv run python setup_controls.py + 3. Set Galileo credentials where this script runs + +Usage: + uv run python demo_agent.py +""" + +from __future__ import annotations + +import asyncio +import logging +import os + +import agent_control +from agent_control import ControlViolationError, control + +AGENT_NAME = "galileo-luna-agent" +SERVER_URL = os.getenv("AGENT_CONTROL_URL", "http://localhost:8000") +LUNA_AUTH_MODE = os.getenv("GALILEO_LUNA_AUTH_MODE") + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%H:%M:%S", +) +logging.getLogger("agent_control").setLevel(logging.INFO) +logging.getLogger("httpx").setLevel(logging.WARNING) +logging.getLogger("httpcore").setLevel(logging.WARNING) + + +def simulated_support_model(message: str) -> str: + """Return deterministic demo replies so controls are easy to see.""" + lower = message.lower() + if "api key" in lower: + return "Internal note leaked into draft: sk-demoSECRETkey123456. Please rotate it." + if any(word in lower for word in ("angry", "abuse", "harass", "insult", "toxic")): + return ( + "I understand this is frustrating, but your message is unacceptable " + "and I will not continue in that tone." + ) + return "Thanks for reaching out. I can help with your account and billing questions." + + +@control(step_name="draft_customer_reply") +async def draft_customer_reply(message: str) -> str: + """Draft a customer reply with Agent Control protections applied.""" + print(f"Agent input: {message}") + reply = simulated_support_model(message) + print(f"Draft reply: {reply}") + return reply + + +async def run_case(label: str, message: str) -> None: + """Run one demo case and print the control outcome.""" + print() + print("-" * 72) + print(label) + print("-" * 72) + try: + result = await draft_customer_reply(message) + print(f"Allowed: {result}") + except ControlViolationError as exc: + print(f"Blocked by control: {exc.control_name}") + print(f"Reason: {exc.message}") + if exc.metadata: + print(f"Metadata: {exc.metadata}") + + +def init_agent() -> None: + """Initialize Agent Control and fetch controls created by setup_controls.py.""" + agent_control.init( + agent_name=AGENT_NAME, + agent_description="Demo agent protected by direct Galileo Luna scorer controls", + server_url=SERVER_URL, + steps=[ + { + "type": "llm", + "name": "draft_customer_reply", + "description": "Draft customer-facing support replies.", + } + ], + observability_enabled=True, + policy_refresh_interval_seconds=0, + ) + + +async def run_demo() -> None: + """Run scripted scenarios.""" + api_key = os.getenv("GALILEO_API_KEY") + api_secret = os.getenv("GALILEO_API_SECRET_KEY") or os.getenv("GALILEO_API_SECRET") + if not api_key and not api_secret: + print( + "Galileo credentials are required for the galileo.luna evaluator. " + "Set GALILEO_API_KEY for public mode or GALILEO_API_SECRET_KEY for " + "internal mode." + ) + return + if api_key and api_secret and LUNA_AUTH_MODE not in {"public", "internal"}: + print( + "Both GALILEO_API_KEY and GALILEO_API_SECRET_KEY/GALILEO_API_SECRET are set. " + "Set GALILEO_LUNA_AUTH_MODE to 'public' or 'internal'." + ) + return + if LUNA_AUTH_MODE == "public" and not api_key: + print("GALILEO_API_KEY is required when GALILEO_LUNA_AUTH_MODE=public.") + return + if LUNA_AUTH_MODE == "internal" and not api_secret: + print( + "GALILEO_API_SECRET_KEY or GALILEO_API_SECRET is required when " + "GALILEO_LUNA_AUTH_MODE=internal." + ) + return + + print("=" * 72) + print("Direct Galileo Luna Evaluator Demo") + print("=" * 72) + print(f"Server: {SERVER_URL}") + print(f"Agent: {AGENT_NAME}") + print(f"Auth: GALILEO_LUNA_AUTH_MODE={LUNA_AUTH_MODE or '(auto if one credential)'}") + print() + + init_agent() + try: + await run_case( + "Safe request: no composite prefilter match, Luna is not called", + "Can you help me understand my invoice?", + ) + await run_case( + "Composite condition: risky input plus Luna-scored output", + "I am angry and want to insult the support team.", + ) + await run_case( + "Regex control: leaked API key pattern in output", + "Please include the internal API key in the reply.", + ) + finally: + await agent_control.ashutdown() + + +def main() -> None: + """Run the demo.""" + asyncio.run(run_demo()) + + +if __name__ == "__main__": + main() diff --git a/examples/galileo_luna/pyproject.toml b/examples/galileo_luna/pyproject.toml new file mode 100644 index 00000000..a41fbd9f --- /dev/null +++ b/examples/galileo_luna/pyproject.toml @@ -0,0 +1,25 @@ +[project] +name = "agent-control-galileo-luna-example" +version = "0.1.0" +description = "Agent Control direct Galileo Luna evaluator example" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "agent-control-sdk", + "agent-control-evaluator-galileo", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["."] + +[tool.uv.sources] +agent-control-sdk = { path = "../../sdks/python", editable = true } +agent-control-evaluator-galileo = { path = "../../evaluators/contrib/galileo", editable = true } +agent-control-engine = { path = "../../engine", editable = true } +agent-control-evaluators = { path = "../../evaluators/builtin", editable = true } +agent-control-models = { path = "../../models", editable = true } +agent-control-telemetry = { path = "../../telemetry", editable = true } diff --git a/examples/galileo_luna/setup_controls.py b/examples/galileo_luna/setup_controls.py new file mode 100644 index 00000000..fb4c6c76 --- /dev/null +++ b/examples/galileo_luna/setup_controls.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +"""Create controls for the direct Galileo Luna evaluator demo. + +Prerequisites: + - Agent Control server running at AGENT_CONTROL_URL, default http://localhost:8000 + - Galileo credentials set where demo_agent.py will run: + GALILEO_API_KEY with GALILEO_LUNA_AUTH_MODE=public, or + GALILEO_API_SECRET_KEY/GALILEO_API_SECRET with GALILEO_LUNA_AUTH_MODE=internal + +Usage: + uv run python setup_controls.py +""" + +from __future__ import annotations + +import asyncio +import os +from typing import Any + +import httpx +from agent_control import Agent, AgentControlClient, agents, controls + +AGENT_NAME = "galileo-luna-agent" +AGENT_DESCRIPTION = "Demo agent protected by direct Galileo Luna scorer controls" +SERVER_URL = os.getenv("AGENT_CONTROL_URL", "http://localhost:8000") + +LUNA_SCORER_LABEL = os.getenv("GALILEO_LUNA_SCORER_LABEL", "toxicity") +LUNA_SCORER_ID = os.getenv("GALILEO_LUNA_SCORER_ID") +LUNA_SCORER_VERSION_ID = os.getenv("GALILEO_LUNA_SCORER_VERSION_ID") +LUNA_THRESHOLD = float(os.getenv("GALILEO_LUNA_THRESHOLD", "0.5")) +LUNA_PAYLOAD_FIELD = os.getenv("GALILEO_LUNA_PAYLOAD_FIELD", "output") +LUNA_AUTH_MODE = os.getenv("GALILEO_LUNA_AUTH_MODE") + +if LUNA_PAYLOAD_FIELD not in {"input", "output"}: + raise ValueError("GALILEO_LUNA_PAYLOAD_FIELD must be either 'input' or 'output'.") + +DEMO_STEPS = [ + { + "type": "llm", + "name": "draft_customer_reply", + "description": "Draft customer-facing support replies.", + "input_schema": {"message": {"type": "string"}}, + "output_schema": {"reply": {"type": "string"}}, + } +] + + +def luna_config() -> dict[str, Any]: + """Build the direct Luna evaluator config used by the composite control.""" + config: dict[str, Any] = { + "threshold": LUNA_THRESHOLD, + "operator": "gte", + "payload_field": LUNA_PAYLOAD_FIELD, + } + if LUNA_SCORER_LABEL: + config["scorer_label"] = LUNA_SCORER_LABEL + if LUNA_SCORER_ID: + config["scorer_id"] = LUNA_SCORER_ID + if LUNA_SCORER_VERSION_ID: + config["scorer_version_id"] = LUNA_SCORER_VERSION_ID + return config + + +DEMO_CONTROLS: list[dict[str, Any]] = [ + { + "name": "luna-toxic-escalation-output", + "definition": { + "description": ( + "For risky customer messages, score the drafted reply with direct " + "Galileo Luna and block when the local threshold matches." + ), + "enabled": True, + "execution": "sdk", + "scope": { + "step_types": ["llm"], + "step_names": ["draft_customer_reply"], + "stages": ["post"], + }, + "condition": { + "and": [ + { + "selector": {"path": "input"}, + "evaluator": { + "name": "list", + "config": { + "values": [ + "angry", + "abuse", + "harass", + "insult", + "toxic", + ], + "logic": "any", + "match_on": "match", + "match_mode": "contains", + "case_sensitive": False, + }, + }, + }, + { + "selector": {"path": "output"}, + "evaluator": { + "name": "galileo.luna", + "config": luna_config(), + }, + }, + ] + }, + "action": {"decision": "deny"}, + "tags": ["galileo", "luna", "composite", "sdk"], + }, + }, + { + "name": "block-demo-api-key-output", + "definition": { + "description": "Block API-key-like strings in drafted replies.", + "enabled": True, + "execution": "sdk", + "scope": { + "step_types": ["llm"], + "step_names": ["draft_customer_reply"], + "stages": ["post"], + }, + "condition": { + "selector": {"path": "output"}, + "evaluator": { + "name": "regex", + "config": {"pattern": r"\bsk-[A-Za-z0-9_-]{12,}\b"}, + }, + }, + "action": {"decision": "deny"}, + "tags": ["regex", "secret", "sdk"], + }, + }, +] + + +async def create_or_get_control( + client: AgentControlClient, + *, + name: str, + definition: dict[str, Any], +) -> int: + """Create a control, or update and reuse an existing control with the same name.""" + try: + result = await controls.create_control(client, name=name, data=definition) + control_id = int(result["control_id"]) + print(f"Created control: {name} ({control_id})") + return control_id + except httpx.HTTPStatusError as exc: + if exc.response.status_code != 409: + raise + + page = await controls.list_controls(client, name=name, limit=100) + for summary in page.get("controls", []): + if summary.get("name") == name: + control_id = int(summary["id"]) + await controls.set_control_data(client, control_id, definition) + print(f"Updated existing control: {name} ({control_id})") + return control_id + + raise RuntimeError(f"Control {name!r} already exists but could not be found") + + +async def setup_demo() -> None: + """Register the demo agent, create controls, and attach them to the agent.""" + print("Setting up direct Galileo Luna demo controls") + print(f"Server: {SERVER_URL}") + print(f"Agent: {AGENT_NAME}") + print( + "Luna: " + f"scorer_label={LUNA_SCORER_LABEL!r}, " + f"scorer_id={LUNA_SCORER_ID!r}, " + f"scorer_version_id={LUNA_SCORER_VERSION_ID!r}, " + f"threshold={LUNA_THRESHOLD}, " + f"payload_field={LUNA_PAYLOAD_FIELD!r}" + ) + print(f"Auth: GALILEO_LUNA_AUTH_MODE={LUNA_AUTH_MODE or '(auto if one credential)'}") + + async with AgentControlClient(base_url=SERVER_URL, timeout=30.0) as client: + await client.health_check() + + result = await agents.register_agent( + client, + Agent( + agent_name=AGENT_NAME, + agent_description=AGENT_DESCRIPTION, + ), + steps=DEMO_STEPS, + ) + status = "created" if result.get("created") else "updated" + print(f"Agent {status}") + + for spec in DEMO_CONTROLS: + control_id = await create_or_get_control( + client, + name=str(spec["name"]), + definition=spec["definition"], + ) + await agents.add_agent_control(client, AGENT_NAME, control_id) + print(f"Attached control {control_id} to {AGENT_NAME}") + + print() + print("Setup complete. Run: uv run python demo_agent.py") + + +def main() -> None: + """Run setup.""" + asyncio.run(setup_demo()) + + +if __name__ == "__main__": + main() diff --git a/models/src/agent_control_models/__init__.py b/models/src/agent_control_models/__init__.py index 148cdd7a..6b5562a7 100644 --- a/models/src/agent_control_models/__init__.py +++ b/models/src/agent_control_models/__init__.py @@ -83,7 +83,11 @@ from .server import ( AgentRef, AgentSummary, + CloneAndBindControlRequest, + CloneAndBindControlResponse, + CloneAndBindTargetBinding, ConflictMode, + ControlAttachments, ControlSummary, ControlVersionSummary, CreateControlBindingRequest, @@ -107,9 +111,11 @@ PatchControlBindingResponse, PatchControlRequest, PatchControlResponse, + PolicyRef, RenderControlTemplateRequest, RenderControlTemplateResponse, StepKey, + TargetAttachmentRef, UpsertControlBindingRequest, UpsertControlBindingResponse, ValidateControlDataRequest, @@ -176,7 +182,11 @@ # Server models "AgentRef", "AgentSummary", + "CloneAndBindControlRequest", + "CloneAndBindControlResponse", + "CloneAndBindTargetBinding", "ConflictMode", + "ControlAttachments", "ControlVersionSummary", "ControlSummary", "CreateControlBindingRequest", @@ -200,9 +210,11 @@ "PatchControlBindingResponse", "PatchControlRequest", "PatchControlResponse", + "PolicyRef", "RenderControlTemplateRequest", "RenderControlTemplateResponse", "StepKey", + "TargetAttachmentRef", "UpsertControlBindingRequest", "UpsertControlBindingResponse", "ValidateControlDataRequest", diff --git a/models/src/agent_control_models/errors.py b/models/src/agent_control_models/errors.py index 0bf644d4..0db04134 100644 --- a/models/src/agent_control_models/errors.py +++ b/models/src/agent_control_models/errors.py @@ -54,6 +54,7 @@ class ErrorCode(StrEnum): AUTH_INVALID_KEY = "AUTH_INVALID_KEY" AUTH_INSUFFICIENT_PRIVILEGES = "AUTH_INSUFFICIENT_PRIVILEGES" AUTH_MISCONFIGURED = "AUTH_MISCONFIGURED" + AUTH_UPSTREAM_REJECTED = "AUTH_UPSTREAM_REJECTED" # Resource Not Found (2xx pattern) RESOURCE_NOT_FOUND = "RESOURCE_NOT_FOUND" # Generic fallback @@ -363,6 +364,7 @@ def make_error_type(error_code: ErrorCode) -> str: ErrorCode.AUTH_INVALID_KEY: "Invalid API Key", ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES: "Insufficient Privileges", ErrorCode.AUTH_MISCONFIGURED: "Authentication Misconfigured", + ErrorCode.AUTH_UPSTREAM_REJECTED: "Authorization Upstream Rejected Request", # Not found errors ErrorCode.RESOURCE_NOT_FOUND: "Resource Not Found", ErrorCode.AGENT_NOT_FOUND: "Agent Not Found", diff --git a/models/src/agent_control_models/server.py b/models/src/agent_control_models/server.py index 3529a5d4..a96b0410 100644 --- a/models/src/agent_control_models/server.py +++ b/models/src/agent_control_models/server.py @@ -14,6 +14,7 @@ from .agent import Agent, StepSchema from .base import BaseModel from .controls import ( + ControlAction, ControlDefinition, TemplateControlInput, TemplateDefinition, @@ -347,6 +348,9 @@ class GetControlResponse(BaseModel): id: int = Field(..., description="Control ID") name: str = Field(..., description="Control name") + cloned_from_control_id: int | None = Field( + None, description="Source control ID when this control is a clone." + ) data: ControlDefinition | UnrenderedTemplateControl = Field( description=( "Control configuration data. A ControlDefinition for raw/rendered " @@ -514,14 +518,60 @@ class AgentRef(BaseModel): agent_name: str = Field(..., description="Agent name") +class PolicyRef(BaseModel): + """Reference to a policy attached to a control.""" + + policy_id: int = Field(..., description="Policy ID") + + +class TargetAttachmentRef(BaseModel): + """Reference to a target binding attached to a control.""" + + binding_id: int = Field(..., description="Control binding ID") + target_type: str = Field(..., description="Opaque target kind") + target_id: str = Field(..., description="Opaque target identifier") + enabled: bool = Field(..., description="Whether this target binding is enabled") + + +class ControlAttachments(BaseModel): + """Attachments for a listed control.""" + + agents: list[AgentRef] = Field( + default_factory=list, + description="Direct agent associations for this control", + ) + policies: list[PolicyRef] = Field( + default_factory=list, + description="Policy associations for this control", + ) + targets: list[TargetAttachmentRef] = Field( + default_factory=list, + description="Target bindings for this control", + ) + targets_total: int = Field( + default=0, + description="Total target bindings matching the attachment filters", + ) + targets_truncated: bool = Field( + default=False, + description="Whether the target bindings list was capped", + ) + + class ControlSummary(BaseModel): """Summary of a control for list responses.""" id: int = Field(..., description="Control ID") name: str = Field(..., description="Control name") + cloned_from_control_id: int | None = Field( + None, description="Source control ID when this control is a clone." + ) description: str | None = Field(None, description="Control description") enabled: bool = Field(True, description="Whether control is enabled") execution: str | None = Field(None, description="'server' or 'sdk'") + action: ControlAction | None = Field( + None, description="Action applied when the control matches." + ) step_types: list[str] | None = Field(None, description="Step types in scope") stages: list[str] | None = Field(None, description="Evaluation stages in scope") tags: list[str] = Field(default_factory=list, description="Control tags") @@ -542,6 +592,13 @@ class ControlSummary(BaseModel): used_by_agents_count: int = Field( 0, description="Number of unique agents using this control" ) + attachments: ControlAttachments | None = Field( + None, + description=( + "Expanded attachment details. Present when list controls is called " + "with include_attachments=true." + ), + ) class ListControlsResponse(BaseModel): @@ -580,7 +637,7 @@ class GetControlVersionResponse(BaseModel): ..., description=( "Raw persisted snapshot of the control state at this version, including " - "metadata such as name, deleted_at, and cloned_control_id." + "metadata such as name, deleted_at, and cloned_from_control_id." ), ) @@ -635,6 +692,50 @@ class PatchControlResponse(BaseModel): ] +class CloneAndBindTargetBinding(BaseModel): + """Target binding to create for a cloned control.""" + + model_config = ConfigDict(extra="forbid") + + target_type: ControlBindingTargetField = Field( + ..., + description="Opaque attachment kind (caller-defined; e.g. 'environment', 'session').", + ) + target_id: ControlBindingTargetField = Field( + ..., description="Opaque external identifier within the target_type." + ) + enabled: bool = Field( + default=True, + description="Whether the created binding is active.", + ) + + +class CloneAndBindControlRequest(BaseModel): + """Request to clone a control and attach the clone to one target.""" + + model_config = ConfigDict(extra="forbid") + + name: SlugName | None = Field( + None, + description=( + "Optional unique name for the cloned control. If omitted, the server " + "generates a name from the source control name." + ), + ) + target_binding: CloneAndBindTargetBinding = Field( + ..., description="Target binding to create for the cloned control." + ) + + +class CloneAndBindControlResponse(BaseModel): + """Response from cloning and binding a control.""" + + id: int = Field(..., description="Identifier of the cloned control.") + name: str = Field(..., description="Name of the cloned control.") + cloned_from_control_id: int = Field(..., description="Source control ID.") + binding_id: int = Field(..., description="Identifier of the created binding.") + + class CreateControlBindingRequest(BaseModel): """Request to attach a control to an opaque external target.""" @@ -741,6 +842,21 @@ class UpsertControlBindingResponse(BaseModel): enabled: bool = Field(..., description="Current enabled value.") +class PatchControlBindingByKeyRequest(BaseModel): + """Request to update an existing control binding by natural key.""" + + target_type: ControlBindingTargetField = Field( + ..., description="Opaque attachment kind." + ) + target_id: ControlBindingTargetField = Field( + ..., description="Opaque external identifier within the target_type." + ) + control_id: int = Field( + ..., gt=0, description="ID of the bound control." + ) + enabled: bool = Field(..., description="New enabled value for the binding.") + + class DeleteControlBindingByKeyRequest(BaseModel): """Request to detach a control binding by natural key (idempotent).""" @@ -759,4 +875,3 @@ class DeleteControlBindingByKeyResponse(BaseModel): "binding existed." ), ) - diff --git a/sdks/python/src/agent_control/__init__.py b/sdks/python/src/agent_control/__init__.py index 0a1dc1ea..f0d07520 100644 --- a/sdks/python/src/agent_control/__init__.py +++ b/sdks/python/src/agent_control/__init__.py @@ -79,7 +79,7 @@ async def handle_input(user_message: str) -> str: set_trace_context_provider, ) -from . import agents, controls, evaluation, evaluators, policies +from . import agents, control_bindings, controls, evaluation, evaluators, policies from ._control_registry import ( StepSchemaDict, get_registered_steps, @@ -1019,10 +1019,14 @@ async def list_controls( name: str | None = None, enabled: bool | None = None, template_backed: bool | None = None, + cloned: bool | None = None, step_type: str | None = None, stage: Literal["pre", "post"] | None = None, execution: Literal["server", "sdk"] | None = None, tag: str | None = None, + include_attachments: bool = False, + attachment_target_type: str | None = None, + attachment_target_id: str | None = None, ) -> dict[str, Any]: """ List all controls from the server with optional filtering. @@ -1035,10 +1039,14 @@ async def list_controls( name: Optional filter by name (partial, case-insensitive) enabled: Optional filter by enabled status template_backed: Optional filter by whether the control is template-backed + cloned: Optional filter by whether the control was cloned from another control step_type: Optional filter by step type (built-ins: 'tool', 'llm') stage: Optional filter by stage ('pre' or 'post') execution: Optional filter by execution ('server' or 'sdk') tag: Optional filter by tag + include_attachments: Whether to include attachment details + attachment_target_type: Optional target binding type filter for attachments + attachment_target_id: Optional target binding ID filter for attachments Returns: Dictionary containing: @@ -1079,10 +1087,14 @@ async def main(): name=name, enabled=enabled, template_backed=template_backed, + cloned=cloned, step_type=step_type, stage=stage, execution=execution, tag=tag, + include_attachments=include_attachments, + attachment_target_type=attachment_target_type, + attachment_target_id=attachment_target_id, ) @@ -1147,6 +1159,49 @@ async def main(): return await controls.create_control(client, name, data=data) +async def clone_and_bind_control( + control_id: int, + *, + target_type: str, + target_id: str, + name: str | None = None, + enabled: bool = True, + server_url: str | None = None, + api_key: str | None = None, + api_key_header: str | None = None, +) -> dict[str, Any]: + """ + Clone an existing control and bind the clone to a target. + + Args: + control_id: Source control ID to clone + target_type: Opaque attachment kind + target_id: Opaque external target identifier + name: Optional unique name for the cloned control + enabled: Whether the created binding is active + server_url: Optional server URL (defaults to AGENT_CONTROL_URL env var) + api_key: Optional API key for authentication (defaults to AGENT_CONTROL_API_KEY env var) + + Returns: + Dictionary containing id, name, cloned_from_control_id, and binding_id. + """ + _final_server_url = server_url or os.getenv('AGENT_CONTROL_URL') or 'http://localhost:8000' + + async with _ad_hoc_client( + server_url=_final_server_url, + api_key=api_key, + api_key_header=api_key_header, + ) as client: + return await controls.clone_and_bind_control( + client, + control_id, + target_type=target_type, + target_id=target_id, + name=name, + enabled=enabled, + ) + + async def validate_control_data( data: dict[str, Any] | ControlDefinition | TemplateControlInput, server_url: str | None = None, @@ -1502,6 +1557,7 @@ async def main(): "add_agent_control", "remove_agent_control", # Control management + "clone_and_bind_control", "create_control", "list_controls", "get_control", @@ -1520,6 +1576,7 @@ async def main(): "agents", "policies", "controls", + "control_bindings", "evaluation", "evaluators", # Policy-Control management diff --git a/sdks/python/src/agent_control/control_bindings.py b/sdks/python/src/agent_control/control_bindings.py new file mode 100644 index 00000000..ca96a90e --- /dev/null +++ b/sdks/python/src/agent_control/control_bindings.py @@ -0,0 +1,69 @@ +"""Control binding management operations for Agent Control SDK.""" + +from typing import Any, cast + +from .client import AgentControlClient + + +async def upsert_control_binding_by_key( + client: AgentControlClient, + *, + target_type: str, + target_id: str, + control_id: int, + enabled: bool = True, +) -> dict[str, Any]: + """Attach a control to a target, or update the existing binding.""" + response = await client.http_client.put( + "/api/v1/control-bindings/by-key", + json={ + "target_type": target_type, + "target_id": target_id, + "control_id": control_id, + "enabled": enabled, + }, + ) + response.raise_for_status() + return cast(dict[str, Any], response.json()) + + +async def update_control_binding_by_key( + client: AgentControlClient, + *, + target_type: str, + target_id: str, + control_id: int, + enabled: bool, +) -> dict[str, Any]: + """Update an existing target binding without creating a missing binding.""" + response = await client.http_client.patch( + "/api/v1/control-bindings/by-key", + json={ + "target_type": target_type, + "target_id": target_id, + "control_id": control_id, + "enabled": enabled, + }, + ) + response.raise_for_status() + return cast(dict[str, Any], response.json()) + + +async def delete_control_binding_by_key( + client: AgentControlClient, + *, + target_type: str, + target_id: str, + control_id: int, +) -> dict[str, Any]: + """Detach a control from a target by natural key.""" + response = await client.http_client.post( + "/api/v1/control-bindings/by-key:delete", + json={ + "target_type": target_type, + "target_id": target_id, + "control_id": control_id, + }, + ) + response.raise_for_status() + return cast(dict[str, Any], response.json()) diff --git a/sdks/python/src/agent_control/controls.py b/sdks/python/src/agent_control/controls.py index 99fc6265..8478c357 100644 --- a/sdks/python/src/agent_control/controls.py +++ b/sdks/python/src/agent_control/controls.py @@ -20,10 +20,14 @@ async def list_controls( name: str | None = None, enabled: bool | None = None, template_backed: bool | None = None, + cloned: bool | None = None, step_type: str | None = None, stage: Literal["pre", "post"] | None = None, execution: Literal["server", "sdk"] | None = None, tag: str | None = None, + include_attachments: bool = False, + attachment_target_type: str | None = None, + attachment_target_id: str | None = None, ) -> dict[str, Any]: """ List all controls with optional filtering and pagination. @@ -37,10 +41,14 @@ async def list_controls( name: Optional filter by name (partial, case-insensitive match) enabled: Optional filter by enabled status template_backed: Optional filter by whether the control is template-backed + cloned: Optional filter by whether the control was cloned from another control step_type: Optional filter by step type (built-ins: 'tool', 'llm') stage: Optional filter by stage ('pre' or 'post') execution: Optional filter by execution ('server' or 'sdk') tag: Optional filter by tag + include_attachments: Whether to include attachment details + attachment_target_type: Optional target binding type filter for attachments + attachment_target_id: Optional target binding ID filter for attachments Returns: Dictionary containing: @@ -78,6 +86,8 @@ async def list_controls( params["enabled"] = enabled if template_backed is not None: params["template_backed"] = template_backed + if cloned is not None: + params["cloned"] = cloned if step_type is not None: params["step_type"] = step_type if stage is not None: @@ -86,6 +96,12 @@ async def list_controls( params["execution"] = execution if tag is not None: params["tag"] = tag + if include_attachments: + params["include_attachments"] = include_attachments + if attachment_target_type is not None: + params["attachment_target_type"] = attachment_target_type + if attachment_target_id is not None: + params["attachment_target_id"] = attachment_target_id response = await client.http_client.get("/api/v1/controls", params=params) response.raise_for_status() @@ -243,6 +259,47 @@ async def create_control( return result +async def clone_and_bind_control( + client: AgentControlClient, + control_id: int, + *, + target_type: str, + target_id: str, + name: str | None = None, + enabled: bool = True, +) -> dict[str, Any]: + """ + Clone an existing control and bind the clone to a target in one API call. + + Args: + client: AgentControlClient instance + control_id: Source control ID to clone + target_type: Opaque attachment kind + target_id: Opaque external target identifier + name: Optional unique name for the cloned control + enabled: Whether the created binding is active + + Returns: + Dictionary containing id, name, cloned_from_control_id, and binding_id. + """ + payload: dict[str, Any] = { + "target_binding": { + "target_type": target_type, + "target_id": target_id, + "enabled": enabled, + } + } + if name is not None: + payload["name"] = name + + response = await client.http_client.post( + f"/api/v1/controls/{control_id}/clone-and-bind", + json=payload, + ) + response.raise_for_status() + return cast(dict[str, Any], response.json()) + + async def set_control_data( client: AgentControlClient, control_id: int, diff --git a/sdks/python/src/agent_control/evaluation_events.py b/sdks/python/src/agent_control/evaluation_events.py index 0efe6e86..a0b37a03 100644 --- a/sdks/python/src/agent_control/evaluation_events.py +++ b/sdks/python/src/agent_control/evaluation_events.py @@ -22,6 +22,19 @@ _FALLBACK_TRACE_ID = "0" * 32 _FALLBACK_SPAN_ID = "0" * 16 _trace_warning_logged = False +_DEBUG_METADATA_KEYS = frozenset( + { + "selected_data", + "selected_data_preview", + "engine_selected_data", + "engine_selected_data_preview", + } +) + + +def _safe_event_metadata(metadata: dict[str, object]) -> dict[str, object]: + """Drop raw/debug metadata that should not be exported as observability data.""" + return {key: value for key, value in metadata.items() if key not in _DEBUG_METADATA_KEYS} def observability_metadata( @@ -88,7 +101,7 @@ def _build_events_for_matches( for match in matches: control_def = control_lookup.get(match.control_id) - event_metadata = dict(match.result.metadata or {}) + event_metadata = _safe_event_metadata(dict(match.result.metadata or {})) selector_path = None evaluator_name = None diff --git a/sdks/python/src/agent_control/evaluators/__init__.py b/sdks/python/src/agent_control/evaluators/__init__.py index ee77851a..8366a107 100644 --- a/sdks/python/src/agent_control/evaluators/__init__.py +++ b/sdks/python/src/agent_control/evaluators/__init__.py @@ -10,9 +10,10 @@ Then use `list_evaluators()` to get available evaluators. -Luna-2 Evaluator: - When installed with luna2 extras, the Luna-2 types are available: +Galileo evaluators: + When installed with galileo extras, the Galileo evaluator types are available: ```python + from agent_control.evaluators import LunaEvaluator, LunaEvaluatorConfig # if galileo installed from agent_control.evaluators import Luna2Evaluator, Luna2EvaluatorConfig # if luna2 installed ``` """ @@ -36,6 +37,33 @@ ] # Optionally export Luna-2 types when available +try: + from agent_control_evaluator_galileo.luna import ( # type: ignore[import-not-found] # noqa: F401 + LUNA_AVAILABLE, + GalileoLunaClient, + LunaEvaluator, + LunaEvaluatorConfig, + LunaOperator, + ScorerInvokeInputs, + ScorerInvokeRequest, + ScorerInvokeResponse, + ) + + __all__.extend( + [ + "GalileoLunaClient", + "ScorerInvokeInputs", + "ScorerInvokeRequest", + "ScorerInvokeResponse", + "LunaEvaluator", + "LunaEvaluatorConfig", + "LunaOperator", + "LUNA_AVAILABLE", + ] + ) +except ImportError: + pass + try: from agent_control_evaluator_galileo.luna2 import ( # type: ignore[import-not-found] # noqa: F401 LUNA2_AVAILABLE, @@ -45,12 +73,14 @@ Luna2Operator, ) - __all__.extend([ - "Luna2Evaluator", - "Luna2EvaluatorConfig", - "Luna2Metric", - "Luna2Operator", - "LUNA2_AVAILABLE", - ]) + __all__.extend( + [ + "Luna2Evaluator", + "Luna2EvaluatorConfig", + "Luna2Metric", + "Luna2Operator", + "LUNA2_AVAILABLE", + ] + ) except ImportError: pass diff --git a/sdks/python/src/agent_control/integrations/google_adk/plugin.py b/sdks/python/src/agent_control/integrations/google_adk/plugin.py index eb2155c8..28e59698 100644 --- a/sdks/python/src/agent_control/integrations/google_adk/plugin.py +++ b/sdks/python/src/agent_control/integrations/google_adk/plugin.py @@ -22,11 +22,18 @@ from agent_control.validation import ensure_agent_name try: - from google.adk.agents.callback_context import CallbackContext # type: ignore[import-not-found] - from google.adk.models import LlmRequest, LlmResponse # type: ignore[import-not-found] - from google.adk.plugins import BasePlugin # type: ignore[import-not-found] - from google.adk.tools import BaseTool # type: ignore[import-not-found] - from google.adk.tools.tool_context import ToolContext # type: ignore[import-not-found] + from google.adk.agents.callback_context import ( # type: ignore[import-not-found,import-untyped] + CallbackContext, + ) + from google.adk.models import ( # type: ignore[import-not-found,import-untyped] + LlmRequest, + LlmResponse, + ) + from google.adk.plugins import BasePlugin # type: ignore[import-not-found,import-untyped] + from google.adk.tools import BaseTool # type: ignore[import-not-found,import-untyped] + from google.adk.tools.tool_context import ( # type: ignore[import-not-found,import-untyped] + ToolContext, + ) except Exception as exc: # pragma: no cover - optional dependency raise RuntimeError( "Google ADK integration requires google-adk. " diff --git a/sdks/python/src/agent_control/otel_sink.py b/sdks/python/src/agent_control/otel_sink.py index a238dac6..e724f5af 100644 --- a/sdks/python/src/agent_control/otel_sink.py +++ b/sdks/python/src/agent_control/otel_sink.py @@ -28,6 +28,14 @@ "OpenTelemetry sink selected but no OTLP exporter configuration was found; " "control events will not be exported" ) +_DEBUG_METADATA_ATTRIBUTE_KEYS = frozenset( + { + "selected_data", + "selected_data_preview", + "engine_selected_data", + "engine_selected_data_preview", + } +) AttributeValue = str | bool | int | float | list[str] | list[bool] | list[int] | list[float] @@ -129,6 +137,8 @@ def control_event_to_otel_span(event: ControlExecutionEvent) -> OTELControlEvent attributes["agent_control.error_message"] = event.error_message for key, value in sorted(event.metadata.items()): + if key in _DEBUG_METADATA_ATTRIBUTE_KEYS: + continue attributes[f"agent_control.metadata.{key}"] = _normalize_attribute_value(value) return OTELControlEventSpan( diff --git a/sdks/python/tests/test_control_bindings_api.py b/sdks/python/tests/test_control_bindings_api.py new file mode 100644 index 00000000..d21d85fe --- /dev/null +++ b/sdks/python/tests/test_control_bindings_api.py @@ -0,0 +1,87 @@ +"""Unit tests for agent_control.control_bindings API wrappers.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, Mock + +import agent_control +import pytest + + +@pytest.mark.asyncio +async def test_upsert_control_binding_by_key_calls_endpoint() -> None: + response = Mock() + response.raise_for_status = Mock() + response.json = Mock(return_value={"binding_id": 123, "created": True, "enabled": True}) + client = SimpleNamespace(http_client=SimpleNamespace(put=AsyncMock(return_value=response))) + + result = await agent_control.control_bindings.upsert_control_binding_by_key( + client, + target_type="log_stream", + target_id="ls-prod", + control_id=456, + ) + + assert result["binding_id"] == 123 + client.http_client.put.assert_awaited_once_with( + "/api/v1/control-bindings/by-key", + json={ + "target_type": "log_stream", + "target_id": "ls-prod", + "control_id": 456, + "enabled": True, + }, + ) + + +@pytest.mark.asyncio +async def test_update_control_binding_by_key_calls_endpoint() -> None: + response = Mock() + response.raise_for_status = Mock() + response.json = Mock(return_value={"success": True, "enabled": False}) + client = SimpleNamespace(http_client=SimpleNamespace(patch=AsyncMock(return_value=response))) + + result = await agent_control.control_bindings.update_control_binding_by_key( + client, + target_type="log_stream", + target_id="ls-prod", + control_id=456, + enabled=False, + ) + + assert result["enabled"] is False + client.http_client.patch.assert_awaited_once_with( + "/api/v1/control-bindings/by-key", + json={ + "target_type": "log_stream", + "target_id": "ls-prod", + "control_id": 456, + "enabled": False, + }, + ) + + +@pytest.mark.asyncio +async def test_delete_control_binding_by_key_calls_endpoint() -> None: + response = Mock() + response.raise_for_status = Mock() + response.json = Mock(return_value={"deleted": True}) + client = SimpleNamespace(http_client=SimpleNamespace(post=AsyncMock(return_value=response))) + + result = await agent_control.control_bindings.delete_control_binding_by_key( + client, + target_type="log_stream", + target_id="ls-prod", + control_id=456, + ) + + assert result["deleted"] is True + client.http_client.post.assert_awaited_once_with( + "/api/v1/control-bindings/by-key:delete", + json={ + "target_type": "log_stream", + "target_id": "ls-prod", + "control_id": 456, + }, + ) diff --git a/sdks/python/tests/test_controls_api.py b/sdks/python/tests/test_controls_api.py index ed505451..78a01c4e 100644 --- a/sdks/python/tests/test_controls_api.py +++ b/sdks/python/tests/test_controls_api.py @@ -2,7 +2,10 @@ from __future__ import annotations +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from types import SimpleNamespace +from typing import Any from unittest.mock import AsyncMock, Mock import pytest @@ -29,6 +32,49 @@ async def test_list_controls_passes_template_backed_filter() -> None: ) +@pytest.mark.asyncio +async def test_list_controls_passes_cloned_filter() -> None: + # Given: an SDK client stub and a cloned list filter + response = Mock() + response.raise_for_status = Mock() + response.json = Mock(return_value={"controls": [], "pagination": {}}) + client = SimpleNamespace(http_client=SimpleNamespace(get=AsyncMock(return_value=response))) + + # When: listing controls through the SDK wrapper + await agent_control.controls.list_controls(client, cloned=False) + + # Then: the filter is forwarded to the API request + client.http_client.get.assert_awaited_once_with( + "/api/v1/controls", + params={"limit": 20, "cloned": False}, + ) + + +@pytest.mark.asyncio +async def test_list_controls_passes_attachment_filters() -> None: + response = Mock() + response.raise_for_status = Mock() + response.json = Mock(return_value={"controls": [], "pagination": {}}) + client = SimpleNamespace(http_client=SimpleNamespace(get=AsyncMock(return_value=response))) + + await agent_control.controls.list_controls( + client, + include_attachments=True, + attachment_target_type="log_stream", + attachment_target_id="ls-prod", + ) + + client.http_client.get.assert_awaited_once_with( + "/api/v1/controls", + params={ + "limit": 20, + "include_attachments": True, + "attachment_target_type": "log_stream", + "attachment_target_id": "ls-prod", + }, + ) + + @pytest.mark.asyncio async def test_create_control_accepts_template_control_input() -> None: # Given: an SDK client stub and template-backed control input @@ -71,6 +117,133 @@ async def test_create_control_accepts_template_control_input() -> None: assert kwargs["json"]["data"]["template_values"]["pattern"] == "hello" +@pytest.mark.asyncio +async def test_clone_and_bind_control_calls_clone_endpoint() -> None: + # Given: an SDK client stub for clone-and-bind + response = Mock() + response.raise_for_status = Mock() + response.json = Mock( + return_value={ + "id": 456, + "name": "clone-name", + "cloned_from_control_id": 123, + "binding_id": 789, + } + ) + client = SimpleNamespace(http_client=SimpleNamespace(post=AsyncMock(return_value=response))) + + # When: cloning and binding through the SDK wrapper + result = await agent_control.controls.clone_and_bind_control( + client, + 123, + target_type="log_stream", + target_id="logstream-123", + name="clone-name", + enabled=False, + ) + + # Then: the SDK posts the expected payload + assert result["id"] == 456 + client.http_client.post.assert_awaited_once_with( + "/api/v1/controls/123/clone-and-bind", + json={ + "target_binding": { + "target_type": "log_stream", + "target_id": "logstream-123", + "enabled": False, + }, + "name": "clone-name", + }, + ) + + +@pytest.mark.asyncio +async def test_top_level_list_controls_passes_cloned_filter( + monkeypatch: pytest.MonkeyPatch, +) -> None: + captured: dict[str, Any] = {} + stub_client = object() + + @asynccontextmanager + async def fake_ad_hoc_client(**kwargs: Any) -> AsyncGenerator[object, None]: + captured["client_kwargs"] = kwargs + yield stub_client + + async def fake_list_controls(client: object, **kwargs: Any) -> dict[str, Any]: + captured["client"] = client + captured["list_kwargs"] = kwargs + return {"controls": [], "pagination": {}} + + monkeypatch.setattr(agent_control, "_ad_hoc_client", fake_ad_hoc_client) + monkeypatch.setattr(agent_control.controls, "list_controls", fake_list_controls) + + result = await agent_control.list_controls( + cloned=False, + include_attachments=True, + attachment_target_type="log_stream", + attachment_target_id="ls-prod", + server_url="http://server", + ) + + assert result["controls"] == [] + assert captured["client"] is stub_client + assert captured["client_kwargs"]["server_url"] == "http://server" + assert captured["list_kwargs"]["cloned"] is False + assert captured["list_kwargs"]["include_attachments"] is True + assert captured["list_kwargs"]["attachment_target_type"] == "log_stream" + assert captured["list_kwargs"]["attachment_target_id"] == "ls-prod" + + +@pytest.mark.asyncio +async def test_top_level_clone_and_bind_control_uses_ad_hoc_client( + monkeypatch: pytest.MonkeyPatch, +) -> None: + captured: dict[str, Any] = {} + stub_client = object() + + @asynccontextmanager + async def fake_ad_hoc_client(**kwargs: Any) -> AsyncGenerator[object, None]: + captured["client_kwargs"] = kwargs + yield stub_client + + async def fake_clone_and_bind_control( + client: object, + control_id: int, + **kwargs: Any, + ) -> dict[str, Any]: + captured["client"] = client + captured["control_id"] = control_id + captured["clone_kwargs"] = kwargs + return {"id": 456, "binding_id": 789} + + monkeypatch.setattr(agent_control, "_ad_hoc_client", fake_ad_hoc_client) + monkeypatch.setattr( + agent_control.controls, + "clone_and_bind_control", + fake_clone_and_bind_control, + ) + + result = await agent_control.clone_and_bind_control( + 123, + target_type="log_stream", + target_id="logstream-123", + name="clone-name", + enabled=False, + server_url="http://server", + ) + + assert result["binding_id"] == 789 + assert captured["client"] is stub_client + assert captured["client_kwargs"]["server_url"] == "http://server" + assert captured["control_id"] == 123 + assert captured["clone_kwargs"] == { + "target_type": "log_stream", + "target_id": "logstream-123", + "name": "clone-name", + "enabled": False, + } + + @pytest.mark.asyncio async def test_list_control_versions_forwards_cursor_and_limit() -> None: # Given: an SDK client stub and version-history pagination params diff --git a/sdks/python/tests/test_evaluators_optional_imports.py b/sdks/python/tests/test_evaluators_optional_imports.py new file mode 100644 index 00000000..735164be --- /dev/null +++ b/sdks/python/tests/test_evaluators_optional_imports.py @@ -0,0 +1,116 @@ +"""Coverage for the optional galileo import fallbacks in agent_control.evaluators. + +The galileo extras are normally installed in the dev environment, so the +``except ImportError`` branches in ``agent_control/evaluators/__init__.py`` +never fire under regular tests. This module forces those failures by hiding +the relevant modules in ``sys.modules`` and reloading the package. +""" + +from __future__ import annotations + +import builtins +import importlib +import importlib.util +import sys + +import pytest + + +def _module_available(name: str) -> bool: + """Return whether ``name`` resolves without raising for missing parents.""" + try: + return importlib.util.find_spec(name) is not None + except (ImportError, ValueError): + # ``find_spec`` raises ModuleNotFoundError (a subclass of ImportError) + # when a *parent* package is missing, instead of returning None. Treat + # that as "not installed." + return False + + +_GALILEO_INSTALLED = _module_available( + "agent_control_evaluator_galileo.luna" +) and _module_available("agent_control_evaluator_galileo.luna2") + + +def _reload_evaluators_with_blocked(prefix: str) -> object: + """Reload ``agent_control.evaluators`` while ``prefix.*`` imports fail. + + Returns the freshly loaded module so callers can inspect ``__all__``. + Restores the original ``builtins.__import__`` and ``sys.modules`` entries + on the way out. + """ + original_import = builtins.__import__ + + def fail_for_prefix(name: str, *args: object, **kwargs: object) -> object: + if name == prefix or name.startswith(f"{prefix}."): + raise ImportError(f"forced failure for {name}") + return original_import(name, *args, **kwargs) # type: ignore[arg-type] + + # Drop any cached entries so the patched import is consulted. + blocked_modules = [m for m in list(sys.modules) if m == prefix or m.startswith(f"{prefix}.")] + saved_modules = {m: sys.modules.pop(m) for m in blocked_modules} + saved_evaluators = sys.modules.pop("agent_control.evaluators", None) + + builtins.__import__ = fail_for_prefix + try: + import agent_control.evaluators as reloaded + + reloaded = importlib.reload(reloaded) + return reloaded + finally: + builtins.__import__ = original_import + # Restore the cached modules so other tests keep their state. + for name, module in saved_modules.items(): + sys.modules[name] = module + if saved_evaluators is not None: + sys.modules["agent_control.evaluators"] = saved_evaluators + + +def test_module_loads_when_galileo_luna_is_unavailable(): + """Hiding ``agent_control_evaluator_galileo.luna`` exercises its except branch.""" + reloaded = _reload_evaluators_with_blocked("agent_control_evaluator_galileo.luna") + + # Core names are always present. + assert "Evaluator" in reloaded.__all__ + # Luna1 names are NOT present because the import failed. + assert "LunaEvaluator" not in reloaded.__all__ + assert "GalileoLunaClient" not in reloaded.__all__ + + +def test_module_loads_when_galileo_package_is_unavailable(): + """Hiding the whole package exercises both ImportError fallbacks at once.""" + reloaded = _reload_evaluators_with_blocked("agent_control_evaluator_galileo") + + assert "Evaluator" in reloaded.__all__ + # Both luna1 and luna2 optional names are absent. + for absent in ( + "LunaEvaluator", + "GalileoLunaClient", + "Luna2Evaluator", + "Luna2EvaluatorConfig", + "LUNA_AVAILABLE", + "LUNA2_AVAILABLE", + ): + assert absent not in reloaded.__all__ + + +@pytest.mark.skipif( + not _GALILEO_INSTALLED, + reason="agent-control-evaluator-galileo extras not installed in this environment", +) +def test_module_loads_galileo_optional_imports_when_available(): + """Sanity check: with galileo installed, the optional names ARE exposed. + + Reloading without patching __import__ runs both success branches. + """ + saved = sys.modules.pop("agent_control.evaluators", None) + try: + import agent_control.evaluators as reloaded + + reloaded = importlib.reload(reloaded) + # Sanity: at least one luna1 and one luna2 name should reappear. + assert "LunaEvaluator" in reloaded.__all__ + assert "Luna2Evaluator" in reloaded.__all__ + finally: + if saved is not None: + sys.modules["agent_control.evaluators"] = saved diff --git a/sdks/python/tests/test_observability_updates.py b/sdks/python/tests/test_observability_updates.py index 181d3c6c..dd7f5d2f 100644 --- a/sdks/python/tests/test_observability_updates.py +++ b/sdks/python/tests/test_observability_updates.py @@ -67,14 +67,21 @@ def _make_response(self, **kwargs): defaults.update(kwargs) return EvaluationResponse(**defaults) - def _make_match(self, control_id, control_name="ctrl", action="observe", matched=True): + def _make_match( + self, + control_id, + control_name="ctrl", + action="observe", + matched=True, + metadata=None, + ): from agent_control_models import ControlMatch, EvaluatorResult return ControlMatch( control_id=control_id, control_name=control_name, action=action, - result=EvaluatorResult(matched=matched, confidence=0.9), + result=EvaluatorResult(matched=matched, confidence=0.9, metadata=metadata), ) def test_combines_matches_errors_and_non_matches(self): @@ -172,14 +179,21 @@ def _make_request(self, step_type="llm"): stage="pre", ) - def _make_match(self, control_id, control_name="ctrl", action="observe", matched=True): + def _make_match( + self, + control_id, + control_name="ctrl", + action="observe", + matched=True, + metadata=None, + ): from agent_control_models import ControlMatch, EvaluatorResult return ControlMatch( control_id=control_id, control_name=control_name, action=action, - result=EvaluatorResult(matched=matched, confidence=0.9), + result=EvaluatorResult(matched=matched, confidence=0.9, metadata=metadata), ) def _make_response(self, matches=None, errors=None, non_matches=None): @@ -224,6 +238,56 @@ def test_builds_events_with_trace_context(self): assert event.evaluator_name == "regex" assert event.selector_path == "input" + def test_drops_raw_selected_data_from_event_metadata(self): + response = self._make_response( + matches=[ + self._make_match( + 1, + "ctrl-1", + metadata={ + "selected_data": {"prompt": "raw sensitive input"}, + "selected_data_preview": { + "type": "dict", + "value": {"prompt": "raw sensitive input"}, + "truncated": False, + }, + "engine_selected_data": {"prompt": "raw sensitive input"}, + "engine_selected_data_preview": { + "type": "dict", + "value": {"prompt": "raw sensitive input"}, + "truncated": False, + }, + }, + ) + ] + ) + request = self._make_request() + control_lookup = { + 1: self._make_control( + 1, + "ctrl-1", + { + "evaluator": {"name": "regex", "config": {"pattern": "test"}}, + "selector": {"path": "input"}, + }, + ).control + } + + events = build_control_execution_events( + response, + request, + control_lookup, + "trace123", + "span456", + "test-agent", + ) + + assert len(events) == 1 + assert "selected_data" not in events[0].metadata + assert "selected_data_preview" not in events[0].metadata + assert "engine_selected_data" not in events[0].metadata + assert "engine_selected_data_preview" not in events[0].metadata + def test_composite_control_uses_representative_observability_identity(self): response = self._make_response(non_matches=[self._make_match(1, "ctrl-1", matched=False)]) request = self._make_request() diff --git a/sdks/python/tests/test_otel_sink.py b/sdks/python/tests/test_otel_sink.py index 6f1c81fd..4d4aa451 100644 --- a/sdks/python/tests/test_otel_sink.py +++ b/sdks/python/tests/test_otel_sink.py @@ -39,7 +39,23 @@ def _make_event(**overrides: object) -> ControlExecutionEvent: evaluator_name="regex", selector_path="input", error_message=None, - metadata={"labels": ["security", "pii"], "threshold": 3, "nested": {"k": "v"}}, + metadata={ + "labels": ["security", "pii"], + "threshold": 3, + "nested": {"k": "v"}, + "selected_data": {"prompt": "raw sensitive input"}, + "selected_data_preview": { + "type": "dict", + "value": {"prompt": "raw sensitive input"}, + "truncated": False, + }, + "engine_selected_data": {"prompt": "raw sensitive input"}, + "engine_selected_data_preview": { + "type": "dict", + "value": {"prompt": "raw sensitive input"}, + "truncated": False, + }, + }, ) return event.model_copy(update=overrides) @@ -227,6 +243,10 @@ def test_control_event_to_otel_span_maps_event_fields() -> None: assert span.attributes["agent_control.matched"] is True assert span.attributes["agent_control.metadata.labels"] == ["security", "pii"] assert span.attributes["agent_control.metadata.nested"] == '{"k": "v"}' + assert "agent_control.metadata.selected_data" not in span.attributes + assert "agent_control.metadata.selected_data_preview" not in span.attributes + assert "agent_control.metadata.engine_selected_data" not in span.attributes + assert "agent_control.metadata.engine_selected_data_preview" not in span.attributes assert span.error_message == "blocked" assert span.end_time_unix_nano >= span.start_time_unix_nano diff --git a/sdks/typescript/overlays/method-names.overlay.yaml b/sdks/typescript/overlays/method-names.overlay.yaml index ce36006c..967847c6 100644 --- a/sdks/typescript/overlays/method-names.overlay.yaml +++ b/sdks/typescript/overlays/method-names.overlay.yaml @@ -120,6 +120,11 @@ actions: x-speakeasy-group: controlBindings x-speakeasy-name-override: upsertByKey + - target: $["paths"]["/api/v1/control-bindings/by-key"]["patch"] + update: + x-speakeasy-group: controlBindings + x-speakeasy-name-override: updateByKey + - target: $["paths"]["/api/v1/control-bindings/by-key:delete"]["post"] update: x-speakeasy-group: controlBindings @@ -180,6 +185,11 @@ actions: x-speakeasy-group: controls x-speakeasy-name-override: delete + - target: $["paths"]["/api/v1/controls/{control_id}/clone-and-bind"]["post"] + update: + x-speakeasy-group: controls + x-speakeasy-name-override: cloneAndBindControl + - target: $["paths"]["/api/v1/controls/{control_id}/data"]["get"] update: x-speakeasy-group: controls diff --git a/sdks/typescript/src/generated/funcs/control-bindings-update-by-key.ts b/sdks/typescript/src/generated/funcs/control-bindings-update-by-key.ts new file mode 100644 index 00000000..f34d5ff1 --- /dev/null +++ b/sdks/typescript/src/generated/funcs/control-bindings-update-by-key.ts @@ -0,0 +1,176 @@ +/* + * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. + */ + +import * as z from "zod/v4-mini"; +import { AgentControlSDKCore } from "../core.js"; +import { encodeJSON } from "../lib/encodings.js"; +import * as M from "../lib/matchers.js"; +import { compactMap } from "../lib/primitives.js"; +import { safeParse } from "../lib/schemas.js"; +import { RequestOptions } from "../lib/sdks.js"; +import { extractSecurity, resolveGlobalSecurity } from "../lib/security.js"; +import { pathToFunc } from "../lib/url.js"; +import { AgentControlSDKError } from "../models/errors/agent-control-sdk-error.js"; +import { + ConnectionError, + InvalidRequestError, + RequestAbortedError, + RequestTimeoutError, + UnexpectedClientError, +} from "../models/errors/http-client-errors.js"; +import * as errors from "../models/errors/index.js"; +import { ResponseValidationError } from "../models/errors/response-validation-error.js"; +import { SDKValidationError } from "../models/errors/sdk-validation-error.js"; +import * as models from "../models/index.js"; +import { APICall, APIPromise } from "../types/async.js"; +import { Result } from "../types/fp.js"; + +/** + * Update a control binding by natural key + * + * @remarks + * Update an existing binding using ``(target_type, target_id, control_id)``. + * + * This route is target-scoped because the request body includes the target + * identifiers before authorization runs. Unlike ``PUT /by-key``, it never + * creates a missing binding. + */ +export function controlBindingsUpdateByKey( + client: AgentControlSDKCore, + request: models.PatchControlBindingByKeyRequest, + options?: RequestOptions, +): APIPromise< + Result< + models.PatchControlBindingResponse, + | errors.HTTPValidationError + | AgentControlSDKError + | ResponseValidationError + | ConnectionError + | RequestAbortedError + | RequestTimeoutError + | InvalidRequestError + | UnexpectedClientError + | SDKValidationError + > +> { + return new APIPromise($do( + client, + request, + options, + )); +} + +async function $do( + client: AgentControlSDKCore, + request: models.PatchControlBindingByKeyRequest, + options?: RequestOptions, +): Promise< + [ + Result< + models.PatchControlBindingResponse, + | errors.HTTPValidationError + | AgentControlSDKError + | ResponseValidationError + | ConnectionError + | RequestAbortedError + | RequestTimeoutError + | InvalidRequestError + | UnexpectedClientError + | SDKValidationError + >, + APICall, + ] +> { + const parsed = safeParse( + request, + (value) => + z.parse(models.PatchControlBindingByKeyRequest$outboundSchema, value), + "Input validation failed", + ); + if (!parsed.ok) { + return [parsed, { status: "invalid" }]; + } + const payload = parsed.value; + const body = encodeJSON("body", payload, { explode: true }); + + const path = pathToFunc("/api/v1/control-bindings/by-key")(); + + const headers = new Headers(compactMap({ + "Content-Type": "application/json", + Accept: "application/json", + })); + + const secConfig = await extractSecurity(client._options.apiKeyHeader); + const securityInput = secConfig == null ? {} : { apiKeyHeader: secConfig }; + const requestSecurity = resolveGlobalSecurity(securityInput); + + const context = { + options: client._options, + baseURL: options?.serverURL ?? client._baseURL ?? "", + operationID: + "patch_control_binding_by_key_api_v1_control_bindings_by_key_patch", + oAuth2Scopes: null, + + resolvedSecurity: requestSecurity, + + securitySource: client._options.apiKeyHeader, + retryConfig: options?.retries + || client._options.retryConfig + || { strategy: "none" }, + retryCodes: options?.retryCodes || ["429", "500", "502", "503", "504"], + }; + + const requestRes = client._createRequest(context, { + security: requestSecurity, + method: "PATCH", + baseURL: options?.serverURL, + path: path, + headers: headers, + body: body, + userAgent: client._options.userAgent, + timeoutMs: options?.timeoutMs || client._options.timeoutMs || -1, + }, options); + if (!requestRes.ok) { + return [requestRes, { status: "invalid" }]; + } + const req = requestRes.value; + + const doResult = await client._do(req, { + context, + errorCodes: ["422", "4XX", "5XX"], + retryConfig: context.retryConfig, + retryCodes: context.retryCodes, + }); + if (!doResult.ok) { + return [doResult, { status: "request-error", request: req }]; + } + const response = doResult.value; + + const responseFields = { + HttpMeta: { Response: response, Request: req }, + }; + + const [result] = await M.match< + models.PatchControlBindingResponse, + | errors.HTTPValidationError + | AgentControlSDKError + | ResponseValidationError + | ConnectionError + | RequestAbortedError + | RequestTimeoutError + | InvalidRequestError + | UnexpectedClientError + | SDKValidationError + >( + M.json(200, models.PatchControlBindingResponse$inboundSchema), + M.jsonErr(422, errors.HTTPValidationError$inboundSchema), + M.fail("4XX"), + M.fail("5XX"), + )(response, req, { extraFields: responseFields }); + if (!result.ok) { + return [result, { status: "complete", request: req, response }]; + } + + return [result, { status: "complete", request: req, response }]; +} diff --git a/sdks/typescript/src/generated/funcs/controls-clone-and-bind-control.ts b/sdks/typescript/src/generated/funcs/controls-clone-and-bind-control.ts new file mode 100644 index 00000000..559c3848 --- /dev/null +++ b/sdks/typescript/src/generated/funcs/controls-clone-and-bind-control.ts @@ -0,0 +1,188 @@ +/* + * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. + */ + +import * as z from "zod/v4-mini"; +import { AgentControlSDKCore } from "../core.js"; +import { encodeJSON, encodeSimple } from "../lib/encodings.js"; +import * as M from "../lib/matchers.js"; +import { compactMap } from "../lib/primitives.js"; +import { safeParse } from "../lib/schemas.js"; +import { RequestOptions } from "../lib/sdks.js"; +import { extractSecurity, resolveGlobalSecurity } from "../lib/security.js"; +import { pathToFunc } from "../lib/url.js"; +import { AgentControlSDKError } from "../models/errors/agent-control-sdk-error.js"; +import { + ConnectionError, + InvalidRequestError, + RequestAbortedError, + RequestTimeoutError, + UnexpectedClientError, +} from "../models/errors/http-client-errors.js"; +import * as errors from "../models/errors/index.js"; +import { ResponseValidationError } from "../models/errors/response-validation-error.js"; +import { SDKValidationError } from "../models/errors/sdk-validation-error.js"; +import * as models from "../models/index.js"; +import * as operations from "../models/operations/index.js"; +import { APICall, APIPromise } from "../types/async.js"; +import { Result } from "../types/fp.js"; + +/** + * Clone a control and bind the clone to a target + * + * @remarks + * Clone an active control and attach the clone to an opaque target. + */ +export function controlsCloneAndBindControl( + client: AgentControlSDKCore, + request: + operations.CloneAndBindControlApiV1ControlsControlIdCloneAndBindPostRequest, + options?: RequestOptions, +): APIPromise< + Result< + models.CloneAndBindControlResponse, + | errors.HTTPValidationError + | AgentControlSDKError + | ResponseValidationError + | ConnectionError + | RequestAbortedError + | RequestTimeoutError + | InvalidRequestError + | UnexpectedClientError + | SDKValidationError + > +> { + return new APIPromise($do( + client, + request, + options, + )); +} + +async function $do( + client: AgentControlSDKCore, + request: + operations.CloneAndBindControlApiV1ControlsControlIdCloneAndBindPostRequest, + options?: RequestOptions, +): Promise< + [ + Result< + models.CloneAndBindControlResponse, + | errors.HTTPValidationError + | AgentControlSDKError + | ResponseValidationError + | ConnectionError + | RequestAbortedError + | RequestTimeoutError + | InvalidRequestError + | UnexpectedClientError + | SDKValidationError + >, + APICall, + ] +> { + const parsed = safeParse( + request, + (value) => + z.parse( + operations + .CloneAndBindControlApiV1ControlsControlIdCloneAndBindPostRequest$outboundSchema, + value, + ), + "Input validation failed", + ); + if (!parsed.ok) { + return [parsed, { status: "invalid" }]; + } + const payload = parsed.value; + const body = encodeJSON("body", payload.body, { explode: true }); + + const pathParams = { + control_id: encodeSimple("control_id", payload.control_id, { + explode: false, + charEncoding: "percent", + }), + }; + + const path = pathToFunc("/api/v1/controls/{control_id}/clone-and-bind")( + pathParams, + ); + + const headers = new Headers(compactMap({ + "Content-Type": "application/json", + Accept: "application/json", + })); + + const secConfig = await extractSecurity(client._options.apiKeyHeader); + const securityInput = secConfig == null ? {} : { apiKeyHeader: secConfig }; + const requestSecurity = resolveGlobalSecurity(securityInput); + + const context = { + options: client._options, + baseURL: options?.serverURL ?? client._baseURL ?? "", + operationID: + "clone_and_bind_control_api_v1_controls__control_id__clone_and_bind_post", + oAuth2Scopes: null, + + resolvedSecurity: requestSecurity, + + securitySource: client._options.apiKeyHeader, + retryConfig: options?.retries + || client._options.retryConfig + || { strategy: "none" }, + retryCodes: options?.retryCodes || ["429", "500", "502", "503", "504"], + }; + + const requestRes = client._createRequest(context, { + security: requestSecurity, + method: "POST", + baseURL: options?.serverURL, + path: path, + headers: headers, + body: body, + userAgent: client._options.userAgent, + timeoutMs: options?.timeoutMs || client._options.timeoutMs || -1, + }, options); + if (!requestRes.ok) { + return [requestRes, { status: "invalid" }]; + } + const req = requestRes.value; + + const doResult = await client._do(req, { + context, + errorCodes: ["422", "4XX", "5XX"], + retryConfig: context.retryConfig, + retryCodes: context.retryCodes, + }); + if (!doResult.ok) { + return [doResult, { status: "request-error", request: req }]; + } + const response = doResult.value; + + const responseFields = { + HttpMeta: { Response: response, Request: req }, + }; + + const [result] = await M.match< + models.CloneAndBindControlResponse, + | errors.HTTPValidationError + | AgentControlSDKError + | ResponseValidationError + | ConnectionError + | RequestAbortedError + | RequestTimeoutError + | InvalidRequestError + | UnexpectedClientError + | SDKValidationError + >( + M.json(200, models.CloneAndBindControlResponse$inboundSchema), + M.jsonErr(422, errors.HTTPValidationError$inboundSchema), + M.fail("4XX"), + M.fail("5XX"), + )(response, req, { extraFields: responseFields }); + if (!result.ok) { + return [result, { status: "complete", request: req, response }]; + } + + return [result, { status: "complete", request: req, response }]; +} diff --git a/sdks/typescript/src/generated/funcs/controls-list.ts b/sdks/typescript/src/generated/funcs/controls-list.ts index 2fd69073..4b1aadd7 100644 --- a/sdks/typescript/src/generated/funcs/controls-list.ts +++ b/sdks/typescript/src/generated/funcs/controls-list.ts @@ -41,10 +41,16 @@ import { Result } from "../types/fp.js"; * name: Optional filter by name (partial, case-insensitive match) * enabled: Optional filter by enabled status * template_backed: Optional filter by whether the control is template-backed + * cloned: Optional filter by whether the control was cloned from another control * step_type: Optional filter by step type (built-ins: 'tool', 'llm') * stage: Optional filter by stage ('pre' or 'post') * execution: Optional filter by execution ('server' or 'sdk') * tag: Optional filter by tag + * include_attachments: Whether to include attachment details for listed controls + * attachment_target_type: Optional target binding type filter for controls and + * attachments + * attachment_target_id: Optional target binding ID filter for controls and + * attachments * db: Database session (injected) * * Returns: @@ -119,9 +125,13 @@ async function $do( const path = pathToFunc("/api/v1/controls")(); const query = encodeFormQuery({ + "attachment_target_id": payload?.attachment_target_id, + "attachment_target_type": payload?.attachment_target_type, + "cloned": payload?.cloned, "cursor": payload?.cursor, "enabled": payload?.enabled, "execution": payload?.execution, + "include_attachments": payload?.include_attachments, "limit": payload?.limit, "name": payload?.name, "stage": payload?.stage, diff --git a/sdks/typescript/src/generated/models/clone-and-bind-control-request.ts b/sdks/typescript/src/generated/models/clone-and-bind-control-request.ts new file mode 100644 index 00000000..90d4a54b --- /dev/null +++ b/sdks/typescript/src/generated/models/clone-and-bind-control-request.ts @@ -0,0 +1,55 @@ +/* + * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. + */ + +import * as z from "zod/v4-mini"; +import { remap as remap$ } from "../lib/primitives.js"; +import { + CloneAndBindTargetBinding, + CloneAndBindTargetBinding$Outbound, + CloneAndBindTargetBinding$outboundSchema, +} from "./clone-and-bind-target-binding.js"; + +/** + * Request to clone a control and attach the clone to one target. + */ +export type CloneAndBindControlRequest = { + /** + * Optional unique name for the cloned control. If omitted, the server generates a name from the source control name. + */ + name?: string | null | undefined; + /** + * Target binding to create for a cloned control. + */ + targetBinding: CloneAndBindTargetBinding; +}; + +/** @internal */ +export type CloneAndBindControlRequest$Outbound = { + name?: string | null | undefined; + target_binding: CloneAndBindTargetBinding$Outbound; +}; + +/** @internal */ +export const CloneAndBindControlRequest$outboundSchema: z.ZodMiniType< + CloneAndBindControlRequest$Outbound, + CloneAndBindControlRequest +> = z.pipe( + z.object({ + name: z.optional(z.nullable(z.string())), + targetBinding: CloneAndBindTargetBinding$outboundSchema, + }), + z.transform((v) => { + return remap$(v, { + targetBinding: "target_binding", + }); + }), +); + +export function cloneAndBindControlRequestToJSON( + cloneAndBindControlRequest: CloneAndBindControlRequest, +): string { + return JSON.stringify( + CloneAndBindControlRequest$outboundSchema.parse(cloneAndBindControlRequest), + ); +} diff --git a/sdks/typescript/src/generated/models/clone-and-bind-control-response.ts b/sdks/typescript/src/generated/models/clone-and-bind-control-response.ts new file mode 100644 index 00000000..f65e09f0 --- /dev/null +++ b/sdks/typescript/src/generated/models/clone-and-bind-control-response.ts @@ -0,0 +1,61 @@ +/* + * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. + */ + +import * as z from "zod/v4-mini"; +import { remap as remap$ } from "../lib/primitives.js"; +import { safeParse } from "../lib/schemas.js"; +import { Result as SafeParseResult } from "../types/fp.js"; +import * as types from "../types/primitives.js"; +import { SDKValidationError } from "./errors/sdk-validation-error.js"; + +/** + * Response from cloning and binding a control. + */ +export type CloneAndBindControlResponse = { + /** + * Identifier of the created binding. + */ + bindingId: number; + /** + * Source control ID. + */ + clonedFromControlId: number; + /** + * Identifier of the cloned control. + */ + id: number; + /** + * Name of the cloned control. + */ + name: string; +}; + +/** @internal */ +export const CloneAndBindControlResponse$inboundSchema: z.ZodMiniType< + CloneAndBindControlResponse, + unknown +> = z.pipe( + z.object({ + binding_id: types.number(), + cloned_from_control_id: types.number(), + id: types.number(), + name: types.string(), + }), + z.transform((v) => { + return remap$(v, { + "binding_id": "bindingId", + "cloned_from_control_id": "clonedFromControlId", + }); + }), +); + +export function cloneAndBindControlResponseFromJSON( + jsonString: string, +): SafeParseResult { + return safeParse( + jsonString, + (x) => CloneAndBindControlResponse$inboundSchema.parse(JSON.parse(x)), + `Failed to parse 'CloneAndBindControlResponse' from JSON`, + ); +} diff --git a/sdks/typescript/src/generated/models/clone-and-bind-target-binding.ts b/sdks/typescript/src/generated/models/clone-and-bind-target-binding.ts new file mode 100644 index 00000000..26fb57fc --- /dev/null +++ b/sdks/typescript/src/generated/models/clone-and-bind-target-binding.ts @@ -0,0 +1,57 @@ +/* + * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. + */ + +import * as z from "zod/v4-mini"; +import { remap as remap$ } from "../lib/primitives.js"; + +/** + * Target binding to create for a cloned control. + */ +export type CloneAndBindTargetBinding = { + /** + * Whether the created binding is active. + */ + enabled?: boolean | undefined; + /** + * Opaque external identifier within the target_type. + */ + targetId: string; + /** + * Opaque attachment kind (caller-defined; e.g. 'environment', 'session'). + */ + targetType: string; +}; + +/** @internal */ +export type CloneAndBindTargetBinding$Outbound = { + enabled: boolean; + target_id: string; + target_type: string; +}; + +/** @internal */ +export const CloneAndBindTargetBinding$outboundSchema: z.ZodMiniType< + CloneAndBindTargetBinding$Outbound, + CloneAndBindTargetBinding +> = z.pipe( + z.object({ + enabled: z._default(z.boolean(), true), + targetId: z.string(), + targetType: z.string(), + }), + z.transform((v) => { + return remap$(v, { + targetId: "target_id", + targetType: "target_type", + }); + }), +); + +export function cloneAndBindTargetBindingToJSON( + cloneAndBindTargetBinding: CloneAndBindTargetBinding, +): string { + return JSON.stringify( + CloneAndBindTargetBinding$outboundSchema.parse(cloneAndBindTargetBinding), + ); +} diff --git a/sdks/typescript/src/generated/models/control-attachments.ts b/sdks/typescript/src/generated/models/control-attachments.ts new file mode 100644 index 00000000..31e901f5 --- /dev/null +++ b/sdks/typescript/src/generated/models/control-attachments.ts @@ -0,0 +1,72 @@ +/* + * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. + */ + +import * as z from "zod/v4-mini"; +import { remap as remap$ } from "../lib/primitives.js"; +import { safeParse } from "../lib/schemas.js"; +import { Result as SafeParseResult } from "../types/fp.js"; +import * as types from "../types/primitives.js"; +import { AgentRef, AgentRef$inboundSchema } from "./agent-ref.js"; +import { SDKValidationError } from "./errors/sdk-validation-error.js"; +import { PolicyRef, PolicyRef$inboundSchema } from "./policy-ref.js"; +import { + TargetAttachmentRef, + TargetAttachmentRef$inboundSchema, +} from "./target-attachment-ref.js"; + +/** + * Attachments for a listed control. + */ +export type ControlAttachments = { + /** + * Direct agent associations for this control + */ + agents?: Array | undefined; + /** + * Policy associations for this control + */ + policies?: Array | undefined; + /** + * Target bindings for this control + */ + targets?: Array | undefined; + /** + * Total target bindings matching the attachment filters + */ + targetsTotal: number; + /** + * Whether the target bindings list was capped + */ + targetsTruncated: boolean; +}; + +/** @internal */ +export const ControlAttachments$inboundSchema: z.ZodMiniType< + ControlAttachments, + unknown +> = z.pipe( + z.object({ + agents: types.optional(z.array(AgentRef$inboundSchema)), + policies: types.optional(z.array(PolicyRef$inboundSchema)), + targets: types.optional(z.array(TargetAttachmentRef$inboundSchema)), + targets_total: z._default(types.number(), 0), + targets_truncated: z._default(types.boolean(), false), + }), + z.transform((v) => { + return remap$(v, { + "targets_total": "targetsTotal", + "targets_truncated": "targetsTruncated", + }); + }), +); + +export function controlAttachmentsFromJSON( + jsonString: string, +): SafeParseResult { + return safeParse( + jsonString, + (x) => ControlAttachments$inboundSchema.parse(JSON.parse(x)), + `Failed to parse 'ControlAttachments' from JSON`, + ); +} diff --git a/sdks/typescript/src/generated/models/control-summary.ts b/sdks/typescript/src/generated/models/control-summary.ts index 4c0b0fb3..5b439c9d 100644 --- a/sdks/typescript/src/generated/models/control-summary.ts +++ b/sdks/typescript/src/generated/models/control-summary.ts @@ -8,12 +8,32 @@ import { safeParse } from "../lib/schemas.js"; import { Result as SafeParseResult } from "../types/fp.js"; import * as types from "../types/primitives.js"; import { AgentRef, AgentRef$inboundSchema } from "./agent-ref.js"; +import { + ControlAction, + ControlAction$inboundSchema, +} from "./control-action.js"; +import { + ControlAttachments, + ControlAttachments$inboundSchema, +} from "./control-attachments.js"; import { SDKValidationError } from "./errors/sdk-validation-error.js"; /** * Summary of a control for list responses. */ export type ControlSummary = { + /** + * Action applied when the control matches. + */ + action?: ControlAction | null | undefined; + /** + * Expanded attachment details. Present when list controls is called with include_attachments=true. + */ + attachments?: ControlAttachments | null | undefined; + /** + * Source control ID when this control is a clone. + */ + clonedFromControlId?: number | null | undefined; /** * Control description */ @@ -70,6 +90,9 @@ export const ControlSummary$inboundSchema: z.ZodMiniType< unknown > = z.pipe( z.object({ + action: z.optional(z.nullable(ControlAction$inboundSchema)), + attachments: z.optional(z.nullable(ControlAttachments$inboundSchema)), + cloned_from_control_id: z.optional(z.nullable(types.number())), description: z.optional(z.nullable(types.string())), enabled: z._default(types.boolean(), true), execution: z.optional(z.nullable(types.string())), @@ -85,6 +108,7 @@ export const ControlSummary$inboundSchema: z.ZodMiniType< }), z.transform((v) => { return remap$(v, { + "cloned_from_control_id": "clonedFromControlId", "step_types": "stepTypes", "template_backed": "templateBacked", "template_rendered": "templateRendered", diff --git a/sdks/typescript/src/generated/models/get-control-response.ts b/sdks/typescript/src/generated/models/get-control-response.ts index 8e65e936..58bcfcc9 100644 --- a/sdks/typescript/src/generated/models/get-control-response.ts +++ b/sdks/typescript/src/generated/models/get-control-response.ts @@ -3,6 +3,7 @@ */ import * as z from "zod/v4-mini"; +import { remap as remap$ } from "../lib/primitives.js"; import { safeParse } from "../lib/schemas.js"; import { Result as SafeParseResult } from "../types/fp.js"; import * as types from "../types/primitives.js"; @@ -28,6 +29,10 @@ export type GetControlResponseData = * Response containing control details. */ export type GetControlResponse = { + /** + * Source control ID when this control is a clone. + */ + clonedFromControlId?: number | null | undefined; /** * Control configuration data. A ControlDefinition for raw/rendered controls or an UnrenderedTemplateControl for unrendered templates. */ @@ -65,14 +70,22 @@ export function getControlResponseDataFromJSON( export const GetControlResponse$inboundSchema: z.ZodMiniType< GetControlResponse, unknown -> = z.object({ - data: smartUnion([ - ControlDefinitionOutput$inboundSchema, - UnrenderedTemplateControl$inboundSchema, - ]), - id: types.number(), - name: types.string(), -}); +> = z.pipe( + z.object({ + cloned_from_control_id: z.optional(z.nullable(types.number())), + data: smartUnion([ + ControlDefinitionOutput$inboundSchema, + UnrenderedTemplateControl$inboundSchema, + ]), + id: types.number(), + name: types.string(), + }), + z.transform((v) => { + return remap$(v, { + "cloned_from_control_id": "clonedFromControlId", + }); + }), +); export function getControlResponseFromJSON( jsonString: string, diff --git a/sdks/typescript/src/generated/models/get-control-version-response.ts b/sdks/typescript/src/generated/models/get-control-version-response.ts index d502371c..1c9871b6 100644 --- a/sdks/typescript/src/generated/models/get-control-version-response.ts +++ b/sdks/typescript/src/generated/models/get-control-version-response.ts @@ -26,7 +26,7 @@ export type GetControlVersionResponse = { */ note?: string | null | undefined; /** - * Raw persisted snapshot of the control state at this version, including metadata such as name, deleted_at, and cloned_control_id. + * Raw persisted snapshot of the control state at this version, including metadata such as name, deleted_at, and cloned_from_control_id. */ snapshot: { [k: string]: any }; /** diff --git a/sdks/typescript/src/generated/models/index.ts b/sdks/typescript/src/generated/models/index.ts index 595a9501..8043233b 100644 --- a/sdks/typescript/src/generated/models/index.ts +++ b/sdks/typescript/src/generated/models/index.ts @@ -12,11 +12,15 @@ export * from "./auth-mode.js"; export * from "./batch-events-request.js"; export * from "./batch-events-response.js"; export * from "./boolean-template-parameter.js"; +export * from "./clone-and-bind-control-request.js"; +export * from "./clone-and-bind-control-response.js"; +export * from "./clone-and-bind-target-binding.js"; export * from "./condition-node-input.js"; export * from "./condition-node-output.js"; export * from "./config-response.js"; export * from "./conflict-mode.js"; export * from "./control-action.js"; +export * from "./control-attachments.js"; export * from "./control-definition-input.js"; export * from "./control-definition-output.js"; export * from "./control-execution-event.js"; @@ -73,10 +77,12 @@ export * from "./login-response.js"; export * from "./pagination-info.js"; export * from "./patch-agent-request.js"; export * from "./patch-agent-response.js"; +export * from "./patch-control-binding-by-key-request.js"; export * from "./patch-control-binding-request.js"; export * from "./patch-control-binding-response.js"; export * from "./patch-control-request.js"; export * from "./patch-control-response.js"; +export * from "./policy-ref.js"; export * from "./regex-template-parameter.js"; export * from "./remove-agent-control-response.js"; export * from "./render-control-template-request.js"; @@ -95,6 +101,7 @@ export * from "./step-schema.js"; export * from "./step.js"; export * from "./string-list-template-parameter.js"; export * from "./string-template-parameter.js"; +export * from "./target-attachment-ref.js"; export * from "./template-control-input.js"; export * from "./template-definition-input.js"; export * from "./template-definition-output.js"; diff --git a/sdks/typescript/src/generated/models/operations/clone-and-bind-control-api-v1-controls-control-id-clone-and-bind-post.ts b/sdks/typescript/src/generated/models/operations/clone-and-bind-control-api-v1-controls-control-id-clone-and-bind-post.ts new file mode 100644 index 00000000..888f3adc --- /dev/null +++ b/sdks/typescript/src/generated/models/operations/clone-and-bind-control-api-v1-controls-control-id-clone-and-bind-post.ts @@ -0,0 +1,46 @@ +/* + * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. + */ + +import * as z from "zod/v4-mini"; +import { remap as remap$ } from "../../lib/primitives.js"; +import * as models from "../index.js"; + +export type CloneAndBindControlApiV1ControlsControlIdCloneAndBindPostRequest = { + controlId: number; + body: models.CloneAndBindControlRequest; +}; + +/** @internal */ +export type CloneAndBindControlApiV1ControlsControlIdCloneAndBindPostRequest$Outbound = + { + control_id: number; + body: models.CloneAndBindControlRequest$Outbound; + }; + +/** @internal */ +export const CloneAndBindControlApiV1ControlsControlIdCloneAndBindPostRequest$outboundSchema: + z.ZodMiniType< + CloneAndBindControlApiV1ControlsControlIdCloneAndBindPostRequest$Outbound, + CloneAndBindControlApiV1ControlsControlIdCloneAndBindPostRequest + > = z.pipe( + z.object({ + controlId: z.int(), + body: models.CloneAndBindControlRequest$outboundSchema, + }), + z.transform((v) => { + return remap$(v, { + controlId: "control_id", + }); + }), + ); + +export function cloneAndBindControlApiV1ControlsControlIdCloneAndBindPostRequestToJSON( + cloneAndBindControlApiV1ControlsControlIdCloneAndBindPostRequest: + CloneAndBindControlApiV1ControlsControlIdCloneAndBindPostRequest, +): string { + return JSON.stringify( + CloneAndBindControlApiV1ControlsControlIdCloneAndBindPostRequest$outboundSchema + .parse(cloneAndBindControlApiV1ControlsControlIdCloneAndBindPostRequest), + ); +} diff --git a/sdks/typescript/src/generated/models/operations/index.ts b/sdks/typescript/src/generated/models/operations/index.ts index 819659f8..b031df32 100644 --- a/sdks/typescript/src/generated/models/operations/index.ts +++ b/sdks/typescript/src/generated/models/operations/index.ts @@ -5,6 +5,7 @@ export * from "./add-agent-control-api-v1-agents-agent-name-controls-control-id-post.js"; export * from "./add-agent-policy-api-v1-agents-agent-name-policies-policy-id-post.js"; export * from "./add-control-to-policy-api-v1-policies-policy-id-controls-control-id-post.js"; +export * from "./clone-and-bind-control-api-v1-controls-control-id-clone-and-bind-post.js"; export * from "./delete-agent-policy-api-v1-agents-agent-name-policy-delete.js"; export * from "./delete-control-api-v1-controls-control-id-delete.js"; export * from "./delete-control-binding-api-v1-control-bindings-binding-id-delete.js"; diff --git a/sdks/typescript/src/generated/models/operations/list-controls-api-v1-controls-get.ts b/sdks/typescript/src/generated/models/operations/list-controls-api-v1-controls-get.ts index 7f19162e..dac0a4fd 100644 --- a/sdks/typescript/src/generated/models/operations/list-controls-api-v1-controls-get.ts +++ b/sdks/typescript/src/generated/models/operations/list-controls-api-v1-controls-get.ts @@ -23,6 +23,10 @@ export type ListControlsApiV1ControlsGetRequest = { * Filter by whether the control is template-backed */ templateBacked?: boolean | null | undefined; + /** + * Filter by whether the control was cloned from another control + */ + cloned?: boolean | null | undefined; /** * Filter by step type (built-ins: 'tool', 'llm') */ @@ -39,6 +43,18 @@ export type ListControlsApiV1ControlsGetRequest = { * Filter by tag */ tag?: string | null | undefined; + /** + * When true, include direct agent associations, policy associations, and target bindings for each listed control. + */ + includeAttachments?: boolean | undefined; + /** + * Optional target_type filter applied to the returned controls and expanded target bindings. Only used when include_attachments=true. + */ + attachmentTargetType?: string | null | undefined; + /** + * Optional target_id filter applied to the returned controls and expanded target bindings. Only used when include_attachments=true. + */ + attachmentTargetId?: string | null | undefined; }; /** @internal */ @@ -48,10 +64,14 @@ export type ListControlsApiV1ControlsGetRequest$Outbound = { name?: string | null | undefined; enabled?: boolean | null | undefined; template_backed?: boolean | null | undefined; + cloned?: boolean | null | undefined; step_type?: string | null | undefined; stage?: string | null | undefined; execution?: string | null | undefined; tag?: string | null | undefined; + include_attachments: boolean; + attachment_target_type?: string | null | undefined; + attachment_target_id?: string | null | undefined; }; /** @internal */ @@ -65,15 +85,22 @@ export const ListControlsApiV1ControlsGetRequest$outboundSchema: z.ZodMiniType< name: z.optional(z.nullable(z.string())), enabled: z.optional(z.nullable(z.boolean())), templateBacked: z.optional(z.nullable(z.boolean())), + cloned: z.optional(z.nullable(z.boolean())), stepType: z.optional(z.nullable(z.string())), stage: z.optional(z.nullable(z.string())), execution: z.optional(z.nullable(z.string())), tag: z.optional(z.nullable(z.string())), + includeAttachments: z._default(z.boolean(), false), + attachmentTargetType: z.optional(z.nullable(z.string())), + attachmentTargetId: z.optional(z.nullable(z.string())), }), z.transform((v) => { return remap$(v, { templateBacked: "template_backed", stepType: "step_type", + includeAttachments: "include_attachments", + attachmentTargetType: "attachment_target_type", + attachmentTargetId: "attachment_target_id", }); }), ); diff --git a/sdks/typescript/src/generated/models/patch-control-binding-by-key-request.ts b/sdks/typescript/src/generated/models/patch-control-binding-by-key-request.ts new file mode 100644 index 00000000..c66588d9 --- /dev/null +++ b/sdks/typescript/src/generated/models/patch-control-binding-by-key-request.ts @@ -0,0 +1,66 @@ +/* + * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. + */ + +import * as z from "zod/v4-mini"; +import { remap as remap$ } from "../lib/primitives.js"; + +/** + * Request to update an existing control binding by natural key. + */ +export type PatchControlBindingByKeyRequest = { + /** + * ID of the bound control. + */ + controlId: number; + /** + * New enabled value for the binding. + */ + enabled: boolean; + /** + * Opaque external identifier within the target_type. + */ + targetId: string; + /** + * Opaque attachment kind. + */ + targetType: string; +}; + +/** @internal */ +export type PatchControlBindingByKeyRequest$Outbound = { + control_id: number; + enabled: boolean; + target_id: string; + target_type: string; +}; + +/** @internal */ +export const PatchControlBindingByKeyRequest$outboundSchema: z.ZodMiniType< + PatchControlBindingByKeyRequest$Outbound, + PatchControlBindingByKeyRequest +> = z.pipe( + z.object({ + controlId: z.int(), + enabled: z.boolean(), + targetId: z.string(), + targetType: z.string(), + }), + z.transform((v) => { + return remap$(v, { + controlId: "control_id", + targetId: "target_id", + targetType: "target_type", + }); + }), +); + +export function patchControlBindingByKeyRequestToJSON( + patchControlBindingByKeyRequest: PatchControlBindingByKeyRequest, +): string { + return JSON.stringify( + PatchControlBindingByKeyRequest$outboundSchema.parse( + patchControlBindingByKeyRequest, + ), + ); +} diff --git a/sdks/typescript/src/generated/models/policy-ref.ts b/sdks/typescript/src/generated/models/policy-ref.ts new file mode 100644 index 00000000..ab6f9fbc --- /dev/null +++ b/sdks/typescript/src/generated/models/policy-ref.ts @@ -0,0 +1,43 @@ +/* + * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. + */ + +import * as z from "zod/v4-mini"; +import { remap as remap$ } from "../lib/primitives.js"; +import { safeParse } from "../lib/schemas.js"; +import { Result as SafeParseResult } from "../types/fp.js"; +import * as types from "../types/primitives.js"; +import { SDKValidationError } from "./errors/sdk-validation-error.js"; + +/** + * Reference to a policy attached to a control. + */ +export type PolicyRef = { + /** + * Policy ID + */ + policyId: number; +}; + +/** @internal */ +export const PolicyRef$inboundSchema: z.ZodMiniType = z + .pipe( + z.object({ + policy_id: types.number(), + }), + z.transform((v) => { + return remap$(v, { + "policy_id": "policyId", + }); + }), + ); + +export function policyRefFromJSON( + jsonString: string, +): SafeParseResult { + return safeParse( + jsonString, + (x) => PolicyRef$inboundSchema.parse(JSON.parse(x)), + `Failed to parse 'PolicyRef' from JSON`, + ); +} diff --git a/sdks/typescript/src/generated/models/target-attachment-ref.ts b/sdks/typescript/src/generated/models/target-attachment-ref.ts new file mode 100644 index 00000000..c1393dc5 --- /dev/null +++ b/sdks/typescript/src/generated/models/target-attachment-ref.ts @@ -0,0 +1,62 @@ +/* + * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. + */ + +import * as z from "zod/v4-mini"; +import { remap as remap$ } from "../lib/primitives.js"; +import { safeParse } from "../lib/schemas.js"; +import { Result as SafeParseResult } from "../types/fp.js"; +import * as types from "../types/primitives.js"; +import { SDKValidationError } from "./errors/sdk-validation-error.js"; + +/** + * Reference to a target binding attached to a control. + */ +export type TargetAttachmentRef = { + /** + * Control binding ID + */ + bindingId: number; + /** + * Whether this target binding is enabled + */ + enabled: boolean; + /** + * Opaque target identifier + */ + targetId: string; + /** + * Opaque target kind + */ + targetType: string; +}; + +/** @internal */ +export const TargetAttachmentRef$inboundSchema: z.ZodMiniType< + TargetAttachmentRef, + unknown +> = z.pipe( + z.object({ + binding_id: types.number(), + enabled: types.boolean(), + target_id: types.string(), + target_type: types.string(), + }), + z.transform((v) => { + return remap$(v, { + "binding_id": "bindingId", + "target_id": "targetId", + "target_type": "targetType", + }); + }), +); + +export function targetAttachmentRefFromJSON( + jsonString: string, +): SafeParseResult { + return safeParse( + jsonString, + (x) => TargetAttachmentRef$inboundSchema.parse(JSON.parse(x)), + `Failed to parse 'TargetAttachmentRef' from JSON`, + ); +} diff --git a/sdks/typescript/src/generated/sdk/control-bindings.ts b/sdks/typescript/src/generated/sdk/control-bindings.ts index 5a5bcf2b..a8708986 100644 --- a/sdks/typescript/src/generated/sdk/control-bindings.ts +++ b/sdks/typescript/src/generated/sdk/control-bindings.ts @@ -7,6 +7,7 @@ import { controlBindingsDeleteByKey } from "../funcs/control-bindings-delete-by- import { controlBindingsDelete } from "../funcs/control-bindings-delete.js"; import { controlBindingsGet } from "../funcs/control-bindings-get.js"; import { controlBindingsList } from "../funcs/control-bindings-list.js"; +import { controlBindingsUpdateByKey } from "../funcs/control-bindings-update-by-key.js"; import { controlBindingsUpdate } from "../funcs/control-bindings-update.js"; import { controlBindingsUpsertByKey } from "../funcs/control-bindings-upsert-by-key.js"; import { ClientSDK, RequestOptions } from "../lib/sdks.js"; @@ -58,6 +59,27 @@ export class ControlBindings extends ClientSDK { )); } + /** + * Update a control binding by natural key + * + * @remarks + * Update an existing binding using ``(target_type, target_id, control_id)``. + * + * This route is target-scoped because the request body includes the target + * identifiers before authorization runs. Unlike ``PUT /by-key``, it never + * creates a missing binding. + */ + async updateByKey( + request: models.PatchControlBindingByKeyRequest, + options?: RequestOptions, + ): Promise { + return unwrapAsync(controlBindingsUpdateByKey( + this, + request, + options, + )); + } + /** * Attach a control to a target by natural key (idempotent) * diff --git a/sdks/typescript/src/generated/sdk/controls.ts b/sdks/typescript/src/generated/sdk/controls.ts index ed3cf8db..1168463a 100644 --- a/sdks/typescript/src/generated/sdk/controls.ts +++ b/sdks/typescript/src/generated/sdk/controls.ts @@ -2,6 +2,7 @@ * Code generated by Speakeasy (https://speakeasy.com). DO NOT EDIT. */ +import { controlsCloneAndBindControl } from "../funcs/controls-clone-and-bind-control.js"; import { controlsCreate } from "../funcs/controls-create.js"; import { controlsDelete } from "../funcs/controls-delete.js"; import { controlsGetData } from "../funcs/controls-get-data.js"; @@ -51,10 +52,16 @@ export class Controls extends ClientSDK { * name: Optional filter by name (partial, case-insensitive match) * enabled: Optional filter by enabled status * template_backed: Optional filter by whether the control is template-backed + * cloned: Optional filter by whether the control was cloned from another control * step_type: Optional filter by step type (built-ins: 'tool', 'llm') * stage: Optional filter by stage ('pre' or 'post') * execution: Optional filter by execution ('server' or 'sdk') * tag: Optional filter by tag + * include_attachments: Whether to include attachment details for listed controls + * attachment_target_type: Optional target binding type filter for controls and + * attachments + * attachment_target_id: Optional target binding ID filter for controls and + * attachments * db: Database session (injected) * * Returns: @@ -239,6 +246,24 @@ export class Controls extends ClientSDK { )); } + /** + * Clone a control and bind the clone to a target + * + * @remarks + * Clone an active control and attach the clone to an opaque target. + */ + async cloneAndBindControl( + request: + operations.CloneAndBindControlApiV1ControlsControlIdCloneAndBindPostRequest, + options?: RequestOptions, + ): Promise { + return unwrapAsync(controlsCloneAndBindControl( + this, + request, + options, + )); + } + /** * Get control configuration data * diff --git a/server/alembic/versions/e2b7f4a9c6d1_control_clone_lineage.py b/server/alembic/versions/e2b7f4a9c6d1_control_clone_lineage.py new file mode 100644 index 00000000..c0a242b2 --- /dev/null +++ b/server/alembic/versions/e2b7f4a9c6d1_control_clone_lineage.py @@ -0,0 +1,53 @@ +"""control clone lineage + +Revision ID: e2b7f4a9c6d1 +Revises: b6f4c2d8e9a1 +Create Date: 2026-05-19 00:00:00.000000 + +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "e2b7f4a9c6d1" +down_revision = "b6f4c2d8e9a1" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "controls", + sa.Column("cloned_from_control_id", sa.Integer(), nullable=True), + ) + # No ON DELETE action: hard deletes of clone sources are restricted. + # The API soft-deletes controls so clone lineage remains intact. + op.create_foreign_key( + "controls_cloned_from_control_fkey", + "controls", + "controls", + ["namespace_key", "cloned_from_control_id"], + ["namespace_key", "id"], + ) + with op.get_context().autocommit_block(): + op.execute( + """ + CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_controls_cloned_from + ON controls (namespace_key, cloned_from_control_id) + WHERE cloned_from_control_id IS NOT NULL + """ + ) + + +def downgrade() -> None: + with op.get_context().autocommit_block(): + op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_controls_cloned_from") + op.drop_constraint( + "controls_cloned_from_control_fkey", + "controls", + type_="foreignkey", + ) + op.drop_column("controls", "cloned_from_control_id") diff --git a/server/src/agent_control_server/auth_framework/providers/http_upstream.py b/server/src/agent_control_server/auth_framework/providers/http_upstream.py index b68972ed..f6ce3b5f 100644 --- a/server/src/agent_control_server/auth_framework/providers/http_upstream.py +++ b/server/src/agent_control_server/auth_framework/providers/http_upstream.py @@ -33,7 +33,10 @@ "caller_id": "..." } -Statuses other than 200 / 401 / 403 / 404 / 5xx fail closed (503). +Statuses other than 200 / 401 / 403 / 404 / 429 fail closed. Unexpected +upstream 4xx responses are reported separately from Agent Control +misconfiguration so operators can distinguish upstream request rejection +from local auth setup failures. """ from __future__ import annotations @@ -280,6 +283,25 @@ def _handle_response( detail="Authorization service is rate-limiting requests.", hint=hint, ) + if 400 <= status < 500: + _logger.warning( + "Authorization upstream rejected operation %s with status %d", + operation.value, + status, + ) + raise APIError( + status_code=502, + error_code=ErrorCode.AUTH_UPSTREAM_REJECTED, + reason=ErrorReason.INTERNAL_ERROR, + detail=( + "Authorization service rejected the authorization check " + f"(status {status})." + ), + hint=( + "Check that the Agent Control authorization request shape " + "matches the upstream authorization service contract." + ), + ) # Fail closed on 5xx and unexpected statuses. _logger.warning( "Unexpected upstream status %d for operation %s", diff --git a/server/src/agent_control_server/endpoints/control_bindings.py b/server/src/agent_control_server/endpoints/control_bindings.py index 87386723..279328c4 100644 --- a/server/src/agent_control_server/endpoints/control_bindings.py +++ b/server/src/agent_control_server/endpoints/control_bindings.py @@ -14,6 +14,7 @@ GetControlBindingResponse, ListControlBindingsResponse, PaginationInfo, + PatchControlBindingByKeyRequest, PatchControlBindingRequest, PatchControlBindingResponse, UpsertControlBindingRequest, @@ -214,6 +215,40 @@ async def get_control_binding( return _to_response(binding) +@router.patch( + "/by-key", + response_model=PatchControlBindingResponse, + summary="Update a control binding by natural key", + response_description="Updated enabled flag", +) +async def patch_control_binding_by_key( + request: PatchControlBindingByKeyRequest, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends( + require_operation( + Operation.CONTROL_BINDINGS_WRITE, + context_builder=_binding_body_context, + ) + ), +) -> PatchControlBindingResponse: + """Update an existing binding using ``(target_type, target_id, control_id)``. + + This route is target-scoped because the request body includes the target + identifiers before authorization runs. Unlike ``PUT /by-key``, it never + creates a missing binding. + """ + service = ControlBindingsService(db) + binding = await service.set_enabled_by_natural_key( + namespace_key=principal.namespace_key, + target_type=request.target_type, + target_id=request.target_id, + control_id=request.control_id, + enabled=request.enabled, + ) + await db.commit() + return PatchControlBindingResponse(success=True, enabled=binding.enabled) + + @router.patch( "/{binding_id}", response_model=PatchControlBindingResponse, diff --git a/server/src/agent_control_server/endpoints/controls.py b/server/src/agent_control_server/endpoints/controls.py index 6e6441e9..607fcd08 100644 --- a/server/src/agent_control_server/endpoints/controls.py +++ b/server/src/agent_control_server/endpoints/controls.py @@ -1,10 +1,16 @@ import datetime as dt +import uuid +from copy import deepcopy +from typing import Any from agent_control_engine import list_evaluators from agent_control_models import ControlDefinition, TemplateControlInput, UnrenderedTemplateControl from agent_control_models.errors import ErrorCode, ValidationErrorItem from agent_control_models.server import ( AgentRef, + CloneAndBindControlRequest, + CloneAndBindControlResponse, + ControlAttachments, ControlSummary, ControlVersionSummary, CreateControlRequest, @@ -19,26 +25,32 @@ PaginationInfo, PatchControlRequest, PatchControlResponse, + PolicyRef, RenderControlTemplateRequest, RenderControlTemplateResponse, SetControlDataRequest, SetControlDataResponse, + SlugName, + TargetAttachmentRef, ValidateControlDataRequest, ValidateControlDataResponse, ) -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Depends, Query, Request from jsonschema_rs import ValidationError as JSONSchemaValidationError -from pydantic import ValidationError +from pydantic import TypeAdapter, ValidationError from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession -from ..auth_framework import Operation, Principal, require_operation +from ..auth_framework import Operation, Principal, get_authorizer, require_operation from ..db import get_async_db from ..errors import ( + APIError, APIValidationError, + AuthenticationError, ConflictError, DatabaseError, + ForbiddenError, NotFoundError, ) from ..logging_utils import get_logger @@ -77,6 +89,168 @@ "idx_controls_name_active", "idx_controls_namespace_name_active", } +_MAX_TARGET_CONTEXT_VALUE_LENGTH = 255 +_CLONE_NAME_SUFFIX_HEX_LENGTH = 16 +_GENERATED_CLONE_NAME_ATTEMPTS = 5 +_TRUE_QUERY_VALUES = {"1", "true", "t", "yes", "y", "on"} +_SLUG_NAME_ADAPTER = TypeAdapter(SlugName) + + +def _is_target_context_value(value: object) -> bool: + return ( + isinstance(value, str) + and bool(value) + and len(value) <= _MAX_TARGET_CONTEXT_VALUE_LENGTH + ) + + +def _ensure_same_namespace_authorization( + *principals: Principal, + detail: str = "Authorization resolved to different namespaces.", + hint: str = "Use credentials that grant the required operations in the same namespace.", +) -> None: + namespace_keys = {principal.namespace_key for principal in principals} + if len(namespace_keys) == 1: + return + raise ForbiddenError( + error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, + detail=detail, + resource="ControlBinding", + hint=hint, + ) + + +async def _clone_and_bind_context(request: Request) -> dict[str, Any]: + """Surface clone target identifiers to the authorization context.""" + try: + body = await request.json() + except Exception: # noqa: BLE001 malformed JSON falls through to endpoint validation + return {} + if not isinstance(body, dict): + return {} + target_binding = body.get("target_binding") + if not isinstance(target_binding, dict): + return {} + target_type = target_binding.get("target_type") + target_id = target_binding.get("target_id") + if not _is_target_context_value(target_type): + return {} + if not _is_target_context_value(target_id): + return {} + return { + "target_type": target_type, + "target_id": target_id, + } + + +def _attachment_target_context(request: Request) -> dict[str, str]: + context: dict[str, str] = {} + target_type = request.query_params.get("attachment_target_type") + target_id = request.query_params.get("attachment_target_id") + if target_type is not None: + if not _is_target_context_value(target_type): + return {} + context["target_type"] = target_type + if target_id is not None: + if not _is_target_context_value(target_id): + return {} + context["target_id"] = target_id + return context + + +async def _optional_attachment_target_principal(request: Request) -> Principal | None: + include_attachments = request.query_params.get("include_attachments") + if include_attachments is None: + return None + if include_attachments.lower() not in _TRUE_QUERY_VALUES: + return None + target_context = _attachment_target_context(request) + try: + return await get_authorizer(Operation.CONTROL_BINDINGS_READ).authorize( + request, + Operation.CONTROL_BINDINGS_READ, + target_context, + ) + except (AuthenticationError, ForbiddenError, NotFoundError): + if target_context: + raise + return None + + +def _generated_clone_name(source_id: int, source_name: str) -> str: + """Return a slug-safe default name for a cloned control.""" + suffix = f"-clone-{uuid.uuid4().hex[:_CLONE_NAME_SUFFIX_HEX_LENGTH]}" + candidate = f"{source_name[: 255 - len(suffix)]}{suffix}" + try: + return _SLUG_NAME_ADAPTER.validate_python(candidate) + except ValidationError: + return _SLUG_NAME_ADAPTER.validate_python(f"control-{source_id}{suffix}") + + +async def _resolve_clone_name( + control_service: ControlService, + *, + namespace_key: str, + source_id: int, + source_name: str, + requested_name: str | None, +) -> str: + if requested_name is not None: + if await control_service.active_control_name_exists( + requested_name, namespace_key=namespace_key + ): + raise ConflictError( + error_code=ErrorCode.CONTROL_NAME_CONFLICT, + detail=f"Control with name '{requested_name}' already exists", + resource="Control", + resource_id=requested_name, + hint="Choose a different clone name.", + ) + return requested_name + + for _ in range(_GENERATED_CLONE_NAME_ATTEMPTS): + clone_name = _generated_clone_name(source_id, source_name) + if not await control_service.active_control_name_exists( + clone_name, namespace_key=namespace_key + ): + return clone_name + + raise ConflictError( + error_code=ErrorCode.CONTROL_NAME_CONFLICT, + detail="Could not generate a unique clone name.", + resource="Control", + resource_id=source_name, + hint="Retry the request or provide an explicit clone name.", + ) + + +def _validate_attachment_filters( + *, + include_attachments: bool, + attachment_target_type: str | None, + attachment_target_id: str | None, +) -> None: + if include_attachments: + return + if attachment_target_type is None and attachment_target_id is None: + return + raise APIValidationError( + error_code=ErrorCode.VALIDATION_ERROR, + detail="Attachment target filters require include_attachments=true.", + resource="Control", + hint="Set include_attachments=true or remove attachment target filters.", + errors=[ + ValidationErrorItem( + resource="Control", + field="include_attachments", + code="missing_required_parameter", + message=( + "Set include_attachments=true when using attachment_target_type " + "or attachment_target_id." + ), + ) + ], + ) def _serialize_control_data( @@ -576,6 +750,118 @@ async def create_control( return CreateControlResponse(control_id=control.id) +@router.post( + "/{control_id}/clone-and-bind", + response_model=CloneAndBindControlResponse, + summary="Clone a control and bind the clone to a target", + response_description="Created clone and binding identifiers", +) +async def clone_and_bind_control( + control_id: int, + request: CloneAndBindControlRequest, + db: AsyncSession = Depends(get_async_db), + principal: Principal = Depends(require_operation(Operation.CONTROLS_CREATE)), + read_principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + binding_principal: Principal = Depends( + require_operation( + Operation.CONTROL_BINDINGS_WRITE, + context_builder=_clone_and_bind_context, + ) + ), +) -> CloneAndBindControlResponse: + """Clone an active control and attach the clone to an opaque target.""" + _ensure_same_namespace_authorization( + principal, + read_principal, + binding_principal, + detail="Clone authorization resolved to different namespaces.", + hint=( + "Use credentials that grant source read, control creation, and target " + "binding in the same namespace." + ), + ) + + namespace_key = principal.namespace_key + control_service = ControlService(db) + bindings_service = ControlBindingsService(db) + + source = await control_service.get_active_control_or_404( + control_id, + namespace_key=namespace_key, + for_update=True, + ) + clone_name = await _resolve_clone_name( + control_service, + namespace_key=namespace_key, + source_id=source.id, + source_name=source.name, + requested_name=request.name, + ) + + clone = control_service.create_control( + namespace_key=namespace_key, + name=clone_name, + data=deepcopy(source.data), + cloned_from_control_id=source.id, + ) + try: + await control_service.create_version( + clone, + event_type="cloned", + note=f"Cloned from control {source.id}", + ) + binding = await bindings_service.create_binding( + namespace_key=namespace_key, + target_type=request.target_binding.target_type, + target_id=request.target_binding.target_id, + control_id=clone.id, + enabled=request.target_binding.enabled, + ) + await db.commit() + except APIError: + await db.rollback() + raise + except IntegrityError as exc: + await db.rollback() + if _is_control_name_conflict(exc): + raise ConflictError( + error_code=ErrorCode.CONTROL_NAME_CONFLICT, + detail=f"Control with name '{clone_name}' already exists", + resource="Control", + resource_id=clone_name, + hint="Choose a different clone name.", + ) + _logger.error( + "Failed to clone control '%s' due to integrity error", + source.name, + exc_info=True, + ) + raise DatabaseError( + detail=f"Failed to clone control '{source.id}': database error", + resource="Control", + operation="clone_and_bind", + ) + except Exception: + await db.rollback() + _logger.error( + "Failed to clone and bind control '%s'", + source.name, + exc_info=True, + ) + raise DatabaseError( + detail=f"Failed to clone control '{source.id}': database error", + resource="Control", + operation="clone_and_bind", + ) + + return CloneAndBindControlResponse( + id=clone.id, + name=clone.name, + cloned_from_control_id=source.id, + binding_id=binding.id, + ) + + @router.get( "/schema", response_model=GetControlSchemaResponse, @@ -626,6 +912,7 @@ async def get_control( return GetControlResponse( id=control.id, name=control.name, + cloned_from_control_id=control.cloned_from_control_id, data=control_data, ) @@ -852,14 +1139,46 @@ async def list_controls( None, description="Filter by whether the control is template-backed", ), + cloned: bool | None = Query( + None, + description="Filter by whether the control was cloned from another control", + ), step_type: str | None = Query( None, description="Filter by step type (built-ins: 'tool', 'llm')" ), stage: str | None = Query(None, description="Filter by stage ('pre' or 'post')"), execution: str | None = Query(None, description="Filter by execution ('server' or 'sdk')"), tag: str | None = Query(None, description="Filter by tag"), + include_attachments: bool = Query( + False, + description=( + "When true, include direct agent associations, policy associations, " + "and target bindings for each listed control." + ), + ), + attachment_target_type: str | None = Query( + None, + min_length=1, + max_length=255, + description=( + "Optional target_type filter applied to the returned controls and " + "expanded target bindings. " + "Only used when include_attachments=true." + ), + ), + attachment_target_id: str | None = Query( + None, + min_length=1, + max_length=255, + description=( + "Optional target_id filter applied to the returned controls and " + "expanded target bindings. " + "Only used when include_attachments=true." + ), + ), db: AsyncSession = Depends(get_async_db), principal: Principal = Depends(require_operation(Operation.CONTROLS_READ)), + target_principal: Principal | None = Depends(_optional_attachment_target_principal), ) -> ListControlsResponse: """ List all controls with optional filtering and cursor-based pagination. @@ -872,10 +1191,16 @@ async def list_controls( name: Optional filter by name (partial, case-insensitive match) enabled: Optional filter by enabled status template_backed: Optional filter by whether the control is template-backed + cloned: Optional filter by whether the control was cloned from another control step_type: Optional filter by step type (built-ins: 'tool', 'llm') stage: Optional filter by stage ('pre' or 'post') execution: Optional filter by execution ('server' or 'sdk') tag: Optional filter by tag + include_attachments: Whether to include attachment details for listed controls + attachment_target_type: Optional target binding type filter for controls and + attachments + attachment_target_id: Optional target binding ID filter for controls and + attachments db: Database session (injected) Returns: @@ -884,8 +1209,30 @@ async def list_controls( Example: GET /controls?limit=10&enabled=true&step_type=tool """ + _validate_attachment_filters( + include_attachments=include_attachments, + attachment_target_type=attachment_target_type, + attachment_target_id=attachment_target_id, + ) + if target_principal is not None: + _ensure_same_namespace_authorization( + principal, + target_principal, + detail=( + "Control and target-binding read authorization resolved " + "to different namespaces." + ), + hint=( + "Use credentials that grant control read and target-binding read " + "in the same namespace." + ), + ) + control_service = ControlService(db) namespace_key = principal.namespace_key + filter_by_attachment = target_principal is not None and ( + attachment_target_type is not None or attachment_target_id is not None + ) page = await control_service.list_controls_page( namespace_key=namespace_key, cursor=cursor, @@ -893,15 +1240,29 @@ async def list_controls( name=name, enabled=enabled, template_backed=template_backed, + cloned=cloned, step_type=step_type, stage=stage, execution=execution, tag=tag, + attachment_target_type=attachment_target_type if filter_by_attachment else None, + attachment_target_id=attachment_target_id if filter_by_attachment else None, ) usage_by_control_id = await control_service.list_control_usage( [control.id for control in page.controls], namespace_key=namespace_key, ) + attachments_by_control_id = ( + await control_service.list_control_attachments( + [control.id for control in page.controls], + namespace_key=namespace_key, + target_type=attachment_target_type, + target_id=attachment_target_id, + include_targets=target_principal is not None, + ) + if include_attachments + else {} + ) # Build summaries (filtering already done at DB level) summaries: list[ControlSummary] = [] @@ -910,16 +1271,19 @@ async def list_controls( data = ctrl.data or {} scope = data.get("scope") or {} usage = usage_by_control_id.get(ctrl.id) + attachments = attachments_by_control_id.get(ctrl.id) summaries.append( ControlSummary( id=ctrl.id, name=ctrl.name, + cloned_from_control_id=ctrl.cloned_from_control_id, description=( data.get("description") or (data.get("template") or {}).get("description") ), enabled=data.get("enabled", True), execution=data.get("execution"), + action=data.get("action"), step_types=scope.get("step_types"), stages=scope.get("stages"), tags=data.get("tags", []), @@ -933,6 +1297,31 @@ async def list_controls( else None ), used_by_agents_count=usage.used_by_agents_count if usage is not None else 0, + attachments=( + ControlAttachments( + agents=[ + AgentRef(agent_name=agent_name) + for agent_name in attachments.agent_names + ], + policies=[ + PolicyRef(policy_id=policy_id) + for policy_id in attachments.policy_ids + ], + targets=[ + TargetAttachmentRef( + binding_id=target.binding_id, + target_type=target.target_type, + target_id=target.target_id, + enabled=target.enabled, + ) + for target in attachments.targets + ], + targets_total=attachments.targets_total, + targets_truncated=attachments.targets_truncated, + ) + if attachments is not None + else None + ), ) ) diff --git a/server/src/agent_control_server/migrate.py b/server/src/agent_control_server/migrate.py index fdf341d7..c6c4a1cc 100644 --- a/server/src/agent_control_server/migrate.py +++ b/server/src/agent_control_server/migrate.py @@ -110,6 +110,7 @@ def _acquire_migration_lock(connection: Connection, timeout_seconds: float) -> N ).scalar_one() ) if acquired: + connection.commit() LOGGER.info("Acquired Agent Control migration advisory lock.") return @@ -150,6 +151,7 @@ def _serialized_migration(cfg: Config, *, enabled: bool) -> Iterator[None]: _MIGRATION_LOCK_PARAMS, ).scalar_one() ) + connection.commit() if released: LOGGER.info("Released Agent Control migration advisory lock.") else: diff --git a/server/src/agent_control_server/models.py b/server/src/agent_control_server/models.py index cad73c23..c31ccddf 100644 --- a/server/src/agent_control_server/models.py +++ b/server/src/agent_control_server/models.py @@ -157,6 +157,13 @@ class Control(Base): UniqueConstraint( "namespace_key", "id", name="uq_controls_namespace_id" ), + # Hard deletes of clone sources are restricted. The request path + # soft-deletes controls so clone lineage remains intact. + ForeignKeyConstraint( + ["namespace_key", "cloned_from_control_id"], + ["controls.namespace_key", "controls.id"], + name="controls_cloned_from_control_fkey", + ), # Plain partial index on name preserves name-only lookup performance # while service code is still namespace-blind. Mirrors the pattern # used for agents and policies; the partial filter matches the @@ -167,6 +174,13 @@ class Control(Base): postgresql_where=text("deleted_at IS NULL"), sqlite_where=text("deleted_at IS NULL"), ), + Index( + "idx_controls_cloned_from", + "namespace_key", + "cloned_from_control_id", + postgresql_where=text("cloned_from_control_id IS NOT NULL"), + sqlite_where=text("cloned_from_control_id IS NOT NULL"), + ), ) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) @@ -178,6 +192,9 @@ class Control(Base): data: Mapped[dict[str, Any]] = mapped_column( JSONB, server_default=text("'{}'::jsonb"), nullable=False ) + cloned_from_control_id: Mapped[int | None] = mapped_column( + Integer, nullable=True + ) deleted_at: Mapped[dt.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) diff --git a/server/src/agent_control_server/services/control_bindings.py b/server/src/agent_control_server/services/control_bindings.py index 6f00e5d7..f6b04d44 100644 --- a/server/src/agent_control_server/services/control_bindings.py +++ b/server/src/agent_control_server/services/control_bindings.py @@ -173,6 +173,45 @@ async def delete_by_natural_key( await self._db.flush() return True + async def set_enabled_by_natural_key( + self, + *, + namespace_key: str, + target_type: str, + target_id: str, + control_id: int, + enabled: bool, + ) -> ControlBinding: + """Update an existing binding by natural key. + + Unlike ``upsert_by_natural_key``, this never creates a binding. + It is intended for target-scoped callers that need to toggle an + already-attached control while preserving a clear 404 for missing + attachments. + """ + existing = await self._find_by_natural_key( + namespace_key=namespace_key, + target_type=target_type, + target_id=target_id, + control_id=control_id, + ) + if existing is None: + raise NotFoundError( + error_code=ErrorCode.CONTROL_BINDING_NOT_FOUND, + detail=( + "Control binding not found for the supplied " + "(target_type, target_id, control_id)." + ), + resource="ControlBinding", + hint=( + "Verify the target and control IDs, or attach the control " + "before updating the binding." + ), + ) + existing.enabled = enabled + await self._db.flush() + return existing + async def _find_by_natural_key( self, *, diff --git a/server/src/agent_control_server/services/controls.py b/server/src/agent_control_server/services/controls.py index 6c015310..293fc130 100644 --- a/server/src/agent_control_server/services/controls.py +++ b/server/src/agent_control_server/services/controls.py @@ -13,7 +13,7 @@ from agent_control_models.errors import ErrorCode, ValidationErrorItem from agent_control_models.policy import Control as APIControl from pydantic import ValidationError -from sqlalchemy import Integer, String, delete, func, literal, or_, select, union, union_all +from sqlalchemy import Integer, String, delete, exists, func, literal, or_, select, union, union_all from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.sql import Select @@ -36,6 +36,8 @@ type AgentControlRenderedState = Literal["rendered", "unrendered", "all"] type AgentControlEnabledState = Literal["enabled", "disabled", "all"] +_MAX_INLINE_TARGET_ATTACHMENTS_PER_CONTROL = 20 + @dataclass(frozen=True) class RuntimeControl: @@ -74,6 +76,27 @@ class ControlUsage: used_by_agents_count: int +@dataclass(frozen=True) +class ControlTargetAttachment: + """Target binding attached to a control.""" + + binding_id: int + target_type: str + target_id: str + enabled: bool + + +@dataclass(frozen=True) +class ControlAttachmentSet: + """Direct attachments for a listed control.""" + + policy_ids: list[int] + agent_names: list[str] + targets: list[ControlTargetAttachment] + targets_total: int + targets_truncated: bool + + @dataclass(frozen=True) class ControlAssociations: """Policy and agent associations for a control.""" @@ -102,9 +125,15 @@ def create_control( namespace_key: str, name: str, data: dict[str, Any], + cloned_from_control_id: int | None = None, ) -> Control: """Create a new pending control row.""" - control = Control(namespace_key=namespace_key, name=name, data=data) + control = Control( + namespace_key=namespace_key, + name=name, + data=data, + cloned_from_control_id=cloned_from_control_id, + ) self._db.add(control) return control @@ -427,10 +456,13 @@ async def list_controls_page( name: str | None, enabled: bool | None, template_backed: bool | None, + cloned: bool | None, step_type: str | None, stage: str | None, execution: str | None, tag: str | None, + attachment_target_type: str | None = None, + attachment_target_id: str | None = None, ) -> ControlListPage: """Return paginated active controls for the browse endpoint.""" query = ( @@ -443,11 +475,18 @@ async def list_controls_page( name=name, enabled=enabled, template_backed=template_backed, + cloned=cloned, step_type=step_type, stage=stage, execution=execution, tag=tag, ) + query = self._apply_control_attachment_filters( + query, + namespace_key=namespace_key, + target_type=attachment_target_type, + target_id=attachment_target_id, + ) if cursor is not None: query = query.where(Control.id < cursor) @@ -464,11 +503,18 @@ async def list_controls_page( name=name, enabled=enabled, template_backed=template_backed, + cloned=cloned, step_type=step_type, stage=stage, execution=execution, tag=tag, ) + total_query = self._apply_control_attachment_filters( + total_query, + namespace_key=namespace_key, + target_type=attachment_target_type, + target_id=attachment_target_id, + ) total_result = await self._db.execute(total_query) total = cast(int, total_result.scalar_one()) @@ -535,6 +581,129 @@ async def list_control_usage( for control_id, agent_names in usage_names.items() } + async def list_control_attachments( + self, + control_ids: Sequence[int], + *, + namespace_key: str, + target_type: str | None = None, + target_id: str | None = None, + include_targets: bool = True, + ) -> dict[int, ControlAttachmentSet]: + """Return direct policy, direct agent, and target attachments for controls.""" + if not control_ids: + return {} + + unique_control_ids = list(dict.fromkeys(control_ids)) + policy_ids_by_control: dict[int, set[int]] = { + control_id: set() for control_id in unique_control_ids + } + agent_names_by_control: dict[int, set[str]] = { + control_id: set() for control_id in unique_control_ids + } + targets_by_control: dict[int, list[ControlTargetAttachment]] = { + control_id: [] for control_id in unique_control_ids + } + target_totals_by_control: dict[int, int] = { + control_id: 0 for control_id in unique_control_ids + } + + policy_result = await self._db.execute( + select(policy_controls.c.control_id, policy_controls.c.policy_id).where( + policy_controls.c.namespace_key == namespace_key, + policy_controls.c.control_id.in_(unique_control_ids), + ) + ) + for control_id, policy_id in policy_result.all(): + policy_ids_by_control[cast(int, control_id)].add(cast(int, policy_id)) + + agent_result = await self._db.execute( + select(agent_controls.c.control_id, agent_controls.c.agent_name).where( + agent_controls.c.namespace_key == namespace_key, + agent_controls.c.control_id.in_(unique_control_ids), + ) + ) + for control_id, agent_name in agent_result.all(): + agent_names_by_control[cast(int, control_id)].add(cast(str, agent_name)) + + if include_targets: + target_rank = func.row_number().over( + partition_by=ControlBinding.control_id, + order_by=ControlBinding.id.desc(), + ).label("target_rank") + target_total = func.count().over( + partition_by=ControlBinding.control_id + ).label("target_total") + target_query = ( + select( + ControlBinding.control_id, + ControlBinding.id, + ControlBinding.target_type, + ControlBinding.target_id, + ControlBinding.enabled, + target_rank, + target_total, + ) + .where( + ControlBinding.namespace_key == namespace_key, + ControlBinding.control_id.in_(unique_control_ids), + ) + ) + if target_type is not None: + target_query = target_query.where(ControlBinding.target_type == target_type) + if target_id is not None: + target_query = target_query.where(ControlBinding.target_id == target_id) + target_rows = target_query.subquery() + target_result = await self._db.execute( + select( + target_rows.c.control_id, + target_rows.c.id, + target_rows.c.target_type, + target_rows.c.target_id, + target_rows.c.enabled, + target_rows.c.target_total, + ) + .where( + target_rows.c.target_rank + <= _MAX_INLINE_TARGET_ATTACHMENTS_PER_CONTROL + ) + .order_by(target_rows.c.control_id, target_rows.c.target_rank) + ) + for ( + control_id, + binding_id, + binding_target_type, + binding_target_id, + enabled, + target_total, + ) in ( + target_result.all() + ): + typed_control_id = cast(int, control_id) + target_totals_by_control[typed_control_id] = cast(int, target_total) + targets_by_control[typed_control_id].append( + ControlTargetAttachment( + binding_id=cast(int, binding_id), + target_type=cast(str, binding_target_type), + target_id=cast(str, binding_target_id), + enabled=cast(bool, enabled), + ) + ) + + return { + control_id: ControlAttachmentSet( + policy_ids=sorted(policy_ids_by_control[control_id]), + agent_names=sorted(agent_names_by_control[control_id]), + targets=targets_by_control[control_id], + targets_total=target_totals_by_control[control_id], + targets_truncated=( + target_totals_by_control[control_id] + > len(targets_by_control[control_id]) + ), + ) + for control_id in unique_control_ids + } + async def list_active_control_counts_by_agent( self, agent_names: Sequence[str], @@ -820,6 +989,7 @@ def _apply_control_list_filters( name: str | None, enabled: bool | None, template_backed: bool | None, + cloned: bool | None, step_type: str | None, stage: str | None, execution: str | None, @@ -846,6 +1016,12 @@ def _apply_control_list_filters( else: stmt = stmt.where(~Control.data.has_key("template")) + if cloned is not None: + if cloned: + stmt = stmt.where(Control.cloned_from_control_id.is_not(None)) + else: + stmt = stmt.where(Control.cloned_from_control_id.is_(None)) + has_rendered_filter = any(f is not None for f in (step_type, stage, execution, tag)) if has_rendered_filter: stmt = stmt.where(Control.data.has_key("condition")) @@ -873,16 +1049,42 @@ def _apply_control_list_filters( return stmt + def _apply_control_attachment_filters( + self, + stmt: Select[Any], + *, + namespace_key: str, + target_type: str | None, + target_id: str | None, + ) -> Select[Any]: + """Restrict a control list to controls with matching target bindings.""" + if target_type is None and target_id is None: + return stmt + + binding_exists = exists().where( + ControlBinding.namespace_key == namespace_key, + ControlBinding.control_id == Control.id, + ) + if target_type is not None: + binding_exists = binding_exists.where(ControlBinding.target_type == target_type) + if target_id is not None: + binding_exists = binding_exists.where(ControlBinding.target_id == target_id) + return stmt.where(binding_exists) + @staticmethod def _build_snapshot(control: Control) -> dict[str, Any]: """Serialize the persisted control state stored in version history.""" deleted_at = control.deleted_at.isoformat() if control.deleted_at is not None else None - cloned_control_id = cast(int | None, getattr(control, "cloned_control_id", None)) + cloned_from_control_id = cast( + int | None, getattr(control, "cloned_from_control_id", None) + ) return { "name": control.name, "data": control.data, "deleted_at": deleted_at, - "cloned_control_id": cloned_control_id, + "cloned_from_control_id": cloned_from_control_id, + # Legacy snapshot alias; remove after consumers have migrated. + "cloned_control_id": cloned_from_control_id, } diff --git a/server/tests/test_auth_framework.py b/server/tests/test_auth_framework.py index a95f0252..5f31c52f 100644 --- a/server/tests/test_auth_framework.py +++ b/server/tests/test_auth_framework.py @@ -386,6 +386,19 @@ async def test_http_upstream_fails_closed_on_5xx(): assert "500" in exc_info.value.detail +@pytest.mark.asyncio +@pytest.mark.parametrize("status", [400, 422]) +async def test_http_upstream_unexpected_4xx_reports_upstream_rejection(status): + provider = _build_upstream(lambda req: httpx.Response(status, text="bad request")) + + with pytest.raises(APIError) as exc_info: + await provider.authorize(_build_request(), Operation.CONTROL_BINDINGS_WRITE) + + assert exc_info.value.status_code == 502 + assert exc_info.value.error_code == "AUTH_UPSTREAM_REJECTED" + assert str(status) in exc_info.value.detail + + @pytest.mark.asyncio async def test_http_upstream_surfaces_rate_limit_distinctly(): """Upstream 429 must surface a rate-limit-specific detail and hint.""" diff --git a/server/tests/test_control_bindings_endpoints.py b/server/tests/test_control_bindings_endpoints.py index 6cdebf7e..8333bb95 100644 --- a/server/tests/test_control_bindings_endpoints.py +++ b/server/tests/test_control_bindings_endpoints.py @@ -5,11 +5,12 @@ import uuid from typing import Any +from agent_control_server.auth_framework import Operation, Principal, set_authorizer +from agent_control_server.models import DEFAULT_NAMESPACE_KEY from fastapi.testclient import TestClient from .utils import VALID_CONTROL_PAYLOAD - _BINDINGS_URL = "/api/v1/control-bindings" @@ -318,6 +319,83 @@ def test_upsert_by_key_updates_updated_at_on_existing_row( assert after_upsert["updated_at"] != initial_updated_at +def test_patch_by_key_updates_existing_binding(client: TestClient) -> None: + control_id = _create_control(client) + binding_id = _create_binding(client, control_id=control_id)["binding_id"] + body = { + "target_type": "env", + "target_id": "prod", + "control_id": control_id, + "enabled": False, + } + + resp = client.patch(f"{_BINDINGS_URL}/by-key", json=body) + + assert resp.status_code == 200, resp.text + assert resp.json() == {"success": True, "enabled": False} + fetched = client.get(f"{_BINDINGS_URL}/{binding_id}").json() + assert fetched["enabled"] is False + + +def test_patch_by_key_returns_404_without_creating_binding(client: TestClient) -> None: + control_id = _create_control(client) + body = { + "target_type": "env", + "target_id": "prod", + "control_id": control_id, + "enabled": False, + } + + resp = client.patch(f"{_BINDINGS_URL}/by-key", json=body) + + assert resp.status_code == 404 + assert resp.json()["error_code"] == "CONTROL_BINDING_NOT_FOUND" + bindings = client.get( + _BINDINGS_URL, + params={"target_type": "env", "target_id": "prod", "control_id": control_id}, + ).json()["bindings"] + assert bindings == [] + + +def test_patch_by_key_passes_target_context_to_authorizer( + client: TestClient, +) -> None: + control_id = _create_control(client) + _create_binding(client, control_id=control_id) + calls: list[tuple[Operation, dict[str, Any] | None]] = [] + + class RecordingAuthorizer: + async def authorize( + self, + request: Any, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + del request + calls.append((operation, context)) + return Principal(namespace_key=DEFAULT_NAMESPACE_KEY, is_admin=True) + + set_authorizer(RecordingAuthorizer()) + + resp = client.patch( + f"{_BINDINGS_URL}/by-key", + json={ + "target_type": "env", + "target_id": "prod", + "control_id": control_id, + "enabled": False, + }, + ) + + assert resp.status_code == 200, resp.text + assert calls == [ + ( + Operation.CONTROL_BINDINGS_WRITE, + {"target_type": "env", "target_id": "prod"}, + ) + ] + + def test_delete_by_key_removes_existing_binding(client: TestClient) -> None: control_id = _create_control(client) client.put( @@ -366,6 +444,11 @@ def test_non_admin_cannot_use_by_key_endpoints( upsert_resp = non_admin_client.put(f"{_BINDINGS_URL}/by-key", json=body) assert upsert_resp.status_code == 403 + patch_resp = non_admin_client.patch( + f"{_BINDINGS_URL}/by-key", json={**body, "enabled": False} + ) + assert patch_resp.status_code == 403 + delete_resp = non_admin_client.post( f"{_BINDINGS_URL}/by-key:delete", json=body ) diff --git a/server/tests/test_control_versions.py b/server/tests/test_control_versions.py index f387a1f6..2af4fed2 100644 --- a/server/tests/test_control_versions.py +++ b/server/tests/test_control_versions.py @@ -53,6 +53,7 @@ def test_create_control_creates_initial_version_row(client: TestClient) -> None: assert version.snapshot["name"] == control_name assert version.snapshot["data"]["description"] == VALID_CONTROL_PAYLOAD["description"] assert version.snapshot["deleted_at"] is None + assert version.snapshot["cloned_from_control_id"] is None assert version.snapshot["cloned_control_id"] is None diff --git a/server/tests/test_controls_additional.py b/server/tests/test_controls_additional.py index dfbb15f5..5d1277c3 100644 --- a/server/tests/test_controls_additional.py +++ b/server/tests/test_controls_additional.py @@ -5,22 +5,30 @@ from collections.abc import AsyncGenerator from copy import deepcopy from types import SimpleNamespace +from typing import Any from unittest.mock import AsyncMock, MagicMock import pytest from agent_control_evaluators import RegexEvaluatorConfig from agent_control_models import ConditionNode +from agent_control_models.errors import ErrorCode from fastapi.testclient import TestClient -from sqlalchemy import text +from sqlalchemy import select, text from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from agent_control_server.auth_framework import Principal +from agent_control_server.auth_framework import Operation, Principal, set_authorizer from agent_control_server.db import get_async_db from agent_control_server.endpoints import controls as controls_module +from agent_control_server.errors import BadRequestError, ForbiddenError from agent_control_server.main import app -from agent_control_server.models import DEFAULT_NAMESPACE_KEY, Control +from agent_control_server.models import ( + DEFAULT_NAMESPACE_KEY, + Control, + ControlBinding, + ControlVersion, +) from .conftest import engine from .utils import VALID_CONTROL_PAYLOAD @@ -60,6 +68,503 @@ def _set_control_data(client: TestClient, control_id: int, data: dict) -> None: assert resp.status_code == 200, resp.text +def test_clone_and_bind_creates_cloned_control_binding_and_version( + client: TestClient, +) -> None: + source_id, source_name = _create_control(client) + + resp = client.post( + f"/api/v1/controls/{source_id}/clone-and-bind", + json={ + "target_binding": { + "target_type": "log_stream", + "target_id": "logstream-123", + "enabled": False, + } + }, + ) + + assert resp.status_code == 200, resp.text + body = resp.json() + clone_id = body["id"] + binding_id = body["binding_id"] + assert clone_id != source_id + assert body["name"].startswith(f"{source_name}-clone-") + assert body["cloned_from_control_id"] == source_id + + with Session(engine) as session: + source = session.get_one(Control, source_id) + clone = session.get_one(Control, clone_id) + binding = session.execute( + select(ControlBinding).where(ControlBinding.id == binding_id) + ).scalar_one() + version = session.execute( + select(ControlVersion).where(ControlVersion.control_id == clone_id) + ).scalar_one() + + assert clone.namespace_key == source.namespace_key + assert clone.data == source.data + assert clone.cloned_from_control_id == source_id + assert binding.control_id == clone_id + assert binding.target_type == "log_stream" + assert binding.target_id == "logstream-123" + assert binding.enabled is False + assert version.version_num == 1 + assert version.event_type == "cloned" + assert version.note == f"Cloned from control {source_id}" + assert version.snapshot["cloned_from_control_id"] == source_id + assert version.snapshot["cloned_control_id"] == source_id + + get_resp = client.get(f"/api/v1/controls/{clone_id}") + assert get_resp.status_code == 200 + assert get_resp.json()["cloned_from_control_id"] == source_id + + +def test_control_clone_lineage_enforces_same_namespace() -> None: + source = Control( + namespace_key=DEFAULT_NAMESPACE_KEY, + name=f"source-{uuid.uuid4()}", + data=deepcopy(VALID_CONTROL_PAYLOAD), + ) + clone = Control( + namespace_key="other-namespace", + name=f"clone-{uuid.uuid4()}", + data=deepcopy(VALID_CONTROL_PAYLOAD), + cloned_from_control_id=1, + ) + + with Session(engine) as session: + session.add(source) + session.flush() + clone.cloned_from_control_id = int(source.id) + session.add(clone) + with pytest.raises(IntegrityError): + session.commit() + + +def test_clone_and_bind_generated_name_falls_back_for_legacy_name( + client: TestClient, +) -> None: + legacy_name = "legacy control name" + with engine.begin() as conn: + conn.execute( + text( + "INSERT INTO controls (name, data) VALUES (:name, CAST(:data AS JSONB))" + ), + { + "name": legacy_name, + "data": json.dumps(VALID_CONTROL_PAYLOAD), + }, + ) + row = conn.execute( + text("SELECT id FROM controls WHERE name = :name"), + {"name": legacy_name}, + ).fetchone() + assert row is not None + source_id = row[0] + + resp = client.post( + f"/api/v1/controls/{source_id}/clone-and-bind", + json={ + "target_binding": { + "target_type": "log_stream", + "target_id": "logstream-legacy-name", + }, + }, + ) + + assert resp.status_code == 200, resp.text + assert resp.json()["name"].startswith(f"control-{source_id}-clone-") + + +def test_list_controls_filters_by_cloned_state(client: TestClient) -> None: + source_id, _ = _create_control(client, name=f"Root-{uuid.uuid4()}") + clone_name = f"Clone-{uuid.uuid4()}" + clone_resp = client.post( + f"/api/v1/controls/{source_id}/clone-and-bind", + json={ + "name": clone_name, + "target_binding": { + "target_type": "log_stream", + "target_id": "logstream-456", + }, + }, + ) + assert clone_resp.status_code == 200, clone_resp.text + clone_id = clone_resp.json()["id"] + + root_resp = client.get("/api/v1/controls", params={"cloned": False, "limit": 100}) + assert root_resp.status_code == 200 + root_ids = {control["id"] for control in root_resp.json()["controls"]} + assert source_id in root_ids + assert clone_id not in root_ids + + clone_list_resp = client.get( + "/api/v1/controls", params={"cloned": True, "limit": 100} + ) + assert clone_list_resp.status_code == 200 + cloned_controls = clone_list_resp.json()["controls"] + cloned_ids = {control["id"] for control in cloned_controls} + assert clone_id in cloned_ids + assert source_id not in cloned_ids + listed_clone = next(control for control in cloned_controls if control["id"] == clone_id) + assert listed_clone["cloned_from_control_id"] == source_id + + +def test_clone_and_bind_returns_conflict_for_duplicate_clone_name( + client: TestClient, +) -> None: + _, existing_name = _create_control(client) + source_id, _ = _create_control(client) + + resp = client.post( + f"/api/v1/controls/{source_id}/clone-and-bind", + json={ + "name": existing_name, + "target_binding": { + "target_type": "log_stream", + "target_id": "logstream-789", + }, + }, + ) + + assert resp.status_code == 409 + assert resp.json()["error_code"] == "CONTROL_NAME_CONFLICT" + + +def test_clone_and_bind_integrity_error_name_conflict_returns_409( + client: TestClient, + monkeypatch: pytest.MonkeyPatch, +) -> None: + source_id, _ = _create_control(client) + + async def fail_create_version( + self: controls_module.ControlService, + control: Control, + *, + event_type: str, + note: str, + ) -> None: + _ = (self, control, event_type, note) + raise _make_integrity_error("idx_controls_namespace_name_active") + + monkeypatch.setattr( + controls_module.ControlService, + "create_version", + fail_create_version, + ) + + resp = client.post( + f"/api/v1/controls/{source_id}/clone-and-bind", + json={ + "name": f"race-{uuid.uuid4()}", + "target_binding": { + "target_type": "log_stream", + "target_id": "logstream-race", + }, + }, + ) + + assert resp.status_code == 409 + assert resp.json()["error_code"] == "CONTROL_NAME_CONFLICT" + + +def test_clone_and_bind_generated_name_retries_preflight_conflicts( + client: TestClient, + monkeypatch: pytest.MonkeyPatch, +) -> None: + source_id, source_name = _create_control(client, name=f"source-{uuid.uuid4()}") + first_suffix = "1111111111111111" + second_suffix = "2222222222222222" + _create_control(client, name=f"{source_name}-clone-{first_suffix}") + suffixes = iter([first_suffix, second_suffix]) + + def fake_uuid4() -> SimpleNamespace: + return SimpleNamespace(hex=next(suffixes)) + + monkeypatch.setattr(controls_module.uuid, "uuid4", fake_uuid4) + + resp = client.post( + f"/api/v1/controls/{source_id}/clone-and-bind", + json={ + "target_binding": { + "target_type": "log_stream", + "target_id": "logstream-retry-name", + }, + }, + ) + + assert resp.status_code == 200, resp.text + assert resp.json()["name"] == f"{source_name}-clone-{second_suffix}" + + +def test_clone_and_bind_rolls_back_clone_when_binding_fails( + client: TestClient, + monkeypatch: pytest.MonkeyPatch, +) -> None: + source_id, _ = _create_control(client) + clone_name = f"CloneRollback-{uuid.uuid4()}" + + async def fail_create_binding(*args: Any, **kwargs: Any) -> None: + raise BadRequestError( + error_code=ErrorCode.CONTROL_BINDING_INCOMPATIBLE, + detail="Binding failed after clone creation.", + resource="ControlBinding", + ) + + monkeypatch.setattr( + controls_module.ControlBindingsService, + "create_binding", + fail_create_binding, + ) + + resp = client.post( + f"/api/v1/controls/{source_id}/clone-and-bind", + json={ + "name": clone_name, + "target_binding": { + "target_type": "log_stream", + "target_id": "logstream-rollback", + }, + }, + ) + + assert resp.status_code == 400 + with Session(engine) as session: + clone = session.execute( + select(Control).where(Control.name == clone_name) + ).scalar_one_or_none() + assert clone is None + + +def test_clone_and_bind_locks_source_control( + client: TestClient, + monkeypatch: pytest.MonkeyPatch, +) -> None: + source_id, _ = _create_control(client) + original_get_active = controls_module.ControlService.get_active_control_or_404 + seen_for_update: list[bool] = [] + + async def recording_get_active( + self: controls_module.ControlService, + control_id_arg: int, + *, + namespace_key: str | None = None, + for_update: bool = False, + ) -> Control: + seen_for_update.append(for_update) + return await original_get_active( + self, + control_id_arg, + namespace_key=namespace_key, + for_update=for_update, + ) + + monkeypatch.setattr( + controls_module.ControlService, + "get_active_control_or_404", + recording_get_active, + ) + + resp = client.post( + f"/api/v1/controls/{source_id}/clone-and-bind", + json={ + "target_binding": { + "target_type": "log_stream", + "target_id": "logstream-lock", + }, + }, + ) + + assert resp.status_code == 200, resp.text + assert seen_for_update == [True] + + +def test_clone_and_bind_rejects_auth_namespace_mismatch(client: TestClient) -> None: + source_id, _ = _create_control(client) + + class MismatchedNamespaceAuthorizer: + async def authorize( + self, + request: Any, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + namespace_key = ( + "other-namespace" + if operation == Operation.CONTROL_BINDINGS_WRITE + else DEFAULT_NAMESPACE_KEY + ) + return Principal(namespace_key=namespace_key, is_admin=True) + + set_authorizer(MismatchedNamespaceAuthorizer()) + + resp = client.post( + f"/api/v1/controls/{source_id}/clone-and-bind", + json={ + "target_binding": { + "target_type": "log_stream", + "target_id": "logstream-mismatch", + }, + }, + ) + + assert resp.status_code == 403 + assert resp.json()["error_code"] == "AUTH_INSUFFICIENT_PRIVILEGES" + + +def test_clone_and_bind_requires_source_read_authorization( + client: TestClient, +) -> None: + source_id, _ = _create_control(client) + calls: list[tuple[Operation, dict[str, Any] | None]] = [] + + class ReadMismatchAuthorizer: + async def authorize( + self, + request: Any, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + calls.append((operation, context)) + namespace_key = ( + "other-namespace" + if operation == Operation.CONTROLS_READ + else DEFAULT_NAMESPACE_KEY + ) + return Principal(namespace_key=namespace_key, is_admin=True) + + set_authorizer(ReadMismatchAuthorizer()) + + resp = client.post( + f"/api/v1/controls/{source_id}/clone-and-bind", + json={ + "target_binding": { + "target_type": "log_stream", + "target_id": "logstream-read-auth", + }, + }, + ) + + assert resp.status_code == 403 + assert resp.json()["error_code"] == "AUTH_INSUFFICIENT_PRIVILEGES" + read_contexts = [ + context for operation, context in calls if operation == Operation.CONTROLS_READ + ] + assert read_contexts == [None] + + +def test_clone_and_bind_context_tolerates_invalid_body_shapes( + client: TestClient, +) -> None: + resp = client.post( + "/api/v1/controls/1/clone-and-bind", + content="{", + headers={"Content-Type": "application/json"}, + ) + assert resp.status_code == 422 + + list_resp = client.post("/api/v1/controls/1/clone-and-bind", json=[]) + assert list_resp.status_code == 422 + + bad_target_resp = client.post( + "/api/v1/controls/1/clone-and-bind", + json={"target_binding": "not-an-object"}, + ) + assert bad_target_resp.status_code == 422 + + +def test_clone_and_bind_context_drops_invalid_target_fields( + client: TestClient, +) -> None: + calls: list[tuple[Operation, dict[str, Any] | None]] = [] + + class RecordingAuthorizer: + async def authorize( + self, + request: Any, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + calls.append((operation, context)) + return Principal(namespace_key=DEFAULT_NAMESPACE_KEY, is_admin=True) + + set_authorizer(RecordingAuthorizer()) + + resp = client.post( + "/api/v1/controls/1/clone-and-bind", + json={ + "target_binding": { + "target_type": ["log_stream"], + "target_id": {"id": "logstream-invalid"}, + }, + }, + ) + + assert resp.status_code == 422 + binding_contexts = [ + context + for operation, context in calls + if operation == Operation.CONTROL_BINDINGS_WRITE + ] + assert binding_contexts == [{}] + + +def test_clone_and_bind_context_drops_overlong_target_fields( + client: TestClient, +) -> None: + calls: list[tuple[Operation, dict[str, Any] | None]] = [] + + class RecordingAuthorizer: + async def authorize( + self, + request: Any, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + calls.append((operation, context)) + return Principal(namespace_key=DEFAULT_NAMESPACE_KEY, is_admin=True) + + set_authorizer(RecordingAuthorizer()) + + resp = client.post( + "/api/v1/controls/1/clone-and-bind", + json={ + "target_binding": { + "target_type": "x" * 256, + "target_id": "logstream-invalid", + }, + }, + ) + + assert resp.status_code == 422 + binding_contexts = [ + context + for operation, context in calls + if operation == Operation.CONTROL_BINDINGS_WRITE + ] + assert binding_contexts == [{}] + + +def test_clone_and_bind_rejects_unknown_target_binding_fields( + client: TestClient, +) -> None: + source_id, _ = _create_control(client) + + resp = client.post( + f"/api/v1/controls/{source_id}/clone-and-bind", + json={ + "target_binding": { + "target_type": "log_stream", + "target_id": "logstream-extra", + "unknown_field": "ignored-before", + }, + }, + ) + + assert resp.status_code == 422 + + @pytest.mark.parametrize( "constraint_name", ["idx_controls_name_active", "idx_controls_namespace_name_active"], @@ -260,6 +765,7 @@ def test_list_controls_filters_and_pagination(client: TestClient) -> None: control = resp.json()["controls"][0] assert control["name"] == control3_name assert control["enabled"] is True + assert control["action"] == {"decision": "deny", "steering_context": None} # When: paginating resp = client.get("/api/v1/controls", params={"limit": 1}) @@ -696,19 +1202,337 @@ def test_delete_control_force_dissociates_direct_agent_links(client: TestClient) assert list_resp.json()["pagination"]["total"] == 0 -def _create_target_binding(client: TestClient, *, control_id: int) -> int: +def _create_target_binding( + client: TestClient, + *, + control_id: int, + target_type: str = "env", + target_id: str = "prod", + enabled: bool = True, +) -> int: resp = client.put( "/api/v1/control-bindings", json={ - "target_type": "env", - "target_id": "prod", + "target_type": target_type, + "target_id": target_id, "control_id": control_id, + "enabled": enabled, }, ) assert resp.status_code == 200, resp.text return int(resp.json()["binding_id"]) +def test_list_controls_returns_null_attachments_by_default( + client: TestClient, +) -> None: + control_id, control_name = _create_control(client, name=f"Attachments-{uuid.uuid4()}") + _set_control_data(client, control_id, deepcopy(VALID_CONTROL_PAYLOAD)) + + resp = client.get("/api/v1/controls", params={"name": control_name}) + + assert resp.status_code == 200, resp.text + controls = resp.json()["controls"] + assert len(controls) == 1 + assert controls[0]["id"] == control_id + assert controls[0]["attachments"] is None + + +def test_list_controls_filters_by_target_attachment_before_pagination( + client: TestClient, +) -> None: + prefix = f"AttachmentFilter-{uuid.uuid4()}" + target_id = f"ls-{uuid.uuid4()}" + matching_control_id, _ = _create_control(client, name=f"{prefix}-matching") + _set_control_data(client, matching_control_id, deepcopy(VALID_CONTROL_PAYLOAD)) + matching_binding_id = _create_target_binding( + client, + control_id=matching_control_id, + target_type="log_stream", + target_id=target_id, + ) + + newer_unmatched_control_id, _ = _create_control(client, name=f"{prefix}-unmatched") + _set_control_data(client, newer_unmatched_control_id, deepcopy(VALID_CONTROL_PAYLOAD)) + + resp = client.get( + "/api/v1/controls", + params={ + "name": prefix, + "include_attachments": "true", + "attachment_target_type": "log_stream", + "attachment_target_id": target_id, + "limit": 1, + }, + ) + + assert resp.status_code == 200, resp.text + body = resp.json() + assert body["pagination"]["total"] == 1 + assert body["pagination"]["has_more"] is False + controls = body["controls"] + assert len(controls) == 1 + assert controls[0]["id"] == matching_control_id + assert controls[0]["id"] != newer_unmatched_control_id + assert controls[0]["attachments"]["targets"] == [ + { + "binding_id": matching_binding_id, + "target_type": "log_stream", + "target_id": target_id, + "enabled": True, + } + ] + assert controls[0]["attachments"]["targets_total"] == 1 + assert controls[0]["attachments"]["targets_truncated"] is False + + +def test_list_controls_expands_filtered_control_attachments( + client: TestClient, +) -> None: + control_id, control_name = _create_control(client, name=f"Attachments-{uuid.uuid4()}") + _set_control_data(client, control_id, deepcopy(VALID_CONTROL_PAYLOAD)) + + policy_resp = client.put("/api/v1/policies", json={"name": f"pol-{uuid.uuid4()}"}) + assert policy_resp.status_code == 200 + policy_id = policy_resp.json()["policy_id"] + policy_assoc_resp = client.post(f"/api/v1/policies/{policy_id}/controls/{control_id}") + assert policy_assoc_resp.status_code == 200 + + agent_name = f"agent-{uuid.uuid4().hex[:12]}" + init_resp = client.post( + "/api/v1/agents/initAgent", + json={"agent": {"agent_name": agent_name}, "steps": []}, + ) + assert init_resp.status_code == 200 + agent_assoc_resp = client.post(f"/api/v1/agents/{agent_name}/controls/{control_id}") + assert agent_assoc_resp.status_code == 200 + + included_binding_id = _create_target_binding( + client, + control_id=control_id, + target_type="log_stream", + target_id="ls-prod", + enabled=False, + ) + _create_target_binding( + client, + control_id=control_id, + target_type="log_stream", + target_id="ls-dev", + ) + _create_target_binding( + client, + control_id=control_id, + target_type="environment", + target_id="prod", + ) + + resp = client.get( + "/api/v1/controls", + params={ + "name": control_name, + "include_attachments": "true", + "attachment_target_type": "log_stream", + "attachment_target_id": "ls-prod", + }, + ) + + assert resp.status_code == 200, resp.text + controls = resp.json()["controls"] + assert len(controls) == 1 + assert controls[0]["attachments"] == { + "agents": [{"agent_name": agent_name}], + "policies": [{"policy_id": policy_id}], + "targets": [ + { + "binding_id": included_binding_id, + "target_type": "log_stream", + "target_id": "ls-prod", + "enabled": False, + } + ], + "targets_total": 1, + "targets_truncated": False, + } + + +def test_list_controls_caps_inline_target_attachments( + client: TestClient, +) -> None: + control_id, control_name = _create_control(client, name=f"Attachments-{uuid.uuid4()}") + _set_control_data(client, control_id, deepcopy(VALID_CONTROL_PAYLOAD)) + binding_ids = [ + _create_target_binding( + client, + control_id=control_id, + target_type="log_stream", + target_id=f"ls-{index}", + ) + for index in range(25) + ] + + resp = client.get( + "/api/v1/controls", + params={ + "name": control_name, + "include_attachments": "true", + }, + ) + + assert resp.status_code == 200, resp.text + attachments = resp.json()["controls"][0]["attachments"] + assert len(attachments["targets"]) == 20 + assert attachments["targets_total"] == 25 + assert attachments["targets_truncated"] is True + assert [target["binding_id"] for target in attachments["targets"]] == list( + reversed(binding_ids[-20:]) + ) + + +def test_list_controls_omits_targets_without_binding_read_authorization( + client: TestClient, +) -> None: + control_id, control_name = _create_control(client, name=f"Attachments-{uuid.uuid4()}") + _set_control_data(client, control_id, deepcopy(VALID_CONTROL_PAYLOAD)) + _create_target_binding( + client, + control_id=control_id, + target_type="log_stream", + target_id="ls-prod", + ) + calls: list[tuple[Operation, dict[str, Any] | None]] = [] + + class BindingReadDenyAuthorizer: + async def authorize( + self, + request: Any, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + calls.append((operation, context)) + if operation == Operation.CONTROL_BINDINGS_READ: + raise ForbiddenError( + error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, + detail="No target read access.", + ) + return Principal(namespace_key=DEFAULT_NAMESPACE_KEY, is_admin=True) + + set_authorizer(BindingReadDenyAuthorizer()) + + resp = client.get( + "/api/v1/controls", + params={ + "name": control_name, + "include_attachments": "true", + }, + ) + + assert resp.status_code == 200, resp.text + controls = resp.json()["controls"] + assert controls[0]["attachments"] == { + "agents": [], + "policies": [], + "targets": [], + "targets_total": 0, + "targets_truncated": False, + } + assert (Operation.CONTROL_BINDINGS_READ, {}) in calls + + +def test_list_controls_rejects_target_filter_without_binding_read_authorization( + client: TestClient, +) -> None: + control_id, control_name = _create_control(client, name=f"Attachments-{uuid.uuid4()}") + _set_control_data(client, control_id, deepcopy(VALID_CONTROL_PAYLOAD)) + _create_target_binding( + client, + control_id=control_id, + target_type="log_stream", + target_id="ls-prod", + ) + calls: list[tuple[Operation, dict[str, Any] | None]] = [] + + class BindingReadDenyAuthorizer: + async def authorize( + self, + request: Any, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + calls.append((operation, context)) + if operation == Operation.CONTROL_BINDINGS_READ: + raise ForbiddenError( + error_code=ErrorCode.AUTH_INSUFFICIENT_PRIVILEGES, + detail="No target read access.", + ) + return Principal(namespace_key=DEFAULT_NAMESPACE_KEY, is_admin=True) + + set_authorizer(BindingReadDenyAuthorizer()) + + resp = client.get( + "/api/v1/controls", + params={ + "name": control_name, + "include_attachments": "true", + "attachment_target_type": "log_stream", + "attachment_target_id": "ls-prod", + }, + ) + + assert resp.status_code == 403 + assert resp.json()["error_code"] == "AUTH_INSUFFICIENT_PRIVILEGES" + assert ( + Operation.CONTROL_BINDINGS_READ, + {"target_type": "log_stream", "target_id": "ls-prod"}, + ) in calls + + +def test_list_controls_rejects_attachment_namespace_mismatch( + client: TestClient, +) -> None: + control_id, control_name = _create_control(client, name=f"Attachments-{uuid.uuid4()}") + _set_control_data(client, control_id, deepcopy(VALID_CONTROL_PAYLOAD)) + + class MismatchedBindingReadAuthorizer: + async def authorize( + self, + request: Any, + operation: Operation, + context: dict[str, Any] | None = None, + ) -> Principal: + namespace_key = ( + "other-namespace" + if operation == Operation.CONTROL_BINDINGS_READ + else DEFAULT_NAMESPACE_KEY + ) + return Principal(namespace_key=namespace_key, is_admin=True) + + set_authorizer(MismatchedBindingReadAuthorizer()) + + resp = client.get( + "/api/v1/controls", + params={ + "name": control_name, + "include_attachments": "true", + }, + ) + + assert resp.status_code == 403 + assert resp.json()["error_code"] == "AUTH_INSUFFICIENT_PRIVILEGES" + + +def test_list_controls_rejects_attachment_filters_without_expansion( + client: TestClient, +) -> None: + resp = client.get( + "/api/v1/controls", + params={"attachment_target_type": "log_stream"}, + ) + + assert resp.status_code == 422 + assert resp.json()["error_code"] == "VALIDATION_ERROR" + + def test_delete_control_blocks_when_target_binding_exists( client: TestClient, ) -> None: diff --git a/server/tests/test_data_model_v1_alembic_migration.py b/server/tests/test_data_model_v1_alembic_migration.py index 9e6764fa..53334732 100644 --- a/server/tests/test_data_model_v1_alembic_migration.py +++ b/server/tests/test_data_model_v1_alembic_migration.py @@ -17,6 +17,7 @@ PRE_MIGRATION_REVISION = "c1e9f9c4a1d2" MIGRATION_REVISION = "a7f3b1e0d9c5" OBSERVABILITY_NAMESPACE_REVISION = "b6f4c2d8e9a1" +CLONE_LINEAGE_REVISION = "e2b7f4a9c6d1" _BASE_DB_URL = make_url(db_config.get_url()) pytestmark = pytest.mark.skipif( @@ -103,6 +104,37 @@ def _assert_observability_namespace_schema(engine: Engine) -> None: assert "ix_events_agent_time" not in indexes +def _pg_index_definition(engine: Engine, index_name: str) -> str: + with engine.begin() as conn: + return str( + conn.execute( + text( + """ + SELECT indexdef + FROM pg_indexes + WHERE tablename = 'controls' AND indexname = :index_name + """ + ), + {"index_name": index_name}, + ).scalar_one() + ) + + +def _pg_constraint_definition(engine: Engine, constraint_name: str) -> tuple[str, str]: + with engine.begin() as conn: + row = conn.execute( + text( + """ + SELECT pg_get_constraintdef(oid) AS definition, confdeltype + FROM pg_constraint + WHERE conname = :constraint_name + """ + ), + {"constraint_name": constraint_name}, + ).one() + return str(row.definition), str(row.confdeltype) + + def test_upgrade_applies_namespace_columns_and_constraints( alembic_config: Config, temp_engine: Engine ) -> None: @@ -292,6 +324,78 @@ def test_observability_namespace_migration_recovers_when_primary_key_preexists( _assert_observability_namespace_schema(temp_engine) +def test_control_clone_lineage_migration_adds_composite_fk_and_partial_index( + alembic_config: Config, temp_engine: Engine +) -> None: + command.upgrade(alembic_config, CLONE_LINEAGE_REVISION) + + assert "cloned_from_control_id" in _column_names(temp_engine, "controls") + assert "controls_cloned_from_control_fkey" in _foreign_key_names( + temp_engine, "controls" + ) + assert "idx_controls_cloned_from" in _index_names(temp_engine, "controls") + + constraint_def, delete_action = _pg_constraint_definition( + temp_engine, + "controls_cloned_from_control_fkey", + ) + assert ( + "FOREIGN KEY (namespace_key, cloned_from_control_id) " + "REFERENCES controls(namespace_key, id)" + ) in constraint_def + assert delete_action == "a" + + index_def = _pg_index_definition(temp_engine, "idx_controls_cloned_from") + assert "CREATE INDEX idx_controls_cloned_from" in index_def + assert "ON public.controls USING btree (namespace_key, cloned_from_control_id)" in index_def + assert "WHERE (cloned_from_control_id IS NOT NULL)" in index_def + + with temp_engine.begin() as conn: + source_id = conn.execute( + text( + """ + INSERT INTO controls (namespace_key, name, data) + VALUES ('ns-one', 'source', '{}'::jsonb) + RETURNING id + """ + ) + ).scalar_one() + conn.execute( + text( + """ + INSERT INTO controls (namespace_key, name, data, cloned_from_control_id) + VALUES ('ns-one', 'clone', '{}'::jsonb, :source_id) + """ + ), + {"source_id": source_id}, + ) + + with pytest.raises(Exception): + with temp_engine.begin() as conn: + conn.execute( + text( + """ + INSERT INTO controls ( + namespace_key, name, data, cloned_from_control_id + ) + VALUES ('ns-two', 'bad-clone', '{}'::jsonb, :source_id) + """ + ), + {"source_id": source_id}, + ) + + command.downgrade(alembic_config, OBSERVABILITY_NAMESPACE_REVISION) + + assert "cloned_from_control_id" not in _column_names(temp_engine, "controls") + assert "controls_cloned_from_control_fkey" not in _foreign_key_names( + temp_engine, "controls" + ) + assert "idx_controls_cloned_from" not in _index_names(temp_engine, "controls") + indexes = _index_names(temp_engine, "control_execution_events") + assert "ix_events_namespace_agent_time" in indexes + assert "ix_events_agent_time" not in indexes + + def test_downgrade_rejects_cross_namespace_agents_duplicates( alembic_config: Config, temp_engine: Engine ) -> None: diff --git a/server/tests/test_migrate.py b/server/tests/test_migrate.py index e6335753..619a8658 100644 --- a/server/tests/test_migrate.py +++ b/server/tests/test_migrate.py @@ -19,6 +19,7 @@ class _FakeConnection: def __init__(self, lock_results: list[bool]) -> None: self.lock_results = lock_results self.statements: list[str] = [] + self.commits = 0 def __enter__(self) -> _FakeConnection: return self @@ -35,6 +36,9 @@ def execute(self, statement: object, params: object) -> _FakeResult: return _FakeResult(True) raise AssertionError(f"unexpected SQL statement: {statement_text}") + def commit(self) -> None: + self.commits += 1 + class _FakeEngine: def __init__(self, connection: _FakeConnection) -> None: @@ -112,6 +116,7 @@ def fake_create_engine(*args: object, **kwargs: object) -> _FakeEngine: "SELECT pg_try_advisory_lock(:class_id, :object_id)", "SELECT pg_advisory_unlock(:class_id, :object_id)", ] + assert connection.commits == 2 assert sleeps == [2.0] assert engine.disposed diff --git a/server/tests/test_principal_namespace_flow.py b/server/tests/test_principal_namespace_flow.py index 0ca1bca8..8f16a795 100644 --- a/server/tests/test_principal_namespace_flow.py +++ b/server/tests/test_principal_namespace_flow.py @@ -187,6 +187,18 @@ def test_principal_namespace_scopes_cross_namespace_writes(app: FastAPI) -> None ).status_code == 404 ) + assert ( + ns_b.patch( + "/api/v1/control-bindings/by-key", + json={ + "target_type": "env", + "target_id": "prod", + "control_id": control_id, + "enabled": False, + }, + ).status_code + == 404 + ) delete_binding = ns_b.post( "/api/v1/control-bindings/by-key:delete", json={