Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions docs/docs/pypaimon/pytorch.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,100 @@ When the `streaming` parameter is true, it will iteratively read;
when it is false, it will read the full amount of data into memory.

**`prefetch_concurrency`** (default: 1): When streaming is true, number of threads used for parallel prefetch within each DataLoader worker. Set to a value greater than 1 to partition splits across threads and increase read throughput. Has no effect when streaming is false.

## Shuffle

PyPaimon supports streaming shuffle for PyTorch `IterableDataset`. The shuffle
pipeline can be composed of three layers:

1. **Chunk shuffle**: split files into row chunks during scan planning and
shuffle the generated chunk splits. This is enabled by
`TableScan.with_chunk_shuffle(seed, chunk_size)`.
2. **Split interleave**: read from multiple splits in round-robin order inside
each DataLoader worker.
3. **Buffer shuffle**: apply a reservoir-style row shuffle buffer before rows
are yielded to PyTorch.

Chunk shuffle is a scan planning feature for append tables, including
Data Evolution append tables. For Data Evolution tables, chunk shuffle keeps
row-id-aligned data files and sidecar files together while slicing by row-id
range. Chunk shuffle should be used with file formats that **support random
access**. Currently, the random-access file formats are Lance, Vortex, Row, and
Blob. Primary-key tables and deletion-vector scans are not supported by
`with_chunk_shuffle`.

The second and third layers are Dataset features. They work on the splits you
pass to `to_torch`, so they can be used with either normal splits or
chunk-shuffled splits.

### Use Dataset Shuffle Only

Use this when normal scan splits are enough and you only want split interleave
plus row buffer shuffle:

```python
from torch.utils.data import DataLoader

table_scan = read_builder.new_scan()
table_read = read_builder.new_read()
splits = table_scan.plan().splits()

dataset = table_read.to_torch(
splits,
streaming=True,
shuffle=True,
seed=42,
buffer_size=1000,
max_buffer_input_splits=10,
)

dataloader = DataLoader(
dataset,
batch_size=32,
num_workers=2,
shuffle=False,
)
```

`buffer_size` controls the row shuffle buffer. Larger values produce a better
approximation of global shuffle, at the cost of more memory. If
`max_buffer_input_splits` is `1`, split interleave is skipped and only buffer
shuffle is applied. `shuffle=True` requires `streaming=True` and does not
support `prefetch_concurrency > 1`.

### Use All Three Layers

For append tables, enable chunk shuffle during scan planning, then enable
Dataset shuffle when converting to PyTorch:

```python
from torch.utils.data import DataLoader

seed = 42

table_scan = read_builder.new_scan().with_chunk_shuffle(
seed=seed,
chunk_size=1000,
)
table_read = read_builder.new_read()
splits = table_scan.plan().splits()

dataset = table_read.to_torch(
splits,
streaming=True,
shuffle=True,
seed=seed,
buffer_size=1000,
max_buffer_input_splits=10,
)

dataloader = DataLoader(
dataset,
batch_size=32,
num_workers=2,
shuffle=False,
)
```

Call `dataset.set_epoch(epoch)` before creating or iterating a DataLoader for a
new training epoch if you want a different buffer-shuffle order for each epoch.
224 changes: 182 additions & 42 deletions paimon-python/pypaimon/read/datasource/torch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
Module to read a Paimon table into PyTorch Dataset.
"""
import queue
import random
import threading
from typing import List
from typing import Iterator, List

import torch
from torch.utils.data import Dataset, IterableDataset
Expand All @@ -29,6 +30,12 @@
from pypaimon.read.table_read import TableRead


def _share_epoch_with_torch_workers(value):
if isinstance(value, torch.Tensor):
return value.share_memory_()
return torch.tensor(value, dtype=torch.long).share_memory_()


class TorchDataset(Dataset):
"""
PyTorch Dataset implementation for reading Paimon table data.
Expand Down Expand Up @@ -76,7 +83,44 @@ def __getitem__(self, index: int):
return self._data[index]


class TorchIterDataset(IterableDataset):
class _BaseTorchIterDataset(IterableDataset):
"""
Shared helpers for streaming PyTorch datasets backed by Paimon splits.
"""

def __init__(self, table_read: TableRead, splits: List[Split]):
self.table_read = table_read
self.splits = splits
self.field_names = [field.name for field in table_read.read_type]

def _row_to_dict(self, offset_row) -> dict:
row_dict = {}
for i, field_name in enumerate(self.field_names):
value = offset_row.get_field(i)
row_dict[field_name] = value
return row_dict

def _worker_splits(self, worker_info) -> List[Split]:
if worker_info is None:
return self.splits

worker_id = worker_info.id
num_workers = worker_info.num_workers
total_splits = len(self.splits)
splits_per_worker = total_splits // num_workers
remainder = total_splits % num_workers

if worker_id < remainder:
start_idx = worker_id * (splits_per_worker + 1)
end_idx = start_idx + splits_per_worker + 1
else:
start_idx = worker_id * splits_per_worker + remainder
end_idx = start_idx + splits_per_worker

return self.splits[start_idx:end_idx]


class TorchIterDataset(_BaseTorchIterDataset):
"""
PyTorch IterableDataset implementation for reading Paimon table data.

Expand Down Expand Up @@ -104,18 +148,8 @@ def __init__(self, table_read: TableRead, splits: List[Split], prefetch_concurre
this worker (default 1). When > 1, splits are partitioned across
threads to increase read throughput.
"""
self.table_read = table_read
self.splits = splits
super().__init__(table_read, splits)
self.prefetch_concurrency = max(1, int(prefetch_concurrency))
# Get field names from read_type
self.field_names = [field.name for field in table_read.read_type]

def _row_to_dict(self, offset_row) -> dict:
row_dict = {}
for i, field_name in enumerate(self.field_names):
value = offset_row.get_field(i)
row_dict[field_name] = value
return row_dict

def __iter__(self):
"""
Expand All @@ -128,30 +162,7 @@ def __iter__(self):
row data of dict type, where keys are column names
"""
worker_info = torch.utils.data.get_worker_info()

if worker_info is None:
# Single-process data loading, iterate over all splits
splits_to_process = self.splits
else:
# Multi-process data loading, partition splits across workers
worker_id = worker_info.id
num_workers = worker_info.num_workers

# Calculate start and end indices for this worker
# Distribute splits evenly by slicing
total_splits = len(self.splits)
splits_per_worker = total_splits // num_workers
remainder = total_splits % num_workers

# Workers with id < remainder get one extra split
if worker_id < remainder:
start_idx = worker_id * (splits_per_worker + 1)
end_idx = start_idx + splits_per_worker + 1
else:
start_idx = worker_id * splits_per_worker + remainder
end_idx = start_idx + splits_per_worker

splits_to_process = self.splits[start_idx:end_idx]
splits_to_process = self._worker_splits(worker_info)

if self.prefetch_concurrency > 1:
for row in self._iter_rows(splits_to_process):
Expand All @@ -161,11 +172,7 @@ def __iter__(self):
worker_iterator = self.table_read.to_iterator(splits_to_process)

for offset_row in worker_iterator:
row_dict = {}
for i, field_name in enumerate(self.field_names):
value = offset_row.get_field(i)
row_dict[field_name] = value
yield row_dict
yield self._row_to_dict(offset_row)

def _iter_rows(self, splits: List[Split]):
n = min(self.prefetch_concurrency, len(splits))
Expand Down Expand Up @@ -221,3 +228,136 @@ def producer(split_group: List):
stop.set()
for t in threads:
t.join(timeout=self._PREFETCH_JOIN_TIMEOUT_SEC)


class TorchShuffledIterDataset(_BaseTorchIterDataset):
"""
PyTorch IterableDataset with Paimon-controlled streaming shuffle.

This dataset consumes pre-planned splits, then mixes rows with split
interleaving and a shuffle buffer. Chunk-level shuffle, when needed,
stays in TableScan.with_chunk_shuffle().
"""

def __init__(
self,
table_read: TableRead,
splits: List[Split],
seed: int = 0,
buffer_size: int = 1000,
max_buffer_input_splits: int = 10,
):
super().__init__(table_read, splits)
self.seed = self._require_int(seed, "seed")
self.buffer_size = self._require_positive_int(buffer_size, "buffer_size")
self.max_buffer_input_splits = self._require_positive_int(
max_buffer_input_splits, "max_buffer_input_splits")
self._epoch = _share_epoch_with_torch_workers(0)

def __setstate__(self, state):
self.__dict__ = state
self._epoch = _share_epoch_with_torch_workers(self._epoch)

@property
def epoch(self) -> int:
return int(self._epoch)

@epoch.setter
def epoch(self, epoch: int) -> None:
epoch = self._require_int(epoch, "epoch")
self._epoch += epoch - self._epoch

@staticmethod
def _require_int(value: int, name: str) -> int:
if not isinstance(value, int):
raise ValueError("%s must be an int" % name)
return value

@staticmethod
def _require_positive_int(value: int, name: str) -> int:
if not isinstance(value, int) or value <= 0:
raise ValueError("%s must be a positive int" % name)
return value

def set_epoch(self, epoch: int) -> "TorchShuffledIterDataset":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_epoch() does not work once the dataset is already held by persistent DataLoader workers. With DataLoader(..., num_workers>0, persistent_workers=True), the worker processes keep their own Dataset instances alive across epochs, so calling dataset.set_epoch(1) in the parent process only updates the parent copy. The workers still use the old self.epoch here when building the buffer-shuffle RNG, which makes the shuffle order repeat even though the docs say callers can set the epoch before iterating the DataLoader for the next epoch. Could we either propagate epoch through worker-visible shared state, or document/reject persistent workers and require rebuilding the DataLoader? It would be good to add a test for this mode as well.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I studied that the Dataloader mechanism is as below:
image

The main process will create multiple worker processes through LINUX fork, they share the same memory but with COW protection.
If persistent worker is true, workers are reused across different epochs. If the main process changes the epoch, workers won't see it because of COW.

Now I also refer to huggingface Dataset, use torch.Tensor.share_memory_() which will store the shared data in a special file, so that the changes are visible to all processes.

self.epoch = epoch
return self

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id if worker_info is not None else 0
splits_to_process = self._worker_splits(worker_info)

if self.max_buffer_input_splits == 1:
rows = self._iter_ordered_rows(splits_to_process)
else:
rows = self._iter_interleaved_rows(splits_to_process)
for row in self._iter_buffer_shuffled_rows(rows, worker_id):
yield row

def _iter_ordered_rows(self, splits: List[Split]) -> Iterator[dict]:
for offset_row in self.table_read.to_iterator(splits):
yield self._row_to_dict(offset_row)

def _iter_interleaved_rows(self, splits: List[Split]) -> Iterator[dict]:
if not splits:
return

split_iter = iter(splits)
active: List[Iterator] = []

def add_next_split() -> bool:
try:
split = next(split_iter)
except StopIteration:
return False
active.append(self.table_read.to_iterator([split]))
return True

for _ in range(min(self.max_buffer_input_splits, len(splits))):
add_next_split()

idx = 0
try:
while active:
if idx >= len(active):
idx = 0
row_iter = active[idx]
try:
offset_row = next(row_iter)
except StopIteration:
self._close_iterator(row_iter)
del active[idx]
add_next_split()
continue

yield self._row_to_dict(offset_row)
idx += 1
finally:
for row_iter in active:
self._close_iterator(row_iter)

@staticmethod
def _close_iterator(row_iter) -> None:
close = getattr(row_iter, "close", None)
if close is not None:
close()

def _iter_buffer_shuffled_rows(
self,
rows: Iterator[dict],
worker_id: int,
) -> Iterator[dict]:
rng = random.Random(self.seed + self.epoch * 1000003 + worker_id)
buffer = []
for row in rows:
if len(buffer) < self.buffer_size:
buffer.append(row)
continue
idx = rng.randint(0, self.buffer_size - 1)
yield buffer[idx]
buffer[idx] = row

rng.shuffle(buffer)
for row in buffer:
yield row
Loading
Loading