From f82a18f35c38e777363c4fe0d43ee3ccc0dbbc4d Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 18 May 2026 14:18:01 -0600 Subject: [PATCH 01/15] feat: add audio and video context Add audio/video context config models and canonical media helpers. Translate canonical media blocks for OpenAI-compatible clients while preserving URL media as URLs. Reject unsupported audio/video blocks in the Anthropic adapter. Refs #671 --- .../src/data_designer/config/__init__.py | 9 ++ .../data_designer/config/column_configs.py | 24 +-- .../src/data_designer/config/models.py | 126 ++++++++++++++- .../config/utils/media_helpers.py | 115 ++++++++++++++ .../tests/config/test_columns.py | 37 ++++- .../tests/config/test_models.py | 149 ++++++++++++++++++ .../tests/config/utils/test_media_helpers.py | 54 +++++++ .../column_generators/generators/base.py | 5 +- .../models/clients/adapters/anthropic.py | 11 ++ .../clients/adapters/anthropic_translation.py | 31 ++++ .../clients/adapters/openai_compatible.py | 137 +++++++++++++++- .../src/data_designer/engine/models/utils.py | 2 +- .../engine/models/clients/test_anthropic.py | 26 +++ .../clients/test_anthropic_translation.py | 21 +++ .../models/clients/test_openai_compatible.py | 83 ++++++++++ .../tests/engine/models/test_model_utils.py | 17 ++ .../tests/engine/test_validation.py | 14 +- 17 files changed, 834 insertions(+), 27 deletions(-) create mode 100644 packages/data-designer-config/src/data_designer/config/utils/media_helpers.py create mode 100644 packages/data-designer-config/tests/config/utils/test_media_helpers.py diff --git a/packages/data-designer-config/src/data_designer/config/__init__.py b/packages/data-designer-config/src/data_designer/config/__init__.py index a8f683aa3..cf2a4c4d4 100644 --- a/packages/data-designer-config/src/data_designer/config/__init__.py +++ b/packages/data-designer-config/src/data_designer/config/__init__.py @@ -38,6 +38,7 @@ ToolConfig, ) from data_designer.config.models import ( # noqa: F401 + AudioContext, ChatCompletionInferenceParams, EmbeddingInferenceParams, GenerationType, @@ -50,8 +51,10 @@ ModalityDataType, ModelConfig, ModelProvider, + MultiModalContextT, UniformDistribution, UniformDistributionParams, + VideoContext, ) from data_designer.config.processors import ( # noqa: F401 DropColumnsProcessorConfig, @@ -106,6 +109,7 @@ from data_designer.config.utils.code_lang import CodeLang # noqa: F401 from data_designer.config.utils.image_helpers import ImageFormat # noqa: F401 from data_designer.config.utils.info import InfoType # noqa: F401 + from data_designer.config.utils.media_helpers import AudioFormat, VideoFormat # noqa: F401 from data_designer.config.utils.trace_type import TraceType # noqa: F401 from data_designer.config.validator_params import ( # noqa: F401 CodeValidatorParams, @@ -161,6 +165,8 @@ "MCPProvider": (_MOD_MCP, "MCPProvider"), "ToolConfig": (_MOD_MCP, "ToolConfig"), # models + "AudioContext": (_MOD_MODELS, "AudioContext"), + "AudioFormat": (f"{_MOD_UTILS}.media_helpers", "AudioFormat"), "ChatCompletionInferenceParams": (_MOD_MODELS, "ChatCompletionInferenceParams"), "EmbeddingInferenceParams": (_MOD_MODELS, "EmbeddingInferenceParams"), "GenerationType": (_MOD_MODELS, "GenerationType"), @@ -174,8 +180,11 @@ "ModalityDataType": (_MOD_MODELS, "ModalityDataType"), "ModelConfig": (_MOD_MODELS, "ModelConfig"), "ModelProvider": (_MOD_MODELS, "ModelProvider"), + "MultiModalContextT": (_MOD_MODELS, "MultiModalContextT"), "UniformDistribution": (_MOD_MODELS, "UniformDistribution"), "UniformDistributionParams": (_MOD_MODELS, "UniformDistributionParams"), + "VideoContext": (_MOD_MODELS, "VideoContext"), + "VideoFormat": (f"{_MOD_UTILS}.media_helpers", "VideoFormat"), # processors "DropColumnsProcessorConfig": (_MOD_PROCESSORS, "DropColumnsProcessorConfig"), "ProcessorType": (_MOD_PROCESSORS, "ProcessorType"), diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py index a1016fec1..c770d6e65 100644 --- a/packages/data-designer-config/src/data_designer/config/column_configs.py +++ b/packages/data-designer-config/src/data_designer/config/column_configs.py @@ -11,7 +11,7 @@ from data_designer.config.base import ConfigBase, SingleColumnConfig from data_designer.config.errors import InvalidConfigError -from data_designer.config.models import ImageContext +from data_designer.config.models import MultiModalContextT from data_designer.config.sampler_params import SamplerParamsT, SamplerType from data_designer.config.utils.code_lang import CodeLang from data_designer.config.utils.constants import REASONING_CONTENT_COLUMN_POSTFIX, TRACE_COLUMN_POSTFIX @@ -139,8 +139,8 @@ class LLMTextColumnConfig(SingleColumnConfig): Do not put any output parsing instructions in the system prompt. Instead, use the appropriate column type for the output you want to generate - e.g., `LLMStructuredColumnConfig` for structured output, `LLMCodeColumnConfig` for code. - multi_modal_context: Optional list of image contexts for multi-modal generation. - Enables vision-capable models to generate text based on image inputs. + multi_modal_context: Optional list of multimodal contexts for generation. + Enables capable models to generate text based on image, audio, or video inputs. tool_alias: Optional alias of the tool configuration to use for MCP tool calls. Must match a tool alias defined when initializing the DataDesignerConfigBuilder. When provided, the model may call permitted tools during generation. @@ -166,8 +166,8 @@ class LLMTextColumnConfig(SingleColumnConfig): system_prompt: str | None = Field( default=None, description="Optional system prompt to set model behavior and constraints" ) - multi_modal_context: list[ImageContext] | None = Field( - default=None, description="Optional list of ImageContext for vision model inputs" + multi_modal_context: list[MultiModalContextT] | None = Field( + default=None, description="Optional list of multimodal context inputs" ) tool_alias: str | None = Field( default=None, description="Optional alias of the tool configuration to use for MCP tool calls" @@ -250,7 +250,7 @@ class LLMCodeColumnConfig(LLMTextColumnConfig): prompt (required): Prompt template for code generation (supports Jinja2). model_alias (required): Alias of the model configuration to use. system_prompt: Optional system prompt (supports Jinja2). - multi_modal_context: Optional image contexts for multi-modal generation. + multi_modal_context: Optional multimodal contexts for generation. tool_alias: Optional tool configuration alias for MCP tool calls. with_trace: Specifies what trace information to capture in a `{column_name}__trace` column. Options are `TraceType.NONE` (default), `TraceType.LAST_MESSAGE`, or @@ -288,7 +288,7 @@ class LLMStructuredColumnConfig(LLMTextColumnConfig): prompt (required): Prompt template for structured generation (supports Jinja2). model_alias (required): Alias of the model configuration to use. system_prompt: Optional system prompt (supports Jinja2). - multi_modal_context: Optional image contexts for multi-modal generation. + multi_modal_context: Optional multimodal contexts for generation. tool_alias: Optional tool configuration alias for MCP tool calls. with_trace: Specifies what trace information to capture in a `{column_name}__trace` column. Options are `TraceType.NONE` (default), `TraceType.LAST_MESSAGE`, or @@ -358,7 +358,7 @@ class LLMJudgeColumnConfig(LLMTextColumnConfig): prompt (required): Prompt template for the judge evaluation (supports Jinja2). model_alias (required): Alias of the model configuration to use. system_prompt: Optional system prompt (supports Jinja2). - multi_modal_context: Optional image contexts for multi-modal generation. + multi_modal_context: Optional multimodal contexts for generation. tool_alias: Optional tool configuration alias for MCP tool calls. with_trace: Specifies what trace information to capture in a `{column_name}__trace` column. Options are `TraceType.NONE` (default), `TraceType.LAST_MESSAGE`, or @@ -596,8 +596,8 @@ class ImageColumnConfig(SingleColumnConfig): reference other columns (e.g., "Generate an image of a {{ character_name }}"). Must be a valid Jinja2 template. model_alias (required): The model to use for image generation. - multi_modal_context: Optional list of image contexts for multi-modal generation. - Enables autoregressive multi-modal models to generate images based on image inputs. + multi_modal_context: Optional list of multimodal contexts for generation. + Enables autoregressive multi-modal models to generate images based on media inputs. Only works with autoregressive models that support image-to-image generation. Inherited Attributes: @@ -609,8 +609,8 @@ class ImageColumnConfig(SingleColumnConfig): description="Jinja2 template for the image generation prompt; can reference other columns via {{ column_name }}" ) model_alias: str = Field(description="Alias of the model to use for image generation") - multi_modal_context: list[ImageContext] | None = Field( - default=None, description="Optional list of ImageContext for multi-modal image-to-image generation" + multi_modal_context: list[MultiModalContextT] | None = Field( + default=None, description="Optional list of multimodal context inputs for image generation" ) column_type: Literal["image"] = "image" diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py index 482f78308..b361e185f 100644 --- a/packages/data-designer-config/src/data_designer/config/models.py +++ b/packages/data-designer-config/src/data_designer/config/models.py @@ -31,6 +31,18 @@ load_image_path_to_base64, ) from data_designer.config.utils.io_helpers import smart_load_yaml +from data_designer.config.utils.media_helpers import ( + AudioFormat, + VideoFormat, + audio_format_from_mime_type, + audio_mime_type, + is_audio_url, + is_video_url, + normalize_media_context_values, + parse_base64_data_uri, + video_format_from_mime_type, + video_mime_type, +) from data_designer.config.utils.warning_helpers import warn_at_caller logger = logging.getLogger(__name__) @@ -40,6 +52,8 @@ class Modality(str, Enum): """Supported modality types for multimodal model data.""" IMAGE = "image" + AUDIO = "audio" + VIDEO = "video" class ModalityDataType(str, Enum): @@ -77,7 +91,7 @@ class ImageContext(ModalityContext): image_format: Image format (required when data_type is explicitly "base64"). """ - modality: Modality = Modality.IMAGE + modality: Literal[Modality.IMAGE] = Modality.IMAGE image_format: ImageFormat | None = None def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[dict[str, Any]]: @@ -173,6 +187,116 @@ def _validate_image_format(self) -> Self: return self +class AudioContext(ModalityContext): + """Configuration for providing audio context to multimodal models. + + Audio context values are URL or base64 media values. Unlike ``ImageContext``, + this class does not resolve local file paths to base64. + """ + + modality: Literal[Modality.AUDIO] = Modality.AUDIO + audio_format: AudioFormat | None = None + + def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[dict[str, Any]]: + """Get the contexts for the audio modality.""" + del base_path + return [self._build_context(value) for value in normalize_media_context_values(record[self.column_name])] + + def _build_context(self, context_value: Any) -> dict[str, Any]: + if self.data_type == ModalityDataType.URL or (self.data_type is None and is_audio_url(context_value)): + source: dict[str, Any] = {"type": "url", "url": context_value} + if self.audio_format is not None: + source["format"] = self.audio_format.value + return {"type": "audio", "source": source} + + media_type, data, audio_format = self._resolve_base64_parts(context_value) + return { + "type": "audio", + "source": { + "type": "base64", + "media_type": media_type, + "data": data, + "format": audio_format.value, + }, + } + + def _resolve_base64_parts(self, context_value: Any) -> tuple[str, Any, AudioFormat]: + parsed = parse_base64_data_uri(context_value) + if parsed is not None: + media_type, data = parsed + detected_format = audio_format_from_mime_type(media_type) + if detected_format is None: + raise ValueError(f"Unsupported audio media type {media_type!r}") + audio_format = self.audio_format or detected_format + if audio_format != detected_format: + raise ValueError( + f"audio_format {audio_format.value!r} does not match data URI media type {media_type!r}" + ) + return media_type, data, audio_format + + if self.audio_format is None: + raise ValueError("audio_format is required for base64 audio context values") + return audio_mime_type(self.audio_format), context_value, self.audio_format + + @model_validator(mode="after") + def _validate_audio_format(self) -> Self: + if self.data_type == ModalityDataType.BASE64 and self.audio_format is None: + raise ValueError(f"audio_format is required when data_type is {self.data_type.value}") + return self + + +class VideoContext(ModalityContext): + """Configuration for providing video context to multimodal models. + + Video context values are URL or base64 media values. Local file path + resolution is intentionally out of scope for this context type. + """ + + modality: Literal[Modality.VIDEO] = Modality.VIDEO + video_format: VideoFormat | None = None + + def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[dict[str, Any]]: + """Get the contexts for the video modality.""" + del base_path + return [self._build_context(value) for value in normalize_media_context_values(record[self.column_name])] + + def _build_context(self, context_value: Any) -> dict[str, Any]: + if self.data_type == ModalityDataType.URL or (self.data_type is None and is_video_url(context_value)): + return {"type": "video", "source": {"type": "url", "url": context_value}} + + media_type, data = self._resolve_base64_parts(context_value) + return {"type": "video", "source": {"type": "base64", "media_type": media_type, "data": data}} + + def _resolve_base64_parts(self, context_value: Any) -> tuple[str, Any]: + parsed = parse_base64_data_uri(context_value) + if parsed is not None: + media_type, data = parsed + detected_format = video_format_from_mime_type(media_type) + if detected_format is None: + raise ValueError(f"Unsupported video media type {media_type!r}") + if self.video_format is not None and self.video_format != detected_format: + raise ValueError( + f"video_format {self.video_format.value!r} does not match data URI media type {media_type!r}" + ) + return media_type, data + + if self.video_format is None: + raise ValueError("video_format is required for base64 video context values") + return video_mime_type(self.video_format), context_value + + @model_validator(mode="after") + def _validate_video_format(self) -> Self: + if self.data_type == ModalityDataType.BASE64 and self.video_format is None: + raise ValueError(f"video_format is required when data_type is {self.data_type.value}") + return self + + +MultiModalContextT: TypeAlias = Annotated[ + ImageContext | AudioContext | VideoContext, + Field(discriminator="modality"), +] + + DistributionParamsT = TypeVar("DistributionParamsT", bound=ConfigBase) diff --git a/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py new file mode 100644 index 000000000..0cb493b01 --- /dev/null +++ b/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py @@ -0,0 +1,115 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared helpers for multimodal media context values.""" + +from __future__ import annotations + +import json +import re +from typing import Any + +from data_designer.config.utils.type_helpers import StrEnum + + +class AudioFormat(StrEnum): + """Supported audio formats for audio context.""" + + MP3 = "mp3" + WAV = "wav" + + +class VideoFormat(StrEnum): + """Supported video formats for video context.""" + + MP4 = "mp4" + MOV = "mov" + WEBM = "webm" + + +SUPPORTED_AUDIO_EXTENSIONS = [f".{fmt.value.lower()}" for fmt in AudioFormat] +SUPPORTED_VIDEO_EXTENSIONS = [f".{fmt.value.lower()}" for fmt in VideoFormat] + +_DATA_URI_RE = re.compile(r"^data:(?P[^;]+);base64,(?P.+)$") + +_AUDIO_FORMAT_TO_MIME_TYPE: dict[AudioFormat, str] = { + AudioFormat.MP3: "audio/mpeg", + AudioFormat.WAV: "audio/wav", +} +_VIDEO_FORMAT_TO_MIME_TYPE: dict[VideoFormat, str] = { + VideoFormat.MP4: "video/mp4", + VideoFormat.MOV: "video/quicktime", + VideoFormat.WEBM: "video/webm", +} +_AUDIO_MIME_TYPE_TO_FORMAT: dict[str, AudioFormat] = { + mime_type: audio_format for audio_format, mime_type in _AUDIO_FORMAT_TO_MIME_TYPE.items() +} +_VIDEO_MIME_TYPE_TO_FORMAT: dict[str, VideoFormat] = { + mime_type: video_format for video_format, mime_type in _VIDEO_FORMAT_TO_MIME_TYPE.items() +} + + +def normalize_media_context_values(raw_value: Any) -> list[Any]: + """Normalize scalar, JSON-list, list, and array-like media values.""" + if isinstance(raw_value, str): + try: + parsed_value = json.loads(raw_value) + if isinstance(parsed_value, list): + return parsed_value + except (json.JSONDecodeError, TypeError): + pass + return [raw_value] + + if isinstance(raw_value, list): + return raw_value + + if hasattr(raw_value, "__iter__") and not isinstance(raw_value, (str, bytes, dict)): + return list(raw_value) + + return [raw_value] + + +def parse_base64_data_uri(value: str) -> tuple[str, str] | None: + """Return ``(media_type, data)`` for a base64 data URI.""" + if not isinstance(value, str): + return None + match = _DATA_URI_RE.match(value) + if match is None: + return None + return match.group("media_type"), match.group("data") + + +def is_audio_url(value: str) -> bool: + """Return whether a value looks like an audio URL.""" + return _is_media_url(value, SUPPORTED_AUDIO_EXTENSIONS) + + +def is_video_url(value: str) -> bool: + """Return whether a value looks like a video URL.""" + return _is_media_url(value, SUPPORTED_VIDEO_EXTENSIONS) + + +def audio_mime_type(audio_format: AudioFormat) -> str: + """Return the MIME type for an audio format.""" + return _AUDIO_FORMAT_TO_MIME_TYPE[audio_format] + + +def video_mime_type(video_format: VideoFormat) -> str: + """Return the MIME type for a video format.""" + return _VIDEO_FORMAT_TO_MIME_TYPE[video_format] + + +def audio_format_from_mime_type(media_type: str) -> AudioFormat | None: + """Infer an audio format from a MIME type.""" + return _AUDIO_MIME_TYPE_TO_FORMAT.get(media_type.lower()) + + +def video_format_from_mime_type(media_type: str) -> VideoFormat | None: + """Infer a video format from a MIME type.""" + return _VIDEO_MIME_TYPE_TO_FORMAT.get(media_type.lower()) + + +def _is_media_url(value: str, supported_extensions: list[str]) -> bool: + if not isinstance(value, str): + return False + return value.startswith(("http://", "https://")) and any(ext in value.lower() for ext in supported_extensions) diff --git a/packages/data-designer-config/tests/config/test_columns.py b/packages/data-designer-config/tests/config/test_columns.py index 4937a8e37..fa0ce6882 100644 --- a/packages/data-designer-config/tests/config/test_columns.py +++ b/packages/data-designer-config/tests/config/test_columns.py @@ -28,7 +28,7 @@ is_plugin_column_type, ) from data_designer.config.errors import InvalidConfigError -from data_designer.config.models import ImageContext +from data_designer.config.models import AudioContext, ImageContext, ModalityDataType, VideoContext from data_designer.config.sampler_params import ( CategorySamplerParams, GaussianSamplerParams, @@ -130,9 +130,13 @@ def test_llm_text_column_config_required_columns_includes_multi_modal_context(): name="test_llm_text", prompt="Classify this image: {{ description }}", model_alias=stub_model_alias, - multi_modal_context=[ImageContext(column_name="image_base64")], + multi_modal_context=[ + ImageContext(column_name="image_base64"), + AudioContext(column_name="audio_url", data_type=ModalityDataType.URL), + VideoContext(column_name="video_url", data_type=ModalityDataType.URL), + ], ) - assert set(config.required_columns) == {"description", "image_base64"} + assert set(config.required_columns) == {"description", "image_base64", "audio_url", "video_url"} def test_llm_text_column_config_required_columns_deduplicates_multi_modal_and_prompt(): @@ -150,9 +154,32 @@ def test_image_column_config_required_columns_includes_multi_modal_context(): name="test_image", prompt="Generate based on {{ style }}", model_alias=stub_model_alias, - multi_modal_context=[ImageContext(column_name="reference_image")], + multi_modal_context=[ + ImageContext(column_name="reference_image"), + AudioContext(column_name="reference_audio", data_type=ModalityDataType.URL), + ], ) - assert set(config.required_columns) == {"style", "reference_image"} + assert set(config.required_columns) == {"style", "reference_image", "reference_audio"} + + +def test_multi_modal_context_round_trips_discriminated_union() -> None: + config = LLMTextColumnConfig( + name="test_llm_text", + prompt="Describe the context", + model_alias=stub_model_alias, + multi_modal_context=[ + ImageContext(column_name="image_url", data_type=ModalityDataType.URL), + AudioContext(column_name="audio_url", data_type=ModalityDataType.URL), + VideoContext(column_name="video_url", data_type=ModalityDataType.URL), + ], + ) + + round_tripped = LLMTextColumnConfig(**config.model_dump()) + + assert round_tripped.multi_modal_context is not None + assert isinstance(round_tripped.multi_modal_context[0], ImageContext) + assert isinstance(round_tripped.multi_modal_context[1], AudioContext) + assert isinstance(round_tripped.multi_modal_context[2], VideoContext) def test_llm_text_column_config_with_trace_serialization() -> None: diff --git a/packages/data-designer-config/tests/config/test_models.py b/packages/data-designer-config/tests/config/test_models.py index c5bacd818..4cb62b649 100644 --- a/packages/data-designer-config/tests/config/test_models.py +++ b/packages/data-designer-config/tests/config/test_models.py @@ -15,6 +15,8 @@ import data_designer.lazy_heavy_imports as lazy from data_designer.config.errors import InvalidConfigError from data_designer.config.models import ( + AudioContext, + AudioFormat, ChatCompletionInferenceParams, EmbeddingInferenceParams, GenerationType, @@ -27,6 +29,8 @@ ModelConfig, UniformDistribution, UniformDistributionParams, + VideoContext, + VideoFormat, load_model_configs, ) @@ -244,6 +248,151 @@ def test_image_context_auto_detect_file_path_not_exists(tmp_path: Path) -> None: ) +def test_audio_context_get_contexts_single_string() -> None: + audio_context = AudioContext( + column_name="audio_base64", data_type=ModalityDataType.BASE64, audio_format=AudioFormat.MP3 + ) + assert audio_context.get_contexts({"audio_base64": "audio1base64"}) == [ + { + "type": "audio", + "source": { + "type": "base64", + "media_type": "audio/mpeg", + "data": "audio1base64", + "format": "mp3", + }, + } + ] + + audio_context = AudioContext(column_name="audio_url", data_type=ModalityDataType.URL) + assert audio_context.get_contexts({"audio_url": "https://example.com/audio.mp3"}) == [ + {"type": "audio", "source": {"type": "url", "url": "https://example.com/audio.mp3"}} + ] + + +def test_audio_context_get_contexts_list_json_and_numpy() -> None: + audio_context = AudioContext( + column_name="audio_base64", data_type=ModalityDataType.BASE64, audio_format=AudioFormat.WAV + ) + assert audio_context.get_contexts({"audio_base64": ["audio1", "audio2"]}) == [ + { + "type": "audio", + "source": {"type": "base64", "media_type": "audio/wav", "data": "audio1", "format": "wav"}, + }, + { + "type": "audio", + "source": {"type": "base64", "media_type": "audio/wav", "data": "audio2", "format": "wav"}, + }, + ] + + json_str = json.dumps(["https://example.com/audio1.mp3", "https://example.com/audio2.mp3"]) + url_context = AudioContext(column_name="audio_url", data_type=ModalityDataType.URL) + assert url_context.get_contexts({"audio_url": json_str}) == [ + {"type": "audio", "source": {"type": "url", "url": "https://example.com/audio1.mp3"}}, + {"type": "audio", "source": {"type": "url", "url": "https://example.com/audio2.mp3"}}, + ] + + numpy_array = lazy.np.array(["https://example.com/audio1.mp3", "https://example.com/audio2.mp3"]) + assert url_context.get_contexts({"audio_url": numpy_array}) == [ + {"type": "audio", "source": {"type": "url", "url": "https://example.com/audio1.mp3"}}, + {"type": "audio", "source": {"type": "url", "url": "https://example.com/audio2.mp3"}}, + ] + + +def test_audio_context_auto_detect_url_and_data_uri() -> None: + assert AudioContext(column_name="audio_col").get_contexts({"audio_col": "https://example.com/audio.mp3"}) == [ + {"type": "audio", "source": {"type": "url", "url": "https://example.com/audio.mp3"}} + ] + + assert AudioContext(column_name="audio_col").get_contexts({"audio_col": "data:audio/mpeg;base64,audio1base64"}) == [ + { + "type": "audio", + "source": { + "type": "base64", + "media_type": "audio/mpeg", + "data": "audio1base64", + "format": "mp3", + }, + } + ] + + +def test_audio_context_validate_audio_format() -> None: + with pytest.raises(ValueError, match="audio_format is required when data_type is base64"): + AudioContext(column_name="audio_base64", data_type=ModalityDataType.BASE64) + + with pytest.raises(ValueError, match="audio_format is required for base64 audio context values"): + AudioContext(column_name="audio_base64").get_contexts({"audio_base64": "audio1base64"}) + + with pytest.raises(ValueError, match="does not match data URI media type"): + AudioContext(column_name="audio_base64", audio_format=AudioFormat.WAV).get_contexts( + {"audio_base64": "data:audio/mpeg;base64,audio1base64"} + ) + + +def test_video_context_get_contexts_single_string() -> None: + video_context = VideoContext( + column_name="video_base64", data_type=ModalityDataType.BASE64, video_format=VideoFormat.MP4 + ) + assert video_context.get_contexts({"video_base64": "video1base64"}) == [ + { + "type": "video", + "source": {"type": "base64", "media_type": "video/mp4", "data": "video1base64"}, + } + ] + + video_context = VideoContext(column_name="video_url", data_type=ModalityDataType.URL) + assert video_context.get_contexts({"video_url": "https://example.com/video.mp4"}) == [ + {"type": "video", "source": {"type": "url", "url": "https://example.com/video.mp4"}} + ] + + +def test_video_context_get_contexts_list_json_and_numpy() -> None: + video_context = VideoContext( + column_name="video_base64", data_type=ModalityDataType.BASE64, video_format=VideoFormat.WEBM + ) + assert video_context.get_contexts({"video_base64": ["video1", "video2"]}) == [ + {"type": "video", "source": {"type": "base64", "media_type": "video/webm", "data": "video1"}}, + {"type": "video", "source": {"type": "base64", "media_type": "video/webm", "data": "video2"}}, + ] + + json_str = json.dumps(["https://example.com/video1.mp4", "https://example.com/video2.mp4"]) + url_context = VideoContext(column_name="video_url", data_type=ModalityDataType.URL) + assert url_context.get_contexts({"video_url": json_str}) == [ + {"type": "video", "source": {"type": "url", "url": "https://example.com/video1.mp4"}}, + {"type": "video", "source": {"type": "url", "url": "https://example.com/video2.mp4"}}, + ] + + numpy_array = lazy.np.array(["https://example.com/video1.mp4", "https://example.com/video2.mp4"]) + assert url_context.get_contexts({"video_url": numpy_array}) == [ + {"type": "video", "source": {"type": "url", "url": "https://example.com/video1.mp4"}}, + {"type": "video", "source": {"type": "url", "url": "https://example.com/video2.mp4"}}, + ] + + +def test_video_context_auto_detect_url_and_data_uri() -> None: + assert VideoContext(column_name="video_col").get_contexts({"video_col": "https://example.com/video.mp4"}) == [ + {"type": "video", "source": {"type": "url", "url": "https://example.com/video.mp4"}} + ] + + assert VideoContext(column_name="video_col").get_contexts({"video_col": "data:video/mp4;base64,video1base64"}) == [ + {"type": "video", "source": {"type": "base64", "media_type": "video/mp4", "data": "video1base64"}} + ] + + +def test_video_context_validate_video_format() -> None: + with pytest.raises(ValueError, match="video_format is required when data_type is base64"): + VideoContext(column_name="video_base64", data_type=ModalityDataType.BASE64) + + with pytest.raises(ValueError, match="video_format is required for base64 video context values"): + VideoContext(column_name="video_base64").get_contexts({"video_base64": "video1base64"}) + + with pytest.raises(ValueError, match="does not match data URI media type"): + VideoContext(column_name="video_base64", video_format=VideoFormat.WEBM).get_contexts( + {"video_base64": "data:video/mp4;base64,video1base64"} + ) + + def test_inference_parameters_default_construction(): empty_inference_parameters = ChatCompletionInferenceParams() assert empty_inference_parameters.generate_kwargs == {} diff --git a/packages/data-designer-config/tests/config/utils/test_media_helpers.py b/packages/data-designer-config/tests/config/utils/test_media_helpers.py new file mode 100644 index 000000000..184f1c8d5 --- /dev/null +++ b/packages/data-designer-config/tests/config/utils/test_media_helpers.py @@ -0,0 +1,54 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json + +import data_designer.lazy_heavy_imports as lazy +from data_designer.config.utils.media_helpers import ( + AudioFormat, + VideoFormat, + audio_format_from_mime_type, + audio_mime_type, + is_audio_url, + is_video_url, + normalize_media_context_values, + parse_base64_data_uri, + video_format_from_mime_type, + video_mime_type, +) + + +def test_normalize_media_context_values() -> None: + assert normalize_media_context_values("single") == ["single"] + assert normalize_media_context_values(["one", "two"]) == ["one", "two"] + assert normalize_media_context_values(json.dumps(["one", "two"])) == ["one", "two"] + assert normalize_media_context_values(json.dumps({"nested": "value"})) == [json.dumps({"nested": "value"})] + assert normalize_media_context_values(lazy.np.array(["one", "two"])) == ["one", "two"] + + +def test_parse_base64_data_uri() -> None: + assert parse_base64_data_uri("data:audio/mpeg;base64,abc123") == ("audio/mpeg", "abc123") + assert parse_base64_data_uri("abc123") is None + + +def test_audio_url_detection() -> None: + assert is_audio_url("https://example.com/audio.mp3") is True + assert is_audio_url("https://example.com/audio.wav?download=1") is True + assert is_audio_url("https://example.com/image.png") is False + assert is_audio_url(123) is False # type: ignore[arg-type] + + +def test_video_url_detection() -> None: + assert is_video_url("https://example.com/video.mp4") is True + assert is_video_url("https://example.com/video.webm?download=1") is True + assert is_video_url("https://example.com/audio.mp3") is False + assert is_video_url(123) is False # type: ignore[arg-type] + + +def test_media_format_mime_helpers() -> None: + assert audio_mime_type(AudioFormat.MP3) == "audio/mpeg" + assert audio_format_from_mime_type("audio/mpeg") == AudioFormat.MP3 + assert video_mime_type(VideoFormat.MP4) == "video/mp4" + assert video_format_from_mime_type("video/mp4") == VideoFormat.MP4 diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py index ba432ce2c..fd002f9a6 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py @@ -279,9 +279,8 @@ def inference_parameters(self) -> BaseInferenceParams: def _build_multi_modal_context(self, record: dict) -> list[dict[str, Any]] | None: """Build multi-modal context from the config's multi_modal_context list. - Passes base_path to get_contexts() so that generated image file paths - (stored under base_dataset_path in create mode) can be resolved to base64 - before being sent to the model endpoint. + Passes base_path to get_contexts() so context types that support + artifact-relative resolution can use the dataset artifact directory. Args: record: The deserialized record containing column values. diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py index 204b46677..1f4d5081d 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py @@ -6,6 +6,7 @@ from typing import Any from data_designer.engine.models.clients.adapters.anthropic_translation import ( + UnsupportedAnthropicMediaBlockError, build_anthropic_payload, parse_anthropic_response, ) @@ -106,6 +107,16 @@ async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerat def _build_payload_or_raise(self, request: ChatCompletionRequest) -> dict[str, Any]: try: return build_anthropic_payload(request) + except UnsupportedAnthropicMediaBlockError as exc: + raise ProviderError.unsupported_capability( + provider_name=self.provider_name, + model_name=request.model, + operation=f"{exc.modality}-context", + message=( + f"Provider {self.provider_name!r} does not support {exc.modality} context " + f"for model {request.model!r}." + ), + ) from exc except ValueError as exc: raise ProviderError( kind=ProviderErrorKind.BAD_REQUEST, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py index 21a959f40..debe277b9 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py @@ -20,6 +20,14 @@ _DATA_URI_RE = re.compile(r"^data:(?P[^;]+);base64,(?P.+)$") +class UnsupportedAnthropicMediaBlockError(ValueError): + """Raised when a canonical media block cannot be translated to Anthropic.""" + + def __init__(self, modality: str) -> None: + self.modality = modality + super().__init__(f"Anthropic adapter does not support {modality} context blocks.") + + def merge_system_parts(parts: list[str | list[dict[str, Any]]]) -> str | list[dict[str, Any]]: """Merge system parts into a single string or Anthropic block list. @@ -196,6 +204,11 @@ def translate_content_blocks(content: Any) -> list[dict[str, Any]]: if isinstance(block, dict) and block.get("type") == "image_url": translated.append(translate_image_url_block(block)) continue + if isinstance(block, dict) and block.get("type") == "image": + translated.append(translate_canonical_image_block(block)) + continue + if isinstance(block, dict) and block.get("type") in {"audio", "video"}: + raise UnsupportedAnthropicMediaBlockError(block["type"]) # Anthropic rejects empty text blocks β€” drop them. if isinstance(block, dict) and block.get("type") == "text" and not block.get("text"): continue @@ -341,3 +354,21 @@ def translate_image_url_block(block: dict[str, Any]) -> dict[str, Any]: "type": "image", "source": {"type": "url", "url": url}, } + + +def translate_canonical_image_block(block: dict[str, Any]) -> dict[str, Any]: + source = block.get("source") + if not isinstance(source, dict): + raise ValueError(f"Canonical image block must include a source object, got: {block!r}") + + source_type = source.get("type") + if source_type == "url": + return {"type": "image", "source": {"type": "url", "url": source.get("url", "")}} + if source_type == "base64": + media_type = source.get("media_type") + data = source.get("data") + if not isinstance(media_type, str) or not isinstance(data, str): + raise ValueError(f"Canonical image base64 source must include media_type and data, got: {source!r}") + return {"type": "image", "source": {"type": "base64", "media_type": media_type, "data": data}} + + raise ValueError(f"Unsupported canonical image source type {source_type!r}") diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py index 44ab1f1d5..81ac9a2a9 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py @@ -61,13 +61,23 @@ def supports_image_generation(self) -> bool: def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: transport = TransportKwargs.from_request(request) - payload = {"model": request.model, "messages": request.messages, **transport.body} + messages = translate_openai_compatible_messages( + request.messages, + provider_name=self.provider_name, + model_name=request.model, + ) + payload = {"model": request.model, "messages": messages, **transport.body} response_json = self._post_sync(self._ROUTE_CHAT, payload, transport.headers, request.model, transport.timeout) return parse_chat_completion_response(response_json) async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: transport = TransportKwargs.from_request(request) - payload = {"model": request.model, "messages": request.messages, **transport.body} + messages = translate_openai_compatible_messages( + request.messages, + provider_name=self.provider_name, + model_name=request.model, + ) + payload = {"model": request.model, "messages": messages, **transport.body} response_json = await self._apost( self._ROUTE_CHAT, payload, transport.headers, request.model, transport.timeout ) @@ -101,7 +111,12 @@ def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResp transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE) if request.messages is not None: route = self._ROUTE_CHAT - payload = {"model": request.model, "messages": request.messages, **transport.body} + messages = translate_openai_compatible_messages( + request.messages, + provider_name=self.provider_name, + model_name=request.model, + ) + payload = {"model": request.model, "messages": messages, **transport.body} else: route = self._ROUTE_IMAGE payload = {"model": request.model, "prompt": request.prompt, **transport.body} @@ -112,7 +127,12 @@ async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerat transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE) if request.messages is not None: route = self._ROUTE_CHAT - payload = {"model": request.model, "messages": request.messages, **transport.body} + messages = translate_openai_compatible_messages( + request.messages, + provider_name=self.provider_name, + model_name=request.model, + ) + payload = {"model": request.model, "messages": messages, **transport.body} else: route = self._ROUTE_IMAGE payload = {"model": request.model, "prompt": request.prompt, **transport.body} @@ -133,6 +153,115 @@ def _build_headers(self, extra_headers: dict[str, str]) -> dict[str, str]: # --------------------------------------------------------------------------- +def translate_openai_compatible_messages( + messages: list[dict[str, Any]], + *, + provider_name: str, + model_name: str, +) -> list[dict[str, Any]]: + """Translate canonical media blocks to OpenAI-compatible content blocks.""" + translated_messages: list[dict[str, Any]] = [] + for message in messages: + translated = dict(message) + if "content" in translated: + translated["content"] = translate_openai_compatible_content_blocks( + translated["content"], + provider_name=provider_name, + model_name=model_name, + ) + translated_messages.append(translated) + return translated_messages + + +def translate_openai_compatible_content_blocks( + content: Any, + *, + provider_name: str, + model_name: str, +) -> Any: + if not isinstance(content, list): + return content + + return [ + translate_openai_compatible_content_block( + block, + provider_name=provider_name, + model_name=model_name, + ) + for block in content + ] + + +def translate_openai_compatible_content_block( + block: Any, + *, + provider_name: str, + model_name: str, +) -> Any: + if not isinstance(block, dict): + return block + + block_type = block.get("type") + if block_type in {"image_url", "input_audio", "text"}: + return block + if block_type == "image": + return _translate_canonical_image_block(block) + if block_type == "audio": + return _translate_canonical_audio_block(block) + if block_type == "video": + return _translate_canonical_video_block(block) + return block + + +def _translate_canonical_image_block(block: dict[str, Any]) -> dict[str, Any]: + source = _get_media_source(block, modality="image") + source_type = source.get("type") + if source_type == "url": + return {"type": "image_url", "image_url": {"url": source.get("url", "")}} + if source_type == "base64": + media_type = source.get("media_type") + data = source.get("data") + if not isinstance(media_type, str) or not isinstance(data, str): + raise ValueError(f"Canonical image base64 source must include media_type and data, got: {source!r}") + return {"type": "image_url", "image_url": {"url": f"data:{media_type};base64,{data}"}} + raise ValueError(f"Unsupported canonical image source type {source_type!r}") + + +def _translate_canonical_audio_block(block: dict[str, Any]) -> dict[str, Any]: + source = _get_media_source(block, modality="audio") + source_type = source.get("type") + if source_type == "url": + return {"type": "audio_url", "audio_url": {"url": source.get("url", "")}} + if source_type == "base64": + data = source.get("data") + audio_format = source.get("format") + if not isinstance(data, str) or not isinstance(audio_format, str): + raise ValueError(f"Canonical audio base64 source must include data and format, got: {source!r}") + return {"type": "input_audio", "input_audio": {"data": data, "format": audio_format}} + raise ValueError(f"Unsupported canonical audio source type {source_type!r}") + + +def _translate_canonical_video_block(block: dict[str, Any]) -> dict[str, Any]: + source = _get_media_source(block, modality="video") + source_type = source.get("type") + if source_type == "url": + return {"type": "video_url", "video_url": {"url": source.get("url", "")}} + if source_type == "base64": + media_type = source.get("media_type") + data = source.get("data") + if not isinstance(media_type, str) or not isinstance(data, str): + raise ValueError(f"Canonical video base64 source must include media_type and data, got: {source!r}") + return {"type": "video_url", "video_url": {"url": f"data:{media_type};base64,{data}"}} + raise ValueError(f"Unsupported canonical video source type {source_type!r}") + + +def _get_media_source(block: dict[str, Any], *, modality: str) -> dict[str, Any]: + source = block.get("source") + if not isinstance(source, dict): + raise ValueError(f"Canonical {modality} block must include a source object, got: {block!r}") + return source + + def _parse_embedding_json(response_json: dict[str, Any]) -> EmbeddingResponse: data = response_json.get("data") or [] vectors = [extract_embedding_vector(item) for item in data] diff --git a/packages/data-designer-engine/src/data_designer/engine/models/utils.py b/packages/data-designer-engine/src/data_designer/engine/models/utils.py index f7183e83d..a4b51e4bd 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/utils.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/utils.py @@ -18,7 +18,7 @@ class ChatMessage: Attributes: role: The role of the message sender. One of 'user', 'assistant', 'system', or 'tool'. content: The message content. Can be a string or a list of content blocks - for multimodal messages (e.g., text + images). + for multimodal messages (e.g., text + image/audio/video context). reasoning_content: Optional reasoning/thinking content from the assistant, typically from extended thinking or chain-of-thought models. tool_calls: Optional list of tool calls requested by the assistant. diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py index 1b1022d15..b0c71cc5a 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py @@ -555,6 +555,32 @@ def test_completion_preserves_non_image_content_blocks() -> None: assert content[1] == {"type": "custom_block", "data": "something"} +@pytest.mark.parametrize("modality", ["audio", "video"]) +def test_completion_rejects_audio_video_context_as_unsupported(modality: str) -> None: + sync_mock = make_mock_sync_client(_text_response()) + client = _make_client(sync_client=sync_mock) + + request = ChatCompletionRequest( + model=MODEL, + messages=[ + { + "role": "user", + "content": [ + {"type": modality, "source": {"type": "url", "url": "https://example.com/media"}}, + {"type": "text", "text": "Describe this."}, + ], + }, + ], + ) + + with pytest.raises(ProviderError) as exc_info: + client.completion(request) + + assert exc_info.value.kind == ProviderErrorKind.UNSUPPORTED_CAPABILITY + assert modality in exc_info.value.message + sync_mock.post.assert_not_called() + + def test_completion_passes_string_content_unchanged() -> None: sync_mock = make_mock_sync_client(_text_response()) client = _make_client(sync_client=sync_mock) diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py index 2aad03ced..b60b00f67 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py @@ -9,6 +9,7 @@ from data_designer.engine.mcp.registry import MCPToolDefinition from data_designer.engine.models.clients.adapters.anthropic_translation import ( + UnsupportedAnthropicMediaBlockError, build_anthropic_payload, extract_system_content, merge_system_parts, @@ -342,6 +343,26 @@ def test_translate_content_blocks_converts_images_and_preserves_other_blocks() - ] +def test_translate_content_blocks_converts_canonical_images() -> None: + blocks = translate_content_blocks( + [ + {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": "iVBOR..."}}, + {"type": "text", "text": "Caption"}, + ] + ) + + assert blocks == [ + {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": "iVBOR..."}}, + {"type": "text", "text": "Caption"}, + ] + + +@pytest.mark.parametrize("modality", ["audio", "video"]) +def test_translate_content_blocks_rejects_unsupported_media(modality: str) -> None: + with pytest.raises(UnsupportedAnthropicMediaBlockError, match=f"{modality} context"): + translate_content_blocks([{"type": modality, "source": {"type": "url", "url": "https://example.com/media"}}]) + + def test_translate_content_blocks_rejects_malformed_image_url_block() -> None: with pytest.raises(TypeError, match="image_url block must contain a dict"): translate_content_blocks( diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py index 3284d79b5..3e0ec82e9 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py @@ -317,6 +317,89 @@ def test_completion_forwards_multimodal_tool_result_content_unchanged() -> None: assert payload["messages"][0]["content"] == content +def test_completion_translates_canonical_image_blocks() -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + image_block = {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": "iVBOR..."}} + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [image_block, {"type": "text", "text": "What is this?"}]}], + ) + client.completion(request) + + payload = sync_mock.post.call_args.kwargs["json"] + content = payload["messages"][0]["content"] + assert content[0] == {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR..."}} + + +def test_completion_translates_base64_audio_blocks() -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + audio_block = { + "type": "audio", + "source": {"type": "base64", "media_type": "audio/mpeg", "data": "abc123", "format": "mp3"}, + } + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [audio_block, {"type": "text", "text": "Transcribe this."}]}], + ) + client.completion(request) + + payload = sync_mock.post.call_args.kwargs["json"] + content = payload["messages"][0]["content"] + assert content[0] == {"type": "input_audio", "input_audio": {"data": "abc123", "format": "mp3"}} + + +def test_completion_translates_audio_url_blocks() -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + audio_block = {"type": "audio", "source": {"type": "url", "url": "https://example.com/audio.mp3"}} + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [audio_block, {"type": "text", "text": "Transcribe this."}]}], + ) + client.completion(request) + + payload = sync_mock.post.call_args.kwargs["json"] + content = payload["messages"][0]["content"] + assert content[0] == {"type": "audio_url", "audio_url": {"url": "https://example.com/audio.mp3"}} + + +def test_completion_translates_video_blocks() -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + video_block = {"type": "video", "source": {"type": "url", "url": "https://example.com/video.mp4"}} + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [video_block, {"type": "text", "text": "Describe this."}]}], + ) + client.completion(request) + + payload = sync_mock.post.call_args.kwargs["json"] + content = payload["messages"][0]["content"] + assert content[0] == {"type": "video_url", "video_url": {"url": "https://example.com/video.mp4"}} + + +def test_completion_translates_base64_video_blocks() -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + video_block = {"type": "video", "source": {"type": "base64", "media_type": "video/mp4", "data": "abc123"}} + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [video_block, {"type": "text", "text": "Describe this."}]}], + ) + client.completion(request) + + payload = sync_mock.post.call_args.kwargs["json"] + content = payload["messages"][0]["content"] + assert content[0] == {"type": "video_url", "video_url": {"url": "data:video/mp4;base64,abc123"}} + + # --- Auth headers --- diff --git a/packages/data-designer-engine/tests/engine/models/test_model_utils.py b/packages/data-designer-engine/tests/engine/models/test_model_utils.py index c2f07c068..65ca83afb 100644 --- a/packages/data-designer-engine/tests/engine/models/test_model_utils.py +++ b/packages/data-designer-engine/tests/engine/models/test_model_utils.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from data_designer.engine.models.utils import ChatMessage, prompt_to_messages @@ -33,3 +35,18 @@ def test_chat_message_as_tool_accepts_multimodal_content() -> None: assert message.content == content assert message.to_dict()["content"] == content + + +def test_prompt_to_messages_preserves_mixed_media_context_order() -> None: + context = [ + {"type": "image", "source": {"type": "url", "url": "https://example.com/image.png"}}, + { + "type": "audio", + "source": {"type": "base64", "media_type": "audio/mpeg", "data": "abc123", "format": "mp3"}, + }, + {"type": "video", "source": {"type": "url", "url": "https://example.com/video.mp4"}}, + ] + + assert prompt_to_messages(user_prompt="describe", multi_modal_context=context) == [ + ChatMessage.as_user([*context, {"type": "text", "text": "describe"}]) + ] diff --git a/packages/data-designer-engine/tests/engine/test_validation.py b/packages/data-designer-engine/tests/engine/test_validation.py index 38af52947..efea50fa1 100644 --- a/packages/data-designer-engine/tests/engine/test_validation.py +++ b/packages/data-designer-engine/tests/engine/test_validation.py @@ -16,7 +16,7 @@ SeedDatasetColumnConfig, ValidationColumnConfig, ) -from data_designer.config.models import ImageContext, ModalityDataType +from data_designer.config.models import AudioContext, ImageContext, ModalityDataType from data_designer.config.processors import ( DropColumnsProcessorConfig, SchemaTransformProcessorConfig, @@ -248,6 +248,18 @@ def test_validate_column_config_with_multi_modal_context(): assert len(violations) == 0 +def test_validate_column_config_with_audio_multi_modal_context(): + column = LLMTextColumnConfig( + name="audio_description", + prompt="Describe the audio.", + model_alias=STUB_MODEL_ALIAS, + multi_modal_context=[AudioContext(column_name="audio_url", data_type=ModalityDataType.URL)], + ) + + violations = validate_prompt_templates([column], [column.name]) + assert len(violations) == 0 + + def test_validate_columns_not_all_dropped(): violations = validate_columns_not_all_dropped( [ From cda2a4fb11423cfb85d995e7df1ee38ab922f444 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 18 May 2026 14:27:01 -0600 Subject: [PATCH 02/15] fix: harden media context review gaps Preserve extensionless HTTP(S) audio and video URLs as URL media, reject local path-looking audio/video context values, and reject provider-specific audio/video blocks in the Anthropic adapter. Refs #671 --- .../src/data_designer/config/models.py | 15 +++++-- .../config/utils/media_helpers.py | 44 +++++++++++++++---- .../tests/config/test_models.py | 18 ++++++++ .../tests/config/utils/test_media_helpers.py | 17 +++++++ .../clients/adapters/anthropic_translation.py | 12 ++++- .../src/data_designer/engine/models/facade.py | 5 ++- .../clients/test_anthropic_translation.py | 14 ++++++ .../models/clients/test_openai_compatible.py | 8 ++-- 8 files changed, 113 insertions(+), 20 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py index b361e185f..4a8abaa6f 100644 --- a/packages/data-designer-config/src/data_designer/config/models.py +++ b/packages/data-designer-config/src/data_designer/config/models.py @@ -36,8 +36,9 @@ VideoFormat, audio_format_from_mime_type, audio_mime_type, - is_audio_url, - is_video_url, + is_audio_path, + is_media_url, + is_video_path, normalize_media_context_values, parse_base64_data_uri, video_format_from_mime_type, @@ -203,7 +204,7 @@ def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[di return [self._build_context(value) for value in normalize_media_context_values(record[self.column_name])] def _build_context(self, context_value: Any) -> dict[str, Any]: - if self.data_type == ModalityDataType.URL or (self.data_type is None and is_audio_url(context_value)): + if self.data_type == ModalityDataType.URL or (self.data_type is None and is_media_url(context_value)): source: dict[str, Any] = {"type": "url", "url": context_value} if self.audio_format is not None: source["format"] = self.audio_format.value @@ -234,6 +235,9 @@ def _resolve_base64_parts(self, context_value: Any) -> tuple[str, Any, AudioForm ) return media_type, data, audio_format + if is_audio_path(context_value): + raise ValueError("Local audio paths are not supported; provide an audio URL or base64 audio data") + if self.audio_format is None: raise ValueError("audio_format is required for base64 audio context values") return audio_mime_type(self.audio_format), context_value, self.audio_format @@ -261,7 +265,7 @@ def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[di return [self._build_context(value) for value in normalize_media_context_values(record[self.column_name])] def _build_context(self, context_value: Any) -> dict[str, Any]: - if self.data_type == ModalityDataType.URL or (self.data_type is None and is_video_url(context_value)): + if self.data_type == ModalityDataType.URL or (self.data_type is None and is_media_url(context_value)): return {"type": "video", "source": {"type": "url", "url": context_value}} media_type, data = self._resolve_base64_parts(context_value) @@ -280,6 +284,9 @@ def _resolve_base64_parts(self, context_value: Any) -> tuple[str, Any]: ) return media_type, data + if is_video_path(context_value): + raise ValueError("Local video paths are not supported; provide a video URL or base64 video data") + if self.video_format is None: raise ValueError("video_format is required for base64 video context values") return video_mime_type(self.video_format), context_value diff --git a/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py index 0cb493b01..35435af91 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py @@ -27,8 +27,8 @@ class VideoFormat(StrEnum): WEBM = "webm" -SUPPORTED_AUDIO_EXTENSIONS = [f".{fmt.value.lower()}" for fmt in AudioFormat] -SUPPORTED_VIDEO_EXTENSIONS = [f".{fmt.value.lower()}" for fmt in VideoFormat] +SUPPORTED_AUDIO_EXTENSIONS = tuple(f".{fmt.value.lower()}" for fmt in AudioFormat) +SUPPORTED_VIDEO_EXTENSIONS = tuple(f".{fmt.value.lower()}" for fmt in VideoFormat) _DATA_URI_RE = re.compile(r"^data:(?P[^;]+);base64,(?P.+)$") @@ -42,10 +42,17 @@ class VideoFormat(StrEnum): VideoFormat.WEBM: "video/webm", } _AUDIO_MIME_TYPE_TO_FORMAT: dict[str, AudioFormat] = { - mime_type: audio_format for audio_format, mime_type in _AUDIO_FORMAT_TO_MIME_TYPE.items() + "audio/mpeg": AudioFormat.MP3, + "audio/mp3": AudioFormat.MP3, + "audio/wav": AudioFormat.WAV, + "audio/wave": AudioFormat.WAV, + "audio/x-wav": AudioFormat.WAV, + "audio/vnd.wave": AudioFormat.WAV, } _VIDEO_MIME_TYPE_TO_FORMAT: dict[str, VideoFormat] = { - mime_type: video_format for video_format, mime_type in _VIDEO_FORMAT_TO_MIME_TYPE.items() + "video/mp4": VideoFormat.MP4, + "video/quicktime": VideoFormat.MOV, + "video/webm": VideoFormat.WEBM, } @@ -79,14 +86,29 @@ def parse_base64_data_uri(value: str) -> tuple[str, str] | None: return match.group("media_type"), match.group("data") +def is_media_url(value: str) -> bool: + """Return whether a value is an HTTP(S) media URL.""" + return isinstance(value, str) and value.startswith(("http://", "https://")) + + def is_audio_url(value: str) -> bool: """Return whether a value looks like an audio URL.""" - return _is_media_url(value, SUPPORTED_AUDIO_EXTENSIONS) + return is_media_url(value) and _has_media_extension(value, SUPPORTED_AUDIO_EXTENSIONS) def is_video_url(value: str) -> bool: """Return whether a value looks like a video URL.""" - return _is_media_url(value, SUPPORTED_VIDEO_EXTENSIONS) + return is_media_url(value) and _has_media_extension(value, SUPPORTED_VIDEO_EXTENSIONS) + + +def is_audio_path(value: str) -> bool: + """Return whether a value looks like a local audio path.""" + return _has_path_extension(value, SUPPORTED_AUDIO_EXTENSIONS) + + +def is_video_path(value: str) -> bool: + """Return whether a value looks like a local video path.""" + return _has_path_extension(value, SUPPORTED_VIDEO_EXTENSIONS) def audio_mime_type(audio_format: AudioFormat) -> str: @@ -109,7 +131,13 @@ def video_format_from_mime_type(media_type: str) -> VideoFormat | None: return _VIDEO_MIME_TYPE_TO_FORMAT.get(media_type.lower()) -def _is_media_url(value: str, supported_extensions: list[str]) -> bool: +def _has_media_extension(value: str, supported_extensions: tuple[str, ...]) -> bool: + if not isinstance(value, str): + return False + return any(ext in value.lower() for ext in supported_extensions) + + +def _has_path_extension(value: str, supported_extensions: tuple[str, ...]) -> bool: if not isinstance(value, str): return False - return value.startswith(("http://", "https://")) and any(ext in value.lower() for ext in supported_extensions) + return not is_media_url(value) and value.lower().endswith(supported_extensions) diff --git a/packages/data-designer-config/tests/config/test_models.py b/packages/data-designer-config/tests/config/test_models.py index 4cb62b649..dbdf9cfc9 100644 --- a/packages/data-designer-config/tests/config/test_models.py +++ b/packages/data-designer-config/tests/config/test_models.py @@ -304,6 +304,10 @@ def test_audio_context_auto_detect_url_and_data_uri() -> None: {"type": "audio", "source": {"type": "url", "url": "https://example.com/audio.mp3"}} ] + assert AudioContext(column_name="audio_col").get_contexts({"audio_col": "https://example.com/download?id=123"}) == [ + {"type": "audio", "source": {"type": "url", "url": "https://example.com/download?id=123"}} + ] + assert AudioContext(column_name="audio_col").get_contexts({"audio_col": "data:audio/mpeg;base64,audio1base64"}) == [ { "type": "audio", @@ -329,6 +333,11 @@ def test_audio_context_validate_audio_format() -> None: {"audio_base64": "data:audio/mpeg;base64,audio1base64"} ) + with pytest.raises(ValueError, match="Local audio paths are not supported"): + AudioContext(column_name="audio_base64", audio_format=AudioFormat.MP3).get_contexts( + {"audio_base64": "screen_recording.mp3"} + ) + def test_video_context_get_contexts_single_string() -> None: video_context = VideoContext( @@ -375,6 +384,10 @@ def test_video_context_auto_detect_url_and_data_uri() -> None: {"type": "video", "source": {"type": "url", "url": "https://example.com/video.mp4"}} ] + assert VideoContext(column_name="video_col").get_contexts({"video_col": "https://example.com/download?id=123"}) == [ + {"type": "video", "source": {"type": "url", "url": "https://example.com/download?id=123"}} + ] + assert VideoContext(column_name="video_col").get_contexts({"video_col": "data:video/mp4;base64,video1base64"}) == [ {"type": "video", "source": {"type": "base64", "media_type": "video/mp4", "data": "video1base64"}} ] @@ -392,6 +405,11 @@ def test_video_context_validate_video_format() -> None: {"video_base64": "data:video/mp4;base64,video1base64"} ) + with pytest.raises(ValueError, match="Local video paths are not supported"): + VideoContext(column_name="video_base64", video_format=VideoFormat.MP4).get_contexts( + {"video_base64": "screen_recording.mp4"} + ) + def test_inference_parameters_default_construction(): empty_inference_parameters = ChatCompletionInferenceParams() diff --git a/packages/data-designer-config/tests/config/utils/test_media_helpers.py b/packages/data-designer-config/tests/config/utils/test_media_helpers.py index 184f1c8d5..fbb32011d 100644 --- a/packages/data-designer-config/tests/config/utils/test_media_helpers.py +++ b/packages/data-designer-config/tests/config/utils/test_media_helpers.py @@ -11,7 +11,10 @@ VideoFormat, audio_format_from_mime_type, audio_mime_type, + is_audio_path, is_audio_url, + is_media_url, + is_video_path, is_video_url, normalize_media_context_values, parse_base64_data_uri, @@ -34,6 +37,7 @@ def test_parse_base64_data_uri() -> None: def test_audio_url_detection() -> None: + assert is_media_url("https://example.com/download?id=123") is True assert is_audio_url("https://example.com/audio.mp3") is True assert is_audio_url("https://example.com/audio.wav?download=1") is True assert is_audio_url("https://example.com/image.png") is False @@ -41,14 +45,27 @@ def test_audio_url_detection() -> None: def test_video_url_detection() -> None: + assert is_media_url("https://example.com/download?id=123") is True assert is_video_url("https://example.com/video.mp4") is True assert is_video_url("https://example.com/video.webm?download=1") is True assert is_video_url("https://example.com/audio.mp3") is False assert is_video_url(123) is False # type: ignore[arg-type] +def test_local_media_path_detection() -> None: + assert is_audio_path("screen_recording.mp3") is True + assert is_audio_path("nested/screen_recording.wav") is True + assert is_audio_path("https://example.com/audio.mp3") is False + assert is_video_path("screen_recording.mp4") is True + assert is_video_path("nested/screen_recording.webm") is True + assert is_video_path("https://example.com/video.mp4") is False + + def test_media_format_mime_helpers() -> None: assert audio_mime_type(AudioFormat.MP3) == "audio/mpeg" assert audio_format_from_mime_type("audio/mpeg") == AudioFormat.MP3 + assert audio_format_from_mime_type("audio/mp3") == AudioFormat.MP3 + assert audio_format_from_mime_type("audio/x-wav") == AudioFormat.WAV assert video_mime_type(VideoFormat.MP4) == "video/mp4" assert video_format_from_mime_type("video/mp4") == VideoFormat.MP4 + assert video_format_from_mime_type("VIDEO/MP4") == VideoFormat.MP4 diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py index debe277b9..d2e8b66ca 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py @@ -18,6 +18,14 @@ _DEFAULT_MAX_TOKENS = 4096 _DATA_URI_RE = re.compile(r"^data:(?P[^;]+);base64,(?P.+)$") +_UNSUPPORTED_MEDIA_BLOCK_MODALITIES: dict[str, str] = { + "audio": "audio", + "audio_url": "audio", + "input_audio": "audio", + "video": "video", + "video_url": "video", + "input_video": "video", +} class UnsupportedAnthropicMediaBlockError(ValueError): @@ -207,8 +215,8 @@ def translate_content_blocks(content: Any) -> list[dict[str, Any]]: if isinstance(block, dict) and block.get("type") == "image": translated.append(translate_canonical_image_block(block)) continue - if isinstance(block, dict) and block.get("type") in {"audio", "video"}: - raise UnsupportedAnthropicMediaBlockError(block["type"]) + if isinstance(block, dict) and block.get("type") in _UNSUPPORTED_MEDIA_BLOCK_MODALITIES: + raise UnsupportedAnthropicMediaBlockError(_UNSUPPORTED_MEDIA_BLOCK_MODALITIES[block["type"]]) # Anthropic rejects empty text blocks β€” drop them. if isinstance(block, dict) and block.get("type") == "text" and not block.get("text"): continue diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 81a935282..39b667e22 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -309,6 +309,7 @@ def generate( prompt. parser (func(str) -> Any): A function applied to the LLM response which processes an LLM response into some output object. Default: identity function. + multi_modal_context: Optional list of image, audio, or video context blocks. tool_alias (str | None): Optional tool configuration alias. When provided, the model may call permitted tools from the configured MCP providers. The alias must reference a ToolConfig registered in the MCPRegistry. @@ -627,7 +628,7 @@ def generate_image( Args: prompt: The prompt for image generation - multi_modal_context: Optional list of image contexts for multi-modal generation. + multi_modal_context: Optional list of image, audio, or video contexts for multi-modal generation. Only used with autoregressive models via chat completions API. skip_usage_tracking: Whether to skip usage tracking **kwargs: Additional arguments to pass to the model @@ -686,7 +687,7 @@ async def agenerate_image( Args: prompt: The prompt for image generation - multi_modal_context: Optional list of image contexts for multi-modal generation. + multi_modal_context: Optional list of image, audio, or video contexts for multi-modal generation. Only used with autoregressive models via chat completions API. skip_usage_tracking: Whether to skip usage tracking **kwargs: Additional arguments to pass to the model diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py index b60b00f67..ecab9cc90 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py @@ -363,6 +363,20 @@ def test_translate_content_blocks_rejects_unsupported_media(modality: str) -> No translate_content_blocks([{"type": modality, "source": {"type": "url", "url": "https://example.com/media"}}]) +@pytest.mark.parametrize( + ("block_type", "modality"), + [ + ("audio_url", "audio"), + ("input_audio", "audio"), + ("video_url", "video"), + ("input_video", "video"), + ], +) +def test_translate_content_blocks_rejects_provider_specific_media(block_type: str, modality: str) -> None: + with pytest.raises(UnsupportedAnthropicMediaBlockError, match=f"{modality} context"): + translate_content_blocks([{"type": block_type}]) + + def test_translate_content_blocks_rejects_malformed_image_url_block() -> None: with pytest.raises(TypeError, match="image_url block must contain a dict"): translate_content_blocks( diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py index 3e0ec82e9..47af95e9c 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py @@ -356,7 +356,7 @@ def test_completion_translates_audio_url_blocks() -> None: sync_mock = make_mock_sync_client(_chat_response()) client = _make_client(sync_client=sync_mock) - audio_block = {"type": "audio", "source": {"type": "url", "url": "https://example.com/audio.mp3"}} + audio_block = {"type": "audio", "source": {"type": "url", "url": "https://example.com/download?id=123"}} request = ChatCompletionRequest( model=MODEL, messages=[{"role": "user", "content": [audio_block, {"type": "text", "text": "Transcribe this."}]}], @@ -365,14 +365,14 @@ def test_completion_translates_audio_url_blocks() -> None: payload = sync_mock.post.call_args.kwargs["json"] content = payload["messages"][0]["content"] - assert content[0] == {"type": "audio_url", "audio_url": {"url": "https://example.com/audio.mp3"}} + assert content[0] == {"type": "audio_url", "audio_url": {"url": "https://example.com/download?id=123"}} def test_completion_translates_video_blocks() -> None: sync_mock = make_mock_sync_client(_chat_response()) client = _make_client(sync_client=sync_mock) - video_block = {"type": "video", "source": {"type": "url", "url": "https://example.com/video.mp4"}} + video_block = {"type": "video", "source": {"type": "url", "url": "https://example.com/download?id=123"}} request = ChatCompletionRequest( model=MODEL, messages=[{"role": "user", "content": [video_block, {"type": "text", "text": "Describe this."}]}], @@ -381,7 +381,7 @@ def test_completion_translates_video_blocks() -> None: payload = sync_mock.post.call_args.kwargs["json"] content = payload["messages"][0]["content"] - assert content[0] == {"type": "video_url", "video_url": {"url": "https://example.com/video.mp4"}} + assert content[0] == {"type": "video_url", "video_url": {"url": "https://example.com/download?id=123"}} def test_completion_translates_base64_video_blocks() -> None: From 07d26dcbb339578d7a2c83e4114ff1987129d7db Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 18 May 2026 14:48:43 -0600 Subject: [PATCH 03/15] test: add audio video context smoke notebook Add a Jupytext source notebook and generated Colab artifact that exercise audio/video context URL, base64, local path rejection, OpenAI-compatible payload translation, and Anthropic unsupported-media handling. Refs #671 --- .../7-audio-video-context-smoke-test.ipynb | 405 ++++++++++++++++++ .../7-audio-video-context-smoke-test.py | 250 +++++++++++ 2 files changed, 655 insertions(+) create mode 100644 docs/colab_notebooks/7-audio-video-context-smoke-test.ipynb create mode 100644 docs/notebook_source/7-audio-video-context-smoke-test.py diff --git a/docs/colab_notebooks/7-audio-video-context-smoke-test.ipynb b/docs/colab_notebooks/7-audio-video-context-smoke-test.ipynb new file mode 100644 index 000000000..a3f12590a --- /dev/null +++ b/docs/colab_notebooks/7-audio-video-context-smoke-test.ipynb @@ -0,0 +1,405 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "46e1962e", + "metadata": { + "nemo_colab_inject": true + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "id": "140da62d", + "metadata": {}, + "source": [ + "# Audio and Video Context Smoke Test" + ] + }, + { + "cell_type": "markdown", + "id": "b0e37630", + "metadata": {}, + "source": [ + "This notebook verifies the audio and video context flow without calling a real model endpoint.\n", + "\n", + "It checks that:\n", + "\n", + "- `AudioContext` and `VideoContext` produce canonical media blocks.\n", + "- HTTP(S) URLs stay as URLs, including URLs without file extensions.\n", + "- Base64 media values keep the required media metadata.\n", + "- Local path-looking values are rejected instead of being resolved in the config layer.\n", + "- The OpenAI-compatible adapter forwards URL media as URL blocks in the endpoint payload." + ] + }, + { + "cell_type": "markdown", + "id": "27fda0b6", + "metadata": { + "nemo_colab_inject": true + }, + "source": [ + "### ⚑ Colab Setup\n", + "\n", + "Run the cells below to install the dependencies and set up the API key. If you don't have an API key, you can generate one from [build.nvidia.com](https://build.nvidia.com).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3438b6b3", + "metadata": { + "nemo_colab_inject": true + }, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install -U data-designer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "af547269", + "metadata": { + "nemo_colab_inject": true + }, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "from google.colab import userdata\n", + "\n", + "try:\n", + " os.environ[\"NVIDIA_API_KEY\"] = userdata.get(\"NVIDIA_API_KEY\")\n", + "except userdata.SecretNotFoundError:\n", + " os.environ[\"NVIDIA_API_KEY\"] = getpass.getpass(\"Enter your NVIDIA API key: \")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08f0fe7e", + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import annotations\n", + "\n", + "import json\n", + "from collections.abc import Callable\n", + "from typing import Any\n", + "from unittest.mock import MagicMock\n", + "\n", + "import data_designer.config as dd\n", + "from data_designer.engine.models.clients.adapters.anthropic_translation import (\n", + " UnsupportedAnthropicMediaBlockError,\n", + " translate_content_blocks,\n", + ")\n", + "from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode\n", + "from data_designer.engine.models.clients.adapters.openai_compatible import OpenAICompatibleClient\n", + "from data_designer.engine.models.clients.types import ChatCompletionRequest" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "85337ecd", + "metadata": {}, + "outputs": [], + "source": [ + "AUDIO_URL = \"https://example.com/download?id=audio-123\"\n", + "VIDEO_URL = \"https://example.com/download?id=video-456\"\n", + "AUDIO_DATA_URI = \"data:audio/mpeg;base64,YXVkaW8=\"\n", + "VIDEO_DATA_URI = \"data:video/mp4;base64,dmlkZW8=\"\n", + "\n", + "record = {\n", + " \"audio_url\": AUDIO_URL,\n", + " \"video_url\": VIDEO_URL,\n", + " \"audio_data_uri\": AUDIO_DATA_URI,\n", + " \"video_data_uri\": VIDEO_DATA_URI,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "c7396746", + "metadata": {}, + "source": [ + "## Config blocks keep URLs and base64 media distinct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07c16ecd", + "metadata": {}, + "outputs": [], + "source": [ + "url_context_blocks = [\n", + " *dd.AudioContext(column_name=\"audio_url\").get_contexts(record),\n", + " *dd.VideoContext(column_name=\"video_url\").get_contexts(record),\n", + "]\n", + "\n", + "assert url_context_blocks == [\n", + " {\"type\": \"audio\", \"source\": {\"type\": \"url\", \"url\": AUDIO_URL}},\n", + " {\"type\": \"video\", \"source\": {\"type\": \"url\", \"url\": VIDEO_URL}},\n", + "]\n", + "\n", + "url_context_blocks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d5f4936d", + "metadata": {}, + "outputs": [], + "source": [ + "base64_context_blocks = [\n", + " *dd.AudioContext(column_name=\"audio_data_uri\").get_contexts(record),\n", + " *dd.VideoContext(column_name=\"video_data_uri\").get_contexts(record),\n", + "]\n", + "\n", + "assert base64_context_blocks == [\n", + " {\n", + " \"type\": \"audio\",\n", + " \"source\": {\n", + " \"type\": \"base64\",\n", + " \"media_type\": \"audio/mpeg\",\n", + " \"data\": \"YXVkaW8=\",\n", + " \"format\": \"mp3\",\n", + " },\n", + " },\n", + " {\"type\": \"video\", \"source\": {\"type\": \"base64\", \"media_type\": \"video/mp4\", \"data\": \"dmlkZW8=\"}},\n", + "]\n", + "\n", + "base64_context_blocks" + ] + }, + { + "cell_type": "markdown", + "id": "afee0a9a", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "## Local file names are not resolved by audio/video context" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8c792fc", + "metadata": {}, + "outputs": [], + "source": [ + "def assert_raises_message(callback: Callable[[], Any], expected_message: str) -> str:\n", + " try:\n", + " callback()\n", + " except ValueError as exc:\n", + " message = str(exc)\n", + " assert expected_message in message\n", + " return message\n", + " raise AssertionError(f\"Expected ValueError containing {expected_message!r}\")\n", + "\n", + "\n", + "audio_path_error = assert_raises_message(\n", + " lambda: dd.AudioContext(column_name=\"audio_path\", audio_format=dd.AudioFormat.MP3).get_contexts(\n", + " {\"audio_path\": \"screen_recording.mp3\"}\n", + " ),\n", + " \"Local audio paths are not supported\",\n", + ")\n", + "video_path_error = assert_raises_message(\n", + " lambda: dd.VideoContext(column_name=\"video_path\", video_format=dd.VideoFormat.MP4).get_contexts(\n", + " {\"video_path\": \"screen_recording.mp4\"}\n", + " ),\n", + " \"Local video paths are not supported\",\n", + ")\n", + "\n", + "audio_path_error, video_path_error" + ] + }, + { + "cell_type": "markdown", + "id": "d157d387", + "metadata": {}, + "source": [ + "## Column config round-trips mixed media context" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ebc939de", + "metadata": {}, + "outputs": [], + "source": [ + "column_config = dd.LLMTextColumnConfig(\n", + " name=\"media_summary\",\n", + " prompt=\"Summarize the audio and video context.\",\n", + " model_alias=\"mock-multimodal-model\",\n", + " multi_modal_context=[\n", + " dd.AudioContext(column_name=\"audio_url\"),\n", + " dd.VideoContext(column_name=\"video_url\"),\n", + " ],\n", + ")\n", + "\n", + "round_tripped = dd.LLMTextColumnConfig(**column_config.model_dump())\n", + "\n", + "assert [type(context).__name__ for context in round_tripped.multi_modal_context or []] == [\n", + " \"AudioContext\",\n", + " \"VideoContext\",\n", + "]\n", + "assert set(round_tripped.required_columns) == {\"audio_url\", \"video_url\"}\n", + "\n", + "round_tripped" + ] + }, + { + "cell_type": "markdown", + "id": "928d988c", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "## OpenAI-compatible payloads send URL media as URLs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6260e981", + "metadata": {}, + "outputs": [], + "source": [ + "def mock_httpx_response(json_data: dict[str, Any], status_code: int = 200) -> MagicMock:\n", + " response = MagicMock()\n", + " response.status_code = status_code\n", + " response.json.return_value = json_data\n", + " response.text = json.dumps(json_data)\n", + " response.headers = {}\n", + " return response\n", + "\n", + "\n", + "def make_mock_sync_client(response_json: dict[str, Any]) -> MagicMock:\n", + " client = MagicMock()\n", + " client.post = MagicMock(return_value=mock_httpx_response(response_json))\n", + " return client\n", + "\n", + "\n", + "def chat_response(content: str = \"ok\") -> dict[str, Any]:\n", + " return {\n", + " \"choices\": [{\"index\": 0, \"message\": {\"role\": \"assistant\", \"content\": content}, \"finish_reason\": \"stop\"}],\n", + " \"usage\": {\"prompt_tokens\": 3, \"completion_tokens\": 1, \"total_tokens\": 4},\n", + " }\n", + "\n", + "\n", + "sync_client = make_mock_sync_client(chat_response())\n", + "client = OpenAICompatibleClient(\n", + " provider_name=\"smoke-provider\",\n", + " endpoint=\"https://api.example.com/v1\",\n", + " api_key=\"not-used\",\n", + " concurrency_mode=ClientConcurrencyMode.SYNC,\n", + " sync_client=sync_client,\n", + ")\n", + "\n", + "client.completion(\n", + " ChatCompletionRequest(\n", + " model=\"mock-multimodal-model\",\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": [*url_context_blocks, {\"type\": \"text\", \"text\": \"Summarize the media.\"}],\n", + " }\n", + " ],\n", + " )\n", + ")\n", + "\n", + "url_payload_blocks = sync_client.post.call_args.kwargs[\"json\"][\"messages\"][0][\"content\"]\n", + "\n", + "assert url_payload_blocks[:2] == [\n", + " {\"type\": \"audio_url\", \"audio_url\": {\"url\": AUDIO_URL}},\n", + " {\"type\": \"video_url\", \"video_url\": {\"url\": VIDEO_URL}},\n", + "]\n", + "\n", + "url_payload_blocks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e6d78eb", + "metadata": {}, + "outputs": [], + "source": [ + "sync_client = make_mock_sync_client(chat_response())\n", + "client = OpenAICompatibleClient(\n", + " provider_name=\"smoke-provider\",\n", + " endpoint=\"https://api.example.com/v1\",\n", + " api_key=\"not-used\",\n", + " concurrency_mode=ClientConcurrencyMode.SYNC,\n", + " sync_client=sync_client,\n", + ")\n", + "\n", + "client.completion(\n", + " ChatCompletionRequest(\n", + " model=\"mock-multimodal-model\",\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": [*base64_context_blocks, {\"type\": \"text\", \"text\": \"Summarize the media.\"}],\n", + " }\n", + " ],\n", + " )\n", + ")\n", + "\n", + "base64_payload_blocks = sync_client.post.call_args.kwargs[\"json\"][\"messages\"][0][\"content\"]\n", + "\n", + "assert base64_payload_blocks[:2] == [\n", + " {\"type\": \"input_audio\", \"input_audio\": {\"data\": \"YXVkaW8=\", \"format\": \"mp3\"}},\n", + " {\"type\": \"video_url\", \"video_url\": {\"url\": \"data:video/mp4;base64,dmlkZW8=\"}},\n", + "]\n", + "\n", + "base64_payload_blocks" + ] + }, + { + "cell_type": "markdown", + "id": "6ebd8210", + "metadata": {}, + "source": [ + "## Anthropic rejects unsupported audio/video context before an HTTP call" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "137027ab", + "metadata": {}, + "outputs": [], + "source": [ + "for block in url_context_blocks:\n", + " try:\n", + " translate_content_blocks([block])\n", + " except UnsupportedAnthropicMediaBlockError as exc:\n", + " assert exc.modality in {\"audio\", \"video\"}\n", + " else:\n", + " raise AssertionError(f\"Expected Anthropic to reject {block['type']} context\")\n", + "\n", + "\"All audio/video context smoke checks passed.\"" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebook_source/7-audio-video-context-smoke-test.py b/docs/notebook_source/7-audio-video-context-smoke-test.py new file mode 100644 index 000000000..4f468b86d --- /dev/null +++ b/docs/notebook_source/7-audio-video-context-smoke-test.py @@ -0,0 +1,250 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.18.1 +# kernelspec: +# display_name: .venv +# language: python +# name: python3 +# --- + +# %% [markdown] +# # Audio and Video Context Smoke Test + +# %% [markdown] +# This notebook verifies the audio and video context flow without calling a real model endpoint. +# +# It checks that: +# +# - `AudioContext` and `VideoContext` produce canonical media blocks. +# - HTTP(S) URLs stay as URLs, including URLs without file extensions. +# - Base64 media values keep the required media metadata. +# - Local path-looking values are rejected instead of being resolved in the config layer. +# - The OpenAI-compatible adapter forwards URL media as URL blocks in the endpoint payload. + +# %% +from __future__ import annotations + +import json +from collections.abc import Callable +from typing import Any +from unittest.mock import MagicMock + +import data_designer.config as dd +from data_designer.engine.models.clients.adapters.anthropic_translation import ( + UnsupportedAnthropicMediaBlockError, + translate_content_blocks, +) +from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode +from data_designer.engine.models.clients.adapters.openai_compatible import OpenAICompatibleClient +from data_designer.engine.models.clients.types import ChatCompletionRequest + +# %% +AUDIO_URL = "https://example.com/download?id=audio-123" +VIDEO_URL = "https://example.com/download?id=video-456" +AUDIO_DATA_URI = "data:audio/mpeg;base64,YXVkaW8=" +VIDEO_DATA_URI = "data:video/mp4;base64,dmlkZW8=" + +record = { + "audio_url": AUDIO_URL, + "video_url": VIDEO_URL, + "audio_data_uri": AUDIO_DATA_URI, + "video_data_uri": VIDEO_DATA_URI, +} + +# %% [markdown] +# ## Config blocks keep URLs and base64 media distinct + +# %% +url_context_blocks = [ + *dd.AudioContext(column_name="audio_url").get_contexts(record), + *dd.VideoContext(column_name="video_url").get_contexts(record), +] + +assert url_context_blocks == [ + {"type": "audio", "source": {"type": "url", "url": AUDIO_URL}}, + {"type": "video", "source": {"type": "url", "url": VIDEO_URL}}, +] + +url_context_blocks + +# %% +base64_context_blocks = [ + *dd.AudioContext(column_name="audio_data_uri").get_contexts(record), + *dd.VideoContext(column_name="video_data_uri").get_contexts(record), +] + +assert base64_context_blocks == [ + { + "type": "audio", + "source": { + "type": "base64", + "media_type": "audio/mpeg", + "data": "YXVkaW8=", + "format": "mp3", + }, + }, + {"type": "video", "source": {"type": "base64", "media_type": "video/mp4", "data": "dmlkZW8="}}, +] + +base64_context_blocks + +# %% [markdown] +# ## Local file names are not resolved by audio/video context + + +# %% +def assert_raises_message(callback: Callable[[], Any], expected_message: str) -> str: + try: + callback() + except ValueError as exc: + message = str(exc) + assert expected_message in message + return message + raise AssertionError(f"Expected ValueError containing {expected_message!r}") + + +audio_path_error = assert_raises_message( + lambda: dd.AudioContext(column_name="audio_path", audio_format=dd.AudioFormat.MP3).get_contexts( + {"audio_path": "screen_recording.mp3"} + ), + "Local audio paths are not supported", +) +video_path_error = assert_raises_message( + lambda: dd.VideoContext(column_name="video_path", video_format=dd.VideoFormat.MP4).get_contexts( + {"video_path": "screen_recording.mp4"} + ), + "Local video paths are not supported", +) + +audio_path_error, video_path_error + +# %% [markdown] +# ## Column config round-trips mixed media context + +# %% +column_config = dd.LLMTextColumnConfig( + name="media_summary", + prompt="Summarize the audio and video context.", + model_alias="mock-multimodal-model", + multi_modal_context=[ + dd.AudioContext(column_name="audio_url"), + dd.VideoContext(column_name="video_url"), + ], +) + +round_tripped = dd.LLMTextColumnConfig(**column_config.model_dump()) + +assert [type(context).__name__ for context in round_tripped.multi_modal_context or []] == [ + "AudioContext", + "VideoContext", +] +assert set(round_tripped.required_columns) == {"audio_url", "video_url"} + +round_tripped + +# %% [markdown] +# ## OpenAI-compatible payloads send URL media as URLs + + +# %% +def mock_httpx_response(json_data: dict[str, Any], status_code: int = 200) -> MagicMock: + response = MagicMock() + response.status_code = status_code + response.json.return_value = json_data + response.text = json.dumps(json_data) + response.headers = {} + return response + + +def make_mock_sync_client(response_json: dict[str, Any]) -> MagicMock: + client = MagicMock() + client.post = MagicMock(return_value=mock_httpx_response(response_json)) + return client + + +def chat_response(content: str = "ok") -> dict[str, Any]: + return { + "choices": [{"index": 0, "message": {"role": "assistant", "content": content}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 3, "completion_tokens": 1, "total_tokens": 4}, + } + + +sync_client = make_mock_sync_client(chat_response()) +client = OpenAICompatibleClient( + provider_name="smoke-provider", + endpoint="https://api.example.com/v1", + api_key="not-used", + concurrency_mode=ClientConcurrencyMode.SYNC, + sync_client=sync_client, +) + +client.completion( + ChatCompletionRequest( + model="mock-multimodal-model", + messages=[ + { + "role": "user", + "content": [*url_context_blocks, {"type": "text", "text": "Summarize the media."}], + } + ], + ) +) + +url_payload_blocks = sync_client.post.call_args.kwargs["json"]["messages"][0]["content"] + +assert url_payload_blocks[:2] == [ + {"type": "audio_url", "audio_url": {"url": AUDIO_URL}}, + {"type": "video_url", "video_url": {"url": VIDEO_URL}}, +] + +url_payload_blocks + +# %% +sync_client = make_mock_sync_client(chat_response()) +client = OpenAICompatibleClient( + provider_name="smoke-provider", + endpoint="https://api.example.com/v1", + api_key="not-used", + concurrency_mode=ClientConcurrencyMode.SYNC, + sync_client=sync_client, +) + +client.completion( + ChatCompletionRequest( + model="mock-multimodal-model", + messages=[ + { + "role": "user", + "content": [*base64_context_blocks, {"type": "text", "text": "Summarize the media."}], + } + ], + ) +) + +base64_payload_blocks = sync_client.post.call_args.kwargs["json"]["messages"][0]["content"] + +assert base64_payload_blocks[:2] == [ + {"type": "input_audio", "input_audio": {"data": "YXVkaW8=", "format": "mp3"}}, + {"type": "video_url", "video_url": {"url": "data:video/mp4;base64,dmlkZW8="}}, +] + +base64_payload_blocks + +# %% [markdown] +# ## Anthropic rejects unsupported audio/video context before an HTTP call + +# %% +for block in url_context_blocks: + try: + translate_content_blocks([block]) + except UnsupportedAnthropicMediaBlockError as exc: + assert exc.modality in {"audio", "video"} + else: + raise AssertionError(f"Expected Anthropic to reject {block['type']} context") + +"All audio/video context smoke checks passed." From ac9b4356173829fe720fdcf03b6209c461840291 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 18 May 2026 14:54:48 -0600 Subject: [PATCH 04/15] test: make media context notebook end to end Rewrite the audio/video smoke notebook to run a full Data Designer preview against a local OpenAI-compatible HTTP server. Assert the generated dataset, captured endpoint payload, URL/base64 translation, and local path rejection through the interface pipeline. Refs #671 --- .../7-audio-video-context-smoke-test.ipynb | 466 +++++++++--------- .../7-audio-video-context-smoke-test.py | 392 ++++++++------- 2 files changed, 453 insertions(+), 405 deletions(-) diff --git a/docs/colab_notebooks/7-audio-video-context-smoke-test.ipynb b/docs/colab_notebooks/7-audio-video-context-smoke-test.ipynb index a3f12590a..1fbdb1d33 100644 --- a/docs/colab_notebooks/7-audio-video-context-smoke-test.ipynb +++ b/docs/colab_notebooks/7-audio-video-context-smoke-test.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "46e1962e", + "id": "adb8cb68", "metadata": { "nemo_colab_inject": true }, @@ -12,31 +12,29 @@ }, { "cell_type": "markdown", - "id": "140da62d", + "id": "2bd587c5", "metadata": {}, "source": [ - "# Audio and Video Context Smoke Test" + "# Audio and Video Context End-to-End Smoke Test" ] }, { "cell_type": "markdown", - "id": "b0e37630", + "id": "d6fcf989", "metadata": {}, "source": [ - "This notebook verifies the audio and video context flow without calling a real model endpoint.\n", + "This notebook verifies audio and video context through a full Data Designer preview pipeline.\n", "\n", - "It checks that:\n", + "It starts a tiny local OpenAI-compatible HTTP server, configures Data Designer to use that server as a model\n", + "provider, builds a seeded dataset with audio/video context columns, runs `DataDesigner.preview(...)`, and asserts on\n", + "the payload received by the endpoint.\n", "\n", - "- `AudioContext` and `VideoContext` produce canonical media blocks.\n", - "- HTTP(S) URLs stay as URLs, including URLs without file extensions.\n", - "- Base64 media values keep the required media metadata.\n", - "- Local path-looking values are rejected instead of being resolved in the config layer.\n", - "- The OpenAI-compatible adapter forwards URL media as URL blocks in the endpoint payload." + "No external model API key is required." ] }, { "cell_type": "markdown", - "id": "27fda0b6", + "id": "0192ed76", "metadata": { "nemo_colab_inject": true }, @@ -49,7 +47,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3438b6b3", + "id": "9a4a01df", "metadata": { "nemo_colab_inject": true }, @@ -62,7 +60,7 @@ { "cell_type": "code", "execution_count": null, - "id": "af547269", + "id": "8d91f318", "metadata": { "nemo_colab_inject": true }, @@ -82,31 +80,37 @@ { "cell_type": "code", "execution_count": null, - "id": "08f0fe7e", + "id": "3565c497", "metadata": {}, "outputs": [], "source": [ "from __future__ import annotations\n", "\n", "import json\n", + "import tempfile\n", + "import threading\n", "from collections.abc import Callable\n", + "from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer\n", "from typing import Any\n", - "from unittest.mock import MagicMock\n", + "\n", + "import pandas as pd\n", "\n", "import data_designer.config as dd\n", - "from data_designer.engine.models.clients.adapters.anthropic_translation import (\n", - " UnsupportedAnthropicMediaBlockError,\n", - " translate_content_blocks,\n", - ")\n", - "from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode\n", - "from data_designer.engine.models.clients.adapters.openai_compatible import OpenAICompatibleClient\n", - "from data_designer.engine.models.clients.types import ChatCompletionRequest" + "from data_designer.interface import DataDesigner" + ] + }, + { + "cell_type": "markdown", + "id": "dbb38552", + "metadata": {}, + "source": [ + "## Seed media values" ] }, { "cell_type": "code", "execution_count": null, - "id": "85337ecd", + "id": "4a4d5fb8", "metadata": {}, "outputs": [], "source": [ @@ -115,281 +119,295 @@ "AUDIO_DATA_URI = \"data:audio/mpeg;base64,YXVkaW8=\"\n", "VIDEO_DATA_URI = \"data:video/mp4;base64,dmlkZW8=\"\n", "\n", - "record = {\n", - " \"audio_url\": AUDIO_URL,\n", - " \"video_url\": VIDEO_URL,\n", - " \"audio_data_uri\": AUDIO_DATA_URI,\n", - " \"video_data_uri\": VIDEO_DATA_URI,\n", - "}" + "seed_df = pd.DataFrame(\n", + " [\n", + " {\n", + " \"record_id\": \"row-1\",\n", + " \"audio_url\": AUDIO_URL,\n", + " \"video_url\": VIDEO_URL,\n", + " \"audio_data_uri\": AUDIO_DATA_URI,\n", + " \"video_data_uri\": VIDEO_DATA_URI,\n", + " }\n", + " ]\n", + ")\n", + "\n", + "seed_df" ] }, { "cell_type": "markdown", - "id": "c7396746", - "metadata": {}, + "id": "7af2eb26", + "metadata": { + "lines_to_next_cell": 2 + }, "source": [ - "## Config blocks keep URLs and base64 media distinct" + "## Local OpenAI-compatible test endpoint" ] }, { "cell_type": "code", "execution_count": null, - "id": "07c16ecd", - "metadata": {}, - "outputs": [], - "source": [ - "url_context_blocks = [\n", - " *dd.AudioContext(column_name=\"audio_url\").get_contexts(record),\n", - " *dd.VideoContext(column_name=\"video_url\").get_contexts(record),\n", - "]\n", - "\n", - "assert url_context_blocks == [\n", - " {\"type\": \"audio\", \"source\": {\"type\": \"url\", \"url\": AUDIO_URL}},\n", - " {\"type\": \"video\", \"source\": {\"type\": \"url\", \"url\": VIDEO_URL}},\n", - "]\n", - "\n", - "url_context_blocks" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d5f4936d", - "metadata": {}, + "id": "747d9851", + "metadata": { + "lines_to_next_cell": 2 + }, "outputs": [], "source": [ - "base64_context_blocks = [\n", - " *dd.AudioContext(column_name=\"audio_data_uri\").get_contexts(record),\n", - " *dd.VideoContext(column_name=\"video_data_uri\").get_contexts(record),\n", - "]\n", + "class RecordingOpenAIHandler(BaseHTTPRequestHandler):\n", + " captured_requests: list[dict[str, Any]] = []\n", "\n", - "assert base64_context_blocks == [\n", - " {\n", - " \"type\": \"audio\",\n", - " \"source\": {\n", - " \"type\": \"base64\",\n", - " \"media_type\": \"audio/mpeg\",\n", - " \"data\": \"YXVkaW8=\",\n", - " \"format\": \"mp3\",\n", - " },\n", - " },\n", - " {\"type\": \"video\", \"source\": {\"type\": \"base64\", \"media_type\": \"video/mp4\", \"data\": \"dmlkZW8=\"}},\n", - "]\n", + " def do_POST(self) -> None:\n", + " content_length = int(self.headers.get(\"Content-Length\", \"0\"))\n", + " raw_body = self.rfile.read(content_length)\n", + " payload = json.loads(raw_body.decode(\"utf-8\"))\n", "\n", - "base64_context_blocks" + " self.captured_requests.append(\n", + " {\n", + " \"path\": self.path,\n", + " \"headers\": dict(self.headers),\n", + " \"json\": payload,\n", + " }\n", + " )\n", + "\n", + " content_blocks = payload[\"messages\"][0][\"content\"]\n", + " media_blocks = [\n", + " block\n", + " for block in content_blocks\n", + " if isinstance(block, dict) and block.get(\"type\") in {\"audio_url\", \"input_audio\", \"video_url\"}\n", + " ]\n", + " response_json = {\n", + " \"id\": \"chatcmpl-audio-video-smoke\",\n", + " \"object\": \"chat.completion\",\n", + " \"choices\": [\n", + " {\n", + " \"index\": 0,\n", + " \"message\": {\n", + " \"role\": \"assistant\",\n", + " \"content\": f\"mock summary received {len(media_blocks)} media blocks\",\n", + " },\n", + " \"finish_reason\": \"stop\",\n", + " }\n", + " ],\n", + " \"usage\": {\"prompt_tokens\": 10, \"completion_tokens\": 5, \"total_tokens\": 15},\n", + " }\n", + " response_body = json.dumps(response_json).encode(\"utf-8\")\n", + "\n", + " self.send_response(200)\n", + " self.send_header(\"Content-Type\", \"application/json\")\n", + " self.send_header(\"Content-Length\", str(len(response_body)))\n", + " self.end_headers()\n", + " self.wfile.write(response_body)\n", + "\n", + " def log_message(self, format: str, *args: Any) -> None:\n", + " return\n", + "\n", + "\n", + "class LocalOpenAIServer:\n", + " def __init__(self) -> None:\n", + " self._server = ThreadingHTTPServer((\"127.0.0.1\", 0), RecordingOpenAIHandler)\n", + " self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)\n", + "\n", + " @property\n", + " def endpoint(self) -> str:\n", + " return f\"http://127.0.0.1:{self._server.server_port}/v1\"\n", + "\n", + " @property\n", + " def captured_requests(self) -> list[dict[str, Any]]:\n", + " return RecordingOpenAIHandler.captured_requests\n", + "\n", + " def __enter__(self) -> LocalOpenAIServer:\n", + " RecordingOpenAIHandler.captured_requests = []\n", + " self._thread.start()\n", + " return self\n", + "\n", + " def __exit__(self, exc_type: object, exc: object, traceback: object) -> None:\n", + " self._server.shutdown()\n", + " self._server.server_close()\n", + " self._thread.join(timeout=5)" ] }, { "cell_type": "markdown", - "id": "afee0a9a", + "id": "236e2a16", "metadata": { "lines_to_next_cell": 2 }, "source": [ - "## Local file names are not resolved by audio/video context" + "## Build and run a full Data Designer preview" ] }, { "cell_type": "code", "execution_count": null, - "id": "e8c792fc", + "id": "d93a87e4", "metadata": {}, "outputs": [], "source": [ - "def assert_raises_message(callback: Callable[[], Any], expected_message: str) -> str:\n", - " try:\n", - " callback()\n", - " except ValueError as exc:\n", - " message = str(exc)\n", - " assert expected_message in message\n", - " return message\n", - " raise AssertionError(f\"Expected ValueError containing {expected_message!r}\")\n", + "def build_audio_video_config(model_alias: str, provider_name: str) -> dd.DataDesignerConfigBuilder:\n", + " config_builder = dd.DataDesignerConfigBuilder(\n", + " model_configs=[\n", + " dd.ModelConfig(\n", + " alias=model_alias,\n", + " model=\"local-audio-video-model\",\n", + " provider=provider_name,\n", + " inference_parameters=dd.ChatCompletionInferenceParams(\n", + " temperature=0.0,\n", + " max_tokens=64,\n", + " max_parallel_requests=1,\n", + " timeout=10,\n", + " ),\n", + " skip_health_check=True,\n", + " )\n", + " ]\n", + " )\n", + " config_builder.with_seed_dataset(dd.DataFrameSeedSource(df=seed_df))\n", + " config_builder.add_column(\n", + " dd.LLMTextColumnConfig(\n", + " name=\"media_summary\",\n", + " model_alias=model_alias,\n", + " prompt=\"Summarize the audio and video context for {{ record_id }}.\",\n", + " multi_modal_context=[\n", + " dd.AudioContext(column_name=\"audio_url\"),\n", + " dd.VideoContext(column_name=\"video_url\"),\n", + " dd.AudioContext(column_name=\"audio_data_uri\"),\n", + " dd.VideoContext(column_name=\"video_data_uri\"),\n", + " ],\n", + " )\n", + " )\n", + " return config_builder\n", "\n", "\n", - "audio_path_error = assert_raises_message(\n", - " lambda: dd.AudioContext(column_name=\"audio_path\", audio_format=dd.AudioFormat.MP3).get_contexts(\n", - " {\"audio_path\": \"screen_recording.mp3\"}\n", - " ),\n", - " \"Local audio paths are not supported\",\n", - ")\n", - "video_path_error = assert_raises_message(\n", - " lambda: dd.VideoContext(column_name=\"video_path\", video_format=dd.VideoFormat.MP4).get_contexts(\n", - " {\"video_path\": \"screen_recording.mp4\"}\n", - " ),\n", - " \"Local video paths are not supported\",\n", - ")\n", + "with LocalOpenAIServer() as local_server, tempfile.TemporaryDirectory() as artifact_dir:\n", + " provider = dd.ModelProvider(\n", + " name=\"local-openai\",\n", + " endpoint=local_server.endpoint,\n", + " provider_type=\"openai\",\n", + " api_key=\"local-test-key\",\n", + " )\n", + " data_designer = DataDesigner(artifact_path=artifact_dir, model_providers=[provider])\n", + " config_builder = build_audio_video_config(model_alias=\"local-multimodal\", provider_name=provider.name)\n", + "\n", + " preview = data_designer.preview(config_builder, num_records=1)\n", + " captured_requests = local_server.captured_requests\n", "\n", - "audio_path_error, video_path_error" + "preview.dataset" ] }, { "cell_type": "markdown", - "id": "d157d387", + "id": "09aaa62b", "metadata": {}, "source": [ - "## Column config round-trips mixed media context" + "## Verify the preview output and endpoint payload" ] }, { "cell_type": "code", "execution_count": null, - "id": "ebc939de", + "id": "f23fc721", "metadata": {}, "outputs": [], "source": [ - "column_config = dd.LLMTextColumnConfig(\n", - " name=\"media_summary\",\n", - " prompt=\"Summarize the audio and video context.\",\n", - " model_alias=\"mock-multimodal-model\",\n", - " multi_modal_context=[\n", - " dd.AudioContext(column_name=\"audio_url\"),\n", - " dd.VideoContext(column_name=\"video_url\"),\n", - " ],\n", - ")\n", + "assert len(preview.dataset) == 1\n", + "assert preview.dataset.loc[0, \"media_summary\"] == \"mock summary received 4 media blocks\"\n", + "assert len(captured_requests) == 1\n", "\n", - "round_tripped = dd.LLMTextColumnConfig(**column_config.model_dump())\n", + "request = captured_requests[0]\n", + "payload = request[\"json\"]\n", + "content_blocks = payload[\"messages\"][0][\"content\"]\n", "\n", - "assert [type(context).__name__ for context in round_tripped.multi_modal_context or []] == [\n", - " \"AudioContext\",\n", - " \"VideoContext\",\n", + "expected_media_blocks = [\n", + " {\"type\": \"audio_url\", \"audio_url\": {\"url\": AUDIO_URL}},\n", + " {\"type\": \"video_url\", \"video_url\": {\"url\": VIDEO_URL}},\n", + " {\"type\": \"input_audio\", \"input_audio\": {\"data\": \"YXVkaW8=\", \"format\": \"mp3\"}},\n", + " {\"type\": \"video_url\", \"video_url\": {\"url\": \"data:video/mp4;base64,dmlkZW8=\"}},\n", "]\n", - "assert set(round_tripped.required_columns) == {\"audio_url\", \"video_url\"}\n", "\n", - "round_tripped" + "assert request[\"path\"] == \"/v1/chat/completions\"\n", + "assert payload[\"model\"] == \"local-audio-video-model\"\n", + "assert content_blocks[:4] == expected_media_blocks\n", + "assert content_blocks[4] == {\"type\": \"text\", \"text\": \"Summarize the audio and video context for row-1.\"}\n", + "\n", + "content_blocks" ] }, { "cell_type": "markdown", - "id": "928d988c", + "id": "821c0007", "metadata": { "lines_to_next_cell": 2 }, "source": [ - "## OpenAI-compatible payloads send URL media as URLs" + "## Verify local path rejection through the pipeline" ] }, { "cell_type": "code", "execution_count": null, - "id": "6260e981", + "id": "259da710", "metadata": {}, "outputs": [], "source": [ - "def mock_httpx_response(json_data: dict[str, Any], status_code: int = 200) -> MagicMock:\n", - " response = MagicMock()\n", - " response.status_code = status_code\n", - " response.json.return_value = json_data\n", - " response.text = json.dumps(json_data)\n", - " response.headers = {}\n", - " return response\n", - "\n", - "\n", - "def make_mock_sync_client(response_json: dict[str, Any]) -> MagicMock:\n", - " client = MagicMock()\n", - " client.post = MagicMock(return_value=mock_httpx_response(response_json))\n", - " return client\n", - "\n", - "\n", - "def chat_response(content: str = \"ok\") -> dict[str, Any]:\n", - " return {\n", - " \"choices\": [{\"index\": 0, \"message\": {\"role\": \"assistant\", \"content\": content}, \"finish_reason\": \"stop\"}],\n", - " \"usage\": {\"prompt_tokens\": 3, \"completion_tokens\": 1, \"total_tokens\": 4},\n", - " }\n", - "\n", - "\n", - "sync_client = make_mock_sync_client(chat_response())\n", - "client = OpenAICompatibleClient(\n", - " provider_name=\"smoke-provider\",\n", - " endpoint=\"https://api.example.com/v1\",\n", - " api_key=\"not-used\",\n", - " concurrency_mode=ClientConcurrencyMode.SYNC,\n", - " sync_client=sync_client,\n", - ")\n", - "\n", - "client.completion(\n", - " ChatCompletionRequest(\n", - " model=\"mock-multimodal-model\",\n", - " messages=[\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": [*url_context_blocks, {\"type\": \"text\", \"text\": \"Summarize the media.\"}],\n", - " }\n", - " ],\n", - " )\n", - ")\n", - "\n", - "url_payload_blocks = sync_client.post.call_args.kwargs[\"json\"][\"messages\"][0][\"content\"]\n", - "\n", - "assert url_payload_blocks[:2] == [\n", - " {\"type\": \"audio_url\", \"audio_url\": {\"url\": AUDIO_URL}},\n", - " {\"type\": \"video_url\", \"video_url\": {\"url\": VIDEO_URL}},\n", - "]\n", - "\n", - "url_payload_blocks" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0e6d78eb", - "metadata": {}, - "outputs": [], - "source": [ - "sync_client = make_mock_sync_client(chat_response())\n", - "client = OpenAICompatibleClient(\n", - " provider_name=\"smoke-provider\",\n", - " endpoint=\"https://api.example.com/v1\",\n", - " api_key=\"not-used\",\n", - " concurrency_mode=ClientConcurrencyMode.SYNC,\n", - " sync_client=sync_client,\n", - ")\n", - "\n", - "client.completion(\n", - " ChatCompletionRequest(\n", - " model=\"mock-multimodal-model\",\n", - " messages=[\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": [*base64_context_blocks, {\"type\": \"text\", \"text\": \"Summarize the media.\"}],\n", - " }\n", - " ],\n", - " )\n", - ")\n", - "\n", - "base64_payload_blocks = sync_client.post.call_args.kwargs[\"json\"][\"messages\"][0][\"content\"]\n", - "\n", - "assert base64_payload_blocks[:2] == [\n", - " {\"type\": \"input_audio\", \"input_audio\": {\"data\": \"YXVkaW8=\", \"format\": \"mp3\"}},\n", - " {\"type\": \"video_url\", \"video_url\": {\"url\": \"data:video/mp4;base64,dmlkZW8=\"}},\n", - "]\n", - "\n", - "base64_payload_blocks" - ] - }, - { - "cell_type": "markdown", - "id": "6ebd8210", - "metadata": {}, - "source": [ - "## Anthropic rejects unsupported audio/video context before an HTTP call" + "def assert_raises_message(callback: Callable[[], Any], expected_message: str) -> str:\n", + " try:\n", + " callback()\n", + " except Exception as exc:\n", + " message = str(exc)\n", + " assert expected_message in message\n", + " return message\n", + " raise AssertionError(f\"Expected exception containing {expected_message!r}\")\n", + "\n", + "\n", + "bad_seed_df = pd.DataFrame([{\"record_id\": \"bad-row\", \"video_path\": \"screen_recording.mp4\"}])\n", + "\n", + "\n", + "def run_bad_path_preview() -> None:\n", + " with LocalOpenAIServer() as local_server, tempfile.TemporaryDirectory() as artifact_dir:\n", + " provider = dd.ModelProvider(\n", + " name=\"local-openai\",\n", + " endpoint=local_server.endpoint,\n", + " provider_type=\"openai\",\n", + " api_key=\"local-test-key\",\n", + " )\n", + " data_designer = DataDesigner(artifact_path=artifact_dir, model_providers=[provider])\n", + " config_builder = dd.DataDesignerConfigBuilder(\n", + " model_configs=[\n", + " dd.ModelConfig(\n", + " alias=\"local-multimodal\",\n", + " model=\"local-audio-video-model\",\n", + " provider=provider.name,\n", + " inference_parameters=dd.ChatCompletionInferenceParams(max_parallel_requests=1, timeout=10),\n", + " skip_health_check=True,\n", + " )\n", + " ]\n", + " )\n", + " config_builder.with_seed_dataset(dd.DataFrameSeedSource(df=bad_seed_df))\n", + " config_builder.add_column(\n", + " dd.LLMTextColumnConfig(\n", + " name=\"media_summary\",\n", + " model_alias=\"local-multimodal\",\n", + " prompt=\"Summarize the video for {{ record_id }}.\",\n", + " multi_modal_context=[\n", + " dd.VideoContext(column_name=\"video_path\", video_format=dd.VideoFormat.MP4),\n", + " ],\n", + " )\n", + " )\n", + " data_designer.preview(config_builder, num_records=1)\n", + "\n", + "\n", + "path_error = assert_raises_message(run_bad_path_preview, \"Local video paths are not supported\")\n", + "\n", + "path_error" ] }, { "cell_type": "code", "execution_count": null, - "id": "137027ab", + "id": "3df8cbf2", "metadata": {}, "outputs": [], "source": [ - "for block in url_context_blocks:\n", - " try:\n", - " translate_content_blocks([block])\n", - " except UnsupportedAnthropicMediaBlockError as exc:\n", - " assert exc.modality in {\"audio\", \"video\"}\n", - " else:\n", - " raise AssertionError(f\"Expected Anthropic to reject {block['type']} context\")\n", - "\n", - "\"All audio/video context smoke checks passed.\"" + "\"Full Data Designer audio/video context smoke test passed.\"" ] } ], diff --git a/docs/notebook_source/7-audio-video-context-smoke-test.py b/docs/notebook_source/7-audio-video-context-smoke-test.py index 4f468b86d..ab40d47e7 100644 --- a/docs/notebook_source/7-audio-video-context-smoke-test.py +++ b/docs/notebook_source/7-audio-video-context-smoke-test.py @@ -13,35 +13,34 @@ # --- # %% [markdown] -# # Audio and Video Context Smoke Test +# # Audio and Video Context End-to-End Smoke Test # %% [markdown] -# This notebook verifies the audio and video context flow without calling a real model endpoint. +# This notebook verifies audio and video context through a full Data Designer preview pipeline. # -# It checks that: +# It starts a tiny local OpenAI-compatible HTTP server, configures Data Designer to use that server as a model +# provider, builds a seeded dataset with audio/video context columns, runs `DataDesigner.preview(...)`, and asserts on +# the payload received by the endpoint. # -# - `AudioContext` and `VideoContext` produce canonical media blocks. -# - HTTP(S) URLs stay as URLs, including URLs without file extensions. -# - Base64 media values keep the required media metadata. -# - Local path-looking values are rejected instead of being resolved in the config layer. -# - The OpenAI-compatible adapter forwards URL media as URL blocks in the endpoint payload. +# No external model API key is required. # %% from __future__ import annotations import json +import tempfile +import threading from collections.abc import Callable +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from typing import Any -from unittest.mock import MagicMock + +import pandas as pd import data_designer.config as dd -from data_designer.engine.models.clients.adapters.anthropic_translation import ( - UnsupportedAnthropicMediaBlockError, - translate_content_blocks, -) -from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode -from data_designer.engine.models.clients.adapters.openai_compatible import OpenAICompatibleClient -from data_designer.engine.models.clients.types import ChatCompletionRequest +from data_designer.interface import DataDesigner + +# %% [markdown] +# ## Seed media values # %% AUDIO_URL = "https://example.com/download?id=audio-123" @@ -49,202 +48,233 @@ AUDIO_DATA_URI = "data:audio/mpeg;base64,YXVkaW8=" VIDEO_DATA_URI = "data:video/mp4;base64,dmlkZW8=" -record = { - "audio_url": AUDIO_URL, - "video_url": VIDEO_URL, - "audio_data_uri": AUDIO_DATA_URI, - "video_data_uri": VIDEO_DATA_URI, -} - -# %% [markdown] -# ## Config blocks keep URLs and base64 media distinct - -# %% -url_context_blocks = [ - *dd.AudioContext(column_name="audio_url").get_contexts(record), - *dd.VideoContext(column_name="video_url").get_contexts(record), -] - -assert url_context_blocks == [ - {"type": "audio", "source": {"type": "url", "url": AUDIO_URL}}, - {"type": "video", "source": {"type": "url", "url": VIDEO_URL}}, -] - -url_context_blocks - -# %% -base64_context_blocks = [ - *dd.AudioContext(column_name="audio_data_uri").get_contexts(record), - *dd.VideoContext(column_name="video_data_uri").get_contexts(record), -] - -assert base64_context_blocks == [ - { - "type": "audio", - "source": { - "type": "base64", - "media_type": "audio/mpeg", - "data": "YXVkaW8=", - "format": "mp3", - }, - }, - {"type": "video", "source": {"type": "base64", "media_type": "video/mp4", "data": "dmlkZW8="}}, -] +seed_df = pd.DataFrame( + [ + { + "record_id": "row-1", + "audio_url": AUDIO_URL, + "video_url": VIDEO_URL, + "audio_data_uri": AUDIO_DATA_URI, + "video_data_uri": VIDEO_DATA_URI, + } + ] +) -base64_context_blocks +seed_df # %% [markdown] -# ## Local file names are not resolved by audio/video context +# ## Local OpenAI-compatible test endpoint # %% -def assert_raises_message(callback: Callable[[], Any], expected_message: str) -> str: - try: - callback() - except ValueError as exc: - message = str(exc) - assert expected_message in message - return message - raise AssertionError(f"Expected ValueError containing {expected_message!r}") +class RecordingOpenAIHandler(BaseHTTPRequestHandler): + captured_requests: list[dict[str, Any]] = [] + def do_POST(self) -> None: + content_length = int(self.headers.get("Content-Length", "0")) + raw_body = self.rfile.read(content_length) + payload = json.loads(raw_body.decode("utf-8")) -audio_path_error = assert_raises_message( - lambda: dd.AudioContext(column_name="audio_path", audio_format=dd.AudioFormat.MP3).get_contexts( - {"audio_path": "screen_recording.mp3"} - ), - "Local audio paths are not supported", -) -video_path_error = assert_raises_message( - lambda: dd.VideoContext(column_name="video_path", video_format=dd.VideoFormat.MP4).get_contexts( - {"video_path": "screen_recording.mp4"} - ), - "Local video paths are not supported", -) + self.captured_requests.append( + { + "path": self.path, + "headers": dict(self.headers), + "json": payload, + } + ) + + content_blocks = payload["messages"][0]["content"] + media_blocks = [ + block + for block in content_blocks + if isinstance(block, dict) and block.get("type") in {"audio_url", "input_audio", "video_url"} + ] + response_json = { + "id": "chatcmpl-audio-video-smoke", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": f"mock summary received {len(media_blocks)} media blocks", + }, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + } + response_body = json.dumps(response_json).encode("utf-8") + + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(response_body))) + self.end_headers() + self.wfile.write(response_body) + + def log_message(self, format: str, *args: Any) -> None: + return + + +class LocalOpenAIServer: + def __init__(self) -> None: + self._server = ThreadingHTTPServer(("127.0.0.1", 0), RecordingOpenAIHandler) + self._thread = threading.Thread(target=self._server.serve_forever, daemon=True) + + @property + def endpoint(self) -> str: + return f"http://127.0.0.1:{self._server.server_port}/v1" + + @property + def captured_requests(self) -> list[dict[str, Any]]: + return RecordingOpenAIHandler.captured_requests + + def __enter__(self) -> LocalOpenAIServer: + RecordingOpenAIHandler.captured_requests = [] + self._thread.start() + return self + + def __exit__(self, exc_type: object, exc: object, traceback: object) -> None: + self._server.shutdown() + self._server.server_close() + self._thread.join(timeout=5) -audio_path_error, video_path_error # %% [markdown] -# ## Column config round-trips mixed media context +# ## Build and run a full Data Designer preview + # %% -column_config = dd.LLMTextColumnConfig( - name="media_summary", - prompt="Summarize the audio and video context.", - model_alias="mock-multimodal-model", - multi_modal_context=[ - dd.AudioContext(column_name="audio_url"), - dd.VideoContext(column_name="video_url"), - ], -) +def build_audio_video_config(model_alias: str, provider_name: str) -> dd.DataDesignerConfigBuilder: + config_builder = dd.DataDesignerConfigBuilder( + model_configs=[ + dd.ModelConfig( + alias=model_alias, + model="local-audio-video-model", + provider=provider_name, + inference_parameters=dd.ChatCompletionInferenceParams( + temperature=0.0, + max_tokens=64, + max_parallel_requests=1, + timeout=10, + ), + skip_health_check=True, + ) + ] + ) + config_builder.with_seed_dataset(dd.DataFrameSeedSource(df=seed_df)) + config_builder.add_column( + dd.LLMTextColumnConfig( + name="media_summary", + model_alias=model_alias, + prompt="Summarize the audio and video context for {{ record_id }}.", + multi_modal_context=[ + dd.AudioContext(column_name="audio_url"), + dd.VideoContext(column_name="video_url"), + dd.AudioContext(column_name="audio_data_uri"), + dd.VideoContext(column_name="video_data_uri"), + ], + ) + ) + return config_builder -round_tripped = dd.LLMTextColumnConfig(**column_config.model_dump()) -assert [type(context).__name__ for context in round_tripped.multi_modal_context or []] == [ - "AudioContext", - "VideoContext", -] -assert set(round_tripped.required_columns) == {"audio_url", "video_url"} +with LocalOpenAIServer() as local_server, tempfile.TemporaryDirectory() as artifact_dir: + provider = dd.ModelProvider( + name="local-openai", + endpoint=local_server.endpoint, + provider_type="openai", + api_key="local-test-key", + ) + data_designer = DataDesigner(artifact_path=artifact_dir, model_providers=[provider]) + config_builder = build_audio_video_config(model_alias="local-multimodal", provider_name=provider.name) -round_tripped + preview = data_designer.preview(config_builder, num_records=1) + captured_requests = local_server.captured_requests -# %% [markdown] -# ## OpenAI-compatible payloads send URL media as URLs +preview.dataset +# %% [markdown] +# ## Verify the preview output and endpoint payload # %% -def mock_httpx_response(json_data: dict[str, Any], status_code: int = 200) -> MagicMock: - response = MagicMock() - response.status_code = status_code - response.json.return_value = json_data - response.text = json.dumps(json_data) - response.headers = {} - return response - - -def make_mock_sync_client(response_json: dict[str, Any]) -> MagicMock: - client = MagicMock() - client.post = MagicMock(return_value=mock_httpx_response(response_json)) - return client - - -def chat_response(content: str = "ok") -> dict[str, Any]: - return { - "choices": [{"index": 0, "message": {"role": "assistant", "content": content}, "finish_reason": "stop"}], - "usage": {"prompt_tokens": 3, "completion_tokens": 1, "total_tokens": 4}, - } - - -sync_client = make_mock_sync_client(chat_response()) -client = OpenAICompatibleClient( - provider_name="smoke-provider", - endpoint="https://api.example.com/v1", - api_key="not-used", - concurrency_mode=ClientConcurrencyMode.SYNC, - sync_client=sync_client, -) - -client.completion( - ChatCompletionRequest( - model="mock-multimodal-model", - messages=[ - { - "role": "user", - "content": [*url_context_blocks, {"type": "text", "text": "Summarize the media."}], - } - ], - ) -) +assert len(preview.dataset) == 1 +assert preview.dataset.loc[0, "media_summary"] == "mock summary received 4 media blocks" +assert len(captured_requests) == 1 -url_payload_blocks = sync_client.post.call_args.kwargs["json"]["messages"][0]["content"] +request = captured_requests[0] +payload = request["json"] +content_blocks = payload["messages"][0]["content"] -assert url_payload_blocks[:2] == [ +expected_media_blocks = [ {"type": "audio_url", "audio_url": {"url": AUDIO_URL}}, {"type": "video_url", "video_url": {"url": VIDEO_URL}}, -] - -url_payload_blocks - -# %% -sync_client = make_mock_sync_client(chat_response()) -client = OpenAICompatibleClient( - provider_name="smoke-provider", - endpoint="https://api.example.com/v1", - api_key="not-used", - concurrency_mode=ClientConcurrencyMode.SYNC, - sync_client=sync_client, -) - -client.completion( - ChatCompletionRequest( - model="mock-multimodal-model", - messages=[ - { - "role": "user", - "content": [*base64_context_blocks, {"type": "text", "text": "Summarize the media."}], - } - ], - ) -) - -base64_payload_blocks = sync_client.post.call_args.kwargs["json"]["messages"][0]["content"] - -assert base64_payload_blocks[:2] == [ {"type": "input_audio", "input_audio": {"data": "YXVkaW8=", "format": "mp3"}}, {"type": "video_url", "video_url": {"url": "data:video/mp4;base64,dmlkZW8="}}, ] -base64_payload_blocks +assert request["path"] == "/v1/chat/completions" +assert payload["model"] == "local-audio-video-model" +assert content_blocks[:4] == expected_media_blocks +assert content_blocks[4] == {"type": "text", "text": "Summarize the audio and video context for row-1."} + +content_blocks # %% [markdown] -# ## Anthropic rejects unsupported audio/video context before an HTTP call +# ## Verify local path rejection through the pipeline + # %% -for block in url_context_blocks: +def assert_raises_message(callback: Callable[[], Any], expected_message: str) -> str: try: - translate_content_blocks([block]) - except UnsupportedAnthropicMediaBlockError as exc: - assert exc.modality in {"audio", "video"} - else: - raise AssertionError(f"Expected Anthropic to reject {block['type']} context") + callback() + except Exception as exc: + message = str(exc) + assert expected_message in message + return message + raise AssertionError(f"Expected exception containing {expected_message!r}") + + +bad_seed_df = pd.DataFrame([{"record_id": "bad-row", "video_path": "screen_recording.mp4"}]) + + +def run_bad_path_preview() -> None: + with LocalOpenAIServer() as local_server, tempfile.TemporaryDirectory() as artifact_dir: + provider = dd.ModelProvider( + name="local-openai", + endpoint=local_server.endpoint, + provider_type="openai", + api_key="local-test-key", + ) + data_designer = DataDesigner(artifact_path=artifact_dir, model_providers=[provider]) + config_builder = dd.DataDesignerConfigBuilder( + model_configs=[ + dd.ModelConfig( + alias="local-multimodal", + model="local-audio-video-model", + provider=provider.name, + inference_parameters=dd.ChatCompletionInferenceParams(max_parallel_requests=1, timeout=10), + skip_health_check=True, + ) + ] + ) + config_builder.with_seed_dataset(dd.DataFrameSeedSource(df=bad_seed_df)) + config_builder.add_column( + dd.LLMTextColumnConfig( + name="media_summary", + model_alias="local-multimodal", + prompt="Summarize the video for {{ record_id }}.", + multi_modal_context=[ + dd.VideoContext(column_name="video_path", video_format=dd.VideoFormat.MP4), + ], + ) + ) + data_designer.preview(config_builder, num_records=1) + + +path_error = assert_raises_message(run_bad_path_preview, "Local video paths are not supported") + +path_error -"All audio/video context smoke checks passed." +# %% +"Full Data Designer audio/video context smoke test passed." From 2820afd2553952b2e75c240f325518f37cd3d4e8 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Mon, 18 May 2026 15:05:44 -0600 Subject: [PATCH 05/15] test: remove media context notebook from docs Move the generated audio/video context E2E notebook out of the PR docs surface and keep it locally under the main checkout's .scratch directory. Refs #671 --- .../7-audio-video-context-smoke-test.ipynb | 423 ------------------ .../7-audio-video-context-smoke-test.py | 280 ------------ 2 files changed, 703 deletions(-) delete mode 100644 docs/colab_notebooks/7-audio-video-context-smoke-test.ipynb delete mode 100644 docs/notebook_source/7-audio-video-context-smoke-test.py diff --git a/docs/colab_notebooks/7-audio-video-context-smoke-test.ipynb b/docs/colab_notebooks/7-audio-video-context-smoke-test.ipynb deleted file mode 100644 index 1fbdb1d33..000000000 --- a/docs/colab_notebooks/7-audio-video-context-smoke-test.ipynb +++ /dev/null @@ -1,423 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "adb8cb68", - "metadata": { - "nemo_colab_inject": true - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "id": "2bd587c5", - "metadata": {}, - "source": [ - "# Audio and Video Context End-to-End Smoke Test" - ] - }, - { - "cell_type": "markdown", - "id": "d6fcf989", - "metadata": {}, - "source": [ - "This notebook verifies audio and video context through a full Data Designer preview pipeline.\n", - "\n", - "It starts a tiny local OpenAI-compatible HTTP server, configures Data Designer to use that server as a model\n", - "provider, builds a seeded dataset with audio/video context columns, runs `DataDesigner.preview(...)`, and asserts on\n", - "the payload received by the endpoint.\n", - "\n", - "No external model API key is required." - ] - }, - { - "cell_type": "markdown", - "id": "0192ed76", - "metadata": { - "nemo_colab_inject": true - }, - "source": [ - "### ⚑ Colab Setup\n", - "\n", - "Run the cells below to install the dependencies and set up the API key. If you don't have an API key, you can generate one from [build.nvidia.com](https://build.nvidia.com).\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9a4a01df", - "metadata": { - "nemo_colab_inject": true - }, - "outputs": [], - "source": [ - "%%capture\n", - "!pip install -U data-designer" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8d91f318", - "metadata": { - "nemo_colab_inject": true - }, - "outputs": [], - "source": [ - "import getpass\n", - "import os\n", - "\n", - "from google.colab import userdata\n", - "\n", - "try:\n", - " os.environ[\"NVIDIA_API_KEY\"] = userdata.get(\"NVIDIA_API_KEY\")\n", - "except userdata.SecretNotFoundError:\n", - " os.environ[\"NVIDIA_API_KEY\"] = getpass.getpass(\"Enter your NVIDIA API key: \")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3565c497", - "metadata": {}, - "outputs": [], - "source": [ - "from __future__ import annotations\n", - "\n", - "import json\n", - "import tempfile\n", - "import threading\n", - "from collections.abc import Callable\n", - "from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer\n", - "from typing import Any\n", - "\n", - "import pandas as pd\n", - "\n", - "import data_designer.config as dd\n", - "from data_designer.interface import DataDesigner" - ] - }, - { - "cell_type": "markdown", - "id": "dbb38552", - "metadata": {}, - "source": [ - "## Seed media values" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4a4d5fb8", - "metadata": {}, - "outputs": [], - "source": [ - "AUDIO_URL = \"https://example.com/download?id=audio-123\"\n", - "VIDEO_URL = \"https://example.com/download?id=video-456\"\n", - "AUDIO_DATA_URI = \"data:audio/mpeg;base64,YXVkaW8=\"\n", - "VIDEO_DATA_URI = \"data:video/mp4;base64,dmlkZW8=\"\n", - "\n", - "seed_df = pd.DataFrame(\n", - " [\n", - " {\n", - " \"record_id\": \"row-1\",\n", - " \"audio_url\": AUDIO_URL,\n", - " \"video_url\": VIDEO_URL,\n", - " \"audio_data_uri\": AUDIO_DATA_URI,\n", - " \"video_data_uri\": VIDEO_DATA_URI,\n", - " }\n", - " ]\n", - ")\n", - "\n", - "seed_df" - ] - }, - { - "cell_type": "markdown", - "id": "7af2eb26", - "metadata": { - "lines_to_next_cell": 2 - }, - "source": [ - "## Local OpenAI-compatible test endpoint" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "747d9851", - "metadata": { - "lines_to_next_cell": 2 - }, - "outputs": [], - "source": [ - "class RecordingOpenAIHandler(BaseHTTPRequestHandler):\n", - " captured_requests: list[dict[str, Any]] = []\n", - "\n", - " def do_POST(self) -> None:\n", - " content_length = int(self.headers.get(\"Content-Length\", \"0\"))\n", - " raw_body = self.rfile.read(content_length)\n", - " payload = json.loads(raw_body.decode(\"utf-8\"))\n", - "\n", - " self.captured_requests.append(\n", - " {\n", - " \"path\": self.path,\n", - " \"headers\": dict(self.headers),\n", - " \"json\": payload,\n", - " }\n", - " )\n", - "\n", - " content_blocks = payload[\"messages\"][0][\"content\"]\n", - " media_blocks = [\n", - " block\n", - " for block in content_blocks\n", - " if isinstance(block, dict) and block.get(\"type\") in {\"audio_url\", \"input_audio\", \"video_url\"}\n", - " ]\n", - " response_json = {\n", - " \"id\": \"chatcmpl-audio-video-smoke\",\n", - " \"object\": \"chat.completion\",\n", - " \"choices\": [\n", - " {\n", - " \"index\": 0,\n", - " \"message\": {\n", - " \"role\": \"assistant\",\n", - " \"content\": f\"mock summary received {len(media_blocks)} media blocks\",\n", - " },\n", - " \"finish_reason\": \"stop\",\n", - " }\n", - " ],\n", - " \"usage\": {\"prompt_tokens\": 10, \"completion_tokens\": 5, \"total_tokens\": 15},\n", - " }\n", - " response_body = json.dumps(response_json).encode(\"utf-8\")\n", - "\n", - " self.send_response(200)\n", - " self.send_header(\"Content-Type\", \"application/json\")\n", - " self.send_header(\"Content-Length\", str(len(response_body)))\n", - " self.end_headers()\n", - " self.wfile.write(response_body)\n", - "\n", - " def log_message(self, format: str, *args: Any) -> None:\n", - " return\n", - "\n", - "\n", - "class LocalOpenAIServer:\n", - " def __init__(self) -> None:\n", - " self._server = ThreadingHTTPServer((\"127.0.0.1\", 0), RecordingOpenAIHandler)\n", - " self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)\n", - "\n", - " @property\n", - " def endpoint(self) -> str:\n", - " return f\"http://127.0.0.1:{self._server.server_port}/v1\"\n", - "\n", - " @property\n", - " def captured_requests(self) -> list[dict[str, Any]]:\n", - " return RecordingOpenAIHandler.captured_requests\n", - "\n", - " def __enter__(self) -> LocalOpenAIServer:\n", - " RecordingOpenAIHandler.captured_requests = []\n", - " self._thread.start()\n", - " return self\n", - "\n", - " def __exit__(self, exc_type: object, exc: object, traceback: object) -> None:\n", - " self._server.shutdown()\n", - " self._server.server_close()\n", - " self._thread.join(timeout=5)" - ] - }, - { - "cell_type": "markdown", - "id": "236e2a16", - "metadata": { - "lines_to_next_cell": 2 - }, - "source": [ - "## Build and run a full Data Designer preview" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d93a87e4", - "metadata": {}, - "outputs": [], - "source": [ - "def build_audio_video_config(model_alias: str, provider_name: str) -> dd.DataDesignerConfigBuilder:\n", - " config_builder = dd.DataDesignerConfigBuilder(\n", - " model_configs=[\n", - " dd.ModelConfig(\n", - " alias=model_alias,\n", - " model=\"local-audio-video-model\",\n", - " provider=provider_name,\n", - " inference_parameters=dd.ChatCompletionInferenceParams(\n", - " temperature=0.0,\n", - " max_tokens=64,\n", - " max_parallel_requests=1,\n", - " timeout=10,\n", - " ),\n", - " skip_health_check=True,\n", - " )\n", - " ]\n", - " )\n", - " config_builder.with_seed_dataset(dd.DataFrameSeedSource(df=seed_df))\n", - " config_builder.add_column(\n", - " dd.LLMTextColumnConfig(\n", - " name=\"media_summary\",\n", - " model_alias=model_alias,\n", - " prompt=\"Summarize the audio and video context for {{ record_id }}.\",\n", - " multi_modal_context=[\n", - " dd.AudioContext(column_name=\"audio_url\"),\n", - " dd.VideoContext(column_name=\"video_url\"),\n", - " dd.AudioContext(column_name=\"audio_data_uri\"),\n", - " dd.VideoContext(column_name=\"video_data_uri\"),\n", - " ],\n", - " )\n", - " )\n", - " return config_builder\n", - "\n", - "\n", - "with LocalOpenAIServer() as local_server, tempfile.TemporaryDirectory() as artifact_dir:\n", - " provider = dd.ModelProvider(\n", - " name=\"local-openai\",\n", - " endpoint=local_server.endpoint,\n", - " provider_type=\"openai\",\n", - " api_key=\"local-test-key\",\n", - " )\n", - " data_designer = DataDesigner(artifact_path=artifact_dir, model_providers=[provider])\n", - " config_builder = build_audio_video_config(model_alias=\"local-multimodal\", provider_name=provider.name)\n", - "\n", - " preview = data_designer.preview(config_builder, num_records=1)\n", - " captured_requests = local_server.captured_requests\n", - "\n", - "preview.dataset" - ] - }, - { - "cell_type": "markdown", - "id": "09aaa62b", - "metadata": {}, - "source": [ - "## Verify the preview output and endpoint payload" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f23fc721", - "metadata": {}, - "outputs": [], - "source": [ - "assert len(preview.dataset) == 1\n", - "assert preview.dataset.loc[0, \"media_summary\"] == \"mock summary received 4 media blocks\"\n", - "assert len(captured_requests) == 1\n", - "\n", - "request = captured_requests[0]\n", - "payload = request[\"json\"]\n", - "content_blocks = payload[\"messages\"][0][\"content\"]\n", - "\n", - "expected_media_blocks = [\n", - " {\"type\": \"audio_url\", \"audio_url\": {\"url\": AUDIO_URL}},\n", - " {\"type\": \"video_url\", \"video_url\": {\"url\": VIDEO_URL}},\n", - " {\"type\": \"input_audio\", \"input_audio\": {\"data\": \"YXVkaW8=\", \"format\": \"mp3\"}},\n", - " {\"type\": \"video_url\", \"video_url\": {\"url\": \"data:video/mp4;base64,dmlkZW8=\"}},\n", - "]\n", - "\n", - "assert request[\"path\"] == \"/v1/chat/completions\"\n", - "assert payload[\"model\"] == \"local-audio-video-model\"\n", - "assert content_blocks[:4] == expected_media_blocks\n", - "assert content_blocks[4] == {\"type\": \"text\", \"text\": \"Summarize the audio and video context for row-1.\"}\n", - "\n", - "content_blocks" - ] - }, - { - "cell_type": "markdown", - "id": "821c0007", - "metadata": { - "lines_to_next_cell": 2 - }, - "source": [ - "## Verify local path rejection through the pipeline" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "259da710", - "metadata": {}, - "outputs": [], - "source": [ - "def assert_raises_message(callback: Callable[[], Any], expected_message: str) -> str:\n", - " try:\n", - " callback()\n", - " except Exception as exc:\n", - " message = str(exc)\n", - " assert expected_message in message\n", - " return message\n", - " raise AssertionError(f\"Expected exception containing {expected_message!r}\")\n", - "\n", - "\n", - "bad_seed_df = pd.DataFrame([{\"record_id\": \"bad-row\", \"video_path\": \"screen_recording.mp4\"}])\n", - "\n", - "\n", - "def run_bad_path_preview() -> None:\n", - " with LocalOpenAIServer() as local_server, tempfile.TemporaryDirectory() as artifact_dir:\n", - " provider = dd.ModelProvider(\n", - " name=\"local-openai\",\n", - " endpoint=local_server.endpoint,\n", - " provider_type=\"openai\",\n", - " api_key=\"local-test-key\",\n", - " )\n", - " data_designer = DataDesigner(artifact_path=artifact_dir, model_providers=[provider])\n", - " config_builder = dd.DataDesignerConfigBuilder(\n", - " model_configs=[\n", - " dd.ModelConfig(\n", - " alias=\"local-multimodal\",\n", - " model=\"local-audio-video-model\",\n", - " provider=provider.name,\n", - " inference_parameters=dd.ChatCompletionInferenceParams(max_parallel_requests=1, timeout=10),\n", - " skip_health_check=True,\n", - " )\n", - " ]\n", - " )\n", - " config_builder.with_seed_dataset(dd.DataFrameSeedSource(df=bad_seed_df))\n", - " config_builder.add_column(\n", - " dd.LLMTextColumnConfig(\n", - " name=\"media_summary\",\n", - " model_alias=\"local-multimodal\",\n", - " prompt=\"Summarize the video for {{ record_id }}.\",\n", - " multi_modal_context=[\n", - " dd.VideoContext(column_name=\"video_path\", video_format=dd.VideoFormat.MP4),\n", - " ],\n", - " )\n", - " )\n", - " data_designer.preview(config_builder, num_records=1)\n", - "\n", - "\n", - "path_error = assert_raises_message(run_bad_path_preview, \"Local video paths are not supported\")\n", - "\n", - "path_error" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3df8cbf2", - "metadata": {}, - "outputs": [], - "source": [ - "\"Full Data Designer audio/video context smoke test passed.\"" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/notebook_source/7-audio-video-context-smoke-test.py b/docs/notebook_source/7-audio-video-context-smoke-test.py deleted file mode 100644 index ab40d47e7..000000000 --- a/docs/notebook_source/7-audio-video-context-smoke-test.py +++ /dev/null @@ -1,280 +0,0 @@ -# --- -# jupyter: -# jupytext: -# text_representation: -# extension: .py -# format_name: percent -# format_version: '1.3' -# jupytext_version: 1.18.1 -# kernelspec: -# display_name: .venv -# language: python -# name: python3 -# --- - -# %% [markdown] -# # Audio and Video Context End-to-End Smoke Test - -# %% [markdown] -# This notebook verifies audio and video context through a full Data Designer preview pipeline. -# -# It starts a tiny local OpenAI-compatible HTTP server, configures Data Designer to use that server as a model -# provider, builds a seeded dataset with audio/video context columns, runs `DataDesigner.preview(...)`, and asserts on -# the payload received by the endpoint. -# -# No external model API key is required. - -# %% -from __future__ import annotations - -import json -import tempfile -import threading -from collections.abc import Callable -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer -from typing import Any - -import pandas as pd - -import data_designer.config as dd -from data_designer.interface import DataDesigner - -# %% [markdown] -# ## Seed media values - -# %% -AUDIO_URL = "https://example.com/download?id=audio-123" -VIDEO_URL = "https://example.com/download?id=video-456" -AUDIO_DATA_URI = "data:audio/mpeg;base64,YXVkaW8=" -VIDEO_DATA_URI = "data:video/mp4;base64,dmlkZW8=" - -seed_df = pd.DataFrame( - [ - { - "record_id": "row-1", - "audio_url": AUDIO_URL, - "video_url": VIDEO_URL, - "audio_data_uri": AUDIO_DATA_URI, - "video_data_uri": VIDEO_DATA_URI, - } - ] -) - -seed_df - -# %% [markdown] -# ## Local OpenAI-compatible test endpoint - - -# %% -class RecordingOpenAIHandler(BaseHTTPRequestHandler): - captured_requests: list[dict[str, Any]] = [] - - def do_POST(self) -> None: - content_length = int(self.headers.get("Content-Length", "0")) - raw_body = self.rfile.read(content_length) - payload = json.loads(raw_body.decode("utf-8")) - - self.captured_requests.append( - { - "path": self.path, - "headers": dict(self.headers), - "json": payload, - } - ) - - content_blocks = payload["messages"][0]["content"] - media_blocks = [ - block - for block in content_blocks - if isinstance(block, dict) and block.get("type") in {"audio_url", "input_audio", "video_url"} - ] - response_json = { - "id": "chatcmpl-audio-video-smoke", - "object": "chat.completion", - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": f"mock summary received {len(media_blocks)} media blocks", - }, - "finish_reason": "stop", - } - ], - "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, - } - response_body = json.dumps(response_json).encode("utf-8") - - self.send_response(200) - self.send_header("Content-Type", "application/json") - self.send_header("Content-Length", str(len(response_body))) - self.end_headers() - self.wfile.write(response_body) - - def log_message(self, format: str, *args: Any) -> None: - return - - -class LocalOpenAIServer: - def __init__(self) -> None: - self._server = ThreadingHTTPServer(("127.0.0.1", 0), RecordingOpenAIHandler) - self._thread = threading.Thread(target=self._server.serve_forever, daemon=True) - - @property - def endpoint(self) -> str: - return f"http://127.0.0.1:{self._server.server_port}/v1" - - @property - def captured_requests(self) -> list[dict[str, Any]]: - return RecordingOpenAIHandler.captured_requests - - def __enter__(self) -> LocalOpenAIServer: - RecordingOpenAIHandler.captured_requests = [] - self._thread.start() - return self - - def __exit__(self, exc_type: object, exc: object, traceback: object) -> None: - self._server.shutdown() - self._server.server_close() - self._thread.join(timeout=5) - - -# %% [markdown] -# ## Build and run a full Data Designer preview - - -# %% -def build_audio_video_config(model_alias: str, provider_name: str) -> dd.DataDesignerConfigBuilder: - config_builder = dd.DataDesignerConfigBuilder( - model_configs=[ - dd.ModelConfig( - alias=model_alias, - model="local-audio-video-model", - provider=provider_name, - inference_parameters=dd.ChatCompletionInferenceParams( - temperature=0.0, - max_tokens=64, - max_parallel_requests=1, - timeout=10, - ), - skip_health_check=True, - ) - ] - ) - config_builder.with_seed_dataset(dd.DataFrameSeedSource(df=seed_df)) - config_builder.add_column( - dd.LLMTextColumnConfig( - name="media_summary", - model_alias=model_alias, - prompt="Summarize the audio and video context for {{ record_id }}.", - multi_modal_context=[ - dd.AudioContext(column_name="audio_url"), - dd.VideoContext(column_name="video_url"), - dd.AudioContext(column_name="audio_data_uri"), - dd.VideoContext(column_name="video_data_uri"), - ], - ) - ) - return config_builder - - -with LocalOpenAIServer() as local_server, tempfile.TemporaryDirectory() as artifact_dir: - provider = dd.ModelProvider( - name="local-openai", - endpoint=local_server.endpoint, - provider_type="openai", - api_key="local-test-key", - ) - data_designer = DataDesigner(artifact_path=artifact_dir, model_providers=[provider]) - config_builder = build_audio_video_config(model_alias="local-multimodal", provider_name=provider.name) - - preview = data_designer.preview(config_builder, num_records=1) - captured_requests = local_server.captured_requests - -preview.dataset - -# %% [markdown] -# ## Verify the preview output and endpoint payload - -# %% -assert len(preview.dataset) == 1 -assert preview.dataset.loc[0, "media_summary"] == "mock summary received 4 media blocks" -assert len(captured_requests) == 1 - -request = captured_requests[0] -payload = request["json"] -content_blocks = payload["messages"][0]["content"] - -expected_media_blocks = [ - {"type": "audio_url", "audio_url": {"url": AUDIO_URL}}, - {"type": "video_url", "video_url": {"url": VIDEO_URL}}, - {"type": "input_audio", "input_audio": {"data": "YXVkaW8=", "format": "mp3"}}, - {"type": "video_url", "video_url": {"url": "data:video/mp4;base64,dmlkZW8="}}, -] - -assert request["path"] == "/v1/chat/completions" -assert payload["model"] == "local-audio-video-model" -assert content_blocks[:4] == expected_media_blocks -assert content_blocks[4] == {"type": "text", "text": "Summarize the audio and video context for row-1."} - -content_blocks - -# %% [markdown] -# ## Verify local path rejection through the pipeline - - -# %% -def assert_raises_message(callback: Callable[[], Any], expected_message: str) -> str: - try: - callback() - except Exception as exc: - message = str(exc) - assert expected_message in message - return message - raise AssertionError(f"Expected exception containing {expected_message!r}") - - -bad_seed_df = pd.DataFrame([{"record_id": "bad-row", "video_path": "screen_recording.mp4"}]) - - -def run_bad_path_preview() -> None: - with LocalOpenAIServer() as local_server, tempfile.TemporaryDirectory() as artifact_dir: - provider = dd.ModelProvider( - name="local-openai", - endpoint=local_server.endpoint, - provider_type="openai", - api_key="local-test-key", - ) - data_designer = DataDesigner(artifact_path=artifact_dir, model_providers=[provider]) - config_builder = dd.DataDesignerConfigBuilder( - model_configs=[ - dd.ModelConfig( - alias="local-multimodal", - model="local-audio-video-model", - provider=provider.name, - inference_parameters=dd.ChatCompletionInferenceParams(max_parallel_requests=1, timeout=10), - skip_health_check=True, - ) - ] - ) - config_builder.with_seed_dataset(dd.DataFrameSeedSource(df=bad_seed_df)) - config_builder.add_column( - dd.LLMTextColumnConfig( - name="media_summary", - model_alias="local-multimodal", - prompt="Summarize the video for {{ record_id }}.", - multi_modal_context=[ - dd.VideoContext(column_name="video_path", video_format=dd.VideoFormat.MP4), - ], - ) - ) - data_designer.preview(config_builder, num_records=1) - - -path_error = assert_raises_message(run_bad_path_preview, "Local video paths are not supported") - -path_error - -# %% -"Full Data Designer audio/video context smoke test passed." From 253418b933cc384a5d95579d8866306ed15c9a17 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Thu, 21 May 2026 10:51:36 -0600 Subject: [PATCH 06/15] harden multimodal media context handling --- .../data_designer/config/column_configs.py | 30 ++- .../src/data_designer/config/models.py | 130 ++++++------- .../config/utils/media_helpers.py | 15 ++ .../tests/config/test_columns.py | 46 ++++- .../tests/config/test_models.py | 182 ++++++------------ .../tests/config/utils/test_media_helpers.py | 18 ++ .../clients/adapters/anthropic_translation.py | 25 +-- .../clients/adapters/openai_compatible.py | 53 +++-- .../generators/test_image.py | 66 ++++++- .../clients/test_anthropic_translation.py | 33 ++-- .../models/clients/test_openai_compatible.py | 35 +++- .../tests/engine/models/test_model_utils.py | 11 +- .../tests/engine/test_validation.py | 4 +- 13 files changed, 361 insertions(+), 287 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py index c770d6e65..fd9d3761b 100644 --- a/packages/data-designer-config/src/data_designer/config/column_configs.py +++ b/packages/data-designer-config/src/data_designer/config/column_configs.py @@ -180,6 +180,17 @@ class LLMTextColumnConfig(SingleColumnConfig): ) column_type: Literal["llm-text"] = "llm-text" + @field_validator("multi_modal_context", mode="before") + @classmethod + def inject_legacy_image_context_modality(cls, value: Any) -> Any: + """Preserve legacy image-context dicts that predate the modality discriminator.""" + if not isinstance(value, list): + return value + return [ + {"modality": "image", **item} if isinstance(item, dict) and _is_legacy_image_context_dict(item) else item + for item in value + ] + @staticmethod def get_column_emoji() -> str: return "πŸ“" @@ -597,8 +608,8 @@ class ImageColumnConfig(SingleColumnConfig): Must be a valid Jinja2 template. model_alias (required): The model to use for image generation. multi_modal_context: Optional list of multimodal contexts for generation. - Enables autoregressive multi-modal models to generate images based on media inputs. - Only works with autoregressive models that support image-to-image generation. + Enables autoregressive multimodal models to generate images based on image, audio, or video inputs. + Ignored by diffusion image-generation routes, which do not consume multimodal context. Inherited Attributes: name (required): Unique name of the column to be generated. @@ -614,6 +625,17 @@ class ImageColumnConfig(SingleColumnConfig): ) column_type: Literal["image"] = "image" + @field_validator("multi_modal_context", mode="before") + @classmethod + def inject_legacy_image_context_modality(cls, value: Any) -> Any: + """Preserve legacy image-context dicts that predate the modality discriminator.""" + if not isinstance(value, list): + return value + return [ + {"modality": "image", **item} if isinstance(item, dict) and _is_legacy_image_context_dict(item) else item + for item in value + ] + @staticmethod def get_column_emoji() -> str: return "πŸ–ΌοΈ" @@ -731,3 +753,7 @@ def validate_generator_function(self) -> Self: f"Expected a function decorated with @custom_column_generator." ) return self + + +def _is_legacy_image_context_dict(value: dict[str, Any]) -> bool: + return "modality" not in value diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py index 4a8abaa6f..bd614c66e 100644 --- a/packages/data-designer-config/src/data_designer/config/models.py +++ b/packages/data-designer-config/src/data_designer/config/models.py @@ -3,7 +3,6 @@ from __future__ import annotations -import json import logging from abc import ABC, abstractmethod from enum import Enum @@ -36,6 +35,8 @@ VideoFormat, audio_format_from_mime_type, audio_mime_type, + get_media_base64_context, + get_media_url_context, is_audio_path, is_media_url, is_video_path, @@ -111,46 +112,19 @@ def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[di Returns: A list of image contexts. """ - raw_value = record[self.column_name] - - # Normalize to list of strings - if isinstance(raw_value, str): - # Try to parse as JSON first - try: - parsed_value = json.loads(raw_value) - if isinstance(parsed_value, list): - context_values = parsed_value - else: - context_values = [raw_value] - except (json.JSONDecodeError, TypeError): - context_values = [raw_value] - elif isinstance(raw_value, list): - context_values = raw_value - elif hasattr(raw_value, "__iter__") and not isinstance(raw_value, (str, bytes, dict)): - # Handle array-like objects (numpy arrays, pandas Series, etc.) - context_values = list(raw_value) - else: - context_values = [raw_value] - - # Build context list - contexts = [] - for context_value in context_values: - context = dict(type="image_url") - if self.data_type is not None: - if self.data_type == ModalityDataType.URL: - context["image_url"] = {"url": context_value} - else: - context["image_url"] = { - "url": f"data:image/{self.image_format.value};base64,{context_value}", - } - else: - # Auto-detect: resolve file paths, pass through URLs, assume base64 otherwise - context["image_url"] = self._auto_resolve_context_value(context_value, base_path) - contexts.append(context) - - return contexts - - def _auto_resolve_context_value(self, context_value: str, base_path: str | None) -> dict[str, str]: + return [ + self._build_context(value, base_path=base_path) + for value in normalize_media_context_values(record[self.column_name]) + ] + + def _build_context(self, context_value: Any, *, base_path: str | None) -> dict[str, Any]: + if self.data_type == ModalityDataType.URL: + return get_media_url_context(Modality.IMAGE.value, context_value) + if self.data_type == ModalityDataType.BASE64: + return self._format_base64_context(context_value) + return self._auto_resolve_context_value(context_value, base_path) + + def _auto_resolve_context_value(self, context_value: Any, base_path: str | None) -> dict[str, Any]: """Auto-detect the format of a context value and resolve it. Resolution rules: @@ -164,22 +138,27 @@ def _auto_resolve_context_value(self, context_value: str, base_path: str | None) return self._format_base64_context(base64_data) if is_image_url(context_value): - return {"url": context_value} + return get_media_url_context(Modality.IMAGE.value, context_value) return self._format_base64_context(context_value) - def _format_base64_context(self, base64_data: str) -> dict[str, str]: - """Format base64 image data as an image_url context dict. + def _format_base64_context(self, base64_data: str) -> dict[str, Any]: + """Format base64 image data as a canonical image source dict. Uses self.image_format if set, otherwise detects from the image bytes. """ + parsed = parse_base64_data_uri(base64_data) + if parsed is not None: + media_type, data = parsed + if not media_type.startswith("image/"): + raise ValueError(f"Unsupported image media type {media_type!r}") + return get_media_base64_context(Modality.IMAGE.value, media_type, data) + image_format = self.image_format if image_format is None: image_bytes = decode_base64_image(base64_data) image_format = detect_image_format(image_bytes) - return { - "url": f"data:image/{image_format.value};base64,{base64_data}", - } + return get_media_base64_context(Modality.IMAGE.value, f"image/{image_format.value}", base64_data) @model_validator(mode="after") def _validate_image_format(self) -> Self: @@ -200,28 +179,20 @@ class AudioContext(ModalityContext): def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[dict[str, Any]]: """Get the contexts for the audio modality.""" - del base_path return [self._build_context(value) for value in normalize_media_context_values(record[self.column_name])] def _build_context(self, context_value: Any) -> dict[str, Any]: - if self.data_type == ModalityDataType.URL or (self.data_type is None and is_media_url(context_value)): - source: dict[str, Any] = {"type": "url", "url": context_value} - if self.audio_format is not None: - source["format"] = self.audio_format.value - return {"type": "audio", "source": source} - - media_type, data, audio_format = self._resolve_base64_parts(context_value) - return { - "type": "audio", - "source": { - "type": "base64", - "media_type": media_type, - "data": data, - "format": audio_format.value, - }, - } - - def _resolve_base64_parts(self, context_value: Any) -> tuple[str, Any, AudioFormat]: + if self.data_type == ModalityDataType.URL: + self._validate_url_context_value(context_value) + return get_media_url_context(Modality.AUDIO.value, context_value) + + if self.data_type is None and is_media_url(context_value): + return get_media_url_context(Modality.AUDIO.value, context_value) + + media_type, data = self._resolve_base64_parts(context_value) + return get_media_base64_context(Modality.AUDIO.value, media_type, data) + + def _resolve_base64_parts(self, context_value: Any) -> tuple[str, Any]: parsed = parse_base64_data_uri(context_value) if parsed is not None: media_type, data = parsed @@ -233,14 +204,20 @@ def _resolve_base64_parts(self, context_value: Any) -> tuple[str, Any, AudioForm raise ValueError( f"audio_format {audio_format.value!r} does not match data URI media type {media_type!r}" ) - return media_type, data, audio_format + return media_type, data if is_audio_path(context_value): raise ValueError("Local audio paths are not supported; provide an audio URL or base64 audio data") if self.audio_format is None: raise ValueError("audio_format is required for base64 audio context values") - return audio_mime_type(self.audio_format), context_value, self.audio_format + return audio_mime_type(self.audio_format), context_value + + def _validate_url_context_value(self, context_value: Any) -> None: + if is_audio_path(context_value): + raise ValueError("Local audio paths are not supported; provide an audio URL or base64 audio data") + if not is_media_url(context_value): + raise ValueError("audio URL context values must be HTTP(S) URLs") @model_validator(mode="after") def _validate_audio_format(self) -> Self: @@ -261,15 +238,18 @@ class VideoContext(ModalityContext): def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[dict[str, Any]]: """Get the contexts for the video modality.""" - del base_path return [self._build_context(value) for value in normalize_media_context_values(record[self.column_name])] def _build_context(self, context_value: Any) -> dict[str, Any]: - if self.data_type == ModalityDataType.URL or (self.data_type is None and is_media_url(context_value)): - return {"type": "video", "source": {"type": "url", "url": context_value}} + if self.data_type == ModalityDataType.URL: + self._validate_url_context_value(context_value) + return get_media_url_context(Modality.VIDEO.value, context_value) + + if self.data_type is None and is_media_url(context_value): + return get_media_url_context(Modality.VIDEO.value, context_value) media_type, data = self._resolve_base64_parts(context_value) - return {"type": "video", "source": {"type": "base64", "media_type": media_type, "data": data}} + return get_media_base64_context(Modality.VIDEO.value, media_type, data) def _resolve_base64_parts(self, context_value: Any) -> tuple[str, Any]: parsed = parse_base64_data_uri(context_value) @@ -291,6 +271,12 @@ def _resolve_base64_parts(self, context_value: Any) -> tuple[str, Any]: raise ValueError("video_format is required for base64 video context values") return video_mime_type(self.video_format), context_value + def _validate_url_context_value(self, context_value: Any) -> None: + if is_video_path(context_value): + raise ValueError("Local video paths are not supported; provide a video URL or base64 video data") + if not is_media_url(context_value): + raise ValueError("video URL context values must be HTTP(S) URLs") + @model_validator(mode="after") def _validate_video_format(self) -> Self: if self.data_type == ModalityDataType.BASE64 and self.video_format is None: diff --git a/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py index 35435af91..4af00d872 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py @@ -56,6 +56,21 @@ class VideoFormat(StrEnum): } +def get_media_context(modality: str, source: dict[str, Any]) -> dict[str, Any]: + """Build a canonical media context block.""" + return {"type": modality, "source": source} + + +def get_media_url_context(modality: str, url: Any) -> dict[str, Any]: + """Build a canonical URL media context block.""" + return get_media_context(modality, {"type": "url", "url": url}) + + +def get_media_base64_context(modality: str, media_type: str, data: Any) -> dict[str, Any]: + """Build a canonical base64 media context block.""" + return get_media_context(modality, {"type": "base64", "media_type": media_type, "data": data}) + + def normalize_media_context_values(raw_value: Any) -> list[Any]: """Normalize scalar, JSON-list, list, and array-like media values.""" if isinstance(raw_value, str): diff --git a/packages/data-designer-config/tests/config/test_columns.py b/packages/data-designer-config/tests/config/test_columns.py index fa0ce6882..ee87ffdbd 100644 --- a/packages/data-designer-config/tests/config/test_columns.py +++ b/packages/data-designer-config/tests/config/test_columns.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from typing import Literal import pytest @@ -157,14 +159,25 @@ def test_image_column_config_required_columns_includes_multi_modal_context(): multi_modal_context=[ ImageContext(column_name="reference_image"), AudioContext(column_name="reference_audio", data_type=ModalityDataType.URL), + VideoContext(column_name="reference_video", data_type=ModalityDataType.URL), ], ) - assert set(config.required_columns) == {"style", "reference_image", "reference_audio"} + assert set(config.required_columns) == {"style", "reference_image", "reference_audio", "reference_video"} -def test_multi_modal_context_round_trips_discriminated_union() -> None: - config = LLMTextColumnConfig( - name="test_llm_text", +@pytest.mark.parametrize( + "config_cls,name", + [ + (LLMTextColumnConfig, "test_llm_text"), + (ImageColumnConfig, "test_image"), + ], +) +def test_multi_modal_context_round_trips_discriminated_union( + config_cls: type[LLMTextColumnConfig] | type[ImageColumnConfig], + name: str, +) -> None: + config = config_cls( + name=name, prompt="Describe the context", model_alias=stub_model_alias, multi_modal_context=[ @@ -174,7 +187,7 @@ def test_multi_modal_context_round_trips_discriminated_union() -> None: ], ) - round_tripped = LLMTextColumnConfig(**config.model_dump()) + round_tripped = config_cls(**config.model_dump()) assert round_tripped.multi_modal_context is not None assert isinstance(round_tripped.multi_modal_context[0], ImageContext) @@ -182,6 +195,29 @@ def test_multi_modal_context_round_trips_discriminated_union() -> None: assert isinstance(round_tripped.multi_modal_context[2], VideoContext) +@pytest.mark.parametrize( + "config_cls,name", + [ + (LLMTextColumnConfig, "test_llm_text"), + (ImageColumnConfig, "test_image"), + ], +) +def test_column_config_accepts_legacy_image_context_dict( + config_cls: type[LLMTextColumnConfig] | type[ImageColumnConfig], + name: str, +) -> None: + config = config_cls( + name=name, + prompt="Describe the image", + model_alias=stub_model_alias, + multi_modal_context=[{"column_name": "image_url", "data_type": "url"}], + ) + + assert config.multi_modal_context is not None + assert isinstance(config.multi_modal_context[0], ImageContext) + assert config.multi_modal_context[0].column_name == "image_url" + + def test_llm_text_column_config_with_trace_serialization() -> None: """Test that with_trace field serializes and deserializes correctly.""" config = LLMTextColumnConfig( diff --git a/packages/data-designer-config/tests/config/test_models.py b/packages/data-designer-config/tests/config/test_models.py index dbdf9cfc9..aa60c954d 100644 --- a/packages/data-designer-config/tests/config/test_models.py +++ b/packages/data-designer-config/tests/config/test_models.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import base64 import json import tempfile @@ -25,6 +27,7 @@ ImageInferenceParams, ManualDistribution, ManualDistributionParams, + Modality, ModalityDataType, ModelConfig, UniformDistribution, @@ -33,6 +36,7 @@ VideoFormat, load_model_configs, ) +from data_designer.config.utils.media_helpers import get_media_base64_context, get_media_url_context def test_image_context_get_contexts_single_string(): @@ -41,18 +45,12 @@ def test_image_context_get_contexts_single_string(): column_name="image_base64", data_type=ModalityDataType.BASE64, image_format=ImageFormat.PNG ) assert image_context.get_contexts({"image_base64": "somebase64encodedimagestring"}) == [ - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64,somebase64encodedimagestring"}, - } + get_media_base64_context(Modality.IMAGE.value, "image/png", "somebase64encodedimagestring") ] image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL) assert image_context.get_contexts({"image_url": "https://example.com/examle_image.png"}) == [ - { - "type": "image_url", - "image_url": {"url": "https://example.com/examle_image.png"}, - } + get_media_url_context(Modality.IMAGE.value, "https://example.com/examle_image.png") ] @@ -62,32 +60,17 @@ def test_image_context_get_contexts_list_of_strings(): column_name="image_base64", data_type=ModalityDataType.BASE64, image_format=ImageFormat.PNG ) assert image_context.get_contexts({"image_base64": ["image1base64", "image2base64", "image3base64"]}) == [ - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64,image1base64"}, - }, - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64,image2base64"}, - }, - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64,image3base64"}, - }, + get_media_base64_context(Modality.IMAGE.value, "image/png", "image1base64"), + get_media_base64_context(Modality.IMAGE.value, "image/png", "image2base64"), + get_media_base64_context(Modality.IMAGE.value, "image/png", "image3base64"), ] image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL) assert image_context.get_contexts( {"image_url": ["https://example.com/image1.png", "https://example.com/image2.png"]} ) == [ - { - "type": "image_url", - "image_url": {"url": "https://example.com/image1.png"}, - }, - { - "type": "image_url", - "image_url": {"url": "https://example.com/image2.png"}, - }, + get_media_url_context(Modality.IMAGE.value, "https://example.com/image1.png"), + get_media_url_context(Modality.IMAGE.value, "https://example.com/image2.png"), ] @@ -98,27 +81,15 @@ def test_image_context_get_contexts_numpy_array(): ) numpy_array = lazy.np.array(["image1base64", "image2base64"]) assert image_context.get_contexts({"image_base64": numpy_array}) == [ - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64,image1base64"}, - }, - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64,image2base64"}, - }, + get_media_base64_context(Modality.IMAGE.value, "image/png", "image1base64"), + get_media_base64_context(Modality.IMAGE.value, "image/png", "image2base64"), ] image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL) numpy_array = lazy.np.array(["https://example.com/image1.png", "https://example.com/image2.png"]) assert image_context.get_contexts({"image_url": numpy_array}) == [ - { - "type": "image_url", - "image_url": {"url": "https://example.com/image1.png"}, - }, - { - "type": "image_url", - "image_url": {"url": "https://example.com/image2.png"}, - }, + get_media_url_context(Modality.IMAGE.value, "https://example.com/image1.png"), + get_media_url_context(Modality.IMAGE.value, "https://example.com/image2.png"), ] @@ -129,27 +100,15 @@ def test_image_context_get_contexts_json_serialized_list(): ) json_str = json.dumps(["image1base64", "image2base64"]) assert image_context.get_contexts({"image_base64": json_str}) == [ - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64,image1base64"}, - }, - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64,image2base64"}, - }, + get_media_base64_context(Modality.IMAGE.value, "image/png", "image1base64"), + get_media_base64_context(Modality.IMAGE.value, "image/png", "image2base64"), ] image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL) json_str = json.dumps(["https://example.com/image1.png", "https://example.com/image2.png"]) assert image_context.get_contexts({"image_url": json_str}) == [ - { - "type": "image_url", - "image_url": {"url": "https://example.com/image1.png"}, - }, - { - "type": "image_url", - "image_url": {"url": "https://example.com/image2.png"}, - }, + get_media_url_context(Modality.IMAGE.value, "https://example.com/image1.png"), + get_media_url_context(Modality.IMAGE.value, "https://example.com/image2.png"), ] @@ -158,10 +117,7 @@ def test_image_context_get_contexts_json_string_not_list(): image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL) json_str = json.dumps({"nested": "object"}) assert image_context.get_contexts({"image_url": json_str}) == [ - { - "type": "image_url", - "image_url": {"url": json_str}, - } + get_media_url_context(Modality.IMAGE.value, json_str) ] @@ -170,10 +126,7 @@ def test_image_context_get_contexts_invalid_json(): image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL) invalid_json = "not a valid json string" assert image_context.get_contexts({"image_url": invalid_json}) == [ - { - "type": "image_url", - "image_url": {"url": invalid_json}, - } + get_media_url_context(Modality.IMAGE.value, invalid_json) ] @@ -199,7 +152,7 @@ def test_image_context_auto_detect_url() -> None: """Test auto-detection with URL value (no data_type).""" context = ImageContext(column_name="image_col") result = context.get_contexts({"image_col": "https://example.com/image.png"}) - assert result == [{"type": "image_url", "image_url": {"url": "https://example.com/image.png"}}] + assert result == [get_media_url_context(Modality.IMAGE.value, "https://example.com/image.png")] def test_image_context_auto_detect_base64(minimal_png_base64: str) -> None: @@ -208,8 +161,7 @@ def test_image_context_auto_detect_base64(minimal_png_base64: str) -> None: context = ImageContext(column_name="image_col") result = context.get_contexts({"image_col": png_base64}) assert len(result) == 1 - assert result[0]["type"] == "image_url" - assert f"base64,{png_base64}" in result[0]["image_url"]["url"] + assert result[0] == get_media_base64_context(Modality.IMAGE.value, "image/png", png_base64) def test_image_context_auto_detect_file_path_resolved(tmp_path: Path) -> None: @@ -226,9 +178,8 @@ def test_image_context_auto_detect_file_path_resolved(tmp_path: Path) -> None: base_path=str(tmp_path), ) assert len(result) == 1 - assert result[0]["type"] == "image_url" expected_base64 = base64.b64encode(png_bytes).decode() - assert f"base64,{expected_base64}" in result[0]["image_url"]["url"] + assert result[0] == get_media_base64_context(Modality.IMAGE.value, "image/png", expected_base64) def test_image_context_auto_detect_file_path_not_resolved_without_base_path() -> None: @@ -253,20 +204,12 @@ def test_audio_context_get_contexts_single_string() -> None: column_name="audio_base64", data_type=ModalityDataType.BASE64, audio_format=AudioFormat.MP3 ) assert audio_context.get_contexts({"audio_base64": "audio1base64"}) == [ - { - "type": "audio", - "source": { - "type": "base64", - "media_type": "audio/mpeg", - "data": "audio1base64", - "format": "mp3", - }, - } + get_media_base64_context(Modality.AUDIO.value, "audio/mpeg", "audio1base64") ] audio_context = AudioContext(column_name="audio_url", data_type=ModalityDataType.URL) assert audio_context.get_contexts({"audio_url": "https://example.com/audio.mp3"}) == [ - {"type": "audio", "source": {"type": "url", "url": "https://example.com/audio.mp3"}} + get_media_url_context(Modality.AUDIO.value, "https://example.com/audio.mp3") ] @@ -275,49 +218,35 @@ def test_audio_context_get_contexts_list_json_and_numpy() -> None: column_name="audio_base64", data_type=ModalityDataType.BASE64, audio_format=AudioFormat.WAV ) assert audio_context.get_contexts({"audio_base64": ["audio1", "audio2"]}) == [ - { - "type": "audio", - "source": {"type": "base64", "media_type": "audio/wav", "data": "audio1", "format": "wav"}, - }, - { - "type": "audio", - "source": {"type": "base64", "media_type": "audio/wav", "data": "audio2", "format": "wav"}, - }, + get_media_base64_context(Modality.AUDIO.value, "audio/wav", "audio1"), + get_media_base64_context(Modality.AUDIO.value, "audio/wav", "audio2"), ] json_str = json.dumps(["https://example.com/audio1.mp3", "https://example.com/audio2.mp3"]) url_context = AudioContext(column_name="audio_url", data_type=ModalityDataType.URL) assert url_context.get_contexts({"audio_url": json_str}) == [ - {"type": "audio", "source": {"type": "url", "url": "https://example.com/audio1.mp3"}}, - {"type": "audio", "source": {"type": "url", "url": "https://example.com/audio2.mp3"}}, + get_media_url_context(Modality.AUDIO.value, "https://example.com/audio1.mp3"), + get_media_url_context(Modality.AUDIO.value, "https://example.com/audio2.mp3"), ] numpy_array = lazy.np.array(["https://example.com/audio1.mp3", "https://example.com/audio2.mp3"]) assert url_context.get_contexts({"audio_url": numpy_array}) == [ - {"type": "audio", "source": {"type": "url", "url": "https://example.com/audio1.mp3"}}, - {"type": "audio", "source": {"type": "url", "url": "https://example.com/audio2.mp3"}}, + get_media_url_context(Modality.AUDIO.value, "https://example.com/audio1.mp3"), + get_media_url_context(Modality.AUDIO.value, "https://example.com/audio2.mp3"), ] def test_audio_context_auto_detect_url_and_data_uri() -> None: assert AudioContext(column_name="audio_col").get_contexts({"audio_col": "https://example.com/audio.mp3"}) == [ - {"type": "audio", "source": {"type": "url", "url": "https://example.com/audio.mp3"}} + get_media_url_context(Modality.AUDIO.value, "https://example.com/audio.mp3") ] assert AudioContext(column_name="audio_col").get_contexts({"audio_col": "https://example.com/download?id=123"}) == [ - {"type": "audio", "source": {"type": "url", "url": "https://example.com/download?id=123"}} + get_media_url_context(Modality.AUDIO.value, "https://example.com/download?id=123") ] assert AudioContext(column_name="audio_col").get_contexts({"audio_col": "data:audio/mpeg;base64,audio1base64"}) == [ - { - "type": "audio", - "source": { - "type": "base64", - "media_type": "audio/mpeg", - "data": "audio1base64", - "format": "mp3", - }, - } + get_media_base64_context(Modality.AUDIO.value, "audio/mpeg", "audio1base64") ] @@ -325,6 +254,14 @@ def test_audio_context_validate_audio_format() -> None: with pytest.raises(ValueError, match="audio_format is required when data_type is base64"): AudioContext(column_name="audio_base64", data_type=ModalityDataType.BASE64) + with pytest.raises(ValueError, match="Local audio paths are not supported"): + AudioContext(column_name="audio_url", data_type=ModalityDataType.URL).get_contexts( + {"audio_url": "screen_recording.mp3"} + ) + + with pytest.raises(ValueError, match="audio URL context values must be HTTP"): + AudioContext(column_name="audio_url", data_type=ModalityDataType.URL).get_contexts({"audio_url": "not-a-url"}) + with pytest.raises(ValueError, match="audio_format is required for base64 audio context values"): AudioContext(column_name="audio_base64").get_contexts({"audio_base64": "audio1base64"}) @@ -344,15 +281,12 @@ def test_video_context_get_contexts_single_string() -> None: column_name="video_base64", data_type=ModalityDataType.BASE64, video_format=VideoFormat.MP4 ) assert video_context.get_contexts({"video_base64": "video1base64"}) == [ - { - "type": "video", - "source": {"type": "base64", "media_type": "video/mp4", "data": "video1base64"}, - } + get_media_base64_context(Modality.VIDEO.value, "video/mp4", "video1base64") ] video_context = VideoContext(column_name="video_url", data_type=ModalityDataType.URL) assert video_context.get_contexts({"video_url": "https://example.com/video.mp4"}) == [ - {"type": "video", "source": {"type": "url", "url": "https://example.com/video.mp4"}} + get_media_url_context(Modality.VIDEO.value, "https://example.com/video.mp4") ] @@ -361,35 +295,35 @@ def test_video_context_get_contexts_list_json_and_numpy() -> None: column_name="video_base64", data_type=ModalityDataType.BASE64, video_format=VideoFormat.WEBM ) assert video_context.get_contexts({"video_base64": ["video1", "video2"]}) == [ - {"type": "video", "source": {"type": "base64", "media_type": "video/webm", "data": "video1"}}, - {"type": "video", "source": {"type": "base64", "media_type": "video/webm", "data": "video2"}}, + get_media_base64_context(Modality.VIDEO.value, "video/webm", "video1"), + get_media_base64_context(Modality.VIDEO.value, "video/webm", "video2"), ] json_str = json.dumps(["https://example.com/video1.mp4", "https://example.com/video2.mp4"]) url_context = VideoContext(column_name="video_url", data_type=ModalityDataType.URL) assert url_context.get_contexts({"video_url": json_str}) == [ - {"type": "video", "source": {"type": "url", "url": "https://example.com/video1.mp4"}}, - {"type": "video", "source": {"type": "url", "url": "https://example.com/video2.mp4"}}, + get_media_url_context(Modality.VIDEO.value, "https://example.com/video1.mp4"), + get_media_url_context(Modality.VIDEO.value, "https://example.com/video2.mp4"), ] numpy_array = lazy.np.array(["https://example.com/video1.mp4", "https://example.com/video2.mp4"]) assert url_context.get_contexts({"video_url": numpy_array}) == [ - {"type": "video", "source": {"type": "url", "url": "https://example.com/video1.mp4"}}, - {"type": "video", "source": {"type": "url", "url": "https://example.com/video2.mp4"}}, + get_media_url_context(Modality.VIDEO.value, "https://example.com/video1.mp4"), + get_media_url_context(Modality.VIDEO.value, "https://example.com/video2.mp4"), ] def test_video_context_auto_detect_url_and_data_uri() -> None: assert VideoContext(column_name="video_col").get_contexts({"video_col": "https://example.com/video.mp4"}) == [ - {"type": "video", "source": {"type": "url", "url": "https://example.com/video.mp4"}} + get_media_url_context(Modality.VIDEO.value, "https://example.com/video.mp4") ] assert VideoContext(column_name="video_col").get_contexts({"video_col": "https://example.com/download?id=123"}) == [ - {"type": "video", "source": {"type": "url", "url": "https://example.com/download?id=123"}} + get_media_url_context(Modality.VIDEO.value, "https://example.com/download?id=123") ] assert VideoContext(column_name="video_col").get_contexts({"video_col": "data:video/mp4;base64,video1base64"}) == [ - {"type": "video", "source": {"type": "base64", "media_type": "video/mp4", "data": "video1base64"}} + get_media_base64_context(Modality.VIDEO.value, "video/mp4", "video1base64") ] @@ -397,6 +331,14 @@ def test_video_context_validate_video_format() -> None: with pytest.raises(ValueError, match="video_format is required when data_type is base64"): VideoContext(column_name="video_base64", data_type=ModalityDataType.BASE64) + with pytest.raises(ValueError, match="Local video paths are not supported"): + VideoContext(column_name="video_url", data_type=ModalityDataType.URL).get_contexts( + {"video_url": "screen_recording.mp4"} + ) + + with pytest.raises(ValueError, match="video URL context values must be HTTP"): + VideoContext(column_name="video_url", data_type=ModalityDataType.URL).get_contexts({"video_url": "not-a-url"}) + with pytest.raises(ValueError, match="video_format is required for base64 video context values"): VideoContext(column_name="video_base64").get_contexts({"video_base64": "video1base64"}) diff --git a/packages/data-designer-config/tests/config/utils/test_media_helpers.py b/packages/data-designer-config/tests/config/utils/test_media_helpers.py index fbb32011d..b762c6654 100644 --- a/packages/data-designer-config/tests/config/utils/test_media_helpers.py +++ b/packages/data-designer-config/tests/config/utils/test_media_helpers.py @@ -11,6 +11,9 @@ VideoFormat, audio_format_from_mime_type, audio_mime_type, + get_media_base64_context, + get_media_context, + get_media_url_context, is_audio_path, is_audio_url, is_media_url, @@ -23,6 +26,21 @@ ) +def test_media_context_builders() -> None: + assert get_media_context("image", {"type": "url", "url": "https://example.com/image.png"}) == { + "type": "image", + "source": {"type": "url", "url": "https://example.com/image.png"}, + } + assert get_media_url_context("audio", "https://example.com/audio.mp3") == { + "type": "audio", + "source": {"type": "url", "url": "https://example.com/audio.mp3"}, + } + assert get_media_base64_context("video", "video/mp4", "abc123") == { + "type": "video", + "source": {"type": "base64", "media_type": "video/mp4", "data": "abc123"}, + } + + def test_normalize_media_context_values() -> None: assert normalize_media_context_values("single") == ["single"] assert normalize_media_context_values(["one", "two"]) == ["one", "two"] diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py index d2e8b66ca..ef6d1dcf4 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py @@ -7,6 +7,8 @@ import re from typing import Any +from data_designer.config.models import Modality +from data_designer.config.utils.media_helpers import get_media_base64_context, get_media_url_context from data_designer.engine.models.clients.parsing import extract_usage, fill_reasoning_token_count_from_content from data_designer.engine.models.clients.types import ( AssistantMessage, @@ -215,8 +217,9 @@ def translate_content_blocks(content: Any) -> list[dict[str, Any]]: if isinstance(block, dict) and block.get("type") == "image": translated.append(translate_canonical_image_block(block)) continue - if isinstance(block, dict) and block.get("type") in _UNSUPPORTED_MEDIA_BLOCK_MODALITIES: - raise UnsupportedAnthropicMediaBlockError(_UNSUPPORTED_MEDIA_BLOCK_MODALITIES[block["type"]]) + block_type = block.get("type") if isinstance(block, dict) else None + if isinstance(block_type, str) and block_type in _UNSUPPORTED_MEDIA_BLOCK_MODALITIES: + raise UnsupportedAnthropicMediaBlockError(_UNSUPPORTED_MEDIA_BLOCK_MODALITIES[block_type]) # Anthropic rejects empty text blocks β€” drop them. if isinstance(block, dict) and block.get("type") == "text" and not block.get("text"): continue @@ -349,19 +352,9 @@ def translate_image_url_block(block: dict[str, Any]) -> dict[str, Any]: match = _DATA_URI_RE.match(url) if match: - return { - "type": "image", - "source": { - "type": "base64", - "media_type": match.group("media_type"), - "data": match.group("data"), - }, - } + return get_media_base64_context(Modality.IMAGE.value, match.group("media_type"), match.group("data")) - return { - "type": "image", - "source": {"type": "url", "url": url}, - } + return get_media_url_context(Modality.IMAGE.value, url) def translate_canonical_image_block(block: dict[str, Any]) -> dict[str, Any]: @@ -371,12 +364,12 @@ def translate_canonical_image_block(block: dict[str, Any]) -> dict[str, Any]: source_type = source.get("type") if source_type == "url": - return {"type": "image", "source": {"type": "url", "url": source.get("url", "")}} + return get_media_url_context(Modality.IMAGE.value, source.get("url", "")) if source_type == "base64": media_type = source.get("media_type") data = source.get("data") if not isinstance(media_type, str) or not isinstance(data, str): raise ValueError(f"Canonical image base64 source must include media_type and data, got: {source!r}") - return {"type": "image", "source": {"type": "base64", "media_type": media_type, "data": data}} + return get_media_base64_context(Modality.IMAGE.value, media_type, data) raise ValueError(f"Unsupported canonical image source type {source_type!r}") diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py index 81ac9a2a9..94a294717 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py @@ -5,9 +5,11 @@ from typing import Any +from data_designer.config.utils.media_helpers import audio_format_from_mime_type from data_designer.engine.models.clients.adapters.http_model_client import ( HttpModelClient, ) +from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind from data_designer.engine.models.clients.parsing import ( aextract_images_from_chat_response, aextract_images_from_image_response, @@ -164,44 +166,34 @@ def translate_openai_compatible_messages( for message in messages: translated = dict(message) if "content" in translated: - translated["content"] = translate_openai_compatible_content_blocks( - translated["content"], - provider_name=provider_name, - model_name=model_name, - ) + try: + translated["content"] = translate_openai_compatible_content_blocks(translated["content"]) + except ValueError as exc: + raise ProviderError( + kind=ProviderErrorKind.BAD_REQUEST, + message=str(exc), + provider_name=provider_name, + model_name=model_name, + cause=exc, + ) from exc translated_messages.append(translated) return translated_messages -def translate_openai_compatible_content_blocks( - content: Any, - *, - provider_name: str, - model_name: str, -) -> Any: +def translate_openai_compatible_content_blocks(content: Any) -> Any: if not isinstance(content, list): return content - return [ - translate_openai_compatible_content_block( - block, - provider_name=provider_name, - model_name=model_name, - ) - for block in content - ] + return [translate_openai_compatible_content_block(block) for block in content] -def translate_openai_compatible_content_block( - block: Any, - *, - provider_name: str, - model_name: str, -) -> Any: +def translate_openai_compatible_content_block(block: Any) -> Any: if not isinstance(block, dict): return block block_type = block.get("type") + if not isinstance(block_type, str): + return block if block_type in {"image_url", "input_audio", "text"}: return block if block_type == "image": @@ -233,11 +225,14 @@ def _translate_canonical_audio_block(block: dict[str, Any]) -> dict[str, Any]: if source_type == "url": return {"type": "audio_url", "audio_url": {"url": source.get("url", "")}} if source_type == "base64": + media_type = source.get("media_type") data = source.get("data") - audio_format = source.get("format") - if not isinstance(data, str) or not isinstance(audio_format, str): - raise ValueError(f"Canonical audio base64 source must include data and format, got: {source!r}") - return {"type": "input_audio", "input_audio": {"data": data, "format": audio_format}} + if not isinstance(media_type, str) or not isinstance(data, str): + raise ValueError(f"Canonical audio base64 source must include media_type and data, got: {source!r}") + audio_format = audio_format_from_mime_type(media_type) + if audio_format is None: + raise ValueError(f"Unsupported canonical audio media type {media_type!r}") + return {"type": "input_audio", "input_audio": {"data": data, "format": audio_format.value}} raise ValueError(f"Unsupported canonical audio source type {source_type!r}") diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py index 6f5e8799e..bc5dc0c51 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py @@ -1,13 +1,23 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import base64 from unittest.mock import Mock, patch import pytest from data_designer.config.column_configs import ImageColumnConfig -from data_designer.config.models import ImageContext, ImageFormat, ModalityDataType +from data_designer.config.models import ( + AudioContext, + ImageContext, + ImageFormat, + Modality, + ModalityDataType, + VideoContext, +) +from data_designer.config.utils.media_helpers import get_media_base64_context, get_media_url_context from data_designer.engine.column_generators.generators.base import GenerationStrategy from data_designer.engine.column_generators.generators.image import ImageCellGenerator from data_designer.engine.processing.ginja.exceptions import UserTemplateError @@ -175,8 +185,9 @@ def test_image_cell_generator_with_multi_modal_context(stub_resource_provider): assert call_args.kwargs["prompt"] == "Generate a similar image to the reference" assert call_args.kwargs["multi_modal_context"] is not None assert len(call_args.kwargs["multi_modal_context"]) == 1 - assert call_args.kwargs["multi_modal_context"][0]["type"] == "image_url" - assert call_args.kwargs["multi_modal_context"][0]["image_url"] == {"url": "https://example.com/image.png"} + assert call_args.kwargs["multi_modal_context"][0] == get_media_url_context( + Modality.IMAGE.value, "https://example.com/image.png" + ) def test_image_cell_generator_with_base64_multi_modal_context(stub_resource_provider): @@ -218,9 +229,47 @@ def test_image_cell_generator_with_base64_multi_modal_context(stub_resource_prov assert call_args.kwargs["prompt"] == "Generate a variation of this image" assert call_args.kwargs["multi_modal_context"] is not None assert len(call_args.kwargs["multi_modal_context"]) == 1 - assert call_args.kwargs["multi_modal_context"][0]["type"] == "image_url" - # Should be formatted as data URI - assert "data:image/png;base64," in call_args.kwargs["multi_modal_context"][0]["image_url"]["url"] + assert call_args.kwargs["multi_modal_context"][0] == get_media_base64_context( + Modality.IMAGE.value, "image/png", "iVBORw0KGgoAAAANS" + ) + + +def test_image_cell_generator_with_mixed_media_context(stub_resource_provider: Mock) -> None: + config = ImageColumnConfig( + name="test_image", + prompt="Generate a poster from this media", + model_alias="test_model", + multi_modal_context=[ + ImageContext(column_name="reference_image", data_type=ModalityDataType.URL), + AudioContext(column_name="reference_audio", data_type=ModalityDataType.URL), + VideoContext(column_name="reference_video", data_type=ModalityDataType.URL), + ], + ) + + mock_storage = Mock() + mock_storage.save_base64_image.return_value = "images/generated.png" + stub_resource_provider.artifact_storage.media_storage = mock_storage + + with patch.object( + stub_resource_provider.model_registry.get_model.return_value, + "generate_image", + return_value=["base64_generated_image"], + ) as mock_generate: + generator = ImageCellGenerator(config=config, resource_provider=stub_resource_provider) + generator.generate( + data={ + "reference_image": "https://example.com/image.png", + "reference_audio": "https://example.com/audio.mp3", + "reference_video": "https://example.com/video.mp4", + } + ) + + mock_generate.assert_called_once() + assert mock_generate.call_args.kwargs["multi_modal_context"] == [ + get_media_url_context(Modality.IMAGE.value, "https://example.com/image.png"), + get_media_url_context(Modality.AUDIO.value, "https://example.com/audio.mp3"), + get_media_url_context(Modality.VIDEO.value, "https://example.com/video.mp4"), + ] def test_image_cell_generator_build_multi_modal_context_returns_none_when_not_configured( @@ -274,10 +323,9 @@ def test_image_cell_generator_auto_resolves_generated_image_file_path(stub_resou context = call_args.kwargs["multi_modal_context"] assert context is not None assert len(context) == 1 - assert context[0]["type"] == "image_url" # Should contain base64 data, NOT the file path expected_b64 = base64.b64encode(png_bytes).decode() - assert expected_b64 in context[0]["image_url"]["url"] + assert context[0] == get_media_base64_context(Modality.IMAGE.value, "image/png", expected_b64) def test_image_cell_generator_auto_detect_passes_through_urls(stub_resource_provider: Mock) -> None: @@ -306,4 +354,4 @@ def test_image_cell_generator_auto_detect_passes_through_urls(stub_resource_prov mock_generate.assert_called_once() context = mock_generate.call_args.kwargs["multi_modal_context"] assert context is not None - assert context[0]["image_url"] == {"url": "https://example.com/image.png"} + assert context[0] == get_media_url_context(Modality.IMAGE.value, "https://example.com/image.png") diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py index ecab9cc90..450a32b0d 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py @@ -7,6 +7,8 @@ import pytest +from data_designer.config.models import Modality +from data_designer.config.utils.media_helpers import get_media_base64_context, get_media_url_context from data_designer.engine.mcp.registry import MCPToolDefinition from data_designer.engine.models.clients.adapters.anthropic_translation import ( UnsupportedAnthropicMediaBlockError, @@ -68,7 +70,7 @@ def test_build_anthropic_payload_preserves_multimodal_system_content() -> None: assert payload["system"] == [ {"type": "text", "text": "Describe this image."}, - {"type": "image", "source": {"type": "url", "url": "https://example.com/reference.png"}}, + get_media_url_context(Modality.IMAGE.value, "https://example.com/reference.png"), ] @@ -169,14 +171,14 @@ def test_translate_request_messages_merges_parallel_tool_results() -> None: ], [ {"type": "text", "text": "Rule 1"}, - {"type": "image", "source": {"type": "url", "url": "https://example.com/reference.png"}}, + get_media_url_context(Modality.IMAGE.value, "https://example.com/reference.png"), {"type": "text", "text": "Rule 2"}, ], id="mixed-text-and-image-returns-blocks", ), pytest.param( [{"type": "image_url", "image_url": {"url": "https://example.com/reference.png"}}], - [{"type": "image", "source": {"type": "url", "url": "https://example.com/reference.png"}}], + [get_media_url_context(Modality.IMAGE.value, "https://example.com/reference.png")], id="image-only-returns-blocks", ), pytest.param( @@ -212,13 +214,13 @@ def test_extract_system_content_normalizes_supported_inputs( "Text preamble", [ {"type": "text", "text": "Rule 1"}, - {"type": "image", "source": {"type": "url", "url": "https://example.com/img.png"}}, + get_media_url_context(Modality.IMAGE.value, "https://example.com/img.png"), ], ], [ {"type": "text", "text": "Text preamble"}, {"type": "text", "text": "Rule 1"}, - {"type": "image", "source": {"type": "url", "url": "https://example.com/img.png"}}, + get_media_url_context(Modality.IMAGE.value, "https://example.com/img.png"), ], id="mixed-string-and-blocks", ), @@ -337,7 +339,7 @@ def test_translate_content_blocks_converts_images_and_preserves_other_blocks() - ) assert blocks == [ - {"type": "image", "source": {"type": "url", "url": "https://example.com/cat.png"}}, + get_media_url_context(Modality.IMAGE.value, "https://example.com/cat.png"), {"type": "text", "text": "Caption"}, {"type": "custom_block", "value": "kept"}, ] @@ -346,13 +348,13 @@ def test_translate_content_blocks_converts_images_and_preserves_other_blocks() - def test_translate_content_blocks_converts_canonical_images() -> None: blocks = translate_content_blocks( [ - {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": "iVBOR..."}}, + get_media_base64_context(Modality.IMAGE.value, "image/png", "iVBOR..."), {"type": "text", "text": "Caption"}, ] ) assert blocks == [ - {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": "iVBOR..."}}, + get_media_base64_context(Modality.IMAGE.value, "image/png", "iVBOR..."), {"type": "text", "text": "Caption"}, ] @@ -392,12 +394,12 @@ def test_translate_content_blocks_rejects_malformed_image_url_block() -> None: [ pytest.param( {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR..."}}, - {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": "iVBOR..."}}, + get_media_base64_context(Modality.IMAGE.value, "image/png", "iVBOR..."), id="data-uri-dict", ), pytest.param( {"type": "image_url", "image_url": {"url": "https://example.com/cat.png"}}, - {"type": "image", "source": {"type": "url", "url": "https://example.com/cat.png"}}, + get_media_url_context(Modality.IMAGE.value, "https://example.com/cat.png"), id="remote-url-dict", ), ], @@ -526,7 +528,7 @@ def test_translate_tool_result_message_requires_tool_call_id(message: dict[str, {"type": "text", "text": "Caption"}, ], [ - {"type": "image", "source": {"type": "url", "url": "https://example.com/chart.png"}}, + get_media_url_context(Modality.IMAGE.value, "https://example.com/chart.png"), {"type": "text", "text": "Caption"}, ], id="mixed-blocks", @@ -538,14 +540,7 @@ def test_translate_tool_result_message_requires_tool_call_id(message: dict[str, ], [ {"type": "text", "text": "Rendered chart:"}, - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": "iVBORw0KGgo=", - }, - }, + get_media_base64_context(Modality.IMAGE.value, "image/png", "iVBORw0KGgo="), ], id="mixed-blocks-with-data-uri", ), diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py index 47af95e9c..5107fa9ad 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py @@ -8,6 +8,8 @@ import pytest +from data_designer.config.models import Modality +from data_designer.config.utils.media_helpers import get_media_base64_context, get_media_url_context from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode from data_designer.engine.models.clients.adapters.openai_compatible import OpenAICompatibleClient from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind @@ -321,7 +323,7 @@ def test_completion_translates_canonical_image_blocks() -> None: sync_mock = make_mock_sync_client(_chat_response()) client = _make_client(sync_client=sync_mock) - image_block = {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": "iVBOR..."}} + image_block = get_media_base64_context(Modality.IMAGE.value, "image/png", "iVBOR...") request = ChatCompletionRequest( model=MODEL, messages=[{"role": "user", "content": [image_block, {"type": "text", "text": "What is this?"}]}], @@ -337,10 +339,7 @@ def test_completion_translates_base64_audio_blocks() -> None: sync_mock = make_mock_sync_client(_chat_response()) client = _make_client(sync_client=sync_mock) - audio_block = { - "type": "audio", - "source": {"type": "base64", "media_type": "audio/mpeg", "data": "abc123", "format": "mp3"}, - } + audio_block = get_media_base64_context(Modality.AUDIO.value, "audio/mpeg", "abc123") request = ChatCompletionRequest( model=MODEL, messages=[{"role": "user", "content": [audio_block, {"type": "text", "text": "Transcribe this."}]}], @@ -352,11 +351,31 @@ def test_completion_translates_base64_audio_blocks() -> None: assert content[0] == {"type": "input_audio", "input_audio": {"data": "abc123", "format": "mp3"}} +def test_completion_rejects_unsupported_canonical_audio_media_type() -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + audio_block = get_media_base64_context(Modality.AUDIO.value, "audio/flac", "abc123") + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [audio_block, {"type": "text", "text": "Transcribe this."}]}], + ) + + with pytest.raises(ProviderError) as exc_info: + client.completion(request) + + assert exc_info.value.kind == ProviderErrorKind.BAD_REQUEST + assert exc_info.value.provider_name == PROVIDER + assert exc_info.value.model_name == MODEL + assert "audio/flac" in exc_info.value.message + sync_mock.post.assert_not_called() + + def test_completion_translates_audio_url_blocks() -> None: sync_mock = make_mock_sync_client(_chat_response()) client = _make_client(sync_client=sync_mock) - audio_block = {"type": "audio", "source": {"type": "url", "url": "https://example.com/download?id=123"}} + audio_block = get_media_url_context(Modality.AUDIO.value, "https://example.com/download?id=123") request = ChatCompletionRequest( model=MODEL, messages=[{"role": "user", "content": [audio_block, {"type": "text", "text": "Transcribe this."}]}], @@ -372,7 +391,7 @@ def test_completion_translates_video_blocks() -> None: sync_mock = make_mock_sync_client(_chat_response()) client = _make_client(sync_client=sync_mock) - video_block = {"type": "video", "source": {"type": "url", "url": "https://example.com/download?id=123"}} + video_block = get_media_url_context(Modality.VIDEO.value, "https://example.com/download?id=123") request = ChatCompletionRequest( model=MODEL, messages=[{"role": "user", "content": [video_block, {"type": "text", "text": "Describe this."}]}], @@ -388,7 +407,7 @@ def test_completion_translates_base64_video_blocks() -> None: sync_mock = make_mock_sync_client(_chat_response()) client = _make_client(sync_client=sync_mock) - video_block = {"type": "video", "source": {"type": "base64", "media_type": "video/mp4", "data": "abc123"}} + video_block = get_media_base64_context(Modality.VIDEO.value, "video/mp4", "abc123") request = ChatCompletionRequest( model=MODEL, messages=[{"role": "user", "content": [video_block, {"type": "text", "text": "Describe this."}]}], diff --git a/packages/data-designer-engine/tests/engine/models/test_model_utils.py b/packages/data-designer-engine/tests/engine/models/test_model_utils.py index 65ca83afb..7149560ba 100644 --- a/packages/data-designer-engine/tests/engine/models/test_model_utils.py +++ b/packages/data-designer-engine/tests/engine/models/test_model_utils.py @@ -3,6 +3,8 @@ from __future__ import annotations +from data_designer.config.models import Modality +from data_designer.config.utils.media_helpers import get_media_base64_context, get_media_url_context from data_designer.engine.models.utils import ChatMessage, prompt_to_messages @@ -39,12 +41,9 @@ def test_chat_message_as_tool_accepts_multimodal_content() -> None: def test_prompt_to_messages_preserves_mixed_media_context_order() -> None: context = [ - {"type": "image", "source": {"type": "url", "url": "https://example.com/image.png"}}, - { - "type": "audio", - "source": {"type": "base64", "media_type": "audio/mpeg", "data": "abc123", "format": "mp3"}, - }, - {"type": "video", "source": {"type": "url", "url": "https://example.com/video.mp4"}}, + get_media_url_context(Modality.IMAGE.value, "https://example.com/image.png"), + get_media_base64_context(Modality.AUDIO.value, "audio/mpeg", "abc123"), + get_media_url_context(Modality.VIDEO.value, "https://example.com/video.mp4"), ] assert prompt_to_messages(user_prompt="describe", multi_modal_context=context) == [ diff --git a/packages/data-designer-engine/tests/engine/test_validation.py b/packages/data-designer-engine/tests/engine/test_validation.py index efea50fa1..1a1191a9e 100644 --- a/packages/data-designer-engine/tests/engine/test_validation.py +++ b/packages/data-designer-engine/tests/engine/test_validation.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from unittest.mock import Mock, patch import pytest @@ -248,7 +250,7 @@ def test_validate_column_config_with_multi_modal_context(): assert len(violations) == 0 -def test_validate_column_config_with_audio_multi_modal_context(): +def test_validate_column_config_with_audio_multi_modal_context() -> None: column = LLMTextColumnConfig( name="audio_description", prompt="Describe the audio.", From 725205c77f92561c855f27db16177372c01430d9 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Thu, 21 May 2026 12:16:09 -0600 Subject: [PATCH 07/15] address media context review notes Remove unused URL-specific media helpers, share the base64 data URI parser in Anthropic translation, align AudioContext validation messaging, and update config docs for audio/video contexts. Refs #671 --- docs/code_reference/config/models.md | 2 +- .../pages/code_reference/config/models.mdx | 2 +- .../src/data_designer/config/models.py | 5 ++--- .../config/utils/media_helpers.py | 16 ---------------- .../tests/config/utils/test_media_helpers.py | 19 ++++--------------- .../clients/adapters/anthropic_translation.py | 15 +++++++++------ 6 files changed, 17 insertions(+), 42 deletions(-) diff --git a/docs/code_reference/config/models.md b/docs/code_reference/config/models.md index e14e8cfdb..22f0b1172 100644 --- a/docs/code_reference/config/models.md +++ b/docs/code_reference/config/models.md @@ -1,6 +1,6 @@ # Models -[ModelProvider](#data_designer.config.models.ModelProvider) stores connection and authentication details for model providers. [ModelConfig](#data_designer.config.models.ModelConfig) stores a model alias, model identifier, provider settings, and inference parameters. [Inference Parameters](../../concepts/models/inference-parameters.md) control model behavior. Chat-completion parameters include `temperature`, `top_p`, and `max_tokens`; `temperature` and `top_p` can be fixed values or configured distributions. [ImageContext](#data_designer.config.models.ImageContext) provides image inputs to multimodal models, and [ImageInferenceParams](#data_designer.config.models.ImageInferenceParams) configures image generation models. +[ModelProvider](#data_designer.config.models.ModelProvider) stores connection and authentication details for model providers. [ModelConfig](#data_designer.config.models.ModelConfig) stores a model alias, model identifier, provider settings, and inference parameters. [Inference Parameters](../../concepts/models/inference-parameters.md) control model behavior. Chat-completion parameters include `temperature`, `top_p`, and `max_tokens`; `temperature` and `top_p` can be fixed values or configured distributions. [ImageContext](#data_designer.config.models.ImageContext), [AudioContext](#data_designer.config.models.AudioContext), and [VideoContext](#data_designer.config.models.VideoContext) provide image, audio, and video inputs to multimodal models, and [ImageInferenceParams](#data_designer.config.models.ImageInferenceParams) configures image generation models. Related guides: diff --git a/fern/versions/latest/pages/code_reference/config/models.mdx b/fern/versions/latest/pages/code_reference/config/models.mdx index a9c6da403..7ac0c09b6 100644 --- a/fern/versions/latest/pages/code_reference/config/models.mdx +++ b/fern/versions/latest/pages/code_reference/config/models.mdx @@ -3,7 +3,7 @@ title: "Models" description: "" position: 1 --- -The `models` module defines configuration objects for model-based generation. [ModelProvider](#data_designer.config.models.ModelProvider) specifies connection and authentication details for custom providers. [ModelConfig](#data_designer.config.models.ModelConfig) encapsulates model details including the model alias, identifier, and inference parameters. [Inference Parameters](/concepts/models/inference-parameters) controls model behavior through settings like `temperature`, `top_p`, and `max_tokens`, with support for both fixed values and distribution-based sampling. The module includes [ImageContext](#data_designer.config.models.ImageContext) for providing image inputs to multimodal models, and [ImageInferenceParams](#data_designer.config.models.ImageInferenceParams) for configuring image generation models. +The `models` module defines configuration objects for model-based generation. [ModelProvider](#data_designer.config.models.ModelProvider) specifies connection and authentication details for custom providers. [ModelConfig](#data_designer.config.models.ModelConfig) encapsulates model details including the model alias, identifier, and inference parameters. [Inference Parameters](/concepts/models/inference-parameters) controls model behavior through settings like `temperature`, `top_p`, and `max_tokens`, with support for both fixed values and distribution-based sampling. The module includes [ImageContext](#data_designer.config.models.ImageContext), [AudioContext](#data_designer.config.models.AudioContext), and [VideoContext](#data_designer.config.models.VideoContext) for providing image, audio, and video inputs to multimodal models, and [ImageInferenceParams](#data_designer.config.models.ImageInferenceParams) for configuring image generation models. For more information on how they are used, see below: diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py index bd614c66e..80e3752e5 100644 --- a/packages/data-designer-config/src/data_designer/config/models.py +++ b/packages/data-designer-config/src/data_designer/config/models.py @@ -199,10 +199,9 @@ def _resolve_base64_parts(self, context_value: Any) -> tuple[str, Any]: detected_format = audio_format_from_mime_type(media_type) if detected_format is None: raise ValueError(f"Unsupported audio media type {media_type!r}") - audio_format = self.audio_format or detected_format - if audio_format != detected_format: + if self.audio_format is not None and self.audio_format != detected_format: raise ValueError( - f"audio_format {audio_format.value!r} does not match data URI media type {media_type!r}" + f"audio_format {self.audio_format.value!r} does not match data URI media type {media_type!r}" ) return media_type, data diff --git a/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py index 4af00d872..98d40a256 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py @@ -106,16 +106,6 @@ def is_media_url(value: str) -> bool: return isinstance(value, str) and value.startswith(("http://", "https://")) -def is_audio_url(value: str) -> bool: - """Return whether a value looks like an audio URL.""" - return is_media_url(value) and _has_media_extension(value, SUPPORTED_AUDIO_EXTENSIONS) - - -def is_video_url(value: str) -> bool: - """Return whether a value looks like a video URL.""" - return is_media_url(value) and _has_media_extension(value, SUPPORTED_VIDEO_EXTENSIONS) - - def is_audio_path(value: str) -> bool: """Return whether a value looks like a local audio path.""" return _has_path_extension(value, SUPPORTED_AUDIO_EXTENSIONS) @@ -146,12 +136,6 @@ def video_format_from_mime_type(media_type: str) -> VideoFormat | None: return _VIDEO_MIME_TYPE_TO_FORMAT.get(media_type.lower()) -def _has_media_extension(value: str, supported_extensions: tuple[str, ...]) -> bool: - if not isinstance(value, str): - return False - return any(ext in value.lower() for ext in supported_extensions) - - def _has_path_extension(value: str, supported_extensions: tuple[str, ...]) -> bool: if not isinstance(value, str): return False diff --git a/packages/data-designer-config/tests/config/utils/test_media_helpers.py b/packages/data-designer-config/tests/config/utils/test_media_helpers.py index b762c6654..e5751a55b 100644 --- a/packages/data-designer-config/tests/config/utils/test_media_helpers.py +++ b/packages/data-designer-config/tests/config/utils/test_media_helpers.py @@ -15,10 +15,8 @@ get_media_context, get_media_url_context, is_audio_path, - is_audio_url, is_media_url, is_video_path, - is_video_url, normalize_media_context_values, parse_base64_data_uri, video_format_from_mime_type, @@ -54,20 +52,11 @@ def test_parse_base64_data_uri() -> None: assert parse_base64_data_uri("abc123") is None -def test_audio_url_detection() -> None: +def test_media_url_detection() -> None: assert is_media_url("https://example.com/download?id=123") is True - assert is_audio_url("https://example.com/audio.mp3") is True - assert is_audio_url("https://example.com/audio.wav?download=1") is True - assert is_audio_url("https://example.com/image.png") is False - assert is_audio_url(123) is False # type: ignore[arg-type] - - -def test_video_url_detection() -> None: - assert is_media_url("https://example.com/download?id=123") is True - assert is_video_url("https://example.com/video.mp4") is True - assert is_video_url("https://example.com/video.webm?download=1") is True - assert is_video_url("https://example.com/audio.mp3") is False - assert is_video_url(123) is False # type: ignore[arg-type] + assert is_media_url("http://example.com/media") is True + assert is_media_url("ftp://example.com/media") is False + assert is_media_url(123) is False # type: ignore[arg-type] def test_local_media_path_detection() -> None: diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py index ef6d1dcf4..62dc1d4e7 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py @@ -4,11 +4,14 @@ from __future__ import annotations import json -import re from typing import Any from data_designer.config.models import Modality -from data_designer.config.utils.media_helpers import get_media_base64_context, get_media_url_context +from data_designer.config.utils.media_helpers import ( + get_media_base64_context, + get_media_url_context, + parse_base64_data_uri, +) from data_designer.engine.models.clients.parsing import extract_usage, fill_reasoning_token_count_from_content from data_designer.engine.models.clients.types import ( AssistantMessage, @@ -19,7 +22,6 @@ ) _DEFAULT_MAX_TOKENS = 4096 -_DATA_URI_RE = re.compile(r"^data:(?P[^;]+);base64,(?P.+)$") _UNSUPPORTED_MEDIA_BLOCK_MODALITIES: dict[str, str] = { "audio": "audio", "audio_url": "audio", @@ -350,9 +352,10 @@ def translate_image_url_block(block: dict[str, Any]) -> dict[str, Any]: url = image_url.get("url", "") - match = _DATA_URI_RE.match(url) - if match: - return get_media_base64_context(Modality.IMAGE.value, match.group("media_type"), match.group("data")) + parsed = parse_base64_data_uri(url) + if parsed is not None: + media_type, data = parsed + return get_media_base64_context(Modality.IMAGE.value, media_type, data) return get_media_url_context(Modality.IMAGE.value, url) From 58058a59ee43dcae82c0c86c6fb1614e756c8ca8 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 22 May 2026 09:00:03 -0600 Subject: [PATCH 08/15] docs: update media context guidance --- .../4-providing-images-as-context.ipynb | 40 ++++++++++++++++++- .../4-providing-images-as-context.py | 35 +++++++++++++++- docs/notebook_source/_README.md | 1 + .../pages/code_reference/config/models.mdx | 4 +- .../latest/pages/concepts/columns.mdx | 6 ++- .../models/default-model-settings.mdx | 12 ++++-- .../pages/concepts/models/model-configs.mdx | 8 ++-- .../4-providing-images-as-context.mdx | 2 +- .../latest/pages/notebooks/README.mdx | 2 +- 9 files changed, 96 insertions(+), 14 deletions(-) diff --git a/docs/colab_notebooks/4-providing-images-as-context.ipynb b/docs/colab_notebooks/4-providing-images-as-context.ipynb index ba225ac0a..5accf02e4 100644 --- a/docs/colab_notebooks/4-providing-images-as-context.ipynb +++ b/docs/colab_notebooks/4-providing-images-as-context.ipynb @@ -24,9 +24,11 @@ "#### πŸ“š What you'll learn\n", "\n", "This notebook demonstrates how to provide images as context to generate text descriptions using vision-language models.\n", + "The same `multi_modal_context` field can also carry audio or video context when the selected model supports those modalities.\n", "\n", "- ✨ **Visual Document Processing**: Converting images to chat-ready format for model consumption\n", "- πŸ” **Vision-Language Generation**: Using vision models to generate detailed summaries from images\n", + "- 🧩 **Media Context Pattern**: Understanding how `ImageContext`, `AudioContext`, and `VideoContext` fit into the same configuration field\n", "\n", "If this is your first time using Data Designer, we recommend starting with the [first notebook](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/1-the-basics/) in this tutorial series.\n" ] @@ -268,6 +270,42 @@ "config_builder.with_seed_dataset(dd.DataFrameSeedSource(df=df_seed))" ] }, + { + "cell_type": "markdown", + "id": "media-context-capabilities", + "metadata": {}, + "source": [ + "### 🧩 Media context and model capabilities\n", + "\n", + "`multi_modal_context` accepts media context descriptors such as `ImageContext`, `AudioContext`, and `VideoContext`. Data Designer reads the referenced seed columns and serializes them for the model request, but the selected model still determines which modalities are valid.\n", + "\n", + "This notebook uses image context only because image-capable VLMs are broadly available. Before combining image, audio, and video in one column, choose a model alias backed by an omni or otherwise modality-compatible model, and check that the provider accepts every context type you send.\n", + "\n", + "For base64 seed columns, store the raw base64 payload without a `data:;base64,` prefix and specify the media format on the context object:\n", + "\n", + "```python\n", + "media_context = [\n", + " dd.ImageContext(\n", + " column_name=\"image_base64\",\n", + " data_type=dd.ModalityDataType.BASE64,\n", + " image_format=dd.ImageFormat.PNG,\n", + " ),\n", + " dd.AudioContext(\n", + " column_name=\"audio_base64\",\n", + " data_type=dd.ModalityDataType.BASE64,\n", + " audio_format=dd.AudioFormat.MP3,\n", + " ),\n", + " dd.VideoContext(\n", + " column_name=\"video_base64\",\n", + " data_type=dd.ModalityDataType.BASE64,\n", + " video_format=dd.VideoFormat.MP4,\n", + " ),\n", + "]\n", + "```\n", + "\n", + "URL-backed media can use `data_type=dd.ModalityDataType.URL`, subject to the provider's URL support and file-size limits." + ] + }, { "cell_type": "code", "execution_count": null, @@ -456,7 +494,7 @@ "\n", "- Experiment with different vision models for specific image types\n", "- Try different prompt variations to generate specialized descriptions (e.g., technical details, key findings)\n", - "- Combine vision-based descriptions with other column types for multi-modal workflows\n", + "- Combine image, audio, or video context with other column types after confirming your selected model supports those modalities\n", "- Apply this pattern to other vision tasks like image captioning, OCR validation, or visual question answering\n", "\n", "- [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/) with Data Designer\n" diff --git a/docs/notebook_source/4-providing-images-as-context.py b/docs/notebook_source/4-providing-images-as-context.py index 7d849c89a..6e7526e53 100644 --- a/docs/notebook_source/4-providing-images-as-context.py +++ b/docs/notebook_source/4-providing-images-as-context.py @@ -19,9 +19,11 @@ # #### πŸ“š What you'll learn # # This notebook demonstrates how to provide images as context to generate text descriptions using vision-language models. +# The same `multi_modal_context` field can also carry audio or video context when the selected model supports those modalities. # # - ✨ **Visual Document Processing**: Converting images to chat-ready format for model consumption # - πŸ” **Vision-Language Generation**: Using vision models to generate detailed summaries from images +# - 🧩 **Media Context Pattern**: Understanding how `ImageContext`, `AudioContext`, and `VideoContext` fit into the same configuration field # # If this is your first time using Data Designer, we recommend starting with the [first notebook](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/1-the-basics/) in this tutorial series. # @@ -153,6 +155,37 @@ def convert_image_to_chat_format(record, height: int) -> dict: df_seed = pd.DataFrame(img_dataset)[["uuid", "label", "base64_image"]] config_builder.with_seed_dataset(dd.DataFrameSeedSource(df=df_seed)) +# %% [markdown] +# ### 🧩 Media context and model capabilities +# +# `multi_modal_context` accepts media context descriptors such as `ImageContext`, `AudioContext`, and `VideoContext`. Data Designer reads the referenced seed columns and serializes them for the model request, but the selected model still determines which modalities are valid. +# +# This notebook uses image context only because image-capable VLMs are broadly available. Before combining image, audio, and video in one column, choose a model alias backed by an omni or otherwise modality-compatible model, and check that the provider accepts every context type you send. +# +# For base64 seed columns, store the raw base64 payload without a `data:;base64,` prefix and specify the media format on the context object: +# +# ```python +# media_context = [ +# dd.ImageContext( +# column_name="image_base64", +# data_type=dd.ModalityDataType.BASE64, +# image_format=dd.ImageFormat.PNG, +# ), +# dd.AudioContext( +# column_name="audio_base64", +# data_type=dd.ModalityDataType.BASE64, +# audio_format=dd.AudioFormat.MP3, +# ), +# dd.VideoContext( +# column_name="video_base64", +# data_type=dd.ModalityDataType.BASE64, +# video_format=dd.VideoFormat.MP4, +# ), +# ] +# ``` +# +# URL-backed media can use `data_type=dd.ModalityDataType.URL`, subject to the provider's URL support and file-size limits. + # %% # Add a column to generate detailed image descriptions config_builder.add_column( @@ -257,7 +290,7 @@ def convert_image_to_chat_format(record, height: int) -> dict: # # - Experiment with different vision models for specific image types # - Try different prompt variations to generate specialized descriptions (e.g., technical details, key findings) -# - Combine vision-based descriptions with other column types for multi-modal workflows +# - Combine image, audio, or video context with other column types after confirming your selected model supports those modalities # - Apply this pattern to other vision tasks like image captioning, OCR validation, or visual question answering # # - [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/) with Data Designer diff --git a/docs/notebook_source/_README.md b/docs/notebook_source/_README.md index 97bcdf8cb..68b459eec 100644 --- a/docs/notebook_source/_README.md +++ b/docs/notebook_source/_README.md @@ -95,6 +95,7 @@ Learn how to use vision-language models to generate text descriptions from image - Processing and converting images to base64 format for model consumption - Using vision-language models (VLMs) to analyze visual documents +- Understanding how image, audio, and video context share the same `multi_modal_context` field, while still requiring model support for each modality - Generating detailed summaries from document images - Inspecting and validating vision-based generation results diff --git a/fern/versions/latest/pages/code_reference/config/models.mdx b/fern/versions/latest/pages/code_reference/config/models.mdx index 7ac0c09b6..7f1f64adf 100644 --- a/fern/versions/latest/pages/code_reference/config/models.mdx +++ b/fern/versions/latest/pages/code_reference/config/models.mdx @@ -5,9 +5,11 @@ position: 1 --- The `models` module defines configuration objects for model-based generation. [ModelProvider](#data_designer.config.models.ModelProvider) specifies connection and authentication details for custom providers. [ModelConfig](#data_designer.config.models.ModelConfig) encapsulates model details including the model alias, identifier, and inference parameters. [Inference Parameters](/concepts/models/inference-parameters) controls model behavior through settings like `temperature`, `top_p`, and `max_tokens`, with support for both fixed values and distribution-based sampling. The module includes [ImageContext](#data_designer.config.models.ImageContext), [AudioContext](#data_designer.config.models.AudioContext), and [VideoContext](#data_designer.config.models.VideoContext) for providing image, audio, and video inputs to multimodal models, and [ImageInferenceParams](#data_designer.config.models.ImageInferenceParams) for configuring image generation models. +`ImageContext`, `AudioContext`, and `VideoContext` describe the media blocks that Data Designer should send. They do not override provider limitations: the selected model must support every modality and media format included in a column's `multi_modal_context`. + For more information on how they are used, see below: - **[Model Providers](/concepts/models/model-providers)** - **[Model Configurations](/concepts/models/model-configs)** -- **[Image Context](/tutorials/providing-images-as-context)** +- **[Image and Media Context](/tutorials/providing-images-as-context)** - **[Generating Images](/tutorials/generating-images)** diff --git a/fern/versions/latest/pages/concepts/columns.mdx b/fern/versions/latest/pages/concepts/columns.mdx index 9f4e82527..fff571a98 100644 --- a/fern/versions/latest/pages/concepts/columns.mdx +++ b/fern/versions/latest/pages/concepts/columns.mdx @@ -45,6 +45,8 @@ LLM-Text columns generate natural language text: product descriptions, customer Use **Jinja2 templating** in prompts to reference other columns. Data Designer automatically manages dependencies and injects the referenced column values into the prompt. +LLM-Text and LLM-Structured columns can also include `multi_modal_context` with `ImageContext`, `AudioContext`, or `VideoContext`. Data Designer reads the referenced seed columns and serializes the media blocks, but it does not make an image-only model understand audio or video. Choose a `model_alias` whose underlying provider/model supports every modality in the column. + Generation Traces LLM columns can optionally capture message traces in a separate `{column_name}__trace` column. Set `with_trace` on the column config to control what's captured: `TraceType.NONE` (default, no trace), `TraceType.LAST_MESSAGE` (final assistant message only), or `TraceType.ALL_MESSAGES` (full conversation history). The trace includes the ordered message history for the final generation attempt (system/user/assistant/tool calls/tool results), and may include model reasoning fields when the provider exposes them. @@ -126,11 +128,11 @@ Image columns require a model configured with `ImageInferenceParams`. Model-spec - **Preview** (`data_designer.preview()`): Images are stored as base64-encoded strings directly in the DataFrame for quick iteration - **Create** (`data_designer.create()`): Images are saved to disk in an `images//` folder with UUID filenames; the DataFrame stores relative paths -Image columns also support `multi_modal_context` for autoregressive models that accept image inputs, enabling image-to-image generation workflows. +Image columns also support `multi_modal_context` for autoregressive multimodal models that accept media inputs, enabling image-to-image and other media-conditioned image generation workflows. Diffusion image-generation routes do not consume multimodal context, and not every autoregressive image model accepts every media type. Tutorials -The image tutorials cover three workflows: [Providing Images as Context](/tutorials/providing-images-as-context) (image β†’ text), [Generating Images](/tutorials/generating-images) (text β†’ image), and [Editing Images with Image Context](/tutorials/image-to-image-editing) (image β†’ image). +The image tutorials cover three workflows: [Providing Images as Context](/tutorials/providing-images-as-context) (image β†’ text, with notes on audio/video-capable models), [Generating Images](/tutorials/generating-images) (text β†’ image), and [Editing Images with Image Context](/tutorials/image-to-image-editing) (image β†’ image). ### 🧬 Embedding Columns diff --git a/fern/versions/latest/pages/concepts/models/default-model-settings.mdx b/fern/versions/latest/pages/concepts/models/default-model-settings.mdx index 803573770..5b960d951 100644 --- a/fern/versions/latest/pages/concepts/models/default-model-settings.mdx +++ b/fern/versions/latest/pages/concepts/models/default-model-settings.mdx @@ -48,7 +48,7 @@ The following model configurations are automatically available when `NVIDIA_API_ |-------|-------|----------|---------------------| | `nvidia-text` | `nvidia/nemotron-3-nano-30b-a3b` | General text generation | `temperature=1.0, top_p=1.0` | | `nvidia-reasoning` | `nvidia/nemotron-3-super-120b-a12b` | Reasoning and analysis tasks | `temperature=1.0, top_p=0.95, extra_body={"reasoning_effort": "medium"}` | -| `nvidia-vision` | `nvidia/nemotron-nano-12b-v2-vl` | Vision and image understanding | `temperature=0.85, top_p=0.95` | +| `nvidia-vision` | `nvidia/nemotron-3-nano-omni-30b-a3b-reasoning` | Omni multimodal understanding for image, audio, and video inputs | `temperature=0.60, top_p=0.95` | | `nvidia-embedding` | `nvidia/llama-3.2-nv-embedqa-1b-v2` | Text embeddings | `encoding_format="float", extra_body={"input_type": "query"}` | @@ -59,8 +59,8 @@ The following model configurations are automatically available when `OPENAI_API_ | Alias | Model | Use Case | Inference Parameters | |-------|-------|----------|---------------------| | `openai-text` | `gpt-4.1` | General text generation | `temperature=0.85, top_p=0.95` | -| `openai-reasoning` | `gpt-5` | Reasoning and analysis tasks | `temperature=0.35, top_p=0.95` | -| `openai-vision` | `gpt-5` | Vision and image understanding | `temperature=0.85, top_p=0.95` | +| `openai-reasoning` | `gpt-5` | Reasoning and analysis tasks | `extra_body={"reasoning_effort": "medium"}` | +| `openai-vision` | `gpt-5` | Vision and image understanding | `extra_body={"reasoning_effort": "medium"}` | | `openai-embedding` | `text-embedding-3-large` | Text embeddings | `encoding_format="float"` | ### OpenRouter Models @@ -71,9 +71,13 @@ The following model configurations are automatically available when `OPENROUTER_ |-------|-------|----------|---------------------| | `openrouter-text` | `nvidia/nemotron-3-nano-30b-a3b` | General text generation | `temperature=1.0, top_p=1.0` | | `openrouter-reasoning` | `openai/gpt-oss-20b` | Reasoning and analysis tasks | `temperature=0.35, top_p=0.95` | -| `openrouter-vision` | `nvidia/nemotron-3-nano-omni-30b-a3b-reasoning:free` | Vision and image understanding | `temperature=0.60, top_p=0.95` | +| `openrouter-vision` | `nvidia/nemotron-3-nano-omni-30b-a3b-reasoning:free` | Omni multimodal understanding for image, audio, and video inputs, subject to OpenRouter model support | `temperature=0.60, top_p=0.95` | | `openrouter-embedding` | `openai/text-embedding-3-large` | Text embeddings | `encoding_format="float"` | + + The `multi_modal_context` field can include image, audio, and video contexts, but each model/provider combination has its own accepted input formats, media-size limits, and modality mix. Use an image-capable model for image-only workflows, and use an omni or otherwise multimodal model before sending audio or video context. + + ## Using Default Settings diff --git a/fern/versions/latest/pages/concepts/models/model-configs.mdx b/fern/versions/latest/pages/concepts/models/model-configs.mdx index a784b5746..23e3b2a4b 100644 --- a/fern/versions/latest/pages/concepts/models/model-configs.mdx +++ b/fern/versions/latest/pages/concepts/models/model-configs.mdx @@ -9,6 +9,8 @@ Model configurations define the specific models you use for synthetic data gener A `ModelConfig` specifies which LLM model to use and how it should behave during generation. When you create column configurations (like `LLMText`, `LLMCode`, or `LLMStructured`), you reference a model by its alias. Data Designer uses the model configuration to determine which model to call and with what parameters. +When a column includes `multi_modal_context`, the `ModelConfig` alias must point to a model that supports the media types you send. Data Designer can serialize image, audio, and video context blocks, but model capability is still provider-specific. + ## ModelConfig Structure The `ModelConfig` class has the following fields: @@ -81,13 +83,13 @@ model_configs = [ max_tokens=4096, ), ), - # Vision tasks + # Omni multimodal tasks dd.ModelConfig( alias="vision-model", - model="nvidia/nemotron-nano-12b-v2-vl", + model="nvidia/nemotron-3-nano-omni-30b-a3b-reasoning", provider="nvidia", inference_parameters=dd.ChatCompletionInferenceParams( - temperature=0.7, + temperature=0.60, top_p=0.95, max_tokens=2048, ), diff --git a/fern/versions/latest/pages/notebooks/4-providing-images-as-context.mdx b/fern/versions/latest/pages/notebooks/4-providing-images-as-context.mdx index 2913480cd..0672eee7d 100644 --- a/fern/versions/latest/pages/notebooks/4-providing-images-as-context.mdx +++ b/fern/versions/latest/pages/notebooks/4-providing-images-as-context.mdx @@ -1,6 +1,6 @@ --- title: "Providing Images as Context" -description: "Multimodal prompts with image inputs." +description: "Multimodal prompts with image inputs and notes for audio/video-capable models." position: 5 --- diff --git a/fern/versions/latest/pages/notebooks/README.mdx b/fern/versions/latest/pages/notebooks/README.mdx index 6a4b80054..bad4af083 100644 --- a/fern/versions/latest/pages/notebooks/README.mdx +++ b/fern/versions/latest/pages/notebooks/README.mdx @@ -11,6 +11,6 @@ These tutorials walk through Data Designer end-to-end with executable Jupyter no | [The Basics](/tutorials/the-basics) | Declare columns, generate your first dataset | | [Structured Outputs, Jinja Expressions, and Conditional Generation](/tutorials/structured-outputs-jinja-expressions-and-conditional-generation) | Schema-constrained outputs and dynamic prompts | | [Seeding with an External Dataset](/tutorials/seeding-with-an-external-dataset) | Use existing data as input for generation | -| [Providing Images as Context](/tutorials/providing-images-as-context) | Multimodal prompts with image inputs | +| [Providing Images as Context](/tutorials/providing-images-as-context) | Multimodal prompts with image inputs, plus the media-context pattern for models that support audio or video | | [Generating Images](/tutorials/generating-images) | Create image columns from text prompts | | [Image-to-Image Editing](/tutorials/image-to-image-editing) | Edit images using image context | From de841c28388687885c16a1642f7b938f4c8a4369 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 22 May 2026 09:19:20 -0600 Subject: [PATCH 09/15] refactor: consolidate media helpers --- .../src/data_designer/config/__init__.py | 5 +- .../src/data_designer/config/models.py | 14 +- .../config/utils/image_helpers.py | 295 ---------------- .../config/utils/media_helpers.py | 178 +++++++++- .../config/utils/visualization.py | 2 +- .../tests/config/test_models.py | 17 + .../tests/config/utils/test_image_helpers.py | 324 ------------------ .../tests/config/utils/test_media_helpers.py | 268 +++++++++++++++ .../src/data_designer/engine/mcp/io.py | 2 +- .../engine/models/clients/parsing.py | 2 +- .../src/data_designer/engine/models/facade.py | 2 +- .../engine/storage/media_storage.py | 2 +- 12 files changed, 472 insertions(+), 639 deletions(-) delete mode 100644 packages/data-designer-config/src/data_designer/config/utils/image_helpers.py delete mode 100644 packages/data-designer-config/tests/config/utils/test_image_helpers.py diff --git a/packages/data-designer-config/src/data_designer/config/__init__.py b/packages/data-designer-config/src/data_designer/config/__init__.py index cf2a4c4d4..65c5cd17e 100644 --- a/packages/data-designer-config/src/data_designer/config/__init__.py +++ b/packages/data-designer-config/src/data_designer/config/__init__.py @@ -107,9 +107,8 @@ ) from data_designer.config.seed_source_dataframe import DataFrameSeedSource # noqa: F401 from data_designer.config.utils.code_lang import CodeLang # noqa: F401 - from data_designer.config.utils.image_helpers import ImageFormat # noqa: F401 from data_designer.config.utils.info import InfoType # noqa: F401 - from data_designer.config.utils.media_helpers import AudioFormat, VideoFormat # noqa: F401 + from data_designer.config.utils.media_helpers import AudioFormat, ImageFormat, VideoFormat # noqa: F401 from data_designer.config.utils.trace_type import TraceType # noqa: F401 from data_designer.config.validator_params import ( # noqa: F401 CodeValidatorParams, @@ -171,7 +170,7 @@ "EmbeddingInferenceParams": (_MOD_MODELS, "EmbeddingInferenceParams"), "GenerationType": (_MOD_MODELS, "GenerationType"), "ImageContext": (_MOD_MODELS, "ImageContext"), - "ImageFormat": (f"{_MOD_UTILS}.image_helpers", "ImageFormat"), + "ImageFormat": (f"{_MOD_UTILS}.media_helpers", "ImageFormat"), "ImageInferenceParams": (_MOD_MODELS, "ImageInferenceParams"), "ManualDistribution": (_MOD_MODELS, "ManualDistribution"), "ManualDistributionParams": (_MOD_MODELS, "ManualDistributionParams"), diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py index 80e3752e5..93be1935d 100644 --- a/packages/data-designer-config/src/data_designer/config/models.py +++ b/packages/data-designer-config/src/data_designer/config/models.py @@ -21,25 +21,23 @@ MIN_TEMPERATURE, MIN_TOP_P, ) -from data_designer.config.utils.image_helpers import ( - ImageFormat, - decode_base64_image, - detect_image_format, - is_image_path, - is_image_url, - load_image_path_to_base64, -) from data_designer.config.utils.io_helpers import smart_load_yaml from data_designer.config.utils.media_helpers import ( AudioFormat, + ImageFormat, VideoFormat, audio_format_from_mime_type, audio_mime_type, + decode_base64_image, + detect_image_format, get_media_base64_context, get_media_url_context, is_audio_path, + is_image_path, + is_image_url, is_media_url, is_video_path, + load_image_path_to_base64, normalize_media_context_values, parse_base64_data_uri, video_format_from_mime_type, diff --git a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py deleted file mode 100644 index 934be5b43..000000000 --- a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py +++ /dev/null @@ -1,295 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Helper utilities for working with images.""" - -from __future__ import annotations - -import base64 -import io -import re -from pathlib import Path - -import requests - -import data_designer.lazy_heavy_imports as lazy -from data_designer.config.utils.type_helpers import StrEnum - - -class ImageFormat(StrEnum): - """Supported image formats for image modality.""" - - PNG = "png" - JPG = "jpg" - JPEG = "jpeg" - GIF = "gif" - WEBP = "webp" - - -# Magic bytes for image format detection -IMAGE_FORMAT_MAGIC_BYTES = { - ImageFormat.PNG: b"\x89PNG\r\n\x1a\n", - ImageFormat.JPG: b"\xff\xd8\xff", - ImageFormat.GIF: b"GIF8", - # WEBP uses RIFF header - handled separately -} - -# Maps PIL format name (lowercase) to our ImageFormat enum. -# PIL reports "JPEG" (not "JPG"), so we normalize it here. -_PIL_FORMAT_TO_IMAGE_FORMAT: dict[str, ImageFormat] = { - "png": ImageFormat.PNG, - "jpeg": ImageFormat.JPG, - "jpg": ImageFormat.JPG, - "gif": ImageFormat.GIF, - "webp": ImageFormat.WEBP, -} - -_BASE64_PATTERN = re.compile(r"^[A-Za-z0-9+/=]+$") - -# Patterns for diffusion-based image models only (use image_generation API). -IMAGE_DIFFUSION_MODEL_PATTERNS = ( - "dall-e-", - "dalle", - "stable-diffusion", - "sd-", - "sd_", - "imagen", - "gpt-image-", -) - -SUPPORTED_IMAGE_EXTENSIONS = [f".{fmt.value.lower()}" for fmt in ImageFormat] - - -def is_image_diffusion_model(model_name: str) -> bool: - """Return True if the model is a diffusion-based image generation model. - - Args: - model_name: Model name or identifier (e.g. from provider). - - Returns: - True if the model is detected as diffusion-based, False otherwise. - """ - return any(pattern in model_name.lower() for pattern in IMAGE_DIFFUSION_MODEL_PATTERNS) - - -def extract_base64_from_data_uri(data: str) -> str: - """Extract base64 from data URI or return as-is. - - Handles data URIs like "data:image/png;base64,iVBORw0..." and returns - just the base64 portion. - - Args: - data: Data URI (e.g., "data:image/png;base64,XXX") or plain base64 - - Returns: - Base64 string without data URI prefix - - Raises: - ValueError: If data URI format is invalid - """ - if data.startswith("data:"): - if "," in data: - return data.split(",", 1)[1] - raise ValueError("Invalid data URI format: missing comma separator") - return data - - -def decode_base64_image(base64_data: str) -> bytes: - """Decode base64 string to image bytes. - - Automatically handles data URIs by extracting the base64 portion first. - - Args: - base64_data: Base64 string (with or without data URI prefix) - - Returns: - Decoded image bytes - - Raises: - ValueError: If base64 data is invalid - """ - # Remove data URI prefix if present - base64_data = extract_base64_from_data_uri(base64_data) - - try: - return base64.b64decode(base64_data, validate=True) - except Exception as e: - raise ValueError(f"Invalid base64 data: {e}") from e - - -def detect_image_format(image_bytes: bytes) -> ImageFormat: - """Detect image format from bytes. - - Uses magic bytes for fast detection, falls back to PIL for robust detection. - - Args: - image_bytes: Image data as bytes - - Returns: - Detected ImageFormat - - Raises: - ValueError: If the image format cannot be determined - """ - # Check magic bytes first (fast) - if image_bytes.startswith(IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.PNG]): - return ImageFormat.PNG - elif image_bytes.startswith(IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.JPG]): - return ImageFormat.JPG - elif image_bytes.startswith(IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.GIF]): - return ImageFormat.GIF - elif image_bytes.startswith(b"RIFF") and b"WEBP" in image_bytes[:12]: - return ImageFormat.WEBP - - # Fallback to PIL for robust detection - try: - img = lazy.Image.open(io.BytesIO(image_bytes)) - format_str = img.format.lower() if img.format else None - if format_str in _PIL_FORMAT_TO_IMAGE_FORMAT: - return _PIL_FORMAT_TO_IMAGE_FORMAT[format_str] - except Exception: - pass - - raise ValueError( - f"Unable to detect image format (first 8 bytes: {image_bytes[:8]!r}). " - f"Supported formats: {', '.join(SUPPORTED_IMAGE_EXTENSIONS)}." - ) - - -def is_image_path(value: str) -> bool: - """Check if a string is an image file path. - - Args: - value: String to check - - Returns: - True if the string looks like an image file path, False otherwise - """ - if not isinstance(value, str): - return False - return any(value.lower().endswith(ext) for ext in SUPPORTED_IMAGE_EXTENSIONS) - - -def is_base64_image(value: str) -> bool: - """Check if a string is base64-encoded image data. - - Args: - value: String to check - - Returns: - True if the string looks like base64-encoded image data, False otherwise - """ - if not isinstance(value, str): - return False - # Check if it starts with data URI scheme - if value.startswith("data:image/"): - return True - # Check if it looks like base64 (at least 100 chars, contains only base64 chars) - if len(value) > 100 and _BASE64_PATTERN.match(value[:100]): - try: - # Try to decode a small portion to verify it's valid base64 - base64.b64decode(value[:100]) - return True - except Exception: - return False - return False - - -def is_image_url(value: str) -> bool: - """Check if a string is an image URL. - - Args: - value: String to check - - Returns: - True if the string looks like an image URL, False otherwise - """ - if not isinstance(value, str): - return False - return value.startswith(("http://", "https://")) and any(ext in value.lower() for ext in SUPPORTED_IMAGE_EXTENSIONS) - - -def load_image_path_to_base64(image_path: str, base_path: str | None = None) -> str | None: - """Load an image from a file path and return as base64. - - Args: - image_path: Relative or absolute path to the image file. - base_path: Optional base path to resolve relative paths from. - - Returns: - Base64-encoded image data or None if loading fails. - """ - try: - path = Path(image_path) - - # If path is not absolute, try to resolve it - if not path.is_absolute(): - if base_path: - path = Path(base_path) / path - # If still not found, try current working directory - if not path.exists(): - path = Path.cwd() / image_path - - # Check if file exists - if not path.exists(): - return None - - # Read image file and convert to base64 - with open(path, "rb") as f: - image_bytes = f.read() - return base64.b64encode(image_bytes).decode() - except Exception: - return None - - -def load_image_url_to_base64(url: str, timeout: int = 60) -> str: - """Download an image from a URL and return as base64. - - Args: - url: HTTP(S) URL pointing to an image. - timeout: Request timeout in seconds. - - Returns: - Base64-encoded image data. - - Raises: - requests.HTTPError: If the download fails with a non-2xx status. - """ - resp = requests.get(url, timeout=timeout) - resp.raise_for_status() - return base64.b64encode(resp.content).decode() - - -async def aload_image_url_to_base64(url: str, timeout: int = 60) -> str: - """Download an image from a URL asynchronously and return as base64. - - Args: - url: HTTP(S) URL pointing to an image. - timeout: Request timeout in seconds. - - Returns: - Base64-encoded image data. - - Raises: - httpx.HTTPStatusError: If the download fails with a non-2xx status. - """ - async with lazy.httpx.AsyncClient() as client: - resp = await client.get(url, timeout=timeout) - resp.raise_for_status() - return base64.b64encode(resp.content).decode() - - -def validate_image(image_path: Path) -> None: - """Validate that an image file is readable and not corrupted. - - Args: - image_path: Path to image file - - Raises: - ValueError: If image is corrupted or unreadable - """ - try: - with lazy.Image.open(image_path) as img: - img.verify() - except Exception as e: - raise ValueError(f"Image validation failed: {e}") from e diff --git a/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py index 98d40a256..e36e59d28 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py @@ -5,13 +5,29 @@ from __future__ import annotations +import base64 +import io import json import re +from pathlib import Path from typing import Any +import requests + +import data_designer.lazy_heavy_imports as lazy from data_designer.config.utils.type_helpers import StrEnum +class ImageFormat(StrEnum): + """Supported image formats for image modality.""" + + PNG = "png" + JPG = "jpg" + JPEG = "jpeg" + GIF = "gif" + WEBP = "webp" + + class AudioFormat(StrEnum): """Supported audio formats for audio context.""" @@ -27,11 +43,40 @@ class VideoFormat(StrEnum): WEBM = "webm" -SUPPORTED_AUDIO_EXTENSIONS = tuple(f".{fmt.value.lower()}" for fmt in AudioFormat) -SUPPORTED_VIDEO_EXTENSIONS = tuple(f".{fmt.value.lower()}" for fmt in VideoFormat) +_SUPPORTED_IMAGE_EXTENSIONS = tuple(f".{fmt.value.lower()}" for fmt in ImageFormat) +_SUPPORTED_AUDIO_EXTENSIONS = tuple(f".{fmt.value.lower()}" for fmt in AudioFormat) +_SUPPORTED_VIDEO_EXTENSIONS = tuple(f".{fmt.value.lower()}" for fmt in VideoFormat) +_BASE64_PATTERN = re.compile(r"^[A-Za-z0-9+/=]+$") _DATA_URI_RE = re.compile(r"^data:(?P[^;]+);base64,(?P.+)$") +_IMAGE_DIFFUSION_MODEL_PATTERNS = ( + "dall-e-", + "dalle", + "stable-diffusion", + "sd-", + "sd_", + "imagen", + "gpt-image-", +) + +_IMAGE_FORMAT_MAGIC_BYTES = { + ImageFormat.PNG: b"\x89PNG\r\n\x1a\n", + ImageFormat.JPG: b"\xff\xd8\xff", + ImageFormat.GIF: b"GIF8", + # WEBP uses RIFF header - handled separately +} + +# Maps PIL format name (lowercase) to our ImageFormat enum. +# PIL reports "JPEG" (not "JPG"), so we normalize it here. +_PIL_FORMAT_TO_IMAGE_FORMAT: dict[str, ImageFormat] = { + "png": ImageFormat.PNG, + "jpeg": ImageFormat.JPG, + "jpg": ImageFormat.JPG, + "gif": ImageFormat.GIF, + "webp": ImageFormat.WEBP, +} + _AUDIO_FORMAT_TO_MIME_TYPE: dict[AudioFormat, str] = { AudioFormat.MP3: "audio/mpeg", AudioFormat.WAV: "audio/wav", @@ -56,6 +101,131 @@ class VideoFormat(StrEnum): } +def is_image_diffusion_model(model_name: str) -> bool: + """Return True if the model is a diffusion-based image generation model.""" + return any(pattern in model_name.lower() for pattern in _IMAGE_DIFFUSION_MODEL_PATTERNS) + + +def extract_base64_from_data_uri(data: str) -> str: + """Extract base64 from data URI or return as-is.""" + if data.startswith("data:"): + if "," in data: + return data.split(",", 1)[1] + raise ValueError("Invalid data URI format: missing comma separator") + return data + + +def decode_base64_image(base64_data: str) -> bytes: + """Decode base64 string to image bytes.""" + base64_data = extract_base64_from_data_uri(base64_data) + + try: + return base64.b64decode(base64_data, validate=True) + except Exception as e: + raise ValueError(f"Invalid base64 data: {e}") from e + + +def detect_image_format(image_bytes: bytes) -> ImageFormat: + """Detect image format from bytes.""" + if image_bytes.startswith(_IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.PNG]): + return ImageFormat.PNG + elif image_bytes.startswith(_IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.JPG]): + return ImageFormat.JPG + elif image_bytes.startswith(_IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.GIF]): + return ImageFormat.GIF + elif image_bytes.startswith(b"RIFF") and b"WEBP" in image_bytes[:12]: + return ImageFormat.WEBP + + try: + img = lazy.Image.open(io.BytesIO(image_bytes)) + format_str = img.format.lower() if img.format else None + if format_str in _PIL_FORMAT_TO_IMAGE_FORMAT: + return _PIL_FORMAT_TO_IMAGE_FORMAT[format_str] + except Exception: + pass + + raise ValueError( + f"Unable to detect image format (first 8 bytes: {image_bytes[:8]!r}). " + f"Supported formats: {', '.join(_SUPPORTED_IMAGE_EXTENSIONS)}." + ) + + +def is_image_path(value: str) -> bool: + """Check if a string is an image file path.""" + if not isinstance(value, str): + return False + return any(value.lower().endswith(ext) for ext in _SUPPORTED_IMAGE_EXTENSIONS) + + +def is_base64_image(value: str) -> bool: + """Check if a string is base64-encoded image data.""" + if not isinstance(value, str): + return False + if value.startswith("data:image/"): + return True + if len(value) > 100 and _BASE64_PATTERN.match(value[:100]): + try: + base64.b64decode(value[:100]) + return True + except Exception: + return False + return False + + +def is_image_url(value: str) -> bool: + """Check if a string is an image URL.""" + if not isinstance(value, str): + return False + return value.startswith(("http://", "https://")) and any( + ext in value.lower() for ext in _SUPPORTED_IMAGE_EXTENSIONS + ) + + +def load_image_path_to_base64(image_path: str, base_path: str | None = None) -> str | None: + """Load an image from a file path and return as base64.""" + try: + path = Path(image_path) + + if not path.is_absolute(): + if base_path: + path = Path(base_path) / path + if not path.exists(): + path = Path.cwd() / image_path + + if not path.exists(): + return None + + with open(path, "rb") as f: + image_bytes = f.read() + return base64.b64encode(image_bytes).decode() + except Exception: + return None + + +def load_image_url_to_base64(url: str, timeout: int = 60) -> str: + """Download an image from a URL and return as base64.""" + resp = requests.get(url, timeout=timeout) + resp.raise_for_status() + return base64.b64encode(resp.content).decode() + + +async def aload_image_url_to_base64(url: str, timeout: int = 60) -> str: + """Download an image from a URL asynchronously and return as base64.""" + async with lazy.httpx.AsyncClient() as client: + resp = await client.get(url, timeout=timeout) + resp.raise_for_status() + return base64.b64encode(resp.content).decode() + + +def validate_image(image_path: Path) -> None: + """Validate that an image file is readable and not corrupted.""" + try: + with lazy.Image.open(image_path) as img: + img.verify() + except Exception as e: + raise ValueError(f"Image validation failed: {e}") from e + + def get_media_context(modality: str, source: dict[str, Any]) -> dict[str, Any]: """Build a canonical media context block.""" return {"type": modality, "source": source} @@ -108,12 +278,12 @@ def is_media_url(value: str) -> bool: def is_audio_path(value: str) -> bool: """Return whether a value looks like a local audio path.""" - return _has_path_extension(value, SUPPORTED_AUDIO_EXTENSIONS) + return _has_path_extension(value, _SUPPORTED_AUDIO_EXTENSIONS) def is_video_path(value: str) -> bool: """Return whether a value looks like a local video path.""" - return _has_path_extension(value, SUPPORTED_VIDEO_EXTENSIONS) + return _has_path_extension(value, _SUPPORTED_VIDEO_EXTENSIONS) def audio_mime_type(audio_format: AudioFormat) -> str: diff --git a/packages/data-designer-config/src/data_designer/config/utils/visualization.py b/packages/data-designer-config/src/data_designer/config/utils/visualization.py index f18b33e53..2bd9772ab 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/visualization.py +++ b/packages/data-designer-config/src/data_designer/config/utils/visualization.py @@ -37,7 +37,7 @@ TRACE_COLUMN_POSTFIX, ) from data_designer.config.utils.errors import DatasetSampleDisplayError -from data_designer.config.utils.image_helpers import ( +from data_designer.config.utils.media_helpers import ( extract_base64_from_data_uri, is_base64_image, is_image_path, diff --git a/packages/data-designer-config/tests/config/test_models.py b/packages/data-designer-config/tests/config/test_models.py index aa60c954d..74f69185c 100644 --- a/packages/data-designer-config/tests/config/test_models.py +++ b/packages/data-designer-config/tests/config/test_models.py @@ -14,6 +14,7 @@ import yaml from pydantic import ValidationError +import data_designer.config as dd import data_designer.lazy_heavy_imports as lazy from data_designer.config.errors import InvalidConfigError from data_designer.config.models import ( @@ -39,6 +40,22 @@ from data_designer.config.utils.media_helpers import get_media_base64_context, get_media_url_context +def test_media_context_exports_are_available_on_config_namespace() -> None: + assert dd.ImageContext is ImageContext + assert dd.AudioContext is AudioContext + assert dd.VideoContext is VideoContext + assert dd.ImageFormat is ImageFormat + assert dd.AudioFormat is AudioFormat + assert dd.VideoFormat is VideoFormat + + assert "ImageContext" in dd.__all__ + assert "ImageFormat" in dd.__all__ + assert "AudioContext" in dd.__all__ + assert "AudioFormat" in dd.__all__ + assert "VideoContext" in dd.__all__ + assert "VideoFormat" in dd.__all__ + + def test_image_context_get_contexts_single_string(): """Test get_contexts with a single string value.""" image_context = ImageContext( diff --git a/packages/data-designer-config/tests/config/utils/test_image_helpers.py b/packages/data-designer-config/tests/config/utils/test_image_helpers.py deleted file mode 100644 index e425582a2..000000000 --- a/packages/data-designer-config/tests/config/utils/test_image_helpers.py +++ /dev/null @@ -1,324 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import base64 -from pathlib import Path -from unittest.mock import Mock, patch - -import pytest - -import data_designer.lazy_heavy_imports as lazy -from data_designer.config.models import ImageFormat -from data_designer.config.utils.image_helpers import ( - decode_base64_image, - detect_image_format, - extract_base64_from_data_uri, - is_base64_image, - is_image_diffusion_model, - is_image_path, - is_image_url, - load_image_path_to_base64, - validate_image, -) - -# --------------------------------------------------------------------------- -# extract_base64_from_data_uri -# --------------------------------------------------------------------------- - - -def test_extract_base64_from_data_uri_with_prefix() -> None: - data_uri = "data:image/png;base64,iVBORw0KGgoAAAANS" - result = extract_base64_from_data_uri(data_uri) - assert result == "iVBORw0KGgoAAAANS" - - -def test_extract_base64_plain_base64_without_prefix() -> None: - plain_base64 = "iVBORw0KGgoAAAANS" - result = extract_base64_from_data_uri(plain_base64) - assert result == plain_base64 - - -def test_extract_base64_invalid_data_uri_raises_error() -> None: - with pytest.raises(ValueError, match="Invalid data URI format: missing comma separator"): - extract_base64_from_data_uri("data:image/png;base64") - - -# --------------------------------------------------------------------------- -# decode_base64_image -# --------------------------------------------------------------------------- - - -def test_decode_base64_image_valid() -> None: - png_bytes = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01" - base64_data = base64.b64encode(png_bytes).decode() - result = decode_base64_image(base64_data) - assert result == png_bytes - - -def test_decode_base64_image_with_data_uri() -> None: - png_bytes = b"\x89PNG\r\n\x1a\n" - base64_data = base64.b64encode(png_bytes).decode() - data_uri = f"data:image/png;base64,{base64_data}" - result = decode_base64_image(data_uri) - assert result == png_bytes - - -def test_decode_base64_image_invalid_raises_error() -> None: - with pytest.raises(ValueError, match="Invalid base64 data"): - decode_base64_image("not-valid-base64!!!") - - -# --------------------------------------------------------------------------- -# detect_image_format (magic bytes) -# --------------------------------------------------------------------------- - - -@pytest.mark.parametrize( - "header_bytes,expected_format", - [ - (b"\x89PNG\r\n\x1a\n" + b"\x00" * 10, ImageFormat.PNG), - (b"\xff\xd8\xff" + b"\x00" * 10, ImageFormat.JPG), - (b"RIFF" + b"\x00" * 4 + b"WEBP", ImageFormat.WEBP), - ], - ids=["png", "jpg", "webp"], -) -def test_detect_image_format_magic_bytes(header_bytes: bytes, expected_format: ImageFormat) -> None: - assert detect_image_format(header_bytes) == expected_format - - -def test_detect_image_format_gif_magic_bytes(tmp_path: Path) -> None: - img = lazy.Image.new("RGB", (1, 1), color="red") - gif_path = tmp_path / "test.gif" - img.save(gif_path, format="GIF") - gif_bytes = gif_path.read_bytes() - assert detect_image_format(gif_bytes) == ImageFormat.GIF - - -def test_detect_image_format_with_pil_fallback_jpeg() -> None: - mock_img = Mock() - mock_img.format = "JPEG" - test_bytes = b"\x00\x00\x00\x00" - - with patch.object(lazy.Image, "open", return_value=mock_img): - result = detect_image_format(test_bytes) - assert result == ImageFormat.JPG - - -def test_detect_image_format_unknown_raises_error() -> None: - unknown_bytes = b"\x00\x00\x00\x00" + b"\x00" * 10 - with pytest.raises(ValueError, match="Unable to detect image format"): - detect_image_format(unknown_bytes) - - -# --------------------------------------------------------------------------- -# is_image_path -# --------------------------------------------------------------------------- - - -@pytest.mark.parametrize( - "value,expected", - [ - ("/path/to/image.png", True), - ("image.PNG", True), - ("image.jpg", True), - ("image.jpeg", True), - ("/path/to/file.txt", False), - ("document.pdf", False), - ("/some.png/file.txt", False), - ], - ids=["png", "png-upper", "jpg", "jpeg", "txt", "pdf", "ext-in-dir"], -) -def test_is_image_path(value: str, expected: bool) -> None: - assert is_image_path(value) is expected - - -# --------------------------------------------------------------------------- -# is_image_url -# --------------------------------------------------------------------------- - - -@pytest.mark.parametrize( - "value,expected", - [ - ("http://example.com/image.png", True), - ("https://example.com/photo.jpg", True), - ("https://example.com/image.png?size=large", True), - ("https://example.com/page.html", False), - ("ftp://example.com/image.png", False), - ], - ids=["http", "https", "query-params", "non-image-ext", "ftp"], -) -def test_is_image_url(value: str, expected: bool) -> None: - assert is_image_url(value) is expected - - -# --------------------------------------------------------------------------- -# is_base64_image -# --------------------------------------------------------------------------- - - -def test_is_base64_image_data_uri() -> None: - assert is_base64_image("data:image/png;base64,iVBORw0KGgo") is True - - -def test_is_base64_image_long_valid_base64() -> None: - long_base64 = base64.b64encode(b"x" * 100).decode() - assert is_base64_image(long_base64) is True - - -def test_is_base64_image_short_string() -> None: - assert is_base64_image("short") is False - - -def test_is_base64_image_invalid_base64_decode() -> None: - invalid_base64 = "A" * 50 + "=" + "A" * 49 + "more text" - assert is_base64_image(invalid_base64) is False - - -# --------------------------------------------------------------------------- -# Non-string guard (is_image_path, is_base64_image, is_image_url) -# --------------------------------------------------------------------------- - - -@pytest.mark.parametrize( - "func", - [is_image_path, is_base64_image, is_image_url], - ids=["is_image_path", "is_base64_image", "is_image_url"], -) -@pytest.mark.parametrize("value", [123, None, []], ids=["int", "none", "list"]) -def test_non_string_input_returns_false(func: object, value: object) -> None: - assert func(value) is False - - -# --------------------------------------------------------------------------- -# is_image_diffusion_model -# --------------------------------------------------------------------------- - - -@pytest.mark.parametrize( - "model_name,expected", - [ - ("dall-e-3", True), - ("DALL-E-2", True), - ("openai/dalle-2", True), - ("stable-diffusion-xl", True), - ("sd-2.1", True), - ("sd_1.5", True), - ("imagen-3", True), - ("google/imagen", True), - ("gpt-image-1", True), - ("gemini-3-pro-image-preview", False), - ("gpt-5-image", False), - ("flux.2-pro", False), - ], - ids=[ - "dall-e-3", - "DALL-E-2", - "dalle-2", - "stable-diffusion-xl", - "sd-2.1", - "sd_1.5", - "imagen-3", - "google-imagen", - "gpt-image-1", - "gemini-not-diffusion", - "gpt-5-not-diffusion", - "flux-not-diffusion", - ], -) -def test_is_image_diffusion_model(model_name: str, expected: bool) -> None: - assert is_image_diffusion_model(model_name) is expected - - -# --------------------------------------------------------------------------- -# validate_image -# --------------------------------------------------------------------------- - - -def test_validate_image_valid_png(tmp_path: Path, sample_png_bytes: bytes) -> None: - image_path = tmp_path / "test.png" - image_path.write_bytes(sample_png_bytes) - validate_image(image_path) - - -def test_validate_image_corrupted_raises_error(tmp_path: Path) -> None: - image_path = tmp_path / "corrupted.png" - image_path.write_bytes(b"not a valid image") - with pytest.raises(ValueError, match="Image validation failed"): - validate_image(image_path) - - -def test_validate_image_nonexistent_raises_error(tmp_path: Path) -> None: - image_path = tmp_path / "nonexistent.png" - with pytest.raises(ValueError, match="Image validation failed"): - validate_image(image_path) - - -# --------------------------------------------------------------------------- -# load_image_path_to_base64 -# --------------------------------------------------------------------------- - - -def test_load_image_path_to_base64_absolute_path(tmp_path: Path) -> None: - img = lazy.Image.new("RGB", (1, 1), color="blue") - image_path = tmp_path / "test.png" - img.save(image_path) - - result = load_image_path_to_base64(str(image_path)) - assert result is not None - assert len(result) > 0 - decoded = base64.b64decode(result) - assert len(decoded) > 0 - - -def test_load_image_path_to_base64_relative_with_base_path(tmp_path: Path) -> None: - img = lazy.Image.new("RGB", (1, 1), color="green") - image_path = tmp_path / "subdir" / "test.png" - image_path.parent.mkdir(exist_ok=True) - img.save(image_path) - - result = load_image_path_to_base64("subdir/test.png", base_path=str(tmp_path)) - assert result is not None - assert len(result) > 0 - - -def test_load_image_path_to_base64_nonexistent_file() -> None: - result = load_image_path_to_base64("/nonexistent/path/to/image.png") - assert result is None - - -def test_load_image_path_to_base64_relative_with_cwd_fallback(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.chdir(tmp_path) - - img = lazy.Image.new("RGB", (1, 1), color="yellow") - image_path = tmp_path / "test_cwd.png" - img.save(image_path) - - result = load_image_path_to_base64("test_cwd.png") - assert result is not None - assert len(result) > 0 - - -def test_load_image_path_to_base64_base_path_fallback_to_cwd(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.chdir(tmp_path) - - img = lazy.Image.new("RGB", (1, 1), color="red") - image_path = tmp_path / "test.png" - img.save(image_path) - - wrong_base = tmp_path / "wrong" - wrong_base.mkdir() - - result = load_image_path_to_base64("test.png", base_path=str(wrong_base)) - assert result is not None - assert len(result) > 0 - - -def test_load_image_path_to_base64_exception_handling(tmp_path: Path) -> None: - dir_path = tmp_path / "directory" - dir_path.mkdir() - - result = load_image_path_to_base64(str(dir_path)) - assert result is None diff --git a/packages/data-designer-config/tests/config/utils/test_media_helpers.py b/packages/data-designer-config/tests/config/utils/test_media_helpers.py index e5751a55b..9d8af4050 100644 --- a/packages/data-designer-config/tests/config/utils/test_media_helpers.py +++ b/packages/data-designer-config/tests/config/utils/test_media_helpers.py @@ -3,22 +3,38 @@ from __future__ import annotations +import base64 import json +from collections.abc import Callable +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest import data_designer.lazy_heavy_imports as lazy from data_designer.config.utils.media_helpers import ( AudioFormat, + ImageFormat, VideoFormat, audio_format_from_mime_type, audio_mime_type, + decode_base64_image, + detect_image_format, + extract_base64_from_data_uri, get_media_base64_context, get_media_context, get_media_url_context, is_audio_path, + is_base64_image, + is_image_diffusion_model, + is_image_path, + is_image_url, is_media_url, is_video_path, + load_image_path_to_base64, normalize_media_context_values, parse_base64_data_uri, + validate_image, video_format_from_mime_type, video_mime_type, ) @@ -69,6 +85,7 @@ def test_local_media_path_detection() -> None: def test_media_format_mime_helpers() -> None: + assert ImageFormat.PNG.value == "png" assert audio_mime_type(AudioFormat.MP3) == "audio/mpeg" assert audio_format_from_mime_type("audio/mpeg") == AudioFormat.MP3 assert audio_format_from_mime_type("audio/mp3") == AudioFormat.MP3 @@ -76,3 +93,254 @@ def test_media_format_mime_helpers() -> None: assert video_mime_type(VideoFormat.MP4) == "video/mp4" assert video_format_from_mime_type("video/mp4") == VideoFormat.MP4 assert video_format_from_mime_type("VIDEO/MP4") == VideoFormat.MP4 + + +def test_extract_base64_from_data_uri_with_prefix() -> None: + data_uri = "data:image/png;base64,iVBORw0KGgoAAAANS" + result = extract_base64_from_data_uri(data_uri) + assert result == "iVBORw0KGgoAAAANS" + + +def test_extract_base64_plain_base64_without_prefix() -> None: + plain_base64 = "iVBORw0KGgoAAAANS" + result = extract_base64_from_data_uri(plain_base64) + assert result == plain_base64 + + +def test_extract_base64_invalid_data_uri_raises_error() -> None: + with pytest.raises(ValueError, match="Invalid data URI format: missing comma separator"): + extract_base64_from_data_uri("data:image/png;base64") + + +def test_decode_base64_image_valid() -> None: + png_bytes = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01" + base64_data = base64.b64encode(png_bytes).decode() + result = decode_base64_image(base64_data) + assert result == png_bytes + + +def test_decode_base64_image_with_data_uri() -> None: + png_bytes = b"\x89PNG\r\n\x1a\n" + base64_data = base64.b64encode(png_bytes).decode() + data_uri = f"data:image/png;base64,{base64_data}" + result = decode_base64_image(data_uri) + assert result == png_bytes + + +def test_decode_base64_image_invalid_raises_error() -> None: + with pytest.raises(ValueError, match="Invalid base64 data"): + decode_base64_image("not-valid-base64!!!") + + +@pytest.mark.parametrize( + "header_bytes,expected_format", + [ + (b"\x89PNG\r\n\x1a\n" + b"\x00" * 10, ImageFormat.PNG), + (b"\xff\xd8\xff" + b"\x00" * 10, ImageFormat.JPG), + (b"RIFF" + b"\x00" * 4 + b"WEBP", ImageFormat.WEBP), + ], + ids=["png", "jpg", "webp"], +) +def test_detect_image_format_magic_bytes(header_bytes: bytes, expected_format: ImageFormat) -> None: + assert detect_image_format(header_bytes) == expected_format + + +def test_detect_image_format_gif_magic_bytes(tmp_path: Path) -> None: + img = lazy.Image.new("RGB", (1, 1), color="red") + gif_path = tmp_path / "test.gif" + img.save(gif_path, format="GIF") + gif_bytes = gif_path.read_bytes() + assert detect_image_format(gif_bytes) == ImageFormat.GIF + + +def test_detect_image_format_with_pil_fallback_jpeg() -> None: + mock_img = Mock() + mock_img.format = "JPEG" + test_bytes = b"\x00\x00\x00\x00" + + with patch.object(lazy.Image, "open", return_value=mock_img): + result = detect_image_format(test_bytes) + assert result == ImageFormat.JPG + + +def test_detect_image_format_unknown_raises_error() -> None: + unknown_bytes = b"\x00\x00\x00\x00" + b"\x00" * 10 + with pytest.raises(ValueError, match="Unable to detect image format"): + detect_image_format(unknown_bytes) + + +@pytest.mark.parametrize( + "value,expected", + [ + ("/path/to/image.png", True), + ("image.PNG", True), + ("image.jpg", True), + ("image.jpeg", True), + ("/path/to/file.txt", False), + ("document.pdf", False), + ("/some.png/file.txt", False), + ], + ids=["png", "png-upper", "jpg", "jpeg", "txt", "pdf", "ext-in-dir"], +) +def test_is_image_path(value: str, expected: bool) -> None: + assert is_image_path(value) is expected + + +@pytest.mark.parametrize( + "value,expected", + [ + ("http://example.com/image.png", True), + ("https://example.com/photo.jpg", True), + ("https://example.com/image.png?size=large", True), + ("https://example.com/page.html", False), + ("ftp://example.com/image.png", False), + ], + ids=["http", "https", "query-params", "non-image-ext", "ftp"], +) +def test_is_image_url(value: str, expected: bool) -> None: + assert is_image_url(value) is expected + + +def test_is_base64_image_data_uri() -> None: + assert is_base64_image("data:image/png;base64,iVBORw0KGgo") is True + + +def test_is_base64_image_long_valid_base64() -> None: + long_base64 = base64.b64encode(b"x" * 100).decode() + assert is_base64_image(long_base64) is True + + +def test_is_base64_image_short_string() -> None: + assert is_base64_image("short") is False + + +def test_is_base64_image_invalid_base64_decode() -> None: + invalid_base64 = "A" * 50 + "=" + "A" * 49 + "more text" + assert is_base64_image(invalid_base64) is False + + +@pytest.mark.parametrize( + "func", + [is_image_path, is_base64_image, is_image_url], + ids=["is_image_path", "is_base64_image", "is_image_url"], +) +@pytest.mark.parametrize("value", [123, None, []], ids=["int", "none", "list"]) +def test_image_media_helpers_non_string_input_returns_false(func: Callable[..., bool], value: object) -> None: + assert func(value) is False + + +@pytest.mark.parametrize( + "model_name,expected", + [ + ("dall-e-3", True), + ("DALL-E-2", True), + ("openai/dalle-2", True), + ("stable-diffusion-xl", True), + ("sd-2.1", True), + ("sd_1.5", True), + ("imagen-3", True), + ("google/imagen", True), + ("gpt-image-1", True), + ("gemini-3-pro-image-preview", False), + ("gpt-5-image", False), + ("flux.2-pro", False), + ], + ids=[ + "dall-e-3", + "DALL-E-2", + "dalle-2", + "stable-diffusion-xl", + "sd-2.1", + "sd_1.5", + "imagen-3", + "google-imagen", + "gpt-image-1", + "gemini-not-diffusion", + "gpt-5-not-diffusion", + "flux-not-diffusion", + ], +) +def test_is_image_diffusion_model(model_name: str, expected: bool) -> None: + assert is_image_diffusion_model(model_name) is expected + + +def test_validate_image_valid_png(tmp_path: Path, sample_png_bytes: bytes) -> None: + image_path = tmp_path / "test.png" + image_path.write_bytes(sample_png_bytes) + validate_image(image_path) + + +def test_validate_image_corrupted_raises_error(tmp_path: Path) -> None: + image_path = tmp_path / "corrupted.png" + image_path.write_bytes(b"not a valid image") + with pytest.raises(ValueError, match="Image validation failed"): + validate_image(image_path) + + +def test_validate_image_nonexistent_raises_error(tmp_path: Path) -> None: + image_path = tmp_path / "nonexistent.png" + with pytest.raises(ValueError, match="Image validation failed"): + validate_image(image_path) + + +def test_load_image_path_to_base64_absolute_path(tmp_path: Path) -> None: + img = lazy.Image.new("RGB", (1, 1), color="blue") + image_path = tmp_path / "test.png" + img.save(image_path) + + result = load_image_path_to_base64(str(image_path)) + assert result is not None + assert len(result) > 0 + decoded = base64.b64decode(result) + assert len(decoded) > 0 + + +def test_load_image_path_to_base64_relative_with_base_path(tmp_path: Path) -> None: + img = lazy.Image.new("RGB", (1, 1), color="green") + image_path = tmp_path / "subdir" / "test.png" + image_path.parent.mkdir(exist_ok=True) + img.save(image_path) + + result = load_image_path_to_base64("subdir/test.png", base_path=str(tmp_path)) + assert result is not None + assert len(result) > 0 + + +def test_load_image_path_to_base64_nonexistent_file() -> None: + result = load_image_path_to_base64("/nonexistent/path/to/image.png") + assert result is None + + +def test_load_image_path_to_base64_relative_with_cwd_fallback(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.chdir(tmp_path) + + img = lazy.Image.new("RGB", (1, 1), color="yellow") + image_path = tmp_path / "test_cwd.png" + img.save(image_path) + + result = load_image_path_to_base64("test_cwd.png") + assert result is not None + assert len(result) > 0 + + +def test_load_image_path_to_base64_base_path_fallback_to_cwd(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.chdir(tmp_path) + + img = lazy.Image.new("RGB", (1, 1), color="red") + image_path = tmp_path / "test.png" + img.save(image_path) + + wrong_base = tmp_path / "wrong" + wrong_base.mkdir() + + result = load_image_path_to_base64("test.png", base_path=str(wrong_base)) + assert result is not None + assert len(result) > 0 + + +def test_load_image_path_to_base64_exception_handling(tmp_path: Path) -> None: + dir_path = tmp_path / "directory" + dir_path.mkdir() + + result = load_image_path_to_base64(str(dir_path)) + assert result is None diff --git a/packages/data-designer-engine/src/data_designer/engine/mcp/io.py b/packages/data-designer-engine/src/data_designer/engine/mcp/io.py index 807e4b890..d9b2b8efd 100644 --- a/packages/data-designer-engine/src/data_designer/engine/mcp/io.py +++ b/packages/data-designer-engine/src/data_designer/engine/mcp/io.py @@ -43,7 +43,7 @@ from mcp.client.streamable_http import streamablehttp_client from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider, MCPProviderT -from data_designer.config.utils.image_helpers import ( +from data_designer.config.utils.media_helpers import ( decode_base64_image, detect_image_format, extract_base64_from_data_uri, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py index 02a3223a2..05509799c 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py @@ -11,7 +11,7 @@ from dataclasses import replace from typing import Any -from data_designer.config.utils.image_helpers import ( +from data_designer.config.utils.media_helpers import ( aload_image_url_to_base64, extract_base64_from_data_uri, is_base64_image, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 39b667e22..2c0a7a9ab 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -15,7 +15,7 @@ OPENROUTER_ATTRIBUTION_HEADERS, OPENROUTER_PROVIDER_NAME, ) -from data_designer.config.utils.image_helpers import is_image_diffusion_model +from data_designer.config.utils.media_helpers import is_image_diffusion_model from data_designer.engine.mcp.errors import MCPConfigurationError from data_designer.engine.model_provider import ModelProviderRegistry from data_designer.engine.models.clients.types import ( diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py index 1c887c808..6bd6e3dd9 100644 --- a/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py +++ b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py @@ -6,7 +6,7 @@ import uuid from pathlib import Path -from data_designer.config.utils.image_helpers import decode_base64_image, detect_image_format, validate_image +from data_designer.config.utils.media_helpers import decode_base64_image, detect_image_format, validate_image from data_designer.config.utils.type_helpers import StrEnum IMAGES_SUBDIR = "images" From 2720e59c3e6244be795f6988b21b99cc8bb1eb98 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 22 May 2026 09:34:31 -0600 Subject: [PATCH 10/15] support local audio and video paths --- .../src/data_designer/config/__init__.py | 2 - .../src/data_designer/config/models.py | 36 ++++++++------ .../config/utils/media_helpers.py | 26 +++++----- .../tests/config/test_models.py | 48 ++++++++++++------- .../models/clients/test_openai_compatible.py | 32 +++++++++++++ 5 files changed, 97 insertions(+), 47 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/__init__.py b/packages/data-designer-config/src/data_designer/config/__init__.py index 65c5cd17e..e608476b2 100644 --- a/packages/data-designer-config/src/data_designer/config/__init__.py +++ b/packages/data-designer-config/src/data_designer/config/__init__.py @@ -51,7 +51,6 @@ ModalityDataType, ModelConfig, ModelProvider, - MultiModalContextT, UniformDistribution, UniformDistributionParams, VideoContext, @@ -179,7 +178,6 @@ "ModalityDataType": (_MOD_MODELS, "ModalityDataType"), "ModelConfig": (_MOD_MODELS, "ModelConfig"), "ModelProvider": (_MOD_MODELS, "ModelProvider"), - "MultiModalContextT": (_MOD_MODELS, "MultiModalContextT"), "UniformDistribution": (_MOD_MODELS, "UniformDistribution"), "UniformDistributionParams": (_MOD_MODELS, "UniformDistributionParams"), "VideoContext": (_MOD_MODELS, "VideoContext"), diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py index 93be1935d..81cb10e9d 100644 --- a/packages/data-designer-config/src/data_designer/config/models.py +++ b/packages/data-designer-config/src/data_designer/config/models.py @@ -168,8 +168,8 @@ def _validate_image_format(self) -> Self: class AudioContext(ModalityContext): """Configuration for providing audio context to multimodal models. - Audio context values are URL or base64 media values. Unlike ``ImageContext``, - this class does not resolve local file paths to base64. + Audio context values are URL, local path, or base64 media values. Local + paths are passed through so colocated vLLM servers can read them directly. """ modality: Literal[Modality.AUDIO] = Modality.AUDIO @@ -184,6 +184,9 @@ def _build_context(self, context_value: Any) -> dict[str, Any]: self._validate_url_context_value(context_value) return get_media_url_context(Modality.AUDIO.value, context_value) + if self.data_type is None and is_audio_path(context_value): + return get_media_url_context(Modality.AUDIO.value, context_value) + if self.data_type is None and is_media_url(context_value): return get_media_url_context(Modality.AUDIO.value, context_value) @@ -204,17 +207,18 @@ def _resolve_base64_parts(self, context_value: Any) -> tuple[str, Any]: return media_type, data if is_audio_path(context_value): - raise ValueError("Local audio paths are not supported; provide an audio URL or base64 audio data") + raise ValueError( + "audio base64 context values must be base64 audio data; use data_type=url " + "or omit data_type to pass local audio paths through" + ) if self.audio_format is None: raise ValueError("audio_format is required for base64 audio context values") return audio_mime_type(self.audio_format), context_value def _validate_url_context_value(self, context_value: Any) -> None: - if is_audio_path(context_value): - raise ValueError("Local audio paths are not supported; provide an audio URL or base64 audio data") - if not is_media_url(context_value): - raise ValueError("audio URL context values must be HTTP(S) URLs") + if not is_media_url(context_value) and not is_audio_path(context_value): + raise ValueError("audio URL context values must be HTTP(S) URLs or local audio paths") @model_validator(mode="after") def _validate_audio_format(self) -> Self: @@ -226,8 +230,8 @@ def _validate_audio_format(self) -> Self: class VideoContext(ModalityContext): """Configuration for providing video context to multimodal models. - Video context values are URL or base64 media values. Local file path - resolution is intentionally out of scope for this context type. + Video context values are URL, local path, or base64 media values. Local + paths are passed through so colocated vLLM servers can read them directly. """ modality: Literal[Modality.VIDEO] = Modality.VIDEO @@ -242,6 +246,9 @@ def _build_context(self, context_value: Any) -> dict[str, Any]: self._validate_url_context_value(context_value) return get_media_url_context(Modality.VIDEO.value, context_value) + if self.data_type is None and is_video_path(context_value): + return get_media_url_context(Modality.VIDEO.value, context_value) + if self.data_type is None and is_media_url(context_value): return get_media_url_context(Modality.VIDEO.value, context_value) @@ -262,17 +269,18 @@ def _resolve_base64_parts(self, context_value: Any) -> tuple[str, Any]: return media_type, data if is_video_path(context_value): - raise ValueError("Local video paths are not supported; provide a video URL or base64 video data") + raise ValueError( + "video base64 context values must be base64 video data; use data_type=url " + "or omit data_type to pass local video paths through" + ) if self.video_format is None: raise ValueError("video_format is required for base64 video context values") return video_mime_type(self.video_format), context_value def _validate_url_context_value(self, context_value: Any) -> None: - if is_video_path(context_value): - raise ValueError("Local video paths are not supported; provide a video URL or base64 video data") - if not is_media_url(context_value): - raise ValueError("video URL context values must be HTTP(S) URLs") + if not is_media_url(context_value) and not is_video_path(context_value): + raise ValueError("video URL context values must be HTTP(S) URLs or local video paths") @model_validator(mode="after") def _validate_video_format(self) -> Self: diff --git a/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py index e36e59d28..9c293746b 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py @@ -17,6 +17,19 @@ import data_designer.lazy_heavy_imports as lazy from data_designer.config.utils.type_helpers import StrEnum +_BASE64_PATTERN = re.compile(r"^[A-Za-z0-9+/=]+$") +_DATA_URI_RE = re.compile(r"^data:(?P[^;]+);base64,(?P.+)$") + +_IMAGE_DIFFUSION_MODEL_PATTERNS = ( + "dall-e-", + "dalle", + "stable-diffusion", + "sd-", + "sd_", + "imagen", + "gpt-image-", +) + class ImageFormat(StrEnum): """Supported image formats for image modality.""" @@ -47,19 +60,6 @@ class VideoFormat(StrEnum): _SUPPORTED_AUDIO_EXTENSIONS = tuple(f".{fmt.value.lower()}" for fmt in AudioFormat) _SUPPORTED_VIDEO_EXTENSIONS = tuple(f".{fmt.value.lower()}" for fmt in VideoFormat) -_BASE64_PATTERN = re.compile(r"^[A-Za-z0-9+/=]+$") -_DATA_URI_RE = re.compile(r"^data:(?P[^;]+);base64,(?P.+)$") - -_IMAGE_DIFFUSION_MODEL_PATTERNS = ( - "dall-e-", - "dalle", - "stable-diffusion", - "sd-", - "sd_", - "imagen", - "gpt-image-", -) - _IMAGE_FORMAT_MAGIC_BYTES = { ImageFormat.PNG: b"\x89PNG\r\n\x1a\n", ImageFormat.JPG: b"\xff\xd8\xff", diff --git a/packages/data-designer-config/tests/config/test_models.py b/packages/data-designer-config/tests/config/test_models.py index 74f69185c..7c39bf533 100644 --- a/packages/data-designer-config/tests/config/test_models.py +++ b/packages/data-designer-config/tests/config/test_models.py @@ -228,6 +228,9 @@ def test_audio_context_get_contexts_single_string() -> None: assert audio_context.get_contexts({"audio_url": "https://example.com/audio.mp3"}) == [ get_media_url_context(Modality.AUDIO.value, "https://example.com/audio.mp3") ] + assert audio_context.get_contexts({"audio_url": "recordings/speech.mp3"}) == [ + get_media_url_context(Modality.AUDIO.value, "recordings/speech.mp3") + ] def test_audio_context_get_contexts_list_json_and_numpy() -> None: @@ -258,6 +261,10 @@ def test_audio_context_auto_detect_url_and_data_uri() -> None: get_media_url_context(Modality.AUDIO.value, "https://example.com/audio.mp3") ] + assert AudioContext(column_name="audio_col").get_contexts({"audio_col": "recordings/speech.wav"}) == [ + get_media_url_context(Modality.AUDIO.value, "recordings/speech.wav") + ] + assert AudioContext(column_name="audio_col").get_contexts({"audio_col": "https://example.com/download?id=123"}) == [ get_media_url_context(Modality.AUDIO.value, "https://example.com/download?id=123") ] @@ -271,11 +278,6 @@ def test_audio_context_validate_audio_format() -> None: with pytest.raises(ValueError, match="audio_format is required when data_type is base64"): AudioContext(column_name="audio_base64", data_type=ModalityDataType.BASE64) - with pytest.raises(ValueError, match="Local audio paths are not supported"): - AudioContext(column_name="audio_url", data_type=ModalityDataType.URL).get_contexts( - {"audio_url": "screen_recording.mp3"} - ) - with pytest.raises(ValueError, match="audio URL context values must be HTTP"): AudioContext(column_name="audio_url", data_type=ModalityDataType.URL).get_contexts({"audio_url": "not-a-url"}) @@ -287,10 +289,14 @@ def test_audio_context_validate_audio_format() -> None: {"audio_base64": "data:audio/mpeg;base64,audio1base64"} ) - with pytest.raises(ValueError, match="Local audio paths are not supported"): - AudioContext(column_name="audio_base64", audio_format=AudioFormat.MP3).get_contexts( - {"audio_base64": "screen_recording.mp3"} - ) + assert AudioContext(column_name="audio_base64", audio_format=AudioFormat.MP3).get_contexts( + {"audio_base64": "screen_recording.mp3"} + ) == [get_media_url_context(Modality.AUDIO.value, "screen_recording.mp3")] + + with pytest.raises(ValueError, match="audio base64 context values must be base64 audio data"): + AudioContext( + column_name="audio_base64", data_type=ModalityDataType.BASE64, audio_format=AudioFormat.MP3 + ).get_contexts({"audio_base64": "screen_recording.mp3"}) def test_video_context_get_contexts_single_string() -> None: @@ -305,6 +311,9 @@ def test_video_context_get_contexts_single_string() -> None: assert video_context.get_contexts({"video_url": "https://example.com/video.mp4"}) == [ get_media_url_context(Modality.VIDEO.value, "https://example.com/video.mp4") ] + assert video_context.get_contexts({"video_url": "clips/screen_recording.mp4"}) == [ + get_media_url_context(Modality.VIDEO.value, "clips/screen_recording.mp4") + ] def test_video_context_get_contexts_list_json_and_numpy() -> None: @@ -335,6 +344,10 @@ def test_video_context_auto_detect_url_and_data_uri() -> None: get_media_url_context(Modality.VIDEO.value, "https://example.com/video.mp4") ] + assert VideoContext(column_name="video_col").get_contexts({"video_col": "clips/screen_recording.webm"}) == [ + get_media_url_context(Modality.VIDEO.value, "clips/screen_recording.webm") + ] + assert VideoContext(column_name="video_col").get_contexts({"video_col": "https://example.com/download?id=123"}) == [ get_media_url_context(Modality.VIDEO.value, "https://example.com/download?id=123") ] @@ -348,11 +361,6 @@ def test_video_context_validate_video_format() -> None: with pytest.raises(ValueError, match="video_format is required when data_type is base64"): VideoContext(column_name="video_base64", data_type=ModalityDataType.BASE64) - with pytest.raises(ValueError, match="Local video paths are not supported"): - VideoContext(column_name="video_url", data_type=ModalityDataType.URL).get_contexts( - {"video_url": "screen_recording.mp4"} - ) - with pytest.raises(ValueError, match="video URL context values must be HTTP"): VideoContext(column_name="video_url", data_type=ModalityDataType.URL).get_contexts({"video_url": "not-a-url"}) @@ -364,10 +372,14 @@ def test_video_context_validate_video_format() -> None: {"video_base64": "data:video/mp4;base64,video1base64"} ) - with pytest.raises(ValueError, match="Local video paths are not supported"): - VideoContext(column_name="video_base64", video_format=VideoFormat.MP4).get_contexts( - {"video_base64": "screen_recording.mp4"} - ) + assert VideoContext(column_name="video_base64", video_format=VideoFormat.MP4).get_contexts( + {"video_base64": "screen_recording.mp4"} + ) == [get_media_url_context(Modality.VIDEO.value, "screen_recording.mp4")] + + with pytest.raises(ValueError, match="video base64 context values must be base64 video data"): + VideoContext( + column_name="video_base64", data_type=ModalityDataType.BASE64, video_format=VideoFormat.MP4 + ).get_contexts({"video_base64": "screen_recording.mp4"}) def test_inference_parameters_default_construction(): diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py index 5107fa9ad..0e5272534 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py @@ -387,6 +387,22 @@ def test_completion_translates_audio_url_blocks() -> None: assert content[0] == {"type": "audio_url", "audio_url": {"url": "https://example.com/download?id=123"}} +def test_completion_translates_local_audio_path_blocks() -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + audio_block = get_media_url_context(Modality.AUDIO.value, "recordings/speech.wav") + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [audio_block, {"type": "text", "text": "Transcribe this."}]}], + ) + client.completion(request) + + payload = sync_mock.post.call_args.kwargs["json"] + content = payload["messages"][0]["content"] + assert content[0] == {"type": "audio_url", "audio_url": {"url": "recordings/speech.wav"}} + + def test_completion_translates_video_blocks() -> None: sync_mock = make_mock_sync_client(_chat_response()) client = _make_client(sync_client=sync_mock) @@ -403,6 +419,22 @@ def test_completion_translates_video_blocks() -> None: assert content[0] == {"type": "video_url", "video_url": {"url": "https://example.com/download?id=123"}} +def test_completion_translates_local_video_path_blocks() -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + video_block = get_media_url_context(Modality.VIDEO.value, "clips/screen_recording.mp4") + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [video_block, {"type": "text", "text": "Describe this."}]}], + ) + client.completion(request) + + payload = sync_mock.post.call_args.kwargs["json"] + content = payload["messages"][0]["content"] + assert content[0] == {"type": "video_url", "video_url": {"url": "clips/screen_recording.mp4"}} + + def test_completion_translates_base64_video_blocks() -> None: sync_mock = make_mock_sync_client(_chat_response()) client = _make_client(sync_client=sync_mock) From 55793bcadb85dabee8898b603ce78839a605cb04 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 22 May 2026 09:38:35 -0600 Subject: [PATCH 11/15] refactor: combine media path checks --- .../src/data_designer/config/models.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py index 81cb10e9d..ef445e47c 100644 --- a/packages/data-designer-config/src/data_designer/config/models.py +++ b/packages/data-designer-config/src/data_designer/config/models.py @@ -184,10 +184,7 @@ def _build_context(self, context_value: Any) -> dict[str, Any]: self._validate_url_context_value(context_value) return get_media_url_context(Modality.AUDIO.value, context_value) - if self.data_type is None and is_audio_path(context_value): - return get_media_url_context(Modality.AUDIO.value, context_value) - - if self.data_type is None and is_media_url(context_value): + if self.data_type is None and (is_audio_path(context_value) or is_media_url(context_value)): return get_media_url_context(Modality.AUDIO.value, context_value) media_type, data = self._resolve_base64_parts(context_value) @@ -246,10 +243,7 @@ def _build_context(self, context_value: Any) -> dict[str, Any]: self._validate_url_context_value(context_value) return get_media_url_context(Modality.VIDEO.value, context_value) - if self.data_type is None and is_video_path(context_value): - return get_media_url_context(Modality.VIDEO.value, context_value) - - if self.data_type is None and is_media_url(context_value): + if self.data_type is None and (is_video_path(context_value) or is_media_url(context_value)): return get_media_url_context(Modality.VIDEO.value, context_value) media_type, data = self._resolve_base64_parts(context_value) From 00cc30a82b43d35e19851815649de512971641cc Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 22 May 2026 10:03:03 -0600 Subject: [PATCH 12/15] address media context review feedback --- .../4-providing-images-as-context.ipynb | 2 +- .../4-providing-images-as-context.py | 2 +- .../models/default-model-settings.mdx | 2 +- .../pages/concepts/models/model-configs.mdx | 2 +- .../data_designer/config/column_configs.py | 41 ++++--- .../src/data_designer/config/models.py | 30 ++++- .../config/utils/media_helpers.py | 25 ++++- .../tests/config/test_columns.py | 52 ++++++++- .../tests/config/test_models.py | 15 +++ .../tests/config/utils/test_media_helpers.py | 3 + .../clients/adapters/anthropic_translation.py | 2 + .../clients/adapters/openai_compatible.py | 104 ++++++++++++++++-- .../models/clients/test_openai_compatible.py | 60 +++++++++- 13 files changed, 304 insertions(+), 36 deletions(-) diff --git a/docs/colab_notebooks/4-providing-images-as-context.ipynb b/docs/colab_notebooks/4-providing-images-as-context.ipynb index 5accf02e4..3355cafc1 100644 --- a/docs/colab_notebooks/4-providing-images-as-context.ipynb +++ b/docs/colab_notebooks/4-providing-images-as-context.ipynb @@ -303,7 +303,7 @@ "]\n", "```\n", "\n", - "URL-backed media can use `data_type=dd.ModalityDataType.URL`, subject to the provider's URL support and file-size limits." + "URL-backed media can use `data_type=dd.ModalityDataType.URL`, subject to the provider's URL support and file-size limits. Local audio/video paths in URL mode require the model endpoint to have filesystem access to the same paths, typically a colocated vLLM server configured for local media access." ] }, { diff --git a/docs/notebook_source/4-providing-images-as-context.py b/docs/notebook_source/4-providing-images-as-context.py index 6e7526e53..dbdc65cd6 100644 --- a/docs/notebook_source/4-providing-images-as-context.py +++ b/docs/notebook_source/4-providing-images-as-context.py @@ -184,7 +184,7 @@ def convert_image_to_chat_format(record, height: int) -> dict: # ] # ``` # -# URL-backed media can use `data_type=dd.ModalityDataType.URL`, subject to the provider's URL support and file-size limits. +# URL-backed media can use `data_type=dd.ModalityDataType.URL`, subject to the provider's URL support and file-size limits. Local audio/video paths in URL mode require the model endpoint to have filesystem access to the same paths, typically a colocated vLLM server configured for local media access. # %% # Add a column to generate detailed image descriptions diff --git a/fern/versions/latest/pages/concepts/models/default-model-settings.mdx b/fern/versions/latest/pages/concepts/models/default-model-settings.mdx index 5b960d951..9afcdc989 100644 --- a/fern/versions/latest/pages/concepts/models/default-model-settings.mdx +++ b/fern/versions/latest/pages/concepts/models/default-model-settings.mdx @@ -75,7 +75,7 @@ The following model configurations are automatically available when `OPENROUTER_ | `openrouter-embedding` | `openai/text-embedding-3-large` | Text embeddings | `encoding_format="float"` | - The `multi_modal_context` field can include image, audio, and video contexts, but each model/provider combination has its own accepted input formats, media-size limits, and modality mix. Use an image-capable model for image-only workflows, and use an omni or otherwise multimodal model before sending audio or video context. + The `multi_modal_context` field can include image, audio, and video contexts, but each model/provider combination has its own accepted input formats, media-size limits, and modality mix. Use an image-capable model for image-only workflows, and use an omni or otherwise multimodal model before sending audio or video context. Local audio/video paths in URL mode require the model endpoint to have filesystem access to the same paths, typically a colocated vLLM server configured for local media access. diff --git a/fern/versions/latest/pages/concepts/models/model-configs.mdx b/fern/versions/latest/pages/concepts/models/model-configs.mdx index 23e3b2a4b..ec9335ef5 100644 --- a/fern/versions/latest/pages/concepts/models/model-configs.mdx +++ b/fern/versions/latest/pages/concepts/models/model-configs.mdx @@ -9,7 +9,7 @@ Model configurations define the specific models you use for synthetic data gener A `ModelConfig` specifies which LLM model to use and how it should behave during generation. When you create column configurations (like `LLMText`, `LLMCode`, or `LLMStructured`), you reference a model by its alias. Data Designer uses the model configuration to determine which model to call and with what parameters. -When a column includes `multi_modal_context`, the `ModelConfig` alias must point to a model that supports the media types you send. Data Designer can serialize image, audio, and video context blocks, but model capability is still provider-specific. +When a column includes `multi_modal_context`, the `ModelConfig` alias must point to a model that supports the media types you send. Data Designer can serialize image, audio, and video context blocks, but model capability is still provider-specific. Local audio/video paths in URL mode require the model endpoint to have filesystem access to the same paths, typically a colocated vLLM server configured for local media access. ## ModelConfig Structure diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py index fd9d3761b..f7f569cfe 100644 --- a/packages/data-designer-config/src/data_designer/config/column_configs.py +++ b/packages/data-designer-config/src/data_designer/config/column_configs.py @@ -17,8 +17,11 @@ from data_designer.config.utils.constants import REASONING_CONTENT_COLUMN_POSTFIX, TRACE_COLUMN_POSTFIX from data_designer.config.utils.misc import assert_valid_jinja2_template, extract_keywords_from_jinja2_template from data_designer.config.utils.trace_type import TraceType +from data_designer.config.utils.warning_helpers import warn_at_caller from data_designer.config.validator_params import ValidatorParamsT, ValidatorType +_NON_IMAGE_CONTEXT_KEYS = frozenset({"audio_format", "video_format"}) + class GenerationStrategy(str, Enum): """Strategy for custom column generation.""" @@ -184,12 +187,7 @@ class LLMTextColumnConfig(SingleColumnConfig): @classmethod def inject_legacy_image_context_modality(cls, value: Any) -> Any: """Preserve legacy image-context dicts that predate the modality discriminator.""" - if not isinstance(value, list): - return value - return [ - {"modality": "image", **item} if isinstance(item, dict) and _is_legacy_image_context_dict(item) else item - for item in value - ] + return _inject_legacy_image_context_modality(value) @staticmethod def get_column_emoji() -> str: @@ -629,12 +627,7 @@ class ImageColumnConfig(SingleColumnConfig): @classmethod def inject_legacy_image_context_modality(cls, value: Any) -> Any: """Preserve legacy image-context dicts that predate the modality discriminator.""" - if not isinstance(value, list): - return value - return [ - {"modality": "image", **item} if isinstance(item, dict) and _is_legacy_image_context_dict(item) else item - for item in value - ] + return _inject_legacy_image_context_modality(value) @staticmethod def get_column_emoji() -> str: @@ -755,5 +748,27 @@ def validate_generator_function(self) -> Self: return self +def _inject_legacy_image_context_modality(value: Any) -> Any: + if not isinstance(value, list): + return value + return [ + _inject_legacy_image_context_item(item) + if isinstance(item, dict) and _is_legacy_image_context_dict(item) + else item + for item in value + ] + + +def _inject_legacy_image_context_item(item: dict[str, Any]) -> dict[str, Any]: + warn_at_caller( + "Modality-less multi_modal_context dictionaries are treated as legacy ImageContext configs. " + "Set modality='image', modality='audio', or modality='video' explicitly for new configs.", + DeprecationWarning, + ) + return {"modality": "image", **item} + + def _is_legacy_image_context_dict(value: dict[str, Any]) -> bool: - return "modality" not in value + if "modality" in value: + return False + return not _NON_IMAGE_CONTEXT_KEYS.intersection(value) diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py index ef445e47c..6a4748e79 100644 --- a/packages/data-designer-config/src/data_designer/config/models.py +++ b/packages/data-designer-config/src/data_designer/config/models.py @@ -32,6 +32,7 @@ detect_image_format, get_media_base64_context, get_media_url_context, + image_format_from_mime_type, is_audio_path, is_image_path, is_image_url, @@ -148,8 +149,13 @@ def _format_base64_context(self, base64_data: str) -> dict[str, Any]: parsed = parse_base64_data_uri(base64_data) if parsed is not None: media_type, data = parsed - if not media_type.startswith("image/"): + detected_format = image_format_from_mime_type(media_type) + if detected_format is None: raise ValueError(f"Unsupported image media type {media_type!r}") + if self.image_format is not None and not _image_formats_match(self.image_format, detected_format): + raise ValueError( + f"image_format {self.image_format.value!r} does not match data URI media type {media_type!r}" + ) return get_media_base64_context(Modality.IMAGE.value, media_type, data) image_format = self.image_format @@ -165,18 +171,30 @@ def _validate_image_format(self) -> Self: return self +def _image_formats_match(configured_format: ImageFormat, detected_format: ImageFormat) -> bool: + if configured_format == detected_format: + return True + return {configured_format, detected_format} == {ImageFormat.JPG, ImageFormat.JPEG} + + class AudioContext(ModalityContext): """Configuration for providing audio context to multimodal models. Audio context values are URL, local path, or base64 media values. Local paths are passed through so colocated vLLM servers can read them directly. + ``audio_format`` is consulted only for base64 sources; URL and local-path + sources are passed through unchanged. """ modality: Literal[Modality.AUDIO] = Modality.AUDIO audio_format: AudioFormat | None = None def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[dict[str, Any]]: - """Get the contexts for the audio modality.""" + """Get audio contexts. + + ``base_path`` is accepted for signature compatibility with ``ImageContext`` + but unused; local audio paths are passed through unchanged. + """ return [self._build_context(value) for value in normalize_media_context_values(record[self.column_name])] def _build_context(self, context_value: Any) -> dict[str, Any]: @@ -229,13 +247,19 @@ class VideoContext(ModalityContext): Video context values are URL, local path, or base64 media values. Local paths are passed through so colocated vLLM servers can read them directly. + ``video_format`` is consulted only for base64 sources; URL and local-path + sources are passed through unchanged. """ modality: Literal[Modality.VIDEO] = Modality.VIDEO video_format: VideoFormat | None = None def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[dict[str, Any]]: - """Get the contexts for the video modality.""" + """Get video contexts. + + ``base_path`` is accepted for signature compatibility with ``ImageContext`` + but unused; local video paths are passed through unchanged. + """ return [self._build_context(value) for value in normalize_media_context_values(record[self.column_name])] def _build_context(self, context_value: Any) -> dict[str, Any]: diff --git a/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py index 9c293746b..998b81c05 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py @@ -17,6 +17,8 @@ import data_designer.lazy_heavy_imports as lazy from data_designer.config.utils.type_helpers import StrEnum +# --- Format enums and constants --- + _BASE64_PATTERN = re.compile(r"^[A-Za-z0-9+/=]+$") _DATA_URI_RE = re.compile(r"^data:(?P[^;]+);base64,(?P.+)$") @@ -77,6 +79,13 @@ class VideoFormat(StrEnum): "webp": ImageFormat.WEBP, } +_IMAGE_MIME_TYPE_TO_FORMAT: dict[str, ImageFormat] = { + "image/png": ImageFormat.PNG, + "image/jpeg": ImageFormat.JPG, + "image/jpg": ImageFormat.JPG, + "image/gif": ImageFormat.GIF, + "image/webp": ImageFormat.WEBP, +} _AUDIO_FORMAT_TO_MIME_TYPE: dict[AudioFormat, str] = { AudioFormat.MP3: "audio/mpeg", AudioFormat.WAV: "audio/wav", @@ -101,6 +110,9 @@ class VideoFormat(StrEnum): } +# --- Image helpers --- + + def is_image_diffusion_model(model_name: str) -> bool: """Return True if the model is a diffusion-based image generation model.""" return any(pattern in model_name.lower() for pattern in _IMAGE_DIFFUSION_MODEL_PATTERNS) @@ -226,6 +238,9 @@ def validate_image(image_path: Path) -> None: raise ValueError(f"Image validation failed: {e}") from e +# --- Canonical media blocks --- + + def get_media_context(modality: str, source: dict[str, Any]) -> dict[str, Any]: """Build a canonical media context block.""" return {"type": modality, "source": source} @@ -248,7 +263,7 @@ def normalize_media_context_values(raw_value: Any) -> list[Any]: parsed_value = json.loads(raw_value) if isinstance(parsed_value, list): return parsed_value - except (json.JSONDecodeError, TypeError): + except json.JSONDecodeError: pass return [raw_value] @@ -271,6 +286,9 @@ def parse_base64_data_uri(value: str) -> tuple[str, str] | None: return match.group("media_type"), match.group("data") +# --- Audio/video helpers --- + + def is_media_url(value: str) -> bool: """Return whether a value is an HTTP(S) media URL.""" return isinstance(value, str) and value.startswith(("http://", "https://")) @@ -296,6 +314,11 @@ def video_mime_type(video_format: VideoFormat) -> str: return _VIDEO_FORMAT_TO_MIME_TYPE[video_format] +def image_format_from_mime_type(media_type: str) -> ImageFormat | None: + """Infer an image format from a MIME type.""" + return _IMAGE_MIME_TYPE_TO_FORMAT.get(media_type.lower()) + + def audio_format_from_mime_type(media_type: str) -> AudioFormat | None: """Infer an audio format from a MIME type.""" return _AUDIO_MIME_TYPE_TO_FORMAT.get(media_type.lower()) diff --git a/packages/data-designer-config/tests/config/test_columns.py b/packages/data-designer-config/tests/config/test_columns.py index ee87ffdbd..baafb1c43 100644 --- a/packages/data-designer-config/tests/config/test_columns.py +++ b/packages/data-designer-config/tests/config/test_columns.py @@ -206,18 +206,58 @@ def test_column_config_accepts_legacy_image_context_dict( config_cls: type[LLMTextColumnConfig] | type[ImageColumnConfig], name: str, ) -> None: - config = config_cls( - name=name, - prompt="Describe the image", - model_alias=stub_model_alias, - multi_modal_context=[{"column_name": "image_url", "data_type": "url"}], - ) + with pytest.warns(DeprecationWarning, match="treated as legacy ImageContext configs"): + config = config_cls( + name=name, + prompt="Describe the image", + model_alias=stub_model_alias, + multi_modal_context=[{"column_name": "image_url", "data_type": "url"}], + ) assert config.multi_modal_context is not None assert isinstance(config.multi_modal_context[0], ImageContext) assert config.multi_modal_context[0].column_name == "image_url" +@pytest.mark.parametrize( + "context_dict", + [ + {"column_name": "audio_url", "data_type": "url"}, + {"column_name": "video_url", "data_type": "url"}, + ], + ids=["audio-url-shaped", "video-url-shaped"], +) +def test_column_config_warns_modality_less_url_context_is_legacy_image(context_dict: dict[str, str]) -> None: + with pytest.warns(DeprecationWarning, match="treated as legacy ImageContext configs"): + config = LLMTextColumnConfig( + name="test_llm_text", + prompt="Describe the context", + model_alias=stub_model_alias, + multi_modal_context=[context_dict], + ) + + assert config.multi_modal_context is not None + assert isinstance(config.multi_modal_context[0], ImageContext) + + +@pytest.mark.parametrize( + "context_dict", + [ + {"column_name": "audio_url", "data_type": "url", "audio_format": "mp3"}, + {"column_name": "video_url", "data_type": "url", "video_format": "mp4"}, + ], + ids=["audio-format", "video-format"], +) +def test_column_config_requires_modality_for_audio_video_specific_dicts(context_dict: dict[str, str]) -> None: + with pytest.raises(ValidationError, match="modality"): + LLMTextColumnConfig( + name="test_llm_text", + prompt="Describe the context", + model_alias=stub_model_alias, + multi_modal_context=[context_dict], + ) + + def test_llm_text_column_config_with_trace_serialization() -> None: """Test that with_trace field serializes and deserializes correctly.""" config = LLMTextColumnConfig( diff --git a/packages/data-designer-config/tests/config/test_models.py b/packages/data-designer-config/tests/config/test_models.py index 7c39bf533..ebe720cf4 100644 --- a/packages/data-designer-config/tests/config/test_models.py +++ b/packages/data-designer-config/tests/config/test_models.py @@ -158,6 +158,21 @@ def test_image_context_validate_image_format(): ImageContext(column_name="image_base64", data_type=ModalityDataType.BASE64) +def test_image_context_validates_data_uri_media_type_against_image_format() -> None: + context = ImageContext(column_name="image_base64", image_format=ImageFormat.PNG) + + with pytest.raises(ValueError, match="image_format 'png' does not match data URI media type 'image/jpeg'"): + context.get_contexts({"image_base64": "data:image/jpeg;base64,image1base64"}) + + +def test_image_context_accepts_jpg_format_for_jpeg_data_uri() -> None: + context = ImageContext(column_name="image_base64", image_format=ImageFormat.JPG) + + assert context.get_contexts({"image_base64": "data:image/jpeg;base64,image1base64"}) == [ + get_media_base64_context(Modality.IMAGE.value, "image/jpeg", "image1base64") + ] + + def test_image_context_no_data_type_passes_validation() -> None: """Test that ImageContext without data_type passes validation.""" context = ImageContext(column_name="image_col") diff --git a/packages/data-designer-config/tests/config/utils/test_media_helpers.py b/packages/data-designer-config/tests/config/utils/test_media_helpers.py index 9d8af4050..f1e898d66 100644 --- a/packages/data-designer-config/tests/config/utils/test_media_helpers.py +++ b/packages/data-designer-config/tests/config/utils/test_media_helpers.py @@ -24,6 +24,7 @@ get_media_base64_context, get_media_context, get_media_url_context, + image_format_from_mime_type, is_audio_path, is_base64_image, is_image_diffusion_model, @@ -86,6 +87,8 @@ def test_local_media_path_detection() -> None: def test_media_format_mime_helpers() -> None: assert ImageFormat.PNG.value == "png" + assert image_format_from_mime_type("image/png") == ImageFormat.PNG + assert image_format_from_mime_type("image/jpeg") == ImageFormat.JPG assert audio_mime_type(AudioFormat.MP3) == "audio/mpeg" assert audio_format_from_mime_type("audio/mpeg") == AudioFormat.MP3 assert audio_format_from_mime_type("audio/mp3") == AudioFormat.MP3 diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py index 62dc1d4e7..5ba186add 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py @@ -22,6 +22,8 @@ ) _DEFAULT_MAX_TOKENS = 4096 +# Include canonical blocks from *Context.get_contexts and provider-specific +# blocks that users may author directly in templates or tool-result content. _UNSUPPORTED_MEDIA_BLOCK_MODALITIES: dict[str, str] = { "audio": "audio", "audio_url": "audio", diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py index 94a294717..594b9a347 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py @@ -5,6 +5,7 @@ from typing import Any +from data_designer.config.utils.constants import OPENAI_PROVIDER_NAME from data_designer.config.utils.media_helpers import audio_format_from_mime_type from data_designer.engine.models.clients.adapters.http_model_client import ( HttpModelClient, @@ -167,7 +168,11 @@ def translate_openai_compatible_messages( translated = dict(message) if "content" in translated: try: - translated["content"] = translate_openai_compatible_content_blocks(translated["content"]) + translated["content"] = translate_openai_compatible_content_blocks( + translated["content"], + provider_name=provider_name, + model_name=model_name, + ) except ValueError as exc: raise ProviderError( kind=ProviderErrorKind.BAD_REQUEST, @@ -180,14 +185,27 @@ def translate_openai_compatible_messages( return translated_messages -def translate_openai_compatible_content_blocks(content: Any) -> Any: +def translate_openai_compatible_content_blocks( + content: Any, + *, + provider_name: str, + model_name: str, +) -> Any: if not isinstance(content, list): return content - return [translate_openai_compatible_content_block(block) for block in content] + return [ + translate_openai_compatible_content_block(block, provider_name=provider_name, model_name=model_name) + for block in content + ] -def translate_openai_compatible_content_block(block: Any) -> Any: +def translate_openai_compatible_content_block( + block: Any, + *, + provider_name: str, + model_name: str, +) -> Any: if not isinstance(block, dict): return block @@ -196,12 +214,28 @@ def translate_openai_compatible_content_block(block: Any) -> Any: return block if block_type in {"image_url", "input_audio", "text"}: return block + if block_type == "audio_url": + _ensure_extension_media_block_supported( + provider_name=provider_name, + model_name=model_name, + modality="audio", + source_type="url", + ) + return block + if block_type in {"video_url", "input_video"}: + _ensure_extension_media_block_supported( + provider_name=provider_name, + model_name=model_name, + modality="video", + source_type="url", + ) + return block if block_type == "image": return _translate_canonical_image_block(block) if block_type == "audio": - return _translate_canonical_audio_block(block) + return _translate_canonical_audio_block(block, provider_name=provider_name, model_name=model_name) if block_type == "video": - return _translate_canonical_video_block(block) + return _translate_canonical_video_block(block, provider_name=provider_name, model_name=model_name) return block @@ -219,10 +253,23 @@ def _translate_canonical_image_block(block: dict[str, Any]) -> dict[str, Any]: raise ValueError(f"Unsupported canonical image source type {source_type!r}") -def _translate_canonical_audio_block(block: dict[str, Any]) -> dict[str, Any]: +def _translate_canonical_audio_block( + block: dict[str, Any], + *, + provider_name: str, + model_name: str, +) -> dict[str, Any]: source = _get_media_source(block, modality="audio") source_type = source.get("type") if source_type == "url": + _ensure_extension_media_block_supported( + provider_name=provider_name, + model_name=model_name, + modality="audio", + source_type="url", + ) + # ``audio_url`` is an OpenAI-compatible extension used by providers such as vLLM/NVIDIA, + # not by OpenAI's hosted Chat Completions route. return {"type": "audio_url", "audio_url": {"url": source.get("url", "")}} if source_type == "base64": media_type = source.get("media_type") @@ -236,20 +283,61 @@ def _translate_canonical_audio_block(block: dict[str, Any]) -> dict[str, Any]: raise ValueError(f"Unsupported canonical audio source type {source_type!r}") -def _translate_canonical_video_block(block: dict[str, Any]) -> dict[str, Any]: +def _translate_canonical_video_block( + block: dict[str, Any], + *, + provider_name: str, + model_name: str, +) -> dict[str, Any]: source = _get_media_source(block, modality="video") source_type = source.get("type") if source_type == "url": + _ensure_extension_media_block_supported( + provider_name=provider_name, + model_name=model_name, + modality="video", + source_type="url", + ) + # ``video_url`` is an OpenAI-compatible extension used by providers such as vLLM/NVIDIA, + # not by OpenAI's hosted Chat Completions route. return {"type": "video_url", "video_url": {"url": source.get("url", "")}} if source_type == "base64": + _ensure_extension_media_block_supported( + provider_name=provider_name, + model_name=model_name, + modality="video", + source_type="base64", + ) media_type = source.get("media_type") data = source.get("data") if not isinstance(media_type, str) or not isinstance(data, str): raise ValueError(f"Canonical video base64 source must include media_type and data, got: {source!r}") + # No widely supported ``input_video`` block exists; capable OpenAI-compatible + # providers may accept a data URI in ``video_url``. return {"type": "video_url", "video_url": {"url": f"data:{media_type};base64,{data}"}} raise ValueError(f"Unsupported canonical video source type {source_type!r}") +def _ensure_extension_media_block_supported( + *, + provider_name: str, + model_name: str, + modality: str, + source_type: str, +) -> None: + if provider_name.lower() != OPENAI_PROVIDER_NAME: + return + raise ProviderError.unsupported_capability( + provider_name=provider_name, + model_name=model_name, + operation=f"{modality}-{source_type} context", + message=( + f"Provider {provider_name!r} does not support {modality} {source_type} context blocks " + "through the OpenAI-compatible chat-completions adapter." + ), + ) + + def _get_media_source(block: dict[str, Any], *, modality: str) -> dict[str, Any]: source = block.get("source") if not isinstance(source, dict): diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py index 0e5272534..b03bbb532 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py @@ -22,6 +22,7 @@ from tests.engine.models.clients.conftest import make_mock_async_client, make_mock_sync_client PROVIDER = "test-provider" +OPENAI_PROVIDER = "openai" MODEL = "gpt-test" ENDPOINT = "https://api.example.com/v1" @@ -31,10 +32,11 @@ def _make_client( sync_client: MagicMock | None = None, async_client: MagicMock | None = None, api_key: str | None = "sk-test-key", + provider_name: str = PROVIDER, ) -> OpenAICompatibleClient: concurrency_mode = ClientConcurrencyMode.ASYNC if async_client is not None else ClientConcurrencyMode.SYNC return OpenAICompatibleClient( - provider_name=PROVIDER, + provider_name=provider_name, endpoint=ENDPOINT, api_key=api_key, concurrency_mode=concurrency_mode, @@ -387,6 +389,33 @@ def test_completion_translates_audio_url_blocks() -> None: assert content[0] == {"type": "audio_url", "audio_url": {"url": "https://example.com/download?id=123"}} +@pytest.mark.parametrize( + "audio_block", + [ + get_media_url_context(Modality.AUDIO.value, "https://example.com/download?id=123"), + {"type": "audio_url", "audio_url": {"url": "https://example.com/download?id=123"}}, + ], + ids=["canonical", "provider-specific"], +) +def test_completion_rejects_audio_url_blocks_for_openai_provider(audio_block: dict[str, Any]) -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock, provider_name=OPENAI_PROVIDER) + + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [audio_block, {"type": "text", "text": "Transcribe this."}]}], + ) + + with pytest.raises(ProviderError) as exc_info: + client.completion(request) + + assert exc_info.value.kind == ProviderErrorKind.UNSUPPORTED_CAPABILITY + assert exc_info.value.provider_name == OPENAI_PROVIDER + assert exc_info.value.model_name == MODEL + assert "audio url context" in exc_info.value.message + sync_mock.post.assert_not_called() + + def test_completion_translates_local_audio_path_blocks() -> None: sync_mock = make_mock_sync_client(_chat_response()) client = _make_client(sync_client=sync_mock) @@ -451,6 +480,35 @@ def test_completion_translates_base64_video_blocks() -> None: assert content[0] == {"type": "video_url", "video_url": {"url": "data:video/mp4;base64,abc123"}} +@pytest.mark.parametrize( + "video_block", + [ + get_media_url_context(Modality.VIDEO.value, "https://example.com/download?id=123"), + get_media_base64_context(Modality.VIDEO.value, "video/mp4", "abc123"), + {"type": "video_url", "video_url": {"url": "https://example.com/download?id=123"}}, + {"type": "input_video", "input_video": {"data": "abc123", "format": "mp4"}}, + ], + ids=["canonical-url", "canonical-base64", "provider-video-url", "provider-input-video"], +) +def test_completion_rejects_video_blocks_for_openai_provider(video_block: dict[str, Any]) -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock, provider_name=OPENAI_PROVIDER) + + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [video_block, {"type": "text", "text": "Describe this."}]}], + ) + + with pytest.raises(ProviderError) as exc_info: + client.completion(request) + + assert exc_info.value.kind == ProviderErrorKind.UNSUPPORTED_CAPABILITY + assert exc_info.value.provider_name == OPENAI_PROVIDER + assert exc_info.value.model_name == MODEL + assert "video" in exc_info.value.message + sync_mock.post.assert_not_called() + + # --- Auth headers --- From badf4e11f46ba5f62ba3bdd7ebd4dc72fd42a32e Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 22 May 2026 10:28:02 -0600 Subject: [PATCH 13/15] remove openai media preflight --- .../clients/adapters/openai_compatible.py | 100 ++---------------- .../models/clients/test_openai_compatible.py | 79 ++++---------- 2 files changed, 29 insertions(+), 150 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py index 594b9a347..e9f9228a1 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py @@ -5,7 +5,6 @@ from typing import Any -from data_designer.config.utils.constants import OPENAI_PROVIDER_NAME from data_designer.config.utils.media_helpers import audio_format_from_mime_type from data_designer.engine.models.clients.adapters.http_model_client import ( HttpModelClient, @@ -168,11 +167,7 @@ def translate_openai_compatible_messages( translated = dict(message) if "content" in translated: try: - translated["content"] = translate_openai_compatible_content_blocks( - translated["content"], - provider_name=provider_name, - model_name=model_name, - ) + translated["content"] = translate_openai_compatible_content_blocks(translated["content"]) except ValueError as exc: raise ProviderError( kind=ProviderErrorKind.BAD_REQUEST, @@ -185,57 +180,28 @@ def translate_openai_compatible_messages( return translated_messages -def translate_openai_compatible_content_blocks( - content: Any, - *, - provider_name: str, - model_name: str, -) -> Any: +def translate_openai_compatible_content_blocks(content: Any) -> Any: if not isinstance(content, list): return content - return [ - translate_openai_compatible_content_block(block, provider_name=provider_name, model_name=model_name) - for block in content - ] + return [translate_openai_compatible_content_block(block) for block in content] -def translate_openai_compatible_content_block( - block: Any, - *, - provider_name: str, - model_name: str, -) -> Any: +def translate_openai_compatible_content_block(block: Any) -> Any: if not isinstance(block, dict): return block block_type = block.get("type") if not isinstance(block_type, str): return block - if block_type in {"image_url", "input_audio", "text"}: - return block - if block_type == "audio_url": - _ensure_extension_media_block_supported( - provider_name=provider_name, - model_name=model_name, - modality="audio", - source_type="url", - ) - return block - if block_type in {"video_url", "input_video"}: - _ensure_extension_media_block_supported( - provider_name=provider_name, - model_name=model_name, - modality="video", - source_type="url", - ) + if block_type in {"audio_url", "image_url", "input_audio", "input_video", "text", "video_url"}: return block if block_type == "image": return _translate_canonical_image_block(block) if block_type == "audio": - return _translate_canonical_audio_block(block, provider_name=provider_name, model_name=model_name) + return _translate_canonical_audio_block(block) if block_type == "video": - return _translate_canonical_video_block(block, provider_name=provider_name, model_name=model_name) + return _translate_canonical_video_block(block) return block @@ -253,21 +219,10 @@ def _translate_canonical_image_block(block: dict[str, Any]) -> dict[str, Any]: raise ValueError(f"Unsupported canonical image source type {source_type!r}") -def _translate_canonical_audio_block( - block: dict[str, Any], - *, - provider_name: str, - model_name: str, -) -> dict[str, Any]: +def _translate_canonical_audio_block(block: dict[str, Any]) -> dict[str, Any]: source = _get_media_source(block, modality="audio") source_type = source.get("type") if source_type == "url": - _ensure_extension_media_block_supported( - provider_name=provider_name, - model_name=model_name, - modality="audio", - source_type="url", - ) # ``audio_url`` is an OpenAI-compatible extension used by providers such as vLLM/NVIDIA, # not by OpenAI's hosted Chat Completions route. return {"type": "audio_url", "audio_url": {"url": source.get("url", "")}} @@ -283,31 +238,14 @@ def _translate_canonical_audio_block( raise ValueError(f"Unsupported canonical audio source type {source_type!r}") -def _translate_canonical_video_block( - block: dict[str, Any], - *, - provider_name: str, - model_name: str, -) -> dict[str, Any]: +def _translate_canonical_video_block(block: dict[str, Any]) -> dict[str, Any]: source = _get_media_source(block, modality="video") source_type = source.get("type") if source_type == "url": - _ensure_extension_media_block_supported( - provider_name=provider_name, - model_name=model_name, - modality="video", - source_type="url", - ) # ``video_url`` is an OpenAI-compatible extension used by providers such as vLLM/NVIDIA, # not by OpenAI's hosted Chat Completions route. return {"type": "video_url", "video_url": {"url": source.get("url", "")}} if source_type == "base64": - _ensure_extension_media_block_supported( - provider_name=provider_name, - model_name=model_name, - modality="video", - source_type="base64", - ) media_type = source.get("media_type") data = source.get("data") if not isinstance(media_type, str) or not isinstance(data, str): @@ -318,26 +256,6 @@ def _translate_canonical_video_block( raise ValueError(f"Unsupported canonical video source type {source_type!r}") -def _ensure_extension_media_block_supported( - *, - provider_name: str, - model_name: str, - modality: str, - source_type: str, -) -> None: - if provider_name.lower() != OPENAI_PROVIDER_NAME: - return - raise ProviderError.unsupported_capability( - provider_name=provider_name, - model_name=model_name, - operation=f"{modality}-{source_type} context", - message=( - f"Provider {provider_name!r} does not support {modality} {source_type} context blocks " - "through the OpenAI-compatible chat-completions adapter." - ), - ) - - def _get_media_source(block: dict[str, Any], *, modality: str) -> dict[str, Any]: source = block.get("source") if not isinstance(source, dict): diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py index b03bbb532..262d09c90 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py @@ -22,7 +22,6 @@ from tests.engine.models.clients.conftest import make_mock_async_client, make_mock_sync_client PROVIDER = "test-provider" -OPENAI_PROVIDER = "openai" MODEL = "gpt-test" ENDPOINT = "https://api.example.com/v1" @@ -32,11 +31,10 @@ def _make_client( sync_client: MagicMock | None = None, async_client: MagicMock | None = None, api_key: str | None = "sk-test-key", - provider_name: str = PROVIDER, ) -> OpenAICompatibleClient: concurrency_mode = ClientConcurrencyMode.ASYNC if async_client is not None else ClientConcurrencyMode.SYNC return OpenAICompatibleClient( - provider_name=provider_name, + provider_name=PROVIDER, endpoint=ENDPOINT, api_key=api_key, concurrency_mode=concurrency_mode, @@ -389,33 +387,6 @@ def test_completion_translates_audio_url_blocks() -> None: assert content[0] == {"type": "audio_url", "audio_url": {"url": "https://example.com/download?id=123"}} -@pytest.mark.parametrize( - "audio_block", - [ - get_media_url_context(Modality.AUDIO.value, "https://example.com/download?id=123"), - {"type": "audio_url", "audio_url": {"url": "https://example.com/download?id=123"}}, - ], - ids=["canonical", "provider-specific"], -) -def test_completion_rejects_audio_url_blocks_for_openai_provider(audio_block: dict[str, Any]) -> None: - sync_mock = make_mock_sync_client(_chat_response()) - client = _make_client(sync_client=sync_mock, provider_name=OPENAI_PROVIDER) - - request = ChatCompletionRequest( - model=MODEL, - messages=[{"role": "user", "content": [audio_block, {"type": "text", "text": "Transcribe this."}]}], - ) - - with pytest.raises(ProviderError) as exc_info: - client.completion(request) - - assert exc_info.value.kind == ProviderErrorKind.UNSUPPORTED_CAPABILITY - assert exc_info.value.provider_name == OPENAI_PROVIDER - assert exc_info.value.model_name == MODEL - assert "audio url context" in exc_info.value.message - sync_mock.post.assert_not_called() - - def test_completion_translates_local_audio_path_blocks() -> None: sync_mock = make_mock_sync_client(_chat_response()) client = _make_client(sync_client=sync_mock) @@ -480,35 +451,6 @@ def test_completion_translates_base64_video_blocks() -> None: assert content[0] == {"type": "video_url", "video_url": {"url": "data:video/mp4;base64,abc123"}} -@pytest.mark.parametrize( - "video_block", - [ - get_media_url_context(Modality.VIDEO.value, "https://example.com/download?id=123"), - get_media_base64_context(Modality.VIDEO.value, "video/mp4", "abc123"), - {"type": "video_url", "video_url": {"url": "https://example.com/download?id=123"}}, - {"type": "input_video", "input_video": {"data": "abc123", "format": "mp4"}}, - ], - ids=["canonical-url", "canonical-base64", "provider-video-url", "provider-input-video"], -) -def test_completion_rejects_video_blocks_for_openai_provider(video_block: dict[str, Any]) -> None: - sync_mock = make_mock_sync_client(_chat_response()) - client = _make_client(sync_client=sync_mock, provider_name=OPENAI_PROVIDER) - - request = ChatCompletionRequest( - model=MODEL, - messages=[{"role": "user", "content": [video_block, {"type": "text", "text": "Describe this."}]}], - ) - - with pytest.raises(ProviderError) as exc_info: - client.completion(request) - - assert exc_info.value.kind == ProviderErrorKind.UNSUPPORTED_CAPABILITY - assert exc_info.value.provider_name == OPENAI_PROVIDER - assert exc_info.value.model_name == MODEL - assert "video" in exc_info.value.message - sync_mock.post.assert_not_called() - - # --- Auth headers --- @@ -578,6 +520,25 @@ def test_http_error_maps_to_provider_error( assert exc_info.value.kind == expected_kind +def test_http_400_error_preserves_provider_message() -> None: + error_json = {"error": {"message": "Unsupported content type: audio_url"}} + sync_mock = make_mock_sync_client(error_json, status_code=400) + client = _make_client(sync_client=sync_mock) + + audio_block = get_media_url_context(Modality.AUDIO.value, "https://example.com/speech.mp3") + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [audio_block, {"type": "text", "text": "Transcribe this."}]}], + ) + with pytest.raises(ProviderError) as exc_info: + client.completion(request) + + assert exc_info.value.kind == ProviderErrorKind.BAD_REQUEST + assert exc_info.value.status_code == 400 + assert "Unsupported content type: audio_url" in exc_info.value.message + sync_mock.post.assert_called_once() + + def test_transport_timeout_raises_timeout_error() -> None: sync_mock = MagicMock() sync_mock.post = MagicMock(side_effect=TimeoutError("timed out")) From 4d981fbe2ca550498ec8e551e249728562d019e7 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 22 May 2026 11:07:00 -0600 Subject: [PATCH 14/15] sync generated colab notebooks --- docs/colab_notebooks/1-the-basics.ipynb | 80 +++++++++-------- ...ctured-outputs-and-jinja-expressions.ipynb | 90 ++++++++++--------- .../3-seeding-with-a-dataset.ipynb | 72 ++++++++------- .../4-providing-images-as-context.ipynb | 82 +++++++++-------- .../colab_notebooks/5-generating-images.ipynb | 60 +++++++------ .../6-editing-images-with-image-context.ipynb | 66 ++++++++------ 6 files changed, 249 insertions(+), 201 deletions(-) diff --git a/docs/colab_notebooks/1-the-basics.ipynb b/docs/colab_notebooks/1-the-basics.ipynb index a55028a44..bef78d828 100644 --- a/docs/colab_notebooks/1-the-basics.ipynb +++ b/docs/colab_notebooks/1-the-basics.ipynb @@ -2,15 +2,17 @@ "cells": [ { "cell_type": "markdown", - "id": "e9bc2aab", - "metadata": {}, + "id": "f5bc03e0", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "\"Open" ] }, { "cell_type": "markdown", - "id": "33dcb5be", + "id": "3454d676", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: The Basics\n", @@ -22,7 +24,7 @@ }, { "cell_type": "markdown", - "id": "adb77b8d", + "id": "4737bc0d", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -34,8 +36,10 @@ }, { "cell_type": "markdown", - "id": "170ce1ea", - "metadata": {}, + "id": "cf21b784", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "### ⚑ Colab Setup\n", "\n", @@ -45,8 +49,10 @@ { "cell_type": "code", "execution_count": null, - "id": "67e478f9", - "metadata": {}, + "id": "a87ee38a", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "%%capture\n", @@ -56,8 +62,10 @@ { "cell_type": "code", "execution_count": null, - "id": "533fc40d", - "metadata": {}, + "id": "d6c4cc8a", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "import getpass\n", @@ -74,7 +82,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9ad92889", + "id": "848dea00", "metadata": {}, "outputs": [], "source": [ @@ -84,7 +92,7 @@ }, { "cell_type": "markdown", - "id": "0232c4c6", + "id": "b97786a8", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -97,7 +105,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fbbd0cab", + "id": "b31c1fc9", "metadata": {}, "outputs": [], "source": [ @@ -106,7 +114,7 @@ }, { "cell_type": "markdown", - "id": "305f635e", + "id": "2fef5ae8", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -123,7 +131,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d0865b58", + "id": "7a9f6398", "metadata": {}, "outputs": [], "source": [ @@ -153,7 +161,7 @@ }, { "cell_type": "markdown", - "id": "6e1624f7", + "id": "1d0a178f", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -168,7 +176,7 @@ { "cell_type": "code", "execution_count": null, - "id": "33562cda", + "id": "aacc0ec5", "metadata": {}, "outputs": [], "source": [ @@ -177,7 +185,7 @@ }, { "cell_type": "markdown", - "id": "d8ec3063", + "id": "4be3497f", "metadata": {}, "source": [ "## 🎲 Getting started with sampler columns\n", @@ -194,7 +202,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70de1b0a", + "id": "e212d83e", "metadata": {}, "outputs": [], "source": [ @@ -203,7 +211,7 @@ }, { "cell_type": "markdown", - "id": "991a8f34", + "id": "c28350d3", "metadata": {}, "source": [ "Let's start designing our product review dataset by adding product category and subcategory columns.\n" @@ -212,7 +220,7 @@ { "cell_type": "code", "execution_count": null, - "id": "222cbbcc", + "id": "070f14e7", "metadata": {}, "outputs": [], "source": [ @@ -293,7 +301,7 @@ }, { "cell_type": "markdown", - "id": "29ca2aa3", + "id": "e0d8497d", "metadata": {}, "source": [ "Next, let's add samplers to generate data related to the customer and their review.\n" @@ -302,7 +310,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4ca9ba1c", + "id": "62e84282", "metadata": {}, "outputs": [], "source": [ @@ -339,7 +347,7 @@ }, { "cell_type": "markdown", - "id": "f4d54299", + "id": "8cb147fa", "metadata": {}, "source": [ "## 🦜 LLM-generated columns\n", @@ -354,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "361b63b1", + "id": "37a4a6d0", "metadata": {}, "outputs": [], "source": [ @@ -390,7 +398,7 @@ }, { "cell_type": "markdown", - "id": "49ca028a", + "id": "49559576", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -407,7 +415,7 @@ { "cell_type": "code", "execution_count": null, - "id": "068ea8c3", + "id": "0d52b447", "metadata": {}, "outputs": [], "source": [ @@ -417,7 +425,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bf196a77", + "id": "088a5004", "metadata": {}, "outputs": [], "source": [ @@ -428,7 +436,7 @@ { "cell_type": "code", "execution_count": null, - "id": "36ebb017", + "id": "9780021a", "metadata": {}, "outputs": [], "source": [ @@ -438,7 +446,7 @@ }, { "cell_type": "markdown", - "id": "1dcba545", + "id": "c9122bc6", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -451,7 +459,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e5164902", + "id": "4d6bb3c5", "metadata": {}, "outputs": [], "source": [ @@ -461,7 +469,7 @@ }, { "cell_type": "markdown", - "id": "cc433fae", + "id": "6003ae71", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -474,7 +482,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17132fe2", + "id": "e343639d", "metadata": {}, "outputs": [], "source": [ @@ -484,7 +492,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6669442a", + "id": "cd328abd", "metadata": {}, "outputs": [], "source": [ @@ -497,7 +505,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ee689b41", + "id": "6a09793a", "metadata": {}, "outputs": [], "source": [ @@ -509,7 +517,7 @@ }, { "cell_type": "markdown", - "id": "6965e6ac", + "id": "769dd181", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb index 77272fbb1..276bc86d7 100644 --- a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb +++ b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb @@ -2,15 +2,17 @@ "cells": [ { "cell_type": "markdown", - "id": "f4f854dd", - "metadata": {}, + "id": "f81a1643", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "\"Open" ] }, { "cell_type": "markdown", - "id": "027ffdf3", + "id": "0c33bf13", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Structured Outputs, Jinja Expressions, and Conditional Generation\n", @@ -24,7 +26,7 @@ }, { "cell_type": "markdown", - "id": "158f95c6", + "id": "37d85d1d", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -36,8 +38,10 @@ }, { "cell_type": "markdown", - "id": "459b2f2b", - "metadata": {}, + "id": "a3b60315", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "### ⚑ Colab Setup\n", "\n", @@ -47,8 +51,10 @@ { "cell_type": "code", "execution_count": null, - "id": "2bdb065c", - "metadata": {}, + "id": "71ea617c", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "%%capture\n", @@ -58,8 +64,10 @@ { "cell_type": "code", "execution_count": null, - "id": "8ccc1e8f", - "metadata": {}, + "id": "ab7f4096", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "import getpass\n", @@ -76,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "aeb8441e", + "id": "03e30510", "metadata": {}, "outputs": [], "source": [ @@ -86,7 +94,7 @@ }, { "cell_type": "markdown", - "id": "df989756", + "id": "31946b6a", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -99,7 +107,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2a8f113b", + "id": "edb53392", "metadata": {}, "outputs": [], "source": [ @@ -108,7 +116,7 @@ }, { "cell_type": "markdown", - "id": "b986772a", + "id": "7979bbb9", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -125,7 +133,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2ce9cf8c", + "id": "96d72c55", "metadata": {}, "outputs": [], "source": [ @@ -155,7 +163,7 @@ }, { "cell_type": "markdown", - "id": "6b5ab2ea", + "id": "ddd9d06f", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -170,7 +178,7 @@ { "cell_type": "code", "execution_count": null, - "id": "69a41c06", + "id": "96581bae", "metadata": {}, "outputs": [], "source": [ @@ -179,7 +187,7 @@ }, { "cell_type": "markdown", - "id": "b17aca77", + "id": "f8865cae", "metadata": {}, "source": [ "### πŸ§‘β€πŸŽ¨ Designing our data\n", @@ -206,7 +214,7 @@ { "cell_type": "code", "execution_count": null, - "id": "133df1c0", + "id": "85166774", "metadata": {}, "outputs": [], "source": [ @@ -234,7 +242,7 @@ }, { "cell_type": "markdown", - "id": "2535b9c0", + "id": "5a606cbe", "metadata": {}, "source": [ "Next, let's design our product review dataset using a few more tricks compared to the previous notebook.\n" @@ -243,7 +251,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7d4d991d", + "id": "ca79722a", "metadata": {}, "outputs": [], "source": [ @@ -352,7 +360,7 @@ }, { "cell_type": "markdown", - "id": "afc66880", + "id": "3df06121", "metadata": {}, "source": [ "Next, we will use more advanced Jinja expressions to create new columns.\n", @@ -369,7 +377,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5d8452b2", + "id": "932d9c49", "metadata": {}, "outputs": [], "source": [ @@ -422,7 +430,7 @@ }, { "cell_type": "markdown", - "id": "d7780299", + "id": "0ee33040", "metadata": {}, "source": [ "## 🚦 Conditional generation with `skip.when`\n", @@ -445,7 +453,7 @@ }, { "cell_type": "markdown", - "id": "794ac1aa", + "id": "f4749d4b", "metadata": {}, "source": [ "**Pattern 1 β€” Expression gate.** Only generate a detailed complaint analysis when the customer gave a low rating (1 or 2 stars).\n", @@ -455,7 +463,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6d96baaa", + "id": "4b18aefc", "metadata": {}, "outputs": [], "source": [ @@ -478,7 +486,7 @@ }, { "cell_type": "markdown", - "id": "a3598079", + "id": "9f3bedb2", "metadata": {}, "source": [ "**Pattern 2 β€” Skip propagation.** `action_items` depends on `complaint_analysis`.\n", @@ -489,7 +497,7 @@ { "cell_type": "code", "execution_count": null, - "id": "59be7563", + "id": "a7407102", "metadata": {}, "outputs": [], "source": [ @@ -508,7 +516,7 @@ }, { "cell_type": "markdown", - "id": "44cfc2e8", + "id": "c3222b17", "metadata": {}, "source": [ "**Pattern 3 β€” Propagation opt-out.** `review_summary` also depends on `complaint_analysis`,\n", @@ -519,7 +527,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a9cee7fe", + "id": "2fda072a", "metadata": {}, "outputs": [], "source": [ @@ -545,7 +553,7 @@ }, { "cell_type": "markdown", - "id": "67f39d99", + "id": "dfaf3d79", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -562,7 +570,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3aa1cd01", + "id": "cd207969", "metadata": {}, "outputs": [], "source": [ @@ -572,7 +580,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5d78f540", + "id": "3466c5de", "metadata": {}, "outputs": [], "source": [ @@ -585,7 +593,7 @@ { "cell_type": "code", "execution_count": null, - "id": "86011901", + "id": "99ec7423", "metadata": {}, "outputs": [], "source": [ @@ -597,7 +605,7 @@ }, { "cell_type": "markdown", - "id": "8fa363ed", + "id": "c3b0c432", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -610,7 +618,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3dede878", + "id": "c9d31b2e", "metadata": {}, "outputs": [], "source": [ @@ -620,7 +628,7 @@ }, { "cell_type": "markdown", - "id": "38839a98", + "id": "bfe09a95", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -633,7 +641,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8208f51b", + "id": "54e6a578", "metadata": {}, "outputs": [], "source": [ @@ -643,7 +651,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2b07217f", + "id": "210d3b83", "metadata": {}, "outputs": [], "source": [ @@ -656,7 +664,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7deaa6e2", + "id": "ba14deb4", "metadata": {}, "outputs": [], "source": [ @@ -668,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "b4c1a576", + "id": "6a8319de", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb index 7aab5eaa8..54530ee77 100644 --- a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb +++ b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb @@ -2,15 +2,17 @@ "cells": [ { "cell_type": "markdown", - "id": "21e9e0eb", - "metadata": {}, + "id": "0a01390d", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "\"Open" ] }, { "cell_type": "markdown", - "id": "b185696e", + "id": "1c353f07", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Seeding Synthetic Data Generation with an External Dataset\n", @@ -24,7 +26,7 @@ }, { "cell_type": "markdown", - "id": "692c9796", + "id": "ffeca512", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -36,8 +38,10 @@ }, { "cell_type": "markdown", - "id": "daa8cd50", - "metadata": {}, + "id": "bd06dd7b", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "### ⚑ Colab Setup\n", "\n", @@ -47,8 +51,10 @@ { "cell_type": "code", "execution_count": null, - "id": "8848bd1e", - "metadata": {}, + "id": "09d07f44", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "%%capture\n", @@ -58,8 +64,10 @@ { "cell_type": "code", "execution_count": null, - "id": "317ce78f", - "metadata": {}, + "id": "8d2baac1", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "import getpass\n", @@ -76,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1cb2d5c8", + "id": "48b16d15", "metadata": {}, "outputs": [], "source": [ @@ -86,7 +94,7 @@ }, { "cell_type": "markdown", - "id": "8b49428f", + "id": "7930135e", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -99,7 +107,7 @@ { "cell_type": "code", "execution_count": null, - "id": "69df6d66", + "id": "b033aa9d", "metadata": {}, "outputs": [], "source": [ @@ -108,7 +116,7 @@ }, { "cell_type": "markdown", - "id": "50378de0", + "id": "50c00422", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -125,7 +133,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e932a29e", + "id": "b503a010", "metadata": {}, "outputs": [], "source": [ @@ -155,7 +163,7 @@ }, { "cell_type": "markdown", - "id": "9487eecc", + "id": "efca2a84", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -170,7 +178,7 @@ { "cell_type": "code", "execution_count": null, - "id": "172f0df0", + "id": "45afdfd9", "metadata": {}, "outputs": [], "source": [ @@ -179,7 +187,7 @@ }, { "cell_type": "markdown", - "id": "54700574", + "id": "fdcfa350", "metadata": {}, "source": [ "## πŸ₯ Prepare a seed dataset\n", @@ -204,7 +212,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7c1e1f69", + "id": "1cb526af", "metadata": {}, "outputs": [], "source": [ @@ -222,7 +230,7 @@ }, { "cell_type": "markdown", - "id": "bdd24ad6", + "id": "0fcacdcc", "metadata": {}, "source": [ "## 🎨 Designing our synthetic patient notes dataset\n", @@ -235,7 +243,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2b33b6f6", + "id": "26eb22c4", "metadata": {}, "outputs": [], "source": [ @@ -316,7 +324,7 @@ }, { "cell_type": "markdown", - "id": "2d23d1c3", + "id": "667c9ec4", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -333,7 +341,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d2e864ef", + "id": "1f00421c", "metadata": {}, "outputs": [], "source": [ @@ -343,7 +351,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d948d638", + "id": "c05b8619", "metadata": {}, "outputs": [], "source": [ @@ -354,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a5bb03c7", + "id": "f04b086a", "metadata": {}, "outputs": [], "source": [ @@ -364,7 +372,7 @@ }, { "cell_type": "markdown", - "id": "a6d81e80", + "id": "8426dafb", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -377,7 +385,7 @@ { "cell_type": "code", "execution_count": null, - "id": "536d8500", + "id": "1f532c12", "metadata": {}, "outputs": [], "source": [ @@ -387,7 +395,7 @@ }, { "cell_type": "markdown", - "id": "e93e1239", + "id": "033d314c", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -400,7 +408,7 @@ { "cell_type": "code", "execution_count": null, - "id": "60a30857", + "id": "f4c27dd8", "metadata": {}, "outputs": [], "source": [ @@ -410,7 +418,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b746c558", + "id": "6188ed81", "metadata": {}, "outputs": [], "source": [ @@ -423,7 +431,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e8aa5c7e", + "id": "8e27cc08", "metadata": {}, "outputs": [], "source": [ @@ -435,7 +443,7 @@ }, { "cell_type": "markdown", - "id": "023fff7b", + "id": "44d280d6", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/colab_notebooks/4-providing-images-as-context.ipynb b/docs/colab_notebooks/4-providing-images-as-context.ipynb index 3355cafc1..7aa71227c 100644 --- a/docs/colab_notebooks/4-providing-images-as-context.ipynb +++ b/docs/colab_notebooks/4-providing-images-as-context.ipynb @@ -2,15 +2,17 @@ "cells": [ { "cell_type": "markdown", - "id": "f7d47856", - "metadata": {}, + "id": "cd505b79", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "\"Open" ] }, { "cell_type": "markdown", - "id": "e826ba2c", + "id": "ed119996", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Providing Images as Context for Vision-Based Data Generation" @@ -18,7 +20,7 @@ }, { "cell_type": "markdown", - "id": "4e0854f1", + "id": "d13a4cb5", "metadata": {}, "source": [ "#### πŸ“š What you'll learn\n", @@ -35,7 +37,7 @@ }, { "cell_type": "markdown", - "id": "adc08017", + "id": "2924c2d1", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -47,8 +49,10 @@ }, { "cell_type": "markdown", - "id": "c68a6c2c", - "metadata": {}, + "id": "4c6e4f22", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "### ⚑ Colab Setup\n", "\n", @@ -58,8 +62,10 @@ { "cell_type": "code", "execution_count": null, - "id": "67bf78ce", - "metadata": {}, + "id": "98151070", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "%%capture\n", @@ -69,8 +75,10 @@ { "cell_type": "code", "execution_count": null, - "id": "21bbf67b", - "metadata": {}, + "id": "5490b9a8", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "import getpass\n", @@ -87,7 +95,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d7056b4d", + "id": "7a66e1ce", "metadata": {}, "outputs": [], "source": [ @@ -110,7 +118,7 @@ }, { "cell_type": "markdown", - "id": "48235c24", + "id": "3e7a28c6", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -123,7 +131,7 @@ { "cell_type": "code", "execution_count": null, - "id": "768218ca", + "id": "f31d6ac0", "metadata": {}, "outputs": [], "source": [ @@ -132,7 +140,7 @@ }, { "cell_type": "markdown", - "id": "ff4a52ed", + "id": "14b063e4", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -147,7 +155,7 @@ { "cell_type": "code", "execution_count": null, - "id": "42640912", + "id": "d8fd37ae", "metadata": {}, "outputs": [], "source": [ @@ -156,7 +164,7 @@ }, { "cell_type": "markdown", - "id": "4ecad6af", + "id": "3a7e0787", "metadata": {}, "source": [ "### 🌱 Seed Dataset Creation\n", @@ -173,7 +181,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bafdf91f", + "id": "b01b5496", "metadata": {}, "outputs": [], "source": [ @@ -188,7 +196,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dc5c92ac", + "id": "78b3b9ea", "metadata": {}, "outputs": [], "source": [ @@ -233,7 +241,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d4cde737", + "id": "7b6b2908", "metadata": {}, "outputs": [], "source": [ @@ -251,7 +259,7 @@ { "cell_type": "code", "execution_count": null, - "id": "39848e33", + "id": "e0ab09d5", "metadata": {}, "outputs": [], "source": [ @@ -261,7 +269,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b94581da", + "id": "c9ce69ed", "metadata": {}, "outputs": [], "source": [ @@ -272,7 +280,7 @@ }, { "cell_type": "markdown", - "id": "media-context-capabilities", + "id": "94528475", "metadata": {}, "source": [ "### 🧩 Media context and model capabilities\n", @@ -309,7 +317,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7c561ff0", + "id": "bd8148f4", "metadata": {}, "outputs": [], "source": [ @@ -331,7 +339,7 @@ }, { "cell_type": "markdown", - "id": "99a5ad0c", + "id": "2150d704", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -348,7 +356,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d32dcf48", + "id": "85cf2067", "metadata": {}, "outputs": [], "source": [ @@ -358,7 +366,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70db2f87", + "id": "509f00ed", "metadata": {}, "outputs": [], "source": [ @@ -369,7 +377,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8b65b184", + "id": "8b1a7d15", "metadata": {}, "outputs": [], "source": [ @@ -379,7 +387,7 @@ }, { "cell_type": "markdown", - "id": "58e3147f", + "id": "9bf4843c", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -392,7 +400,7 @@ { "cell_type": "code", "execution_count": null, - "id": "82b01514", + "id": "d80d106d", "metadata": {}, "outputs": [], "source": [ @@ -402,7 +410,7 @@ }, { "cell_type": "markdown", - "id": "8274677b", + "id": "ed22e721", "metadata": {}, "source": [ "### πŸ”Ž Visual Inspection\n", @@ -413,7 +421,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e7bd89dc", + "id": "f41c068a", "metadata": { "lines_to_next_cell": 2 }, @@ -437,7 +445,7 @@ }, { "cell_type": "markdown", - "id": "01f6d07d", + "id": "f096be05", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -450,7 +458,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21981b68", + "id": "c2efd0f8", "metadata": {}, "outputs": [], "source": [ @@ -460,7 +468,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7c655cea", + "id": "1f7b5f60", "metadata": {}, "outputs": [], "source": [ @@ -473,7 +481,7 @@ { "cell_type": "code", "execution_count": null, - "id": "291a3dfc", + "id": "dbb9ea18", "metadata": {}, "outputs": [], "source": [ @@ -485,7 +493,7 @@ }, { "cell_type": "markdown", - "id": "af7c69cc", + "id": "f7a1f3ba", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/colab_notebooks/5-generating-images.ipynb b/docs/colab_notebooks/5-generating-images.ipynb index efecb0387..76a933da0 100644 --- a/docs/colab_notebooks/5-generating-images.ipynb +++ b/docs/colab_notebooks/5-generating-images.ipynb @@ -2,15 +2,17 @@ "cells": [ { "cell_type": "markdown", - "id": "52eeca6e", - "metadata": {}, + "id": "66019c7e", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "\"Open" ] }, { "cell_type": "markdown", - "id": "ea02d680", + "id": "267d3938", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Generating Images\n", @@ -32,7 +34,7 @@ }, { "cell_type": "markdown", - "id": "1c36e1cd", + "id": "486d74eb", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -43,8 +45,10 @@ }, { "cell_type": "markdown", - "id": "4933a0df", - "metadata": {}, + "id": "9c888db8", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "### ⚑ Colab Setup\n", "\n", @@ -54,8 +58,10 @@ { "cell_type": "code", "execution_count": null, - "id": "abe49f1b", - "metadata": {}, + "id": "4fcdfb3f", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "%%capture\n", @@ -65,8 +71,10 @@ { "cell_type": "code", "execution_count": null, - "id": "f6ffa0a4", - "metadata": {}, + "id": "6a87ecb2", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "import getpass\n", @@ -83,7 +91,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f1de4914", + "id": "ec5dd8e7", "metadata": {}, "outputs": [], "source": [ @@ -96,7 +104,7 @@ }, { "cell_type": "markdown", - "id": "112c71f5", + "id": "651d9f3b", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -107,7 +115,7 @@ { "cell_type": "code", "execution_count": null, - "id": "88c82623", + "id": "5fc8972e", "metadata": {}, "outputs": [], "source": [ @@ -116,7 +124,7 @@ }, { "cell_type": "markdown", - "id": "50ca5262", + "id": "dd50d576", "metadata": {}, "source": [ "### πŸŽ›οΈ Define an image-generation model\n", @@ -128,7 +136,7 @@ { "cell_type": "code", "execution_count": null, - "id": "49fdc61e", + "id": "03ca2abf", "metadata": {}, "outputs": [], "source": [ @@ -150,7 +158,7 @@ }, { "cell_type": "markdown", - "id": "6740ea52", + "id": "73bf1fa1", "metadata": {}, "source": [ "### πŸ—οΈ Build the config: samplers + image column\n", @@ -161,7 +169,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7b89467a", + "id": "efa7ecf8", "metadata": {}, "outputs": [], "source": [ @@ -334,7 +342,7 @@ }, { "cell_type": "markdown", - "id": "ad84fd89", + "id": "e34da1ef", "metadata": {}, "source": [ "### πŸ” Preview: images as base64\n", @@ -345,7 +353,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24ecd543", + "id": "e27fc9fd", "metadata": {}, "outputs": [], "source": [ @@ -355,7 +363,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7540fc51", + "id": "437b1054", "metadata": {}, "outputs": [], "source": [ @@ -366,7 +374,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8920f6c5", + "id": "5666999a", "metadata": {}, "outputs": [], "source": [ @@ -375,7 +383,7 @@ }, { "cell_type": "markdown", - "id": "5739eee6", + "id": "9e9b5c1b", "metadata": {}, "source": [ "### πŸ†™ Create: images saved to disk\n", @@ -386,7 +394,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f5326cbb", + "id": "8adecae8", "metadata": {}, "outputs": [], "source": [ @@ -396,7 +404,7 @@ { "cell_type": "code", "execution_count": null, - "id": "506d537f", + "id": "92998c4c", "metadata": {}, "outputs": [], "source": [ @@ -407,7 +415,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8dbd4874", + "id": "0ad903b0", "metadata": {}, "outputs": [], "source": [ @@ -423,7 +431,7 @@ }, { "cell_type": "markdown", - "id": "fa0307b2", + "id": "12134406", "metadata": {}, "source": [ "## ⏭️ Next steps\n", diff --git a/docs/colab_notebooks/6-editing-images-with-image-context.ipynb b/docs/colab_notebooks/6-editing-images-with-image-context.ipynb index 8a29e17af..023dd198c 100644 --- a/docs/colab_notebooks/6-editing-images-with-image-context.ipynb +++ b/docs/colab_notebooks/6-editing-images-with-image-context.ipynb @@ -2,15 +2,17 @@ "cells": [ { "cell_type": "markdown", - "id": "7348e00d", - "metadata": {}, + "id": "30e20568", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "\"Open" ] }, { "cell_type": "markdown", - "id": "c5e18f66", + "id": "d63f4416", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Image-to-Image Editing\n", @@ -32,7 +34,7 @@ }, { "cell_type": "markdown", - "id": "daa7359c", + "id": "d3e60ea6", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -43,8 +45,10 @@ }, { "cell_type": "markdown", - "id": "5bb9d062", - "metadata": {}, + "id": "2f1c15e7", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "### ⚑ Colab Setup\n", "\n", @@ -54,8 +58,10 @@ { "cell_type": "code", "execution_count": null, - "id": "b03fb17a", - "metadata": {}, + "id": "143db4c6", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "%%capture\n", @@ -65,8 +71,10 @@ { "cell_type": "code", "execution_count": null, - "id": "e931d0de", - "metadata": {}, + "id": "d9115072", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "import getpass\n", @@ -83,7 +91,7 @@ { "cell_type": "code", "execution_count": null, - "id": "02e932f5", + "id": "dfb43d40", "metadata": {}, "outputs": [], "source": [ @@ -99,7 +107,7 @@ }, { "cell_type": "markdown", - "id": "369a04c5", + "id": "5f892cd5", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -110,7 +118,7 @@ { "cell_type": "code", "execution_count": null, - "id": "070aaa15", + "id": "70b474a9", "metadata": {}, "outputs": [], "source": [ @@ -119,7 +127,7 @@ }, { "cell_type": "markdown", - "id": "142952fe", + "id": "f2aef849", "metadata": {}, "source": [ "### πŸŽ›οΈ Define an image model\n", @@ -135,7 +143,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2d66a7c8", + "id": "aa2f73aa", "metadata": {}, "outputs": [], "source": [ @@ -157,7 +165,7 @@ }, { "cell_type": "markdown", - "id": "c4d0e592", + "id": "f19cf925", "metadata": {}, "source": [ "### πŸ—οΈ Build the configuration\n", @@ -172,7 +180,7 @@ { "cell_type": "code", "execution_count": null, - "id": "51a228bb", + "id": "d76b5043", "metadata": {}, "outputs": [], "source": [ @@ -270,7 +278,7 @@ }, { "cell_type": "markdown", - "id": "dc6d84fa", + "id": "c73e97f0", "metadata": {}, "source": [ "### πŸ” Preview: quick iteration\n", @@ -281,7 +289,7 @@ { "cell_type": "code", "execution_count": null, - "id": "05b58baa", + "id": "87f2ce90", "metadata": {}, "outputs": [], "source": [ @@ -291,7 +299,7 @@ { "cell_type": "code", "execution_count": null, - "id": "97e35ebb", + "id": "5032ba60", "metadata": {}, "outputs": [], "source": [ @@ -302,7 +310,7 @@ { "cell_type": "code", "execution_count": null, - "id": "345514ab", + "id": "b7806720", "metadata": {}, "outputs": [], "source": [ @@ -311,7 +319,7 @@ }, { "cell_type": "markdown", - "id": "15dfb8b7", + "id": "fb02667d", "metadata": { "lines_to_next_cell": 2 }, @@ -324,7 +332,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13728788", + "id": "514fc44d", "metadata": {}, "outputs": [], "source": [ @@ -355,7 +363,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6da35706", + "id": "27719b25", "metadata": {}, "outputs": [], "source": [ @@ -365,7 +373,7 @@ }, { "cell_type": "markdown", - "id": "59abd92b", + "id": "99c431db", "metadata": {}, "source": [ "### πŸ†™ Create at scale\n", @@ -376,7 +384,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25be841b", + "id": "e8862095", "metadata": {}, "outputs": [], "source": [ @@ -386,7 +394,7 @@ { "cell_type": "code", "execution_count": null, - "id": "389cc5d2", + "id": "690c8016", "metadata": {}, "outputs": [], "source": [ @@ -397,7 +405,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15002cbf", + "id": "6bd21f76", "metadata": {}, "outputs": [], "source": [ @@ -407,7 +415,7 @@ }, { "cell_type": "markdown", - "id": "ba28d5ee", + "id": "1d00589c", "metadata": {}, "source": [ "## ⏭️ Next steps\n", From f482dfb7e6298ccac96e2dd0da5de1e2b470fbb4 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 22 May 2026 11:24:16 -0600 Subject: [PATCH 15/15] align media local path autodetection --- .../4-providing-images-as-context.ipynb | 2 +- .../4-providing-images-as-context.py | 2 +- .../models/default-model-settings.mdx | 2 +- .../pages/concepts/models/model-configs.mdx | 2 +- .../src/data_designer/config/models.py | 30 ++++++------- .../tests/config/test_models.py | 44 ++++++++++++------- 6 files changed, 46 insertions(+), 36 deletions(-) diff --git a/docs/colab_notebooks/4-providing-images-as-context.ipynb b/docs/colab_notebooks/4-providing-images-as-context.ipynb index 7aa71227c..6cd599e0c 100644 --- a/docs/colab_notebooks/4-providing-images-as-context.ipynb +++ b/docs/colab_notebooks/4-providing-images-as-context.ipynb @@ -311,7 +311,7 @@ "]\n", "```\n", "\n", - "URL-backed media can use `data_type=dd.ModalityDataType.URL`, subject to the provider's URL support and file-size limits. Local audio/video paths in URL mode require the model endpoint to have filesystem access to the same paths, typically a colocated vLLM server configured for local media access." + "URL-backed media can use `data_type=dd.ModalityDataType.URL`, subject to the provider's URL support and file-size limits. Local audio/video paths require explicit URL mode and require the model endpoint to have filesystem access to the same paths, typically a colocated vLLM server configured for local media access." ] }, { diff --git a/docs/notebook_source/4-providing-images-as-context.py b/docs/notebook_source/4-providing-images-as-context.py index dbdc65cd6..301e90125 100644 --- a/docs/notebook_source/4-providing-images-as-context.py +++ b/docs/notebook_source/4-providing-images-as-context.py @@ -184,7 +184,7 @@ def convert_image_to_chat_format(record, height: int) -> dict: # ] # ``` # -# URL-backed media can use `data_type=dd.ModalityDataType.URL`, subject to the provider's URL support and file-size limits. Local audio/video paths in URL mode require the model endpoint to have filesystem access to the same paths, typically a colocated vLLM server configured for local media access. +# URL-backed media can use `data_type=dd.ModalityDataType.URL`, subject to the provider's URL support and file-size limits. Local audio/video paths require explicit URL mode and require the model endpoint to have filesystem access to the same paths, typically a colocated vLLM server configured for local media access. # %% # Add a column to generate detailed image descriptions diff --git a/fern/versions/latest/pages/concepts/models/default-model-settings.mdx b/fern/versions/latest/pages/concepts/models/default-model-settings.mdx index 9afcdc989..b75d01c29 100644 --- a/fern/versions/latest/pages/concepts/models/default-model-settings.mdx +++ b/fern/versions/latest/pages/concepts/models/default-model-settings.mdx @@ -75,7 +75,7 @@ The following model configurations are automatically available when `OPENROUTER_ | `openrouter-embedding` | `openai/text-embedding-3-large` | Text embeddings | `encoding_format="float"` | - The `multi_modal_context` field can include image, audio, and video contexts, but each model/provider combination has its own accepted input formats, media-size limits, and modality mix. Use an image-capable model for image-only workflows, and use an omni or otherwise multimodal model before sending audio or video context. Local audio/video paths in URL mode require the model endpoint to have filesystem access to the same paths, typically a colocated vLLM server configured for local media access. + The `multi_modal_context` field can include image, audio, and video contexts, but each model/provider combination has its own accepted input formats, media-size limits, and modality mix. Use an image-capable model for image-only workflows, and use an omni or otherwise multimodal model before sending audio or video context. Local audio/video paths require explicit URL mode (`data_type=url`) and require the model endpoint to have filesystem access to the same paths, typically a colocated vLLM server configured for local media access. diff --git a/fern/versions/latest/pages/concepts/models/model-configs.mdx b/fern/versions/latest/pages/concepts/models/model-configs.mdx index 15b6c4c70..3314b11e2 100644 --- a/fern/versions/latest/pages/concepts/models/model-configs.mdx +++ b/fern/versions/latest/pages/concepts/models/model-configs.mdx @@ -9,7 +9,7 @@ Model configurations define the specific models you use for synthetic data gener A `ModelConfig` specifies which LLM model to use and how it should behave during generation. When you create column configurations (like `LLMText`, `LLMCode`, or `LLMStructured`), you reference a model by its alias. Data Designer uses the model configuration to determine which model to call and with what parameters. -When a column includes `multi_modal_context`, the `ModelConfig` alias must point to a model that supports the media types you send. Data Designer can serialize image, audio, and video context blocks, but model capability is still provider-specific. Local audio/video paths in URL mode require the model endpoint to have filesystem access to the same paths, typically a colocated vLLM server configured for local media access. +When a column includes `multi_modal_context`, the `ModelConfig` alias must point to a model that supports the media types you send. Data Designer can serialize image, audio, and video context blocks, but model capability is still provider-specific. Local audio/video paths require explicit URL mode (`data_type=url`) and require the model endpoint to have filesystem access to the same paths, typically a colocated vLLM server configured for local media access. ## ModelConfig Structure diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py index 6a4748e79..e92014b47 100644 --- a/packages/data-designer-config/src/data_designer/config/models.py +++ b/packages/data-designer-config/src/data_designer/config/models.py @@ -180,10 +180,9 @@ def _image_formats_match(configured_format: ImageFormat, detected_format: ImageF class AudioContext(ModalityContext): """Configuration for providing audio context to multimodal models. - Audio context values are URL, local path, or base64 media values. Local - paths are passed through so colocated vLLM servers can read them directly. - ``audio_format`` is consulted only for base64 sources; URL and local-path - sources are passed through unchanged. + Audio context values are URL or base64 media values. Local paths may be + passed through only in explicit URL mode so colocated model endpoints can + read them directly. ``audio_format`` is consulted only for base64 sources. """ modality: Literal[Modality.AUDIO] = Modality.AUDIO @@ -193,7 +192,7 @@ def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[di """Get audio contexts. ``base_path`` is accepted for signature compatibility with ``ImageContext`` - but unused; local audio paths are passed through unchanged. + but unused; audio contexts do not resolve local files to base64. """ return [self._build_context(value) for value in normalize_media_context_values(record[self.column_name])] @@ -202,7 +201,7 @@ def _build_context(self, context_value: Any) -> dict[str, Any]: self._validate_url_context_value(context_value) return get_media_url_context(Modality.AUDIO.value, context_value) - if self.data_type is None and (is_audio_path(context_value) or is_media_url(context_value)): + if self.data_type is None and is_media_url(context_value): return get_media_url_context(Modality.AUDIO.value, context_value) media_type, data = self._resolve_base64_parts(context_value) @@ -223,8 +222,8 @@ def _resolve_base64_parts(self, context_value: Any) -> tuple[str, Any]: if is_audio_path(context_value): raise ValueError( - "audio base64 context values must be base64 audio data; use data_type=url " - "or omit data_type to pass local audio paths through" + "audio context values that look like local paths must use data_type=url; " + "otherwise provide base64 audio data" ) if self.audio_format is None: @@ -245,10 +244,9 @@ def _validate_audio_format(self) -> Self: class VideoContext(ModalityContext): """Configuration for providing video context to multimodal models. - Video context values are URL, local path, or base64 media values. Local - paths are passed through so colocated vLLM servers can read them directly. - ``video_format`` is consulted only for base64 sources; URL and local-path - sources are passed through unchanged. + Video context values are URL or base64 media values. Local paths may be + passed through only in explicit URL mode so colocated model endpoints can + read them directly. ``video_format`` is consulted only for base64 sources. """ modality: Literal[Modality.VIDEO] = Modality.VIDEO @@ -258,7 +256,7 @@ def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[di """Get video contexts. ``base_path`` is accepted for signature compatibility with ``ImageContext`` - but unused; local video paths are passed through unchanged. + but unused; video contexts do not resolve local files to base64. """ return [self._build_context(value) for value in normalize_media_context_values(record[self.column_name])] @@ -267,7 +265,7 @@ def _build_context(self, context_value: Any) -> dict[str, Any]: self._validate_url_context_value(context_value) return get_media_url_context(Modality.VIDEO.value, context_value) - if self.data_type is None and (is_video_path(context_value) or is_media_url(context_value)): + if self.data_type is None and is_media_url(context_value): return get_media_url_context(Modality.VIDEO.value, context_value) media_type, data = self._resolve_base64_parts(context_value) @@ -288,8 +286,8 @@ def _resolve_base64_parts(self, context_value: Any) -> tuple[str, Any]: if is_video_path(context_value): raise ValueError( - "video base64 context values must be base64 video data; use data_type=url " - "or omit data_type to pass local video paths through" + "video context values that look like local paths must use data_type=url; " + "otherwise provide base64 video data" ) if self.video_format is None: diff --git a/packages/data-designer-config/tests/config/test_models.py b/packages/data-designer-config/tests/config/test_models.py index ebe720cf4..ca50e94d4 100644 --- a/packages/data-designer-config/tests/config/test_models.py +++ b/packages/data-designer-config/tests/config/test_models.py @@ -246,6 +246,9 @@ def test_audio_context_get_contexts_single_string() -> None: assert audio_context.get_contexts({"audio_url": "recordings/speech.mp3"}) == [ get_media_url_context(Modality.AUDIO.value, "recordings/speech.mp3") ] + assert audio_context.get_contexts({"audio_url": "file:///data/recordings/speech.mp3"}) == [ + get_media_url_context(Modality.AUDIO.value, "file:///data/recordings/speech.mp3") + ] def test_audio_context_get_contexts_list_json_and_numpy() -> None: @@ -276,10 +279,6 @@ def test_audio_context_auto_detect_url_and_data_uri() -> None: get_media_url_context(Modality.AUDIO.value, "https://example.com/audio.mp3") ] - assert AudioContext(column_name="audio_col").get_contexts({"audio_col": "recordings/speech.wav"}) == [ - get_media_url_context(Modality.AUDIO.value, "recordings/speech.wav") - ] - assert AudioContext(column_name="audio_col").get_contexts({"audio_col": "https://example.com/download?id=123"}) == [ get_media_url_context(Modality.AUDIO.value, "https://example.com/download?id=123") ] @@ -289,6 +288,12 @@ def test_audio_context_auto_detect_url_and_data_uri() -> None: ] +@pytest.mark.parametrize("audio_path", ["recordings/speech.wav", "file:///data/recordings/speech.mp3"]) +def test_audio_context_auto_detect_local_path_rejected(audio_path: str) -> None: + with pytest.raises(ValueError, match="audio context values that look like local paths must use data_type=url"): + AudioContext(column_name="audio_col").get_contexts({"audio_col": audio_path}) + + def test_audio_context_validate_audio_format() -> None: with pytest.raises(ValueError, match="audio_format is required when data_type is base64"): AudioContext(column_name="audio_base64", data_type=ModalityDataType.BASE64) @@ -304,11 +309,12 @@ def test_audio_context_validate_audio_format() -> None: {"audio_base64": "data:audio/mpeg;base64,audio1base64"} ) - assert AudioContext(column_name="audio_base64", audio_format=AudioFormat.MP3).get_contexts( - {"audio_base64": "screen_recording.mp3"} - ) == [get_media_url_context(Modality.AUDIO.value, "screen_recording.mp3")] + with pytest.raises(ValueError, match="audio context values that look like local paths must use data_type=url"): + AudioContext(column_name="audio_base64", audio_format=AudioFormat.MP3).get_contexts( + {"audio_base64": "screen_recording.mp3"} + ) - with pytest.raises(ValueError, match="audio base64 context values must be base64 audio data"): + with pytest.raises(ValueError, match="audio context values that look like local paths must use data_type=url"): AudioContext( column_name="audio_base64", data_type=ModalityDataType.BASE64, audio_format=AudioFormat.MP3 ).get_contexts({"audio_base64": "screen_recording.mp3"}) @@ -329,6 +335,9 @@ def test_video_context_get_contexts_single_string() -> None: assert video_context.get_contexts({"video_url": "clips/screen_recording.mp4"}) == [ get_media_url_context(Modality.VIDEO.value, "clips/screen_recording.mp4") ] + assert video_context.get_contexts({"video_url": "file:///data/clips/screen_recording.mp4"}) == [ + get_media_url_context(Modality.VIDEO.value, "file:///data/clips/screen_recording.mp4") + ] def test_video_context_get_contexts_list_json_and_numpy() -> None: @@ -359,10 +368,6 @@ def test_video_context_auto_detect_url_and_data_uri() -> None: get_media_url_context(Modality.VIDEO.value, "https://example.com/video.mp4") ] - assert VideoContext(column_name="video_col").get_contexts({"video_col": "clips/screen_recording.webm"}) == [ - get_media_url_context(Modality.VIDEO.value, "clips/screen_recording.webm") - ] - assert VideoContext(column_name="video_col").get_contexts({"video_col": "https://example.com/download?id=123"}) == [ get_media_url_context(Modality.VIDEO.value, "https://example.com/download?id=123") ] @@ -372,6 +377,12 @@ def test_video_context_auto_detect_url_and_data_uri() -> None: ] +@pytest.mark.parametrize("video_path", ["clips/screen_recording.webm", "file:///data/clips/screen_recording.mp4"]) +def test_video_context_auto_detect_local_path_rejected(video_path: str) -> None: + with pytest.raises(ValueError, match="video context values that look like local paths must use data_type=url"): + VideoContext(column_name="video_col").get_contexts({"video_col": video_path}) + + def test_video_context_validate_video_format() -> None: with pytest.raises(ValueError, match="video_format is required when data_type is base64"): VideoContext(column_name="video_base64", data_type=ModalityDataType.BASE64) @@ -387,11 +398,12 @@ def test_video_context_validate_video_format() -> None: {"video_base64": "data:video/mp4;base64,video1base64"} ) - assert VideoContext(column_name="video_base64", video_format=VideoFormat.MP4).get_contexts( - {"video_base64": "screen_recording.mp4"} - ) == [get_media_url_context(Modality.VIDEO.value, "screen_recording.mp4")] + with pytest.raises(ValueError, match="video context values that look like local paths must use data_type=url"): + VideoContext(column_name="video_base64", video_format=VideoFormat.MP4).get_contexts( + {"video_base64": "screen_recording.mp4"} + ) - with pytest.raises(ValueError, match="video base64 context values must be base64 video data"): + with pytest.raises(ValueError, match="video context values that look like local paths must use data_type=url"): VideoContext( column_name="video_base64", data_type=ModalityDataType.BASE64, video_format=VideoFormat.MP4 ).get_contexts({"video_base64": "screen_recording.mp4"})