diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/LogicalOp.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/LogicalOp.scala index 4e9d6c6e2cd..55e241ecaf3 100644 --- a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/LogicalOp.scala +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/LogicalOp.scala @@ -43,6 +43,7 @@ import org.apache.texera.amber.operator.dummy.DummyOpDesc import org.apache.texera.amber.operator.filter.SpecializedFilterOpDesc import org.apache.texera.amber.operator.hashJoin.HashJoinOpDesc import org.apache.texera.amber.operator.huggingFace.{ + HuggingFaceInferenceOpDesc, HuggingFaceIrisLogisticRegressionOpDesc, HuggingFaceSentimentAnalysisOpDesc, HuggingFaceSpamSMSDetectionOpDesc, @@ -396,6 +397,7 @@ trait StateTransferFunc ), new Type(value = classOf[SklearnDummyClassifierOpDesc], name = "SklearnDummyClassifier"), new Type(value = classOf[SklearnPredictionOpDesc], name = "SklearnPrediction"), + new Type(value = classOf[HuggingFaceInferenceOpDesc], name = "HuggingFace"), new Type( value = classOf[HuggingFaceSentimentAnalysisOpDesc], name = "HuggingFaceSentimentAnalysis" diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDesc.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDesc.scala new file mode 100644 index 00000000000..f70c4409b4c --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDesc.scala @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace + +import com.fasterxml.jackson.annotation.{JsonProperty, JsonPropertyDescription} +import com.kjetland.jackson.jsonSchema.annotations.JsonSchemaTitle +import org.apache.texera.amber.core.tuple.{AttributeType, Schema} +import org.apache.texera.amber.core.workflow.{InputPort, OutputPort, PortIdentity} +import org.apache.texera.amber.operator.PythonOperatorDescriptor +import org.apache.texera.amber.operator.huggingFace.codegen.{ + AudioTaskCodegen, + CodegenContext, + ImageTaskCodegen, + MediaGenCodegen, + PythonCodegenBase, + QaRankingCodegen, + TaskCodegen, + TextGenCodegen +} +import org.apache.texera.amber.operator.metadata.annotations.AutofillAttributeName +import org.apache.texera.amber.operator.metadata.{OperatorGroupConstants, OperatorInfo} +import org.apache.texera.amber.pybuilder.PyStringTypes.EncodableString + +/** + * Generic Hugging Face inference operator. + * + * This is the first slice of a feature that will eventually cover ~20 HF + * pipeline tasks. PR 2 ships text-generation only; image, audio, + * media-generation, and QA task families land in subsequent PRs as new + * `TaskCodegen` implementations registered in `registeredCodegens`. + * + * The Python script that runs at execution time is assembled by + * `PythonCodegenBase.render(ctx, codegen)`, which composes the shared + * provider-fallback / request-loop infrastructure with the per-task + * payload + parse snippets supplied by the selected `TaskCodegen`. + * + * User-provided string fields are typed as [[EncodableString]] so the + * `pyb"..."` macro inside `PythonCodegenBase` emits them as + * base64-decoded expressions at runtime instead of raw Python literals — + * this is what allows the operator to satisfy + * `PythonCodeRawInvalidTextSpec`'s contract that arbitrary `@JsonProperty` + * values must not leak into generated source. + */ +class HuggingFaceInferenceOpDesc extends PythonOperatorDescriptor { + + @JsonProperty(value = "hfApiToken", required = true) + @JsonSchemaTitle("HF API Token") + @JsonPropertyDescription( + "Your Hugging Face API token (from https://huggingface.co/settings/tokens)" + ) + var hfApiToken: EncodableString = "" + + @JsonProperty(value = "task", required = true, defaultValue = "text-generation") + @JsonSchemaTitle("Task") + @JsonPropertyDescription("The Hugging Face pipeline task type") + var task: EncodableString = "text-generation" + + @JsonProperty( + value = "modelId", + required = true, + defaultValue = "Qwen/Qwen2.5-72B-Instruct" + ) + @JsonSchemaTitle("Model") + @JsonPropertyDescription("Select a Hugging Face model") + var modelId: EncodableString = "Qwen/Qwen2.5-72B-Instruct" + + @JsonProperty(value = "promptColumn", required = true) + @JsonSchemaTitle("Prompt Column") + @JsonPropertyDescription("Column in the input table to use as the user prompt") + @AutofillAttributeName + var promptColumn: EncodableString = "" + + @JsonProperty(value = "imageInput", required = false) + @JsonSchemaTitle("Image Upload") + @JsonPropertyDescription("Upload an image for Hugging Face image tasks") + var imageInput: EncodableString = "" + + @JsonProperty(value = "inputImageColumn", required = false) + @JsonSchemaTitle("Input Image Column") + @JsonPropertyDescription("Column containing image data from the input table") + @AutofillAttributeName + var inputImageColumn: EncodableString = "" + + @JsonProperty(value = "audioInput", required = false) + @JsonSchemaTitle("Audio Upload") + @JsonPropertyDescription("Upload audio for Hugging Face audio tasks") + var audioInput: EncodableString = "" + + @JsonProperty(value = "inputAudioColumn", required = false) + @JsonSchemaTitle("Input Audio Column") + @JsonPropertyDescription("Column containing audio data from the input table") + @AutofillAttributeName + var inputAudioColumn: EncodableString = "" + + @JsonProperty(value = "contextColumn", required = false) + @JsonSchemaTitle("Context Column") + @JsonPropertyDescription("Column containing the context passage for question answering") + @AutofillAttributeName + var contextColumn: EncodableString = "" + + @JsonProperty(value = "candidateLabels", required = false) + @JsonSchemaTitle("Candidate Labels") + @JsonPropertyDescription("Comma-separated candidate labels for zero-shot classification") + var candidateLabels: EncodableString = "" + + @JsonProperty(value = "sentencesColumn", required = false) + @JsonSchemaTitle("Sentences Column") + @JsonPropertyDescription( + "Column with comma-separated sentences for sentence similarity and text ranking" + ) + @AutofillAttributeName + var sentencesColumn: EncodableString = "" + + @JsonProperty( + value = "systemPrompt", + required = false, + defaultValue = "You are a helpful assistant." + ) + @JsonSchemaTitle("System Prompt") + @JsonPropertyDescription("Optional system message to set model behavior") + var systemPrompt: EncodableString = "You are a helpful assistant." + + @JsonProperty(value = "maxNewTokens", required = false, defaultValue = "256") + @JsonSchemaTitle("Max New Tokens") + @JsonPropertyDescription("Maximum number of tokens to generate (1-4096)") + var maxNewTokens: java.lang.Integer = 256 + + @JsonProperty(value = "temperature", required = false) + @JsonSchemaTitle("Temperature") + @JsonPropertyDescription("Sampling temperature (0.0 = deterministic, up to 2.0)") + var temperature: java.lang.Double = 0.7 + + @JsonProperty( + value = "resultColumn", + required = false, + defaultValue = "hf_response" + ) + @JsonSchemaTitle("Result Column Name") + @JsonPropertyDescription("Name of the new column added to the output table") + var resultColumn: EncodableString = "hf_response" + + /** + * Per-task code generators. New entries are added as task families land + * in subsequent PRs (e.g. ImageTaskCodegen, AudioTaskCodegen, etc.). + * + * An unrecognized task string falls back to [[TextGenCodegen]]; the + * generated Python's `else` branch then produces a generic `{"inputs": + * prompt_value}` payload and the HF endpoint surfaces the real error at + * runtime. This matches the original monolithic operator's behavior and + * keeps `generatePythonCode` total (it never throws on arbitrary input, + * which is required by `PythonCodeRawInvalidTextSpec`). + */ + private val registeredCodegens: Map[String, TaskCodegen] = { + val byTask = scala.collection.mutable.Map.empty[String, TaskCodegen] + byTask += (TextGenCodegen.task -> TextGenCodegen) + ImageTaskCodegen.tasks.foreach(t => byTask += (t -> ImageTaskCodegen)) + AudioTaskCodegen.tasks.foreach(t => byTask += (t -> AudioTaskCodegen)) + MediaGenCodegen.tasks.foreach(t => byTask += (t -> MediaGenCodegen)) + QaRankingCodegen.tasks.foreach(t => byTask += (t -> QaRankingCodegen)) + byTask.toMap + } + + private def codegenForTask(t: String): TaskCodegen = + registeredCodegens.getOrElse(t, TextGenCodegen) + + override def generatePythonCode(): String = { + val safeTask: EncodableString = + if (task == null || task.trim.isEmpty) "text-generation" else task + val safeModelId: EncodableString = + if (modelId == null) "" else modelId.trim + val safePromptCol: EncodableString = + if (promptColumn == null) "" else promptColumn + val safeResultCol: EncodableString = + if (resultColumn == null || resultColumn.trim.isEmpty) "hf_response" else resultColumn + val safeSystemPrompt: EncodableString = + if (systemPrompt == null) "" else systemPrompt + val safeToken: EncodableString = + if (hfApiToken == null) "" else hfApiToken + + val safeMaxTokens = + math.max(1, math.min(if (maxNewTokens != null) maxNewTokens.intValue else 256, 4096)) + val safeTemp = + math.max(0.0, math.min(if (temperature != null) temperature.doubleValue else 0.7, 2.0)) + + val safeImageInput: EncodableString = + if (imageInput == null) "" else imageInput + val safeInputImageColumn: EncodableString = + if (inputImageColumn == null) "" else inputImageColumn + val safeAudioInput: EncodableString = + if (audioInput == null) "" else audioInput + val safeInputAudioColumn: EncodableString = + if (inputAudioColumn == null) "" else inputAudioColumn + val safeContextColumn: EncodableString = + if (contextColumn == null) "" else contextColumn + val safeCandidateLabels: EncodableString = + if (candidateLabels == null) "" else candidateLabels + val safeSentencesColumn: EncodableString = + if (sentencesColumn == null) "" else sentencesColumn + + val ctx = CodegenContext( + hfApiToken = safeToken, + modelId = safeModelId, + promptColumn = safePromptCol, + resultColumn = safeResultCol, + task = safeTask, + systemPrompt = safeSystemPrompt, + safeMaxTokens = safeMaxTokens, + safeTemp = safeTemp, + imageInput = safeImageInput, + inputImageColumn = safeInputImageColumn, + audioInput = safeAudioInput, + inputAudioColumn = safeInputAudioColumn, + contextColumn = safeContextColumn, + candidateLabels = safeCandidateLabels, + sentencesColumn = safeSentencesColumn + ) + + PythonCodegenBase.render(ctx, codegenForTask(safeTask)) + } + + override def operatorInfo: OperatorInfo = + OperatorInfo( + "Hugging Face", + "Call a Hugging Face model via the Inference API", + OperatorGroupConstants.HUGGINGFACE_GROUP, + inputPorts = List(InputPort()), + outputPorts = List(OutputPort()) + ) + + override def getOutputSchemas( + inputSchemas: Map[PortIdentity, Schema] + ): Map[PortIdentity, Schema] = { + val resCol = + if (resultColumn == null || resultColumn.trim.isEmpty) "hf_response" + else resultColumn + Map( + operatorInfo.outputPorts.head.id -> inputSchemas.values.head + .add(resCol, AttributeType.STRING) + ) + } +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/AudioTaskCodegen.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/AudioTaskCodegen.scala new file mode 100644 index 00000000000..2728ecceb2b --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/AudioTaskCodegen.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace.codegen + +/** + * Codegen for Hugging Face audio task families. + * + * ASR and audio-classification send audio bytes as the raw request body. + * Text-to-speech is prompt-driven and sends a JSON payload; its providers + * return either audio bytes directly or a JSON envelope pointing to audio. + */ +object AudioTaskCodegen extends TaskCodegen { + + override val task: String = "automatic-speech-recognition" + + override val tasks: Set[String] = Set( + "automatic-speech-recognition", + "audio-classification", + "text-to-speech" + ) + + override def payloadPython(ctx: CodegenContext): String = + """ if task in audio_only_tasks: + | payload = current_audio_bytes + | use_raw_binary_body = True + | raw_binary_headers = audio_headers + | elif task == "text-to-speech": + | payload = {"inputs": prompt_value} + | else: + | payload = {"inputs": prompt_value}""".stripMargin + + override def parsePython(ctx: CodegenContext): String = + """ if task == "text-to-speech": + | if isinstance(body, dict): + | if "output" in body: + | out = body["output"] + | url = out[0] if isinstance(out, list) else out + | if isinstance(url, str) and url.startswith("http"): + | return self._audio_url_to_data_url(url) + | if "audio" in body: + | audio = body["audio"] + | if isinstance(audio, dict): + | if "url" in audio: + | return self._audio_url_to_data_url(audio["url"]) + | if "b64_json" in audio: + | return f"data:audio/mpeg;base64,{audio['b64_json']}" + | if "data" in body: + | data = body["data"] + | if data and isinstance(data[0], dict): + | if "url" in data[0]: + | return self._audio_url_to_data_url(data[0]["url"]) + | if "b64_json" in data[0]: + | return f"data:audio/mpeg;base64,{data[0]['b64_json']}" + | return json.dumps(body) + | elif task == "automatic-speech-recognition": + | if isinstance(body, dict): + | if "text" in body: + | return body["text"] + | if "generated_text" in body: + | return body["generated_text"] + | return json.dumps(body) + | elif task == "audio-classification": + | return json.dumps(body)""".stripMargin +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/ImageTaskCodegen.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/ImageTaskCodegen.scala new file mode 100644 index 00000000000..4c8136b1593 --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/ImageTaskCodegen.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace.codegen + +/** + * Codegen for the Hugging Face image-pipeline task family. + * + * Splits into two sub-families: + * - "image-only" tasks send raw image bytes as the request body and don't + * consume the prompt column: image-classification, object-detection, + * image-segmentation, image-to-text. + * - "image + prompt" tasks bundle a base64 image and a text prompt in a + * JSON payload: visual-question-answering, document-question-answering, + * zero-shot-image-classification, image-text-to-text, image-to-image. + * + * Per-row `current_image_bytes` is resolved upstream in + * [[PythonCodegenBase]]'s `process_table` (either from the operator's + * uploaded image or from `INPUT_IMAGE_COLUMN`). The image helpers + * (`_read_image_input`, `_compress_image_bytes`, `_image_input_as_base64`, + * `_read_binary_value`, `_looks_like_html`, `_html_to_image_bytes`, + * `_extract_json_arg`) live in PythonCodegenBase alongside the per-task + * tuples (`image_only_tasks`, `image_prompt_tasks`, `image_tasks`). + */ +object ImageTaskCodegen extends TaskCodegen { + + /** Primary key for registration; the dispatcher maps every task in + * [[tasks]] to this codegen. + */ + override val task: String = "image-classification" + + /** All HF tasks routed through this codegen. */ + override val tasks: Set[String] = Set( + // image-only + "image-classification", + "object-detection", + "image-segmentation", + "image-to-text", + // image + prompt + "visual-question-answering", + "document-question-answering", + "zero-shot-image-classification", + "image-text-to-text", + "image-to-image" + ) + + override def payloadPython(ctx: CodegenContext): String = + """ if task in image_only_tasks: + | payload = current_image_bytes + | use_raw_binary_body = True + | raw_binary_headers = image_headers + | elif task in ("visual-question-answering", "document-question-answering"): + | payload = { + | "inputs": { + | "image": self._image_input_as_base64(current_image_bytes), + | "question": prompt_value, + | } + | } + | elif task == "image-text-to-text": + | img_b64 = self._image_input_as_base64(current_image_bytes) + | payload = { + | "model": self.MODEL_ID, + | "messages": [{ + | "role": "user", + | "content": [ + | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}}, + | {"type": "text", "text": prompt_value if prompt_value else "Describe this image."}, + | ], + | }], + | "max_tokens": self.MAX_NEW_TOKENS, + | } + | elif task == "image-to-image": + | payload = current_image_bytes + | use_raw_binary_body = True + | raw_binary_headers = image_headers + | elif task == "zero-shot-image-classification": + | # Zero-shot requires the caller to supply candidate labels. + | # We reuse the prompt column as a comma-separated label list so + | # the task is shippable without a dedicated operator field. + | # TODO: replace with a first-class `candidateLabels` field once + | # the property panel supports task-specific inputs. + | labels = [s.strip() for s in prompt_value.split(",") if s.strip()] + | payload = { + | "inputs": self._image_input_as_base64(current_image_bytes), + | "parameters": {"candidate_labels": labels}, + | } + | else: + | payload = {"inputs": prompt_value}""".stripMargin + + override def parsePython(ctx: CodegenContext): String = + """ if task == "image-to-text": + | if isinstance(body, dict): + | if "md_results" in body: + | return body["md_results"] + | if "choices" in body: + | return body["choices"][0]["message"]["content"] + | if isinstance(body, list) and body and isinstance(body[0], dict): + | return body[0].get("generated_text", json.dumps(body)) + | return json.dumps(body) + | elif task in ("visual-question-answering", "document-question-answering"): + | if isinstance(body, dict): + | return body.get("answer", json.dumps(body)) + | return json.dumps(body) + | elif task == "image-text-to-text": + | if isinstance(body, dict) and "choices" in body: + | return body["choices"][0]["message"]["content"] + | if isinstance(body, list) and body and isinstance(body[0], dict): + | return body[0].get("generated_text", json.dumps(body)) + | return json.dumps(body) + | elif task == "image-to-image": + | if isinstance(body, dict): + | if "output" in body: + | out = body["output"] + | url = out[0] if isinstance(out, list) else out + | if isinstance(url, str) and url.startswith("http"): + | return self._url_to_data_url(url) + | if "images" in body: + | images = body["images"] + | if images and isinstance(images[0], dict) and "url" in images[0]: + | return self._url_to_data_url(images[0]["url"]) + | if "data" in body: + | data = body["data"] + | if isinstance(data, dict) and "outputs" in data: + | outputs = data["outputs"] + | if outputs and isinstance(outputs[0], str) and outputs[0].startswith("http"): + | return self._url_to_data_url(outputs[0]) + | if isinstance(data, list) and data and isinstance(data[0], dict): + | if "b64_json" in data[0]: + | return f"data:image/png;base64,{data[0]['b64_json']}" + | if "url" in data[0]: + | return self._url_to_data_url(data[0]["url"]) + | return json.dumps(body) + | elif task in ("image-classification", "object-detection", "image-segmentation", "zero-shot-image-classification"): + | return json.dumps(body)""".stripMargin +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/MediaGenCodegen.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/MediaGenCodegen.scala new file mode 100644 index 00000000000..609a04e46ae --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/MediaGenCodegen.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace.codegen + +/** + * Codegen for prompt-driven media generation tasks. + * + * Providers return media in several shapes: raw bytes, OpenAI-style + * b64_json, or URLs. URL responses are normalized to data URLs by the + * shared `_url_to_data_url` helper so downstream result rendering receives + * a stable string format. + */ +object MediaGenCodegen extends TaskCodegen { + + override val task: String = "text-to-image" + + override val tasks: Set[String] = Set( + "text-to-image", + "text-to-video" + ) + + override def payloadPython(ctx: CodegenContext): String = + """ if task in ("text-to-image", "text-to-video"): + | payload = {"inputs": prompt_value} + | else: + | payload = {"inputs": prompt_value}""".stripMargin + + override def parsePython(ctx: CodegenContext): String = + """ if task == "text-to-image": + | if isinstance(body, dict): + | if "output" in body: + | out = body["output"] + | url = out[0] if isinstance(out, list) else out + | if isinstance(url, str) and url.startswith("http"): + | return self._url_to_data_url(url) + | if "images" in body: + | images = body["images"] + | if images and isinstance(images[0], dict) and "url" in images[0]: + | return self._url_to_data_url(images[0]["url"]) + | if "data" in body: + | data = body["data"] + | if isinstance(data, dict) and "outputs" in data: + | outputs = data["outputs"] + | if outputs and isinstance(outputs[0], str) and outputs[0].startswith("http"): + | return self._url_to_data_url(outputs[0]) + | if isinstance(data, list) and data and isinstance(data[0], dict): + | if "b64_json" in data[0]: + | return f"data:image/png;base64,{data[0]['b64_json']}" + | if "url" in data[0]: + | return self._url_to_data_url(data[0]["url"]) + | return json.dumps(body) + | elif task == "text-to-video": + | if isinstance(body, dict): + | if "output" in body: + | out = body["output"] + | url = out[0] if isinstance(out, list) else out + | if isinstance(url, str) and url.startswith("http"): + | return self._url_to_data_url(url) + | if "video" in body: + | video = body["video"] + | if isinstance(video, dict) and "url" in video: + | return self._url_to_data_url(video["url"]) + | return json.dumps(body)""".stripMargin +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/PythonCodegenBase.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/PythonCodegenBase.scala new file mode 100644 index 00000000000..0304f2c01dd --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/PythonCodegenBase.scala @@ -0,0 +1,897 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace.codegen + +import org.apache.texera.amber.pybuilder.PythonTemplateBuilder.PythonTemplateBuilderStringContext + +/** + * Builds the Python script emitted by HuggingFaceInferenceOpDesc. + * + * The script defines a `ProcessTableOperator` class with: + * - Per-instance configuration set in `open(self)` from base64-encoded + * values that the `pyb"..."` macro decodes at runtime (so user-input + * strings never appear as raw Python literals in the source). + * - A provider-fallback system that walks the HF Hub's inference-provider + * list cheapest-first and tries each provider's native chat-completions + * route, with HF Inference Router as the default. + * - A `process_table` loop that validates the prompt column, builds the + * per-row payload via the per-task codegen, posts to the resolved + * provider, and parses the response. + * - A `_parse_response` task switch whose branches are provided by the + * per-task codegen. + * + * Per-task variation lives in `TaskCodegen` implementations. This class + * holds only what is shared across all HF tasks; per-task helpers (image + * loading, audio MIME inference, media-URL fetching, etc.) will be added + * in subsequent PRs as the corresponding task families land. + */ +object PythonCodegenBase { + + def render(ctx: CodegenContext, codegen: TaskCodegen): String = { + val payload = codegen.payloadPython(ctx) + val parse = codegen.parsePython(ctx) + val hfApiToken = ctx.hfApiToken + val modelId = ctx.modelId + val promptColumn = ctx.promptColumn + val resultColumn = ctx.resultColumn + val task = ctx.task + val systemPrompt = ctx.systemPrompt + val maxNewTokens = ctx.safeMaxTokens + val temperature = ctx.safeTemp + val imageInput = ctx.imageInput + val inputImageColumn = ctx.inputImageColumn + val audioInput = ctx.audioInput + val inputAudioColumn = ctx.inputAudioColumn + val contextColumn = ctx.contextColumn + val candidateLabels = ctx.candidateLabels + val sentencesColumn = ctx.sentencesColumn + pyb"""import os + |import re + |import json + |import base64 + |import requests + |import pandas as pd + |from urllib.parse import urlparse + |from pytexera import * + | + |# Defensive format check for MODEL_ID before it is interpolated into + |# HF URL paths. The base host is hardcoded so the worst case isn't + |# SSRF, but rejecting `..` segments / query strings / fragments / + |# control chars keeps the operator's request shape predictable. + |_HF_MODEL_ID_PATTERN = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*(/[A-Za-z0-9._-]+)+$$") + | + |class ProcessTableOperator(UDFTableOperator): + | + | # Providers ranked cheapest-first (lower index = cheaper). + | # Unknown providers are appended at the end. + | PROVIDER_COST_PRIORITY = [ + | "hf-inference", + | "cerebras", + | "sambanova", + | "groq", + | "novita", + | "nebius", + | "fireworks-ai", + | "together", + | "hyperbolic", + | "scaleway", + | "nscale", + | "ovhcloud", + | "deepinfra", + | "featherless-ai", + | "baseten", + | "publicai", + | "nvidia", + | "openai", + | "cohere", + | "clarifai", + | ] + | + | # Per-provider chat-completions route overrides. Providers not listed + | # here use the default `v1/chat/completions` path. Single source of + | # truth for both _post_with_fallback (text-gen) and _call_provider + | # (OpenAI-compatible fallback) so the two stay in sync as providers + | # are added. + | CHAT_ROUTES = { + | "groq": "openai/v1/chat/completions", + | "fireworks-ai": "inference/v1/chat/completions", + | "cohere": "compatibility/v1/chat/completions", + | "clarifai": "v2/ext/openai/v1/chat/completions", + | "deepinfra": "v1/openai/chat/completions", + | } + | + | # Third-party providers that speak the OpenAI chat-completions + | # protocol. Used by _call_provider's OpenAI-compatible branch. + | OPENAI_COMPATIBLE_PROVIDERS = ( + | "cerebras", "sambanova", "groq", "novita", "nebius", + | "fireworks-ai", "together", "hyperbolic", "cohere", "clarifai", + | "deepinfra", "featherless-ai", "nscale", "nvidia", "openai", + | "ovhcloud", "publicai", "scaleway", "baseten", + | ) + | + | def open(self): + | # User-provided strings reach the operator via base64-encoded + | # decode expressions so they cannot break Python syntax or + | # leak raw text into the generated source. + | self.HF_API_TOKEN = $hfApiToken + | self.MODEL_ID = $modelId + | self.PROMPT_COLUMN = $promptColumn + | self.RESULT_COLUMN = $resultColumn + | self.TASK = $task + | self.SYSTEM_PROMPT = $systemPrompt + | self.MAX_NEW_TOKENS = $maxNewTokens + | self.TEMPERATURE = $temperature + | self.IMAGE_INPUT = $imageInput + | self.INPUT_IMAGE_COLUMN = $inputImageColumn + | self.AUDIO_INPUT = $audioInput + | self.INPUT_AUDIO_COLUMN = $inputAudioColumn + | self.CONTEXT_COLUMN = $contextColumn + | self.CANDIDATE_LABELS = $candidateLabels + | self.SENTENCES_COLUMN = $sentencesColumn + | + | def _resolve_providers(self, token): + | '''Query the HF Hub API for inference providers serving this model. + | Returns a list of dicts with 'name' and 'providerId' sorted + | cheapest-first. Falls back to hf-inference if anything goes wrong. + | ''' + | try: + | resp = requests.get( + | f"https://huggingface.co/api/models/{self.MODEL_ID}", + | headers={"Authorization": f"Bearer {token}"}, + | params={"expand[]": "inferenceProviderMapping"}, + | timeout=30, + | ) + | if resp.status_code == 200: + | data = resp.json() + | mapping = ( + | data.get("inferenceProviderMapping") + | or data.get("inference_provider_mapping") + | or {} + | ) + | if mapping: + | live = [ + | { + | "name": p, + | "providerId": v.get("providerId", self.MODEL_ID), + | "task": v.get("task", ""), + | "isModelAuthor": v.get("isModelAuthor", False), + | } + | for p, v in mapping.items() + | if isinstance(v, dict) and v.get("status") == "live" + | ] + | if live: + | priority = {name: idx for idx, name in enumerate(self.PROVIDER_COST_PRIORITY)} + | live.sort(key=lambda prov: priority.get(prov["name"], len(self.PROVIDER_COST_PRIORITY))) + | return live + | except Exception: + | pass + | return [{"name": "hf-inference", "providerId": self.MODEL_ID}] + | + | def _post_with_fallback(self, providers, json_headers, raw_binary_headers, pipeline_payload, use_raw_binary_body, prompt_value): + | '''Try providers in order, using the correct API route for each. + | Returns (response, provider_summary). provider_summary is None on + | success or a string describing what failed. + | ''' + | RETRYABLE = (400, 404, 422, 429, 502, 503) + | last_resp = None + | errors = [] + | for prov in providers: + | provider_name = prov["name"] + | provider_id = prov["providerId"] + | is_model_author = prov.get("isModelAuthor", False) + | prov_task = prov.get("task", "") + | try: + | if self.TASK in ("text-generation", "image-text-to-text"): + | route = self.CHAT_ROUTES.get(provider_name, "v1/chat/completions") + | url = f"https://router.huggingface.co/{provider_name}/{route}" + | resp = requests.post(url, headers=json_headers, json=pipeline_payload, timeout=120) + | elif is_model_author and prov_task in ("image-to-text", "image-text-to-text") and provider_name not in ("zai-org",): + | url = f"https://router.huggingface.co/{provider_name}/v1/chat/completions" + | img_b64 = "" + | if use_raw_binary_body and isinstance(pipeline_payload, bytes): + | img_b64 = base64.b64encode(pipeline_payload).decode("utf-8") + | chat_payload = { + | "model": provider_id, + | "messages": [{ + | "role": "user", + | "content": [ + | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}} if img_b64 else None, + | {"type": "text", "text": prompt_value if prompt_value else "What is in this image?"}, + | ], + | }], + | } + | chat_payload["messages"][0]["content"] = [c for c in chat_payload["messages"][0]["content"] if c is not None] + | resp = requests.post(url, headers=json_headers, json=chat_payload, timeout=120) + | elif provider_name == "hf-inference": + | url = f"https://router.huggingface.co/hf-inference/models/{self.MODEL_ID}" + | if use_raw_binary_body: + | resp = requests.post(url, headers=raw_binary_headers, data=pipeline_payload, timeout=120) + | else: + | resp = requests.post(url, headers=json_headers, json=pipeline_payload, timeout=120) + | else: + | resp = self._call_provider(provider_name, provider_id, json_headers, raw_binary_headers, pipeline_payload, use_raw_binary_body, prompt_value) + | except Exception as e: + | errors.append(f"{provider_name}: {type(e).__name__}") + | continue + | if resp.status_code in (200, 201): + | return resp, None + | if resp.status_code == 401: + | return resp, None + | try: + | detail = resp.json().get("error", resp.text[:200]) + | except Exception: + | detail = resp.text[:200] if resp.text else "no details" + | errors.append(f"{provider_name}: HTTP {resp.status_code} - {detail}") + | last_resp = resp + | if resp.status_code not in RETRYABLE: + | return resp, "; ".join(errors) + | summary = "; ".join(errors) if errors else "no providers available" + | return last_resp, summary + | + | def _call_provider(self, provider_name, provider_id, json_headers, raw_binary_headers, pipeline_payload, use_raw_binary_body, prompt_value): + | '''Route to a third-party provider using its native API format. + | Handles OpenAI-compatible chat providers for text-gen, zai-org's + | custom API, Replicate / Fal-ai / Wavespeed for media-generation + | and image-to-image, and an unknown-provider fallback that tries + | the pipeline format then chat completions. + | ''' + | base = f"https://router.huggingface.co/{provider_name}" + | task = self.TASK + | img_b64 = "" + | if use_raw_binary_body and isinstance(pipeline_payload, bytes): + | img_b64 = base64.b64encode(pipeline_payload).decode("utf-8") + | + | # zai-org: custom /api/paas/v4/ surface. + | if provider_name == "zai-org": + | zai_headers = {**json_headers, "x-source-channel": "hugging_face", "accept-language": "en-US,en"} + | if task in ("image-to-text", "image-text-to-text"): + | url = f"{base}/api/paas/v4/layout_parsing" + | file_data = f"data:image/png;base64,{img_b64}" if img_b64 else "" + | return requests.post(url, headers=zai_headers, json={"model": provider_id, "file": file_data}, timeout=120) + | url = f"{base}/api/paas/v4/chat/completions" + | messages = [{"role": "user", "content": prompt_value}] + | if img_b64: + | messages = [{"role": "user", "content": [ + | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}}, + | {"type": "text", "text": prompt_value if prompt_value else "What is in this image?"}, + | ]}] + | return requests.post(url, headers=zai_headers, json={"model": provider_id, "messages": messages}, timeout=120) + | + | # Replicate: synchronous predictions endpoint with polling fallback. + | if provider_name == "replicate": + | url = f"{base}/v1/models/{provider_id}/predictions" + | hdrs = {**json_headers, "Prefer": "wait"} + | if task == "text-to-speech": + | inp = {"text": prompt_value} + | elif task in ("text-to-image", "text-to-video"): + | inp = {"prompt": prompt_value} + | elif task == "automatic-speech-recognition" and img_b64: + | inp = {"audio": f"data:audio/wav;base64,{img_b64}"} + | elif task == "image-to-image" and img_b64: + | data_url = f"data:image/png;base64,{img_b64}" + | inp = {"image": data_url, "images": [data_url], "input_image": data_url, "prompt": prompt_value} + | elif img_b64: + | inp = {"image": f"data:image/png;base64,{img_b64}", "prompt": prompt_value} + | else: + | inp = {"prompt": prompt_value} + | resp = requests.post(url, headers=hdrs, json={"input": inp}, timeout=120) + | if resp.status_code == 202: + | import time as _time + | pred = resp.json() + | poll_url = pred.get("urls", {}).get("get", "") + | if not poll_url: + | return resp + | from urllib.parse import urlparse as _urlparse + | poll_path = _urlparse(poll_url).path + | poll_url = f"{base}{poll_path}" + | # Worst case: 300 polls × 2s = ~10 minutes per row before we give + | # up. Sized for text-to-video which legitimately takes minutes on + | # Replicate. process_table is synchronous, so emit a progress + | # line every 30 polls (~1 min) to distinguish slow work from a + | # hang in the worker log. + | for poll_idx in range(300): + | _time.sleep(2) + | poll_resp = requests.get(poll_url, headers=json_headers, timeout=30) + | if poll_resp.status_code != 200: + | continue + | status = poll_resp.json().get("status", "") + | if status in ("succeeded", "failed", "canceled"): + | return poll_resp + | if (poll_idx + 1) % 30 == 0: + | print(f"[hf] Replicate still running for model '{self.MODEL_ID}' after {(poll_idx + 1) * 2}s; will wait up to 600s.") + | return poll_resp + | return resp + | + | # Fal-ai: per-model endpoint. + | if provider_name == "fal-ai": + | url = f"{base}/{provider_id}" + | if task == "text-to-speech": + | return requests.post(url, headers=json_headers, json={"text": prompt_value}, timeout=120) + | if task in ("text-to-image", "text-to-video"): + | return requests.post(url, headers=json_headers, json={"prompt": prompt_value}, timeout=120) + | if task == "image-to-image" and img_b64: + | data_url = f"data:image/png;base64,{img_b64}" + | return requests.post(url, headers=json_headers, json={"image_url": data_url, "image_urls": [data_url], "prompt": prompt_value}, timeout=120) + | if img_b64: + | return requests.post(url, headers=json_headers, json={"image_url": f"data:image/png;base64,{img_b64}", "prompt": prompt_value}, timeout=120) + | return requests.post(url, headers=json_headers, json={"prompt": prompt_value}, timeout=120) + | + | # Wavespeed: async submit + poll. + | if provider_name == "wavespeed": + | url = f"{base}/api/v3/{provider_id}" + | payload = {"prompt": prompt_value} + | if img_b64: + | payload["image"] = img_b64 + | payload["images"] = [img_b64] + | submit_resp = requests.post(url, headers=json_headers, json=payload, timeout=120) + | if submit_resp.status_code not in (200, 201): + | return submit_resp + | get_path = submit_resp.json().get("data", {}).get("urls", {}).get("get", "") + | if not get_path: + | return submit_resp + | from urllib.parse import urlparse as _urlparse + | result_url = f"{base}{_urlparse(get_path).path}" + | import time as _time + | poll_resp = submit_resp + | # Worst case: 120 polls × 1s = ~2 minutes per row. Emit a progress + | # line every 30 polls (~30 s) so the worker log distinguishes slow + | # work from a hang. + | for poll_idx in range(120): + | _time.sleep(1) + | poll_resp = requests.get(result_url, headers=json_headers, timeout=30) + | if poll_resp.status_code != 200: + | continue + | status = poll_resp.json().get("data", {}).get("status", "") + | if status in ("completed", "failed"): + | return poll_resp + | if (poll_idx + 1) % 30 == 0: + | print(f"[hf] Wavespeed still running for model '{self.MODEL_ID}' after {poll_idx + 1}s; will wait up to 120s.") + | return poll_resp + | + | if provider_name in self.OPENAI_COMPATIBLE_PROVIDERS: + | if task == "text-to-image": + | url = f"{base}/v1/images/generations" + | return requests.post(url, headers=json_headers, json={"model": provider_id, "prompt": prompt_value}, timeout=120) + | if task == "text-to-speech": + | url = f"{base}/v1/audio/speech" + | return requests.post(url, headers=json_headers, json={"model": provider_id, "input": prompt_value}, timeout=120) + | url = f"{base}/{self.CHAT_ROUTES.get(provider_name, 'v1/chat/completions')}" + | messages = [{"role": "user", "content": prompt_value}] + | if img_b64: + | messages = [{"role": "user", "content": [ + | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}}, + | {"type": "text", "text": prompt_value if prompt_value else "What is in this image?"}, + | ]}] + | return requests.post( + | url, + | headers=json_headers, + | json={"model": provider_id, "messages": messages}, + | timeout=120, + | ) + | + | # Unknown provider: try pipeline format, then chat completions. + | url = f"{base}/{provider_id}" + | if use_raw_binary_body: + | resp = requests.post(url, headers=raw_binary_headers, data=pipeline_payload, timeout=120) + | else: + | resp = requests.post(url, headers=json_headers, json=pipeline_payload, timeout=120) + | if resp.status_code in (400, 404, 422): + | url = f"{base}/v1/chat/completions" + | messages = [{"role": "user", "content": prompt_value}] + | if img_b64: + | messages = [{"role": "user", "content": [ + | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}}, + | {"type": "text", "text": prompt_value if prompt_value else "Describe this image."}, + | ]}] + | resp2 = requests.post( + | url, + | headers=json_headers, + | json={"model": provider_id, "messages": messages}, + | timeout=120, + | ) + | if resp2.status_code == 200: + | return resp2 + | return resp + | + | @overrides + | def process_table(self, table: Table, port: int) -> Iterator[Optional[TableLike]]: + | prompt_col = self.PROMPT_COLUMN + | result_col = self.RESULT_COLUMN + | task = self.TASK + | image_only_tasks = ("image-classification", "object-detection", "image-segmentation", "image-to-text") + | image_prompt_tasks = ("visual-question-answering", "document-question-answering", "zero-shot-image-classification", "image-text-to-text", "image-to-image") + | image_tasks = image_only_tasks + image_prompt_tasks + | audio_only_tasks = ("automatic-speech-recognition", "audio-classification") + | + | # --- validate MODEL_ID format before any HF URL is built --- + | if not _HF_MODEL_ID_PATTERN.match(self.MODEL_ID or ""): + | raise ValueError( + | f"Invalid Hugging Face model ID '{self.MODEL_ID}'. " + | f"Expected format like 'org/model-name' or 'org/model-name/revision'." + | ) + | + | # --- resolve API token --- + | token = self.HF_API_TOKEN if self.HF_API_TOKEN else os.environ.get("HF_TOKEN", "") + | if not token: + | raise ValueError( + | "Hugging Face API token is not set. " + | "Provide it in the operator config or via HF_TOKEN env var." + | ) + | + | # --- resolve all available inference providers for this model (tried in order) --- + | providers = self._resolve_providers(token) + | + | # --- validate prompt column exists (skipped for binary-only tasks) --- + | if task not in image_only_tasks and task not in audio_only_tasks: + | assert prompt_col in table.columns, ( + | f"Prompt column '{prompt_col}' not found in input table. " + | f"Available columns: {list(table.columns)}" + | ) + | if task == "question-answering": + | ctx_col = self.CONTEXT_COLUMN + | assert ctx_col and ctx_col in table.columns, ( + | f"Context column '{ctx_col}' not found in input table. " + | f"Available columns: {list(table.columns)}" + | ) + | if task in ("sentence-similarity", "text-ranking"): + | sent_col = self.SENTENCES_COLUMN + | assert sent_col and sent_col in table.columns, ( + | f"Sentences column '{sent_col}' not found in input table. " + | f"Available columns: {list(table.columns)}" + | ) + | + | # --- handle empty table --- + | if table.empty: + | table[result_col] = pd.Series(dtype="object") + | yield table + | return + | + | json_headers = { + | "Authorization": f"Bearer {token}", + | "Content-Type": "application/json", + | } + | image_headers = { + | "Authorization": f"Bearer {token}", + | "Content-Type": "application/octet-stream", + | } + | audio_headers = { + | "Authorization": f"Bearer {token}", + | "Content-Type": self._get_audio_content_type(), + | } + | + | # --- pre-compute table dict for table-question-answering --- + | table_dict = None + | if task == "table-question-answering": + | table_dict = {} + | for col in table.columns: + | if col != prompt_col and col != result_col: + | table_dict[col] = [ + | str(v) if not pd.isna(v) else "" for v in table[col].tolist() + | ] + | + | # --- resolve image source (upload or column) for image tasks --- + | has_image_upload = bool(self.IMAGE_INPUT) and bool(str(self.IMAGE_INPUT).strip()) + | use_image_column = not has_image_upload and bool(self.INPUT_IMAGE_COLUMN) and self.INPUT_IMAGE_COLUMN in table.columns + | image_bytes = None + | image_error = None + | has_audio_upload = bool(self.AUDIO_INPUT) and bool(str(self.AUDIO_INPUT).strip()) + | use_audio_column = not has_audio_upload and bool(self.INPUT_AUDIO_COLUMN) and self.INPUT_AUDIO_COLUMN in table.columns + | audio_bytes = None + | audio_error = None + | if task in image_tasks and not use_image_column: + | if not has_image_upload: + | image_error = "No image source. Set an Input Image Column or upload an image." + | else: + | try: + | image_bytes = self._read_image_input() + | except Exception as e: + | image_error = f"Could not read image input ({type(e).__name__}: {e})" + | if task in audio_only_tasks and not use_audio_column: + | if not has_audio_upload: + | audio_error = "No audio source. Set an Input Audio Column or upload audio." + | else: + | try: + | audio_bytes = self._read_audio_input() + | except Exception as e: + | audio_error = f"Could not read audio input ({type(e).__name__}: {e})" + | + | results = [] + | for idx, row in table.iterrows(): + | if image_error is not None: + | results.append(self._format_error("Image task configuration error", image_error)) + | continue + | if audio_error is not None: + | results.append(self._format_error("Audio task configuration error", audio_error)) + | continue + | + | if task in image_only_tasks: + | prompt_value = "" + | elif task in audio_only_tasks: + | prompt_value = "" + | elif task in image_prompt_tasks and prompt_col not in table.columns: + | prompt_value = "What is shown in this image?" + | else: + | prompt_value = row[prompt_col] + | if pd.isna(prompt_value): + | prompt_value = "" + | else: + | prompt_value = str(prompt_value) + | + | # --- resolve per-row image bytes from column --- + | current_image_bytes = image_bytes + | if task in image_tasks and use_image_column: + | try: + | raw = self._read_binary_value(row[self.INPUT_IMAGE_COLUMN]) + | if raw is None: + | results.append(self._format_error("Image data error", f"Row {idx}: image column is empty")) + | continue + | current_image_bytes = self._compress_image_bytes(raw) + | except Exception as e: + | results.append(self._format_error("Image data error", f"Row {idx}: {type(e).__name__}: {e}")) + | continue + | + | # --- resolve per-row audio bytes from column --- + | current_audio_bytes = audio_bytes + | if task in audio_only_tasks and use_audio_column: + | try: + | current_audio_bytes = self._read_binary_value(row[self.INPUT_AUDIO_COLUMN]) + | if current_audio_bytes is None: + | results.append(self._format_error("Audio data error", f"Row {idx}: audio column is empty")) + | continue + | except Exception as e: + | results.append(self._format_error("Audio data error", f"Row {idx}: {type(e).__name__}: {e}")) + | continue + | + | # --- build task-specific payload (provided by per-task codegen) --- + | use_raw_binary_body = False + | raw_binary_headers = image_headers + |${payload} + | + | try: + | resp, provider_summary = self._post_with_fallback( + | providers, json_headers, raw_binary_headers, payload, use_raw_binary_body, prompt_value + | ) + | + | if resp is None: + | results.append( + | self._format_error( + | "All inference providers failed", + | f"No provider could serve model '{self.MODEL_ID}'. " + | f"Tried: {provider_summary}" + | ) + | ) + | continue + | + | if resp.status_code == 429: + | results.append( + | self._format_http_error( + | "HF API rate limit hit, retry later", resp.status_code, resp.text + | ) + | ) + | continue + | if resp.status_code == 401: + | results.append( + | self._format_http_error("Invalid HF API token", resp.status_code, resp.text) + | ) + | continue + | if resp.status_code not in (200, 201): + | results.append( + | self._format_error( + | "All inference providers failed", + | f"No provider could serve model '{self.MODEL_ID}'. " + | f"Tried: {provider_summary}" + | ) + | ) + | continue + | + | content_type = resp.headers.get("Content-Type", "") + | if content_type.startswith("image/"): + | b64 = base64.b64encode(resp.content).decode("utf-8") + | results.append(f"data:{content_type};base64,{b64}") + | continue + | if content_type.startswith("audio/") or content_type.startswith("video/"): + | b64 = base64.b64encode(resp.content).decode("utf-8") + | results.append(f"data:{content_type};base64,{b64}") + | continue + | + | try: + | body = resp.json() + | except ValueError: + | body = resp.text + | content = self._parse_response(body) + | results.append(content) + | + | except Exception as e: + | import warnings + | warnings.warn( + | f"Row {idx}: request failed ({type(e).__name__}: {e}), " + | f"setting result to readable error text." + | ) + | results.append(self._format_error("Request failed", f"{type(e).__name__}: {e}")) + | + | table[result_col] = results + | yield table + | + | def _format_error(self, title, detail): + | return f"{title}: {detail}" + | + | def _format_http_error(self, title, status_code, response_text): + | # Cap at 200 chars to match the truncation in _post_with_fallback's + | # error-detail extraction; a large body / HTML error page would + | # otherwise land verbatim in the result cell. + | detail = response_text.strip()[:200] + | if not detail: + | detail = "" + | return f"{title} [status={status_code}] response={detail}" + | + | # ────────────────────────────────────────────────────────────────── + | # Image-task helpers (used by ImageTaskCodegen and image-related + | # branches of _call_provider). + | # ────────────────────────────────────────────────────────────────── + | + | def _read_image_input(self): + | image_input = str(self.IMAGE_INPUT or "").strip() + | if image_input.startswith("data:"): + | _, encoded = image_input.split(",", 1) + | return base64.b64decode(encoded) + | if image_input.startswith("http://") or image_input.startswith("https://"): + | resp = requests.get(image_input, timeout=120) + | resp.raise_for_status() + | return resp.content + | if not os.path.exists(image_input): + | raise FileNotFoundError(f"Image file not found at path: {image_input}") + | if not os.path.isfile(image_input): + | raise ValueError(f"Image input path is not a file: {image_input}") + | with open(image_input, "rb") as image_file: + | return image_file.read() + | + | def _compress_image_bytes(self, image_bytes, max_bytes=33000): + | from io import BytesIO + | from PIL import Image as PILImage + | if len(image_bytes) <= max_bytes: + | return image_bytes + | try: + | img = PILImage.open(BytesIO(image_bytes)) + | img = img.convert("RGB") + | max_dim = 512 + | quality = 75 + | while max_dim >= 160: + | scale = min(1, max_dim / max(img.width, img.height)) + | w = max(1, round(img.width * scale)) + | h = max(1, round(img.height * scale)) + | resized = img.resize((w, h), PILImage.LANCZOS) + | q = quality + | while q >= 35: + | buf = BytesIO() + | resized.save(buf, format="JPEG", quality=q) + | if buf.tell() <= max_bytes: + | return buf.getvalue() + | q -= 10 + | max_dim = int(max_dim * 0.75) + | buf = BytesIO() + | resized.save(buf, format="JPEG", quality=35) + | return buf.getvalue() + | except Exception: + | return image_bytes + | + | def _image_input_as_base64(self, image_bytes): + | return base64.b64encode(image_bytes).decode("utf-8") + | + | def _read_audio_input(self): + | audio_input = str(self.AUDIO_INPUT or "").strip() + | if audio_input.startswith("data:"): + | _, encoded = audio_input.split(",", 1) + | return base64.b64decode(encoded) + | if audio_input.startswith("http://") or audio_input.startswith("https://"): + | resp = requests.get(audio_input, timeout=120) + | resp.raise_for_status() + | return resp.content + | if not os.path.exists(audio_input): + | raise FileNotFoundError(f"Audio file not found at path: {audio_input}") + | if not os.path.isfile(audio_input): + | raise ValueError(f"Audio input path is not a file: {audio_input}") + | with open(audio_input, "rb") as audio_file: + | return audio_file.read() + | + | def _read_binary_value(self, value): + | if value is None or (isinstance(value, float) and pd.isna(value)): + | return None + | if isinstance(value, bytes): + | return value + | val = str(value).strip() + | if not val: + | return None + | if self._looks_like_html(val): + | return self._html_to_image_bytes(val) + | if val.startswith("data:"): + | _, encoded = val.split(",", 1) + | return base64.b64decode(encoded) + | if val.startswith("http://") or val.startswith("https://"): + | resp = requests.get(val, timeout=120) + | resp.raise_for_status() + | return resp.content + | if os.path.exists(val) and os.path.isfile(val): + | with open(val, "rb") as f: + | return f.read() + | try: + | return base64.b64decode(val) + | except Exception: + | return val.encode("utf-8") + | + | def _looks_like_html(self, val): + | s = val.lstrip()[:200].lower() + | if s.startswith("= len(text): + | return None, start_pos + | ch = text[start_pos] + | openers = {"[": "]", "{": "}"} + | if ch not in openers: + | return None, start_pos + | closer = openers[ch] + | depth = 1 + | pos = start_pos + 1 + | in_string = False + | while pos < len(text) and depth > 0: + | c = text[pos] + | if in_string: + | if c == "\\\\": + | pos += 2 + | continue + | if c == '"': + | in_string = False + | else: + | if c == '"': + | in_string = True + | elif c == ch: + | depth += 1 + | elif c == closer: + | depth -= 1 + | pos += 1 + | if depth == 0: + | return text[start_pos:pos], pos + | return None, start_pos + | + | def _get_audio_content_type(self): + | audio_input = str(self.AUDIO_INPUT or "").strip().lower() + | if audio_input.startswith("data:"): + | header = audio_input.split(",", 1)[0] + | if ";" in header: + | return header[5:header.index(";")] + | return header[5:] + | extension_map = { + | ".mp3": "audio/mpeg", + | ".mpeg": "audio/mpeg", + | ".wav": "audio/wav", + | ".flac": "audio/flac", + | ".ogg": "audio/ogg", + | ".oga": "audio/ogg", + | ".webm": "audio/webm", + | ".opus": "audio/webm;codecs=opus", + | ".amr": "audio/amr", + | ".m4a": "audio/m4a", + | } + | _, ext = os.path.splitext(audio_input) + | return extension_map.get(ext, "audio/mpeg") + | + | def _audio_url_to_data_url(self, url): + | resp = requests.get(url, timeout=120) + | resp.raise_for_status() + | content_type = resp.headers.get("Content-Type", "").strip() + | if not content_type or content_type == "application/octet-stream": + | parsed = urlparse(url) + | _, ext = os.path.splitext(parsed.path.lower()) + | extension_map = { + | ".mp3": "audio/mpeg", + | ".mpeg": "audio/mpeg", + | ".wav": "audio/wav", + | ".flac": "audio/flac", + | ".ogg": "audio/ogg", + | ".oga": "audio/ogg", + | ".webm": "audio/webm", + | ".opus": "audio/webm;codecs=opus", + | ".amr": "audio/amr", + | ".m4a": "audio/m4a", + | } + | content_type = extension_map.get(ext, "audio/mpeg") + | b64 = base64.b64encode(resp.content).decode("utf-8") + | return f"data:{content_type};base64,{b64}" + | + | def _url_to_data_url(self, url): + | '''Fetch a URL and return a data URL with the correct MIME type.''' + | resp = requests.get(url, timeout=120) + | resp.raise_for_status() + | content_type = resp.headers.get("Content-Type", "").split(";")[0].strip() + | if not content_type or content_type == "application/octet-stream": + | from urllib.parse import urlparse as _urlparse + | ext = os.path.splitext(_urlparse(url).path.lower())[1] + | mime_map = {".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".gif": "image/gif", ".webp": "image/webp", ".svg": "image/svg+xml", ".mp4": "video/mp4", ".webm": "video/webm"} + | guessed = mime_map.get(ext, "") + | if guessed: + | content_type = guessed + | else: + | task_mime = {"image-to-image": "image/png", "text-to-image": "image/png", "text-to-video": "video/mp4"} + | content_type = task_mime.get(self.TASK, "application/octet-stream") + | b64 = base64.b64encode(resp.content).decode("utf-8") + | return f"data:{content_type};base64,{b64}" + | + | def _parse_response(self, body): + | task = self.TASK + | try: + | if isinstance(body, str): + | return body + |${parse} + | else: + | return json.dumps(body) + | except (KeyError, IndexError, TypeError): + | return json.dumps(body) + |""".encode + } +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/QaRankingCodegen.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/QaRankingCodegen.scala new file mode 100644 index 00000000000..a1785e23e65 --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/QaRankingCodegen.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace.codegen + +/** + * Codegen for question-answering, zero-shot, similarity, and ranking tasks. + * + * These tasks are prompt-driven but need extra per-row or per-operator + * inputs: context text, candidate labels, table contents, or a list of + * comparison sentences/documents. + */ +object QaRankingCodegen extends TaskCodegen { + + override val task: String = "question-answering" + + override val tasks: Set[String] = Set( + "question-answering", + "table-question-answering", + "zero-shot-classification", + "sentence-similarity", + "text-ranking" + ) + + override def payloadPython(ctx: CodegenContext): String = + """ if task == "question-answering": + | ctx_val = row[self.CONTEXT_COLUMN] + | ctx_val = "" if pd.isna(ctx_val) else str(ctx_val) + | payload = {"inputs": {"question": prompt_value, "context": ctx_val}} + | elif task == "table-question-answering": + | payload = {"inputs": {"query": prompt_value, "table": table_dict}} + | elif task == "zero-shot-classification": + | labels = [l.strip() for l in self.CANDIDATE_LABELS.split(",") if l.strip()] + | payload = { + | "inputs": prompt_value, + | "parameters": {"candidate_labels": labels}, + | } + | elif task in ("sentence-similarity", "text-ranking"): + | sent_val = row[self.SENTENCES_COLUMN] + | sent_val = "" if pd.isna(sent_val) else str(sent_val) + | sentences_list = [s.strip() for s in sent_val.split(",") if s.strip()] + | payload = { + | "inputs": { + | "source_sentence": prompt_value, + | "sentences": sentences_list, + | } + | } + | else: + | payload = {"inputs": prompt_value}""".stripMargin + + override def parsePython(ctx: CodegenContext): String = + """ if task == "question-answering": + | return body.get("answer", json.dumps(body)) + | elif task == "table-question-answering": + | return body.get("answer", json.dumps(body)) + | elif task in ("zero-shot-classification", "sentence-similarity", "text-ranking"): + | return json.dumps(body)""".stripMargin +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TaskCodegen.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TaskCodegen.scala new file mode 100644 index 00000000000..8abcef721b5 --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TaskCodegen.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace.codegen + +import org.apache.texera.amber.pybuilder.PyStringTypes.EncodableString + +/** + * Inputs the dispatcher passes through to each TaskCodegen. + * + * User-provided string fields are typed as [[EncodableString]] so the + * `pyb"..."` macro in [[PythonCodegenBase]] emits them as base64-decoded + * runtime expressions rather than raw Python string literals — required to + * pass `PythonCodeRawInvalidTextSpec`'s leakage check. + */ +final case class CodegenContext( + hfApiToken: EncodableString, + modelId: EncodableString, + promptColumn: EncodableString, + resultColumn: EncodableString, + task: EncodableString, + systemPrompt: EncodableString, + safeMaxTokens: Int, + safeTemp: Double, + imageInput: EncodableString = "", + inputImageColumn: EncodableString = "", + audioInput: EncodableString = "", + inputAudioColumn: EncodableString = "", + contextColumn: EncodableString = "", + candidateLabels: EncodableString = "", + sentencesColumn: EncodableString = "" +) + +/** + * A bundle of Python snippets that customize generated inference code for + * one Hugging Face pipeline task family. + * + * Concrete implementations are `object`s registered in + * `HuggingFaceInferenceOpDesc.registeredCodegens`. New task families + * (image, audio, QA, etc.) land in subsequent PRs by introducing new + * `*Codegen` objects and adding them to that map. + * + * Snippets returned by these methods are Python source spliced into the + * shared template assembled by [[PythonCodegenBase.render]]. Snippets must + * NOT directly inline user-provided strings — reference the per-instance + * attributes `self.HF_API_TOKEN`, `self.MODEL_ID`, `self.PROMPT_COLUMN`, + * etc. that the base class initializes from `CodegenContext` via the + * `pyb` macro's safe encoding. The snippet author is responsible for the + * correct indentation column (see existing implementations). + */ +trait TaskCodegen { + + /** Canonical Hugging Face pipeline task string used as the primary key for + * registration, e.g. "text-generation". Codegens that handle multiple + * task strings (image, audio, …) override [[tasks]] to enumerate all of + * them — the operator's dispatcher registers an entry per task. + */ + def task: String + + /** All Hugging Face pipeline task strings handled by this codegen. + * Defaults to the singleton `Set(task)` for codegens that handle one + * task; multi-task codegens override this. + */ + def tasks: Set[String] = Set(task) + + /** Python text that assigns `payload = …` for one row inside + * `process_table`'s per-row loop. The snippet supplies its own leading + * `if`/`elif task == "...":` opener and any `else` fallback. + */ + def payloadPython(ctx: CodegenContext): String + + /** Python text for the body of `_parse_response`'s task switch. The + * snippet supplies its own leading `if`/`elif task == "...":` opener. + * The base class wraps the result in the try/except matching the + * source layout. + */ + def parsePython(ctx: CodegenContext): String +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TextGenCodegen.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TextGenCodegen.scala new file mode 100644 index 00000000000..b836de9e121 --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TextGenCodegen.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace.codegen + +/** + * Codegen for the `text-generation` Hugging Face pipeline task. + * + * The payload is the OpenAI chat-completions shape — `messages` with a + * system + user pair plus `max_tokens` / `temperature` knobs — which is + * what the HF router and every OpenAI-compatible third-party provider + * (Cerebras, Groq, Sambanova, Together, …) accepts. + * + * The parse step pulls `body["choices"][0]["message"]["content"]` out of + * the response. + */ +object TextGenCodegen extends TaskCodegen { + + override val task: String = "text-generation" + + override def payloadPython(ctx: CodegenContext): String = + """ if task == "text-generation": + | payload = { + | "model": self.MODEL_ID, + | "messages": [ + | {"role": "system", "content": self.SYSTEM_PROMPT}, + | {"role": "user", "content": prompt_value}, + | ], + | "max_tokens": self.MAX_NEW_TOKENS, + | "temperature": self.TEMPERATURE, + | } + | else: + | payload = {"inputs": prompt_value}""".stripMargin + + override def parsePython(ctx: CodegenContext): String = + """ if task == "text-generation": + | return body["choices"][0]["message"]["content"]""".stripMargin +} diff --git a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDescSpec.scala b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDescSpec.scala new file mode 100644 index 00000000000..624958fa356 --- /dev/null +++ b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDescSpec.scala @@ -0,0 +1,407 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace + +import org.apache.texera.amber.core.tuple.{AttributeType, Schema} +import org.apache.texera.amber.core.workflow.PortIdentity +import org.apache.texera.amber.operator.huggingFace.codegen.{ + AudioTaskCodegen, + CodegenContext, + MediaGenCodegen, + QaRankingCodegen, + TextGenCodegen +} +import org.apache.texera.amber.operator.metadata.OperatorGroupConstants +import org.apache.texera.amber.pybuilder.PyStringTypes.EncodableString +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class HuggingFaceInferenceOpDescSpec extends AnyFlatSpec with Matchers { + + private def makeDesc( + token: EncodableString = "token", + modelId: EncodableString = "Qwen/Qwen2.5-72B-Instruct", + promptColumn: EncodableString = "prompt", + task: EncodableString = "text-generation", + systemPrompt: EncodableString = "You are a helpful assistant.", + maxNewTokens: Int = 256, + temperature: Double = 0.7, + resultColumn: EncodableString = "hf_response", + imageInput: EncodableString = "", + inputImageColumn: EncodableString = "", + audioInput: EncodableString = "", + inputAudioColumn: EncodableString = "", + contextColumn: EncodableString = "", + candidateLabels: EncodableString = "", + sentencesColumn: EncodableString = "" + ): HuggingFaceInferenceOpDesc = { + val desc = new HuggingFaceInferenceOpDesc() + desc.hfApiToken = token + desc.modelId = modelId + desc.promptColumn = promptColumn + desc.task = task + desc.systemPrompt = systemPrompt + desc.maxNewTokens = maxNewTokens + desc.temperature = temperature + desc.resultColumn = resultColumn + desc.imageInput = imageInput + desc.inputImageColumn = inputImageColumn + desc.audioInput = audioInput + desc.inputAudioColumn = inputAudioColumn + desc.contextColumn = contextColumn + desc.candidateLabels = candidateLabels + desc.sentencesColumn = sentencesColumn + desc + } + + "HuggingFaceInferenceOpDesc.operatorInfo" should + "advertise the user-friendly name, HuggingFace group, and one input/output port" in { + val info = (new HuggingFaceInferenceOpDesc).operatorInfo + info.userFriendlyName shouldBe "Hugging Face" + info.operatorGroupName shouldBe OperatorGroupConstants.HUGGINGFACE_GROUP + info.inputPorts.size shouldBe 1 + info.outputPorts.size shouldBe 1 + } + + "generatePythonCode" should + "fall back to the text-gen codegen on an unrecognized task (HF reports the real error at runtime)" in { + // generatePythonCode must be total — never throw on arbitrary @JsonProperty + // values — per the PythonCodeRawInvalidTextSpec contract. An unknown task + // routes through TextGenCodegen, whose payload `if/else` hits the generic + // `{"inputs": prompt_value}` branch at runtime. + val code = makeDesc(task = "not-a-real-task").generatePythonCode() + code should include("""payload = {"inputs": prompt_value}""") + } + + it should "emit a ProcessTableOperator that initializes config in open()" in { + val code = makeDesc().generatePythonCode() + code should include("class ProcessTableOperator(UDFTableOperator):") + code should include("def open(self):") + // User-input strings are decoded at runtime, not embedded as literals. + code should include("self.HF_API_TOKEN = self.decode_python_template(") + code should include("self.MODEL_ID = self.decode_python_template(") + code should include("self.PROMPT_COLUMN = self.decode_python_template(") + code should include("self.TASK = self.decode_python_template(") + code should include("self.SYSTEM_PROMPT = self.decode_python_template(") + } + + it should "wire the text-gen payload and response parse correctly" in { + val code = makeDesc().generatePythonCode() + // Payload — chat-completions shape against the configured model + system prompt. + code should include("self.MODEL_ID") + code should include("self.SYSTEM_PROMPT") + code should include("self.MAX_NEW_TOKENS") + code should include("self.TEMPERATURE") + // Parse — text-gen pulls choices[0].message.content out of the response. + code should include("""body["choices"][0]["message"]["content"]""") + } + + it should + "emit a runtime check that rejects malformed MODEL_ID values before any HF URL is built" in { + val code = makeDesc().generatePythonCode() + // Pattern that fences MODEL_ID to org/model-name (allowing org/model-name/revision). + code should include("_HF_MODEL_ID_PATTERN = re.compile(") + // Runtime fail-fast inside process_table — happens before _resolve_providers + // composes the URL, so a malformed value never escapes into a request. + code should include("if not _HF_MODEL_ID_PATTERN.match(") + code should include("raise ValueError(") + code should include("Invalid Hugging Face model ID") + } + + it should "not leak raw user-input strings into the generated Python source" in { + // Sentinel value chosen to be distinctive and non-overlapping with anything + // else in the template. If our encoding regressed back to raw literals + // (e.g. `MODEL_ID = "MARKER_zXyq42"`), this assertion would fail. + val marker = "MARKER_zXyq42" + val code = + makeDesc(modelId = marker, promptColumn = marker, token = marker).generatePythonCode() + code should not include marker + } + + it should "clamp maxNewTokens into the 1-4096 range" in { + makeDesc(maxNewTokens = -5).generatePythonCode() should include( + "self.MAX_NEW_TOKENS = 1" + ) + makeDesc(maxNewTokens = 99999).generatePythonCode() should include( + "self.MAX_NEW_TOKENS = 4096" + ) + } + + it should "clamp temperature into the 0.0-2.0 range" in { + makeDesc(temperature = -1.0).generatePythonCode() should include( + "self.TEMPERATURE = 0.0" + ) + makeDesc(temperature = 5.0).generatePythonCode() should include( + "self.TEMPERATURE = 2.0" + ) + } + + it should "tolerate null @JsonProperty values and fall back to safe defaults" in { + // Every user-input field can land as null when the JSON deserializer is + // handed a workflow that omits the field. generatePythonCode must not + // throw on any combination — and the generated Python must still parse. + val desc = new HuggingFaceInferenceOpDesc() + desc.hfApiToken = null + desc.modelId = null + desc.promptColumn = null + desc.systemPrompt = null + desc.resultColumn = null + desc.task = null + desc.maxNewTokens = null + desc.temperature = null + desc.imageInput = null + desc.inputImageColumn = null + desc.audioInput = null + desc.inputAudioColumn = null + desc.contextColumn = null + desc.candidateLabels = null + desc.sentencesColumn = null + val code = desc.generatePythonCode() + code should include("class ProcessTableOperator(UDFTableOperator):") + code should include("def open(self):") + // System-prompt default is the empty-string sentinel (no fallback string + // injected) but the operator class still initializes the constant. + code should include("self.SYSTEM_PROMPT = ") + // maxNewTokens null path defaults to 256. + code should include("self.MAX_NEW_TOKENS = 256") + // temperature null path defaults to 0.7. + code should include("self.TEMPERATURE = 0.7") + } + + "TextGenCodegen" should "advertise text-generation as its canonical task" in { + TextGenCodegen.task shouldBe "text-generation" + } + + it should + "emit payload and parse snippets that don't depend on the CodegenContext" in { + // For text-generation, the codegen's only inputs to Python are static + // strings referencing self.* attributes — exercising both methods + // confirms they don't accidentally consume ctx fields (a future + // refactor regression would surface here). + val ctx = CodegenContext( + hfApiToken = "irrelevant", + modelId = "irrelevant", + promptColumn = "irrelevant", + resultColumn = "irrelevant", + task = "irrelevant", + systemPrompt = "irrelevant", + safeMaxTokens = 0, + safeTemp = 0.0 + ) + TextGenCodegen.payloadPython(ctx) should include("self.MODEL_ID") + TextGenCodegen.parsePython(ctx) should include("""body["choices"][0]["message"]["content"]""") + } + + "image task family" should + "route image-only tasks through ImageTaskCodegen (raw binary payload + image headers)" in { + val code = + makeDesc(task = "image-classification", inputImageColumn = "img").generatePythonCode() + code should include("self.IMAGE_INPUT = ") + code should include("self.INPUT_IMAGE_COLUMN = ") + code should include("if task in image_only_tasks:") + code should include("payload = current_image_bytes") + code should include("use_raw_binary_body = True") + code should include("raw_binary_headers = image_headers") + // image bytes resolution + image content-type response handling exist + code should include("self._read_image_input()") + code should include("self._read_binary_value") + code should include("self._compress_image_bytes") + code should include("""if content_type.startswith("image/"):""") + } + + it should "route VQA / document-QA through ImageTaskCodegen (base64 image + question payload)" in { + val code = makeDesc(task = "visual-question-answering").generatePythonCode() + code should include( + """elif task in ("visual-question-answering", "document-question-answering"):""" + ) + code should include("self._image_input_as_base64(current_image_bytes)") + code should include(""""question": prompt_value""") + } + + it should "route image-text-to-text through chat completions with embedded base64 image" in { + val code = makeDesc(task = "image-text-to-text").generatePythonCode() + code should include("""elif task == "image-text-to-text":""") + code should include("""data:image/png;base64,{img_b64}""") + code should include("self.MODEL_ID") + } + + it should "route image-to-image as raw binary and parse via _url_to_data_url on JSON response" in { + val code = makeDesc(task = "image-to-image").generatePythonCode() + code should include("""elif task == "image-to-image":""") + code should include("self._url_to_data_url(") + } + + it should + "register all 9 image task strings under the dispatcher (image-only + image+prompt)" in { + // Each image task should pull in ImageTaskCodegen's branch chain. + val imageTasks = Seq( + "image-classification", + "object-detection", + "image-segmentation", + "image-to-text", + "visual-question-answering", + "document-question-answering", + "zero-shot-image-classification", + "image-text-to-text", + "image-to-image" + ) + imageTasks.foreach { t => + val code = makeDesc(task = t).generatePythonCode() + code should include("if task in image_only_tasks:") + } + } + + "audio task family" should + "route ASR and audio-classification through AudioTaskCodegen as raw binary payloads" in { + val code = + makeDesc(task = "automatic-speech-recognition", inputAudioColumn = "audio") + .generatePythonCode() + code should include("self.AUDIO_INPUT = ") + code should include("self.INPUT_AUDIO_COLUMN = ") + code should include( + """audio_only_tasks = ("automatic-speech-recognition", "audio-classification")""" + ) + code should include("payload = current_audio_bytes") + code should include("raw_binary_headers = audio_headers") + code should include("self._read_audio_input()") + code should include( + """if content_type.startswith("audio/") or content_type.startswith("video/"):""" + ) + } + + it should "route text-to-speech through AudioTaskCodegen and normalize audio URLs" in { + val code = makeDesc(task = "text-to-speech").generatePythonCode() + code should include("""elif task == "text-to-speech":""") + code should include("""payload = {"inputs": prompt_value}""") + code should include("self._audio_url_to_data_url(") + code should include("data:audio/mpeg;base64") + } + + it should "register all audio task strings under the dispatcher" in { + AudioTaskCodegen.tasks should contain allOf ( + "automatic-speech-recognition", + "audio-classification", + "text-to-speech" + ) + AudioTaskCodegen.tasks.foreach { t => + val code = makeDesc(task = t, inputAudioColumn = "audio").generatePythonCode() + code should include("if task in audio_only_tasks:") + } + } + + "media generation task family" should + "route text-to-image through MediaGenCodegen and parse URL or b64 responses as data URLs" in { + val code = makeDesc(task = "text-to-image").generatePythonCode() + code should include("""if task in ("text-to-image", "text-to-video"):""") + code should include("""payload = {"inputs": prompt_value}""") + code should include("""if task == "text-to-image":""") + code should include("self._url_to_data_url(") + code should include("data:image/png;base64") + } + + it should "route text-to-video through MediaGenCodegen and normalize remote video URLs" in { + val code = makeDesc(task = "text-to-video").generatePythonCode() + code should include("""elif task == "text-to-video":""") + code should include("self._url_to_data_url(") + code should include("video/mp4") + } + + it should "register all media generation task strings under the dispatcher" in { + MediaGenCodegen.tasks should contain allOf ("text-to-image", "text-to-video") + MediaGenCodegen.tasks.foreach { t => + val code = makeDesc(task = t).generatePythonCode() + code should include("""if task in ("text-to-image", "text-to-video"):""") + } + } + + "qa and ranking task family" should + "route question-answering through QaRankingCodegen with context-column validation" in { + val code = makeDesc(task = "question-answering", contextColumn = "context").generatePythonCode() + code should include("self.CONTEXT_COLUMN = ") + code should include("""if task == "question-answering":""") + code should include("ctx_col = self.CONTEXT_COLUMN") + code should include("Context column") + code should include("""payload = {"inputs": {"question": prompt_value, "context": ctx_val}}""") + code should include("""return body.get("answer", json.dumps(body))""") + } + + it should "route table-question-answering with a precomputed table payload" in { + val code = makeDesc(task = "table-question-answering").generatePythonCode() + code should include("""if task == "table-question-answering":""") + code should include("table_dict = {}") + code should include("""payload = {"inputs": {"query": prompt_value, "table": table_dict}}""") + code should include("""return body.get("answer", json.dumps(body))""") + } + + it should "route zero-shot-classification with candidate labels" in { + val code = + makeDesc(task = "zero-shot-classification", candidateLabels = "positive,negative") + .generatePythonCode() + code should include("self.CANDIDATE_LABELS = ") + code should include("""elif task == "zero-shot-classification":""") + code should include("labels = [l.strip() for l in self.CANDIDATE_LABELS.split") + code should include(""""parameters": {"candidate_labels": labels}""") + } + + it should "route sentence-similarity and text-ranking with sentences-column validation" in { + Seq("sentence-similarity", "text-ranking").foreach { taskName => + val code = makeDesc(task = taskName, sentencesColumn = "sentences").generatePythonCode() + code should include("self.SENTENCES_COLUMN = ") + code should include("""elif task in ("sentence-similarity", "text-ranking"):""") + code should include("sent_col = self.SENTENCES_COLUMN") + code should include("Sentences column") + code should include(""""source_sentence": prompt_value""") + code should include(""""sentences": sentences_list""") + } + } + + it should "register all qa and ranking task strings under the dispatcher" in { + QaRankingCodegen.tasks should contain allOf ( + "question-answering", + "table-question-answering", + "zero-shot-classification", + "sentence-similarity", + "text-ranking" + ) + QaRankingCodegen.tasks.foreach { t => + val code = makeDesc(task = t, contextColumn = "context", sentencesColumn = "sentences") + .generatePythonCode() + code should include("""if task == "question-answering":""") + } + } + + "getOutputSchemas" should "add the result column as a STRING to the inherited schema" in { + val desc = makeDesc(resultColumn = "answer") + val inputSchema = Schema().add("prompt", AttributeType.STRING) + val out = desc.getOutputSchemas(Map(PortIdentity(0) -> inputSchema)) + val outSchema = out(desc.operatorInfo.outputPorts.head.id) + outSchema.getAttributeNames.contains("prompt") shouldBe true + outSchema.getAttributeNames.contains("answer") shouldBe true + outSchema.getAttribute("answer").getType shouldBe AttributeType.STRING + } + + it should "fall back to the default 'hf_response' name when resultColumn is empty" in { + val desc = makeDesc(resultColumn = "") + val inputSchema = Schema().add("prompt", AttributeType.STRING) + val out = desc.getOutputSchemas(Map(PortIdentity(0) -> inputSchema)) + val outSchema = out(desc.operatorInfo.outputPorts.head.id) + outSchema.getAttributeNames.contains("hf_response") shouldBe true + } +} diff --git a/frontend/src/app/app.module.ts b/frontend/src/app/app.module.ts index 511395365df..d9eea71b186 100644 --- a/frontend/src/app/app.module.ts +++ b/frontend/src/app/app.module.ts @@ -105,6 +105,9 @@ import { CoeditorUserIconComponent } from "./workspace/component/menu/coeditor-u import { AgentPanelComponent } from "./workspace/component/agent/agent-panel/agent-panel.component"; import { AgentChatComponent } from "./workspace/component/agent/agent-panel/agent-chat/agent-chat.component"; import { AgentRegistrationComponent } from "./workspace/component/agent/agent-panel/agent-registration/agent-registration.component"; +import { HuggingFaceImageUploadComponent } from "./workspace/component/hugging-face-image-upload/hugging-face-image-upload.component"; +import { HuggingFaceComponent } from "./workspace/component/hugging-face/hugging-face.component"; +import { HuggingFaceAudioUploadComponent } from "./workspace/component/hugging-face-audio-upload/hugging-face-audio-upload.component"; import { DatasetFileSelectorComponent } from "./workspace/component/dataset-file-selector/dataset-file-selector.component"; import { DatasetVersionSelectorComponent } from "./workspace/component/dataset-version-selector/dataset-version-selector.component"; import { DatasetSelectionModalComponent } from "./workspace/component/dataset-selection-modal/dataset-selection-modal.component"; @@ -329,6 +332,9 @@ registerLocaleData(en); AgentChatComponent, AgentRegistrationComponent, AgentInteractionComponent, + HuggingFaceComponent, + HuggingFaceAudioUploadComponent, + HuggingFaceImageUploadComponent, DatasetFileSelectorComponent, DatasetVersionSelectorComponent, DatasetSelectionModalComponent, diff --git a/frontend/src/app/common/formly/formly-config.ts b/frontend/src/app/common/formly/formly-config.ts index 707ddfa7975..c4fc54fd77f 100644 --- a/frontend/src/app/common/formly/formly-config.ts +++ b/frontend/src/app/common/formly/formly-config.ts @@ -29,6 +29,9 @@ import { CollabWrapperComponent } from "./collab-wrapper/collab-wrapper/collab-w import { FormlyRepeatDndComponent } from "./repeat-dnd/repeat-dnd.component"; import { UiUdfParametersComponent } from "../../workspace/component/ui-udf-parameters/ui-udf-parameters.component"; import { DatasetVersionSelectorComponent } from "../../workspace/component/dataset-version-selector/dataset-version-selector.component"; +import { HuggingFaceImageUploadComponent } from "../../workspace/component/hugging-face-image-upload/hugging-face-image-upload.component"; +import { HuggingFaceComponent } from "../../workspace/component/hugging-face/hugging-face.component"; +import { HuggingFaceAudioUploadComponent } from "../../workspace/component/hugging-face-audio-upload/hugging-face-audio-upload.component"; /** * Configuration for using Json Schema with Formly. @@ -80,6 +83,9 @@ export const TEXERA_FORMLY_CONFIG = { { name: "codearea", component: CodeareaCustomTemplateComponent }, { name: "inputautocomplete", component: DatasetFileSelectorComponent, wrappers: ["form-field"] }, { name: "datasetversionselector", component: DatasetVersionSelectorComponent, wrappers: ["form-field"] }, + { name: "huggingface", component: HuggingFaceComponent, wrappers: ["form-field"] }, + { name: "huggingface-audio-upload", component: HuggingFaceAudioUploadComponent, wrappers: ["form-field"] }, + { name: "huggingface-image-upload", component: HuggingFaceImageUploadComponent, wrappers: ["form-field"] }, { name: "repeat-section-dnd", component: FormlyRepeatDndComponent }, { name: "ui-udf-parameters", component: UiUdfParametersComponent, wrappers: ["form-field"] }, ], diff --git a/frontend/src/app/workspace/component/hugging-face-audio-upload/hugging-face-audio-upload.component.html b/frontend/src/app/workspace/component/hugging-face-audio-upload/hugging-face-audio-upload.component.html new file mode 100644 index 00000000000..a2b5f4d5133 --- /dev/null +++ b/frontend/src/app/workspace/component/hugging-face-audio-upload/hugging-face-audio-upload.component.html @@ -0,0 +1,61 @@ + + +
+
+ Audio files are uploaded to temporary backend storage and referenced from the operator, so larger clips can be used + without bloating the workflow JSON. +
+ + + +
+ +
+ {{ fileName || "Selected audio" }} + Uploading... + +
+
+ +
+ {{ errorMessage }} +
+
diff --git a/frontend/src/app/workspace/component/hugging-face-audio-upload/hugging-face-audio-upload.component.scss b/frontend/src/app/workspace/component/hugging-face-audio-upload/hugging-face-audio-upload.component.scss new file mode 100644 index 00000000000..0757524e04f --- /dev/null +++ b/frontend/src/app/workspace/component/hugging-face-audio-upload/hugging-face-audio-upload.component.scss @@ -0,0 +1,68 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +.hf-audio-upload { + display: flex; + flex-direction: column; + gap: 8px; +} + +.hf-audio-guidance { + color: #595959; + font-size: 12px; + line-height: 1.4; +} + +.hf-audio-upload-input { + width: 100%; +} + +.hf-audio-preview { + border: 1px solid #d9d9d9; + border-radius: 4px; + padding: 8px; +} + +.hf-audio-preview audio { + display: block; + width: 100%; +} + +.hf-audio-meta { + display: flex; + align-items: center; + justify-content: space-between; + gap: 8px; + margin-top: 8px; +} + +.hf-audio-meta span { + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.hf-audio-status { + color: #595959; + font-size: 12px; +} + +.hf-audio-error { + color: #cf1322; +} diff --git a/frontend/src/app/workspace/component/hugging-face-audio-upload/hugging-face-audio-upload.component.spec.ts b/frontend/src/app/workspace/component/hugging-face-audio-upload/hugging-face-audio-upload.component.spec.ts new file mode 100644 index 00000000000..c94766a8dad --- /dev/null +++ b/frontend/src/app/workspace/component/hugging-face-audio-upload/hugging-face-audio-upload.component.spec.ts @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { HuggingFaceAudioUploadComponent } from "./hugging-face-audio-upload.component"; + +describe("HuggingFaceAudioUploadComponent (unit)", () => { + it("should be defined", () => { + expect(HuggingFaceAudioUploadComponent).toBeDefined(); + }); + + it("should have the correct selector", () => { + const metadata = Reflect.getOwnPropertyDescriptor(HuggingFaceAudioUploadComponent, "__annotations"); + // Component decorator metadata is available via the Angular compiler; + // at minimum verify the class is importable and constructable metadata exists. + expect(HuggingFaceAudioUploadComponent.prototype).toBeDefined(); + }); +}); diff --git a/frontend/src/app/workspace/component/hugging-face-audio-upload/hugging-face-audio-upload.component.ts b/frontend/src/app/workspace/component/hugging-face-audio-upload/hugging-face-audio-upload.component.ts new file mode 100644 index 00000000000..1a10602ce4c --- /dev/null +++ b/frontend/src/app/workspace/component/hugging-face-audio-upload/hugging-face-audio-upload.component.ts @@ -0,0 +1,153 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { Component, OnDestroy, OnInit } from "@angular/core"; +import { CommonModule } from "@angular/common"; +import { FieldType, FieldTypeConfig } from "@ngx-formly/core"; +import { HttpClient } from "@angular/common/http"; +import { NzButtonModule } from "ng-zorro-antd/button"; +import { firstValueFrom } from "rxjs"; +import { AppSettings } from "../../../common/app-setting"; + +interface HuggingFaceAudioUploadResponse { + path: string; + fileName: string; +} + +@Component({ + selector: "texera-hugging-face-audio-upload", + templateUrl: "./hugging-face-audio-upload.component.html", + styleUrls: ["./hugging-face-audio-upload.component.scss"], + imports: [CommonModule, NzButtonModule], +}) +export class HuggingFaceAudioUploadComponent extends FieldType implements OnInit, OnDestroy { + fileName = ""; + errorMessage = ""; + isUploading = false; + private localPreviewUrl = ""; + + ngOnInit(): void { + if (typeof this.formControl.value === "string" && this.formControl.value.trim().length > 0) { + this.fileName = this.getDisplayName(this.formControl.value); + } + } + + constructor(private http: HttpClient) { + super(); + } + + get previewSrc(): string { + if (this.localPreviewUrl) { + return this.localPreviewUrl; + } + const value = this.formControl.value; + if (typeof value !== "string" || value.trim().length === 0) { + return ""; + } + if (value.startsWith("data:audio/")) { + return value; + } + return `${AppSettings.getApiEndpoint()}/huggingface/audio-preview?path=${encodeURIComponent(value)}`; + } + + ngOnDestroy(): void { + this.revokePreviewUrl(); + } + + async onFileSelected(event: Event): Promise { + this.errorMessage = ""; + const input = event.target as HTMLInputElement; + const file = input.files?.[0]; + + if (!file) { + return; + } + if (!file.type.startsWith("audio/")) { + this.errorMessage = "Choose an audio file."; + input.value = ""; + return; + } + this.revokePreviewUrl(); + this.localPreviewUrl = URL.createObjectURL(file); + this.isUploading = true; + + try { + const response = await firstValueFrom( + this.http.post( + `${AppSettings.getApiEndpoint()}/huggingface/upload-audio?filename=${encodeURIComponent(file.name)}`, + file, + { + headers: { + "Content-Type": "application/octet-stream", + }, + } + ) + ); + this.fileName = response.fileName || file.name; + this.formControl.setValue(response.path); + if (typeof this.key === "string" && this.model) { + this.model[this.key] = response.path; + } + this.formControl.markAsDirty(); + this.formControl.markAsTouched(); + this.formControl.updateValueAndValidity(); + } catch { + this.clearAudio(input, false); + this.errorMessage = "Could not upload this audio file."; + } finally { + this.isUploading = false; + } + } + + clearAudio(input: HTMLInputElement, clearError: boolean = true): void { + this.fileName = ""; + if (clearError) { + this.errorMessage = ""; + } + this.isUploading = false; + this.revokePreviewUrl(); + input.value = ""; + this.formControl.setValue(""); + if (typeof this.key === "string" && this.model) { + this.model[this.key] = ""; + } + this.formControl.markAsDirty(); + this.formControl.markAsTouched(); + this.formControl.updateValueAndValidity(); + } + + private revokePreviewUrl(): void { + if (this.localPreviewUrl) { + URL.revokeObjectURL(this.localPreviewUrl); + this.localPreviewUrl = ""; + } + } + + private getDisplayName(value: string): string { + const trimmedValue = value.trim(); + if (!trimmedValue) { + return ""; + } + if (trimmedValue.startsWith("data:audio/")) { + return "Selected audio"; + } + const segments = trimmedValue.split(/[\\/]/); + return segments[segments.length - 1] || "Selected audio"; + } +} diff --git a/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.html b/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.html new file mode 100644 index 00000000000..441c71f9c93 --- /dev/null +++ b/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.html @@ -0,0 +1,51 @@ + + +
+ + +
+ Uploaded Hugging Face task input +
+ {{ displayFileName || "Selected image" }} + +
+
+ +
+ {{ errorMessage }} +
+
diff --git a/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.scss b/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.scss new file mode 100644 index 00000000000..b292d8131ec --- /dev/null +++ b/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.scss @@ -0,0 +1,60 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +.hf-image-upload { + display: flex; + flex-direction: column; + gap: 8px; +} + +.hf-image-upload-input { + width: 100%; +} + +.hf-image-preview { + border: 1px solid #d9d9d9; + border-radius: 4px; + padding: 8px; +} + +.hf-image-preview img { + display: block; + width: 100%; + max-height: 220px; + object-fit: contain; + background: #f5f5f5; +} + +.hf-image-meta { + display: flex; + align-items: center; + justify-content: space-between; + gap: 8px; + margin-top: 8px; +} + +.hf-image-meta span { + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.hf-image-error { + color: #cf1322; +} diff --git a/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.spec.ts b/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.spec.ts new file mode 100644 index 00000000000..6bd947ef0ee --- /dev/null +++ b/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.spec.ts @@ -0,0 +1,146 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { ComponentFixture, TestBed } from "@angular/core/testing"; +import { FormControl } from "@angular/forms"; +import { HuggingFaceImageUploadComponent } from "./hugging-face-image-upload.component"; +import { commonTestProviders } from "../../../common/testing/test-utils"; + +describe("HuggingFaceImageUploadComponent", () => { + let component: HuggingFaceImageUploadComponent; + let fixture: ComponentFixture; + + beforeEach(async () => { + await TestBed.configureTestingModule({ + imports: [HuggingFaceImageUploadComponent], + providers: [...commonTestProviders], + }).compileComponents(); + + fixture = TestBed.createComponent(HuggingFaceImageUploadComponent); + component = fixture.componentInstance; + component.field = { + props: {}, + formControl: new FormControl(""), + key: "image", + model: {}, + } as any; + fixture.detectChanges(); + }); + + it("should create", () => { + expect(component).toBeTruthy(); + }); + + describe("derived view state", () => { + it("reports no image when formControl is empty", () => { + expect(component.hasImage).toBe(false); + expect(component.previewSrc).toBe(""); + expect(component.displayFileName).toBe(""); + }); + + it("reports an image when formControl holds a data URL", () => { + component.formControl.setValue("data:image/jpeg;base64,AAA"); + expect(component.hasImage).toBe(true); + expect(component.previewSrc).toBe("data:image/jpeg;base64,AAA"); + expect(component.displayFileName).toBe("Uploaded image"); + }); + + it("prefers the explicit filename over the fallback label", () => { + component.formControl.setValue("data:image/jpeg;base64,AAA"); + component.fileName = "cat.jpg"; + expect(component.displayFileName).toBe("cat.jpg"); + }); + }); + + describe("onFileSelected", () => { + function makeFileInput(file?: File): HTMLInputElement { + const input = document.createElement("input"); + input.type = "file"; + if (file) { + Object.defineProperty(input, "files", { + value: [file] as unknown as FileList, + configurable: true, + }); + } + return input; + } + + it("clears prior error and returns early when no file is provided", async () => { + component.errorMessage = "previous error"; + const input = makeFileInput(); + await component.onFileSelected({ target: input } as unknown as Event); + expect(component.errorMessage).toBe(""); + expect(component.formControl.value).toBe(""); + }); + + it("rejects non-image files and resets the input", async () => { + const txtFile = new File(["hi"], "note.txt", { type: "text/plain" }); + const input = makeFileInput(txtFile); + await component.onFileSelected({ target: input } as unknown as Event); + expect(component.errorMessage).toBe("Choose an image file."); + expect(component.hasImage).toBe(false); + }); + + it("reports an error when image compression fails", async () => { + // jsdom's Image never fires onload/onerror, so compressImage would hang + // forever. Stub FileReader so it synchronously fires onerror, which + // makes compressImage reject and exercises the catch branch. + const realFileReader = globalThis.FileReader; + class FailingFileReader { + onload: ((e: Event) => void) | null = null; + onerror: ((e: Event) => void) | null = null; + readAsDataURL() { + queueMicrotask(() => this.onerror?.(new Event("error"))); + } + } + (globalThis as any).FileReader = FailingFileReader; + try { + const imgFile = new File(["fake"], "broken.png", { type: "image/png" }); + const input = makeFileInput(imgFile); + await component.onFileSelected({ target: input } as unknown as Event); + expect(component.errorMessage).toBe("Could not prepare this image. Try a smaller image file."); + expect(component.hasImage).toBe(false); + } finally { + (globalThis as any).FileReader = realFileReader; + } + }); + }); + + describe("clearImage", () => { + it("resets file state, the form control, and any model value", () => { + (component.field as any).model = { image: "data:image/jpeg;base64,AAA" }; + component.formControl.setValue("data:image/jpeg;base64,AAA"); + component.fileName = "cat.jpg"; + component.errorMessage = "some error"; + + const input = document.createElement("input"); + input.type = "file"; + + component.clearImage(input); + + expect(component.fileName).toBe(""); + expect(component.errorMessage).toBe(""); + expect(input.value).toBe(""); + expect(component.formControl.value).toBe(""); + expect(component.formControl.dirty).toBe(true); + expect(component.formControl.touched).toBe(true); + expect((component.model as any).image).toBe(""); + }); + }); +}); diff --git a/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.ts b/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.ts new file mode 100644 index 00000000000..4b72e14aa57 --- /dev/null +++ b/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.ts @@ -0,0 +1,162 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { Component } from "@angular/core"; +import { CommonModule } from "@angular/common"; +import { FieldType, FieldTypeConfig } from "@ngx-formly/core"; +import { NzButtonModule } from "ng-zorro-antd/button"; + +// Keep in sync with PythonCodegenBase._compress_image_bytes(max_bytes) on the backend: +// the uploaded data URL must stay within the size the inference helpers expect. +const MAX_DATA_URL_LENGTH = 45000; +const INITIAL_MAX_DIMENSION = 512; +const MIN_MAX_DIMENSION = 160; +const INITIAL_JPEG_QUALITY = 0.75; +const MIN_JPEG_QUALITY = 0.35; + +@Component({ + selector: "texera-hugging-face-image-upload", + templateUrl: "./hugging-face-image-upload.component.html", + styleUrls: ["./hugging-face-image-upload.component.scss"], + imports: [CommonModule, NzButtonModule], +}) +export class HuggingFaceImageUploadComponent extends FieldType { + fileName = ""; + errorMessage = ""; + + get hasImage(): boolean { + const value = this.formControl.value; + return typeof value === "string" && value.startsWith("data:image/"); + } + + get previewSrc(): string { + return this.hasImage ? this.formControl.value : ""; + } + + get displayFileName(): string { + if (this.fileName) return this.fileName; + if (this.hasImage) return "Uploaded image"; + return ""; + } + + async onFileSelected(event: Event): Promise { + this.errorMessage = ""; + const input = event.target as HTMLInputElement; + const file = input.files?.[0]; + + if (!file) { + return; + } + if (!file.type.startsWith("image/")) { + this.errorMessage = "Choose an image file."; + input.value = ""; + return; + } + + try { + const dataUrl = await this.compressImage(file); + this.fileName = file.name; + this.formControl.setValue(dataUrl); + if (typeof this.key === "string" && this.model) { + this.model[this.key] = dataUrl; + } + this.formControl.markAsDirty(); + this.formControl.markAsTouched(); + this.formControl.updateValueAndValidity(); + } catch { + this.errorMessage = "Could not prepare this image. Try a smaller image file."; + input.value = ""; + } + } + + private compressImage(file: File): Promise { + const reader = new FileReader(); + const image = new Image(); + + return new Promise((resolve, reject) => { + reader.onload = () => { + if (typeof reader.result !== "string") { + reject(); + return; + } + image.onload = () => { + const compressed = this.renderCompressedDataUrl(image); + if (!compressed.startsWith("data:image/") || compressed.length > MAX_DATA_URL_LENGTH) { + reject(); + return; + } + resolve(compressed); + }; + image.onerror = () => reject(); + image.src = reader.result; + }; + reader.onerror = () => reject(); + reader.readAsDataURL(file); + }); + } + + private renderCompressedDataUrl(image: HTMLImageElement): string { + let maxDimension = INITIAL_MAX_DIMENSION; + let quality = INITIAL_JPEG_QUALITY; + let bestDataUrl = ""; + + while (maxDimension >= MIN_MAX_DIMENSION) { + const scale = Math.min(1, maxDimension / Math.max(image.width, image.height)); + const width = Math.max(1, Math.round(image.width * scale)); + const height = Math.max(1, Math.round(image.height * scale)); + const canvas = document.createElement("canvas"); + canvas.width = width; + canvas.height = height; + const ctx = canvas.getContext("2d"); + + if (!ctx) { + return bestDataUrl; + } + + ctx.drawImage(image, 0, 0, width, height); + quality = INITIAL_JPEG_QUALITY; + + while (quality >= MIN_JPEG_QUALITY) { + const dataUrl = canvas.toDataURL("image/jpeg", quality); + bestDataUrl = dataUrl; + if (dataUrl.length <= MAX_DATA_URL_LENGTH) { + return dataUrl; + } + quality -= 0.1; + } + + maxDimension = Math.floor(maxDimension * 0.75); + } + + return bestDataUrl; + } + + clearImage(input: HTMLInputElement): void { + this.fileName = ""; + this.errorMessage = ""; + input.value = ""; + this.formControl.setValue(""); + if (typeof this.key === "string" && this.model) { + this.model[this.key] = ""; + } + this.formControl.markAsDirty(); + this.formControl.markAsTouched(); + this.formControl.updateValueAndValidity(); + } +} diff --git a/frontend/src/app/workspace/component/hugging-face/hugging-face.component.html b/frontend/src/app/workspace/component/hugging-face/hugging-face.component.html new file mode 100644 index 00000000000..833231f3ab7 --- /dev/null +++ b/frontend/src/app/workspace/component/hugging-face/hugging-face.component.html @@ -0,0 +1,196 @@ + + +
+ + + + + + + + +
+ {{ tasksError }} + +
+ + + + + + + + + + + + + +
+ + Loading models... +
+ + +
+ {{ errorMessage }} + +
+ + +
+ +
+ Selected: + {{ formControl.value }} + +
+ + +
+ {{ isSearching ? 'No models found for "' + searchText + '".' : 'No models available.' }} +
+ + +
+ {{ model.id }} + + + + {{ model.downloads | number }} + + + + {{ model.likes | number }} + + +
+
+ + +
+ + Page {{ currentPage + 1 }} of {{ totalPages }} + +
+
+ + diff --git a/frontend/src/app/workspace/component/hugging-face/hugging-face.component.scss b/frontend/src/app/workspace/component/hugging-face/hugging-face.component.scss new file mode 100644 index 00000000000..70cf372296a --- /dev/null +++ b/frontend/src/app/workspace/component/hugging-face/hugging-face.component.scss @@ -0,0 +1,155 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +.hf-model-select-container { + width: 100%; +} + +.hf-section-label { + display: block; + font-size: 14px; + font-weight: normal; + color: rgba(0, 0, 0, 0.85); + line-height: 32px; + margin-top: 8px; + + .hf-required { + display: inline-block; + color: #ff4d4f; + font-size: 14px; + font-family: SimSun, sans-serif; + line-height: 1; + margin-right: 4px; + } +} + +.hf-loading { + display: flex; + align-items: center; + gap: 8px; + padding: 4px 0; + + .loading-text { + font-size: 12px; + color: #999; + } +} + +.hf-error { + display: flex; + align-items: center; + gap: 4px; + padding: 4px 0; + + .error-text { + font-size: 12px; + color: #ff4d4f; + } +} + +.hf-model-list { + border: 1px solid #d9d9d9; + border-radius: 4px; + max-height: 360px; + overflow-y: auto; +} + +.hf-selected-model { + display: flex; + align-items: center; + padding: 6px 10px; + background: #e6f7ff; + border-bottom: 1px solid #d9d9d9; + font-size: 12px; + + .hf-selected-label { + font-weight: 500; + margin-right: 6px; + color: rgba(0, 0, 0, 0.65); + } + + .hf-selected-value { + color: #1890ff; + font-weight: 500; + flex: 1; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + } +} + +.hf-empty { + padding: 16px; + text-align: center; + color: #999; + font-size: 12px; +} + +.hf-model-item { + display: flex; + align-items: center; + justify-content: space-between; + padding: 6px 10px; + cursor: pointer; + border-bottom: 1px solid #f0f0f0; + transition: background 0.15s; + + &:last-child { + border-bottom: none; + } + + &:hover { + background: #fafafa; + } + + &.hf-model-item-selected { + background: #e6f7ff; + } + + .hf-model-id { + font-size: 12px; + color: rgba(0, 0, 0, 0.85); + flex: 1; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; + margin-right: 8px; + } + + .hf-model-meta { + font-size: 11px; + color: #999; + white-space: nowrap; + flex-shrink: 0; + } +} + +.hf-pagination { + display: flex; + align-items: center; + justify-content: center; + gap: 12px; + padding: 8px 0; + margin-top: 4px; + + .hf-page-info { + font-size: 12px; + color: rgba(0, 0, 0, 0.65); + } +} diff --git a/frontend/src/app/workspace/component/hugging-face/hugging-face.component.spec.ts b/frontend/src/app/workspace/component/hugging-face/hugging-face.component.spec.ts new file mode 100644 index 00000000000..c60d92bc19b --- /dev/null +++ b/frontend/src/app/workspace/component/hugging-face/hugging-face.component.spec.ts @@ -0,0 +1,73 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { + HuggingFaceComponent, + HuggingFaceModelOption, + STATIC_TASK_OPTIONS, + invalidateHuggingFaceModelCache, +} from "./hugging-face.component"; + +describe("HuggingFaceComponent (unit)", () => { + beforeEach(() => { + invalidateHuggingFaceModelCache(); + }); + + it("should export a non-empty static task list", () => { + expect(STATIC_TASK_OPTIONS.length).toBeGreaterThan(0); + }); + + it("should include text-generation in static task options", () => { + const textGen = STATIC_TASK_OPTIONS.find(t => t.tag === "text-generation"); + expect(textGen).toBeTruthy(); + expect(textGen!.label).toBe("Text Generation"); + }); + + it("should include image tasks in static task options", () => { + const imageTasks = STATIC_TASK_OPTIONS.filter(t => + ["image-classification", "object-detection", "image-segmentation", "image-to-text"].includes(t.tag) + ); + expect(imageTasks.length).toBe(4); + }); + + it("should include audio tasks in static task options", () => { + const audioTasks = STATIC_TASK_OPTIONS.filter(t => + ["automatic-speech-recognition", "audio-classification", "text-to-speech"].includes(t.tag) + ); + expect(audioTasks.length).toBe(3); + }); + + it("should include QA/ranking tasks in static task options", () => { + const qaTasks = STATIC_TASK_OPTIONS.filter(t => + ["question-answering", "zero-shot-classification", "sentence-similarity", "text-ranking"].includes(t.tag) + ); + expect(qaTasks.length).toBe(4); + }); + + it("should clear caches on invalidateHuggingFaceModelCache", () => { + // Just verify it doesn't throw — the function clears module-level Maps + expect(() => invalidateHuggingFaceModelCache()).not.toThrow(); + }); + + it("should have unique tags in static task options", () => { + const tags = STATIC_TASK_OPTIONS.map(t => t.tag); + const uniqueTags = new Set(tags); + expect(uniqueTags.size).toBe(tags.length); + }); +}); diff --git a/frontend/src/app/workspace/component/hugging-face/hugging-face.component.ts b/frontend/src/app/workspace/component/hugging-face/hugging-face.component.ts new file mode 100644 index 00000000000..a28cf907c20 --- /dev/null +++ b/frontend/src/app/workspace/component/hugging-face/hugging-face.component.ts @@ -0,0 +1,543 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { Component, OnInit, OnDestroy, ChangeDetectorRef } from "@angular/core"; +import { CommonModule } from "@angular/common"; +import { FormsModule } from "@angular/forms"; +import { FieldType, FieldTypeConfig, FormlyModule } from "@ngx-formly/core"; +import { HttpClient } from "@angular/common/http"; +import { NzSelectModule } from "ng-zorro-antd/select"; +import { NzInputModule } from "ng-zorro-antd/input"; +import { NzSpinModule } from "ng-zorro-antd/spin"; +import { NzButtonModule } from "ng-zorro-antd/button"; +import { NzIconModule } from "ng-zorro-antd/icon"; +import { AppSettings } from "../../../common/app-setting"; +import { Subject, Subscription } from "rxjs"; +import { takeUntil } from "rxjs/operators"; + +export interface HuggingFaceModelOption { + id: string; + label: string; + pipeline_tag?: string; + downloads?: number; + likes?: number; +} + +export interface HuggingFaceTaskOption { + tag: string; + label: string; +} + +// ── Static fallback task list (used when the dynamic fetch fails) ── +export const STATIC_TASK_OPTIONS: HuggingFaceTaskOption[] = [ + { tag: "text-generation", label: "Text Generation" }, + { tag: "automatic-speech-recognition", label: "Automatic Speech Recognition" }, + { tag: "audio-classification", label: "Audio Classification" }, + { tag: "text-classification", label: "Text Classification" }, + { tag: "text-to-speech", label: "Text to Speech" }, + { tag: "token-classification", label: "Token Classification" }, + { tag: "question-answering", label: "Question Answering" }, + { tag: "table-question-answering", label: "Table Question Answering" }, + { tag: "zero-shot-classification", label: "Zero-Shot Classification" }, + { tag: "translation", label: "Translation" }, + { tag: "summarization", label: "Summarization" }, + { tag: "feature-extraction", label: "Feature Extraction" }, + { tag: "fill-mask", label: "Fill-Mask" }, + { tag: "sentence-similarity", label: "Sentence Similarity" }, + { tag: "text-ranking", label: "Text Ranking" }, + { tag: "image-classification", label: "Image Classification" }, + { tag: "object-detection", label: "Object Detection" }, + { tag: "image-segmentation", label: "Image Segmentation" }, + { tag: "image-to-text", label: "Image to Text" }, + { tag: "visual-question-answering", label: "Visual Question Answering" }, + { tag: "document-question-answering", label: "Document Question Answering" }, + { tag: "zero-shot-image-classification", label: "Zero-Shot Image Classification" }, +]; + +// Keep legacy export for any other code that imports it +export const TASK_TAG_MAP: Record = {}; +for (const { tag, label } of STATIC_TASK_OPTIONS) { + TASK_TAG_MAP[label] = tag; +} +export const TASK_NAMES = STATIC_TASK_OPTIONS.map(t => t.label); + +const PAGE_SIZE = 50; + +// ── Module-level caches (reused across component instances) ── +const allModelsByTag: Map = new Map(); +const inFlightByTag: Map = new Map(); +const errorByTag: Map = new Map(); + +let cachedTaskOptions: HuggingFaceTaskOption[] | null = null; +let tasksFetchSubscription: Subscription | null = null; +let tasksFetchError: string | null = null; + +/** Clear all cached data (useful for tests or manual invalidation). */ +export function invalidateHuggingFaceModelCache(): void { + allModelsByTag.clear(); + errorByTag.clear(); + inFlightByTag.forEach(sub => sub.unsubscribe()); + inFlightByTag.clear(); + cachedTaskOptions = null; + tasksFetchError = null; + tasksFetchSubscription?.unsubscribe(); + tasksFetchSubscription = null; +} + +@Component({ + selector: "texera-hugging-face-model-select", + templateUrl: "./hugging-face.component.html", + styleUrls: ["hugging-face.component.scss"], + imports: [ + CommonModule, + FormsModule, + NzSelectModule, + NzInputModule, + NzSpinModule, + NzButtonModule, + NzIconModule, + FormlyModule, + ], +}) +export class HuggingFaceComponent extends FieldType implements OnInit, OnDestroy { + private readonly taskScopedKeys = [ + "modelId", + "promptColumn", + "imageInput", + "audioInput", + "inputImageColumn", + "inputAudioColumn", + "candidateLabels", + "sentencesColumn", + "contextColumn", + "systemPrompt", + "maxNewTokens", + "temperature", + ] as const; + private readonly taskStateByTag = new Map>>(); + // ── Task state ── + taskOptions: HuggingFaceTaskOption[] = cachedTaskOptions ?? STATIC_TASK_OPTIONS; + selectedTaskTag = "text-generation"; + tasksLoading = false; + tasksError: string | null = null; + + // ── All models for the current task (fetched once from backend, cached) ── + private allModels: HuggingFaceModelOption[] = []; + + // ── Displayed state ── + pagedModels: HuggingFaceModelOption[] = []; + currentPage = 0; + totalPages = 0; + + loading = false; + errorMessage: string | null = null; + + // ── Search state (client-side filtering over ALL models) ── + searchText = ""; + private filteredModels: HuggingFaceModelOption[] | null = null; + + private readonly destroy$ = new Subject(); + private subscription: Subscription | null = null; + + constructor( + private http: HttpClient, + private cdr: ChangeDetectorRef + ) { + super(); + } + + ngOnInit(): void { + const savedTag = this.getCurrentTaskTag(); + this.selectedTaskTag = savedTag ?? this.selectedTaskTag; + this.syncTaskSelection(this.selectedTaskTag, false); + this.loadTasks(); + this.loadAllModels(); + // Formly can attach sibling controls after this field initializes. + // Re-sync once the control tree settles so a fresh operator starts in a valid task state. + setTimeout(() => this.syncTaskSelection(this.getCurrentTaskTag() ?? this.selectedTaskTag, false), 0); + } + + ngOnDestroy(): void { + this.destroy$.next(); + this.destroy$.complete(); + this.subscription?.unsubscribe(); + } + + // ── Task loading ── + + /** + * Fetch available pipeline tags from the backend, which proxies HuggingFace's /api/tasks. + * Falls back to STATIC_TASK_OPTIONS if the fetch fails. + */ + private loadTasks(): void { + // Already fetched and cached + if (cachedTaskOptions !== null) { + this.taskOptions = cachedTaskOptions; + return; + } + + // Previous fetch errored — show static list, don't retry automatically + if (tasksFetchError !== null) { + this.tasksError = tasksFetchError; + this.taskOptions = STATIC_TASK_OPTIONS; + return; + } + + // Another component instance already has a fetch in flight — wait for it + if (tasksFetchSubscription !== null) { + this.tasksLoading = true; + // Poll for completion (the module-level cache will be set when done) + const poll = setInterval(() => { + if (cachedTaskOptions !== null || tasksFetchError !== null) { + clearInterval(poll); + this.tasksLoading = false; + this.taskOptions = cachedTaskOptions ?? STATIC_TASK_OPTIONS; + if (tasksFetchError) this.tasksError = tasksFetchError; + this.cdr.detectChanges(); + } + }, 200); + return; + } + + this.tasksLoading = true; + this.tasksError = null; + this.cdr.detectChanges(); + + tasksFetchSubscription = this.http + .get(`${AppSettings.getApiEndpoint()}/huggingface/tasks`) + .pipe(takeUntil(this.destroy$)) + .subscribe({ + next: tasks => { + tasksFetchSubscription = null; + cachedTaskOptions = tasks.length > 0 ? tasks : STATIC_TASK_OPTIONS; + this.taskOptions = cachedTaskOptions; + this.tasksLoading = false; + this.cdr.detectChanges(); + }, + error: (err: unknown) => { + console.error("Failed to load HuggingFace tasks:", err); + tasksFetchSubscription = null; + tasksFetchError = "Could not load tasks from Hugging Face. Using default list."; + this.tasksError = tasksFetchError; + this.taskOptions = STATIC_TASK_OPTIONS; + this.tasksLoading = false; + this.cdr.detectChanges(); + }, + }); + } + + retryTasksLoad(): void { + tasksFetchError = null; + this.tasksError = null; + this.loadTasks(); + } + + // ── Task selection ── + + onTaskSelected(tag: string): void { + const previousTask = this.getCurrentTaskTag() ?? this.selectedTaskTag; + this.snapshotTaskState(previousTask); + this.syncTaskSelection(tag, true); + this.restoreTaskState(tag); + this.searchText = ""; + this.filteredModels = null; + this.loadAllModels(); + } + + // ── Data loading ── + + /** + * Fetch ALL models for the selected task. + * The backend paginates through HF Hub internally and caches the result. + * The first request per task may be slow; subsequent requests are instant. + */ + private loadAllModels(): void { + const tag = this.selectedTaskTag || "text-generation"; + + this.loading = false; + this.errorMessage = null; + + // Fast path: cached on the frontend + if (allModelsByTag.has(tag)) { + this.allModels = allModelsByTag.get(tag)!; + this.goToPage(0); + return; + } + + // Previous error + if (errorByTag.has(tag)) { + this.errorMessage = errorByTag.get(tag)!; + this.allModels = []; + this.pagedModels = []; + this.totalPages = 0; + return; + } + + // Cancel previous + this.subscription?.unsubscribe(); + this.subscription = null; + + this.allModels = []; + this.pagedModels = []; + this.totalPages = 0; + + // Show spinner immediately for the initial fetch — it can take a while + // as the backend pages through HF Hub for the first time. + this.loading = true; + this.cdr.detectChanges(); + + this.subscription = this.http + .get( + `${AppSettings.getApiEndpoint()}/huggingface/models?task=${encodeURIComponent(tag)}` + ) + .pipe(takeUntil(this.destroy$)) + .subscribe({ + next: models => { + allModelsByTag.set(tag, models); + inFlightByTag.delete(tag); + this.loading = false; + this.allModels = models; + this.goToPage(0); + }, + error: (err: unknown) => { + console.error(`Failed to load HuggingFace models for task '${tag}':`, err); + const msg = "Failed to load models. Click retry to try again."; + errorByTag.set(tag, msg); + inFlightByTag.delete(tag); + this.loading = false; + this.errorMessage = msg; + this.cdr.detectChanges(); + }, + }); + + inFlightByTag.set(tag, this.subscription); + } + + // ── Pagination (client-side over the active list) ── + + private get activeList(): HuggingFaceModelOption[] { + return this.filteredModels !== null ? this.filteredModels : this.allModels; + } + + goToPage(page: number): void { + const list = this.activeList; + this.totalPages = Math.max(1, Math.ceil(list.length / PAGE_SIZE)); + this.currentPage = Math.min(page, this.totalPages - 1); + const start = this.currentPage * PAGE_SIZE; + this.pagedModels = list.slice(start, start + PAGE_SIZE); + this.cdr.detectChanges(); + } + + prevPage(): void { + if (this.currentPage > 0) { + this.goToPage(this.currentPage - 1); + } + } + + nextPage(): void { + if (this.currentPage < this.totalPages - 1) { + this.goToPage(this.currentPage + 1); + } + } + + get hasNextPage(): boolean { + return this.currentPage < this.totalPages - 1; + } + + retryLoad(): void { + const tag = this.selectedTaskTag || "text-generation"; + errorByTag.delete(tag); + this.loadAllModels(); + } + + // ── Search (client-side filter over ALL cached models) ── + + onSearchInput(query: string): void { + this.searchText = query; + if (!query.trim()) { + this.filteredModels = null; + } else { + const lower = query.toLowerCase(); + this.filteredModels = this.allModels.filter(m => m.id.toLowerCase().includes(lower)); + } + this.goToPage(0); + } + + clearSearch(): void { + this.searchText = ""; + this.filteredModels = null; + this.goToPage(0); + } + + get isSearching(): boolean { + return this.filteredModels !== null; + } + + // ── Model selection ── + + onModelSelected(modelId: string): void { + this.formControl.setValue(modelId); + } + + // ── Private helpers ── + + private getCurrentTaskTag(): string | undefined { + const fromModel = this.model?.task; + if (typeof fromModel === "string" && fromModel.trim().length > 0) { + return fromModel; + } + const fromParentControl = this.formControl?.parent?.get("task")?.value; + if (typeof fromParentControl === "string" && fromParentControl.trim().length > 0) { + return fromParentControl; + } + const fromFieldForm = this.field.form?.get("task")?.value; + if (typeof fromFieldForm === "string" && fromFieldForm.trim().length > 0) { + return fromFieldForm; + } + return undefined; + } + + private persistTaskSelection(tag: string): void { + // 1. Update the backing model FIRST so expression functions read the new value. + if (this.model) { + this.model.task = tag; + } + + // 2. Update the hidden task form control. Using emitEvent: true (default) + // ensures formly picks up the change and re-evaluates all sibling expressions. + const taskControlFromField = this.field.form?.get("task"); + if (taskControlFromField) { + taskControlFromField.setValue(tag); + } + + const taskControlFromParent = this.formControl?.parent?.get("task"); + if (taskControlFromParent && taskControlFromParent !== taskControlFromField) { + taskControlFromParent.setValue(tag); + } + + // 3. Force formly to re-evaluate ALL field expressions (not just this field's subtree). + // this.field is the modelId field; its parent covers all sibling fields. + const rootField = this.field.parent ?? this.field; + this.field.options?.detectChanges?.(rootField); + } + + private syncTaskSelection(tag: string, resetTaskSpecificFields: boolean): void { + this.selectedTaskTag = tag; + if (resetTaskSpecificFields) { + this.resetTaskStateForFirstVisit(tag); + } + this.persistTaskSelection(tag); + this.refreshTaskScopedValidity(); + } + + private refreshTaskScopedValidity(): void { + const keys = [ + "task", + "modelId", + "promptColumn", + "imageInput", + "audioInput", + "inputImageColumn", + "inputAudioColumn", + "candidateLabels", + "sentencesColumn", + "contextColumn", + "systemPrompt", + "maxNewTokens", + "temperature", + ]; + for (const key of keys) { + const control = this.field.form?.get(key) ?? this.formControl?.parent?.get(key); + control?.updateValueAndValidity({ emitEvent: false }); + } + this.field.form?.updateValueAndValidity({ emitEvent: false }); + this.formControl?.parent?.updateValueAndValidity({ emitEvent: false }); + + // Emit a single value change after all fields are settled so the + // workflow action service picks up the new operator properties. + this.formControl?.parent?.updateValueAndValidity({ emitEvent: true }); + } + + private snapshotTaskState(tag: string): void { + if (!tag) { + return; + } + const snapshot: Partial> = {}; + for (const key of this.taskScopedKeys) { + snapshot[key] = this.readFieldValue(key); + } + this.taskStateByTag.set(tag, snapshot); + } + + private restoreTaskState(tag: string): void { + const snapshot = this.taskStateByTag.get(tag); + if (!snapshot) { + return; + } + for (const key of this.taskScopedKeys) { + if (Object.prototype.hasOwnProperty.call(snapshot, key)) { + this.writeFieldValue(key, snapshot[key]); + } + } + this.refreshTaskScopedValidity(); + } + + private resetTaskStateForFirstVisit(tag: string): void { + if (this.taskStateByTag.has(tag)) { + return; + } + const defaults: Partial> = { + modelId: "", + promptColumn: "", + imageInput: "", + audioInput: "", + inputImageColumn: "", + inputAudioColumn: "", + candidateLabels: "", + sentencesColumn: "", + contextColumn: "", + systemPrompt: "You are a helpful assistant.", + maxNewTokens: 256, + temperature: 0.7, + }; + for (const key of this.taskScopedKeys) { + this.writeFieldValue(key, defaults[key] ?? ""); + } + } + + private readFieldValue(key: (typeof this.taskScopedKeys)[number]): unknown { + const control = this.field.form?.get(key) ?? this.formControl?.parent?.get(key); + if (control) { + return control.value; + } + return this.model?.[key]; + } + + private writeFieldValue(key: (typeof this.taskScopedKeys)[number], value: unknown): void { + const control = this.field.form?.get(key) ?? this.formControl?.parent?.get(key); + if (control) { + control.setValue(value, { emitEvent: false }); + control.markAsDirty(); + control.updateValueAndValidity({ emitEvent: false }); + } + if (this.model) { + (this.model as Record)[key] = value; + } + } +} diff --git a/frontend/src/assets/operator_images/HuggingFace.png b/frontend/src/assets/operator_images/HuggingFace.png new file mode 100644 index 00000000000..673b8ea9077 Binary files /dev/null and b/frontend/src/assets/operator_images/HuggingFace.png differ diff --git a/frontend/src/assets/sample-image.png b/frontend/src/assets/sample-image.png new file mode 100644 index 00000000000..c28d120ab74 Binary files /dev/null and b/frontend/src/assets/sample-image.png differ