diff --git a/docs/colab_notebooks/1-the-basics.ipynb b/docs/colab_notebooks/1-the-basics.ipynb index a55028a44..bef78d828 100644 --- a/docs/colab_notebooks/1-the-basics.ipynb +++ b/docs/colab_notebooks/1-the-basics.ipynb @@ -2,15 +2,17 @@ "cells": [ { "cell_type": "markdown", - "id": "e9bc2aab", - "metadata": {}, + "id": "f5bc03e0", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "\"Open" ] }, { "cell_type": "markdown", - "id": "33dcb5be", + "id": "3454d676", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: The Basics\n", @@ -22,7 +24,7 @@ }, { "cell_type": "markdown", - "id": "adb77b8d", + "id": "4737bc0d", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -34,8 +36,10 @@ }, { "cell_type": "markdown", - "id": "170ce1ea", - "metadata": {}, + "id": "cf21b784", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "### ⚑ Colab Setup\n", "\n", @@ -45,8 +49,10 @@ { "cell_type": "code", "execution_count": null, - "id": "67e478f9", - "metadata": {}, + "id": "a87ee38a", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "%%capture\n", @@ -56,8 +62,10 @@ { "cell_type": "code", "execution_count": null, - "id": "533fc40d", - "metadata": {}, + "id": "d6c4cc8a", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "import getpass\n", @@ -74,7 +82,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9ad92889", + "id": "848dea00", "metadata": {}, "outputs": [], "source": [ @@ -84,7 +92,7 @@ }, { "cell_type": "markdown", - "id": "0232c4c6", + "id": "b97786a8", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -97,7 +105,7 @@ { "cell_type": "code", "execution_count": null, - "id": "fbbd0cab", + "id": "b31c1fc9", "metadata": {}, "outputs": [], "source": [ @@ -106,7 +114,7 @@ }, { "cell_type": "markdown", - "id": "305f635e", + "id": "2fef5ae8", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -123,7 +131,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d0865b58", + "id": "7a9f6398", "metadata": {}, "outputs": [], "source": [ @@ -153,7 +161,7 @@ }, { "cell_type": "markdown", - "id": "6e1624f7", + "id": "1d0a178f", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -168,7 +176,7 @@ { "cell_type": "code", "execution_count": null, - "id": "33562cda", + "id": "aacc0ec5", "metadata": {}, "outputs": [], "source": [ @@ -177,7 +185,7 @@ }, { "cell_type": "markdown", - "id": "d8ec3063", + "id": "4be3497f", "metadata": {}, "source": [ "## 🎲 Getting started with sampler columns\n", @@ -194,7 +202,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70de1b0a", + "id": "e212d83e", "metadata": {}, "outputs": [], "source": [ @@ -203,7 +211,7 @@ }, { "cell_type": "markdown", - "id": "991a8f34", + "id": "c28350d3", "metadata": {}, "source": [ "Let's start designing our product review dataset by adding product category and subcategory columns.\n" @@ -212,7 +220,7 @@ { "cell_type": "code", "execution_count": null, - "id": "222cbbcc", + "id": "070f14e7", "metadata": {}, "outputs": [], "source": [ @@ -293,7 +301,7 @@ }, { "cell_type": "markdown", - "id": "29ca2aa3", + "id": "e0d8497d", "metadata": {}, "source": [ "Next, let's add samplers to generate data related to the customer and their review.\n" @@ -302,7 +310,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4ca9ba1c", + "id": "62e84282", "metadata": {}, "outputs": [], "source": [ @@ -339,7 +347,7 @@ }, { "cell_type": "markdown", - "id": "f4d54299", + "id": "8cb147fa", "metadata": {}, "source": [ "## 🦜 LLM-generated columns\n", @@ -354,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "361b63b1", + "id": "37a4a6d0", "metadata": {}, "outputs": [], "source": [ @@ -390,7 +398,7 @@ }, { "cell_type": "markdown", - "id": "49ca028a", + "id": "49559576", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -407,7 +415,7 @@ { "cell_type": "code", "execution_count": null, - "id": "068ea8c3", + "id": "0d52b447", "metadata": {}, "outputs": [], "source": [ @@ -417,7 +425,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bf196a77", + "id": "088a5004", "metadata": {}, "outputs": [], "source": [ @@ -428,7 +436,7 @@ { "cell_type": "code", "execution_count": null, - "id": "36ebb017", + "id": "9780021a", "metadata": {}, "outputs": [], "source": [ @@ -438,7 +446,7 @@ }, { "cell_type": "markdown", - "id": "1dcba545", + "id": "c9122bc6", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -451,7 +459,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e5164902", + "id": "4d6bb3c5", "metadata": {}, "outputs": [], "source": [ @@ -461,7 +469,7 @@ }, { "cell_type": "markdown", - "id": "cc433fae", + "id": "6003ae71", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -474,7 +482,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17132fe2", + "id": "e343639d", "metadata": {}, "outputs": [], "source": [ @@ -484,7 +492,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6669442a", + "id": "cd328abd", "metadata": {}, "outputs": [], "source": [ @@ -497,7 +505,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ee689b41", + "id": "6a09793a", "metadata": {}, "outputs": [], "source": [ @@ -509,7 +517,7 @@ }, { "cell_type": "markdown", - "id": "6965e6ac", + "id": "769dd181", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb index 77272fbb1..276bc86d7 100644 --- a/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb +++ b/docs/colab_notebooks/2-structured-outputs-and-jinja-expressions.ipynb @@ -2,15 +2,17 @@ "cells": [ { "cell_type": "markdown", - "id": "f4f854dd", - "metadata": {}, + "id": "f81a1643", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "\"Open" ] }, { "cell_type": "markdown", - "id": "027ffdf3", + "id": "0c33bf13", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Structured Outputs, Jinja Expressions, and Conditional Generation\n", @@ -24,7 +26,7 @@ }, { "cell_type": "markdown", - "id": "158f95c6", + "id": "37d85d1d", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -36,8 +38,10 @@ }, { "cell_type": "markdown", - "id": "459b2f2b", - "metadata": {}, + "id": "a3b60315", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "### ⚑ Colab Setup\n", "\n", @@ -47,8 +51,10 @@ { "cell_type": "code", "execution_count": null, - "id": "2bdb065c", - "metadata": {}, + "id": "71ea617c", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "%%capture\n", @@ -58,8 +64,10 @@ { "cell_type": "code", "execution_count": null, - "id": "8ccc1e8f", - "metadata": {}, + "id": "ab7f4096", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "import getpass\n", @@ -76,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "aeb8441e", + "id": "03e30510", "metadata": {}, "outputs": [], "source": [ @@ -86,7 +94,7 @@ }, { "cell_type": "markdown", - "id": "df989756", + "id": "31946b6a", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -99,7 +107,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2a8f113b", + "id": "edb53392", "metadata": {}, "outputs": [], "source": [ @@ -108,7 +116,7 @@ }, { "cell_type": "markdown", - "id": "b986772a", + "id": "7979bbb9", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -125,7 +133,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2ce9cf8c", + "id": "96d72c55", "metadata": {}, "outputs": [], "source": [ @@ -155,7 +163,7 @@ }, { "cell_type": "markdown", - "id": "6b5ab2ea", + "id": "ddd9d06f", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -170,7 +178,7 @@ { "cell_type": "code", "execution_count": null, - "id": "69a41c06", + "id": "96581bae", "metadata": {}, "outputs": [], "source": [ @@ -179,7 +187,7 @@ }, { "cell_type": "markdown", - "id": "b17aca77", + "id": "f8865cae", "metadata": {}, "source": [ "### πŸ§‘β€πŸŽ¨ Designing our data\n", @@ -206,7 +214,7 @@ { "cell_type": "code", "execution_count": null, - "id": "133df1c0", + "id": "85166774", "metadata": {}, "outputs": [], "source": [ @@ -234,7 +242,7 @@ }, { "cell_type": "markdown", - "id": "2535b9c0", + "id": "5a606cbe", "metadata": {}, "source": [ "Next, let's design our product review dataset using a few more tricks compared to the previous notebook.\n" @@ -243,7 +251,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7d4d991d", + "id": "ca79722a", "metadata": {}, "outputs": [], "source": [ @@ -352,7 +360,7 @@ }, { "cell_type": "markdown", - "id": "afc66880", + "id": "3df06121", "metadata": {}, "source": [ "Next, we will use more advanced Jinja expressions to create new columns.\n", @@ -369,7 +377,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5d8452b2", + "id": "932d9c49", "metadata": {}, "outputs": [], "source": [ @@ -422,7 +430,7 @@ }, { "cell_type": "markdown", - "id": "d7780299", + "id": "0ee33040", "metadata": {}, "source": [ "## 🚦 Conditional generation with `skip.when`\n", @@ -445,7 +453,7 @@ }, { "cell_type": "markdown", - "id": "794ac1aa", + "id": "f4749d4b", "metadata": {}, "source": [ "**Pattern 1 β€” Expression gate.** Only generate a detailed complaint analysis when the customer gave a low rating (1 or 2 stars).\n", @@ -455,7 +463,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6d96baaa", + "id": "4b18aefc", "metadata": {}, "outputs": [], "source": [ @@ -478,7 +486,7 @@ }, { "cell_type": "markdown", - "id": "a3598079", + "id": "9f3bedb2", "metadata": {}, "source": [ "**Pattern 2 β€” Skip propagation.** `action_items` depends on `complaint_analysis`.\n", @@ -489,7 +497,7 @@ { "cell_type": "code", "execution_count": null, - "id": "59be7563", + "id": "a7407102", "metadata": {}, "outputs": [], "source": [ @@ -508,7 +516,7 @@ }, { "cell_type": "markdown", - "id": "44cfc2e8", + "id": "c3222b17", "metadata": {}, "source": [ "**Pattern 3 β€” Propagation opt-out.** `review_summary` also depends on `complaint_analysis`,\n", @@ -519,7 +527,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a9cee7fe", + "id": "2fda072a", "metadata": {}, "outputs": [], "source": [ @@ -545,7 +553,7 @@ }, { "cell_type": "markdown", - "id": "67f39d99", + "id": "dfaf3d79", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -562,7 +570,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3aa1cd01", + "id": "cd207969", "metadata": {}, "outputs": [], "source": [ @@ -572,7 +580,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5d78f540", + "id": "3466c5de", "metadata": {}, "outputs": [], "source": [ @@ -585,7 +593,7 @@ { "cell_type": "code", "execution_count": null, - "id": "86011901", + "id": "99ec7423", "metadata": {}, "outputs": [], "source": [ @@ -597,7 +605,7 @@ }, { "cell_type": "markdown", - "id": "8fa363ed", + "id": "c3b0c432", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -610,7 +618,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3dede878", + "id": "c9d31b2e", "metadata": {}, "outputs": [], "source": [ @@ -620,7 +628,7 @@ }, { "cell_type": "markdown", - "id": "38839a98", + "id": "bfe09a95", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -633,7 +641,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8208f51b", + "id": "54e6a578", "metadata": {}, "outputs": [], "source": [ @@ -643,7 +651,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2b07217f", + "id": "210d3b83", "metadata": {}, "outputs": [], "source": [ @@ -656,7 +664,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7deaa6e2", + "id": "ba14deb4", "metadata": {}, "outputs": [], "source": [ @@ -668,7 +676,7 @@ }, { "cell_type": "markdown", - "id": "b4c1a576", + "id": "6a8319de", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb index 7aab5eaa8..54530ee77 100644 --- a/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb +++ b/docs/colab_notebooks/3-seeding-with-a-dataset.ipynb @@ -2,15 +2,17 @@ "cells": [ { "cell_type": "markdown", - "id": "21e9e0eb", - "metadata": {}, + "id": "0a01390d", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "\"Open" ] }, { "cell_type": "markdown", - "id": "b185696e", + "id": "1c353f07", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Seeding Synthetic Data Generation with an External Dataset\n", @@ -24,7 +26,7 @@ }, { "cell_type": "markdown", - "id": "692c9796", + "id": "ffeca512", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -36,8 +38,10 @@ }, { "cell_type": "markdown", - "id": "daa8cd50", - "metadata": {}, + "id": "bd06dd7b", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "### ⚑ Colab Setup\n", "\n", @@ -47,8 +51,10 @@ { "cell_type": "code", "execution_count": null, - "id": "8848bd1e", - "metadata": {}, + "id": "09d07f44", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "%%capture\n", @@ -58,8 +64,10 @@ { "cell_type": "code", "execution_count": null, - "id": "317ce78f", - "metadata": {}, + "id": "8d2baac1", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "import getpass\n", @@ -76,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1cb2d5c8", + "id": "48b16d15", "metadata": {}, "outputs": [], "source": [ @@ -86,7 +94,7 @@ }, { "cell_type": "markdown", - "id": "8b49428f", + "id": "7930135e", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -99,7 +107,7 @@ { "cell_type": "code", "execution_count": null, - "id": "69df6d66", + "id": "b033aa9d", "metadata": {}, "outputs": [], "source": [ @@ -108,7 +116,7 @@ }, { "cell_type": "markdown", - "id": "50378de0", + "id": "50c00422", "metadata": {}, "source": [ "### πŸŽ›οΈ Define model configurations\n", @@ -125,7 +133,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e932a29e", + "id": "b503a010", "metadata": {}, "outputs": [], "source": [ @@ -155,7 +163,7 @@ }, { "cell_type": "markdown", - "id": "9487eecc", + "id": "efca2a84", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -170,7 +178,7 @@ { "cell_type": "code", "execution_count": null, - "id": "172f0df0", + "id": "45afdfd9", "metadata": {}, "outputs": [], "source": [ @@ -179,7 +187,7 @@ }, { "cell_type": "markdown", - "id": "54700574", + "id": "fdcfa350", "metadata": {}, "source": [ "## πŸ₯ Prepare a seed dataset\n", @@ -204,7 +212,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7c1e1f69", + "id": "1cb526af", "metadata": {}, "outputs": [], "source": [ @@ -222,7 +230,7 @@ }, { "cell_type": "markdown", - "id": "bdd24ad6", + "id": "0fcacdcc", "metadata": {}, "source": [ "## 🎨 Designing our synthetic patient notes dataset\n", @@ -235,7 +243,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2b33b6f6", + "id": "26eb22c4", "metadata": {}, "outputs": [], "source": [ @@ -316,7 +324,7 @@ }, { "cell_type": "markdown", - "id": "2d23d1c3", + "id": "667c9ec4", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -333,7 +341,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d2e864ef", + "id": "1f00421c", "metadata": {}, "outputs": [], "source": [ @@ -343,7 +351,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d948d638", + "id": "c05b8619", "metadata": {}, "outputs": [], "source": [ @@ -354,7 +362,7 @@ { "cell_type": "code", "execution_count": null, - "id": "a5bb03c7", + "id": "f04b086a", "metadata": {}, "outputs": [], "source": [ @@ -364,7 +372,7 @@ }, { "cell_type": "markdown", - "id": "a6d81e80", + "id": "8426dafb", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -377,7 +385,7 @@ { "cell_type": "code", "execution_count": null, - "id": "536d8500", + "id": "1f532c12", "metadata": {}, "outputs": [], "source": [ @@ -387,7 +395,7 @@ }, { "cell_type": "markdown", - "id": "e93e1239", + "id": "033d314c", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -400,7 +408,7 @@ { "cell_type": "code", "execution_count": null, - "id": "60a30857", + "id": "f4c27dd8", "metadata": {}, "outputs": [], "source": [ @@ -410,7 +418,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b746c558", + "id": "6188ed81", "metadata": {}, "outputs": [], "source": [ @@ -423,7 +431,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e8aa5c7e", + "id": "8e27cc08", "metadata": {}, "outputs": [], "source": [ @@ -435,7 +443,7 @@ }, { "cell_type": "markdown", - "id": "023fff7b", + "id": "44d280d6", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", diff --git a/docs/colab_notebooks/4-providing-images-as-context.ipynb b/docs/colab_notebooks/4-providing-images-as-context.ipynb index ba225ac0a..6cd599e0c 100644 --- a/docs/colab_notebooks/4-providing-images-as-context.ipynb +++ b/docs/colab_notebooks/4-providing-images-as-context.ipynb @@ -2,15 +2,17 @@ "cells": [ { "cell_type": "markdown", - "id": "f7d47856", - "metadata": {}, + "id": "cd505b79", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "\"Open" ] }, { "cell_type": "markdown", - "id": "e826ba2c", + "id": "ed119996", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Providing Images as Context for Vision-Based Data Generation" @@ -18,22 +20,24 @@ }, { "cell_type": "markdown", - "id": "4e0854f1", + "id": "d13a4cb5", "metadata": {}, "source": [ "#### πŸ“š What you'll learn\n", "\n", "This notebook demonstrates how to provide images as context to generate text descriptions using vision-language models.\n", + "The same `multi_modal_context` field can also carry audio or video context when the selected model supports those modalities.\n", "\n", "- ✨ **Visual Document Processing**: Converting images to chat-ready format for model consumption\n", "- πŸ” **Vision-Language Generation**: Using vision models to generate detailed summaries from images\n", + "- 🧩 **Media Context Pattern**: Understanding how `ImageContext`, `AudioContext`, and `VideoContext` fit into the same configuration field\n", "\n", "If this is your first time using Data Designer, we recommend starting with the [first notebook](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/1-the-basics/) in this tutorial series.\n" ] }, { "cell_type": "markdown", - "id": "adc08017", + "id": "2924c2d1", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -45,8 +49,10 @@ }, { "cell_type": "markdown", - "id": "c68a6c2c", - "metadata": {}, + "id": "4c6e4f22", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "### ⚑ Colab Setup\n", "\n", @@ -56,8 +62,10 @@ { "cell_type": "code", "execution_count": null, - "id": "67bf78ce", - "metadata": {}, + "id": "98151070", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "%%capture\n", @@ -67,8 +75,10 @@ { "cell_type": "code", "execution_count": null, - "id": "21bbf67b", - "metadata": {}, + "id": "5490b9a8", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "import getpass\n", @@ -85,7 +95,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d7056b4d", + "id": "7a66e1ce", "metadata": {}, "outputs": [], "source": [ @@ -108,7 +118,7 @@ }, { "cell_type": "markdown", - "id": "48235c24", + "id": "3e7a28c6", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -121,7 +131,7 @@ { "cell_type": "code", "execution_count": null, - "id": "768218ca", + "id": "f31d6ac0", "metadata": {}, "outputs": [], "source": [ @@ -130,7 +140,7 @@ }, { "cell_type": "markdown", - "id": "ff4a52ed", + "id": "14b063e4", "metadata": {}, "source": [ "### πŸ—οΈ Initialize the Data Designer Config Builder\n", @@ -145,7 +155,7 @@ { "cell_type": "code", "execution_count": null, - "id": "42640912", + "id": "d8fd37ae", "metadata": {}, "outputs": [], "source": [ @@ -154,7 +164,7 @@ }, { "cell_type": "markdown", - "id": "4ecad6af", + "id": "3a7e0787", "metadata": {}, "source": [ "### 🌱 Seed Dataset Creation\n", @@ -171,7 +181,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bafdf91f", + "id": "b01b5496", "metadata": {}, "outputs": [], "source": [ @@ -186,7 +196,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dc5c92ac", + "id": "78b3b9ea", "metadata": {}, "outputs": [], "source": [ @@ -231,7 +241,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d4cde737", + "id": "7b6b2908", "metadata": {}, "outputs": [], "source": [ @@ -249,7 +259,7 @@ { "cell_type": "code", "execution_count": null, - "id": "39848e33", + "id": "e0ab09d5", "metadata": {}, "outputs": [], "source": [ @@ -259,7 +269,7 @@ { "cell_type": "code", "execution_count": null, - "id": "b94581da", + "id": "c9ce69ed", "metadata": {}, "outputs": [], "source": [ @@ -268,10 +278,46 @@ "config_builder.with_seed_dataset(dd.DataFrameSeedSource(df=df_seed))" ] }, + { + "cell_type": "markdown", + "id": "94528475", + "metadata": {}, + "source": [ + "### 🧩 Media context and model capabilities\n", + "\n", + "`multi_modal_context` accepts media context descriptors such as `ImageContext`, `AudioContext`, and `VideoContext`. Data Designer reads the referenced seed columns and serializes them for the model request, but the selected model still determines which modalities are valid.\n", + "\n", + "This notebook uses image context only because image-capable VLMs are broadly available. Before combining image, audio, and video in one column, choose a model alias backed by an omni or otherwise modality-compatible model, and check that the provider accepts every context type you send.\n", + "\n", + "For base64 seed columns, store the raw base64 payload without a `data:;base64,` prefix and specify the media format on the context object:\n", + "\n", + "```python\n", + "media_context = [\n", + " dd.ImageContext(\n", + " column_name=\"image_base64\",\n", + " data_type=dd.ModalityDataType.BASE64,\n", + " image_format=dd.ImageFormat.PNG,\n", + " ),\n", + " dd.AudioContext(\n", + " column_name=\"audio_base64\",\n", + " data_type=dd.ModalityDataType.BASE64,\n", + " audio_format=dd.AudioFormat.MP3,\n", + " ),\n", + " dd.VideoContext(\n", + " column_name=\"video_base64\",\n", + " data_type=dd.ModalityDataType.BASE64,\n", + " video_format=dd.VideoFormat.MP4,\n", + " ),\n", + "]\n", + "```\n", + "\n", + "URL-backed media can use `data_type=dd.ModalityDataType.URL`, subject to the provider's URL support and file-size limits. Local audio/video paths require explicit URL mode and require the model endpoint to have filesystem access to the same paths, typically a colocated vLLM server configured for local media access." + ] + }, { "cell_type": "code", "execution_count": null, - "id": "7c561ff0", + "id": "bd8148f4", "metadata": {}, "outputs": [], "source": [ @@ -293,7 +339,7 @@ }, { "cell_type": "markdown", - "id": "99a5ad0c", + "id": "2150d704", "metadata": {}, "source": [ "### πŸ” Iteration is key – preview the dataset!\n", @@ -310,7 +356,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d32dcf48", + "id": "85cf2067", "metadata": {}, "outputs": [], "source": [ @@ -320,7 +366,7 @@ { "cell_type": "code", "execution_count": null, - "id": "70db2f87", + "id": "509f00ed", "metadata": {}, "outputs": [], "source": [ @@ -331,7 +377,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8b65b184", + "id": "8b1a7d15", "metadata": {}, "outputs": [], "source": [ @@ -341,7 +387,7 @@ }, { "cell_type": "markdown", - "id": "58e3147f", + "id": "9bf4843c", "metadata": {}, "source": [ "### πŸ“Š Analyze the generated data\n", @@ -354,7 +400,7 @@ { "cell_type": "code", "execution_count": null, - "id": "82b01514", + "id": "d80d106d", "metadata": {}, "outputs": [], "source": [ @@ -364,7 +410,7 @@ }, { "cell_type": "markdown", - "id": "8274677b", + "id": "ed22e721", "metadata": {}, "source": [ "### πŸ”Ž Visual Inspection\n", @@ -375,7 +421,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e7bd89dc", + "id": "f41c068a", "metadata": { "lines_to_next_cell": 2 }, @@ -399,7 +445,7 @@ }, { "cell_type": "markdown", - "id": "01f6d07d", + "id": "f096be05", "metadata": {}, "source": [ "### πŸ†™ Scale up!\n", @@ -412,7 +458,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21981b68", + "id": "c2efd0f8", "metadata": {}, "outputs": [], "source": [ @@ -422,7 +468,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7c655cea", + "id": "1f7b5f60", "metadata": {}, "outputs": [], "source": [ @@ -435,7 +481,7 @@ { "cell_type": "code", "execution_count": null, - "id": "291a3dfc", + "id": "dbb9ea18", "metadata": {}, "outputs": [], "source": [ @@ -447,7 +493,7 @@ }, { "cell_type": "markdown", - "id": "af7c69cc", + "id": "f7a1f3ba", "metadata": {}, "source": [ "## ⏭️ Next Steps\n", @@ -456,7 +502,7 @@ "\n", "- Experiment with different vision models for specific image types\n", "- Try different prompt variations to generate specialized descriptions (e.g., technical details, key findings)\n", - "- Combine vision-based descriptions with other column types for multi-modal workflows\n", + "- Combine image, audio, or video context with other column types after confirming your selected model supports those modalities\n", "- Apply this pattern to other vision tasks like image captioning, OCR validation, or visual question answering\n", "\n", "- [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/) with Data Designer\n" diff --git a/docs/colab_notebooks/5-generating-images.ipynb b/docs/colab_notebooks/5-generating-images.ipynb index efecb0387..76a933da0 100644 --- a/docs/colab_notebooks/5-generating-images.ipynb +++ b/docs/colab_notebooks/5-generating-images.ipynb @@ -2,15 +2,17 @@ "cells": [ { "cell_type": "markdown", - "id": "52eeca6e", - "metadata": {}, + "id": "66019c7e", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "\"Open" ] }, { "cell_type": "markdown", - "id": "ea02d680", + "id": "267d3938", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Generating Images\n", @@ -32,7 +34,7 @@ }, { "cell_type": "markdown", - "id": "1c36e1cd", + "id": "486d74eb", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -43,8 +45,10 @@ }, { "cell_type": "markdown", - "id": "4933a0df", - "metadata": {}, + "id": "9c888db8", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "### ⚑ Colab Setup\n", "\n", @@ -54,8 +58,10 @@ { "cell_type": "code", "execution_count": null, - "id": "abe49f1b", - "metadata": {}, + "id": "4fcdfb3f", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "%%capture\n", @@ -65,8 +71,10 @@ { "cell_type": "code", "execution_count": null, - "id": "f6ffa0a4", - "metadata": {}, + "id": "6a87ecb2", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "import getpass\n", @@ -83,7 +91,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f1de4914", + "id": "ec5dd8e7", "metadata": {}, "outputs": [], "source": [ @@ -96,7 +104,7 @@ }, { "cell_type": "markdown", - "id": "112c71f5", + "id": "651d9f3b", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -107,7 +115,7 @@ { "cell_type": "code", "execution_count": null, - "id": "88c82623", + "id": "5fc8972e", "metadata": {}, "outputs": [], "source": [ @@ -116,7 +124,7 @@ }, { "cell_type": "markdown", - "id": "50ca5262", + "id": "dd50d576", "metadata": {}, "source": [ "### πŸŽ›οΈ Define an image-generation model\n", @@ -128,7 +136,7 @@ { "cell_type": "code", "execution_count": null, - "id": "49fdc61e", + "id": "03ca2abf", "metadata": {}, "outputs": [], "source": [ @@ -150,7 +158,7 @@ }, { "cell_type": "markdown", - "id": "6740ea52", + "id": "73bf1fa1", "metadata": {}, "source": [ "### πŸ—οΈ Build the config: samplers + image column\n", @@ -161,7 +169,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7b89467a", + "id": "efa7ecf8", "metadata": {}, "outputs": [], "source": [ @@ -334,7 +342,7 @@ }, { "cell_type": "markdown", - "id": "ad84fd89", + "id": "e34da1ef", "metadata": {}, "source": [ "### πŸ” Preview: images as base64\n", @@ -345,7 +353,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24ecd543", + "id": "e27fc9fd", "metadata": {}, "outputs": [], "source": [ @@ -355,7 +363,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7540fc51", + "id": "437b1054", "metadata": {}, "outputs": [], "source": [ @@ -366,7 +374,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8920f6c5", + "id": "5666999a", "metadata": {}, "outputs": [], "source": [ @@ -375,7 +383,7 @@ }, { "cell_type": "markdown", - "id": "5739eee6", + "id": "9e9b5c1b", "metadata": {}, "source": [ "### πŸ†™ Create: images saved to disk\n", @@ -386,7 +394,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f5326cbb", + "id": "8adecae8", "metadata": {}, "outputs": [], "source": [ @@ -396,7 +404,7 @@ { "cell_type": "code", "execution_count": null, - "id": "506d537f", + "id": "92998c4c", "metadata": {}, "outputs": [], "source": [ @@ -407,7 +415,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8dbd4874", + "id": "0ad903b0", "metadata": {}, "outputs": [], "source": [ @@ -423,7 +431,7 @@ }, { "cell_type": "markdown", - "id": "fa0307b2", + "id": "12134406", "metadata": {}, "source": [ "## ⏭️ Next steps\n", diff --git a/docs/colab_notebooks/6-editing-images-with-image-context.ipynb b/docs/colab_notebooks/6-editing-images-with-image-context.ipynb index 8a29e17af..023dd198c 100644 --- a/docs/colab_notebooks/6-editing-images-with-image-context.ipynb +++ b/docs/colab_notebooks/6-editing-images-with-image-context.ipynb @@ -2,15 +2,17 @@ "cells": [ { "cell_type": "markdown", - "id": "7348e00d", - "metadata": {}, + "id": "30e20568", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "\"Open" ] }, { "cell_type": "markdown", - "id": "c5e18f66", + "id": "d63f4416", "metadata": {}, "source": [ "# 🎨 Data Designer Tutorial: Image-to-Image Editing\n", @@ -32,7 +34,7 @@ }, { "cell_type": "markdown", - "id": "daa7359c", + "id": "d3e60ea6", "metadata": {}, "source": [ "### πŸ“¦ Import Data Designer\n", @@ -43,8 +45,10 @@ }, { "cell_type": "markdown", - "id": "5bb9d062", - "metadata": {}, + "id": "2f1c15e7", + "metadata": { + "nemo_colab_inject": true + }, "source": [ "### ⚑ Colab Setup\n", "\n", @@ -54,8 +58,10 @@ { "cell_type": "code", "execution_count": null, - "id": "b03fb17a", - "metadata": {}, + "id": "143db4c6", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "%%capture\n", @@ -65,8 +71,10 @@ { "cell_type": "code", "execution_count": null, - "id": "e931d0de", - "metadata": {}, + "id": "d9115072", + "metadata": { + "nemo_colab_inject": true + }, "outputs": [], "source": [ "import getpass\n", @@ -83,7 +91,7 @@ { "cell_type": "code", "execution_count": null, - "id": "02e932f5", + "id": "dfb43d40", "metadata": {}, "outputs": [], "source": [ @@ -99,7 +107,7 @@ }, { "cell_type": "markdown", - "id": "369a04c5", + "id": "5f892cd5", "metadata": {}, "source": [ "### βš™οΈ Initialize the Data Designer interface\n", @@ -110,7 +118,7 @@ { "cell_type": "code", "execution_count": null, - "id": "070aaa15", + "id": "70b474a9", "metadata": {}, "outputs": [], "source": [ @@ -119,7 +127,7 @@ }, { "cell_type": "markdown", - "id": "142952fe", + "id": "f2aef849", "metadata": {}, "source": [ "### πŸŽ›οΈ Define an image model\n", @@ -135,7 +143,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2d66a7c8", + "id": "aa2f73aa", "metadata": {}, "outputs": [], "source": [ @@ -157,7 +165,7 @@ }, { "cell_type": "markdown", - "id": "c4d0e592", + "id": "f19cf925", "metadata": {}, "source": [ "### πŸ—οΈ Build the configuration\n", @@ -172,7 +180,7 @@ { "cell_type": "code", "execution_count": null, - "id": "51a228bb", + "id": "d76b5043", "metadata": {}, "outputs": [], "source": [ @@ -270,7 +278,7 @@ }, { "cell_type": "markdown", - "id": "dc6d84fa", + "id": "c73e97f0", "metadata": {}, "source": [ "### πŸ” Preview: quick iteration\n", @@ -281,7 +289,7 @@ { "cell_type": "code", "execution_count": null, - "id": "05b58baa", + "id": "87f2ce90", "metadata": {}, "outputs": [], "source": [ @@ -291,7 +299,7 @@ { "cell_type": "code", "execution_count": null, - "id": "97e35ebb", + "id": "5032ba60", "metadata": {}, "outputs": [], "source": [ @@ -302,7 +310,7 @@ { "cell_type": "code", "execution_count": null, - "id": "345514ab", + "id": "b7806720", "metadata": {}, "outputs": [], "source": [ @@ -311,7 +319,7 @@ }, { "cell_type": "markdown", - "id": "15dfb8b7", + "id": "fb02667d", "metadata": { "lines_to_next_cell": 2 }, @@ -324,7 +332,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13728788", + "id": "514fc44d", "metadata": {}, "outputs": [], "source": [ @@ -355,7 +363,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6da35706", + "id": "27719b25", "metadata": {}, "outputs": [], "source": [ @@ -365,7 +373,7 @@ }, { "cell_type": "markdown", - "id": "59abd92b", + "id": "99c431db", "metadata": {}, "source": [ "### πŸ†™ Create at scale\n", @@ -376,7 +384,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25be841b", + "id": "e8862095", "metadata": {}, "outputs": [], "source": [ @@ -386,7 +394,7 @@ { "cell_type": "code", "execution_count": null, - "id": "389cc5d2", + "id": "690c8016", "metadata": {}, "outputs": [], "source": [ @@ -397,7 +405,7 @@ { "cell_type": "code", "execution_count": null, - "id": "15002cbf", + "id": "6bd21f76", "metadata": {}, "outputs": [], "source": [ @@ -407,7 +415,7 @@ }, { "cell_type": "markdown", - "id": "ba28d5ee", + "id": "1d00589c", "metadata": {}, "source": [ "## ⏭️ Next steps\n", diff --git a/docs/notebook_source/4-providing-images-as-context.py b/docs/notebook_source/4-providing-images-as-context.py index 7d849c89a..301e90125 100644 --- a/docs/notebook_source/4-providing-images-as-context.py +++ b/docs/notebook_source/4-providing-images-as-context.py @@ -19,9 +19,11 @@ # #### πŸ“š What you'll learn # # This notebook demonstrates how to provide images as context to generate text descriptions using vision-language models. +# The same `multi_modal_context` field can also carry audio or video context when the selected model supports those modalities. # # - ✨ **Visual Document Processing**: Converting images to chat-ready format for model consumption # - πŸ” **Vision-Language Generation**: Using vision models to generate detailed summaries from images +# - 🧩 **Media Context Pattern**: Understanding how `ImageContext`, `AudioContext`, and `VideoContext` fit into the same configuration field # # If this is your first time using Data Designer, we recommend starting with the [first notebook](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/1-the-basics/) in this tutorial series. # @@ -153,6 +155,37 @@ def convert_image_to_chat_format(record, height: int) -> dict: df_seed = pd.DataFrame(img_dataset)[["uuid", "label", "base64_image"]] config_builder.with_seed_dataset(dd.DataFrameSeedSource(df=df_seed)) +# %% [markdown] +# ### 🧩 Media context and model capabilities +# +# `multi_modal_context` accepts media context descriptors such as `ImageContext`, `AudioContext`, and `VideoContext`. Data Designer reads the referenced seed columns and serializes them for the model request, but the selected model still determines which modalities are valid. +# +# This notebook uses image context only because image-capable VLMs are broadly available. Before combining image, audio, and video in one column, choose a model alias backed by an omni or otherwise modality-compatible model, and check that the provider accepts every context type you send. +# +# For base64 seed columns, store the raw base64 payload without a `data:;base64,` prefix and specify the media format on the context object: +# +# ```python +# media_context = [ +# dd.ImageContext( +# column_name="image_base64", +# data_type=dd.ModalityDataType.BASE64, +# image_format=dd.ImageFormat.PNG, +# ), +# dd.AudioContext( +# column_name="audio_base64", +# data_type=dd.ModalityDataType.BASE64, +# audio_format=dd.AudioFormat.MP3, +# ), +# dd.VideoContext( +# column_name="video_base64", +# data_type=dd.ModalityDataType.BASE64, +# video_format=dd.VideoFormat.MP4, +# ), +# ] +# ``` +# +# URL-backed media can use `data_type=dd.ModalityDataType.URL`, subject to the provider's URL support and file-size limits. Local audio/video paths require explicit URL mode and require the model endpoint to have filesystem access to the same paths, typically a colocated vLLM server configured for local media access. + # %% # Add a column to generate detailed image descriptions config_builder.add_column( @@ -257,7 +290,7 @@ def convert_image_to_chat_format(record, height: int) -> dict: # # - Experiment with different vision models for specific image types # - Try different prompt variations to generate specialized descriptions (e.g., technical details, key findings) -# - Combine vision-based descriptions with other column types for multi-modal workflows +# - Combine image, audio, or video context with other column types after confirming your selected model supports those modalities # - Apply this pattern to other vision tasks like image captioning, OCR validation, or visual question answering # # - [Generating images](https://nvidia-nemo.github.io/DataDesigner/latest/notebooks/5-generating-images/) with Data Designer diff --git a/docs/notebook_source/_README.md b/docs/notebook_source/_README.md index a51336e53..e65782de1 100644 --- a/docs/notebook_source/_README.md +++ b/docs/notebook_source/_README.md @@ -95,6 +95,7 @@ Learn how to use vision-language models to generate text descriptions from image - Processing and converting images to base64 format for model consumption - Using vision-language models (VLMs) to analyze visual documents +- Understanding how image, audio, and video context share the same `multi_modal_context` field, while still requiring model support for each modality - Generating detailed summaries from document images - Inspecting and validating vision-based generation results diff --git a/fern/versions/latest/pages/concepts/columns.mdx b/fern/versions/latest/pages/concepts/columns.mdx index daab64cfa..31fce016d 100644 --- a/fern/versions/latest/pages/concepts/columns.mdx +++ b/fern/versions/latest/pages/concepts/columns.mdx @@ -45,6 +45,8 @@ LLM-Text columns generate natural language text: product descriptions, customer Use **Jinja2 templating** in prompts to reference other columns. Data Designer automatically manages dependencies and injects the referenced column values into the prompt. +LLM-Text and LLM-Structured columns can also include `multi_modal_context` with `ImageContext`, `AudioContext`, or `VideoContext`. Data Designer reads the referenced seed columns and serializes the media blocks, but it does not make an image-only model understand audio or video. Choose a `model_alias` whose underlying provider/model supports every modality in the column. + Generation Traces LLM columns can optionally capture message traces in a separate `{column_name}__trace` column. Set `with_trace` on the column config to control what's captured: `TraceType.NONE` (default, no trace), `TraceType.LAST_MESSAGE` (final assistant message only), or `TraceType.ALL_MESSAGES` (full conversation history). The trace includes the ordered message history for the final generation attempt (system/user/assistant/tool calls/tool results), and may include model reasoning fields when the provider exposes them. @@ -126,11 +128,11 @@ Image columns require a model configured with `ImageInferenceParams`. Model-spec - **Preview** (`data_designer.preview()`): Images are stored as base64-encoded strings directly in the DataFrame for quick iteration - **Create** (`data_designer.create()`): Images are saved to disk in an `images//` folder with UUID filenames; the DataFrame stores relative paths -Image columns also support `multi_modal_context` for autoregressive models that accept image inputs, enabling image-to-image generation workflows. +Image columns also support `multi_modal_context` for autoregressive multimodal models that accept media inputs, enabling image-to-image and other media-conditioned image generation workflows. Diffusion image-generation routes do not consume multimodal context, and not every autoregressive image model accepts every media type. Tutorials -The image tutorials cover three workflows: [Providing Images as Context](/tutorials/providing-images-as-context) (image β†’ text), [Generating Images](/tutorials/generating-images) (text β†’ image), and [Editing Images with Image Context](/tutorials/image-to-image-editing) (image β†’ image). +The image tutorials cover three workflows: [Providing Images as Context](/tutorials/providing-images-as-context) (image β†’ text, with notes on audio/video-capable models), [Generating Images](/tutorials/generating-images) (text β†’ image), and [Editing Images with Image Context](/tutorials/image-to-image-editing) (image β†’ image). ### 🧬 Embedding Columns diff --git a/fern/versions/latest/pages/concepts/models/default-model-settings.mdx b/fern/versions/latest/pages/concepts/models/default-model-settings.mdx index 803573770..b75d01c29 100644 --- a/fern/versions/latest/pages/concepts/models/default-model-settings.mdx +++ b/fern/versions/latest/pages/concepts/models/default-model-settings.mdx @@ -48,7 +48,7 @@ The following model configurations are automatically available when `NVIDIA_API_ |-------|-------|----------|---------------------| | `nvidia-text` | `nvidia/nemotron-3-nano-30b-a3b` | General text generation | `temperature=1.0, top_p=1.0` | | `nvidia-reasoning` | `nvidia/nemotron-3-super-120b-a12b` | Reasoning and analysis tasks | `temperature=1.0, top_p=0.95, extra_body={"reasoning_effort": "medium"}` | -| `nvidia-vision` | `nvidia/nemotron-nano-12b-v2-vl` | Vision and image understanding | `temperature=0.85, top_p=0.95` | +| `nvidia-vision` | `nvidia/nemotron-3-nano-omni-30b-a3b-reasoning` | Omni multimodal understanding for image, audio, and video inputs | `temperature=0.60, top_p=0.95` | | `nvidia-embedding` | `nvidia/llama-3.2-nv-embedqa-1b-v2` | Text embeddings | `encoding_format="float", extra_body={"input_type": "query"}` | @@ -59,8 +59,8 @@ The following model configurations are automatically available when `OPENAI_API_ | Alias | Model | Use Case | Inference Parameters | |-------|-------|----------|---------------------| | `openai-text` | `gpt-4.1` | General text generation | `temperature=0.85, top_p=0.95` | -| `openai-reasoning` | `gpt-5` | Reasoning and analysis tasks | `temperature=0.35, top_p=0.95` | -| `openai-vision` | `gpt-5` | Vision and image understanding | `temperature=0.85, top_p=0.95` | +| `openai-reasoning` | `gpt-5` | Reasoning and analysis tasks | `extra_body={"reasoning_effort": "medium"}` | +| `openai-vision` | `gpt-5` | Vision and image understanding | `extra_body={"reasoning_effort": "medium"}` | | `openai-embedding` | `text-embedding-3-large` | Text embeddings | `encoding_format="float"` | ### OpenRouter Models @@ -71,9 +71,13 @@ The following model configurations are automatically available when `OPENROUTER_ |-------|-------|----------|---------------------| | `openrouter-text` | `nvidia/nemotron-3-nano-30b-a3b` | General text generation | `temperature=1.0, top_p=1.0` | | `openrouter-reasoning` | `openai/gpt-oss-20b` | Reasoning and analysis tasks | `temperature=0.35, top_p=0.95` | -| `openrouter-vision` | `nvidia/nemotron-3-nano-omni-30b-a3b-reasoning:free` | Vision and image understanding | `temperature=0.60, top_p=0.95` | +| `openrouter-vision` | `nvidia/nemotron-3-nano-omni-30b-a3b-reasoning:free` | Omni multimodal understanding for image, audio, and video inputs, subject to OpenRouter model support | `temperature=0.60, top_p=0.95` | | `openrouter-embedding` | `openai/text-embedding-3-large` | Text embeddings | `encoding_format="float"` | + + The `multi_modal_context` field can include image, audio, and video contexts, but each model/provider combination has its own accepted input formats, media-size limits, and modality mix. Use an image-capable model for image-only workflows, and use an omni or otherwise multimodal model before sending audio or video context. Local audio/video paths require explicit URL mode (`data_type=url`) and require the model endpoint to have filesystem access to the same paths, typically a colocated vLLM server configured for local media access. + + ## Using Default Settings diff --git a/fern/versions/latest/pages/concepts/models/model-configs.mdx b/fern/versions/latest/pages/concepts/models/model-configs.mdx index 11834f035..3314b11e2 100644 --- a/fern/versions/latest/pages/concepts/models/model-configs.mdx +++ b/fern/versions/latest/pages/concepts/models/model-configs.mdx @@ -9,6 +9,8 @@ Model configurations define the specific models you use for synthetic data gener A `ModelConfig` specifies which LLM model to use and how it should behave during generation. When you create column configurations (like `LLMText`, `LLMCode`, or `LLMStructured`), you reference a model by its alias. Data Designer uses the model configuration to determine which model to call and with what parameters. +When a column includes `multi_modal_context`, the `ModelConfig` alias must point to a model that supports the media types you send. Data Designer can serialize image, audio, and video context blocks, but model capability is still provider-specific. Local audio/video paths require explicit URL mode (`data_type=url`) and require the model endpoint to have filesystem access to the same paths, typically a colocated vLLM server configured for local media access. + ## ModelConfig Structure The `ModelConfig` class has the following fields: @@ -81,13 +83,13 @@ model_configs = [ max_tokens=4096, ), ), - # Vision tasks + # Omni multimodal tasks dd.ModelConfig( alias="vision-model", - model="nvidia/nemotron-nano-12b-v2-vl", + model="nvidia/nemotron-3-nano-omni-30b-a3b-reasoning", provider="nvidia", inference_parameters=dd.ChatCompletionInferenceParams( - temperature=0.7, + temperature=0.60, top_p=0.95, max_tokens=2048, ), diff --git a/fern/versions/latest/pages/notebooks/4-providing-images-as-context.mdx b/fern/versions/latest/pages/notebooks/4-providing-images-as-context.mdx index 2913480cd..0672eee7d 100644 --- a/fern/versions/latest/pages/notebooks/4-providing-images-as-context.mdx +++ b/fern/versions/latest/pages/notebooks/4-providing-images-as-context.mdx @@ -1,6 +1,6 @@ --- title: "Providing Images as Context" -description: "Multimodal prompts with image inputs." +description: "Multimodal prompts with image inputs and notes for audio/video-capable models." position: 5 --- diff --git a/fern/versions/latest/pages/notebooks/README.mdx b/fern/versions/latest/pages/notebooks/README.mdx index 6a4b80054..bad4af083 100644 --- a/fern/versions/latest/pages/notebooks/README.mdx +++ b/fern/versions/latest/pages/notebooks/README.mdx @@ -11,6 +11,6 @@ These tutorials walk through Data Designer end-to-end with executable Jupyter no | [The Basics](/tutorials/the-basics) | Declare columns, generate your first dataset | | [Structured Outputs, Jinja Expressions, and Conditional Generation](/tutorials/structured-outputs-jinja-expressions-and-conditional-generation) | Schema-constrained outputs and dynamic prompts | | [Seeding with an External Dataset](/tutorials/seeding-with-an-external-dataset) | Use existing data as input for generation | -| [Providing Images as Context](/tutorials/providing-images-as-context) | Multimodal prompts with image inputs | +| [Providing Images as Context](/tutorials/providing-images-as-context) | Multimodal prompts with image inputs, plus the media-context pattern for models that support audio or video | | [Generating Images](/tutorials/generating-images) | Create image columns from text prompts | | [Image-to-Image Editing](/tutorials/image-to-image-editing) | Edit images using image context | diff --git a/packages/data-designer-config/src/data_designer/config/__init__.py b/packages/data-designer-config/src/data_designer/config/__init__.py index a8f683aa3..e608476b2 100644 --- a/packages/data-designer-config/src/data_designer/config/__init__.py +++ b/packages/data-designer-config/src/data_designer/config/__init__.py @@ -38,6 +38,7 @@ ToolConfig, ) from data_designer.config.models import ( # noqa: F401 + AudioContext, ChatCompletionInferenceParams, EmbeddingInferenceParams, GenerationType, @@ -52,6 +53,7 @@ ModelProvider, UniformDistribution, UniformDistributionParams, + VideoContext, ) from data_designer.config.processors import ( # noqa: F401 DropColumnsProcessorConfig, @@ -104,8 +106,8 @@ ) from data_designer.config.seed_source_dataframe import DataFrameSeedSource # noqa: F401 from data_designer.config.utils.code_lang import CodeLang # noqa: F401 - from data_designer.config.utils.image_helpers import ImageFormat # noqa: F401 from data_designer.config.utils.info import InfoType # noqa: F401 + from data_designer.config.utils.media_helpers import AudioFormat, ImageFormat, VideoFormat # noqa: F401 from data_designer.config.utils.trace_type import TraceType # noqa: F401 from data_designer.config.validator_params import ( # noqa: F401 CodeValidatorParams, @@ -161,11 +163,13 @@ "MCPProvider": (_MOD_MCP, "MCPProvider"), "ToolConfig": (_MOD_MCP, "ToolConfig"), # models + "AudioContext": (_MOD_MODELS, "AudioContext"), + "AudioFormat": (f"{_MOD_UTILS}.media_helpers", "AudioFormat"), "ChatCompletionInferenceParams": (_MOD_MODELS, "ChatCompletionInferenceParams"), "EmbeddingInferenceParams": (_MOD_MODELS, "EmbeddingInferenceParams"), "GenerationType": (_MOD_MODELS, "GenerationType"), "ImageContext": (_MOD_MODELS, "ImageContext"), - "ImageFormat": (f"{_MOD_UTILS}.image_helpers", "ImageFormat"), + "ImageFormat": (f"{_MOD_UTILS}.media_helpers", "ImageFormat"), "ImageInferenceParams": (_MOD_MODELS, "ImageInferenceParams"), "ManualDistribution": (_MOD_MODELS, "ManualDistribution"), "ManualDistributionParams": (_MOD_MODELS, "ManualDistributionParams"), @@ -176,6 +180,8 @@ "ModelProvider": (_MOD_MODELS, "ModelProvider"), "UniformDistribution": (_MOD_MODELS, "UniformDistribution"), "UniformDistributionParams": (_MOD_MODELS, "UniformDistributionParams"), + "VideoContext": (_MOD_MODELS, "VideoContext"), + "VideoFormat": (f"{_MOD_UTILS}.media_helpers", "VideoFormat"), # processors "DropColumnsProcessorConfig": (_MOD_PROCESSORS, "DropColumnsProcessorConfig"), "ProcessorType": (_MOD_PROCESSORS, "ProcessorType"), diff --git a/packages/data-designer-config/src/data_designer/config/column_configs.py b/packages/data-designer-config/src/data_designer/config/column_configs.py index a1016fec1..f7f569cfe 100644 --- a/packages/data-designer-config/src/data_designer/config/column_configs.py +++ b/packages/data-designer-config/src/data_designer/config/column_configs.py @@ -11,14 +11,17 @@ from data_designer.config.base import ConfigBase, SingleColumnConfig from data_designer.config.errors import InvalidConfigError -from data_designer.config.models import ImageContext +from data_designer.config.models import MultiModalContextT from data_designer.config.sampler_params import SamplerParamsT, SamplerType from data_designer.config.utils.code_lang import CodeLang from data_designer.config.utils.constants import REASONING_CONTENT_COLUMN_POSTFIX, TRACE_COLUMN_POSTFIX from data_designer.config.utils.misc import assert_valid_jinja2_template, extract_keywords_from_jinja2_template from data_designer.config.utils.trace_type import TraceType +from data_designer.config.utils.warning_helpers import warn_at_caller from data_designer.config.validator_params import ValidatorParamsT, ValidatorType +_NON_IMAGE_CONTEXT_KEYS = frozenset({"audio_format", "video_format"}) + class GenerationStrategy(str, Enum): """Strategy for custom column generation.""" @@ -139,8 +142,8 @@ class LLMTextColumnConfig(SingleColumnConfig): Do not put any output parsing instructions in the system prompt. Instead, use the appropriate column type for the output you want to generate - e.g., `LLMStructuredColumnConfig` for structured output, `LLMCodeColumnConfig` for code. - multi_modal_context: Optional list of image contexts for multi-modal generation. - Enables vision-capable models to generate text based on image inputs. + multi_modal_context: Optional list of multimodal contexts for generation. + Enables capable models to generate text based on image, audio, or video inputs. tool_alias: Optional alias of the tool configuration to use for MCP tool calls. Must match a tool alias defined when initializing the DataDesignerConfigBuilder. When provided, the model may call permitted tools during generation. @@ -166,8 +169,8 @@ class LLMTextColumnConfig(SingleColumnConfig): system_prompt: str | None = Field( default=None, description="Optional system prompt to set model behavior and constraints" ) - multi_modal_context: list[ImageContext] | None = Field( - default=None, description="Optional list of ImageContext for vision model inputs" + multi_modal_context: list[MultiModalContextT] | None = Field( + default=None, description="Optional list of multimodal context inputs" ) tool_alias: str | None = Field( default=None, description="Optional alias of the tool configuration to use for MCP tool calls" @@ -180,6 +183,12 @@ class LLMTextColumnConfig(SingleColumnConfig): ) column_type: Literal["llm-text"] = "llm-text" + @field_validator("multi_modal_context", mode="before") + @classmethod + def inject_legacy_image_context_modality(cls, value: Any) -> Any: + """Preserve legacy image-context dicts that predate the modality discriminator.""" + return _inject_legacy_image_context_modality(value) + @staticmethod def get_column_emoji() -> str: return "πŸ“" @@ -250,7 +259,7 @@ class LLMCodeColumnConfig(LLMTextColumnConfig): prompt (required): Prompt template for code generation (supports Jinja2). model_alias (required): Alias of the model configuration to use. system_prompt: Optional system prompt (supports Jinja2). - multi_modal_context: Optional image contexts for multi-modal generation. + multi_modal_context: Optional multimodal contexts for generation. tool_alias: Optional tool configuration alias for MCP tool calls. with_trace: Specifies what trace information to capture in a `{column_name}__trace` column. Options are `TraceType.NONE` (default), `TraceType.LAST_MESSAGE`, or @@ -288,7 +297,7 @@ class LLMStructuredColumnConfig(LLMTextColumnConfig): prompt (required): Prompt template for structured generation (supports Jinja2). model_alias (required): Alias of the model configuration to use. system_prompt: Optional system prompt (supports Jinja2). - multi_modal_context: Optional image contexts for multi-modal generation. + multi_modal_context: Optional multimodal contexts for generation. tool_alias: Optional tool configuration alias for MCP tool calls. with_trace: Specifies what trace information to capture in a `{column_name}__trace` column. Options are `TraceType.NONE` (default), `TraceType.LAST_MESSAGE`, or @@ -358,7 +367,7 @@ class LLMJudgeColumnConfig(LLMTextColumnConfig): prompt (required): Prompt template for the judge evaluation (supports Jinja2). model_alias (required): Alias of the model configuration to use. system_prompt: Optional system prompt (supports Jinja2). - multi_modal_context: Optional image contexts for multi-modal generation. + multi_modal_context: Optional multimodal contexts for generation. tool_alias: Optional tool configuration alias for MCP tool calls. with_trace: Specifies what trace information to capture in a `{column_name}__trace` column. Options are `TraceType.NONE` (default), `TraceType.LAST_MESSAGE`, or @@ -596,9 +605,9 @@ class ImageColumnConfig(SingleColumnConfig): reference other columns (e.g., "Generate an image of a {{ character_name }}"). Must be a valid Jinja2 template. model_alias (required): The model to use for image generation. - multi_modal_context: Optional list of image contexts for multi-modal generation. - Enables autoregressive multi-modal models to generate images based on image inputs. - Only works with autoregressive models that support image-to-image generation. + multi_modal_context: Optional list of multimodal contexts for generation. + Enables autoregressive multimodal models to generate images based on image, audio, or video inputs. + Ignored by diffusion image-generation routes, which do not consume multimodal context. Inherited Attributes: name (required): Unique name of the column to be generated. @@ -609,11 +618,17 @@ class ImageColumnConfig(SingleColumnConfig): description="Jinja2 template for the image generation prompt; can reference other columns via {{ column_name }}" ) model_alias: str = Field(description="Alias of the model to use for image generation") - multi_modal_context: list[ImageContext] | None = Field( - default=None, description="Optional list of ImageContext for multi-modal image-to-image generation" + multi_modal_context: list[MultiModalContextT] | None = Field( + default=None, description="Optional list of multimodal context inputs for image generation" ) column_type: Literal["image"] = "image" + @field_validator("multi_modal_context", mode="before") + @classmethod + def inject_legacy_image_context_modality(cls, value: Any) -> Any: + """Preserve legacy image-context dicts that predate the modality discriminator.""" + return _inject_legacy_image_context_modality(value) + @staticmethod def get_column_emoji() -> str: return "πŸ–ΌοΈ" @@ -731,3 +746,29 @@ def validate_generator_function(self) -> Self: f"Expected a function decorated with @custom_column_generator." ) return self + + +def _inject_legacy_image_context_modality(value: Any) -> Any: + if not isinstance(value, list): + return value + return [ + _inject_legacy_image_context_item(item) + if isinstance(item, dict) and _is_legacy_image_context_dict(item) + else item + for item in value + ] + + +def _inject_legacy_image_context_item(item: dict[str, Any]) -> dict[str, Any]: + warn_at_caller( + "Modality-less multi_modal_context dictionaries are treated as legacy ImageContext configs. " + "Set modality='image', modality='audio', or modality='video' explicitly for new configs.", + DeprecationWarning, + ) + return {"modality": "image", **item} + + +def _is_legacy_image_context_dict(value: dict[str, Any]) -> bool: + if "modality" in value: + return False + return not _NON_IMAGE_CONTEXT_KEYS.intersection(value) diff --git a/packages/data-designer-config/src/data_designer/config/models.py b/packages/data-designer-config/src/data_designer/config/models.py index 482f78308..e92014b47 100644 --- a/packages/data-designer-config/src/data_designer/config/models.py +++ b/packages/data-designer-config/src/data_designer/config/models.py @@ -3,7 +3,6 @@ from __future__ import annotations -import json import logging from abc import ABC, abstractmethod from enum import Enum @@ -22,15 +21,29 @@ MIN_TEMPERATURE, MIN_TOP_P, ) -from data_designer.config.utils.image_helpers import ( +from data_designer.config.utils.io_helpers import smart_load_yaml +from data_designer.config.utils.media_helpers import ( + AudioFormat, ImageFormat, + VideoFormat, + audio_format_from_mime_type, + audio_mime_type, decode_base64_image, detect_image_format, + get_media_base64_context, + get_media_url_context, + image_format_from_mime_type, + is_audio_path, is_image_path, is_image_url, + is_media_url, + is_video_path, load_image_path_to_base64, + normalize_media_context_values, + parse_base64_data_uri, + video_format_from_mime_type, + video_mime_type, ) -from data_designer.config.utils.io_helpers import smart_load_yaml from data_designer.config.utils.warning_helpers import warn_at_caller logger = logging.getLogger(__name__) @@ -40,6 +53,8 @@ class Modality(str, Enum): """Supported modality types for multimodal model data.""" IMAGE = "image" + AUDIO = "audio" + VIDEO = "video" class ModalityDataType(str, Enum): @@ -77,7 +92,7 @@ class ImageContext(ModalityContext): image_format: Image format (required when data_type is explicitly "base64"). """ - modality: Modality = Modality.IMAGE + modality: Literal[Modality.IMAGE] = Modality.IMAGE image_format: ImageFormat | None = None def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[dict[str, Any]]: @@ -96,46 +111,19 @@ def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[di Returns: A list of image contexts. """ - raw_value = record[self.column_name] - - # Normalize to list of strings - if isinstance(raw_value, str): - # Try to parse as JSON first - try: - parsed_value = json.loads(raw_value) - if isinstance(parsed_value, list): - context_values = parsed_value - else: - context_values = [raw_value] - except (json.JSONDecodeError, TypeError): - context_values = [raw_value] - elif isinstance(raw_value, list): - context_values = raw_value - elif hasattr(raw_value, "__iter__") and not isinstance(raw_value, (str, bytes, dict)): - # Handle array-like objects (numpy arrays, pandas Series, etc.) - context_values = list(raw_value) - else: - context_values = [raw_value] - - # Build context list - contexts = [] - for context_value in context_values: - context = dict(type="image_url") - if self.data_type is not None: - if self.data_type == ModalityDataType.URL: - context["image_url"] = {"url": context_value} - else: - context["image_url"] = { - "url": f"data:image/{self.image_format.value};base64,{context_value}", - } - else: - # Auto-detect: resolve file paths, pass through URLs, assume base64 otherwise - context["image_url"] = self._auto_resolve_context_value(context_value, base_path) - contexts.append(context) - - return contexts - - def _auto_resolve_context_value(self, context_value: str, base_path: str | None) -> dict[str, str]: + return [ + self._build_context(value, base_path=base_path) + for value in normalize_media_context_values(record[self.column_name]) + ] + + def _build_context(self, context_value: Any, *, base_path: str | None) -> dict[str, Any]: + if self.data_type == ModalityDataType.URL: + return get_media_url_context(Modality.IMAGE.value, context_value) + if self.data_type == ModalityDataType.BASE64: + return self._format_base64_context(context_value) + return self._auto_resolve_context_value(context_value, base_path) + + def _auto_resolve_context_value(self, context_value: Any, base_path: str | None) -> dict[str, Any]: """Auto-detect the format of a context value and resolve it. Resolution rules: @@ -149,22 +137,32 @@ def _auto_resolve_context_value(self, context_value: str, base_path: str | None) return self._format_base64_context(base64_data) if is_image_url(context_value): - return {"url": context_value} + return get_media_url_context(Modality.IMAGE.value, context_value) return self._format_base64_context(context_value) - def _format_base64_context(self, base64_data: str) -> dict[str, str]: - """Format base64 image data as an image_url context dict. + def _format_base64_context(self, base64_data: str) -> dict[str, Any]: + """Format base64 image data as a canonical image source dict. Uses self.image_format if set, otherwise detects from the image bytes. """ + parsed = parse_base64_data_uri(base64_data) + if parsed is not None: + media_type, data = parsed + detected_format = image_format_from_mime_type(media_type) + if detected_format is None: + raise ValueError(f"Unsupported image media type {media_type!r}") + if self.image_format is not None and not _image_formats_match(self.image_format, detected_format): + raise ValueError( + f"image_format {self.image_format.value!r} does not match data URI media type {media_type!r}" + ) + return get_media_base64_context(Modality.IMAGE.value, media_type, data) + image_format = self.image_format if image_format is None: image_bytes = decode_base64_image(base64_data) image_format = detect_image_format(image_bytes) - return { - "url": f"data:image/{image_format.value};base64,{base64_data}", - } + return get_media_base64_context(Modality.IMAGE.value, f"image/{image_format.value}", base64_data) @model_validator(mode="after") def _validate_image_format(self) -> Self: @@ -173,6 +171,146 @@ def _validate_image_format(self) -> Self: return self +def _image_formats_match(configured_format: ImageFormat, detected_format: ImageFormat) -> bool: + if configured_format == detected_format: + return True + return {configured_format, detected_format} == {ImageFormat.JPG, ImageFormat.JPEG} + + +class AudioContext(ModalityContext): + """Configuration for providing audio context to multimodal models. + + Audio context values are URL or base64 media values. Local paths may be + passed through only in explicit URL mode so colocated model endpoints can + read them directly. ``audio_format`` is consulted only for base64 sources. + """ + + modality: Literal[Modality.AUDIO] = Modality.AUDIO + audio_format: AudioFormat | None = None + + def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[dict[str, Any]]: + """Get audio contexts. + + ``base_path`` is accepted for signature compatibility with ``ImageContext`` + but unused; audio contexts do not resolve local files to base64. + """ + return [self._build_context(value) for value in normalize_media_context_values(record[self.column_name])] + + def _build_context(self, context_value: Any) -> dict[str, Any]: + if self.data_type == ModalityDataType.URL: + self._validate_url_context_value(context_value) + return get_media_url_context(Modality.AUDIO.value, context_value) + + if self.data_type is None and is_media_url(context_value): + return get_media_url_context(Modality.AUDIO.value, context_value) + + media_type, data = self._resolve_base64_parts(context_value) + return get_media_base64_context(Modality.AUDIO.value, media_type, data) + + def _resolve_base64_parts(self, context_value: Any) -> tuple[str, Any]: + parsed = parse_base64_data_uri(context_value) + if parsed is not None: + media_type, data = parsed + detected_format = audio_format_from_mime_type(media_type) + if detected_format is None: + raise ValueError(f"Unsupported audio media type {media_type!r}") + if self.audio_format is not None and self.audio_format != detected_format: + raise ValueError( + f"audio_format {self.audio_format.value!r} does not match data URI media type {media_type!r}" + ) + return media_type, data + + if is_audio_path(context_value): + raise ValueError( + "audio context values that look like local paths must use data_type=url; " + "otherwise provide base64 audio data" + ) + + if self.audio_format is None: + raise ValueError("audio_format is required for base64 audio context values") + return audio_mime_type(self.audio_format), context_value + + def _validate_url_context_value(self, context_value: Any) -> None: + if not is_media_url(context_value) and not is_audio_path(context_value): + raise ValueError("audio URL context values must be HTTP(S) URLs or local audio paths") + + @model_validator(mode="after") + def _validate_audio_format(self) -> Self: + if self.data_type == ModalityDataType.BASE64 and self.audio_format is None: + raise ValueError(f"audio_format is required when data_type is {self.data_type.value}") + return self + + +class VideoContext(ModalityContext): + """Configuration for providing video context to multimodal models. + + Video context values are URL or base64 media values. Local paths may be + passed through only in explicit URL mode so colocated model endpoints can + read them directly. ``video_format`` is consulted only for base64 sources. + """ + + modality: Literal[Modality.VIDEO] = Modality.VIDEO + video_format: VideoFormat | None = None + + def get_contexts(self, record: dict, *, base_path: str | None = None) -> list[dict[str, Any]]: + """Get video contexts. + + ``base_path`` is accepted for signature compatibility with ``ImageContext`` + but unused; video contexts do not resolve local files to base64. + """ + return [self._build_context(value) for value in normalize_media_context_values(record[self.column_name])] + + def _build_context(self, context_value: Any) -> dict[str, Any]: + if self.data_type == ModalityDataType.URL: + self._validate_url_context_value(context_value) + return get_media_url_context(Modality.VIDEO.value, context_value) + + if self.data_type is None and is_media_url(context_value): + return get_media_url_context(Modality.VIDEO.value, context_value) + + media_type, data = self._resolve_base64_parts(context_value) + return get_media_base64_context(Modality.VIDEO.value, media_type, data) + + def _resolve_base64_parts(self, context_value: Any) -> tuple[str, Any]: + parsed = parse_base64_data_uri(context_value) + if parsed is not None: + media_type, data = parsed + detected_format = video_format_from_mime_type(media_type) + if detected_format is None: + raise ValueError(f"Unsupported video media type {media_type!r}") + if self.video_format is not None and self.video_format != detected_format: + raise ValueError( + f"video_format {self.video_format.value!r} does not match data URI media type {media_type!r}" + ) + return media_type, data + + if is_video_path(context_value): + raise ValueError( + "video context values that look like local paths must use data_type=url; " + "otherwise provide base64 video data" + ) + + if self.video_format is None: + raise ValueError("video_format is required for base64 video context values") + return video_mime_type(self.video_format), context_value + + def _validate_url_context_value(self, context_value: Any) -> None: + if not is_media_url(context_value) and not is_video_path(context_value): + raise ValueError("video URL context values must be HTTP(S) URLs or local video paths") + + @model_validator(mode="after") + def _validate_video_format(self) -> Self: + if self.data_type == ModalityDataType.BASE64 and self.video_format is None: + raise ValueError(f"video_format is required when data_type is {self.data_type.value}") + return self + + +MultiModalContextT: TypeAlias = Annotated[ + ImageContext | AudioContext | VideoContext, + Field(discriminator="modality"), +] + + DistributionParamsT = TypeVar("DistributionParamsT", bound=ConfigBase) diff --git a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py deleted file mode 100644 index 934be5b43..000000000 --- a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py +++ /dev/null @@ -1,295 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Helper utilities for working with images.""" - -from __future__ import annotations - -import base64 -import io -import re -from pathlib import Path - -import requests - -import data_designer.lazy_heavy_imports as lazy -from data_designer.config.utils.type_helpers import StrEnum - - -class ImageFormat(StrEnum): - """Supported image formats for image modality.""" - - PNG = "png" - JPG = "jpg" - JPEG = "jpeg" - GIF = "gif" - WEBP = "webp" - - -# Magic bytes for image format detection -IMAGE_FORMAT_MAGIC_BYTES = { - ImageFormat.PNG: b"\x89PNG\r\n\x1a\n", - ImageFormat.JPG: b"\xff\xd8\xff", - ImageFormat.GIF: b"GIF8", - # WEBP uses RIFF header - handled separately -} - -# Maps PIL format name (lowercase) to our ImageFormat enum. -# PIL reports "JPEG" (not "JPG"), so we normalize it here. -_PIL_FORMAT_TO_IMAGE_FORMAT: dict[str, ImageFormat] = { - "png": ImageFormat.PNG, - "jpeg": ImageFormat.JPG, - "jpg": ImageFormat.JPG, - "gif": ImageFormat.GIF, - "webp": ImageFormat.WEBP, -} - -_BASE64_PATTERN = re.compile(r"^[A-Za-z0-9+/=]+$") - -# Patterns for diffusion-based image models only (use image_generation API). -IMAGE_DIFFUSION_MODEL_PATTERNS = ( - "dall-e-", - "dalle", - "stable-diffusion", - "sd-", - "sd_", - "imagen", - "gpt-image-", -) - -SUPPORTED_IMAGE_EXTENSIONS = [f".{fmt.value.lower()}" for fmt in ImageFormat] - - -def is_image_diffusion_model(model_name: str) -> bool: - """Return True if the model is a diffusion-based image generation model. - - Args: - model_name: Model name or identifier (e.g. from provider). - - Returns: - True if the model is detected as diffusion-based, False otherwise. - """ - return any(pattern in model_name.lower() for pattern in IMAGE_DIFFUSION_MODEL_PATTERNS) - - -def extract_base64_from_data_uri(data: str) -> str: - """Extract base64 from data URI or return as-is. - - Handles data URIs like "data:image/png;base64,iVBORw0..." and returns - just the base64 portion. - - Args: - data: Data URI (e.g., "data:image/png;base64,XXX") or plain base64 - - Returns: - Base64 string without data URI prefix - - Raises: - ValueError: If data URI format is invalid - """ - if data.startswith("data:"): - if "," in data: - return data.split(",", 1)[1] - raise ValueError("Invalid data URI format: missing comma separator") - return data - - -def decode_base64_image(base64_data: str) -> bytes: - """Decode base64 string to image bytes. - - Automatically handles data URIs by extracting the base64 portion first. - - Args: - base64_data: Base64 string (with or without data URI prefix) - - Returns: - Decoded image bytes - - Raises: - ValueError: If base64 data is invalid - """ - # Remove data URI prefix if present - base64_data = extract_base64_from_data_uri(base64_data) - - try: - return base64.b64decode(base64_data, validate=True) - except Exception as e: - raise ValueError(f"Invalid base64 data: {e}") from e - - -def detect_image_format(image_bytes: bytes) -> ImageFormat: - """Detect image format from bytes. - - Uses magic bytes for fast detection, falls back to PIL for robust detection. - - Args: - image_bytes: Image data as bytes - - Returns: - Detected ImageFormat - - Raises: - ValueError: If the image format cannot be determined - """ - # Check magic bytes first (fast) - if image_bytes.startswith(IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.PNG]): - return ImageFormat.PNG - elif image_bytes.startswith(IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.JPG]): - return ImageFormat.JPG - elif image_bytes.startswith(IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.GIF]): - return ImageFormat.GIF - elif image_bytes.startswith(b"RIFF") and b"WEBP" in image_bytes[:12]: - return ImageFormat.WEBP - - # Fallback to PIL for robust detection - try: - img = lazy.Image.open(io.BytesIO(image_bytes)) - format_str = img.format.lower() if img.format else None - if format_str in _PIL_FORMAT_TO_IMAGE_FORMAT: - return _PIL_FORMAT_TO_IMAGE_FORMAT[format_str] - except Exception: - pass - - raise ValueError( - f"Unable to detect image format (first 8 bytes: {image_bytes[:8]!r}). " - f"Supported formats: {', '.join(SUPPORTED_IMAGE_EXTENSIONS)}." - ) - - -def is_image_path(value: str) -> bool: - """Check if a string is an image file path. - - Args: - value: String to check - - Returns: - True if the string looks like an image file path, False otherwise - """ - if not isinstance(value, str): - return False - return any(value.lower().endswith(ext) for ext in SUPPORTED_IMAGE_EXTENSIONS) - - -def is_base64_image(value: str) -> bool: - """Check if a string is base64-encoded image data. - - Args: - value: String to check - - Returns: - True if the string looks like base64-encoded image data, False otherwise - """ - if not isinstance(value, str): - return False - # Check if it starts with data URI scheme - if value.startswith("data:image/"): - return True - # Check if it looks like base64 (at least 100 chars, contains only base64 chars) - if len(value) > 100 and _BASE64_PATTERN.match(value[:100]): - try: - # Try to decode a small portion to verify it's valid base64 - base64.b64decode(value[:100]) - return True - except Exception: - return False - return False - - -def is_image_url(value: str) -> bool: - """Check if a string is an image URL. - - Args: - value: String to check - - Returns: - True if the string looks like an image URL, False otherwise - """ - if not isinstance(value, str): - return False - return value.startswith(("http://", "https://")) and any(ext in value.lower() for ext in SUPPORTED_IMAGE_EXTENSIONS) - - -def load_image_path_to_base64(image_path: str, base_path: str | None = None) -> str | None: - """Load an image from a file path and return as base64. - - Args: - image_path: Relative or absolute path to the image file. - base_path: Optional base path to resolve relative paths from. - - Returns: - Base64-encoded image data or None if loading fails. - """ - try: - path = Path(image_path) - - # If path is not absolute, try to resolve it - if not path.is_absolute(): - if base_path: - path = Path(base_path) / path - # If still not found, try current working directory - if not path.exists(): - path = Path.cwd() / image_path - - # Check if file exists - if not path.exists(): - return None - - # Read image file and convert to base64 - with open(path, "rb") as f: - image_bytes = f.read() - return base64.b64encode(image_bytes).decode() - except Exception: - return None - - -def load_image_url_to_base64(url: str, timeout: int = 60) -> str: - """Download an image from a URL and return as base64. - - Args: - url: HTTP(S) URL pointing to an image. - timeout: Request timeout in seconds. - - Returns: - Base64-encoded image data. - - Raises: - requests.HTTPError: If the download fails with a non-2xx status. - """ - resp = requests.get(url, timeout=timeout) - resp.raise_for_status() - return base64.b64encode(resp.content).decode() - - -async def aload_image_url_to_base64(url: str, timeout: int = 60) -> str: - """Download an image from a URL asynchronously and return as base64. - - Args: - url: HTTP(S) URL pointing to an image. - timeout: Request timeout in seconds. - - Returns: - Base64-encoded image data. - - Raises: - httpx.HTTPStatusError: If the download fails with a non-2xx status. - """ - async with lazy.httpx.AsyncClient() as client: - resp = await client.get(url, timeout=timeout) - resp.raise_for_status() - return base64.b64encode(resp.content).decode() - - -def validate_image(image_path: Path) -> None: - """Validate that an image file is readable and not corrupted. - - Args: - image_path: Path to image file - - Raises: - ValueError: If image is corrupted or unreadable - """ - try: - with lazy.Image.open(image_path) as img: - img.verify() - except Exception as e: - raise ValueError(f"Image validation failed: {e}") from e diff --git a/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py new file mode 100644 index 000000000..998b81c05 --- /dev/null +++ b/packages/data-designer-config/src/data_designer/config/utils/media_helpers.py @@ -0,0 +1,335 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared helpers for multimodal media context values.""" + +from __future__ import annotations + +import base64 +import io +import json +import re +from pathlib import Path +from typing import Any + +import requests + +import data_designer.lazy_heavy_imports as lazy +from data_designer.config.utils.type_helpers import StrEnum + +# --- Format enums and constants --- + +_BASE64_PATTERN = re.compile(r"^[A-Za-z0-9+/=]+$") +_DATA_URI_RE = re.compile(r"^data:(?P[^;]+);base64,(?P.+)$") + +_IMAGE_DIFFUSION_MODEL_PATTERNS = ( + "dall-e-", + "dalle", + "stable-diffusion", + "sd-", + "sd_", + "imagen", + "gpt-image-", +) + + +class ImageFormat(StrEnum): + """Supported image formats for image modality.""" + + PNG = "png" + JPG = "jpg" + JPEG = "jpeg" + GIF = "gif" + WEBP = "webp" + + +class AudioFormat(StrEnum): + """Supported audio formats for audio context.""" + + MP3 = "mp3" + WAV = "wav" + + +class VideoFormat(StrEnum): + """Supported video formats for video context.""" + + MP4 = "mp4" + MOV = "mov" + WEBM = "webm" + + +_SUPPORTED_IMAGE_EXTENSIONS = tuple(f".{fmt.value.lower()}" for fmt in ImageFormat) +_SUPPORTED_AUDIO_EXTENSIONS = tuple(f".{fmt.value.lower()}" for fmt in AudioFormat) +_SUPPORTED_VIDEO_EXTENSIONS = tuple(f".{fmt.value.lower()}" for fmt in VideoFormat) + +_IMAGE_FORMAT_MAGIC_BYTES = { + ImageFormat.PNG: b"\x89PNG\r\n\x1a\n", + ImageFormat.JPG: b"\xff\xd8\xff", + ImageFormat.GIF: b"GIF8", + # WEBP uses RIFF header - handled separately +} + +# Maps PIL format name (lowercase) to our ImageFormat enum. +# PIL reports "JPEG" (not "JPG"), so we normalize it here. +_PIL_FORMAT_TO_IMAGE_FORMAT: dict[str, ImageFormat] = { + "png": ImageFormat.PNG, + "jpeg": ImageFormat.JPG, + "jpg": ImageFormat.JPG, + "gif": ImageFormat.GIF, + "webp": ImageFormat.WEBP, +} + +_IMAGE_MIME_TYPE_TO_FORMAT: dict[str, ImageFormat] = { + "image/png": ImageFormat.PNG, + "image/jpeg": ImageFormat.JPG, + "image/jpg": ImageFormat.JPG, + "image/gif": ImageFormat.GIF, + "image/webp": ImageFormat.WEBP, +} +_AUDIO_FORMAT_TO_MIME_TYPE: dict[AudioFormat, str] = { + AudioFormat.MP3: "audio/mpeg", + AudioFormat.WAV: "audio/wav", +} +_VIDEO_FORMAT_TO_MIME_TYPE: dict[VideoFormat, str] = { + VideoFormat.MP4: "video/mp4", + VideoFormat.MOV: "video/quicktime", + VideoFormat.WEBM: "video/webm", +} +_AUDIO_MIME_TYPE_TO_FORMAT: dict[str, AudioFormat] = { + "audio/mpeg": AudioFormat.MP3, + "audio/mp3": AudioFormat.MP3, + "audio/wav": AudioFormat.WAV, + "audio/wave": AudioFormat.WAV, + "audio/x-wav": AudioFormat.WAV, + "audio/vnd.wave": AudioFormat.WAV, +} +_VIDEO_MIME_TYPE_TO_FORMAT: dict[str, VideoFormat] = { + "video/mp4": VideoFormat.MP4, + "video/quicktime": VideoFormat.MOV, + "video/webm": VideoFormat.WEBM, +} + + +# --- Image helpers --- + + +def is_image_diffusion_model(model_name: str) -> bool: + """Return True if the model is a diffusion-based image generation model.""" + return any(pattern in model_name.lower() for pattern in _IMAGE_DIFFUSION_MODEL_PATTERNS) + + +def extract_base64_from_data_uri(data: str) -> str: + """Extract base64 from data URI or return as-is.""" + if data.startswith("data:"): + if "," in data: + return data.split(",", 1)[1] + raise ValueError("Invalid data URI format: missing comma separator") + return data + + +def decode_base64_image(base64_data: str) -> bytes: + """Decode base64 string to image bytes.""" + base64_data = extract_base64_from_data_uri(base64_data) + + try: + return base64.b64decode(base64_data, validate=True) + except Exception as e: + raise ValueError(f"Invalid base64 data: {e}") from e + + +def detect_image_format(image_bytes: bytes) -> ImageFormat: + """Detect image format from bytes.""" + if image_bytes.startswith(_IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.PNG]): + return ImageFormat.PNG + elif image_bytes.startswith(_IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.JPG]): + return ImageFormat.JPG + elif image_bytes.startswith(_IMAGE_FORMAT_MAGIC_BYTES[ImageFormat.GIF]): + return ImageFormat.GIF + elif image_bytes.startswith(b"RIFF") and b"WEBP" in image_bytes[:12]: + return ImageFormat.WEBP + + try: + img = lazy.Image.open(io.BytesIO(image_bytes)) + format_str = img.format.lower() if img.format else None + if format_str in _PIL_FORMAT_TO_IMAGE_FORMAT: + return _PIL_FORMAT_TO_IMAGE_FORMAT[format_str] + except Exception: + pass + + raise ValueError( + f"Unable to detect image format (first 8 bytes: {image_bytes[:8]!r}). " + f"Supported formats: {', '.join(_SUPPORTED_IMAGE_EXTENSIONS)}." + ) + + +def is_image_path(value: str) -> bool: + """Check if a string is an image file path.""" + if not isinstance(value, str): + return False + return any(value.lower().endswith(ext) for ext in _SUPPORTED_IMAGE_EXTENSIONS) + + +def is_base64_image(value: str) -> bool: + """Check if a string is base64-encoded image data.""" + if not isinstance(value, str): + return False + if value.startswith("data:image/"): + return True + if len(value) > 100 and _BASE64_PATTERN.match(value[:100]): + try: + base64.b64decode(value[:100]) + return True + except Exception: + return False + return False + + +def is_image_url(value: str) -> bool: + """Check if a string is an image URL.""" + if not isinstance(value, str): + return False + return value.startswith(("http://", "https://")) and any( + ext in value.lower() for ext in _SUPPORTED_IMAGE_EXTENSIONS + ) + + +def load_image_path_to_base64(image_path: str, base_path: str | None = None) -> str | None: + """Load an image from a file path and return as base64.""" + try: + path = Path(image_path) + + if not path.is_absolute(): + if base_path: + path = Path(base_path) / path + if not path.exists(): + path = Path.cwd() / image_path + + if not path.exists(): + return None + + with open(path, "rb") as f: + image_bytes = f.read() + return base64.b64encode(image_bytes).decode() + except Exception: + return None + + +def load_image_url_to_base64(url: str, timeout: int = 60) -> str: + """Download an image from a URL and return as base64.""" + resp = requests.get(url, timeout=timeout) + resp.raise_for_status() + return base64.b64encode(resp.content).decode() + + +async def aload_image_url_to_base64(url: str, timeout: int = 60) -> str: + """Download an image from a URL asynchronously and return as base64.""" + async with lazy.httpx.AsyncClient() as client: + resp = await client.get(url, timeout=timeout) + resp.raise_for_status() + return base64.b64encode(resp.content).decode() + + +def validate_image(image_path: Path) -> None: + """Validate that an image file is readable and not corrupted.""" + try: + with lazy.Image.open(image_path) as img: + img.verify() + except Exception as e: + raise ValueError(f"Image validation failed: {e}") from e + + +# --- Canonical media blocks --- + + +def get_media_context(modality: str, source: dict[str, Any]) -> dict[str, Any]: + """Build a canonical media context block.""" + return {"type": modality, "source": source} + + +def get_media_url_context(modality: str, url: Any) -> dict[str, Any]: + """Build a canonical URL media context block.""" + return get_media_context(modality, {"type": "url", "url": url}) + + +def get_media_base64_context(modality: str, media_type: str, data: Any) -> dict[str, Any]: + """Build a canonical base64 media context block.""" + return get_media_context(modality, {"type": "base64", "media_type": media_type, "data": data}) + + +def normalize_media_context_values(raw_value: Any) -> list[Any]: + """Normalize scalar, JSON-list, list, and array-like media values.""" + if isinstance(raw_value, str): + try: + parsed_value = json.loads(raw_value) + if isinstance(parsed_value, list): + return parsed_value + except json.JSONDecodeError: + pass + return [raw_value] + + if isinstance(raw_value, list): + return raw_value + + if hasattr(raw_value, "__iter__") and not isinstance(raw_value, (str, bytes, dict)): + return list(raw_value) + + return [raw_value] + + +def parse_base64_data_uri(value: str) -> tuple[str, str] | None: + """Return ``(media_type, data)`` for a base64 data URI.""" + if not isinstance(value, str): + return None + match = _DATA_URI_RE.match(value) + if match is None: + return None + return match.group("media_type"), match.group("data") + + +# --- Audio/video helpers --- + + +def is_media_url(value: str) -> bool: + """Return whether a value is an HTTP(S) media URL.""" + return isinstance(value, str) and value.startswith(("http://", "https://")) + + +def is_audio_path(value: str) -> bool: + """Return whether a value looks like a local audio path.""" + return _has_path_extension(value, _SUPPORTED_AUDIO_EXTENSIONS) + + +def is_video_path(value: str) -> bool: + """Return whether a value looks like a local video path.""" + return _has_path_extension(value, _SUPPORTED_VIDEO_EXTENSIONS) + + +def audio_mime_type(audio_format: AudioFormat) -> str: + """Return the MIME type for an audio format.""" + return _AUDIO_FORMAT_TO_MIME_TYPE[audio_format] + + +def video_mime_type(video_format: VideoFormat) -> str: + """Return the MIME type for a video format.""" + return _VIDEO_FORMAT_TO_MIME_TYPE[video_format] + + +def image_format_from_mime_type(media_type: str) -> ImageFormat | None: + """Infer an image format from a MIME type.""" + return _IMAGE_MIME_TYPE_TO_FORMAT.get(media_type.lower()) + + +def audio_format_from_mime_type(media_type: str) -> AudioFormat | None: + """Infer an audio format from a MIME type.""" + return _AUDIO_MIME_TYPE_TO_FORMAT.get(media_type.lower()) + + +def video_format_from_mime_type(media_type: str) -> VideoFormat | None: + """Infer a video format from a MIME type.""" + return _VIDEO_MIME_TYPE_TO_FORMAT.get(media_type.lower()) + + +def _has_path_extension(value: str, supported_extensions: tuple[str, ...]) -> bool: + if not isinstance(value, str): + return False + return not is_media_url(value) and value.lower().endswith(supported_extensions) diff --git a/packages/data-designer-config/src/data_designer/config/utils/visualization.py b/packages/data-designer-config/src/data_designer/config/utils/visualization.py index f18b33e53..2bd9772ab 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/visualization.py +++ b/packages/data-designer-config/src/data_designer/config/utils/visualization.py @@ -37,7 +37,7 @@ TRACE_COLUMN_POSTFIX, ) from data_designer.config.utils.errors import DatasetSampleDisplayError -from data_designer.config.utils.image_helpers import ( +from data_designer.config.utils.media_helpers import ( extract_base64_from_data_uri, is_base64_image, is_image_path, diff --git a/packages/data-designer-config/tests/config/test_columns.py b/packages/data-designer-config/tests/config/test_columns.py index 4937a8e37..baafb1c43 100644 --- a/packages/data-designer-config/tests/config/test_columns.py +++ b/packages/data-designer-config/tests/config/test_columns.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from typing import Literal import pytest @@ -28,7 +30,7 @@ is_plugin_column_type, ) from data_designer.config.errors import InvalidConfigError -from data_designer.config.models import ImageContext +from data_designer.config.models import AudioContext, ImageContext, ModalityDataType, VideoContext from data_designer.config.sampler_params import ( CategorySamplerParams, GaussianSamplerParams, @@ -130,9 +132,13 @@ def test_llm_text_column_config_required_columns_includes_multi_modal_context(): name="test_llm_text", prompt="Classify this image: {{ description }}", model_alias=stub_model_alias, - multi_modal_context=[ImageContext(column_name="image_base64")], + multi_modal_context=[ + ImageContext(column_name="image_base64"), + AudioContext(column_name="audio_url", data_type=ModalityDataType.URL), + VideoContext(column_name="video_url", data_type=ModalityDataType.URL), + ], ) - assert set(config.required_columns) == {"description", "image_base64"} + assert set(config.required_columns) == {"description", "image_base64", "audio_url", "video_url"} def test_llm_text_column_config_required_columns_deduplicates_multi_modal_and_prompt(): @@ -150,9 +156,106 @@ def test_image_column_config_required_columns_includes_multi_modal_context(): name="test_image", prompt="Generate based on {{ style }}", model_alias=stub_model_alias, - multi_modal_context=[ImageContext(column_name="reference_image")], + multi_modal_context=[ + ImageContext(column_name="reference_image"), + AudioContext(column_name="reference_audio", data_type=ModalityDataType.URL), + VideoContext(column_name="reference_video", data_type=ModalityDataType.URL), + ], ) - assert set(config.required_columns) == {"style", "reference_image"} + assert set(config.required_columns) == {"style", "reference_image", "reference_audio", "reference_video"} + + +@pytest.mark.parametrize( + "config_cls,name", + [ + (LLMTextColumnConfig, "test_llm_text"), + (ImageColumnConfig, "test_image"), + ], +) +def test_multi_modal_context_round_trips_discriminated_union( + config_cls: type[LLMTextColumnConfig] | type[ImageColumnConfig], + name: str, +) -> None: + config = config_cls( + name=name, + prompt="Describe the context", + model_alias=stub_model_alias, + multi_modal_context=[ + ImageContext(column_name="image_url", data_type=ModalityDataType.URL), + AudioContext(column_name="audio_url", data_type=ModalityDataType.URL), + VideoContext(column_name="video_url", data_type=ModalityDataType.URL), + ], + ) + + round_tripped = config_cls(**config.model_dump()) + + assert round_tripped.multi_modal_context is not None + assert isinstance(round_tripped.multi_modal_context[0], ImageContext) + assert isinstance(round_tripped.multi_modal_context[1], AudioContext) + assert isinstance(round_tripped.multi_modal_context[2], VideoContext) + + +@pytest.mark.parametrize( + "config_cls,name", + [ + (LLMTextColumnConfig, "test_llm_text"), + (ImageColumnConfig, "test_image"), + ], +) +def test_column_config_accepts_legacy_image_context_dict( + config_cls: type[LLMTextColumnConfig] | type[ImageColumnConfig], + name: str, +) -> None: + with pytest.warns(DeprecationWarning, match="treated as legacy ImageContext configs"): + config = config_cls( + name=name, + prompt="Describe the image", + model_alias=stub_model_alias, + multi_modal_context=[{"column_name": "image_url", "data_type": "url"}], + ) + + assert config.multi_modal_context is not None + assert isinstance(config.multi_modal_context[0], ImageContext) + assert config.multi_modal_context[0].column_name == "image_url" + + +@pytest.mark.parametrize( + "context_dict", + [ + {"column_name": "audio_url", "data_type": "url"}, + {"column_name": "video_url", "data_type": "url"}, + ], + ids=["audio-url-shaped", "video-url-shaped"], +) +def test_column_config_warns_modality_less_url_context_is_legacy_image(context_dict: dict[str, str]) -> None: + with pytest.warns(DeprecationWarning, match="treated as legacy ImageContext configs"): + config = LLMTextColumnConfig( + name="test_llm_text", + prompt="Describe the context", + model_alias=stub_model_alias, + multi_modal_context=[context_dict], + ) + + assert config.multi_modal_context is not None + assert isinstance(config.multi_modal_context[0], ImageContext) + + +@pytest.mark.parametrize( + "context_dict", + [ + {"column_name": "audio_url", "data_type": "url", "audio_format": "mp3"}, + {"column_name": "video_url", "data_type": "url", "video_format": "mp4"}, + ], + ids=["audio-format", "video-format"], +) +def test_column_config_requires_modality_for_audio_video_specific_dicts(context_dict: dict[str, str]) -> None: + with pytest.raises(ValidationError, match="modality"): + LLMTextColumnConfig( + name="test_llm_text", + prompt="Describe the context", + model_alias=stub_model_alias, + multi_modal_context=[context_dict], + ) def test_llm_text_column_config_with_trace_serialization() -> None: diff --git a/packages/data-designer-config/tests/config/test_models.py b/packages/data-designer-config/tests/config/test_models.py index c5bacd818..ca50e94d4 100644 --- a/packages/data-designer-config/tests/config/test_models.py +++ b/packages/data-designer-config/tests/config/test_models.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import base64 import json import tempfile @@ -12,9 +14,12 @@ import yaml from pydantic import ValidationError +import data_designer.config as dd import data_designer.lazy_heavy_imports as lazy from data_designer.config.errors import InvalidConfigError from data_designer.config.models import ( + AudioContext, + AudioFormat, ChatCompletionInferenceParams, EmbeddingInferenceParams, GenerationType, @@ -23,12 +28,32 @@ ImageInferenceParams, ManualDistribution, ManualDistributionParams, + Modality, ModalityDataType, ModelConfig, UniformDistribution, UniformDistributionParams, + VideoContext, + VideoFormat, load_model_configs, ) +from data_designer.config.utils.media_helpers import get_media_base64_context, get_media_url_context + + +def test_media_context_exports_are_available_on_config_namespace() -> None: + assert dd.ImageContext is ImageContext + assert dd.AudioContext is AudioContext + assert dd.VideoContext is VideoContext + assert dd.ImageFormat is ImageFormat + assert dd.AudioFormat is AudioFormat + assert dd.VideoFormat is VideoFormat + + assert "ImageContext" in dd.__all__ + assert "ImageFormat" in dd.__all__ + assert "AudioContext" in dd.__all__ + assert "AudioFormat" in dd.__all__ + assert "VideoContext" in dd.__all__ + assert "VideoFormat" in dd.__all__ def test_image_context_get_contexts_single_string(): @@ -37,18 +62,12 @@ def test_image_context_get_contexts_single_string(): column_name="image_base64", data_type=ModalityDataType.BASE64, image_format=ImageFormat.PNG ) assert image_context.get_contexts({"image_base64": "somebase64encodedimagestring"}) == [ - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64,somebase64encodedimagestring"}, - } + get_media_base64_context(Modality.IMAGE.value, "image/png", "somebase64encodedimagestring") ] image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL) assert image_context.get_contexts({"image_url": "https://example.com/examle_image.png"}) == [ - { - "type": "image_url", - "image_url": {"url": "https://example.com/examle_image.png"}, - } + get_media_url_context(Modality.IMAGE.value, "https://example.com/examle_image.png") ] @@ -58,32 +77,17 @@ def test_image_context_get_contexts_list_of_strings(): column_name="image_base64", data_type=ModalityDataType.BASE64, image_format=ImageFormat.PNG ) assert image_context.get_contexts({"image_base64": ["image1base64", "image2base64", "image3base64"]}) == [ - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64,image1base64"}, - }, - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64,image2base64"}, - }, - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64,image3base64"}, - }, + get_media_base64_context(Modality.IMAGE.value, "image/png", "image1base64"), + get_media_base64_context(Modality.IMAGE.value, "image/png", "image2base64"), + get_media_base64_context(Modality.IMAGE.value, "image/png", "image3base64"), ] image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL) assert image_context.get_contexts( {"image_url": ["https://example.com/image1.png", "https://example.com/image2.png"]} ) == [ - { - "type": "image_url", - "image_url": {"url": "https://example.com/image1.png"}, - }, - { - "type": "image_url", - "image_url": {"url": "https://example.com/image2.png"}, - }, + get_media_url_context(Modality.IMAGE.value, "https://example.com/image1.png"), + get_media_url_context(Modality.IMAGE.value, "https://example.com/image2.png"), ] @@ -94,27 +98,15 @@ def test_image_context_get_contexts_numpy_array(): ) numpy_array = lazy.np.array(["image1base64", "image2base64"]) assert image_context.get_contexts({"image_base64": numpy_array}) == [ - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64,image1base64"}, - }, - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64,image2base64"}, - }, + get_media_base64_context(Modality.IMAGE.value, "image/png", "image1base64"), + get_media_base64_context(Modality.IMAGE.value, "image/png", "image2base64"), ] image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL) numpy_array = lazy.np.array(["https://example.com/image1.png", "https://example.com/image2.png"]) assert image_context.get_contexts({"image_url": numpy_array}) == [ - { - "type": "image_url", - "image_url": {"url": "https://example.com/image1.png"}, - }, - { - "type": "image_url", - "image_url": {"url": "https://example.com/image2.png"}, - }, + get_media_url_context(Modality.IMAGE.value, "https://example.com/image1.png"), + get_media_url_context(Modality.IMAGE.value, "https://example.com/image2.png"), ] @@ -125,27 +117,15 @@ def test_image_context_get_contexts_json_serialized_list(): ) json_str = json.dumps(["image1base64", "image2base64"]) assert image_context.get_contexts({"image_base64": json_str}) == [ - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64,image1base64"}, - }, - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64,image2base64"}, - }, + get_media_base64_context(Modality.IMAGE.value, "image/png", "image1base64"), + get_media_base64_context(Modality.IMAGE.value, "image/png", "image2base64"), ] image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL) json_str = json.dumps(["https://example.com/image1.png", "https://example.com/image2.png"]) assert image_context.get_contexts({"image_url": json_str}) == [ - { - "type": "image_url", - "image_url": {"url": "https://example.com/image1.png"}, - }, - { - "type": "image_url", - "image_url": {"url": "https://example.com/image2.png"}, - }, + get_media_url_context(Modality.IMAGE.value, "https://example.com/image1.png"), + get_media_url_context(Modality.IMAGE.value, "https://example.com/image2.png"), ] @@ -154,10 +134,7 @@ def test_image_context_get_contexts_json_string_not_list(): image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL) json_str = json.dumps({"nested": "object"}) assert image_context.get_contexts({"image_url": json_str}) == [ - { - "type": "image_url", - "image_url": {"url": json_str}, - } + get_media_url_context(Modality.IMAGE.value, json_str) ] @@ -166,10 +143,7 @@ def test_image_context_get_contexts_invalid_json(): image_context = ImageContext(column_name="image_url", data_type=ModalityDataType.URL) invalid_json = "not a valid json string" assert image_context.get_contexts({"image_url": invalid_json}) == [ - { - "type": "image_url", - "image_url": {"url": invalid_json}, - } + get_media_url_context(Modality.IMAGE.value, invalid_json) ] @@ -184,6 +158,21 @@ def test_image_context_validate_image_format(): ImageContext(column_name="image_base64", data_type=ModalityDataType.BASE64) +def test_image_context_validates_data_uri_media_type_against_image_format() -> None: + context = ImageContext(column_name="image_base64", image_format=ImageFormat.PNG) + + with pytest.raises(ValueError, match="image_format 'png' does not match data URI media type 'image/jpeg'"): + context.get_contexts({"image_base64": "data:image/jpeg;base64,image1base64"}) + + +def test_image_context_accepts_jpg_format_for_jpeg_data_uri() -> None: + context = ImageContext(column_name="image_base64", image_format=ImageFormat.JPG) + + assert context.get_contexts({"image_base64": "data:image/jpeg;base64,image1base64"}) == [ + get_media_base64_context(Modality.IMAGE.value, "image/jpeg", "image1base64") + ] + + def test_image_context_no_data_type_passes_validation() -> None: """Test that ImageContext without data_type passes validation.""" context = ImageContext(column_name="image_col") @@ -195,7 +184,7 @@ def test_image_context_auto_detect_url() -> None: """Test auto-detection with URL value (no data_type).""" context = ImageContext(column_name="image_col") result = context.get_contexts({"image_col": "https://example.com/image.png"}) - assert result == [{"type": "image_url", "image_url": {"url": "https://example.com/image.png"}}] + assert result == [get_media_url_context(Modality.IMAGE.value, "https://example.com/image.png")] def test_image_context_auto_detect_base64(minimal_png_base64: str) -> None: @@ -204,8 +193,7 @@ def test_image_context_auto_detect_base64(minimal_png_base64: str) -> None: context = ImageContext(column_name="image_col") result = context.get_contexts({"image_col": png_base64}) assert len(result) == 1 - assert result[0]["type"] == "image_url" - assert f"base64,{png_base64}" in result[0]["image_url"]["url"] + assert result[0] == get_media_base64_context(Modality.IMAGE.value, "image/png", png_base64) def test_image_context_auto_detect_file_path_resolved(tmp_path: Path) -> None: @@ -222,9 +210,8 @@ def test_image_context_auto_detect_file_path_resolved(tmp_path: Path) -> None: base_path=str(tmp_path), ) assert len(result) == 1 - assert result[0]["type"] == "image_url" expected_base64 = base64.b64encode(png_bytes).decode() - assert f"base64,{expected_base64}" in result[0]["image_url"]["url"] + assert result[0] == get_media_base64_context(Modality.IMAGE.value, "image/png", expected_base64) def test_image_context_auto_detect_file_path_not_resolved_without_base_path() -> None: @@ -244,6 +231,184 @@ def test_image_context_auto_detect_file_path_not_exists(tmp_path: Path) -> None: ) +def test_audio_context_get_contexts_single_string() -> None: + audio_context = AudioContext( + column_name="audio_base64", data_type=ModalityDataType.BASE64, audio_format=AudioFormat.MP3 + ) + assert audio_context.get_contexts({"audio_base64": "audio1base64"}) == [ + get_media_base64_context(Modality.AUDIO.value, "audio/mpeg", "audio1base64") + ] + + audio_context = AudioContext(column_name="audio_url", data_type=ModalityDataType.URL) + assert audio_context.get_contexts({"audio_url": "https://example.com/audio.mp3"}) == [ + get_media_url_context(Modality.AUDIO.value, "https://example.com/audio.mp3") + ] + assert audio_context.get_contexts({"audio_url": "recordings/speech.mp3"}) == [ + get_media_url_context(Modality.AUDIO.value, "recordings/speech.mp3") + ] + assert audio_context.get_contexts({"audio_url": "file:///data/recordings/speech.mp3"}) == [ + get_media_url_context(Modality.AUDIO.value, "file:///data/recordings/speech.mp3") + ] + + +def test_audio_context_get_contexts_list_json_and_numpy() -> None: + audio_context = AudioContext( + column_name="audio_base64", data_type=ModalityDataType.BASE64, audio_format=AudioFormat.WAV + ) + assert audio_context.get_contexts({"audio_base64": ["audio1", "audio2"]}) == [ + get_media_base64_context(Modality.AUDIO.value, "audio/wav", "audio1"), + get_media_base64_context(Modality.AUDIO.value, "audio/wav", "audio2"), + ] + + json_str = json.dumps(["https://example.com/audio1.mp3", "https://example.com/audio2.mp3"]) + url_context = AudioContext(column_name="audio_url", data_type=ModalityDataType.URL) + assert url_context.get_contexts({"audio_url": json_str}) == [ + get_media_url_context(Modality.AUDIO.value, "https://example.com/audio1.mp3"), + get_media_url_context(Modality.AUDIO.value, "https://example.com/audio2.mp3"), + ] + + numpy_array = lazy.np.array(["https://example.com/audio1.mp3", "https://example.com/audio2.mp3"]) + assert url_context.get_contexts({"audio_url": numpy_array}) == [ + get_media_url_context(Modality.AUDIO.value, "https://example.com/audio1.mp3"), + get_media_url_context(Modality.AUDIO.value, "https://example.com/audio2.mp3"), + ] + + +def test_audio_context_auto_detect_url_and_data_uri() -> None: + assert AudioContext(column_name="audio_col").get_contexts({"audio_col": "https://example.com/audio.mp3"}) == [ + get_media_url_context(Modality.AUDIO.value, "https://example.com/audio.mp3") + ] + + assert AudioContext(column_name="audio_col").get_contexts({"audio_col": "https://example.com/download?id=123"}) == [ + get_media_url_context(Modality.AUDIO.value, "https://example.com/download?id=123") + ] + + assert AudioContext(column_name="audio_col").get_contexts({"audio_col": "data:audio/mpeg;base64,audio1base64"}) == [ + get_media_base64_context(Modality.AUDIO.value, "audio/mpeg", "audio1base64") + ] + + +@pytest.mark.parametrize("audio_path", ["recordings/speech.wav", "file:///data/recordings/speech.mp3"]) +def test_audio_context_auto_detect_local_path_rejected(audio_path: str) -> None: + with pytest.raises(ValueError, match="audio context values that look like local paths must use data_type=url"): + AudioContext(column_name="audio_col").get_contexts({"audio_col": audio_path}) + + +def test_audio_context_validate_audio_format() -> None: + with pytest.raises(ValueError, match="audio_format is required when data_type is base64"): + AudioContext(column_name="audio_base64", data_type=ModalityDataType.BASE64) + + with pytest.raises(ValueError, match="audio URL context values must be HTTP"): + AudioContext(column_name="audio_url", data_type=ModalityDataType.URL).get_contexts({"audio_url": "not-a-url"}) + + with pytest.raises(ValueError, match="audio_format is required for base64 audio context values"): + AudioContext(column_name="audio_base64").get_contexts({"audio_base64": "audio1base64"}) + + with pytest.raises(ValueError, match="does not match data URI media type"): + AudioContext(column_name="audio_base64", audio_format=AudioFormat.WAV).get_contexts( + {"audio_base64": "data:audio/mpeg;base64,audio1base64"} + ) + + with pytest.raises(ValueError, match="audio context values that look like local paths must use data_type=url"): + AudioContext(column_name="audio_base64", audio_format=AudioFormat.MP3).get_contexts( + {"audio_base64": "screen_recording.mp3"} + ) + + with pytest.raises(ValueError, match="audio context values that look like local paths must use data_type=url"): + AudioContext( + column_name="audio_base64", data_type=ModalityDataType.BASE64, audio_format=AudioFormat.MP3 + ).get_contexts({"audio_base64": "screen_recording.mp3"}) + + +def test_video_context_get_contexts_single_string() -> None: + video_context = VideoContext( + column_name="video_base64", data_type=ModalityDataType.BASE64, video_format=VideoFormat.MP4 + ) + assert video_context.get_contexts({"video_base64": "video1base64"}) == [ + get_media_base64_context(Modality.VIDEO.value, "video/mp4", "video1base64") + ] + + video_context = VideoContext(column_name="video_url", data_type=ModalityDataType.URL) + assert video_context.get_contexts({"video_url": "https://example.com/video.mp4"}) == [ + get_media_url_context(Modality.VIDEO.value, "https://example.com/video.mp4") + ] + assert video_context.get_contexts({"video_url": "clips/screen_recording.mp4"}) == [ + get_media_url_context(Modality.VIDEO.value, "clips/screen_recording.mp4") + ] + assert video_context.get_contexts({"video_url": "file:///data/clips/screen_recording.mp4"}) == [ + get_media_url_context(Modality.VIDEO.value, "file:///data/clips/screen_recording.mp4") + ] + + +def test_video_context_get_contexts_list_json_and_numpy() -> None: + video_context = VideoContext( + column_name="video_base64", data_type=ModalityDataType.BASE64, video_format=VideoFormat.WEBM + ) + assert video_context.get_contexts({"video_base64": ["video1", "video2"]}) == [ + get_media_base64_context(Modality.VIDEO.value, "video/webm", "video1"), + get_media_base64_context(Modality.VIDEO.value, "video/webm", "video2"), + ] + + json_str = json.dumps(["https://example.com/video1.mp4", "https://example.com/video2.mp4"]) + url_context = VideoContext(column_name="video_url", data_type=ModalityDataType.URL) + assert url_context.get_contexts({"video_url": json_str}) == [ + get_media_url_context(Modality.VIDEO.value, "https://example.com/video1.mp4"), + get_media_url_context(Modality.VIDEO.value, "https://example.com/video2.mp4"), + ] + + numpy_array = lazy.np.array(["https://example.com/video1.mp4", "https://example.com/video2.mp4"]) + assert url_context.get_contexts({"video_url": numpy_array}) == [ + get_media_url_context(Modality.VIDEO.value, "https://example.com/video1.mp4"), + get_media_url_context(Modality.VIDEO.value, "https://example.com/video2.mp4"), + ] + + +def test_video_context_auto_detect_url_and_data_uri() -> None: + assert VideoContext(column_name="video_col").get_contexts({"video_col": "https://example.com/video.mp4"}) == [ + get_media_url_context(Modality.VIDEO.value, "https://example.com/video.mp4") + ] + + assert VideoContext(column_name="video_col").get_contexts({"video_col": "https://example.com/download?id=123"}) == [ + get_media_url_context(Modality.VIDEO.value, "https://example.com/download?id=123") + ] + + assert VideoContext(column_name="video_col").get_contexts({"video_col": "data:video/mp4;base64,video1base64"}) == [ + get_media_base64_context(Modality.VIDEO.value, "video/mp4", "video1base64") + ] + + +@pytest.mark.parametrize("video_path", ["clips/screen_recording.webm", "file:///data/clips/screen_recording.mp4"]) +def test_video_context_auto_detect_local_path_rejected(video_path: str) -> None: + with pytest.raises(ValueError, match="video context values that look like local paths must use data_type=url"): + VideoContext(column_name="video_col").get_contexts({"video_col": video_path}) + + +def test_video_context_validate_video_format() -> None: + with pytest.raises(ValueError, match="video_format is required when data_type is base64"): + VideoContext(column_name="video_base64", data_type=ModalityDataType.BASE64) + + with pytest.raises(ValueError, match="video URL context values must be HTTP"): + VideoContext(column_name="video_url", data_type=ModalityDataType.URL).get_contexts({"video_url": "not-a-url"}) + + with pytest.raises(ValueError, match="video_format is required for base64 video context values"): + VideoContext(column_name="video_base64").get_contexts({"video_base64": "video1base64"}) + + with pytest.raises(ValueError, match="does not match data URI media type"): + VideoContext(column_name="video_base64", video_format=VideoFormat.WEBM).get_contexts( + {"video_base64": "data:video/mp4;base64,video1base64"} + ) + + with pytest.raises(ValueError, match="video context values that look like local paths must use data_type=url"): + VideoContext(column_name="video_base64", video_format=VideoFormat.MP4).get_contexts( + {"video_base64": "screen_recording.mp4"} + ) + + with pytest.raises(ValueError, match="video context values that look like local paths must use data_type=url"): + VideoContext( + column_name="video_base64", data_type=ModalityDataType.BASE64, video_format=VideoFormat.MP4 + ).get_contexts({"video_base64": "screen_recording.mp4"}) + + def test_inference_parameters_default_construction(): empty_inference_parameters = ChatCompletionInferenceParams() assert empty_inference_parameters.generate_kwargs == {} diff --git a/packages/data-designer-config/tests/config/utils/test_image_helpers.py b/packages/data-designer-config/tests/config/utils/test_media_helpers.py similarity index 71% rename from packages/data-designer-config/tests/config/utils/test_image_helpers.py rename to packages/data-designer-config/tests/config/utils/test_media_helpers.py index e425582a2..f1e898d66 100644 --- a/packages/data-designer-config/tests/config/utils/test_image_helpers.py +++ b/packages/data-designer-config/tests/config/utils/test_media_helpers.py @@ -1,31 +1,101 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import base64 +import json +from collections.abc import Callable from pathlib import Path from unittest.mock import Mock, patch import pytest import data_designer.lazy_heavy_imports as lazy -from data_designer.config.models import ImageFormat -from data_designer.config.utils.image_helpers import ( +from data_designer.config.utils.media_helpers import ( + AudioFormat, + ImageFormat, + VideoFormat, + audio_format_from_mime_type, + audio_mime_type, decode_base64_image, detect_image_format, extract_base64_from_data_uri, + get_media_base64_context, + get_media_context, + get_media_url_context, + image_format_from_mime_type, + is_audio_path, is_base64_image, is_image_diffusion_model, is_image_path, is_image_url, + is_media_url, + is_video_path, load_image_path_to_base64, + normalize_media_context_values, + parse_base64_data_uri, validate_image, + video_format_from_mime_type, + video_mime_type, ) -# --------------------------------------------------------------------------- -# extract_base64_from_data_uri -# --------------------------------------------------------------------------- + +def test_media_context_builders() -> None: + assert get_media_context("image", {"type": "url", "url": "https://example.com/image.png"}) == { + "type": "image", + "source": {"type": "url", "url": "https://example.com/image.png"}, + } + assert get_media_url_context("audio", "https://example.com/audio.mp3") == { + "type": "audio", + "source": {"type": "url", "url": "https://example.com/audio.mp3"}, + } + assert get_media_base64_context("video", "video/mp4", "abc123") == { + "type": "video", + "source": {"type": "base64", "media_type": "video/mp4", "data": "abc123"}, + } + + +def test_normalize_media_context_values() -> None: + assert normalize_media_context_values("single") == ["single"] + assert normalize_media_context_values(["one", "two"]) == ["one", "two"] + assert normalize_media_context_values(json.dumps(["one", "two"])) == ["one", "two"] + assert normalize_media_context_values(json.dumps({"nested": "value"})) == [json.dumps({"nested": "value"})] + assert normalize_media_context_values(lazy.np.array(["one", "two"])) == ["one", "two"] + + +def test_parse_base64_data_uri() -> None: + assert parse_base64_data_uri("data:audio/mpeg;base64,abc123") == ("audio/mpeg", "abc123") + assert parse_base64_data_uri("abc123") is None + + +def test_media_url_detection() -> None: + assert is_media_url("https://example.com/download?id=123") is True + assert is_media_url("http://example.com/media") is True + assert is_media_url("ftp://example.com/media") is False + assert is_media_url(123) is False # type: ignore[arg-type] + + +def test_local_media_path_detection() -> None: + assert is_audio_path("screen_recording.mp3") is True + assert is_audio_path("nested/screen_recording.wav") is True + assert is_audio_path("https://example.com/audio.mp3") is False + assert is_video_path("screen_recording.mp4") is True + assert is_video_path("nested/screen_recording.webm") is True + assert is_video_path("https://example.com/video.mp4") is False + + +def test_media_format_mime_helpers() -> None: + assert ImageFormat.PNG.value == "png" + assert image_format_from_mime_type("image/png") == ImageFormat.PNG + assert image_format_from_mime_type("image/jpeg") == ImageFormat.JPG + assert audio_mime_type(AudioFormat.MP3) == "audio/mpeg" + assert audio_format_from_mime_type("audio/mpeg") == AudioFormat.MP3 + assert audio_format_from_mime_type("audio/mp3") == AudioFormat.MP3 + assert audio_format_from_mime_type("audio/x-wav") == AudioFormat.WAV + assert video_mime_type(VideoFormat.MP4) == "video/mp4" + assert video_format_from_mime_type("video/mp4") == VideoFormat.MP4 + assert video_format_from_mime_type("VIDEO/MP4") == VideoFormat.MP4 def test_extract_base64_from_data_uri_with_prefix() -> None: @@ -45,11 +115,6 @@ def test_extract_base64_invalid_data_uri_raises_error() -> None: extract_base64_from_data_uri("data:image/png;base64") -# --------------------------------------------------------------------------- -# decode_base64_image -# --------------------------------------------------------------------------- - - def test_decode_base64_image_valid() -> None: png_bytes = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01" base64_data = base64.b64encode(png_bytes).decode() @@ -70,11 +135,6 @@ def test_decode_base64_image_invalid_raises_error() -> None: decode_base64_image("not-valid-base64!!!") -# --------------------------------------------------------------------------- -# detect_image_format (magic bytes) -# --------------------------------------------------------------------------- - - @pytest.mark.parametrize( "header_bytes,expected_format", [ @@ -112,11 +172,6 @@ def test_detect_image_format_unknown_raises_error() -> None: detect_image_format(unknown_bytes) -# --------------------------------------------------------------------------- -# is_image_path -# --------------------------------------------------------------------------- - - @pytest.mark.parametrize( "value,expected", [ @@ -134,11 +189,6 @@ def test_is_image_path(value: str, expected: bool) -> None: assert is_image_path(value) is expected -# --------------------------------------------------------------------------- -# is_image_url -# --------------------------------------------------------------------------- - - @pytest.mark.parametrize( "value,expected", [ @@ -154,11 +204,6 @@ def test_is_image_url(value: str, expected: bool) -> None: assert is_image_url(value) is expected -# --------------------------------------------------------------------------- -# is_base64_image -# --------------------------------------------------------------------------- - - def test_is_base64_image_data_uri() -> None: assert is_base64_image("data:image/png;base64,iVBORw0KGgo") is True @@ -177,26 +222,16 @@ def test_is_base64_image_invalid_base64_decode() -> None: assert is_base64_image(invalid_base64) is False -# --------------------------------------------------------------------------- -# Non-string guard (is_image_path, is_base64_image, is_image_url) -# --------------------------------------------------------------------------- - - @pytest.mark.parametrize( "func", [is_image_path, is_base64_image, is_image_url], ids=["is_image_path", "is_base64_image", "is_image_url"], ) @pytest.mark.parametrize("value", [123, None, []], ids=["int", "none", "list"]) -def test_non_string_input_returns_false(func: object, value: object) -> None: +def test_image_media_helpers_non_string_input_returns_false(func: Callable[..., bool], value: object) -> None: assert func(value) is False -# --------------------------------------------------------------------------- -# is_image_diffusion_model -# --------------------------------------------------------------------------- - - @pytest.mark.parametrize( "model_name,expected", [ @@ -232,11 +267,6 @@ def test_is_image_diffusion_model(model_name: str, expected: bool) -> None: assert is_image_diffusion_model(model_name) is expected -# --------------------------------------------------------------------------- -# validate_image -# --------------------------------------------------------------------------- - - def test_validate_image_valid_png(tmp_path: Path, sample_png_bytes: bytes) -> None: image_path = tmp_path / "test.png" image_path.write_bytes(sample_png_bytes) @@ -256,11 +286,6 @@ def test_validate_image_nonexistent_raises_error(tmp_path: Path) -> None: validate_image(image_path) -# --------------------------------------------------------------------------- -# load_image_path_to_base64 -# --------------------------------------------------------------------------- - - def test_load_image_path_to_base64_absolute_path(tmp_path: Path) -> None: img = lazy.Image.new("RGB", (1, 1), color="blue") image_path = tmp_path / "test.png" diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py index ba432ce2c..fd002f9a6 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py @@ -279,9 +279,8 @@ def inference_parameters(self) -> BaseInferenceParams: def _build_multi_modal_context(self, record: dict) -> list[dict[str, Any]] | None: """Build multi-modal context from the config's multi_modal_context list. - Passes base_path to get_contexts() so that generated image file paths - (stored under base_dataset_path in create mode) can be resolved to base64 - before being sent to the model endpoint. + Passes base_path to get_contexts() so context types that support + artifact-relative resolution can use the dataset artifact directory. Args: record: The deserialized record containing column values. diff --git a/packages/data-designer-engine/src/data_designer/engine/mcp/io.py b/packages/data-designer-engine/src/data_designer/engine/mcp/io.py index 807e4b890..d9b2b8efd 100644 --- a/packages/data-designer-engine/src/data_designer/engine/mcp/io.py +++ b/packages/data-designer-engine/src/data_designer/engine/mcp/io.py @@ -43,7 +43,7 @@ from mcp.client.streamable_http import streamablehttp_client from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider, MCPProviderT -from data_designer.config.utils.image_helpers import ( +from data_designer.config.utils.media_helpers import ( decode_base64_image, detect_image_format, extract_base64_from_data_uri, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py index 204b46677..1f4d5081d 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py @@ -6,6 +6,7 @@ from typing import Any from data_designer.engine.models.clients.adapters.anthropic_translation import ( + UnsupportedAnthropicMediaBlockError, build_anthropic_payload, parse_anthropic_response, ) @@ -106,6 +107,16 @@ async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerat def _build_payload_or_raise(self, request: ChatCompletionRequest) -> dict[str, Any]: try: return build_anthropic_payload(request) + except UnsupportedAnthropicMediaBlockError as exc: + raise ProviderError.unsupported_capability( + provider_name=self.provider_name, + model_name=request.model, + operation=f"{exc.modality}-context", + message=( + f"Provider {self.provider_name!r} does not support {exc.modality} context " + f"for model {request.model!r}." + ), + ) from exc except ValueError as exc: raise ProviderError( kind=ProviderErrorKind.BAD_REQUEST, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py index 21a959f40..5ba186add 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic_translation.py @@ -4,9 +4,14 @@ from __future__ import annotations import json -import re from typing import Any +from data_designer.config.models import Modality +from data_designer.config.utils.media_helpers import ( + get_media_base64_context, + get_media_url_context, + parse_base64_data_uri, +) from data_designer.engine.models.clients.parsing import extract_usage, fill_reasoning_token_count_from_content from data_designer.engine.models.clients.types import ( AssistantMessage, @@ -17,7 +22,24 @@ ) _DEFAULT_MAX_TOKENS = 4096 -_DATA_URI_RE = re.compile(r"^data:(?P[^;]+);base64,(?P.+)$") +# Include canonical blocks from *Context.get_contexts and provider-specific +# blocks that users may author directly in templates or tool-result content. +_UNSUPPORTED_MEDIA_BLOCK_MODALITIES: dict[str, str] = { + "audio": "audio", + "audio_url": "audio", + "input_audio": "audio", + "video": "video", + "video_url": "video", + "input_video": "video", +} + + +class UnsupportedAnthropicMediaBlockError(ValueError): + """Raised when a canonical media block cannot be translated to Anthropic.""" + + def __init__(self, modality: str) -> None: + self.modality = modality + super().__init__(f"Anthropic adapter does not support {modality} context blocks.") def merge_system_parts(parts: list[str | list[dict[str, Any]]]) -> str | list[dict[str, Any]]: @@ -196,6 +218,12 @@ def translate_content_blocks(content: Any) -> list[dict[str, Any]]: if isinstance(block, dict) and block.get("type") == "image_url": translated.append(translate_image_url_block(block)) continue + if isinstance(block, dict) and block.get("type") == "image": + translated.append(translate_canonical_image_block(block)) + continue + block_type = block.get("type") if isinstance(block, dict) else None + if isinstance(block_type, str) and block_type in _UNSUPPORTED_MEDIA_BLOCK_MODALITIES: + raise UnsupportedAnthropicMediaBlockError(_UNSUPPORTED_MEDIA_BLOCK_MODALITIES[block_type]) # Anthropic rejects empty text blocks β€” drop them. if isinstance(block, dict) and block.get("type") == "text" and not block.get("text"): continue @@ -326,18 +354,27 @@ def translate_image_url_block(block: dict[str, Any]) -> dict[str, Any]: url = image_url.get("url", "") - match = _DATA_URI_RE.match(url) - if match: - return { - "type": "image", - "source": { - "type": "base64", - "media_type": match.group("media_type"), - "data": match.group("data"), - }, - } + parsed = parse_base64_data_uri(url) + if parsed is not None: + media_type, data = parsed + return get_media_base64_context(Modality.IMAGE.value, media_type, data) - return { - "type": "image", - "source": {"type": "url", "url": url}, - } + return get_media_url_context(Modality.IMAGE.value, url) + + +def translate_canonical_image_block(block: dict[str, Any]) -> dict[str, Any]: + source = block.get("source") + if not isinstance(source, dict): + raise ValueError(f"Canonical image block must include a source object, got: {block!r}") + + source_type = source.get("type") + if source_type == "url": + return get_media_url_context(Modality.IMAGE.value, source.get("url", "")) + if source_type == "base64": + media_type = source.get("media_type") + data = source.get("data") + if not isinstance(media_type, str) or not isinstance(data, str): + raise ValueError(f"Canonical image base64 source must include media_type and data, got: {source!r}") + return get_media_base64_context(Modality.IMAGE.value, media_type, data) + + raise ValueError(f"Unsupported canonical image source type {source_type!r}") diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py index 44ab1f1d5..e9f9228a1 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py @@ -5,9 +5,11 @@ from typing import Any +from data_designer.config.utils.media_helpers import audio_format_from_mime_type from data_designer.engine.models.clients.adapters.http_model_client import ( HttpModelClient, ) +from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind from data_designer.engine.models.clients.parsing import ( aextract_images_from_chat_response, aextract_images_from_image_response, @@ -61,13 +63,23 @@ def supports_image_generation(self) -> bool: def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: transport = TransportKwargs.from_request(request) - payload = {"model": request.model, "messages": request.messages, **transport.body} + messages = translate_openai_compatible_messages( + request.messages, + provider_name=self.provider_name, + model_name=request.model, + ) + payload = {"model": request.model, "messages": messages, **transport.body} response_json = self._post_sync(self._ROUTE_CHAT, payload, transport.headers, request.model, transport.timeout) return parse_chat_completion_response(response_json) async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: transport = TransportKwargs.from_request(request) - payload = {"model": request.model, "messages": request.messages, **transport.body} + messages = translate_openai_compatible_messages( + request.messages, + provider_name=self.provider_name, + model_name=request.model, + ) + payload = {"model": request.model, "messages": messages, **transport.body} response_json = await self._apost( self._ROUTE_CHAT, payload, transport.headers, request.model, transport.timeout ) @@ -101,7 +113,12 @@ def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResp transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE) if request.messages is not None: route = self._ROUTE_CHAT - payload = {"model": request.model, "messages": request.messages, **transport.body} + messages = translate_openai_compatible_messages( + request.messages, + provider_name=self.provider_name, + model_name=request.model, + ) + payload = {"model": request.model, "messages": messages, **transport.body} else: route = self._ROUTE_IMAGE payload = {"model": request.model, "prompt": request.prompt, **transport.body} @@ -112,7 +129,12 @@ async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerat transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE) if request.messages is not None: route = self._ROUTE_CHAT - payload = {"model": request.model, "messages": request.messages, **transport.body} + messages = translate_openai_compatible_messages( + request.messages, + provider_name=self.provider_name, + model_name=request.model, + ) + payload = {"model": request.model, "messages": messages, **transport.body} else: route = self._ROUTE_IMAGE payload = {"model": request.model, "prompt": request.prompt, **transport.body} @@ -133,6 +155,114 @@ def _build_headers(self, extra_headers: dict[str, str]) -> dict[str, str]: # --------------------------------------------------------------------------- +def translate_openai_compatible_messages( + messages: list[dict[str, Any]], + *, + provider_name: str, + model_name: str, +) -> list[dict[str, Any]]: + """Translate canonical media blocks to OpenAI-compatible content blocks.""" + translated_messages: list[dict[str, Any]] = [] + for message in messages: + translated = dict(message) + if "content" in translated: + try: + translated["content"] = translate_openai_compatible_content_blocks(translated["content"]) + except ValueError as exc: + raise ProviderError( + kind=ProviderErrorKind.BAD_REQUEST, + message=str(exc), + provider_name=provider_name, + model_name=model_name, + cause=exc, + ) from exc + translated_messages.append(translated) + return translated_messages + + +def translate_openai_compatible_content_blocks(content: Any) -> Any: + if not isinstance(content, list): + return content + + return [translate_openai_compatible_content_block(block) for block in content] + + +def translate_openai_compatible_content_block(block: Any) -> Any: + if not isinstance(block, dict): + return block + + block_type = block.get("type") + if not isinstance(block_type, str): + return block + if block_type in {"audio_url", "image_url", "input_audio", "input_video", "text", "video_url"}: + return block + if block_type == "image": + return _translate_canonical_image_block(block) + if block_type == "audio": + return _translate_canonical_audio_block(block) + if block_type == "video": + return _translate_canonical_video_block(block) + return block + + +def _translate_canonical_image_block(block: dict[str, Any]) -> dict[str, Any]: + source = _get_media_source(block, modality="image") + source_type = source.get("type") + if source_type == "url": + return {"type": "image_url", "image_url": {"url": source.get("url", "")}} + if source_type == "base64": + media_type = source.get("media_type") + data = source.get("data") + if not isinstance(media_type, str) or not isinstance(data, str): + raise ValueError(f"Canonical image base64 source must include media_type and data, got: {source!r}") + return {"type": "image_url", "image_url": {"url": f"data:{media_type};base64,{data}"}} + raise ValueError(f"Unsupported canonical image source type {source_type!r}") + + +def _translate_canonical_audio_block(block: dict[str, Any]) -> dict[str, Any]: + source = _get_media_source(block, modality="audio") + source_type = source.get("type") + if source_type == "url": + # ``audio_url`` is an OpenAI-compatible extension used by providers such as vLLM/NVIDIA, + # not by OpenAI's hosted Chat Completions route. + return {"type": "audio_url", "audio_url": {"url": source.get("url", "")}} + if source_type == "base64": + media_type = source.get("media_type") + data = source.get("data") + if not isinstance(media_type, str) or not isinstance(data, str): + raise ValueError(f"Canonical audio base64 source must include media_type and data, got: {source!r}") + audio_format = audio_format_from_mime_type(media_type) + if audio_format is None: + raise ValueError(f"Unsupported canonical audio media type {media_type!r}") + return {"type": "input_audio", "input_audio": {"data": data, "format": audio_format.value}} + raise ValueError(f"Unsupported canonical audio source type {source_type!r}") + + +def _translate_canonical_video_block(block: dict[str, Any]) -> dict[str, Any]: + source = _get_media_source(block, modality="video") + source_type = source.get("type") + if source_type == "url": + # ``video_url`` is an OpenAI-compatible extension used by providers such as vLLM/NVIDIA, + # not by OpenAI's hosted Chat Completions route. + return {"type": "video_url", "video_url": {"url": source.get("url", "")}} + if source_type == "base64": + media_type = source.get("media_type") + data = source.get("data") + if not isinstance(media_type, str) or not isinstance(data, str): + raise ValueError(f"Canonical video base64 source must include media_type and data, got: {source!r}") + # No widely supported ``input_video`` block exists; capable OpenAI-compatible + # providers may accept a data URI in ``video_url``. + return {"type": "video_url", "video_url": {"url": f"data:{media_type};base64,{data}"}} + raise ValueError(f"Unsupported canonical video source type {source_type!r}") + + +def _get_media_source(block: dict[str, Any], *, modality: str) -> dict[str, Any]: + source = block.get("source") + if not isinstance(source, dict): + raise ValueError(f"Canonical {modality} block must include a source object, got: {block!r}") + return source + + def _parse_embedding_json(response_json: dict[str, Any]) -> EmbeddingResponse: data = response_json.get("data") or [] vectors = [extract_embedding_vector(item) for item in data] diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py index 02a3223a2..05509799c 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py @@ -11,7 +11,7 @@ from dataclasses import replace from typing import Any -from data_designer.config.utils.image_helpers import ( +from data_designer.config.utils.media_helpers import ( aload_image_url_to_base64, extract_base64_from_data_uri, is_base64_image, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 81a935282..2c0a7a9ab 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -15,7 +15,7 @@ OPENROUTER_ATTRIBUTION_HEADERS, OPENROUTER_PROVIDER_NAME, ) -from data_designer.config.utils.image_helpers import is_image_diffusion_model +from data_designer.config.utils.media_helpers import is_image_diffusion_model from data_designer.engine.mcp.errors import MCPConfigurationError from data_designer.engine.model_provider import ModelProviderRegistry from data_designer.engine.models.clients.types import ( @@ -309,6 +309,7 @@ def generate( prompt. parser (func(str) -> Any): A function applied to the LLM response which processes an LLM response into some output object. Default: identity function. + multi_modal_context: Optional list of image, audio, or video context blocks. tool_alias (str | None): Optional tool configuration alias. When provided, the model may call permitted tools from the configured MCP providers. The alias must reference a ToolConfig registered in the MCPRegistry. @@ -627,7 +628,7 @@ def generate_image( Args: prompt: The prompt for image generation - multi_modal_context: Optional list of image contexts for multi-modal generation. + multi_modal_context: Optional list of image, audio, or video contexts for multi-modal generation. Only used with autoregressive models via chat completions API. skip_usage_tracking: Whether to skip usage tracking **kwargs: Additional arguments to pass to the model @@ -686,7 +687,7 @@ async def agenerate_image( Args: prompt: The prompt for image generation - multi_modal_context: Optional list of image contexts for multi-modal generation. + multi_modal_context: Optional list of image, audio, or video contexts for multi-modal generation. Only used with autoregressive models via chat completions API. skip_usage_tracking: Whether to skip usage tracking **kwargs: Additional arguments to pass to the model diff --git a/packages/data-designer-engine/src/data_designer/engine/models/utils.py b/packages/data-designer-engine/src/data_designer/engine/models/utils.py index f7183e83d..a4b51e4bd 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/utils.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/utils.py @@ -18,7 +18,7 @@ class ChatMessage: Attributes: role: The role of the message sender. One of 'user', 'assistant', 'system', or 'tool'. content: The message content. Can be a string or a list of content blocks - for multimodal messages (e.g., text + images). + for multimodal messages (e.g., text + image/audio/video context). reasoning_content: Optional reasoning/thinking content from the assistant, typically from extended thinking or chain-of-thought models. tool_calls: Optional list of tool calls requested by the assistant. diff --git a/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py index 1c887c808..6bd6e3dd9 100644 --- a/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py +++ b/packages/data-designer-engine/src/data_designer/engine/storage/media_storage.py @@ -6,7 +6,7 @@ import uuid from pathlib import Path -from data_designer.config.utils.image_helpers import decode_base64_image, detect_image_format, validate_image +from data_designer.config.utils.media_helpers import decode_base64_image, detect_image_format, validate_image from data_designer.config.utils.type_helpers import StrEnum IMAGES_SUBDIR = "images" diff --git a/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py index 6f5e8799e..bc5dc0c51 100644 --- a/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py +++ b/packages/data-designer-engine/tests/engine/column_generators/generators/test_image.py @@ -1,13 +1,23 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import base64 from unittest.mock import Mock, patch import pytest from data_designer.config.column_configs import ImageColumnConfig -from data_designer.config.models import ImageContext, ImageFormat, ModalityDataType +from data_designer.config.models import ( + AudioContext, + ImageContext, + ImageFormat, + Modality, + ModalityDataType, + VideoContext, +) +from data_designer.config.utils.media_helpers import get_media_base64_context, get_media_url_context from data_designer.engine.column_generators.generators.base import GenerationStrategy from data_designer.engine.column_generators.generators.image import ImageCellGenerator from data_designer.engine.processing.ginja.exceptions import UserTemplateError @@ -175,8 +185,9 @@ def test_image_cell_generator_with_multi_modal_context(stub_resource_provider): assert call_args.kwargs["prompt"] == "Generate a similar image to the reference" assert call_args.kwargs["multi_modal_context"] is not None assert len(call_args.kwargs["multi_modal_context"]) == 1 - assert call_args.kwargs["multi_modal_context"][0]["type"] == "image_url" - assert call_args.kwargs["multi_modal_context"][0]["image_url"] == {"url": "https://example.com/image.png"} + assert call_args.kwargs["multi_modal_context"][0] == get_media_url_context( + Modality.IMAGE.value, "https://example.com/image.png" + ) def test_image_cell_generator_with_base64_multi_modal_context(stub_resource_provider): @@ -218,9 +229,47 @@ def test_image_cell_generator_with_base64_multi_modal_context(stub_resource_prov assert call_args.kwargs["prompt"] == "Generate a variation of this image" assert call_args.kwargs["multi_modal_context"] is not None assert len(call_args.kwargs["multi_modal_context"]) == 1 - assert call_args.kwargs["multi_modal_context"][0]["type"] == "image_url" - # Should be formatted as data URI - assert "data:image/png;base64," in call_args.kwargs["multi_modal_context"][0]["image_url"]["url"] + assert call_args.kwargs["multi_modal_context"][0] == get_media_base64_context( + Modality.IMAGE.value, "image/png", "iVBORw0KGgoAAAANS" + ) + + +def test_image_cell_generator_with_mixed_media_context(stub_resource_provider: Mock) -> None: + config = ImageColumnConfig( + name="test_image", + prompt="Generate a poster from this media", + model_alias="test_model", + multi_modal_context=[ + ImageContext(column_name="reference_image", data_type=ModalityDataType.URL), + AudioContext(column_name="reference_audio", data_type=ModalityDataType.URL), + VideoContext(column_name="reference_video", data_type=ModalityDataType.URL), + ], + ) + + mock_storage = Mock() + mock_storage.save_base64_image.return_value = "images/generated.png" + stub_resource_provider.artifact_storage.media_storage = mock_storage + + with patch.object( + stub_resource_provider.model_registry.get_model.return_value, + "generate_image", + return_value=["base64_generated_image"], + ) as mock_generate: + generator = ImageCellGenerator(config=config, resource_provider=stub_resource_provider) + generator.generate( + data={ + "reference_image": "https://example.com/image.png", + "reference_audio": "https://example.com/audio.mp3", + "reference_video": "https://example.com/video.mp4", + } + ) + + mock_generate.assert_called_once() + assert mock_generate.call_args.kwargs["multi_modal_context"] == [ + get_media_url_context(Modality.IMAGE.value, "https://example.com/image.png"), + get_media_url_context(Modality.AUDIO.value, "https://example.com/audio.mp3"), + get_media_url_context(Modality.VIDEO.value, "https://example.com/video.mp4"), + ] def test_image_cell_generator_build_multi_modal_context_returns_none_when_not_configured( @@ -274,10 +323,9 @@ def test_image_cell_generator_auto_resolves_generated_image_file_path(stub_resou context = call_args.kwargs["multi_modal_context"] assert context is not None assert len(context) == 1 - assert context[0]["type"] == "image_url" # Should contain base64 data, NOT the file path expected_b64 = base64.b64encode(png_bytes).decode() - assert expected_b64 in context[0]["image_url"]["url"] + assert context[0] == get_media_base64_context(Modality.IMAGE.value, "image/png", expected_b64) def test_image_cell_generator_auto_detect_passes_through_urls(stub_resource_provider: Mock) -> None: @@ -306,4 +354,4 @@ def test_image_cell_generator_auto_detect_passes_through_urls(stub_resource_prov mock_generate.assert_called_once() context = mock_generate.call_args.kwargs["multi_modal_context"] assert context is not None - assert context[0]["image_url"] == {"url": "https://example.com/image.png"} + assert context[0] == get_media_url_context(Modality.IMAGE.value, "https://example.com/image.png") diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py index 1b1022d15..b0c71cc5a 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic.py @@ -555,6 +555,32 @@ def test_completion_preserves_non_image_content_blocks() -> None: assert content[1] == {"type": "custom_block", "data": "something"} +@pytest.mark.parametrize("modality", ["audio", "video"]) +def test_completion_rejects_audio_video_context_as_unsupported(modality: str) -> None: + sync_mock = make_mock_sync_client(_text_response()) + client = _make_client(sync_client=sync_mock) + + request = ChatCompletionRequest( + model=MODEL, + messages=[ + { + "role": "user", + "content": [ + {"type": modality, "source": {"type": "url", "url": "https://example.com/media"}}, + {"type": "text", "text": "Describe this."}, + ], + }, + ], + ) + + with pytest.raises(ProviderError) as exc_info: + client.completion(request) + + assert exc_info.value.kind == ProviderErrorKind.UNSUPPORTED_CAPABILITY + assert modality in exc_info.value.message + sync_mock.post.assert_not_called() + + def test_completion_passes_string_content_unchanged() -> None: sync_mock = make_mock_sync_client(_text_response()) client = _make_client(sync_client=sync_mock) diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py index 2aad03ced..450a32b0d 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_anthropic_translation.py @@ -7,8 +7,11 @@ import pytest +from data_designer.config.models import Modality +from data_designer.config.utils.media_helpers import get_media_base64_context, get_media_url_context from data_designer.engine.mcp.registry import MCPToolDefinition from data_designer.engine.models.clients.adapters.anthropic_translation import ( + UnsupportedAnthropicMediaBlockError, build_anthropic_payload, extract_system_content, merge_system_parts, @@ -67,7 +70,7 @@ def test_build_anthropic_payload_preserves_multimodal_system_content() -> None: assert payload["system"] == [ {"type": "text", "text": "Describe this image."}, - {"type": "image", "source": {"type": "url", "url": "https://example.com/reference.png"}}, + get_media_url_context(Modality.IMAGE.value, "https://example.com/reference.png"), ] @@ -168,14 +171,14 @@ def test_translate_request_messages_merges_parallel_tool_results() -> None: ], [ {"type": "text", "text": "Rule 1"}, - {"type": "image", "source": {"type": "url", "url": "https://example.com/reference.png"}}, + get_media_url_context(Modality.IMAGE.value, "https://example.com/reference.png"), {"type": "text", "text": "Rule 2"}, ], id="mixed-text-and-image-returns-blocks", ), pytest.param( [{"type": "image_url", "image_url": {"url": "https://example.com/reference.png"}}], - [{"type": "image", "source": {"type": "url", "url": "https://example.com/reference.png"}}], + [get_media_url_context(Modality.IMAGE.value, "https://example.com/reference.png")], id="image-only-returns-blocks", ), pytest.param( @@ -211,13 +214,13 @@ def test_extract_system_content_normalizes_supported_inputs( "Text preamble", [ {"type": "text", "text": "Rule 1"}, - {"type": "image", "source": {"type": "url", "url": "https://example.com/img.png"}}, + get_media_url_context(Modality.IMAGE.value, "https://example.com/img.png"), ], ], [ {"type": "text", "text": "Text preamble"}, {"type": "text", "text": "Rule 1"}, - {"type": "image", "source": {"type": "url", "url": "https://example.com/img.png"}}, + get_media_url_context(Modality.IMAGE.value, "https://example.com/img.png"), ], id="mixed-string-and-blocks", ), @@ -336,12 +339,46 @@ def test_translate_content_blocks_converts_images_and_preserves_other_blocks() - ) assert blocks == [ - {"type": "image", "source": {"type": "url", "url": "https://example.com/cat.png"}}, + get_media_url_context(Modality.IMAGE.value, "https://example.com/cat.png"), {"type": "text", "text": "Caption"}, {"type": "custom_block", "value": "kept"}, ] +def test_translate_content_blocks_converts_canonical_images() -> None: + blocks = translate_content_blocks( + [ + get_media_base64_context(Modality.IMAGE.value, "image/png", "iVBOR..."), + {"type": "text", "text": "Caption"}, + ] + ) + + assert blocks == [ + get_media_base64_context(Modality.IMAGE.value, "image/png", "iVBOR..."), + {"type": "text", "text": "Caption"}, + ] + + +@pytest.mark.parametrize("modality", ["audio", "video"]) +def test_translate_content_blocks_rejects_unsupported_media(modality: str) -> None: + with pytest.raises(UnsupportedAnthropicMediaBlockError, match=f"{modality} context"): + translate_content_blocks([{"type": modality, "source": {"type": "url", "url": "https://example.com/media"}}]) + + +@pytest.mark.parametrize( + ("block_type", "modality"), + [ + ("audio_url", "audio"), + ("input_audio", "audio"), + ("video_url", "video"), + ("input_video", "video"), + ], +) +def test_translate_content_blocks_rejects_provider_specific_media(block_type: str, modality: str) -> None: + with pytest.raises(UnsupportedAnthropicMediaBlockError, match=f"{modality} context"): + translate_content_blocks([{"type": block_type}]) + + def test_translate_content_blocks_rejects_malformed_image_url_block() -> None: with pytest.raises(TypeError, match="image_url block must contain a dict"): translate_content_blocks( @@ -357,12 +394,12 @@ def test_translate_content_blocks_rejects_malformed_image_url_block() -> None: [ pytest.param( {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR..."}}, - {"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": "iVBOR..."}}, + get_media_base64_context(Modality.IMAGE.value, "image/png", "iVBOR..."), id="data-uri-dict", ), pytest.param( {"type": "image_url", "image_url": {"url": "https://example.com/cat.png"}}, - {"type": "image", "source": {"type": "url", "url": "https://example.com/cat.png"}}, + get_media_url_context(Modality.IMAGE.value, "https://example.com/cat.png"), id="remote-url-dict", ), ], @@ -491,7 +528,7 @@ def test_translate_tool_result_message_requires_tool_call_id(message: dict[str, {"type": "text", "text": "Caption"}, ], [ - {"type": "image", "source": {"type": "url", "url": "https://example.com/chart.png"}}, + get_media_url_context(Modality.IMAGE.value, "https://example.com/chart.png"), {"type": "text", "text": "Caption"}, ], id="mixed-blocks", @@ -503,14 +540,7 @@ def test_translate_tool_result_message_requires_tool_call_id(message: dict[str, ], [ {"type": "text", "text": "Rendered chart:"}, - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": "iVBORw0KGgo=", - }, - }, + get_media_base64_context(Modality.IMAGE.value, "image/png", "iVBORw0KGgo="), ], id="mixed-blocks-with-data-uri", ), diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py index 3284d79b5..262d09c90 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_openai_compatible.py @@ -8,6 +8,8 @@ import pytest +from data_designer.config.models import Modality +from data_designer.config.utils.media_helpers import get_media_base64_context, get_media_url_context from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode from data_designer.engine.models.clients.adapters.openai_compatible import OpenAICompatibleClient from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind @@ -317,6 +319,138 @@ def test_completion_forwards_multimodal_tool_result_content_unchanged() -> None: assert payload["messages"][0]["content"] == content +def test_completion_translates_canonical_image_blocks() -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + image_block = get_media_base64_context(Modality.IMAGE.value, "image/png", "iVBOR...") + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [image_block, {"type": "text", "text": "What is this?"}]}], + ) + client.completion(request) + + payload = sync_mock.post.call_args.kwargs["json"] + content = payload["messages"][0]["content"] + assert content[0] == {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR..."}} + + +def test_completion_translates_base64_audio_blocks() -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + audio_block = get_media_base64_context(Modality.AUDIO.value, "audio/mpeg", "abc123") + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [audio_block, {"type": "text", "text": "Transcribe this."}]}], + ) + client.completion(request) + + payload = sync_mock.post.call_args.kwargs["json"] + content = payload["messages"][0]["content"] + assert content[0] == {"type": "input_audio", "input_audio": {"data": "abc123", "format": "mp3"}} + + +def test_completion_rejects_unsupported_canonical_audio_media_type() -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + audio_block = get_media_base64_context(Modality.AUDIO.value, "audio/flac", "abc123") + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [audio_block, {"type": "text", "text": "Transcribe this."}]}], + ) + + with pytest.raises(ProviderError) as exc_info: + client.completion(request) + + assert exc_info.value.kind == ProviderErrorKind.BAD_REQUEST + assert exc_info.value.provider_name == PROVIDER + assert exc_info.value.model_name == MODEL + assert "audio/flac" in exc_info.value.message + sync_mock.post.assert_not_called() + + +def test_completion_translates_audio_url_blocks() -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + audio_block = get_media_url_context(Modality.AUDIO.value, "https://example.com/download?id=123") + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [audio_block, {"type": "text", "text": "Transcribe this."}]}], + ) + client.completion(request) + + payload = sync_mock.post.call_args.kwargs["json"] + content = payload["messages"][0]["content"] + assert content[0] == {"type": "audio_url", "audio_url": {"url": "https://example.com/download?id=123"}} + + +def test_completion_translates_local_audio_path_blocks() -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + audio_block = get_media_url_context(Modality.AUDIO.value, "recordings/speech.wav") + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [audio_block, {"type": "text", "text": "Transcribe this."}]}], + ) + client.completion(request) + + payload = sync_mock.post.call_args.kwargs["json"] + content = payload["messages"][0]["content"] + assert content[0] == {"type": "audio_url", "audio_url": {"url": "recordings/speech.wav"}} + + +def test_completion_translates_video_blocks() -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + video_block = get_media_url_context(Modality.VIDEO.value, "https://example.com/download?id=123") + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [video_block, {"type": "text", "text": "Describe this."}]}], + ) + client.completion(request) + + payload = sync_mock.post.call_args.kwargs["json"] + content = payload["messages"][0]["content"] + assert content[0] == {"type": "video_url", "video_url": {"url": "https://example.com/download?id=123"}} + + +def test_completion_translates_local_video_path_blocks() -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + video_block = get_media_url_context(Modality.VIDEO.value, "clips/screen_recording.mp4") + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [video_block, {"type": "text", "text": "Describe this."}]}], + ) + client.completion(request) + + payload = sync_mock.post.call_args.kwargs["json"] + content = payload["messages"][0]["content"] + assert content[0] == {"type": "video_url", "video_url": {"url": "clips/screen_recording.mp4"}} + + +def test_completion_translates_base64_video_blocks() -> None: + sync_mock = make_mock_sync_client(_chat_response()) + client = _make_client(sync_client=sync_mock) + + video_block = get_media_base64_context(Modality.VIDEO.value, "video/mp4", "abc123") + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [video_block, {"type": "text", "text": "Describe this."}]}], + ) + client.completion(request) + + payload = sync_mock.post.call_args.kwargs["json"] + content = payload["messages"][0]["content"] + assert content[0] == {"type": "video_url", "video_url": {"url": "data:video/mp4;base64,abc123"}} + + # --- Auth headers --- @@ -386,6 +520,25 @@ def test_http_error_maps_to_provider_error( assert exc_info.value.kind == expected_kind +def test_http_400_error_preserves_provider_message() -> None: + error_json = {"error": {"message": "Unsupported content type: audio_url"}} + sync_mock = make_mock_sync_client(error_json, status_code=400) + client = _make_client(sync_client=sync_mock) + + audio_block = get_media_url_context(Modality.AUDIO.value, "https://example.com/speech.mp3") + request = ChatCompletionRequest( + model=MODEL, + messages=[{"role": "user", "content": [audio_block, {"type": "text", "text": "Transcribe this."}]}], + ) + with pytest.raises(ProviderError) as exc_info: + client.completion(request) + + assert exc_info.value.kind == ProviderErrorKind.BAD_REQUEST + assert exc_info.value.status_code == 400 + assert "Unsupported content type: audio_url" in exc_info.value.message + sync_mock.post.assert_called_once() + + def test_transport_timeout_raises_timeout_error() -> None: sync_mock = MagicMock() sync_mock.post = MagicMock(side_effect=TimeoutError("timed out")) diff --git a/packages/data-designer-engine/tests/engine/models/test_model_utils.py b/packages/data-designer-engine/tests/engine/models/test_model_utils.py index c2f07c068..7149560ba 100644 --- a/packages/data-designer-engine/tests/engine/models/test_model_utils.py +++ b/packages/data-designer-engine/tests/engine/models/test_model_utils.py @@ -1,6 +1,10 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from data_designer.config.models import Modality +from data_designer.config.utils.media_helpers import get_media_base64_context, get_media_url_context from data_designer.engine.models.utils import ChatMessage, prompt_to_messages @@ -33,3 +37,15 @@ def test_chat_message_as_tool_accepts_multimodal_content() -> None: assert message.content == content assert message.to_dict()["content"] == content + + +def test_prompt_to_messages_preserves_mixed_media_context_order() -> None: + context = [ + get_media_url_context(Modality.IMAGE.value, "https://example.com/image.png"), + get_media_base64_context(Modality.AUDIO.value, "audio/mpeg", "abc123"), + get_media_url_context(Modality.VIDEO.value, "https://example.com/video.mp4"), + ] + + assert prompt_to_messages(user_prompt="describe", multi_modal_context=context) == [ + ChatMessage.as_user([*context, {"type": "text", "text": "describe"}]) + ] diff --git a/packages/data-designer-engine/tests/engine/test_validation.py b/packages/data-designer-engine/tests/engine/test_validation.py index 38af52947..1a1191a9e 100644 --- a/packages/data-designer-engine/tests/engine/test_validation.py +++ b/packages/data-designer-engine/tests/engine/test_validation.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from unittest.mock import Mock, patch import pytest @@ -16,7 +18,7 @@ SeedDatasetColumnConfig, ValidationColumnConfig, ) -from data_designer.config.models import ImageContext, ModalityDataType +from data_designer.config.models import AudioContext, ImageContext, ModalityDataType from data_designer.config.processors import ( DropColumnsProcessorConfig, SchemaTransformProcessorConfig, @@ -248,6 +250,18 @@ def test_validate_column_config_with_multi_modal_context(): assert len(violations) == 0 +def test_validate_column_config_with_audio_multi_modal_context() -> None: + column = LLMTextColumnConfig( + name="audio_description", + prompt="Describe the audio.", + model_alias=STUB_MODEL_ALIAS, + multi_modal_context=[AudioContext(column_name="audio_url", data_type=ModalityDataType.URL)], + ) + + violations = validate_prompt_templates([column], [column.name]) + assert len(violations) == 0 + + def test_validate_columns_not_all_dropped(): violations = validate_columns_not_all_dropped( [