Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,6 @@ node_modules/
uv.lock
pixi.lock

# Claude Code local files
.claude/
spatialdata_pr_context.md
7 changes: 6 additions & 1 deletion src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,12 @@ def get_values(
if origin == "obs":
df = obs[value_key_values].copy()
if origin == "var":
matched_table.obs = pd.DataFrame(obs)
# When the table came from anndata.experimental.read_lazy, obs is a Dataset2D, not a
# DataFrame, and pd.DataFrame(obs) returns a malformed frame. Materialize via to_memory().
if isinstance(obs, pd.DataFrame):
matched_table.obs = pd.DataFrame(obs)
else:
matched_table.obs = obs.to_memory()
if table_layer is None:
x = matched_table[:, value_key_values].X
else:
Expand Down
61 changes: 28 additions & 33 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@

if TYPE_CHECKING:
from spatialdata._core.query.spatial_query import BaseSpatialRequest
from spatialdata._io.format import (
SpatialDataContainerFormatType,
SpatialDataFormatType,
)
from spatialdata._io.format import SpatialDataContainerFormatType, SpatialDataFormatType


class SpatialData:
Expand Down Expand Up @@ -232,9 +229,7 @@ def get_annotated_regions(table: AnnData) -> list[str]:
-------
The annotated regions.
"""
from spatialdata.models.models import (
_get_region_metadata_from_region_key_column,
)
from spatialdata.models.models import _get_region_metadata_from_region_key_column

return _get_region_metadata_from_region_key_column(table)

Expand Down Expand Up @@ -705,9 +700,7 @@ def _filter_tables(
if table is not None and len(table) != 0:
tables[table_name] = table
elif by == "elements":
from spatialdata._core.query.relational_query import (
_filter_table_by_elements,
)
from spatialdata._core.query.relational_query import _filter_table_by_elements

assert elements_dict is not None
table = _filter_table_by_elements(table, elements_dict=elements_dict)
Expand All @@ -732,10 +725,7 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None:
The method does not allow to rename a coordinate system into an existing one, unless the existing one is also
renamed in the same call.
"""
from spatialdata.transformations.operations import (
get_transformation,
set_transformation,
)
from spatialdata.transformations.operations import get_transformation, set_transformation

# check that the rename_dict is valid
old_names = self.coordinate_systems
Expand Down Expand Up @@ -1111,7 +1101,7 @@ def write(
overwrite: bool = False,
consolidate_metadata: bool = True,
update_sdata_path: bool = True,
sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None,
sdata_formats: (SpatialDataFormatType | list[SpatialDataFormatType] | None) = None,
shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None,
raster_compressor: dict[Literal["lz4", "zstd"], int] | None = None,
) -> None:
Expand Down Expand Up @@ -1225,15 +1215,12 @@ def _write_element(
)

root_group, element_type_group, element_group = _get_groups_for_element(
zarr_path=zarr_container_path, element_type=element_type, element_name=element_name, use_consolidated=False
)
from spatialdata._io import (
write_image,
write_labels,
write_points,
write_shapes,
write_table,
zarr_path=zarr_container_path,
element_type=element_type,
element_name=element_name,
use_consolidated=False,
)
from spatialdata._io import write_image, write_labels, write_points, write_shapes, write_table
from spatialdata._io.format import _parse_formats

if parsed_formats is None:
Expand Down Expand Up @@ -1287,7 +1274,7 @@ def write_element(
self,
element_name: str | list[str],
overwrite: bool = False,
sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None,
sdata_formats: (SpatialDataFormatType | list[SpatialDataFormatType] | None) = None,
shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None,
raster_compressor: dict[Literal["lz4", "zstd"], int] | None = None,
) -> None:
Expand Down Expand Up @@ -1573,7 +1560,10 @@ def write_channel_names(self, element_name: str | None = None) -> None:
# Mypy does not understand that path is not None so we have the check in the conditional
if element_type == "images" and self.path is not None:
_, _, element_group = _get_groups_for_element(
zarr_path=Path(self.path), element_type=element_type, element_name=element_name, use_consolidated=False
zarr_path=Path(self.path),
element_type=element_type,
element_name=element_name,
use_consolidated=False,
)

from spatialdata._io._utils import overwrite_channel_names
Expand Down Expand Up @@ -1624,19 +1614,18 @@ def write_transformations(self, element_name: str | None = None) -> None:
)
axes = get_axes_names(element)
if isinstance(element, DataArray | DataTree):
from spatialdata._io._utils import (
overwrite_coordinate_transformations_raster,
)
from spatialdata._io._utils import overwrite_coordinate_transformations_raster
from spatialdata._io.format import RasterFormats

raster_format = RasterFormats[element_group.metadata.attributes["spatialdata_attrs"]["version"]]
overwrite_coordinate_transformations_raster(
group=element_group, axes=axes, transformations=transformations, raster_format=raster_format
group=element_group,
axes=axes,
transformations=transformations,
raster_format=raster_format,
)
elif isinstance(element, DaskDataFrame | GeoDataFrame | AnnData):
from spatialdata._io._utils import (
overwrite_coordinate_transformations_non_raster,
)
from spatialdata._io._utils import overwrite_coordinate_transformations_non_raster

overwrite_coordinate_transformations_non_raster(
group=element_group,
Expand Down Expand Up @@ -1855,6 +1844,7 @@ def read(
file_path: str | Path | UPath | zarr.Group,
selection: tuple[str] | None = None,
reconsolidate_metadata: bool = False,
lazy: bool = False,
) -> SpatialData:
"""
Read a SpatialData object from a Zarr storage (on-disk or remote).
Expand All @@ -1867,6 +1857,11 @@ def read(
The elements to read (images, labels, points, shapes, table). If None, all elements are read.
reconsolidate_metadata
If the consolidated metadata store got corrupted this can lead to errors when trying to read the data.
lazy
If True, read tables lazily using anndata.experimental.read_lazy.
This keeps large tables out of memory until needed. Requires anndata >= 0.12.
Note: Images, labels, and points are always read lazily (using Dask).
This parameter only affects tables, which are normally loaded into memory.

Returns
-------
Expand All @@ -1879,7 +1874,7 @@ def read(

_write_consolidated_metadata(file_path)

return read_zarr(file_path, selection=selection)
return read_zarr(file_path, selection=selection, lazy=lazy)

@property
def images(self) -> Images:
Expand Down
38 changes: 29 additions & 9 deletions src/spatialdata/_io/io_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,38 @@
from anndata._io.specs import write_elem as write_adata
from ome_zarr.format import Format

from spatialdata._io.format import (
CurrentTablesFormat,
TablesFormats,
TablesFormatV01,
TablesFormatV02,
_parse_version,
)
from spatialdata._io.format import CurrentTablesFormat, TablesFormats, TablesFormatV01, TablesFormatV02, _parse_version
from spatialdata.models import TableModel, get_table_keys


def _read_table(store: str | Path) -> AnnData:
table = read_anndata_zarr(str(store))
def _read_table(store: str | Path, lazy: bool = False) -> AnnData:
"""
Read a table from a zarr store.

Parameters
----------
store
Path to the zarr store containing the table.
lazy
If True, read the table lazily using ``anndata.experimental.read_lazy``.
This keeps large matrices (X, layers) as dask arrays backed by zarr,
so they are only loaded into memory on demand. Requires anndata >= 0.12.

Returns
-------
The AnnData table, either lazily loaded or in-memory.

Raises
------
ImportError
If ``lazy=True`` but anndata >= 0.12 is not installed.
"""
if lazy:
from anndata.experimental import read_lazy

table = read_lazy(str(store))
else:
table = read_anndata_zarr(str(store))

f = zarr.open(Path(store), mode="r") # Path avoids zarr v3 URL-parsing special chars (e.g. #) in names
version = _parse_version(f, expect_attrs_key=False)
Expand Down
21 changes: 11 additions & 10 deletions src/spatialdata/_io/io_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import warnings
from collections.abc import Callable
from functools import partial
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Literal, cast
Expand All @@ -17,11 +18,7 @@
from zarr.errors import ArrayNotFoundError

from spatialdata._core.spatialdata import SpatialData
from spatialdata._io._utils import (
BadFileHandleMethod,
_resolve_zarr_store,
handle_read_errors,
)
from spatialdata._io._utils import BadFileHandleMethod, _resolve_zarr_store, handle_read_errors
from spatialdata._io.io_points import _read_points
from spatialdata._io.io_raster import _read_multiscale
from spatialdata._io.io_shapes import _read_shapes
Expand Down Expand Up @@ -106,10 +103,7 @@ def get_raster_format_for_read(
-------
The ome-zarr format to use for reading the raster element.
"""
from spatialdata._io.format import (
sdata_zarr_version_to_ome_zarr_format,
sdata_zarr_version_to_raster_format,
)
from spatialdata._io.format import sdata_zarr_version_to_ome_zarr_format, sdata_zarr_version_to_raster_format

if sdata_version == "0.1":
group_version = group.metadata.attributes["multiscales"][0]["version"]
Expand All @@ -126,6 +120,7 @@ def read_zarr(
store: str | Path | UPath | zarr.Group,
selection: None | tuple[str] = None,
on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR,
lazy: bool = False,
) -> SpatialData:
"""
Read a SpatialData dataset from a zarr store (on-disk or remote).
Expand All @@ -149,6 +144,12 @@ def read_zarr(
object is returned containing only elements that could be read. Failures can only be
determined from the warnings.

lazy
If True, read tables lazily using anndata.experimental.read_lazy.
This keeps large tables out of memory until needed. Requires anndata >= 0.12.
Note: Images, labels, and points are always read lazily (using Dask).
This parameter only affects tables, which are normally loaded into memory.

Returns
-------
A SpatialData object.
Expand Down Expand Up @@ -195,7 +196,7 @@ def read_zarr(
"labels": (_read_multiscale, "labels", labels),
"points": (_read_points, "points", points),
"shapes": (_read_shapes, "shapes", shapes),
"tables": (_read_table, "tables", tables),
"tables": (partial(_read_table, lazy=lazy), "tables", tables),
}
for group_name, (
read_func,
Expand Down
10 changes: 8 additions & 2 deletions src/spatialdata/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,17 @@ def _inplace_fix_subset_categorical_obs(subset_adata: AnnData, original_adata: A
"""
if not hasattr(subset_adata, "obs") or not hasattr(original_adata, "obs"):
return
obs = pd.DataFrame(subset_adata.obs)
# Tables read via anndata.experimental.read_lazy have a Dataset2D obs instead of a DataFrame;
# pd.DataFrame() would silently malform it, so materialize with to_memory() in that case.
obs = pd.DataFrame(subset_adata.obs) if isinstance(subset_adata.obs, pd.DataFrame) else subset_adata.obs.to_memory()
original_obs = (
original_adata.obs if isinstance(original_adata.obs, pd.DataFrame) else original_adata.obs.to_memory()
)

for column in obs.columns:
is_categorical = isinstance(obs[column].dtype, pd.CategoricalDtype)
if is_categorical:
c = obs[column].cat.set_categories(original_adata.obs[column].cat.categories)
c = obs[column].cat.set_categories(original_obs[column].cat.categories)
obs[column] = c
subset_adata.obs = obs

Expand Down
36 changes: 34 additions & 2 deletions src/spatialdata/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,25 @@
ATTRS_KEY = "spatialdata_attrs"


def _is_lazy_anndata(adata: AnnData) -> bool:
"""Check if an AnnData object is lazily loaded.

Lazy AnnData objects (from anndata.experimental.read_lazy) have obs/var
stored as xarray Dataset2D instead of pandas DataFrame.

Parameters
----------
adata
The AnnData object to check.

Returns
-------
True if the AnnData is lazily loaded, False otherwise.
"""
# Check if obs is not a pandas DataFrame (lazy AnnData uses xarray Dataset2D)
return not isinstance(adata.obs, pd.DataFrame)


def _parse_transformations(element: SpatialElement, transformations: MappingToCoordinateSystem_t | None = None) -> None:
_validate_mapping_to_coordinate_system_type(transformations)
transformations_in_element = _get_transformations(element)
Expand Down Expand Up @@ -1068,6 +1087,13 @@ def _validate_table_annotation_metadata(cls, data: AnnData) -> None:
raise ValueError(f"`{attr[cls.REGION_KEY_KEY]}` not found in `adata.obs`. Please create the column.")
if attr[cls.INSTANCE_KEY] not in data.obs:
raise ValueError(f"`{attr[cls.INSTANCE_KEY]}` not found in `adata.obs`. Please create the column.")

# Skip detailed dtype/value validation for lazy-loaded AnnData
# These checks would trigger data loading, defeating the purpose of lazy loading
# Validation will occur when data is actually computed/accessed
if _is_lazy_anndata(data):
return

instance_col = data.obs[attr[cls.INSTANCE_KEY]]
dtype = instance_col.dtype

Expand Down Expand Up @@ -1137,14 +1163,19 @@ def validate(
if ATTRS_KEY not in data.uns:
return data

# Check if this is a lazy-loaded AnnData (from anndata.experimental.read_lazy)
# Lazy AnnData has xarray-based obs/var, which requires different validation
is_lazy = _is_lazy_anndata(data)

_, region_key, instance_key = get_table_keys(data)
if region_key is not None:
if region_key not in data.obs:
raise ValueError(
f"Region key `{region_key}` not in `adata.obs`. Please create the column and parse "
f"using TableModel.parse(adata)."
)
if not isinstance(data.obs[region_key].dtype, CategoricalDtype):
# Skip dtype validation for lazy tables (would require loading data)
if not is_lazy and not isinstance(data.obs[region_key].dtype, CategoricalDtype):
raise ValueError(
f"`table.obs[{region_key}]` must be of type `categorical`, not `{type(data.obs[region_key])}`."
)
Expand All @@ -1154,7 +1185,8 @@ def validate(
f"Instance key `{instance_key}` not in `adata.obs`. Please create the column and parse"
f" using TableModel.parse(adata)."
)
if data.obs[instance_key].isnull().values.any():
# Skip null check for lazy tables (would require loading data)
if not is_lazy and data.obs[instance_key].isnull().values.any():
raise ValueError("`table.obs[instance_key]` must not contain null values, but it does.")

cls._validate_table_annotation_metadata(data)
Expand Down
Loading
Loading