feat: per-dataset max_new_tokens override#356
Conversation
Add optional `max_new_tokens` to the Dataset config so performance and accuracy datasets can use different per-request max_tokens within a single `--mode both` run. The client sends model_params.max_new_tokens as the OpenAI completions `max_tokens`. A large global value (e.g. 32768) inflates the server-side per-request decode KV reservation; at high concurrency this starves the disaggregated ctx->gen KV-cache transfer and triggers KV-cache-transfer timeout storms. This change lets a perf dataset use a small cap (avoid the overload) while accuracy datasets keep a large cap (avoid truncating long reasoning output). Falls back to model_params.max_new_tokens when unset; applied per-dataset in execute.py via model_params.model_copy(update=...). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Cover Dataset.max_new_tokens: defaults to None, accepts a per-dataset override, and rejects non-positive values. All test_schema.py tests pass. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
|
MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅ |
There was a problem hiding this comment.
Code Review
This pull request introduces a per-dataset max_new_tokens override capability to allow performance and accuracy datasets to use different token limits, falling back to the global model_params when unset. The feedback suggests encapsulating the override logic into a helper method get_model_params on the Dataset class to eliminate code duplication across the accuracy and performance dataset loading paths, and adding corresponding unit tests for this helper.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| max_new_tokens: int | None = Field( | ||
| None, | ||
| gt=0, | ||
| description=( | ||
| "Per-dataset override of model_params.max_new_tokens (sent as the " | ||
| "per-request max_tokens). Lets a performance dataset use a small cap " | ||
| "(to avoid server-side KV over-reservation/overload at high concurrency) " | ||
| "while accuracy datasets use a larger cap (to avoid truncating long " | ||
| "reasoning output). Falls back to model_params.max_new_tokens when unset." | ||
| ), | ||
| ) |
There was a problem hiding this comment.
To avoid duplicating the max_new_tokens override logic across different dataset loading paths, we can encapsulate this behavior as a helper method on the Dataset class itself. This improves maintainability and makes the code more robust to future changes.
| max_new_tokens: int | None = Field( | |
| None, | |
| gt=0, | |
| description=( | |
| "Per-dataset override of model_params.max_new_tokens (sent as the " | |
| "per-request max_tokens). Lets a performance dataset use a small cap " | |
| "(to avoid server-side KV over-reservation/overload at high concurrency) " | |
| "while accuracy datasets use a larger cap (to avoid truncating long " | |
| "reasoning output). Falls back to model_params.max_new_tokens when unset." | |
| ), | |
| ) | |
| max_new_tokens: int | None = Field( | |
| None, | |
| gt=0, | |
| description=( | |
| "Per-dataset override of model_params.max_new_tokens (sent as the " | |
| "per-request max_tokens). Lets a performance dataset use a small cap " | |
| "(to avoid server-side KV over-reservation/overload at high concurrency) " | |
| "while accuracy datasets use a larger cap (to avoid truncating long " | |
| "reasoning output). Falls back to model_params.max_new_tokens when unset." | |
| ), | |
| ) | |
| def get_model_params(self, global_params: ModelParams) -> ModelParams: | |
| """Get model params with per-dataset max_new_tokens override applied if set.""" | |
| if self.max_new_tokens is None: | |
| return global_params | |
| return global_params.model_copy(update={"max_new_tokens": self.max_new_tokens}) |
| # Per-dataset max_new_tokens override (falls back to global model_params). | ||
| acc_model_params = ( | ||
| config.model_params | ||
| if acc_cfg.max_new_tokens is None | ||
| else config.model_params.model_copy( | ||
| update={"max_new_tokens": acc_cfg.max_new_tokens} | ||
| ) | ||
| ) |
There was a problem hiding this comment.
Use the new get_model_params helper method on the Dataset configuration model to simplify the override logic and eliminate duplication.
| # Per-dataset max_new_tokens override (falls back to global model_params). | |
| acc_model_params = ( | |
| config.model_params | |
| if acc_cfg.max_new_tokens is None | |
| else config.model_params.model_copy( | |
| update={"max_new_tokens": acc_cfg.max_new_tokens} | |
| ) | |
| ) | |
| acc_model_params = acc_cfg.get_model_params(config.model_params) |
| # Per-dataset max_new_tokens override (falls back to global model_params). | ||
| perf_model_params = ( | ||
| config.model_params | ||
| if perf_cfg.max_new_tokens is None | ||
| else config.model_params.model_copy( | ||
| update={"max_new_tokens": perf_cfg.max_new_tokens} | ||
| ) | ||
| ) |
There was a problem hiding this comment.
Use the new get_model_params helper method on the Dataset configuration model to simplify the override logic and eliminate duplication.
| # Per-dataset max_new_tokens override (falls back to global model_params). | |
| perf_model_params = ( | |
| config.model_params | |
| if perf_cfg.max_new_tokens is None | |
| else config.model_params.model_copy( | |
| update={"max_new_tokens": perf_cfg.max_new_tokens} | |
| ) | |
| ) | |
| perf_model_params = perf_cfg.get_model_params(config.model_params) |
| @pytest.mark.unit | ||
| def test_max_new_tokens_defaults_none(self): | ||
| ds = Dataset(name="perf", type=DatasetType.PERFORMANCE, path="data.jsonl") | ||
| assert ds.max_new_tokens is None | ||
|
|
||
| @pytest.mark.unit | ||
| def test_per_dataset_max_new_tokens_override(self): | ||
| ds = Dataset( | ||
| name="aime25", | ||
| type=DatasetType.ACCURACY, | ||
| path="aime25.jsonl", | ||
| eval_method=EvalMethod.EXACT_MATCH, | ||
| max_new_tokens=32768, | ||
| ) | ||
| assert ds.max_new_tokens == 32768 | ||
|
|
||
| @pytest.mark.unit | ||
| def test_max_new_tokens_rejects_non_positive(self): | ||
| with pytest.raises(ValueError, match="greater than 0"): | ||
| Dataset( | ||
| name="perf", | ||
| type=DatasetType.PERFORMANCE, | ||
| path="data.jsonl", | ||
| max_new_tokens=0, | ||
| ) |
There was a problem hiding this comment.
Add unit tests to verify the correctness of the new get_model_params helper method on the Dataset configuration model.
@pytest.mark.unit
def test_max_new_tokens_defaults_none(self):
ds = Dataset(name="perf", type=DatasetType.PERFORMANCE, path="data.jsonl")
assert ds.max_new_tokens is None
@pytest.mark.unit
def test_per_dataset_max_new_tokens_override(self):
ds = Dataset(
name="aime25",
type=DatasetType.ACCURACY,
path="aime25.jsonl",
eval_method=EvalMethod.EXACT_MATCH,
max_new_tokens=32768,
)
assert ds.max_new_tokens == 32768
@pytest.mark.unit
def test_max_new_tokens_rejects_non_positive(self):
with pytest.raises(ValueError, match="greater than 0"):
Dataset(
name="perf",
type=DatasetType.PERFORMANCE,
path="data.jsonl",
max_new_tokens=0,
)
@pytest.mark.unit
def test_get_model_params_override(self):
global_params = ModelParams(name="test", max_new_tokens=1024)
ds_no_override = Dataset(name="perf", type=DatasetType.PERFORMANCE, path="data.jsonl")
assert ds_no_override.get_model_params(global_params).max_new_tokens == 1024
ds_with_override = Dataset(
name="aime25",
type=DatasetType.ACCURACY,
path="aime25.jsonl",
max_new_tokens=32768,
)
assert ds_with_override.get_model_params(global_params).max_new_tokens == 32768
What does this PR do?
When running a combined performance + accuracy benchmark in a single --mode both invocation, the two phases want opposite generation caps, but today the harness only exposes one global model_params.max_new_tokens:
Type of change
Related issues
N/A
Testing
Checklist