Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
501b730
test: Capture baseline metric routing behavior
vivekkalyan Mar 5, 2026
6d4f689
feat: Add section-aware metric routing and W&B taxonomy registration
vivekkalyan Mar 5, 2026
498a0cd
feat: Add MetricsBuilder with hierarchical cost rollups
vivekkalyan Mar 5, 2026
c4e848c
test: Capture baseline train trajectory metric routing
vivekkalyan Mar 5, 2026
1d9bf26
feat: Route train trajectory metrics and log costs via MetricsBuilder
vivekkalyan Mar 5, 2026
89a58e1
feat: Rename train metrics to reward, loss, and throughput sections
vivekkalyan Mar 5, 2026
20f9967
feat: Persist MetricsBuilder cumulative state across resume
vivekkalyan Mar 5, 2026
3e5ab1b
feat: Emit canonical train metric keys at source
vivekkalyan Mar 5, 2026
1ba0931
docs: Add metrics taxonomy guide and smoke example
vivekkalyan Mar 5, 2026
75068fd
fix: Bind nested cost metrics to training_step in W&B
vivekkalyan Mar 5, 2026
f958e3c
feat: Add API cost decorator and metrics context wiring
vivekkalyan Mar 5, 2026
7294638
test: Add coverage for API cost decorator and context routing
vivekkalyan Mar 5, 2026
c91cf27
docs: Add API cost decorator guide and smoke demo
vivekkalyan Mar 5, 2026
6fb0d8c
fix: Parse entity and project in metrics smoke config
vivekkalyan Mar 5, 2026
754ef57
test: Cover metrics builder resume and cumulative routing
vivekkalyan Mar 9, 2026
4659a5b
fix: Restore MetricsBuilder cumulative state and routing
vivekkalyan Mar 9, 2026
26c7406
test: Cover taxonomy timing and data metrics
vivekkalyan Mar 9, 2026
02a3c58
feat: Emit time and data metrics across training flows
vivekkalyan Mar 9, 2026
59de04d
docs: Document auto-emitted taxonomy metrics
vivekkalyan Mar 9, 2026
a9ce32f
feat: Add yes-no-maybe metrics example
vivekkalyan Mar 9, 2026
68281b5
fix: Load MetricsBuilder state before builder access
vivekkalyan Mar 9, 2026
7b8d8f9
fix: Preserve gradient step metrics in train outputs
vivekkalyan Mar 9, 2026
004d610
fix: Skip stale MetricsBuilder flush outputs
vivekkalyan Mar 9, 2026
b973366
fix: Normalize Model.log inputs once
vivekkalyan Mar 9, 2026
18417e3
refactor: Share training metric aggregation helpers
vivekkalyan Mar 9, 2026
ffe815e
refactor: Simplify MetricsBuilder state access
vivekkalyan Mar 9, 2026
9ee7a59
refactor: Extract API cost tracking helpers
vivekkalyan Mar 9, 2026
29b7836
fix: Simplify Metrics Logging And Cumulative Naming
vivekkalyan Mar 9, 2026
8c2042c
fix: Require Model-Aware Api Cost Pricing
vivekkalyan Mar 9, 2026
d2e9213
fix: Normalize Unsloth Eval Metric Routing
vivekkalyan Mar 9, 2026
5c46148
fix: Align Wandb Logging With Training Step
vivekkalyan Mar 9, 2026
ad51d34
refactor: Use Backend Train In Metrics Demo
vivekkalyan Mar 9, 2026
fe4a06b
feat: Add LocalBackend wall time and GPU cost metrics
vivekkalyan Mar 9, 2026
096c042
refactor: Rely On LocalBackend metrics in demo
vivekkalyan Mar 9, 2026
67ff726
fix: preserve out-of-order wandb metric logging
vivekkalyan Mar 10, 2026
fab7907
fix: account for cached API token pricing
vivekkalyan Mar 10, 2026
9a48f2e
test: add live API cost smoke tests
vivekkalyan Mar 10, 2026
84328ce
refactor: Rename API cost module
vivekkalyan Mar 10, 2026
7c0a86f
docs: Remove metrics taxonomy smoke example
vivekkalyan Mar 10, 2026
57644a0
refactor: Simplify metrics cost helpers
vivekkalyan Mar 10, 2026
401547b
refactor: Simplify metric taxonomy key handling
vivekkalyan Mar 10, 2026
92384a3
refactor: Canonicalize cost and throughput keys
vivekkalyan Mar 10, 2026
50f071b
Merge branch 'main' into feat/improved-metrics
vivekkalyan Mar 10, 2026
3b943dc
refactor: Require explicit API cost provider and model
vivekkalyan Mar 10, 2026
a032994
docs: Replace metrics taxonomy note
vivekkalyan Mar 10, 2026
5e3c812
style: Apply ruff format
vivekkalyan Mar 10, 2026
b934c25
fix: Resolve ty failures in API cost and Unsloth
vivekkalyan Mar 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 259 additions & 0 deletions dev/yes-no-maybe-metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
"""Yes-no-maybe metrics demo for the LocalBackend `model.train()` path.

This keeps the same prompt family, rollout structure, and reward ordering as
`dev/yes-no-maybe.py` while adding explicit metrics taxonomy instrumentation for
actor/eval timing and data metrics, while relying on LocalBackend for automatic
step wall time and GPU cost logging.
"""

from __future__ import annotations

import asyncio
from itertools import permutations
import os
import time

from dotenv import load_dotenv
import openai

try:
import unsloth # noqa: F401
except ImportError:
pass

import art
from art.local import LocalBackend


async def create_chat_completion(
client: openai.AsyncOpenAI,
*,
model_name: str,
messages: art.Messages,
max_tokens: int,
timeout: float,
) -> openai.types.chat.chat_completion.ChatCompletion:
return await client.chat.completions.create(
messages=messages,
model=model_name,
max_tokens=max_tokens,
timeout=timeout,
)


def with_quotes(word: str) -> str:
return f"'{word}'"


def build_prompts() -> list[str]:
return [
f"{prefix} with {', '.join([with_quotes(word) if use_quotes else word for word in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}"
for prefix in ["respond", "just respond"]
for use_quotes in [True, False]
for words in (
list(permutation)
for length in [3, 2]
for permutation in permutations(["yes", "no", "maybe"], length)
)
]


def reward_for_answer(content: str | None) -> float:
if content == "yes":
return 0.5
if content == "no":
return 0.75
if content == "maybe":
return 1.0
return 0.0


def scenario_id_for_prompt(prompt: str) -> str:
return prompt.replace(" ", "_").replace("'", "")


def response_total_tokens(
response: openai.types.chat.chat_completion.ChatCompletion,
) -> int:
usage = response.usage
if usage is None:
return 0
prompt_tokens = int(usage.prompt_tokens or 0)
completion_tokens = int(usage.completion_tokens or 0)
return prompt_tokens + completion_tokens


def total_actor_tokens(groups: list[art.TrajectoryGroup]) -> int:
return sum(
int(trajectory.metadata.get("actor_total_tokens", 0) or 0)
for group in groups
for trajectory in group.trajectories
)


async def rollout(
client: openai.AsyncOpenAI,
model: art.TrainableModel,
prompt: str,
*,
max_tokens: int,
timeout: float,
) -> art.Trajectory:
messages: art.Messages = [{"role": "user", "content": prompt}]
chat_completion = await create_chat_completion(
client,
model_name=model.get_inference_name(),
messages=messages,
max_tokens=max_tokens,
timeout=timeout,
)
choice = chat_completion.choices[0]
content = choice.message.content
return art.Trajectory(
messages_and_choices=[*messages, choice],
reward=reward_for_answer(content),
metadata={
"scenario_id": scenario_id_for_prompt(prompt),
"actor_total_tokens": response_total_tokens(chat_completion),
},
metrics={
"valid_answer": reward_for_answer(content) > 0.0,
},
)


async def evaluate(
client: openai.AsyncOpenAI,
model: art.TrainableModel,
prompts: list[str],
*,
max_tokens: int,
timeout: float,
) -> list[art.TrajectoryGroup]:
groups = await art.gather_trajectory_groups(
art.TrajectoryGroup(
[
rollout(
client,
model,
prompt,
max_tokens=max_tokens,
timeout=timeout,
)
],
metadata={"scenario_id": scenario_id_for_prompt(prompt)},
)
for prompt in prompts
)
return groups


def print_history_summary(model: art.TrainableModel) -> None:
history_path = (
model.base_path + f"/{model.project}/models/{model.name}/history.jsonl"
)
print(f"History: {history_path}")


def build_internal_config() -> art.dev.InternalModelConfig:
return art.dev.InternalModelConfig(
engine_args=art.dev.EngineArgs(
gpu_memory_utilization=float(
os.environ.get("GPU_MEMORY_UTILIZATION", "0.85")
),
max_model_len=int(os.environ.get("MAX_MODEL_LEN", "4096")),
)
)


async def main() -> None:
load_dotenv()

backend = LocalBackend()
base_model = os.environ.get("BASE_MODEL", "Qwen/Qwen3-30B-A3B-Instruct-2507")
project = os.environ.get("PROJECT", "yes-no-maybe-metrics")
model = art.TrainableModel(
name=os.environ.get("MODEL_NAME", f"yes-no-maybe-metrics-{int(time.time())}"),
project=project,
base_model=base_model,
report_metrics=["wandb"],
_internal_config=build_internal_config(),
)
try:
await model.register(backend)

prompts = build_prompts()
eval_prompts = prompts[: int(os.environ.get("EVAL_PROMPTS", "12"))]
openai_client = model.openai_client()
max_steps = int(os.environ.get("NUM_STEPS", "20"))
rollouts_per_prompt = int(os.environ.get("ROLLOUTS_PER_PROMPT", "32"))
max_tokens = int(os.environ.get("MAX_TOKENS", "100"))
timeout = float(os.environ.get("TIMEOUT", "100"))
eval_every_n_steps = int(os.environ.get("EVAL_EVERY_N_STEPS", "1"))
learning_rate = float(os.environ.get("LEARNING_RATE", "1e-4"))

start_step = await model.get_step()
for offset in range(max_steps):
current_step = start_step + offset

if (
eval_every_n_steps > 0
and (current_step - start_step) % eval_every_n_steps == 0
):
eval_builder = model.metrics_builder("eval")
with eval_builder.activate_context():
with eval_builder.measure("time/step_eval_s"):
val_groups = await evaluate(
openai_client,
model,
eval_prompts,
max_tokens=max_tokens,
timeout=timeout,
)
eval_builder.add_data(
step_actor_tokens=total_actor_tokens(val_groups)
)
await model.log(val_groups, split="val", step=current_step)

train_builder = model.metrics_builder("train")
with train_builder.activate_context():
with train_builder.measure("time/step_actor_s"):
train_groups = await art.gather_trajectory_groups(
(
art.TrajectoryGroup(
rollout(
openai_client,
model,
prompt,
max_tokens=max_tokens,
timeout=timeout,
)
for _ in range(rollouts_per_prompt)
)
for prompt in prompts
)
)
train_builder.add_data(
step_actor_tokens=total_actor_tokens(train_groups)
)
result = await backend.train(
model,
train_groups,
learning_rate=learning_rate,
)

await model.log(
split="train",
step=result.step,
trajectories=train_groups,
metrics=result.metrics,
)
print(f"step {result.step} complete")

print_history_summary(model)
finally:
await backend.close()


if __name__ == "__main__":
asyncio.run(main())
3 changes: 2 additions & 1 deletion docs/docs.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"features/checkpoint-forking",
"features/checkpoint-deletion",
"features/additional-histories",
"features/tracking-metrics",
"features/mcp-rl"
]
},
Expand Down Expand Up @@ -106,4 +107,4 @@
"bluesky": "https://bsky.app/profile/openpipe.bsky.social",
"github": "https://github.com/openpipe/ART"
}
}
}
Loading
Loading