From 8a641a98a1449ba2f683cc0a065fcde87509aae7 Mon Sep 17 00:00:00 2001 From: Mirja Granfors <95694095+mirjagranfors@users.noreply.github.com> Date: Wed, 5 Nov 2025 16:15:19 +0100 Subject: [PATCH 01/24] Update README (#442) * Update README Added links to the tutorials for creating custom scatterers. * Update README * Update README.md --------- Co-authored-by: Alex <95913221+Pwhsky@users.noreply.github.com> --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index 674d41d3..a9f507c2 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,14 @@ Here you find a series of notebooks that give you an overview of the core featur Using PyTorch gradients to fit a Gaussian generated by a DeepTrack2 pipeline. +- DTGS171A **[Creating Custom Scatterers](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171A_custom_scatterers.ipynb)** + + Creating custom scatterers of arbitrary shapes. + +- DTGS171B **[Creating Custom Scatterers: Bacteria](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171B_custom_scatterers_bacteria.ipynb)** + + Creating custom scatterers in the shape of bacteria. + # Examples These are examples of how DeepTrack2 can be used on real datasets: From 5e402dba9807de374fceea5e7843a5c0677070a8 Mon Sep 17 00:00:00 2001 From: github-actions Date: Wed, 5 Nov 2025 15:15:38 +0000 Subject: [PATCH 02/24] Auto-update README-pypi.md --- README-pypi.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README-pypi.md b/README-pypi.md index be0de197..f6173669 100644 --- a/README-pypi.md +++ b/README-pypi.md @@ -93,6 +93,14 @@ Here you find a series of notebooks that give you an overview of the core featur Using PyTorch gradients to fit a Gaussian generated by a DeepTrack2 pipeline. +- DTGS171A **[Creating Custom Scatterers](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171A_custom_scatterers.ipynb)** + + Creating custom scatterers of arbitrary shapes. + +- DTGS171B **[Creating Custom Scatterers: Bacteria](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171B_custom_scatterers_bacteria.ipynb)** + + Creating custom scatterers in the shape of bacteria. + # Examples These are examples of how DeepTrack2 can be used on real datasets: From a045633e747dde090c2b84bd8c6b3ced0a485d5c Mon Sep 17 00:00:00 2001 From: Carlo Date: Fri, 2 Jan 2026 20:33:56 +0100 Subject: [PATCH 03/24] removed Image --- deeptrack/features.py | 10 ++++---- deeptrack/optics.py | 55 ++++++++++++++++++----------------------- deeptrack/scatterers.py | 2 +- 3 files changed, 30 insertions(+), 37 deletions(-) diff --git a/deeptrack/features.py b/deeptrack/features.py index 4bdfad38..26464d4c 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -8093,12 +8093,12 @@ def get( with units.context(ctx): image = self.feature(image) - # Downscale the result to the original resolution. - import skimage.measure + # # Downscale the result to the original resolution. + # import skimage.measure - image = skimage.measure.block_reduce( - image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean - ) + # image = skimage.measure.block_reduce( + # image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean + # ) return image diff --git a/deeptrack/optics.py b/deeptrack/optics.py index 5149bdae..be1217b2 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -344,11 +344,11 @@ def get( volume_samples, **additional_sample_kwargs, ) - sample_volume = Image(sample_volume) + # sample_volume = Image(sample_volume) # Merge all properties into the volume. - for scatterer in volume_samples + field_samples: - sample_volume.merge_properties_from(scatterer) + # for scatterer in volume_samples + field_samples: + # sample_volume.merge_properties_from(scatterer) # Let the objective know about the limits of the volume and all the fields. propagate_data_to_dependencies( @@ -365,18 +365,18 @@ def get( imaged_sample ) - # Merge with input - if not image: - if not self._wrap_array_with_image and isinstance(imaged_sample, Image): - return imaged_sample._value - else: - return imaged_sample + # # Merge with input + # if not image: + # if not self._wrap_array_with_image and isinstance(imaged_sample, Image): + # return imaged_sample._value + # else: + # return imaged_sample - if not isinstance(image, list): - image = [image] - for i in range(len(image)): - image[i].merge_properties_from(imaged_sample) - return image + # if not isinstance(image, list): + # image = [image] + # for i in range(len(image)): + # image[i].merge_properties_from(imaged_sample) + # return image # def _no_wrap_format_input(self, *args, **kwargs) -> list: # return self._image_wrapped_format_input(*args, **kwargs) @@ -757,19 +757,18 @@ def _pupil( W, H = np.meshgrid(y, x) RHO = (W ** 2 + H ** 2).astype(complex) - pupil_function = Image((RHO < 1) + 0.0j, copy=False) + pupil_function = (RHO < 1) + 0.0j # Defocus - z_shift = Image( + z_shift = ( 2 * np.pi * refractive_index_medium / wavelength * voxel_size[2] - * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO), - copy=False, + * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO) ) - z_shift._value[z_shift._value.imag != 0] = 0 + z_shift[z_shift.imag != 0] = 0 try: z_shift = np.nan_to_num(z_shift, False, 0, 0, 0) @@ -1118,9 +1117,7 @@ def get( ] z_limits = limits[2, :] - output_image = Image( - np.zeros((*padded_volume.shape[0:2], 1)), copy=False - ) + output_image = np.zeros((*padded_volume.shape[0:2], 1)) index_iterator = range(padded_volume.shape[2]) @@ -1156,7 +1153,7 @@ def get( field = np.fft.ifft2(convolved_fourier_field) # # Discard remaining imaginary part (should be 0 up to rounding error) field = np.real(field) - output_image._value[:, :, 0] += field[ + output_image[:, :, 0] += field[ : padded_volume.shape[0], : padded_volume.shape[1] ] @@ -1353,9 +1350,7 @@ def get( ] z_limits = limits[2, :] - output_image = Image( - np.zeros((*padded_volume.shape[0:2], 1)) - ) + output_image = np.zeros((*padded_volume.shape[0:2], 1)) index_iterator = range(padded_volume.shape[2]) z_iterator = np.linspace( @@ -1426,7 +1421,7 @@ def get( : padded_volume.shape[0], : padded_volume.shape[1] ] output_image = np.expand_dims(output_image, axis=-1) - output_image = Image(output_image[pad[0] : -pad[2], pad[1] : -pad[3]]) + output_image = output_image[pad[0] : -pad[2], pad[1] : -pad[3]] if not kwargs.get("return_field", False): output_image = np.square(np.abs(output_image)) @@ -1436,7 +1431,7 @@ def get( # output_image = output_image * np.exp(1j * -np.pi / 4) # output_image = output_image + 1 - output_image.properties = illuminated_volume.properties + # output_image.properties = illuminated_volume.properties return output_image @@ -1959,14 +1954,12 @@ def _create_volume( ): continue - padded_scatterer = Image( - np.pad( + padded_scatterer = np.pad( scatterer, [(2, 2), (2, 2), (2, 2)], "constant", constant_values=0, ) - ) padded_scatterer.merge_properties_from(scatterer) scatterer = padded_scatterer diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index 04a7c5ea..ed77e5ed 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -333,7 +333,7 @@ def _process_and_get( new_image = new_image[:, ~np.all(new_image == 0, axis=(0, 2))] new_image = new_image[:, :, ~np.all(new_image == 0, axis=(0, 1))] - return [Image(new_image)] + return [new_image] def _no_wrap_format_input( self, From d61ef780a50b9e72dab03035b4b9d1d7ccf6c381 Mon Sep 17 00:00:00 2001 From: Carlo Date: Sat, 3 Jan 2026 02:40:04 +0100 Subject: [PATCH 04/24] ok --- deeptrack/features.py | 2 +- deeptrack/optics.py | 311 ++++++++++++++++++++++++++++------------ deeptrack/scatterers.py | 136 +++++++++++++----- 3 files changed, 326 insertions(+), 123 deletions(-) diff --git a/deeptrack/features.py b/deeptrack/features.py index 26464d4c..f76245b8 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -8082,7 +8082,7 @@ def get( # Ensure factor is a tuple of three integers. if np.size(factor) == 1: - factor = (factor,) * 3 + factor = (factor, factor, 1) elif len(factor) != 3: raise ValueError( "Factor must be an integer or a tuple of three integers." diff --git a/deeptrack/optics.py b/deeptrack/optics.py index be1217b2..5564e0b2 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -137,11 +137,13 @@ def _pad_volume( from __future__ import annotations from pint import Quantity -from typing import Any +from typing import Any, TYPE_CHECKING import warnings import numpy as np -from scipy.ndimage import convolve +from scipy.ndimage import convolve # might be removed later +import torch +import torch.nn.functional as F from deeptrack.backend.units import ( ConversionTable, @@ -158,6 +160,16 @@ def _pad_volume( from deeptrack import image from deeptrack import units_registry as u +from deeptrack import TORCH_AVAILABLE, image +from deeptrack.backend import xp +from deeptrack.scatterers import ScatteredVolume, ScatteredField + +if TORCH_AVAILABLE: + import torch + +if TYPE_CHECKING: + import torch + #TODO ***??*** revise Microscope - torch, typing, docstring, unit test class Microscope(StructuralFeature): @@ -280,16 +292,16 @@ def get( # Grab properties from the objective to pass to the sample additional_sample_kwargs = self._objective.properties() - # Calculate required output image for the given upscale - # This way of providing the upscale will be deprecated in the future - # in favor of dt.Upscale(). - _upscale_given_by_optics = additional_sample_kwargs["upscale"] - if np.array(_upscale_given_by_optics).size == 1: - _upscale_given_by_optics = (_upscale_given_by_optics,) * 3 + # # Calculate required output image for the given upscale + # # This way of providing the upscale will be deprecated in the future + # # in favor of dt.Upscale(). + # _upscale_given_by_optics = additional_sample_kwargs["upscale"] + # if np.array(_upscale_given_by_optics).size == 1: + # _upscale_given_by_optics = (_upscale_given_by_optics,) * 3 with u.context( create_context( - *additional_sample_kwargs["voxel_size"], *_upscale_given_by_optics + *additional_sample_kwargs["voxel_size"]#, *_upscale_given_by_optics ) ): @@ -329,14 +341,14 @@ def get( volume_samples = [ scatterer for scatterer in list_of_scatterers - if not scatterer.get_property("is_field", default=False) + if isinstance(scatterer, ScatteredVolume) ] # All scatterers that are defined as fields. field_samples = [ scatterer for scatterer in list_of_scatterers - if scatterer.get_property("is_field", default=False) + if isinstance(scatterer, ScatteredField) ] # Merge all volumes into a single volume. @@ -359,33 +371,32 @@ def get( imaged_sample = self._objective.resolve(sample_volume) - # Upscale given by the optics needs to be handled separately. - if _upscale_given_by_optics != (1, 1, 1): - imaged_sample = AveragePooling((*_upscale_given_by_optics[:2], 1))( - imaged_sample - ) - # # Merge with input - # if not image: - # if not self._wrap_array_with_image and isinstance(imaged_sample, Image): - # return imaged_sample._value - # else: - # return imaged_sample + # Collect main_property from scatterers + main_properties = { + s.main_property + for s in list_of_scatterers + if hasattr(s, "main_property") + } - # if not isinstance(image, list): - # image = [image] - # for i in range(len(image)): - # image[i].merge_properties_from(imaged_sample) - # return image + if len(main_properties) != 1: + raise ValueError( + f"Inconsistent main_property across scatterers: {main_properties}" + ) - # def _no_wrap_format_input(self, *args, **kwargs) -> list: - # return self._image_wrapped_format_input(*args, **kwargs) + main_property = main_properties.pop() - # def _no_wrap_process_and_get(self, *args, **feature_input) -> list: - # return self._image_wrapped_process_and_get(*args, **feature_input) + # Handling upscale from dt.Upscale() here to eliminate Image + # wrapping issues. + if np.any(np.array(upscale) != 1): + ux, uy = upscale[:2] + if main_property == "intensity": + print("Using sum pooling for intensity downscaling.") + imaged_sample = SumPoolingCM((ux, uy, 1))(imaged_sample) + else: + imaged_sample = AveragePoolingCM((ux, uy, 1))(imaged_sample) - # def _no_wrap_process_output(self, *args, **feature_input): - # return self._image_wrapped_process_output(*args, **feature_input) + return imaged_sample #TODO ***??*** revise Optics - torch, typing, docstring, unit test @@ -1158,7 +1169,7 @@ def get( ] output_image = output_image[pad[0] : -pad[2], pad[1] : -pad[3]] - output_image.properties = illuminated_volume.properties + pupils.properties + # output_image.properties = illuminated_volume.properties + pupils.properties return output_image @@ -1797,7 +1808,7 @@ def get( #TODO ***??*** revise _get_position - torch, typing, docstring, unit test def _get_position( - image: Image, + scatterer: ScatteredVolume, mode: str = "corner", return_z: bool = False, ) -> np.ndarray: @@ -1821,26 +1832,23 @@ def _get_position( num_outputs = 2 + return_z - if mode == "corner" and image.size > 0: + if mode == "corner" and scatterer.array.size > 0: import scipy.ndimage - image = image.to_numpy() - - shift = scipy.ndimage.center_of_mass(np.abs(image)) + shift = scipy.ndimage.center_of_mass(np.abs(scatterer.array)) if np.isnan(shift).any(): - shift = np.array(image.shape) / 2 + shift = np.array(scatterer.array.shape) / 2 else: shift = np.zeros((num_outputs)) - position = np.array(image.get_property("position", default=None)) + position = np.array(scatterer.get_property("position", default=None)) if position is None: return position scale = np.array(get_active_scale()) - if len(position) == 3: position = position * scale + 0.5 * (scale - 1) if return_z: @@ -1851,7 +1859,7 @@ def _get_position( elif len(position) == 2: if return_z: outp = ( - np.array([position[0], position[1], image.get_property("z", default=0)]) + np.array([position[0], position[1], scatterer.get_property("z", default=0)]) * scale - shift + 0.5 * (scale - 1) @@ -1863,6 +1871,58 @@ def _get_position( return position +def _bilinear_interpolate_numpy( + scatterer: np.ndarray, x_off: float, y_off: float +) -> np.ndarray: + """Apply bilinear subpixel interpolation in the x–y plane (NumPy).""" + kernel = np.array( + [ + [0.0, 0.0, 0.0], + [0.0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off], + [0.0, x_off * (1 - y_off), x_off * y_off], + ] + ) + out = np.zeros_like(scatterer) + for z in range(scatterer.shape[2]): + if np.iscomplexobj(scatterer): + out[:, :, z] = ( + convolve(np.real(scatterer[:, :, z]), kernel, mode="constant") + + 1j + * convolve(np.imag(scatterer[:, :, z]), kernel, mode="constant") + ) + else: + out[:, :, z] = convolve(scatterer[:, :, z], kernel, mode="constant") + return out + + +def _bilinear_interpolate_torch( + scatterer: torch.Tensor, x_off: float, y_off: float +) -> torch.Tensor: + """Apply bilinear subpixel interpolation in the x–y plane (Torch). + + Uses grid_sample for autograd-friendly interpolation. + """ + H, W, D = scatterer.shape + + # Normalized shifts in [-1,1] + x_shift = 2 * x_off / (W - 1) + y_shift = 2 * y_off / (H - 1) + + yy, xx = torch.meshgrid( + torch.linspace(-1, 1, H, device=scatterer.device, dtype=scatterer.dtype), + torch.linspace(-1, 1, W, device=scatterer.device, dtype=scatterer.dtype), + indexing="ij", + ) + grid = torch.stack((xx + x_shift, yy + y_shift), dim=-1) # (H,W,2) + grid = grid.unsqueeze(0).repeat(D, 1, 1, 1) # (D,H,W,2) + + inp = scatterer.permute(2, 0, 1).unsqueeze(1) # (D,1,H,W) + + out = F.grid_sample(inp, grid, mode="bilinear", + padding_mode="zeros", align_corners=True) + return out.squeeze(1).permute(1, 2, 0) # (H,W,D) + + #TODO ***??*** revise _create_volume - torch, typing, docstring, unit test def _create_volume( list_of_scatterers: list, @@ -1922,24 +1982,20 @@ def _create_volume( # This accounts for upscale doing AveragePool instead of SumPool. This is # a bit of a hack, but it works for now. - fudge_factor = scale[0] * scale[1] / scale[2] + # fudge_factor = scale[0] * scale[1] / scale[2] for scatterer in list_of_scatterers: position = _get_position(scatterer, mode="corner", return_z=True) - - if scatterer.get_property("intensity", None) is not None: - intensity = scatterer.get_property("intensity") - scatterer_value = intensity * fudge_factor - elif scatterer.get_property("refractive_index", None) is not None: - refractive_index = scatterer.get_property("refractive_index") - scatterer_value = ( - refractive_index - refractive_index_medium - ) - else: + if scatterer.main_property == "intensity": + scatterer_value = scatterer.get_property("intensity") #* fudge_factor + elif scatterer.main_property == "refractive_index": + scatterer_value = scatterer.get_property("refractive_index") - refractive_index_medium + else: # fallback to generic value scatterer_value = scatterer.get_property("value") - scatterer = scatterer * scatterer_value + # Scale the array accordingly + scatterer.array = scatterer.array * scatterer_value if limits is None: limits = np.zeros((3, 2), dtype=np.int32) @@ -1947,24 +2003,23 @@ def _create_volume( limits[:, 1] = np.floor(position).astype(np.int32) + 1 if ( - position[0] + scatterer.shape[0] < OR[0] + position[0] + scatterer.array.shape[0] < OR[0] or position[0] > OR[2] - or position[1] + scatterer.shape[1] < OR[1] + or position[1] + scatterer.array.shape[1] < OR[1] or position[1] > OR[3] ): continue - padded_scatterer = np.pad( - scatterer, + # Pad scatterer to avoid edge effects during interpolation + padded_scatterer = scatterer + padded_scatterer.array = np.pad( + scatterer.array._value, # this is a temporary fix, Image should be removed from features [(2, 2), (2, 2), (2, 2)], "constant", constant_values=0, ) - padded_scatterer.merge_properties_from(scatterer) - - scatterer = padded_scatterer - position = _get_position(scatterer, mode="corner", return_z=True) - shape = np.array(scatterer.shape) + position = _get_position(padded_scatterer, mode="corner", return_z=True) + shape = np.array(padded_scatterer.array.shape) if position is None: RuntimeWarning( @@ -1973,36 +2028,44 @@ def _create_volume( ) continue - splined_scatterer = np.zeros_like(scatterer) - x_off = position[0] - np.floor(position[0]) y_off = position[1] - np.floor(position[1]) - kernel = np.array( - [ - [0, 0, 0], - [0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off], - [0, x_off * (1 - y_off), x_off * y_off], - ] - ) + + if isinstance(padded_scatterer.array, np.ndarray): # get_backend is a method of Features and not exposed + splined_scatterer = _bilinear_interpolate_numpy(padded_scatterer.array, x_off, y_off) + elif isinstance(padded_scatterer.array, torch.Tensor): + splined_scatterer = _bilinear_interpolate_torch(padded_scatterer.array, x_off, y_off) + else: + raise TypeError( + f"Unsupported array type {type(padded_scatterer.array)}. " + "Expected np.ndarray or torch.Tensor." + ) - for z in range(scatterer.shape[2]): - if splined_scatterer.dtype == complex: - splined_scatterer[:, :, z] = ( - convolve( - np.real(scatterer[:, :, z]), kernel, mode="constant" - ) - + convolve( - np.imag(scatterer[:, :, z]), kernel, mode="constant" - ) - * 1j - ) - else: - splined_scatterer[:, :, z] = convolve( - scatterer[:, :, z], kernel, mode="constant" - ) + # kernel = np.array( + # [ + # [0, 0, 0], + # [0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off], + # [0, x_off * (1 - y_off), x_off * y_off], + # ] + # ) + + # for z in range(padded_scatterer.array.shape[2]): + # if splined_scatterer.dtype == complex: + # splined_scatterer[:, :, z] = ( + # convolve( + # np.real(padded_scatterer.array[:, :, z]), kernel, mode="constant" + # ) + # + convolve( + # np.imag(padded_scatterer.array[:, :, z]), kernel, mode="constant" + # ) + # * 1j + # ) + # else: + # splined_scatterer[:, :, z] = convolve( + # padded_scatterer.array[:, :, z], kernel, mode="constant" + # ) - scatterer = splined_scatterer position = np.floor(position) new_limits = np.zeros(limits.shape, dtype=np.int32) for i in range(3): @@ -2032,6 +2095,7 @@ def _create_volume( within_volume_position = position - limits[:, 0] # NOTE: Maybe shouldn't be additive. + # give options: sum default, but also sum, mean, max, min volume[ int(within_volume_position[0]) : int(within_volume_position[0] + shape[0]), @@ -2041,5 +2105,72 @@ def _create_volume( int(within_volume_position[2]) : int(within_volume_position[2] + shape[2]), - ] += scatterer + ] += splined_scatterer return volume, limits + +# this should be moved to math +class _CenteredPoolingBase: + def __init__(self, pool_size: tuple[int, int, int]): + px, py, pz = pool_size + if pz != 1: + raise ValueError("Only pz=1 supported.") + self.px = int(px) + self.py = int(py) + + def _crop_center(self, array): + H, W = array.shape[:2] + px, py = self.px, self.py + + crop_h = (H // px) * px + crop_w = (W // py) * py + + off_h = (H - crop_h) // 2 + off_w = (W - crop_w) // 2 + + return array[off_h:off_h+crop_h, off_w:off_w+crop_w, ...] + + def _pool_numpy(self, array, func): + import skimage.measure + array = self._crop_center(array) + pool_shape = (self.px, self.py) + (1,) * (array.ndim - 2) + return skimage.measure.block_reduce(array, pool_shape, func) + + def _pool_torch(self, array, sum_pool=False): + px, py = self.px, self.py + array = self._crop_center(array) + + extra = array.shape[2:] + C = int(np.prod(extra)) if extra else 1 + x = array.reshape(1, C, array.shape[0], array.shape[1]) + + pooled = torch.nn.functional.avg_pool2d( + x, kernel_size=(px, py), stride=(px, py) + ) + if sum_pool: + pooled = pooled * (px * py) + + return pooled.reshape( + (pooled.shape[2], pooled.shape[3]) + extra + ) + +class AveragePoolingCM(_CenteredPoolingBase): + """Center-preserving average pooling (intensive quantities).""" + + def __call__(self, array): + if isinstance(array, np.ndarray): + return self._pool_numpy(array, np.mean) + elif TORCH_AVAILABLE and isinstance(array, torch.Tensor): + return self._pool_torch(array, sum_pool=False) + else: + raise TypeError("Unsupported array type.") + +class SumPoolingCM(_CenteredPoolingBase): + """Center-preserving sum pooling (extensive quantities).""" + + def __call__(self, array): + if isinstance(array, np.ndarray): + return self._pool_numpy(array, np.sum) + elif TORCH_AVAILABLE and isinstance(array, torch.Tensor): + return self._pool_torch(array, sum_pool=True) + else: + raise TypeError("Unsupported array type.") \ No newline at end of file diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index ed77e5ed..f3daa9bf 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -166,6 +166,7 @@ import numpy as np from numpy.typing import NDArray from pint import Quantity +from dataclasses import dataclass, field from deeptrack.holography import get_propagation_matrix from deeptrack.backend.units import ( @@ -246,6 +247,9 @@ class Scatterer(Feature): voxel_size=(u.meter, u.meter), ) + #: Default property name (subclasses override this) + main_property: str = "value" + def __init__( self, position: ArrayLike[float] = (32, 32), @@ -258,7 +262,7 @@ def __init__( **kwargs, ) -> None: # Ignore warning to help with comparison with arrays. - if upsample is not 1: # noqa: F632 + if upsample != 1: # noqa: F632 warnings.warn( f"Setting upsample != 1 is deprecated. " f"Please, instead use dt.Upscale(f, factor={upsample})" @@ -310,7 +314,7 @@ def _process_and_get( voxel_size = get_active_voxel_size() # Calls parent _process_and_get. - new_image = super()._process_and_get( + new_image = super(Scatterer, self)._process_and_get( *args, voxel_size=voxel_size, upsample=upsample, @@ -333,32 +337,41 @@ def _process_and_get( new_image = new_image[:, ~np.all(new_image == 0, axis=(0, 2))] new_image = new_image[:, :, ~np.all(new_image == 0, axis=(0, 1))] - return [new_image] - - def _no_wrap_format_input( - self, - *args, - **kwargs - ) -> list: - return self._image_wrapped_format_input(*args, **kwargs) - - def _no_wrap_process_and_get( - self, - *args, - **feature_input - ) -> list: - return self._image_wrapped_process_and_get(*args, **feature_input) - - def _no_wrap_process_output( - self, - *args, - **feature_input - ) -> list: - return self._image_wrapped_process_output(*args, **feature_input) + # # Copy properties + # props = kwargs.copy() + return [self._wrap_output(new_image, kwargs)] + + def _wrap_output(self, array, props) -> ScatteredBase: + """Must be overridden in subclasses to wrap output correctly.""" + raise NotImplementedError + +class VolumeScatterer(Scatterer): + """Abstract scatterer producing ScatteredVolume outputs.""" + def _wrap_output(self, array, props) -> ScatteredVolume: + return [ScatteredVolume( + array=array, + position=props.get("position", (0, 0)), + z=props.get("z", 0.0), + value=props.get("value", 1.0), + intensity=props.get("intensity", None), + refractive_index=props.get("refractive_index", None), + properties=props.copy(), + main_property=self.main_property, + )] + +class FieldScatterer(Scatterer): + def _wrap_output(self, array, props) -> ScatteredField: + return [ScatteredField( + array=array, + position=props.get("position", (0, 0)), + wavelength=props.get("wavelength", 0.0), + properties=props.copy(), + main_property=self.main_property, + )] #TODO ***??*** revise PointParticle - torch, typing, docstring, unit test -class PointParticle(Scatterer): +class PointParticle(VolumeScatterer): """Generate a diffraction-limited point particle. A point particle is approximated by the size of a single pixel or voxel. @@ -382,6 +395,8 @@ class PointParticle(Scatterer): """ + main_property = "intensity" + def __init__( self: PointParticle, **kwargs: Any, @@ -405,7 +420,7 @@ def get( #TODO ***??*** revise Ellipse - torch, typing, docstring, unit test -class Ellipse(Scatterer): +class Ellipse(VolumeScatterer): """Generates an elliptical disk scatterer Parameters @@ -446,6 +461,8 @@ class Ellipse(Scatterer): rotation=(u.radian, u.radian), ) + main_property = "refractive_index" + def __init__( self, radius: float = 1e-6, @@ -519,7 +536,7 @@ def get( #TODO ***??*** revise Sphere - torch, typing, docstring, unit test -class Sphere(Scatterer): +class Sphere(VolumeScatterer): """Generates a spherical scatterer Parameters @@ -550,6 +567,8 @@ class Sphere(Scatterer): radius=(u.meter, u.meter), ) + main_property = "refractive_index" + def __init__( self, radius: float = 1e-6, @@ -584,7 +603,7 @@ def get( #TODO ***??*** revise Ellipsoid - torch, typing, docstring, unit test -class Ellipsoid(Scatterer): +class Ellipsoid(VolumeScatterer): """Generates an ellipsoidal scatterer Parameters @@ -625,6 +644,8 @@ class Ellipsoid(Scatterer): rotation=(u.radian, u.radian), ) + main_property = "refractive_index" + def __init__( self, radius: float = 1e-6, @@ -741,7 +762,7 @@ def get( #TODO ***??*** revise MieScatterer - torch, typing, docstring, unit test -class MieScatterer(Scatterer): +class MieScatterer(FieldScatterer): """Base implementation of a Mie particle. New Mie-theory scatterers can be implemented by extending this class, and @@ -835,6 +856,8 @@ class MieScatterer(Scatterer): coherence_length=(u.meter, u.pixel), ) + main_property = "wavelength" + def __init__( self, coefficients, @@ -864,11 +887,11 @@ def __init__( "Please use input_polarization instead" ) input_polarization = polarization_angle - kwargs.pop("is_field", None) + kwargs.pop("is_field", None) # remove kwargs.pop("crop_empty", None) super().__init__( - is_field=True, + is_field=True, # remove crop_empty=False, L=L, offset_z=offset_z, @@ -1188,7 +1211,6 @@ def get( -mask.shape[1] // 2 : mask.shape[1] // 2, ] mask = np.exp(-0.5 * (x ** 2 + y ** 2) / ((sigma) ** 2)) - arr = arr * mask fourier_field = np.fft.fft2(arr) @@ -1412,3 +1434,53 @@ def inner( refractive_index=refractive_index, **kwargs, ) + + +@dataclass +class ScatteredBase: + """Base class for scatterers (volumes and fields).""" + + array: ArrayLike + position: np.ndarray + z: float = 0.0 + properties: dict[str, Any] = field(default_factory=dict) + main_property: str = None + + def __post_init__(self): + self.position = np.array(self.position, dtype=float).reshape(-1)[:2] + self.z = float(np.atleast_1d(self.z).squeeze()) + + @property + def pos3d(self) -> np.ndarray: + return np.array([*self.position, self.z], dtype=float) + + def as_array(self) -> ArrayLike: + """Return the underlying array. + + Notes + ----- + The raw array is also directly available as ``scatterer.array``. + This method exists mainly for API compatibility and clarity. + + """ + + return self.array + + def get_property(self, key: str, default: Any = None) -> Any: + return getattr(self, key, self.properties.get(key, default)) + + +@dataclass +class ScatteredVolume(ScatteredBase): + """Volumetric object: intensity sources or refractive index contrasts.""" + + refractive_index: float | None = None + intensity: float | None = None + value: float | None = None + + +@dataclass +class ScatteredField(ScatteredBase): + """Complex wavefield (already propagated or emitted).""" + + wavelength: float = 500e-9 \ No newline at end of file From 36af2e4f909ef80673015c853ca6ee7f4c6390d8 Mon Sep 17 00:00:00 2001 From: Carlo Date: Sun, 4 Jan 2026 00:37:23 +0100 Subject: [PATCH 05/24] removed Image from optics and scatterers --- deeptrack/optics.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/deeptrack/optics.py b/deeptrack/optics.py index 5564e0b2..4c23469f 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -351,6 +351,10 @@ def get( if isinstance(scatterer, ScatteredField) ] + warn_upscale_fields = False + if field_samples and np.any(upscale != 1): + warn_upscale_fields = True + # Merge all volumes into a single volume. sample_volume, limits = _create_volume( volume_samples, @@ -371,6 +375,14 @@ def get( imaged_sample = self._objective.resolve(sample_volume) + if warn_upscale_fields: + warnings.warn( + "dt.Upscale is active while FieldScatterers are present. " + "Coherent fields are injected without resampling, so the " + "physical interpretation may change with Upscale. " + "This behavior is currently undefined.", + UserWarning, + ) # Collect main_property from scatterers main_properties = { @@ -1420,7 +1432,25 @@ def get( light_in_focus = light_in * shifted_pupil if len(fields) > 0: - field = np.sum(fields, axis=0) + # field = np.sum(fields, axis=0) + field_arrays = [] + + for fs in fields: + # fs is a ScatteredField + arr = fs.array + + # Enforce (H, W, 1) shape + if arr.ndim == 2: + arr = arr[..., None] + + if arr.ndim != 3 or arr.shape[-1] != 1: + raise ValueError( + f"Expected field of shape (H, W, 1), got {arr.shape}" + ) + + field_arrays.append(arr) + + field = np.sum(field_arrays, axis=0) light_in_focus += field[..., 0] shifted_pupil = np.fft.fftshift(pupils[-1]) light_in_focus = light_in_focus * shifted_pupil From 27ebcb19e61431e3e1c4cd6877804e09ce013679 Mon Sep 17 00:00:00 2001 From: Carlo Date: Sun, 4 Jan 2026 01:58:58 +0100 Subject: [PATCH 06/24] Daniel's Mie scatterer. --- deeptrack/scatterers.py | 63 +++++++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 21 deletions(-) diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index f3daa9bf..b52408e2 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -879,6 +879,7 @@ def __init__( illumination_angle: float=0, amp_factor: float=1, phase_shift_correction: bool=False, + pupil: ArrayLike=[], **kwargs, ) -> None: if polarization_angle is not None: @@ -887,7 +888,7 @@ def __init__( "Please use input_polarization instead" ) input_polarization = polarization_angle - kwargs.pop("is_field", None) # remove + # kwargs.pop("is_field", None) # remove kwargs.pop("crop_empty", None) super().__init__( @@ -912,6 +913,7 @@ def __init__( illumination_angle=illumination_angle, amp_factor=amp_factor, phase_shift_correction=phase_shift_correction, + pupil=pupil, **kwargs, ) @@ -1037,7 +1039,8 @@ def get_plane_in_polar_coords( shape: int, voxel_size: ArrayLike[float], plane_position: float, - illumination_angle: float + illumination_angle: float, + k: float, ) -> tuple[float, float, float, float]: """Computes the coordinates of the plane in polar form.""" @@ -1050,15 +1053,22 @@ def get_plane_in_polar_coords( R2_squared = X ** 2 + Y ** 2 R3 = np.sqrt(R2_squared + Z ** 2) # Might be +z instead of -z. + Q = np.sqrt(R2_squared)/voxel_size[0]**2*2*np.pi/shape[0] + sin_theta=Q/(k) + pupil_mask=sin_theta<=1 + + cos_theta=np.zeros(sin_theta.shape) + cos_theta[pupil_mask]=np.sqrt(1-sin_theta[pupil_mask]**2) # Fet the angles. - cos_theta = Z / R3 + # cos_theta = Z / R3 + illumination_cos_theta = ( np.cos(np.arccos(cos_theta) + illumination_angle) ) phi = np.arctan2(Y, X) - return R3, cos_theta, illumination_cos_theta, phi + return R3, cos_theta, illumination_cos_theta, phi, pupil_mask def get( self, @@ -1083,6 +1093,7 @@ def get( illumination_angle: float, amp_factor: float, phase_shift_correction: bool, + pupil: ArrayLike, **kwargs, ) -> ArrayLike[float]: """Abstract method to initialize the Mie scatterer""" @@ -1099,6 +1110,10 @@ def get( ratio = offset_z / (working_distance - z) + # Wave vector. + k = 2 * np.pi / wavelength * refractive_index_medium + + # Position of pbjective relative particle. relative_position = np.array( ( @@ -1109,11 +1124,12 @@ def get( ) # Get field evaluation plane at offset_z. - R3_field, cos_theta_field, illumination_angle_field, phi_field =\ + R3_field, cos_theta_field, illumination_angle_field, phi_field, pupil_mask =\ self.get_plane_in_polar_coords( arr.shape, voxel_size, relative_position * ratio, - illumination_angle + illumination_angle, + k ) cos_phi_field, sin_phi_field = np.cos(phi_field), np.sin(phi_field) @@ -1132,9 +1148,9 @@ def get( ) # If the beam is within the pupil. - pupil_mask = (x_farfield - position_objective[0]) ** 2 + ( - y_farfield - position_objective[1] - ) ** 2 < (pupil_physical_size / 2) ** 2 + # pupil_mask = (x_farfield - position_objective[0]) ** 2 + ( + # y_farfield - position_objective[1] + # ) ** 2 < (pupil_physical_size / 2) ** 2 R3_field = R3_field[pupil_mask] cos_theta_field = cos_theta_field[pupil_mask] @@ -1169,9 +1185,6 @@ def get( * illumination_angle_field ) - # Wave vector. - k = 2 * np.pi / wavelength * refractive_index_medium - # Harmonics. A, B = coefficients(L) PI, TAU = mie.harmonics(illumination_angle_field, L) @@ -1188,12 +1201,14 @@ def get( [E[i] * B[i] * PI[i] + E[i] * A[i] * TAU[i] for i in range(0, L)] ) - arr[pupil_mask] = ( - -1j - / (k * R3_field) - * np.exp(1j * k * R3_field) - * (S2 * S2_coef + S1 * S1_coef) - ) / amp_factor + # arr[pupil_mask] = ( + # -1j + # / (k * R3_field) + # * np.exp(1j * k * R3_field) + # * (S2 * S2_coef + S1 * S1_coef) + # ) / amp_factor + arr[pupil_mask] = (S2 * S2_coef + S1 * S1_coef)/amp_factor + # For phase shift correction (a multiplication of the field # by exp(1j * k * z)). @@ -1213,13 +1228,19 @@ def get( mask = np.exp(-0.5 * (x ** 2 + y ** 2) / ((sigma) ** 2)) arr = arr * mask - fourier_field = np.fft.fft2(arr) + if len(pupil)>0: + c_pix=[arr.shape[0]//2,arr.shape[1]//2] + + arr[c_pix[0]-pupil.shape[0]//2:c_pix[0]+pupil.shape[0]//2,c_pix[1]-pupil.shape[1]//2:c_pix[1]+pupil.shape[1]//2]*=pupil + fourier_field = -np.fft.ifft2(np.fft.fftshift(np.fft.fft2(np.fft.fftshift(arr)))) + # fourier_field = np.fft.fft2(arr) propagation_matrix = get_propagation_matrix( fourier_field.shape, pixel_size=voxel_size[2], wavelength=wavelength / refractive_index_medium, - to_z=(-offset_z - z), + # to_z=(-offset_z - z), + to_z=(-z), dy=( relative_position[0] * ratio + position[0] @@ -1232,7 +1253,7 @@ def get( ), ) fourier_field = ( - fourier_field * propagation_matrix * np.exp(-1j * k * offset_z) + fourier_field * propagation_matrix #* np.exp(-1j * k * offset_z) ) if return_fft: From 2b56f74e7420032db5d1efc3fd7dafdad4640261 Mon Sep 17 00:00:00 2001 From: Carlo Date: Sun, 4 Jan 2026 02:55:21 +0100 Subject: [PATCH 07/24] u --- deeptrack/optics.py | 1 - deeptrack/scatterers.py | 7 +++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/deeptrack/optics.py b/deeptrack/optics.py index 4c23469f..bfccb4f7 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -2015,7 +2015,6 @@ def _create_volume( # fudge_factor = scale[0] * scale[1] / scale[2] for scatterer in list_of_scatterers: - position = _get_position(scatterer, mode="corner", return_z=True) if scatterer.main_property == "intensity": scatterer_value = scatterer.get_property("intensity") #* fudge_factor diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index b52408e2..408506d6 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -1054,9 +1054,9 @@ def get_plane_in_polar_coords( R2_squared = X ** 2 + Y ** 2 R3 = np.sqrt(R2_squared + Z ** 2) # Might be +z instead of -z. Q = np.sqrt(R2_squared)/voxel_size[0]**2*2*np.pi/shape[0] + # is dimensionally ok? Doesn't look like it. sin_theta=Q/(k) - pupil_mask=sin_theta<=1 - + pupil_mask=sin_theta<1 cos_theta=np.zeros(sin_theta.shape) cos_theta[pupil_mask]=np.sqrt(1-sin_theta[pupil_mask]**2) @@ -1103,6 +1103,9 @@ def get( voxel_size = get_active_voxel_size() arr = pad_image_to_fft(np.zeros((xSize, ySize))).astype(complex) position = np.array(position) * voxel_size[: len(position)] + print(xSize, ySize) + print(padding) + print(position) pupil_physical_size = working_distance * np.tan(collection_angle) * 2 From 2da3587e45ffd7f41232cf6dcfd3db61eb3d0318 Mon Sep 17 00:00:00 2001 From: Carlo Date: Mon, 5 Jan 2026 03:16:12 +0100 Subject: [PATCH 08/24] fixed upscale with field scatterers --- deeptrack/holography.py | 4 ++-- deeptrack/optics.py | 24 ++++++++++++------------ deeptrack/scatterers.py | 21 +++++++++++---------- 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/deeptrack/holography.py b/deeptrack/holography.py index 380969cf..39fc4e43 100644 --- a/deeptrack/holography.py +++ b/deeptrack/holography.py @@ -146,8 +146,8 @@ def get_propagation_matrix( x = np.arange(0, xr, 1) - xr / 2 + (xr % 2) / 2 y = np.arange(0, yr, 1) - yr / 2 + (yr % 2) / 2 - x = 2 * np.pi / pixel_size * x / xr - y = 2 * np.pi / pixel_size * y / yr + x = 2 * np.pi / pixel_size[0] * x / xr + y = 2 * np.pi / pixel_size[1] * y / yr KXk, KYk = np.meshgrid(x, y) KXk = KXk.astype(complex) diff --git a/deeptrack/optics.py b/deeptrack/optics.py index bfccb4f7..092696a5 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -351,9 +351,9 @@ def get( if isinstance(scatterer, ScatteredField) ] - warn_upscale_fields = False - if field_samples and np.any(upscale != 1): - warn_upscale_fields = True + # warn_upscale_fields = False + # if field_samples and np.any(upscale != 1): + # warn_upscale_fields = True # Merge all volumes into a single volume. sample_volume, limits = _create_volume( @@ -375,14 +375,14 @@ def get( imaged_sample = self._objective.resolve(sample_volume) - if warn_upscale_fields: - warnings.warn( - "dt.Upscale is active while FieldScatterers are present. " - "Coherent fields are injected without resampling, so the " - "physical interpretation may change with Upscale. " - "This behavior is currently undefined.", - UserWarning, - ) + # if warn_upscale_fields: + # warnings.warn( + # "dt.Upscale is active while FieldScatterers are present. " + # "Coherent fields are injected without resampling, so the " + # "physical interpretation may change with Upscale. " + # "This behavior is currently undefined.", + # UserWarning, + # ) # Collect main_property from scatterers main_properties = { @@ -1365,7 +1365,7 @@ def get( if output_region[3] is None else int(output_region[3] - limits[1, 0] + pad[3]) ) - + padded_volume = padded_volume[ output_region[0] : output_region[2], output_region[1] : output_region[3], diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index 408506d6..eba0566b 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -1002,8 +1002,10 @@ def get_XY( The meshgrid of X and Y coordinates. """ - x = np.arange(shape[0]) - shape[0] / 2 - y = np.arange(shape[1]) - shape[1] / 2 + # x = np.arange(shape[0]) - shape[0] / 2 + # y = np.arange(shape[1]) - shape[1] / 2 + x = np.arange(shape[0]) - (shape[0] - 1) / 2 + y = np.arange(shape[1]) - (shape[1] - 1) / 2 return np.meshgrid(x * voxel_size[0], y * voxel_size[1], indexing="ij") def get_detector_mask( @@ -1101,11 +1103,9 @@ def get( # Get size of the output. xSize, ySize = self.get_xy_size(output_region, padding) voxel_size = get_active_voxel_size() + scale = get_active_scale() arr = pad_image_to_fft(np.zeros((xSize, ySize))).astype(complex) - position = np.array(position) * voxel_size[: len(position)] - print(xSize, ySize) - print(padding) - print(position) + position = np.array(position) * scale[: len(position)] * voxel_size[: len(position)] pupil_physical_size = working_distance * np.tan(collection_angle) * 2 @@ -1117,7 +1117,7 @@ def get( k = 2 * np.pi / wavelength * refractive_index_medium - # Position of pbjective relative particle. + # Position of objective relative particle. relative_position = np.array( ( position_objective[0] - position[0], @@ -1236,11 +1236,11 @@ def get( arr[c_pix[0]-pupil.shape[0]//2:c_pix[0]+pupil.shape[0]//2,c_pix[1]-pupil.shape[1]//2:c_pix[1]+pupil.shape[1]//2]*=pupil fourier_field = -np.fft.ifft2(np.fft.fftshift(np.fft.fft2(np.fft.fftshift(arr)))) - # fourier_field = np.fft.fft2(arr) + # fourier_field = np.fft.fft2(np.fft.fftshift(arr)) propagation_matrix = get_propagation_matrix( fourier_field.shape, - pixel_size=voxel_size[2], + pixel_size=voxel_size[:2], # this needs a double check wavelength=wavelength / refractive_index_medium, # to_z=(-offset_z - z), to_z=(-z), @@ -1252,9 +1252,10 @@ def get( dx=( relative_position[1] * ratio + position[1] - + (padding[1] - arr.shape[1] / 2) * voxel_size[1] + + (padding[2] - arr.shape[1] / 2) * voxel_size[1] # check if padding is top, bottom, left, right ), ) + fourier_field = ( fourier_field * propagation_matrix #* np.exp(-1j * k * offset_z) ) From cdaf8de6d0e5dbb0d57577285ac9638fbe233c39 Mon Sep 17 00:00:00 2001 From: Carlo Date: Wed, 7 Jan 2026 10:54:33 +0100 Subject: [PATCH 09/24] u --- deeptrack/scatterers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index eba0566b..70084488 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -1002,10 +1002,10 @@ def get_XY( The meshgrid of X and Y coordinates. """ - # x = np.arange(shape[0]) - shape[0] / 2 - # y = np.arange(shape[1]) - shape[1] / 2 - x = np.arange(shape[0]) - (shape[0] - 1) / 2 - y = np.arange(shape[1]) - (shape[1] - 1) / 2 + x = np.arange(shape[0]) - shape[0] / 2 + y = np.arange(shape[1]) - shape[1] / 2 + # x = np.arange(shape[0]) - (shape[0] - 1) / 2 + # y = np.arange(shape[1]) - (shape[1] - 1) / 2 return np.meshgrid(x * voxel_size[0], y * voxel_size[1], indexing="ij") def get_detector_mask( From 191d3a51c13b015c0867bcd116cabe61a9f84757 Mon Sep 17 00:00:00 2001 From: Carlo Date: Wed, 7 Jan 2026 13:26:38 +0100 Subject: [PATCH 10/24] rolled back momentarily removed changed to MieScatetterer suggested by Daniel. Now Upscale doesn't work with MieScatterers --- deeptrack/scatterers.py | 76 ++++++++++++++++++++++------------------- 1 file changed, 40 insertions(+), 36 deletions(-) diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index 70084488..a9681953 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -879,7 +879,7 @@ def __init__( illumination_angle: float=0, amp_factor: float=1, phase_shift_correction: bool=False, - pupil: ArrayLike=[], + # pupil: ArrayLike=[], # Daniel **kwargs, ) -> None: if polarization_angle is not None: @@ -913,7 +913,7 @@ def __init__( illumination_angle=illumination_angle, amp_factor=amp_factor, phase_shift_correction=phase_shift_correction, - pupil=pupil, + # pupil=pupil, # Daniel **kwargs, ) @@ -1004,8 +1004,6 @@ def get_XY( """ x = np.arange(shape[0]) - shape[0] / 2 y = np.arange(shape[1]) - shape[1] / 2 - # x = np.arange(shape[0]) - (shape[0] - 1) / 2 - # y = np.arange(shape[1]) - (shape[1] - 1) / 2 return np.meshgrid(x * voxel_size[0], y * voxel_size[1], indexing="ij") def get_detector_mask( @@ -1042,7 +1040,7 @@ def get_plane_in_polar_coords( voxel_size: ArrayLike[float], plane_position: float, illumination_angle: float, - k: float, + # k: float, # Daniel ) -> tuple[float, float, float, float]: """Computes the coordinates of the plane in polar form.""" @@ -1055,22 +1053,24 @@ def get_plane_in_polar_coords( R2_squared = X ** 2 + Y ** 2 R3 = np.sqrt(R2_squared + Z ** 2) # Might be +z instead of -z. - Q = np.sqrt(R2_squared)/voxel_size[0]**2*2*np.pi/shape[0] - # is dimensionally ok? Doesn't look like it. - sin_theta=Q/(k) - pupil_mask=sin_theta<1 - cos_theta=np.zeros(sin_theta.shape) - cos_theta[pupil_mask]=np.sqrt(1-sin_theta[pupil_mask]**2) + + # # DANIEL + # Q = np.sqrt(R2_squared)/voxel_size[0]**2*2*np.pi/shape[0] + # # is dimensionally ok? + # sin_theta=Q/(k) + # pupil_mask=sin_theta<1 + # cos_theta=np.zeros(sin_theta.shape) + # cos_theta[pupil_mask]=np.sqrt(1-sin_theta[pupil_mask]**2) # Fet the angles. - # cos_theta = Z / R3 + cos_theta = Z / R3 illumination_cos_theta = ( np.cos(np.arccos(cos_theta) + illumination_angle) ) phi = np.arctan2(Y, X) - return R3, cos_theta, illumination_cos_theta, phi, pupil_mask + return R3, cos_theta, illumination_cos_theta, phi#, pupil_mask # Daniel def get( self, @@ -1095,7 +1095,7 @@ def get( illumination_angle: float, amp_factor: float, phase_shift_correction: bool, - pupil: ArrayLike, + # pupil: ArrayLike, # Daniel **kwargs, ) -> ArrayLike[float]: """Abstract method to initialize the Mie scatterer""" @@ -1126,13 +1126,13 @@ def get( ) ) - # Get field evaluation plane at offset_z. - R3_field, cos_theta_field, illumination_angle_field, phi_field, pupil_mask =\ + # Get field evaluation plane at offset_z. # , pupil_mask # Daniel + R3_field, cos_theta_field, illumination_angle_field, phi_field =\ self.get_plane_in_polar_coords( arr.shape, voxel_size, relative_position * ratio, illumination_angle, - k + # k # Daniel ) cos_phi_field, sin_phi_field = np.cos(phi_field), np.sin(phi_field) @@ -1150,10 +1150,10 @@ def get( sin_phi_field / ratio ) - # If the beam is within the pupil. - # pupil_mask = (x_farfield - position_objective[0]) ** 2 + ( - # y_farfield - position_objective[1] - # ) ** 2 < (pupil_physical_size / 2) ** 2 + # If the beam is within the pupil. Remove if Daniel + pupil_mask = (x_farfield - position_objective[0]) ** 2 + ( + y_farfield - position_objective[1] + ) ** 2 < (pupil_physical_size / 2) ** 2 R3_field = R3_field[pupil_mask] cos_theta_field = cos_theta_field[pupil_mask] @@ -1204,13 +1204,14 @@ def get( [E[i] * B[i] * PI[i] + E[i] * A[i] * TAU[i] for i in range(0, L)] ) - # arr[pupil_mask] = ( - # -1j - # / (k * R3_field) - # * np.exp(1j * k * R3_field) - # * (S2 * S2_coef + S1 * S1_coef) - # ) / amp_factor - arr[pupil_mask] = (S2 * S2_coef + S1 * S1_coef)/amp_factor + # Daniel + # arr[pupil_mask] = (S2 * S2_coef + S1 * S1_coef)/amp_factor + arr[pupil_mask] = ( + -1j + / (k * R3_field) + * np.exp(1j * k * R3_field) + * (S2 * S2_coef + S1 * S1_coef) + ) / amp_factor # For phase shift correction (a multiplication of the field @@ -1231,19 +1232,22 @@ def get( mask = np.exp(-0.5 * (x ** 2 + y ** 2) / ((sigma) ** 2)) arr = arr * mask - if len(pupil)>0: - c_pix=[arr.shape[0]//2,arr.shape[1]//2] + # Not sure if needed... CM + # if len(pupil)>0: + # c_pix=[arr.shape[0]//2,arr.shape[1]//2] - arr[c_pix[0]-pupil.shape[0]//2:c_pix[0]+pupil.shape[0]//2,c_pix[1]-pupil.shape[1]//2:c_pix[1]+pupil.shape[1]//2]*=pupil - fourier_field = -np.fft.ifft2(np.fft.fftshift(np.fft.fft2(np.fft.fftshift(arr)))) - # fourier_field = np.fft.fft2(np.fft.fftshift(arr)) + # arr[c_pix[0]-pupil.shape[0]//2:c_pix[0]+pupil.shape[0]//2,c_pix[1]-pupil.shape[1]//2:c_pix[1]+pupil.shape[1]//2]*=pupil + + # Daniel + # fourier_field = -np.fft.ifft2(np.fft.fftshift(np.fft.fft2(np.fft.fftshift(arr)))) + fourier_field = np.fft.fft2(arr) propagation_matrix = get_propagation_matrix( fourier_field.shape, pixel_size=voxel_size[:2], # this needs a double check wavelength=wavelength / refractive_index_medium, - # to_z=(-offset_z - z), - to_z=(-z), + # to_z=(-z), # Daniel + to_z=(-offset_z - z), dy=( relative_position[0] * ratio + position[0] @@ -1257,7 +1261,7 @@ def get( ) fourier_field = ( - fourier_field * propagation_matrix #* np.exp(-1j * k * offset_z) + fourier_field * propagation_matrix * np.exp(-1j * k * offset_z) # Remove last part (from exp)) if Daniel ) if return_fft: From 6dde1adfb0741edec8a15003873295d67b9380b7 Mon Sep 17 00:00:00 2001 From: Carlo Date: Thu, 8 Jan 2026 00:00:39 +0100 Subject: [PATCH 11/24] added contrast type included objective attribute (contrast_type) to decide wether to use intensity or refractive index --- deeptrack/holography.py | 18 ++++++++--- deeptrack/optics.py | 72 ++++++++++++++++++++++++++++++----------- deeptrack/scatterers.py | 8 ++--- 3 files changed, 70 insertions(+), 28 deletions(-) diff --git a/deeptrack/holography.py b/deeptrack/holography.py index 39fc4e43..141cc540 100644 --- a/deeptrack/holography.py +++ b/deeptrack/holography.py @@ -101,7 +101,7 @@ def get_propagation_matrix( def get_propagation_matrix( shape: tuple[int, int], to_z: float, - pixel_size: float, + pixel_size: float | tuple[float, float], wavelength: float, dx: float = 0, dy: float = 0 @@ -118,8 +118,8 @@ def get_propagation_matrix( The dimensions of the optical field (height, width). to_z: float Propagation distance along the z-axis. - pixel_size: float - The physical size of each pixel in the optical field. + pixel_size: float | tuple[float, float] + Physical pixel size. If scalar, isotropic pixels are assumed. wavelength: float The wavelength of the optical field. dx: float, optional @@ -140,14 +140,22 @@ def get_propagation_matrix( """ + if pixel_size is None: + pixel_size = get_active_voxel_size() + + if np.isscalar(pixel_size): + pixel_size = (pixel_size, pixel_size) + + px, py = pixel_size + k = 2 * np.pi / wavelength yr, xr, *_ = shape x = np.arange(0, xr, 1) - xr / 2 + (xr % 2) / 2 y = np.arange(0, yr, 1) - yr / 2 + (yr % 2) / 2 - x = 2 * np.pi / pixel_size[0] * x / xr - y = 2 * np.pi / pixel_size[1] * y / yr + x = 2 * np.pi / px * x / xr + y = 2 * np.pi / py * y / yr KXk, KYk = np.meshgrid(x, y) KXk = KXk.astype(complex) diff --git a/deeptrack/optics.py b/deeptrack/optics.py index 092696a5..3990e3c1 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -291,6 +291,15 @@ def get( # Grab properties from the objective to pass to the sample additional_sample_kwargs = self._objective.properties() + contrast_type = getattr(self._objective, "contrast_type", None) + if contrast_type is None: + raise RuntimeError( + f"{self._objective.__class__.__name__} must define `contrast_type` " + "(e.g. 'intensity' or 'refractive_index')." + ) + + additional_sample_kwargs["contrast_type"] = contrast_type + # # Calculate required output image for the given upscale # # This way of providing the upscale will be deprecated in the future @@ -384,25 +393,25 @@ def get( # UserWarning, # ) - # Collect main_property from scatterers - main_properties = { - s.main_property - for s in list_of_scatterers - if hasattr(s, "main_property") - } + # # Collect main_property from scatterers + # main_properties = { + # s.main_property + # for s in list_of_scatterers + # if hasattr(s, "main_property") + # } - if len(main_properties) != 1: - raise ValueError( - f"Inconsistent main_property across scatterers: {main_properties}" - ) + # if len(main_properties) != 1: + # raise ValueError( + # f"Inconsistent main_property across scatterers: {main_properties}" + # ) - main_property = main_properties.pop() + # main_property = main_properties.pop() # Handling upscale from dt.Upscale() here to eliminate Image # wrapping issues. if np.any(np.array(upscale) != 1): ux, uy = upscale[:2] - if main_property == "intensity": + if contrast_type == "intensity": print("Using sum pooling for intensity downscaling.") imaged_sample = SumPoolingCM((ux, uy, 1))(imaged_sample) else: @@ -1045,6 +1054,7 @@ class Fluorescence(Optics): 1.4 """ + contrast_type = "intensity" def get( self: Fluorescence, @@ -1270,6 +1280,8 @@ class Brightfield(Optics): """ + contrast_type = "refractive_index" + __conversion_table__ = ConversionTable( working_distance=(u.meter, u.meter), ) @@ -1988,6 +2000,12 @@ def _create_volume( Spatial limits of the volume. """ + contrast_type = kwargs.get("contrast_type", None) + if contrast_type is None: + raise RuntimeError( + "_create_volume requires a contrast_type " + "(e.g. 'intensity' or 'refractive_index')" + ) if not isinstance(list_of_scatterers, list): list_of_scatterers = [list_of_scatterers] @@ -2016,12 +2034,30 @@ def _create_volume( for scatterer in list_of_scatterers: position = _get_position(scatterer, mode="corner", return_z=True) - if scatterer.main_property == "intensity": - scatterer_value = scatterer.get_property("intensity") #* fudge_factor - elif scatterer.main_property == "refractive_index": - scatterer_value = scatterer.get_property("refractive_index") - refractive_index_medium - else: # fallback to generic value - scatterer_value = scatterer.get_property("value") + # if scatterer.main_property == "intensity": + # scatterer_value = scatterer.get_property("intensity") #* fudge_factor + # elif scatterer.main_property == "refractive_index": + # scatterer_value = scatterer.get_property("refractive_index") - refractive_index_medium + # else: # fallback to generic value + # scatterer_value = scatterer.get_property("value") + + # # Scale the array accordingly + # scatterer.array = scatterer.array * scatterer_value + + if contrast_type == "intensity": + value = scatterer.get_property("intensity", None) + if value is None: + raise ValueError("Scatterer has no intensity.") + scatterer_value = value + + elif contrast_type == "refractive_index": + ri = scatterer.get_property("refractive_index", None) + if ri is None: + raise ValueError("Scatterer has no refractive_index.") + scatterer_value = ri - refractive_index_medium + + else: + raise RuntimeError(f"Unknown contrast_type: {contrast_type}") # Scale the array accordingly scatterer.array = scatterer.array * scatterer_value diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index a9681953..a93a732e 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -353,8 +353,8 @@ def _wrap_output(self, array, props) -> ScatteredVolume: position=props.get("position", (0, 0)), z=props.get("z", 0.0), value=props.get("value", 1.0), - intensity=props.get("intensity", None), - refractive_index=props.get("refractive_index", None), + intensity=props.get("intensity", 1.0), + refractive_index=props.get("refractive_index", 1.59), properties=props.copy(), main_property=self.main_property, )] @@ -364,7 +364,7 @@ def _wrap_output(self, array, props) -> ScatteredField: return [ScatteredField( array=array, position=props.get("position", (0, 0)), - wavelength=props.get("wavelength", 0.0), + wavelength=props.get("wavelength", 532.0), properties=props.copy(), main_property=self.main_property, )] @@ -394,8 +394,6 @@ class PointParticle(VolumeScatterer): for `Brightfield` and `intensity` for `Fluorescence`). """ - - main_property = "intensity" def __init__( self: PointParticle, From 4527e7edcf50f664336e24c09cdfecb6471194a5 Mon Sep 17 00:00:00 2001 From: Carlo Date: Thu, 8 Jan 2026 00:08:26 +0100 Subject: [PATCH 12/24] u --- deeptrack/optics.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/deeptrack/optics.py b/deeptrack/optics.py index 3990e3c1..98026db4 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -393,20 +393,6 @@ def get( # UserWarning, # ) - # # Collect main_property from scatterers - # main_properties = { - # s.main_property - # for s in list_of_scatterers - # if hasattr(s, "main_property") - # } - - # if len(main_properties) != 1: - # raise ValueError( - # f"Inconsistent main_property across scatterers: {main_properties}" - # ) - - # main_property = main_properties.pop() - # Handling upscale from dt.Upscale() here to eliminate Image # wrapping issues. if np.any(np.array(upscale) != 1): @@ -2034,15 +2020,6 @@ def _create_volume( for scatterer in list_of_scatterers: position = _get_position(scatterer, mode="corner", return_z=True) - # if scatterer.main_property == "intensity": - # scatterer_value = scatterer.get_property("intensity") #* fudge_factor - # elif scatterer.main_property == "refractive_index": - # scatterer_value = scatterer.get_property("refractive_index") - refractive_index_medium - # else: # fallback to generic value - # scatterer_value = scatterer.get_property("value") - - # # Scale the array accordingly - # scatterer.array = scatterer.array * scatterer_value if contrast_type == "intensity": value = scatterer.get_property("intensity", None) From 93a0ef4bcfafca040f07e68c7504dbc9770bacef Mon Sep 17 00:00:00 2001 From: Carlo Date: Thu, 8 Jan 2026 00:36:08 +0100 Subject: [PATCH 13/24] rebased ok --- deeptrack/optics.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deeptrack/optics.py b/deeptrack/optics.py index 98026db4..9d10cbe5 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -250,7 +250,7 @@ def __init__( self._sample = self.add_feature(sample) self._objective = self.add_feature(objective) - self._sample.store_properties() + # self._sample.store_properties() def get( self: Microscope, @@ -1047,7 +1047,7 @@ def get( illuminated_volume: ArrayLike[complex], limits: ArrayLike[int], **kwargs: Any, - ) -> Image: + ) -> ArrayLike[complex]: """Simulates the imaging process using a fluorescence microscope. This method convolves the 3D illuminated volume with a pupil function @@ -2055,7 +2055,7 @@ def _create_volume( # Pad scatterer to avoid edge effects during interpolation padded_scatterer = scatterer padded_scatterer.array = np.pad( - scatterer.array._value, # this is a temporary fix, Image should be removed from features + scatterer.array, [(2, 2), (2, 2), (2, 2)], "constant", constant_values=0, From dab26360fe6e7d156eb182fb6d1bef493c3240dc Mon Sep 17 00:00:00 2001 From: Carlo Date: Thu, 8 Jan 2026 11:47:31 +0100 Subject: [PATCH 14/24] u --- deeptrack/optics.py | 54 +++++++++++------------------------------ deeptrack/scatterers.py | 10 ++++---- 2 files changed, 19 insertions(+), 45 deletions(-) diff --git a/deeptrack/optics.py b/deeptrack/optics.py index 9d10cbe5..a7c44e8c 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -154,7 +154,7 @@ def _pad_volume( from deeptrack.math import AveragePooling from deeptrack.features import propagate_data_to_dependencies from deeptrack.features import DummyFeature, Feature, StructuralFeature -from deeptrack.image import Image, pad_image_to_fft +from deeptrack.image import pad_image_to_fft from deeptrack.types import ArrayLike, PropertyLike from deeptrack import image @@ -198,7 +198,7 @@ class Microscope(StructuralFeature): Methods ------- - `get(image: Image or None, **kwargs: Any) -> Image` + `get(image: np.ndarray or None, **kwargs: Any) -> np.ndarray` Simulates the imaging process using the defined optical system and returns the resulting image. @@ -254,9 +254,9 @@ def __init__( def get( self: Microscope, - image: Image | None, + image: np.ndarray | None, **kwargs: Any, - ) -> Image: + ) -> np.ndarray: """Generate an image of the sample using the defined optical system. This method processes the sample through the optical system to @@ -264,14 +264,14 @@ def get( Parameters ---------- - image: Image | None + image: np.ndarray | None The input image to be processed. If None, a new image is created. **kwargs: Any Additional parameters for the imaging process. Returns ------- - Image: Image + image: np.ndarray The processed image after applying the optical system. Examples @@ -300,14 +300,6 @@ def get( additional_sample_kwargs["contrast_type"] = contrast_type - - # # Calculate required output image for the given upscale - # # This way of providing the upscale will be deprecated in the future - # # in favor of dt.Upscale(). - # _upscale_given_by_optics = additional_sample_kwargs["upscale"] - # if np.array(_upscale_given_by_optics).size == 1: - # _upscale_given_by_optics = (_upscale_given_by_optics,) * 3 - with u.context( create_context( *additional_sample_kwargs["voxel_size"]#, *_upscale_given_by_optics @@ -359,21 +351,12 @@ def get( for scatterer in list_of_scatterers if isinstance(scatterer, ScatteredField) ] - - # warn_upscale_fields = False - # if field_samples and np.any(upscale != 1): - # warn_upscale_fields = True # Merge all volumes into a single volume. sample_volume, limits = _create_volume( volume_samples, **additional_sample_kwargs, ) - # sample_volume = Image(sample_volume) - - # Merge all properties into the volume. - # for scatterer in volume_samples + field_samples: - # sample_volume.merge_properties_from(scatterer) # Let the objective know about the limits of the volume and all the fields. propagate_data_to_dependencies( @@ -384,15 +367,6 @@ def get( imaged_sample = self._objective.resolve(sample_volume) - # if warn_upscale_fields: - # warnings.warn( - # "dt.Upscale is active while FieldScatterers are present. " - # "Coherent fields are injected without resampling, so the " - # "physical interpretation may change with Upscale. " - # "This behavior is currently undefined.", - # UserWarning, - # ) - # Handling upscale from dt.Upscale() here to eliminate Image # wrapping issues. if np.any(np.array(upscale) != 1): @@ -1024,7 +998,7 @@ class Fluorescence(Optics): Methods ------- - `get(illuminated_volume: array_like[complex], limits: array_like[int, int], **kwargs: Any) -> Image` + `get(illuminated_volume: array_like[complex], limits: array_like[int, int], **kwargs: Any) -> np.ndarray` Simulates the imaging process using a fluorescence microscope. Examples @@ -1066,7 +1040,7 @@ def get( Returns ------- - Image: Image + image: np.ndarray A 2D image object representing the fluorescence projection. Notes @@ -1084,7 +1058,7 @@ def get( >>> optics = dt.Fluorescence( ... NA=1.4, wavelength=0.52e-6, magnification=60, ... ) - >>> volume = dt.Image(np.ones((128, 128, 10), dtype=complex)) + >>> volume = np.ones((128, 128, 10), dtype=complex) >>> limits = np.array([[0, 128], [0, 128], [0, 10]]) >>> properties = optics.properties() >>> filtered_properties = { @@ -1250,7 +1224,7 @@ class Brightfield(Optics): ------- `get(illuminated_volume: array_like[complex], limits: array_like[int, int], fields: array_like[complex], - **kwargs: Any) -> Image` + **kwargs: Any) -> np.ndarray` Simulates imaging with brightfield microscopy. @@ -1278,7 +1252,7 @@ def get( limits: ArrayLike[int], fields: ArrayLike[complex], **kwargs: Any, - ) -> Image: + ) -> np.ndarray: """Simulates imaging with brightfield microscopy. This method propagates light through the given volume, applying @@ -1303,7 +1277,7 @@ def get( Returns ------- - Image: Image + image: np.ndarray Processed image after simulating the brightfield imaging process. Examples @@ -1318,7 +1292,7 @@ def get( ... wavelength=0.52e-6, ... magnification=60, ... ) - >>> volume = dt.Image(np.ones((128, 128, 10), dtype=complex)) + >>> volume = np.ones((128, 128, 10), dtype=complex) >>> limits = np.array([[0, 128], [0, 128], [0, 10]]) >>> fields = np.array([np.ones((162, 162), dtype=complex)]) >>> properties = optics.properties() @@ -1665,7 +1639,7 @@ def get( limits: ArrayLike[int], fields: ArrayLike[complex], **kwargs: Any, - ) -> Image: + ) -> np.ndarray: """Retrieve the darkfield image of the illuminated volume. Parameters diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index a93a732e..ea0e140d 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -176,7 +176,7 @@ ) from deeptrack.backend import mie from deeptrack.features import Feature, MERGE_STRATEGY_APPEND -from deeptrack.image import pad_image_to_fft, Image +from deeptrack.image import pad_image_to_fft from deeptrack.types import ArrayLike from deeptrack import units_registry as u @@ -300,7 +300,7 @@ def _process_and_get( upsample_axes=None, crop_empty=True, **kwargs - ) -> list[Image] | list[np.ndarray]: + ) -> list[np.ndarray]: # Post processes the created object to handle upsampling, # as well as cropping empty slices. if not self._processed_properties: @@ -407,7 +407,7 @@ def __init__( def get( self: PointParticle, - image: Image | np.ndarray, + image: np.ndarray, **kwarg: Any, ) -> NDArray[Any] | torch.Tensor: """Evaluate and return the scatterer volume.""" @@ -576,7 +576,7 @@ def __init__( def get( self, - image: Image | np.ndarray, + image: np.ndarray, radius: float, voxel_size: float, **kwargs @@ -713,7 +713,7 @@ def _process_properties( def get( self, - image: Image | np.ndarray, + image: np.ndarray, radius: float, rotation: ArrayLike[float] | float, voxel_size: float, From e04589f0efb3ae940459f966b15070d5f956f097 Mon Sep 17 00:00:00 2001 From: Carlo Date: Fri, 9 Jan 2026 01:40:22 +0100 Subject: [PATCH 15/24] fixed sampletomask --- deeptrack/features.py | 232 ++++++++++++++++++++++++++++------------ deeptrack/optics.py | 6 +- deeptrack/scatterers.py | 60 ++++++----- 3 files changed, 204 insertions(+), 94 deletions(-) diff --git a/deeptrack/features.py b/deeptrack/features.py index f76245b8..8a031051 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -171,7 +171,7 @@ from deeptrack.backend import config, TORCH_AVAILABLE, xp from deeptrack.backend.core import DeepTrackNode from deeptrack.backend.units import ConversionTable, create_context -from deeptrack.image import Image #TODO TBE +# from deeptrack.image import Image #TODO TBE from deeptrack.properties import PropertyDict, SequentialProperty from deeptrack.sources import SourceItem from deeptrack.types import ArrayLike, PropertyLike @@ -218,11 +218,11 @@ "OneOf", "OneOfDict", "LoadImage", - "SampleToMasks", # TODO ***CM*** revise this after elimination of Image + "SampleToMasks", "AsType", "ChannelFirst2d", - "Upscale", # TODO ***CM*** revise and check PyTorch afrer elimin. Image - "NonOverlapping", # TODO ***CM*** revise + PyTorch afrer elimin. Image + "Upscale", + "NonOverlapping", "Store", "Squeeze", "Unsqueeze", @@ -7493,7 +7493,7 @@ def __init__( def get( self: Feature, - image: np.ndarray | Image, + image: np.ndarray, transformation_function: Callable[[Image], Image], **kwargs: Any, ) -> Image: @@ -7515,7 +7515,7 @@ def get( """ - return transformation_function(image) + return transformation_function(image.array) def _process_and_get( self: Feature, @@ -7540,26 +7540,33 @@ def _process_and_get( """ # Handle list of images. - if isinstance(images, list) and len(images) != 1: - list_of_labels = super()._process_and_get(images, **kwargs) - if not self._wrap_array_with_image: - for idx, (label, image) in enumerate(zip(list_of_labels, - images)): - list_of_labels[idx] = \ - Image(label, copy=False).merge_properties_from(image) - else: - if isinstance(images, list): - images = images[0] - list_of_labels = [] - for prop in images.properties: - - if "position" in prop: + # if isinstance(images, list) and len(images) != 1: + list_of_labels = super()._process_and_get(images, **kwargs) + # print(len(list_of_labels)) + # print(list_of_labels[0].shape) + + from deeptrack.scatterers import ScatteredVolume + # if not self._wrap_array_with_image: + for idx, (label, image) in enumerate(zip(list_of_labels, + images)): + list_of_labels[idx] = \ + ScatteredVolume(array=label, properties=image.properties.copy()) + # Image(label, copy=False).merge_properties_from(image) + # else: + # if isinstance(images, list): + # images = images[0] + # list_of_labels = [] + # for prop in images.properties: + + # if "position" in prop: + + # inp = Image(np.array(images)) + # inp.append(prop) + # out = Image(self.get(inp, **kwargs)) + # out.merge_properties_from(inp) + # list_of_labels.append(out) - inp = Image(np.array(images)) - inp.append(prop) - out = Image(self.get(inp, **kwargs)) - out.merge_properties_from(inp) - list_of_labels.append(out) + # Create an empty output image. output_region = kwargs["output_region"] @@ -7574,8 +7581,10 @@ def _process_and_get( from deeptrack.optics import _get_position # Merge masks into the output. - for label in list_of_labels: - position = _get_position(label) + for volume in list_of_labels: + label = volume.array + position = _get_position(volume) + p0 = np.round(position - output_region[0:2]) if np.any(p0 > output.shape[0:2]) or \ @@ -7657,11 +7666,11 @@ def _process_and_get( labelarg[..., label_index], ) - if not self._wrap_array_with_image: - return output - output = Image(output) - for label in list_of_labels: - output.merge_properties_from(label) + # if not self._wrap_array_with_image: + # return output + # output = Image(output) + # for label in list_of_labels: + # output.merge_properties_from(label) return output @@ -8087,7 +8096,7 @@ def get( raise ValueError( "Factor must be an integer or a tuple of three integers." ) - + # Create a context for upscaling and perform computation. ctx = create_context(None, None, None, *factor) with units.context(ctx): @@ -8356,7 +8365,7 @@ def get( list_of_volumes = [list_of_volumes] for _ in range(max_iters): - + list_of_volumes = [ self._resample_volume_position(volume) for volume in list_of_volumes @@ -8411,32 +8420,40 @@ def _check_non_overlapping( - If bounding cubes overlap, voxel-level checks are performed. """ + from deeptrack.scatterers import ScatteredVolume - from skimage.morphology import isotropic_erosion, isotropic_dilation - - from deeptrack.augmentations import CropTight, Pad + from deeptrack.augmentations import CropTight, Pad # these are not compatibles with torch backend from deeptrack.optics import _get_position min_distance = self.min_distance() crop = CropTight() + + new_volumes = [] - if min_distance < 0: - list_of_volumes = [ - Image( - crop(isotropic_erosion(volume != 0, -min_distance/2)), - copy=False, - ).merge_properties_from(volume) - for volume in list_of_volumes - ] - else: - pad = Pad(px = [int(np.ceil(min_distance/2))]*6, keep_size=True) - list_of_volumes = [ - Image( - crop(isotropic_dilation(pad(volume) != 0, min_distance/2)), - copy=False, - ).merge_properties_from(volume) - for volume in list_of_volumes - ] + for volume in list_of_volumes: + arr = volume.array + mask = arr != 0 + + if min_distance < 0: + new_arr = isotropic_erosion(mask, -min_distance / 2, backend=self.get_backend()) + else: + pad = Pad(px=[int(np.ceil(min_distance / 2))] * 6, keep_size=True) + new_arr = isotropic_dilation(pad(mask) != 0 , min_distance / 2, backend=self.get_backend()) + new_arr = crop(new_arr) + + if self.get_backend() == "torch": + new_arr = new_arr.to(dtype=arr.dtype) + else: + new_arr = new_arr.astype(arr.dtype) + + new_volume = ScatteredVolume( + array=new_arr, + properties=volume.properties.copy(), + ) + + new_volumes.append(new_volume) + + list_of_volumes = new_volumes min_distance = 1 # The position of the top left corner of each volume (index (0, 0, 0)). @@ -8472,10 +8489,10 @@ def _check_non_overlapping( volume_bounding_cube[i], volume_bounding_cube[j] ) overlapping_volume_1 = self._get_overlapping_volume( - list_of_volumes[i], volume_bounding_cube[i], overlapping_cube + list_of_volumes[i].array, volume_bounding_cube[i], overlapping_cube ) overlapping_volume_2 = self._get_overlapping_volume( - list_of_volumes[j], volume_bounding_cube[j], overlapping_cube + list_of_volumes[j].array, volume_bounding_cube[j], overlapping_cube ) # If either the overlapping regions are empty, the volumes do not @@ -8710,8 +8727,12 @@ def _check_volumes_non_overlapping( """ # Get the positions of the non-zero voxels of each volume. - positions_1 = np.argwhere(volume_1) - positions_2 = np.argwhere(volume_2) + if self.get_backend() == "torch": + positions_1 = torch.nonzero(volume_1, as_tuple=False) + positions_2 = torch.nonzero(volume_2, as_tuple=False) + else: + positions_1 = np.argwhere(volume_1) + positions_2 = np.argwhere(volume_2) # if positions_1.size == 0 or positions_2.size == 0: # return True # If either volume is empty, they are "non-overlapping" @@ -8732,9 +8753,14 @@ def _check_volumes_non_overlapping( # Check that the non-zero voxels of the volumes are at least # min_distance apart. - return np.all( - cdist(positions_1, positions_2) > min_distance - ) + if self.get_backend() == "torch": + dist = torch.cdist( + positions_1.float(), + positions_2.float(), + ) + return bool((dist > min_distance).all()) + else: + return np.all(cdist(positions_1, positions_2) > min_distance) def _resample_volume_position( self: NonOverlapping, @@ -8750,7 +8776,7 @@ def _resample_volume_position( Parameters ---------- - volume: np.ndarray or Image + volume: np.ndarray The 3D volume whose position is to be resampled. The volume must have a `properties` attribute containing dictionaries with `position` and `_position_sampler` keys. @@ -8771,12 +8797,12 @@ def _resample_volume_position( """ - for pdict in volume.properties: - if "position" in pdict and "_position_sampler" in pdict: - new_position = pdict["_position_sampler"]() - if isinstance(new_position, Quantity): - new_position = new_position.to("pixel").magnitude - pdict["position"] = new_position + pdict = volume.properties + if "position" in pdict and "_position_sampler" in pdict: + new_position = pdict["_position_sampler"]() + if isinstance(new_position, Quantity): + new_position = new_position.to("pixel").magnitude + pdict["position"] = new_position return volume @@ -9594,3 +9620,73 @@ def get( res = res[0] return res + +### Move to math? +def isotropic_dilation( + mask, + radius: float, + *, + backend: str, + device=None, + dtype=None, +): + if radius <= 0: + return mask + + if backend == "numpy": + from skimage.morphology import isotropic_dilation + return isotropic_dilation(mask, radius) + + # torch backend + import torch + + r = int(np.ceil(radius)) + kernel = torch.ones( + (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1), + device=device or mask.device, + dtype=dtype or torch.float32, + ) + + x = mask.to(dtype=kernel.dtype)[None, None] + y = torch.nn.functional.conv3d( + x, + kernel, + padding=r, + ) + + return (y[0, 0] > 0) + + +def isotropic_erosion( + mask, + radius: float, + *, + backend: str, + device=None, + dtype=None, +): + if radius <= 0: + return mask + + if backend == "numpy": + from skimage.morphology import isotropic_erosion + return isotropic_erosion(mask, radius) + + import torch + + r = int(np.ceil(radius)) + kernel = torch.ones( + (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1), + device=device or mask.device, + dtype=dtype or torch.float32, + ) + + x = mask.to(dtype=kernel.dtype)[None, None] + y = torch.nn.functional.conv3d( + x, + kernel, + padding=r, + ) + + required = kernel.numel() + return (y[0, 0] >= required) diff --git a/deeptrack/optics.py b/deeptrack/optics.py index a7c44e8c..a0ad2be9 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -2027,13 +2027,15 @@ def _create_volume( continue # Pad scatterer to avoid edge effects during interpolation - padded_scatterer = scatterer - padded_scatterer.array = np.pad( + padded_scatterer_arr = np.pad( #torch? scatterer.array, [(2, 2), (2, 2), (2, 2)], "constant", constant_values=0, ) + padded_scatterer = ScatteredVolume( + array=padded_scatterer_arr,properties=scatterer.properties.copy() + ) position = _get_position(padded_scatterer, mode="corner", return_z=True) shape = np.array(padded_scatterer.array.shape) diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index ea0e140d..40d3fce3 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -348,26 +348,26 @@ def _wrap_output(self, array, props) -> ScatteredBase: class VolumeScatterer(Scatterer): """Abstract scatterer producing ScatteredVolume outputs.""" def _wrap_output(self, array, props) -> ScatteredVolume: - return [ScatteredVolume( + return ScatteredVolume( array=array, - position=props.get("position", (0, 0)), - z=props.get("z", 0.0), - value=props.get("value", 1.0), - intensity=props.get("intensity", 1.0), - refractive_index=props.get("refractive_index", 1.59), + # position=props.get("position", (0, 0)), + # z=props.get("z", 0.0), + # value=props.get("value", 1.0), + # intensity=props.get("intensity", 1.0), + # refractive_index=props.get("refractive_index", 1.59), properties=props.copy(), - main_property=self.main_property, - )] + # position_sampler=props.get("_position_sampler", None), + ) class FieldScatterer(Scatterer): def _wrap_output(self, array, props) -> ScatteredField: - return [ScatteredField( + return ScatteredField( array=array, - position=props.get("position", (0, 0)), - wavelength=props.get("wavelength", 532.0), + # position=props.get("position", (0, 0)), + # wavelength=props.get("wavelength", 532.0), properties=props.copy(), - main_property=self.main_property, - )] + # position_sampler=props.get("_position_sampler", None), + ) #TODO ***??*** revise PointParticle - torch, typing, docstring, unit test @@ -1468,14 +1468,23 @@ class ScatteredBase: """Base class for scatterers (volumes and fields).""" array: ArrayLike - position: np.ndarray - z: float = 0.0 + # position: np.ndarray + # z: float = 0.0 properties: dict[str, Any] = field(default_factory=dict) - main_property: str = None - def __post_init__(self): - self.position = np.array(self.position, dtype=float).reshape(-1)[:2] - self.z = float(np.atleast_1d(self.z).squeeze()) + # def __post_init__(self): + # self.position = np.array(self.position, dtype=float).reshape(-1)[:2] + # self.z = float(np.atleast_1d(self.z).squeeze()) + + @property + def ndim(self) -> int: + """Number of dimensions of the underlying array.""" + return self.array.ndim + + @property + def shape(self) -> int: + """Number of dimensions of the underlying array.""" + return self.array.shape @property def pos3d(self) -> np.ndarray: @@ -1501,13 +1510,16 @@ def get_property(self, key: str, default: Any = None) -> Any: class ScatteredVolume(ScatteredBase): """Volumetric object: intensity sources or refractive index contrasts.""" - refractive_index: float | None = None - intensity: float | None = None - value: float | None = None - + # refractive_index: float | None = None + # intensity: float | None = None + # value: float | None = None + # position_sampler: Optional[Callable[[], np.ndarray]] = None + pass @dataclass class ScatteredField(ScatteredBase): """Complex wavefield (already propagated or emitted).""" - wavelength: float = 500e-9 \ No newline at end of file + # wavelength: float = 500e-9 + # position_sampler: Optional[Callable[[], np.ndarray]] = None + pass \ No newline at end of file From dcb6fed9022ea1e788d5f3e32b63b5d53042ba9e Mon Sep 17 00:00:00 2001 From: Carlo Date: Fri, 9 Jan 2026 01:53:36 +0100 Subject: [PATCH 16/24] u --- deeptrack/scatterers.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index 40d3fce3..52cbd576 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -247,9 +247,6 @@ class Scatterer(Feature): voxel_size=(u.meter, u.meter), ) - #: Default property name (subclasses override this) - main_property: str = "value" - def __init__( self, position: ArrayLike[float] = (32, 32), @@ -350,23 +347,14 @@ class VolumeScatterer(Scatterer): def _wrap_output(self, array, props) -> ScatteredVolume: return ScatteredVolume( array=array, - # position=props.get("position", (0, 0)), - # z=props.get("z", 0.0), - # value=props.get("value", 1.0), - # intensity=props.get("intensity", 1.0), - # refractive_index=props.get("refractive_index", 1.59), properties=props.copy(), - # position_sampler=props.get("_position_sampler", None), ) class FieldScatterer(Scatterer): def _wrap_output(self, array, props) -> ScatteredField: return ScatteredField( array=array, - # position=props.get("position", (0, 0)), - # wavelength=props.get("wavelength", 532.0), properties=props.copy(), - # position_sampler=props.get("_position_sampler", None), ) @@ -459,8 +447,6 @@ class Ellipse(VolumeScatterer): rotation=(u.radian, u.radian), ) - main_property = "refractive_index" - def __init__( self, radius: float = 1e-6, @@ -565,8 +551,6 @@ class Sphere(VolumeScatterer): radius=(u.meter, u.meter), ) - main_property = "refractive_index" - def __init__( self, radius: float = 1e-6, @@ -642,8 +626,6 @@ class Ellipsoid(VolumeScatterer): rotation=(u.radian, u.radian), ) - main_property = "refractive_index" - def __init__( self, radius: float = 1e-6, @@ -854,8 +836,6 @@ class MieScatterer(FieldScatterer): coherence_length=(u.meter, u.pixel), ) - main_property = "wavelength" - def __init__( self, coefficients, From c5ab7b00c9ad716eb6fe4eeb5d8f5a32898f4f1b Mon Sep 17 00:00:00 2001 From: Carlo Date: Fri, 9 Jan 2026 09:52:44 +0100 Subject: [PATCH 17/24] sampletomask torch compatible --- deeptrack/features.py | 78 ++++++++++++++++++------------------------- deeptrack/optics.py | 4 +-- 2 files changed, 35 insertions(+), 47 deletions(-) diff --git a/deeptrack/features.py b/deeptrack/features.py index 8a031051..1a669887 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -171,7 +171,7 @@ from deeptrack.backend import config, TORCH_AVAILABLE, xp from deeptrack.backend.core import DeepTrackNode from deeptrack.backend.units import ConversionTable, create_context -# from deeptrack.image import Image #TODO TBE +from deeptrack.image import Image #TODO TBE from deeptrack.properties import PropertyDict, SequentialProperty from deeptrack.sources import SourceItem from deeptrack.types import ArrayLike, PropertyLike @@ -7398,7 +7398,7 @@ class SampleToMasks(Feature): Returns ------- - Image or np.ndarray + np.ndarray The final mask image with the specified number of layers. Raises @@ -7460,7 +7460,7 @@ class SampleToMasks(Feature): def __init__( self: Feature, - transformation_function: Callable[[Image], Image], + transformation_function: Callable[[np.ndarray], np.ndarray, torch.Tensor], number_of_masks: PropertyLike[int] = 1, output_region: PropertyLike[tuple[int, int, int, int]] = None, merge_method: PropertyLike[str | Callable | list[str | Callable]] = "add", @@ -7494,16 +7494,16 @@ def __init__( def get( self: Feature, image: np.ndarray, - transformation_function: Callable[[Image], Image], + transformation_function: Callable[list[np.ndarray] | np.ndarray | torch.Tensor], **kwargs: Any, - ) -> Image: + ) -> np.ndarray: """Apply the transformation function to a single image. Parameters ---------- - image: np.ndarray | Image + image: np.ndarray The input image. - transformation_function: Callable[[Image], Image] + transformation_function: Callable[[np.ndarray], np.ndarray] Function to transform the image. **kwargs: dict[str, Any] Additional parameters. @@ -7519,9 +7519,9 @@ def get( def _process_and_get( self: Feature, - images: list[np.ndarray] | np.ndarray | list[Image] | Image, + images: list[np.ndarray] | np.ndarray | list[torch.Tensor] | torch.Tensor, **kwargs: Any, - ) -> Image | np.ndarray: + ) -> np.ndarray: """Process a list of images and generate a multi-layer mask. Parameters @@ -7542,40 +7542,21 @@ def _process_and_get( # Handle list of images. # if isinstance(images, list) and len(images) != 1: list_of_labels = super()._process_and_get(images, **kwargs) - # print(len(list_of_labels)) - # print(list_of_labels[0].shape) from deeptrack.scatterers import ScatteredVolume - # if not self._wrap_array_with_image: - for idx, (label, image) in enumerate(zip(list_of_labels, - images)): + for idx, (label, image) in enumerate(zip(list_of_labels, images)): list_of_labels[idx] = \ - ScatteredVolume(array=label, properties=image.properties.copy()) - # Image(label, copy=False).merge_properties_from(image) - # else: - # if isinstance(images, list): - # images = images[0] - # list_of_labels = [] - # for prop in images.properties: - - # if "position" in prop: - - # inp = Image(np.array(images)) - # inp.append(prop) - # out = Image(self.get(inp, **kwargs)) - # out.merge_properties_from(inp) - # list_of_labels.append(out) - - + ScatteredVolume(array=label, properties=image.properties.copy()) # Create an empty output image. output_region = kwargs["output_region"] - output = np.zeros( + output = xp.zeros( ( output_region[2] - output_region[0], output_region[3] - output_region[1], kwargs["number_of_masks"], - ) + ), + dtype=list_of_labels[0].array.dtype, ) from deeptrack.optics import _get_position @@ -7585,14 +7566,22 @@ def _process_and_get( label = volume.array position = _get_position(volume) - p0 = np.round(position - output_region[0:2]) + # p0 = np.round(position - output_region[0:2]) + p0 = xp.round(position - xp.asarray(output_region[0:2])) + p0 = p0.astype(xp.int64) + - if np.any(p0 > output.shape[0:2]) or \ - np.any(p0 + label.shape[0:2] < 0): + # if np.any(p0 > output.shape[0:2]) or \ + # np.any(p0 + label.shape[0:2] < 0): + if xp.any(p0 > xp.asarray(output.shape[:2])) or \ + xp.any(p0 + xp.asarray(label.shape[:2]) < 0): continue - crop_x = int(-np.min([p0[0], 0])) - crop_y = int(-np.min([p0[1], 0])) + # crop_x = int(-np.min([p0[0], 0])) + # crop_y = int(-np.min([p0[1], 0])) + crop_x = (-xp.minimum(p0[0], 0)).item() + crop_y = (-xp.minimum(p0[1], 0)).item() + crop_x_end = int( label.shape[0] - np.max([p0[0] + label.shape[0] - output.shape[0], 0]) @@ -7644,9 +7633,13 @@ def _process_and_get( p0[0] : p0[0] + labelarg.shape[0], p0[1] : p0[1] + labelarg.shape[1], label_index, - ] = (output_slice[..., label_index] != 0) | ( + ] = xp.logical_or( + output_slice[..., label_index] != 0, labelarg[..., label_index] != 0 - ) + ) + # (output_slice[..., label_index] != 0) | ( + # labelarg[..., label_index] != 0 + # ) elif merge == "mul": output[ @@ -7666,11 +7659,6 @@ def _process_and_get( labelarg[..., label_index], ) - # if not self._wrap_array_with_image: - # return output - # output = Image(output) - # for label in list_of_labels: - # output.merge_properties_from(label) return output diff --git a/deeptrack/optics.py b/deeptrack/optics.py index a0ad2be9..c2c87b14 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -2027,14 +2027,14 @@ def _create_volume( continue # Pad scatterer to avoid edge effects during interpolation - padded_scatterer_arr = np.pad( #torch? + padded_scatterer_arr = np.pad( #Use Pad instead and make it torch-compatible? scatterer.array, [(2, 2), (2, 2), (2, 2)], "constant", constant_values=0, ) padded_scatterer = ScatteredVolume( - array=padded_scatterer_arr,properties=scatterer.properties.copy() + array=padded_scatterer_arr, properties=scatterer.properties.copy() ) position = _get_position(padded_scatterer, mode="corner", return_z=True) shape = np.array(padded_scatterer.array.shape) From fec2411db3bfdfe928e51c935bca921f8f4704ef Mon Sep 17 00:00:00 2001 From: Carlo Date: Fri, 9 Jan 2026 11:03:02 +0100 Subject: [PATCH 18/24] unified field and volume wrapper --- deeptrack/features.py | 18 +++----- deeptrack/optics.py | 38 ++++------------- deeptrack/scatterers.py | 93 +++++++++++++++++++++++------------------ 3 files changed, 65 insertions(+), 84 deletions(-) diff --git a/deeptrack/features.py b/deeptrack/features.py index 1a669887..be82b558 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -7543,10 +7543,11 @@ def _process_and_get( # if isinstance(images, list) and len(images) != 1: list_of_labels = super()._process_and_get(images, **kwargs) - from deeptrack.scatterers import ScatteredVolume + from deeptrack.scatterers import ScatteredObject + for idx, (label, image) in enumerate(zip(list_of_labels, images)): list_of_labels[idx] = \ - ScatteredVolume(array=label, properties=image.properties.copy()) + ScatteredObject(array=label, properties=image.properties.copy(), role=image.role) # Create an empty output image. output_region = kwargs["output_region"] @@ -7566,19 +7567,14 @@ def _process_and_get( label = volume.array position = _get_position(volume) - # p0 = np.round(position - output_region[0:2]) p0 = xp.round(position - xp.asarray(output_region[0:2])) p0 = p0.astype(xp.int64) - # if np.any(p0 > output.shape[0:2]) or \ - # np.any(p0 + label.shape[0:2] < 0): if xp.any(p0 > xp.asarray(output.shape[:2])) or \ xp.any(p0 + xp.asarray(label.shape[:2]) < 0): continue - # crop_x = int(-np.min([p0[0], 0])) - # crop_y = int(-np.min([p0[1], 0])) crop_x = (-xp.minimum(p0[0], 0)).item() crop_y = (-xp.minimum(p0[1], 0)).item() @@ -7637,9 +7633,6 @@ def _process_and_get( output_slice[..., label_index] != 0, labelarg[..., label_index] != 0 ) - # (output_slice[..., label_index] != 0) | ( - # labelarg[..., label_index] != 0 - # ) elif merge == "mul": output[ @@ -8408,7 +8401,7 @@ def _check_non_overlapping( - If bounding cubes overlap, voxel-level checks are performed. """ - from deeptrack.scatterers import ScatteredVolume + from deeptrack.scatterers import ScatteredObject from deeptrack.augmentations import CropTight, Pad # these are not compatibles with torch backend from deeptrack.optics import _get_position @@ -8434,9 +8427,10 @@ def _check_non_overlapping( else: new_arr = new_arr.astype(arr.dtype) - new_volume = ScatteredVolume( + new_volume = ScatteredObject( array=new_arr, properties=volume.properties.copy(), + role=volume.role, ) new_volumes.append(new_volume) diff --git a/deeptrack/optics.py b/deeptrack/optics.py index c2c87b14..d0b5d66b 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -162,7 +162,7 @@ def _pad_volume( from deeptrack import TORCH_AVAILABLE, image from deeptrack.backend import xp -from deeptrack.scatterers import ScatteredVolume, ScatteredField +from deeptrack.scatterers import ScatteredObject if TORCH_AVAILABLE: import torch @@ -342,14 +342,14 @@ def get( volume_samples = [ scatterer for scatterer in list_of_scatterers - if isinstance(scatterer, ScatteredVolume) + if scatterer.role == "volume" ] # All scatterers that are defined as fields. field_samples = [ scatterer for scatterer in list_of_scatterers - if isinstance(scatterer, ScatteredField) + if scatterer.role == "field" ] # Merge all volumes into a single volume. @@ -1810,7 +1810,7 @@ def get( #TODO ***??*** revise _get_position - torch, typing, docstring, unit test def _get_position( - scatterer: ScatteredVolume, + scatterer: ScatteredObject, mode: str = "corner", return_z: bool = False, ) -> np.ndarray: @@ -2033,8 +2033,8 @@ def _create_volume( "constant", constant_values=0, ) - padded_scatterer = ScatteredVolume( - array=padded_scatterer_arr, properties=scatterer.properties.copy() + padded_scatterer = ScatteredObject( + array=padded_scatterer_arr, properties=scatterer.properties.copy(), role=scatterer.role, ) position = _get_position(padded_scatterer, mode="corner", return_z=True) shape = np.array(padded_scatterer.array.shape) @@ -2060,30 +2060,6 @@ def _create_volume( "Expected np.ndarray or torch.Tensor." ) - # kernel = np.array( - # [ - # [0, 0, 0], - # [0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off], - # [0, x_off * (1 - y_off), x_off * y_off], - # ] - # ) - - # for z in range(padded_scatterer.array.shape[2]): - # if splined_scatterer.dtype == complex: - # splined_scatterer[:, :, z] = ( - # convolve( - # np.real(padded_scatterer.array[:, :, z]), kernel, mode="constant" - # ) - # + convolve( - # np.imag(padded_scatterer.array[:, :, z]), kernel, mode="constant" - # ) - # * 1j - # ) - # else: - # splined_scatterer[:, :, z] = convolve( - # padded_scatterer.array[:, :, z], kernel, mode="constant" - # ) - position = np.floor(position) new_limits = np.zeros(limits.shape, dtype=np.int32) for i in range(3): @@ -2113,7 +2089,7 @@ def _create_volume( within_volume_position = position - limits[:, 0] # NOTE: Maybe shouldn't be additive. - # give options: sum default, but also sum, mean, max, min + # give options: sum default, but also mean, max, min, or volume[ int(within_volume_position[0]) : int(within_volume_position[0] + shape[0]), diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index 52cbd576..9e45664f 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -338,28 +338,33 @@ def _process_and_get( # props = kwargs.copy() return [self._wrap_output(new_image, kwargs)] - def _wrap_output(self, array, props) -> ScatteredBase: - """Must be overridden in subclasses to wrap output correctly.""" - raise NotImplementedError - -class VolumeScatterer(Scatterer): - """Abstract scatterer producing ScatteredVolume outputs.""" - def _wrap_output(self, array, props) -> ScatteredVolume: - return ScatteredVolume( + def _wrap_output(self, array, props) -> ScatteredObject: + # """Must be overridden in subclasses to wrap output correctly.""" + # raise NotImplementedError + return ScatteredObject( array=array, properties=props.copy(), + role = self.role, ) -class FieldScatterer(Scatterer): - def _wrap_output(self, array, props) -> ScatteredField: - return ScatteredField( - array=array, - properties=props.copy(), - ) +# class VolumeScatterer(Scatterer): +# """Abstract scatterer producing ScatteredVolume outputs.""" +# def _wrap_output(self, array, props) -> ScatteredVolume: +# return ScatteredVolume( +# array=array, +# properties=props.copy(), +# ) + +# class FieldScatterer(Scatterer): +# def _wrap_output(self, array, props) -> ScatteredField: +# return ScatteredField( +# array=array, +# properties=props.copy(), +# ) #TODO ***??*** revise PointParticle - torch, typing, docstring, unit test -class PointParticle(VolumeScatterer): +class PointParticle(Scatterer): """Generate a diffraction-limited point particle. A point particle is approximated by the size of a single pixel or voxel. @@ -382,7 +387,8 @@ class PointParticle(VolumeScatterer): for `Brightfield` and `intensity` for `Fluorescence`). """ - + role = "volume" + def __init__( self: PointParticle, **kwargs: Any, @@ -406,7 +412,7 @@ def get( #TODO ***??*** revise Ellipse - torch, typing, docstring, unit test -class Ellipse(VolumeScatterer): +class Ellipse(Scatterer): """Generates an elliptical disk scatterer Parameters @@ -441,6 +447,7 @@ class Ellipse(VolumeScatterer): before rotation. """ + role = "volume" __conversion_table__ = ConversionTable( radius=(u.meter, u.meter), @@ -520,7 +527,7 @@ def get( #TODO ***??*** revise Sphere - torch, typing, docstring, unit test -class Sphere(VolumeScatterer): +class Sphere(Scatterer): """Generates a spherical scatterer Parameters @@ -546,6 +553,7 @@ class Sphere(VolumeScatterer): Upsamples the calculations of the pixel occupancy fraction. """ + role = "volume" __conversion_table__ = ConversionTable( radius=(u.meter, u.meter), @@ -585,7 +593,7 @@ def get( #TODO ***??*** revise Ellipsoid - torch, typing, docstring, unit test -class Ellipsoid(VolumeScatterer): +class Ellipsoid(Scatterer): """Generates an ellipsoidal scatterer Parameters @@ -621,6 +629,8 @@ class Ellipsoid(VolumeScatterer): """ + role = "volume" + __conversion_table__ = ConversionTable( radius=(u.meter, u.meter), rotation=(u.radian, u.radian), @@ -742,7 +752,7 @@ def get( #TODO ***??*** revise MieScatterer - torch, typing, docstring, unit test -class MieScatterer(FieldScatterer): +class MieScatterer(Scatterer): """Base implementation of a Mie particle. New Mie-theory scatterers can be implemented by extending this class, and @@ -827,6 +837,8 @@ class MieScatterer(FieldScatterer): """ + role = "field" + __conversion_table__ = ConversionTable( radius=(u.meter, u.meter), polarization_angle=(u.radian, u.radian), @@ -1304,6 +1316,8 @@ class MieSphere(MieScatterer): """ + role = "field" + def __init__( self, radius: float = 1e-6, @@ -1406,6 +1420,8 @@ class MieStratifiedSphere(MieScatterer): """ + role = "field" + def __init__( self, radius: ArrayLike[float] = [1e-6], @@ -1444,17 +1460,12 @@ def inner( @dataclass -class ScatteredBase: +class ScatteredObject: """Base class for scatterers (volumes and fields).""" array: ArrayLike - # position: np.ndarray - # z: float = 0.0 properties: dict[str, Any] = field(default_factory=dict) - - # def __post_init__(self): - # self.position = np.array(self.position, dtype=float).reshape(-1)[:2] - # self.z = float(np.atleast_1d(self.z).squeeze()) + role: Literal["volume", "field"] = "volume" @property def ndim(self) -> int: @@ -1486,20 +1497,20 @@ def get_property(self, key: str, default: Any = None) -> Any: return getattr(self, key, self.properties.get(key, default)) -@dataclass -class ScatteredVolume(ScatteredBase): - """Volumetric object: intensity sources or refractive index contrasts.""" +# @dataclass +# class ScatteredVolume(ScatteredBase): +# """Volumetric object: intensity sources or refractive index contrasts.""" - # refractive_index: float | None = None - # intensity: float | None = None - # value: float | None = None - # position_sampler: Optional[Callable[[], np.ndarray]] = None - pass +# # refractive_index: float | None = None +# # intensity: float | None = None +# # value: float | None = None +# # position_sampler: Optional[Callable[[], np.ndarray]] = None +# pass -@dataclass -class ScatteredField(ScatteredBase): - """Complex wavefield (already propagated or emitted).""" +# @dataclass +# class ScatteredField(ScatteredBase): +# """Complex wavefield (already propagated or emitted).""" - # wavelength: float = 500e-9 - # position_sampler: Optional[Callable[[], np.ndarray]] = None - pass \ No newline at end of file +# # wavelength: float = 500e-9 +# # position_sampler: Optional[Callable[[], np.ndarray]] = None +# pass \ No newline at end of file From ac248c7b1382f9fa04ef6c65d812feecd5b6b6ff Mon Sep 17 00:00:00 2001 From: Carlo Date: Fri, 9 Jan 2026 14:18:15 +0100 Subject: [PATCH 19/24] u --- deeptrack/scatterers.py | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index 9e45664f..b942888c 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -1481,6 +1481,16 @@ def shape(self) -> int: def pos3d(self) -> np.ndarray: return np.array([*self.position, self.z], dtype=float) + @property + def position(self) -> np.ndarray: + pos = self.properties.get("position", None) + if pos is None: + return None + pos = np.asarray(pos, dtype=float) + if pos.ndim == 2 and pos.shape[0] == 1: + pos = pos[0] + return pos + def as_array(self) -> ArrayLike: """Return the underlying array. @@ -1495,22 +1505,3 @@ def as_array(self) -> ArrayLike: def get_property(self, key: str, default: Any = None) -> Any: return getattr(self, key, self.properties.get(key, default)) - - -# @dataclass -# class ScatteredVolume(ScatteredBase): -# """Volumetric object: intensity sources or refractive index contrasts.""" - -# # refractive_index: float | None = None -# # intensity: float | None = None -# # value: float | None = None -# # position_sampler: Optional[Callable[[], np.ndarray]] = None -# pass - -# @dataclass -# class ScatteredField(ScatteredBase): -# """Complex wavefield (already propagated or emitted).""" - -# # wavelength: float = 500e-9 -# # position_sampler: Optional[Callable[[], np.ndarray]] = None -# pass \ No newline at end of file From 10e15523cd5a0eef9278b0b2c3bfc6bb5f4a8284 Mon Sep 17 00:00:00 2001 From: Carlo Date: Thu, 15 Jan 2026 01:51:27 +0100 Subject: [PATCH 20/24] u --- deeptrack/features.py | 97 ++----------- deeptrack/math.py | 291 ++++++++++++++++++++++++++++++++++++++- deeptrack/optics.py | 293 +++++++++++++++++++++++----------------- deeptrack/scatterers.py | 116 ++++++++++------ 4 files changed, 548 insertions(+), 249 deletions(-) diff --git a/deeptrack/features.py b/deeptrack/features.py index be82b558..e3092698 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -96,7 +96,7 @@ - `TakeProperties`: Extract all instances of properties from a pipeline. Arithmetic Feature Classes: -- `Add`: Add a value to the input. +- `Add`: Add a value to the input.@dataclass - `Subtract`: Subtract a value from the input. - `Multiply`: Multiply the input by a value. - `Divide`: Divide the input by a value. @@ -7543,11 +7543,11 @@ def _process_and_get( # if isinstance(images, list) and len(images) != 1: list_of_labels = super()._process_and_get(images, **kwargs) - from deeptrack.scatterers import ScatteredObject + from deeptrack.scatterers import ScatteredVolume for idx, (label, image) in enumerate(zip(list_of_labels, images)): list_of_labels[idx] = \ - ScatteredObject(array=label, properties=image.properties.copy(), role=image.role) + ScatteredVolume(array=label, properties=image.properties.copy()) # Create an empty output image. output_region = kwargs["output_region"] @@ -8080,15 +8080,18 @@ def get( # Create a context for upscaling and perform computation. ctx = create_context(None, None, None, *factor) + + print('before:', image) with units.context(ctx): image = self.feature(image) - # # Downscale the result to the original resolution. - # import skimage.measure + print('after:', image) + # Downscale the result to the original resolution. + import skimage.measure - # image = skimage.measure.block_reduce( - # image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean - # ) + image = skimage.measure.block_reduce( + image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean + ) return image @@ -8401,10 +8404,11 @@ def _check_non_overlapping( - If bounding cubes overlap, voxel-level checks are performed. """ - from deeptrack.scatterers import ScatteredObject + from deeptrack.scatterers import ScatteredVolume from deeptrack.augmentations import CropTight, Pad # these are not compatibles with torch backend from deeptrack.optics import _get_position + from deeptrack.math import isotropic_erosion, isotropic_dilation min_distance = self.min_distance() crop = CropTight() @@ -8427,10 +8431,9 @@ def _check_non_overlapping( else: new_arr = new_arr.astype(arr.dtype) - new_volume = ScatteredObject( + new_volume = ScatteredVolume( array=new_arr, properties=volume.properties.copy(), - role=volume.role, ) new_volumes.append(new_volume) @@ -9601,74 +9604,4 @@ def get( if len(res) == 1: res = res[0] - return res - -### Move to math? -def isotropic_dilation( - mask, - radius: float, - *, - backend: str, - device=None, - dtype=None, -): - if radius <= 0: - return mask - - if backend == "numpy": - from skimage.morphology import isotropic_dilation - return isotropic_dilation(mask, radius) - - # torch backend - import torch - - r = int(np.ceil(radius)) - kernel = torch.ones( - (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1), - device=device or mask.device, - dtype=dtype or torch.float32, - ) - - x = mask.to(dtype=kernel.dtype)[None, None] - y = torch.nn.functional.conv3d( - x, - kernel, - padding=r, - ) - - return (y[0, 0] > 0) - - -def isotropic_erosion( - mask, - radius: float, - *, - backend: str, - device=None, - dtype=None, -): - if radius <= 0: - return mask - - if backend == "numpy": - from skimage.morphology import isotropic_erosion - return isotropic_erosion(mask, radius) - - import torch - - r = int(np.ceil(radius)) - kernel = torch.ones( - (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1), - device=device or mask.device, - dtype=dtype or torch.float32, - ) - - x = mask.to(dtype=kernel.dtype)[None, None] - y = torch.nn.functional.conv3d( - x, - kernel, - padding=r, - ) - - required = kernel.numel() - return (y[0, 0] >= required) + return res \ No newline at end of file diff --git a/deeptrack/math.py b/deeptrack/math.py index 05cbf311..0af95c05 100644 --- a/deeptrack/math.py +++ b/deeptrack/math.py @@ -93,7 +93,7 @@ from __future__ import annotations -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, Callable, Dict, Tuple, TYPE_CHECKING import array_api_compat as apc import numpy as np @@ -110,6 +110,7 @@ if TORCH_AVAILABLE: import torch + import torch.nn.functional as F if OPENCV_AVAILABLE: import cv2 @@ -129,11 +130,15 @@ "MaxPooling", "MinPooling", "MedianPooling", + "PoolV2", + "AveragePoolingV2", + "MaxPoolingV2", + "MinPoolingV2", + "MedianPoolingV2", "BlurCV2", "BilateralBlur", ] - if TYPE_CHECKING: import torch @@ -1663,6 +1668,218 @@ def __init__( super().__init__(np.median, ksize=ksize, **kwargs) +class PoolV2: + """ + DeepTrack v2 replacement for Pool. + + Generic, center-preserving block pooling with NumPy and Torch backends. + Public API matches v1: a single integer ksize. + + Pool size semantics: + - 2D input -> (ksize, ksize, 1) + - 3D input -> (ksize, ksize, ksize) + """ + + _TORCH_REDUCERS_2D: Dict[Callable, Callable] = { + np.mean: lambda x, k, s: F.avg_pool2d(x, k, s), + np.sum: lambda x, k, s: F.avg_pool2d(x, k, s) * (k[0] * k[1]), + np.max: lambda x, k, s: F.max_pool2d(x, k, s), + np.min: lambda x, k, s: -F.max_pool2d(-x, k, s), + } + + _TORCH_REDUCERS_3D: Dict[Callable, Callable] = { + np.mean: lambda x, k, s: F.avg_pool3d(x, k, s), + np.sum: lambda x, k, s: F.avg_pool3d(x, k, s) * (k[0] * k[1] * k[2]), + np.max: lambda x, k, s: F.max_pool3d(x, k, s), + np.min: lambda x, k, s: -F.max_pool3d(-x, k, s), + } + + def __init__( + self, + pooling_function: Callable, + ksize: int = 2, + ): + if pooling_function not in ( + np.mean, np.sum, np.min, np.max, np.median + ): + raise ValueError( + "Unsupported pooling_function. " + "Use one of: np.mean, np.sum, np.min, np.max, np.median." + ) + + if not isinstance(ksize, int) or ksize < 1: + raise ValueError("ksize must be a positive integer.") + + self.pooling_function = pooling_function + self.ksize = int(ksize) + + def _get_pool_size(self, array) -> Tuple[int, int, int]: + """ + Determine pooling kernel size based on semantic dimensionality. + + - 2D images: (Nx, Ny) or (Nx, Ny, C) -> pool in x,y only + - 3D volumes: (Nx, Ny, Nz) or (Nx, Ny, Nz, C) -> pool in x,y,z + - Never pool over channels + """ + k = self.ksize + + # 2D image + if array.ndim == 2: + return k, k, 1 + + # 3D array: could be (x, y, z) or (x, y, c) + if array.ndim == 3: + # Heuristic: small last dim → channels + if array.shape[-1] <= 4: + return k, k, 1 + return k, k, k + + # 4D array: (x, y, z, c) + if array.ndim == 4: + return k, k, k + + raise ValueError( + f"Unsupported array shape {array.shape} for pooling." + ) + + def _crop_center(self, array): + px, py, pz = self._get_pool_size(array) + + # 2D (or effectively 2D) + if array.ndim < 3 or pz == 1: + H, W = array.shape[:2] + crop_h = (H // px) * px + crop_w = (W // py) * py + off_h = (H - crop_h) // 2 + off_w = (W - crop_w) // 2 + return array[ + off_h : off_h + crop_h, + off_w : off_w + crop_w, + ... + ] + + # 3D + Z, H, W = array.shape[:3] + crop_z = (Z // pz) * pz + crop_h = (H // px) * px + crop_w = (W // py) * py + off_z = (Z - crop_z) // 2 + off_h = (H - crop_h) // 2 + off_w = (W - crop_w) // 2 + return array[ + off_z : off_z + crop_z, + off_h : off_h + crop_h, + off_w : off_w + crop_w, + ... + ] + + def _pool_numpy(self, array: np.ndarray) -> np.ndarray: + array = self._crop_center(array) + px, py, pz = self._get_pool_size(array) + + if array.ndim < 3 or pz == 1: + pool_shape = (px, py) + (1,) * (array.ndim - 2) + else: + pool_shape = (pz, px, py) + (1,) * (array.ndim - 3) + + return skimage.measure.block_reduce( + array, + block_size=pool_shape, + func=self.pooling_function, + ) + + def _pool_torch(self, array: torch.Tensor) -> torch.Tensor: + array = self._crop_center(array) + px, py, pz = self._get_pool_size(array) + + is_3d = array.ndim >= 3 and pz > 1 + + if not is_3d: + extra = array.shape[2:] + C = int(np.prod(extra)) if extra else 1 + x = array.reshape(1, C, array.shape[0], array.shape[1]) + kernel = (px, py) + stride = (px, py) + reducers = self._TORCH_REDUCERS_2D + else: + extra = array.shape[3:] + C = int(np.prod(extra)) if extra else 1 + x = array.reshape( + 1, C, array.shape[0], array.shape[1], array.shape[2] + ) + kernel = (pz, px, py) + stride = (pz, px, py) + reducers = self._TORCH_REDUCERS_3D + + # Median: explicit unfolding + if self.pooling_function is np.median: + if is_3d: + x_u = ( + x.unfold(2, pz, pz) + .unfold(3, px, px) + .unfold(4, py, py) + ) + x_u = x_u.contiguous().view( + 1, C, + x_u.shape[2], + x_u.shape[3], + x_u.shape[4], + -1, + ) + pooled = x_u.median(dim=-1).values + else: + x_u = x.unfold(2, px, px).unfold(3, py, py) + x_u = x_u.contiguous().view( + 1, C, + x_u.shape[2], + x_u.shape[3], + -1, + ) + pooled = x_u.median(dim=-1).values + else: + reducer = reducers[self.pooling_function] + pooled = reducer(x, kernel, stride) + + return pooled.reshape(pooled.shape[2:] + extra) + + def __call__(self, array): + if isinstance(array, np.ndarray): + return self._pool_numpy(array) + + if TORCH_AVAILABLE and isinstance(array, torch.Tensor): + return self._pool_torch(array) + + raise TypeError( + "PoolV2 only supports np.ndarray or torch.Tensor inputs." + ) + + +class AveragePoolingV2(PoolV2): + def __init__(self, ksize: int = 2): + super().__init__(np.mean, ksize) + + +class SumPoolingV2(PoolV2): + def __init__(self, ksize: int = 2): + super().__init__(np.sum, ksize) + + +class MinPoolingV2(PoolV2): + def __init__(self, ksize: int = 2): + super().__init__(np.min, ksize) + + +class MaxPoolingV2(PoolV2): + def __init__(self, ksize: int = 2): + super().__init__(np.max, ksize) + + +class MedianPoolingV2(PoolV2): + def __init__(self, ksize: int = 2): + super().__init__(np.median, ksize) + + + class Resize(Feature): """Resize an image to a specified size. @@ -2059,3 +2276,73 @@ def __init__( sigmaSpace=sigma_space, **kwargs, ) + + +def isotropic_dilation( + mask, + radius: float, + *, + backend: str, + device=None, + dtype=None, +): + if radius <= 0: + return mask + + if backend == "numpy": + from skimage.morphology import isotropic_dilation + return isotropic_dilation(mask, radius) + + # torch backend + import torch + + r = int(np.ceil(radius)) + kernel = torch.ones( + (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1), + device=device or mask.device, + dtype=dtype or torch.float32, + ) + + x = mask.to(dtype=kernel.dtype)[None, None] + y = torch.nn.functional.conv3d( + x, + kernel, + padding=r, + ) + + return (y[0, 0] > 0) + + +def isotropic_erosion( + mask, + radius: float, + *, + backend: str, + device=None, + dtype=None, +): + if radius <= 0: + return mask + + if backend == "numpy": + from skimage.morphology import isotropic_erosion + return isotropic_erosion(mask, radius) + + import torch + + r = int(np.ceil(radius)) + kernel = torch.ones( + (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1), + device=device or mask.device, + dtype=dtype or torch.float32, + ) + + x = mask.to(dtype=kernel.dtype)[None, None] + y = torch.nn.functional.conv3d( + x, + kernel, + padding=r, + ) + + required = kernel.numel() + return (y[0, 0] >= required) \ No newline at end of file diff --git a/deeptrack/optics.py b/deeptrack/optics.py index d0b5d66b..0788d7d8 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -151,7 +151,7 @@ def _pad_volume( get_active_scale, get_active_voxel_size, ) -from deeptrack.math import AveragePooling +from deeptrack.math import AveragePoolingV2, SumPoolingV2 from deeptrack.features import propagate_data_to_dependencies from deeptrack.features import DummyFeature, Feature, StructuralFeature from deeptrack.image import pad_image_to_fft @@ -162,7 +162,7 @@ def _pad_volume( from deeptrack import TORCH_AVAILABLE, image from deeptrack.backend import xp -from deeptrack.scatterers import ScatteredObject +from deeptrack.scatterers import ScatteredVolume, ScatteredField if TORCH_AVAILABLE: import torch @@ -175,9 +175,13 @@ def _pad_volume( class Microscope(StructuralFeature): """Simulates imaging of a sample using an optical system. - This class combines a feature-set that defines the sample to be imaged with - a feature-set defining the optical system, enabling the simulation of - optical imaging processes. + This class combines the sample to be imaged with the optical system, + enabling the simulation of optical imaging processes. + A Microscope: + - validates the semantic compatibility between scatterers and optics + - interprets volume-based scatterers into scalar fields when needed + - delegates numerical propagation to the objective (Optics) + - performs detector downscaling according to its physical semantics Parameters ---------- @@ -202,6 +206,12 @@ class Microscope(StructuralFeature): Simulates the imaging process using the defined optical system and returns the resulting image. + Notes + ----- + All volume scatterers imaged by a Microscope instance are assumed to + share the same contrast mechanism (e.g. refractive index or fluorescence). + Mixing contrast types is not supported. + Examples -------- Simulating an image using a brightfield optical system: @@ -250,13 +260,40 @@ def __init__( self._sample = self.add_feature(sample) self._objective = self.add_feature(objective) - # self._sample.store_properties() + + def _validate_input(self, scattered): + if hasattr(self._objective, "validate_input"): + self._objective.validate_input(scattered) + + def _extract_contrast_volume(self, scattered): + if hasattr(self._objective, "extract_contrast_volume"): + return self._objective.extract_contrast_volume(scattered) + + # default: geometry-only + return scattered.array + + def _downscale_image(self, image, upscale): + if hasattr(self._objective, "downscale_image"): + return self._objective.downscale_image(image, upscale) + + if not np.any(np.array(upscale) != 1): + return image + + ux, uy = upscale[:2] + if ux != uy: + raise ValueError( + f"Energy-conserving detector integration requires ux == uy, " + f"got ux={ux}, uy={uy}." + ) + if isinstance(ux, float) and ux.is_integer(): + ux = int(ux) + return AveragePoolingV2(ux)(image) def get( self: Microscope, - image: np.ndarray | None, + image: np.ndarray | torch.Tensor | None = None, **kwargs: Any, - ) -> np.ndarray: + ) -> np.ndarray | torch.Tensor: """Generate an image of the sample using the defined optical system. This method processes the sample through the optical system to @@ -264,14 +301,14 @@ def get( Parameters ---------- - image: np.ndarray | None + image: np.ndarray | torch.Tensor | None The input image to be processed. If None, a new image is created. **kwargs: Any Additional parameters for the imaging process. Returns ------- - image: np.ndarray + image: np.ndarray | torch.Tensor The processed image after applying the optical system. Examples @@ -291,18 +328,14 @@ def get( # Grab properties from the objective to pass to the sample additional_sample_kwargs = self._objective.properties() - contrast_type = getattr(self._objective, "contrast_type", None) - if contrast_type is None: - raise RuntimeError( - f"{self._objective.__class__.__name__} must define `contrast_type` " - "(e.g. 'intensity' or 'refractive_index')." - ) - additional_sample_kwargs["contrast_type"] = contrast_type + _upscale_given_by_optics = additional_sample_kwargs["upscale"] + if np.array(_upscale_given_by_optics).size == 1: + _upscale_given_by_optics = (_upscale_given_by_optics,) * 3 with u.context( create_context( - *additional_sample_kwargs["voxel_size"]#, *_upscale_given_by_optics + *additional_sample_kwargs["voxel_size"], *_upscale_given_by_optics ) ): @@ -338,18 +371,22 @@ def get( if not isinstance(list_of_scatterers, list): list_of_scatterers = [list_of_scatterers] + # Semantic validation (per scatterer) + for scattered in list_of_scatterers: + self._validate_input(scattered) + # All scatterers that are defined as volumes. volume_samples = [ scatterer for scatterer in list_of_scatterers - if scatterer.role == "volume" + if isinstance(scatterer, ScatteredVolume) ] # All scatterers that are defined as fields. field_samples = [ scatterer for scatterer in list_of_scatterers - if scatterer.role == "field" + if isinstance(scatterer, ScatteredField) ] # Merge all volumes into a single volume. @@ -358,24 +395,33 @@ def get( **additional_sample_kwargs, ) + # Interpret the merged volume semantically + sample_volume = self._extract_contrast_volume( + ScatteredVolume( + array=sample_volume, + properties=volume_samples[0].properties, + ) + ) + # Let the objective know about the limits of the volume and all the fields. propagate_data_to_dependencies( self._objective, limits=limits, - fields=field_samples, + fields=field_samples, # should We add upscale? ) imaged_sample = self._objective.resolve(sample_volume) - # Handling upscale from dt.Upscale() here to eliminate Image - # wrapping issues. - if np.any(np.array(upscale) != 1): - ux, uy = upscale[:2] - if contrast_type == "intensity": - print("Using sum pooling for intensity downscaling.") - imaged_sample = SumPoolingCM((ux, uy, 1))(imaged_sample) - else: - imaged_sample = AveragePoolingCM((ux, uy, 1))(imaged_sample) + imaged_sample = self._downscale_image(imaged_sample, upscale) + # # Handling upscale from dt.Upscale() here to eliminate Image + # # wrapping issues. + # if np.any(np.array(upscale) != 1): + # ux, uy = upscale[:2] + # if contrast_type == "intensity": + # print("Using sum pooling for intensity downscaling.") + # imaged_sample = SumPoolingCM((ux, uy, 1))(imaged_sample) + # else: + # imaged_sample = AveragePoolingCM((ux, uy, 1))(imaged_sample) return imaged_sample @@ -561,6 +607,15 @@ def __init__( """ + def validate_scattered(self, scattered): + pass + + def extract_contrast_volume(self, scattered): + pass + + def downscale_image(self, image, upscale): + pass + def get_voxel_size( resolution: float | ArrayLike[float], magnification: float, @@ -665,6 +720,7 @@ def _process_properties( wavelength = propertydict["wavelength"] voxel_size = get_active_voxel_size() radius = NA / wavelength * np.array(voxel_size) + print('Pupil radius (in pixels):', radius) if np.any(radius[:2] > 0.5): required_upscale = np.max(np.ceil(radius[:2] * 2)) @@ -1014,7 +1070,66 @@ class Fluorescence(Optics): 1.4 """ - contrast_type = "intensity" + + + def validate_input(self, scattered): + """Semantic validation for fluorescence microscopy.""" + + # Fluorescence cannot operate on coherent fields + if isinstance(scattered, ScatteredField): + raise TypeError( + "Fluorescence microscope cannot operate on ScatteredField." + ) + + # Fluorescence must not use refractive index + if isinstance(scattered, ScatteredVolume): + if scattered.get_property("refractive_index", None) is not None: + raise ValueError( + "Fluorescence does not use refractive index. " + "Found 'refractive_index' in scatterer properties." + ) + + + def extract_contrast_volume(self, scattered: ScatteredVolume) -> np.ndarray: + """Contrast extraction (semantic interpretation)""" + intensity = scattered.get_property("intensity", None) + + if intensity is None: + intensity = scattered.get_property("value", None) + if intensity is None: + raise ValueError( + "Fluorescence requires 'intensity' or 'value'." + ) + + warnings.warn( + "Using 'value' as fluorescence intensity is ambiguous. " + "Please use 'intensity' explicitly to avoid ambiguity.", + UserWarning, + ) + + voxel_size = np.asarray(get_active_voxel_size(), dtype=float) + voxel_volume = float(np.prod(voxel_size)) + + return scattered.array * intensity * voxel_volume + + + def downscale_image(self, image: np.ndarray, upscale): + """Detector downscaling (energy conserving)""" + if not np.any(np.array(upscale) != 1): + return image + + ux, uy = upscale[:2] + if ux != uy: + raise ValueError( + f"Energy-conserving detector integration requires ux == uy, " + f"got ux={ux}, uy={uy}." + ) + if isinstance(ux, float) and ux.is_integer(): + ux = int(ux) + + # Energy-conserving detector integration + return SumPoolingV2(ux)(image) + def get( self: Fluorescence, @@ -1240,7 +1355,6 @@ class Brightfield(Optics): """ - contrast_type = "refractive_index" __conversion_table__ = ConversionTable( working_distance=(u.meter, u.meter), @@ -1960,12 +2074,12 @@ def _create_volume( Spatial limits of the volume. """ - contrast_type = kwargs.get("contrast_type", None) - if contrast_type is None: - raise RuntimeError( - "_create_volume requires a contrast_type " - "(e.g. 'intensity' or 'refractive_index')" - ) + # contrast_type = kwargs.get("contrast_type", None) + # if contrast_type is None: + # raise RuntimeError( + # "_create_volume requires a contrast_type " + # "(e.g. 'intensity' or 'refractive_index')" + # ) if not isinstance(list_of_scatterers, list): list_of_scatterers = [list_of_scatterers] @@ -1995,23 +2109,23 @@ def _create_volume( for scatterer in list_of_scatterers: position = _get_position(scatterer, mode="corner", return_z=True) - if contrast_type == "intensity": - value = scatterer.get_property("intensity", None) - if value is None: - raise ValueError("Scatterer has no intensity.") - scatterer_value = value + # if contrast_type == "intensity": + # value = scatterer.get_property("intensity", None) + # if value is None: + # raise ValueError("Scatterer has no intensity.") + # scatterer_value = value - elif contrast_type == "refractive_index": - ri = scatterer.get_property("refractive_index", None) - if ri is None: - raise ValueError("Scatterer has no refractive_index.") - scatterer_value = ri - refractive_index_medium + # elif contrast_type == "refractive_index": + # ri = scatterer.get_property("refractive_index", None) + # if ri is None: + # raise ValueError("Scatterer has no refractive_index.") + # scatterer_value = ri - refractive_index_medium - else: - raise RuntimeError(f"Unknown contrast_type: {contrast_type}") + # else: + # raise RuntimeError(f"Unknown contrast_type: {contrast_type}") - # Scale the array accordingly - scatterer.array = scatterer.array * scatterer_value + # # Scale the array accordingly + # scatterer.array = scatterer.array * scatterer_value if limits is None: limits = np.zeros((3, 2), dtype=np.int32) @@ -2033,8 +2147,8 @@ def _create_volume( "constant", constant_values=0, ) - padded_scatterer = ScatteredObject( - array=padded_scatterer_arr, properties=scatterer.properties.copy(), role=scatterer.role, + padded_scatterer = ScatteredVolume( + array=padded_scatterer_arr, properties=scatterer.properties.copy(), ) position = _get_position(padded_scatterer, mode="corner", return_z=True) shape = np.array(padded_scatterer.array.shape) @@ -2088,7 +2202,7 @@ def _create_volume( within_volume_position = position - limits[:, 0] - # NOTE: Maybe shouldn't be additive. + # NOTE: Maybe shouldn't be ONLY additive. # give options: sum default, but also mean, max, min, or volume[ int(within_volume_position[0]) : @@ -2100,71 +2214,4 @@ def _create_volume( int(within_volume_position[2]) : int(within_volume_position[2] + shape[2]), ] += splined_scatterer - return volume, limits - -# this should be moved to math -class _CenteredPoolingBase: - def __init__(self, pool_size: tuple[int, int, int]): - px, py, pz = pool_size - if pz != 1: - raise ValueError("Only pz=1 supported.") - self.px = int(px) - self.py = int(py) - - def _crop_center(self, array): - H, W = array.shape[:2] - px, py = self.px, self.py - - crop_h = (H // px) * px - crop_w = (W // py) * py - - off_h = (H - crop_h) // 2 - off_w = (W - crop_w) // 2 - - return array[off_h:off_h+crop_h, off_w:off_w+crop_w, ...] - - def _pool_numpy(self, array, func): - import skimage.measure - array = self._crop_center(array) - pool_shape = (self.px, self.py) + (1,) * (array.ndim - 2) - return skimage.measure.block_reduce(array, pool_shape, func) - - def _pool_torch(self, array, sum_pool=False): - px, py = self.px, self.py - array = self._crop_center(array) - - extra = array.shape[2:] - C = int(np.prod(extra)) if extra else 1 - x = array.reshape(1, C, array.shape[0], array.shape[1]) - - pooled = torch.nn.functional.avg_pool2d( - x, kernel_size=(px, py), stride=(px, py) - ) - if sum_pool: - pooled = pooled * (px * py) - - return pooled.reshape( - (pooled.shape[2], pooled.shape[3]) + extra - ) - -class AveragePoolingCM(_CenteredPoolingBase): - """Center-preserving average pooling (intensive quantities).""" - - def __call__(self, array): - if isinstance(array, np.ndarray): - return self._pool_numpy(array, np.mean) - elif TORCH_AVAILABLE and isinstance(array, torch.Tensor): - return self._pool_torch(array, sum_pool=False) - else: - raise TypeError("Unsupported array type.") - -class SumPoolingCM(_CenteredPoolingBase): - """Center-preserving sum pooling (extensive quantities).""" - - def __call__(self, array): - if isinstance(array, np.ndarray): - return self._pool_numpy(array, np.sum) - elif TORCH_AVAILABLE and isinstance(array, torch.Tensor): - return self._pool_torch(array, sum_pool=True) - else: - raise TypeError("Unsupported array type.") \ No newline at end of file + return volume, limits \ No newline at end of file diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index b942888c..2ba202e0 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -175,12 +175,14 @@ get_active_voxel_size, ) from deeptrack.backend import mie +from deeptrack.math import AveragePoolingV2 from deeptrack.features import Feature, MERGE_STRATEGY_APPEND from deeptrack.image import pad_image_to_fft from deeptrack.types import ArrayLike from deeptrack import units_registry as u + __all__ = [ "Scatterer", "PointParticle", @@ -239,7 +241,7 @@ class Scatterer(Feature): """ - __list_merge_strategy__ = MERGE_STRATEGY_APPEND + __list_merge_strategy__ = MERGE_STRATEGY_APPEND ### Not clear why needed __distributed__ = False __conversion_table__ = ConversionTable( position=(u.pixel, u.pixel), @@ -279,6 +281,21 @@ def __init__( **kwargs, ) + def _antialias_volume(self, volume, factor: int): + """Geometry-only supersampling anti-aliasing. + + Assumes `volume` was generated on a grid oversampled by `factor` + and downsamples it back by average pooling. + """ + if factor == 1: + return volume + + # average pooling conserves fractional occupancy + return AveragePoolingV2( + factor + )(volume) + + def _process_properties( self, properties: dict @@ -308,16 +325,31 @@ def _process_and_get( + "Optics.upscale != 1." ) - voxel_size = get_active_voxel_size() - # Calls parent _process_and_get. + voxel_size = np.asarray(get_active_voxel_size(), float) + + apply_supersampling = upsample > 1 and isinstance(self, VolumeScatterer) + + if upsample > 1 and not apply_supersampling: + warnings.warn( + "Geometry supersampling (upsample) is ignored for " + "FieldScatterers.", + UserWarning, + ) + + if apply_supersampling: + voxel_size /= float(upsample) + new_image = super(Scatterer, self)._process_and_get( *args, voxel_size=voxel_size, upsample=upsample, **kwargs, - ) - new_image = new_image[0] + )[0] + + if apply_supersampling: + new_image = self._antialias_volume(new_image, factor=upsample) + if new_image.size == 0: warnings.warn( @@ -338,33 +370,31 @@ def _process_and_get( # props = kwargs.copy() return [self._wrap_output(new_image, kwargs)] - def _wrap_output(self, array, props) -> ScatteredObject: - # """Must be overridden in subclasses to wrap output correctly.""" - # raise NotImplementedError - return ScatteredObject( + def _wrap_output(self, array, props): + raise NotImplementedError( + f"{self.__class__.__name__} must implement _wrap_output()" + ) + + +class VolumeScatterer(Scatterer): + """Abstract scatterer producing ScatteredVolume outputs.""" + def _wrap_output(self, array, props) -> ScatteredVolume: + return ScatteredVolume( array=array, properties=props.copy(), - role = self.role, ) -# class VolumeScatterer(Scatterer): -# """Abstract scatterer producing ScatteredVolume outputs.""" -# def _wrap_output(self, array, props) -> ScatteredVolume: -# return ScatteredVolume( -# array=array, -# properties=props.copy(), -# ) -# class FieldScatterer(Scatterer): -# def _wrap_output(self, array, props) -> ScatteredField: -# return ScatteredField( -# array=array, -# properties=props.copy(), -# ) +class FieldScatterer(Scatterer): + def _wrap_output(self, array, props) -> ScatteredField: + return ScatteredField( + array=array, + properties=props.copy(), + ) #TODO ***??*** revise PointParticle - torch, typing, docstring, unit test -class PointParticle(Scatterer): +class PointParticle(VolumeScatterer): """Generate a diffraction-limited point particle. A point particle is approximated by the size of a single pixel or voxel. @@ -387,7 +417,6 @@ class PointParticle(Scatterer): for `Brightfield` and `intensity` for `Fluorescence`). """ - role = "volume" def __init__( self: PointParticle, @@ -396,7 +425,7 @@ def __init__( """ """ - + kwargs.pop("upsample", None) super().__init__(upsample=1, upsample_axes=(), **kwargs) def get( @@ -412,7 +441,7 @@ def get( #TODO ***??*** revise Ellipse - torch, typing, docstring, unit test -class Ellipse(Scatterer): +class Ellipse(VolumeScatterer): """Generates an elliptical disk scatterer Parameters @@ -447,7 +476,7 @@ class Ellipse(Scatterer): before rotation. """ - role = "volume" + __conversion_table__ = ConversionTable( radius=(u.meter, u.meter), @@ -527,7 +556,7 @@ def get( #TODO ***??*** revise Sphere - torch, typing, docstring, unit test -class Sphere(Scatterer): +class Sphere(VolumeScatterer): """Generates a spherical scatterer Parameters @@ -553,7 +582,6 @@ class Sphere(Scatterer): Upsamples the calculations of the pixel occupancy fraction. """ - role = "volume" __conversion_table__ = ConversionTable( radius=(u.meter, u.meter), @@ -593,7 +621,7 @@ def get( #TODO ***??*** revise Ellipsoid - torch, typing, docstring, unit test -class Ellipsoid(Scatterer): +class Ellipsoid(VolumeScatterer): """Generates an ellipsoidal scatterer Parameters @@ -629,8 +657,6 @@ class Ellipsoid(Scatterer): """ - role = "volume" - __conversion_table__ = ConversionTable( radius=(u.meter, u.meter), rotation=(u.radian, u.radian), @@ -752,7 +778,7 @@ def get( #TODO ***??*** revise MieScatterer - torch, typing, docstring, unit test -class MieScatterer(Scatterer): +class MieScatterer(FieldScatterer): """Base implementation of a Mie particle. New Mie-theory scatterers can be implemented by extending this class, and @@ -837,7 +863,6 @@ class MieScatterer(Scatterer): """ - role = "field" __conversion_table__ = ConversionTable( radius=(u.meter, u.meter), @@ -878,7 +903,6 @@ def __init__( "Please use input_polarization instead" ) input_polarization = polarization_angle - # kwargs.pop("is_field", None) # remove kwargs.pop("crop_empty", None) super().__init__( @@ -1106,7 +1130,6 @@ def get( # Wave vector. k = 2 * np.pi / wavelength * refractive_index_medium - # Position of objective relative particle. relative_position = np.array( ( @@ -1316,7 +1339,6 @@ class MieSphere(MieScatterer): """ - role = "field" def __init__( self, @@ -1420,7 +1442,6 @@ class MieStratifiedSphere(MieScatterer): """ - role = "field" def __init__( self, @@ -1460,12 +1481,11 @@ def inner( @dataclass -class ScatteredObject: +class ScatteredBase: """Base class for scatterers (volumes and fields).""" - array: ArrayLike + array: np.ndarray | torch.Tensor properties: dict[str, Any] = field(default_factory=dict) - role: Literal["volume", "field"] = "volume" @property def ndim(self) -> int: @@ -1505,3 +1525,15 @@ def as_array(self) -> ArrayLike: def get_property(self, key: str, default: Any = None) -> Any: return getattr(self, key, self.properties.get(key, default)) + + +@dataclass +class ScatteredVolume(ScatteredBase): + """Voxelized volume produced by a VolumeScatterer.""" + pass + + +@dataclass +class ScatteredField(ScatteredBase): + """Complex field produced by a FieldScatterer.""" + pass \ No newline at end of file From 58a1b8a312929128488efe96637cbdc9b8104784 Mon Sep 17 00:00:00 2001 From: Carlo Date: Thu, 15 Jan 2026 16:45:00 +0100 Subject: [PATCH 21/24] u --- deeptrack/features.py | 1320 +++++------------------------------------ deeptrack/optics.py | 1141 ++++++++++++++++++++++++++++++++++- 2 files changed, 1277 insertions(+), 1184 deletions(-) diff --git a/deeptrack/features.py b/deeptrack/features.py index e3092698..901664a9 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -218,11 +218,8 @@ "OneOf", "OneOfDict", "LoadImage", - "SampleToMasks", "AsType", "ChannelFirst2d", - "Upscale", - "NonOverlapping", "Store", "Squeeze", "Unsqueeze", @@ -7359,302 +7356,6 @@ def get( return image -class SampleToMasks(Feature): - """Create a mask from a list of images. - - This feature applies a transformation function to each input image and - merges the resulting masks into a single multi-layer image. Each input - image must have a `position` property that determines its placement within - the final mask. When used with scatterers, the `voxel_size` property must - be provided for correct object sizing. - - Parameters - ---------- - transformation_function: Callable[[Image], Image] - A function that transforms each input image into a mask with - `number_of_masks` layers. - number_of_masks: PropertyLike[int], optional - The number of mask layers to generate. Default is 1. - output_region: PropertyLike[tuple[int, int, int, int]], optional - The size and position of the output mask, typically aligned with - `optics.output_region`. - merge_method: PropertyLike[str | Callable | list[str | Callable]], optional - Method for merging individual masks into the final image. Can be: - - "add" (default): Sum the masks. - - "overwrite": Later masks overwrite earlier masks. - - "or": Combine masks using a logical OR operation. - - "mul": Multiply masks. - - Function: Custom function taking two images and merging them. - - **kwargs: dict[str, Any] - Additional keyword arguments passed to the parent `Feature` class. - - Methods - ------- - `get(image, transformation_function, **kwargs) -> Image` - Applies the transformation function to the input image. - `_process_and_get(images, **kwargs) -> Image | np.ndarray` - Processes a list of images and generates a multi-layer mask. - - Returns - ------- - np.ndarray - The final mask image with the specified number of layers. - - Raises - ------ - ValueError - If `merge_method` is invalid. - - Examples - ------- - >>> import deeptrack as dt - - Define number of particles: - - >>> n_particles = 12 - - Define optics and particles: - - >>> import numpy as np - >>> - >>> optics = dt.Fluorescence(output_region=(0, 0, 64, 64)) - >>> particle = dt.PointParticle( - >>> position=lambda: np.random.uniform(5, 55, size=2), - >>> ) - >>> particles = particle ^ n_particles - - Define pipelines: - - >>> sim_im_pip = optics(particles) - >>> sim_mask_pip = particles >> dt.SampleToMasks( - ... lambda: lambda particles: particles > 0, - ... output_region=optics.output_region, - ... merge_method="or", - ... ) - >>> pipeline = sim_im_pip & sim_mask_pip - >>> pipeline.store_properties() - - Generate image and mask: - - >>> image, mask = pipeline.update()() - - Get particle positions: - - >>> positions = np.array(image.get_property("position", get_one=False)) - - Visualize results: - - >>> import matplotlib.pyplot as plt - >>> - >>> plt.subplot(1, 2, 1) - >>> plt.imshow(image, cmap="gray") - >>> plt.title("Original Image") - >>> plt.subplot(1, 2, 2) - >>> plt.imshow(mask, cmap="gray") - >>> plt.scatter(positions[:,1], positions[:,0], c="y", marker="x", s = 50) - >>> plt.title("Mask") - >>> plt.show() - - """ - - def __init__( - self: Feature, - transformation_function: Callable[[np.ndarray], np.ndarray, torch.Tensor], - number_of_masks: PropertyLike[int] = 1, - output_region: PropertyLike[tuple[int, int, int, int]] = None, - merge_method: PropertyLike[str | Callable | list[str | Callable]] = "add", - **kwargs: Any, - ): - """Initialize the SampleToMasks feature. - - Parameters - ---------- - transformation_function: Callable[[Image], Image] - Function to transform input images into masks. - number_of_masks: PropertyLike[int], optional - Number of mask layers. Default is 1. - output_region: PropertyLike[tuple[int, int, int, int]], optional - Output region of the mask. Default is None. - merge_method: PropertyLike[str | Callable | list[str | Callable]], optional - Method to merge masks. Defaults to "add". - **kwargs: dict[str, Any] - Additional keyword arguments passed to the parent class. - - """ - - super().__init__( - transformation_function=transformation_function, - number_of_masks=number_of_masks, - output_region=output_region, - merge_method=merge_method, - **kwargs, - ) - - def get( - self: Feature, - image: np.ndarray, - transformation_function: Callable[list[np.ndarray] | np.ndarray | torch.Tensor], - **kwargs: Any, - ) -> np.ndarray: - """Apply the transformation function to a single image. - - Parameters - ---------- - image: np.ndarray - The input image. - transformation_function: Callable[[np.ndarray], np.ndarray] - Function to transform the image. - **kwargs: dict[str, Any] - Additional parameters. - - Returns - ------- - Image - The transformed image. - - """ - - return transformation_function(image.array) - - def _process_and_get( - self: Feature, - images: list[np.ndarray] | np.ndarray | list[torch.Tensor] | torch.Tensor, - **kwargs: Any, - ) -> np.ndarray: - """Process a list of images and generate a multi-layer mask. - - Parameters - ---------- - images: np.ndarray or list[np.ndarrray] or Image or list[Image] - List of input images or a single image. - **kwargs: dict[str, Any] - Additional parameters including `output_region`, `number_of_masks`, - and `merge_method`. - - Returns - ------- - Image or np.ndarray - The final mask image. - - """ - - # Handle list of images. - # if isinstance(images, list) and len(images) != 1: - list_of_labels = super()._process_and_get(images, **kwargs) - - from deeptrack.scatterers import ScatteredVolume - - for idx, (label, image) in enumerate(zip(list_of_labels, images)): - list_of_labels[idx] = \ - ScatteredVolume(array=label, properties=image.properties.copy()) - - # Create an empty output image. - output_region = kwargs["output_region"] - output = xp.zeros( - ( - output_region[2] - output_region[0], - output_region[3] - output_region[1], - kwargs["number_of_masks"], - ), - dtype=list_of_labels[0].array.dtype, - ) - - from deeptrack.optics import _get_position - - # Merge masks into the output. - for volume in list_of_labels: - label = volume.array - position = _get_position(volume) - - p0 = xp.round(position - xp.asarray(output_region[0:2])) - p0 = p0.astype(xp.int64) - - - if xp.any(p0 > xp.asarray(output.shape[:2])) or \ - xp.any(p0 + xp.asarray(label.shape[:2]) < 0): - continue - - crop_x = (-xp.minimum(p0[0], 0)).item() - crop_y = (-xp.minimum(p0[1], 0)).item() - - crop_x_end = int( - label.shape[0] - - np.max([p0[0] + label.shape[0] - output.shape[0], 0]) - ) - crop_y_end = int( - label.shape[1] - - np.max([p0[1] + label.shape[1] - output.shape[1], 0]) - ) - - labelarg = label[crop_x:crop_x_end, crop_y:crop_y_end, :] - - p0[0] = np.max([p0[0], 0]) - p0[1] = np.max([p0[1], 0]) - - p0 = p0.astype(int) - - output_slice = output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - ] - - for label_index in range(kwargs["number_of_masks"]): - - if isinstance(kwargs["merge_method"], list): - merge = kwargs["merge_method"][label_index] - else: - merge = kwargs["merge_method"] - - if merge == "add": - output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - label_index, - ] += labelarg[..., label_index] - - elif merge == "overwrite": - output_slice[ - labelarg[..., label_index] != 0, label_index - ] = labelarg[labelarg[..., label_index] != 0, \ - label_index] - output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - label_index, - ] = output_slice[..., label_index] - - elif merge == "or": - output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - label_index, - ] = xp.logical_or( - output_slice[..., label_index] != 0, - labelarg[..., label_index] != 0 - ) - - elif merge == "mul": - output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - label_index, - ] *= labelarg[..., label_index] - - else: - # No match, assume function - output[ - p0[0] : p0[0] + labelarg.shape[0], - p0[1] : p0[1] + labelarg.shape[1], - label_index, - ] = merge( - output_slice[..., label_index], - labelarg[..., label_index], - ) - - return output - - class AsType(Feature): """Convert the data type of arrays. @@ -7920,876 +7621,181 @@ def get( return array -class Upscale(Feature): - """Simulate a pipeline at a higher resolution. - - This feature scales up the resolution of the input pipeline by a specified - factor, performs computations at the higher resolution, and then - downsamples the result back to the original size. This is useful for - simulating effects at a finer resolution while preserving compatibility - with lower-resolution pipelines. - - Internally, this feature redefines the scale of physical units (e.g., - `units.pixel`) to achieve the effect of upscaling. Therefore, it does not - resize the input image itself but affects only features that rely on - physical units. - - Parameters - ---------- - feature: Feature - The pipeline or feature to resolve at a higher resolution. - factor: int or tuple[int, int, int], optional - The factor by which to upscale the simulation. If a single integer is - provided, it is applied uniformly across all axes. If a tuple of three - integers is provided, each axis is scaled individually. Defaults to 1. - **kwargs: Any - Additional keyword arguments passed to the parent `Feature` class. - - Attributes - ---------- - __distributed__: bool - Always `False` for `Upscale`, indicating that this feature’s `.get()` - method processes the entire input at once even if it is a list, rather - than distributing calls for each item of the list. - - Methods - ------- - `get(image, factor, **kwargs) -> np.ndarray | torch.tensor` - Simulates the pipeline at a higher resolution and returns the result at - the original resolution. - - Notes - ----- - - This feature does not directly resize the image. Instead, it modifies the - unit conversions within the pipeline, making physical units smaller, - which results in more detail being simulated. - - The final output is downscaled back to the original resolution using - `block_reduce` from `skimage.measure`. - - The effect is only noticeable if features use physical units (e.g., - `units.pixel`, `units.meter`). Otherwise, the result will be identical. - - Examples - -------- - >>> import deeptrack as dt - - Define an optical pipeline and a spherical particle: - - >>> optics = dt.Fluorescence() - >>> particle = dt.Sphere() - >>> simple_pipeline = optics(particle) - - Create an upscaled pipeline with a factor of 4: - - >>> upscaled_pipeline = dt.Upscale(optics(particle), factor=4) - - Resolve the pipelines: - - >>> image = simple_pipeline() - >>> upscaled_image = upscaled_pipeline() - - Visualize the images: - - >>> import matplotlib.pyplot as plt - >>> - >>> plt.subplot(1, 2, 1) - >>> plt.imshow(image, cmap="gray") - >>> plt.title("Original Image") - >>> - >>> plt.subplot(1, 2, 2) - >>> plt.imshow(upscaled_image, cmap="gray") - >>> plt.title("Simulated at Higher Resolution") - >>> - >>> plt.show() - - Compare the shapes (both are the same due to downscaling): - - >>> print(image.shape) - (128, 128, 1) - >>> print(upscaled_image.shape) - (128, 128, 1) - - """ - - __distributed__: bool = False - - feature: Feature - - def __init__( - self: Feature, - feature: Feature, - factor: int | tuple[int, int, int] = 1, - **kwargs: Any, - ) -> None: - """Initialize the Upscale feature. - - Parameters - ---------- - feature: Feature - The pipeline or feature to resolve at a higher resolution. - factor: int or tuple[int, int, int], optional - The factor by which to upscale the simulation. If a single integer - is provided, it is applied uniformly across all axes. If a tuple of - three integers is provided, each axis is scaled individually. - Defaults to 1. - **kwargs: Any - Additional keyword arguments passed to the parent `Feature` class. - - """ - - super().__init__(factor=factor, **kwargs) - self.feature = self.add_feature(feature) - - def get( - self: Feature, - image: np.ndarray | torch.Tensor, - factor: int | tuple[int, int, int], - **kwargs: Any, - ) -> np.ndarray | torch.Tensor: - """Simulate the pipeline at a higher resolution and return result. - - Parameters - ---------- - image: np.ndarray or torch.Tensor - The input image to process. - factor: int or tuple[int, int, int] - The factor by which to upscale the simulation. If a single integer - is provided, it is applied uniformly across all axes. If a tuple of - three integers is provided, each axis is scaled individually. - **kwargs: Any - Additional keyword arguments passed to the feature. - - Returns - ------- - np.ndarray or torch.Tensor - The processed image at the original resolution. +# class Upscale(Feature): +# """Simulate a pipeline at a higher resolution. - Raises - ------ - ValueError - If the input `factor` is not a valid integer or tuple of integers. - - """ - - # Ensure factor is a tuple of three integers. - if np.size(factor) == 1: - factor = (factor, factor, 1) - elif len(factor) != 3: - raise ValueError( - "Factor must be an integer or a tuple of three integers." - ) - - # Create a context for upscaling and perform computation. - ctx = create_context(None, None, None, *factor) - - print('before:', image) - with units.context(ctx): - image = self.feature(image) - - print('after:', image) - # Downscale the result to the original resolution. - import skimage.measure - - image = skimage.measure.block_reduce( - image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean - ) - - return image - - -class NonOverlapping(Feature): - """Ensure volumes are placed non-overlapping in a 3D space. - - This feature ensures that a list of 3D volumes are positioned such that - their non-zero voxels do not overlap. If volumes overlap, their positions - are resampled until they are non-overlapping. If the maximum number of - attempts is exceeded, the feature regenerates the list of volumes and - raises a warning if non-overlapping placement cannot be achieved. - - Note: `min_distance` refers to the distance between the edges of volumes, - not their centers. Due to the way volumes are calculated, slight rounding - errors may affect the final distance. +# This feature scales up the resolution of the input pipeline by a specified +# factor, performs computations at the higher resolution, and then +# downsamples the result back to the original size. This is useful for +# simulating effects at a finer resolution while preserving compatibility +# with lower-resolution pipelines. - This feature is incompatible with non-volumetric scatterers such as - `MieScatterers`. +# Internally, this feature redefines the scale of physical units (e.g., +# `units.pixel`) to achieve the effect of upscaling. Therefore, it does not +# resize the input image itself but affects only features that rely on +# physical units. + +# Parameters +# ---------- +# feature: Feature +# The pipeline or feature to resolve at a higher resolution. +# factor: int or tuple[int, int, int], optional +# The factor by which to upscale the simulation. If a single integer is +# provided, it is applied uniformly across all axes. If a tuple of three +# integers is provided, each axis is scaled individually. Defaults to 1. +# **kwargs: Any +# Additional keyword arguments passed to the parent `Feature` class. + +# Attributes +# ---------- +# __distributed__: bool +# Always `False` for `Upscale`, indicating that this feature’s `.get()` +# method processes the entire input at once even if it is a list, rather +# than distributing calls for each item of the list. + +# Methods +# ------- +# `get(image, factor, **kwargs) -> np.ndarray | torch.tensor` +# Simulates the pipeline at a higher resolution and returns the result at +# the original resolution. + +# Notes +# ----- +# - This feature does not directly resize the image. Instead, it modifies the +# unit conversions within the pipeline, making physical units smaller, +# which results in more detail being simulated. +# - The final output is downscaled back to the original resolution using +# `block_reduce` from `skimage.measure`. +# - The effect is only noticeable if features use physical units (e.g., +# `units.pixel`, `units.meter`). Otherwise, the result will be identical. + +# Examples +# -------- +# >>> import deeptrack as dt + +# Define an optical pipeline and a spherical particle: + +# >>> optics = dt.Fluorescence() +# >>> particle = dt.Sphere() +# >>> simple_pipeline = optics(particle) + +# Create an upscaled pipeline with a factor of 4: + +# >>> upscaled_pipeline = dt.Upscale(optics(particle), factor=4) - Parameters - ---------- - feature: Feature - The feature that generates the list of volumes to place - non-overlapping. - min_distance: float, optional - The minimum distance between volumes in pixels. It can be negative to - allow for partial overlap. Defaults to 1. - max_attempts: int, optional - The maximum number of attempts to place volumes without overlap. - Defaults to 5. - max_iters: int, optional - The maximum number of resamplings. If this number is exceeded, a new - list of volumes is generated. Defaults to 100. - - Attributes - ---------- - __distributed__: bool - Always `False` for `NonOverlapping`, indicating that this feature’s - `.get()` method processes the entire input at once even if it is a - list, rather than distributing calls for each item of the list.N - - Methods - ------- - `get(*_, min_distance, max_attempts, **kwargs) -> array` - Generate a list of non-overlapping 3D volumes. - `_check_non_overlapping(list_of_volumes) -> bool` - Check if all volumes in the list are non-overlapping. - `_check_bounding_cubes_non_overlapping(...) -> bool` - Check if two bounding cubes are non-overlapping. - `_get_overlapping_cube(...) -> list[int]` - Get the overlapping cube between two bounding cubes. - `_get_overlapping_volume(...) -> array` - Get the overlapping volume between a volume and a bounding cube. - `_check_volumes_non_overlapping(...) -> bool` - Check if two volumes are non-overlapping. - `_resample_volume_position(volume) -> Image` - Resample the position of a volume to avoid overlap. +# Resolve the pipelines: + +# >>> image = simple_pipeline() +# >>> upscaled_image = upscaled_pipeline() + +# Visualize the images: + +# >>> import matplotlib.pyplot as plt +# >>> +# >>> plt.subplot(1, 2, 1) +# >>> plt.imshow(image, cmap="gray") +# >>> plt.title("Original Image") +# >>> +# >>> plt.subplot(1, 2, 2) +# >>> plt.imshow(upscaled_image, cmap="gray") +# >>> plt.title("Simulated at Higher Resolution") +# >>> +# >>> plt.show() - Notes - ----- - - This feature performs bounding cube checks first to quickly reject - obvious overlaps before voxel-level checks. - - If the bounding cubes overlap, precise voxel-based checks are performed. - - Examples - --------- - >>> import deeptrack as dt - - Define an ellipse scatterer with randomly positioned objects: - - >>> import numpy as np - >>> - >>> scatterer = dt.Ellipse( - >>> radius= 13 * dt.units.pixels, - >>> position=lambda: np.random.uniform(5, 115, size=2)* dt.units.pixels, - >>> ) +# Compare the shapes (both are the same due to downscaling): - Create multiple scatterers: - - >>> scatterers = (scatterer ^ 8) - - Define the optics and create the image with possible overlap: - - >>> optics = dt.Fluorescence() - >>> im_with_overlap = optics(scatterers) - >>> im_with_overlap.store_properties() - >>> im_with_overlap_resolved = image_with_overlap() - - Gather position from image: - - >>> pos_with_overlap = np.array( - >>> im_with_overlap_resolved.get_property( - >>> "position", - >>> get_one=False - >>> ) - >>> ) - - Enforce non-overlapping and create the image without overlap: +# >>> print(image.shape) +# (128, 128, 1) +# >>> print(upscaled_image.shape) +# (128, 128, 1) - >>> non_overlapping_scatterers = dt.NonOverlapping( - ... scatterers, - ... min_distance=4, - ... ) - >>> im_without_overlap = optics(non_overlapping_scatterers) - >>> im_without_overlap.store_properties() - >>> im_without_overlap_resolved = im_without_overlap() - - Gather position from image: - - >>> pos_without_overlap = np.array( - >>> im_without_overlap_resolved.get_property( - >>> "position", - >>> get_one=False - >>> ) - >>> ) - - Create a figure with two subplots to visualize the difference: - - >>> import matplotlib.pyplot as plt - >>> - >>> fig, axes = plt.subplots(1, 2, figsize=(10, 5)) - >>> - >>> axes[0].imshow(im_with_overlap_resolved, cmap="gray") - >>> axes[0].scatter(pos_with_overlap[:,1],pos_with_overlap[:,0]) - >>> axes[0].set_title("Overlapping Objects") - >>> axes[0].axis("off") - >>> - >>> axes[1].imshow(im_without_overlap_resolved, cmap="gray") - >>> axes[1].scatter(pos_without_overlap[:,1],pos_without_overlap[:,0]) - >>> axes[1].set_title("Non-Overlapping Objects") - >>> axes[1].axis("off") - >>> plt.tight_layout() - >>> - >>> plt.show() - - Define function to calculate minimum distance: - - >>> def calculate_min_distance(positions): - >>> distances = [ - >>> np.linalg.norm(positions[i] - positions[j]) - >>> for i in range(len(positions)) - >>> for j in range(i + 1, len(positions)) - >>> ] - >>> return min(distances) - - Print minimum distances with and without overlap: - - >>> print(calculate_min_distance(pos_with_overlap)) - 10.768742383382174 - - >>> print(calculate_min_distance(pos_without_overlap)) - 30.82531120942446 - - """ - - __distributed__: bool = False - - def __init__( - self: NonOverlapping, - feature: Feature, - min_distance: float = 1, - max_attempts: int = 5, - max_iters: int = 100, - **kwargs: Any, - ): - """Initializes the NonOverlapping feature. - - Ensures that volumes are placed **non-overlapping** by iteratively - resampling their positions. If the maximum number of attempts is - exceeded, the feature regenerates the list of volumes. - - Parameters - ---------- - feature: Feature - The feature that generates the list of volumes. - min_distance: float, optional - The minimum separation distance **between volume edges**, in - pixels. It defaults to `1`. Negative values allow for partial - overlap. - max_attempts: int, optional - The maximum number of attempts to place the volumes without - overlap. It defaults to `5`. - max_iters: int, optional - The maximum number of resampling iterations per attempt. If - exceeded, a new list of volumes is generated. It defaults to `100`. - - """ - - super().__init__( - min_distance=min_distance, - max_attempts=max_attempts, - max_iters=max_iters, - **kwargs, - ) - self.feature = self.add_feature(feature, **kwargs) - - def get( - self: NonOverlapping, - *_: Any, - min_distance: float, - max_attempts: int, - max_iters: int, - **kwargs: Any, - ) -> list[np.ndarray]: - """Generates a list of non-overlapping 3D volumes within a defined - field of view (FOV). - - This method **iteratively** attempts to place volumes while ensuring - they maintain at least `min_distance` separation. If non-overlapping - placement is not achieved within `max_attempts`, a warning is issued, - and the best available configuration is returned. - - Parameters - ---------- - _: Any - Placeholder parameter, typically for an input image. - min_distance: float - The minimum required separation distance between volumes, in - pixels. - max_attempts: int - The maximum number of attempts to generate a valid non-overlapping - configuration. - max_iters: int - The maximum number of resampling iterations per attempt. - **kwargs: Any - Additional parameters that may be used by subclasses. - - Returns - ------- - list[np.ndarray] - A list of 3D volumes represented as NumPy arrays. If - non-overlapping placement is unsuccessful, the best available - configuration is returned. - - Warns - ----- - UserWarning - If non-overlapping placement is **not** achieved within - `max_attempts`, suggesting parameter adjustments such as increasing - the FOV or reducing `min_distance`. - - Notes - ----- - - The placement process prioritizes bounding cube checks for - efficiency. - - If bounding cubes overlap, voxel-based overlap checks are performed. +# """ + +# __distributed__: bool = False + +# feature: Feature + +# def __init__( +# self: Feature, +# feature: Feature, +# factor: int | tuple[int, int, int] = 1, +# **kwargs: Any, +# ) -> None: +# """Initialize the Upscale feature. + +# Parameters +# ---------- +# feature: Feature +# The pipeline or feature to resolve at a higher resolution. +# factor: int or tuple[int, int, int], optional +# The factor by which to upscale the simulation. If a single integer +# is provided, it is applied uniformly across all axes. If a tuple of +# three integers is provided, each axis is scaled individually. +# Defaults to 1. +# **kwargs: Any +# Additional keyword arguments passed to the parent `Feature` class. + +# """ + +# super().__init__(factor=factor, **kwargs) +# self.feature = self.add_feature(feature) + +# def get( +# self: Feature, +# image: np.ndarray | torch.Tensor, +# factor: int | tuple[int, int, int], +# **kwargs: Any, +# ) -> np.ndarray | torch.Tensor: +# """Simulate the pipeline at a higher resolution and return result. + +# Parameters +# ---------- +# image: np.ndarray or torch.Tensor +# The input image to process. +# factor: int or tuple[int, int, int] +# The factor by which to upscale the simulation. If a single integer +# is provided, it is applied uniformly across all axes. If a tuple of +# three integers is provided, each axis is scaled individually. +# **kwargs: Any +# Additional keyword arguments passed to the feature. + +# Returns +# ------- +# np.ndarray or torch.Tensor +# The processed image at the original resolution. + +# Raises +# ------ +# ValueError +# If the input `factor` is not a valid integer or tuple of integers. + +# """ + +# # Ensure factor is a tuple of three integers. +# if np.size(factor) == 1: +# factor = (factor, factor, 1) +# elif len(factor) != 3: +# raise ValueError( +# "Factor must be an integer or a tuple of three integers." +# ) - """ - - for _ in range(max_attempts): - list_of_volumes = self.feature() - - if not isinstance(list_of_volumes, list): - list_of_volumes = [list_of_volumes] +# # Create a context for upscaling and perform computation. +# ctx = create_context(None, None, None, *factor) - for _ in range(max_iters): - - list_of_volumes = [ - self._resample_volume_position(volume) - for volume in list_of_volumes - ] +# print('before:', image) +# with units.context(ctx): +# image = self.feature(image) - if self._check_non_overlapping(list_of_volumes): - return list_of_volumes +# print('after:', image) +# # Downscale the result to the original resolution. +# import skimage.measure - # Generate a new list of volumes if max_attempts is exceeded. - self.feature.update() - - warnings.warn( - "Non-overlapping placement could not be achieved. Consider " - "adjusting parameters: reduce object radius, increase FOV, " - "or decrease min_distance.", - UserWarning, - ) - return list_of_volumes - - def _check_non_overlapping( - self: NonOverlapping, - list_of_volumes: list[np.ndarray], - ) -> bool: - """Determines whether all volumes in the provided list are - non-overlapping. - - This method verifies that the non-zero voxels of each 3D volume in - `list_of_volumes` are at least `min_distance` apart. It first checks - bounding boxes for early rejection and then examines actual voxel - overlap when necessary. Volumes are assumed to have a `position` - attribute indicating their placement in 3D space. - - Parameters - ---------- - list_of_volumes: list[np.ndarray] - A list of 3D arrays representing the volumes to be checked for - overlap. Each volume is expected to have a position attribute. - - Returns - ------- - bool - `True` if all volumes are non-overlapping, otherwise `False`. - - Notes - ----- - - If `min_distance` is negative, volumes are shrunk using isotropic - erosion before checking overlap. - - If `min_distance` is positive, volumes are padded and expanded using - isotropic dilation. - - Overlapping checks are first performed on bounding cubes for - efficiency. - - If bounding cubes overlap, voxel-level checks are performed. - - """ - from deeptrack.scatterers import ScatteredVolume - - from deeptrack.augmentations import CropTight, Pad # these are not compatibles with torch backend - from deeptrack.optics import _get_position - from deeptrack.math import isotropic_erosion, isotropic_dilation - - min_distance = self.min_distance() - crop = CropTight() - - new_volumes = [] - - for volume in list_of_volumes: - arr = volume.array - mask = arr != 0 - - if min_distance < 0: - new_arr = isotropic_erosion(mask, -min_distance / 2, backend=self.get_backend()) - else: - pad = Pad(px=[int(np.ceil(min_distance / 2))] * 6, keep_size=True) - new_arr = isotropic_dilation(pad(mask) != 0 , min_distance / 2, backend=self.get_backend()) - new_arr = crop(new_arr) - - if self.get_backend() == "torch": - new_arr = new_arr.to(dtype=arr.dtype) - else: - new_arr = new_arr.astype(arr.dtype) - - new_volume = ScatteredVolume( - array=new_arr, - properties=volume.properties.copy(), - ) - - new_volumes.append(new_volume) - - list_of_volumes = new_volumes - min_distance = 1 - - # The position of the top left corner of each volume (index (0, 0, 0)). - volume_positions_1 = [ - _get_position(volume, mode="corner", return_z=True).astype(int) - for volume in list_of_volumes - ] - - # The position of the bottom right corner of each volume - # (index (-1, -1, -1)). - volume_positions_2 = [ - p0 + np.array(v.shape) - for v, p0 in zip(list_of_volumes, volume_positions_1) - ] - - # (x1, y1, z1, x2, y2, z2) for each volume. - volume_bounding_cube = [ - [*p0, *p1] - for p0, p1 in zip(volume_positions_1, volume_positions_2) - ] - - for i, j in itertools.combinations(range(len(list_of_volumes)), 2): - - # If the bounding cubes do not overlap, the volumes do not overlap. - if self._check_bounding_cubes_non_overlapping( - volume_bounding_cube[i], volume_bounding_cube[j], min_distance - ): - continue - - # If the bounding cubes overlap, get the overlapping region of each - # volume. - overlapping_cube = self._get_overlapping_cube( - volume_bounding_cube[i], volume_bounding_cube[j] - ) - overlapping_volume_1 = self._get_overlapping_volume( - list_of_volumes[i].array, volume_bounding_cube[i], overlapping_cube - ) - overlapping_volume_2 = self._get_overlapping_volume( - list_of_volumes[j].array, volume_bounding_cube[j], overlapping_cube - ) - - # If either the overlapping regions are empty, the volumes do not - # overlap (done for speed). - if (np.all(overlapping_volume_1 == 0) - or np.all(overlapping_volume_2 == 0)): - continue - - # If products of overlapping regions are non-zero, return False. - # if np.any(overlapping_volume_1 * overlapping_volume_2): - # return False - - # Finally, check that the non-zero voxels of the volumes are at - # least min_distance apart. - if not self._check_volumes_non_overlapping( - overlapping_volume_1, overlapping_volume_2, min_distance - ): - return False - - return True - - def _check_bounding_cubes_non_overlapping( - self: NonOverlapping, - bounding_cube_1: list[int], - bounding_cube_2: list[int], - min_distance: float, - ) -> bool: - """Determines whether two 3D bounding cubes are non-overlapping. - - This method checks whether the bounding cubes of two volumes are - **separated by at least** `min_distance` along **any** spatial axis. - - Parameters - ---------- - bounding_cube_1: list[int] - A list of six integers `[x1, y1, z1, x2, y2, z2]` representing - the first bounding cube. - bounding_cube_2: list[int] - A list of six integers `[x1, y1, z1, x2, y2, z2]` representing - the second bounding cube. - min_distance: float - The required **minimum separation distance** between the two - bounding cubes. - - Returns - ------- - bool - `True` if the bounding cubes are non-overlapping (separated by at - least `min_distance` along **at least one axis**), otherwise - `False`. - - Notes - ----- - - This function **only checks bounding cubes**, **not actual voxel - data**. - - If the bounding cubes are non-overlapping, the corresponding - **volumes are also non-overlapping**. - - This check is much **faster** than full voxel-based comparisons. - - """ - - # bounding_cube_1 and bounding_cube_2 are (x1, y1, z1, x2, y2, z2). - # Check that the bounding cubes are non-overlapping. - return ( - (bounding_cube_1[0] >= bounding_cube_2[3] + min_distance) or - (bounding_cube_2[0] >= bounding_cube_1[3] + min_distance) or - (bounding_cube_1[1] >= bounding_cube_2[4] + min_distance) or - (bounding_cube_2[1] >= bounding_cube_1[4] + min_distance) or - (bounding_cube_1[2] >= bounding_cube_2[5] + min_distance) or - (bounding_cube_2[2] >= bounding_cube_1[5] + min_distance) - ) - - def _get_overlapping_cube( - self: NonOverlapping, - bounding_cube_1: list[int], - bounding_cube_2: list[int], - ) -> list[int]: - """Computes the overlapping region between two 3D bounding cubes. - - This method calculates the coordinates of the intersection of two - axis-aligned bounding cubes, each represented as a list of six - integers: - - - `[x1, y1, z1]`: Coordinates of the **top-left-front** corner. - - `[x2, y2, z2]`: Coordinates of the **bottom-right-back** corner. - - The resulting overlapping region is determined by: - - Taking the **maximum** of the starting coordinates (`x1, y1, z1`). - - Taking the **minimum** of the ending coordinates (`x2, y2, z2`). - - If the cubes **do not** overlap, the resulting coordinates will not - form a valid cube (i.e., `x1 > x2`, `y1 > y2`, or `z1 > z2`). - - Parameters - ---------- - bounding_cube_1: list[int] - The first bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`. - bounding_cube_2: list[int] - The second bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`. - - Returns - ------- - list[int] - A list of six integers `[x1, y1, z1, x2, y2, z2]` representing the - overlapping bounding cube. If no overlap exists, the coordinates - will **not** define a valid cube. - - Notes - ----- - - This function does **not** check for valid input or ensure the - resulting cube is well-formed. - - If no overlap exists, downstream functions must handle the invalid - result. - - """ - - return [ - max(bounding_cube_1[0], bounding_cube_2[0]), - max(bounding_cube_1[1], bounding_cube_2[1]), - max(bounding_cube_1[2], bounding_cube_2[2]), - min(bounding_cube_1[3], bounding_cube_2[3]), - min(bounding_cube_1[4], bounding_cube_2[4]), - min(bounding_cube_1[5], bounding_cube_2[5]), - ] - - def _get_overlapping_volume( - self: NonOverlapping, - volume: np.ndarray, # 3D array. - bounding_cube: tuple[float, float, float, float, float, float], - overlapping_cube: tuple[float, float, float, float, float, float], - ) -> np.ndarray: - """Extracts the overlapping region of a 3D volume within the specified - overlapping cube. - - This method identifies and returns the subregion of `volume` that - lies within the `overlapping_cube`. The bounding information of the - volume is provided via `bounding_cube`. - - Parameters - ---------- - volume: np.ndarray - A 3D NumPy array representing the volume from which the - overlapping region is extracted. - bounding_cube: tuple[float, float, float, float, float, float] - The bounding cube of the volume, given as a tuple of six floats: - `(x1, y1, z1, x2, y2, z2)`. The first three values define the - **top-left-front** corner, while the last three values define the - **bottom-right-back** corner. - overlapping_cube: tuple[float, float, float, float, float, float] - The overlapping region between the volume and another volume, - represented in the same format as `bounding_cube`. - - Returns - ------- - np.ndarray - A 3D NumPy array representing the portion of `volume` that - lies within `overlapping_cube`. If the overlap does not exist, - an empty array may be returned. - - Notes - ----- - - The method computes the relative indices of `overlapping_cube` - within `volume` by subtracting the bounding cube's starting - position. - - The extracted region is determined by integer indices, meaning - coordinates are implicitly **floored to integers**. - - If `overlapping_cube` extends beyond `volume` boundaries, the - returned subregion is **cropped** to fit within `volume`. - - """ - - # The position of the top left corner of the overlapping cube in the volume - overlapping_cube_position = np.array(overlapping_cube[:3]) - np.array( - bounding_cube[:3] - ) - - # The position of the bottom right corner of the overlapping cube in the volume - overlapping_cube_end_position = np.array( - overlapping_cube[3:] - ) - np.array(bounding_cube[:3]) - - # cast to int - overlapping_cube_position = overlapping_cube_position.astype(int) - overlapping_cube_end_position = overlapping_cube_end_position.astype(int) - - return volume[ - overlapping_cube_position[0] : overlapping_cube_end_position[0], - overlapping_cube_position[1] : overlapping_cube_end_position[1], - overlapping_cube_position[2] : overlapping_cube_end_position[2], - ] - - def _check_volumes_non_overlapping( - self: NonOverlapping, - volume_1: np.ndarray, - volume_2: np.ndarray, - min_distance: float, - ) -> bool: - """Determines whether the non-zero voxels in two 3D volumes are at - least `min_distance` apart. - - This method checks whether the active regions (non-zero voxels) in - `volume_1` and `volume_2` maintain a minimum separation of - `min_distance`. If the volumes differ in size, the positions of their - non-zero voxels are adjusted accordingly to ensure a fair comparison. - - Parameters - ---------- - volume_1: np.ndarray - A 3D NumPy array representing the first volume. - volume_2: np.ndarray - A 3D NumPy array representing the second volume. - min_distance: float - The minimum Euclidean distance required between any two non-zero - voxels in the two volumes. - - Returns - ------- - bool - `True` if all non-zero voxels in `volume_1` and `volume_2` are at - least `min_distance` apart, otherwise `False`. - - Notes - ----- - - This function assumes both volumes are correctly aligned within a - shared coordinate space. - - If the volumes are of different sizes, voxel positions are scaled - or adjusted for accurate distance measurement. - - Uses **Euclidean distance** for separation checking. - - If either volume is empty (i.e., no non-zero voxels), they are - considered non-overlapping. - - """ - - # Get the positions of the non-zero voxels of each volume. - if self.get_backend() == "torch": - positions_1 = torch.nonzero(volume_1, as_tuple=False) - positions_2 = torch.nonzero(volume_2, as_tuple=False) - else: - positions_1 = np.argwhere(volume_1) - positions_2 = np.argwhere(volume_2) - - # if positions_1.size == 0 or positions_2.size == 0: - # return True # If either volume is empty, they are "non-overlapping" - - # # If the volumes are not the same size, the positions of the non-zero - # # voxels of each volume need to be scaled. - # if positions_1.size == 0 or positions_2.size == 0: - # return True # If either volume is empty, they are "non-overlapping" - - # If the volumes are not the same size, the positions of the non-zero - # voxels of each volume need to be scaled. - if volume_1.shape != volume_2.shape: - positions_1 = ( - positions_1 * np.array(volume_2.shape) - / np.array(volume_1.shape) - ) - positions_1 = positions_1.astype(int) - - # Check that the non-zero voxels of the volumes are at least - # min_distance apart. - if self.get_backend() == "torch": - dist = torch.cdist( - positions_1.float(), - positions_2.float(), - ) - return bool((dist > min_distance).all()) - else: - return np.all(cdist(positions_1, positions_2) > min_distance) - - def _resample_volume_position( - self: NonOverlapping, - volume: np.ndarray | Image, - ) -> Image: - """Resamples the position of a 3D volume using its internal position - sampler. - - This method updates the `position` property of the given `volume` by - drawing a new position from the `_position_sampler` stored in the - volume's `properties`. If the sampled position is a `Quantity`, it is - converted to pixel units. - - Parameters - ---------- - volume: np.ndarray - The 3D volume whose position is to be resampled. The volume must - have a `properties` attribute containing dictionaries with - `position` and `_position_sampler` keys. - - Returns - ------- - Image - The same input volume with its `position` property updated to the - newly sampled value. - - Notes - ----- - - The `_position_sampler` function is expected to return a **tuple of - three floats** (e.g., `(x, y, z)`). - - If the sampled position is a `Quantity`, it is converted to pixels. - - **Only** dictionaries in `volume.properties` that contain both - `position` and `_position_sampler` keys are modified. - - """ +# image = skimage.measure.block_reduce( +# image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean +# ) - pdict = volume.properties - if "position" in pdict and "_position_sampler" in pdict: - new_position = pdict["_position_sampler"]() - if isinstance(new_position, Quantity): - new_position = new_position.to("pixel").magnitude - pdict["position"] = new_position +# return image - return volume class Store(Feature): diff --git a/deeptrack/optics.py b/deeptrack/optics.py index 0788d7d8..bb3cf40f 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -267,9 +267,10 @@ def _validate_input(self, scattered): def _extract_contrast_volume(self, scattered): if hasattr(self._objective, "extract_contrast_volume"): - return self._objective.extract_contrast_volume(scattered) - - # default: geometry-only + return self._objective.extract_contrast_volume( + scattered, + **self._objective.properties(), + ) return scattered.array def _downscale_image(self, image, upscale): @@ -395,12 +396,14 @@ def get( **additional_sample_kwargs, ) + print('prop', volume_samples[0].properties) + # Interpret the merged volume semantically sample_volume = self._extract_contrast_volume( ScatteredVolume( array=sample_volume, properties=volume_samples[0].properties, - ) + ), ) # Let the objective know about the limits of the volume and all the fields. @@ -1081,37 +1084,36 @@ def validate_input(self, scattered): "Fluorescence microscope cannot operate on ScatteredField." ) - # Fluorescence must not use refractive index - if isinstance(scattered, ScatteredVolume): - if scattered.get_property("refractive_index", None) is not None: - raise ValueError( - "Fluorescence does not use refractive index. " - "Found 'refractive_index' in scatterer properties." - ) - def extract_contrast_volume(self, scattered: ScatteredVolume) -> np.ndarray: - """Contrast extraction (semantic interpretation)""" - intensity = scattered.get_property("intensity", None) + voxel_size = np.asarray(get_active_voxel_size(), float) + voxel_volume = np.prod(voxel_size) - if intensity is None: - intensity = scattered.get_property("value", None) - if intensity is None: - raise ValueError( - "Fluorescence requires 'intensity' or 'value'." - ) + intensity = scattered.get_property("intensity", None) + value = scattered.get_property("value", None) + ri = scattered.get_property("refractive_index", None) + # Refractive index is always ignored in fluorescence + if ri is not None: warnings.warn( - "Using 'value' as fluorescence intensity is ambiguous. " - "Please use 'intensity' explicitly to avoid ambiguity.", + "Scatterer defines 'refractive_index', which is ignored in " + "fluorescence microscopy.", UserWarning, ) - voxel_size = np.asarray(get_active_voxel_size(), dtype=float) - voxel_volume = float(np.prod(voxel_size)) + # Preferred, physically meaningful case + if intensity is not None: + return intensity * voxel_volume * scattered.array - return scattered.array * intensity * voxel_volume + # Fallback: legacy / dimensionless brightness + warnings.warn( + "Fluorescence scatterer has no 'intensity'. Interpreting 'value' as a " + "non-physical brightness factor. Quantitative interpretation is invalid. " + "Define 'intensity' to model physical fluorescence emission.", + UserWarning, + ) + return value * scattered.array def downscale_image(self, image: np.ndarray, upscale): """Detector downscaling (energy conserving)""" @@ -1357,8 +1359,50 @@ class Brightfield(Optics): __conversion_table__ = ConversionTable( - working_distance=(u.meter, u.meter), - ) + working_distance=(u.meter, u.meter), +) + + def validate_input(self, scattered): + """Semantic validation for brightfield microscopy.""" + + if isinstance(scattered, ScatteredVolume): + warnings.warn( + "Brightfield imaging from ScatteredVolume assumes a " + "weak-phase / projection approximation. " + "Use ScatteredField for physically accurate brightfield simulations.", + UserWarning, + ) + + def extract_contrast_volume( + self, + scattered: ScatteredVolume, + refractive_index_medium: float, + **kwargs: Any, + ) -> np.ndarray: + print('ri_medium', refractive_index_medium) + + ri = scattered.get_property("refractive_index", None) + value = scattered.get_property("value", None) + intensity = scattered.get_property("intensity", None) + + if intensity is not None: + warnings.warn( + "Scatterer defines 'intensity', which is ignored in " + "brightfield microscopy.", + UserWarning, + ) + + if ri is not None: + return (ri - refractive_index_medium) * scattered.array + + warnings.warn( + "No 'refractive_index' specified; using 'value' as a non-physical " + "brightfield contrast. Results are not physically calibrated. " + "Define 'refractive_index' for physically meaningful contrast.", + UserWarning, + ) + + return value * scattered.array def get( self: Brightfield, @@ -1746,6 +1790,57 @@ def __init__( illumination_angle=illumination_angle, **kwargs) + def validate_input(self, scattered): + if isinstance(scattered, ScatteredVolume): + warnings.warn( + "Darkfield imaging from ScatteredVolume is a very rough " + "approximation. Use ScatteredField for physically meaningful " + "darkfield simulations.", + UserWarning, + ) + + def extract_contrast_volume( + self, + scattered: ScatteredVolume, + refractive_index_medium: float, + **kwargs: Any, + ) -> np.ndarray: + """ + Approximate darkfield contrast from a volume (toy model). + + This is a non-physical approximation intended for qualitative simulations. + """ + + ri = scattered.get_property("refractive_index", None) + value = scattered.get_property("value", None) + intensity = scattered.get_property("intensity", None) + + # Intensity has no meaning here + if intensity is not None: + warnings.warn( + "Scatterer defines 'intensity', which is ignored in " + "darkfield microscopy.", + UserWarning, + ) + + if ri is not None: + delta_n = ri - refractive_index_medium + warnings.warn( + "Approximating darkfield contrast from refractive index. " + "Result is non-physical and qualitative only.", + UserWarning, + ) + return (delta_n ** 2) * scattered.array + + warnings.warn( + "No 'refractive_index' specified; using 'value' as a non-physical " + "darkfield scattering strength. Results are qualitative only.", + UserWarning, + ) + + return (value ** 2) * scattered.array + + #Retrieve get as super def get( self: Darkfield, @@ -1922,6 +2017,998 @@ def get( return image +class NonOverlapping(Feature): + """Ensure volumes are placed non-overlapping in a 3D space. + + This feature ensures that a list of 3D volumes are positioned such that + their non-zero voxels do not overlap. If volumes overlap, their positions + are resampled until they are non-overlapping. If the maximum number of + attempts is exceeded, the feature regenerates the list of volumes and + raises a warning if non-overlapping placement cannot be achieved. + + Note: `min_distance` refers to the distance between the edges of volumes, + not their centers. Due to the way volumes are calculated, slight rounding + errors may affect the final distance. + + This feature is incompatible with non-volumetric scatterers such as + `MieScatterers`. + + Parameters + ---------- + feature: Feature + The feature that generates the list of volumes to place + non-overlapping. + min_distance: float, optional + The minimum distance between volumes in pixels. It can be negative to + allow for partial overlap. Defaults to 1. + max_attempts: int, optional + The maximum number of attempts to place volumes without overlap. + Defaults to 5. + max_iters: int, optional + The maximum number of resamplings. If this number is exceeded, a new + list of volumes is generated. Defaults to 100. + + Attributes + ---------- + __distributed__: bool + Always `False` for `NonOverlapping`, indicating that this feature’s + `.get()` method processes the entire input at once even if it is a + list, rather than distributing calls for each item of the list.N + + Methods + ------- + `get(*_, min_distance, max_attempts, **kwargs) -> array` + Generate a list of non-overlapping 3D volumes. + `_check_non_overlapping(list_of_volumes) -> bool` + Check if all volumes in the list are non-overlapping. + `_check_bounding_cubes_non_overlapping(...) -> bool` + Check if two bounding cubes are non-overlapping. + `_get_overlapping_cube(...) -> list[int]` + Get the overlapping cube between two bounding cubes. + `_get_overlapping_volume(...) -> array` + Get the overlapping volume between a volume and a bounding cube. + `_check_volumes_non_overlapping(...) -> bool` + Check if two volumes are non-overlapping. + `_resample_volume_position(volume) -> Image` + Resample the position of a volume to avoid overlap. + + Notes + ----- + - This feature performs bounding cube checks first to quickly reject + obvious overlaps before voxel-level checks. + - If the bounding cubes overlap, precise voxel-based checks are performed. + + Examples + --------- + >>> import deeptrack as dt + + Define an ellipse scatterer with randomly positioned objects: + + >>> import numpy as np + >>> + >>> scatterer = dt.Ellipse( + >>> radius= 13 * dt.units.pixels, + >>> position=lambda: np.random.uniform(5, 115, size=2)* dt.units.pixels, + >>> ) + + Create multiple scatterers: + + >>> scatterers = (scatterer ^ 8) + + Define the optics and create the image with possible overlap: + + >>> optics = dt.Fluorescence() + >>> im_with_overlap = optics(scatterers) + >>> im_with_overlap.store_properties() + >>> im_with_overlap_resolved = image_with_overlap() + + Gather position from image: + + >>> pos_with_overlap = np.array( + >>> im_with_overlap_resolved.get_property( + >>> "position", + >>> get_one=False + >>> ) + >>> ) + + Enforce non-overlapping and create the image without overlap: + + >>> non_overlapping_scatterers = dt.NonOverlapping( + ... scatterers, + ... min_distance=4, + ... ) + >>> im_without_overlap = optics(non_overlapping_scatterers) + >>> im_without_overlap.store_properties() + >>> im_without_overlap_resolved = im_without_overlap() + + Gather position from image: + + >>> pos_without_overlap = np.array( + >>> im_without_overlap_resolved.get_property( + >>> "position", + >>> get_one=False + >>> ) + >>> ) + + Create a figure with two subplots to visualize the difference: + + >>> import matplotlib.pyplot as plt + >>> + >>> fig, axes = plt.subplots(1, 2, figsize=(10, 5)) + >>> + >>> axes[0].imshow(im_with_overlap_resolved, cmap="gray") + >>> axes[0].scatter(pos_with_overlap[:,1],pos_with_overlap[:,0]) + >>> axes[0].set_title("Overlapping Objects") + >>> axes[0].axis("off") + >>> + >>> axes[1].imshow(im_without_overlap_resolved, cmap="gray") + >>> axes[1].scatter(pos_without_overlap[:,1],pos_without_overlap[:,0]) + >>> axes[1].set_title("Non-Overlapping Objects") + >>> axes[1].axis("off") + >>> plt.tight_layout() + >>> + >>> plt.show() + + Define function to calculate minimum distance: + + >>> def calculate_min_distance(positions): + >>> distances = [ + >>> np.linalg.norm(positions[i] - positions[j]) + >>> for i in range(len(positions)) + >>> for j in range(i + 1, len(positions)) + >>> ] + >>> return min(distances) + + Print minimum distances with and without overlap: + + >>> print(calculate_min_distance(pos_with_overlap)) + 10.768742383382174 + + >>> print(calculate_min_distance(pos_without_overlap)) + 30.82531120942446 + + """ + + __distributed__: bool = False + + def __init__( + self: NonOverlapping, + feature: Feature, + min_distance: float = 1, + max_attempts: int = 5, + max_iters: int = 100, + **kwargs: Any, + ): + """Initializes the NonOverlapping feature. + + Ensures that volumes are placed **non-overlapping** by iteratively + resampling their positions. If the maximum number of attempts is + exceeded, the feature regenerates the list of volumes. + + Parameters + ---------- + feature: Feature + The feature that generates the list of volumes. + min_distance: float, optional + The minimum separation distance **between volume edges**, in + pixels. It defaults to `1`. Negative values allow for partial + overlap. + max_attempts: int, optional + The maximum number of attempts to place the volumes without + overlap. It defaults to `5`. + max_iters: int, optional + The maximum number of resampling iterations per attempt. If + exceeded, a new list of volumes is generated. It defaults to `100`. + + """ + + super().__init__( + min_distance=min_distance, + max_attempts=max_attempts, + max_iters=max_iters, + **kwargs, + ) + self.feature = self.add_feature(feature, **kwargs) + + def get( + self: NonOverlapping, + *_: Any, + min_distance: float, + max_attempts: int, + max_iters: int, + **kwargs: Any, + ) -> list[np.ndarray]: + """Generates a list of non-overlapping 3D volumes within a defined + field of view (FOV). + + This method **iteratively** attempts to place volumes while ensuring + they maintain at least `min_distance` separation. If non-overlapping + placement is not achieved within `max_attempts`, a warning is issued, + and the best available configuration is returned. + + Parameters + ---------- + _: Any + Placeholder parameter, typically for an input image. + min_distance: float + The minimum required separation distance between volumes, in + pixels. + max_attempts: int + The maximum number of attempts to generate a valid non-overlapping + configuration. + max_iters: int + The maximum number of resampling iterations per attempt. + **kwargs: Any + Additional parameters that may be used by subclasses. + + Returns + ------- + list[np.ndarray] + A list of 3D volumes represented as NumPy arrays. If + non-overlapping placement is unsuccessful, the best available + configuration is returned. + + Warns + ----- + UserWarning + If non-overlapping placement is **not** achieved within + `max_attempts`, suggesting parameter adjustments such as increasing + the FOV or reducing `min_distance`. + + Notes + ----- + - The placement process prioritizes bounding cube checks for + efficiency. + - If bounding cubes overlap, voxel-based overlap checks are performed. + + """ + + for _ in range(max_attempts): + list_of_volumes = self.feature() + + if not isinstance(list_of_volumes, list): + list_of_volumes = [list_of_volumes] + + for _ in range(max_iters): + + list_of_volumes = [ + self._resample_volume_position(volume) + for volume in list_of_volumes + ] + + if self._check_non_overlapping(list_of_volumes): + return list_of_volumes + + # Generate a new list of volumes if max_attempts is exceeded. + self.feature.update() + + warnings.warn( + "Non-overlapping placement could not be achieved. Consider " + "adjusting parameters: reduce object radius, increase FOV, " + "or decrease min_distance.", + UserWarning, + ) + return list_of_volumes + + def _check_non_overlapping( + self: NonOverlapping, + list_of_volumes: list[np.ndarray], + ) -> bool: + """Determines whether all volumes in the provided list are + non-overlapping. + + This method verifies that the non-zero voxels of each 3D volume in + `list_of_volumes` are at least `min_distance` apart. It first checks + bounding boxes for early rejection and then examines actual voxel + overlap when necessary. Volumes are assumed to have a `position` + attribute indicating their placement in 3D space. + + Parameters + ---------- + list_of_volumes: list[np.ndarray] + A list of 3D arrays representing the volumes to be checked for + overlap. Each volume is expected to have a position attribute. + + Returns + ------- + bool + `True` if all volumes are non-overlapping, otherwise `False`. + + Notes + ----- + - If `min_distance` is negative, volumes are shrunk using isotropic + erosion before checking overlap. + - If `min_distance` is positive, volumes are padded and expanded using + isotropic dilation. + - Overlapping checks are first performed on bounding cubes for + efficiency. + - If bounding cubes overlap, voxel-level checks are performed. + + """ + from deeptrack.scatterers import ScatteredVolume + + from deeptrack.augmentations import CropTight, Pad # these are not compatibles with torch backend + from deeptrack.optics import _get_position + from deeptrack.math import isotropic_erosion, isotropic_dilation + + min_distance = self.min_distance() + crop = CropTight() + + new_volumes = [] + + for volume in list_of_volumes: + arr = volume.array + mask = arr != 0 + + if min_distance < 0: + new_arr = isotropic_erosion(mask, -min_distance / 2, backend=self.get_backend()) + else: + pad = Pad(px=[int(np.ceil(min_distance / 2))] * 6, keep_size=True) + new_arr = isotropic_dilation(pad(mask) != 0 , min_distance / 2, backend=self.get_backend()) + new_arr = crop(new_arr) + + if self.get_backend() == "torch": + new_arr = new_arr.to(dtype=arr.dtype) + else: + new_arr = new_arr.astype(arr.dtype) + + new_volume = ScatteredVolume( + array=new_arr, + properties=volume.properties.copy(), + ) + + new_volumes.append(new_volume) + + list_of_volumes = new_volumes + min_distance = 1 + + # The position of the top left corner of each volume (index (0, 0, 0)). + volume_positions_1 = [ + _get_position(volume, mode="corner", return_z=True).astype(int) + for volume in list_of_volumes + ] + + # The position of the bottom right corner of each volume + # (index (-1, -1, -1)). + volume_positions_2 = [ + p0 + np.array(v.shape) + for v, p0 in zip(list_of_volumes, volume_positions_1) + ] + + # (x1, y1, z1, x2, y2, z2) for each volume. + volume_bounding_cube = [ + [*p0, *p1] + for p0, p1 in zip(volume_positions_1, volume_positions_2) + ] + + for i, j in itertools.combinations(range(len(list_of_volumes)), 2): + + # If the bounding cubes do not overlap, the volumes do not overlap. + if self._check_bounding_cubes_non_overlapping( + volume_bounding_cube[i], volume_bounding_cube[j], min_distance + ): + continue + + # If the bounding cubes overlap, get the overlapping region of each + # volume. + overlapping_cube = self._get_overlapping_cube( + volume_bounding_cube[i], volume_bounding_cube[j] + ) + overlapping_volume_1 = self._get_overlapping_volume( + list_of_volumes[i].array, volume_bounding_cube[i], overlapping_cube + ) + overlapping_volume_2 = self._get_overlapping_volume( + list_of_volumes[j].array, volume_bounding_cube[j], overlapping_cube + ) + + # If either the overlapping regions are empty, the volumes do not + # overlap (done for speed). + if (np.all(overlapping_volume_1 == 0) + or np.all(overlapping_volume_2 == 0)): + continue + + # If products of overlapping regions are non-zero, return False. + # if np.any(overlapping_volume_1 * overlapping_volume_2): + # return False + + # Finally, check that the non-zero voxels of the volumes are at + # least min_distance apart. + if not self._check_volumes_non_overlapping( + overlapping_volume_1, overlapping_volume_2, min_distance + ): + return False + + return True + + def _check_bounding_cubes_non_overlapping( + self: NonOverlapping, + bounding_cube_1: list[int], + bounding_cube_2: list[int], + min_distance: float, + ) -> bool: + """Determines whether two 3D bounding cubes are non-overlapping. + + This method checks whether the bounding cubes of two volumes are + **separated by at least** `min_distance` along **any** spatial axis. + + Parameters + ---------- + bounding_cube_1: list[int] + A list of six integers `[x1, y1, z1, x2, y2, z2]` representing + the first bounding cube. + bounding_cube_2: list[int] + A list of six integers `[x1, y1, z1, x2, y2, z2]` representing + the second bounding cube. + min_distance: float + The required **minimum separation distance** between the two + bounding cubes. + + Returns + ------- + bool + `True` if the bounding cubes are non-overlapping (separated by at + least `min_distance` along **at least one axis**), otherwise + `False`. + + Notes + ----- + - This function **only checks bounding cubes**, **not actual voxel + data**. + - If the bounding cubes are non-overlapping, the corresponding + **volumes are also non-overlapping**. + - This check is much **faster** than full voxel-based comparisons. + + """ + + # bounding_cube_1 and bounding_cube_2 are (x1, y1, z1, x2, y2, z2). + # Check that the bounding cubes are non-overlapping. + return ( + (bounding_cube_1[0] >= bounding_cube_2[3] + min_distance) or + (bounding_cube_2[0] >= bounding_cube_1[3] + min_distance) or + (bounding_cube_1[1] >= bounding_cube_2[4] + min_distance) or + (bounding_cube_2[1] >= bounding_cube_1[4] + min_distance) or + (bounding_cube_1[2] >= bounding_cube_2[5] + min_distance) or + (bounding_cube_2[2] >= bounding_cube_1[5] + min_distance) + ) + + def _get_overlapping_cube( + self: NonOverlapping, + bounding_cube_1: list[int], + bounding_cube_2: list[int], + ) -> list[int]: + """Computes the overlapping region between two 3D bounding cubes. + + This method calculates the coordinates of the intersection of two + axis-aligned bounding cubes, each represented as a list of six + integers: + + - `[x1, y1, z1]`: Coordinates of the **top-left-front** corner. + - `[x2, y2, z2]`: Coordinates of the **bottom-right-back** corner. + + The resulting overlapping region is determined by: + - Taking the **maximum** of the starting coordinates (`x1, y1, z1`). + - Taking the **minimum** of the ending coordinates (`x2, y2, z2`). + + If the cubes **do not** overlap, the resulting coordinates will not + form a valid cube (i.e., `x1 > x2`, `y1 > y2`, or `z1 > z2`). + + Parameters + ---------- + bounding_cube_1: list[int] + The first bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`. + bounding_cube_2: list[int] + The second bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`. + + Returns + ------- + list[int] + A list of six integers `[x1, y1, z1, x2, y2, z2]` representing the + overlapping bounding cube. If no overlap exists, the coordinates + will **not** define a valid cube. + + Notes + ----- + - This function does **not** check for valid input or ensure the + resulting cube is well-formed. + - If no overlap exists, downstream functions must handle the invalid + result. + + """ + + return [ + max(bounding_cube_1[0], bounding_cube_2[0]), + max(bounding_cube_1[1], bounding_cube_2[1]), + max(bounding_cube_1[2], bounding_cube_2[2]), + min(bounding_cube_1[3], bounding_cube_2[3]), + min(bounding_cube_1[4], bounding_cube_2[4]), + min(bounding_cube_1[5], bounding_cube_2[5]), + ] + + def _get_overlapping_volume( + self: NonOverlapping, + volume: np.ndarray, # 3D array. + bounding_cube: tuple[float, float, float, float, float, float], + overlapping_cube: tuple[float, float, float, float, float, float], + ) -> np.ndarray: + """Extracts the overlapping region of a 3D volume within the specified + overlapping cube. + + This method identifies and returns the subregion of `volume` that + lies within the `overlapping_cube`. The bounding information of the + volume is provided via `bounding_cube`. + + Parameters + ---------- + volume: np.ndarray + A 3D NumPy array representing the volume from which the + overlapping region is extracted. + bounding_cube: tuple[float, float, float, float, float, float] + The bounding cube of the volume, given as a tuple of six floats: + `(x1, y1, z1, x2, y2, z2)`. The first three values define the + **top-left-front** corner, while the last three values define the + **bottom-right-back** corner. + overlapping_cube: tuple[float, float, float, float, float, float] + The overlapping region between the volume and another volume, + represented in the same format as `bounding_cube`. + + Returns + ------- + np.ndarray + A 3D NumPy array representing the portion of `volume` that + lies within `overlapping_cube`. If the overlap does not exist, + an empty array may be returned. + + Notes + ----- + - The method computes the relative indices of `overlapping_cube` + within `volume` by subtracting the bounding cube's starting + position. + - The extracted region is determined by integer indices, meaning + coordinates are implicitly **floored to integers**. + - If `overlapping_cube` extends beyond `volume` boundaries, the + returned subregion is **cropped** to fit within `volume`. + + """ + + # The position of the top left corner of the overlapping cube in the volume + overlapping_cube_position = np.array(overlapping_cube[:3]) - np.array( + bounding_cube[:3] + ) + + # The position of the bottom right corner of the overlapping cube in the volume + overlapping_cube_end_position = np.array( + overlapping_cube[3:] + ) - np.array(bounding_cube[:3]) + + # cast to int + overlapping_cube_position = overlapping_cube_position.astype(int) + overlapping_cube_end_position = overlapping_cube_end_position.astype(int) + + return volume[ + overlapping_cube_position[0] : overlapping_cube_end_position[0], + overlapping_cube_position[1] : overlapping_cube_end_position[1], + overlapping_cube_position[2] : overlapping_cube_end_position[2], + ] + + def _check_volumes_non_overlapping( + self: NonOverlapping, + volume_1: np.ndarray, + volume_2: np.ndarray, + min_distance: float, + ) -> bool: + """Determines whether the non-zero voxels in two 3D volumes are at + least `min_distance` apart. + + This method checks whether the active regions (non-zero voxels) in + `volume_1` and `volume_2` maintain a minimum separation of + `min_distance`. If the volumes differ in size, the positions of their + non-zero voxels are adjusted accordingly to ensure a fair comparison. + + Parameters + ---------- + volume_1: np.ndarray + A 3D NumPy array representing the first volume. + volume_2: np.ndarray + A 3D NumPy array representing the second volume. + min_distance: float + The minimum Euclidean distance required between any two non-zero + voxels in the two volumes. + + Returns + ------- + bool + `True` if all non-zero voxels in `volume_1` and `volume_2` are at + least `min_distance` apart, otherwise `False`. + + Notes + ----- + - This function assumes both volumes are correctly aligned within a + shared coordinate space. + - If the volumes are of different sizes, voxel positions are scaled + or adjusted for accurate distance measurement. + - Uses **Euclidean distance** for separation checking. + - If either volume is empty (i.e., no non-zero voxels), they are + considered non-overlapping. + + """ + + # Get the positions of the non-zero voxels of each volume. + if self.get_backend() == "torch": + positions_1 = torch.nonzero(volume_1, as_tuple=False) + positions_2 = torch.nonzero(volume_2, as_tuple=False) + else: + positions_1 = np.argwhere(volume_1) + positions_2 = np.argwhere(volume_2) + + # if positions_1.size == 0 or positions_2.size == 0: + # return True # If either volume is empty, they are "non-overlapping" + + # # If the volumes are not the same size, the positions of the non-zero + # # voxels of each volume need to be scaled. + # if positions_1.size == 0 or positions_2.size == 0: + # return True # If either volume is empty, they are "non-overlapping" + + # If the volumes are not the same size, the positions of the non-zero + # voxels of each volume need to be scaled. + if volume_1.shape != volume_2.shape: + positions_1 = ( + positions_1 * np.array(volume_2.shape) + / np.array(volume_1.shape) + ) + positions_1 = positions_1.astype(int) + + # Check that the non-zero voxels of the volumes are at least + # min_distance apart. + if self.get_backend() == "torch": + dist = torch.cdist( + positions_1.float(), + positions_2.float(), + ) + return bool((dist > min_distance).all()) + else: + return np.all(cdist(positions_1, positions_2) > min_distance) + + def _resample_volume_position( + self: NonOverlapping, + volume: np.ndarray | Image, + ) -> Image: + """Resamples the position of a 3D volume using its internal position + sampler. + + This method updates the `position` property of the given `volume` by + drawing a new position from the `_position_sampler` stored in the + volume's `properties`. If the sampled position is a `Quantity`, it is + converted to pixel units. + + Parameters + ---------- + volume: np.ndarray + The 3D volume whose position is to be resampled. The volume must + have a `properties` attribute containing dictionaries with + `position` and `_position_sampler` keys. + + Returns + ------- + Image + The same input volume with its `position` property updated to the + newly sampled value. + + Notes + ----- + - The `_position_sampler` function is expected to return a **tuple of + three floats** (e.g., `(x, y, z)`). + - If the sampled position is a `Quantity`, it is converted to pixels. + - **Only** dictionaries in `volume.properties` that contain both + `position` and `_position_sampler` keys are modified. + + """ + + pdict = volume.properties + if "position" in pdict and "_position_sampler" in pdict: + new_position = pdict["_position_sampler"]() + if isinstance(new_position, Quantity): + new_position = new_position.to("pixel").magnitude + pdict["position"] = new_position + + return volume + + +class SampleToMasks(Feature): + """Create a mask from a list of images. + + This feature applies a transformation function to each input image and + merges the resulting masks into a single multi-layer image. Each input + image must have a `position` property that determines its placement within + the final mask. When used with scatterers, the `voxel_size` property must + be provided for correct object sizing. + + Parameters + ---------- + transformation_function: Callable[[Image], Image] + A function that transforms each input image into a mask with + `number_of_masks` layers. + number_of_masks: PropertyLike[int], optional + The number of mask layers to generate. Default is 1. + output_region: PropertyLike[tuple[int, int, int, int]], optional + The size and position of the output mask, typically aligned with + `optics.output_region`. + merge_method: PropertyLike[str | Callable | list[str | Callable]], optional + Method for merging individual masks into the final image. Can be: + - "add" (default): Sum the masks. + - "overwrite": Later masks overwrite earlier masks. + - "or": Combine masks using a logical OR operation. + - "mul": Multiply masks. + - Function: Custom function taking two images and merging them. + + **kwargs: dict[str, Any] + Additional keyword arguments passed to the parent `Feature` class. + + Methods + ------- + `get(image, transformation_function, **kwargs) -> Image` + Applies the transformation function to the input image. + `_process_and_get(images, **kwargs) -> Image | np.ndarray` + Processes a list of images and generates a multi-layer mask. + + Returns + ------- + np.ndarray + The final mask image with the specified number of layers. + + Raises + ------ + ValueError + If `merge_method` is invalid. + + Examples + ------- + >>> import deeptrack as dt + + Define number of particles: + + >>> n_particles = 12 + + Define optics and particles: + + >>> import numpy as np + >>> + >>> optics = dt.Fluorescence(output_region=(0, 0, 64, 64)) + >>> particle = dt.PointParticle( + >>> position=lambda: np.random.uniform(5, 55, size=2), + >>> ) + >>> particles = particle ^ n_particles + + Define pipelines: + + >>> sim_im_pip = optics(particles) + >>> sim_mask_pip = particles >> dt.SampleToMasks( + ... lambda: lambda particles: particles > 0, + ... output_region=optics.output_region, + ... merge_method="or", + ... ) + >>> pipeline = sim_im_pip & sim_mask_pip + >>> pipeline.store_properties() + + Generate image and mask: + + >>> image, mask = pipeline.update()() + + Get particle positions: + + >>> positions = np.array(image.get_property("position", get_one=False)) + + Visualize results: + + >>> import matplotlib.pyplot as plt + >>> + >>> plt.subplot(1, 2, 1) + >>> plt.imshow(image, cmap="gray") + >>> plt.title("Original Image") + >>> plt.subplot(1, 2, 2) + >>> plt.imshow(mask, cmap="gray") + >>> plt.scatter(positions[:,1], positions[:,0], c="y", marker="x", s = 50) + >>> plt.title("Mask") + >>> plt.show() + + """ + + def __init__( + self: Feature, + transformation_function: Callable[[np.ndarray], np.ndarray, torch.Tensor], + number_of_masks: PropertyLike[int] = 1, + output_region: PropertyLike[tuple[int, int, int, int]] = None, + merge_method: PropertyLike[str | Callable | list[str | Callable]] = "add", + **kwargs: Any, + ): + """Initialize the SampleToMasks feature. + + Parameters + ---------- + transformation_function: Callable[[Image], Image] + Function to transform input images into masks. + number_of_masks: PropertyLike[int], optional + Number of mask layers. Default is 1. + output_region: PropertyLike[tuple[int, int, int, int]], optional + Output region of the mask. Default is None. + merge_method: PropertyLike[str | Callable | list[str | Callable]], optional + Method to merge masks. Defaults to "add". + **kwargs: dict[str, Any] + Additional keyword arguments passed to the parent class. + + """ + + super().__init__( + transformation_function=transformation_function, + number_of_masks=number_of_masks, + output_region=output_region, + merge_method=merge_method, + **kwargs, + ) + + def get( + self: Feature, + image: np.ndarray, + transformation_function: Callable[list[np.ndarray] | np.ndarray | torch.Tensor], + **kwargs: Any, + ) -> np.ndarray: + """Apply the transformation function to a single image. + + Parameters + ---------- + image: np.ndarray + The input image. + transformation_function: Callable[[np.ndarray], np.ndarray] + Function to transform the image. + **kwargs: dict[str, Any] + Additional parameters. + + Returns + ------- + Image + The transformed image. + + """ + + return transformation_function(image.array) + + def _process_and_get( + self: Feature, + images: list[np.ndarray] | np.ndarray | list[torch.Tensor] | torch.Tensor, + **kwargs: Any, + ) -> np.ndarray: + """Process a list of images and generate a multi-layer mask. + + Parameters + ---------- + images: np.ndarray or list[np.ndarrray] or Image or list[Image] + List of input images or a single image. + **kwargs: dict[str, Any] + Additional parameters including `output_region`, `number_of_masks`, + and `merge_method`. + + Returns + ------- + Image or np.ndarray + The final mask image. + + """ + + # Handle list of images. + # if isinstance(images, list) and len(images) != 1: + list_of_labels = super()._process_and_get(images, **kwargs) + + from deeptrack.scatterers import ScatteredVolume + + for idx, (label, image) in enumerate(zip(list_of_labels, images)): + list_of_labels[idx] = \ + ScatteredVolume(array=label, properties=image.properties.copy()) + + # Create an empty output image. + output_region = kwargs["output_region"] + output = xp.zeros( + ( + output_region[2] - output_region[0], + output_region[3] - output_region[1], + kwargs["number_of_masks"], + ), + dtype=list_of_labels[0].array.dtype, + ) + + from deeptrack.optics import _get_position + + # Merge masks into the output. + for volume in list_of_labels: + label = volume.array + position = _get_position(volume) + + p0 = xp.round(position - xp.asarray(output_region[0:2])) + p0 = p0.astype(xp.int64) + + + if xp.any(p0 > xp.asarray(output.shape[:2])) or \ + xp.any(p0 + xp.asarray(label.shape[:2]) < 0): + continue + + crop_x = (-xp.minimum(p0[0], 0)).item() + crop_y = (-xp.minimum(p0[1], 0)).item() + + crop_x_end = int( + label.shape[0] + - np.max([p0[0] + label.shape[0] - output.shape[0], 0]) + ) + crop_y_end = int( + label.shape[1] + - np.max([p0[1] + label.shape[1] - output.shape[1], 0]) + ) + + labelarg = label[crop_x:crop_x_end, crop_y:crop_y_end, :] + + p0[0] = np.max([p0[0], 0]) + p0[1] = np.max([p0[1], 0]) + + p0 = p0.astype(int) + + output_slice = output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + ] + + for label_index in range(kwargs["number_of_masks"]): + + if isinstance(kwargs["merge_method"], list): + merge = kwargs["merge_method"][label_index] + else: + merge = kwargs["merge_method"] + + if merge == "add": + output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + label_index, + ] += labelarg[..., label_index] + + elif merge == "overwrite": + output_slice[ + labelarg[..., label_index] != 0, label_index + ] = labelarg[labelarg[..., label_index] != 0, \ + label_index] + output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + label_index, + ] = output_slice[..., label_index] + + elif merge == "or": + output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + label_index, + ] = xp.logical_or( + output_slice[..., label_index] != 0, + labelarg[..., label_index] != 0 + ) + + elif merge == "mul": + output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + label_index, + ] *= labelarg[..., label_index] + + else: + # No match, assume function + output[ + p0[0] : p0[0] + labelarg.shape[0], + p0[1] : p0[1] + labelarg.shape[1], + label_index, + ] = merge( + output_slice[..., label_index], + labelarg[..., label_index], + ) + + return output + + #TODO ***??*** revise _get_position - torch, typing, docstring, unit test def _get_position( scatterer: ScatteredObject, From e006806829692d1823b41fd9248dd7413963fb4f Mon Sep 17 00:00:00 2001 From: Carlo Date: Thu, 15 Jan 2026 18:13:58 +0100 Subject: [PATCH 22/24] u --- deeptrack/features.py | 5 +- deeptrack/math.py | 1375 ++++++++++++++--------------------------- deeptrack/optics.py | 20 +- 3 files changed, 483 insertions(+), 917 deletions(-) diff --git a/deeptrack/features.py b/deeptrack/features.py index 901664a9..702e7362 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -1,6 +1,6 @@ """Core features for building and processing pipelines in DeepTrack2. -The `feasture.py` module defines the core classes and utilities used to create +The `feature.py` module defines the core classes and utilities used to create and manipulate features in DeepTrack2, enabling users to build sophisticated data processing pipelines with modular, reusable, and composable components. @@ -80,11 +80,8 @@ - `OneOf`: Resolve one feature from a given collection. - `OneOfDict`: Resolve one feature from a dictionary and apply it to an input. - `LoadImage`: Load an image from disk and preprocess it. -- `SampleToMasks`: Create a mask from a list of images. - `AsType`: Convert the data type of the input. - `ChannelFirst2d`: DEPRECATED Convert an image to a channel-first format. -- `Upscale`: Simulate a pipeline at a higher resolution. -- `NonOverlapping`: Ensure volumes are placed non-overlapping in a 3D space. - `Store`: Store the output of a feature for reuse. - `Squeeze`: Squeeze the input to the smallest possible dimension. - `Unsqueeze`: Unsqueeze the input. diff --git a/deeptrack/math.py b/deeptrack/math.py index 0af95c05..5845fb90 100644 --- a/deeptrack/math.py +++ b/deeptrack/math.py @@ -97,15 +97,15 @@ import array_api_compat as apc import numpy as np -from numpy.typing import NDArray +from numpy.typing import NDArray #TODO TBE from scipy import ndimage import skimage import skimage.measure from deeptrack import utils, OPENCV_AVAILABLE, TORCH_AVAILABLE from deeptrack.features import Feature -from deeptrack.image import Image, strip -from deeptrack.types import ArrayLike, PropertyLike +from deeptrack.image import Image, strip #TODO TBE +from deeptrack.types import PropertyLike from deeptrack.backend import xp if TORCH_AVAILABLE: @@ -130,11 +130,6 @@ "MaxPooling", "MinPooling", "MedianPooling", - "PoolV2", - "AveragePoolingV2", - "MaxPoolingV2", - "MinPoolingV2", - "MedianPoolingV2", "BlurCV2", "BilateralBlur", ] @@ -232,10 +227,10 @@ def __init__( def get( self: Average, - images: list[NDArray[Any] | torch.Tensor | Image], + images: list[np.ndarray | torch.Tensor], axis: int | tuple[int], **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor | Image: + ) -> np.ndarray | torch.Tensor: """Compute the average of input images along the specified axis(es). This method computes the average of the input images along the @@ -323,11 +318,11 @@ def __init__( def get( self: Clip, - image: NDArray[Any] | torch.Tensor | Image, + image: np.ndarray | torch.Tensor, min: float, max: float, **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor | Image: + ) -> np.ndarray | torch.Tensor: """Clips the input image within the specified values. This method clips the input image within the specified minimum and @@ -368,8 +363,7 @@ class NormalizeMinMax(Feature): max: float, optional Upper bound of the transformation. It defaults to 1. featurewise: bool, optional - Whether to normalize each feature independently. It default to `True`, - which is the only behavior currently implemented. + Whether to normalize each feature independently. It default to `True`. Methods ------- @@ -395,8 +389,6 @@ class NormalizeMinMax(Feature): """ - #TODO ___??___ Implement the `featurewise=False` option - def __init__( self: NormalizeMinMax, min: PropertyLike[float] = 0, @@ -423,32 +415,47 @@ def __init__( def get( self: NormalizeMinMax, - image: ArrayLike, + image: np.ndarray | torch.Tensor, min: float, max: float, + featurewise: bool = True, **kwargs: Any, - ) -> ArrayLike: + ) -> np.ndarray | torch.Tensor: """Normalize the input to fall between `min` and `max`. Parameters ---------- - image: array + image: np.ndarray or torch.Tensor Input image to normalize. min: float Lower bound of the output range. max: float Upper bound of the output range. + featurewise: bool + Whether to normalize each feature (channel) independently. Returns ------- - array + np.ndarray or torch.Tensor Min-max normalized image. """ - ptp = xp.max(image) - xp.min(image) - image = image / ptp * (max - min) - image = image - xp.min(image) + min + if featurewise: + # Normalize per feature (last axis) + axis = tuple(range(image.ndim - 1)) + + img_min = xp.min(image, axis=axis, keepdims=True) + img_max = xp.max(image, axis=axis, keepdims=True) + else: + # Normalize globally + img_min = xp.min(image) + img_max = xp.max(image) + + ptp = img_max - img_min + + # Avoid division by zero + image = (image - img_min) / ptp * (max - min) + min try: image[xp.isnan(image)] = 0 @@ -492,8 +499,6 @@ class NormalizeStandard(Feature): """ - #TODO ___??___ Implement the `featurewise=False` option - def __init__( self: NormalizeStandard, featurewise: PropertyLike[bool] = True, @@ -516,33 +521,52 @@ def __init__( def get( self: NormalizeStandard, - image: NDArray[Any] | torch.Tensor | Image, + image: np.ndarray | torch.Tensor, + featurewise: bool, **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor | Image: + ) -> np.ndarray | torch.Tensor: """Normalizes the input image to have mean 0 and standard deviation 1. - This method normalizes the input image to have mean 0 and standard - deviation 1. - Parameters ---------- - image: array + image: np.ndarray or torch.Tensor The input image to normalize. + featurewise: bool + Whether to normalize each feature (channel) independently. Returns ------- - array - The normalized image. - + np.ndarray or torch.Tensor + The standardized image. """ - if apc.is_torch_array(image): - # By default, torch.std() is unbiased, i.e., divides by N-1 - return ( - (image - torch.mean(image)) / torch.std(image, unbiased=False) - ) + if featurewise: + # Normalize per feature (last axis) + axis = tuple(range(image.ndim - 1)) + + mean = xp.mean(image, axis=axis, keepdims=True) + + if apc.is_torch_array(image): + std = torch.std(image, dim=axis, keepdim=True, unbiased=False) + else: + std = xp.std(image, axis=axis) + else: + # Normalize globally + mean = xp.mean(image) - return (image - xp.mean(image)) / xp.std(image) + if apc.is_torch_array(image): + std = torch.std(image, unbiased=False) + else: + std = xp.std(image) + + image = (image - mean) / std + + try: + image[xp.isnan(image)] = 0 + except TypeError: + pass + + return image class NormalizeQuantile(Feature): @@ -614,158 +638,164 @@ def __init__( def get( self: NormalizeQuantile, - image: NDArray[Any] | torch.Tensor | Image, - quantiles: tuple[float, float] = None, + image: np.ndarray | torch.Tensor, + quantiles: tuple[float, float], + featurewise: bool, **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor | Image: + ) -> np.ndarray | torch.Tensor: """Normalize the input image based on the specified quantiles. - This method normalizes the input image based on the specified - quantiles. - Parameters ---------- - image: array + image: np.ndarray or torch.Tensor The input image to normalize. quantiles: tuple[float, float] Quantile range to calculate scaling factor. + featurewise: bool + Whether to normalize each feature (channel) independently. Returns ------- - array - The normalized image. - + np.ndarray or torch.Tensor + The quantile-normalized image. """ - if apc.is_torch_array(image): - q_tensor = torch.tensor( - [*quantiles, 0.5], - device=image.device, - dtype=image.dtype, - ) - q_low, q_high, median = torch.quantile( - image, q_tensor, dim=None, keepdim=False, - ) - else: # NumPy - q_low, q_high, median = xp.quantile(image, (*quantiles, 0.5)) + q_low_val, q_high_val = quantiles - return (image - median) / (q_high - q_low) * 2.0 + if featurewise: + # Per-feature normalization (last axis) + axis = tuple(range(image.ndim - 1)) + if apc.is_torch_array(image): + q = torch.tensor( + [q_low_val, q_high_val, 0.5], + device=image.device, + dtype=image.dtype, + ) + q_low, q_high, median = torch.quantile( + image, q, dim=axis, keepdim=True + ) + else: + q_low, q_high, median = xp.quantile( + image, (q_low_val, q_high_val, 0.5), + axis=axis, + keepdims=True, + ) + else: + # Global normalization + if apc.is_torch_array(image): + q = torch.tensor( + [q_low_val, q_high_val, 0.5], + device=image.device, + dtype=image.dtype, + ) + q_low, q_high, median = torch.quantile( + image, q, dim=None, keepdim=False + ) + else: + q_low, q_high, median = xp.quantile( + image, (q_low_val, q_high_val, 0.5) + ) -#TODO ***JH*** revise Blur - torch, typing, docstring, unit test -class Blur(Feature): - """Apply a blurring filter to an image. - - This class applies a blurring filter to an image. The filter function - must be a function that takes an input image and returns a blurred - image. - - Parameters - ---------- - filter_function: Callable - The blurring function to apply. This function must accept the input - image as a keyword argument named `input`. If using OpenCV functions - (e.g., `cv2.GaussianBlur`), use `BlurCV2` instead. - mode: str - Border mode for handling boundaries (e.g., 'reflect'). - - Methods - ------- - `get(image: np.ndarray | Image, **kwargs: Any) --> np.ndarray` - Applies the blurring filter to the input image. + image = (image - median) / (q_high - q_low) * 2.0 - Examples - -------- - >>> import deeptrack as dt - >>> import numpy as np - >>> from scipy.ndimage import convolve + try: + image[xp.isnan(image)] = 0 + except TypeError: + pass - Create an input image: - >>> input_image = np.random.rand(32, 32) + return image - Define a Gaussian kernel for blurring: - >>> gaussian_kernel = np.array([ - ... [1, 4, 6, 4, 1], - ... [4, 16, 24, 16, 4], - ... [6, 24, 36, 24, 6], - ... [4, 16, 24, 16, 4], - ... [1, 4, 6, 4, 1] - ... ], dtype=float) - >>> gaussian_kernel /= np.sum(gaussian_kernel) - Define a blur function using the Gaussian kernel: - >>> def gaussian_blur(input, **kwargs): - ... return convolve(input, gaussian_kernel, mode='reflect') +#TODO ***CM*** revise typing, docstring, unit test +class Blur(Feature): + """Apply a blurring filter to an image. - Define a blur feature using the Gaussian blur function: - >>> blur = dt.Blur(filter_function=gaussian_blur) - >>> output_image = blur(input_image) - >>> print(output_image.shape) - (32, 32) + This class acts as a backend-dispatching blur operator. Subclasses must + implement backend-specific logic via `_get_numpy` and optionally + `_get_torch`. Notes ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - The filter_function must accept the input image as a keyword argument named - input. This is required because it is called via utils.safe_call. If you - are using functions that do not support input=... (such as OpenCV filters - like cv2.GaussianBlur), consider using BlurCV2 instead. + - NumPy execution is always supported. + - Torch execution is only supported if `_get_torch` is implemented. + - Generic `filter_function`-based blurs are NumPy-only by design. """ def __init__( - self: Blur, - filter_function: Callable, + self, + filter_function: Callable | None = None, mode: PropertyLike[str] = "reflect", **kwargs: Any, ): - """Initialize the parameters for blurring input features. - - This constructor initializes the parameters for blurring input - features. + """Initialize the blur feature. Parameters ---------- - filter_function: Callable - The blurring function to apply. - mode: str - Border mode for handling boundaries (e.g., 'reflect'). - **kwargs: Any - Additional keyword arguments. - + filter_function : Callable or None + NumPy-based blurring function. Must accept the input image as a + keyword argument named `input`. If `None`, the subclass must + implement `_get_numpy`. + mode : str + Border mode for NumPy-based filters. + **kwargs : Any + Additional keyword arguments passed to Feature. """ - self.filter = filter_function - super().__init__(borderType=mode, **kwargs) + self.mode = mode + super().__init__(**kwargs) - def get(self: Blur, image: np.ndarray | Image, **kwargs: Any) -> np.ndarray: - """Applies the blurring filter to the input image. + def __call__( + self, + image: np.ndarray | torch.Tensor, + **kwargs: Any, + ) -> np.ndarray | torch.Tensor: + if isinstance(image, np.ndarray): + return self._get_numpy(image, **kwargs) - This method applies the blurring filter to the input image. + if TORCH_AVAILABLE and isinstance(image, torch.Tensor): + return self._get_torch(image, **kwargs) - Parameters - ---------- - image: np.ndarray - The input image to blur. - **kwargs: dict[str, Any] - Additional keyword arguments. + raise TypeError( + "Blur only supports numpy.ndarray or torch.Tensor inputs." + ) - Returns - ------- - np.ndarray - The blurred image. + def _get_numpy( + self, + image: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + if self.filter is None: + raise NotImplementedError( + f"{self.__class__.__name__} does not implement a NumPy backend." + ) - """ + # Avoid passing conflicting keywords + kwargs = dict(kwargs) + kwargs.pop("input", None) + + return utils.safe_call( + self.filter, + input=image, + mode=self.mode, + **kwargs, + ) + + def _get_torch( + self, + image: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + raise TypeError( + f"{self.__class__.__name__} does not support torch.Tensor inputs. " + "Use a Torch-enabled blur (e.g. AverageBlur or a V2 blur class)." + ) - kwargs.pop("input", False) - return utils.safe_call(self.filter, input=image, **kwargs) -#TODO ***JH*** revise AverageBlur - torch, typing, docstring, unit test +#TODO ***CM*** revise AverageBlur - torch, typing, docstring, unit test class AverageBlur(Blur): """Blur an image by computing simple means over neighbourhoods. @@ -779,7 +809,7 @@ class AverageBlur(Blur): Methods ------- - `get(image: np.ndarray | Image, ksize: int, **kwargs: Any) --> np.ndarray` + `get(image: np.ndarray | torch.Tensor, ksize: int, **kwargs: Any) --> np.ndarray | torch.Tensor` Applies the average blurring filter to the input image. Examples @@ -796,13 +826,6 @@ class AverageBlur(Blur): >>> print(output_image.shape) (32, 32) - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - """ def __init__( @@ -833,800 +856,285 @@ def _kernel_shape(self, shape: tuple[int, ...], ksize: int) -> tuple[int, ...]: def _get_numpy( self, input: np.ndarray, ksize: tuple[int, ...], **kwargs: Any - ) -> np.ndarray: - return ndimage.uniform_filter( - input, - size=ksize, - mode=kwargs.get("mode", "reflect"), - cval=kwargs.get("cval", 0), - origin=kwargs.get("origin", 0), - axes=tuple(range(0, len(ksize))), - ) - - def _get_torch( - self, input: torch.Tensor, ksize: tuple[int, ...], **kwargs: Any - ) -> np.ndarray: - F = xp.nn.functional - - last_dim_is_channel = len(ksize) < input.ndim - if last_dim_is_channel: - # permute to first dim - input = input.movedim(-1, 0) - else: - input = input.unsqueeze(0) - - # add batch dimension - input = input.unsqueeze(0) - - # pad input - input = F.pad( - input, - (ksize[0] // 2, ksize[0] // 2, ksize[1] // 2, ksize[1] // 2), - mode=kwargs.get("mode", "reflect"), - value=kwargs.get("cval", 0), - ) - if input.ndim == 3: - x = F.avg_pool1d( - input, - kernel_size=ksize, - stride=1, - padding=0, - ceil_mode=False, - count_include_pad=False, - ) - elif input.ndim == 4: - x = F.avg_pool2d( - input, - kernel_size=ksize, - stride=1, - padding=0, - ceil_mode=False, - count_include_pad=False, - ) - elif input.ndim == 5: - x = F.avg_pool3d( - input, - kernel_size=ksize, - stride=1, - padding=0, - ceil_mode=False, - count_include_pad=False, - ) - else: - raise NotImplementedError( - f"Input dimension {input.ndim - 2} not supported for torch backend" - ) - - # restore layout - x = x.squeeze(0) - if last_dim_is_channel: - x = x.movedim(0, -1) - else: - x = x.squeeze(0) - - return x - - def get( - self: AverageBlur, - input: ArrayLike, - ksize: int, - **kwargs: Any, - ) -> np.ndarray: - """Applies the average blurring filter to the input image. - - This method applies the average blurring filter to the input image. - - Parameters - ---------- - input: np.ndarray - The input image to blur. - ksize: int - Kernel size for the pooling operation. - **kwargs: dict[str, Any] - Additional keyword arguments. - - Returns - ------- - np.ndarray - The blurred image. - - """ - - k = self._kernel_shape(input.shape, ksize) - - if self.backend == "numpy": - return self._get_numpy(input, k, **kwargs) - elif self.backend == "torch": - return self._get_torch(input, k, **kwargs) - else: - raise NotImplementedError(f"Backend {self.backend} not supported") - - -#TODO ***JH*** revise GaussianBlur - torch, typing, docstring, unit test -class GaussianBlur(Blur): - """Applies a Gaussian blur to images using Gaussian kernels. - - This class blurs images by convolving them with a Gaussian filter, which - smooths the image and reduces high-frequency details. The level of blurring - is controlled by the standard deviation (`sigma`) of the Gaussian kernel. - - Parameters - ---------- - sigma: float - Standard deviation of the Gaussian kernel. - - Examples - -------- - >>> import deeptrack as dt - >>> import numpy as np - >>> import matplotlib.pyplot as plt - - Create an input image: - >>> input_image = np.random.rand(32, 32) - - Define a Gaussian blur feature: - >>> gaussian_blur = dt.GaussianBlur(sigma=2) - >>> output_image = gaussian_blur(input_image) - >>> print(output_image.shape) - (32, 32) - - Visualize the input and output images: - >>> plt.figure(figsize=(8, 4)) - >>> plt.subplot(1, 2, 1) - >>> plt.imshow(input_image, cmap='gray') - >>> plt.subplot(1, 2, 2) - >>> plt.imshow(output_image, cmap='gray') - >>> plt.show() - - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - - """ - - def __init__(self: GaussianBlur, sigma: PropertyLike[float] = 2, **kwargs: Any): - """Initialize the parameters for Gaussian blurring. - - This constructor initializes the parameters for Gaussian blurring. - - Parameters - ---------- - sigma: float - Standard deviation of the Gaussian kernel. - **kwargs: Any - Additional keyword arguments. - - """ - - super().__init__(ndimage.gaussian_filter, sigma=sigma, **kwargs) - - -#TODO ***JH*** revise MedianBlur - torch, typing, docstring, unit test -class MedianBlur(Blur): - """Applies a median blur. - - This class replaces each pixel of the input image with the median value of - its neighborhood. The `ksize` parameter determines the size of the - neighborhood used to calculate the median filter. The median filter is - useful for reducing noise while preserving edges. It is particularly - effective for removing salt-and-pepper noise from images. - - Parameters - ---------- - ksize: int - Kernel size. - **kwargs: dict - Additional parameters sent to the blurring function. - - Examples - -------- - >>> import deeptrack as dt - >>> import numpy as np - >>> import matplotlib.pyplot as plt - - Create an input image: - >>> input_image = np.random.rand(32, 32) - - Define a median blur feature: - >>> median_blur = dt.MedianBlur(ksize=3) - >>> output_image = median_blur(input_image) - >>> print(output_image.shape) - (32, 32) - - Visualize the input and output images: - >>> plt.figure(figsize=(8, 4)) - >>> plt.subplot(1, 2, 1) - >>> plt.imshow(input_image, cmap='gray') - >>> plt.subplot(1, 2, 2) - >>> plt.imshow(output_image, cmap='gray') - >>> plt.show() - - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - - """ - - def __init__( - self: MedianBlur, - ksize: PropertyLike[int] = 3, - **kwargs: Any, - ): - """Initialize the parameters for median blurring. - - This constructor initializes the parameters for median blurring. - - Parameters - ---------- - ksize: int - Kernel size. - **kwargs: Any - Additional keyword arguments. - - """ - - super().__init__(ndimage.median_filter, size=ksize, **kwargs) - - -#TODO ***AL*** revise Pool - torch, typing, docstring, unit test -class Pool(Feature): - """Downsamples the image by applying a function to local regions of the - image. - - This class reduces the resolution of an image by dividing it into - non-overlapping blocks of size `ksize` and applying the specified pooling - function to each block. The result is a downsampled image where each pixel - value represents the result of the pooling function applied to the - corresponding block. - - Parameters - ---------- - pooling_function: function - A function that is applied to each local region of the image. - DOES NOT NEED TO BE WRAPPED IN ANOTHER FUNCTION. - The `pooling_function` must accept the input image as a keyword argument - named `input`, as it is called via `utils.safe_call`. - Examples include `np.mean`, `np.max`, `np.min`, etc. - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional parameters sent to the pooling function. - - Methods - ------- - `get(image: np.ndarray | Image, ksize: int, **kwargs: Any) --> np.ndarray` - Applies the pooling function to the input image. - - Examples - -------- - >>> import deeptrack as dt - >>> import numpy as np - - Create an input image: - >>> input_image = np.random.rand(32, 32) - - Define a pooling feature: - >>> pooling_feature = dt.Pool(pooling_function=np.mean, ksize=4) - >>> output_image = pooling_feature.get(input_image, ksize=4) - >>> print(output_image.shape) - (8, 8) - - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - The filter_function must accept the input image as a keyword argument named - input. This is required because it is called via utils.safe_call. If you - are using functions that do not support input=... (such as OpenCV filters - like cv2.GaussianBlur), consider using BlurCV2 instead. - - """ - - def __init__( - self: Pool, - pooling_function: Callable, - ksize: PropertyLike[int] = 3, - **kwargs: Any, - ): - """Initialize the parameters for pooling input features. - - This constructor initializes the parameters for pooling input - features. - - Parameters - ---------- - pooling_function: Callable - The pooling function to apply. - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional keyword arguments. - - """ - - self.pooling = pooling_function - super().__init__(ksize=ksize, **kwargs) - - def get( - self: Pool, - image: np.ndarray | Image, - ksize: int, - **kwargs: Any, - ) -> np.ndarray: - """Applies the pooling function to the input image. - - This method applies the pooling function to the input image. - - Parameters - ---------- - image: np.ndarray - The input image to pool. - ksize: int - Size of the pooling kernel. - **kwargs: dict[str, Any] - Additional keyword arguments. - - Returns - ------- - np.ndarray - The pooled image. - - """ - - kwargs.pop("func", False) - kwargs.pop("image", False) - kwargs.pop("block_size", False) - return utils.safe_call( - skimage.measure.block_reduce, - image=image, - func=self.pooling, - block_size=ksize, - **kwargs, - ) - - -#TODO ***AL*** revise AveragePooling - torch, typing, docstring, unit test -class AveragePooling(Pool): - """Apply average pooling to an image. - - This class reduces the resolution of an image by dividing it into - non-overlapping blocks of size `ksize` and applying the average function to - each block. The result is a downsampled image where each pixel value - represents the average value within the corresponding block of the - original image. - - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: dict - Additional parameters sent to the pooling function. - - Examples - -------- - >>> import deeptrack as dt - >>> import numpy as np - - Create an input image: - >>> input_image = np.random.rand(32, 32) - - Define an average pooling feature: - >>> average_pooling = dt.AveragePooling(ksize=4) - >>> output_image = average_pooling(input_image) - >>> print(output_image.shape) - (8, 8) - - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - - """ - - def __init__( - self: Pool, - ksize: PropertyLike[int] = 3, - **kwargs: Any, - ): - """Initialize the parameters for average pooling. - - This constructor initializes the parameters for average pooling. - - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional keyword arguments. - - """ - - super().__init__(np.mean, ksize=ksize, **kwargs) - - -class MaxPooling(Pool): - """Apply max-pooling to images. - - `MaxPooling` reduces the resolution of an image by dividing it into - non-overlapping blocks of size `ksize` and applying the `max` function - to each block. The result is a downsampled image where each pixel value - represents the maximum value within the corresponding block of the - original image. This is useful for reducing the size of an image while - retaining the most significant features. - - If the backend is NumPy, the downsampling is performed using - `skimage.measure.block_reduce`. - - If the backend is PyTorch, the downsampling is performed using - `torch.nn.functional.max_pool2d`. - - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional parameters sent to the pooling function. - - Examples - -------- - >>> import deeptrack as dt - - Create an input image: - >>> import numpy as np - >>> - >>> input_image = np.random.rand(32, 32) - - Define and use a max-pooling feature: - - >>> max_pooling = dt.MaxPooling(ksize=8) - >>> output_image = max_pooling(input_image) - >>> output_image.shape - (4, 4) - - """ - - def __init__( - self: MaxPooling, - ksize: PropertyLike[int] = 3, - **kwargs: Any, - ): - """Initialize the parameters for max-pooling. - - This constructor initializes the parameters for max-pooling. - - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional keyword arguments. - - """ - - super().__init__(np.max, ksize=ksize, **kwargs) - - def get( - self: MaxPooling, - image: NDArray[Any] | torch.Tensor, - ksize: int=3, - **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor: - """Max-pooling of input. - - Checks the current backend and chooses the appropriate function to pool - the input image, either `._get_torch()` or `._get_numpy()`. - - Parameters - ---------- - image: array or tensor - Input array or tensor be pooled. - ksize: int - Kernel size of the pooling operation. - - Returns - ------- - array or tensor - The pooled input as `NDArray` or `torch.Tensor` depending on - the backend. - - """ + ) -> np.ndarray: + return ndimage.uniform_filter( + input, + size=ksize, + mode=kwargs.get("mode", "reflect"), + cval=kwargs.get("cval", 0), + origin=kwargs.get("origin", 0), + axes=tuple(range(0, len(ksize))), + ) - if self.get_backend() == "numpy": - return self._get_numpy(image, ksize, **kwargs) + def _get_torch( + self, input: torch.Tensor, ksize: tuple[int, ...], **kwargs: Any + ) -> torch.Tensor: - if self.get_backend() == "torch": - return self._get_torch(image, ksize, **kwargs) + last_dim_is_channel = len(ksize) < input.ndim + if last_dim_is_channel: + input = input.movedim(-1, 0) + else: + input = input.unsqueeze(0) - raise NotImplementedError(f"Backend {self.backend} not supported") + # add batch dimension + input = input.unsqueeze(0) - def _get_numpy( - self: MaxPooling, - image: NDArray[Any], - ksize: int=3, - **kwargs: Any, - ) -> NDArray[Any]: - """Max-pooling pooling with the NumPy backend enabled. + # dynamic padding + pad = [] + for k in reversed(ksize): + p = k // 2 + pad.extend([p, p]) + pad = tuple(pad) - Returns the result of the input array passed to the scikit image - `block_reduce()` function with `np.max()` as the pooling function. + input = F.pad( + input, + pad, + mode=kwargs.get("mode", "reflect"), + value=kwargs.get("cval", 0), + ) - Parameters - ---------- - image: array - Input array to be pooled. - ksize: int - Kernel size of the pooling operation. + if input.ndim == 3: + x = F.avg_pool1d(input, kernel_size=ksize, stride=1) + elif input.ndim == 4: + x = F.avg_pool2d(input, kernel_size=ksize, stride=1) + elif input.ndim == 5: + x = F.avg_pool3d(input, kernel_size=ksize, stride=1) + else: + raise NotImplementedError( + f"Input dimension {input.ndim - 2} not supported for torch backend" + ) - Returns - ------- - array - The pooled image as a NumPy array. - - """ + # restore layout + x = x.squeeze(0) + if last_dim_is_channel: + x = x.movedim(0, -1) + else: + x = x.squeeze(0) - return utils.safe_call( - skimage.measure.block_reduce, - image=image, - func=np.max, - block_size=ksize, - **kwargs, - ) + return x - def _get_torch( - self: MaxPooling, - image: torch.Tensor, - ksize: int=3, + def get( + self: AverageBlur, + input: np.ndarray | torch.Tensor, + ksize: int, **kwargs: Any, - ) -> torch.Tensor: - """Max-pooling with the PyTorch backend enabled. - + ) -> np.ndarray | torch.Tensor: + """Applies the average blurring filter to the input image. - Returns the result of the tensor passed to a PyTorch max - pooling layer. + This method applies the average blurring filter to the input image. Parameters ---------- - image: torch.Tensor - Input tensor to be pooled. + input: np.ndarray + The input image to blur. ksize: int - Kernel size of the pooling operation. + Kernel size for the pooling operation. + **kwargs: dict[str, Any] + Additional keyword arguments. Returns ------- - torch.Tensor - The pooled image as a `torch.Tensor`. + np.ndarray + The blurred image. """ - # If input tensor is 2D - if len(image.shape) == 2: - # Add batch dimension for max-pooling - expanded_image = image.unsqueeze(0) - - pooled_image = torch.nn.functional.max_pool2d( - expanded_image, kernel_size=ksize, - ) - # Remove the expanded dim - return pooled_image.squeeze(0) - - return torch.nn.functional.max_pool2d( - image, - kernel_size=ksize, - ) - + k = self._kernel_shape(input.shape, ksize) -class MinPooling(Pool): - """Apply min-pooling to images. + if self.backend == "numpy": + return self._get_numpy(input, k, **kwargs) + elif self.backend == "torch": + return self._get_torch(input, k, **kwargs) + else: + raise NotImplementedError(f"Backend {self.backend} not supported") - `MinPooling` reduces the resolution of an image by dividing it into - non-overlapping blocks of size `ksize` and applying the `min` function to - each block. The result is a downsampled image where each pixel value - represents the minimum value within the corresponding block of the original - image. - If the backend is NumPy, the downsampling is performed using - `skimage.measure.block_reduce`. +#TODO ***CM*** revise typing, docstring, unit test +class GaussianBlur(Blur): + """Applies a Gaussian blur to images using Gaussian kernels. - If the backend is PyTorch, the downsampling is performed using the inverse - of `torch.nn.functional.max_pool2d` by changing the sign of the input. + This class blurs images by convolving them with a Gaussian filter, which + smooths the image and reduces high-frequency details. The level of blurring + is controlled by the standard deviation (`sigma`) of the Gaussian kernel. Parameters ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional parameters sent to the pooling function. + sigma: float + Standard deviation of the Gaussian kernel. Examples -------- >>> import deeptrack as dt + >>> import numpy as np + >>> import matplotlib.pyplot as plt Create an input image: - >>> import numpy as np - >>> >>> input_image = np.random.rand(32, 32) - Define and use a min-pooling feature: - >>> min_pooling = dt.MinPooling(ksize=4) - >>> output_image = min_pooling(input_image) - >>> output_image.shape - (8, 8) + Define a Gaussian blur feature: + >>> gaussian_blur = dt.GaussianBlur(sigma=2) + >>> output_image = gaussian_blur(input_image) + >>> print(output_image.shape) + (32, 32) + + Visualize the input and output images: + >>> plt.figure(figsize=(8, 4)) + >>> plt.subplot(1, 2, 1) + >>> plt.imshow(input_image, cmap='gray') + >>> plt.subplot(1, 2, 2) + >>> plt.imshow(output_image, cmap='gray') + >>> plt.show() """ - def __init__( - self: MinPooling, - ksize: PropertyLike[int] = 3, - **kwargs: Any, - ): - """Initialize the parameters for min-pooling. + def __init__(self: GaussianBlur, sigma: PropertyLike[float] = 2, **kwargs: Any): + """Initialize the parameters for Gaussian blurring. - This constructor initializes the parameters for min-pooling and checks - whether to use the NumPy or PyTorch implementation, defaults to NumPy. + This constructor initializes the parameters for Gaussian blurring. Parameters ---------- - ksize: int - Size of the pooling kernel. + sigma: float + Standard deviation of the Gaussian kernel. **kwargs: Any Additional keyword arguments. """ - super().__init__(np.min, ksize=ksize, **kwargs) - - def get( - self: MinPooling, - image: NDArray[Any] | torch.Tensor, - ksize: int=3, - **kwargs: Any, - ) -> NDArray[Any] | torch.Tensor: - """Min pooling of input. - - Checks the current backend and chooses the appropriate function to pool - the input image, either `._get_torch()` or `._get_numpy()`. - - Parameters - ---------- - image: array or tensor - Input array or tensor to be pooled. - ksize: int - Kernel size of the pooling operation. - - Returns - ------- - array or tensor - The pooled image as `NDArray` or `torch.Tensor` depending on the - backend. - - """ - - if self.get_backend() == "numpy": - return self._get_numpy(image, ksize, **kwargs) - - if self.get_backend() == "torch": - return self._get_torch(image, ksize, **kwargs) - - raise NotImplementedError(f"Backend {self.backend} not supported") + self.sigma = float(sigma) + super().__init__(None, **kwargs) def _get_numpy( - self: MinPooling, - image: NDArray[Any], - ksize: int=3, + self, + input: np.ndarray, **kwargs: Any, - ) -> NDArray[Any]: - """Min-pooling with the NumPy backend. - - Returns the result of the input array passed to the scikit - `image block_reduce()` function with `np.min()` as the pooling - function. - - Parameters - ---------- - image: NDArray - Input image to be pooled. - ksize: int - Kernel size of the pooling operation. - - Returns - ------- - NDArray - The pooled image as a `NDArray`. - - """ + ) -> np.ndarray: + return ndimage.gaussian_filter( + input, + sigma=self.sigma, + mode=kwargs.get("mode", "reflect"), + cval=kwargs.get("cval", 0), + ) - return utils.safe_call( - skimage.measure.block_reduce, - image=image, - func=np.min, - block_size=ksize, - **kwargs, + def _gaussian_kernel_1d( + self, + sigma: float, + device, + dtype, + ) -> torch.Tensor: + radius = int(np.ceil(3 * sigma)) + x = torch.arange( + -radius, radius + 1, + device=device, + dtype=dtype, ) + kernel = torch.exp(-(x ** 2) / (2 * sigma ** 2)) + kernel /= kernel.sum() + return kernel def _get_torch( - self: MinPooling, - image: torch.Tensor, - ksize: int=3, + self, + input: torch.Tensor, **kwargs: Any, ) -> torch.Tensor: - """Min-pooling with the PyTorch backend. + import torch.nn.functional as F - As PyTorch does not have a min-pooling layer, the equivalent operation - is to first multiply the input tensor with `-1`, then perform - max-pooling, and finally multiply the max pooled tensor with `-1`. + sigma = self.sigma + kernel_1d = self._gaussian_kernel_1d( + sigma, + device=input.device, + dtype=input.dtype, + ) - Parameters - ---------- - image: torch.Tensor - Input tensor to be pooled. - ksize: int - Kernel size of the pooling operation. + last_dim_is_channel = input.ndim >= 3 + if last_dim_is_channel: + input = input.movedim(-1, 0) # C, ... + else: + input = input.unsqueeze(0) # 1, ... - Returns - ------- - torch.Tensor - The pooled image as a `torch.Tensor`. + # add batch dimension + input = input.unsqueeze(0) # 1, C, ... - """ + spatial_dims = input.ndim - 2 + C = input.shape[1] + + for d in range(spatial_dims): + k = kernel_1d + shape = [1] * spatial_dims + shape[d] = -1 + k = k.view(1, 1, *shape) + k = k.repeat(C, 1, *([1] * spatial_dims)) - # If input tensor is 2D - if len(image.shape) == 2: - # Add batch dimension for min-pooling - expanded_image = image.unsqueeze(0) + pad = [0, 0] * spatial_dims + radius = k.shape[2 + d] // 2 + pad[-(2 * d + 2)] = radius + pad[-(2 * d + 1)] = radius + pad = tuple(pad) - pooled_image = - torch.nn.functional.max_pool2d( - expanded_image * (-1), - kernel_size=ksize, + input = F.pad( + input, + pad, + mode=kwargs.get("mode", "reflect"), ) - # Remove the expanded dim - return pooled_image.squeeze(0) + if spatial_dims == 1: + input = F.conv1d(input, k, groups=C) + elif spatial_dims == 2: + input = F.conv2d(input, k, groups=C) + elif spatial_dims == 3: + input = F.conv3d(input, k, groups=C) + else: + raise NotImplementedError( + f"{spatial_dims}D Gaussian blur not supported" + ) + + # restore layout + input = input.squeeze(0) + if last_dim_is_channel: + input = input.movedim(0, -1) + else: + input = input.squeeze(0) - return -torch.nn.functional.max_pool2d( - image * (-1), - kernel_size=ksize, - ) + return input -#TODO ***AL*** revise MedianPooling - torch, typing, docstring, unit test -class MedianPooling(Pool): - """Apply median pooling to images. +#TODO ***JH*** revise MedianBlur - torch, typing, docstring, unit test +class MedianBlur(Blur): + """Applies a median blur. + + This class replaces each pixel of the input image with the median value of + its neighborhood. The `ksize` parameter determines the size of the + neighborhood used to calculate the median filter. The median filter is + useful for reducing noise while preserving edges. It is particularly + effective for removing salt-and-pepper noise from images. - This class reduces the resolution of an image by dividing it into - non-overlapping blocks of size `ksize` and applying the median function to - each block. The result is a downsampled image where each pixel value - represents the median value within the corresponding block of the - original image. This is useful for reducing the size of an image while - retaining the most significant features. + - NumPy backend: `scipy.ndimage.median_filter` + - Torch backend: explicit unfolding followed by `torch.median` Parameters ---------- ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional parameters sent to the pooling function. + Kernel size. + **kwargs: dict + Additional parameters sent to the blurring function. + + Notes + ----- + Torch median blurring is significantly more expensive than mean or + Gaussian blurring due to explicit tensor unfolding. Examples -------- >>> import deeptrack as dt >>> import numpy as np + >>> import matplotlib.pyplot as plt Create an input image: >>> input_image = np.random.rand(32, 32) - Define a median pooling feature: - >>> median_pooling = dt.MedianPooling(ksize=3) - >>> output_image = median_pooling(input_image) + Define a median blur feature: + >>> median_blur = dt.MedianBlur(ksize=3) + >>> output_image = median_blur(input_image) >>> print(output_image.shape) (32, 32) @@ -1638,37 +1146,99 @@ class MedianPooling(Pool): >>> plt.imshow(output_image, cmap='gray') >>> plt.show() - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - """ def __init__( - self: MedianPooling, + self: MedianBlur, ksize: PropertyLike[int] = 3, **kwargs: Any, ): - """Initialize the parameters for median pooling. + self.ksize = int(ksize) + super().__init__(None, **kwargs) + + def _get_numpy( + self, + input: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + return ndimage.median_filter( + input, + size=self.ksize, + mode=kwargs.get("mode", "reflect"), + cval=kwargs.get("cval", 0), + ) - This constructor initializes the parameters for median pooling. + def _get_torch( + self, + input: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + import torch.nn.functional as F - Parameters - ---------- - ksize: int - Size of the pooling kernel. - **kwargs: Any - Additional keyword arguments. + k = self.ksize + if k % 2 == 0: + raise ValueError("MedianBlur requires an odd kernel size.") - """ + last_dim_is_channel = input.ndim >= 3 + if last_dim_is_channel: + input = input.movedim(-1, 0) # C, ... + else: + input = input.unsqueeze(0) # 1, ... + + # add batch dimension + input = input.unsqueeze(0) # 1, C, ... - super().__init__(np.median, ksize=ksize, **kwargs) + spatial_dims = input.ndim - 2 + pad = k // 2 + # padding + pad_tuple = [] + for _ in range(spatial_dims): + pad_tuple.extend([pad, pad]) + pad_tuple = tuple(reversed(pad_tuple)) -class PoolV2: + input = F.pad( + input, + pad_tuple, + mode=kwargs.get("mode", "reflect"), + ) + + # unfold spatial dimensions + if spatial_dims == 1: + x = input.unfold(2, k, 1) + elif spatial_dims == 2: + x = ( + input + .unfold(2, k, 1) + .unfold(3, k, 1) + ) + elif spatial_dims == 3: + x = ( + input + .unfold(2, k, 1) + .unfold(3, k, 1) + .unfold(4, k, 1) + ) + else: + raise NotImplementedError( + f"{spatial_dims}D median blur not supported" + ) + + # flatten neighborhood and take median + x = x.contiguous().view(*x.shape[:-spatial_dims], -1) + x = x.median(dim=-1).values + + # restore layout + x = x.squeeze(0) + if last_dim_is_channel: + x = x.movedim(0, -1) + else: + x = x.squeeze(0) + + return x + +#TODO ***CM*** revise typing, docstring, unit test +class Pool: """ DeepTrack v2 replacement for Pool. @@ -1850,31 +1420,31 @@ def __call__(self, array): return self._pool_torch(array) raise TypeError( - "PoolV2 only supports np.ndarray or torch.Tensor inputs." + "Pool only supports np.ndarray or torch.Tensor inputs." ) -class AveragePoolingV2(PoolV2): +class AveragePooling(Pool): def __init__(self, ksize: int = 2): super().__init__(np.mean, ksize) -class SumPoolingV2(PoolV2): +class SumPooling(Pool): def __init__(self, ksize: int = 2): super().__init__(np.sum, ksize) -class MinPoolingV2(PoolV2): +class MinPooling(Pool): def __init__(self, ksize: int = 2): super().__init__(np.min, ksize) -class MaxPoolingV2(PoolV2): +class MaxPooling(Pool): def __init__(self, ksize: int = 2): super().__init__(np.max, ksize) -class MedianPoolingV2(PoolV2): +class MedianPooling(Pool): def __init__(self, ksize: int = 2): super().__init__(np.median, ksize) @@ -1955,10 +1525,10 @@ def __init__( def get( self: Resize, - image: NDArray | torch.Tensor, + image: np.ndarray | torch.Tensor, dsize: tuple[int, int], **kwargs: Any, - ) -> NDArray | torch.Tensor: + ) -> np.ndarray | torch.Tensor: """Resize the input image to the specified size. Parameters @@ -1991,9 +1561,6 @@ def get( """ - if self._wrap_array_with_image: - image = strip(image) - if apc.is_torch_array(image): original_shape = image.shape @@ -2080,13 +1647,6 @@ class BlurCV2(Feature): >>> print(output_image.shape) (32, 32) - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - """ def __new__( @@ -2230,13 +1790,6 @@ class BilateralBlur(BlurCV2): >>> print(output_image.shape) (32, 32) - Notes - ----- - Calling this feature returns a `np.ndarray` by default. If - `store_properties` is set to `True`, the returned array will be - automatically wrapped in an `Image` object. This behavior is handled - internally and does not affect the return type of the `get()` method. - """ def __init__( diff --git a/deeptrack/optics.py b/deeptrack/optics.py index bb3cf40f..eeb2fca1 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -151,7 +151,7 @@ def _pad_volume( get_active_scale, get_active_voxel_size, ) -from deeptrack.math import AveragePoolingV2, SumPoolingV2 +from deeptrack.math import AveragePooling, SumPooling from deeptrack.features import propagate_data_to_dependencies from deeptrack.features import DummyFeature, Feature, StructuralFeature from deeptrack.image import pad_image_to_fft @@ -1840,6 +1840,22 @@ def extract_contrast_volume( return (value ** 2) * scattered.array + def downscale_image(self, image: np.ndarray, upscale): + """Detector downscaling (energy conserving)""" + if not np.any(np.array(upscale) != 1): + return image + + ux, uy = upscale[:2] + if ux != uy: + raise ValueError( + f"Energy-conserving detector integration requires ux == uy, " + f"got ux={ux}, uy={uy}." + ) + if isinstance(ux, float) and ux.is_integer(): + ux = int(ux) + + # Energy-conserving detector integration + return SumPoolingV2(ux)(image) #Retrieve get as super def get( @@ -3007,7 +3023,7 @@ def _process_and_get( ) return output - + #TODO ***??*** revise _get_position - torch, typing, docstring, unit test def _get_position( From bee5de01eee6e12d14a1bd87c9f762a4cb61ee24 Mon Sep 17 00:00:00 2001 From: Carlo Date: Fri, 16 Jan 2026 11:22:49 +0100 Subject: [PATCH 23/24] u --- deeptrack/math.py | 203 +++++++++++++++++++++++++++------------- deeptrack/optics.py | 6 +- deeptrack/scatterers.py | 14 +-- 3 files changed, 147 insertions(+), 76 deletions(-) diff --git a/deeptrack/math.py b/deeptrack/math.py index 5845fb90..2f8b7339 100644 --- a/deeptrack/math.py +++ b/deeptrack/math.py @@ -1453,54 +1453,67 @@ def __init__(self, ksize: int = 2): class Resize(Feature): """Resize an image to a specified size. - `Resize` resizes an image using: - - OpenCV (`cv2.resize`) for NumPy arrays. - - PyTorch (`torch.nn.functional.interpolate`) for PyTorch tensors. + `Resize` resizes images following the channels-last semantic + convention. - The interpretation of the `dsize` parameter follows the convention - of the underlying backend: - - **NumPy (OpenCV)**: `dsize` is given as `(width, height)` to match - OpenCV’s default. - - **PyTorch**: `dsize` is given as `(height, width)`. + The operation supports both NumPy arrays and PyTorch tensors: + - NumPy arrays are resized using OpenCV (`cv2.resize`). + - PyTorch tensors are resized using `torch.nn.functional.interpolate`. + + In all cases, the input is interpreted as having spatial dimensions + first and an optional channel dimension last. Parameters ---------- - dsize: PropertyLike[tuple[int, int]] - The target size. Format depends on backend: `(width, height)` for - NumPy, `(height, width)` for PyTorch. - **kwargs: Any - Additional parameters sent to the underlying resize function: - - NumPy: passed to `cv2.resize`. - - PyTorch: passed to `torch.nn.functional.interpolate`. + dsize : PropertyLike[tuple[int, int]] + Target output size given as (width, height). This convention is + backend-independent and applies equally to NumPy and PyTorch inputs. + + **kwargs : Any + Additional keyword arguments forwarded to the underlying resize + implementation: + - NumPy backend: passed to `cv2.resize`. + - PyTorch backend: passed to + `torch.nn.functional.interpolate`. Methods ------- get( - image: np.ndarray | torch.Tensor, dsize: tuple[int, int], **kwargs + image: np.ndarray | torch.Tensor, + dsize: tuple[int, int], + **kwargs ) -> np.ndarray | torch.Tensor Resize the input image to the specified size. Examples -------- - >>> import deeptrack as dt + NumPy example: - Numpy example: >>> import numpy as np - >>> - >>> input_image = np.random.rand(16, 16) # Create image - >>> feature = dt.math.Resize(dsize=(8, 4)) # (width=8, height=4) - >>> resized_image = feature.resolve(input_image) # Resize it to (4, 8) - >>> print(resized_image.shape) + >>> input_image = np.random.rand(16, 16) + >>> feature = dt.math.Resize(dsize=(8, 4)) # (width=8, height=4) + >>> resized_image = feature.resolve(input_image) + >>> resized_image.shape (4, 8) PyTorch example: + >>> import torch - >>> - >>> input_image = torch.rand(1, 1, 16, 16) # Create image - >>> feature = dt.math.Resize(dsize=(4, 8)) # (height=4, width=8) - >>> resized_image = feature.resolve(input_image) # Resize it to (4, 8) - >>> print(resized_image.shape) - torch.Size([1, 1, 4, 8]) + >>> input_image = torch.rand(16, 16) # channels-last + >>> feature = dt.math.Resize(dsize=(8, 4)) + >>> resized_image = feature.resolve(input_image) + >>> resized_image.shape + torch.Size([4, 8]) + + Notes + ----- + - Resize follows channels-last semantics, consistent with other features + such as Pool and Blur. + - Torch tensors with channels-first layout (e.g. (C, H, W) or + (N, C, H, W)) are not supported and must be converted to + channels-last format before resizing. + - For PyTorch tensors, bilinear interpolation is used with + `align_corners=False`, closely matching OpenCV’s default behavior. """ @@ -1533,67 +1546,109 @@ def get( Parameters ---------- - image: np.ndarray or torch.Tensor - The input image to resize. - - NumPy arrays may be grayscale (H, W) or color (H, W, C). - - Torch tensors are expected in one of the following formats: - (N, C, H, W), (C, H, W), or (H, W). - dsize: tuple[int, int] - Desired output size of the image. - - NumPy: (width, height) - - PyTorch: (height, width) - **kwargs: Any - Additional keyword arguments passed to the underlying resize - function (`cv2.resize` or `torch.nn.functional.interpolate`). + image : np.ndarray or torch.Tensor + Input image following channels-last semantics. + + Supported shapes are: + - (H, W) + - (H, W, C) + - (Z, H, W) + - (Z, H, W, C) + + For PyTorch tensors, channels-first layouts such as (C, H, W) or + (N, C, H, W) are not supported and must be converted to + channels-last format before calling `Resize`. + + dsize : tuple[int, int] + Desired output size given as (width, height). This convention is + backend-independent and applies to both NumPy and PyTorch inputs. + + **kwargs : Any + Additional keyword arguments passed to the underlying resize + implementation: + - NumPy backend: forwarded to `cv2.resize`. + - PyTorch backend: forwarded to `torch.nn.functional.interpolate`. Returns ------- np.ndarray or torch.Tensor - The resized image in the same type and dimensionality format as - input. + The resized image, with the same type and dimensionality layout as + the input image. Notes ----- + - Resize follows the same channels-last semantic convention as other + features in `deeptrack.math`. - For PyTorch tensors, resizing uses bilinear interpolation with - `align_corners=False`. This choice matches OpenCV’s `cv2.resize` - default behavior when resizing NumPy arrays, aiming to produce nearly - identical results between both backends. + `align_corners=False`, which closely matches OpenCV’s default behavior. """ + target_w, target_h = dsize + + # Torch backend if apc.is_torch_array(image): - original_shape = image.shape - - # Reshape input to (N, C, H, W) - if image.ndim == 2: # (H, W) - image = image.unsqueeze(0).unsqueeze(0) - elif image.ndim == 3: # (C, H, W) - image = image.unsqueeze(0) - elif image.ndim != 4: + import torch.nn.functional as F + + original_ndim = image.ndim + has_channels = ( + image.ndim >= 3 and image.shape[-1] <= 4 + ) + + # Bring to (N, C, H, W) + if image.ndim == 2: + # (H, W) -> (1, 1, H, W) + x = image.unsqueeze(0).unsqueeze(0) + + elif image.ndim == 3 and has_channels: + # (H, W, C) -> (1, C, H, W) + x = image.permute(2, 0, 1).unsqueeze(0) + + elif image.ndim == 3: + # (Z, H, W) -> treat Z as batch + x = image.unsqueeze(1) + + elif image.ndim == 4 and has_channels: + # (Z, H, W, C) -> (Z, C, H, W) + x = image.permute(0, 3, 1, 2) + + else: raise ValueError( - "Resize only supports tensors with shape (N, C, H, W), " - "(C, H, W), or (H, W)." + f"Unsupported tensor shape {image.shape} for Resize." ) - resized = torch.nn.functional.interpolate( - image, - size=dsize, + # Resize spatial dimensions + resized = F.interpolate( + x, + size=(target_h, target_w), mode="bilinear", align_corners=False, ) - # Restore original dimensionality - if len(original_shape) == 2: - resized = resized.squeeze(0).squeeze(0) - elif len(original_shape) == 3: - resized = resized.squeeze(0) + # Restore original layout + if original_ndim == 2: + return resized.squeeze(0).squeeze(0) + + if original_ndim == 3 and has_channels: + return resized.squeeze(0).permute(1, 2, 0) + + if original_ndim == 3: + return resized.squeeze(1) + + if original_ndim == 4: + return resized.permute(0, 2, 3, 1) - return resized + raise RuntimeError("Unexpected shape restoration path.") + # NumPy / OpenCV backend else: import cv2 + + # OpenCV expects (width, height) return utils.safe_call( - cv2.resize, positional_args=[image, dsize], **kwargs + cv2.resize, + positional_args=[image, (target_w, target_h)], + **kwargs, ) @@ -1647,6 +1702,12 @@ class BlurCV2(Feature): >>> print(output_image.shape) (32, 32) + Notes + ----- + BlurCV2 is NumPy-only and does not support PyTorch tensors. + This class is intended for OpenCV-specific filters that are + not available in the backend-agnostic math layer. + """ def __new__( @@ -1741,6 +1802,12 @@ def get( """ + if apc.is_torch_array(image): + raise TypeError( + "BlurCV2 only supports NumPy arrays. " + "For Torch tensors, use Blur or GaussianBlur instead." + ) + kwargs.pop("name", None) result = self.filter(src=image, **kwargs) return result @@ -1790,6 +1857,10 @@ class BilateralBlur(BlurCV2): >>> print(output_image.shape) (32, 32) + Notes + ----- + BilateralBlur is NumPy-only and does not support PyTorch tensors. + """ def __init__( diff --git a/deeptrack/optics.py b/deeptrack/optics.py index eeb2fca1..a7394689 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -288,7 +288,7 @@ def _downscale_image(self, image, upscale): ) if isinstance(ux, float) and ux.is_integer(): ux = int(ux) - return AveragePoolingV2(ux)(image) + return AveragePooling(ux)(image) def get( self: Microscope, @@ -1130,7 +1130,7 @@ def downscale_image(self, image: np.ndarray, upscale): ux = int(ux) # Energy-conserving detector integration - return SumPoolingV2(ux)(image) + return SumPooling(ux)(image) def get( @@ -1855,7 +1855,7 @@ def downscale_image(self, image: np.ndarray, upscale): ux = int(ux) # Energy-conserving detector integration - return SumPoolingV2(ux)(image) + return SumPooling(ux)(image) #Retrieve get as super def get( diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index 2ba202e0..b7e1b70a 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -175,7 +175,7 @@ get_active_voxel_size, ) from deeptrack.backend import mie -from deeptrack.math import AveragePoolingV2 +from deeptrack.math import AveragePooling from deeptrack.features import Feature, MERGE_STRATEGY_APPEND from deeptrack.image import pad_image_to_fft from deeptrack.types import ArrayLike @@ -261,11 +261,11 @@ def __init__( **kwargs, ) -> None: # Ignore warning to help with comparison with arrays. - if upsample != 1: # noqa: F632 - warnings.warn( - f"Setting upsample != 1 is deprecated. " - f"Please, instead use dt.Upscale(f, factor={upsample})" - ) + # if upsample != 1: # noqa: F632 + # warnings.warn( + # f"Setting upsample != 1 is deprecated. " + # f"Please, instead use dt.Upscale(f, factor={upsample})" + # ) self._processed_properties = False @@ -291,7 +291,7 @@ def _antialias_volume(self, volume, factor: int): return volume # average pooling conserves fractional occupancy - return AveragePoolingV2( + return AveragePooling( factor )(volume) From f18884cacc74d12823902fee48796c58d91d5f4d Mon Sep 17 00:00:00 2001 From: Carlo Date: Sun, 18 Jan 2026 04:13:54 +0100 Subject: [PATCH 24/24] implemented BackendDispatched --- deeptrack/features.py | 32 + deeptrack/math.py | 1389 ++++++++++++++++++++++++++--------------- deeptrack/optics.py | 33 +- 3 files changed, 919 insertions(+), 535 deletions(-) diff --git a/deeptrack/features.py b/deeptrack/features.py index 702e7362..9572ea82 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -52,6 +52,11 @@ hierarchical or logical structures in the pipeline without input transformations. +- `BackendDispatched`: Mixin class for backend-specific implementations. + + Provides mechanisms for dispatching feature methods based on the + computational backend (e.g., NumPy, PyTorch). + - `ArithmeticOperationFeature`: Apply arithmetic operation element-wise. A parent class for features performing arithmetic operations like addition, @@ -180,6 +185,7 @@ __all__ = [ "Feature", "StructuralFeature", + "BackendDispatched", "Chain", "Branch", "DummyFeature", @@ -3947,6 +3953,32 @@ class StructuralFeature(Feature): __distributed__: bool = False # Process the entire image list in one call +class BackendDispatched: + """Mixin for Feature.get() methods with backend-specific implementations.""" + + _NUMPY_IMPL: str | None = None + _TORCH_IMPL: str | None = None + + def _dispatch_backend(self, *args, **kwargs): + backend = self.get_backend() + + if backend == "numpy": + if self._NUMPY_IMPL is None: + raise NotImplementedError( + f"{self.__class__.__name__} does not support NumPy backend." + ) + return getattr(self, self._NUMPY_IMPL)(*args, **kwargs) + + if backend == "torch": + if self._TORCH_IMPL is None: + raise NotImplementedError( + f"{self.__class__.__name__} does not support Torch backend." + ) + return getattr(self, self._TORCH_IMPL)(*args, **kwargs) + + raise RuntimeError(f"Unknown backend {backend}") + + class Chain(StructuralFeature): """Resolve two features sequentially. diff --git a/deeptrack/math.py b/deeptrack/math.py index 2f8b7339..40a8adc7 100644 --- a/deeptrack/math.py +++ b/deeptrack/math.py @@ -59,6 +59,8 @@ - `MinPooling`: Apply min-pooling to the image. +- `SumPooling`: Apply sum pooling to the image. + - `MedianPooling`: Apply median pooling to the image. - `Resize`: Resize the image to a specified size. @@ -97,14 +99,12 @@ import array_api_compat as apc import numpy as np -from numpy.typing import NDArray #TODO TBE from scipy import ndimage import skimage import skimage.measure from deeptrack import utils, OPENCV_AVAILABLE, TORCH_AVAILABLE -from deeptrack.features import Feature -from deeptrack.image import Image, strip #TODO TBE +from deeptrack.features import Feature, BackendDispatched from deeptrack.types import PropertyLike from deeptrack.backend import xp @@ -129,7 +129,9 @@ "AveragePooling", "MaxPooling", "MinPooling", + "SumPooling", "MedianPooling", + "Resize", "BlurCV2", "BilateralBlur", ] @@ -297,8 +299,8 @@ class Clip(Feature): def __init__( self: Clip, - min: PropertyLike[float] = -np.inf, - max: PropertyLike[float] = +np.inf, + min: PropertyLike[float] = -xp.inf, + max: PropertyLike[float] = +xp.inf, **kwargs: Any, ): """Initialize the clipping range. @@ -306,9 +308,9 @@ def __init__( Parameters ---------- min: float, optional - Minimum allowed value. It defaults to `-np.inf`. + Minimum allowed value. It defaults to `-xp.inf`. max: float, optional - Maximum allowed value. It defaults to `+np.inf`. + Maximum allowed value. It defaults to `+xp.inf`. **kwargs: Any Additional keyword arguments. @@ -441,31 +443,31 @@ def get( """ - if featurewise: - # Normalize per feature (last axis) - axis = tuple(range(image.ndim - 1)) + has_channels = image.ndim >= 3 and image.shape[-1] <= 4 + if featurewise and has_channels: + # reduce over spatial dimensions only + axis = tuple(range(image.ndim - 1)) img_min = xp.min(image, axis=axis, keepdims=True) img_max = xp.max(image, axis=axis, keepdims=True) else: - # Normalize globally + # global normalization img_min = xp.min(image) img_max = xp.max(image) ptp = img_max - img_min + eps = xp.asarray(1e-8, dtype=image.dtype) + ptp = xp.maximum(ptp, eps) - # Avoid division by zero - image = (image - img_min) / ptp * (max - min) + min - - try: - image[xp.isnan(image)] = 0 - except TypeError: - pass + image = (image - img_min) / ptp + image = image * (max - min) + min + image = xp.where(xp.isnan(image), xp.zeros_like(image), image) return image -class NormalizeStandard(Feature): + +class NormalizeStandard(BackendDispatched, Feature): """Image normalization using standardization. Standardizes the input image to have zero mean and unit standard @@ -499,6 +501,9 @@ class NormalizeStandard(Feature): """ + _NUMPY_IMPL = "_get_numpy" + _TORCH_IMPL = "_get_torch" + def __init__( self: NormalizeStandard, featurewise: PropertyLike[bool] = True, @@ -539,37 +544,63 @@ def get( np.ndarray or torch.Tensor The standardized image. """ + return self._dispatch_backend( + image, + featurewise=featurewise, + ) - if featurewise: - # Normalize per feature (last axis) - axis = tuple(range(image.ndim - 1)) + # ------ NumPy backend ------ + + def _get_numpy( + self, + image: np.ndarray, + featurewise: bool, + ) -> np.ndarray: - mean = xp.mean(image, axis=axis, keepdims=True) + has_channels = image.ndim >= 3 and image.shape[-1] <= 4 - if apc.is_torch_array(image): - std = torch.std(image, dim=axis, keepdim=True, unbiased=False) - else: - std = xp.std(image, axis=axis) + if featurewise and has_channels: + axis = tuple(range(image.ndim - 1)) + mean = np.mean(image, axis=axis, keepdims=True) + std = np.std(image, axis=axis, keepdims=True) # population std else: - # Normalize globally - mean = xp.mean(image) + mean = np.mean(image) + std = np.std(image) - if apc.is_torch_array(image): - std = torch.std(image, unbiased=False) - else: - std = xp.std(image) + std = np.maximum(std, 1e-8) - image = (image - mean) / std + out = (image - mean) / std + out = np.where(np.isnan(out), 0.0, out) - try: - image[xp.isnan(image)] = 0 - except TypeError: - pass + return out - return image + # ------ Torch backend ------ + + def _get_torch( + self, + image: torch.Tensor, + featurewise: bool, + ) -> torch.Tensor: + + has_channels = image.ndim >= 3 and image.shape[-1] <= 4 + + if featurewise and has_channels: + axis = tuple(range(image.ndim - 1)) + mean = image.mean(dim=axis, keepdim=True) + std = image.std(dim=axis, keepdim=True, unbiased=False) + else: + mean = image.mean() + std = image.std(unbiased=False) + std = torch.clamp(std, min=1e-8) -class NormalizeQuantile(Feature): + out = (image - mean) / std + out = torch.nan_to_num(out, nan=0.0) + + return out + + +class NormalizeQuantile(BackendDispatched, Feature): """Image normalization using quantiles. Centers the image at the median and scales it such that the values at the @@ -589,6 +620,12 @@ class NormalizeQuantile(Feature): get(image: array, quantiles: tuple[float, float], **kwargs) -> array Normalizes the input based on the given quantile range. + Notes + ----- + This operation is not differentiable. When used inside a gradient-based + model, it will block gradient flow. Use with care if end-to-end + differentiability is required. + Examples -------- >>> import deeptrack as dt @@ -607,7 +644,8 @@ class NormalizeQuantile(Feature): """ - #TODO ___??___ Implement the `featurewise=False` option + _NUMPY_IMPL = "_get_numpy" + _TORCH_IMPL = "_get_torch" def __init__( self: NormalizeQuantile, @@ -637,12 +675,25 @@ def __init__( ) def get( - self: NormalizeQuantile, + self, image: np.ndarray | torch.Tensor, quantiles: tuple[float, float], featurewise: bool, **kwargs: Any, - ) -> np.ndarray | torch.Tensor: + ): + return self._dispatch_backend( + image, + quantiles=quantiles, + featurewise=featurewise, + ) + + def _get_numpy( + self: NormalizeQuantile, + image: np.ndarray, + quantiles: tuple[float, float], + featurewise: bool, + **kwargs: Any, + ) -> np.ndarray: """Normalize the input image based on the specified quantiles. Parameters @@ -658,140 +709,127 @@ def get( ------- np.ndarray or torch.Tensor The quantile-normalized image. + """ q_low_val, q_high_val = quantiles - if featurewise: - # Per-feature normalization (last axis) + has_channels = image.ndim >= 3 and image.shape[-1] <= 4 + + if featurewise and has_channels: axis = tuple(range(image.ndim - 1)) + q_low, q_high, median = np.quantile( + image, + (q_low_val, q_high_val, 0.5), + axis=axis, + keepdims=True, + ) + else: + q_low, q_high, median = np.quantile( + image, + (q_low_val, q_high_val, 0.5), + ) + + scale = q_high - q_low + eps = np.asarray(1e-8, dtype=image.dtype) + scale = np.maximum(scale, eps) + + image = (image - median) / scale + image = np.where(np.isnan(image), np.zeros_like(image), image) + return image + + def _get_torch( + self, + image: torch.Tensor, + quantiles: tuple[float, float], + featurewise: bool, + ): + q_low_val, q_high_val = quantiles - if apc.is_torch_array(image): + if featurewise: + if image.ndim < 3: + # No channels → global quantile q = torch.tensor( [q_low_val, q_high_val, 0.5], device=image.device, dtype=image.dtype, ) - q_low, q_high, median = torch.quantile( - image, q, dim=axis, keepdim=True - ) + q_low, q_high, median = torch.quantile(image, q) else: - q_low, q_high, median = xp.quantile( - image, (q_low_val, q_high_val, 0.5), - axis=axis, - keepdims=True, - ) - else: - # Global normalization - if apc.is_torch_array(image): + # channels-last: (..., C) + spatial_dims = image.ndim - 1 + C = image.shape[-1] + + # flatten spatial dims + x = image.reshape(-1, C) # (N, C) + q = torch.tensor( [q_low_val, q_high_val, 0.5], device=image.device, dtype=image.dtype, ) - q_low, q_high, median = torch.quantile( - image, q, dim=None, keepdim=False - ) - else: - q_low, q_high, median = xp.quantile( - image, (q_low_val, q_high_val, 0.5) - ) - image = (image - median) / (q_high - q_low) * 2.0 + q_vals = torch.quantile(x, q, dim=0) + q_low, q_high, median = q_vals - try: - image[xp.isnan(image)] = 0 - except TypeError: - pass + # reshape for broadcasting + shape = [1] * image.ndim + shape[-1] = C + q_low = q_low.view(shape) + q_high = q_high.view(shape) + median = median.view(shape) - return image + else: + q = torch.tensor( + [q_low_val, q_high_val, 0.5], + device=image.device, + dtype=image.dtype, + ) + q_low, q_high, median = torch.quantile(image, q) + scale = q_high - q_low + scale = torch.clamp(scale, min=1e-8) + image = (image - median) / scale + image = torch.nan_to_num(image) -#TODO ***CM*** revise typing, docstring, unit test -class Blur(Feature): - """Apply a blurring filter to an image. + return image - This class acts as a backend-dispatching blur operator. Subclasses must - implement backend-specific logic via `_get_numpy` and optionally - `_get_torch`. - Notes - ----- - - NumPy execution is always supported. - - Torch execution is only supported if `_get_torch` is implemented. - - Generic `filter_function`-based blurs are NumPy-only by design. +#TODO ***CM*** revise typing, docstring, unit test +class Blur(BackendDispatched, Feature): + """Abstract blur feature with backend-dispatched implementations. + + This class serves as a base for blur features that support multiple + backends (e.g., NumPy, Torch). Subclasses should implement backend-specific + blurring logic via `_get_numpy` and/or `_get_torch` methods. + + Methods + ------- + get(image: np.ndarray | torch.Tensor, **kwargs) -> np.ndarray | torch.Tensor + Applies the appropriate backend-specific blurring method. + + _blur(xp, image: array, **kwargs) -> array + Internal method that dispatches to the correct backend-specific blur + implementation. """ - def __init__( - self, - filter_function: Callable | None = None, - mode: PropertyLike[str] = "reflect", - **kwargs: Any, - ): - """Initialize the blur feature. - - Parameters - ---------- - filter_function : Callable or None - NumPy-based blurring function. Must accept the input image as a - keyword argument named `input`. If `None`, the subclass must - implement `_get_numpy`. - mode : str - Border mode for NumPy-based filters. - **kwargs : Any - Additional keyword arguments passed to Feature. - """ - self.filter = filter_function - self.mode = mode - super().__init__(**kwargs) + _NUMPY_IMPL = "_get_numpy" + _TORCH_IMPL = "_get_torch" - def __call__( + def get( self, image: np.ndarray | torch.Tensor, - **kwargs: Any, - ) -> np.ndarray | torch.Tensor: - if isinstance(image, np.ndarray): - return self._get_numpy(image, **kwargs) - - if TORCH_AVAILABLE and isinstance(image, torch.Tensor): - return self._get_torch(image, **kwargs) - - raise TypeError( - "Blur only supports numpy.ndarray or torch.Tensor inputs." - ) - - def _get_numpy( - self, - image: np.ndarray, - **kwargs: Any, - ) -> np.ndarray: - if self.filter is None: - raise NotImplementedError( - f"{self.__class__.__name__} does not implement a NumPy backend." - ) - - # Avoid passing conflicting keywords - kwargs = dict(kwargs) - kwargs.pop("input", None) + **kwargs, + ): + return self._dispatch_backend(image, **kwargs) - return utils.safe_call( - self.filter, - input=image, - mode=self.mode, - **kwargs, - ) + def _get_numpy(self, image: np.ndarray, **kwargs): + raise NotImplementedError - def _get_torch( - self, - image: torch.Tensor, - **kwargs: Any, - ) -> torch.Tensor: - raise TypeError( - f"{self.__class__.__name__} does not support torch.Tensor inputs. " - "Use a Torch-enabled blur (e.g. AverageBlur or a V2 blur class)." - ) + def _get_torch(self, image: torch.Tensor, **kwargs): + raise NotImplementedError @@ -829,10 +867,10 @@ class AverageBlur(Blur): """ def __init__( - self: AverageBlur, - ksize: PropertyLike[int] = 3, - **kwargs: Any, - ): + self: AverageBlur, + ksize: int = 3, + **kwargs: Any + ) -> None: """Initialize the parameters for averaging input features. This constructor initializes the parameters for averaging input @@ -847,106 +885,119 @@ def __init__( """ - super().__init__(None, ksize=ksize, **kwargs) + self.ksize = int(ksize) + super().__init__(**kwargs) - def _kernel_shape(self, shape: tuple[int, ...], ksize: int) -> tuple[int, ...]: + @staticmethod + def _kernel_shape(shape: tuple[int, ...], ksize: int) -> tuple[int, ...]: + # If last dim is channel and smaller than kernel, do not blur channels if shape[-1] < ksize: return (ksize,) * (len(shape) - 1) + (1,) return (ksize,) * len(shape) + # ---------- NumPy backend ---------- def _get_numpy( - self, input: np.ndarray, ksize: tuple[int, ...], **kwargs: Any + self: AverageBlur, + image: np.ndarray, + **kwargs: Any ) -> np.ndarray: + """Apply average blurring using SciPy's uniform_filter. + + This method applies average blurring to the input image using + SciPy's `uniform_filter`. + + Parameters + ---------- + image: np.ndarray + The input image to blur. + **kwargs: dict[str, Any] + Additional keyword arguments for `uniform_filter`. + + Returns + ------- + np.ndarray + The blurred image. + + """ + + k = self._kernel_shape(image.shape, self.ksize) return ndimage.uniform_filter( - input, - size=ksize, + image, + size=k, mode=kwargs.get("mode", "reflect"), cval=kwargs.get("cval", 0), origin=kwargs.get("origin", 0), - axes=tuple(range(0, len(ksize))), + axes=tuple(range(len(k))), ) + # ---------- Torch backend ---------- def _get_torch( - self, input: torch.Tensor, ksize: tuple[int, ...], **kwargs: Any + self: AverageBlur, + image: torch.Tensor, + **kwargs: Any ) -> torch.Tensor: + """Apply average blurring using PyTorch's avg_pool. - last_dim_is_channel = len(ksize) < input.ndim + This method applies average blurring to the input image using + PyTorch's `avg_pool` functions. + + Parameters + ---------- + image: torch.Tensor + The input image to blur. + **kwargs: dict[str, Any] + Additional keyword arguments for padding. + + Returns + ------- + torch.Tensor + The blurred image. + + """ + + k = self._kernel_shape(tuple(image.shape), self.ksize) + + last_dim_is_channel = len(k) < image.ndim if last_dim_is_channel: - input = input.movedim(-1, 0) + image = image.movedim(-1, 0) # C, ... else: - input = input.unsqueeze(0) + image = image.unsqueeze(0) # 1, ... # add batch dimension - input = input.unsqueeze(0) + image = image.unsqueeze(0) # 1, C, ... - # dynamic padding + # symmetric padding pad = [] - for k in reversed(ksize): - p = k // 2 + for kk in reversed(k): + p = kk // 2 pad.extend([p, p]) - pad = tuple(pad) - - input = F.pad( - input, - pad, + image = F.pad( + image, + tuple(pad), mode=kwargs.get("mode", "reflect"), value=kwargs.get("cval", 0), ) - if input.ndim == 3: - x = F.avg_pool1d(input, kernel_size=ksize, stride=1) - elif input.ndim == 4: - x = F.avg_pool2d(input, kernel_size=ksize, stride=1) - elif input.ndim == 5: - x = F.avg_pool3d(input, kernel_size=ksize, stride=1) + # pooling by dimensionality + if image.ndim == 3: + out = F.avg_pool1d(image, kernel_size=k, stride=1) + elif image.ndim == 4: + out = F.avg_pool2d(image, kernel_size=k, stride=1) + elif image.ndim == 5: + out = F.avg_pool3d(image, kernel_size=k, stride=1) else: raise NotImplementedError( - f"Input dimension {input.ndim - 2} not supported for torch backend" + f"Input dimensionality {image.ndim - 2} not supported" ) # restore layout - x = x.squeeze(0) + out = out.squeeze(0) if last_dim_is_channel: - x = x.movedim(0, -1) + out = out.movedim(0, -1) else: - x = x.squeeze(0) - - return x - - def get( - self: AverageBlur, - input: np.ndarray | torch.Tensor, - ksize: int, - **kwargs: Any, - ) -> np.ndarray | torch.Tensor: - """Applies the average blurring filter to the input image. - - This method applies the average blurring filter to the input image. - - Parameters - ---------- - input: np.ndarray - The input image to blur. - ksize: int - Kernel size for the pooling operation. - **kwargs: dict[str, Any] - Additional keyword arguments. + out = out.squeeze(0) - Returns - ------- - np.ndarray - The blurred image. - - """ - - k = self._kernel_shape(input.shape, ksize) - - if self.backend == "numpy": - return self._get_numpy(input, k, **kwargs) - elif self.backend == "torch": - return self._get_torch(input, k, **kwargs) - else: - raise NotImplementedError(f"Backend {self.backend} not supported") + return out #TODO ***CM*** revise typing, docstring, unit test @@ -1004,27 +1055,32 @@ def __init__(self: GaussianBlur, sigma: PropertyLike[float] = 2, **kwargs: Any): self.sigma = float(sigma) super().__init__(None, **kwargs) + # ---------- NumPy backend ---------- + def _get_numpy( self, - input: np.ndarray, + image: np.ndarray, **kwargs: Any, ) -> np.ndarray: return ndimage.gaussian_filter( - input, + image, sigma=self.sigma, mode=kwargs.get("mode", "reflect"), cval=kwargs.get("cval", 0), ) + # ---------- Torch backend ---------- + + @staticmethod def _gaussian_kernel_1d( - self, sigma: float, device, dtype, ) -> torch.Tensor: radius = int(np.ceil(3 * sigma)) x = torch.arange( - -radius, radius + 1, + -radius, + radius + 1, device=device, dtype=dtype, ) @@ -1034,29 +1090,29 @@ def _gaussian_kernel_1d( def _get_torch( self, - input: torch.Tensor, + image: torch.Tensor, **kwargs: Any, ) -> torch.Tensor: import torch.nn.functional as F - sigma = self.sigma kernel_1d = self._gaussian_kernel_1d( - sigma, - device=input.device, - dtype=input.dtype, + self.sigma, + device=image.device, + dtype=image.dtype, ) - last_dim_is_channel = input.ndim >= 3 + # channel-last handling + last_dim_is_channel = image.ndim >= 3 if last_dim_is_channel: - input = input.movedim(-1, 0) # C, ... + image = image.movedim(-1, 0) # C, ... else: - input = input.unsqueeze(0) # 1, ... + image = image.unsqueeze(0) # 1, ... # add batch dimension - input = input.unsqueeze(0) # 1, C, ... + image = image.unsqueeze(0) # 1, C, ... - spatial_dims = input.ndim - 2 - C = input.shape[1] + spatial_dims = image.ndim - 2 + C = image.shape[1] for d in range(spatial_dims): k = kernel_1d @@ -1071,31 +1127,31 @@ def _get_torch( pad[-(2 * d + 1)] = radius pad = tuple(pad) - input = F.pad( - input, + image = F.pad( + image, pad, mode=kwargs.get("mode", "reflect"), ) if spatial_dims == 1: - input = F.conv1d(input, k, groups=C) + image = F.conv1d(image, k, groups=C) elif spatial_dims == 2: - input = F.conv2d(input, k, groups=C) + image = F.conv2d(image, k, groups=C) elif spatial_dims == 3: - input = F.conv3d(input, k, groups=C) + image = F.conv3d(image, k, groups=C) else: raise NotImplementedError( f"{spatial_dims}D Gaussian blur not supported" ) # restore layout - input = input.squeeze(0) + image = image.squeeze(0) if last_dim_is_channel: - input = input.movedim(0, -1) + image = image.movedim(0, -1) else: - input = input.squeeze(0) + image = image.squeeze(0) - return input + return image #TODO ***JH*** revise MedianBlur - torch, typing, docstring, unit test @@ -1123,6 +1179,10 @@ class MedianBlur(Blur): Torch median blurring is significantly more expensive than mean or Gaussian blurring due to explicit tensor unfolding. + Median blur is not differentiable. This is typically acceptable, as the + operation is intended for denoising and preprocessing rather than as a + trainable network layer. + Examples -------- >>> import deeptrack as dt @@ -1156,21 +1216,25 @@ def __init__( self.ksize = int(ksize) super().__init__(None, **kwargs) + # ---------- NumPy backend ---------- + def _get_numpy( self, - input: np.ndarray, + image: np.ndarray, **kwargs: Any, ) -> np.ndarray: return ndimage.median_filter( - input, + image, size=self.ksize, mode=kwargs.get("mode", "reflect"), cval=kwargs.get("cval", 0), ) + # ---------- Torch backend ---------- + def _get_torch( self, - input: torch.Tensor, + image: torch.Tensor, **kwargs: Any, ) -> torch.Tensor: import torch.nn.functional as F @@ -1179,42 +1243,36 @@ def _get_torch( if k % 2 == 0: raise ValueError("MedianBlur requires an odd kernel size.") - last_dim_is_channel = input.ndim >= 3 + last_dim_is_channel = image.ndim >= 3 if last_dim_is_channel: - input = input.movedim(-1, 0) # C, ... + image = image.movedim(-1, 0) # C, ... else: - input = input.unsqueeze(0) # 1, ... + image = image.unsqueeze(0) # 1, ... # add batch dimension - input = input.unsqueeze(0) # 1, C, ... + image = image.unsqueeze(0) # 1, C, ... - spatial_dims = input.ndim - 2 + spatial_dims = image.ndim - 2 pad = k // 2 - # padding pad_tuple = [] for _ in range(spatial_dims): pad_tuple.extend([pad, pad]) pad_tuple = tuple(reversed(pad_tuple)) - input = F.pad( - input, + image = F.pad( + image, pad_tuple, mode=kwargs.get("mode", "reflect"), ) - # unfold spatial dimensions if spatial_dims == 1: - x = input.unfold(2, k, 1) + x = image.unfold(2, k, 1) elif spatial_dims == 2: - x = ( - input - .unfold(2, k, 1) - .unfold(3, k, 1) - ) + x = image.unfold(2, k, 1).unfold(3, k, 1) elif spatial_dims == 3: x = ( - input + image .unfold(2, k, 1) .unfold(3, k, 1) .unfold(4, k, 1) @@ -1224,11 +1282,9 @@ def _get_torch( f"{spatial_dims}D median blur not supported" ) - # flatten neighborhood and take median x = x.contiguous().view(*x.shape[:-spatial_dims], -1) x = x.median(dim=-1).values - # restore layout x = x.squeeze(0) if last_dim_is_channel: x = x.movedim(0, -1) @@ -1238,219 +1294,479 @@ def _get_torch( return x #TODO ***CM*** revise typing, docstring, unit test -class Pool: - """ - DeepTrack v2 replacement for Pool. - - Generic, center-preserving block pooling with NumPy and Torch backends. - Public API matches v1: a single integer ksize. - - Pool size semantics: - - 2D input -> (ksize, ksize, 1) - - 3D input -> (ksize, ksize, ksize) - """ - - _TORCH_REDUCERS_2D: Dict[Callable, Callable] = { - np.mean: lambda x, k, s: F.avg_pool2d(x, k, s), - np.sum: lambda x, k, s: F.avg_pool2d(x, k, s) * (k[0] * k[1]), - np.max: lambda x, k, s: F.max_pool2d(x, k, s), - np.min: lambda x, k, s: -F.max_pool2d(-x, k, s), - } +class Pool(BackendDispatched, Feature): + """Abstract base class for pooling features.""" - _TORCH_REDUCERS_3D: Dict[Callable, Callable] = { - np.mean: lambda x, k, s: F.avg_pool3d(x, k, s), - np.sum: lambda x, k, s: F.avg_pool3d(x, k, s) * (k[0] * k[1] * k[2]), - np.max: lambda x, k, s: F.max_pool3d(x, k, s), - np.min: lambda x, k, s: -F.max_pool3d(-x, k, s), - } + _NUMPY_IMPL = "_get_numpy" + _TORCH_IMPL = "_get_torch" def __init__( self, - pooling_function: Callable, - ksize: int = 2, + ksize: PropertyLike[int] = 2, + **kwargs: Any, ): - if pooling_function not in ( - np.mean, np.sum, np.min, np.max, np.median - ): - raise ValueError( - "Unsupported pooling_function. " - "Use one of: np.mean, np.sum, np.min, np.max, np.median." - ) - - if not isinstance(ksize, int) or ksize < 1: - raise ValueError("ksize must be a positive integer.") - - self.pooling_function = pooling_function self.ksize = int(ksize) + super().__init__(**kwargs) - def _get_pool_size(self, array) -> Tuple[int, int, int]: - """ - Determine pooling kernel size based on semantic dimensionality. + def get( + self, + image: np.ndarray | torch.Tensor, + **kwargs: Any, + ) -> np.ndarray | torch.Tensor: + return self._dispatch_backend(image, **kwargs) - - 2D images: (Nx, Ny) or (Nx, Ny, C) -> pool in x,y only - - 3D volumes: (Nx, Ny, Nz) or (Nx, Ny, Nz, C) -> pool in x,y,z - - Never pool over channels - """ + # ---------- shared helpers ---------- + + def _get_pool_size(self, array) -> tuple[int, int, int]: k = self.ksize - # 2D image if array.ndim == 2: return k, k, 1 - # 3D array: could be (x, y, z) or (x, y, c) if array.ndim == 3: - # Heuristic: small last dim → channels - if array.shape[-1] <= 4: + if array.shape[-1] <= 4: # channel heuristic return k, k, 1 return k, k, k - # 4D array: (x, y, z, c) if array.ndim == 4: return k, k, k - raise ValueError( - f"Unsupported array shape {array.shape} for pooling." - ) + raise ValueError(f"Unsupported array shape {array.shape}") def _crop_center(self, array): px, py, pz = self._get_pool_size(array) - # 2D (or effectively 2D) + # 2D or effectively 2D (channels-last) if array.ndim < 3 or pz == 1: H, W = array.shape[:2] crop_h = (H // px) * px crop_w = (W // py) * py - off_h = (H - crop_h) // 2 - off_w = (W - crop_w) // 2 - return array[ - off_h : off_h + crop_h, - off_w : off_w + crop_w, - ... - ] - - # 3D + return array[:crop_h, :crop_w, ...] + + # 3D volume Z, H, W = array.shape[:3] crop_z = (Z // pz) * pz crop_h = (H // px) * px crop_w = (W // py) * py - off_z = (Z - crop_z) // 2 - off_h = (H - crop_h) // 2 - off_w = (W - crop_w) // 2 - return array[ - off_z : off_z + crop_z, - off_h : off_h + crop_h, - off_w : off_w + crop_w, - ... - ] - - def _pool_numpy(self, array: np.ndarray) -> np.ndarray: - array = self._crop_center(array) - px, py, pz = self._get_pool_size(array) + return array[:crop_z, :crop_h, :crop_w, ...] - if array.ndim < 3 or pz == 1: - pool_shape = (px, py) + (1,) * (array.ndim - 2) + # ---------- abstract backends ---------- + + def _get_numpy(self, image: np.ndarray, **kwargs): + raise NotImplementedError + + def _get_torch(self, image: torch.Tensor, **kwargs): + raise NotImplementedError + + +class AveragePooling(Pool): + """Average pooling feature. + + Downsamples the input by applying mean pooling over non-overlapping + blocks of size `ksize`, preserving the center of the image and never + pooling over channel dimensions. + + Works with NumPy and PyTorch backends. + """ + + # ---------- NumPy backend ---------- + + def _get_numpy( + self, + image: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) + + # 2D or effectively 2D (channels-last) + if image.ndim < 3 or pz == 1: + block_size = (px, py) + (1,) * (image.ndim - 2) else: - pool_shape = (pz, px, py) + (1,) * (array.ndim - 3) + # 3D volume (optionally with channels) + block_size = (pz, px, py) + (1,) * (image.ndim - 3) return skimage.measure.block_reduce( - array, - block_size=pool_shape, - func=self.pooling_function, + image, + block_size=block_size, + func=np.mean, ) - def _pool_torch(self, array: torch.Tensor) -> torch.Tensor: - array = self._crop_center(array) - px, py, pz = self._get_pool_size(array) + # ---------- Torch backend ---------- - is_3d = array.ndim >= 3 and pz > 1 + def _get_torch( + self, + image: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + import torch.nn.functional as F + + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) + is_3d = image.ndim >= 3 and pz > 1 + + # Flatten extra (channel / feature) dimensions into C if not is_3d: - extra = array.shape[2:] + extra = image.shape[2:] C = int(np.prod(extra)) if extra else 1 - x = array.reshape(1, C, array.shape[0], array.shape[1]) + x = image.reshape(1, C, image.shape[0], image.shape[1]) kernel = (px, py) stride = (px, py) - reducers = self._TORCH_REDUCERS_2D + pooled = F.avg_pool2d(x, kernel, stride) else: - extra = array.shape[3:] + extra = image.shape[3:] C = int(np.prod(extra)) if extra else 1 - x = array.reshape( - 1, C, array.shape[0], array.shape[1], array.shape[2] + x = image.reshape( + 1, C, + image.shape[0], + image.shape[1], + image.shape[2], ) kernel = (pz, px, py) stride = (pz, px, py) - reducers = self._TORCH_REDUCERS_3D - - # Median: explicit unfolding - if self.pooling_function is np.median: - if is_3d: - x_u = ( - x.unfold(2, pz, pz) - .unfold(3, px, px) - .unfold(4, py, py) - ) - x_u = x_u.contiguous().view( - 1, C, - x_u.shape[2], - x_u.shape[3], - x_u.shape[4], - -1, - ) - pooled = x_u.median(dim=-1).values - else: - x_u = x.unfold(2, px, px).unfold(3, py, py) - x_u = x_u.contiguous().view( - 1, C, - x_u.shape[2], - x_u.shape[3], - -1, - ) - pooled = x_u.median(dim=-1).values + pooled = F.avg_pool3d(x, kernel, stride) + + # Restore original layout + return pooled.reshape(pooled.shape[2:] + extra) + + +class MaxPooling(Pool): + """Max pooling feature. + + Downsamples the input by applying max pooling over non-overlapping + blocks of size `ksize`, preserving the center of the image and never + pooling over channel dimensions. + + Works with NumPy and PyTorch backends. + """ + + # ---------- NumPy backend ---------- + + def _get_numpy( + self, + image: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) + + # 2D or effectively 2D (channels-last) + if image.ndim < 3 or pz == 1: + block_size = (px, py) + (1,) * (image.ndim - 2) + else: + # 3D volume (optionally with channels) + block_size = (pz, px, py) + (1,) * (image.ndim - 3) + + return skimage.measure.block_reduce( + image, + block_size=block_size, + func=np.max, + ) + + # ---------- Torch backend ---------- + + def _get_torch( + self, + image: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + import torch.nn.functional as F + + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) + + is_3d = image.ndim >= 3 and pz > 1 + + # Flatten extra (channel / feature) dimensions into C + if not is_3d: + extra = image.shape[2:] + C = int(np.prod(extra)) if extra else 1 + x = image.reshape(1, C, image.shape[0], image.shape[1]) + kernel = (px, py) + stride = (px, py) + pooled = F.max_pool2d(x, kernel, stride) else: - reducer = reducers[self.pooling_function] - pooled = reducer(x, kernel, stride) + extra = image.shape[3:] + C = int(np.prod(extra)) if extra else 1 + x = image.reshape( + 1, C, + image.shape[0], + image.shape[1], + image.shape[2], + ) + kernel = (pz, px, py) + stride = (pz, px, py) + pooled = F.max_pool3d(x, kernel, stride) + # Restore original layout return pooled.reshape(pooled.shape[2:] + extra) - def __call__(self, array): - if isinstance(array, np.ndarray): - return self._pool_numpy(array) - if TORCH_AVAILABLE and isinstance(array, torch.Tensor): - return self._pool_torch(array) +class MinPooling(Pool): + """Min pooling feature. + + Downsamples the input by applying min pooling over non-overlapping + blocks of size `ksize`, preserving the center of the image and never + pooling over channel dimensions. + + Works with NumPy and PyTorch backends. + + """ + + # ---------- NumPy backend ---------- + + def _get_numpy( + self, + image: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) + + # 2D or effectively 2D (channels-last) + if image.ndim < 3 or pz == 1: + block_size = (px, py) + (1,) * (image.ndim - 2) + else: + # 3D volume (optionally with channels) + block_size = (pz, px, py) + (1,) * (image.ndim - 3) - raise TypeError( - "Pool only supports np.ndarray or torch.Tensor inputs." + return skimage.measure.block_reduce( + image, + block_size=block_size, + func=np.min, ) + # ---------- Torch backend ---------- -class AveragePooling(Pool): - def __init__(self, ksize: int = 2): - super().__init__(np.mean, ksize) + def _get_torch( + self, + image: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + import torch.nn.functional as F + + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) + + is_3d = image.ndim >= 3 and pz > 1 + + # Flatten extra (channel / feature) dimensions into C + if not is_3d: + extra = image.shape[2:] + C = int(np.prod(extra)) if extra else 1 + x = image.reshape(1, C, image.shape[0], image.shape[1]) + kernel = (px, py) + stride = (px, py) + + # min(x) = -max(-x) + pooled = -F.max_pool2d(-x, kernel, stride) + else: + extra = image.shape[3:] + C = int(np.prod(extra)) if extra else 1 + x = image.reshape( + 1, C, + image.shape[0], + image.shape[1], + image.shape[2], + ) + kernel = (pz, px, py) + stride = (pz, px, py) + + pooled = -F.max_pool3d(-x, kernel, stride) + + # Restore original layout + return pooled.reshape(pooled.shape[2:] + extra) class SumPooling(Pool): - def __init__(self, ksize: int = 2): - super().__init__(np.sum, ksize) + """Sum pooling feature. + Downsamples the input by applying sum pooling over non-overlapping + blocks of size `ksize`, preserving the center of the image and never + pooling over channel dimensions. -class MinPooling(Pool): - def __init__(self, ksize: int = 2): - super().__init__(np.min, ksize) + Works with NumPy and PyTorch backends. + """ + # ---------- NumPy backend ---------- -class MaxPooling(Pool): - def __init__(self, ksize: int = 2): - super().__init__(np.max, ksize) + def _get_numpy( + self, + image: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) + + # 2D or effectively 2D (channels-last) + if image.ndim < 3 or pz == 1: + block_size = (px, py) + (1,) * (image.ndim - 2) + else: + # 3D volume (optionally with channels) + block_size = (pz, px, py) + (1,) * (image.ndim - 3) + + return skimage.measure.block_reduce( + image, + block_size=block_size, + func=np.sum, + ) + + # ---------- Torch backend ---------- + + def _get_torch( + self, + image: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + import torch.nn.functional as F + + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) + + is_3d = image.ndim >= 3 and pz > 1 + + # Flatten extra (channel / feature) dimensions into C + if not is_3d: + extra = image.shape[2:] + C = int(np.prod(extra)) if extra else 1 + x = image.reshape(1, C, image.shape[0], image.shape[1]) + kernel = (px, py) + stride = (px, py) + pooled = F.avg_pool2d(x, kernel, stride) * (px * py) + else: + extra = image.shape[3:] + C = int(np.prod(extra)) if extra else 1 + x = image.reshape( + 1, C, + image.shape[0], + image.shape[1], + image.shape[2], + ) + kernel = (pz, px, py) + stride = (pz, px, py) + pooled = F.avg_pool3d(x, kernel, stride) * (pz * px * py) + + # Restore original layout + return pooled.reshape(pooled.shape[2:] + extra) class MedianPooling(Pool): - def __init__(self, ksize: int = 2): - super().__init__(np.median, ksize) + """Median pooling feature. + Downsamples the input by applying median pooling over non-overlapping + blocks of size `ksize`, preserving the center of the image and never + pooling over channel dimensions. + Notes + ----- + - NumPy backend uses `skimage.measure.block_reduce` + - Torch backend performs explicit unfolding followed by `median` + - Torch median pooling is significantly more expensive than mean/max + + Median pooling is not differentiable and should not be used inside + trainable neural networks requiring gradient-based optimization. + + """ + + # ---------- NumPy backend ---------- + + def _get_numpy( + self, + image: np.ndarray, + **kwargs: Any, + ) -> np.ndarray: + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) + + if image.ndim < 3 or pz == 1: + block_size = (px, py) + (1,) * (image.ndim - 2) + else: + block_size = (pz, px, py) + (1,) * (image.ndim - 3) + + return skimage.measure.block_reduce( + image, + block_size=block_size, + func=np.median, + ) -class Resize(Feature): + # ---------- Torch backend ---------- + + def _get_torch( + self, + image: torch.Tensor, + **kwargs: Any, + ) -> torch.Tensor: + + if not self._warned: + warnings.warn( + "MedianPooling is not differentiable and is expensive on the " + "Torch backend. Avoid using it inside trainable models.", + UserWarning, + stacklevel=2, + ) + self._warned = True + + image = self._crop_center(image) + px, py, pz = self._get_pool_size(image) + + is_3d = image.ndim >= 3 and pz > 1 + + if not is_3d: + # 2D case (with optional channels) + extra = image.shape[2:] + C = int(np.prod(extra)) if extra else 1 + + x = image.reshape(1, C, image.shape[0], image.shape[1]) + + # unfold: (B, C, H', W', px, py) + x_u = ( + x.unfold(2, px, px) + .unfold(3, py, py) + ) + + x_u = x_u.contiguous().view( + 1, C, + x_u.shape[2], + x_u.shape[3], + -1, + ) + + pooled = x_u.median(dim=-1).values + + else: + # 3D case (with optional channels) + extra = image.shape[3:] + C = int(np.prod(extra)) if extra else 1 + + x = image.reshape( + 1, C, + image.shape[0], + image.shape[1], + image.shape[2], + ) + + # unfold: (B, C, Z', Y', X', pz, px, py) + x_u = ( + x.unfold(2, pz, pz) + .unfold(3, px, px) + .unfold(4, py, py) + ) + + x_u = x_u.contiguous().view( + 1, C, + x_u.shape[2], + x_u.shape[3], + x_u.shape[4], + -1, + ) + + pooled = x_u.median(dim=-1).values + + return pooled.reshape(pooled.shape[2:] + extra) + + +class Resize(BackendDispatched, Feature): """Resize an image to a specified size. `Resize` resizes images following the channels-last semantic @@ -1517,6 +1833,9 @@ class Resize(Feature): """ + _NUMPY_IMPL = "_get_numpy" + _TORCH_IMPL = "_get_torch" + def __init__( self: Resize, dsize: PropertyLike[tuple[int, int]] = (256, 256), @@ -1527,8 +1846,8 @@ def __init__( Parameters ---------- dsize: PropertyLike[tuple[int, int]] - The target size. Format depends on backend: `(width, height)` for - NumPy, `(height, width)` for PyTorch. Default is (256, 256). + The target size. dsize is always (width, height) for both backends. + Default is (256, 256). **kwargs: Any Additional arguments passed to the parent `Feature` class. @@ -1583,83 +1902,102 @@ def get( `align_corners=False`, which closely matches OpenCV’s default behavior. """ + return self._dispatch_backend(image, dsize=dsize, **kwargs) - target_w, target_h = dsize + # ---------- NumPy backend (OpenCV) ---------- - # Torch backend - if apc.is_torch_array(image): - import torch.nn.functional as F + def _get_numpy( + self, + image: np.ndarray, + dsize: tuple[int, int], + **kwargs: Any, + ) -> np.ndarray: + + target_w, target_h = dsize - original_ndim = image.ndim - has_channels = ( - image.ndim >= 3 and image.shape[-1] <= 4 + # Prefer OpenCV if available + if OPENCV_AVAILABLE: + import cv2 + return utils.safe_call( + cv2.resize, + positional_args=[image, (target_w, target_h)], + **kwargs, ) + if not OPENCV_AVAILABLE and kwargs: + warnings.warn("OpenCV not available: resize kwargs may be ignored.", UserWarning) - # Bring to (N, C, H, W) - if image.ndim == 2: - # (H, W) -> (1, 1, H, W) - x = image.unsqueeze(0).unsqueeze(0) + # Fallback: skimage (always available in DT) + from skimage.transform import resize as sk_resize - elif image.ndim == 3 and has_channels: - # (H, W, C) -> (1, C, H, W) - x = image.permute(2, 0, 1).unsqueeze(0) + if image.ndim == 2: + out_shape = (target_h, target_w) + else: + out_shape = (target_h, target_w) + image.shape[2:] - elif image.ndim == 3: - # (Z, H, W) -> treat Z as batch - x = image.unsqueeze(1) + out = sk_resize( + image, + out_shape, + preserve_range=True, + anti_aliasing=True, + ) - elif image.ndim == 4 and has_channels: - # (Z, H, W, C) -> (Z, C, H, W) - x = image.permute(0, 3, 1, 2) + return out.astype(image.dtype, copy=False) - else: - raise ValueError( - f"Unsupported tensor shape {image.shape} for Resize." - ) + # ---------- Torch backend ---------- - # Resize spatial dimensions - resized = F.interpolate( - x, - size=(target_h, target_w), - mode="bilinear", - align_corners=False, - ) + def _get_torch( + self, + image: torch.Tensor, + dsize: tuple[int, int], + **kwargs: Any, + ) -> torch.Tensor: + import torch.nn.functional as F - # Restore original layout - if original_ndim == 2: - return resized.squeeze(0).squeeze(0) + target_w, target_h = dsize - if original_ndim == 3 and has_channels: - return resized.squeeze(0).permute(1, 2, 0) + original_ndim = image.ndim + has_channels = image.ndim >= 3 and image.shape[-1] <= 4 - if original_ndim == 3: - return resized.squeeze(1) + # Convert to (N, C, H, W) + if image.ndim == 2: + x = image.unsqueeze(0).unsqueeze(0) # (1, 1, H, W) - if original_ndim == 4: - return resized.permute(0, 2, 3, 1) + elif image.ndim == 3 and has_channels: + x = image.permute(2, 0, 1).unsqueeze(0) # (1, C, H, W) - raise RuntimeError("Unexpected shape restoration path.") + elif image.ndim == 3: + x = image.unsqueeze(1) # (Z, 1, H, W) - # NumPy / OpenCV backend - else: - import cv2 + elif image.ndim == 4 and has_channels: + x = image.permute(0, 3, 1, 2) # (Z, C, H, W) - # OpenCV expects (width, height) - return utils.safe_call( - cv2.resize, - positional_args=[image, (target_w, target_h)], - **kwargs, + else: + raise ValueError( + f"Unsupported tensor shape {image.shape} for Resize." ) + # Resize spatial dimensions + x = F.interpolate( + x, + size=(target_h, target_w), + mode="bilinear", + align_corners=False, + ) -if OPENCV_AVAILABLE: - _map_mode_to_cv2_borderType = { - "reflect": cv2.BORDER_REFLECT, - "wrap": cv2.BORDER_WRAP, - "constant": cv2.BORDER_CONSTANT, - "mirror": cv2.BORDER_REFLECT_101, - "nearest": cv2.BORDER_REPLICATE, - } + # Restore original layout + if original_ndim == 2: + return x.squeeze(0).squeeze(0) + + if original_ndim == 3 and has_channels: + return x.squeeze(0).permute(1, 2, 0) + + if original_ndim == 3: + return x.squeeze(1) + + if original_ndim == 4: + return x.permute(0, 2, 3, 1) + + raise RuntimeError("Unexpected shape restoration path.") #TODO ***JH*** revise BlurCV2 - torch, typing, docstring, unit test @@ -1679,7 +2017,7 @@ class BlurCV2(Feature): Methods ------- - `get(image: np.ndarray | Image, **kwargs: Any) --> np.ndarray` + `get(image: np.ndarray, **kwargs: Any) --> np.ndarray` Applies the blurring filter to the input image. Examples @@ -1710,52 +2048,17 @@ class BlurCV2(Feature): """ - def __new__( - cls: type, - *args: tuple, - **kwargs: Any, - ): - """Ensures that OpenCV (cv2) is available before instantiating the - class. - - Overrides the default object creation process to check that the `cv2` - module is available before creating the class. If OpenCV is not - installed, it raises an ImportError with instructions for installation. - - Parameters - ---------- - *args : tuple - Positional arguments passed to the class constructor. - **kwargs : dict - Keyword arguments passed to the class constructor. - - Returns - ------- - BlurCV2 - An instance of the BlurCV2 feature class. - - Raises - ------ - ImportError - If the OpenCV (`cv2`) module is not available in the current - environment. - - """ - - print(cls.__name__) - - if not OPENCV_AVAILABLE: - raise ImportError( - "OpenCV not installed on device. Since OpenCV is an optional " - f"dependency of DeepTrack2. To use {cls.__name__}, " - "you need to install it manually." - ) - - return super().__new__(cls) + _MODE_TO_BORDER = { + "reflect": "BORDER_REFLECT", + "wrap": "BORDER_WRAP", + "constant": "BORDER_CONSTANT", + "mirror": "BORDER_REFLECT_101", + "nearest": "BORDER_REPLICATE", + } def __init__( self: BlurCV2, - filter_function: Callable, + filter_function: Callable | str, mode: PropertyLike[str] = "reflect", **kwargs: Any, ): @@ -1775,13 +2078,20 @@ def __init__( """ + if not OPENCV_AVAILABLE: + raise ImportError( + "OpenCV not installed on device. Since OpenCV is an optional " + f"dependency of DeepTrack2. To use {self.__class__.__name__}, " + "you need to install it manually." + ) + self.filter = filter_function - borderType = _map_mode_to_cv2_borderType[mode] - super().__init__(borderType=borderType, **kwargs) + self.mode = mode + super().__init__(**kwargs) def get( self: BlurCV2, - image: np.ndarray | Image, + image: np.ndarray, **kwargs: Any, ) -> np.ndarray: """Applies the blurring filter to the input image. @@ -1790,8 +2100,8 @@ def get( Parameters ---------- - image: np.ndarray | Image - The input image to blur. Can be a NumPy array or DeepTrack Image. + image: np.ndarray + The input image to blur. Must be a NumPy array. **kwargs: Any Additional parameters for the blurring function. @@ -1805,12 +2115,31 @@ def get( if apc.is_torch_array(image): raise TypeError( "BlurCV2 only supports NumPy arrays. " - "For Torch tensors, use Blur or GaussianBlur instead." + "Use GaussianBlur / AverageBlur for Torch." ) + import cv2 + + filter_fn = getattr(cv2, self.filter) if isinstance(self.filter, str) else self.filter + + try: + border_attr = self._MODE_TO_BORDER[self.mode] + except KeyError as e: + raise ValueError(f"Unsupported border mode '{self.mode}'") from e + + try: + border = getattr(cv2, border_attr) + except AttributeError as e: + raise RuntimeError(f"OpenCV missing border constant '{border_attr}'") from e + + # preserve legacy behavior kwargs.pop("name", None) - result = self.filter(src=image, **kwargs) - return result + + return filter_fn( + src=image, + borderType=border, + **kwargs, + ) #TODO ***JH*** revise BilateralBlur - torch, typing, docstring, unit test @@ -1894,7 +2223,7 @@ def __init__( """ super().__init__( - cv2.bilateralFilter, + filter_function="bilateralFilter", d=d, sigmaColor=sigma_color, sigmaSpace=sigma_space, @@ -1903,13 +2232,25 @@ def __init__( def isotropic_dilation( - mask, + mask: np.ndarray | torch.Tensor, radius: float, *, - backend: str, + backend: Literal["numpy", "torch"], device=None, dtype=None, -): +) -> np.ndarray | torch.Tensor: + """ + Binary dilation using an isotropic (NumPy) or box-shaped (Torch) kernel. + + Notes + ----- + - NumPy backend uses a true Euclidean ball. + - Torch backend uses a cubic structuring element (approximate). + - Torch backend supports 3D masks only. + - Operation is non-differentiable. + + """ + if radius <= 0: return mask @@ -1938,13 +2279,25 @@ def isotropic_dilation( def isotropic_erosion( - mask, + mask: np.ndarray | torch.Tensor, radius: float, *, - backend: str, + backend: Literal["numpy", "torch"], device=None, dtype=None, -): +) -> np.ndarray | torch.Tensor: + """ + Binary erosion using an isotropic (NumPy) or box-shaped (Torch) kernel. + + Notes + ----- + - NumPy backend uses a true Euclidean ball. + - Torch backend uses a cubic structuring element (approximate). + - Torch backend supports 3D masks only. + - Operation is non-differentiable. + + """ + if radius <= 0: return mask diff --git a/deeptrack/optics.py b/deeptrack/optics.py index a7394689..70a54de6 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -395,16 +395,15 @@ def get( volume_samples, **additional_sample_kwargs, ) + if volume_samples: + # Interpret the merged volume semantically + sample_volume = self._extract_contrast_volume( + ScatteredVolume( + array=sample_volume, + properties=volume_samples[0].properties, + ), + ) - print('prop', volume_samples[0].properties) - - # Interpret the merged volume semantically - sample_volume = self._extract_contrast_volume( - ScatteredVolume( - array=sample_volume, - properties=volume_samples[0].properties, - ), - ) # Let the objective know about the limits of the volume and all the fields. propagate_data_to_dependencies( @@ -723,7 +722,6 @@ def _process_properties( wavelength = propertydict["wavelength"] voxel_size = get_active_voxel_size() radius = NA / wavelength * np.array(voxel_size) - print('Pupil radius (in pixels):', radius) if np.any(radius[:2] > 0.5): required_upscale = np.max(np.ceil(radius[:2] * 2)) @@ -1074,7 +1072,6 @@ class Fluorescence(Optics): """ - def validate_input(self, scattered): """Semantic validation for fluorescence microscopy.""" @@ -1085,9 +1082,9 @@ def validate_input(self, scattered): ) - def extract_contrast_volume(self, scattered: ScatteredVolume) -> np.ndarray: - voxel_size = np.asarray(get_active_voxel_size(), float) - voxel_volume = np.prod(voxel_size) + def extract_contrast_volume(self, scattered: ScatteredVolume, **kwargs) -> np.ndarray: + scale = np.asarray(get_active_scale(), float) + scale_volume = np.prod(scale) intensity = scattered.get_property("intensity", None) value = scattered.get_property("value", None) @@ -1103,7 +1100,7 @@ def extract_contrast_volume(self, scattered: ScatteredVolume) -> np.ndarray: # Preferred, physically meaningful case if intensity is not None: - return intensity * voxel_volume * scattered.array + return intensity * scale_volume * scattered.array # Fallback: legacy / dimensionless brightness warnings.warn( @@ -1379,7 +1376,6 @@ def extract_contrast_volume( refractive_index_medium: float, **kwargs: Any, ) -> np.ndarray: - print('ri_medium', refractive_index_medium) ri = scattered.get_property("refractive_index", None) value = scattered.get_property("value", None) @@ -2093,6 +2089,9 @@ class NonOverlapping(Feature): - This feature performs bounding cube checks first to quickly reject obvious overlaps before voxel-level checks. - If the bounding cubes overlap, precise voxel-based checks are performed. + - The feature may be computationally intensive for large numbers of volumes + or high-density placements. + - The feature is not differentiable. Examples --------- @@ -2375,7 +2374,7 @@ def _check_non_overlapping( new_volumes.append(new_volume) - list_of_volumes = new_volumes + list_of_volumes = new_volumes min_distance = 1 # The position of the top left corner of each volume (index (0, 0, 0)).