Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 41 additions & 24 deletions tests/test_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

import pytest

from s2_sdk import Batching, ReadLimit, Record, Retry, S2Stream, SeqNum
from s2_sdk import Batching, Record, Retry, S2Stream, SeqNum, SequencedRecord

TOTAL_RECORDS = 1024
_NUM_RECORDS = 1024
_RECORD_IDX_HEADER = b"record-idx"


@pytest.fixture(scope="session")
Expand All @@ -20,49 +21,65 @@ def basin_prefix() -> str:

@pytest.mark.correctness
@pytest.mark.asyncio
async def test_concurrent_producer_and_consumer_remain_gapless(stream: S2Stream):
async def test_gapless_seq_nums_and_record_order_during_concurrent_append_and_read(
stream: S2Stream,
):
async def read_records() -> None:
highest_contiguous_index = -1
next_record_idx = 0
last_seq_num: int | None = None
observed_records = 0
num_records_read = 0

async for batch in stream.read_session(
start=SeqNum(0), limit=ReadLimit(count=TOTAL_RECORDS), wait=60
):
async for batch in stream.read_session(start=SeqNum(0), wait=60):
for record in batch.records:
assert observed_records < TOTAL_RECORDS

seq_num = record.seq_num
if last_seq_num is None:
assert seq_num == 0
else:
assert seq_num == last_seq_num + 1
last_seq_num = seq_num

body = record.body.decode()
index = int(body)
assert 0 <= index < TOTAL_RECORDS
assert index <= highest_contiguous_index + 1
record_idx = _record_idx(record)
assert 0 <= record_idx < _NUM_RECORDS
assert record_idx <= next_record_idx

if record_idx == next_record_idx:
next_record_idx = record_idx + 1
num_records_read += 1

if index == highest_contiguous_index + 1:
highest_contiguous_index = index
observed_records += 1
if next_record_idx == _NUM_RECORDS:
assert last_seq_num + 1 == num_records_read
assert num_records_read >= _NUM_RECORDS
return
Comment thread
quettabit marked this conversation as resolved.

assert highest_contiguous_index == TOTAL_RECORDS - 1
assert last_seq_num == TOTAL_RECORDS - 1
assert observed_records == TOTAL_RECORDS
pytest.fail(
"read session ended before all records were read: "
f"next_record_idx={next_record_idx}, "
f"num_records_read={num_records_read}"
)

async def append_records() -> None:
async with stream.producer(batching=Batching(max_records=16)) as producer:
tickets = []
for i in range(TOTAL_RECORDS):
ticket = await producer.submit(Record(body=str(i).encode()))
for idx in range(_NUM_RECORDS):
ticket = await producer.submit(_indexed_record(idx))
tickets.append(ticket)

for ticket in tickets:
ack = await ticket
assert ack.seq_num >= 0
await ticket

async with asyncio.TaskGroup() as task_group:
task_group.create_task(read_records())
task_group.create_task(append_records())


def _indexed_record(idx: int) -> Record:
return Record(
body=b"",
headers=[(_RECORD_IDX_HEADER, str(idx).encode())],
)


def _record_idx(record: SequencedRecord) -> int:
values = [value for key, value in record.headers if key == _RECORD_IDX_HEADER]
assert len(values) == 1
return int(values[0])
Loading