Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ def __init__(
factors: Sequence[float] | float = 0,
fixed_mean: bool = True,
preserve_range: bool = False,
channel_wise: bool = False,
dtype: DtypeLike = np.float32,
) -> None:
"""
Expand All @@ -611,8 +612,8 @@ def __init__(
fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
to ensure that the output has the same mean as the input.
channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
channel of the image if True.
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
channel of the image if True.
dtype: output data type, if None, same as input image. defaults to float32.

"""
Expand All @@ -626,29 +627,51 @@ def __init__(
self.factor = self.factors[0]
self.fixed_mean = fixed_mean
self.preserve_range = preserve_range
self.channel_wise = channel_wise
self.dtype = dtype

self.scaler = ScaleIntensityFixedMean(
factor=self.factor, fixed_mean=self.fixed_mean, preserve_range=self.preserve_range, dtype=self.dtype
factor=self.factor,
fixed_mean=self.fixed_mean,
preserve_range=self.preserve_range,
channel_wise=self.channel_wise,
dtype=self.dtype,
)

def randomize(self, data: Any | None = None) -> None:
super().randomize(None)
if not self._do_transform:
return None
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])
if self.channel_wise:
self.factor = [self.R.uniform(low=self.factors[0], high=self.factors[1]) for _ in range(data.shape[0])] # type: ignore
else:
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])

def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
if randomize:
self.randomize()
self.randomize(img)

if not self._do_transform:
return convert_data_type(img, dtype=self.dtype)[0]

if self.channel_wise:
out = []
for i, d in enumerate(img):
out_channel = ScaleIntensityFixedMean(
factor=self.factor[i], # type: ignore
fixed_mean=self.fixed_mean,
preserve_range=self.preserve_range,
dtype=self.dtype,
)(d[None])[0]
out.append(out_channel)
ret: NdarrayOrTensor = torch.stack(out)
ret = convert_to_dst_type(ret, dst=img, dtype=self.dtype or img.dtype)[0]
return ret

return self.scaler(img, self.factor)


Expand Down
21 changes: 17 additions & 4 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ def __init__(
factors: Sequence[float] | float,
fixed_mean: bool = True,
preserve_range: bool = False,
channel_wise: bool = False,
prob: float = 0.1,
dtype: DtypeLike = np.float32,
allow_missing_keys: bool = False,
Expand All @@ -683,8 +684,8 @@ def __init__(
fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
to ensure that the output has the same mean as the input.
channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
channel of the image if True.
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
channel of the image if True.
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: don't raise exception if key is missing.

Expand All @@ -694,7 +695,12 @@ def __init__(
self.fixed_mean = fixed_mean
self.preserve_range = preserve_range
self.scaler = RandScaleIntensityFixedMean(
factors=factors, fixed_mean=self.fixed_mean, preserve_range=preserve_range, dtype=dtype, prob=1.0
factors=factors,
fixed_mean=self.fixed_mean,
preserve_range=preserve_range,
channel_wise=channel_wise,
dtype=dtype,
prob=1.0,
)

def set_random_state(
Expand All @@ -712,8 +718,15 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
return d

# expect all the specified keys have same spatial shape and share same random factors
first_key: Hashable = self.first_key(d)
if first_key == ():
for key in self.key_iterator(d):
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
return d

# all the keys share the same random scale factor
self.scaler.randomize(None)
self.scaler.randomize(d[first_key])
for key in self.key_iterator(d):
d[key] = self.scaler(d[key], randomize=False)
return d
Expand Down
29 changes: 29 additions & 0 deletions tests/transforms/test_rand_scale_intensity_fixed_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,35 @@ def test_value(self, p):
expected = expected + mn
assert_allclose(result, expected, type_test="tensor", atol=1e-7)

@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_channel_wise(self, p):
scaler = RandScaleIntensityFixedMean(prob=1.0, factors=0.5, channel_wise=True)
scaler.set_random_state(seed=0)
im = p(self.imt)
result = scaler(im)
np.random.seed(0)
# simulate the randomize() of transform
np.random.random()
channel_num = self.imt.shape[0]
factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)]
expected = np.stack(
[np.asarray((self.imt[i] - self.imt[i].mean()) * (1 + factor[i]) + self.imt[i].mean()) for i in range(channel_num)]
).astype(np.float32)
assert_allclose(result, p(expected), atol=1e-4, rtol=1e-4, type_test=False)

@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_channel_wise_preserve_range(self, p):
scaler = RandScaleIntensityFixedMean(
prob=1.0, factors=0.5, channel_wise=True, preserve_range=True, fixed_mean=True
)
scaler.set_random_state(seed=0)
im = p(self.imt)
result = scaler(im)
# verify output is within input range per channel
for c in range(self.imt.shape[0]):
assert float(result[c].min()) >= float(im[c].min()) - 1e-6
assert float(result[c].max()) <= float(im[c].max()) + 1e-6


if __name__ == "__main__":
unittest.main()
17 changes: 17 additions & 0 deletions tests/transforms/test_rand_scale_intensity_fixed_meand.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ def test_value(self):
expected = expected + mn
assert_allclose(result[key], p(expected), type_test="tensor", atol=1e-6)

def test_channel_wise(self):
key = "img"
for p in TEST_NDARRAYS:
scaler = RandScaleIntensityFixedMeand(keys=[key], factors=0.5, prob=1.0, channel_wise=True)
scaler.set_random_state(seed=0)
im = p(self.imt)
result = scaler({key: im})
np.random.seed(0)
# simulate the randomize function of transform
np.random.random()
channel_num = self.imt.shape[0]
factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)]
expected = np.stack(
[np.asarray((self.imt[i] - self.imt[i].mean()) * (1 + factor[i]) + self.imt[i].mean()) for i in range(channel_num)]
).astype(np.float32)
assert_allclose(result[key], p(expected), atol=1e-4, rtol=1e-4, type_test=False)


if __name__ == "__main__":
unittest.main()
Loading