diff --git a/agentflow_cli/cli/commands/eval.py b/agentflow_cli/cli/commands/eval.py index 62852ea..3552085 100644 --- a/agentflow_cli/cli/commands/eval.py +++ b/agentflow_cli/cli/commands/eval.py @@ -39,6 +39,7 @@ class _PendingCase: case: Any # EvalCase evaluator: AgentEvaluator + config: Any # EvalConfig — the resolved config for this case file_name: str eval_set_id: str eval_set_name: str @@ -223,8 +224,8 @@ def _collect_from_file( """Load a module and return pending work for every eval case or simulation scenario. Config priority chain (highest → lowest): - 1. confeval.py — global_config (not None when confeval was found) - 2. Per-file — EVAL_CONFIG / get_eval_config() inside the file + 1. Per-file — EVAL_CONFIG / get_eval_config() inside the file + 2. confeval.py — global_config (fallback when no per-file config) 3. Built-in defaults — _default_config() Returns _PendingSimulation items when the file exposes get_scenarios() or SCENARIOS. @@ -253,8 +254,8 @@ def _collect_from_file( file_config = mod.get_eval_config() elif hasattr(mod, "EVAL_CONFIG"): file_config = mod.EVAL_CONFIG - # Priority: confeval > per-file > defaults - config = global_config or file_config or self._default_config() + # Priority: per-file > confeval > defaults + config = file_config or global_config or self._default_config() return self._make_pending(mod, mod.get_eval_set(), config, file_name) # pytest-style discovery @@ -267,8 +268,8 @@ def _collect_from_file( if hasattr(mod, "EVAL_CONFIG") else None ) - # Priority: confeval > per-file > defaults - config = global_config or file_config or self._default_config() + # Priority: per-file > confeval > defaults + config = file_config or global_config or self._default_config() pending: list[_PendingCase] = [] for _, es in eval_pairs: pending.extend(self._make_pending(mod, es, config, file_name)) @@ -329,6 +330,7 @@ def _make_pending( _PendingCase( case=c, evaluator=evaluator, + config=config, file_name=file_name, eval_set_id=eval_set.eval_set_id, eval_set_name=eval_set.name, @@ -531,7 +533,9 @@ async def _run_one( # Report merging # ------------------------------------------------------------------ - def _merge_reports(self, reports: list[EvalReport]) -> EvalReport: + def _merge_reports( + self, reports: list[EvalReport], base_config: Any = None + ) -> EvalReport: if len(reports) == 1: return reports[0] @@ -542,6 +546,7 @@ def _merge_reports(self, reports: list[EvalReport]) -> EvalReport: eval_set_id="combined_eval", eval_set_name="Combined Evaluation", results=all_results, + config_used=base_config.model_dump() if base_config else {}, ) # ------------------------------------------------------------------ @@ -673,6 +678,13 @@ def execute( # noqa: PLR0912, PLR0915 self.output.error("No results produced.") return 1 + # Build per-eval-set config map from pending before results are consumed + group_configs: dict[str, Any] = { + pc.eval_set_id: pc.config + for pc in pending + if isinstance(pc, _PendingCase) + } + # 7. Group by eval_set_id → one EvalReport per set groups: dict[str, tuple[str, list[EvalCaseResult]]] = defaultdict(lambda: ("", [])) for _, eval_set_id, eval_set_name, result in quads: @@ -681,17 +693,18 @@ def execute( # noqa: PLR0912, PLR0915 reports: list[EvalReport] = [] for eval_set_id, (eval_set_name, results) in groups.items(): + group_cfg = group_configs.get(eval_set_id) or confeval_config or self._default_config() reports.append( ER.create( eval_set_id=eval_set_id, eval_set_name=eval_set_name, results=results, - config_used=(confeval_config or self._default_config()).model_dump(), + config_used=group_cfg.model_dump(), ) ) # 8. Merge into a single report - merged = self._merge_reports(reports) + merged = self._merge_reports(reports, base_config=confeval_config or self._default_config()) # 9. Determine exit code if threshold is not None and merged.summary.pass_rate < threshold: diff --git a/agentflow_cli/cli/commands/skills.py b/agentflow_cli/cli/commands/skills.py index d1280cd..6c9ba24 100644 --- a/agentflow_cli/cli/commands/skills.py +++ b/agentflow_cli/cli/commands/skills.py @@ -92,6 +92,11 @@ def kind(self) -> str: source_relpath="agent-skills", manifest=True, ), + _InstallArtifact( + kind="file", + install_relpath=".github/skills/agentflow/SKILL.md", + source_relpath="copilot/SKILL.md", + ), ), ), ) diff --git a/agentflow_cli/cli/templates/skills/copilot/SKILL.md b/agentflow_cli/cli/templates/skills/copilot/SKILL.md new file mode 100644 index 0000000..a42d13d --- /dev/null +++ b/agentflow_cli/cli/templates/skills/copilot/SKILL.md @@ -0,0 +1,82 @@ +--- +name: agentflow +description: Expert guidance for building, debugging, and extending applications with AgentFlow (10xscale-agentflow). TRIGGER when: code imports from agentflow (e.g. `from agentflow import`, `StateGraph`, `Agent`, `ToolNode`, `AgentState`); user references `agentflow.json` or CLI commands (`agentflow init`, `agentflow api`, `agentflow play`, `agentflow build`, `agentflow skills`); user is building graph-based multi-agent workflows, tools, memory, checkpointing, or streaming with this framework. SKIP: generic Python or multi-agent questions not referencing agentflow; other frameworks (LangGraph, CrewAI, AutoGen) unless comparing. +--- + +# Agentflow Project Skill + +Use this skill when working in an Agentflow project. Agentflow is a multi-agent framework that wraps official OpenAI and Google SDK capabilities behind a unified graph, agent, tool, state, storage, API, CLI, and TypeScript client interface. + +Treat https://agentflow.10xscale.ai/ as the first source of truth for public package names, install commands, and user-facing behavior. Use implementation source after the docs establish the intended API. + +## Workflow + +1. Identify the published package or docs surface involved: + - PyPI core Python SDK: `10xscale-agentflow` (`pip install 10xscale-agentflow`), source at https://github.com/10xHub/Agentflow/tree/main/agentflow/agentflow + - PyPI API/CLI SDK: `10xscale-agentflow-cli` (`pip install 10xscale-agentflow-cli`), source at https://github.com/10xHub/Agentflow/tree/main/agentflow-api/agentflow_cli + - npm TypeScript SDK: `@10xscale/agentflow-client` (`npm install @10xscale/agentflow-client`), source at https://github.com/10xHub/Agentflow/tree/main/agentflow-client/src + - Main docs: https://agentflow.10xscale.ai/ + - Playground/UI: `agentflow play` command after installed cli + +2. Read the matching reference file before changing behavior: + + ### Core Python SDK + - Architecture and package flow: `.github/skills/agentflow/references/architecture.md` + - Agent constructor, provider, reasoning, retry, fallback, output_schema: `.github/skills/agentflow/references/agents-and-tools.md` + - Graph construction, nodes, edges, compile, interrupts, config keys: `.github/skills/agentflow/references/state-graph.md` + - State, messages, and content blocks: `.github/skills/agentflow/references/state-and-messages.md` + - Thread and checkpointing: `.github/skills/agentflow/references/checkpointing-and-threads.md` + - Dependency injection (InjectQ): `.github/skills/agentflow/references/dependency-injection.md` + - Multimodal files and media stores: `.github/skills/agentflow/references/media-and-files.md` + - Long-term memory stores (MemoryConfig, QdrantStore, Mem0Store): `.github/skills/agentflow/references/memory-and-store.md` + - Streaming, StreamChunk, SSE, ResponseGranularity: `.github/skills/agentflow/references/streaming.md` + - Stream emitter for tool progress updates: `.github/skills/agentflow/references/stream-emitter.md` + - Observability hooks, validators, and runtime jumps: `.github/skills/agentflow/references/callbacks-and-command.md` + - Prebuilt agents (ReactAgent, PlanActReflectAgent, StructuredOutputAgent, SupervisorTeamAgent, SwarmAgent, RAGAgent) and tools: `.github/skills/agentflow/references/prebuilt-agents-and-tools.md` + - Event publishers and A2A/ACP runtime protocols: `.github/skills/agentflow/references/publishers-and-runtime-protocols.md` + - Context management, ID generation, and background tasks: `.github/skills/agentflow/references/context-id-background.md` + - Provider internals and adapters: `.github/skills/agentflow/references/providers-and-adapters.md` + - Prompt-injection and validation safety: `.github/skills/agentflow/references/security-and-validators.md` + + ### API/CLI SDK + - CLI commands and generated project files: `.github/skills/agentflow/references/cli-commands.md` + - `agentflow.json` and dependency loading: `.github/skills/agentflow/references/api-configuration.md` + - API auth and authorization: `.github/skills/agentflow/references/auth-and-authorization.md` + - API environment, settings, and middleware: `.github/skills/agentflow/references/api-settings-and-middleware.md` + - Rate limiting (config, backends, headers, custom backend): `.github/skills/agentflow/references/rate-limiting.md` + - REST routes and error behavior: `.github/skills/agentflow/references/rest-api-and-errors.md` + - API Snowflake IDs and thread naming: `.github/skills/agentflow/references/id-and-thread-name-generators.md` + - API server and deployment runtime: `.github/skills/agentflow/references/production-runtime.md` + + ### TypeScript client SDK + - REST and TypeScript client surface: `.github/skills/agentflow/references/api-client.md` + - Browser/client-side tool execution: `.github/skills/agentflow/references/remote-tools.md` + - TypeScript auth helpers and structured errors: `.github/skills/agentflow/references/client-auth-and-errors.md` + - TypeScript messages, invoke, and stream details: `.github/skills/agentflow/references/client-messages-invoke-stream.md` + - TypeScript thread, memory, and file APIs: `.github/skills/agentflow/references/client-threads-memory-files.md` + + ### Testing and QA + - Unit testing without LLM calls (TestAgent, QuickTest, MockToolRegistry, `agentflow test`): `.github/skills/agentflow/references/unit-testing.md` + - Evaluation framework (EvalSet, criteria, AgentEvaluator, QuickEval, UserSimulator, `agentflow eval`): `.github/skills/agentflow/references/evaluation.md` + - Testing helpers overview: `.github/skills/agentflow/references/testing-and-evaluation.md` + +3. Prefer existing Agentflow abstractions over new custom wiring: + - Build workflows with `StateGraph`, `Agent`, `ToolNode`, `AgentState`, and `Message`. + - Use prebuilt agents (`ReactAgent`, `PlanActReflectAgent`, `StructuredOutputAgent`, `SupervisorTeamAgent`, `SwarmAgent`, `RAGAgent`) for common patterns before hand-writing graph loops. + - Persist conversation state with checkpointers; use stores only for cross-thread memory. + - Put business services in `InjectQ` instead of global variables. + - Keep API/CLI graph modules storage-agnostic and wire dependencies through `agentflow.json`. + +4. Verify against source when implementation details matter. Public names and expected behavior should match https://agentflow.10xscale.ai/; source under https://github.com/10xHub/Agentflow (core), https://github.com/10xHub/agentflow-cli (API/CLI), and https://github.com/10xHub/agentflow-client (TypeScript) explains how that behavior is implemented. + +## Local Conventions + +- A compiled graph is normally loaded once by the API server and reused per request. +- Public package naming matters: use `10xscale-agentflow`, `10xscale-agentflow-cli`, and `@10xscale/agentflow-client` in user-facing docs and examples, not repository folder names. +- Every persisted interaction should include `config.thread_id`. +- Tools need docstrings and type annotations so model-facing schemas are useful. +- Injectable tool and node parameters (`state`, `config`, `tool_call_id`) are hidden from the model schema. +- For production, avoid process-local storage for shared state; use durable checkpointer/store backends. +- Add observability or audit side effects by registering a `GraphLifecycleHook` on `CallbackManager` — do not wrap `ainvoke()` / `astream()` calls in application code to achieve the same result. +- `reasoning_config` is on by default at medium effort; disable explicitly with `reasoning_config=None` when not needed. +- Provider is auto-detected from the model name; use `base_url` for third-party OpenAI-compatible APIs (Ollama, DeepSeek, OpenRouter). diff --git a/tests/cli/test_cli_main.py b/tests/cli/test_cli_main.py index f15a515..22d748d 100644 --- a/tests/cli/test_cli_main.py +++ b/tests/cli/test_cli_main.py @@ -1,7 +1,8 @@ from typer.testing import CliRunner - +import pytest +from unittest.mock import MagicMock, patch import agentflow_cli.cli.main as main_mod - +from agentflow_cli.cli.exceptions import PyagenityCLIError runner = CliRunner() @@ -27,8 +28,172 @@ def fake_execute(self, **kwargs): assert called["open_playground"] is True +def test_api_command(monkeypatch): + called = {} + monkeypatch.setattr(main_mod, "setup_cli_logging", lambda **kwargs: None) + monkeypatch.setattr( + main_mod.APICommand, + "execute", + lambda self, **kwargs: called.update(kwargs) or 0 + ) + result = runner.invoke(main_mod.app, ["api", "-c", "custom.json", "-H", "1.2.3.4", "-p", "8080", "--no-reload", "--verbose"]) + assert result.exit_code == 0 + assert called["config"] == "custom.json" + assert called["host"] == "1.2.3.4" + assert called["port"] == 8080 + assert called["reload"] is False + + +def test_version_command(monkeypatch): + called = [] + monkeypatch.setattr(main_mod, "setup_cli_logging", lambda **kwargs: None) + monkeypatch.setattr( + main_mod.VersionCommand, + "execute", + lambda self: called.append(True) or 0 + ) + result = runner.invoke(main_mod.app, ["version"]) + assert result.exit_code == 0 + assert len(called) == 1 + + +def test_init_command(monkeypatch): + called = {} + monkeypatch.setattr(main_mod, "setup_cli_logging", lambda **kwargs: None) + monkeypatch.setattr( + main_mod.InitCommand, + "execute", + lambda self, **kwargs: called.update(kwargs) or 0 + ) + result = runner.invoke(main_mod.app, ["init", "--path", "test_path", "--force"]) + assert result.exit_code == 0 + assert called["path"] == "test_path" + assert called["force"] is True + + +def test_build_command(monkeypatch): + called = {} + monkeypatch.setattr(main_mod, "setup_cli_logging", lambda **kwargs: None) + monkeypatch.setattr( + main_mod.BuildCommand, + "execute", + lambda self, **kwargs: called.update(kwargs) or 0 + ) + result = runner.invoke( + main_mod.app, + ["build", "-o", "Dfile", "--force", "--python-version", "3.12", "-p", "5000", "--docker-compose"] + ) + assert result.exit_code == 0 + assert called["output_file"] == "Dfile" + assert called["force"] is True + assert called["python_version"] == "3.12" + assert called["port"] == 5000 + assert called["docker_compose"] is True + + +def test_skills_command(monkeypatch): + called = {} + monkeypatch.setattr(main_mod, "setup_cli_logging", lambda **kwargs: None) + monkeypatch.setattr( + main_mod.SkillsCommand, + "execute", + lambda self, **kwargs: called.update(kwargs) or 0 + ) + result = runner.invoke( + main_mod.app, + ["skills", "-a", "codex", "-p", "skills_path", "--force", "--all", "--list"] + ) + assert result.exit_code == 0 + assert called["agent"] == "codex" + assert called["path"] == "skills_path" + assert called["force"] is True + assert called["all_agents"] is True + assert called["list_agents"] is True + + +def test_test_command(monkeypatch): + called = {} + monkeypatch.setattr(main_mod, "setup_cli_logging", lambda **kwargs: None) + monkeypatch.setattr( + main_mod.TestCommand, + "execute", + lambda self, **kwargs: called.update(kwargs) or 0 + ) + result = runner.invoke( + main_mod.app, + ["test", "tests/foo.py", "--coverage", "--html", "-k", "foo_test", "--", "--lf", "-vv"] + ) + assert result.exit_code == 0 + assert called["path"] == "tests/foo.py" + assert called["coverage"] is True + assert called["html"] is True + assert called["keyword"] == "foo_test" + assert called["extra_args"] == ("--lf", "-vv") + + +def test_eval_command(monkeypatch): + called = {} + monkeypatch.setattr(main_mod, "setup_cli_logging", lambda **kwargs: None) + monkeypatch.setattr( + main_mod.EvalCommand, + "execute", + lambda self, **kwargs: called.update(kwargs) or 0 + ) + result = runner.invoke( + main_mod.app, + ["eval", "target_eval", "-o", "out_dir", "--no-report", "-t", "0.8", "--open", "--parallel", "-c", "8"] + ) + assert result.exit_code == 0 + assert called["target"] == "target_eval" + assert called["output_dir"] == "out_dir" + assert called["no_report"] is True + assert called["threshold"] == 0.8 + assert called["open_report"] is True + assert called["parallel"] is True + assert called["max_concurrency"] == 8 + + def test_a2a_command_is_not_exposed(): result = runner.invoke(main_mod.app, ["a2a"]) assert result.exit_code != 0 assert "No such command 'a2a'" in result.output + + +def test_handle_pyagenity_cli_error(monkeypatch): + monkeypatch.setattr(main_mod, "setup_cli_logging", lambda **kwargs: None) + monkeypatch.setattr( + main_mod.VersionCommand, + "execute", + lambda self: (_ for _ in ()).throw(PyagenityCLIError("Custom error message", exit_code=42)) + ) + result = runner.invoke(main_mod.app, ["version"]) + assert result.exit_code == 42 + + +def test_handle_generic_exception(monkeypatch): + monkeypatch.setattr(main_mod, "setup_cli_logging", lambda **kwargs: None) + monkeypatch.setattr( + main_mod.VersionCommand, + "execute", + lambda self: (_ for _ in ()).throw(ValueError("Some generic value error")) + ) + result = runner.invoke(main_mod.app, ["version"]) + assert result.exit_code == 1 + + +def test_main_keyboard_interrupt(monkeypatch): + monkeypatch.setattr(main_mod, "setup_cli_logging", lambda **kwargs: None) + with patch("agentflow_cli.cli.main.app", side_effect=KeyboardInterrupt): + with pytest.raises(SystemExit) as exc_info: + main_mod.main() + assert exc_info.value.code == 130 + + +def test_main_generic_exception(monkeypatch): + monkeypatch.setattr(main_mod, "setup_cli_logging", lambda **kwargs: None) + with patch("agentflow_cli.cli.main.app", side_effect=ValueError("Main error")): + with pytest.raises(SystemExit) as exc_info: + main_mod.main() + assert exc_info.value.code == 1 + diff --git a/tests/cli/test_eval_command.py b/tests/cli/test_eval_command.py new file mode 100644 index 0000000..b45c16e --- /dev/null +++ b/tests/cli/test_eval_command.py @@ -0,0 +1,560 @@ +import pytest +import asyncio +from pathlib import Path +from unittest.mock import MagicMock, patch +import types + +from agentflow_cli.cli.commands.eval import EvalCommand, _PendingCase, _PendingSimulation +from agentflow_cli.cli.core.output import OutputFormatter +from agentflow.qa.evaluation import EvalConfig, CriteriaConfig, CriterionConfig +from agentflow.qa.evaluation.eval_result import EvalCaseResult + +# Disable pytest collection for the imported EvalCommand class +EvalCommand.__test__ = False + + +class _SilentOutput(OutputFormatter): + def __init__(self) -> None: + super().__init__() + self.successes = [] + self.errors = [] + self.infos = [] + self.warnings = [] + + def success(self, message: str, emoji: bool = True) -> None: + self.successes.append(message) + + def error(self, message: str, emoji: bool = True) -> None: + self.errors.append(message) + + def info(self, message: str, emoji: bool = True) -> None: + self.infos.append(message) + + def warning(self, message: str, emoji: bool = True) -> None: + self.warnings.append(message) + + def print_banner(self, *args, **kwargs) -> None: + pass + + +@pytest.fixture +def cmd() -> EvalCommand: + return EvalCommand(output=_SilentOutput()) + + +def test_load_agent_from_config_success(cmd): + mock_cm = MagicMock() + mock_cm.auto_discover_config.return_value = "agentflow.json" + mock_cm.get_config_value.return_value = "my_module:my_agent" + + mock_agent = MagicMock() + mock_module = types.SimpleNamespace(my_agent=mock_agent) + + with patch("agentflow_cli.cli.commands.eval.ConfigManager", return_value=mock_cm), \ + patch("importlib.import_module", return_value=mock_module) as mock_import: + + agent = cmd._load_agent_from_config() + assert agent is mock_agent + mock_import.assert_called_once_with("my_module") + + +def test_load_agent_from_config_no_json(cmd): + mock_cm = MagicMock() + mock_cm.auto_discover_config.return_value = None + + with patch("agentflow_cli.cli.commands.eval.ConfigManager", return_value=mock_cm): + with pytest.raises(RuntimeError, match="No agentflow.json found"): + cmd._load_agent_from_config() + + +def test_load_agent_from_config_invalid_spec(cmd): + mock_cm = MagicMock() + mock_cm.auto_discover_config.return_value = "agentflow.json" + mock_cm.get_config_value.return_value = "invalid_spec" + + with patch("agentflow_cli.cli.commands.eval.ConfigManager", return_value=mock_cm): + with pytest.raises(RuntimeError, match="Invalid 'agent' field"): + cmd._load_agent_from_config() + + +def test_print_criteria_block(cmd, capsys): + cfg = EvalConfig( + criteria=CriteriaConfig( + tool_name_match=CriterionConfig.tool_name_match(threshold=1.0), + ) + ) + cmd._print_criteria_block(cfg, Path("some_dir")) + captured = capsys.readouterr() + assert "Criteria source:" in captured.out + assert "tool_name_match" in captured.out + + +def test_print_case_progress_passed(cmd, capsys): + res = EvalCaseResult.success( + eval_id="c1", + name="case1", + criterion_results=[], + actual_response="", + ) + res.duration_seconds = 1.23 + cmd._print_case_progress("file.py", "case1", res, 1, 10) + captured = capsys.readouterr() + assert "file.py::case1" in captured.out + assert "PASSED" in captured.out + + +def test_resolve_eval_dir(cmd): + mock_cm = MagicMock() + mock_cm.auto_discover_config.return_value = "agentflow.json" + mock_cm.get_evaluation_config.return_value = {"directory": "custom_evals"} + + with patch("agentflow_cli.cli.commands.eval.ConfigManager", return_value=mock_cm): + eval_dir = cmd._resolve_eval_dir() + assert eval_dir == Path.cwd() / "custom_evals" + + +def test_collect_simulations(cmd): + mod = types.SimpleNamespace(app=MagicMock(), SIMULATOR_CONFIG=MagicMock()) + scenarios = [MagicMock(scenario_id="s1", description="desc1")] + + with patch("agentflow_cli.cli.commands.eval.ConfigManager"): + res = cmd._collect_simulations(mod, scenarios, "file_eval.py") + assert len(res) == 1 + assert isinstance(res[0], _PendingSimulation) + assert res[0].eval_set_id == "file_eval_simulations" + + +@pytest.mark.asyncio +async def test_run_flat_pool_cases(cmd): + case1 = MagicMock() + case1.eval_id = "c1" + case1.name = "case_name" + evaluator = MagicMock() + + async def mock_evaluate_case(case, collector_override): + return EvalCaseResult.success( + eval_id="c1", + name="case_name", + criterion_results=[], + actual_response="resp", + ) + + evaluator._evaluate_case = mock_evaluate_case + evaluator.collector.capture_all_events = True + + pc = _PendingCase( + case=case1, + evaluator=evaluator, + config=EvalConfig(), + file_name="test_file.py", + eval_set_id="es1", + eval_set_name="Eval Set 1", + ) + + with patch("agentflow_cli.cli.commands.eval._reset_inject_proxy"), \ + patch("agentflow_cli.cli.commands.eval.override_dependency"): + + results = await cmd._run_flat_pool([pc], max_concurrency=4, parallel=False) + assert len(results) == 1 + assert results[0][0] == "test_file.py" + assert results[0][1] == "es1" + assert results[0][3].passed is True + + +@pytest.mark.asyncio +async def test_run_flat_pool_case_error(cmd): + case1 = MagicMock() + case1.eval_id = "c1" + case1.name = "case_name" + evaluator = MagicMock() + + async def mock_evaluate_case(case, collector_override): + raise ValueError("evaluation failed") + + evaluator._evaluate_case = mock_evaluate_case + evaluator.collector.capture_all_events = True + + pc = _PendingCase( + case=case1, + evaluator=evaluator, + config=EvalConfig(), + file_name="test_file.py", + eval_set_id="es1", + eval_set_name="Eval Set 1", + ) + + with patch("agentflow_cli.cli.commands.eval._reset_inject_proxy"), \ + patch("agentflow_cli.cli.commands.eval.override_dependency"): + + results = await cmd._run_flat_pool([pc], max_concurrency=4, parallel=False) + assert len(results) == 1 + assert results[0][3].passed is False + assert "evaluation failed" in results[0][3].error + + +@pytest.mark.asyncio +async def test_run_flat_pool_simulation(cmd): + from agentflow.qa.evaluation.token_usage import TokenUsage + + # Mock simulator + simulator = MagicMock() + mock_criterion = MagicMock() + mock_criterion.threshold = 0.7 + simulator.criteria = [mock_criterion] + + sim_result = MagicMock() + sim_result.completed = True + sim_result.criterion_results = [] + sim_result.criterion_scores = {"g1": 1.0} + sim_result.criterion_details = {} + sim_result.conversation = [{"role": "user", "content": "hello"}] + sim_result.simulator_token_usage = TokenUsage(input_tokens=10, output_tokens=5) + sim_result.turns = 2 + sim_result.goals_achieved = 1 + + async def mock_simulator_run(graph, scenario): + return sim_result + + simulator.run = mock_simulator_run + + ps = _PendingSimulation( + scenario=MagicMock(scenario_id="sc1", description="sc_desc"), + graph=MagicMock(), + simulator=simulator, + file_name="test_sim.py", + eval_set_id="es_sim", + eval_set_name="Sim Set", + ) + + results = await cmd._run_flat_pool([ps], max_concurrency=4, parallel=False) + assert len(results) == 1 + assert results[0][3].passed is True + assert "USER: hello" in results[0][3].actual_response + + +@pytest.mark.asyncio +async def test_run_flat_pool_simulation_error(cmd): + simulator = MagicMock() + mock_criterion = MagicMock() + mock_criterion.threshold = 0.7 + simulator.criteria = [mock_criterion] + + async def mock_simulator_run(graph, scenario): + raise ValueError("simulation failed") + simulator.run = mock_simulator_run + + ps = _PendingSimulation( + scenario=MagicMock(scenario_id="sc1", description="sc_desc"), + graph=MagicMock(), + simulator=simulator, + file_name="test_sim.py", + eval_set_id="es_sim", + eval_set_name="Sim Set", + ) + + results = await cmd._run_flat_pool([ps], max_concurrency=4, parallel=False) + assert len(results) == 1 + assert results[0][3].passed is False + assert "simulation failed" in results[0][3].error + + +def test_execute_target_not_found(cmd): + with patch("agentflow_cli.cli.commands.eval.ConfigManager"): + code = cmd.execute(target="non_existent_path") + assert code == 1 + assert len(cmd.output.errors) > 0 + + +def test_execute_no_files(cmd, tmp_path): + with patch("agentflow_cli.cli.commands.eval.ConfigManager"), \ + patch.object(cmd, "_discover", return_value=[]): + code = cmd.execute(target=str(tmp_path)) + assert code == 1 + assert len(cmd.output.errors) > 0 + + +def test_execute_success(cmd, tmp_path): + from agentflow.qa.evaluation.token_usage import TokenUsage + + # Mock ConfigManager and its returns + mock_cm = MagicMock() + mock_cm.auto_discover_config.return_value = "agentflow.json" + mock_cm.get_evaluation_config.return_value = {} + + # Mock discovery and collection + fake_case = MagicMock() + fake_case.config = EvalConfig() + + mock_report = MagicMock() + mock_report.summary.pass_rate = 1.0 + mock_report.summary.passed_cases = 1 + mock_report.summary.total_cases = 1 + mock_report.summary.total_token_usage = TokenUsage(input_tokens=100, output_tokens=50) + + mock_rep_mgr_res = MagicMock() + mock_rep_mgr_res.html_path = "/absolute/path/to/report.html" + mock_rep_mgr_res.json_path = "path/to/report.json" + mock_rep_mgr_res.has_errors = False + + mock_rep_mgr = MagicMock() + mock_rep_mgr.run_all.return_value = mock_rep_mgr_res + + # We run flat pool which returns quads + mock_case_result = EvalCaseResult.success( + eval_id="case_id", + name="case_name", + criterion_results=[], + actual_response="resp", + token_usage=TokenUsage(input_tokens=10, output_tokens=5), + ) + quads = [("test_eval.py", "eval_set_id", "eval_set_name", mock_case_result)] + + with patch("agentflow_cli.cli.commands.eval.ConfigManager", return_value=mock_cm), \ + patch.object(cmd, "_discover", return_value=[Path("test_eval.py")]), \ + patch.object(cmd, "_load_confeval", return_value=None), \ + patch.object(cmd, "_collect_from_file", return_value=[fake_case]), \ + patch.object(cmd, "_print_criteria_block"), \ + patch.object(cmd, "_run_flat_pool", return_value=quads), \ + patch.object(cmd, "_merge_reports", return_value=mock_report), \ + patch("agentflow_cli.cli.commands.eval.ReporterManager", return_value=mock_rep_mgr), \ + patch("webbrowser.open") as mock_web_open: + + code = cmd.execute(target=str(tmp_path), open_report=True) + assert code == 0 + mock_web_open.assert_called_once() + assert len(cmd.output.successes) > 0 + + +def test_execute_below_threshold(cmd, tmp_path): + from agentflow.qa.evaluation.token_usage import TokenUsage + + # Mock ConfigManager and its returns + mock_cm = MagicMock() + mock_cm.auto_discover_config.return_value = "agentflow.json" + mock_cm.get_evaluation_config.return_value = {} + + fake_case = MagicMock() + fake_case.config = EvalConfig() + + mock_report = MagicMock() + mock_report.summary.pass_rate = 0.5 + mock_report.summary.passed_cases = 1 + mock_report.summary.total_cases = 2 + mock_report.summary.total_token_usage = TokenUsage(input_tokens=100, output_tokens=50) + + mock_rep_mgr = MagicMock() + mock_rep_mgr.run_all.return_value = MagicMock(html_path=None, json_path=None, has_errors=False) + + mock_case_result = EvalCaseResult.failure( + eval_id="case_id", + name="case_name", + error="failed", + ) + quads = [("test_eval.py", "eval_set_id", "eval_set_name", mock_case_result)] + + with patch("agentflow_cli.cli.commands.eval.ConfigManager", return_value=mock_cm), \ + patch.object(cmd, "_discover", return_value=[Path("test_eval.py")]), \ + patch.object(cmd, "_load_confeval", return_value=None), \ + patch.object(cmd, "_collect_from_file", return_value=[fake_case]), \ + patch.object(cmd, "_print_criteria_block"), \ + patch.object(cmd, "_run_flat_pool", return_value=quads), \ + patch.object(cmd, "_merge_reports", return_value=mock_report), \ + patch("agentflow_cli.cli.commands.eval.ReporterManager", return_value=mock_rep_mgr): + + code = cmd.execute(target=str(tmp_path), threshold=0.8) + assert code == 1 + assert "below threshold" in cmd.output.errors[0] + + +def test_execute_collect_error(cmd, tmp_path): + with patch("agentflow_cli.cli.commands.eval.ConfigManager"), \ + patch.object(cmd, "_discover", return_value=[Path("test_eval.py")]), \ + patch.object(cmd, "_collect_from_file", side_effect=ValueError("load error")): + + code = cmd.execute(target=str(tmp_path)) + assert code == 1 + assert "Error loading test_eval.py: load error" in cmd.output.errors[0] + + +def test_reset_inject_proxy(): + from agentflow_cli.cli.commands.eval import _reset_inject_proxy + _reset_inject_proxy(None) + + +def test_load_module(cmd, tmp_path): + f = tmp_path / "dummy_mod.py" + f.write_text("x = 42") + mod = cmd._load_module(f) + assert mod.x == 42 + + +def test_load_confeval_func_error(cmd, tmp_path): + confeval = tmp_path / "confeval.py" + confeval.write_text("") + def raise_err(): + raise ValueError("config error") + fake_mod = types.SimpleNamespace(get_eval_config=raise_err, EVAL_CONFIG="some_config") + + with patch.object(cmd, "_load_module", return_value=fake_mod): + result = cmd._load_confeval(tmp_path) + assert result == "some_config" + + +def test_collect_eval_functions(cmd): + from agentflow.qa.evaluation import EvalConfig, EvalSet + + class FakeEvalSet(EvalSet): + def __init__(self): + super().__init__(eval_set_id="es1", name="Set 1", eval_cases=[]) + + class FakeEvalConfig(EvalConfig): + pass + + def func_eval() -> FakeEvalSet: + return FakeEvalSet() + + def func_config() -> FakeEvalConfig: + return FakeEvalConfig() + + fake_mod = types.SimpleNamespace(__name__="fake_mod") + func_eval.__module__ = "fake_mod" + func_config.__module__ = "fake_mod" + fake_mod.func_eval = func_eval + fake_mod.func_config = func_config + + eval_pairs, config = cmd._collect_eval_functions(fake_mod) + assert len(eval_pairs) == 1 + assert eval_pairs[0][0] == "func_eval" + assert isinstance(config, FakeEvalConfig) + + +def test_collect_from_file_scenarios_error(cmd, tmp_path): + p = tmp_path / "x_eval.py" + p.write_text("") + def raise_err(): + raise ValueError("scenarios error") + fake_mod = types.SimpleNamespace(get_scenarios=raise_err) + with patch.object(cmd, "_load_module", return_value=fake_mod): + res = cmd._collect_from_file(p, None) + assert res == [] + + +def test_make_pending_loads_agent_from_config(cmd): + fake_mod = types.SimpleNamespace() + eval_set = MagicMock() + eval_set.eval_cases = [MagicMock()] + config = EvalConfig() + + mock_agent = MagicMock() + with patch.object(cmd, "_load_agent_from_config", return_value=mock_agent): + res = cmd._make_pending(fake_mod, eval_set, config, "file.py") + assert len(res) == 1 + assert res[0].evaluator.graph is mock_agent + + +def test_print_criteria_block_custom(cmd, capsys): + cfg_match = CriterionConfig.tool_name_match(threshold=1.0) + cfg_match.num_samples = 3 + cfg_match.judge_model = "gpt-4" + cfg = EvalConfig( + criteria=CriteriaConfig( + tool_name_match=cfg_match, + ) + ) + cmd._print_criteria_block(cfg, Path("some_dir")) + captured = capsys.readouterr() + assert "samples=3" in captured.out + assert "judge=gpt-4" in captured.out + + +@pytest.mark.asyncio +async def test_run_flat_pool_simulation_no_criterion(cmd): + from agentflow.qa.evaluation.token_usage import TokenUsage + + simulator = MagicMock() + simulator.criteria = [] + + sim_result = MagicMock() + sim_result.completed = True + sim_result.criterion_results = None + sim_result.criterion_scores = {} + sim_result.conversation = [{"role": "user", "content": "hello"}] + sim_result.simulator_token_usage = TokenUsage(input_tokens=10, output_tokens=5) + sim_result.turns = 2 + sim_result.goals_achieved = 1 + + async def mock_simulator_run(graph, scenario): + return sim_result + + simulator.run = mock_simulator_run + + ps = _PendingSimulation( + scenario=MagicMock(scenario_id="sc1", description="sc_desc"), + graph=MagicMock(), + simulator=simulator, + file_name="test_sim.py", + eval_set_id="es_sim", + eval_set_name="Sim Set", + ) + + results = await cmd._run_flat_pool([ps], max_concurrency=4, parallel=False) + assert len(results) == 1 + assert len(results[0][3].criterion_results) == 1 + assert results[0][3].criterion_results[0].criterion == "simulation_completed" + + +@pytest.mark.asyncio +async def test_run_flat_pool_parallel(cmd): + case1 = MagicMock() + case1.eval_id = "c1" + case1.name = "case_name" + evaluator = MagicMock() + + async def mock_evaluate_case(case, collector_override): + return EvalCaseResult.success( + eval_id="c1", + name="case_name", + criterion_results=[], + actual_response="resp", + ) + + evaluator._evaluate_case = mock_evaluate_case + evaluator.collector.capture_all_events = True + + pc = _PendingCase( + case=case1, + evaluator=evaluator, + config=EvalConfig(), + file_name="test_file.py", + eval_set_id="es1", + eval_set_name="Eval Set 1", + ) + + with patch("agentflow_cli.cli.commands.eval._reset_inject_proxy"), \ + patch("agentflow_cli.cli.commands.eval.override_dependency"): + + results = await cmd._run_flat_pool([pc], max_concurrency=2, parallel=True) + assert len(results) == 1 + assert results[0][3].passed is True + + +def test_resolve_eval_dir_error(cmd): + mock_cm = MagicMock() + mock_cm.auto_discover_config.return_value = "agentflow.json" + mock_cm.get_evaluation_config.side_effect = ValueError("config load error") + + with patch("agentflow_cli.cli.commands.eval.ConfigManager", return_value=mock_cm): + eval_dir = cmd._resolve_eval_dir() + assert eval_dir == Path.cwd() / "evals" + + +def test_execute_load_config_error(cmd, tmp_path): + mock_cm = MagicMock() + mock_cm.auto_discover_config.return_value = "agentflow.json" + mock_cm.load_config.side_effect = ValueError("corrupt json") + + with patch("agentflow_cli.cli.commands.eval.ConfigManager", return_value=mock_cm), \ + patch.object(cmd, "_discover", return_value=[]): + code = cmd.execute(target=str(tmp_path)) + assert code == 1 + diff --git a/tests/cli/test_eval_discovery.py b/tests/cli/test_eval_discovery.py index 6fc6820..a92ae52 100644 --- a/tests/cli/test_eval_discovery.py +++ b/tests/cli/test_eval_discovery.py @@ -173,7 +173,7 @@ def _fake_eval_set(self) -> MagicMock: es.cases = [MagicMock()] return es - def test_global_config_takes_priority_over_file_config( + def test_file_config_takes_priority_over_global_config( self, tmp_path: Path, cmd: EvalCommand ) -> None: from agentflow.qa.evaluation import EvalConfig @@ -194,7 +194,8 @@ def test_global_config_takes_priority_over_file_config( result = cmd._collect_from_file(self._dummy_path(tmp_path), global_cfg) _, _, used_config, _ = mock_make.call_args.args - assert used_config is global_cfg + assert used_config is file_cfg + assert used_config is not global_cfg assert result == fake_pending def test_file_config_used_when_no_global(self, tmp_path: Path, cmd: EvalCommand) -> None: diff --git a/tests/cli/test_eval_flat_pool.py b/tests/cli/test_eval_flat_pool.py index 504a3a2..e442042 100644 --- a/tests/cli/test_eval_flat_pool.py +++ b/tests/cli/test_eval_flat_pool.py @@ -103,7 +103,7 @@ def _dummy_path(self, tmp_path: Path) -> Path: p.write_text("") return p - def test_confeval_config_beats_per_file_eval_config( + def test_per_file_config_beats_confeval_config( self, tmp_path: Path, cmd: EvalCommand ) -> None: from agentflow.qa.evaluation import EvalConfig @@ -124,8 +124,8 @@ def test_confeval_config_beats_per_file_eval_config( cmd._collect_from_file(self._dummy_path(tmp_path), global_cfg) _, _, used_config, _ = mock_make.call_args.args - assert used_config is global_cfg - assert used_config is not per_file_cfg + assert used_config is per_file_cfg + assert used_config is not global_cfg def test_per_file_config_used_when_no_confeval( self, tmp_path: Path, cmd: EvalCommand diff --git a/tests/cli/test_test_command.py b/tests/cli/test_test_command.py new file mode 100644 index 0000000..0b594d8 --- /dev/null +++ b/tests/cli/test_test_command.py @@ -0,0 +1,105 @@ +"""Unit tests for TestCommand.""" + +from pathlib import Path +from unittest.mock import MagicMock, patch +import pytest + +from agentflow_cli.cli.commands.test import TestCommand +from agentflow_cli.cli.core.output import OutputFormatter + +# Disable pytest collection for the imported TestCommand class +TestCommand.__test__ = False + + +class _SilentOutput(OutputFormatter): + def __init__(self) -> None: + super().__init__() + self.successes = [] + self.errors = [] + self.infos = [] + + def success(self, message: str, emoji: bool = True) -> None: + self.successes.append(message) + + def error(self, message: str, emoji: bool = True) -> None: + self.errors.append(message) + + def info(self, message: str, emoji: bool = True) -> None: + self.infos.append(message) + + def print_banner(self, *args, **kwargs) -> None: + pass + + +@pytest.fixture +def cmd() -> TestCommand: + return TestCommand(output=_SilentOutput()) + + +def test_execute_simple_success(cmd): + mock_run_res = MagicMock() + mock_run_res.returncode = 0 + + with patch("subprocess.run", return_value=mock_run_res) as mock_run, \ + patch("agentflow_cli.cli.commands.test.ConfigManager.auto_discover_config", return_value=None): + + code = cmd.execute(path="tests/unit_tests") + assert code == 0 + mock_run.assert_called_once() + args = mock_run.call_args[0][0] + assert "pytest" in args + assert "tests/unit_tests" in args + + +def test_execute_failure(cmd): + mock_run_res = MagicMock() + mock_run_res.returncode = 1 + + with patch("subprocess.run", return_value=mock_run_res) as mock_run, \ + patch("agentflow_cli.cli.commands.test.ConfigManager.auto_discover_config", return_value=None): + + code = cmd.execute() + assert code == 1 + assert len(cmd.output.errors) > 0 + + +def test_execute_with_config_overrides(cmd): + mock_run_res = MagicMock() + mock_run_res.returncode = 0 + + # Mock ConfigManager to return test config + mock_cm = MagicMock() + mock_cm.auto_discover_config.return_value = "agentflow.json" + mock_cm.get_test_config.return_value = { + "path": "custom_tests", + "coverage": True, + "coverage_threshold": 90, + } + + with patch("subprocess.run", return_value=mock_run_res) as mock_run, \ + patch("agentflow_cli.cli.commands.test.ConfigManager", return_value=mock_cm), \ + patch("webbrowser.open") as mock_web_open: + + code = cmd.execute(coverage=False, html=True) # html=True requires coverage config override + assert code == 0 + args = mock_run.call_args[0][0] + assert "custom_tests" in args + assert "--cov=." in args + assert "--cov-fail-under=90" in args + mock_web_open.assert_called_once() + + +def test_execute_quiet_and_extra_args(cmd): + mock_run_res = MagicMock() + mock_run_res.returncode = 0 + + with patch("subprocess.run", return_value=mock_run_res) as mock_run, \ + patch("agentflow_cli.cli.commands.test.ConfigManager.auto_discover_config", return_value=None): + + cmd.execute(quiet=True, keyword="my_test", extra_args=("-x", "--lf")) + args = mock_run.call_args[0][0] + assert "-q" in args + assert "-k" in args + assert "my_test" in args + assert "-x" in args + assert "--lf" in args diff --git a/tests/unit_tests/test_graph_router.py b/tests/unit_tests/test_graph_router.py new file mode 100644 index 0000000..508001f --- /dev/null +++ b/tests/unit_tests/test_graph_router.py @@ -0,0 +1,162 @@ +"""Unit tests for the Graph API router endpoints.""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from fastapi import Request + +from agentflow_cli.src.app.routers.graph.router import ( + invoke_graph, + stream_graph, + graph_details, + state_schema, + stop_graph, + setup_graph, + fix_graph, +) +from agentflow_cli.src.app.routers.graph.schemas.graph_schemas import ( + GraphInputSchema, + GraphStopSchema, + GraphSetupSchema, + FixGraphRequestSchema, +) + + +@pytest.fixture +def mock_request(): + request = MagicMock(spec=Request) + request.state = MagicMock() + request.state.request_id = "test-request-id" + request.state.timestamp = "2024-01-01T00:00:00Z" + return request + + +@pytest.fixture +def mock_service(): + service = AsyncMock() + # stream_graph returns an async generator/iterable, not a direct awaitable coroutine + # but we can mock it as returning an async iterable or mock object + service.stream_graph = MagicMock() + return service + + +@pytest.fixture +def mock_user(): + return {"user_id": "user-123", "role": "admin"} + + +@pytest.mark.asyncio +async def test_invoke_graph_endpoint(mock_request, mock_service, mock_user): + graph_input = GraphInputSchema(messages=[{"role": "user", "content": [{"type": "text", "text": "hi"}]}]) + mock_service.invoke_graph.return_value = {"messages": []} + + with patch("agentflow_cli.src.app.routers.graph.router.success_response") as mock_success: + mock_success.return_value = {"status": "success"} + res = await invoke_graph( + request=mock_request, + graph_input=graph_input, + service=mock_service, + user=mock_user, + ) + assert res == {"status": "success"} + mock_service.invoke_graph.assert_called_once_with(graph_input, mock_user) + mock_success.assert_called_once() + + +@pytest.mark.asyncio +async def test_stream_graph_endpoint(mock_service, mock_user): + graph_input = GraphInputSchema(messages=[{"role": "user", "content": [{"type": "text", "text": "hi"}]}]) + + mock_stream = MagicMock() + mock_service.stream_graph.return_value = mock_stream + + with patch("agentflow_cli.src.app.routers.graph.router.StreamingResponse") as mock_streaming_response: + mock_streaming_response.return_value = "streaming_response_obj" + res = await stream_graph( + graph_input=graph_input, + service=mock_service, + user=mock_user, + ) + assert res == "streaming_response_obj" + mock_service.stream_graph.assert_called_once_with(graph_input, mock_user) + mock_streaming_response.assert_called_once() + + +@pytest.mark.asyncio +async def test_graph_details_endpoint(mock_request, mock_service, mock_user): + mock_service.graph_details.return_value = {"info": {}} + + with patch("agentflow_cli.src.app.routers.graph.router.success_response") as mock_success: + mock_success.return_value = {"status": "success"} + res = await graph_details( + request=mock_request, + service=mock_service, + user=mock_user, + ) + assert res == {"status": "success"} + mock_service.graph_details.assert_called_once() + + +@pytest.mark.asyncio +async def test_state_schema_endpoint(mock_request, mock_service, mock_user): + mock_service.get_state_schema.return_value = {"schema": {}} + + with patch("agentflow_cli.src.app.routers.graph.router.success_response") as mock_success: + mock_success.return_value = {"status": "success"} + res = await state_schema( + request=mock_request, + service=mock_service, + user=mock_user, + ) + assert res == {"status": "success"} + mock_service.get_state_schema.assert_called_once() + + +@pytest.mark.asyncio +async def test_stop_graph_endpoint(mock_request, mock_service, mock_user): + stop_req = GraphStopSchema(thread_id="thread-abc", config={"force": True}) + mock_service.stop_graph.return_value = {"status": "stopped"} + + with patch("agentflow_cli.src.app.routers.graph.router.success_response") as mock_success: + mock_success.return_value = {"status": "success"} + res = await stop_graph( + request=mock_request, + stop_request=stop_req, + service=mock_service, + user=mock_user, + ) + assert res == {"status": "success"} + mock_service.stop_graph.assert_called_once_with("thread-abc", mock_user, {"force": True}) + + +@pytest.mark.asyncio +async def test_setup_graph_endpoint(mock_request, mock_service, mock_user): + setup_req = GraphSetupSchema(tools=[]) + mock_service.setup.return_value = {"status": "configured"} + + with patch("agentflow_cli.src.app.routers.graph.router.success_response") as mock_success: + mock_success.return_value = {"status": "success"} + res = await setup_graph( + request=mock_request, + setup_request=setup_req, + service=mock_service, + user=mock_user, + ) + assert res == {"status": "success"} + mock_service.setup.assert_called_once_with(setup_req) + + +@pytest.mark.asyncio +async def test_fix_graph_endpoint(mock_request, mock_service, mock_user): + fix_req = FixGraphRequestSchema(thread_id="thread-abc", config={"clean": True}) + mock_service.fix_graph.return_value = {"status": "fixed"} + + with patch("agentflow_cli.src.app.routers.graph.router.success_response") as mock_success: + mock_success.return_value = {"status": "success"} + res = await fix_graph( + request=mock_request, + fix_request=fix_req, + service=mock_service, + user=mock_user, + ) + assert res == {"status": "success"} + mock_service.fix_graph.assert_called_once_with("thread-abc", mock_user, {"clean": True}) diff --git a/tests/unit_tests/test_graph_service.py b/tests/unit_tests/test_graph_service.py new file mode 100644 index 0000000..a18515a --- /dev/null +++ b/tests/unit_tests/test_graph_service.py @@ -0,0 +1,335 @@ +"""Unit tests for GraphService.""" + +import json +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from fastapi import HTTPException +from pydantic import BaseModel + +from agentflow.core.exceptions.media_exceptions import UnsupportedMediaInputError +from agentflow.core.state import AgentState, Message, StreamChunk, StreamEvent +from agentflow.storage.checkpointer import BaseCheckpointer +from agentflow_cli.src.app.routers.graph.services.graph_service import GraphService +from agentflow_cli.src.app.routers.graph.schemas.graph_schemas import ( + GraphInputSchema, + GraphSetupSchema, +) +from agentflow_cli.src.app.utils.thread_name_generator import ThreadNameGenerator +from agentflow_cli.src.app.core.config.graph_config import GraphConfig + + +class MockStateModel(BaseModel): + context: list = [] + context_summary: str = "" + + +class TestGraphServiceMethods: + """Test cases for GraphService methods.""" + + @pytest.fixture + def mock_graph(self): + graph = MagicMock() + graph.ainvoke = AsyncMock() + graph.astream = MagicMock() # Will be configured per test + graph.astop = AsyncMock() + graph.generate_graph = MagicMock() + + # for get_state_schema + class FakeState(BaseModel): + a: int + graph._state = FakeState + return graph + + @pytest.fixture + def mock_checkpointer(self): + checkpointer = MagicMock(spec=BaseCheckpointer) + checkpointer.aput_thread = AsyncMock(return_value=True) + return checkpointer + + @pytest.fixture + def mock_config(self): + config = MagicMock(spec=GraphConfig) + config.thread_name_generator_path = None + return config + + @pytest.fixture + def mock_thread_name_generator(self): + generator = MagicMock(spec=ThreadNameGenerator) + generator.generate_name = AsyncMock(return_value="custom_thread_name") + return generator + + @pytest.fixture + def service(self, mock_graph, mock_checkpointer, mock_config, mock_thread_name_generator): + srv = GraphService.__new__(GraphService) + srv._graph = mock_graph + srv.checkpointer = mock_checkpointer + srv.config = mock_config + srv.thread_name_generator = mock_thread_name_generator + srv._media_service = None + return srv + + def test_media_service_property(self, service): + # By default InjectQ is not active or try_get fails, so media_service is None + assert service.media_service is None + + # Mock InjectQ + mock_container = MagicMock() + mock_media = MagicMock() + mock_container.try_get.return_value = mock_media + with patch("injectq.InjectQ.get_instance", return_value=mock_container): + # reset property cache + service._media_service = None + assert service.media_service == mock_media + + @pytest.mark.asyncio + async def test_save_thread_name(self, service, mock_checkpointer, mock_thread_name_generator): + # Generator configured + name = await service._save_thread_name({"thread_id": "1"}, 1, ["msg"]) + assert name == "custom_thread_name" + mock_thread_name_generator.generate_name.assert_called_once_with(["msg"]) + mock_checkpointer.aput_thread.assert_called_once() + + # No generator configured + service.thread_name_generator = None + mock_checkpointer.aput_thread.reset_mock() + name = await service._save_thread_name({"thread_id": "1"}, 1, ["msg"]) + assert isinstance(name, str) + assert len(name) > 0 + mock_checkpointer.aput_thread.assert_not_called() + + @pytest.mark.asyncio + async def test_stop_graph_success(self, service, mock_graph): + mock_graph.astop.return_value = {"status": "stopped"} + user = {"user_id": "123"} + result = await service.stop_graph("thread-123", user, {"extra": "val"}) + assert result == {"status": "stopped"} + mock_graph.astop.assert_called_once_with({ + "thread_id": "thread-123", + "user": user, + "extra": "val" + }) + + @pytest.mark.asyncio + async def test_stop_graph_validation_error(self, service, mock_graph): + mock_graph.astop.side_effect = ValueError("invalid input") + with pytest.raises(HTTPException) as exc: + await service.stop_graph("t", {}) + assert exc.value.status_code == 422 + assert "invalid input" in exc.value.detail + + @pytest.mark.asyncio + async def test_stop_graph_general_error(self, service, mock_graph): + mock_graph.astop.side_effect = Exception("db crash") + with pytest.raises(HTTPException) as exc: + await service.stop_graph("t", {}) + assert exc.value.status_code == 500 + assert "db crash" in exc.value.detail + + @pytest.mark.asyncio + async def test_prepare_input(self, service): + gi = GraphInputSchema( + messages=[{"role": "user", "content": [{"type": "text", "text": "hi"}]}], + recursion_limit=10 + ) + + # Test with thread_id set + gi.config = {"thread_id": "t1"} + input_data, config, meta = await service._prepare_input(gi) + assert config["thread_id"] == "t1" + assert config["recursion_limit"] == 10 + assert meta["thread_id"] == "t1" + assert meta["is_new_thread"] is False + assert len(input_data["messages"]) == 1 + + # Test with thread_id empty (generates one) + gi.config = {} + input_data, config, meta = await service._prepare_input(gi) + assert "thread_id" in config + assert meta["is_new_thread"] is True + + @pytest.mark.asyncio + async def test_invoke_graph_success(self, service, mock_graph, mock_config, mock_checkpointer): + gi = GraphInputSchema( + messages=[{"role": "user", "content": [{"type": "text", "text": "hi"}]}], + recursion_limit=10, + config={"thread_id": "t1"} + ) + user = {"user_id": "u1"} + + mock_state = MagicMock(spec=AgentState) + mock_state.model_dump.return_value = {"key": "val"} + mock_msg = MagicMock(spec=Message) + mock_msg.text.return_value = "msg_text" + + mock_graph.ainvoke.return_value = { + "messages": [mock_msg], + "state": mock_state, + "context": [mock_msg], + "context_summary": "summary" + } + + # Mock thread_name_generator_path to trigger save_thread_name + mock_config.thread_name_generator_path = "some_path" + + # Since _save_thread returns True, it's considered a new thread + mock_checkpointer.aput_thread.return_value = True + + result = await service.invoke_graph(gi, user) + assert result.messages == [mock_msg] + assert result.summary == "summary" + assert result.meta["thread_name"] == "custom_thread_name" + + @pytest.mark.asyncio + async def test_invoke_graph_errors(self, service, mock_graph): + gi = GraphInputSchema(messages=[{"role": "user", "content": [{"type": "text", "text": "hi"}]}]) + + # UnsupportedMediaInputError + mock_graph.ainvoke.side_effect = UnsupportedMediaInputError("provider", "model", "media_type", "source_kind") + with pytest.raises(HTTPException) as exc: + await service.invoke_graph(gi, {}) + assert exc.value.status_code == 422 + + # ValueError + mock_graph.ainvoke.side_effect = ValueError("bad format") + with pytest.raises(HTTPException) as exc: + await service.invoke_graph(gi, {}) + assert exc.value.status_code == 422 + + # Exception + mock_graph.ainvoke.side_effect = Exception("failed") + with pytest.raises(HTTPException) as exc: + await service.invoke_graph(gi, {}) + assert exc.value.status_code == 500 + + @pytest.mark.asyncio + async def test_stream_graph_success(self, service, mock_graph, mock_config): + gi = GraphInputSchema( + messages=[{"role": "user", "content": [{"type": "text", "text": "hi"}]}], + config={"thread_id": "t1"} + ) + + chunk = StreamChunk(event=StreamEvent.MESSAGE, data={"chunk": "x"}) + + # Mock generator to yield chunk + async def mock_stream(*args, **kwargs): + yield chunk + + mock_graph.astream = mock_stream + + mock_config.thread_name_generator_path = "some_path" + + chunks = [] + async for c in service.stream_graph(gi, {}): + chunks.append(c) + + assert len(chunks) == 2 # message chunk + final completed status chunk (due to thread name generator path branch) + + data0 = json.loads(chunks[0]) + assert data0["event"] == "message" + + data1 = json.loads(chunks[1]) + assert data1["event"] == "updates" + assert data1["data"]["status"] == "completed" + + @pytest.mark.asyncio + async def test_stream_graph_exception_handling(self, service, mock_graph): + gi = GraphInputSchema( + messages=[{"role": "user", "content": [{"type": "text", "text": "hi"}]}], + config={"thread_id": "t1"} + ) + + async def mock_stream_error(*args, **kwargs): + raise Exception("stream crash") + yield # make it a generator + + mock_graph.astream = mock_stream_error + + chunks = [] + async for c in service.stream_graph(gi, {}): + chunks.append(c) + + assert len(chunks) == 1 + data = json.loads(chunks[0]) + assert data["event"] == "error" + assert "stream crash" in data["data"]["reason"] + + @pytest.mark.asyncio + async def test_graph_details_success_and_errors(self, service, mock_graph): + mock_graph.generate_graph.return_value = { + "info": { + "node_count": 2, "edge_count": 1, "checkpointer": False, + "checkpointer_type": None, "publisher": False, "store": False, + "interrupt_before": None, "interrupt_after": None + }, + "nodes": [], "edges": [] + } + res = await service.graph_details() + assert res.info.node_count == 2 + + # ValueError + mock_graph.generate_graph.side_effect = ValueError("invalid format") + with pytest.raises(HTTPException) as exc: + await service.graph_details() + assert exc.value.status_code == 422 + + # Exception + mock_graph.generate_graph.side_effect = Exception("failed") + with pytest.raises(HTTPException) as exc: + await service.graph_details() + assert exc.value.status_code == 500 + + @pytest.mark.asyncio + async def test_get_state_schema_success_and_errors(self, service, mock_graph): + res = await service.get_state_schema() + assert "properties" in res + + # Exception + del mock_graph._state + with pytest.raises(HTTPException) as exc: + await service.get_state_schema() + assert exc.value.status_code == 500 + + @pytest.mark.asyncio + async def test_setup(self, service, mock_graph): + # Mock GraphSetupSchema data + class MockTool: + node_name = "n1" + name = "t1" + description = "desc" + parameters = {} + + class MockSetupData: + tools = [MockTool()] + + mock_graph.attach_remote_tools = MagicMock() + res = await service.setup(MockSetupData()) + assert res["status"] == "success" + mock_graph.attach_remote_tools.assert_called_once_with([ + { + "type": "function", + "function": { + "name": "t1", + "description": "desc", + "parameters": {} + } + } + ], "n1") + + def test_extract_context_info(self, service): + # Case 1: Result has values + c, s = service._extract_context_info(None, {"context": ["msg"], "context_summary": "sum"}) + assert c == ["msg"] + assert s == "sum" + + # Case 2: Result doesn't have values, reads from state dict + c, s = service._extract_context_info({"context": ["msg2"], "context_summary": "sum2"}, {}) + assert c == ["msg2"] + assert s == "sum2" + + # Case 3: Result doesn't have values, reads from state object + state_obj = MagicMock() + state_obj.context = ["msg3"] + state_obj.context_summary = "sum3" + c, s = service._extract_context_info(state_obj, {}) + assert c == ["msg3"] + assert s == "sum3" diff --git a/tests/unit_tests/test_loader.py b/tests/unit_tests/test_loader.py new file mode 100644 index 0000000..b4773f7 --- /dev/null +++ b/tests/unit_tests/test_loader.py @@ -0,0 +1,367 @@ +"""Unit tests for loader.py.""" + +from unittest.mock import AsyncMock, MagicMock, patch +import pytest +from pathlib import Path +from injectq import InjectQ + +from agentflow.core import CompiledGraph +from agentflow.storage.checkpointer import BaseCheckpointer +from agentflow.storage.store import BaseStore +from agentflow_cli import BaseAuth +from agentflow_cli.src.app.core.auth.authorization import AuthorizationBackend, DefaultAuthorizationBackend +from agentflow_cli.src.app.utils.thread_name_generator import ThreadNameGenerator +from agentflow_cli.src.app.core.config.graph_config import GraphConfig +from agentflow_cli.src.app.loader import ( + load_graph, + load_checkpointer, + load_store, + load_container, + load_auth, + load_authorization, + load_thread_name_generator, + load_and_bind_auth, + load_and_bind_authorization, + attach_all_modules, +) + + +@pytest.mark.asyncio +async def test_load_graph_invalid_format(): + with pytest.raises(ValueError, match="Invalid graph path format"): + await load_graph("invalid_path_no_colon") + + +@pytest.mark.asyncio +async def test_load_graph_success_callable(): + mock_graph = MagicMock(spec=CompiledGraph) + mock_callable = MagicMock(return_value=mock_graph) + + mock_module = MagicMock() + mock_module.my_graph = mock_callable + + with patch("importlib.import_module", return_value=mock_module) as mock_import: + result = await load_graph("my_module:my_graph") + assert result == mock_graph + mock_import.assert_called_once_with("my_module") + + +@pytest.mark.asyncio +async def test_load_graph_success_async_callable(): + mock_graph = MagicMock(spec=CompiledGraph) + mock_callable = AsyncMock(return_value=mock_graph) + + mock_module = MagicMock() + mock_module.my_graph = mock_callable + + with patch("importlib.import_module", return_value=mock_module) as mock_import: + result = await load_graph("my_module:my_graph") + assert result == mock_graph + + +@pytest.mark.asyncio +async def test_load_graph_success_non_callable(): + class NonCallableGraph(CompiledGraph): + def __init__(self): + pass + + non_callable = NonCallableGraph() + mock_module = MagicMock() + mock_module.my_graph = non_callable + + with patch("importlib.import_module", return_value=mock_module): + result = await load_graph("my_module:my_graph") + assert result == non_callable + + +@pytest.mark.asyncio +async def test_load_graph_errors(): + # RuntimeError: app is None + mock_module = MagicMock() + mock_module.my_graph = None + with patch("importlib.import_module", return_value=mock_module): + with pytest.raises(Exception, match="Failed to obtain a runnable graph"): + await load_graph("my_module:my_graph") + + # TypeError: Loaded object is not a CompiledGraph + mock_module.my_graph = "not a graph" + with patch("importlib.import_module", return_value=mock_module): + with pytest.raises(Exception, match="Loaded object is not a CompiledGraph"): + await load_graph("my_module:my_graph") + + # ModuleNotFoundError + with patch("importlib.import_module", side_effect=ModuleNotFoundError("No module named foo")): + with pytest.raises(ModuleNotFoundError): + await load_graph("foo:bar") + + # AttributeError + mock_module = MagicMock() + del mock_module.bar + with patch("importlib.import_module", return_value=mock_module): + with pytest.raises(AttributeError): + await load_graph("foo:bar") + + +def test_load_checkpointer(): + assert load_checkpointer(None) is None + + with pytest.raises(ValueError, match="Invalid checkpointer path format"): + load_checkpointer("invalid_no_colon") + + mock_cp = MagicMock(spec=BaseCheckpointer) + mock_module = MagicMock() + mock_module.cp = mock_cp + with patch("importlib.import_module", return_value=mock_module): + assert load_checkpointer("mod:cp") == mock_cp + + # RuntimeError + mock_module.cp = None + with patch("importlib.import_module", return_value=mock_module): + with pytest.raises(Exception, match="Failed to obtain a BaseCheckpointer"): + load_checkpointer("mod:cp") + + # TypeError + mock_module.cp = "not cp" + with patch("importlib.import_module", return_value=mock_module): + with pytest.raises(Exception, match="Loaded object is not a BaseCheckpointer"): + load_checkpointer("mod:cp") + + # ModuleNotFoundError + with patch("importlib.import_module", side_effect=ModuleNotFoundError()): + with pytest.raises(ModuleNotFoundError): + load_checkpointer("mod:cp") + + # AttributeError + mock_module = MagicMock() + del mock_module.cp + with patch("importlib.import_module", return_value=mock_module): + with pytest.raises(AttributeError): + load_checkpointer("mod:cp") + + +def test_load_store(): + assert load_store(None) is None + + with pytest.raises(ValueError, match="Invalid store path format"): + load_store("invalid") + + mock_store = MagicMock(spec=BaseStore) + mock_module = MagicMock() + mock_module.store = mock_store + with patch("importlib.import_module", return_value=mock_module): + assert load_store("mod:store") == mock_store + + # RuntimeError + mock_module.store = None + with patch("importlib.import_module", return_value=mock_module): + with pytest.raises(Exception, match="Failed to obtain a BaseStore"): + load_store("mod:store") + + # TypeError + mock_module.store = "not store" + with patch("importlib.import_module", return_value=mock_module): + with pytest.raises(Exception, match="Loaded object is not a BaseStore"): + load_store("mod:store") + + # ModuleNotFoundError + with patch("importlib.import_module", side_effect=ModuleNotFoundError()): + with pytest.raises(ModuleNotFoundError): + load_store("mod:store") + + +def test_load_container(): + assert load_container(None) is None + + mock_container = MagicMock(spec=InjectQ) + mock_module = MagicMock() + mock_module.container = mock_container + with patch("importlib.import_module", return_value=mock_module): + assert load_container("mod:container") == mock_container + mock_container.activate.assert_called_once() + + # Exception cases + mock_module.container = "not container" + with patch("importlib.import_module", return_value=mock_module): + with pytest.raises(Exception, match="Failed to load InjectQ"): + load_container("mod:container") + + +def test_load_auth(): + assert load_auth(None) is None + + with pytest.raises(ValueError, match="Invalid auth path format"): + load_auth("invalid") + + class CustomAuth(BaseAuth): + def authenticate(self, *args, **kwargs): + pass + + # Class subclass of BaseAuth + mock_module = MagicMock() + mock_module.auth = CustomAuth + with patch("importlib.import_module", return_value=mock_module): + result = load_auth("mod:auth") + assert isinstance(result, CustomAuth) + + # Instance of BaseAuth + auth_instance = CustomAuth() + mock_module.auth = auth_instance + with patch("importlib.import_module", return_value=mock_module): + assert load_auth("mod:auth") == auth_instance + + # TypeError + mock_module.auth = "not auth" + with patch("importlib.import_module", return_value=mock_module): + with pytest.raises(Exception, match="Loaded object is not a subclass or instance of BaseAuth"): + load_auth("mod:auth") + + # ModuleNotFoundError + with patch("importlib.import_module", side_effect=ModuleNotFoundError()): + with pytest.raises(ModuleNotFoundError): + load_auth("mod:auth") + + +def test_load_authorization(): + assert load_authorization(None) is None + + class CustomAuthorization(AuthorizationBackend): + def authorize(self, *args, **kwargs): + pass + + # Class subclass + mock_module = MagicMock() + mock_module.authorization = CustomAuthorization + with patch("importlib.import_module", return_value=mock_module): + result = load_authorization("mod:authorization") + assert isinstance(result, CustomAuthorization) + + # Instance + auth_instance = CustomAuthorization() + mock_module.authorization = auth_instance + with patch("importlib.import_module", return_value=mock_module): + assert load_authorization("mod:authorization") == auth_instance + + # TypeError + mock_module.authorization = "not authorization" + with patch("importlib.import_module", return_value=mock_module): + with pytest.raises(Exception, match="Loaded object is not a subclass or instance of AuthorizationBackend"): + load_authorization("mod:authorization") + + +def test_load_thread_name_generator(): + assert load_thread_name_generator(None) is None + + class CustomGenerator(ThreadNameGenerator): + async def generate_name(self, messages): + return "name" + + # Class subclass + mock_module = MagicMock() + mock_module.generator = CustomGenerator + with patch("importlib.import_module", return_value=mock_module): + result = load_thread_name_generator("mod:generator") + assert isinstance(result, CustomGenerator) + + # Instance + gen_instance = CustomGenerator() + mock_module.generator = gen_instance + with patch("importlib.import_module", return_value=mock_module): + assert load_thread_name_generator("mod:generator") == gen_instance + + # TypeError + mock_module.generator = "not generator" + with patch("importlib.import_module", return_value=mock_module): + with pytest.raises(Exception, match="Loaded object is not a subclass or instance of ThreadNameGenerator"): + load_thread_name_generator("mod:generator") + + +def test_load_and_bind_auth(): + container = MagicMock(spec=InjectQ) + + # Missing method/path + with pytest.raises(ValueError, match="Both 'method' and 'path' must be specified"): + load_and_bind_auth(container, {"method": "custom"}) + + # Path existence check failure + with patch("pathlib.Path.exists", return_value=False): + with pytest.raises(ValueError, match="Custom auth path does not exist"): + load_and_bind_auth(container, {"method": "custom", "path": "custom_path.py:auth"}) + + # Dotted path conversion to py file check + with patch("pathlib.Path.exists", return_value=True): + with patch("agentflow_cli.src.app.loader.load_auth") as mock_load_auth: + mock_auth_instance = MagicMock(spec=BaseAuth) + mock_load_auth.return_value = mock_auth_instance + + # test "custom" method + load_and_bind_auth(container, {"method": "custom", "path": "my.auth.path:auth"}) + container.bind_instance.assert_called_with(BaseAuth, mock_auth_instance, allow_none=True) + + # test "jwt" method + from agentflow_cli.src.app.core.auth.jwt_auth import JwtAuth + load_and_bind_auth(container, {"method": "jwt", "path": "some_path.py:auth"}) + # JwtAuth is instantiated inside, so we check if standard JwtAuth was bound + args, kwargs = container.bind_instance.call_args + assert args[0] == BaseAuth + assert isinstance(args[1], JwtAuth) + + # test "none" method + load_and_bind_auth(container, {"method": "none", "path": "some_path.py:auth"}) + container.bind_instance.assert_called_with(BaseAuth, None, allow_none=True) + + +def test_load_and_bind_authorization(): + container = MagicMock(spec=InjectQ) + + # Path provided + mock_auth_backend = MagicMock(spec=AuthorizationBackend) + with patch("agentflow_cli.src.app.loader.load_authorization", return_value=mock_auth_backend): + load_and_bind_authorization(container, "mod:auth") + container.bind_instance.assert_called_once_with(AuthorizationBackend, mock_auth_backend) + + # Path is None + container.reset_mock() + load_and_bind_authorization(container, None) + args, kwargs = container.bind_instance.call_args + assert args[0] == AuthorizationBackend + assert isinstance(args[1], DefaultAuthorizationBackend) + + +@pytest.mark.asyncio +async def test_attach_all_modules(): + config = MagicMock(spec=GraphConfig) + config.graph_path = "mod:graph" + config.auth_config.return_value = {"method": "none", "path": "path.py:auth"} + config.thread_name_generator_path = "mod:generator" + config.authorization_path = "mod:authorization" + + container = MagicMock(spec=InjectQ) + + mock_graph = MagicMock(spec=CompiledGraph) + mock_generator = MagicMock(spec=ThreadNameGenerator) + mock_auth_backend = MagicMock(spec=AuthorizationBackend) + + with patch("agentflow_cli.src.app.loader.load_graph", return_value=mock_graph), \ + patch("agentflow_cli.src.app.loader.load_thread_name_generator", return_value=mock_generator), \ + patch("agentflow_cli.src.app.loader.load_authorization", return_value=mock_auth_backend), \ + patch("pathlib.Path.exists", return_value=True), \ + patch("agentflow_cli.src.app.core.config.media_settings.get_media_settings") as mock_get_media_settings: + + mock_get_media_settings.return_value = MagicMock() + + result = await attach_all_modules(config, container) + + assert result == mock_graph + # verify bindings + container.bind_instance.assert_any_call(BaseAuth, None, allow_none=True) + container.bind_instance.assert_any_call(ThreadNameGenerator, mock_generator) + container.bind_instance.assert_any_call(AuthorizationBackend, mock_auth_backend) + + # Test branch where config has no thread name generator and no auth config + config.thread_name_generator_path = None + config.auth_config.return_value = None + + container.reset_mock() + await attach_all_modules(config, container) + container.bind_instance.assert_any_call(ThreadNameGenerator, None, allow_none=True) + container.bind_instance.assert_any_call(BaseAuth, None, allow_none=True) diff --git a/tests/unit_tests/test_setup_middleware.py b/tests/unit_tests/test_setup_middleware.py index 4383d04..39a228d 100644 --- a/tests/unit_tests/test_setup_middleware.py +++ b/tests/unit_tests/test_setup_middleware.py @@ -1,7 +1,13 @@ +import sys +import pytest +from unittest.mock import MagicMock, patch from fastapi import FastAPI from fastapi.testclient import TestClient -from agentflow_cli.src.app.core.config.setup_middleware import setup_middleware +from agentflow_cli.src.app.core.config.setup_middleware import ( + setup_middleware, + SelectiveGZipMiddleware, +) HTTP_OK = 200 @@ -24,3 +30,361 @@ def echo(): # Ensure stable format (uuid length 36) and iso-like timestamp assert len(r.headers["X-Request-ID"]) >= MIN_REQUEST_ID_LEN assert "T" in r.headers["X-Timestamp"] + + +@pytest.mark.asyncio +async def test_selective_gzip_middleware_excludes(): + from unittest.mock import AsyncMock + called = [] + async def app(scope, receive, send): + called.append((scope, receive, send)) + + with patch("agentflow_cli.src.app.core.config.setup_middleware.GZipMiddleware") as MockGZipMiddleware: + mock_gzip_instance = AsyncMock() + MockGZipMiddleware.return_value = mock_gzip_instance + + middleware = SelectiveGZipMiddleware(app) + + # Test excluded path + scope = {"type": "http", "path": "/v1/graph/stream"} + receive = MagicMock() + send = MagicMock() + await middleware(scope, receive, send) + + assert len(called) == 1 + assert called[0][0] == scope + mock_gzip_instance.assert_not_called() + + called.clear() + mock_gzip_instance.reset_mock() + + # Test non-excluded path + scope2 = {"type": "http", "path": "/v1/graph/other"} + await middleware(scope2, receive, send) + + assert len(called) == 0 + mock_gzip_instance.assert_called_once_with(scope2, receive, send) + + +def test_setup_otel_import_error(): + from agentflow_cli.src.app.core.config.setup_middleware import _setup_otel + settings = MagicMock() + settings.OTEL_SERVICE_NAME = "test" + with patch.dict("sys.modules", {"opentelemetry": None}): + # setup_middleware should catch ImportError and return without raising + _setup_otel(MagicMock(), settings) + + +def test_setup_otel_with_endpoint(): + settings = MagicMock() + settings.OTEL_ENABLED = True + settings.OTEL_SERVICE_NAME = "test-service" + settings.OTEL_EXPORTER_OTLP_ENDPOINT = "http://localhost:4317" + + mock_trace = MagicMock() + mock_trace.trace = mock_trace + mock_trace.__path__ = [] + mock_instrumentor = MagicMock() + mock_resource = MagicMock() + mock_provider = MagicMock() + mock_exporter = MagicMock() + mock_processor = MagicMock() + + modules = { + "opentelemetry": mock_trace, + "opentelemetry.trace": mock_trace, + "opentelemetry.instrumentation.fastapi": mock_instrumentor, + "opentelemetry.sdk.resources": mock_resource, + "opentelemetry.sdk.trace": mock_provider, + "opentelemetry.sdk.trace.export": mock_processor, + "opentelemetry.exporter": MagicMock(__path__=[]), + "opentelemetry.exporter.otlp": MagicMock(__path__=[]), + "opentelemetry.exporter.otlp.proto": MagicMock(__path__=[]), + "opentelemetry.exporter.otlp.proto.grpc": MagicMock(__path__=[]), + "opentelemetry.exporter.otlp.proto.grpc.trace_exporter": mock_exporter, + } + + with patch.dict("sys.modules", modules): + mock_resource.Resource.create.return_value = "resource_obj" + mock_provider.TracerProvider.return_value = mock_provider + + mock_exporter.OTLPSpanExporter = MagicMock() + mock_processor.BatchSpanProcessor = MagicMock() + + from agentflow_cli.src.app.core.config.setup_middleware import _setup_otel + _setup_otel(MagicMock(), settings) + + mock_resource.Resource.create.assert_called_once_with({"service.name": "test-service"}) + mock_provider.TracerProvider.assert_called_once_with(resource="resource_obj") + mock_exporter.OTLPSpanExporter.assert_called_once_with(endpoint="http://localhost:4317") + mock_provider.add_span_processor.assert_called_once() + mock_trace.set_tracer_provider.assert_called_once_with(mock_provider) + mock_instrumentor.FastAPIInstrumentor.instrument_app.assert_called_once() + + +def test_setup_otel_no_endpoint(): + settings = MagicMock() + settings.OTEL_ENABLED = True + settings.OTEL_SERVICE_NAME = "test-service" + settings.OTEL_EXPORTER_OTLP_ENDPOINT = None + + mock_trace = MagicMock() + mock_instrumentor = MagicMock() + mock_resource = MagicMock() + mock_provider = MagicMock() + mock_processor = MagicMock() + + modules = { + "opentelemetry": mock_trace, + "opentelemetry.trace": mock_trace, + "opentelemetry.instrumentation.fastapi": mock_instrumentor, + "opentelemetry.sdk.resources": mock_resource, + "opentelemetry.sdk.trace": mock_provider, + "opentelemetry.sdk.trace.export": mock_processor, + } + + with patch.dict("sys.modules", modules): + mock_resource.Resource.create.return_value = "resource_obj" + mock_provider.TracerProvider.return_value = mock_provider + mock_processor.ConsoleSpanExporter = MagicMock() + mock_processor.SimpleSpanProcessor = MagicMock() + + from agentflow_cli.src.app.core.config.setup_middleware import _setup_otel + _setup_otel(MagicMock(), settings) + + mock_processor.ConsoleSpanExporter.assert_called_once() + mock_processor.SimpleSpanProcessor.assert_called_once() + mock_provider.add_span_processor.assert_called_once() + + +def test_setup_otel_grpc_exporter_import_error(): + settings = MagicMock() + settings.OTEL_ENABLED = True + settings.OTEL_SERVICE_NAME = "test-service" + settings.OTEL_EXPORTER_OTLP_ENDPOINT = "http://localhost:4317" + + mock_trace = MagicMock() + mock_instrumentor = MagicMock() + mock_resource = MagicMock() + mock_provider = MagicMock() + mock_processor = MagicMock() + + modules = { + "opentelemetry": mock_trace, + "opentelemetry.trace": mock_trace, + "opentelemetry.instrumentation.fastapi": mock_instrumentor, + "opentelemetry.sdk.resources": mock_resource, + "opentelemetry.sdk.trace": mock_provider, + "opentelemetry.sdk.trace.export": mock_processor, + "opentelemetry.exporter.otlp.proto.grpc.trace_exporter": None, + } + + with patch.dict("sys.modules", modules): + mock_resource.Resource.create.return_value = "resource_obj" + mock_provider.TracerProvider.return_value = mock_provider + + from agentflow_cli.src.app.core.config.setup_middleware import _setup_otel + _setup_otel(MagicMock(), settings) + # Should return gracefully on ImportError of grpc exporter + + +def test_attach_otel_publisher_import_error(): + with patch.dict("sys.modules", {"agentflow.runtime.publisher.base_publisher": None}): + from agentflow_cli.src.app.core.config.setup_middleware import _attach_otel_publisher + container = MagicMock() + _attach_otel_publisher(container, MagicMock()) + + +def test_attach_otel_publisher_value_error(): + class FakeObservabilityLevel: + STANDARD = "standard" + def __init__(self, val): + raise ValueError("invalid level") + + class FakeOtelPublisher: + def __init__(self, level): + self.level = level + + class FakeBasePublisher: + pass + + modules = { + "agentflow.runtime.publisher.base_publisher": MagicMock(BasePublisher=FakeBasePublisher), + "agentflow.runtime.publisher.composite_publisher": MagicMock(), + "agentflow.runtime.publisher.otel_publisher": MagicMock( + ObservabilityLevel=FakeObservabilityLevel, + OtelPublisher=FakeOtelPublisher + ) + } + with patch.dict("sys.modules", modules): + from agentflow_cli.src.app.core.config.setup_middleware import _attach_otel_publisher + container = MagicMock() + container.try_get.return_value = None + settings = MagicMock() + settings.OTEL_LEVEL = "invalid-level" + _attach_otel_publisher(container, settings) + container.bind_instance.assert_called_once() + + +def test_attach_otel_publisher_no_existing(): + class FakeObservabilityLevel: + STANDARD = "standard" + def __init__(self, val): + self.val = val + + class FakeOtelPublisher: + def __init__(self, level): + self.level = level + + class FakeBasePublisher: + pass + + modules = { + "agentflow.runtime.publisher.base_publisher": MagicMock(BasePublisher=FakeBasePublisher), + "agentflow.runtime.publisher.composite_publisher": MagicMock(), + "agentflow.runtime.publisher.otel_publisher": MagicMock( + ObservabilityLevel=FakeObservabilityLevel, + OtelPublisher=FakeOtelPublisher + ) + } + with patch.dict("sys.modules", modules): + from agentflow_cli.src.app.core.config.setup_middleware import _attach_otel_publisher + container = MagicMock() + container.try_get.return_value = None + settings = MagicMock() + settings.OTEL_LEVEL = "standard" + _attach_otel_publisher(container, settings) + container.bind_instance.assert_called_once() + + +def test_attach_otel_publisher_existing_composite(): + class FakeObservabilityLevel: + STANDARD = "standard" + def __init__(self, val): + self.val = val + + class FakeOtelPublisher: + def __init__(self, level): + self.level = level + + class FakeBasePublisher: + pass + + class FakeCompositePublisher: + def __init__(self, publishers=None): + self.publishers = publishers or [] + def add_publisher(self, pub): + self.publishers.append(pub) + + existing = FakeCompositePublisher() + + modules = { + "agentflow.runtime.publisher.base_publisher": MagicMock(BasePublisher=FakeBasePublisher), + "agentflow.runtime.publisher.composite_publisher": MagicMock(CompositePublisher=FakeCompositePublisher), + "agentflow.runtime.publisher.otel_publisher": MagicMock( + ObservabilityLevel=FakeObservabilityLevel, + OtelPublisher=FakeOtelPublisher + ) + } + with patch.dict("sys.modules", modules): + from agentflow_cli.src.app.core.config.setup_middleware import _attach_otel_publisher + container = MagicMock() + container.try_get.return_value = existing + settings = MagicMock() + settings.OTEL_LEVEL = "standard" + + with patch("agentflow_cli.src.app.core.config.setup_middleware.isinstance", return_value=True): + _attach_otel_publisher(container, settings) + + +def test_attach_otel_publisher_existing_single(): + class FakeObservabilityLevel: + STANDARD = "standard" + def __init__(self, val): + self.val = val + + class FakeOtelPublisher: + def __init__(self, level): + self.level = level + + class FakeBasePublisher: + pass + + class FakeCompositePublisher: + def __init__(self, publishers=None): + self.publishers = publishers or [] + + class SinglePublisher: + pass + + existing = SinglePublisher() + + modules = { + "agentflow.runtime.publisher.base_publisher": MagicMock(BasePublisher=FakeBasePublisher), + "agentflow.runtime.publisher.composite_publisher": MagicMock(CompositePublisher=FakeCompositePublisher), + "agentflow.runtime.publisher.otel_publisher": MagicMock( + ObservabilityLevel=FakeObservabilityLevel, + OtelPublisher=FakeOtelPublisher + ) + } + with patch.dict("sys.modules", modules): + from agentflow_cli.src.app.core.config.setup_middleware import _attach_otel_publisher + container = MagicMock() + container.try_get.return_value = existing + settings = MagicMock() + settings.OTEL_LEVEL = "standard" + + _attach_otel_publisher(container, settings) + container.bind_instance.assert_called_once() + + +class MockSettings: + OTEL_ENABLED = True + OTEL_SERVICE_NAME = "test-service" + OTEL_EXPORTER_OTLP_ENDPOINT = None + ORIGINS = "http://localhost,http://localhost:3000" + ALLOWED_HOST = "localhost,127.0.0.1" + MAX_REQUEST_SIZE = 1024 * 1024 + SECURITY_HEADERS_ENABLED = True + HSTS_ENABLED = True + HSTS_MAX_AGE = 31536000 + HSTS_INCLUDE_SUBDOMAINS = True + HSTS_PRELOAD = True + FRAME_OPTIONS = "DENY" + CONTENT_TYPE_OPTIONS = "nosniff" + XSS_PROTECTION = "1; mode=block" + REFERRER_POLICY = "no-referrer" + PERMISSIONS_POLICY = "geolocation=()" + CSP_POLICY = "default-src 'self'" + + +def test_setup_middleware_all(): + app = FastAPI() + settings = MockSettings() + + graph_config = MagicMock() + rate_limit_config = MagicMock() + rate_limit_config.backend = "memory" + rate_limit_config.requests = 100 + rate_limit_config.window = 60 + rate_limit_config.by = "ip" + rate_limit_config.exclude_paths = None + rate_limit_config.trusted_proxy_headers = True + graph_config.rate_limit = rate_limit_config + + container = MagicMock() + + with patch("agentflow_cli.src.app.core.config.setup_middleware.get_settings", return_value=settings), \ + patch("agentflow_cli.src.app.core.config.setup_middleware.init_sentry") as mock_init_sentry, \ + patch("agentflow_cli.src.app.core.config.setup_middleware.build_backend", return_value="mock_backend") as mock_build_backend, \ + patch("agentflow_cli.src.app.core.config.setup_middleware._setup_otel") as mock_setup_otel, \ + patch("agentflow_cli.src.app.core.config.setup_middleware._attach_otel_publisher") as mock_attach: + + setup_middleware(app, graph_config=graph_config, container=container) + + mock_init_sentry.assert_called_once_with(settings) + mock_build_backend.assert_called_once_with(rate_limit_config, container=container) + assert app.state.rate_limit_backend == "mock_backend" + mock_setup_otel.assert_called_once_with(app, settings) + mock_attach.assert_called_once_with(container, settings) +