diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index cc2306bd..d2a57c86 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -37,7 +37,7 @@ def _check_ns_shape_dtype( check_dtype: bool, check_shape: bool, check_scalar: bool, -) -> ModuleType: # numpydoc ignore=RT03 +) -> tuple[Array, Array, ModuleType]: # numpydoc ignore=RT03 """ Assert that namespace, shape and dtype of the two arrays match. @@ -55,24 +55,35 @@ def _check_ns_shape_dtype( Returns ------- - Arrays namespace. + Actual array, desired array, and their namespace. """ - actual_xp = array_namespace(actual) # Raises on scalars and lists + actual_xp = array_namespace(actual) # Raises on Python scalars and lists desired_xp = array_namespace(desired) msg = f"namespaces do not match: {actual_xp} != f{desired_xp}" assert actual_xp == desired_xp, msg + if is_numpy_namespace(actual_xp) and check_scalar: + # only NumPy distinguishes between scalars and arrays; we do if check_scalar. + _msg = ( + "array-ness does not match:\n Actual: " + f"{type(actual)}\n Desired: {type(desired)}" + ) + assert np.isscalar(actual) == np.isscalar(desired), _msg + # Dask uses nan instead of None for unknown shapes actual_shape = cast(tuple[float, ...], actual.shape) desired_shape = cast(tuple[float, ...], desired.shape) assert None not in actual_shape # Requires explicit support assert None not in desired_shape + if is_dask_namespace(desired_xp): if any(math.isnan(i) for i in actual_shape): - actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + actual.compute_chunk_sizes() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + actual_shape = cast(tuple[float, ...], actual.shape) if any(math.isnan(i) for i in desired_shape): - desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + desired.compute_chunk_sizes() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + desired_shape = cast(tuple[float, ...], desired.shape) if check_shape: msg = f"shapes do not match: {actual_shape} != f{desired_shape}" @@ -82,24 +93,16 @@ def _check_ns_shape_dtype( # np.testing.assert_array_equal etc even when strict=False, but not for # non-materializable arrays. # This check excludes 0d arrays as they are special-cased in NumPy. - actual_size = math.prod(actual_shape) # pyright: ignore[reportUnknownArgumentType] - desired_size = math.prod(desired_shape) # pyright: ignore[reportUnknownArgumentType] + actual_size = math.prod(actual_shape) + desired_size = math.prod(desired_shape) msg = f"sizes do not match: {actual_size} != f{desired_size}" assert actual_size == desired_size, msg if check_dtype: msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}" assert actual.dtype == desired.dtype, msg - - if is_numpy_namespace(actual_xp) and check_scalar: - # only NumPy distinguishes between scalars and arrays; we do if check_scalar. - _msg = ( - "array-ness does not match:\n Actual: " - f"{type(actual)}\n Desired: {type(desired)}" - ) - assert np.isscalar(actual) == np.isscalar(desired), _msg - - return desired_xp + desired = desired_xp.broadcast_to(desired, actual_shape) + return actual, desired, desired_xp def _is_materializable(x: Array) -> bool: @@ -169,7 +172,9 @@ def xp_assert_equal( xp_assert_close : Similar function for inexact equality checks. numpy.testing.assert_array_equal : Similar function for NumPy arrays. """ - xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar) + actual, desired, xp = _check_ns_shape_dtype( + actual, desired, check_dtype, check_shape, check_scalar + ) if not _is_materializable(actual): return actual_np = as_numpy_array(actual, xp=xp) @@ -211,7 +216,7 @@ def xp_assert_less( xp_assert_close : Similar function for inexact equality checks. numpy.testing.assert_array_equal : Similar function for NumPy arrays. """ - xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar) + x, y, xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar) if not _is_materializable(x): return x_np = as_numpy_array(x, xp=xp) @@ -267,7 +272,9 @@ def xp_assert_close( ----- The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`. """ - xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar) + actual, desired, xp = _check_ns_shape_dtype( + actual, desired, check_dtype, check_shape, check_scalar + ) if not _is_materializable(actual): return