Skip to content

Commit 5c7d07f

Browse files
feat: Update Sourcer interface for propagating totalPartitions (#343)
Signed-off-by: Vaibhav Tiwari <vaibhav.tiwari33@gmail.com> Co-authored-by: Vigith Maurice <vigith@gmail.com>
1 parent fd0f49a commit 5c7d07f

File tree

11 files changed

+234
-141
lines changed

11 files changed

+234
-141
lines changed

packages/pynumaflow/examples/source/simple_source/example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ async def pending_handler(self) -> PendingResponse:
9090
"""
9191
return PendingResponse(count=0)
9292

93-
async def partitions_handler(self) -> PartitionsResponse:
93+
async def active_partitions_handler(self) -> PartitionsResponse:
9494
"""
9595
The simple source always returns default partitions.
9696
"""

packages/pynumaflow/pynumaflow/proto/sourcer/source.proto

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,10 @@ message PendingResponse {
190190
*/
191191
message PartitionsResponse {
192192
message Result {
193-
// Required field holding the list of partitions.
193+
// Required field holding the list of active partitions.
194194
repeated int32 partitions = 1;
195+
// Total number of partitions in the source.
196+
optional int32 total_partitions = 2;
195197
}
196198
// Required field holding the result.
197199
Result result = 1;

packages/pynumaflow/pynumaflow/proto/sourcer/source_pb2.py

Lines changed: 9 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/pynumaflow/pynumaflow/proto/sourcer/source_pb2.pyi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,12 @@ class PendingResponse(_message.Message):
156156
class PartitionsResponse(_message.Message):
157157
__slots__ = ("result",)
158158
class Result(_message.Message):
159-
__slots__ = ("partitions",)
159+
__slots__ = ("partitions", "total_partitions")
160160
PARTITIONS_FIELD_NUMBER: _ClassVar[int]
161+
TOTAL_PARTITIONS_FIELD_NUMBER: _ClassVar[int]
161162
partitions: _containers.RepeatedScalarFieldContainer[int]
162-
def __init__(self, partitions: _Optional[_Iterable[int]] = ...) -> None: ...
163+
total_partitions: int
164+
def __init__(self, partitions: _Optional[_Iterable[int]] = ..., total_partitions: _Optional[int] = ...) -> None: ...
163165
RESULT_FIELD_NUMBER: _ClassVar[int]
164166
result: PartitionsResponse.Result
165167
def __init__(self, result: _Optional[_Union[PartitionsResponse.Result, _Mapping]] = ...) -> None: ...

packages/pynumaflow/pynumaflow/sourcer/_dtypes.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,10 @@ def count(self) -> int:
240240
class PartitionsResponse:
241241
"""
242242
PartitionsResponse is the response for the partition request.
243-
It indicates the number of partitions at the user defined source.
244-
A negative count indicates that the partition information is not available.
243+
It indicates the active partitions at the user defined source.
245244
246245
Args:
247-
count: the number of partitions.
246+
partitions: the list of active partitions.
248247
"""
249248

250249
_partitions: list[int]
@@ -256,7 +255,7 @@ def __init__(self, partitions: list[int]):
256255

257256
@property
258257
def partitions(self) -> list[int]:
259-
"""Returns the list of partitions"""
258+
"""Returns the list of active partitions"""
260259
return self._partitions
261260

262261

@@ -298,12 +297,22 @@ async def pending_handler(self) -> PendingResponse:
298297
pass
299298

300299
@abstractmethod
301-
async def partitions_handler(self) -> PartitionsResponse:
300+
async def active_partitions_handler(self) -> PartitionsResponse:
302301
"""
303-
The simple source always returns zero to indicate there is no pending record.
302+
Returns the active partitions associated with the source, used by the platform
303+
to determine the partitions to which the watermark should be published.
304304
"""
305305
pass
306306

307+
async def total_partitions_handler(self) -> int | None:
308+
"""
309+
Returns the total number of partitions in the source.
310+
Used by the platform for watermark progression to know when all
311+
processors have reported in.
312+
Returns None by default, indicating the source does not report total partitions.
313+
"""
314+
return None
315+
307316

308317
# Create default partition id from the environment variable "NUMAFLOW_REPLICA"
309318
DefaultPartitionId = int(os.getenv("NUMAFLOW_REPLICA", "0"))

packages/pynumaflow/pynumaflow/sourcer/async_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ async def pending_handler(self) -> PendingResponse:
133133
'''
134134
return PendingResponse(count=0)
135135
136-
async def partitions_handler(self) -> PartitionsResponse:
136+
async def active_partitions_handler(self) -> PartitionsResponse:
137137
'''
138138
The simple source always returns default partitions.
139139
'''

packages/pynumaflow/pynumaflow/sourcer/servicer/async_servicer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def __initialize_handlers(self):
8484
self.__source_ack_handler = self.source_handler.ack_handler
8585
self.__source_nack_handler = self.source_handler.nack_handler
8686
self.__source_pending_handler = self.source_handler.pending_handler
87-
self.__source_partitions_handler = self.source_handler.partitions_handler
87+
self.__source_active_partitions_handler = self.source_handler.active_partitions_handler
88+
self.__source_total_partitions_handler = self.source_handler.total_partitions_handler
8889

8990
async def ReadFn(
9091
self,
@@ -278,10 +279,11 @@ async def PartitionsFn(
278279
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
279280
) -> source_pb2.PartitionsResponse:
280281
"""
281-
PartitionsFn returns the partitions of the user defined source.
282+
PartitionsFn returns the active partitions and total partitions of the user defined source.
282283
"""
283284
try:
284-
partitions = await self.__source_partitions_handler()
285+
partitions = await self.__source_active_partitions_handler()
286+
total_partitions = await self.__source_total_partitions_handler()
285287
except asyncio.CancelledError:
286288
# Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault.
287289
_LOGGER.info("Server shutting down, cancelling RPC.")
@@ -301,8 +303,10 @@ async def PartitionsFn(
301303
return source_pb2.PartitionsResponse(
302304
result=source_pb2.PartitionsResponse.Result(partitions=[])
303305
)
304-
resp = source_pb2.PartitionsResponse.Result(partitions=partitions.partitions)
305-
return source_pb2.PartitionsResponse(result=resp)
306+
result = source_pb2.PartitionsResponse.Result(
307+
partitions=partitions.partitions, total_partitions=total_partitions
308+
)
309+
return source_pb2.PartitionsResponse(result=result)
306310

307311
def clean_background(self, task):
308312
"""

packages/pynumaflow/tests/source/test_async_source.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ack_req_source_fn,
1818
mock_partitions,
1919
AsyncSource,
20+
AsyncSourceWithTotalPartitions,
2021
mock_offset,
2122
nack_req_source_fn,
2223
)
@@ -194,6 +195,66 @@ def test_partitions(async_source_server) -> None:
194195
assert response.result.partitions == mock_partitions()
195196

196197

198+
def test_partitions_default_total_partitions_is_none(async_source_server) -> None:
199+
"""
200+
Verify total_partitions is not set when the source doesn't override
201+
total_partitions_handler.
202+
"""
203+
with grpc.insecure_channel(server_port) as channel:
204+
stub = source_pb2_grpc.SourceStub(channel)
205+
request = _empty_pb2.Empty()
206+
response = stub.PartitionsFn(request=request)
207+
208+
assert response.result.partitions == mock_partitions()
209+
assert not response.result.HasField("total_partitions")
210+
211+
212+
server_port_tp = "unix:///tmp/async_source_tp.sock"
213+
214+
215+
def NewAsyncSourcerWithTotalPartitions():
216+
class_instance = AsyncSourceWithTotalPartitions()
217+
server = SourceAsyncServer(sourcer_instance=class_instance)
218+
udfs = server.servicer
219+
return udfs
220+
221+
222+
async def start_server_tp(udfs):
223+
server = grpc.aio.server()
224+
source_pb2_grpc.add_SourceServicer_to_server(udfs, server)
225+
listen_addr = server_port_tp
226+
server.add_insecure_port(listen_addr)
227+
logging.info("Starting server on %s", listen_addr)
228+
await server.start()
229+
return server, listen_addr
230+
231+
232+
@pytest.fixture(scope="module")
233+
def async_source_server_with_total_partitions():
234+
"""Module-scoped fixture: starts an async gRPC source server with total partitions."""
235+
loop = create_async_loop()
236+
237+
udfs = NewAsyncSourcerWithTotalPartitions()
238+
server = start_async_server(loop, start_server_tp(udfs))
239+
240+
yield loop
241+
242+
teardown_async_server(loop, server)
243+
244+
245+
def test_partitions_with_total_partitions(async_source_server_with_total_partitions) -> None:
246+
"""
247+
Verify total_partitions flows through gRPC when the source implements total_partitions_handler.
248+
"""
249+
with grpc.insecure_channel(server_port_tp) as channel:
250+
stub = source_pb2_grpc.SourceStub(channel)
251+
request = _empty_pb2.Empty()
252+
response = stub.PartitionsFn(request=request)
253+
254+
assert response.result.partitions == mock_partitions()
255+
assert response.result.total_partitions == 10
256+
257+
197258
@pytest.mark.parametrize(
198259
"max_threads_arg,expected",
199260
[

packages/pynumaflow/tests/source/test_async_source_shutdown.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ async def nack_handler(self, nack_request: NackRequest):
4141
async def pending_handler(self) -> PendingResponse:
4242
return PendingResponse(count=0)
4343

44-
async def partitions_handler(self) -> PartitionsResponse:
44+
async def active_partitions_handler(self) -> PartitionsResponse:
4545
return PartitionsResponse(partitions=[])
4646

4747

@@ -194,7 +194,7 @@ async def _run():
194194
async def _cancelled_partitions():
195195
raise asyncio.CancelledError()
196196

197-
handler.partitions_handler = _cancelled_partitions
197+
handler.active_partitions_handler = _cancelled_partitions
198198

199199
servicer = AsyncSourceServicer(source_handler=handler)
200200
shutdown_event = asyncio.Event()

packages/pynumaflow/tests/source/utils.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,44 @@ async def nack_handler(self, nack_request: NackRequest):
5050
async def pending_handler(self) -> PendingResponse:
5151
return PendingResponse(count=10)
5252

53-
async def partitions_handler(self) -> PartitionsResponse:
53+
async def active_partitions_handler(self) -> PartitionsResponse:
5454
return PartitionsResponse(partitions=mock_partitions())
5555

5656

57+
class AsyncSourceWithTotalPartitions(Sourcer):
58+
"""A test source that implements active_partitions_handler and total_partitions_handler."""
59+
60+
async def read_handler(self, datum: ReadRequest, output: NonBlockingIterator):
61+
payload = b"payload:test_mock_message"
62+
keys = ["test_key"]
63+
offset = mock_offset()
64+
event_time = mock_event_time()
65+
for i in range(10):
66+
await output.put(
67+
Message(
68+
payload=payload,
69+
keys=keys,
70+
offset=offset,
71+
event_time=event_time,
72+
)
73+
)
74+
75+
async def ack_handler(self, ack_request: AckRequest):
76+
return
77+
78+
async def nack_handler(self, nack_request: NackRequest):
79+
return
80+
81+
async def pending_handler(self) -> PendingResponse:
82+
return PendingResponse(count=10)
83+
84+
async def active_partitions_handler(self) -> PartitionsResponse:
85+
return PartitionsResponse(partitions=mock_partitions())
86+
87+
async def total_partitions_handler(self) -> int | None:
88+
return 10
89+
90+
5791
def read_req_source_fn() -> ReadRequest:
5892
request = source_pb2.ReadRequest.Request(
5993
num_records=10,
@@ -102,5 +136,5 @@ async def nack_handler(self, nack_request: NackRequest):
102136
async def pending_handler(self) -> PendingResponse:
103137
raise RuntimeError("Got a runtime error from pending handler.")
104138

105-
async def partitions_handler(self) -> PartitionsResponse:
139+
async def active_partitions_handler(self) -> PartitionsResponse:
106140
raise RuntimeError("Got a runtime error from partition handler.")

0 commit comments

Comments
 (0)