Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions python/python/lance/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,52 @@ def __call__(
) -> Union[dict[str, torch.Tensor], torch.Tensor]: ...


def _is_bfloat16_type(t: pa.DataType) -> bool:
"""Check if a PyArrow type is the lance bfloat16 extension type."""
return isinstance(t, pa.ExtensionType) and t.extension_name == "lance.bfloat16"


def _bf16_to_tensor(arr: pa.Array) -> torch.Tensor:
"""Convert a bfloat16 extension array to a torch.bfloat16 tensor.

Reinterprets the raw bytes as uint16 and views as bfloat16,
since they share the same 2-byte memory layout.
Null values are replaced with NaN.
"""
storage = arr.storage if isinstance(arr.type, pa.ExtensionType) else arr
buf = storage.buffers()[1]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do a sanity check that the data type of storage is a 16-bit type at this point?

offset = storage.offset * 2 # 2 bytes per bf16 value
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The given buffer is not writable",
category=UserWarning,
)
tensor = torch.frombuffer(
memoryview(buf),
dtype=torch.uint16,
count=len(storage),
offset=offset,
).view(torch.bfloat16)
except (AttributeError, RuntimeError, TypeError):
np_uint16 = np.frombuffer(
buf, dtype=np.uint16, count=len(storage), offset=offset
)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The given NumPy array is not writable",
category=UserWarning,
)
tensor = torch.from_numpy(np_uint16).view(torch.bfloat16)
if arr.null_count > 0:
tensor = tensor.clone()
null_mask = torch.from_numpy(arr.is_null().to_numpy(zero_copy_only=False))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like there should be a way to do this without a copy but maybe not.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asked claude / codex to do double checks, to make this opportunist

tensor[null_mask] = float("nan")
return tensor


# Convert an Arrow FSL array into a 2D torch tensor
def _fsl_to_tensor(arr: pa.FixedSizeListArray, dimension: int) -> torch.Tensor:
# Note: FixedSizeListArray.values does not take offset/len into account and
Expand Down Expand Up @@ -104,6 +150,14 @@ def _to_tensor(
or pa.types.is_integer(arr.type.value_type)
):
tensor = _fsl_to_tensor(arr, arr.type.list_size)
elif pa.types.is_fixed_size_list(arr.type) and _is_bfloat16_type(
arr.type.value_type
):
values = arr.values
start = arr.offset * arr.type.list_size
num_vals = len(arr) * arr.type.list_size
values = values.slice(start, num_vals)
tensor = _bf16_to_tensor(values).view(-1, arr.type.list_size)
elif (
pa.types.is_integer(arr.type)
or pa.types.is_floating(arr.type)
Expand All @@ -113,13 +167,15 @@ def _to_tensor(

if uint64_as_int64 and tensor.dtype == torch.uint64:
tensor = tensor.to(torch.int64)
elif _is_bfloat16_type(arr.type):
tensor = _bf16_to_tensor(arr)
elif hf_converter is not None:
tensor = hf_converter.to_pytorch(col, arr)

if tensor is None:
raise ValueError(
"Only support FixedSizeList<f16/f32/f64> or "
+ f"numeric values, got: {arr.type}"
"Only support FixedSizeList<f16/bf16/f32/f64> or "
+ f"numeric/bfloat16 values, got: {arr.type}"
)

del arr
Expand Down
95 changes: 94 additions & 1 deletion python/python/tests/torch_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from lance.sampler import ShardedBatchSampler, ShardedFragmentSampler

torch = pytest.importorskip("torch")
from lance.torch.data import LanceDataset, SafeLanceDataset # noqa: E402
from lance.torch.data import ( # noqa: E402
LanceDataset,
SafeLanceDataset,
_bf16_to_tensor,
)


def test_iter_over_dataset_fixed_shape_tensor(tmp_path):
Expand Down Expand Up @@ -326,6 +330,95 @@ def to_tensor_fn(batch, *args, **kwargs):
assert first["val"].shape == (4, 100)


def test_iter_over_dataset_bfloat16(tmp_path):
"""Test that bfloat16 vector columns convert to torch.bfloat16 tensors."""
ml_dtypes = pytest.importorskip("ml_dtypes")
from lance.arrow import BFloat16Array

dim = 32
num_rows = 128
# Create random bfloat16 vectors via float32 → bfloat16 cast
f32_data = np.random.random(num_rows * dim).astype("f")
bf16_data = f32_data.astype(ml_dtypes.bfloat16)

# Build a FixedSizeList<bf16> column
inner = BFloat16Array.from_numpy(bf16_data)
fsl = pa.FixedSizeListArray.from_arrays(inner, dim)
ids = pa.array(range(num_rows), type=pa.int32())
tbl = pa.Table.from_arrays([ids, fsl], ["ids", "vec"])

ds = lance.write_dataset(tbl, tmp_path / "data.lance", max_rows_per_group=32)

torch_ds = LanceDataset(ds, batch_size=16, columns=["ids", "vec"])

total_rows = 0
for batch in torch_ds:
assert set(batch.keys()) == {"ids", "vec"}
assert batch["vec"].dtype == torch.bfloat16
assert batch["vec"].shape[1] == dim
assert batch["ids"].dtype == torch.int32
total_rows += batch["vec"].shape[0]
assert total_rows == num_rows


def test_scalar_bfloat16_column(tmp_path):
"""Test that a scalar bfloat16 column converts to torch.bfloat16 tensor."""
ml_dtypes = pytest.importorskip("ml_dtypes")
from lance.arrow import BFloat16Array

num_rows = 64
f32_data = np.random.random(num_rows).astype("f")
bf16_data = f32_data.astype(ml_dtypes.bfloat16)

arr = BFloat16Array.from_numpy(bf16_data)
tbl = pa.Table.from_arrays([arr], ["val"])

ds = lance.write_dataset(tbl, tmp_path / "data.lance")

torch_ds = LanceDataset(ds, batch_size=16, columns=["val"])

total_rows = 0
for batch in torch_ds:
assert batch.dtype == torch.bfloat16
total_rows += batch.shape[0]
assert total_rows == num_rows


def test_bf16_to_tensor_zero_copy_without_nulls():
"""Non-null bf16 arrays should alias the Arrow data buffer."""
ml_dtypes = pytest.importorskip("ml_dtypes")
from lance.arrow import BFloat16Array

values = np.array([1.0, 2.0, 3.0, 4.0], dtype=ml_dtypes.bfloat16)
arr = BFloat16Array.from_numpy(values).slice(1, 2)

tensor = _bf16_to_tensor(arr)

assert tensor.dtype == torch.bfloat16
assert torch.equal(
tensor.to(torch.float32),
torch.tensor([2.0, 3.0], dtype=torch.float32),
)
assert (
tensor.data_ptr() == arr.storage.buffers()[1].address + arr.storage.offset * 2
)


def test_bf16_to_tensor_clones_when_nulls_present():
"""Null replacement requires a writable tensor, so the Arrow buffer is cloned."""
arr = lance.arrow.bfloat16_array([1.0, None, 3.0])

tensor = _bf16_to_tensor(arr)

assert tensor.dtype == torch.bfloat16
assert (
tensor.data_ptr() != arr.storage.buffers()[1].address + arr.storage.offset * 2
)
assert tensor[0].to(torch.float32).item() == pytest.approx(1.0)
assert torch.isnan(tensor[1])
assert tensor[2].to(torch.float32).item() == pytest.approx(3.0)


def test_safe_lance_dataset_worker_uses_dataset_options(tmp_path: Path):
"""Worker processes must reopen the dataset with dataset_options.

Expand Down
Loading