From 68ae214980664803c0799c198c99468bacc2ba02 Mon Sep 17 00:00:00 2001 From: chillenb Date: Mon, 2 Mar 2026 21:21:11 -0500 Subject: [PATCH 1/6] take fast path if c2c transform does not need padding or trimming --- mkl_fft/_fft_utils.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/mkl_fft/_fft_utils.py b/mkl_fft/_fft_utils.py index ad6a055..6bec3ea 100644 --- a/mkl_fft/_fft_utils.py +++ b/mkl_fft/_fft_utils.py @@ -62,6 +62,17 @@ def _check_shapes_for_direct(xs, shape, axes): return False return True +def _check_shapes_equiv_s_none(s, shape, axes): + for si, ai in zip(s, axes): + try: + sh_ai = shape[ai] + except IndexError: + raise ValueError("Invalid axis (%d) specified" % ai) + + if si != sh_ai: + return False + return True + def _compute_fwd_scale(norm, n, shape): _check_norm(norm) @@ -382,6 +393,7 @@ def _c2c_fftnd_impl( if direction not in [-1, +1]: raise ValueError("Direction of FFT should +1 or -1") + s_equiv_to_none = s is None valid_dtypes = [np.complex64, np.complex128, np.float32, np.float64] # _direct_fftnd requires complex type, and full-dimensional transform if isinstance(x, np.ndarray) and x.size != 0 and x.ndim > 1: @@ -392,6 +404,10 @@ def _c2c_fftnd_impl( xs, xa = _cook_nd_args(x, s, axes) if _check_shapes_for_direct(xs, x.shape, xa): _direct = True + # See if s matches the shape of x along the given axes. + # If it does, we can use _iter_complementary rather than _iter_fftnd. + if _check_shapes_equiv_s_none(xs, x.shape, xa): + s_equiv_to_none = True _direct = _direct and x.dtype in valid_dtypes else: _direct = False @@ -404,7 +420,7 @@ def _c2c_fftnd_impl( out=out, ) else: - if s is None and x.dtype in valid_dtypes: + if s_equiv_to_none and x.dtype in valid_dtypes: x = np.asarray(x) if out is None: res = np.empty_like(x, dtype=_output_dtype(x.dtype)) From 1c6c6f88bd2a250552b1517fbca10f8be31fa35a Mon Sep 17 00:00:00 2001 From: Christopher Hillenbrand Date: Mon, 2 Mar 2026 22:20:00 -0500 Subject: [PATCH 2/6] Satisfy linter --- mkl_fft/_fft_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mkl_fft/_fft_utils.py b/mkl_fft/_fft_utils.py index 6bec3ea..1762db7 100644 --- a/mkl_fft/_fft_utils.py +++ b/mkl_fft/_fft_utils.py @@ -62,6 +62,7 @@ def _check_shapes_for_direct(xs, shape, axes): return False return True + def _check_shapes_equiv_s_none(s, shape, axes): for si, ai in zip(s, axes): try: From c56754e21ced6f16af0db14a860a5748c814096e Mon Sep 17 00:00:00 2001 From: chillenb Date: Tue, 3 Mar 2026 11:19:38 -0500 Subject: [PATCH 3/6] add test for s=None vs equivalent s --- mkl_fft/tests/test_fftnd.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/mkl_fft/tests/test_fftnd.py b/mkl_fft/tests/test_fftnd.py index 4f04a11..e52163a 100644 --- a/mkl_fft/tests/test_fftnd.py +++ b/mkl_fft/tests/test_fftnd.py @@ -264,6 +264,30 @@ def test_s_axes_out(dtype, s, axes, func): assert_allclose(r1, r2, rtol=rtol, atol=atol) +@requires_numpy_2 +@pytest.mark.parametrize("dtype", [complex, float]) +@pytest.mark.parametrize("axes", [(0, 1, 2), (-1, -2, -3), [1, 0, 2]]) +@pytest.mark.parametrize("func", ["fftn", "ifftn", "rfftn"]) +def test_s_none_vs_s_full(dtype, axes, func): + shape = (30, 20, 10) + if dtype is complex and func != "rfftn": + x = np.random.random(shape) + 1j * np.random.random(shape) + else: + x = np.random.random(shape) + + implied_s = [shape[ax] for ax in axes] + if func == "irfftn": + implied_s[-1] = 2 * (implied_s[-1] - 1) + + r1 = getattr(np.fft, func)(x, axes=axes) + r2 = getattr(mkl_fft, func)(x, axes=axes) + r3 = getattr(mkl_fft, func)(x, s=implied_s, axes=axes) + + rtol, atol = _get_rtol_atol(x) + assert_allclose(r1, r2, rtol=rtol, atol=atol) + assert_allclose(r1, r3, rtol=rtol, atol=atol) + + @pytest.mark.parametrize("dtype", [complex, float]) @pytest.mark.parametrize("axes", [(2, 0, 2, 0), (0, 1, 1), (2, 0, 1, 3, 2, 1)]) @pytest.mark.parametrize("func", ["rfftn", "irfftn"]) From 2b05835d7ddc6c1b874a3d02927b0f01062d7248 Mon Sep 17 00:00:00 2001 From: chillenb Date: Fri, 6 Mar 2026 11:04:43 -0500 Subject: [PATCH 4/6] make sure test_s_none_vs_s_full actually uses iter_complementary --- mkl_fft/tests/test_fftnd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mkl_fft/tests/test_fftnd.py b/mkl_fft/tests/test_fftnd.py index e52163a..c1808aa 100644 --- a/mkl_fft/tests/test_fftnd.py +++ b/mkl_fft/tests/test_fftnd.py @@ -266,10 +266,10 @@ def test_s_axes_out(dtype, s, axes, func): @requires_numpy_2 @pytest.mark.parametrize("dtype", [complex, float]) -@pytest.mark.parametrize("axes", [(0, 1, 2), (-1, -2, -3), [1, 0, 2]]) +@pytest.mark.parametrize("axes", [(1, 2, 3), (-1, -2, -3), [2, 1, 3]]) @pytest.mark.parametrize("func", ["fftn", "ifftn", "rfftn"]) def test_s_none_vs_s_full(dtype, axes, func): - shape = (30, 20, 10) + shape = (2, 30, 20, 10) if dtype is complex and func != "rfftn": x = np.random.random(shape) + 1j * np.random.random(shape) else: From 98b19648acdcddae863551b64028262477c47ee2 Mon Sep 17 00:00:00 2001 From: chillenb Date: Fri, 6 Mar 2026 11:10:10 -0500 Subject: [PATCH 5/6] rename s_equiv_to_none and consolidate shape-checking logic --- mkl_fft/_fft_utils.py | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/mkl_fft/_fft_utils.py b/mkl_fft/_fft_utils.py index 1762db7..c0fda37 100644 --- a/mkl_fft/_fft_utils.py +++ b/mkl_fft/_fft_utils.py @@ -43,11 +43,12 @@ def _check_norm(norm): ) -def _check_shapes_for_direct(xs, shape, axes): +def _check_shapes_for_direct(xs, shape, axes, check_complimentary=False): if len(axes) > 7: # Intel MKL supports up to 7D return False - if not (len(xs) == len(shape)): - # full-dimensional transform + if not (len(xs) == len(shape)) and not check_complimentary: + # full-dimensional transform is required for direct, + # but less than full is OK for complimentary. return False if not (len(set(axes)) == len(axes)): # repeated axes @@ -63,18 +64,6 @@ def _check_shapes_for_direct(xs, shape, axes): return True -def _check_shapes_equiv_s_none(s, shape, axes): - for si, ai in zip(s, axes): - try: - sh_ai = shape[ai] - except IndexError: - raise ValueError("Invalid axis (%d) specified" % ai) - - if si != sh_ai: - return False - return True - - def _compute_fwd_scale(norm, n, shape): _check_norm(norm) if norm in (None, "backward"): @@ -394,7 +383,7 @@ def _c2c_fftnd_impl( if direction not in [-1, +1]: raise ValueError("Direction of FFT should +1 or -1") - s_equiv_to_none = s is None + _complementary = s is None valid_dtypes = [np.complex64, np.complex128, np.float32, np.float64] # _direct_fftnd requires complex type, and full-dimensional transform if isinstance(x, np.ndarray) and x.size != 0 and x.ndim > 1: @@ -407,8 +396,10 @@ def _c2c_fftnd_impl( _direct = True # See if s matches the shape of x along the given axes. # If it does, we can use _iter_complementary rather than _iter_fftnd. - if _check_shapes_equiv_s_none(xs, x.shape, xa): - s_equiv_to_none = True + if _check_shapes_for_direct( + xs, x.shape, xa, check_complimentary=True + ): + _complementary = True _direct = _direct and x.dtype in valid_dtypes else: _direct = False @@ -421,7 +412,7 @@ def _c2c_fftnd_impl( out=out, ) else: - if s_equiv_to_none and x.dtype in valid_dtypes: + if _complementary and x.dtype in valid_dtypes: x = np.asarray(x) if out is None: res = np.empty_like(x, dtype=_output_dtype(x.dtype)) From 4a66ad5ea5b93a2e90d99a91d5a1da0c5ab0a002 Mon Sep 17 00:00:00 2001 From: chillenb Date: Fri, 6 Mar 2026 11:34:08 -0500 Subject: [PATCH 6/6] add this to changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d20d4ef..8565870 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added * Added `mkl_fft` patching for NumPy, with `mkl_fft` context manager, `is_patched` query, and `patch_numpy_fft` and `restore_numpy_fft` calls to replace `numpy.fft` calls with calls from `mkl_fft.interfaces.numpy_fft` [gh-224](https://github.com/IntelPython/mkl_fft/pull/224) +### Changed +* In `mkl_fft.fftn` and `mkl_fft.ifftn`, improved checking of the shape argument `s` to use faster direct transforms more often. This makes performance more consistent between `mkl_fft.fftn/ifftn` and `mkl.interfaces`. [gh-283](https://github.com/IntelPython/mkl_fft/pull/283) + ### Removed * Dropped support for Python 3.9 [gh-243](https://github.com/IntelPython/mkl_fft/pull/243)