Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
* Added `mkl_fft` patching for NumPy, with `mkl_fft` context manager, `is_patched` query, and `patch_numpy_fft` and `restore_numpy_fft` calls to replace `numpy.fft` calls with calls from `mkl_fft.interfaces.numpy_fft` [gh-224](https://github.com/IntelPython/mkl_fft/pull/224)

### Changed
* In `mkl_fft.fftn` and `mkl_fft.ifftn`, improved checking of the shape argument `s` to use faster direct transforms more often. This makes performance more consistent between `mkl_fft.fftn/ifftn` and `mkl.interfaces`. [gh-283](https://github.com/IntelPython/mkl_fft/pull/283)

### Removed
* Dropped support for Python 3.9 [gh-243](https://github.com/IntelPython/mkl_fft/pull/243)

Expand Down
16 changes: 12 additions & 4 deletions mkl_fft/_fft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@ def _check_norm(norm):
)


def _check_shapes_for_direct(xs, shape, axes):
def _check_shapes_for_direct(xs, shape, axes, check_complimentary=False):
if len(axes) > 7: # Intel MKL supports up to 7D
return False
if not (len(xs) == len(shape)):
# full-dimensional transform
if not (len(xs) == len(shape)) and not check_complimentary:
# full-dimensional transform is required for direct,
# but less than full is OK for complimentary.
return False
if not (len(set(axes)) == len(axes)):
# repeated axes
Expand Down Expand Up @@ -382,6 +383,7 @@ def _c2c_fftnd_impl(
if direction not in [-1, +1]:
raise ValueError("Direction of FFT should +1 or -1")

_complementary = s is None
valid_dtypes = [np.complex64, np.complex128, np.float32, np.float64]
# _direct_fftnd requires complex type, and full-dimensional transform
if isinstance(x, np.ndarray) and x.size != 0 and x.ndim > 1:
Expand All @@ -392,6 +394,12 @@ def _c2c_fftnd_impl(
xs, xa = _cook_nd_args(x, s, axes)
if _check_shapes_for_direct(xs, x.shape, xa):
_direct = True
# See if s matches the shape of x along the given axes.
# If it does, we can use _iter_complementary rather than _iter_fftnd.
if _check_shapes_for_direct(
xs, x.shape, xa, check_complimentary=True
):
_complementary = True
_direct = _direct and x.dtype in valid_dtypes
else:
_direct = False
Expand All @@ -404,7 +412,7 @@ def _c2c_fftnd_impl(
out=out,
)
else:
if s is None and x.dtype in valid_dtypes:
if _complementary and x.dtype in valid_dtypes:
x = np.asarray(x)
if out is None:
res = np.empty_like(x, dtype=_output_dtype(x.dtype))
Expand Down
24 changes: 24 additions & 0 deletions mkl_fft/tests/test_fftnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,30 @@ def test_s_axes_out(dtype, s, axes, func):
assert_allclose(r1, r2, rtol=rtol, atol=atol)


@requires_numpy_2
@pytest.mark.parametrize("dtype", [complex, float])
@pytest.mark.parametrize("axes", [(1, 2, 3), (-1, -2, -3), [2, 1, 3]])
@pytest.mark.parametrize("func", ["fftn", "ifftn", "rfftn"])
def test_s_none_vs_s_full(dtype, axes, func):
shape = (2, 30, 20, 10)
if dtype is complex and func != "rfftn":
x = np.random.random(shape) + 1j * np.random.random(shape)
else:
x = np.random.random(shape)

implied_s = [shape[ax] for ax in axes]
if func == "irfftn":
implied_s[-1] = 2 * (implied_s[-1] - 1)

r1 = getattr(np.fft, func)(x, axes=axes)
r2 = getattr(mkl_fft, func)(x, axes=axes)
r3 = getattr(mkl_fft, func)(x, s=implied_s, axes=axes)

rtol, atol = _get_rtol_atol(x)
assert_allclose(r1, r2, rtol=rtol, atol=atol)
assert_allclose(r1, r3, rtol=rtol, atol=atol)


@pytest.mark.parametrize("dtype", [complex, float])
@pytest.mark.parametrize("axes", [(2, 0, 2, 0), (0, 1, 1), (2, 0, 1, 3, 2, 1)])
@pytest.mark.parametrize("func", ["rfftn", "irfftn"])
Expand Down
Loading