diff --git a/README-pypi.md b/README-pypi.md index be0de197..f6173669 100644 --- a/README-pypi.md +++ b/README-pypi.md @@ -93,6 +93,14 @@ Here you find a series of notebooks that give you an overview of the core featur Using PyTorch gradients to fit a Gaussian generated by a DeepTrack2 pipeline. +- DTGS171A **[Creating Custom Scatterers](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171A_custom_scatterers.ipynb)** + + Creating custom scatterers of arbitrary shapes. + +- DTGS171B **[Creating Custom Scatterers: Bacteria](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171B_custom_scatterers_bacteria.ipynb)** + + Creating custom scatterers in the shape of bacteria. + # Examples These are examples of how DeepTrack2 can be used on real datasets: diff --git a/README.md b/README.md index 674d41d3..a9f507c2 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,14 @@ Here you find a series of notebooks that give you an overview of the core featur Using PyTorch gradients to fit a Gaussian generated by a DeepTrack2 pipeline. +- DTGS171A **[Creating Custom Scatterers](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171A_custom_scatterers.ipynb)** + + Creating custom scatterers of arbitrary shapes. + +- DTGS171B **[Creating Custom Scatterers: Bacteria](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171B_custom_scatterers_bacteria.ipynb)** + + Creating custom scatterers in the shape of bacteria. + # Examples These are examples of how DeepTrack2 can be used on real datasets: diff --git a/deeptrack/features.py b/deeptrack/features.py index 4bdfad38..9572ea82 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -1,6 +1,6 @@ """Core features for building and processing pipelines in DeepTrack2. -The `feasture.py` module defines the core classes and utilities used to create +The `feature.py` module defines the core classes and utilities used to create and manipulate features in DeepTrack2, enabling users to build sophisticated data processing pipelines with modular, reusable, and composable components. @@ -52,6 +52,11 @@ hierarchical or logical structures in the pipeline without input transformations. +- `BackendDispatched`: Mixin class for backend-specific implementations. + + Provides mechanisms for dispatching feature methods based on the + computational backend (e.g., NumPy, PyTorch). + - `ArithmeticOperationFeature`: Apply arithmetic operation element-wise. A parent class for features performing arithmetic operations like addition, @@ -80,11 +85,8 @@ - `OneOf`: Resolve one feature from a given collection. - `OneOfDict`: Resolve one feature from a dictionary and apply it to an input. - `LoadImage`: Load an image from disk and preprocess it. -- `SampleToMasks`: Create a mask from a list of images. - `AsType`: Convert the data type of the input. - `ChannelFirst2d`: DEPRECATED Convert an image to a channel-first format. -- `Upscale`: Simulate a pipeline at a higher resolution. -- `NonOverlapping`: Ensure volumes are placed non-overlapping in a 3D space. - `Store`: Store the output of a feature for reuse. - `Squeeze`: Squeeze the input to the smallest possible dimension. - `Unsqueeze`: Unsqueeze the input. @@ -96,7 +98,7 @@ - `TakeProperties`: Extract all instances of properties from a pipeline. Arithmetic Feature Classes: -- `Add`: Add a value to the input. +- `Add`: Add a value to the input.@dataclass - `Subtract`: Subtract a value from the input. - `Multiply`: Multiply the input by a value. - `Divide`: Divide the input by a value. @@ -183,6 +185,7 @@ __all__ = [ "Feature", "StructuralFeature", + "BackendDispatched", "Chain", "Branch", "DummyFeature", @@ -218,11 +221,8 @@ "OneOf", "OneOfDict", "LoadImage", - "SampleToMasks", # TODO ***CM*** revise this after elimination of Image "AsType", "ChannelFirst2d", - "Upscale", # TODO ***CM*** revise and check PyTorch afrer elimin. Image - "NonOverlapping", # TODO ***CM*** revise + PyTorch afrer elimin. Image "Store", "Squeeze", "Unsqueeze", @@ -3953,6 +3953,32 @@ class StructuralFeature(Feature): __distributed__: bool = False # Process the entire image list in one call +class BackendDispatched: + """Mixin for Feature.get() methods with backend-specific implementations.""" + + _NUMPY_IMPL: str | None = None + _TORCH_IMPL: str | None = None + + def _dispatch_backend(self, *args, **kwargs): + backend = self.get_backend() + + if backend == "numpy": + if self._NUMPY_IMPL is None: + raise NotImplementedError( + f"{self.__class__.__name__} does not support NumPy backend." + ) + return getattr(self, self._NUMPY_IMPL)(*args, **kwargs) + + if backend == "torch": + if self._TORCH_IMPL is None: + raise NotImplementedError( + f"{self.__class__.__name__} does not support Torch backend." + ) + return getattr(self, self._TORCH_IMPL)(*args, **kwargs) + + raise RuntimeError(f"Unknown backend {backend}") + + class Chain(StructuralFeature): """Resolve two features sequentially. @@ -7359,312 +7385,6 @@ def get( return image -class SampleToMasks(Feature): - """Create a mask from a list of images. - - This feature applies a transformation function to each input image and - merges the resulting masks into a single multi-layer image. Each input - image must have a `position` property that determines its placement within - the final mask. When used with scatterers, the `voxel_size` property must - be provided for correct object sizing. - - Parameters - ---------- - transformation_function: Callable[[Image], Image] - A function that transforms each input image into a mask with - `number_of_masks` layers. - number_of_masks: PropertyLike[int], optional - The number of mask layers to generate. Default is 1. - output_region: PropertyLike[tuple[int, int, int, int]], optional - The size and position of the output mask, typically aligned with - `optics.output_region`. - merge_method: PropertyLike[str | Callable | list[str | Callable]], optional - Method for merging individual masks into the final image. Can be: - - "add" (default): Sum the masks. - - "overwrite": Later masks overwrite earlier masks. - - "or": Combine masks using a logical OR operation. - - "mul": Multiply masks. - - Function: Custom function taking two images and merging them. - - **kwargs: dict[str, Any] - Additional keyword arguments passed to the parent `Feature` class. - - Methods - ------- - `get(image, transformation_function, **kwargs) -> Image` - Applies the transformation function to the input image. - `_process_and_get(images, **kwargs) -> Image | np.ndarray` - Processes a list of images and generates a multi-layer mask. - - Returns - ------- - Image or np.ndarray - The final mask image with the specified number of layers. - - Raises - ------ - ValueError - If `merge_method` is invalid. - - Examples - ------- - >>> import deeptrack as dt - - Define number of particles: - - >>> n_particles = 12 - - Define optics and particles: - - >>> import numpy as np - >>> - >>> optics = dt.Fluorescence(output_region=(0, 0, 64, 64)) - >>> particle = dt.PointParticle( - >>> position=lambda: np.random.uniform(5, 55, size=2), - >>> ) - >>> particles = particle ^ n_particles - - Define pipelines: - - >>> sim_im_pip = optics(particles) - >>> sim_mask_pip = particles >> dt.SampleToMasks( - ... lambda: lambda particles: particles > 0, - ... output_region=optics.output_region, - ... merge_method="or", - ... ) - >>> pipeline = sim_im_pip & sim_mask_pip - >>> pipeline.store_properties() - - Generate image and mask: - - >>> image, mask = pipeline.update()() - - Get particle positions: - - >>> positions = np.array(image.get_property("position", get_one=False)) - - Visualize results: - - >>> import matplotlib.pyplot as plt - >>> - >>> plt.subplot(1, 2, 1) - >>> plt.imshow(image, cmap="gray") - >>> plt.title("Original Image") - >>> plt.subplot(1, 2, 2) - >>> plt.imshow(mask, cmap="gray") - >>> plt.scatter(positions[:,1], positions[:,0], c="y", marker="x", s = 50) - >>> plt.title("Mask") - >>> plt.show() - - """ - - def __init__( - self: Feature, - transformation_function: Callable[[Image], Image], - number_of_masks: PropertyLike[int] = 1, - output_region: PropertyLike[tuple[int, int, int, int]] = None, - merge_method: PropertyLike[str | Callable | list[str | Callable]] = "add", - **kwargs: Any, - ): - """Initialize the SampleToMasks feature. - - Parameters - ---------- - transformation_function: Callable[[Image], Image] - Function to transform input images into masks. - number_of_masks: PropertyLike[int], optional - Number of mask layers. Default is 1. - output_region: PropertyLike[tuple[int, int, int, int]], optional - Output region of the mask. Default is None. - merge_method: PropertyLike[str | Callable | list[str | Callable]], optional - Method to merge masks. Defaults to "add". - **kwargs: dict[str, Any] - Additional keyword arguments passed to the parent class. - - """ - - super().__init__( - transformation_function=transformation_function, - number_of_masks=number_of_masks, - output_region=output_region, - merge_method=merge_method, - **kwargs, - ) - - def get( - self: Feature, - image: np.ndarray | Image, - transformation_function: Callable[[Image], Image], - **kwargs: Any, - ) -> Image: - """Apply the transformation function to a single image. - - Parameters - ---------- - image: np.ndarray | Image - The input image. - transformation_function: Callable[[Image], Image] - Function to transform the image. - **kwargs: dict[str, Any] - Additional parameters. - - Returns - ------- - Image - The transformed image. - - """ - - return transformation_function(image) - - def _process_and_get( - self: Feature, - images: list[np.ndarray] | np.ndarray | list[Image] | Image, - **kwargs: Any, - ) -> Image | np.ndarray: - """Process a list of images and generate a multi-layer mask. - - Parameters - ---------- - images: np.ndarray or list[np.ndarrray] or Image or list[Image] - List of input images or a single image. - **kwargs: dict[str, Any] - Additional parameters including `output_region`, `number_of_masks`, - and `merge_method`. - - Returns - ------- - Image or np.ndarray - The final mask image. - - """ - - # Handle list of images. - if isinstance(images, list) and len(images) != 1: - list_of_labels = super()._process_and_get(images, **kwargs) - if not self._wrap_array_with_image: - for idx, (label, image) in enumerate(zip(list_of_labels, - images)): - list_of_labels[idx] = \ - Image(label, copy=False).merge_properties_from(image) - else: - if isinstance(images, list): - images = images[0] - list_of_labels = [] - for prop in images.properties: - - if "position" in prop: - - inp = Image(np.array(images)) - inp.append(prop) - out = Image(self.get(inp, **kwargs)) - out.merge_properties_from(inp) - list_of_labels.append(out) - - # Create an empty output image. - output_region = kwargs["output_region"] - output = np.zeros( - ( - output_region[2] - output_region[0], - output_region[3] - output_region[1], - kwargs["number_of_masks"], - ) - ) - - from deeptrack.optics import _get_position - - # Merge masks into the output. - for label in list_of_labels: - position = _get_position(label) - p0 = np.round(position - output_region[0:2]) - - if np.any(p0 > output.shape[0:2]) or \ - np.any(p0 + label.shape[0:2] < 0): - continue - - crop_x = int(-np.min([p0[0], 0])) - crop_y = int(-np.min([p0[1], 0])) - crop_x_end = int( - label.shape[0] - - np.max([p0[0] + label.shape[0] - output.shape[0], 0]) - ) - crop_y_end = int( - label.shape[1] - - np.max([p0[1] + label.shape[1] - output.shape[1], 0]) - ) - - labelarg = label[crop_x:crop_x_end, crop_y:crop_y_end, :] - - p0[0] = np.max([p0[0], 0]) - p0[1] = np.max([p0[1], 0]) - - p0 = p0.astype(int) - - output_slice = output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - ] - - for label_index in range(kwargs["number_of_masks"]): - - if isinstance(kwargs["merge_method"], list): - merge = kwargs["merge_method"][label_index] - else: - merge = kwargs["merge_method"] - - if merge == "add": - output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - label_index, - ] += labelarg[..., label_index] - - elif merge == "overwrite": - output_slice[ - labelarg[..., label_index] != 0, label_index - ] = labelarg[labelarg[..., label_index] != 0, \ - label_index] - output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - label_index, - ] = output_slice[..., label_index] - - elif merge == "or": - output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - label_index, - ] = (output_slice[..., label_index] != 0) | ( - labelarg[..., label_index] != 0 - ) - - elif merge == "mul": - output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - label_index, - ] *= labelarg[..., label_index] - - else: - # No match, assume function - output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - label_index, - ] = merge( - output_slice[..., label_index], - labelarg[..., label_index], - ) - - if not self._wrap_array_with_image: - return output - output = Image(output) - for label in list_of_labels: - output.merge_properties_from(label) - return output - - class AsType(Feature): """Convert the data type of arrays. @@ -7930,855 +7650,181 @@ def get( return array -class Upscale(Feature): - """Simulate a pipeline at a higher resolution. - - This feature scales up the resolution of the input pipeline by a specified - factor, performs computations at the higher resolution, and then - downsamples the result back to the original size. This is useful for - simulating effects at a finer resolution while preserving compatibility - with lower-resolution pipelines. - - Internally, this feature redefines the scale of physical units (e.g., - `units.pixel`) to achieve the effect of upscaling. Therefore, it does not - resize the input image itself but affects only features that rely on - physical units. - - Parameters - ---------- - feature: Feature - The pipeline or feature to resolve at a higher resolution. - factor: int or tuple[int, int, int], optional - The factor by which to upscale the simulation. If a single integer is - provided, it is applied uniformly across all axes. If a tuple of three - integers is provided, each axis is scaled individually. Defaults to 1. - **kwargs: Any - Additional keyword arguments passed to the parent `Feature` class. - - Attributes - ---------- - __distributed__: bool - Always `False` for `Upscale`, indicating that this feature’s `.get()` - method processes the entire input at once even if it is a list, rather - than distributing calls for each item of the list. - - Methods - ------- - `get(image, factor, **kwargs) -> np.ndarray | torch.tensor` - Simulates the pipeline at a higher resolution and returns the result at - the original resolution. - - Notes - ----- - - This feature does not directly resize the image. Instead, it modifies the - unit conversions within the pipeline, making physical units smaller, - which results in more detail being simulated. - - The final output is downscaled back to the original resolution using - `block_reduce` from `skimage.measure`. - - The effect is only noticeable if features use physical units (e.g., - `units.pixel`, `units.meter`). Otherwise, the result will be identical. - - Examples - -------- - >>> import deeptrack as dt - - Define an optical pipeline and a spherical particle: - - >>> optics = dt.Fluorescence() - >>> particle = dt.Sphere() - >>> simple_pipeline = optics(particle) - - Create an upscaled pipeline with a factor of 4: - - >>> upscaled_pipeline = dt.Upscale(optics(particle), factor=4) - - Resolve the pipelines: - - >>> image = simple_pipeline() - >>> upscaled_image = upscaled_pipeline() - - Visualize the images: - - >>> import matplotlib.pyplot as plt - >>> - >>> plt.subplot(1, 2, 1) - >>> plt.imshow(image, cmap="gray") - >>> plt.title("Original Image") - >>> - >>> plt.subplot(1, 2, 2) - >>> plt.imshow(upscaled_image, cmap="gray") - >>> plt.title("Simulated at Higher Resolution") - >>> - >>> plt.show() - - Compare the shapes (both are the same due to downscaling): - - >>> print(image.shape) - (128, 128, 1) - >>> print(upscaled_image.shape) - (128, 128, 1) - - """ - - __distributed__: bool = False - - feature: Feature - - def __init__( - self: Feature, - feature: Feature, - factor: int | tuple[int, int, int] = 1, - **kwargs: Any, - ) -> None: - """Initialize the Upscale feature. - - Parameters - ---------- - feature: Feature - The pipeline or feature to resolve at a higher resolution. - factor: int or tuple[int, int, int], optional - The factor by which to upscale the simulation. If a single integer - is provided, it is applied uniformly across all axes. If a tuple of - three integers is provided, each axis is scaled individually. - Defaults to 1. - **kwargs: Any - Additional keyword arguments passed to the parent `Feature` class. - - """ - - super().__init__(factor=factor, **kwargs) - self.feature = self.add_feature(feature) - - def get( - self: Feature, - image: np.ndarray | torch.Tensor, - factor: int | tuple[int, int, int], - **kwargs: Any, - ) -> np.ndarray | torch.Tensor: - """Simulate the pipeline at a higher resolution and return result. - - Parameters - ---------- - image: np.ndarray or torch.Tensor - The input image to process. - factor: int or tuple[int, int, int] - The factor by which to upscale the simulation. If a single integer - is provided, it is applied uniformly across all axes. If a tuple of - three integers is provided, each axis is scaled individually. - **kwargs: Any - Additional keyword arguments passed to the feature. - - Returns - ------- - np.ndarray or torch.Tensor - The processed image at the original resolution. - - Raises - ------ - ValueError - If the input `factor` is not a valid integer or tuple of integers. - - """ - - # Ensure factor is a tuple of three integers. - if np.size(factor) == 1: - factor = (factor,) * 3 - elif len(factor) != 3: - raise ValueError( - "Factor must be an integer or a tuple of three integers." - ) - - # Create a context for upscaling and perform computation. - ctx = create_context(None, None, None, *factor) - with units.context(ctx): - image = self.feature(image) - - # Downscale the result to the original resolution. - import skimage.measure - - image = skimage.measure.block_reduce( - image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean - ) - - return image - - -class NonOverlapping(Feature): - """Ensure volumes are placed non-overlapping in a 3D space. +# class Upscale(Feature): +# """Simulate a pipeline at a higher resolution. - This feature ensures that a list of 3D volumes are positioned such that - their non-zero voxels do not overlap. If volumes overlap, their positions - are resampled until they are non-overlapping. If the maximum number of - attempts is exceeded, the feature regenerates the list of volumes and - raises a warning if non-overlapping placement cannot be achieved. +# This feature scales up the resolution of the input pipeline by a specified +# factor, performs computations at the higher resolution, and then +# downsamples the result back to the original size. This is useful for +# simulating effects at a finer resolution while preserving compatibility +# with lower-resolution pipelines. - Note: `min_distance` refers to the distance between the edges of volumes, - not their centers. Due to the way volumes are calculated, slight rounding - errors may affect the final distance. +# Internally, this feature redefines the scale of physical units (e.g., +# `units.pixel`) to achieve the effect of upscaling. Therefore, it does not +# resize the input image itself but affects only features that rely on +# physical units. + +# Parameters +# ---------- +# feature: Feature +# The pipeline or feature to resolve at a higher resolution. +# factor: int or tuple[int, int, int], optional +# The factor by which to upscale the simulation. If a single integer is +# provided, it is applied uniformly across all axes. If a tuple of three +# integers is provided, each axis is scaled individually. Defaults to 1. +# **kwargs: Any +# Additional keyword arguments passed to the parent `Feature` class. + +# Attributes +# ---------- +# __distributed__: bool +# Always `False` for `Upscale`, indicating that this feature’s `.get()` +# method processes the entire input at once even if it is a list, rather +# than distributing calls for each item of the list. + +# Methods +# ------- +# `get(image, factor, **kwargs) -> np.ndarray | torch.tensor` +# Simulates the pipeline at a higher resolution and returns the result at +# the original resolution. + +# Notes +# ----- +# - This feature does not directly resize the image. Instead, it modifies the +# unit conversions within the pipeline, making physical units smaller, +# which results in more detail being simulated. +# - The final output is downscaled back to the original resolution using +# `block_reduce` from `skimage.measure`. +# - The effect is only noticeable if features use physical units (e.g., +# `units.pixel`, `units.meter`). Otherwise, the result will be identical. + +# Examples +# -------- +# >>> import deeptrack as dt + +# Define an optical pipeline and a spherical particle: + +# >>> optics = dt.Fluorescence() +# >>> particle = dt.Sphere() +# >>> simple_pipeline = optics(particle) + +# Create an upscaled pipeline with a factor of 4: + +# >>> upscaled_pipeline = dt.Upscale(optics(particle), factor=4) - This feature is incompatible with non-volumetric scatterers such as - `MieScatterers`. +# Resolve the pipelines: + +# >>> image = simple_pipeline() +# >>> upscaled_image = upscaled_pipeline() + +# Visualize the images: + +# >>> import matplotlib.pyplot as plt +# >>> +# >>> plt.subplot(1, 2, 1) +# >>> plt.imshow(image, cmap="gray") +# >>> plt.title("Original Image") +# >>> +# >>> plt.subplot(1, 2, 2) +# >>> plt.imshow(upscaled_image, cmap="gray") +# >>> plt.title("Simulated at Higher Resolution") +# >>> +# >>> plt.show() - Parameters - ---------- - feature: Feature - The feature that generates the list of volumes to place - non-overlapping. - min_distance: float, optional - The minimum distance between volumes in pixels. It can be negative to - allow for partial overlap. Defaults to 1. - max_attempts: int, optional - The maximum number of attempts to place volumes without overlap. - Defaults to 5. - max_iters: int, optional - The maximum number of resamplings. If this number is exceeded, a new - list of volumes is generated. Defaults to 100. +# Compare the shapes (both are the same due to downscaling): - Attributes - ---------- - __distributed__: bool - Always `False` for `NonOverlapping`, indicating that this feature’s - `.get()` method processes the entire input at once even if it is a - list, rather than distributing calls for each item of the list.N - - Methods - ------- - `get(*_, min_distance, max_attempts, **kwargs) -> array` - Generate a list of non-overlapping 3D volumes. - `_check_non_overlapping(list_of_volumes) -> bool` - Check if all volumes in the list are non-overlapping. - `_check_bounding_cubes_non_overlapping(...) -> bool` - Check if two bounding cubes are non-overlapping. - `_get_overlapping_cube(...) -> list[int]` - Get the overlapping cube between two bounding cubes. - `_get_overlapping_volume(...) -> array` - Get the overlapping volume between a volume and a bounding cube. - `_check_volumes_non_overlapping(...) -> bool` - Check if two volumes are non-overlapping. - `_resample_volume_position(volume) -> Image` - Resample the position of a volume to avoid overlap. +# >>> print(image.shape) +# (128, 128, 1) +# >>> print(upscaled_image.shape) +# (128, 128, 1) - Notes - ----- - - This feature performs bounding cube checks first to quickly reject - obvious overlaps before voxel-level checks. - - If the bounding cubes overlap, precise voxel-based checks are performed. - - Examples - --------- - >>> import deeptrack as dt - - Define an ellipse scatterer with randomly positioned objects: - - >>> import numpy as np - >>> - >>> scatterer = dt.Ellipse( - >>> radius= 13 * dt.units.pixels, - >>> position=lambda: np.random.uniform(5, 115, size=2)* dt.units.pixels, - >>> ) - - Create multiple scatterers: - - >>> scatterers = (scatterer ^ 8) - - Define the optics and create the image with possible overlap: - - >>> optics = dt.Fluorescence() - >>> im_with_overlap = optics(scatterers) - >>> im_with_overlap.store_properties() - >>> im_with_overlap_resolved = image_with_overlap() - - Gather position from image: - - >>> pos_with_overlap = np.array( - >>> im_with_overlap_resolved.get_property( - >>> "position", - >>> get_one=False - >>> ) - >>> ) - - Enforce non-overlapping and create the image without overlap: - - >>> non_overlapping_scatterers = dt.NonOverlapping( - ... scatterers, - ... min_distance=4, - ... ) - >>> im_without_overlap = optics(non_overlapping_scatterers) - >>> im_without_overlap.store_properties() - >>> im_without_overlap_resolved = im_without_overlap() - - Gather position from image: - - >>> pos_without_overlap = np.array( - >>> im_without_overlap_resolved.get_property( - >>> "position", - >>> get_one=False - >>> ) - >>> ) - - Create a figure with two subplots to visualize the difference: - - >>> import matplotlib.pyplot as plt - >>> - >>> fig, axes = plt.subplots(1, 2, figsize=(10, 5)) - >>> - >>> axes[0].imshow(im_with_overlap_resolved, cmap="gray") - >>> axes[0].scatter(pos_with_overlap[:,1],pos_with_overlap[:,0]) - >>> axes[0].set_title("Overlapping Objects") - >>> axes[0].axis("off") - >>> - >>> axes[1].imshow(im_without_overlap_resolved, cmap="gray") - >>> axes[1].scatter(pos_without_overlap[:,1],pos_without_overlap[:,0]) - >>> axes[1].set_title("Non-Overlapping Objects") - >>> axes[1].axis("off") - >>> plt.tight_layout() - >>> - >>> plt.show() - - Define function to calculate minimum distance: - - >>> def calculate_min_distance(positions): - >>> distances = [ - >>> np.linalg.norm(positions[i] - positions[j]) - >>> for i in range(len(positions)) - >>> for j in range(i + 1, len(positions)) - >>> ] - >>> return min(distances) - - Print minimum distances with and without overlap: - - >>> print(calculate_min_distance(pos_with_overlap)) - 10.768742383382174 - - >>> print(calculate_min_distance(pos_without_overlap)) - 30.82531120942446 - - """ - - __distributed__: bool = False - - def __init__( - self: NonOverlapping, - feature: Feature, - min_distance: float = 1, - max_attempts: int = 5, - max_iters: int = 100, - **kwargs: Any, - ): - """Initializes the NonOverlapping feature. - - Ensures that volumes are placed **non-overlapping** by iteratively - resampling their positions. If the maximum number of attempts is - exceeded, the feature regenerates the list of volumes. - - Parameters - ---------- - feature: Feature - The feature that generates the list of volumes. - min_distance: float, optional - The minimum separation distance **between volume edges**, in - pixels. It defaults to `1`. Negative values allow for partial - overlap. - max_attempts: int, optional - The maximum number of attempts to place the volumes without - overlap. It defaults to `5`. - max_iters: int, optional - The maximum number of resampling iterations per attempt. If - exceeded, a new list of volumes is generated. It defaults to `100`. - - """ - - super().__init__( - min_distance=min_distance, - max_attempts=max_attempts, - max_iters=max_iters, - **kwargs, - ) - self.feature = self.add_feature(feature, **kwargs) - - def get( - self: NonOverlapping, - *_: Any, - min_distance: float, - max_attempts: int, - max_iters: int, - **kwargs: Any, - ) -> list[np.ndarray]: - """Generates a list of non-overlapping 3D volumes within a defined - field of view (FOV). - - This method **iteratively** attempts to place volumes while ensuring - they maintain at least `min_distance` separation. If non-overlapping - placement is not achieved within `max_attempts`, a warning is issued, - and the best available configuration is returned. - - Parameters - ---------- - _: Any - Placeholder parameter, typically for an input image. - min_distance: float - The minimum required separation distance between volumes, in - pixels. - max_attempts: int - The maximum number of attempts to generate a valid non-overlapping - configuration. - max_iters: int - The maximum number of resampling iterations per attempt. - **kwargs: Any - Additional parameters that may be used by subclasses. - - Returns - ------- - list[np.ndarray] - A list of 3D volumes represented as NumPy arrays. If - non-overlapping placement is unsuccessful, the best available - configuration is returned. - - Warns - ----- - UserWarning - If non-overlapping placement is **not** achieved within - `max_attempts`, suggesting parameter adjustments such as increasing - the FOV or reducing `min_distance`. - - Notes - ----- - - The placement process prioritizes bounding cube checks for - efficiency. - - If bounding cubes overlap, voxel-based overlap checks are performed. +# """ + +# __distributed__: bool = False + +# feature: Feature + +# def __init__( +# self: Feature, +# feature: Feature, +# factor: int | tuple[int, int, int] = 1, +# **kwargs: Any, +# ) -> None: +# """Initialize the Upscale feature. + +# Parameters +# ---------- +# feature: Feature +# The pipeline or feature to resolve at a higher resolution. +# factor: int or tuple[int, int, int], optional +# The factor by which to upscale the simulation. If a single integer +# is provided, it is applied uniformly across all axes. If a tuple of +# three integers is provided, each axis is scaled individually. +# Defaults to 1. +# **kwargs: Any +# Additional keyword arguments passed to the parent `Feature` class. + +# """ + +# super().__init__(factor=factor, **kwargs) +# self.feature = self.add_feature(feature) + +# def get( +# self: Feature, +# image: np.ndarray | torch.Tensor, +# factor: int | tuple[int, int, int], +# **kwargs: Any, +# ) -> np.ndarray | torch.Tensor: +# """Simulate the pipeline at a higher resolution and return result. + +# Parameters +# ---------- +# image: np.ndarray or torch.Tensor +# The input image to process. +# factor: int or tuple[int, int, int] +# The factor by which to upscale the simulation. If a single integer +# is provided, it is applied uniformly across all axes. If a tuple of +# three integers is provided, each axis is scaled individually. +# **kwargs: Any +# Additional keyword arguments passed to the feature. + +# Returns +# ------- +# np.ndarray or torch.Tensor +# The processed image at the original resolution. + +# Raises +# ------ +# ValueError +# If the input `factor` is not a valid integer or tuple of integers. + +# """ + +# # Ensure factor is a tuple of three integers. +# if np.size(factor) == 1: +# factor = (factor, factor, 1) +# elif len(factor) != 3: +# raise ValueError( +# "Factor must be an integer or a tuple of three integers." +# ) - """ - - for _ in range(max_attempts): - list_of_volumes = self.feature() - - if not isinstance(list_of_volumes, list): - list_of_volumes = [list_of_volumes] - - for _ in range(max_iters): - - list_of_volumes = [ - self._resample_volume_position(volume) - for volume in list_of_volumes - ] +# # Create a context for upscaling and perform computation. +# ctx = create_context(None, None, None, *factor) - if self._check_non_overlapping(list_of_volumes): - return list_of_volumes +# print('before:', image) +# with units.context(ctx): +# image = self.feature(image) - # Generate a new list of volumes if max_attempts is exceeded. - self.feature.update() +# print('after:', image) +# # Downscale the result to the original resolution. +# import skimage.measure - warnings.warn( - "Non-overlapping placement could not be achieved. Consider " - "adjusting parameters: reduce object radius, increase FOV, " - "or decrease min_distance.", - UserWarning, - ) - return list_of_volumes - - def _check_non_overlapping( - self: NonOverlapping, - list_of_volumes: list[np.ndarray], - ) -> bool: - """Determines whether all volumes in the provided list are - non-overlapping. - - This method verifies that the non-zero voxels of each 3D volume in - `list_of_volumes` are at least `min_distance` apart. It first checks - bounding boxes for early rejection and then examines actual voxel - overlap when necessary. Volumes are assumed to have a `position` - attribute indicating their placement in 3D space. - - Parameters - ---------- - list_of_volumes: list[np.ndarray] - A list of 3D arrays representing the volumes to be checked for - overlap. Each volume is expected to have a position attribute. - - Returns - ------- - bool - `True` if all volumes are non-overlapping, otherwise `False`. - - Notes - ----- - - If `min_distance` is negative, volumes are shrunk using isotropic - erosion before checking overlap. - - If `min_distance` is positive, volumes are padded and expanded using - isotropic dilation. - - Overlapping checks are first performed on bounding cubes for - efficiency. - - If bounding cubes overlap, voxel-level checks are performed. - - """ - - from skimage.morphology import isotropic_erosion, isotropic_dilation - - from deeptrack.augmentations import CropTight, Pad - from deeptrack.optics import _get_position - - min_distance = self.min_distance() - crop = CropTight() - - if min_distance < 0: - list_of_volumes = [ - Image( - crop(isotropic_erosion(volume != 0, -min_distance/2)), - copy=False, - ).merge_properties_from(volume) - for volume in list_of_volumes - ] - else: - pad = Pad(px = [int(np.ceil(min_distance/2))]*6, keep_size=True) - list_of_volumes = [ - Image( - crop(isotropic_dilation(pad(volume) != 0, min_distance/2)), - copy=False, - ).merge_properties_from(volume) - for volume in list_of_volumes - ] - min_distance = 1 - - # The position of the top left corner of each volume (index (0, 0, 0)). - volume_positions_1 = [ - _get_position(volume, mode="corner", return_z=True).astype(int) - for volume in list_of_volumes - ] - - # The position of the bottom right corner of each volume - # (index (-1, -1, -1)). - volume_positions_2 = [ - p0 + np.array(v.shape) - for v, p0 in zip(list_of_volumes, volume_positions_1) - ] - - # (x1, y1, z1, x2, y2, z2) for each volume. - volume_bounding_cube = [ - [*p0, *p1] - for p0, p1 in zip(volume_positions_1, volume_positions_2) - ] - - for i, j in itertools.combinations(range(len(list_of_volumes)), 2): - - # If the bounding cubes do not overlap, the volumes do not overlap. - if self._check_bounding_cubes_non_overlapping( - volume_bounding_cube[i], volume_bounding_cube[j], min_distance - ): - continue - - # If the bounding cubes overlap, get the overlapping region of each - # volume. - overlapping_cube = self._get_overlapping_cube( - volume_bounding_cube[i], volume_bounding_cube[j] - ) - overlapping_volume_1 = self._get_overlapping_volume( - list_of_volumes[i], volume_bounding_cube[i], overlapping_cube - ) - overlapping_volume_2 = self._get_overlapping_volume( - list_of_volumes[j], volume_bounding_cube[j], overlapping_cube - ) - - # If either the overlapping regions are empty, the volumes do not - # overlap (done for speed). - if (np.all(overlapping_volume_1 == 0) - or np.all(overlapping_volume_2 == 0)): - continue - - # If products of overlapping regions are non-zero, return False. - # if np.any(overlapping_volume_1 * overlapping_volume_2): - # return False - - # Finally, check that the non-zero voxels of the volumes are at - # least min_distance apart. - if not self._check_volumes_non_overlapping( - overlapping_volume_1, overlapping_volume_2, min_distance - ): - return False - - return True - - def _check_bounding_cubes_non_overlapping( - self: NonOverlapping, - bounding_cube_1: list[int], - bounding_cube_2: list[int], - min_distance: float, - ) -> bool: - """Determines whether two 3D bounding cubes are non-overlapping. - - This method checks whether the bounding cubes of two volumes are - **separated by at least** `min_distance` along **any** spatial axis. - - Parameters - ---------- - bounding_cube_1: list[int] - A list of six integers `[x1, y1, z1, x2, y2, z2]` representing - the first bounding cube. - bounding_cube_2: list[int] - A list of six integers `[x1, y1, z1, x2, y2, z2]` representing - the second bounding cube. - min_distance: float - The required **minimum separation distance** between the two - bounding cubes. - - Returns - ------- - bool - `True` if the bounding cubes are non-overlapping (separated by at - least `min_distance` along **at least one axis**), otherwise - `False`. - - Notes - ----- - - This function **only checks bounding cubes**, **not actual voxel - data**. - - If the bounding cubes are non-overlapping, the corresponding - **volumes are also non-overlapping**. - - This check is much **faster** than full voxel-based comparisons. - - """ - - # bounding_cube_1 and bounding_cube_2 are (x1, y1, z1, x2, y2, z2). - # Check that the bounding cubes are non-overlapping. - return ( - (bounding_cube_1[0] >= bounding_cube_2[3] + min_distance) or - (bounding_cube_2[0] >= bounding_cube_1[3] + min_distance) or - (bounding_cube_1[1] >= bounding_cube_2[4] + min_distance) or - (bounding_cube_2[1] >= bounding_cube_1[4] + min_distance) or - (bounding_cube_1[2] >= bounding_cube_2[5] + min_distance) or - (bounding_cube_2[2] >= bounding_cube_1[5] + min_distance) - ) - - def _get_overlapping_cube( - self: NonOverlapping, - bounding_cube_1: list[int], - bounding_cube_2: list[int], - ) -> list[int]: - """Computes the overlapping region between two 3D bounding cubes. - - This method calculates the coordinates of the intersection of two - axis-aligned bounding cubes, each represented as a list of six - integers: - - - `[x1, y1, z1]`: Coordinates of the **top-left-front** corner. - - `[x2, y2, z2]`: Coordinates of the **bottom-right-back** corner. - - The resulting overlapping region is determined by: - - Taking the **maximum** of the starting coordinates (`x1, y1, z1`). - - Taking the **minimum** of the ending coordinates (`x2, y2, z2`). - - If the cubes **do not** overlap, the resulting coordinates will not - form a valid cube (i.e., `x1 > x2`, `y1 > y2`, or `z1 > z2`). - - Parameters - ---------- - bounding_cube_1: list[int] - The first bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`. - bounding_cube_2: list[int] - The second bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`. - - Returns - ------- - list[int] - A list of six integers `[x1, y1, z1, x2, y2, z2]` representing the - overlapping bounding cube. If no overlap exists, the coordinates - will **not** define a valid cube. - - Notes - ----- - - This function does **not** check for valid input or ensure the - resulting cube is well-formed. - - If no overlap exists, downstream functions must handle the invalid - result. - - """ - - return [ - max(bounding_cube_1[0], bounding_cube_2[0]), - max(bounding_cube_1[1], bounding_cube_2[1]), - max(bounding_cube_1[2], bounding_cube_2[2]), - min(bounding_cube_1[3], bounding_cube_2[3]), - min(bounding_cube_1[4], bounding_cube_2[4]), - min(bounding_cube_1[5], bounding_cube_2[5]), - ] - - def _get_overlapping_volume( - self: NonOverlapping, - volume: np.ndarray, # 3D array. - bounding_cube: tuple[float, float, float, float, float, float], - overlapping_cube: tuple[float, float, float, float, float, float], - ) -> np.ndarray: - """Extracts the overlapping region of a 3D volume within the specified - overlapping cube. - - This method identifies and returns the subregion of `volume` that - lies within the `overlapping_cube`. The bounding information of the - volume is provided via `bounding_cube`. - - Parameters - ---------- - volume: np.ndarray - A 3D NumPy array representing the volume from which the - overlapping region is extracted. - bounding_cube: tuple[float, float, float, float, float, float] - The bounding cube of the volume, given as a tuple of six floats: - `(x1, y1, z1, x2, y2, z2)`. The first three values define the - **top-left-front** corner, while the last three values define the - **bottom-right-back** corner. - overlapping_cube: tuple[float, float, float, float, float, float] - The overlapping region between the volume and another volume, - represented in the same format as `bounding_cube`. - - Returns - ------- - np.ndarray - A 3D NumPy array representing the portion of `volume` that - lies within `overlapping_cube`. If the overlap does not exist, - an empty array may be returned. - - Notes - ----- - - The method computes the relative indices of `overlapping_cube` - within `volume` by subtracting the bounding cube's starting - position. - - The extracted region is determined by integer indices, meaning - coordinates are implicitly **floored to integers**. - - If `overlapping_cube` extends beyond `volume` boundaries, the - returned subregion is **cropped** to fit within `volume`. - - """ - - # The position of the top left corner of the overlapping cube in the volume - overlapping_cube_position = np.array(overlapping_cube[:3]) - np.array( - bounding_cube[:3] - ) - - # The position of the bottom right corner of the overlapping cube in the volume - overlapping_cube_end_position = np.array( - overlapping_cube[3:] - ) - np.array(bounding_cube[:3]) - - # cast to int - overlapping_cube_position = overlapping_cube_position.astype(int) - overlapping_cube_end_position = overlapping_cube_end_position.astype(int) - - return volume[ - overlapping_cube_position[0] : overlapping_cube_end_position[0], - overlapping_cube_position[1] : overlapping_cube_end_position[1], - overlapping_cube_position[2] : overlapping_cube_end_position[2], - ] - - def _check_volumes_non_overlapping( - self: NonOverlapping, - volume_1: np.ndarray, - volume_2: np.ndarray, - min_distance: float, - ) -> bool: - """Determines whether the non-zero voxels in two 3D volumes are at - least `min_distance` apart. - - This method checks whether the active regions (non-zero voxels) in - `volume_1` and `volume_2` maintain a minimum separation of - `min_distance`. If the volumes differ in size, the positions of their - non-zero voxels are adjusted accordingly to ensure a fair comparison. - - Parameters - ---------- - volume_1: np.ndarray - A 3D NumPy array representing the first volume. - volume_2: np.ndarray - A 3D NumPy array representing the second volume. - min_distance: float - The minimum Euclidean distance required between any two non-zero - voxels in the two volumes. - - Returns - ------- - bool - `True` if all non-zero voxels in `volume_1` and `volume_2` are at - least `min_distance` apart, otherwise `False`. - - Notes - ----- - - This function assumes both volumes are correctly aligned within a - shared coordinate space. - - If the volumes are of different sizes, voxel positions are scaled - or adjusted for accurate distance measurement. - - Uses **Euclidean distance** for separation checking. - - If either volume is empty (i.e., no non-zero voxels), they are - considered non-overlapping. - - """ - - # Get the positions of the non-zero voxels of each volume. - positions_1 = np.argwhere(volume_1) - positions_2 = np.argwhere(volume_2) - - # if positions_1.size == 0 or positions_2.size == 0: - # return True # If either volume is empty, they are "non-overlapping" - - # # If the volumes are not the same size, the positions of the non-zero - # # voxels of each volume need to be scaled. - # if positions_1.size == 0 or positions_2.size == 0: - # return True # If either volume is empty, they are "non-overlapping" - - # If the volumes are not the same size, the positions of the non-zero - # voxels of each volume need to be scaled. - if volume_1.shape != volume_2.shape: - positions_1 = ( - positions_1 * np.array(volume_2.shape) - / np.array(volume_1.shape) - ) - positions_1 = positions_1.astype(int) - - # Check that the non-zero voxels of the volumes are at least - # min_distance apart. - return np.all( - cdist(positions_1, positions_2) > min_distance - ) - - def _resample_volume_position( - self: NonOverlapping, - volume: np.ndarray | Image, - ) -> Image: - """Resamples the position of a 3D volume using its internal position - sampler. - - This method updates the `position` property of the given `volume` by - drawing a new position from the `_position_sampler` stored in the - volume's `properties`. If the sampled position is a `Quantity`, it is - converted to pixel units. - - Parameters - ---------- - volume: np.ndarray or Image - The 3D volume whose position is to be resampled. The volume must - have a `properties` attribute containing dictionaries with - `position` and `_position_sampler` keys. - - Returns - ------- - Image - The same input volume with its `position` property updated to the - newly sampled value. - - Notes - ----- - - The `_position_sampler` function is expected to return a **tuple of - three floats** (e.g., `(x, y, z)`). - - If the sampled position is a `Quantity`, it is converted to pixels. - - **Only** dictionaries in `volume.properties` that contain both - `position` and `_position_sampler` keys are modified. - - """ +# image = skimage.measure.block_reduce( +# image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean +# ) - for pdict in volume.properties: - if "position" in pdict and "_position_sampler" in pdict: - new_position = pdict["_position_sampler"]() - if isinstance(new_position, Quantity): - new_position = new_position.to("pixel").magnitude - pdict["position"] = new_position +# return image - return volume class Store(Feature): @@ -9593,4 +8639,4 @@ def get( if len(res) == 1: res = res[0] - return res + return res \ No newline at end of file diff --git a/deeptrack/holography.py b/deeptrack/holography.py index 380969cf..141cc540 100644 --- a/deeptrack/holography.py +++ b/deeptrack/holography.py @@ -101,7 +101,7 @@ def get_propagation_matrix( def get_propagation_matrix( shape: tuple[int, int], to_z: float, - pixel_size: float, + pixel_size: float | tuple[float, float], wavelength: float, dx: float = 0, dy: float = 0 @@ -118,8 +118,8 @@ def get_propagation_matrix( The dimensions of the optical field (height, width). to_z: float Propagation distance along the z-axis. - pixel_size: float - The physical size of each pixel in the optical field. + pixel_size: float | tuple[float, float] + Physical pixel size. If scalar, isotropic pixels are assumed. wavelength: float The wavelength of the optical field. dx: float, optional @@ -140,14 +140,22 @@ def get_propagation_matrix( """ + if pixel_size is None: + pixel_size = get_active_voxel_size() + + if np.isscalar(pixel_size): + pixel_size = (pixel_size, pixel_size) + + px, py = pixel_size + k = 2 * np.pi / wavelength yr, xr, *_ = shape x = np.arange(0, xr, 1) - xr / 2 + (xr % 2) / 2 y = np.arange(0, yr, 1) - yr / 2 + (yr % 2) / 2 - x = 2 * np.pi / pixel_size * x / xr - y = 2 * np.pi / pixel_size * y / yr + x = 2 * np.pi / px * x / xr + y = 2 * np.pi / py * y / yr KXk, KYk = np.meshgrid(x, y) KXk = KXk.astype(complex) diff --git a/deeptrack/math.py b/deeptrack/math.py index 05cbf311..40a8adc7 100644 --- a/deeptrack/math.py +++ b/deeptrack/math.py @@ -59,6 +59,8 @@ - `MinPooling`: Apply min-pooling to the image. +- `SumPooling`: Apply sum pooling to the image. + - `MedianPooling`: Apply median pooling to the image. - `Resize`: Resize the image to a specified size. @@ -93,23 +95,22 @@ from __future__ import annotations -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, Callable, Dict, Tuple, TYPE_CHECKING import array_api_compat as apc import numpy as np -from numpy.typing import NDArray from scipy import ndimage import skimage import skimage.measure from deeptrack import utils, OPENCV_AVAILABLE, TORCH_AVAILABLE -from deeptrack.features import Feature -from deeptrack.image import Image, strip -from deeptrack.types import ArrayLike, PropertyLike +from deeptrack.features import Feature, BackendDispatched +from deeptrack.types import PropertyLike from deeptrack.backend import xp if TORCH_AVAILABLE: import torch + import torch.nn.functional as F if OPENCV_AVAILABLE: import cv2 @@ -128,12 +129,13 @@ "AveragePooling", "MaxPooling", "MinPooling", + "SumPooling", "MedianPooling", + "Resize", "BlurCV2", "BilateralBlur", ] - if TYPE_CHECKING: import torch @@ -227,10 +229,10 @@ def __init__( def get( self: Average, - images: list[NDArray[Any] | torch.Tensor | Image], + images: list[np.ndarray | torch.Tensor], axis: int | tuple[int], **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor | Image: + ) -> np.ndarray | torch.Tensor: """Compute the average of input images along the specified axis(es). This method computes the average of the input images along the @@ -297,8 +299,8 @@ class Clip(Feature): def __init__( self: Clip, - min: PropertyLike[float] = -np.inf, - max: PropertyLike[float] = +np.inf, + min: PropertyLike[float] = -xp.inf, + max: PropertyLike[float] = +xp.inf, **kwargs: Any, ): """Initialize the clipping range. @@ -306,9 +308,9 @@ def __init__( Parameters ---------- min: float, optional - Minimum allowed value. It defaults to `-np.inf`. + Minimum allowed value. It defaults to `-xp.inf`. max: float, optional - Maximum allowed value. It defaults to `+np.inf`. + Maximum allowed value. It defaults to `+xp.inf`. **kwargs: Any Additional keyword arguments. @@ -318,11 +320,11 @@ def __init__( def get( self: Clip, - image: NDArray[Any] | torch.Tensor | Image, + image: np.ndarray | torch.Tensor, min: float, max: float, **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor | Image: + ) -> np.ndarray | torch.Tensor: """Clips the input image within the specified values. This method clips the input image within the specified minimum and @@ -363,8 +365,7 @@ class NormalizeMinMax(Feature): max: float, optional Upper bound of the transformation. It defaults to 1. featurewise: bool, optional - Whether to normalize each feature independently. It default to `True`, - which is the only behavior currently implemented. + Whether to normalize each feature independently. It default to `True`. Methods ------- @@ -390,8 +391,6 @@ class NormalizeMinMax(Feature): """ - #TODO ___??___ Implement the `featurewise=False` option - def __init__( self: NormalizeMinMax, min: PropertyLike[float] = 0, @@ -418,42 +417,57 @@ def __init__( def get( self: NormalizeMinMax, - image: ArrayLike, + image: np.ndarray | torch.Tensor, min: float, max: float, + featurewise: bool = True, **kwargs: Any, - ) -> ArrayLike: + ) -> np.ndarray | torch.Tensor: """Normalize the input to fall between `min` and `max`. Parameters ---------- - image: array + image: np.ndarray or torch.Tensor Input image to normalize. min: float Lower bound of the output range. max: float Upper bound of the output range. + featurewise: bool + Whether to normalize each feature (channel) independently. Returns ------- - array + np.ndarray or torch.Tensor Min-max normalized image. """ - ptp = xp.max(image) - xp.min(image) - image = image / ptp * (max - min) - image = image - xp.min(image) + min + has_channels = image.ndim >= 3 and image.shape[-1] <= 4 - try: - image[xp.isnan(image)] = 0 - except TypeError: - pass + if featurewise and has_channels: + # reduce over spatial dimensions only + axis = tuple(range(image.ndim - 1)) + img_min = xp.min(image, axis=axis, keepdims=True) + img_max = xp.max(image, axis=axis, keepdims=True) + else: + # global normalization + img_min = xp.min(image) + img_max = xp.max(image) + + ptp = img_max - img_min + eps = xp.asarray(1e-8, dtype=image.dtype) + ptp = xp.maximum(ptp, eps) + image = (image - img_min) / ptp + image = image * (max - min) + min + + image = xp.where(xp.isnan(image), xp.zeros_like(image), image) return image -class NormalizeStandard(Feature): + +class NormalizeStandard(BackendDispatched, Feature): """Image normalization using standardization. Standardizes the input image to have zero mean and unit standard @@ -487,7 +501,8 @@ class NormalizeStandard(Feature): """ - #TODO ___??___ Implement the `featurewise=False` option + _NUMPY_IMPL = "_get_numpy" + _TORCH_IMPL = "_get_torch" def __init__( self: NormalizeStandard, @@ -511,36 +526,81 @@ def __init__( def get( self: NormalizeStandard, - image: NDArray[Any] | torch.Tensor | Image, + image: np.ndarray | torch.Tensor, + featurewise: bool, **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor | Image: + ) -> np.ndarray | torch.Tensor: """Normalizes the input image to have mean 0 and standard deviation 1. - This method normalizes the input image to have mean 0 and standard - deviation 1. - Parameters ---------- - image: array + image: np.ndarray or torch.Tensor The input image to normalize. + featurewise: bool + Whether to normalize each feature (channel) independently. Returns ------- - array - The normalized image. - + np.ndarray or torch.Tensor + The standardized image. """ + return self._dispatch_backend( + image, + featurewise=featurewise, + ) - if apc.is_torch_array(image): - # By default, torch.std() is unbiased, i.e., divides by N-1 - return ( - (image - torch.mean(image)) / torch.std(image, unbiased=False) - ) + # ------ NumPy backend ------ + + def _get_numpy( + self, + image: np.ndarray, + featurewise: bool, + ) -> np.ndarray: + + has_channels = image.ndim >= 3 and image.shape[-1] <= 4 + + if featurewise and has_channels: + axis = tuple(range(image.ndim - 1)) + mean = np.mean(image, axis=axis, keepdims=True) + std = np.std(image, axis=axis, keepdims=True) # population std + else: + mean = np.mean(image) + std = np.std(image) + + std = np.maximum(std, 1e-8) + + out = (image - mean) / std + out = np.where(np.isnan(out), 0.0, out) + + return out + + # ------ Torch backend ------ + + def _get_torch( + self, + image: torch.Tensor, + featurewise: bool, + ) -> torch.Tensor: + + has_channels = image.ndim >= 3 and image.shape[-1] <= 4 + + if featurewise and has_channels: + axis = tuple(range(image.ndim - 1)) + mean = image.mean(dim=axis, keepdim=True) + std = image.std(dim=axis, keepdim=True, unbiased=False) + else: + mean = image.mean() + std = image.std(unbiased=False) - return (image - xp.mean(image)) / xp.std(image) + std = torch.clamp(std, min=1e-8) + out = (image - mean) / std + out = torch.nan_to_num(out, nan=0.0) -class NormalizeQuantile(Feature): + return out + + +class NormalizeQuantile(BackendDispatched, Feature): """Image normalization using quantiles. Centers the image at the median and scales it such that the values at the @@ -560,6 +620,12 @@ class NormalizeQuantile(Feature): get(image: array, quantiles: tuple[float, float], **kwargs) -> array Normalizes the input based on the given quantile range. + Notes + ----- + This operation is not differentiable. When used inside a gradient-based + model, it will block gradient flow. Use with care if end-to-end + differentiability is required. + Examples -------- >>> import deeptrack as dt @@ -578,7 +644,8 @@ class NormalizeQuantile(Feature): """ - #TODO ___??___ Implement the `featurewise=False` option + _NUMPY_IMPL = "_get_numpy" + _TORCH_IMPL = "_get_torch" def __init__( self: NormalizeQuantile, @@ -608,159 +675,165 @@ def __init__( ) def get( + self, + image: np.ndarray | torch.Tensor, + quantiles: tuple[float, float], + featurewise: bool, + **kwargs: Any, + ): + return self._dispatch_backend( + image, + quantiles=quantiles, + featurewise=featurewise, + ) + + def _get_numpy( self: NormalizeQuantile, - image: NDArray[Any] | torch.Tensor | Image, - quantiles: tuple[float, float] = None, + image: np.ndarray, + quantiles: tuple[float, float], + featurewise: bool, **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor | Image: + ) -> np.ndarray: """Normalize the input image based on the specified quantiles. - This method normalizes the input image based on the specified - quantiles. - Parameters ---------- - image: array + image: np.ndarray or torch.Tensor The input image to normalize. quantiles: tuple[float, float] Quantile range to calculate scaling factor. + featurewise: bool + Whether to normalize each feature (channel) independently. Returns ------- - array - The normalized image. - + np.ndarray or torch.Tensor + The quantile-normalized image. + """ - if apc.is_torch_array(image): - q_tensor = torch.tensor( - [*quantiles, 0.5], - device=image.device, - dtype=image.dtype, - ) - q_low, q_high, median = torch.quantile( - image, q_tensor, dim=None, keepdim=False, - ) - else: # NumPy - q_low, q_high, median = xp.quantile(image, (*quantiles, 0.5)) - - return (image - median) / (q_high - q_low) * 2.0 + q_low_val, q_high_val = quantiles + has_channels = image.ndim >= 3 and image.shape[-1] <= 4 -#TODO ***JH*** revise Blur - torch, typing, docstring, unit test -class Blur(Feature): - """Apply a blurring filter to an image. + if featurewise and has_channels: + axis = tuple(range(image.ndim - 1)) + q_low, q_high, median = np.quantile( + image, + (q_low_val, q_high_val, 0.5), + axis=axis, + keepdims=True, + ) + else: + q_low, q_high, median = np.quantile( + image, + (q_low_val, q_high_val, 0.5), + ) - This class applies a blurring filter to an image. The filter function - must be a function that takes an input image and returns a blurred - image. + scale = q_high - q_low + eps = np.asarray(1e-8, dtype=image.dtype) + scale = np.maximum(scale, eps) - Parameters - ---------- - filter_function: Callable - The blurring function to apply. This function must accept the input - image as a keyword argument named `input`. If using OpenCV functions - (e.g., `cv2.GaussianBlur`), use `BlurCV2` instead. - mode: str - Border mode for handling boundaries (e.g., 'reflect'). + image = (image - median) / scale + image = np.where(np.isnan(image), np.zeros_like(image), image) + return image + + def _get_torch( + self, + image: torch.Tensor, + quantiles: tuple[float, float], + featurewise: bool, + ): + q_low_val, q_high_val = quantiles + + if featurewise: + if image.ndim < 3: + # No channels → global quantile + q = torch.tensor( + [q_low_val, q_high_val, 0.5], + device=image.device, + dtype=image.dtype, + ) + q_low, q_high, median = torch.quantile(image, q) + else: + # channels-last: (..., C) + spatial_dims = image.ndim - 1 + C = image.shape[-1] + + # flatten spatial dims + x = image.reshape(-1, C) # (N, C) + + q = torch.tensor( + [q_low_val, q_high_val, 0.5], + device=image.device, + dtype=image.dtype, + ) - Methods - ------- - `get(image: np.ndarray | Image, **kwargs: Any) --> np.ndarray` - Applies the blurring filter to the input image. + q_vals = torch.quantile(x, q, dim=0) + q_low, q_high, median = q_vals - Examples - -------- - >>> import deeptrack as dt - >>> import numpy as np - >>> from scipy.ndimage import convolve + # reshape for broadcasting + shape = [1] * image.ndim + shape[-1] = C + q_low = q_low.view(shape) + q_high = q_high.view(shape) + median = median.view(shape) - Create an input image: - >>> input_image = np.random.rand(32, 32) + else: + q = torch.tensor( + [q_low_val, q_high_val, 0.5], + device=image.device, + dtype=image.dtype, + ) + q_low, q_high, median = torch.quantile(image, q) - Define a Gaussian kernel for blurring: - >>> gaussian_kernel = np.array([ - ... [1, 4, 6, 4, 1], - ... [4, 16, 24, 16, 4], - ... [6, 24, 36, 24, 6], - ... [4, 16, 24, 16, 4], - ... [1, 4, 6, 4, 1] - ... ], dtype=float) - >>> gaussian_kernel /= np.sum(gaussian_kernel) + scale = q_high - q_low + scale = torch.clamp(scale, min=1e-8) + image = (image - median) / scale + image = torch.nan_to_num(image) - Define a blur function using the Gaussian kernel: - >>> def gaussian_blur(input, **kwargs): - ... return convolve(input, gaussian_kernel, mode='reflect') + return image - Define a blur feature using the Gaussian blur function: - >>> blur = dt.Blur(filter_function=gaussian_blur) - >>> output_image = blur(input_image) - >>> print(output_image.shape) - (32, 32) - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - The filter_function must accept the input image as a keyword argument named - input. This is required because it is called via utils.safe_call. If you - are using functions that do not support input=... (such as OpenCV filters - like cv2.GaussianBlur), consider using BlurCV2 instead. +#TODO ***CM*** revise typing, docstring, unit test +class Blur(BackendDispatched, Feature): + """Abstract blur feature with backend-dispatched implementations. + + This class serves as a base for blur features that support multiple + backends (e.g., NumPy, Torch). Subclasses should implement backend-specific + blurring logic via `_get_numpy` and/or `_get_torch` methods. + + Methods + ------- + get(image: np.ndarray | torch.Tensor, **kwargs) -> np.ndarray | torch.Tensor + Applies the appropriate backend-specific blurring method. + + _blur(xp, image: array, **kwargs) -> array + Internal method that dispatches to the correct backend-specific blur + implementation. """ - def __init__( - self: Blur, - filter_function: Callable, - mode: PropertyLike[str] = "reflect", - **kwargs: Any, - ): - """Initialize the parameters for blurring input features. - - This constructor initializes the parameters for blurring input - features. - - Parameters - ---------- - filter_function: Callable - The blurring function to apply. - mode: str - Border mode for handling boundaries (e.g., 'reflect'). - **kwargs: Any - Additional keyword arguments. - - """ - - self.filter = filter_function - super().__init__(borderType=mode, **kwargs) - - def get(self: Blur, image: np.ndarray | Image, **kwargs: Any) -> np.ndarray: - """Applies the blurring filter to the input image. - - This method applies the blurring filter to the input image. + _NUMPY_IMPL = "_get_numpy" + _TORCH_IMPL = "_get_torch" - Parameters - ---------- - image: np.ndarray - The input image to blur. - **kwargs: dict[str, Any] - Additional keyword arguments. + def get( + self, + image: np.ndarray | torch.Tensor, + **kwargs, + ): + return self._dispatch_backend(image, **kwargs) - Returns - ------- - np.ndarray - The blurred image. + def _get_numpy(self, image: np.ndarray, **kwargs): + raise NotImplementedError - """ + def _get_torch(self, image: torch.Tensor, **kwargs): + raise NotImplementedError - kwargs.pop("input", False) - return utils.safe_call(self.filter, input=image, **kwargs) -#TODO ***JH*** revise AverageBlur - torch, typing, docstring, unit test +#TODO ***CM*** revise AverageBlur - torch, typing, docstring, unit test class AverageBlur(Blur): """Blur an image by computing simple means over neighbourhoods. @@ -774,7 +847,7 @@ class AverageBlur(Blur): Methods ------- - `get(image: np.ndarray | Image, ksize: int, **kwargs: Any) --> np.ndarray` + `get(image: np.ndarray | torch.Tensor, ksize: int, **kwargs: Any) --> np.ndarray | torch.Tensor` Applies the average blurring filter to the input image. Examples @@ -791,20 +864,13 @@ class AverageBlur(Blur): >>> print(output_image.shape) (32, 32) - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - """ def __init__( - self: AverageBlur, - ksize: PropertyLike[int] = 3, - **kwargs: Any, - ): + self: AverageBlur, + ksize: int = 3, + **kwargs: Any + ) -> None: """Initialize the parameters for averaging input features. This constructor initializes the parameters for averaging input @@ -819,125 +885,122 @@ def __init__( """ - super().__init__(None, ksize=ksize, **kwargs) + self.ksize = int(ksize) + super().__init__(**kwargs) - def _kernel_shape(self, shape: tuple[int, ...], ksize: int) -> tuple[int, ...]: + @staticmethod + def _kernel_shape(shape: tuple[int, ...], ksize: int) -> tuple[int, ...]: + # If last dim is channel and smaller than kernel, do not blur channels if shape[-1] < ksize: return (ksize,) * (len(shape) - 1) + (1,) return (ksize,) * len(shape) + # ---------- NumPy backend ---------- def _get_numpy( - self, input: np.ndarray, ksize: tuple[int, ...], **kwargs: Any + self: AverageBlur, + image: np.ndarray, + **kwargs: Any ) -> np.ndarray: + """Apply average blurring using SciPy's uniform_filter. + + This method applies average blurring to the input image using + SciPy's `uniform_filter`. + + Parameters + ---------- + image: np.ndarray + The input image to blur. + **kwargs: dict[str, Any] + Additional keyword arguments for `uniform_filter`. + + Returns + ------- + np.ndarray + The blurred image. + + """ + + k = self._kernel_shape(image.shape, self.ksize) return ndimage.uniform_filter( - input, - size=ksize, + image, + size=k, mode=kwargs.get("mode", "reflect"), cval=kwargs.get("cval", 0), origin=kwargs.get("origin", 0), - axes=tuple(range(0, len(ksize))), + axes=tuple(range(len(k))), ) + # ---------- Torch backend ---------- def _get_torch( - self, input: torch.Tensor, ksize: tuple[int, ...], **kwargs: Any - ) -> np.ndarray: - F = xp.nn.functional + self: AverageBlur, + image: torch.Tensor, + **kwargs: Any + ) -> torch.Tensor: + """Apply average blurring using PyTorch's avg_pool. + + This method applies average blurring to the input image using + PyTorch's `avg_pool` functions. + + Parameters + ---------- + image: torch.Tensor + The input image to blur. + **kwargs: dict[str, Any] + Additional keyword arguments for padding. + + Returns + ------- + torch.Tensor + The blurred image. + + """ + + k = self._kernel_shape(tuple(image.shape), self.ksize) - last_dim_is_channel = len(ksize) < input.ndim + last_dim_is_channel = len(k) < image.ndim if last_dim_is_channel: - # permute to first dim - input = input.movedim(-1, 0) + image = image.movedim(-1, 0) # C, ... else: - input = input.unsqueeze(0) + image = image.unsqueeze(0) # 1, ... # add batch dimension - input = input.unsqueeze(0) - - # pad input - input = F.pad( - input, - (ksize[0] // 2, ksize[0] // 2, ksize[1] // 2, ksize[1] // 2), + image = image.unsqueeze(0) # 1, C, ... + + # symmetric padding + pad = [] + for kk in reversed(k): + p = kk // 2 + pad.extend([p, p]) + image = F.pad( + image, + tuple(pad), mode=kwargs.get("mode", "reflect"), value=kwargs.get("cval", 0), ) - if input.ndim == 3: - x = F.avg_pool1d( - input, - kernel_size=ksize, - stride=1, - padding=0, - ceil_mode=False, - count_include_pad=False, - ) - elif input.ndim == 4: - x = F.avg_pool2d( - input, - kernel_size=ksize, - stride=1, - padding=0, - ceil_mode=False, - count_include_pad=False, - ) - elif input.ndim == 5: - x = F.avg_pool3d( - input, - kernel_size=ksize, - stride=1, - padding=0, - ceil_mode=False, - count_include_pad=False, - ) + + # pooling by dimensionality + if image.ndim == 3: + out = F.avg_pool1d(image, kernel_size=k, stride=1) + elif image.ndim == 4: + out = F.avg_pool2d(image, kernel_size=k, stride=1) + elif image.ndim == 5: + out = F.avg_pool3d(image, kernel_size=k, stride=1) else: raise NotImplementedError( - f"Input dimension {input.ndim - 2} not supported for torch backend" + f"Input dimensionality {image.ndim - 2} not supported" ) # restore layout - x = x.squeeze(0) + out = out.squeeze(0) if last_dim_is_channel: - x = x.movedim(0, -1) + out = out.movedim(0, -1) else: - x = x.squeeze(0) - - return x - - def get( - self: AverageBlur, - input: ArrayLike, - ksize: int, - **kwargs: Any, - ) -> np.ndarray: - """Applies the average blurring filter to the input image. - - This method applies the average blurring filter to the input image. - - Parameters - ---------- - input: np.ndarray - The input image to blur. - ksize: int - Kernel size for the pooling operation. - **kwargs: dict[str, Any] - Additional keyword arguments. - - Returns - ------- - np.ndarray - The blurred image. - - """ - - k = self._kernel_shape(input.shape, ksize) + out = out.squeeze(0) - if self.backend == "numpy": - return self._get_numpy(input, k, **kwargs) - elif self.backend == "torch": - return self._get_torch(input, k, **kwargs) - else: - raise NotImplementedError(f"Backend {self.backend} not supported") + return out -#TODO ***JH*** revise GaussianBlur - torch, typing, docstring, unit test +#TODO ***CM*** revise typing, docstring, unit test class GaussianBlur(Blur): """Applies a Gaussian blur to images using Gaussian kernels. @@ -973,13 +1036,6 @@ class GaussianBlur(Blur): >>> plt.imshow(output_image, cmap='gray') >>> plt.show() - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - """ def __init__(self: GaussianBlur, sigma: PropertyLike[float] = 2, **kwargs: Any): @@ -996,727 +1052,790 @@ def __init__(self: GaussianBlur, sigma: PropertyLike[float] = 2, **kwargs: Any): """ - super().__init__(ndimage.gaussian_filter, sigma=sigma, **kwargs) + self.sigma = float(sigma) + super().__init__(None, **kwargs) + # ---------- NumPy backend ---------- -#TODO ***JH*** revise MedianBlur - torch, typing, docstring, unit test -class MedianBlur(Blur): - """Applies a median blur. + def _get_numpy( + self, + image: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + return ndimage.gaussian_filter( + image, + sigma=self.sigma, + mode=kwargs.get("mode", "reflect"), + cval=kwargs.get("cval", 0), + ) - This class replaces each pixel of the input image with the median value of - its neighborhood. The `ksize` parameter determines the size of the - neighborhood used to calculate the median filter. The median filter is - useful for reducing noise while preserving edges. It is particularly - effective for removing salt-and-pepper noise from images. + # ---------- Torch backend ---------- - Parameters - ---------- - ksize: int - Kernel size. - **kwargs: dict - Additional parameters sent to the blurring function. + @staticmethod + def _gaussian_kernel_1d( + sigma: float, + device, + dtype, + ) -> torch.Tensor: + radius = int(np.ceil(3 * sigma)) + x = torch.arange( + -radius, + radius + 1, + device=device, + dtype=dtype, + ) + kernel = torch.exp(-(x ** 2) / (2 * sigma ** 2)) + kernel /= kernel.sum() + return kernel - Examples - -------- - >>> import deeptrack as dt - >>> import numpy as np - >>> import matplotlib.pyplot as plt + def _get_torch( + self, + image: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + import torch.nn.functional as F - Create an input image: - >>> input_image = np.random.rand(32, 32) + kernel_1d = self._gaussian_kernel_1d( + self.sigma, + device=image.device, + dtype=image.dtype, + ) - Define a median blur feature: - >>> median_blur = dt.MedianBlur(ksize=3) - >>> output_image = median_blur(input_image) - >>> print(output_image.shape) - (32, 32) + # channel-last handling + last_dim_is_channel = image.ndim >= 3 + if last_dim_is_channel: + image = image.movedim(-1, 0) # C, ... + else: + image = image.unsqueeze(0) # 1, ... - Visualize the input and output images: - >>> plt.figure(figsize=(8, 4)) - >>> plt.subplot(1, 2, 1) - >>> plt.imshow(input_image, cmap='gray') - >>> plt.subplot(1, 2, 2) - >>> plt.imshow(output_image, cmap='gray') - >>> plt.show() + # add batch dimension + image = image.unsqueeze(0) # 1, C, ... - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. + spatial_dims = image.ndim - 2 + C = image.shape[1] - """ + for d in range(spatial_dims): + k = kernel_1d + shape = [1] * spatial_dims + shape[d] = -1 + k = k.view(1, 1, *shape) + k = k.repeat(C, 1, *([1] * spatial_dims)) - def __init__( - self: MedianBlur, - ksize: PropertyLike[int] = 3, - **kwargs: Any, - ): - """Initialize the parameters for median blurring. + pad = [0, 0] * spatial_dims + radius = k.shape[2 + d] // 2 + pad[-(2 * d + 2)] = radius + pad[-(2 * d + 1)] = radius + pad = tuple(pad) - This constructor initializes the parameters for median blurring. + image = F.pad( + image, + pad, + mode=kwargs.get("mode", "reflect"), + ) - Parameters - ---------- - ksize: int - Kernel size. - **kwargs: Any - Additional keyword arguments. + if spatial_dims == 1: + image = F.conv1d(image, k, groups=C) + elif spatial_dims == 2: + image = F.conv2d(image, k, groups=C) + elif spatial_dims == 3: + image = F.conv3d(image, k, groups=C) + else: + raise NotImplementedError( + f"{spatial_dims}D Gaussian blur not supported" + ) - """ + # restore layout + image = image.squeeze(0) + if last_dim_is_channel: + image = image.movedim(0, -1) + else: + image = image.squeeze(0) - super().__init__(ndimage.median_filter, size=ksize, **kwargs) + return image -#TODO ***AL*** revise Pool - torch, typing, docstring, unit test -class Pool(Feature): - """Downsamples the image by applying a function to local regions of the - image. +#TODO ***JH*** revise MedianBlur - torch, typing, docstring, unit test +class MedianBlur(Blur): + """Applies a median blur. - This class reduces the resolution of an image by dividing it into - non-overlapping blocks of size `ksize` and applying the specified pooling - function to each block. The result is a downsampled image where each pixel - value represents the result of the pooling function applied to the - corresponding block. + This class replaces each pixel of the input image with the median value of + its neighborhood. The `ksize` parameter determines the size of the + neighborhood used to calculate the median filter. The median filter is + useful for reducing noise while preserving edges. It is particularly + effective for removing salt-and-pepper noise from images. + + - NumPy backend: `scipy.ndimage.median_filter` + - Torch backend: explicit unfolding followed by `torch.median` Parameters ---------- - pooling_function: function - A function that is applied to each local region of the image. - DOES NOT NEED TO BE WRAPPED IN ANOTHER FUNCTION. - The `pooling_function` must accept the input image as a keyword argument - named `input`, as it is called via `utils.safe_call`. - Examples include `np.mean`, `np.max`, `np.min`, etc. ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional parameters sent to the pooling function. + Kernel size. + **kwargs: dict + Additional parameters sent to the blurring function. - Methods - ------- - `get(image: np.ndarray | Image, ksize: int, **kwargs: Any) --> np.ndarray` - Applies the pooling function to the input image. + Notes + ----- + Torch median blurring is significantly more expensive than mean or + Gaussian blurring due to explicit tensor unfolding. + + Median blur is not differentiable. This is typically acceptable, as the + operation is intended for denoising and preprocessing rather than as a + trainable network layer. Examples -------- >>> import deeptrack as dt >>> import numpy as np + >>> import matplotlib.pyplot as plt Create an input image: >>> input_image = np.random.rand(32, 32) - Define a pooling feature: - >>> pooling_feature = dt.Pool(pooling_function=np.mean, ksize=4) - >>> output_image = pooling_feature.get(input_image, ksize=4) + Define a median blur feature: + >>> median_blur = dt.MedianBlur(ksize=3) + >>> output_image = median_blur(input_image) >>> print(output_image.shape) - (8, 8) + (32, 32) - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - The filter_function must accept the input image as a keyword argument named - input. This is required because it is called via utils.safe_call. If you - are using functions that do not support input=... (such as OpenCV filters - like cv2.GaussianBlur), consider using BlurCV2 instead. + Visualize the input and output images: + >>> plt.figure(figsize=(8, 4)) + >>> plt.subplot(1, 2, 1) + >>> plt.imshow(input_image, cmap='gray') + >>> plt.subplot(1, 2, 2) + >>> plt.imshow(output_image, cmap='gray') + >>> plt.show() """ def __init__( - self: Pool, - pooling_function: Callable, + self: MedianBlur, ksize: PropertyLike[int] = 3, **kwargs: Any, ): - """Initialize the parameters for pooling input features. + self.ksize = int(ksize) + super().__init__(None, **kwargs) - This constructor initializes the parameters for pooling input - features. + # ---------- NumPy backend ---------- - Parameters - ---------- - pooling_function: Callable - The pooling function to apply. - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional keyword arguments. - - """ - - self.pooling = pooling_function - super().__init__(ksize=ksize, **kwargs) - - def get( - self: Pool, - image: np.ndarray | Image, - ksize: int, + def _get_numpy( + self, + image: np.ndarray, **kwargs: Any, ) -> np.ndarray: - """Applies the pooling function to the input image. + return ndimage.median_filter( + image, + size=self.ksize, + mode=kwargs.get("mode", "reflect"), + cval=kwargs.get("cval", 0), + ) - This method applies the pooling function to the input image. + # ---------- Torch backend ---------- - Parameters - ---------- - image: np.ndarray - The input image to pool. - ksize: int - Size of the pooling kernel. - **kwargs: dict[str, Any] - Additional keyword arguments. + def _get_torch( + self, + image: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + import torch.nn.functional as F - Returns - ------- - np.ndarray - The pooled image. + k = self.ksize + if k % 2 == 0: + raise ValueError("MedianBlur requires an odd kernel size.") - """ + last_dim_is_channel = image.ndim >= 3 + if last_dim_is_channel: + image = image.movedim(-1, 0) # C, ... + else: + image = image.unsqueeze(0) # 1, ... - kwargs.pop("func", False) - kwargs.pop("image", False) - kwargs.pop("block_size", False) - return utils.safe_call( - skimage.measure.block_reduce, - image=image, - func=self.pooling, - block_size=ksize, - **kwargs, - ) + # add batch dimension + image = image.unsqueeze(0) # 1, C, ... + spatial_dims = image.ndim - 2 + pad = k // 2 -#TODO ***AL*** revise AveragePooling - torch, typing, docstring, unit test -class AveragePooling(Pool): - """Apply average pooling to an image. + pad_tuple = [] + for _ in range(spatial_dims): + pad_tuple.extend([pad, pad]) + pad_tuple = tuple(reversed(pad_tuple)) - This class reduces the resolution of an image by dividing it into - non-overlapping blocks of size `ksize` and applying the average function to - each block. The result is a downsampled image where each pixel value - represents the average value within the corresponding block of the - original image. + image = F.pad( + image, + pad_tuple, + mode=kwargs.get("mode", "reflect"), + ) - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: dict - Additional parameters sent to the pooling function. + if spatial_dims == 1: + x = image.unfold(2, k, 1) + elif spatial_dims == 2: + x = image.unfold(2, k, 1).unfold(3, k, 1) + elif spatial_dims == 3: + x = ( + image + .unfold(2, k, 1) + .unfold(3, k, 1) + .unfold(4, k, 1) + ) + else: + raise NotImplementedError( + f"{spatial_dims}D median blur not supported" + ) - Examples - -------- - >>> import deeptrack as dt - >>> import numpy as np + x = x.contiguous().view(*x.shape[:-spatial_dims], -1) + x = x.median(dim=-1).values - Create an input image: - >>> input_image = np.random.rand(32, 32) + x = x.squeeze(0) + if last_dim_is_channel: + x = x.movedim(0, -1) + else: + x = x.squeeze(0) - Define an average pooling feature: - >>> average_pooling = dt.AveragePooling(ksize=4) - >>> output_image = average_pooling(input_image) - >>> print(output_image.shape) - (8, 8) + return x - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. +#TODO ***CM*** revise typing, docstring, unit test +class Pool(BackendDispatched, Feature): + """Abstract base class for pooling features.""" - """ + _NUMPY_IMPL = "_get_numpy" + _TORCH_IMPL = "_get_torch" def __init__( - self: Pool, - ksize: PropertyLike[int] = 3, + self, + ksize: PropertyLike[int] = 2, **kwargs: Any, ): - """Initialize the parameters for average pooling. + self.ksize = int(ksize) + super().__init__(**kwargs) - This constructor initializes the parameters for average pooling. + def get( + self, + image: np.ndarray | torch.Tensor, + **kwargs: Any, + ) -> np.ndarray | torch.Tensor: + return self._dispatch_backend(image, **kwargs) - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional keyword arguments. + # ---------- shared helpers ---------- - """ + def _get_pool_size(self, array) -> tuple[int, int, int]: + k = self.ksize - super().__init__(np.mean, ksize=ksize, **kwargs) + if array.ndim == 2: + return k, k, 1 + if array.ndim == 3: + if array.shape[-1] <= 4: # channel heuristic + return k, k, 1 + return k, k, k -class MaxPooling(Pool): - """Apply max-pooling to images. + if array.ndim == 4: + return k, k, k - `MaxPooling` reduces the resolution of an image by dividing it into - non-overlapping blocks of size `ksize` and applying the `max` function - to each block. The result is a downsampled image where each pixel value - represents the maximum value within the corresponding block of the - original image. This is useful for reducing the size of an image while - retaining the most significant features. + raise ValueError(f"Unsupported array shape {array.shape}") - If the backend is NumPy, the downsampling is performed using - `skimage.measure.block_reduce`. + def _crop_center(self, array): + px, py, pz = self._get_pool_size(array) - If the backend is PyTorch, the downsampling is performed using - `torch.nn.functional.max_pool2d`. + # 2D or effectively 2D (channels-last) + if array.ndim < 3 or pz == 1: + H, W = array.shape[:2] + crop_h = (H // px) * px + crop_w = (W // py) * py + return array[:crop_h, :crop_w, ...] - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional parameters sent to the pooling function. + # 3D volume + Z, H, W = array.shape[:3] + crop_z = (Z // pz) * pz + crop_h = (H // px) * px + crop_w = (W // py) * py + return array[:crop_z, :crop_h, :crop_w, ...] - Examples - -------- - >>> import deeptrack as dt + # ---------- abstract backends ---------- - Create an input image: - >>> import numpy as np - >>> - >>> input_image = np.random.rand(32, 32) + def _get_numpy(self, image: np.ndarray, **kwargs): + raise NotImplementedError - Define and use a max-pooling feature: + def _get_torch(self, image: torch.Tensor, **kwargs): + raise NotImplementedError - >>> max_pooling = dt.MaxPooling(ksize=8) - >>> output_image = max_pooling(input_image) - >>> output_image.shape - (4, 4) +class AveragePooling(Pool): + """Average pooling feature. + + Downsamples the input by applying mean pooling over non-overlapping + blocks of size `ksize`, preserving the center of the image and never + pooling over channel dimensions. + + Works with NumPy and PyTorch backends. """ - def __init__( - self: MaxPooling, - ksize: PropertyLike[int] = 3, - **kwargs: Any, - ): - """Initialize the parameters for max-pooling. + # ---------- NumPy backend ---------- - This constructor initializes the parameters for max-pooling. + def _get_numpy( + self, + image: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional keyword arguments. + # 2D or effectively 2D (channels-last) + if image.ndim < 3 or pz == 1: + block_size = (px, py) + (1,) * (image.ndim - 2) + else: + # 3D volume (optionally with channels) + block_size = (pz, px, py) + (1,) * (image.ndim - 3) - """ + return skimage.measure.block_reduce( + image, + block_size=block_size, + func=np.mean, + ) - super().__init__(np.max, ksize=ksize, **kwargs) + # ---------- Torch backend ---------- - def get( - self: MaxPooling, - image: NDArray[Any] | torch.Tensor, - ksize: int=3, + def _get_torch( + self, + image: torch.Tensor, **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor: - """Max-pooling of input. + ) -> torch.Tensor: + import torch.nn.functional as F - Checks the current backend and chooses the appropriate function to pool - the input image, either `._get_torch()` or `._get_numpy()`. + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) - Parameters - ---------- - image: array or tensor - Input array or tensor be pooled. - ksize: int - Kernel size of the pooling operation. + is_3d = image.ndim >= 3 and pz > 1 - Returns - ------- - array or tensor - The pooled input as `NDArray` or `torch.Tensor` depending on - the backend. + # Flatten extra (channel / feature) dimensions into C + if not is_3d: + extra = image.shape[2:] + C = int(np.prod(extra)) if extra else 1 + x = image.reshape(1, C, image.shape[0], image.shape[1]) + kernel = (px, py) + stride = (px, py) + pooled = F.avg_pool2d(x, kernel, stride) + else: + extra = image.shape[3:] + C = int(np.prod(extra)) if extra else 1 + x = image.reshape( + 1, C, + image.shape[0], + image.shape[1], + image.shape[2], + ) + kernel = (pz, px, py) + stride = (pz, px, py) + pooled = F.avg_pool3d(x, kernel, stride) - """ + # Restore original layout + return pooled.reshape(pooled.shape[2:] + extra) - if self.get_backend() == "numpy": - return self._get_numpy(image, ksize, **kwargs) - if self.get_backend() == "torch": - return self._get_torch(image, ksize, **kwargs) +class MaxPooling(Pool): + """Max pooling feature. - raise NotImplementedError(f"Backend {self.backend} not supported") + Downsamples the input by applying max pooling over non-overlapping + blocks of size `ksize`, preserving the center of the image and never + pooling over channel dimensions. - def _get_numpy( - self: MaxPooling, - image: NDArray[Any], - ksize: int=3, - **kwargs: Any, - ) -> NDArray[Any]: - """Max-pooling pooling with the NumPy backend enabled. + Works with NumPy and PyTorch backends. + """ - Returns the result of the input array passed to the scikit image - `block_reduce()` function with `np.max()` as the pooling function. + # ---------- NumPy backend ---------- - Parameters - ---------- - image: array - Input array to be pooled. - ksize: int - Kernel size of the pooling operation. + def _get_numpy( + self, + image: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) - Returns - ------- - array - The pooled image as a NumPy array. - - """ + # 2D or effectively 2D (channels-last) + if image.ndim < 3 or pz == 1: + block_size = (px, py) + (1,) * (image.ndim - 2) + else: + # 3D volume (optionally with channels) + block_size = (pz, px, py) + (1,) * (image.ndim - 3) - return utils.safe_call( - skimage.measure.block_reduce, - image=image, + return skimage.measure.block_reduce( + image, + block_size=block_size, func=np.max, - block_size=ksize, - **kwargs, ) + # ---------- Torch backend ---------- + def _get_torch( - self: MaxPooling, + self, image: torch.Tensor, - ksize: int=3, **kwargs: Any, ) -> torch.Tensor: - """Max-pooling with the PyTorch backend enabled. - - - Returns the result of the tensor passed to a PyTorch max - pooling layer. - - Parameters - ---------- - image: torch.Tensor - Input tensor to be pooled. - ksize: int - Kernel size of the pooling operation. + import torch.nn.functional as F - Returns - ------- - torch.Tensor - The pooled image as a `torch.Tensor`. + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) - """ - - # If input tensor is 2D - if len(image.shape) == 2: - # Add batch dimension for max-pooling - expanded_image = image.unsqueeze(0) + is_3d = image.ndim >= 3 and pz > 1 - pooled_image = torch.nn.functional.max_pool2d( - expanded_image, kernel_size=ksize, + # Flatten extra (channel / feature) dimensions into C + if not is_3d: + extra = image.shape[2:] + C = int(np.prod(extra)) if extra else 1 + x = image.reshape(1, C, image.shape[0], image.shape[1]) + kernel = (px, py) + stride = (px, py) + pooled = F.max_pool2d(x, kernel, stride) + else: + extra = image.shape[3:] + C = int(np.prod(extra)) if extra else 1 + x = image.reshape( + 1, C, + image.shape[0], + image.shape[1], + image.shape[2], ) - # Remove the expanded dim - return pooled_image.squeeze(0) + kernel = (pz, px, py) + stride = (pz, px, py) + pooled = F.max_pool3d(x, kernel, stride) - return torch.nn.functional.max_pool2d( - image, - kernel_size=ksize, - ) + # Restore original layout + return pooled.reshape(pooled.shape[2:] + extra) class MinPooling(Pool): - """Apply min-pooling to images. + """Min pooling feature. - `MinPooling` reduces the resolution of an image by dividing it into - non-overlapping blocks of size `ksize` and applying the `min` function to - each block. The result is a downsampled image where each pixel value - represents the minimum value within the corresponding block of the original - image. + Downsamples the input by applying min pooling over non-overlapping + blocks of size `ksize`, preserving the center of the image and never + pooling over channel dimensions. - If the backend is NumPy, the downsampling is performed using - `skimage.measure.block_reduce`. + Works with NumPy and PyTorch backends. - If the backend is PyTorch, the downsampling is performed using the inverse - of `torch.nn.functional.max_pool2d` by changing the sign of the input. + """ - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional parameters sent to the pooling function. + # ---------- NumPy backend ---------- - Examples - -------- - >>> import deeptrack as dt + def _get_numpy( + self, + image: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) - Create an input image: - >>> import numpy as np - >>> - >>> input_image = np.random.rand(32, 32) + # 2D or effectively 2D (channels-last) + if image.ndim < 3 or pz == 1: + block_size = (px, py) + (1,) * (image.ndim - 2) + else: + # 3D volume (optionally with channels) + block_size = (pz, px, py) + (1,) * (image.ndim - 3) - Define and use a min-pooling feature: - >>> min_pooling = dt.MinPooling(ksize=4) - >>> output_image = min_pooling(input_image) - >>> output_image.shape - (8, 8) + return skimage.measure.block_reduce( + image, + block_size=block_size, + func=np.min, + ) - """ + # ---------- Torch backend ---------- - def __init__( - self: MinPooling, - ksize: PropertyLike[int] = 3, + def _get_torch( + self, + image: torch.Tensor, **kwargs: Any, - ): - """Initialize the parameters for min-pooling. + ) -> torch.Tensor: + import torch.nn.functional as F - This constructor initializes the parameters for min-pooling and checks - whether to use the NumPy or PyTorch implementation, defaults to NumPy. + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional keyword arguments. + is_3d = image.ndim >= 3 and pz > 1 - """ - - super().__init__(np.min, ksize=ksize, **kwargs) + # Flatten extra (channel / feature) dimensions into C + if not is_3d: + extra = image.shape[2:] + C = int(np.prod(extra)) if extra else 1 + x = image.reshape(1, C, image.shape[0], image.shape[1]) + kernel = (px, py) + stride = (px, py) - def get( - self: MinPooling, - image: NDArray[Any] | torch.Tensor, - ksize: int=3, - **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor: - """Min pooling of input. + # min(x) = -max(-x) + pooled = -F.max_pool2d(-x, kernel, stride) + else: + extra = image.shape[3:] + C = int(np.prod(extra)) if extra else 1 + x = image.reshape( + 1, C, + image.shape[0], + image.shape[1], + image.shape[2], + ) + kernel = (pz, px, py) + stride = (pz, px, py) - Checks the current backend and chooses the appropriate function to pool - the input image, either `._get_torch()` or `._get_numpy()`. + pooled = -F.max_pool3d(-x, kernel, stride) - Parameters - ---------- - image: array or tensor - Input array or tensor to be pooled. - ksize: int - Kernel size of the pooling operation. + # Restore original layout + return pooled.reshape(pooled.shape[2:] + extra) - Returns - ------- - array or tensor - The pooled image as `NDArray` or `torch.Tensor` depending on the - backend. - """ +class SumPooling(Pool): + """Sum pooling feature. - if self.get_backend() == "numpy": - return self._get_numpy(image, ksize, **kwargs) + Downsamples the input by applying sum pooling over non-overlapping + blocks of size `ksize`, preserving the center of the image and never + pooling over channel dimensions. - if self.get_backend() == "torch": - return self._get_torch(image, ksize, **kwargs) + Works with NumPy and PyTorch backends. + """ - raise NotImplementedError(f"Backend {self.backend} not supported") + # ---------- NumPy backend ---------- def _get_numpy( - self: MinPooling, - image: NDArray[Any], - ksize: int=3, + self, + image: np.ndarray, **kwargs: Any, - ) -> NDArray[Any]: - """Min-pooling with the NumPy backend. - - Returns the result of the input array passed to the scikit - `image block_reduce()` function with `np.min()` as the pooling - function. - - Parameters - ---------- - image: NDArray - Input image to be pooled. - ksize: int - Kernel size of the pooling operation. - - Returns - ------- - NDArray - The pooled image as a `NDArray`. + ) -> np.ndarray: + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) - """ + # 2D or effectively 2D (channels-last) + if image.ndim < 3 or pz == 1: + block_size = (px, py) + (1,) * (image.ndim - 2) + else: + # 3D volume (optionally with channels) + block_size = (pz, px, py) + (1,) * (image.ndim - 3) - return utils.safe_call( - skimage.measure.block_reduce, - image=image, - func=np.min, - block_size=ksize, - **kwargs, + return skimage.measure.block_reduce( + image, + block_size=block_size, + func=np.sum, ) + # ---------- Torch backend ---------- + def _get_torch( - self: MinPooling, + self, image: torch.Tensor, - ksize: int=3, **kwargs: Any, ) -> torch.Tensor: - """Min-pooling with the PyTorch backend. + import torch.nn.functional as F - As PyTorch does not have a min-pooling layer, the equivalent operation - is to first multiply the input tensor with `-1`, then perform - max-pooling, and finally multiply the max pooled tensor with `-1`. + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) - Parameters - ---------- - image: torch.Tensor - Input tensor to be pooled. - ksize: int - Kernel size of the pooling operation. + is_3d = image.ndim >= 3 and pz > 1 - Returns - ------- - torch.Tensor - The pooled image as a `torch.Tensor`. + # Flatten extra (channel / feature) dimensions into C + if not is_3d: + extra = image.shape[2:] + C = int(np.prod(extra)) if extra else 1 + x = image.reshape(1, C, image.shape[0], image.shape[1]) + kernel = (px, py) + stride = (px, py) + pooled = F.avg_pool2d(x, kernel, stride) * (px * py) + else: + extra = image.shape[3:] + C = int(np.prod(extra)) if extra else 1 + x = image.reshape( + 1, C, + image.shape[0], + image.shape[1], + image.shape[2], + ) + kernel = (pz, px, py) + stride = (pz, px, py) + pooled = F.avg_pool3d(x, kernel, stride) * (pz * px * py) - """ + # Restore original layout + return pooled.reshape(pooled.shape[2:] + extra) - # If input tensor is 2D - if len(image.shape) == 2: - # Add batch dimension for min-pooling - expanded_image = image.unsqueeze(0) - pooled_image = - torch.nn.functional.max_pool2d( - expanded_image * (-1), - kernel_size=ksize, - ) +class MedianPooling(Pool): + """Median pooling feature. + + Downsamples the input by applying median pooling over non-overlapping + blocks of size `ksize`, preserving the center of the image and never + pooling over channel dimensions. + + Notes + ----- + - NumPy backend uses `skimage.measure.block_reduce` + - Torch backend performs explicit unfolding followed by `median` + - Torch median pooling is significantly more expensive than mean/max + + Median pooling is not differentiable and should not be used inside + trainable neural networks requiring gradient-based optimization. + + """ - # Remove the expanded dim - return pooled_image.squeeze(0) + # ---------- NumPy backend ---------- - return -torch.nn.functional.max_pool2d( - image * (-1), - kernel_size=ksize, + def _get_numpy( + self, + image: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) + + if image.ndim < 3 or pz == 1: + block_size = (px, py) + (1,) * (image.ndim - 2) + else: + block_size = (pz, px, py) + (1,) * (image.ndim - 3) + + return skimage.measure.block_reduce( + image, + block_size=block_size, + func=np.median, ) + # ---------- Torch backend ---------- -#TODO ***AL*** revise MedianPooling - torch, typing, docstring, unit test -class MedianPooling(Pool): - """Apply median pooling to images. + def _get_torch( + self, + image: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: - This class reduces the resolution of an image by dividing it into - non-overlapping blocks of size `ksize` and applying the median function to - each block. The result is a downsampled image where each pixel value - represents the median value within the corresponding block of the - original image. This is useful for reducing the size of an image while - retaining the most significant features. + if not self._warned: + warnings.warn( + "MedianPooling is not differentiable and is expensive on the " + "Torch backend. Avoid using it inside trainable models.", + UserWarning, + stacklevel=2, + ) + self._warned = True - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional parameters sent to the pooling function. + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) - Examples - -------- - >>> import deeptrack as dt - >>> import numpy as np + is_3d = image.ndim >= 3 and pz > 1 - Create an input image: - >>> input_image = np.random.rand(32, 32) + if not is_3d: + # 2D case (with optional channels) + extra = image.shape[2:] + C = int(np.prod(extra)) if extra else 1 - Define a median pooling feature: - >>> median_pooling = dt.MedianPooling(ksize=3) - >>> output_image = median_pooling(input_image) - >>> print(output_image.shape) - (32, 32) + x = image.reshape(1, C, image.shape[0], image.shape[1]) - Visualize the input and output images: - >>> plt.figure(figsize=(8, 4)) - >>> plt.subplot(1, 2, 1) - >>> plt.imshow(input_image, cmap='gray') - >>> plt.subplot(1, 2, 2) - >>> plt.imshow(output_image, cmap='gray') - >>> plt.show() + # unfold: (B, C, H', W', px, py) + x_u = ( + x.unfold(2, px, px) + .unfold(3, py, py) + ) - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. + x_u = x_u.contiguous().view( + 1, C, + x_u.shape[2], + x_u.shape[3], + -1, + ) - """ + pooled = x_u.median(dim=-1).values - def __init__( - self: MedianPooling, - ksize: PropertyLike[int] = 3, - **kwargs: Any, - ): - """Initialize the parameters for median pooling. + else: + # 3D case (with optional channels) + extra = image.shape[3:] + C = int(np.prod(extra)) if extra else 1 + + x = image.reshape( + 1, C, + image.shape[0], + image.shape[1], + image.shape[2], + ) - This constructor initializes the parameters for median pooling. + # unfold: (B, C, Z', Y', X', pz, px, py) + x_u = ( + x.unfold(2, pz, pz) + .unfold(3, px, px) + .unfold(4, py, py) + ) - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional keyword arguments. + x_u = x_u.contiguous().view( + 1, C, + x_u.shape[2], + x_u.shape[3], + x_u.shape[4], + -1, + ) - """ + pooled = x_u.median(dim=-1).values - super().__init__(np.median, ksize=ksize, **kwargs) + return pooled.reshape(pooled.shape[2:] + extra) -class Resize(Feature): +class Resize(BackendDispatched, Feature): """Resize an image to a specified size. - `Resize` resizes an image using: - - OpenCV (`cv2.resize`) for NumPy arrays. - - PyTorch (`torch.nn.functional.interpolate`) for PyTorch tensors. + `Resize` resizes images following the channels-last semantic + convention. - The interpretation of the `dsize` parameter follows the convention - of the underlying backend: - - **NumPy (OpenCV)**: `dsize` is given as `(width, height)` to match - OpenCV’s default. - - **PyTorch**: `dsize` is given as `(height, width)`. + The operation supports both NumPy arrays and PyTorch tensors: + - NumPy arrays are resized using OpenCV (`cv2.resize`). + - PyTorch tensors are resized using `torch.nn.functional.interpolate`. + + In all cases, the input is interpreted as having spatial dimensions + first and an optional channel dimension last. Parameters ---------- - dsize: PropertyLike[tuple[int, int]] - The target size. Format depends on backend: `(width, height)` for - NumPy, `(height, width)` for PyTorch. - **kwargs: Any - Additional parameters sent to the underlying resize function: - - NumPy: passed to `cv2.resize`. - - PyTorch: passed to `torch.nn.functional.interpolate`. + dsize : PropertyLike[tuple[int, int]] + Target output size given as (width, height). This convention is + backend-independent and applies equally to NumPy and PyTorch inputs. + + **kwargs : Any + Additional keyword arguments forwarded to the underlying resize + implementation: + - NumPy backend: passed to `cv2.resize`. + - PyTorch backend: passed to + `torch.nn.functional.interpolate`. Methods ------- get( - image: np.ndarray | torch.Tensor, dsize: tuple[int, int], **kwargs + image: np.ndarray | torch.Tensor, + dsize: tuple[int, int], + **kwargs ) -> np.ndarray | torch.Tensor Resize the input image to the specified size. Examples -------- - >>> import deeptrack as dt + NumPy example: - Numpy example: >>> import numpy as np - >>> - >>> input_image = np.random.rand(16, 16) # Create image - >>> feature = dt.math.Resize(dsize=(8, 4)) # (width=8, height=4) - >>> resized_image = feature.resolve(input_image) # Resize it to (4, 8) - >>> print(resized_image.shape) + >>> input_image = np.random.rand(16, 16) + >>> feature = dt.math.Resize(dsize=(8, 4)) # (width=8, height=4) + >>> resized_image = feature.resolve(input_image) + >>> resized_image.shape (4, 8) PyTorch example: + >>> import torch - >>> - >>> input_image = torch.rand(1, 1, 16, 16) # Create image - >>> feature = dt.math.Resize(dsize=(4, 8)) # (height=4, width=8) - >>> resized_image = feature.resolve(input_image) # Resize it to (4, 8) - >>> print(resized_image.shape) - torch.Size([1, 1, 4, 8]) + >>> input_image = torch.rand(16, 16) # channels-last + >>> feature = dt.math.Resize(dsize=(8, 4)) + >>> resized_image = feature.resolve(input_image) + >>> resized_image.shape + torch.Size([4, 8]) + + Notes + ----- + - Resize follows channels-last semantics, consistent with other features + such as Pool and Blur. + - Torch tensors with channels-first layout (e.g. (C, H, W) or + (N, C, H, W)) are not supported and must be converted to + channels-last format before resizing. + - For PyTorch tensors, bilinear interpolation is used with + `align_corners=False`, closely matching OpenCV’s default behavior. """ + _NUMPY_IMPL = "_get_numpy" + _TORCH_IMPL = "_get_torch" + def __init__( self: Resize, dsize: PropertyLike[tuple[int, int]] = (256, 256), @@ -1727,8 +1846,8 @@ def __init__( Parameters ---------- dsize: PropertyLike[tuple[int, int]] - The target size. Format depends on backend: `(width, height)` for - NumPy, `(height, width)` for PyTorch. Default is (256, 256). + The target size. dsize is always (width, height) for both backends. + Default is (256, 256). **kwargs: Any Additional arguments passed to the parent `Feature` class. @@ -1738,89 +1857,147 @@ def __init__( def get( self: Resize, - image: NDArray | torch.Tensor, + image: np.ndarray | torch.Tensor, dsize: tuple[int, int], **kwargs: Any, - ) -> NDArray | torch.Tensor: + ) -> np.ndarray | torch.Tensor: """Resize the input image to the specified size. Parameters ---------- - image: np.ndarray or torch.Tensor - The input image to resize. - - NumPy arrays may be grayscale (H, W) or color (H, W, C). - - Torch tensors are expected in one of the following formats: - (N, C, H, W), (C, H, W), or (H, W). - dsize: tuple[int, int] - Desired output size of the image. - - NumPy: (width, height) - - PyTorch: (height, width) - **kwargs: Any - Additional keyword arguments passed to the underlying resize - function (`cv2.resize` or `torch.nn.functional.interpolate`). + image : np.ndarray or torch.Tensor + Input image following channels-last semantics. + + Supported shapes are: + - (H, W) + - (H, W, C) + - (Z, H, W) + - (Z, H, W, C) + + For PyTorch tensors, channels-first layouts such as (C, H, W) or + (N, C, H, W) are not supported and must be converted to + channels-last format before calling `Resize`. + + dsize : tuple[int, int] + Desired output size given as (width, height). This convention is + backend-independent and applies to both NumPy and PyTorch inputs. + + **kwargs : Any + Additional keyword arguments passed to the underlying resize + implementation: + - NumPy backend: forwarded to `cv2.resize`. + - PyTorch backend: forwarded to `torch.nn.functional.interpolate`. Returns ------- np.ndarray or torch.Tensor - The resized image in the same type and dimensionality format as - input. + The resized image, with the same type and dimensionality layout as + the input image. Notes ----- + - Resize follows the same channels-last semantic convention as other + features in `deeptrack.math`. - For PyTorch tensors, resizing uses bilinear interpolation with - `align_corners=False`. This choice matches OpenCV’s `cv2.resize` - default behavior when resizing NumPy arrays, aiming to produce nearly - identical results between both backends. + `align_corners=False`, which closely matches OpenCV’s default behavior. """ + return self._dispatch_backend(image, dsize=dsize, **kwargs) - if self._wrap_array_with_image: - image = strip(image) + # ---------- NumPy backend (OpenCV) ---------- - if apc.is_torch_array(image): - original_shape = image.shape - - # Reshape input to (N, C, H, W) - if image.ndim == 2: # (H, W) - image = image.unsqueeze(0).unsqueeze(0) - elif image.ndim == 3: # (C, H, W) - image = image.unsqueeze(0) - elif image.ndim != 4: - raise ValueError( - "Resize only supports tensors with shape (N, C, H, W), " - "(C, H, W), or (H, W)." - ) + def _get_numpy( + self, + image: np.ndarray, + dsize: tuple[int, int], + **kwargs: Any, + ) -> np.ndarray: - resized = torch.nn.functional.interpolate( - image, - size=dsize, - mode="bilinear", - align_corners=False, + target_w, target_h = dsize + + # Prefer OpenCV if available + if OPENCV_AVAILABLE: + import cv2 + return utils.safe_call( + cv2.resize, + positional_args=[image, (target_w, target_h)], + **kwargs, ) + if not OPENCV_AVAILABLE and kwargs: + warnings.warn("OpenCV not available: resize kwargs may be ignored.", UserWarning) + + # Fallback: skimage (always available in DT) + from skimage.transform import resize as sk_resize + + if image.ndim == 2: + out_shape = (target_h, target_w) + else: + out_shape = (target_h, target_w) + image.shape[2:] + + out = sk_resize( + image, + out_shape, + preserve_range=True, + anti_aliasing=True, + ) + + return out.astype(image.dtype, copy=False) + + # ---------- Torch backend ---------- + + def _get_torch( + self, + image: torch.Tensor, + dsize: tuple[int, int], + **kwargs: Any, + ) -> torch.Tensor: + import torch.nn.functional as F + + target_w, target_h = dsize + + original_ndim = image.ndim + has_channels = image.ndim >= 3 and image.shape[-1] <= 4 + + # Convert to (N, C, H, W) + if image.ndim == 2: + x = image.unsqueeze(0).unsqueeze(0) # (1, 1, H, W) - # Restore original dimensionality - if len(original_shape) == 2: - resized = resized.squeeze(0).squeeze(0) - elif len(original_shape) == 3: - resized = resized.squeeze(0) + elif image.ndim == 3 and has_channels: + x = image.permute(2, 0, 1).unsqueeze(0) # (1, C, H, W) - return resized + elif image.ndim == 3: + x = image.unsqueeze(1) # (Z, 1, H, W) + + elif image.ndim == 4 and has_channels: + x = image.permute(0, 3, 1, 2) # (Z, C, H, W) else: - import cv2 - return utils.safe_call( - cv2.resize, positional_args=[image, dsize], **kwargs + raise ValueError( + f"Unsupported tensor shape {image.shape} for Resize." ) + # Resize spatial dimensions + x = F.interpolate( + x, + size=(target_h, target_w), + mode="bilinear", + align_corners=False, + ) -if OPENCV_AVAILABLE: - _map_mode_to_cv2_borderType = { - "reflect": cv2.BORDER_REFLECT, - "wrap": cv2.BORDER_WRAP, - "constant": cv2.BORDER_CONSTANT, - "mirror": cv2.BORDER_REFLECT_101, - "nearest": cv2.BORDER_REPLICATE, - } + # Restore original layout + if original_ndim == 2: + return x.squeeze(0).squeeze(0) + + if original_ndim == 3 and has_channels: + return x.squeeze(0).permute(1, 2, 0) + + if original_ndim == 3: + return x.squeeze(1) + + if original_ndim == 4: + return x.permute(0, 2, 3, 1) + + raise RuntimeError("Unexpected shape restoration path.") #TODO ***JH*** revise BlurCV2 - torch, typing, docstring, unit test @@ -1840,7 +2017,7 @@ class BlurCV2(Feature): Methods ------- - `get(image: np.ndarray | Image, **kwargs: Any) --> np.ndarray` + `get(image: np.ndarray, **kwargs: Any) --> np.ndarray` Applies the blurring filter to the input image. Examples @@ -1865,59 +2042,23 @@ class BlurCV2(Feature): Notes ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. + BlurCV2 is NumPy-only and does not support PyTorch tensors. + This class is intended for OpenCV-specific filters that are + not available in the backend-agnostic math layer. """ - def __new__( - cls: type, - *args: tuple, - **kwargs: Any, - ): - """Ensures that OpenCV (cv2) is available before instantiating the - class. - - Overrides the default object creation process to check that the `cv2` - module is available before creating the class. If OpenCV is not - installed, it raises an ImportError with instructions for installation. - - Parameters - ---------- - *args : tuple - Positional arguments passed to the class constructor. - **kwargs : dict - Keyword arguments passed to the class constructor. - - Returns - ------- - BlurCV2 - An instance of the BlurCV2 feature class. - - Raises - ------ - ImportError - If the OpenCV (`cv2`) module is not available in the current - environment. - - """ - - print(cls.__name__) - - if not OPENCV_AVAILABLE: - raise ImportError( - "OpenCV not installed on device. Since OpenCV is an optional " - f"dependency of DeepTrack2. To use {cls.__name__}, " - "you need to install it manually." - ) - - return super().__new__(cls) + _MODE_TO_BORDER = { + "reflect": "BORDER_REFLECT", + "wrap": "BORDER_WRAP", + "constant": "BORDER_CONSTANT", + "mirror": "BORDER_REFLECT_101", + "nearest": "BORDER_REPLICATE", + } def __init__( self: BlurCV2, - filter_function: Callable, + filter_function: Callable | str, mode: PropertyLike[str] = "reflect", **kwargs: Any, ): @@ -1937,13 +2078,20 @@ def __init__( """ + if not OPENCV_AVAILABLE: + raise ImportError( + "OpenCV not installed on device. Since OpenCV is an optional " + f"dependency of DeepTrack2. To use {self.__class__.__name__}, " + "you need to install it manually." + ) + self.filter = filter_function - borderType = _map_mode_to_cv2_borderType[mode] - super().__init__(borderType=borderType, **kwargs) + self.mode = mode + super().__init__(**kwargs) def get( self: BlurCV2, - image: np.ndarray | Image, + image: np.ndarray, **kwargs: Any, ) -> np.ndarray: """Applies the blurring filter to the input image. @@ -1952,8 +2100,8 @@ def get( Parameters ---------- - image: np.ndarray | Image - The input image to blur. Can be a NumPy array or DeepTrack Image. + image: np.ndarray + The input image to blur. Must be a NumPy array. **kwargs: Any Additional parameters for the blurring function. @@ -1964,9 +2112,34 @@ def get( """ + if apc.is_torch_array(image): + raise TypeError( + "BlurCV2 only supports NumPy arrays. " + "Use GaussianBlur / AverageBlur for Torch." + ) + + import cv2 + + filter_fn = getattr(cv2, self.filter) if isinstance(self.filter, str) else self.filter + + try: + border_attr = self._MODE_TO_BORDER[self.mode] + except KeyError as e: + raise ValueError(f"Unsupported border mode '{self.mode}'") from e + + try: + border = getattr(cv2, border_attr) + except AttributeError as e: + raise RuntimeError(f"OpenCV missing border constant '{border_attr}'") from e + + # preserve legacy behavior kwargs.pop("name", None) - result = self.filter(src=image, **kwargs) - return result + + return filter_fn( + src=image, + borderType=border, + **kwargs, + ) #TODO ***JH*** revise BilateralBlur - torch, typing, docstring, unit test @@ -2015,10 +2188,7 @@ class BilateralBlur(BlurCV2): Notes ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. + BilateralBlur is NumPy-only and does not support PyTorch tensors. """ @@ -2053,9 +2223,103 @@ def __init__( """ super().__init__( - cv2.bilateralFilter, + filter_function="bilateralFilter", d=d, sigmaColor=sigma_color, sigmaSpace=sigma_space, **kwargs, ) + + +def isotropic_dilation( + mask: np.ndarray | torch.Tensor, + radius: float, + *, + backend: Literal["numpy", "torch"], + device=None, + dtype=None, +) -> np.ndarray | torch.Tensor: + """ + Binary dilation using an isotropic (NumPy) or box-shaped (Torch) kernel. + + Notes + ----- + - NumPy backend uses a true Euclidean ball. + - Torch backend uses a cubic structuring element (approximate). + - Torch backend supports 3D masks only. + - Operation is non-differentiable. + + """ + + if radius <= 0: + return mask + + if backend == "numpy": + from skimage.morphology import isotropic_dilation + return isotropic_dilation(mask, radius) + + # torch backend + import torch + + r = int(np.ceil(radius)) + kernel = torch.ones( + (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1), + device=device or mask.device, + dtype=dtype or torch.float32, + ) + + x = mask.to(dtype=kernel.dtype)[None, None] + y = torch.nn.functional.conv3d( + x, + kernel, + padding=r, + ) + + return (y[0, 0] > 0) + + +def isotropic_erosion( + mask: np.ndarray | torch.Tensor, + radius: float, + *, + backend: Literal["numpy", "torch"], + device=None, + dtype=None, +) -> np.ndarray | torch.Tensor: + """ + Binary erosion using an isotropic (NumPy) or box-shaped (Torch) kernel. + + Notes + ----- + - NumPy backend uses a true Euclidean ball. + - Torch backend uses a cubic structuring element (approximate). + - Torch backend supports 3D masks only. + - Operation is non-differentiable. + + """ + + if radius <= 0: + return mask + + if backend == "numpy": + from skimage.morphology import isotropic_erosion + return isotropic_erosion(mask, radius) + + import torch + + r = int(np.ceil(radius)) + kernel = torch.ones( + (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1), + device=device or mask.device, + dtype=dtype or torch.float32, + ) + + x = mask.to(dtype=kernel.dtype)[None, None] + y = torch.nn.functional.conv3d( + x, + kernel, + padding=r, + ) + + required = kernel.numel() + return (y[0, 0] >= required) \ No newline at end of file diff --git a/deeptrack/optics.py b/deeptrack/optics.py index 5149bdae..70a54de6 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -137,11 +137,13 @@ def _pad_volume( from __future__ import annotations from pint import Quantity -from typing import Any +from typing import Any, TYPE_CHECKING import warnings import numpy as np -from scipy.ndimage import convolve +from scipy.ndimage import convolve # might be removed later +import torch +import torch.nn.functional as F from deeptrack.backend.units import ( ConversionTable, @@ -149,23 +151,37 @@ def _pad_volume( get_active_scale, get_active_voxel_size, ) -from deeptrack.math import AveragePooling +from deeptrack.math import AveragePooling, SumPooling from deeptrack.features import propagate_data_to_dependencies from deeptrack.features import DummyFeature, Feature, StructuralFeature -from deeptrack.image import Image, pad_image_to_fft +from deeptrack.image import pad_image_to_fft from deeptrack.types import ArrayLike, PropertyLike from deeptrack import image from deeptrack import units_registry as u +from deeptrack import TORCH_AVAILABLE, image +from deeptrack.backend import xp +from deeptrack.scatterers import ScatteredVolume, ScatteredField + +if TORCH_AVAILABLE: + import torch + +if TYPE_CHECKING: + import torch + #TODO ***??*** revise Microscope - torch, typing, docstring, unit test class Microscope(StructuralFeature): """Simulates imaging of a sample using an optical system. - This class combines a feature-set that defines the sample to be imaged with - a feature-set defining the optical system, enabling the simulation of - optical imaging processes. + This class combines the sample to be imaged with the optical system, + enabling the simulation of optical imaging processes. + A Microscope: + - validates the semantic compatibility between scatterers and optics + - interprets volume-based scatterers into scalar fields when needed + - delegates numerical propagation to the objective (Optics) + - performs detector downscaling according to its physical semantics Parameters ---------- @@ -186,10 +202,16 @@ class Microscope(StructuralFeature): Methods ------- - `get(image: Image or None, **kwargs: Any) -> Image` + `get(image: np.ndarray or None, **kwargs: Any) -> np.ndarray` Simulates the imaging process using the defined optical system and returns the resulting image. + Notes + ----- + All volume scatterers imaged by a Microscope instance are assumed to + share the same contrast mechanism (e.g. refractive index or fluorescence). + Mixing contrast types is not supported. + Examples -------- Simulating an image using a brightfield optical system: @@ -238,13 +260,41 @@ def __init__( self._sample = self.add_feature(sample) self._objective = self.add_feature(objective) - self._sample.store_properties() + + def _validate_input(self, scattered): + if hasattr(self._objective, "validate_input"): + self._objective.validate_input(scattered) + + def _extract_contrast_volume(self, scattered): + if hasattr(self._objective, "extract_contrast_volume"): + return self._objective.extract_contrast_volume( + scattered, + **self._objective.properties(), + ) + return scattered.array + + def _downscale_image(self, image, upscale): + if hasattr(self._objective, "downscale_image"): + return self._objective.downscale_image(image, upscale) + + if not np.any(np.array(upscale) != 1): + return image + + ux, uy = upscale[:2] + if ux != uy: + raise ValueError( + f"Energy-conserving detector integration requires ux == uy, " + f"got ux={ux}, uy={uy}." + ) + if isinstance(ux, float) and ux.is_integer(): + ux = int(ux) + return AveragePooling(ux)(image) def get( self: Microscope, - image: Image | None, + image: np.ndarray | torch.Tensor | None = None, **kwargs: Any, - ) -> Image: + ) -> np.ndarray | torch.Tensor: """Generate an image of the sample using the defined optical system. This method processes the sample through the optical system to @@ -252,14 +302,14 @@ def get( Parameters ---------- - image: Image | None + image: np.ndarray | torch.Tensor | None The input image to be processed. If None, a new image is created. **kwargs: Any Additional parameters for the imaging process. Returns ------- - Image: Image + image: np.ndarray | torch.Tensor The processed image after applying the optical system. Examples @@ -280,9 +330,6 @@ def get( # Grab properties from the objective to pass to the sample additional_sample_kwargs = self._objective.properties() - # Calculate required output image for the given upscale - # This way of providing the upscale will be deprecated in the future - # in favor of dt.Upscale(). _upscale_given_by_optics = additional_sample_kwargs["upscale"] if np.array(_upscale_given_by_optics).size == 1: _upscale_given_by_optics = (_upscale_given_by_optics,) * 3 @@ -325,67 +372,60 @@ def get( if not isinstance(list_of_scatterers, list): list_of_scatterers = [list_of_scatterers] + # Semantic validation (per scatterer) + for scattered in list_of_scatterers: + self._validate_input(scattered) + # All scatterers that are defined as volumes. volume_samples = [ scatterer for scatterer in list_of_scatterers - if not scatterer.get_property("is_field", default=False) + if isinstance(scatterer, ScatteredVolume) ] # All scatterers that are defined as fields. field_samples = [ scatterer for scatterer in list_of_scatterers - if scatterer.get_property("is_field", default=False) + if isinstance(scatterer, ScatteredField) ] - + # Merge all volumes into a single volume. sample_volume, limits = _create_volume( volume_samples, **additional_sample_kwargs, ) - sample_volume = Image(sample_volume) + if volume_samples: + # Interpret the merged volume semantically + sample_volume = self._extract_contrast_volume( + ScatteredVolume( + array=sample_volume, + properties=volume_samples[0].properties, + ), + ) - # Merge all properties into the volume. - for scatterer in volume_samples + field_samples: - sample_volume.merge_properties_from(scatterer) # Let the objective know about the limits of the volume and all the fields. propagate_data_to_dependencies( self._objective, limits=limits, - fields=field_samples, + fields=field_samples, # should We add upscale? ) imaged_sample = self._objective.resolve(sample_volume) - # Upscale given by the optics needs to be handled separately. - if _upscale_given_by_optics != (1, 1, 1): - imaged_sample = AveragePooling((*_upscale_given_by_optics[:2], 1))( - imaged_sample - ) - - # Merge with input - if not image: - if not self._wrap_array_with_image and isinstance(imaged_sample, Image): - return imaged_sample._value - else: - return imaged_sample - - if not isinstance(image, list): - image = [image] - for i in range(len(image)): - image[i].merge_properties_from(imaged_sample) - return image - - # def _no_wrap_format_input(self, *args, **kwargs) -> list: - # return self._image_wrapped_format_input(*args, **kwargs) + imaged_sample = self._downscale_image(imaged_sample, upscale) + # # Handling upscale from dt.Upscale() here to eliminate Image + # # wrapping issues. + # if np.any(np.array(upscale) != 1): + # ux, uy = upscale[:2] + # if contrast_type == "intensity": + # print("Using sum pooling for intensity downscaling.") + # imaged_sample = SumPoolingCM((ux, uy, 1))(imaged_sample) + # else: + # imaged_sample = AveragePoolingCM((ux, uy, 1))(imaged_sample) - # def _no_wrap_process_and_get(self, *args, **feature_input) -> list: - # return self._image_wrapped_process_and_get(*args, **feature_input) - - # def _no_wrap_process_output(self, *args, **feature_input): - # return self._image_wrapped_process_output(*args, **feature_input) + return imaged_sample #TODO ***??*** revise Optics - torch, typing, docstring, unit test @@ -569,6 +609,15 @@ def __init__( """ + def validate_scattered(self, scattered): + pass + + def extract_contrast_volume(self, scattered): + pass + + def downscale_image(self, image, upscale): + pass + def get_voxel_size( resolution: float | ArrayLike[float], magnification: float, @@ -757,19 +806,18 @@ def _pupil( W, H = np.meshgrid(y, x) RHO = (W ** 2 + H ** 2).astype(complex) - pupil_function = Image((RHO < 1) + 0.0j, copy=False) + pupil_function = (RHO < 1) + 0.0j # Defocus - z_shift = Image( + z_shift = ( 2 * np.pi * refractive_index_medium / wavelength * voxel_size[2] - * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO), - copy=False, + * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO) ) - z_shift._value[z_shift._value.imag != 0] = 0 + z_shift[z_shift.imag != 0] = 0 try: z_shift = np.nan_to_num(z_shift, False, 0, 0, 0) @@ -1007,7 +1055,7 @@ class Fluorescence(Optics): Methods ------- - `get(illuminated_volume: array_like[complex], limits: array_like[int, int], **kwargs: Any) -> Image` + `get(illuminated_volume: array_like[complex], limits: array_like[int, int], **kwargs: Any) -> np.ndarray` Simulates the imaging process using a fluorescence microscope. Examples @@ -1024,12 +1072,70 @@ class Fluorescence(Optics): """ + def validate_input(self, scattered): + """Semantic validation for fluorescence microscopy.""" + + # Fluorescence cannot operate on coherent fields + if isinstance(scattered, ScatteredField): + raise TypeError( + "Fluorescence microscope cannot operate on ScatteredField." + ) + + + def extract_contrast_volume(self, scattered: ScatteredVolume, **kwargs) -> np.ndarray: + scale = np.asarray(get_active_scale(), float) + scale_volume = np.prod(scale) + + intensity = scattered.get_property("intensity", None) + value = scattered.get_property("value", None) + ri = scattered.get_property("refractive_index", None) + + # Refractive index is always ignored in fluorescence + if ri is not None: + warnings.warn( + "Scatterer defines 'refractive_index', which is ignored in " + "fluorescence microscopy.", + UserWarning, + ) + + # Preferred, physically meaningful case + if intensity is not None: + return intensity * scale_volume * scattered.array + + # Fallback: legacy / dimensionless brightness + warnings.warn( + "Fluorescence scatterer has no 'intensity'. Interpreting 'value' as a " + "non-physical brightness factor. Quantitative interpretation is invalid. " + "Define 'intensity' to model physical fluorescence emission.", + UserWarning, + ) + + return value * scattered.array + + def downscale_image(self, image: np.ndarray, upscale): + """Detector downscaling (energy conserving)""" + if not np.any(np.array(upscale) != 1): + return image + + ux, uy = upscale[:2] + if ux != uy: + raise ValueError( + f"Energy-conserving detector integration requires ux == uy, " + f"got ux={ux}, uy={uy}." + ) + if isinstance(ux, float) and ux.is_integer(): + ux = int(ux) + + # Energy-conserving detector integration + return SumPooling(ux)(image) + + def get( self: Fluorescence, illuminated_volume: ArrayLike[complex], limits: ArrayLike[int], **kwargs: Any, - ) -> Image: + ) -> ArrayLike[complex]: """Simulates the imaging process using a fluorescence microscope. This method convolves the 3D illuminated volume with a pupil function @@ -1048,7 +1154,7 @@ def get( Returns ------- - Image: Image + image: np.ndarray A 2D image object representing the fluorescence projection. Notes @@ -1066,7 +1172,7 @@ def get( >>> optics = dt.Fluorescence( ... NA=1.4, wavelength=0.52e-6, magnification=60, ... ) - >>> volume = dt.Image(np.ones((128, 128, 10), dtype=complex)) + >>> volume = np.ones((128, 128, 10), dtype=complex) >>> limits = np.array([[0, 128], [0, 128], [0, 10]]) >>> properties = optics.properties() >>> filtered_properties = { @@ -1118,9 +1224,7 @@ def get( ] z_limits = limits[2, :] - output_image = Image( - np.zeros((*padded_volume.shape[0:2], 1)), copy=False - ) + output_image = np.zeros((*padded_volume.shape[0:2], 1)) index_iterator = range(padded_volume.shape[2]) @@ -1156,12 +1260,12 @@ def get( field = np.fft.ifft2(convolved_fourier_field) # # Discard remaining imaginary part (should be 0 up to rounding error) field = np.real(field) - output_image._value[:, :, 0] += field[ + output_image[:, :, 0] += field[ : padded_volume.shape[0], : padded_volume.shape[1] ] output_image = output_image[pad[0] : -pad[2], pad[1] : -pad[3]] - output_image.properties = illuminated_volume.properties + pupils.properties + # output_image.properties = illuminated_volume.properties + pupils.properties return output_image @@ -1234,7 +1338,7 @@ class Brightfield(Optics): ------- `get(illuminated_volume: array_like[complex], limits: array_like[int, int], fields: array_like[complex], - **kwargs: Any) -> Image` + **kwargs: Any) -> np.ndarray` Simulates imaging with brightfield microscopy. @@ -1250,9 +1354,51 @@ class Brightfield(Optics): """ + __conversion_table__ = ConversionTable( - working_distance=(u.meter, u.meter), - ) + working_distance=(u.meter, u.meter), +) + + def validate_input(self, scattered): + """Semantic validation for brightfield microscopy.""" + + if isinstance(scattered, ScatteredVolume): + warnings.warn( + "Brightfield imaging from ScatteredVolume assumes a " + "weak-phase / projection approximation. " + "Use ScatteredField for physically accurate brightfield simulations.", + UserWarning, + ) + + def extract_contrast_volume( + self, + scattered: ScatteredVolume, + refractive_index_medium: float, + **kwargs: Any, + ) -> np.ndarray: + + ri = scattered.get_property("refractive_index", None) + value = scattered.get_property("value", None) + intensity = scattered.get_property("intensity", None) + + if intensity is not None: + warnings.warn( + "Scatterer defines 'intensity', which is ignored in " + "brightfield microscopy.", + UserWarning, + ) + + if ri is not None: + return (ri - refractive_index_medium) * scattered.array + + warnings.warn( + "No 'refractive_index' specified; using 'value' as a non-physical " + "brightfield contrast. Results are not physically calibrated. " + "Define 'refractive_index' for physically meaningful contrast.", + UserWarning, + ) + + return value * scattered.array def get( self: Brightfield, @@ -1260,7 +1406,7 @@ def get( limits: ArrayLike[int], fields: ArrayLike[complex], **kwargs: Any, - ) -> Image: + ) -> np.ndarray: """Simulates imaging with brightfield microscopy. This method propagates light through the given volume, applying @@ -1285,7 +1431,7 @@ def get( Returns ------- - Image: Image + image: np.ndarray Processed image after simulating the brightfield imaging process. Examples @@ -1300,7 +1446,7 @@ def get( ... wavelength=0.52e-6, ... magnification=60, ... ) - >>> volume = dt.Image(np.ones((128, 128, 10), dtype=complex)) + >>> volume = np.ones((128, 128, 10), dtype=complex) >>> limits = np.array([[0, 128], [0, 128], [0, 10]]) >>> fields = np.array([np.ones((162, 162), dtype=complex)]) >>> properties = optics.properties() @@ -1345,7 +1491,7 @@ def get( if output_region[3] is None else int(output_region[3] - limits[1, 0] + pad[3]) ) - + padded_volume = padded_volume[ output_region[0] : output_region[2], output_region[1] : output_region[3], @@ -1353,9 +1499,7 @@ def get( ] z_limits = limits[2, :] - output_image = Image( - np.zeros((*padded_volume.shape[0:2], 1)) - ) + output_image = np.zeros((*padded_volume.shape[0:2], 1)) index_iterator = range(padded_volume.shape[2]) z_iterator = np.linspace( @@ -1414,7 +1558,25 @@ def get( light_in_focus = light_in * shifted_pupil if len(fields) > 0: - field = np.sum(fields, axis=0) + # field = np.sum(fields, axis=0) + field_arrays = [] + + for fs in fields: + # fs is a ScatteredField + arr = fs.array + + # Enforce (H, W, 1) shape + if arr.ndim == 2: + arr = arr[..., None] + + if arr.ndim != 3 or arr.shape[-1] != 1: + raise ValueError( + f"Expected field of shape (H, W, 1), got {arr.shape}" + ) + + field_arrays.append(arr) + + field = np.sum(field_arrays, axis=0) light_in_focus += field[..., 0] shifted_pupil = np.fft.fftshift(pupils[-1]) light_in_focus = light_in_focus * shifted_pupil @@ -1426,7 +1588,7 @@ def get( : padded_volume.shape[0], : padded_volume.shape[1] ] output_image = np.expand_dims(output_image, axis=-1) - output_image = Image(output_image[pad[0] : -pad[2], pad[1] : -pad[3]]) + output_image = output_image[pad[0] : -pad[2], pad[1] : -pad[3]] if not kwargs.get("return_field", False): output_image = np.square(np.abs(output_image)) @@ -1436,7 +1598,7 @@ def get( # output_image = output_image * np.exp(1j * -np.pi / 4) # output_image = output_image + 1 - output_image.properties = illuminated_volume.properties + # output_image.properties = illuminated_volume.properties return output_image @@ -1624,6 +1786,73 @@ def __init__( illumination_angle=illumination_angle, **kwargs) + def validate_input(self, scattered): + if isinstance(scattered, ScatteredVolume): + warnings.warn( + "Darkfield imaging from ScatteredVolume is a very rough " + "approximation. Use ScatteredField for physically meaningful " + "darkfield simulations.", + UserWarning, + ) + + def extract_contrast_volume( + self, + scattered: ScatteredVolume, + refractive_index_medium: float, + **kwargs: Any, + ) -> np.ndarray: + """ + Approximate darkfield contrast from a volume (toy model). + + This is a non-physical approximation intended for qualitative simulations. + """ + + ri = scattered.get_property("refractive_index", None) + value = scattered.get_property("value", None) + intensity = scattered.get_property("intensity", None) + + # Intensity has no meaning here + if intensity is not None: + warnings.warn( + "Scatterer defines 'intensity', which is ignored in " + "darkfield microscopy.", + UserWarning, + ) + + if ri is not None: + delta_n = ri - refractive_index_medium + warnings.warn( + "Approximating darkfield contrast from refractive index. " + "Result is non-physical and qualitative only.", + UserWarning, + ) + return (delta_n ** 2) * scattered.array + + warnings.warn( + "No 'refractive_index' specified; using 'value' as a non-physical " + "darkfield scattering strength. Results are qualitative only.", + UserWarning, + ) + + return (value ** 2) * scattered.array + + def downscale_image(self, image: np.ndarray, upscale): + """Detector downscaling (energy conserving)""" + if not np.any(np.array(upscale) != 1): + return image + + ux, uy = upscale[:2] + if ux != uy: + raise ValueError( + f"Energy-conserving detector integration requires ux == uy, " + f"got ux={ux}, uy={uy}." + ) + if isinstance(ux, float) and ux.is_integer(): + ux = int(ux) + + # Energy-conserving detector integration + return SumPooling(ux)(image) + #Retrieve get as super def get( self: Darkfield, @@ -1631,7 +1860,7 @@ def get( limits: ArrayLike[int], fields: ArrayLike[complex], **kwargs: Any, - ) -> Image: + ) -> np.ndarray: """Retrieve the darkfield image of the illuminated volume. Parameters @@ -1800,9 +2029,1004 @@ def get( return image +class NonOverlapping(Feature): + """Ensure volumes are placed non-overlapping in a 3D space. + + This feature ensures that a list of 3D volumes are positioned such that + their non-zero voxels do not overlap. If volumes overlap, their positions + are resampled until they are non-overlapping. If the maximum number of + attempts is exceeded, the feature regenerates the list of volumes and + raises a warning if non-overlapping placement cannot be achieved. + + Note: `min_distance` refers to the distance between the edges of volumes, + not their centers. Due to the way volumes are calculated, slight rounding + errors may affect the final distance. + + This feature is incompatible with non-volumetric scatterers such as + `MieScatterers`. + + Parameters + ---------- + feature: Feature + The feature that generates the list of volumes to place + non-overlapping. + min_distance: float, optional + The minimum distance between volumes in pixels. It can be negative to + allow for partial overlap. Defaults to 1. + max_attempts: int, optional + The maximum number of attempts to place volumes without overlap. + Defaults to 5. + max_iters: int, optional + The maximum number of resamplings. If this number is exceeded, a new + list of volumes is generated. Defaults to 100. + + Attributes + ---------- + __distributed__: bool + Always `False` for `NonOverlapping`, indicating that this feature’s + `.get()` method processes the entire input at once even if it is a + list, rather than distributing calls for each item of the list.N + + Methods + ------- + `get(*_, min_distance, max_attempts, **kwargs) -> array` + Generate a list of non-overlapping 3D volumes. + `_check_non_overlapping(list_of_volumes) -> bool` + Check if all volumes in the list are non-overlapping. + `_check_bounding_cubes_non_overlapping(...) -> bool` + Check if two bounding cubes are non-overlapping. + `_get_overlapping_cube(...) -> list[int]` + Get the overlapping cube between two bounding cubes. + `_get_overlapping_volume(...) -> array` + Get the overlapping volume between a volume and a bounding cube. + `_check_volumes_non_overlapping(...) -> bool` + Check if two volumes are non-overlapping. + `_resample_volume_position(volume) -> Image` + Resample the position of a volume to avoid overlap. + + Notes + ----- + - This feature performs bounding cube checks first to quickly reject + obvious overlaps before voxel-level checks. + - If the bounding cubes overlap, precise voxel-based checks are performed. + - The feature may be computationally intensive for large numbers of volumes + or high-density placements. + - The feature is not differentiable. + + Examples + --------- + >>> import deeptrack as dt + + Define an ellipse scatterer with randomly positioned objects: + + >>> import numpy as np + >>> + >>> scatterer = dt.Ellipse( + >>> radius= 13 * dt.units.pixels, + >>> position=lambda: np.random.uniform(5, 115, size=2)* dt.units.pixels, + >>> ) + + Create multiple scatterers: + + >>> scatterers = (scatterer ^ 8) + + Define the optics and create the image with possible overlap: + + >>> optics = dt.Fluorescence() + >>> im_with_overlap = optics(scatterers) + >>> im_with_overlap.store_properties() + >>> im_with_overlap_resolved = image_with_overlap() + + Gather position from image: + + >>> pos_with_overlap = np.array( + >>> im_with_overlap_resolved.get_property( + >>> "position", + >>> get_one=False + >>> ) + >>> ) + + Enforce non-overlapping and create the image without overlap: + + >>> non_overlapping_scatterers = dt.NonOverlapping( + ... scatterers, + ... min_distance=4, + ... ) + >>> im_without_overlap = optics(non_overlapping_scatterers) + >>> im_without_overlap.store_properties() + >>> im_without_overlap_resolved = im_without_overlap() + + Gather position from image: + + >>> pos_without_overlap = np.array( + >>> im_without_overlap_resolved.get_property( + >>> "position", + >>> get_one=False + >>> ) + >>> ) + + Create a figure with two subplots to visualize the difference: + + >>> import matplotlib.pyplot as plt + >>> + >>> fig, axes = plt.subplots(1, 2, figsize=(10, 5)) + >>> + >>> axes[0].imshow(im_with_overlap_resolved, cmap="gray") + >>> axes[0].scatter(pos_with_overlap[:,1],pos_with_overlap[:,0]) + >>> axes[0].set_title("Overlapping Objects") + >>> axes[0].axis("off") + >>> + >>> axes[1].imshow(im_without_overlap_resolved, cmap="gray") + >>> axes[1].scatter(pos_without_overlap[:,1],pos_without_overlap[:,0]) + >>> axes[1].set_title("Non-Overlapping Objects") + >>> axes[1].axis("off") + >>> plt.tight_layout() + >>> + >>> plt.show() + + Define function to calculate minimum distance: + + >>> def calculate_min_distance(positions): + >>> distances = [ + >>> np.linalg.norm(positions[i] - positions[j]) + >>> for i in range(len(positions)) + >>> for j in range(i + 1, len(positions)) + >>> ] + >>> return min(distances) + + Print minimum distances with and without overlap: + + >>> print(calculate_min_distance(pos_with_overlap)) + 10.768742383382174 + + >>> print(calculate_min_distance(pos_without_overlap)) + 30.82531120942446 + + """ + + __distributed__: bool = False + + def __init__( + self: NonOverlapping, + feature: Feature, + min_distance: float = 1, + max_attempts: int = 5, + max_iters: int = 100, + **kwargs: Any, + ): + """Initializes the NonOverlapping feature. + + Ensures that volumes are placed **non-overlapping** by iteratively + resampling their positions. If the maximum number of attempts is + exceeded, the feature regenerates the list of volumes. + + Parameters + ---------- + feature: Feature + The feature that generates the list of volumes. + min_distance: float, optional + The minimum separation distance **between volume edges**, in + pixels. It defaults to `1`. Negative values allow for partial + overlap. + max_attempts: int, optional + The maximum number of attempts to place the volumes without + overlap. It defaults to `5`. + max_iters: int, optional + The maximum number of resampling iterations per attempt. If + exceeded, a new list of volumes is generated. It defaults to `100`. + + """ + + super().__init__( + min_distance=min_distance, + max_attempts=max_attempts, + max_iters=max_iters, + **kwargs, + ) + self.feature = self.add_feature(feature, **kwargs) + + def get( + self: NonOverlapping, + *_: Any, + min_distance: float, + max_attempts: int, + max_iters: int, + **kwargs: Any, + ) -> list[np.ndarray]: + """Generates a list of non-overlapping 3D volumes within a defined + field of view (FOV). + + This method **iteratively** attempts to place volumes while ensuring + they maintain at least `min_distance` separation. If non-overlapping + placement is not achieved within `max_attempts`, a warning is issued, + and the best available configuration is returned. + + Parameters + ---------- + _: Any + Placeholder parameter, typically for an input image. + min_distance: float + The minimum required separation distance between volumes, in + pixels. + max_attempts: int + The maximum number of attempts to generate a valid non-overlapping + configuration. + max_iters: int + The maximum number of resampling iterations per attempt. + **kwargs: Any + Additional parameters that may be used by subclasses. + + Returns + ------- + list[np.ndarray] + A list of 3D volumes represented as NumPy arrays. If + non-overlapping placement is unsuccessful, the best available + configuration is returned. + + Warns + ----- + UserWarning + If non-overlapping placement is **not** achieved within + `max_attempts`, suggesting parameter adjustments such as increasing + the FOV or reducing `min_distance`. + + Notes + ----- + - The placement process prioritizes bounding cube checks for + efficiency. + - If bounding cubes overlap, voxel-based overlap checks are performed. + + """ + + for _ in range(max_attempts): + list_of_volumes = self.feature() + + if not isinstance(list_of_volumes, list): + list_of_volumes = [list_of_volumes] + + for _ in range(max_iters): + + list_of_volumes = [ + self._resample_volume_position(volume) + for volume in list_of_volumes + ] + + if self._check_non_overlapping(list_of_volumes): + return list_of_volumes + + # Generate a new list of volumes if max_attempts is exceeded. + self.feature.update() + + warnings.warn( + "Non-overlapping placement could not be achieved. Consider " + "adjusting parameters: reduce object radius, increase FOV, " + "or decrease min_distance.", + UserWarning, + ) + return list_of_volumes + + def _check_non_overlapping( + self: NonOverlapping, + list_of_volumes: list[np.ndarray], + ) -> bool: + """Determines whether all volumes in the provided list are + non-overlapping. + + This method verifies that the non-zero voxels of each 3D volume in + `list_of_volumes` are at least `min_distance` apart. It first checks + bounding boxes for early rejection and then examines actual voxel + overlap when necessary. Volumes are assumed to have a `position` + attribute indicating their placement in 3D space. + + Parameters + ---------- + list_of_volumes: list[np.ndarray] + A list of 3D arrays representing the volumes to be checked for + overlap. Each volume is expected to have a position attribute. + + Returns + ------- + bool + `True` if all volumes are non-overlapping, otherwise `False`. + + Notes + ----- + - If `min_distance` is negative, volumes are shrunk using isotropic + erosion before checking overlap. + - If `min_distance` is positive, volumes are padded and expanded using + isotropic dilation. + - Overlapping checks are first performed on bounding cubes for + efficiency. + - If bounding cubes overlap, voxel-level checks are performed. + + """ + from deeptrack.scatterers import ScatteredVolume + + from deeptrack.augmentations import CropTight, Pad # these are not compatibles with torch backend + from deeptrack.optics import _get_position + from deeptrack.math import isotropic_erosion, isotropic_dilation + + min_distance = self.min_distance() + crop = CropTight() + + new_volumes = [] + + for volume in list_of_volumes: + arr = volume.array + mask = arr != 0 + + if min_distance < 0: + new_arr = isotropic_erosion(mask, -min_distance / 2, backend=self.get_backend()) + else: + pad = Pad(px=[int(np.ceil(min_distance / 2))] * 6, keep_size=True) + new_arr = isotropic_dilation(pad(mask) != 0 , min_distance / 2, backend=self.get_backend()) + new_arr = crop(new_arr) + + if self.get_backend() == "torch": + new_arr = new_arr.to(dtype=arr.dtype) + else: + new_arr = new_arr.astype(arr.dtype) + + new_volume = ScatteredVolume( + array=new_arr, + properties=volume.properties.copy(), + ) + + new_volumes.append(new_volume) + + list_of_volumes = new_volumes + min_distance = 1 + + # The position of the top left corner of each volume (index (0, 0, 0)). + volume_positions_1 = [ + _get_position(volume, mode="corner", return_z=True).astype(int) + for volume in list_of_volumes + ] + + # The position of the bottom right corner of each volume + # (index (-1, -1, -1)). + volume_positions_2 = [ + p0 + np.array(v.shape) + for v, p0 in zip(list_of_volumes, volume_positions_1) + ] + + # (x1, y1, z1, x2, y2, z2) for each volume. + volume_bounding_cube = [ + [*p0, *p1] + for p0, p1 in zip(volume_positions_1, volume_positions_2) + ] + + for i, j in itertools.combinations(range(len(list_of_volumes)), 2): + + # If the bounding cubes do not overlap, the volumes do not overlap. + if self._check_bounding_cubes_non_overlapping( + volume_bounding_cube[i], volume_bounding_cube[j], min_distance + ): + continue + + # If the bounding cubes overlap, get the overlapping region of each + # volume. + overlapping_cube = self._get_overlapping_cube( + volume_bounding_cube[i], volume_bounding_cube[j] + ) + overlapping_volume_1 = self._get_overlapping_volume( + list_of_volumes[i].array, volume_bounding_cube[i], overlapping_cube + ) + overlapping_volume_2 = self._get_overlapping_volume( + list_of_volumes[j].array, volume_bounding_cube[j], overlapping_cube + ) + + # If either the overlapping regions are empty, the volumes do not + # overlap (done for speed). + if (np.all(overlapping_volume_1 == 0) + or np.all(overlapping_volume_2 == 0)): + continue + + # If products of overlapping regions are non-zero, return False. + # if np.any(overlapping_volume_1 * overlapping_volume_2): + # return False + + # Finally, check that the non-zero voxels of the volumes are at + # least min_distance apart. + if not self._check_volumes_non_overlapping( + overlapping_volume_1, overlapping_volume_2, min_distance + ): + return False + + return True + + def _check_bounding_cubes_non_overlapping( + self: NonOverlapping, + bounding_cube_1: list[int], + bounding_cube_2: list[int], + min_distance: float, + ) -> bool: + """Determines whether two 3D bounding cubes are non-overlapping. + + This method checks whether the bounding cubes of two volumes are + **separated by at least** `min_distance` along **any** spatial axis. + + Parameters + ---------- + bounding_cube_1: list[int] + A list of six integers `[x1, y1, z1, x2, y2, z2]` representing + the first bounding cube. + bounding_cube_2: list[int] + A list of six integers `[x1, y1, z1, x2, y2, z2]` representing + the second bounding cube. + min_distance: float + The required **minimum separation distance** between the two + bounding cubes. + + Returns + ------- + bool + `True` if the bounding cubes are non-overlapping (separated by at + least `min_distance` along **at least one axis**), otherwise + `False`. + + Notes + ----- + - This function **only checks bounding cubes**, **not actual voxel + data**. + - If the bounding cubes are non-overlapping, the corresponding + **volumes are also non-overlapping**. + - This check is much **faster** than full voxel-based comparisons. + + """ + + # bounding_cube_1 and bounding_cube_2 are (x1, y1, z1, x2, y2, z2). + # Check that the bounding cubes are non-overlapping. + return ( + (bounding_cube_1[0] >= bounding_cube_2[3] + min_distance) or + (bounding_cube_2[0] >= bounding_cube_1[3] + min_distance) or + (bounding_cube_1[1] >= bounding_cube_2[4] + min_distance) or + (bounding_cube_2[1] >= bounding_cube_1[4] + min_distance) or + (bounding_cube_1[2] >= bounding_cube_2[5] + min_distance) or + (bounding_cube_2[2] >= bounding_cube_1[5] + min_distance) + ) + + def _get_overlapping_cube( + self: NonOverlapping, + bounding_cube_1: list[int], + bounding_cube_2: list[int], + ) -> list[int]: + """Computes the overlapping region between two 3D bounding cubes. + + This method calculates the coordinates of the intersection of two + axis-aligned bounding cubes, each represented as a list of six + integers: + + - `[x1, y1, z1]`: Coordinates of the **top-left-front** corner. + - `[x2, y2, z2]`: Coordinates of the **bottom-right-back** corner. + + The resulting overlapping region is determined by: + - Taking the **maximum** of the starting coordinates (`x1, y1, z1`). + - Taking the **minimum** of the ending coordinates (`x2, y2, z2`). + + If the cubes **do not** overlap, the resulting coordinates will not + form a valid cube (i.e., `x1 > x2`, `y1 > y2`, or `z1 > z2`). + + Parameters + ---------- + bounding_cube_1: list[int] + The first bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`. + bounding_cube_2: list[int] + The second bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`. + + Returns + ------- + list[int] + A list of six integers `[x1, y1, z1, x2, y2, z2]` representing the + overlapping bounding cube. If no overlap exists, the coordinates + will **not** define a valid cube. + + Notes + ----- + - This function does **not** check for valid input or ensure the + resulting cube is well-formed. + - If no overlap exists, downstream functions must handle the invalid + result. + + """ + + return [ + max(bounding_cube_1[0], bounding_cube_2[0]), + max(bounding_cube_1[1], bounding_cube_2[1]), + max(bounding_cube_1[2], bounding_cube_2[2]), + min(bounding_cube_1[3], bounding_cube_2[3]), + min(bounding_cube_1[4], bounding_cube_2[4]), + min(bounding_cube_1[5], bounding_cube_2[5]), + ] + + def _get_overlapping_volume( + self: NonOverlapping, + volume: np.ndarray, # 3D array. + bounding_cube: tuple[float, float, float, float, float, float], + overlapping_cube: tuple[float, float, float, float, float, float], + ) -> np.ndarray: + """Extracts the overlapping region of a 3D volume within the specified + overlapping cube. + + This method identifies and returns the subregion of `volume` that + lies within the `overlapping_cube`. The bounding information of the + volume is provided via `bounding_cube`. + + Parameters + ---------- + volume: np.ndarray + A 3D NumPy array representing the volume from which the + overlapping region is extracted. + bounding_cube: tuple[float, float, float, float, float, float] + The bounding cube of the volume, given as a tuple of six floats: + `(x1, y1, z1, x2, y2, z2)`. The first three values define the + **top-left-front** corner, while the last three values define the + **bottom-right-back** corner. + overlapping_cube: tuple[float, float, float, float, float, float] + The overlapping region between the volume and another volume, + represented in the same format as `bounding_cube`. + + Returns + ------- + np.ndarray + A 3D NumPy array representing the portion of `volume` that + lies within `overlapping_cube`. If the overlap does not exist, + an empty array may be returned. + + Notes + ----- + - The method computes the relative indices of `overlapping_cube` + within `volume` by subtracting the bounding cube's starting + position. + - The extracted region is determined by integer indices, meaning + coordinates are implicitly **floored to integers**. + - If `overlapping_cube` extends beyond `volume` boundaries, the + returned subregion is **cropped** to fit within `volume`. + + """ + + # The position of the top left corner of the overlapping cube in the volume + overlapping_cube_position = np.array(overlapping_cube[:3]) - np.array( + bounding_cube[:3] + ) + + # The position of the bottom right corner of the overlapping cube in the volume + overlapping_cube_end_position = np.array( + overlapping_cube[3:] + ) - np.array(bounding_cube[:3]) + + # cast to int + overlapping_cube_position = overlapping_cube_position.astype(int) + overlapping_cube_end_position = overlapping_cube_end_position.astype(int) + + return volume[ + overlapping_cube_position[0] : overlapping_cube_end_position[0], + overlapping_cube_position[1] : overlapping_cube_end_position[1], + overlapping_cube_position[2] : overlapping_cube_end_position[2], + ] + + def _check_volumes_non_overlapping( + self: NonOverlapping, + volume_1: np.ndarray, + volume_2: np.ndarray, + min_distance: float, + ) -> bool: + """Determines whether the non-zero voxels in two 3D volumes are at + least `min_distance` apart. + + This method checks whether the active regions (non-zero voxels) in + `volume_1` and `volume_2` maintain a minimum separation of + `min_distance`. If the volumes differ in size, the positions of their + non-zero voxels are adjusted accordingly to ensure a fair comparison. + + Parameters + ---------- + volume_1: np.ndarray + A 3D NumPy array representing the first volume. + volume_2: np.ndarray + A 3D NumPy array representing the second volume. + min_distance: float + The minimum Euclidean distance required between any two non-zero + voxels in the two volumes. + + Returns + ------- + bool + `True` if all non-zero voxels in `volume_1` and `volume_2` are at + least `min_distance` apart, otherwise `False`. + + Notes + ----- + - This function assumes both volumes are correctly aligned within a + shared coordinate space. + - If the volumes are of different sizes, voxel positions are scaled + or adjusted for accurate distance measurement. + - Uses **Euclidean distance** for separation checking. + - If either volume is empty (i.e., no non-zero voxels), they are + considered non-overlapping. + + """ + + # Get the positions of the non-zero voxels of each volume. + if self.get_backend() == "torch": + positions_1 = torch.nonzero(volume_1, as_tuple=False) + positions_2 = torch.nonzero(volume_2, as_tuple=False) + else: + positions_1 = np.argwhere(volume_1) + positions_2 = np.argwhere(volume_2) + + # if positions_1.size == 0 or positions_2.size == 0: + # return True # If either volume is empty, they are "non-overlapping" + + # # If the volumes are not the same size, the positions of the non-zero + # # voxels of each volume need to be scaled. + # if positions_1.size == 0 or positions_2.size == 0: + # return True # If either volume is empty, they are "non-overlapping" + + # If the volumes are not the same size, the positions of the non-zero + # voxels of each volume need to be scaled. + if volume_1.shape != volume_2.shape: + positions_1 = ( + positions_1 * np.array(volume_2.shape) + / np.array(volume_1.shape) + ) + positions_1 = positions_1.astype(int) + + # Check that the non-zero voxels of the volumes are at least + # min_distance apart. + if self.get_backend() == "torch": + dist = torch.cdist( + positions_1.float(), + positions_2.float(), + ) + return bool((dist > min_distance).all()) + else: + return np.all(cdist(positions_1, positions_2) > min_distance) + + def _resample_volume_position( + self: NonOverlapping, + volume: np.ndarray | Image, + ) -> Image: + """Resamples the position of a 3D volume using its internal position + sampler. + + This method updates the `position` property of the given `volume` by + drawing a new position from the `_position_sampler` stored in the + volume's `properties`. If the sampled position is a `Quantity`, it is + converted to pixel units. + + Parameters + ---------- + volume: np.ndarray + The 3D volume whose position is to be resampled. The volume must + have a `properties` attribute containing dictionaries with + `position` and `_position_sampler` keys. + + Returns + ------- + Image + The same input volume with its `position` property updated to the + newly sampled value. + + Notes + ----- + - The `_position_sampler` function is expected to return a **tuple of + three floats** (e.g., `(x, y, z)`). + - If the sampled position is a `Quantity`, it is converted to pixels. + - **Only** dictionaries in `volume.properties` that contain both + `position` and `_position_sampler` keys are modified. + + """ + + pdict = volume.properties + if "position" in pdict and "_position_sampler" in pdict: + new_position = pdict["_position_sampler"]() + if isinstance(new_position, Quantity): + new_position = new_position.to("pixel").magnitude + pdict["position"] = new_position + + return volume + + +class SampleToMasks(Feature): + """Create a mask from a list of images. + + This feature applies a transformation function to each input image and + merges the resulting masks into a single multi-layer image. Each input + image must have a `position` property that determines its placement within + the final mask. When used with scatterers, the `voxel_size` property must + be provided for correct object sizing. + + Parameters + ---------- + transformation_function: Callable[[Image], Image] + A function that transforms each input image into a mask with + `number_of_masks` layers. + number_of_masks: PropertyLike[int], optional + The number of mask layers to generate. Default is 1. + output_region: PropertyLike[tuple[int, int, int, int]], optional + The size and position of the output mask, typically aligned with + `optics.output_region`. + merge_method: PropertyLike[str | Callable | list[str | Callable]], optional + Method for merging individual masks into the final image. Can be: + - "add" (default): Sum the masks. + - "overwrite": Later masks overwrite earlier masks. + - "or": Combine masks using a logical OR operation. + - "mul": Multiply masks. + - Function: Custom function taking two images and merging them. + + **kwargs: dict[str, Any] + Additional keyword arguments passed to the parent `Feature` class. + + Methods + ------- + `get(image, transformation_function, **kwargs) -> Image` + Applies the transformation function to the input image. + `_process_and_get(images, **kwargs) -> Image | np.ndarray` + Processes a list of images and generates a multi-layer mask. + + Returns + ------- + np.ndarray + The final mask image with the specified number of layers. + + Raises + ------ + ValueError + If `merge_method` is invalid. + + Examples + ------- + >>> import deeptrack as dt + + Define number of particles: + + >>> n_particles = 12 + + Define optics and particles: + + >>> import numpy as np + >>> + >>> optics = dt.Fluorescence(output_region=(0, 0, 64, 64)) + >>> particle = dt.PointParticle( + >>> position=lambda: np.random.uniform(5, 55, size=2), + >>> ) + >>> particles = particle ^ n_particles + + Define pipelines: + + >>> sim_im_pip = optics(particles) + >>> sim_mask_pip = particles >> dt.SampleToMasks( + ... lambda: lambda particles: particles > 0, + ... output_region=optics.output_region, + ... merge_method="or", + ... ) + >>> pipeline = sim_im_pip & sim_mask_pip + >>> pipeline.store_properties() + + Generate image and mask: + + >>> image, mask = pipeline.update()() + + Get particle positions: + + >>> positions = np.array(image.get_property("position", get_one=False)) + + Visualize results: + + >>> import matplotlib.pyplot as plt + >>> + >>> plt.subplot(1, 2, 1) + >>> plt.imshow(image, cmap="gray") + >>> plt.title("Original Image") + >>> plt.subplot(1, 2, 2) + >>> plt.imshow(mask, cmap="gray") + >>> plt.scatter(positions[:,1], positions[:,0], c="y", marker="x", s = 50) + >>> plt.title("Mask") + >>> plt.show() + + """ + + def __init__( + self: Feature, + transformation_function: Callable[[np.ndarray], np.ndarray, torch.Tensor], + number_of_masks: PropertyLike[int] = 1, + output_region: PropertyLike[tuple[int, int, int, int]] = None, + merge_method: PropertyLike[str | Callable | list[str | Callable]] = "add", + **kwargs: Any, + ): + """Initialize the SampleToMasks feature. + + Parameters + ---------- + transformation_function: Callable[[Image], Image] + Function to transform input images into masks. + number_of_masks: PropertyLike[int], optional + Number of mask layers. Default is 1. + output_region: PropertyLike[tuple[int, int, int, int]], optional + Output region of the mask. Default is None. + merge_method: PropertyLike[str | Callable | list[str | Callable]], optional + Method to merge masks. Defaults to "add". + **kwargs: dict[str, Any] + Additional keyword arguments passed to the parent class. + + """ + + super().__init__( + transformation_function=transformation_function, + number_of_masks=number_of_masks, + output_region=output_region, + merge_method=merge_method, + **kwargs, + ) + + def get( + self: Feature, + image: np.ndarray, + transformation_function: Callable[list[np.ndarray] | np.ndarray | torch.Tensor], + **kwargs: Any, + ) -> np.ndarray: + """Apply the transformation function to a single image. + + Parameters + ---------- + image: np.ndarray + The input image. + transformation_function: Callable[[np.ndarray], np.ndarray] + Function to transform the image. + **kwargs: dict[str, Any] + Additional parameters. + + Returns + ------- + Image + The transformed image. + + """ + + return transformation_function(image.array) + + def _process_and_get( + self: Feature, + images: list[np.ndarray] | np.ndarray | list[torch.Tensor] | torch.Tensor, + **kwargs: Any, + ) -> np.ndarray: + """Process a list of images and generate a multi-layer mask. + + Parameters + ---------- + images: np.ndarray or list[np.ndarrray] or Image or list[Image] + List of input images or a single image. + **kwargs: dict[str, Any] + Additional parameters including `output_region`, `number_of_masks`, + and `merge_method`. + + Returns + ------- + Image or np.ndarray + The final mask image. + + """ + + # Handle list of images. + # if isinstance(images, list) and len(images) != 1: + list_of_labels = super()._process_and_get(images, **kwargs) + + from deeptrack.scatterers import ScatteredVolume + + for idx, (label, image) in enumerate(zip(list_of_labels, images)): + list_of_labels[idx] = \ + ScatteredVolume(array=label, properties=image.properties.copy()) + + # Create an empty output image. + output_region = kwargs["output_region"] + output = xp.zeros( + ( + output_region[2] - output_region[0], + output_region[3] - output_region[1], + kwargs["number_of_masks"], + ), + dtype=list_of_labels[0].array.dtype, + ) + + from deeptrack.optics import _get_position + + # Merge masks into the output. + for volume in list_of_labels: + label = volume.array + position = _get_position(volume) + + p0 = xp.round(position - xp.asarray(output_region[0:2])) + p0 = p0.astype(xp.int64) + + + if xp.any(p0 > xp.asarray(output.shape[:2])) or \ + xp.any(p0 + xp.asarray(label.shape[:2]) < 0): + continue + + crop_x = (-xp.minimum(p0[0], 0)).item() + crop_y = (-xp.minimum(p0[1], 0)).item() + + crop_x_end = int( + label.shape[0] + - np.max([p0[0] + label.shape[0] - output.shape[0], 0]) + ) + crop_y_end = int( + label.shape[1] + - np.max([p0[1] + label.shape[1] - output.shape[1], 0]) + ) + + labelarg = label[crop_x:crop_x_end, crop_y:crop_y_end, :] + + p0[0] = np.max([p0[0], 0]) + p0[1] = np.max([p0[1], 0]) + + p0 = p0.astype(int) + + output_slice = output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + ] + + for label_index in range(kwargs["number_of_masks"]): + + if isinstance(kwargs["merge_method"], list): + merge = kwargs["merge_method"][label_index] + else: + merge = kwargs["merge_method"] + + if merge == "add": + output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + label_index, + ] += labelarg[..., label_index] + + elif merge == "overwrite": + output_slice[ + labelarg[..., label_index] != 0, label_index + ] = labelarg[labelarg[..., label_index] != 0, \ + label_index] + output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + label_index, + ] = output_slice[..., label_index] + + elif merge == "or": + output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + label_index, + ] = xp.logical_or( + output_slice[..., label_index] != 0, + labelarg[..., label_index] != 0 + ) + + elif merge == "mul": + output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + label_index, + ] *= labelarg[..., label_index] + + else: + # No match, assume function + output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + label_index, + ] = merge( + output_slice[..., label_index], + labelarg[..., label_index], + ) + + return output + + #TODO ***??*** revise _get_position - torch, typing, docstring, unit test def _get_position( - image: Image, + scatterer: ScatteredObject, mode: str = "corner", return_z: bool = False, ) -> np.ndarray: @@ -1826,26 +3050,23 @@ def _get_position( num_outputs = 2 + return_z - if mode == "corner" and image.size > 0: + if mode == "corner" and scatterer.array.size > 0: import scipy.ndimage - image = image.to_numpy() - - shift = scipy.ndimage.center_of_mass(np.abs(image)) + shift = scipy.ndimage.center_of_mass(np.abs(scatterer.array)) if np.isnan(shift).any(): - shift = np.array(image.shape) / 2 + shift = np.array(scatterer.array.shape) / 2 else: shift = np.zeros((num_outputs)) - position = np.array(image.get_property("position", default=None)) + position = np.array(scatterer.get_property("position", default=None)) if position is None: return position scale = np.array(get_active_scale()) - if len(position) == 3: position = position * scale + 0.5 * (scale - 1) if return_z: @@ -1856,7 +3077,7 @@ def _get_position( elif len(position) == 2: if return_z: outp = ( - np.array([position[0], position[1], image.get_property("z", default=0)]) + np.array([position[0], position[1], scatterer.get_property("z", default=0)]) * scale - shift + 0.5 * (scale - 1) @@ -1868,6 +3089,58 @@ def _get_position( return position +def _bilinear_interpolate_numpy( + scatterer: np.ndarray, x_off: float, y_off: float +) -> np.ndarray: + """Apply bilinear subpixel interpolation in the x–y plane (NumPy).""" + kernel = np.array( + [ + [0.0, 0.0, 0.0], + [0.0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off], + [0.0, x_off * (1 - y_off), x_off * y_off], + ] + ) + out = np.zeros_like(scatterer) + for z in range(scatterer.shape[2]): + if np.iscomplexobj(scatterer): + out[:, :, z] = ( + convolve(np.real(scatterer[:, :, z]), kernel, mode="constant") + + 1j + * convolve(np.imag(scatterer[:, :, z]), kernel, mode="constant") + ) + else: + out[:, :, z] = convolve(scatterer[:, :, z], kernel, mode="constant") + return out + + +def _bilinear_interpolate_torch( + scatterer: torch.Tensor, x_off: float, y_off: float +) -> torch.Tensor: + """Apply bilinear subpixel interpolation in the x–y plane (Torch). + + Uses grid_sample for autograd-friendly interpolation. + """ + H, W, D = scatterer.shape + + # Normalized shifts in [-1,1] + x_shift = 2 * x_off / (W - 1) + y_shift = 2 * y_off / (H - 1) + + yy, xx = torch.meshgrid( + torch.linspace(-1, 1, H, device=scatterer.device, dtype=scatterer.dtype), + torch.linspace(-1, 1, W, device=scatterer.device, dtype=scatterer.dtype), + indexing="ij", + ) + grid = torch.stack((xx + x_shift, yy + y_shift), dim=-1) # (H,W,2) + grid = grid.unsqueeze(0).repeat(D, 1, 1, 1) # (D,H,W,2) + + inp = scatterer.permute(2, 0, 1).unsqueeze(1) # (D,1,H,W) + + out = F.grid_sample(inp, grid, mode="bilinear", + padding_mode="zeros", align_corners=True) + return out.squeeze(1).permute(1, 2, 0) # (H,W,D) + + #TODO ***??*** revise _create_volume - torch, typing, docstring, unit test def _create_volume( list_of_scatterers: list, @@ -1903,6 +3176,12 @@ def _create_volume( Spatial limits of the volume. """ + # contrast_type = kwargs.get("contrast_type", None) + # if contrast_type is None: + # raise RuntimeError( + # "_create_volume requires a contrast_type " + # "(e.g. 'intensity' or 'refractive_index')" + # ) if not isinstance(list_of_scatterers, list): list_of_scatterers = [list_of_scatterers] @@ -1927,24 +3206,28 @@ def _create_volume( # This accounts for upscale doing AveragePool instead of SumPool. This is # a bit of a hack, but it works for now. - fudge_factor = scale[0] * scale[1] / scale[2] + # fudge_factor = scale[0] * scale[1] / scale[2] for scatterer in list_of_scatterers: - position = _get_position(scatterer, mode="corner", return_z=True) - if scatterer.get_property("intensity", None) is not None: - intensity = scatterer.get_property("intensity") - scatterer_value = intensity * fudge_factor - elif scatterer.get_property("refractive_index", None) is not None: - refractive_index = scatterer.get_property("refractive_index") - scatterer_value = ( - refractive_index - refractive_index_medium - ) - else: - scatterer_value = scatterer.get_property("value") + # if contrast_type == "intensity": + # value = scatterer.get_property("intensity", None) + # if value is None: + # raise ValueError("Scatterer has no intensity.") + # scatterer_value = value + + # elif contrast_type == "refractive_index": + # ri = scatterer.get_property("refractive_index", None) + # if ri is None: + # raise ValueError("Scatterer has no refractive_index.") + # scatterer_value = ri - refractive_index_medium + + # else: + # raise RuntimeError(f"Unknown contrast_type: {contrast_type}") - scatterer = scatterer * scatterer_value + # # Scale the array accordingly + # scatterer.array = scatterer.array * scatterer_value if limits is None: limits = np.zeros((3, 2), dtype=np.int32) @@ -1952,26 +3235,25 @@ def _create_volume( limits[:, 1] = np.floor(position).astype(np.int32) + 1 if ( - position[0] + scatterer.shape[0] < OR[0] + position[0] + scatterer.array.shape[0] < OR[0] or position[0] > OR[2] - or position[1] + scatterer.shape[1] < OR[1] + or position[1] + scatterer.array.shape[1] < OR[1] or position[1] > OR[3] ): continue - padded_scatterer = Image( - np.pad( - scatterer, + # Pad scatterer to avoid edge effects during interpolation + padded_scatterer_arr = np.pad( #Use Pad instead and make it torch-compatible? + scatterer.array, [(2, 2), (2, 2), (2, 2)], "constant", constant_values=0, ) - ) - padded_scatterer.merge_properties_from(scatterer) - - scatterer = padded_scatterer - position = _get_position(scatterer, mode="corner", return_z=True) - shape = np.array(scatterer.shape) + padded_scatterer = ScatteredVolume( + array=padded_scatterer_arr, properties=scatterer.properties.copy(), + ) + position = _get_position(padded_scatterer, mode="corner", return_z=True) + shape = np.array(padded_scatterer.array.shape) if position is None: RuntimeWarning( @@ -1980,36 +3262,20 @@ def _create_volume( ) continue - splined_scatterer = np.zeros_like(scatterer) - x_off = position[0] - np.floor(position[0]) y_off = position[1] - np.floor(position[1]) - kernel = np.array( - [ - [0, 0, 0], - [0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off], - [0, x_off * (1 - y_off), x_off * y_off], - ] - ) - - for z in range(scatterer.shape[2]): - if splined_scatterer.dtype == complex: - splined_scatterer[:, :, z] = ( - convolve( - np.real(scatterer[:, :, z]), kernel, mode="constant" - ) - + convolve( - np.imag(scatterer[:, :, z]), kernel, mode="constant" - ) - * 1j - ) - else: - splined_scatterer[:, :, z] = convolve( - scatterer[:, :, z], kernel, mode="constant" - ) + + if isinstance(padded_scatterer.array, np.ndarray): # get_backend is a method of Features and not exposed + splined_scatterer = _bilinear_interpolate_numpy(padded_scatterer.array, x_off, y_off) + elif isinstance(padded_scatterer.array, torch.Tensor): + splined_scatterer = _bilinear_interpolate_torch(padded_scatterer.array, x_off, y_off) + else: + raise TypeError( + f"Unsupported array type {type(padded_scatterer.array)}. " + "Expected np.ndarray or torch.Tensor." + ) - scatterer = splined_scatterer position = np.floor(position) new_limits = np.zeros(limits.shape, dtype=np.int32) for i in range(3): @@ -2038,7 +3304,8 @@ def _create_volume( within_volume_position = position - limits[:, 0] - # NOTE: Maybe shouldn't be additive. + # NOTE: Maybe shouldn't be ONLY additive. + # give options: sum default, but also mean, max, min, or volume[ int(within_volume_position[0]) : int(within_volume_position[0] + shape[0]), @@ -2048,5 +3315,5 @@ def _create_volume( int(within_volume_position[2]) : int(within_volume_position[2] + shape[2]), - ] += scatterer - return volume, limits + ] += splined_scatterer + return volume, limits \ No newline at end of file diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index 04a7c5ea..b7e1b70a 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -166,6 +166,7 @@ import numpy as np from numpy.typing import NDArray from pint import Quantity +from dataclasses import dataclass, field from deeptrack.holography import get_propagation_matrix from deeptrack.backend.units import ( @@ -174,12 +175,14 @@ get_active_voxel_size, ) from deeptrack.backend import mie +from deeptrack.math import AveragePooling from deeptrack.features import Feature, MERGE_STRATEGY_APPEND -from deeptrack.image import pad_image_to_fft, Image +from deeptrack.image import pad_image_to_fft from deeptrack.types import ArrayLike from deeptrack import units_registry as u + __all__ = [ "Scatterer", "PointParticle", @@ -238,7 +241,7 @@ class Scatterer(Feature): """ - __list_merge_strategy__ = MERGE_STRATEGY_APPEND + __list_merge_strategy__ = MERGE_STRATEGY_APPEND ### Not clear why needed __distributed__ = False __conversion_table__ = ConversionTable( position=(u.pixel, u.pixel), @@ -258,11 +261,11 @@ def __init__( **kwargs, ) -> None: # Ignore warning to help with comparison with arrays. - if upsample is not 1: # noqa: F632 - warnings.warn( - f"Setting upsample != 1 is deprecated. " - f"Please, instead use dt.Upscale(f, factor={upsample})" - ) + # if upsample != 1: # noqa: F632 + # warnings.warn( + # f"Setting upsample != 1 is deprecated. " + # f"Please, instead use dt.Upscale(f, factor={upsample})" + # ) self._processed_properties = False @@ -278,6 +281,21 @@ def __init__( **kwargs, ) + def _antialias_volume(self, volume, factor: int): + """Geometry-only supersampling anti-aliasing. + + Assumes `volume` was generated on a grid oversampled by `factor` + and downsamples it back by average pooling. + """ + if factor == 1: + return volume + + # average pooling conserves fractional occupancy + return AveragePooling( + factor + )(volume) + + def _process_properties( self, properties: dict @@ -296,7 +314,7 @@ def _process_and_get( upsample_axes=None, crop_empty=True, **kwargs - ) -> list[Image] | list[np.ndarray]: + ) -> list[np.ndarray]: # Post processes the created object to handle upsampling, # as well as cropping empty slices. if not self._processed_properties: @@ -307,16 +325,31 @@ def _process_and_get( + "Optics.upscale != 1." ) - voxel_size = get_active_voxel_size() - # Calls parent _process_and_get. - new_image = super()._process_and_get( + voxel_size = np.asarray(get_active_voxel_size(), float) + + apply_supersampling = upsample > 1 and isinstance(self, VolumeScatterer) + + if upsample > 1 and not apply_supersampling: + warnings.warn( + "Geometry supersampling (upsample) is ignored for " + "FieldScatterers.", + UserWarning, + ) + + if apply_supersampling: + voxel_size /= float(upsample) + + new_image = super(Scatterer, self)._process_and_get( *args, voxel_size=voxel_size, upsample=upsample, **kwargs, - ) - new_image = new_image[0] + )[0] + + if apply_supersampling: + new_image = self._antialias_volume(new_image, factor=upsample) + if new_image.size == 0: warnings.warn( @@ -333,32 +366,35 @@ def _process_and_get( new_image = new_image[:, ~np.all(new_image == 0, axis=(0, 2))] new_image = new_image[:, :, ~np.all(new_image == 0, axis=(0, 1))] - return [Image(new_image)] + # # Copy properties + # props = kwargs.copy() + return [self._wrap_output(new_image, kwargs)] - def _no_wrap_format_input( - self, - *args, - **kwargs - ) -> list: - return self._image_wrapped_format_input(*args, **kwargs) + def _wrap_output(self, array, props): + raise NotImplementedError( + f"{self.__class__.__name__} must implement _wrap_output()" + ) - def _no_wrap_process_and_get( - self, - *args, - **feature_input - ) -> list: - return self._image_wrapped_process_and_get(*args, **feature_input) - def _no_wrap_process_output( - self, - *args, - **feature_input - ) -> list: - return self._image_wrapped_process_output(*args, **feature_input) +class VolumeScatterer(Scatterer): + """Abstract scatterer producing ScatteredVolume outputs.""" + def _wrap_output(self, array, props) -> ScatteredVolume: + return ScatteredVolume( + array=array, + properties=props.copy(), + ) + + +class FieldScatterer(Scatterer): + def _wrap_output(self, array, props) -> ScatteredField: + return ScatteredField( + array=array, + properties=props.copy(), + ) #TODO ***??*** revise PointParticle - torch, typing, docstring, unit test -class PointParticle(Scatterer): +class PointParticle(VolumeScatterer): """Generate a diffraction-limited point particle. A point particle is approximated by the size of a single pixel or voxel. @@ -389,12 +425,12 @@ def __init__( """ """ - + kwargs.pop("upsample", None) super().__init__(upsample=1, upsample_axes=(), **kwargs) def get( self: PointParticle, - image: Image | np.ndarray, + image: np.ndarray, **kwarg: Any, ) -> NDArray[Any] | torch.Tensor: """Evaluate and return the scatterer volume.""" @@ -405,7 +441,7 @@ def get( #TODO ***??*** revise Ellipse - torch, typing, docstring, unit test -class Ellipse(Scatterer): +class Ellipse(VolumeScatterer): """Generates an elliptical disk scatterer Parameters @@ -441,6 +477,7 @@ class Ellipse(Scatterer): """ + __conversion_table__ = ConversionTable( radius=(u.meter, u.meter), rotation=(u.radian, u.radian), @@ -519,7 +556,7 @@ def get( #TODO ***??*** revise Sphere - torch, typing, docstring, unit test -class Sphere(Scatterer): +class Sphere(VolumeScatterer): """Generates a spherical scatterer Parameters @@ -559,7 +596,7 @@ def __init__( def get( self, - image: Image | np.ndarray, + image: np.ndarray, radius: float, voxel_size: float, **kwargs @@ -584,7 +621,7 @@ def get( #TODO ***??*** revise Ellipsoid - torch, typing, docstring, unit test -class Ellipsoid(Scatterer): +class Ellipsoid(VolumeScatterer): """Generates an ellipsoidal scatterer Parameters @@ -694,7 +731,7 @@ def _process_properties( def get( self, - image: Image | np.ndarray, + image: np.ndarray, radius: float, rotation: ArrayLike[float] | float, voxel_size: float, @@ -741,7 +778,7 @@ def get( #TODO ***??*** revise MieScatterer - torch, typing, docstring, unit test -class MieScatterer(Scatterer): +class MieScatterer(FieldScatterer): """Base implementation of a Mie particle. New Mie-theory scatterers can be implemented by extending this class, and @@ -826,6 +863,7 @@ class MieScatterer(Scatterer): """ + __conversion_table__ = ConversionTable( radius=(u.meter, u.meter), polarization_angle=(u.radian, u.radian), @@ -856,6 +894,7 @@ def __init__( illumination_angle: float=0, amp_factor: float=1, phase_shift_correction: bool=False, + # pupil: ArrayLike=[], # Daniel **kwargs, ) -> None: if polarization_angle is not None: @@ -864,11 +903,10 @@ def __init__( "Please use input_polarization instead" ) input_polarization = polarization_angle - kwargs.pop("is_field", None) kwargs.pop("crop_empty", None) super().__init__( - is_field=True, + is_field=True, # remove crop_empty=False, L=L, offset_z=offset_z, @@ -889,6 +927,7 @@ def __init__( illumination_angle=illumination_angle, amp_factor=amp_factor, phase_shift_correction=phase_shift_correction, + # pupil=pupil, # Daniel **kwargs, ) @@ -1014,7 +1053,8 @@ def get_plane_in_polar_coords( shape: int, voxel_size: ArrayLike[float], plane_position: float, - illumination_angle: float + illumination_angle: float, + # k: float, # Daniel ) -> tuple[float, float, float, float]: """Computes the coordinates of the plane in polar form.""" @@ -1027,15 +1067,24 @@ def get_plane_in_polar_coords( R2_squared = X ** 2 + Y ** 2 R3 = np.sqrt(R2_squared + Z ** 2) # Might be +z instead of -z. + + # # DANIEL + # Q = np.sqrt(R2_squared)/voxel_size[0]**2*2*np.pi/shape[0] + # # is dimensionally ok? + # sin_theta=Q/(k) + # pupil_mask=sin_theta<1 + # cos_theta=np.zeros(sin_theta.shape) + # cos_theta[pupil_mask]=np.sqrt(1-sin_theta[pupil_mask]**2) # Fet the angles. cos_theta = Z / R3 + illumination_cos_theta = ( np.cos(np.arccos(cos_theta) + illumination_angle) ) phi = np.arctan2(Y, X) - return R3, cos_theta, illumination_cos_theta, phi + return R3, cos_theta, illumination_cos_theta, phi#, pupil_mask # Daniel def get( self, @@ -1060,6 +1109,7 @@ def get( illumination_angle: float, amp_factor: float, phase_shift_correction: bool, + # pupil: ArrayLike, # Daniel **kwargs, ) -> ArrayLike[float]: """Abstract method to initialize the Mie scatterer""" @@ -1067,8 +1117,9 @@ def get( # Get size of the output. xSize, ySize = self.get_xy_size(output_region, padding) voxel_size = get_active_voxel_size() + scale = get_active_scale() arr = pad_image_to_fft(np.zeros((xSize, ySize))).astype(complex) - position = np.array(position) * voxel_size[: len(position)] + position = np.array(position) * scale[: len(position)] * voxel_size[: len(position)] pupil_physical_size = working_distance * np.tan(collection_angle) * 2 @@ -1076,7 +1127,10 @@ def get( ratio = offset_z / (working_distance - z) - # Position of pbjective relative particle. + # Wave vector. + k = 2 * np.pi / wavelength * refractive_index_medium + + # Position of objective relative particle. relative_position = np.array( ( position_objective[0] - position[0], @@ -1085,12 +1139,13 @@ def get( ) ) - # Get field evaluation plane at offset_z. + # Get field evaluation plane at offset_z. # , pupil_mask # Daniel R3_field, cos_theta_field, illumination_angle_field, phi_field =\ self.get_plane_in_polar_coords( arr.shape, voxel_size, relative_position * ratio, - illumination_angle + illumination_angle, + # k # Daniel ) cos_phi_field, sin_phi_field = np.cos(phi_field), np.sin(phi_field) @@ -1108,7 +1163,7 @@ def get( sin_phi_field / ratio ) - # If the beam is within the pupil. + # If the beam is within the pupil. Remove if Daniel pupil_mask = (x_farfield - position_objective[0]) ** 2 + ( y_farfield - position_objective[1] ) ** 2 < (pupil_physical_size / 2) ** 2 @@ -1146,9 +1201,6 @@ def get( * illumination_angle_field ) - # Wave vector. - k = 2 * np.pi / wavelength * refractive_index_medium - # Harmonics. A, B = coefficients(L) PI, TAU = mie.harmonics(illumination_angle_field, L) @@ -1165,12 +1217,15 @@ def get( [E[i] * B[i] * PI[i] + E[i] * A[i] * TAU[i] for i in range(0, L)] ) + # Daniel + # arr[pupil_mask] = (S2 * S2_coef + S1 * S1_coef)/amp_factor arr[pupil_mask] = ( -1j / (k * R3_field) * np.exp(1j * k * R3_field) * (S2 * S2_coef + S1 * S1_coef) ) / amp_factor + # For phase shift correction (a multiplication of the field # by exp(1j * k * z)). @@ -1188,15 +1243,23 @@ def get( -mask.shape[1] // 2 : mask.shape[1] // 2, ] mask = np.exp(-0.5 * (x ** 2 + y ** 2) / ((sigma) ** 2)) - arr = arr * mask + # Not sure if needed... CM + # if len(pupil)>0: + # c_pix=[arr.shape[0]//2,arr.shape[1]//2] + + # arr[c_pix[0]-pupil.shape[0]//2:c_pix[0]+pupil.shape[0]//2,c_pix[1]-pupil.shape[1]//2:c_pix[1]+pupil.shape[1]//2]*=pupil + + # Daniel + # fourier_field = -np.fft.ifft2(np.fft.fftshift(np.fft.fft2(np.fft.fftshift(arr)))) fourier_field = np.fft.fft2(arr) propagation_matrix = get_propagation_matrix( fourier_field.shape, - pixel_size=voxel_size[2], + pixel_size=voxel_size[:2], # this needs a double check wavelength=wavelength / refractive_index_medium, + # to_z=(-z), # Daniel to_z=(-offset_z - z), dy=( relative_position[0] * ratio @@ -1206,11 +1269,12 @@ def get( dx=( relative_position[1] * ratio + position[1] - + (padding[1] - arr.shape[1] / 2) * voxel_size[1] + + (padding[2] - arr.shape[1] / 2) * voxel_size[1] # check if padding is top, bottom, left, right ), ) + fourier_field = ( - fourier_field * propagation_matrix * np.exp(-1j * k * offset_z) + fourier_field * propagation_matrix * np.exp(-1j * k * offset_z) # Remove last part (from exp)) if Daniel ) if return_fft: @@ -1275,6 +1339,7 @@ class MieSphere(MieScatterer): """ + def __init__( self, radius: float = 1e-6, @@ -1377,6 +1442,7 @@ class MieStratifiedSphere(MieScatterer): """ + def __init__( self, radius: ArrayLike[float] = [1e-6], @@ -1412,3 +1478,62 @@ def inner( refractive_index=refractive_index, **kwargs, ) + + +@dataclass +class ScatteredBase: + """Base class for scatterers (volumes and fields).""" + + array: np.ndarray | torch.Tensor + properties: dict[str, Any] = field(default_factory=dict) + + @property + def ndim(self) -> int: + """Number of dimensions of the underlying array.""" + return self.array.ndim + + @property + def shape(self) -> int: + """Number of dimensions of the underlying array.""" + return self.array.shape + + @property + def pos3d(self) -> np.ndarray: + return np.array([*self.position, self.z], dtype=float) + + @property + def position(self) -> np.ndarray: + pos = self.properties.get("position", None) + if pos is None: + return None + pos = np.asarray(pos, dtype=float) + if pos.ndim == 2 and pos.shape[0] == 1: + pos = pos[0] + return pos + + def as_array(self) -> ArrayLike: + """Return the underlying array. + + Notes + ----- + The raw array is also directly available as ``scatterer.array``. + This method exists mainly for API compatibility and clarity. + + """ + + return self.array + + def get_property(self, key: str, default: Any = None) -> Any: + return getattr(self, key, self.properties.get(key, default)) + + +@dataclass +class ScatteredVolume(ScatteredBase): + """Voxelized volume produced by a VolumeScatterer.""" + pass + + +@dataclass +class ScatteredField(ScatteredBase): + """Complex field produced by a FieldScatterer.""" + pass \ No newline at end of file