From 4c98a6b6c68b5c098ce37c4cd99cf4e02cbab06d Mon Sep 17 00:00:00 2001 From: Eric Windmill Date: Wed, 11 Mar 2026 11:39:36 -0700 Subject: [PATCH 1/7] copy runner into repo --- packages/dash_evals/README.md | 3 + packages/dash_evals/pyproject.toml | 72 +++++ packages/dash_evals/pyrefly.toml | 4 + .../dash_evals/src/dash_evals/__init__.py | 12 + packages/dash_evals/src/dash_evals/main.py | 118 +++++++ .../src/dash_evals/runner/__init__.py | 7 + .../src/dash_evals/runner/args_runner.py | 73 +++++ .../src/dash_evals/runner/json_runner.py | 213 +++++++++++++ .../dash_evals/runner/sandboxes/__init__.py | 5 + .../runner/sandboxes/podman/__init__.py | 1 + .../runner/sandboxes/podman/podman.py | 301 ++++++++++++++++++ .../dash_evals/runner/sandboxes/provider.py | 4 + .../src/dash_evals/runner/scorers/__init__.py | 20 ++ .../dash_evals/runner/scorers/code_quality.py | 132 ++++++++ .../dash_evals/runner/scorers/dart_analyze.py | 136 ++++++++ .../runner/scorers/export_workspace.py | 131 ++++++++ .../dash_evals/runner/scorers/flutter_code.py | 97 ++++++ .../runner/scorers/flutter_output_parser.py | 102 ++++++ .../runner/scorers/flutter_scoring.py | 152 +++++++++ .../dash_evals/runner/scorers/flutter_test.py | 113 +++++++ .../runner/scorers/mcp_tool_usage.py | 162 ++++++++++ .../dash_evals/runner/scorers/skill_usage.py | 67 ++++ .../src/dash_evals/runner/solvers/__init__.py | 17 + .../runner/solvers/add_system_message.py | 27 ++ .../runner/solvers/context_injector.py | 42 +++ .../dash_evals/runner/solvers/extract_code.py | 34 ++ .../runner/solvers/inject_test_files.py | 64 ++++ .../runner/solvers/setup_workspace.py | 189 +++++++++++ .../runner/solvers/write_to_sandbox.py | 112 +++++++ .../src/dash_evals/runner/tasks/__init__.py | 17 + .../runner/tasks/analyze_codebase.py | 114 +++++++ .../src/dash_evals/runner/tasks/bug_fix.py | 191 +++++++++++ .../src/dash_evals/runner/tasks/code_gen.py | 261 +++++++++++++++ .../src/dash_evals/runner/tasks/mcp_tool.py | 81 +++++ .../runner/tasks/question_answer.py | 57 ++++ .../src/dash_evals/runner/tasks/skill_test.py | 121 +++++++ .../dash_evals/runner/tasks/task_helpers.py | 171 ++++++++++ .../src/dash_evals/utils/__init__.py | 3 + .../src/dash_evals/utils/logging.py | 123 +++++++ .../src/dash_evals/utils/markdown.py | 44 +++ packages/dash_evals/tests/__init__.py | 0 .../dash_evals/tests/test_export_workspace.py | 253 +++++++++++++++ .../tests/test_flutter_code_execution.py | 227 +++++++++++++ packages/dash_evals/tests/test_models.py | 42 +++ packages/dash_evals/tests/test_scorers.py | 147 +++++++++ packages/dash_evals/tests/test_solvers.py | 55 ++++ packages/dash_evals/tests/test_utils.py | 77 +++++ .../dataset_config/lib/src/models/task.dart | 2 +- .../lib/src/models/task.freezed.dart | 4 +- .../lib/src/commands/doctor_command.dart | 20 +- .../lib/src/commands/run_command.dart | 8 +- .../test/commands/doctor_command_test.dart | 2 +- 52 files changed, 4412 insertions(+), 18 deletions(-) create mode 100644 packages/dash_evals/README.md create mode 100644 packages/dash_evals/pyproject.toml create mode 100644 packages/dash_evals/pyrefly.toml create mode 100644 packages/dash_evals/src/dash_evals/__init__.py create mode 100644 packages/dash_evals/src/dash_evals/main.py create mode 100644 packages/dash_evals/src/dash_evals/runner/__init__.py create mode 100644 packages/dash_evals/src/dash_evals/runner/args_runner.py create mode 100644 packages/dash_evals/src/dash_evals/runner/json_runner.py create mode 100644 packages/dash_evals/src/dash_evals/runner/sandboxes/__init__.py create mode 100644 packages/dash_evals/src/dash_evals/runner/sandboxes/podman/__init__.py create mode 100644 packages/dash_evals/src/dash_evals/runner/sandboxes/podman/podman.py create mode 100644 packages/dash_evals/src/dash_evals/runner/sandboxes/provider.py create mode 100644 packages/dash_evals/src/dash_evals/runner/scorers/__init__.py create mode 100644 packages/dash_evals/src/dash_evals/runner/scorers/code_quality.py create mode 100644 packages/dash_evals/src/dash_evals/runner/scorers/dart_analyze.py create mode 100644 packages/dash_evals/src/dash_evals/runner/scorers/export_workspace.py create mode 100644 packages/dash_evals/src/dash_evals/runner/scorers/flutter_code.py create mode 100644 packages/dash_evals/src/dash_evals/runner/scorers/flutter_output_parser.py create mode 100644 packages/dash_evals/src/dash_evals/runner/scorers/flutter_scoring.py create mode 100644 packages/dash_evals/src/dash_evals/runner/scorers/flutter_test.py create mode 100644 packages/dash_evals/src/dash_evals/runner/scorers/mcp_tool_usage.py create mode 100644 packages/dash_evals/src/dash_evals/runner/scorers/skill_usage.py create mode 100644 packages/dash_evals/src/dash_evals/runner/solvers/__init__.py create mode 100644 packages/dash_evals/src/dash_evals/runner/solvers/add_system_message.py create mode 100644 packages/dash_evals/src/dash_evals/runner/solvers/context_injector.py create mode 100644 packages/dash_evals/src/dash_evals/runner/solvers/extract_code.py create mode 100644 packages/dash_evals/src/dash_evals/runner/solvers/inject_test_files.py create mode 100644 packages/dash_evals/src/dash_evals/runner/solvers/setup_workspace.py create mode 100644 packages/dash_evals/src/dash_evals/runner/solvers/write_to_sandbox.py create mode 100644 packages/dash_evals/src/dash_evals/runner/tasks/__init__.py create mode 100644 packages/dash_evals/src/dash_evals/runner/tasks/analyze_codebase.py create mode 100644 packages/dash_evals/src/dash_evals/runner/tasks/bug_fix.py create mode 100644 packages/dash_evals/src/dash_evals/runner/tasks/code_gen.py create mode 100644 packages/dash_evals/src/dash_evals/runner/tasks/mcp_tool.py create mode 100644 packages/dash_evals/src/dash_evals/runner/tasks/question_answer.py create mode 100644 packages/dash_evals/src/dash_evals/runner/tasks/skill_test.py create mode 100644 packages/dash_evals/src/dash_evals/runner/tasks/task_helpers.py create mode 100644 packages/dash_evals/src/dash_evals/utils/__init__.py create mode 100644 packages/dash_evals/src/dash_evals/utils/logging.py create mode 100644 packages/dash_evals/src/dash_evals/utils/markdown.py create mode 100644 packages/dash_evals/tests/__init__.py create mode 100644 packages/dash_evals/tests/test_export_workspace.py create mode 100644 packages/dash_evals/tests/test_flutter_code_execution.py create mode 100644 packages/dash_evals/tests/test_models.py create mode 100644 packages/dash_evals/tests/test_scorers.py create mode 100644 packages/dash_evals/tests/test_solvers.py create mode 100644 packages/dash_evals/tests/test_utils.py diff --git a/packages/dash_evals/README.md b/packages/dash_evals/README.md new file mode 100644 index 0000000..774ff45 --- /dev/null +++ b/packages/dash_evals/README.md @@ -0,0 +1,3 @@ +# dash_evals + +Python package for running LLM evaluations on Dart and Flutter tasks using [Inspect AI](https://inspect.aisi.org.uk/). diff --git a/packages/dash_evals/pyproject.toml b/packages/dash_evals/pyproject.toml new file mode 100644 index 0000000..ac2f9f2 --- /dev/null +++ b/packages/dash_evals/pyproject.toml @@ -0,0 +1,72 @@ +[project] +name = "dash-evals" +version = "0.1.0" +description = "" +authors = [{ name = "Eric Windmill", email = "eric@ericwindmill.com" }] +readme = "README.md" +requires-python = ">=3.13,<4.0.0" +dependencies = [ + "inspect-ai>=0.3.142,<0.4.0", + "pyyaml>=6.0.3,<7.0.0", + "google-genai>=1.47.0,<2.0.0", + "mcp>=1.20.0,<2.0.0", + "python-dotenv>=1.2.1,<2.0.0", + "anthropic>=0.75.0,<0.81.0", + "openai>=2.8.1,<3.0.0", + "firebase-admin>=6.0.0,<8.0.0", + "pydantic>=2.0.0,<3.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-mock>=3.12.0", + "pytest-cov>=4.1.0", + "pylint>=3.0.0", +] + +[project.scripts] +run-evals = "dash_evals.main:main" + +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +# Register podman sandbox with inspect_ai +[project.entry-points.inspect_ai] +dash_evals = "dash_evals.runner.sandboxes" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.setuptools.package-data] +dash_evals = ["data/*.yaml"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] + +[tool.coverage.run] +omit = [ + "src/dash_evals/main.py", + "src/dash_evals/uploader.py", + "src/dash_evals/uploader_aggregates.py", + "src/dash_evals/tasks/*", +] + +[tool.pylint.messages_control] +disable = [ + "logging-fstring-interpolation", # Allow f-strings in logging (modern Python standard) +] + +[tool.pylint.format] +max-line-length = 100 + +[tool.ruff] +line-length = 100 + +[tool.ruff.lint] +select = ["E", "F", "W", "I"] +ignore = ["E501"] # Line too long (handled by formatter) diff --git a/packages/dash_evals/pyrefly.toml b/packages/dash_evals/pyrefly.toml new file mode 100644 index 0000000..6370db3 --- /dev/null +++ b/packages/dash_evals/pyrefly.toml @@ -0,0 +1,4 @@ +# Pyrefly configuration +# Tell Pyrefly to use the local venv Python interpreter + +python-interpreter = "../../.venv/bin/python" diff --git a/packages/dash_evals/src/dash_evals/__init__.py b/packages/dash_evals/src/dash_evals/__init__.py new file mode 100644 index 0000000..d7e7db7 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/__init__.py @@ -0,0 +1,12 @@ +"""dash_evals - Evaluation framework for Dart and Flutter AI assistants. + +This package provides tools for running evaluations using Inspect AI +to measure model performance on Dart/Flutter tasks. + +Configuration is resolved by the Dart CLI (devals) and emitted as JSONL +datasets + a run manifest. The Python package reads the manifest and +calls eval_set() directly. + +Main entry point: + run-evals --manifest +""" diff --git a/packages/dash_evals/src/dash_evals/main.py b/packages/dash_evals/src/dash_evals/main.py new file mode 100644 index 0000000..6c6fc67 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/main.py @@ -0,0 +1,118 @@ +# Copyright 2025 The Flutter Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +"""CLI entry point for running evaluations. + +Usage: + run-evals --json ./eval_set.json + run-evals --task my_task --model openai/gpt-4o --dataset samples.jsonl +""" + +import argparse +import logging +import sys +from pathlib import Path + +from dotenv import load_dotenv + +# Import sandbox environments to register them with InspectAI +# The @sandboxenv decorator registers the sandbox type when the module is imported +import dash_evals.runner.sandboxes.podman.podman # noqa: F401 # Registers 'podman' +from dash_evals.runner.args_runner import _run_from_args +from dash_evals.runner.json_runner import run_from_json + +# Basic console logger for early startup messages +logging.basicConfig(level=logging.INFO, format="%(message)s") +_startup_logger = logging.getLogger("startup") + + +def main(): + """Parse command-line arguments and run evaluations.""" + # Load .env from the repo root (walks up from cwd). + # This populates os.environ with API keys, credentials, etc. + # System env vars take precedence over .env values (python-dotenv default). + load_dotenv(override=False) + + parser = argparse.ArgumentParser( + description="Run Inspect AI evaluations for the Dart/Flutter plugin.", + epilog="Example: run-evals --json ./eval_set.json", + ) + + # ---------- JSON mode (mutually exclusive with direct args) ---------- + parser.add_argument( + "--json", + type=Path, + help="Path to eval_set.json (emitted by Dart CLI).", + ) + + # ---------- Direct-args mode ---------- + parser.add_argument( + "--task", + type=str, + help="Task function name (e.g. 'flutter_code_gen' or dotted path).", + ) + parser.add_argument( + "--model", + type=str, + action="append", + help="Model to evaluate (can be repeated). Example: openai/gpt-4o", + ) + parser.add_argument( + "--dataset", + type=Path, + help="Path to a dataset file (JSON/JSONL/CSV).", + ) + parser.add_argument( + "--log-dir", + type=Path, + help="Directory to write evaluation logs.", + ) + parser.add_argument( + "--sandbox", + type=str, + nargs=2, + metavar=("TYPE", "CONFIG"), + help="Sandbox type and config path. Example: podman compose.yaml", + ) + parser.add_argument( + "--max-connections", + type=int, + help="Maximum concurrent model connections.", + ) + parser.add_argument( + "--max-samples", + type=int, + help="Maximum concurrent samples per task.", + ) + parser.add_argument( + "--fail-on-error", + type=float, + help="Proportion of sample errors to tolerate (0.0-1.0).", + ) + + args = parser.parse_args() + + # Ensure either --json or direct args are provided, but not both. + direct_args_provided = any([args.task, args.model, args.dataset]) + if args.json and direct_args_provided: + parser.error( + "Cannot combine --json with --task/--model/--dataset. Use one mode or the other." + ) + if not args.json and not direct_args_provided: + parser.error("Provide either --json or at least --task and --model.") + + try: + if args.json: + has_failures = run_from_json(args.json) + else: + has_failures = _run_from_args(args) + except Exception as e: + _startup_logger.error(f"Failed to run evaluation: {e}") + sys.exit(1) + + sys.exit(1 if has_failures else 0) + + +if __name__ == "__main__": + main() diff --git a/packages/dash_evals/src/dash_evals/runner/__init__.py b/packages/dash_evals/src/dash_evals/runner/__init__.py new file mode 100644 index 0000000..033a4e6 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/__init__.py @@ -0,0 +1,7 @@ +"""Runner module for executing evaluations. + +This module contains the core evaluation logic including: +- Task definitions and registry +- Solvers for setting up workspaces +- Scorers for evaluating model outputs +""" diff --git a/packages/dash_evals/src/dash_evals/runner/args_runner.py b/packages/dash_evals/src/dash_evals/runner/args_runner.py new file mode 100644 index 0000000..ee1b6ec --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/args_runner.py @@ -0,0 +1,73 @@ +import argparse +import logging +import sys +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def _run_from_args(args: argparse.Namespace) -> bool: + """Build an eval_set call from individual CLI arguments. + + Returns: + True if any tasks failed, False if all succeeded. + """ + import inspect_ai + + from dash_evals.runner.json_runner import _resolve_task_func + from dash_evals.utils.logging import setup_logging + + if not args.task: + logger.error("--task is required in direct-args mode.") + sys.exit(1) + if not args.model: + logger.error("--model is required in direct-args mode.") + sys.exit(1) + + # Resolve task function + task_func = _resolve_task_func(args.task) + + # Build dataset + dataset = None + if args.dataset: + from inspect_ai.dataset import json_dataset + + dataset = json_dataset(str(args.dataset)) + + # Build the task instance + task_def = {"name": args.task} + task_instance = task_func(dataset, task_def) if dataset else task_func(None, task_def) + + # Set up logging + log_dir = args.log_dir or Path("./logs") + log_dir.mkdir(parents=True, exist_ok=True) + setup_logging(log_dir, name="dash_evals") + + # Build eval_set kwargs + eval_kwargs: dict = { + "log_dir": str(log_dir), + "model": args.model, + } + if args.sandbox: + eval_kwargs["sandbox"] = tuple(args.sandbox) + if args.max_connections is not None: + eval_kwargs["max_connections"] = args.max_connections + if args.max_samples is not None: + eval_kwargs["max_samples"] = args.max_samples + if args.fail_on_error is not None: + eval_kwargs["fail_on_error"] = args.fail_on_error + + logger.info( + f"\n{'=' * 70}\nšŸš€ RUNNING task '{args.task}' with model(s): " + f"{', '.join(args.model)}\n{'=' * 70}" + ) + + try: + success, _ = inspect_ai.eval_set( + tasks=[task_instance], + **eval_kwargs, + ) + return not success + except Exception as e: + logger.error(f"Evaluation failed: {e}") + return True diff --git a/packages/dash_evals/src/dash_evals/runner/json_runner.py b/packages/dash_evals/src/dash_evals/runner/json_runner.py new file mode 100644 index 0000000..a5d7a5b --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/json_runner.py @@ -0,0 +1,213 @@ +"""Thin shim: read InspectEvalSet JSON, build Tasks, call eval_set(). + +The JSON file maps ~1:1 to eval_set() kwargs. The 'tasks' key contains +task definitions with inline datasets (InspectDataset with InspectSample +objects). +""" + +import importlib +import json +import logging +from pathlib import Path + +import inspect_ai +from inspect_ai.dataset import MemoryDataset, Sample + +from dash_evals.utils.logging import capture_output, setup_logging + +logger = logging.getLogger(__name__) + +# Keys in the JSON that are NOT eval_set() kwargs. +# They are consumed separately to build Task objects. +_NON_EVAL_SET_KEYS = {"tasks"} + + +def _resolve_task_func(name: str): + """Resolve a task function by name using importlib. + + Supports: + - Short names: "flutter_code_gen" → dash_evals.runner.tasks.flutter_code_gen + - Dotted paths: "dash_evals.runner.tasks.flutter_code_gen.flutter_code_gen" + + For short names, first tries to import a module with the same name. + If that fails, falls back to looking up the function in the tasks + package's __init__ (e.g., flutter_bug_fix is exported from bug_fix.py + via __init__.py). + + Returns the callable task function. + """ + if "." not in name: + # Short name: try module with the same name first + module_path = f"dash_evals.runner.tasks.{name}" + func_name = name + try: + module = importlib.import_module(module_path) + func = getattr(module, func_name, None) + if func is not None: + return func + except ModuleNotFoundError: + pass + + # Fall back to the tasks package __init__ (handles re-exports + # like flutter_bug_fix from bug_fix.py) + package = importlib.import_module("dash_evals.runner.tasks") + func = getattr(package, func_name, None) + if func is not None: + return func + + raise ValueError( + f"Could not find task function '{name}'. " + f"Check that the function exists in dash_evals.runner.tasks " + f"and is exported in __init__.py." + ) + else: + # Dotted path: last segment is the function name + module_path, _, func_name = name.rpartition(".") + + try: + module = importlib.import_module(module_path) + except ModuleNotFoundError: + raise ValueError( + f"Could not find module '{module_path}' for task function '{name}'. " + f"Check that the module exists and is importable." + ) + + func = getattr(module, func_name, None) + if func is None: + raise ValueError(f"Module '{module_path}' does not have a function '{func_name}'.") + return func + + +def _build_dataset_from_inline(task_def: dict) -> MemoryDataset: + """Build an Inspect AI MemoryDataset from inline dataset in the task def. + + The task_def["dataset"]["samples"] contains a list of InspectSample dicts. + """ + dataset_def = task_def.get("dataset") + if not dataset_def: + return MemoryDataset([], name=task_def.get("name", "")) + + raw_samples = dataset_def.get("samples", []) + samples = [] + for raw in raw_samples: + sample = Sample( + input=raw["input"], + target=raw.get("target", ""), + id=raw.get("id"), + metadata=raw.get("metadata"), + files=raw.get("files"), + setup=raw.get("setup"), + sandbox=raw.get("sandbox"), + ) + samples.append(sample) + + return MemoryDataset( + samples, + name=dataset_def.get("name", task_def.get("name", "")), + ) + + +def run_from_json(manifest_path: str | Path) -> bool: + """Load an InspectEvalSet JSON, build Tasks, and call eval_set(). + + Args: + manifest_path: Path to eval_set.json emitted by the Dart CLI. + + Returns: + True if any tasks failed, False if all succeeded. + """ + manifest_path = Path(manifest_path) + + with open(manifest_path) as f: + raw = json.load(f) + + # Support single eval_set or list (one per flutter channel) + manifests = raw if isinstance(raw, list) else [raw] + + any_failures = False + for manifest in manifests: + if _run_single_manifest(manifest): + any_failures = True + + return any_failures + + +def _run_single_manifest(manifest: dict) -> bool: + """Run a single InspectEvalSet entry. + + Returns True if any tasks failed. + """ + log_dir = manifest["log_dir"] + Path(log_dir).mkdir(parents=True, exist_ok=True) + job_logger, log_file_path = setup_logging(Path(log_dir), name="dash_evals") + + # Build Task objects from inline datasets + task_defs = manifest["tasks"] + task_instances: list[inspect_ai.Task] = [] + + for task_def in task_defs: + task_func_name = task_def.get("task_func") + task_name = task_def.get("name", task_func_name or "(unknown)") + + if not task_func_name: + # Mode 2: hydrate directly from JSON (future) + job_logger.warning( + f" ⚠ {task_name}: no task_func — Mode 2 hydration not yet supported" + ) + continue + + try: + task_func = _resolve_task_func(task_func_name) + except ValueError as e: + job_logger.warning(f" āœ— {task_name}: {e}") + continue + + # Build inline dataset + dataset = _build_dataset_from_inline(task_def) + + # Inject task_name into the config for task functions that expect it. + # The Dart CLI emits "name" but task functions use "task_name". + if "task_name" not in task_def and "name" in task_def: + task_def["task_name"] = task_def["name"] + + # Inject sandbox_type for task functions that check it. + # The Dart CLI emits "sandbox" as ["type", "path"] or a string, + # but task functions check "sandbox_type". + if "sandbox_type" not in task_def: + sandbox = task_def.get("sandbox") or manifest.get("sandbox") + if isinstance(sandbox, list) and len(sandbox) >= 1: + task_def["sandbox_type"] = sandbox[0] + elif isinstance(sandbox, str) and sandbox != "local": + task_def["sandbox_type"] = sandbox + + try: + task_instance = task_func(dataset, task_def) + task_instances.append(task_instance) + job_logger.info(f" āœ“ {task_name} ({len(dataset)} samples)") + except Exception as e: + job_logger.error(f" āœ— {task_name}: {e}") + + if not task_instances: + job_logger.warning("No valid tasks to run") + return True + + # Build eval_set kwargs from remaining manifest keys + eval_set_kwargs = {k: v for k, v in manifest.items() if k not in _NON_EVAL_SET_KEYS} + + # Convert sandbox list to tuple (eval_set expects tuple for ("type", "path")) + sandbox = eval_set_kwargs.get("sandbox") + if isinstance(sandbox, list) and len(sandbox) == 2: + eval_set_kwargs["sandbox"] = tuple(sandbox) + + job_logger.info(f"\n{'=' * 70}\nšŸš€ RUNNING {len(task_instances)} TASKS\n{'=' * 70}") + + try: + with capture_output(log_file_path): + success, _ = inspect_ai.eval_set( + tasks=task_instances, + **eval_set_kwargs, + ) + return not success + except Exception as e: + job_logger.error(f"Evaluation failed: {e}") + return True diff --git a/packages/dash_evals/src/dash_evals/runner/sandboxes/__init__.py b/packages/dash_evals/src/dash_evals/runner/sandboxes/__init__.py new file mode 100644 index 0000000..017b622 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/sandboxes/__init__.py @@ -0,0 +1,5 @@ +# Sandbox environments for dash_evals + +from .provider import podman as podman + +_all_ = ["podman"] diff --git a/packages/dash_evals/src/dash_evals/runner/sandboxes/podman/__init__.py b/packages/dash_evals/src/dash_evals/runner/sandboxes/podman/__init__.py new file mode 100644 index 0000000..e6124e3 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/sandboxes/podman/__init__.py @@ -0,0 +1 @@ +# Podman sandbox environment for inspect_ai diff --git a/packages/dash_evals/src/dash_evals/runner/sandboxes/podman/podman.py b/packages/dash_evals/src/dash_evals/runner/sandboxes/podman/podman.py new file mode 100644 index 0000000..f205747 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/sandboxes/podman/podman.py @@ -0,0 +1,301 @@ +"""Simple Podman sandbox environment for inspect_ai. + +This module provides a minimal PodmanSandboxEnvironment that uses the Podman CLI +directly (not podman-compose) for running containers. +""" + +import asyncio +import base64 +import logging +import os +import tempfile +from pathlib import Path +from typing import Literal, Union, overload + +from inspect_ai.util._sandbox.environment import ( + SandboxConnection, + SandboxEnvironment, + SandboxEnvironmentConfigType, +) +from inspect_ai.util._sandbox.limits import SandboxEnvironmentLimits +from inspect_ai.util._sandbox.registry import sandboxenv +from inspect_ai.util._subprocess import ExecResult + +logger = logging.getLogger(__name__) + +# Default Flutter sandbox image (built from dataset/sandboxes/podman/Containerfile) +DEFAULT_IMAGE = "localhost/flutter-sandbox:latest" + + +@sandboxenv(name="podman") +class PodmanSandboxEnvironment(SandboxEnvironment): + """Simple Podman-based sandbox environment.""" + + def __init__(self, container_id: str, working_dir: str = "/workspace"): + super().__init__() + self.container_id = container_id + self._working_dir = working_dir + + @classmethod + def config_files(cls) -> list[str]: + return ["compose.yaml", "Containerfile", "Dockerfile"] + + @classmethod + def default_concurrency(cls) -> int | None: + return (os.cpu_count() or 1) * 2 + + @classmethod + async def task_init(cls, task_name: str, config: SandboxEnvironmentConfigType | None) -> None: + """Validate podman is available.""" + try: + proc = await asyncio.create_subprocess_exec( + "podman", + "--version", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + _, stderr = await proc.communicate() + if proc.returncode != 0: + raise RuntimeError(f"Podman check failed: {stderr.decode()}") + except FileNotFoundError: + raise RuntimeError( + "Podman executable not found. Please ensure podman is installed and in your PATH." + ) + + @classmethod + async def sample_init( + cls, + task_name: str, + config: SandboxEnvironmentConfigType | None, + metadata: dict[str, str], + ) -> dict[str, SandboxEnvironment]: + """Start a container for this sample.""" + # Determine image from config or use default + image = DEFAULT_IMAGE + if isinstance(config, str) and not config.endswith((".yaml", ".yml")): + image = config + + # Start container (no TTY to avoid control chars, sleep to keep running) + # Mount /tmp so workspace files copied by setup_workspace are accessible + tmp_dir = tempfile.gettempdir() + cmd = [ + "podman", + "run", + "-d", + "--rm", + "-v", + f"{tmp_dir}:{tmp_dir}", # Mount temp dir for workspace sharing + image, + "sleep", + "infinity", + ] + + try: + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await proc.communicate() + + if proc.returncode != 0: + raise RuntimeError(f"Failed to start podman container: {stderr.decode()}") + except FileNotFoundError: + raise RuntimeError( + "Podman executable not found. Please ensure podman is installed and in your PATH." + ) + + container_id = stdout.decode().strip() + logger.info(f"Started podman container: {container_id[:12]}") + + return {"default": cls(container_id=container_id)} + + @classmethod + async def sample_cleanup( + cls, + task_name: str, + config: SandboxEnvironmentConfigType | None, + environments: dict[str, SandboxEnvironment], + interrupted: bool, + ) -> None: + """Stop and remove containers.""" + for env in environments.values(): + if isinstance(env, PodmanSandboxEnvironment): + logger.info(f"Cleaning up container: {env.container_id[:12]}") + await asyncio.create_subprocess_exec( + "podman", + "rm", + "-f", + env.container_id, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + + @classmethod + async def task_cleanup( + cls, task_name: str, config: SandboxEnvironmentConfigType | None, cleanup: bool + ) -> None: + """No task-level cleanup needed - containers are removed per-sample.""" + pass + + @classmethod + async def cli_cleanup(cls, id: str | None) -> None: + """CLI cleanup for orphaned containers.""" + if id: + await asyncio.create_subprocess_exec( + "podman", + "rm", + "-f", + id, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.DEVNULL, + ) + + async def exec( + self, + cmd: list[str], + input: str | bytes | None = None, + cwd: str | None = None, + env: dict[str, str] | None = None, + user: str | None = None, + timeout: int | None = None, + timeout_retry: bool = True, + concurrency: bool = True, + truncate: bool = True, + ) -> ExecResult[str]: + """Execute command inside the container.""" + if env is None: + env = {} + podman_cmd = ["podman", "exec", "-i"] + + # Working directory + final_cwd = cwd if cwd else self._working_dir + podman_cmd.extend(["--workdir", final_cwd]) + + # User + if user: + podman_cmd.extend(["--user", user]) + + # Environment variables + for k, v in env.items(): + podman_cmd.extend(["--env", f"{k}={v}"]) + + podman_cmd.append(self.container_id) + podman_cmd.extend(cmd) + + proc = await asyncio.create_subprocess_exec( + *podman_cmd, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + try: + stdin_data = input.encode() if isinstance(input, str) else input + stdout, stderr = await asyncio.wait_for( + proc.communicate(input=stdin_data), + timeout=timeout, + ) + except asyncio.TimeoutError: + try: + proc.kill() + except ProcessLookupError: + pass + raise TimeoutError(f"Command timed out after {timeout}s") from None + + stdout_decoded = stdout.decode("utf-8", errors="replace") + stderr_decoded = stderr.decode("utf-8", errors="replace") + + # Truncate if too large + if truncate: + limit = SandboxEnvironmentLimits.MAX_EXEC_OUTPUT_SIZE + if len(stdout_decoded) > limit: + stdout_decoded = stdout_decoded[:limit] + "...[TRUNCATED]" + if len(stderr_decoded) > limit: + stderr_decoded = stderr_decoded[:limit] + "...[TRUNCATED]" + + return ExecResult( + success=(proc.returncode == 0), + returncode=proc.returncode or 0, + stdout=stdout_decoded, + stderr=stderr_decoded, + ) + + async def write_file(self, file: str, contents: str | bytes) -> None: + """Write file to container.""" + # Ensure directory exists + dir_path = Path(file).parent.as_posix() + if dir_path and dir_path != ".": + await self.exec(["mkdir", "-p", dir_path]) + + # Handle binary content with base64 encoding via stdin (not command line) + if isinstance(contents, bytes): + b64_content = base64.b64encode(contents).decode("ascii") + result = await self.exec( + ["sh", "-c", f'base64 -d > "{file}"'], + input=b64_content, + ) + else: + result = await self.exec( + ["sh", "-c", f'cat > "{file}"'], + input=contents, + ) + + if not result.success: + if "permission denied" in result.stderr.lower(): + raise PermissionError(f"Permission denied writing {file}") + raise RuntimeError(f"Failed to write file {file}: {result.stderr}") + + @overload + async def read_file(self, file: str, text: Literal[True] = True) -> str: ... + + @overload + async def read_file(self, file: str, text: Literal[False]) -> bytes: ... + + async def read_file(self, file: str, text: bool = True) -> Union[str, bytes]: + """Read file from container.""" + if text: + # Text mode: use cat directly + result = await self.exec(["cat", file], truncate=False) + if not result.success: + if "No such file" in result.stderr: + raise FileNotFoundError(f"File not found: {file}") + if "permission denied" in result.stderr.lower(): + raise PermissionError(f"Permission denied reading {file}") + raise RuntimeError(f"Failed to read file {file}: {result.stderr}") + + if len(result.stdout) > SandboxEnvironmentLimits.MAX_READ_FILE_SIZE: + raise RuntimeError(f"File {file} exceeds size limit") + return result.stdout + else: + # Binary mode: use base64 to transfer safely (-w0 disables line wrapping) + result = await self.exec(["sh", "-c", f'base64 -w0 "{file}"'], truncate=False) + if not result.success: + if "No such file" in result.stderr: + raise FileNotFoundError(f"File not found: {file}") + if "permission denied" in result.stderr.lower(): + raise PermissionError(f"Permission denied reading {file}") + raise RuntimeError(f"Failed to read file {file}: {result.stderr}") + + decoded = base64.b64decode(result.stdout.strip()) + if len(decoded) > SandboxEnvironmentLimits.MAX_READ_FILE_SIZE: + raise RuntimeError(f"File {file} exceeds size limit") + return decoded + + async def connection(self, *, user: str | None = None) -> SandboxConnection: + """Get connection info for debugging.""" + cmd_parts = ["podman", "exec", "-it"] + if user: + cmd_parts.extend(["--user", user]) + cmd_parts.extend([self.container_id, "/bin/bash"]) + + return SandboxConnection( + type="podman", + command=" ".join(cmd_parts), + vscode_command=None, + ports=None, + container=self.container_id, + ) + + def default_polling_interval(self) -> float: + return 0.2 diff --git a/packages/dash_evals/src/dash_evals/runner/sandboxes/provider.py b/packages/dash_evals/src/dash_evals/runner/sandboxes/provider.py new file mode 100644 index 0000000..1153aca --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/sandboxes/provider.py @@ -0,0 +1,4 @@ +def podman(): + from dash_evals.runner.sandboxes.podman.podman import PodmanSandboxEnvironment + + return PodmanSandboxEnvironment diff --git a/packages/dash_evals/src/dash_evals/runner/scorers/__init__.py b/packages/dash_evals/src/dash_evals/runner/scorers/__init__.py new file mode 100644 index 0000000..bc75527 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/scorers/__init__.py @@ -0,0 +1,20 @@ +"""Custom scorers for dash-evals tasks.""" + +from .code_quality import code_quality_scorer +from .dart_analyze import dart_analyze_scorer +from .export_workspace import export_workspace +from .flutter_code import flutter_code_scorer +from .flutter_test import flutter_test_scorer +from .mcp_tool_usage import DART_MCP_TOOLS, mcp_tool_usage +from .skill_usage import skill_usage_scorer + +__all__ = [ + "code_quality_scorer", + "dart_analyze_scorer", + "export_workspace", + "flutter_code_scorer", + "flutter_test_scorer", + "DART_MCP_TOOLS", + "mcp_tool_usage", + "skill_usage_scorer", +] diff --git a/packages/dash_evals/src/dash_evals/runner/scorers/code_quality.py b/packages/dash_evals/src/dash_evals/runner/scorers/code_quality.py new file mode 100644 index 0000000..25b5a96 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/scorers/code_quality.py @@ -0,0 +1,132 @@ +"""LLM-graded code quality scorer. + +Reusable scorer that uses an LLM to evaluate subjective code quality aspects. +""" + +import json +import re + +from inspect_ai.model import get_model +from inspect_ai.scorer import Score, Scorer, Target, mean, scorer, stderr +from inspect_ai.solver import TaskState + +DEFAULT_RUBRIC = """ +Evaluate this code fix on subjective quality (0-3 scale): + +1. **Minimality**: Is the fix focused? Does it avoid unnecessary changes? + - 0: Bloated, touches unrelated code or adds unnecessary complexity + - 1: Some unnecessary changes but mostly focused + - 2: Focused with minor extras + - 3: Surgical, changes only what's needed + +2. **Elegance**: Would a senior developer approve of this approach? + - 0: Hacky, works but ugly or non-idiomatic + - 1: Works but has style issues + - 2: Good but not exemplary + - 3: Clean, idiomatic, follows language conventions + +3. **Robustness**: Does it handle edge cases appropriately? + - 0: Fragile, likely breaks on edge cases + - 1: Handles basic cases only + - 2: Handles most edge cases + - 3: Defensive, handles nulls/empty states/errors gracefully + +Respond with ONLY a JSON object (no markdown): +{"minimality": N, "elegance": N, "robustness": N, "reasoning": "Brief explanation"} +""" + + +@scorer(metrics=[mean(), stderr()]) +def code_quality_scorer(rubric: str | None = None, model: str | None = None) -> Scorer: + """ + Score code quality using LLM judgment. + + Uses a rubric to evaluate subjective aspects of code quality that + static analysis can't capture: minimality, elegance, robustness. + + Args: + rubric: Custom rubric prompt. If None, uses default Dart/Flutter rubric. + model: Model to use for grading. If None, uses the task's model. + + Returns: + A Scorer that evaluates code quality on a 0-1 scale. + """ + grading_rubric = rubric or DEFAULT_RUBRIC + + async def score(state: TaskState, target: Target) -> Score: + code = state.output.completion + + # Build grading prompt + prompt = f"{grading_rubric}\n\nCode to evaluate:\n```dart\n{code}\n```" + + # Get grader model + grader = get_model(model) if model else get_model() + + try: + result = await grader.generate(prompt) + response_text = result.completion + + # Parse JSON from response + scores = _parse_json_response(response_text) + + if scores is None: + return Score( + value=0.0, + explanation=f"Failed to parse grader response: {response_text[:500]}", + metadata={"raw_response": response_text}, + ) + + # Calculate normalized score (0-1) + # Use `or 0` pattern to handle None values (not just missing keys) + minimality = scores.get("minimality") or 0 + elegance = scores.get("elegance") or 0 + robustness = scores.get("robustness") or 0 + total = minimality + elegance + robustness + normalized_score = total / 9.0 # Max possible is 9 (3 + 3 + 3) + + return Score( + value=normalized_score, + explanation=scores.get("reasoning", "No reasoning provided"), + metadata={ + "minimality": minimality, + "elegance": elegance, + "robustness": robustness, + "raw_response": response_text, + }, + ) + + except Exception as e: + return Score( + value=0.0, + explanation=f"Grading failed: {e!s}", + metadata={"error": str(e)}, + ) + + return score + + +def _parse_json_response(text: str) -> dict | None: + """Extract JSON from LLM response, handling markdown code blocks.""" + # Try direct parse first + try: + return json.loads(text.strip()) + except json.JSONDecodeError: + pass + + # Try extracting from markdown code block + match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL) + if match: + try: + return json.loads(match.group(1)) + except json.JSONDecodeError: + pass + + # Try finding any JSON object in the text + match = re.search(r"\{[^{}]*\}", text) + if match: + try: + return json.loads(match.group(0)) + except json.JSONDecodeError: + pass + + return None diff --git a/packages/dash_evals/src/dash_evals/runner/scorers/dart_analyze.py b/packages/dash_evals/src/dash_evals/runner/scorers/dart_analyze.py new file mode 100644 index 0000000..f6d0e6a --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/scorers/dart_analyze.py @@ -0,0 +1,136 @@ +"""Dart static analysis scorer. + +Reusable scorer that runs ``dart analyze`` on auto-discovered project roots +and scores based on output. +""" + +import os + +from inspect_ai.scorer import CORRECT, INCORRECT, Score, Scorer, Target, accuracy, scorer +from inspect_ai.solver import TaskState +from inspect_ai.util import sandbox + + +@scorer(metrics=[accuracy()]) +def dart_analyze_scorer(strict: bool = False, project_dir: str | None = None) -> Scorer: + """ + Score based on dart static analysis results. + + Scoping behavior (in priority order): + + 1. If ``project_dir`` argument is set, analyze only that subdirectory. + 2. If ``state.metadata["project_dir"]`` exists, use that. + 3. Fall back to auto-discovering all ``pubspec.yaml`` files. + + Scores: + - CORRECT if no errors in any project (and no warnings if strict=True) + - INCORRECT if any project has errors + + Args: + strict: If True, also fail on warnings. Default False (only errors fail). + project_dir: Optional subdirectory to scope analysis to. + Relative to the workspace root. + + Returns: + A Scorer that evaluates Dart code quality via static analysis. + """ + + async def score(state: TaskState, target: Target) -> Score: + sb = sandbox() + workspace = state.metadata.get("workspace") + + if not workspace: + return Score( + value=INCORRECT, + explanation="No workspace found - setup may have failed", + ) + + # Determine target project directory(ies) + scope = project_dir or state.metadata.get("project_dir") # noqa: F823 + + if scope: + # Scoped to a specific project subdirectory + project_dirs = [scope] + else: + # Discover all Dart/Flutter projects by finding pubspec.yaml files + find_result = await sb.exec( + ["find", ".", "-name", "pubspec.yaml", "-not", "-path", "*/.*"], + cwd=workspace, + timeout=30, + ) + + pubspec_paths = [ + p.strip() for p in (find_result.stdout or "").splitlines() if p.strip() + ] + + if not pubspec_paths: + # Fallback: try analyzing workspace root directly + pubspec_paths = ["."] + + # Derive project directories from pubspec.yaml paths + project_dirs = sorted({os.path.dirname(p) or "." for p in pubspec_paths}) + + # Run dart analyze in each project directory + all_outputs: list[str] = [] + has_errors = False + has_warnings = False + + for proj_dir in project_dirs: + project_cwd = os.path.join(workspace, proj_dir) + + args = ["dart", "analyze", "."] + if strict: + args.append("--fatal-infos") + + result = await sb.exec(args, cwd=project_cwd, timeout=60) + + stdout = result.stdout or "" + stderr = result.stderr or "" + output = stdout + stderr + + # Tag output with the project directory for clarity + labeled = f"[{proj_dir}] {output.strip()}" + all_outputs.append(labeled) + + if "error •" in output.lower() or result.returncode != 0: + has_errors = True + if "warning •" in output.lower(): + has_warnings = True + + combined = "\n\n".join(all_outputs) + + if has_errors: + return Score( + value=INCORRECT, + explanation=f"Static analysis failed:\n{combined[:2000]}", + metadata={ + "analyze_output": combined, + "projects_analyzed": project_dirs, + }, + ) + + if strict and has_warnings: + return Score( + value=INCORRECT, + explanation=f"Static analysis has warnings (strict mode):\n{combined[:2000]}", + metadata={ + "analyze_output": combined, + "projects_analyzed": project_dirs, + }, + ) + + # Count info-level issues across all projects + info_count = combined.lower().count("info •") + + return Score( + value=CORRECT, + explanation=f"Static analysis passed across {len(project_dirs)} project(s) " + f"({info_count} info-level issues)", + metadata={ + "analyze_output": combined, + "info_count": info_count, + "projects_analyzed": project_dirs, + }, + ) + + return score diff --git a/packages/dash_evals/src/dash_evals/runner/scorers/export_workspace.py b/packages/dash_evals/src/dash_evals/runner/scorers/export_workspace.py new file mode 100644 index 0000000..dfcb7e5 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/scorers/export_workspace.py @@ -0,0 +1,131 @@ +"""Scorer that exports the final workspace to the examples directory. + +This scorer is a side-effect scorer: it copies the agent's finished workspace +to /examples/:// so eval authors can +inspect the exact code produced during a run. + +It is only added to a Task's scorer list when task_config.save_examples=True, +so it can assume it should always run — no runtime guard needed. + +== How It Works == + +Scorers run while the sandbox container is still alive. We use the sandbox +API to create a tar archive of /workspace inside the container, read it out +via sandbox().read_file(), then extract it on the host into examples_dir. + +This scorer must run LAST in the scorer list so it captures the final state +of the workspace after all code edits are complete. +""" + +import io +import tarfile +from pathlib import Path + +from inspect_ai.scorer import CORRECT, Score, Scorer, Target, scorer +from inspect_ai.solver import TaskState +from inspect_ai.util import sandbox + +# Paths to exclude from the exported workspace copy (passed to tar --exclude). +_TAR_EXCLUDES = [ + "./build", + "./.dart_tool", + "./.packages", + "./.flutter-plugins", + "./.flutter-plugins-dependencies", + "./.git", + "./.idea", + "./.vscode", + "./pubspec.lock", + # Generated Dart code + "./**/*.g.dart", + "./**/*.freezed.dart", + "./**/*.mocks.dart", +] + + +@scorer(metrics=[]) +def export_workspace() -> Scorer: + """Copy the final workspace to the examples directory alongside logs. + + Reads ``examples_dir`` and ``save_examples`` from ``state.metadata``. + Uses the sandbox API to tar the workspace inside the container and + extract it on the host — works with any sandbox type (podman/docker). + + The destination path is:: + + /:// + + This scorer is only added to the Task scorer list when + ``task_config.save_examples=True``, so it always runs unconditionally. + """ + + async def score(state: TaskState, target: Target) -> Score: + workspace = state.metadata.get("workspace") + examples_dir = state.metadata.get("examples_dir") + + if not examples_dir: + return Score( + value=CORRECT, + explanation="examples_dir not set in metadata — skipping export", + ) + + if not workspace: + return Score( + value=CORRECT, + explanation="No workspace in metadata — nothing to export", + ) + + # Build the destination path: examples/// + task_variant = state.metadata.get("task_variant", "unknown") + sample_id = str(state.sample_id) if state.sample_id is not None else "unknown" + dest = Path(examples_dir) / task_variant / sample_id + + try: + dest.mkdir(parents=True, exist_ok=True) + await _export_via_sandbox(workspace, dest) + except Exception as e: + return Score( + value=CORRECT, + explanation=f"Export failed (non-fatal): {e}", + metadata={"export_error": str(e)}, + ) + + return Score( + value=CORRECT, + explanation=f"Workspace exported to {dest}", + metadata={"exported_to": str(dest)}, + ) + + return score + + +async def _export_via_sandbox(workspace: str, dest: Path) -> None: + """Archive workspace inside the container, read the tar, extract on host. + + Args: + workspace: Absolute path to the workspace directory inside the container. + dest: Host-side destination directory to extract into. + + Raises: + RuntimeError: If the tar command fails inside the container. + """ + exclude_args = [] + for pattern in _TAR_EXCLUDES: + exclude_args.extend(["--exclude", pattern]) + + # Create a tar archive of the workspace inside the container. + # Write to a temp file inside the container so we can read_file() it. + archive_path = "/tmp/_export_workspace.tar" + result = await sandbox().exec(["tar", "-cf", archive_path, *exclude_args, "-C", workspace, "."]) + if not result.success: + raise RuntimeError(f"tar failed inside container: {result.stderr}") + + # Read the tar bytes out through the sandbox API. + tar_bytes = await sandbox().read_file(archive_path, text=False) + + # Extract on the host. + with tarfile.open(fileobj=io.BytesIO(tar_bytes)) as tf: + tf.extractall(dest, filter="data") + + # Clean up the temp archive inside the container (best-effort). + await sandbox().exec(["rm", "-f", archive_path]) diff --git a/packages/dash_evals/src/dash_evals/runner/scorers/flutter_code.py b/packages/dash_evals/src/dash_evals/runner/scorers/flutter_code.py new file mode 100644 index 0000000..a3c5933 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/scorers/flutter_code.py @@ -0,0 +1,97 @@ +"""Scorer for Flutter code quality evaluation.""" + +from inspect_ai.scorer import Score, Scorer, Target, accuracy, scorer +from inspect_ai.solver import TaskState +from inspect_ai.util import sandbox + +from .flutter_output_parser import parse_analyzer_output, parse_test_output +from .flutter_scoring import ( + calculate_analyzer_score, + calculate_final_score, + calculate_test_score, + validate_code_structure, +) + + +@scorer(metrics=[accuracy()]) +def flutter_code_scorer() -> Scorer: + """ + Custom scorer that evaluates Flutter code based on: + + 1. Code analysis (flutter analyze) + 2. Test results (flutter test) + 3. Code structure validation + + The final score is a weighted combination of these factors: + + - Analyzer: 30% + - Tests: 50% + - Structure: 20% + + A score >= 0.7 is considered passing for the accuracy metric. + + Returns: + A Scorer that evaluates Flutter code quality. + """ + + async def score(state: TaskState, target: Target) -> Score: + # Check for setup errors first + if setup_error := state.metadata.get("setup_error"): + return Score( + value=0.0, + answer="", + explanation=f"āœ— Setup failed: {setup_error}", + metadata={"setup_error": setup_error}, + ) + + sb = sandbox() + workspace = state.metadata.get("workspace") + + if not workspace: + return Score(value=0.0, explanation="No workspace found - setup may have failed") + + explanation_parts = [] + + # 1. Run flutter analyze + analyze_result = await sb.exec(["flutter", "analyze", "--no-pub"], cwd=workspace) + + if analyze_result.success: + output = analyze_result.stdout + analyze_result.stderr + analyzer_result = parse_analyzer_output(output) + analyzer_score, analyzer_explanation = calculate_analyzer_score(analyzer_result) + explanation_parts.append(analyzer_explanation) + else: + analyzer_score = 0.0 + explanation_parts.append("āœ— Code analysis failed (syntax errors)") + + # 2. Run flutter test + test_result = await sb.exec(["flutter", "test", "--no-pub"], cwd=workspace) + output = test_result.stdout + test_result.stderr + test_result_parsed = parse_test_output(output, test_result.success) + test_score, test_explanation = calculate_test_score(test_result_parsed) + explanation_parts.append(test_explanation) + + # 3. Validate code structure + code = state.metadata.get("generated_code", "") + required_widgets = state.metadata.get("required_widgets", []) + structure_score, structure_explanation = validate_code_structure(code, required_widgets) + explanation_parts.append(structure_explanation) + + # Calculate final score + final_score = calculate_final_score(analyzer_score, test_score, structure_score) + + return Score( + value=final_score, # Return actual weighted score (0.0-1.0) + answer=state.output.completion[:200] + "...", + explanation="\n".join(explanation_parts), + metadata={ + "analyzer_score": analyzer_score, + "test_score": test_score, + "structure_score": structure_score, + "final_score": final_score, + "analyzer_output": analyze_result.stdout if analyze_result else "", + "test_output": test_result.stdout if test_result else "", + }, + ) + + return score diff --git a/packages/dash_evals/src/dash_evals/runner/scorers/flutter_output_parser.py b/packages/dash_evals/src/dash_evals/runner/scorers/flutter_output_parser.py new file mode 100644 index 0000000..241bebc --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/scorers/flutter_output_parser.py @@ -0,0 +1,102 @@ +"""Parsers for Flutter command outputs.""" + +from dataclasses import dataclass + + +@dataclass +class AnalyzerResult: + """Parsed flutter analyze output.""" + + error_count: int + warning_count: int + info_count: int + raw_output: str + + +@dataclass +class TestResult: + """Parsed flutter test output.""" + + passed: bool + raw_output: str + passed_count: int = 0 + failed_count: int = 0 + + +def parse_analyzer_output(output: str) -> AnalyzerResult: + """ + Parse flutter analyze output to count errors, warnings, and info messages. + + Args: + output: Combined stdout and stderr from flutter analyze + + Returns: + AnalyzerResult with counts and raw output + + Examples: + >>> result = parse_analyzer_output("error • Something wrong\\nwarning • Be careful") + >>> result.error_count + 1 + >>> result.warning_count + 1 + """ + output_lower = output.lower() + error_count = output_lower.count("error •") + warning_count = output_lower.count("warning •") + info_count = output_lower.count("info •") + + return AnalyzerResult( + error_count=error_count, + warning_count=warning_count, + info_count=info_count, + raw_output=output, + ) + + +def parse_test_output(output: str, success: bool) -> TestResult: + """ + Parse flutter test output to determine test results. + + Args: + output: Combined stdout and stderr from flutter test + success: Whether the test command succeeded + + Returns: + TestResult with pass/fail status and raw output + + Examples: + >>> result = parse_test_output("All tests passed!", success=True) + >>> result.passed + True + """ + if success: + return TestResult(passed=True, raw_output=output) + + # Parse test output to count passed/failed + if "All tests passed" in output or "all tests passed" in output: + return TestResult(passed=True, raw_output=output) + elif "+0" in output or "No tests" in output: + return TestResult(passed=False, passed_count=0, raw_output=output) + else: + # Partial credit for some passing tests + # Try to extract counts from output like "+3 -2" + passed_count = 0 + failed_count = 0 + + # Look for patterns like "+3" and "-2" + import re + + passed_match = re.search(r"\+(\d+)", output) + failed_match = re.search(r"-(\d+)", output) + + if passed_match: + passed_count = int(passed_match.group(1)) + if failed_match: + failed_count = int(failed_match.group(1)) + + return TestResult( + passed=False, + passed_count=passed_count, + failed_count=failed_count, + raw_output=output, + ) diff --git a/packages/dash_evals/src/dash_evals/runner/scorers/flutter_scoring.py b/packages/dash_evals/src/dash_evals/runner/scorers/flutter_scoring.py new file mode 100644 index 0000000..b42a503 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/scorers/flutter_scoring.py @@ -0,0 +1,152 @@ +"""Scoring utilities for Flutter code evaluation.""" + +from typing import Tuple + +from dash_evals.runner.scorers.flutter_output_parser import AnalyzerResult, TestResult + + +def calculate_analyzer_score(result: AnalyzerResult) -> Tuple[float, str]: + """ + Calculate score from analyzer results. + + Args: + result: Parsed analyzer output + + Returns: + Tuple of (score, explanation) + + Examples: + >>> from dash_evals.utils.flutter_output_parser import AnalyzerResult, TestResult + >>> result = AnalyzerResult(0, 0, 0, "") + >>> score, explanation = calculate_analyzer_score(result) + >>> score + 1.0 + """ + total_issues = result.error_count + result.warning_count + (result.info_count * 0.5) + + if total_issues == 0: + return 1.0, "āœ“ No analyzer issues" + elif total_issues <= 3: + return ( + 0.7, + f"⚠ Minor issues: {result.error_count} errors, {result.warning_count} warnings", + ) + else: + return ( + 0.3, + f"āœ— Multiple issues: {result.error_count} errors, {result.warning_count} warnings", + ) + + +def calculate_test_score(result: TestResult) -> Tuple[float, str]: + """ + Calculate score from test results as a percentage of tests passed. + + Args: + result: Parsed test output + + Returns: + Tuple of (score, explanation) where score is the percentage of tests passed (0.0-1.0) + + Examples: + >>> from dash_evals.parsers.flutter_output_parser import TestResult + >>> result = TestResult(passed=True, raw_output="") + >>> score, explanation = calculate_test_score(result) + >>> score + 1.0 + """ + if result.passed: + return 1.0, "āœ“ All tests passed" + + # Calculate percentage based on actual test counts + total_tests = result.passed_count + result.failed_count + + if total_tests == 0: + return 0.0, "āœ— No tests found or executed" + + pass_rate = result.passed_count / total_tests + + if pass_rate == 1.0: + return 1.0, f"āœ“ All tests passed ({result.passed_count}/{total_tests})" + elif pass_rate == 0.0: + return 0.0, f"āœ— All tests failed (0/{total_tests})" + else: + return ( + pass_rate, + f"⚠ {result.passed_count}/{total_tests} tests passed ({pass_rate:.0%})", + ) + + +def validate_code_structure(code: str, required_widgets: list) -> Tuple[float, str]: + """ + Validate that code contains required structural elements. + + Args: + code: The generated Dart code + required_widgets: List of required widget names from metadata + + Returns: + Tuple of (score, explanation) + + Examples: + >>> code = "class MyApp extends StatelessWidget { MaterialApp() }" + >>> score, explanation = validate_code_structure(code, ["TextField"]) + >>> score >= 0.7 + True + """ + required_elements = [] + + # Check for required widgets from target + if "MyApp" in code: + required_elements.append("MyApp class") + if "StatefulWidget" in code or "StatelessWidget" in code: + required_elements.append("Widget structure") + if "MaterialApp" in code: + required_elements.append("MaterialApp") + + # Check for specific requirements from metadata + for widget in required_widgets: + if widget in code: + required_elements.append(widget) + + # Score based on required elements + if len(required_elements) >= len(required_widgets) + 2: + return 1.0, f"āœ“ Contains required elements: {', '.join(required_elements)}" + elif len(required_elements) >= len(required_widgets): + return 0.7, "⚠ Missing some elements" + else: + return 0.3, "āœ— Missing required elements" + + +def calculate_final_score( + analyzer_score: float, + test_score: float, + structure_score: float, + weights: dict | None = None, +) -> float: + """ + Calculate weighted final score. + + Args: + analyzer_score: Code quality score (0.0-1.0) + test_score: Test pass rate (0.0-1.0) + structure_score: Code structure score (0.0-1.0) + weights: Optional custom weights (default: {"analyzer": 0.3, "test": 0.5, "structure": 0.2}) + + Returns: + Weighted final score (0.0-1.0) + + Examples: + >>> calculate_final_score(1.0, 1.0, 1.0) + 1.0 + >>> calculate_final_score(0.0, 1.0, 0.0) # Test score is 50% of total + 0.5 + """ + if weights is None: + weights = {"analyzer": 0.3, "test": 0.5, "structure": 0.2} + + return ( + analyzer_score * weights["analyzer"] + + test_score * weights["test"] + + structure_score * weights["structure"] + ) diff --git a/packages/dash_evals/src/dash_evals/runner/scorers/flutter_test.py b/packages/dash_evals/src/dash_evals/runner/scorers/flutter_test.py new file mode 100644 index 0000000..1216433 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/scorers/flutter_test.py @@ -0,0 +1,113 @@ +"""Flutter test runner scorer. + +Reusable scorer that runs ``flutter test`` and scores based on pass/fail. +""" + +from inspect_ai.scorer import Score, Scorer, Target, accuracy, scorer +from inspect_ai.solver import TaskState +from inspect_ai.util import sandbox + + +@scorer(metrics=[accuracy()]) +def flutter_test_scorer(test_path: str = "test/") -> Scorer: + """ + Score based on Flutter test results. + + Runs ``flutter test`` on the specified path and scores: + - CORRECT if all tests pass + - INCORRECT if any tests fail + + Args: + test_path: Path to test directory or file. Default "test/". + + Returns: + A Scorer that evaluates code by running Flutter tests. + """ + + async def score(state: TaskState, target: Target) -> Score: + sb = sandbox() + workspace = state.metadata.get("workspace") + + if not workspace: + return Score( + value=0.0, + explanation="No workspace found - setup may have failed", + ) + + # Run flutter test + # Scope to project_dir if set in metadata (for multi-project repos) + cwd = workspace + metadata_project_dir = state.metadata.get("project_dir") + if metadata_project_dir: + import os + + cwd = os.path.join(workspace, metadata_project_dir) + + result = await sb.exec( + ["flutter", "test", test_path, "--no-pub"], + cwd=cwd, + timeout=180, + ) + + stdout = result.stdout or "" + stderr = result.stderr or "" + output = stdout + stderr + + # Parse test results + test_info = _parse_test_output(output) + total_tests = test_info["passed"] + test_info["failed"] + + if total_tests == 0: + return Score( + value=0.0, + explanation="No tests found or executed", + metadata={"test_output": output, "passed": 0, "failed": 0}, + ) + + pass_rate = test_info["passed"] / total_tests + + if result.returncode == 0: + return Score( + value=1.0, + explanation=f"All tests passed ({test_info['passed']} tests)", + metadata={ + "test_output": output, + "passed": test_info["passed"], + "failed": 0, + "pass_rate": 1.0, + }, + ) + else: + return Score( + value=pass_rate, # Return actual percentage + explanation=f"{test_info['passed']}/{total_tests} tests passed ({pass_rate:.0%}):\n{output[:1500]}", + metadata={ + "test_output": output, + "passed": test_info["passed"], + "failed": test_info["failed"], + "pass_rate": pass_rate, + }, + ) + + return score + + +def _parse_test_output(output: str) -> dict: + """Parse flutter test output to extract pass/fail counts.""" + import re + + # Normalize carriage returns to make regex work on all line endings + output = output.replace("\r\n", "\n").replace("\r", "\n") + + # Look for patterns like "+3 -1" or "+5" + # Format: "00:04 +3 -1: Some tests failed" (find the LAST occurrence) + matches = re.findall(r"\+(\d+)(?:\s+-(\d+))?:", output) + + if matches: + # Take the last match - this gives the final test counts + last_match = matches[-1] + passed = int(last_match[0]) + failed = int(last_match[1]) if last_match[1] else 0 + return {"passed": passed, "failed": failed} + + return {"passed": 0, "failed": 0} diff --git a/packages/dash_evals/src/dash_evals/runner/scorers/mcp_tool_usage.py b/packages/dash_evals/src/dash_evals/runner/scorers/mcp_tool_usage.py new file mode 100644 index 0000000..6fdce4a --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/scorers/mcp_tool_usage.py @@ -0,0 +1,162 @@ +"""Scorer for verifying MCP tool usage during evaluations.""" + +from inspect_ai.scorer import Score, Scorer, Target, accuracy, scorer +from inspect_ai.solver import TaskState + +# Complete list of Dart MCP server tools from: +# https://github.com/dart-lang/ai/tree/main/pkgs/dart_mcp_server +DART_MCP_TOOLS: set[str] = { + "add_roots", + "analyze_files", + "connect_dart_tooling_daemon", + "create_project", + "dart_fix", + "dart_format", + "flutter_driver", + "get_active_location", + "get_app_logs", + "get_runtime_errors", + "get_selected_widget", + "get_widget_tree", + "hot_reload", + "hot_restart", + "hover", + "launch_app", + "list_devices", + "list_running_apps", + "pub", + "pub_dev_search", + "read_package_uris", + "remove_roots", + "resolve_workspace_symbol", + "rip_grep_packages", + "run_tests", + "set_widget_selection_mode", + "signature_help", + "stop_app", +} + + +@scorer(metrics=[accuracy()]) +def mcp_tool_usage( + mcp_server_name: str = "Dart", + mcp_tool_names: list[str] | None = None, + required_tools: list[str] | None = None, +) -> Scorer: + """ + Scorer that checks if an MCP tool from the specified server was called. + + This scorer examines the message history to determine whether the model + actually used an MCP tool (vs. answering from its training data). + + Args: + mcp_server_name: The name prefix of the MCP server tools. Tools matching + "{mcp_server_name}_*" pattern will be identified as + MCP tools. + mcp_tool_names: Optional list of specific tool names to identify as MCP + tools. If not provided and mcp_server_name is "Dart", + defaults to the full DART_MCP_TOOLS list. + required_tools: Optional list of specific MCP tool names that MUST have + been called for a passing score. If provided, the scorer + checks that every tool in this list was used. If not + provided, any MCP tool usage counts as a pass. + + Returns: + A Scorer that returns "C" if MCP tool(s) were used as required, "I" otherwise. + + Example:: + + from dash_evals.scorers import mcp_tool_usage + + Task( + dataset=my_dataset, + solver=react(), + tools=[dart_mcp_server], + scorer=[ + includes(ignore_case=True), # Check answer correctness + mcp_tool_usage(), # Uses DART_MCP_TOOLS by default + # Or check specific tools: + # mcp_tool_usage(required_tools=["create_project"]), + ], + ) + """ + # Default to DART_MCP_TOOLS for Dart server, otherwise use provided list + if mcp_tool_names is not None: + known_mcp_tools = set(mcp_tool_names) + elif mcp_server_name == "Dart": + known_mcp_tools = DART_MCP_TOOLS + else: + known_mcp_tools = set() + + async def score(state: TaskState, target: Target) -> Score: + # Track all tools called and whether MCP tool was used + tools_called: list[str] = [] + mcp_tool_used = False + mcp_tools_called: list[str] = [] + + # Look through all messages for tool calls + for message in state.messages: + # Check if message has tool_calls attribute (assistant messages with tool use) + if hasattr(message, "tool_calls") and message.tool_calls: + for tool_call in message.tool_calls: + tool_name = tool_call.function + tools_called.append(tool_name) + + # Check if this is an MCP tool: + # 1. Prefixed with server name (e.g., "Dart_search_packages") + # 2. OR in the explicit list of known MCP tool names + is_mcp_tool = tool_name.startswith(f"{mcp_server_name}_") or ( + tool_name in known_mcp_tools + ) + if is_mcp_tool: + mcp_tool_used = True + mcp_tools_called.append(tool_name) + + # Check required_tools if specified + if required_tools: + mcp_tools_called_set = set(mcp_tools_called) + missing_tools = [t for t in required_tools if t not in mcp_tools_called_set] + if missing_tools: + explanation = ( + f"Required MCP tool(s) NOT used: {missing_tools}. " + f"MCP tools called: {mcp_tools_called if mcp_tools_called else 'none'}. " + f"All tools called: {tools_called if tools_called else 'none'}" + ) + return Score( + value="I", + answer=", ".join(mcp_tools_called) if mcp_tools_called else "none", + explanation=explanation, + metadata={ + "mcp_server_name": mcp_server_name, + "mcp_tool_used": mcp_tool_used, + "mcp_tools_called": mcp_tools_called, + "all_tools_called": tools_called, + "required_tools": required_tools, + "missing_tools": missing_tools, + }, + ) + + # Build explanation + if mcp_tool_used: + explanation = ( + f"MCP tool(s) from '{mcp_server_name}' server were used: {mcp_tools_called}" + ) + else: + explanation = ( + f"MCP tool from '{mcp_server_name}' server was NOT used. " + f"All tools called: {tools_called if tools_called else 'none'}" + ) + + return Score( + value="C" if mcp_tool_used else "I", + answer=", ".join(mcp_tools_called) if mcp_tools_called else "none", + explanation=explanation, + metadata={ + "mcp_server_name": mcp_server_name, + "mcp_tool_used": mcp_tool_used, + "mcp_tools_called": mcp_tools_called, + "all_tools_called": tools_called, + }, + ) + + return score diff --git a/packages/dash_evals/src/dash_evals/runner/scorers/skill_usage.py b/packages/dash_evals/src/dash_evals/runner/scorers/skill_usage.py new file mode 100644 index 0000000..8dccd03 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/scorers/skill_usage.py @@ -0,0 +1,67 @@ +"""Scorer for verifying skill usage during evaluations.""" + +from inspect_ai.scorer import Score, Scorer, Target, accuracy, scorer +from inspect_ai.solver import TaskState + +# The skill tool name used by Inspect AI's built-in skill() function. +SKILL_TOOL_NAME = "skill" + + +@scorer(metrics=[accuracy()]) +def skill_usage_scorer() -> Scorer: + """Scorer that checks if the agent used the skill tool. + + Examines the message history to determine whether the model + actually called the skill tool to read/discover available skills, + rather than answering from its training data alone. + + Returns: + A Scorer that returns "C" if the skill tool was used, "I" otherwise. + + Example:: + + from dash_evals.runner.scorers import skill_usage_scorer + + Task( + dataset=my_dataset, + solver=react(tools=[skill_tool, bash(timeout=120)]), + scorer=[ + model_graded_fact(), # Check answer correctness + skill_usage_scorer(), # Check skill tool was used + ], + ) + """ + + async def score(state: TaskState, target: Target) -> Score: + tools_called: list[str] = [] + skill_call_count = 0 + + for message in state.messages: + if hasattr(message, "tool_calls") and message.tool_calls: + for tool_call in message.tool_calls: + tool_name = tool_call.function + tools_called.append(tool_name) + if tool_name == SKILL_TOOL_NAME: + skill_call_count += 1 + + skill_tool_used = skill_call_count > 0 + if skill_tool_used: + explanation = f"Skill tool was used ({skill_call_count} call(s))" + else: + explanation = ( + f"Skill tool was NOT used. " + f"All tools called: {tools_called if tools_called else 'none'}" + ) + + return Score( + value="C" if skill_tool_used else "I", + answer=f"{skill_call_count} skill call(s)", + explanation=explanation, + metadata={ + "skill_tool_used": skill_tool_used, + "skill_call_count": skill_call_count, + "all_tools_called": tools_called, + }, + ) + + return score diff --git a/packages/dash_evals/src/dash_evals/runner/solvers/__init__.py b/packages/dash_evals/src/dash_evals/runner/solvers/__init__.py new file mode 100644 index 0000000..6f7cc9f --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/solvers/__init__.py @@ -0,0 +1,17 @@ +"""Custom solvers for dash-evals tasks.""" + +from .add_system_message import add_system_message +from .context_injector import context_injector +from .extract_code import extract_code +from .inject_test_files import inject_test_files +from .setup_workspace import setup_workspace +from .write_to_sandbox import write_to_sandbox + +__all__ = [ + "add_system_message", + "context_injector", + "extract_code", + "inject_test_files", + "setup_workspace", + "write_to_sandbox", +] diff --git a/packages/dash_evals/src/dash_evals/runner/solvers/add_system_message.py b/packages/dash_evals/src/dash_evals/runner/solvers/add_system_message.py new file mode 100644 index 0000000..67e14b7 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/solvers/add_system_message.py @@ -0,0 +1,27 @@ +"""Solver to add a system message to the conversation.""" + +from inspect_ai.model import ChatMessageSystem +from inspect_ai.solver import Generate, Solver, TaskState, solver + + +@solver +def add_system_message(message: str) -> Solver: + """ + Add a system message without template formatting. + + This avoids the template formatting that system_message() does, + which would fail on curly braces in the message content (e.g., code examples). + + Args: + message: The system message content (literal string, no formatting) + + Returns: + A solver that inserts the system message + """ + + async def solve(state: TaskState, generate: Generate) -> TaskState: + # Insert system message at the beginning + state.messages.insert(0, ChatMessageSystem(content=message)) + return state + + return solve diff --git a/packages/dash_evals/src/dash_evals/runner/solvers/context_injector.py b/packages/dash_evals/src/dash_evals/runner/solvers/context_injector.py new file mode 100644 index 0000000..dd8ebc5 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/solvers/context_injector.py @@ -0,0 +1,42 @@ +"""Solver to inject context files into the conversation.""" + +from inspect_ai.model import ChatMessageUser +from inspect_ai.solver import Generate, Solver, TaskState, solver + + +@solver +def context_injector(context_files: list[dict]) -> Solver: + """ + Inject context files into the conversation. + + This solver inserts context files (like Dart/Flutter best practices) as a user + message after the system message but before the main prompt. + + Args: + context_files: List of context file dicts with 'title', 'version', 'content' keys. + + Returns: + A solver that injects context files into the conversation. + """ + + async def solve(state: TaskState, generate: Generate) -> TaskState: + if not context_files: + return state + + # Build context content from all context files + context_parts = ["## Additional Guidelines\n"] + for cf in context_files: + title = cf.get("title", "Untitled") + version = cf.get("version", "0.0") + content = cf.get("content", "") + context_parts.append(f"\n### {title} (v{version})\n") + context_parts.append(content) + + context_message = "\n".join(context_parts) + + # Insert after system message (index 1) but before user prompt + state.messages.insert(1, ChatMessageUser(content=context_message)) + + return state + + return solve diff --git a/packages/dash_evals/src/dash_evals/runner/solvers/extract_code.py b/packages/dash_evals/src/dash_evals/runner/solvers/extract_code.py new file mode 100644 index 0000000..a234487 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/solvers/extract_code.py @@ -0,0 +1,34 @@ +"""Solver to extract code from markdown responses.""" + +from inspect_ai.solver import Generate, Solver, TaskState, solver + +from dash_evals.utils.markdown import extract_code_from_markdown + + +@solver +def extract_code(language: str = "dart") -> Solver: + """ + Extract code from the model's markdown response and store it. + + This is a pure solver that extracts code and stores it in state.store + without any filesystem side effects. Use write_to_sandbox() to write + the extracted code to the sandbox. + + Args: + language: The programming language to extract (default: "dart") + + Returns: + A solver that extracts code and stores it in state.store["extracted_code"] + """ + + async def solve(state: TaskState, generate: Generate) -> TaskState: + code = state.output.completion + extracted = extract_code_from_markdown(code, language=language) + + # Store in both state.store and state.metadata for compatibility + state.store.set("extracted_code", extracted) + state.metadata["generated_code"] = extracted + + return state + + return solve diff --git a/packages/dash_evals/src/dash_evals/runner/solvers/inject_test_files.py b/packages/dash_evals/src/dash_evals/runner/solvers/inject_test_files.py new file mode 100644 index 0000000..af5b435 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/solvers/inject_test_files.py @@ -0,0 +1,64 @@ +"""Solver to inject test files into the workspace.""" + +import glob +from pathlib import Path + +from inspect_ai.solver import Solver, TaskState, solver +from inspect_ai.util import sandbox + + +@solver +def inject_test_files() -> Solver: + """Inject test files into the workspace. + + Reads test files from the source path and copies them into the workspace + test directory using the sandbox API. Supports glob patterns like ``tests/*``. + """ + + async def solve(state: TaskState, generate) -> TaskState: + # Early termination if previous solver failed + if state.metadata.get("setup_error"): + return state + + tests_path_str = state.metadata.get("tests") + workspace_path_str = state.metadata.get("workspace") + + if not tests_path_str: + return state + + if not workspace_path_str: + state.metadata["setup_error"] = "No workspace path in metadata" + return state + + # Expand glob patterns + test_files = glob.glob(tests_path_str) + if not test_files: + # Try as a literal path if glob returns nothing + tests_path = Path(tests_path_str) + if not tests_path.exists(): + state.metadata["setup_error"] = f"Test file not found: {tests_path}" + return state + test_files = [tests_path_str] + + sb = sandbox() + + for test_file_path in test_files: + tests_path = Path(test_file_path) + if not tests_path.is_file(): + continue + + test_content = tests_path.read_text() + + # Prefix with 'sample_' to avoid overwriting existing workspace tests. + # Files already prefixed are left as-is. + filename = tests_path.name + if not filename.startswith("sample_"): + filename = f"sample_{filename}" + + # Write test file to workspace using sandbox API + target_path = f"{workspace_path_str}/test/{filename}" + await sb.write_file(target_path, test_content) + + return state + + return solve diff --git a/packages/dash_evals/src/dash_evals/runner/solvers/setup_workspace.py b/packages/dash_evals/src/dash_evals/runner/solvers/setup_workspace.py new file mode 100644 index 0000000..d0d4251 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/solvers/setup_workspace.py @@ -0,0 +1,189 @@ +"""Solver to set up a clean workspace by copying template to temp directory. + +CRITICAL: This solver MUST run first. It copies the template to a temp directory +to ensure we NEVER modify the original template. Other solvers should check for +'setup_error' in metadata and terminate early if set. +""" + +import shutil +import tempfile +from pathlib import Path + +from inspect_ai.solver import Solver, TaskState, solver +from inspect_ai.util import sandbox + + +def _get_sandbox_type(state: TaskState) -> str: + """Determine sandbox type from workspace config in metadata. + + Returns: + 'git' if workspace_git is set (clone at runtime) + 'container' if workspace is /workspace (podman/docker) + 'local' if workspace is a host filesystem path + 'none' if no workspace is set + """ + # Check for git workspace first + if state.metadata.get("workspace_git"): + return "git" + + workspace = state.metadata.get("workspace") + if not workspace: + return "none" + if workspace == "/workspace": + return "container" + return "local" + + +async def _setup_container_workspace(state: TaskState) -> TaskState: + """Handle container sandbox setup (podman/docker). + + For container sandboxes, Sample.files and Sample.setup already handle: + - Copying workspace files to /workspace in container + - Running flutter pub get + + This function just validates the setup is correct. + """ + workspace = state.metadata.get("workspace") + + if workspace != "/workspace": + state.metadata["setup_error"] = ( + f"Container workspace expected at /workspace, got: {workspace}" + ) + + # Nothing to do - Sample.files/Sample.setup handled everything + return state + + +async def _setup_local_workspace(state: TaskState) -> TaskState: + """Handle local sandbox setup. + + For local sandboxes: + 1. Copy template to temp directory (isolation) + 2. Run flutter pub get to resolve dependencies + 3. Update metadata['workspace'] to point to the copy + """ + template_path_str = state.metadata.get("workspace") + + # This function is only called when sandbox_type is "local", meaning + # workspace is set and is a host path (not None, not /workspace) + assert template_path_str is not None + + template_path = Path(template_path_str) + + # Validate template exists + if not template_path.exists(): + state.metadata["setup_error"] = f"Template not found: {template_path}" + return state + + # Create temp directory and copy template into it + # Ignore build artifacts and other generated files + temp_dir = tempfile.mkdtemp(prefix="eval_workspace_") + workspace_copy = Path(temp_dir) / template_path.name + + ignore_patterns = shutil.ignore_patterns( + "build", # Flutter/Dart build output + ".dart_tool", # Dart SDK internal + ".packages", # Legacy package resolution + ".flutter-plugins*", # Flutter plugin resolution + "*.iml", # IDE files + ".idea", # IntelliJ IDEA + ".vscode", # VS Code settings + "*.g.dart", # Generated code + "*.freezed.dart", # Freezed generated + "*.mocks.dart", # Mockito generated + ) + shutil.copytree(template_path, workspace_copy, ignore=ignore_patterns) + + # Update metadata to point to the copy + state.metadata["workspace"] = str(workspace_copy) + state.metadata["workspace_template"] = template_path_str # Keep original reference + + # Run dependency install command to resolve deps in the copied workspace. + # This is required because the .dart_tool and .packages files contain + # absolute paths that are invalid after copying to a new location. + dep_cmd = state.metadata.get("dep_install_cmd", ["flutter", "pub", "get"]) + sb = sandbox() + dep_result = await sb.exec( + dep_cmd, + cwd=str(workspace_copy), + ) + + if not dep_result.success: + cmd_str = " ".join(dep_cmd) + state.metadata["setup_error"] = f"{cmd_str} failed: {dep_result.stderr}" + return state + + return state + + +@solver +def setup_workspace() -> Solver: + """Copy workspace template to a temp directory for isolated execution. + + Dispatches to the appropriate setup handler based on sandbox type: + - Git: Clone repository inside sandbox + - Container (podman/docker): Sample.files handles provisioning, just validate + - Local: Copy template to temp dir and run flutter pub get + + If setup fails, sets metadata['setup_error']. Subsequent solvers MUST + check for this and terminate early to prevent writing to the original + template directory. + """ + + async def solve(state: TaskState, generate) -> TaskState: + sandbox_type = _get_sandbox_type(state) + + if sandbox_type == "none": + # No workspace configured, nothing to do + return state + elif sandbox_type == "git": + return await _setup_git_workspace(state) + elif sandbox_type == "container": + return await _setup_container_workspace(state) + elif sandbox_type == "local": + return await _setup_local_workspace(state) + else: + state.metadata["setup_error"] = f"Unknown sandbox type: {sandbox_type}" + return state + + return solve + + +async def _setup_git_workspace(state: TaskState) -> TaskState: + """Clone git repository inside sandbox. + + For git workspaces: + 1. git clone /workspace + 2. git checkout (if specified) + + Dependency resolution (pub get) is left to the model, since git repos + may be monorepos where the agent needs to navigate to the correct + subdirectory before running pub get. + """ + git_url = state.metadata.get("workspace_git") + git_ref = state.metadata.get("workspace_git_ref") + workspace = "/workspace" + + # Type guard: git_url should always be set when this function is called + if not git_url or not isinstance(git_url, str): + state.metadata["setup_error"] = "workspace_git not set or invalid" + return state + + sb = sandbox() + + # Clone repository + result = await sb.exec(["git", "clone", git_url, workspace]) + if not result.success: + state.metadata["setup_error"] = f"git clone failed: {result.stderr}" + return state + + # Checkout specific ref if provided + if git_ref and isinstance(git_ref, str): + result = await sb.exec(["git", "checkout", git_ref], cwd=workspace) + if not result.success: + state.metadata["setup_error"] = f"git checkout failed: {result.stderr}" + return state + + # Set workspace path in metadata for downstream solvers/scorers + state.metadata["workspace"] = workspace + return state diff --git a/packages/dash_evals/src/dash_evals/runner/solvers/write_to_sandbox.py b/packages/dash_evals/src/dash_evals/runner/solvers/write_to_sandbox.py new file mode 100644 index 0000000..6566af1 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/solvers/write_to_sandbox.py @@ -0,0 +1,112 @@ +"""Solver to write extracted code to the workspace filesystem.""" + +import tempfile +from pathlib import Path + +from inspect_ai.solver import Generate, Solver, TaskState, solver +from inspect_ai.util import sandbox + + +def _get_sandbox_type(state: TaskState) -> str: + """Determine sandbox type from workspace path in metadata. + + Returns: + 'container' if workspace is /workspace (podman/docker) + 'local' if workspace is a host filesystem path + 'none' if no workspace is set + """ + workspace = state.metadata.get("workspace") + if not workspace: + return "none" + if workspace == "/workspace": + return "container" + return "local" + + +async def _write_to_container(state: TaskState, code: str, target_path: str) -> TaskState: + """Write code to container sandbox at /workspace. + + Container sandboxes are ephemeral, so writing to /workspace is safe. + """ + sb = sandbox() + full_path = f"/workspace/{target_path}" + await sb.write_file(full_path, code) + return state + + +async def _write_to_local(state: TaskState, code: str, target_path: str) -> TaskState: + """Write code to local sandbox temp directory. + + SAFETY: Only writes to temp directories to prevent accidental + modification of source templates. + """ + workspace_path_str = state.metadata.get("workspace") + + # This function is only called when sandbox_type is "local", meaning + # workspace is set and is a host path (not None, not /workspace) + assert workspace_path_str is not None + + # SAFETY CHECK: Verify workspace is in temp directory + temp_dir = tempfile.gettempdir() + if not workspace_path_str.startswith(temp_dir): + state.metadata["setup_error"] = ( + f"SAFETY: Refusing to write to non-temp directory: {workspace_path_str}. " + f"Expected path starting with {temp_dir}" + ) + return state + + sb = sandbox() + full_path = f"{workspace_path_str}/{target_path}" + await sb.write_file(full_path, code) + return state + + +@solver +def write_to_sandbox(target_path: str = "lib/main.dart") -> Solver: + """ + Write extracted code from state.store to the workspace. + + This solver reads the "extracted_code" from state.store (set by extract_code) + and writes it to the specified path in the workspace directory using the sandbox API. + + Dispatches to the appropriate handler based on sandbox type: + - Container (podman/docker): Write directly to /workspace (ephemeral) + - Local: Write to temp directory (with safety validation) + + Args: + target_path: Relative path within workspace to write the code. + Default is "lib/main.dart" for Flutter projects. + + Returns: + A solver that writes extracted code to the workspace. + """ + + async def solve(state: TaskState, generate: Generate) -> TaskState: + # Early termination if previous solver failed + if state.metadata.get("setup_error"): + return state + + # Get extracted code from store (set by extract_code solver) + code = state.store.get("extracted_code", "") + + if not code: + # Fallback to metadata for backward compatibility + code = state.metadata.get("generated_code", "") + + if not code: + return state + + sandbox_type = _get_sandbox_type(state) + + if sandbox_type == "none": + # No workspace configured, nothing to do + return state + elif sandbox_type == "container": + return await _write_to_container(state, code, target_path) + elif sandbox_type == "local": + return await _write_to_local(state, code, target_path) + else: + state.metadata["setup_error"] = f"Unknown sandbox type: {sandbox_type}" + return state + + return solve diff --git a/packages/dash_evals/src/dash_evals/runner/tasks/__init__.py b/packages/dash_evals/src/dash_evals/runner/tasks/__init__.py new file mode 100644 index 0000000..4430e07 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/tasks/__init__.py @@ -0,0 +1,17 @@ +from .analyze_codebase import analyze_codebase +from .bug_fix import bug_fix, flutter_bug_fix +from .code_gen import code_gen, flutter_code_gen +from .mcp_tool import mcp_tool +from .question_answer import question_answer +from .skill_test import skill_test + +__all__ = [ + "analyze_codebase", + "bug_fix", + "code_gen", + "flutter_bug_fix", + "flutter_code_gen", + "mcp_tool", + "question_answer", + "skill_test", +] diff --git a/packages/dash_evals/src/dash_evals/runner/tasks/analyze_codebase.py b/packages/dash_evals/src/dash_evals/runner/tasks/analyze_codebase.py new file mode 100644 index 0000000..6d5933b --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/tasks/analyze_codebase.py @@ -0,0 +1,114 @@ +""" +Analyze Codebase Task + +Evaluates LLM ability to explore and answer questions about an existing codebase. +The model gets read-only access to workspace files via bash commands, +but is instructed not to modify any files. +""" + +from textwrap import dedent +from typing import cast + +from inspect_ai import Task, task +from inspect_ai.agent import react +from inspect_ai.dataset import Dataset +from inspect_ai.model import ChatMessageSystem +from inspect_ai.scorer import model_graded_fact +from inspect_ai.solver import Generate, Solver, TaskState, solver +from inspect_ai.tool import bash + +from dash_evals.runner.scorers import export_workspace +from dash_evals.runner.solvers import setup_workspace + +from .task_helpers import ( + append_context_injection, + build_task_metadata, + get_skill_tool, +) + +DEFAULT_ANALYZE_SYSTEM_MESSAGE = dedent("""\ + You are an expert code reviewer analyzing a codebase. + + Your task is to: + + 1. Explore the codebase at {workspace} using the available tools + 2. Understand the project structure, dependencies, and architecture + 3. Answer the user's question based on what you find in the code + + Important guidelines: + - Use bash commands (cat, find, grep, ls, head, tail, etc.) to browse files + - Do NOT edit or modify any files + - Base your answer on actual code you find, not assumptions + - Reference specific files and line numbers when relevant + - When done, call submit() with your complete answer +""") + + +@solver +def _add_workspace_system_message(template: str) -> Solver: + """Add system message with workspace path substituted from metadata.""" + + async def solve(state: TaskState, generate: Generate) -> TaskState: + workspace = state.metadata.get("workspace", "/workspace") + message = template.format(workspace=workspace) + state.messages.insert(0, ChatMessageSystem(content=message)) + return state + + return solve + + +def _build_solver_chain(config: dict, system_message: str) -> list: + """Build the solver chain for analyze codebase tasks.""" + solver_chain = [] + + solver_chain.append(_add_workspace_system_message(system_message)) + + append_context_injection(solver_chain, config) + + tools = [ + bash(timeout=120), + ] + skill_tool = get_skill_tool(config) + if skill_tool: + tools.append(skill_tool) + + solver_chain.append( + cast( + Solver, + react( + name="code_analyzer", + description="Expert code reviewer who explores and analyzes codebases.", + tools=tools, + ), + ) + ) + + return solver_chain + + +@task +def analyze_codebase(dataset: Dataset, config: dict) -> Task: + """ + Task for evaluating LLM ability to explore and answer questions about a codebase. + + Args: + dataset: Inspect dataset loaded from JSONL. + config: Task manifest entry with variant, system_message, etc. + """ + system_message = config.get("system_message") or DEFAULT_ANALYZE_SYSTEM_MESSAGE + + solver_chain = _build_solver_chain(config, system_message) + + scorers: list = [model_graded_fact()] + if config.get("save_examples"): + scorers.append(export_workspace()) + + return Task( + name=config["task_name"], + dataset=dataset, + setup=[setup_workspace()], + solver=solver_chain, + scorer=scorers, + time_limit=300, + metadata=build_task_metadata(config), + ) diff --git a/packages/dash_evals/src/dash_evals/runner/tasks/bug_fix.py b/packages/dash_evals/src/dash_evals/runner/tasks/bug_fix.py new file mode 100644 index 0000000..f193e39 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/tasks/bug_fix.py @@ -0,0 +1,191 @@ +""" +Bug Fix Task (Generic, Agentic) + +Evaluates LLM ability to diagnose and fix bugs in existing code using +an agentic approach where the model can explore files and make edits. + +This task behaves similarly to real-world AI coding assistants like Claude Code +and Cursor. Language-specific behavior (system message, scorers) is controlled +via the config dict or thin wrappers (e.g., `flutter_bug_fix`). +""" + +from textwrap import dedent +from typing import cast + +from inspect_ai import Task, task +from inspect_ai.agent import react +from inspect_ai.dataset import Dataset +from inspect_ai.solver import Solver +from inspect_ai.tool import bash_session, text_editor + +from dash_evals.runner.scorers import ( + code_quality_scorer, + dart_analyze_scorer, + export_workspace, + flutter_test_scorer, +) +from dash_evals.runner.solvers import add_system_message, inject_test_files, setup_workspace + +from .task_helpers import ( + append_context_injection, + build_task_metadata, + get_skill_tool, + validate_sandbox_tools, +) + +# ============================================================================ +# Default System Messages +# ============================================================================ + +DEFAULT_BUG_FIX_PROMPT = dedent("""\ + You are an expert developer debugging a production issue. + + Your task is to: + + 1. Explore the codebase to understand the structure + 2. Identify the root cause of the bug + 3. Fix the bug by editing the necessary file(s) + 4. Verify your fix passes any tests and static analysis + 5. If there are any errors or warnings at all, fix them + 6. When done, call submit() with a brief explanation of what you fixed +""") + +FLUTTER_BUG_FIX_PROMPT = dedent("""\ + You are an expert Flutter developer debugging a production issue. + + Your task is to: + + 1. Explore the codebase to understand the structure + 2. Identify the root cause of the bug + 3. Fix the bug by editing the necessary file(s) + 4. Verify your fix passes any tests and static analysis. Be sure to run + dart analyze in the directory containing the pubspec.yaml for the + package you modified, not the workspace root. + 5. If there are any errors or warnings at all, fix them. + 6. When done, call submit() with a brief explanation of what you fixed +""") + + +# ============================================================================ +# Solver Builder +# ============================================================================ + + +def _build_solver_chain(config: dict, system_message: str) -> list: + """Build the solver chain for bug fix tasks.""" + solver_chain = [] + + solver_chain.append(add_system_message(system_message)) + + append_context_injection(solver_chain, config) + + tools = [ + bash_session(timeout=120), + text_editor(timeout=60), + ] + skill_tool = get_skill_tool(config) + if skill_tool: + tools.append(skill_tool) + + agent_name = config.get("agent_name", "debugger") + agent_description = config.get( + "agent_description", + "Expert developer who debugs and fixes code issues.", + ) + + solver_chain.append( + cast( + Solver, + react( + name=agent_name, + description=agent_description, + tools=tools, + ), + ) + ) + + return solver_chain + + +# ============================================================================ +# Generic Task +# ============================================================================ + + +@task +def bug_fix(dataset: Dataset, config: dict) -> Task: + """ + Generic task for evaluating LLM ability to diagnose and fix bugs. + + The config dict controls language-specific behavior: + - system_message: Custom system prompt (optional) + - agent_name: Name for the react agent (default: "debugger") + - agent_description: Description for the react agent (optional) + - scorers: List of scorer instances (optional, defaults to dart analyzers) + + Args: + dataset: Inspect dataset loaded from JSONL. + config: Task manifest entry with variant, system_message, etc. + """ + system_message = config.get("system_message") or DEFAULT_BUG_FIX_PROMPT + + validate_sandbox_tools(config, ["bash_session", "text_editor"]) + + solver_chain = _build_solver_chain(config, system_message) + + scorers: list = config.get( + "scorers", + [ + dart_analyze_scorer(), + code_quality_scorer(), + ], + ) + if config.get("save_examples"): + scorers.append(export_workspace()) + + return Task( + name=config["task_name"], + dataset=dataset, + setup=[setup_workspace(), inject_test_files()], + solver=solver_chain, + scorer=scorers, + time_limit=config.get("time_limit", 600), + metadata=build_task_metadata(config), + ) + + +# ============================================================================ +# Flutter Thin Wrapper +# ============================================================================ + + +@task +def flutter_bug_fix(dataset: Dataset, config: dict) -> Task: + """ + Flutter-specific bug fix task. + + Thin wrapper around bug_fix() with Flutter defaults: + - Flutter system message + - Flutter-specific scorers (dart_analyze, flutter_test, code_quality) + + Args: + dataset: Inspect dataset loaded from JSONL. + config: Task manifest entry with variant, system_message, etc. + """ + flutter_config = { + "system_message": FLUTTER_BUG_FIX_PROMPT, + "agent_name": "flutter_debugger", + "agent_description": "Expert Flutter developer who debugs and fixes code issues.", + **config, + } + if "scorers" not in config: + scorers: list = [ + dart_analyze_scorer(), + flutter_test_scorer(), + code_quality_scorer(), + ] + if config.get("save_examples"): + scorers.append(export_workspace()) + flutter_config["scorers"] = scorers + + return bug_fix(dataset, flutter_config) diff --git a/packages/dash_evals/src/dash_evals/runner/tasks/code_gen.py b/packages/dash_evals/src/dash_evals/runner/tasks/code_gen.py new file mode 100644 index 0000000..3d7871a --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/tasks/code_gen.py @@ -0,0 +1,261 @@ +""" +Code Generation Task (Generic) + +Evaluates LLM ability to write working code from scratch. Language-specific +behavior (response schema, system message, target path) is controlled via +the config dict. Thin wrappers (e.g., `flutter_code_gen`) supply language defaults. + +Evaluates by: +1. Generating code based on prompts +2. Executing code in a sandbox +3. Running tests +4. Analyzing code quality +5. Scoring based on test results and code metrics + +Uses structured output to get clean code without regex extraction. +""" + +from inspect_ai import Task, task +from inspect_ai.dataset import Dataset +from inspect_ai.model import GenerateConfig, ResponseSchema +from inspect_ai.solver import Generate, Solver, TaskState, chain_of_thought, solver +from inspect_ai.util import json_schema +from pydantic import BaseModel, Field + +from dash_evals.runner.scorers import export_workspace, flutter_code_scorer +from dash_evals.runner.solvers import ( + add_system_message, + inject_test_files, + setup_workspace, + write_to_sandbox, +) +from dash_evals.runner.tasks.task_helpers import validate_sandbox_tools + +from .task_helpers import ( + append_context_injection, + append_model_interaction, + build_task_metadata, +) + +# ============================================================================ +# Structured Output Models +# ============================================================================ + + +class FlutterCodeResponse(BaseModel): + """Structured response for Flutter code generation.""" + + main_dart: str = Field( + description="Complete Dart code for lib/main.dart. Must include all imports and a MyApp class." + ) + reasoning: str = Field( + default="", + description="Brief explanation of key implementation decisions (optional).", + ) + + +class GenericCodeResponse(BaseModel): + """Generic structured response for code generation.""" + + code: str = Field(description="Complete source code for the target file.") + reasoning: str = Field( + default="", + description="Brief explanation of key implementation decisions (optional).", + ) + + +# Map of known response schemas by language +RESPONSE_SCHEMAS: dict[str, type[BaseModel]] = { + "flutter": FlutterCodeResponse, + "generic": GenericCodeResponse, +} + + +# ============================================================================ +# Structured Output Solver +# ============================================================================ + + +@solver +def parse_structured_code( + code_field: str = "main_dart", response_model: type[BaseModel] | None = None +) -> Solver: + """ + Parse structured JSON output and store extracted code. + + Reads the model's structured JSON response, extracts the code field, + and stores it in state.store["extracted_code"] for write_to_sandbox. + + Args: + code_field: The field name containing the code in the response model. + response_model: Optional Pydantic model to validate against. If None, + attempts to use the raw output. + """ + + async def solve(state: TaskState, generate: Generate) -> TaskState: + try: + if response_model: + response = response_model.model_validate_json(state.output.completion) + code = getattr(response, code_field) + state.store.set("extracted_code", code) + state.metadata["generated_code"] = code + reasoning = getattr(response, "reasoning", "") + if reasoning: + state.metadata["model_reasoning"] = reasoning + else: + state.store.set("extracted_code", state.output.completion) + state.metadata["generated_code"] = state.output.completion + except Exception as e: + # If parsing fails, try to use the raw output as code + # This handles cases where the model ignores the schema + state.store.set("extracted_code", state.output.completion) + state.metadata["generated_code"] = state.output.completion + state.metadata["structured_parse_error"] = str(e) + return state + + return solve + + +# ============================================================================ +# Default System Messages +# ============================================================================ + +DEFAULT_CODE_GEN_SYSTEM_MESSAGE = ( + "You are an expert developer. " + "Generate complete, working code that follows best practices. " + "Your code will be tested automatically, so ensure it compiles and runs correctly. " + "Provide the complete code in your response." +) + +FLUTTER_CODE_GEN_SYSTEM_MESSAGE = ( + "You are an expert Flutter developer. " + "Generate complete, working Flutter code that follows best practices. " + "IMPORTANT: Your main app class MUST be named 'MyApp'. " + "Your code will be tested automatically, so ensure it compiles and runs correctly. " + "Always include proper imports and use const constructors where appropriate. " + "Provide the complete code in the main_dart field of your response." +) + + +# ============================================================================ +# Solver Builder +# ============================================================================ + + +def _build_solver_with_tools(config: dict, system_msg: str): + """Build solver with optional MCP server tools and chain_of_thought.""" + solver_chain = [add_system_message(system_msg)] + append_context_injection(solver_chain, config) + solver_chain.append(chain_of_thought()) + append_model_interaction(solver_chain, config) + + return solver_chain + + +# ============================================================================ +# Generic Task +# ============================================================================ + + +@task +def code_gen(dataset: Dataset, config: dict) -> Task: + """ + Generic task for evaluating LLM code generation. + + The config dict controls language-specific behavior: + - system_message: Custom system prompt (optional) + - target_path: Where to write generated code (default: "lib/main.dart") + - response_schema_name: Key into RESPONSE_SCHEMAS (default: "generic") + - response_schema_description: Description for the schema (optional) + - code_field: Field name for code in structured output (default: "code") + + Args: + dataset: Inspect dataset loaded from JSONL. + config: Task manifest entry with variant, system_message, etc. + """ + validate_sandbox_tools(config, ["bash_session", "text_editor"]) + + system_msg = config.get("system_message") or DEFAULT_CODE_GEN_SYSTEM_MESSAGE + target_path = config.get("target_path", "lib/main.dart") + schema_name = config.get("response_schema_name", "generic") + code_field = config.get("code_field", "code") + + response_model = RESPONSE_SCHEMAS.get(schema_name, GenericCodeResponse) + + solver_chain = _build_solver_with_tools(config, system_msg) + + # Use scorers from config or default + scorers: list = config.get("scorers", [flutter_code_scorer()]) + if config.get("save_examples"): + scorers.append(export_workspace()) + + schema_description = config.get( + "response_schema_description", + f"Source code for {target_path}", + ) + + return Task( + name=config["task_name"], + dataset=dataset, + setup=[ + setup_workspace(), + inject_test_files(), + ], + solver=[ + *solver_chain, + parse_structured_code( + code_field=code_field, + response_model=response_model, + ), + write_to_sandbox(target_path=target_path), + ], + scorer=scorers, + config=GenerateConfig( + response_schema=ResponseSchema( + name="generated_code", + json_schema=json_schema(response_model), + description=schema_description, + ) + ), + time_limit=config.get("time_limit", 300), + metadata=build_task_metadata(config), + ) + + +# ============================================================================ +# Flutter Thin Wrapper +# ============================================================================ + + +@task +def flutter_code_gen(dataset: Dataset, config: dict) -> Task: + """ + Flutter-specific code generation task. + + Thin wrapper around code_gen() with Flutter defaults: + - Flutter system message + - FlutterCodeResponse schema + - target_path = "lib/main.dart" + - flutter_code_scorer + + Args: + dataset: Inspect dataset loaded from JSONL. + config: Task manifest entry with variant, system_message, etc. + """ + # Apply Flutter defaults without overriding explicit config + flutter_config = { + "system_message": FLUTTER_CODE_GEN_SYSTEM_MESSAGE, + "target_path": "lib/main.dart", + "response_schema_name": "flutter", + "response_schema_description": "Flutter application code for lib/main.dart", + "code_field": "main_dart", + **config, # User config wins + } + # Ensure Flutter scorers are used unless explicitly overridden + if "scorers" not in config: + scorers: list = [flutter_code_scorer()] + if config.get("save_examples"): + scorers.append(export_workspace()) + flutter_config["scorers"] = scorers + + return code_gen(dataset, flutter_config) diff --git a/packages/dash_evals/src/dash_evals/runner/tasks/mcp_tool.py b/packages/dash_evals/src/dash_evals/runner/tasks/mcp_tool.py new file mode 100644 index 0000000..ff04800 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/tasks/mcp_tool.py @@ -0,0 +1,81 @@ +""" +MCP Tool Usage Task (Unified) + +Tests an agent's ability to use a specific MCP server tool. Consolidates the +former `mcp_create_project` and `mcp_pub_dev_search` tasks into a single +configurable task. + +Config keys: + required_tools: list[str] — MCP tool names the agent should use (for scoring) + inject_temp_dir: bool — if True, replace {root_path} in sample inputs with a + temp directory (needed for create_project-style tasks) +""" + +import tempfile + +from inspect_ai import Task, task +from inspect_ai.dataset import Dataset, MemoryDataset, Sample +from inspect_ai.scorer import includes + +from ..scorers import mcp_tool_usage +from ..solvers import add_system_message +from .task_helpers import ( + append_context_injection, + append_model_interaction, + build_task_metadata, +) + + +@task +def mcp_tool(dataset: Dataset, config: dict) -> Task: + """ + Unified task for evaluating MCP tool usage. + + Args: + dataset: Inspect dataset loaded from JSONL. + config: Task manifest entry with: + - required_tools: list of MCP tool names the agent should call + - inject_temp_dir: if True, replaces {root_path} in inputs + - system_message: custom system prompt (optional) + """ + required_tools = config.get("required_tools", []) + inject_temp_dir = config.get("inject_temp_dir", False) + + # Pre-process samples if temp directory injection is needed + active_dataset = dataset + if inject_temp_dir: + temp_root = tempfile.mkdtemp(prefix="mcp_tool_") + processed_samples = [] + for sample in dataset: + input_str = sample.input if isinstance(sample.input, str) else str(sample.input) + processed_samples.append( + Sample( + input=input_str.replace("{root_path}", temp_root), + target=sample.target, + id=sample.id, + metadata=sample.metadata, + ) + ) + active_dataset = MemoryDataset( + samples=processed_samples, + name=config.get("task_name", "mcp_tool"), + ) + + # Build solver chain + system_msg = config.get("system_message", "You are a helpful assistant.") + solver_chain = [add_system_message(system_msg)] + append_context_injection(solver_chain, config) + append_model_interaction(solver_chain, config) + + return Task( + name=config["task_name"], + dataset=active_dataset, + solver=solver_chain, + scorer=[ + includes(ignore_case=True), + mcp_tool_usage(required_tools=required_tools if required_tools else None), + ], + time_limit=config.get("time_limit", 300), + message_limit=config.get("message_limit", 50), + metadata=build_task_metadata(config), + ) diff --git a/packages/dash_evals/src/dash_evals/runner/tasks/question_answer.py b/packages/dash_evals/src/dash_evals/runner/tasks/question_answer.py new file mode 100644 index 0000000..9a7ea0d --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/tasks/question_answer.py @@ -0,0 +1,57 @@ +"""QA tasks for evaluating model Q&A capabilities.""" + +from textwrap import dedent + +from inspect_ai import Task, task +from inspect_ai.dataset import Dataset +from inspect_ai.scorer import model_graded_fact +from inspect_ai.solver import chain_of_thought + +from ..solvers import add_system_message +from .task_helpers import ( + append_context_injection, + append_model_interaction, + build_task_metadata, +) + +DEFAULT_QA_SYSTEM_MESSAGE = dedent(""" + You are a helpful and knowledgeable coding assistant. + Answer questions clearly and accurately, providing examples when helpful. +""") + + +def _build_qa_solver(system_msg: str, config: dict): + """ + Build solver chain for QA tasks. + + Includes chain_of_thought for improved reasoning. + """ + solver_chain = [add_system_message(system_msg)] + append_context_injection(solver_chain, config) + solver_chain.append(chain_of_thought()) + append_model_interaction(solver_chain, config) + + return solver_chain + + +@task +def question_answer(dataset: Dataset, config: dict) -> Task: + """ + Generic QA task for evaluating model Q&A capabilities. + + Args: + dataset: Inspect dataset loaded from JSONL. + config: Task manifest entry with variant, system_message, etc. + """ + system_msg = config.get("system_message") or DEFAULT_QA_SYSTEM_MESSAGE + + solver = _build_qa_solver(system_msg, config) + + return Task( + name=config["task_name"], + dataset=dataset, + solver=solver, + scorer=model_graded_fact(), + time_limit=300, + metadata=build_task_metadata(config), + ) diff --git a/packages/dash_evals/src/dash_evals/runner/tasks/skill_test.py b/packages/dash_evals/src/dash_evals/runner/tasks/skill_test.py new file mode 100644 index 0000000..b695732 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/tasks/skill_test.py @@ -0,0 +1,121 @@ +""" +Skill Test Task + +Evaluates whether the agent discovers and applies specific skills provided +to it. This task is designed to allow the agent to use the skill tool if +provided, and the scorer checks if the tool was utilized appropriately when +skills are available. + +Requires a sandbox since Inspect AI's skill() tool copies skill directories +into the sandbox filesystem. +""" + +from textwrap import dedent +from typing import cast + +from inspect_ai import Task, task +from inspect_ai.agent import react +from inspect_ai.dataset import Dataset +from inspect_ai.scorer import model_graded_fact +from inspect_ai.solver import Solver +from inspect_ai.tool import bash + +from dash_evals.runner.scorers import export_workspace, skill_usage_scorer +from dash_evals.runner.solvers import add_system_message, setup_workspace + +from .task_helpers import ( + append_context_injection, + build_task_metadata, + get_skill_tool, +) + +DEFAULT_SKILL_TEST_SYSTEM_MESSAGE = dedent("""\ + You are an expert developer solving a task. + + You MAY be provided with agent skills — folders of instructions, + scripts, and resources that will help you complete this task more + accurately and efficiently. + + Your approach should be: + + 1. First, use the skill tool to discover available skills + 2. Read the SKILL.md file(s) to understand the guidance provided + 3. Apply the skill instructions to complete the user's task + 4. When done, call submit() with your answer + + Important: + - Base your answer on the skill content, not just your training data + - Reference specific guidance from the Skills (if available) in your response +""") + + +def _build_solver_chain(config: dict, system_message: str) -> list: + """Build the solver chain for skill test tasks.""" + solver_chain = [] + + solver_chain.append(add_system_message(system_message)) + + append_context_injection(solver_chain, config) + + # Build tools list — skill tool is required for this task type + skill_tool = get_skill_tool(config) + + tools = [bash(timeout=120)] + if skill_tool is not None: + tools.append(skill_tool) + + # Add the react agent with bash (and optionally skill) tool + solver_chain.append( + cast( + Solver, + react( + name="skill_tester", + description="Developer who discovers and applies agent skills to complete tasks.", + tools=tools, + ), + ) + ) + + return solver_chain + + +@task +def skill_test(dataset: Dataset, config: dict) -> Task: + """Task for evaluating whether an agent correctly uses provided skills. + + This task is specifically designed to test skill discovery and application. + If skills are provided, the agent may: + + 1. Use the skill tool to discover available skills + 2. Read and understand the skill instructions + 3. Apply the skill guidance to answer the user's question + 4. Submit an answer that reflects skill-based knowledge + + The scorer checks both answer quality (model_graded_fact) and, if skills + are present, whether the skill tool was actually used (skill_usage_scorer). + + Args: + dataset: Inspect dataset loaded from JSONL. + config: Task manifest entry with variant, system_message, etc. + """ + system_message = config.get("system_message") or DEFAULT_SKILL_TEST_SYSTEM_MESSAGE + + solver_chain = _build_solver_chain(config, system_message) + + scorers: list = [ + model_graded_fact(), # Answer quality + ] + if get_skill_tool(config) is not None: + scorers.append(skill_usage_scorer()) # Verify skill tool was actually used + if config.get("save_examples"): + scorers.append(export_workspace()) + + return Task( + name=config["task_name"], + dataset=dataset, + setup=[setup_workspace()], + solver=solver_chain, + scorer=scorers, + time_limit=300, + metadata=build_task_metadata(config), + ) diff --git a/packages/dash_evals/src/dash_evals/runner/tasks/task_helpers.py b/packages/dash_evals/src/dash_evals/runner/tasks/task_helpers.py new file mode 100644 index 0000000..9c2f61a --- /dev/null +++ b/packages/dash_evals/src/dash_evals/runner/tasks/task_helpers.py @@ -0,0 +1,171 @@ +"""Shared helper functions for building task components. + +These helpers encapsulate common patterns used across tasks: +- Creating the Dart MCP server +- Building task metadata +- Appending variant-driven solvers (context injection, MCP tools, skills) + +All helpers accept a `config` dict (from the run manifest) instead of +TaskConfig, enabling the JSONL + manifest-based execution flow. +""" + +from __future__ import annotations + +from typing import Any, cast + +from inspect_ai.agent import react +from inspect_ai.solver import Solver, generate +from inspect_ai.tool import MCPServer, Tool, mcp_server_stdio, skill + +from dash_evals.runner.solvers import context_injector + +# Tools that trigger sandbox injection (require Linux container). +# bash_session() and text_editor() both call sandbox_with_injected_tools(), +# which injects helper scripts and only supports Linux containers. +INJECTION_TOOLS = frozenset({"bash_session", "text_editor"}) + + +def validate_sandbox_tools(config: dict, tool_names: list[str]) -> None: + """Validate that the requested tools are compatible with the sandbox type. + + Args: + config: Task manifest entry with 'sandbox_type' and 'task_name' keys. + tool_names: Names of tools this task will use. + + Raises: + ValueError: If local sandbox is used with injection-requiring tools. + """ + if config.get("sandbox_type", "local") != "local": + return + + conflicting = INJECTION_TOOLS.intersection(tool_names) + if not conflicting: + return + + tool_list = "\n".join(f" • {t}()" for t in sorted(conflicting)) + task_name = config.get("task_name", "unknown") + raise ValueError( + f"\n{'=' * 60}\n" + f"Task '{task_name}' cannot run on a local sandbox.\n\n" + f"The following tools require a Linux container (Docker/Podman):\n" + f"{tool_list}\n\n" + f"These tools inject helper scripts into the sandbox, which is\n" + f"not supported on macOS.\n\n" + f"To fix this, either:\n" + f" 1. Set sandbox_type to 'docker' or 'podman' in your job YAML\n" + f" 2. Use a task that supports local execution (e.g. 'analyze_codebase')\n" + f"{'=' * 60}" + ) + + +def create_mcp_server(config: dict | None = None): + """ + Create an MCP server tool from config. + + Reads 'mcp_server_command' and 'mcp_server_args' from config. + Defaults to the Dart MCP server if not specified. + + Args: + config: Task config with optional 'mcp_server_command' and + 'mcp_server_args' keys. + + Returns: + MCP server stdio tool. + """ + config = config or {} + command = config.get("mcp_server_command", "dart") + args = config.get("mcp_server_args", ["mcp-server", "--force-roots-fallback"]) + name = config.get("mcp_server_name", "Dart") + return mcp_server_stdio( + name=name, + command=command, + args=args, + ) + + +# Backwards-compatible alias +def create_dart_mcp_server(): + """Create the standard Dart MCP server tool (backwards-compatible alias).""" + return create_mcp_server() + + +def build_task_metadata(config: dict) -> dict: + """Build task metadata dictionary from manifest config. + + Args: + config: Task manifest entry with 'variant', 'save_examples', etc. + + Returns: + Metadata dictionary for Task. + """ + metadata: dict[str, Any] = {} + variant = config.get("variant", {}) + if variant: + metadata["variant_config"] = variant + + if config.get("save_examples") and config.get("examples_dir"): + metadata["save_examples"] = True + metadata["examples_dir"] = config["examples_dir"] + metadata["task_variant"] = config.get("task_name", "unknown") + + return metadata + + +def append_context_injection(solver_chain: list, config: dict) -> None: + """Append context injection solver if the variant has context files. + + Args: + solver_chain: The solver chain list to append to. + config: Task manifest entry with 'variant' key. + """ + variant = config.get("variant", {}) + context_files = variant.get("context_files", []) + if context_files: + solver_chain.append(context_injector(context_files)) + + +def get_skill_tool(config: dict) -> Tool | None: + """Create the skill tool if the variant has skills configured. + + Args: + config: Task manifest entry with 'variant' key. + + Returns: + The skill Tool, or None if no skills are configured. + """ + variant = config.get("variant", {}) + skill_paths = variant.get("skill_paths", []) + if skill_paths: + return skill(skill_paths) + return None + + +def append_model_interaction( + solver_chain: list, + config: dict, + *, + extra_tools: list | None = None, +) -> None: + """Append either a react agent (with MCP tools) or plain generate. + + Args: + solver_chain: The solver chain list to append to. + config: Task manifest entry with 'variant' key. + extra_tools: Additional tools to include alongside MCP (optional). + """ + tools: list[Tool | MCPServer] = [] + variant = config.get("variant", {}) + if variant.get("mcp_servers"): + tools.append(create_mcp_server(config)) + + skill_tool = get_skill_tool(config) + if skill_tool: + tools.append(skill_tool) + + if extra_tools: + tools.extend(extra_tools) + + if tools: + solver_chain.append(cast(Solver, react(tools=tools))) + else: + solver_chain.append(generate()) diff --git a/packages/dash_evals/src/dash_evals/utils/__init__.py b/packages/dash_evals/src/dash_evals/utils/__init__.py new file mode 100644 index 0000000..d46eb46 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/utils/__init__.py @@ -0,0 +1,3 @@ +"""Utility functions for dash-evals.""" + +__all__ = [] diff --git a/packages/dash_evals/src/dash_evals/utils/logging.py b/packages/dash_evals/src/dash_evals/utils/logging.py new file mode 100644 index 0000000..4284c85 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/utils/logging.py @@ -0,0 +1,123 @@ +"""Logging utilities for the dash-evals runner. + +Provides file and console logging for tracing runner execution. +""" + +import logging +import sys +from contextlib import contextmanager +from datetime import datetime, timezone +from pathlib import Path +from typing import TextIO + + +class TeeStream: + """A stream that writes to both the original stream and a file.""" + + def __init__(self, original: TextIO, log_file: TextIO): + self.original = original + self.log_file = log_file + + def write(self, text: str) -> int: + """Write to both streams.""" + self.original.write(text) + # Strip ANSI codes for cleaner log file + clean_text = _strip_ansi(text) + self.log_file.write(clean_text) + return len(text) + + def flush(self) -> None: + """Flush both streams.""" + self.original.flush() + self.log_file.flush() + + def fileno(self) -> int: + """Return the file descriptor of the original stream.""" + return self.original.fileno() + + def isatty(self) -> bool: + """Return whether the original stream is a tty.""" + return self.original.isatty() + + +def _strip_ansi(text: str) -> str: + """Remove ANSI escape codes from text.""" + import re + + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + return ansi_escape.sub("", text) + + +@contextmanager +def capture_output(log_file_path: Path): + """Context manager to capture stdout/stderr to a log file. + + Args: + log_file_path: Path to the log file to append output to. + + Yields: + None - stdout/stderr are captured during the context. + """ + # Open log file in append mode + log_file = open(log_file_path, "a", encoding="utf-8") + + # Save original streams + original_stdout = sys.stdout + original_stderr = sys.stderr + + try: + # Replace with tee streams + sys.stdout = TeeStream(original_stdout, log_file) # type: ignore + sys.stderr = TeeStream(original_stderr, log_file) # type: ignore + yield + finally: + # Restore original streams + sys.stdout = original_stdout + sys.stderr = original_stderr + log_file.close() + + +def setup_logging(log_dir: Path, name: str = "dash_evals") -> tuple[logging.Logger, Path]: + """Configure logging to both console and file. + + Args: + log_dir: Directory to write log files + name: Logger name (default: dash_evals) + + Returns: + Tuple of (configured logger instance, log file path) + """ + # Create logger + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.propagate = False # Prevent duplicate output to root logger + + # Clear any existing handlers (avoid duplicates) + logger.handlers.clear() + + # Console handler (INFO level) + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + console_format = logging.Formatter( + "%(asctime)s - %(levelname)s - %(message)s", + datefmt="%H:%M:%S", + ) + console_handler.setFormatter(console_format) + logger.addHandler(console_handler) + + # File handler (DEBUG level - more verbose) + log_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d_%H-%M-%S") + log_file = log_dir / f"runner_{timestamp}.log" + + file_handler = logging.FileHandler(log_file, encoding="utf-8") + file_handler.setLevel(logging.DEBUG) + file_format = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + file_handler.setFormatter(file_format) + logger.addHandler(file_handler) + + logger.info(f"šŸ“ Runner log: {log_file}") + return logger, log_file diff --git a/packages/dash_evals/src/dash_evals/utils/markdown.py b/packages/dash_evals/src/dash_evals/utils/markdown.py new file mode 100644 index 0000000..8c7fd13 --- /dev/null +++ b/packages/dash_evals/src/dash_evals/utils/markdown.py @@ -0,0 +1,44 @@ +"""Utilities for working with markdown or text.""" + + +def extract_code_from_markdown(text: str, language: str | None = None) -> str: + """Extract code from markdown code blocks. + + Args: + text: Text that may contain markdown code blocks. + language: Optional language identifier (e.g., 'dart', 'python'). + + Returns: + Extracted code, or original text if no code blocks found. + + Example: + >>> extract_code_from_markdown("```dart\\ncode\\n```") + 'code' + """ + # Try language-specific code block first if language is provided + if language: + marker = f"```{language}" + if marker in text: + start = text.find(marker) + len(marker) + end = text.find("```", start) + if end != -1: + return text[start:end].strip() + + # Try generic language-specific code blocks (e.g., ```dart, ```python) + if "```" in text: + # Look for language-specific blocks + for lang in ["dart", "python", "javascript", "typescript", "java", "kotlin"]: + marker = f"```{lang}" + if marker in text: + start = text.find(marker) + len(marker) + end = text.find("```", start) + if end != -1: + return text[start:end].strip() + + # Fall back to generic code block + start = text.find("```") + 3 + end = text.find("```", start) + if end != -1: + return text[start:end].strip() + + return text diff --git a/packages/dash_evals/tests/__init__.py b/packages/dash_evals/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/dash_evals/tests/test_export_workspace.py b/packages/dash_evals/tests/test_export_workspace.py new file mode 100644 index 0000000..ecaf946 --- /dev/null +++ b/packages/dash_evals/tests/test_export_workspace.py @@ -0,0 +1,253 @@ +"""Tests for the export_workspace scorer. + +These tests mock the sandbox API so that the scorer's tar-based export logic +can be exercised without a real container. +""" + +import asyncio +import io +import tarfile +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +from dash_evals.runner.scorers.export_workspace import export_workspace + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_state( + workspace: str | None, + examples_dir: str | None, + task_variant: str = "my-task:baseline", +) -> MagicMock: + """Build a minimal mock TaskState for the scorer tests.""" + state = MagicMock() + state.metadata = { + k: v + for k, v in { + "workspace": workspace, + "examples_dir": examples_dir, + "save_examples": True, + "task_variant": task_variant, + }.items() + if v is not None + } + state.sample_id = "sample-001" + return state + + +def _build_mock_sandbox(workspace_path: str, excludes: list[str] | None = None): + """Build a mock sandbox whose exec() creates a real tar from the host filesystem. + + This simulates what happens inside the container: tar -cf the workspace, + then read_file returns the tar bytes. + """ + mock_sb = AsyncMock() + + # The tar bytes produced by exec() — stored here so read_file() can return them. + _tar_bytes: dict[str, bytes] = {} + + async def fake_exec(cmd: list[str]): + """Simulate: tar -cf --exclude ... -C .""" + result = MagicMock() + # Parse the command to find -C and the archive path. + # Expected form: ["tar", "-cf", archive_path, ...excludes..., "-C", workspace, "."] + if cmd[0] == "rm": + result.success = True + return result + + archive_path = cmd[2] + # Find -C index + try: + c_idx = cmd.index("-C") + src_dir = Path(cmd[c_idx + 1]) + except (ValueError, IndexError): + result.success = False + result.stderr = "could not parse -C from command" + return result + + if not src_dir.exists(): + result.success = False + result.stderr = f"tar: {src_dir}: No such file or directory" + return result + + # Collect --exclude patterns + excludes_set: set[str] = set() + i = 0 + while i < len(cmd): + if cmd[i] == "--exclude": + excludes_set.add(cmd[i + 1]) + i += 2 + else: + i += 1 + + # Build a real tar in memory + buf = io.BytesIO() + with tarfile.open(fileobj=buf, mode="w") as tf: + for p in sorted(src_dir.rglob("*")): + rel = p.relative_to(src_dir) + arc_name = f"./{rel}" + # Check excludes (simple prefix match) + skip = False + for excl in excludes_set: + if arc_name.startswith(excl): + skip = True + break + if skip: + continue + tf.add(str(p), arcname=arc_name) + + _tar_bytes[archive_path] = buf.getvalue() + result.success = True + result.stderr = "" + return result + + async def fake_read_file(path: str, text: bool = True): + return _tar_bytes.get(path, b"") + + mock_sb.exec = fake_exec + mock_sb.read_file = fake_read_file + return mock_sb + + +async def _run_scorer(workspace, examples_dir, task_variant="my-task:baseline", mock_sb=None): + """Helper: run the scorer with the given args and return the Score.""" + state = _make_state(workspace, examples_dir, task_variant) + target = MagicMock() + scorer = export_workspace() + + if mock_sb is not None: + with patch( + "dash_evals.runner.scorers.export_workspace.sandbox", + return_value=mock_sb, + ): + return await scorer(state, target) + else: + # No sandbox mock — used for tests that don't reach _export_via_sandbox + # (e.g. missing metadata keys). We still need to patch sandbox to avoid + # real sandbox lookup errors on the off-chance execution reaches it. + mock_fallback = AsyncMock() + mock_fallback.exec = AsyncMock(side_effect=RuntimeError("no sandbox")) + with patch( + "dash_evals.runner.scorers.export_workspace.sandbox", + return_value=mock_fallback, + ): + return await scorer(state, target) + + +# --------------------------------------------------------------------------- +# Factory tests +# --------------------------------------------------------------------------- + + +class TestExportWorkspaceFactory: + """Tests for the export_workspace scorer factory.""" + + def test_returns_callable(self): + """export_workspace() should return a callable scorer.""" + result = export_workspace() + assert callable(result) + + +# --------------------------------------------------------------------------- +# Scorer behaviour tests +# --------------------------------------------------------------------------- + + +class TestExportWorkspaceScorer: + """Tests for the export_workspace scorer inner logic.""" + + def test_copies_workspace_to_examples_dir(self, tmp_path: Path): + """Should copy workspace contents to examples_dir///.""" + workspace = tmp_path / "eval_workspace_abc123" / "my_flutter_app" + workspace.mkdir(parents=True) + (workspace / "pubspec.yaml").write_text("name: my_app") + (workspace / "lib").mkdir() + (workspace / "lib" / "main.dart").write_text("void main() {}") + + examples_dir = tmp_path / "logs" / "examples" + mock_sb = _build_mock_sandbox(str(workspace)) + + score = asyncio.run(_run_scorer(str(workspace), str(examples_dir), mock_sb=mock_sb)) + + assert score is not None + + dest = examples_dir / "my-task:baseline" / "sample-001" + assert dest.exists() + assert (dest / "pubspec.yaml").exists() + assert (dest / "lib" / "main.dart").exists() + + def test_returns_score_when_no_examples_dir(self): + """Should return a score gracefully when examples_dir is not in metadata.""" + score = asyncio.run(_run_scorer(workspace="/tmp/some_workspace", examples_dir=None)) + + assert score is not None + assert "examples_dir not set" in (score.explanation or "") + + def test_returns_score_when_no_workspace(self, tmp_path: Path): + """Should return a score gracefully when workspace is not in metadata.""" + score = asyncio.run(_run_scorer(workspace=None, examples_dir=str(tmp_path / "examples"))) + + assert score is not None + assert "No workspace" in (score.explanation or "") + + def test_returns_score_when_workspace_path_missing(self, tmp_path: Path): + """Should return a score gracefully when workspace path doesn't exist on disk.""" + mock_sb = _build_mock_sandbox(str(tmp_path / "nonexistent_workspace")) + score = asyncio.run( + _run_scorer( + workspace=str(tmp_path / "nonexistent_workspace"), + examples_dir=str(tmp_path / "examples"), + mock_sb=mock_sb, + ) + ) + + assert score is not None + assert "failed" in (score.explanation or "").lower() or "No such file" in ( + score.explanation or "" + ) + + def test_excludes_build_artifacts(self, tmp_path: Path): + """Build artifacts should not be copied to the examples directory.""" + workspace = tmp_path / "eval_workspace_abc" / "workspace" + workspace.mkdir(parents=True) + (workspace / "pubspec.yaml").write_text("name: app") + (workspace / "build").mkdir() + (workspace / "build" / "big_artifact.txt").write_text("should be excluded") + (workspace / ".dart_tool").mkdir() + (workspace / ".dart_tool" / "config.json").write_text("{}") + + examples_dir = tmp_path / "examples" + mock_sb = _build_mock_sandbox(str(workspace)) + + asyncio.run(_run_scorer(str(workspace), str(examples_dir), mock_sb=mock_sb)) + + dest = examples_dir / "my-task:baseline" / "sample-001" + assert not (dest / "build").exists() + assert not (dest / ".dart_tool").exists() + assert (dest / "pubspec.yaml").exists() + + def test_multiple_samples_get_separate_dirs(self, tmp_path: Path): + """Each sample should land in its own subdirectory under the task_variant dir.""" + workspace = tmp_path / "eval_workspace" / "my_app" + workspace.mkdir(parents=True) + (workspace / "main.dart").write_text("// code") + + examples_dir = tmp_path / "examples" + mock_sb = _build_mock_sandbox(str(workspace)) + + for sample_id in ["sample-001", "sample-002"]: + state = _make_state(str(workspace), str(examples_dir)) + state.sample_id = sample_id + target = MagicMock() + scorer = export_workspace() + with patch( + "dash_evals.runner.scorers.export_workspace.sandbox", + return_value=mock_sb, + ): + asyncio.run(scorer(state, target)) + + assert (examples_dir / "my-task:baseline" / "sample-001").exists() + assert (examples_dir / "my-task:baseline" / "sample-002").exists() diff --git a/packages/dash_evals/tests/test_flutter_code_execution.py b/packages/dash_evals/tests/test_flutter_code_execution.py new file mode 100644 index 0000000..8922ce2 --- /dev/null +++ b/packages/dash_evals/tests/test_flutter_code_execution.py @@ -0,0 +1,227 @@ +""" +Unit tests for flutter_code_execution task and related utilities. + +Tests the pure functions that don't have side effects. +""" + +import pytest + +# Test Flutter parsers +from dash_evals.runner.scorers.flutter_output_parser import ( + AnalyzerResult, + TestResult, + parse_analyzer_output, + parse_test_output, +) + +# Test Flutter scoring +from dash_evals.runner.scorers.flutter_scoring import ( + calculate_analyzer_score, + calculate_final_score, + calculate_test_score, + validate_code_structure, +) + + +class TestAnalyzerParsing: + """Test analyzer output parsing.""" + + def test_parse_analyzer_output_clean(self): + """Test parsing output with no issues.""" + output = "Analyzing project...\nNo issues found!" + result = parse_analyzer_output(output) + + assert result.error_count == 0 + assert result.warning_count == 0 + assert result.info_count == 0 + assert result.raw_output == output + + def test_parse_analyzer_output_with_issues(self): + """Test parsing output with errors and warnings.""" + output = """ + error • The method 'foo' isn't defined + warning • Unused import + warning • Missing const + info • Consider using final + """ + result = parse_analyzer_output(output) + + assert result.error_count == 1 + assert result.warning_count == 2 + assert result.info_count == 1 + + def test_calculate_analyzer_score_perfect(self): + """Test score calculation with no issues.""" + result = AnalyzerResult(0, 0, 0, "") + score, explanation = calculate_analyzer_score(result) + assert score == 1.0 + assert "No analyzer issues" in explanation + + def test_calculate_analyzer_score_minor(self): + """Test score calculation with minor issues.""" + result = AnalyzerResult(0, 2, 1, "") + score, explanation = calculate_analyzer_score(result) + assert score == 0.7 + assert "Minor issues" in explanation + + def test_calculate_analyzer_score_major(self): + """Test score calculation with major issues.""" + result = AnalyzerResult(3, 5, 2, "") + score, explanation = calculate_analyzer_score(result) + assert score == 0.3 + assert "Multiple issues" in explanation + + +class TestTestParsing: + """Test test output parsing.""" + + def test_parse_test_output_success(self): + """Test parsing successful test output.""" + output = "All tests passed!" + result = parse_test_output(output, success=True) + + assert result.passed is True + assert result.raw_output == output + + def test_parse_test_output_all_passed_in_output(self): + """Test parsing when 'All tests passed' is in output.""" + output = "Running tests... All tests passed! (5 tests)" + result = parse_test_output(output, success=False) + + assert result.passed is True + + def test_parse_test_output_no_tests(self): + """Test parsing when no tests passed.""" + output = "+0 tests passed" + result = parse_test_output(output, success=False) + + assert result.passed is False + assert result.passed_count == 0 + + def test_parse_test_output_partial(self): + """Test parsing when some tests failed.""" + output = "+3 -2: Some tests failed" + result = parse_test_output(output, success=False) + + assert result.passed is False + assert result.passed_count == 3 + assert result.failed_count == 2 + + def test_calculate_test_score_success(self): + """Test score calculation for successful tests.""" + result = TestResult(passed=True, raw_output="") + score, explanation = calculate_test_score(result) + assert score == 1.0 + assert "All tests passed" in explanation + + def test_calculate_test_score_no_tests(self): + """Test score calculation when no tests passed.""" + result = TestResult(passed=False, passed_count=0, raw_output="") + score, explanation = calculate_test_score(result) + assert score == 0.0 + assert "No tests found" in explanation + + def test_calculate_test_score_partial(self): + """Test score calculation for partial success.""" + result = TestResult(passed=False, passed_count=3, failed_count=2, raw_output="") + score, explanation = calculate_test_score(result) + # 3 passed / 5 total = 0.6 + assert score == 0.6 + assert "3/5" in explanation + + +class TestCodeStructureValidation: + """Test code structure validation.""" + + def test_validate_code_structure_complete(self): + """Test validation with all required elements.""" + code = """ + import 'package:flutter/material.dart'; + + void main() => runApp(MyApp()); + + class MyApp extends StatelessWidget { + @override + Widget build(BuildContext context) { + return MaterialApp( + home: Scaffold( + body: TextField(), + ), + ); + } + } + """ + required_widgets = ["TextField"] + score, explanation = validate_code_structure(code, required_widgets) + + assert score == 1.0 + assert "Contains required elements" in explanation + assert "MyApp class" in explanation + + def test_validate_code_structure_partial(self): + """Test validation with some missing elements.""" + code = """ + import 'package:flutter/material.dart'; + + void main() => runApp(MyApp()); + + class MyApp extends StatelessWidget { + @override + Widget build(BuildContext context) { + return Container(); + } + } + """ + required_widgets = ["TextField"] + score, explanation = validate_code_structure(code, required_widgets) + + # Has MyApp and StatelessWidget but missing MaterialApp and TextField + assert score == 0.7 + assert "Missing some elements" in explanation + + def test_validate_code_structure_minimal(self): + """Test validation with minimal code.""" + code = "void main() {}" + required_widgets = ["TextField", "MaterialApp"] + score, explanation = validate_code_structure(code, required_widgets) + + assert score == 0.3 + assert "Missing required elements" in explanation + + +class TestScoreCalculation: + """Test final score calculation.""" + + def test_calculate_final_score_perfect(self): + """Test perfect score.""" + score = calculate_final_score(1.0, 1.0, 1.0) + assert score == 1.0 + + def test_calculate_final_score_weighted(self): + """Test weighted score calculation.""" + # 30% analyzer (1.0) + 50% test (0.5) + 20% structure (1.0) + score = calculate_final_score(1.0, 0.5, 1.0) + expected = 0.3 * 1.0 + 0.5 * 0.5 + 0.2 * 1.0 + assert score == pytest.approx(expected) + + def test_calculate_final_score_zero(self): + """Test zero score.""" + score = calculate_final_score(0.0, 0.0, 0.0) + assert score == 0.0 + + def test_calculate_final_score_test_heavy(self): + """Test that test score is weighted most heavily.""" + # Test score is 50% of total + score_high_test = calculate_final_score(0.0, 1.0, 0.0) + score_high_analyzer = calculate_final_score(1.0, 0.0, 0.0) + score_high_structure = calculate_final_score(0.0, 0.0, 1.0) + + assert score_high_test > score_high_analyzer + assert score_high_test > score_high_structure + + def test_calculate_final_score_custom_weights(self): + """Test custom weights.""" + weights = {"analyzer": 0.5, "test": 0.3, "structure": 0.2} + score = calculate_final_score(1.0, 0.5, 1.0, weights=weights) + expected = 0.5 * 1.0 + 0.3 * 0.5 + 0.2 * 1.0 + assert score == pytest.approx(expected) diff --git a/packages/dash_evals/tests/test_models.py b/packages/dash_evals/tests/test_models.py new file mode 100644 index 0000000..12e3ae8 --- /dev/null +++ b/packages/dash_evals/tests/test_models.py @@ -0,0 +1,42 @@ +""" +Tests for task helpers. + +Tests validate_sandbox_tools helper from the manifest-based task system. +""" + +import pytest + +from dash_evals.runner.tasks.task_helpers import validate_sandbox_tools + + +class TestValidateSandboxTools: + """Tests for validate_sandbox_tools helper.""" + + def _make_config(self, sandbox_type: str = "local") -> dict: + """Create a minimal manifest config dict for testing.""" + return { + "task_name": "test_task:baseline", + "sandbox_type": sandbox_type, + } + + def test_raises_for_local_with_injection_tools(self): + """Local sandbox + bash_session/text_editor should raise ValueError.""" + config = self._make_config(sandbox_type="local") + with pytest.raises(ValueError, match="cannot run on a local sandbox"): + validate_sandbox_tools(config, ["bash_session", "text_editor"]) + + def test_passes_for_docker(self): + """Docker sandbox should never raise, regardless of tools.""" + config = self._make_config(sandbox_type="docker") + validate_sandbox_tools(config, ["bash_session", "text_editor"]) + + def test_passes_for_local_with_safe_tools(self): + """Local sandbox with non-injection tools should not raise.""" + config = self._make_config(sandbox_type="local") + validate_sandbox_tools(config, ["bash"]) + + def test_raises_for_single_injection_tool(self): + """Even one injection tool should be enough to raise.""" + config = self._make_config(sandbox_type="local") + with pytest.raises(ValueError, match="bash_session"): + validate_sandbox_tools(config, ["bash_session"]) diff --git a/packages/dash_evals/tests/test_scorers.py b/packages/dash_evals/tests/test_scorers.py new file mode 100644 index 0000000..ce96a45 --- /dev/null +++ b/packages/dash_evals/tests/test_scorers.py @@ -0,0 +1,147 @@ +""" +Tests for scorer modules. + +Tests the pure-function helpers and the scorer factory functions where possible +without spinning up an actual Inspect AI evaluation. + +Covered: +- code_quality._parse_json_response +- mcp_tool_usage (DART_MCP_TOOLS constant, scorer factory) +- skill_usage (SKILL_TOOL_NAME constant) +""" + +import json +from unittest.mock import MagicMock + +import pytest + +from dash_evals.runner.scorers.code_quality import _parse_json_response +from dash_evals.runner.scorers.mcp_tool_usage import DART_MCP_TOOLS, mcp_tool_usage +from dash_evals.runner.scorers.skill_usage import SKILL_TOOL_NAME, skill_usage_scorer + + +# ────────────────────────────────────────────── +# _parse_json_response (code_quality helper) +# ────────────────────────────────────────────── +class TestParseJsonResponse: + """Tests for the _parse_json_response helper.""" + + def test_parse_plain_json(self): + """Test parsing a plain JSON string.""" + text = '{"minimality": 3, "elegance": 2, "robustness": 1, "reasoning": "Good"}' + result = _parse_json_response(text) + assert result is not None + assert result["minimality"] == 3 + assert result["reasoning"] == "Good" + + def test_parse_json_in_markdown_block(self): + """Test extracting JSON from a markdown code block.""" + text = '```json\n{"minimality": 2, "elegance": 2, "robustness": 2, "reasoning": "OK"}\n```' + result = _parse_json_response(text) + assert result is not None + assert result["minimality"] == 2 + + def test_parse_json_in_generic_block(self): + """Test extracting JSON from a generic code block.""" + text = '```\n{"minimality": 1, "elegance": 1, "robustness": 1, "reasoning": "Meh"}\n```' + result = _parse_json_response(text) + assert result is not None + assert result["reasoning"] == "Meh" + + def test_parse_json_embedded_in_text(self): + """Test extracting JSON object embedded in prose.""" + text = ( + "The code is decent. " + '{"minimality": 2, "elegance": 3, "robustness": 2, "reasoning": "Solid"}' + ) + result = _parse_json_response(text) + assert result is not None + assert result["elegance"] == 3 + + def test_parse_invalid_json_returns_none(self): + """Test that invalid/unparseable text returns None.""" + result = _parse_json_response("This is not JSON at all.") + assert result is None + + def test_parse_empty_string_returns_none(self): + """Test that empty string returns None.""" + result = _parse_json_response("") + assert result is None + + def test_parse_json_with_whitespace(self): + """Test parsing JSON with extra whitespace.""" + text = ' \n {"minimality": 1, "elegance": 1, "robustness": 1, "reasoning": "OK"} \n ' + result = _parse_json_response(text) + assert result is not None + assert result["minimality"] == 1 + + +# ────────────────────────────────────────────── +# DART_MCP_TOOLS constant +# ────────────────────────────────────────────── +class TestDartMCPTools: + """Tests for the DART_MCP_TOOLS constant.""" + + def test_is_set(self): + """DART_MCP_TOOLS should be a set.""" + assert isinstance(DART_MCP_TOOLS, set) + + def test_not_empty(self): + """DART_MCP_TOOLS should not be empty.""" + assert len(DART_MCP_TOOLS) > 0 + + def test_contains_known_tools(self): + """Should contain well-known Dart MCP server tools.""" + expected = ["analyze_files", "pub", "run_tests", "hot_reload", "launch_app"] + for tool in expected: + assert tool in DART_MCP_TOOLS, f"Missing tool: {tool}" + + def test_all_entries_are_strings(self): + """All tools should be string names.""" + for tool in DART_MCP_TOOLS: + assert isinstance(tool, str) + + +# ────────────────────────────────────────────── +# SKILL_TOOL_NAME constant +# ────────────────────────────────────────────── +class TestSkillUsageConstants: + """Tests for skill_usage constants.""" + + def test_skill_tool_name(self): + """SKILL_TOOL_NAME should be 'skill'.""" + assert SKILL_TOOL_NAME == "skill" + + +# ────────────────────────────────────────────── +# mcp_tool_usage scorer factory +# ────────────────────────────────────────────── +class TestMCPToolUsageScorerFactory: + """Tests for the mcp_tool_usage scorer factory function.""" + + def test_returns_callable(self): + """mcp_tool_usage() should return a callable scorer.""" + result = mcp_tool_usage() + assert callable(result) + + def test_accepts_custom_server_name(self): + """mcp_tool_usage() should accept a custom MCP server name.""" + result = mcp_tool_usage(mcp_server_name="Firebase") + assert callable(result) + + def test_accepts_custom_tool_names(self): + """mcp_tool_usage() should accept a custom tool names list.""" + result = mcp_tool_usage(mcp_tool_names=["my_tool_1", "my_tool_2"]) + assert callable(result) + + +# ────────────────────────────────────────────── +# skill_usage_scorer factory +# ────────────────────────────────────────────── +class TestSkillUsageScorerFactory: + """Tests for the skill_usage_scorer factory function.""" + + def test_returns_callable(self): + """skill_usage_scorer() should return a callable scorer.""" + result = skill_usage_scorer() + assert callable(result) diff --git a/packages/dash_evals/tests/test_solvers.py b/packages/dash_evals/tests/test_solvers.py new file mode 100644 index 0000000..0a7c08b --- /dev/null +++ b/packages/dash_evals/tests/test_solvers.py @@ -0,0 +1,55 @@ +""" +Tests for solver modules. + +Tests the extract_code, add_system_message, and context_injector solvers. +These are thin wrappers around pure logic, so we test the factory functions +and (where possible) the underlying logic without a full Inspect AI runtime. +""" + +from dash_evals.runner.solvers.add_system_message import add_system_message +from dash_evals.runner.solvers.context_injector import context_injector +from dash_evals.runner.solvers.extract_code import extract_code + + +class TestExtractCodeFactory: + """Tests for the extract_code solver factory.""" + + def test_returns_callable(self): + """extract_code() should return a callable solver.""" + result = extract_code() + assert callable(result) + + def test_accepts_language_arg(self): + """extract_code() should accept a custom language argument.""" + result = extract_code(language="python") + assert callable(result) + + +class TestAddSystemMessageFactory: + """Tests for the add_system_message solver factory.""" + + def test_returns_callable(self): + """add_system_message() should return a callable solver.""" + result = add_system_message("Hello system") + assert callable(result) + + def test_accepts_message_with_curly_braces(self): + """add_system_message() should handle messages with curly braces (code).""" + # This is the reason this solver exists — to avoid template formatting + result = add_system_message("void main() { print('hello'); }") + assert callable(result) + + +class TestContextInjectorFactory: + """Tests for the context_injector solver factory.""" + + def test_returns_callable_with_empty_list(self): + """context_injector() with empty list should return callable.""" + result = context_injector(context_files=[]) + assert callable(result) + + def test_returns_callable_with_files(self): + """context_injector() with ContextFile list should return callable.""" + # ContextFile construction may need real file paths, test the factory only + result = context_injector(context_files=[]) + assert callable(result) diff --git a/packages/dash_evals/tests/test_utils.py b/packages/dash_evals/tests/test_utils.py new file mode 100644 index 0000000..4ce3e85 --- /dev/null +++ b/packages/dash_evals/tests/test_utils.py @@ -0,0 +1,77 @@ +""" +Tests for utility modules: markdown extraction and YAML loading. +""" + +from pathlib import Path + +import pytest + +from dash_evals.utils.markdown import extract_code_from_markdown + + +class TestExtractCodeFromMarkdown: + """Tests for extract_code_from_markdown().""" + + def test_extract_dart_code_block(self): + """Test extracting dart code from a language-specific block.""" + text = "Here's the code:\n```dart\nvoid main() {}\n```\nDone." + result = extract_code_from_markdown(text, language="dart") + assert result == "void main() {}" + + def test_extract_python_code_block(self): + """Test extracting python code.""" + text = "```python\nprint('hello')\n```" + result = extract_code_from_markdown(text, language="python") + assert result == "print('hello')" + + def test_extract_generic_code_block(self): + """Test extracting from a generic (no language) code block.""" + text = "```\nsome code\n```" + result = extract_code_from_markdown(text) + assert result == "some code" + + def test_no_code_block_returns_original(self): + """Test that text without code blocks is returned as-is.""" + text = "Just some plain text without code blocks." + result = extract_code_from_markdown(text) + assert result == text + + def test_language_specific_fallback(self): + """Test that dart block is found even when different language requested.""" + text = "```dart\nvoid main() {}\n```" + # Requesting "python" but only dart block exists + result = extract_code_from_markdown(text, language="python") + # Should fallback to finding the dart block + assert result == "void main() {}" + + def test_multiple_code_blocks_extracts_first(self): + """Test that the first matching code block is extracted.""" + text = "```dart\nfirst()\n```\n\n```dart\nsecond()\n```" + result = extract_code_from_markdown(text, language="dart") + assert result == "first()" + + def test_multiline_code_block(self): + """Test extracting multiline code.""" + text = """```dart +import 'package:flutter/material.dart'; + +void main() => runApp(MyApp()); + +class MyApp extends StatelessWidget { + @override + Widget build(BuildContext context) { + return Container(); + } +} +```""" + result = extract_code_from_markdown(text, language="dart") + assert "import 'package:flutter/material.dart';" in result + assert "class MyApp" in result + + def test_code_block_with_surrounding_text(self): + """Test extraction when code block is embedded in explanatory text.""" + text = ( + "Here is the solution:\n\n```dart\nvoid main() {}\n```\n\nThis creates a minimal app." + ) + result = extract_code_from_markdown(text, language="dart") + assert result == "void main() {}" diff --git a/packages/dataset_config/lib/src/models/task.dart b/packages/dataset_config/lib/src/models/task.dart index 6f43641..5a2d2d1 100644 --- a/packages/dataset_config/lib/src/models/task.dart +++ b/packages/dataset_config/lib/src/models/task.dart @@ -93,7 +93,7 @@ sealed class Task with _$Task { /// /// When present, the Python runner uses this to look up a pre-built /// `@task` function (e.g. `"flutter_code_gen"` or - /// `"eval_runner.runner.tasks.flutter_code_gen"`). + /// `"dash_evals.runner.tasks.flutter_code_gen"`). /// When absent, the runner hydrates directly from JSON (Mode 2 — future). @JsonKey(name: 'task_func') String? taskFunc, diff --git a/packages/dataset_config/lib/src/models/task.freezed.dart b/packages/dataset_config/lib/src/models/task.freezed.dart index fb9e212..94a4a37 100644 --- a/packages/dataset_config/lib/src/models/task.freezed.dart +++ b/packages/dataset_config/lib/src/models/task.freezed.dart @@ -55,7 +55,7 @@ mixin _$Task { /// /// When present, the Python runner uses this to look up a pre-built /// `@task` function (e.g. `"flutter_code_gen"` or -/// `"eval_runner.runner.tasks.flutter_code_gen"`). +/// `"dash_evals.runner.tasks.flutter_code_gen"`). /// When absent, the runner hydrates directly from JSON (Mode 2 — future). @JsonKey(name: 'task_func') String? get taskFunc;/// Task name. /// @@ -346,7 +346,7 @@ class _Task implements Task { /// /// When present, the Python runner uses this to look up a pre-built /// `@task` function (e.g. `"flutter_code_gen"` or -/// `"eval_runner.runner.tasks.flutter_code_gen"`). +/// `"dash_evals.runner.tasks.flutter_code_gen"`). /// When absent, the runner hydrates directly from JSON (Mode 2 — future). @override@JsonKey(name: 'task_func') final String? taskFunc; /// Task name. diff --git a/packages/devals_cli/lib/src/commands/doctor_command.dart b/packages/devals_cli/lib/src/commands/doctor_command.dart index eb174c0..c597388 100644 --- a/packages/devals_cli/lib/src/commands/doctor_command.dart +++ b/packages/devals_cli/lib/src/commands/doctor_command.dart @@ -48,7 +48,7 @@ typedef ProcessRunner = /// Command that checks whether prerequisites are installed. /// /// Similar to `flutter doctor`, this verifies the tools needed -/// for the CLI, eval_runner, and eval_explorer. +/// for the CLI, dash_evals, and eval_explorer. class DoctorCommand extends Command { DoctorCommand({ProcessRunner? processRunner}) : _runProcess = processRunner ?? Process.run; @@ -61,7 +61,7 @@ class DoctorCommand extends Command { @override String get description => 'Check that all prerequisites are installed for ' - 'the CLI, eval_runner, and eval_explorer.'; + 'the CLI, dash_evals, and eval_explorer.'; @override Future run() async { @@ -151,19 +151,19 @@ List buildChecks({ProcessRunner? processRunner}) { ), DoctorCheck( name: 'Python', - component: 'eval_runner', + component: 'dash_evals', isRequired: true, check: () => _checkPython(run), ), DoctorCheck( - name: 'eval_runner installed', - component: 'eval_runner', + name: 'dash_evals installed', + component: 'dash_evals', isRequired: true, - check: () => _checkEvalRunner(run), + check: () => _checkDashEvals(run), ), DoctorCheck( name: 'Podman', - component: 'eval_runner', + component: 'dash_evals', check: () => _checkPodman(run), ), DoctorCheck( @@ -179,7 +179,7 @@ List buildChecks({ProcessRunner? processRunner}) { ), DoctorCheck( name: 'API keys', - component: 'eval_runner', + component: 'dash_evals', isRequired: true, check: () => _checkApiKeys(), ), @@ -254,13 +254,13 @@ Future _checkPython(ProcessRunner run) async { return CheckResult(status: CheckStatus.ok, version: version); } -Future _checkEvalRunner(ProcessRunner run) async { +Future _checkDashEvals(ProcessRunner run) async { final output = await _tryRun(run, 'run-evals', ['--help']); if (output == null) { return const CheckResult( status: CheckStatus.error, message: 'not found', - fix: 'cd path/to/eval_runner && pip install -e .', + fix: 'cd path/to/dash_evals && pip install -e .', ); } return const CheckResult(status: CheckStatus.ok); diff --git a/packages/devals_cli/lib/src/commands/run_command.dart b/packages/devals_cli/lib/src/commands/run_command.dart index 883c036..b7e159e 100644 --- a/packages/devals_cli/lib/src/commands/run_command.dart +++ b/packages/devals_cli/lib/src/commands/run_command.dart @@ -7,7 +7,7 @@ import 'package:devals/src/dataset/filesystem_utils.dart'; import 'package:howdy/howdy.dart'; import 'package:path/path.dart' as p; -/// Command to run evaluations using the Python eval_runner. +/// Command to run evaluations using the Python dash_evals package. /// /// Config resolution and dry-run happen entirely in Dart. For actual runs, /// Dart writes an EvalSet JSON file, then Python reads it and calls @@ -25,7 +25,7 @@ class RunCommand extends Command { String get name => 'run'; @override - String get description => 'Run evaluations using the eval_runner.'; + String get description => 'Run evaluations using dash_evals.'; @override String get invocation => '${runner?.executableName} run '; @@ -82,8 +82,8 @@ class RunCommand extends Command { if (e.errorCode == 2) { Text.error( 'Command "run-evals" not found.\n' - 'Please install the eval_runner Python package:\n' - ' pip install -e /pkgs/eval_runner', + 'Please install the dash_evals Python package:\n' + ' pip install -e /packages/dash_evals', ); return 1; } diff --git a/packages/devals_cli/test/commands/doctor_command_test.dart b/packages/devals_cli/test/commands/doctor_command_test.dart index 0d5ec3a..15f3e75 100644 --- a/packages/devals_cli/test/commands/doctor_command_test.dart +++ b/packages/devals_cli/test/commands/doctor_command_test.dart @@ -85,7 +85,7 @@ void main() { }); }); - group('eval_runner check', () { + group('dash_evals check', () { test('succeeds when installed', () async { final checks = buildChecks( processRunner: mockProcessRunner({ From 7d9889b0e503f3d5fda180bebf6ebd21c81b0fc5 Mon Sep 17 00:00:00 2001 From: Eric Windmill Date: Wed, 11 Mar 2026 11:42:20 -0700 Subject: [PATCH 2/7] github workflow --- .github/workflows/dash_evals_module_tests.yml | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 .github/workflows/dash_evals_module_tests.yml diff --git a/.github/workflows/dash_evals_module_tests.yml b/.github/workflows/dash_evals_module_tests.yml new file mode 100644 index 0000000..d8dd65c --- /dev/null +++ b/.github/workflows/dash_evals_module_tests.yml @@ -0,0 +1,44 @@ +name: dash_evals module - Python tests + +on: + pull_request: + paths: + - 'packages/dash_evals/**' + - '.github/workflows/dash_evals_module_tests.yml' + push: + branches: + - main + paths: + - 'packages/dash_evals/**' + - '.github/workflows/dash_evals_module_tests.yml' + +jobs: + runner-tests: + runs-on: ubuntu-latest + timeout-minutes: 15 + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.13' + + - name: Create virtual environment + working-directory: packages/dash_evals + run: python -m venv .venv + + - name: Install dependencies + working-directory: packages/dash_evals + run: | + source .venv/bin/activate + pip install --upgrade pip + pip install -e ".[dev]" + + - name: Run tests + working-directory: packages/dash_evals + run: | + source .venv/bin/activate + pytest -v From 92faf3a00d7c6216baf8437ad10fc18b72b2830b Mon Sep 17 00:00:00 2001 From: Eric Windmill Date: Wed, 11 Mar 2026 12:25:13 -0700 Subject: [PATCH 3/7] add python config for yardstick --- .gitignore | 1 + packages/dash_evals/pyrefly.toml | 2 +- .../.gitignore | 0 .../CHANGELOG.md | 0 .../README.md | 0 .../analysis_options.yaml | 0 .../lib/dataset_config_dart.dart} | 0 .../lib/src/config_resolver.dart | 0 .../lib/src/models/context_file.dart | 0 .../lib/src/models/context_file.freezed.dart | 0 .../lib/src/models/context_file.g.dart | 0 .../lib/src/models/dataset.dart | 0 .../lib/src/models/dataset.freezed.dart | 0 .../lib/src/models/dataset.g.dart | 0 .../lib/src/models/eval_log.dart | 0 .../lib/src/models/eval_log.freezed.dart | 0 .../lib/src/models/eval_log.g.dart | 0 .../lib/src/models/eval_set.dart | 2 +- .../lib/src/models/eval_set.freezed.dart | 0 .../lib/src/models/eval_set.g.dart | 0 .../lib/src/models/field_spec.dart | 0 .../lib/src/models/field_spec.freezed.dart | 0 .../lib/src/models/field_spec.g.dart | 0 .../lib/src/models/job.dart | 0 .../lib/src/models/job.freezed.dart | 0 .../lib/src/models/job.g.dart | 0 .../lib/src/models/models.dart | 0 .../lib/src/models/sample.dart | 0 .../lib/src/models/sample.freezed.dart | 0 .../lib/src/models/sample.g.dart | 0 .../lib/src/models/task.dart | 2 +- .../lib/src/models/task.freezed.dart | 0 .../lib/src/models/task.g.dart | 0 .../lib/src/models/task_info.dart | 0 .../lib/src/models/task_info.freezed.dart | 0 .../lib/src/models/task_info.g.dart | 0 .../lib/src/models/variant.dart | 0 .../lib/src/models/variant.freezed.dart | 0 .../lib/src/models/variant.g.dart | 0 .../lib/src/parsed_task.dart | 0 .../lib/src/parsers/json_parser.dart | 0 .../lib/src/parsers/parser.dart | 0 .../lib/src/parsers/yaml_parser.dart | 0 .../lib/src/resolvers/eval_set_resolver.dart | 0 .../lib/src/runner_config_exception.dart | 0 .../lib/src/utils/yaml_utils.dart | 0 .../lib/src/writers/eval_set_writer.dart | 0 .../pubspec.yaml | 2 +- .../test/eval_set_resolver_test.dart | 2 +- .../test/eval_set_writer_test.dart | 2 +- .../test/json_parser_test.dart | 2 +- .../test/parsed_task_test.dart | 2 +- .../test/yaml_utils_test.dart | 2 +- packages/dataset_config_python/README.md | 17 + packages/dataset_config_python/pyproject.toml | 37 ++ .../src/dataset_config_python/__init__.py | 12 + .../dataset_config_python/models/__init__.py | 21 + .../models/context_file.py | 72 +++ .../dataset_config_python/models/dataset.py | 17 + .../dataset_config_python/models/eval_set.py | 95 +++ .../src/dataset_config_python/models/job.py | 102 +++ .../dataset_config_python/models/sample.py | 38 ++ .../src/dataset_config_python/models/task.py | 77 +++ .../dataset_config_python/models/variant.py | 36 ++ .../src/dataset_config_python/parser.py | 582 ++++++++++++++++++ .../src/dataset_config_python/resolver.py | 503 +++++++++++++++ .../src/dataset_config_python/writer.py | 29 + .../dataset_config_python/tests/__init__.py | 0 .../tests/test_config.py | 363 +++++++++++ packages/devals_cli/example/pubspec.lock | 8 +- .../lib/src/commands/create_job_command.dart | 2 +- .../src/commands/create_pipeline_command.dart | 2 +- .../lib/src/commands/run_command.dart | 2 +- .../devals_cli/lib/src/dataset/dry_run.dart | 2 +- packages/devals_cli/pubspec.yaml | 4 +- pubspec.yaml | 2 +- 76 files changed, 2022 insertions(+), 20 deletions(-) rename packages/{dataset_config => dataset_config_dart}/.gitignore (100%) rename packages/{dataset_config => dataset_config_dart}/CHANGELOG.md (100%) rename packages/{dataset_config => dataset_config_dart}/README.md (100%) rename packages/{dataset_config => dataset_config_dart}/analysis_options.yaml (100%) rename packages/{dataset_config/lib/dataset_config.dart => dataset_config_dart/lib/dataset_config_dart.dart} (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/config_resolver.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/context_file.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/context_file.freezed.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/context_file.g.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/dataset.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/dataset.freezed.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/dataset.g.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/eval_log.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/eval_log.freezed.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/eval_log.g.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/eval_set.dart (99%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/eval_set.freezed.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/eval_set.g.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/field_spec.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/field_spec.freezed.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/field_spec.g.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/job.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/job.freezed.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/job.g.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/models.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/sample.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/sample.freezed.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/sample.g.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/task.dart (98%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/task.freezed.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/task.g.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/task_info.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/task_info.freezed.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/task_info.g.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/variant.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/variant.freezed.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/models/variant.g.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/parsed_task.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/parsers/json_parser.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/parsers/parser.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/parsers/yaml_parser.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/resolvers/eval_set_resolver.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/runner_config_exception.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/utils/yaml_utils.dart (100%) rename packages/{dataset_config => dataset_config_dart}/lib/src/writers/eval_set_writer.dart (100%) rename packages/{dataset_config => dataset_config_dart}/pubspec.yaml (92%) rename packages/{dataset_config => dataset_config_dart}/test/eval_set_resolver_test.dart (99%) rename packages/{dataset_config => dataset_config_dart}/test/eval_set_writer_test.dart (97%) rename packages/{dataset_config => dataset_config_dart}/test/json_parser_test.dart (99%) rename packages/{dataset_config => dataset_config_dart}/test/parsed_task_test.dart (98%) rename packages/{dataset_config => dataset_config_dart}/test/yaml_utils_test.dart (98%) create mode 100644 packages/dataset_config_python/README.md create mode 100644 packages/dataset_config_python/pyproject.toml create mode 100644 packages/dataset_config_python/src/dataset_config_python/__init__.py create mode 100644 packages/dataset_config_python/src/dataset_config_python/models/__init__.py create mode 100644 packages/dataset_config_python/src/dataset_config_python/models/context_file.py create mode 100644 packages/dataset_config_python/src/dataset_config_python/models/dataset.py create mode 100644 packages/dataset_config_python/src/dataset_config_python/models/eval_set.py create mode 100644 packages/dataset_config_python/src/dataset_config_python/models/job.py create mode 100644 packages/dataset_config_python/src/dataset_config_python/models/sample.py create mode 100644 packages/dataset_config_python/src/dataset_config_python/models/task.py create mode 100644 packages/dataset_config_python/src/dataset_config_python/models/variant.py create mode 100644 packages/dataset_config_python/src/dataset_config_python/parser.py create mode 100644 packages/dataset_config_python/src/dataset_config_python/resolver.py create mode 100644 packages/dataset_config_python/src/dataset_config_python/writer.py create mode 100644 packages/dataset_config_python/tests/__init__.py create mode 100644 packages/dataset_config_python/tests/test_config.py diff --git a/.gitignore b/.gitignore index 286021d..87eb4c9 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ coverage /docs/_build /docs/dart_docs logs/ +**/pyrefly.toml ## diff --git a/packages/dash_evals/pyrefly.toml b/packages/dash_evals/pyrefly.toml index 6370db3..d35f141 100644 --- a/packages/dash_evals/pyrefly.toml +++ b/packages/dash_evals/pyrefly.toml @@ -1,4 +1,4 @@ # Pyrefly configuration -# Tell Pyrefly to use the local venv Python interpreter +# Tell Pyrefly to use the repo-root venv Python interpreter python-interpreter = "../../.venv/bin/python" diff --git a/packages/dataset_config/.gitignore b/packages/dataset_config_dart/.gitignore similarity index 100% rename from packages/dataset_config/.gitignore rename to packages/dataset_config_dart/.gitignore diff --git a/packages/dataset_config/CHANGELOG.md b/packages/dataset_config_dart/CHANGELOG.md similarity index 100% rename from packages/dataset_config/CHANGELOG.md rename to packages/dataset_config_dart/CHANGELOG.md diff --git a/packages/dataset_config/README.md b/packages/dataset_config_dart/README.md similarity index 100% rename from packages/dataset_config/README.md rename to packages/dataset_config_dart/README.md diff --git a/packages/dataset_config/analysis_options.yaml b/packages/dataset_config_dart/analysis_options.yaml similarity index 100% rename from packages/dataset_config/analysis_options.yaml rename to packages/dataset_config_dart/analysis_options.yaml diff --git a/packages/dataset_config/lib/dataset_config.dart b/packages/dataset_config_dart/lib/dataset_config_dart.dart similarity index 100% rename from packages/dataset_config/lib/dataset_config.dart rename to packages/dataset_config_dart/lib/dataset_config_dart.dart diff --git a/packages/dataset_config/lib/src/config_resolver.dart b/packages/dataset_config_dart/lib/src/config_resolver.dart similarity index 100% rename from packages/dataset_config/lib/src/config_resolver.dart rename to packages/dataset_config_dart/lib/src/config_resolver.dart diff --git a/packages/dataset_config/lib/src/models/context_file.dart b/packages/dataset_config_dart/lib/src/models/context_file.dart similarity index 100% rename from packages/dataset_config/lib/src/models/context_file.dart rename to packages/dataset_config_dart/lib/src/models/context_file.dart diff --git a/packages/dataset_config/lib/src/models/context_file.freezed.dart b/packages/dataset_config_dart/lib/src/models/context_file.freezed.dart similarity index 100% rename from packages/dataset_config/lib/src/models/context_file.freezed.dart rename to packages/dataset_config_dart/lib/src/models/context_file.freezed.dart diff --git a/packages/dataset_config/lib/src/models/context_file.g.dart b/packages/dataset_config_dart/lib/src/models/context_file.g.dart similarity index 100% rename from packages/dataset_config/lib/src/models/context_file.g.dart rename to packages/dataset_config_dart/lib/src/models/context_file.g.dart diff --git a/packages/dataset_config/lib/src/models/dataset.dart b/packages/dataset_config_dart/lib/src/models/dataset.dart similarity index 100% rename from packages/dataset_config/lib/src/models/dataset.dart rename to packages/dataset_config_dart/lib/src/models/dataset.dart diff --git a/packages/dataset_config/lib/src/models/dataset.freezed.dart b/packages/dataset_config_dart/lib/src/models/dataset.freezed.dart similarity index 100% rename from packages/dataset_config/lib/src/models/dataset.freezed.dart rename to packages/dataset_config_dart/lib/src/models/dataset.freezed.dart diff --git a/packages/dataset_config/lib/src/models/dataset.g.dart b/packages/dataset_config_dart/lib/src/models/dataset.g.dart similarity index 100% rename from packages/dataset_config/lib/src/models/dataset.g.dart rename to packages/dataset_config_dart/lib/src/models/dataset.g.dart diff --git a/packages/dataset_config/lib/src/models/eval_log.dart b/packages/dataset_config_dart/lib/src/models/eval_log.dart similarity index 100% rename from packages/dataset_config/lib/src/models/eval_log.dart rename to packages/dataset_config_dart/lib/src/models/eval_log.dart diff --git a/packages/dataset_config/lib/src/models/eval_log.freezed.dart b/packages/dataset_config_dart/lib/src/models/eval_log.freezed.dart similarity index 100% rename from packages/dataset_config/lib/src/models/eval_log.freezed.dart rename to packages/dataset_config_dart/lib/src/models/eval_log.freezed.dart diff --git a/packages/dataset_config/lib/src/models/eval_log.g.dart b/packages/dataset_config_dart/lib/src/models/eval_log.g.dart similarity index 100% rename from packages/dataset_config/lib/src/models/eval_log.g.dart rename to packages/dataset_config_dart/lib/src/models/eval_log.g.dart diff --git a/packages/dataset_config/lib/src/models/eval_set.dart b/packages/dataset_config_dart/lib/src/models/eval_set.dart similarity index 99% rename from packages/dataset_config/lib/src/models/eval_set.dart rename to packages/dataset_config_dart/lib/src/models/eval_set.dart index 94c8ddf..ce3c6fd 100644 --- a/packages/dataset_config/lib/src/models/eval_set.dart +++ b/packages/dataset_config_dart/lib/src/models/eval_set.dart @@ -1,5 +1,5 @@ import 'package:freezed_annotation/freezed_annotation.dart'; -import 'package:dataset_config/src/models/models.dart'; +import 'package:dataset_config_dart/src/models/models.dart'; part 'eval_set.freezed.dart'; part 'eval_set.g.dart'; diff --git a/packages/dataset_config/lib/src/models/eval_set.freezed.dart b/packages/dataset_config_dart/lib/src/models/eval_set.freezed.dart similarity index 100% rename from packages/dataset_config/lib/src/models/eval_set.freezed.dart rename to packages/dataset_config_dart/lib/src/models/eval_set.freezed.dart diff --git a/packages/dataset_config/lib/src/models/eval_set.g.dart b/packages/dataset_config_dart/lib/src/models/eval_set.g.dart similarity index 100% rename from packages/dataset_config/lib/src/models/eval_set.g.dart rename to packages/dataset_config_dart/lib/src/models/eval_set.g.dart diff --git a/packages/dataset_config/lib/src/models/field_spec.dart b/packages/dataset_config_dart/lib/src/models/field_spec.dart similarity index 100% rename from packages/dataset_config/lib/src/models/field_spec.dart rename to packages/dataset_config_dart/lib/src/models/field_spec.dart diff --git a/packages/dataset_config/lib/src/models/field_spec.freezed.dart b/packages/dataset_config_dart/lib/src/models/field_spec.freezed.dart similarity index 100% rename from packages/dataset_config/lib/src/models/field_spec.freezed.dart rename to packages/dataset_config_dart/lib/src/models/field_spec.freezed.dart diff --git a/packages/dataset_config/lib/src/models/field_spec.g.dart b/packages/dataset_config_dart/lib/src/models/field_spec.g.dart similarity index 100% rename from packages/dataset_config/lib/src/models/field_spec.g.dart rename to packages/dataset_config_dart/lib/src/models/field_spec.g.dart diff --git a/packages/dataset_config/lib/src/models/job.dart b/packages/dataset_config_dart/lib/src/models/job.dart similarity index 100% rename from packages/dataset_config/lib/src/models/job.dart rename to packages/dataset_config_dart/lib/src/models/job.dart diff --git a/packages/dataset_config/lib/src/models/job.freezed.dart b/packages/dataset_config_dart/lib/src/models/job.freezed.dart similarity index 100% rename from packages/dataset_config/lib/src/models/job.freezed.dart rename to packages/dataset_config_dart/lib/src/models/job.freezed.dart diff --git a/packages/dataset_config/lib/src/models/job.g.dart b/packages/dataset_config_dart/lib/src/models/job.g.dart similarity index 100% rename from packages/dataset_config/lib/src/models/job.g.dart rename to packages/dataset_config_dart/lib/src/models/job.g.dart diff --git a/packages/dataset_config/lib/src/models/models.dart b/packages/dataset_config_dart/lib/src/models/models.dart similarity index 100% rename from packages/dataset_config/lib/src/models/models.dart rename to packages/dataset_config_dart/lib/src/models/models.dart diff --git a/packages/dataset_config/lib/src/models/sample.dart b/packages/dataset_config_dart/lib/src/models/sample.dart similarity index 100% rename from packages/dataset_config/lib/src/models/sample.dart rename to packages/dataset_config_dart/lib/src/models/sample.dart diff --git a/packages/dataset_config/lib/src/models/sample.freezed.dart b/packages/dataset_config_dart/lib/src/models/sample.freezed.dart similarity index 100% rename from packages/dataset_config/lib/src/models/sample.freezed.dart rename to packages/dataset_config_dart/lib/src/models/sample.freezed.dart diff --git a/packages/dataset_config/lib/src/models/sample.g.dart b/packages/dataset_config_dart/lib/src/models/sample.g.dart similarity index 100% rename from packages/dataset_config/lib/src/models/sample.g.dart rename to packages/dataset_config_dart/lib/src/models/sample.g.dart diff --git a/packages/dataset_config/lib/src/models/task.dart b/packages/dataset_config_dart/lib/src/models/task.dart similarity index 98% rename from packages/dataset_config/lib/src/models/task.dart rename to packages/dataset_config_dart/lib/src/models/task.dart index 5a2d2d1..ccb568b 100644 --- a/packages/dataset_config/lib/src/models/task.dart +++ b/packages/dataset_config_dart/lib/src/models/task.dart @@ -1,5 +1,5 @@ import 'package:freezed_annotation/freezed_annotation.dart'; -import 'package:dataset_config/src/models/models.dart'; +import 'package:dataset_config_dart/src/models/models.dart'; part 'task.freezed.dart'; part 'task.g.dart'; diff --git a/packages/dataset_config/lib/src/models/task.freezed.dart b/packages/dataset_config_dart/lib/src/models/task.freezed.dart similarity index 100% rename from packages/dataset_config/lib/src/models/task.freezed.dart rename to packages/dataset_config_dart/lib/src/models/task.freezed.dart diff --git a/packages/dataset_config/lib/src/models/task.g.dart b/packages/dataset_config_dart/lib/src/models/task.g.dart similarity index 100% rename from packages/dataset_config/lib/src/models/task.g.dart rename to packages/dataset_config_dart/lib/src/models/task.g.dart diff --git a/packages/dataset_config/lib/src/models/task_info.dart b/packages/dataset_config_dart/lib/src/models/task_info.dart similarity index 100% rename from packages/dataset_config/lib/src/models/task_info.dart rename to packages/dataset_config_dart/lib/src/models/task_info.dart diff --git a/packages/dataset_config/lib/src/models/task_info.freezed.dart b/packages/dataset_config_dart/lib/src/models/task_info.freezed.dart similarity index 100% rename from packages/dataset_config/lib/src/models/task_info.freezed.dart rename to packages/dataset_config_dart/lib/src/models/task_info.freezed.dart diff --git a/packages/dataset_config/lib/src/models/task_info.g.dart b/packages/dataset_config_dart/lib/src/models/task_info.g.dart similarity index 100% rename from packages/dataset_config/lib/src/models/task_info.g.dart rename to packages/dataset_config_dart/lib/src/models/task_info.g.dart diff --git a/packages/dataset_config/lib/src/models/variant.dart b/packages/dataset_config_dart/lib/src/models/variant.dart similarity index 100% rename from packages/dataset_config/lib/src/models/variant.dart rename to packages/dataset_config_dart/lib/src/models/variant.dart diff --git a/packages/dataset_config/lib/src/models/variant.freezed.dart b/packages/dataset_config_dart/lib/src/models/variant.freezed.dart similarity index 100% rename from packages/dataset_config/lib/src/models/variant.freezed.dart rename to packages/dataset_config_dart/lib/src/models/variant.freezed.dart diff --git a/packages/dataset_config/lib/src/models/variant.g.dart b/packages/dataset_config_dart/lib/src/models/variant.g.dart similarity index 100% rename from packages/dataset_config/lib/src/models/variant.g.dart rename to packages/dataset_config_dart/lib/src/models/variant.g.dart diff --git a/packages/dataset_config/lib/src/parsed_task.dart b/packages/dataset_config_dart/lib/src/parsed_task.dart similarity index 100% rename from packages/dataset_config/lib/src/parsed_task.dart rename to packages/dataset_config_dart/lib/src/parsed_task.dart diff --git a/packages/dataset_config/lib/src/parsers/json_parser.dart b/packages/dataset_config_dart/lib/src/parsers/json_parser.dart similarity index 100% rename from packages/dataset_config/lib/src/parsers/json_parser.dart rename to packages/dataset_config_dart/lib/src/parsers/json_parser.dart diff --git a/packages/dataset_config/lib/src/parsers/parser.dart b/packages/dataset_config_dart/lib/src/parsers/parser.dart similarity index 100% rename from packages/dataset_config/lib/src/parsers/parser.dart rename to packages/dataset_config_dart/lib/src/parsers/parser.dart diff --git a/packages/dataset_config/lib/src/parsers/yaml_parser.dart b/packages/dataset_config_dart/lib/src/parsers/yaml_parser.dart similarity index 100% rename from packages/dataset_config/lib/src/parsers/yaml_parser.dart rename to packages/dataset_config_dart/lib/src/parsers/yaml_parser.dart diff --git a/packages/dataset_config/lib/src/resolvers/eval_set_resolver.dart b/packages/dataset_config_dart/lib/src/resolvers/eval_set_resolver.dart similarity index 100% rename from packages/dataset_config/lib/src/resolvers/eval_set_resolver.dart rename to packages/dataset_config_dart/lib/src/resolvers/eval_set_resolver.dart diff --git a/packages/dataset_config/lib/src/runner_config_exception.dart b/packages/dataset_config_dart/lib/src/runner_config_exception.dart similarity index 100% rename from packages/dataset_config/lib/src/runner_config_exception.dart rename to packages/dataset_config_dart/lib/src/runner_config_exception.dart diff --git a/packages/dataset_config/lib/src/utils/yaml_utils.dart b/packages/dataset_config_dart/lib/src/utils/yaml_utils.dart similarity index 100% rename from packages/dataset_config/lib/src/utils/yaml_utils.dart rename to packages/dataset_config_dart/lib/src/utils/yaml_utils.dart diff --git a/packages/dataset_config/lib/src/writers/eval_set_writer.dart b/packages/dataset_config_dart/lib/src/writers/eval_set_writer.dart similarity index 100% rename from packages/dataset_config/lib/src/writers/eval_set_writer.dart rename to packages/dataset_config_dart/lib/src/writers/eval_set_writer.dart diff --git a/packages/dataset_config/pubspec.yaml b/packages/dataset_config_dart/pubspec.yaml similarity index 92% rename from packages/dataset_config/pubspec.yaml rename to packages/dataset_config_dart/pubspec.yaml index 2404ff7..cc76a7a 100644 --- a/packages/dataset_config/pubspec.yaml +++ b/packages/dataset_config_dart/pubspec.yaml @@ -1,4 +1,4 @@ -name: dataset_config +name: dataset_config_dart description: Core library for resolving eval dataset YAML into run manifests. version: 0.0.1 publish_to: none diff --git a/packages/dataset_config/test/eval_set_resolver_test.dart b/packages/dataset_config_dart/test/eval_set_resolver_test.dart similarity index 99% rename from packages/dataset_config/test/eval_set_resolver_test.dart rename to packages/dataset_config_dart/test/eval_set_resolver_test.dart index 03d4c02..d982b58 100644 --- a/packages/dataset_config/test/eval_set_resolver_test.dart +++ b/packages/dataset_config_dart/test/eval_set_resolver_test.dart @@ -1,4 +1,4 @@ -import 'package:dataset_config/dataset_config.dart'; +import 'package:dataset_config_dart/dataset_config_dart.dart'; import 'package:test/test.dart'; void main() { diff --git a/packages/dataset_config/test/eval_set_writer_test.dart b/packages/dataset_config_dart/test/eval_set_writer_test.dart similarity index 97% rename from packages/dataset_config/test/eval_set_writer_test.dart rename to packages/dataset_config_dart/test/eval_set_writer_test.dart index 947e93b..2ef58e5 100644 --- a/packages/dataset_config/test/eval_set_writer_test.dart +++ b/packages/dataset_config_dart/test/eval_set_writer_test.dart @@ -1,7 +1,7 @@ import 'dart:convert'; import 'dart:io'; -import 'package:dataset_config/dataset_config.dart'; +import 'package:dataset_config_dart/dataset_config_dart.dart'; import 'package:test/test.dart'; void main() { diff --git a/packages/dataset_config/test/json_parser_test.dart b/packages/dataset_config_dart/test/json_parser_test.dart similarity index 99% rename from packages/dataset_config/test/json_parser_test.dart rename to packages/dataset_config_dart/test/json_parser_test.dart index d7d9d34..f09520c 100644 --- a/packages/dataset_config/test/json_parser_test.dart +++ b/packages/dataset_config_dart/test/json_parser_test.dart @@ -1,4 +1,4 @@ -import 'package:dataset_config/dataset_config.dart'; +import 'package:dataset_config_dart/dataset_config_dart.dart'; import 'package:test/test.dart'; void main() { diff --git a/packages/dataset_config/test/parsed_task_test.dart b/packages/dataset_config_dart/test/parsed_task_test.dart similarity index 98% rename from packages/dataset_config/test/parsed_task_test.dart rename to packages/dataset_config_dart/test/parsed_task_test.dart index 6f35924..4921e30 100644 --- a/packages/dataset_config/test/parsed_task_test.dart +++ b/packages/dataset_config_dart/test/parsed_task_test.dart @@ -1,4 +1,4 @@ -import 'package:dataset_config/dataset_config.dart'; +import 'package:dataset_config_dart/dataset_config_dart.dart'; import 'package:test/test.dart'; void main() { diff --git a/packages/dataset_config/test/yaml_utils_test.dart b/packages/dataset_config_dart/test/yaml_utils_test.dart similarity index 98% rename from packages/dataset_config/test/yaml_utils_test.dart rename to packages/dataset_config_dart/test/yaml_utils_test.dart index 55fe43f..ec79202 100644 --- a/packages/dataset_config/test/yaml_utils_test.dart +++ b/packages/dataset_config_dart/test/yaml_utils_test.dart @@ -1,6 +1,6 @@ import 'dart:io'; -import 'package:dataset_config/dataset_config.dart'; +import 'package:dataset_config_dart/dataset_config_dart.dart'; import 'package:test/test.dart'; import 'package:yaml/yaml.dart'; diff --git a/packages/dataset_config_python/README.md b/packages/dataset_config_python/README.md new file mode 100644 index 0000000..40b60de --- /dev/null +++ b/packages/dataset_config_python/README.md @@ -0,0 +1,17 @@ +# dataset_config_python + +Configuration resolver for [dash-evals](../dash_evals/). Reads YAML config +files (jobs, tasks, samples) and produces the EvalSet JSON that `dash_evals` +consumes. + +**No Dart SDK or Inspect AI dependency required** — install this package alone +to resolve configs from Python. + +## Quick start + +```python +from dataset_config_python import resolve, write_eval_sets + +eval_sets = resolve(dataset_path="./my_dataset", job_names=["local_dev"]) +write_eval_sets(eval_sets, output_dir=".devals-tool/local_dev") +``` diff --git a/packages/dataset_config_python/pyproject.toml b/packages/dataset_config_python/pyproject.toml new file mode 100644 index 0000000..553eb3f --- /dev/null +++ b/packages/dataset_config_python/pyproject.toml @@ -0,0 +1,37 @@ +[project] +name = "dataset-config-python" +version = "0.1.0" +description = "Configuration resolver for dash-evals: reads YAML configs and produces EvalSet JSON." +authors = [{ name = "Eric Windmill", email = "eric@ericwindmill.com" }] +readme = "README.md" +requires-python = ">=3.13,<4.0.0" +dependencies = [ + "pyyaml>=6.0.3,<7.0.0", + "pydantic>=2.0.0,<3.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-mock>=3.12.0", +] + +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] + +[tool.ruff] +line-length = 100 + +[tool.ruff.lint] +select = ["E", "F", "W", "I"] +ignore = ["E501"] diff --git a/packages/dataset_config_python/src/dataset_config_python/__init__.py b/packages/dataset_config_python/src/dataset_config_python/__init__.py new file mode 100644 index 0000000..135b4cb --- /dev/null +++ b/packages/dataset_config_python/src/dataset_config_python/__init__.py @@ -0,0 +1,12 @@ +"""dataset_config_python — Configuration resolver for dash-evals. + +Reads YAML config files (jobs, tasks, samples) and produces the +EvalSet JSON that dash_evals consumes. + +No Dart SDK or Inspect AI dependency required. +""" + +from dataset_config_python.resolver import resolve +from dataset_config_python.writer import write_eval_sets + +__all__ = ["resolve", "write_eval_sets"] diff --git a/packages/dataset_config_python/src/dataset_config_python/models/__init__.py b/packages/dataset_config_python/src/dataset_config_python/models/__init__.py new file mode 100644 index 0000000..a90aaad --- /dev/null +++ b/packages/dataset_config_python/src/dataset_config_python/models/__init__.py @@ -0,0 +1,21 @@ +"""Pydantic models for dash-evals configuration.""" + +from dataset_config_python.models.context_file import ContextFile, ContextFileMetadata +from dataset_config_python.models.dataset import Dataset +from dataset_config_python.models.eval_set import EvalSet +from dataset_config_python.models.job import Job, JobTask +from dataset_config_python.models.sample import Sample +from dataset_config_python.models.task import Task +from dataset_config_python.models.variant import Variant + +__all__ = [ + "ContextFile", + "ContextFileMetadata", + "Dataset", + "EvalSet", + "Job", + "JobTask", + "Sample", + "Task", + "Variant", +] diff --git a/packages/dataset_config_python/src/dataset_config_python/models/context_file.py b/packages/dataset_config_python/src/dataset_config_python/models/context_file.py new file mode 100644 index 0000000..1dcd3d1 --- /dev/null +++ b/packages/dataset_config_python/src/dataset_config_python/models/context_file.py @@ -0,0 +1,72 @@ +"""Context file model — parsed YAML frontmatter + content.""" + +from __future__ import annotations + +import os + +import yaml +from pydantic import BaseModel + + +class ContextFileMetadata(BaseModel): + """Metadata parsed from a context file's YAML frontmatter.""" + + title: str + version: str + description: str + dart_version: str | None = None + flutter_version: str | None = None + updated: str | None = None + + +class ContextFile(BaseModel): + """A context file with parsed YAML frontmatter and markdown content.""" + + metadata: ContextFileMetadata + """Parsed frontmatter metadata.""" + + content: str + """File content after the frontmatter section.""" + + file_path: str + """Absolute path to the context file on disk.""" + + @staticmethod + def load(file_path: str) -> ContextFile: + """Load a context file from disk, parsing its YAML frontmatter. + + The file must begin with ``---`` and contain valid YAML frontmatter + followed by a closing ``---`` delimiter. + """ + if not os.path.isfile(file_path): + raise FileNotFoundError(f"Context file not found: {file_path}") + + with open(file_path) as f: + text = f.read() + + if not text.startswith("---"): + raise ValueError(f"Context file must have YAML frontmatter: {file_path}") + + parts = text.split("---") + if len(parts) < 3: + raise ValueError(f"Invalid frontmatter in {file_path}") + + # parts[0] is empty (before first ---), parts[1] is frontmatter, + # parts[2..] is content (rejoin in case content contains ---) + yaml_content = yaml.safe_load(parts[1]) + content = "---".join(parts[2:]).strip() + + metadata = ContextFileMetadata( + title=yaml_content["title"], + version=str(yaml_content["version"]), + description=yaml_content["description"], + dart_version=str(yaml_content["dart_version"]) if "dart_version" in yaml_content else None, + flutter_version=str(yaml_content["flutter_version"]) if "flutter_version" in yaml_content else None, + updated=str(yaml_content["updated"]) if "updated" in yaml_content else None, + ) + + return ContextFile( + metadata=metadata, + content=content, + file_path=file_path, + ) diff --git a/packages/dataset_config_python/src/dataset_config_python/models/dataset.py b/packages/dataset_config_python/src/dataset_config_python/models/dataset.py new file mode 100644 index 0000000..4892eaa --- /dev/null +++ b/packages/dataset_config_python/src/dataset_config_python/models/dataset.py @@ -0,0 +1,17 @@ +"""Dataset model — mirrors Inspect AI's Dataset/MemoryDataset.""" + +from __future__ import annotations + +from pydantic import BaseModel + +from dataset_config_python.models.sample import Sample + + +class Dataset(BaseModel): + """A named collection of samples.""" + + samples: list[Sample] = [] + """The sample records in this dataset.""" + + name: str = "" + """Display name for the dataset.""" diff --git a/packages/dataset_config_python/src/dataset_config_python/models/eval_set.py b/packages/dataset_config_python/src/dataset_config_python/models/eval_set.py new file mode 100644 index 0000000..c4fa2fe --- /dev/null +++ b/packages/dataset_config_python/src/dataset_config_python/models/eval_set.py @@ -0,0 +1,95 @@ +"""EvalSet model — the output shape consumed by dash_evals.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, Field + +from dataset_config_python.models.task import Task + + +class EvalSet(BaseModel): + """Resolved evaluation set ready for JSON serialization. + + The ``.model_dump()`` output of this model matches the JSON shape + that ``dash_evals.runner.json_runner.run_from_json()`` consumes. + """ + + tasks: list[Task] + """Task(s) to evaluate with inline datasets.""" + + log_dir: str + """Output path for logging results.""" + + model: list[str] | None = None + """Model(s) for evaluation.""" + + sandbox: Any | None = None + """Sandbox environment type.""" + + # Retry settings + retry_attempts: int | None = None + retry_wait: float | None = None + retry_connections: float | None = None + retry_cleanup: bool | None = None + retry_on_error: int | None = None + + # Error handling + fail_on_error: float | None = None + continue_on_fail: bool | None = None + debug_errors: bool | None = None + + # Concurrency + max_samples: int | None = None + max_tasks: int | None = None + max_subprocesses: int | None = None + max_sandboxes: int | None = None + + # Logging + log_level: str | None = None + log_level_transcript: str | None = None + log_format: str | None = None + log_samples: bool | None = None + log_realtime: bool | None = None + log_images: bool | None = None + log_buffer: int | None = None + log_shared: int | None = None + log_dir_allow_dirty: bool | None = None + + # Model config + model_base_url: str | None = None + model_args: dict[str, Any] = Field(default_factory=dict) + model_roles: dict[str, str] | None = None + task_args: dict[str, Any] = Field(default_factory=dict) + model_cost_config: dict[str, Any] | None = None + + # Sandbox + sandbox_cleanup: bool | None = None + + # Sample control + limit: Any | None = None + sample_id: Any | None = None + sample_shuffle: Any | None = None + epochs: Any | None = None + + # Misc + tags: list[str] | None = None + metadata: dict[str, Any] | None = None + trace: bool | None = None + display: str | None = None + approval: Any | None = None + solver: Any | None = None + score: bool = True + + # Limits + message_limit: int | None = None + token_limit: int | None = None + time_limit: int | None = None + working_limit: int | None = None + cost_limit: float | None = None + + # Bundling + bundle_dir: str | None = None + bundle_overwrite: bool = False + eval_set_id: str | None = None diff --git a/packages/dataset_config_python/src/dataset_config_python/models/job.py b/packages/dataset_config_python/src/dataset_config_python/models/job.py new file mode 100644 index 0000000..c82ccc1 --- /dev/null +++ b/packages/dataset_config_python/src/dataset_config_python/models/job.py @@ -0,0 +1,102 @@ +"""Job model — runtime configuration for an evaluation run.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, Field + + +class JobTask(BaseModel): + """Per-task configuration within a job.""" + + id: str + """Task identifier matching a task directory name.""" + + include_samples: list[str] | None = None + """Only run these sample IDs.""" + + exclude_samples: list[str] | None = None + """Exclude these sample IDs.""" + + system_message: str | None = None + """Override system message for this task.""" + + @staticmethod + def from_yaml(task_id: str, data: dict[str, Any] | None) -> JobTask: + """Create from parsed YAML data.""" + if data is None: + return JobTask(id=task_id) + return JobTask( + id=task_id, + include_samples=data.get("include-samples"), + exclude_samples=data.get("exclude-samples"), + system_message=data.get("system_message"), + ) + + +class Job(BaseModel): + """A job configuration defining what to run and how to run it.""" + + # Core settings + log_dir: str + sandbox_type: str = "local" + max_connections: int = 10 + models: list[str] | None = None + variants: dict[str, dict[str, Any]] | None = None + task_paths: list[str] | None = None + tasks: dict[str, JobTask] | None = None + save_examples: bool = False + + # Promoted eval_set() parameters + retry_attempts: int | None = None + max_retries: int | None = None + retry_wait: float | None = None + retry_connections: float | None = None + retry_cleanup: bool | None = None + fail_on_error: float | None = None + continue_on_fail: bool | None = None + retry_on_error: int | None = None + debug_errors: bool | None = None + max_samples: int | None = None + max_tasks: int | None = None + max_subprocesses: int | None = None + max_sandboxes: int | None = None + log_level: str | None = None + log_level_transcript: str | None = None + log_format: str | None = None + tags: list[str] | None = None + metadata: dict[str, Any] | None = None + trace: bool | None = None + display: str | None = None + score: bool | None = None + limit: Any | None = None + sample_id: Any | None = None + sample_shuffle: Any | None = None + epochs: Any | None = None + approval: Any | None = None + solver: Any | None = None + sandbox_cleanup: bool | None = None + model_base_url: str | None = None + model_args: dict[str, Any] | None = None + model_roles: dict[str, str] | None = None + task_args: dict[str, Any] | None = None + message_limit: int | None = None + token_limit: int | None = None + time_limit: int | None = None + working_limit: int | None = None + cost_limit: float | None = None + model_cost_config: dict[str, Any] | None = None + log_samples: bool | None = None + log_realtime: bool | None = None + log_images: bool | None = None + log_buffer: int | None = None + log_shared: int | None = None + bundle_dir: str | None = None + bundle_overwrite: bool | None = None + log_dir_allow_dirty: bool | None = None + eval_set_id: str | None = None + + # Pass-through overrides + eval_set_overrides: dict[str, Any] | None = None + task_defaults: dict[str, Any] | None = None diff --git a/packages/dataset_config_python/src/dataset_config_python/models/sample.py b/packages/dataset_config_python/src/dataset_config_python/models/sample.py new file mode 100644 index 0000000..d6fe2ce --- /dev/null +++ b/packages/dataset_config_python/src/dataset_config_python/models/sample.py @@ -0,0 +1,38 @@ +"""Sample model — mirrors Inspect AI's Sample.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel + + +class Sample(BaseModel): + """A sample for an evaluation task. + + Maps to Inspect AI's ``Sample`` class. + """ + + input: str + """The input to be submitted to the model.""" + + target: str = "" + """Ideal target output.""" + + id: str | None = None + """Unique identifier for the sample.""" + + choices: list[str] | None = None + """Available answer choices (multiple-choice evals only).""" + + metadata: dict[str, Any] | None = None + """Arbitrary metadata associated with the sample.""" + + sandbox: Any | None = None + """Sandbox environment type and optional config file.""" + + files: dict[str, str] | None = None + """Files that go along with the sample (copied to SandboxEnvironment).""" + + setup: str | None = None + """Setup script to run for sample.""" diff --git a/packages/dataset_config_python/src/dataset_config_python/models/task.py b/packages/dataset_config_python/src/dataset_config_python/models/task.py new file mode 100644 index 0000000..aeaf471 --- /dev/null +++ b/packages/dataset_config_python/src/dataset_config_python/models/task.py @@ -0,0 +1,77 @@ +"""Task model — mirrors Inspect AI's Task configuration.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel + +from dataset_config_python.models.dataset import Dataset + + +class Task(BaseModel): + """A single evaluation task with inline dataset. + + Maps to the task definitions in the EvalSet JSON consumed by + ``dash_evals.runner.json_runner``. + """ + + name: str = "" + """Task name (e.g. ``"dart_qa:baseline"``).""" + + task_func: str | None = None + """Task function identifier for hydration (e.g. ``"question_answer"``).""" + + dataset: Dataset | None = None + """Inline dataset with samples.""" + + sandbox: Any | None = None + """Sandbox environment type.""" + + metadata: dict[str, Any] | None = None + """Task-level metadata (variant config, system message, etc.).""" + + model: str | None = None + """Default model for this task.""" + + config: dict[str, Any] | None = None + """Model generation config.""" + + model_roles: dict[str, str] | None = None + """Named roles for use in get_model().""" + + approval: Any | None = None + """Tool use approval policies.""" + + epochs: Any | None = None + """Epochs to repeat samples for.""" + + fail_on_error: Any | None = None + """Fail on sample errors.""" + + continue_on_fail: bool | None = None + """Continue running if fail_on_error condition is met.""" + + message_limit: int | None = None + """Limit on total messages per sample.""" + + token_limit: int | None = None + """Limit on total tokens per sample.""" + + time_limit: int | None = None + """Limit on clock time (in seconds) per sample.""" + + working_limit: int | None = None + """Limit on working time (in seconds) per sample.""" + + cost_limit: float | None = None + """Limit on total cost (in dollars) per sample.""" + + early_stopping: Any | None = None + """Early stopping callbacks.""" + + display_name: str | None = None + """Task display name.""" + + version: Any = 0 + """Version of task.""" diff --git a/packages/dataset_config_python/src/dataset_config_python/models/variant.py b/packages/dataset_config_python/src/dataset_config_python/models/variant.py new file mode 100644 index 0000000..690e675 --- /dev/null +++ b/packages/dataset_config_python/src/dataset_config_python/models/variant.py @@ -0,0 +1,36 @@ +"""Variant model — evaluation variant configuration.""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + +from dataset_config_python.models.context_file import ContextFile + + +class Variant(BaseModel): + """A configuration variant for running evaluations. + + Variants define different testing configurations to compare model + performance with and without specific tooling or context. + + Features are implied by field presence: + - context_files populated → context injection enabled + - mcp_servers populated → MCP tools enabled + - skill_paths populated → agent skills enabled + - all empty → baseline variant + """ + + name: str = "baseline" + """User-defined variant name.""" + + context_files: list[ContextFile] = Field(default_factory=list) + """Loaded context files (paths resolved by config resolver).""" + + mcp_servers: list[str] = Field(default_factory=list) + """MCP server keys to enable (e.g. ``['dart']``).""" + + skill_paths: list[str] = Field(default_factory=list) + """Resolved paths to agent skill directories.""" + + flutter_channel: str | None = None + """Flutter SDK channel to use (e.g. 'stable', 'beta', 'main').""" diff --git a/packages/dataset_config_python/src/dataset_config_python/parser.py b/packages/dataset_config_python/src/dataset_config_python/parser.py new file mode 100644 index 0000000..0e9fc12 --- /dev/null +++ b/packages/dataset_config_python/src/dataset_config_python/parser.py @@ -0,0 +1,582 @@ +"""YAML parser — reads job, task, and sample YAML files from the filesystem.""" + +from __future__ import annotations + +import glob as globmod +import os +import re +from datetime import datetime, timezone +from typing import Any + +import yaml + +from dataset_config_python.models.job import Job, JobTask +from dataset_config_python.models.sample import Sample +from dataset_config_python.models.variant import Variant + +# Default log directory (relative to dataset root). +_DEFAULT_LOGS_DIR = "../logs" + + +class ParsedTask: + """Lightweight intermediate type used during parsing and resolution. + + Groups samples with task-level config before the resolver produces + the final Task objects. + """ + + def __init__( + self, + *, + id: str, + task_func: str, + samples: list[Sample], + variant: Variant | None = None, + sandbox_type: str = "local", + system_message: str | None = None, + allowed_variants: list[str] | None = None, + save_examples: bool = False, + examples_dir: str | None = None, + # Task-level settings + model: str | None = None, + config: dict[str, Any] | None = None, + model_roles: dict[str, str] | None = None, + sandbox: Any | None = None, + approval: Any | None = None, + epochs: Any | None = None, + fail_on_error: Any | None = None, + continue_on_fail: bool | None = None, + message_limit: int | None = None, + token_limit: int | None = None, + time_limit: int | None = None, + working_limit: int | None = None, + cost_limit: float | None = None, + early_stopping: Any | None = None, + display_name: str | None = None, + version: Any | None = None, + metadata: dict[str, Any] | None = None, + ): + self.id = id + self.task_func = task_func + self.samples = samples + self.variant = variant or Variant() + self.sandbox_type = sandbox_type + self.system_message = system_message + self.allowed_variants = allowed_variants + self.save_examples = save_examples + self.examples_dir = examples_dir + self.model = model + self.config = config + self.model_roles = model_roles + self.sandbox = sandbox + self.approval = approval + self.epochs = epochs + self.fail_on_error = fail_on_error + self.continue_on_fail = continue_on_fail + self.message_limit = message_limit + self.token_limit = token_limit + self.time_limit = time_limit + self.working_limit = working_limit + self.cost_limit = cost_limit + self.early_stopping = early_stopping + self.display_name = display_name + self.version = version + self.metadata = metadata + + _UNSET: Any = object() + + def copy_with( + self, + *, + id: str | None = _UNSET, + task_func: str | None = _UNSET, + samples: list[Sample] | None = _UNSET, + variant: Variant | None = _UNSET, + sandbox_type: str | None = _UNSET, + system_message: str | None = _UNSET, + allowed_variants: list[str] | None = _UNSET, + save_examples: bool | None = _UNSET, + examples_dir: str | None = _UNSET, + model: str | None = _UNSET, + config: dict[str, Any] | None = _UNSET, + model_roles: dict[str, str] | None = _UNSET, + sandbox: Any = _UNSET, + approval: Any = _UNSET, + epochs: Any = _UNSET, + fail_on_error: Any = _UNSET, + continue_on_fail: bool | None = _UNSET, + message_limit: int | None = _UNSET, + token_limit: int | None = _UNSET, + time_limit: int | None = _UNSET, + working_limit: int | None = _UNSET, + cost_limit: float | None = _UNSET, + early_stopping: Any = _UNSET, + display_name: str | None = _UNSET, + version: Any = _UNSET, + metadata: dict[str, Any] | None = _UNSET, + ) -> ParsedTask: + """Create a copy with overrides.""" + _U = ParsedTask._UNSET + return ParsedTask( + id=self.id if id is _U else id, # type: ignore[arg-type] + task_func=self.task_func if task_func is _U else task_func, # type: ignore[arg-type] + samples=self.samples if samples is _U else samples, # type: ignore[arg-type] + variant=self.variant if variant is _U else variant, + sandbox_type=self.sandbox_type if sandbox_type is _U else sandbox_type, # type: ignore[arg-type] + system_message=self.system_message if system_message is _U else system_message, + allowed_variants=self.allowed_variants if allowed_variants is _U else allowed_variants, + save_examples=self.save_examples if save_examples is _U else save_examples, # type: ignore[arg-type] + examples_dir=self.examples_dir if examples_dir is _U else examples_dir, + model=self.model if model is _U else model, + config=self.config if config is _U else config, + model_roles=self.model_roles if model_roles is _U else model_roles, + sandbox=self.sandbox if sandbox is _U else sandbox, + approval=self.approval if approval is _U else approval, + epochs=self.epochs if epochs is _U else epochs, + fail_on_error=self.fail_on_error if fail_on_error is _U else fail_on_error, + continue_on_fail=self.continue_on_fail if continue_on_fail is _U else continue_on_fail, + message_limit=self.message_limit if message_limit is _U else message_limit, + token_limit=self.token_limit if token_limit is _U else token_limit, + time_limit=self.time_limit if time_limit is _U else time_limit, + working_limit=self.working_limit if working_limit is _U else working_limit, + cost_limit=self.cost_limit if cost_limit is _U else cost_limit, + early_stopping=self.early_stopping if early_stopping is _U else early_stopping, + display_name=self.display_name if display_name is _U else display_name, + version=self.version if version is _U else version, + metadata=self.metadata if metadata is _U else metadata, + ) + + +def _is_glob(pattern: str) -> bool: + return "*" in pattern or "?" in pattern or "[" in pattern + + +def _expand_glob_files(base_dir: str, pattern: str) -> list[str]: + """Expand a glob pattern relative to base_dir, returning matching files.""" + full_pattern = os.path.join(base_dir, pattern) + matches = [ + os.path.normpath(f) + for f in sorted(globmod.glob(full_pattern, recursive=True)) + if os.path.isfile(f) and (f.endswith(".yaml") or f.endswith(".yml") or f.endswith(".md")) + ] + return matches + + +def _read_yaml_file(path: str) -> dict[str, Any]: + """Read a YAML file and return as a dict.""" + with open(path) as f: + data = yaml.safe_load(f) + if data is None: + return {} + if not isinstance(data, dict): + raise ValueError(f"Expected dict in YAML file {path}, got {type(data).__name__}") + return data + + +def _resolve_log_dir(logs_dir: str, base_dir: str) -> str: + """Resolve log directory with a timestamp subfolder.""" + now = datetime.now(timezone.utc) + timestamp = now.strftime("%Y-%m-%d_%H-%M-%S") + return os.path.normpath(os.path.join(base_dir, logs_dir, timestamp)) + + +def _pre_resolve_to_abs(resource: Any, task_dir: str) -> Any: + """Pre-resolve a task-level resource to an absolute path.""" + if resource is None: + return None + if isinstance(resource, str): + if resource.startswith("./") or resource.startswith("../") or resource.startswith("/"): + return {"path": os.path.normpath(os.path.join(task_dir, resource))} + return resource + if isinstance(resource, dict): + if "path" in resource: + return {**resource, "path": os.path.normpath(os.path.join(task_dir, resource["path"]))} + return resource + return resource + + +def _resolve_resource_path(resource: Any, base_dir: str) -> str | None: + """Resolve a workspace/tests resource reference to an absolute path.""" + if resource is None: + return None + if isinstance(resource, str): + if resource.startswith("./") or resource.startswith("../") or resource.startswith("/"): + return os.path.normpath(os.path.join(base_dir, resource)) + return None + if isinstance(resource, dict) and "path" in resource: + return os.path.normpath(os.path.join(base_dir, resource["path"])) + return None + + +# --------------------------------------------------------------------------- +# Task parsing +# --------------------------------------------------------------------------- + + +def parse_tasks(dataset_root: str) -> list[ParsedTask]: + """Parse all task.yaml files from tasks/ subdirectories.""" + tasks_dir = os.path.join(dataset_root, "tasks") + if not os.path.isdir(tasks_dir): + return [] + + parsed = [] + for entry in sorted(os.listdir(tasks_dir)): + task_dir = os.path.join(tasks_dir, entry) + if not os.path.isdir(task_dir): + continue + task_file = os.path.join(task_dir, "task.yaml") + if os.path.isfile(task_file): + parsed.extend(_load_task_file(task_file, dataset_root)) + + return parsed + + +def _load_task_file(task_path: str, dataset_root: str) -> list[ParsedTask]: + """Load a single task.yaml file into ParsedTask(s).""" + data = _read_yaml_file(task_path) + task_dir = os.path.dirname(task_path) + + task_id = data.get("id") or os.path.basename(task_dir) + task_func = data.get("func") or task_id + + task_workspace_raw = data.get("workspace") + task_tests_raw = data.get("tests") + system_message = data.get("system_message") + + task_workspace = _pre_resolve_to_abs(task_workspace_raw, task_dir) + task_tests = _pre_resolve_to_abs(task_tests_raw, task_dir) + + allowed_variants = data.get("allowed_variants") + + # Parse samples section + samples_raw = data.get("samples") + if not isinstance(samples_raw, dict): + raise ValueError( + f"Task '{task_id}': 'samples' must be a dict with 'inline' and/or " + f"'paths' keys, got {type(samples_raw).__name__}" + ) + samples = _load_samples_section(samples_raw, dataset_root, task_workspace, task_tests, task_dir) + + return [ + ParsedTask( + id=task_id, + task_func=task_func, + variant=Variant(), + samples=samples, + system_message=system_message, + allowed_variants=allowed_variants, + model=data.get("model"), + config=data.get("config") if isinstance(data.get("config"), dict) else None, + model_roles=data.get("model_roles") if isinstance(data.get("model_roles"), dict) else None, + sandbox=data.get("sandbox"), + approval=data.get("approval"), + epochs=data.get("epochs"), + fail_on_error=data.get("fail_on_error"), + continue_on_fail=data.get("continue_on_fail"), + message_limit=data.get("message_limit"), + token_limit=data.get("token_limit"), + time_limit=data.get("time_limit"), + working_limit=data.get("working_limit"), + cost_limit=float(data["cost_limit"]) if data.get("cost_limit") is not None else None, + early_stopping=data.get("early_stopping"), + display_name=data.get("display_name"), + version=data.get("version"), + metadata=data.get("metadata") if isinstance(data.get("metadata"), dict) else None, + ) + ] + + +# --------------------------------------------------------------------------- +# Sample loading +# --------------------------------------------------------------------------- + + +def _load_samples_section( + samples_map: dict[str, Any], + dataset_root: str, + task_workspace: Any, + task_tests: Any, + task_dir: str, +) -> list[Sample]: + """Load samples from 'paths' and 'inline' subsections.""" + path_patterns: list[str] = samples_map.get("paths") or [] + inline_defs: list[dict[str, Any]] = samples_map.get("inline") or [] + + samples: list[Sample] = [] + + for pattern in path_patterns: + if _is_glob(pattern): + matched = _expand_glob_files(task_dir, pattern) + else: + candidate = os.path.normpath(os.path.join(task_dir, pattern)) + matched = [candidate] if os.path.isfile(candidate) else [] + + if not matched: + raise FileNotFoundError(f"No sample files matched pattern: {pattern}") + + samples.extend(_load_samples_from_files(matched, dataset_root, task_workspace, task_tests)) + + for defn in inline_defs: + if not defn: + continue + samples.append(_resolve_sample(defn, task_dir, dataset_root, task_workspace, task_tests)) + + return samples + + +def _load_samples_from_files( + sample_files: list[str], + dataset_root: str, + task_workspace: Any, + task_tests: Any, +) -> list[Sample]: + """Load samples from external YAML files.""" + samples: list[Sample] = [] + + for file_path in sample_files: + full_path = file_path if os.path.isabs(file_path) else os.path.join(dataset_root, file_path) + if not os.path.isfile(full_path): + raise FileNotFoundError(f"Sample file not found: {full_path}") + + sample_dir = os.path.dirname(full_path) + with open(full_path) as f: + content = f.read() + + # Support multi-document YAML (--- separated) + docs = re.split(r"^---\s*$", content, flags=re.MULTILINE) + for doc in docs: + if not doc.strip(): + continue + data = yaml.safe_load(doc) + if isinstance(data, dict): + samples.append( + _resolve_sample(data, sample_dir, dataset_root, task_workspace, task_tests) + ) + + return samples + + +def _resolve_sample( + doc: dict[str, Any], + base_dir: str, + dataset_root: str, + task_workspace: Any, + task_tests: Any, +) -> Sample: + """Resolve a single sample dict into a Sample.""" + for field in ("id", "input", "target"): + if field not in doc: + raise ValueError( + f"Sample '{doc.get('id', 'unknown')}' missing required field: {field}" + ) + + sample_workspace = doc.get("workspace") + sample_tests = doc.get("tests") + + effective_workspace = sample_workspace if sample_workspace is not None else task_workspace + + workspace = None + workspace_git = None + workspace_git_ref = None + + if effective_workspace is not None: + if isinstance(effective_workspace, dict) and "git" in effective_workspace: + workspace_git = effective_workspace.get("git") + workspace_git_ref = effective_workspace.get("ref") + else: + resolve_dir = base_dir if sample_workspace is not None else dataset_root + workspace = _resolve_resource_path(effective_workspace, resolve_dir) + + tests = None + if sample_tests is not None: + tests = _resolve_resource_path(sample_tests, base_dir) + elif task_tests is not None: + tests = _resolve_resource_path(task_tests, dataset_root) + + # Normalize tags + raw_tags = doc.get("tags") + if isinstance(raw_tags, str): + tags = [t.strip() for t in raw_tags.split(",")] + elif isinstance(raw_tags, list): + tags = raw_tags + else: + tags = [] + + # Build metadata + meta: dict[str, Any] = {**(doc.get("metadata") or {})} + meta["difficulty"] = doc.get("difficulty", "medium") + meta["tags"] = tags + if workspace is not None: + meta["workspace"] = workspace + if tests is not None: + meta["tests"] = tests + if workspace_git is not None: + meta["workspace_git"] = workspace_git + if workspace_git_ref is not None: + meta["workspace_git_ref"] = workspace_git_ref + + return Sample( + id=doc["id"], + input=doc["input"], + target=doc["target"], + metadata=meta, + choices=doc.get("choices"), + sandbox=doc.get("sandbox"), + files=doc.get("files"), + setup=doc.get("setup"), + ) + + +# --------------------------------------------------------------------------- +# Job parsing +# --------------------------------------------------------------------------- + + +def parse_job(job_path: str, dataset_root: str) -> Job: + """Parse a job YAML file into a Job model.""" + if not os.path.isfile(job_path): + raise FileNotFoundError(f"Job file not found: {job_path}") + + data = _read_yaml_file(job_path) + + logs_dir = data.get("logs_dir") or _DEFAULT_LOGS_DIR + log_dir = _resolve_log_dir(logs_dir, dataset_root) + + sandbox_type = data.get("sandbox_type") or "local" + max_connections = data.get("max_connections") or 10 + + # Parse task filters + task_paths = None + tasks = None + tasks_raw = data.get("tasks") + if isinstance(tasks_raw, dict): + task_paths = tasks_raw.get("paths") + inline_tasks = tasks_raw.get("inline") + if isinstance(inline_tasks, dict): + tasks = {} + for tid, tdata in inline_tasks.items(): + tasks[tid] = JobTask.from_yaml(tid, tdata) + + # Parse variants + variants = None + variants_raw = data.get("variants") + if isinstance(variants_raw, dict): + variants = {} + for key, value in variants_raw.items(): + if isinstance(value, dict): + variants[str(key)] = dict(value) + else: + variants[str(key)] = {} + + return Job( + log_dir=log_dir, + sandbox_type=sandbox_type, + max_connections=max_connections, + models=data.get("models"), + variants=variants, + task_paths=task_paths, + tasks=tasks, + save_examples=data.get("save_examples") is True, + retry_attempts=data.get("retry_attempts"), + max_retries=data.get("max_retries"), + retry_wait=float(data["retry_wait"]) if data.get("retry_wait") is not None else None, + retry_connections=( + float(data["retry_connections"]) if data.get("retry_connections") is not None else None + ), + retry_cleanup=data.get("retry_cleanup"), + fail_on_error=( + float(data["fail_on_error"]) if data.get("fail_on_error") is not None else None + ), + continue_on_fail=data.get("continue_on_fail"), + retry_on_error=data.get("retry_on_error"), + debug_errors=data.get("debug_errors"), + max_samples=data.get("max_samples"), + max_tasks=data.get("max_tasks"), + max_subprocesses=data.get("max_subprocesses"), + max_sandboxes=data.get("max_sandboxes"), + log_level=data.get("log_level"), + log_level_transcript=data.get("log_level_transcript"), + log_format=data.get("log_format"), + tags=data.get("tags"), + metadata=data.get("metadata") if isinstance(data.get("metadata"), dict) else None, + trace=data.get("trace"), + display=data.get("display"), + score=data.get("score"), + limit=data.get("limit"), + sample_id=data.get("sample_id"), + sample_shuffle=data.get("sample_shuffle"), + epochs=data.get("epochs"), + approval=data.get("approval"), + solver=data.get("solver"), + sandbox_cleanup=data.get("sandbox_cleanup"), + model_base_url=data.get("model_base_url"), + model_args=data.get("model_args") if isinstance(data.get("model_args"), dict) else None, + model_roles=( + data.get("model_roles") if isinstance(data.get("model_roles"), dict) else None + ), + task_args=data.get("task_args") if isinstance(data.get("task_args"), dict) else None, + message_limit=data.get("message_limit"), + token_limit=data.get("token_limit"), + time_limit=data.get("time_limit"), + working_limit=data.get("working_limit"), + cost_limit=float(data["cost_limit"]) if data.get("cost_limit") is not None else None, + model_cost_config=( + data.get("model_cost_config") + if isinstance(data.get("model_cost_config"), dict) + else None + ), + log_samples=data.get("log_samples"), + log_realtime=data.get("log_realtime"), + log_images=data.get("log_images"), + log_buffer=data.get("log_buffer"), + log_shared=data.get("log_shared"), + bundle_dir=data.get("bundle_dir"), + bundle_overwrite=data.get("bundle_overwrite"), + log_dir_allow_dirty=data.get("log_dir_allow_dirty"), + eval_set_id=data.get("eval_set_id"), + eval_set_overrides=( + data.get("eval_set_overrides") + if isinstance(data.get("eval_set_overrides"), dict) + else None + ), + task_defaults=( + data.get("task_defaults") if isinstance(data.get("task_defaults"), dict) else None + ), + ) + + +def find_job_file(dataset_root: str, job: str) -> str: + """Find a job file by name or path. + + Looks in ``jobs/`` directory first, then treats *job* as a relative/absolute path. + """ + if "/" in job or job.endswith(".yaml"): + job_path = job if os.path.isabs(job) else os.path.join(dataset_root, job) + if not os.path.isfile(job_path): + raise FileNotFoundError(f"Job file not found: {job_path}") + return os.path.normpath(job_path) + + jobs_dir = os.path.join(dataset_root, "jobs") + if not os.path.isdir(jobs_dir): + raise FileNotFoundError( + "Jobs directory not found. " + "Create it or specify a full path to the job file." + ) + + with_ext = os.path.join(jobs_dir, f"{job}.yaml") + if os.path.isfile(with_ext): + return os.path.normpath(with_ext) + + without_ext = os.path.join(jobs_dir, job) + if os.path.isfile(without_ext): + return os.path.normpath(without_ext) + + available = [ + os.path.splitext(f)[0] + for f in sorted(os.listdir(jobs_dir)) + if f.endswith(".yaml") + ] + raise FileNotFoundError( + f"Job '{job}' not found in {jobs_dir}. " + f"Available jobs: {available or '(none)'}" + ) diff --git a/packages/dataset_config_python/src/dataset_config_python/resolver.py b/packages/dataset_config_python/src/dataset_config_python/resolver.py new file mode 100644 index 0000000..0801b3c --- /dev/null +++ b/packages/dataset_config_python/src/dataset_config_python/resolver.py @@ -0,0 +1,503 @@ +"""Resolver — combines parsed tasks and job config into EvalSet objects.""" + +from __future__ import annotations + +import glob as globmod +import os +from typing import Any + +from dataset_config_python.models.context_file import ContextFile +from dataset_config_python.models.dataset import Dataset +from dataset_config_python.models.eval_set import EvalSet +from dataset_config_python.models.sample import Sample +from dataset_config_python.models.task import Task +from dataset_config_python.models.variant import Variant +from dataset_config_python.parser import ParsedTask, find_job_file, parse_job, parse_tasks + +# Default models when a job doesn't specify its own. +DEFAULT_MODELS: list[str] = [ + "anthropic/claude-haiku-4-5", + "anthropic/claude-sonnet-4-5", + "anthropic/claude-opus-4-6", + "google/gemini-2.5-flash", + "google/gemini-3-pro-preview", + "google/gemini-3-flash-preview", + "openai/gpt-5-mini", + "openai/gpt-5-nano", + "openai/gpt-5", + "openai/gpt-5-pro", +] + +# Available sandbox configurations. +SANDBOX_REGISTRY: dict[str, dict[str, str]] = { + "podman": {"name": "podman", "path": "./sandboxes/podman/compose.yaml"}, + "podman-beta": {"name": "podman", "path": "./sandboxes/podman/compose-beta.yaml"}, + "podman-main": {"name": "podman", "path": "./sandboxes/podman/compose-main.yaml"}, +} + +# Maps Flutter SDK channel names to sandbox registry keys. +SDK_CHANNELS: dict[str, str] = { + "stable": "podman", + "beta": "podman-beta", + "main": "podman-main", +} + + +def _is_glob(pattern: str) -> bool: + return "*" in pattern or "?" in pattern or "[" in pattern + + +def resolve( + dataset_path: str, + job_names: list[str], +) -> list[EvalSet]: + """Resolve dataset + job(s) into EvalSet objects. + + This is the main public API of the package. + + Args: + dataset_path: Root directory containing ``tasks/`` and ``jobs/``. + job_names: Job names (looked up in ``jobs/``) or paths. + + Returns: + A list of EvalSet objects ready for JSON serialization. + """ + task_configs = parse_tasks(dataset_path) + results: list[EvalSet] = [] + + for job_name in job_names: + job_path = find_job_file(dataset_path, job_name) + job = parse_job(job_path, dataset_path) + results.extend(_resolve_job(task_configs, job, dataset_path)) + + return results + + +def _resolve_job( + dataset_tasks: list[ParsedTask], + job: Any, + dataset_root: str, +) -> list[EvalSet]: + """Resolve task configs and job into EvalSet objects.""" + models = job.models if job.models else list(DEFAULT_MODELS) + sandbox_type_str = job.sandbox_type + + expanded_tasks = _expand_task_configs(dataset_tasks, job, sandbox_type_str, dataset_root) + + # Group by flutter channel + groups: dict[str | None, list[ParsedTask]] = {} + for tc in expanded_tasks: + key = tc.variant.flutter_channel + groups.setdefault(key, []).append(tc) + + return [ + _build_eval_set( + task_configs=group, + log_dir=job.log_dir, + models=models, + sandbox=_resolve_sandbox(dataset_root, job, flutter_channel=channel), + job=job, + ) + for channel, group in groups.items() + ] + + +# --------------------------------------------------------------------------- +# EvalSet building +# --------------------------------------------------------------------------- + + +def _build_eval_set( + *, + task_configs: list[ParsedTask], + log_dir: str, + models: list[str], + sandbox: Any, + job: Any, +) -> EvalSet: + """Build an EvalSet from resolved ParsedTasks.""" + inspect_tasks: list[Task] = [] + is_container = job.sandbox_type and job.sandbox_type != "local" + task_defaults = job.task_defaults or {} + + for tc in task_configs: + # Enrich each sample with task-level metadata + inspect_samples: list[Sample] = [] + for sample in tc.samples: + enriched: dict[str, Any] = {**(sample.metadata or {})} + + if tc.save_examples: + enriched["save_examples"] = True + if tc.examples_dir is not None: + enriched["examples_dir"] = tc.examples_dir + enriched["task_variant"] = f"{tc.id}:{tc.variant.name}" + + # Build files + setup for sandbox provisioning + files = dict(sample.files) if sample.files else None + setup = sample.setup + workspace = (sample.metadata or {}).get("workspace") + workspace_git = (sample.metadata or {}).get("workspace_git") + workspace_git_ref = (sample.metadata or {}).get("workspace_git_ref") + + if workspace is not None and is_container: + files = {**(files or {}), "/workspace": workspace} + setup = setup or "cd /workspace && flutter pub get" + enriched["workspace"] = "/workspace" + if workspace_git is not None: + enriched["workspace_git"] = workspace_git + if workspace_git_ref is not None: + enriched["workspace_git_ref"] = workspace_git_ref + + inspect_samples.append( + Sample( + id=sample.id, + input=sample.input, + target=sample.target, + metadata=enriched, + choices=sample.choices, + sandbox=sample.sandbox, + files=files, + setup=setup, + ) + ) + + dataset = Dataset( + samples=inspect_samples, + name=f"{tc.id}:{tc.variant.name}", + ) + + # Task metadata (variant config, system message, etc.) + task_metadata: dict[str, Any] = {"variant": tc.variant.name} + if tc.variant.context_files: + task_metadata["variant_config"] = { + "context_files": [ + { + "title": cf.metadata.title, + "version": cf.metadata.version, + "content": cf.content, + } + for cf in tc.variant.context_files + ], + "mcp_servers": tc.variant.mcp_servers, + "skill_paths": tc.variant.skill_paths, + } + elif tc.variant.mcp_servers or tc.variant.skill_paths: + task_metadata["variant_config"] = { + "mcp_servers": tc.variant.mcp_servers, + "skill_paths": tc.variant.skill_paths, + } + if tc.system_message is not None: + task_metadata["system_message"] = tc.system_message + if tc.save_examples: + task_metadata["save_examples"] = True + if tc.examples_dir is not None: + task_metadata["examples_dir"] = tc.examples_dir + if tc.metadata: + task_metadata.update(tc.metadata) + + # Determine sandbox for this task + task_sandbox = None + if tc.sandbox is not None: + task_sandbox = tc.sandbox + elif tc.sandbox_type and tc.sandbox_type != "local": + task_sandbox = _serialize_sandbox(sandbox) + + # Resolve task-level settings with precedence: + # task.yaml > task_defaults > hardcoded defaults + resolved_time_limit = ( + tc.time_limit + or task_defaults.get("time_limit") + or (300 if job.sandbox_type != "local" else None) + ) + + inspect_tasks.append( + Task( + name=f"{tc.id}:{tc.variant.name}", + task_func=tc.task_func, + dataset=dataset, + sandbox=task_sandbox, + metadata=task_metadata, + model=tc.model or task_defaults.get("model"), + config=tc.config or task_defaults.get("config"), + model_roles=tc.model_roles or task_defaults.get("model_roles"), + approval=tc.approval or task_defaults.get("approval"), + epochs=tc.epochs or task_defaults.get("epochs"), + fail_on_error=tc.fail_on_error or task_defaults.get("fail_on_error"), + continue_on_fail=tc.continue_on_fail if tc.continue_on_fail is not None else task_defaults.get("continue_on_fail"), + message_limit=tc.message_limit or task_defaults.get("message_limit"), + token_limit=tc.token_limit or task_defaults.get("token_limit"), + time_limit=resolved_time_limit, + working_limit=tc.working_limit or task_defaults.get("working_limit"), + cost_limit=tc.cost_limit if tc.cost_limit is not None else ( + float(task_defaults["cost_limit"]) if task_defaults.get("cost_limit") is not None else None + ), + early_stopping=tc.early_stopping or task_defaults.get("early_stopping"), + display_name=tc.display_name or task_defaults.get("display_name"), + version=tc.version if tc.version is not None else (task_defaults.get("version") or 0), + ) + ) + + # Build EvalSet with all job-level parameters + overrides = job.eval_set_overrides or {} + + return EvalSet( + tasks=inspect_tasks, + log_dir=log_dir, + model=models, + sandbox=_serialize_sandbox(sandbox), + # Retry + retry_attempts=job.retry_attempts or overrides.get("retry_attempts") or 10, + retry_wait=job.retry_wait or overrides.get("retry_wait") or 60.0, + retry_connections=job.retry_connections or overrides.get("retry_connections") or 0.5, + retry_cleanup=job.retry_cleanup if job.retry_cleanup is not None else overrides.get("retry_cleanup"), + retry_on_error=job.retry_on_error or job.max_retries or overrides.get("retry_on_error"), + # Error handling + fail_on_error=job.fail_on_error if job.fail_on_error is not None else (overrides.get("fail_on_error") or 0.05), + continue_on_fail=job.continue_on_fail if job.continue_on_fail is not None else overrides.get("continue_on_fail"), + debug_errors=job.debug_errors if job.debug_errors is not None else overrides.get("debug_errors"), + # Concurrency + max_samples=job.max_samples or overrides.get("max_samples"), + max_tasks=job.max_tasks or overrides.get("max_tasks"), + max_subprocesses=job.max_subprocesses or overrides.get("max_subprocesses"), + max_sandboxes=job.max_sandboxes or overrides.get("max_sandboxes"), + # Logging + log_level=job.log_level or overrides.get("log_level") or "info", + log_level_transcript=job.log_level_transcript or overrides.get("log_level_transcript"), + log_format=job.log_format or overrides.get("log_format") or "json", + log_samples=job.log_samples if job.log_samples is not None else overrides.get("log_samples"), + log_realtime=job.log_realtime if job.log_realtime is not None else overrides.get("log_realtime"), + log_images=job.log_images if job.log_images is not None else overrides.get("log_images"), + log_buffer=job.log_buffer or overrides.get("log_buffer"), + log_shared=job.log_shared or overrides.get("log_shared"), + log_dir_allow_dirty=job.log_dir_allow_dirty if job.log_dir_allow_dirty is not None else overrides.get("log_dir_allow_dirty"), + # Model config + model_base_url=job.model_base_url or overrides.get("model_base_url"), + model_args=job.model_args or overrides.get("model_args") or {}, + model_roles=job.model_roles or overrides.get("model_roles"), + task_args=job.task_args or overrides.get("task_args") or {}, + model_cost_config=job.model_cost_config or overrides.get("model_cost_config"), + # Sandbox + sandbox_cleanup=job.sandbox_cleanup if job.sandbox_cleanup is not None else overrides.get("sandbox_cleanup"), + # Sample control + limit=job.limit or overrides.get("limit"), + sample_id=job.sample_id or overrides.get("sample_id"), + sample_shuffle=job.sample_shuffle or overrides.get("sample_shuffle"), + epochs=job.epochs or overrides.get("epochs"), + # Misc + tags=job.tags or overrides.get("tags"), + metadata=job.metadata or overrides.get("metadata"), + trace=job.trace if job.trace is not None else overrides.get("trace"), + display=job.display or overrides.get("display"), + approval=job.approval or overrides.get("approval"), + solver=job.solver or overrides.get("solver"), + score=job.score if job.score is not None else (overrides.get("score") if overrides.get("score") is not None else True), + # Limits + message_limit=job.message_limit or overrides.get("message_limit"), + token_limit=job.token_limit or overrides.get("token_limit"), + time_limit=job.time_limit or overrides.get("time_limit"), + working_limit=job.working_limit or overrides.get("working_limit"), + cost_limit=job.cost_limit if job.cost_limit is not None else ( + float(overrides["cost_limit"]) if overrides.get("cost_limit") is not None else None + ), + # Bundling + bundle_dir=job.bundle_dir or overrides.get("bundle_dir"), + bundle_overwrite=job.bundle_overwrite if job.bundle_overwrite is not None else (overrides.get("bundle_overwrite") or False), + eval_set_id=job.eval_set_id or overrides.get("eval_set_id"), + ) + + +# --------------------------------------------------------------------------- +# Model resolution +# --------------------------------------------------------------------------- + + +def _resolve_models(job: Any) -> list[str]: + if job.models: + return job.models + return list(DEFAULT_MODELS) + + +# --------------------------------------------------------------------------- +# Sandbox resolution +# --------------------------------------------------------------------------- + + +def _resolve_sandbox( + dataset_root: str, + job: Any, + *, + flutter_channel: str | None = None, +) -> Any: + """Resolve sandbox spec for a given config.""" + sandbox_type = job.sandbox_type + if not sandbox_type or sandbox_type == "local": + return "local" + + # Channel override + if flutter_channel and flutter_channel in SDK_CHANNELS: + registry_key = SDK_CHANNELS[flutter_channel] + if registry_key in SANDBOX_REGISTRY: + defn = SANDBOX_REGISTRY[registry_key] + sandbox_path = defn["path"] + if not os.path.isabs(sandbox_path): + sandbox_path = os.path.normpath(os.path.join(dataset_root, sandbox_path)) + return {"type": defn["name"], "path": sandbox_path} + + # Named sandbox from registry + if sandbox_type in SANDBOX_REGISTRY: + defn = SANDBOX_REGISTRY[sandbox_type] + sandbox_path = defn["path"] + if not os.path.isabs(sandbox_path): + sandbox_path = os.path.normpath(os.path.join(dataset_root, sandbox_path)) + return {"type": defn["name"], "path": sandbox_path} + + return "local" + + +# --------------------------------------------------------------------------- +# Task Ɨ variant expansion +# --------------------------------------------------------------------------- + + +def _expand_task_configs( + dataset_tasks: list[ParsedTask], + job: Any, + sandbox_type: str, + dataset_root: str, +) -> list[ParsedTask]: + """Expand task Ɨ variant combinations.""" + job_variants = job.variants or {"baseline": {}} + expanded: list[ParsedTask] = [] + + for tc in dataset_tasks: + task_id = tc.id + + # Filter by job.tasks + if job.tasks is not None and task_id not in job.tasks: + continue + + # Determine effective variants (intersection) + effective_variants: dict[str, dict[str, Any]] = {} + for vname, vdef in job_variants.items(): + if tc.allowed_variants is None or vname in tc.allowed_variants: + effective_variants[vname] = vdef + + # Get job-level task overrides + job_task = job.tasks.get(task_id) if job.tasks else None + + # Apply sample filtering + samples = tc.samples + if job_task is not None: + if job_task.include_samples: + samples = [s for s in samples if s.id in job_task.include_samples] + if job_task.exclude_samples: + samples = [s for s in samples if s.id not in job_task.exclude_samples] + + # Apply system_message override + system_message = tc.system_message + if job_task and job_task.system_message is not None: + system_message = job_task.system_message + + # Create one ParsedTask per effective variant + for vname, vdef in effective_variants.items(): + variant = _resolve_variant(vname, vdef, dataset_root) + + examples_dir = None + if job.save_examples: + examples_dir = os.path.join(job.log_dir, "examples") + + expanded.append( + tc.copy_with( + samples=samples, + variant=variant, + sandbox_type=sandbox_type, + system_message=system_message, + allowed_variants=None, + save_examples=job.save_examples, + examples_dir=examples_dir, + ) + ) + + return expanded + + +# --------------------------------------------------------------------------- +# Variant resolution +# --------------------------------------------------------------------------- + + +def _resolve_variant( + name: str, + vdef: dict[str, Any], + dataset_root: str, +) -> Variant: + """Resolve a variant dict into a fully-resolved Variant.""" + if not vdef: + return Variant(name=name) + + # Load context files (with glob support) + context_files: list[ContextFile] = [] + cf_paths: list[str] = vdef.get("context_files") or [] + for cf_path in cf_paths: + if _is_glob(cf_path): + full_pattern = os.path.join(dataset_root, cf_path) + matched = sorted( + f + for f in globmod.glob(full_pattern, recursive=True) + if os.path.isfile(f) and (f.endswith(".yaml") or f.endswith(".yml") or f.endswith(".md")) + ) + if not matched: + raise FileNotFoundError(f"No context files matched pattern: {cf_path}") + for f in matched: + context_files.append(ContextFile.load(f)) + else: + full_path = os.path.normpath(os.path.join(dataset_root, cf_path)) + context_files.append(ContextFile.load(full_path)) + + # Resolve skill paths (with glob support) + skill_paths: list[str] = [] + raw_skills: list[str] = vdef.get("skills") or vdef.get("skill_paths") or [] + for skill_path_str in raw_skills: + if _is_glob(skill_path_str): + full_pattern = os.path.join(dataset_root, skill_path_str) + matched_dirs = sorted( + d + for d in globmod.glob(full_pattern, recursive=True) + if os.path.isdir(d) + ) + valid_dirs = [d for d in matched_dirs if os.path.isfile(os.path.join(d, "SKILL.md"))] + if not valid_dirs: + raise FileNotFoundError(f"No skill directories matched pattern: {skill_path_str}") + skill_paths.extend(valid_dirs) + else: + skill_dir = os.path.normpath(os.path.join(dataset_root, skill_path_str)) + if not os.path.isdir(skill_dir): + raise FileNotFoundError(f"Skill directory not found: {skill_dir}") + if not os.path.isfile(os.path.join(skill_dir, "SKILL.md")): + raise FileNotFoundError( + f"SKILL.md not found in {skill_dir}. " + "Each skill directory must contain a SKILL.md file." + ) + skill_paths.append(skill_dir) + + return Variant( + name=name, + context_files=context_files, + mcp_servers=vdef.get("mcp_servers") or [], + skill_paths=skill_paths, + flutter_channel=vdef.get("flutter_channel"), + ) + + +# --------------------------------------------------------------------------- +# Serialization helpers +# --------------------------------------------------------------------------- + + +def _serialize_sandbox(sandbox: Any) -> Any: + """Serialize sandbox to eval_set()-compatible format.""" + if isinstance(sandbox, str): + return None if sandbox == "local" else sandbox + if isinstance(sandbox, dict): + return [sandbox["type"], sandbox["path"]] + return None diff --git a/packages/dataset_config_python/src/dataset_config_python/writer.py b/packages/dataset_config_python/src/dataset_config_python/writer.py new file mode 100644 index 0000000..fc40128 --- /dev/null +++ b/packages/dataset_config_python/src/dataset_config_python/writer.py @@ -0,0 +1,29 @@ +"""Writer — serializes EvalSet objects to JSON files.""" + +from __future__ import annotations + +import json +import os + +from dataset_config_python.models.eval_set import EvalSet + + +def write_eval_sets(eval_sets: list[EvalSet], output_dir: str) -> str: + """Write EvalSet JSON for the given resolved configs. + + Files are written to *output_dir*. Returns the path to the JSON file. + + Single config → single JSON object; multiple → JSON array. + """ + os.makedirs(output_dir, exist_ok=True) + json_path = os.path.join(output_dir, "eval_set.json") + + if len(eval_sets) == 1: + json_content = eval_sets[0].model_dump(exclude_none=True) + else: + json_content = [es.model_dump(exclude_none=True) for es in eval_sets] + + with open(json_path, "w") as f: + json.dump(json_content, f, indent=2) + + return json_path diff --git a/packages/dataset_config_python/tests/__init__.py b/packages/dataset_config_python/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/dataset_config_python/tests/test_config.py b/packages/dataset_config_python/tests/test_config.py new file mode 100644 index 0000000..b79f9a9 --- /dev/null +++ b/packages/dataset_config_python/tests/test_config.py @@ -0,0 +1,363 @@ +"""Tests for the dataset_config_python package.""" + +from __future__ import annotations + +import json +import os + +import pytest + +from dataset_config_python import resolve, write_eval_sets +from dataset_config_python.models import ( + ContextFile, + Dataset, + EvalSet, + JobTask, + Sample, + Task, + Variant, +) +from dataset_config_python.parser import find_job_file, parse_job, parse_tasks + +# --------------------------------------------------------------------------- +# Fixtures: create a minimal dataset directory structure +# --------------------------------------------------------------------------- + + +@pytest.fixture +def dataset_dir(tmp_path): + """Create a minimal dataset directory with tasks and jobs.""" + # tasks/dart_qa/task.yaml + task_dir = tmp_path / "tasks" / "dart_qa" + task_dir.mkdir(parents=True) + task_yaml = task_dir / "task.yaml" + task_yaml.write_text( + """ +id: dart_qa +func: question_answer +system_message: "You are an expert." +samples: + inline: + - id: sample_1 + input: "What is Dart?" + target: "A programming language." + difficulty: easy + - id: sample_2 + input: "What is Flutter?" + target: "A UI framework." + difficulty: medium + tags: ui, framework +""" + ) + + # tasks/code_gen/task.yaml + code_gen_dir = tmp_path / "tasks" / "code_gen" + code_gen_dir.mkdir(parents=True) + code_gen_yaml = code_gen_dir / "task.yaml" + code_gen_yaml.write_text( + """ +id: code_gen +func: flutter_code_gen +time_limit: 600 +allowed_variants: + - baseline + - context_only +samples: + inline: + - id: sample_1 + input: "Create a counter app." + target: "A working counter app." +""" + ) + + # jobs/local_dev.yaml + jobs_dir = tmp_path / "jobs" + jobs_dir.mkdir() + job_yaml = jobs_dir / "local_dev.yaml" + job_yaml.write_text( + """ +logs_dir: ./logs +sandbox_type: local +max_connections: 5 +models: + - google/gemini-2.5-flash +variants: + baseline: {} + context_only: + context_files: [] +""" + ) + + return tmp_path + + +@pytest.fixture +def dataset_dir_with_sample_files(tmp_path): + """Create a dataset directory with external sample files.""" + task_dir = tmp_path / "tasks" / "qa" + task_dir.mkdir(parents=True) + + # External sample file + samples_dir = task_dir / "samples" + samples_dir.mkdir() + sample_file = samples_dir / "basics.yaml" + sample_file.write_text( + """ +id: basic_1 +input: "Explain null safety." +target: "Null safety prevents null pointer exceptions." +--- +id: basic_2 +input: "What are isolates?" +target: "Isolates are Dart's concurrency model." +""" + ) + + task_yaml = task_dir / "task.yaml" + task_yaml.write_text( + """ +id: qa +func: question_answer +samples: + paths: + - samples/basics.yaml +""" + ) + + jobs_dir = tmp_path / "jobs" + jobs_dir.mkdir() + (jobs_dir / "default.yaml").write_text( + """ +logs_dir: ./logs +""" + ) + + return tmp_path + + +# --------------------------------------------------------------------------- +# Model tests +# --------------------------------------------------------------------------- + + +class TestModels: + def test_sample_creation(self): + s = Sample(input="test", target="expected", id="s1") + assert s.input == "test" + assert s.target == "expected" + assert s.id == "s1" + + def test_sample_defaults(self): + s = Sample(input="test") + assert s.target == "" + assert s.id is None + assert s.metadata is None + + def test_dataset_creation(self): + samples = [Sample(input="a", target="b", id="1")] + ds = Dataset(samples=samples, name="test_ds") + assert len(ds.samples) == 1 + assert ds.name == "test_ds" + + def test_variant_defaults(self): + v = Variant() + assert v.name == "baseline" + assert v.context_files == [] + assert v.mcp_servers == [] + assert v.skill_paths == [] + assert v.flutter_channel is None + + def test_job_task_from_yaml_none(self): + jt = JobTask.from_yaml("my_task", None) + assert jt.id == "my_task" + assert jt.include_samples is None + + def test_job_task_from_yaml_with_data(self): + jt = JobTask.from_yaml("my_task", {"include-samples": ["s1", "s2"]}) + assert jt.include_samples == ["s1", "s2"] + + def test_eval_set_serialization(self): + es = EvalSet( + tasks=[Task(name="test:baseline", task_func="qa")], + log_dir="/tmp/logs", + model=["google/gemini-2.5-flash"], + ) + data = es.model_dump(exclude_none=True) + assert data["log_dir"] == "/tmp/logs" + assert len(data["tasks"]) == 1 + assert data["tasks"][0]["name"] == "test:baseline" + + +# --------------------------------------------------------------------------- +# Parser tests +# --------------------------------------------------------------------------- + + +class TestParser: + def test_parse_tasks(self, dataset_dir): + tasks = parse_tasks(str(dataset_dir)) + assert len(tasks) == 2 + task_ids = {t.id for t in tasks} + assert "dart_qa" in task_ids + assert "code_gen" in task_ids + + def test_parse_tasks_samples(self, dataset_dir): + tasks = parse_tasks(str(dataset_dir)) + dart_qa = next(t for t in tasks if t.id == "dart_qa") + assert len(dart_qa.samples) == 2 + assert dart_qa.samples[0].id == "sample_1" + + def test_parse_tasks_metadata(self, dataset_dir): + tasks = parse_tasks(str(dataset_dir)) + dart_qa = next(t for t in tasks if t.id == "dart_qa") + # Check tags normalization + s2 = next(s for s in dart_qa.samples if s.id == "sample_2") + assert s2.metadata is not None + assert s2.metadata["tags"] == ["ui", "framework"] + assert s2.metadata["difficulty"] == "medium" + + def test_parse_tasks_allowed_variants(self, dataset_dir): + tasks = parse_tasks(str(dataset_dir)) + code_gen = next(t for t in tasks if t.id == "code_gen") + assert code_gen.allowed_variants == ["baseline", "context_only"] + + def test_parse_tasks_time_limit(self, dataset_dir): + tasks = parse_tasks(str(dataset_dir)) + code_gen = next(t for t in tasks if t.id == "code_gen") + assert code_gen.time_limit == 600 + + def test_parse_job(self, dataset_dir): + job_path = os.path.join(str(dataset_dir), "jobs", "local_dev.yaml") + job = parse_job(job_path, str(dataset_dir)) + assert job.sandbox_type == "local" + assert job.max_connections == 5 + assert job.models == ["google/gemini-2.5-flash"] + + def test_find_job_file(self, dataset_dir): + path = find_job_file(str(dataset_dir), "local_dev") + assert path.endswith("local_dev.yaml") + + def test_find_job_file_not_found(self, dataset_dir): + with pytest.raises(FileNotFoundError): + find_job_file(str(dataset_dir), "nonexistent") + + def test_parse_tasks_with_sample_files(self, dataset_dir_with_sample_files): + """Test parsing tasks with external sample files (multi-doc YAML).""" + tasks = parse_tasks(str(dataset_dir_with_sample_files)) + assert len(tasks) == 1 + qa = tasks[0] + assert qa.id == "qa" + assert len(qa.samples) == 2 + assert qa.samples[0].id == "basic_1" + assert qa.samples[1].id == "basic_2" + + def test_parse_tasks_empty_dir(self, tmp_path): + tasks = parse_tasks(str(tmp_path)) + assert tasks == [] + + +# --------------------------------------------------------------------------- +# Resolver tests +# --------------------------------------------------------------------------- + + +class TestResolver: + def test_resolve_basic(self, dataset_dir): + eval_sets = resolve(dataset_path=str(dataset_dir), job_names=["local_dev"]) + assert len(eval_sets) == 1 + es = eval_sets[0] + assert es.model == ["google/gemini-2.5-flash"] + assert es.log_level == "info" + + def test_resolve_task_variant_expansion(self, dataset_dir): + eval_sets = resolve(dataset_path=str(dataset_dir), job_names=["local_dev"]) + es = eval_sets[0] + # dart_qa has 2 variants (baseline, context_only), code_gen has 2 allowed + task_names = [t.name for t in es.tasks] + assert "dart_qa:baseline" in task_names + assert "dart_qa:context_only" in task_names + assert "code_gen:baseline" in task_names + assert "code_gen:context_only" in task_names + + def test_resolve_inline_datasets(self, dataset_dir): + eval_sets = resolve(dataset_path=str(dataset_dir), job_names=["local_dev"]) + es = eval_sets[0] + dart_qa_baseline = next(t for t in es.tasks if t.name == "dart_qa:baseline") + assert dart_qa_baseline.dataset is not None + assert len(dart_qa_baseline.dataset.samples) == 2 + + def test_resolve_sandbox_local(self, dataset_dir): + eval_sets = resolve(dataset_path=str(dataset_dir), job_names=["local_dev"]) + es = eval_sets[0] + assert es.sandbox is None # 'local' serializes to None + + +# --------------------------------------------------------------------------- +# Writer tests +# --------------------------------------------------------------------------- + + +class TestWriter: + def test_write_single(self, dataset_dir, tmp_path): + eval_sets = resolve(dataset_path=str(dataset_dir), job_names=["local_dev"]) + output_dir = str(tmp_path / "output") + json_path = write_eval_sets(eval_sets, output_dir) + assert os.path.isfile(json_path) + + with open(json_path) as f: + data = json.load(f) + assert isinstance(data, dict) + assert "tasks" in data + assert "log_dir" in data + + def test_write_multiple(self, tmp_path): + es1 = EvalSet( + tasks=[Task(name="t1:baseline", task_func="qa")], + log_dir="/tmp/logs1", + ) + es2 = EvalSet( + tasks=[Task(name="t2:baseline", task_func="qa")], + log_dir="/tmp/logs2", + ) + output_dir = str(tmp_path / "output") + json_path = write_eval_sets([es1, es2], output_dir) + + with open(json_path) as f: + data = json.load(f) + assert isinstance(data, list) + assert len(data) == 2 + + +# --------------------------------------------------------------------------- +# Context file tests +# --------------------------------------------------------------------------- + + +class TestContextFile: + def test_load(self, tmp_path): + cf = tmp_path / "context.md" + cf.write_text( + """--- +title: Flutter Guide +version: "1.0" +description: A guide to Flutter +--- +# Content starts here + +Some markdown content. +""" + ) + loaded = ContextFile.load(str(cf)) + assert loaded.metadata.title == "Flutter Guide" + assert loaded.metadata.version == "1.0" + assert "Content starts here" in loaded.content + + def test_load_not_found(self): + with pytest.raises(FileNotFoundError): + ContextFile.load("/nonexistent/file.md") + + def test_load_no_frontmatter(self, tmp_path): + cf = tmp_path / "bad.md" + cf.write_text("No frontmatter here") + with pytest.raises(ValueError): + ContextFile.load(str(cf)) diff --git a/packages/devals_cli/example/pubspec.lock b/packages/devals_cli/example/pubspec.lock index 5a4eaa7..e0444fb 100644 --- a/packages/devals_cli/example/pubspec.lock +++ b/packages/devals_cli/example/pubspec.lock @@ -111,10 +111,10 @@ packages: dependency: transitive description: name: matcher - sha256: "12956d0ad8390bbcc63ca2e1469c0619946ccb52809807067a7020d57e647aa6" + sha256: dc0b7dc7651697ea4ff3e69ef44b0407ea32c487a39fff6a4004fa585e901861 url: "https://pub.dev" source: hosted - version: "0.12.18" + version: "0.12.19" material_color_utilities: dependency: transitive description: @@ -188,10 +188,10 @@ packages: dependency: transitive description: name: test_api - sha256: "93167629bfc610f71560ab9312acdda4959de4df6fac7492c89ff0d3886f6636" + sha256: "8161c84903fd860b26bfdefb7963b3f0b68fee7adea0f59ef805ecca346f0c7a" url: "https://pub.dev" source: hosted - version: "0.7.9" + version: "0.7.10" vector_math: dependency: transitive description: diff --git a/packages/devals_cli/lib/src/commands/create_job_command.dart b/packages/devals_cli/lib/src/commands/create_job_command.dart index 3a47c0d..ba4dc03 100644 --- a/packages/devals_cli/lib/src/commands/create_job_command.dart +++ b/packages/devals_cli/lib/src/commands/create_job_command.dart @@ -1,5 +1,5 @@ import 'package:args/command_runner.dart'; -import 'package:dataset_config/dataset_config.dart'; +import 'package:dataset_config_dart/dataset_config_dart.dart'; import 'package:devals/src/dataset/dataset_reader.dart'; import 'package:devals/src/dataset/eval_writer.dart'; import 'package:devals/src/dataset/file_templates/job_template.dart'; diff --git a/packages/devals_cli/lib/src/commands/create_pipeline_command.dart b/packages/devals_cli/lib/src/commands/create_pipeline_command.dart index e632581..22b1e61 100644 --- a/packages/devals_cli/lib/src/commands/create_pipeline_command.dart +++ b/packages/devals_cli/lib/src/commands/create_pipeline_command.dart @@ -1,7 +1,7 @@ import 'dart:io'; import 'package:args/command_runner.dart'; -import 'package:dataset_config/dataset_config.dart'; +import 'package:dataset_config_dart/dataset_config_dart.dart'; import 'package:devals/src/cli_exception.dart'; import 'package:devals/src/dataset/eval_writer.dart'; import 'package:devals/src/dataset/file_templates/job_template.dart'; diff --git a/packages/devals_cli/lib/src/commands/run_command.dart b/packages/devals_cli/lib/src/commands/run_command.dart index b7e159e..a9fde02 100644 --- a/packages/devals_cli/lib/src/commands/run_command.dart +++ b/packages/devals_cli/lib/src/commands/run_command.dart @@ -1,7 +1,7 @@ import 'dart:io'; import 'package:args/command_runner.dart'; -import 'package:dataset_config/dataset_config.dart'; +import 'package:dataset_config_dart/dataset_config_dart.dart'; import 'package:devals/src/dataset/dry_run.dart'; import 'package:devals/src/dataset/filesystem_utils.dart'; import 'package:howdy/howdy.dart'; diff --git a/packages/devals_cli/lib/src/dataset/dry_run.dart b/packages/devals_cli/lib/src/dataset/dry_run.dart index 651614a..891f700 100644 --- a/packages/devals_cli/lib/src/dataset/dry_run.dart +++ b/packages/devals_cli/lib/src/dataset/dry_run.dart @@ -1,4 +1,4 @@ -import 'package:dataset_config/dataset_config.dart'; +import 'package:dataset_config_dart/dataset_config_dart.dart'; /// Preview resolved config without running evaluations. /// diff --git a/packages/devals_cli/pubspec.yaml b/packages/devals_cli/pubspec.yaml index 0a723c8..40bc053 100644 --- a/packages/devals_cli/pubspec.yaml +++ b/packages/devals_cli/pubspec.yaml @@ -21,8 +21,8 @@ dependencies: git: url: https://github.com/ericwindmill/howdy.git path: packages/howdy-cli - dataset_config: - path: ../dataset_config + dataset_config_dart: + path: ../dataset_config_dart path: ^1.9.0 yaml: ^3.1.0 yaml_edit: ^2.2.0 diff --git a/pubspec.yaml b/pubspec.yaml index 6e22516..b62d219 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -8,7 +8,7 @@ environment: workspace: - packages/devals_cli - - packages/dataset_config + - packages/dataset_config_dart dependencies: args: ^2.7.0 From a476a6434ff64b2b7eb919b3f1650b51eb30b579 Mon Sep 17 00:00:00 2001 From: Eric Windmill Date: Wed, 11 Mar 2026 13:02:16 -0700 Subject: [PATCH 4/7] initial comp script --- .../dataset_config_python/models/dataset.py | 6 + .../dataset_config_python/models/sample.py | 15 +- .../src/dataset_config_python/models/task.py | 15 + tool/bin/resolve_dart.dart | 32 ++ tool/bin/resolve_python.py | 38 +++ tool/fixtures/basic/jobs/local_dev.yaml | 7 + tool/fixtures/basic/tasks/dart_qa/task.yaml | 14 + tool/fixtures/multi_variant/jobs/dev.yaml | 11 + .../multi_variant/tasks/code_gen/task.yaml | 11 + .../multi_variant/tasks/dart_qa/task.yaml | 7 + tool/verify_config_parity.dart | 298 ++++++++++++++++++ 11 files changed, 450 insertions(+), 4 deletions(-) create mode 100644 tool/bin/resolve_dart.dart create mode 100644 tool/bin/resolve_python.py create mode 100644 tool/fixtures/basic/jobs/local_dev.yaml create mode 100644 tool/fixtures/basic/tasks/dart_qa/task.yaml create mode 100644 tool/fixtures/multi_variant/jobs/dev.yaml create mode 100644 tool/fixtures/multi_variant/tasks/code_gen/task.yaml create mode 100644 tool/fixtures/multi_variant/tasks/dart_qa/task.yaml create mode 100644 tool/verify_config_parity.dart diff --git a/packages/dataset_config_python/src/dataset_config_python/models/dataset.py b/packages/dataset_config_python/src/dataset_config_python/models/dataset.py index 4892eaa..b04ceb5 100644 --- a/packages/dataset_config_python/src/dataset_config_python/models/dataset.py +++ b/packages/dataset_config_python/src/dataset_config_python/models/dataset.py @@ -15,3 +15,9 @@ class Dataset(BaseModel): name: str = "" """Display name for the dataset.""" + + location: str | None = None + """Dataset location (file path or remote URL).""" + + shuffled: bool = False + """Whether the dataset was shuffled after reading.""" diff --git a/packages/dataset_config_python/src/dataset_config_python/models/sample.py b/packages/dataset_config_python/src/dataset_config_python/models/sample.py index d6fe2ce..442ebce 100644 --- a/packages/dataset_config_python/src/dataset_config_python/models/sample.py +++ b/packages/dataset_config_python/src/dataset_config_python/models/sample.py @@ -13,11 +13,18 @@ class Sample(BaseModel): Maps to Inspect AI's ``Sample`` class. """ - input: str - """The input to be submitted to the model.""" + input: str | list[Any] + """The input to be submitted to the model. - target: str = "" - """Ideal target output.""" + Can be a simple string or a list of ChatMessage-like objects. + """ + + target: str | list[str] = "" + """Ideal target output. + + May be a literal value or narrative text to be used by a model grader. + Can be a single string or a list of strings. + """ id: str | None = None """Unique identifier for the sample.""" diff --git a/packages/dataset_config_python/src/dataset_config_python/models/task.py b/packages/dataset_config_python/src/dataset_config_python/models/task.py index aeaf471..cafbbe3 100644 --- a/packages/dataset_config_python/src/dataset_config_python/models/task.py +++ b/packages/dataset_config_python/src/dataset_config_python/models/task.py @@ -25,6 +25,21 @@ class Task(BaseModel): dataset: Dataset | None = None """Inline dataset with samples.""" + setup: Any | None = None + """Setup step (always run even when the main solver is replaced).""" + + solver: Any | None = None + """Solver or list of solvers. Defaults to ``generate()``.""" + + cleanup: Any | None = None + """Optional cleanup function for task.""" + + scorer: Any | None = None + """Scorer used to evaluate model output.""" + + metrics: Any | None = None + """Alternative metrics (overrides the metrics provided by the scorer).""" + sandbox: Any | None = None """Sandbox environment type.""" diff --git a/tool/bin/resolve_dart.dart b/tool/bin/resolve_dart.dart new file mode 100644 index 0000000..8233ec2 --- /dev/null +++ b/tool/bin/resolve_dart.dart @@ -0,0 +1,32 @@ +import 'dart:convert'; + +import 'package:dataset_config_dart/dataset_config_dart.dart'; + +/// Thin CLI wrapper that resolves a dataset + job using dataset_config_dart +/// and prints the resulting EvalSet JSON to stdout. +/// +/// Usage: +/// dart run tool/bin/resolve_dart.dart [datasetPath] [jobName] +void main(List args) { + if (args.length != 2) { + throw ArgumentError( + 'Usage: dart run tool/bin/resolve_dart.dart ', + ); + } + + final datasetPath = args[0]; + final jobName = args[1]; + + final resolver = ConfigResolver(); + final evalSets = resolver.resolve(datasetPath, [jobName]); + + // Match the writer's convention: single → object, multiple → array + final jsonContent = evalSets.length == 1 + ? evalSets.first.toJson() + : evalSets.map((c) => c.toJson()).toList(); + + // Sort keys for stable comparison + final jsonString = const JsonEncoder.withIndent(' ').convert(jsonContent); + // ignore: avoid_print + print(jsonString); +} diff --git a/tool/bin/resolve_python.py b/tool/bin/resolve_python.py new file mode 100644 index 0000000..e5f8217 --- /dev/null +++ b/tool/bin/resolve_python.py @@ -0,0 +1,38 @@ +"""Thin CLI wrapper that resolves a dataset + job using dataset_config_python +and prints the resulting EvalSet JSON to stdout. + +Usage: + python tool/bin/resolve_python.py +""" + +from __future__ import annotations + +import json +import sys + +from dataset_config_python import resolve + + +def main() -> None: + if len(sys.argv) != 3: + raise SystemExit( + "Usage: python tool/bin/resolve_python.py " + ) + + dataset_path = sys.argv[1] + job_name = sys.argv[2] + + eval_sets = resolve(dataset_path=dataset_path, job_names=[job_name]) + + # Match the writer's convention: single → object, multiple → array + if len(eval_sets) == 1: + json_content = eval_sets[0].model_dump(exclude_none=True) + else: + json_content = [es.model_dump(exclude_none=True) for es in eval_sets] + + # Sort keys for stable comparison + print(json.dumps(json_content, indent=2, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/tool/fixtures/basic/jobs/local_dev.yaml b/tool/fixtures/basic/jobs/local_dev.yaml new file mode 100644 index 0000000..0c5beca --- /dev/null +++ b/tool/fixtures/basic/jobs/local_dev.yaml @@ -0,0 +1,7 @@ +logs_dir: ./logs +sandbox_type: local +max_connections: 5 +models: + - google/gemini-2.5-flash +variants: + baseline: {} diff --git a/tool/fixtures/basic/tasks/dart_qa/task.yaml b/tool/fixtures/basic/tasks/dart_qa/task.yaml new file mode 100644 index 0000000..af24388 --- /dev/null +++ b/tool/fixtures/basic/tasks/dart_qa/task.yaml @@ -0,0 +1,14 @@ +id: dart_qa +func: question_answer +system_message: "You are an expert." +samples: + inline: + - id: sample_1 + input: "What is Dart?" + target: "A programming language." + difficulty: easy + - id: sample_2 + input: "What is Flutter?" + target: "A UI framework." + difficulty: medium + tags: ui, framework diff --git a/tool/fixtures/multi_variant/jobs/dev.yaml b/tool/fixtures/multi_variant/jobs/dev.yaml new file mode 100644 index 0000000..e0e884b --- /dev/null +++ b/tool/fixtures/multi_variant/jobs/dev.yaml @@ -0,0 +1,11 @@ +logs_dir: ./logs +sandbox_type: local +models: + - google/gemini-2.5-flash +variants: + baseline: {} + context_only: + context_files: [] + full_mcp: + mcp_servers: + - my_server diff --git a/tool/fixtures/multi_variant/tasks/code_gen/task.yaml b/tool/fixtures/multi_variant/tasks/code_gen/task.yaml new file mode 100644 index 0000000..d004f01 --- /dev/null +++ b/tool/fixtures/multi_variant/tasks/code_gen/task.yaml @@ -0,0 +1,11 @@ +id: code_gen +func: flutter_code_gen +time_limit: 600 +allowed_variants: + - baseline + - context_only +samples: + inline: + - id: sample_1 + input: "Create a counter app." + target: "A working counter app." diff --git a/tool/fixtures/multi_variant/tasks/dart_qa/task.yaml b/tool/fixtures/multi_variant/tasks/dart_qa/task.yaml new file mode 100644 index 0000000..38751a1 --- /dev/null +++ b/tool/fixtures/multi_variant/tasks/dart_qa/task.yaml @@ -0,0 +1,7 @@ +id: dart_qa +func: question_answer +samples: + inline: + - id: sample_1 + input: "Explain null safety." + target: "Null safety prevents null pointer exceptions." diff --git a/tool/verify_config_parity.dart b/tool/verify_config_parity.dart new file mode 100644 index 0000000..5f82578 --- /dev/null +++ b/tool/verify_config_parity.dart @@ -0,0 +1,298 @@ +import 'dart:convert'; +import 'dart:io'; + +import 'package:path/path.dart' as p; + +/// Cross-language config parity verification. +/// +/// For each fixture in `tool/fixtures/`, runs both the Dart and Python config +/// resolvers with the same YAML input and verifies they produce identical +/// JSON output. +/// +/// Usage: +/// dart run tool/verify_config_parity.dart +/// +/// Exit codes: +/// 0 — all fixtures match +/// 1 — one or more fixtures diverge (diff printed to stderr) +void main() async { + final repoRoot = _findRepoRoot(); + final fixturesDir = Directory(p.join(repoRoot, 'tool', 'fixtures')); + + if (!fixturesDir.existsSync()) { + stderr.writeln('ERROR: fixtures directory not found: ${fixturesDir.path}'); + exit(1); + } + + final fixtureDirs = fixturesDir.listSync().whereType().toList() + ..sort((a, b) => a.path.compareTo(b.path)); + + if (fixtureDirs.isEmpty) { + stderr.writeln( + 'ERROR: no fixture directories found in ${fixturesDir.path}', + ); + exit(1); + } + + stdout.writeln('Config Parity Verification'); + stdout.writeln('=' * 60); + stdout.writeln(''); + + var allPassed = true; + + for (final fixtureDir in fixtureDirs) { + final fixtureName = p.basename(fixtureDir.path); + final jobsDir = Directory(p.join(fixtureDir.path, 'jobs')); + + if (!jobsDir.existsSync()) { + stderr.writeln(' SKIP $fixtureName — no jobs/ directory'); + continue; + } + + final jobFiles = jobsDir + .listSync() + .whereType() + .where((f) => f.path.endsWith('.yaml') || f.path.endsWith('.yml')) + .toList(); + + for (final jobFile in jobFiles) { + final jobName = p.basenameWithoutExtension(jobFile.path); + final label = '$fixtureName / $jobName'; + + stdout.write(' $label ... '); + + try { + final passed = await _verifyFixture( + repoRoot: repoRoot, + datasetPath: fixtureDir.path, + jobName: jobName, + label: label, + ); + if (passed) { + stdout.writeln('āœ… PASS'); + } else { + stdout.writeln('āŒ FAIL'); + allPassed = false; + } + } catch (e) { + stdout.writeln('šŸ’„ ERROR'); + stderr.writeln(' $e'); + allPassed = false; + } + } + } + + stdout.writeln(''); + if (allPassed) { + stdout.writeln('All fixtures passed! šŸŽ‰'); + } else { + stdout.writeln('Some fixtures FAILED. See errors above.'); + exit(1); + } +} + +/// Run both resolvers on the given fixture and compare JSON output. +Future _verifyFixture({ + required String repoRoot, + required String datasetPath, + required String jobName, + required String label, +}) async { + // Run Dart resolver + final dartResult = await Process.run( + 'dart', + [ + 'run', + p.join(repoRoot, 'tool', 'bin', 'resolve_dart.dart'), + datasetPath, + jobName, + ], + workingDirectory: repoRoot, + ); + + if (dartResult.exitCode != 0) { + stderr.writeln(' Dart resolver failed (exit ${dartResult.exitCode}):'); + stderr.writeln(_indent(dartResult.stderr.toString())); + return false; + } + + // Run Python resolver + // Use the venv Python if available, otherwise fall back to system python3. + final pythonBin = _findPython(repoRoot); + final pythonResult = await Process.run( + pythonBin, + [ + p.join(repoRoot, 'tool', 'bin', 'resolve_python.py'), + datasetPath, + jobName, + ], + workingDirectory: repoRoot, + environment: { + 'PYTHONPATH': p.join( + repoRoot, + 'packages', + 'dataset_config_python', + 'src', + ), + }, + ); + + if (pythonResult.exitCode != 0) { + stderr.writeln( + ' Python resolver failed (exit ${pythonResult.exitCode}):', + ); + stderr.writeln(_indent(pythonResult.stderr.toString())); + return false; + } + + // Parse JSON outputs + final dartJson = _parseAndNormalize(dartResult.stdout.toString().trim()); + final pythonJson = _parseAndNormalize(pythonResult.stdout.toString().trim()); + + // Deep compare + if (_deepEquals(dartJson, pythonJson)) { + return true; + } + + // Print diff on failure + final dartPretty = const JsonEncoder.withIndent(' ').convert(dartJson); + final pythonPretty = const JsonEncoder.withIndent(' ').convert(pythonJson); + + stderr.writeln(' JSON output differs for $label:'); + stderr.writeln(''); + _printDiff(dartPretty, pythonPretty); + + return false; +} + +/// Regex to match timestamped log_dir suffixes (e.g. /2026-03-11_19-52-13). +final _timestampSuffix = RegExp(r'/\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}$'); + +/// Parse JSON string and normalize for comparison: +/// - Sort all map keys recursively +/// - Remove null values (Dart includes them, Python excludes with exclude_none) +/// - Remove empty maps/lists that one side might include but the other omits +/// - Normalize timestamped log_dir paths +dynamic _parseAndNormalize(String jsonStr) { + final parsed = json.decode(jsonStr); + return _normalize(parsed); +} + +/// Recursively normalize a JSON value for comparison. +dynamic _normalize(dynamic value) { + if (value is Map) { + final sorted = Map.fromEntries( + (value.entries.toList() + ..sort((a, b) => a.key.toString().compareTo(b.key.toString()))) + .map((e) => MapEntry(e.key.toString(), _normalize(e.value))), + ); + // Remove null values (Dart freezed includes them, Python excludes them) + sorted.removeWhere((k, v) => v == null); + // Remove empty map/list values that might be omitted on the other side + sorted.removeWhere((k, v) { + if (v is Map && v.isEmpty) return true; + if (v is List && v.isEmpty) return true; + return false; + }); + // Strip Dart-only fields that have false/zero defaults and are absent + // from Python models (e.g. Dataset.shuffled, Dataset.location). + const dartOnlyDefaults = { + 'shuffled': false, + }; + for (final entry in dartOnlyDefaults.entries) { + if (sorted[entry.key] == entry.value) { + sorted.remove(entry.key); + } + } + // Normalize timestamped log_dir paths — both sides append timestamps + // but at slightly different times; strip the timestamp for comparison + if (sorted.containsKey('log_dir') && sorted['log_dir'] is String) { + sorted['log_dir'] = (sorted['log_dir'] as String).replaceAll( + _timestampSuffix, + '/', + ); + } + return sorted; + } + if (value is List) { + return value.map(_normalize).toList(); + } + // Normalize numeric types: int 0 == double 0.0 + if (value is num) { + if (value == value.toInt()) return value.toInt(); + return value.toDouble(); + } + return value; +} + +/// Deep equality check for JSON-like structures. +bool _deepEquals(dynamic a, dynamic b) { + if (a is Map && b is Map) { + if (a.length != b.length) return false; + for (final key in a.keys) { + if (!b.containsKey(key)) return false; + if (!_deepEquals(a[key], b[key])) return false; + } + return true; + } + if (a is List && b is List) { + if (a.length != b.length) return false; + for (var i = 0; i < a.length; i++) { + if (!_deepEquals(a[i], b[i])) return false; + } + return true; + } + return a == b; +} + +/// Print a line-by-line diff between two strings. +void _printDiff(String a, String b) { + final aLines = a.split('\n'); + final bLines = b.split('\n'); + final maxLines = aLines.length > bLines.length + ? aLines.length + : bLines.length; + + for (var i = 0; i < maxLines; i++) { + final aLine = i < aLines.length ? aLines[i] : ''; + final bLine = i < bLines.length ? bLines[i] : ''; + if (aLine != bLine) { + stderr.writeln(' dart: $aLine'); + stderr.writeln(' python: $bLine'); + stderr.writeln(''); + } + } +} + +/// Find the repo root by looking for pubspec.yaml. +String _findRepoRoot() { + var dir = Directory.current; + while (true) { + if (File(p.join(dir.path, 'pubspec.yaml')).existsSync() && + Directory(p.join(dir.path, 'packages')).existsSync()) { + return dir.path; + } + final parent = dir.parent; + if (parent.path == dir.path) { + // Fallback to current directory + return Directory.current.path; + } + dir = parent; + } +} + +/// Find the best Python executable. Prefers the repo's venv if it exists. +String _findPython(String repoRoot) { + final venvPython = p.join(repoRoot, '.venv', 'bin', 'python'); + if (File(venvPython).existsSync()) return venvPython; + + final venvPython3 = p.join(repoRoot, '.venv', 'bin', 'python3'); + if (File(venvPython3).existsSync()) return venvPython3; + + return 'python3'; +} + +/// Indent every line for nested error output. +String _indent(String text, {String prefix = ' '}) { + return text.split('\n').map((line) => '$prefix$line').join('\n'); +} From d87ccfc22f2c9d77f0e532c4b3d9698d199aaca4 Mon Sep 17 00:00:00 2001 From: Eric Windmill Date: Wed, 11 Mar 2026 13:16:19 -0700 Subject: [PATCH 5/7] adds tool to maintain two configs --- pubspec.yaml | 1 + .../bin/config_partiy.dart} | 16 +-- .../{ => config_parity}/bin/resolve_dart.dart | 0 .../{ => config_parity}/bin/resolve_python.py | 0 .../fixtures/basic/jobs/local_dev.yaml | 0 .../fixtures/basic/tasks/dart_qa/task.yaml | 0 .../fixtures/multi_variant/jobs/dev.yaml | 0 .../multi_variant/tasks/code_gen/task.yaml | 0 .../multi_variant/tasks/dart_qa/task.yaml | 0 tool/config_parity/pubspec.lock | 108 ++++++++++++++++++ tool/config_parity/pubspec.yaml | 13 +++ 11 files changed, 125 insertions(+), 13 deletions(-) rename tool/{verify_config_parity.dart => config_parity/bin/config_partiy.dart} (93%) rename tool/{ => config_parity}/bin/resolve_dart.dart (100%) rename tool/{ => config_parity}/bin/resolve_python.py (100%) rename tool/{ => config_parity}/fixtures/basic/jobs/local_dev.yaml (100%) rename tool/{ => config_parity}/fixtures/basic/tasks/dart_qa/task.yaml (100%) rename tool/{ => config_parity}/fixtures/multi_variant/jobs/dev.yaml (100%) rename tool/{ => config_parity}/fixtures/multi_variant/tasks/code_gen/task.yaml (100%) rename tool/{ => config_parity}/fixtures/multi_variant/tasks/dart_qa/task.yaml (100%) create mode 100644 tool/config_parity/pubspec.lock create mode 100644 tool/config_parity/pubspec.yaml diff --git a/pubspec.yaml b/pubspec.yaml index b62d219..f1c4e3e 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -9,6 +9,7 @@ environment: workspace: - packages/devals_cli - packages/dataset_config_dart + - tool/config_parity dependencies: args: ^2.7.0 diff --git a/tool/verify_config_parity.dart b/tool/config_parity/bin/config_partiy.dart similarity index 93% rename from tool/verify_config_parity.dart rename to tool/config_parity/bin/config_partiy.dart index 5f82578..640abfc 100644 --- a/tool/verify_config_parity.dart +++ b/tool/config_parity/bin/config_partiy.dart @@ -17,7 +17,7 @@ import 'package:path/path.dart' as p; /// 1 — one or more fixtures diverge (diff printed to stderr) void main() async { final repoRoot = _findRepoRoot(); - final fixturesDir = Directory(p.join(repoRoot, 'tool', 'fixtures')); + final fixturesDir = Directory(p.join(repoRoot, 'tool', 'config_parity', 'fixtures')); if (!fixturesDir.existsSync()) { stderr.writeln('ERROR: fixtures directory not found: ${fixturesDir.path}'); @@ -103,7 +103,7 @@ Future _verifyFixture({ 'dart', [ 'run', - p.join(repoRoot, 'tool', 'bin', 'resolve_dart.dart'), + p.join(repoRoot, 'tool', 'config_parity', 'bin', 'resolve_dart.dart'), datasetPath, jobName, ], @@ -122,7 +122,7 @@ Future _verifyFixture({ final pythonResult = await Process.run( pythonBin, [ - p.join(repoRoot, 'tool', 'bin', 'resolve_python.py'), + p.join(repoRoot, 'tool', 'config_parity', 'bin', 'resolve_python.py'), datasetPath, jobName, ], @@ -194,16 +194,6 @@ dynamic _normalize(dynamic value) { if (v is List && v.isEmpty) return true; return false; }); - // Strip Dart-only fields that have false/zero defaults and are absent - // from Python models (e.g. Dataset.shuffled, Dataset.location). - const dartOnlyDefaults = { - 'shuffled': false, - }; - for (final entry in dartOnlyDefaults.entries) { - if (sorted[entry.key] == entry.value) { - sorted.remove(entry.key); - } - } // Normalize timestamped log_dir paths — both sides append timestamps // but at slightly different times; strip the timestamp for comparison if (sorted.containsKey('log_dir') && sorted['log_dir'] is String) { diff --git a/tool/bin/resolve_dart.dart b/tool/config_parity/bin/resolve_dart.dart similarity index 100% rename from tool/bin/resolve_dart.dart rename to tool/config_parity/bin/resolve_dart.dart diff --git a/tool/bin/resolve_python.py b/tool/config_parity/bin/resolve_python.py similarity index 100% rename from tool/bin/resolve_python.py rename to tool/config_parity/bin/resolve_python.py diff --git a/tool/fixtures/basic/jobs/local_dev.yaml b/tool/config_parity/fixtures/basic/jobs/local_dev.yaml similarity index 100% rename from tool/fixtures/basic/jobs/local_dev.yaml rename to tool/config_parity/fixtures/basic/jobs/local_dev.yaml diff --git a/tool/fixtures/basic/tasks/dart_qa/task.yaml b/tool/config_parity/fixtures/basic/tasks/dart_qa/task.yaml similarity index 100% rename from tool/fixtures/basic/tasks/dart_qa/task.yaml rename to tool/config_parity/fixtures/basic/tasks/dart_qa/task.yaml diff --git a/tool/fixtures/multi_variant/jobs/dev.yaml b/tool/config_parity/fixtures/multi_variant/jobs/dev.yaml similarity index 100% rename from tool/fixtures/multi_variant/jobs/dev.yaml rename to tool/config_parity/fixtures/multi_variant/jobs/dev.yaml diff --git a/tool/fixtures/multi_variant/tasks/code_gen/task.yaml b/tool/config_parity/fixtures/multi_variant/tasks/code_gen/task.yaml similarity index 100% rename from tool/fixtures/multi_variant/tasks/code_gen/task.yaml rename to tool/config_parity/fixtures/multi_variant/tasks/code_gen/task.yaml diff --git a/tool/fixtures/multi_variant/tasks/dart_qa/task.yaml b/tool/config_parity/fixtures/multi_variant/tasks/dart_qa/task.yaml similarity index 100% rename from tool/fixtures/multi_variant/tasks/dart_qa/task.yaml rename to tool/config_parity/fixtures/multi_variant/tasks/dart_qa/task.yaml diff --git a/tool/config_parity/pubspec.lock b/tool/config_parity/pubspec.lock new file mode 100644 index 0000000..dd2733b --- /dev/null +++ b/tool/config_parity/pubspec.lock @@ -0,0 +1,108 @@ +# Generated by pub +# See https://dart.dev/tools/pub/glossary#lockfile +packages: + async: + dependency: transitive + description: + name: async + sha256: "758e6d74e971c3e5aceb4110bfd6698efc7f501675bcfe0c775459a8140750eb" + url: "https://pub.dev" + source: hosted + version: "2.13.0" + collection: + dependency: transitive + description: + name: collection + sha256: "2f5709ae4d3d59dd8f7cd309b4e023046b57d8a6c82130785d2b0e5868084e76" + url: "https://pub.dev" + source: hosted + version: "1.19.1" + dataset_config_dart: + dependency: "direct main" + description: + path: "../../packages/dataset_config_dart" + relative: true + source: path + version: "0.0.1" + file: + dependency: transitive + description: + name: file + sha256: a3b4f84adafef897088c160faf7dfffb7696046cb13ae90b508c2cbc95d3b8d4 + url: "https://pub.dev" + source: hosted + version: "7.0.1" + freezed_annotation: + dependency: transitive + description: + name: freezed_annotation + sha256: "7294967ff0a6d98638e7acb774aac3af2550777accd8149c90af5b014e6d44d8" + url: "https://pub.dev" + source: hosted + version: "3.1.0" + glob: + dependency: transitive + description: + name: glob + sha256: c3f1ee72c96f8f78935e18aa8cecced9ab132419e8625dc187e1c2408efc20de + url: "https://pub.dev" + source: hosted + version: "2.1.3" + json_annotation: + dependency: transitive + description: + name: json_annotation + sha256: cb09e7dac6210041fad964ed7fbee004f14258b4eca4040f72d1234062ace4c8 + url: "https://pub.dev" + source: hosted + version: "4.11.0" + meta: + dependency: transitive + description: + name: meta + sha256: "9f29b9bcc8ee287b1a31e0d01be0eae99a930dbffdaecf04b3f3d82a969f296f" + url: "https://pub.dev" + source: hosted + version: "1.18.1" + path: + dependency: "direct main" + description: + name: path + sha256: "75cca69d1490965be98c73ceaea117e8a04dd21217b37b292c9ddbec0d955bc5" + url: "https://pub.dev" + source: hosted + version: "1.9.1" + source_span: + dependency: transitive + description: + name: source_span + sha256: "56a02f1f4cd1a2d96303c0144c93bd6d909eea6bee6bf5a0e0b685edbd4c47ab" + url: "https://pub.dev" + source: hosted + version: "1.10.2" + string_scanner: + dependency: transitive + description: + name: string_scanner + sha256: "921cd31725b72fe181906c6a94d987c78e3b98c2e205b397ea399d4054872b43" + url: "https://pub.dev" + source: hosted + version: "1.4.1" + term_glyph: + dependency: transitive + description: + name: term_glyph + sha256: "7f554798625ea768a7518313e58f83891c7f5024f88e46e7182a4558850a4b8e" + url: "https://pub.dev" + source: hosted + version: "1.2.2" + yaml: + dependency: transitive + description: + name: yaml + sha256: b9da305ac7c39faa3f030eccd175340f968459dae4af175130b3fc47e40d76ce + url: "https://pub.dev" + source: hosted + version: "3.1.3" +sdks: + dart: ">=3.10.0 <4.0.0" diff --git a/tool/config_parity/pubspec.yaml b/tool/config_parity/pubspec.yaml new file mode 100644 index 0000000..ddbf86a --- /dev/null +++ b/tool/config_parity/pubspec.yaml @@ -0,0 +1,13 @@ +name: config_parity +publish_to: none +description: Scripts that keep python-config and dart-config aligned. +version: 0.0.1 +resolution: workspace + +environment: + sdk: ^3.10.0 + +dependencies: + path: ^1.9.1 + dataset_config_dart: + path: ../../packages/dataset_config_dart \ No newline at end of file From 9061e34f2aae2dc769cc77a8ce5140199c122c72 Mon Sep 17 00:00:00 2001 From: Eric Windmill Date: Wed, 11 Mar 2026 13:19:05 -0700 Subject: [PATCH 6/7] add to ci --- .github/workflows/config_parity.yml | 46 ++++++++++++++++++++++++ tool/config_parity/bin/resolve_python.py | 2 +- 2 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/config_parity.yml diff --git a/.github/workflows/config_parity.yml b/.github/workflows/config_parity.yml new file mode 100644 index 0000000..a2338af --- /dev/null +++ b/.github/workflows/config_parity.yml @@ -0,0 +1,46 @@ +name: Config Parity + +on: + pull_request: + paths: + - 'packages/dataset_config_dart/**' + - 'packages/dataset_config_python/**' + - 'tool/config_parity/**' + - '.github/workflows/config_parity.yml' + push: + branches: + - main + paths: + - 'packages/dataset_config_dart/**' + - 'packages/dataset_config_python/**' + - 'tool/config_parity/**' + - '.github/workflows/config_parity.yml' + +jobs: + config-parity: + runs-on: ubuntu-latest + timeout-minutes: 10 + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + + - name: Install Flutter + run: | + git clone https://github.com/flutter/flutter.git --depth 1 -b stable $HOME/flutter + echo "$HOME/flutter/bin" >> $GITHUB_PATH + echo "$HOME/.pub-cache/bin" >> $GITHUB_PATH + + - name: Install Dart dependencies + run: flutter pub get + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.13' + + - name: Install Python config package + run: pip install -e packages/dataset_config_python + + - name: Verify config parity + run: dart run tool/config_parity/bin/config_partiy.dart diff --git a/tool/config_parity/bin/resolve_python.py b/tool/config_parity/bin/resolve_python.py index e5f8217..cd5514b 100644 --- a/tool/config_parity/bin/resolve_python.py +++ b/tool/config_parity/bin/resolve_python.py @@ -10,7 +10,7 @@ import json import sys -from dataset_config_python import resolve +from dataset_config_python import resolve # pyrefly: ignore def main() -> None: From 3be68eb9fb02ad658f411142ccf930c446318d9e Mon Sep 17 00:00:00 2001 From: Eric Windmill Date: Wed, 11 Mar 2026 13:23:24 -0700 Subject: [PATCH 7/7] cleanup ci --- .../{config_tests.yml => config_dart_tests.yml} | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) rename .github/workflows/{config_tests.yml => config_dart_tests.yml} (69%) diff --git a/.github/workflows/config_tests.yml b/.github/workflows/config_dart_tests.yml similarity index 69% rename from .github/workflows/config_tests.yml rename to .github/workflows/config_dart_tests.yml index a32649b..b3fed8a 100644 --- a/.github/workflows/config_tests.yml +++ b/.github/workflows/config_dart_tests.yml @@ -3,14 +3,14 @@ name: Config Tests on: pull_request: paths: - - 'packages/dataset_config/**' - - '.github/workflows/config_tests.yml' + - 'packages/dataset_config_dart/**' + - '.github/workflows/config_dart_tests.yml' push: branches: - main paths: - - 'packages/dataset_config/**' - - '.github/workflows/config_tests.yml' + - 'packages/dataset_config_dart/**' + - '.github/workflows/config_dart_tests.yml' jobs: config-tests: @@ -31,9 +31,9 @@ jobs: run: flutter pub get - name: Analyze - working-directory: packages/dataset_config + working-directory: packages/dataset_config_dart run: dart analyze --fatal-infos - name: Run tests - working-directory: packages/dataset_config + working-directory: packages/dataset_config_dart run: dart test