Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions ydb/_topic_writer/topic_writer_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,24 +658,37 @@ async def _send_loop(self, writer: "WriterAsyncIOStream"):
messages = list(self._messages)

last_seq_no = 0
for m in messages:
writer.write([m])
if messages:
writer.write(messages)
last_seq_no = messages[-1].seq_no
logger.debug(
"writer reconnector %s sent buffered message seqno=%s",
"writer reconnector %s sent %s buffered messages seqno=%s..%s",
self._id,
m.seq_no,
len(messages),
messages[0].seq_no,
messages[-1].seq_no,
)
last_seq_no = m.seq_no

while True:
new_msg: InternalMessage = await self._new_messages.get()
if new_msg.seq_no > last_seq_no:
writer.write([new_msg])
logger.debug(
"writer reconnector %s sent message seqno=%s",
self._id,
new_msg.seq_no,
)
if new_msg.seq_no <= last_seq_no:
continue

batch = [new_msg]
while not self._new_messages.empty():
next_msg = self._new_messages.get_nowait()
if next_msg.seq_no > last_seq_no:
batch.append(next_msg)

writer.write(batch)
last_seq_no = batch[-1].seq_no
logger.debug(
"writer reconnector %s sent %s messages seqno=%s..%s",
self._id,
len(batch),
batch[0].seq_no,
batch[-1].seq_no,
)
except asyncio.CancelledError:
# the loop task cancelled be parent code, for example for reconnection
# no need to stop all work.
Expand Down
60 changes: 53 additions & 7 deletions ydb/_topic_writer/topic_writer_asyncio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,7 @@ async def test_reconnect_and_resent_non_acked_messages_on_retriable_error(
stream_writer = get_stream_writer()

messages = await stream_writer.from_client.get()
assert [InternalMessage(message1)] == messages
messages = await stream_writer.from_client.get()
assert [InternalMessage(message2)] == messages
assert [InternalMessage(message1), InternalMessage(message2)] == messages

# ack first message
stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=1))
Expand Down Expand Up @@ -529,16 +527,64 @@ async def test_auto_seq_no(self, default_driver, default_settings, get_stream_wr
stream_writer = get_stream_writer()

sent = await stream_writer.from_client.get()
assert [InternalMessage(PublicMessage(seqno=last_seq_no + 1, data="123"))] == sent

sent = await stream_writer.from_client.get()
assert [InternalMessage(PublicMessage(seqno=last_seq_no + 2, data="456"))] == sent
assert [
InternalMessage(PublicMessage(seqno=last_seq_no + 1, data="123")),
InternalMessage(PublicMessage(seqno=last_seq_no + 2, data="456")),
] == sent

with pytest.raises(TopicWriterError):
await reconnector.write_with_ack_future([PublicMessage(seqno=last_seq_no + 3, data="123")])

await reconnector.close(flush=False)

async def test_write_multiple_messages_batched_into_single_send(
self, reconnector: WriterAsyncIOReconnector, get_stream_writer
):
stream_writer = get_stream_writer()
messages = [
PublicMessage(data="msg1", seqno=1),
PublicMessage(data="msg2", seqno=2),
PublicMessage(data="msg3", seqno=3),
]
await reconnector.write_with_ack_future(messages)

sent = await asyncio.wait_for(stream_writer.from_client.get(), 1)
assert sent == [InternalMessage(m) for m in messages]
assert stream_writer.from_client.empty()

await reconnector.close(flush=False)

async def test_buffered_messages_on_reconnect_sent_as_single_batch(
self,
reconnector: WriterAsyncIOReconnector,
get_stream_writer,
):
stream_writer = get_stream_writer()
messages = [
PublicMessage(data="msg1", seqno=1),
PublicMessage(data="msg2", seqno=2),
PublicMessage(data="msg3", seqno=3),
]
await reconnector.write_with_ack_future(messages)

sent = await asyncio.wait_for(stream_writer.from_client.get(), 1)
assert len(sent) == 3

# ack first message, then trigger retriable error
stream_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=1))
stream_writer.from_server.put_nowait(issues.Overloaded("test"))

second_writer = get_stream_writer()
resent = await asyncio.wait_for(second_writer.from_client.get(), 1)

# msg2 and msg3 must arrive as a single batch, not two separate sends
assert resent == [InternalMessage(messages[1]), InternalMessage(messages[2])]
assert second_writer.from_client.empty()

second_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=2))
second_writer.from_server.put_nowait(self.make_default_ack_message(seq_no=3))
await reconnector.close(flush=True)

async def test_deny_double_seqno(self, reconnector: WriterAsyncIOReconnector, get_stream_writer):
writer = get_stream_writer()

Expand Down
120 changes: 119 additions & 1 deletion ydb/_topic_writer/topic_writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@

import pytest

from .topic_writer import _split_messages_by_size
from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec
from .topic_writer import (
InternalMessage,
PublicMessage,
_split_messages_by_size,
_split_messages_for_send,
messages_to_proto_requests,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -48,3 +55,114 @@
def test_split_messages_by_size(messages: List[int], split_size: int, expected: List[List[int]]):
res = _split_messages_by_size(messages, split_size, lambda x: x) # noqa
assert res == expected


def _make_msg(data: bytes, codec: PublicCodec = PublicCodec.RAW, seqno: int = 1) -> InternalMessage:
msg = InternalMessage(PublicMessage(data=data, seqno=seqno))
msg.codec = codec
return msg


class TestSplitMessagesForSend:
def test_empty(self):
assert _split_messages_for_send([]) == []

def test_single_message(self):
msg = _make_msg(b"hello")
assert _split_messages_for_send([msg]) == [[msg]]

def test_same_codec_kept_together(self):
msgs = [_make_msg(b"a", PublicCodec.RAW, i) for i in range(1, 4)]
assert _split_messages_for_send(msgs) == [msgs]

def test_different_codecs_split_into_separate_groups(self):
raw = _make_msg(b"a", PublicCodec.RAW, seqno=1)
gzip = _make_msg(b"b", PublicCodec.GZIP, seqno=2)
result = _split_messages_for_send([raw, gzip])
assert result == [[raw], [gzip]]

def test_alternating_codecs_each_run_is_own_group(self):
# RAW, GZIP, RAW must produce 3 separate groups, not 2
r1 = _make_msg(b"a", PublicCodec.RAW, seqno=1)
g1 = _make_msg(b"b", PublicCodec.GZIP, seqno=2)
r2 = _make_msg(b"c", PublicCodec.RAW, seqno=3)
result = _split_messages_for_send([r1, g1, r2])
assert result == [[r1], [g1], [r2]]

def test_size_limit_splits_same_codec_group(self, monkeypatch):
from .topic_writer import _message_data_overhead
from .. import connection

# patch limit so that one 5-byte message fits but two do not
monkeypatch.setattr(connection, "_DEFAULT_MAX_GRPC_MESSAGE_SIZE", _message_data_overhead + 5)
m1 = _make_msg(b"hello", PublicCodec.RAW, seqno=1)
m2 = _make_msg(b"x", PublicCodec.RAW, seqno=2)
result = _split_messages_for_send([m1, m2])
assert result == [[m1], [m2]]

def test_size_limit_not_exceeded_keeps_group(self):
m1 = _make_msg(b"hello", PublicCodec.RAW, seqno=1)
m2 = _make_msg(b"world", PublicCodec.RAW, seqno=2)
result = _split_messages_for_send([m1, m2])
assert result == [[m1, m2]]

def test_codec_split_takes_priority_over_size(self):
# tiny messages but different codecs — must still split by codec
r = _make_msg(b"a", PublicCodec.RAW, seqno=1)
g = _make_msg(b"b", PublicCodec.GZIP, seqno=2)
result = _split_messages_for_send([r, g])
assert len(result) == 2
assert result[0][0].codec == PublicCodec.RAW
assert result[1][0].codec == PublicCodec.GZIP


class TestMessagesToProtoRequests:
def test_empty(self):
assert messages_to_proto_requests([], tx_identity=None) == []

def test_single_message_produces_one_request(self):
msg = _make_msg(b"hello", PublicCodec.RAW, seqno=1)
requests = messages_to_proto_requests([msg], tx_identity=None)
assert len(requests) == 1

def test_request_codec_matches_messages(self):
raw = _make_msg(b"a", PublicCodec.RAW, seqno=1)
requests = messages_to_proto_requests([raw], tx_identity=None)
assert requests[0].value.codec == PublicCodec.RAW

gzip = _make_msg(b"b", PublicCodec.GZIP, seqno=1)
requests = messages_to_proto_requests([gzip], tx_identity=None)
assert requests[0].value.codec == PublicCodec.GZIP

def test_same_codec_produces_single_request(self):
msgs = [_make_msg(b"x", PublicCodec.RAW, seqno=i) for i in range(1, 4)]
requests = messages_to_proto_requests(msgs, tx_identity=None)
assert len(requests) == 1
assert len(requests[0].value.messages) == 3

def test_different_codecs_produce_separate_requests(self):
raw = _make_msg(b"a", PublicCodec.RAW, seqno=1)
gzip = _make_msg(b"b", PublicCodec.GZIP, seqno=2)
requests = messages_to_proto_requests([raw, gzip], tx_identity=None)
assert len(requests) == 2
assert requests[0].value.codec == PublicCodec.RAW
assert requests[1].value.codec == PublicCodec.GZIP

def test_size_exceeded_produces_multiple_requests(self, monkeypatch):
from .topic_writer import _message_data_overhead
from .. import connection

monkeypatch.setattr(connection, "_DEFAULT_MAX_GRPC_MESSAGE_SIZE", _message_data_overhead + 5)
m1 = _make_msg(b"hello", PublicCodec.RAW, seqno=1)
m2 = _make_msg(b"x", PublicCodec.RAW, seqno=2)
requests = messages_to_proto_requests([m1, m2], tx_identity=None)
assert len(requests) == 2
assert requests[0].value.messages[0].seq_no == 1
assert requests[1].value.messages[0].seq_no == 2

def test_messages_order_preserved_within_request(self):
msgs = [_make_msg(f"msg{i}".encode(), PublicCodec.RAW, seqno=i) for i in range(1, 5)]
requests = messages_to_proto_requests(msgs, tx_identity=None)
assert len(requests) == 1
seq_nos = [m.seq_no for m in requests[0].value.messages]
assert seq_nos == [1, 2, 3, 4]
Loading