Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ class RunConfig(ConfigBase):
buffer_size: Number of records to process in each batch during dataset generation.
A batch is processed end-to-end (column generation, post-batch processors, and writing the batch
to artifact storage) before moving on to the next batch. Must be > 0. Default is 1000.
max_in_flight_tasks: Maximum number of async scheduler tasks that may hold task
leases at once. Tasks may be executing, awaiting I/O, or waiting on model
request admission. Model API request concurrency is controlled separately by
``max_parallel_requests``. Must be >= 1. Default is 1024.
non_inference_max_parallel_workers: Maximum number of worker threads used for non-inference
cell-by-cell generators. Must be >= 1. Default is 4.
max_conversation_restarts: Maximum number of full conversation restarts permitted when
Expand Down Expand Up @@ -168,6 +172,14 @@ class RunConfig(ConfigBase):
shutdown_error_rate: float = Field(default=0.5, ge=0.0, le=1.0)
shutdown_error_window: int = Field(default=10, ge=1)
buffer_size: int = Field(default=1000, gt=0)
max_in_flight_tasks: int = Field(
default=1024,
ge=1,
description=(
"Maximum number of async scheduler tasks that may hold task leases at once. "
"Model API request concurrency is controlled separately by max_parallel_requests."
),
)
non_inference_max_parallel_workers: int = Field(default=4, ge=1)
max_conversation_restarts: int = Field(default=5, ge=0)
max_conversation_correction_steps: int = Field(default=0, ge=0)
Expand Down
15 changes: 15 additions & 0 deletions packages/data-designer-config/tests/config/test_run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ def test_run_config_accepts_disabled_dropped_column_preservation() -> None:
assert run_config.preserve_dropped_columns is False


def test_run_config_defaults_max_in_flight_tasks_to_1024() -> None:
assert RunConfig().max_in_flight_tasks == 1024


def test_run_config_accepts_custom_max_in_flight_tasks() -> None:
run_config = RunConfig(max_in_flight_tasks=2048)

assert run_config.max_in_flight_tasks == 2048


def test_run_config_rejects_invalid_max_in_flight_tasks() -> None:
with pytest.raises(ValidationError, match="max_in_flight_tasks"):
RunConfig(max_in_flight_tasks=0)


def test_run_config_throttle_shim_rejects_unknown_legacy_fields() -> None:
with pytest.raises(ValidationError, match="max_concurrent_requests"):
RunConfig(throttle={"max_concurrent_requests": 1})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
stable_task_id,
)
from data_designer.engine.dataset_builders.scheduling.task_admission import (
DEFAULT_IN_FLIGHT_TASK_CAPACITY,
TaskAdmissionConfig,
TaskAdmissionController,
TaskAdmissionDenied,
Expand Down Expand Up @@ -76,8 +77,6 @@

logger = logging.getLogger(__name__)

DEFAULT_TASK_POOL_SIZE: int = 256
MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER: int = 2
MODEL_GROUP_ADMISSION_BACKLOG_MULTIPLIER: int = 2

# Degraded-provider WARN: emit at most one warning per interval when the
Expand Down Expand Up @@ -144,8 +143,8 @@ def __init__(
buffer_manager: RowGroupBufferManager | None = None,
*,
max_concurrent_row_groups: int = 3,
max_submitted_tasks: int = DEFAULT_TASK_POOL_SIZE,
max_model_task_admission: int = DEFAULT_TASK_POOL_SIZE,
max_in_flight_tasks: int = DEFAULT_IN_FLIGHT_TASK_CAPACITY,
max_model_task_admission: int = DEFAULT_IN_FLIGHT_TASK_CAPACITY,
task_admission_config: TaskAdmissionConfig | None = None,
salvage_max_rounds: int = 2,
on_finalize_row_group: Callable[[int], None] | None = None,
Expand Down Expand Up @@ -183,8 +182,8 @@ def __init__(
model_group_limit_cap=max_model_task_admission,
)
admission_config = task_admission_config or TaskAdmissionConfig(
submission_capacity=max_submitted_tasks,
resource_limits={"llm_wait": max_model_task_admission, "local": max_submitted_tasks},
submission_capacity=max_in_flight_tasks,
resource_limits={"llm_wait": max_model_task_admission},
)
self._task_admission = TaskAdmissionController(admission_config)
self._task_admission_config = admission_config
Expand Down Expand Up @@ -277,7 +276,7 @@ def __init__(
# Pre-compute row-group sizes for O(1) lookup
self._rg_size_map: dict[int, int] = dict(row_groups)
self._max_concurrent_row_groups = max_concurrent_row_groups
self._max_submitted_tasks = max_submitted_tasks
self._max_in_flight_tasks = max_in_flight_tasks
self._max_model_task_admission = max_model_task_admission
self._num_records = num_records
self._buffer_size = buffer_size
Expand Down Expand Up @@ -910,7 +909,7 @@ def _adaptive_row_group_block_reason(self) -> str | None:
if not self._row_group_row_guard_allows(next_size):
return "max_admitted_rows"
queue_view = self._fair_queue.view()
queue_guard = max(self._max_submitted_tasks * 4, self._max_model_task_admission * 2)
queue_guard = self._max_in_flight_tasks * 4
if queue_view.queued_total >= queue_guard:
return "queued_task_guardrail"
task_view = self._task_admission.view()
Expand Down Expand Up @@ -1907,7 +1906,7 @@ def capacity_plan(self) -> AsyncCapacityPlan:
max_admitted_rows=self._adaptive_max_admitted_rows,
blocked_reasons=dict(self._row_group_admission_blocked_reasons),
),
submission_capacity=CapacityValue(value=self._max_submitted_tasks, source="dataset_builder"),
submission_capacity=CapacityValue(value=self._max_in_flight_tasks, source="run_config"),
task_resource_limits=CapacityValue(
value=dict(self._task_admission_config.resource_limits),
source="engine_internal_config",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@
import asyncio

from data_designer.engine.dataset_builders.async_scheduler import (
DEFAULT_TASK_POOL_SIZE,
MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER,
AsyncTaskScheduler,
)
from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta
Expand Down Expand Up @@ -1055,19 +1053,17 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None:
df = self._processor_runner.run_post_batch(df, current_batch_number=rg_id, strict_row_count=True)
buffer_manager.replace_dataframe(rg_id, df)

# Coarse upper bound used only for scheduler task-stage model admission.
# Concrete provider/model request capacity is enforced by request admission
# at the model-call boundary.
aggregate = self._resource_provider.model_registry.get_aggregate_max_parallel_requests()
max_in_flight_tasks = self._resource_provider.run_config.max_in_flight_tasks
max_model_task_admission = max_in_flight_tasks

scheduler = AsyncTaskScheduler(
generators=gen_map,
graph=graph,
tracker=tracker,
row_groups=row_groups,
buffer_manager=buffer_manager,
max_submitted_tasks=DEFAULT_TASK_POOL_SIZE,
max_model_task_admission=max(DEFAULT_TASK_POOL_SIZE, MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER * aggregate),
max_in_flight_tasks=max_in_flight_tasks,
max_model_task_admission=max_model_task_admission,
on_finalize_row_group=on_finalize_row_group,
on_seeds_complete=(
on_seeds_complete if self._processor_runner.has_processors_for(ProcessorStage.PRE_BATCH) else None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@
"unknown_lease",
]
RELEASED_TASK_LEASE_HISTORY_LIMIT = 8192
DEFAULT_IN_FLIGHT_TASK_CAPACITY = 1024


@dataclass(frozen=True)
class TaskAdmissionConfig:
"""Engine-internal scheduler task-stage admission configuration."""

submission_capacity: int = 256
submission_capacity: int = DEFAULT_IN_FLIGHT_TASK_CAPACITY
resource_limits: Mapping[SchedulerResourceKey, int] = field(default_factory=dict)
bounded_borrow: BoundedBorrowTaskAdmissionPolicyConfig | None = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,13 @@ def __init__(self, **kwargs: object) -> None:
monkeypatch.setattr(builder_mod, "AsyncTaskScheduler", _SpyScheduler)
request_admission = object()
model_registry = MagicMock()
model_registry.get_aggregate_max_parallel_requests.return_value = 2
model_registry.get_aggregate_max_parallel_requests.side_effect = AssertionError(
"model task admission should follow max_in_flight_tasks directly"
)
model_registry.request_admission = request_admission
provider = SimpleNamespace(
model_registry=model_registry,
run_config=SimpleNamespace(progress_interval=5.0, progress_bar=False),
run_config=SimpleNamespace(max_in_flight_tasks=64, progress_interval=5.0, progress_bar=False),
)
processor_runner = MagicMock()
processor_runner.has_processors_for.return_value = False
Expand All @@ -222,6 +224,8 @@ def __init__(self, **kwargs: object) -> None:

assert captured_kwargs["request_pressure_provider"] is request_admission
assert captured_kwargs["request_pressure_advisory"] is True
assert captured_kwargs["max_in_flight_tasks"] == 64
assert captured_kwargs["max_model_task_admission"] == 64


# -- Test that existing sync path is unaffected --------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -856,8 +856,8 @@ async def test_scheduler_stateful_generator_serializes() -> None:


@pytest.mark.asyncio(loop_scope="session")
async def test_scheduler_bounded_submission() -> None:
"""Submitted task count respects max_submitted_tasks."""
async def test_scheduler_bounded_in_flight_tasks() -> None:
"""In-flight task count respects max_in_flight_tasks."""
provider = _mock_provider()

# Use a pipeline with many cells and low submission limit
Expand All @@ -883,7 +883,7 @@ async def test_scheduler_bounded_submission() -> None:
graph=graph,
tracker=tracker,
row_groups=row_groups,
max_submitted_tasks=2,
max_in_flight_tasks=2,
)
await scheduler.run()

Expand Down Expand Up @@ -1821,22 +1821,22 @@ async def test_scheduler_llm_bound_one_way_handoff() -> None:
row_groups = [(0, 3)]
tracker = CompletionTracker.with_graph(graph, row_groups)

max_submitted = 2
max_in_flight = 2
max_llm_wait = 2
scheduler = AsyncTaskScheduler(
generators=generators,
graph=graph,
tracker=tracker,
row_groups=row_groups,
max_submitted_tasks=max_submitted,
max_in_flight_tasks=max_in_flight,
max_model_task_admission=max_llm_wait,
)
await scheduler.run()

assert tracker.is_row_group_complete(0, 3, ["seed", "llm_col"])

snapshot = scheduler.task_admission_snapshot()
assert snapshot.resources_available["submission"] == max_submitted
assert snapshot.resources_available["submission"] == max_in_flight
assert snapshot.resources_available["llm_wait"] == max_llm_wait


Expand Down Expand Up @@ -1867,7 +1867,7 @@ async def test_scheduler_non_llm_holds_submission_slot() -> None:
graph=graph,
tracker=tracker,
row_groups=row_groups,
max_submitted_tasks=2,
max_in_flight_tasks=2,
max_model_task_admission=max_llm_wait,
)
await scheduler.run()
Expand All @@ -1880,7 +1880,7 @@ async def test_scheduler_non_llm_holds_submission_slot() -> None:

@pytest.mark.asyncio(loop_scope="session")
async def test_scheduler_deadlock_regression() -> None:
"""max_submitted_tasks=1, max_model_task_admission=1, two ready LLM tasks completes without deadlock."""
"""max_in_flight_tasks=1, max_model_task_admission=1, two ready LLM tasks completes without deadlock."""
provider = _mock_provider()
configs = [
SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}),
Expand All @@ -1904,7 +1904,7 @@ async def test_scheduler_deadlock_regression() -> None:
graph=graph,
tracker=tracker,
row_groups=row_groups,
max_submitted_tasks=1,
max_in_flight_tasks=1,
max_model_task_admission=1,
)

Expand Down Expand Up @@ -2379,23 +2379,23 @@ async def test_scheduler_llm_bound_429_retried_in_salvage() -> None:
storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet"
buffer_mgr = RowGroupBufferManager(storage)

max_submitted = 4
max_in_flight = 4
max_llm_wait = 2
scheduler = AsyncTaskScheduler(
generators=generators,
graph=graph,
tracker=tracker,
row_groups=row_groups,
buffer_manager=buffer_mgr,
max_submitted_tasks=max_submitted,
max_in_flight_tasks=max_in_flight,
max_model_task_admission=max_llm_wait,
)
await scheduler.run()

assert tracker.is_row_group_complete(0, num_records, ["seed", "llm_col"])

snapshot = scheduler.task_admission_snapshot()
assert snapshot.resources_available["submission"] == max_submitted
assert snapshot.resources_available["submission"] == max_in_flight
assert snapshot.resources_available["llm_wait"] == max_llm_wait


Expand Down Expand Up @@ -2441,15 +2441,15 @@ async def agenerate(self, data: dict) -> dict:
row_groups = [(0, 2)]
tracker = CompletionTracker.with_graph(graph, row_groups)

max_submitted = 4
max_in_flight = 4
max_llm_wait = 2
sink = InMemoryAdmissionEventSink()
scheduler = AsyncTaskScheduler(
generators=generators,
graph=graph,
tracker=tracker,
row_groups=row_groups,
max_submitted_tasks=max_submitted,
max_in_flight_tasks=max_in_flight,
max_model_task_admission=max_llm_wait,
scheduler_event_sink=sink,
)
Expand All @@ -2462,7 +2462,7 @@ async def agenerate(self, data: dict) -> dict:
await run_task

snapshot = scheduler.task_admission_snapshot()
assert snapshot.resources_available["submission"] == max_submitted
assert snapshot.resources_available["submission"] == max_in_flight
assert snapshot.resources_available["llm_wait"] == max_llm_wait
assert "cancelled" in [event.event_kind for event in sink.scheduler_events]
assert all(event.snapshot is not None for event in sink.scheduler_events)
Expand Down Expand Up @@ -2684,7 +2684,7 @@ async def test_scheduler_fair_admission_across_ready_columns() -> None:
graph=graph,
tracker=tracker,
row_groups=row_groups,
max_submitted_tasks=4,
max_in_flight_tasks=4,
trace=True,
)

Expand Down Expand Up @@ -2758,7 +2758,7 @@ async def agenerate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame:
graph=graph,
tracker=tracker,
row_groups=row_groups,
max_submitted_tasks=8,
max_in_flight_tasks=8,
max_concurrent_row_groups=2,
trace=True,
)
Expand Down Expand Up @@ -2806,7 +2806,7 @@ async def test_scheduler_fair_llm_group_cap_preserves_peer_admission() -> None:
graph=graph,
tracker=tracker,
row_groups=row_groups,
max_submitted_tasks=4,
max_in_flight_tasks=4,
max_model_task_admission=4,
trace=True,
)
Expand Down Expand Up @@ -2877,7 +2877,7 @@ async def test_scheduler_downstream_interleaves_with_upstream() -> None:
tracker=tracker,
row_groups=row_groups,
buffer_manager=buffer_manager,
max_submitted_tasks=4,
max_in_flight_tasks=4,
trace=True,
)
await asyncio.wait_for(scheduler.run(), timeout=10.0)
Expand Down Expand Up @@ -2925,7 +2925,7 @@ async def test_scheduler_capacity_plan_observes_buffer_backpressure() -> None:
tracker=tracker,
row_groups=row_groups,
max_concurrent_row_groups=2,
max_submitted_tasks=2,
max_in_flight_tasks=2,
trace=True,
num_records=12,
buffer_size=3,
Expand Down Expand Up @@ -3023,7 +3023,7 @@ async def test_scheduler_emits_job_health_and_row_group_telemetry() -> None:
tracker=tracker,
row_groups=row_groups,
max_concurrent_row_groups=1,
max_submitted_tasks=2,
max_in_flight_tasks=2,
max_model_task_admission=1,
scheduler_event_sink=sink,
num_records=2,
Expand Down Expand Up @@ -3089,7 +3089,7 @@ async def test_scheduler_adaptive_row_group_admission_expands_target_for_horizon
tracker=tracker,
row_groups=row_groups,
max_concurrent_row_groups=4,
max_submitted_tasks=4,
max_in_flight_tasks=4,
max_model_task_admission=4,
adaptive_row_group_admission=True,
adaptive_row_group_initial_target=1,
Expand Down Expand Up @@ -3189,6 +3189,17 @@ def test_scheduler_adaptive_row_group_block_reason_prefers_llm_saturation() -> N
assert scheduler._adaptive_row_group_block_reason() == "llm_wait_saturated"


def test_scheduler_adaptive_row_group_queue_guard_uses_in_flight_task_cap() -> None:
scheduler, _tracker = _build_simple_pipeline(num_records=2, buffer_size=1)
scheduler._max_in_flight_tasks = 2
scheduler._max_model_task_admission = 100
scheduler._fair_queue = SimpleNamespace(
view=lambda: SimpleNamespace(queued_total=8, queued_peer_demand_by_resource={})
)

assert scheduler._adaptive_row_group_block_reason() == "queued_task_guardrail"


@pytest.mark.asyncio(loop_scope="session")
async def test_scheduler_raises_when_ready_frontier_blocked_without_in_flight() -> None:
provider = _mock_provider()
Expand Down
Loading
Loading