Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions flink-python/pyflink/datastream/data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,92 @@ def process_element(self, value, ctx: 'ProcessFunction.Context'):
return self.process(FilterProcessFunctionAdapter(func), output_type=output_type) \
.name("Filter")

def infer(
self,
model: str,
input_col: str,
output_col: str,
batch_size: int = 32,
max_batch_timeout_ms: int = 100,
model_warmup: bool = True,
device: str = "cpu",
num_workers: int = 1,
task_type: str = "embedding",
**kwargs
) -> 'DataStream':
"""
Applies AI model inference on a DataStream.

This method provides an easy-to-use interface for performing machine learning inference
on streaming data. It handles model lifecycle management, batching, and resource optimization
automatically.

Example::

>>> # Text embedding
>>> embeddings = data_stream.infer(
... model="sentence-transformers/all-MiniLM-L6-v2",
... input_col="text",
... output_col="embedding",
... device="cuda:0"
... )
>>>
>>> # Sentiment classification
>>> sentiments = data_stream.infer(
... model="distilbert-base-uncased-finetuned-sst-2-english",
... input_col="text",
... output_col="sentiment",
... task_type="classification"
... )

:param model: Model name (HuggingFace) or local path. Examples:
- "sentence-transformers/all-MiniLM-L6-v2"
- "bert-base-uncased"
- "/path/to/local/model"
:param input_col: Name of the input column to perform inference on
:param output_col: Name of the output column to store inference results
:param batch_size: Number of records to batch together for inference (default: 32)
:param max_batch_timeout_ms: Maximum time to wait for a batch in milliseconds (default: 100)
:param model_warmup: Whether to warm up the model on initialization (default: True)
:param device: Device for inference. Options: "cpu", "cuda:0", "cuda:1", etc. (default: "cpu")
:param num_workers: Number of Python worker processes (default: 1)
:param task_type: Type of inference task: "embedding", "classification", "generation" (default: "embedding")
:param kwargs: Additional configuration options
:return: A new DataStream containing inference results

.. note::
This feature requires the following Python packages to be installed:

- torch>=2.0.0
- transformers>=4.30.0

For GPU inference, CUDA must be properly configured.

.. versionadded:: 1.20
"""
from pyflink.ml.inference import InferenceConfig, InferenceFunction

# Create inference configuration
config = InferenceConfig(
model=model,
input_col=input_col,
output_col=output_col,
batch_size=batch_size,
max_batch_timeout_ms=max_batch_timeout_ms,
warmup_enabled=model_warmup,
device=device,
num_workers=num_workers,
task_type=task_type,
**kwargs
)

# Create inference function
inference_func = InferenceFunction(config)

# Apply inference
return self.map(inference_func, output_type=self.get_type()) \
.name(f"Inference[{model}]")

def window_all(self, window_assigner: WindowAssigner) -> 'AllWindowedStream':
"""
Windows this data stream to a AllWindowedStream, which evaluates windows over a non key
Expand Down
206 changes: 206 additions & 0 deletions flink-python/pyflink/examples/datastream/text_embedding_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################

"""
Example: Text Embedding with PyFlink AI Inference

This example demonstrates how to use the infer() method to generate text embeddings
using a pre-trained model from HuggingFace.
"""

from pyflink.datastream import StreamExecutionEnvironment


def text_embedding_example():
"""Example of generating text embeddings."""

# Create execution environment
env = StreamExecutionEnvironment.get_execution_environment()
env.set_parallelism(1)

# Sample text data
text_data = [
{"id": 1, "text": "Apache Flink is a framework for stateful computations"},
{"id": 2, "text": "PyFlink brings Python support to Apache Flink"},
{"id": 3, "text": "Machine learning inference in streaming applications"},
{"id": 4, "text": "Real-time AI model serving with Flink"},
{"id": 5, "text": "Distributed machine learning at scale"}
]

# Create data stream
data_stream = env.from_collection(text_data)

# Apply inference - generate embeddings
# Using a small, fast model for demonstration
embeddings = data_stream.infer(
model="sentence-transformers/all-MiniLM-L6-v2",
input_col="text",
output_col="embedding",
batch_size=2,
device="cpu",
model_warmup=True
)

# Print results
embeddings.print()

# Execute
env.execute("Text Embedding Example")


def sentiment_classification_example():
"""Example of sentiment classification."""

env = StreamExecutionEnvironment.get_execution_environment()
env.set_parallelism(1)

# Sample reviews
reviews = [
{"id": 1, "review": "This product is amazing! I love it!"},
{"id": 2, "review": "Terrible quality, very disappointed."},
{"id": 3, "review": "It's okay, nothing special."},
{"id": 4, "review": "Exceeded my expectations! Highly recommended."},
{"id": 5, "review": "Waste of money, do not buy."}
]

data_stream = env.from_collection(reviews)

# Apply sentiment classification
sentiments = data_stream.infer(
model="distilbert-base-uncased-finetuned-sst-2-english",
input_col="review",
output_col="sentiment",
task_type="classification",
batch_size=2,
device="cpu"
)

sentiments.print()

env.execute("Sentiment Classification Example")


def kafka_realtime_inference_example():
"""Example of real-time inference with Kafka."""
from pyflink.datastream.connectors.kafka import FlinkKafkaConsumer, FlinkKafkaProducer
from pyflink.common.serialization import SimpleStringSchema
import json

env = StreamExecutionEnvironment.get_execution_environment()
env.set_parallelism(1)

# Kafka consumer configuration
kafka_props = {
'bootstrap.servers': 'localhost:9092',
'group.id': 'flink-inference-group'
}

# Create Kafka consumer
kafka_consumer = FlinkKafkaConsumer(
topics='input-texts',
deserialization_schema=SimpleStringSchema(),
properties=kafka_props
)

# Read from Kafka
kafka_stream = env.add_source(kafka_consumer)

# Parse JSON
parsed_stream = kafka_stream.map(lambda x: json.loads(x))

# Apply inference
inference_result = parsed_stream.infer(
model="sentence-transformers/all-MiniLM-L6-v2",
input_col="text",
output_col="embedding",
batch_size=32,
max_batch_timeout_ms=50,
device="cuda:0", # Use GPU for faster inference
model_warmup=True
)

# Convert back to JSON
output_stream = inference_result.map(lambda x: json.dumps(x))

# Write to Kafka
kafka_producer = FlinkKafkaProducer(
topic='output-embeddings',
serialization_schema=SimpleStringSchema(),
producer_config=kafka_props
)

output_stream.add_sink(kafka_producer)

env.execute("Kafka Real-time Inference")


def custom_model_example():
"""Example using a local custom model."""

env = StreamExecutionEnvironment.get_execution_environment()
env.set_parallelism(1)

data = [
{"id": 1, "text": "Example text 1"},
{"id": 2, "text": "Example text 2"}
]

data_stream = env.from_collection(data)

# Use local model path
result = data_stream.infer(
model="/path/to/local/model", # Local model directory
input_col="text",
output_col="features",
device="cpu",
model_cache_dir="/tmp/model-cache" # Cache directory
)

result.print()

env.execute("Custom Model Example")


if __name__ == '__main__':
import sys

if len(sys.argv) < 2:
print("Usage: python text_embedding_example.py <example>")
print("Examples:")
print(" - embedding: Text embedding generation")
print(" - sentiment: Sentiment classification")
print(" - kafka: Kafka real-time inference")
print(" - custom: Custom local model")
sys.exit(1)

example = sys.argv[1]

if example == "embedding":
text_embedding_example()
elif example == "sentiment":
sentiment_classification_example()
elif example == "kafka":
kafka_realtime_inference_example()
elif example == "custom":
custom_model_example()
else:
print(f"Unknown example: {example}")
sys.exit(1)
21 changes: 21 additions & 0 deletions flink-python/pyflink/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
PyFlink ML module for machine learning and AI inference.
"""

__all__ = ['inference']
49 changes: 49 additions & 0 deletions flink-python/pyflink/ml/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
PyFlink AI inference module.

This module provides easy-to-use AI model inference capabilities for PyFlink,
including model lifecycle management, batch inference, and resource optimization.

Example:
>>> from pyflink.datastream import StreamExecutionEnvironment
>>>
>>> env = StreamExecutionEnvironment.get_execution_environment()
>>> data = env.from_collection([{"text": "hello"}, {"text": "world"}])
>>>
>>> result = data.infer(
... model="sentence-transformers/all-MiniLM-L6-v2",
... input_col="text",
... output_col="embedding"
... )
"""

from pyflink.ml.inference.config import InferenceConfig
from pyflink.ml.inference.function import InferenceFunction, BatchInferenceFunction
from pyflink.ml.inference.lifecycle import ModelLifecycleManager
from pyflink.ml.inference.executor import BatchInferenceExecutor
from pyflink.ml.inference.metrics import InferenceMetrics

__all__ = [
'InferenceConfig',
'InferenceFunction',
'BatchInferenceFunction',
'ModelLifecycleManager',
'BatchInferenceExecutor',
'InferenceMetrics',
]
Loading