diff --git a/packages/data-designer-config/src/data_designer/config/run_config.py b/packages/data-designer-config/src/data_designer/config/run_config.py index 502410ab1..6d38f6f15 100644 --- a/packages/data-designer-config/src/data_designer/config/run_config.py +++ b/packages/data-designer-config/src/data_designer/config/run_config.py @@ -24,6 +24,7 @@ class JinjaRenderingEngine(StrEnum): "RunConfig.throttle and ThrottleConfig are deprecated. Use RunConfig.request_admission with " "RequestAdmissionTuningConfig for supported advanced request-admission tuning." ) +_PROGRESS_BAR_DEPRECATION_MESSAGE = "RunConfig.progress_bar is deprecated. Use RunConfig.display_tui instead." class RequestAdmissionTuningConfig(ConfigBase): @@ -142,9 +143,9 @@ class RunConfig(ConfigBase): Default is 0. async_trace: If True, collect per-task tracing data when using the async engine (DATA_DESIGNER_ASYNC_ENGINE=1). Has no effect on the sync path. Default is False. - progress_bar: If True, display sticky ANSI progress bars instead of periodic log lines - during generation. Requires a TTY; falls back to log lines in non-TTY environments. - Default is False. + display_tui: If True, display the terminal throughput TUI instead of periodic + log lines during generation. Requires a TTY; falls back to log lines in + non-TTY environments. Default is True. progress_interval: How often (in seconds) the async progress reporter emits a consolidated log block. Must be > 0. Default is 5.0. preserve_dropped_columns: If True, write columns removed by drop processors to @@ -172,7 +173,7 @@ class RunConfig(ConfigBase): max_conversation_restarts: int = Field(default=5, ge=0) max_conversation_correction_steps: int = Field(default=0, ge=0) async_trace: bool = False - progress_bar: bool = False + display_tui: bool = True progress_interval: float = Field(default=5.0, gt=0.0) preserve_dropped_columns: bool = Field( default=True, @@ -191,9 +192,22 @@ class RunConfig(ConfigBase): @model_validator(mode="before") @classmethod - def translate_deprecated_throttle_config(cls, data: Any) -> Any: - if isinstance(data, dict) and "throttle" in data: - normalized = dict(data) + def translate_deprecated_fields(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + + normalized = dict(data) + + if "progress_bar" in normalized: + progress_bar = normalized.pop("progress_bar") + normalized.setdefault("display_tui", progress_bar) + warnings.warn( + _PROGRESS_BAR_DEPRECATION_MESSAGE, + DeprecationWarning, + stacklevel=2, + ) + + if "throttle" in normalized: throttle = normalized.pop("throttle") if normalized.get("request_admission") is not None: raise ValueError( @@ -211,7 +225,25 @@ def translate_deprecated_throttle_config(cls, data: Any) -> Any: stacklevel=2, ) return normalized - return data + return normalized + + @property + def progress_bar(self) -> bool: + warnings.warn( + _PROGRESS_BAR_DEPRECATION_MESSAGE, + DeprecationWarning, + stacklevel=2, + ) + return self.display_tui + + @progress_bar.setter + def progress_bar(self, value: bool) -> None: + warnings.warn( + _PROGRESS_BAR_DEPRECATION_MESSAGE, + DeprecationWarning, + stacklevel=2, + ) + self.display_tui = value @model_validator(mode="after") def normalize_shutdown_settings(self) -> Self: diff --git a/packages/data-designer-config/tests/config/test_run_config.py b/packages/data-designer-config/tests/config/test_run_config.py index 1d6efd9c1..46cd87879 100644 --- a/packages/data-designer-config/tests/config/test_run_config.py +++ b/packages/data-designer-config/tests/config/test_run_config.py @@ -24,6 +24,37 @@ def test_run_config_accepts_native_renderer() -> None: assert JinjaRenderingEngine(run_config.jinja_rendering_engine) == JinjaRenderingEngine.NATIVE +def test_run_config_defaults_to_display_tui_enabled() -> None: + assert RunConfig().display_tui is True + + +def test_run_config_accepts_display_tui() -> None: + assert RunConfig(display_tui=False).display_tui is False + + +def test_run_config_progress_bar_shim_translates_to_display_tui() -> None: + with pytest.warns(DeprecationWarning, match="RunConfig.progress_bar.*RunConfig.display_tui"): + run_config = RunConfig(progress_bar=False) + + assert run_config.display_tui is False + + +def test_run_config_progress_bar_property_getter_warns() -> None: + run_config = RunConfig(display_tui=False) + + with pytest.warns(DeprecationWarning, match="RunConfig.progress_bar.*RunConfig.display_tui"): + assert run_config.progress_bar is False + + +def test_run_config_progress_bar_property_setter_warns() -> None: + run_config = RunConfig(display_tui=False) + + with pytest.warns(DeprecationWarning, match="RunConfig.progress_bar.*RunConfig.display_tui"): + run_config.progress_bar = True + + assert run_config.display_tui is True + + def test_run_config_preserves_dropped_columns_by_default() -> None: assert RunConfig().preserve_dropped_columns is True diff --git a/packages/data-designer-engine/pyproject.toml b/packages/data-designer-engine/pyproject.toml index 87f19785b..83ea684a8 100644 --- a/packages/data-designer-engine/pyproject.toml +++ b/packages/data-designer-engine/pyproject.toml @@ -33,6 +33,7 @@ bump = true [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ "anyascii>=0.3.3,<1", + "asciichartpy>=1.5.25,<2", "chardet>=3.0.2,<6", # Pulled in by sqlfluff "cryptography>=46.0.7,<47", # 46.0.7 fixes CVE-2026-39892 pulled in by mcp "data-designer-config=={{ version }}", diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py index 08c78120b..09c7ac314 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py @@ -7,6 +7,7 @@ import asyncio import concurrent.futures +import contextlib import inspect import logging from typing import TYPE_CHECKING, Any @@ -86,6 +87,11 @@ def generate(self, *args: Any, **kwargs: Any) -> tuple[Any, list]: "Use 'await model.agenerate()' in async custom columns." ) + from data_designer.engine.context import ( + current_generation_column, + current_run_cancel_event, + is_run_cancellation_requested, + ) from data_designer.engine.dataset_builders.utils.async_concurrency import ensure_async_engine_loop # Honor a per-call ``timeout=`` override (passed straight through to the @@ -99,10 +105,41 @@ def generate(self, *args: Any, **kwargs: Any) -> tuple[Any, list]: conversation_restarts = int(kwargs.get("max_conversation_restarts", 0) or 0) bridge_timeout = _compute_bridge_timeout(per_request_timeout, correction_steps, conversation_restarts) + if is_run_cancellation_requested(): + raise asyncio.CancelledError + + column = current_generation_column.get() + cancel_event = current_run_cancel_event.get() + + async def agenerate_with_bridge_context() -> tuple[Any, list]: + column_token = current_generation_column.set(column) + cancel_token = current_run_cancel_event.set(cancel_event) + try: + if is_run_cancellation_requested(): + raise asyncio.CancelledError + return await facade.agenerate(*args, **kwargs) + finally: + # Cross-thread cancellation can close the coroutine from a + # different context after it has started. The task context is + # being discarded either way, so avoid an unraisable reset error. + with contextlib.suppress(ValueError): + current_run_cancel_event.reset(cancel_token) + with contextlib.suppress(ValueError): + current_generation_column.reset(column_token) + loop = ensure_async_engine_loop() - future = asyncio.run_coroutine_threadsafe(facade.agenerate(*args, **kwargs), loop) + coro = agenerate_with_bridge_context() + try: + future = asyncio.run_coroutine_threadsafe(coro, loop) + except RuntimeError as exc: + coro.close() + if is_run_cancellation_requested() or "interpreter shutdown" in str(exc): + raise asyncio.CancelledError from exc + raise try: return future.result(timeout=bridge_timeout) + except concurrent.futures.CancelledError as exc: + raise asyncio.CancelledError from exc except concurrent.futures.TimeoutError as exc: future.cancel() # Demoted to debug: the raised ModelTimeoutError already surfaces diff --git a/packages/data-designer-engine/src/data_designer/engine/context.py b/packages/data-designer-engine/src/data_designer/engine/context.py index 500b6bb51..137a4f65a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/context.py +++ b/packages/data-designer-engine/src/data_designer/engine/context.py @@ -4,11 +4,26 @@ from __future__ import annotations from contextvars import ContextVar +from threading import Event # Set by the async scheduler before executing each task. # Value: (current_rg_index, total_rg_count) or None. current_row_group: ContextVar[tuple[int, int] | None] = ContextVar("current_row_group", default=None) +# Set by the async scheduler before executing each task so model usage can be +# attributed even when scheduler telemetry context is not available. +current_generation_column: ContextVar[str | None] = ContextVar("current_generation_column", default=None) + +# Shared cancellation signal for sync generator work running in thread-pool +# workers. Context variables copy the Event object into worker threads, and the +# scheduler flips the Event on cancellation. +current_run_cancel_event: ContextVar[Event | None] = ContextVar("current_run_cancel_event", default=None) + + +def is_run_cancellation_requested() -> bool: + cancel_event = current_run_cancel_event.get() + return cancel_event.is_set() if cancel_event is not None else False + def format_row_group_tag() -> str: """Return a '(x/X) ' prefix if a row group context is active, else ''.""" diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 9109eafcc..8623e22cd 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -12,6 +12,7 @@ from collections import Counter, defaultdict, deque from collections.abc import Coroutine, Mapping from dataclasses import dataclass +from threading import Event from typing import TYPE_CHECKING, Any, Callable import data_designer.lazy_heavy_imports as lazy @@ -25,7 +26,7 @@ RequestAdmissionConfigSnapshot, RowGroupAdmission, ) -from data_designer.engine.context import current_row_group +from data_designer.engine.context import current_generation_column, current_row_group, current_run_cancel_event from data_designer.engine.dataset_builders.errors import DatasetGenerationError from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta @@ -44,17 +45,11 @@ TaskAdmissionLease, ) from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef, Task, TaskTrace -from data_designer.engine.dataset_builders.utils.async_progress_reporter import ( - DEFAULT_REPORT_INTERVAL, - AsyncProgressReporter, -) -from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker from data_designer.engine.dataset_builders.utils.skip_evaluator import should_skip_column_for_record from data_designer.engine.dataset_builders.utils.skip_tracker import ( apply_skip_to_record, strip_skip_metadata_from_records, ) -from data_designer.engine.dataset_builders.utils.sticky_progress_bar import StickyProgressBar from data_designer.engine.errors import DataDesignerError from data_designer.engine.models.clients.errors import ProviderError from data_designer.engine.models.errors import RETRYABLE_MODEL_ERRORS, GenerationValidationFailureError @@ -67,6 +62,12 @@ SchedulerAdmissionEventSink, runtime_correlation_provider, ) +from data_designer.engine.progress.reporter import ( + DEFAULT_REPORT_INTERVAL, + AsyncProgressReporter, +) +from data_designer.engine.progress.terminal.throughput_panel import TerminalThroughputPanel +from data_designer.engine.progress.tracker import ProgressTracker if TYPE_CHECKING: from data_designer.engine.column_generators.generators.base import ColumnGenerator @@ -161,7 +162,7 @@ def __init__( num_records: int = 0, buffer_size: int = 0, progress_interval: float | None = None, - progress_bar: bool = False, + display_tui: bool = False, scheduler_event_sink: SchedulerAdmissionEventSink | None = None, run_id: str | None = None, adaptive_row_group_admission: bool = False, @@ -273,6 +274,7 @@ def __init__( # engine drops rows and continues, losing the cause unless we capture it. self._first_non_retryable_error: Exception | None = None self._fatal_worker_error: BaseException | None = None + self._cancel_requested = Event() # Pre-compute row-group sizes for O(1) lookup self._rg_size_map: dict[int, int] = dict(row_groups) @@ -309,7 +311,7 @@ def __init__( self._seed_cols: tuple[str, ...] = tuple(c for c in graph.columns if not graph.get_upstream_columns(c)) # Per-column progress tracking (cell-by-cell only; full-column tasks are instant) - self._progress_bar = StickyProgressBar() if progress_bar else None + self._progress_bar = TerminalThroughputPanel() if display_tui else None self._reporter = self._setup_async_progress_reporter(num_records, buffer_size, progress_interval) def _setup_async_progress_reporter( @@ -340,6 +342,7 @@ def _setup_async_progress_reporter( trackers, report_interval=interval, progress_bar=self._progress_bar, + run_id=self._run_id, ) @property @@ -370,6 +373,10 @@ def first_non_retryable_error(self) -> Exception | None: """ return self._first_non_retryable_error + def request_cancel(self) -> None: + """Signal cancellation to scheduler tasks and bridged sync generator work.""" + self._cancel_requested.set() + def _raise_if_fatal_worker_error(self) -> None: if self._fatal_worker_error is None: return @@ -1004,50 +1011,58 @@ async def run(self) -> None: num_rgs = len(self._row_groups) with self._progress_bar or contextlib.nullcontext(): - if self._reporter: - self._reporter.log_start(num_row_groups=num_rgs) - - self._emit_scheduler_event("scheduler_job_started", diagnostics=self._scheduler_job_diagnostics()) - self._emit_scheduler_health_snapshot("start") - - # Launch admission as a background task so it interleaves with dispatch. - admission_task = asyncio.create_task(self._admit_row_groups()) - try: - # Main dispatch loop - await self._main_dispatch_loop(seed_cols, has_pre_batch, all_columns) - finally: - # Always cancel admission + drain in-flight workers, regardless - # of how the dispatch loop exited (normal, early shutdown, - # CancelledError, or processor failure). - if not admission_task.done(): - admission_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await admission_task - await asyncio.shield(self._cancel_workers()) - # Salvage partially-complete row groups left over from early - # shutdown. Must run AFTER _cancel_workers - in-flight tasks - # could otherwise write into a buffer that's being finalized. - if self._early_shutdown and self._rg_states: - self._finalize_after_shutdown(all_columns) - - # Reached only on the clean-exit path; an exception in the - # dispatch loop or the finally block propagates and skips this. - if self._reporter: - self._reporter.log_final() + if self._reporter: + self._reporter.log_start(num_row_groups=num_rgs) - self._emit_scheduler_health_snapshot("completed") - self._emit_scheduler_event( - "scheduler_job_completed", diagnostics=self._scheduler_health_diagnostics(reason="completed") - ) + self._emit_scheduler_event("scheduler_job_started", diagnostics=self._scheduler_job_diagnostics()) + self._emit_scheduler_health_snapshot("start") + + # Launch admission as a background task so it interleaves with dispatch. + admission_task = asyncio.create_task(self._admit_row_groups()) - if self._rg_states: - incomplete = list(self._rg_states) - logger.error( - f"Scheduler exited with {len(self._rg_states)} unfinished row group(s): {incomplete}. " - "These row groups were not checkpointed." + try: + # Main dispatch loop + try: + await self._main_dispatch_loop(seed_cols, has_pre_batch, all_columns) + except asyncio.CancelledError: + self.request_cancel() + raise + finally: + # Always cancel admission + drain in-flight workers, regardless + # of how the dispatch loop exited (normal, early shutdown, + # CancelledError, or processor failure). + if not admission_task.done(): + admission_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await admission_task + await asyncio.shield(self._cancel_workers()) + # Salvage partially-complete row groups left over from early + # shutdown. Must run AFTER _cancel_workers - in-flight tasks + # could otherwise write into a buffer that's being finalized. + if self._early_shutdown and self._rg_states: + self._finalize_after_shutdown(all_columns) + + # Reached only on the clean-exit path; an exception in the + # dispatch loop or the finally block propagates and skips this. + if self._reporter: + self._reporter.log_final() + + self._emit_scheduler_health_snapshot("completed") + self._emit_scheduler_event( + "scheduler_job_completed", diagnostics=self._scheduler_health_diagnostics(reason="completed") ) + if self._rg_states: + incomplete = list(self._rg_states) + logger.error( + f"Scheduler exited with {len(self._rg_states)} unfinished row group(s): {incomplete}. " + "These row groups were not checkpointed." + ) + finally: + if self._reporter: + self._reporter.close() + async def _main_dispatch_loop( self, seed_cols: tuple[str, ...], @@ -1446,6 +1461,8 @@ async def _execute_task_inner(self, task: Task, lease: TaskAdmissionLease, task_ """Core task execution logic.""" num_rgs = len(self._row_groups) token = current_row_group.set((task.row_group, num_rgs)) + column_token = current_generation_column.set(task.column) + cancel_token = current_run_cancel_event.set(self._cancel_requested) group = lease.item.group identity_hash = hashlib.sha1("\0".join(group.key.identity).encode()).hexdigest()[:16] correlation_token = runtime_correlation_provider.set( @@ -1463,6 +1480,8 @@ async def _execute_task_inner(self, task: Task, lease: TaskAdmissionLease, task_ await self._execute_task_inner_impl(task, lease, task_execution_id) finally: runtime_correlation_provider.reset(correlation_token) + current_run_cancel_event.reset(cancel_token) + current_generation_column.reset(column_token) current_row_group.reset(token) async def _execute_task_inner_impl(self, task: Task, lease: TaskAdmissionLease, task_execution_id: str) -> None: diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py index 3abed6136..0a8da3808 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py @@ -3,6 +3,7 @@ from __future__ import annotations +import concurrent.futures import contextlib import functools import json @@ -42,7 +43,6 @@ from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph from data_designer.engine.dataset_builders.utils.processor_runner import ProcessorRunner, ProcessorStage -from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker from data_designer.engine.dataset_builders.utils.skip_evaluator import should_skip_column_for_record from data_designer.engine.dataset_builders.utils.skip_tracker import ( SKIPPED_COLUMNS_RECORD_KEY, @@ -51,10 +51,11 @@ restore_skip_metadata, strip_skip_metadata_from_records, ) -from data_designer.engine.dataset_builders.utils.sticky_progress_bar import StickyProgressBar from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler from data_designer.engine.processing.processors.base import Processor from data_designer.engine.processing.processors.drop_columns import DropColumnsProcessor +from data_designer.engine.progress.terminal.throughput_panel import TerminalThroughputPanel +from data_designer.engine.progress.tracker import ProgressTracker from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry from data_designer.engine.resources.resource_provider import ResourceProvider from data_designer.engine.storage.artifact_storage import ( @@ -104,6 +105,23 @@ def _is_async_trace_enabled(settings: RunConfig) -> bool: return settings.async_trace or os.environ.get("DATA_DESIGNER_ASYNC_TRACE", "0") == "1" +def _await_async_scheduler_result(future: concurrent.futures.Future[Any], scheduler: Any) -> None: + try: + future.result() + except KeyboardInterrupt: + request_cancel = getattr(scheduler, "request_cancel", None) + if callable(request_cancel): + request_cancel() + future.cancel() + try: + future.result() + except concurrent.futures.CancelledError: + pass + except Exception: + logger.debug("Async scheduler raised while cancelling after KeyboardInterrupt", exc_info=True) + raise + + class _ConfigCompatibility(StrEnum): COMPATIBLE = "compatible" INCOMPATIBLE = "incompatible" @@ -641,7 +659,7 @@ def _build_async_preview(self, generators: list[ColumnGenerator], num_records: i loop = ensure_async_engine_loop() future = asyncio.run_coroutine_threadsafe(scheduler.run(), loop) try: - future.result() + _await_async_scheduler_result(future, scheduler) finally: self._task_traces = scheduler.traces self._early_shutdown = scheduler.early_shutdown @@ -937,7 +955,7 @@ def on_complete(final_path: Path | str | None) -> None: loop = ensure_async_engine_loop() future = asyncio.run_coroutine_threadsafe(scheduler.run(), loop) try: - future.result() + _await_async_scheduler_result(future, scheduler) finally: self._task_traces = scheduler.traces self._early_shutdown = scheduler.early_shutdown @@ -1084,7 +1102,7 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: num_records=num_records, buffer_size=buffer_size, progress_interval=self._resource_provider.run_config.progress_interval, - progress_bar=self._resource_provider.run_config.progress_bar, + display_tui=self._resource_provider.run_config.display_tui, request_pressure_provider=self._resource_provider.model_registry.request_admission, request_pressure_advisory=True, ) @@ -1367,7 +1385,7 @@ def _setup_fan_out( self, generator: ColumnGeneratorWithModelRegistry, max_workers: int, - progress_bar: StickyProgressBar | None = None, + progress_bar: TerminalThroughputPanel | None = None, ) -> tuple[ProgressTracker, dict[str, Any]]: if generator.get_generation_strategy() != GenerationStrategy.CELL_BY_CELL: raise DatasetGenerationError( @@ -1435,7 +1453,7 @@ def _finalize_fan_out(self, progress_tracker: ProgressTracker) -> None: def _fan_out_with_async(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None: if getattr(generator.config, "tool_alias", None): logger.info("🛠️ Tool calling enabled") - bar = StickyProgressBar() if self._resource_provider.run_config.progress_bar else None + bar = TerminalThroughputPanel() if self._resource_provider.run_config.display_tui else None can_skip = self._column_can_skip(generator.config.name) with bar or contextlib.nullcontext(): progress_tracker, executor_kwargs = self._setup_fan_out(generator, max_workers, progress_bar=bar) @@ -1459,7 +1477,7 @@ def _fan_out_with_async(self, generator: ColumnGeneratorWithModelRegistry, max_w def _fan_out_with_threads(self, generator: ColumnGeneratorWithModelRegistry, max_workers: int) -> None: if getattr(generator.config, "tool_alias", None): logger.info("🛠️ Tool calling enabled") - bar = StickyProgressBar() if self._resource_provider.run_config.progress_bar else None + bar = TerminalThroughputPanel() if self._resource_provider.run_config.display_tui else None can_skip = self._column_can_skip(generator.config.name) with bar or contextlib.nullcontext(): progress_tracker, executor_kwargs = self._setup_fan_out(generator, max_workers, progress_bar=bar) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_progress_reporter.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_progress_reporter.py deleted file mode 100644 index c394ae613..000000000 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_progress_reporter.py +++ /dev/null @@ -1,131 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import logging -import time -from typing import TYPE_CHECKING - -from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker -from data_designer.logging import LOG_INDENT - -if TYPE_CHECKING: - from data_designer.engine.dataset_builders.utils.sticky_progress_bar import StickyProgressBar - -logger = logging.getLogger(__name__) - -DEFAULT_REPORT_INTERVAL = 5.0 - - -class AsyncProgressReporter: - """Consolidated progress reporter for async generation. - - Owns per-column ProgressTracker instances (in quiet mode) and emits - a single grouped log block at most once per ``report_interval`` seconds. - """ - - def __init__( - self, - trackers: dict[str, ProgressTracker], - *, - report_interval: float = DEFAULT_REPORT_INTERVAL, - progress_bar: StickyProgressBar | None = None, - ) -> None: - self._trackers = trackers - self._report_interval = report_interval - self._start_time = time.perf_counter() - self._last_report_time: float = self._start_time - self._last_reported_total: int = -1 - self._bar = progress_bar - if self._bar is not None: - for col, tracker in trackers.items(): - self._bar.add_bar(col, f"column '{col}'", tracker.total_records) - - def log_start(self, num_row_groups: int) -> None: - cols = ", ".join(self._trackers) - total = sum(t.total_records for t in self._trackers.values()) - logger.info( - "⚡️ Async generation: %d column(s) (%s), %d tasks across %d row group(s)", - len(self._trackers), - cols, - total, - num_row_groups, - ) - - def record_success(self, column: str) -> None: - if tracker := self._trackers.get(column): - tracker.record_success() - self._maybe_report() - - def record_failure(self, column: str) -> None: - if tracker := self._trackers.get(column): - tracker.record_failure() - self._maybe_report() - - def record_skipped(self, column: str) -> None: - if tracker := self._trackers.get(column): - tracker.record_skipped() - self._maybe_report() - - def log_final(self) -> None: - if self._bar is not None and self._bar.is_active: - for col in self._trackers: - self._bar.remove_bar(col) - else: - self._emit() - elapsed = time.perf_counter() - self._start_time - snapshots = [tracker.get_snapshot(elapsed) for tracker in self._trackers.values()] - total_ok = sum(snapshot[2] for snapshot in snapshots) - total_fail = sum(snapshot[3] for snapshot in snapshots) - total_skipped = sum(snapshot[4] for snapshot in snapshots) - skipped_suffix = f", {total_skipped} skipped" if total_skipped else "" - logger.info( - "✅ Async generation complete [%.1fs]: %d ok, %d failed%s across %d column(s)", - elapsed, - total_ok, - total_fail, - skipped_suffix, - len(self._trackers), - ) - - def _maybe_report(self) -> None: - if self._bar is not None and self._bar.is_active: - self._update_bar() - return - now = time.perf_counter() - if now - self._last_report_time < self._report_interval: - return - self._last_report_time = now - self._emit() - - def _update_bar(self) -> None: - elapsed = time.perf_counter() - self._start_time - updates: dict[str, tuple[int, int, int]] = {} - for col, tracker in self._trackers.items(): - completed, _total, success, failed, _skipped, _pct, _rate, _emoji = tracker.get_snapshot(elapsed) - updates[col] = (completed, success, failed) - self._bar.update_many(updates) - - def _emit(self) -> None: - current_total = sum(tracker.get_snapshot(0.0)[0] for tracker in self._trackers.values()) - if current_total == self._last_reported_total: - return - self._last_reported_total = current_total - - elapsed = time.perf_counter() - self._start_time - logger.info("📊 Progress [%.1fs]:", elapsed) - for col, tracker in self._trackers.items(): - completed, total_records, _success, _failed, skipped, pct, rate, emoji = tracker.get_snapshot(elapsed) - skipped_suffix = f", {skipped} skipped" if skipped else "" - logger.info( - "%s%s %s: %d/%d (%.0f%%) %.1f rec/s%s", - LOG_INDENT, - emoji, - col, - completed, - total_records, - pct, - rate, - skipped_suffix, - ) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/sticky_progress_bar.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/sticky_progress_bar.py deleted file mode 100644 index f4df06221..000000000 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/sticky_progress_bar.py +++ /dev/null @@ -1,214 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import logging -import re -import shutil -import sys -import time -from dataclasses import dataclass, field -from threading import Lock -from typing import TextIO - -BAR_FILLED = "█" -BAR_EMPTY = "░" -_ANSI_RE = re.compile(r"\033\[[0-9;]*m") - - -def _compute_stats_width(total: int) -> int: - """Compute the fixed width of the stats portion based on total records.""" - total_w = len(str(total)) - # " 100% | xxx/xxx | 9999.9 rec/s | eta 999s | xxx failed" - sample = f" 100% | {'9' * total_w}/{total} | 9999.9 rec/s | eta 999s | {'9' * total_w} failed" - return len(sample) - - -@dataclass -class _BarState: - label: str - total: int - completed: int = 0 - success: int = 0 - failed: int = 0 - start_time: float = field(default_factory=time.perf_counter) - stats_width: int = 0 - - def __post_init__(self) -> None: - self.stats_width = _compute_stats_width(self.total) - - -class StickyProgressBar: - """ANSI progress bar that sticks to the bottom of the terminal. - - Log messages (via standard ``logging``) are rendered above the bar - automatically. The bar redraws in-place after each update. - - Usage:: - - with StickyProgressBar() as bar: - bar.add_bar("col_a", "column 'a'", total=100) - for i in range(100): - bar.update("col_a", completed=i + 1, success=i + 1) - bar.remove_bar("col_a") - - Falls back to a no-op on non-TTY streams (CI, pipes, notebooks). - """ - - def __init__(self, stream: TextIO | None = None) -> None: - self._stream = stream or sys.stderr - self._is_tty = hasattr(self._stream, "isatty") and self._stream.isatty() - self._bars: dict[str, _BarState] = {} - self._lock = Lock() - self._drawn_lines = 0 - self._active = False - self._wrapped_handlers: list[tuple[logging.StreamHandler, object]] = [] - - @property - def is_active(self) -> bool: - return self._active - - @property - def drawn_lines(self) -> int: - return self._drawn_lines - - # -- context manager -- - - def __enter__(self) -> StickyProgressBar: - if self._is_tty: - self._active = True - self._wrap_handlers() - self._write("\033[?25l") # hide cursor - return self - - def __exit__(self, *args: object) -> None: - if self._active: - with self._lock: - self._clear_bars() - self._write("\033[?25h") # show cursor - self._unwrap_handlers() - self._active = False - - # -- public API -- - - def add_bar(self, key: str, label: str, total: int) -> None: - with self._lock: - self._bars[key] = _BarState(label=label, total=total) - if self._active: - self._redraw() - - def update( - self, - key: str, - *, - completed: int, - success: int = 0, - failed: int = 0, - ) -> None: - with self._lock: - if bar := self._bars.get(key): - bar.completed = completed - bar.success = success - bar.failed = failed - if self._active: - self._redraw() - - def update_many(self, updates: dict[str, tuple[int, int, int]]) -> None: - with self._lock: - for key, (completed, success, failed) in updates.items(): - if bar := self._bars.get(key): - bar.completed = completed - bar.success = success - bar.failed = failed - if self._active: - self._redraw() - - def remove_bar(self, key: str) -> None: - with self._lock: - self._bars.pop(key, None) - if self._active: - self._redraw() - - # -- handler wrapping -- - - def _wrap_handlers(self) -> None: - """Wrap stderr logging handlers so log lines render above the bars.""" - root = logging.getLogger() - for handler in root.handlers: - if not isinstance(handler, logging.StreamHandler): - continue - if getattr(handler, "stream", None) is not self._stream: - continue - original_emit = handler.emit - - def _make_wrapper(orig: object) -> object: - def wrapped_emit(record: logging.LogRecord) -> None: - with self._lock: - self._clear_bars() - orig(record) # type: ignore[operator] - self._redraw() - - return wrapped_emit - - handler.emit = _make_wrapper(original_emit) # type: ignore[assignment] - self._wrapped_handlers.append((handler, original_emit)) - - def _unwrap_handlers(self) -> None: - for handler, original_emit in self._wrapped_handlers: - handler.emit = original_emit # type: ignore[assignment] - self._wrapped_handlers.clear() - - # -- drawing -- - - def _clear_bars(self) -> None: - """Clear drawn bar lines from the terminal. Caller must hold the lock.""" - if self._drawn_lines > 0: - for _ in range(self._drawn_lines): - self._write("\033[A\033[2K") - self._write("\r\033[2K") - self._drawn_lines = 0 - - def _redraw(self) -> None: - """Redraw all bars. Caller must hold the lock.""" - self._clear_bars() - if not self._bars: - return - width = shutil.get_terminal_size().columns - max_label = max(len(b.label) for b in self._bars.values()) - for bar in self._bars.values(): - line = self._format_bar(bar, width, max_label) - self._write(line + "\n") - visible = len(_ANSI_RE.sub("", line)) - if width > 0 and visible > width: - self._drawn_lines += (visible + width - 1) // width - else: - self._drawn_lines += 1 - - def _format_bar(self, bar: _BarState, width: int, label_width: int) -> str: - completed = min(bar.completed, bar.total) - pct = (completed / bar.total * 100) if bar.total > 0 else 100.0 - elapsed = time.perf_counter() - bar.start_time - rate = min(bar.completed / elapsed if elapsed > 0 else 0.0, 9999.9) - remaining = max(0, bar.total - completed) - eta = f"{min(remaining / rate, 999):.0f}s" if rate > 0 else "?" - - label = bar.label.ljust(label_width) - total_w = len(str(bar.total)) - count_str = f"{completed:>{total_w}}/{bar.total}" - stats = f" {pct:3.0f}% | {count_str} | {rate:6.1f} rec/s | eta {eta:>4s} | {bar.failed:>{total_w}} failed" - stats = stats.ljust(bar.stats_width) - - bar_width = width - len(label) - bar.stats_width - 4 - if bar_width < 1: - return f" {label} {stats}"[: max(0, width - 1)] - - filled = int(bar_width * pct / 100) - empty = bar_width - filled - - colored_bar = f"\033[32m{BAR_FILLED * filled}\033[90m{BAR_EMPTY * empty}\033[0m" - return f" {label} {colored_bar}{stats}" - - def _write(self, text: str) -> None: - self._stream.write(text) - self._stream.flush() diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 81a935282..3fe4738f4 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -16,6 +16,7 @@ OPENROUTER_PROVIDER_NAME, ) from data_designer.config.utils.image_helpers import is_image_diffusion_model +from data_designer.engine.context import current_generation_column from data_designer.engine.mcp.errors import MCPConfigurationError from data_designer.engine.model_provider import ModelProviderRegistry from data_designer.engine.models.clients.types import ( @@ -42,7 +43,9 @@ RequestUsageStats, TokenUsageStats, ) +from data_designer.engine.models.usage_events import TokenUsageEvent, emit_token_usage_event from data_designer.engine.models.utils import ChatMessage, prompt_to_messages +from data_designer.engine.observability import runtime_correlation_provider if TYPE_CHECKING: from data_designer.engine.mcp.facade import MCPFacade @@ -851,9 +854,9 @@ def _track_usage(self, usage: Usage | None, *, is_request_successful: bool) -> N return token_usage = None - if usage is not None and usage.input_tokens is not None: + if usage is not None and (usage.input_tokens is not None or usage.output_tokens is not None): token_usage = TokenUsageStats( - input_tokens=usage.input_tokens, + input_tokens=usage.input_tokens or 0, output_tokens=usage.output_tokens or 0, reasoning_tokens=usage.reasoning_tokens, reasoning_token_count_source=usage.reasoning_token_count_source, @@ -863,3 +866,18 @@ def _track_usage(self, usage: Usage | None, *, is_request_successful: bool) -> N token_usage=token_usage, request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), ) + if token_usage is not None: + correlation = runtime_correlation_provider.current() + column = current_generation_column.get() + if column is None and correlation is not None: + column = correlation.task_column + emit_token_usage_event( + TokenUsageEvent( + model_alias=self.model_alias, + model_name=self.model_name, + input_tokens=token_usage.input_tokens, + output_tokens=token_usage.output_tokens, + column=column, + correlation=correlation, + ) + ) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py index 38bbd0598..0cd3151af 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py @@ -32,6 +32,7 @@ from data_designer.engine.observability import ( RequestAdmissionEvent, RequestAdmissionEventSink, + emit_request_admission_event, runtime_correlation_provider, ) @@ -778,11 +779,10 @@ def _request_event_locked( ) def _emit_events(self, events: list[RequestAdmissionEvent]) -> None: - if self._event_sink is None: - return for event in events: - try: - self._event_sink.emit_request_event(event) - except Exception: - logger.warning("Request admission event sink raised; dropping event.", exc_info=True) - continue + if self._event_sink is not None: + try: + self._event_sink.emit_request_event(event) + except Exception: + logger.warning("Request admission event sink raised; dropping event.", exc_info=True) + emit_request_admission_event(event) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/usage_events.py b/packages/data-designer-engine/src/data_designer/engine/models/usage_events.py new file mode 100644 index 000000000..ea4ce8205 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/usage_events.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import itertools +import logging +from collections.abc import Callable +from dataclasses import dataclass +from threading import Lock + +from data_designer.engine.observability import RuntimeCorrelation + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class TokenUsageEvent: + model_alias: str + model_name: str + input_tokens: int + output_tokens: int + column: str | None = None + correlation: RuntimeCorrelation | None = None + + +TokenUsageCallback = Callable[[TokenUsageEvent], None] + +_callback_lock = Lock() +_callback_ids = itertools.count() +_callbacks: dict[int, TokenUsageCallback] = {} + + +def subscribe_token_usage(callback: TokenUsageCallback) -> Callable[[], None]: + callback_id = next(_callback_ids) + with _callback_lock: + _callbacks[callback_id] = callback + + def unsubscribe() -> None: + with _callback_lock: + _callbacks.pop(callback_id, None) + + return unsubscribe + + +def emit_token_usage_event(event: TokenUsageEvent) -> None: + with _callback_lock: + callbacks = tuple(_callbacks.values()) + + for callback in callbacks: + try: + callback(event) + except Exception: + logger.debug("Token usage event callback failed", exc_info=True) diff --git a/packages/data-designer-engine/src/data_designer/engine/observability.py b/packages/data-designer-engine/src/data_designer/engine/observability.py index a7a28c41b..174a387b9 100644 --- a/packages/data-designer-engine/src/data_designer/engine/observability.py +++ b/packages/data-designer-engine/src/data_designer/engine/observability.py @@ -4,12 +4,17 @@ from __future__ import annotations import contextvars +import itertools +import logging import math import time from collections.abc import Mapping from dataclasses import dataclass, field, fields, is_dataclass from enum import Enum -from typing import Literal, Protocol +from threading import Lock +from typing import Callable, Literal, Protocol + +logger = logging.getLogger(__name__) @dataclass(frozen=True) @@ -215,6 +220,36 @@ class RequestAdmissionEventSink(Protocol): def emit_request_event(self, event: RequestAdmissionEvent) -> None: ... +RequestAdmissionEventCallback = Callable[[RequestAdmissionEvent], None] + +_request_event_callback_lock = Lock() +_request_event_callback_ids = itertools.count() +_request_event_callbacks: dict[int, RequestAdmissionEventCallback] = {} + + +def subscribe_request_admission_events(callback: RequestAdmissionEventCallback) -> Callable[[], None]: + callback_id = next(_request_event_callback_ids) + with _request_event_callback_lock: + _request_event_callbacks[callback_id] = callback + + def unsubscribe() -> None: + with _request_event_callback_lock: + _request_event_callbacks.pop(callback_id, None) + + return unsubscribe + + +def emit_request_admission_event(event: RequestAdmissionEvent) -> None: + with _request_event_callback_lock: + callbacks = tuple(_request_event_callbacks.values()) + + for callback in callbacks: + try: + callback(event) + except Exception: + logger.debug("Request admission event callback failed", exc_info=True) + + class InMemoryAdmissionEventSink: """Small sink used by tests, diagnostics, and benchmark smoke runs.""" diff --git a/packages/data-designer-engine/src/data_designer/engine/progress/__init__.py b/packages/data-designer-engine/src/data_designer/engine/progress/__init__.py new file mode 100644 index 000000000..f1ea03ddb --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/progress/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations diff --git a/packages/data-designer-engine/src/data_designer/engine/progress/reporter.py b/packages/data-designer-engine/src/data_designer/engine/progress/reporter.py new file mode 100644 index 000000000..edb2f74a0 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/progress/reporter.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import logging +import time +from collections.abc import Callable +from typing import TYPE_CHECKING + +from data_designer.engine.models.usage_events import TokenUsageEvent, subscribe_token_usage +from data_designer.engine.observability import RequestAdmissionEvent, subscribe_request_admission_events +from data_designer.engine.progress.tracker import ProgressTracker +from data_designer.logging import LOG_INDENT + +if TYPE_CHECKING: + from data_designer.engine.progress.terminal.throughput_panel import TerminalThroughputPanel + +logger = logging.getLogger(__name__) + +DEFAULT_REPORT_INTERVAL = 5.0 +DEFAULT_TTY_REPORT_INTERVAL = 0.75 +FEEDBACK_MARKER_EVENTS = frozenset({"request_rate_limited", "request_limit_decreased"}) + + +class AsyncProgressReporter: + """Consolidated progress reporter for async generation. + + Owns per-column ProgressTracker instances (in quiet mode) and emits + a single grouped log block at most once per ``report_interval`` seconds. + """ + + def __init__( + self, + trackers: dict[str, ProgressTracker], + *, + report_interval: float = DEFAULT_REPORT_INTERVAL, + progress_bar: TerminalThroughputPanel | None = None, + run_id: str | None = None, + ) -> None: + self._trackers = trackers + self._report_interval = report_interval + self._run_id = run_id + self._start_time = time.perf_counter() + self._last_report_time: float = self._start_time + self._last_bar_report_time: float = self._start_time + self._last_reported_total: int = -1 + self._bar = progress_bar + self._unsubscribe_token_usage: Callable[[], None] | None = None + self._unsubscribe_request_admission_events: Callable[[], None] | None = None + if self._bar is not None: + for col, tracker in trackers.items(): + self._bar.add_bar(col, col, tracker.total_records) + self._unsubscribe_token_usage = subscribe_token_usage(self._record_token_usage) + self._unsubscribe_request_admission_events = subscribe_request_admission_events( + self._record_request_admission_event + ) + + def log_start(self, num_row_groups: int) -> None: + cols = ", ".join(self._trackers) + total = sum(t.total_records for t in self._trackers.values()) + logger.info( + "⚡️ Async generation: %d column(s) (%s), %d tasks across %d row group(s)", + len(self._trackers), + cols, + total, + num_row_groups, + ) + + def record_success(self, column: str) -> None: + if tracker := self._trackers.get(column): + tracker.record_success() + self._maybe_report() + + def record_failure(self, column: str) -> None: + if tracker := self._trackers.get(column): + tracker.record_failure() + self._maybe_report() + + def record_skipped(self, column: str) -> None: + if tracker := self._trackers.get(column): + tracker.record_skipped() + self._maybe_report() + + def log_final(self) -> None: + try: + if self._bar is not None and self._bar.is_active: + self._update_bar(force=True) + else: + self._emit() + elapsed = time.perf_counter() - self._start_time + snapshots = [tracker.get_snapshot(elapsed) for tracker in self._trackers.values()] + total_ok = sum(snapshot[2] for snapshot in snapshots) + total_fail = sum(snapshot[3] for snapshot in snapshots) + total_skipped = sum(snapshot[4] for snapshot in snapshots) + skipped_suffix = f", {total_skipped} skipped" if total_skipped else "" + logger.info( + "✅ Async generation complete [%.1fs]: %d ok, %d failed%s across %d column(s)", + elapsed, + total_ok, + total_fail, + skipped_suffix, + len(self._trackers), + ) + finally: + self.close() + + def close(self) -> None: + if self._unsubscribe_token_usage is not None: + self._unsubscribe_token_usage() + self._unsubscribe_token_usage = None + if self._unsubscribe_request_admission_events is not None: + self._unsubscribe_request_admission_events() + self._unsubscribe_request_admission_events = None + + def _maybe_report(self) -> None: + now = time.perf_counter() + if self._bar is not None and self._bar.is_active: + if now - self._last_bar_report_time < DEFAULT_TTY_REPORT_INTERVAL: + return + self._last_bar_report_time = now + self._update_bar() + return + if now - self._last_report_time < self._report_interval: + return + self._last_report_time = now + self._emit() + + def _update_bar(self, *, force: bool = False) -> None: + elapsed = time.perf_counter() - self._start_time + updates: dict[str, tuple[int, int, int, int]] = {} + for col, tracker in self._trackers.items(): + completed, _total, success, failed, skipped, _pct, _rate, _emoji = tracker.get_snapshot(elapsed) + updates[col] = (completed, success, failed, skipped) + self._bar.update_many(updates, force=force) + + def _record_token_usage(self, event: TokenUsageEvent) -> None: + if self._bar is not None and self._matches_run(event.correlation): + self._bar.record_model_usage( + model_alias=event.model_alias, + model_name=event.model_name, + input_tokens=event.input_tokens, + output_tokens=event.output_tokens, + ) + + def _record_request_admission_event(self, event: RequestAdmissionEvent) -> None: + if ( + self._bar is not None + and event.event_kind in FEEDBACK_MARKER_EVENTS + and self._matches_run(event.captured_correlation) + ): + self._bar.record_feedback_signal(event_kind=event.event_kind) + + def _matches_run(self, correlation: object) -> bool: + if self._run_id is None: + return True + if correlation is None: + return False + if isinstance(correlation, dict): + return correlation.get("run_id") == self._run_id + return getattr(correlation, "run_id", None) == self._run_id + + def _emit(self) -> None: + current_total = sum(tracker.get_snapshot(0.0)[0] for tracker in self._trackers.values()) + if current_total == self._last_reported_total: + return + self._last_reported_total = current_total + + elapsed = time.perf_counter() - self._start_time + logger.info("📊 Progress [%.1fs]:", elapsed) + for col, tracker in self._trackers.items(): + completed, total_records, _success, _failed, skipped, pct, rate, emoji = tracker.get_snapshot(elapsed) + skipped_suffix = f", {skipped} skipped" if skipped else "" + logger.info( + "%s%s %s: %d/%d (%.0f%%) %.1f rec/s%s", + LOG_INDENT, + emoji, + col, + completed, + total_records, + pct, + rate, + skipped_suffix, + ) diff --git a/packages/data-designer-engine/src/data_designer/engine/progress/terminal/__init__.py b/packages/data-designer-engine/src/data_designer/engine/progress/terminal/__init__.py new file mode 100644 index 000000000..f1ea03ddb --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/progress/terminal/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations diff --git a/packages/data-designer-engine/src/data_designer/engine/progress/terminal/throughput_panel.py b/packages/data-designer-engine/src/data_designer/engine/progress/terminal/throughput_panel.py new file mode 100644 index 000000000..aba5d8fd8 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/progress/terminal/throughput_panel.py @@ -0,0 +1,872 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import logging +import re +import shutil +import sys +import time +from dataclasses import dataclass, field +from threading import Lock +from typing import Sequence, TextIO + +import asciichartpy + +_ANSI_RE = re.compile(r"\033\[[0-9;?]*[a-zA-Z]") +_CONTROL_RE = re.compile(r"[\x00-\x1f\x7f-\x9f]") +_RESET = "\033[0m" +_BORDER = "\033[38;5;39m" +_TITLE = "\033[1;38;5;81m" +_MUTED = "\033[2;38;5;245m" +_FAILED = "\033[31m" +_OK = "\033[32m" +_TRACK = "\033[38;5;238m" +_FEEDBACK_MARKER = "\033[1;38;5;196m◆\033[0m" +_CURVE_COLORS = [ + asciichartpy.lightcyan, + asciichartpy.lightgreen, + asciichartpy.lightmagenta, + asciichartpy.lightyellow, + asciichartpy.lightblue, + asciichartpy.lightred, + asciichartpy.cyan, + asciichartpy.green, +] +_DEFAULT_PANEL_HEIGHT = 22 +_MIN_PANEL_HEIGHT = 9 +_MIN_TERMINAL_WIDTH = 30 +_MIN_REDRAW_INTERVAL_SECONDS = 0.75 +_RATE_SAMPLE_INTERVAL_SECONDS = 2.0 +_RATE_SMOOTHING_WINDOW = 3 +_MAX_RATE_SAMPLES = 7200 +_MAX_FEEDBACK_MARKERS = 512 +_RATE_FORMAT = "{:6.1f} " +_Y_AXIS_RESERVED = 12 +_CHART_LINE_COUNT = 9 +_MIN_CHART_LINE_COUNT = 3 +_MIN_LEGEND_LABEL_WIDTH = 8 +_MIN_MODEL_ALIAS_WIDTH = 10 +_MIN_MODEL_NAME_WIDTH = 10 +_RATE_COLUMN_WIDTH = 5 +_INPUT_TOKEN_RATE_WIDTH = 8 +_OUTPUT_TOKEN_RATE_WIDTH = 9 +_LEGEND_COLUMN_GAP = 2 +_MIN_PROGRESS_BAR_WIDTH = 6 +_PROGRESS_BAR_CHAR = "━" +_NOW_RATE_HEADER = "now rec/s" +_AVG_RATE_HEADER = "avg rec/s" + + +_ProgressUpdate = tuple[int, int, int, int] + + +def _visible_len(text: str) -> int: + return len(_ANSI_RE.sub("", text)) + + +def _fit_ansi(text: str, width: int) -> str: + visible = _visible_len(text) + if visible > width: + return _ANSI_RE.sub("", text)[:width] + return text + (" " * (width - visible)) + + +def _color(text: str, color: str) -> str: + return f"{color}{text}{_RESET}" + + +def _sanitize_label(label: str) -> str: + return _CONTROL_RE.sub("", _ANSI_RE.sub("", label)) + + +def _fit_plain(text: str, width: int) -> str: + clean = _sanitize_label(text) + return clean[:width].ljust(width) + + +def _average(values: Sequence[float]) -> float: + return sum(values) / len(values) if values else 0.0 + + +def _smooth_series(series: Sequence[float], window: int = _RATE_SMOOTHING_WINDOW) -> list[float]: + if window <= 1: + return list(series) + return [_average(series[max(0, i - window + 1) : i + 1]) for i in range(len(series))] + + +def _compress_series(series: Sequence[float], max_points: int) -> list[float]: + if max_points <= 0: + return [] + if len(series) <= max_points: + return list(series) or [0.0] + + compressed: list[float] = [] + count = len(series) + for bucket_index in range(max_points): + start = int(bucket_index * count / max_points) + end = int((bucket_index + 1) * count / max_points) + bucket = series[start : max(end, start + 1)] + compressed.append(_average(bucket)) + return compressed + + +def _expand_series(series: Sequence[float], point_count: int) -> list[float]: + if point_count <= 0: + return [] + if not series: + return [0.0] * point_count + if len(series) == 1: + return [series[0]] * point_count + + expanded: list[float] = [] + source_last_index = len(series) - 1 + target_last_index = max(1, point_count - 1) + for index in range(point_count): + position = index * source_last_index / target_last_index + left_index = int(position) + right_index = min(left_index + 1, source_last_index) + weight = position - left_index + expanded.append(series[left_index] * (1 - weight) + series[right_index] * weight) + return expanded + + +def _fit_series(series: Sequence[float], point_count: int) -> list[float]: + if len(series) > point_count: + return _compress_series(series, point_count) + return _expand_series(series, point_count) + + +def _visible_index_of_any(text: str, chars: str) -> int | None: + visible_index = 0 + index = 0 + while index < len(text): + if match := _ANSI_RE.match(text, index): + index = match.end() + continue + if text[index] in chars: + return visible_index + visible_index += 1 + index += 1 + return None + + +def _replace_visible_char(text: str, visible_index: int, replacement: str) -> str: + output: list[str] = [] + current_visible_index = 0 + index = 0 + replaced = False + while index < len(text): + if match := _ANSI_RE.match(text, index): + output.append(match.group()) + index = match.end() + continue + + if current_visible_index == visible_index: + output.append(replacement) + replaced = True + else: + output.append(text[index]) + + current_visible_index += 1 + index += 1 + + if not replaced and visible_index >= current_visible_index: + output.append(" " * (visible_index - current_visible_index)) + output.append(replacement) + return "".join(output) + + +@dataclass +class _BarState: + label: str + total: int + completed: int = 0 + success: int = 0 + failed: int = 0 + skipped: int = 0 + start_time: float = field(default_factory=time.perf_counter) + last_sample_time: float = field(default_factory=time.perf_counter) + last_completed: int = 0 + latest_rate: float = 0.0 + rates: list[float] = field(default_factory=lambda: [0.0]) + + def record_update( + self, + *, + completed: int, + success: int, + failed: int, + skipped: int, + now: float, + ) -> None: + bounded_completed = min(max(completed, 0), self.total) if self.total > 0 else max(completed, 0) + elapsed = now - self.last_sample_time + self.completed = bounded_completed + self.success = success + self.failed = failed + self.skipped = skipped + + should_sample = elapsed >= _RATE_SAMPLE_INTERVAL_SECONDS or bounded_completed >= self.total + if should_sample: + delta_completed = max(0, bounded_completed - self.last_completed) + sample_elapsed = max(elapsed, 0.001) + rate = delta_completed / sample_elapsed + self.rates.append(rate) + if len(self.rates) > _MAX_RATE_SAMPLES: + del self.rates[: len(self.rates) - _MAX_RATE_SAMPLES] + self.latest_rate = _average(self.rates[-_RATE_SMOOTHING_WINDOW:]) + self.last_completed = bounded_completed + self.last_sample_time = now + + def average_rate(self, now: float) -> float: + elapsed = max(now - self.start_time, 0.001) + return self.completed / elapsed if elapsed > 0 else 0.0 + + +@dataclass +class _ModelUsageState: + model_alias: str + model_name: str + start_time: float + request_count: int = 0 + input_tokens: int = 0 + output_tokens: int = 0 + + def record_usage(self, *, model_name: str, input_tokens: int, output_tokens: int) -> None: + self.model_name = model_name + self.request_count += 1 + self.input_tokens += max(0, input_tokens) + self.output_tokens += max(0, output_tokens) + + def rpm(self, now: float) -> float: + elapsed_minutes = max((now - self.start_time) / 60.0, 0.001) + return self.request_count / elapsed_minutes + + def input_token_rate(self, now: float) -> float: + elapsed = max(now - self.start_time, 0.001) + return self.input_tokens / elapsed if elapsed > 0 else 0.0 + + def output_token_rate(self, now: float) -> float: + elapsed = max(now - self.start_time, 0.001) + return self.output_tokens / elapsed if elapsed > 0 else 0.0 + + +@dataclass(frozen=True) +class _FeedbackMarker: + elapsed: float + value: float + event_kind: str + + +class TerminalThroughputPanel: + """ANSI throughput chart panel that sticks to the bottom of the terminal. + + Log messages (via standard ``logging``) are rendered above the panel + automatically. The panel redraws in-place after each update, tracks one + records-per-second curve per active generation column, and gives the chart + a stable height while the column and model tables grow to fit their rows. + + Usage:: + + with TerminalThroughputPanel() as bar: + bar.add_bar("col_a", "column 'a'", total=100) + for i in range(100): + bar.update("col_a", completed=i + 1, success=i + 1) + + Falls back to a no-op on non-TTY streams (CI, pipes, notebooks). + """ + + def __init__(self, stream: TextIO | None = None, *, panel_height: int = _DEFAULT_PANEL_HEIGHT) -> None: + self._stream = stream or sys.stderr + self._is_tty = hasattr(self._stream, "isatty") and self._stream.isatty() + self._bars: dict[str, _BarState] = {} + self._model_usage: dict[str, _ModelUsageState] = {} + self._feedback_markers: list[_FeedbackMarker] = [] + self._lock = Lock() + self._drawn_lines = 0 + self._active = False + self._wrapped_handlers: list[tuple[logging.StreamHandler, object]] = [] + self._panel_height = max(_MIN_PANEL_HEIGHT, panel_height) + self._start_time = time.perf_counter() + self._last_redraw_time: float = 0.0 + + @property + def is_active(self) -> bool: + return self._active + + @property + def drawn_lines(self) -> int: + return self._drawn_lines + + # -- context manager -- + + def __enter__(self) -> TerminalThroughputPanel: + if self._is_tty and shutil.get_terminal_size().columns >= _MIN_TERMINAL_WIDTH: + self._active = True + self._start_time = time.perf_counter() + self._last_redraw_time = 0.0 + self._wrap_handlers() + self._write("\033[?25l") # hide cursor + return self + + def __exit__(self, *args: object) -> None: + if self._active: + self._write("\033[?25h") # show cursor + self._unwrap_handlers() + self._active = False + self._drawn_lines = 0 + + # -- public API -- + + def add_bar(self, key: str, label: str, total: int) -> None: + with self._lock: + self._bars[key] = _BarState(label=_sanitize_label(label), total=total) + if self._active: + self._redraw(force=True) + + def update( + self, + key: str, + *, + completed: int, + success: int = 0, + failed: int = 0, + skipped: int = 0, + force: bool = False, + ) -> None: + with self._lock: + if bar := self._bars.get(key): + now = time.perf_counter() + bar.record_update( + completed=completed, + success=success, + failed=failed, + skipped=skipped, + now=now, + ) + if self._active: + self._redraw_if_due(now, force=force) + + def update_many(self, updates: dict[str, _ProgressUpdate], *, force: bool = False) -> None: + with self._lock: + now = time.perf_counter() + for key, update in updates.items(): + if bar := self._bars.get(key): + completed, success, failed, skipped = update + bar.record_update( + completed=completed, + success=success, + failed=failed, + skipped=skipped, + now=now, + ) + if self._active: + self._redraw_if_due(now, force=force) + + def record_model_usage( + self, + *, + model_alias: str, + model_name: str, + input_tokens: int, + output_tokens: int, + force: bool = False, + ) -> None: + with self._lock: + now = time.perf_counter() + alias = _sanitize_label(model_alias) or "(unknown)" + name = _sanitize_label(model_name) or "(unknown)" + if state := self._model_usage.get(alias): + state.record_usage(model_name=name, input_tokens=input_tokens, output_tokens=output_tokens) + else: + self._model_usage[alias] = _ModelUsageState( + model_alias=alias, + model_name=name, + start_time=self._start_time, + request_count=1, + input_tokens=max(0, input_tokens), + output_tokens=max(0, output_tokens), + ) + if self._active: + self._redraw_if_due(now, force=force) + + def record_feedback_signal(self, *, event_kind: str, force: bool = False) -> None: + with self._lock: + now = time.perf_counter() + self._feedback_markers.append( + _FeedbackMarker( + elapsed=max(now - self._start_time, 0.0), + value=max((bar.latest_rate for bar in self._bars.values()), default=0.0), + event_kind=_sanitize_label(event_kind), + ) + ) + if len(self._feedback_markers) > _MAX_FEEDBACK_MARKERS: + del self._feedback_markers[: len(self._feedback_markers) - _MAX_FEEDBACK_MARKERS] + if self._active: + self._redraw_if_due(now, force=force) + + def remove_bar(self, key: str) -> None: + with self._lock: + self._bars.pop(key, None) + if self._active: + self._redraw(force=True) + + # -- handler wrapping -- + + def _wrap_handlers(self) -> None: + """Wrap stderr logging handlers so log lines render above the bars.""" + root = logging.getLogger() + for handler in root.handlers: + if not isinstance(handler, logging.StreamHandler): + continue + if getattr(handler, "stream", None) is not self._stream: + continue + original_emit = handler.emit + + def _make_wrapper(orig: object) -> object: + def wrapped_emit(record: logging.LogRecord) -> None: + with self._lock: + self._clear_bars() + orig(record) # type: ignore[operator] + self._redraw(force=True) + + return wrapped_emit + + handler.emit = _make_wrapper(original_emit) # type: ignore[assignment] + self._wrapped_handlers.append((handler, original_emit)) + + def _unwrap_handlers(self) -> None: + for handler, original_emit in self._wrapped_handlers: + handler.emit = original_emit # type: ignore[assignment] + self._wrapped_handlers.clear() + + # -- drawing -- + + def _clear_bars(self) -> None: + """Clear drawn panel lines from the terminal. Caller must hold the lock.""" + if self._drawn_lines > 0: + for _ in range(self._drawn_lines): + self._write("\033[A\033[2K") + self._write("\r\033[2K") + self._drawn_lines = 0 + + def _redraw_if_due(self, now: float, *, force: bool = False) -> None: + if force or self._drawn_lines == 0 or now - self._last_redraw_time >= _MIN_REDRAW_INTERVAL_SECONDS: + self._redraw(force=True, now=now) + + def _redraw(self, *, force: bool = False, now: float | None = None) -> None: + """Redraw the chart panel. Caller must hold the lock.""" + if not force: + current_time = time.perf_counter() if now is None else now + if self._drawn_lines > 0 and current_time - self._last_redraw_time < _MIN_REDRAW_INTERVAL_SECONDS: + return + self._clear_bars() + if not self._bars: + return + lines = self._format_panel() + for line in lines: + self._write(line + "\n") + self._drawn_lines = len(lines) + self._last_redraw_time = time.perf_counter() if now is None else now + + def _format_panel(self) -> list[str]: + terminal_size = shutil.get_terminal_size() + panel_width = max(4, terminal_size.columns - 1) + panel_height = min(self._panel_height, max(_MIN_PANEL_HEIGHT, terminal_size.lines - 1)) + inner_width = panel_width - 2 + + body_capacity = max(1, panel_height - 4) + chart_line_count = min(_CHART_LINE_COUNT, max(_MIN_CHART_LINE_COUNT, body_capacity - 1)) + minimum_legend_capacity = max(1, body_capacity - chart_line_count) + chart_height = chart_line_count - 1 + + now = time.perf_counter() + bars = list(self._bars.values()) + model_usage = list(self._model_usage.values()) + chart_lines = self._format_chart_lines(bars, inner_width, chart_height, now) + legend_lines = self._format_legend_lines(bars, model_usage, now, minimum_legend_capacity, inner_width) + + lines = [ + self._border("╭", "─", "╮", panel_width), + self._panel_line(self._format_header(bars, now), inner_width), + ] + lines.extend(self._panel_line(line, inner_width) for line in chart_lines) + lines.append(self._border("├", "─", "┤", panel_width)) + lines.extend(self._panel_line(line, inner_width) for line in legend_lines) + lines.append(self._border("╰", "─", "╯", panel_width)) + return lines + + def _format_header(self, bars: list[_BarState], now: float) -> str: + elapsed = max(now - self._start_time, 0.0) + completed = sum(bar.completed for bar in bars) + total = sum(bar.total for bar in bars) + latest_rate = sum(bar.latest_rate for bar in bars) + failed = sum(bar.failed for bar in bars) + skipped = sum(bar.skipped for bar in bars) + failed_text = _color(f"{failed} failed", _FAILED) if failed else _color("0 failed", _OK) + skipped_text = f" | {skipped} skipped" if skipped else "" + return ( + f"{_TITLE}Throughput{_RESET} " + f"{_MUTED}rec/s | {elapsed:5.1f}s | {completed}/{total} | " + f"now {latest_rate:6.1f}{skipped_text} | {_RESET}{failed_text}" + ) + + def _format_chart_lines( + self, + bars: list[_BarState], + inner_width: int, + chart_height: int, + now: float, + ) -> list[str]: + max_points = max(2, inner_width - _Y_AXIS_RESERVED) + series = [_fit_series(_smooth_series(bar.rates), max_points) for bar in bars] + max_rate = max((max(points) for points in series if points), default=0.0) + chart_max = max(1.0, max_rate) + chart = asciichartpy.plot( + series, + { + "height": chart_height, + "min": 0.0, + "max": chart_max, + "format": _RATE_FORMAT, + "colors": [_CURVE_COLORS[i % len(_CURVE_COLORS)] for i in range(len(series))], + }, + ) + lines = chart.splitlines() + while len(lines) < chart_height + 1: + lines.append("") + return self._overlay_feedback_markers( + lines[: chart_height + 1], + current_elapsed=max(now - self._start_time, 0.001), + chart_max=chart_max, + point_count=max_points, + chart_height=chart_height, + ) + + def _overlay_feedback_markers( + self, + lines: list[str], + *, + current_elapsed: float, + chart_max: float, + point_count: int, + chart_height: int, + ) -> list[str]: + if not self._feedback_markers or not lines: + return lines + + marked_lines = list(lines) + plot_column_count = max(1, point_count - 1) + for marker in self._feedback_markers: + marker_elapsed = min(max(marker.elapsed, 0.0), current_elapsed) + x_index = int(round(marker_elapsed / current_elapsed * (plot_column_count - 1))) + y_value = min(max(marker.value, 0.0), chart_max) + row_index = int(round((chart_max - y_value) / chart_max * chart_height)) + row_index = min(max(row_index, 0), len(marked_lines) - 1) + axis_index = _visible_index_of_any(marked_lines[row_index], "┼┤") + if axis_index is None: + continue + marked_lines[row_index] = _replace_visible_char( + marked_lines[row_index], + axis_index + 1 + x_index, + _FEEDBACK_MARKER, + ) + return marked_lines + + def _format_legend_lines( + self, + bars: list[_BarState], + model_usage: list[_ModelUsageState], + now: float, + minimum_capacity: int, + inner_width: int, + ) -> list[str]: + lines = self._format_column_table_lines(bars, now, inner_width) + if model_usage: + lines.append("") + lines.extend(self._format_model_table_lines(model_usage, now, inner_width)) + + while len(lines) < minimum_capacity: + lines.append("") + return lines + + def _format_column_table_lines( + self, + bars: list[_BarState], + now: float, + inner_width: int, + ) -> list[str]: + lines: list[str] = [] + + include_status = any(bar.failed or bar.skipped for bar in bars) + label_width, done_width, rate_width, status_width, progress_width = self._column_table_widths( + bars, + now, + include_status=include_status, + inner_width=inner_width, + ) + lines.append( + self._format_legend_table_line( + marker="", + label="column", + done="done", + now_value=_NOW_RATE_HEADER, + avg_value=_AVG_RATE_HEADER, + status="status" if include_status else None, + label_width=label_width, + done_width=done_width, + rate_width=rate_width, + status_width=status_width, + progress_bar="", + progress_width=progress_width, + ) + ) + + for index, bar in enumerate(bars): + color = _CURVE_COLORS[index % len(_CURVE_COLORS)] + lines.append( + self._format_legend_table_line( + marker=_color("●", color), + label=bar.label, + done=self._format_done(bar), + now_value=f"{bar.latest_rate:.1f}", + avg_value=f"{bar.average_rate(now):.1f}", + status=self._format_status(bar) if include_status else None, + label_width=label_width, + done_width=done_width, + rate_width=rate_width, + status_width=status_width, + progress_bar=self._format_progress_bar(bar, progress_width, color), + progress_width=progress_width, + ) + ) + + return lines + + def _column_table_widths( + self, + bars: list[_BarState], + now: float, + *, + include_status: bool, + inner_width: int, + ) -> tuple[int, int, int, int, int]: + done_width = max(len("done"), *(len(self._format_done(bar)) for bar in bars)) + rate_width = max( + len(_NOW_RATE_HEADER), + len(_AVG_RATE_HEADER), + _RATE_COLUMN_WIDTH, + *(len(f"{value:.1f}") for bar in bars for value in (bar.latest_rate, bar.average_rate(now))), + ) + status_width = 0 + if include_status: + status_width = max(len("status"), *(len(self._format_status(bar)) for bar in bars)) + + separator_count = 4 + int(include_status) + fixed_width_without_label_or_progress = ( + 2 + (separator_count * _LEGEND_COLUMN_GAP) + done_width + (rate_width * 2) + status_width + ) + content_label_width = max(len("column"), *(len(_sanitize_label(bar.label)) for bar in bars)) + desired_label_width = max(_MIN_LEGEND_LABEL_WIDTH, content_label_width) + available_width = inner_width - fixed_width_without_label_or_progress + + if available_width >= desired_label_width + _MIN_PROGRESS_BAR_WIDTH: + label_width = desired_label_width + progress_width = available_width - label_width + elif available_width >= _MIN_LEGEND_LABEL_WIDTH + _MIN_PROGRESS_BAR_WIDTH: + progress_width = _MIN_PROGRESS_BAR_WIDTH + label_width = available_width - progress_width + else: + label_width = max(_MIN_LEGEND_LABEL_WIDTH, min(desired_label_width, max(0, available_width))) + progress_width = max(0, available_width - label_width) + + return label_width, done_width, rate_width, status_width, progress_width + + def _format_model_table_lines( + self, + model_usage: list[_ModelUsageState], + now: float, + inner_width: int, + ) -> list[str]: + lines: list[str] = [] + + alias_width, model_width, rpm_width, input_width, output_width = self._model_table_widths( + model_usage, + now, + inner_width, + ) + lines.append( + self._format_model_table_line( + model_alias="model alias", + model_name="model name", + rpm="rpm", + input_token_rate="in tok/s", + output_token_rate="out tok/s", + alias_width=alias_width, + model_width=model_width, + rpm_width=rpm_width, + input_width=input_width, + output_width=output_width, + header=True, + ) + ) + + for state in model_usage: + lines.append( + self._format_model_table_line( + model_alias=state.model_alias, + model_name=state.model_name, + rpm=f"{state.rpm(now):.1f}", + input_token_rate=f"{state.input_token_rate(now):.1f}", + output_token_rate=f"{state.output_token_rate(now):.1f}", + alias_width=alias_width, + model_width=model_width, + rpm_width=rpm_width, + input_width=input_width, + output_width=output_width, + header=False, + ) + ) + + return lines + + def _model_table_widths( + self, + model_usage: list[_ModelUsageState], + now: float, + inner_width: int, + ) -> tuple[int, int, int, int, int]: + rpm_width = max( + len("rpm"), + _RATE_COLUMN_WIDTH, + *(len(f"{state.rpm(now):.1f}") for state in model_usage), + ) + input_width = max( + len("in tok/s"), + _INPUT_TOKEN_RATE_WIDTH, + *(len(f"{state.input_token_rate(now):.1f}") for state in model_usage), + ) + output_width = max( + len("out tok/s"), + _OUTPUT_TOKEN_RATE_WIDTH, + *(len(f"{state.output_token_rate(now):.1f}") for state in model_usage), + ) + + fixed_width_without_text = 2 + (4 * _LEGEND_COLUMN_GAP) + rpm_width + input_width + output_width + available_text_width = inner_width - fixed_width_without_text + desired_alias_width = max( + _MIN_MODEL_ALIAS_WIDTH, + len("model alias"), + *(len(state.model_alias) for state in model_usage), + ) + desired_model_width = max( + _MIN_MODEL_NAME_WIDTH, + len("model name"), + *(len(state.model_name) for state in model_usage), + ) + + if available_text_width >= desired_alias_width + desired_model_width: + alias_width = desired_alias_width + model_width = desired_model_width + elif available_text_width >= _MIN_MODEL_ALIAS_WIDTH + _MIN_MODEL_NAME_WIDTH: + alias_width = min(desired_alias_width, max(_MIN_MODEL_ALIAS_WIDTH, available_text_width // 2)) + model_width = available_text_width - alias_width + else: + alias_width = _MIN_MODEL_ALIAS_WIDTH + model_width = _MIN_MODEL_NAME_WIDTH + + return alias_width, model_width, rpm_width, input_width, output_width + + def _format_progress_bar(self, bar: _BarState, width: int, color: str) -> str: + if width <= 0: + return "" + + if bar.total <= 0: + fraction = 1.0 + else: + fraction = min(max(bar.completed / bar.total, 0.0), 1.0) + + filled_width = min(width, int(round(width * fraction))) + if bar.completed > 0: + filled_width = max(1, filled_width) + empty_width = width - filled_width + return f"{color}{_PROGRESS_BAR_CHAR * filled_width}{_TRACK}{_PROGRESS_BAR_CHAR * empty_width}{_RESET}" + + def _format_legend_table_line( + self, + *, + marker: str, + label: str, + done: str, + now_value: str, + avg_value: str, + status: str | None, + label_width: int, + done_width: int, + rate_width: int, + status_width: int, + progress_bar: str, + progress_width: int, + ) -> str: + marker_text = f"{marker} " if marker else " " + gap = " " * _LEGEND_COLUMN_GAP + line = ( + f"{marker_text}{_fit_plain(label, label_width)}{gap}{now_value:>{rate_width}}{gap}{avg_value:>{rate_width}}" + ) + if status is not None: + line = f"{line}{gap}{status:>{status_width}}" + line = f"{line}{gap}{done:>{done_width}}" + if progress_width > 0: + line = f"{line}{gap}{_fit_ansi(progress_bar, progress_width)}" + if marker: + return line + return f"{_MUTED}{line}{_RESET}" + + def _format_model_table_line( + self, + *, + model_alias: str, + model_name: str, + rpm: str, + input_token_rate: str, + output_token_rate: str, + alias_width: int, + model_width: int, + rpm_width: int, + input_width: int, + output_width: int, + header: bool, + ) -> str: + gap = " " * _LEGEND_COLUMN_GAP + line = ( + f" {_fit_plain(model_alias, alias_width)}{gap}{_fit_plain(model_name, model_width)}" + f"{gap}{rpm:>{rpm_width}}{gap}{input_token_rate:>{input_width}}" + f"{gap}{output_token_rate:>{output_width}}" + ) + if header: + return f"{_MUTED}{line}{_RESET}" + return line + + def _format_done(self, bar: _BarState) -> str: + pct = (bar.completed / bar.total * 100) if bar.total > 0 else 100.0 + return f"{bar.completed}/{bar.total} {pct:3.0f}%" + + def _format_status(self, bar: _BarState) -> str: + parts: list[str] = [] + if bar.failed: + parts.append(f"{bar.failed} failed") + if bar.skipped: + parts.append(f"{bar.skipped} skipped") + return ", ".join(parts) if parts else "ok" + + def _panel_line(self, text: str, inner_width: int) -> str: + return f"{_BORDER}│{_RESET}{_fit_ansi(text, inner_width)}{_BORDER}│{_RESET}" + + def _border(self, left: str, fill: str, right: str, width: int) -> str: + return f"{_BORDER}{left}{fill * (width - 2)}{right}{_RESET}" + + def _write(self, text: str) -> None: + self._stream.write(text) + self._stream.flush() diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/progress_tracker.py b/packages/data-designer-engine/src/data_designer/engine/progress/tracker.py similarity index 91% rename from packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/progress_tracker.py rename to packages/data-designer-engine/src/data_designer/engine/progress/tracker.py index 73afa2e26..d8e1665d0 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/progress_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/progress/tracker.py @@ -11,7 +11,7 @@ from data_designer.logging import LOG_INDENT, RandomEmoji if TYPE_CHECKING: - from data_designer.engine.dataset_builders.utils.sticky_progress_bar import StickyProgressBar + from data_designer.engine.progress.terminal.throughput_panel import TerminalThroughputPanel logger = logging.getLogger(__name__) @@ -42,7 +42,7 @@ def __init__( log_interval_percent: int = 10, *, quiet: bool = False, - progress_bar: StickyProgressBar | None = None, + progress_bar: TerminalThroughputPanel | None = None, progress_bar_key: str | None = None, ): self.total_records = total_records @@ -103,8 +103,16 @@ def get_snapshot(self, elapsed: float | None = None) -> tuple[int, int, int, int def log_final(self) -> None: """Log final progress summary.""" with self.lock: - if self._bar is not None: - self._bar.remove_bar(self._bar_key) + if self._bar is not None and self._bar.is_active: + self._bar.update( + self._bar_key, + completed=self.completed, + success=self.success, + failed=self.failed, + skipped=self.skipped, + force=True, + ) + return if self.completed > 0: self._log_progress_unlocked() @@ -143,6 +151,7 @@ def _log_progress_unlocked(self) -> None: completed=self.completed, success=self.success, failed=self.failed, + skipped=self.skipped, ) return diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py index af7bc7109..2f3676a52 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_custom.py @@ -26,6 +26,7 @@ _compute_bridge_timeout, ) from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError +from data_designer.engine.context import current_generation_column, current_run_cancel_event from data_designer.engine.models.clients.errors import SyncClientUnavailableError from data_designer.engine.models.errors import RETRYABLE_MODEL_ERRORS, ModelTimeoutError from data_designer.engine.resources.resource_provider import ResourceProvider @@ -570,6 +571,63 @@ async def fake_agenerate(*args: Any, **kwargs: Any) -> tuple: engine_thread.join(timeout=5) +def test_async_bridge_preserves_generation_column_context() -> None: + """Bridged sync custom generators should still attribute model usage to their column.""" + facade = Mock() + facade.generate.side_effect = SyncClientUnavailableError( + "Sync methods are not available on an async-mode HttpModelClient." + ) + facade.request_timeout = 60.0 + + async def fake_agenerate(*args: Any, **kwargs: Any) -> tuple: + assert current_generation_column.get() == "intent_label" + return ("async_result", list(args), kwargs) + + facade.agenerate = fake_agenerate + proxy = _AsyncBridgedModelFacade(facade) + + engine_loop = asyncio.new_event_loop() + engine_thread = threading.Thread(target=engine_loop.run_forever, daemon=True) + engine_thread.start() + column_token = current_generation_column.set("intent_label") + + try: + with patch( + "data_designer.engine.dataset_builders.utils.async_concurrency.ensure_async_engine_loop", + return_value=engine_loop, + ): + result = proxy.generate("hello", parser=str) + finally: + current_generation_column.reset(column_token) + engine_loop.call_soon_threadsafe(engine_loop.stop) + engine_thread.join(timeout=5) + + assert result == ("async_result", ["hello"], {"parser": str}) + + +def test_async_bridge_obeys_run_cancellation_before_scheduling() -> None: + """Cancelled runs should not schedule new async model calls from worker threads.""" + facade = Mock() + facade.generate.side_effect = SyncClientUnavailableError( + "Sync methods are not available on an async-mode HttpModelClient." + ) + facade.request_timeout = 60.0 + facade.agenerate = Mock() + proxy = _AsyncBridgedModelFacade(facade) + + cancel_event = threading.Event() + cancel_event.set() + cancel_token = current_run_cancel_event.set(cancel_event) + + try: + with pytest.raises(asyncio.CancelledError): + proxy.generate("hello", parser=str) + finally: + current_run_cancel_event.reset(cancel_token) + + facade.agenerate.assert_not_called() + + def test_async_bridge_non_client_mode_errors_propagate() -> None: """Only SyncClientUnavailableError triggers bridging; other errors propagate.""" # ValueError - different type entirely diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py index f01dc1d91..3d0e4f79f 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py @@ -205,7 +205,7 @@ def __init__(self, **kwargs: object) -> None: model_registry.request_admission = request_admission provider = SimpleNamespace( model_registry=model_registry, - run_config=SimpleNamespace(progress_interval=5.0, progress_bar=False), + run_config=SimpleNamespace(progress_interval=5.0, display_tui=False), ) processor_runner = MagicMock() processor_runner.has_processors_for.return_value = False diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py index 5aea97420..deb5d6453 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_dataset_builder.py @@ -3,6 +3,7 @@ from __future__ import annotations +import concurrent.futures import logging from typing import TYPE_CHECKING from unittest.mock import Mock, patch @@ -720,6 +721,29 @@ def mock_run_coroutine_threadsafe(coro, loop): buffer_manager.free_row_group.assert_not_called() +def test_await_async_scheduler_result_cancels_scheduler_on_keyboard_interrupt() -> None: + class MockFuture: + def __init__(self) -> None: + self.result_calls = 0 + self.cancel = Mock() + + def result(self) -> None: + self.result_calls += 1 + if self.result_calls == 1: + raise KeyboardInterrupt + raise concurrent.futures.CancelledError + + scheduler = Mock() + future = MockFuture() + + with pytest.raises(KeyboardInterrupt): + builder_mod._await_async_scheduler_result(future, scheduler) + + scheduler.request_cancel.assert_called_once_with() + future.cancel.assert_called_once_with() + assert future.result_calls == 2 + + def test_reset_run_state_clears_per_run_signals(stub_resource_provider, stub_test_config_builder) -> None: """``_reset_run_state`` must clear all per-run state so reused builders don't leak.""" builder = DatasetBuilder( diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_sticky_progress_bar.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_sticky_progress_bar.py deleted file mode 100644 index d155be394..000000000 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_sticky_progress_bar.py +++ /dev/null @@ -1,241 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import io -import logging -import os -import re -import shutil -from unittest.mock import patch - -import pytest - -from data_designer.engine.dataset_builders.utils.async_progress_reporter import AsyncProgressReporter -from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker -from data_designer.engine.dataset_builders.utils.sticky_progress_bar import ( - StickyProgressBar, -) - -CURSOR_UP_CLEAR = "\033[A\033[2K" -HIDE_CURSOR = "\033[?25l" -SHOW_CURSOR = "\033[?25h" -_ALL_ANSI_RE = re.compile(r"\033\[[0-9;?]*[a-zA-Z]") - - -class FakeTTY(io.StringIO): - """StringIO that reports itself as a TTY so StickyProgressBar activates.""" - - def isatty(self) -> bool: - return True - - -@pytest.fixture -def tty_stream() -> FakeTTY: - return FakeTTY() - - -def test_no_output_when_not_tty() -> None: - stream = io.StringIO() - with StickyProgressBar(stream=stream) as bar: - bar.add_bar("a", "col_a", 10) - bar.update("a", completed=5, success=5) - assert stream.getvalue() == "" - - -def test_hides_and_shows_cursor(tty_stream: FakeTTY) -> None: - with StickyProgressBar(stream=tty_stream): - pass - output = tty_stream.getvalue() - assert output.startswith(HIDE_CURSOR) - assert output.endswith(SHOW_CURSOR) - - -def test_drawn_lines_tracks_add_and_remove(tty_stream: FakeTTY) -> None: - with StickyProgressBar(stream=tty_stream) as bar: - bar.add_bar("a", "col_a", 10) - bar.add_bar("b", "col_b", 10) - bar.add_bar("c", "col_c", 10) - assert bar.drawn_lines == 3 - - bar.remove_bar("a") - assert bar.drawn_lines == 2 - - bar.add_bar("d", "col_d", 10) - assert bar.drawn_lines == 3 - - bar.update("b", completed=5, success=5) - assert bar.drawn_lines == 3 - - bar.remove_bar("b") - bar.remove_bar("c") - bar.remove_bar("d") - assert bar.drawn_lines == 0 - - -def test_drawn_lines_stable_across_many_updates(tty_stream: FakeTTY) -> None: - with StickyProgressBar(stream=tty_stream) as bar: - bar.add_bar("a", "col_a", 100) - bar.add_bar("b", "col_b", 100) - bar.add_bar("c", "col_c", 100) - for i in range(50): - bar.update("a", completed=i, success=i) - bar.update("b", completed=i, success=i) - bar.update("c", completed=i, success=i) - - snapshot = tty_stream.getvalue() - bar.update("a", completed=50, success=50) - assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == 3 - - -def test_log_interleaving_preserves_drawn_lines(tty_stream: FakeTTY) -> None: - root_logger = logging.getLogger() - handler = logging.StreamHandler(tty_stream) - handler.setFormatter(logging.Formatter("%(message)s")) - root_logger.addHandler(handler) - - try: - with StickyProgressBar(stream=tty_stream) as bar: - bar.add_bar("x", "col_x", 100) - bar.add_bar("y", "col_y", 100) - bar.add_bar("z", "col_z", 100) - - for i in range(20): - bar.update("x", completed=i, success=i) - root_logger.info("log at step %d", i) - bar.update("y", completed=i, success=i) - bar.update("z", completed=i, success=i) - - snapshot = tty_stream.getvalue() - bar.update("x", completed=20, success=20) - assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == 3 - finally: - root_logger.removeHandler(handler) - - -def test_wrapping_counts_physical_lines(tty_stream: FakeTTY) -> None: - narrow = os.terminal_size((40, 24)) - with patch.object(shutil, "get_terminal_size", return_value=narrow): - with StickyProgressBar(stream=tty_stream) as bar: - bar.add_bar("a", "col_a", 100) - bar.add_bar("b", "col_b", 100) - - original_format = bar._format_bar - - def oversized_format(b: object, width: int, label_width: int) -> str: - line = original_format(b, width, label_width) - return line + "X" * 20 - - with patch.object(bar, "_format_bar", side_effect=oversized_format): - bar.update("a", completed=5, success=5) - - snapshot = tty_stream.getvalue() - bar.update("b", completed=1, success=1) - assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) > 2 - - -def test_wrapping_stable_across_updates(tty_stream: FakeTTY) -> None: - narrow = os.terminal_size((40, 24)) - with patch.object(shutil, "get_terminal_size", return_value=narrow): - with StickyProgressBar(stream=tty_stream) as bar: - bar.add_bar("a", "col_a", 100) - bar.add_bar("b", "col_b", 100) - - snapshot = tty_stream.getvalue() - bar.update("a", completed=0, success=0) - initial_clears = tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) - - for i in range(1, 21): - bar.update("a", completed=i, success=i) - bar.update("b", completed=i, success=i) - - snapshot = tty_stream.getvalue() - bar.update("a", completed=21, success=21) - assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == initial_clears - - -def test_narrow_terminal_graceful_degradation(tty_stream: FakeTTY) -> None: - narrow = os.terminal_size((30, 24)) - with patch.object(shutil, "get_terminal_size", return_value=narrow): - with StickyProgressBar(stream=tty_stream) as bar: - bar.add_bar("a", "column 'verification_1'", 300) - bar.update("a", completed=50, success=50) - - snapshot = tty_stream.getvalue() - bar.update("a", completed=51, success=51) - assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == 1 - - output = tty_stream.getvalue() - clean = _ALL_ANSI_RE.sub("", output).replace("\r", "") - for line in clean.split("\n"): - assert len(line) <= 29 - - -def test_update_many_single_redraw(tty_stream: FakeTTY) -> None: - with StickyProgressBar(stream=tty_stream) as bar: - bar.add_bar("a", "col_a", 100) - bar.add_bar("b", "col_b", 100) - before = tty_stream.getvalue() - - bar.update_many({"a": (10, 10, 0), "b": (20, 20, 0)}) - after = tty_stream.getvalue() - - new_output = after[len(before) :] - assert new_output.count(CURSOR_UP_CLEAR) == 2 - - clean = _ALL_ANSI_RE.sub("", after) - assert "10/100" in clean - assert "20/100" in clean - - -def test_update_many_ignores_unknown_keys(tty_stream: FakeTTY) -> None: - with StickyProgressBar(stream=tty_stream) as bar: - bar.add_bar("a", "col_a", 100) - bar.update_many({"a": (10, 10, 0), "unknown": (5, 5, 0)}) - - clean = _ALL_ANSI_RE.sub("", tty_stream.getvalue()) - assert "10/100" in clean - assert "unknown" not in clean - - snapshot = tty_stream.getvalue() - bar.update("a", completed=11, success=11) - assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == 1 - - -def test_reporter_updates_and_logs_keep_drawn_lines_in_sync(tty_stream: FakeTTY) -> None: - root_logger = logging.getLogger() - handler = logging.StreamHandler(tty_stream) - handler.setFormatter(logging.Formatter("%(message)s")) - root_logger.addHandler(handler) - - try: - bar = StickyProgressBar(stream=tty_stream) - trackers = { - "col_a": ProgressTracker(total_records=100, label="column 'a'", quiet=True), - "col_b": ProgressTracker(total_records=100, label="column 'b'", quiet=True), - "col_c": ProgressTracker(total_records=100, label="column 'c'", quiet=True), - } - - with bar: - reporter = AsyncProgressReporter(trackers, report_interval=0.1, progress_bar=bar) - reporter.log_start(num_row_groups=1) - - snapshot = tty_stream.getvalue() - reporter.record_success("col_a") - assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == 3 - - for i in range(49): - if i % 10 == 0: - root_logger.info("Processing batch %d", i) - reporter.record_success("col_b") - reporter.record_success("col_c") - - snapshot = tty_stream.getvalue() - reporter.record_success("col_a") - assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == 3 - - reporter.log_final() - assert bar.drawn_lines == 0 - finally: - root_logger.removeHandler(handler) diff --git a/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py b/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py index af77f8c40..bd0370704 100644 --- a/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py +++ b/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py @@ -25,7 +25,7 @@ RequestGroupSpec, RequestResourceKey, ) -from data_designer.engine.observability import InMemoryAdmissionEventSink +from data_designer.engine.observability import InMemoryAdmissionEventSink, subscribe_request_admission_events def _item(domain: RequestDomain = RequestDomain.CHAT, timeout: float | None = None) -> RequestAdmissionItem: @@ -141,6 +141,30 @@ def test_request_admission_rate_limit_decreases_and_sets_cooldown() -> None: assert snapshot.cooldown_remaining_seconds > 0 +def test_request_admission_broadcasts_global_feedback_events() -> None: + events = [] + unsubscribe = subscribe_request_admission_events(events.append) + try: + controller = _controller( + cap=4, + config=RequestAdmissionConfig( + multiplicative_decrease_factor=0.5, + cooldown_seconds=10, + ), + ) + item = _item() + lease = controller.try_acquire(item) + assert isinstance(lease, RequestAdmissionLease) + + controller.release(lease, RequestReleaseOutcome(kind="rate_limited")) + finally: + unsubscribe() + + event_kinds = {event.event_kind for event in events} + assert "request_rate_limited" in event_kinds + assert "request_limit_decreased" in event_kinds + + def test_request_admission_rate_limit_burst_decreases_once_per_cascade() -> None: controller = _controller( cap=8, diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py index 0be33bd02..abda8aeae 100644 --- a/packages/data-designer-engine/tests/engine/models/test_facade.py +++ b/packages/data-designer-engine/tests/engine/models/test_facade.py @@ -8,6 +8,7 @@ import pytest +from data_designer.engine.context import current_generation_column from data_designer.engine.mcp.errors import MCPConfigurationError, MCPToolError from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind from data_designer.engine.models.clients.types import ( @@ -28,6 +29,7 @@ from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.parsers.errors import ParserException from data_designer.engine.models.usage import TokenCountSource +from data_designer.engine.models.usage_events import TokenUsageEvent, subscribe_token_usage from data_designer.engine.models.utils import ChatMessage from data_designer.engine.testing import StubMCPFacade, StubMCPRegistry, make_stub_completion_response @@ -344,6 +346,53 @@ def test_completion_tracks_reasoning_tokens_without_changing_output_tokens( assert token_usage.total_tokens == 18 +def test_completion_emits_token_usage_event( + stub_model_facade: ModelFacade, + stub_model_client: MagicMock, +) -> None: + events: list[TokenUsageEvent] = [] + unsubscribe = subscribe_token_usage(events.append) + stub_model_client.completion.return_value = ChatCompletionResponse( + message=AssistantMessage(content="ok"), + usage=Usage(input_tokens=10, output_tokens=8), + ) + token = current_generation_column.set("intent_label") + + try: + stub_model_facade.completion([ChatMessage.as_user("hi")]) + finally: + current_generation_column.reset(token) + unsubscribe() + + assert len(events) == 1 + assert events[0].model_alias == stub_model_facade.model_alias + assert events[0].model_name == stub_model_facade.model_name + assert events[0].input_tokens == 10 + assert events[0].output_tokens == 8 + assert events[0].column == "intent_label" + + +def test_completion_emits_token_usage_event_when_only_output_tokens_are_reported( + stub_model_facade: ModelFacade, + stub_model_client: MagicMock, +) -> None: + events: list[TokenUsageEvent] = [] + unsubscribe = subscribe_token_usage(events.append) + stub_model_client.completion.return_value = ChatCompletionResponse( + message=AssistantMessage(content="ok"), + usage=Usage(output_tokens=8), + ) + + try: + stub_model_facade.completion([ChatMessage.as_user("hi")]) + finally: + unsubscribe() + + assert len(events) == 1 + assert events[0].input_tokens == 0 + assert events[0].output_tokens == 8 + + def test_consolidate_kwargs(stub_model_configs: list[Any], stub_model_facade: ModelFacade) -> None: # Model config generate kwargs are used as base, and purpose is removed. # When telemetry is enabled (default), X-Title is injected. diff --git a/packages/data-designer-engine/tests/engine/progress/terminal/test_throughput_panel.py b/packages/data-designer-engine/tests/engine/progress/terminal/test_throughput_panel.py new file mode 100644 index 000000000..0fe574295 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/progress/terminal/test_throughput_panel.py @@ -0,0 +1,478 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import io +import logging +import os +import re +import shutil +import time +from collections.abc import Iterator +from unittest.mock import patch + +import pytest + +from data_designer.engine.models.usage_events import TokenUsageEvent, emit_token_usage_event +from data_designer.engine.observability import ( + RequestAdmissionEvent, + RuntimeCorrelation, + emit_request_admission_event, +) +from data_designer.engine.progress.reporter import AsyncProgressReporter +from data_designer.engine.progress.terminal.throughput_panel import ( + _CHART_LINE_COUNT, + _MAX_RATE_SAMPLES, + _RATE_SAMPLE_INTERVAL_SECONDS, + TerminalThroughputPanel, + _BarState, + _fit_series, +) +from data_designer.engine.progress.tracker import ProgressTracker + +CURSOR_UP_CLEAR = "\033[A\033[2K" +HIDE_CURSOR = "\033[?25l" +SHOW_CURSOR = "\033[?25h" +_ALL_ANSI_RE = re.compile(r"\033\[[0-9;?]*[a-zA-Z]") + + +class FakeTTY(io.StringIO): + """StringIO that reports itself as a TTY so TerminalThroughputPanel activates.""" + + def isatty(self) -> bool: + return True + + +@pytest.fixture +def tty_stream() -> FakeTTY: + return FakeTTY() + + +@pytest.fixture(autouse=True) +def fixed_terminal_size() -> Iterator[None]: + with patch.object(shutil, "get_terminal_size", return_value=os.terminal_size((80, 24))): + yield + + +def _clean(text: str) -> str: + return _ALL_ANSI_RE.sub("", text).replace("\r", "") + + +def _correlation(run_id: str) -> RuntimeCorrelation: + return RuntimeCorrelation( + run_id=run_id, + row_group=0, + task_column="col_a", + task_type="cell", + scheduling_group_kind="model", + scheduling_group_identity_hash="hash", + task_execution_id="task-exec", + ) + + +def _last_panel_lines(output: str) -> list[str]: + clean = _clean(output) + panel_start = clean.rfind("\n╭") + panel_start = panel_start + 1 if panel_start >= 0 else clean.rfind("╭") + assert panel_start >= 0 + return clean[panel_start:].splitlines() + + +def _chart_lines(panel_lines: list[str]) -> list[str]: + separator_index = next(index for index, line in enumerate(panel_lines) if "├" in line) + return panel_lines[2:separator_index] + + +def _marker_positions(panel_lines: list[str]) -> list[tuple[int, int]]: + return [(row_index, line.index("◆")) for row_index, line in enumerate(_chart_lines(panel_lines)) if "◆" in line] + + +def test_no_output_when_not_tty() -> None: + stream = io.StringIO() + with TerminalThroughputPanel(stream=stream) as bar: + bar.add_bar("a", "col_a", 10) + bar.update("a", completed=5, success=5) + assert stream.getvalue() == "" + + +def test_hides_and_shows_cursor(tty_stream: FakeTTY) -> None: + with TerminalThroughputPanel(stream=tty_stream): + pass + output = tty_stream.getvalue() + assert output.startswith(HIDE_CURSOR) + assert output.endswith(SHOW_CURSOR) + + +def test_tiny_terminal_falls_back_to_no_panel(tty_stream: FakeTTY) -> None: + with patch.object(shutil, "get_terminal_size", return_value=os.terminal_size((20, 24))): + with TerminalThroughputPanel(stream=tty_stream) as bar: + assert bar.is_active is False + bar.add_bar("a", "col_a", 10) + bar.update("a", completed=5, success=5, force=True) + + assert tty_stream.getvalue() == "" + + +def test_renders_bounded_throughput_panel(tty_stream: FakeTTY) -> None: + with TerminalThroughputPanel(stream=tty_stream) as bar: + bar.add_bar("a", "column 'a'", 100) + bar.add_bar("b", "column 'b'", 100) + bar.update_many({"a": (10, 10, 0, 0), "b": (20, 20, 0, 0)}, force=True) + + assert bar.drawn_lines == 22 + panel_lines = _last_panel_lines(tty_stream.getvalue()) + panel = "\n".join(panel_lines) + assert "Throughput" in panel + assert "rec/s" in panel + assert "now rec/s" in panel + assert "avg rec/s" in panel + assert "column 'a'" in panel + assert "10/100" in panel + assert "column 'b'" in panel + assert "20/100" in panel + header = next(line for line in panel_lines if "now rec/s" in line) + row = next(line for line in panel_lines if "column 'a'" in line) + assert "|" not in header + assert "|" not in row + assert "in tok/s" not in panel + assert "out tok/s" not in panel + assert header.index("avg rec/s") < header.index("done") + assert "━" in row + assert row.rindex("0.0") < row.index("10/100") + assert row.index("10/100") < row.index("━") + assert "╭" in panel + assert "╰" in panel + + +def test_model_usage_rates_render_in_separate_table(tty_stream: FakeTTY) -> None: + with TerminalThroughputPanel(stream=tty_stream) as bar: + bar.add_bar("a", "column 'a'", 100) + bar.update("a", completed=10, success=10, force=True) + bar._start_time = time.perf_counter() - 10.0 # noqa: SLF001 + bar.record_model_usage( + model_alias="test", + model_name="test-model", + input_tokens=100, + output_tokens=25, + force=True, + ) + + panel = "\n".join(_last_panel_lines(tty_stream.getvalue())) + assert "model alias" in panel + assert "model name" in panel + assert "test" in panel + assert "test-model" in panel + assert "rpm" in panel + assert "in tok/s" in panel + assert "out tok/s" in panel + assert "6.0" in panel + assert "10.0" in panel + assert "2.5" in panel + + +def test_many_columns_and_models_do_not_shrink_chart(tty_stream: FakeTTY) -> None: + with TerminalThroughputPanel(stream=tty_stream) as bar: + for index in range(8): + bar.add_bar(f"col_{index}", f"column_{index}", 100) + bar.update_many( + {f"col_{index}": (index + 1, index + 1, 0, 0) for index in range(8)}, + force=True, + ) + for index in range(8): + bar.record_model_usage( + model_alias=f"model_{index}", + model_name=f"provider/model-{index}", + input_tokens=100 + index, + output_tokens=10 + index, + force=True, + ) + + panel_lines = _last_panel_lines(tty_stream.getvalue()) + panel = "\n".join(panel_lines) + assert len(_chart_lines(panel_lines)) == _CHART_LINE_COUNT + assert len(panel_lines) > 22 + assert "more column(s)" not in panel + assert "more model(s)" not in panel + for index in range(8): + assert f"column_{index}" in panel + assert f"model_{index}" in panel + + +def test_feedback_marker_reprojects_as_elapsed_time_grows(tty_stream: FakeTTY) -> None: + with TerminalThroughputPanel(stream=tty_stream) as bar: + bar.add_bar("a", "column_a", 100) + state = bar._bars["a"] # noqa: SLF001 + state.rates = [0.0, 10.0, 20.0] + state.latest_rate = 12.0 + bar._start_time = time.perf_counter() - 10.0 # noqa: SLF001 + + bar.record_feedback_signal(event_kind="request_rate_limited", force=True) + before_positions = _marker_positions(_last_panel_lines(tty_stream.getvalue())) + assert before_positions + + bar._start_time = time.perf_counter() - 100.0 # noqa: SLF001 + bar.update("a", completed=20, success=20, force=True) + after_positions = _marker_positions(_last_panel_lines(tty_stream.getvalue())) + + assert after_positions + assert after_positions[0][1] < before_positions[0][1] + + +def test_control_sequences_are_removed_from_labels(tty_stream: FakeTTY) -> None: + with TerminalThroughputPanel(stream=tty_stream) as bar: + bar.add_bar("a", "col\x1b[31m_a\nsuffix", 100) + bar.update("a", completed=10, success=10, force=True) + + clean = _clean(tty_stream.getvalue()) + assert "col_asuffix" in clean + + +def test_rate_samples_are_bounded() -> None: + state = _BarState(label="col_a", total=1_000_000, start_time=0.0, last_sample_time=0.0) + + for index in range(_MAX_RATE_SAMPLES + 5): + completed = (index + 1) * 10 + state.record_update( + completed=completed, + success=completed, + failed=0, + skipped=0, + now=(index + 1) * _RATE_SAMPLE_INTERVAL_SECONDS, + ) + + assert len(state.rates) == _MAX_RATE_SAMPLES + + +def test_sparse_rate_samples_span_chart_width() -> None: + fitted = _fit_series([0.0, 10.0, 5.0], 7) + + assert len(fitted) == 7 + assert fitted[0] == 0.0 + assert fitted[3] == pytest.approx(10.0) + assert fitted[-1] == 5.0 + + +def test_frequent_updates_are_redraw_throttled(tty_stream: FakeTTY) -> None: + with TerminalThroughputPanel(stream=tty_stream) as bar: + bar.add_bar("a", "col_a", 100) + bar.add_bar("b", "col_b", 100) + bar.update_many({"a": (1, 1, 0, 0), "b": (2, 2, 0, 0)}, force=True) + + snapshot = tty_stream.getvalue() + for i in range(50): + bar.update_many({"a": (i, i, 0, 0), "b": (i * 2, i * 2, 0, 0)}) + + assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == 0 + + bar.update("a", completed=50, success=50, force=True) + assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == 22 + assert bar.drawn_lines == 22 + + +def test_log_interleaving_preserves_panel_height(tty_stream: FakeTTY) -> None: + root_logger = logging.getLogger() + handler = logging.StreamHandler(tty_stream) + handler.setFormatter(logging.Formatter("%(message)s")) + root_logger.addHandler(handler) + + try: + with TerminalThroughputPanel(stream=tty_stream) as bar: + bar.add_bar("x", "col_x", 100) + bar.add_bar("y", "col_y", 100) + + for i in range(10): + bar.update("x", completed=i, success=i) + root_logger.info("log at step %d", i) + bar.update("y", completed=i, success=i) + + snapshot = tty_stream.getvalue() + bar.update("x", completed=20, success=20, force=True) + assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == 22 + finally: + root_logger.removeHandler(handler) + + +def test_narrow_terminal_keeps_panel_within_width(tty_stream: FakeTTY) -> None: + narrow = os.terminal_size((36, 24)) + with patch.object(shutil, "get_terminal_size", return_value=narrow): + with TerminalThroughputPanel(stream=tty_stream) as bar: + bar.add_bar("a", "column 'verification_1'", 300) + bar.update("a", completed=50, success=50, force=True) + + output = tty_stream.getvalue() + for line in _last_panel_lines(output): + assert len(line) <= 35 + + +def test_update_many_single_redraw(tty_stream: FakeTTY) -> None: + with TerminalThroughputPanel(stream=tty_stream) as bar: + bar.add_bar("a", "col_a", 100) + bar.add_bar("b", "col_b", 100) + before = tty_stream.getvalue() + + bar.update_many({"a": (10, 10, 0, 0), "b": (20, 20, 0, 0)}, force=True) + after = tty_stream.getvalue() + + new_output = after[len(before) :] + assert new_output.count(CURSOR_UP_CLEAR) == 22 + + clean = _clean(after) + assert "10/100" in clean + assert "20/100" in clean + + +def test_update_many_includes_failures_and_skips(tty_stream: FakeTTY) -> None: + with TerminalThroughputPanel(stream=tty_stream) as bar: + bar.add_bar("a", "col_a", 100) + bar.update_many({"a": (10, 7, 2, 1), "unknown": (5, 5, 0, 0)}, force=True) + + clean = _clean(tty_stream.getvalue()) + assert "10/100" in clean + assert "2 failed" in clean + assert "1 skipped" in clean + assert "unknown" not in clean + + +def test_remove_bar_redraws_panel(tty_stream: FakeTTY) -> None: + with TerminalThroughputPanel(stream=tty_stream) as bar: + bar.add_bar("a", "col_a", 100) + bar.add_bar("b", "col_b", 100) + + snapshot = tty_stream.getvalue() + bar.remove_bar("a") + + new_output = tty_stream.getvalue()[len(snapshot) :] + assert new_output.count(CURSOR_UP_CLEAR) == 22 + panel = "\n".join(_last_panel_lines(tty_stream.getvalue())) + assert "col_a" not in panel + assert "col_b" in panel + + +def test_reporter_updates_and_logs_keep_drawn_lines_in_sync(tty_stream: FakeTTY) -> None: + root_logger = logging.getLogger() + old_level = root_logger.level + root_logger.setLevel(logging.INFO) + handler = logging.StreamHandler(tty_stream) + handler.setFormatter(logging.Formatter("%(message)s")) + root_logger.addHandler(handler) + + try: + bar = TerminalThroughputPanel(stream=tty_stream) + trackers = { + "col_a": ProgressTracker(total_records=100, label="column 'a'", quiet=True), + "col_b": ProgressTracker(total_records=100, label="column 'b'", quiet=True), + "col_c": ProgressTracker(total_records=100, label="column 'c'", quiet=True), + } + + with bar: + reporter = AsyncProgressReporter(trackers, report_interval=0.1, progress_bar=bar) + reporter.log_start(num_row_groups=1) + panel = "\n".join(_last_panel_lines(tty_stream.getvalue())) + assert "col_a" in panel + assert "column 'a'" not in panel + + emit_token_usage_event( + TokenUsageEvent( + model_alias="test", + model_name="test-model", + input_tokens=120, + output_tokens=30, + ) + ) + assert bar._model_usage["test"].input_tokens == 120 # noqa: SLF001 + assert bar._model_usage["test"].output_tokens == 30 # noqa: SLF001 + + snapshot = tty_stream.getvalue() + reporter.record_success("col_a") + assert tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) == 0 + + for i in range(49): + if i % 10 == 0: + root_logger.info("Processing batch %d", i) + reporter.record_success("col_b") + reporter.record_skipped("col_c") + + snapshot = tty_stream.getvalue() + reporter.log_final() + assert bar.drawn_lines == 22 + clear_count = tty_stream.getvalue()[len(snapshot) :].count(CURSOR_UP_CLEAR) + assert clear_count >= bar.drawn_lines + assert clear_count % bar.drawn_lines == 0 + finally: + root_logger.removeHandler(handler) + root_logger.setLevel(old_level) + + +def test_reporter_records_feedback_markers_from_request_events(tty_stream: FakeTTY) -> None: + trackers = {"col_a": ProgressTracker(total_records=100, label="column 'a'", quiet=True)} + + with TerminalThroughputPanel(stream=tty_stream) as bar: + reporter = AsyncProgressReporter(trackers, report_interval=0.1, progress_bar=bar) + try: + emit_request_admission_event( + RequestAdmissionEvent.capture("request_rate_limited", sequence=1), + ) + assert len(bar._feedback_markers) == 1 # noqa: SLF001 + + emit_request_admission_event( + RequestAdmissionEvent.capture("request_wait_started", sequence=2), + ) + assert len(bar._feedback_markers) == 1 # noqa: SLF001 + finally: + reporter.close() + + emit_request_admission_event( + RequestAdmissionEvent.capture("request_rate_limited", sequence=3), + ) + assert len(bar._feedback_markers) == 1 # noqa: SLF001 + + +def test_reporter_filters_global_events_by_run_id(tty_stream: FakeTTY) -> None: + trackers = {"col_a": ProgressTracker(total_records=100, label="column 'a'", quiet=True)} + + with TerminalThroughputPanel(stream=tty_stream) as bar: + reporter = AsyncProgressReporter(trackers, report_interval=0.1, progress_bar=bar, run_id="run-a") + try: + emit_token_usage_event( + TokenUsageEvent( + model_alias="other", + model_name="other-model", + input_tokens=100, + output_tokens=10, + correlation=_correlation("run-b"), + ) + ) + emit_token_usage_event( + TokenUsageEvent( + model_alias="uncorrelated", + model_name="uncorrelated-model", + input_tokens=100, + output_tokens=10, + ) + ) + assert not bar._model_usage # noqa: SLF001 + + emit_token_usage_event( + TokenUsageEvent( + model_alias="owned", + model_name="owned-model", + input_tokens=120, + output_tokens=30, + correlation=_correlation("run-a"), + ) + ) + assert set(bar._model_usage) == {"owned"} # noqa: SLF001 + + emit_request_admission_event( + RequestAdmissionEvent.capture("request_rate_limited", sequence=1, correlation=_correlation("run-b")) + ) + emit_request_admission_event(RequestAdmissionEvent.capture("request_rate_limited", sequence=2)) + assert not bar._feedback_markers # noqa: SLF001 + + emit_request_admission_event( + RequestAdmissionEvent.capture("request_rate_limited", sequence=3, correlation=_correlation("run-a")) + ) + assert len(bar._feedback_markers) == 1 # noqa: SLF001 + finally: + reporter.close() diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_progress_tracker.py b/packages/data-designer-engine/tests/engine/progress/test_tracker.py similarity index 98% rename from packages/data-designer-engine/tests/engine/dataset_builders/utils/test_progress_tracker.py rename to packages/data-designer-engine/tests/engine/progress/test_tracker.py index 13b698a22..47b45ca80 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_progress_tracker.py +++ b/packages/data-designer-engine/tests/engine/progress/test_tracker.py @@ -1,12 +1,14 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import logging import threading import pytest -from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker +from data_designer.engine.progress.tracker import ProgressTracker @pytest.fixture diff --git a/packages/data-designer/src/data_designer/cli/commands/create.py b/packages/data-designer/src/data_designer/cli/commands/create.py index 5a739c25a..f99f7ec22 100644 --- a/packages/data-designer/src/data_designer/cli/commands/create.py +++ b/packages/data-designer/src/data_designer/cli/commands/create.py @@ -61,6 +61,14 @@ def create_command( "The file is written to //.." ), ), + tui: bool | None = typer.Option( + None, + "--tui/--no-tui", + help=( + "Force the terminal progress TUI on or off for this run. " + "When omitted, uses the configured RunConfig setting." + ), + ), ) -> None: """Create a full dataset and save results to disk. @@ -91,4 +99,5 @@ def create_command( artifact_path=artifact_path, resume=resume, output_format=output_format, + tui=tui, ) diff --git a/packages/data-designer/src/data_designer/cli/controllers/generation_controller.py b/packages/data-designer/src/data_designer/cli/controllers/generation_controller.py index 4a4231c41..09f0853d5 100644 --- a/packages/data-designer/src/data_designer/cli/controllers/generation_controller.py +++ b/packages/data-designer/src/data_designer/cli/controllers/generation_controller.py @@ -119,6 +119,7 @@ def run_create( artifact_path: str | None, resume: ResumeMode = ResumeMode.NEVER, output_format: str | None = None, + tui: bool | None = None, ) -> None: """Load config, create a full dataset, and save results to disk. @@ -130,6 +131,8 @@ def run_create( resume: Controls how interrupted runs are handled. output_format: If set, export the dataset to a single file in this format after generation. One of 'jsonl', 'csv', 'parquet'. + tui: If set, overrides the active RunConfig display_tui setting for this + create invocation's terminal UI. """ config_builder = self._load_config(config_source) @@ -144,6 +147,8 @@ def run_create( try: data_designer = DataDesigner(artifact_path=resolved_artifact_path) + if tui is not None: + data_designer.set_run_config(data_designer.run_config.model_copy(update={"display_tui": tui})) results = data_designer.create( config_builder, num_records=num_records, diff --git a/packages/data-designer/tests/cli/commands/test_create_command.py b/packages/data-designer/tests/cli/commands/test_create_command.py index 8b3335d4e..40e9fa63e 100644 --- a/packages/data-designer/tests/cli/commands/test_create_command.py +++ b/packages/data-designer/tests/cli/commands/test_create_command.py @@ -26,6 +26,7 @@ def test_create_command_delegates_to_controller(mock_ctrl_cls: MagicMock) -> Non artifact_path=None, resume=ResumeMode.NEVER, output_format=None, + tui=None, ) mock_ctrl_cls.assert_called_once() @@ -36,6 +37,7 @@ def test_create_command_delegates_to_controller(mock_ctrl_cls: MagicMock) -> Non artifact_path=None, resume=ResumeMode.NEVER, output_format=None, + tui=None, ) @@ -52,6 +54,7 @@ def test_create_command_passes_custom_options(mock_ctrl_cls: MagicMock) -> None: artifact_path="/custom/output", resume=ResumeMode.NEVER, output_format=None, + tui=None, ) mock_ctrl.run_create.assert_called_once_with( @@ -61,6 +64,7 @@ def test_create_command_passes_custom_options(mock_ctrl_cls: MagicMock) -> None: artifact_path="/custom/output", resume=ResumeMode.NEVER, output_format=None, + tui=None, ) @@ -77,6 +81,7 @@ def test_create_command_default_artifact_path_is_none(mock_ctrl_cls: MagicMock) artifact_path=None, resume=ResumeMode.NEVER, output_format=None, + tui=None, ) mock_ctrl.run_create.assert_called_once_with( @@ -86,6 +91,7 @@ def test_create_command_default_artifact_path_is_none(mock_ctrl_cls: MagicMock) artifact_path=None, resume=ResumeMode.NEVER, output_format=None, + tui=None, ) @@ -102,6 +108,7 @@ def test_create_command_passes_resume_always(mock_ctrl_cls: MagicMock) -> None: artifact_path=None, resume=ResumeMode.ALWAYS, output_format=None, + tui=None, ) mock_ctrl.run_create.assert_called_once_with( @@ -111,6 +118,7 @@ def test_create_command_passes_resume_always(mock_ctrl_cls: MagicMock) -> None: artifact_path=None, resume=ResumeMode.ALWAYS, output_format=None, + tui=None, ) @@ -127,6 +135,7 @@ def test_create_command_passes_resume_if_possible(mock_ctrl_cls: MagicMock) -> N artifact_path=None, resume=ResumeMode.IF_POSSIBLE, output_format=None, + tui=None, ) mock_ctrl.run_create.assert_called_once_with( @@ -136,6 +145,7 @@ def test_create_command_passes_resume_if_possible(mock_ctrl_cls: MagicMock) -> N artifact_path=None, resume=ResumeMode.IF_POSSIBLE, output_format=None, + tui=None, ) @@ -152,6 +162,7 @@ def test_create_command_passes_output_format(mock_ctrl_cls: MagicMock) -> None: artifact_path=None, resume=ResumeMode.NEVER, output_format="jsonl", + tui=None, ) mock_ctrl.run_create.assert_called_once_with( @@ -161,4 +172,32 @@ def test_create_command_passes_output_format(mock_ctrl_cls: MagicMock) -> None: artifact_path=None, resume=ResumeMode.NEVER, output_format="jsonl", + tui=None, + ) + + +@patch("data_designer.cli.commands.create.GenerationController") +def test_create_command_passes_tui_override(mock_ctrl_cls: MagicMock) -> None: + """Test create_command forwards explicit TUI override.""" + mock_ctrl = MagicMock() + mock_ctrl_cls.return_value = mock_ctrl + + create_command( + config_source="config.yaml", + num_records=10, + dataset_name="dataset", + artifact_path=None, + resume=ResumeMode.NEVER, + output_format=None, + tui=False, + ) + + mock_ctrl.run_create.assert_called_once_with( + config_source="config.yaml", + num_records=10, + dataset_name="dataset", + artifact_path=None, + resume=ResumeMode.NEVER, + output_format=None, + tui=False, ) diff --git a/packages/data-designer/tests/cli/controllers/test_generation_controller.py b/packages/data-designer/tests/cli/controllers/test_generation_controller.py index b8047641a..eef3a34fa 100644 --- a/packages/data-designer/tests/cli/controllers/test_generation_controller.py +++ b/packages/data-designer/tests/cli/controllers/test_generation_controller.py @@ -13,6 +13,7 @@ from data_designer.cli.utils.config_loader import ConfigLoadError from data_designer.config.config_builder import DataDesignerConfigBuilder from data_designer.config.errors import InvalidConfigError +from data_designer.config.run_config import RunConfig from data_designer.config.utils.constants import DEFAULT_DISPLAY_WIDTH from data_designer.engine.storage.artifact_storage import ResumeMode @@ -711,6 +712,33 @@ def test_run_create_custom_options(mock_load_config: MagicMock, mock_dd_cls: Mag ) +@pytest.mark.parametrize("tui", [True, False]) +@patch(f"{_CTRL}.DataDesigner") +@patch(f"{_CTRL}.load_config_builder") +def test_run_create_applies_tui_override(mock_load_config: MagicMock, mock_dd_cls: MagicMock, tui: bool) -> None: + """run_create applies explicit --tui/--no-tui override to RunConfig.""" + mock_load_config.return_value = MagicMock(spec=DataDesignerConfigBuilder) + mock_dd = MagicMock() + mock_dd.run_config = RunConfig(display_tui=not tui) + mock_dd_cls.return_value = mock_dd + mock_dd.create.return_value = _make_mock_create_results() + + controller = GenerationController() + controller.run_create( + config_source="config.yaml", + num_records=10, + dataset_name="dataset", + artifact_path=None, + tui=tui, + ) + + mock_dd.set_run_config.assert_called_once() + assert mock_dd.set_run_config.call_args.args[0].display_tui is tui + mock_dd.create.assert_called_once_with( + mock_load_config.return_value, num_records=10, dataset_name="dataset", resume=ResumeMode.NEVER + ) + + @patch(f"{_CTRL}.load_config_builder") def test_run_create_config_load_error(mock_load_config: MagicMock) -> None: """Test create exits with code 1 when config fails to load.""" diff --git a/packages/data-designer/tests/cli/test_main.py b/packages/data-designer/tests/cli/test_main.py index 928e85159..9e11a656c 100644 --- a/packages/data-designer/tests/cli/test_main.py +++ b/packages/data-designer/tests/cli/test_main.py @@ -177,6 +177,42 @@ def test_app_dispatches_lazy_create_command(mock_controller_cls: Mock) -> None: artifact_path=None, resume=ResumeMode.NEVER, output_format=None, + tui=None, + ) + + +@patch("data_designer.cli.commands.create.GenerationController") +def test_app_dispatches_create_tui_flags(mock_controller_cls: Mock) -> None: + """The create command parses --tui/--no-tui into the TUI override.""" + mock_controller = Mock() + mock_controller_cls.return_value = mock_controller + + result = runner.invoke(app, ["create", "config.yaml", "--no-tui"]) + + assert result.exit_code == 0 + mock_controller.run_create.assert_called_once_with( + config_source="config.yaml", + num_records=DEFAULT_NUM_RECORDS, + dataset_name="dataset", + artifact_path=None, + resume=ResumeMode.NEVER, + output_format=None, + tui=False, + ) + + mock_controller.reset_mock() + + result = runner.invoke(app, ["create", "config.yaml", "--tui"]) + + assert result.exit_code == 0 + mock_controller.run_create.assert_called_once_with( + config_source="config.yaml", + num_records=DEFAULT_NUM_RECORDS, + dataset_name="dataset", + artifact_path=None, + resume=ResumeMode.NEVER, + output_format=None, + tui=True, ) diff --git a/uv.lock b/uv.lock index 6d01ccc55..317034df3 100644 --- a/uv.lock +++ b/uv.lock @@ -269,6 +269,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ed/c9/d7977eaacb9df673210491da99e6a247e93df98c715fc43fd136ce1d3d33/arrow-1.4.0-py3-none-any.whl", hash = "sha256:749f0769958ebdc79c173ff0b0670d59051a535fa26e8eba02953dc19eb43205", size = 68797, upload-time = "2025-10-18T17:46:45.663Z" }, ] +[[package]] +name = "asciichartpy" +version = "1.5.25" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/41/3a/b01436be647f881515ec2f253616bf4a40c1d27d02a69e7f038e27fcdf81/asciichartpy-1.5.25.tar.gz", hash = "sha256:63a305302b2aad51da288b58226009b7b0313eba7d8e2452d5a21a13fcf44d74", size = 8201, upload-time = "2020-08-17T02:07:18.292Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/d0/7b958df957e4827837b590944008f0b28078f552b451f7407b4b3d54f574/asciichartpy-1.5.25-py2.py3-none-any.whl", hash = "sha256:33c417a3c8ef7d0a11b98eb9ea6dd9b2c1b17559e539b207a17d26d4302d0258", size = 7228, upload-time = "2020-08-17T02:07:16.386Z" }, +] + +[[package]] +name = "astroid" +version = "3.3.11" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/18/74/dfb75f9ccd592bbedb175d4a32fc643cf569d7c218508bfbd6ea7ef9c091/astroid-3.3.11.tar.gz", hash = "sha256:1e5a5011af2920c7c67a53f65d536d65bfa7116feeaf2354d8b94f29573bb0ce", size = 400439, upload-time = "2025-07-13T18:04:23.177Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/0f/3b8fdc946b4d9cc8cc1e8af42c4e409468c84441b933d037e101b3d72d86/astroid-3.3.11-py3-none-any.whl", hash = "sha256:54c760ae8322ece1abd213057c4b5bba7c49818853fc901ef09719a60dbf9dec", size = 275612, upload-time = "2025-07-13T18:04:21.07Z" }, +] + [[package]] name = "asttokens" version = "3.0.1" @@ -847,6 +871,7 @@ name = "data-designer-engine" source = { editable = "packages/data-designer-engine" } dependencies = [ { name = "anyascii" }, + { name = "asciichartpy" }, { name = "chardet" }, { name = "cryptography" }, { name = "data-designer-config" }, @@ -877,6 +902,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "anyascii", specifier = ">=0.3.3,<1" }, + { name = "asciichartpy", specifier = ">=1.5.25,<2" }, { name = "chardet", specifier = ">=3.0.2,<6" }, { name = "cryptography", specifier = ">=46.0.7,<47" }, { name = "data-designer-config", editable = "packages/data-designer-config" },