From d4ce4ba373ef0e63e6620e6ca7cae90309aa3e69 Mon Sep 17 00:00:00 2001 From: Derrick Williams Date: Wed, 8 Apr 2026 18:25:31 +0000 Subject: [PATCH 1/3] change runinference yaml name to vertexai --- .../yaml/tests/{runinference.yaml => runinference_vertexai.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sdks/python/apache_beam/yaml/tests/{runinference.yaml => runinference_vertexai.yaml} (100%) diff --git a/sdks/python/apache_beam/yaml/tests/runinference.yaml b/sdks/python/apache_beam/yaml/tests/runinference_vertexai.yaml similarity index 100% rename from sdks/python/apache_beam/yaml/tests/runinference.yaml rename to sdks/python/apache_beam/yaml/tests/runinference_vertexai.yaml From f10c5eeb1d5cf8077431ae34e944b15fa58e4a16 Mon Sep 17 00:00:00 2001 From: Derrick Williams Date: Wed, 8 Apr 2026 18:26:05 +0000 Subject: [PATCH 2/3] add huggingface support --- sdks/python/apache_beam/yaml/yaml_ml.py | 46 +++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/sdks/python/apache_beam/yaml/yaml_ml.py b/sdks/python/apache_beam/yaml/yaml_ml.py index 51f18c733046..88ac2b52d3bc 100644 --- a/sdks/python/apache_beam/yaml/yaml_ml.py +++ b/sdks/python/apache_beam/yaml/yaml_ml.py @@ -282,6 +282,52 @@ def inference_output_type(self): ('model_id', Optional[str])]) +@ModelHandlerProvider.register_handler_type('HuggingFacePipeline') +class HuggingFacePipelineProvider(ModelHandlerProvider): + def __init__( + self, + task: str = "", + model: str = "", + preprocess: Optional[dict[str, str]] = None, + postprocess: Optional[dict[str, str]] = None, + device: Optional[str] = None, + inference_fn: Optional[dict[str, str]] = None, + load_pipeline_args: Optional[dict[str, Any]] = None, + **kwargs): + try: + from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler + except ImportError: + raise ValueError( + 'Unable to import HuggingFacePipelineModelHandler. Please ' + 'install transformers dependencies.') + + kwargs = {k: v for k, v in kwargs.items() if not k.startswith('_')} + + inference_fn_obj = self.parse_processing_transform( + inference_fn, 'inference_fn') if inference_fn else None + + handler_kwargs = {} + if inference_fn_obj: + handler_kwargs['inference_fn'] = inference_fn_obj + + _handler = HuggingFacePipelineModelHandler( + task=task, + model=model, + device=device, + load_pipeline_args=load_pipeline_args, + **handler_kwargs, + **kwargs) + + super().__init__(_handler, preprocess, postprocess) + + @staticmethod + def validate(model_handler_spec): + pass + + def inference_output_type(self): + return Any + + @beam.ptransform.ptransform_fn def run_inference( pcoll, From 1e9bf433d37d0dae249336d6b9bc4514e2f2587f Mon Sep 17 00:00:00 2001 From: Derrick Williams Date: Wed, 8 Apr 2026 18:27:20 +0000 Subject: [PATCH 3/3] add huggingface test --- .../yaml/tests/runinference_huggingface.yaml | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 sdks/python/apache_beam/yaml/tests/runinference_huggingface.yaml diff --git a/sdks/python/apache_beam/yaml/tests/runinference_huggingface.yaml b/sdks/python/apache_beam/yaml/tests/runinference_huggingface.yaml new file mode 100644 index 000000000000..7c429e6067a1 --- /dev/null +++ b/sdks/python/apache_beam/yaml/tests/runinference_huggingface.yaml @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +fixtures: + - name: mock_pipeline + type: unittest.mock.patch + config: + target: transformers.pipeline + +pipelines: + - pipeline: + type: chain + transforms: + - type: Create + config: + elements: + - text: "Hello world" + - text: "Bye world" + - type: RunInference + config: + model_handler: + type: "HuggingFacePipeline" + config: + task: "text-classification" + model: "unused" + inference_fn: + callable: | + def mock_inference(batch, pipeline, inference_args): + return [[dict(label='POSITIVE', score=0.9)] for _ in batch] + preprocess: + callable: 'lambda x: x.text' + - type: MapToFields + config: + language: python + fields: + text: text + inference: + callable: | + def get_json(x): + import json + return json.dumps(x.inference.inference, indent=0).strip() + - type: AssertEqual + config: + elements: + - text: "Hello world" + inference: "[\n{\n\"label\": \"POSITIVE\",\n\"score\": 0.9\n}\n]" + - text: "Bye world" + inference: "[\n{\n\"label\": \"POSITIVE\",\n\"score\": 0.9\n}\n]" + options: + yaml_experimental_features: ['ML']