From c9eb77d6958c6bda76b9edc6023420dcba26c0f8 Mon Sep 17 00:00:00 2001 From: Oleg Ovcharuk Date: Wed, 1 Apr 2026 14:32:24 +0300 Subject: [PATCH] Send writers messages as one batch --- ydb/_topic_writer/topic_writer_asyncio.py | 37 ++++-- .../topic_writer_asyncio_test.py | 60 ++++++++- ydb/_topic_writer/topic_writer_test.py | 120 +++++++++++++++++- 3 files changed, 197 insertions(+), 20 deletions(-) diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index b80537dc..fd9a03a3 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -658,24 +658,37 @@ async def _send_loop(self, writer: "WriterAsyncIOStream"): messages = list(self._messages) last_seq_no = 0 - for m in messages: - writer.write([m]) + if messages: + writer.write(messages) + last_seq_no = messages[-1].seq_no logger.debug( - "writer reconnector %s sent buffered message seqno=%s", + "writer reconnector %s sent %s buffered messages seqno=%s..%s", self._id, - m.seq_no, + len(messages), + messages[0].seq_no, + messages[-1].seq_no, ) - last_seq_no = m.seq_no while True: new_msg: InternalMessage = await self._new_messages.get() - if new_msg.seq_no > last_seq_no: - writer.write([new_msg]) - logger.debug( - "writer reconnector %s sent message seqno=%s", - self._id, - new_msg.seq_no, - ) + if new_msg.seq_no <= last_seq_no: + continue + + batch = [new_msg] + while not self._new_messages.empty(): + next_msg = self._new_messages.get_nowait() + if next_msg.seq_no > last_seq_no: + batch.append(next_msg) + + writer.write(batch) + last_seq_no = batch[-1].seq_no + logger.debug( + "writer reconnector %s sent %s messages seqno=%s..%s", + self._id, + len(batch), + batch[0].seq_no, + batch[-1].seq_no, + ) except asyncio.CancelledError: # the loop task cancelled be parent code, for example for reconnection # no need to stop all work. diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index a616b0b6..06b49a10 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -439,9 +439,7 @@ async def test_reconnect_and_resent_non_acked_messages_on_retriable_error( stream_writer = get_stream_writer() messages = await stream_writer.from_client.get() - assert [InternalMessage(message1)] == messages - messages = await stream_writer.from_client.get() - assert [InternalMessage(message2)] == messages + assert [InternalMessage(message1), InternalMessage(message2)] == messages # ack first message stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=1)) @@ -529,16 +527,64 @@ async def test_auto_seq_no(self, default_driver, default_settings, get_stream_wr stream_writer = get_stream_writer() sent = await stream_writer.from_client.get() - assert [InternalMessage(PublicMessage(seqno=last_seq_no + 1, data="123"))] == sent - - sent = await stream_writer.from_client.get() - assert [InternalMessage(PublicMessage(seqno=last_seq_no + 2, data="456"))] == sent + assert [ + InternalMessage(PublicMessage(seqno=last_seq_no + 1, data="123")), + InternalMessage(PublicMessage(seqno=last_seq_no + 2, data="456")), + ] == sent with pytest.raises(TopicWriterError): await reconnector.write_with_ack_future([PublicMessage(seqno=last_seq_no + 3, data="123")]) await reconnector.close(flush=False) + async def test_write_multiple_messages_batched_into_single_send( + self, reconnector: WriterAsyncIOReconnector, get_stream_writer + ): + stream_writer = get_stream_writer() + messages = [ + PublicMessage(data="msg1", seqno=1), + PublicMessage(data="msg2", seqno=2), + PublicMessage(data="msg3", seqno=3), + ] + await reconnector.write_with_ack_future(messages) + + sent = await asyncio.wait_for(stream_writer.from_client.get(), 1) + assert sent == [InternalMessage(m) for m in messages] + assert stream_writer.from_client.empty() + + await reconnector.close(flush=False) + + async def test_buffered_messages_on_reconnect_sent_as_single_batch( + self, + reconnector: WriterAsyncIOReconnector, + get_stream_writer, + ): + stream_writer = get_stream_writer() + messages = [ + PublicMessage(data="msg1", seqno=1), + PublicMessage(data="msg2", seqno=2), + PublicMessage(data="msg3", seqno=3), + ] + await reconnector.write_with_ack_future(messages) + + sent = await asyncio.wait_for(stream_writer.from_client.get(), 1) + assert len(sent) == 3 + + # ack first message, then trigger retriable error + stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=1)) + stream_writer.from_server.put_nowait(issues.Overloaded("test")) + + second_writer = get_stream_writer() + resent = await asyncio.wait_for(second_writer.from_client.get(), 1) + + # msg2 and msg3 must arrive as a single batch, not two separate sends + assert resent == [InternalMessage(messages[1]), InternalMessage(messages[2])] + assert second_writer.from_client.empty() + + second_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=2)) + second_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=3)) + await reconnector.close(flush=True) + async def test_deny_double_seqno(self, reconnector: WriterAsyncIOReconnector, get_stream_writer): writer = get_stream_writer() diff --git a/ydb/_topic_writer/topic_writer_test.py b/ydb/_topic_writer/topic_writer_test.py index 0e829255..215b6c02 100644 --- a/ydb/_topic_writer/topic_writer_test.py +++ b/ydb/_topic_writer/topic_writer_test.py @@ -2,7 +2,14 @@ import pytest -from .topic_writer import _split_messages_by_size +from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec +from .topic_writer import ( + InternalMessage, + PublicMessage, + _split_messages_by_size, + _split_messages_for_send, + messages_to_proto_requests, +) @pytest.mark.parametrize( @@ -48,3 +55,114 @@ def test_split_messages_by_size(messages: List[int], split_size: int, expected: List[List[int]]): res = _split_messages_by_size(messages, split_size, lambda x: x) # noqa assert res == expected + + +def _make_msg(data: bytes, codec: PublicCodec = PublicCodec.RAW, seqno: int = 1) -> InternalMessage: + msg = InternalMessage(PublicMessage(data=data, seqno=seqno)) + msg.codec = codec + return msg + + +class TestSplitMessagesForSend: + def test_empty(self): + assert _split_messages_for_send([]) == [] + + def test_single_message(self): + msg = _make_msg(b"hello") + assert _split_messages_for_send([msg]) == [[msg]] + + def test_same_codec_kept_together(self): + msgs = [_make_msg(b"a", PublicCodec.RAW, i) for i in range(1, 4)] + assert _split_messages_for_send(msgs) == [msgs] + + def test_different_codecs_split_into_separate_groups(self): + raw = _make_msg(b"a", PublicCodec.RAW, seqno=1) + gzip = _make_msg(b"b", PublicCodec.GZIP, seqno=2) + result = _split_messages_for_send([raw, gzip]) + assert result == [[raw], [gzip]] + + def test_alternating_codecs_each_run_is_own_group(self): + # RAW, GZIP, RAW must produce 3 separate groups, not 2 + r1 = _make_msg(b"a", PublicCodec.RAW, seqno=1) + g1 = _make_msg(b"b", PublicCodec.GZIP, seqno=2) + r2 = _make_msg(b"c", PublicCodec.RAW, seqno=3) + result = _split_messages_for_send([r1, g1, r2]) + assert result == [[r1], [g1], [r2]] + + def test_size_limit_splits_same_codec_group(self, monkeypatch): + from .topic_writer import _message_data_overhead + from .. import connection + + # patch limit so that one 5-byte message fits but two do not + monkeypatch.setattr(connection, "_DEFAULT_MAX_GRPC_MESSAGE_SIZE", _message_data_overhead + 5) + m1 = _make_msg(b"hello", PublicCodec.RAW, seqno=1) + m2 = _make_msg(b"x", PublicCodec.RAW, seqno=2) + result = _split_messages_for_send([m1, m2]) + assert result == [[m1], [m2]] + + def test_size_limit_not_exceeded_keeps_group(self): + m1 = _make_msg(b"hello", PublicCodec.RAW, seqno=1) + m2 = _make_msg(b"world", PublicCodec.RAW, seqno=2) + result = _split_messages_for_send([m1, m2]) + assert result == [[m1, m2]] + + def test_codec_split_takes_priority_over_size(self): + # tiny messages but different codecs — must still split by codec + r = _make_msg(b"a", PublicCodec.RAW, seqno=1) + g = _make_msg(b"b", PublicCodec.GZIP, seqno=2) + result = _split_messages_for_send([r, g]) + assert len(result) == 2 + assert result[0][0].codec == PublicCodec.RAW + assert result[1][0].codec == PublicCodec.GZIP + + +class TestMessagesToProtoRequests: + def test_empty(self): + assert messages_to_proto_requests([], tx_identity=None) == [] + + def test_single_message_produces_one_request(self): + msg = _make_msg(b"hello", PublicCodec.RAW, seqno=1) + requests = messages_to_proto_requests([msg], tx_identity=None) + assert len(requests) == 1 + + def test_request_codec_matches_messages(self): + raw = _make_msg(b"a", PublicCodec.RAW, seqno=1) + requests = messages_to_proto_requests([raw], tx_identity=None) + assert requests[0].value.codec == PublicCodec.RAW + + gzip = _make_msg(b"b", PublicCodec.GZIP, seqno=1) + requests = messages_to_proto_requests([gzip], tx_identity=None) + assert requests[0].value.codec == PublicCodec.GZIP + + def test_same_codec_produces_single_request(self): + msgs = [_make_msg(b"x", PublicCodec.RAW, seqno=i) for i in range(1, 4)] + requests = messages_to_proto_requests(msgs, tx_identity=None) + assert len(requests) == 1 + assert len(requests[0].value.messages) == 3 + + def test_different_codecs_produce_separate_requests(self): + raw = _make_msg(b"a", PublicCodec.RAW, seqno=1) + gzip = _make_msg(b"b", PublicCodec.GZIP, seqno=2) + requests = messages_to_proto_requests([raw, gzip], tx_identity=None) + assert len(requests) == 2 + assert requests[0].value.codec == PublicCodec.RAW + assert requests[1].value.codec == PublicCodec.GZIP + + def test_size_exceeded_produces_multiple_requests(self, monkeypatch): + from .topic_writer import _message_data_overhead + from .. import connection + + monkeypatch.setattr(connection, "_DEFAULT_MAX_GRPC_MESSAGE_SIZE", _message_data_overhead + 5) + m1 = _make_msg(b"hello", PublicCodec.RAW, seqno=1) + m2 = _make_msg(b"x", PublicCodec.RAW, seqno=2) + requests = messages_to_proto_requests([m1, m2], tx_identity=None) + assert len(requests) == 2 + assert requests[0].value.messages[0].seq_no == 1 + assert requests[1].value.messages[0].seq_no == 2 + + def test_messages_order_preserved_within_request(self): + msgs = [_make_msg(f"msg{i}".encode(), PublicCodec.RAW, seqno=i) for i in range(1, 5)] + requests = messages_to_proto_requests(msgs, tx_identity=None) + assert len(requests) == 1 + seq_nos = [m.seq_no for m in requests[0].value.messages] + assert seq_nos == [1, 2, 3, 4]