From 8a641a98a1449ba2f683cc0a065fcde87509aae7 Mon Sep 17 00:00:00 2001
From: Mirja Granfors <95694095+mirjagranfors@users.noreply.github.com>
Date: Wed, 5 Nov 2025 16:15:19 +0100
Subject: [PATCH 01/24] Update README (#442)
* Update README
Added links to the tutorials for creating custom scatterers.
* Update README
* Update README.md
---------
Co-authored-by: Alex <95913221+Pwhsky@users.noreply.github.com>
---
README.md | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/README.md b/README.md
index 674d41d3..a9f507c2 100644
--- a/README.md
+++ b/README.md
@@ -97,6 +97,14 @@ Here you find a series of notebooks that give you an overview of the core featur
Using PyTorch gradients to fit a Gaussian generated by a DeepTrack2 pipeline.
+- DTGS171A **[Creating Custom Scatterers](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171A_custom_scatterers.ipynb)**
+
+ Creating custom scatterers of arbitrary shapes.
+
+- DTGS171B **[Creating Custom Scatterers: Bacteria](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171B_custom_scatterers_bacteria.ipynb)**
+
+ Creating custom scatterers in the shape of bacteria.
+
# Examples
These are examples of how DeepTrack2 can be used on real datasets:
From 5e402dba9807de374fceea5e7843a5c0677070a8 Mon Sep 17 00:00:00 2001
From: github-actions
Date: Wed, 5 Nov 2025 15:15:38 +0000
Subject: [PATCH 02/24] Auto-update README-pypi.md
---
README-pypi.md | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/README-pypi.md b/README-pypi.md
index be0de197..f6173669 100644
--- a/README-pypi.md
+++ b/README-pypi.md
@@ -93,6 +93,14 @@ Here you find a series of notebooks that give you an overview of the core featur
Using PyTorch gradients to fit a Gaussian generated by a DeepTrack2 pipeline.
+- DTGS171A **[Creating Custom Scatterers](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171A_custom_scatterers.ipynb)**
+
+ Creating custom scatterers of arbitrary shapes.
+
+- DTGS171B **[Creating Custom Scatterers: Bacteria](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171B_custom_scatterers_bacteria.ipynb)**
+
+ Creating custom scatterers in the shape of bacteria.
+
# Examples
These are examples of how DeepTrack2 can be used on real datasets:
From a045633e747dde090c2b84bd8c6b3ced0a485d5c Mon Sep 17 00:00:00 2001
From: Carlo
Date: Fri, 2 Jan 2026 20:33:56 +0100
Subject: [PATCH 03/24] removed Image
---
deeptrack/features.py | 10 ++++----
deeptrack/optics.py | 55 ++++++++++++++++++-----------------------
deeptrack/scatterers.py | 2 +-
3 files changed, 30 insertions(+), 37 deletions(-)
diff --git a/deeptrack/features.py b/deeptrack/features.py
index 4bdfad38..26464d4c 100644
--- a/deeptrack/features.py
+++ b/deeptrack/features.py
@@ -8093,12 +8093,12 @@ def get(
with units.context(ctx):
image = self.feature(image)
- # Downscale the result to the original resolution.
- import skimage.measure
+ # # Downscale the result to the original resolution.
+ # import skimage.measure
- image = skimage.measure.block_reduce(
- image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean
- )
+ # image = skimage.measure.block_reduce(
+ # image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean
+ # )
return image
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index 5149bdae..be1217b2 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -344,11 +344,11 @@ def get(
volume_samples,
**additional_sample_kwargs,
)
- sample_volume = Image(sample_volume)
+ # sample_volume = Image(sample_volume)
# Merge all properties into the volume.
- for scatterer in volume_samples + field_samples:
- sample_volume.merge_properties_from(scatterer)
+ # for scatterer in volume_samples + field_samples:
+ # sample_volume.merge_properties_from(scatterer)
# Let the objective know about the limits of the volume and all the fields.
propagate_data_to_dependencies(
@@ -365,18 +365,18 @@ def get(
imaged_sample
)
- # Merge with input
- if not image:
- if not self._wrap_array_with_image and isinstance(imaged_sample, Image):
- return imaged_sample._value
- else:
- return imaged_sample
+ # # Merge with input
+ # if not image:
+ # if not self._wrap_array_with_image and isinstance(imaged_sample, Image):
+ # return imaged_sample._value
+ # else:
+ # return imaged_sample
- if not isinstance(image, list):
- image = [image]
- for i in range(len(image)):
- image[i].merge_properties_from(imaged_sample)
- return image
+ # if not isinstance(image, list):
+ # image = [image]
+ # for i in range(len(image)):
+ # image[i].merge_properties_from(imaged_sample)
+ # return image
# def _no_wrap_format_input(self, *args, **kwargs) -> list:
# return self._image_wrapped_format_input(*args, **kwargs)
@@ -757,19 +757,18 @@ def _pupil(
W, H = np.meshgrid(y, x)
RHO = (W ** 2 + H ** 2).astype(complex)
- pupil_function = Image((RHO < 1) + 0.0j, copy=False)
+ pupil_function = (RHO < 1) + 0.0j
# Defocus
- z_shift = Image(
+ z_shift = (
2
* np.pi
* refractive_index_medium
/ wavelength
* voxel_size[2]
- * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO),
- copy=False,
+ * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO)
)
- z_shift._value[z_shift._value.imag != 0] = 0
+ z_shift[z_shift.imag != 0] = 0
try:
z_shift = np.nan_to_num(z_shift, False, 0, 0, 0)
@@ -1118,9 +1117,7 @@ def get(
]
z_limits = limits[2, :]
- output_image = Image(
- np.zeros((*padded_volume.shape[0:2], 1)), copy=False
- )
+ output_image = np.zeros((*padded_volume.shape[0:2], 1))
index_iterator = range(padded_volume.shape[2])
@@ -1156,7 +1153,7 @@ def get(
field = np.fft.ifft2(convolved_fourier_field)
# # Discard remaining imaginary part (should be 0 up to rounding error)
field = np.real(field)
- output_image._value[:, :, 0] += field[
+ output_image[:, :, 0] += field[
: padded_volume.shape[0], : padded_volume.shape[1]
]
@@ -1353,9 +1350,7 @@ def get(
]
z_limits = limits[2, :]
- output_image = Image(
- np.zeros((*padded_volume.shape[0:2], 1))
- )
+ output_image = np.zeros((*padded_volume.shape[0:2], 1))
index_iterator = range(padded_volume.shape[2])
z_iterator = np.linspace(
@@ -1426,7 +1421,7 @@ def get(
: padded_volume.shape[0], : padded_volume.shape[1]
]
output_image = np.expand_dims(output_image, axis=-1)
- output_image = Image(output_image[pad[0] : -pad[2], pad[1] : -pad[3]])
+ output_image = output_image[pad[0] : -pad[2], pad[1] : -pad[3]]
if not kwargs.get("return_field", False):
output_image = np.square(np.abs(output_image))
@@ -1436,7 +1431,7 @@ def get(
# output_image = output_image * np.exp(1j * -np.pi / 4)
# output_image = output_image + 1
- output_image.properties = illuminated_volume.properties
+ # output_image.properties = illuminated_volume.properties
return output_image
@@ -1959,14 +1954,12 @@ def _create_volume(
):
continue
- padded_scatterer = Image(
- np.pad(
+ padded_scatterer = np.pad(
scatterer,
[(2, 2), (2, 2), (2, 2)],
"constant",
constant_values=0,
)
- )
padded_scatterer.merge_properties_from(scatterer)
scatterer = padded_scatterer
diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py
index 04a7c5ea..ed77e5ed 100644
--- a/deeptrack/scatterers.py
+++ b/deeptrack/scatterers.py
@@ -333,7 +333,7 @@ def _process_and_get(
new_image = new_image[:, ~np.all(new_image == 0, axis=(0, 2))]
new_image = new_image[:, :, ~np.all(new_image == 0, axis=(0, 1))]
- return [Image(new_image)]
+ return [new_image]
def _no_wrap_format_input(
self,
From d61ef780a50b9e72dab03035b4b9d1d7ccf6c381 Mon Sep 17 00:00:00 2001
From: Carlo
Date: Sat, 3 Jan 2026 02:40:04 +0100
Subject: [PATCH 04/24] ok
---
deeptrack/features.py | 2 +-
deeptrack/optics.py | 311 ++++++++++++++++++++++++++++------------
deeptrack/scatterers.py | 136 +++++++++++++-----
3 files changed, 326 insertions(+), 123 deletions(-)
diff --git a/deeptrack/features.py b/deeptrack/features.py
index 26464d4c..f76245b8 100644
--- a/deeptrack/features.py
+++ b/deeptrack/features.py
@@ -8082,7 +8082,7 @@ def get(
# Ensure factor is a tuple of three integers.
if np.size(factor) == 1:
- factor = (factor,) * 3
+ factor = (factor, factor, 1)
elif len(factor) != 3:
raise ValueError(
"Factor must be an integer or a tuple of three integers."
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index be1217b2..5564e0b2 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -137,11 +137,13 @@ def _pad_volume(
from __future__ import annotations
from pint import Quantity
-from typing import Any
+from typing import Any, TYPE_CHECKING
import warnings
import numpy as np
-from scipy.ndimage import convolve
+from scipy.ndimage import convolve # might be removed later
+import torch
+import torch.nn.functional as F
from deeptrack.backend.units import (
ConversionTable,
@@ -158,6 +160,16 @@ def _pad_volume(
from deeptrack import image
from deeptrack import units_registry as u
+from deeptrack import TORCH_AVAILABLE, image
+from deeptrack.backend import xp
+from deeptrack.scatterers import ScatteredVolume, ScatteredField
+
+if TORCH_AVAILABLE:
+ import torch
+
+if TYPE_CHECKING:
+ import torch
+
#TODO ***??*** revise Microscope - torch, typing, docstring, unit test
class Microscope(StructuralFeature):
@@ -280,16 +292,16 @@ def get(
# Grab properties from the objective to pass to the sample
additional_sample_kwargs = self._objective.properties()
- # Calculate required output image for the given upscale
- # This way of providing the upscale will be deprecated in the future
- # in favor of dt.Upscale().
- _upscale_given_by_optics = additional_sample_kwargs["upscale"]
- if np.array(_upscale_given_by_optics).size == 1:
- _upscale_given_by_optics = (_upscale_given_by_optics,) * 3
+ # # Calculate required output image for the given upscale
+ # # This way of providing the upscale will be deprecated in the future
+ # # in favor of dt.Upscale().
+ # _upscale_given_by_optics = additional_sample_kwargs["upscale"]
+ # if np.array(_upscale_given_by_optics).size == 1:
+ # _upscale_given_by_optics = (_upscale_given_by_optics,) * 3
with u.context(
create_context(
- *additional_sample_kwargs["voxel_size"], *_upscale_given_by_optics
+ *additional_sample_kwargs["voxel_size"]#, *_upscale_given_by_optics
)
):
@@ -329,14 +341,14 @@ def get(
volume_samples = [
scatterer
for scatterer in list_of_scatterers
- if not scatterer.get_property("is_field", default=False)
+ if isinstance(scatterer, ScatteredVolume)
]
# All scatterers that are defined as fields.
field_samples = [
scatterer
for scatterer in list_of_scatterers
- if scatterer.get_property("is_field", default=False)
+ if isinstance(scatterer, ScatteredField)
]
# Merge all volumes into a single volume.
@@ -359,33 +371,32 @@ def get(
imaged_sample = self._objective.resolve(sample_volume)
- # Upscale given by the optics needs to be handled separately.
- if _upscale_given_by_optics != (1, 1, 1):
- imaged_sample = AveragePooling((*_upscale_given_by_optics[:2], 1))(
- imaged_sample
- )
- # # Merge with input
- # if not image:
- # if not self._wrap_array_with_image and isinstance(imaged_sample, Image):
- # return imaged_sample._value
- # else:
- # return imaged_sample
+ # Collect main_property from scatterers
+ main_properties = {
+ s.main_property
+ for s in list_of_scatterers
+ if hasattr(s, "main_property")
+ }
- # if not isinstance(image, list):
- # image = [image]
- # for i in range(len(image)):
- # image[i].merge_properties_from(imaged_sample)
- # return image
+ if len(main_properties) != 1:
+ raise ValueError(
+ f"Inconsistent main_property across scatterers: {main_properties}"
+ )
- # def _no_wrap_format_input(self, *args, **kwargs) -> list:
- # return self._image_wrapped_format_input(*args, **kwargs)
+ main_property = main_properties.pop()
- # def _no_wrap_process_and_get(self, *args, **feature_input) -> list:
- # return self._image_wrapped_process_and_get(*args, **feature_input)
+ # Handling upscale from dt.Upscale() here to eliminate Image
+ # wrapping issues.
+ if np.any(np.array(upscale) != 1):
+ ux, uy = upscale[:2]
+ if main_property == "intensity":
+ print("Using sum pooling for intensity downscaling.")
+ imaged_sample = SumPoolingCM((ux, uy, 1))(imaged_sample)
+ else:
+ imaged_sample = AveragePoolingCM((ux, uy, 1))(imaged_sample)
- # def _no_wrap_process_output(self, *args, **feature_input):
- # return self._image_wrapped_process_output(*args, **feature_input)
+ return imaged_sample
#TODO ***??*** revise Optics - torch, typing, docstring, unit test
@@ -1158,7 +1169,7 @@ def get(
]
output_image = output_image[pad[0] : -pad[2], pad[1] : -pad[3]]
- output_image.properties = illuminated_volume.properties + pupils.properties
+ # output_image.properties = illuminated_volume.properties + pupils.properties
return output_image
@@ -1797,7 +1808,7 @@ def get(
#TODO ***??*** revise _get_position - torch, typing, docstring, unit test
def _get_position(
- image: Image,
+ scatterer: ScatteredVolume,
mode: str = "corner",
return_z: bool = False,
) -> np.ndarray:
@@ -1821,26 +1832,23 @@ def _get_position(
num_outputs = 2 + return_z
- if mode == "corner" and image.size > 0:
+ if mode == "corner" and scatterer.array.size > 0:
import scipy.ndimage
- image = image.to_numpy()
-
- shift = scipy.ndimage.center_of_mass(np.abs(image))
+ shift = scipy.ndimage.center_of_mass(np.abs(scatterer.array))
if np.isnan(shift).any():
- shift = np.array(image.shape) / 2
+ shift = np.array(scatterer.array.shape) / 2
else:
shift = np.zeros((num_outputs))
- position = np.array(image.get_property("position", default=None))
+ position = np.array(scatterer.get_property("position", default=None))
if position is None:
return position
scale = np.array(get_active_scale())
-
if len(position) == 3:
position = position * scale + 0.5 * (scale - 1)
if return_z:
@@ -1851,7 +1859,7 @@ def _get_position(
elif len(position) == 2:
if return_z:
outp = (
- np.array([position[0], position[1], image.get_property("z", default=0)])
+ np.array([position[0], position[1], scatterer.get_property("z", default=0)])
* scale
- shift
+ 0.5 * (scale - 1)
@@ -1863,6 +1871,58 @@ def _get_position(
return position
+def _bilinear_interpolate_numpy(
+ scatterer: np.ndarray, x_off: float, y_off: float
+) -> np.ndarray:
+ """Apply bilinear subpixel interpolation in the x–y plane (NumPy)."""
+ kernel = np.array(
+ [
+ [0.0, 0.0, 0.0],
+ [0.0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off],
+ [0.0, x_off * (1 - y_off), x_off * y_off],
+ ]
+ )
+ out = np.zeros_like(scatterer)
+ for z in range(scatterer.shape[2]):
+ if np.iscomplexobj(scatterer):
+ out[:, :, z] = (
+ convolve(np.real(scatterer[:, :, z]), kernel, mode="constant")
+ + 1j
+ * convolve(np.imag(scatterer[:, :, z]), kernel, mode="constant")
+ )
+ else:
+ out[:, :, z] = convolve(scatterer[:, :, z], kernel, mode="constant")
+ return out
+
+
+def _bilinear_interpolate_torch(
+ scatterer: torch.Tensor, x_off: float, y_off: float
+) -> torch.Tensor:
+ """Apply bilinear subpixel interpolation in the x–y plane (Torch).
+
+ Uses grid_sample for autograd-friendly interpolation.
+ """
+ H, W, D = scatterer.shape
+
+ # Normalized shifts in [-1,1]
+ x_shift = 2 * x_off / (W - 1)
+ y_shift = 2 * y_off / (H - 1)
+
+ yy, xx = torch.meshgrid(
+ torch.linspace(-1, 1, H, device=scatterer.device, dtype=scatterer.dtype),
+ torch.linspace(-1, 1, W, device=scatterer.device, dtype=scatterer.dtype),
+ indexing="ij",
+ )
+ grid = torch.stack((xx + x_shift, yy + y_shift), dim=-1) # (H,W,2)
+ grid = grid.unsqueeze(0).repeat(D, 1, 1, 1) # (D,H,W,2)
+
+ inp = scatterer.permute(2, 0, 1).unsqueeze(1) # (D,1,H,W)
+
+ out = F.grid_sample(inp, grid, mode="bilinear",
+ padding_mode="zeros", align_corners=True)
+ return out.squeeze(1).permute(1, 2, 0) # (H,W,D)
+
+
#TODO ***??*** revise _create_volume - torch, typing, docstring, unit test
def _create_volume(
list_of_scatterers: list,
@@ -1922,24 +1982,20 @@ def _create_volume(
# This accounts for upscale doing AveragePool instead of SumPool. This is
# a bit of a hack, but it works for now.
- fudge_factor = scale[0] * scale[1] / scale[2]
+ # fudge_factor = scale[0] * scale[1] / scale[2]
for scatterer in list_of_scatterers:
position = _get_position(scatterer, mode="corner", return_z=True)
-
- if scatterer.get_property("intensity", None) is not None:
- intensity = scatterer.get_property("intensity")
- scatterer_value = intensity * fudge_factor
- elif scatterer.get_property("refractive_index", None) is not None:
- refractive_index = scatterer.get_property("refractive_index")
- scatterer_value = (
- refractive_index - refractive_index_medium
- )
- else:
+ if scatterer.main_property == "intensity":
+ scatterer_value = scatterer.get_property("intensity") #* fudge_factor
+ elif scatterer.main_property == "refractive_index":
+ scatterer_value = scatterer.get_property("refractive_index") - refractive_index_medium
+ else: # fallback to generic value
scatterer_value = scatterer.get_property("value")
- scatterer = scatterer * scatterer_value
+ # Scale the array accordingly
+ scatterer.array = scatterer.array * scatterer_value
if limits is None:
limits = np.zeros((3, 2), dtype=np.int32)
@@ -1947,24 +2003,23 @@ def _create_volume(
limits[:, 1] = np.floor(position).astype(np.int32) + 1
if (
- position[0] + scatterer.shape[0] < OR[0]
+ position[0] + scatterer.array.shape[0] < OR[0]
or position[0] > OR[2]
- or position[1] + scatterer.shape[1] < OR[1]
+ or position[1] + scatterer.array.shape[1] < OR[1]
or position[1] > OR[3]
):
continue
- padded_scatterer = np.pad(
- scatterer,
+ # Pad scatterer to avoid edge effects during interpolation
+ padded_scatterer = scatterer
+ padded_scatterer.array = np.pad(
+ scatterer.array._value, # this is a temporary fix, Image should be removed from features
[(2, 2), (2, 2), (2, 2)],
"constant",
constant_values=0,
)
- padded_scatterer.merge_properties_from(scatterer)
-
- scatterer = padded_scatterer
- position = _get_position(scatterer, mode="corner", return_z=True)
- shape = np.array(scatterer.shape)
+ position = _get_position(padded_scatterer, mode="corner", return_z=True)
+ shape = np.array(padded_scatterer.array.shape)
if position is None:
RuntimeWarning(
@@ -1973,36 +2028,44 @@ def _create_volume(
)
continue
- splined_scatterer = np.zeros_like(scatterer)
-
x_off = position[0] - np.floor(position[0])
y_off = position[1] - np.floor(position[1])
- kernel = np.array(
- [
- [0, 0, 0],
- [0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off],
- [0, x_off * (1 - y_off), x_off * y_off],
- ]
- )
+
+ if isinstance(padded_scatterer.array, np.ndarray): # get_backend is a method of Features and not exposed
+ splined_scatterer = _bilinear_interpolate_numpy(padded_scatterer.array, x_off, y_off)
+ elif isinstance(padded_scatterer.array, torch.Tensor):
+ splined_scatterer = _bilinear_interpolate_torch(padded_scatterer.array, x_off, y_off)
+ else:
+ raise TypeError(
+ f"Unsupported array type {type(padded_scatterer.array)}. "
+ "Expected np.ndarray or torch.Tensor."
+ )
- for z in range(scatterer.shape[2]):
- if splined_scatterer.dtype == complex:
- splined_scatterer[:, :, z] = (
- convolve(
- np.real(scatterer[:, :, z]), kernel, mode="constant"
- )
- + convolve(
- np.imag(scatterer[:, :, z]), kernel, mode="constant"
- )
- * 1j
- )
- else:
- splined_scatterer[:, :, z] = convolve(
- scatterer[:, :, z], kernel, mode="constant"
- )
+ # kernel = np.array(
+ # [
+ # [0, 0, 0],
+ # [0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off],
+ # [0, x_off * (1 - y_off), x_off * y_off],
+ # ]
+ # )
+
+ # for z in range(padded_scatterer.array.shape[2]):
+ # if splined_scatterer.dtype == complex:
+ # splined_scatterer[:, :, z] = (
+ # convolve(
+ # np.real(padded_scatterer.array[:, :, z]), kernel, mode="constant"
+ # )
+ # + convolve(
+ # np.imag(padded_scatterer.array[:, :, z]), kernel, mode="constant"
+ # )
+ # * 1j
+ # )
+ # else:
+ # splined_scatterer[:, :, z] = convolve(
+ # padded_scatterer.array[:, :, z], kernel, mode="constant"
+ # )
- scatterer = splined_scatterer
position = np.floor(position)
new_limits = np.zeros(limits.shape, dtype=np.int32)
for i in range(3):
@@ -2032,6 +2095,7 @@ def _create_volume(
within_volume_position = position - limits[:, 0]
# NOTE: Maybe shouldn't be additive.
+ # give options: sum default, but also sum, mean, max, min
volume[
int(within_volume_position[0]) :
int(within_volume_position[0] + shape[0]),
@@ -2041,5 +2105,72 @@ def _create_volume(
int(within_volume_position[2]) :
int(within_volume_position[2] + shape[2]),
- ] += scatterer
+ ] += splined_scatterer
return volume, limits
+
+# this should be moved to math
+class _CenteredPoolingBase:
+ def __init__(self, pool_size: tuple[int, int, int]):
+ px, py, pz = pool_size
+ if pz != 1:
+ raise ValueError("Only pz=1 supported.")
+ self.px = int(px)
+ self.py = int(py)
+
+ def _crop_center(self, array):
+ H, W = array.shape[:2]
+ px, py = self.px, self.py
+
+ crop_h = (H // px) * px
+ crop_w = (W // py) * py
+
+ off_h = (H - crop_h) // 2
+ off_w = (W - crop_w) // 2
+
+ return array[off_h:off_h+crop_h, off_w:off_w+crop_w, ...]
+
+ def _pool_numpy(self, array, func):
+ import skimage.measure
+ array = self._crop_center(array)
+ pool_shape = (self.px, self.py) + (1,) * (array.ndim - 2)
+ return skimage.measure.block_reduce(array, pool_shape, func)
+
+ def _pool_torch(self, array, sum_pool=False):
+ px, py = self.px, self.py
+ array = self._crop_center(array)
+
+ extra = array.shape[2:]
+ C = int(np.prod(extra)) if extra else 1
+ x = array.reshape(1, C, array.shape[0], array.shape[1])
+
+ pooled = torch.nn.functional.avg_pool2d(
+ x, kernel_size=(px, py), stride=(px, py)
+ )
+ if sum_pool:
+ pooled = pooled * (px * py)
+
+ return pooled.reshape(
+ (pooled.shape[2], pooled.shape[3]) + extra
+ )
+
+class AveragePoolingCM(_CenteredPoolingBase):
+ """Center-preserving average pooling (intensive quantities)."""
+
+ def __call__(self, array):
+ if isinstance(array, np.ndarray):
+ return self._pool_numpy(array, np.mean)
+ elif TORCH_AVAILABLE and isinstance(array, torch.Tensor):
+ return self._pool_torch(array, sum_pool=False)
+ else:
+ raise TypeError("Unsupported array type.")
+
+class SumPoolingCM(_CenteredPoolingBase):
+ """Center-preserving sum pooling (extensive quantities)."""
+
+ def __call__(self, array):
+ if isinstance(array, np.ndarray):
+ return self._pool_numpy(array, np.sum)
+ elif TORCH_AVAILABLE and isinstance(array, torch.Tensor):
+ return self._pool_torch(array, sum_pool=True)
+ else:
+ raise TypeError("Unsupported array type.")
\ No newline at end of file
diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py
index ed77e5ed..f3daa9bf 100644
--- a/deeptrack/scatterers.py
+++ b/deeptrack/scatterers.py
@@ -166,6 +166,7 @@
import numpy as np
from numpy.typing import NDArray
from pint import Quantity
+from dataclasses import dataclass, field
from deeptrack.holography import get_propagation_matrix
from deeptrack.backend.units import (
@@ -246,6 +247,9 @@ class Scatterer(Feature):
voxel_size=(u.meter, u.meter),
)
+ #: Default property name (subclasses override this)
+ main_property: str = "value"
+
def __init__(
self,
position: ArrayLike[float] = (32, 32),
@@ -258,7 +262,7 @@ def __init__(
**kwargs,
) -> None:
# Ignore warning to help with comparison with arrays.
- if upsample is not 1: # noqa: F632
+ if upsample != 1: # noqa: F632
warnings.warn(
f"Setting upsample != 1 is deprecated. "
f"Please, instead use dt.Upscale(f, factor={upsample})"
@@ -310,7 +314,7 @@ def _process_and_get(
voxel_size = get_active_voxel_size()
# Calls parent _process_and_get.
- new_image = super()._process_and_get(
+ new_image = super(Scatterer, self)._process_and_get(
*args,
voxel_size=voxel_size,
upsample=upsample,
@@ -333,32 +337,41 @@ def _process_and_get(
new_image = new_image[:, ~np.all(new_image == 0, axis=(0, 2))]
new_image = new_image[:, :, ~np.all(new_image == 0, axis=(0, 1))]
- return [new_image]
-
- def _no_wrap_format_input(
- self,
- *args,
- **kwargs
- ) -> list:
- return self._image_wrapped_format_input(*args, **kwargs)
-
- def _no_wrap_process_and_get(
- self,
- *args,
- **feature_input
- ) -> list:
- return self._image_wrapped_process_and_get(*args, **feature_input)
-
- def _no_wrap_process_output(
- self,
- *args,
- **feature_input
- ) -> list:
- return self._image_wrapped_process_output(*args, **feature_input)
+ # # Copy properties
+ # props = kwargs.copy()
+ return [self._wrap_output(new_image, kwargs)]
+
+ def _wrap_output(self, array, props) -> ScatteredBase:
+ """Must be overridden in subclasses to wrap output correctly."""
+ raise NotImplementedError
+
+class VolumeScatterer(Scatterer):
+ """Abstract scatterer producing ScatteredVolume outputs."""
+ def _wrap_output(self, array, props) -> ScatteredVolume:
+ return [ScatteredVolume(
+ array=array,
+ position=props.get("position", (0, 0)),
+ z=props.get("z", 0.0),
+ value=props.get("value", 1.0),
+ intensity=props.get("intensity", None),
+ refractive_index=props.get("refractive_index", None),
+ properties=props.copy(),
+ main_property=self.main_property,
+ )]
+
+class FieldScatterer(Scatterer):
+ def _wrap_output(self, array, props) -> ScatteredField:
+ return [ScatteredField(
+ array=array,
+ position=props.get("position", (0, 0)),
+ wavelength=props.get("wavelength", 0.0),
+ properties=props.copy(),
+ main_property=self.main_property,
+ )]
#TODO ***??*** revise PointParticle - torch, typing, docstring, unit test
-class PointParticle(Scatterer):
+class PointParticle(VolumeScatterer):
"""Generate a diffraction-limited point particle.
A point particle is approximated by the size of a single pixel or voxel.
@@ -382,6 +395,8 @@ class PointParticle(Scatterer):
"""
+ main_property = "intensity"
+
def __init__(
self: PointParticle,
**kwargs: Any,
@@ -405,7 +420,7 @@ def get(
#TODO ***??*** revise Ellipse - torch, typing, docstring, unit test
-class Ellipse(Scatterer):
+class Ellipse(VolumeScatterer):
"""Generates an elliptical disk scatterer
Parameters
@@ -446,6 +461,8 @@ class Ellipse(Scatterer):
rotation=(u.radian, u.radian),
)
+ main_property = "refractive_index"
+
def __init__(
self,
radius: float = 1e-6,
@@ -519,7 +536,7 @@ def get(
#TODO ***??*** revise Sphere - torch, typing, docstring, unit test
-class Sphere(Scatterer):
+class Sphere(VolumeScatterer):
"""Generates a spherical scatterer
Parameters
@@ -550,6 +567,8 @@ class Sphere(Scatterer):
radius=(u.meter, u.meter),
)
+ main_property = "refractive_index"
+
def __init__(
self,
radius: float = 1e-6,
@@ -584,7 +603,7 @@ def get(
#TODO ***??*** revise Ellipsoid - torch, typing, docstring, unit test
-class Ellipsoid(Scatterer):
+class Ellipsoid(VolumeScatterer):
"""Generates an ellipsoidal scatterer
Parameters
@@ -625,6 +644,8 @@ class Ellipsoid(Scatterer):
rotation=(u.radian, u.radian),
)
+ main_property = "refractive_index"
+
def __init__(
self,
radius: float = 1e-6,
@@ -741,7 +762,7 @@ def get(
#TODO ***??*** revise MieScatterer - torch, typing, docstring, unit test
-class MieScatterer(Scatterer):
+class MieScatterer(FieldScatterer):
"""Base implementation of a Mie particle.
New Mie-theory scatterers can be implemented by extending this class, and
@@ -835,6 +856,8 @@ class MieScatterer(Scatterer):
coherence_length=(u.meter, u.pixel),
)
+ main_property = "wavelength"
+
def __init__(
self,
coefficients,
@@ -864,11 +887,11 @@ def __init__(
"Please use input_polarization instead"
)
input_polarization = polarization_angle
- kwargs.pop("is_field", None)
+ kwargs.pop("is_field", None) # remove
kwargs.pop("crop_empty", None)
super().__init__(
- is_field=True,
+ is_field=True, # remove
crop_empty=False,
L=L,
offset_z=offset_z,
@@ -1188,7 +1211,6 @@ def get(
-mask.shape[1] // 2 : mask.shape[1] // 2,
]
mask = np.exp(-0.5 * (x ** 2 + y ** 2) / ((sigma) ** 2))
-
arr = arr * mask
fourier_field = np.fft.fft2(arr)
@@ -1412,3 +1434,53 @@ def inner(
refractive_index=refractive_index,
**kwargs,
)
+
+
+@dataclass
+class ScatteredBase:
+ """Base class for scatterers (volumes and fields)."""
+
+ array: ArrayLike
+ position: np.ndarray
+ z: float = 0.0
+ properties: dict[str, Any] = field(default_factory=dict)
+ main_property: str = None
+
+ def __post_init__(self):
+ self.position = np.array(self.position, dtype=float).reshape(-1)[:2]
+ self.z = float(np.atleast_1d(self.z).squeeze())
+
+ @property
+ def pos3d(self) -> np.ndarray:
+ return np.array([*self.position, self.z], dtype=float)
+
+ def as_array(self) -> ArrayLike:
+ """Return the underlying array.
+
+ Notes
+ -----
+ The raw array is also directly available as ``scatterer.array``.
+ This method exists mainly for API compatibility and clarity.
+
+ """
+
+ return self.array
+
+ def get_property(self, key: str, default: Any = None) -> Any:
+ return getattr(self, key, self.properties.get(key, default))
+
+
+@dataclass
+class ScatteredVolume(ScatteredBase):
+ """Volumetric object: intensity sources or refractive index contrasts."""
+
+ refractive_index: float | None = None
+ intensity: float | None = None
+ value: float | None = None
+
+
+@dataclass
+class ScatteredField(ScatteredBase):
+ """Complex wavefield (already propagated or emitted)."""
+
+ wavelength: float = 500e-9
\ No newline at end of file
From 36af2e4f909ef80673015c853ca6ee7f4c6390d8 Mon Sep 17 00:00:00 2001
From: Carlo
Date: Sun, 4 Jan 2026 00:37:23 +0100
Subject: [PATCH 05/24] removed Image from optics and scatterers
---
deeptrack/optics.py | 32 +++++++++++++++++++++++++++++++-
1 file changed, 31 insertions(+), 1 deletion(-)
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index 5564e0b2..4c23469f 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -351,6 +351,10 @@ def get(
if isinstance(scatterer, ScatteredField)
]
+ warn_upscale_fields = False
+ if field_samples and np.any(upscale != 1):
+ warn_upscale_fields = True
+
# Merge all volumes into a single volume.
sample_volume, limits = _create_volume(
volume_samples,
@@ -371,6 +375,14 @@ def get(
imaged_sample = self._objective.resolve(sample_volume)
+ if warn_upscale_fields:
+ warnings.warn(
+ "dt.Upscale is active while FieldScatterers are present. "
+ "Coherent fields are injected without resampling, so the "
+ "physical interpretation may change with Upscale. "
+ "This behavior is currently undefined.",
+ UserWarning,
+ )
# Collect main_property from scatterers
main_properties = {
@@ -1420,7 +1432,25 @@ def get(
light_in_focus = light_in * shifted_pupil
if len(fields) > 0:
- field = np.sum(fields, axis=0)
+ # field = np.sum(fields, axis=0)
+ field_arrays = []
+
+ for fs in fields:
+ # fs is a ScatteredField
+ arr = fs.array
+
+ # Enforce (H, W, 1) shape
+ if arr.ndim == 2:
+ arr = arr[..., None]
+
+ if arr.ndim != 3 or arr.shape[-1] != 1:
+ raise ValueError(
+ f"Expected field of shape (H, W, 1), got {arr.shape}"
+ )
+
+ field_arrays.append(arr)
+
+ field = np.sum(field_arrays, axis=0)
light_in_focus += field[..., 0]
shifted_pupil = np.fft.fftshift(pupils[-1])
light_in_focus = light_in_focus * shifted_pupil
From 27ebcb19e61431e3e1c4cd6877804e09ce013679 Mon Sep 17 00:00:00 2001
From: Carlo
Date: Sun, 4 Jan 2026 01:58:58 +0100
Subject: [PATCH 06/24] Daniel's Mie scatterer.
---
deeptrack/scatterers.py | 63 +++++++++++++++++++++++++++--------------
1 file changed, 42 insertions(+), 21 deletions(-)
diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py
index f3daa9bf..b52408e2 100644
--- a/deeptrack/scatterers.py
+++ b/deeptrack/scatterers.py
@@ -879,6 +879,7 @@ def __init__(
illumination_angle: float=0,
amp_factor: float=1,
phase_shift_correction: bool=False,
+ pupil: ArrayLike=[],
**kwargs,
) -> None:
if polarization_angle is not None:
@@ -887,7 +888,7 @@ def __init__(
"Please use input_polarization instead"
)
input_polarization = polarization_angle
- kwargs.pop("is_field", None) # remove
+ # kwargs.pop("is_field", None) # remove
kwargs.pop("crop_empty", None)
super().__init__(
@@ -912,6 +913,7 @@ def __init__(
illumination_angle=illumination_angle,
amp_factor=amp_factor,
phase_shift_correction=phase_shift_correction,
+ pupil=pupil,
**kwargs,
)
@@ -1037,7 +1039,8 @@ def get_plane_in_polar_coords(
shape: int,
voxel_size: ArrayLike[float],
plane_position: float,
- illumination_angle: float
+ illumination_angle: float,
+ k: float,
) -> tuple[float, float, float, float]:
"""Computes the coordinates of the plane in polar form."""
@@ -1050,15 +1053,22 @@ def get_plane_in_polar_coords(
R2_squared = X ** 2 + Y ** 2
R3 = np.sqrt(R2_squared + Z ** 2) # Might be +z instead of -z.
+ Q = np.sqrt(R2_squared)/voxel_size[0]**2*2*np.pi/shape[0]
+ sin_theta=Q/(k)
+ pupil_mask=sin_theta<=1
+
+ cos_theta=np.zeros(sin_theta.shape)
+ cos_theta[pupil_mask]=np.sqrt(1-sin_theta[pupil_mask]**2)
# Fet the angles.
- cos_theta = Z / R3
+ # cos_theta = Z / R3
+
illumination_cos_theta = (
np.cos(np.arccos(cos_theta) + illumination_angle)
)
phi = np.arctan2(Y, X)
- return R3, cos_theta, illumination_cos_theta, phi
+ return R3, cos_theta, illumination_cos_theta, phi, pupil_mask
def get(
self,
@@ -1083,6 +1093,7 @@ def get(
illumination_angle: float,
amp_factor: float,
phase_shift_correction: bool,
+ pupil: ArrayLike,
**kwargs,
) -> ArrayLike[float]:
"""Abstract method to initialize the Mie scatterer"""
@@ -1099,6 +1110,10 @@ def get(
ratio = offset_z / (working_distance - z)
+ # Wave vector.
+ k = 2 * np.pi / wavelength * refractive_index_medium
+
+
# Position of pbjective relative particle.
relative_position = np.array(
(
@@ -1109,11 +1124,12 @@ def get(
)
# Get field evaluation plane at offset_z.
- R3_field, cos_theta_field, illumination_angle_field, phi_field =\
+ R3_field, cos_theta_field, illumination_angle_field, phi_field, pupil_mask =\
self.get_plane_in_polar_coords(
arr.shape, voxel_size,
relative_position * ratio,
- illumination_angle
+ illumination_angle,
+ k
)
cos_phi_field, sin_phi_field = np.cos(phi_field), np.sin(phi_field)
@@ -1132,9 +1148,9 @@ def get(
)
# If the beam is within the pupil.
- pupil_mask = (x_farfield - position_objective[0]) ** 2 + (
- y_farfield - position_objective[1]
- ) ** 2 < (pupil_physical_size / 2) ** 2
+ # pupil_mask = (x_farfield - position_objective[0]) ** 2 + (
+ # y_farfield - position_objective[1]
+ # ) ** 2 < (pupil_physical_size / 2) ** 2
R3_field = R3_field[pupil_mask]
cos_theta_field = cos_theta_field[pupil_mask]
@@ -1169,9 +1185,6 @@ def get(
* illumination_angle_field
)
- # Wave vector.
- k = 2 * np.pi / wavelength * refractive_index_medium
-
# Harmonics.
A, B = coefficients(L)
PI, TAU = mie.harmonics(illumination_angle_field, L)
@@ -1188,12 +1201,14 @@ def get(
[E[i] * B[i] * PI[i] + E[i] * A[i] * TAU[i] for i in range(0, L)]
)
- arr[pupil_mask] = (
- -1j
- / (k * R3_field)
- * np.exp(1j * k * R3_field)
- * (S2 * S2_coef + S1 * S1_coef)
- ) / amp_factor
+ # arr[pupil_mask] = (
+ # -1j
+ # / (k * R3_field)
+ # * np.exp(1j * k * R3_field)
+ # * (S2 * S2_coef + S1 * S1_coef)
+ # ) / amp_factor
+ arr[pupil_mask] = (S2 * S2_coef + S1 * S1_coef)/amp_factor
+
# For phase shift correction (a multiplication of the field
# by exp(1j * k * z)).
@@ -1213,13 +1228,19 @@ def get(
mask = np.exp(-0.5 * (x ** 2 + y ** 2) / ((sigma) ** 2))
arr = arr * mask
- fourier_field = np.fft.fft2(arr)
+ if len(pupil)>0:
+ c_pix=[arr.shape[0]//2,arr.shape[1]//2]
+
+ arr[c_pix[0]-pupil.shape[0]//2:c_pix[0]+pupil.shape[0]//2,c_pix[1]-pupil.shape[1]//2:c_pix[1]+pupil.shape[1]//2]*=pupil
+ fourier_field = -np.fft.ifft2(np.fft.fftshift(np.fft.fft2(np.fft.fftshift(arr))))
+ # fourier_field = np.fft.fft2(arr)
propagation_matrix = get_propagation_matrix(
fourier_field.shape,
pixel_size=voxel_size[2],
wavelength=wavelength / refractive_index_medium,
- to_z=(-offset_z - z),
+ # to_z=(-offset_z - z),
+ to_z=(-z),
dy=(
relative_position[0] * ratio
+ position[0]
@@ -1232,7 +1253,7 @@ def get(
),
)
fourier_field = (
- fourier_field * propagation_matrix * np.exp(-1j * k * offset_z)
+ fourier_field * propagation_matrix #* np.exp(-1j * k * offset_z)
)
if return_fft:
From 2b56f74e7420032db5d1efc3fd7dafdad4640261 Mon Sep 17 00:00:00 2001
From: Carlo
Date: Sun, 4 Jan 2026 02:55:21 +0100
Subject: [PATCH 07/24] u
---
deeptrack/optics.py | 1 -
deeptrack/scatterers.py | 7 +++++--
2 files changed, 5 insertions(+), 3 deletions(-)
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index 4c23469f..bfccb4f7 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -2015,7 +2015,6 @@ def _create_volume(
# fudge_factor = scale[0] * scale[1] / scale[2]
for scatterer in list_of_scatterers:
-
position = _get_position(scatterer, mode="corner", return_z=True)
if scatterer.main_property == "intensity":
scatterer_value = scatterer.get_property("intensity") #* fudge_factor
diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py
index b52408e2..408506d6 100644
--- a/deeptrack/scatterers.py
+++ b/deeptrack/scatterers.py
@@ -1054,9 +1054,9 @@ def get_plane_in_polar_coords(
R2_squared = X ** 2 + Y ** 2
R3 = np.sqrt(R2_squared + Z ** 2) # Might be +z instead of -z.
Q = np.sqrt(R2_squared)/voxel_size[0]**2*2*np.pi/shape[0]
+ # is dimensionally ok? Doesn't look like it.
sin_theta=Q/(k)
- pupil_mask=sin_theta<=1
-
+ pupil_mask=sin_theta<1
cos_theta=np.zeros(sin_theta.shape)
cos_theta[pupil_mask]=np.sqrt(1-sin_theta[pupil_mask]**2)
@@ -1103,6 +1103,9 @@ def get(
voxel_size = get_active_voxel_size()
arr = pad_image_to_fft(np.zeros((xSize, ySize))).astype(complex)
position = np.array(position) * voxel_size[: len(position)]
+ print(xSize, ySize)
+ print(padding)
+ print(position)
pupil_physical_size = working_distance * np.tan(collection_angle) * 2
From 2da3587e45ffd7f41232cf6dcfd3db61eb3d0318 Mon Sep 17 00:00:00 2001
From: Carlo
Date: Mon, 5 Jan 2026 03:16:12 +0100
Subject: [PATCH 08/24] fixed upscale with field scatterers
---
deeptrack/holography.py | 4 ++--
deeptrack/optics.py | 24 ++++++++++++------------
deeptrack/scatterers.py | 21 +++++++++++----------
3 files changed, 25 insertions(+), 24 deletions(-)
diff --git a/deeptrack/holography.py b/deeptrack/holography.py
index 380969cf..39fc4e43 100644
--- a/deeptrack/holography.py
+++ b/deeptrack/holography.py
@@ -146,8 +146,8 @@ def get_propagation_matrix(
x = np.arange(0, xr, 1) - xr / 2 + (xr % 2) / 2
y = np.arange(0, yr, 1) - yr / 2 + (yr % 2) / 2
- x = 2 * np.pi / pixel_size * x / xr
- y = 2 * np.pi / pixel_size * y / yr
+ x = 2 * np.pi / pixel_size[0] * x / xr
+ y = 2 * np.pi / pixel_size[1] * y / yr
KXk, KYk = np.meshgrid(x, y)
KXk = KXk.astype(complex)
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index bfccb4f7..092696a5 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -351,9 +351,9 @@ def get(
if isinstance(scatterer, ScatteredField)
]
- warn_upscale_fields = False
- if field_samples and np.any(upscale != 1):
- warn_upscale_fields = True
+ # warn_upscale_fields = False
+ # if field_samples and np.any(upscale != 1):
+ # warn_upscale_fields = True
# Merge all volumes into a single volume.
sample_volume, limits = _create_volume(
@@ -375,14 +375,14 @@ def get(
imaged_sample = self._objective.resolve(sample_volume)
- if warn_upscale_fields:
- warnings.warn(
- "dt.Upscale is active while FieldScatterers are present. "
- "Coherent fields are injected without resampling, so the "
- "physical interpretation may change with Upscale. "
- "This behavior is currently undefined.",
- UserWarning,
- )
+ # if warn_upscale_fields:
+ # warnings.warn(
+ # "dt.Upscale is active while FieldScatterers are present. "
+ # "Coherent fields are injected without resampling, so the "
+ # "physical interpretation may change with Upscale. "
+ # "This behavior is currently undefined.",
+ # UserWarning,
+ # )
# Collect main_property from scatterers
main_properties = {
@@ -1365,7 +1365,7 @@ def get(
if output_region[3] is None
else int(output_region[3] - limits[1, 0] + pad[3])
)
-
+
padded_volume = padded_volume[
output_region[0] : output_region[2],
output_region[1] : output_region[3],
diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py
index 408506d6..eba0566b 100644
--- a/deeptrack/scatterers.py
+++ b/deeptrack/scatterers.py
@@ -1002,8 +1002,10 @@ def get_XY(
The meshgrid of X and Y coordinates.
"""
- x = np.arange(shape[0]) - shape[0] / 2
- y = np.arange(shape[1]) - shape[1] / 2
+ # x = np.arange(shape[0]) - shape[0] / 2
+ # y = np.arange(shape[1]) - shape[1] / 2
+ x = np.arange(shape[0]) - (shape[0] - 1) / 2
+ y = np.arange(shape[1]) - (shape[1] - 1) / 2
return np.meshgrid(x * voxel_size[0], y * voxel_size[1], indexing="ij")
def get_detector_mask(
@@ -1101,11 +1103,9 @@ def get(
# Get size of the output.
xSize, ySize = self.get_xy_size(output_region, padding)
voxel_size = get_active_voxel_size()
+ scale = get_active_scale()
arr = pad_image_to_fft(np.zeros((xSize, ySize))).astype(complex)
- position = np.array(position) * voxel_size[: len(position)]
- print(xSize, ySize)
- print(padding)
- print(position)
+ position = np.array(position) * scale[: len(position)] * voxel_size[: len(position)]
pupil_physical_size = working_distance * np.tan(collection_angle) * 2
@@ -1117,7 +1117,7 @@ def get(
k = 2 * np.pi / wavelength * refractive_index_medium
- # Position of pbjective relative particle.
+ # Position of objective relative particle.
relative_position = np.array(
(
position_objective[0] - position[0],
@@ -1236,11 +1236,11 @@ def get(
arr[c_pix[0]-pupil.shape[0]//2:c_pix[0]+pupil.shape[0]//2,c_pix[1]-pupil.shape[1]//2:c_pix[1]+pupil.shape[1]//2]*=pupil
fourier_field = -np.fft.ifft2(np.fft.fftshift(np.fft.fft2(np.fft.fftshift(arr))))
- # fourier_field = np.fft.fft2(arr)
+ # fourier_field = np.fft.fft2(np.fft.fftshift(arr))
propagation_matrix = get_propagation_matrix(
fourier_field.shape,
- pixel_size=voxel_size[2],
+ pixel_size=voxel_size[:2], # this needs a double check
wavelength=wavelength / refractive_index_medium,
# to_z=(-offset_z - z),
to_z=(-z),
@@ -1252,9 +1252,10 @@ def get(
dx=(
relative_position[1] * ratio
+ position[1]
- + (padding[1] - arr.shape[1] / 2) * voxel_size[1]
+ + (padding[2] - arr.shape[1] / 2) * voxel_size[1] # check if padding is top, bottom, left, right
),
)
+
fourier_field = (
fourier_field * propagation_matrix #* np.exp(-1j * k * offset_z)
)
From cdaf8de6d0e5dbb0d57577285ac9638fbe233c39 Mon Sep 17 00:00:00 2001
From: Carlo
Date: Wed, 7 Jan 2026 10:54:33 +0100
Subject: [PATCH 09/24] u
---
deeptrack/scatterers.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py
index eba0566b..70084488 100644
--- a/deeptrack/scatterers.py
+++ b/deeptrack/scatterers.py
@@ -1002,10 +1002,10 @@ def get_XY(
The meshgrid of X and Y coordinates.
"""
- # x = np.arange(shape[0]) - shape[0] / 2
- # y = np.arange(shape[1]) - shape[1] / 2
- x = np.arange(shape[0]) - (shape[0] - 1) / 2
- y = np.arange(shape[1]) - (shape[1] - 1) / 2
+ x = np.arange(shape[0]) - shape[0] / 2
+ y = np.arange(shape[1]) - shape[1] / 2
+ # x = np.arange(shape[0]) - (shape[0] - 1) / 2
+ # y = np.arange(shape[1]) - (shape[1] - 1) / 2
return np.meshgrid(x * voxel_size[0], y * voxel_size[1], indexing="ij")
def get_detector_mask(
From 191d3a51c13b015c0867bcd116cabe61a9f84757 Mon Sep 17 00:00:00 2001
From: Carlo
Date: Wed, 7 Jan 2026 13:26:38 +0100
Subject: [PATCH 10/24] rolled back
momentarily removed changed to MieScatetterer suggested by Daniel. Now Upscale doesn't work with MieScatterers
---
deeptrack/scatterers.py | 76 ++++++++++++++++++++++-------------------
1 file changed, 40 insertions(+), 36 deletions(-)
diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py
index 70084488..a9681953 100644
--- a/deeptrack/scatterers.py
+++ b/deeptrack/scatterers.py
@@ -879,7 +879,7 @@ def __init__(
illumination_angle: float=0,
amp_factor: float=1,
phase_shift_correction: bool=False,
- pupil: ArrayLike=[],
+ # pupil: ArrayLike=[], # Daniel
**kwargs,
) -> None:
if polarization_angle is not None:
@@ -913,7 +913,7 @@ def __init__(
illumination_angle=illumination_angle,
amp_factor=amp_factor,
phase_shift_correction=phase_shift_correction,
- pupil=pupil,
+ # pupil=pupil, # Daniel
**kwargs,
)
@@ -1004,8 +1004,6 @@ def get_XY(
"""
x = np.arange(shape[0]) - shape[0] / 2
y = np.arange(shape[1]) - shape[1] / 2
- # x = np.arange(shape[0]) - (shape[0] - 1) / 2
- # y = np.arange(shape[1]) - (shape[1] - 1) / 2
return np.meshgrid(x * voxel_size[0], y * voxel_size[1], indexing="ij")
def get_detector_mask(
@@ -1042,7 +1040,7 @@ def get_plane_in_polar_coords(
voxel_size: ArrayLike[float],
plane_position: float,
illumination_angle: float,
- k: float,
+ # k: float, # Daniel
) -> tuple[float, float, float, float]:
"""Computes the coordinates of the plane in polar form."""
@@ -1055,22 +1053,24 @@ def get_plane_in_polar_coords(
R2_squared = X ** 2 + Y ** 2
R3 = np.sqrt(R2_squared + Z ** 2) # Might be +z instead of -z.
- Q = np.sqrt(R2_squared)/voxel_size[0]**2*2*np.pi/shape[0]
- # is dimensionally ok? Doesn't look like it.
- sin_theta=Q/(k)
- pupil_mask=sin_theta<1
- cos_theta=np.zeros(sin_theta.shape)
- cos_theta[pupil_mask]=np.sqrt(1-sin_theta[pupil_mask]**2)
+
+ # # DANIEL
+ # Q = np.sqrt(R2_squared)/voxel_size[0]**2*2*np.pi/shape[0]
+ # # is dimensionally ok?
+ # sin_theta=Q/(k)
+ # pupil_mask=sin_theta<1
+ # cos_theta=np.zeros(sin_theta.shape)
+ # cos_theta[pupil_mask]=np.sqrt(1-sin_theta[pupil_mask]**2)
# Fet the angles.
- # cos_theta = Z / R3
+ cos_theta = Z / R3
illumination_cos_theta = (
np.cos(np.arccos(cos_theta) + illumination_angle)
)
phi = np.arctan2(Y, X)
- return R3, cos_theta, illumination_cos_theta, phi, pupil_mask
+ return R3, cos_theta, illumination_cos_theta, phi#, pupil_mask # Daniel
def get(
self,
@@ -1095,7 +1095,7 @@ def get(
illumination_angle: float,
amp_factor: float,
phase_shift_correction: bool,
- pupil: ArrayLike,
+ # pupil: ArrayLike, # Daniel
**kwargs,
) -> ArrayLike[float]:
"""Abstract method to initialize the Mie scatterer"""
@@ -1126,13 +1126,13 @@ def get(
)
)
- # Get field evaluation plane at offset_z.
- R3_field, cos_theta_field, illumination_angle_field, phi_field, pupil_mask =\
+ # Get field evaluation plane at offset_z. # , pupil_mask # Daniel
+ R3_field, cos_theta_field, illumination_angle_field, phi_field =\
self.get_plane_in_polar_coords(
arr.shape, voxel_size,
relative_position * ratio,
illumination_angle,
- k
+ # k # Daniel
)
cos_phi_field, sin_phi_field = np.cos(phi_field), np.sin(phi_field)
@@ -1150,10 +1150,10 @@ def get(
sin_phi_field / ratio
)
- # If the beam is within the pupil.
- # pupil_mask = (x_farfield - position_objective[0]) ** 2 + (
- # y_farfield - position_objective[1]
- # ) ** 2 < (pupil_physical_size / 2) ** 2
+ # If the beam is within the pupil. Remove if Daniel
+ pupil_mask = (x_farfield - position_objective[0]) ** 2 + (
+ y_farfield - position_objective[1]
+ ) ** 2 < (pupil_physical_size / 2) ** 2
R3_field = R3_field[pupil_mask]
cos_theta_field = cos_theta_field[pupil_mask]
@@ -1204,13 +1204,14 @@ def get(
[E[i] * B[i] * PI[i] + E[i] * A[i] * TAU[i] for i in range(0, L)]
)
- # arr[pupil_mask] = (
- # -1j
- # / (k * R3_field)
- # * np.exp(1j * k * R3_field)
- # * (S2 * S2_coef + S1 * S1_coef)
- # ) / amp_factor
- arr[pupil_mask] = (S2 * S2_coef + S1 * S1_coef)/amp_factor
+ # Daniel
+ # arr[pupil_mask] = (S2 * S2_coef + S1 * S1_coef)/amp_factor
+ arr[pupil_mask] = (
+ -1j
+ / (k * R3_field)
+ * np.exp(1j * k * R3_field)
+ * (S2 * S2_coef + S1 * S1_coef)
+ ) / amp_factor
# For phase shift correction (a multiplication of the field
@@ -1231,19 +1232,22 @@ def get(
mask = np.exp(-0.5 * (x ** 2 + y ** 2) / ((sigma) ** 2))
arr = arr * mask
- if len(pupil)>0:
- c_pix=[arr.shape[0]//2,arr.shape[1]//2]
+ # Not sure if needed... CM
+ # if len(pupil)>0:
+ # c_pix=[arr.shape[0]//2,arr.shape[1]//2]
- arr[c_pix[0]-pupil.shape[0]//2:c_pix[0]+pupil.shape[0]//2,c_pix[1]-pupil.shape[1]//2:c_pix[1]+pupil.shape[1]//2]*=pupil
- fourier_field = -np.fft.ifft2(np.fft.fftshift(np.fft.fft2(np.fft.fftshift(arr))))
- # fourier_field = np.fft.fft2(np.fft.fftshift(arr))
+ # arr[c_pix[0]-pupil.shape[0]//2:c_pix[0]+pupil.shape[0]//2,c_pix[1]-pupil.shape[1]//2:c_pix[1]+pupil.shape[1]//2]*=pupil
+
+ # Daniel
+ # fourier_field = -np.fft.ifft2(np.fft.fftshift(np.fft.fft2(np.fft.fftshift(arr))))
+ fourier_field = np.fft.fft2(arr)
propagation_matrix = get_propagation_matrix(
fourier_field.shape,
pixel_size=voxel_size[:2], # this needs a double check
wavelength=wavelength / refractive_index_medium,
- # to_z=(-offset_z - z),
- to_z=(-z),
+ # to_z=(-z), # Daniel
+ to_z=(-offset_z - z),
dy=(
relative_position[0] * ratio
+ position[0]
@@ -1257,7 +1261,7 @@ def get(
)
fourier_field = (
- fourier_field * propagation_matrix #* np.exp(-1j * k * offset_z)
+ fourier_field * propagation_matrix * np.exp(-1j * k * offset_z) # Remove last part (from exp)) if Daniel
)
if return_fft:
From 6dde1adfb0741edec8a15003873295d67b9380b7 Mon Sep 17 00:00:00 2001
From: Carlo
Date: Thu, 8 Jan 2026 00:00:39 +0100
Subject: [PATCH 11/24] added contrast type
included objective attribute (contrast_type) to decide wether to use intensity or refractive index
---
deeptrack/holography.py | 18 ++++++++---
deeptrack/optics.py | 72 ++++++++++++++++++++++++++++++-----------
deeptrack/scatterers.py | 8 ++---
3 files changed, 70 insertions(+), 28 deletions(-)
diff --git a/deeptrack/holography.py b/deeptrack/holography.py
index 39fc4e43..141cc540 100644
--- a/deeptrack/holography.py
+++ b/deeptrack/holography.py
@@ -101,7 +101,7 @@ def get_propagation_matrix(
def get_propagation_matrix(
shape: tuple[int, int],
to_z: float,
- pixel_size: float,
+ pixel_size: float | tuple[float, float],
wavelength: float,
dx: float = 0,
dy: float = 0
@@ -118,8 +118,8 @@ def get_propagation_matrix(
The dimensions of the optical field (height, width).
to_z: float
Propagation distance along the z-axis.
- pixel_size: float
- The physical size of each pixel in the optical field.
+ pixel_size: float | tuple[float, float]
+ Physical pixel size. If scalar, isotropic pixels are assumed.
wavelength: float
The wavelength of the optical field.
dx: float, optional
@@ -140,14 +140,22 @@ def get_propagation_matrix(
"""
+ if pixel_size is None:
+ pixel_size = get_active_voxel_size()
+
+ if np.isscalar(pixel_size):
+ pixel_size = (pixel_size, pixel_size)
+
+ px, py = pixel_size
+
k = 2 * np.pi / wavelength
yr, xr, *_ = shape
x = np.arange(0, xr, 1) - xr / 2 + (xr % 2) / 2
y = np.arange(0, yr, 1) - yr / 2 + (yr % 2) / 2
- x = 2 * np.pi / pixel_size[0] * x / xr
- y = 2 * np.pi / pixel_size[1] * y / yr
+ x = 2 * np.pi / px * x / xr
+ y = 2 * np.pi / py * y / yr
KXk, KYk = np.meshgrid(x, y)
KXk = KXk.astype(complex)
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index 092696a5..3990e3c1 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -291,6 +291,15 @@ def get(
# Grab properties from the objective to pass to the sample
additional_sample_kwargs = self._objective.properties()
+ contrast_type = getattr(self._objective, "contrast_type", None)
+ if contrast_type is None:
+ raise RuntimeError(
+ f"{self._objective.__class__.__name__} must define `contrast_type` "
+ "(e.g. 'intensity' or 'refractive_index')."
+ )
+
+ additional_sample_kwargs["contrast_type"] = contrast_type
+
# # Calculate required output image for the given upscale
# # This way of providing the upscale will be deprecated in the future
@@ -384,25 +393,25 @@ def get(
# UserWarning,
# )
- # Collect main_property from scatterers
- main_properties = {
- s.main_property
- for s in list_of_scatterers
- if hasattr(s, "main_property")
- }
+ # # Collect main_property from scatterers
+ # main_properties = {
+ # s.main_property
+ # for s in list_of_scatterers
+ # if hasattr(s, "main_property")
+ # }
- if len(main_properties) != 1:
- raise ValueError(
- f"Inconsistent main_property across scatterers: {main_properties}"
- )
+ # if len(main_properties) != 1:
+ # raise ValueError(
+ # f"Inconsistent main_property across scatterers: {main_properties}"
+ # )
- main_property = main_properties.pop()
+ # main_property = main_properties.pop()
# Handling upscale from dt.Upscale() here to eliminate Image
# wrapping issues.
if np.any(np.array(upscale) != 1):
ux, uy = upscale[:2]
- if main_property == "intensity":
+ if contrast_type == "intensity":
print("Using sum pooling for intensity downscaling.")
imaged_sample = SumPoolingCM((ux, uy, 1))(imaged_sample)
else:
@@ -1045,6 +1054,7 @@ class Fluorescence(Optics):
1.4
"""
+ contrast_type = "intensity"
def get(
self: Fluorescence,
@@ -1270,6 +1280,8 @@ class Brightfield(Optics):
"""
+ contrast_type = "refractive_index"
+
__conversion_table__ = ConversionTable(
working_distance=(u.meter, u.meter),
)
@@ -1988,6 +2000,12 @@ def _create_volume(
Spatial limits of the volume.
"""
+ contrast_type = kwargs.get("contrast_type", None)
+ if contrast_type is None:
+ raise RuntimeError(
+ "_create_volume requires a contrast_type "
+ "(e.g. 'intensity' or 'refractive_index')"
+ )
if not isinstance(list_of_scatterers, list):
list_of_scatterers = [list_of_scatterers]
@@ -2016,12 +2034,30 @@ def _create_volume(
for scatterer in list_of_scatterers:
position = _get_position(scatterer, mode="corner", return_z=True)
- if scatterer.main_property == "intensity":
- scatterer_value = scatterer.get_property("intensity") #* fudge_factor
- elif scatterer.main_property == "refractive_index":
- scatterer_value = scatterer.get_property("refractive_index") - refractive_index_medium
- else: # fallback to generic value
- scatterer_value = scatterer.get_property("value")
+ # if scatterer.main_property == "intensity":
+ # scatterer_value = scatterer.get_property("intensity") #* fudge_factor
+ # elif scatterer.main_property == "refractive_index":
+ # scatterer_value = scatterer.get_property("refractive_index") - refractive_index_medium
+ # else: # fallback to generic value
+ # scatterer_value = scatterer.get_property("value")
+
+ # # Scale the array accordingly
+ # scatterer.array = scatterer.array * scatterer_value
+
+ if contrast_type == "intensity":
+ value = scatterer.get_property("intensity", None)
+ if value is None:
+ raise ValueError("Scatterer has no intensity.")
+ scatterer_value = value
+
+ elif contrast_type == "refractive_index":
+ ri = scatterer.get_property("refractive_index", None)
+ if ri is None:
+ raise ValueError("Scatterer has no refractive_index.")
+ scatterer_value = ri - refractive_index_medium
+
+ else:
+ raise RuntimeError(f"Unknown contrast_type: {contrast_type}")
# Scale the array accordingly
scatterer.array = scatterer.array * scatterer_value
diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py
index a9681953..a93a732e 100644
--- a/deeptrack/scatterers.py
+++ b/deeptrack/scatterers.py
@@ -353,8 +353,8 @@ def _wrap_output(self, array, props) -> ScatteredVolume:
position=props.get("position", (0, 0)),
z=props.get("z", 0.0),
value=props.get("value", 1.0),
- intensity=props.get("intensity", None),
- refractive_index=props.get("refractive_index", None),
+ intensity=props.get("intensity", 1.0),
+ refractive_index=props.get("refractive_index", 1.59),
properties=props.copy(),
main_property=self.main_property,
)]
@@ -364,7 +364,7 @@ def _wrap_output(self, array, props) -> ScatteredField:
return [ScatteredField(
array=array,
position=props.get("position", (0, 0)),
- wavelength=props.get("wavelength", 0.0),
+ wavelength=props.get("wavelength", 532.0),
properties=props.copy(),
main_property=self.main_property,
)]
@@ -394,8 +394,6 @@ class PointParticle(VolumeScatterer):
for `Brightfield` and `intensity` for `Fluorescence`).
"""
-
- main_property = "intensity"
def __init__(
self: PointParticle,
From 4527e7edcf50f664336e24c09cdfecb6471194a5 Mon Sep 17 00:00:00 2001
From: Carlo
Date: Thu, 8 Jan 2026 00:08:26 +0100
Subject: [PATCH 12/24] u
---
deeptrack/optics.py | 23 -----------------------
1 file changed, 23 deletions(-)
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index 3990e3c1..98026db4 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -393,20 +393,6 @@ def get(
# UserWarning,
# )
- # # Collect main_property from scatterers
- # main_properties = {
- # s.main_property
- # for s in list_of_scatterers
- # if hasattr(s, "main_property")
- # }
-
- # if len(main_properties) != 1:
- # raise ValueError(
- # f"Inconsistent main_property across scatterers: {main_properties}"
- # )
-
- # main_property = main_properties.pop()
-
# Handling upscale from dt.Upscale() here to eliminate Image
# wrapping issues.
if np.any(np.array(upscale) != 1):
@@ -2034,15 +2020,6 @@ def _create_volume(
for scatterer in list_of_scatterers:
position = _get_position(scatterer, mode="corner", return_z=True)
- # if scatterer.main_property == "intensity":
- # scatterer_value = scatterer.get_property("intensity") #* fudge_factor
- # elif scatterer.main_property == "refractive_index":
- # scatterer_value = scatterer.get_property("refractive_index") - refractive_index_medium
- # else: # fallback to generic value
- # scatterer_value = scatterer.get_property("value")
-
- # # Scale the array accordingly
- # scatterer.array = scatterer.array * scatterer_value
if contrast_type == "intensity":
value = scatterer.get_property("intensity", None)
From 93a0ef4bcfafca040f07e68c7504dbc9770bacef Mon Sep 17 00:00:00 2001
From: Carlo
Date: Thu, 8 Jan 2026 00:36:08 +0100
Subject: [PATCH 13/24] rebased ok
---
deeptrack/optics.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index 98026db4..9d10cbe5 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -250,7 +250,7 @@ def __init__(
self._sample = self.add_feature(sample)
self._objective = self.add_feature(objective)
- self._sample.store_properties()
+ # self._sample.store_properties()
def get(
self: Microscope,
@@ -1047,7 +1047,7 @@ def get(
illuminated_volume: ArrayLike[complex],
limits: ArrayLike[int],
**kwargs: Any,
- ) -> Image:
+ ) -> ArrayLike[complex]:
"""Simulates the imaging process using a fluorescence microscope.
This method convolves the 3D illuminated volume with a pupil function
@@ -2055,7 +2055,7 @@ def _create_volume(
# Pad scatterer to avoid edge effects during interpolation
padded_scatterer = scatterer
padded_scatterer.array = np.pad(
- scatterer.array._value, # this is a temporary fix, Image should be removed from features
+ scatterer.array,
[(2, 2), (2, 2), (2, 2)],
"constant",
constant_values=0,
From dab26360fe6e7d156eb182fb6d1bef493c3240dc Mon Sep 17 00:00:00 2001
From: Carlo
Date: Thu, 8 Jan 2026 11:47:31 +0100
Subject: [PATCH 14/24] u
---
deeptrack/optics.py | 54 +++++++++++------------------------------
deeptrack/scatterers.py | 10 ++++----
2 files changed, 19 insertions(+), 45 deletions(-)
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index 9d10cbe5..a7c44e8c 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -154,7 +154,7 @@ def _pad_volume(
from deeptrack.math import AveragePooling
from deeptrack.features import propagate_data_to_dependencies
from deeptrack.features import DummyFeature, Feature, StructuralFeature
-from deeptrack.image import Image, pad_image_to_fft
+from deeptrack.image import pad_image_to_fft
from deeptrack.types import ArrayLike, PropertyLike
from deeptrack import image
@@ -198,7 +198,7 @@ class Microscope(StructuralFeature):
Methods
-------
- `get(image: Image or None, **kwargs: Any) -> Image`
+ `get(image: np.ndarray or None, **kwargs: Any) -> np.ndarray`
Simulates the imaging process using the defined optical system and
returns the resulting image.
@@ -254,9 +254,9 @@ def __init__(
def get(
self: Microscope,
- image: Image | None,
+ image: np.ndarray | None,
**kwargs: Any,
- ) -> Image:
+ ) -> np.ndarray:
"""Generate an image of the sample using the defined optical system.
This method processes the sample through the optical system to
@@ -264,14 +264,14 @@ def get(
Parameters
----------
- image: Image | None
+ image: np.ndarray | None
The input image to be processed. If None, a new image is created.
**kwargs: Any
Additional parameters for the imaging process.
Returns
-------
- Image: Image
+ image: np.ndarray
The processed image after applying the optical system.
Examples
@@ -300,14 +300,6 @@ def get(
additional_sample_kwargs["contrast_type"] = contrast_type
-
- # # Calculate required output image for the given upscale
- # # This way of providing the upscale will be deprecated in the future
- # # in favor of dt.Upscale().
- # _upscale_given_by_optics = additional_sample_kwargs["upscale"]
- # if np.array(_upscale_given_by_optics).size == 1:
- # _upscale_given_by_optics = (_upscale_given_by_optics,) * 3
-
with u.context(
create_context(
*additional_sample_kwargs["voxel_size"]#, *_upscale_given_by_optics
@@ -359,21 +351,12 @@ def get(
for scatterer in list_of_scatterers
if isinstance(scatterer, ScatteredField)
]
-
- # warn_upscale_fields = False
- # if field_samples and np.any(upscale != 1):
- # warn_upscale_fields = True
# Merge all volumes into a single volume.
sample_volume, limits = _create_volume(
volume_samples,
**additional_sample_kwargs,
)
- # sample_volume = Image(sample_volume)
-
- # Merge all properties into the volume.
- # for scatterer in volume_samples + field_samples:
- # sample_volume.merge_properties_from(scatterer)
# Let the objective know about the limits of the volume and all the fields.
propagate_data_to_dependencies(
@@ -384,15 +367,6 @@ def get(
imaged_sample = self._objective.resolve(sample_volume)
- # if warn_upscale_fields:
- # warnings.warn(
- # "dt.Upscale is active while FieldScatterers are present. "
- # "Coherent fields are injected without resampling, so the "
- # "physical interpretation may change with Upscale. "
- # "This behavior is currently undefined.",
- # UserWarning,
- # )
-
# Handling upscale from dt.Upscale() here to eliminate Image
# wrapping issues.
if np.any(np.array(upscale) != 1):
@@ -1024,7 +998,7 @@ class Fluorescence(Optics):
Methods
-------
- `get(illuminated_volume: array_like[complex], limits: array_like[int, int], **kwargs: Any) -> Image`
+ `get(illuminated_volume: array_like[complex], limits: array_like[int, int], **kwargs: Any) -> np.ndarray`
Simulates the imaging process using a fluorescence microscope.
Examples
@@ -1066,7 +1040,7 @@ def get(
Returns
-------
- Image: Image
+ image: np.ndarray
A 2D image object representing the fluorescence projection.
Notes
@@ -1084,7 +1058,7 @@ def get(
>>> optics = dt.Fluorescence(
... NA=1.4, wavelength=0.52e-6, magnification=60,
... )
- >>> volume = dt.Image(np.ones((128, 128, 10), dtype=complex))
+ >>> volume = np.ones((128, 128, 10), dtype=complex)
>>> limits = np.array([[0, 128], [0, 128], [0, 10]])
>>> properties = optics.properties()
>>> filtered_properties = {
@@ -1250,7 +1224,7 @@ class Brightfield(Optics):
-------
`get(illuminated_volume: array_like[complex],
limits: array_like[int, int], fields: array_like[complex],
- **kwargs: Any) -> Image`
+ **kwargs: Any) -> np.ndarray`
Simulates imaging with brightfield microscopy.
@@ -1278,7 +1252,7 @@ def get(
limits: ArrayLike[int],
fields: ArrayLike[complex],
**kwargs: Any,
- ) -> Image:
+ ) -> np.ndarray:
"""Simulates imaging with brightfield microscopy.
This method propagates light through the given volume, applying
@@ -1303,7 +1277,7 @@ def get(
Returns
-------
- Image: Image
+ image: np.ndarray
Processed image after simulating the brightfield imaging process.
Examples
@@ -1318,7 +1292,7 @@ def get(
... wavelength=0.52e-6,
... magnification=60,
... )
- >>> volume = dt.Image(np.ones((128, 128, 10), dtype=complex))
+ >>> volume = np.ones((128, 128, 10), dtype=complex)
>>> limits = np.array([[0, 128], [0, 128], [0, 10]])
>>> fields = np.array([np.ones((162, 162), dtype=complex)])
>>> properties = optics.properties()
@@ -1665,7 +1639,7 @@ def get(
limits: ArrayLike[int],
fields: ArrayLike[complex],
**kwargs: Any,
- ) -> Image:
+ ) -> np.ndarray:
"""Retrieve the darkfield image of the illuminated volume.
Parameters
diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py
index a93a732e..ea0e140d 100644
--- a/deeptrack/scatterers.py
+++ b/deeptrack/scatterers.py
@@ -176,7 +176,7 @@
)
from deeptrack.backend import mie
from deeptrack.features import Feature, MERGE_STRATEGY_APPEND
-from deeptrack.image import pad_image_to_fft, Image
+from deeptrack.image import pad_image_to_fft
from deeptrack.types import ArrayLike
from deeptrack import units_registry as u
@@ -300,7 +300,7 @@ def _process_and_get(
upsample_axes=None,
crop_empty=True,
**kwargs
- ) -> list[Image] | list[np.ndarray]:
+ ) -> list[np.ndarray]:
# Post processes the created object to handle upsampling,
# as well as cropping empty slices.
if not self._processed_properties:
@@ -407,7 +407,7 @@ def __init__(
def get(
self: PointParticle,
- image: Image | np.ndarray,
+ image: np.ndarray,
**kwarg: Any,
) -> NDArray[Any] | torch.Tensor:
"""Evaluate and return the scatterer volume."""
@@ -576,7 +576,7 @@ def __init__(
def get(
self,
- image: Image | np.ndarray,
+ image: np.ndarray,
radius: float,
voxel_size: float,
**kwargs
@@ -713,7 +713,7 @@ def _process_properties(
def get(
self,
- image: Image | np.ndarray,
+ image: np.ndarray,
radius: float,
rotation: ArrayLike[float] | float,
voxel_size: float,
From e04589f0efb3ae940459f966b15070d5f956f097 Mon Sep 17 00:00:00 2001
From: Carlo
Date: Fri, 9 Jan 2026 01:40:22 +0100
Subject: [PATCH 15/24] fixed sampletomask
---
deeptrack/features.py | 232 ++++++++++++++++++++++++++++------------
deeptrack/optics.py | 6 +-
deeptrack/scatterers.py | 60 ++++++-----
3 files changed, 204 insertions(+), 94 deletions(-)
diff --git a/deeptrack/features.py b/deeptrack/features.py
index f76245b8..8a031051 100644
--- a/deeptrack/features.py
+++ b/deeptrack/features.py
@@ -171,7 +171,7 @@
from deeptrack.backend import config, TORCH_AVAILABLE, xp
from deeptrack.backend.core import DeepTrackNode
from deeptrack.backend.units import ConversionTable, create_context
-from deeptrack.image import Image #TODO TBE
+# from deeptrack.image import Image #TODO TBE
from deeptrack.properties import PropertyDict, SequentialProperty
from deeptrack.sources import SourceItem
from deeptrack.types import ArrayLike, PropertyLike
@@ -218,11 +218,11 @@
"OneOf",
"OneOfDict",
"LoadImage",
- "SampleToMasks", # TODO ***CM*** revise this after elimination of Image
+ "SampleToMasks",
"AsType",
"ChannelFirst2d",
- "Upscale", # TODO ***CM*** revise and check PyTorch afrer elimin. Image
- "NonOverlapping", # TODO ***CM*** revise + PyTorch afrer elimin. Image
+ "Upscale",
+ "NonOverlapping",
"Store",
"Squeeze",
"Unsqueeze",
@@ -7493,7 +7493,7 @@ def __init__(
def get(
self: Feature,
- image: np.ndarray | Image,
+ image: np.ndarray,
transformation_function: Callable[[Image], Image],
**kwargs: Any,
) -> Image:
@@ -7515,7 +7515,7 @@ def get(
"""
- return transformation_function(image)
+ return transformation_function(image.array)
def _process_and_get(
self: Feature,
@@ -7540,26 +7540,33 @@ def _process_and_get(
"""
# Handle list of images.
- if isinstance(images, list) and len(images) != 1:
- list_of_labels = super()._process_and_get(images, **kwargs)
- if not self._wrap_array_with_image:
- for idx, (label, image) in enumerate(zip(list_of_labels,
- images)):
- list_of_labels[idx] = \
- Image(label, copy=False).merge_properties_from(image)
- else:
- if isinstance(images, list):
- images = images[0]
- list_of_labels = []
- for prop in images.properties:
-
- if "position" in prop:
+ # if isinstance(images, list) and len(images) != 1:
+ list_of_labels = super()._process_and_get(images, **kwargs)
+ # print(len(list_of_labels))
+ # print(list_of_labels[0].shape)
+
+ from deeptrack.scatterers import ScatteredVolume
+ # if not self._wrap_array_with_image:
+ for idx, (label, image) in enumerate(zip(list_of_labels,
+ images)):
+ list_of_labels[idx] = \
+ ScatteredVolume(array=label, properties=image.properties.copy())
+ # Image(label, copy=False).merge_properties_from(image)
+ # else:
+ # if isinstance(images, list):
+ # images = images[0]
+ # list_of_labels = []
+ # for prop in images.properties:
+
+ # if "position" in prop:
+
+ # inp = Image(np.array(images))
+ # inp.append(prop)
+ # out = Image(self.get(inp, **kwargs))
+ # out.merge_properties_from(inp)
+ # list_of_labels.append(out)
- inp = Image(np.array(images))
- inp.append(prop)
- out = Image(self.get(inp, **kwargs))
- out.merge_properties_from(inp)
- list_of_labels.append(out)
+
# Create an empty output image.
output_region = kwargs["output_region"]
@@ -7574,8 +7581,10 @@ def _process_and_get(
from deeptrack.optics import _get_position
# Merge masks into the output.
- for label in list_of_labels:
- position = _get_position(label)
+ for volume in list_of_labels:
+ label = volume.array
+ position = _get_position(volume)
+
p0 = np.round(position - output_region[0:2])
if np.any(p0 > output.shape[0:2]) or \
@@ -7657,11 +7666,11 @@ def _process_and_get(
labelarg[..., label_index],
)
- if not self._wrap_array_with_image:
- return output
- output = Image(output)
- for label in list_of_labels:
- output.merge_properties_from(label)
+ # if not self._wrap_array_with_image:
+ # return output
+ # output = Image(output)
+ # for label in list_of_labels:
+ # output.merge_properties_from(label)
return output
@@ -8087,7 +8096,7 @@ def get(
raise ValueError(
"Factor must be an integer or a tuple of three integers."
)
-
+
# Create a context for upscaling and perform computation.
ctx = create_context(None, None, None, *factor)
with units.context(ctx):
@@ -8356,7 +8365,7 @@ def get(
list_of_volumes = [list_of_volumes]
for _ in range(max_iters):
-
+
list_of_volumes = [
self._resample_volume_position(volume)
for volume in list_of_volumes
@@ -8411,32 +8420,40 @@ def _check_non_overlapping(
- If bounding cubes overlap, voxel-level checks are performed.
"""
+ from deeptrack.scatterers import ScatteredVolume
- from skimage.morphology import isotropic_erosion, isotropic_dilation
-
- from deeptrack.augmentations import CropTight, Pad
+ from deeptrack.augmentations import CropTight, Pad # these are not compatibles with torch backend
from deeptrack.optics import _get_position
min_distance = self.min_distance()
crop = CropTight()
+
+ new_volumes = []
- if min_distance < 0:
- list_of_volumes = [
- Image(
- crop(isotropic_erosion(volume != 0, -min_distance/2)),
- copy=False,
- ).merge_properties_from(volume)
- for volume in list_of_volumes
- ]
- else:
- pad = Pad(px = [int(np.ceil(min_distance/2))]*6, keep_size=True)
- list_of_volumes = [
- Image(
- crop(isotropic_dilation(pad(volume) != 0, min_distance/2)),
- copy=False,
- ).merge_properties_from(volume)
- for volume in list_of_volumes
- ]
+ for volume in list_of_volumes:
+ arr = volume.array
+ mask = arr != 0
+
+ if min_distance < 0:
+ new_arr = isotropic_erosion(mask, -min_distance / 2, backend=self.get_backend())
+ else:
+ pad = Pad(px=[int(np.ceil(min_distance / 2))] * 6, keep_size=True)
+ new_arr = isotropic_dilation(pad(mask) != 0 , min_distance / 2, backend=self.get_backend())
+ new_arr = crop(new_arr)
+
+ if self.get_backend() == "torch":
+ new_arr = new_arr.to(dtype=arr.dtype)
+ else:
+ new_arr = new_arr.astype(arr.dtype)
+
+ new_volume = ScatteredVolume(
+ array=new_arr,
+ properties=volume.properties.copy(),
+ )
+
+ new_volumes.append(new_volume)
+
+ list_of_volumes = new_volumes
min_distance = 1
# The position of the top left corner of each volume (index (0, 0, 0)).
@@ -8472,10 +8489,10 @@ def _check_non_overlapping(
volume_bounding_cube[i], volume_bounding_cube[j]
)
overlapping_volume_1 = self._get_overlapping_volume(
- list_of_volumes[i], volume_bounding_cube[i], overlapping_cube
+ list_of_volumes[i].array, volume_bounding_cube[i], overlapping_cube
)
overlapping_volume_2 = self._get_overlapping_volume(
- list_of_volumes[j], volume_bounding_cube[j], overlapping_cube
+ list_of_volumes[j].array, volume_bounding_cube[j], overlapping_cube
)
# If either the overlapping regions are empty, the volumes do not
@@ -8710,8 +8727,12 @@ def _check_volumes_non_overlapping(
"""
# Get the positions of the non-zero voxels of each volume.
- positions_1 = np.argwhere(volume_1)
- positions_2 = np.argwhere(volume_2)
+ if self.get_backend() == "torch":
+ positions_1 = torch.nonzero(volume_1, as_tuple=False)
+ positions_2 = torch.nonzero(volume_2, as_tuple=False)
+ else:
+ positions_1 = np.argwhere(volume_1)
+ positions_2 = np.argwhere(volume_2)
# if positions_1.size == 0 or positions_2.size == 0:
# return True # If either volume is empty, they are "non-overlapping"
@@ -8732,9 +8753,14 @@ def _check_volumes_non_overlapping(
# Check that the non-zero voxels of the volumes are at least
# min_distance apart.
- return np.all(
- cdist(positions_1, positions_2) > min_distance
- )
+ if self.get_backend() == "torch":
+ dist = torch.cdist(
+ positions_1.float(),
+ positions_2.float(),
+ )
+ return bool((dist > min_distance).all())
+ else:
+ return np.all(cdist(positions_1, positions_2) > min_distance)
def _resample_volume_position(
self: NonOverlapping,
@@ -8750,7 +8776,7 @@ def _resample_volume_position(
Parameters
----------
- volume: np.ndarray or Image
+ volume: np.ndarray
The 3D volume whose position is to be resampled. The volume must
have a `properties` attribute containing dictionaries with
`position` and `_position_sampler` keys.
@@ -8771,12 +8797,12 @@ def _resample_volume_position(
"""
- for pdict in volume.properties:
- if "position" in pdict and "_position_sampler" in pdict:
- new_position = pdict["_position_sampler"]()
- if isinstance(new_position, Quantity):
- new_position = new_position.to("pixel").magnitude
- pdict["position"] = new_position
+ pdict = volume.properties
+ if "position" in pdict and "_position_sampler" in pdict:
+ new_position = pdict["_position_sampler"]()
+ if isinstance(new_position, Quantity):
+ new_position = new_position.to("pixel").magnitude
+ pdict["position"] = new_position
return volume
@@ -9594,3 +9620,73 @@ def get(
res = res[0]
return res
+
+### Move to math?
+def isotropic_dilation(
+ mask,
+ radius: float,
+ *,
+ backend: str,
+ device=None,
+ dtype=None,
+):
+ if radius <= 0:
+ return mask
+
+ if backend == "numpy":
+ from skimage.morphology import isotropic_dilation
+ return isotropic_dilation(mask, radius)
+
+ # torch backend
+ import torch
+
+ r = int(np.ceil(radius))
+ kernel = torch.ones(
+ (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1),
+ device=device or mask.device,
+ dtype=dtype or torch.float32,
+ )
+
+ x = mask.to(dtype=kernel.dtype)[None, None]
+ y = torch.nn.functional.conv3d(
+ x,
+ kernel,
+ padding=r,
+ )
+
+ return (y[0, 0] > 0)
+
+
+def isotropic_erosion(
+ mask,
+ radius: float,
+ *,
+ backend: str,
+ device=None,
+ dtype=None,
+):
+ if radius <= 0:
+ return mask
+
+ if backend == "numpy":
+ from skimage.morphology import isotropic_erosion
+ return isotropic_erosion(mask, radius)
+
+ import torch
+
+ r = int(np.ceil(radius))
+ kernel = torch.ones(
+ (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1),
+ device=device or mask.device,
+ dtype=dtype or torch.float32,
+ )
+
+ x = mask.to(dtype=kernel.dtype)[None, None]
+ y = torch.nn.functional.conv3d(
+ x,
+ kernel,
+ padding=r,
+ )
+
+ required = kernel.numel()
+ return (y[0, 0] >= required)
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index a7c44e8c..a0ad2be9 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -2027,13 +2027,15 @@ def _create_volume(
continue
# Pad scatterer to avoid edge effects during interpolation
- padded_scatterer = scatterer
- padded_scatterer.array = np.pad(
+ padded_scatterer_arr = np.pad( #torch?
scatterer.array,
[(2, 2), (2, 2), (2, 2)],
"constant",
constant_values=0,
)
+ padded_scatterer = ScatteredVolume(
+ array=padded_scatterer_arr,properties=scatterer.properties.copy()
+ )
position = _get_position(padded_scatterer, mode="corner", return_z=True)
shape = np.array(padded_scatterer.array.shape)
diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py
index ea0e140d..40d3fce3 100644
--- a/deeptrack/scatterers.py
+++ b/deeptrack/scatterers.py
@@ -348,26 +348,26 @@ def _wrap_output(self, array, props) -> ScatteredBase:
class VolumeScatterer(Scatterer):
"""Abstract scatterer producing ScatteredVolume outputs."""
def _wrap_output(self, array, props) -> ScatteredVolume:
- return [ScatteredVolume(
+ return ScatteredVolume(
array=array,
- position=props.get("position", (0, 0)),
- z=props.get("z", 0.0),
- value=props.get("value", 1.0),
- intensity=props.get("intensity", 1.0),
- refractive_index=props.get("refractive_index", 1.59),
+ # position=props.get("position", (0, 0)),
+ # z=props.get("z", 0.0),
+ # value=props.get("value", 1.0),
+ # intensity=props.get("intensity", 1.0),
+ # refractive_index=props.get("refractive_index", 1.59),
properties=props.copy(),
- main_property=self.main_property,
- )]
+ # position_sampler=props.get("_position_sampler", None),
+ )
class FieldScatterer(Scatterer):
def _wrap_output(self, array, props) -> ScatteredField:
- return [ScatteredField(
+ return ScatteredField(
array=array,
- position=props.get("position", (0, 0)),
- wavelength=props.get("wavelength", 532.0),
+ # position=props.get("position", (0, 0)),
+ # wavelength=props.get("wavelength", 532.0),
properties=props.copy(),
- main_property=self.main_property,
- )]
+ # position_sampler=props.get("_position_sampler", None),
+ )
#TODO ***??*** revise PointParticle - torch, typing, docstring, unit test
@@ -1468,14 +1468,23 @@ class ScatteredBase:
"""Base class for scatterers (volumes and fields)."""
array: ArrayLike
- position: np.ndarray
- z: float = 0.0
+ # position: np.ndarray
+ # z: float = 0.0
properties: dict[str, Any] = field(default_factory=dict)
- main_property: str = None
- def __post_init__(self):
- self.position = np.array(self.position, dtype=float).reshape(-1)[:2]
- self.z = float(np.atleast_1d(self.z).squeeze())
+ # def __post_init__(self):
+ # self.position = np.array(self.position, dtype=float).reshape(-1)[:2]
+ # self.z = float(np.atleast_1d(self.z).squeeze())
+
+ @property
+ def ndim(self) -> int:
+ """Number of dimensions of the underlying array."""
+ return self.array.ndim
+
+ @property
+ def shape(self) -> int:
+ """Number of dimensions of the underlying array."""
+ return self.array.shape
@property
def pos3d(self) -> np.ndarray:
@@ -1501,13 +1510,16 @@ def get_property(self, key: str, default: Any = None) -> Any:
class ScatteredVolume(ScatteredBase):
"""Volumetric object: intensity sources or refractive index contrasts."""
- refractive_index: float | None = None
- intensity: float | None = None
- value: float | None = None
-
+ # refractive_index: float | None = None
+ # intensity: float | None = None
+ # value: float | None = None
+ # position_sampler: Optional[Callable[[], np.ndarray]] = None
+ pass
@dataclass
class ScatteredField(ScatteredBase):
"""Complex wavefield (already propagated or emitted)."""
- wavelength: float = 500e-9
\ No newline at end of file
+ # wavelength: float = 500e-9
+ # position_sampler: Optional[Callable[[], np.ndarray]] = None
+ pass
\ No newline at end of file
From dcb6fed9022ea1e788d5f3e32b63b5d53042ba9e Mon Sep 17 00:00:00 2001
From: Carlo
Date: Fri, 9 Jan 2026 01:53:36 +0100
Subject: [PATCH 16/24] u
---
deeptrack/scatterers.py | 20 --------------------
1 file changed, 20 deletions(-)
diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py
index 40d3fce3..52cbd576 100644
--- a/deeptrack/scatterers.py
+++ b/deeptrack/scatterers.py
@@ -247,9 +247,6 @@ class Scatterer(Feature):
voxel_size=(u.meter, u.meter),
)
- #: Default property name (subclasses override this)
- main_property: str = "value"
-
def __init__(
self,
position: ArrayLike[float] = (32, 32),
@@ -350,23 +347,14 @@ class VolumeScatterer(Scatterer):
def _wrap_output(self, array, props) -> ScatteredVolume:
return ScatteredVolume(
array=array,
- # position=props.get("position", (0, 0)),
- # z=props.get("z", 0.0),
- # value=props.get("value", 1.0),
- # intensity=props.get("intensity", 1.0),
- # refractive_index=props.get("refractive_index", 1.59),
properties=props.copy(),
- # position_sampler=props.get("_position_sampler", None),
)
class FieldScatterer(Scatterer):
def _wrap_output(self, array, props) -> ScatteredField:
return ScatteredField(
array=array,
- # position=props.get("position", (0, 0)),
- # wavelength=props.get("wavelength", 532.0),
properties=props.copy(),
- # position_sampler=props.get("_position_sampler", None),
)
@@ -459,8 +447,6 @@ class Ellipse(VolumeScatterer):
rotation=(u.radian, u.radian),
)
- main_property = "refractive_index"
-
def __init__(
self,
radius: float = 1e-6,
@@ -565,8 +551,6 @@ class Sphere(VolumeScatterer):
radius=(u.meter, u.meter),
)
- main_property = "refractive_index"
-
def __init__(
self,
radius: float = 1e-6,
@@ -642,8 +626,6 @@ class Ellipsoid(VolumeScatterer):
rotation=(u.radian, u.radian),
)
- main_property = "refractive_index"
-
def __init__(
self,
radius: float = 1e-6,
@@ -854,8 +836,6 @@ class MieScatterer(FieldScatterer):
coherence_length=(u.meter, u.pixel),
)
- main_property = "wavelength"
-
def __init__(
self,
coefficients,
From c5ab7b00c9ad716eb6fe4eeb5d8f5a32898f4f1b Mon Sep 17 00:00:00 2001
From: Carlo
Date: Fri, 9 Jan 2026 09:52:44 +0100
Subject: [PATCH 17/24] sampletomask torch compatible
---
deeptrack/features.py | 78 ++++++++++++++++++-------------------------
deeptrack/optics.py | 4 +--
2 files changed, 35 insertions(+), 47 deletions(-)
diff --git a/deeptrack/features.py b/deeptrack/features.py
index 8a031051..1a669887 100644
--- a/deeptrack/features.py
+++ b/deeptrack/features.py
@@ -171,7 +171,7 @@
from deeptrack.backend import config, TORCH_AVAILABLE, xp
from deeptrack.backend.core import DeepTrackNode
from deeptrack.backend.units import ConversionTable, create_context
-# from deeptrack.image import Image #TODO TBE
+from deeptrack.image import Image #TODO TBE
from deeptrack.properties import PropertyDict, SequentialProperty
from deeptrack.sources import SourceItem
from deeptrack.types import ArrayLike, PropertyLike
@@ -7398,7 +7398,7 @@ class SampleToMasks(Feature):
Returns
-------
- Image or np.ndarray
+ np.ndarray
The final mask image with the specified number of layers.
Raises
@@ -7460,7 +7460,7 @@ class SampleToMasks(Feature):
def __init__(
self: Feature,
- transformation_function: Callable[[Image], Image],
+ transformation_function: Callable[[np.ndarray], np.ndarray, torch.Tensor],
number_of_masks: PropertyLike[int] = 1,
output_region: PropertyLike[tuple[int, int, int, int]] = None,
merge_method: PropertyLike[str | Callable | list[str | Callable]] = "add",
@@ -7494,16 +7494,16 @@ def __init__(
def get(
self: Feature,
image: np.ndarray,
- transformation_function: Callable[[Image], Image],
+ transformation_function: Callable[list[np.ndarray] | np.ndarray | torch.Tensor],
**kwargs: Any,
- ) -> Image:
+ ) -> np.ndarray:
"""Apply the transformation function to a single image.
Parameters
----------
- image: np.ndarray | Image
+ image: np.ndarray
The input image.
- transformation_function: Callable[[Image], Image]
+ transformation_function: Callable[[np.ndarray], np.ndarray]
Function to transform the image.
**kwargs: dict[str, Any]
Additional parameters.
@@ -7519,9 +7519,9 @@ def get(
def _process_and_get(
self: Feature,
- images: list[np.ndarray] | np.ndarray | list[Image] | Image,
+ images: list[np.ndarray] | np.ndarray | list[torch.Tensor] | torch.Tensor,
**kwargs: Any,
- ) -> Image | np.ndarray:
+ ) -> np.ndarray:
"""Process a list of images and generate a multi-layer mask.
Parameters
@@ -7542,40 +7542,21 @@ def _process_and_get(
# Handle list of images.
# if isinstance(images, list) and len(images) != 1:
list_of_labels = super()._process_and_get(images, **kwargs)
- # print(len(list_of_labels))
- # print(list_of_labels[0].shape)
from deeptrack.scatterers import ScatteredVolume
- # if not self._wrap_array_with_image:
- for idx, (label, image) in enumerate(zip(list_of_labels,
- images)):
+ for idx, (label, image) in enumerate(zip(list_of_labels, images)):
list_of_labels[idx] = \
- ScatteredVolume(array=label, properties=image.properties.copy())
- # Image(label, copy=False).merge_properties_from(image)
- # else:
- # if isinstance(images, list):
- # images = images[0]
- # list_of_labels = []
- # for prop in images.properties:
-
- # if "position" in prop:
-
- # inp = Image(np.array(images))
- # inp.append(prop)
- # out = Image(self.get(inp, **kwargs))
- # out.merge_properties_from(inp)
- # list_of_labels.append(out)
-
-
+ ScatteredVolume(array=label, properties=image.properties.copy())
# Create an empty output image.
output_region = kwargs["output_region"]
- output = np.zeros(
+ output = xp.zeros(
(
output_region[2] - output_region[0],
output_region[3] - output_region[1],
kwargs["number_of_masks"],
- )
+ ),
+ dtype=list_of_labels[0].array.dtype,
)
from deeptrack.optics import _get_position
@@ -7585,14 +7566,22 @@ def _process_and_get(
label = volume.array
position = _get_position(volume)
- p0 = np.round(position - output_region[0:2])
+ # p0 = np.round(position - output_region[0:2])
+ p0 = xp.round(position - xp.asarray(output_region[0:2]))
+ p0 = p0.astype(xp.int64)
+
- if np.any(p0 > output.shape[0:2]) or \
- np.any(p0 + label.shape[0:2] < 0):
+ # if np.any(p0 > output.shape[0:2]) or \
+ # np.any(p0 + label.shape[0:2] < 0):
+ if xp.any(p0 > xp.asarray(output.shape[:2])) or \
+ xp.any(p0 + xp.asarray(label.shape[:2]) < 0):
continue
- crop_x = int(-np.min([p0[0], 0]))
- crop_y = int(-np.min([p0[1], 0]))
+ # crop_x = int(-np.min([p0[0], 0]))
+ # crop_y = int(-np.min([p0[1], 0]))
+ crop_x = (-xp.minimum(p0[0], 0)).item()
+ crop_y = (-xp.minimum(p0[1], 0)).item()
+
crop_x_end = int(
label.shape[0]
- np.max([p0[0] + label.shape[0] - output.shape[0], 0])
@@ -7644,9 +7633,13 @@ def _process_and_get(
p0[0] : p0[0] + labelarg.shape[0],
p0[1] : p0[1] + labelarg.shape[1],
label_index,
- ] = (output_slice[..., label_index] != 0) | (
+ ] = xp.logical_or(
+ output_slice[..., label_index] != 0,
labelarg[..., label_index] != 0
- )
+ )
+ # (output_slice[..., label_index] != 0) | (
+ # labelarg[..., label_index] != 0
+ # )
elif merge == "mul":
output[
@@ -7666,11 +7659,6 @@ def _process_and_get(
labelarg[..., label_index],
)
- # if not self._wrap_array_with_image:
- # return output
- # output = Image(output)
- # for label in list_of_labels:
- # output.merge_properties_from(label)
return output
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index a0ad2be9..c2c87b14 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -2027,14 +2027,14 @@ def _create_volume(
continue
# Pad scatterer to avoid edge effects during interpolation
- padded_scatterer_arr = np.pad( #torch?
+ padded_scatterer_arr = np.pad( #Use Pad instead and make it torch-compatible?
scatterer.array,
[(2, 2), (2, 2), (2, 2)],
"constant",
constant_values=0,
)
padded_scatterer = ScatteredVolume(
- array=padded_scatterer_arr,properties=scatterer.properties.copy()
+ array=padded_scatterer_arr, properties=scatterer.properties.copy()
)
position = _get_position(padded_scatterer, mode="corner", return_z=True)
shape = np.array(padded_scatterer.array.shape)
From fec2411db3bfdfe928e51c935bca921f8f4704ef Mon Sep 17 00:00:00 2001
From: Carlo
Date: Fri, 9 Jan 2026 11:03:02 +0100
Subject: [PATCH 18/24] unified field and volume wrapper
---
deeptrack/features.py | 18 +++-----
deeptrack/optics.py | 38 ++++-------------
deeptrack/scatterers.py | 93 +++++++++++++++++++++++------------------
3 files changed, 65 insertions(+), 84 deletions(-)
diff --git a/deeptrack/features.py b/deeptrack/features.py
index 1a669887..be82b558 100644
--- a/deeptrack/features.py
+++ b/deeptrack/features.py
@@ -7543,10 +7543,11 @@ def _process_and_get(
# if isinstance(images, list) and len(images) != 1:
list_of_labels = super()._process_and_get(images, **kwargs)
- from deeptrack.scatterers import ScatteredVolume
+ from deeptrack.scatterers import ScatteredObject
+
for idx, (label, image) in enumerate(zip(list_of_labels, images)):
list_of_labels[idx] = \
- ScatteredVolume(array=label, properties=image.properties.copy())
+ ScatteredObject(array=label, properties=image.properties.copy(), role=image.role)
# Create an empty output image.
output_region = kwargs["output_region"]
@@ -7566,19 +7567,14 @@ def _process_and_get(
label = volume.array
position = _get_position(volume)
- # p0 = np.round(position - output_region[0:2])
p0 = xp.round(position - xp.asarray(output_region[0:2]))
p0 = p0.astype(xp.int64)
- # if np.any(p0 > output.shape[0:2]) or \
- # np.any(p0 + label.shape[0:2] < 0):
if xp.any(p0 > xp.asarray(output.shape[:2])) or \
xp.any(p0 + xp.asarray(label.shape[:2]) < 0):
continue
- # crop_x = int(-np.min([p0[0], 0]))
- # crop_y = int(-np.min([p0[1], 0]))
crop_x = (-xp.minimum(p0[0], 0)).item()
crop_y = (-xp.minimum(p0[1], 0)).item()
@@ -7637,9 +7633,6 @@ def _process_and_get(
output_slice[..., label_index] != 0,
labelarg[..., label_index] != 0
)
- # (output_slice[..., label_index] != 0) | (
- # labelarg[..., label_index] != 0
- # )
elif merge == "mul":
output[
@@ -8408,7 +8401,7 @@ def _check_non_overlapping(
- If bounding cubes overlap, voxel-level checks are performed.
"""
- from deeptrack.scatterers import ScatteredVolume
+ from deeptrack.scatterers import ScatteredObject
from deeptrack.augmentations import CropTight, Pad # these are not compatibles with torch backend
from deeptrack.optics import _get_position
@@ -8434,9 +8427,10 @@ def _check_non_overlapping(
else:
new_arr = new_arr.astype(arr.dtype)
- new_volume = ScatteredVolume(
+ new_volume = ScatteredObject(
array=new_arr,
properties=volume.properties.copy(),
+ role=volume.role,
)
new_volumes.append(new_volume)
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index c2c87b14..d0b5d66b 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -162,7 +162,7 @@ def _pad_volume(
from deeptrack import TORCH_AVAILABLE, image
from deeptrack.backend import xp
-from deeptrack.scatterers import ScatteredVolume, ScatteredField
+from deeptrack.scatterers import ScatteredObject
if TORCH_AVAILABLE:
import torch
@@ -342,14 +342,14 @@ def get(
volume_samples = [
scatterer
for scatterer in list_of_scatterers
- if isinstance(scatterer, ScatteredVolume)
+ if scatterer.role == "volume"
]
# All scatterers that are defined as fields.
field_samples = [
scatterer
for scatterer in list_of_scatterers
- if isinstance(scatterer, ScatteredField)
+ if scatterer.role == "field"
]
# Merge all volumes into a single volume.
@@ -1810,7 +1810,7 @@ def get(
#TODO ***??*** revise _get_position - torch, typing, docstring, unit test
def _get_position(
- scatterer: ScatteredVolume,
+ scatterer: ScatteredObject,
mode: str = "corner",
return_z: bool = False,
) -> np.ndarray:
@@ -2033,8 +2033,8 @@ def _create_volume(
"constant",
constant_values=0,
)
- padded_scatterer = ScatteredVolume(
- array=padded_scatterer_arr, properties=scatterer.properties.copy()
+ padded_scatterer = ScatteredObject(
+ array=padded_scatterer_arr, properties=scatterer.properties.copy(), role=scatterer.role,
)
position = _get_position(padded_scatterer, mode="corner", return_z=True)
shape = np.array(padded_scatterer.array.shape)
@@ -2060,30 +2060,6 @@ def _create_volume(
"Expected np.ndarray or torch.Tensor."
)
- # kernel = np.array(
- # [
- # [0, 0, 0],
- # [0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off],
- # [0, x_off * (1 - y_off), x_off * y_off],
- # ]
- # )
-
- # for z in range(padded_scatterer.array.shape[2]):
- # if splined_scatterer.dtype == complex:
- # splined_scatterer[:, :, z] = (
- # convolve(
- # np.real(padded_scatterer.array[:, :, z]), kernel, mode="constant"
- # )
- # + convolve(
- # np.imag(padded_scatterer.array[:, :, z]), kernel, mode="constant"
- # )
- # * 1j
- # )
- # else:
- # splined_scatterer[:, :, z] = convolve(
- # padded_scatterer.array[:, :, z], kernel, mode="constant"
- # )
-
position = np.floor(position)
new_limits = np.zeros(limits.shape, dtype=np.int32)
for i in range(3):
@@ -2113,7 +2089,7 @@ def _create_volume(
within_volume_position = position - limits[:, 0]
# NOTE: Maybe shouldn't be additive.
- # give options: sum default, but also sum, mean, max, min
+ # give options: sum default, but also mean, max, min, or
volume[
int(within_volume_position[0]) :
int(within_volume_position[0] + shape[0]),
diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py
index 52cbd576..9e45664f 100644
--- a/deeptrack/scatterers.py
+++ b/deeptrack/scatterers.py
@@ -338,28 +338,33 @@ def _process_and_get(
# props = kwargs.copy()
return [self._wrap_output(new_image, kwargs)]
- def _wrap_output(self, array, props) -> ScatteredBase:
- """Must be overridden in subclasses to wrap output correctly."""
- raise NotImplementedError
-
-class VolumeScatterer(Scatterer):
- """Abstract scatterer producing ScatteredVolume outputs."""
- def _wrap_output(self, array, props) -> ScatteredVolume:
- return ScatteredVolume(
+ def _wrap_output(self, array, props) -> ScatteredObject:
+ # """Must be overridden in subclasses to wrap output correctly."""
+ # raise NotImplementedError
+ return ScatteredObject(
array=array,
properties=props.copy(),
+ role = self.role,
)
-class FieldScatterer(Scatterer):
- def _wrap_output(self, array, props) -> ScatteredField:
- return ScatteredField(
- array=array,
- properties=props.copy(),
- )
+# class VolumeScatterer(Scatterer):
+# """Abstract scatterer producing ScatteredVolume outputs."""
+# def _wrap_output(self, array, props) -> ScatteredVolume:
+# return ScatteredVolume(
+# array=array,
+# properties=props.copy(),
+# )
+
+# class FieldScatterer(Scatterer):
+# def _wrap_output(self, array, props) -> ScatteredField:
+# return ScatteredField(
+# array=array,
+# properties=props.copy(),
+# )
#TODO ***??*** revise PointParticle - torch, typing, docstring, unit test
-class PointParticle(VolumeScatterer):
+class PointParticle(Scatterer):
"""Generate a diffraction-limited point particle.
A point particle is approximated by the size of a single pixel or voxel.
@@ -382,7 +387,8 @@ class PointParticle(VolumeScatterer):
for `Brightfield` and `intensity` for `Fluorescence`).
"""
-
+ role = "volume"
+
def __init__(
self: PointParticle,
**kwargs: Any,
@@ -406,7 +412,7 @@ def get(
#TODO ***??*** revise Ellipse - torch, typing, docstring, unit test
-class Ellipse(VolumeScatterer):
+class Ellipse(Scatterer):
"""Generates an elliptical disk scatterer
Parameters
@@ -441,6 +447,7 @@ class Ellipse(VolumeScatterer):
before rotation.
"""
+ role = "volume"
__conversion_table__ = ConversionTable(
radius=(u.meter, u.meter),
@@ -520,7 +527,7 @@ def get(
#TODO ***??*** revise Sphere - torch, typing, docstring, unit test
-class Sphere(VolumeScatterer):
+class Sphere(Scatterer):
"""Generates a spherical scatterer
Parameters
@@ -546,6 +553,7 @@ class Sphere(VolumeScatterer):
Upsamples the calculations of the pixel occupancy fraction.
"""
+ role = "volume"
__conversion_table__ = ConversionTable(
radius=(u.meter, u.meter),
@@ -585,7 +593,7 @@ def get(
#TODO ***??*** revise Ellipsoid - torch, typing, docstring, unit test
-class Ellipsoid(VolumeScatterer):
+class Ellipsoid(Scatterer):
"""Generates an ellipsoidal scatterer
Parameters
@@ -621,6 +629,8 @@ class Ellipsoid(VolumeScatterer):
"""
+ role = "volume"
+
__conversion_table__ = ConversionTable(
radius=(u.meter, u.meter),
rotation=(u.radian, u.radian),
@@ -742,7 +752,7 @@ def get(
#TODO ***??*** revise MieScatterer - torch, typing, docstring, unit test
-class MieScatterer(FieldScatterer):
+class MieScatterer(Scatterer):
"""Base implementation of a Mie particle.
New Mie-theory scatterers can be implemented by extending this class, and
@@ -827,6 +837,8 @@ class MieScatterer(FieldScatterer):
"""
+ role = "field"
+
__conversion_table__ = ConversionTable(
radius=(u.meter, u.meter),
polarization_angle=(u.radian, u.radian),
@@ -1304,6 +1316,8 @@ class MieSphere(MieScatterer):
"""
+ role = "field"
+
def __init__(
self,
radius: float = 1e-6,
@@ -1406,6 +1420,8 @@ class MieStratifiedSphere(MieScatterer):
"""
+ role = "field"
+
def __init__(
self,
radius: ArrayLike[float] = [1e-6],
@@ -1444,17 +1460,12 @@ def inner(
@dataclass
-class ScatteredBase:
+class ScatteredObject:
"""Base class for scatterers (volumes and fields)."""
array: ArrayLike
- # position: np.ndarray
- # z: float = 0.0
properties: dict[str, Any] = field(default_factory=dict)
-
- # def __post_init__(self):
- # self.position = np.array(self.position, dtype=float).reshape(-1)[:2]
- # self.z = float(np.atleast_1d(self.z).squeeze())
+ role: Literal["volume", "field"] = "volume"
@property
def ndim(self) -> int:
@@ -1486,20 +1497,20 @@ def get_property(self, key: str, default: Any = None) -> Any:
return getattr(self, key, self.properties.get(key, default))
-@dataclass
-class ScatteredVolume(ScatteredBase):
- """Volumetric object: intensity sources or refractive index contrasts."""
+# @dataclass
+# class ScatteredVolume(ScatteredBase):
+# """Volumetric object: intensity sources or refractive index contrasts."""
- # refractive_index: float | None = None
- # intensity: float | None = None
- # value: float | None = None
- # position_sampler: Optional[Callable[[], np.ndarray]] = None
- pass
+# # refractive_index: float | None = None
+# # intensity: float | None = None
+# # value: float | None = None
+# # position_sampler: Optional[Callable[[], np.ndarray]] = None
+# pass
-@dataclass
-class ScatteredField(ScatteredBase):
- """Complex wavefield (already propagated or emitted)."""
+# @dataclass
+# class ScatteredField(ScatteredBase):
+# """Complex wavefield (already propagated or emitted)."""
- # wavelength: float = 500e-9
- # position_sampler: Optional[Callable[[], np.ndarray]] = None
- pass
\ No newline at end of file
+# # wavelength: float = 500e-9
+# # position_sampler: Optional[Callable[[], np.ndarray]] = None
+# pass
\ No newline at end of file
From ac248c7b1382f9fa04ef6c65d812feecd5b6b6ff Mon Sep 17 00:00:00 2001
From: Carlo
Date: Fri, 9 Jan 2026 14:18:15 +0100
Subject: [PATCH 19/24] u
---
deeptrack/scatterers.py | 29 ++++++++++-------------------
1 file changed, 10 insertions(+), 19 deletions(-)
diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py
index 9e45664f..b942888c 100644
--- a/deeptrack/scatterers.py
+++ b/deeptrack/scatterers.py
@@ -1481,6 +1481,16 @@ def shape(self) -> int:
def pos3d(self) -> np.ndarray:
return np.array([*self.position, self.z], dtype=float)
+ @property
+ def position(self) -> np.ndarray:
+ pos = self.properties.get("position", None)
+ if pos is None:
+ return None
+ pos = np.asarray(pos, dtype=float)
+ if pos.ndim == 2 and pos.shape[0] == 1:
+ pos = pos[0]
+ return pos
+
def as_array(self) -> ArrayLike:
"""Return the underlying array.
@@ -1495,22 +1505,3 @@ def as_array(self) -> ArrayLike:
def get_property(self, key: str, default: Any = None) -> Any:
return getattr(self, key, self.properties.get(key, default))
-
-
-# @dataclass
-# class ScatteredVolume(ScatteredBase):
-# """Volumetric object: intensity sources or refractive index contrasts."""
-
-# # refractive_index: float | None = None
-# # intensity: float | None = None
-# # value: float | None = None
-# # position_sampler: Optional[Callable[[], np.ndarray]] = None
-# pass
-
-# @dataclass
-# class ScatteredField(ScatteredBase):
-# """Complex wavefield (already propagated or emitted)."""
-
-# # wavelength: float = 500e-9
-# # position_sampler: Optional[Callable[[], np.ndarray]] = None
-# pass
\ No newline at end of file
From 10e15523cd5a0eef9278b0b2c3bfc6bb5f4a8284 Mon Sep 17 00:00:00 2001
From: Carlo
Date: Thu, 15 Jan 2026 01:51:27 +0100
Subject: [PATCH 20/24] u
---
deeptrack/features.py | 97 ++-----------
deeptrack/math.py | 291 ++++++++++++++++++++++++++++++++++++++-
deeptrack/optics.py | 293 +++++++++++++++++++++++-----------------
deeptrack/scatterers.py | 116 ++++++++++------
4 files changed, 548 insertions(+), 249 deletions(-)
diff --git a/deeptrack/features.py b/deeptrack/features.py
index be82b558..e3092698 100644
--- a/deeptrack/features.py
+++ b/deeptrack/features.py
@@ -96,7 +96,7 @@
- `TakeProperties`: Extract all instances of properties from a pipeline.
Arithmetic Feature Classes:
-- `Add`: Add a value to the input.
+- `Add`: Add a value to the input.@dataclass
- `Subtract`: Subtract a value from the input.
- `Multiply`: Multiply the input by a value.
- `Divide`: Divide the input by a value.
@@ -7543,11 +7543,11 @@ def _process_and_get(
# if isinstance(images, list) and len(images) != 1:
list_of_labels = super()._process_and_get(images, **kwargs)
- from deeptrack.scatterers import ScatteredObject
+ from deeptrack.scatterers import ScatteredVolume
for idx, (label, image) in enumerate(zip(list_of_labels, images)):
list_of_labels[idx] = \
- ScatteredObject(array=label, properties=image.properties.copy(), role=image.role)
+ ScatteredVolume(array=label, properties=image.properties.copy())
# Create an empty output image.
output_region = kwargs["output_region"]
@@ -8080,15 +8080,18 @@ def get(
# Create a context for upscaling and perform computation.
ctx = create_context(None, None, None, *factor)
+
+ print('before:', image)
with units.context(ctx):
image = self.feature(image)
- # # Downscale the result to the original resolution.
- # import skimage.measure
+ print('after:', image)
+ # Downscale the result to the original resolution.
+ import skimage.measure
- # image = skimage.measure.block_reduce(
- # image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean
- # )
+ image = skimage.measure.block_reduce(
+ image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean
+ )
return image
@@ -8401,10 +8404,11 @@ def _check_non_overlapping(
- If bounding cubes overlap, voxel-level checks are performed.
"""
- from deeptrack.scatterers import ScatteredObject
+ from deeptrack.scatterers import ScatteredVolume
from deeptrack.augmentations import CropTight, Pad # these are not compatibles with torch backend
from deeptrack.optics import _get_position
+ from deeptrack.math import isotropic_erosion, isotropic_dilation
min_distance = self.min_distance()
crop = CropTight()
@@ -8427,10 +8431,9 @@ def _check_non_overlapping(
else:
new_arr = new_arr.astype(arr.dtype)
- new_volume = ScatteredObject(
+ new_volume = ScatteredVolume(
array=new_arr,
properties=volume.properties.copy(),
- role=volume.role,
)
new_volumes.append(new_volume)
@@ -9601,74 +9604,4 @@ def get(
if len(res) == 1:
res = res[0]
- return res
-
-### Move to math?
-def isotropic_dilation(
- mask,
- radius: float,
- *,
- backend: str,
- device=None,
- dtype=None,
-):
- if radius <= 0:
- return mask
-
- if backend == "numpy":
- from skimage.morphology import isotropic_dilation
- return isotropic_dilation(mask, radius)
-
- # torch backend
- import torch
-
- r = int(np.ceil(radius))
- kernel = torch.ones(
- (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1),
- device=device or mask.device,
- dtype=dtype or torch.float32,
- )
-
- x = mask.to(dtype=kernel.dtype)[None, None]
- y = torch.nn.functional.conv3d(
- x,
- kernel,
- padding=r,
- )
-
- return (y[0, 0] > 0)
-
-
-def isotropic_erosion(
- mask,
- radius: float,
- *,
- backend: str,
- device=None,
- dtype=None,
-):
- if radius <= 0:
- return mask
-
- if backend == "numpy":
- from skimage.morphology import isotropic_erosion
- return isotropic_erosion(mask, radius)
-
- import torch
-
- r = int(np.ceil(radius))
- kernel = torch.ones(
- (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1),
- device=device or mask.device,
- dtype=dtype or torch.float32,
- )
-
- x = mask.to(dtype=kernel.dtype)[None, None]
- y = torch.nn.functional.conv3d(
- x,
- kernel,
- padding=r,
- )
-
- required = kernel.numel()
- return (y[0, 0] >= required)
+ return res
\ No newline at end of file
diff --git a/deeptrack/math.py b/deeptrack/math.py
index 05cbf311..0af95c05 100644
--- a/deeptrack/math.py
+++ b/deeptrack/math.py
@@ -93,7 +93,7 @@
from __future__ import annotations
-from typing import Any, Callable, TYPE_CHECKING
+from typing import Any, Callable, Dict, Tuple, TYPE_CHECKING
import array_api_compat as apc
import numpy as np
@@ -110,6 +110,7 @@
if TORCH_AVAILABLE:
import torch
+ import torch.nn.functional as F
if OPENCV_AVAILABLE:
import cv2
@@ -129,11 +130,15 @@
"MaxPooling",
"MinPooling",
"MedianPooling",
+ "PoolV2",
+ "AveragePoolingV2",
+ "MaxPoolingV2",
+ "MinPoolingV2",
+ "MedianPoolingV2",
"BlurCV2",
"BilateralBlur",
]
-
if TYPE_CHECKING:
import torch
@@ -1663,6 +1668,218 @@ def __init__(
super().__init__(np.median, ksize=ksize, **kwargs)
+class PoolV2:
+ """
+ DeepTrack v2 replacement for Pool.
+
+ Generic, center-preserving block pooling with NumPy and Torch backends.
+ Public API matches v1: a single integer ksize.
+
+ Pool size semantics:
+ - 2D input -> (ksize, ksize, 1)
+ - 3D input -> (ksize, ksize, ksize)
+ """
+
+ _TORCH_REDUCERS_2D: Dict[Callable, Callable] = {
+ np.mean: lambda x, k, s: F.avg_pool2d(x, k, s),
+ np.sum: lambda x, k, s: F.avg_pool2d(x, k, s) * (k[0] * k[1]),
+ np.max: lambda x, k, s: F.max_pool2d(x, k, s),
+ np.min: lambda x, k, s: -F.max_pool2d(-x, k, s),
+ }
+
+ _TORCH_REDUCERS_3D: Dict[Callable, Callable] = {
+ np.mean: lambda x, k, s: F.avg_pool3d(x, k, s),
+ np.sum: lambda x, k, s: F.avg_pool3d(x, k, s) * (k[0] * k[1] * k[2]),
+ np.max: lambda x, k, s: F.max_pool3d(x, k, s),
+ np.min: lambda x, k, s: -F.max_pool3d(-x, k, s),
+ }
+
+ def __init__(
+ self,
+ pooling_function: Callable,
+ ksize: int = 2,
+ ):
+ if pooling_function not in (
+ np.mean, np.sum, np.min, np.max, np.median
+ ):
+ raise ValueError(
+ "Unsupported pooling_function. "
+ "Use one of: np.mean, np.sum, np.min, np.max, np.median."
+ )
+
+ if not isinstance(ksize, int) or ksize < 1:
+ raise ValueError("ksize must be a positive integer.")
+
+ self.pooling_function = pooling_function
+ self.ksize = int(ksize)
+
+ def _get_pool_size(self, array) -> Tuple[int, int, int]:
+ """
+ Determine pooling kernel size based on semantic dimensionality.
+
+ - 2D images: (Nx, Ny) or (Nx, Ny, C) -> pool in x,y only
+ - 3D volumes: (Nx, Ny, Nz) or (Nx, Ny, Nz, C) -> pool in x,y,z
+ - Never pool over channels
+ """
+ k = self.ksize
+
+ # 2D image
+ if array.ndim == 2:
+ return k, k, 1
+
+ # 3D array: could be (x, y, z) or (x, y, c)
+ if array.ndim == 3:
+ # Heuristic: small last dim → channels
+ if array.shape[-1] <= 4:
+ return k, k, 1
+ return k, k, k
+
+ # 4D array: (x, y, z, c)
+ if array.ndim == 4:
+ return k, k, k
+
+ raise ValueError(
+ f"Unsupported array shape {array.shape} for pooling."
+ )
+
+ def _crop_center(self, array):
+ px, py, pz = self._get_pool_size(array)
+
+ # 2D (or effectively 2D)
+ if array.ndim < 3 or pz == 1:
+ H, W = array.shape[:2]
+ crop_h = (H // px) * px
+ crop_w = (W // py) * py
+ off_h = (H - crop_h) // 2
+ off_w = (W - crop_w) // 2
+ return array[
+ off_h : off_h + crop_h,
+ off_w : off_w + crop_w,
+ ...
+ ]
+
+ # 3D
+ Z, H, W = array.shape[:3]
+ crop_z = (Z // pz) * pz
+ crop_h = (H // px) * px
+ crop_w = (W // py) * py
+ off_z = (Z - crop_z) // 2
+ off_h = (H - crop_h) // 2
+ off_w = (W - crop_w) // 2
+ return array[
+ off_z : off_z + crop_z,
+ off_h : off_h + crop_h,
+ off_w : off_w + crop_w,
+ ...
+ ]
+
+ def _pool_numpy(self, array: np.ndarray) -> np.ndarray:
+ array = self._crop_center(array)
+ px, py, pz = self._get_pool_size(array)
+
+ if array.ndim < 3 or pz == 1:
+ pool_shape = (px, py) + (1,) * (array.ndim - 2)
+ else:
+ pool_shape = (pz, px, py) + (1,) * (array.ndim - 3)
+
+ return skimage.measure.block_reduce(
+ array,
+ block_size=pool_shape,
+ func=self.pooling_function,
+ )
+
+ def _pool_torch(self, array: torch.Tensor) -> torch.Tensor:
+ array = self._crop_center(array)
+ px, py, pz = self._get_pool_size(array)
+
+ is_3d = array.ndim >= 3 and pz > 1
+
+ if not is_3d:
+ extra = array.shape[2:]
+ C = int(np.prod(extra)) if extra else 1
+ x = array.reshape(1, C, array.shape[0], array.shape[1])
+ kernel = (px, py)
+ stride = (px, py)
+ reducers = self._TORCH_REDUCERS_2D
+ else:
+ extra = array.shape[3:]
+ C = int(np.prod(extra)) if extra else 1
+ x = array.reshape(
+ 1, C, array.shape[0], array.shape[1], array.shape[2]
+ )
+ kernel = (pz, px, py)
+ stride = (pz, px, py)
+ reducers = self._TORCH_REDUCERS_3D
+
+ # Median: explicit unfolding
+ if self.pooling_function is np.median:
+ if is_3d:
+ x_u = (
+ x.unfold(2, pz, pz)
+ .unfold(3, px, px)
+ .unfold(4, py, py)
+ )
+ x_u = x_u.contiguous().view(
+ 1, C,
+ x_u.shape[2],
+ x_u.shape[3],
+ x_u.shape[4],
+ -1,
+ )
+ pooled = x_u.median(dim=-1).values
+ else:
+ x_u = x.unfold(2, px, px).unfold(3, py, py)
+ x_u = x_u.contiguous().view(
+ 1, C,
+ x_u.shape[2],
+ x_u.shape[3],
+ -1,
+ )
+ pooled = x_u.median(dim=-1).values
+ else:
+ reducer = reducers[self.pooling_function]
+ pooled = reducer(x, kernel, stride)
+
+ return pooled.reshape(pooled.shape[2:] + extra)
+
+ def __call__(self, array):
+ if isinstance(array, np.ndarray):
+ return self._pool_numpy(array)
+
+ if TORCH_AVAILABLE and isinstance(array, torch.Tensor):
+ return self._pool_torch(array)
+
+ raise TypeError(
+ "PoolV2 only supports np.ndarray or torch.Tensor inputs."
+ )
+
+
+class AveragePoolingV2(PoolV2):
+ def __init__(self, ksize: int = 2):
+ super().__init__(np.mean, ksize)
+
+
+class SumPoolingV2(PoolV2):
+ def __init__(self, ksize: int = 2):
+ super().__init__(np.sum, ksize)
+
+
+class MinPoolingV2(PoolV2):
+ def __init__(self, ksize: int = 2):
+ super().__init__(np.min, ksize)
+
+
+class MaxPoolingV2(PoolV2):
+ def __init__(self, ksize: int = 2):
+ super().__init__(np.max, ksize)
+
+
+class MedianPoolingV2(PoolV2):
+ def __init__(self, ksize: int = 2):
+ super().__init__(np.median, ksize)
+
+
+
class Resize(Feature):
"""Resize an image to a specified size.
@@ -2059,3 +2276,73 @@ def __init__(
sigmaSpace=sigma_space,
**kwargs,
)
+
+
+def isotropic_dilation(
+ mask,
+ radius: float,
+ *,
+ backend: str,
+ device=None,
+ dtype=None,
+):
+ if radius <= 0:
+ return mask
+
+ if backend == "numpy":
+ from skimage.morphology import isotropic_dilation
+ return isotropic_dilation(mask, radius)
+
+ # torch backend
+ import torch
+
+ r = int(np.ceil(radius))
+ kernel = torch.ones(
+ (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1),
+ device=device or mask.device,
+ dtype=dtype or torch.float32,
+ )
+
+ x = mask.to(dtype=kernel.dtype)[None, None]
+ y = torch.nn.functional.conv3d(
+ x,
+ kernel,
+ padding=r,
+ )
+
+ return (y[0, 0] > 0)
+
+
+def isotropic_erosion(
+ mask,
+ radius: float,
+ *,
+ backend: str,
+ device=None,
+ dtype=None,
+):
+ if radius <= 0:
+ return mask
+
+ if backend == "numpy":
+ from skimage.morphology import isotropic_erosion
+ return isotropic_erosion(mask, radius)
+
+ import torch
+
+ r = int(np.ceil(radius))
+ kernel = torch.ones(
+ (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1),
+ device=device or mask.device,
+ dtype=dtype or torch.float32,
+ )
+
+ x = mask.to(dtype=kernel.dtype)[None, None]
+ y = torch.nn.functional.conv3d(
+ x,
+ kernel,
+ padding=r,
+ )
+
+ required = kernel.numel()
+ return (y[0, 0] >= required)
\ No newline at end of file
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index d0b5d66b..0788d7d8 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -151,7 +151,7 @@ def _pad_volume(
get_active_scale,
get_active_voxel_size,
)
-from deeptrack.math import AveragePooling
+from deeptrack.math import AveragePoolingV2, SumPoolingV2
from deeptrack.features import propagate_data_to_dependencies
from deeptrack.features import DummyFeature, Feature, StructuralFeature
from deeptrack.image import pad_image_to_fft
@@ -162,7 +162,7 @@ def _pad_volume(
from deeptrack import TORCH_AVAILABLE, image
from deeptrack.backend import xp
-from deeptrack.scatterers import ScatteredObject
+from deeptrack.scatterers import ScatteredVolume, ScatteredField
if TORCH_AVAILABLE:
import torch
@@ -175,9 +175,13 @@ def _pad_volume(
class Microscope(StructuralFeature):
"""Simulates imaging of a sample using an optical system.
- This class combines a feature-set that defines the sample to be imaged with
- a feature-set defining the optical system, enabling the simulation of
- optical imaging processes.
+ This class combines the sample to be imaged with the optical system,
+ enabling the simulation of optical imaging processes.
+ A Microscope:
+ - validates the semantic compatibility between scatterers and optics
+ - interprets volume-based scatterers into scalar fields when needed
+ - delegates numerical propagation to the objective (Optics)
+ - performs detector downscaling according to its physical semantics
Parameters
----------
@@ -202,6 +206,12 @@ class Microscope(StructuralFeature):
Simulates the imaging process using the defined optical system and
returns the resulting image.
+ Notes
+ -----
+ All volume scatterers imaged by a Microscope instance are assumed to
+ share the same contrast mechanism (e.g. refractive index or fluorescence).
+ Mixing contrast types is not supported.
+
Examples
--------
Simulating an image using a brightfield optical system:
@@ -250,13 +260,40 @@ def __init__(
self._sample = self.add_feature(sample)
self._objective = self.add_feature(objective)
- # self._sample.store_properties()
+
+ def _validate_input(self, scattered):
+ if hasattr(self._objective, "validate_input"):
+ self._objective.validate_input(scattered)
+
+ def _extract_contrast_volume(self, scattered):
+ if hasattr(self._objective, "extract_contrast_volume"):
+ return self._objective.extract_contrast_volume(scattered)
+
+ # default: geometry-only
+ return scattered.array
+
+ def _downscale_image(self, image, upscale):
+ if hasattr(self._objective, "downscale_image"):
+ return self._objective.downscale_image(image, upscale)
+
+ if not np.any(np.array(upscale) != 1):
+ return image
+
+ ux, uy = upscale[:2]
+ if ux != uy:
+ raise ValueError(
+ f"Energy-conserving detector integration requires ux == uy, "
+ f"got ux={ux}, uy={uy}."
+ )
+ if isinstance(ux, float) and ux.is_integer():
+ ux = int(ux)
+ return AveragePoolingV2(ux)(image)
def get(
self: Microscope,
- image: np.ndarray | None,
+ image: np.ndarray | torch.Tensor | None = None,
**kwargs: Any,
- ) -> np.ndarray:
+ ) -> np.ndarray | torch.Tensor:
"""Generate an image of the sample using the defined optical system.
This method processes the sample through the optical system to
@@ -264,14 +301,14 @@ def get(
Parameters
----------
- image: np.ndarray | None
+ image: np.ndarray | torch.Tensor | None
The input image to be processed. If None, a new image is created.
**kwargs: Any
Additional parameters for the imaging process.
Returns
-------
- image: np.ndarray
+ image: np.ndarray | torch.Tensor
The processed image after applying the optical system.
Examples
@@ -291,18 +328,14 @@ def get(
# Grab properties from the objective to pass to the sample
additional_sample_kwargs = self._objective.properties()
- contrast_type = getattr(self._objective, "contrast_type", None)
- if contrast_type is None:
- raise RuntimeError(
- f"{self._objective.__class__.__name__} must define `contrast_type` "
- "(e.g. 'intensity' or 'refractive_index')."
- )
- additional_sample_kwargs["contrast_type"] = contrast_type
+ _upscale_given_by_optics = additional_sample_kwargs["upscale"]
+ if np.array(_upscale_given_by_optics).size == 1:
+ _upscale_given_by_optics = (_upscale_given_by_optics,) * 3
with u.context(
create_context(
- *additional_sample_kwargs["voxel_size"]#, *_upscale_given_by_optics
+ *additional_sample_kwargs["voxel_size"], *_upscale_given_by_optics
)
):
@@ -338,18 +371,22 @@ def get(
if not isinstance(list_of_scatterers, list):
list_of_scatterers = [list_of_scatterers]
+ # Semantic validation (per scatterer)
+ for scattered in list_of_scatterers:
+ self._validate_input(scattered)
+
# All scatterers that are defined as volumes.
volume_samples = [
scatterer
for scatterer in list_of_scatterers
- if scatterer.role == "volume"
+ if isinstance(scatterer, ScatteredVolume)
]
# All scatterers that are defined as fields.
field_samples = [
scatterer
for scatterer in list_of_scatterers
- if scatterer.role == "field"
+ if isinstance(scatterer, ScatteredField)
]
# Merge all volumes into a single volume.
@@ -358,24 +395,33 @@ def get(
**additional_sample_kwargs,
)
+ # Interpret the merged volume semantically
+ sample_volume = self._extract_contrast_volume(
+ ScatteredVolume(
+ array=sample_volume,
+ properties=volume_samples[0].properties,
+ )
+ )
+
# Let the objective know about the limits of the volume and all the fields.
propagate_data_to_dependencies(
self._objective,
limits=limits,
- fields=field_samples,
+ fields=field_samples, # should We add upscale?
)
imaged_sample = self._objective.resolve(sample_volume)
- # Handling upscale from dt.Upscale() here to eliminate Image
- # wrapping issues.
- if np.any(np.array(upscale) != 1):
- ux, uy = upscale[:2]
- if contrast_type == "intensity":
- print("Using sum pooling for intensity downscaling.")
- imaged_sample = SumPoolingCM((ux, uy, 1))(imaged_sample)
- else:
- imaged_sample = AveragePoolingCM((ux, uy, 1))(imaged_sample)
+ imaged_sample = self._downscale_image(imaged_sample, upscale)
+ # # Handling upscale from dt.Upscale() here to eliminate Image
+ # # wrapping issues.
+ # if np.any(np.array(upscale) != 1):
+ # ux, uy = upscale[:2]
+ # if contrast_type == "intensity":
+ # print("Using sum pooling for intensity downscaling.")
+ # imaged_sample = SumPoolingCM((ux, uy, 1))(imaged_sample)
+ # else:
+ # imaged_sample = AveragePoolingCM((ux, uy, 1))(imaged_sample)
return imaged_sample
@@ -561,6 +607,15 @@ def __init__(
"""
+ def validate_scattered(self, scattered):
+ pass
+
+ def extract_contrast_volume(self, scattered):
+ pass
+
+ def downscale_image(self, image, upscale):
+ pass
+
def get_voxel_size(
resolution: float | ArrayLike[float],
magnification: float,
@@ -665,6 +720,7 @@ def _process_properties(
wavelength = propertydict["wavelength"]
voxel_size = get_active_voxel_size()
radius = NA / wavelength * np.array(voxel_size)
+ print('Pupil radius (in pixels):', radius)
if np.any(radius[:2] > 0.5):
required_upscale = np.max(np.ceil(radius[:2] * 2))
@@ -1014,7 +1070,66 @@ class Fluorescence(Optics):
1.4
"""
- contrast_type = "intensity"
+
+
+ def validate_input(self, scattered):
+ """Semantic validation for fluorescence microscopy."""
+
+ # Fluorescence cannot operate on coherent fields
+ if isinstance(scattered, ScatteredField):
+ raise TypeError(
+ "Fluorescence microscope cannot operate on ScatteredField."
+ )
+
+ # Fluorescence must not use refractive index
+ if isinstance(scattered, ScatteredVolume):
+ if scattered.get_property("refractive_index", None) is not None:
+ raise ValueError(
+ "Fluorescence does not use refractive index. "
+ "Found 'refractive_index' in scatterer properties."
+ )
+
+
+ def extract_contrast_volume(self, scattered: ScatteredVolume) -> np.ndarray:
+ """Contrast extraction (semantic interpretation)"""
+ intensity = scattered.get_property("intensity", None)
+
+ if intensity is None:
+ intensity = scattered.get_property("value", None)
+ if intensity is None:
+ raise ValueError(
+ "Fluorescence requires 'intensity' or 'value'."
+ )
+
+ warnings.warn(
+ "Using 'value' as fluorescence intensity is ambiguous. "
+ "Please use 'intensity' explicitly to avoid ambiguity.",
+ UserWarning,
+ )
+
+ voxel_size = np.asarray(get_active_voxel_size(), dtype=float)
+ voxel_volume = float(np.prod(voxel_size))
+
+ return scattered.array * intensity * voxel_volume
+
+
+ def downscale_image(self, image: np.ndarray, upscale):
+ """Detector downscaling (energy conserving)"""
+ if not np.any(np.array(upscale) != 1):
+ return image
+
+ ux, uy = upscale[:2]
+ if ux != uy:
+ raise ValueError(
+ f"Energy-conserving detector integration requires ux == uy, "
+ f"got ux={ux}, uy={uy}."
+ )
+ if isinstance(ux, float) and ux.is_integer():
+ ux = int(ux)
+
+ # Energy-conserving detector integration
+ return SumPoolingV2(ux)(image)
+
def get(
self: Fluorescence,
@@ -1240,7 +1355,6 @@ class Brightfield(Optics):
"""
- contrast_type = "refractive_index"
__conversion_table__ = ConversionTable(
working_distance=(u.meter, u.meter),
@@ -1960,12 +2074,12 @@ def _create_volume(
Spatial limits of the volume.
"""
- contrast_type = kwargs.get("contrast_type", None)
- if contrast_type is None:
- raise RuntimeError(
- "_create_volume requires a contrast_type "
- "(e.g. 'intensity' or 'refractive_index')"
- )
+ # contrast_type = kwargs.get("contrast_type", None)
+ # if contrast_type is None:
+ # raise RuntimeError(
+ # "_create_volume requires a contrast_type "
+ # "(e.g. 'intensity' or 'refractive_index')"
+ # )
if not isinstance(list_of_scatterers, list):
list_of_scatterers = [list_of_scatterers]
@@ -1995,23 +2109,23 @@ def _create_volume(
for scatterer in list_of_scatterers:
position = _get_position(scatterer, mode="corner", return_z=True)
- if contrast_type == "intensity":
- value = scatterer.get_property("intensity", None)
- if value is None:
- raise ValueError("Scatterer has no intensity.")
- scatterer_value = value
+ # if contrast_type == "intensity":
+ # value = scatterer.get_property("intensity", None)
+ # if value is None:
+ # raise ValueError("Scatterer has no intensity.")
+ # scatterer_value = value
- elif contrast_type == "refractive_index":
- ri = scatterer.get_property("refractive_index", None)
- if ri is None:
- raise ValueError("Scatterer has no refractive_index.")
- scatterer_value = ri - refractive_index_medium
+ # elif contrast_type == "refractive_index":
+ # ri = scatterer.get_property("refractive_index", None)
+ # if ri is None:
+ # raise ValueError("Scatterer has no refractive_index.")
+ # scatterer_value = ri - refractive_index_medium
- else:
- raise RuntimeError(f"Unknown contrast_type: {contrast_type}")
+ # else:
+ # raise RuntimeError(f"Unknown contrast_type: {contrast_type}")
- # Scale the array accordingly
- scatterer.array = scatterer.array * scatterer_value
+ # # Scale the array accordingly
+ # scatterer.array = scatterer.array * scatterer_value
if limits is None:
limits = np.zeros((3, 2), dtype=np.int32)
@@ -2033,8 +2147,8 @@ def _create_volume(
"constant",
constant_values=0,
)
- padded_scatterer = ScatteredObject(
- array=padded_scatterer_arr, properties=scatterer.properties.copy(), role=scatterer.role,
+ padded_scatterer = ScatteredVolume(
+ array=padded_scatterer_arr, properties=scatterer.properties.copy(),
)
position = _get_position(padded_scatterer, mode="corner", return_z=True)
shape = np.array(padded_scatterer.array.shape)
@@ -2088,7 +2202,7 @@ def _create_volume(
within_volume_position = position - limits[:, 0]
- # NOTE: Maybe shouldn't be additive.
+ # NOTE: Maybe shouldn't be ONLY additive.
# give options: sum default, but also mean, max, min, or
volume[
int(within_volume_position[0]) :
@@ -2100,71 +2214,4 @@ def _create_volume(
int(within_volume_position[2]) :
int(within_volume_position[2] + shape[2]),
] += splined_scatterer
- return volume, limits
-
-# this should be moved to math
-class _CenteredPoolingBase:
- def __init__(self, pool_size: tuple[int, int, int]):
- px, py, pz = pool_size
- if pz != 1:
- raise ValueError("Only pz=1 supported.")
- self.px = int(px)
- self.py = int(py)
-
- def _crop_center(self, array):
- H, W = array.shape[:2]
- px, py = self.px, self.py
-
- crop_h = (H // px) * px
- crop_w = (W // py) * py
-
- off_h = (H - crop_h) // 2
- off_w = (W - crop_w) // 2
-
- return array[off_h:off_h+crop_h, off_w:off_w+crop_w, ...]
-
- def _pool_numpy(self, array, func):
- import skimage.measure
- array = self._crop_center(array)
- pool_shape = (self.px, self.py) + (1,) * (array.ndim - 2)
- return skimage.measure.block_reduce(array, pool_shape, func)
-
- def _pool_torch(self, array, sum_pool=False):
- px, py = self.px, self.py
- array = self._crop_center(array)
-
- extra = array.shape[2:]
- C = int(np.prod(extra)) if extra else 1
- x = array.reshape(1, C, array.shape[0], array.shape[1])
-
- pooled = torch.nn.functional.avg_pool2d(
- x, kernel_size=(px, py), stride=(px, py)
- )
- if sum_pool:
- pooled = pooled * (px * py)
-
- return pooled.reshape(
- (pooled.shape[2], pooled.shape[3]) + extra
- )
-
-class AveragePoolingCM(_CenteredPoolingBase):
- """Center-preserving average pooling (intensive quantities)."""
-
- def __call__(self, array):
- if isinstance(array, np.ndarray):
- return self._pool_numpy(array, np.mean)
- elif TORCH_AVAILABLE and isinstance(array, torch.Tensor):
- return self._pool_torch(array, sum_pool=False)
- else:
- raise TypeError("Unsupported array type.")
-
-class SumPoolingCM(_CenteredPoolingBase):
- """Center-preserving sum pooling (extensive quantities)."""
-
- def __call__(self, array):
- if isinstance(array, np.ndarray):
- return self._pool_numpy(array, np.sum)
- elif TORCH_AVAILABLE and isinstance(array, torch.Tensor):
- return self._pool_torch(array, sum_pool=True)
- else:
- raise TypeError("Unsupported array type.")
\ No newline at end of file
+ return volume, limits
\ No newline at end of file
diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py
index b942888c..2ba202e0 100644
--- a/deeptrack/scatterers.py
+++ b/deeptrack/scatterers.py
@@ -175,12 +175,14 @@
get_active_voxel_size,
)
from deeptrack.backend import mie
+from deeptrack.math import AveragePoolingV2
from deeptrack.features import Feature, MERGE_STRATEGY_APPEND
from deeptrack.image import pad_image_to_fft
from deeptrack.types import ArrayLike
from deeptrack import units_registry as u
+
__all__ = [
"Scatterer",
"PointParticle",
@@ -239,7 +241,7 @@ class Scatterer(Feature):
"""
- __list_merge_strategy__ = MERGE_STRATEGY_APPEND
+ __list_merge_strategy__ = MERGE_STRATEGY_APPEND ### Not clear why needed
__distributed__ = False
__conversion_table__ = ConversionTable(
position=(u.pixel, u.pixel),
@@ -279,6 +281,21 @@ def __init__(
**kwargs,
)
+ def _antialias_volume(self, volume, factor: int):
+ """Geometry-only supersampling anti-aliasing.
+
+ Assumes `volume` was generated on a grid oversampled by `factor`
+ and downsamples it back by average pooling.
+ """
+ if factor == 1:
+ return volume
+
+ # average pooling conserves fractional occupancy
+ return AveragePoolingV2(
+ factor
+ )(volume)
+
+
def _process_properties(
self,
properties: dict
@@ -308,16 +325,31 @@ def _process_and_get(
+ "Optics.upscale != 1."
)
- voxel_size = get_active_voxel_size()
- # Calls parent _process_and_get.
+ voxel_size = np.asarray(get_active_voxel_size(), float)
+
+ apply_supersampling = upsample > 1 and isinstance(self, VolumeScatterer)
+
+ if upsample > 1 and not apply_supersampling:
+ warnings.warn(
+ "Geometry supersampling (upsample) is ignored for "
+ "FieldScatterers.",
+ UserWarning,
+ )
+
+ if apply_supersampling:
+ voxel_size /= float(upsample)
+
new_image = super(Scatterer, self)._process_and_get(
*args,
voxel_size=voxel_size,
upsample=upsample,
**kwargs,
- )
- new_image = new_image[0]
+ )[0]
+
+ if apply_supersampling:
+ new_image = self._antialias_volume(new_image, factor=upsample)
+
if new_image.size == 0:
warnings.warn(
@@ -338,33 +370,31 @@ def _process_and_get(
# props = kwargs.copy()
return [self._wrap_output(new_image, kwargs)]
- def _wrap_output(self, array, props) -> ScatteredObject:
- # """Must be overridden in subclasses to wrap output correctly."""
- # raise NotImplementedError
- return ScatteredObject(
+ def _wrap_output(self, array, props):
+ raise NotImplementedError(
+ f"{self.__class__.__name__} must implement _wrap_output()"
+ )
+
+
+class VolumeScatterer(Scatterer):
+ """Abstract scatterer producing ScatteredVolume outputs."""
+ def _wrap_output(self, array, props) -> ScatteredVolume:
+ return ScatteredVolume(
array=array,
properties=props.copy(),
- role = self.role,
)
-# class VolumeScatterer(Scatterer):
-# """Abstract scatterer producing ScatteredVolume outputs."""
-# def _wrap_output(self, array, props) -> ScatteredVolume:
-# return ScatteredVolume(
-# array=array,
-# properties=props.copy(),
-# )
-# class FieldScatterer(Scatterer):
-# def _wrap_output(self, array, props) -> ScatteredField:
-# return ScatteredField(
-# array=array,
-# properties=props.copy(),
-# )
+class FieldScatterer(Scatterer):
+ def _wrap_output(self, array, props) -> ScatteredField:
+ return ScatteredField(
+ array=array,
+ properties=props.copy(),
+ )
#TODO ***??*** revise PointParticle - torch, typing, docstring, unit test
-class PointParticle(Scatterer):
+class PointParticle(VolumeScatterer):
"""Generate a diffraction-limited point particle.
A point particle is approximated by the size of a single pixel or voxel.
@@ -387,7 +417,6 @@ class PointParticle(Scatterer):
for `Brightfield` and `intensity` for `Fluorescence`).
"""
- role = "volume"
def __init__(
self: PointParticle,
@@ -396,7 +425,7 @@ def __init__(
"""
"""
-
+ kwargs.pop("upsample", None)
super().__init__(upsample=1, upsample_axes=(), **kwargs)
def get(
@@ -412,7 +441,7 @@ def get(
#TODO ***??*** revise Ellipse - torch, typing, docstring, unit test
-class Ellipse(Scatterer):
+class Ellipse(VolumeScatterer):
"""Generates an elliptical disk scatterer
Parameters
@@ -447,7 +476,7 @@ class Ellipse(Scatterer):
before rotation.
"""
- role = "volume"
+
__conversion_table__ = ConversionTable(
radius=(u.meter, u.meter),
@@ -527,7 +556,7 @@ def get(
#TODO ***??*** revise Sphere - torch, typing, docstring, unit test
-class Sphere(Scatterer):
+class Sphere(VolumeScatterer):
"""Generates a spherical scatterer
Parameters
@@ -553,7 +582,6 @@ class Sphere(Scatterer):
Upsamples the calculations of the pixel occupancy fraction.
"""
- role = "volume"
__conversion_table__ = ConversionTable(
radius=(u.meter, u.meter),
@@ -593,7 +621,7 @@ def get(
#TODO ***??*** revise Ellipsoid - torch, typing, docstring, unit test
-class Ellipsoid(Scatterer):
+class Ellipsoid(VolumeScatterer):
"""Generates an ellipsoidal scatterer
Parameters
@@ -629,8 +657,6 @@ class Ellipsoid(Scatterer):
"""
- role = "volume"
-
__conversion_table__ = ConversionTable(
radius=(u.meter, u.meter),
rotation=(u.radian, u.radian),
@@ -752,7 +778,7 @@ def get(
#TODO ***??*** revise MieScatterer - torch, typing, docstring, unit test
-class MieScatterer(Scatterer):
+class MieScatterer(FieldScatterer):
"""Base implementation of a Mie particle.
New Mie-theory scatterers can be implemented by extending this class, and
@@ -837,7 +863,6 @@ class MieScatterer(Scatterer):
"""
- role = "field"
__conversion_table__ = ConversionTable(
radius=(u.meter, u.meter),
@@ -878,7 +903,6 @@ def __init__(
"Please use input_polarization instead"
)
input_polarization = polarization_angle
- # kwargs.pop("is_field", None) # remove
kwargs.pop("crop_empty", None)
super().__init__(
@@ -1106,7 +1130,6 @@ def get(
# Wave vector.
k = 2 * np.pi / wavelength * refractive_index_medium
-
# Position of objective relative particle.
relative_position = np.array(
(
@@ -1316,7 +1339,6 @@ class MieSphere(MieScatterer):
"""
- role = "field"
def __init__(
self,
@@ -1420,7 +1442,6 @@ class MieStratifiedSphere(MieScatterer):
"""
- role = "field"
def __init__(
self,
@@ -1460,12 +1481,11 @@ def inner(
@dataclass
-class ScatteredObject:
+class ScatteredBase:
"""Base class for scatterers (volumes and fields)."""
- array: ArrayLike
+ array: np.ndarray | torch.Tensor
properties: dict[str, Any] = field(default_factory=dict)
- role: Literal["volume", "field"] = "volume"
@property
def ndim(self) -> int:
@@ -1505,3 +1525,15 @@ def as_array(self) -> ArrayLike:
def get_property(self, key: str, default: Any = None) -> Any:
return getattr(self, key, self.properties.get(key, default))
+
+
+@dataclass
+class ScatteredVolume(ScatteredBase):
+ """Voxelized volume produced by a VolumeScatterer."""
+ pass
+
+
+@dataclass
+class ScatteredField(ScatteredBase):
+ """Complex field produced by a FieldScatterer."""
+ pass
\ No newline at end of file
From 58a1b8a312929128488efe96637cbdc9b8104784 Mon Sep 17 00:00:00 2001
From: Carlo
Date: Thu, 15 Jan 2026 16:45:00 +0100
Subject: [PATCH 21/24] u
---
deeptrack/features.py | 1320 +++++------------------------------------
deeptrack/optics.py | 1141 ++++++++++++++++++++++++++++++++++-
2 files changed, 1277 insertions(+), 1184 deletions(-)
diff --git a/deeptrack/features.py b/deeptrack/features.py
index e3092698..901664a9 100644
--- a/deeptrack/features.py
+++ b/deeptrack/features.py
@@ -218,11 +218,8 @@
"OneOf",
"OneOfDict",
"LoadImage",
- "SampleToMasks",
"AsType",
"ChannelFirst2d",
- "Upscale",
- "NonOverlapping",
"Store",
"Squeeze",
"Unsqueeze",
@@ -7359,302 +7356,6 @@ def get(
return image
-class SampleToMasks(Feature):
- """Create a mask from a list of images.
-
- This feature applies a transformation function to each input image and
- merges the resulting masks into a single multi-layer image. Each input
- image must have a `position` property that determines its placement within
- the final mask. When used with scatterers, the `voxel_size` property must
- be provided for correct object sizing.
-
- Parameters
- ----------
- transformation_function: Callable[[Image], Image]
- A function that transforms each input image into a mask with
- `number_of_masks` layers.
- number_of_masks: PropertyLike[int], optional
- The number of mask layers to generate. Default is 1.
- output_region: PropertyLike[tuple[int, int, int, int]], optional
- The size and position of the output mask, typically aligned with
- `optics.output_region`.
- merge_method: PropertyLike[str | Callable | list[str | Callable]], optional
- Method for merging individual masks into the final image. Can be:
- - "add" (default): Sum the masks.
- - "overwrite": Later masks overwrite earlier masks.
- - "or": Combine masks using a logical OR operation.
- - "mul": Multiply masks.
- - Function: Custom function taking two images and merging them.
-
- **kwargs: dict[str, Any]
- Additional keyword arguments passed to the parent `Feature` class.
-
- Methods
- -------
- `get(image, transformation_function, **kwargs) -> Image`
- Applies the transformation function to the input image.
- `_process_and_get(images, **kwargs) -> Image | np.ndarray`
- Processes a list of images and generates a multi-layer mask.
-
- Returns
- -------
- np.ndarray
- The final mask image with the specified number of layers.
-
- Raises
- ------
- ValueError
- If `merge_method` is invalid.
-
- Examples
- -------
- >>> import deeptrack as dt
-
- Define number of particles:
-
- >>> n_particles = 12
-
- Define optics and particles:
-
- >>> import numpy as np
- >>>
- >>> optics = dt.Fluorescence(output_region=(0, 0, 64, 64))
- >>> particle = dt.PointParticle(
- >>> position=lambda: np.random.uniform(5, 55, size=2),
- >>> )
- >>> particles = particle ^ n_particles
-
- Define pipelines:
-
- >>> sim_im_pip = optics(particles)
- >>> sim_mask_pip = particles >> dt.SampleToMasks(
- ... lambda: lambda particles: particles > 0,
- ... output_region=optics.output_region,
- ... merge_method="or",
- ... )
- >>> pipeline = sim_im_pip & sim_mask_pip
- >>> pipeline.store_properties()
-
- Generate image and mask:
-
- >>> image, mask = pipeline.update()()
-
- Get particle positions:
-
- >>> positions = np.array(image.get_property("position", get_one=False))
-
- Visualize results:
-
- >>> import matplotlib.pyplot as plt
- >>>
- >>> plt.subplot(1, 2, 1)
- >>> plt.imshow(image, cmap="gray")
- >>> plt.title("Original Image")
- >>> plt.subplot(1, 2, 2)
- >>> plt.imshow(mask, cmap="gray")
- >>> plt.scatter(positions[:,1], positions[:,0], c="y", marker="x", s = 50)
- >>> plt.title("Mask")
- >>> plt.show()
-
- """
-
- def __init__(
- self: Feature,
- transformation_function: Callable[[np.ndarray], np.ndarray, torch.Tensor],
- number_of_masks: PropertyLike[int] = 1,
- output_region: PropertyLike[tuple[int, int, int, int]] = None,
- merge_method: PropertyLike[str | Callable | list[str | Callable]] = "add",
- **kwargs: Any,
- ):
- """Initialize the SampleToMasks feature.
-
- Parameters
- ----------
- transformation_function: Callable[[Image], Image]
- Function to transform input images into masks.
- number_of_masks: PropertyLike[int], optional
- Number of mask layers. Default is 1.
- output_region: PropertyLike[tuple[int, int, int, int]], optional
- Output region of the mask. Default is None.
- merge_method: PropertyLike[str | Callable | list[str | Callable]], optional
- Method to merge masks. Defaults to "add".
- **kwargs: dict[str, Any]
- Additional keyword arguments passed to the parent class.
-
- """
-
- super().__init__(
- transformation_function=transformation_function,
- number_of_masks=number_of_masks,
- output_region=output_region,
- merge_method=merge_method,
- **kwargs,
- )
-
- def get(
- self: Feature,
- image: np.ndarray,
- transformation_function: Callable[list[np.ndarray] | np.ndarray | torch.Tensor],
- **kwargs: Any,
- ) -> np.ndarray:
- """Apply the transformation function to a single image.
-
- Parameters
- ----------
- image: np.ndarray
- The input image.
- transformation_function: Callable[[np.ndarray], np.ndarray]
- Function to transform the image.
- **kwargs: dict[str, Any]
- Additional parameters.
-
- Returns
- -------
- Image
- The transformed image.
-
- """
-
- return transformation_function(image.array)
-
- def _process_and_get(
- self: Feature,
- images: list[np.ndarray] | np.ndarray | list[torch.Tensor] | torch.Tensor,
- **kwargs: Any,
- ) -> np.ndarray:
- """Process a list of images and generate a multi-layer mask.
-
- Parameters
- ----------
- images: np.ndarray or list[np.ndarrray] or Image or list[Image]
- List of input images or a single image.
- **kwargs: dict[str, Any]
- Additional parameters including `output_region`, `number_of_masks`,
- and `merge_method`.
-
- Returns
- -------
- Image or np.ndarray
- The final mask image.
-
- """
-
- # Handle list of images.
- # if isinstance(images, list) and len(images) != 1:
- list_of_labels = super()._process_and_get(images, **kwargs)
-
- from deeptrack.scatterers import ScatteredVolume
-
- for idx, (label, image) in enumerate(zip(list_of_labels, images)):
- list_of_labels[idx] = \
- ScatteredVolume(array=label, properties=image.properties.copy())
-
- # Create an empty output image.
- output_region = kwargs["output_region"]
- output = xp.zeros(
- (
- output_region[2] - output_region[0],
- output_region[3] - output_region[1],
- kwargs["number_of_masks"],
- ),
- dtype=list_of_labels[0].array.dtype,
- )
-
- from deeptrack.optics import _get_position
-
- # Merge masks into the output.
- for volume in list_of_labels:
- label = volume.array
- position = _get_position(volume)
-
- p0 = xp.round(position - xp.asarray(output_region[0:2]))
- p0 = p0.astype(xp.int64)
-
-
- if xp.any(p0 > xp.asarray(output.shape[:2])) or \
- xp.any(p0 + xp.asarray(label.shape[:2]) < 0):
- continue
-
- crop_x = (-xp.minimum(p0[0], 0)).item()
- crop_y = (-xp.minimum(p0[1], 0)).item()
-
- crop_x_end = int(
- label.shape[0]
- - np.max([p0[0] + label.shape[0] - output.shape[0], 0])
- )
- crop_y_end = int(
- label.shape[1]
- - np.max([p0[1] + label.shape[1] - output.shape[1], 0])
- )
-
- labelarg = label[crop_x:crop_x_end, crop_y:crop_y_end, :]
-
- p0[0] = np.max([p0[0], 0])
- p0[1] = np.max([p0[1], 0])
-
- p0 = p0.astype(int)
-
- output_slice = output[
- p0[0] : p0[0] + labelarg.shape[0],
- p0[1] : p0[1] + labelarg.shape[1],
- ]
-
- for label_index in range(kwargs["number_of_masks"]):
-
- if isinstance(kwargs["merge_method"], list):
- merge = kwargs["merge_method"][label_index]
- else:
- merge = kwargs["merge_method"]
-
- if merge == "add":
- output[
- p0[0] : p0[0] + labelarg.shape[0],
- p0[1] : p0[1] + labelarg.shape[1],
- label_index,
- ] += labelarg[..., label_index]
-
- elif merge == "overwrite":
- output_slice[
- labelarg[..., label_index] != 0, label_index
- ] = labelarg[labelarg[..., label_index] != 0, \
- label_index]
- output[
- p0[0] : p0[0] + labelarg.shape[0],
- p0[1] : p0[1] + labelarg.shape[1],
- label_index,
- ] = output_slice[..., label_index]
-
- elif merge == "or":
- output[
- p0[0] : p0[0] + labelarg.shape[0],
- p0[1] : p0[1] + labelarg.shape[1],
- label_index,
- ] = xp.logical_or(
- output_slice[..., label_index] != 0,
- labelarg[..., label_index] != 0
- )
-
- elif merge == "mul":
- output[
- p0[0] : p0[0] + labelarg.shape[0],
- p0[1] : p0[1] + labelarg.shape[1],
- label_index,
- ] *= labelarg[..., label_index]
-
- else:
- # No match, assume function
- output[
- p0[0] : p0[0] + labelarg.shape[0],
- p0[1] : p0[1] + labelarg.shape[1],
- label_index,
- ] = merge(
- output_slice[..., label_index],
- labelarg[..., label_index],
- )
-
- return output
-
-
class AsType(Feature):
"""Convert the data type of arrays.
@@ -7920,876 +7621,181 @@ def get(
return array
-class Upscale(Feature):
- """Simulate a pipeline at a higher resolution.
-
- This feature scales up the resolution of the input pipeline by a specified
- factor, performs computations at the higher resolution, and then
- downsamples the result back to the original size. This is useful for
- simulating effects at a finer resolution while preserving compatibility
- with lower-resolution pipelines.
-
- Internally, this feature redefines the scale of physical units (e.g.,
- `units.pixel`) to achieve the effect of upscaling. Therefore, it does not
- resize the input image itself but affects only features that rely on
- physical units.
-
- Parameters
- ----------
- feature: Feature
- The pipeline or feature to resolve at a higher resolution.
- factor: int or tuple[int, int, int], optional
- The factor by which to upscale the simulation. If a single integer is
- provided, it is applied uniformly across all axes. If a tuple of three
- integers is provided, each axis is scaled individually. Defaults to 1.
- **kwargs: Any
- Additional keyword arguments passed to the parent `Feature` class.
-
- Attributes
- ----------
- __distributed__: bool
- Always `False` for `Upscale`, indicating that this feature’s `.get()`
- method processes the entire input at once even if it is a list, rather
- than distributing calls for each item of the list.
-
- Methods
- -------
- `get(image, factor, **kwargs) -> np.ndarray | torch.tensor`
- Simulates the pipeline at a higher resolution and returns the result at
- the original resolution.
-
- Notes
- -----
- - This feature does not directly resize the image. Instead, it modifies the
- unit conversions within the pipeline, making physical units smaller,
- which results in more detail being simulated.
- - The final output is downscaled back to the original resolution using
- `block_reduce` from `skimage.measure`.
- - The effect is only noticeable if features use physical units (e.g.,
- `units.pixel`, `units.meter`). Otherwise, the result will be identical.
-
- Examples
- --------
- >>> import deeptrack as dt
-
- Define an optical pipeline and a spherical particle:
-
- >>> optics = dt.Fluorescence()
- >>> particle = dt.Sphere()
- >>> simple_pipeline = optics(particle)
-
- Create an upscaled pipeline with a factor of 4:
-
- >>> upscaled_pipeline = dt.Upscale(optics(particle), factor=4)
-
- Resolve the pipelines:
-
- >>> image = simple_pipeline()
- >>> upscaled_image = upscaled_pipeline()
-
- Visualize the images:
-
- >>> import matplotlib.pyplot as plt
- >>>
- >>> plt.subplot(1, 2, 1)
- >>> plt.imshow(image, cmap="gray")
- >>> plt.title("Original Image")
- >>>
- >>> plt.subplot(1, 2, 2)
- >>> plt.imshow(upscaled_image, cmap="gray")
- >>> plt.title("Simulated at Higher Resolution")
- >>>
- >>> plt.show()
-
- Compare the shapes (both are the same due to downscaling):
-
- >>> print(image.shape)
- (128, 128, 1)
- >>> print(upscaled_image.shape)
- (128, 128, 1)
-
- """
-
- __distributed__: bool = False
-
- feature: Feature
-
- def __init__(
- self: Feature,
- feature: Feature,
- factor: int | tuple[int, int, int] = 1,
- **kwargs: Any,
- ) -> None:
- """Initialize the Upscale feature.
-
- Parameters
- ----------
- feature: Feature
- The pipeline or feature to resolve at a higher resolution.
- factor: int or tuple[int, int, int], optional
- The factor by which to upscale the simulation. If a single integer
- is provided, it is applied uniformly across all axes. If a tuple of
- three integers is provided, each axis is scaled individually.
- Defaults to 1.
- **kwargs: Any
- Additional keyword arguments passed to the parent `Feature` class.
-
- """
-
- super().__init__(factor=factor, **kwargs)
- self.feature = self.add_feature(feature)
-
- def get(
- self: Feature,
- image: np.ndarray | torch.Tensor,
- factor: int | tuple[int, int, int],
- **kwargs: Any,
- ) -> np.ndarray | torch.Tensor:
- """Simulate the pipeline at a higher resolution and return result.
-
- Parameters
- ----------
- image: np.ndarray or torch.Tensor
- The input image to process.
- factor: int or tuple[int, int, int]
- The factor by which to upscale the simulation. If a single integer
- is provided, it is applied uniformly across all axes. If a tuple of
- three integers is provided, each axis is scaled individually.
- **kwargs: Any
- Additional keyword arguments passed to the feature.
-
- Returns
- -------
- np.ndarray or torch.Tensor
- The processed image at the original resolution.
+# class Upscale(Feature):
+# """Simulate a pipeline at a higher resolution.
- Raises
- ------
- ValueError
- If the input `factor` is not a valid integer or tuple of integers.
-
- """
-
- # Ensure factor is a tuple of three integers.
- if np.size(factor) == 1:
- factor = (factor, factor, 1)
- elif len(factor) != 3:
- raise ValueError(
- "Factor must be an integer or a tuple of three integers."
- )
-
- # Create a context for upscaling and perform computation.
- ctx = create_context(None, None, None, *factor)
-
- print('before:', image)
- with units.context(ctx):
- image = self.feature(image)
-
- print('after:', image)
- # Downscale the result to the original resolution.
- import skimage.measure
-
- image = skimage.measure.block_reduce(
- image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean
- )
-
- return image
-
-
-class NonOverlapping(Feature):
- """Ensure volumes are placed non-overlapping in a 3D space.
-
- This feature ensures that a list of 3D volumes are positioned such that
- their non-zero voxels do not overlap. If volumes overlap, their positions
- are resampled until they are non-overlapping. If the maximum number of
- attempts is exceeded, the feature regenerates the list of volumes and
- raises a warning if non-overlapping placement cannot be achieved.
-
- Note: `min_distance` refers to the distance between the edges of volumes,
- not their centers. Due to the way volumes are calculated, slight rounding
- errors may affect the final distance.
+# This feature scales up the resolution of the input pipeline by a specified
+# factor, performs computations at the higher resolution, and then
+# downsamples the result back to the original size. This is useful for
+# simulating effects at a finer resolution while preserving compatibility
+# with lower-resolution pipelines.
- This feature is incompatible with non-volumetric scatterers such as
- `MieScatterers`.
+# Internally, this feature redefines the scale of physical units (e.g.,
+# `units.pixel`) to achieve the effect of upscaling. Therefore, it does not
+# resize the input image itself but affects only features that rely on
+# physical units.
+
+# Parameters
+# ----------
+# feature: Feature
+# The pipeline or feature to resolve at a higher resolution.
+# factor: int or tuple[int, int, int], optional
+# The factor by which to upscale the simulation. If a single integer is
+# provided, it is applied uniformly across all axes. If a tuple of three
+# integers is provided, each axis is scaled individually. Defaults to 1.
+# **kwargs: Any
+# Additional keyword arguments passed to the parent `Feature` class.
+
+# Attributes
+# ----------
+# __distributed__: bool
+# Always `False` for `Upscale`, indicating that this feature’s `.get()`
+# method processes the entire input at once even if it is a list, rather
+# than distributing calls for each item of the list.
+
+# Methods
+# -------
+# `get(image, factor, **kwargs) -> np.ndarray | torch.tensor`
+# Simulates the pipeline at a higher resolution and returns the result at
+# the original resolution.
+
+# Notes
+# -----
+# - This feature does not directly resize the image. Instead, it modifies the
+# unit conversions within the pipeline, making physical units smaller,
+# which results in more detail being simulated.
+# - The final output is downscaled back to the original resolution using
+# `block_reduce` from `skimage.measure`.
+# - The effect is only noticeable if features use physical units (e.g.,
+# `units.pixel`, `units.meter`). Otherwise, the result will be identical.
+
+# Examples
+# --------
+# >>> import deeptrack as dt
+
+# Define an optical pipeline and a spherical particle:
+
+# >>> optics = dt.Fluorescence()
+# >>> particle = dt.Sphere()
+# >>> simple_pipeline = optics(particle)
+
+# Create an upscaled pipeline with a factor of 4:
+
+# >>> upscaled_pipeline = dt.Upscale(optics(particle), factor=4)
- Parameters
- ----------
- feature: Feature
- The feature that generates the list of volumes to place
- non-overlapping.
- min_distance: float, optional
- The minimum distance between volumes in pixels. It can be negative to
- allow for partial overlap. Defaults to 1.
- max_attempts: int, optional
- The maximum number of attempts to place volumes without overlap.
- Defaults to 5.
- max_iters: int, optional
- The maximum number of resamplings. If this number is exceeded, a new
- list of volumes is generated. Defaults to 100.
-
- Attributes
- ----------
- __distributed__: bool
- Always `False` for `NonOverlapping`, indicating that this feature’s
- `.get()` method processes the entire input at once even if it is a
- list, rather than distributing calls for each item of the list.N
-
- Methods
- -------
- `get(*_, min_distance, max_attempts, **kwargs) -> array`
- Generate a list of non-overlapping 3D volumes.
- `_check_non_overlapping(list_of_volumes) -> bool`
- Check if all volumes in the list are non-overlapping.
- `_check_bounding_cubes_non_overlapping(...) -> bool`
- Check if two bounding cubes are non-overlapping.
- `_get_overlapping_cube(...) -> list[int]`
- Get the overlapping cube between two bounding cubes.
- `_get_overlapping_volume(...) -> array`
- Get the overlapping volume between a volume and a bounding cube.
- `_check_volumes_non_overlapping(...) -> bool`
- Check if two volumes are non-overlapping.
- `_resample_volume_position(volume) -> Image`
- Resample the position of a volume to avoid overlap.
+# Resolve the pipelines:
+
+# >>> image = simple_pipeline()
+# >>> upscaled_image = upscaled_pipeline()
+
+# Visualize the images:
+
+# >>> import matplotlib.pyplot as plt
+# >>>
+# >>> plt.subplot(1, 2, 1)
+# >>> plt.imshow(image, cmap="gray")
+# >>> plt.title("Original Image")
+# >>>
+# >>> plt.subplot(1, 2, 2)
+# >>> plt.imshow(upscaled_image, cmap="gray")
+# >>> plt.title("Simulated at Higher Resolution")
+# >>>
+# >>> plt.show()
- Notes
- -----
- - This feature performs bounding cube checks first to quickly reject
- obvious overlaps before voxel-level checks.
- - If the bounding cubes overlap, precise voxel-based checks are performed.
-
- Examples
- ---------
- >>> import deeptrack as dt
-
- Define an ellipse scatterer with randomly positioned objects:
-
- >>> import numpy as np
- >>>
- >>> scatterer = dt.Ellipse(
- >>> radius= 13 * dt.units.pixels,
- >>> position=lambda: np.random.uniform(5, 115, size=2)* dt.units.pixels,
- >>> )
+# Compare the shapes (both are the same due to downscaling):
- Create multiple scatterers:
-
- >>> scatterers = (scatterer ^ 8)
-
- Define the optics and create the image with possible overlap:
-
- >>> optics = dt.Fluorescence()
- >>> im_with_overlap = optics(scatterers)
- >>> im_with_overlap.store_properties()
- >>> im_with_overlap_resolved = image_with_overlap()
-
- Gather position from image:
-
- >>> pos_with_overlap = np.array(
- >>> im_with_overlap_resolved.get_property(
- >>> "position",
- >>> get_one=False
- >>> )
- >>> )
-
- Enforce non-overlapping and create the image without overlap:
+# >>> print(image.shape)
+# (128, 128, 1)
+# >>> print(upscaled_image.shape)
+# (128, 128, 1)
- >>> non_overlapping_scatterers = dt.NonOverlapping(
- ... scatterers,
- ... min_distance=4,
- ... )
- >>> im_without_overlap = optics(non_overlapping_scatterers)
- >>> im_without_overlap.store_properties()
- >>> im_without_overlap_resolved = im_without_overlap()
-
- Gather position from image:
-
- >>> pos_without_overlap = np.array(
- >>> im_without_overlap_resolved.get_property(
- >>> "position",
- >>> get_one=False
- >>> )
- >>> )
-
- Create a figure with two subplots to visualize the difference:
-
- >>> import matplotlib.pyplot as plt
- >>>
- >>> fig, axes = plt.subplots(1, 2, figsize=(10, 5))
- >>>
- >>> axes[0].imshow(im_with_overlap_resolved, cmap="gray")
- >>> axes[0].scatter(pos_with_overlap[:,1],pos_with_overlap[:,0])
- >>> axes[0].set_title("Overlapping Objects")
- >>> axes[0].axis("off")
- >>>
- >>> axes[1].imshow(im_without_overlap_resolved, cmap="gray")
- >>> axes[1].scatter(pos_without_overlap[:,1],pos_without_overlap[:,0])
- >>> axes[1].set_title("Non-Overlapping Objects")
- >>> axes[1].axis("off")
- >>> plt.tight_layout()
- >>>
- >>> plt.show()
-
- Define function to calculate minimum distance:
-
- >>> def calculate_min_distance(positions):
- >>> distances = [
- >>> np.linalg.norm(positions[i] - positions[j])
- >>> for i in range(len(positions))
- >>> for j in range(i + 1, len(positions))
- >>> ]
- >>> return min(distances)
-
- Print minimum distances with and without overlap:
-
- >>> print(calculate_min_distance(pos_with_overlap))
- 10.768742383382174
-
- >>> print(calculate_min_distance(pos_without_overlap))
- 30.82531120942446
-
- """
-
- __distributed__: bool = False
-
- def __init__(
- self: NonOverlapping,
- feature: Feature,
- min_distance: float = 1,
- max_attempts: int = 5,
- max_iters: int = 100,
- **kwargs: Any,
- ):
- """Initializes the NonOverlapping feature.
-
- Ensures that volumes are placed **non-overlapping** by iteratively
- resampling their positions. If the maximum number of attempts is
- exceeded, the feature regenerates the list of volumes.
-
- Parameters
- ----------
- feature: Feature
- The feature that generates the list of volumes.
- min_distance: float, optional
- The minimum separation distance **between volume edges**, in
- pixels. It defaults to `1`. Negative values allow for partial
- overlap.
- max_attempts: int, optional
- The maximum number of attempts to place the volumes without
- overlap. It defaults to `5`.
- max_iters: int, optional
- The maximum number of resampling iterations per attempt. If
- exceeded, a new list of volumes is generated. It defaults to `100`.
-
- """
-
- super().__init__(
- min_distance=min_distance,
- max_attempts=max_attempts,
- max_iters=max_iters,
- **kwargs,
- )
- self.feature = self.add_feature(feature, **kwargs)
-
- def get(
- self: NonOverlapping,
- *_: Any,
- min_distance: float,
- max_attempts: int,
- max_iters: int,
- **kwargs: Any,
- ) -> list[np.ndarray]:
- """Generates a list of non-overlapping 3D volumes within a defined
- field of view (FOV).
-
- This method **iteratively** attempts to place volumes while ensuring
- they maintain at least `min_distance` separation. If non-overlapping
- placement is not achieved within `max_attempts`, a warning is issued,
- and the best available configuration is returned.
-
- Parameters
- ----------
- _: Any
- Placeholder parameter, typically for an input image.
- min_distance: float
- The minimum required separation distance between volumes, in
- pixels.
- max_attempts: int
- The maximum number of attempts to generate a valid non-overlapping
- configuration.
- max_iters: int
- The maximum number of resampling iterations per attempt.
- **kwargs: Any
- Additional parameters that may be used by subclasses.
-
- Returns
- -------
- list[np.ndarray]
- A list of 3D volumes represented as NumPy arrays. If
- non-overlapping placement is unsuccessful, the best available
- configuration is returned.
-
- Warns
- -----
- UserWarning
- If non-overlapping placement is **not** achieved within
- `max_attempts`, suggesting parameter adjustments such as increasing
- the FOV or reducing `min_distance`.
-
- Notes
- -----
- - The placement process prioritizes bounding cube checks for
- efficiency.
- - If bounding cubes overlap, voxel-based overlap checks are performed.
+# """
+
+# __distributed__: bool = False
+
+# feature: Feature
+
+# def __init__(
+# self: Feature,
+# feature: Feature,
+# factor: int | tuple[int, int, int] = 1,
+# **kwargs: Any,
+# ) -> None:
+# """Initialize the Upscale feature.
+
+# Parameters
+# ----------
+# feature: Feature
+# The pipeline or feature to resolve at a higher resolution.
+# factor: int or tuple[int, int, int], optional
+# The factor by which to upscale the simulation. If a single integer
+# is provided, it is applied uniformly across all axes. If a tuple of
+# three integers is provided, each axis is scaled individually.
+# Defaults to 1.
+# **kwargs: Any
+# Additional keyword arguments passed to the parent `Feature` class.
+
+# """
+
+# super().__init__(factor=factor, **kwargs)
+# self.feature = self.add_feature(feature)
+
+# def get(
+# self: Feature,
+# image: np.ndarray | torch.Tensor,
+# factor: int | tuple[int, int, int],
+# **kwargs: Any,
+# ) -> np.ndarray | torch.Tensor:
+# """Simulate the pipeline at a higher resolution and return result.
+
+# Parameters
+# ----------
+# image: np.ndarray or torch.Tensor
+# The input image to process.
+# factor: int or tuple[int, int, int]
+# The factor by which to upscale the simulation. If a single integer
+# is provided, it is applied uniformly across all axes. If a tuple of
+# three integers is provided, each axis is scaled individually.
+# **kwargs: Any
+# Additional keyword arguments passed to the feature.
+
+# Returns
+# -------
+# np.ndarray or torch.Tensor
+# The processed image at the original resolution.
+
+# Raises
+# ------
+# ValueError
+# If the input `factor` is not a valid integer or tuple of integers.
+
+# """
+
+# # Ensure factor is a tuple of three integers.
+# if np.size(factor) == 1:
+# factor = (factor, factor, 1)
+# elif len(factor) != 3:
+# raise ValueError(
+# "Factor must be an integer or a tuple of three integers."
+# )
- """
-
- for _ in range(max_attempts):
- list_of_volumes = self.feature()
-
- if not isinstance(list_of_volumes, list):
- list_of_volumes = [list_of_volumes]
+# # Create a context for upscaling and perform computation.
+# ctx = create_context(None, None, None, *factor)
- for _ in range(max_iters):
-
- list_of_volumes = [
- self._resample_volume_position(volume)
- for volume in list_of_volumes
- ]
+# print('before:', image)
+# with units.context(ctx):
+# image = self.feature(image)
- if self._check_non_overlapping(list_of_volumes):
- return list_of_volumes
+# print('after:', image)
+# # Downscale the result to the original resolution.
+# import skimage.measure
- # Generate a new list of volumes if max_attempts is exceeded.
- self.feature.update()
-
- warnings.warn(
- "Non-overlapping placement could not be achieved. Consider "
- "adjusting parameters: reduce object radius, increase FOV, "
- "or decrease min_distance.",
- UserWarning,
- )
- return list_of_volumes
-
- def _check_non_overlapping(
- self: NonOverlapping,
- list_of_volumes: list[np.ndarray],
- ) -> bool:
- """Determines whether all volumes in the provided list are
- non-overlapping.
-
- This method verifies that the non-zero voxels of each 3D volume in
- `list_of_volumes` are at least `min_distance` apart. It first checks
- bounding boxes for early rejection and then examines actual voxel
- overlap when necessary. Volumes are assumed to have a `position`
- attribute indicating their placement in 3D space.
-
- Parameters
- ----------
- list_of_volumes: list[np.ndarray]
- A list of 3D arrays representing the volumes to be checked for
- overlap. Each volume is expected to have a position attribute.
-
- Returns
- -------
- bool
- `True` if all volumes are non-overlapping, otherwise `False`.
-
- Notes
- -----
- - If `min_distance` is negative, volumes are shrunk using isotropic
- erosion before checking overlap.
- - If `min_distance` is positive, volumes are padded and expanded using
- isotropic dilation.
- - Overlapping checks are first performed on bounding cubes for
- efficiency.
- - If bounding cubes overlap, voxel-level checks are performed.
-
- """
- from deeptrack.scatterers import ScatteredVolume
-
- from deeptrack.augmentations import CropTight, Pad # these are not compatibles with torch backend
- from deeptrack.optics import _get_position
- from deeptrack.math import isotropic_erosion, isotropic_dilation
-
- min_distance = self.min_distance()
- crop = CropTight()
-
- new_volumes = []
-
- for volume in list_of_volumes:
- arr = volume.array
- mask = arr != 0
-
- if min_distance < 0:
- new_arr = isotropic_erosion(mask, -min_distance / 2, backend=self.get_backend())
- else:
- pad = Pad(px=[int(np.ceil(min_distance / 2))] * 6, keep_size=True)
- new_arr = isotropic_dilation(pad(mask) != 0 , min_distance / 2, backend=self.get_backend())
- new_arr = crop(new_arr)
-
- if self.get_backend() == "torch":
- new_arr = new_arr.to(dtype=arr.dtype)
- else:
- new_arr = new_arr.astype(arr.dtype)
-
- new_volume = ScatteredVolume(
- array=new_arr,
- properties=volume.properties.copy(),
- )
-
- new_volumes.append(new_volume)
-
- list_of_volumes = new_volumes
- min_distance = 1
-
- # The position of the top left corner of each volume (index (0, 0, 0)).
- volume_positions_1 = [
- _get_position(volume, mode="corner", return_z=True).astype(int)
- for volume in list_of_volumes
- ]
-
- # The position of the bottom right corner of each volume
- # (index (-1, -1, -1)).
- volume_positions_2 = [
- p0 + np.array(v.shape)
- for v, p0 in zip(list_of_volumes, volume_positions_1)
- ]
-
- # (x1, y1, z1, x2, y2, z2) for each volume.
- volume_bounding_cube = [
- [*p0, *p1]
- for p0, p1 in zip(volume_positions_1, volume_positions_2)
- ]
-
- for i, j in itertools.combinations(range(len(list_of_volumes)), 2):
-
- # If the bounding cubes do not overlap, the volumes do not overlap.
- if self._check_bounding_cubes_non_overlapping(
- volume_bounding_cube[i], volume_bounding_cube[j], min_distance
- ):
- continue
-
- # If the bounding cubes overlap, get the overlapping region of each
- # volume.
- overlapping_cube = self._get_overlapping_cube(
- volume_bounding_cube[i], volume_bounding_cube[j]
- )
- overlapping_volume_1 = self._get_overlapping_volume(
- list_of_volumes[i].array, volume_bounding_cube[i], overlapping_cube
- )
- overlapping_volume_2 = self._get_overlapping_volume(
- list_of_volumes[j].array, volume_bounding_cube[j], overlapping_cube
- )
-
- # If either the overlapping regions are empty, the volumes do not
- # overlap (done for speed).
- if (np.all(overlapping_volume_1 == 0)
- or np.all(overlapping_volume_2 == 0)):
- continue
-
- # If products of overlapping regions are non-zero, return False.
- # if np.any(overlapping_volume_1 * overlapping_volume_2):
- # return False
-
- # Finally, check that the non-zero voxels of the volumes are at
- # least min_distance apart.
- if not self._check_volumes_non_overlapping(
- overlapping_volume_1, overlapping_volume_2, min_distance
- ):
- return False
-
- return True
-
- def _check_bounding_cubes_non_overlapping(
- self: NonOverlapping,
- bounding_cube_1: list[int],
- bounding_cube_2: list[int],
- min_distance: float,
- ) -> bool:
- """Determines whether two 3D bounding cubes are non-overlapping.
-
- This method checks whether the bounding cubes of two volumes are
- **separated by at least** `min_distance` along **any** spatial axis.
-
- Parameters
- ----------
- bounding_cube_1: list[int]
- A list of six integers `[x1, y1, z1, x2, y2, z2]` representing
- the first bounding cube.
- bounding_cube_2: list[int]
- A list of six integers `[x1, y1, z1, x2, y2, z2]` representing
- the second bounding cube.
- min_distance: float
- The required **minimum separation distance** between the two
- bounding cubes.
-
- Returns
- -------
- bool
- `True` if the bounding cubes are non-overlapping (separated by at
- least `min_distance` along **at least one axis**), otherwise
- `False`.
-
- Notes
- -----
- - This function **only checks bounding cubes**, **not actual voxel
- data**.
- - If the bounding cubes are non-overlapping, the corresponding
- **volumes are also non-overlapping**.
- - This check is much **faster** than full voxel-based comparisons.
-
- """
-
- # bounding_cube_1 and bounding_cube_2 are (x1, y1, z1, x2, y2, z2).
- # Check that the bounding cubes are non-overlapping.
- return (
- (bounding_cube_1[0] >= bounding_cube_2[3] + min_distance) or
- (bounding_cube_2[0] >= bounding_cube_1[3] + min_distance) or
- (bounding_cube_1[1] >= bounding_cube_2[4] + min_distance) or
- (bounding_cube_2[1] >= bounding_cube_1[4] + min_distance) or
- (bounding_cube_1[2] >= bounding_cube_2[5] + min_distance) or
- (bounding_cube_2[2] >= bounding_cube_1[5] + min_distance)
- )
-
- def _get_overlapping_cube(
- self: NonOverlapping,
- bounding_cube_1: list[int],
- bounding_cube_2: list[int],
- ) -> list[int]:
- """Computes the overlapping region between two 3D bounding cubes.
-
- This method calculates the coordinates of the intersection of two
- axis-aligned bounding cubes, each represented as a list of six
- integers:
-
- - `[x1, y1, z1]`: Coordinates of the **top-left-front** corner.
- - `[x2, y2, z2]`: Coordinates of the **bottom-right-back** corner.
-
- The resulting overlapping region is determined by:
- - Taking the **maximum** of the starting coordinates (`x1, y1, z1`).
- - Taking the **minimum** of the ending coordinates (`x2, y2, z2`).
-
- If the cubes **do not** overlap, the resulting coordinates will not
- form a valid cube (i.e., `x1 > x2`, `y1 > y2`, or `z1 > z2`).
-
- Parameters
- ----------
- bounding_cube_1: list[int]
- The first bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`.
- bounding_cube_2: list[int]
- The second bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`.
-
- Returns
- -------
- list[int]
- A list of six integers `[x1, y1, z1, x2, y2, z2]` representing the
- overlapping bounding cube. If no overlap exists, the coordinates
- will **not** define a valid cube.
-
- Notes
- -----
- - This function does **not** check for valid input or ensure the
- resulting cube is well-formed.
- - If no overlap exists, downstream functions must handle the invalid
- result.
-
- """
-
- return [
- max(bounding_cube_1[0], bounding_cube_2[0]),
- max(bounding_cube_1[1], bounding_cube_2[1]),
- max(bounding_cube_1[2], bounding_cube_2[2]),
- min(bounding_cube_1[3], bounding_cube_2[3]),
- min(bounding_cube_1[4], bounding_cube_2[4]),
- min(bounding_cube_1[5], bounding_cube_2[5]),
- ]
-
- def _get_overlapping_volume(
- self: NonOverlapping,
- volume: np.ndarray, # 3D array.
- bounding_cube: tuple[float, float, float, float, float, float],
- overlapping_cube: tuple[float, float, float, float, float, float],
- ) -> np.ndarray:
- """Extracts the overlapping region of a 3D volume within the specified
- overlapping cube.
-
- This method identifies and returns the subregion of `volume` that
- lies within the `overlapping_cube`. The bounding information of the
- volume is provided via `bounding_cube`.
-
- Parameters
- ----------
- volume: np.ndarray
- A 3D NumPy array representing the volume from which the
- overlapping region is extracted.
- bounding_cube: tuple[float, float, float, float, float, float]
- The bounding cube of the volume, given as a tuple of six floats:
- `(x1, y1, z1, x2, y2, z2)`. The first three values define the
- **top-left-front** corner, while the last three values define the
- **bottom-right-back** corner.
- overlapping_cube: tuple[float, float, float, float, float, float]
- The overlapping region between the volume and another volume,
- represented in the same format as `bounding_cube`.
-
- Returns
- -------
- np.ndarray
- A 3D NumPy array representing the portion of `volume` that
- lies within `overlapping_cube`. If the overlap does not exist,
- an empty array may be returned.
-
- Notes
- -----
- - The method computes the relative indices of `overlapping_cube`
- within `volume` by subtracting the bounding cube's starting
- position.
- - The extracted region is determined by integer indices, meaning
- coordinates are implicitly **floored to integers**.
- - If `overlapping_cube` extends beyond `volume` boundaries, the
- returned subregion is **cropped** to fit within `volume`.
-
- """
-
- # The position of the top left corner of the overlapping cube in the volume
- overlapping_cube_position = np.array(overlapping_cube[:3]) - np.array(
- bounding_cube[:3]
- )
-
- # The position of the bottom right corner of the overlapping cube in the volume
- overlapping_cube_end_position = np.array(
- overlapping_cube[3:]
- ) - np.array(bounding_cube[:3])
-
- # cast to int
- overlapping_cube_position = overlapping_cube_position.astype(int)
- overlapping_cube_end_position = overlapping_cube_end_position.astype(int)
-
- return volume[
- overlapping_cube_position[0] : overlapping_cube_end_position[0],
- overlapping_cube_position[1] : overlapping_cube_end_position[1],
- overlapping_cube_position[2] : overlapping_cube_end_position[2],
- ]
-
- def _check_volumes_non_overlapping(
- self: NonOverlapping,
- volume_1: np.ndarray,
- volume_2: np.ndarray,
- min_distance: float,
- ) -> bool:
- """Determines whether the non-zero voxels in two 3D volumes are at
- least `min_distance` apart.
-
- This method checks whether the active regions (non-zero voxels) in
- `volume_1` and `volume_2` maintain a minimum separation of
- `min_distance`. If the volumes differ in size, the positions of their
- non-zero voxels are adjusted accordingly to ensure a fair comparison.
-
- Parameters
- ----------
- volume_1: np.ndarray
- A 3D NumPy array representing the first volume.
- volume_2: np.ndarray
- A 3D NumPy array representing the second volume.
- min_distance: float
- The minimum Euclidean distance required between any two non-zero
- voxels in the two volumes.
-
- Returns
- -------
- bool
- `True` if all non-zero voxels in `volume_1` and `volume_2` are at
- least `min_distance` apart, otherwise `False`.
-
- Notes
- -----
- - This function assumes both volumes are correctly aligned within a
- shared coordinate space.
- - If the volumes are of different sizes, voxel positions are scaled
- or adjusted for accurate distance measurement.
- - Uses **Euclidean distance** for separation checking.
- - If either volume is empty (i.e., no non-zero voxels), they are
- considered non-overlapping.
-
- """
-
- # Get the positions of the non-zero voxels of each volume.
- if self.get_backend() == "torch":
- positions_1 = torch.nonzero(volume_1, as_tuple=False)
- positions_2 = torch.nonzero(volume_2, as_tuple=False)
- else:
- positions_1 = np.argwhere(volume_1)
- positions_2 = np.argwhere(volume_2)
-
- # if positions_1.size == 0 or positions_2.size == 0:
- # return True # If either volume is empty, they are "non-overlapping"
-
- # # If the volumes are not the same size, the positions of the non-zero
- # # voxels of each volume need to be scaled.
- # if positions_1.size == 0 or positions_2.size == 0:
- # return True # If either volume is empty, they are "non-overlapping"
-
- # If the volumes are not the same size, the positions of the non-zero
- # voxels of each volume need to be scaled.
- if volume_1.shape != volume_2.shape:
- positions_1 = (
- positions_1 * np.array(volume_2.shape)
- / np.array(volume_1.shape)
- )
- positions_1 = positions_1.astype(int)
-
- # Check that the non-zero voxels of the volumes are at least
- # min_distance apart.
- if self.get_backend() == "torch":
- dist = torch.cdist(
- positions_1.float(),
- positions_2.float(),
- )
- return bool((dist > min_distance).all())
- else:
- return np.all(cdist(positions_1, positions_2) > min_distance)
-
- def _resample_volume_position(
- self: NonOverlapping,
- volume: np.ndarray | Image,
- ) -> Image:
- """Resamples the position of a 3D volume using its internal position
- sampler.
-
- This method updates the `position` property of the given `volume` by
- drawing a new position from the `_position_sampler` stored in the
- volume's `properties`. If the sampled position is a `Quantity`, it is
- converted to pixel units.
-
- Parameters
- ----------
- volume: np.ndarray
- The 3D volume whose position is to be resampled. The volume must
- have a `properties` attribute containing dictionaries with
- `position` and `_position_sampler` keys.
-
- Returns
- -------
- Image
- The same input volume with its `position` property updated to the
- newly sampled value.
-
- Notes
- -----
- - The `_position_sampler` function is expected to return a **tuple of
- three floats** (e.g., `(x, y, z)`).
- - If the sampled position is a `Quantity`, it is converted to pixels.
- - **Only** dictionaries in `volume.properties` that contain both
- `position` and `_position_sampler` keys are modified.
-
- """
+# image = skimage.measure.block_reduce(
+# image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean
+# )
- pdict = volume.properties
- if "position" in pdict and "_position_sampler" in pdict:
- new_position = pdict["_position_sampler"]()
- if isinstance(new_position, Quantity):
- new_position = new_position.to("pixel").magnitude
- pdict["position"] = new_position
+# return image
- return volume
class Store(Feature):
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index 0788d7d8..bb3cf40f 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -267,9 +267,10 @@ def _validate_input(self, scattered):
def _extract_contrast_volume(self, scattered):
if hasattr(self._objective, "extract_contrast_volume"):
- return self._objective.extract_contrast_volume(scattered)
-
- # default: geometry-only
+ return self._objective.extract_contrast_volume(
+ scattered,
+ **self._objective.properties(),
+ )
return scattered.array
def _downscale_image(self, image, upscale):
@@ -395,12 +396,14 @@ def get(
**additional_sample_kwargs,
)
+ print('prop', volume_samples[0].properties)
+
# Interpret the merged volume semantically
sample_volume = self._extract_contrast_volume(
ScatteredVolume(
array=sample_volume,
properties=volume_samples[0].properties,
- )
+ ),
)
# Let the objective know about the limits of the volume and all the fields.
@@ -1081,37 +1084,36 @@ def validate_input(self, scattered):
"Fluorescence microscope cannot operate on ScatteredField."
)
- # Fluorescence must not use refractive index
- if isinstance(scattered, ScatteredVolume):
- if scattered.get_property("refractive_index", None) is not None:
- raise ValueError(
- "Fluorescence does not use refractive index. "
- "Found 'refractive_index' in scatterer properties."
- )
-
def extract_contrast_volume(self, scattered: ScatteredVolume) -> np.ndarray:
- """Contrast extraction (semantic interpretation)"""
- intensity = scattered.get_property("intensity", None)
+ voxel_size = np.asarray(get_active_voxel_size(), float)
+ voxel_volume = np.prod(voxel_size)
- if intensity is None:
- intensity = scattered.get_property("value", None)
- if intensity is None:
- raise ValueError(
- "Fluorescence requires 'intensity' or 'value'."
- )
+ intensity = scattered.get_property("intensity", None)
+ value = scattered.get_property("value", None)
+ ri = scattered.get_property("refractive_index", None)
+ # Refractive index is always ignored in fluorescence
+ if ri is not None:
warnings.warn(
- "Using 'value' as fluorescence intensity is ambiguous. "
- "Please use 'intensity' explicitly to avoid ambiguity.",
+ "Scatterer defines 'refractive_index', which is ignored in "
+ "fluorescence microscopy.",
UserWarning,
)
- voxel_size = np.asarray(get_active_voxel_size(), dtype=float)
- voxel_volume = float(np.prod(voxel_size))
+ # Preferred, physically meaningful case
+ if intensity is not None:
+ return intensity * voxel_volume * scattered.array
- return scattered.array * intensity * voxel_volume
+ # Fallback: legacy / dimensionless brightness
+ warnings.warn(
+ "Fluorescence scatterer has no 'intensity'. Interpreting 'value' as a "
+ "non-physical brightness factor. Quantitative interpretation is invalid. "
+ "Define 'intensity' to model physical fluorescence emission.",
+ UserWarning,
+ )
+ return value * scattered.array
def downscale_image(self, image: np.ndarray, upscale):
"""Detector downscaling (energy conserving)"""
@@ -1357,8 +1359,50 @@ class Brightfield(Optics):
__conversion_table__ = ConversionTable(
- working_distance=(u.meter, u.meter),
- )
+ working_distance=(u.meter, u.meter),
+)
+
+ def validate_input(self, scattered):
+ """Semantic validation for brightfield microscopy."""
+
+ if isinstance(scattered, ScatteredVolume):
+ warnings.warn(
+ "Brightfield imaging from ScatteredVolume assumes a "
+ "weak-phase / projection approximation. "
+ "Use ScatteredField for physically accurate brightfield simulations.",
+ UserWarning,
+ )
+
+ def extract_contrast_volume(
+ self,
+ scattered: ScatteredVolume,
+ refractive_index_medium: float,
+ **kwargs: Any,
+ ) -> np.ndarray:
+ print('ri_medium', refractive_index_medium)
+
+ ri = scattered.get_property("refractive_index", None)
+ value = scattered.get_property("value", None)
+ intensity = scattered.get_property("intensity", None)
+
+ if intensity is not None:
+ warnings.warn(
+ "Scatterer defines 'intensity', which is ignored in "
+ "brightfield microscopy.",
+ UserWarning,
+ )
+
+ if ri is not None:
+ return (ri - refractive_index_medium) * scattered.array
+
+ warnings.warn(
+ "No 'refractive_index' specified; using 'value' as a non-physical "
+ "brightfield contrast. Results are not physically calibrated. "
+ "Define 'refractive_index' for physically meaningful contrast.",
+ UserWarning,
+ )
+
+ return value * scattered.array
def get(
self: Brightfield,
@@ -1746,6 +1790,57 @@ def __init__(
illumination_angle=illumination_angle,
**kwargs)
+ def validate_input(self, scattered):
+ if isinstance(scattered, ScatteredVolume):
+ warnings.warn(
+ "Darkfield imaging from ScatteredVolume is a very rough "
+ "approximation. Use ScatteredField for physically meaningful "
+ "darkfield simulations.",
+ UserWarning,
+ )
+
+ def extract_contrast_volume(
+ self,
+ scattered: ScatteredVolume,
+ refractive_index_medium: float,
+ **kwargs: Any,
+ ) -> np.ndarray:
+ """
+ Approximate darkfield contrast from a volume (toy model).
+
+ This is a non-physical approximation intended for qualitative simulations.
+ """
+
+ ri = scattered.get_property("refractive_index", None)
+ value = scattered.get_property("value", None)
+ intensity = scattered.get_property("intensity", None)
+
+ # Intensity has no meaning here
+ if intensity is not None:
+ warnings.warn(
+ "Scatterer defines 'intensity', which is ignored in "
+ "darkfield microscopy.",
+ UserWarning,
+ )
+
+ if ri is not None:
+ delta_n = ri - refractive_index_medium
+ warnings.warn(
+ "Approximating darkfield contrast from refractive index. "
+ "Result is non-physical and qualitative only.",
+ UserWarning,
+ )
+ return (delta_n ** 2) * scattered.array
+
+ warnings.warn(
+ "No 'refractive_index' specified; using 'value' as a non-physical "
+ "darkfield scattering strength. Results are qualitative only.",
+ UserWarning,
+ )
+
+ return (value ** 2) * scattered.array
+
+
#Retrieve get as super
def get(
self: Darkfield,
@@ -1922,6 +2017,998 @@ def get(
return image
+class NonOverlapping(Feature):
+ """Ensure volumes are placed non-overlapping in a 3D space.
+
+ This feature ensures that a list of 3D volumes are positioned such that
+ their non-zero voxels do not overlap. If volumes overlap, their positions
+ are resampled until they are non-overlapping. If the maximum number of
+ attempts is exceeded, the feature regenerates the list of volumes and
+ raises a warning if non-overlapping placement cannot be achieved.
+
+ Note: `min_distance` refers to the distance between the edges of volumes,
+ not their centers. Due to the way volumes are calculated, slight rounding
+ errors may affect the final distance.
+
+ This feature is incompatible with non-volumetric scatterers such as
+ `MieScatterers`.
+
+ Parameters
+ ----------
+ feature: Feature
+ The feature that generates the list of volumes to place
+ non-overlapping.
+ min_distance: float, optional
+ The minimum distance between volumes in pixels. It can be negative to
+ allow for partial overlap. Defaults to 1.
+ max_attempts: int, optional
+ The maximum number of attempts to place volumes without overlap.
+ Defaults to 5.
+ max_iters: int, optional
+ The maximum number of resamplings. If this number is exceeded, a new
+ list of volumes is generated. Defaults to 100.
+
+ Attributes
+ ----------
+ __distributed__: bool
+ Always `False` for `NonOverlapping`, indicating that this feature’s
+ `.get()` method processes the entire input at once even if it is a
+ list, rather than distributing calls for each item of the list.N
+
+ Methods
+ -------
+ `get(*_, min_distance, max_attempts, **kwargs) -> array`
+ Generate a list of non-overlapping 3D volumes.
+ `_check_non_overlapping(list_of_volumes) -> bool`
+ Check if all volumes in the list are non-overlapping.
+ `_check_bounding_cubes_non_overlapping(...) -> bool`
+ Check if two bounding cubes are non-overlapping.
+ `_get_overlapping_cube(...) -> list[int]`
+ Get the overlapping cube between two bounding cubes.
+ `_get_overlapping_volume(...) -> array`
+ Get the overlapping volume between a volume and a bounding cube.
+ `_check_volumes_non_overlapping(...) -> bool`
+ Check if two volumes are non-overlapping.
+ `_resample_volume_position(volume) -> Image`
+ Resample the position of a volume to avoid overlap.
+
+ Notes
+ -----
+ - This feature performs bounding cube checks first to quickly reject
+ obvious overlaps before voxel-level checks.
+ - If the bounding cubes overlap, precise voxel-based checks are performed.
+
+ Examples
+ ---------
+ >>> import deeptrack as dt
+
+ Define an ellipse scatterer with randomly positioned objects:
+
+ >>> import numpy as np
+ >>>
+ >>> scatterer = dt.Ellipse(
+ >>> radius= 13 * dt.units.pixels,
+ >>> position=lambda: np.random.uniform(5, 115, size=2)* dt.units.pixels,
+ >>> )
+
+ Create multiple scatterers:
+
+ >>> scatterers = (scatterer ^ 8)
+
+ Define the optics and create the image with possible overlap:
+
+ >>> optics = dt.Fluorescence()
+ >>> im_with_overlap = optics(scatterers)
+ >>> im_with_overlap.store_properties()
+ >>> im_with_overlap_resolved = image_with_overlap()
+
+ Gather position from image:
+
+ >>> pos_with_overlap = np.array(
+ >>> im_with_overlap_resolved.get_property(
+ >>> "position",
+ >>> get_one=False
+ >>> )
+ >>> )
+
+ Enforce non-overlapping and create the image without overlap:
+
+ >>> non_overlapping_scatterers = dt.NonOverlapping(
+ ... scatterers,
+ ... min_distance=4,
+ ... )
+ >>> im_without_overlap = optics(non_overlapping_scatterers)
+ >>> im_without_overlap.store_properties()
+ >>> im_without_overlap_resolved = im_without_overlap()
+
+ Gather position from image:
+
+ >>> pos_without_overlap = np.array(
+ >>> im_without_overlap_resolved.get_property(
+ >>> "position",
+ >>> get_one=False
+ >>> )
+ >>> )
+
+ Create a figure with two subplots to visualize the difference:
+
+ >>> import matplotlib.pyplot as plt
+ >>>
+ >>> fig, axes = plt.subplots(1, 2, figsize=(10, 5))
+ >>>
+ >>> axes[0].imshow(im_with_overlap_resolved, cmap="gray")
+ >>> axes[0].scatter(pos_with_overlap[:,1],pos_with_overlap[:,0])
+ >>> axes[0].set_title("Overlapping Objects")
+ >>> axes[0].axis("off")
+ >>>
+ >>> axes[1].imshow(im_without_overlap_resolved, cmap="gray")
+ >>> axes[1].scatter(pos_without_overlap[:,1],pos_without_overlap[:,0])
+ >>> axes[1].set_title("Non-Overlapping Objects")
+ >>> axes[1].axis("off")
+ >>> plt.tight_layout()
+ >>>
+ >>> plt.show()
+
+ Define function to calculate minimum distance:
+
+ >>> def calculate_min_distance(positions):
+ >>> distances = [
+ >>> np.linalg.norm(positions[i] - positions[j])
+ >>> for i in range(len(positions))
+ >>> for j in range(i + 1, len(positions))
+ >>> ]
+ >>> return min(distances)
+
+ Print minimum distances with and without overlap:
+
+ >>> print(calculate_min_distance(pos_with_overlap))
+ 10.768742383382174
+
+ >>> print(calculate_min_distance(pos_without_overlap))
+ 30.82531120942446
+
+ """
+
+ __distributed__: bool = False
+
+ def __init__(
+ self: NonOverlapping,
+ feature: Feature,
+ min_distance: float = 1,
+ max_attempts: int = 5,
+ max_iters: int = 100,
+ **kwargs: Any,
+ ):
+ """Initializes the NonOverlapping feature.
+
+ Ensures that volumes are placed **non-overlapping** by iteratively
+ resampling their positions. If the maximum number of attempts is
+ exceeded, the feature regenerates the list of volumes.
+
+ Parameters
+ ----------
+ feature: Feature
+ The feature that generates the list of volumes.
+ min_distance: float, optional
+ The minimum separation distance **between volume edges**, in
+ pixels. It defaults to `1`. Negative values allow for partial
+ overlap.
+ max_attempts: int, optional
+ The maximum number of attempts to place the volumes without
+ overlap. It defaults to `5`.
+ max_iters: int, optional
+ The maximum number of resampling iterations per attempt. If
+ exceeded, a new list of volumes is generated. It defaults to `100`.
+
+ """
+
+ super().__init__(
+ min_distance=min_distance,
+ max_attempts=max_attempts,
+ max_iters=max_iters,
+ **kwargs,
+ )
+ self.feature = self.add_feature(feature, **kwargs)
+
+ def get(
+ self: NonOverlapping,
+ *_: Any,
+ min_distance: float,
+ max_attempts: int,
+ max_iters: int,
+ **kwargs: Any,
+ ) -> list[np.ndarray]:
+ """Generates a list of non-overlapping 3D volumes within a defined
+ field of view (FOV).
+
+ This method **iteratively** attempts to place volumes while ensuring
+ they maintain at least `min_distance` separation. If non-overlapping
+ placement is not achieved within `max_attempts`, a warning is issued,
+ and the best available configuration is returned.
+
+ Parameters
+ ----------
+ _: Any
+ Placeholder parameter, typically for an input image.
+ min_distance: float
+ The minimum required separation distance between volumes, in
+ pixels.
+ max_attempts: int
+ The maximum number of attempts to generate a valid non-overlapping
+ configuration.
+ max_iters: int
+ The maximum number of resampling iterations per attempt.
+ **kwargs: Any
+ Additional parameters that may be used by subclasses.
+
+ Returns
+ -------
+ list[np.ndarray]
+ A list of 3D volumes represented as NumPy arrays. If
+ non-overlapping placement is unsuccessful, the best available
+ configuration is returned.
+
+ Warns
+ -----
+ UserWarning
+ If non-overlapping placement is **not** achieved within
+ `max_attempts`, suggesting parameter adjustments such as increasing
+ the FOV or reducing `min_distance`.
+
+ Notes
+ -----
+ - The placement process prioritizes bounding cube checks for
+ efficiency.
+ - If bounding cubes overlap, voxel-based overlap checks are performed.
+
+ """
+
+ for _ in range(max_attempts):
+ list_of_volumes = self.feature()
+
+ if not isinstance(list_of_volumes, list):
+ list_of_volumes = [list_of_volumes]
+
+ for _ in range(max_iters):
+
+ list_of_volumes = [
+ self._resample_volume_position(volume)
+ for volume in list_of_volumes
+ ]
+
+ if self._check_non_overlapping(list_of_volumes):
+ return list_of_volumes
+
+ # Generate a new list of volumes if max_attempts is exceeded.
+ self.feature.update()
+
+ warnings.warn(
+ "Non-overlapping placement could not be achieved. Consider "
+ "adjusting parameters: reduce object radius, increase FOV, "
+ "or decrease min_distance.",
+ UserWarning,
+ )
+ return list_of_volumes
+
+ def _check_non_overlapping(
+ self: NonOverlapping,
+ list_of_volumes: list[np.ndarray],
+ ) -> bool:
+ """Determines whether all volumes in the provided list are
+ non-overlapping.
+
+ This method verifies that the non-zero voxels of each 3D volume in
+ `list_of_volumes` are at least `min_distance` apart. It first checks
+ bounding boxes for early rejection and then examines actual voxel
+ overlap when necessary. Volumes are assumed to have a `position`
+ attribute indicating their placement in 3D space.
+
+ Parameters
+ ----------
+ list_of_volumes: list[np.ndarray]
+ A list of 3D arrays representing the volumes to be checked for
+ overlap. Each volume is expected to have a position attribute.
+
+ Returns
+ -------
+ bool
+ `True` if all volumes are non-overlapping, otherwise `False`.
+
+ Notes
+ -----
+ - If `min_distance` is negative, volumes are shrunk using isotropic
+ erosion before checking overlap.
+ - If `min_distance` is positive, volumes are padded and expanded using
+ isotropic dilation.
+ - Overlapping checks are first performed on bounding cubes for
+ efficiency.
+ - If bounding cubes overlap, voxel-level checks are performed.
+
+ """
+ from deeptrack.scatterers import ScatteredVolume
+
+ from deeptrack.augmentations import CropTight, Pad # these are not compatibles with torch backend
+ from deeptrack.optics import _get_position
+ from deeptrack.math import isotropic_erosion, isotropic_dilation
+
+ min_distance = self.min_distance()
+ crop = CropTight()
+
+ new_volumes = []
+
+ for volume in list_of_volumes:
+ arr = volume.array
+ mask = arr != 0
+
+ if min_distance < 0:
+ new_arr = isotropic_erosion(mask, -min_distance / 2, backend=self.get_backend())
+ else:
+ pad = Pad(px=[int(np.ceil(min_distance / 2))] * 6, keep_size=True)
+ new_arr = isotropic_dilation(pad(mask) != 0 , min_distance / 2, backend=self.get_backend())
+ new_arr = crop(new_arr)
+
+ if self.get_backend() == "torch":
+ new_arr = new_arr.to(dtype=arr.dtype)
+ else:
+ new_arr = new_arr.astype(arr.dtype)
+
+ new_volume = ScatteredVolume(
+ array=new_arr,
+ properties=volume.properties.copy(),
+ )
+
+ new_volumes.append(new_volume)
+
+ list_of_volumes = new_volumes
+ min_distance = 1
+
+ # The position of the top left corner of each volume (index (0, 0, 0)).
+ volume_positions_1 = [
+ _get_position(volume, mode="corner", return_z=True).astype(int)
+ for volume in list_of_volumes
+ ]
+
+ # The position of the bottom right corner of each volume
+ # (index (-1, -1, -1)).
+ volume_positions_2 = [
+ p0 + np.array(v.shape)
+ for v, p0 in zip(list_of_volumes, volume_positions_1)
+ ]
+
+ # (x1, y1, z1, x2, y2, z2) for each volume.
+ volume_bounding_cube = [
+ [*p0, *p1]
+ for p0, p1 in zip(volume_positions_1, volume_positions_2)
+ ]
+
+ for i, j in itertools.combinations(range(len(list_of_volumes)), 2):
+
+ # If the bounding cubes do not overlap, the volumes do not overlap.
+ if self._check_bounding_cubes_non_overlapping(
+ volume_bounding_cube[i], volume_bounding_cube[j], min_distance
+ ):
+ continue
+
+ # If the bounding cubes overlap, get the overlapping region of each
+ # volume.
+ overlapping_cube = self._get_overlapping_cube(
+ volume_bounding_cube[i], volume_bounding_cube[j]
+ )
+ overlapping_volume_1 = self._get_overlapping_volume(
+ list_of_volumes[i].array, volume_bounding_cube[i], overlapping_cube
+ )
+ overlapping_volume_2 = self._get_overlapping_volume(
+ list_of_volumes[j].array, volume_bounding_cube[j], overlapping_cube
+ )
+
+ # If either the overlapping regions are empty, the volumes do not
+ # overlap (done for speed).
+ if (np.all(overlapping_volume_1 == 0)
+ or np.all(overlapping_volume_2 == 0)):
+ continue
+
+ # If products of overlapping regions are non-zero, return False.
+ # if np.any(overlapping_volume_1 * overlapping_volume_2):
+ # return False
+
+ # Finally, check that the non-zero voxels of the volumes are at
+ # least min_distance apart.
+ if not self._check_volumes_non_overlapping(
+ overlapping_volume_1, overlapping_volume_2, min_distance
+ ):
+ return False
+
+ return True
+
+ def _check_bounding_cubes_non_overlapping(
+ self: NonOverlapping,
+ bounding_cube_1: list[int],
+ bounding_cube_2: list[int],
+ min_distance: float,
+ ) -> bool:
+ """Determines whether two 3D bounding cubes are non-overlapping.
+
+ This method checks whether the bounding cubes of two volumes are
+ **separated by at least** `min_distance` along **any** spatial axis.
+
+ Parameters
+ ----------
+ bounding_cube_1: list[int]
+ A list of six integers `[x1, y1, z1, x2, y2, z2]` representing
+ the first bounding cube.
+ bounding_cube_2: list[int]
+ A list of six integers `[x1, y1, z1, x2, y2, z2]` representing
+ the second bounding cube.
+ min_distance: float
+ The required **minimum separation distance** between the two
+ bounding cubes.
+
+ Returns
+ -------
+ bool
+ `True` if the bounding cubes are non-overlapping (separated by at
+ least `min_distance` along **at least one axis**), otherwise
+ `False`.
+
+ Notes
+ -----
+ - This function **only checks bounding cubes**, **not actual voxel
+ data**.
+ - If the bounding cubes are non-overlapping, the corresponding
+ **volumes are also non-overlapping**.
+ - This check is much **faster** than full voxel-based comparisons.
+
+ """
+
+ # bounding_cube_1 and bounding_cube_2 are (x1, y1, z1, x2, y2, z2).
+ # Check that the bounding cubes are non-overlapping.
+ return (
+ (bounding_cube_1[0] >= bounding_cube_2[3] + min_distance) or
+ (bounding_cube_2[0] >= bounding_cube_1[3] + min_distance) or
+ (bounding_cube_1[1] >= bounding_cube_2[4] + min_distance) or
+ (bounding_cube_2[1] >= bounding_cube_1[4] + min_distance) or
+ (bounding_cube_1[2] >= bounding_cube_2[5] + min_distance) or
+ (bounding_cube_2[2] >= bounding_cube_1[5] + min_distance)
+ )
+
+ def _get_overlapping_cube(
+ self: NonOverlapping,
+ bounding_cube_1: list[int],
+ bounding_cube_2: list[int],
+ ) -> list[int]:
+ """Computes the overlapping region between two 3D bounding cubes.
+
+ This method calculates the coordinates of the intersection of two
+ axis-aligned bounding cubes, each represented as a list of six
+ integers:
+
+ - `[x1, y1, z1]`: Coordinates of the **top-left-front** corner.
+ - `[x2, y2, z2]`: Coordinates of the **bottom-right-back** corner.
+
+ The resulting overlapping region is determined by:
+ - Taking the **maximum** of the starting coordinates (`x1, y1, z1`).
+ - Taking the **minimum** of the ending coordinates (`x2, y2, z2`).
+
+ If the cubes **do not** overlap, the resulting coordinates will not
+ form a valid cube (i.e., `x1 > x2`, `y1 > y2`, or `z1 > z2`).
+
+ Parameters
+ ----------
+ bounding_cube_1: list[int]
+ The first bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`.
+ bounding_cube_2: list[int]
+ The second bounding cube, formatted as `[x1, y1, z1, x2, y2, z2]`.
+
+ Returns
+ -------
+ list[int]
+ A list of six integers `[x1, y1, z1, x2, y2, z2]` representing the
+ overlapping bounding cube. If no overlap exists, the coordinates
+ will **not** define a valid cube.
+
+ Notes
+ -----
+ - This function does **not** check for valid input or ensure the
+ resulting cube is well-formed.
+ - If no overlap exists, downstream functions must handle the invalid
+ result.
+
+ """
+
+ return [
+ max(bounding_cube_1[0], bounding_cube_2[0]),
+ max(bounding_cube_1[1], bounding_cube_2[1]),
+ max(bounding_cube_1[2], bounding_cube_2[2]),
+ min(bounding_cube_1[3], bounding_cube_2[3]),
+ min(bounding_cube_1[4], bounding_cube_2[4]),
+ min(bounding_cube_1[5], bounding_cube_2[5]),
+ ]
+
+ def _get_overlapping_volume(
+ self: NonOverlapping,
+ volume: np.ndarray, # 3D array.
+ bounding_cube: tuple[float, float, float, float, float, float],
+ overlapping_cube: tuple[float, float, float, float, float, float],
+ ) -> np.ndarray:
+ """Extracts the overlapping region of a 3D volume within the specified
+ overlapping cube.
+
+ This method identifies and returns the subregion of `volume` that
+ lies within the `overlapping_cube`. The bounding information of the
+ volume is provided via `bounding_cube`.
+
+ Parameters
+ ----------
+ volume: np.ndarray
+ A 3D NumPy array representing the volume from which the
+ overlapping region is extracted.
+ bounding_cube: tuple[float, float, float, float, float, float]
+ The bounding cube of the volume, given as a tuple of six floats:
+ `(x1, y1, z1, x2, y2, z2)`. The first three values define the
+ **top-left-front** corner, while the last three values define the
+ **bottom-right-back** corner.
+ overlapping_cube: tuple[float, float, float, float, float, float]
+ The overlapping region between the volume and another volume,
+ represented in the same format as `bounding_cube`.
+
+ Returns
+ -------
+ np.ndarray
+ A 3D NumPy array representing the portion of `volume` that
+ lies within `overlapping_cube`. If the overlap does not exist,
+ an empty array may be returned.
+
+ Notes
+ -----
+ - The method computes the relative indices of `overlapping_cube`
+ within `volume` by subtracting the bounding cube's starting
+ position.
+ - The extracted region is determined by integer indices, meaning
+ coordinates are implicitly **floored to integers**.
+ - If `overlapping_cube` extends beyond `volume` boundaries, the
+ returned subregion is **cropped** to fit within `volume`.
+
+ """
+
+ # The position of the top left corner of the overlapping cube in the volume
+ overlapping_cube_position = np.array(overlapping_cube[:3]) - np.array(
+ bounding_cube[:3]
+ )
+
+ # The position of the bottom right corner of the overlapping cube in the volume
+ overlapping_cube_end_position = np.array(
+ overlapping_cube[3:]
+ ) - np.array(bounding_cube[:3])
+
+ # cast to int
+ overlapping_cube_position = overlapping_cube_position.astype(int)
+ overlapping_cube_end_position = overlapping_cube_end_position.astype(int)
+
+ return volume[
+ overlapping_cube_position[0] : overlapping_cube_end_position[0],
+ overlapping_cube_position[1] : overlapping_cube_end_position[1],
+ overlapping_cube_position[2] : overlapping_cube_end_position[2],
+ ]
+
+ def _check_volumes_non_overlapping(
+ self: NonOverlapping,
+ volume_1: np.ndarray,
+ volume_2: np.ndarray,
+ min_distance: float,
+ ) -> bool:
+ """Determines whether the non-zero voxels in two 3D volumes are at
+ least `min_distance` apart.
+
+ This method checks whether the active regions (non-zero voxels) in
+ `volume_1` and `volume_2` maintain a minimum separation of
+ `min_distance`. If the volumes differ in size, the positions of their
+ non-zero voxels are adjusted accordingly to ensure a fair comparison.
+
+ Parameters
+ ----------
+ volume_1: np.ndarray
+ A 3D NumPy array representing the first volume.
+ volume_2: np.ndarray
+ A 3D NumPy array representing the second volume.
+ min_distance: float
+ The minimum Euclidean distance required between any two non-zero
+ voxels in the two volumes.
+
+ Returns
+ -------
+ bool
+ `True` if all non-zero voxels in `volume_1` and `volume_2` are at
+ least `min_distance` apart, otherwise `False`.
+
+ Notes
+ -----
+ - This function assumes both volumes are correctly aligned within a
+ shared coordinate space.
+ - If the volumes are of different sizes, voxel positions are scaled
+ or adjusted for accurate distance measurement.
+ - Uses **Euclidean distance** for separation checking.
+ - If either volume is empty (i.e., no non-zero voxels), they are
+ considered non-overlapping.
+
+ """
+
+ # Get the positions of the non-zero voxels of each volume.
+ if self.get_backend() == "torch":
+ positions_1 = torch.nonzero(volume_1, as_tuple=False)
+ positions_2 = torch.nonzero(volume_2, as_tuple=False)
+ else:
+ positions_1 = np.argwhere(volume_1)
+ positions_2 = np.argwhere(volume_2)
+
+ # if positions_1.size == 0 or positions_2.size == 0:
+ # return True # If either volume is empty, they are "non-overlapping"
+
+ # # If the volumes are not the same size, the positions of the non-zero
+ # # voxels of each volume need to be scaled.
+ # if positions_1.size == 0 or positions_2.size == 0:
+ # return True # If either volume is empty, they are "non-overlapping"
+
+ # If the volumes are not the same size, the positions of the non-zero
+ # voxels of each volume need to be scaled.
+ if volume_1.shape != volume_2.shape:
+ positions_1 = (
+ positions_1 * np.array(volume_2.shape)
+ / np.array(volume_1.shape)
+ )
+ positions_1 = positions_1.astype(int)
+
+ # Check that the non-zero voxels of the volumes are at least
+ # min_distance apart.
+ if self.get_backend() == "torch":
+ dist = torch.cdist(
+ positions_1.float(),
+ positions_2.float(),
+ )
+ return bool((dist > min_distance).all())
+ else:
+ return np.all(cdist(positions_1, positions_2) > min_distance)
+
+ def _resample_volume_position(
+ self: NonOverlapping,
+ volume: np.ndarray | Image,
+ ) -> Image:
+ """Resamples the position of a 3D volume using its internal position
+ sampler.
+
+ This method updates the `position` property of the given `volume` by
+ drawing a new position from the `_position_sampler` stored in the
+ volume's `properties`. If the sampled position is a `Quantity`, it is
+ converted to pixel units.
+
+ Parameters
+ ----------
+ volume: np.ndarray
+ The 3D volume whose position is to be resampled. The volume must
+ have a `properties` attribute containing dictionaries with
+ `position` and `_position_sampler` keys.
+
+ Returns
+ -------
+ Image
+ The same input volume with its `position` property updated to the
+ newly sampled value.
+
+ Notes
+ -----
+ - The `_position_sampler` function is expected to return a **tuple of
+ three floats** (e.g., `(x, y, z)`).
+ - If the sampled position is a `Quantity`, it is converted to pixels.
+ - **Only** dictionaries in `volume.properties` that contain both
+ `position` and `_position_sampler` keys are modified.
+
+ """
+
+ pdict = volume.properties
+ if "position" in pdict and "_position_sampler" in pdict:
+ new_position = pdict["_position_sampler"]()
+ if isinstance(new_position, Quantity):
+ new_position = new_position.to("pixel").magnitude
+ pdict["position"] = new_position
+
+ return volume
+
+
+class SampleToMasks(Feature):
+ """Create a mask from a list of images.
+
+ This feature applies a transformation function to each input image and
+ merges the resulting masks into a single multi-layer image. Each input
+ image must have a `position` property that determines its placement within
+ the final mask. When used with scatterers, the `voxel_size` property must
+ be provided for correct object sizing.
+
+ Parameters
+ ----------
+ transformation_function: Callable[[Image], Image]
+ A function that transforms each input image into a mask with
+ `number_of_masks` layers.
+ number_of_masks: PropertyLike[int], optional
+ The number of mask layers to generate. Default is 1.
+ output_region: PropertyLike[tuple[int, int, int, int]], optional
+ The size and position of the output mask, typically aligned with
+ `optics.output_region`.
+ merge_method: PropertyLike[str | Callable | list[str | Callable]], optional
+ Method for merging individual masks into the final image. Can be:
+ - "add" (default): Sum the masks.
+ - "overwrite": Later masks overwrite earlier masks.
+ - "or": Combine masks using a logical OR operation.
+ - "mul": Multiply masks.
+ - Function: Custom function taking two images and merging them.
+
+ **kwargs: dict[str, Any]
+ Additional keyword arguments passed to the parent `Feature` class.
+
+ Methods
+ -------
+ `get(image, transformation_function, **kwargs) -> Image`
+ Applies the transformation function to the input image.
+ `_process_and_get(images, **kwargs) -> Image | np.ndarray`
+ Processes a list of images and generates a multi-layer mask.
+
+ Returns
+ -------
+ np.ndarray
+ The final mask image with the specified number of layers.
+
+ Raises
+ ------
+ ValueError
+ If `merge_method` is invalid.
+
+ Examples
+ -------
+ >>> import deeptrack as dt
+
+ Define number of particles:
+
+ >>> n_particles = 12
+
+ Define optics and particles:
+
+ >>> import numpy as np
+ >>>
+ >>> optics = dt.Fluorescence(output_region=(0, 0, 64, 64))
+ >>> particle = dt.PointParticle(
+ >>> position=lambda: np.random.uniform(5, 55, size=2),
+ >>> )
+ >>> particles = particle ^ n_particles
+
+ Define pipelines:
+
+ >>> sim_im_pip = optics(particles)
+ >>> sim_mask_pip = particles >> dt.SampleToMasks(
+ ... lambda: lambda particles: particles > 0,
+ ... output_region=optics.output_region,
+ ... merge_method="or",
+ ... )
+ >>> pipeline = sim_im_pip & sim_mask_pip
+ >>> pipeline.store_properties()
+
+ Generate image and mask:
+
+ >>> image, mask = pipeline.update()()
+
+ Get particle positions:
+
+ >>> positions = np.array(image.get_property("position", get_one=False))
+
+ Visualize results:
+
+ >>> import matplotlib.pyplot as plt
+ >>>
+ >>> plt.subplot(1, 2, 1)
+ >>> plt.imshow(image, cmap="gray")
+ >>> plt.title("Original Image")
+ >>> plt.subplot(1, 2, 2)
+ >>> plt.imshow(mask, cmap="gray")
+ >>> plt.scatter(positions[:,1], positions[:,0], c="y", marker="x", s = 50)
+ >>> plt.title("Mask")
+ >>> plt.show()
+
+ """
+
+ def __init__(
+ self: Feature,
+ transformation_function: Callable[[np.ndarray], np.ndarray, torch.Tensor],
+ number_of_masks: PropertyLike[int] = 1,
+ output_region: PropertyLike[tuple[int, int, int, int]] = None,
+ merge_method: PropertyLike[str | Callable | list[str | Callable]] = "add",
+ **kwargs: Any,
+ ):
+ """Initialize the SampleToMasks feature.
+
+ Parameters
+ ----------
+ transformation_function: Callable[[Image], Image]
+ Function to transform input images into masks.
+ number_of_masks: PropertyLike[int], optional
+ Number of mask layers. Default is 1.
+ output_region: PropertyLike[tuple[int, int, int, int]], optional
+ Output region of the mask. Default is None.
+ merge_method: PropertyLike[str | Callable | list[str | Callable]], optional
+ Method to merge masks. Defaults to "add".
+ **kwargs: dict[str, Any]
+ Additional keyword arguments passed to the parent class.
+
+ """
+
+ super().__init__(
+ transformation_function=transformation_function,
+ number_of_masks=number_of_masks,
+ output_region=output_region,
+ merge_method=merge_method,
+ **kwargs,
+ )
+
+ def get(
+ self: Feature,
+ image: np.ndarray,
+ transformation_function: Callable[list[np.ndarray] | np.ndarray | torch.Tensor],
+ **kwargs: Any,
+ ) -> np.ndarray:
+ """Apply the transformation function to a single image.
+
+ Parameters
+ ----------
+ image: np.ndarray
+ The input image.
+ transformation_function: Callable[[np.ndarray], np.ndarray]
+ Function to transform the image.
+ **kwargs: dict[str, Any]
+ Additional parameters.
+
+ Returns
+ -------
+ Image
+ The transformed image.
+
+ """
+
+ return transformation_function(image.array)
+
+ def _process_and_get(
+ self: Feature,
+ images: list[np.ndarray] | np.ndarray | list[torch.Tensor] | torch.Tensor,
+ **kwargs: Any,
+ ) -> np.ndarray:
+ """Process a list of images and generate a multi-layer mask.
+
+ Parameters
+ ----------
+ images: np.ndarray or list[np.ndarrray] or Image or list[Image]
+ List of input images or a single image.
+ **kwargs: dict[str, Any]
+ Additional parameters including `output_region`, `number_of_masks`,
+ and `merge_method`.
+
+ Returns
+ -------
+ Image or np.ndarray
+ The final mask image.
+
+ """
+
+ # Handle list of images.
+ # if isinstance(images, list) and len(images) != 1:
+ list_of_labels = super()._process_and_get(images, **kwargs)
+
+ from deeptrack.scatterers import ScatteredVolume
+
+ for idx, (label, image) in enumerate(zip(list_of_labels, images)):
+ list_of_labels[idx] = \
+ ScatteredVolume(array=label, properties=image.properties.copy())
+
+ # Create an empty output image.
+ output_region = kwargs["output_region"]
+ output = xp.zeros(
+ (
+ output_region[2] - output_region[0],
+ output_region[3] - output_region[1],
+ kwargs["number_of_masks"],
+ ),
+ dtype=list_of_labels[0].array.dtype,
+ )
+
+ from deeptrack.optics import _get_position
+
+ # Merge masks into the output.
+ for volume in list_of_labels:
+ label = volume.array
+ position = _get_position(volume)
+
+ p0 = xp.round(position - xp.asarray(output_region[0:2]))
+ p0 = p0.astype(xp.int64)
+
+
+ if xp.any(p0 > xp.asarray(output.shape[:2])) or \
+ xp.any(p0 + xp.asarray(label.shape[:2]) < 0):
+ continue
+
+ crop_x = (-xp.minimum(p0[0], 0)).item()
+ crop_y = (-xp.minimum(p0[1], 0)).item()
+
+ crop_x_end = int(
+ label.shape[0]
+ - np.max([p0[0] + label.shape[0] - output.shape[0], 0])
+ )
+ crop_y_end = int(
+ label.shape[1]
+ - np.max([p0[1] + label.shape[1] - output.shape[1], 0])
+ )
+
+ labelarg = label[crop_x:crop_x_end, crop_y:crop_y_end, :]
+
+ p0[0] = np.max([p0[0], 0])
+ p0[1] = np.max([p0[1], 0])
+
+ p0 = p0.astype(int)
+
+ output_slice = output[
+ p0[0] : p0[0] + labelarg.shape[0],
+ p0[1] : p0[1] + labelarg.shape[1],
+ ]
+
+ for label_index in range(kwargs["number_of_masks"]):
+
+ if isinstance(kwargs["merge_method"], list):
+ merge = kwargs["merge_method"][label_index]
+ else:
+ merge = kwargs["merge_method"]
+
+ if merge == "add":
+ output[
+ p0[0] : p0[0] + labelarg.shape[0],
+ p0[1] : p0[1] + labelarg.shape[1],
+ label_index,
+ ] += labelarg[..., label_index]
+
+ elif merge == "overwrite":
+ output_slice[
+ labelarg[..., label_index] != 0, label_index
+ ] = labelarg[labelarg[..., label_index] != 0, \
+ label_index]
+ output[
+ p0[0] : p0[0] + labelarg.shape[0],
+ p0[1] : p0[1] + labelarg.shape[1],
+ label_index,
+ ] = output_slice[..., label_index]
+
+ elif merge == "or":
+ output[
+ p0[0] : p0[0] + labelarg.shape[0],
+ p0[1] : p0[1] + labelarg.shape[1],
+ label_index,
+ ] = xp.logical_or(
+ output_slice[..., label_index] != 0,
+ labelarg[..., label_index] != 0
+ )
+
+ elif merge == "mul":
+ output[
+ p0[0] : p0[0] + labelarg.shape[0],
+ p0[1] : p0[1] + labelarg.shape[1],
+ label_index,
+ ] *= labelarg[..., label_index]
+
+ else:
+ # No match, assume function
+ output[
+ p0[0] : p0[0] + labelarg.shape[0],
+ p0[1] : p0[1] + labelarg.shape[1],
+ label_index,
+ ] = merge(
+ output_slice[..., label_index],
+ labelarg[..., label_index],
+ )
+
+ return output
+
+
#TODO ***??*** revise _get_position - torch, typing, docstring, unit test
def _get_position(
scatterer: ScatteredObject,
From e006806829692d1823b41fd9248dd7413963fb4f Mon Sep 17 00:00:00 2001
From: Carlo
Date: Thu, 15 Jan 2026 18:13:58 +0100
Subject: [PATCH 22/24] u
---
deeptrack/features.py | 5 +-
deeptrack/math.py | 1375 ++++++++++++++---------------------------
deeptrack/optics.py | 20 +-
3 files changed, 483 insertions(+), 917 deletions(-)
diff --git a/deeptrack/features.py b/deeptrack/features.py
index 901664a9..702e7362 100644
--- a/deeptrack/features.py
+++ b/deeptrack/features.py
@@ -1,6 +1,6 @@
"""Core features for building and processing pipelines in DeepTrack2.
-The `feasture.py` module defines the core classes and utilities used to create
+The `feature.py` module defines the core classes and utilities used to create
and manipulate features in DeepTrack2, enabling users to build sophisticated
data processing pipelines with modular, reusable, and composable components.
@@ -80,11 +80,8 @@
- `OneOf`: Resolve one feature from a given collection.
- `OneOfDict`: Resolve one feature from a dictionary and apply it to an input.
- `LoadImage`: Load an image from disk and preprocess it.
-- `SampleToMasks`: Create a mask from a list of images.
- `AsType`: Convert the data type of the input.
- `ChannelFirst2d`: DEPRECATED Convert an image to a channel-first format.
-- `Upscale`: Simulate a pipeline at a higher resolution.
-- `NonOverlapping`: Ensure volumes are placed non-overlapping in a 3D space.
- `Store`: Store the output of a feature for reuse.
- `Squeeze`: Squeeze the input to the smallest possible dimension.
- `Unsqueeze`: Unsqueeze the input.
diff --git a/deeptrack/math.py b/deeptrack/math.py
index 0af95c05..5845fb90 100644
--- a/deeptrack/math.py
+++ b/deeptrack/math.py
@@ -97,15 +97,15 @@
import array_api_compat as apc
import numpy as np
-from numpy.typing import NDArray
+from numpy.typing import NDArray #TODO TBE
from scipy import ndimage
import skimage
import skimage.measure
from deeptrack import utils, OPENCV_AVAILABLE, TORCH_AVAILABLE
from deeptrack.features import Feature
-from deeptrack.image import Image, strip
-from deeptrack.types import ArrayLike, PropertyLike
+from deeptrack.image import Image, strip #TODO TBE
+from deeptrack.types import PropertyLike
from deeptrack.backend import xp
if TORCH_AVAILABLE:
@@ -130,11 +130,6 @@
"MaxPooling",
"MinPooling",
"MedianPooling",
- "PoolV2",
- "AveragePoolingV2",
- "MaxPoolingV2",
- "MinPoolingV2",
- "MedianPoolingV2",
"BlurCV2",
"BilateralBlur",
]
@@ -232,10 +227,10 @@ def __init__(
def get(
self: Average,
- images: list[NDArray[Any] | torch.Tensor | Image],
+ images: list[np.ndarray | torch.Tensor],
axis: int | tuple[int],
**kwargs: Any,
- ) -> NDArray[Any] | torch.Tensor | Image:
+ ) -> np.ndarray | torch.Tensor:
"""Compute the average of input images along the specified axis(es).
This method computes the average of the input images along the
@@ -323,11 +318,11 @@ def __init__(
def get(
self: Clip,
- image: NDArray[Any] | torch.Tensor | Image,
+ image: np.ndarray | torch.Tensor,
min: float,
max: float,
**kwargs: Any,
- ) -> NDArray[Any] | torch.Tensor | Image:
+ ) -> np.ndarray | torch.Tensor:
"""Clips the input image within the specified values.
This method clips the input image within the specified minimum and
@@ -368,8 +363,7 @@ class NormalizeMinMax(Feature):
max: float, optional
Upper bound of the transformation. It defaults to 1.
featurewise: bool, optional
- Whether to normalize each feature independently. It default to `True`,
- which is the only behavior currently implemented.
+ Whether to normalize each feature independently. It default to `True`.
Methods
-------
@@ -395,8 +389,6 @@ class NormalizeMinMax(Feature):
"""
- #TODO ___??___ Implement the `featurewise=False` option
-
def __init__(
self: NormalizeMinMax,
min: PropertyLike[float] = 0,
@@ -423,32 +415,47 @@ def __init__(
def get(
self: NormalizeMinMax,
- image: ArrayLike,
+ image: np.ndarray | torch.Tensor,
min: float,
max: float,
+ featurewise: bool = True,
**kwargs: Any,
- ) -> ArrayLike:
+ ) -> np.ndarray | torch.Tensor:
"""Normalize the input to fall between `min` and `max`.
Parameters
----------
- image: array
+ image: np.ndarray or torch.Tensor
Input image to normalize.
min: float
Lower bound of the output range.
max: float
Upper bound of the output range.
+ featurewise: bool
+ Whether to normalize each feature (channel) independently.
Returns
-------
- array
+ np.ndarray or torch.Tensor
Min-max normalized image.
"""
- ptp = xp.max(image) - xp.min(image)
- image = image / ptp * (max - min)
- image = image - xp.min(image) + min
+ if featurewise:
+ # Normalize per feature (last axis)
+ axis = tuple(range(image.ndim - 1))
+
+ img_min = xp.min(image, axis=axis, keepdims=True)
+ img_max = xp.max(image, axis=axis, keepdims=True)
+ else:
+ # Normalize globally
+ img_min = xp.min(image)
+ img_max = xp.max(image)
+
+ ptp = img_max - img_min
+
+ # Avoid division by zero
+ image = (image - img_min) / ptp * (max - min) + min
try:
image[xp.isnan(image)] = 0
@@ -492,8 +499,6 @@ class NormalizeStandard(Feature):
"""
- #TODO ___??___ Implement the `featurewise=False` option
-
def __init__(
self: NormalizeStandard,
featurewise: PropertyLike[bool] = True,
@@ -516,33 +521,52 @@ def __init__(
def get(
self: NormalizeStandard,
- image: NDArray[Any] | torch.Tensor | Image,
+ image: np.ndarray | torch.Tensor,
+ featurewise: bool,
**kwargs: Any,
- ) -> NDArray[Any] | torch.Tensor | Image:
+ ) -> np.ndarray | torch.Tensor:
"""Normalizes the input image to have mean 0 and standard deviation 1.
- This method normalizes the input image to have mean 0 and standard
- deviation 1.
-
Parameters
----------
- image: array
+ image: np.ndarray or torch.Tensor
The input image to normalize.
+ featurewise: bool
+ Whether to normalize each feature (channel) independently.
Returns
-------
- array
- The normalized image.
-
+ np.ndarray or torch.Tensor
+ The standardized image.
"""
- if apc.is_torch_array(image):
- # By default, torch.std() is unbiased, i.e., divides by N-1
- return (
- (image - torch.mean(image)) / torch.std(image, unbiased=False)
- )
+ if featurewise:
+ # Normalize per feature (last axis)
+ axis = tuple(range(image.ndim - 1))
+
+ mean = xp.mean(image, axis=axis, keepdims=True)
+
+ if apc.is_torch_array(image):
+ std = torch.std(image, dim=axis, keepdim=True, unbiased=False)
+ else:
+ std = xp.std(image, axis=axis)
+ else:
+ # Normalize globally
+ mean = xp.mean(image)
- return (image - xp.mean(image)) / xp.std(image)
+ if apc.is_torch_array(image):
+ std = torch.std(image, unbiased=False)
+ else:
+ std = xp.std(image)
+
+ image = (image - mean) / std
+
+ try:
+ image[xp.isnan(image)] = 0
+ except TypeError:
+ pass
+
+ return image
class NormalizeQuantile(Feature):
@@ -614,158 +638,164 @@ def __init__(
def get(
self: NormalizeQuantile,
- image: NDArray[Any] | torch.Tensor | Image,
- quantiles: tuple[float, float] = None,
+ image: np.ndarray | torch.Tensor,
+ quantiles: tuple[float, float],
+ featurewise: bool,
**kwargs: Any,
- ) -> NDArray[Any] | torch.Tensor | Image:
+ ) -> np.ndarray | torch.Tensor:
"""Normalize the input image based on the specified quantiles.
- This method normalizes the input image based on the specified
- quantiles.
-
Parameters
----------
- image: array
+ image: np.ndarray or torch.Tensor
The input image to normalize.
quantiles: tuple[float, float]
Quantile range to calculate scaling factor.
+ featurewise: bool
+ Whether to normalize each feature (channel) independently.
Returns
-------
- array
- The normalized image.
-
+ np.ndarray or torch.Tensor
+ The quantile-normalized image.
"""
- if apc.is_torch_array(image):
- q_tensor = torch.tensor(
- [*quantiles, 0.5],
- device=image.device,
- dtype=image.dtype,
- )
- q_low, q_high, median = torch.quantile(
- image, q_tensor, dim=None, keepdim=False,
- )
- else: # NumPy
- q_low, q_high, median = xp.quantile(image, (*quantiles, 0.5))
+ q_low_val, q_high_val = quantiles
- return (image - median) / (q_high - q_low) * 2.0
+ if featurewise:
+ # Per-feature normalization (last axis)
+ axis = tuple(range(image.ndim - 1))
+ if apc.is_torch_array(image):
+ q = torch.tensor(
+ [q_low_val, q_high_val, 0.5],
+ device=image.device,
+ dtype=image.dtype,
+ )
+ q_low, q_high, median = torch.quantile(
+ image, q, dim=axis, keepdim=True
+ )
+ else:
+ q_low, q_high, median = xp.quantile(
+ image, (q_low_val, q_high_val, 0.5),
+ axis=axis,
+ keepdims=True,
+ )
+ else:
+ # Global normalization
+ if apc.is_torch_array(image):
+ q = torch.tensor(
+ [q_low_val, q_high_val, 0.5],
+ device=image.device,
+ dtype=image.dtype,
+ )
+ q_low, q_high, median = torch.quantile(
+ image, q, dim=None, keepdim=False
+ )
+ else:
+ q_low, q_high, median = xp.quantile(
+ image, (q_low_val, q_high_val, 0.5)
+ )
-#TODO ***JH*** revise Blur - torch, typing, docstring, unit test
-class Blur(Feature):
- """Apply a blurring filter to an image.
-
- This class applies a blurring filter to an image. The filter function
- must be a function that takes an input image and returns a blurred
- image.
-
- Parameters
- ----------
- filter_function: Callable
- The blurring function to apply. This function must accept the input
- image as a keyword argument named `input`. If using OpenCV functions
- (e.g., `cv2.GaussianBlur`), use `BlurCV2` instead.
- mode: str
- Border mode for handling boundaries (e.g., 'reflect').
-
- Methods
- -------
- `get(image: np.ndarray | Image, **kwargs: Any) --> np.ndarray`
- Applies the blurring filter to the input image.
+ image = (image - median) / (q_high - q_low) * 2.0
- Examples
- --------
- >>> import deeptrack as dt
- >>> import numpy as np
- >>> from scipy.ndimage import convolve
+ try:
+ image[xp.isnan(image)] = 0
+ except TypeError:
+ pass
- Create an input image:
- >>> input_image = np.random.rand(32, 32)
+ return image
- Define a Gaussian kernel for blurring:
- >>> gaussian_kernel = np.array([
- ... [1, 4, 6, 4, 1],
- ... [4, 16, 24, 16, 4],
- ... [6, 24, 36, 24, 6],
- ... [4, 16, 24, 16, 4],
- ... [1, 4, 6, 4, 1]
- ... ], dtype=float)
- >>> gaussian_kernel /= np.sum(gaussian_kernel)
- Define a blur function using the Gaussian kernel:
- >>> def gaussian_blur(input, **kwargs):
- ... return convolve(input, gaussian_kernel, mode='reflect')
+#TODO ***CM*** revise typing, docstring, unit test
+class Blur(Feature):
+ """Apply a blurring filter to an image.
- Define a blur feature using the Gaussian blur function:
- >>> blur = dt.Blur(filter_function=gaussian_blur)
- >>> output_image = blur(input_image)
- >>> print(output_image.shape)
- (32, 32)
+ This class acts as a backend-dispatching blur operator. Subclasses must
+ implement backend-specific logic via `_get_numpy` and optionally
+ `_get_torch`.
Notes
-----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
- The filter_function must accept the input image as a keyword argument named
- input. This is required because it is called via utils.safe_call. If you
- are using functions that do not support input=... (such as OpenCV filters
- like cv2.GaussianBlur), consider using BlurCV2 instead.
+ - NumPy execution is always supported.
+ - Torch execution is only supported if `_get_torch` is implemented.
+ - Generic `filter_function`-based blurs are NumPy-only by design.
"""
def __init__(
- self: Blur,
- filter_function: Callable,
+ self,
+ filter_function: Callable | None = None,
mode: PropertyLike[str] = "reflect",
**kwargs: Any,
):
- """Initialize the parameters for blurring input features.
-
- This constructor initializes the parameters for blurring input
- features.
+ """Initialize the blur feature.
Parameters
----------
- filter_function: Callable
- The blurring function to apply.
- mode: str
- Border mode for handling boundaries (e.g., 'reflect').
- **kwargs: Any
- Additional keyword arguments.
-
+ filter_function : Callable or None
+ NumPy-based blurring function. Must accept the input image as a
+ keyword argument named `input`. If `None`, the subclass must
+ implement `_get_numpy`.
+ mode : str
+ Border mode for NumPy-based filters.
+ **kwargs : Any
+ Additional keyword arguments passed to Feature.
"""
-
self.filter = filter_function
- super().__init__(borderType=mode, **kwargs)
+ self.mode = mode
+ super().__init__(**kwargs)
- def get(self: Blur, image: np.ndarray | Image, **kwargs: Any) -> np.ndarray:
- """Applies the blurring filter to the input image.
+ def __call__(
+ self,
+ image: np.ndarray | torch.Tensor,
+ **kwargs: Any,
+ ) -> np.ndarray | torch.Tensor:
+ if isinstance(image, np.ndarray):
+ return self._get_numpy(image, **kwargs)
- This method applies the blurring filter to the input image.
+ if TORCH_AVAILABLE and isinstance(image, torch.Tensor):
+ return self._get_torch(image, **kwargs)
- Parameters
- ----------
- image: np.ndarray
- The input image to blur.
- **kwargs: dict[str, Any]
- Additional keyword arguments.
+ raise TypeError(
+ "Blur only supports numpy.ndarray or torch.Tensor inputs."
+ )
- Returns
- -------
- np.ndarray
- The blurred image.
+ def _get_numpy(
+ self,
+ image: np.ndarray,
+ **kwargs: Any,
+ ) -> np.ndarray:
+ if self.filter is None:
+ raise NotImplementedError(
+ f"{self.__class__.__name__} does not implement a NumPy backend."
+ )
- """
+ # Avoid passing conflicting keywords
+ kwargs = dict(kwargs)
+ kwargs.pop("input", None)
+
+ return utils.safe_call(
+ self.filter,
+ input=image,
+ mode=self.mode,
+ **kwargs,
+ )
+
+ def _get_torch(
+ self,
+ image: torch.Tensor,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ raise TypeError(
+ f"{self.__class__.__name__} does not support torch.Tensor inputs. "
+ "Use a Torch-enabled blur (e.g. AverageBlur or a V2 blur class)."
+ )
- kwargs.pop("input", False)
- return utils.safe_call(self.filter, input=image, **kwargs)
-#TODO ***JH*** revise AverageBlur - torch, typing, docstring, unit test
+#TODO ***CM*** revise AverageBlur - torch, typing, docstring, unit test
class AverageBlur(Blur):
"""Blur an image by computing simple means over neighbourhoods.
@@ -779,7 +809,7 @@ class AverageBlur(Blur):
Methods
-------
- `get(image: np.ndarray | Image, ksize: int, **kwargs: Any) --> np.ndarray`
+ `get(image: np.ndarray | torch.Tensor, ksize: int, **kwargs: Any) --> np.ndarray | torch.Tensor`
Applies the average blurring filter to the input image.
Examples
@@ -796,13 +826,6 @@ class AverageBlur(Blur):
>>> print(output_image.shape)
(32, 32)
- Notes
- -----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
-
"""
def __init__(
@@ -833,800 +856,285 @@ def _kernel_shape(self, shape: tuple[int, ...], ksize: int) -> tuple[int, ...]:
def _get_numpy(
self, input: np.ndarray, ksize: tuple[int, ...], **kwargs: Any
- ) -> np.ndarray:
- return ndimage.uniform_filter(
- input,
- size=ksize,
- mode=kwargs.get("mode", "reflect"),
- cval=kwargs.get("cval", 0),
- origin=kwargs.get("origin", 0),
- axes=tuple(range(0, len(ksize))),
- )
-
- def _get_torch(
- self, input: torch.Tensor, ksize: tuple[int, ...], **kwargs: Any
- ) -> np.ndarray:
- F = xp.nn.functional
-
- last_dim_is_channel = len(ksize) < input.ndim
- if last_dim_is_channel:
- # permute to first dim
- input = input.movedim(-1, 0)
- else:
- input = input.unsqueeze(0)
-
- # add batch dimension
- input = input.unsqueeze(0)
-
- # pad input
- input = F.pad(
- input,
- (ksize[0] // 2, ksize[0] // 2, ksize[1] // 2, ksize[1] // 2),
- mode=kwargs.get("mode", "reflect"),
- value=kwargs.get("cval", 0),
- )
- if input.ndim == 3:
- x = F.avg_pool1d(
- input,
- kernel_size=ksize,
- stride=1,
- padding=0,
- ceil_mode=False,
- count_include_pad=False,
- )
- elif input.ndim == 4:
- x = F.avg_pool2d(
- input,
- kernel_size=ksize,
- stride=1,
- padding=0,
- ceil_mode=False,
- count_include_pad=False,
- )
- elif input.ndim == 5:
- x = F.avg_pool3d(
- input,
- kernel_size=ksize,
- stride=1,
- padding=0,
- ceil_mode=False,
- count_include_pad=False,
- )
- else:
- raise NotImplementedError(
- f"Input dimension {input.ndim - 2} not supported for torch backend"
- )
-
- # restore layout
- x = x.squeeze(0)
- if last_dim_is_channel:
- x = x.movedim(0, -1)
- else:
- x = x.squeeze(0)
-
- return x
-
- def get(
- self: AverageBlur,
- input: ArrayLike,
- ksize: int,
- **kwargs: Any,
- ) -> np.ndarray:
- """Applies the average blurring filter to the input image.
-
- This method applies the average blurring filter to the input image.
-
- Parameters
- ----------
- input: np.ndarray
- The input image to blur.
- ksize: int
- Kernel size for the pooling operation.
- **kwargs: dict[str, Any]
- Additional keyword arguments.
-
- Returns
- -------
- np.ndarray
- The blurred image.
-
- """
-
- k = self._kernel_shape(input.shape, ksize)
-
- if self.backend == "numpy":
- return self._get_numpy(input, k, **kwargs)
- elif self.backend == "torch":
- return self._get_torch(input, k, **kwargs)
- else:
- raise NotImplementedError(f"Backend {self.backend} not supported")
-
-
-#TODO ***JH*** revise GaussianBlur - torch, typing, docstring, unit test
-class GaussianBlur(Blur):
- """Applies a Gaussian blur to images using Gaussian kernels.
-
- This class blurs images by convolving them with a Gaussian filter, which
- smooths the image and reduces high-frequency details. The level of blurring
- is controlled by the standard deviation (`sigma`) of the Gaussian kernel.
-
- Parameters
- ----------
- sigma: float
- Standard deviation of the Gaussian kernel.
-
- Examples
- --------
- >>> import deeptrack as dt
- >>> import numpy as np
- >>> import matplotlib.pyplot as plt
-
- Create an input image:
- >>> input_image = np.random.rand(32, 32)
-
- Define a Gaussian blur feature:
- >>> gaussian_blur = dt.GaussianBlur(sigma=2)
- >>> output_image = gaussian_blur(input_image)
- >>> print(output_image.shape)
- (32, 32)
-
- Visualize the input and output images:
- >>> plt.figure(figsize=(8, 4))
- >>> plt.subplot(1, 2, 1)
- >>> plt.imshow(input_image, cmap='gray')
- >>> plt.subplot(1, 2, 2)
- >>> plt.imshow(output_image, cmap='gray')
- >>> plt.show()
-
- Notes
- -----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
-
- """
-
- def __init__(self: GaussianBlur, sigma: PropertyLike[float] = 2, **kwargs: Any):
- """Initialize the parameters for Gaussian blurring.
-
- This constructor initializes the parameters for Gaussian blurring.
-
- Parameters
- ----------
- sigma: float
- Standard deviation of the Gaussian kernel.
- **kwargs: Any
- Additional keyword arguments.
-
- """
-
- super().__init__(ndimage.gaussian_filter, sigma=sigma, **kwargs)
-
-
-#TODO ***JH*** revise MedianBlur - torch, typing, docstring, unit test
-class MedianBlur(Blur):
- """Applies a median blur.
-
- This class replaces each pixel of the input image with the median value of
- its neighborhood. The `ksize` parameter determines the size of the
- neighborhood used to calculate the median filter. The median filter is
- useful for reducing noise while preserving edges. It is particularly
- effective for removing salt-and-pepper noise from images.
-
- Parameters
- ----------
- ksize: int
- Kernel size.
- **kwargs: dict
- Additional parameters sent to the blurring function.
-
- Examples
- --------
- >>> import deeptrack as dt
- >>> import numpy as np
- >>> import matplotlib.pyplot as plt
-
- Create an input image:
- >>> input_image = np.random.rand(32, 32)
-
- Define a median blur feature:
- >>> median_blur = dt.MedianBlur(ksize=3)
- >>> output_image = median_blur(input_image)
- >>> print(output_image.shape)
- (32, 32)
-
- Visualize the input and output images:
- >>> plt.figure(figsize=(8, 4))
- >>> plt.subplot(1, 2, 1)
- >>> plt.imshow(input_image, cmap='gray')
- >>> plt.subplot(1, 2, 2)
- >>> plt.imshow(output_image, cmap='gray')
- >>> plt.show()
-
- Notes
- -----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
-
- """
-
- def __init__(
- self: MedianBlur,
- ksize: PropertyLike[int] = 3,
- **kwargs: Any,
- ):
- """Initialize the parameters for median blurring.
-
- This constructor initializes the parameters for median blurring.
-
- Parameters
- ----------
- ksize: int
- Kernel size.
- **kwargs: Any
- Additional keyword arguments.
-
- """
-
- super().__init__(ndimage.median_filter, size=ksize, **kwargs)
-
-
-#TODO ***AL*** revise Pool - torch, typing, docstring, unit test
-class Pool(Feature):
- """Downsamples the image by applying a function to local regions of the
- image.
-
- This class reduces the resolution of an image by dividing it into
- non-overlapping blocks of size `ksize` and applying the specified pooling
- function to each block. The result is a downsampled image where each pixel
- value represents the result of the pooling function applied to the
- corresponding block.
-
- Parameters
- ----------
- pooling_function: function
- A function that is applied to each local region of the image.
- DOES NOT NEED TO BE WRAPPED IN ANOTHER FUNCTION.
- The `pooling_function` must accept the input image as a keyword argument
- named `input`, as it is called via `utils.safe_call`.
- Examples include `np.mean`, `np.max`, `np.min`, etc.
- ksize: int
- Size of the pooling kernel.
- **kwargs: Any
- Additional parameters sent to the pooling function.
-
- Methods
- -------
- `get(image: np.ndarray | Image, ksize: int, **kwargs: Any) --> np.ndarray`
- Applies the pooling function to the input image.
-
- Examples
- --------
- >>> import deeptrack as dt
- >>> import numpy as np
-
- Create an input image:
- >>> input_image = np.random.rand(32, 32)
-
- Define a pooling feature:
- >>> pooling_feature = dt.Pool(pooling_function=np.mean, ksize=4)
- >>> output_image = pooling_feature.get(input_image, ksize=4)
- >>> print(output_image.shape)
- (8, 8)
-
- Notes
- -----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
- The filter_function must accept the input image as a keyword argument named
- input. This is required because it is called via utils.safe_call. If you
- are using functions that do not support input=... (such as OpenCV filters
- like cv2.GaussianBlur), consider using BlurCV2 instead.
-
- """
-
- def __init__(
- self: Pool,
- pooling_function: Callable,
- ksize: PropertyLike[int] = 3,
- **kwargs: Any,
- ):
- """Initialize the parameters for pooling input features.
-
- This constructor initializes the parameters for pooling input
- features.
-
- Parameters
- ----------
- pooling_function: Callable
- The pooling function to apply.
- ksize: int
- Size of the pooling kernel.
- **kwargs: Any
- Additional keyword arguments.
-
- """
-
- self.pooling = pooling_function
- super().__init__(ksize=ksize, **kwargs)
-
- def get(
- self: Pool,
- image: np.ndarray | Image,
- ksize: int,
- **kwargs: Any,
- ) -> np.ndarray:
- """Applies the pooling function to the input image.
-
- This method applies the pooling function to the input image.
-
- Parameters
- ----------
- image: np.ndarray
- The input image to pool.
- ksize: int
- Size of the pooling kernel.
- **kwargs: dict[str, Any]
- Additional keyword arguments.
-
- Returns
- -------
- np.ndarray
- The pooled image.
-
- """
-
- kwargs.pop("func", False)
- kwargs.pop("image", False)
- kwargs.pop("block_size", False)
- return utils.safe_call(
- skimage.measure.block_reduce,
- image=image,
- func=self.pooling,
- block_size=ksize,
- **kwargs,
- )
-
-
-#TODO ***AL*** revise AveragePooling - torch, typing, docstring, unit test
-class AveragePooling(Pool):
- """Apply average pooling to an image.
-
- This class reduces the resolution of an image by dividing it into
- non-overlapping blocks of size `ksize` and applying the average function to
- each block. The result is a downsampled image where each pixel value
- represents the average value within the corresponding block of the
- original image.
-
- Parameters
- ----------
- ksize: int
- Size of the pooling kernel.
- **kwargs: dict
- Additional parameters sent to the pooling function.
-
- Examples
- --------
- >>> import deeptrack as dt
- >>> import numpy as np
-
- Create an input image:
- >>> input_image = np.random.rand(32, 32)
-
- Define an average pooling feature:
- >>> average_pooling = dt.AveragePooling(ksize=4)
- >>> output_image = average_pooling(input_image)
- >>> print(output_image.shape)
- (8, 8)
-
- Notes
- -----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
-
- """
-
- def __init__(
- self: Pool,
- ksize: PropertyLike[int] = 3,
- **kwargs: Any,
- ):
- """Initialize the parameters for average pooling.
-
- This constructor initializes the parameters for average pooling.
-
- Parameters
- ----------
- ksize: int
- Size of the pooling kernel.
- **kwargs: Any
- Additional keyword arguments.
-
- """
-
- super().__init__(np.mean, ksize=ksize, **kwargs)
-
-
-class MaxPooling(Pool):
- """Apply max-pooling to images.
-
- `MaxPooling` reduces the resolution of an image by dividing it into
- non-overlapping blocks of size `ksize` and applying the `max` function
- to each block. The result is a downsampled image where each pixel value
- represents the maximum value within the corresponding block of the
- original image. This is useful for reducing the size of an image while
- retaining the most significant features.
-
- If the backend is NumPy, the downsampling is performed using
- `skimage.measure.block_reduce`.
-
- If the backend is PyTorch, the downsampling is performed using
- `torch.nn.functional.max_pool2d`.
-
- Parameters
- ----------
- ksize: int
- Size of the pooling kernel.
- **kwargs: Any
- Additional parameters sent to the pooling function.
-
- Examples
- --------
- >>> import deeptrack as dt
-
- Create an input image:
- >>> import numpy as np
- >>>
- >>> input_image = np.random.rand(32, 32)
-
- Define and use a max-pooling feature:
-
- >>> max_pooling = dt.MaxPooling(ksize=8)
- >>> output_image = max_pooling(input_image)
- >>> output_image.shape
- (4, 4)
-
- """
-
- def __init__(
- self: MaxPooling,
- ksize: PropertyLike[int] = 3,
- **kwargs: Any,
- ):
- """Initialize the parameters for max-pooling.
-
- This constructor initializes the parameters for max-pooling.
-
- Parameters
- ----------
- ksize: int
- Size of the pooling kernel.
- **kwargs: Any
- Additional keyword arguments.
-
- """
-
- super().__init__(np.max, ksize=ksize, **kwargs)
-
- def get(
- self: MaxPooling,
- image: NDArray[Any] | torch.Tensor,
- ksize: int=3,
- **kwargs: Any,
- ) -> NDArray[Any] | torch.Tensor:
- """Max-pooling of input.
-
- Checks the current backend and chooses the appropriate function to pool
- the input image, either `._get_torch()` or `._get_numpy()`.
-
- Parameters
- ----------
- image: array or tensor
- Input array or tensor be pooled.
- ksize: int
- Kernel size of the pooling operation.
-
- Returns
- -------
- array or tensor
- The pooled input as `NDArray` or `torch.Tensor` depending on
- the backend.
-
- """
+ ) -> np.ndarray:
+ return ndimage.uniform_filter(
+ input,
+ size=ksize,
+ mode=kwargs.get("mode", "reflect"),
+ cval=kwargs.get("cval", 0),
+ origin=kwargs.get("origin", 0),
+ axes=tuple(range(0, len(ksize))),
+ )
- if self.get_backend() == "numpy":
- return self._get_numpy(image, ksize, **kwargs)
+ def _get_torch(
+ self, input: torch.Tensor, ksize: tuple[int, ...], **kwargs: Any
+ ) -> torch.Tensor:
- if self.get_backend() == "torch":
- return self._get_torch(image, ksize, **kwargs)
+ last_dim_is_channel = len(ksize) < input.ndim
+ if last_dim_is_channel:
+ input = input.movedim(-1, 0)
+ else:
+ input = input.unsqueeze(0)
- raise NotImplementedError(f"Backend {self.backend} not supported")
+ # add batch dimension
+ input = input.unsqueeze(0)
- def _get_numpy(
- self: MaxPooling,
- image: NDArray[Any],
- ksize: int=3,
- **kwargs: Any,
- ) -> NDArray[Any]:
- """Max-pooling pooling with the NumPy backend enabled.
+ # dynamic padding
+ pad = []
+ for k in reversed(ksize):
+ p = k // 2
+ pad.extend([p, p])
+ pad = tuple(pad)
- Returns the result of the input array passed to the scikit image
- `block_reduce()` function with `np.max()` as the pooling function.
+ input = F.pad(
+ input,
+ pad,
+ mode=kwargs.get("mode", "reflect"),
+ value=kwargs.get("cval", 0),
+ )
- Parameters
- ----------
- image: array
- Input array to be pooled.
- ksize: int
- Kernel size of the pooling operation.
+ if input.ndim == 3:
+ x = F.avg_pool1d(input, kernel_size=ksize, stride=1)
+ elif input.ndim == 4:
+ x = F.avg_pool2d(input, kernel_size=ksize, stride=1)
+ elif input.ndim == 5:
+ x = F.avg_pool3d(input, kernel_size=ksize, stride=1)
+ else:
+ raise NotImplementedError(
+ f"Input dimension {input.ndim - 2} not supported for torch backend"
+ )
- Returns
- -------
- array
- The pooled image as a NumPy array.
-
- """
+ # restore layout
+ x = x.squeeze(0)
+ if last_dim_is_channel:
+ x = x.movedim(0, -1)
+ else:
+ x = x.squeeze(0)
- return utils.safe_call(
- skimage.measure.block_reduce,
- image=image,
- func=np.max,
- block_size=ksize,
- **kwargs,
- )
+ return x
- def _get_torch(
- self: MaxPooling,
- image: torch.Tensor,
- ksize: int=3,
+ def get(
+ self: AverageBlur,
+ input: np.ndarray | torch.Tensor,
+ ksize: int,
**kwargs: Any,
- ) -> torch.Tensor:
- """Max-pooling with the PyTorch backend enabled.
-
+ ) -> np.ndarray | torch.Tensor:
+ """Applies the average blurring filter to the input image.
- Returns the result of the tensor passed to a PyTorch max
- pooling layer.
+ This method applies the average blurring filter to the input image.
Parameters
----------
- image: torch.Tensor
- Input tensor to be pooled.
+ input: np.ndarray
+ The input image to blur.
ksize: int
- Kernel size of the pooling operation.
+ Kernel size for the pooling operation.
+ **kwargs: dict[str, Any]
+ Additional keyword arguments.
Returns
-------
- torch.Tensor
- The pooled image as a `torch.Tensor`.
+ np.ndarray
+ The blurred image.
"""
- # If input tensor is 2D
- if len(image.shape) == 2:
- # Add batch dimension for max-pooling
- expanded_image = image.unsqueeze(0)
-
- pooled_image = torch.nn.functional.max_pool2d(
- expanded_image, kernel_size=ksize,
- )
- # Remove the expanded dim
- return pooled_image.squeeze(0)
-
- return torch.nn.functional.max_pool2d(
- image,
- kernel_size=ksize,
- )
-
+ k = self._kernel_shape(input.shape, ksize)
-class MinPooling(Pool):
- """Apply min-pooling to images.
+ if self.backend == "numpy":
+ return self._get_numpy(input, k, **kwargs)
+ elif self.backend == "torch":
+ return self._get_torch(input, k, **kwargs)
+ else:
+ raise NotImplementedError(f"Backend {self.backend} not supported")
- `MinPooling` reduces the resolution of an image by dividing it into
- non-overlapping blocks of size `ksize` and applying the `min` function to
- each block. The result is a downsampled image where each pixel value
- represents the minimum value within the corresponding block of the original
- image.
- If the backend is NumPy, the downsampling is performed using
- `skimage.measure.block_reduce`.
+#TODO ***CM*** revise typing, docstring, unit test
+class GaussianBlur(Blur):
+ """Applies a Gaussian blur to images using Gaussian kernels.
- If the backend is PyTorch, the downsampling is performed using the inverse
- of `torch.nn.functional.max_pool2d` by changing the sign of the input.
+ This class blurs images by convolving them with a Gaussian filter, which
+ smooths the image and reduces high-frequency details. The level of blurring
+ is controlled by the standard deviation (`sigma`) of the Gaussian kernel.
Parameters
----------
- ksize: int
- Size of the pooling kernel.
- **kwargs: Any
- Additional parameters sent to the pooling function.
+ sigma: float
+ Standard deviation of the Gaussian kernel.
Examples
--------
>>> import deeptrack as dt
+ >>> import numpy as np
+ >>> import matplotlib.pyplot as plt
Create an input image:
- >>> import numpy as np
- >>>
>>> input_image = np.random.rand(32, 32)
- Define and use a min-pooling feature:
- >>> min_pooling = dt.MinPooling(ksize=4)
- >>> output_image = min_pooling(input_image)
- >>> output_image.shape
- (8, 8)
+ Define a Gaussian blur feature:
+ >>> gaussian_blur = dt.GaussianBlur(sigma=2)
+ >>> output_image = gaussian_blur(input_image)
+ >>> print(output_image.shape)
+ (32, 32)
+
+ Visualize the input and output images:
+ >>> plt.figure(figsize=(8, 4))
+ >>> plt.subplot(1, 2, 1)
+ >>> plt.imshow(input_image, cmap='gray')
+ >>> plt.subplot(1, 2, 2)
+ >>> plt.imshow(output_image, cmap='gray')
+ >>> plt.show()
"""
- def __init__(
- self: MinPooling,
- ksize: PropertyLike[int] = 3,
- **kwargs: Any,
- ):
- """Initialize the parameters for min-pooling.
+ def __init__(self: GaussianBlur, sigma: PropertyLike[float] = 2, **kwargs: Any):
+ """Initialize the parameters for Gaussian blurring.
- This constructor initializes the parameters for min-pooling and checks
- whether to use the NumPy or PyTorch implementation, defaults to NumPy.
+ This constructor initializes the parameters for Gaussian blurring.
Parameters
----------
- ksize: int
- Size of the pooling kernel.
+ sigma: float
+ Standard deviation of the Gaussian kernel.
**kwargs: Any
Additional keyword arguments.
"""
- super().__init__(np.min, ksize=ksize, **kwargs)
-
- def get(
- self: MinPooling,
- image: NDArray[Any] | torch.Tensor,
- ksize: int=3,
- **kwargs: Any,
- ) -> NDArray[Any] | torch.Tensor:
- """Min pooling of input.
-
- Checks the current backend and chooses the appropriate function to pool
- the input image, either `._get_torch()` or `._get_numpy()`.
-
- Parameters
- ----------
- image: array or tensor
- Input array or tensor to be pooled.
- ksize: int
- Kernel size of the pooling operation.
-
- Returns
- -------
- array or tensor
- The pooled image as `NDArray` or `torch.Tensor` depending on the
- backend.
-
- """
-
- if self.get_backend() == "numpy":
- return self._get_numpy(image, ksize, **kwargs)
-
- if self.get_backend() == "torch":
- return self._get_torch(image, ksize, **kwargs)
-
- raise NotImplementedError(f"Backend {self.backend} not supported")
+ self.sigma = float(sigma)
+ super().__init__(None, **kwargs)
def _get_numpy(
- self: MinPooling,
- image: NDArray[Any],
- ksize: int=3,
+ self,
+ input: np.ndarray,
**kwargs: Any,
- ) -> NDArray[Any]:
- """Min-pooling with the NumPy backend.
-
- Returns the result of the input array passed to the scikit
- `image block_reduce()` function with `np.min()` as the pooling
- function.
-
- Parameters
- ----------
- image: NDArray
- Input image to be pooled.
- ksize: int
- Kernel size of the pooling operation.
-
- Returns
- -------
- NDArray
- The pooled image as a `NDArray`.
-
- """
+ ) -> np.ndarray:
+ return ndimage.gaussian_filter(
+ input,
+ sigma=self.sigma,
+ mode=kwargs.get("mode", "reflect"),
+ cval=kwargs.get("cval", 0),
+ )
- return utils.safe_call(
- skimage.measure.block_reduce,
- image=image,
- func=np.min,
- block_size=ksize,
- **kwargs,
+ def _gaussian_kernel_1d(
+ self,
+ sigma: float,
+ device,
+ dtype,
+ ) -> torch.Tensor:
+ radius = int(np.ceil(3 * sigma))
+ x = torch.arange(
+ -radius, radius + 1,
+ device=device,
+ dtype=dtype,
)
+ kernel = torch.exp(-(x ** 2) / (2 * sigma ** 2))
+ kernel /= kernel.sum()
+ return kernel
def _get_torch(
- self: MinPooling,
- image: torch.Tensor,
- ksize: int=3,
+ self,
+ input: torch.Tensor,
**kwargs: Any,
) -> torch.Tensor:
- """Min-pooling with the PyTorch backend.
+ import torch.nn.functional as F
- As PyTorch does not have a min-pooling layer, the equivalent operation
- is to first multiply the input tensor with `-1`, then perform
- max-pooling, and finally multiply the max pooled tensor with `-1`.
+ sigma = self.sigma
+ kernel_1d = self._gaussian_kernel_1d(
+ sigma,
+ device=input.device,
+ dtype=input.dtype,
+ )
- Parameters
- ----------
- image: torch.Tensor
- Input tensor to be pooled.
- ksize: int
- Kernel size of the pooling operation.
+ last_dim_is_channel = input.ndim >= 3
+ if last_dim_is_channel:
+ input = input.movedim(-1, 0) # C, ...
+ else:
+ input = input.unsqueeze(0) # 1, ...
- Returns
- -------
- torch.Tensor
- The pooled image as a `torch.Tensor`.
+ # add batch dimension
+ input = input.unsqueeze(0) # 1, C, ...
- """
+ spatial_dims = input.ndim - 2
+ C = input.shape[1]
+
+ for d in range(spatial_dims):
+ k = kernel_1d
+ shape = [1] * spatial_dims
+ shape[d] = -1
+ k = k.view(1, 1, *shape)
+ k = k.repeat(C, 1, *([1] * spatial_dims))
- # If input tensor is 2D
- if len(image.shape) == 2:
- # Add batch dimension for min-pooling
- expanded_image = image.unsqueeze(0)
+ pad = [0, 0] * spatial_dims
+ radius = k.shape[2 + d] // 2
+ pad[-(2 * d + 2)] = radius
+ pad[-(2 * d + 1)] = radius
+ pad = tuple(pad)
- pooled_image = - torch.nn.functional.max_pool2d(
- expanded_image * (-1),
- kernel_size=ksize,
+ input = F.pad(
+ input,
+ pad,
+ mode=kwargs.get("mode", "reflect"),
)
- # Remove the expanded dim
- return pooled_image.squeeze(0)
+ if spatial_dims == 1:
+ input = F.conv1d(input, k, groups=C)
+ elif spatial_dims == 2:
+ input = F.conv2d(input, k, groups=C)
+ elif spatial_dims == 3:
+ input = F.conv3d(input, k, groups=C)
+ else:
+ raise NotImplementedError(
+ f"{spatial_dims}D Gaussian blur not supported"
+ )
+
+ # restore layout
+ input = input.squeeze(0)
+ if last_dim_is_channel:
+ input = input.movedim(0, -1)
+ else:
+ input = input.squeeze(0)
- return -torch.nn.functional.max_pool2d(
- image * (-1),
- kernel_size=ksize,
- )
+ return input
-#TODO ***AL*** revise MedianPooling - torch, typing, docstring, unit test
-class MedianPooling(Pool):
- """Apply median pooling to images.
+#TODO ***JH*** revise MedianBlur - torch, typing, docstring, unit test
+class MedianBlur(Blur):
+ """Applies a median blur.
+
+ This class replaces each pixel of the input image with the median value of
+ its neighborhood. The `ksize` parameter determines the size of the
+ neighborhood used to calculate the median filter. The median filter is
+ useful for reducing noise while preserving edges. It is particularly
+ effective for removing salt-and-pepper noise from images.
- This class reduces the resolution of an image by dividing it into
- non-overlapping blocks of size `ksize` and applying the median function to
- each block. The result is a downsampled image where each pixel value
- represents the median value within the corresponding block of the
- original image. This is useful for reducing the size of an image while
- retaining the most significant features.
+ - NumPy backend: `scipy.ndimage.median_filter`
+ - Torch backend: explicit unfolding followed by `torch.median`
Parameters
----------
ksize: int
- Size of the pooling kernel.
- **kwargs: Any
- Additional parameters sent to the pooling function.
+ Kernel size.
+ **kwargs: dict
+ Additional parameters sent to the blurring function.
+
+ Notes
+ -----
+ Torch median blurring is significantly more expensive than mean or
+ Gaussian blurring due to explicit tensor unfolding.
Examples
--------
>>> import deeptrack as dt
>>> import numpy as np
+ >>> import matplotlib.pyplot as plt
Create an input image:
>>> input_image = np.random.rand(32, 32)
- Define a median pooling feature:
- >>> median_pooling = dt.MedianPooling(ksize=3)
- >>> output_image = median_pooling(input_image)
+ Define a median blur feature:
+ >>> median_blur = dt.MedianBlur(ksize=3)
+ >>> output_image = median_blur(input_image)
>>> print(output_image.shape)
(32, 32)
@@ -1638,37 +1146,99 @@ class MedianPooling(Pool):
>>> plt.imshow(output_image, cmap='gray')
>>> plt.show()
- Notes
- -----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
-
"""
def __init__(
- self: MedianPooling,
+ self: MedianBlur,
ksize: PropertyLike[int] = 3,
**kwargs: Any,
):
- """Initialize the parameters for median pooling.
+ self.ksize = int(ksize)
+ super().__init__(None, **kwargs)
+
+ def _get_numpy(
+ self,
+ input: np.ndarray,
+ **kwargs: Any,
+ ) -> np.ndarray:
+ return ndimage.median_filter(
+ input,
+ size=self.ksize,
+ mode=kwargs.get("mode", "reflect"),
+ cval=kwargs.get("cval", 0),
+ )
- This constructor initializes the parameters for median pooling.
+ def _get_torch(
+ self,
+ input: torch.Tensor,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ import torch.nn.functional as F
- Parameters
- ----------
- ksize: int
- Size of the pooling kernel.
- **kwargs: Any
- Additional keyword arguments.
+ k = self.ksize
+ if k % 2 == 0:
+ raise ValueError("MedianBlur requires an odd kernel size.")
- """
+ last_dim_is_channel = input.ndim >= 3
+ if last_dim_is_channel:
+ input = input.movedim(-1, 0) # C, ...
+ else:
+ input = input.unsqueeze(0) # 1, ...
+
+ # add batch dimension
+ input = input.unsqueeze(0) # 1, C, ...
- super().__init__(np.median, ksize=ksize, **kwargs)
+ spatial_dims = input.ndim - 2
+ pad = k // 2
+ # padding
+ pad_tuple = []
+ for _ in range(spatial_dims):
+ pad_tuple.extend([pad, pad])
+ pad_tuple = tuple(reversed(pad_tuple))
-class PoolV2:
+ input = F.pad(
+ input,
+ pad_tuple,
+ mode=kwargs.get("mode", "reflect"),
+ )
+
+ # unfold spatial dimensions
+ if spatial_dims == 1:
+ x = input.unfold(2, k, 1)
+ elif spatial_dims == 2:
+ x = (
+ input
+ .unfold(2, k, 1)
+ .unfold(3, k, 1)
+ )
+ elif spatial_dims == 3:
+ x = (
+ input
+ .unfold(2, k, 1)
+ .unfold(3, k, 1)
+ .unfold(4, k, 1)
+ )
+ else:
+ raise NotImplementedError(
+ f"{spatial_dims}D median blur not supported"
+ )
+
+ # flatten neighborhood and take median
+ x = x.contiguous().view(*x.shape[:-spatial_dims], -1)
+ x = x.median(dim=-1).values
+
+ # restore layout
+ x = x.squeeze(0)
+ if last_dim_is_channel:
+ x = x.movedim(0, -1)
+ else:
+ x = x.squeeze(0)
+
+ return x
+
+#TODO ***CM*** revise typing, docstring, unit test
+class Pool:
"""
DeepTrack v2 replacement for Pool.
@@ -1850,31 +1420,31 @@ def __call__(self, array):
return self._pool_torch(array)
raise TypeError(
- "PoolV2 only supports np.ndarray or torch.Tensor inputs."
+ "Pool only supports np.ndarray or torch.Tensor inputs."
)
-class AveragePoolingV2(PoolV2):
+class AveragePooling(Pool):
def __init__(self, ksize: int = 2):
super().__init__(np.mean, ksize)
-class SumPoolingV2(PoolV2):
+class SumPooling(Pool):
def __init__(self, ksize: int = 2):
super().__init__(np.sum, ksize)
-class MinPoolingV2(PoolV2):
+class MinPooling(Pool):
def __init__(self, ksize: int = 2):
super().__init__(np.min, ksize)
-class MaxPoolingV2(PoolV2):
+class MaxPooling(Pool):
def __init__(self, ksize: int = 2):
super().__init__(np.max, ksize)
-class MedianPoolingV2(PoolV2):
+class MedianPooling(Pool):
def __init__(self, ksize: int = 2):
super().__init__(np.median, ksize)
@@ -1955,10 +1525,10 @@ def __init__(
def get(
self: Resize,
- image: NDArray | torch.Tensor,
+ image: np.ndarray | torch.Tensor,
dsize: tuple[int, int],
**kwargs: Any,
- ) -> NDArray | torch.Tensor:
+ ) -> np.ndarray | torch.Tensor:
"""Resize the input image to the specified size.
Parameters
@@ -1991,9 +1561,6 @@ def get(
"""
- if self._wrap_array_with_image:
- image = strip(image)
-
if apc.is_torch_array(image):
original_shape = image.shape
@@ -2080,13 +1647,6 @@ class BlurCV2(Feature):
>>> print(output_image.shape)
(32, 32)
- Notes
- -----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
-
"""
def __new__(
@@ -2230,13 +1790,6 @@ class BilateralBlur(BlurCV2):
>>> print(output_image.shape)
(32, 32)
- Notes
- -----
- Calling this feature returns a `np.ndarray` by default. If
- `store_properties` is set to `True`, the returned array will be
- automatically wrapped in an `Image` object. This behavior is handled
- internally and does not affect the return type of the `get()` method.
-
"""
def __init__(
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index bb3cf40f..eeb2fca1 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -151,7 +151,7 @@ def _pad_volume(
get_active_scale,
get_active_voxel_size,
)
-from deeptrack.math import AveragePoolingV2, SumPoolingV2
+from deeptrack.math import AveragePooling, SumPooling
from deeptrack.features import propagate_data_to_dependencies
from deeptrack.features import DummyFeature, Feature, StructuralFeature
from deeptrack.image import pad_image_to_fft
@@ -1840,6 +1840,22 @@ def extract_contrast_volume(
return (value ** 2) * scattered.array
+ def downscale_image(self, image: np.ndarray, upscale):
+ """Detector downscaling (energy conserving)"""
+ if not np.any(np.array(upscale) != 1):
+ return image
+
+ ux, uy = upscale[:2]
+ if ux != uy:
+ raise ValueError(
+ f"Energy-conserving detector integration requires ux == uy, "
+ f"got ux={ux}, uy={uy}."
+ )
+ if isinstance(ux, float) and ux.is_integer():
+ ux = int(ux)
+
+ # Energy-conserving detector integration
+ return SumPoolingV2(ux)(image)
#Retrieve get as super
def get(
@@ -3007,7 +3023,7 @@ def _process_and_get(
)
return output
-
+
#TODO ***??*** revise _get_position - torch, typing, docstring, unit test
def _get_position(
From bee5de01eee6e12d14a1bd87c9f762a4cb61ee24 Mon Sep 17 00:00:00 2001
From: Carlo
Date: Fri, 16 Jan 2026 11:22:49 +0100
Subject: [PATCH 23/24] u
---
deeptrack/math.py | 203 +++++++++++++++++++++++++++-------------
deeptrack/optics.py | 6 +-
deeptrack/scatterers.py | 14 +--
3 files changed, 147 insertions(+), 76 deletions(-)
diff --git a/deeptrack/math.py b/deeptrack/math.py
index 5845fb90..2f8b7339 100644
--- a/deeptrack/math.py
+++ b/deeptrack/math.py
@@ -1453,54 +1453,67 @@ def __init__(self, ksize: int = 2):
class Resize(Feature):
"""Resize an image to a specified size.
- `Resize` resizes an image using:
- - OpenCV (`cv2.resize`) for NumPy arrays.
- - PyTorch (`torch.nn.functional.interpolate`) for PyTorch tensors.
+ `Resize` resizes images following the channels-last semantic
+ convention.
- The interpretation of the `dsize` parameter follows the convention
- of the underlying backend:
- - **NumPy (OpenCV)**: `dsize` is given as `(width, height)` to match
- OpenCV’s default.
- - **PyTorch**: `dsize` is given as `(height, width)`.
+ The operation supports both NumPy arrays and PyTorch tensors:
+ - NumPy arrays are resized using OpenCV (`cv2.resize`).
+ - PyTorch tensors are resized using `torch.nn.functional.interpolate`.
+
+ In all cases, the input is interpreted as having spatial dimensions
+ first and an optional channel dimension last.
Parameters
----------
- dsize: PropertyLike[tuple[int, int]]
- The target size. Format depends on backend: `(width, height)` for
- NumPy, `(height, width)` for PyTorch.
- **kwargs: Any
- Additional parameters sent to the underlying resize function:
- - NumPy: passed to `cv2.resize`.
- - PyTorch: passed to `torch.nn.functional.interpolate`.
+ dsize : PropertyLike[tuple[int, int]]
+ Target output size given as (width, height). This convention is
+ backend-independent and applies equally to NumPy and PyTorch inputs.
+
+ **kwargs : Any
+ Additional keyword arguments forwarded to the underlying resize
+ implementation:
+ - NumPy backend: passed to `cv2.resize`.
+ - PyTorch backend: passed to
+ `torch.nn.functional.interpolate`.
Methods
-------
get(
- image: np.ndarray | torch.Tensor, dsize: tuple[int, int], **kwargs
+ image: np.ndarray | torch.Tensor,
+ dsize: tuple[int, int],
+ **kwargs
) -> np.ndarray | torch.Tensor
Resize the input image to the specified size.
Examples
--------
- >>> import deeptrack as dt
+ NumPy example:
- Numpy example:
>>> import numpy as np
- >>>
- >>> input_image = np.random.rand(16, 16) # Create image
- >>> feature = dt.math.Resize(dsize=(8, 4)) # (width=8, height=4)
- >>> resized_image = feature.resolve(input_image) # Resize it to (4, 8)
- >>> print(resized_image.shape)
+ >>> input_image = np.random.rand(16, 16)
+ >>> feature = dt.math.Resize(dsize=(8, 4)) # (width=8, height=4)
+ >>> resized_image = feature.resolve(input_image)
+ >>> resized_image.shape
(4, 8)
PyTorch example:
+
>>> import torch
- >>>
- >>> input_image = torch.rand(1, 1, 16, 16) # Create image
- >>> feature = dt.math.Resize(dsize=(4, 8)) # (height=4, width=8)
- >>> resized_image = feature.resolve(input_image) # Resize it to (4, 8)
- >>> print(resized_image.shape)
- torch.Size([1, 1, 4, 8])
+ >>> input_image = torch.rand(16, 16) # channels-last
+ >>> feature = dt.math.Resize(dsize=(8, 4))
+ >>> resized_image = feature.resolve(input_image)
+ >>> resized_image.shape
+ torch.Size([4, 8])
+
+ Notes
+ -----
+ - Resize follows channels-last semantics, consistent with other features
+ such as Pool and Blur.
+ - Torch tensors with channels-first layout (e.g. (C, H, W) or
+ (N, C, H, W)) are not supported and must be converted to
+ channels-last format before resizing.
+ - For PyTorch tensors, bilinear interpolation is used with
+ `align_corners=False`, closely matching OpenCV’s default behavior.
"""
@@ -1533,67 +1546,109 @@ def get(
Parameters
----------
- image: np.ndarray or torch.Tensor
- The input image to resize.
- - NumPy arrays may be grayscale (H, W) or color (H, W, C).
- - Torch tensors are expected in one of the following formats:
- (N, C, H, W), (C, H, W), or (H, W).
- dsize: tuple[int, int]
- Desired output size of the image.
- - NumPy: (width, height)
- - PyTorch: (height, width)
- **kwargs: Any
- Additional keyword arguments passed to the underlying resize
- function (`cv2.resize` or `torch.nn.functional.interpolate`).
+ image : np.ndarray or torch.Tensor
+ Input image following channels-last semantics.
+
+ Supported shapes are:
+ - (H, W)
+ - (H, W, C)
+ - (Z, H, W)
+ - (Z, H, W, C)
+
+ For PyTorch tensors, channels-first layouts such as (C, H, W) or
+ (N, C, H, W) are not supported and must be converted to
+ channels-last format before calling `Resize`.
+
+ dsize : tuple[int, int]
+ Desired output size given as (width, height). This convention is
+ backend-independent and applies to both NumPy and PyTorch inputs.
+
+ **kwargs : Any
+ Additional keyword arguments passed to the underlying resize
+ implementation:
+ - NumPy backend: forwarded to `cv2.resize`.
+ - PyTorch backend: forwarded to `torch.nn.functional.interpolate`.
Returns
-------
np.ndarray or torch.Tensor
- The resized image in the same type and dimensionality format as
- input.
+ The resized image, with the same type and dimensionality layout as
+ the input image.
Notes
-----
+ - Resize follows the same channels-last semantic convention as other
+ features in `deeptrack.math`.
- For PyTorch tensors, resizing uses bilinear interpolation with
- `align_corners=False`. This choice matches OpenCV’s `cv2.resize`
- default behavior when resizing NumPy arrays, aiming to produce nearly
- identical results between both backends.
+ `align_corners=False`, which closely matches OpenCV’s default behavior.
"""
+ target_w, target_h = dsize
+
+ # Torch backend
if apc.is_torch_array(image):
- original_shape = image.shape
-
- # Reshape input to (N, C, H, W)
- if image.ndim == 2: # (H, W)
- image = image.unsqueeze(0).unsqueeze(0)
- elif image.ndim == 3: # (C, H, W)
- image = image.unsqueeze(0)
- elif image.ndim != 4:
+ import torch.nn.functional as F
+
+ original_ndim = image.ndim
+ has_channels = (
+ image.ndim >= 3 and image.shape[-1] <= 4
+ )
+
+ # Bring to (N, C, H, W)
+ if image.ndim == 2:
+ # (H, W) -> (1, 1, H, W)
+ x = image.unsqueeze(0).unsqueeze(0)
+
+ elif image.ndim == 3 and has_channels:
+ # (H, W, C) -> (1, C, H, W)
+ x = image.permute(2, 0, 1).unsqueeze(0)
+
+ elif image.ndim == 3:
+ # (Z, H, W) -> treat Z as batch
+ x = image.unsqueeze(1)
+
+ elif image.ndim == 4 and has_channels:
+ # (Z, H, W, C) -> (Z, C, H, W)
+ x = image.permute(0, 3, 1, 2)
+
+ else:
raise ValueError(
- "Resize only supports tensors with shape (N, C, H, W), "
- "(C, H, W), or (H, W)."
+ f"Unsupported tensor shape {image.shape} for Resize."
)
- resized = torch.nn.functional.interpolate(
- image,
- size=dsize,
+ # Resize spatial dimensions
+ resized = F.interpolate(
+ x,
+ size=(target_h, target_w),
mode="bilinear",
align_corners=False,
)
- # Restore original dimensionality
- if len(original_shape) == 2:
- resized = resized.squeeze(0).squeeze(0)
- elif len(original_shape) == 3:
- resized = resized.squeeze(0)
+ # Restore original layout
+ if original_ndim == 2:
+ return resized.squeeze(0).squeeze(0)
+
+ if original_ndim == 3 and has_channels:
+ return resized.squeeze(0).permute(1, 2, 0)
+
+ if original_ndim == 3:
+ return resized.squeeze(1)
+
+ if original_ndim == 4:
+ return resized.permute(0, 2, 3, 1)
- return resized
+ raise RuntimeError("Unexpected shape restoration path.")
+ # NumPy / OpenCV backend
else:
import cv2
+
+ # OpenCV expects (width, height)
return utils.safe_call(
- cv2.resize, positional_args=[image, dsize], **kwargs
+ cv2.resize,
+ positional_args=[image, (target_w, target_h)],
+ **kwargs,
)
@@ -1647,6 +1702,12 @@ class BlurCV2(Feature):
>>> print(output_image.shape)
(32, 32)
+ Notes
+ -----
+ BlurCV2 is NumPy-only and does not support PyTorch tensors.
+ This class is intended for OpenCV-specific filters that are
+ not available in the backend-agnostic math layer.
+
"""
def __new__(
@@ -1741,6 +1802,12 @@ def get(
"""
+ if apc.is_torch_array(image):
+ raise TypeError(
+ "BlurCV2 only supports NumPy arrays. "
+ "For Torch tensors, use Blur or GaussianBlur instead."
+ )
+
kwargs.pop("name", None)
result = self.filter(src=image, **kwargs)
return result
@@ -1790,6 +1857,10 @@ class BilateralBlur(BlurCV2):
>>> print(output_image.shape)
(32, 32)
+ Notes
+ -----
+ BilateralBlur is NumPy-only and does not support PyTorch tensors.
+
"""
def __init__(
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index eeb2fca1..a7394689 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -288,7 +288,7 @@ def _downscale_image(self, image, upscale):
)
if isinstance(ux, float) and ux.is_integer():
ux = int(ux)
- return AveragePoolingV2(ux)(image)
+ return AveragePooling(ux)(image)
def get(
self: Microscope,
@@ -1130,7 +1130,7 @@ def downscale_image(self, image: np.ndarray, upscale):
ux = int(ux)
# Energy-conserving detector integration
- return SumPoolingV2(ux)(image)
+ return SumPooling(ux)(image)
def get(
@@ -1855,7 +1855,7 @@ def downscale_image(self, image: np.ndarray, upscale):
ux = int(ux)
# Energy-conserving detector integration
- return SumPoolingV2(ux)(image)
+ return SumPooling(ux)(image)
#Retrieve get as super
def get(
diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py
index 2ba202e0..b7e1b70a 100644
--- a/deeptrack/scatterers.py
+++ b/deeptrack/scatterers.py
@@ -175,7 +175,7 @@
get_active_voxel_size,
)
from deeptrack.backend import mie
-from deeptrack.math import AveragePoolingV2
+from deeptrack.math import AveragePooling
from deeptrack.features import Feature, MERGE_STRATEGY_APPEND
from deeptrack.image import pad_image_to_fft
from deeptrack.types import ArrayLike
@@ -261,11 +261,11 @@ def __init__(
**kwargs,
) -> None:
# Ignore warning to help with comparison with arrays.
- if upsample != 1: # noqa: F632
- warnings.warn(
- f"Setting upsample != 1 is deprecated. "
- f"Please, instead use dt.Upscale(f, factor={upsample})"
- )
+ # if upsample != 1: # noqa: F632
+ # warnings.warn(
+ # f"Setting upsample != 1 is deprecated. "
+ # f"Please, instead use dt.Upscale(f, factor={upsample})"
+ # )
self._processed_properties = False
@@ -291,7 +291,7 @@ def _antialias_volume(self, volume, factor: int):
return volume
# average pooling conserves fractional occupancy
- return AveragePoolingV2(
+ return AveragePooling(
factor
)(volume)
From f18884cacc74d12823902fee48796c58d91d5f4d Mon Sep 17 00:00:00 2001
From: Carlo
Date: Sun, 18 Jan 2026 04:13:54 +0100
Subject: [PATCH 24/24] implemented BackendDispatched
---
deeptrack/features.py | 32 +
deeptrack/math.py | 1389 ++++++++++++++++++++++++++---------------
deeptrack/optics.py | 33 +-
3 files changed, 919 insertions(+), 535 deletions(-)
diff --git a/deeptrack/features.py b/deeptrack/features.py
index 702e7362..9572ea82 100644
--- a/deeptrack/features.py
+++ b/deeptrack/features.py
@@ -52,6 +52,11 @@
hierarchical or logical structures in the pipeline without input
transformations.
+- `BackendDispatched`: Mixin class for backend-specific implementations.
+
+ Provides mechanisms for dispatching feature methods based on the
+ computational backend (e.g., NumPy, PyTorch).
+
- `ArithmeticOperationFeature`: Apply arithmetic operation element-wise.
A parent class for features performing arithmetic operations like addition,
@@ -180,6 +185,7 @@
__all__ = [
"Feature",
"StructuralFeature",
+ "BackendDispatched",
"Chain",
"Branch",
"DummyFeature",
@@ -3947,6 +3953,32 @@ class StructuralFeature(Feature):
__distributed__: bool = False # Process the entire image list in one call
+class BackendDispatched:
+ """Mixin for Feature.get() methods with backend-specific implementations."""
+
+ _NUMPY_IMPL: str | None = None
+ _TORCH_IMPL: str | None = None
+
+ def _dispatch_backend(self, *args, **kwargs):
+ backend = self.get_backend()
+
+ if backend == "numpy":
+ if self._NUMPY_IMPL is None:
+ raise NotImplementedError(
+ f"{self.__class__.__name__} does not support NumPy backend."
+ )
+ return getattr(self, self._NUMPY_IMPL)(*args, **kwargs)
+
+ if backend == "torch":
+ if self._TORCH_IMPL is None:
+ raise NotImplementedError(
+ f"{self.__class__.__name__} does not support Torch backend."
+ )
+ return getattr(self, self._TORCH_IMPL)(*args, **kwargs)
+
+ raise RuntimeError(f"Unknown backend {backend}")
+
+
class Chain(StructuralFeature):
"""Resolve two features sequentially.
diff --git a/deeptrack/math.py b/deeptrack/math.py
index 2f8b7339..40a8adc7 100644
--- a/deeptrack/math.py
+++ b/deeptrack/math.py
@@ -59,6 +59,8 @@
- `MinPooling`: Apply min-pooling to the image.
+- `SumPooling`: Apply sum pooling to the image.
+
- `MedianPooling`: Apply median pooling to the image.
- `Resize`: Resize the image to a specified size.
@@ -97,14 +99,12 @@
import array_api_compat as apc
import numpy as np
-from numpy.typing import NDArray #TODO TBE
from scipy import ndimage
import skimage
import skimage.measure
from deeptrack import utils, OPENCV_AVAILABLE, TORCH_AVAILABLE
-from deeptrack.features import Feature
-from deeptrack.image import Image, strip #TODO TBE
+from deeptrack.features import Feature, BackendDispatched
from deeptrack.types import PropertyLike
from deeptrack.backend import xp
@@ -129,7 +129,9 @@
"AveragePooling",
"MaxPooling",
"MinPooling",
+ "SumPooling",
"MedianPooling",
+ "Resize",
"BlurCV2",
"BilateralBlur",
]
@@ -297,8 +299,8 @@ class Clip(Feature):
def __init__(
self: Clip,
- min: PropertyLike[float] = -np.inf,
- max: PropertyLike[float] = +np.inf,
+ min: PropertyLike[float] = -xp.inf,
+ max: PropertyLike[float] = +xp.inf,
**kwargs: Any,
):
"""Initialize the clipping range.
@@ -306,9 +308,9 @@ def __init__(
Parameters
----------
min: float, optional
- Minimum allowed value. It defaults to `-np.inf`.
+ Minimum allowed value. It defaults to `-xp.inf`.
max: float, optional
- Maximum allowed value. It defaults to `+np.inf`.
+ Maximum allowed value. It defaults to `+xp.inf`.
**kwargs: Any
Additional keyword arguments.
@@ -441,31 +443,31 @@ def get(
"""
- if featurewise:
- # Normalize per feature (last axis)
- axis = tuple(range(image.ndim - 1))
+ has_channels = image.ndim >= 3 and image.shape[-1] <= 4
+ if featurewise and has_channels:
+ # reduce over spatial dimensions only
+ axis = tuple(range(image.ndim - 1))
img_min = xp.min(image, axis=axis, keepdims=True)
img_max = xp.max(image, axis=axis, keepdims=True)
else:
- # Normalize globally
+ # global normalization
img_min = xp.min(image)
img_max = xp.max(image)
ptp = img_max - img_min
+ eps = xp.asarray(1e-8, dtype=image.dtype)
+ ptp = xp.maximum(ptp, eps)
- # Avoid division by zero
- image = (image - img_min) / ptp * (max - min) + min
-
- try:
- image[xp.isnan(image)] = 0
- except TypeError:
- pass
+ image = (image - img_min) / ptp
+ image = image * (max - min) + min
+ image = xp.where(xp.isnan(image), xp.zeros_like(image), image)
return image
-class NormalizeStandard(Feature):
+
+class NormalizeStandard(BackendDispatched, Feature):
"""Image normalization using standardization.
Standardizes the input image to have zero mean and unit standard
@@ -499,6 +501,9 @@ class NormalizeStandard(Feature):
"""
+ _NUMPY_IMPL = "_get_numpy"
+ _TORCH_IMPL = "_get_torch"
+
def __init__(
self: NormalizeStandard,
featurewise: PropertyLike[bool] = True,
@@ -539,37 +544,63 @@ def get(
np.ndarray or torch.Tensor
The standardized image.
"""
+ return self._dispatch_backend(
+ image,
+ featurewise=featurewise,
+ )
- if featurewise:
- # Normalize per feature (last axis)
- axis = tuple(range(image.ndim - 1))
+ # ------ NumPy backend ------
+
+ def _get_numpy(
+ self,
+ image: np.ndarray,
+ featurewise: bool,
+ ) -> np.ndarray:
- mean = xp.mean(image, axis=axis, keepdims=True)
+ has_channels = image.ndim >= 3 and image.shape[-1] <= 4
- if apc.is_torch_array(image):
- std = torch.std(image, dim=axis, keepdim=True, unbiased=False)
- else:
- std = xp.std(image, axis=axis)
+ if featurewise and has_channels:
+ axis = tuple(range(image.ndim - 1))
+ mean = np.mean(image, axis=axis, keepdims=True)
+ std = np.std(image, axis=axis, keepdims=True) # population std
else:
- # Normalize globally
- mean = xp.mean(image)
+ mean = np.mean(image)
+ std = np.std(image)
- if apc.is_torch_array(image):
- std = torch.std(image, unbiased=False)
- else:
- std = xp.std(image)
+ std = np.maximum(std, 1e-8)
- image = (image - mean) / std
+ out = (image - mean) / std
+ out = np.where(np.isnan(out), 0.0, out)
- try:
- image[xp.isnan(image)] = 0
- except TypeError:
- pass
+ return out
- return image
+ # ------ Torch backend ------
+
+ def _get_torch(
+ self,
+ image: torch.Tensor,
+ featurewise: bool,
+ ) -> torch.Tensor:
+
+ has_channels = image.ndim >= 3 and image.shape[-1] <= 4
+
+ if featurewise and has_channels:
+ axis = tuple(range(image.ndim - 1))
+ mean = image.mean(dim=axis, keepdim=True)
+ std = image.std(dim=axis, keepdim=True, unbiased=False)
+ else:
+ mean = image.mean()
+ std = image.std(unbiased=False)
+ std = torch.clamp(std, min=1e-8)
-class NormalizeQuantile(Feature):
+ out = (image - mean) / std
+ out = torch.nan_to_num(out, nan=0.0)
+
+ return out
+
+
+class NormalizeQuantile(BackendDispatched, Feature):
"""Image normalization using quantiles.
Centers the image at the median and scales it such that the values at the
@@ -589,6 +620,12 @@ class NormalizeQuantile(Feature):
get(image: array, quantiles: tuple[float, float], **kwargs) -> array
Normalizes the input based on the given quantile range.
+ Notes
+ -----
+ This operation is not differentiable. When used inside a gradient-based
+ model, it will block gradient flow. Use with care if end-to-end
+ differentiability is required.
+
Examples
--------
>>> import deeptrack as dt
@@ -607,7 +644,8 @@ class NormalizeQuantile(Feature):
"""
- #TODO ___??___ Implement the `featurewise=False` option
+ _NUMPY_IMPL = "_get_numpy"
+ _TORCH_IMPL = "_get_torch"
def __init__(
self: NormalizeQuantile,
@@ -637,12 +675,25 @@ def __init__(
)
def get(
- self: NormalizeQuantile,
+ self,
image: np.ndarray | torch.Tensor,
quantiles: tuple[float, float],
featurewise: bool,
**kwargs: Any,
- ) -> np.ndarray | torch.Tensor:
+ ):
+ return self._dispatch_backend(
+ image,
+ quantiles=quantiles,
+ featurewise=featurewise,
+ )
+
+ def _get_numpy(
+ self: NormalizeQuantile,
+ image: np.ndarray,
+ quantiles: tuple[float, float],
+ featurewise: bool,
+ **kwargs: Any,
+ ) -> np.ndarray:
"""Normalize the input image based on the specified quantiles.
Parameters
@@ -658,140 +709,127 @@ def get(
-------
np.ndarray or torch.Tensor
The quantile-normalized image.
+
"""
q_low_val, q_high_val = quantiles
- if featurewise:
- # Per-feature normalization (last axis)
+ has_channels = image.ndim >= 3 and image.shape[-1] <= 4
+
+ if featurewise and has_channels:
axis = tuple(range(image.ndim - 1))
+ q_low, q_high, median = np.quantile(
+ image,
+ (q_low_val, q_high_val, 0.5),
+ axis=axis,
+ keepdims=True,
+ )
+ else:
+ q_low, q_high, median = np.quantile(
+ image,
+ (q_low_val, q_high_val, 0.5),
+ )
+
+ scale = q_high - q_low
+ eps = np.asarray(1e-8, dtype=image.dtype)
+ scale = np.maximum(scale, eps)
+
+ image = (image - median) / scale
+ image = np.where(np.isnan(image), np.zeros_like(image), image)
+ return image
+
+ def _get_torch(
+ self,
+ image: torch.Tensor,
+ quantiles: tuple[float, float],
+ featurewise: bool,
+ ):
+ q_low_val, q_high_val = quantiles
- if apc.is_torch_array(image):
+ if featurewise:
+ if image.ndim < 3:
+ # No channels → global quantile
q = torch.tensor(
[q_low_val, q_high_val, 0.5],
device=image.device,
dtype=image.dtype,
)
- q_low, q_high, median = torch.quantile(
- image, q, dim=axis, keepdim=True
- )
+ q_low, q_high, median = torch.quantile(image, q)
else:
- q_low, q_high, median = xp.quantile(
- image, (q_low_val, q_high_val, 0.5),
- axis=axis,
- keepdims=True,
- )
- else:
- # Global normalization
- if apc.is_torch_array(image):
+ # channels-last: (..., C)
+ spatial_dims = image.ndim - 1
+ C = image.shape[-1]
+
+ # flatten spatial dims
+ x = image.reshape(-1, C) # (N, C)
+
q = torch.tensor(
[q_low_val, q_high_val, 0.5],
device=image.device,
dtype=image.dtype,
)
- q_low, q_high, median = torch.quantile(
- image, q, dim=None, keepdim=False
- )
- else:
- q_low, q_high, median = xp.quantile(
- image, (q_low_val, q_high_val, 0.5)
- )
- image = (image - median) / (q_high - q_low) * 2.0
+ q_vals = torch.quantile(x, q, dim=0)
+ q_low, q_high, median = q_vals
- try:
- image[xp.isnan(image)] = 0
- except TypeError:
- pass
+ # reshape for broadcasting
+ shape = [1] * image.ndim
+ shape[-1] = C
+ q_low = q_low.view(shape)
+ q_high = q_high.view(shape)
+ median = median.view(shape)
- return image
+ else:
+ q = torch.tensor(
+ [q_low_val, q_high_val, 0.5],
+ device=image.device,
+ dtype=image.dtype,
+ )
+ q_low, q_high, median = torch.quantile(image, q)
+ scale = q_high - q_low
+ scale = torch.clamp(scale, min=1e-8)
+ image = (image - median) / scale
+ image = torch.nan_to_num(image)
-#TODO ***CM*** revise typing, docstring, unit test
-class Blur(Feature):
- """Apply a blurring filter to an image.
+ return image
- This class acts as a backend-dispatching blur operator. Subclasses must
- implement backend-specific logic via `_get_numpy` and optionally
- `_get_torch`.
- Notes
- -----
- - NumPy execution is always supported.
- - Torch execution is only supported if `_get_torch` is implemented.
- - Generic `filter_function`-based blurs are NumPy-only by design.
+#TODO ***CM*** revise typing, docstring, unit test
+class Blur(BackendDispatched, Feature):
+ """Abstract blur feature with backend-dispatched implementations.
+
+ This class serves as a base for blur features that support multiple
+ backends (e.g., NumPy, Torch). Subclasses should implement backend-specific
+ blurring logic via `_get_numpy` and/or `_get_torch` methods.
+
+ Methods
+ -------
+ get(image: np.ndarray | torch.Tensor, **kwargs) -> np.ndarray | torch.Tensor
+ Applies the appropriate backend-specific blurring method.
+
+ _blur(xp, image: array, **kwargs) -> array
+ Internal method that dispatches to the correct backend-specific blur
+ implementation.
"""
- def __init__(
- self,
- filter_function: Callable | None = None,
- mode: PropertyLike[str] = "reflect",
- **kwargs: Any,
- ):
- """Initialize the blur feature.
-
- Parameters
- ----------
- filter_function : Callable or None
- NumPy-based blurring function. Must accept the input image as a
- keyword argument named `input`. If `None`, the subclass must
- implement `_get_numpy`.
- mode : str
- Border mode for NumPy-based filters.
- **kwargs : Any
- Additional keyword arguments passed to Feature.
- """
- self.filter = filter_function
- self.mode = mode
- super().__init__(**kwargs)
+ _NUMPY_IMPL = "_get_numpy"
+ _TORCH_IMPL = "_get_torch"
- def __call__(
+ def get(
self,
image: np.ndarray | torch.Tensor,
- **kwargs: Any,
- ) -> np.ndarray | torch.Tensor:
- if isinstance(image, np.ndarray):
- return self._get_numpy(image, **kwargs)
-
- if TORCH_AVAILABLE and isinstance(image, torch.Tensor):
- return self._get_torch(image, **kwargs)
-
- raise TypeError(
- "Blur only supports numpy.ndarray or torch.Tensor inputs."
- )
-
- def _get_numpy(
- self,
- image: np.ndarray,
- **kwargs: Any,
- ) -> np.ndarray:
- if self.filter is None:
- raise NotImplementedError(
- f"{self.__class__.__name__} does not implement a NumPy backend."
- )
-
- # Avoid passing conflicting keywords
- kwargs = dict(kwargs)
- kwargs.pop("input", None)
+ **kwargs,
+ ):
+ return self._dispatch_backend(image, **kwargs)
- return utils.safe_call(
- self.filter,
- input=image,
- mode=self.mode,
- **kwargs,
- )
+ def _get_numpy(self, image: np.ndarray, **kwargs):
+ raise NotImplementedError
- def _get_torch(
- self,
- image: torch.Tensor,
- **kwargs: Any,
- ) -> torch.Tensor:
- raise TypeError(
- f"{self.__class__.__name__} does not support torch.Tensor inputs. "
- "Use a Torch-enabled blur (e.g. AverageBlur or a V2 blur class)."
- )
+ def _get_torch(self, image: torch.Tensor, **kwargs):
+ raise NotImplementedError
@@ -829,10 +867,10 @@ class AverageBlur(Blur):
"""
def __init__(
- self: AverageBlur,
- ksize: PropertyLike[int] = 3,
- **kwargs: Any,
- ):
+ self: AverageBlur,
+ ksize: int = 3,
+ **kwargs: Any
+ ) -> None:
"""Initialize the parameters for averaging input features.
This constructor initializes the parameters for averaging input
@@ -847,106 +885,119 @@ def __init__(
"""
- super().__init__(None, ksize=ksize, **kwargs)
+ self.ksize = int(ksize)
+ super().__init__(**kwargs)
- def _kernel_shape(self, shape: tuple[int, ...], ksize: int) -> tuple[int, ...]:
+ @staticmethod
+ def _kernel_shape(shape: tuple[int, ...], ksize: int) -> tuple[int, ...]:
+ # If last dim is channel and smaller than kernel, do not blur channels
if shape[-1] < ksize:
return (ksize,) * (len(shape) - 1) + (1,)
return (ksize,) * len(shape)
+ # ---------- NumPy backend ----------
def _get_numpy(
- self, input: np.ndarray, ksize: tuple[int, ...], **kwargs: Any
+ self: AverageBlur,
+ image: np.ndarray,
+ **kwargs: Any
) -> np.ndarray:
+ """Apply average blurring using SciPy's uniform_filter.
+
+ This method applies average blurring to the input image using
+ SciPy's `uniform_filter`.
+
+ Parameters
+ ----------
+ image: np.ndarray
+ The input image to blur.
+ **kwargs: dict[str, Any]
+ Additional keyword arguments for `uniform_filter`.
+
+ Returns
+ -------
+ np.ndarray
+ The blurred image.
+
+ """
+
+ k = self._kernel_shape(image.shape, self.ksize)
return ndimage.uniform_filter(
- input,
- size=ksize,
+ image,
+ size=k,
mode=kwargs.get("mode", "reflect"),
cval=kwargs.get("cval", 0),
origin=kwargs.get("origin", 0),
- axes=tuple(range(0, len(ksize))),
+ axes=tuple(range(len(k))),
)
+ # ---------- Torch backend ----------
def _get_torch(
- self, input: torch.Tensor, ksize: tuple[int, ...], **kwargs: Any
+ self: AverageBlur,
+ image: torch.Tensor,
+ **kwargs: Any
) -> torch.Tensor:
+ """Apply average blurring using PyTorch's avg_pool.
- last_dim_is_channel = len(ksize) < input.ndim
+ This method applies average blurring to the input image using
+ PyTorch's `avg_pool` functions.
+
+ Parameters
+ ----------
+ image: torch.Tensor
+ The input image to blur.
+ **kwargs: dict[str, Any]
+ Additional keyword arguments for padding.
+
+ Returns
+ -------
+ torch.Tensor
+ The blurred image.
+
+ """
+
+ k = self._kernel_shape(tuple(image.shape), self.ksize)
+
+ last_dim_is_channel = len(k) < image.ndim
if last_dim_is_channel:
- input = input.movedim(-1, 0)
+ image = image.movedim(-1, 0) # C, ...
else:
- input = input.unsqueeze(0)
+ image = image.unsqueeze(0) # 1, ...
# add batch dimension
- input = input.unsqueeze(0)
+ image = image.unsqueeze(0) # 1, C, ...
- # dynamic padding
+ # symmetric padding
pad = []
- for k in reversed(ksize):
- p = k // 2
+ for kk in reversed(k):
+ p = kk // 2
pad.extend([p, p])
- pad = tuple(pad)
-
- input = F.pad(
- input,
- pad,
+ image = F.pad(
+ image,
+ tuple(pad),
mode=kwargs.get("mode", "reflect"),
value=kwargs.get("cval", 0),
)
- if input.ndim == 3:
- x = F.avg_pool1d(input, kernel_size=ksize, stride=1)
- elif input.ndim == 4:
- x = F.avg_pool2d(input, kernel_size=ksize, stride=1)
- elif input.ndim == 5:
- x = F.avg_pool3d(input, kernel_size=ksize, stride=1)
+ # pooling by dimensionality
+ if image.ndim == 3:
+ out = F.avg_pool1d(image, kernel_size=k, stride=1)
+ elif image.ndim == 4:
+ out = F.avg_pool2d(image, kernel_size=k, stride=1)
+ elif image.ndim == 5:
+ out = F.avg_pool3d(image, kernel_size=k, stride=1)
else:
raise NotImplementedError(
- f"Input dimension {input.ndim - 2} not supported for torch backend"
+ f"Input dimensionality {image.ndim - 2} not supported"
)
# restore layout
- x = x.squeeze(0)
+ out = out.squeeze(0)
if last_dim_is_channel:
- x = x.movedim(0, -1)
+ out = out.movedim(0, -1)
else:
- x = x.squeeze(0)
-
- return x
-
- def get(
- self: AverageBlur,
- input: np.ndarray | torch.Tensor,
- ksize: int,
- **kwargs: Any,
- ) -> np.ndarray | torch.Tensor:
- """Applies the average blurring filter to the input image.
-
- This method applies the average blurring filter to the input image.
-
- Parameters
- ----------
- input: np.ndarray
- The input image to blur.
- ksize: int
- Kernel size for the pooling operation.
- **kwargs: dict[str, Any]
- Additional keyword arguments.
+ out = out.squeeze(0)
- Returns
- -------
- np.ndarray
- The blurred image.
-
- """
-
- k = self._kernel_shape(input.shape, ksize)
-
- if self.backend == "numpy":
- return self._get_numpy(input, k, **kwargs)
- elif self.backend == "torch":
- return self._get_torch(input, k, **kwargs)
- else:
- raise NotImplementedError(f"Backend {self.backend} not supported")
+ return out
#TODO ***CM*** revise typing, docstring, unit test
@@ -1004,27 +1055,32 @@ def __init__(self: GaussianBlur, sigma: PropertyLike[float] = 2, **kwargs: Any):
self.sigma = float(sigma)
super().__init__(None, **kwargs)
+ # ---------- NumPy backend ----------
+
def _get_numpy(
self,
- input: np.ndarray,
+ image: np.ndarray,
**kwargs: Any,
) -> np.ndarray:
return ndimage.gaussian_filter(
- input,
+ image,
sigma=self.sigma,
mode=kwargs.get("mode", "reflect"),
cval=kwargs.get("cval", 0),
)
+ # ---------- Torch backend ----------
+
+ @staticmethod
def _gaussian_kernel_1d(
- self,
sigma: float,
device,
dtype,
) -> torch.Tensor:
radius = int(np.ceil(3 * sigma))
x = torch.arange(
- -radius, radius + 1,
+ -radius,
+ radius + 1,
device=device,
dtype=dtype,
)
@@ -1034,29 +1090,29 @@ def _gaussian_kernel_1d(
def _get_torch(
self,
- input: torch.Tensor,
+ image: torch.Tensor,
**kwargs: Any,
) -> torch.Tensor:
import torch.nn.functional as F
- sigma = self.sigma
kernel_1d = self._gaussian_kernel_1d(
- sigma,
- device=input.device,
- dtype=input.dtype,
+ self.sigma,
+ device=image.device,
+ dtype=image.dtype,
)
- last_dim_is_channel = input.ndim >= 3
+ # channel-last handling
+ last_dim_is_channel = image.ndim >= 3
if last_dim_is_channel:
- input = input.movedim(-1, 0) # C, ...
+ image = image.movedim(-1, 0) # C, ...
else:
- input = input.unsqueeze(0) # 1, ...
+ image = image.unsqueeze(0) # 1, ...
# add batch dimension
- input = input.unsqueeze(0) # 1, C, ...
+ image = image.unsqueeze(0) # 1, C, ...
- spatial_dims = input.ndim - 2
- C = input.shape[1]
+ spatial_dims = image.ndim - 2
+ C = image.shape[1]
for d in range(spatial_dims):
k = kernel_1d
@@ -1071,31 +1127,31 @@ def _get_torch(
pad[-(2 * d + 1)] = radius
pad = tuple(pad)
- input = F.pad(
- input,
+ image = F.pad(
+ image,
pad,
mode=kwargs.get("mode", "reflect"),
)
if spatial_dims == 1:
- input = F.conv1d(input, k, groups=C)
+ image = F.conv1d(image, k, groups=C)
elif spatial_dims == 2:
- input = F.conv2d(input, k, groups=C)
+ image = F.conv2d(image, k, groups=C)
elif spatial_dims == 3:
- input = F.conv3d(input, k, groups=C)
+ image = F.conv3d(image, k, groups=C)
else:
raise NotImplementedError(
f"{spatial_dims}D Gaussian blur not supported"
)
# restore layout
- input = input.squeeze(0)
+ image = image.squeeze(0)
if last_dim_is_channel:
- input = input.movedim(0, -1)
+ image = image.movedim(0, -1)
else:
- input = input.squeeze(0)
+ image = image.squeeze(0)
- return input
+ return image
#TODO ***JH*** revise MedianBlur - torch, typing, docstring, unit test
@@ -1123,6 +1179,10 @@ class MedianBlur(Blur):
Torch median blurring is significantly more expensive than mean or
Gaussian blurring due to explicit tensor unfolding.
+ Median blur is not differentiable. This is typically acceptable, as the
+ operation is intended for denoising and preprocessing rather than as a
+ trainable network layer.
+
Examples
--------
>>> import deeptrack as dt
@@ -1156,21 +1216,25 @@ def __init__(
self.ksize = int(ksize)
super().__init__(None, **kwargs)
+ # ---------- NumPy backend ----------
+
def _get_numpy(
self,
- input: np.ndarray,
+ image: np.ndarray,
**kwargs: Any,
) -> np.ndarray:
return ndimage.median_filter(
- input,
+ image,
size=self.ksize,
mode=kwargs.get("mode", "reflect"),
cval=kwargs.get("cval", 0),
)
+ # ---------- Torch backend ----------
+
def _get_torch(
self,
- input: torch.Tensor,
+ image: torch.Tensor,
**kwargs: Any,
) -> torch.Tensor:
import torch.nn.functional as F
@@ -1179,42 +1243,36 @@ def _get_torch(
if k % 2 == 0:
raise ValueError("MedianBlur requires an odd kernel size.")
- last_dim_is_channel = input.ndim >= 3
+ last_dim_is_channel = image.ndim >= 3
if last_dim_is_channel:
- input = input.movedim(-1, 0) # C, ...
+ image = image.movedim(-1, 0) # C, ...
else:
- input = input.unsqueeze(0) # 1, ...
+ image = image.unsqueeze(0) # 1, ...
# add batch dimension
- input = input.unsqueeze(0) # 1, C, ...
+ image = image.unsqueeze(0) # 1, C, ...
- spatial_dims = input.ndim - 2
+ spatial_dims = image.ndim - 2
pad = k // 2
- # padding
pad_tuple = []
for _ in range(spatial_dims):
pad_tuple.extend([pad, pad])
pad_tuple = tuple(reversed(pad_tuple))
- input = F.pad(
- input,
+ image = F.pad(
+ image,
pad_tuple,
mode=kwargs.get("mode", "reflect"),
)
- # unfold spatial dimensions
if spatial_dims == 1:
- x = input.unfold(2, k, 1)
+ x = image.unfold(2, k, 1)
elif spatial_dims == 2:
- x = (
- input
- .unfold(2, k, 1)
- .unfold(3, k, 1)
- )
+ x = image.unfold(2, k, 1).unfold(3, k, 1)
elif spatial_dims == 3:
x = (
- input
+ image
.unfold(2, k, 1)
.unfold(3, k, 1)
.unfold(4, k, 1)
@@ -1224,11 +1282,9 @@ def _get_torch(
f"{spatial_dims}D median blur not supported"
)
- # flatten neighborhood and take median
x = x.contiguous().view(*x.shape[:-spatial_dims], -1)
x = x.median(dim=-1).values
- # restore layout
x = x.squeeze(0)
if last_dim_is_channel:
x = x.movedim(0, -1)
@@ -1238,219 +1294,479 @@ def _get_torch(
return x
#TODO ***CM*** revise typing, docstring, unit test
-class Pool:
- """
- DeepTrack v2 replacement for Pool.
-
- Generic, center-preserving block pooling with NumPy and Torch backends.
- Public API matches v1: a single integer ksize.
-
- Pool size semantics:
- - 2D input -> (ksize, ksize, 1)
- - 3D input -> (ksize, ksize, ksize)
- """
-
- _TORCH_REDUCERS_2D: Dict[Callable, Callable] = {
- np.mean: lambda x, k, s: F.avg_pool2d(x, k, s),
- np.sum: lambda x, k, s: F.avg_pool2d(x, k, s) * (k[0] * k[1]),
- np.max: lambda x, k, s: F.max_pool2d(x, k, s),
- np.min: lambda x, k, s: -F.max_pool2d(-x, k, s),
- }
+class Pool(BackendDispatched, Feature):
+ """Abstract base class for pooling features."""
- _TORCH_REDUCERS_3D: Dict[Callable, Callable] = {
- np.mean: lambda x, k, s: F.avg_pool3d(x, k, s),
- np.sum: lambda x, k, s: F.avg_pool3d(x, k, s) * (k[0] * k[1] * k[2]),
- np.max: lambda x, k, s: F.max_pool3d(x, k, s),
- np.min: lambda x, k, s: -F.max_pool3d(-x, k, s),
- }
+ _NUMPY_IMPL = "_get_numpy"
+ _TORCH_IMPL = "_get_torch"
def __init__(
self,
- pooling_function: Callable,
- ksize: int = 2,
+ ksize: PropertyLike[int] = 2,
+ **kwargs: Any,
):
- if pooling_function not in (
- np.mean, np.sum, np.min, np.max, np.median
- ):
- raise ValueError(
- "Unsupported pooling_function. "
- "Use one of: np.mean, np.sum, np.min, np.max, np.median."
- )
-
- if not isinstance(ksize, int) or ksize < 1:
- raise ValueError("ksize must be a positive integer.")
-
- self.pooling_function = pooling_function
self.ksize = int(ksize)
+ super().__init__(**kwargs)
- def _get_pool_size(self, array) -> Tuple[int, int, int]:
- """
- Determine pooling kernel size based on semantic dimensionality.
+ def get(
+ self,
+ image: np.ndarray | torch.Tensor,
+ **kwargs: Any,
+ ) -> np.ndarray | torch.Tensor:
+ return self._dispatch_backend(image, **kwargs)
- - 2D images: (Nx, Ny) or (Nx, Ny, C) -> pool in x,y only
- - 3D volumes: (Nx, Ny, Nz) or (Nx, Ny, Nz, C) -> pool in x,y,z
- - Never pool over channels
- """
+ # ---------- shared helpers ----------
+
+ def _get_pool_size(self, array) -> tuple[int, int, int]:
k = self.ksize
- # 2D image
if array.ndim == 2:
return k, k, 1
- # 3D array: could be (x, y, z) or (x, y, c)
if array.ndim == 3:
- # Heuristic: small last dim → channels
- if array.shape[-1] <= 4:
+ if array.shape[-1] <= 4: # channel heuristic
return k, k, 1
return k, k, k
- # 4D array: (x, y, z, c)
if array.ndim == 4:
return k, k, k
- raise ValueError(
- f"Unsupported array shape {array.shape} for pooling."
- )
+ raise ValueError(f"Unsupported array shape {array.shape}")
def _crop_center(self, array):
px, py, pz = self._get_pool_size(array)
- # 2D (or effectively 2D)
+ # 2D or effectively 2D (channels-last)
if array.ndim < 3 or pz == 1:
H, W = array.shape[:2]
crop_h = (H // px) * px
crop_w = (W // py) * py
- off_h = (H - crop_h) // 2
- off_w = (W - crop_w) // 2
- return array[
- off_h : off_h + crop_h,
- off_w : off_w + crop_w,
- ...
- ]
-
- # 3D
+ return array[:crop_h, :crop_w, ...]
+
+ # 3D volume
Z, H, W = array.shape[:3]
crop_z = (Z // pz) * pz
crop_h = (H // px) * px
crop_w = (W // py) * py
- off_z = (Z - crop_z) // 2
- off_h = (H - crop_h) // 2
- off_w = (W - crop_w) // 2
- return array[
- off_z : off_z + crop_z,
- off_h : off_h + crop_h,
- off_w : off_w + crop_w,
- ...
- ]
-
- def _pool_numpy(self, array: np.ndarray) -> np.ndarray:
- array = self._crop_center(array)
- px, py, pz = self._get_pool_size(array)
+ return array[:crop_z, :crop_h, :crop_w, ...]
- if array.ndim < 3 or pz == 1:
- pool_shape = (px, py) + (1,) * (array.ndim - 2)
+ # ---------- abstract backends ----------
+
+ def _get_numpy(self, image: np.ndarray, **kwargs):
+ raise NotImplementedError
+
+ def _get_torch(self, image: torch.Tensor, **kwargs):
+ raise NotImplementedError
+
+
+class AveragePooling(Pool):
+ """Average pooling feature.
+
+ Downsamples the input by applying mean pooling over non-overlapping
+ blocks of size `ksize`, preserving the center of the image and never
+ pooling over channel dimensions.
+
+ Works with NumPy and PyTorch backends.
+ """
+
+ # ---------- NumPy backend ----------
+
+ def _get_numpy(
+ self,
+ image: np.ndarray,
+ **kwargs: Any,
+ ) -> np.ndarray:
+ image = self._crop_center(image)
+ px, py, pz = self._get_pool_size(image)
+
+ # 2D or effectively 2D (channels-last)
+ if image.ndim < 3 or pz == 1:
+ block_size = (px, py) + (1,) * (image.ndim - 2)
else:
- pool_shape = (pz, px, py) + (1,) * (array.ndim - 3)
+ # 3D volume (optionally with channels)
+ block_size = (pz, px, py) + (1,) * (image.ndim - 3)
return skimage.measure.block_reduce(
- array,
- block_size=pool_shape,
- func=self.pooling_function,
+ image,
+ block_size=block_size,
+ func=np.mean,
)
- def _pool_torch(self, array: torch.Tensor) -> torch.Tensor:
- array = self._crop_center(array)
- px, py, pz = self._get_pool_size(array)
+ # ---------- Torch backend ----------
- is_3d = array.ndim >= 3 and pz > 1
+ def _get_torch(
+ self,
+ image: torch.Tensor,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ import torch.nn.functional as F
+
+ image = self._crop_center(image)
+ px, py, pz = self._get_pool_size(image)
+ is_3d = image.ndim >= 3 and pz > 1
+
+ # Flatten extra (channel / feature) dimensions into C
if not is_3d:
- extra = array.shape[2:]
+ extra = image.shape[2:]
C = int(np.prod(extra)) if extra else 1
- x = array.reshape(1, C, array.shape[0], array.shape[1])
+ x = image.reshape(1, C, image.shape[0], image.shape[1])
kernel = (px, py)
stride = (px, py)
- reducers = self._TORCH_REDUCERS_2D
+ pooled = F.avg_pool2d(x, kernel, stride)
else:
- extra = array.shape[3:]
+ extra = image.shape[3:]
C = int(np.prod(extra)) if extra else 1
- x = array.reshape(
- 1, C, array.shape[0], array.shape[1], array.shape[2]
+ x = image.reshape(
+ 1, C,
+ image.shape[0],
+ image.shape[1],
+ image.shape[2],
)
kernel = (pz, px, py)
stride = (pz, px, py)
- reducers = self._TORCH_REDUCERS_3D
-
- # Median: explicit unfolding
- if self.pooling_function is np.median:
- if is_3d:
- x_u = (
- x.unfold(2, pz, pz)
- .unfold(3, px, px)
- .unfold(4, py, py)
- )
- x_u = x_u.contiguous().view(
- 1, C,
- x_u.shape[2],
- x_u.shape[3],
- x_u.shape[4],
- -1,
- )
- pooled = x_u.median(dim=-1).values
- else:
- x_u = x.unfold(2, px, px).unfold(3, py, py)
- x_u = x_u.contiguous().view(
- 1, C,
- x_u.shape[2],
- x_u.shape[3],
- -1,
- )
- pooled = x_u.median(dim=-1).values
+ pooled = F.avg_pool3d(x, kernel, stride)
+
+ # Restore original layout
+ return pooled.reshape(pooled.shape[2:] + extra)
+
+
+class MaxPooling(Pool):
+ """Max pooling feature.
+
+ Downsamples the input by applying max pooling over non-overlapping
+ blocks of size `ksize`, preserving the center of the image and never
+ pooling over channel dimensions.
+
+ Works with NumPy and PyTorch backends.
+ """
+
+ # ---------- NumPy backend ----------
+
+ def _get_numpy(
+ self,
+ image: np.ndarray,
+ **kwargs: Any,
+ ) -> np.ndarray:
+ image = self._crop_center(image)
+ px, py, pz = self._get_pool_size(image)
+
+ # 2D or effectively 2D (channels-last)
+ if image.ndim < 3 or pz == 1:
+ block_size = (px, py) + (1,) * (image.ndim - 2)
+ else:
+ # 3D volume (optionally with channels)
+ block_size = (pz, px, py) + (1,) * (image.ndim - 3)
+
+ return skimage.measure.block_reduce(
+ image,
+ block_size=block_size,
+ func=np.max,
+ )
+
+ # ---------- Torch backend ----------
+
+ def _get_torch(
+ self,
+ image: torch.Tensor,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ import torch.nn.functional as F
+
+ image = self._crop_center(image)
+ px, py, pz = self._get_pool_size(image)
+
+ is_3d = image.ndim >= 3 and pz > 1
+
+ # Flatten extra (channel / feature) dimensions into C
+ if not is_3d:
+ extra = image.shape[2:]
+ C = int(np.prod(extra)) if extra else 1
+ x = image.reshape(1, C, image.shape[0], image.shape[1])
+ kernel = (px, py)
+ stride = (px, py)
+ pooled = F.max_pool2d(x, kernel, stride)
else:
- reducer = reducers[self.pooling_function]
- pooled = reducer(x, kernel, stride)
+ extra = image.shape[3:]
+ C = int(np.prod(extra)) if extra else 1
+ x = image.reshape(
+ 1, C,
+ image.shape[0],
+ image.shape[1],
+ image.shape[2],
+ )
+ kernel = (pz, px, py)
+ stride = (pz, px, py)
+ pooled = F.max_pool3d(x, kernel, stride)
+ # Restore original layout
return pooled.reshape(pooled.shape[2:] + extra)
- def __call__(self, array):
- if isinstance(array, np.ndarray):
- return self._pool_numpy(array)
- if TORCH_AVAILABLE and isinstance(array, torch.Tensor):
- return self._pool_torch(array)
+class MinPooling(Pool):
+ """Min pooling feature.
+
+ Downsamples the input by applying min pooling over non-overlapping
+ blocks of size `ksize`, preserving the center of the image and never
+ pooling over channel dimensions.
+
+ Works with NumPy and PyTorch backends.
+
+ """
+
+ # ---------- NumPy backend ----------
+
+ def _get_numpy(
+ self,
+ image: np.ndarray,
+ **kwargs: Any,
+ ) -> np.ndarray:
+ image = self._crop_center(image)
+ px, py, pz = self._get_pool_size(image)
+
+ # 2D or effectively 2D (channels-last)
+ if image.ndim < 3 or pz == 1:
+ block_size = (px, py) + (1,) * (image.ndim - 2)
+ else:
+ # 3D volume (optionally with channels)
+ block_size = (pz, px, py) + (1,) * (image.ndim - 3)
- raise TypeError(
- "Pool only supports np.ndarray or torch.Tensor inputs."
+ return skimage.measure.block_reduce(
+ image,
+ block_size=block_size,
+ func=np.min,
)
+ # ---------- Torch backend ----------
-class AveragePooling(Pool):
- def __init__(self, ksize: int = 2):
- super().__init__(np.mean, ksize)
+ def _get_torch(
+ self,
+ image: torch.Tensor,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ import torch.nn.functional as F
+
+ image = self._crop_center(image)
+ px, py, pz = self._get_pool_size(image)
+
+ is_3d = image.ndim >= 3 and pz > 1
+
+ # Flatten extra (channel / feature) dimensions into C
+ if not is_3d:
+ extra = image.shape[2:]
+ C = int(np.prod(extra)) if extra else 1
+ x = image.reshape(1, C, image.shape[0], image.shape[1])
+ kernel = (px, py)
+ stride = (px, py)
+
+ # min(x) = -max(-x)
+ pooled = -F.max_pool2d(-x, kernel, stride)
+ else:
+ extra = image.shape[3:]
+ C = int(np.prod(extra)) if extra else 1
+ x = image.reshape(
+ 1, C,
+ image.shape[0],
+ image.shape[1],
+ image.shape[2],
+ )
+ kernel = (pz, px, py)
+ stride = (pz, px, py)
+
+ pooled = -F.max_pool3d(-x, kernel, stride)
+
+ # Restore original layout
+ return pooled.reshape(pooled.shape[2:] + extra)
class SumPooling(Pool):
- def __init__(self, ksize: int = 2):
- super().__init__(np.sum, ksize)
+ """Sum pooling feature.
+ Downsamples the input by applying sum pooling over non-overlapping
+ blocks of size `ksize`, preserving the center of the image and never
+ pooling over channel dimensions.
-class MinPooling(Pool):
- def __init__(self, ksize: int = 2):
- super().__init__(np.min, ksize)
+ Works with NumPy and PyTorch backends.
+ """
+ # ---------- NumPy backend ----------
-class MaxPooling(Pool):
- def __init__(self, ksize: int = 2):
- super().__init__(np.max, ksize)
+ def _get_numpy(
+ self,
+ image: np.ndarray,
+ **kwargs: Any,
+ ) -> np.ndarray:
+ image = self._crop_center(image)
+ px, py, pz = self._get_pool_size(image)
+
+ # 2D or effectively 2D (channels-last)
+ if image.ndim < 3 or pz == 1:
+ block_size = (px, py) + (1,) * (image.ndim - 2)
+ else:
+ # 3D volume (optionally with channels)
+ block_size = (pz, px, py) + (1,) * (image.ndim - 3)
+
+ return skimage.measure.block_reduce(
+ image,
+ block_size=block_size,
+ func=np.sum,
+ )
+
+ # ---------- Torch backend ----------
+
+ def _get_torch(
+ self,
+ image: torch.Tensor,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ import torch.nn.functional as F
+
+ image = self._crop_center(image)
+ px, py, pz = self._get_pool_size(image)
+
+ is_3d = image.ndim >= 3 and pz > 1
+
+ # Flatten extra (channel / feature) dimensions into C
+ if not is_3d:
+ extra = image.shape[2:]
+ C = int(np.prod(extra)) if extra else 1
+ x = image.reshape(1, C, image.shape[0], image.shape[1])
+ kernel = (px, py)
+ stride = (px, py)
+ pooled = F.avg_pool2d(x, kernel, stride) * (px * py)
+ else:
+ extra = image.shape[3:]
+ C = int(np.prod(extra)) if extra else 1
+ x = image.reshape(
+ 1, C,
+ image.shape[0],
+ image.shape[1],
+ image.shape[2],
+ )
+ kernel = (pz, px, py)
+ stride = (pz, px, py)
+ pooled = F.avg_pool3d(x, kernel, stride) * (pz * px * py)
+
+ # Restore original layout
+ return pooled.reshape(pooled.shape[2:] + extra)
class MedianPooling(Pool):
- def __init__(self, ksize: int = 2):
- super().__init__(np.median, ksize)
+ """Median pooling feature.
+ Downsamples the input by applying median pooling over non-overlapping
+ blocks of size `ksize`, preserving the center of the image and never
+ pooling over channel dimensions.
+ Notes
+ -----
+ - NumPy backend uses `skimage.measure.block_reduce`
+ - Torch backend performs explicit unfolding followed by `median`
+ - Torch median pooling is significantly more expensive than mean/max
+
+ Median pooling is not differentiable and should not be used inside
+ trainable neural networks requiring gradient-based optimization.
+
+ """
+
+ # ---------- NumPy backend ----------
+
+ def _get_numpy(
+ self,
+ image: np.ndarray,
+ **kwargs: Any,
+ ) -> np.ndarray:
+ image = self._crop_center(image)
+ px, py, pz = self._get_pool_size(image)
+
+ if image.ndim < 3 or pz == 1:
+ block_size = (px, py) + (1,) * (image.ndim - 2)
+ else:
+ block_size = (pz, px, py) + (1,) * (image.ndim - 3)
+
+ return skimage.measure.block_reduce(
+ image,
+ block_size=block_size,
+ func=np.median,
+ )
-class Resize(Feature):
+ # ---------- Torch backend ----------
+
+ def _get_torch(
+ self,
+ image: torch.Tensor,
+ **kwargs: Any,
+ ) -> torch.Tensor:
+
+ if not self._warned:
+ warnings.warn(
+ "MedianPooling is not differentiable and is expensive on the "
+ "Torch backend. Avoid using it inside trainable models.",
+ UserWarning,
+ stacklevel=2,
+ )
+ self._warned = True
+
+ image = self._crop_center(image)
+ px, py, pz = self._get_pool_size(image)
+
+ is_3d = image.ndim >= 3 and pz > 1
+
+ if not is_3d:
+ # 2D case (with optional channels)
+ extra = image.shape[2:]
+ C = int(np.prod(extra)) if extra else 1
+
+ x = image.reshape(1, C, image.shape[0], image.shape[1])
+
+ # unfold: (B, C, H', W', px, py)
+ x_u = (
+ x.unfold(2, px, px)
+ .unfold(3, py, py)
+ )
+
+ x_u = x_u.contiguous().view(
+ 1, C,
+ x_u.shape[2],
+ x_u.shape[3],
+ -1,
+ )
+
+ pooled = x_u.median(dim=-1).values
+
+ else:
+ # 3D case (with optional channels)
+ extra = image.shape[3:]
+ C = int(np.prod(extra)) if extra else 1
+
+ x = image.reshape(
+ 1, C,
+ image.shape[0],
+ image.shape[1],
+ image.shape[2],
+ )
+
+ # unfold: (B, C, Z', Y', X', pz, px, py)
+ x_u = (
+ x.unfold(2, pz, pz)
+ .unfold(3, px, px)
+ .unfold(4, py, py)
+ )
+
+ x_u = x_u.contiguous().view(
+ 1, C,
+ x_u.shape[2],
+ x_u.shape[3],
+ x_u.shape[4],
+ -1,
+ )
+
+ pooled = x_u.median(dim=-1).values
+
+ return pooled.reshape(pooled.shape[2:] + extra)
+
+
+class Resize(BackendDispatched, Feature):
"""Resize an image to a specified size.
`Resize` resizes images following the channels-last semantic
@@ -1517,6 +1833,9 @@ class Resize(Feature):
"""
+ _NUMPY_IMPL = "_get_numpy"
+ _TORCH_IMPL = "_get_torch"
+
def __init__(
self: Resize,
dsize: PropertyLike[tuple[int, int]] = (256, 256),
@@ -1527,8 +1846,8 @@ def __init__(
Parameters
----------
dsize: PropertyLike[tuple[int, int]]
- The target size. Format depends on backend: `(width, height)` for
- NumPy, `(height, width)` for PyTorch. Default is (256, 256).
+ The target size. dsize is always (width, height) for both backends.
+ Default is (256, 256).
**kwargs: Any
Additional arguments passed to the parent `Feature` class.
@@ -1583,83 +1902,102 @@ def get(
`align_corners=False`, which closely matches OpenCV’s default behavior.
"""
+ return self._dispatch_backend(image, dsize=dsize, **kwargs)
- target_w, target_h = dsize
+ # ---------- NumPy backend (OpenCV) ----------
- # Torch backend
- if apc.is_torch_array(image):
- import torch.nn.functional as F
+ def _get_numpy(
+ self,
+ image: np.ndarray,
+ dsize: tuple[int, int],
+ **kwargs: Any,
+ ) -> np.ndarray:
+
+ target_w, target_h = dsize
- original_ndim = image.ndim
- has_channels = (
- image.ndim >= 3 and image.shape[-1] <= 4
+ # Prefer OpenCV if available
+ if OPENCV_AVAILABLE:
+ import cv2
+ return utils.safe_call(
+ cv2.resize,
+ positional_args=[image, (target_w, target_h)],
+ **kwargs,
)
+ if not OPENCV_AVAILABLE and kwargs:
+ warnings.warn("OpenCV not available: resize kwargs may be ignored.", UserWarning)
- # Bring to (N, C, H, W)
- if image.ndim == 2:
- # (H, W) -> (1, 1, H, W)
- x = image.unsqueeze(0).unsqueeze(0)
+ # Fallback: skimage (always available in DT)
+ from skimage.transform import resize as sk_resize
- elif image.ndim == 3 and has_channels:
- # (H, W, C) -> (1, C, H, W)
- x = image.permute(2, 0, 1).unsqueeze(0)
+ if image.ndim == 2:
+ out_shape = (target_h, target_w)
+ else:
+ out_shape = (target_h, target_w) + image.shape[2:]
- elif image.ndim == 3:
- # (Z, H, W) -> treat Z as batch
- x = image.unsqueeze(1)
+ out = sk_resize(
+ image,
+ out_shape,
+ preserve_range=True,
+ anti_aliasing=True,
+ )
- elif image.ndim == 4 and has_channels:
- # (Z, H, W, C) -> (Z, C, H, W)
- x = image.permute(0, 3, 1, 2)
+ return out.astype(image.dtype, copy=False)
- else:
- raise ValueError(
- f"Unsupported tensor shape {image.shape} for Resize."
- )
+ # ---------- Torch backend ----------
- # Resize spatial dimensions
- resized = F.interpolate(
- x,
- size=(target_h, target_w),
- mode="bilinear",
- align_corners=False,
- )
+ def _get_torch(
+ self,
+ image: torch.Tensor,
+ dsize: tuple[int, int],
+ **kwargs: Any,
+ ) -> torch.Tensor:
+ import torch.nn.functional as F
- # Restore original layout
- if original_ndim == 2:
- return resized.squeeze(0).squeeze(0)
+ target_w, target_h = dsize
- if original_ndim == 3 and has_channels:
- return resized.squeeze(0).permute(1, 2, 0)
+ original_ndim = image.ndim
+ has_channels = image.ndim >= 3 and image.shape[-1] <= 4
- if original_ndim == 3:
- return resized.squeeze(1)
+ # Convert to (N, C, H, W)
+ if image.ndim == 2:
+ x = image.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
- if original_ndim == 4:
- return resized.permute(0, 2, 3, 1)
+ elif image.ndim == 3 and has_channels:
+ x = image.permute(2, 0, 1).unsqueeze(0) # (1, C, H, W)
- raise RuntimeError("Unexpected shape restoration path.")
+ elif image.ndim == 3:
+ x = image.unsqueeze(1) # (Z, 1, H, W)
- # NumPy / OpenCV backend
- else:
- import cv2
+ elif image.ndim == 4 and has_channels:
+ x = image.permute(0, 3, 1, 2) # (Z, C, H, W)
- # OpenCV expects (width, height)
- return utils.safe_call(
- cv2.resize,
- positional_args=[image, (target_w, target_h)],
- **kwargs,
+ else:
+ raise ValueError(
+ f"Unsupported tensor shape {image.shape} for Resize."
)
+ # Resize spatial dimensions
+ x = F.interpolate(
+ x,
+ size=(target_h, target_w),
+ mode="bilinear",
+ align_corners=False,
+ )
-if OPENCV_AVAILABLE:
- _map_mode_to_cv2_borderType = {
- "reflect": cv2.BORDER_REFLECT,
- "wrap": cv2.BORDER_WRAP,
- "constant": cv2.BORDER_CONSTANT,
- "mirror": cv2.BORDER_REFLECT_101,
- "nearest": cv2.BORDER_REPLICATE,
- }
+ # Restore original layout
+ if original_ndim == 2:
+ return x.squeeze(0).squeeze(0)
+
+ if original_ndim == 3 and has_channels:
+ return x.squeeze(0).permute(1, 2, 0)
+
+ if original_ndim == 3:
+ return x.squeeze(1)
+
+ if original_ndim == 4:
+ return x.permute(0, 2, 3, 1)
+
+ raise RuntimeError("Unexpected shape restoration path.")
#TODO ***JH*** revise BlurCV2 - torch, typing, docstring, unit test
@@ -1679,7 +2017,7 @@ class BlurCV2(Feature):
Methods
-------
- `get(image: np.ndarray | Image, **kwargs: Any) --> np.ndarray`
+ `get(image: np.ndarray, **kwargs: Any) --> np.ndarray`
Applies the blurring filter to the input image.
Examples
@@ -1710,52 +2048,17 @@ class BlurCV2(Feature):
"""
- def __new__(
- cls: type,
- *args: tuple,
- **kwargs: Any,
- ):
- """Ensures that OpenCV (cv2) is available before instantiating the
- class.
-
- Overrides the default object creation process to check that the `cv2`
- module is available before creating the class. If OpenCV is not
- installed, it raises an ImportError with instructions for installation.
-
- Parameters
- ----------
- *args : tuple
- Positional arguments passed to the class constructor.
- **kwargs : dict
- Keyword arguments passed to the class constructor.
-
- Returns
- -------
- BlurCV2
- An instance of the BlurCV2 feature class.
-
- Raises
- ------
- ImportError
- If the OpenCV (`cv2`) module is not available in the current
- environment.
-
- """
-
- print(cls.__name__)
-
- if not OPENCV_AVAILABLE:
- raise ImportError(
- "OpenCV not installed on device. Since OpenCV is an optional "
- f"dependency of DeepTrack2. To use {cls.__name__}, "
- "you need to install it manually."
- )
-
- return super().__new__(cls)
+ _MODE_TO_BORDER = {
+ "reflect": "BORDER_REFLECT",
+ "wrap": "BORDER_WRAP",
+ "constant": "BORDER_CONSTANT",
+ "mirror": "BORDER_REFLECT_101",
+ "nearest": "BORDER_REPLICATE",
+ }
def __init__(
self: BlurCV2,
- filter_function: Callable,
+ filter_function: Callable | str,
mode: PropertyLike[str] = "reflect",
**kwargs: Any,
):
@@ -1775,13 +2078,20 @@ def __init__(
"""
+ if not OPENCV_AVAILABLE:
+ raise ImportError(
+ "OpenCV not installed on device. Since OpenCV is an optional "
+ f"dependency of DeepTrack2. To use {self.__class__.__name__}, "
+ "you need to install it manually."
+ )
+
self.filter = filter_function
- borderType = _map_mode_to_cv2_borderType[mode]
- super().__init__(borderType=borderType, **kwargs)
+ self.mode = mode
+ super().__init__(**kwargs)
def get(
self: BlurCV2,
- image: np.ndarray | Image,
+ image: np.ndarray,
**kwargs: Any,
) -> np.ndarray:
"""Applies the blurring filter to the input image.
@@ -1790,8 +2100,8 @@ def get(
Parameters
----------
- image: np.ndarray | Image
- The input image to blur. Can be a NumPy array or DeepTrack Image.
+ image: np.ndarray
+ The input image to blur. Must be a NumPy array.
**kwargs: Any
Additional parameters for the blurring function.
@@ -1805,12 +2115,31 @@ def get(
if apc.is_torch_array(image):
raise TypeError(
"BlurCV2 only supports NumPy arrays. "
- "For Torch tensors, use Blur or GaussianBlur instead."
+ "Use GaussianBlur / AverageBlur for Torch."
)
+ import cv2
+
+ filter_fn = getattr(cv2, self.filter) if isinstance(self.filter, str) else self.filter
+
+ try:
+ border_attr = self._MODE_TO_BORDER[self.mode]
+ except KeyError as e:
+ raise ValueError(f"Unsupported border mode '{self.mode}'") from e
+
+ try:
+ border = getattr(cv2, border_attr)
+ except AttributeError as e:
+ raise RuntimeError(f"OpenCV missing border constant '{border_attr}'") from e
+
+ # preserve legacy behavior
kwargs.pop("name", None)
- result = self.filter(src=image, **kwargs)
- return result
+
+ return filter_fn(
+ src=image,
+ borderType=border,
+ **kwargs,
+ )
#TODO ***JH*** revise BilateralBlur - torch, typing, docstring, unit test
@@ -1894,7 +2223,7 @@ def __init__(
"""
super().__init__(
- cv2.bilateralFilter,
+ filter_function="bilateralFilter",
d=d,
sigmaColor=sigma_color,
sigmaSpace=sigma_space,
@@ -1903,13 +2232,25 @@ def __init__(
def isotropic_dilation(
- mask,
+ mask: np.ndarray | torch.Tensor,
radius: float,
*,
- backend: str,
+ backend: Literal["numpy", "torch"],
device=None,
dtype=None,
-):
+) -> np.ndarray | torch.Tensor:
+ """
+ Binary dilation using an isotropic (NumPy) or box-shaped (Torch) kernel.
+
+ Notes
+ -----
+ - NumPy backend uses a true Euclidean ball.
+ - Torch backend uses a cubic structuring element (approximate).
+ - Torch backend supports 3D masks only.
+ - Operation is non-differentiable.
+
+ """
+
if radius <= 0:
return mask
@@ -1938,13 +2279,25 @@ def isotropic_dilation(
def isotropic_erosion(
- mask,
+ mask: np.ndarray | torch.Tensor,
radius: float,
*,
- backend: str,
+ backend: Literal["numpy", "torch"],
device=None,
dtype=None,
-):
+) -> np.ndarray | torch.Tensor:
+ """
+ Binary erosion using an isotropic (NumPy) or box-shaped (Torch) kernel.
+
+ Notes
+ -----
+ - NumPy backend uses a true Euclidean ball.
+ - Torch backend uses a cubic structuring element (approximate).
+ - Torch backend supports 3D masks only.
+ - Operation is non-differentiable.
+
+ """
+
if radius <= 0:
return mask
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index a7394689..70a54de6 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -395,16 +395,15 @@ def get(
volume_samples,
**additional_sample_kwargs,
)
+ if volume_samples:
+ # Interpret the merged volume semantically
+ sample_volume = self._extract_contrast_volume(
+ ScatteredVolume(
+ array=sample_volume,
+ properties=volume_samples[0].properties,
+ ),
+ )
- print('prop', volume_samples[0].properties)
-
- # Interpret the merged volume semantically
- sample_volume = self._extract_contrast_volume(
- ScatteredVolume(
- array=sample_volume,
- properties=volume_samples[0].properties,
- ),
- )
# Let the objective know about the limits of the volume and all the fields.
propagate_data_to_dependencies(
@@ -723,7 +722,6 @@ def _process_properties(
wavelength = propertydict["wavelength"]
voxel_size = get_active_voxel_size()
radius = NA / wavelength * np.array(voxel_size)
- print('Pupil radius (in pixels):', radius)
if np.any(radius[:2] > 0.5):
required_upscale = np.max(np.ceil(radius[:2] * 2))
@@ -1074,7 +1072,6 @@ class Fluorescence(Optics):
"""
-
def validate_input(self, scattered):
"""Semantic validation for fluorescence microscopy."""
@@ -1085,9 +1082,9 @@ def validate_input(self, scattered):
)
- def extract_contrast_volume(self, scattered: ScatteredVolume) -> np.ndarray:
- voxel_size = np.asarray(get_active_voxel_size(), float)
- voxel_volume = np.prod(voxel_size)
+ def extract_contrast_volume(self, scattered: ScatteredVolume, **kwargs) -> np.ndarray:
+ scale = np.asarray(get_active_scale(), float)
+ scale_volume = np.prod(scale)
intensity = scattered.get_property("intensity", None)
value = scattered.get_property("value", None)
@@ -1103,7 +1100,7 @@ def extract_contrast_volume(self, scattered: ScatteredVolume) -> np.ndarray:
# Preferred, physically meaningful case
if intensity is not None:
- return intensity * voxel_volume * scattered.array
+ return intensity * scale_volume * scattered.array
# Fallback: legacy / dimensionless brightness
warnings.warn(
@@ -1379,7 +1376,6 @@ def extract_contrast_volume(
refractive_index_medium: float,
**kwargs: Any,
) -> np.ndarray:
- print('ri_medium', refractive_index_medium)
ri = scattered.get_property("refractive_index", None)
value = scattered.get_property("value", None)
@@ -2093,6 +2089,9 @@ class NonOverlapping(Feature):
- This feature performs bounding cube checks first to quickly reject
obvious overlaps before voxel-level checks.
- If the bounding cubes overlap, precise voxel-based checks are performed.
+ - The feature may be computationally intensive for large numbers of volumes
+ or high-density placements.
+ - The feature is not differentiable.
Examples
---------
@@ -2375,7 +2374,7 @@ def _check_non_overlapping(
new_volumes.append(new_volume)
- list_of_volumes = new_volumes
+ list_of_volumes = new_volumes
min_distance = 1
# The position of the top left corner of each volume (index (0, 0, 0)).