diff --git a/flink-python/pyflink/datastream/data_stream.py b/flink-python/pyflink/datastream/data_stream.py index 8be48e0640f32..169cfd61da919 100644 --- a/flink-python/pyflink/datastream/data_stream.py +++ b/flink-python/pyflink/datastream/data_stream.py @@ -18,7 +18,7 @@ import typing import uuid from enum import Enum -from typing import Callable, Union, List, cast, Optional, overload +from typing import Callable, Generic, Iterable, Iterator, List, cast, Optional, overload, TypeVar, Union from pyflink.util.java_utils import get_j_env_configuration @@ -65,8 +65,12 @@ WINDOW_STATE_NAME = 'window-contents' +# TypeVar for generic DataStream typing support (FLINK-37912) +T = TypeVar('T') +OUT = TypeVar('OUT') +KEY = TypeVar('KEY') -class DataStream(object): +class DataStream(Generic[T]): """ A DataStream represents a stream of elements of the same type. A DataStream can be transformed into another DataStream by applying a transformation as for example: @@ -271,8 +275,8 @@ def set_description(self, description: str) -> 'DataStream': self._j_data_stream.setDescription(description) return self - def map(self, func: Union[Callable, MapFunction], output_type: TypeInformation = None) \ - -> 'DataStream': + def map(self, func: Union[Callable[[T], OUT], 'MapFunction[T, OUT]'], output_type: TypeInformation = None) \ + -> 'DataStream[OUT]': """ Applies a Map transformation on a DataStream. The transformation calls a MapFunction for each element of the DataStream. Each MapFunction call returns exactly one element. @@ -314,8 +318,8 @@ def process_element(self, value, ctx: 'ProcessFunction.Context'): .name("Map") def flat_map(self, - func: Union[Callable, FlatMapFunction], - output_type: TypeInformation = None) -> 'DataStream': + func: Union[Callable[[T], Iterable[OUT]], FlatMapFunction[T, OUT]], + output_type: TypeInformation = None) -> 'DataStream[OUT]': """ Applies a FlatMap transformation on a DataStream. The transformation calls a FlatMapFunction for each element of the DataStream. Each FlatMapFunction call can return any number of @@ -356,8 +360,8 @@ def process_element(self, value, ctx: 'ProcessFunction.Context'): .name("FlatMap") def key_by(self, - key_selector: Union[Callable, KeySelector], - key_type: TypeInformation = None) -> 'KeyedStream': + key_selector: Union[Callable[[T], KEY], KeySelector[T, KEY]], + key_type: TypeInformation = None) -> 'KeyedStream[T, KEY]': """ Creates a new KeyedStream that uses the provided key for partitioning its operator states. @@ -413,7 +417,7 @@ def process_element(self, value, ctx: 'ProcessFunction.Context'): self) return key_stream - def filter(self, func: Union[Callable, FilterFunction]) -> 'DataStream': + def filter(self, func: Union[Callable[[T], bool], 'FilterFunction[T]']) -> 'DataStream[T]': """ Applies a Filter transformation on a DataStream. The transformation calls a FilterFunction for each element of the DataStream and retains only those element for which the function @@ -1090,7 +1094,7 @@ def slot_sharing_group(self, slot_sharing_group: Union[str, SlotSharingGroup]) \ return self -class KeyedStream(DataStream): +class KeyedStream(DataStream[T]): """ A KeyedStream represents a DataStream on which operator state is partitioned by key using a provided KeySelector. Typical operations supported by a DataStream are also possible on a @@ -1111,8 +1115,8 @@ def __init__(self, j_keyed_stream, original_data_type_info, origin_stream: DataS self._original_data_type_info = original_data_type_info self._origin_stream = origin_stream - def map(self, func: Union[Callable, MapFunction], output_type: TypeInformation = None) \ - -> 'DataStream': + def map(self, func: Union[Callable[[T], OUT], 'MapFunction[T, OUT]'], output_type: TypeInformation = None) \ + -> 'DataStream[OUT]': """ Applies a Map transformation on a KeyedStream. The transformation calls a MapFunction for each element of the DataStream. Each MapFunction call returns exactly one element. @@ -1154,8 +1158,8 @@ def process_element(self, value, ctx: 'KeyedProcessFunction.Context'): .name("Map") # type: ignore def flat_map(self, - func: Union[Callable, FlatMapFunction], - output_type: TypeInformation = None) -> 'DataStream': + func: Union[Callable[[T], Iterable[OUT]], 'FlatMapFunction[T, OUT]'], + output_type: TypeInformation = None) -> 'DataStream[OUT]': """ Applies a FlatMap transformation on a KeyedStream. The transformation calls a FlatMapFunction for each element of the DataStream. Each FlatMapFunction call can return @@ -1195,7 +1199,7 @@ def process_element(self, value, ctx: 'KeyedProcessFunction.Context'): return self.process(FlatMapKeyedProcessFunctionAdapter(func), output_type) \ .name("FlatMap") - def reduce(self, func: Union[Callable, ReduceFunction]) -> 'DataStream': + def reduce(self, func: Union[Callable[[T, T], T], 'ReduceFunction[T]']) -> 'DataStream[T]': """ Applies a reduce transformation on the grouped data stream grouped on by the given key position. The `ReduceFunction` will receive input values based on the key value. @@ -1280,7 +1284,7 @@ def on_timer(self, timestamp: int, ctx: 'KeyedProcessFunction.OnTimerContext'): return self.process(ReduceProcessKeyedProcessFunctionAdapter(func), output_type) \ .name("Reduce") - def filter(self, func: Union[Callable, FilterFunction]) -> 'DataStream': + def filter(self, func: Union[Callable[[T], bool], 'FilterFunction[T]']) -> 'DataStream[T]': if not isinstance(func, FilterFunction) and not callable(func): raise TypeError("The input must be a FilterFunction or a callable function") @@ -1610,12 +1614,12 @@ def max_by(self, position_to_max_by: Union[int, str] = 0) -> 'DataStream': def add_sink(self, sink_func: SinkFunction) -> 'DataStreamSink': return self._values().add_sink(sink_func) - def key_by(self, key_selector: Union[Callable, KeySelector], - key_type: TypeInformation = None) -> 'KeyedStream': + def key_by(self, key_selector: Union[Callable[[T], KEY], 'KeySelector[T, KEY]'], + key_type: TypeInformation = None) -> 'KeyedStream[T, KEY]': return self._origin_stream.key_by(key_selector, key_type) def process(self, func: KeyedProcessFunction, # type: ignore - output_type: TypeInformation = None) -> 'DataStream': + output_type: TypeInformation = None) -> 'DataStream[Any]': """ Applies the given ProcessFunction on the input stream, thereby creating a transformed output stream. @@ -1920,7 +1924,9 @@ def reduce(self, >>> ds.key_by(lambda x: x[1]) \\ ... .window(TumblingEventTimeWindows.of(Time.seconds(5))) \\ - ... .reduce(lambda a, b: a[0] + b[0], b[1]) +# UPDATED +def key_by(self, key_selector: Union[Callable[[T], KEY], 'KeySelector[T, KEY]'], + key_type: TypeInformation = None) -> 'KeyedStream[T, KEY]': ... .reduce(lambda a, b: a[0] + b[0], b[1]) :param reduce_function: The reduce function. :param window_function: The window function. diff --git a/flink-python/pyflink/datastream/functions.py b/flink-python/pyflink/datastream/functions.py index db4366390e1ab..e02a0346539ac 100644 --- a/flink-python/pyflink/datastream/functions.py +++ b/flink-python/pyflink/datastream/functions.py @@ -19,7 +19,7 @@ from enum import Enum from py4j.java_gateway import JavaObject -from typing import Union, Any, Generic, TypeVar, Iterable, List, Callable, Optional +from typing import Union, Any, Generic, TypeVar, Iterable, Iterator, List, Callable, Optional from pyflink.datastream.state import ValueState, ValueStateDescriptor, ListStateDescriptor, \ ListState, MapStateDescriptor, MapState, ReducingStateDescriptor, ReducingState, \ @@ -203,14 +203,14 @@ class Function(ABC): """ The base class for all user-defined functions. """ - def open(self, runtime_context: RuntimeContext): + def open(self, runtime_context: RuntimeContext) -> None: pass - def close(self): + def close(self) -> None: pass -class MapFunction(Function): +class MapFunction(Function, Generic[IN, OUT]): """ Base class for Map functions. Map functions take elements and transform them, element wise. A Map function always produces a single result element for each input element. Typical @@ -225,7 +225,7 @@ class MapFunction(Function): """ @abstractmethod - def map(self, value): + def map(self, value: IN) -> OUT: """ The mapping method. Takes an element from the input data and transforms it into exactly one element. @@ -236,7 +236,7 @@ def map(self, value): pass -class CoMapFunction(Function): +class CoMapFunction(Function, Generic[IN, OUT]): """ A CoMapFunction implements a map() transformation over two connected streams. @@ -252,7 +252,7 @@ class CoMapFunction(Function): """ @abstractmethod - def map1(self, value): + def map1(self, value: IN) -> OUT: """ This method is called for each element in the first of the connected streams. @@ -262,7 +262,7 @@ def map1(self, value): pass @abstractmethod - def map2(self, value): + def map2(self, value: IN) -> OUT: """ This method is called for each element in the second of the connected streams. @@ -272,13 +272,13 @@ def map2(self, value): pass -class FlatMapFunction(Function): +class FlatMapFunction(Function, Generic[IN, OUT]): """ Base class for flatMap functions. FlatMap functions take elements and transform them, into zero, one, or more elements. Typical applications can be splitting elements, or unnesting lists and arrays. Operations that produce multiple strictly one result element per input element can also use the MapFunction. - The basic syntax for using a MapFUnction is as follows: + The basic syntax for using a MapFunction is as follows: :: >>> ds = ... @@ -286,17 +286,17 @@ class FlatMapFunction(Function): """ @abstractmethod - def flat_map(self, value): + def flat_map(self, value: IN) -> Iterator[OUT]: """ - The core mthod of the FlatMapFunction. Takes an element from the input data and transforms + The core method of the FlatMapFunction. Takes an element from the input data and transforms it into zero, one, or more elements. A basic implementation of flat map is as follows: :: >>> class MyFlatMapFunction(FlatMapFunction): - >>> def flat_map(self, value): - >>> for i in range(value): - >>> yield i + ... def flat_map(self, value: IN) -> Iterator[OUT]: + ... for i in range(value): + ... yield i :param value: The input value. :return: A generator @@ -304,7 +304,7 @@ def flat_map(self, value): pass -class CoFlatMapFunction(Function): +class CoFlatMapFunction(Function, Generic[IN, OUT]): """ A CoFlatMapFunction implements a flat-map transformation over two connected streams. @@ -336,7 +336,7 @@ class CoFlatMapFunction(Function): """ @abstractmethod - def flat_map1(self, value): + def flat_map1(self, value: IN) -> Iterator[OUT]: """ This method is called for each element in the first of the connected streams. @@ -346,7 +346,7 @@ def flat_map1(self, value): pass @abstractmethod - def flat_map2(self, value): + def flat_map2(self, value: IN) -> Iterator[OUT]: """ This method is called for each element in the second of the connected streams. @@ -356,7 +356,7 @@ def flat_map2(self, value): pass -class ReduceFunction(Function): +class ReduceFunction(Function, Generic[IN]): """ Base interface for Reduce functions. Reduce functions combine groups of elements to a single value, by taking always two elements and combining them into one. Reduce functions may be @@ -371,7 +371,7 @@ class ReduceFunction(Function): """ @abstractmethod - def reduce(self, value1, value2): + def reduce(self, value1: IN, value2: IN) -> IN: """ The core method of ReduceFunction, combining two values into one value of the same type. The reduce function is consecutively applied to all values of a group until only a single @@ -461,7 +461,7 @@ def merge(self, acc_a, acc_b): pass -class KeySelector(Function): +class KeySelector(Function, Generic[IN, KEY]): """ The KeySelector allows to use deterministic objects for operations such as reduce, reduceGroup, join coGroup, etc. If invoked multiple times on the same object, the returned key must be the @@ -469,7 +469,7 @@ class KeySelector(Function): """ @abstractmethod - def get_key(self, value): + def get_key(self, value: IN) -> KEY: """ User-defined function that deterministically extracts the key from an object. @@ -489,7 +489,7 @@ def get_key(self, value): return 0 -class FilterFunction(Function): +class FilterFunction(Function, Generic[IN]): """ A filter function is a predicate applied individually to each record. The predicate decides whether to keep the element, or to discard it. @@ -505,7 +505,7 @@ class FilterFunction(Function): """ @abstractmethod - def filter(self, value): + def filter(self, value: IN) -> bool: """ The filter function that evaluates the predicate. @@ -655,7 +655,7 @@ def timestamp(self) -> int: pass @abstractmethod - def process_element(self, value, ctx: 'ProcessFunction.Context'): + def process_element(self, value: IN, ctx: 'ProcessFunction.Context') -> Iterator[OUT]: """ Process one element from the input stream. @@ -716,7 +716,7 @@ def time_domain(self) -> TimeDomain: pass @abstractmethod - def process_element(self, value, ctx: 'KeyedProcessFunction.Context'): + def process_element(self, value: IN, ctx: 'KeyedProcessFunction.Context') -> Iterator[OUT]: """ Process one element from the input stream. @@ -780,7 +780,7 @@ def timestamp(self) -> int: pass @abstractmethod - def process_element1(self, value, ctx: 'CoProcessFunction.Context'): + def process_element1(self, value: IN, ctx: 'CoProcessFunction.Context') -> Iterator[OUT]: """ This method is called for each element in the first of the connected streams. @@ -795,7 +795,7 @@ def process_element1(self, value, ctx: 'CoProcessFunction.Context'): pass @abstractmethod - def process_element2(self, value, ctx: 'CoProcessFunction.Context'): + def process_element2(self, value: IN, ctx: 'CoProcessFunction.Context') -> Iterator[OUT]: """ This method is called for each element in the second of the connected streams.