Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions flink-python/pyflink/datastream/data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import typing
import uuid
from enum import Enum
from typing import Callable, Union, List, cast, Optional, overload
from typing import Callable, Generic, Iterable, Iterator, List, cast, Optional, overload, TypeVar, Union

from pyflink.util.java_utils import get_j_env_configuration

Expand Down Expand Up @@ -65,8 +65,12 @@

WINDOW_STATE_NAME = 'window-contents'

# TypeVar for generic DataStream typing support (FLINK-37912)
T = TypeVar('T')
OUT = TypeVar('OUT')
KEY = TypeVar('KEY')

class DataStream(object):
class DataStream(Generic[T]):
"""
A DataStream represents a stream of elements of the same type. A DataStream can be transformed
into another DataStream by applying a transformation as for example:
Expand Down Expand Up @@ -271,8 +275,8 @@ def set_description(self, description: str) -> 'DataStream':
self._j_data_stream.setDescription(description)
return self

def map(self, func: Union[Callable, MapFunction], output_type: TypeInformation = None) \
-> 'DataStream':
def map(self, func: Union[Callable[[T], OUT], 'MapFunction[T, OUT]'], output_type: TypeInformation = None) \
-> 'DataStream[OUT]':
"""
Applies a Map transformation on a DataStream. The transformation calls a MapFunction for
each element of the DataStream. Each MapFunction call returns exactly one element.
Expand Down Expand Up @@ -314,8 +318,8 @@ def process_element(self, value, ctx: 'ProcessFunction.Context'):
.name("Map")

def flat_map(self,
func: Union[Callable, FlatMapFunction],
output_type: TypeInformation = None) -> 'DataStream':
func: Union[Callable[[T], Iterable[OUT]], FlatMapFunction[T, OUT]],
output_type: TypeInformation = None) -> 'DataStream[OUT]':
"""
Applies a FlatMap transformation on a DataStream. The transformation calls a FlatMapFunction
for each element of the DataStream. Each FlatMapFunction call can return any number of
Expand Down Expand Up @@ -356,8 +360,8 @@ def process_element(self, value, ctx: 'ProcessFunction.Context'):
.name("FlatMap")

def key_by(self,
key_selector: Union[Callable, KeySelector],
key_type: TypeInformation = None) -> 'KeyedStream':
key_selector: Union[Callable[[T], KEY], KeySelector[T, KEY]],
key_type: TypeInformation = None) -> 'KeyedStream[T, KEY]':
"""
Creates a new KeyedStream that uses the provided key for partitioning its operator states.

Expand Down Expand Up @@ -413,7 +417,7 @@ def process_element(self, value, ctx: 'ProcessFunction.Context'):
self)
return key_stream

def filter(self, func: Union[Callable, FilterFunction]) -> 'DataStream':
def filter(self, func: Union[Callable[[T], bool], 'FilterFunction[T]']) -> 'DataStream[T]':
"""
Applies a Filter transformation on a DataStream. The transformation calls a FilterFunction
for each element of the DataStream and retains only those element for which the function
Expand Down Expand Up @@ -1090,7 +1094,7 @@ def slot_sharing_group(self, slot_sharing_group: Union[str, SlotSharingGroup]) \
return self


class KeyedStream(DataStream):
class KeyedStream(DataStream[T]):
"""
A KeyedStream represents a DataStream on which operator state is partitioned by key using a
provided KeySelector. Typical operations supported by a DataStream are also possible on a
Expand All @@ -1111,8 +1115,8 @@ def __init__(self, j_keyed_stream, original_data_type_info, origin_stream: DataS
self._original_data_type_info = original_data_type_info
self._origin_stream = origin_stream

def map(self, func: Union[Callable, MapFunction], output_type: TypeInformation = None) \
-> 'DataStream':
def map(self, func: Union[Callable[[T], OUT], 'MapFunction[T, OUT]'], output_type: TypeInformation = None) \
-> 'DataStream[OUT]':
"""
Applies a Map transformation on a KeyedStream. The transformation calls a MapFunction for
each element of the DataStream. Each MapFunction call returns exactly one element.
Expand Down Expand Up @@ -1154,8 +1158,8 @@ def process_element(self, value, ctx: 'KeyedProcessFunction.Context'):
.name("Map") # type: ignore

def flat_map(self,
func: Union[Callable, FlatMapFunction],
output_type: TypeInformation = None) -> 'DataStream':
func: Union[Callable[[T], Iterable[OUT]], 'FlatMapFunction[T, OUT]'],
output_type: TypeInformation = None) -> 'DataStream[OUT]':
"""
Applies a FlatMap transformation on a KeyedStream. The transformation calls a
FlatMapFunction for each element of the DataStream. Each FlatMapFunction call can return
Expand Down Expand Up @@ -1195,7 +1199,7 @@ def process_element(self, value, ctx: 'KeyedProcessFunction.Context'):
return self.process(FlatMapKeyedProcessFunctionAdapter(func), output_type) \
.name("FlatMap")

def reduce(self, func: Union[Callable, ReduceFunction]) -> 'DataStream':
def reduce(self, func: Union[Callable[[T, T], T], 'ReduceFunction[T]']) -> 'DataStream[T]':
"""
Applies a reduce transformation on the grouped data stream grouped on by the given
key position. The `ReduceFunction` will receive input values based on the key value.
Expand Down Expand Up @@ -1280,7 +1284,7 @@ def on_timer(self, timestamp: int, ctx: 'KeyedProcessFunction.OnTimerContext'):
return self.process(ReduceProcessKeyedProcessFunctionAdapter(func), output_type) \
.name("Reduce")

def filter(self, func: Union[Callable, FilterFunction]) -> 'DataStream':
def filter(self, func: Union[Callable[[T], bool], 'FilterFunction[T]']) -> 'DataStream[T]':
if not isinstance(func, FilterFunction) and not callable(func):
raise TypeError("The input must be a FilterFunction or a callable function")

Expand Down Expand Up @@ -1610,12 +1614,12 @@ def max_by(self, position_to_max_by: Union[int, str] = 0) -> 'DataStream':
def add_sink(self, sink_func: SinkFunction) -> 'DataStreamSink':
return self._values().add_sink(sink_func)

def key_by(self, key_selector: Union[Callable, KeySelector],
key_type: TypeInformation = None) -> 'KeyedStream':
def key_by(self, key_selector: Union[Callable[[T], KEY], 'KeySelector[T, KEY]'],
key_type: TypeInformation = None) -> 'KeyedStream[T, KEY]':
return self._origin_stream.key_by(key_selector, key_type)

def process(self, func: KeyedProcessFunction, # type: ignore
output_type: TypeInformation = None) -> 'DataStream':
output_type: TypeInformation = None) -> 'DataStream[Any]':
"""
Applies the given ProcessFunction on the input stream, thereby creating a transformed output
stream.
Expand Down Expand Up @@ -1920,7 +1924,9 @@ def reduce(self,

>>> ds.key_by(lambda x: x[1]) \\
... .window(TumblingEventTimeWindows.of(Time.seconds(5))) \\
... .reduce(lambda a, b: a[0] + b[0], b[1])
# UPDATED
def key_by(self, key_selector: Union[Callable[[T], KEY], 'KeySelector[T, KEY]'],
key_type: TypeInformation = None) -> 'KeyedStream[T, KEY]': ... .reduce(lambda a, b: a[0] + b[0], b[1])

:param reduce_function: The reduce function.
:param window_function: The window function.
Expand Down
56 changes: 28 additions & 28 deletions flink-python/pyflink/datastream/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from enum import Enum

from py4j.java_gateway import JavaObject
from typing import Union, Any, Generic, TypeVar, Iterable, List, Callable, Optional
from typing import Union, Any, Generic, TypeVar, Iterable, Iterator, List, Callable, Optional

from pyflink.datastream.state import ValueState, ValueStateDescriptor, ListStateDescriptor, \
ListState, MapStateDescriptor, MapState, ReducingStateDescriptor, ReducingState, \
Expand Down Expand Up @@ -203,14 +203,14 @@ class Function(ABC):
"""
The base class for all user-defined functions.
"""
def open(self, runtime_context: RuntimeContext):
def open(self, runtime_context: RuntimeContext) -> None:
pass

def close(self):
def close(self) -> None:
pass


class MapFunction(Function):
class MapFunction(Function, Generic[IN, OUT]):
"""
Base class for Map functions. Map functions take elements and transform them, element wise. A
Map function always produces a single result element for each input element. Typical
Expand All @@ -225,7 +225,7 @@ class MapFunction(Function):
"""

@abstractmethod
def map(self, value):
def map(self, value: IN) -> OUT:
"""
The mapping method. Takes an element from the input data and transforms it into exactly one
element.
Expand All @@ -236,7 +236,7 @@ def map(self, value):
pass


class CoMapFunction(Function):
class CoMapFunction(Function, Generic[IN, OUT]):
"""
A CoMapFunction implements a map() transformation over two connected streams.

Expand All @@ -252,7 +252,7 @@ class CoMapFunction(Function):
"""

@abstractmethod
def map1(self, value):
def map1(self, value: IN) -> OUT:
"""
This method is called for each element in the first of the connected streams.

Expand All @@ -262,7 +262,7 @@ def map1(self, value):
pass

@abstractmethod
def map2(self, value):
def map2(self, value: IN) -> OUT:
"""
This method is called for each element in the second of the connected streams.

Expand All @@ -272,39 +272,39 @@ def map2(self, value):
pass


class FlatMapFunction(Function):
class FlatMapFunction(Function, Generic[IN, OUT]):
"""
Base class for flatMap functions. FlatMap functions take elements and transform them, into zero,
one, or more elements. Typical applications can be splitting elements, or unnesting lists and
arrays. Operations that produce multiple strictly one result element per input element can also
use the MapFunction.
The basic syntax for using a MapFUnction is as follows:
The basic syntax for using a MapFunction is as follows:

::
>>> ds = ...
>>> new_ds = ds.flat_map(MyFlatMapFunction())
"""

@abstractmethod
def flat_map(self, value):
def flat_map(self, value: IN) -> Iterator[OUT]:
"""
The core mthod of the FlatMapFunction. Takes an element from the input data and transforms
The core method of the FlatMapFunction. Takes an element from the input data and transforms
it into zero, one, or more elements.
A basic implementation of flat map is as follows:

::
>>> class MyFlatMapFunction(FlatMapFunction):
>>> def flat_map(self, value):
>>> for i in range(value):
>>> yield i
... def flat_map(self, value: IN) -> Iterator[OUT]:
... for i in range(value):
... yield i

:param value: The input value.
:return: A generator
"""
pass


class CoFlatMapFunction(Function):
class CoFlatMapFunction(Function, Generic[IN, OUT]):
"""
A CoFlatMapFunction implements a flat-map transformation over two connected streams.

Expand Down Expand Up @@ -336,7 +336,7 @@ class CoFlatMapFunction(Function):
"""

@abstractmethod
def flat_map1(self, value):
def flat_map1(self, value: IN) -> Iterator[OUT]:
"""
This method is called for each element in the first of the connected streams.

Expand All @@ -346,7 +346,7 @@ def flat_map1(self, value):
pass

@abstractmethod
def flat_map2(self, value):
def flat_map2(self, value: IN) -> Iterator[OUT]:
"""
This method is called for each element in the second of the connected streams.

Expand All @@ -356,7 +356,7 @@ def flat_map2(self, value):
pass


class ReduceFunction(Function):
class ReduceFunction(Function, Generic[IN]):
"""
Base interface for Reduce functions. Reduce functions combine groups of elements to a single
value, by taking always two elements and combining them into one. Reduce functions may be
Expand All @@ -371,7 +371,7 @@ class ReduceFunction(Function):
"""

@abstractmethod
def reduce(self, value1, value2):
def reduce(self, value1: IN, value2: IN) -> IN:
"""
The core method of ReduceFunction, combining two values into one value of the same type.
The reduce function is consecutively applied to all values of a group until only a single
Expand Down Expand Up @@ -461,15 +461,15 @@ def merge(self, acc_a, acc_b):
pass


class KeySelector(Function):
class KeySelector(Function, Generic[IN, KEY]):
"""
The KeySelector allows to use deterministic objects for operations such as reduce, reduceGroup,
join coGroup, etc. If invoked multiple times on the same object, the returned key must be the
same. The extractor takes an object an returns the deterministic key for that object.
"""

@abstractmethod
def get_key(self, value):
def get_key(self, value: IN) -> KEY:
"""
User-defined function that deterministically extracts the key from an object.

Expand All @@ -489,7 +489,7 @@ def get_key(self, value):
return 0


class FilterFunction(Function):
class FilterFunction(Function, Generic[IN]):
"""
A filter function is a predicate applied individually to each record. The predicate decides
whether to keep the element, or to discard it.
Expand All @@ -505,7 +505,7 @@ class FilterFunction(Function):
"""

@abstractmethod
def filter(self, value):
def filter(self, value: IN) -> bool:
"""
The filter function that evaluates the predicate.

Expand Down Expand Up @@ -655,7 +655,7 @@ def timestamp(self) -> int:
pass

@abstractmethod
def process_element(self, value, ctx: 'ProcessFunction.Context'):
def process_element(self, value: IN, ctx: 'ProcessFunction.Context') -> Iterator[OUT]:
"""
Process one element from the input stream.

Expand Down Expand Up @@ -716,7 +716,7 @@ def time_domain(self) -> TimeDomain:
pass

@abstractmethod
def process_element(self, value, ctx: 'KeyedProcessFunction.Context'):
def process_element(self, value: IN, ctx: 'KeyedProcessFunction.Context') -> Iterator[OUT]:
"""
Process one element from the input stream.

Expand Down Expand Up @@ -780,7 +780,7 @@ def timestamp(self) -> int:
pass

@abstractmethod
def process_element1(self, value, ctx: 'CoProcessFunction.Context'):
def process_element1(self, value: IN, ctx: 'CoProcessFunction.Context') -> Iterator[OUT]:
"""
This method is called for each element in the first of the connected streams.

Expand All @@ -795,7 +795,7 @@ def process_element1(self, value, ctx: 'CoProcessFunction.Context'):
pass

@abstractmethod
def process_element2(self, value, ctx: 'CoProcessFunction.Context'):
def process_element2(self, value: IN, ctx: 'CoProcessFunction.Context') -> Iterator[OUT]:
"""
This method is called for each element in the second of the connected streams.

Expand Down