Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Removed
* Dropped support for Python 3.9 [gh-243](https://github.com/IntelPython/mkl_fft/pull/243)

### Fixed
* Fix `TypeError` exception raised with empty axes [gh-288](https://github.com/IntelPython/mkl_fft/pull/288)

## [2.1.2] - 2025-12-02

### Added
Expand Down
125 changes: 121 additions & 4 deletions mkl_fft/_fft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,67 @@ def _init_nd_shape_and_axes(x, shape, axes):


def _iter_complementary(x, axes, func, kwargs, result):
"""
Apply FFT function by iterating over complementary axes.

This function applies an FFT operation to slices of the input array
by iterating over all axes that are NOT in the `axes` parameter
(the complementary axes). For each position in the complementary axes,
it applies the FFT function to a slice along the specified axes.

Parameters
----------
x : ndarray
Input array.
axes : int, sequence of ints, or None
Axes along which to perform the FFT operation. The function iterates
over the complementary axes (axes not in this parameter). If ``None``,
performs direct N-D FFT without iteration.
Default: None
func : callable
FFT function to apply to each slice. Should accept array input and
return transformed output.
kwargs : dict
Additional keyword arguments to pass to `func`.
result : ndarray
Pre-allocated output array where results are stored.

Returns
-------
ndarray
The transformed array (same as `result`).

Notes
-----
For complex input, the function uses in-place operations with the `out`
parameter passed for better performance. For real input, `np.copyto` is
used instead to avoid element ordering issues that can occur with the
`out` parameter in certain FFT operations.

Examples
--------
Consider an input array with shape (3, 4, 5) and performing FFT
along axis 2 only:

>>> x = np.random.random((3, 4, 5))
>>> result = np.empty((3, 4, 5), dtype=np.complex128)
>>> _iter_complementary(
... x, axes=(2,), func=_direct_fftnd,
... kwargs={'direction': 1, 'fsc': 1.0}, result=result
... )

The function will iterate over axes 0 and 1 (complementary axes)
and apply `_direct_fftnd` to each 1-D slice along axis 2:

- Iteration 0: func(x[0, 0, :]) -> result[0, 0, :]
- Iteration 1: func(x[0, 1, :]) -> result[0, 1, :]
- ...
- Iteration 11: func(x[2, 3, :]) -> result[2, 3, :]

Total: 3 * 4 = 12 FFT operations on arrays of shape (5,).

"""

if axes is None:
# s and axes are None, direct N-D FFT
return func(x, **kwargs, out=result)
Expand All @@ -233,8 +294,10 @@ def _iter_complementary(x, axes, func, kwargs, result):
m_ind = _flat_to_multi(ind, sub_shape)
for k1, k2 in zip(dual_ind, m_ind):
sl[k1] = k2
tsl = tuple(sl)

if np.issubdtype(x.dtype, np.complexfloating):
func(x[tuple(sl)], **kwargs, out=result[tuple(sl)])
func(x[tsl], **kwargs, out=result[tsl])
else:
# For c2c FFT, if the input is real, half of the output is the
# complex conjugate of the other half. Instead of upcasting the
Expand All @@ -247,7 +310,7 @@ def _iter_complementary(x, axes, func, kwargs, result):
# array appeared in the second half of the NumPy output array,
# while the equivalent element in the NumPy array was the conjugate
# of the mkl_fft output array.
np.copyto(result[tuple(sl)], func(x[tuple(sl)], **kwargs))
np.copyto(result[tsl], func(x[tsl], **kwargs))

return result

Expand All @@ -260,7 +323,49 @@ def _iter_fftnd(
direction=+1,
scale_function=lambda ind: 1.0,
):
a = np.asarray(a)
"""
Perform N-D FFT as a series of 1-D FFTs along specified axes.

This function implements N-D FFT by applying 1-D FFT iteratively along each
axis. The axes are processed in reverse order to end with the first axis
given.

Parameters
----------
a : ndarray
Input array.
s : sequence of ints, optional
Shape of the FFT output along each axis in `axes`. If not provided, the
shape is inferred from the input array.
Default: ``None``
axes : sequence of ints, optional
Axes along which to compute the FFT. If not provided, all axes are used.
Default: ``None``
out : ndarray, optional
Output array to store the result. Used for in-place operations when
possible.
Default: ``None``
direction : int, optional
FFT direction: ``+1`` for forward FFT, ``-1`` for inverse FFT.
Default: ``+1``
scale_function : callable, optional
Function that takes iteration index and returns the scaling factor for
that step. Used to apply normalization at specific iteration steps.
Default: ``lambda ind: 1.0``

Returns
-------
ndarray
The transformed array.

Notes
-----
The function optimizes memory usage by performing in-place calculations
when possible. In-place operations are used everywhere except when the
array size changes after the first FFT along an axis.

"""

s, axes = _init_nd_shape_and_axes(a, s, axes)

# Combine the two, but in reverse, to end with the first axis given.
Expand Down Expand Up @@ -412,8 +517,20 @@ def _c2c_fftnd_impl(
out=out,
)
else:
x = np.asarray(x)

# Fast path: FFT over no axes is complete identity (preserve dtype)
_, xa = _cook_nd_args(x, s, axes)
if len(xa) == 0:
if out is None:
out = x.copy()
else:
_validate_out_array(out, x, x.dtype)
np.copyto(out, x)
# No scaling applied - identity transform has no normalization
return out

if _complementary and x.dtype in valid_dtypes:
x = np.asarray(x)
if out is None:
res = np.empty_like(x, dtype=_output_dtype(x.dtype))
else:
Expand Down
45 changes: 45 additions & 0 deletions mkl_fft/tests/test_fftnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,48 @@ def test_out_strided(axes, func):
expected = getattr(np.fft, func)(x, axes=axes, out=out)

assert_allclose(result, expected, strict=True)


@pytest.mark.parametrize(
"dtype", [np.float32, np.float64, np.complex64, np.complex128]
)
@pytest.mark.parametrize("shape", [(3, 4), (5,), (2, 3, 4), (10, 20)])
@pytest.mark.parametrize("norm", [None, "ortho", "forward", "backward"])
@pytest.mark.parametrize("func", ["fftn", "ifftn", "fft2", "ifft2"])
def test_empty_axes(dtype, shape, norm, func):
if np.issubdtype(dtype, np.complexfloating):
x = rnd.random(shape).astype(dtype) + 1j * rnd.random(shape).astype(
dtype
)
else:
x = rnd.random(shape).astype(dtype)

# Test fftn with axes=()
result = getattr(mkl_fft, func)(x, axes=(), norm=norm)
expected = getattr(np.fft, func)(x, axes=(), norm=norm)

rtol, atol = _get_rtol_atol(result)
assert_allclose(result, expected, rtol=rtol, atol=atol, strict=True)


@pytest.mark.parametrize(
"dtype", [np.float32, np.float64, np.complex64, np.complex128]
)
@pytest.mark.parametrize("func", ["fftn", "ifftn", "fft2", "ifft2"])
def test_empty_axes_with_out(dtype, func):
if np.issubdtype(dtype, np.complexfloating):
x = rnd.random((3, 4)).astype(dtype) + 1j * rnd.random((3, 4)).astype(
dtype
)
else:
x = rnd.random((3, 4)).astype(dtype)

# For axes=(), output dtype should match input dtype (identity transform)
out = np.empty_like(x, dtype=dtype)
result = getattr(mkl_fft, func)(x, axes=(), out=out)
expected = getattr(np.fft, func)(x, axes=())

# Result should be written to out
assert result is out
rtol, atol = _get_rtol_atol(result)
assert_allclose(result, expected, rtol=rtol, atol=atol, strict=True)
Loading