diff --git a/CHANGELOG.md b/CHANGELOG.md index d20d4ef..8565870 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added * Added `mkl_fft` patching for NumPy, with `mkl_fft` context manager, `is_patched` query, and `patch_numpy_fft` and `restore_numpy_fft` calls to replace `numpy.fft` calls with calls from `mkl_fft.interfaces.numpy_fft` [gh-224](https://github.com/IntelPython/mkl_fft/pull/224) +### Changed +* In `mkl_fft.fftn` and `mkl_fft.ifftn`, improved checking of the shape argument `s` to use faster direct transforms more often. This makes performance more consistent between `mkl_fft.fftn/ifftn` and `mkl.interfaces`. [gh-283](https://github.com/IntelPython/mkl_fft/pull/283) + ### Removed * Dropped support for Python 3.9 [gh-243](https://github.com/IntelPython/mkl_fft/pull/243) diff --git a/mkl_fft/_fft_utils.py b/mkl_fft/_fft_utils.py index ad6a055..c0fda37 100644 --- a/mkl_fft/_fft_utils.py +++ b/mkl_fft/_fft_utils.py @@ -43,11 +43,12 @@ def _check_norm(norm): ) -def _check_shapes_for_direct(xs, shape, axes): +def _check_shapes_for_direct(xs, shape, axes, check_complimentary=False): if len(axes) > 7: # Intel MKL supports up to 7D return False - if not (len(xs) == len(shape)): - # full-dimensional transform + if not (len(xs) == len(shape)) and not check_complimentary: + # full-dimensional transform is required for direct, + # but less than full is OK for complimentary. return False if not (len(set(axes)) == len(axes)): # repeated axes @@ -382,6 +383,7 @@ def _c2c_fftnd_impl( if direction not in [-1, +1]: raise ValueError("Direction of FFT should +1 or -1") + _complementary = s is None valid_dtypes = [np.complex64, np.complex128, np.float32, np.float64] # _direct_fftnd requires complex type, and full-dimensional transform if isinstance(x, np.ndarray) and x.size != 0 and x.ndim > 1: @@ -392,6 +394,12 @@ def _c2c_fftnd_impl( xs, xa = _cook_nd_args(x, s, axes) if _check_shapes_for_direct(xs, x.shape, xa): _direct = True + # See if s matches the shape of x along the given axes. + # If it does, we can use _iter_complementary rather than _iter_fftnd. + if _check_shapes_for_direct( + xs, x.shape, xa, check_complimentary=True + ): + _complementary = True _direct = _direct and x.dtype in valid_dtypes else: _direct = False @@ -404,7 +412,7 @@ def _c2c_fftnd_impl( out=out, ) else: - if s is None and x.dtype in valid_dtypes: + if _complementary and x.dtype in valid_dtypes: x = np.asarray(x) if out is None: res = np.empty_like(x, dtype=_output_dtype(x.dtype)) diff --git a/mkl_fft/tests/test_fftnd.py b/mkl_fft/tests/test_fftnd.py index 4f04a11..c1808aa 100644 --- a/mkl_fft/tests/test_fftnd.py +++ b/mkl_fft/tests/test_fftnd.py @@ -264,6 +264,30 @@ def test_s_axes_out(dtype, s, axes, func): assert_allclose(r1, r2, rtol=rtol, atol=atol) +@requires_numpy_2 +@pytest.mark.parametrize("dtype", [complex, float]) +@pytest.mark.parametrize("axes", [(1, 2, 3), (-1, -2, -3), [2, 1, 3]]) +@pytest.mark.parametrize("func", ["fftn", "ifftn", "rfftn"]) +def test_s_none_vs_s_full(dtype, axes, func): + shape = (2, 30, 20, 10) + if dtype is complex and func != "rfftn": + x = np.random.random(shape) + 1j * np.random.random(shape) + else: + x = np.random.random(shape) + + implied_s = [shape[ax] for ax in axes] + if func == "irfftn": + implied_s[-1] = 2 * (implied_s[-1] - 1) + + r1 = getattr(np.fft, func)(x, axes=axes) + r2 = getattr(mkl_fft, func)(x, axes=axes) + r3 = getattr(mkl_fft, func)(x, s=implied_s, axes=axes) + + rtol, atol = _get_rtol_atol(x) + assert_allclose(r1, r2, rtol=rtol, atol=atol) + assert_allclose(r1, r3, rtol=rtol, atol=atol) + + @pytest.mark.parametrize("dtype", [complex, float]) @pytest.mark.parametrize("axes", [(2, 0, 2, 0), (0, 1, 1), (2, 0, 1, 3, 2, 1)]) @pytest.mark.parametrize("func", ["rfftn", "irfftn"])