From f6b14cbc169889d96530525dbee0269d36b527d4 Mon Sep 17 00:00:00 2001 From: quettabit <27509167+quettabit@users.noreply.github.com> Date: Thu, 21 May 2026 00:11:41 -0600 Subject: [PATCH] initial commit --- tests/test_correctness.py | 65 ++++++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 24 deletions(-) diff --git a/tests/test_correctness.py b/tests/test_correctness.py index c871952..5293ca2 100644 --- a/tests/test_correctness.py +++ b/tests/test_correctness.py @@ -3,9 +3,10 @@ import pytest -from s2_sdk import Batching, ReadLimit, Record, Retry, S2Stream, SeqNum +from s2_sdk import Batching, Record, Retry, S2Stream, SeqNum, SequencedRecord -TOTAL_RECORDS = 1024 +_NUM_RECORDS = 1024 +_RECORD_IDX_HEADER = b"record-idx" @pytest.fixture(scope="session") @@ -20,18 +21,16 @@ def basin_prefix() -> str: @pytest.mark.correctness @pytest.mark.asyncio -async def test_concurrent_producer_and_consumer_remain_gapless(stream: S2Stream): +async def test_gapless_seq_nums_and_record_order_during_concurrent_append_and_read( + stream: S2Stream, +): async def read_records() -> None: - highest_contiguous_index = -1 + next_record_idx = 0 last_seq_num: int | None = None - observed_records = 0 + num_records_read = 0 - async for batch in stream.read_session( - start=SeqNum(0), limit=ReadLimit(count=TOTAL_RECORDS), wait=60 - ): + async for batch in stream.read_session(start=SeqNum(0), wait=60): for record in batch.records: - assert observed_records < TOTAL_RECORDS - seq_num = record.seq_num if last_seq_num is None: assert seq_num == 0 @@ -39,30 +38,48 @@ async def read_records() -> None: assert seq_num == last_seq_num + 1 last_seq_num = seq_num - body = record.body.decode() - index = int(body) - assert 0 <= index < TOTAL_RECORDS - assert index <= highest_contiguous_index + 1 + record_idx = _record_idx(record) + assert 0 <= record_idx < _NUM_RECORDS + assert record_idx <= next_record_idx + + if record_idx == next_record_idx: + next_record_idx = record_idx + 1 + num_records_read += 1 - if index == highest_contiguous_index + 1: - highest_contiguous_index = index - observed_records += 1 + if next_record_idx == _NUM_RECORDS: + assert last_seq_num + 1 == num_records_read + assert num_records_read >= _NUM_RECORDS + return - assert highest_contiguous_index == TOTAL_RECORDS - 1 - assert last_seq_num == TOTAL_RECORDS - 1 - assert observed_records == TOTAL_RECORDS + pytest.fail( + "read session ended before all records were read: " + f"next_record_idx={next_record_idx}, " + f"num_records_read={num_records_read}" + ) async def append_records() -> None: async with stream.producer(batching=Batching(max_records=16)) as producer: tickets = [] - for i in range(TOTAL_RECORDS): - ticket = await producer.submit(Record(body=str(i).encode())) + for idx in range(_NUM_RECORDS): + ticket = await producer.submit(_indexed_record(idx)) tickets.append(ticket) for ticket in tickets: - ack = await ticket - assert ack.seq_num >= 0 + await ticket async with asyncio.TaskGroup() as task_group: task_group.create_task(read_records()) task_group.create_task(append_records()) + + +def _indexed_record(idx: int) -> Record: + return Record( + body=b"", + headers=[(_RECORD_IDX_HEADER, str(idx).encode())], + ) + + +def _record_idx(record: SequencedRecord) -> int: + values = [value for key, value in record.headers if key == _RECORD_IDX_HEADER] + assert len(values) == 1 + return int(values[0])