diff --git a/dpctl_ext/tensor/CMakeLists.txt b/dpctl_ext/tensor/CMakeLists.txt index 864e34ddaba4..cf55035c23d9 100644 --- a/dpctl_ext/tensor/CMakeLists.txt +++ b/dpctl_ext/tensor/CMakeLists.txt @@ -63,6 +63,46 @@ set(_tensor_impl_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/repeat.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp ) +set(_accumulator_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/accumulators_common.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp +) +set(_reduction_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduction_common.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/all.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/any.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmax.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmin.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/logsumexp.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/max.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/min.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/prod.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduce_hypot.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/sum.cpp +) +set(_sorting_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/isin.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_argsort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/topk.cpp +) +set(_tensor_accumulation_impl_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp + ${_accumulator_sources} +) +set(_tensor_reductions_impl_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_reductions.cpp + ${_reduction_sources} +) +set(_tensor_sorting_impl_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp + ${_sorting_sources} +) set(_static_lib_trgt simplify_iteration_space) @@ -85,6 +125,24 @@ add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_impl_sources}) target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt}) list(APPEND _py_trgts ${python_module_name}) +set(python_module_name _tensor_accumulation_impl) +pybind11_add_module(${python_module_name} MODULE ${_tensor_accumulation_impl_sources}) +add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_accumulation_impl_sources}) +target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt}) +list(APPEND _py_trgts ${python_module_name}) + +set(python_module_name _tensor_reductions_impl) +pybind11_add_module(${python_module_name} MODULE ${_tensor_reductions_impl_sources}) +add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_reductions_impl_sources}) +target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt}) +list(APPEND _py_trgts ${python_module_name}) + +set(python_module_name _tensor_sorting_impl) +pybind11_add_module(${python_module_name} MODULE ${_tensor_sorting_impl_sources}) +add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_sources}) +target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt}) +list(APPEND _py_trgts ${python_module_name}) + set(_clang_prefix "") if(WIN32) set(_clang_prefix "/clang:") @@ -97,14 +155,14 @@ set(_no_fast_math_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/clip.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/where.cpp ) -#list( -#APPEND _no_fast_math_sources -# ${_elementwise_sources} -# ${_reduction_sources} -# ${_sorting_sources} -# ${_linalg_sources} -# ${_accumulator_sources} -#) +list( + APPEND _no_fast_math_sources + # ${_elementwise_sources} + ${_reduction_sources} + ${_sorting_sources} + # ${_linalg_sources} + ${_accumulator_sources} +) foreach(_src_fn ${_no_fast_math_sources}) get_source_file_property(_cmpl_options_prop ${_src_fn} COMPILE_OPTIONS) diff --git a/dpctl_ext/tensor/__init__.py b/dpctl_ext/tensor/__init__.py index 9d4013e146a7..ac24151bedfe 100644 --- a/dpctl_ext/tensor/__init__.py +++ b/dpctl_ext/tensor/__init__.py @@ -27,6 +27,7 @@ # ***************************************************************************** +from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum from ._clip import clip from ._copy_utils import ( asnumpy, @@ -77,12 +78,38 @@ tile, unstack, ) +from ._reduction import ( + argmax, + argmin, + count_nonzero, + logsumexp, + max, + min, + prod, + reduce_hypot, + sum, +) from ._reshape import reshape from ._search_functions import where +from ._searchsorted import searchsorted +from ._set_functions import ( + isin, + unique_all, + unique_counts, + unique_inverse, + unique_values, +) +from ._sorting import argsort, sort, top_k from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type +from ._utility_functions import all, any, diff __all__ = [ + "all", + "any", "arange", + "argmax", + "argmin", + "argsort", "asarray", "asnumpy", "astype", @@ -91,7 +118,12 @@ "can_cast", "concat", "copy", + "count_nonzero", "clip", + "cumulative_logsumexp", + "cumulative_prod", + "cumulative_sum", + "diff", "empty", "empty_like", "extract", @@ -104,29 +136,43 @@ "full_like", "iinfo", "isdtype", + "isin", "linspace", + "logsumexp", + "max", "meshgrid", + "min", "moveaxis", "permute_dims", "nonzero", "ones", "ones_like", "place", + "prod", "put", "put_along_axis", + "reduce_hypot", "repeat", "reshape", "result_type", "roll", + "searchsorted", + "sort", "squeeze", "stack", + "sum", "swapaxes", "take", "take_along_axis", "tile", + "top_k", "to_numpy", "tril", "triu", + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", "unstack", "where", "zeros", diff --git a/dpctl_ext/tensor/_accumulation.py b/dpctl_ext/tensor/_accumulation.py new file mode 100644 index 000000000000..2dfe9656e198 --- /dev/null +++ b/dpctl_ext/tensor/_accumulation.py @@ -0,0 +1,470 @@ +# ***************************************************************************** +# Copyright (c) 2026, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +import dpctl +import dpctl.tensor as dpt +from dpctl.utils import ExecutionPlacementError, SequentialOrderManager + +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor._tensor_accumulation_impl as tai +import dpctl_ext.tensor._tensor_impl as ti + +from ._numpy_helper import normalize_axis_index +from ._type_utils import ( + _default_accumulation_dtype, + _default_accumulation_dtype_fp_types, + _to_device_supported_dtype, +) + + +def _accumulate_common( + x, + axis, + dtype, + include_initial, + out, + _accumulate_fn, + _accumulate_include_initial_fn, + _dtype_supported, + _default_accumulation_type_fn, +): + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + appended_axis = False + if x.ndim == 0: + x = x[dpt.newaxis] + appended_axis = True + nd = x.ndim + if axis is None: + if nd > 1: + raise ValueError( + "`axis` cannot be `None` for array of dimension `{}`".format(nd) + ) + axis = 0 + else: + axis = normalize_axis_index(axis, nd, "axis") + sh = x.shape + res_sh = ( + sh[:axis] + (sh[axis] + 1,) + sh[axis + 1 :] if include_initial else sh + ) + a1 = axis + 1 + if a1 == nd: + perm = list(range(nd)) + arr = x + else: + perm = [i for i in range(nd) if i != axis] + [ + axis, + ] + arr = dpt_ext.permute_dims(x, perm) + q = x.sycl_queue + inp_dt = x.dtype + res_usm_type = x.usm_type + if dtype is None: + res_dt = _default_accumulation_type_fn(inp_dt, q) + else: + res_dt = dpt.dtype(dtype) + res_dt = _to_device_supported_dtype(res_dt, q.sycl_device) + + # checking now avoids unnecessary allocations + implemented_types = _dtype_supported(inp_dt, res_dt) + if dtype is None and not implemented_types: + raise RuntimeError( + "Automatically determined accumulation data type does not " + "have direct implementation" + ) + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" + ) + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + out_sh = out.shape + # append an axis to `out` if scalar + if appended_axis and not include_initial: + out = out[dpt.newaxis, ...] + orig_out = out + final_res_sh = res_sh[1:] + else: + final_res_sh = res_sh + if not out_sh == final_res_sh: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {final_res_sh}, got {out_sh}" + ) + if res_dt != out.dtype: + raise ValueError( + f"Output array of type {res_dt} is needed, " f"got {out.dtype}" + ) + if dpctl.utils.get_execution_queue((q, out.sycl_queue)) is None: + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + # permute out array dims if necessary + if a1 != nd: + out = dpt_ext.permute_dims(out, perm) + orig_out = out + if ti._array_overlap(x, out) and implemented_types: + out = dpt_ext.empty_like(out) + else: + out = dpt_ext.empty( + res_sh, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q + ) + if a1 != nd: + out = dpt_ext.permute_dims(out, perm) + + _manager = SequentialOrderManager[q] + depends = _manager.submitted_events + if implemented_types: + if not include_initial: + ht_e, acc_ev = _accumulate_fn( + src=arr, + trailing_dims_to_accumulate=1, + dst=out, + sycl_queue=q, + depends=depends, + ) + else: + ht_e, acc_ev = _accumulate_include_initial_fn( + src=arr, dst=out, sycl_queue=q, depends=depends + ) + _manager.add_event_pair(ht_e, acc_ev) + if not (orig_out is None or out is orig_out): + # Copy the out data from temporary buffer to original memory + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=orig_out, sycl_queue=q, depends=[acc_ev] + ) + _manager.add_event_pair(ht_e_cpy, cpy_e) + out = orig_out + else: + if _dtype_supported(res_dt, res_dt): + tmp = dpt_ext.empty( + arr.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr, dst=tmp, sycl_queue=q, depends=depends + ) + _manager.add_event_pair(ht_e_cpy, cpy_e) + if not include_initial: + ht_e, acc_ev = _accumulate_fn( + src=tmp, + trailing_dims_to_accumulate=1, + dst=out, + sycl_queue=q, + depends=[cpy_e], + ) + else: + ht_e, acc_ev = _accumulate_include_initial_fn( + src=tmp, + dst=out, + sycl_queue=q, + depends=[cpy_e], + ) + _manager.add_event_pair(ht_e, acc_ev) + else: + buf_dt = _default_accumulation_type_fn(inp_dt, q) + tmp = dpt_ext.empty( + arr.shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr, dst=tmp, sycl_queue=q, depends=depends + ) + _manager.add_event_pair(ht_e_cpy, cpy_e) + tmp_res = dpt_ext.empty( + res_sh, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q + ) + if a1 != nd: + tmp_res = dpt_ext.permute_dims(tmp_res, perm) + if not include_initial: + ht_e, acc_ev = _accumulate_fn( + src=tmp, + trailing_dims_to_accumulate=1, + dst=tmp_res, + sycl_queue=q, + depends=[cpy_e], + ) + else: + ht_e, acc_ev = _accumulate_include_initial_fn( + src=tmp, + dst=tmp_res, + sycl_queue=q, + depends=[cpy_e], + ) + _manager.add_event_pair(ht_e, acc_ev) + ht_e_cpy2, cpy_e2 = ti._copy_usm_ndarray_into_usm_ndarray( + src=tmp_res, dst=out, sycl_queue=q, depends=[acc_ev] + ) + _manager.add_event_pair(ht_e_cpy2, cpy_e2) + + if appended_axis: + out = dpt_ext.squeeze(out) + if a1 != nd: + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + out = dpt_ext.permute_dims(out, inv_perm) + + return out + + +def cumulative_sum( + x, /, *, axis=None, dtype=None, include_initial=False, out=None +): + """ + cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False, + out=None) + + Calculates the cumulative sum of elements in the input array `x`. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int]): + axis along which cumulative sum must be computed. + If `None`, the sum is computed over the entire array. + If `x` is a one-dimensional array, providing an `axis` is optional; + however, if `x` has more than one dimension, providing an `axis` + is required. + Default: `None`. + dtype (Optional[dtype]): + data type of the returned array. If `None`, the default data + type is inferred from the "kind" of the input array data type. + + * If `x` has a real- or complex-valued floating-point data + type, the returned array will have the same data type as + `x`. + * If `x` has signed integral data type, the returned array + will have the default signed integral type for the device + where input array `x` is allocated. + * If `x` has unsigned integral data type, the returned array + will have the default unsigned integral type for the device + where input array `x` is allocated. + * If `x` has a boolean data type, the returned array will + have the default signed integral type for the device + where input array `x` is allocated. + + If the data type (either specified or resolved) differs from the + data type of `x`, the input array elements are cast to the + specified data type before computing the cumulative sum. + Default: `None`. + include_initial (bool): + boolean indicating whether to include the initial value (i.e., the + additive identity, zero) as the first value along the provided axis + in the output. Default: `False`. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of `out` must match the expected shape and the + expected data type of the result or (if provided) `dtype`. + If `None` then a new array is returned. Default: `None`. + + Returns: + usm_ndarray: + an array containing cumulative sums. The returned array has the data + type as described in the `dtype` parameter description above. + + The returned array shape is determined as follows: + + * If `include_initial` is `False`, the returned array will + have the same shape as `x` + * If `include_initial` is `True`, the returned array will + have the same shape as `x` except the axis along which the + cumulative sum is calculated, which will have size `N+1` + + where `N` is the size of the axis the cumulative sums are computed + along. + """ + return _accumulate_common( + x, + axis, + dtype, + include_initial, + out, + tai._cumsum_over_axis, + tai._cumsum_final_axis_include_initial, + tai._cumsum_dtype_supported, + _default_accumulation_dtype, + ) + + +def cumulative_prod( + x, /, *, axis=None, dtype=None, include_initial=False, out=None +): + """ + cumulative_prod(x, /, *, axis=None, dtype=None, include_initial=False, + out=None) + + Calculates the cumulative product of elements in the input array `x`. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int]): + axis along which cumulative product must be computed. + If `None`, the product is computed over the entire array. + If `x` is a one-dimensional array, providing an `axis` is optional; + however, if `x` has more than one dimension, providing an `axis` + is required. + Default: `None`. + dtype (Optional[dtype]): + data type of the returned array. If `None`, the default data + type is inferred from the "kind" of the input array data type. + + * If `x` has a real- or complex-valued floating-point data + type, the returned array will have the same data type as + `x`. + * If `x` has signed integral data type, the returned array + will have the default signed integral type for the device + where input array `x` is allocated. + * If `x` has unsigned integral data type, the returned array + will have the default unsigned integral type for the device + where input array `x` is allocated. + * If `x` has a boolean data type, the returned array will + have the default signed integral type for the device + where input array `x` is allocated. + + If the data type (either specified or resolved) differs from the + data type of `x`, the input array elements are cast to the + specified data type before computing the cumulative product. + Default: `None`. + include_initial (bool): + boolean indicating whether to include the initial value (i.e., the + additive identity, zero) as the first value along the provided + axis in the output. Default: `False`. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of `out` must match the expected shape and the + expected data type of the result or (if provided) `dtype`. + If `None` then a new array is returned. Default: `None`. + + Returns: + usm_ndarray: + an array containing cumulative products. The returned array has + the data type as described in the `dtype` parameter description + above. + + The returned array shape is determined as follows: + + * If `include_initial` is `False`, the returned array will + have the same shape as `x` + * If `include_initial` is `True`, the returned array will + have the same shape as `x` except the axis along which the + cumulative product is calculated, which will have size `N+1` + + where `N` is the size of the axis the cumulative products are + computed along. + """ + return _accumulate_common( + x, + axis, + dtype, + include_initial, + out, + tai._cumprod_over_axis, + tai._cumprod_final_axis_include_initial, + tai._cumprod_dtype_supported, + _default_accumulation_dtype, + ) + + +def cumulative_logsumexp( + x, /, *, axis=None, dtype=None, include_initial=False, out=None +): + """ + cumulative_logsumexp(x, /, *, axis=None, dtype=None, include_initial=False, + out=None) + + Calculates the cumulative logsmumexp of elements in the input array `x`. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int]): + axis along which cumulative logsumexp must be computed. + If `None`, the logsumexp is computed over the entire array. + If `x` is a one-dimensional array, providing an `axis` is optional; + however, if `x` has more than one dimension, providing an `axis` + is required. + Default: `None`. + dtype (Optional[dtype]): + data type of the returned array. If `None`, the default data + type is inferred from the "kind" of the input array data type. + + * If `x` has a real- or complex-valued floating-point data + type, the returned array will have the same data type as + `x`. + * If `x` has signed integral data type, the returned array + will have the default signed integral type for the device + where input array `x` is allocated. + * If `x` has unsigned integral data type, the returned array + will have the default unsigned integral type for the device + where input array `x` is allocated. + * If `x` has a boolean data type, the returned array will + have the default signed integral type for the device + where input array `x` is allocated. + + If the data type (either specified or resolved) differs from the + data type of `x`, the input array elements are cast to the + specified data type before computing the cumulative logsumexp. + Default: `None`. + include_initial (bool): + boolean indicating whether to include the initial value (i.e., the + additive identity, zero) as the first value along the provided axis + in the output. Default: `False`. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of `out` must match the expected shape and the + expected data type of the result or (if provided) `dtype`. + If `None` then a new array is returned. Default: `None`. + + Returns: + usm_ndarray: + an array containing cumulative logsumexp results. The returned + array has the data type as described in the `dtype` parameter + description above. + + The returned array shape is determined as follows: + + * If `include_initial` is `False`, the returned array will + have the same shape as `x` + * If `include_initial` is `True`, the returned array will + have the same shape as `x` except the axis along which the + cumulative logsumexp is calculated, which will have size + `N+1` + """ + return _accumulate_common( + x, + axis, + dtype, + include_initial, + out, + tai._cumlogsumexp_over_axis, + tai._cumlogsumexp_final_axis_include_initial, + tai._cumlogsumexp_dtype_supported, + _default_accumulation_dtype_fp_types, + ) diff --git a/dpctl_ext/tensor/_manipulation_functions.py b/dpctl_ext/tensor/_manipulation_functions.py index 08459dcaea76..e2d55c533bc0 100644 --- a/dpctl_ext/tensor/_manipulation_functions.py +++ b/dpctl_ext/tensor/_manipulation_functions.py @@ -624,7 +624,7 @@ def repeat(x, repeats, /, *, axis=None): "'repeats' array must be broadcastable to the size of " "the repeated axis" ) - if not dpt.all(repeats >= 0): + if not dpt_ext.all(repeats >= 0): raise ValueError("'repeats' elements must be positive") elif isinstance(repeats, (tuple, list, range)): @@ -646,7 +646,7 @@ def repeat(x, repeats, /, *, axis=None): repeats = dpt_ext.asarray( repeats, dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q ) - if not dpt.all(repeats >= 0): + if not dpt_ext.all(repeats >= 0): raise ValueError("`repeats` elements must be positive") else: raise TypeError( diff --git a/dpctl_ext/tensor/_reduction.py b/dpctl_ext/tensor/_reduction.py new file mode 100644 index 000000000000..b8fdcf4f37e6 --- /dev/null +++ b/dpctl_ext/tensor/_reduction.py @@ -0,0 +1,834 @@ +# ***************************************************************************** +# Copyright (c) 2026, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +import dpctl +import dpctl.tensor as dpt +from dpctl.utils import ExecutionPlacementError, SequentialOrderManager + +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor._tensor_impl as ti +import dpctl_ext.tensor._tensor_reductions_impl as tri + +from ._numpy_helper import normalize_axis_tuple +from ._type_utils import ( + _default_accumulation_dtype, + _default_accumulation_dtype_fp_types, + _to_device_supported_dtype, +) + + +def _comparison_over_axis(x, axis, keepdims, out, _reduction_fn): + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + + nd = x.ndim + if axis is None: + axis = tuple(range(nd)) + perm = list(axis) + x_tmp = x + else: + if not isinstance(axis, (tuple, list)): + axis = (axis,) + axis = normalize_axis_tuple(axis, nd, "axis") + perm = [i for i in range(nd) if i not in axis] + list(axis) + x_tmp = dpt_ext.permute_dims(x, perm) + red_nd = len(axis) + if any([x_tmp.shape[i] == 0 for i in range(-red_nd, 0)]): + raise ValueError("reduction cannot be performed over zero-size axes") + res_shape = x_tmp.shape[: nd - red_nd] + exec_q = x.sycl_queue + res_dt = x.dtype + res_usm_type = x.usm_type + + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" + ) + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + if not keepdims: + final_res_shape = res_shape + else: + inp_shape = x.shape + final_res_shape = tuple( + inp_shape[i] if i not in axis else 1 for i in range(nd) + ) + if not out.shape == final_res_shape: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {final_res_shape}, got {out.shape}" + ) + if res_dt != out.dtype: + raise ValueError( + f"Output array of type {res_dt} is needed, got {out.dtype}" + ) + if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None: + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + if keepdims: + out = dpt_ext.squeeze(out, axis=axis) + orig_out = out + if ti._array_overlap(x, out): + out = dpt_ext.empty_like(out) + else: + out = dpt_ext.empty( + res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q + ) + + _manager = SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + if red_nd == 0: + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=x_tmp, dst=out, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_e_cpy, cpy_e) + if not (orig_out is None or orig_out is out): + ht_e_cpy2, cpy2_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=orig_out, sycl_queue=exec_q, depends=[cpy_e] + ) + _manager.add_event_pair(ht_e_cpy2, cpy2_e) + out = orig_out + return out + + hev, red_ev = _reduction_fn( + src=x_tmp, + trailing_dims_to_reduce=red_nd, + dst=out, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(hev, red_ev) + if not (orig_out is None or orig_out is out): + ht_e_cpy2, cpy2_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=orig_out, sycl_queue=exec_q, depends=[red_ev] + ) + _manager.add_event_pair(ht_e_cpy2, cpy2_e) + out = orig_out + + if keepdims: + res_shape = res_shape + (1,) * red_nd + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + out = dpt_ext.permute_dims(dpt_ext.reshape(out, res_shape), inv_perm) + return out + + +def _reduction_over_axis( + x, + axis, + dtype, + keepdims, + out, + _reduction_fn, + _dtype_supported, + _default_reduction_type_fn, +): + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + nd = x.ndim + if axis is None: + axis = tuple(range(nd)) + perm = list(axis) + arr = x + else: + if not isinstance(axis, (tuple, list)): + axis = (axis,) + axis = normalize_axis_tuple(axis, nd, "axis") + perm = [i for i in range(nd) if i not in axis] + list(axis) + arr = dpt_ext.permute_dims(x, perm) + red_nd = len(axis) + res_shape = arr.shape[: nd - red_nd] + q = x.sycl_queue + inp_dt = x.dtype + if dtype is None: + res_dt = _default_reduction_type_fn(inp_dt, q) + else: + res_dt = dpt.dtype(dtype) + res_dt = _to_device_supported_dtype(res_dt, q.sycl_device) + + res_usm_type = x.usm_type + + implemented_types = _dtype_supported(inp_dt, res_dt, res_usm_type, q) + if dtype is None and not implemented_types: + raise RuntimeError( + "Automatically determined reduction data type does not " + "have direct implementation" + ) + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" + ) + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + if not keepdims: + final_res_shape = res_shape + else: + inp_shape = x.shape + final_res_shape = tuple( + inp_shape[i] if i not in axis else 1 for i in range(nd) + ) + if not out.shape == final_res_shape: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {final_res_shape}, got {out.shape}" + ) + if res_dt != out.dtype: + raise ValueError( + f"Output array of type {res_dt} is needed, got {out.dtype}" + ) + if dpctl.utils.get_execution_queue((q, out.sycl_queue)) is None: + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + if keepdims: + out = dpt_ext.squeeze(out, axis=axis) + orig_out = out + if ti._array_overlap(x, out) and implemented_types: + out = dpt_ext.empty_like(out) + else: + out = dpt_ext.empty( + res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q + ) + + _manager = SequentialOrderManager[q] + dep_evs = _manager.submitted_events + if red_nd == 0: + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr, dst=out, sycl_queue=q, depends=dep_evs + ) + _manager.add_event_pair(ht_e_cpy, cpy_e) + if not (orig_out is None or orig_out is out): + ht_e_cpy2, cpy2_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=orig_out, sycl_queue=q, depends=[cpy_e] + ) + _manager.add_event_pair(ht_e_cpy2, cpy2_e) + out = orig_out + return out + + if implemented_types: + ht_e, red_e = _reduction_fn( + src=arr, + trailing_dims_to_reduce=red_nd, + dst=out, + sycl_queue=q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_e, red_e) + if not (orig_out is None or orig_out is out): + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=orig_out, sycl_queue=q, depends=[red_e] + ) + _manager.add_event_pair(ht_e_cpy, cpy_e) + out = orig_out + else: + if _dtype_supported(res_dt, res_dt, res_usm_type, q): + tmp = dpt_ext.empty( + arr.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr, dst=tmp, sycl_queue=q, depends=dep_evs + ) + _manager.add_event_pair(ht_e_cpy, cpy_e) + ht_e_red, red_ev = _reduction_fn( + src=tmp, + trailing_dims_to_reduce=red_nd, + dst=out, + sycl_queue=q, + depends=[cpy_e], + ) + _manager.add_event_pair(ht_e_red, red_ev) + else: + buf_dt = _default_reduction_type_fn(inp_dt, q) + tmp = dpt_ext.empty( + arr.shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr, dst=tmp, sycl_queue=q, depends=dep_evs + ) + _manager.add_event_pair(ht_e_cpy, cpy_e) + tmp_res = dpt_ext.empty( + res_shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q + ) + ht_e_red, r_e = _reduction_fn( + src=tmp, + trailing_dims_to_reduce=red_nd, + dst=tmp_res, + sycl_queue=q, + depends=[cpy_e], + ) + _manager.add_event_pair(ht_e_red, r_e) + ht_e_cpy2, cpy2_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=tmp_res, dst=out, sycl_queue=q, depends=[r_e] + ) + _manager.add_event_pair(ht_e_cpy2, cpy2_e) + + if keepdims: + res_shape = res_shape + (1,) * red_nd + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + out = dpt_ext.permute_dims(dpt_ext.reshape(out, res_shape), inv_perm) + return out + + +def _search_over_axis(x, axis, keepdims, out, _reduction_fn): + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + + nd = x.ndim + if axis is None: + axis = tuple(range(nd)) + perm = list(axis) + x_tmp = x + else: + if isinstance(axis, int): + axis = (axis,) + else: + raise TypeError( + f"'axis' argument expected to have type 'int' " + r"or be `None`, " + f"got type {type(axis)}" + ) + axis = normalize_axis_tuple(axis, nd, "axis") + perm = [i for i in range(nd) if i not in axis] + list(axis) + x_tmp = dpt_ext.permute_dims(x, perm) + axis = normalize_axis_tuple(axis, nd, "axis") + red_nd = len(axis) + if any([x_tmp.shape[i] == 0 for i in range(-red_nd, 0)]): + raise ValueError("reduction cannot be performed over zero-size axes") + res_shape = x_tmp.shape[: nd - red_nd] + exec_q = x.sycl_queue + res_dt = ti.default_device_index_type(exec_q.sycl_device) + res_usm_type = x.usm_type + + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" + ) + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + if not keepdims: + final_res_shape = res_shape + else: + inp_shape = x.shape + final_res_shape = tuple( + inp_shape[i] if i not in axis else 1 for i in range(nd) + ) + if not out.shape == final_res_shape: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {final_res_shape}, got {out.shape}" + ) + if res_dt != out.dtype: + raise ValueError( + f"Output array of type {res_dt} is needed, got {out.dtype}" + ) + if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None: + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + if keepdims: + out = dpt_ext.squeeze(out, axis=axis) + orig_out = out + if ti._array_overlap(x, out) and red_nd > 0: + out = dpt_ext.empty_like(out) + else: + out = dpt_ext.empty( + res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q + ) + + _manager = SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + if red_nd == 0: + ht_e_fill, fill_ev = ti._full_usm_ndarray( + fill_value=0, dst=out, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_e_fill, fill_ev) + return out + + hev, red_ev = _reduction_fn( + src=x_tmp, + trailing_dims_to_reduce=red_nd, + dst=out, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(hev, red_ev) + if not (orig_out is None or orig_out is out): + ht_e_cpy2, cpy2_e = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, dst=orig_out, sycl_queue=exec_q, depends=[red_ev] + ) + _manager.add_event_pair(ht_e_cpy2, cpy2_e) + out = orig_out + + if keepdims: + res_shape = res_shape + (1,) * red_nd + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + out = dpt_ext.permute_dims(dpt_ext.reshape(out, res_shape), inv_perm) + return out + + +def argmax(x, /, *, axis=None, keepdims=False, out=None): + """ + Returns the indices of the maximum values of the input array ``x`` along a + specified axis. + + When the maximum value occurs multiple times, the indices corresponding to + the first occurrence are returned. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int]): + axis along which to search. If ``None``, returns the index of the + maximum value of the flattened array. + Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the indices of the first occurrence of the + maximum values. If the entire array was searched, a + zero-dimensional array is returned. The returned array has the + default array index data type for the device of ``x``. + """ + return _search_over_axis(x, axis, keepdims, out, tri._argmax_over_axis) + + +def argmin(x, /, *, axis=None, keepdims=False, out=None): + """ + Returns the indices of the minimum values of the input array ``x`` along a + specified axis. + + When the minimum value occurs multiple times, the indices corresponding to + the first occurrence are returned. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int]): + axis along which to search. If ``None``, returns the index of the + minimum value of the flattened array. + Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the indices of the first occurrence of the + minimum values. If the entire array was searched, a + zero-dimensional array is returned. The returned array has the + default array index data type for the device of ``x``. + """ + return _search_over_axis(x, axis, keepdims, out, tri._argmin_over_axis) + + +def count_nonzero(x, /, *, axis=None, keepdims=False, out=None): + """ + Counts the number of elements in the input array ``x`` which are non-zero. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which to count. If a tuple of unique integers, + the number of non-zero values are computed over multiple axes. + If ``None``, the number of non-zero values is computed over the + entire array. + Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and data + type. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the count of non-zero values. If the sum was + computed over the entire array, a zero-dimensional array is + returned. The returned array will have the default array index data + type. + """ + if x.dtype != dpt.bool: + x = dpt.astype(x, dpt.bool, copy=False) + return sum( + x, + axis=axis, + dtype=ti.default_device_index_type(x.sycl_device), + keepdims=keepdims, + out=out, + ) + + +def logsumexp(x, /, *, axis=None, dtype=None, keepdims=False, out=None): + """ + Calculates the logarithm of the sum of exponentials of elements in the + input array ``x``. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which values must be computed. If a tuple + of unique integers, values are computed over multiple axes. + If ``None``, the result is computed over the entire array. + Default: ``None``. + dtype (Optional[dtype]): + data type of the returned array. If ``None``, the default data + type is inferred from the "kind" of the input array data type. + + * If ``x`` has a real-valued floating-point data type, the + returned array will have the same data type as ``x``. + * If ``x`` has a boolean or integral data type, the returned array + will have the default floating point data type for the device + where input array ``x`` is allocated. + * If ``x`` has a complex-valued floating-point data type, + an error is raised. + + If the data type (either specified or resolved) differs from the + data type of ``x``, the input array elements are cast to the + specified data type before computing the result. + Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result or (if provided) ``dtype``. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the results. If the result was computed over + the entire array, a zero-dimensional array is returned. + The returned array has the data type as described in the + ``dtype`` parameter description above. + """ + return _reduction_over_axis( + x, + axis, + dtype, + keepdims, + out, + tri._logsumexp_over_axis, + lambda inp_dt, res_dt, *_: tri._logsumexp_over_axis_dtype_supported( + inp_dt, res_dt + ), + _default_accumulation_dtype_fp_types, + ) + + +def max(x, /, *, axis=None, keepdims=False, out=None): + """ + Calculates the maximum value of the input array ``x``. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which maxima must be computed. If a tuple + of unique integers, the maxima are computed over multiple axes. + If ``None``, the max is computed over the entire array. + Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the maxima. If the max was computed over the + entire array, a zero-dimensional array is returned. The returned + array has the same data type as ``x``. + """ + return _comparison_over_axis(x, axis, keepdims, out, tri._max_over_axis) + + +def min(x, /, *, axis=None, keepdims=False, out=None): + """ + Calculates the minimum value of the input array ``x``. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which minima must be computed. If a tuple + of unique integers, the minima are computed over multiple axes. + If ``None``, the min is computed over the entire array. + Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the minima. If the min was computed over the + entire array, a zero-dimensional array is returned. The returned + array has the same data type as ``x``. + """ + return _comparison_over_axis(x, axis, keepdims, out, tri._min_over_axis) + + +def prod(x, /, *, axis=None, dtype=None, keepdims=False, out=None): + """ + Calculates the product of elements in the input array ``x``. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which products must be computed. If a tuple + of unique integers, products are computed over multiple axes. + If ``None``, the product is computed over the entire array. + Default: ``None``. + dtype (Optional[dtype]): + data type of the returned array. If ``None``, the default data + type is inferred from the "kind" of the input array data type. + + * If ``x`` has a real- or complex-valued floating-point data + type, the returned array will have the same data type as + ``x``. + * If ``x`` has signed integral data type, the returned array + will have the default signed integral type for the device + where input array ``x`` is allocated. + * If ``x`` has unsigned integral data type, the returned array + will have the default unsigned integral type for the device + where input array ``x`` is allocated. + * If ``x`` has a boolean data type, the returned array will + have the default signed integral type for the device + where input array ``x`` is allocated. + + If the data type (either specified or resolved) differs from the + data type of ``x``, the input array elements are cast to the + specified data type before computing the product. + Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result or (if provided) ``dtype``. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the products. If the product was computed over + the entire array, a zero-dimensional array is returned. The + returned array has the data type as described in the ``dtype`` + parameter description above. + """ + return _reduction_over_axis( + x, + axis, + dtype, + keepdims, + out, + tri._prod_over_axis, + tri._prod_over_axis_dtype_supported, + _default_accumulation_dtype, + ) + + +def reduce_hypot(x, /, *, axis=None, dtype=None, keepdims=False, out=None): + """ + Calculates the square root of the sum of squares of elements in the input + array ``x``. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which values must be computed. If a tuple + of unique integers, values are computed over multiple axes. + If ``None``, the result is computed over the entire array. + Default: ``None``. + dtype (Optional[dtype]): + data type of the returned array. If ``None``, the default data + type is inferred from the "kind" of the input array data type. + + * If ``x`` has a real-valued floating-point data type, the + returned array will have the same data type as ``x``. + * If ``x`` has a boolean or integral data type, the returned array + will have the default floating point data type for the device + where input array ``x`` is allocated. + * If ``x`` has a complex-valued floating-point data type, + an error is raised. + + If the data type (either specified or resolved) differs from the + data type of ``x``, the input array elements are cast to the + specified data type before computing the result. Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result or (if provided) ``dtype``. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the results. If the result was computed over + the entire array, a zero-dimensional array is returned. The + returned array has the data type as described in the ``dtype`` + parameter description above. + """ + return _reduction_over_axis( + x, + axis, + dtype, + keepdims, + out, + tri._hypot_over_axis, + lambda inp_dt, res_dt, *_: tri._hypot_over_axis_dtype_supported( + inp_dt, res_dt + ), + _default_accumulation_dtype_fp_types, + ) + + +def sum(x, /, *, axis=None, dtype=None, keepdims=False, out=None): + """ + Calculates the sum of elements in the input array ``x``. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int, Tuple[int, ...]]): + axis or axes along which sums must be computed. If a tuple + of unique integers, sums are computed over multiple axes. + If ``None``, the sum is computed over the entire array. + Default: ``None``. + dtype (Optional[dtype]): + data type of the returned array. If ``None``, the default data + type is inferred from the "kind" of the input array data type. + + * If ``x`` has a real- or complex-valued floating-point data + type, the returned array will have the same data type as + ``x``. + * If ``x`` has signed integral data type, the returned array + will have the default signed integral type for the device + where input array ``x`` is allocated. + * If ``x`` has unsigned integral data type, the returned array + will have the default unsigned integral type for the device + where input array ``x`` is allocated. + array ``x`` is allocated. + * If ``x`` has a boolean data type, the returned array will + have the default signed integral type for the device + where input array ``x`` is allocated. + + If the data type (either specified or resolved) differs from the + data type of ``x``, the input array elements are cast to the + specified data type before computing the sum. + Default: ``None``. + keepdims (Optional[bool]): + if ``True``, the reduced axes (dimensions) are included in the + result as singleton dimensions, so that the returned array remains + compatible with the input arrays according to Array Broadcasting + rules. Otherwise, if ``False``, the reduced axes are not included + in the returned array. Default: ``False``. + out (Optional[usm_ndarray]): + the array into which the result is written. + The data type of ``out`` must match the expected shape and the + expected data type of the result or (if provided) ``dtype``. + If ``None`` then a new array is returned. Default: ``None``. + + Returns: + usm_ndarray: + an array containing the sums. If the sum was computed over the + entire array, a zero-dimensional array is returned. The returned + array has the data type as described in the ``dtype`` parameter + description above. + """ + return _reduction_over_axis( + x, + axis, + dtype, + keepdims, + out, + tri._sum_over_axis, + tri._sum_over_axis_dtype_supported, + _default_accumulation_dtype, + ) diff --git a/dpctl_ext/tensor/_searchsorted.py b/dpctl_ext/tensor/_searchsorted.py new file mode 100644 index 000000000000..2d4807fb0d0c --- /dev/null +++ b/dpctl_ext/tensor/_searchsorted.py @@ -0,0 +1,189 @@ +# ***************************************************************************** +# Copyright (c) 2026, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + + +from typing import Literal, Union + +import dpctl +import dpctl.utils as du + +# TODO: revert to `from ._usmarray import...` +# when dpnp fully migrates dpctl/tensor +from dpctl.tensor._usmarray import usm_ndarray + +from ._copy_utils import _empty_like_orderK +from ._ctors import empty +from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy +from ._tensor_impl import _take as ti_take +from ._tensor_impl import ( + default_device_index_type as ti_default_device_index_type, +) +from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right +from ._type_utils import isdtype, result_type + + +def searchsorted( + x1: usm_ndarray, + x2: usm_ndarray, + /, + *, + side: Literal["left", "right"] = "left", + sorter: Union[usm_ndarray, None] = None, +) -> usm_ndarray: + """searchsorted(x1, x2, side='left', sorter=None) + + Finds the indices into `x1` such that, if the corresponding elements + in `x2` were inserted before the indices, the order of `x1`, when sorted + in ascending order, would be preserved. + + Args: + x1 (usm_ndarray): + input array. Must be a one-dimensional array. If `sorter` is + `None`, must be sorted in ascending order; otherwise, `sorter` must + be an array of indices that sort `x1` in ascending order. + x2 (usm_ndarray): + array containing search values. + side (Literal["left", "right]): + argument controlling which index is returned if a value lands + exactly on an edge. If `x2` is an array of rank `N` where + `v = x2[n, m, ..., j]`, the element `ret[n, m, ..., j]` in the + return array `ret` contains the position `i` such that + if `side="left"`, it is the first index such that + `x1[i-1] < v <= x1[i]`, `0` if `v <= x1[0]`, and `x1.size` + if `v > x1[-1]`; + and if `side="right"`, it is the first position `i` such that + `x1[i-1] <= v < x1[i]`, `0` if `v < x1[0]`, and `x1.size` + if `v >= x1[-1]`. Default: `"left"`. + sorter (Optional[usm_ndarray]): + array of indices that sort `x1` in ascending order. The array must + have the same shape as `x1` and have an integral data type. + Out of bound index values of `sorter` array are treated using + `"wrap"` mode documented in :py:func:`dpctl.tensor.take`. + Default: `None`. + """ + if not isinstance(x1, usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}") + if not isinstance(x2, usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") + if sorter is not None and not isinstance(sorter, usm_ndarray): + raise TypeError( + f"Expected dpctl.tensor.usm_ndarray, got {type(sorter)}" + ) + + if side not in ["left", "right"]: + raise ValueError( + "Unrecognized value of 'side' keyword argument. " + "Expected either 'left' or 'right'" + ) + + if sorter is None: + q = du.get_execution_queue([x1.sycl_queue, x2.sycl_queue]) + else: + q = du.get_execution_queue( + [x1.sycl_queue, x2.sycl_queue, sorter.sycl_queue] + ) + if q is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously " + "inferred from input arguments." + ) + + if x1.ndim != 1: + raise ValueError("First argument array must be one-dimensional") + + x1_dt = x1.dtype + x2_dt = x2.dtype + + _manager = du.SequentialOrderManager[q] + dep_evs = _manager.submitted_events + ev = dpctl.SyclEvent() + if sorter is not None: + if not isdtype(sorter.dtype, "integral"): + raise ValueError( + f"Sorter array must have integral data type, got {sorter.dtype}" + ) + if x1.shape != sorter.shape: + raise ValueError( + "Sorter array must be one-dimension with the same " + "shape as the first argument array" + ) + res = empty(x1.shape, dtype=x1_dt, usm_type=x1.usm_type, sycl_queue=q) + ind = (sorter,) + axis = 0 + wrap_out_of_bound_indices_mode = 0 + ht_ev, ev = ti_take( + x1, + ind, + res, + axis, + wrap_out_of_bound_indices_mode, + sycl_queue=q, + depends=dep_evs, + ) + x1 = res + _manager.add_event_pair(ht_ev, ev) + + if x1_dt != x2_dt: + dt = result_type(x1, x2) + if x1_dt != dt: + x1_buf = _empty_like_orderK(x1, dt) + dep_evs = _manager.submitted_events + ht_ev, ev = ti_copy( + src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, ev) + x1 = x1_buf + if x2_dt != dt: + x2_buf = _empty_like_orderK(x2, dt) + dep_evs = _manager.submitted_events + ht_ev, ev = ti_copy( + src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, ev) + x2 = x2_buf + + dst_usm_type = du.get_coerced_usm_type([x1.usm_type, x2.usm_type]) + index_dt = ti_default_device_index_type(q) + + dst = _empty_like_orderK(x2, index_dt, usm_type=dst_usm_type) + + dep_evs = _manager.submitted_events + if side == "left": + ht_ev, s_ev = _searchsorted_left( + hay=x1, + needles=x2, + positions=dst, + sycl_queue=q, + depends=dep_evs, + ) + else: + ht_ev, s_ev = _searchsorted_right( + hay=x1, needles=x2, positions=dst, sycl_queue=q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, s_ev) + return dst diff --git a/dpctl_ext/tensor/_set_functions.py b/dpctl_ext/tensor/_set_functions.py new file mode 100644 index 000000000000..93f81f044fd2 --- /dev/null +++ b/dpctl_ext/tensor/_set_functions.py @@ -0,0 +1,803 @@ +# ***************************************************************************** +# Copyright (c) 2026, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +from typing import NamedTuple, Optional, Union + +import dpctl.tensor as dpt +import dpctl.utils as du +from dpctl.tensor._tensor_elementwise_impl import _not_equal, _subtract + +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext + +from ._copy_utils import _empty_like_orderK +from ._scalar_utils import ( + _get_dtype, + _get_queue_usm_type, + _get_shape, + _validate_dtype, +) +from ._tensor_impl import ( + _copy_usm_ndarray_into_usm_ndarray, + _extract, + _full_usm_ndarray, + _linspace_step, + _take, + default_device_index_type, + mask_positions, +) +from ._tensor_sorting_impl import ( + _argsort_ascending, + _isin, + _searchsorted_left, + _sort_ascending, +) +from ._type_utils import ( + _resolve_weak_types_all_py_ints, + _to_device_supported_dtype, +) + +__all__ = [ + "isin", + "unique_values", + "unique_counts", + "unique_inverse", + "unique_all", + "UniqueAllResult", + "UniqueCountsResult", + "UniqueInverseResult", +] + + +class UniqueAllResult(NamedTuple): + values: dpt.usm_ndarray + indices: dpt.usm_ndarray + inverse_indices: dpt.usm_ndarray + counts: dpt.usm_ndarray + + +class UniqueCountsResult(NamedTuple): + values: dpt.usm_ndarray + counts: dpt.usm_ndarray + + +class UniqueInverseResult(NamedTuple): + values: dpt.usm_ndarray + inverse_indices: dpt.usm_ndarray + + +def unique_values(x: dpt.usm_ndarray) -> dpt.usm_ndarray: + """unique_values(x) + + Returns the unique elements of an input array `x`. + + Args: + x (usm_ndarray): + input array. Inputs with more than one dimension are flattened. + Returns: + usm_ndarray + an array containing the set of unique elements in `x`. The + returned array has the same data type as `x`. + """ + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + array_api_dev = x.device + exec_q = array_api_dev.sycl_queue + if x.ndim == 1: + fx = x + else: + fx = dpt_ext.reshape(x, (x.size,), order="C") + if fx.size == 0: + return fx + s = dpt_ext.empty_like(fx, order="C") + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + if fx.flags.c_contiguous: + ht_ev, sort_ev = _sort_ascending( + src=fx, + trailing_dims_to_sort=1, + dst=s, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, sort_ev) + else: + tmp = dpt_ext.empty_like(fx, order="C") + ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray( + src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, copy_ev) + ht_ev, sort_ev = _sort_ascending( + src=tmp, + trailing_dims_to_sort=1, + dst=s, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, sort_ev) + unique_mask = dpt_ext.empty(fx.shape, dtype="?", sycl_queue=exec_q) + ht_ev, uneq_ev = _not_equal( + src1=s[:-1], + src2=s[1:], + dst=unique_mask[1:], + sycl_queue=exec_q, + depends=[sort_ev], + ) + _manager.add_event_pair(ht_ev, uneq_ev) + # writing into new allocation, no dependencies + ht_ev, one_ev = _full_usm_ndarray( + fill_value=True, dst=unique_mask[0], sycl_queue=exec_q + ) + _manager.add_event_pair(ht_ev, one_ev) + cumsum = dpt_ext.empty(s.shape, dtype=dpt.int64, sycl_queue=exec_q) + # synchronizing call + n_uniques = mask_positions( + unique_mask, cumsum, sycl_queue=exec_q, depends=[one_ev, uneq_ev] + ) + if n_uniques == fx.size: + return s + unique_vals = dpt_ext.empty( + n_uniques, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q + ) + ht_ev, ex_e = _extract( + src=s, + cumsum=cumsum, + axis_start=0, + axis_end=1, + dst=unique_vals, + sycl_queue=exec_q, + ) + _manager.add_event_pair(ht_ev, ex_e) + return unique_vals + + +def unique_counts(x: dpt.usm_ndarray) -> UniqueCountsResult: + """unique_counts(x) + + Returns the unique elements of an input array `x` and the corresponding + counts for each unique element in `x`. + + Args: + x (usm_ndarray): + input array. Inputs with more than one dimension are flattened. + Returns: + tuple[usm_ndarray, usm_ndarray] + a namedtuple `(values, counts)` whose + + * first element is the field name `values` and is an array + containing the unique elements of `x`. This array has the + same data type as `x`. + * second element has the field name `counts` and is an array + containing the number of times each unique element occurs in `x`. + This array has the same shape as `values` and has the default + array index data type. + """ + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + array_api_dev = x.device + exec_q = array_api_dev.sycl_queue + x_usm_type = x.usm_type + if x.ndim == 1: + fx = x + else: + fx = dpt_ext.reshape(x, (x.size,), order="C") + ind_dt = default_device_index_type(exec_q) + if fx.size == 0: + return UniqueCountsResult(fx, dpt_ext.empty_like(fx, dtype=ind_dt)) + s = dpt_ext.empty_like(fx, order="C") + + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + if fx.flags.c_contiguous: + ht_ev, sort_ev = _sort_ascending( + src=fx, + trailing_dims_to_sort=1, + dst=s, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, sort_ev) + else: + tmp = dpt_ext.empty_like(fx, order="C") + ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray( + src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, copy_ev) + ht_ev, sort_ev = _sort_ascending( + src=tmp, + dst=s, + trailing_dims_to_sort=1, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, sort_ev) + unique_mask = dpt_ext.empty(s.shape, dtype="?", sycl_queue=exec_q) + ht_ev, uneq_ev = _not_equal( + src1=s[:-1], + src2=s[1:], + dst=unique_mask[1:], + sycl_queue=exec_q, + depends=[sort_ev], + ) + _manager.add_event_pair(ht_ev, uneq_ev) + # no dependency, since we write into new allocation + ht_ev, one_ev = _full_usm_ndarray( + fill_value=True, dst=unique_mask[0], sycl_queue=exec_q + ) + _manager.add_event_pair(ht_ev, one_ev) + cumsum = dpt_ext.empty( + unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q + ) + # synchronizing call + n_uniques = mask_positions( + unique_mask, cumsum, sycl_queue=exec_q, depends=[one_ev, uneq_ev] + ) + if n_uniques == fx.size: + return UniqueCountsResult( + s, + dpt_ext.ones( + n_uniques, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q + ), + ) + unique_vals = dpt_ext.empty( + n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q + ) + # populate unique values + ht_ev, ex_e = _extract( + src=s, + cumsum=cumsum, + axis_start=0, + axis_end=1, + dst=unique_vals, + sycl_queue=exec_q, + ) + _manager.add_event_pair(ht_ev, ex_e) + unique_counts = dpt_ext.empty( + n_uniques + 1, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q + ) + idx = dpt_ext.empty(x.size, dtype=ind_dt, sycl_queue=exec_q) + # writing into new allocation, no dependency + ht_ev, id_ev = _linspace_step(start=0, dt=1, dst=idx, sycl_queue=exec_q) + _manager.add_event_pair(ht_ev, id_ev) + ht_ev, extr_ev = _extract( + src=idx, + cumsum=cumsum, + axis_start=0, + axis_end=1, + dst=unique_counts[:-1], + sycl_queue=exec_q, + depends=[id_ev], + ) + _manager.add_event_pair(ht_ev, extr_ev) + # no dependency, writing into disjoint segmenent of new allocation + ht_ev, set_ev = _full_usm_ndarray( + x.size, dst=unique_counts[-1], sycl_queue=exec_q + ) + _manager.add_event_pair(ht_ev, set_ev) + _counts = dpt_ext.empty_like(unique_counts[1:]) + ht_ev, sub_ev = _subtract( + src1=unique_counts[1:], + src2=unique_counts[:-1], + dst=_counts, + sycl_queue=exec_q, + depends=[set_ev, extr_ev], + ) + _manager.add_event_pair(ht_ev, sub_ev) + return UniqueCountsResult(unique_vals, _counts) + + +def unique_inverse(x): + """unique_inverse + + Returns the unique elements of an input array x and the indices from the + set of unique elements that reconstruct `x`. + + Args: + x (usm_ndarray): + input array. Inputs with more than one dimension are flattened. + Returns: + tuple[usm_ndarray, usm_ndarray] + a namedtuple `(values, inverse_indices)` whose + + * first element has the field name `values` and is an array + containing the unique elements of `x`. The array has the same + data type as `x`. + * second element has the field name `inverse_indices` and is an + array containing the indices of values that reconstruct `x`. + The array has the same shape as `x` and has the default array + index data type. + """ + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + array_api_dev = x.device + exec_q = array_api_dev.sycl_queue + x_usm_type = x.usm_type + ind_dt = default_device_index_type(exec_q) + if x.ndim == 1: + fx = x + else: + fx = dpt_ext.reshape(x, (x.size,), order="C") + sorting_ids = dpt_ext.empty_like(fx, dtype=ind_dt, order="C") + unsorting_ids = dpt_ext.empty_like(sorting_ids, dtype=ind_dt, order="C") + if fx.size == 0: + return UniqueInverseResult(fx, dpt_ext.reshape(unsorting_ids, x.shape)) + + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + if fx.flags.c_contiguous: + ht_ev, sort_ev = _argsort_ascending( + src=fx, + trailing_dims_to_sort=1, + dst=sorting_ids, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, sort_ev) + else: + tmp = dpt_ext.empty_like(fx, order="C") + ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray( + src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, copy_ev) + ht_ev, sort_ev = _argsort_ascending( + src=tmp, + trailing_dims_to_sort=1, + dst=sorting_ids, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, sort_ev) + ht_ev, argsort_ev = _argsort_ascending( + src=sorting_ids, + trailing_dims_to_sort=1, + dst=unsorting_ids, + sycl_queue=exec_q, + depends=[sort_ev], + ) + _manager.add_event_pair(ht_ev, argsort_ev) + s = dpt_ext.empty_like(fx) + # s = fx[sorting_ids] + ht_ev, take_ev = _take( + src=fx, + ind=(sorting_ids,), + dst=s, + axis_start=0, + mode=0, + sycl_queue=exec_q, + depends=[sort_ev], + ) + _manager.add_event_pair(ht_ev, take_ev) + unique_mask = dpt_ext.empty(fx.shape, dtype="?", sycl_queue=exec_q) + ht_ev, uneq_ev = _not_equal( + src1=s[:-1], + src2=s[1:], + dst=unique_mask[1:], + sycl_queue=exec_q, + depends=[take_ev], + ) + _manager.add_event_pair(ht_ev, uneq_ev) + # no dependency + ht_ev, one_ev = _full_usm_ndarray( + fill_value=True, dst=unique_mask[0], sycl_queue=exec_q + ) + _manager.add_event_pair(ht_ev, one_ev) + cumsum = dpt_ext.empty( + unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q + ) + # synchronizing call + n_uniques = mask_positions( + unique_mask, cumsum, sycl_queue=exec_q, depends=[uneq_ev, one_ev] + ) + if n_uniques == fx.size: + return UniqueInverseResult(s, dpt_ext.reshape(unsorting_ids, x.shape)) + unique_vals = dpt_ext.empty( + n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q + ) + ht_ev, uv_ev = _extract( + src=s, + cumsum=cumsum, + axis_start=0, + axis_end=1, + dst=unique_vals, + sycl_queue=exec_q, + ) + _manager.add_event_pair(ht_ev, uv_ev) + cum_unique_counts = dpt_ext.empty( + n_uniques + 1, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q + ) + idx = dpt_ext.empty(x.size, dtype=ind_dt, sycl_queue=exec_q) + ht_ev, id_ev = _linspace_step(start=0, dt=1, dst=idx, sycl_queue=exec_q) + _manager.add_event_pair(ht_ev, id_ev) + ht_ev, extr_ev = _extract( + src=idx, + cumsum=cumsum, + axis_start=0, + axis_end=1, + dst=cum_unique_counts[:-1], + sycl_queue=exec_q, + depends=[id_ev], + ) + _manager.add_event_pair(ht_ev, extr_ev) + ht_ev, set_ev = _full_usm_ndarray( + x.size, dst=cum_unique_counts[-1], sycl_queue=exec_q + ) + _manager.add_event_pair(ht_ev, set_ev) + _counts = dpt_ext.empty_like(cum_unique_counts[1:]) + ht_ev, sub_ev = _subtract( + src1=cum_unique_counts[1:], + src2=cum_unique_counts[:-1], + dst=_counts, + sycl_queue=exec_q, + depends=[set_ev, extr_ev], + ) + _manager.add_event_pair(ht_ev, sub_ev) + + inv = dpt_ext.empty_like(x, dtype=ind_dt, order="C") + ht_ev, ssl_ev = _searchsorted_left( + hay=unique_vals, + needles=x, + positions=inv, + sycl_queue=exec_q, + depends=[ + uv_ev, + ], + ) + _manager.add_event_pair(ht_ev, ssl_ev) + + return UniqueInverseResult(unique_vals, inv) + + +def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult: + """unique_all(x) + + Returns the unique elements of an input array `x`, the first occurring + indices for each unique element in `x`, the indices from the set of unique + elements that reconstruct `x`, and the corresponding counts for each + unique element in `x`. + + Args: + x (usm_ndarray): + input array. Inputs with more than one dimension are flattened. + Returns: + tuple[usm_ndarray, usm_ndarray, usm_ndarray, usm_ndarray] + a namedtuple `(values, indices, inverse_indices, counts)` whose + + * first element has the field name `values` and is an array + containing the unique elements of `x`. The array has the same + data type as `x`. + * second element has the field name `indices` and is an array + the indices (of first occurrences) of `x` that result in + `values`. The array has the same shape as `values` and has the + default array index data type. + * third element has the field name `inverse_indices` and is an + array containing the indices of values that reconstruct `x`. + The array has the same shape as `x` and has the default array + index data type. + * fourth element has the field name `counts` and is an array + containing the number of times each unique element occurs in `x`. + This array has the same shape as `values` and has the default + array index data type. + """ + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + array_api_dev = x.device + exec_q = array_api_dev.sycl_queue + x_usm_type = x.usm_type + ind_dt = default_device_index_type(exec_q) + if x.ndim == 1: + fx = x + else: + fx = dpt_ext.reshape(x, (x.size,), order="C") + sorting_ids = dpt_ext.empty_like(fx, dtype=ind_dt, order="C") + unsorting_ids = dpt_ext.empty_like(sorting_ids, dtype=ind_dt, order="C") + if fx.size == 0: + # original array contains no data + # so it can be safely returned as values + return UniqueAllResult( + fx, + sorting_ids, + dpt_ext.reshape(unsorting_ids, x.shape), + dpt_ext.empty_like(fx, dtype=ind_dt), + ) + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + if fx.flags.c_contiguous: + ht_ev, sort_ev = _argsort_ascending( + src=fx, + trailing_dims_to_sort=1, + dst=sorting_ids, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, sort_ev) + else: + tmp = dpt_ext.empty_like(fx, order="C") + ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray( + src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, copy_ev) + ht_ev, sort_ev = _argsort_ascending( + src=tmp, + trailing_dims_to_sort=1, + dst=sorting_ids, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, sort_ev) + ht_ev, args_ev = _argsort_ascending( + src=sorting_ids, + trailing_dims_to_sort=1, + dst=unsorting_ids, + sycl_queue=exec_q, + depends=[sort_ev], + ) + _manager.add_event_pair(ht_ev, args_ev) + s = dpt_ext.empty_like(fx) + # s = fx[sorting_ids] + ht_ev, take_ev = _take( + src=fx, + ind=(sorting_ids,), + dst=s, + axis_start=0, + mode=0, + sycl_queue=exec_q, + depends=[sort_ev], + ) + _manager.add_event_pair(ht_ev, take_ev) + unique_mask = dpt_ext.empty(fx.shape, dtype="?", sycl_queue=exec_q) + ht_ev, uneq_ev = _not_equal( + src1=s[:-1], + src2=s[1:], + dst=unique_mask[1:], + sycl_queue=exec_q, + depends=[take_ev], + ) + _manager.add_event_pair(ht_ev, uneq_ev) + ht_ev, one_ev = _full_usm_ndarray( + fill_value=True, dst=unique_mask[0], sycl_queue=exec_q + ) + _manager.add_event_pair(ht_ev, one_ev) + cumsum = dpt_ext.empty( + unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q + ) + # synchronizing call + n_uniques = mask_positions( + unique_mask, cumsum, sycl_queue=exec_q, depends=[uneq_ev, one_ev] + ) + if n_uniques == fx.size: + _counts = dpt_ext.ones( + n_uniques, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q + ) + return UniqueAllResult( + s, + sorting_ids, + dpt_ext.reshape(unsorting_ids, x.shape), + _counts, + ) + unique_vals = dpt_ext.empty( + n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q + ) + ht_ev, uv_ev = _extract( + src=s, + cumsum=cumsum, + axis_start=0, + axis_end=1, + dst=unique_vals, + sycl_queue=exec_q, + ) + _manager.add_event_pair(ht_ev, uv_ev) + cum_unique_counts = dpt_ext.empty( + n_uniques + 1, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q + ) + idx = dpt_ext.empty(x.size, dtype=ind_dt, sycl_queue=exec_q) + ht_ev, id_ev = _linspace_step(start=0, dt=1, dst=idx, sycl_queue=exec_q) + _manager.add_event_pair(ht_ev, id_ev) + ht_ev, extr_ev = _extract( + src=idx, + cumsum=cumsum, + axis_start=0, + axis_end=1, + dst=cum_unique_counts[:-1], + sycl_queue=exec_q, + depends=[id_ev], + ) + _manager.add_event_pair(ht_ev, extr_ev) + ht_ev, set_ev = _full_usm_ndarray( + x.size, dst=cum_unique_counts[-1], sycl_queue=exec_q + ) + _manager.add_event_pair(ht_ev, set_ev) + _counts = dpt_ext.empty_like(cum_unique_counts[1:]) + ht_ev, sub_ev = _subtract( + src1=cum_unique_counts[1:], + src2=cum_unique_counts[:-1], + dst=_counts, + sycl_queue=exec_q, + depends=[set_ev, extr_ev], + ) + _manager.add_event_pair(ht_ev, sub_ev) + + inv = dpt_ext.empty_like(x, dtype=ind_dt, order="C") + ht_ev, ssl_ev = _searchsorted_left( + hay=unique_vals, + needles=x, + positions=inv, + sycl_queue=exec_q, + depends=[ + uv_ev, + ], + ) + _manager.add_event_pair(ht_ev, ssl_ev) + return UniqueAllResult( + unique_vals, + sorting_ids[cum_unique_counts[:-1]], + inv, + _counts, + ) + + +def isin( + x: Union[dpt.usm_ndarray, int, float, complex, bool], + test_elements: Union[dpt.usm_ndarray, int, float, complex, bool], + /, + *, + invert: Optional[bool] = False, +) -> dpt.usm_ndarray: + """isin(x, test_elements, /, *, invert=False) + + Tests `x in test_elements` for each element of `x`. Returns a boolean array + with the same shape as `x` that is `True` where the element is in + `test_elements`, `False` otherwise. + + Args: + x (Union[usm_ndarray, bool, int, float, complex]): + input element or elements. + test_elements (Union[usm_ndarray, bool, int, float, complex]): + elements against which to test each value of `x`. + invert (Optional[bool]): + if `True`, the output results are inverted, i.e., are equivalent to + testing `x not in test_elements` for each element of `x`. + Default: `False`. + + Returns: + usm_ndarray: + an array of the inclusion test results. The returned array has a + boolean data type and the same shape as `x`. + """ + q1, x_usm_type = _get_queue_usm_type(x) + q2, test_usm_type = _get_queue_usm_type(test_elements) + if q1 is None and q2 is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments. " + "One of the arguments must represent USM allocation and " + "expose `__sycl_usm_array_interface__` property" + ) + if q1 is None: + exec_q = q2 + res_usm_type = test_usm_type + elif q2 is None: + exec_q = q1 + res_usm_type = x_usm_type + else: + exec_q = du.get_execution_queue((q1, q2)) + if exec_q is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = du.get_coerced_usm_type( + ( + x_usm_type, + test_usm_type, + ) + ) + du.validate_usm_type(res_usm_type, allow_none=False) + sycl_dev = exec_q.sycl_device + + if not isinstance(invert, bool): + raise TypeError( + "`invert` keyword argument must be of boolean type, " + f"got {type(invert)}" + ) + + x_dt = _get_dtype(x, sycl_dev) + test_dt = _get_dtype(test_elements, sycl_dev) + if not all(_validate_dtype(dt) for dt in (x_dt, test_dt)): + raise ValueError("Operands have unsupported data types") + + x_sh = _get_shape(x) + if isinstance(test_elements, dpt.usm_ndarray) and test_elements.size == 0: + if invert: + return dpt_ext.ones( + x_sh, dtype=dpt.bool, usm_type=res_usm_type, sycl_queue=exec_q + ) + else: + return dpt_ext.zeros( + x_sh, dtype=dpt.bool, usm_type=res_usm_type, sycl_queue=exec_q + ) + + dt1, dt2 = _resolve_weak_types_all_py_ints(x_dt, test_dt, sycl_dev) + dt = _to_device_supported_dtype(dpt_ext.result_type(dt1, dt2), sycl_dev) + + if not isinstance(x, dpt.usm_ndarray): + x_arr = dpt_ext.asarray( + x, dtype=dt1, usm_type=res_usm_type, sycl_queue=exec_q + ) + else: + x_arr = x + + if not isinstance(test_elements, dpt.usm_ndarray): + test_arr = dpt_ext.asarray( + test_elements, dtype=dt2, usm_type=res_usm_type, sycl_queue=exec_q + ) + else: + test_arr = test_elements + + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + + if x_dt != dt: + x_buf = _empty_like_orderK(x_arr, dt, res_usm_type, exec_q) + ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray( + src=x_arr, dst=x_buf, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, ev) + else: + x_buf = x_arr + + if test_dt != dt: + # copy into C-contiguous memory, because the array will be flattened + test_buf = dpt_ext.empty_like( + test_arr, dtype=dt, order="C", usm_type=res_usm_type + ) + ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray( + src=test_arr, dst=test_buf, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, ev) + else: + test_buf = test_arr + + test_buf = dpt_ext.reshape(test_buf, -1) + test_buf = dpt_ext.sort(test_buf) + + dst = dpt_ext.empty_like( + x_buf, dtype=dpt.bool, usm_type=res_usm_type, order="C" + ) + + dep_evs = _manager.submitted_events + ht_ev, s_ev = _isin( + needles=x_buf, + hay=test_buf, + dst=dst, + sycl_queue=exec_q, + invert=invert, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, s_ev) + return dst diff --git a/dpctl_ext/tensor/_sorting.py b/dpctl_ext/tensor/_sorting.py new file mode 100644 index 000000000000..24693a408889 --- /dev/null +++ b/dpctl_ext/tensor/_sorting.py @@ -0,0 +1,450 @@ +# ***************************************************************************** +# Copyright (c) 2026, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +import operator +from typing import NamedTuple + +import dpctl.tensor as dpt +import dpctl.utils as du + +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor._tensor_impl as ti + +from ._numpy_helper import normalize_axis_index +from ._tensor_sorting_impl import ( + _argsort_ascending, + _argsort_descending, + _radix_argsort_ascending, + _radix_argsort_descending, + _radix_sort_ascending, + _radix_sort_descending, + _radix_sort_dtype_supported, + _sort_ascending, + _sort_descending, + _topk, +) + +__all__ = ["sort", "argsort", "top_k"] + + +def _get_mergesort_impl_fn(descending): + return _sort_descending if descending else _sort_ascending + + +def _get_radixsort_impl_fn(descending): + return _radix_sort_descending if descending else _radix_sort_ascending + + +def sort(x, /, *, axis=-1, descending=False, stable=True, kind=None): + """sort(x, axis=-1, descending=False, stable=True) + + Returns a sorted copy of an input array `x`. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int]): + axis along which to sort. If set to `-1`, the function + must sort along the last axis. Default: `-1`. + descending (Optional[bool]): + sort order. If `True`, the array must be sorted in descending + order (by value). If `False`, the array must be sorted in + ascending order (by value). Default: `False`. + stable (Optional[bool]): + sort stability. If `True`, the returned array must maintain the + relative order of `x` values which compare as equal. If `False`, + the returned array may or may not maintain the relative order of + `x` values which compare as equal. Default: `True`. + kind (Optional[Literal["stable", "mergesort", "radixsort"]]): + Sorting algorithm. The default is `"stable"`, which uses parallel + merge-sort or parallel radix-sort algorithms depending on the + array data type. + Returns: + usm_ndarray: + a sorted array. The returned array has the same data type and + the same shape as the input array `x`. + """ + if not isinstance(x, dpt.usm_ndarray): + raise TypeError( + f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}" + ) + nd = x.ndim + if nd == 0: + axis = normalize_axis_index(axis, ndim=1, msg_prefix="axis") + return dpt_ext.copy(x, order="C") + else: + axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis") + a1 = axis + 1 + if a1 == nd: + perm = list(range(nd)) + arr = x + else: + perm = [i for i in range(nd) if i != axis] + [ + axis, + ] + arr = dpt_ext.permute_dims(x, perm) + if kind is None: + kind = "stable" + if not isinstance(kind, str) or kind not in [ + "stable", + "radixsort", + "mergesort", + ]: + raise ValueError( + "Unsupported kind value. Expected 'stable', 'mergesort', " + f"or 'radixsort', but got '{kind}'" + ) + if kind == "mergesort": + impl_fn = _get_mergesort_impl_fn(descending) + elif kind == "radixsort": + if _radix_sort_dtype_supported(x.dtype.num): + impl_fn = _get_radixsort_impl_fn(descending) + else: + raise ValueError(f"Radix sort is not supported for {x.dtype}") + else: + dt = x.dtype + if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]: + impl_fn = _get_radixsort_impl_fn(descending) + else: + impl_fn = _get_mergesort_impl_fn(descending) + exec_q = x.sycl_queue + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + if arr.flags.c_contiguous: + res = dpt_ext.empty_like(arr, order="C") + ht_ev, impl_ev = impl_fn( + src=arr, + trailing_dims_to_sort=1, + dst=res, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, impl_ev) + else: + tmp = dpt_ext.empty_like(arr, order="C") + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, copy_ev) + res = dpt_ext.empty_like(arr, order="C") + ht_ev, impl_ev = impl_fn( + src=tmp, + trailing_dims_to_sort=1, + dst=res, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, impl_ev) + if a1 != nd: + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + res = dpt_ext.permute_dims(res, inv_perm) + return res + + +def _get_mergeargsort_impl_fn(descending): + return _argsort_descending if descending else _argsort_ascending + + +def _get_radixargsort_impl_fn(descending): + return _radix_argsort_descending if descending else _radix_argsort_ascending + + +def argsort(x, axis=-1, descending=False, stable=True, kind=None): + """argsort(x, axis=-1, descending=False, stable=True) + + Returns the indices that sort an array `x` along a specified axis. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int]): + axis along which to sort. If set to `-1`, the function + must sort along the last axis. Default: `-1`. + descending (Optional[bool]): + sort order. If `True`, the array must be sorted in descending + order (by value). If `False`, the array must be sorted in + ascending order (by value). Default: `False`. + stable (Optional[bool]): + sort stability. If `True`, the returned array must maintain the + relative order of `x` values which compare as equal. If `False`, + the returned array may or may not maintain the relative order of + `x` values which compare as equal. Default: `True`. + kind (Optional[Literal["stable", "mergesort", "radixsort"]]): + Sorting algorithm. The default is `"stable"`, which uses parallel + merge-sort or parallel radix-sort algorithms depending on the + array data type. + + Returns: + usm_ndarray: + an array of indices. The returned array has the same shape as + the input array `x`. The return array has default array index + data type. + """ + if not isinstance(x, dpt.usm_ndarray): + raise TypeError( + f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}" + ) + nd = x.ndim + if nd == 0: + axis = normalize_axis_index(axis, ndim=1, msg_prefix="axis") + return dpt_ext.zeros_like( + x, dtype=ti.default_device_index_type(x.sycl_queue), order="C" + ) + else: + axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis") + a1 = axis + 1 + if a1 == nd: + perm = list(range(nd)) + arr = x + else: + perm = [i for i in range(nd) if i != axis] + [ + axis, + ] + arr = dpt_ext.permute_dims(x, perm) + if kind is None: + kind = "stable" + if not isinstance(kind, str) or kind not in [ + "stable", + "radixsort", + "mergesort", + ]: + raise ValueError( + "Unsupported kind value. Expected 'stable', 'mergesort', " + f"or 'radixsort', but got '{kind}'" + ) + if kind == "mergesort": + impl_fn = _get_mergeargsort_impl_fn(descending) + elif kind == "radixsort": + if _radix_sort_dtype_supported(x.dtype.num): + impl_fn = _get_radixargsort_impl_fn(descending) + else: + raise ValueError(f"Radix sort is not supported for {x.dtype}") + else: + dt = x.dtype + if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]: + impl_fn = _get_radixargsort_impl_fn(descending) + else: + impl_fn = _get_mergeargsort_impl_fn(descending) + exec_q = x.sycl_queue + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + index_dt = ti.default_device_index_type(exec_q) + if arr.flags.c_contiguous: + res = dpt_ext.empty_like(arr, dtype=index_dt, order="C") + ht_ev, impl_ev = impl_fn( + src=arr, + trailing_dims_to_sort=1, + dst=res, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, impl_ev) + else: + tmp = dpt_ext.empty_like(arr, order="C") + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, copy_ev) + res = dpt_ext.empty_like(arr, dtype=index_dt, order="C") + ht_ev, impl_ev = impl_fn( + src=tmp, + trailing_dims_to_sort=1, + dst=res, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, impl_ev) + if a1 != nd: + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + res = dpt_ext.permute_dims(res, inv_perm) + return res + + +def _get_top_k_largest(mode): + modes = {"largest": True, "smallest": False} + try: + return modes[mode] + except KeyError: + raise ValueError( + f"`mode` must be `largest` or `smallest`. Got `{mode}`." + ) + + +class TopKResult(NamedTuple): + values: dpt.usm_ndarray + indices: dpt.usm_ndarray + + +def top_k(x, k, /, *, axis=None, mode="largest"): + """top_k(x, k, axis=None, mode="largest") + + Returns the `k` largest or smallest values and their indices in the input + array `x` along the specified axis `axis`. + + Args: + x (usm_ndarray): + input array. + k (int): + number of elements to find. Must be a positive integer value. + axis (Optional[int]): + axis along which to search. If `None`, the search will be performed + over the flattened array. Default: ``None``. + mode (Literal["largest", "smallest"]): + search mode. Must be one of the following modes: + + - `"largest"`: return the `k` largest elements. + - `"smallest"`: return the `k` smallest elements. + + Default: `"largest"`. + + Returns: + tuple[usm_ndarray, usm_ndarray] + a namedtuple `(values, indices)` whose + + * first element `values` will be an array containing the `k` + largest or smallest elements of `x`. The array has the same data + type as `x`. If `axis` was `None`, `values` will be a + one-dimensional array with shape `(k,)` and otherwise, `values` + will have shape `x.shape[:axis] + (k,) + x.shape[axis+1:]` + * second element `indices` will be an array containing indices of + `x` that result in `values`. The array will have the same shape + as `values` and will have the default array index data type. + """ + largest = _get_top_k_largest(mode) + if not isinstance(x, dpt.usm_ndarray): + raise TypeError( + f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}" + ) + + k = operator.index(k) + if k < 0: + raise ValueError("`k` must be a positive integer value") + + nd = x.ndim + if axis is None: + sz = x.size + if nd == 0: + if k > 1: + raise ValueError(f"`k`={k} is out of bounds 1") + return TopKResult( + dpt_ext.copy(x, order="C"), + dpt_ext.zeros_like( + x, dtype=ti.default_device_index_type(x.sycl_queue) + ), + ) + arr = x + n_search_dims = None + res_sh = k + else: + axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis") + sz = x.shape[axis] + a1 = axis + 1 + if a1 == nd: + perm = list(range(nd)) + arr = x + else: + perm = [i for i in range(nd) if i != axis] + [ + axis, + ] + arr = dpt_ext.permute_dims(x, perm) + n_search_dims = 1 + res_sh = arr.shape[: nd - 1] + (k,) + + if k > sz: + raise ValueError(f"`k`={k} is out of bounds {sz}") + + exec_q = x.sycl_queue + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + + res_usm_type = arr.usm_type + if arr.flags.c_contiguous: + vals = dpt_ext.empty( + res_sh, + dtype=arr.dtype, + usm_type=res_usm_type, + order="C", + sycl_queue=exec_q, + ) + inds = dpt_ext.empty( + res_sh, + dtype=ti.default_device_index_type(exec_q), + usm_type=res_usm_type, + order="C", + sycl_queue=exec_q, + ) + ht_ev, impl_ev = _topk( + src=arr, + trailing_dims_to_search=n_search_dims, + k=k, + largest=largest, + vals=vals, + inds=inds, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, impl_ev) + else: + tmp = dpt_ext.empty_like(arr, order="C") + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, copy_ev) + vals = dpt_ext.empty( + res_sh, + dtype=arr.dtype, + usm_type=res_usm_type, + order="C", + sycl_queue=exec_q, + ) + inds = dpt_ext.empty( + res_sh, + dtype=ti.default_device_index_type(exec_q), + usm_type=res_usm_type, + order="C", + sycl_queue=exec_q, + ) + ht_ev, impl_ev = _topk( + src=tmp, + trailing_dims_to_search=n_search_dims, + k=k, + largest=largest, + vals=vals, + inds=inds, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, impl_ev) + if axis is not None and a1 != nd: + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + vals = dpt_ext.permute_dims(vals, inv_perm) + inds = dpt_ext.permute_dims(inds, inv_perm) + + return TopKResult(vals, inds) diff --git a/dpctl_ext/tensor/_utility_functions.py b/dpctl_ext/tensor/_utility_functions.py new file mode 100644 index 000000000000..a122ac3d6cea --- /dev/null +++ b/dpctl_ext/tensor/_utility_functions.py @@ -0,0 +1,509 @@ +# ***************************************************************************** +# Copyright (c) 2026, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +import builtins +import operator + +import dpctl.tensor as dpt +import dpctl.utils as du + +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor._tensor_impl as ti +import dpctl_ext.tensor._tensor_reductions_impl as tri + +from ._numpy_helper import normalize_axis_index, normalize_axis_tuple +from ._scalar_utils import ( + _get_dtype, + _get_queue_usm_type, + _get_shape, + _validate_dtype, +) +from ._type_utils import ( + _resolve_one_strong_one_weak_types, + _resolve_one_strong_two_weak_types, +) + + +def _boolean_reduction(x, axis, keepdims, func): + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + + nd = x.ndim + if axis is None: + red_nd = nd + # case of a scalar + if red_nd == 0: + return dpt_ext.astype(x, dpt.bool) + x_tmp = x + res_shape = () + perm = list(range(nd)) + else: + if not isinstance(axis, (tuple, list)): + axis = (axis,) + axis = normalize_axis_tuple(axis, nd, "axis") + + red_nd = len(axis) + # check for axis=() + if red_nd == 0: + return dpt_ext.astype(x, dpt.bool) + perm = [i for i in range(nd) if i not in axis] + list(axis) + x_tmp = dpt_ext.permute_dims(x, perm) + res_shape = x_tmp.shape[: nd - red_nd] + + exec_q = x.sycl_queue + res_usm_type = x.usm_type + + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + # always allocate the temporary as + # int32 and usm-device to ensure that atomic updates + # are supported + res_tmp = dpt_ext.empty( + res_shape, + dtype=dpt.int32, + usm_type="device", + sycl_queue=exec_q, + ) + hev0, ev0 = func( + src=x_tmp, + trailing_dims_to_reduce=red_nd, + dst=res_tmp, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(hev0, ev0) + + # copy to boolean result array + res = dpt_ext.empty( + res_shape, + dtype=dpt.bool, + usm_type=res_usm_type, + sycl_queue=exec_q, + ) + hev1, ev1 = ti._copy_usm_ndarray_into_usm_ndarray( + src=res_tmp, dst=res, sycl_queue=exec_q, depends=[ev0] + ) + _manager.add_event_pair(hev1, ev1) + + if keepdims: + res_shape = res_shape + (1,) * red_nd + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + res = dpt_ext.permute_dims(dpt_ext.reshape(res, res_shape), inv_perm) + return res + + +def all(x, /, *, axis=None, keepdims=False): + """ + all(x, axis=None, keepdims=False) + + Tests whether all input array elements evaluate to True along a given axis. + + Args: + x (usm_ndarray): Input array. + axis (Optional[Union[int, Tuple[int,...]]]): Axis (or axes) + along which to perform a logical AND reduction. + When `axis` is `None`, a logical AND reduction + is performed over all dimensions of `x`. + If `axis` is negative, the axis is counted from + the last dimension to the first. + Default: `None`. + keepdims (bool, optional): If `True`, the reduced axes are included + in the result as singleton dimensions, and the result is + broadcastable to the input array shape. + If `False`, the reduced axes are not included in the result. + Default: `False`. + + Returns: + usm_ndarray: + An array with a data type of `bool` + containing the results of the logical AND reduction. + """ + return _boolean_reduction(x, axis, keepdims, tri._all) + + +def any(x, /, *, axis=None, keepdims=False): + """ + any(x, axis=None, keepdims=False) + + Tests whether any input array elements evaluate to True along a given axis. + + Args: + x (usm_ndarray): Input array. + axis (Optional[Union[int, Tuple[int,...]]]): Axis (or axes) + along which to perform a logical OR reduction. + When `axis` is `None`, a logical OR reduction + is performed over all dimensions of `x`. + If `axis` is negative, the axis is counted from + the last dimension to the first. + Default: `None`. + keepdims (bool, optional): If `True`, the reduced axes are included + in the result as singleton dimensions, and the result is + broadcastable to the input array shape. + If `False`, the reduced axes are not included in the result. + Default: `False`. + + Returns: + usm_ndarray: + An array with a data type of `bool` + containing the results of the logical OR reduction. + """ + return _boolean_reduction(x, axis, keepdims, tri._any) + + +def _validate_diff_shape(sh1, sh2, axis): + """ + Utility for validating that two shapes `sh1` and `sh2` + are possible to concatenate along `axis`. + """ + if not sh2: + # scalars will always be accepted + return True + else: + sh1_ndim = len(sh1) + if sh1_ndim == len(sh2) and builtins.all( + sh1[i] == sh2[i] for i in range(sh1_ndim) if i != axis + ): + return True + else: + return False + + +def _concat_diff_input(arr, axis, prepend, append): + """ + Concatenates `arr`, `prepend` and, `append` along `axis`, + where `arr` is an array and `prepend` and `append` are + any mixture of arrays and scalars. + """ + if prepend is not None and append is not None: + q1, x_usm_type = arr.sycl_queue, arr.usm_type + q2, prepend_usm_type = _get_queue_usm_type(prepend) + q3, append_usm_type = _get_queue_usm_type(append) + if q2 is None and q3 is None: + exec_q = q1 + coerced_usm_type = x_usm_type + elif q3 is None: + exec_q = du.get_execution_queue((q1, q2)) + if exec_q is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + coerced_usm_type = du.get_coerced_usm_type( + ( + x_usm_type, + prepend_usm_type, + ) + ) + elif q2 is None: + exec_q = du.get_execution_queue((q1, q3)) + if exec_q is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + coerced_usm_type = du.get_coerced_usm_type( + ( + x_usm_type, + append_usm_type, + ) + ) + else: + exec_q = du.get_execution_queue((q1, q2, q3)) + if exec_q is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + coerced_usm_type = du.get_coerced_usm_type( + ( + x_usm_type, + prepend_usm_type, + append_usm_type, + ) + ) + du.validate_usm_type(coerced_usm_type, allow_none=False) + arr_shape = arr.shape + prepend_shape = _get_shape(prepend) + append_shape = _get_shape(append) + if not builtins.all( + isinstance(s, (tuple, list)) + for s in ( + prepend_shape, + append_shape, + ) + ): + raise TypeError( + "Shape of arguments can not be inferred. " + "Arguments are expected to be " + "lists, tuples, or both" + ) + valid_prepend_shape = _validate_diff_shape( + arr_shape, prepend_shape, axis + ) + if not valid_prepend_shape: + raise ValueError( + f"`diff` argument `prepend` with shape {prepend_shape} is " + f"invalid for first input with shape {arr_shape}" + ) + valid_append_shape = _validate_diff_shape(arr_shape, append_shape, axis) + if not valid_append_shape: + raise ValueError( + f"`diff` argument `append` with shape {append_shape} is invalid" + f" for first input with shape {arr_shape}" + ) + sycl_dev = exec_q.sycl_device + arr_dtype = arr.dtype + prepend_dtype = _get_dtype(prepend, sycl_dev) + append_dtype = _get_dtype(append, sycl_dev) + if not builtins.all( + _validate_dtype(o) for o in (prepend_dtype, append_dtype) + ): + raise ValueError("Operands have unsupported data types") + prepend_dtype, append_dtype = _resolve_one_strong_two_weak_types( + arr_dtype, prepend_dtype, append_dtype, sycl_dev + ) + if isinstance(prepend, dpt.usm_ndarray): + a_prepend = prepend + else: + a_prepend = dpt_ext.asarray( + prepend, + dtype=prepend_dtype, + usm_type=coerced_usm_type, + sycl_queue=exec_q, + ) + if isinstance(append, dpt.usm_ndarray): + a_append = append + else: + a_append = dpt_ext.asarray( + append, + dtype=append_dtype, + usm_type=coerced_usm_type, + sycl_queue=exec_q, + ) + if not prepend_shape: + prepend_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :] + a_prepend = dpt_ext.broadcast_to(a_prepend, prepend_shape) + if not append_shape: + append_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :] + a_append = dpt_ext.broadcast_to(a_append, append_shape) + return dpt_ext.concat((a_prepend, arr, a_append), axis=axis) + elif prepend is not None: + q1, x_usm_type = arr.sycl_queue, arr.usm_type + q2, prepend_usm_type = _get_queue_usm_type(prepend) + if q2 is None: + exec_q = q1 + coerced_usm_type = x_usm_type + else: + exec_q = du.get_execution_queue((q1, q2)) + if exec_q is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + coerced_usm_type = du.get_coerced_usm_type( + ( + x_usm_type, + prepend_usm_type, + ) + ) + du.validate_usm_type(coerced_usm_type, allow_none=False) + arr_shape = arr.shape + prepend_shape = _get_shape(prepend) + if not isinstance(prepend_shape, (tuple, list)): + raise TypeError( + "Shape of argument can not be inferred. " + "Argument is expected to be a " + "list or tuple" + ) + valid_prepend_shape = _validate_diff_shape( + arr_shape, prepend_shape, axis + ) + if not valid_prepend_shape: + raise ValueError( + f"`diff` argument `prepend` with shape {prepend_shape} is " + f"invalid for first input with shape {arr_shape}" + ) + sycl_dev = exec_q.sycl_device + arr_dtype = arr.dtype + prepend_dtype = _get_dtype(prepend, sycl_dev) + if not _validate_dtype(prepend_dtype): + raise ValueError("Operand has unsupported data type") + prepend_dtype = _resolve_one_strong_one_weak_types( + arr_dtype, prepend_dtype, sycl_dev + ) + if isinstance(prepend, dpt.usm_ndarray): + a_prepend = prepend + else: + a_prepend = dpt_ext.asarray( + prepend, + dtype=prepend_dtype, + usm_type=coerced_usm_type, + sycl_queue=exec_q, + ) + if not prepend_shape: + prepend_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :] + a_prepend = dpt_ext.broadcast_to(a_prepend, prepend_shape) + return dpt_ext.concat((a_prepend, arr), axis=axis) + elif append is not None: + q1, x_usm_type = arr.sycl_queue, arr.usm_type + q2, append_usm_type = _get_queue_usm_type(append) + if q2 is None: + exec_q = q1 + coerced_usm_type = x_usm_type + else: + exec_q = du.get_execution_queue((q1, q2)) + if exec_q is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + coerced_usm_type = du.get_coerced_usm_type( + ( + x_usm_type, + append_usm_type, + ) + ) + du.validate_usm_type(coerced_usm_type, allow_none=False) + arr_shape = arr.shape + append_shape = _get_shape(append) + if not isinstance(append_shape, (tuple, list)): + raise TypeError( + "Shape of argument can not be inferred. " + "Argument is expected to be a " + "list or tuple" + ) + valid_append_shape = _validate_diff_shape(arr_shape, append_shape, axis) + if not valid_append_shape: + raise ValueError( + f"`diff` argument `append` with shape {append_shape} is invalid" + f" for first input with shape {arr_shape}" + ) + sycl_dev = exec_q.sycl_device + arr_dtype = arr.dtype + append_dtype = _get_dtype(append, sycl_dev) + if not _validate_dtype(append_dtype): + raise ValueError("Operand has unsupported data type") + append_dtype = _resolve_one_strong_one_weak_types( + arr_dtype, append_dtype, sycl_dev + ) + if isinstance(append, dpt.usm_ndarray): + a_append = append + else: + a_append = dpt_ext.asarray( + append, + dtype=append_dtype, + usm_type=coerced_usm_type, + sycl_queue=exec_q, + ) + if not append_shape: + append_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :] + a_append = dpt_ext.broadcast_to(a_append, append_shape) + return dpt_ext.concat((arr, a_append), axis=axis) + else: + arr1 = arr + return arr1 + + +def diff(x, /, *, axis=-1, n=1, prepend=None, append=None): + """ + Calculates the `n`-th discrete forward difference of `x` along `axis`. + + Args: + x (usm_ndarray): + input array. + axis (int): + axis along which to compute the difference. A valid axis must be on + the interval `[-N, N)`, where `N` is the rank (number of + dimensions) of `x`. + Default: `-1` + n (int): + number of times to recursively compute the difference. + Default: `1`. + prepend (Union[usm_ndarray, bool, int, float, complex]): + value or values to prepend to the specified axis before taking the + difference. + Must have the same shape as `x` except along `axis`, which can have + any shape. + Default: `None`. + append (Union[usm_ndarray, bool, int, float, complex]): + value or values to append to the specified axis before taking the + difference. + Must have the same shape as `x` except along `axis`, which can have + any shape. + Default: `None`. + + Returns: + usm_ndarray: + an array containing the `n`-th differences. The array will have the + same shape as `x`, except along `axis`, which will have shape: + ``prepend.shape[axis] + x.shape[axis] + append.shape[axis] - n`` + + The data type of the returned array is determined by the Type + Promotion Rules. + """ + + if not isinstance(x, dpt.usm_ndarray): + raise TypeError( + "Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x)}" + ) + x_nd = x.ndim + axis = normalize_axis_index(operator.index(axis), x_nd) + n = operator.index(n) + if n < 0: + raise ValueError(f"`n` must be positive, got {n}") + arr = _concat_diff_input(x, axis, prepend, append) + if n == 0: + return arr + # form slices and recurse + sl0 = tuple( + slice(None) if i != axis else slice(1, None) for i in range(x_nd) + ) + sl1 = tuple( + slice(None) if i != axis else slice(None, -1) for i in range(x_nd) + ) + + diff_op = dpt.not_equal if x.dtype == dpt.bool else dpt.subtract + if n > 1: + arr_tmp0 = diff_op(arr[sl0], arr[sl1]) + arr_tmp1 = diff_op(arr_tmp0[sl0], arr_tmp0[sl1]) + n = n - 2 + if n > 0: + sl3 = tuple( + slice(None) if i != axis else slice(None, -2) + for i in range(x_nd) + ) + for _ in range(n): + arr_tmp0_sliced = arr_tmp0[sl3] + diff_op(arr_tmp1[sl0], arr_tmp1[sl1], out=arr_tmp0_sliced) + arr_tmp0, arr_tmp1 = arr_tmp1, arr_tmp0_sliced + arr = arr_tmp1 + else: + arr = diff_op(arr[sl0], arr[sl1]) + return arr diff --git a/dpctl_ext/tensor/libtensor/include/kernels/reductions.hpp b/dpctl_ext/tensor/libtensor/include/kernels/reductions.hpp new file mode 100644 index 000000000000..ee6431dec637 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/reductions.hpp @@ -0,0 +1,3323 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor reduction along axis. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "dpctl_tensor_types.hpp" +#include "utils/math_utils.hpp" +#include "utils/offset_utils.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl::tensor::kernels +{ + +using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; + +namespace reduction_detail +{ + +inline std::size_t get_work_group_size(const sycl::device &d) +{ + // prevents running out of resources on CPU + return std::min( + 2048, d.get_info() / 2); +} + +} // namespace reduction_detail + +template +struct needs_workaround +{ + static constexpr bool value = + (std::is_same_v> && + (std::is_same_v || + std::is_same_v)) || + (__LIBSYCL_MAJOR_VERSION < 7 && std::is_same_v && + std::is_same_v>); +}; + +template +struct can_use_reduce_over_group +{ + static constexpr bool value = + sycl::has_known_identity::value && + !needs_workaround::value; +}; + +template +struct SequentialReduction +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + outT identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + std::size_t reduction_max_gid_ = 0; + +public: + SequentialReduction(const argT *inp, + outT *res, + const ReductionOp &reduction_op, + const outT &identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + std::size_t reduction_size) + : inp_(inp), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size) + { + } + + void operator()(sycl::id<1> id) const + { + + auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]); + const ssize_t &inp_iter_offset = + inp_out_iter_offsets_.get_first_offset(); + const ssize_t &out_iter_offset = + inp_out_iter_offsets_.get_second_offset(); + + outT red_val(identity_); + for (std::size_t m = 0; m < reduction_max_gid_; ++m) { + const ssize_t inp_reduction_offset = inp_reduced_dims_indexer_(m); + const ssize_t inp_offset = inp_iter_offset + inp_reduction_offset; + + using dpctl::tensor::type_utils::convert_impl; + outT val; + if constexpr (su_ns::IsLogicalAnd::value || + su_ns::IsLogicalOr::value) + { + val = convert_impl(inp_[inp_offset]); + } + else { + val = convert_impl(inp_[inp_offset]); + } + red_val = reduction_op_(red_val, val); + } + + out_[out_iter_offset] = red_val; + } +}; + +/* === Reduction, using sycl::reduce_over_group, and sycl::atomic_ref === */ + +/* + This kernel only works for outT with sizeof(outT) == 4, or sizeof(outT) == 8 + if the device has aspect atomic64 and only with those supported by + sycl::atomic_ref +*/ +template +struct ReductionOverGroupWithAtomicFunctor +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + outT identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + std::size_t reduction_max_gid_ = 0; + std::size_t iter_gws_ = 1; + std::size_t reductions_per_wi = 16; + +public: + ReductionOverGroupWithAtomicFunctor( + const argT *data, + outT *res, + const ReductionOp &reduction_op, + const outT &identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + std::size_t reduction_size, + std::size_t iteration_size, + std::size_t reduction_size_per_wi) + : inp_(data), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), iter_gws_(iteration_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t iter_gid = it.get_group(0) % iter_gws_; + const std::size_t reduction_batch_id = it.get_group(0) / iter_gws_; + + const std::size_t reduction_lid = it.get_local_id(0); + const std::size_t wg = + it.get_local_range(0); // 0 <= reduction_lid < wg + + // work-items operate over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + const auto &inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + outT local_red_val(identity_); + std::size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + std::size_t arg_reduce_gid_max = std::min( + reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg); + + for (std::size_t arg_reduce_gid = arg_reduce_gid0; + arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg) + { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + using dpctl::tensor::type_utils::convert_impl; + outT val; + if constexpr (su_ns::IsLogicalAnd::value || + su_ns::IsLogicalOr::value) + { + // handle nans + val = convert_impl(inp_[inp_offset]); + } + else { + val = convert_impl(inp_[inp_offset]); + } + + local_red_val = reduction_op_(local_red_val, val); + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + outT red_val_over_wg; + if constexpr (su_ns::IsLogicalAnd::value) { + red_val_over_wg = static_cast( + sycl::all_of_group(work_group, local_red_val)); + } + else if constexpr (su_ns::IsLogicalOr::value) { + red_val_over_wg = static_cast( + sycl::any_of_group(work_group, local_red_val)); + } + else { + red_val_over_wg = sycl::reduce_over_group(work_group, local_red_val, + identity_, reduction_op_); + } + + if (work_group.leader()) { + sycl::atomic_ref + res_ref(out_[out_iter_offset]); + if constexpr (su_ns::IsPlus::value) { + res_ref += red_val_over_wg; + } + else if constexpr (su_ns::IsMaximum::value) { + res_ref.fetch_max(red_val_over_wg); + } + else if constexpr (su_ns::IsMinimum::value) { + res_ref.fetch_min(red_val_over_wg); + } + else if constexpr (su_ns::IsLogicalAnd::value) { + res_ref.fetch_and(red_val_over_wg); + } + else if constexpr (su_ns::IsLogicalOr::value) { + res_ref.fetch_or(red_val_over_wg); + } + else { + outT read_val = res_ref.load(); + outT new_val{}; + do { + new_val = reduction_op_(read_val, red_val_over_wg); + } while (!res_ref.compare_exchange_strong(read_val, new_val)); + } + } + } +}; + +/* === Reduction, using custom_reduce_over_group, and sycl::atomic_ref === */ + +template +struct CustomReductionOverGroupWithAtomicFunctor +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + outT identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + SlmT local_mem_; + std::size_t reduction_max_gid_ = 0; + std::size_t iter_gws_ = 1; + std::size_t reductions_per_wi = 16; + +public: + CustomReductionOverGroupWithAtomicFunctor( + const argT *data, + outT *res, + const ReductionOp &reduction_op, + const outT &identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + SlmT local_mem, + std::size_t reduction_size, + std::size_t iteration_size, + std::size_t reduction_size_per_wi) + : inp_(data), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + local_mem_(local_mem), reduction_max_gid_(reduction_size), + iter_gws_(iteration_size), reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t iter_gid = it.get_group(0) % iter_gws_; + const std::size_t reduction_batch_id = it.get_group(0) / iter_gws_; + + const std::size_t reduction_lid = it.get_local_id(0); + const std::size_t wg = + it.get_local_range(0); // 0 <= reduction_lid < wg + + // work-items operate over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + const auto &inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + outT local_red_val(identity_); + std::size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + std::size_t arg_reduce_gid_max = std::min( + reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg); + + for (std::size_t arg_reduce_gid = arg_reduce_gid0; + arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg) + { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + using dpctl::tensor::type_utils::convert_impl; + outT val; + if constexpr (su_ns::IsLogicalAnd::value || + su_ns::IsLogicalOr::value) + { + // handle nans + val = convert_impl(inp_[inp_offset]); + } + else { + val = convert_impl(inp_[inp_offset]); + } + + local_red_val = reduction_op_(local_red_val, val); + } + + auto work_group = it.get_group(); + outT red_val_over_wg = su_ns::custom_reduce_over_group( + work_group, local_mem_, local_red_val, reduction_op_); + + if (work_group.leader()) { + sycl::atomic_ref + res_ref(out_[out_iter_offset]); + // retain these checks in case a reduce_over_group work-around is + // needed + if constexpr (su_ns::IsSyclPlus::value) { + res_ref += red_val_over_wg; + } + else if constexpr (su_ns::IsSyclMaximum::value) { + res_ref.fetch_max(red_val_over_wg); + } + else if constexpr (su_ns::IsSyclMinimum::value) { + res_ref.fetch_min(red_val_over_wg); + } + else if constexpr (su_ns::IsSyclLogicalAnd::value) { + res_ref.fetch_and(red_val_over_wg); + } + else if constexpr (su_ns::IsSyclLogicalOr::value) + { + res_ref.fetch_or(red_val_over_wg); + } + else { + outT read_val = res_ref.load(); + outT new_val{}; + do { + new_val = reduction_op_(read_val, red_val_over_wg); + } while (!res_ref.compare_exchange_strong(read_val, new_val)); + } + } + } +}; + +template +struct ReductionOverGroupNoAtomicFunctor +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + outT identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + std::size_t reduction_max_gid_ = 0; + std::size_t iter_gws_ = 1; + std::size_t reductions_per_wi = 16; + +public: + ReductionOverGroupNoAtomicFunctor( + const argT *data, + outT *res, + const ReductionOp &reduction_op, + const outT &identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + std::size_t reduction_size, + std::size_t iteration_size, + std::size_t reduction_size_per_wi) + : inp_(data), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), iter_gws_(iteration_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t reduction_lid = it.get_local_id(0); + const std::size_t wg = + it.get_local_range(0); // 0 <= reduction_lid < wg + + const std::size_t iter_gid = it.get_group(0) % iter_gws_; + const std::size_t reduction_batch_id = it.get_group(0) / iter_gws_; + const std::size_t n_reduction_groups = + it.get_group_range(0) / iter_gws_; + + // work-items operates over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + const auto &inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + outT local_red_val(identity_); + std::size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + for (std::size_t m = 0; m < reductions_per_wi; ++m) { + std::size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + + if (arg_reduce_gid < reduction_max_gid_) { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + using dpctl::tensor::type_utils::convert_impl; + outT val; + if constexpr (su_ns::IsLogicalAnd::value || + su_ns::IsLogicalOr::value) + { + // handle nans + val = convert_impl(inp_[inp_offset]); + } + else { + val = convert_impl(inp_[inp_offset]); + } + + local_red_val = reduction_op_(local_red_val, val); + } + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + outT red_val_over_wg; + if constexpr (su_ns::IsLogicalAnd::value) { + red_val_over_wg = sycl::all_of_group(work_group, local_red_val); + } + else if constexpr (su_ns::IsLogicalOr::value) { + red_val_over_wg = sycl::any_of_group(work_group, local_red_val); + } + else { + red_val_over_wg = sycl::reduce_over_group(work_group, local_red_val, + identity_, reduction_op_); + } + + if (work_group.leader()) { + // each group writes to a different memory location + out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = + red_val_over_wg; + } + } +}; + +/* = Reduction, using custom_reduce_over_group and not using atomic_ref*/ + +template +struct CustomReductionOverGroupNoAtomicFunctor +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + outT identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + SlmT local_mem_; + std::size_t reduction_max_gid_ = 0; + std::size_t iter_gws_ = 1; + std::size_t reductions_per_wi = 16; + +public: + CustomReductionOverGroupNoAtomicFunctor( + const argT *data, + outT *res, + const ReductionOp &reduction_op, + const outT &identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + SlmT local_mem, + std::size_t reduction_size, + std::size_t iteration_size, + std::size_t reduction_size_per_wi) + : inp_(data), out_(res), reduction_op_(reduction_op), + identity_(identity_val), inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + local_mem_(local_mem), reduction_max_gid_(reduction_size), + iter_gws_(iteration_size), reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t reduction_lid = it.get_local_id(0); + const std::size_t wg = + it.get_local_range(0); // 0 <= reduction_lid < wg + + const std::size_t iter_gid = it.get_group(0) % iter_gws_; + const std::size_t reduction_batch_id = it.get_group(0) / iter_gws_; + const std::size_t n_reduction_groups = + it.get_group_range(0) / iter_gws_; + + // work-items operates over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + auto inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + outT local_red_val(identity_); + std::size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + for (std::size_t m = 0; m < reductions_per_wi; ++m) { + std::size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + + if (arg_reduce_gid < reduction_max_gid_) { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + using dpctl::tensor::type_utils::convert_impl; + outT val; + if constexpr (std::is_same_v> || + std::is_same_v>) + { + // handle nans + val = convert_impl(inp_[inp_offset]); + } + else { + val = convert_impl(inp_[inp_offset]); + } + + local_red_val = reduction_op_(local_red_val, val); + } + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + outT red_val_over_wg = su_ns::custom_reduce_over_group( + work_group, local_mem_, local_red_val, reduction_op_); + + if (work_group.leader()) { + // each group writes to a different memory location + out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = + red_val_over_wg; + } + } +}; + +template < + typename argTy, + typename resTy, + typename ReductionOpT, + typename InputOutputIterIndexerT, + typename ReductionIndexerT, + template + class kernel_name_token> +sycl::event + sequential_reduction(sycl::queue &exec_q, + const argTy *arg, + resTy *res, + resTy identity_val, + std::size_t iter_nelems, + std::size_t reduction_nelems, + const InputOutputIterIndexerT &in_out_iter_indexer, + const ReductionIndexerT &reduction_indexer, + const std::vector &depends) +{ + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using KernelName = + class kernel_name_token; + + cgh.parallel_for( + sycl::range<1>(iter_nelems), + SequentialReduction( + arg, res, ReductionOpT(), identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems)); + }); + + return red_ev; +} + +template +class custom_reduction_wrapper; + +template < + typename argTy, + typename resTy, + typename ReductionOpT, + typename InputOutputIterIndexerT, + typename ReductionIndexerT, + template + class kernel_name_token> +sycl::event + submit_atomic_reduction(sycl::queue &exec_q, + const argTy *arg, + resTy *res, + resTy identity_val, + std::size_t wg, + std::size_t iter_nelems, + std::size_t reduction_nelems, + std::size_t reductions_per_wi, + std::size_t reduction_groups, + const InputOutputIterIndexerT &in_out_iter_indexer, + const ReductionIndexerT &reduction_indexer, + const std::vector &depends) +{ + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + auto ndRange = sycl::nd_range<1>(globalRange, localRange); + + if constexpr (can_use_reduce_over_group::value) { + using KernelName = + class kernel_name_token; + + cgh.parallel_for( + ndRange, + ReductionOverGroupWithAtomicFunctor( + arg, res, ReductionOpT(), identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems, iter_nelems, + reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + + using KernelName = class custom_reduction_wrapper< + kernel_name_token>; + + cgh.parallel_for( + ndRange, + CustomReductionOverGroupWithAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + arg, res, ReductionOpT(), identity_val, in_out_iter_indexer, + reduction_indexer, local_memory, reduction_nelems, + iter_nelems, reductions_per_wi)); + } + }); + return red_ev; +} + +template +class reduction_over_group_with_atomics_init_krn; + +template +class reduction_seq_krn; + +template +class reduction_over_group_with_atomics_krn; + +typedef sycl::event (*reduction_strided_impl_fn_ptr)( + sycl::queue &, + std::size_t, + std::size_t, + const char *, + char *, + int, + const ssize_t *, + ssize_t, + ssize_t, + int, + const ssize_t *, + ssize_t, + const std::vector &); + +using dpctl::tensor::sycl_utils::choose_workgroup_size; + +template +sycl::event reduction_over_group_with_atomics_strided_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of rows in a + // matrix when reducing over rows) + std::size_t reduction_nelems, // size of each reduction (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + int iter_nd, + const ssize_t *iter_shape_and_strides, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + int red_nd, + const ssize_t *reduction_shape_stride, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + static constexpr resTy identity_val = + su_ns::Identity::value; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, iter_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); + + return comp_ev; + } + else { + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + using IndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + + const ssize_t *const &res_shape = iter_shape_and_strides; + const ssize_t *const &res_strides = + iter_shape_and_strides + 2 * iter_nd; + const IndexerT res_indexer(iter_nd, iter_res_offset, res_shape, + res_strides); + using InitKernelName = + class reduction_over_group_with_atomics_init_krn; + cgh.depends_on(depends); + + cgh.parallel_for( + sycl::range<1>(iter_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = identity_val; + }); + }); + + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, iter_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + static constexpr std::size_t preferred_reductions_per_wi = 8; + std::size_t reductions_per_wi = + (reduction_nelems < preferred_reductions_per_wi * wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) + : preferred_reductions_per_wi; + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + sycl::event comp_ev = + submit_atomic_reduction( + exec_q, arg_tp, res_tp, identity_val, wg, iter_nelems, + reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, {res_init_ev}); + + return comp_ev; + } +} + +// Contig + +typedef sycl::event (*reduction_contig_impl_fn_ptr)( + sycl::queue &, + std::size_t, + std::size_t, + const char *, + char *, + ssize_t, + ssize_t, + ssize_t, + const std::vector &); + +/* @brief Reduce rows in a matrix */ +template +sycl::event reduction_axis1_over_group_with_atomics_contig_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of rows in a + // matrix when reducing over rows) + std::size_t reduction_nelems, // size of each reduction (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; + + static constexpr resTy identity_val = + su_ns::Identity::value; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using InputIterIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + const InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{/* size */ iter_nelems, + /* step */ reduction_nelems}, + NoOpIndexerT{}}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); + + return comp_ev; + } + else { + sycl::event res_init_ev = exec_q.fill( + res_tp, resTy(identity_val), iter_nelems, depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using RowsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + RowsIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + const RowsIndexerT rows_indexer{/* size */ iter_nelems, + /* step */ reduction_nelems}; + static constexpr NoOpIndexerT result_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{rows_indexer, + result_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + static constexpr std::size_t preferred_reductions_per_wi = 8; + std::size_t reductions_per_wi = + (reduction_nelems < preferred_reductions_per_wi * wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) + : preferred_reductions_per_wi; + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + sycl::event comp_ev = + submit_atomic_reduction( + exec_q, arg_tp, res_tp, identity_val, wg, iter_nelems, + reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, {res_init_ev}); + + return comp_ev; + } +} + +/* @brief Reduce rows in a matrix */ +template +sycl::event reduction_axis0_over_group_with_atomics_contig_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of cols in a + // matrix when reducing over cols) + std::size_t reduction_nelems, // size of each reduction (length of cols, + // i.e. number of rows) + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; + + static constexpr resTy identity_val = + su_ns::Identity::value; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); + + return comp_ev; + } + else { + sycl::event res_init_ev = exec_q.fill( + res_tp, resTy(identity_val), iter_nelems, depends); + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = ColsIndexerT; + + static constexpr NoOpIndexerT columns_indexer{}; + static constexpr NoOpIndexerT result_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + result_indexer}; + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + static constexpr std::size_t preferred_reductions_per_wi = 8; + std::size_t reductions_per_wi = + (reduction_nelems < preferred_reductions_per_wi * wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) + : preferred_reductions_per_wi; + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + sycl::event comp_ev = + submit_atomic_reduction( + exec_q, arg_tp, res_tp, identity_val, wg, iter_nelems, + reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, {res_init_ev}); + + return comp_ev; + } +} + +/* = Reduction, using sycl::reduce_over_group, but not using atomic_ref = */ + +template < + typename argTy, + typename resTy, + typename ReductionOpT, + typename InputOutputIterIndexerT, + typename ReductionIndexerT, + template + class kernel_name_token> +sycl::event submit_no_atomic_reduction( + sycl::queue &exec_q, + const argTy *arg, + resTy *res, + resTy identity_val, + std::size_t wg, + std::size_t iter_nelems, + std::size_t reduction_nelems, + std::size_t reductions_per_wi, + std::size_t reduction_groups, + const InputOutputIterIndexerT &in_out_iter_indexer, + const ReductionIndexerT &reduction_indexer, + const std::vector &depends) +{ + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + auto ndRange = sycl::nd_range<1>(globalRange, localRange); + + if constexpr (can_use_reduce_over_group::value) { + using KernelName = + class kernel_name_token; + + cgh.parallel_for( + ndRange, + ReductionOverGroupNoAtomicFunctor( + arg, res, ReductionOpT(), identity_val, in_out_iter_indexer, + reduction_indexer, reduction_nelems, iter_nelems, + reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = class custom_reduction_wrapper< + kernel_name_token>; + + cgh.parallel_for( + ndRange, + CustomReductionOverGroupNoAtomicFunctor< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT>( + arg, res, ReductionOpT(), identity_val, in_out_iter_indexer, + reduction_indexer, local_memory, reduction_nelems, + iter_nelems, reductions_per_wi)); + } + }); + return red_ev; +} + +template +class reduction_over_group_temps_krn; + +typedef sycl::event (*reduction_strided_impl_fn_ptr)( + sycl::queue &, + std::size_t, + std::size_t, + const char *, + char *, + int, + const ssize_t *, + ssize_t, + ssize_t, + int, + const ssize_t *, + ssize_t, + const std::vector &); + +template +class reduction_over_group_temps_empty_krn; + +template +sycl::event reduction_over_group_temps_strided_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of rows in a + // matrix when reducing over rows) + std::size_t reduction_nelems, // size of each reduction (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + int iter_nd, + const ssize_t *iter_shape_and_strides, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + int red_nd, + const ssize_t *reduction_shape_stride, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + static constexpr resTy identity_val = + su_ns::Identity::value; + + if (reduction_nelems == 0) { + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + using IndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + + const ssize_t *const &res_shape = iter_shape_and_strides; + const ssize_t *const &res_strides = + iter_shape_and_strides + 2 * iter_nd; + const IndexerT res_indexer(iter_nd, iter_res_offset, res_shape, + res_strides); + using InitKernelName = + class reduction_over_group_temps_empty_krn; + cgh.depends_on(depends); + + cgh.parallel_for( + sycl::range<1>(iter_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = identity_val; + }); + }); + + return res_init_ev; + } + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, iter_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); + + return comp_ev; + } + + static constexpr std::size_t preferred_reductions_per_wi = 8; + // prevents running out of resources on CPU + std::size_t max_wg = reduction_detail::get_work_group_size(d); + + std::size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + // Perform reduction using one 1 work-group per iteration, + // can output directly to res + + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, iter_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event comp_ev = submit_no_atomic_reduction< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, arg_tp, res_tp, identity_val, wg, iter_nelems, + reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, depends); + + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + std::size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + const std::size_t tmp_alloc_size = + iter_nelems * (reduction_groups + second_iter_reduction_groups_); + auto tmp_owner = dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; + ; + + sycl::event first_reduction_ev; + { + using InputIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + + // Only 2*iter_nd entries describing shape and strides of + // iterated dimensions of input array from + // iter_shape_and_strides are going to be accessed by + // inp_indexer + const InputIndexerT inp_indexer(iter_nd, iter_arg_offset, + iter_shape_and_strides); + static constexpr ResIndexerT noop_tmp_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + noop_tmp_indexer}; + const ReductionIndexerT reduction_indexer{ + red_nd, reduction_arg_offset, reduction_shape_stride}; + + first_reduction_ev = submit_no_atomic_reduction< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, arg_tp, partially_reduced_tmp, identity_val, wg, + iter_nelems, reduction_nelems, preferred_reductions_per_wi, + reduction_groups, in_out_iter_indexer, reduction_indexer, + depends); + } + + std::size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + std::size_t reduction_groups_ = + (remaining_reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + sycl::event partial_reduction_ev; + { + using InputIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ reduction_groups_}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{ + inp_indexer, res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + partial_reduction_ev = submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, temp_arg, temp2_arg, identity_val, wg, iter_nelems, + remaining_reduction_nelems, preferred_reductions_per_wi, + reduction_groups_, in_out_iter_indexer, reduction_indexer, + {dependent_ev}); + } + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + + // final reduction to res + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ remaining_reduction_nelems}; + const ResIndexerT res_iter_indexer{ + iter_nd, iter_res_offset, + /* shape */ iter_shape_and_strides, + /* strides */ iter_shape_and_strides + 2 * iter_nd}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = std::max( + 1, (remaining_reduction_nelems + wg - 1) / wg); + + reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event final_reduction_ev = submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, temp_arg, res_tp, identity_val, wg, iter_nelems, + remaining_reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, {dependent_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {final_reduction_ev}, tmp_owner); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; + } +} + +template +sycl::event reduction_axis1_over_group_temps_contig_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of rows in a + // matrix when reducing over rows) + std::size_t reduction_nelems, // size of each reduction (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; + + static constexpr resTy identity_val = + su_ns::Identity::value; + + if (reduction_nelems == 0) { + sycl::event res_init_ev = exec_q.fill( + res_tp, resTy(identity_val), iter_nelems, depends); + + return res_init_ev; + } + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using InputIterIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + const InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{/* size */ iter_nelems, + /* step */ reduction_nelems}, + NoOpIndexerT{}}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); + + return comp_ev; + } + + static constexpr std::size_t preferred_reductions_per_wi = 8; + // prevents running out of resources on CPU + std::size_t max_wg = reduction_detail::get_work_group_size(d); + + std::size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + // Perform reduction using one 1 work-group per iteration, + // can output directly to res + + using InputIterIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + const InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{/* size */ iter_nelems, + /* step */ reduction_nelems}, + NoOpIndexerT{}}; + static constexpr ReductionIndexerT reduction_indexer{}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event comp_ev = submit_no_atomic_reduction< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, arg_tp, res_tp, identity_val, wg, iter_nelems, + reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, depends); + + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + std::size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + const std::size_t tmp_alloc_size = + iter_nelems * (reduction_groups + second_iter_reduction_groups_); + auto tmp_owner = dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; + + sycl::event first_reduction_ev; + { + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using RowsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + RowsIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + const RowsIndexerT rows_indexer{/* size */ iter_nelems, + /* step */ reduction_nelems}; + static constexpr NoOpIndexerT noop_tmp_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{rows_indexer, + noop_tmp_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + first_reduction_ev = submit_no_atomic_reduction< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, arg_tp, partially_reduced_tmp, identity_val, wg, + iter_nelems, reduction_nelems, preferred_reductions_per_wi, + reduction_groups, in_out_iter_indexer, reduction_indexer, + depends); + } + + std::size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + std::size_t reduction_groups_ = + (remaining_reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ reduction_groups_}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event partial_reduction_ev = submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, temp_arg, temp2_arg, identity_val, wg, iter_nelems, + remaining_reduction_nelems, preferred_reductions_per_wi, + reduction_groups_, in_out_iter_indexer, reduction_indexer, + {dependent_ev}); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + + // final reduction to res + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ remaining_reduction_nelems}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = std::max( + 1, (remaining_reduction_nelems + wg - 1) / wg); + + reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event final_reduction_ev = submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, temp_arg, res_tp, identity_val, wg, iter_nelems, + remaining_reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, {dependent_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {final_reduction_ev}, tmp_owner); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; + } +} + +template +sycl::event reduction_axis0_over_group_temps_contig_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of rows in a + // matrix when reducing over rows) + std::size_t reduction_nelems, // size of each reduction (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; + + static constexpr resTy identity_val = + su_ns::Identity::value; + + if (reduction_nelems == 0) { + sycl::event res_init_ev = exec_q.fill( + res_tp, resTy(identity_val), iter_nelems, depends); + + return res_init_ev; + } + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + sycl::event comp_ev = + sequential_reduction( + exec_q, arg_tp, res_tp, identity_val, iter_nelems, + reduction_nelems, in_out_iter_indexer, reduction_indexer, + depends); + + return comp_ev; + } + + static constexpr std::size_t preferred_reductions_per_wi = 8; + // prevents running out of resources on CPU + std::size_t max_wg = reduction_detail::get_work_group_size(d); + + std::size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + // Perform reduction using one 1 work-group per iteration, + // can output directly to res + + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = ColsIndexerT; + + static constexpr NoOpIndexerT columns_indexer{}; + static constexpr NoOpIndexerT result_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + result_indexer}; + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event comp_ev = submit_no_atomic_reduction< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, arg_tp, res_tp, identity_val, wg, iter_nelems, + reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, depends); + + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + std::size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + const std::size_t tmp_alloc_size = + iter_nelems * (reduction_groups + second_iter_reduction_groups_); + + auto tmp_owner = dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; + + sycl::event first_reduction_ev; + { + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = ColsIndexerT; + + static constexpr NoOpIndexerT columns_indexer{}; + static constexpr NoOpIndexerT noop_tmp_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + noop_tmp_indexer}; + const ReductionIndexerT reduction_indexer{ + /* size */ reduction_nelems, + /* step */ iter_nelems}; + + first_reduction_ev = submit_no_atomic_reduction< + argTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, arg_tp, partially_reduced_tmp, identity_val, wg, + iter_nelems, reduction_nelems, preferred_reductions_per_wi, + reduction_groups, in_out_iter_indexer, reduction_indexer, + depends); + } + + std::size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + std::size_t reduction_groups_ = + (remaining_reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ reduction_groups_}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event partial_reduction_ev = submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, temp_arg, temp2_arg, identity_val, wg, iter_nelems, + remaining_reduction_nelems, preferred_reductions_per_wi, + reduction_groups_, in_out_iter_indexer, reduction_indexer, + {dependent_ev}); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + + // final reduction to res + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ remaining_reduction_nelems}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = std::max( + 1, (remaining_reduction_nelems + wg - 1) / wg); + + reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event final_reduction_ev = submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, reduction_over_group_temps_krn>( + exec_q, temp_arg, res_tp, identity_val, wg, iter_nelems, + remaining_reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, {dependent_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {final_reduction_ev}, tmp_owner); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; + } +} + +// Argmax and Argmin + +/* Sequential search reduction */ + +template +struct SequentialSearchReduction +{ +private: + const argT *inp_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + argT identity_; + IdxReductionOp idx_reduction_op_; + outT idx_identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + std::size_t reduction_max_gid_ = 0; + +public: + SequentialSearchReduction( + const argT *inp, + outT *res, + const ReductionOp &reduction_op, + const argT &identity_val, + const IdxReductionOp &idx_reduction_op, + const outT &idx_identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + std::size_t reduction_size) + : inp_(inp), out_(res), reduction_op_(reduction_op), + identity_(identity_val), idx_reduction_op_(idx_reduction_op), + idx_identity_(idx_identity_val), + inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size) + { + } + + void operator()(sycl::id<1> id) const + { + + auto const &inp_out_iter_offsets_ = inp_out_iter_indexer_(id[0]); + const ssize_t &inp_iter_offset = + inp_out_iter_offsets_.get_first_offset(); + const ssize_t &out_iter_offset = + inp_out_iter_offsets_.get_second_offset(); + + argT red_val(identity_); + outT idx_val(idx_identity_); + for (std::size_t m = 0; m < reduction_max_gid_; ++m) { + const ssize_t inp_reduction_offset = inp_reduced_dims_indexer_(m); + const ssize_t inp_offset = inp_iter_offset + inp_reduction_offset; + + argT val = inp_[inp_offset]; + if (val == red_val) { + idx_val = idx_reduction_op_(idx_val, static_cast(m)); + } + else { + if constexpr (su_ns::IsMinimum::value) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using dpctl::tensor::math_utils::less_complex; + // less_complex always returns false for NaNs, so check + if (less_complex(val, red_val) || + std::isnan(std::real(val)) || + std::isnan(std::imag(val))) + { + red_val = val; + idx_val = static_cast(m); + } + } + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { + if (val < red_val || std::isnan(val)) { + red_val = val; + idx_val = static_cast(m); + } + } + else { + if (val < red_val) { + red_val = val; + idx_val = static_cast(m); + } + } + } + else if constexpr (su_ns::IsMaximum::value) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using dpctl::tensor::math_utils::greater_complex; + if (greater_complex(val, red_val) || + std::isnan(std::real(val)) || + std::isnan(std::imag(val))) + { + red_val = val; + idx_val = static_cast(m); + } + } + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { + if (val > red_val || std::isnan(val)) { + red_val = val; + idx_val = static_cast(m); + } + } + else { + if (val > red_val) { + red_val = val; + idx_val = static_cast(m); + } + } + } + } + } + out_[out_iter_offset] = idx_val; + } +}; + +/* = Search reduction using reduce_over_group*/ + +template +struct SearchReduction +{ +private: + const argT *inp_ = nullptr; + argT *vals_ = nullptr; + const outT *inds_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + argT identity_; + IdxReductionOp idx_reduction_op_; + outT idx_identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + std::size_t reduction_max_gid_ = 0; + std::size_t iter_gws_ = 1; + std::size_t reductions_per_wi = 16; + +public: + SearchReduction(const argT *data, + argT *vals, + const outT *inds, + outT *res, + const ReductionOp &reduction_op, + const argT &identity_val, + const IdxReductionOp &idx_reduction_op, + const outT &idx_identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + std::size_t reduction_size, + std::size_t iteration_size, + std::size_t reduction_size_per_wi) + : inp_(data), vals_(vals), inds_(inds), out_(res), + reduction_op_(reduction_op), identity_(identity_val), + idx_reduction_op_(idx_reduction_op), idx_identity_(idx_identity_val), + inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), iter_gws_(iteration_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t reduction_lid = it.get_local_id(0); + const std::size_t wg = + it.get_local_range(0); // 0 <= reduction_lid < wg + + const std::size_t iter_gid = it.get_group(0) % iter_gws_; + const std::size_t reduction_batch_id = it.get_group(0) / iter_gws_; + const std::size_t n_reduction_groups = + it.get_group_range(0) / iter_gws_; + + // work-items operates over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + const auto &inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + argT local_red_val(identity_); + outT local_idx(idx_identity_); + std::size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + for (std::size_t m = 0; m < reductions_per_wi; ++m) { + std::size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + + if (arg_reduce_gid < reduction_max_gid_) { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + argT val = inp_[inp_offset]; + if (val == local_red_val) { + if constexpr (!First) { + local_idx = + idx_reduction_op_(local_idx, inds_[inp_offset]); + } + else { + local_idx = idx_reduction_op_( + local_idx, static_cast(arg_reduce_gid)); + } + } + else { + if constexpr (su_ns::IsMinimum::value) { + if (val < local_red_val) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = static_cast(arg_reduce_gid); + } + } + } + else if constexpr (su_ns::IsMaximum::value) { + if (val > local_red_val) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = static_cast(arg_reduce_gid); + } + } + } + } + } + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + argT red_val_over_wg = sycl::reduce_over_group( + work_group, local_red_val, identity_, reduction_op_); + + if constexpr (std::is_integral_v) { + local_idx = + (red_val_over_wg == local_red_val) ? local_idx : idx_identity_; + } + else { + local_idx = + (red_val_over_wg == local_red_val || + std::isnan(red_val_over_wg) || std::isnan(local_red_val)) + ? local_idx + : idx_identity_; + } + outT idx_over_wg = sycl::reduce_over_group( + work_group, local_idx, idx_identity_, idx_reduction_op_); + + if (work_group.leader()) { + // each group writes to a different memory location + if constexpr (!Last) { + // if not the final reduction, write value corresponding to + // an index to a temporary + vals_[out_iter_offset * n_reduction_groups + + reduction_batch_id] = red_val_over_wg; + } + out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = + idx_over_wg; + } + } +}; + +/* = Search reduction using custom_reduce_over_group*/ + +template +struct CustomSearchReduction +{ +private: + const argT *inp_ = nullptr; + argT *vals_ = nullptr; + const outT *inds_ = nullptr; + outT *out_ = nullptr; + ReductionOp reduction_op_; + argT identity_; + IdxReductionOp idx_reduction_op_; + outT idx_identity_; + InputOutputIterIndexerT inp_out_iter_indexer_; + InputRedIndexerT inp_reduced_dims_indexer_; + SlmT local_mem_; + std::size_t reduction_max_gid_ = 0; + std::size_t iter_gws_ = 1; + std::size_t reductions_per_wi = 16; + +public: + CustomSearchReduction(const argT *data, + argT *vals, + outT *inds, + outT *res, + const ReductionOp &reduction_op, + const argT &identity_val, + const IdxReductionOp &idx_reduction_op, + const outT &idx_identity_val, + const InputOutputIterIndexerT &arg_res_iter_indexer, + const InputRedIndexerT &arg_reduced_dims_indexer, + SlmT local_mem, + std::size_t reduction_size, + std::size_t iteration_size, + std::size_t reduction_size_per_wi) + : inp_(data), vals_(vals), inds_(inds), out_(res), + reduction_op_(reduction_op), identity_(identity_val), + idx_reduction_op_(idx_reduction_op), idx_identity_(idx_identity_val), + inp_out_iter_indexer_(arg_res_iter_indexer), + inp_reduced_dims_indexer_(arg_reduced_dims_indexer), + local_mem_(local_mem), reduction_max_gid_(reduction_size), + iter_gws_(iteration_size), reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t reduction_lid = it.get_local_id(0); + const std::size_t wg = + it.get_local_range(0); // 0 <= reduction_lid < wg + + const std::size_t iter_gid = it.get_group(0) % iter_gws_; + const std::size_t reduction_batch_id = it.get_group(0) / iter_gws_; + const std::size_t n_reduction_groups = + it.get_group_range(0) / iter_gws_; + + // work-items operates over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + + const auto &inp_out_iter_offsets_ = inp_out_iter_indexer_(iter_gid); + const auto &inp_iter_offset = inp_out_iter_offsets_.get_first_offset(); + const auto &out_iter_offset = inp_out_iter_offsets_.get_second_offset(); + + argT local_red_val(identity_); + outT local_idx(idx_identity_); + std::size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + for (std::size_t m = 0; m < reductions_per_wi; ++m) { + std::size_t arg_reduce_gid = arg_reduce_gid0 + m * wg; + + if (arg_reduce_gid < reduction_max_gid_) { + auto inp_reduction_offset = + inp_reduced_dims_indexer_(arg_reduce_gid); + auto inp_offset = inp_iter_offset + inp_reduction_offset; + + argT val = inp_[inp_offset]; + if (val == local_red_val) { + if constexpr (!First) { + local_idx = + idx_reduction_op_(local_idx, inds_[inp_offset]); + } + else { + local_idx = idx_reduction_op_( + local_idx, static_cast(arg_reduce_gid)); + } + } + else { + if constexpr (su_ns::IsMinimum::value) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using dpctl::tensor::math_utils::less_complex; + // less_complex always returns false for NaNs, so + // check + if (less_complex(val, local_red_val) || + std::isnan(std::real(val)) || + std::isnan(std::imag(val))) + { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { + if (val < local_red_val || std::isnan(val)) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + else { + if (val < local_red_val) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + } + else if constexpr (su_ns::IsMaximum::value) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + using dpctl::tensor::math_utils::greater_complex; + if (greater_complex(val, local_red_val) || + std::isnan(std::real(val)) || + std::isnan(std::imag(val))) + { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { + if (val > local_red_val || std::isnan(val)) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + else { + if (val > local_red_val) { + local_red_val = val; + if constexpr (!First) { + local_idx = inds_[inp_offset]; + } + else { + local_idx = + static_cast(arg_reduce_gid); + } + } + } + } + } + } + } + + auto work_group = it.get_group(); + // This only works if reduction_op_ is from small set of operators + argT red_val_over_wg = su_ns::custom_reduce_over_group( + work_group, local_mem_, local_red_val, reduction_op_); + + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + // equality does not hold for NaNs, so check here + local_idx = (red_val_over_wg == local_red_val || + std::isnan(std::real(local_red_val)) || + std::isnan(std::imag(local_red_val))) + ? local_idx + : idx_identity_; + } + else if constexpr (std::is_floating_point_v || + std::is_same_v) + { + // equality does not hold for NaNs, so check here + local_idx = + (red_val_over_wg == local_red_val || std::isnan(local_red_val)) + ? local_idx + : idx_identity_; + } + else { + local_idx = + red_val_over_wg == local_red_val ? local_idx : idx_identity_; + } + outT idx_over_wg = sycl::reduce_over_group( + work_group, local_idx, idx_identity_, idx_reduction_op_); + if (work_group.leader()) { + // each group writes to a different memory location + if constexpr (!Last) { + // if not the final reduction, write value corresponding to + // an index to a temporary + vals_[out_iter_offset * n_reduction_groups + + reduction_batch_id] = red_val_over_wg; + } + out_[out_iter_offset * n_reduction_groups + reduction_batch_id] = + idx_over_wg; + } + } +}; + +typedef sycl::event (*search_strided_impl_fn_ptr)( + sycl::queue &, + std::size_t, + std::size_t, + const char *, + char *, + int, + const ssize_t *, + ssize_t, + ssize_t, + int, + const ssize_t *, + ssize_t, + const std::vector &); + +template +class search_seq_strided_krn; + +template +class search_seq_contig_krn; + +template +class search_over_group_krn; + +template +class custom_search_over_group_krn; + +template +class search_empty_krn; + +template +sycl::event + submit_search_reduction(sycl::queue &exec_q, + const argTy *arg, + argTy *arg_tmp, + resTy *res_tmp, + resTy *res, + argTy identity_val, + resTy idx_identity_val, + std::size_t wg, + std::size_t iter_nelems, + std::size_t reduction_nelems, + std::size_t reductions_per_wi, + std::size_t reduction_groups, + const InputOutputIterIndexerT &in_out_iter_indexer, + const ReductionIndexerT &reduction_indexer, + const std::vector &depends) +{ + sycl::event red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + auto globalRange = sycl::range<1>{iter_nelems * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + auto ndRange = sycl::nd_range<1>(globalRange, localRange); + + if constexpr (can_use_reduce_over_group::value) { + using KernelName = + class search_over_group_krn; + cgh.parallel_for( + ndRange, SearchReduction( + arg, arg_tmp, res_tmp, res, ReductionOpT(), + identity_val, IndexOpT(), idx_identity_val, + in_out_iter_indexer, reduction_indexer, + reduction_nelems, iter_nelems, reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + using KernelName = class custom_search_over_group_krn< + argTy, resTy, ReductionOpT, IndexOpT, InputOutputIterIndexerT, + ReductionIndexerT, SlmT, First, Last>; + cgh.parallel_for( + ndRange, + CustomSearchReduction( + arg, arg_tmp, res_tmp, res, ReductionOpT(), identity_val, + IndexOpT(), idx_identity_val, in_out_iter_indexer, + reduction_indexer, local_memory, reduction_nelems, + iter_nelems, reductions_per_wi)); + } + }); + return red_ev; +} + +template +sycl::event search_over_group_temps_strided_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of rows in a + // matrix when reducing over rows) + std::size_t reduction_nelems, // size of each reduction (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + int iter_nd, + const ssize_t *iter_shape_and_strides, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + int red_nd, + const ssize_t *reduction_shape_stride, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + static constexpr argTy identity_val = + su_ns::Identity::value; + static constexpr resTy idx_identity_val = + su_ns::Identity::value; + + if (reduction_nelems == 0) { + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + using IndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + + const ssize_t *const &res_shape = iter_shape_and_strides; + const ssize_t *const &res_strides = + iter_shape_and_strides + 2 * iter_nd; + const IndexerT res_indexer(iter_nd, iter_res_offset, res_shape, + res_strides); + using InitKernelName = + class search_empty_krn; + cgh.depends_on(depends); + + cgh.parallel_for( + sycl::range<1>(iter_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = idx_identity_val; + }); + }); + + return res_init_ev; + } + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, iter_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.parallel_for>( + sycl::range<1>(iter_nelems), + SequentialSearchReduction( + arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + + return comp_ev; + } + + static constexpr std::size_t preferred_reductions_per_wi = 4; + // prevents running out of resources on CPU + std::size_t max_wg = reduction_detail::get_work_group_size(d); + + std::size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + // Perform reduction using one 1 work-group per iteration, + // can output directly to res + + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + using ReductionIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{ + iter_nd, iter_arg_offset, iter_res_offset, iter_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_arg_offset, + reduction_shape_stride}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event comp_ev = + submit_search_reduction( + exec_q, arg_tp, nullptr, nullptr, res_tp, identity_val, + idx_identity_val, wg, iter_nelems, reduction_nelems, + reductions_per_wi, reduction_groups, in_out_iter_indexer, + reduction_indexer, depends); + + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + std::size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + const std::size_t tmp_alloc_size = + iter_nelems * (reduction_groups + second_iter_reduction_groups_); + auto tmp_owner = dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; + + auto val_tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + + argTy *partially_reduced_vals_tmp = val_tmp_owner.get(); + argTy *partially_reduced_vals_tmp2 = + partially_reduced_vals_tmp + reduction_groups * iter_nelems; + + sycl::event first_reduction_ev; + { + using InputIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + + // Only 2*iter_nd entries describing shape and strides of iterated + // dimensions of input array from iter_shape_and_strides are going + // to be accessed by inp_indexer + const InputIndexerT inp_indexer(iter_nd, iter_arg_offset, + iter_shape_and_strides); + static constexpr ResIndexerT noop_tmp_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + noop_tmp_indexer}; + const ReductionIndexerT reduction_indexer{ + red_nd, reduction_arg_offset, reduction_shape_stride}; + + first_reduction_ev = + submit_search_reduction( + exec_q, arg_tp, partially_reduced_vals_tmp, nullptr, + partially_reduced_tmp, identity_val, idx_identity_val, wg, + iter_nelems, reduction_nelems, reductions_per_wi, + reduction_groups, in_out_iter_indexer, reduction_indexer, + depends); + } + + std::size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + + argTy *vals_temp_arg = partially_reduced_vals_tmp; + argTy *vals_temp2_arg = partially_reduced_vals_tmp2; + + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + std::size_t reduction_groups_ = + (remaining_reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ reduction_groups_}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event partial_reduction_ev = + submit_search_reduction( + exec_q, vals_temp_arg, vals_temp2_arg, temp_arg, temp2_arg, + identity_val, idx_identity_val, wg, iter_nelems, + remaining_reduction_nelems, preferred_reductions_per_wi, + reduction_groups_, in_out_iter_indexer, reduction_indexer, + {dependent_ev}); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + std::swap(vals_temp_arg, vals_temp2_arg); + dependent_ev = partial_reduction_ev; + } + + // final reduction to res + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ remaining_reduction_nelems}; + const ResIndexerT res_iter_indexer{ + iter_nd, iter_res_offset, + /* shape */ iter_shape_and_strides, + /* strides */ iter_shape_and_strides + 2 * iter_nd}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = std::max( + 1, (remaining_reduction_nelems + wg - 1) / wg); + + reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event final_reduction_ev = + submit_search_reduction( + exec_q, vals_temp_arg, nullptr, temp_arg, res_tp, identity_val, + idx_identity_val, wg, iter_nelems, remaining_reduction_nelems, + reductions_per_wi, reduction_groups, in_out_iter_indexer, + reduction_indexer, {dependent_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {final_reduction_ev}, tmp_owner, val_tmp_owner); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; + } +} + +typedef sycl::event (*search_contig_impl_fn_ptr)( + sycl::queue &, + std::size_t, + std::size_t, + const char *, + char *, + ssize_t, + ssize_t, + ssize_t, + const std::vector &); + +template +sycl::event search_axis1_over_group_temps_contig_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of rows in a + // matrix when reducing over rows) + std::size_t reduction_nelems, // size of each reduction (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; + + static constexpr argTy identity_val = + su_ns::Identity::value; + static constexpr resTy idx_identity_val = + su_ns::Identity::value; + + if (reduction_nelems == 0) { + sycl::event res_init_ev = exec_q.fill( + res_tp, resTy(idx_identity_val), iter_nelems, depends); + + return res_init_ev; + } + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using InputIterIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + const InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{/* size */ iter_nelems, + /* step */ reduction_nelems}, + NoOpIndexerT{}}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.parallel_for>( + sycl::range<1>(iter_nelems), + SequentialSearchReduction( + arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + + return comp_ev; + } + + static constexpr std::size_t preferred_reductions_per_wi = 8; + // prevents running out of resources on CPU + std::size_t max_wg = reduction_detail::get_work_group_size(d); + + std::size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + // Perform reduction using one 1 work-group per iteration, + // can output directly to res + using InputIterIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + const InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{/* size */ iter_nelems, + /* step */ reduction_nelems}, + NoOpIndexerT{}}; + static constexpr ReductionIndexerT reduction_indexer{}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event comp_ev = + submit_search_reduction( + exec_q, arg_tp, nullptr, nullptr, res_tp, identity_val, + idx_identity_val, wg, iter_nelems, reduction_nelems, + reductions_per_wi, reduction_groups, in_out_iter_indexer, + reduction_indexer, depends); + + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + std::size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + const std::size_t tmp_alloc_size = + iter_nelems * (reduction_groups + second_iter_reduction_groups_); + auto tmp_owner = dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; + + auto val_tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + argTy *partially_reduced_vals_tmp = val_tmp_owner.get(); + argTy *partially_reduced_vals_tmp2 = + partially_reduced_vals_tmp + reduction_groups * iter_nelems; + + sycl::event first_reduction_ev; + { + using InputIterIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIterIndexerT, NoOpIndexerT>; + using ReductionIndexerT = NoOpIndexerT; + + const InputOutputIterIndexerT in_out_iter_indexer{ + InputIterIndexerT{/* size */ iter_nelems, + /* step */ reduction_nelems}, + NoOpIndexerT{}}; + static constexpr ReductionIndexerT reduction_indexer{}; + + first_reduction_ev = + submit_search_reduction( + exec_q, arg_tp, partially_reduced_vals_tmp, nullptr, + partially_reduced_tmp, identity_val, idx_identity_val, wg, + iter_nelems, reduction_nelems, preferred_reductions_per_wi, + reduction_groups, in_out_iter_indexer, reduction_indexer, + depends); + } + + std::size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + + argTy *vals_temp_arg = partially_reduced_vals_tmp; + argTy *vals_temp2_arg = partially_reduced_vals_tmp2; + + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + std::size_t reduction_groups_ = + (remaining_reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ reduction_groups_}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event partial_reduction_ev = + submit_search_reduction( + exec_q, vals_temp_arg, vals_temp2_arg, temp_arg, temp2_arg, + identity_val, idx_identity_val, wg, iter_nelems, + remaining_reduction_nelems, preferred_reductions_per_wi, + reduction_groups_, in_out_iter_indexer, reduction_indexer, + {dependent_ev}); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + std::swap(vals_temp_arg, vals_temp2_arg); + dependent_ev = partial_reduction_ev; + } + + // final reduction to res + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ remaining_reduction_nelems}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = std::max( + 1, (remaining_reduction_nelems + wg - 1) / wg); + + reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event final_reduction_ev = + submit_search_reduction( + exec_q, vals_temp_arg, nullptr, temp_arg, res_tp, identity_val, + idx_identity_val, wg, iter_nelems, remaining_reduction_nelems, + reductions_per_wi, reduction_groups, in_out_iter_indexer, + reduction_indexer, {dependent_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {final_reduction_ev}, tmp_owner, val_tmp_owner); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; + } +} + +template +sycl::event search_axis0_over_group_temps_contig_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of reductions (num. of rows in a + // matrix when reducing over rows) + std::size_t reduction_nelems, // size of each reduction (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t reduction_arg_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + reduction_arg_offset; + resTy *res_tp = reinterpret_cast(res_cp) + iter_res_offset; + + static constexpr argTy identity_val = + su_ns::Identity::value; + static constexpr resTy idx_identity_val = + su_ns::Identity::value; + + if (reduction_nelems == 0) { + sycl::event res_init_ev = exec_q.fill( + res_tp, resTy(idx_identity_val), iter_nelems, depends); + + return res_init_ev; + } + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + const InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using KernelName = + class search_seq_contig_krn; + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for( + iter_range, + SequentialSearchReduction( + arg_tp, res_tp, ReductionOpT(), identity_val, IndexOpT(), + idx_identity_val, in_out_iter_indexer, reduction_indexer, + reduction_nelems)); + }); + + return comp_ev; + } + + static constexpr std::size_t preferred_reductions_per_wi = 8; + // prevents running out of resources on CPU + std::size_t max_wg = reduction_detail::get_work_group_size(d); + + std::size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + // Perform reduction using one 1 work-group per iteration, + // can output directly to res + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = ColsIndexerT; + + static constexpr NoOpIndexerT columns_indexer{}; + static constexpr NoOpIndexerT result_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + result_indexer}; + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event comp_ev = + submit_search_reduction( + exec_q, arg_tp, nullptr, nullptr, res_tp, identity_val, + idx_identity_val, wg, iter_nelems, reduction_nelems, + reductions_per_wi, reduction_groups, in_out_iter_indexer, + reduction_indexer, depends); + + return comp_ev; + } + else { + // more than one work-groups is needed, requires a temporary + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + std::size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + const std::size_t tmp_alloc_size = + iter_nelems * (reduction_groups + second_iter_reduction_groups_); + auto tmp_owner = dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * iter_nelems; + + auto vals_tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + argTy *partially_reduced_vals_tmp = vals_tmp_owner.get(); + argTy *partially_reduced_vals_tmp2 = + partially_reduced_vals_tmp + reduction_groups * iter_nelems; + + sycl::event first_reduction_ev; + { + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ColsIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = ColsIndexerT; + + static constexpr NoOpIndexerT columns_indexer{}; + static constexpr NoOpIndexerT result_indexer{}; + const InputOutputIterIndexerT in_out_iter_indexer{columns_indexer, + result_indexer}; + const ReductionIndexerT reduction_indexer{ + /* size */ reduction_nelems, + /* step */ iter_nelems}; + + first_reduction_ev = + submit_search_reduction( + exec_q, arg_tp, partially_reduced_vals_tmp, nullptr, + partially_reduced_tmp, identity_val, idx_identity_val, wg, + iter_nelems, reduction_nelems, preferred_reductions_per_wi, + reduction_groups, in_out_iter_indexer, reduction_indexer, + depends); + } + + std::size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + + argTy *vals_temp_arg = partially_reduced_vals_tmp; + argTy *vals_temp2_arg = partially_reduced_vals_tmp2; + + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + std::size_t reduction_groups_ = + (remaining_reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ reduction_groups_}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event partial_reduction_ev = + submit_search_reduction( + exec_q, vals_temp_arg, vals_temp2_arg, temp_arg, temp2_arg, + identity_val, idx_identity_val, wg, iter_nelems, + remaining_reduction_nelems, preferred_reductions_per_wi, + reduction_groups_, in_out_iter_indexer, reduction_indexer, + {dependent_ev}); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + std::swap(vals_temp_arg, vals_temp2_arg); + dependent_ev = partial_reduction_ev; + } + + // final reduction to res + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ remaining_reduction_nelems}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = std::max( + 1, (remaining_reduction_nelems + wg - 1) / wg); + + reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event final_reduction_ev = + submit_search_reduction( + exec_q, vals_temp_arg, nullptr, temp_arg, res_tp, identity_val, + idx_identity_val, wg, iter_nelems, remaining_reduction_nelems, + reductions_per_wi, reduction_groups, in_out_iter_indexer, + reduction_indexer, {dependent_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {final_reduction_ev}, tmp_owner, vals_tmp_owner); + + // FIXME: do not return host-task event + // Instead collect all host-tasks to a list + + return cleanup_host_task_event; + } +} + +} // namespace dpctl::tensor::kernels diff --git a/dpctl_ext/tensor/libtensor/include/kernels/sorting/isin.hpp b/dpctl_ext/tensor/libtensor/include/kernels/sorting/isin.hpp new file mode 100644 index 000000000000..847fa96ecdff --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/sorting/isin.hpp @@ -0,0 +1,245 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor membership operations. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/search_sorted_detail.hpp" +#include "utils/offset_utils.hpp" +#include "utils/rich_comparisons.hpp" + +namespace dpctl::tensor::kernels +{ + +using dpctl::tensor::ssize_t; + +template +struct IsinFunctor +{ +private: + bool invert; + const T *hay_tp; + const T *needles_tp; + bool *out_tp; + std::size_t hay_nelems; + HayIndexerT hay_indexer; + NeedlesIndexerT needles_indexer; + OutIndexerT out_indexer; + +public: + IsinFunctor(const bool invert_, + const T *hay_, + const T *needles_, + bool *out_, + const std::size_t hay_nelems_, + const HayIndexerT &hay_indexer_, + const NeedlesIndexerT &needles_indexer_, + const OutIndexerT &out_indexer_) + : invert(invert_), hay_tp(hay_), needles_tp(needles_), out_tp(out_), + hay_nelems(hay_nelems_), hay_indexer(hay_indexer_), + needles_indexer(needles_indexer_), out_indexer(out_indexer_) + { + } + + void operator()(sycl::id<1> id) const + { + using Compare = + typename dpctl::tensor::rich_comparisons::AscendingSorter::type; + static constexpr Compare comp{}; + + const std::size_t i = id[0]; + const T needle_v = needles_tp[needles_indexer(i)]; + + // position of the needle_v in the hay array + std::size_t pos{}; + + static constexpr std::size_t zero(0); + // search in hay in left-closed interval, give `pos` such that + // hay[pos - 1] < needle_v <= hay[pos] + + // lower_bound returns the first pos such that bool(hay[pos] < + // needle_v) is false, i.e. needle_v <= hay[pos] + pos = search_sorted_detail::lower_bound_indexed_impl( + hay_tp, zero, hay_nelems, needle_v, comp, hay_indexer); + bool out = (pos == hay_nelems ? false : hay_tp[pos] == needle_v); + out_tp[out_indexer(i)] = (invert) ? !out : out; + } +}; + +typedef sycl::event (*isin_contig_impl_fp_ptr_t)( + sycl::queue &, + const bool, + const std::size_t, + const std::size_t, + const char *, + const ssize_t, + const char *, + const ssize_t, + char *, + const ssize_t, + const std::vector &); + +template +class isin_contig_impl_krn; + +template +sycl::event isin_contig_impl(sycl::queue &exec_q, + const bool invert, + const std::size_t hay_nelems, + const std::size_t needles_nelems, + const char *hay_cp, + const ssize_t hay_offset, + const char *needles_cp, + const ssize_t needles_offset, + char *out_cp, + const ssize_t out_offset, + const std::vector &depends) +{ + const T *hay_tp = reinterpret_cast(hay_cp) + hay_offset; + const T *needles_tp = + reinterpret_cast(needles_cp) + needles_offset; + + bool *out_tp = reinterpret_cast(out_cp) + out_offset; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using KernelName = class isin_contig_impl_krn; + + sycl::range<1> gRange(needles_nelems); + + using TrivialIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + static constexpr TrivialIndexerT hay_indexer{}; + static constexpr TrivialIndexerT needles_indexer{}; + static constexpr TrivialIndexerT out_indexer{}; + + const auto fnctr = + IsinFunctor( + invert, hay_tp, needles_tp, out_tp, hay_nelems, hay_indexer, + needles_indexer, out_indexer); + + cgh.parallel_for(gRange, fnctr); + }); + + return comp_ev; +} + +typedef sycl::event (*isin_strided_impl_fp_ptr_t)( + sycl::queue &, + const bool, + const std::size_t, + const std::size_t, + const char *, + const ssize_t, + const ssize_t, + const char *, + const ssize_t, + char *, + const ssize_t, + int, + const ssize_t *, + const std::vector &); + +template +class isin_strided_impl_krn; + +template +sycl::event isin_strided_impl( + sycl::queue &exec_q, + const bool invert, + const std::size_t hay_nelems, + const std::size_t needles_nelems, + const char *hay_cp, + const ssize_t hay_offset, + // hay is 1D, so hay_nelems, hay_offset, hay_stride describe strided array + const ssize_t hay_stride, + const char *needles_cp, + const ssize_t needles_offset, + char *out_cp, + const ssize_t out_offset, + const int needles_nd, + // packed_shape_strides is [needles_shape, needles_strides, + // out_strides] has length of 3*needles_nd + const ssize_t *packed_shape_strides, + const std::vector &depends) +{ + const T *hay_tp = reinterpret_cast(hay_cp); + const T *needles_tp = reinterpret_cast(needles_cp); + + bool *out_tp = reinterpret_cast(out_cp); + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + sycl::range<1> gRange(needles_nelems); + + using HayIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + const HayIndexerT hay_indexer( + /* offset */ hay_offset, + /* size */ hay_nelems, + /* step */ hay_stride); + + using NeedlesIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const ssize_t *needles_shape_strides = packed_shape_strides; + const NeedlesIndexerT needles_indexer(needles_nd, needles_offset, + needles_shape_strides); + using OutIndexerT = dpctl::tensor::offset_utils::UnpackedStridedIndexer; + + const ssize_t *out_shape = packed_shape_strides; + const ssize_t *out_strides = packed_shape_strides + 2 * needles_nd; + const OutIndexerT out_indexer(needles_nd, out_offset, out_shape, + out_strides); + + const auto fnctr = + IsinFunctor( + invert, hay_tp, needles_tp, out_tp, hay_nelems, hay_indexer, + needles_indexer, out_indexer); + using KernelName = class isin_strided_impl_krn; + + cgh.parallel_for(gRange, fnctr); + }); + + return comp_ev; +} + +} // namespace dpctl::tensor::kernels diff --git a/dpctl_ext/tensor/libtensor/include/kernels/sorting/merge_sort.hpp b/dpctl_ext/tensor/libtensor/include/kernels/sorting/merge_sort.hpp new file mode 100644 index 000000000000..a047c172f7bc --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/sorting/merge_sort.hpp @@ -0,0 +1,856 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor sort/argsort operations. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/search_sorted_detail.hpp" +#include "kernels/sorting/sort_utils.hpp" + +namespace dpctl::tensor::kernels +{ + +namespace merge_sort_detail +{ + +using dpctl::tensor::ssize_t; +using namespace dpctl::tensor::kernels::search_sorted_detail; + +/*! @brief Merge two contiguous sorted segments */ +template +void merge_impl(const std::size_t offset, + const InAcc in_acc, + OutAcc out_acc, + const std::size_t start_1, + const std::size_t end_1, + const std::size_t end_2, + const std::size_t start_out, + Compare comp, + const std::size_t chunk) +{ + const std::size_t start_2 = end_1; + // Borders of the sequences to merge within this call + const std::size_t local_start_1 = sycl::min(offset + start_1, end_1); + const std::size_t local_end_1 = sycl::min(local_start_1 + chunk, end_1); + const std::size_t local_start_2 = sycl::min(offset + start_2, end_2); + const std::size_t local_end_2 = sycl::min(local_start_2 + chunk, end_2); + + const std::size_t local_size_1 = local_end_1 - local_start_1; + const std::size_t local_size_2 = local_end_2 - local_start_2; + + const auto r_item_1 = in_acc[end_1 - 1]; + const auto l_item_2 = (start_2 < end_2) ? in_acc[start_2] : r_item_1; + + // Copy if the sequences are sorted with respect to each other or merge + // otherwise + if (!comp(l_item_2, r_item_1)) { + const std::size_t out_shift_1 = start_out + local_start_1 - start_1; + const std::size_t out_shift_2 = + start_out + end_1 - start_1 + local_start_2 - start_2; + + for (std::size_t i = 0; i < local_size_1; ++i) { + out_acc[out_shift_1 + i] = in_acc[local_start_1 + i]; + } + for (std::size_t i = 0; i < local_size_2; ++i) { + out_acc[out_shift_2 + i] = in_acc[local_start_2 + i]; + } + } + else if (comp(r_item_1, l_item_2)) { + const std::size_t out_shift_1 = + start_out + end_2 - start_2 + local_start_1 - start_1; + const std::size_t out_shift_2 = start_out + local_start_2 - start_2; + for (std::size_t i = 0; i < local_size_1; ++i) { + out_acc[out_shift_1 + i] = in_acc[local_start_1 + i]; + } + for (std::size_t i = 0; i < local_size_2; ++i) { + out_acc[out_shift_2 + i] = in_acc[local_start_2 + i]; + } + } + // Perform merging + else { + + // Process 1st sequence + if (local_start_1 < local_end_1) { + // Reduce the range for searching within the 2nd sequence and handle + // bound items find left border in 2nd sequence + const auto local_l_item_1 = in_acc[local_start_1]; + std::size_t l_search_bound_2 = + lower_bound_impl(in_acc, start_2, end_2, local_l_item_1, comp); + const std::size_t l_shift_1 = local_start_1 - start_1; + const std::size_t l_shift_2 = l_search_bound_2 - start_2; + + out_acc[start_out + l_shift_1 + l_shift_2] = local_l_item_1; + + std::size_t r_search_bound_2{}; + // find right border in 2nd sequence + if (local_size_1 > 1) { + const auto local_r_item_1 = in_acc[local_end_1 - 1]; + r_search_bound_2 = lower_bound_impl( + in_acc, l_search_bound_2, end_2, local_r_item_1, comp); + const auto r_shift_1 = local_end_1 - 1 - start_1; + const auto r_shift_2 = r_search_bound_2 - start_2; + + out_acc[start_out + r_shift_1 + r_shift_2] = local_r_item_1; + } + + // Handle intermediate items + if (r_search_bound_2 == l_search_bound_2) { + const std::size_t shift_2 = l_search_bound_2 - start_2; + for (std::size_t idx = local_start_1 + 1; idx < local_end_1 - 1; + ++idx) { + const auto intermediate_item_1 = in_acc[idx]; + const std::size_t shift_1 = idx - start_1; + out_acc[start_out + shift_1 + shift_2] = + intermediate_item_1; + } + } + else { + for (std::size_t idx = local_start_1 + 1; idx < local_end_1 - 1; + ++idx) { + const auto intermediate_item_1 = in_acc[idx]; + // we shouldn't seek in whole 2nd sequence. Just for the + // part where the 1st sequence should be + l_search_bound_2 = lower_bound_impl( + in_acc, l_search_bound_2, r_search_bound_2, + intermediate_item_1, comp); + const std::size_t shift_1 = idx - start_1; + const std::size_t shift_2 = l_search_bound_2 - start_2; + + out_acc[start_out + shift_1 + shift_2] = + intermediate_item_1; + } + } + } + // Process 2nd sequence + if (local_start_2 < local_end_2) { + // Reduce the range for searching within the 1st sequence and handle + // bound items find left border in 1st sequence + const auto local_l_item_2 = in_acc[local_start_2]; + std::size_t l_search_bound_1 = + upper_bound_impl(in_acc, start_1, end_1, local_l_item_2, comp); + const std::size_t l_shift_1 = l_search_bound_1 - start_1; + const std::size_t l_shift_2 = local_start_2 - start_2; + + out_acc[start_out + l_shift_1 + l_shift_2] = local_l_item_2; + + std::size_t r_search_bound_1{}; + // find right border in 1st sequence + if (local_size_2 > 1) { + const auto local_r_item_2 = in_acc[local_end_2 - 1]; + r_search_bound_1 = upper_bound_impl( + in_acc, l_search_bound_1, end_1, local_r_item_2, comp); + const std::size_t r_shift_1 = r_search_bound_1 - start_1; + const std::size_t r_shift_2 = local_end_2 - 1 - start_2; + + out_acc[start_out + r_shift_1 + r_shift_2] = local_r_item_2; + } + + // Handle intermediate items + if (l_search_bound_1 == r_search_bound_1) { + const std::size_t shift_1 = l_search_bound_1 - start_1; + for (auto idx = local_start_2 + 1; idx < local_end_2 - 1; ++idx) + { + const auto intermediate_item_2 = in_acc[idx]; + const std::size_t shift_2 = idx - start_2; + out_acc[start_out + shift_1 + shift_2] = + intermediate_item_2; + } + } + else { + for (auto idx = local_start_2 + 1; idx < local_end_2 - 1; ++idx) + { + const auto intermediate_item_2 = in_acc[idx]; + // we shouldn't seek in whole 1st sequence. Just for the + // part where the 2nd sequence should be + l_search_bound_1 = upper_bound_impl( + in_acc, l_search_bound_1, r_search_bound_1, + intermediate_item_2, comp); + const std::size_t shift_1 = l_search_bound_1 - start_1; + const std::size_t shift_2 = idx - start_2; + + out_acc[start_out + shift_1 + shift_2] = + intermediate_item_2; + } + } + } + } +} + +template +void insertion_sort_impl(Iter &&first, + std::size_t begin, + std::size_t end, + Compare &&comp) +{ + for (std::size_t i = begin + 1; i < end; ++i) { + const auto val_i = first[i]; + std::size_t j = i - 1; + while ((j + 1 > begin) && (comp(val_i, first[j]))) { + first[j + 1] = first[j]; + --j; + } + if (j + 1 < i) { + first[j + 1] = val_i; + } + } +} + +template +void leaf_sort_impl(Iter &&first, + std::size_t begin, + std::size_t end, + Compare &&comp) +{ + return insertion_sort_impl(std::forward(first), + std::move(begin), std::move(end), + std::forward(comp)); +} + +template +struct GetValueType +{ + using value_type = typename std::iterator_traits::value_type; +}; + +template +struct GetValueType> +{ + using value_type = ElementType; +}; + +template +struct GetValueType< + sycl::accessor> +{ + using value_type = ElementType; +}; + +template +struct GetValueType> +{ + using value_type = ElementType; +}; + +template +struct GetReadOnlyAccess +{ + Iter operator()(const Iter &it, sycl::handler &) + { + return it; + } +}; + +template +struct GetReadOnlyAccess> +{ + auto operator()(const sycl::buffer &buf, + sycl::handler &cgh) + { + sycl::accessor acc(buf, cgh, sycl::read_only); + return acc; + } +}; + +template +struct GetWriteDiscardAccess +{ + Iter operator()(Iter it, sycl::handler &) + { + return it; + } +}; + +template +struct GetWriteDiscardAccess> +{ + auto operator()(sycl::buffer &buf, + sycl::handler &cgh) + { + sycl::accessor acc(buf, cgh, sycl::write_only, sycl::no_init); + return acc; + } +}; + +template +struct GetReadWriteAccess +{ + Iter operator()(Iter &it, sycl::handler &) + { + return it; + } +}; + +template +struct GetReadWriteAccess> +{ + auto operator()(sycl::buffer &buf, + sycl::handler &cgh) + { + sycl::accessor acc(buf, cgh, sycl::read_write); + return acc; + } +}; + +template +class sort_base_step_contig_krn; + +template +sycl::event + sort_base_step_contig_impl(sycl::queue &q, + const std::size_t iter_nelems, + const std::size_t sort_nelems, + const InpAcc input, + OutAcc output, + const Comp &comp, + const std::size_t conseq_nelems_sorted, + const std::vector &depends = {}) +{ + + using inpT = typename GetValueType::value_type; + using outT = typename GetValueType::value_type; + using KernelName = sort_base_step_contig_krn; + + const std::size_t n_segments = + quotient_ceil(sort_nelems, conseq_nelems_sorted); + + sycl::event base_sort = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + const sycl::range<1> gRange{iter_nelems * n_segments}; + + auto input_acc = GetReadOnlyAccess{}(input, cgh); + auto output_acc = GetWriteDiscardAccess{}(output, cgh); + + cgh.parallel_for(gRange, [=](sycl::id<1> id) { + const std::size_t iter_id = id[0] / n_segments; + const std::size_t segment_id = id[0] - iter_id * n_segments; + + const std::size_t iter_offset = iter_id * sort_nelems; + const std::size_t beg_id = + iter_offset + segment_id * conseq_nelems_sorted; + const std::size_t end_id = + iter_offset + + std::min((segment_id + 1) * conseq_nelems_sorted, sort_nelems); + for (std::size_t i = beg_id; i < end_id; ++i) { + output_acc[i] = input_acc[i]; + } + + leaf_sort_impl(output_acc, beg_id, end_id, comp); + }); + }); + + return base_sort; +} + +template +class sort_over_work_group_contig_krn; + +template +sycl::event sort_over_work_group_contig_impl( + sycl::queue &q, + std::size_t iter_nelems, + std::size_t sort_nelems, + const InpAcc input, + OutAcc output, + const Comp &comp, + std::size_t &nelems_wg_sorts, + const std::vector &depends = {}) +{ + using inpT = typename GetValueType::value_type; + using T = typename GetValueType::value_type; + using KernelName = sort_over_work_group_contig_krn; + + const auto &kernel_id = sycl::get_kernel_id(); + + auto const &ctx = q.get_context(); + auto const &dev = q.get_device(); + auto kb = sycl::get_kernel_bundle( + ctx, {dev}, {kernel_id}); + + auto krn = kb.get_kernel(kernel_id); + + const std::uint32_t max_sg_size = krn.template get_info< + sycl::info::kernel_device_specific::max_sub_group_size>(dev); + const std::uint64_t device_local_memory_size = + dev.get_info(); + + // leave 512 bytes of local memory for RT + const std::uint64_t safety_margin = 512; + + const std::uint64_t nelems_per_slm = + (device_local_memory_size - safety_margin) / (2 * sizeof(T)); + + static constexpr std::uint32_t sub_groups_per_work_group = 4; + const std::uint32_t elems_per_wi = dev.has(sycl::aspect::cpu) ? 8 : 2; + + const std::size_t lws = sub_groups_per_work_group * max_sg_size; + + nelems_wg_sorts = elems_per_wi * lws; + + if (nelems_wg_sorts > nelems_per_slm) { + nelems_wg_sorts = (q.get_device().has(sycl::aspect::cpu) ? 16 : 4); + + return sort_base_step_contig_impl( + q, iter_nelems, sort_nelems, input, output, comp, nelems_wg_sorts, + depends); + } + + // This assumption permits doing away with using a loop + assert(nelems_wg_sorts % lws == 0); + + const std::size_t n_segments = quotient_ceil(sort_nelems, nelems_wg_sorts); + + sycl::event base_sort_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.use_kernel_bundle(kb); + + sycl::range<1> global_range{iter_nelems * n_segments * lws}; + sycl::range<1> local_range{lws}; + + sycl::range<1> slm_range{nelems_wg_sorts}; + sycl::local_accessor work_space(slm_range, cgh); + sycl::local_accessor scratch_space(slm_range, cgh); + + auto input_acc = GetReadOnlyAccess{}(input, cgh); + auto output_acc = GetWriteDiscardAccess{}(output, cgh); + + sycl::nd_range<1> ndRange(global_range, local_range); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> it) { + const std::size_t group_id = it.get_group_linear_id(); + const std::size_t iter_id = group_id / n_segments; + const std::size_t segment_id = group_id - iter_id * n_segments; + const std::size_t lid = it.get_local_linear_id(); + + const std::size_t segment_start_idx = segment_id * nelems_wg_sorts; + const std::size_t segment_end_idx = + std::min(segment_start_idx + nelems_wg_sorts, sort_nelems); + const std::size_t wg_chunk_size = + segment_end_idx - segment_start_idx; + + // load input into SLM + for (std::size_t array_id = segment_start_idx + lid; + array_id < segment_end_idx; array_id += lws) + { + T v = (array_id < sort_nelems) + ? input_acc[iter_id * sort_nelems + array_id] + : T{}; + work_space[array_id - segment_start_idx] = v; + } + sycl::group_barrier(it.get_group()); + + const std::size_t chunk = quotient_ceil(nelems_wg_sorts, lws); + + const std::size_t chunk_start_idx = lid * chunk; + const std::size_t chunk_end_idx = + sycl::min(chunk_start_idx + chunk, wg_chunk_size); + + leaf_sort_impl(work_space, chunk_start_idx, chunk_end_idx, comp); + + sycl::group_barrier(it.get_group()); + + bool data_in_temp = false; + std::size_t n_chunks_merged = 1; + + // merge chunk while n_chunks_merged * chunk < wg_chunk_size + const std::size_t max_chunks_merged = + 1 + ((wg_chunk_size - 1) / chunk); + for (; n_chunks_merged < max_chunks_merged; + data_in_temp = !data_in_temp, n_chunks_merged *= 2) + { + const std::size_t nelems_sorted_so_far = + n_chunks_merged * chunk; + const std::size_t q = (lid / n_chunks_merged); + const std::size_t start_1 = + sycl::min(2 * nelems_sorted_so_far * q, wg_chunk_size); + const std::size_t end_1 = + sycl::min(start_1 + nelems_sorted_so_far, wg_chunk_size); + const std::size_t end_2 = + sycl::min(end_1 + nelems_sorted_so_far, wg_chunk_size); + const std::size_t offset = chunk * (lid - q * n_chunks_merged); + + if (data_in_temp) { + merge_impl(offset, scratch_space, work_space, start_1, + end_1, end_2, start_1, comp, chunk); + } + else { + merge_impl(offset, work_space, scratch_space, start_1, + end_1, end_2, start_1, comp, chunk); + } + sycl::group_barrier(it.get_group()); + } + + const auto &out_src = (data_in_temp) ? scratch_space : work_space; + for (std::size_t array_id = segment_start_idx + lid; + array_id < segment_end_idx; array_id += lws) + { + if (array_id < sort_nelems) { + output_acc[iter_id * sort_nelems + array_id] = + out_src[array_id - segment_start_idx]; + } + } + }); + }); + + return base_sort_ev; +} + +class vacuous_krn; + +inline sycl::event tie_events(sycl::queue &q, + const std::vector depends) +{ + if (depends.empty()) + return sycl::event(); + if (depends.size() == 1) + return depends[0]; + + sycl::event e = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + using KernelName = vacuous_krn; + cgh.single_task([]() {}); + }); + + return e; +} + +template +class merge_adjacent_blocks_to_temp_krn; + +template +class merge_adjacent_blocks_from_temp_krn; + +template +sycl::event + merge_sorted_block_contig_impl(sycl::queue &q, + std::size_t iter_nelems, + std::size_t sort_nelems, + Acc output, + const Comp comp, + std::size_t sorted_block_size, + const std::vector &depends = {}) +{ + + if (sorted_block_size >= sort_nelems) + return tie_events(q, depends); + + // experimentally determined value + // size of segments worked upon by each work-item during merging + const sycl::device &dev = q.get_device(); + const std::size_t segment_size = (dev.has(sycl::aspect::cpu)) ? 32 : 4; + + const std::size_t chunk_size = + (sorted_block_size < segment_size) ? sorted_block_size : segment_size; + + assert(sorted_block_size % chunk_size == 0); + + using T = typename GetValueType::value_type; + + sycl::buffer temp_buf(sycl::range<1>{iter_nelems * sort_nelems}); + // T *allocated_mem = sycl::malloc_device(iter_nelems * sort_nelems, q); + + bool needs_copy = true; + bool used_depends = false; + + sycl::event dep_ev; + std::size_t chunks_merged = sorted_block_size / chunk_size; + + assert(!(chunks_merged & (chunks_merged - 1))); + + using ToTempKernelName = class merge_adjacent_blocks_to_temp_krn; + using FromTempKernelName = + class merge_adjacent_blocks_from_temp_krn; + + while (chunks_merged * chunk_size < sort_nelems) { + sycl::event local_dep = dep_ev; + + sycl::event merge_ev = q.submit([&](sycl::handler &cgh) { + if (used_depends) { + cgh.depends_on(local_dep); + } + else { + cgh.depends_on(depends); + used_depends = true; + } + + const std::size_t n_chunks = quotient_ceil(sort_nelems, chunk_size); + + if (needs_copy) { + sycl::accessor temp_acc{temp_buf, cgh, sycl::write_only, + sycl::no_init}; + auto output_acc = GetReadOnlyAccess{}(output, cgh); + cgh.parallel_for( + {iter_nelems * n_chunks}, [=](sycl::id<1> wid) { + auto flat_idx = wid[0]; + auto iter_idx = flat_idx / n_chunks; + auto idx = flat_idx - n_chunks * iter_idx; + + const std::size_t idx_mult = + (idx / chunks_merged) * chunks_merged; + const std::size_t idx_rem = (idx - idx_mult); + const std::size_t start_1 = + sycl::min(2 * idx_mult * chunk_size, sort_nelems); + const std::size_t end_1 = sycl::min( + start_1 + chunks_merged * chunk_size, sort_nelems); + const std::size_t end_2 = sycl::min( + end_1 + chunks_merged * chunk_size, sort_nelems); + const std::size_t offset = chunk_size * idx_rem; + + const std::size_t iter_offset = iter_idx * sort_nelems; + + merge_impl(offset, output_acc, temp_acc, + iter_offset + start_1, iter_offset + end_1, + iter_offset + end_2, iter_offset + start_1, + comp, chunk_size); + }); + } + else { + sycl::accessor temp_acc{temp_buf, cgh, sycl::read_only}; + auto output_acc = GetWriteDiscardAccess{}(output, cgh); + cgh.parallel_for( + {iter_nelems * n_chunks}, [=](sycl::id<1> wid) { + auto flat_idx = wid[0]; + auto iter_idx = flat_idx / n_chunks; + auto idx = flat_idx - n_chunks * iter_idx; + + const std::size_t idx_mult = + (idx / chunks_merged) * chunks_merged; + const std::size_t idx_rem = (idx - idx_mult); + const std::size_t start_1 = + sycl::min(2 * idx_mult * chunk_size, sort_nelems); + const std::size_t end_1 = sycl::min( + start_1 + chunks_merged * chunk_size, sort_nelems); + const std::size_t end_2 = sycl::min( + end_1 + chunks_merged * chunk_size, sort_nelems); + const std::size_t offset = chunk_size * idx_rem; + + const std::size_t iter_offset = iter_idx * sort_nelems; + + merge_impl(offset, temp_acc, output_acc, + iter_offset + start_1, iter_offset + end_1, + iter_offset + end_2, iter_offset + start_1, + comp, chunk_size); + }); + } + }); + + chunks_merged *= 2; + dep_ev = merge_ev; + + if (chunks_merged * chunk_size < sort_nelems) { + needs_copy = !needs_copy; + } + } + + if (needs_copy) { + sycl::event copy_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dep_ev); + + sycl::accessor temp_acc{temp_buf, cgh, sycl::read_only}; + auto output_acc = GetWriteDiscardAccess{}(output, cgh); + + cgh.copy(temp_acc, output_acc); + }); + dep_ev = copy_ev; + } + + return dep_ev; +} + +} // namespace merge_sort_detail + +template > +sycl::event stable_sort_axis1_contig_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows in a + // matrix when sorting over rows) + std::size_t sort_nelems, // size of each array to sort (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + sort_arg_offset; + argTy *res_tp = + reinterpret_cast(res_cp) + iter_res_offset + sort_res_offset; + + auto comp = Comp{}; + + // constant chosen experimentally to ensure monotonicity of + // sorting performance, as measured on GPU Max, and Iris Xe + constexpr std::size_t sequential_sorting_threshold = 16; + + if (sort_nelems < sequential_sorting_threshold) { + // equal work-item sorts entire row + sycl::event sequential_sorting_ev = + merge_sort_detail::sort_base_step_contig_impl( + exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp, + sort_nelems, depends); + + return sequential_sorting_ev; + } + else { + std::size_t sorted_block_size{}; + + // Sort segments of the array + sycl::event base_sort_ev = + merge_sort_detail::sort_over_work_group_contig_impl( + exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp, + sorted_block_size, // modified in place with size of sorted + // block size + depends); + + // Merge segments in parallel until all elements are sorted + sycl::event merges_ev = + merge_sort_detail::merge_sorted_block_contig_impl( + exec_q, iter_nelems, sort_nelems, res_tp, comp, + sorted_block_size, {base_sort_ev}); + + return merges_ev; + } +} + +template +class populate_index_data_krn; + +template +class index_map_to_rows_krn; + +template +struct IndexComp +{ + IndexComp(const ValueT *data, const ValueComp &comp_op) + : ptr(data), value_comp(comp_op) + { + } + + bool operator()(const IndexT &i1, const IndexT &i2) const + { + return value_comp(ptr[i1], ptr[i2]); + } + +private: + const ValueT *ptr; + ValueComp value_comp; +}; + +template > +sycl::event stable_argsort_axis1_contig_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows in a + // matrix when sorting over rows) + std::size_t sort_nelems, // size of each array to sort (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + sort_arg_offset; + IndexTy *res_tp = + reinterpret_cast(res_cp) + iter_res_offset + sort_res_offset; + + const IndexComp index_comp{arg_tp, ValueComp{}}; + + static constexpr std::size_t determine_automatically = 0; + std::size_t sorted_block_size = determine_automatically; + + const std::size_t total_nelems = iter_nelems * sort_nelems; + + using dpctl::tensor::kernels::sort_utils_detail::iota_impl; + + using IotaKernelName = populate_index_data_krn; + + sycl::event populate_indexed_data_ev = iota_impl( + exec_q, res_tp, total_nelems, depends); + + // Sort segments of the array + sycl::event base_sort_ev = + merge_sort_detail::sort_over_work_group_contig_impl( + exec_q, iter_nelems, sort_nelems, res_tp, res_tp, index_comp, + sorted_block_size, // modified in place with size of sorted block + // size + {populate_indexed_data_ev}); + + // Merge segments in parallel until all elements are sorted + sycl::event merges_ev = merge_sort_detail::merge_sorted_block_contig_impl( + exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size, + {base_sort_ev}); + + // no need to map back if iter_nelems == 1 + if (iter_nelems == 1u) { + return merges_ev; + } + + using MapBackKernelName = index_map_to_rows_krn; + using dpctl::tensor::kernels::sort_utils_detail::map_back_impl; + + sycl::event write_out_ev = map_back_impl( + exec_q, total_nelems, res_tp, res_tp, sort_nelems, {merges_ev}); + + return write_out_ev; +} + +} // namespace dpctl::tensor::kernels diff --git a/dpctl_ext/tensor/libtensor/include/kernels/sorting/radix_sort.hpp b/dpctl_ext/tensor/libtensor/include/kernels/sorting/radix_sort.hpp new file mode 100644 index 000000000000..545444101fc6 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/sorting/radix_sort.hpp @@ -0,0 +1,1920 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/sort_utils.hpp" +#include "utils/sycl_alloc_utils.hpp" + +namespace dpctl::tensor::kernels +{ + +namespace radix_sort_details +{ + +template +class radix_sort_count_kernel; + +template +class radix_sort_scan_kernel; + +template +class radix_sort_reorder_peer_kernel; + +template +class radix_sort_reorder_kernel; + +/*! @brief Computes smallest exponent such that `n <= (1 << exponent)` */ +template && + sizeof(SizeT) == sizeof(std::uint64_t), + int> = 0> +std::uint32_t ceil_log2(SizeT n) +{ + // if n > 2^b, n = q * 2^b + r for q > 0 and 0 <= r < 2^b + // floor_log2(q * 2^b + r) == floor_log2(q * 2^b) == q + floor_log2(n1) + // ceil_log2(n) == 1 + floor_log2(n-1) + if (n <= 1) + return std::uint32_t{1}; + + std::uint32_t exp{1}; + --n; + if (n >= (SizeT{1} << 32)) { + n >>= 32; + exp += 32; + } + if (n >= (SizeT{1} << 16)) { + n >>= 16; + exp += 16; + } + if (n >= (SizeT{1} << 8)) { + n >>= 8; + exp += 8; + } + if (n >= (SizeT{1} << 4)) { + n >>= 4; + exp += 4; + } + if (n >= (SizeT{1} << 2)) { + n >>= 2; + exp += 2; + } + if (n >= (SizeT{1} << 1)) { + n >>= 1; + ++exp; + } + return exp; +} + +//---------------------------------------------------------- +// bitwise order-preserving conversions to unsigned integers +//---------------------------------------------------------- + +template +bool order_preserving_cast(bool val) +{ + if constexpr (is_ascending) + return val; + else + return !val; +} + +template , int> = 0> +UIntT order_preserving_cast(UIntT val) +{ + if constexpr (is_ascending) { + return val; + } + else { + // bitwise invert + return (~val); + } +} + +template && std::is_signed_v, + int> = 0> +std::make_unsigned_t order_preserving_cast(IntT val) +{ + using UIntT = std::make_unsigned_t; + const UIntT uint_val = sycl::bit_cast(val); + + if constexpr (is_ascending) { + // ascending_mask: 100..0 + static constexpr UIntT ascending_mask = + (UIntT(1) << std::numeric_limits::digits); + return (uint_val ^ ascending_mask); + } + else { + // descending_mask: 011..1 + static constexpr UIntT descending_mask = + (std::numeric_limits::max() >> 1); + return (uint_val ^ descending_mask); + } +} + +template +std::uint16_t order_preserving_cast(sycl::half val) +{ + using UIntT = std::uint16_t; + + const UIntT uint_val = sycl::bit_cast( + (sycl::isnan(val)) ? std::numeric_limits::quiet_NaN() + : val); + UIntT mask; + + // test the sign bit of the original value + const bool zero_fp_sign_bit = (UIntT(0) == (uint_val >> 15)); + + static constexpr UIntT zero_mask = UIntT(0x8000u); + static constexpr UIntT nonzero_mask = UIntT(0xFFFFu); + + static constexpr UIntT inv_zero_mask = static_cast(~zero_mask); + static constexpr UIntT inv_nonzero_mask = static_cast(~nonzero_mask); + + if constexpr (is_ascending) { + mask = (zero_fp_sign_bit) ? zero_mask : nonzero_mask; + } + else { + mask = (zero_fp_sign_bit) ? (inv_zero_mask) : (inv_nonzero_mask); + } + + return (uint_val ^ mask); +} + +template && + sizeof(FloatT) == sizeof(std::uint32_t), + int> = 0> +std::uint32_t order_preserving_cast(FloatT val) +{ + using UIntT = std::uint32_t; + + UIntT uint_val = sycl::bit_cast( + (sycl::isnan(val)) ? std::numeric_limits::quiet_NaN() : val); + + UIntT mask; + + // test the sign bit of the original value + const bool zero_fp_sign_bit = (UIntT(0) == (uint_val >> 31)); + + static constexpr UIntT zero_mask = UIntT(0x80000000u); + static constexpr UIntT nonzero_mask = UIntT(0xFFFFFFFFu); + + if constexpr (is_ascending) + mask = (zero_fp_sign_bit) ? zero_mask : nonzero_mask; + else + mask = (zero_fp_sign_bit) ? (~zero_mask) : (~nonzero_mask); + + return (uint_val ^ mask); +} + +template && + sizeof(FloatT) == sizeof(std::uint64_t), + int> = 0> +std::uint64_t order_preserving_cast(FloatT val) +{ + using UIntT = std::uint64_t; + + UIntT uint_val = sycl::bit_cast( + (sycl::isnan(val)) ? std::numeric_limits::quiet_NaN() : val); + UIntT mask; + + // test the sign bit of the original value + const bool zero_fp_sign_bit = (UIntT(0) == (uint_val >> 63)); + + static constexpr UIntT zero_mask = UIntT(0x8000000000000000u); + static constexpr UIntT nonzero_mask = UIntT(0xFFFFFFFFFFFFFFFFu); + + if constexpr (is_ascending) + mask = (zero_fp_sign_bit) ? zero_mask : nonzero_mask; + else + mask = (zero_fp_sign_bit) ? (~zero_mask) : (~nonzero_mask); + + return (uint_val ^ mask); +} + +//----------------- +// bucket functions +//----------------- + +template +constexpr std::size_t number_of_bits_in_type() +{ + constexpr std::size_t type_bits = + (sizeof(T) * std::numeric_limits::digits); + return type_bits; +} + +// the number of buckets (size of radix bits) in T +template +constexpr std::uint32_t number_of_buckets_in_type(std::uint32_t radix_bits) +{ + constexpr std::size_t type_bits = number_of_bits_in_type(); + return (type_bits + radix_bits - 1) / radix_bits; +} + +// get bits value (bucket) in a certain radix position +template +std::uint32_t get_bucket_id(T val, std::uint32_t radix_offset) +{ + static_assert(std::is_unsigned_v); + + return (val >> radix_offset) & T(radix_mask); +} + +//-------------------------------- +// count kernel (single iteration) +//-------------------------------- + +template +sycl::event + radix_sort_count_submit(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_segments, + std::size_t wg_size, + std::uint32_t radix_offset, + std::size_t n_values, + ValueT *vals_ptr, + std::size_t n_counts, + CountT *counts_ptr, + const Proj &proj_op, + const bool is_ascending, + const std::vector &dependency_events) +{ + // bin_count = radix_states used for an array storing bucket state counters + static constexpr std::uint32_t radix_states = + (std::uint32_t(1) << radix_bits); + static constexpr std::uint32_t radix_mask = radix_states - 1; + + // iteration space info + const std::size_t n = n_values; + // each segment is processed by a work-group + const std::size_t elems_per_segment = (n + n_segments - 1) / n_segments; + const std::size_t no_op_flag_id = n_counts - 1; + + assert(n_counts == (n_segments + 1) * radix_states + 1); + + sycl::event local_count_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependency_events); + + sycl::local_accessor counts_lacc(wg_size * radix_states, + cgh); + + sycl::nd_range<1> ndRange(n_iters * n_segments * wg_size, wg_size); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> ndit) { + // 0 <= lid < wg_size + const std::size_t lid = ndit.get_local_id(0); + // 0 <= group_id < n_segments * n_iters + const std::size_t group_id = ndit.get_group(0); + const std::size_t iter_id = group_id / n_segments; + const std::size_t val_iter_offset = iter_id * n; + // 0 <= wgr_id < n_segments + const std::size_t wgr_id = group_id - iter_id * n_segments; + + const std::size_t seg_start = elems_per_segment * wgr_id; + + // count per work-item: create a private array for storing count + // values here bin_count = radix_states + std::array counts_arr = {CountT{0}}; + + // count per work-item: count values and write result to private + // count array + const std::size_t seg_end = + sycl::min(seg_start + elems_per_segment, n); + if (is_ascending) { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += wg_size) { + // get the bucket for the bit-ordered input value, + // applying the offset and mask for radix bits + const auto val = + order_preserving_cast( + proj_op(vals_ptr[val_iter_offset + val_id])); + const std::uint32_t bucket_id = + get_bucket_id(val, radix_offset); + + // increment counter for this bit bucket + ++counts_arr[bucket_id]; + } + } + else { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += wg_size) { + // get the bucket for the bit-ordered input value, + // applying the offset and mask for radix bits + const auto val = + order_preserving_cast( + proj_op(vals_ptr[val_iter_offset + val_id])); + const std::uint32_t bucket_id = + get_bucket_id(val, radix_offset); + + // increment counter for this bit bucket + ++counts_arr[bucket_id]; + } + } + + // count per work-item: write private count array to local count + // array counts_lacc is concatenation of private count arrays from + // each work-item in the order of their local ids + const std::uint32_t count_start_id = radix_states * lid; + for (std::uint32_t radix_state_id = 0; + radix_state_id < radix_states; ++radix_state_id) + { + counts_lacc[count_start_id + radix_state_id] = + counts_arr[radix_state_id]; + } + + sycl::group_barrier(ndit.get_group()); + + // count per work-group: reduce till count_lacc[] size > wg_size + // all work-items in the work-group do the work. + for (std::uint32_t i = 1; i < radix_states; ++i) { + // Since we interested in computing total count over work-group + // for each radix state, the correct result is only assured if + // wg_size >= radix_states + counts_lacc[lid] += counts_lacc[wg_size * i + lid]; + } + + sycl::group_barrier(ndit.get_group()); + + // count per work-group: reduce until count_lacc[] size > + // radix_states (n_witems /= 2 per iteration) + for (std::uint32_t n_witems = (wg_size >> 1); + n_witems >= radix_states; n_witems >>= 1) + { + if (lid < n_witems) + counts_lacc[lid] += counts_lacc[n_witems + lid]; + + sycl::group_barrier(ndit.get_group()); + } + + const std::size_t iter_counter_offset = iter_id * n_counts; + + // count per work-group: write local count array to global count + // array + if (lid < radix_states) { + // move buckets with the same id to adjacent positions, + // thus splitting count array into radix_states regions + counts_ptr[iter_counter_offset + (n_segments + 1) * lid + + wgr_id] = counts_lacc[lid]; + } + + // side work: reset 'no-operation-flag', signaling to skip re-order + // phase + if (wgr_id == 0 && lid == 0) { + CountT &no_op_flag = + counts_ptr[iter_counter_offset + no_op_flag_id]; + no_op_flag = 0; + } + }); + }); + + return local_count_ev; +} + +//----------------------------------------------------------------------- +// radix sort: scan kernel (single iteration) +//----------------------------------------------------------------------- + +template +sycl::event radix_sort_scan_submit(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_segments, + std::size_t wg_size, + std::size_t n_values, + std::size_t n_counts, + CountT *counts_ptr, + const std::vector depends) +{ + const std::size_t no_op_flag_id = n_counts - 1; + + // Scan produces local offsets using count values. + // There are no local offsets for the first segment, but the rest segments + // should be scanned with respect to the count value in the first segment + // what requires n + 1 positions + const std::size_t scan_size = n_segments + 1; + wg_size = std::min(scan_size, wg_size); + + static constexpr std::uint32_t radix_states = std::uint32_t(1) + << radix_bits; + + // compilation of the kernel prevents out of resources issue, which may + // occur due to usage of collective algorithms such as joint_exclusive_scan + // even if local memory is not explicitly requested + sycl::event scan_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + sycl::nd_range<1> ndRange(n_iters * radix_states * wg_size, wg_size); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> ndit) { + const std::size_t group_id = ndit.get_group(0); + const std::size_t iter_id = group_id / radix_states; + const std::size_t wgr_id = group_id - iter_id * radix_states; + // find borders of a region with a specific bucket id + auto begin_ptr = + counts_ptr + scan_size * wgr_id + iter_id * n_counts; + + sycl::joint_exclusive_scan(ndit.get_group(), begin_ptr, + begin_ptr + scan_size, begin_ptr, + CountT(0), sycl::plus{}); + + const auto lid = ndit.get_local_linear_id(); + + // NB: No race condition here, because the condition may ever be + // true for only on one WG, one WI. + if ((lid == wg_size - 1) && (begin_ptr[scan_size - 1] == n_values)) + { + // set flag, since all the values got into one + // this is optimization, may happy often for + // higher radix offsets (all zeros) + auto &no_op_flag = + counts_ptr[iter_id * n_counts + no_op_flag_id]; + no_op_flag = 1; + } + }); + }); + + return scan_ev; +} + +//----------------------------------------------------------------------- +// radix sort: group level reorder algorithms +//----------------------------------------------------------------------- + +struct empty_storage +{ + template + empty_storage(T &&...) + { + } +}; + +// Number with `n` least significant bits of uint32_t +inline std::uint32_t n_ls_bits_set(std::uint32_t n) noexcept +{ + static constexpr std::uint32_t zero{}; + static constexpr std::uint32_t all_bits_set = ~zero; + + return ~(all_bits_set << n); +} + +enum class peer_prefix_algo +{ + subgroup_ballot, + atomic_fetch_or, + scan_then_broadcast +}; + +template +struct peer_prefix_helper; + +template +auto get_accessor_pointer(const AccT &acc) +{ + return acc.template get_multi_ptr().get(); +} + +template +struct peer_prefix_helper +{ + using AtomicT = sycl::atomic_ref; + using TempStorageT = sycl::local_accessor; + +private: + sycl::sub_group sgroup; + std::uint32_t lid; + std::uint32_t item_mask; + AtomicT atomic_peer_mask; + +public: + peer_prefix_helper(sycl::nd_item<1> ndit, TempStorageT lacc) + : sgroup(ndit.get_sub_group()), lid(ndit.get_local_linear_id()), + item_mask(n_ls_bits_set(lid)), atomic_peer_mask(lacc[0]) + { + } + + std::uint32_t peer_contribution(OffsetT &new_offset_id, + OffsetT offset_prefix, + bool wi_bit_set) const + { + // reset mask for each radix state + if (lid == 0) + atomic_peer_mask.store(std::uint32_t{0}); + sycl::group_barrier(sgroup); + + const std::uint32_t uint_contrib{wi_bit_set ? std::uint32_t{1} + : std::uint32_t{0}}; + + // set local id's bit to 1 if the bucket value matches the radix state + atomic_peer_mask.fetch_or(uint_contrib << lid); + sycl::group_barrier(sgroup); + std::uint32_t peer_mask_bits = atomic_peer_mask.load(); + std::uint32_t sg_total_offset = sycl::popcount(peer_mask_bits); + + // get the local offset index from the bits set in the peer mask with + // index less than the work item ID + peer_mask_bits &= item_mask; + new_offset_id |= wi_bit_set + ? (offset_prefix + sycl::popcount(peer_mask_bits)) + : OffsetT{0}; + return sg_total_offset; + } +}; + +template +struct peer_prefix_helper +{ + using TempStorageT = empty_storage; + using ItemType = sycl::nd_item<1>; + using SubGroupType = sycl::sub_group; + +private: + SubGroupType sgroup; + std::uint32_t sg_size; + +public: + peer_prefix_helper(sycl::nd_item<1> ndit, TempStorageT) + : sgroup(ndit.get_sub_group()), sg_size(sgroup.get_local_range()[0]) + { + } + + std::uint32_t peer_contribution(OffsetT &new_offset_id, + OffsetT offset_prefix, + bool wi_bit_set) const + { + const std::uint32_t contrib{wi_bit_set ? std::uint32_t{1} + : std::uint32_t{0}}; + + std::uint32_t sg_item_offset = sycl::exclusive_scan_over_group( + sgroup, contrib, sycl::plus{}); + + new_offset_id |= + (wi_bit_set ? (offset_prefix + sg_item_offset) : OffsetT(0)); + + // the last scanned value does not contain number of all copies, thus + // adding contribution + std::uint32_t sg_total_offset = sycl::group_broadcast( + sgroup, sg_item_offset + contrib, sg_size - 1); + + return sg_total_offset; + } +}; + +template +struct peer_prefix_helper +{ +private: + sycl::sub_group sgroup; + std::uint32_t lid; + sycl::ext::oneapi::sub_group_mask item_sg_mask; + + sycl::ext::oneapi::sub_group_mask mask_builder(std::uint32_t mask, + std::uint32_t sg_size) + { + return sycl::detail::Builder::createSubGroupMask< + sycl::ext::oneapi::sub_group_mask>(mask, sg_size); + } + +public: + using TempStorageT = empty_storage; + + peer_prefix_helper(sycl::nd_item<1> ndit, TempStorageT) + : sgroup(ndit.get_sub_group()), lid(ndit.get_local_linear_id()), + item_sg_mask( + mask_builder(n_ls_bits_set(lid), sgroup.get_local_linear_range())) + { + } + + std::uint32_t peer_contribution(OffsetT &new_offset_id, + OffsetT offset_prefix, + bool wi_bit_set) const + { + // set local id's bit to 1 if the bucket value matches the radix state + auto peer_mask = sycl::ext::oneapi::group_ballot(sgroup, wi_bit_set); + std::uint32_t peer_mask_bits{}; + + peer_mask.extract_bits(peer_mask_bits); + std::uint32_t sg_total_offset = sycl::popcount(peer_mask_bits); + + // get the local offset index from the bits set in the peer mask with + // index less than the work item ID + peer_mask &= item_sg_mask; + peer_mask.extract_bits(peer_mask_bits); + + new_offset_id |= wi_bit_set + ? (offset_prefix + sycl::popcount(peer_mask_bits)) + : OffsetT(0); + + return sg_total_offset; + } +}; + +template +void copy_func_for_radix_sort(const std::size_t n_segments, + const std::size_t elems_per_segment, + const std::size_t sg_size, + const std::uint32_t lid, + const std::size_t wgr_id, + const InputT *input_ptr, + const std::size_t n_values, + OutputT *output_ptr) +{ + // item info + const std::size_t seg_start = elems_per_segment * wgr_id; + + std::size_t seg_end = sycl::min(seg_start + elems_per_segment, n_values); + + // ensure that each work item in a subgroup does the same number of loop + // iterations + const std::uint16_t tail_size = (seg_end - seg_start) % sg_size; + seg_end -= tail_size; + + // find offsets for the same values within a segment and fill the resulting + // buffer + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += sg_size) { + output_ptr[val_id] = std::move(input_ptr[val_id]); + } + + if (tail_size > 0 && lid < tail_size) { + const std::size_t val_id = seg_end + lid; + output_ptr[val_id] = std::move(input_ptr[val_id]); + } +} + +//----------------------------------------------------------------------- +// radix sort: reorder kernel (per iteration) +//----------------------------------------------------------------------- +template +sycl::event + radix_sort_reorder_submit(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_segments, + std::uint32_t radix_offset, + std::size_t n_values, + const InputT *input_ptr, + OutputT *output_ptr, + std::size_t n_offsets, + OffsetT *offset_ptr, + const ProjT &proj_op, + const bool is_ascending, + const std::vector dependency_events) +{ + using ValueT = InputT; + using PeerHelper = peer_prefix_helper; + + static constexpr std::uint32_t radix_states = std::uint32_t{1} + << radix_bits; + static constexpr std::uint32_t radix_mask = radix_states - 1; + const std::size_t elems_per_segment = + (n_values + n_segments - 1) / n_segments; + + const std::size_t no_op_flag_id = n_offsets - 1; + + const auto &kernel_id = sycl::get_kernel_id(); + + auto const &ctx = exec_q.get_context(); + auto const &dev = exec_q.get_device(); + auto kb = sycl::get_kernel_bundle( + ctx, {dev}, {kernel_id}); + + auto krn = kb.get_kernel(kernel_id); + + const std::uint32_t sg_size = krn.template get_info< + sycl::info::kernel_device_specific::max_sub_group_size>(dev); + + sycl::event reorder_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependency_events); + cgh.use_kernel_bundle(kb); + + using StorageT = typename PeerHelper::TempStorageT; + + StorageT peer_temp(1, cgh); + + sycl::range<1> lRange{sg_size}; + sycl::range<1> gRange{n_iters * n_segments * sg_size}; + + sycl::nd_range<1> ndRange{gRange, lRange}; + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> ndit) { + const std::size_t group_id = ndit.get_group(0); + const std::size_t iter_id = group_id / n_segments; + const std::size_t segment_id = group_id - iter_id * n_segments; + + auto b_offset_ptr = offset_ptr + iter_id * n_offsets; + auto b_input_ptr = input_ptr + iter_id * n_values; + auto b_output_ptr = output_ptr + iter_id * n_values; + + const std::uint32_t lid = ndit.get_local_id(0); + + auto &no_op_flag = b_offset_ptr[no_op_flag_id]; + if (no_op_flag) { + // no reordering necessary, simply copy + copy_func_for_radix_sort( + n_segments, elems_per_segment, sg_size, lid, segment_id, + b_input_ptr, n_values, b_output_ptr); + return; + } + + // create a private array for storing offset values + // and add total offset and offset for compute unit + // for a certain radix state + std::array offset_arr{}; + const std::size_t scan_size = n_segments + 1; + + OffsetT scanned_bin = 0; + + /* find cumulative offset */ + static constexpr std::uint32_t zero_radix_state_id = 0; + offset_arr[zero_radix_state_id] = b_offset_ptr[segment_id]; + + for (std::uint32_t radix_state_id = 1; + radix_state_id < radix_states; ++radix_state_id) + { + const std::uint32_t local_offset_id = + segment_id + scan_size * radix_state_id; + + // scan bins serially + const std::size_t last_segment_bucket_id = + radix_state_id * scan_size - 1; + scanned_bin += b_offset_ptr[last_segment_bucket_id]; + + offset_arr[radix_state_id] = + scanned_bin + b_offset_ptr[local_offset_id]; + } + + const std::size_t seg_start = elems_per_segment * segment_id; + std::size_t seg_end = + sycl::min(seg_start + elems_per_segment, n_values); + // ensure that each work item in a subgroup does the same number of + // loop iterations + const std::uint32_t tail_size = (seg_end - seg_start) % sg_size; + seg_end -= tail_size; + + const PeerHelper peer_prefix_hlp(ndit, peer_temp); + + // find offsets for the same values within a segment and fill the + // resulting buffer + if (is_ascending) { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += sg_size) { + ValueT in_val = std::move(b_input_ptr[val_id]); + + // get the bucket for the bit-ordered input value, applying + // the offset and mask for radix bits + const auto mapped_val = + order_preserving_cast( + proj_op(in_val)); + std::uint32_t bucket_id = + get_bucket_id(mapped_val, radix_offset); + + OffsetT new_offset_id = 0; + for (std::uint32_t radix_state_id = 0; + radix_state_id < radix_states; ++radix_state_id) + { + bool is_current_bucket = (bucket_id == radix_state_id); + std::uint32_t sg_total_offset = + peer_prefix_hlp.peer_contribution( + /* modified by reference */ new_offset_id, + offset_arr[radix_state_id], + /* bit contribution from this work-item */ + is_current_bucket); + offset_arr[radix_state_id] += sg_total_offset; + } + b_output_ptr[new_offset_id] = std::move(in_val); + } + } + else { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += sg_size) { + ValueT in_val = std::move(b_input_ptr[val_id]); + + // get the bucket for the bit-ordered input value, applying + // the offset and mask for radix bits + const auto mapped_val = + order_preserving_cast( + proj_op(in_val)); + std::uint32_t bucket_id = + get_bucket_id(mapped_val, radix_offset); + + OffsetT new_offset_id = 0; + for (std::uint32_t radix_state_id = 0; + radix_state_id < radix_states; ++radix_state_id) + { + bool is_current_bucket = (bucket_id == radix_state_id); + std::uint32_t sg_total_offset = + peer_prefix_hlp.peer_contribution( + /* modified by reference */ new_offset_id, + offset_arr[radix_state_id], + /* bit contribution from this work-item */ + is_current_bucket); + offset_arr[radix_state_id] += sg_total_offset; + } + b_output_ptr[new_offset_id] = std::move(in_val); + } + } + if (tail_size > 0) { + ValueT in_val; + + // default: is greater than any actual radix state + std::uint32_t bucket_id = radix_states; + if (lid < tail_size) { + in_val = std::move(b_input_ptr[seg_end + lid]); + + const auto proj_val = proj_op(in_val); + const auto mapped_val = + (is_ascending) + ? order_preserving_cast( + proj_val) + : order_preserving_cast( + proj_val); + bucket_id = + get_bucket_id(mapped_val, radix_offset); + } + + OffsetT new_offset_id = 0; + for (std::uint32_t radix_state_id = 0; + radix_state_id < radix_states; ++radix_state_id) + { + bool is_current_bucket = (bucket_id == radix_state_id); + std::uint32_t sg_total_offset = + peer_prefix_hlp.peer_contribution( + new_offset_id, offset_arr[radix_state_id], + is_current_bucket); + + offset_arr[radix_state_id] += sg_total_offset; + } + + if (lid < tail_size) { + b_output_ptr[new_offset_id] = std::move(in_val); + } + } + }); + }); + + return reorder_ev; +} + +template +sizeT _slm_adjusted_work_group_size(sycl::queue &exec_q, + sizeT required_slm_bytes_per_wg, + sizeT wg_size) +{ + const auto &dev = exec_q.get_device(); + + if (wg_size == 0) + wg_size = + dev.template get_info(); + + const auto local_mem_sz = + dev.template get_info(); + + return sycl::min(local_mem_sz / required_slm_bytes_per_wg, wg_size); +} + +//----------------------------------------------------------------------- +// radix sort: one iteration +//----------------------------------------------------------------------- + +template +struct parallel_radix_sort_iteration_step +{ + template + using count_phase = radix_sort_count_kernel; + template + using local_scan_phase = radix_sort_scan_kernel; + template + using reorder_peer_phase = + radix_sort_reorder_peer_kernel; + template + using reorder_phase = radix_sort_reorder_kernel; + + template + static sycl::event submit(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_segments, + std::uint32_t radix_iter, + std::size_t n_values, + const InputT *in_ptr, + OutputT *out_ptr, + std::size_t n_counts, + CountT *counts_ptr, + const ProjT &proj_op, + const bool is_ascending, + const std::vector &dependency_events) + { + using _RadixCountKernel = count_phase; + using _RadixLocalScanKernel = + local_scan_phase; + using _RadixReorderPeerKernel = + reorder_peer_phase; + using _RadixReorderKernel = + reorder_phase; + + const auto &supported_sub_group_sizes = + exec_q.get_device() + .template get_info(); + const std::size_t max_sg_size = + (supported_sub_group_sizes.empty() + ? 0 + : supported_sub_group_sizes.back()); + const std::size_t reorder_sg_size = max_sg_size; + const std::size_t scan_wg_size = + exec_q.get_device() + .template get_info(); + + static constexpr std::size_t two_mils = (std::size_t(1) << 21); + std::size_t count_wg_size = + ((max_sg_size > 0) && (n_values > two_mils) ? 128 : max_sg_size); + + static constexpr std::uint32_t radix_states = std::uint32_t(1) + << radix_bits; + + // correct count_wg_size according to local memory limit in count phase + const auto max_count_wg_size = _slm_adjusted_work_group_size( + exec_q, sizeof(CountT) * radix_states, count_wg_size); + count_wg_size = + static_cast<::std::size_t>((max_count_wg_size / radix_states)) * + radix_states; + + // work-group size must be a power of 2 and not less than the number of + // states, for scanning to work correctly + + const std::size_t rounded_down_count_wg_size = + std::size_t{1} << (number_of_bits_in_type() - + sycl::clz(count_wg_size) - 1); + count_wg_size = + sycl::max(rounded_down_count_wg_size, std::size_t(radix_states)); + + // Compute the radix position for the given iteration + std::uint32_t radix_offset = radix_iter * radix_bits; + + // 1. Count Phase + sycl::event count_ev = + radix_sort_count_submit<_RadixCountKernel, radix_bits>( + exec_q, n_iters, n_segments, count_wg_size, radix_offset, + n_values, in_ptr, n_counts, counts_ptr, proj_op, is_ascending, + dependency_events); + + // 2. Scan Phase + sycl::event scan_ev = + radix_sort_scan_submit<_RadixLocalScanKernel, radix_bits>( + exec_q, n_iters, n_segments, scan_wg_size, n_values, n_counts, + counts_ptr, {count_ev}); + + // 3. Reorder Phase + sycl::event reorder_ev{}; + // subgroup_ballot-based peer algo uses extract_bits to populate + // uint32_t mask and hence relies on sub-group to be 32 or narrower + static constexpr std::size_t sg32_v = 32u; + static constexpr std::size_t sg16_v = 16u; + static constexpr std::size_t sg08_v = 8u; + if (sg32_v == reorder_sg_size || sg16_v == reorder_sg_size || + sg08_v == reorder_sg_size) + { + static constexpr auto peer_algorithm = + peer_prefix_algo::subgroup_ballot; + + reorder_ev = radix_sort_reorder_submit<_RadixReorderPeerKernel, + radix_bits, peer_algorithm>( + exec_q, n_iters, n_segments, radix_offset, n_values, in_ptr, + out_ptr, n_counts, counts_ptr, proj_op, is_ascending, + {scan_ev}); + } + else { + static constexpr auto peer_algorithm = + peer_prefix_algo::scan_then_broadcast; + + reorder_ev = radix_sort_reorder_submit<_RadixReorderKernel, + radix_bits, peer_algorithm>( + exec_q, n_iters, n_segments, radix_offset, n_values, in_ptr, + out_ptr, n_counts, counts_ptr, proj_op, is_ascending, + {scan_ev}); + } + + return reorder_ev; + } +}; // struct parallel_radix_sort_iteration + +template +class radix_sort_one_wg_krn; + +template +struct subgroup_radix_sort +{ +private: + class use_slm_tag + { + }; + class use_global_mem_tag + { + }; + +public: + template + sycl::event operator()(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_to_sort, + ValueT *input_ptr, + OutputT *output_ptr, + ProjT proj_op, + const bool is_ascending, + const std::vector &depends) + { + static_assert(std::is_same_v, OutputT>); + + using _SortKernelLoc = + radix_sort_one_wg_krn; + using _SortKernelPartGlob = + radix_sort_one_wg_krn; + using _SortKernelGlob = + radix_sort_one_wg_krn; + + static constexpr std::size_t max_concurrent_work_groups = 128U; + + // Choose this to occupy the entire accelerator + const std::size_t n_work_groups = + std::min(n_iters, max_concurrent_work_groups); + + // determine which temporary allocation can be accommodated in SLM + const auto &SLM_availability = + check_slm_size(exec_q, n_to_sort); + + const std::size_t n_batch_size = n_work_groups; + + switch (SLM_availability) { + case temp_allocations::both_in_slm: + { + static constexpr auto storage_for_values = use_slm_tag{}; + static constexpr auto storage_for_counters = use_slm_tag{}; + + return one_group_submitter<_SortKernelLoc>()( + exec_q, n_iters, n_iters, n_to_sort, input_ptr, output_ptr, + proj_op, is_ascending, storage_for_values, storage_for_counters, + depends); + } + case temp_allocations::counters_in_slm: + { + static constexpr auto storage_for_values = use_global_mem_tag{}; + static constexpr auto storage_for_counters = use_slm_tag{}; + + return one_group_submitter<_SortKernelPartGlob>()( + exec_q, n_iters, n_batch_size, n_to_sort, input_ptr, output_ptr, + proj_op, is_ascending, storage_for_values, storage_for_counters, + depends); + } + default: + { + static constexpr auto storage_for_values = use_global_mem_tag{}; + static constexpr auto storage_for_counters = use_global_mem_tag{}; + + return one_group_submitter<_SortKernelGlob>()( + exec_q, n_iters, n_batch_size, n_to_sort, input_ptr, output_ptr, + proj_op, is_ascending, storage_for_values, storage_for_counters, + depends); + } + } + } + +private: + template + class TempBuf; + + template + class TempBuf + { + std::size_t buf_size; + + public: + TempBuf(std::size_t, std::size_t n) : buf_size(n) {} + auto get_acc(sycl::handler &cgh) + { + return sycl::local_accessor(buf_size, cgh); + } + + std::size_t get_iter_stride() const + { + return std::size_t{0}; + } + }; + + template + class TempBuf + { + sycl::buffer buf; + std::size_t iter_stride; + + public: + TempBuf(std::size_t n_iters, std::size_t n) + : buf(n_iters * n), iter_stride(n) + { + } + auto get_acc(sycl::handler &cgh) + { + return sycl::accessor(buf, cgh, sycl::read_write, sycl::no_init); + } + std::size_t get_iter_stride() const + { + return iter_stride; + } + }; + + static_assert(wg_size <= 1024); + static constexpr std::uint16_t bin_count = (1 << radix); + static constexpr std::uint16_t counter_buf_sz = wg_size * bin_count + 1; + + enum class temp_allocations + { + both_in_slm, + counters_in_slm, + both_in_global_mem + }; + + template + temp_allocations check_slm_size(const sycl::queue &exec_q, SizeT n) + { + // the kernel is designed for data size <= 64K + assert(n <= (SizeT(1) << 16)); + + static constexpr auto req_slm_size_counters = + counter_buf_sz * sizeof(std::uint16_t); + + const auto &dev = exec_q.get_device(); + + // Pessimistically only use half of the memory to take into account + // a SYCL group algorithm might use a portion of SLM + const std::size_t max_slm_size = + dev.template get_info() / 2; + + const auto n_uniform = 1 << ceil_log2(n); + const auto req_slm_size_val = sizeof(T) * n_uniform; + + return ((req_slm_size_val + req_slm_size_counters) <= max_slm_size) + ? + // the values and the counters are placed in SLM + temp_allocations::both_in_slm + : (req_slm_size_counters <= max_slm_size) + ? + // the counters are placed in SLM, the values - in the + // global memory + temp_allocations::counters_in_slm + : + // the values and the counters are placed in the global + // memory + temp_allocations::both_in_global_mem; + } + + template + struct one_group_submitter + { + template + sycl::event operator()(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_batch_size, + std::size_t n_values, + InputT *input_arr, + OutputT *output_arr, + const ProjT &proj_op, + const bool is_ascending, + SLM_value_tag, + SLM_counter_tag, + const std::vector &depends) + { + assert(!(n_values >> 16)); + + assert(n_values <= static_cast(block_size) * + static_cast(wg_size)); + + const std::uint16_t n = static_cast(n_values); + static_assert(std::is_same_v, OutputT>); + + using ValueT = OutputT; + + using KeyT = std::invoke_result_t; + + TempBuf buf_val( + n_batch_size, static_cast(block_size * wg_size)); + TempBuf buf_count( + n_batch_size, static_cast(counter_buf_sz)); + + sycl::range<1> lRange{wg_size}; + + sycl::event sort_ev; + std::vector deps{depends}; + + const std::size_t n_batches = + (n_iters + n_batch_size - 1) / n_batch_size; + + const auto &kernel_id = sycl::get_kernel_id(); + + auto const &ctx = exec_q.get_context(); + auto const &dev = exec_q.get_device(); + auto kb = sycl::get_kernel_bundle( + ctx, {dev}, {kernel_id}); + + const auto &krn = kb.get_kernel(kernel_id); + + const std::uint32_t krn_sg_size = krn.template get_info< + sycl::info::kernel_device_specific::max_sub_group_size>(dev); + + // due to a bug in CPU device implementation, an additional + // synchronization is necessary for short sub-group sizes + const bool work_around_needed = + exec_q.get_device().has(sycl::aspect::cpu) && + (krn_sg_size < 16); + + for (std::size_t batch_id = 0; batch_id < n_batches; ++batch_id) { + + const std::size_t block_start = batch_id * n_batch_size; + + // input_arr/output_arr each has shape (n_iters, n) + InputT *this_input_arr = input_arr + block_start * n_values; + OutputT *this_output_arr = output_arr + block_start * n_values; + + const std::size_t block_end = + std::min(block_start + n_batch_size, n_iters); + + sycl::range<1> gRange{(block_end - block_start) * wg_size}; + sycl::nd_range ndRange{gRange, lRange}; + + sort_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + cgh.use_kernel_bundle(kb); + + // allocation to use for value exchanges + auto exchange_acc = buf_val.get_acc(cgh); + const std::size_t exchange_acc_iter_stride = + buf_val.get_iter_stride(); + + // allocation for counters + auto counter_acc = buf_count.get_acc(cgh); + const std::size_t counter_acc_iter_stride = + buf_count.get_iter_stride(); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> + ndit) { + ValueT values[block_size]; + + const std::size_t iter_id = ndit.get_group(0); + const std::size_t iter_val_offset = + iter_id * static_cast(n); + const std::size_t iter_counter_offset = + iter_id * counter_acc_iter_stride; + const std::size_t iter_exchange_offset = + iter_id * exchange_acc_iter_stride; + + std::uint16_t wi = ndit.get_local_linear_id(); + std::uint16_t begin_bit = 0; + + static constexpr std::uint16_t end_bit = + number_of_bits_in_type(); + + // copy from input array into values +#pragma unroll + for (std::uint16_t i = 0; i < block_size; ++i) { + const std::uint16_t id = wi * block_size + i; + values[i] = + (id < n) ? this_input_arr[iter_val_offset + id] + : ValueT{}; + } + + while (true) { + // indices for indirect access in the "re-order" + // phase + std::uint16_t indices[block_size]; + { + // pointers to bucket's counters + std::uint16_t *counters[block_size]; + + // counting phase + auto pcounter = + get_accessor_pointer(counter_acc) + + (wi + iter_counter_offset); + + // initialize counters +#pragma unroll + for (std::uint16_t i = 0; i < bin_count; ++i) + pcounter[i * wg_size] = std::uint16_t{0}; + + sycl::group_barrier(ndit.get_group()); + + if (is_ascending) { +#pragma unroll + for (std::uint16_t i = 0; i < block_size; + ++i) { + const std::uint16_t id = + wi * block_size + i; + static constexpr std::uint16_t + bin_mask = bin_count - 1; + + // points to the padded element, i.e. id + // is in-range + static constexpr std::uint16_t + default_out_of_range_bin_id = + bin_mask; + + const std::uint16_t bin = + (id < n) + ? get_bucket_id( + order_preserving_cast< + /* is_ascending */ + true>( + proj_op(values[i])), + begin_bit) + : default_out_of_range_bin_id; + + // counting and local offset calculation + counters[i] = &pcounter[bin * wg_size]; + indices[i] = *counters[i]; + *counters[i] = indices[i] + 1; + + if (work_around_needed) { + sycl::group_barrier( + ndit.get_group()); + } + } + } + else { +#pragma unroll + for (std::uint16_t i = 0; i < block_size; + ++i) { + const std::uint16_t id = + wi * block_size + i; + static constexpr std::uint16_t + bin_mask = bin_count - 1; + + // points to the padded element, i.e. id + // is in-range + static constexpr std::uint16_t + default_out_of_range_bin_id = + bin_mask; + + const std::uint16_t bin = + (id < n) + ? get_bucket_id( + order_preserving_cast< + /* is_ascending */ + false>( + proj_op(values[i])), + begin_bit) + : default_out_of_range_bin_id; + + // counting and local offset calculation + counters[i] = &pcounter[bin * wg_size]; + indices[i] = *counters[i]; + *counters[i] = indices[i] + 1; + + if (work_around_needed) { + sycl::group_barrier( + ndit.get_group()); + } + } + } + + sycl::group_barrier(ndit.get_group()); + + // exclusive scan phase + { + + // scan contiguous numbers + std::uint16_t bin_sum[bin_count]; + const std::size_t counter_offset0 = + iter_counter_offset + wi * bin_count; + bin_sum[0] = counter_acc[counter_offset0]; + +#pragma unroll + for (std::uint16_t i = 1; i < bin_count; + ++i) + bin_sum[i] = + bin_sum[i - 1] + + counter_acc[counter_offset0 + i]; + + sycl::group_barrier(ndit.get_group()); + + // exclusive scan local sum + std::uint16_t sum_scan = + sycl::exclusive_scan_over_group( + ndit.get_group(), + bin_sum[bin_count - 1], + sycl::plus()); + +// add to local sum, generate exclusive scan result +#pragma unroll + for (std::uint16_t i = 0; i < bin_count; + ++i) + counter_acc[counter_offset0 + i + 1] = + sum_scan + bin_sum[i]; + + if (wi == 0) + counter_acc[iter_counter_offset + 0] = + std::uint32_t{0}; + + sycl::group_barrier(ndit.get_group()); + } + +#pragma unroll + for (std::uint16_t i = 0; i < block_size; ++i) { + // a global index is a local offset plus a + // global base index + indices[i] += *counters[i]; + } + + sycl::group_barrier(ndit.get_group()); + } + + begin_bit += radix; + + // "re-order" phase + sycl::group_barrier(ndit.get_group()); + if (begin_bit >= end_bit) { + // the last iteration - writing out the result +#pragma unroll + for (std::uint16_t i = 0; i < block_size; ++i) { + const std::uint16_t r = indices[i]; + if (r < n) { + this_output_arr[iter_val_offset + r] = + values[i]; + } + } + + return; + } + + // data exchange +#pragma unroll + for (std::uint16_t i = 0; i < block_size; ++i) { + const std::uint16_t r = indices[i]; + if (r < n) + exchange_acc[iter_exchange_offset + r] = + values[i]; + } + + sycl::group_barrier(ndit.get_group()); + +#pragma unroll + for (std::uint16_t i = 0; i < block_size; ++i) { + const std::uint16_t id = wi * block_size + i; + if (id < n) + values[i] = + exchange_acc[iter_exchange_offset + id]; + } + + sycl::group_barrier(ndit.get_group()); + } + }); + }); + + deps = {sort_ev}; + } + + return sort_ev; + } + }; +}; + +template +struct OneWorkGroupRadixSortKernel; + +//----------------------------------------------------------------------- +// radix sort: main function +//----------------------------------------------------------------------- +template +sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_to_sort, + const ValueT *input_arr, + ValueT *output_arr, + const ProjT &proj_op, + const bool is_ascending, + const std::vector &depends) +{ + assert(n_to_sort > 1); + + using KeyT = std::remove_cv_t< + std::remove_reference_t>>; + + // radix bits represent number of processed bits in each value during one + // iteration + static constexpr std::uint32_t radix_bits = 4; + + sycl::event sort_ev{}; + + const auto &dev = exec_q.get_device(); + const auto max_wg_size = + dev.template get_info(); + + static constexpr std::uint16_t ref_wg_size = 64; + if (n_to_sort <= 16384 && ref_wg_size * 8 <= max_wg_size) { + using _RadixSortKernel = OneWorkGroupRadixSortKernel; + + if (n_to_sort <= 64 && ref_wg_size <= max_wg_size) { + // wg_size * block_size == 64 * 1 * 1 == 64 + static constexpr std::uint16_t wg_size = ref_wg_size; + static constexpr std::uint16_t block_size = 1; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 128 && ref_wg_size * 2 <= max_wg_size) { + // wg_size * block_size == 64 * 2 * 1 == 128 + static constexpr std::uint16_t wg_size = ref_wg_size * 2; + static constexpr std::uint16_t block_size = 1; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 256 && ref_wg_size * 2 <= max_wg_size) { + // wg_size * block_size == 64 * 2 * 2 == 256 + static constexpr std::uint16_t wg_size = ref_wg_size * 2; + static constexpr std::uint16_t block_size = 2; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 512 && ref_wg_size * 2 <= max_wg_size) { + // wg_size * block_size == 64 * 2 * 4 == 512 + static constexpr std::uint16_t wg_size = ref_wg_size * 2; + static constexpr std::uint16_t block_size = 4; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 1024 && ref_wg_size * 2 <= max_wg_size) { + // wg_size * block_size == 64 * 2 * 8 == 1024 + static constexpr std::uint16_t wg_size = ref_wg_size * 2; + static constexpr std::uint16_t block_size = 8; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 2048 && ref_wg_size * 4 <= max_wg_size) { + // wg_size * block_size == 64 * 4 * 8 == 2048 + static constexpr std::uint16_t wg_size = ref_wg_size * 4; + static constexpr std::uint16_t block_size = 8; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 4096 && ref_wg_size * 4 <= max_wg_size) { + // wg_size * block_size == 64 * 4 * 16 == 4096 + static constexpr std::uint16_t wg_size = ref_wg_size * 4; + static constexpr std::uint16_t block_size = 16; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 8192 && ref_wg_size * 8 <= max_wg_size) { + // wg_size * block_size == 64 * 8 * 16 == 8192 + static constexpr std::uint16_t wg_size = ref_wg_size * 8; + static constexpr std::uint16_t block_size = 16; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else { + // wg_size * block_size == 64 * 8 * 32 == 16384 + static constexpr std::uint16_t wg_size = ref_wg_size * 8; + static constexpr std::uint16_t block_size = 32; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + } + else { + static constexpr std::uint32_t radix_iters = + number_of_buckets_in_type(radix_bits); + static constexpr std::uint32_t radix_states = std::uint32_t(1) + << radix_bits; + + static constexpr std::size_t bound_512k = (std::size_t(1) << 19); + static constexpr std::size_t bound_2m = (std::size_t(1) << 21); + + const auto wg_sz_k = (n_to_sort < bound_512k) ? 8 + : (n_to_sort <= bound_2m) ? 4 + : 1; + const std::size_t wg_size = max_wg_size / wg_sz_k; + + const std::size_t n_segments = (n_to_sort + wg_size - 1) / wg_size; + + // Additional radix_states elements are used for getting local offsets + // from count values + no_op flag; 'No operation' flag specifies whether + // to skip re-order phase if the all keys are the same (lie in one bin) + const std::size_t n_counts = + (n_segments + 1) * radix_states + 1 /*no_op flag*/; + + using CountT = std::uint32_t; + + // memory for storing count and offset values + auto count_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + n_iters * n_counts, exec_q); + + CountT *count_ptr = count_owner.get(); + + static constexpr std::uint32_t zero_radix_iter{0}; + + if constexpr (std::is_same_v) { + + sort_ev = parallel_radix_sort_iteration_step< + radix_bits, /*even=*/true>::submit(exec_q, n_iters, n_segments, + zero_radix_iter, n_to_sort, + input_arr, output_arr, + n_counts, count_ptr, proj_op, + is_ascending, depends); + + sort_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {sort_ev}, count_owner); + + return sort_ev; + } + + auto tmp_arr_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + n_iters * n_to_sort, exec_q); + + ValueT *tmp_arr = tmp_arr_owner.get(); + + // iterations per each bucket + assert("Number of iterations must be even" && radix_iters % 2 == 0); + assert(radix_iters > 0); + + sort_ev = parallel_radix_sort_iteration_step< + radix_bits, /*even=*/true>::submit(exec_q, n_iters, n_segments, + zero_radix_iter, n_to_sort, + input_arr, tmp_arr, n_counts, + count_ptr, proj_op, is_ascending, + depends); + + for (std::uint32_t radix_iter = 1; radix_iter < radix_iters; + ++radix_iter) { + if (radix_iter % 2 == 0) { + sort_ev = parallel_radix_sort_iteration_step< + radix_bits, + /*even=*/true>::submit(exec_q, n_iters, n_segments, + radix_iter, n_to_sort, output_arr, + tmp_arr, n_counts, count_ptr, + proj_op, is_ascending, {sort_ev}); + } + else { + sort_ev = parallel_radix_sort_iteration_step< + radix_bits, + /*even=*/false>::submit(exec_q, n_iters, n_segments, + radix_iter, n_to_sort, tmp_arr, + output_arr, n_counts, count_ptr, + proj_op, is_ascending, {sort_ev}); + } + } + + sort_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {sort_ev}, tmp_arr_owner, count_owner); + } + + return sort_ev; +} + +struct IdentityProj +{ + constexpr IdentityProj() {} + + template + constexpr T operator()(T val) const + { + return val; + } +}; + +template +struct ValueProj +{ + constexpr ValueProj() {} + + constexpr ValueT operator()(const std::pair &pair) const + { + return pair.first; + } +}; + +template +struct IndexedProj +{ + IndexedProj(const ValueT *arg_ptr) : ptr(arg_ptr), value_projector{} {} + + IndexedProj(const ValueT *arg_ptr, const ProjT &proj_op) + : ptr(arg_ptr), value_projector(proj_op) + { + } + + auto operator()(IndexT i) const + { + return value_projector(ptr[i]); + } + +private: + const ValueT *ptr; + ProjT value_projector; +}; + +} // namespace radix_sort_details + +using dpctl::tensor::ssize_t; + +template +sycl::event + radix_sort_axis1_contig_impl(sycl::queue &exec_q, + const bool sort_ascending, + // number of sub-arrays to sort (num. of rows + // in a matrix when sorting over rows) + std::size_t iter_nelems, + // size of each array to sort (length of rows, + // i.e. number of columns) + std::size_t sort_nelems, + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + sort_arg_offset; + argTy *res_tp = + reinterpret_cast(res_cp) + iter_res_offset + sort_res_offset; + + using Proj = radix_sort_details::IdentityProj; + static constexpr Proj proj_op{}; + + sycl::event radix_sort_ev = + radix_sort_details::parallel_radix_sort_impl( + exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, proj_op, + sort_ascending, depends); + + return radix_sort_ev; +} + +template +class radix_argsort_index_write_out_krn; + +template +class radix_argsort_iota_krn; + +template +sycl::event + radix_argsort_axis1_contig_impl(sycl::queue &exec_q, + const bool sort_ascending, + // number of sub-arrays to sort (num. of + // rows in a matrix when sorting over rows) + std::size_t iter_nelems, + // size of each array to sort (length of + // rows, i.e. number of columns) + std::size_t sort_nelems, + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + sort_arg_offset; + IndexTy *res_tp = + reinterpret_cast(res_cp) + iter_res_offset + sort_res_offset; + + const std::size_t total_nelems = iter_nelems * sort_nelems; + auto workspace_owner = + dpctl::tensor::alloc_utils::smart_malloc_device(total_nelems, + exec_q); + + // get raw USM pointer + IndexTy *workspace = workspace_owner.get(); + + using IdentityProjT = radix_sort_details::IdentityProj; + using IndexedProjT = + radix_sort_details::IndexedProj; + const IndexedProjT proj_op{arg_tp}; + + using IotaKernelName = radix_argsort_iota_krn; + + using dpctl::tensor::kernels::sort_utils_detail::iota_impl; + + sycl::event iota_ev = iota_impl( + exec_q, workspace, total_nelems, depends); + + sycl::event radix_sort_ev = + radix_sort_details::parallel_radix_sort_impl( + exec_q, iter_nelems, sort_nelems, workspace, res_tp, proj_op, + sort_ascending, {iota_ev}); + + using MapBackKernelName = radix_argsort_index_write_out_krn; + using dpctl::tensor::kernels::sort_utils_detail::map_back_impl; + + sycl::event dep = radix_sort_ev; + + // no need to perform map_back ( id % sort_nelems) + // if total_nelems == sort_nelems + if (iter_nelems > 1u) { + dep = map_back_impl( + exec_q, total_nelems, res_tp, res_tp, sort_nelems, {dep}); + } + + sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {dep}, workspace_owner); + + return cleanup_ev; +} + +} // namespace dpctl::tensor::kernels diff --git a/dpctl_ext/tensor/libtensor/include/kernels/sorting/search_sorted_detail.hpp b/dpctl_ext/tensor/libtensor/include/kernels/sorting/search_sorted_detail.hpp new file mode 100644 index 000000000000..1f3576402511 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/sorting/search_sorted_detail.hpp @@ -0,0 +1,119 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor sort/argsort operations. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace dpctl::tensor::kernels::search_sorted_detail +{ + +template +T quotient_ceil(T n, T m) +{ + return (n + m - 1) / m; +} + +template +std::size_t lower_bound_impl(const Acc acc, + const std::size_t first, + const std::size_t last, + const Value &value, + const Compare &comp) +{ + std::size_t n = last - first; + std::size_t cur = n, start = first; + std::size_t it; + while (n > 0) { + it = start; + cur = n / 2; + it += cur; + if (comp(acc[it], value)) { + n -= cur + 1, start = ++it; + } + else + n = cur; + } + return start; +} + +template +std::size_t upper_bound_impl(const Acc acc, + const std::size_t first, + const std::size_t last, + const Value &value, + const Compare &comp) +{ + const auto &op_comp = [comp](auto x, auto y) { return !comp(y, x); }; + return lower_bound_impl(acc, first, last, value, op_comp); +} + +template +std::size_t lower_bound_indexed_impl(const Acc acc, + std::size_t first, + std::size_t last, + const Value &value, + const Compare &comp, + const IndexerT &acc_indexer) +{ + std::size_t n = last - first; + std::size_t cur = n, start = first; + std::size_t it; + while (n > 0) { + it = start; + cur = n / 2; + it += cur; + if (comp(acc[acc_indexer(it)], value)) { + n -= cur + 1, start = ++it; + } + else + n = cur; + } + return start; +} + +template +std::size_t upper_bound_indexed_impl(const Acc acc, + const std::size_t first, + const std::size_t last, + const Value &value, + const Compare &comp, + const IndexerT &acc_indexer) +{ + const auto &op_comp = [comp](auto x, auto y) { return !comp(y, x); }; + return lower_bound_indexed_impl(acc, first, last, value, op_comp, + acc_indexer); +} + +} // namespace dpctl::tensor::kernels::search_sorted_detail diff --git a/dpctl_ext/tensor/libtensor/include/kernels/sorting/searchsorted.hpp b/dpctl_ext/tensor/libtensor/include/kernels/sorting/searchsorted.hpp new file mode 100644 index 000000000000..bc400c9e569a --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/sorting/searchsorted.hpp @@ -0,0 +1,258 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor sort/argsort operations. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/search_sorted_detail.hpp" +#include "utils/offset_utils.hpp" + +namespace dpctl::tensor::kernels +{ + +using dpctl::tensor::ssize_t; + +template +struct SearchSortedFunctor +{ +private: + const argTy *hay_tp; + const argTy *needles_tp; + indTy *positions_tp; + std::size_t hay_nelems; + HayIndexerT hay_indexer; + NeedlesIndexerT needles_indexer; + PositionsIndexerT positions_indexer; + +public: + SearchSortedFunctor(const argTy *hay_, + const argTy *needles_, + indTy *positions_, + const std::size_t hay_nelems_, + const HayIndexerT &hay_indexer_, + const NeedlesIndexerT &needles_indexer_, + const PositionsIndexerT &positions_indexer_) + : hay_tp(hay_), needles_tp(needles_), positions_tp(positions_), + hay_nelems(hay_nelems_), hay_indexer(hay_indexer_), + needles_indexer(needles_indexer_), + positions_indexer(positions_indexer_) + { + } + + void operator()(sycl::id<1> id) const + { + const Compare comp{}; + + const std::size_t i = id[0]; + const argTy needle_v = needles_tp[needles_indexer(i)]; + + // position of the needle_v in the hay array + indTy pos{}; + + static constexpr std::size_t zero(0); + if constexpr (left_side) { + // search in hay in left-closed interval, give `pos` such that + // hay[pos - 1] < needle_v <= hay[pos] + + // lower_bound returns the first pos such that bool(hay[pos] < + // needle_v) is false, i.e. needle_v <= hay[pos] + pos = search_sorted_detail::lower_bound_indexed_impl( + hay_tp, zero, hay_nelems, needle_v, comp, hay_indexer); + } + else { + // search in hay in right-closed interval: hay[pos - 1] <= needle_v + // < hay[pos] + + // upper_bound returns the first pos such that bool(needle_v < + // hay[pos]) is true, i.e. needle_v < hay[pos] + pos = search_sorted_detail::upper_bound_indexed_impl( + hay_tp, zero, hay_nelems, needle_v, comp, hay_indexer); + } + + positions_tp[positions_indexer(i)] = pos; + } +}; + +typedef sycl::event (*searchsorted_contig_impl_fp_ptr_t)( + sycl::queue &, + const std::size_t, + const std::size_t, + const char *, + const ssize_t, + const char *, + const ssize_t, + char *, + const ssize_t, + const std::vector &); + +template +class searchsorted_contig_impl_krn; + +template +sycl::event searchsorted_contig_impl(sycl::queue &exec_q, + const std::size_t hay_nelems, + const std::size_t needles_nelems, + const char *hay_cp, + const ssize_t hay_offset, + const char *needles_cp, + const ssize_t needles_offset, + char *positions_cp, + const ssize_t positions_offset, + const std::vector &depends) +{ + const argTy *hay_tp = reinterpret_cast(hay_cp) + hay_offset; + const argTy *needles_tp = + reinterpret_cast(needles_cp) + needles_offset; + + indTy *positions_tp = + reinterpret_cast(positions_cp) + positions_offset; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using KernelName = + class searchsorted_contig_impl_krn; + + sycl::range<1> gRange(needles_nelems); + + using TrivialIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + static constexpr TrivialIndexerT hay_indexer{}; + static constexpr TrivialIndexerT needles_indexer{}; + static constexpr TrivialIndexerT positions_indexer{}; + + const auto fnctr = + SearchSortedFunctor( + hay_tp, needles_tp, positions_tp, hay_nelems, hay_indexer, + needles_indexer, positions_indexer); + + cgh.parallel_for(gRange, fnctr); + }); + + return comp_ev; +} + +typedef sycl::event (*searchsorted_strided_impl_fp_ptr_t)( + sycl::queue &, + const std::size_t, + const std::size_t, + const char *, + const ssize_t, + const ssize_t, + const char *, + const ssize_t, + char *, + const ssize_t, + int, + const ssize_t *, + const std::vector &); + +template +class searchsorted_strided_impl_krn; + +template +sycl::event searchsorted_strided_impl( + sycl::queue &exec_q, + const std::size_t hay_nelems, + const std::size_t needles_nelems, + const char *hay_cp, + const ssize_t hay_offset, + // hay is 1D, so hay_nelems, hay_offset, hay_stride describe strided array + const ssize_t hay_stride, + const char *needles_cp, + const ssize_t needles_offset, + char *positions_cp, + const ssize_t positions_offset, + const int needles_nd, + // packed_shape_strides is [needles_shape, needles_strides, + // positions_strides] has length of 3*needles_nd + const ssize_t *packed_shape_strides, + const std::vector &depends) +{ + const argTy *hay_tp = reinterpret_cast(hay_cp); + const argTy *needles_tp = reinterpret_cast(needles_cp); + + indTy *positions_tp = reinterpret_cast(positions_cp); + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + sycl::range<1> gRange(needles_nelems); + + using HayIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + const HayIndexerT hay_indexer( + /* offset */ hay_offset, + /* size */ hay_nelems, + /* step */ hay_stride); + + using NeedlesIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const ssize_t *needles_shape_strides = packed_shape_strides; + const NeedlesIndexerT needles_indexer(needles_nd, needles_offset, + needles_shape_strides); + using PositionsIndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + + const ssize_t *positions_shape = packed_shape_strides; + const ssize_t *positions_strides = + packed_shape_strides + 2 * needles_nd; + const PositionsIndexerT positions_indexer( + needles_nd, positions_offset, positions_shape, positions_strides); + + const auto fnctr = + SearchSortedFunctor( + hay_tp, needles_tp, positions_tp, hay_nelems, hay_indexer, + needles_indexer, positions_indexer); + using KernelName = + class searchsorted_strided_impl_krn; + + cgh.parallel_for(gRange, fnctr); + }); + + return comp_ev; +} + +} // namespace dpctl::tensor::kernels diff --git a/dpctl_ext/tensor/libtensor/include/kernels/sorting/sort_impl_fn_ptr_t.hpp b/dpctl_ext/tensor/libtensor/include/kernels/sorting/sort_impl_fn_ptr_t.hpp new file mode 100644 index 000000000000..7b48f310a445 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/sorting/sort_impl_fn_ptr_t.hpp @@ -0,0 +1,61 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include + +#include "kernels/dpctl_tensor_types.hpp" + +namespace dpctl::tensor::kernels +{ + +using dpctl::tensor::ssize_t; + +typedef sycl::event (*sort_contig_fn_ptr_t)(sycl::queue &, + std::size_t, + std::size_t, + const char *, + char *, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + const std::vector &); + +} // namespace dpctl::tensor::kernels diff --git a/dpctl_ext/tensor/libtensor/include/kernels/sorting/sort_utils.hpp b/dpctl_ext/tensor/libtensor/include/kernels/sorting/sort_utils.hpp new file mode 100644 index 000000000000..fd32905b808e --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/sorting/sort_utils.hpp @@ -0,0 +1,144 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor sort/argsort operations. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include + +#include + +namespace dpctl::tensor::kernels::sort_utils_detail +{ + +namespace syclexp = sycl::ext::oneapi::experimental; + +template +sycl::event iota_impl(sycl::queue &exec_q, + T *data, + std::size_t nelems, + const std::vector &dependent_events) +{ + static constexpr std::uint32_t lws = 256; + static constexpr std::uint32_t n_wi = 4; + const std::size_t n_groups = (nelems + n_wi * lws - 1) / (n_wi * lws); + + sycl::range<1> gRange{n_groups * lws}; + sycl::range<1> lRange{lws}; + sycl::nd_range<1> ndRange{gRange, lRange}; + + sycl::event e = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_events); + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> it) { + const std::size_t gid = it.get_global_linear_id(); + const auto &sg = it.get_sub_group(); + const std::uint32_t lane_id = sg.get_local_id()[0]; + + const std::size_t offset = (gid - lane_id) * n_wi; + const std::uint32_t max_sgSize = sg.get_max_local_range()[0]; + + std::array stripe{}; +#pragma unroll + for (std::uint32_t i = 0; i < n_wi; ++i) { + stripe[i] = T(offset + lane_id + i * max_sgSize); + } + + if (offset + n_wi * max_sgSize < nelems) { + static constexpr auto group_ls_props = + syclexp::properties{syclexp::data_placement_striped}; + + auto out_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&data[offset]); + + syclexp::group_store(sg, sycl::span{&stripe[0], n_wi}, + out_multi_ptr, group_ls_props); + } + else { + for (std::size_t idx = offset + lane_id; idx < nelems; + idx += max_sgSize) { + data[idx] = T(idx); + } + } + }); + }); + + return e; +} + +template +sycl::event map_back_impl(sycl::queue &exec_q, + std::size_t nelems, + const IndexTy *flat_index_data, + IndexTy *reduced_index_data, + std::size_t row_size, + const std::vector &dependent_events) +{ + static constexpr std::uint32_t lws = 64; + static constexpr std::uint32_t n_wi = 4; + const std::size_t n_groups = (nelems + lws * n_wi - 1) / (n_wi * lws); + + sycl::range<1> lRange{lws}; + sycl::range<1> gRange{n_groups * lws}; + sycl::nd_range<1> ndRange{gRange, lRange}; + + sycl::event map_back_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_events); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> it) { + const std::size_t gid = it.get_global_linear_id(); + const auto &sg = it.get_sub_group(); + const std::uint32_t lane_id = sg.get_local_id()[0]; + const std::uint32_t sg_size = sg.get_max_local_range()[0]; + + const std::size_t start_id = (gid - lane_id) * n_wi + lane_id; + +#pragma unroll + for (std::uint32_t i = 0; i < n_wi; ++i) { + const std::size_t data_id = start_id + i * sg_size; + + if (data_id < nelems) { + const IndexTy linear_index = flat_index_data[data_id]; + reduced_index_data[data_id] = (linear_index % row_size); + } + } + }); + }); + + return map_back_ev; +} + +} // namespace dpctl::tensor::kernels::sort_utils_detail diff --git a/dpctl_ext/tensor/libtensor/include/kernels/sorting/topk.hpp b/dpctl_ext/tensor/libtensor/include/kernels/sorting/topk.hpp new file mode 100644 index 000000000000..a1e57dc9ef30 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/sorting/topk.hpp @@ -0,0 +1,511 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor topk operation. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include "kernels/sorting/merge_sort.hpp" +#include "kernels/sorting/radix_sort.hpp" +#include "kernels/sorting/search_sorted_detail.hpp" +#include "kernels/sorting/sort_utils.hpp" +#include "utils/sycl_alloc_utils.hpp" + +namespace dpctl::tensor::kernels +{ + +namespace topk_detail +{ + +void scale_topk_params(const std::uint64_t nelems_per_slm, + const std::size_t sub_groups_per_work_group, + const std::uint32_t elems_per_wi, + const std::vector &sg_sizes, + std::size_t &lws, + std::size_t &nelems_wg_sorts) +{ + for (auto it = sg_sizes.rbegin(); it != sg_sizes.rend(); ++it) { + auto sg_size = *it; + lws = sub_groups_per_work_group * sg_size; + nelems_wg_sorts = elems_per_wi * lws; + if (nelems_wg_sorts < nelems_per_slm) { + return; + } + } + // should never reach + throw std::runtime_error("Could not construct top k kernel parameters"); +} + +template +sycl::event write_out_impl(sycl::queue &exec_q, + std::size_t iter_nelems, + std::size_t k, + const argTy *arg_tp, + const IndexTy *index_data, + std::size_t iter_index_stride, + std::size_t axis_nelems, + argTy *vals_tp, + IndexTy *inds_tp, + const std::vector &depends) +{ + static constexpr std::uint32_t lws = 64; + static constexpr std::uint32_t n_wi = 4; + const std::size_t nelems = iter_nelems * k; + const std::size_t n_groups = (nelems + lws * n_wi - 1) / (n_wi * lws); + + sycl::range<1> lRange{lws}; + sycl::range<1> gRange{n_groups * lws}; + sycl::nd_range<1> ndRange{gRange, lRange}; + + sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> it) { + const std::size_t gid = it.get_global_linear_id(); + const auto &sg = it.get_sub_group(); + const std::uint32_t lane_id = sg.get_local_id()[0]; + const std::uint32_t sg_size = sg.get_max_local_range()[0]; + + const std::size_t start_id = (gid - lane_id) * n_wi + lane_id; + +#pragma unroll + for (std::uint32_t i = 0; i < n_wi; ++i) { + const std::size_t data_id = start_id + i * sg_size; + + if (data_id < nelems) { + const std::size_t iter_id = data_id / k; + + /* + const std::size_t axis_gid = data_id - (iter_gid * k); + const std::size_t src_idx = iter_gid * iter_index_stride + + axis_gid; + */ + const std::size_t src_idx = + data_id + iter_id * (iter_index_stride - k); + + const IndexTy res_ind = index_data[src_idx]; + const argTy v = arg_tp[res_ind]; + + const std::size_t dst_idx = data_id; + vals_tp[dst_idx] = v; + inds_tp[dst_idx] = (res_ind % axis_nelems); + } + } + }); + }); + + return write_out_ev; +} + +} // namespace topk_detail + +template +class topk_populate_index_data_krn; + +template +class topk_full_merge_map_back_krn; + +template +sycl::event + topk_full_merge_sort_impl(sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays + std::size_t axis_nelems, // size of each sub-array + std::size_t k, + const argTy *arg_tp, + argTy *vals_tp, + IndexTy *inds_tp, + const CompT &comp, + const std::vector &depends) +{ + auto index_data_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * axis_nelems, exec_q); + // extract USM pointer + IndexTy *index_data = index_data_owner.get(); + + using IotaKernelName = topk_populate_index_data_krn; + + using dpctl::tensor::kernels::sort_utils_detail::iota_impl; + + sycl::event populate_indexed_data_ev = iota_impl( + exec_q, index_data, iter_nelems * axis_nelems, depends); + + std::size_t sorted_block_size; + // Sort segments of the array + sycl::event base_sort_ev = + merge_sort_detail::sort_over_work_group_contig_impl( + exec_q, iter_nelems, axis_nelems, index_data, index_data, comp, + sorted_block_size, // modified in place with size of sorted block + // size + {populate_indexed_data_ev}); + + // Merge segments in parallel until all elements are sorted + sycl::event merges_ev = merge_sort_detail::merge_sorted_block_contig_impl( + exec_q, iter_nelems, axis_nelems, index_data, comp, sorted_block_size, + {base_sort_ev}); + + using WriteOutKernelName = topk_full_merge_map_back_krn; + + sycl::event write_out_ev = + topk_detail::write_out_impl( + exec_q, iter_nelems, k, arg_tp, index_data, axis_nelems, + axis_nelems, vals_tp, inds_tp, {merges_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {write_out_ev}, + index_data_owner); + + return cleanup_host_task_event; +}; + +template +class topk_partial_merge_map_back_krn; + +template +class topk_over_work_group_krn; + +template > +sycl::event topk_merge_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows + // in a matrix when sorting over rows) + std::size_t axis_nelems, // size of each array to sort (length of + // rows, i.e. number of columns) + std::size_t k, + const char *arg_cp, + char *vals_cp, + char *inds_cp, + const std::vector &depends) +{ + if (axis_nelems < k) { + throw std::runtime_error("Invalid sort axis size for value of k"); + } + + const argTy *arg_tp = reinterpret_cast(arg_cp); + argTy *vals_tp = reinterpret_cast(vals_cp); + IndexTy *inds_tp = reinterpret_cast(inds_cp); + + using dpctl::tensor::kernels::IndexComp; + const IndexComp index_comp{arg_tp, ValueComp{}}; + + if (axis_nelems <= 512 || k >= 1024 || k > axis_nelems / 2) { + return topk_full_merge_sort_impl(exec_q, iter_nelems, axis_nelems, k, + arg_tp, vals_tp, inds_tp, index_comp, + depends); + } + else { + using PartialKernelName = + topk_over_work_group_krn; + + const auto &kernel_id = sycl::get_kernel_id(); + + auto const &ctx = exec_q.get_context(); + auto const &dev = exec_q.get_device(); + + auto kb = sycl::get_kernel_bundle( + ctx, {dev}, {kernel_id}); + + auto krn = kb.get_kernel(kernel_id); + + const std::uint32_t max_sg_size = krn.template get_info< + sycl::info::kernel_device_specific::max_sub_group_size>(dev); + const std::uint64_t device_local_memory_size = + dev.get_info(); + + // leave 512 bytes of local memory for RT + const std::uint64_t safety_margin = 512; + + const std::uint64_t nelems_per_slm = + (device_local_memory_size - safety_margin) / (2 * sizeof(IndexTy)); + + static constexpr std::uint32_t sub_groups_per_work_group = 4; + const std::uint32_t elems_per_wi = dev.has(sycl::aspect::cpu) ? 8 : 2; + + std::size_t lws = sub_groups_per_work_group * max_sg_size; + + std::size_t sorted_block_size = elems_per_wi * lws; + if (sorted_block_size > nelems_per_slm) { + const std::vector sg_sizes = + dev.get_info(); + topk_detail::scale_topk_params( + nelems_per_slm, sub_groups_per_work_group, elems_per_wi, + sg_sizes, + lws, // modified by reference + sorted_block_size // modified by reference + ); + } + + // This assumption permits doing away with using a loop + assert(sorted_block_size % lws == 0); + + using search_sorted_detail::quotient_ceil; + const std::size_t n_segments = + quotient_ceil(axis_nelems, sorted_block_size); + + // round k up for the later merge kernel if necessary + const std::size_t round_k_to = dev.has(sycl::aspect::cpu) ? 32 : 4; + std::size_t k_rounded = + (k < round_k_to) + ? k + : quotient_ceil(k, round_k_to) * round_k_to; + + // get length of tail for alloc size + auto rem = axis_nelems % sorted_block_size; + auto alloc_len = (rem && rem < k_rounded) + ? rem + k_rounded * (n_segments - 1) + : k_rounded * n_segments; + + // if allocation would be sufficiently large or k is larger than + // elements processed, use full sort + if (k_rounded >= axis_nelems || k_rounded >= sorted_block_size || + alloc_len >= axis_nelems / 2) + { + return topk_full_merge_sort_impl(exec_q, iter_nelems, axis_nelems, + k, arg_tp, vals_tp, inds_tp, + index_comp, depends); + } + + auto index_data_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * alloc_len, exec_q); + // get raw USM pointer + IndexTy *index_data = index_data_owner.get(); + + // no need to populate index data: SLM will be populated with default + // values + + sycl::event base_sort_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.use_kernel_bundle(kb); + + sycl::range<1> global_range{iter_nelems * n_segments * lws}; + sycl::range<1> local_range{lws}; + + sycl::range<1> slm_range{sorted_block_size}; + sycl::local_accessor work_space(slm_range, cgh); + sycl::local_accessor scratch_space(slm_range, cgh); + + sycl::nd_range<1> ndRange(global_range, local_range); + + cgh.parallel_for( + ndRange, [=](sycl::nd_item<1> it) { + const std::size_t group_id = it.get_group_linear_id(); + const std::size_t iter_id = group_id / n_segments; + const std::size_t segment_id = + group_id - iter_id * n_segments; + const std::size_t lid = it.get_local_linear_id(); + + const std::size_t segment_start_idx = + segment_id * sorted_block_size; + const std::size_t segment_end_idx = std::min( + segment_start_idx + sorted_block_size, axis_nelems); + const std::size_t wg_chunk_size = + segment_end_idx - segment_start_idx; + + // load input into SLM + for (std::size_t array_id = segment_start_idx + lid; + array_id < segment_end_idx; array_id += lws) + { + IndexTy v = (array_id < axis_nelems) + ? iter_id * axis_nelems + array_id + : IndexTy{}; + work_space[array_id - segment_start_idx] = v; + } + sycl::group_barrier(it.get_group()); + + const std::size_t chunk = + quotient_ceil(sorted_block_size, lws); + + const std::size_t chunk_start_idx = lid * chunk; + const std::size_t chunk_end_idx = + sycl::min(chunk_start_idx + chunk, wg_chunk_size); + + merge_sort_detail::leaf_sort_impl( + work_space, chunk_start_idx, chunk_end_idx, index_comp); + + sycl::group_barrier(it.get_group()); + + bool data_in_temp = false; + std::size_t n_chunks_merged = 1; + + // merge chunk while n_chunks_merged * chunk < wg_chunk_size + const std::size_t max_chunks_merged = + 1 + ((wg_chunk_size - 1) / chunk); + for (; n_chunks_merged < max_chunks_merged; + data_in_temp = !data_in_temp, n_chunks_merged *= 2) + { + const std::size_t nelems_sorted_so_far = + n_chunks_merged * chunk; + const std::size_t q = (lid / n_chunks_merged); + const std::size_t start_1 = sycl::min( + 2 * nelems_sorted_so_far * q, wg_chunk_size); + const std::size_t end_1 = sycl::min( + start_1 + nelems_sorted_so_far, wg_chunk_size); + const std::size_t end_2 = sycl::min( + end_1 + nelems_sorted_so_far, wg_chunk_size); + const std::size_t offset = + chunk * (lid - q * n_chunks_merged); + + if (data_in_temp) { + merge_sort_detail::merge_impl( + offset, scratch_space, work_space, start_1, + end_1, end_2, start_1, index_comp, chunk); + } + else { + merge_sort_detail::merge_impl( + offset, work_space, scratch_space, start_1, + end_1, end_2, start_1, index_comp, chunk); + } + sycl::group_barrier(it.get_group()); + } + + // output assumed to be structured as (iter_nelems, + // alloc_len) + const std::size_t k_segment_start_idx = + segment_id * k_rounded; + const std::size_t k_segment_end_idx = std::min( + k_segment_start_idx + k_rounded, alloc_len); + const auto &out_src = + (data_in_temp) ? scratch_space : work_space; + for (std::size_t array_id = k_segment_start_idx + lid; + array_id < k_segment_end_idx; array_id += lws) + { + if (lid < k_rounded) { + index_data[iter_id * alloc_len + array_id] = + out_src[array_id - k_segment_start_idx]; + } + } + }); + }); + + // Merge segments in parallel until all elements are sorted + sycl::event merges_ev = + merge_sort_detail::merge_sorted_block_contig_impl( + exec_q, iter_nelems, alloc_len, index_data, index_comp, + k_rounded, {base_sort_ev}); + + // Write out top k of the merge-sorted memory + using WriteOutKernelName = + topk_partial_merge_map_back_krn; + + sycl::event write_topk_ev = + topk_detail::write_out_impl( + exec_q, iter_nelems, k, arg_tp, index_data, alloc_len, + axis_nelems, vals_tp, inds_tp, {merges_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {write_topk_ev}, index_data_owner); + + return cleanup_host_task_event; + } +} + +template +class topk_iota_krn; + +template +class topk_radix_map_back_krn; + +template +sycl::event topk_radix_impl(sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays + std::size_t axis_nelems, // size of each sub-array + std::size_t k, + bool ascending, + const char *arg_cp, + char *vals_cp, + char *inds_cp, + const std::vector &depends) +{ + if (axis_nelems < k) { + throw std::runtime_error("Invalid sort axis size for value of k"); + } + + const argTy *arg_tp = reinterpret_cast(arg_cp); + argTy *vals_tp = reinterpret_cast(vals_cp); + IndexTy *inds_tp = reinterpret_cast(inds_cp); + + const std::size_t total_nelems = iter_nelems * axis_nelems; + const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64; + auto workspace_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + padded_total_nelems + total_nelems, exec_q); + + // get raw USM pointer + IndexTy *workspace = workspace_owner.get(); + IndexTy *tmp_tp = workspace + padded_total_nelems; + + using IdentityProjT = radix_sort_details::IdentityProj; + using IndexedProjT = + radix_sort_details::IndexedProj; + const IndexedProjT proj_op{arg_tp}; + + using IotaKernelName = topk_iota_krn; + + using dpctl::tensor::kernels::sort_utils_detail::iota_impl; + + sycl::event iota_ev = iota_impl( + exec_q, workspace, total_nelems, depends); + + sycl::event radix_sort_ev = + radix_sort_details::parallel_radix_sort_impl( + exec_q, iter_nelems, axis_nelems, workspace, tmp_tp, proj_op, + ascending, {iota_ev}); + + // Write out top k of the temporary + using WriteOutKernelName = topk_radix_map_back_krn; + + sycl::event write_topk_ev = + topk_detail::write_out_impl( + exec_q, iter_nelems, k, arg_tp, tmp_tp, axis_nelems, axis_nelems, + vals_tp, inds_tp, {radix_sort_ev}); + + sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {write_topk_ev}, workspace_owner); + + return cleanup_ev; +} + +} // namespace dpctl::tensor::kernels diff --git a/dpctl_ext/tensor/libtensor/include/utils/rich_comparisons.hpp b/dpctl_ext/tensor/libtensor/include/utils/rich_comparisons.hpp new file mode 100644 index 000000000000..87cdfbfbd54f --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/utils/rich_comparisons.hpp @@ -0,0 +1,149 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include + +#include "sycl/sycl.hpp" + +namespace dpctl::tensor::rich_comparisons +{ + +namespace detail +{ +template +struct ExtendedRealFPLess +{ + /* [R, nan] */ + bool operator()(const fpT v1, const fpT v2) const + { + return (!std::isnan(v1) && (std::isnan(v2) || (v1 < v2))); + } +}; + +template +struct ExtendedRealFPGreater +{ + bool operator()(const fpT v1, const fpT v2) const + { + return (!std::isnan(v2) && (std::isnan(v1) || (v2 < v1))); + } +}; + +template +struct ExtendedComplexFPLess +{ + /* [(R, R), (R, nan), (nan, R), (nan, nan)] */ + + bool operator()(const cT &v1, const cT &v2) const + { + using realT = typename cT::value_type; + + const realT real1 = std::real(v1); + const realT real2 = std::real(v2); + + const bool r1_nan = std::isnan(real1); + const bool r2_nan = std::isnan(real2); + + const realT imag1 = std::imag(v1); + const realT imag2 = std::imag(v2); + + const bool i1_nan = std::isnan(imag1); + const bool i2_nan = std::isnan(imag2); + + const int idx1 = ((r1_nan) ? 2 : 0) + ((i1_nan) ? 1 : 0); + const int idx2 = ((r2_nan) ? 2 : 0) + ((i2_nan) ? 1 : 0); + + const bool res = + !(r1_nan && i1_nan) && + ((idx1 < idx2) || + ((idx1 == idx2) && + ((r1_nan && !i1_nan && (imag1 < imag2)) || + (!r1_nan && i1_nan && (real1 < real2)) || + (!r1_nan && !i1_nan && + ((real1 < real2) || (!(real2 < real1) && (imag1 < imag2))))))); + + return res; + } +}; + +template +struct ExtendedComplexFPGreater +{ + bool operator()(const cT &v1, const cT &v2) const + { + auto less_ = ExtendedComplexFPLess{}; + return less_(v2, v1); + } +}; + +template +inline constexpr bool is_fp_v = (std::is_same_v || + std::is_same_v || + std::is_same_v); + +} // namespace detail + +template +struct AscendingSorter +{ + using type = std::conditional_t, + detail::ExtendedRealFPLess, + std::less>; +}; + +template +struct AscendingSorter> +{ + using type = detail::ExtendedComplexFPLess>; +}; + +template +struct DescendingSorter +{ + using type = std::conditional_t, + detail::ExtendedRealFPGreater, + std::greater>; +}; + +template +struct DescendingSorter> +{ + using type = detail::ExtendedComplexFPGreater>; +}; + +} // namespace dpctl::tensor::rich_comparisons diff --git a/dpctl_ext/tensor/libtensor/source/accumulators/accumulate_over_axis.hpp b/dpctl_ext/tensor/libtensor/source/accumulators/accumulate_over_axis.hpp new file mode 100644 index 000000000000..4dd00620a260 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/accumulators/accumulate_over_axis.hpp @@ -0,0 +1,462 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_accumulation_impl +// extensions +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/accumulators.hpp" +#include "simplify_iteration_space.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/output_validation.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +template +std::pair + py_accumulate_over_axis(const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_accumulate, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + std::vector const &depends, + const strided_fnT &strided_dispatch_table, + const contig_fnT &contig_dispatch_table) +{ + int src_nd = src.get_ndim(); + int dst_nd = dst.get_ndim(); + if (src_nd != dst_nd) { + throw py::value_error("The input and output arrays must have " + "the same array ranks"); + } + int iter_nd = src_nd - trailing_dims_to_accumulate; + if (trailing_dims_to_accumulate <= 0 || iter_nd < 0) { + throw py::value_error( + "trailing_dims_to_accumulate must be positive, but no " + "greater than rank of the input array"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + std::size_t iter_nelems(1); + for (int i = 0; same_shapes && (i < iter_nd); ++i) { + auto src_shape_i = src_shape_ptr[i]; + same_shapes = same_shapes && (src_shape_i == dst_shape_ptr[i]); + iter_nelems *= static_cast(src_shape_i); + } + + std::size_t acc_nelems(1); + for (int i = iter_nd; same_shapes && (i < src_nd); ++i) { + auto dst_shape_i = dst_shape_ptr[i]; + same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_i); + acc_nelems *= static_cast(dst_shape_i); + } + + if (!same_shapes) { + throw py::value_error( + "Destination shape does not match the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + if ((iter_nelems == 0) || (acc_nelems == 0)) { + return std::make_pair(sycl::event(), sycl::event()); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample( + dst, acc_nelems * iter_nelems); + + const char *src_data = src.get_data(); + char *dst_data = dst.get_data(); + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + bool is_src_c_contig = src.is_c_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + + std::vector host_task_events; + + if ((is_src_c_contig && is_dst_c_contig) && iter_nd == 0) { + auto fn = contig_dispatch_table[src_typeid][dst_typeid]; + if (fn == nullptr) { + throw std::runtime_error("Datatypes are not supported"); + } + + sycl::event acc_ev = fn(exec_q, acc_nelems, src_data, dst_data, + host_task_events, depends); + + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {acc_ev}), + acc_ev); + } + + auto src_shape_vec = src.get_shape_vector(); + auto src_strides_vec = src.get_strides_vector(); + auto dst_strides_vec = dst.get_strides_vector(); + + int acc_nd = trailing_dims_to_accumulate; + + using shT = std::vector; + shT acc_shape(std::begin(src_shape_vec) + iter_nd, std::end(src_shape_vec)); + + shT acc_src_strides(std::begin(src_strides_vec) + iter_nd, + std::end(src_strides_vec)); + + shT acc_dst_strides(std::begin(dst_strides_vec) + iter_nd, + std::end(dst_strides_vec)); + + shT iter_shape(std::begin(src_shape_vec), + std::begin(src_shape_vec) + iter_nd); + + shT iter_src_strides(std::begin(src_strides_vec), + std::begin(src_strides_vec) + iter_nd); + + shT iter_dst_strides(std::begin(dst_strides_vec), + std::begin(dst_strides_vec) + iter_nd); + + shT simplified_iter_shape; + shT simplified_iter_src_strides; + shT simplified_iter_dst_strides; + py::ssize_t iter_src_offset(0); + py::ssize_t iter_dst_offset(0); + + if (iter_nd == 0) { + iter_nd = 1; + simplified_iter_shape.push_back(1); + simplified_iter_src_strides.push_back(0); + simplified_iter_dst_strides.push_back(0); + } + else { + simplify_iteration_space( + iter_nd, src_shape_ptr, iter_src_strides, iter_dst_strides, + // output + simplified_iter_shape, simplified_iter_src_strides, + simplified_iter_dst_strides, iter_src_offset, iter_dst_offset); + } + + // Strided implementation + auto strided_fn = strided_dispatch_table[src_typeid][dst_typeid]; + if (strided_fn == nullptr) { + throw std::runtime_error("Datatypes are not supported"); + } + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto ptr_size_event_tuple = device_allocate_and_pack( + exec_q, host_task_events, simplified_iter_shape, + simplified_iter_src_strides, simplified_iter_dst_strides, acc_shape, + acc_src_strides, acc_dst_strides); + auto packed_shapes_and_strides_owner = + std::move(std::get<0>(ptr_size_event_tuple)); + const auto ©_shapes_strides_ev = std::get<2>(ptr_size_event_tuple); + const py::ssize_t *packed_shapes_and_strides = + packed_shapes_and_strides_owner.get(); + + const py::ssize_t *iter_shape_and_strides = packed_shapes_and_strides; + const py::ssize_t *acc_shapes_and_strides = + packed_shapes_and_strides + 3 * simplified_iter_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), copy_shapes_strides_ev); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + + sycl::event acc_ev = strided_fn( + exec_q, iter_nelems, acc_nelems, src_data, iter_nd, + iter_shape_and_strides, iter_src_offset, iter_dst_offset, acc_nd, + acc_shapes_and_strides, dst_data, host_task_events, all_deps); + + sycl::event temp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {acc_ev}, packed_shapes_and_strides_owner); + host_task_events.push_back(temp_cleanup_ev); + + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events), + acc_ev); +} + +template +std::pair py_accumulate_final_axis_include_initial( + const dpctl::tensor::usm_ndarray &src, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + std::vector const &depends, + const strided_fnT &strided_dispatch_table, + const contig_fnT &contig_dispatch_table) +{ + int src_nd = src.get_ndim(); + int dst_nd = dst.get_ndim(); + + if (src_nd != dst_nd) { + throw py::value_error("The input and output arrays must have " + "the same array ranks"); + } + + static constexpr int acc_nd = 1; + + int iter_nd = src_nd - acc_nd; + if (iter_nd < 0) { + throw py::value_error("accumulation axis must not be greater than rank " + "of the input array"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + std::size_t iter_nelems(1); + for (int i = 0; same_shapes && (i < iter_nd); ++i) { + auto src_shape_i = src_shape_ptr[i]; + same_shapes = same_shapes && (src_shape_i == dst_shape_ptr[i]); + iter_nelems *= static_cast(src_shape_i); + } + + std::size_t acc_nelems(1); + for (int i = iter_nd; same_shapes && (i < src_nd); ++i) { + auto dst_shape_i = dst_shape_ptr[i]; + same_shapes = same_shapes && (src_shape_ptr[i] + 1 == dst_shape_i); + acc_nelems *= static_cast(dst_shape_i); + } + + if (!same_shapes) { + throw py::value_error( + "Destination shape does not match the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + if ((iter_nelems == 0) || (acc_nelems == 0)) { + return std::make_pair(sycl::event(), sycl::event()); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample( + dst, acc_nelems * iter_nelems); + + const char *src_data = src.get_data(); + char *dst_data = dst.get_data(); + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + bool is_src_c_contig = src.is_c_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + + std::vector host_task_events; + + if ((is_src_c_contig && is_dst_c_contig) && iter_nd == 0) { + auto fn = contig_dispatch_table[src_typeid][dst_typeid]; + if (fn == nullptr) { + throw std::runtime_error("Datatypes are not supported"); + } + + sycl::event acc_ev = fn(exec_q, acc_nelems, src_data, dst_data, + host_task_events, depends); + + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {acc_ev}), + acc_ev); + } + + auto src_shape_vec = src.get_shape_vector(); + auto src_strides_vec = src.get_strides_vector(); + auto dst_strides_vec = dst.get_strides_vector(); + + using shT = std::vector; + shT acc_shape(std::begin(src_shape_vec) + iter_nd, std::end(src_shape_vec)); + + shT acc_src_strides(std::begin(src_strides_vec) + iter_nd, + std::end(src_strides_vec)); + + shT acc_dst_strides(std::begin(dst_strides_vec) + iter_nd, + std::end(dst_strides_vec)); + + shT iter_shape(std::begin(src_shape_vec), + std::begin(src_shape_vec) + iter_nd); + + shT iter_src_strides(std::begin(src_strides_vec), + std::begin(src_strides_vec) + iter_nd); + + shT iter_dst_strides(std::begin(dst_strides_vec), + std::begin(dst_strides_vec) + iter_nd); + + shT simplified_iter_shape; + shT simplified_iter_src_strides; + shT simplified_iter_dst_strides; + py::ssize_t iter_src_offset(0); + py::ssize_t iter_dst_offset(0); + + if (iter_nd == 0) { + iter_nd = 1; + simplified_iter_shape.push_back(1); + simplified_iter_src_strides.push_back(0); + simplified_iter_dst_strides.push_back(0); + } + else { + simplify_iteration_space( + iter_nd, src_shape_ptr, iter_src_strides, iter_dst_strides, + // output + simplified_iter_shape, simplified_iter_src_strides, + simplified_iter_dst_strides, iter_src_offset, iter_dst_offset); + } + + // Strided implementation + auto strided_fn = strided_dispatch_table[src_typeid][dst_typeid]; + if (strided_fn == nullptr) { + throw std::runtime_error("Datatypes are not supported"); + } + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto ptr_size_event_tuple = device_allocate_and_pack( + exec_q, host_task_events, simplified_iter_shape, + simplified_iter_src_strides, simplified_iter_dst_strides, acc_shape, + acc_src_strides, acc_dst_strides); + auto packed_shapes_and_strides_owner = + std::move(std::get<0>(ptr_size_event_tuple)); + const auto ©_shapes_strides_ev = std::get<2>(ptr_size_event_tuple); + const py::ssize_t *packed_shapes_and_strides = + packed_shapes_and_strides_owner.get(); + + const py::ssize_t *iter_shape_and_strides = packed_shapes_and_strides; + const py::ssize_t *acc_shapes_and_strides = + packed_shapes_and_strides + 3 * simplified_iter_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), copy_shapes_strides_ev); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + + sycl::event acc_ev = strided_fn( + exec_q, iter_nelems, acc_nelems, src_data, iter_nd, + iter_shape_and_strides, iter_src_offset, iter_dst_offset, acc_nd, + acc_shapes_and_strides, dst_data, host_task_events, all_deps); + + sycl::event temp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {acc_ev}, packed_shapes_and_strides_owner); + host_task_events.push_back(temp_cleanup_ev); + + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events), + acc_ev); +} + +/*! @brief Template implementing Python API for querying accumulation + * type support */ +template +bool py_accumulate_dtype_supported(const py::dtype &input_dtype, + const py::dtype &output_dtype, + const fnT &dispatch_table) +{ + int arg_tn = + input_dtype.num(); // NumPy type numbers are the same as in dpctl + int out_tn = + output_dtype.num(); // NumPy type numbers are the same as in dpctl + int arg_typeid = -1; + int out_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + arg_typeid = array_types.typenum_to_lookup_id(arg_tn); + out_typeid = array_types.typenum_to_lookup_id(out_tn); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + if (arg_typeid < 0 || arg_typeid >= td_ns::num_types || out_typeid < 0 || + out_typeid >= td_ns::num_types) + { + throw std::runtime_error("Reduction type support check: lookup failed"); + } + + // remove_all_extents gets underlying type of table + using fn_ptrT = typename std::remove_all_extents::type; + fn_ptrT fn = nullptr; + + fn = dispatch_table[arg_typeid][out_typeid]; + + return (fn != nullptr); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/accumulators/accumulators_common.cpp b/dpctl_ext/tensor/libtensor/source/accumulators/accumulators_common.cpp new file mode 100644 index 000000000000..5e07e81b7ad5 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/accumulators/accumulators_common.cpp @@ -0,0 +1,55 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_accumulation_impl +// extensions +//===----------------------------------------------------------------------===// + +#include + +#include "cumulative_logsumexp.hpp" +#include "cumulative_prod.hpp" +#include "cumulative_sum.hpp" + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +/*! @brief Add accumulators to Python module */ +void init_accumulator_functions(py::module_ m) +{ + init_cumulative_logsumexp(m); + init_cumulative_prod(m); + init_cumulative_sum(m); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/accumulators/accumulators_common.hpp b/dpctl_ext/tensor/libtensor/source/accumulators/accumulators_common.hpp new file mode 100644 index 000000000000..c33a040a7fa7 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/accumulators/accumulators_common.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_accumulation_impl +// extensions +//===----------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_accumulator_functions(py::module_); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_logsumexp.cpp b/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_logsumexp.cpp new file mode 100644 index 000000000000..e24cf56ddd62 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_logsumexp.cpp @@ -0,0 +1,347 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_accumulation_impl +// extensions +//===----------------------------------------------------------------------===// + +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include +#include + +#include "accumulate_over_axis.hpp" +#include "kernels/accumulators.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_dispatch_building.hpp" + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +namespace su_ns = dpctl::tensor::sycl_utils; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::accumulators::accumulate_1d_contig_impl_fn_ptr_t; +static accumulate_1d_contig_impl_fn_ptr_t + cumlogsumexp_1d_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::accumulators::accumulate_strided_impl_fn_ptr_t; +static accumulate_strided_impl_fn_ptr_t + cumlogsumexp_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static accumulate_1d_contig_impl_fn_ptr_t + cumlogsumexp_1d_include_initial_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static accumulate_strided_impl_fn_ptr_t + cumlogsumexp_include_initial_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +template +struct TypePairSupportDataForLogSumExpAccumulation +{ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct CumLogSumExp1DContigFactory +{ + fnT get() + { + if constexpr (TypePairSupportDataForLogSumExpAccumulation< + srcTy, dstTy>::is_defined) + { + using ScanOpT = su_ns::LogSumExp; + static constexpr bool include_initial = false; + if constexpr (std::is_same_v) { + using dpctl::tensor::kernels::accumulators::NoOpTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_1d_contig_impl, ScanOpT, + include_initial>; + return fn; + } + else { + using dpctl::tensor::kernels::accumulators::CastTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_1d_contig_impl, + ScanOpT, include_initial>; + return fn; + } + } + else { + return nullptr; + } + } +}; + +template +struct CumLogSumExp1DIncludeInitialContigFactory +{ + fnT get() + { + if constexpr (TypePairSupportDataForLogSumExpAccumulation< + srcTy, dstTy>::is_defined) + { + using ScanOpT = su_ns::LogSumExp; + static constexpr bool include_initial = true; + if constexpr (std::is_same_v) { + using dpctl::tensor::kernels::accumulators::NoOpTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_1d_contig_impl, ScanOpT, + include_initial>; + return fn; + } + else { + using dpctl::tensor::kernels::accumulators::CastTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_1d_contig_impl, + ScanOpT, include_initial>; + return fn; + } + } + else { + return nullptr; + } + } +}; + +template +struct CumLogSumExpStridedFactory +{ + fnT get() + { + if constexpr (TypePairSupportDataForLogSumExpAccumulation< + srcTy, dstTy>::is_defined) + { + using ScanOpT = su_ns::LogSumExp; + static constexpr bool include_initial = false; + if constexpr (std::is_same_v) { + using dpctl::tensor::kernels::accumulators::NoOpTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_strided_impl, ScanOpT, + include_initial>; + return fn; + } + else { + using dpctl::tensor::kernels::accumulators::CastTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_strided_impl, + ScanOpT, include_initial>; + return fn; + } + } + else { + return nullptr; + } + } +}; + +template +struct CumLogSumExpIncludeInitialStridedFactory +{ + fnT get() + { + if constexpr (TypePairSupportDataForLogSumExpAccumulation< + srcTy, dstTy>::is_defined) + { + using ScanOpT = su_ns::LogSumExp; + static constexpr bool include_initial = true; + if constexpr (std::is_same_v) { + using dpctl::tensor::kernels::accumulators::NoOpTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_strided_impl, ScanOpT, + include_initial>; + return fn; + } + else { + using dpctl::tensor::kernels::accumulators::CastTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_strided_impl, + ScanOpT, include_initial>; + return fn; + } + } + else { + return nullptr; + } + } +}; + +void populate_cumlogsumexp_dispatch_tables(void) +{ + td_ns::DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(cumlogsumexp_1d_contig_dispatch_table); + + td_ns::DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(cumlogsumexp_strided_dispatch_table); + + td_ns::DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table( + cumlogsumexp_1d_include_initial_contig_dispatch_table); + + td_ns::DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table( + cumlogsumexp_include_initial_strided_dispatch_table); + + return; +} + +} // namespace impl + +void init_cumulative_logsumexp(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + + using impl::populate_cumlogsumexp_dispatch_tables; + populate_cumlogsumexp_dispatch_tables(); + + using impl::cumlogsumexp_1d_contig_dispatch_table; + using impl::cumlogsumexp_strided_dispatch_table; + auto cumlogsumexp_pyapi = [&](const arrayT &src, + int trailing_dims_to_accumulate, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_accumulate_over_axis(src, trailing_dims_to_accumulate, dst, + exec_q, depends, + cumlogsumexp_strided_dispatch_table, + cumlogsumexp_1d_contig_dispatch_table); + }; + m.def("_cumlogsumexp_over_axis", cumlogsumexp_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_accumulate"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + using impl::cumlogsumexp_1d_include_initial_contig_dispatch_table; + using impl::cumlogsumexp_include_initial_strided_dispatch_table; + auto cumlogsumexp_include_initial_pyapi = + [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_accumulate_final_axis_include_initial( + src, dst, exec_q, depends, + cumlogsumexp_include_initial_strided_dispatch_table, + cumlogsumexp_1d_include_initial_contig_dispatch_table); + }; + m.def("_cumlogsumexp_final_axis_include_initial", + cumlogsumexp_include_initial_pyapi, "", py::arg("src"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + + auto cumlogsumexp_dtype_supported = [&](const py::dtype &input_dtype, + const py::dtype &output_dtype) { + return py_accumulate_dtype_supported( + input_dtype, output_dtype, cumlogsumexp_strided_dispatch_table); + }; + m.def("_cumlogsumexp_dtype_supported", cumlogsumexp_dtype_supported, "", + py::arg("arg_dtype"), py::arg("out_dtype")); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_logsumexp.hpp b/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_logsumexp.hpp new file mode 100644 index 000000000000..f1292320bd0d --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_logsumexp.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_accumulation_impl +// extensions +//===----------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_cumulative_logsumexp(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_prod.cpp b/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_prod.cpp new file mode 100644 index 000000000000..65f3c311eda1 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_prod.cpp @@ -0,0 +1,356 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_accumulation_impl +// extensions +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include +#include + +#include "accumulate_over_axis.hpp" +#include "kernels/accumulators.hpp" +#include "utils/type_dispatch_building.hpp" + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::accumulators::accumulate_1d_contig_impl_fn_ptr_t; +static accumulate_1d_contig_impl_fn_ptr_t + cumprod_1d_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::accumulators::accumulate_strided_impl_fn_ptr_t; +static accumulate_strided_impl_fn_ptr_t + cumprod_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static accumulate_1d_contig_impl_fn_ptr_t + cumprod_1d_include_initial_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static accumulate_strided_impl_fn_ptr_t + cumprod_include_initial_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +template +struct TypePairSupportDataForProdAccumulation +{ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint64_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +using CumProdScanOpT = std::conditional_t, + sycl::logical_and, + sycl::multiplies>; + +template +struct CumProd1DContigFactory +{ + fnT get() + { + if constexpr (TypePairSupportDataForProdAccumulation::is_defined) + { + using ScanOpT = CumProdScanOpT; + static constexpr bool include_initial = false; + if constexpr (std::is_same_v) { + using dpctl::tensor::kernels::accumulators::NoOpTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_1d_contig_impl, ScanOpT, + include_initial>; + return fn; + } + else { + using dpctl::tensor::kernels::accumulators::CastTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_1d_contig_impl, + ScanOpT, include_initial>; + return fn; + } + } + else { + return nullptr; + } + } +}; + +template +struct CumProd1DIncludeInitialContigFactory +{ + fnT get() + { + if constexpr (TypePairSupportDataForProdAccumulation::is_defined) + { + using ScanOpT = CumProdScanOpT; + static constexpr bool include_initial = true; + if constexpr (std::is_same_v) { + using dpctl::tensor::kernels::accumulators::NoOpTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_1d_contig_impl, ScanOpT, + include_initial>; + return fn; + } + else { + using dpctl::tensor::kernels::accumulators::CastTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_1d_contig_impl, + ScanOpT, include_initial>; + return fn; + } + } + else { + return nullptr; + } + } +}; + +template +struct CumProdStridedFactory +{ + fnT get() + { + if constexpr (TypePairSupportDataForProdAccumulation::is_defined) + { + using ScanOpT = CumProdScanOpT; + static constexpr bool include_initial = false; + if constexpr (std::is_same_v) { + using dpctl::tensor::kernels::accumulators::NoOpTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_strided_impl, ScanOpT, + include_initial>; + return fn; + } + else { + using dpctl::tensor::kernels::accumulators::CastTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_strided_impl, + ScanOpT, include_initial>; + return fn; + } + } + else { + return nullptr; + } + } +}; + +template +struct CumProdIncludeInitialStridedFactory +{ + fnT get() + { + if constexpr (TypePairSupportDataForProdAccumulation::is_defined) + { + using ScanOpT = CumProdScanOpT; + static constexpr bool include_initial = true; + if constexpr (std::is_same_v) { + using dpctl::tensor::kernels::accumulators::NoOpTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_strided_impl, ScanOpT, + include_initial>; + return fn; + } + else { + using dpctl::tensor::kernels::accumulators::CastTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_strided_impl, + ScanOpT, include_initial>; + return fn; + } + } + else { + return nullptr; + } + } +}; + +void populate_cumprod_dispatch_tables(void) +{ + td_ns::DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(cumprod_1d_contig_dispatch_table); + + td_ns::DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(cumprod_strided_dispatch_table); + + td_ns::DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table( + cumprod_1d_include_initial_contig_dispatch_table); + + td_ns::DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table( + cumprod_include_initial_strided_dispatch_table); + + return; +} + +} // namespace impl + +void init_cumulative_prod(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + + using impl::populate_cumprod_dispatch_tables; + populate_cumprod_dispatch_tables(); + + using impl::cumprod_1d_contig_dispatch_table; + using impl::cumprod_strided_dispatch_table; + auto cumprod_pyapi = [&](const arrayT &src, int trailing_dims_to_accumulate, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_accumulate_over_axis( + src, trailing_dims_to_accumulate, dst, exec_q, depends, + cumprod_strided_dispatch_table, cumprod_1d_contig_dispatch_table); + }; + m.def("_cumprod_over_axis", cumprod_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_accumulate"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + using impl::cumprod_1d_include_initial_contig_dispatch_table; + using impl::cumprod_include_initial_strided_dispatch_table; + auto cumprod_include_initial_pyapi = + [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_accumulate_final_axis_include_initial( + src, dst, exec_q, depends, + cumprod_include_initial_strided_dispatch_table, + cumprod_1d_include_initial_contig_dispatch_table); + }; + m.def("_cumprod_final_axis_include_initial", cumprod_include_initial_pyapi, + "", py::arg("src"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + + auto cumprod_dtype_supported = [&](const py::dtype &input_dtype, + const py::dtype &output_dtype) { + return py_accumulate_dtype_supported(input_dtype, output_dtype, + cumprod_strided_dispatch_table); + }; + m.def("_cumprod_dtype_supported", cumprod_dtype_supported, "", + py::arg("arg_dtype"), py::arg("out_dtype")); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_prod.hpp b/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_prod.hpp new file mode 100644 index 000000000000..e14bb2c44361 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_prod.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_accumulation_impl +// extensions +//===----------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_cumulative_prod(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_sum.cpp b/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_sum.cpp new file mode 100644 index 000000000000..60b46946acc9 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_sum.cpp @@ -0,0 +1,354 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_accumulation_impl +// extensions +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include +#include + +#include "accumulate_over_axis.hpp" +#include "kernels/accumulators.hpp" +#include "utils/type_dispatch_building.hpp" + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::accumulators::accumulate_1d_contig_impl_fn_ptr_t; +static accumulate_1d_contig_impl_fn_ptr_t + cumsum_1d_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::accumulators::accumulate_strided_impl_fn_ptr_t; +static accumulate_strided_impl_fn_ptr_t + cumsum_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static accumulate_1d_contig_impl_fn_ptr_t + cumsum_1d_include_initial_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static accumulate_strided_impl_fn_ptr_t + cumsum_include_initial_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +template +struct TypePairSupportDataForSumAccumulation +{ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint64_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +using CumSumScanOpT = std:: + conditional_t, sycl::logical_or, sycl::plus>; + +template +struct CumSum1DContigFactory +{ + fnT get() + { + if constexpr (TypePairSupportDataForSumAccumulation::is_defined) + { + using ScanOpT = CumSumScanOpT; + static constexpr bool include_initial = false; + if constexpr (std::is_same_v) { + using dpctl::tensor::kernels::accumulators::NoOpTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_1d_contig_impl, ScanOpT, + include_initial>; + return fn; + } + else { + using dpctl::tensor::kernels::accumulators::CastTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_1d_contig_impl, + ScanOpT, include_initial>; + return fn; + } + } + else { + return nullptr; + } + } +}; + +template +struct CumSum1DIncludeInitialContigFactory +{ + fnT get() + { + if constexpr (TypePairSupportDataForSumAccumulation::is_defined) + { + using ScanOpT = CumSumScanOpT; + static constexpr bool include_initial = true; + if constexpr (std::is_same_v) { + using dpctl::tensor::kernels::accumulators::NoOpTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_1d_contig_impl, ScanOpT, + include_initial>; + return fn; + } + else { + using dpctl::tensor::kernels::accumulators::CastTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_1d_contig_impl, + ScanOpT, include_initial>; + return fn; + } + } + else { + return nullptr; + } + } +}; + +template +struct CumSumStridedFactory +{ + fnT get() + { + if constexpr (TypePairSupportDataForSumAccumulation::is_defined) + { + using ScanOpT = CumSumScanOpT; + static constexpr bool include_initial = false; + if constexpr (std::is_same_v) { + using dpctl::tensor::kernels::accumulators::NoOpTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_strided_impl, ScanOpT, + include_initial>; + return fn; + } + else { + using dpctl::tensor::kernels::accumulators::CastTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_strided_impl, + ScanOpT, include_initial>; + return fn; + } + } + else { + return nullptr; + } + } +}; + +template +struct CumSumIncludeInitialStridedFactory +{ + fnT get() + { + if constexpr (TypePairSupportDataForSumAccumulation::is_defined) + { + using ScanOpT = CumSumScanOpT; + static constexpr bool include_initial = true; + if constexpr (std::is_same_v) { + using dpctl::tensor::kernels::accumulators::NoOpTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_strided_impl, ScanOpT, + include_initial>; + return fn; + } + else { + using dpctl::tensor::kernels::accumulators::CastTransformer; + fnT fn = dpctl::tensor::kernels::accumulators:: + accumulate_strided_impl, + ScanOpT, include_initial>; + return fn; + } + } + else { + return nullptr; + } + } +}; + +void populate_cumsum_dispatch_tables(void) +{ + td_ns::DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(cumsum_1d_contig_dispatch_table); + + td_ns::DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(cumsum_strided_dispatch_table); + + td_ns::DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table( + cumsum_1d_include_initial_contig_dispatch_table); + + td_ns::DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(cumsum_include_initial_strided_dispatch_table); + + return; +} + +} // namespace impl + +void init_cumulative_sum(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + + using impl::populate_cumsum_dispatch_tables; + populate_cumsum_dispatch_tables(); + + using impl::cumsum_1d_contig_dispatch_table; + using impl::cumsum_strided_dispatch_table; + auto cumsum_pyapi = [&](const arrayT &src, int trailing_dims_to_accumulate, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_accumulate_over_axis( + src, trailing_dims_to_accumulate, dst, exec_q, depends, + cumsum_strided_dispatch_table, cumsum_1d_contig_dispatch_table); + }; + m.def("_cumsum_over_axis", cumsum_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_accumulate"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + using impl::cumsum_1d_include_initial_contig_dispatch_table; + using impl::cumsum_include_initial_strided_dispatch_table; + auto cumsum_include_initial_pyapi = + [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_accumulate_final_axis_include_initial( + src, dst, exec_q, depends, + cumsum_include_initial_strided_dispatch_table, + cumsum_1d_include_initial_contig_dispatch_table); + }; + m.def("_cumsum_final_axis_include_initial", cumsum_include_initial_pyapi, + "", py::arg("src"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + + auto cumsum_dtype_supported = [&](const py::dtype &input_dtype, + const py::dtype &output_dtype) { + return py_accumulate_dtype_supported(input_dtype, output_dtype, + cumsum_strided_dispatch_table); + }; + m.def("_cumsum_dtype_supported", cumsum_dtype_supported, "", + py::arg("arg_dtype"), py::arg("out_dtype")); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_sum.hpp b/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_sum.hpp new file mode 100644 index 000000000000..5e06b222a3bc --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/accumulators/cumulative_sum.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_accumulation_impl +// extensions +//===----------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_cumulative_sum(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/all.cpp b/dpctl_ext/tensor/libtensor/source/reductions/all.cpp new file mode 100644 index 000000000000..a901b9e1d9a3 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/all.cpp @@ -0,0 +1,164 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "reduction_atomic_support.hpp" +#include "reduction_over_axis.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + all_reduction_strided_dispatch_vector[td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + all_reduction_axis1_contig_dispatch_vector[td_ns::num_types]; +static reduction_contig_impl_fn_ptr + all_reduction_axis0_contig_dispatch_vector[td_ns::num_types]; + +template +struct AllStridedFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_and; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl; + } +}; + +template +struct AllAxis1ContigFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_and; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl; + } +}; + +template +struct AllAxis0ContigFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_and; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl; + } +}; + +void populate_all_dispatch_vectors(void) +{ + using td_ns::DispatchVectorBuilder; + + DispatchVectorBuilder + all_dvb1; + all_dvb1.populate_dispatch_vector(all_reduction_strided_dispatch_vector); + + DispatchVectorBuilder + all_dvb2; + all_dvb2.populate_dispatch_vector( + all_reduction_axis1_contig_dispatch_vector); + + DispatchVectorBuilder + all_dvb3; + all_dvb3.populate_dispatch_vector( + all_reduction_axis0_contig_dispatch_vector); +}; + +using atomic_support::atomic_support_fn_ptr_t; +using atomic_support::check_atomic_support; +static atomic_support_fn_ptr_t all_atomic_support = + check_atomic_support; + +} // namespace impl + +void init_all(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + impl::populate_all_dispatch_vectors(); + using impl::all_reduction_axis0_contig_dispatch_vector; + using impl::all_reduction_axis1_contig_dispatch_vector; + using impl::all_reduction_strided_dispatch_vector; + + using impl::all_atomic_support; + + auto all_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_boolean_reduction( + src, trailing_dims_to_reduce, dst, exec_q, depends, + all_reduction_axis1_contig_dispatch_vector, + all_reduction_axis0_contig_dispatch_vector, + all_reduction_strided_dispatch_vector, all_atomic_support); + }; + m.def("_all", all_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/all.hpp b/dpctl_ext/tensor/libtensor/source/reductions/all.hpp new file mode 100644 index 000000000000..5fb184e37c66 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/all.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_all(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/any.cpp b/dpctl_ext/tensor/libtensor/source/reductions/any.cpp new file mode 100644 index 000000000000..6859e46cbc4a --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/any.cpp @@ -0,0 +1,164 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "reduction_atomic_support.hpp" +#include "reduction_over_axis.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + any_reduction_strided_dispatch_vector[td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + any_reduction_axis1_contig_dispatch_vector[td_ns::num_types]; +static reduction_contig_impl_fn_ptr + any_reduction_axis0_contig_dispatch_vector[td_ns::num_types]; + +template +struct AnyStridedFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_or; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl; + } +}; + +template +struct AnyAxis1ContigFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_or; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl; + } +}; + +template +struct AnyAxis0ContigFactory +{ + fnT get() const + { + using dstTy = std::int32_t; + using ReductionOpT = sycl::logical_or; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl; + } +}; + +void populate_any_dispatch_vectors(void) +{ + using td_ns::DispatchVectorBuilder; + + DispatchVectorBuilder + any_dvb1; + any_dvb1.populate_dispatch_vector(any_reduction_strided_dispatch_vector); + + DispatchVectorBuilder + any_dvb2; + any_dvb2.populate_dispatch_vector( + any_reduction_axis1_contig_dispatch_vector); + + DispatchVectorBuilder + any_dvb3; + any_dvb3.populate_dispatch_vector( + any_reduction_axis0_contig_dispatch_vector); +}; + +using atomic_support::atomic_support_fn_ptr_t; +using atomic_support::check_atomic_support; +static atomic_support_fn_ptr_t any_atomic_support = + check_atomic_support; + +} // namespace impl + +void init_any(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + impl::populate_any_dispatch_vectors(); + using impl::any_reduction_axis0_contig_dispatch_vector; + using impl::any_reduction_axis1_contig_dispatch_vector; + using impl::any_reduction_strided_dispatch_vector; + + using impl::any_atomic_support; + + auto any_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_boolean_reduction( + src, trailing_dims_to_reduce, dst, exec_q, depends, + any_reduction_axis1_contig_dispatch_vector, + any_reduction_axis0_contig_dispatch_vector, + any_reduction_strided_dispatch_vector, any_atomic_support); + }; + m.def("_any", any_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/any.hpp b/dpctl_ext/tensor/libtensor/source/reductions/any.hpp new file mode 100644 index 000000000000..4e368a674615 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/any.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_any(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/argmax.cpp b/dpctl_ext/tensor/libtensor/source/reductions/argmax.cpp new file mode 100644 index 000000000000..10fc49759168 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/argmax.cpp @@ -0,0 +1,279 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "reduction_over_axis.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_dispatch_building.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace su_ns = dpctl::tensor::sycl_utils; + +namespace impl +{ + +using dpctl::tensor::kernels::search_strided_impl_fn_ptr; +static search_strided_impl_fn_ptr + argmax_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::search_contig_impl_fn_ptr; +static search_contig_impl_fn_ptr + argmax_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +using dpctl::tensor::kernels::search_contig_impl_fn_ptr; +static search_contig_impl_fn_ptr + argmax_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +template +struct TypePairSupportForArgmaxReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + // input int8_t + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::int64_t>, + + td_ns::TypePairDefinedEntry, + outTy, + std::int64_t>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct ArgmaxOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgmaxReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct ArgmaxOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgmaxReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis1_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis1_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct ArgmaxOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgmaxReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis0_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Maximum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis0_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +void populate_argmax_over_axis_dispatch_tables(void) +{ + using td_ns::DispatchTableBuilder; + + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(argmax_over_axis_strided_temps_dispatch_table); + + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(argmax_over_axis1_contig_temps_dispatch_table); + + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(argmax_over_axis0_contig_temps_dispatch_table); +} + +} // namespace impl + +void init_argmax(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_argmax_over_axis_dispatch_tables; + populate_argmax_over_axis_dispatch_tables(); + using impl::argmax_over_axis0_contig_temps_dispatch_table; + using impl::argmax_over_axis1_contig_temps_dispatch_table; + using impl::argmax_over_axis_strided_temps_dispatch_table; + + auto argmax_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_search_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + argmax_over_axis_strided_temps_dispatch_table, + argmax_over_axis0_contig_temps_dispatch_table, + argmax_over_axis1_contig_temps_dispatch_table); + }; + m.def("_argmax_over_axis", argmax_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/argmax.hpp b/dpctl_ext/tensor/libtensor/source/reductions/argmax.hpp new file mode 100644 index 000000000000..3274f8c7d0cb --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/argmax.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_argmax(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/argmin.cpp b/dpctl_ext/tensor/libtensor/source/reductions/argmin.cpp new file mode 100644 index 000000000000..ec4637b62d49 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/argmin.cpp @@ -0,0 +1,279 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "reduction_over_axis.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_dispatch_building.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace su_ns = dpctl::tensor::sycl_utils; + +namespace impl +{ + +using dpctl::tensor::kernels::search_strided_impl_fn_ptr; +static search_strided_impl_fn_ptr + argmin_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::search_contig_impl_fn_ptr; +static search_contig_impl_fn_ptr + argmin_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +using dpctl::tensor::kernels::search_contig_impl_fn_ptr; +static search_contig_impl_fn_ptr + argmin_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +template +struct TypePairSupportForArgminReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + // input int8_t + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::int64_t>, + + td_ns::TypePairDefinedEntry, + outTy, + std::int64_t>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct ArgminOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgminReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_over_group_temps_strided_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct ArgminOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgminReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis1_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis1_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct ArgminOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportForArgminReductionTemps::is_defined) + { + if constexpr (std::is_integral_v && + !std::is_same_v) { + // op for values + using ReductionOpT = sycl::minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis0_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + else { + // op for values + using ReductionOpT = su_ns::Minimum; + // op for indices + using IndexOpT = sycl::minimum; + return dpctl::tensor::kernels:: + search_axis0_over_group_temps_contig_impl< + srcTy, dstTy, ReductionOpT, IndexOpT>; + } + } + else { + return nullptr; + } + } +}; + +void populate_argmin_over_axis_dispatch_tables(void) +{ + using td_ns::DispatchTableBuilder; + + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(argmin_over_axis_strided_temps_dispatch_table); + + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(argmin_over_axis1_contig_temps_dispatch_table); + + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(argmin_over_axis0_contig_temps_dispatch_table); +} + +} // namespace impl + +void init_argmin(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_argmin_over_axis_dispatch_tables; + populate_argmin_over_axis_dispatch_tables(); + using impl::argmin_over_axis0_contig_temps_dispatch_table; + using impl::argmin_over_axis1_contig_temps_dispatch_table; + using impl::argmin_over_axis_strided_temps_dispatch_table; + + auto argmin_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_search_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + argmin_over_axis_strided_temps_dispatch_table, + argmin_over_axis0_contig_temps_dispatch_table, + argmin_over_axis1_contig_temps_dispatch_table); + }; + m.def("_argmin_over_axis", argmin_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/argmin.hpp b/dpctl_ext/tensor/libtensor/source/reductions/argmin.hpp new file mode 100644 index 000000000000..1865c258a527 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/argmin.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_argmin(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/logsumexp.cpp b/dpctl_ext/tensor/libtensor/source/reductions/logsumexp.cpp new file mode 100644 index 000000000000..08ed4f12dbda --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/logsumexp.cpp @@ -0,0 +1,257 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "reduction_over_axis.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_dispatch_building.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace su_ns = dpctl::tensor::sycl_utils; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + logsumexp_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + logsumexp_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + logsumexp_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +template +struct TypePairSupportDataForLogSumExpReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< +#if 1 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, +#endif + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct LogSumExpOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForLogSumExpReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::LogSumExp; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct LogSumExpOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForLogSumExpReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::LogSumExp; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct LogSumExpOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForLogSumExpReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::LogSumExp; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +void populate_logsumexp_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using namespace td_ns; + + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table( + logsumexp_over_axis_strided_temps_dispatch_table); + + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table( + logsumexp_over_axis1_contig_temps_dispatch_table); + + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table( + logsumexp_over_axis0_contig_temps_dispatch_table); +} + +} // namespace impl + +void init_logsumexp(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_logsumexp_over_axis_dispatch_tables; + populate_logsumexp_over_axis_dispatch_tables(); + using impl::logsumexp_over_axis0_contig_temps_dispatch_table; + using impl::logsumexp_over_axis1_contig_temps_dispatch_table; + using impl::logsumexp_over_axis_strided_temps_dispatch_table; + + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + + auto logsumexp_pyapi = [&](const arrayT &src, + int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_tree_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + logsumexp_over_axis_strided_temps_dispatch_table, + logsumexp_over_axis0_contig_temps_dispatch_table, + logsumexp_over_axis1_contig_temps_dispatch_table); + }; + m.def("_logsumexp_over_axis", logsumexp_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto logsumexp_dtype_supported = [&](const py::dtype &input_dtype, + const py::dtype &output_dtype) { + return py_tree_reduction_dtype_supported( + input_dtype, output_dtype, + logsumexp_over_axis_strided_temps_dispatch_table); + }; + m.def("_logsumexp_over_axis_dtype_supported", logsumexp_dtype_supported, + "", py::arg("arg_dtype"), py::arg("out_dtype")); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/logsumexp.hpp b/dpctl_ext/tensor/libtensor/source/reductions/logsumexp.hpp new file mode 100644 index 000000000000..2e2c19877db6 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/logsumexp.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_logsumexp(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/max.cpp b/dpctl_ext/tensor/libtensor/source/reductions/max.cpp new file mode 100644 index 000000000000..d19ed226d3b4 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/max.cpp @@ -0,0 +1,410 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_dispatch_building.hpp" + +#include "reduction_atomic_support.hpp" +#include "reduction_over_axis.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace su_ns = dpctl::tensor::sycl_utils; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + max_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + max_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + max_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + max_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + max_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + max_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +/* @brief Types supported by max reduction code based on atomic_ref */ +template +struct TypePairSupportDataForMaxReductionAtomic +{ + /* value is true if a kernel for must be instantiated, false + * otherwise */ + static constexpr bool is_defined = std::disjunction< + // input int32 + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + // input float + td_ns::TypePairDefinedEntry, + // input double + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForMaxReductionTemps +{ + static constexpr bool is_defined = std::disjunction< + // input bool + td_ns::TypePairDefinedEntry, + // input int8_t + td_ns::TypePairDefinedEntry, + // input uint8_t + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + // input uint16_t + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct MaxOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MaxOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMaxReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + using ReductionOpT = su_ns::Maximum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + } + else { + return nullptr; + } + } +}; + +void populate_max_over_axis_dispatch_tables(void) +{ + using td_ns::DispatchTableBuilder; + + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(max_over_axis_strided_atomic_dispatch_table); + + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(max_over_axis_strided_temps_dispatch_table); + + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(max_over_axis1_contig_atomic_dispatch_table); + + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(max_over_axis0_contig_atomic_dispatch_table); + + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(max_over_axis1_contig_temps_dispatch_table); + + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(max_over_axis0_contig_temps_dispatch_table); +} + +using atomic_support::atomic_support_fn_ptr_t; +static atomic_support_fn_ptr_t max_atomic_support_vector[td_ns::num_types]; + +void populate_max_atomic_support_dispatch_vector(void) +{ + using td_ns::DispatchVectorBuilder; + + using atomic_support::MaxAtomicSupportFactory; + DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(max_atomic_support_vector); +} + +} // namespace impl + +void init_max(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_max_over_axis_dispatch_tables; + populate_max_over_axis_dispatch_tables(); + using impl::max_over_axis0_contig_atomic_dispatch_table; + using impl::max_over_axis0_contig_temps_dispatch_table; + using impl::max_over_axis1_contig_atomic_dispatch_table; + using impl::max_over_axis1_contig_temps_dispatch_table; + using impl::max_over_axis_strided_atomic_dispatch_table; + using impl::max_over_axis_strided_temps_dispatch_table; + + using impl::populate_max_atomic_support_dispatch_vector; + populate_max_atomic_support_dispatch_vector(); + using impl::max_atomic_support_vector; + + auto max_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + max_over_axis_strided_atomic_dispatch_table, + max_over_axis0_contig_atomic_dispatch_table, + max_over_axis1_contig_atomic_dispatch_table, + max_over_axis_strided_temps_dispatch_table, + max_over_axis0_contig_temps_dispatch_table, + max_over_axis1_contig_temps_dispatch_table, + max_atomic_support_vector); + }; + m.def("_max_over_axis", max_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/max.hpp b/dpctl_ext/tensor/libtensor/source/reductions/max.hpp new file mode 100644 index 000000000000..bc242dc8d74b --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/max.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_max(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/min.cpp b/dpctl_ext/tensor/libtensor/source/reductions/min.cpp new file mode 100644 index 000000000000..97d3432b13ed --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/min.cpp @@ -0,0 +1,412 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_dispatch_building.hpp" + +#include "reduction_atomic_support.hpp" +#include "reduction_over_axis.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace su_ns = dpctl::tensor::sycl_utils; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + min_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + min_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + min_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + min_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + min_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + min_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +/* @brief Types supported by min reduction code based on atomic_ref */ +template +struct TypePairSupportDataForMinReductionAtomic +{ + /* value is true if a kernel for must be instantiated, false + * otherwise */ + static constexpr bool is_defined = std::disjunction< + // input int32 + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + // input float + td_ns::TypePairDefinedEntry, + // input double + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForMinReductionTemps +{ + static constexpr bool is_defined = std::disjunction< + // input bool + td_ns::TypePairDefinedEntry, + // input int8_t + td_ns::TypePairDefinedEntry, + // input uint8_t + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + // input uint16_t + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct MinOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionAtomic< + srcTy, dstTy>::is_defined) + { + if constexpr (std::is_floating_point::value) { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + } + else { + return nullptr; + } + } +}; + +template +struct MinOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForMinReductionTemps< + srcTy, dstTy>::is_defined) { + if constexpr (std::is_integral_v && + !std::is_same_v) { + using ReductionOpT = sycl::minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + using ReductionOpT = su_ns::Minimum; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + } + else { + return nullptr; + } + } +}; + +void populate_min_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using td_ns::DispatchTableBuilder; + + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(min_over_axis_strided_atomic_dispatch_table); + + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(min_over_axis_strided_temps_dispatch_table); + + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(min_over_axis1_contig_atomic_dispatch_table); + + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(min_over_axis0_contig_atomic_dispatch_table); + + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(min_over_axis1_contig_temps_dispatch_table); + + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(min_over_axis0_contig_temps_dispatch_table); +} + +using atomic_support::atomic_support_fn_ptr_t; +static atomic_support_fn_ptr_t min_atomic_support_vector[td_ns::num_types]; + +void populate_min_atomic_support_dispatch_vector(void) +{ + using td_ns::DispatchVectorBuilder; + + using atomic_support::MinAtomicSupportFactory; + DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(min_atomic_support_vector); +} + +} // namespace impl + +void init_min(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_min_over_axis_dispatch_tables; + populate_min_over_axis_dispatch_tables(); + using impl::min_over_axis0_contig_atomic_dispatch_table; + using impl::min_over_axis0_contig_temps_dispatch_table; + using impl::min_over_axis1_contig_atomic_dispatch_table; + using impl::min_over_axis1_contig_temps_dispatch_table; + using impl::min_over_axis_strided_atomic_dispatch_table; + using impl::min_over_axis_strided_temps_dispatch_table; + + using impl::populate_min_atomic_support_dispatch_vector; + populate_min_atomic_support_dispatch_vector(); + using impl::min_atomic_support_vector; + + auto min_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + min_over_axis_strided_atomic_dispatch_table, + min_over_axis0_contig_atomic_dispatch_table, + min_over_axis1_contig_atomic_dispatch_table, + min_over_axis_strided_temps_dispatch_table, + min_over_axis0_contig_temps_dispatch_table, + min_over_axis1_contig_temps_dispatch_table, + min_atomic_support_vector); + }; + m.def("_min_over_axis", min_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/min.hpp b/dpctl_ext/tensor/libtensor/source/reductions/min.hpp new file mode 100644 index 000000000000..e054f44539f3 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/min.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_min(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/prod.cpp b/dpctl_ext/tensor/libtensor/source/reductions/prod.cpp new file mode 100644 index 000000000000..e16e0cf25e1d --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/prod.cpp @@ -0,0 +1,464 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "utils/type_dispatch_building.hpp" + +#include "reduction_atomic_support.hpp" +#include "reduction_over_axis.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + prod_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + prod_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + prod_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + prod_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + prod_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + prod_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +/* @brief Types supported by plus-reduction code based on atomic_ref */ +template +struct TypePairSupportDataForProductReductionAtomic +{ + + /* value if true a kernel for must be instantiated, false + * otherwise */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForProductReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns:: + TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input double + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct ProductOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = std::conditional_t, + sycl::logical_and, + sycl::multiplies>; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::multiplies; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = std::conditional_t, + sycl::logical_and, + sycl::multiplies>; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct ProductOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForProductReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = std::conditional_t, + sycl::logical_and, + sycl::multiplies>; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +void populate_prod_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using namespace td_ns; + + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(prod_over_axis_strided_atomic_dispatch_table); + + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(prod_over_axis_strided_temps_dispatch_table); + + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(prod_over_axis1_contig_atomic_dispatch_table); + + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(prod_over_axis0_contig_atomic_dispatch_table); + + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(prod_over_axis1_contig_temps_dispatch_table); + + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(prod_over_axis0_contig_temps_dispatch_table); +} + +using atomic_support::atomic_support_fn_ptr_t; +static atomic_support_fn_ptr_t prod_atomic_support_vector[td_ns::num_types]; + +void populate_prod_atomic_support_dispatch_vector(void) +{ + using td_ns::DispatchVectorBuilder; + + using atomic_support::ProductAtomicSupportFactory; + DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(prod_atomic_support_vector); +} + +} // namespace impl + +void init_prod(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_prod_over_axis_dispatch_tables; + populate_prod_over_axis_dispatch_tables(); + using impl::prod_over_axis0_contig_atomic_dispatch_table; + using impl::prod_over_axis0_contig_temps_dispatch_table; + using impl::prod_over_axis1_contig_atomic_dispatch_table; + using impl::prod_over_axis1_contig_temps_dispatch_table; + using impl::prod_over_axis_strided_atomic_dispatch_table; + using impl::prod_over_axis_strided_temps_dispatch_table; + + using impl::populate_prod_atomic_support_dispatch_vector; + populate_prod_atomic_support_dispatch_vector(); + using impl::prod_atomic_support_vector; + + auto prod_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + prod_over_axis_strided_atomic_dispatch_table, + prod_over_axis0_contig_atomic_dispatch_table, + prod_over_axis1_contig_atomic_dispatch_table, + prod_over_axis_strided_temps_dispatch_table, + prod_over_axis0_contig_temps_dispatch_table, + prod_over_axis1_contig_temps_dispatch_table, + prod_atomic_support_vector); + }; + m.def("_prod_over_axis", prod_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto prod_dtype_supported = + [&](const py::dtype &input_dtype, const py::dtype &output_dtype, + const std::string &dst_usm_type, sycl::queue &q) { + return py_reduction_dtype_supported( + input_dtype, output_dtype, dst_usm_type, q, + prod_over_axis_strided_atomic_dispatch_table, + prod_over_axis_strided_temps_dispatch_table, + prod_atomic_support_vector); + }; + m.def("_prod_over_axis_dtype_supported", prod_dtype_supported, "", + py::arg("arg_dtype"), py::arg("out_dtype"), + py::arg("dst_usm_type"), py::arg("sycl_queue")); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/prod.hpp b/dpctl_ext/tensor/libtensor/source/reductions/prod.hpp new file mode 100644 index 000000000000..15b1c07e5ddd --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/prod.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_prod(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/reduce_hypot.cpp b/dpctl_ext/tensor/libtensor/source/reductions/reduce_hypot.cpp new file mode 100644 index 000000000000..3c343b702238 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/reduce_hypot.cpp @@ -0,0 +1,253 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_dispatch_building.hpp" + +#include "reduction_over_axis.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace su_ns = dpctl::tensor::sycl_utils; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + hypot_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + hypot_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + hypot_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +template +struct TypePairSupportDataForHypotReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input double + td_ns::TypePairDefinedEntry, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct HypotOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForHypotReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::Hypot; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct HypotOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForHypotReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::Hypot; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct HypotOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForHypotReductionTemps< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = su_ns::Hypot; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +void populate_hypot_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using namespace td_ns; + + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(hypot_over_axis_strided_temps_dispatch_table); + + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(hypot_over_axis1_contig_temps_dispatch_table); + + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(hypot_over_axis0_contig_temps_dispatch_table); +} + +} // namespace impl + +void init_reduce_hypot(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_hypot_over_axis_dispatch_tables; + populate_hypot_over_axis_dispatch_tables(); + using impl::hypot_over_axis0_contig_temps_dispatch_table; + using impl::hypot_over_axis1_contig_temps_dispatch_table; + using impl::hypot_over_axis_strided_temps_dispatch_table; + + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + + auto hypot_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_tree_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + hypot_over_axis_strided_temps_dispatch_table, + hypot_over_axis0_contig_temps_dispatch_table, + hypot_over_axis1_contig_temps_dispatch_table); + }; + m.def("_hypot_over_axis", hypot_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto hypot_dtype_supported = [&](const py::dtype &input_dtype, + const py::dtype &output_dtype) { + return py_tree_reduction_dtype_supported( + input_dtype, output_dtype, + hypot_over_axis_strided_temps_dispatch_table); + }; + m.def("_hypot_over_axis_dtype_supported", hypot_dtype_supported, "", + py::arg("arg_dtype"), py::arg("out_dtype")); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/reduce_hypot.hpp b/dpctl_ext/tensor/libtensor/source/reductions/reduce_hypot.hpp new file mode 100644 index 000000000000..c0a16345af75 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/reduce_hypot.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_reduce_hypot(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/reduction_atomic_support.hpp b/dpctl_ext/tensor/libtensor/source/reductions/reduction_atomic_support.hpp new file mode 100644 index 000000000000..5f9cc32f1203 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/reduction_atomic_support.hpp @@ -0,0 +1,147 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +#include + +#include "utils/type_utils.hpp" + +namespace dpctl::tensor::py_internal::atomic_support +{ + +typedef bool (*atomic_support_fn_ptr_t)(const sycl::queue &, sycl::usm::alloc); + +/*! @brief Function which returns a constant value for atomic support */ +template +bool fixed_decision(const sycl::queue &, sycl::usm::alloc) +{ + return return_value; +} + +/*! @brief Template for querying atomic support for a type on a device */ +template +bool check_atomic_support(const sycl::queue &exec_q, + sycl::usm::alloc usm_alloc_type) +{ + static constexpr bool atomic32 = (sizeof(T) == 4); + static constexpr bool atomic64 = (sizeof(T) == 8); + using dpctl::tensor::type_utils::is_complex; + if constexpr ((!atomic32 && !atomic64) || is_complex::value) { + return fixed_decision(exec_q, usm_alloc_type); + } + else { + bool supports_atomics = false; + const sycl::device &dev = exec_q.get_device(); + if constexpr (atomic64) { + if (!dev.has(sycl::aspect::atomic64)) { + return false; + } + } + switch (usm_alloc_type) { + case sycl::usm::alloc::shared: + supports_atomics = + dev.has(sycl::aspect::usm_atomic_shared_allocations); + break; + case sycl::usm::alloc::host: + supports_atomics = + dev.has(sycl::aspect::usm_atomic_host_allocations); + break; + case sycl::usm::alloc::device: + supports_atomics = true; + break; + default: + supports_atomics = false; + } + return supports_atomics; + } +} + +template +struct ArithmeticAtomicSupportFactory +{ + fnT get() + { + using dpctl::tensor::type_utils::is_complex; + if constexpr (std::is_floating_point_v || + std::is_same_v || is_complex::value) + { + // for real- and complex- floating point types, tree reduction has + // better round-off accumulation properties (round-off error is + // proportional to the log2(reduction_size), while naive elementwise + // summation used by atomic implementation has round-off error + // growing proportional to the reduction_size.), hence reduction + // over floating point types should always use tree_reduction + // algorithm, even though atomic implementation may be applicable + return fixed_decision; + } + else { + return check_atomic_support; + } + } +}; + +template +struct MinMaxAtomicSupportFactory +{ + fnT get() + { + return check_atomic_support; + } +}; + +template +struct MaxAtomicSupportFactory : public MinMaxAtomicSupportFactory +{ +}; + +template +struct MinAtomicSupportFactory : public MinMaxAtomicSupportFactory +{ +}; + +template +struct SumAtomicSupportFactory : public ArithmeticAtomicSupportFactory +{ +}; + +template +struct ProductAtomicSupportFactory + : public ArithmeticAtomicSupportFactory +{ +}; + +} // namespace dpctl::tensor::py_internal::atomic_support diff --git a/dpctl_ext/tensor/libtensor/source/reductions/reduction_common.cpp b/dpctl_ext/tensor/libtensor/source/reductions/reduction_common.cpp new file mode 100644 index 000000000000..fca5e09e2fe5 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/reduction_common.cpp @@ -0,0 +1,69 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include + +#include "all.hpp" +#include "any.hpp" +#include "argmax.hpp" +#include "argmin.hpp" +#include "logsumexp.hpp" +#include "max.hpp" +#include "min.hpp" +#include "prod.hpp" +#include "reduce_hypot.hpp" +#include "sum.hpp" + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +/*! @brief Add reduction functions to Python module */ +void init_reduction_functions(py::module_ m) +{ + init_all(m); + init_any(m); + init_argmax(m); + init_argmin(m); + init_logsumexp(m); + init_max(m); + init_min(m); + init_prod(m); + init_reduce_hypot(m); + init_sum(m); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/reduction_common.hpp b/dpctl_ext/tensor/libtensor/source/reductions/reduction_common.hpp new file mode 100644 index 000000000000..4df67c16bc4e --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/reduction_common.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_reduction_functions(py::module_); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/reduction_over_axis.hpp b/dpctl_ext/tensor/libtensor/source/reductions/reduction_over_axis.hpp new file mode 100644 index 000000000000..130e61eb8e7d --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/reduction_over_axis.hpp @@ -0,0 +1,1317 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension, specifically functions for reductions. +//===---------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "simplify_iteration_space.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/output_validation.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/type_dispatch.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +/* ====================== dtype supported ======================== */ + +/*! @brief Template implementing Python API for querying type support by + * reduction which may support atomics */ +template +bool py_reduction_dtype_supported( + const py::dtype &input_dtype, + const py::dtype &output_dtype, + const std::string &dst_usm_type, + sycl::queue &q, + const fnT &atomic_dispatch_table, + const fnT &temps_dispatch_table, + const CheckAtomicSupportFnT &check_atomic_support) +{ + int arg_tn = + input_dtype.num(); // NumPy type numbers are the same as in dpctl + int out_tn = + output_dtype.num(); // NumPy type numbers are the same as in dpctl + int arg_typeid = -1; + int out_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + arg_typeid = array_types.typenum_to_lookup_id(arg_tn); + out_typeid = array_types.typenum_to_lookup_id(out_tn); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + if (arg_typeid < 0 || arg_typeid >= td_ns::num_types || out_typeid < 0 || + out_typeid >= td_ns::num_types) + { + throw std::runtime_error("Reduction type support check: lookup failed"); + } + + // remove_all_extents gets underlying type of table + using fn_ptrT = typename std::remove_all_extents::type; + fn_ptrT fn = nullptr; + + sycl::usm::alloc kind = sycl::usm::alloc::unknown; + + if (dst_usm_type == "device") { + kind = sycl::usm::alloc::device; + } + else if (dst_usm_type == "shared") { + kind = sycl::usm::alloc::shared; + } + else if (dst_usm_type == "host") { + kind = sycl::usm::alloc::host; + } + else { + throw py::value_error("Unrecognized `dst_usm_type` argument."); + } + + bool supports_atomics = check_atomic_support[out_typeid](q, kind); + + if (supports_atomics) { + fn = atomic_dispatch_table[arg_typeid][out_typeid]; + } + + if (fn == nullptr) { + // use slower reduction implementation using temporaries + fn = temps_dispatch_table[arg_typeid][out_typeid]; + } + + return (fn != nullptr); +} + +/*! @brief Template implementing Python API for querying type support by tree + * reduction */ +template +bool py_tree_reduction_dtype_supported(const py::dtype &input_dtype, + const py::dtype &output_dtype, + const fnT &temps_dispatch_table) +{ + int arg_tn = + input_dtype.num(); // NumPy type numbers are the same as in dpctl + int out_tn = + output_dtype.num(); // NumPy type numbers are the same as in dpctl + int arg_typeid = -1; + int out_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + arg_typeid = array_types.typenum_to_lookup_id(arg_tn); + out_typeid = array_types.typenum_to_lookup_id(out_tn); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + if (arg_typeid < 0 || arg_typeid >= td_ns::num_types || out_typeid < 0 || + out_typeid >= td_ns::num_types) + { + throw std::runtime_error("Reduction type support check: lookup failed"); + } + + auto fn = temps_dispatch_table[arg_typeid][out_typeid]; + + return (fn != nullptr); +} + +/* ==================== Generic reductions ====================== */ + +/*! @brief Template implementing Python API for reduction over axis which may + * support atomics */ +template +std::pair py_reduction_over_axis( + const dpctl::tensor::usm_ndarray &src, + int trailing_dims_to_reduce, // comp over this many trailing indexes + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const strided_fnT &atomic_dispatch_table, + const contig_fnT &axis0_atomic_dispatch_table, + const contig_fnT &axis1_atomic_dispatch_table, + const strided_fnT &temps_dispatch_table, + const contig_fnT &axis0_temps_dispatch_table, + const contig_fnT &axis1_temps_dispatch_table, + const SupportAtomicFnT &check_atomic_support) +{ + int src_nd = src.get_ndim(); + int iteration_nd = src_nd - trailing_dims_to_reduce; + if (trailing_dims_to_reduce <= 0 || iteration_nd < 0) { + throw py::value_error("Trailing_dim_to_reduce must be positive, but no " + "greater than rank of the array being reduced"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != iteration_nd) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of reduced dimensions"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < dst_nd); ++i) { + same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); + } + + if (!same_shapes) { + throw py::value_error("Destination shape does not match unreduced " + "dimensions of the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + std::size_t dst_nelems = dst.get_size(); + + if (dst_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + std::size_t reduction_nelems(1); + for (int i = dst_nd; i < src_nd; ++i) { + reduction_nelems *= static_cast(src_shape_ptr[i]); + } + + // check that dst and src do not overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, dst_nelems); + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + namespace td_ns = dpctl::tensor::type_dispatch; + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + void *data_ptr = dst.get_data(); + const auto &ctx = exec_q.get_context(); + auto usm_type = sycl::get_pointer_type(data_ptr, ctx); + + bool supports_atomics = check_atomic_support[dst_typeid](exec_q, usm_type); + + // handle special case when both reduction and iteration are 1D contiguous + bool is_src_c_contig = src.is_c_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + + if ((is_src_c_contig && is_dst_c_contig) || + (is_src_f_contig && dst_nelems == 1)) + { + // remove_all_extents gets underlying type of table + using contig_fn_ptr_T = + typename std::remove_all_extents::type; + contig_fn_ptr_T fn; + if (supports_atomics) { + fn = axis1_atomic_dispatch_table[src_typeid][dst_typeid]; + } + else { + fn = axis1_temps_dispatch_table[src_typeid][dst_typeid]; + } + if (fn != nullptr) { + std::size_t iter_nelems = dst_nelems; + + static constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + else if (is_src_f_contig && + ((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous())) + { + // remove_all_extents gets underlying type of table + using contig_fn_ptr_T = + typename std::remove_all_extents::type; + contig_fn_ptr_T fn; + if (supports_atomics) { + fn = axis0_atomic_dispatch_table[src_typeid][dst_typeid]; + } + else { + fn = axis0_temps_dispatch_table[src_typeid][dst_typeid]; + } + if (fn != nullptr) { + std::size_t iter_nelems = dst_nelems; + + static constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + + // TODO: not used anywhere + auto const &src_shape_vecs = src.get_shape_vector(); + auto const &src_strides_vecs = src.get_strides_vector(); + auto const &dst_strides_vecs = dst.get_strides_vector(); + + int reduction_nd = trailing_dims_to_reduce; + const py::ssize_t *reduction_shape_ptr = src_shape_ptr + dst_nd; + using shT = std::vector; + shT reduction_src_strides(std::begin(src_strides_vecs) + dst_nd, + std::end(src_strides_vecs)); + + shT simplified_reduction_shape; + shT simplified_reduction_src_strides; + py::ssize_t reduction_src_offset(0); + + simplify_iteration_space_1( + reduction_nd, reduction_shape_ptr, reduction_src_strides, + // output + simplified_reduction_shape, simplified_reduction_src_strides, + reduction_src_offset); + + const py::ssize_t *iteration_shape_ptr = src_shape_ptr; + + shT iteration_src_strides(std::begin(src_strides_vecs), + std::begin(src_strides_vecs) + iteration_nd); + shT const &iteration_dst_strides = dst_strides_vecs; + + shT simplified_iteration_shape; + shT simplified_iteration_src_strides; + shT simplified_iteration_dst_strides; + py::ssize_t iteration_src_offset(0); + py::ssize_t iteration_dst_offset(0); + + if (iteration_nd == 0) { + if (dst_nelems != 1) { + throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); + } + iteration_nd = 1; + simplified_iteration_shape.push_back(1); + simplified_iteration_src_strides.push_back(0); + simplified_iteration_dst_strides.push_back(0); + } + else { + simplify_iteration_space(iteration_nd, iteration_shape_ptr, + iteration_src_strides, iteration_dst_strides, + // output + simplified_iteration_shape, + simplified_iteration_src_strides, + simplified_iteration_dst_strides, + iteration_src_offset, iteration_dst_offset); + } + + if ((reduction_nd == 1) && (iteration_nd == 1)) { + bool mat_reduce_over_axis1 = false; + bool mat_reduce_over_axis0 = false; + bool array_reduce_all_elems = false; + std::size_t iter_nelems = dst_nelems; + + if (simplified_reduction_src_strides[0] == 1) { + array_reduce_all_elems = (simplified_iteration_shape[0] == 1); + mat_reduce_over_axis1 = + (simplified_iteration_dst_strides[0] == 1) && + (static_cast( + simplified_iteration_src_strides[0]) == reduction_nelems); + } + else if (static_cast( + simplified_reduction_src_strides[0]) == iter_nelems) + { + mat_reduce_over_axis0 = + (simplified_iteration_dst_strides[0] == 1) && + (simplified_iteration_src_strides[0] == 1); + } + + if (mat_reduce_over_axis1 || array_reduce_all_elems) { + using contig_fn_ptr_T = + typename std::remove_all_extents::type; + contig_fn_ptr_T fn; + if (supports_atomics) { + fn = axis1_atomic_dispatch_table[src_typeid][dst_typeid]; + } + else { + fn = axis1_temps_dispatch_table[src_typeid][dst_typeid]; + } + if (fn != nullptr) { + sycl::event reduction_over_axis1_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis1_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis1_contig_ev); + } + } + else if (mat_reduce_over_axis0) { + using contig_fn_ptr_T = + typename std::remove_all_extents::type; + contig_fn_ptr_T fn; + if (supports_atomics) { + fn = axis0_atomic_dispatch_table[src_typeid][dst_typeid]; + } + else { + fn = axis0_temps_dispatch_table[src_typeid][dst_typeid]; + } + if (fn != nullptr) { + sycl::event reduction_over_axis0_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis0_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis0_contig_ev); + } + } + } + + // remove_all_extents gets underlying type of table + using strided_fn_ptr_T = + typename std::remove_all_extents::type; + strided_fn_ptr_T fn = nullptr; + + if (supports_atomics) { + fn = atomic_dispatch_table[src_typeid][dst_typeid]; + } + + if (fn == nullptr) { + // use slower reduction implementation using temporaries + fn = temps_dispatch_table[src_typeid][dst_typeid]; + if (fn == nullptr) { + throw std::runtime_error("Datatypes are not supported"); + } + } + + std::vector host_task_events{}; + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto arrays_metainfo_packing_triple_ = + device_allocate_and_pack( + exec_q, host_task_events, + // iteration metadata + simplified_iteration_shape, simplified_iteration_src_strides, + simplified_iteration_dst_strides, + // reduction metadata + simplified_reduction_shape, simplified_reduction_src_strides); + auto tmp_alloc_owner = + std::move(std::get<0>(arrays_metainfo_packing_triple_)); + const auto ©_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_); + const py::ssize_t *temp_allocation_ptr = tmp_alloc_owner.get(); + + const py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; + const py::ssize_t *reduction_shape_stride = + temp_allocation_ptr + 3 * simplified_iteration_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + auto reduction_ev = + fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), dst.get_data(), + iteration_nd, iter_shape_and_strides, iteration_src_offset, + iteration_dst_offset, + reduction_nd, // number dimensions being reduced + reduction_shape_stride, reduction_src_offset, all_deps); + + sycl::event temp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {reduction_ev}, tmp_alloc_owner); + host_task_events.push_back(temp_cleanup_ev); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(keep_args_event, reduction_ev); +} + +/* ================= No atomic reductions ====================== */ + +/*! @brief Template implementing Python API for reduction over axis without + * atomics */ +template +std::pair py_tree_reduction_over_axis( + const dpctl::tensor::usm_ndarray &src, + int trailing_dims_to_reduce, // comp over this many trailing indexes + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const strided_fnT &temps_dispatch_table, + const contig_fnT &axis0_temps_dispatch_table, + const contig_fnT &axis1_temps_dispatch_table) +{ + int src_nd = src.get_ndim(); + int iteration_nd = src_nd - trailing_dims_to_reduce; + if (trailing_dims_to_reduce <= 0 || iteration_nd < 0) { + throw py::value_error("Trailing_dim_to_reduce must be positive, but no " + "greater than rank of the array being reduced"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != iteration_nd) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of reduced dimensions"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < dst_nd); ++i) { + same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); + } + + if (!same_shapes) { + throw py::value_error("Destination shape does not match unreduced " + "dimensions of the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + std::size_t dst_nelems = dst.get_size(); + + if (dst_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + std::size_t reduction_nelems(1); + for (int i = dst_nd; i < src_nd; ++i) { + reduction_nelems *= static_cast(src_shape_ptr[i]); + } + + // check that dst and src do not overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, dst_nelems); + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + namespace td_ns = dpctl::tensor::type_dispatch; + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + // handle special case when both reduction and iteration are 1D contiguous + bool is_src_c_contig = src.is_c_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + + if ((is_src_c_contig && is_dst_c_contig) || + (is_src_f_contig && dst_nelems == 1)) + { + auto fn = axis1_temps_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + std::size_t iter_nelems = dst_nelems; + + static constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + else if (is_src_f_contig && + ((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous())) + { + auto fn = axis0_temps_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + std::size_t iter_nelems = dst_nelems; + + static constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + + auto const &src_shape_vecs = src.get_shape_vector(); + auto const &src_strides_vecs = src.get_strides_vector(); + auto const &dst_strides_vecs = dst.get_strides_vector(); + + int reduction_nd = trailing_dims_to_reduce; + const py::ssize_t *reduction_shape_ptr = src_shape_ptr + dst_nd; + using shT = std::vector; + shT reduction_src_strides(std::begin(src_strides_vecs) + dst_nd, + std::end(src_strides_vecs)); + + shT simplified_reduction_shape; + shT simplified_reduction_src_strides; + py::ssize_t reduction_src_offset(0); + + simplify_iteration_space_1( + reduction_nd, reduction_shape_ptr, reduction_src_strides, + // output + simplified_reduction_shape, simplified_reduction_src_strides, + reduction_src_offset); + + const py::ssize_t *iteration_shape_ptr = src_shape_ptr; + + shT iteration_src_strides(std::begin(src_strides_vecs), + std::begin(src_strides_vecs) + iteration_nd); + shT const &iteration_dst_strides = dst_strides_vecs; + + shT simplified_iteration_shape; + shT simplified_iteration_src_strides; + shT simplified_iteration_dst_strides; + py::ssize_t iteration_src_offset(0); + py::ssize_t iteration_dst_offset(0); + + if (iteration_nd == 0) { + if (dst_nelems != 1) { + throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); + } + iteration_nd = 1; + simplified_iteration_shape.push_back(1); + simplified_iteration_src_strides.push_back(0); + simplified_iteration_dst_strides.push_back(0); + } + else { + simplify_iteration_space(iteration_nd, iteration_shape_ptr, + iteration_src_strides, iteration_dst_strides, + // output + simplified_iteration_shape, + simplified_iteration_src_strides, + simplified_iteration_dst_strides, + iteration_src_offset, iteration_dst_offset); + } + + if ((reduction_nd == 1) && (iteration_nd == 1)) { + bool mat_reduce_over_axis1 = false; + bool mat_reduce_over_axis0 = false; + bool array_reduce_all_elems = false; + std::size_t iter_nelems = dst_nelems; + + if (simplified_reduction_src_strides[0] == 1) { + array_reduce_all_elems = (simplified_iteration_shape[0] == 1); + mat_reduce_over_axis1 = + (simplified_iteration_dst_strides[0] == 1) && + (static_cast( + simplified_iteration_src_strides[0]) == reduction_nelems); + } + else if (static_cast( + simplified_reduction_src_strides[0]) == iter_nelems) + { + mat_reduce_over_axis0 = + (simplified_iteration_dst_strides[0] == 1) && + (simplified_iteration_src_strides[0] == 1); + } + + if (mat_reduce_over_axis1 || array_reduce_all_elems) { + auto fn = axis1_temps_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + sycl::event reduction_over_axis1_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis1_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis1_contig_ev); + } + } + else if (mat_reduce_over_axis0) { + auto fn = axis0_temps_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + sycl::event reduction_over_axis0_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis0_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis0_contig_ev); + } + } + } + + auto fn = temps_dispatch_table[src_typeid][dst_typeid]; + if (fn == nullptr) { + throw std::runtime_error("Datatypes are not supported"); + } + + std::vector host_task_events{}; + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto arrays_metainfo_packing_triple_ = + device_allocate_and_pack( + exec_q, host_task_events, + // iteration metadata + simplified_iteration_shape, simplified_iteration_src_strides, + simplified_iteration_dst_strides, + // reduction metadata + simplified_reduction_shape, simplified_reduction_src_strides); + auto tmp_owner = std::move(std::get<0>(arrays_metainfo_packing_triple_)); + const auto ©_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_); + const py::ssize_t *temp_allocation_ptr = tmp_owner.get(); + + const py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; + const py::ssize_t *reduction_shape_stride = + temp_allocation_ptr + 3 * simplified_iteration_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + auto reduction_ev = + fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), dst.get_data(), + iteration_nd, iter_shape_and_strides, iteration_src_offset, + iteration_dst_offset, + reduction_nd, // number dimensions being reduced + reduction_shape_stride, reduction_src_offset, all_deps); + + sycl::event temp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {reduction_ev}, tmp_owner); + host_task_events.push_back(temp_cleanup_ev); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(keep_args_event, reduction_ev); +} + +/*! @brief Template implementing Python API for searching over an axis */ +template +std::pair py_search_over_axis( + const dpctl::tensor::usm_ndarray &src, + int trailing_dims_to_reduce, // comp over this many trailing indexes + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const strided_fnT &strided_dispatch_table, + const contig_fnT &axis0_contig_dispatch_table, + const contig_fnT &axis1_contig_dispatch_table) +{ + int src_nd = src.get_ndim(); + int iteration_nd = src_nd - trailing_dims_to_reduce; + if (trailing_dims_to_reduce <= 0 || iteration_nd < 0) { + throw py::value_error("Trailing_dim_to_reduce must be positive, but no " + "greater than rank of the array being reduced"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != iteration_nd) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of reduced dimensions"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < dst_nd); ++i) { + same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); + } + + if (!same_shapes) { + throw py::value_error("Destination shape does not match unreduced " + "dimensions of the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + std::size_t dst_nelems = dst.get_size(); + + if (dst_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + std::size_t reduction_nelems(1); + for (int i = dst_nd; i < src_nd; ++i) { + reduction_nelems *= static_cast(src_shape_ptr[i]); + } + + // check that dst and src do not overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, dst_nelems); + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + namespace td_ns = dpctl::tensor::type_dispatch; + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + // handle special case when both reduction and iteration are 1D contiguous + bool is_src_c_contig = src.is_c_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + + if (is_src_c_contig && is_dst_c_contig) { + auto fn = axis1_contig_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + std::size_t iter_nelems = dst_nelems; + + static constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + else if (is_src_f_contig && dst_nd == 1) { + auto fn = axis0_contig_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + std::size_t iter_nelems = dst_nelems; + + static constexpr py::ssize_t zero_offset = 0; + + sycl::event reduction_over_axis_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), + zero_offset, // iteration_src_offset + zero_offset, // iteration_dst_offset + zero_offset, // reduction_src_offset + depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis_contig_ev); + } + } + + auto const &src_shape_vecs = src.get_shape_vector(); + auto const &src_strides_vecs = src.get_strides_vector(); + auto const &dst_strides_vecs = dst.get_strides_vector(); + + int reduction_nd = trailing_dims_to_reduce; + const py::ssize_t *reduction_shape_ptr = src_shape_ptr + dst_nd; + using shT = std::vector; + shT reduction_src_strides(std::begin(src_strides_vecs) + dst_nd, + std::end(src_strides_vecs)); + + shT compact_reduction_shape; + shT compact_reduction_src_strides; + py::ssize_t reduction_src_offset(0); + + // TODO: not used anywhere + compact_iteration_space( + reduction_nd, reduction_shape_ptr, reduction_src_strides, + // output + compact_reduction_shape, compact_reduction_src_strides); + + const py::ssize_t *iteration_shape_ptr = src_shape_ptr; + + shT iteration_src_strides(std::begin(src_strides_vecs), + std::begin(src_strides_vecs) + iteration_nd); + shT const &iteration_dst_strides = dst_strides_vecs; + + shT simplified_iteration_shape; + shT simplified_iteration_src_strides; + shT simplified_iteration_dst_strides; + py::ssize_t iteration_src_offset(0); + py::ssize_t iteration_dst_offset(0); + + if (iteration_nd == 0) { + if (dst_nelems != 1) { + throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); + } + iteration_nd = 1; + simplified_iteration_shape.push_back(1); + simplified_iteration_src_strides.push_back(0); + simplified_iteration_dst_strides.push_back(0); + } + else { + simplify_iteration_space(iteration_nd, iteration_shape_ptr, + iteration_src_strides, iteration_dst_strides, + // output + simplified_iteration_shape, + simplified_iteration_src_strides, + simplified_iteration_dst_strides, + iteration_src_offset, iteration_dst_offset); + } + + if ((reduction_nd == 1) && (iteration_nd == 1)) { + bool mat_reduce_over_axis1 = false; + bool mat_reduce_over_axis0 = false; + std::size_t iter_nelems = dst_nelems; + + if (compact_reduction_src_strides[0] == 1) { + mat_reduce_over_axis1 = + (simplified_iteration_dst_strides[0] == 1) && + (static_cast( + simplified_iteration_src_strides[0]) == reduction_nelems); + } + else if (static_cast(compact_reduction_src_strides[0]) == + iter_nelems) + { + mat_reduce_over_axis0 = + (simplified_iteration_dst_strides[0] == 1) && + (simplified_iteration_src_strides[0] == 1); + } + + if (mat_reduce_over_axis1) { + auto fn = axis1_contig_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + sycl::event reduction_over_axis1_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis1_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis1_contig_ev); + } + } + else if (mat_reduce_over_axis0) { + auto fn = axis0_contig_dispatch_table[src_typeid][dst_typeid]; + if (fn != nullptr) { + sycl::event reduction_over_axis0_contig_ev = + fn(exec_q, iter_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_src_offset, + iteration_dst_offset, reduction_src_offset, depends); + + sycl::event keep_args_event = dpctl::utils::keep_args_alive( + exec_q, {src, dst}, {reduction_over_axis0_contig_ev}); + + return std::make_pair(keep_args_event, + reduction_over_axis0_contig_ev); + } + } + } + + auto fn = strided_dispatch_table[src_typeid][dst_typeid]; + if (fn == nullptr) { + throw std::runtime_error("Datatypes are not supported"); + } + + std::vector host_task_events{}; + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + + auto arrays_metainfo_packing_triple_ = + device_allocate_and_pack( + exec_q, host_task_events, + // iteration metadata + simplified_iteration_shape, simplified_iteration_src_strides, + simplified_iteration_dst_strides, + // reduction metadata + compact_reduction_shape, compact_reduction_src_strides); + auto tmp_owner = std::move(std::get<0>(arrays_metainfo_packing_triple_)); + const auto ©_metadata_ev = std::get<2>(arrays_metainfo_packing_triple_); + const py::ssize_t *temp_allocation_ptr = tmp_owner.get(); + + const py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; + const py::ssize_t *reduction_shape_stride = + temp_allocation_ptr + 3 * simplified_iteration_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + auto comp_ev = fn(exec_q, dst_nelems, reduction_nelems, src.get_data(), + dst.get_data(), iteration_nd, iter_shape_and_strides, + iteration_src_offset, iteration_dst_offset, + reduction_nd, // number dimensions being reduced + reduction_shape_stride, reduction_src_offset, all_deps); + + sycl::event temp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {comp_ev}, tmp_owner); + host_task_events.push_back(temp_cleanup_ev); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(keep_args_event, comp_ev); +} + +/* ================= Atomic only reductions ====================== */ + +/*! @brief Template implementing Python API for boolean reductions over an axis + */ +template +std::pair + py_boolean_reduction(const dpctl::tensor::usm_ndarray &src, + int trailing_dims_to_reduce, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const contig_dispatchT &axis1_contig_dispatch_vector, + const contig_dispatchT &axis0_contig_dispatch_vector, + const strided_dispatchT &strided_dispatch_vector, + const atomic_support_fnT check_atomic_support) +{ + int src_nd = src.get_ndim(); + int iter_nd = src_nd - trailing_dims_to_reduce; + if (trailing_dims_to_reduce <= 0 || iter_nd < 0) { + throw py::value_error("Trailing_dim_to_reduce must be positive, but no " + "greater than rank of the array being reduced"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != iter_nd) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of reduced dimensions"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < dst_nd); ++i) { + same_shapes = same_shapes && (src_shape_ptr[i] == dst_shape_ptr[i]); + } + + if (!same_shapes) { + throw py::value_error("Destination shape does not match unreduced " + "dimensions of the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + std::size_t dst_nelems = dst.get_size(); + + std::size_t red_nelems(1); + for (int i = dst_nd; i < src_nd; ++i) { + red_nelems *= static_cast(src_shape_ptr[i]); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(dst, src)) { + throw py::value_error("Arrays are expected to have no memory overlap"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, dst_nelems); + + const char *src_data = src.get_data(); + char *dst_data = dst.get_data(); + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + static constexpr int int32_typeid = + static_cast(td_ns::typenum_t::INT32); + if (dst_typeid != int32_typeid) { + throw py::value_error( + "Unexpected data type of destination array, expecting 'int32'"); + } + + void *data_ptr = dst.get_data(); + const auto &ctx = exec_q.get_context(); + auto usm_type = sycl::get_pointer_type(data_ptr, ctx); + + bool supports_atomics = check_atomic_support(exec_q, usm_type); + if (!supports_atomics) { + throw py::value_error( + "This reduction is not supported for this device and usm_type."); + } + + bool is_src_c_contig = src.is_c_contiguous(); + bool is_src_f_contig = src.is_f_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + + // TODO: should be dst_nelems == 0? + if ((is_src_c_contig && is_dst_c_contig) || + (is_src_f_contig && dst_nelems == 0)) + { + auto fn = axis1_contig_dispatch_vector[src_typeid]; + static constexpr py::ssize_t zero_offset = 0; + + sycl::event red_ev = + fn(exec_q, dst_nelems, red_nelems, src_data, dst_data, zero_offset, + zero_offset, zero_offset, depends); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {red_ev}); + + return std::make_pair(keep_args_event, red_ev); + } + else if (is_src_f_contig && + ((is_dst_c_contig && dst_nd == 1) || dst.is_f_contiguous())) + { + auto fn = axis0_contig_dispatch_vector[src_typeid]; + static constexpr py::ssize_t zero_offset = 0; + + sycl::event red_ev = + fn(exec_q, dst_nelems, red_nelems, src_data, dst_data, zero_offset, + zero_offset, zero_offset, depends); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {red_ev}); + + return std::make_pair(keep_args_event, red_ev); + } + + auto src_shape_vecs = src.get_shape_vector(); + auto src_strides_vecs = src.get_strides_vector(); + auto dst_strides_vecs = dst.get_strides_vector(); + + int simplified_red_nd = trailing_dims_to_reduce; + + using shT = std::vector; + shT red_src_strides(std::begin(src_strides_vecs) + dst_nd, + std::end(src_strides_vecs)); + + shT simplified_red_shape; + shT simplified_red_src_strides; + py::ssize_t red_src_offset(0); + + simplify_iteration_space_1( + simplified_red_nd, src_shape_ptr + dst_nd, red_src_strides, + // output + simplified_red_shape, simplified_red_src_strides, red_src_offset); + + shT iter_src_strides(std::begin(src_strides_vecs), + std::begin(src_strides_vecs) + iter_nd); + shT const &iter_dst_strides = dst_strides_vecs; + + shT simplified_iter_shape; + shT simplified_iter_src_strides; + shT simplified_iter_dst_strides; + py::ssize_t iter_src_offset(0); + py::ssize_t iter_dst_offset(0); + + if (iter_nd == 0) { + if (dst_nelems != 1) { + throw std::runtime_error("iteration_nd == 0, but dst_nelems != 1"); + } + iter_nd = 1; + simplified_iter_shape.push_back(1); + simplified_iter_src_strides.push_back(0); + simplified_iter_dst_strides.push_back(0); + } + else { + simplify_iteration_space( + iter_nd, src_shape_ptr, iter_src_strides, iter_dst_strides, + // output + simplified_iter_shape, simplified_iter_src_strides, + simplified_iter_dst_strides, iter_src_offset, iter_dst_offset); + } + + if (simplified_red_nd == 1 && iter_nd == 1) { + bool mat_reduce_over_axis1 = false; + bool mat_reduce_over_axis0 = false; + bool array_reduce_all_elems = false; + std::size_t iter_nelems = dst_nelems; + + if (simplified_red_src_strides[0] == 1) { + array_reduce_all_elems = (simplified_iter_shape[0] == 1); + mat_reduce_over_axis1 = + (simplified_iter_dst_strides[0] == 1) && + (static_cast(simplified_iter_src_strides[0]) == + red_nelems); + } + else if (static_cast(simplified_red_src_strides[0]) == + iter_nelems) { + mat_reduce_over_axis0 = (simplified_iter_dst_strides[0] == 1) && + (simplified_iter_src_strides[0] == 1); + } + if (mat_reduce_over_axis1 || array_reduce_all_elems) { + auto fn = axis1_contig_dispatch_vector[src_typeid]; + + sycl::event red_ev = + fn(exec_q, iter_nelems, red_nelems, src_data, dst_data, + iter_src_offset, iter_dst_offset, red_src_offset, depends); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {red_ev}); + + return std::make_pair(keep_args_event, red_ev); + } + else if (mat_reduce_over_axis0) { + auto fn = axis0_contig_dispatch_vector[src_typeid]; + + sycl::event red_ev = + fn(exec_q, iter_nelems, red_nelems, src_data, dst_data, + iter_src_offset, iter_dst_offset, red_src_offset, depends); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {red_ev}); + + return std::make_pair(keep_args_event, red_ev); + } + } + + auto fn = strided_dispatch_vector[src_typeid]; + + std::vector host_task_events{}; + auto iter_red_metadata_packing_triple_ = + dpctl::tensor::offset_utils::device_allocate_and_pack( + exec_q, host_task_events, simplified_iter_shape, + simplified_iter_src_strides, simplified_iter_dst_strides, + simplified_red_shape, simplified_red_src_strides); + auto packed_shapes_strides_owner = + std::move(std::get<0>(iter_red_metadata_packing_triple_)); + const auto ©_metadata_ev = + std::get<2>(iter_red_metadata_packing_triple_); + const py::ssize_t *packed_shapes_and_strides = + packed_shapes_strides_owner.get(); + + const py::ssize_t *iter_shape_and_strides = packed_shapes_and_strides; + const py::ssize_t *red_shape_stride = + packed_shapes_and_strides + 3 * simplified_iter_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + auto red_ev = + fn(exec_q, dst_nelems, red_nelems, src_data, dst_data, iter_nd, + iter_shape_and_strides, iter_src_offset, iter_dst_offset, + simplified_red_nd, red_shape_stride, red_src_offset, all_deps); + + sycl::event temp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {red_ev}, packed_shapes_strides_owner); + host_task_events.push_back(temp_cleanup_ev); + + sycl::event keep_args_event = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, host_task_events); + + return std::make_pair(keep_args_event, red_ev); +} + +extern void init_reduction_functions(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/sum.cpp b/dpctl_ext/tensor/libtensor/source/reductions/sum.cpp new file mode 100644 index 000000000000..294adfc93a26 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/sum.cpp @@ -0,0 +1,461 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/reductions.hpp" +#include "utils/type_dispatch_building.hpp" + +#include "reduction_atomic_support.hpp" +#include "reduction_over_axis.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace impl +{ + +using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; +static reduction_strided_impl_fn_ptr + sum_over_axis_strided_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_strided_impl_fn_ptr + sum_over_axis_strided_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; +static reduction_contig_impl_fn_ptr + sum_over_axis1_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + sum_over_axis0_contig_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + sum_over_axis1_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static reduction_contig_impl_fn_ptr + sum_over_axis0_contig_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +/* @brief Types supported by plus-reduction code based on atomic_ref */ +template +struct TypePairSupportDataForSumReductionAtomic +{ + + /* value if true a kernel for must be instantiated, false + * otherwise */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint8 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint16 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input uint32 + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // input int64 + td_ns::TypePairDefinedEntry, + // input uint64 + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct TypePairSupportDataForSumReductionTemps +{ + + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint8_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint16_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint32_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input int64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input uint64_t + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + + // input half + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns:: + TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input float + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + td_ns::TypePairDefinedEntry>, + + // input double + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry>, + + // input std::complex + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + td_ns::TypePairDefinedEntry, + outTy, + std::complex>, + + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct SumOverAxisAtomicStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_over_group_with_atomics_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxisTempsStridedFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionTemps< + srcTy, dstTy>::is_defined) { + using ReductionOpT = + std::conditional_t, + sycl::logical_or, sycl::plus>; + return dpctl::tensor::kernels:: + reduction_over_group_temps_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxis1AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxis0AtomicContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionAtomic< + srcTy, dstTy>::is_defined) + { + using ReductionOpT = sycl::plus; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_with_atomics_contig_impl< + srcTy, dstTy, ReductionOpT>; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxis1TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionTemps< + srcTy, dstTy>::is_defined) { + using ReductionOpT = + std::conditional_t, + sycl::logical_or, sycl::plus>; + return dpctl::tensor::kernels:: + reduction_axis1_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct SumOverAxis0TempsContigFactory +{ + fnT get() const + { + if constexpr (TypePairSupportDataForSumReductionTemps< + srcTy, dstTy>::is_defined) { + using ReductionOpT = + std::conditional_t, + sycl::logical_or, sycl::plus>; + return dpctl::tensor::kernels:: + reduction_axis0_over_group_temps_contig_impl; + } + else { + return nullptr; + } + } +}; + +void populate_sum_over_axis_dispatch_tables(void) +{ + using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr; + using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr; + using namespace td_ns; + + DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(sum_over_axis_strided_atomic_dispatch_table); + + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(sum_over_axis_strided_temps_dispatch_table); + + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(sum_over_axis1_contig_atomic_dispatch_table); + + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(sum_over_axis0_contig_atomic_dispatch_table); + + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(sum_over_axis1_contig_temps_dispatch_table); + + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(sum_over_axis0_contig_temps_dispatch_table); +} + +using atomic_support::atomic_support_fn_ptr_t; +static atomic_support_fn_ptr_t sum_atomic_support_vector[td_ns::num_types]; + +void populate_sum_atomic_support_dispatch_vector(void) +{ + using td_ns::DispatchVectorBuilder; + + using atomic_support::SumAtomicSupportFactory; + DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(sum_atomic_support_vector); +} + +} // namespace impl + +void init_sum(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + using impl::populate_sum_over_axis_dispatch_tables; + populate_sum_over_axis_dispatch_tables(); + using impl::sum_over_axis0_contig_atomic_dispatch_table; + using impl::sum_over_axis0_contig_temps_dispatch_table; + using impl::sum_over_axis1_contig_atomic_dispatch_table; + using impl::sum_over_axis1_contig_temps_dispatch_table; + using impl::sum_over_axis_strided_atomic_dispatch_table; + using impl::sum_over_axis_strided_temps_dispatch_table; + + using impl::populate_sum_atomic_support_dispatch_vector; + populate_sum_atomic_support_dispatch_vector(); + using impl::sum_atomic_support_vector; + + auto sum_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_reduction_over_axis( + src, trailing_dims_to_reduce, dst, exec_q, depends, + sum_over_axis_strided_atomic_dispatch_table, + sum_over_axis0_contig_atomic_dispatch_table, + sum_over_axis1_contig_atomic_dispatch_table, + sum_over_axis_strided_temps_dispatch_table, + sum_over_axis0_contig_temps_dispatch_table, + sum_over_axis1_contig_temps_dispatch_table, + sum_atomic_support_vector); + }; + m.def("_sum_over_axis", sum_pyapi, "", py::arg("src"), + py::arg("trailing_dims_to_reduce"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto sum_dtype_supported = + [&](const py::dtype &input_dtype, const py::dtype &output_dtype, + const std::string &dst_usm_type, sycl::queue &q) { + return py_reduction_dtype_supported( + input_dtype, output_dtype, dst_usm_type, q, + sum_over_axis_strided_atomic_dispatch_table, + sum_over_axis_strided_temps_dispatch_table, + sum_atomic_support_vector); + }; + m.def("_sum_over_axis_dtype_supported", sum_dtype_supported, "", + py::arg("arg_dtype"), py::arg("out_dtype"), + py::arg("dst_usm_type"), py::arg("sycl_queue")); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/reductions/sum.hpp b/dpctl_ext/tensor/libtensor/source/reductions/sum.hpp new file mode 100644 index 000000000000..08add902a049 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/reductions/sum.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_sum(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/isin.cpp b/dpctl_ext/tensor/libtensor/source/sorting/isin.cpp new file mode 100644 index 000000000000..f1ae5863bbb9 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/isin.cpp @@ -0,0 +1,325 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/sorting/isin.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/output_validation.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/type_dispatch.hpp" + +#include "simplify_iteration_space.hpp" + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace dpctl::tensor::py_internal +{ +namespace detail +{ + +using dpctl::tensor::kernels::isin_contig_impl_fp_ptr_t; + +static isin_contig_impl_fp_ptr_t + isin_contig_impl_dispatch_vector[td_ns::num_types]; + +template +struct IsinContigFactory +{ + constexpr IsinContigFactory() {} + + fnT get() const + { + using dpctl::tensor::kernels::isin_contig_impl; + return isin_contig_impl; + } +}; + +using dpctl::tensor::kernels::isin_strided_impl_fp_ptr_t; + +static isin_strided_impl_fp_ptr_t + isin_strided_impl_dispatch_vector[td_ns::num_types]; + +template +struct IsinStridedFactory +{ + constexpr IsinStridedFactory() {} + + fnT get() const + { + using dpctl::tensor::kernels::isin_strided_impl; + return isin_strided_impl; + } +}; + +void init_isin_dispatch_vector(void) +{ + + // Contiguous input function dispatch + td_ns::DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(isin_contig_impl_dispatch_vector); + + // Strided input function dispatch + td_ns::DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(isin_strided_impl_dispatch_vector); +} + +} // namespace detail + +/*! @brief search for needle from needles in sorted hay */ +std::pair + py_isin(const dpctl::tensor::usm_ndarray &needles, + const dpctl::tensor::usm_ndarray &hay, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const bool invert, + const std::vector &depends) +{ + const int hay_nd = hay.get_ndim(); + const int needles_nd = needles.get_ndim(); + const int dst_nd = dst.get_ndim(); + + if (hay_nd != 1 || needles_nd != dst_nd) { + throw py::value_error("Array dimensions mismatch"); + } + + // check that needle and dst have the same shape + std::size_t needles_nelems(1); + bool same_shape(true); + + const std::size_t hay_nelems = static_cast(hay.get_shape(0)); + + const py::ssize_t *needles_shape_ptr = needles.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + for (int i = 0; (i < needles_nd) && same_shape; ++i) { + const auto needles_sh_i = needles_shape_ptr[i]; + const auto dst_sh_i = dst_shape_ptr[i]; + + same_shape = same_shape && (needles_sh_i == dst_sh_i); + needles_nelems *= static_cast(needles_sh_i); + } + + if (!same_shape) { + throw py::value_error( + "Array of values to search for and array of their " + "dst do not have the same shape"); + } + + // check that dst is ample enough + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, + needles_nelems); + + // check that dst is writable + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + // check that queues are compatible + if (!dpctl::utils::queues_are_compatible(exec_q, {hay, needles, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + // if output array overlaps with input arrays, race condition results + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(dst, hay) || overlap(dst, needles)) { + throw py::value_error("Destination array overlaps with input."); + } + + const int hay_typenum = hay.get_typenum(); + const int needles_typenum = needles.get_typenum(); + const int dst_typenum = dst.get_typenum(); + + auto const &array_types = td_ns::usm_ndarray_types(); + const int hay_typeid = array_types.typenum_to_lookup_id(hay_typenum); + const int needles_typeid = + array_types.typenum_to_lookup_id(needles_typenum); + const int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + // check hay and needle have the same data-type + if (needles_typeid != hay_typeid) { + throw py::value_error( + "Hay array and needles array must have the same data types"); + } + // check that dst has boolean data type + const auto dst_typenum_t_v = static_cast(dst_typeid); + if (dst_typenum_t_v != td_ns::typenum_t::BOOL) { + throw py::value_error("dst array must have data-type bool"); + } + + if (needles_nelems == 0) { + // Nothing to do + return std::make_pair(sycl::event{}, sycl::event{}); + } + + // if all inputs are contiguous call contiguous implementations + // otherwise call strided implementation + const bool hay_is_c_contig = hay.is_c_contiguous(); + const bool hay_is_f_contig = hay.is_f_contiguous(); + + const bool needles_is_c_contig = needles.is_c_contiguous(); + const bool needles_is_f_contig = needles.is_f_contiguous(); + + const bool dst_is_c_contig = dst.is_c_contiguous(); + const bool dst_is_f_contig = dst.is_f_contiguous(); + + const bool all_c_contig = + (hay_is_c_contig && needles_is_c_contig && dst_is_c_contig); + const bool all_f_contig = + (hay_is_f_contig && needles_is_f_contig && dst_is_f_contig); + + const char *hay_data = hay.get_data(); + const char *needles_data = needles.get_data(); + + char *dst_data = dst.get_data(); + + if (all_c_contig || all_f_contig) { + auto fn = detail::isin_contig_impl_dispatch_vector[hay_typeid]; + + static constexpr py::ssize_t zero_offset(0); + + sycl::event comp_ev = fn(exec_q, invert, hay_nelems, needles_nelems, + hay_data, zero_offset, needles_data, + zero_offset, dst_data, zero_offset, depends); + + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {hay, needles, dst}, {comp_ev}), + comp_ev); + } + + // strided case + + const auto &needles_strides = needles.get_strides_vector(); + const auto &dst_strides = dst.get_strides_vector(); + + int simplified_nd = needles_nd; + + using shT = std::vector; + shT simplified_common_shape; + shT simplified_needles_strides; + shT simplified_dst_strides; + py::ssize_t needles_offset(0); + py::ssize_t dst_offset(0); + + if (simplified_nd == 0) { + // needles and dst have same nd + simplified_nd = 1; + simplified_common_shape.push_back(1); + simplified_needles_strides.push_back(0); + simplified_dst_strides.push_back(0); + } + else { + simplify_iteration_space( + // modified by reference + simplified_nd, + // read-only inputs + needles_shape_ptr, needles_strides, dst_strides, + // output, modified by reference + simplified_common_shape, simplified_needles_strides, + simplified_dst_strides, needles_offset, dst_offset); + } + std::vector host_task_events; + host_task_events.reserve(2); + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + + auto ptr_size_event_tuple = device_allocate_and_pack( + exec_q, host_task_events, + // vectors being packed + simplified_common_shape, simplified_needles_strides, + simplified_dst_strides); + auto packed_shape_strides_owner = + std::move(std::get<0>(ptr_size_event_tuple)); + const sycl::event ©_shape_strides_ev = + std::get<2>(ptr_size_event_tuple); + const py::ssize_t *packed_shape_strides = packed_shape_strides_owner.get(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shape_strides_ev); + + auto strided_fn = detail::isin_strided_impl_dispatch_vector[hay_typeid]; + + if (!strided_fn) { + throw std::runtime_error( + "No implementation for data types of input arrays"); + } + + static constexpr py::ssize_t zero_offset(0); + py::ssize_t hay_step = hay.get_strides_vector()[0]; + + const sycl::event &comp_ev = strided_fn( + exec_q, invert, hay_nelems, needles_nelems, hay_data, zero_offset, + hay_step, needles_data, needles_offset, dst_data, dst_offset, + simplified_nd, packed_shape_strides, all_deps); + + // free packed temporaries + sycl::event temporaries_cleanup_ev = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {comp_ev}, packed_shape_strides_owner); + + host_task_events.push_back(temporaries_cleanup_ev); + const sycl::event &ht_ev = dpctl::utils::keep_args_alive( + exec_q, {hay, needles, dst}, host_task_events); + + return std::make_pair(ht_ev, comp_ev); +} + +void init_isin_functions(py::module_ m) +{ + detail::init_isin_dispatch_vector(); + + m.def("_isin", &py_isin, py::arg("needles"), py::arg("hay"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("invert"), + py::arg("depends") = py::list()); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/isin.hpp b/dpctl_ext/tensor/libtensor/source/sorting/isin.hpp new file mode 100644 index 000000000000..236e8b5898c6 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/isin.hpp @@ -0,0 +1,47 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_isin_functions(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/merge_argsort.cpp b/dpctl_ext/tensor/libtensor/source/sorting/merge_argsort.cpp new file mode 100644 index 000000000000..2b6dcc8bf447 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/merge_argsort.cpp @@ -0,0 +1,157 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "utils/rich_comparisons.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/sorting/merge_sort.hpp" +#include "kernels/sorting/sort_impl_fn_ptr_t.hpp" + +#include "merge_argsort.hpp" +#include "py_argsort_common.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +using dpctl::tensor::kernels::sort_contig_fn_ptr_t; +static sort_contig_fn_ptr_t + ascending_argsort_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static sort_contig_fn_ptr_t + descending_argsort_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +template +struct AscendingArgSortContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v || + std::is_same_v) + { + using dpctl::tensor::rich_comparisons::AscendingSorter; + using Comp = typename AscendingSorter::type; + + using dpctl::tensor::kernels::stable_argsort_axis1_contig_impl; + return stable_argsort_axis1_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct DescendingArgSortContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v || + std::is_same_v) + { + using dpctl::tensor::rich_comparisons::DescendingSorter; + using Comp = typename DescendingSorter::type; + + using dpctl::tensor::kernels::stable_argsort_axis1_contig_impl; + return stable_argsort_axis1_contig_impl; + } + else { + return nullptr; + } + } +}; + +void init_merge_argsort_dispatch_tables(void) +{ + using dpctl::tensor::kernels::sort_contig_fn_ptr_t; + + td_ns::DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(ascending_argsort_contig_dispatch_table); + + td_ns::DispatchTableBuilder< + sort_contig_fn_ptr_t, DescendingArgSortContigFactory, td_ns::num_types> + dtb2; + dtb2.populate_dispatch_table(descending_argsort_contig_dispatch_table); +} + +void init_merge_argsort_functions(py::module_ m) +{ + init_merge_argsort_dispatch_tables(); + + auto py_argsort_ascending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return py_argsort(src, trailing_dims_to_sort, dst, exec_q, depends, + ascending_argsort_contig_dispatch_table); + }; + m.def("_argsort_ascending", py_argsort_ascending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto py_argsort_descending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return py_argsort(src, trailing_dims_to_sort, dst, exec_q, depends, + descending_argsort_contig_dispatch_table); + }; + m.def("_argsort_descending", py_argsort_descending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + return; +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/merge_argsort.hpp b/dpctl_ext/tensor/libtensor/source/sorting/merge_argsort.hpp new file mode 100644 index 000000000000..10777b4bc2fd --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/merge_argsort.hpp @@ -0,0 +1,47 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_merge_argsort_functions(py::module_); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/merge_sort.cpp b/dpctl_ext/tensor/libtensor/source/sorting/merge_sort.cpp new file mode 100644 index 000000000000..fbd60621b3bb --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/merge_sort.cpp @@ -0,0 +1,139 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "utils/rich_comparisons.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/sorting/merge_sort.hpp" +#include "kernels/sorting/sort_impl_fn_ptr_t.hpp" + +#include "merge_sort.hpp" +#include "py_sort_common.hpp" + +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace dpctl::tensor::py_internal +{ + +using dpctl::tensor::kernels::sort_contig_fn_ptr_t; +static sort_contig_fn_ptr_t + ascending_sort_contig_dispatch_vector[td_ns::num_types]; +static sort_contig_fn_ptr_t + descending_sort_contig_dispatch_vector[td_ns::num_types]; + +template +struct AscendingSortContigFactory +{ + fnT get() + { + using dpctl::tensor::rich_comparisons::AscendingSorter; + using Comp = typename AscendingSorter::type; + + using dpctl::tensor::kernels::stable_sort_axis1_contig_impl; + return stable_sort_axis1_contig_impl; + } +}; + +template +struct DescendingSortContigFactory +{ + fnT get() + { + using dpctl::tensor::rich_comparisons::DescendingSorter; + using Comp = typename DescendingSorter::type; + + using dpctl::tensor::kernels::stable_sort_axis1_contig_impl; + return stable_sort_axis1_contig_impl; + } +}; + +void init_merge_sort_dispatch_vectors(void) +{ + using dpctl::tensor::kernels::sort_contig_fn_ptr_t; + + td_ns::DispatchVectorBuilder + dtv1; + dtv1.populate_dispatch_vector(ascending_sort_contig_dispatch_vector); + + td_ns::DispatchVectorBuilder + dtv2; + dtv2.populate_dispatch_vector(descending_sort_contig_dispatch_vector); +} + +void init_merge_sort_functions(py::module_ m) +{ + init_merge_sort_dispatch_vectors(); + + auto py_sort_ascending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return py_sort(src, trailing_dims_to_sort, dst, exec_q, depends, + ascending_sort_contig_dispatch_vector); + }; + m.def("_sort_ascending", py_sort_ascending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto py_sort_descending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return py_sort(src, trailing_dims_to_sort, dst, exec_q, depends, + descending_sort_contig_dispatch_vector); + }; + m.def("_sort_descending", py_sort_descending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + return; +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/merge_sort.hpp b/dpctl_ext/tensor/libtensor/source/sorting/merge_sort.hpp new file mode 100644 index 000000000000..a6bdd0a4efe9 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/merge_sort.hpp @@ -0,0 +1,47 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_merge_sort_functions(py::module_); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/py_argsort_common.hpp b/dpctl_ext/tensor/libtensor/source/sorting/py_argsort_common.hpp new file mode 100644 index 000000000000..6328b3339376 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/py_argsort_common.hpp @@ -0,0 +1,184 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include + +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/type_dispatch.hpp" + +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace dpctl::tensor::py_internal +{ + +template +std::pair + py_argsort(const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const sorting_contig_impl_fnT &sort_contig_fns) +{ + int src_nd = src.get_ndim(); + int dst_nd = dst.get_ndim(); + if (src_nd != dst_nd) { + throw py::value_error("The input and output arrays must have " + "the same array ranks"); + } + int iteration_nd = src_nd - trailing_dims_to_sort; + if (trailing_dims_to_sort <= 0 || iteration_nd < 0) { + throw py::value_error("Trailing_dim_to_sort must be positive, but no " + "greater than rank of the array being sorted"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + std::size_t iter_nelems(1); + + for (int i = 0; same_shapes && (i < iteration_nd); ++i) { + auto src_shape_i = src_shape_ptr[i]; + same_shapes = same_shapes && (src_shape_i == dst_shape_ptr[i]); + iter_nelems *= static_cast(src_shape_i); + } + + std::size_t sort_nelems(1); + for (int i = iteration_nd; same_shapes && (i < src_nd); ++i) { + auto src_shape_i = src_shape_ptr[i]; + same_shapes = same_shapes && (src_shape_i == dst_shape_ptr[i]); + sort_nelems *= static_cast(src_shape_i); + } + + if (!same_shapes) { + throw py::value_error( + "Destination shape does not match the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + if ((iter_nelems == 0) || (sort_nelems == 0)) { + // Nothing to do + return std::make_pair(sycl::event(), sycl::event()); + } + + // check that dst and src do not overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample( + dst, sort_nelems * iter_nelems); + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + if ((dst_typeid != static_cast(td_ns::typenum_t::INT64)) && + (dst_typeid != static_cast(td_ns::typenum_t::INT32))) + { + throw py::value_error( + "Output index array must have data type int32 or int64"); + } + + bool is_src_c_contig = src.is_c_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + + if (is_src_c_contig && is_dst_c_contig) { + if (sort_nelems > 1) { + static constexpr py::ssize_t zero_offset = py::ssize_t(0); + + auto fn = sort_contig_fns[src_typeid][dst_typeid]; + + if (fn == nullptr) { + throw py::value_error( + "Not implemented for dtypes of input arrays"); + } + + sycl::event comp_ev = + fn(exec_q, iter_nelems, sort_nelems, src.get_data(), + dst.get_data(), zero_offset, zero_offset, zero_offset, + zero_offset, depends); + + sycl::event keep_args_alive_ev = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {comp_ev}); + + return std::make_pair(keep_args_alive_ev, comp_ev); + } + else { + assert(dst.get_size() == iter_nelems); + int dst_elemsize = dst.get_elemsize(); + static constexpr int memset_val(0); + + sycl::event fill_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.memset(reinterpret_cast(dst.get_data()), memset_val, + iter_nelems * dst_elemsize); + }); + + sycl::event keep_args_alive_ev = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {fill_ev}); + + return std::make_pair(keep_args_alive_ev, fill_ev); + } + } + + throw py::value_error( + "Both source and destination arrays must be C-contiguous"); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/py_sort_common.hpp b/dpctl_ext/tensor/libtensor/source/sorting/py_sort_common.hpp new file mode 100644 index 000000000000..ee8777f35077 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/py_sort_common.hpp @@ -0,0 +1,178 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include + +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/type_dispatch.hpp" + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace dpctl::tensor::py_internal +{ + +template +std::pair + py_sort(const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const sorting_contig_impl_fnT &sort_contig_fns) +{ + int src_nd = src.get_ndim(); + int dst_nd = dst.get_ndim(); + if (src_nd != dst_nd) { + throw py::value_error("The input and output arrays must have " + "the same array ranks"); + } + int iteration_nd = src_nd - trailing_dims_to_sort; + if (trailing_dims_to_sort <= 0 || iteration_nd < 0) { + throw py::value_error("Trailing_dim_to_sort must be positive, but no " + "greater than rank of the array being sorted"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + std::size_t iter_nelems(1); + + for (int i = 0; same_shapes && (i < iteration_nd); ++i) { + auto src_shape_i = src_shape_ptr[i]; + same_shapes = same_shapes && (src_shape_i == dst_shape_ptr[i]); + iter_nelems *= static_cast(src_shape_i); + } + + std::size_t sort_nelems(1); + for (int i = iteration_nd; same_shapes && (i < src_nd); ++i) { + auto src_shape_i = src_shape_ptr[i]; + same_shapes = same_shapes && (src_shape_i == dst_shape_ptr[i]); + sort_nelems *= static_cast(src_shape_i); + } + + if (!same_shapes) { + throw py::value_error( + "Destination shape does not match the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + if ((iter_nelems == 0) || (sort_nelems == 0)) { + // Nothing to do + return std::make_pair(sycl::event(), sycl::event()); + } + + // check that dst and src do not overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample( + dst, sort_nelems * iter_nelems); + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + if (src_typeid != dst_typeid) { + throw py::value_error("Both input arrays must have " + "the same value data type"); + } + + bool is_src_c_contig = src.is_c_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + + if (is_src_c_contig && is_dst_c_contig) { + if (sort_nelems > 1) { + static constexpr py::ssize_t zero_offset = py::ssize_t(0); + + auto fn = sort_contig_fns[src_typeid]; + + if (nullptr == fn) { + throw py::value_error( + "Not implemented for the dtype of input arrays"); + } + + sycl::event comp_ev = + fn(exec_q, iter_nelems, sort_nelems, src.get_data(), + dst.get_data(), zero_offset, zero_offset, zero_offset, + zero_offset, depends); + + sycl::event keep_args_alive_ev = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {comp_ev}); + + return std::make_pair(keep_args_alive_ev, comp_ev); + } + else { + assert(dst.get_size() == iter_nelems); + int src_elemsize = src.get_elemsize(); + + sycl::event copy_ev = + exec_q.copy(src.get_data(), dst.get_data(), + src_elemsize * iter_nelems, depends); + + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {copy_ev}), + copy_ev); + } + } + + throw py::value_error( + "Both source and destination arrays must be C-contiguous"); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/radix_argsort.cpp b/dpctl_ext/tensor/libtensor/source/sorting/radix_argsort.cpp new file mode 100644 index 000000000000..e54b8f739a4b --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/radix_argsort.cpp @@ -0,0 +1,187 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "utils/type_dispatch.hpp" + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/radix_sort.hpp" +#include "kernels/sorting/sort_impl_fn_ptr_t.hpp" + +#include "py_argsort_common.hpp" +#include "radix_argsort.hpp" +#include "radix_sort_support.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace impl_ns = dpctl::tensor::kernels::radix_sort_details; + +using dpctl::tensor::ssize_t; +using dpctl::tensor::kernels::sort_contig_fn_ptr_t; + +static sort_contig_fn_ptr_t + ascending_radix_argsort_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static sort_contig_fn_ptr_t + descending_radix_argsort_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +namespace +{ + +template +sycl::event argsort_axis1_contig_caller(sycl::queue &q, + std::size_t iter_nelems, + std::size_t sort_nelems, + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + using dpctl::tensor::kernels::radix_argsort_axis1_contig_impl; + + return radix_argsort_axis1_contig_impl( + q, is_ascending, iter_nelems, sort_nelems, arg_cp, res_cp, + iter_arg_offset, iter_res_offset, sort_arg_offset, sort_res_offset, + depends); +} + +} // end of anonymous namespace + +template +struct AscendingRadixArgSortContigFactory +{ + fnT get() + { + if constexpr (RadixSortSupportVector::is_defined && + (std::is_same_v || + std::is_same_v)) + { + return argsort_axis1_contig_caller< + /*ascending*/ true, argTy, IndexTy>; + } + else { + return nullptr; + } + } +}; + +template +struct DescendingRadixArgSortContigFactory +{ + fnT get() + { + if constexpr (RadixSortSupportVector::is_defined && + (std::is_same_v || + std::is_same_v)) + { + return argsort_axis1_contig_caller< + /*ascending*/ false, argTy, IndexTy>; + } + else { + return nullptr; + } + } +}; + +void init_radix_argsort_dispatch_tables(void) +{ + using dpctl::tensor::kernels::sort_contig_fn_ptr_t; + + td_ns::DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(ascending_radix_argsort_contig_dispatch_table); + + td_ns::DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table( + descending_radix_argsort_contig_dispatch_table); +} + +void init_radix_argsort_functions(py::module_ m) +{ + init_radix_argsort_dispatch_tables(); + + auto py_radix_argsort_ascending = + [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return py_argsort(src, trailing_dims_to_sort, dst, exec_q, depends, + ascending_radix_argsort_contig_dispatch_table); + }; + m.def("_radix_argsort_ascending", py_radix_argsort_ascending, + py::arg("src"), py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto py_radix_argsort_descending = + [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return py_argsort(src, trailing_dims_to_sort, dst, exec_q, depends, + descending_radix_argsort_contig_dispatch_table); + }; + m.def("_radix_argsort_descending", py_radix_argsort_descending, + py::arg("src"), py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + return; +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/radix_argsort.hpp b/dpctl_ext/tensor/libtensor/source/sorting/radix_argsort.hpp new file mode 100644 index 000000000000..89013fbb1bdc --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/radix_argsort.hpp @@ -0,0 +1,47 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_radix_argsort_functions(py::module_); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/radix_sort.cpp b/dpctl_ext/tensor/libtensor/source/sorting/radix_sort.cpp new file mode 100644 index 000000000000..35c71a0eb7d3 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/radix_sort.cpp @@ -0,0 +1,188 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "utils/type_dispatch.hpp" + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/radix_sort.hpp" +#include "kernels/sorting/sort_impl_fn_ptr_t.hpp" + +#include "py_sort_common.hpp" +#include "radix_sort.hpp" +#include "radix_sort_support.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace impl_ns = dpctl::tensor::kernels::radix_sort_details; + +using dpctl::tensor::ssize_t; +using dpctl::tensor::kernels::sort_contig_fn_ptr_t; +static sort_contig_fn_ptr_t + ascending_radix_sort_contig_dispatch_vector[td_ns::num_types]; +static sort_contig_fn_ptr_t + descending_radix_sort_contig_dispatch_vector[td_ns::num_types]; + +namespace +{ + +template +sycl::event sort_axis1_contig_caller(sycl::queue &q, + std::size_t iter_nelems, + std::size_t sort_nelems, + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + using dpctl::tensor::kernels::radix_sort_axis1_contig_impl; + + return radix_sort_axis1_contig_impl( + q, is_ascending, iter_nelems, sort_nelems, arg_cp, res_cp, + iter_arg_offset, iter_res_offset, sort_arg_offset, sort_res_offset, + depends); +} + +} // end of anonymous namespace + +template +struct AscendingRadixSortContigFactory +{ + fnT get() + { + if constexpr (RadixSortSupportVector::is_defined) { + return sort_axis1_contig_caller; + } + else { + return nullptr; + } + } +}; + +template +struct DescendingRadixSortContigFactory +{ + fnT get() + { + if constexpr (RadixSortSupportVector::is_defined) { + return sort_axis1_contig_caller; + } + else { + return nullptr; + } + } +}; + +void init_radix_sort_dispatch_vectors(void) +{ + using dpctl::tensor::kernels::sort_contig_fn_ptr_t; + + td_ns::DispatchVectorBuilder< + sort_contig_fn_ptr_t, AscendingRadixSortContigFactory, td_ns::num_types> + dtv1; + dtv1.populate_dispatch_vector(ascending_radix_sort_contig_dispatch_vector); + + td_ns::DispatchVectorBuilder + dtv2; + dtv2.populate_dispatch_vector(descending_radix_sort_contig_dispatch_vector); +} + +bool py_radix_sort_defined(int typenum) +{ + const auto &array_types = td_ns::usm_ndarray_types(); + + try { + int type_id = array_types.typenum_to_lookup_id(typenum); + return (nullptr != + ascending_radix_sort_contig_dispatch_vector[type_id]); + } catch (const std::exception &e) { + return false; + } +} + +void init_radix_sort_functions(py::module_ m) +{ + init_radix_sort_dispatch_vectors(); + + auto py_radix_sort_ascending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return py_sort(src, trailing_dims_to_sort, dst, exec_q, depends, + ascending_radix_sort_contig_dispatch_vector); + }; + m.def("_radix_sort_ascending", py_radix_sort_ascending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto py_radix_sort_descending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return py_sort(src, trailing_dims_to_sort, dst, exec_q, depends, + descending_radix_sort_contig_dispatch_vector); + }; + m.def("_radix_sort_descending", py_radix_sort_descending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + m.def("_radix_sort_dtype_supported", py_radix_sort_defined); + + return; +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/radix_sort.hpp b/dpctl_ext/tensor/libtensor/source/sorting/radix_sort.hpp new file mode 100644 index 000000000000..5f3c771b464b --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/radix_sort.hpp @@ -0,0 +1,47 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_radix_sort_functions(py::module_); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/radix_sort_support.hpp b/dpctl_ext/tensor/libtensor/source/sorting/radix_sort_support.hpp new file mode 100644 index 000000000000..8d7e55a5cd28 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/radix_sort_support.hpp @@ -0,0 +1,78 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include + +namespace dpctl::tensor::py_internal +{ + +template +struct TypeDefinedEntry : std::bool_constant> +{ + static constexpr bool is_defined = true; +}; + +struct NotDefinedEntry : std::true_type +{ + static constexpr bool is_defined = false; +}; + +template +struct RadixSortSupportVector +{ + using resolver_t = + typename std::disjunction, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + NotDefinedEntry>; + + static constexpr bool is_defined = resolver_t::is_defined; +}; + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/searchsorted.cpp b/dpctl_ext/tensor/libtensor/source/sorting/searchsorted.cpp new file mode 100644 index 000000000000..8b1ce04a97d6 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/searchsorted.cpp @@ -0,0 +1,478 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/sorting/searchsorted.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/output_validation.hpp" +#include "utils/rich_comparisons.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/type_dispatch.hpp" + +#include "simplify_iteration_space.hpp" + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace dpctl::tensor::py_internal +{ + +namespace detail +{ + +using dpctl::tensor::kernels::searchsorted_contig_impl_fp_ptr_t; + +static searchsorted_contig_impl_fp_ptr_t + left_side_searchsorted_contig_impl[td_ns::num_types][td_ns::num_types]; + +static searchsorted_contig_impl_fp_ptr_t + right_side_searchsorted_contig_impl[td_ns::num_types][td_ns::num_types]; + +template +struct LeftSideSearchSortedContigFactory +{ + constexpr LeftSideSearchSortedContigFactory() {} + + fnT get() const + { + if constexpr (std::is_same_v || + std::is_same_v) + { + static constexpr bool left_side_search(true); + using dpctl::tensor::kernels::searchsorted_contig_impl; + using dpctl::tensor::rich_comparisons::AscendingSorter; + + using Compare = typename AscendingSorter::type; + + return searchsorted_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct RightSideSearchSortedContigFactory +{ + constexpr RightSideSearchSortedContigFactory() {} + + fnT get() const + { + if constexpr (std::is_same_v || + std::is_same_v) + { + static constexpr bool right_side_search(false); + + using dpctl::tensor::kernels::searchsorted_contig_impl; + using dpctl::tensor::rich_comparisons::AscendingSorter; + + using Compare = typename AscendingSorter::type; + + return searchsorted_contig_impl; + } + else { + return nullptr; + } + } +}; + +using dpctl::tensor::kernels::searchsorted_strided_impl_fp_ptr_t; + +static searchsorted_strided_impl_fp_ptr_t + left_side_searchsorted_strided_impl[td_ns::num_types][td_ns::num_types]; + +static searchsorted_strided_impl_fp_ptr_t + right_side_searchsorted_strided_impl[td_ns::num_types][td_ns::num_types]; + +template +struct LeftSideSearchSortedStridedFactory +{ + constexpr LeftSideSearchSortedStridedFactory() {} + + fnT get() const + { + if constexpr (std::is_same_v || + std::is_same_v) + { + static constexpr bool left_side_search(true); + using dpctl::tensor::kernels::searchsorted_strided_impl; + using dpctl::tensor::rich_comparisons::AscendingSorter; + + using Compare = typename AscendingSorter::type; + + return searchsorted_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct RightSideSearchSortedStridedFactory +{ + constexpr RightSideSearchSortedStridedFactory() {} + + fnT get() const + { + if constexpr (std::is_same_v || + std::is_same_v) + { + static constexpr bool right_side_search(false); + using dpctl::tensor::kernels::searchsorted_strided_impl; + using dpctl::tensor::rich_comparisons::AscendingSorter; + + using Compare = typename AscendingSorter::type; + + return searchsorted_strided_impl; + } + else { + return nullptr; + } + } +}; + +void init_searchsorted_dispatch_table(void) +{ + + // Contiguous input function dispatch + td_ns::DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(left_side_searchsorted_contig_impl); + + td_ns::DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(right_side_searchsorted_contig_impl); + + // Strided input function dispatch + td_ns::DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(left_side_searchsorted_strided_impl); + + td_ns::DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(right_side_searchsorted_strided_impl); +} + +} // namespace detail + +/*! @brief search for needle from needles in sorted hay */ +std::pair + py_searchsorted(const dpctl::tensor::usm_ndarray &hay, + const dpctl::tensor::usm_ndarray &needles, + const dpctl::tensor::usm_ndarray &positions, + sycl::queue &exec_q, + const bool search_left_side, + const std::vector &depends) +{ + const int hay_nd = hay.get_ndim(); + const int needles_nd = needles.get_ndim(); + const int positions_nd = positions.get_ndim(); + + if (hay_nd != 1 || needles_nd != positions_nd) { + throw py::value_error("Array dimensions mismatch"); + } + + // check that needle and positions have the same shape + std::size_t needles_nelems(1); + bool same_shape(true); + + const std::size_t hay_nelems = static_cast(hay.get_shape(0)); + + const py::ssize_t *needles_shape_ptr = needles.get_shape_raw(); + const py::ssize_t *positions_shape_ptr = positions.get_shape_raw(); + + for (int i = 0; (i < needles_nd) && same_shape; ++i) { + const auto needles_sh_i = needles_shape_ptr[i]; + const auto positions_sh_i = positions_shape_ptr[i]; + + same_shape = same_shape && (needles_sh_i == positions_sh_i); + needles_nelems *= static_cast(needles_sh_i); + } + + if (!same_shape) { + throw py::value_error( + "Array of values to search for and array of their " + "positions do not have the same shape"); + } + + // check that positions is ample enough + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(positions, + needles_nelems); + + // check that positions is writable + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(positions); + + // check that queues are compatible + if (!dpctl::utils::queues_are_compatible(exec_q, {hay, needles, positions})) + { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + // if output array overlaps with input arrays, race condition results + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(positions, hay) || overlap(positions, needles)) { + throw py::value_error("Destination array overlaps with input."); + } + + const int hay_typenum = hay.get_typenum(); + const int needles_typenum = needles.get_typenum(); + const int positions_typenum = positions.get_typenum(); + + auto const &array_types = td_ns::usm_ndarray_types(); + const int hay_typeid = array_types.typenum_to_lookup_id(hay_typenum); + const int needles_typeid = + array_types.typenum_to_lookup_id(needles_typenum); + const int positions_typeid = + array_types.typenum_to_lookup_id(positions_typenum); + + // check hay and needle have the same data-type + if (needles_typeid != hay_typeid) { + throw py::value_error( + "Hay array and needles array must have the same data types"); + } + // check that positions has indexing data-type (int32, or int64) + const auto positions_typenum_t_v = + static_cast(positions_typeid); + if (positions_typenum_t_v != td_ns::typenum_t::INT32 && + positions_typenum_t_v != td_ns::typenum_t::INT64) + { + throw py::value_error( + "Positions array must have data-type int32, or int64"); + } + + if (needles_nelems == 0) { + // Nothing to do + return std::make_pair(sycl::event{}, sycl::event{}); + } + + // if all inputs are contiguous call contiguous implementations + // otherwise call strided implementation + const bool hay_is_c_contig = hay.is_c_contiguous(); + const bool hay_is_f_contig = hay.is_f_contiguous(); + + const bool needles_is_c_contig = needles.is_c_contiguous(); + const bool needles_is_f_contig = needles.is_f_contiguous(); + + const bool positions_is_c_contig = positions.is_c_contiguous(); + const bool positions_is_f_contig = positions.is_f_contiguous(); + + const bool all_c_contig = + (hay_is_c_contig && needles_is_c_contig && positions_is_c_contig); + const bool all_f_contig = + (hay_is_f_contig && needles_is_f_contig && positions_is_f_contig); + + const char *hay_data = hay.get_data(); + const char *needles_data = needles.get_data(); + + char *positions_data = positions.get_data(); + + if (all_c_contig || all_f_contig) { + auto fn = + (search_left_side) + ? detail::left_side_searchsorted_contig_impl[hay_typeid] + [positions_typeid] + : detail::right_side_searchsorted_contig_impl[hay_typeid] + [positions_typeid]; + + if (fn) { + static constexpr py::ssize_t zero_offset(0); + + sycl::event comp_ev = + fn(exec_q, hay_nelems, needles_nelems, hay_data, zero_offset, + needles_data, zero_offset, positions_data, zero_offset, + depends); + + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {hay, needles, positions}, + {comp_ev}), + comp_ev); + } + } + + // strided case + + const auto &needles_strides = needles.get_strides_vector(); + const auto &positions_strides = positions.get_strides_vector(); + + int simplified_nd = needles_nd; + + using shT = std::vector; + shT simplified_common_shape; + shT simplified_needles_strides; + shT simplified_positions_strides; + py::ssize_t needles_offset(0); + py::ssize_t positions_offset(0); + + if (simplified_nd == 0) { + // needles and positions have same nd + simplified_nd = 1; + simplified_common_shape.push_back(1); + simplified_needles_strides.push_back(0); + simplified_positions_strides.push_back(0); + } + else { + simplify_iteration_space( + // modified by reference + simplified_nd, + // read-only inputs + needles_shape_ptr, needles_strides, positions_strides, + // output, modified by reference + simplified_common_shape, simplified_needles_strides, + simplified_positions_strides, needles_offset, positions_offset); + } + std::vector host_task_events; + host_task_events.reserve(2); + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + + auto ptr_size_event_tuple = device_allocate_and_pack( + exec_q, host_task_events, + // vectors being packed + simplified_common_shape, simplified_needles_strides, + simplified_positions_strides); + auto packed_shape_strides_owner = + std::move(std::get<0>(ptr_size_event_tuple)); + const sycl::event ©_shape_strides_ev = + std::get<2>(ptr_size_event_tuple); + const py::ssize_t *packed_shape_strides = packed_shape_strides_owner.get(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shape_strides_ev); + + auto strided_fn = + (search_left_side) + ? detail::left_side_searchsorted_strided_impl[hay_typeid] + [positions_typeid] + : detail::right_side_searchsorted_strided_impl[hay_typeid] + [positions_typeid]; + + if (!strided_fn) { + throw std::runtime_error( + "No implementation for data types of input arrays"); + } + + static constexpr py::ssize_t zero_offset(0); + py::ssize_t hay_step = hay.get_strides_vector()[0]; + + const sycl::event &comp_ev = strided_fn( + exec_q, hay_nelems, needles_nelems, hay_data, zero_offset, hay_step, + needles_data, needles_offset, positions_data, positions_offset, + simplified_nd, packed_shape_strides, all_deps); + + // free packed temporaries + sycl::event temporaries_cleanup_ev = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {comp_ev}, packed_shape_strides_owner); + + host_task_events.push_back(temporaries_cleanup_ev); + const sycl::event &ht_ev = dpctl::utils::keep_args_alive( + exec_q, {hay, needles, positions}, host_task_events); + + return std::make_pair(ht_ev, comp_ev); +} + +/*! @brief search for needle from needles in sorted hay, + * hay[pos] <= needle < hay[pos + 1] + */ +std::pair + py_searchsorted_left(const dpctl::tensor::usm_ndarray &hay, + const dpctl::tensor::usm_ndarray &needles, + const dpctl::tensor::usm_ndarray &positions, + sycl::queue &exec_q, + const std::vector &depends) +{ + static constexpr bool side_left(true); + return py_searchsorted(hay, needles, positions, exec_q, side_left, depends); +} + +/*! @brief search for needle from needles in sorted hay, + * hay[pos] < needle <= hay[pos + 1] + */ +std::pair + py_searchsorted_right(const dpctl::tensor::usm_ndarray &hay, + const dpctl::tensor::usm_ndarray &needles, + const dpctl::tensor::usm_ndarray &positions, + sycl::queue &exec_q, + const std::vector &depends) +{ + static constexpr bool side_right(false); + return py_searchsorted(hay, needles, positions, exec_q, side_right, + depends); +} + +void init_searchsorted_functions(py::module_ m) +{ + detail::init_searchsorted_dispatch_table(); + + m.def("_searchsorted_left", &py_searchsorted_left, py::arg("hay"), + py::arg("needles"), py::arg("positions"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_searchsorted_right", &py_searchsorted_right, py::arg("hay"), + py::arg("needles"), py::arg("positions"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/searchsorted.hpp b/dpctl_ext/tensor/libtensor/source/sorting/searchsorted.hpp new file mode 100644 index 000000000000..b60dae1e0ec9 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/searchsorted.hpp @@ -0,0 +1,47 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_searchsorted_functions(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/topk.cpp b/dpctl_ext/tensor/libtensor/source/sorting/topk.cpp new file mode 100644 index 000000000000..6b8344df12c8 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/topk.cpp @@ -0,0 +1,303 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/sorting/topk.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/rich_comparisons.hpp" +#include "utils/type_dispatch.hpp" + +#include "topk.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +typedef sycl::event (*topk_impl_fn_ptr_t)(sycl::queue &, + std::size_t, + std::size_t, + std::size_t, + bool, + const char *, + char *, + char *, + const std::vector &); + +static topk_impl_fn_ptr_t topk_dispatch_vector[td_ns::num_types]; + +namespace +{ + +template +struct use_radix_sort : public std::false_type +{ +}; + +template +struct use_radix_sort< + T, + std::enable_if_t, + std::is_same, + std::is_same, + std::is_same, + std::is_same>::value>> + : public std::true_type +{ +}; + +template +sycl::event topk_caller(sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays + std::size_t axis_nelems, // size of each sub-array + std::size_t k, + bool largest, + const char *arg_cp, + char *vals_cp, + char *inds_cp, + const std::vector &depends) +{ + if constexpr (use_radix_sort::value) { + using dpctl::tensor::kernels::topk_radix_impl; + auto ascending = !largest; + return topk_radix_impl(exec_q, iter_nelems, axis_nelems, + k, ascending, arg_cp, vals_cp, + inds_cp, depends); + } + else { + using dpctl::tensor::kernels::topk_merge_impl; + if (largest) { + using CompTy = + typename dpctl::tensor::rich_comparisons::DescendingSorter< + argTy>::type; + return topk_merge_impl( + exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp, + depends); + } + else { + using CompTy = + typename dpctl::tensor::rich_comparisons::AscendingSorter< + argTy>::type; + return topk_merge_impl( + exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp, + depends); + } + } +} + +} // namespace + +std::pair + py_topk(const dpctl::tensor::usm_ndarray &src, + std::optional trailing_dims_to_search, + const std::size_t k, + const bool largest, + const dpctl::tensor::usm_ndarray &vals, + const dpctl::tensor::usm_ndarray &inds, + sycl::queue &exec_q, + const std::vector &depends) +{ + int src_nd = src.get_ndim(); + int vals_nd = vals.get_ndim(); + int inds_nd = inds.get_ndim(); + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *vals_shape_ptr = vals.get_shape_raw(); + const py::ssize_t *inds_shape_ptr = inds.get_shape_raw(); + + std::size_t axis_nelems(1); + std::size_t iter_nelems(1); + if (trailing_dims_to_search.has_value()) { + if (src_nd != vals_nd || src_nd != inds_nd) { + throw py::value_error("The input and output arrays must have " + "the same array ranks"); + } + + auto trailing_dims = trailing_dims_to_search.value(); + int iter_nd = src_nd - trailing_dims; + if (trailing_dims <= 0 || iter_nd < 0) { + throw py::value_error( + "trailing_dims_to_search must be positive, but no " + "greater than rank of the array being searched"); + } + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < iter_nd); ++i) { + auto src_shape_i = src_shape_ptr[i]; + same_shapes = same_shapes && (src_shape_i == vals_shape_ptr[i] && + src_shape_i == inds_shape_ptr[i]); + iter_nelems *= static_cast(src_shape_i); + } + + if (!same_shapes) { + throw py::value_error( + "Destination shape does not match the input shape"); + } + + std::size_t vals_k(1); + std::size_t inds_k(1); + for (int i = iter_nd; i < src_nd; ++i) { + axis_nelems *= static_cast(src_shape_ptr[i]); + vals_k *= static_cast(vals_shape_ptr[i]); + inds_k *= static_cast(inds_shape_ptr[i]); + } + + bool valid_k = (vals_k == k && inds_k == k && axis_nelems >= k); + if (!valid_k) { + throw py::value_error("The value of k is invalid for the input and " + "destination arrays"); + } + } + else { + if (vals_nd != 1 || inds_nd != 1) { + throw py::value_error("Output arrays must be one-dimensional"); + } + + for (int i = 0; i < src_nd; ++i) { + axis_nelems *= static_cast(src_shape_ptr[i]); + } + + bool valid_k = (axis_nelems >= k && + static_cast(vals_shape_ptr[0]) == k && + static_cast(inds_shape_ptr[0]) == k); + if (!valid_k) { + throw py::value_error("The value of k is invalid for the input and " + "destination arrays"); + } + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, vals, inds})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(vals); + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(inds); + + if ((iter_nelems == 0) || (axis_nelems == 0)) { + // Nothing to do + return std::make_pair(sycl::event(), sycl::event()); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, vals) || overlap(src, inds)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(vals, + k * iter_nelems); + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(inds, + k * iter_nelems); + + int src_typenum = src.get_typenum(); + int vals_typenum = vals.get_typenum(); + int inds_typenum = inds.get_typenum(); + + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int vals_typeid = array_types.typenum_to_lookup_id(vals_typenum); + int inds_typeid = array_types.typenum_to_lookup_id(inds_typenum); + + if (src_typeid != vals_typeid) { + throw py::value_error("Input array and vals array must have " + "the same data type"); + } + + if (inds_typeid != static_cast(td_ns::typenum_t::INT64)) { + throw py::value_error("Inds array must have data type int64"); + } + + bool is_src_c_contig = src.is_c_contiguous(); + bool is_vals_c_contig = vals.is_c_contiguous(); + bool is_inds_c_contig = inds.is_c_contiguous(); + + if (is_src_c_contig && is_vals_c_contig && is_inds_c_contig) { + auto fn = topk_dispatch_vector[src_typeid]; + + sycl::event comp_ev = + fn(exec_q, iter_nelems, axis_nelems, k, largest, src.get_data(), + vals.get_data(), inds.get_data(), depends); + + sycl::event keep_args_alive_ev = + dpctl::utils::keep_args_alive(exec_q, {src, vals, inds}, {comp_ev}); + + return std::make_pair(keep_args_alive_ev, comp_ev); + } + + return std::make_pair(sycl::event(), sycl::event()); +} + +template +struct TopKFactory +{ + fnT get() + { + using IdxT = std::int64_t; + return topk_caller; + } +}; + +void init_topk_dispatch_vectors(void) +{ + td_ns::DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(topk_dispatch_vector); +} + +void init_topk_functions(py::module_ m) +{ + init_topk_dispatch_vectors(); + + m.def("_topk", &py_topk, py::arg("src"), py::arg("trailing_dims_to_search"), + py::arg("k"), py::arg("largest"), py::arg("vals"), py::arg("inds"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/topk.hpp b/dpctl_ext/tensor/libtensor/source/sorting/topk.hpp new file mode 100644 index 000000000000..d39c0eefdb93 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/topk.hpp @@ -0,0 +1,47 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_topk_functions(py::module_); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/tensor_accumulation.cpp b/dpctl_ext/tensor/libtensor/source/tensor_accumulation.cpp new file mode 100644 index 000000000000..faa3fc8b52c6 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/tensor_accumulation.cpp @@ -0,0 +1,43 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_accumulation_impl +// extensions +//===----------------------------------------------------------------------===// + +#include + +#include "accumulators/accumulators_common.hpp" + +PYBIND11_MODULE(_tensor_accumulation_impl, m) +{ + dpctl::tensor::py_internal::init_accumulator_functions(m); +} diff --git a/dpctl_ext/tensor/libtensor/source/tensor_reductions.cpp b/dpctl_ext/tensor/libtensor/source/tensor_reductions.cpp new file mode 100644 index 000000000000..6e6a24f7b934 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/tensor_reductions.cpp @@ -0,0 +1,43 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_reductions_impl +/// extension. +//===---------------------------------------------------------------------===// + +#include + +#include "reductions/reduction_common.hpp" + +PYBIND11_MODULE(_tensor_reductions_impl, m) +{ + dpctl::tensor::py_internal::init_reduction_functions(m); +} diff --git a/dpctl_ext/tensor/libtensor/source/tensor_sorting.cpp b/dpctl_ext/tensor/libtensor/source/tensor_sorting.cpp new file mode 100644 index 000000000000..318c3559d77c --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/tensor_sorting.cpp @@ -0,0 +1,55 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#include + +#include "sorting/isin.hpp" +#include "sorting/merge_argsort.hpp" +#include "sorting/merge_sort.hpp" +#include "sorting/radix_argsort.hpp" +#include "sorting/radix_sort.hpp" +#include "sorting/searchsorted.hpp" +#include "sorting/topk.hpp" + +PYBIND11_MODULE(_tensor_sorting_impl, m) +{ + dpctl::tensor::py_internal::init_isin_functions(m); + dpctl::tensor::py_internal::init_merge_sort_functions(m); + dpctl::tensor::py_internal::init_merge_argsort_functions(m); + dpctl::tensor::py_internal::init_searchsorted_functions(m); + dpctl::tensor::py_internal::init_radix_sort_functions(m); + dpctl::tensor::py_internal::init_radix_argsort_functions(m); + dpctl::tensor::py_internal::init_topk_functions(m); +} diff --git a/dpnp/dpnp_iface_counting.py b/dpnp/dpnp_iface_counting.py index a4b85aa85294..a8ebafbcead7 100644 --- a/dpnp/dpnp_iface_counting.py +++ b/dpnp/dpnp_iface_counting.py @@ -39,8 +39,9 @@ """ -import dpctl.tensor as dpt - +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt import dpnp diff --git a/dpnp/dpnp_iface_logic.py b/dpnp/dpnp_iface_logic.py index 1834f25a0485..a81416a28e43 100644 --- a/dpnp/dpnp_iface_logic.py +++ b/dpnp/dpnp_iface_logic.py @@ -44,11 +44,13 @@ # pylint: disable=no-name-in-module -import dpctl.tensor as dpt import dpctl.tensor._tensor_elementwise_impl as ti import dpctl.utils as dpu import numpy +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt import dpnp import dpnp.backend.extensions.ufunc._ufunc_impl as ufi from dpnp.dpnp_algo.dpnp_elementwise_common import DPNPBinaryFunc, DPNPUnaryFunc diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index d188ae098cd9..b5afd9523d67 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -378,22 +378,24 @@ def _get_first_nan_index(usm_a): true_val = dpt_ext.asarray( True, sycl_queue=usm_a.sycl_queue, usm_type=usm_a.usm_type ) - return dpt.searchsorted(dpt.isnan(usm_a), true_val, side="left") - return dpt.searchsorted(usm_a, usm_a[-1], side="left") + return dpt_ext.searchsorted( + dpt.isnan(usm_a), true_val, side="left" + ) + return dpt_ext.searchsorted(usm_a, usm_a[-1], side="left") return None usm_ar = dpnp.get_usm_ndarray(ar) num_of_flags = (return_index, return_inverse, return_counts).count(True) if num_of_flags == 0: - usm_res = dpt.unique_values(usm_ar) + usm_res = dpt_ext.unique_values(usm_ar) usm_res = (usm_res,) # cast to a tuple to align with other cases elif num_of_flags == 1 and return_inverse: - usm_res = dpt.unique_inverse(usm_ar) + usm_res = dpt_ext.unique_inverse(usm_ar) elif num_of_flags == 1 and return_counts: - usm_res = dpt.unique_counts(usm_ar) + usm_res = dpt_ext.unique_counts(usm_ar) else: - usm_res = dpt.unique_all(usm_ar) + usm_res = dpt_ext.unique_all(usm_ar) first_nan = None if equal_nan: @@ -426,7 +428,9 @@ def _get_first_nan_index(usm_a): if first_nan is not None: # all NaNs are collapsed, so need to put a count of all NaNs # at the last index - dpt.sum(usm_res.counts[first_nan:], out=usm_res.counts[first_nan]) + dpt_ext.sum( + usm_res.counts[first_nan:], out=usm_res.counts[first_nan] + ) result += (usm_res.counts[: first_nan + 1],) else: result += (usm_res.counts,) diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 06f4fe936253..cdcdd3af92e4 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -47,14 +47,13 @@ import builtins import warnings -import dpctl.tensor as dpt import dpctl.tensor._tensor_elementwise_impl as ti import dpctl.utils as dpu import numpy # TODO: revert to `import dpctl.tensor...` # when dpnp fully migrates dpctl/tensor -import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor as dpt import dpctl_ext.tensor._type_utils as dtu import dpnp import dpnp.backend.extensions.ufunc._ufunc_impl as ufi @@ -730,7 +729,7 @@ def clip(a, /, min=None, max=None, *, out=None, order="K", **kwargs): usm_max = None if max is None else dpnp.get_usm_ndarray_or_scalar(max) usm_out = None if out is None else dpnp.get_usm_ndarray(out) - usm_res = dpt_ext.clip(usm_arr, usm_min, usm_max, out=usm_out, order=order) + usm_res = dpt.clip(usm_arr, usm_min, usm_max, out=usm_out, order=order) if out is not None and isinstance(out, dpnp_array): return out return dpnp_array._create_from_usm_ndarray(usm_res) diff --git a/dpnp/dpnp_iface_searching.py b/dpnp/dpnp_iface_searching.py index 15f52338ec7e..19279f81286a 100644 --- a/dpnp/dpnp_iface_searching.py +++ b/dpnp/dpnp_iface_searching.py @@ -39,12 +39,10 @@ """ -import dpctl.tensor as dpt - # pylint: disable=no-name-in-module # TODO: revert to `import dpctl.tensor...` # when dpnp fully migrates dpctl/tensor -import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor as dpt import dpctl_ext.tensor._tensor_impl as dti import dpnp @@ -376,7 +374,7 @@ def searchsorted(a, v, side="left", sorter=None): usm_a = dpnp.get_usm_ndarray(a) if dpnp.isscalar(v): - usm_v = dpt_ext.asarray(v, sycl_queue=a.sycl_queue, usm_type=a.usm_type) + usm_v = dpt.asarray(v, sycl_queue=a.sycl_queue, usm_type=a.usm_type) else: usm_v = dpnp.get_usm_ndarray(v) @@ -474,7 +472,5 @@ def where(condition, x=None, y=None, /, *, order="K", out=None): usm_condition = dpnp.get_usm_ndarray(condition) usm_out = None if out is None else dpnp.get_usm_ndarray(out) - usm_res = dpt_ext.where( - usm_condition, usm_x, usm_y, order=order, out=usm_out - ) + usm_res = dpt.where(usm_condition, usm_x, usm_y, order=order, out=usm_out) return dpnp.get_result_array(usm_res, out) diff --git a/dpnp/dpnp_iface_sorting.py b/dpnp/dpnp_iface_sorting.py index 5f7a3829b3c9..e7abef1f4338 100644 --- a/dpnp/dpnp_iface_sorting.py +++ b/dpnp/dpnp_iface_sorting.py @@ -41,11 +41,9 @@ from collections.abc import Sequence -import dpctl.tensor as dpt - # TODO: revert to `import dpctl.tensor...` # when dpnp fully migrates dpctl/tensor -import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor as dpt import dpnp from dpctl_ext.tensor._numpy_helper import normalize_axis_index @@ -87,7 +85,7 @@ def _wrap_sort_argsort( usm_a = dpnp.get_usm_ndarray(a) if axis is None: - usm_a = dpt_ext.reshape(usm_a, -1) + usm_a = dpt.reshape(usm_a, -1) axis = -1 axis = normalize_axis_index(axis, ndim=usm_a.ndim) diff --git a/dpnp/dpnp_iface_statistics.py b/dpnp/dpnp_iface_statistics.py index 9d3ccc40ecf5..75fe215837b9 100644 --- a/dpnp/dpnp_iface_statistics.py +++ b/dpnp/dpnp_iface_statistics.py @@ -1118,7 +1118,7 @@ def max(a, axis=None, out=None, keepdims=False, initial=None, where=True): return dpnp_wrap_reduction_call( usm_a, out, - dpt.max, + dpt_ext.max, a.dtype, axis=axis, keepdims=keepdims, @@ -1395,7 +1395,7 @@ def min(a, axis=None, out=None, keepdims=False, initial=None, where=True): return dpnp_wrap_reduction_call( usm_a, out, - dpt.min, + dpt_ext.min, a.dtype, axis=axis, keepdims=keepdims, diff --git a/dpnp/dpnp_iface_trigonometric.py b/dpnp/dpnp_iface_trigonometric.py index 9894bd304701..a17c7dfd9d9a 100644 --- a/dpnp/dpnp_iface_trigonometric.py +++ b/dpnp/dpnp_iface_trigonometric.py @@ -42,12 +42,11 @@ # pylint: disable=protected-access # pylint: disable=no-name-in-module - -import dpctl.tensor as dpt import dpctl.tensor._tensor_elementwise_impl as ti # TODO: revert to `import dpctl.tensor...` # when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt import dpctl_ext.tensor._type_utils as dtu import dpnp import dpnp.backend.extensions.ufunc._ufunc_impl as ufi diff --git a/pyproject.toml b/pyproject.toml index 67fb75cb5f54..09253467b8dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,7 +108,7 @@ target-version = ['py310', 'py311', 'py312', 'py313', 'py314'] [tool.codespell] builtin = "clear,rare,informal,names" check-filenames = true -ignore-words-list = "amin,arange,elemt,fro,hist,ith,mone,nd,nin,sinc,vart,GroupT,AccessorT,IndexT" +ignore-words-list = "amin,arange,elemt,fro,hist,ith,mone,nd,nin,sinc,vart,GroupT,AccessorT,IndexT,fpT,OffsetT,inpT" quiet-level = 3 [tool.coverage.report]