Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 61 additions & 17 deletions dreadnode/scorers/judge.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing as t

import rigging as rg
from loguru import logger

from dreadnode.common_types import AnyDict
from dreadnode.meta import Config
Expand Down Expand Up @@ -42,6 +43,7 @@ def llm_judge(
input: t.Any | None = None,
expected_output: t.Any | None = None,
model_params: rg.GenerateParams | AnyDict | None = None,
fallback_model: str | rg.Generator | None = None,
passing: t.Callable[[float], bool] | None = None,
min_score: float | None = None,
max_score: float | None = None,
Expand All @@ -57,6 +59,7 @@ def llm_judge(
input: The input which produced the output for context, if applicable.
expected_output: The expected output to compare against, if applicable.
model_params: Optional parameters for the model.
fallback_model: Optional fallback model to use if the primary model fails.
passing: Optional callback to determine if the score is passing based on the score value - overrides any model-specified value.
min_score: Optional minimum score for the judgement - if provided, the score will be clamped to this value.
max_score: Optional maximum score for the judgement - if provided, the score will be clamped to this value.
Expand All @@ -74,36 +77,68 @@ async def evaluate(
input: t.Any | None = input,
expected_output: t.Any | None = expected_output,
model_params: rg.GenerateParams | AnyDict | None = model_params,
fallback_model: str | rg.Generator | None = fallback_model,
min_score: float | None = min_score,
max_score: float | None = max_score,
system_prompt: str | None = system_prompt,
) -> list[Metric]:
generator: rg.Generator
if isinstance(model, str):
generator = rg.get_generator(
model,
params=model_params
if isinstance(model_params, rg.GenerateParams)
else rg.GenerateParams.model_validate(model_params)
if model_params
else None,
)
elif isinstance(model, rg.Generator):
generator = model
else:
def _create_generator(
model: str | rg.Generator,
params: rg.GenerateParams | AnyDict | None,
) -> rg.Generator:
"""Create a Generator from a model identifier or return the Generator instance."""
if isinstance(model, str):
return rg.get_generator(
model,
params=params
if isinstance(params, rg.GenerateParams)
else rg.GenerateParams.model_validate(params)
if params
else None,
)
if isinstance(model, rg.Generator):
return model
raise TypeError("Model must be a string identifier or a Generator instance.")

generator = _create_generator(model, model_params)

input_data = JudgeInput(
input=str(input) if input is not None else None,
expected_output=str(expected_output) if expected_output is not None else None,
output=str(data),
rubric=rubric,
)

pipeline = generator.chat([])
if system_prompt:
pipeline.chat.inject_system_content(system_prompt)
judgement = await judge.bind(pipeline)(input_data)
# Track fallback usage for observability
used_fallback = False
primary_error: str | None = None

# Try primary model, fallback if needed
try:
pipeline = generator.chat([])
if system_prompt:
pipeline.chat.inject_system_content(system_prompt)
judgement = await judge.bind(pipeline)(input_data)
except Exception as e:
if fallback_model is None:
raise
# Log primary model failure and fallback usage
used_fallback = True
primary_error = f"{type(e).__name__}: {e}"
primary_model_name = model if isinstance(model, str) else type(model).__name__
fallback_model_name = (
fallback_model if isinstance(fallback_model, str) else type(fallback_model).__name__
)
logger.warning(
f"Primary model '{primary_model_name}' failed with {primary_error}. "
f"Using fallback model '{fallback_model_name}'."
)
# Use fallback model
generator = _create_generator(fallback_model, model_params)
pipeline = generator.chat([])
if system_prompt:
pipeline.chat.inject_system_content(system_prompt)
judgement = await judge.bind(pipeline)(input_data)

if min_score is not None:
judgement.score = max(min_score, judgement.score)
Expand All @@ -117,6 +152,15 @@ async def evaluate(
value=judgement.score,
attributes={
"reason": judgement.reason,
"used_fallback": used_fallback,
"fallback_model": (
str(fallback_model)
if isinstance(fallback_model, str)
else type(fallback_model).__name__
)
if used_fallback
else None,
"primary_error": primary_error,
},
)
pass_metric = Metric(value=float(judgement.passing))
Expand Down