From f2c41d1fcf83466c3b2fcda061f8934296654f2e Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Tue, 28 Apr 2026 20:32:56 -0400 Subject: [PATCH 1/8] Support standard raster image formats --- pyproject.toml | 3 ++- src/polystore/disk.py | 15 +++++++++++++-- src/polystore/formats.py | 11 ++++++++++- src/polystore/virtual_workspace.py | 6 +++++- 4 files changed, 30 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1dd9dfc..08c8602 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "numpy>=1.26.0", "portalocker>=2.8.0", # Cross-platform file locking "metaclass-registry", + "imageio>=2.37.0", "zarr>=2.18.0,<3.0", # Required for ZarrStorageBackend "ome-zarr>=0.11.0", # Required for OME-ZARR HCS compliance ] @@ -197,4 +198,4 @@ ignore = [ ] [tool.ruff.per-file-ignores] -"__init__.py" = ["F401"] # unused imports \ No newline at end of file +"__init__.py" = ["F401"] # unused imports diff --git a/src/polystore/disk.py b/src/polystore/disk.py index 40c33d9..fe82b86 100644 --- a/src/polystore/disk.py +++ b/src/polystore/disk.py @@ -9,6 +9,7 @@ import logging import os import shutil +import importlib from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Set, Union @@ -23,7 +24,7 @@ def optional_import(module_name): try: - return __import__(module_name) + return importlib.import_module(module_name) except ImportError: return None @@ -44,6 +45,7 @@ def optional_import(module_name): cupy = get_cupy() tf = get_tf() tifffile = optional_import("tifffile") +imageio = optional_import("imageio.v3") # Optional arraybridge integration for memory conversion try: @@ -99,6 +101,7 @@ def _register_formats(self): # Complex formats - use custom handlers (FileFormat.TIFF, tifffile, self._tiff_writer, self._tiff_reader), + (FileFormat.RASTER_IMAGE, imageio, self._image_writer, self._image_reader), (FileFormat.TEXT, True, self._text_writer, self._text_reader), (FileFormat.JSON, True, self._json_writer, self._json_reader), (FileFormat.CSV, True, self._csv_writer, self._csv_reader), @@ -164,6 +167,14 @@ def _tiff_reader(self, path): else: return tifffile.imread(str(path)) + def _image_writer(self, path, data, **kwargs): + """Write standard raster images using imageio.""" + imageio.imwrite(path, np.asarray(data)) + + def _image_reader(self, path): + """Read standard raster images using imageio.""" + return imageio.imread(path) + def _text_writer(self, path, data, **kwargs): """Write text data to file. Accepts and ignores extra kwargs for compatibility.""" path.write_text(str(data)) @@ -261,7 +272,7 @@ def load(self, file_path: Union[str, Path], **kwargs) -> Any: ext = disk_path.suffix.lower() if not self.format_registry.is_registered(ext): - raise ValueError(f"No writer registered for extension '{ext}'") + raise ValueError(f"No reader registered for extension '{ext}'") try: reader = self.format_registry.get_reader(ext) diff --git a/src/polystore/formats.py b/src/polystore/formats.py index ddfb9a5..3643361 100644 --- a/src/polystore/formats.py +++ b/src/polystore/formats.py @@ -20,6 +20,7 @@ class FileFormat(Enum): # Image formats TIFF = "tiff" + RASTER_IMAGE = "raster_image" # Data formats CSV = "csv" @@ -44,6 +45,7 @@ def extensions(self): FileFormat.TENSORFLOW: [".tf"], FileFormat.ZARR: [".zarr"], FileFormat.TIFF: [".tif", ".tiff"], + FileFormat.RASTER_IMAGE: [".bmp", ".gif", ".jpeg", ".jpg", ".png"], FileFormat.CSV: [".csv"], FileFormat.JSON: [".json"], FileFormat.TEXT: [".txt"], @@ -51,7 +53,14 @@ def extensions(self): } # Default image extensions -DEFAULT_IMAGE_EXTENSIONS = {".tif", ".tiff", ".TIF", ".TIFF"} +DEFAULT_IMAGE_EXTENSIONS = { + extension + for extensions in ( + FILE_FORMAT_EXTENSIONS[FileFormat.TIFF], + FILE_FORMAT_EXTENSIONS[FileFormat.RASTER_IMAGE], + ) + for extension in extensions +} def get_format_from_extension(ext: str) -> FileFormat: diff --git a/src/polystore/virtual_workspace.py b/src/polystore/virtual_workspace.py index 45081a3..c7bc61b 100644 --- a/src/polystore/virtual_workspace.py +++ b/src/polystore/virtual_workspace.py @@ -210,6 +210,10 @@ def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, logger.info(f" relative_dir_str='{relative_dir_str}'") logger.info(f" mapping has {len(self._mapping_cache)} entries") + lowercase_extensions = ( + None if extensions is None else {ext.lower() for ext in extensions} + ) + # Filter paths in this directory results = [] for virtual_relative in self._mapping_cache.keys(): @@ -230,7 +234,7 @@ def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, vpath = Path(virtual_relative) if pattern and not fnmatch(vpath.name, pattern): continue - if extensions and vpath.suffix not in extensions: + if lowercase_extensions and vpath.suffix.lower() not in lowercase_extensions: continue # Return absolute path From 637107ed87707a0577c96f7cf755272a88170d1a Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Wed, 29 Apr 2026 05:49:00 -0400 Subject: [PATCH 2/8] Make memory extension filtering case-insensitive --- src/polystore/memory.py | 8 +++++++- tests/test_memory_backend.py | 11 +++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/polystore/memory.py b/src/polystore/memory.py index a59114f..5f3f6df 100644 --- a/src/polystore/memory.py +++ b/src/polystore/memory.py @@ -139,6 +139,9 @@ def list_files( if self._memory_store[dir_key] is not None: raise NotADirectoryError(f"Path is not a directory: {directory}") + lowercase_extensions = ( + None if extensions is None else {extension.lower() for extension in extensions} + ) result = [] dir_prefix = dir_key + "/" if not dir_key.endswith("/") else dir_key @@ -159,7 +162,10 @@ def list_files( filename = Path(rel_path).name # If pattern is None, match all files if pattern is None or fnmatch(filename, pattern): - if not extensions or Path(filename).suffix in extensions: + if ( + lowercase_extensions is None + or Path(filename).suffix.lower() in lowercase_extensions + ): # Calculate depth for breadth-first sorting depth = rel_path.count('/') result.append((Path(path), depth)) diff --git a/tests/test_memory_backend.py b/tests/test_memory_backend.py index f55996b..ec8a080 100644 --- a/tests/test_memory_backend.py +++ b/tests/test_memory_backend.py @@ -109,6 +109,17 @@ def test_list_files_with_extension_filter(self): npy_files = self.backend.list_files("/test", extensions={".npy"}) assert len(npy_files) == 2 + def test_list_files_extension_filter_is_case_insensitive(self): + """Test extension filtering matches backend contract case-insensitively.""" + self.backend.save(np.array([1]), "/test/image.TIF") + self.backend.save(np.array([2]), "/test/image.tif") + self.backend.save("text", "/test/notes.TXT") + + tif_files = self.backend.list_files("/test", extensions={".tif"}) + + assert len(tif_files) == 2 + assert {path.name for path in tif_files} == {"image.TIF", "image.tif"} + def test_list_files_recursive(self): """Test recursive file listing.""" # Create files in multiple levels From 3365356d20eed6af0ff6ddde9cccd9bf02d595e6 Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Mon, 4 May 2026 17:03:42 -0400 Subject: [PATCH 3/8] Reduce VFS debug log noise --- src/polystore/base.py | 13 +++++++------ src/polystore/virtual_workspace.py | 20 +++++++++++++------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/polystore/base.py b/src/polystore/base.py index 2b033fc..e18849e 100644 --- a/src/polystore/base.py +++ b/src/polystore/base.py @@ -546,15 +546,16 @@ def reset_memory_backend() -> None: # Clear files from existing memory backend while preserving directories memory_backend = storage_registry[Backend.MEMORY.value] - # DEBUG: Log what's in memory before clearing existing_keys = list(memory_backend._memory_store.keys()) - logger.info(f"🔍 VFS_CLEAR: Memory backend has {len(existing_keys)} entries BEFORE clear") - logger.info(f"🔍 VFS_CLEAR: First 10 keys: {existing_keys[:10]}") + logger.debug("Memory backend has %s entries before clear", len(existing_keys)) + logger.debug("First memory backend keys before clear: %s", existing_keys[:10]) memory_backend.clear_files_only() - # DEBUG: Log what's in memory after clearing remaining_keys = list(memory_backend._memory_store.keys()) - logger.info(f"🔍 VFS_CLEAR: Memory backend has {len(remaining_keys)} entries AFTER clear (directories only)") - logger.info(f"🔍 VFS_CLEAR: First 10 remaining keys: {remaining_keys[:10]}") + logger.debug( + "Memory backend has %s entries after clear (directories only)", + len(remaining_keys), + ) + logger.debug("First memory backend keys after clear: %s", remaining_keys[:10]) logger.info("Memory backend reset - files cleared, directories preserved") diff --git a/src/polystore/virtual_workspace.py b/src/polystore/virtual_workspace.py index c7bc61b..bec8be5 100644 --- a/src/polystore/virtual_workspace.py +++ b/src/polystore/virtual_workspace.py @@ -205,10 +205,16 @@ def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, if self._mapping_cache is None: self._load_mapping() - logger.info(f"VirtualWorkspace.list_files called: directory={directory}, recursive={recursive}, pattern={pattern}, extensions={extensions}") - logger.info(f" plate_root={self.plate_root}") - logger.info(f" relative_dir_str='{relative_dir_str}'") - logger.info(f" mapping has {len(self._mapping_cache)} entries") + logger.debug( + "VirtualWorkspace.list_files directory=%s recursive=%s pattern=%s extensions=%s", + directory, + recursive, + pattern, + extensions, + ) + logger.debug(" plate_root=%s", self.plate_root) + logger.debug(" relative_dir_str=%r", relative_dir_str) + logger.debug(" mapping has %s entries", len(self._mapping_cache)) lowercase_extensions = ( None if extensions is None else {ext.lower() for ext in extensions} @@ -240,14 +246,14 @@ def list_files(self, directory: Union[str, Path], pattern: Optional[str] = None, # Return absolute path results.append(str(self.plate_root / virtual_relative)) - logger.info(f" VirtualWorkspace.list_files returning {len(results)} files") + logger.debug(" VirtualWorkspace.list_files returning %s files", len(results)) if len(results) == 0 and len(self._mapping_cache) > 0: # Log first few mapping keys to help debug sample_keys = list(self._mapping_cache.keys())[:3] - logger.info(f" Sample mapping keys: {sample_keys}") + logger.debug(" Sample mapping keys: %s", sample_keys) if not recursive and relative_dir_str == '': sample_parents = [str(Path(k).parent).replace('\\', '/') for k in sample_keys] - logger.info(f" Sample parent dirs: {sample_parents}") + logger.debug(" Sample parent dirs: %s", sample_parents) logger.info(f" Expected parent to match: '{relative_dir_str}'") return sorted(results) From 000ab4ea425ec2f8367d4f1c3d35a69450848b71 Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Wed, 20 May 2026 17:00:26 -0400 Subject: [PATCH 4/8] Move napari batching to Qt event loop --- .../napari/napari_batch_processor.py | 40 +++++++------------ 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/src/polystore/streaming/receivers/napari/napari_batch_processor.py b/src/polystore/streaming/receivers/napari/napari_batch_processor.py index b8dcbdd..ad485af 100644 --- a/src/polystore/streaming/receivers/napari/napari_batch_processor.py +++ b/src/polystore/streaming/receivers/napari/napari_batch_processor.py @@ -1,20 +1,16 @@ import logging from typing import Any, Dict, List, Optional -from polystore.streaming.receivers.core import DebouncedBatchEngine - logger = logging.getLogger(__name__) class NapariBatchProcessor: """ - Batch processor for Napari viewer with configurable batching strategies. - - Accumulates items and displays them based on batch_size configuration: - - None: Wait for all items in operation, then display once - - N: Display every N items incrementally - - Uses debouncing to collect items arriving in rapid succession. + Batch processor for Napari viewer display operations. + + Napari layer mutation must run on the Qt event-loop thread. OpenHCS owns that + Qt-thread debounce before this processor is called, so this class only + adapts batch payloads into the server display operation. """ def __init__( @@ -29,22 +25,15 @@ def __init__( Args: napari_server: Reference to NapariViewerServer for display operations - batch_size: Number of items to batch before displaying - None = wait for all (default), N = display every N items - debounce_delay_ms: Wait time after last item before processing (ms) - max_debounce_wait_ms: Maximum total wait time before forcing display (ms) + batch_size: Reserved for compatibility with viewer configuration + debounce_delay_ms: Qt-thread debounce delay owned by the caller + max_debounce_wait_ms: Reserved for compatibility with viewer configuration """ self.napari_server = napari_server self.batch_size = batch_size self.debounce_delay_ms = debounce_delay_ms self.max_debounce_wait_ms = max_debounce_wait_ms - self._engine = DebouncedBatchEngine( - process_fn=self._process_batch, - debounce_delay_ms=debounce_delay_ms, - max_debounce_wait_ms=max_debounce_wait_ms, - ) - logger.info( f"NapariBatchProcessor: Created with batch_size={batch_size}, " f"debounce={debounce_delay_ms}ms, max_wait={max_debounce_wait_ms}ms" @@ -58,7 +47,7 @@ def add_items( component_names_metadata: Dict[str, Any], ): """ - Add items to the batch for processing. + Display items already released by the Qt-thread debounce. Args: layer_key: Unique identifier for the layer @@ -66,9 +55,9 @@ def add_items( display_config: Display configuration dict component_names_metadata: Component name mappings for dimension labels """ - self._engine.enqueue( - items=items, - context={ + self._process_batch( + items, + { "display_config": display_config, "component_names_metadata": component_names_metadata, "layer_key": layer_key, @@ -81,12 +70,11 @@ def add_items( ) def flush(self) -> None: - """Force immediate processing of the pending batch.""" - self._engine.flush() + """Compatibility no-op; OpenHCS owns the Qt-thread debounce timer.""" def _process_batch(self, items: List[Dict[str, Any]], context: Dict[str, Any]) -> None: """Process callback used by shared debounced batch engine.""" - self.napari_server._display_layer_batch( + self.napari_server.display_layer_batch( layer_key=context["layer_key"], items=items, display_config=context["display_config"], From e5909ba51def2e703bb0a9ccb6360262972416b3 Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Wed, 20 May 2026 18:22:26 -0400 Subject: [PATCH 5/8] Handle unparsed streaming artifact metadata --- src/polystore/__init__.py | 6 +- src/polystore/disk.py | 3 + src/polystore/memory.py | 3 + src/polystore/streaming/_streaming_backend.py | 34 ++++++-- tests/test_streaming_metadata.py | 78 +++++++++++++++++++ 5 files changed, 115 insertions(+), 9 deletions(-) create mode 100644 tests/test_streaming_metadata.py diff --git a/src/polystore/__init__.py b/src/polystore/__init__.py index 5c38d68..123c449 100644 --- a/src/polystore/__init__.py +++ b/src/polystore/__init__.py @@ -26,10 +26,10 @@ get_backend, ) from .constants import Backend, MemoryType, TransportMode -from .disk import DiskStorageBackend +from .disk import DiskBackend, DiskStorageBackend from .filemanager import FileManager from .formats import FileFormat, DEFAULT_IMAGE_EXTENSIONS -from .memory import MemoryStorageBackend +from .memory import MemoryBackend, MemoryStorageBackend from .metadata_writer import ( AtomicMetadataWriter, MetadataWriteError, @@ -76,7 +76,9 @@ "register_cleanup_callback", "STORAGE_BACKENDS", "DiskStorageBackend", + "DiskBackend", "MemoryStorageBackend", + "MemoryBackend", "FileManager", "file_lock", "atomic_write_json", diff --git a/src/polystore/disk.py b/src/polystore/disk.py index fe82b86..ca24e7c 100644 --- a/src/polystore/disk.py +++ b/src/polystore/disk.py @@ -834,3 +834,6 @@ def _save_rois(self, rois: List, output_path: Path, images_dir: str = None, **kw logger.info(f"Saved {roi_count} ROIs to .roi.zip archive: {output_path}") return str(output_path) + + +DiskBackend = DiskStorageBackend diff --git a/src/polystore/memory.py b/src/polystore/memory.py index 5f3f6df..872d581 100644 --- a/src/polystore/memory.py +++ b/src/polystore/memory.py @@ -657,3 +657,6 @@ def __init__(self, target: str): def __repr__(self): return f"" + + +MemoryBackend = MemoryStorageBackend diff --git a/src/polystore/streaming/_streaming_backend.py b/src/polystore/streaming/_streaming_backend.py index 417baa2..932a2c4 100644 --- a/src/polystore/streaming/_streaming_backend.py +++ b/src/polystore/streaming/_streaming_backend.py @@ -9,8 +9,9 @@ import os import time import uuid +from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, List, Set, Union +from typing import Any, Callable, List, Mapping, Set, Union import numpy as np from ..base import DataSink @@ -24,6 +25,27 @@ logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class StreamingComponentMetadata: + """Message metadata for one streamed item.""" + + parsed_filename_metadata: Mapping[str, Any] | None + source: str + + def to_payload(self) -> dict[str, Any]: + if self.parsed_filename_metadata is None: + metadata: dict[str, Any] = {} + elif isinstance(self.parsed_filename_metadata, Mapping): + metadata = dict(self.parsed_filename_metadata) + else: + raise TypeError( + "Streaming filename parser must return a mapping or None, " + f"got {type(self.parsed_filename_metadata).__name__}." + ) + metadata["source"] = self.source + return metadata + + class StreamingBackend(DataSink): """ Abstract base class for ZeroMQ-based streaming backends. @@ -165,12 +187,10 @@ def _parse_component_metadata(self, file_path: Union[str, Path], microscope_hand Component metadata dict with source added """ filename = os.path.basename(str(file_path)) - component_metadata = microscope_handler.parser.parse_filename(filename) - - # Add pre-built source value directly - component_metadata['source'] = source - - return component_metadata + return StreamingComponentMetadata( + microscope_handler.parser.parse_filename(filename), + source, + ).to_payload() def _detect_data_type(self, data: Any): """ diff --git a/tests/test_streaming_metadata.py b/tests/test_streaming_metadata.py new file mode 100644 index 0000000..8141f1d --- /dev/null +++ b/tests/test_streaming_metadata.py @@ -0,0 +1,78 @@ +from types import SimpleNamespace + +import pytest + +from polystore.streaming._streaming_backend import StreamingBackend + + +class MetadataProbeStreamingBackend(StreamingBackend): + VIEWER_TYPE = "probe" + SHM_PREFIX = "probe_" + + def save_batch(self, data_list, file_paths, **kwargs): + raise NotImplementedError + + +def test_streaming_component_metadata_accepts_unparsed_artifact_filename() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = SimpleNamespace( + parser=SimpleNamespace(parse_filename=lambda _filename: None) + ) + + metadata = backend._parse_component_metadata( + "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip", + microscope_handler, + source="IdentifyPrimaryObjects", + ) + + assert metadata == {"source": "IdentifyPrimaryObjects"} + + +def test_streaming_batch_items_accept_unparsed_artifact_filename() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = SimpleNamespace( + parser=SimpleNamespace(parse_filename=lambda _filename: None) + ) + + batch_images, image_ids = backend._prepare_batch_items( + [object()], + ["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], + microscope_handler, + "IdentifyPrimaryObjects", + lambda _data, _path, _data_type: ({"payload": "ok"}, "image"), + ) + + assert len(image_ids) == 1 + assert batch_images[0]["metadata"] == {"source": "IdentifyPrimaryObjects"} + assert batch_images[0]["payload"] == "ok" + + +def test_streaming_component_metadata_preserves_parsed_filename_fields() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = SimpleNamespace( + parser=SimpleNamespace( + parse_filename=lambda _filename: {"well": "A01", "channel": 1} + ) + ) + + metadata = backend._parse_component_metadata( + "A01_s001_w1_z001_t001.TIF", + microscope_handler, + source="Crop", + ) + + assert metadata == {"well": "A01", "channel": 1, "source": "Crop"} + + +def test_streaming_component_metadata_rejects_invalid_parser_result() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = SimpleNamespace( + parser=SimpleNamespace(parse_filename=lambda _filename: ["not", "metadata"]) + ) + + with pytest.raises(TypeError, match="mapping or None"): + backend._parse_component_metadata( + "A01_s001_w1_z001_t001.TIF", + microscope_handler, + source="Crop", + ) From 28f670c623e8ec1f53cd93fc10e17096c47f4a06 Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Thu, 21 May 2026 00:05:51 -0400 Subject: [PATCH 6/8] Improve ROI streaming metadata handling --- pyproject.toml | 1 + src/polystore/fiji_stream.py | 5 +- src/polystore/napari_stream.py | 8 +- src/polystore/roi.py | 142 +++++++--- src/polystore/roi_converters.py | 247 ++++++++++++------ src/polystore/streaming/_streaming_backend.py | 180 ++++++------- .../streaming/receivers/napari/layer_key.py | 15 +- src/polystore/streaming_constants.py | 15 ++ tests/test_roi.py | 79 ++++++ tests/test_streaming_metadata.py | 62 +++-- 10 files changed, 501 insertions(+), 253 deletions(-) create mode 100644 tests/test_roi.py diff --git a/pyproject.toml b/pyproject.toml index 08c8602..5d6f7da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ classifiers = [ ] dependencies = [ + "arraybridge>=0.2.9", "numpy>=1.26.0", "portalocker>=2.8.0", # Cross-platform file locking "metaclass-registry", diff --git a/src/polystore/fiji_stream.py b/src/polystore/fiji_stream.py index 4d52817..08132bc 100644 --- a/src/polystore/fiji_stream.py +++ b/src/polystore/fiji_stream.py @@ -31,12 +31,9 @@ class FijiStreamingBackend(StreamingBackend): """Fiji streaming backend with ZMQ publisher pattern (matches Napari architecture).""" _backend_type = Backend.FIJI_STREAM.value - # Configure ABC attributes VIEWER_TYPE = 'fiji' SHM_PREFIX = 'fiji_' - # __init__, _get_publisher, save, cleanup now inherited from ABC - def _prepare_rois_data(self, data: Any, file_path: Union[str, Path]) -> dict: """ Prepare ROIs data for transmission. @@ -90,6 +87,7 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * source = kwargs.get('source', 'unknown_source') # Pre-built source value images_dir = kwargs.get('images_dir') # Source image subdirectory for ROI mapping plate_path = kwargs.get('plate_path') + component_metadata = kwargs.get('component_metadata') logger.info(f"🏷️ FIJI BACKEND: plate_path = {plate_path}") logger.info(f"🏷️ FIJI BACKEND: microscope_handler = {microscope_handler}") display_payload_extra = { @@ -108,6 +106,7 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * display_config, self._prepare_batch_item, plate_path=plate_path, + component_metadata=component_metadata, component_names_kwargs={"log_prefix": "🏷️ FIJI BACKEND", "verbose": True}, display_payload_extra=display_payload_extra, message_extra=message_extra, diff --git a/src/polystore/napari_stream.py b/src/polystore/napari_stream.py index 630bcc8..d762cd6 100644 --- a/src/polystore/napari_stream.py +++ b/src/polystore/napari_stream.py @@ -20,7 +20,6 @@ import zmq from .constants import Backend, TransportMode -from .streaming_constants import StreamingDataType from .streaming import StreamingBackend from .roi_converters import NapariROIConverter from zmqruntime.transport import get_zmq_transport_url, coerce_transport_mode @@ -32,12 +31,9 @@ class NapariStreamingBackend(StreamingBackend): """Napari streaming backend with automatic registration.""" _backend_type = Backend.NAPARI_STREAM.value - # Configure ABC attributes VIEWER_TYPE = 'napari' SHM_PREFIX = 'napari_' - # __init__, _get_publisher, save, cleanup now inherited from ABC - def _prepare_shapes_data(self, data: Any, file_path: Union[str, Path]) -> dict: """ Prepare shapes data for transmission. @@ -57,7 +53,7 @@ def _prepare_shapes_data(self, data: Any, file_path: Union[str, Path]) -> dict: } def _prepare_batch_item(self, data: Any, file_path: Union[str, Path], data_type): - if data_type in (StreamingDataType.SHAPES, StreamingDataType.POINTS): + if data_type.uses_napari_vector_payload: item_data = self._prepare_shapes_data(data, file_path) data_type_value = data_type.value else: @@ -88,6 +84,7 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * microscope_handler = kwargs['microscope_handler'] source = kwargs.get('source', 'unknown_source') # Pre-built source value plate_path = kwargs.get('plate_path') + component_metadata = kwargs.get('component_metadata') display_payload_extra = { "colormap": display_config.get_colormap_name(), "variable_size_handling": display_config.variable_size_handling.value @@ -103,6 +100,7 @@ def save_batch(self, data_list: List[Any], file_paths: List[Union[str, Path]], * display_config, self._prepare_batch_item, plate_path=plate_path, + component_metadata=component_metadata, display_payload_extra=display_payload_extra, ) diff --git a/src/polystore/roi.py b/src/polystore/roi.py index fb6bdb6..d841591 100644 --- a/src/polystore/roi.py +++ b/src/polystore/roi.py @@ -6,12 +6,14 @@ """ import logging +from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union import numpy as np +from metaclass_registry import AutoRegisterMeta from .constants import Backend @@ -27,8 +29,14 @@ class ShapeType(Enum): ELLIPSE = "ellipse" +class ROIShape(ABC): + """Nominal base for all ROI shape records.""" + + shape_type: ShapeType + + @dataclass(frozen=True) -class PolygonShape: +class PolygonShape(ROIShape): """Polygon ROI shape defined by vertex coordinates.""" coordinates: np.ndarray # Nx2 array of (y, x) coordinates shape_type: ShapeType = field(default=ShapeType.POLYGON, init=False) @@ -41,7 +49,7 @@ def __post_init__(self): @dataclass(frozen=True) -class PolylineShape: +class PolylineShape(ROIShape): """Polyline ROI shape defined by path coordinates (open path, not closed polygon).""" coordinates: np.ndarray # Nx2 array of (y, x) coordinates shape_type: ShapeType = field(default=ShapeType.POLYLINE, init=False) @@ -54,7 +62,7 @@ def __post_init__(self): @dataclass(frozen=True) -class MaskShape: +class MaskShape(ROIShape): """Binary mask ROI shape.""" mask: np.ndarray # 2D boolean array bbox: Tuple[int, int, int, int] # (min_y, min_x, max_y, max_x) @@ -68,7 +76,7 @@ def __post_init__(self): @dataclass(frozen=True) -class PointShape: +class PointShape(ROIShape): """Point ROI shape.""" y: float x: float @@ -76,7 +84,7 @@ class PointShape: @dataclass(frozen=True) -class EllipseShape: +class EllipseShape(ROIShape): """Ellipse ROI shape.""" center_y: float center_x: float @@ -95,14 +103,82 @@ def __post_init__(self): if not self.shapes: raise ValueError("ROI must have at least one shape") for shape in self.shapes: - if not hasattr(shape, "shape_type"): - raise ValueError(f"Shape {shape} must have shape_type attribute") + if not isinstance(shape, ROIShape): + raise ValueError(f"Shape {shape} must be an ROIShape") + + +class ROIJsonShapeDecoder(ABC, metaclass=AutoRegisterMeta): + """Decode one serialized ROI shape variant.""" + + __registry_key__ = "shape_type" + __skip_if_no_key__ = True + + shape_type: ClassVar[ShapeType | None] = None + + @classmethod + def for_serialized_shape(cls, shape_dict: Dict[str, Any]) -> "ROIJsonShapeDecoder | None": + shape_type = shape_dict.get("type") + try: + shape_key = ShapeType(shape_type) + except ValueError: + logger.warning(f"Unknown shape type: {shape_type}, skipping") + return None + return cls.__registry__[shape_key]() + + @abstractmethod + def decode(self, shape_dict: Dict[str, Any]) -> Any: + """Return the concrete ROI shape represented by ``shape_dict``.""" + + +class PolygonROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.POLYGON + + def decode(self, shape_dict: Dict[str, Any]) -> PolygonShape: + return PolygonShape(coordinates=np.array(shape_dict["coordinates"])) + + +class PolylineROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.POLYLINE + + def decode(self, shape_dict: Dict[str, Any]) -> PolylineShape: + return PolylineShape(coordinates=np.array(shape_dict["coordinates"])) + + +class MaskROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.MASK + + def decode(self, shape_dict: Dict[str, Any]) -> MaskShape: + return MaskShape( + mask=np.array(shape_dict["mask"], dtype=bool), + bbox=tuple(shape_dict["bbox"]), + ) + + +class PointROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.POINT + + def decode(self, shape_dict: Dict[str, Any]) -> PointShape: + return PointShape(y=shape_dict["y"], x=shape_dict["x"]) + + +class EllipseROIJsonShapeDecoder(ROIJsonShapeDecoder): + shape_type = ShapeType.ELLIPSE + + def decode(self, shape_dict: Dict[str, Any]) -> EllipseShape: + return EllipseShape( + center_y=shape_dict["center_y"], + center_x=shape_dict["center_x"], + radius_y=shape_dict["radius_y"], + radius_x=shape_dict["radius_x"], + ) def extract_rois_from_labeled_mask( labeled_mask: np.ndarray, min_area: int = 10, extract_contours: bool = True, + spatial_origin_yx: Optional[Tuple[int, int]] = None, + source_spatial_shape_yx: Optional[Tuple[int, int]] = None, ) -> List[ROI]: """Extract ROIs from a labeled segmentation mask.""" from skimage import measure @@ -117,19 +193,33 @@ def extract_rois_from_labeled_mask( regions = regionprops(labeled_mask) slices = find_objects(labeled_mask) + origin_y, origin_x = spatial_origin_yx or (0, 0) rois = [] for region in regions: if region.area < min_area: continue + min_y, min_x, max_y, max_x = region.bbox metadata = { "label": int(region.label), "area": float(region.area), "perimeter": float(region.perimeter), - "centroid": tuple(float(c) for c in region.centroid), - "bbox": tuple(int(b) for b in region.bbox), + "centroid": ( + float(region.centroid[0] + origin_y), + float(region.centroid[1] + origin_x), + ), + "bbox": ( + int(min_y + origin_y), + int(min_x + origin_x), + int(max_y + origin_y), + int(max_x + origin_x), + ), } + if source_spatial_shape_yx is not None: + metadata["source_spatial_shape_yx"] = tuple( + int(value) for value in source_spatial_shape_yx + ) shapes = [] if extract_contours: @@ -142,14 +232,14 @@ def extract_rois_from_labeled_mask( contours = measure.find_contours(padded_mask, level=0.5) offset_y = slice_y.start offset_x = slice_x.start - padding_offset = np.array([offset_y, offset_x]) - 1 + padding_offset = np.array([offset_y + origin_y, offset_x + origin_x]) - 1 for contour in contours: if len(contour) >= 3: contour_full = contour + padding_offset shapes.append(PolygonShape(coordinates=contour_full)) else: binary_mask = (labeled_mask == region.label) - shapes.append(MaskShape(mask=binary_mask, bbox=region.bbox)) + shapes.append(MaskShape(mask=binary_mask, bbox=metadata["bbox"])) if shapes: rois.append(ROI(shapes=shapes, metadata=metadata)) @@ -203,31 +293,9 @@ def load_rois_from_json(json_path: Path) -> List[ROI]: metadata = roi_dict.get("metadata", {}) shapes = [] for shape_dict in roi_dict.get("shapes", []): - shape_type = shape_dict.get("type") - - if shape_type == "polygon": - coordinates = np.array(shape_dict["coordinates"]) - shapes.append(PolygonShape(coordinates=coordinates)) - elif shape_type == "polyline": - coordinates = np.array(shape_dict["coordinates"]) - shapes.append(PolylineShape(coordinates=coordinates)) - elif shape_type == "mask": - mask = np.array(shape_dict["mask"], dtype=bool) - bbox = tuple(shape_dict["bbox"]) - shapes.append(MaskShape(mask=mask, bbox=bbox)) - elif shape_type == "point": - shapes.append(PointShape(y=shape_dict["y"], x=shape_dict["x"])) - elif shape_type == "ellipse": - shapes.append( - EllipseShape( - center_y=shape_dict["center_y"], - center_x=shape_dict["center_x"], - radius_y=shape_dict["radius_y"], - radius_x=shape_dict["radius_x"], - ) - ) - else: - logger.warning(f"Unknown shape type: {shape_type}, skipping") + decoder = ROIJsonShapeDecoder.for_serialized_shape(shape_dict) + if decoder is not None: + shapes.append(decoder.decode(shape_dict)) if shapes: rois.append(ROI(shapes=shapes, metadata=metadata)) diff --git a/src/polystore/roi_converters.py b/src/polystore/roi_converters.py index 46e8631..616e4c4 100644 --- a/src/polystore/roi_converters.py +++ b/src/polystore/roi_converters.py @@ -7,63 +7,184 @@ """ import logging -from typing import Any, Dict, List, Tuple +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, ClassVar, Dict, List, Tuple import numpy as np +from metaclass_registry import AutoRegisterMeta -from .roi import EllipseShape, PointShape, PolygonShape, PolylineShape, ROI -from .streaming_constants import NapariShapeType +from .roi import EllipseShape, PointShape, PolygonShape, PolylineShape, ROI, ShapeType logger = logging.getLogger(__name__) -class NapariROIConverter: - """Convert ROI objects to Napari shapes format.""" +@dataclass(frozen=True, slots=True) +class NapariShapeTypeAlias: + """Inert alias from Napari wire shape names to ROI shape types.""" + + alias: str + shape_type: ShapeType + + +NAPARI_SHAPE_TYPE_ALIASES = ( + NapariShapeTypeAlias("path", ShapeType.POLYLINE), + NapariShapeTypeAlias("points", ShapeType.POINT), +) + + +class NapariShapeConverter(ABC, metaclass=AutoRegisterMeta): + """Registered conversion behavior for one ROI shape type.""" + + __registry_key__ = "shape_type" + __skip_if_no_key__ = True + + shape_type: ClassVar[ShapeType | None] = None + + @classmethod + def for_shape_dict(cls, shape_dict: Dict[str, Any]) -> "NapariShapeConverter": + return cls.__registry__[_shape_type_from_napari(shape_dict["type"])]() + + def append_common_properties( + self, + metadata: Dict[str, Any], + properties: dict[str, list[Any]], + centroid: tuple[Any, Any], + *, + area: Any | None = None, + ) -> None: + properties["label"].append(metadata.get("label", "")) + properties["area"].append(metadata.get("area", 0) if area is None else area) + properties["centroid_y"].append(centroid[0]) + properties["centroid_x"].append(centroid[1]) + + @abstractmethod + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + """Add dimensions to a 2D shape to make it nD.""" + + @abstractmethod + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: dict[str, list[Any]], + ) -> None: + """Append this shape to a Napari layer payload.""" + + +def _shape_type_from_napari(shape_type: object) -> ShapeType: + if isinstance(shape_type, ShapeType): + return shape_type + value = str(shape_type.value) if isinstance(shape_type, Enum) else str(shape_type) + for alias in NAPARI_SHAPE_TYPE_ALIASES: + if alias.alias == value: + return alias.shape_type + return ShapeType(value) + + +class CoordinateNapariShapeConverter(NapariShapeConverter): + """Shared converter for coordinate-list shapes.""" - _SHAPE_DIMENSION_HANDLERS = { - "polygon": lambda shape_dict, prepend_dims: np.hstack( - [np.tile(prepend_dims, (len(shape_dict["coordinates"]), 1)), np.array(shape_dict["coordinates"])] - ), - "polyline": lambda shape_dict, prepend_dims: np.hstack( - [np.tile(prepend_dims, (len(shape_dict["coordinates"]), 1)), np.array(shape_dict["coordinates"])] - ), - "ellipse": lambda shape_dict, prepend_dims: np.hstack( + napari_shape_type: ClassVar[str] + + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + coordinates = np.array(shape_dict["coordinates"]) + return np.hstack([np.tile(prepend_dims, (len(coordinates), 1)), coordinates]) + + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: dict[str, list[Any]], + ) -> None: + metadata = shape_dict.get("metadata", {}) + napari_shapes.append(np.array(shape_dict["coordinates"])) + shape_types.append(self.napari_shape_type) + self.append_common_properties( + metadata, + properties, + metadata.get("centroid", (0, 0)), + ) + + +class PolygonNapariShapeConverter(CoordinateNapariShapeConverter): + shape_type = ShapeType.POLYGON + napari_shape_type = "polygon" + + +class PolylineNapariShapeConverter(CoordinateNapariShapeConverter): + shape_type = ShapeType.POLYLINE + napari_shape_type = "path" + + +class EllipseNapariShapeConverter(NapariShapeConverter): + shape_type = ShapeType.ELLIPSE + + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + center = shape_dict["center"] + radii = shape_dict["radii"] + corners = np.array( [ - np.tile(prepend_dims, (4, 1)), - np.array( - [ - [ - shape_dict["center"][0] - shape_dict["radii"][0], - shape_dict["center"][1] - shape_dict["radii"][1], - ], - [ - shape_dict["center"][0] - shape_dict["radii"][0], - shape_dict["center"][1] + shape_dict["radii"][1], - ], - [ - shape_dict["center"][0] + shape_dict["radii"][0], - shape_dict["center"][1] + shape_dict["radii"][1], - ], - [ - shape_dict["center"][0] + shape_dict["radii"][0], - shape_dict["center"][1] - shape_dict["radii"][1], - ], - ] - ), + [center[0] - radii[0], center[1] - radii[1]], + [center[0] - radii[0], center[1] + radii[1]], + [center[0] + radii[0], center[1] + radii[1]], + [center[0] + radii[0], center[1] - radii[1]], ] - ), - "point": lambda shape_dict, prepend_dims: np.concatenate([prepend_dims, shape_dict["coordinates"]]).reshape(1, -1), - } + ) + return np.hstack([np.tile(prepend_dims, (4, 1)), corners]) + + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: dict[str, list[Any]], + ) -> None: + metadata = shape_dict.get("metadata", {}) + center = np.array(shape_dict["center"]) + radii = np.array(shape_dict["radii"]) + napari_shapes.append(np.array([center - radii, center + radii])) + shape_types.append("ellipse") + self.append_common_properties( + metadata, + properties, + metadata.get("centroid", (0, 0)), + ) + + +class PointNapariShapeConverter(NapariShapeConverter): + shape_type = ShapeType.POINT + + def add_dimensions(self, shape_dict: Dict[str, Any], prepend_dims: np.ndarray) -> np.ndarray: + return np.concatenate([prepend_dims, shape_dict["coordinates"]]).reshape(1, -1) + + def append_napari_format( + self, + shape_dict: Dict[str, Any], + napari_shapes: list[np.ndarray], + shape_types: list[str], + properties: dict[str, list[Any]], + ) -> None: + metadata = shape_dict.get("metadata", {}) + coordinates = shape_dict["coordinates"] + napari_shapes.append(np.array([coordinates])) + shape_types.append("point") + self.append_common_properties(metadata, properties, coordinates, area=0) + + +class NapariROIConverter: + """Convert ROI objects to Napari shapes format.""" @staticmethod def add_dimensions_to_shape(shape_dict: Dict[str, Any], prepend_dims: List[float]) -> np.ndarray: """Add dimensions to a 2D shape to make it nD.""" - shape_type = shape_dict["type"] - shape_type_enum = NapariShapeType(shape_type) if isinstance(shape_type, str) else shape_type - handler = NapariROIConverter._SHAPE_DIMENSION_HANDLERS.get(shape_type_enum.value) - if handler is None: - raise ValueError(f"Unsupported shape type: {shape_type}") - return handler(shape_dict, np.array(prepend_dims)) + return NapariShapeConverter.for_shape_dict(shape_dict).add_dimensions( + shape_dict, + np.array(prepend_dims), + ) @staticmethod def rois_to_shapes(rois: List[ROI]) -> List[Dict[str, Any]]: @@ -104,40 +225,12 @@ def shapes_to_napari_format(shapes_data: List[Dict]) -> Tuple[List[np.ndarray], properties = {"label": [], "area": [], "centroid_y": [], "centroid_x": []} for shape_dict in shapes_data: - shape_type = shape_dict.get("type") - metadata = shape_dict.get("metadata", {}) - - if shape_type == "polygon": - coords = np.array(shape_dict["coordinates"]) - napari_shapes.append(coords) - shape_types.append("polygon") - centroid = metadata.get("centroid", (0, 0)) - properties["label"].append(metadata.get("label", "")) - properties["area"].append(metadata.get("area", 0)) - properties["centroid_y"].append(centroid[0]) - properties["centroid_x"].append(centroid[1]) - - elif shape_type == "ellipse": - center = np.array(shape_dict["center"]) - radii = np.array(shape_dict["radii"]) - corners = np.array([center - radii, center + radii]) - napari_shapes.append(corners) - shape_types.append("ellipse") - centroid = metadata.get("centroid", (0, 0)) - properties["label"].append(metadata.get("label", "")) - properties["area"].append(metadata.get("area", 0)) - properties["centroid_y"].append(centroid[0]) - properties["centroid_x"].append(centroid[1]) - - elif shape_type == "point": - coords = np.array([shape_dict["coordinates"]]) - napari_shapes.append(coords) - shape_types.append("point") - point_coords = shape_dict["coordinates"] - properties["label"].append(metadata.get("label", "")) - properties["area"].append(0) - properties["centroid_y"].append(point_coords[0]) - properties["centroid_x"].append(point_coords[1]) + NapariShapeConverter.for_shape_dict(shape_dict).append_napari_format( + shape_dict, + napari_shapes, + shape_types, + properties, + ) return napari_shapes, shape_types, properties diff --git a/src/polystore/streaming/_streaming_backend.py b/src/polystore/streaming/_streaming_backend.py index 932a2c4..1f3a70a 100644 --- a/src/polystore/streaming/_streaming_backend.py +++ b/src/polystore/streaming/_streaming_backend.py @@ -13,6 +13,8 @@ from pathlib import Path from typing import Any, Callable, List, Mapping, Set, Union import numpy as np +from arraybridge import convert_memory, detect_memory_type +from arraybridge.types import MemoryType as ArrayBridgeMemoryType from ..base import DataSink from ..constants import TransportMode @@ -20,32 +22,62 @@ from ..roi import ROI, PointShape from ..zmq_config import POLYSTORE_ZMQ_CONFIG from zmqruntime.ack_listener import GlobalAckListener -from zmqruntime.transport import coerce_transport_mode, get_zmq_transport_url +from zmqruntime.transport import coerce_transport_mode logger = logging.getLogger(__name__) +PrepareStreamingItem = Callable[[Any, Union[str, Path], Any], tuple[dict, str]] + + @dataclass(frozen=True) class StreamingComponentMetadata: """Message metadata for one streamed item.""" - parsed_filename_metadata: Mapping[str, Any] | None + parsed_filename_metadata: Mapping[str, Any] source: str def to_payload(self) -> dict[str, Any]: - if self.parsed_filename_metadata is None: - metadata: dict[str, Any] = {} - elif isinstance(self.parsed_filename_metadata, Mapping): + if isinstance(self.parsed_filename_metadata, Mapping): metadata = dict(self.parsed_filename_metadata) else: raise TypeError( - "Streaming filename parser must return a mapping or None, " + "Streaming component metadata must be a mapping, " f"got {type(self.parsed_filename_metadata).__name__}." ) metadata["source"] = self.source return metadata +@dataclass(frozen=True) +class StreamingBatchRequest: + """Shared provenance for one streaming batch.""" + + data_list: List[Any] + file_paths: List[Union[str, Path]] + microscope_handler: Any + source: str + prepare_item: PrepareStreamingItem + component_metadata: Mapping[str, Any] | None = None + + +class StreamingPayloadMemoryAuthority: + """Memory conversion authority for streamable image payloads.""" + + @staticmethod + def to_numpy(data: Any) -> np.ndarray: + if isinstance(data, np.ndarray): + return data + if isinstance(data, (list, tuple)): + return np.asarray(data) + return convert_memory( + data, + detect_memory_type(data), + ArrayBridgeMemoryType.NUMPY.value, + gpu_id=0, + ) + + class StreamingBackend(DataSink): """ Abstract base class for ZeroMQ-based streaming backends. @@ -126,55 +158,13 @@ def __init__(self, transport_config=None): self._shared_memory_blocks = {} self._transport_config = transport_config or POLYSTORE_ZMQ_CONFIG - def _get_publisher(self, host: str, port: int, transport_mode: TransportMode, transport_config=None): - """ - Lazy initialization of ZeroMQ publisher (common for all streaming backends). - - Uses REQ socket for Fiji (synchronous request/reply with blocking) - and PUB socket for Napari (broadcast pattern). - - Args: - host: Host to connect to (ignored for IPC mode) - port: Port to connect to - transport_mode: IPC or TCP transport (required - comes from config) - - Returns: - ZeroMQ publisher socket - """ - # Generate transport URL using centralized function - transport_config = transport_config or self._transport_config - url = get_zmq_transport_url( - port, - host=host, - mode=coerce_transport_mode(transport_mode), - config=transport_config, - ) - - key = url # Use URL as key instead of host:port - if key not in self._publishers: - try: - import zmq - if self._context is None: - self._context = zmq.Context() - - # Use REQ socket for all viewers (synchronous request/reply) - # All viewers must send acknowledgment after processing - publisher = self._context.socket(zmq.REQ) - - publisher.connect(url) - socket_name = "REQ" - logger.info(f"{self.VIEWER_TYPE} streaming {socket_name} socket connected to {url}") - time.sleep(0.1) - self._publishers[key] = publisher - - except ImportError: - logger.error("ZeroMQ not available - streaming disabled") - raise RuntimeError("ZeroMQ required for streaming") - - return self._publishers[key] - - def _parse_component_metadata(self, file_path: Union[str, Path], microscope_handler, - source: str) -> dict: + def _parse_component_metadata( + self, + file_path: Union[str, Path], + microscope_handler, + source: str, + component_metadata: Mapping[str, Any] | None = None, + ) -> dict: """ Parse component metadata from filename (common for all streaming backends). @@ -187,10 +177,17 @@ def _parse_component_metadata(self, file_path: Union[str, Path], microscope_hand Component metadata dict with source added """ filename = os.path.basename(str(file_path)) - return StreamingComponentMetadata( - microscope_handler.parser.parse_filename(filename), - source, - ).to_payload() + parsed_metadata = ( + component_metadata + if component_metadata is not None + else microscope_handler.parser.parse_filename(filename) + ) + if parsed_metadata is None: + raise ValueError( + "Streaming component metadata requires explicit component_metadata " + f"or a parser-readable filename; got {filename!r}." + ) + return StreamingComponentMetadata(parsed_metadata, source).to_payload() def _detect_data_type(self, data: Any): """ @@ -226,9 +223,7 @@ def _create_shared_memory(self, data: Any, file_path: Union[str, Path]) -> dict: Returns: Dict with shared memory metadata """ - # Convert to numpy - np_data = data.cpu().numpy() if hasattr(data, 'cpu') else \ - data.get() if hasattr(data, 'get') else np.asarray(data) + np_data = StreamingPayloadMemoryAuthority.to_numpy(data) # Create shared memory with hash-based naming to avoid "File name too long" errors # Hash the timestamp and object ID to create a short, unique name @@ -289,13 +284,7 @@ def _register_with_queue_tracker( tracker.register_sent(image_id) def _build_component_modes(self, display_config) -> dict: - component_modes = {} - for comp_name in display_config.COMPONENT_ORDER: - mode_field = f"{comp_name}_mode" - if hasattr(display_config, mode_field): - mode = getattr(display_config, mode_field) - component_modes[comp_name] = mode.value - return component_modes + return display_config.component_modes() def _build_display_config_base(self, display_config, component_modes: dict) -> dict: return { @@ -324,20 +313,14 @@ def _collect_component_names_metadata( try: for comp_name in component_names: - method_name = f"get_{comp_name}_values" - method = getattr(microscope_handler.metadata_handler, method_name, None) - if callable(method): - try: - metadata = method(plate_path) - if verbose and log_prefix: - logger.info(f"{log_prefix}: Got {comp_name} metadata: {metadata}") - if metadata: - component_names_metadata[comp_name] = metadata - except Exception as e: - if verbose and log_prefix: - logger.warning(f"{log_prefix}: Could not get {comp_name} metadata: {e}", exc_info=True) - elif verbose and log_prefix: - logger.info(f"{log_prefix}: No method {method_name} on metadata_handler") + metadata = microscope_handler.metadata_handler.get_component_values( + plate_path, + comp_name, + ) + if verbose and log_prefix: + logger.info(f"{log_prefix}: Got {comp_name} metadata: {metadata}") + if metadata: + component_names_metadata[comp_name] = metadata except Exception as e: if verbose and log_prefix: logger.warning(f"{log_prefix}: Could not get component metadata: {e}", exc_info=True) @@ -346,24 +329,23 @@ def _collect_component_names_metadata( def _prepare_batch_items( self, - data_list: List[Any], - file_paths: List[Union[str, Path]], - microscope_handler, - source: str, - prepare_item: Callable[[Any, Union[str, Path], Any], tuple[dict, str]], + request: StreamingBatchRequest, ) -> tuple[list[dict], list[str]]: batch_images = [] image_ids = [] - for data, file_path in zip(data_list, file_paths): + for data, file_path in zip(request.data_list, request.file_paths): image_id = str(uuid.uuid4()) image_ids.append(image_id) data_type = self._detect_data_type(data) component_metadata = self._parse_component_metadata( - file_path, microscope_handler, source + file_path, + request.microscope_handler, + request.source, + request.component_metadata, ) - item_data, data_type_value = prepare_item(data, file_path, data_type) + item_data, data_type_value = request.prepare_item(data, file_path, data_type) batch_images.append( { @@ -383,9 +365,10 @@ def _build_batch_message( microscope_handler, source: str, display_config, - prepare_item: Callable[[Any, Union[str, Path], Any], tuple[dict, str]], + prepare_item: PrepareStreamingItem, plate_path: Union[str, Path, None] = None, component_names_kwargs: dict | None = None, + component_metadata: Mapping[str, Any] | None = None, display_payload_extra: dict | None = None, message_extra: dict | None = None, ) -> tuple[dict, list[dict], list[str]]: @@ -393,11 +376,14 @@ def _build_batch_message( raise ValueError("data_list and file_paths must have the same length") batch_images, image_ids = self._prepare_batch_items( - data_list, - file_paths, - microscope_handler, - source, - prepare_item, + StreamingBatchRequest( + data_list=data_list, + file_paths=file_paths, + microscope_handler=microscope_handler, + source=source, + prepare_item=prepare_item, + component_metadata=component_metadata, + ) ) component_modes = self._build_component_modes(display_config) diff --git a/src/polystore/streaming/receivers/napari/layer_key.py b/src/polystore/streaming/receivers/napari/layer_key.py index dec6fff..51b7d67 100644 --- a/src/polystore/streaming/receivers/napari/layer_key.py +++ b/src/polystore/streaming/receivers/napari/layer_key.py @@ -14,13 +14,7 @@ def normalize_component_layout(display_config: Any) -> tuple[dict[str, str], lis component_order = display_config["component_order"] return component_modes, component_order - component_order = list(display_config.COMPONENT_ORDER) - component_modes: dict[str, str] = {} - for component in component_order: - mode_field = f"{component}_mode" - mode_value = display_config.__getattribute__(mode_field) - component_modes[component] = mode_value.value - return component_modes, component_order + return display_config.component_modes(), list(display_config.COMPONENT_ORDER) def build_layer_key( @@ -38,9 +32,4 @@ def build_layer_key( layer_key = "_".join(layer_key_parts) if layer_key_parts else "default_layer" - if data_type == StreamingDataType.SHAPES: - return f"{layer_key}_shapes" - if data_type == StreamingDataType.POINTS: - return f"{layer_key}_points" - return layer_key - + return f"{layer_key}{data_type.napari_layer_suffix}" diff --git a/src/polystore/streaming_constants.py b/src/polystore/streaming_constants.py index d7f0596..05c834c 100644 --- a/src/polystore/streaming_constants.py +++ b/src/polystore/streaming_constants.py @@ -15,6 +15,21 @@ class StreamingDataType(Enum): POINTS = "points" # Napari points layer (e.g., skeleton tracings) ROIS = "rois" # Fiji ROI payloads + @property + def uses_napari_vector_payload(self) -> bool: + """Whether napari should receive this type through vector layer payloads.""" + return self in (type(self).SHAPES, type(self).POINTS) + + @property + def napari_layer_suffix(self) -> str: + """Layer-key suffix contributed by this data type.""" + return { + type(self).IMAGE: "", + type(self).SHAPES: "_shapes", + type(self).POINTS: "_points", + type(self).ROIS: "", + }[self] + class NapariShapeType(Enum): """Napari shape types for ROI visualization.""" diff --git a/tests/test_roi.py b/tests/test_roi.py new file mode 100644 index 0000000..565022f --- /dev/null +++ b/tests/test_roi.py @@ -0,0 +1,79 @@ +import numpy as np + +from polystore.roi import MaskShape +from polystore.roi import PolygonShape +from polystore.roi import load_rois_from_json +from polystore.roi import extract_rois_from_labeled_mask + + +def test_extract_rois_from_labeled_mask_applies_spatial_origin_to_polygons(): + labels = np.zeros((8, 8), dtype=np.int32) + labels[2:6, 3:7] = 1 + + rois = extract_rois_from_labeled_mask( + labels, + min_area=0, + extract_contours=True, + spatial_origin_yx=(10, 20), + ) + + assert len(rois) == 1 + assert rois[0].metadata["bbox"] == (12, 23, 16, 27) + assert rois[0].metadata["centroid"] == (13.5, 24.5) + assert isinstance(rois[0].shapes[0], PolygonShape) + assert float(rois[0].shapes[0].coordinates[:, 0].min()) >= 11.5 + assert float(rois[0].shapes[0].coordinates[:, 1].min()) >= 22.5 + + +def test_extract_rois_from_labeled_mask_applies_spatial_origin_to_mask_bbox(): + labels = np.zeros((8, 8), dtype=np.int32) + labels[2:6, 3:7] = 1 + + rois = extract_rois_from_labeled_mask( + labels, + min_area=0, + extract_contours=False, + spatial_origin_yx=(10, 20), + ) + + assert len(rois) == 1 + assert isinstance(rois[0].shapes[0], MaskShape) + assert rois[0].shapes[0].bbox == (12, 23, 16, 27) + + +def test_extract_rois_from_labeled_mask_records_source_canvas_shape(): + labels = np.zeros((8, 8), dtype=np.int32) + labels[2:6, 3:7] = 1 + + rois = extract_rois_from_labeled_mask( + labels, + min_area=0, + source_spatial_shape_yx=(100, 200), + ) + + assert len(rois) == 1 + assert rois[0].metadata["source_spatial_shape_yx"] == (100, 200) + + +def test_load_rois_from_json_decodes_shapes_through_nominal_registry(tmp_path): + roi_path = tmp_path / "rois.json" + roi_path.write_text( + """ + [ + { + "metadata": {"label": 1}, + "shapes": [ + {"type": "polygon", "coordinates": [[1, 2], [3, 4], [5, 6]]}, + {"type": "mask", "mask": [[true, false], [false, true]], "bbox": [10, 20, 12, 22]} + ] + } + ] + """ + ) + + rois = load_rois_from_json(roi_path) + + assert len(rois) == 1 + assert isinstance(rois[0].shapes[0], PolygonShape) + assert isinstance(rois[0].shapes[1], MaskShape) + assert rois[0].shapes[1].bbox == (10, 20, 12, 22) diff --git a/tests/test_streaming_metadata.py b/tests/test_streaming_metadata.py index 8141f1d..95d3b00 100644 --- a/tests/test_streaming_metadata.py +++ b/tests/test_streaming_metadata.py @@ -3,6 +3,7 @@ import pytest from polystore.streaming._streaming_backend import StreamingBackend +from polystore.streaming._streaming_backend import StreamingBatchRequest class MetadataProbeStreamingBackend(StreamingBackend): @@ -13,38 +14,36 @@ def save_batch(self, data_list, file_paths, **kwargs): raise NotImplementedError -def test_streaming_component_metadata_accepts_unparsed_artifact_filename() -> None: +def test_streaming_component_metadata_rejects_unparsed_artifact_filename() -> None: backend = MetadataProbeStreamingBackend() microscope_handler = SimpleNamespace( parser=SimpleNamespace(parse_filename=lambda _filename: None) ) - metadata = backend._parse_component_metadata( - "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip", - microscope_handler, - source="IdentifyPrimaryObjects", - ) - - assert metadata == {"source": "IdentifyPrimaryObjects"} + with pytest.raises(ValueError, match="explicit component_metadata"): + backend._parse_component_metadata( + "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip", + microscope_handler, + source="IdentifyPrimaryObjects", + ) -def test_streaming_batch_items_accept_unparsed_artifact_filename() -> None: +def test_streaming_batch_items_reject_unparsed_artifact_filename() -> None: backend = MetadataProbeStreamingBackend() microscope_handler = SimpleNamespace( parser=SimpleNamespace(parse_filename=lambda _filename: None) ) - batch_images, image_ids = backend._prepare_batch_items( - [object()], - ["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], - microscope_handler, - "IdentifyPrimaryObjects", - lambda _data, _path, _data_type: ({"payload": "ok"}, "image"), - ) - - assert len(image_ids) == 1 - assert batch_images[0]["metadata"] == {"source": "IdentifyPrimaryObjects"} - assert batch_images[0]["payload"] == "ok" + with pytest.raises(ValueError, match="explicit component_metadata"): + backend._prepare_batch_items( + StreamingBatchRequest( + data_list=[object()], + file_paths=["A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip"], + microscope_handler=microscope_handler, + source="IdentifyPrimaryObjects", + prepare_item=lambda _data, _path, _data_type: ({"payload": "ok"}, "image"), + ) + ) def test_streaming_component_metadata_preserves_parsed_filename_fields() -> None: @@ -64,13 +63,34 @@ def test_streaming_component_metadata_preserves_parsed_filename_fields() -> None assert metadata == {"well": "A01", "channel": 1, "source": "Crop"} +def test_streaming_component_metadata_prefers_explicit_metadata() -> None: + backend = MetadataProbeStreamingBackend() + microscope_handler = SimpleNamespace( + parser=SimpleNamespace(parse_filename=lambda _filename: None) + ) + + metadata = backend._parse_component_metadata( + "A01_s001_w1_z001_t001_Nuclei_step3_rois.roi.zip", + microscope_handler, + source="IdentifyPrimaryObjects", + component_metadata={"well": "A01", "site": 1, "channel": 1}, + ) + + assert metadata == { + "well": "A01", + "site": 1, + "channel": 1, + "source": "IdentifyPrimaryObjects", + } + + def test_streaming_component_metadata_rejects_invalid_parser_result() -> None: backend = MetadataProbeStreamingBackend() microscope_handler = SimpleNamespace( parser=SimpleNamespace(parse_filename=lambda _filename: ["not", "metadata"]) ) - with pytest.raises(TypeError, match="mapping or None"): + with pytest.raises(TypeError, match="must be a mapping"): backend._parse_component_metadata( "A01_s001_w1_z001_t001.TIF", microscope_handler, From ce8dc58fd180064dd6a5bdd9054210a7d8f96780 Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Sun, 24 May 2026 13:57:49 -0400 Subject: [PATCH 7/8] Support stacked labeled mask ROI extraction --- src/polystore/roi.py | 234 +++++++++++++++++++++++++++++++------------ 1 file changed, 169 insertions(+), 65 deletions(-) diff --git a/src/polystore/roi.py b/src/polystore/roi.py index d841591..26c1ef1 100644 --- a/src/polystore/roi.py +++ b/src/polystore/roi.py @@ -107,6 +107,167 @@ def __post_init__(self): raise ValueError(f"Shape {shape} must be an ROIShape") +@dataclass(frozen=True, slots=True) +class LabeledMaskROIExtractionRequest: + """Request to extract ROIs from a labeled mask or stack.""" + + labeled_mask: np.ndarray + min_area: int = 10 + extract_contours: bool = True + spatial_origin_yx: Optional[Tuple[int, int]] = None + source_spatial_shape_yx: Optional[Tuple[int, int]] = None + + +class LabeledMaskROIExtractor(ABC, metaclass=AutoRegisterMeta): + """Registered extraction behavior for one labeled-mask dimensional family.""" + + __registry_key__ = "__name__" + __skip_if_no_key__ = True + + @classmethod + def for_request( + cls, + request: LabeledMaskROIExtractionRequest, + ) -> "LabeledMaskROIExtractor": + for extractor_type in cls.__registry__.values(): + extractor = extractor_type() + if extractor.accepts(request.labeled_mask): + return extractor + raise ValueError( + "No ROI extractor registered for labeled mask shape " + f"{request.labeled_mask.shape}." + ) + + @abstractmethod + def accepts(self, labeled_mask: np.ndarray) -> bool: + """Return whether this extractor owns the mask dimensionality.""" + + @abstractmethod + def extract(self, request: LabeledMaskROIExtractionRequest) -> List[ROI]: + """Extract ROIs from the request.""" + + +class TwoDimensionalLabeledMaskROIExtractor(LabeledMaskROIExtractor): + """Extract ROIs from a single 2D labeled mask.""" + + def accepts(self, labeled_mask: np.ndarray) -> bool: + return labeled_mask.ndim == 2 + + def extract(self, request: LabeledMaskROIExtractionRequest) -> List[ROI]: + from skimage import measure + from skimage.measure import regionprops + from scipy.ndimage import find_objects + + labeled_mask = request.labeled_mask + if not np.issubdtype(labeled_mask.dtype, np.integer): + labeled_mask = labeled_mask.astype(np.int32) + + regions = regionprops(labeled_mask) + slices = find_objects(labeled_mask) + origin_y, origin_x = request.spatial_origin_yx or (0, 0) + + rois = [] + for region in regions: + if region.area < request.min_area: + continue + min_y, min_x, max_y, max_x = region.bbox + + metadata = { + "label": int(region.label), + "area": float(region.area), + "perimeter": float(region.perimeter), + "centroid": ( + float(region.centroid[0] + origin_y), + float(region.centroid[1] + origin_x), + ), + "bbox": ( + int(min_y + origin_y), + int(min_x + origin_x), + int(max_y + origin_y), + int(max_x + origin_x), + ), + } + if request.source_spatial_shape_yx is not None: + metadata["source_spatial_shape_yx"] = tuple( + int(value) for value in request.source_spatial_shape_yx + ) + + shapes = [] + if request.extract_contours: + label_idx = region.label - 1 + if label_idx < len(slices) and slices[label_idx] is not None: + slice_y, slice_x = slices[label_idx] + cropped_mask = labeled_mask[slice_y, slice_x] + binary_mask = (cropped_mask == region.label).astype(np.uint8) + padded_mask = np.pad(binary_mask, pad_width=1, mode="constant", constant_values=0) + contours = measure.find_contours(padded_mask, level=0.5) + offset_y = slice_y.start + offset_x = slice_x.start + padding_offset = np.array([offset_y + origin_y, offset_x + origin_x]) - 1 + for contour in contours: + if len(contour) >= 3: + contour_full = contour + padding_offset + shapes.append(PolygonShape(coordinates=contour_full)) + else: + binary_mask = labeled_mask == region.label + shapes.append(MaskShape(mask=binary_mask, bbox=metadata["bbox"])) + + if shapes: + rois.append(ROI(shapes=shapes, metadata=metadata)) + + logger.info(f"Extracted {len(rois)} ROIs from labeled mask") + return rois + + +class NonSpatialLabeledMaskROIExtractor(LabeledMaskROIExtractor): + """Treat scalar and otherwise non-spatial label payloads as empty ROI sets.""" + + def accepts(self, labeled_mask: np.ndarray) -> bool: + return labeled_mask.ndim < 2 + + def extract(self, request: LabeledMaskROIExtractionRequest) -> List[ROI]: + return [] + + +class StackedLabeledMaskROIExtractor(LabeledMaskROIExtractor): + """Extract ROIs from all 2D planes in a labeled-mask stack.""" + + def accepts(self, labeled_mask: np.ndarray) -> bool: + return labeled_mask.ndim > 2 + + def extract(self, request: LabeledMaskROIExtractionRequest) -> List[ROI]: + stack = request.labeled_mask + plane_shape = stack.shape[-2:] + leading_shape = stack.shape[:-2] + rois: list[ROI] = [] + for plane_indices in np.ndindex(leading_shape): + plane_request = LabeledMaskROIExtractionRequest( + labeled_mask=stack[plane_indices], + min_area=request.min_area, + extract_contours=request.extract_contours, + spatial_origin_yx=request.spatial_origin_yx, + source_spatial_shape_yx=request.source_spatial_shape_yx or plane_shape, + ) + for roi in TwoDimensionalLabeledMaskROIExtractor().extract(plane_request): + rois.append(self._with_plane_metadata(roi, plane_indices, leading_shape)) + return rois + + @staticmethod + def _with_plane_metadata( + roi: ROI, + plane_indices: tuple[int, ...], + leading_shape: tuple[int, ...], + ) -> ROI: + return ROI( + shapes=roi.shapes, + metadata={ + **roi.metadata, + "plane_indices": tuple(int(index) for index in plane_indices), + "plane_shape": tuple(int(size) for size in leading_shape), + }, + ) + + class ROIJsonShapeDecoder(ABC, metaclass=AutoRegisterMeta): """Decode one serialized ROI shape variant.""" @@ -181,71 +342,14 @@ def extract_rois_from_labeled_mask( source_spatial_shape_yx: Optional[Tuple[int, int]] = None, ) -> List[ROI]: """Extract ROIs from a labeled segmentation mask.""" - from skimage import measure - from skimage.measure import regionprops - from scipy.ndimage import find_objects - - if labeled_mask.ndim != 2: - raise ValueError(f"Labeled mask must be 2D, got shape {labeled_mask.shape}") - - if not np.issubdtype(labeled_mask.dtype, np.integer): - labeled_mask = labeled_mask.astype(np.int32) - - regions = regionprops(labeled_mask) - slices = find_objects(labeled_mask) - origin_y, origin_x = spatial_origin_yx or (0, 0) - - rois = [] - for region in regions: - if region.area < min_area: - continue - min_y, min_x, max_y, max_x = region.bbox - - metadata = { - "label": int(region.label), - "area": float(region.area), - "perimeter": float(region.perimeter), - "centroid": ( - float(region.centroid[0] + origin_y), - float(region.centroid[1] + origin_x), - ), - "bbox": ( - int(min_y + origin_y), - int(min_x + origin_x), - int(max_y + origin_y), - int(max_x + origin_x), - ), - } - if source_spatial_shape_yx is not None: - metadata["source_spatial_shape_yx"] = tuple( - int(value) for value in source_spatial_shape_yx - ) - - shapes = [] - if extract_contours: - label_idx = region.label - 1 - if label_idx < len(slices) and slices[label_idx] is not None: - slice_y, slice_x = slices[label_idx] - cropped_mask = labeled_mask[slice_y, slice_x] - binary_mask = (cropped_mask == region.label).astype(np.uint8) - padded_mask = np.pad(binary_mask, pad_width=1, mode="constant", constant_values=0) - contours = measure.find_contours(padded_mask, level=0.5) - offset_y = slice_y.start - offset_x = slice_x.start - padding_offset = np.array([offset_y + origin_y, offset_x + origin_x]) - 1 - for contour in contours: - if len(contour) >= 3: - contour_full = contour + padding_offset - shapes.append(PolygonShape(coordinates=contour_full)) - else: - binary_mask = (labeled_mask == region.label) - shapes.append(MaskShape(mask=binary_mask, bbox=metadata["bbox"])) - - if shapes: - rois.append(ROI(shapes=shapes, metadata=metadata)) - - logger.info(f"Extracted {len(rois)} ROIs from labeled mask") - return rois + request = LabeledMaskROIExtractionRequest( + labeled_mask=np.asarray(labeled_mask), + min_area=min_area, + extract_contours=extract_contours, + spatial_origin_yx=spatial_origin_yx, + source_spatial_shape_yx=source_spatial_shape_yx, + ) + return LabeledMaskROIExtractor.for_request(request).extract(request) def _get_backend_from_filemanager(filemanager: Any, backend: Union[str, Backend]): From 924b950d5de6dc8019cd8f2e3ed78f191b553092 Mon Sep 17 00:00:00 2001 From: Tristan Simas Date: Tue, 26 May 2026 18:17:32 -0400 Subject: [PATCH 8/8] Add Bio-Formats storage backend --- src/polystore/backend_registry.py | 3 +- src/polystore/bioformats_java.py | 223 ++++++++++++++++++++++++ src/polystore/bioformats_storage.py | 258 ++++++++++++++++++++++++++++ src/polystore/constants.py | 1 + 4 files changed, 483 insertions(+), 2 deletions(-) create mode 100644 src/polystore/bioformats_java.py create mode 100644 src/polystore/bioformats_storage.py diff --git a/src/polystore/backend_registry.py b/src/polystore/backend_registry.py index ad8ac52..eb4cb21 100644 --- a/src/polystore/backend_registry.py +++ b/src/polystore/backend_registry.py @@ -74,7 +74,7 @@ def create_storage_registry() -> Dict[str, DataSink]: # Backends that require context-specific initialization (e.g., plate_root) # These are registered lazily when needed, not at startup - SKIP_BACKENDS = {'virtual_workspace', 'omero_local'} + SKIP_BACKENDS = {'virtual_workspace', 'omero_local', 'bioformats'} registry = {} for backend_type in STORAGE_BACKENDS.keys(): @@ -157,4 +157,3 @@ def cleanup_all_backends() -> None: _backend_instances.clear() logger.info("All backend instances cleaned up") - diff --git a/src/polystore/bioformats_java.py b/src/polystore/bioformats_java.py new file mode 100644 index 0000000..41c7824 --- /dev/null +++ b/src/polystore/bioformats_java.py @@ -0,0 +1,223 @@ +"""Shared Java Bio-Formats bridge for metadata discovery and plane loading.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from threading import Lock +from typing import Any, Callable + +import numpy as np + + +class BioFormatsJavaUnavailableError(RuntimeError): + """Raised when the Java Bio-Formats runtime cannot be initialized.""" + + +@dataclass(frozen=True, slots=True) +class BioFormatsOpenedReader: + """Open Bio-Formats reader plus its OME metadata store.""" + + reader: Any + metadata: Any + + def close(self) -> None: + self.reader.close() + + +class BioFormatsJavaContext: + """Lazy JVM/ImageJ context for Bio-Formats Java access.""" + + _lock = Lock() + _instance: "BioFormatsJavaContext | None" = None + + def __init__(self, imagej_module: Any, scyjava_module: Any): + self.imagej = imagej_module + self.scyjava = scyjava_module + self.ij = None + self.ImageReader = None + self.MetadataTools = None + self.FormatTools = None + + @classmethod + def instance(cls) -> "BioFormatsJavaContext": + with cls._lock: + if cls._instance is None: + cls._instance = cls._create() + return cls._instance + + @classmethod + def _create(cls) -> "BioFormatsJavaContext": + try: + import imagej + import scyjava + except ImportError as exc: + raise BioFormatsJavaUnavailableError( + "Bio-Formats support requires the optional bioformats/fiji dependencies." + ) from exc + return cls(imagej, scyjava) + + def ensure_initialized(self) -> None: + if self.ij is not None: + return + try: + self.ij = self.imagej.init("sc.fiji:fiji", mode="headless") + self.ImageReader = self.scyjava.jimport("loci.formats.ImageReader") + self.MetadataTools = self.scyjava.jimport("loci.formats.MetadataTools") + self.FormatTools = self.scyjava.jimport("loci.formats.FormatTools") + except Exception as exc: + raise BioFormatsJavaUnavailableError( + "Could not initialize Fiji/Bio-Formats through pyimagej." + ) from exc + + def open_reader(self, source_path: str | Path) -> BioFormatsOpenedReader: + self.ensure_initialized() + metadata = self.MetadataTools.createOMEXMLMetadata() + reader = self.ImageReader() + try: + reader.setMetadataStore(metadata) + reader.setId(str(source_path)) + return BioFormatsOpenedReader(reader=reader, metadata=metadata) + except Exception: + reader.close() + raise + + +def java_int(value: Any) -> int | None: + """Convert nullable Java primitive wrappers to Python int.""" + return OptionalJavaScalar.from_java(value, JAVA_SCALAR_PROJECTOR.readers).convert(int) + + +def java_float(value: Any) -> float | None: + """Convert nullable Java numeric wrappers to Python float.""" + return OptionalJavaScalar.from_java(value, JAVA_SCALAR_PROJECTOR.readers).convert(float) + + +def java_str(value: Any) -> str | None: + """Convert nullable Java strings to Python strings.""" + if value is None: + return None + return str(value) + + +def _read_java_value(value: Any) -> Any: + return value.value() + + +def _read_java_get_value(value: Any) -> Any: + return value.getValue() + + +@dataclass(frozen=True, slots=True) +class JavaScalarProjector: + """Project nullable Java scalar wrappers to Python scalar values.""" + + readers: tuple[Callable[[Any], Any], ...] + + def unwrap(self, value: Any) -> Any: + for reader in self.readers: + try: + return reader(value) + except AttributeError: + continue + return value + + +@dataclass(frozen=True, slots=True) +class OptionalJavaScalar: + """Nullable Java scalar after wrapper unwrapping.""" + + value: Any | None + + @classmethod + def from_java( + cls, + value: Any, + readers: tuple[Callable[[Any], Any], ...], + ) -> "OptionalJavaScalar": + if value is None: + return cls(None) + return cls(JavaScalarProjector(readers).unwrap(value)) + + def convert(self, converter: Callable[[Any], Any]) -> Any | None: + if self.value is None: + return None + return converter(self.value) + + +JAVA_SCALAR_PROJECTOR = JavaScalarProjector( + readers=( + _read_java_value, + _read_java_get_value, + ) +) + + +def load_bioformats_plane( + *, + source_path: Path, + series_index: int, + plane_index: int, +) -> np.ndarray: + """Load a single 2D Bio-Formats plane through the Java ImageReader.""" + context = BioFormatsJavaContext.instance() + opened = context.open_reader(source_path) + reader = opened.reader + try: + reader.setSeries(series_index) + if reader.getRGBChannelCount() != 1: + raise ValueError( + "Bio-Formats RGB/interleaved planes are not yet representable as " + "OpenHCS scalar channel planes." + ) + raw = bytes(reader.openBytes(plane_index)) + dtype = PixelDtypeCatalog.from_format_tools(context.FormatTools).dtype( + pixel_type=int(reader.getPixelType()), + little_endian=bool(reader.isLittleEndian()), + ) + array = np.frombuffer(raw, dtype=dtype) + return array.reshape((int(reader.getSizeY()), int(reader.getSizeX()))) + finally: + opened.close() + + +@dataclass(frozen=True, slots=True) +class PixelDtypeSpec: + """NumPy dtype projection for one Bio-Formats pixel type.""" + + key: int + dtype_code: str + endian_sensitive: bool = True + + def dtype(self, *, little_endian: bool) -> np.dtype: + if not self.endian_sensitive: + return np.dtype(self.dtype_code) + endian = "<" if little_endian else ">" + return np.dtype(endian + self.dtype_code) + + +@dataclass(frozen=True, slots=True) +class PixelDtypeCatalog: + """Authoritative Bio-Formats pixel-type to NumPy dtype mapping.""" + + specs_by_key: dict[int, PixelDtypeSpec] + + @classmethod + def from_format_tools(cls, format_tools: Any) -> "PixelDtypeCatalog": + specs = ( + PixelDtypeSpec(int(format_tools.INT8), "i1", endian_sensitive=False), + PixelDtypeSpec(int(format_tools.UINT8), "u1", endian_sensitive=False), + PixelDtypeSpec(int(format_tools.INT16), "i2"), + PixelDtypeSpec(int(format_tools.UINT16), "u2"), + PixelDtypeSpec(int(format_tools.INT32), "i4"), + PixelDtypeSpec(int(format_tools.UINT32), "u4"), + PixelDtypeSpec(int(format_tools.FLOAT), "f4"), + PixelDtypeSpec(int(format_tools.DOUBLE), "f8"), + ) + return cls({spec.key: spec for spec in specs}) + + def dtype(self, *, pixel_type: int, little_endian: bool) -> np.dtype: + try: + return self.specs_by_key[pixel_type].dtype(little_endian=little_endian) + except KeyError as exc: + raise ValueError(f"Unsupported Bio-Formats pixel type: {pixel_type}") from exc diff --git a/src/polystore/bioformats_storage.py b/src/polystore/bioformats_storage.py new file mode 100644 index 0000000..ba17dcf --- /dev/null +++ b/src/polystore/bioformats_storage.py @@ -0,0 +1,258 @@ +"""Structured-reference backend for Bio-Formats-backed virtual workspaces.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from fnmatch import fnmatch +from pathlib import Path +from typing import Any, Dict, List, Optional, Set, Union + +from .base import PicklableBackend, ReadOnlyBackend +from .constants import Backend +from .exceptions import StorageResolutionError +from .metadata_writer import get_metadata_path + + +@dataclass(frozen=True, slots=True) +class BioFormatsPlaneRef: + """Serializable reference to one Bio-Formats image plane.""" + + source_path: Path + series_index: int + plane_index: int + c: int + z: int + t: int + reader: str = "bioformats" + + @classmethod + def from_mapping( + cls, + payload: Dict[str, Any], + *, + plate_root: Path, + ) -> "BioFormatsPlaneRef": + source_path = Path(payload["source_path"]) + if not source_path.is_absolute(): + source_path = plate_root / source_path + return cls( + source_path=source_path, + series_index=int(payload.get("series_index", 0)), + plane_index=int(payload["plane_index"]), + c=int(payload["c"]), + z=int(payload["z"]), + t=int(payload["t"]), + reader=str(payload.get("reader", "bioformats")), + ) + + +class BioFormatsStorageBackend(ReadOnlyBackend, PicklableBackend): + """Load normalized virtual source keys from structured Bio-Formats refs.""" + + _backend_type = Backend.BIOFORMATS.value + + def __init__(self, plate_root: Path | None = None): + self.plate_root = None if plate_root is None else Path(plate_root) + self._mapping_cache: Optional[Dict[str, Dict[str, Any]]] = None + self._cache_mtime: Optional[float] = None + + def get_connection_params(self) -> Optional[Dict[str, Any]]: + if self.plate_root is None: + return None + return {"plate_root": str(self.plate_root)} + + def set_connection_params(self, params: Optional[Dict[str, Any]]) -> None: + if not params: + self.plate_root = None + self._mapping_cache = None + self._cache_mtime = None + return + self.plate_root = Path(params["plate_root"]) + self._mapping_cache = None + self._cache_mtime = None + + def load(self, file_path: Union[str, Path], **kwargs) -> Any: + ref = self._resolve_ref(file_path) + if ref.reader == "npy": + return _load_npy_plane(ref) + if ref.reader != "bioformats": + raise BioFormatsReaderUnavailableError( + f"Unsupported Bio-Formats reader {ref.reader!r}." + ) + from .bioformats_java import load_bioformats_plane + + return load_bioformats_plane( + source_path=ref.source_path, + series_index=ref.series_index, + plane_index=ref.plane_index, + ) + + def load_batch(self, file_paths: List[Union[str, Path]], **kwargs) -> List[Any]: + return [self.load(file_path, **kwargs) for file_path in file_paths] + + def list_files( + self, + directory: Union[str, Path], + pattern: Optional[str] = None, + extensions: Optional[Set[str]] = None, + recursive: bool = False, + **kwargs, + ) -> List[str]: + plate_root = self._require_plate_root() + relative_dir = self.relative_to_root(directory) + normalized_dir = _normalize_relative_path(str(relative_dir)) + lowercase_extensions = ( + None if extensions is None else {extension.lower() for extension in extensions} + ) + results = [] + for virtual_path in self._load_mapping().keys(): + if not _virtual_path_in_directory( + virtual_path, + normalized_dir=normalized_dir, + recursive=recursive, + ): + continue + path = Path(virtual_path) + if lowercase_extensions is not None and path.suffix.lower() not in lowercase_extensions: + continue + if pattern is not None and not fnmatch(path.name, pattern): + continue + results.append(str(plate_root / virtual_path)) + return results + + def exists(self, path: Union[str, Path]) -> bool: + try: + relative = self.normalized_relative_path(path) + except StorageResolutionError: + return False + if not relative: + return True + mapping = self._load_mapping() + return relative in mapping or any( + virtual_path.startswith(relative + "/") + for virtual_path in mapping + ) + + def is_file(self, path: Union[str, Path]) -> bool: + try: + relative = self.normalized_relative_path(path) + except StorageResolutionError: + return False + return relative in self._load_mapping() + + def is_dir(self, path: Union[str, Path]) -> bool: + try: + relative = self.normalized_relative_path(path) + except StorageResolutionError: + return False + return not relative or any( + virtual_path.startswith(relative + "/") + for virtual_path in self._load_mapping() + ) + + def list_dir(self, path: Union[str, Path]) -> List[str]: + relative = self.normalized_relative_path(path) + prefix = "" if not relative else relative + "/" + names = set() + for virtual_path in self._load_mapping(): + if not virtual_path.startswith(prefix): + continue + remainder = virtual_path[len(prefix):] + if remainder: + names.add(remainder.split("/", 1)[0]) + return sorted(names) + + def _resolve_ref(self, path: Union[str, Path]) -> BioFormatsPlaneRef: + plate_root = self._require_plate_root() + relative_path = self.normalized_relative_path(path) + mapping = self._load_mapping() + try: + payload = mapping[relative_path] + except KeyError as exc: + raise StorageResolutionError( + f"Path not in Bio-Formats workspace mapping: {relative_path}" + ) from exc + if not isinstance(payload, dict): + raise StorageResolutionError( + f"Bio-Formats workspace mapping for {relative_path!r} is not structured." + ) + return BioFormatsPlaneRef.from_mapping(payload, plate_root=plate_root) + + def _load_mapping(self) -> Dict[str, Dict[str, Any]]: + plate_root = self._require_plate_root() + metadata_path = get_metadata_path(plate_root) + if not metadata_path.exists(): + raise FileNotFoundError(f"Metadata not found: {metadata_path}") + current_mtime = metadata_path.stat().st_mtime + if self._mapping_cache is not None and self._cache_mtime == current_mtime: + return self._mapping_cache + metadata = json.loads(metadata_path.read_text(encoding="utf-8")) + combined_mapping: Dict[str, Dict[str, Any]] = {} + for subdirectory in metadata.get("subdirectories", {}).values(): + if Backend.BIOFORMATS.value not in subdirectory.get("available_backends", {}): + continue + workspace_mapping = subdirectory.get("workspace_mapping", {}) + for virtual_path, ref_payload in workspace_mapping.items(): + if isinstance(ref_payload, dict): + combined_mapping[_normalize_relative_path(str(virtual_path))] = ref_payload + if not combined_mapping: + raise ValueError(f"No Bio-Formats workspace_mapping in {metadata_path}") + self._mapping_cache = combined_mapping + self._cache_mtime = current_mtime + return combined_mapping + + def _require_plate_root(self) -> Path: + if self.plate_root is None: + raise StorageResolutionError("BioFormatsStorageBackend requires plate_root.") + return self.plate_root + + def relative_to_root(self, path: Union[str, Path]) -> Path: + plate_root = self._require_plate_root() + path_obj = Path(path) + if not path_obj.is_absolute(): + return path_obj + try: + return path_obj.relative_to(plate_root) + except ValueError as exc: + raise StorageResolutionError( + f"Path {path_obj} is outside Bio-Formats plate root {plate_root}." + ) from exc + + def normalized_relative_path(self, path: Union[str, Path]) -> str: + return _normalize_relative_path(str(self.relative_to_root(path))) + + +class BioFormatsReaderUnavailableError(RuntimeError): + """Raised when a production Bio-Formats reader has not been configured.""" + + +def _load_npy_plane(ref: BioFormatsPlaneRef) -> Any: + import numpy as np + + array = np.load(ref.source_path) + if array.ndim == 2: + return array + if array.ndim == 5: + return array[ref.t - 1, ref.z - 1, ref.c - 1] + if array.ndim == 3: + return array[ref.plane_index] + raise ValueError( + f"Unsupported npy Bio-Formats fixture shape {array.shape} for {ref.source_path}." + ) + + +def _normalize_relative_path(path: str) -> str: + normalized = path.replace("\\", "/") + return "" if normalized == "." else normalized + + +def _virtual_path_in_directory( + virtual_path: str, + *, + normalized_dir: str, + recursive: bool, +) -> bool: + if recursive: + return not normalized_dir or virtual_path.startswith(normalized_dir + "/") + return _normalize_relative_path(str(Path(virtual_path).parent)) == normalized_dir diff --git a/src/polystore/constants.py b/src/polystore/constants.py index 3a27cfb..0103236 100644 --- a/src/polystore/constants.py +++ b/src/polystore/constants.py @@ -19,6 +19,7 @@ class Backend(Enum): FIJI_STREAM = "fiji_stream" OMERO_LOCAL = "omero_local" VIRTUAL_WORKSPACE = "virtual_workspace" + BIOFORMATS = "bioformats" class TransportMode(Enum):