From a1f80237de7576aee5ee6d099c5bb86f68817455 Mon Sep 17 00:00:00 2001 From: Mohamed Salah Date: Sat, 14 Feb 2026 22:25:40 +0200 Subject: [PATCH] Add missing channel_wise parameter to RandScaleIntensityFixedMean (#8363) The channel_wise option was documented in docstrings but not actually implemented in RandScaleIntensityFixedMean and its dictionary variant RandScaleIntensityFixedMeand. This adds the parameter following the existing pattern from RandScaleIntensity, generating per-channel random scale factors when channel_wise=True. Fixes #8363 --- monai/transforms/intensity/array.py | 33 ++++++++++++++++--- monai/transforms/intensity/dictionary.py | 21 +++++++++--- .../test_rand_scale_intensity_fixed_mean.py | 29 ++++++++++++++++ .../test_rand_scale_intensity_fixed_meand.py | 17 ++++++++++ 4 files changed, 91 insertions(+), 9 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 0421d34492..28b19b864e 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -601,6 +601,7 @@ def __init__( factors: Sequence[float] | float = 0, fixed_mean: bool = True, preserve_range: bool = False, + channel_wise: bool = False, dtype: DtypeLike = np.float32, ) -> None: """ @@ -611,8 +612,8 @@ def __init__( fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling to ensure that the output has the same mean as the input. channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied - on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the - channel of the image if True. + on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the + channel of the image if True. dtype: output data type, if None, same as input image. defaults to float32. """ @@ -626,17 +627,25 @@ def __init__( self.factor = self.factors[0] self.fixed_mean = fixed_mean self.preserve_range = preserve_range + self.channel_wise = channel_wise self.dtype = dtype self.scaler = ScaleIntensityFixedMean( - factor=self.factor, fixed_mean=self.fixed_mean, preserve_range=self.preserve_range, dtype=self.dtype + factor=self.factor, + fixed_mean=self.fixed_mean, + preserve_range=self.preserve_range, + channel_wise=self.channel_wise, + dtype=self.dtype, ) def randomize(self, data: Any | None = None) -> None: super().randomize(None) if not self._do_transform: return None - self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) + if self.channel_wise: + self.factor = [self.R.uniform(low=self.factors[0], high=self.factors[1]) for _ in range(data.shape[0])] # type: ignore + else: + self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1]) def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: """ @@ -644,11 +653,25 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen """ img = convert_to_tensor(img, track_meta=get_track_meta()) if randomize: - self.randomize() + self.randomize(img) if not self._do_transform: return convert_data_type(img, dtype=self.dtype)[0] + if self.channel_wise: + out = [] + for i, d in enumerate(img): + out_channel = ScaleIntensityFixedMean( + factor=self.factor[i], # type: ignore + fixed_mean=self.fixed_mean, + preserve_range=self.preserve_range, + dtype=self.dtype, + )(d[None])[0] + out.append(out_channel) + ret: NdarrayOrTensor = torch.stack(out) + ret = convert_to_dst_type(ret, dst=img, dtype=self.dtype or img.dtype)[0] + return ret + return self.scaler(img, self.factor) diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 3d29b3031d..c24d8c8954 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -669,6 +669,7 @@ def __init__( factors: Sequence[float] | float, fixed_mean: bool = True, preserve_range: bool = False, + channel_wise: bool = False, prob: float = 0.1, dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, @@ -683,8 +684,8 @@ def __init__( fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling to ensure that the output has the same mean as the input. channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied - on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the - channel of the image if True. + on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the + channel of the image if True. dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. @@ -694,7 +695,12 @@ def __init__( self.fixed_mean = fixed_mean self.preserve_range = preserve_range self.scaler = RandScaleIntensityFixedMean( - factors=factors, fixed_mean=self.fixed_mean, preserve_range=preserve_range, dtype=dtype, prob=1.0 + factors=factors, + fixed_mean=self.fixed_mean, + preserve_range=preserve_range, + channel_wise=channel_wise, + dtype=dtype, + prob=1.0, ) def set_random_state( @@ -712,8 +718,15 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) return d + # expect all the specified keys have same spatial shape and share same random factors + first_key: Hashable = self.first_key(d) + if first_key == (): + for key in self.key_iterator(d): + d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) + return d + # all the keys share the same random scale factor - self.scaler.randomize(None) + self.scaler.randomize(d[first_key]) for key in self.key_iterator(d): d[key] = self.scaler(d[key], randomize=False) return d diff --git a/tests/transforms/test_rand_scale_intensity_fixed_mean.py b/tests/transforms/test_rand_scale_intensity_fixed_mean.py index ac45a9d463..2649830267 100644 --- a/tests/transforms/test_rand_scale_intensity_fixed_mean.py +++ b/tests/transforms/test_rand_scale_intensity_fixed_mean.py @@ -36,6 +36,35 @@ def test_value(self, p): expected = expected + mn assert_allclose(result, expected, type_test="tensor", atol=1e-7) + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_channel_wise(self, p): + scaler = RandScaleIntensityFixedMean(prob=1.0, factors=0.5, channel_wise=True) + scaler.set_random_state(seed=0) + im = p(self.imt) + result = scaler(im) + np.random.seed(0) + # simulate the randomize() of transform + np.random.random() + channel_num = self.imt.shape[0] + factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)] + expected = np.stack( + [np.asarray((self.imt[i] - self.imt[i].mean()) * (1 + factor[i]) + self.imt[i].mean()) for i in range(channel_num)] + ).astype(np.float32) + assert_allclose(result, p(expected), atol=1e-4, rtol=1e-4, type_test=False) + + @parameterized.expand([[p] for p in TEST_NDARRAYS]) + def test_channel_wise_preserve_range(self, p): + scaler = RandScaleIntensityFixedMean( + prob=1.0, factors=0.5, channel_wise=True, preserve_range=True, fixed_mean=True + ) + scaler.set_random_state(seed=0) + im = p(self.imt) + result = scaler(im) + # verify output is within input range per channel + for c in range(self.imt.shape[0]): + assert float(result[c].min()) >= float(im[c].min()) - 1e-6 + assert float(result[c].max()) <= float(im[c].max()) + 1e-6 + if __name__ == "__main__": unittest.main() diff --git a/tests/transforms/test_rand_scale_intensity_fixed_meand.py b/tests/transforms/test_rand_scale_intensity_fixed_meand.py index 55111a4c2e..29592281b3 100644 --- a/tests/transforms/test_rand_scale_intensity_fixed_meand.py +++ b/tests/transforms/test_rand_scale_intensity_fixed_meand.py @@ -36,6 +36,23 @@ def test_value(self): expected = expected + mn assert_allclose(result[key], p(expected), type_test="tensor", atol=1e-6) + def test_channel_wise(self): + key = "img" + for p in TEST_NDARRAYS: + scaler = RandScaleIntensityFixedMeand(keys=[key], factors=0.5, prob=1.0, channel_wise=True) + scaler.set_random_state(seed=0) + im = p(self.imt) + result = scaler({key: im}) + np.random.seed(0) + # simulate the randomize function of transform + np.random.random() + channel_num = self.imt.shape[0] + factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)] + expected = np.stack( + [np.asarray((self.imt[i] - self.imt[i].mean()) * (1 + factor[i]) + self.imt[i].mean()) for i in range(channel_num)] + ).astype(np.float32) + assert_allclose(result[key], p(expected), atol=1e-4, rtol=1e-4, type_test=False) + if __name__ == "__main__": unittest.main()