diff --git a/dpctl_ext/tensor/CMakeLists.txt b/dpctl_ext/tensor/CMakeLists.txt index eff5e7552648..056b7c425544 100644 --- a/dpctl_ext/tensor/CMakeLists.txt +++ b/dpctl_ext/tensor/CMakeLists.txt @@ -69,10 +69,23 @@ set(_accumulator_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp ) +set(_sorting_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/isin.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_argsort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/topk.cpp +) set(_tensor_accumulation_impl_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp ${_accumulator_sources} ) +set(_tensor_sorting_impl_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp + ${_sorting_sources} +) set(_static_lib_trgt simplify_iteration_space) @@ -101,6 +114,12 @@ add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_accumulation_i target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt}) list(APPEND _py_trgts ${python_module_name}) +set(python_module_name _tensor_sorting_impl) +pybind11_add_module(${python_module_name} MODULE ${_tensor_sorting_impl_sources}) +add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_sources}) +target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt}) +list(APPEND _py_trgts ${python_module_name}) + set(_clang_prefix "") if(WIN32) set(_clang_prefix "/clang:") @@ -117,7 +136,7 @@ list( APPEND _no_fast_math_sources # ${_elementwise_sources} # ${_reduction_sources} - # ${_sorting_sources} + ${_sorting_sources} # ${_linalg_sources} ${_accumulator_sources} ) diff --git a/dpctl_ext/tensor/__init__.py b/dpctl_ext/tensor/__init__.py index 72c7536ed473..cba7c417d559 100644 --- a/dpctl_ext/tensor/__init__.py +++ b/dpctl_ext/tensor/__init__.py @@ -80,10 +80,20 @@ ) from ._reshape import reshape from ._search_functions import where +from ._searchsorted import searchsorted +from ._set_functions import ( + isin, + unique_all, + unique_counts, + unique_inverse, + unique_values, +) +from ._sorting import argsort, sort, top_k from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type __all__ = [ "arange", + "argsort", "asarray", "asnumpy", "astype", @@ -108,6 +118,7 @@ "full_like", "iinfo", "isdtype", + "isin", "linspace", "meshgrid", "moveaxis", @@ -122,15 +133,22 @@ "reshape", "result_type", "roll", + "searchsorted", + "sort", "squeeze", "stack", "swapaxes", "take", "take_along_axis", "tile", + "top_k", "to_numpy", "tril", "triu", + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", "unstack", "where", "zeros", diff --git a/dpctl_ext/tensor/_searchsorted.py b/dpctl_ext/tensor/_searchsorted.py new file mode 100644 index 000000000000..2d4807fb0d0c --- /dev/null +++ b/dpctl_ext/tensor/_searchsorted.py @@ -0,0 +1,189 @@ +# ***************************************************************************** +# Copyright (c) 2026, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + + +from typing import Literal, Union + +import dpctl +import dpctl.utils as du + +# TODO: revert to `from ._usmarray import...` +# when dpnp fully migrates dpctl/tensor +from dpctl.tensor._usmarray import usm_ndarray + +from ._copy_utils import _empty_like_orderK +from ._ctors import empty +from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy +from ._tensor_impl import _take as ti_take +from ._tensor_impl import ( + default_device_index_type as ti_default_device_index_type, +) +from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right +from ._type_utils import isdtype, result_type + + +def searchsorted( + x1: usm_ndarray, + x2: usm_ndarray, + /, + *, + side: Literal["left", "right"] = "left", + sorter: Union[usm_ndarray, None] = None, +) -> usm_ndarray: + """searchsorted(x1, x2, side='left', sorter=None) + + Finds the indices into `x1` such that, if the corresponding elements + in `x2` were inserted before the indices, the order of `x1`, when sorted + in ascending order, would be preserved. + + Args: + x1 (usm_ndarray): + input array. Must be a one-dimensional array. If `sorter` is + `None`, must be sorted in ascending order; otherwise, `sorter` must + be an array of indices that sort `x1` in ascending order. + x2 (usm_ndarray): + array containing search values. + side (Literal["left", "right]): + argument controlling which index is returned if a value lands + exactly on an edge. If `x2` is an array of rank `N` where + `v = x2[n, m, ..., j]`, the element `ret[n, m, ..., j]` in the + return array `ret` contains the position `i` such that + if `side="left"`, it is the first index such that + `x1[i-1] < v <= x1[i]`, `0` if `v <= x1[0]`, and `x1.size` + if `v > x1[-1]`; + and if `side="right"`, it is the first position `i` such that + `x1[i-1] <= v < x1[i]`, `0` if `v < x1[0]`, and `x1.size` + if `v >= x1[-1]`. Default: `"left"`. + sorter (Optional[usm_ndarray]): + array of indices that sort `x1` in ascending order. The array must + have the same shape as `x1` and have an integral data type. + Out of bound index values of `sorter` array are treated using + `"wrap"` mode documented in :py:func:`dpctl.tensor.take`. + Default: `None`. + """ + if not isinstance(x1, usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}") + if not isinstance(x2, usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") + if sorter is not None and not isinstance(sorter, usm_ndarray): + raise TypeError( + f"Expected dpctl.tensor.usm_ndarray, got {type(sorter)}" + ) + + if side not in ["left", "right"]: + raise ValueError( + "Unrecognized value of 'side' keyword argument. " + "Expected either 'left' or 'right'" + ) + + if sorter is None: + q = du.get_execution_queue([x1.sycl_queue, x2.sycl_queue]) + else: + q = du.get_execution_queue( + [x1.sycl_queue, x2.sycl_queue, sorter.sycl_queue] + ) + if q is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously " + "inferred from input arguments." + ) + + if x1.ndim != 1: + raise ValueError("First argument array must be one-dimensional") + + x1_dt = x1.dtype + x2_dt = x2.dtype + + _manager = du.SequentialOrderManager[q] + dep_evs = _manager.submitted_events + ev = dpctl.SyclEvent() + if sorter is not None: + if not isdtype(sorter.dtype, "integral"): + raise ValueError( + f"Sorter array must have integral data type, got {sorter.dtype}" + ) + if x1.shape != sorter.shape: + raise ValueError( + "Sorter array must be one-dimension with the same " + "shape as the first argument array" + ) + res = empty(x1.shape, dtype=x1_dt, usm_type=x1.usm_type, sycl_queue=q) + ind = (sorter,) + axis = 0 + wrap_out_of_bound_indices_mode = 0 + ht_ev, ev = ti_take( + x1, + ind, + res, + axis, + wrap_out_of_bound_indices_mode, + sycl_queue=q, + depends=dep_evs, + ) + x1 = res + _manager.add_event_pair(ht_ev, ev) + + if x1_dt != x2_dt: + dt = result_type(x1, x2) + if x1_dt != dt: + x1_buf = _empty_like_orderK(x1, dt) + dep_evs = _manager.submitted_events + ht_ev, ev = ti_copy( + src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, ev) + x1 = x1_buf + if x2_dt != dt: + x2_buf = _empty_like_orderK(x2, dt) + dep_evs = _manager.submitted_events + ht_ev, ev = ti_copy( + src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, ev) + x2 = x2_buf + + dst_usm_type = du.get_coerced_usm_type([x1.usm_type, x2.usm_type]) + index_dt = ti_default_device_index_type(q) + + dst = _empty_like_orderK(x2, index_dt, usm_type=dst_usm_type) + + dep_evs = _manager.submitted_events + if side == "left": + ht_ev, s_ev = _searchsorted_left( + hay=x1, + needles=x2, + positions=dst, + sycl_queue=q, + depends=dep_evs, + ) + else: + ht_ev, s_ev = _searchsorted_right( + hay=x1, needles=x2, positions=dst, sycl_queue=q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, s_ev) + return dst diff --git a/dpctl_ext/tensor/_set_functions.py b/dpctl_ext/tensor/_set_functions.py new file mode 100644 index 000000000000..93f81f044fd2 --- /dev/null +++ b/dpctl_ext/tensor/_set_functions.py @@ -0,0 +1,803 @@ +# ***************************************************************************** +# Copyright (c) 2026, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +from typing import NamedTuple, Optional, Union + +import dpctl.tensor as dpt +import dpctl.utils as du +from dpctl.tensor._tensor_elementwise_impl import _not_equal, _subtract + +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext + +from ._copy_utils import _empty_like_orderK +from ._scalar_utils import ( + _get_dtype, + _get_queue_usm_type, + _get_shape, + _validate_dtype, +) +from ._tensor_impl import ( + _copy_usm_ndarray_into_usm_ndarray, + _extract, + _full_usm_ndarray, + _linspace_step, + _take, + default_device_index_type, + mask_positions, +) +from ._tensor_sorting_impl import ( + _argsort_ascending, + _isin, + _searchsorted_left, + _sort_ascending, +) +from ._type_utils import ( + _resolve_weak_types_all_py_ints, + _to_device_supported_dtype, +) + +__all__ = [ + "isin", + "unique_values", + "unique_counts", + "unique_inverse", + "unique_all", + "UniqueAllResult", + "UniqueCountsResult", + "UniqueInverseResult", +] + + +class UniqueAllResult(NamedTuple): + values: dpt.usm_ndarray + indices: dpt.usm_ndarray + inverse_indices: dpt.usm_ndarray + counts: dpt.usm_ndarray + + +class UniqueCountsResult(NamedTuple): + values: dpt.usm_ndarray + counts: dpt.usm_ndarray + + +class UniqueInverseResult(NamedTuple): + values: dpt.usm_ndarray + inverse_indices: dpt.usm_ndarray + + +def unique_values(x: dpt.usm_ndarray) -> dpt.usm_ndarray: + """unique_values(x) + + Returns the unique elements of an input array `x`. + + Args: + x (usm_ndarray): + input array. Inputs with more than one dimension are flattened. + Returns: + usm_ndarray + an array containing the set of unique elements in `x`. The + returned array has the same data type as `x`. + """ + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + array_api_dev = x.device + exec_q = array_api_dev.sycl_queue + if x.ndim == 1: + fx = x + else: + fx = dpt_ext.reshape(x, (x.size,), order="C") + if fx.size == 0: + return fx + s = dpt_ext.empty_like(fx, order="C") + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + if fx.flags.c_contiguous: + ht_ev, sort_ev = _sort_ascending( + src=fx, + trailing_dims_to_sort=1, + dst=s, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, sort_ev) + else: + tmp = dpt_ext.empty_like(fx, order="C") + ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray( + src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, copy_ev) + ht_ev, sort_ev = _sort_ascending( + src=tmp, + trailing_dims_to_sort=1, + dst=s, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, sort_ev) + unique_mask = dpt_ext.empty(fx.shape, dtype="?", sycl_queue=exec_q) + ht_ev, uneq_ev = _not_equal( + src1=s[:-1], + src2=s[1:], + dst=unique_mask[1:], + sycl_queue=exec_q, + depends=[sort_ev], + ) + _manager.add_event_pair(ht_ev, uneq_ev) + # writing into new allocation, no dependencies + ht_ev, one_ev = _full_usm_ndarray( + fill_value=True, dst=unique_mask[0], sycl_queue=exec_q + ) + _manager.add_event_pair(ht_ev, one_ev) + cumsum = dpt_ext.empty(s.shape, dtype=dpt.int64, sycl_queue=exec_q) + # synchronizing call + n_uniques = mask_positions( + unique_mask, cumsum, sycl_queue=exec_q, depends=[one_ev, uneq_ev] + ) + if n_uniques == fx.size: + return s + unique_vals = dpt_ext.empty( + n_uniques, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q + ) + ht_ev, ex_e = _extract( + src=s, + cumsum=cumsum, + axis_start=0, + axis_end=1, + dst=unique_vals, + sycl_queue=exec_q, + ) + _manager.add_event_pair(ht_ev, ex_e) + return unique_vals + + +def unique_counts(x: dpt.usm_ndarray) -> UniqueCountsResult: + """unique_counts(x) + + Returns the unique elements of an input array `x` and the corresponding + counts for each unique element in `x`. + + Args: + x (usm_ndarray): + input array. Inputs with more than one dimension are flattened. + Returns: + tuple[usm_ndarray, usm_ndarray] + a namedtuple `(values, counts)` whose + + * first element is the field name `values` and is an array + containing the unique elements of `x`. This array has the + same data type as `x`. + * second element has the field name `counts` and is an array + containing the number of times each unique element occurs in `x`. + This array has the same shape as `values` and has the default + array index data type. + """ + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + array_api_dev = x.device + exec_q = array_api_dev.sycl_queue + x_usm_type = x.usm_type + if x.ndim == 1: + fx = x + else: + fx = dpt_ext.reshape(x, (x.size,), order="C") + ind_dt = default_device_index_type(exec_q) + if fx.size == 0: + return UniqueCountsResult(fx, dpt_ext.empty_like(fx, dtype=ind_dt)) + s = dpt_ext.empty_like(fx, order="C") + + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + if fx.flags.c_contiguous: + ht_ev, sort_ev = _sort_ascending( + src=fx, + trailing_dims_to_sort=1, + dst=s, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, sort_ev) + else: + tmp = dpt_ext.empty_like(fx, order="C") + ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray( + src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, copy_ev) + ht_ev, sort_ev = _sort_ascending( + src=tmp, + dst=s, + trailing_dims_to_sort=1, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, sort_ev) + unique_mask = dpt_ext.empty(s.shape, dtype="?", sycl_queue=exec_q) + ht_ev, uneq_ev = _not_equal( + src1=s[:-1], + src2=s[1:], + dst=unique_mask[1:], + sycl_queue=exec_q, + depends=[sort_ev], + ) + _manager.add_event_pair(ht_ev, uneq_ev) + # no dependency, since we write into new allocation + ht_ev, one_ev = _full_usm_ndarray( + fill_value=True, dst=unique_mask[0], sycl_queue=exec_q + ) + _manager.add_event_pair(ht_ev, one_ev) + cumsum = dpt_ext.empty( + unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q + ) + # synchronizing call + n_uniques = mask_positions( + unique_mask, cumsum, sycl_queue=exec_q, depends=[one_ev, uneq_ev] + ) + if n_uniques == fx.size: + return UniqueCountsResult( + s, + dpt_ext.ones( + n_uniques, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q + ), + ) + unique_vals = dpt_ext.empty( + n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q + ) + # populate unique values + ht_ev, ex_e = _extract( + src=s, + cumsum=cumsum, + axis_start=0, + axis_end=1, + dst=unique_vals, + sycl_queue=exec_q, + ) + _manager.add_event_pair(ht_ev, ex_e) + unique_counts = dpt_ext.empty( + n_uniques + 1, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q + ) + idx = dpt_ext.empty(x.size, dtype=ind_dt, sycl_queue=exec_q) + # writing into new allocation, no dependency + ht_ev, id_ev = _linspace_step(start=0, dt=1, dst=idx, sycl_queue=exec_q) + _manager.add_event_pair(ht_ev, id_ev) + ht_ev, extr_ev = _extract( + src=idx, + cumsum=cumsum, + axis_start=0, + axis_end=1, + dst=unique_counts[:-1], + sycl_queue=exec_q, + depends=[id_ev], + ) + _manager.add_event_pair(ht_ev, extr_ev) + # no dependency, writing into disjoint segmenent of new allocation + ht_ev, set_ev = _full_usm_ndarray( + x.size, dst=unique_counts[-1], sycl_queue=exec_q + ) + _manager.add_event_pair(ht_ev, set_ev) + _counts = dpt_ext.empty_like(unique_counts[1:]) + ht_ev, sub_ev = _subtract( + src1=unique_counts[1:], + src2=unique_counts[:-1], + dst=_counts, + sycl_queue=exec_q, + depends=[set_ev, extr_ev], + ) + _manager.add_event_pair(ht_ev, sub_ev) + return UniqueCountsResult(unique_vals, _counts) + + +def unique_inverse(x): + """unique_inverse + + Returns the unique elements of an input array x and the indices from the + set of unique elements that reconstruct `x`. + + Args: + x (usm_ndarray): + input array. Inputs with more than one dimension are flattened. + Returns: + tuple[usm_ndarray, usm_ndarray] + a namedtuple `(values, inverse_indices)` whose + + * first element has the field name `values` and is an array + containing the unique elements of `x`. The array has the same + data type as `x`. + * second element has the field name `inverse_indices` and is an + array containing the indices of values that reconstruct `x`. + The array has the same shape as `x` and has the default array + index data type. + """ + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + array_api_dev = x.device + exec_q = array_api_dev.sycl_queue + x_usm_type = x.usm_type + ind_dt = default_device_index_type(exec_q) + if x.ndim == 1: + fx = x + else: + fx = dpt_ext.reshape(x, (x.size,), order="C") + sorting_ids = dpt_ext.empty_like(fx, dtype=ind_dt, order="C") + unsorting_ids = dpt_ext.empty_like(sorting_ids, dtype=ind_dt, order="C") + if fx.size == 0: + return UniqueInverseResult(fx, dpt_ext.reshape(unsorting_ids, x.shape)) + + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + if fx.flags.c_contiguous: + ht_ev, sort_ev = _argsort_ascending( + src=fx, + trailing_dims_to_sort=1, + dst=sorting_ids, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, sort_ev) + else: + tmp = dpt_ext.empty_like(fx, order="C") + ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray( + src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, copy_ev) + ht_ev, sort_ev = _argsort_ascending( + src=tmp, + trailing_dims_to_sort=1, + dst=sorting_ids, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, sort_ev) + ht_ev, argsort_ev = _argsort_ascending( + src=sorting_ids, + trailing_dims_to_sort=1, + dst=unsorting_ids, + sycl_queue=exec_q, + depends=[sort_ev], + ) + _manager.add_event_pair(ht_ev, argsort_ev) + s = dpt_ext.empty_like(fx) + # s = fx[sorting_ids] + ht_ev, take_ev = _take( + src=fx, + ind=(sorting_ids,), + dst=s, + axis_start=0, + mode=0, + sycl_queue=exec_q, + depends=[sort_ev], + ) + _manager.add_event_pair(ht_ev, take_ev) + unique_mask = dpt_ext.empty(fx.shape, dtype="?", sycl_queue=exec_q) + ht_ev, uneq_ev = _not_equal( + src1=s[:-1], + src2=s[1:], + dst=unique_mask[1:], + sycl_queue=exec_q, + depends=[take_ev], + ) + _manager.add_event_pair(ht_ev, uneq_ev) + # no dependency + ht_ev, one_ev = _full_usm_ndarray( + fill_value=True, dst=unique_mask[0], sycl_queue=exec_q + ) + _manager.add_event_pair(ht_ev, one_ev) + cumsum = dpt_ext.empty( + unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q + ) + # synchronizing call + n_uniques = mask_positions( + unique_mask, cumsum, sycl_queue=exec_q, depends=[uneq_ev, one_ev] + ) + if n_uniques == fx.size: + return UniqueInverseResult(s, dpt_ext.reshape(unsorting_ids, x.shape)) + unique_vals = dpt_ext.empty( + n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q + ) + ht_ev, uv_ev = _extract( + src=s, + cumsum=cumsum, + axis_start=0, + axis_end=1, + dst=unique_vals, + sycl_queue=exec_q, + ) + _manager.add_event_pair(ht_ev, uv_ev) + cum_unique_counts = dpt_ext.empty( + n_uniques + 1, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q + ) + idx = dpt_ext.empty(x.size, dtype=ind_dt, sycl_queue=exec_q) + ht_ev, id_ev = _linspace_step(start=0, dt=1, dst=idx, sycl_queue=exec_q) + _manager.add_event_pair(ht_ev, id_ev) + ht_ev, extr_ev = _extract( + src=idx, + cumsum=cumsum, + axis_start=0, + axis_end=1, + dst=cum_unique_counts[:-1], + sycl_queue=exec_q, + depends=[id_ev], + ) + _manager.add_event_pair(ht_ev, extr_ev) + ht_ev, set_ev = _full_usm_ndarray( + x.size, dst=cum_unique_counts[-1], sycl_queue=exec_q + ) + _manager.add_event_pair(ht_ev, set_ev) + _counts = dpt_ext.empty_like(cum_unique_counts[1:]) + ht_ev, sub_ev = _subtract( + src1=cum_unique_counts[1:], + src2=cum_unique_counts[:-1], + dst=_counts, + sycl_queue=exec_q, + depends=[set_ev, extr_ev], + ) + _manager.add_event_pair(ht_ev, sub_ev) + + inv = dpt_ext.empty_like(x, dtype=ind_dt, order="C") + ht_ev, ssl_ev = _searchsorted_left( + hay=unique_vals, + needles=x, + positions=inv, + sycl_queue=exec_q, + depends=[ + uv_ev, + ], + ) + _manager.add_event_pair(ht_ev, ssl_ev) + + return UniqueInverseResult(unique_vals, inv) + + +def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult: + """unique_all(x) + + Returns the unique elements of an input array `x`, the first occurring + indices for each unique element in `x`, the indices from the set of unique + elements that reconstruct `x`, and the corresponding counts for each + unique element in `x`. + + Args: + x (usm_ndarray): + input array. Inputs with more than one dimension are flattened. + Returns: + tuple[usm_ndarray, usm_ndarray, usm_ndarray, usm_ndarray] + a namedtuple `(values, indices, inverse_indices, counts)` whose + + * first element has the field name `values` and is an array + containing the unique elements of `x`. The array has the same + data type as `x`. + * second element has the field name `indices` and is an array + the indices (of first occurrences) of `x` that result in + `values`. The array has the same shape as `values` and has the + default array index data type. + * third element has the field name `inverse_indices` and is an + array containing the indices of values that reconstruct `x`. + The array has the same shape as `x` and has the default array + index data type. + * fourth element has the field name `counts` and is an array + containing the number of times each unique element occurs in `x`. + This array has the same shape as `values` and has the default + array index data type. + """ + if not isinstance(x, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}") + array_api_dev = x.device + exec_q = array_api_dev.sycl_queue + x_usm_type = x.usm_type + ind_dt = default_device_index_type(exec_q) + if x.ndim == 1: + fx = x + else: + fx = dpt_ext.reshape(x, (x.size,), order="C") + sorting_ids = dpt_ext.empty_like(fx, dtype=ind_dt, order="C") + unsorting_ids = dpt_ext.empty_like(sorting_ids, dtype=ind_dt, order="C") + if fx.size == 0: + # original array contains no data + # so it can be safely returned as values + return UniqueAllResult( + fx, + sorting_ids, + dpt_ext.reshape(unsorting_ids, x.shape), + dpt_ext.empty_like(fx, dtype=ind_dt), + ) + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + if fx.flags.c_contiguous: + ht_ev, sort_ev = _argsort_ascending( + src=fx, + trailing_dims_to_sort=1, + dst=sorting_ids, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, sort_ev) + else: + tmp = dpt_ext.empty_like(fx, order="C") + ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray( + src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, copy_ev) + ht_ev, sort_ev = _argsort_ascending( + src=tmp, + trailing_dims_to_sort=1, + dst=sorting_ids, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, sort_ev) + ht_ev, args_ev = _argsort_ascending( + src=sorting_ids, + trailing_dims_to_sort=1, + dst=unsorting_ids, + sycl_queue=exec_q, + depends=[sort_ev], + ) + _manager.add_event_pair(ht_ev, args_ev) + s = dpt_ext.empty_like(fx) + # s = fx[sorting_ids] + ht_ev, take_ev = _take( + src=fx, + ind=(sorting_ids,), + dst=s, + axis_start=0, + mode=0, + sycl_queue=exec_q, + depends=[sort_ev], + ) + _manager.add_event_pair(ht_ev, take_ev) + unique_mask = dpt_ext.empty(fx.shape, dtype="?", sycl_queue=exec_q) + ht_ev, uneq_ev = _not_equal( + src1=s[:-1], + src2=s[1:], + dst=unique_mask[1:], + sycl_queue=exec_q, + depends=[take_ev], + ) + _manager.add_event_pair(ht_ev, uneq_ev) + ht_ev, one_ev = _full_usm_ndarray( + fill_value=True, dst=unique_mask[0], sycl_queue=exec_q + ) + _manager.add_event_pair(ht_ev, one_ev) + cumsum = dpt_ext.empty( + unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q + ) + # synchronizing call + n_uniques = mask_positions( + unique_mask, cumsum, sycl_queue=exec_q, depends=[uneq_ev, one_ev] + ) + if n_uniques == fx.size: + _counts = dpt_ext.ones( + n_uniques, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q + ) + return UniqueAllResult( + s, + sorting_ids, + dpt_ext.reshape(unsorting_ids, x.shape), + _counts, + ) + unique_vals = dpt_ext.empty( + n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q + ) + ht_ev, uv_ev = _extract( + src=s, + cumsum=cumsum, + axis_start=0, + axis_end=1, + dst=unique_vals, + sycl_queue=exec_q, + ) + _manager.add_event_pair(ht_ev, uv_ev) + cum_unique_counts = dpt_ext.empty( + n_uniques + 1, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q + ) + idx = dpt_ext.empty(x.size, dtype=ind_dt, sycl_queue=exec_q) + ht_ev, id_ev = _linspace_step(start=0, dt=1, dst=idx, sycl_queue=exec_q) + _manager.add_event_pair(ht_ev, id_ev) + ht_ev, extr_ev = _extract( + src=idx, + cumsum=cumsum, + axis_start=0, + axis_end=1, + dst=cum_unique_counts[:-1], + sycl_queue=exec_q, + depends=[id_ev], + ) + _manager.add_event_pair(ht_ev, extr_ev) + ht_ev, set_ev = _full_usm_ndarray( + x.size, dst=cum_unique_counts[-1], sycl_queue=exec_q + ) + _manager.add_event_pair(ht_ev, set_ev) + _counts = dpt_ext.empty_like(cum_unique_counts[1:]) + ht_ev, sub_ev = _subtract( + src1=cum_unique_counts[1:], + src2=cum_unique_counts[:-1], + dst=_counts, + sycl_queue=exec_q, + depends=[set_ev, extr_ev], + ) + _manager.add_event_pair(ht_ev, sub_ev) + + inv = dpt_ext.empty_like(x, dtype=ind_dt, order="C") + ht_ev, ssl_ev = _searchsorted_left( + hay=unique_vals, + needles=x, + positions=inv, + sycl_queue=exec_q, + depends=[ + uv_ev, + ], + ) + _manager.add_event_pair(ht_ev, ssl_ev) + return UniqueAllResult( + unique_vals, + sorting_ids[cum_unique_counts[:-1]], + inv, + _counts, + ) + + +def isin( + x: Union[dpt.usm_ndarray, int, float, complex, bool], + test_elements: Union[dpt.usm_ndarray, int, float, complex, bool], + /, + *, + invert: Optional[bool] = False, +) -> dpt.usm_ndarray: + """isin(x, test_elements, /, *, invert=False) + + Tests `x in test_elements` for each element of `x`. Returns a boolean array + with the same shape as `x` that is `True` where the element is in + `test_elements`, `False` otherwise. + + Args: + x (Union[usm_ndarray, bool, int, float, complex]): + input element or elements. + test_elements (Union[usm_ndarray, bool, int, float, complex]): + elements against which to test each value of `x`. + invert (Optional[bool]): + if `True`, the output results are inverted, i.e., are equivalent to + testing `x not in test_elements` for each element of `x`. + Default: `False`. + + Returns: + usm_ndarray: + an array of the inclusion test results. The returned array has a + boolean data type and the same shape as `x`. + """ + q1, x_usm_type = _get_queue_usm_type(x) + q2, test_usm_type = _get_queue_usm_type(test_elements) + if q1 is None and q2 is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments. " + "One of the arguments must represent USM allocation and " + "expose `__sycl_usm_array_interface__` property" + ) + if q1 is None: + exec_q = q2 + res_usm_type = test_usm_type + elif q2 is None: + exec_q = q1 + res_usm_type = x_usm_type + else: + exec_q = du.get_execution_queue((q1, q2)) + if exec_q is None: + raise du.ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = du.get_coerced_usm_type( + ( + x_usm_type, + test_usm_type, + ) + ) + du.validate_usm_type(res_usm_type, allow_none=False) + sycl_dev = exec_q.sycl_device + + if not isinstance(invert, bool): + raise TypeError( + "`invert` keyword argument must be of boolean type, " + f"got {type(invert)}" + ) + + x_dt = _get_dtype(x, sycl_dev) + test_dt = _get_dtype(test_elements, sycl_dev) + if not all(_validate_dtype(dt) for dt in (x_dt, test_dt)): + raise ValueError("Operands have unsupported data types") + + x_sh = _get_shape(x) + if isinstance(test_elements, dpt.usm_ndarray) and test_elements.size == 0: + if invert: + return dpt_ext.ones( + x_sh, dtype=dpt.bool, usm_type=res_usm_type, sycl_queue=exec_q + ) + else: + return dpt_ext.zeros( + x_sh, dtype=dpt.bool, usm_type=res_usm_type, sycl_queue=exec_q + ) + + dt1, dt2 = _resolve_weak_types_all_py_ints(x_dt, test_dt, sycl_dev) + dt = _to_device_supported_dtype(dpt_ext.result_type(dt1, dt2), sycl_dev) + + if not isinstance(x, dpt.usm_ndarray): + x_arr = dpt_ext.asarray( + x, dtype=dt1, usm_type=res_usm_type, sycl_queue=exec_q + ) + else: + x_arr = x + + if not isinstance(test_elements, dpt.usm_ndarray): + test_arr = dpt_ext.asarray( + test_elements, dtype=dt2, usm_type=res_usm_type, sycl_queue=exec_q + ) + else: + test_arr = test_elements + + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + + if x_dt != dt: + x_buf = _empty_like_orderK(x_arr, dt, res_usm_type, exec_q) + ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray( + src=x_arr, dst=x_buf, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, ev) + else: + x_buf = x_arr + + if test_dt != dt: + # copy into C-contiguous memory, because the array will be flattened + test_buf = dpt_ext.empty_like( + test_arr, dtype=dt, order="C", usm_type=res_usm_type + ) + ht_ev, ev = _copy_usm_ndarray_into_usm_ndarray( + src=test_arr, dst=test_buf, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, ev) + else: + test_buf = test_arr + + test_buf = dpt_ext.reshape(test_buf, -1) + test_buf = dpt_ext.sort(test_buf) + + dst = dpt_ext.empty_like( + x_buf, dtype=dpt.bool, usm_type=res_usm_type, order="C" + ) + + dep_evs = _manager.submitted_events + ht_ev, s_ev = _isin( + needles=x_buf, + hay=test_buf, + dst=dst, + sycl_queue=exec_q, + invert=invert, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, s_ev) + return dst diff --git a/dpctl_ext/tensor/_sorting.py b/dpctl_ext/tensor/_sorting.py new file mode 100644 index 000000000000..24693a408889 --- /dev/null +++ b/dpctl_ext/tensor/_sorting.py @@ -0,0 +1,450 @@ +# ***************************************************************************** +# Copyright (c) 2026, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +import operator +from typing import NamedTuple + +import dpctl.tensor as dpt +import dpctl.utils as du + +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor._tensor_impl as ti + +from ._numpy_helper import normalize_axis_index +from ._tensor_sorting_impl import ( + _argsort_ascending, + _argsort_descending, + _radix_argsort_ascending, + _radix_argsort_descending, + _radix_sort_ascending, + _radix_sort_descending, + _radix_sort_dtype_supported, + _sort_ascending, + _sort_descending, + _topk, +) + +__all__ = ["sort", "argsort", "top_k"] + + +def _get_mergesort_impl_fn(descending): + return _sort_descending if descending else _sort_ascending + + +def _get_radixsort_impl_fn(descending): + return _radix_sort_descending if descending else _radix_sort_ascending + + +def sort(x, /, *, axis=-1, descending=False, stable=True, kind=None): + """sort(x, axis=-1, descending=False, stable=True) + + Returns a sorted copy of an input array `x`. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int]): + axis along which to sort. If set to `-1`, the function + must sort along the last axis. Default: `-1`. + descending (Optional[bool]): + sort order. If `True`, the array must be sorted in descending + order (by value). If `False`, the array must be sorted in + ascending order (by value). Default: `False`. + stable (Optional[bool]): + sort stability. If `True`, the returned array must maintain the + relative order of `x` values which compare as equal. If `False`, + the returned array may or may not maintain the relative order of + `x` values which compare as equal. Default: `True`. + kind (Optional[Literal["stable", "mergesort", "radixsort"]]): + Sorting algorithm. The default is `"stable"`, which uses parallel + merge-sort or parallel radix-sort algorithms depending on the + array data type. + Returns: + usm_ndarray: + a sorted array. The returned array has the same data type and + the same shape as the input array `x`. + """ + if not isinstance(x, dpt.usm_ndarray): + raise TypeError( + f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}" + ) + nd = x.ndim + if nd == 0: + axis = normalize_axis_index(axis, ndim=1, msg_prefix="axis") + return dpt_ext.copy(x, order="C") + else: + axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis") + a1 = axis + 1 + if a1 == nd: + perm = list(range(nd)) + arr = x + else: + perm = [i for i in range(nd) if i != axis] + [ + axis, + ] + arr = dpt_ext.permute_dims(x, perm) + if kind is None: + kind = "stable" + if not isinstance(kind, str) or kind not in [ + "stable", + "radixsort", + "mergesort", + ]: + raise ValueError( + "Unsupported kind value. Expected 'stable', 'mergesort', " + f"or 'radixsort', but got '{kind}'" + ) + if kind == "mergesort": + impl_fn = _get_mergesort_impl_fn(descending) + elif kind == "radixsort": + if _radix_sort_dtype_supported(x.dtype.num): + impl_fn = _get_radixsort_impl_fn(descending) + else: + raise ValueError(f"Radix sort is not supported for {x.dtype}") + else: + dt = x.dtype + if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]: + impl_fn = _get_radixsort_impl_fn(descending) + else: + impl_fn = _get_mergesort_impl_fn(descending) + exec_q = x.sycl_queue + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + if arr.flags.c_contiguous: + res = dpt_ext.empty_like(arr, order="C") + ht_ev, impl_ev = impl_fn( + src=arr, + trailing_dims_to_sort=1, + dst=res, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, impl_ev) + else: + tmp = dpt_ext.empty_like(arr, order="C") + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, copy_ev) + res = dpt_ext.empty_like(arr, order="C") + ht_ev, impl_ev = impl_fn( + src=tmp, + trailing_dims_to_sort=1, + dst=res, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, impl_ev) + if a1 != nd: + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + res = dpt_ext.permute_dims(res, inv_perm) + return res + + +def _get_mergeargsort_impl_fn(descending): + return _argsort_descending if descending else _argsort_ascending + + +def _get_radixargsort_impl_fn(descending): + return _radix_argsort_descending if descending else _radix_argsort_ascending + + +def argsort(x, axis=-1, descending=False, stable=True, kind=None): + """argsort(x, axis=-1, descending=False, stable=True) + + Returns the indices that sort an array `x` along a specified axis. + + Args: + x (usm_ndarray): + input array. + axis (Optional[int]): + axis along which to sort. If set to `-1`, the function + must sort along the last axis. Default: `-1`. + descending (Optional[bool]): + sort order. If `True`, the array must be sorted in descending + order (by value). If `False`, the array must be sorted in + ascending order (by value). Default: `False`. + stable (Optional[bool]): + sort stability. If `True`, the returned array must maintain the + relative order of `x` values which compare as equal. If `False`, + the returned array may or may not maintain the relative order of + `x` values which compare as equal. Default: `True`. + kind (Optional[Literal["stable", "mergesort", "radixsort"]]): + Sorting algorithm. The default is `"stable"`, which uses parallel + merge-sort or parallel radix-sort algorithms depending on the + array data type. + + Returns: + usm_ndarray: + an array of indices. The returned array has the same shape as + the input array `x`. The return array has default array index + data type. + """ + if not isinstance(x, dpt.usm_ndarray): + raise TypeError( + f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}" + ) + nd = x.ndim + if nd == 0: + axis = normalize_axis_index(axis, ndim=1, msg_prefix="axis") + return dpt_ext.zeros_like( + x, dtype=ti.default_device_index_type(x.sycl_queue), order="C" + ) + else: + axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis") + a1 = axis + 1 + if a1 == nd: + perm = list(range(nd)) + arr = x + else: + perm = [i for i in range(nd) if i != axis] + [ + axis, + ] + arr = dpt_ext.permute_dims(x, perm) + if kind is None: + kind = "stable" + if not isinstance(kind, str) or kind not in [ + "stable", + "radixsort", + "mergesort", + ]: + raise ValueError( + "Unsupported kind value. Expected 'stable', 'mergesort', " + f"or 'radixsort', but got '{kind}'" + ) + if kind == "mergesort": + impl_fn = _get_mergeargsort_impl_fn(descending) + elif kind == "radixsort": + if _radix_sort_dtype_supported(x.dtype.num): + impl_fn = _get_radixargsort_impl_fn(descending) + else: + raise ValueError(f"Radix sort is not supported for {x.dtype}") + else: + dt = x.dtype + if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]: + impl_fn = _get_radixargsort_impl_fn(descending) + else: + impl_fn = _get_mergeargsort_impl_fn(descending) + exec_q = x.sycl_queue + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + index_dt = ti.default_device_index_type(exec_q) + if arr.flags.c_contiguous: + res = dpt_ext.empty_like(arr, dtype=index_dt, order="C") + ht_ev, impl_ev = impl_fn( + src=arr, + trailing_dims_to_sort=1, + dst=res, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, impl_ev) + else: + tmp = dpt_ext.empty_like(arr, order="C") + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, copy_ev) + res = dpt_ext.empty_like(arr, dtype=index_dt, order="C") + ht_ev, impl_ev = impl_fn( + src=tmp, + trailing_dims_to_sort=1, + dst=res, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, impl_ev) + if a1 != nd: + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + res = dpt_ext.permute_dims(res, inv_perm) + return res + + +def _get_top_k_largest(mode): + modes = {"largest": True, "smallest": False} + try: + return modes[mode] + except KeyError: + raise ValueError( + f"`mode` must be `largest` or `smallest`. Got `{mode}`." + ) + + +class TopKResult(NamedTuple): + values: dpt.usm_ndarray + indices: dpt.usm_ndarray + + +def top_k(x, k, /, *, axis=None, mode="largest"): + """top_k(x, k, axis=None, mode="largest") + + Returns the `k` largest or smallest values and their indices in the input + array `x` along the specified axis `axis`. + + Args: + x (usm_ndarray): + input array. + k (int): + number of elements to find. Must be a positive integer value. + axis (Optional[int]): + axis along which to search. If `None`, the search will be performed + over the flattened array. Default: ``None``. + mode (Literal["largest", "smallest"]): + search mode. Must be one of the following modes: + + - `"largest"`: return the `k` largest elements. + - `"smallest"`: return the `k` smallest elements. + + Default: `"largest"`. + + Returns: + tuple[usm_ndarray, usm_ndarray] + a namedtuple `(values, indices)` whose + + * first element `values` will be an array containing the `k` + largest or smallest elements of `x`. The array has the same data + type as `x`. If `axis` was `None`, `values` will be a + one-dimensional array with shape `(k,)` and otherwise, `values` + will have shape `x.shape[:axis] + (k,) + x.shape[axis+1:]` + * second element `indices` will be an array containing indices of + `x` that result in `values`. The array will have the same shape + as `values` and will have the default array index data type. + """ + largest = _get_top_k_largest(mode) + if not isinstance(x, dpt.usm_ndarray): + raise TypeError( + f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}" + ) + + k = operator.index(k) + if k < 0: + raise ValueError("`k` must be a positive integer value") + + nd = x.ndim + if axis is None: + sz = x.size + if nd == 0: + if k > 1: + raise ValueError(f"`k`={k} is out of bounds 1") + return TopKResult( + dpt_ext.copy(x, order="C"), + dpt_ext.zeros_like( + x, dtype=ti.default_device_index_type(x.sycl_queue) + ), + ) + arr = x + n_search_dims = None + res_sh = k + else: + axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis") + sz = x.shape[axis] + a1 = axis + 1 + if a1 == nd: + perm = list(range(nd)) + arr = x + else: + perm = [i for i in range(nd) if i != axis] + [ + axis, + ] + arr = dpt_ext.permute_dims(x, perm) + n_search_dims = 1 + res_sh = arr.shape[: nd - 1] + (k,) + + if k > sz: + raise ValueError(f"`k`={k} is out of bounds {sz}") + + exec_q = x.sycl_queue + _manager = du.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + + res_usm_type = arr.usm_type + if arr.flags.c_contiguous: + vals = dpt_ext.empty( + res_sh, + dtype=arr.dtype, + usm_type=res_usm_type, + order="C", + sycl_queue=exec_q, + ) + inds = dpt_ext.empty( + res_sh, + dtype=ti.default_device_index_type(exec_q), + usm_type=res_usm_type, + order="C", + sycl_queue=exec_q, + ) + ht_ev, impl_ev = _topk( + src=arr, + trailing_dims_to_search=n_search_dims, + k=k, + largest=largest, + vals=vals, + inds=inds, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, impl_ev) + else: + tmp = dpt_ext.empty_like(arr, order="C") + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_ev, copy_ev) + vals = dpt_ext.empty( + res_sh, + dtype=arr.dtype, + usm_type=res_usm_type, + order="C", + sycl_queue=exec_q, + ) + inds = dpt_ext.empty( + res_sh, + dtype=ti.default_device_index_type(exec_q), + usm_type=res_usm_type, + order="C", + sycl_queue=exec_q, + ) + ht_ev, impl_ev = _topk( + src=tmp, + trailing_dims_to_search=n_search_dims, + k=k, + largest=largest, + vals=vals, + inds=inds, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, impl_ev) + if axis is not None and a1 != nd: + inv_perm = sorted(range(nd), key=lambda d: perm[d]) + vals = dpt_ext.permute_dims(vals, inv_perm) + inds = dpt_ext.permute_dims(inds, inv_perm) + + return TopKResult(vals, inds) diff --git a/dpctl_ext/tensor/libtensor/include/kernels/sorting/isin.hpp b/dpctl_ext/tensor/libtensor/include/kernels/sorting/isin.hpp new file mode 100644 index 000000000000..847fa96ecdff --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/sorting/isin.hpp @@ -0,0 +1,245 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor membership operations. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/search_sorted_detail.hpp" +#include "utils/offset_utils.hpp" +#include "utils/rich_comparisons.hpp" + +namespace dpctl::tensor::kernels +{ + +using dpctl::tensor::ssize_t; + +template +struct IsinFunctor +{ +private: + bool invert; + const T *hay_tp; + const T *needles_tp; + bool *out_tp; + std::size_t hay_nelems; + HayIndexerT hay_indexer; + NeedlesIndexerT needles_indexer; + OutIndexerT out_indexer; + +public: + IsinFunctor(const bool invert_, + const T *hay_, + const T *needles_, + bool *out_, + const std::size_t hay_nelems_, + const HayIndexerT &hay_indexer_, + const NeedlesIndexerT &needles_indexer_, + const OutIndexerT &out_indexer_) + : invert(invert_), hay_tp(hay_), needles_tp(needles_), out_tp(out_), + hay_nelems(hay_nelems_), hay_indexer(hay_indexer_), + needles_indexer(needles_indexer_), out_indexer(out_indexer_) + { + } + + void operator()(sycl::id<1> id) const + { + using Compare = + typename dpctl::tensor::rich_comparisons::AscendingSorter::type; + static constexpr Compare comp{}; + + const std::size_t i = id[0]; + const T needle_v = needles_tp[needles_indexer(i)]; + + // position of the needle_v in the hay array + std::size_t pos{}; + + static constexpr std::size_t zero(0); + // search in hay in left-closed interval, give `pos` such that + // hay[pos - 1] < needle_v <= hay[pos] + + // lower_bound returns the first pos such that bool(hay[pos] < + // needle_v) is false, i.e. needle_v <= hay[pos] + pos = search_sorted_detail::lower_bound_indexed_impl( + hay_tp, zero, hay_nelems, needle_v, comp, hay_indexer); + bool out = (pos == hay_nelems ? false : hay_tp[pos] == needle_v); + out_tp[out_indexer(i)] = (invert) ? !out : out; + } +}; + +typedef sycl::event (*isin_contig_impl_fp_ptr_t)( + sycl::queue &, + const bool, + const std::size_t, + const std::size_t, + const char *, + const ssize_t, + const char *, + const ssize_t, + char *, + const ssize_t, + const std::vector &); + +template +class isin_contig_impl_krn; + +template +sycl::event isin_contig_impl(sycl::queue &exec_q, + const bool invert, + const std::size_t hay_nelems, + const std::size_t needles_nelems, + const char *hay_cp, + const ssize_t hay_offset, + const char *needles_cp, + const ssize_t needles_offset, + char *out_cp, + const ssize_t out_offset, + const std::vector &depends) +{ + const T *hay_tp = reinterpret_cast(hay_cp) + hay_offset; + const T *needles_tp = + reinterpret_cast(needles_cp) + needles_offset; + + bool *out_tp = reinterpret_cast(out_cp) + out_offset; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using KernelName = class isin_contig_impl_krn; + + sycl::range<1> gRange(needles_nelems); + + using TrivialIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + static constexpr TrivialIndexerT hay_indexer{}; + static constexpr TrivialIndexerT needles_indexer{}; + static constexpr TrivialIndexerT out_indexer{}; + + const auto fnctr = + IsinFunctor( + invert, hay_tp, needles_tp, out_tp, hay_nelems, hay_indexer, + needles_indexer, out_indexer); + + cgh.parallel_for(gRange, fnctr); + }); + + return comp_ev; +} + +typedef sycl::event (*isin_strided_impl_fp_ptr_t)( + sycl::queue &, + const bool, + const std::size_t, + const std::size_t, + const char *, + const ssize_t, + const ssize_t, + const char *, + const ssize_t, + char *, + const ssize_t, + int, + const ssize_t *, + const std::vector &); + +template +class isin_strided_impl_krn; + +template +sycl::event isin_strided_impl( + sycl::queue &exec_q, + const bool invert, + const std::size_t hay_nelems, + const std::size_t needles_nelems, + const char *hay_cp, + const ssize_t hay_offset, + // hay is 1D, so hay_nelems, hay_offset, hay_stride describe strided array + const ssize_t hay_stride, + const char *needles_cp, + const ssize_t needles_offset, + char *out_cp, + const ssize_t out_offset, + const int needles_nd, + // packed_shape_strides is [needles_shape, needles_strides, + // out_strides] has length of 3*needles_nd + const ssize_t *packed_shape_strides, + const std::vector &depends) +{ + const T *hay_tp = reinterpret_cast(hay_cp); + const T *needles_tp = reinterpret_cast(needles_cp); + + bool *out_tp = reinterpret_cast(out_cp); + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + sycl::range<1> gRange(needles_nelems); + + using HayIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + const HayIndexerT hay_indexer( + /* offset */ hay_offset, + /* size */ hay_nelems, + /* step */ hay_stride); + + using NeedlesIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const ssize_t *needles_shape_strides = packed_shape_strides; + const NeedlesIndexerT needles_indexer(needles_nd, needles_offset, + needles_shape_strides); + using OutIndexerT = dpctl::tensor::offset_utils::UnpackedStridedIndexer; + + const ssize_t *out_shape = packed_shape_strides; + const ssize_t *out_strides = packed_shape_strides + 2 * needles_nd; + const OutIndexerT out_indexer(needles_nd, out_offset, out_shape, + out_strides); + + const auto fnctr = + IsinFunctor( + invert, hay_tp, needles_tp, out_tp, hay_nelems, hay_indexer, + needles_indexer, out_indexer); + using KernelName = class isin_strided_impl_krn; + + cgh.parallel_for(gRange, fnctr); + }); + + return comp_ev; +} + +} // namespace dpctl::tensor::kernels diff --git a/dpctl_ext/tensor/libtensor/include/kernels/sorting/merge_sort.hpp b/dpctl_ext/tensor/libtensor/include/kernels/sorting/merge_sort.hpp new file mode 100644 index 000000000000..a047c172f7bc --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/sorting/merge_sort.hpp @@ -0,0 +1,856 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor sort/argsort operations. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/search_sorted_detail.hpp" +#include "kernels/sorting/sort_utils.hpp" + +namespace dpctl::tensor::kernels +{ + +namespace merge_sort_detail +{ + +using dpctl::tensor::ssize_t; +using namespace dpctl::tensor::kernels::search_sorted_detail; + +/*! @brief Merge two contiguous sorted segments */ +template +void merge_impl(const std::size_t offset, + const InAcc in_acc, + OutAcc out_acc, + const std::size_t start_1, + const std::size_t end_1, + const std::size_t end_2, + const std::size_t start_out, + Compare comp, + const std::size_t chunk) +{ + const std::size_t start_2 = end_1; + // Borders of the sequences to merge within this call + const std::size_t local_start_1 = sycl::min(offset + start_1, end_1); + const std::size_t local_end_1 = sycl::min(local_start_1 + chunk, end_1); + const std::size_t local_start_2 = sycl::min(offset + start_2, end_2); + const std::size_t local_end_2 = sycl::min(local_start_2 + chunk, end_2); + + const std::size_t local_size_1 = local_end_1 - local_start_1; + const std::size_t local_size_2 = local_end_2 - local_start_2; + + const auto r_item_1 = in_acc[end_1 - 1]; + const auto l_item_2 = (start_2 < end_2) ? in_acc[start_2] : r_item_1; + + // Copy if the sequences are sorted with respect to each other or merge + // otherwise + if (!comp(l_item_2, r_item_1)) { + const std::size_t out_shift_1 = start_out + local_start_1 - start_1; + const std::size_t out_shift_2 = + start_out + end_1 - start_1 + local_start_2 - start_2; + + for (std::size_t i = 0; i < local_size_1; ++i) { + out_acc[out_shift_1 + i] = in_acc[local_start_1 + i]; + } + for (std::size_t i = 0; i < local_size_2; ++i) { + out_acc[out_shift_2 + i] = in_acc[local_start_2 + i]; + } + } + else if (comp(r_item_1, l_item_2)) { + const std::size_t out_shift_1 = + start_out + end_2 - start_2 + local_start_1 - start_1; + const std::size_t out_shift_2 = start_out + local_start_2 - start_2; + for (std::size_t i = 0; i < local_size_1; ++i) { + out_acc[out_shift_1 + i] = in_acc[local_start_1 + i]; + } + for (std::size_t i = 0; i < local_size_2; ++i) { + out_acc[out_shift_2 + i] = in_acc[local_start_2 + i]; + } + } + // Perform merging + else { + + // Process 1st sequence + if (local_start_1 < local_end_1) { + // Reduce the range for searching within the 2nd sequence and handle + // bound items find left border in 2nd sequence + const auto local_l_item_1 = in_acc[local_start_1]; + std::size_t l_search_bound_2 = + lower_bound_impl(in_acc, start_2, end_2, local_l_item_1, comp); + const std::size_t l_shift_1 = local_start_1 - start_1; + const std::size_t l_shift_2 = l_search_bound_2 - start_2; + + out_acc[start_out + l_shift_1 + l_shift_2] = local_l_item_1; + + std::size_t r_search_bound_2{}; + // find right border in 2nd sequence + if (local_size_1 > 1) { + const auto local_r_item_1 = in_acc[local_end_1 - 1]; + r_search_bound_2 = lower_bound_impl( + in_acc, l_search_bound_2, end_2, local_r_item_1, comp); + const auto r_shift_1 = local_end_1 - 1 - start_1; + const auto r_shift_2 = r_search_bound_2 - start_2; + + out_acc[start_out + r_shift_1 + r_shift_2] = local_r_item_1; + } + + // Handle intermediate items + if (r_search_bound_2 == l_search_bound_2) { + const std::size_t shift_2 = l_search_bound_2 - start_2; + for (std::size_t idx = local_start_1 + 1; idx < local_end_1 - 1; + ++idx) { + const auto intermediate_item_1 = in_acc[idx]; + const std::size_t shift_1 = idx - start_1; + out_acc[start_out + shift_1 + shift_2] = + intermediate_item_1; + } + } + else { + for (std::size_t idx = local_start_1 + 1; idx < local_end_1 - 1; + ++idx) { + const auto intermediate_item_1 = in_acc[idx]; + // we shouldn't seek in whole 2nd sequence. Just for the + // part where the 1st sequence should be + l_search_bound_2 = lower_bound_impl( + in_acc, l_search_bound_2, r_search_bound_2, + intermediate_item_1, comp); + const std::size_t shift_1 = idx - start_1; + const std::size_t shift_2 = l_search_bound_2 - start_2; + + out_acc[start_out + shift_1 + shift_2] = + intermediate_item_1; + } + } + } + // Process 2nd sequence + if (local_start_2 < local_end_2) { + // Reduce the range for searching within the 1st sequence and handle + // bound items find left border in 1st sequence + const auto local_l_item_2 = in_acc[local_start_2]; + std::size_t l_search_bound_1 = + upper_bound_impl(in_acc, start_1, end_1, local_l_item_2, comp); + const std::size_t l_shift_1 = l_search_bound_1 - start_1; + const std::size_t l_shift_2 = local_start_2 - start_2; + + out_acc[start_out + l_shift_1 + l_shift_2] = local_l_item_2; + + std::size_t r_search_bound_1{}; + // find right border in 1st sequence + if (local_size_2 > 1) { + const auto local_r_item_2 = in_acc[local_end_2 - 1]; + r_search_bound_1 = upper_bound_impl( + in_acc, l_search_bound_1, end_1, local_r_item_2, comp); + const std::size_t r_shift_1 = r_search_bound_1 - start_1; + const std::size_t r_shift_2 = local_end_2 - 1 - start_2; + + out_acc[start_out + r_shift_1 + r_shift_2] = local_r_item_2; + } + + // Handle intermediate items + if (l_search_bound_1 == r_search_bound_1) { + const std::size_t shift_1 = l_search_bound_1 - start_1; + for (auto idx = local_start_2 + 1; idx < local_end_2 - 1; ++idx) + { + const auto intermediate_item_2 = in_acc[idx]; + const std::size_t shift_2 = idx - start_2; + out_acc[start_out + shift_1 + shift_2] = + intermediate_item_2; + } + } + else { + for (auto idx = local_start_2 + 1; idx < local_end_2 - 1; ++idx) + { + const auto intermediate_item_2 = in_acc[idx]; + // we shouldn't seek in whole 1st sequence. Just for the + // part where the 2nd sequence should be + l_search_bound_1 = upper_bound_impl( + in_acc, l_search_bound_1, r_search_bound_1, + intermediate_item_2, comp); + const std::size_t shift_1 = l_search_bound_1 - start_1; + const std::size_t shift_2 = idx - start_2; + + out_acc[start_out + shift_1 + shift_2] = + intermediate_item_2; + } + } + } + } +} + +template +void insertion_sort_impl(Iter &&first, + std::size_t begin, + std::size_t end, + Compare &&comp) +{ + for (std::size_t i = begin + 1; i < end; ++i) { + const auto val_i = first[i]; + std::size_t j = i - 1; + while ((j + 1 > begin) && (comp(val_i, first[j]))) { + first[j + 1] = first[j]; + --j; + } + if (j + 1 < i) { + first[j + 1] = val_i; + } + } +} + +template +void leaf_sort_impl(Iter &&first, + std::size_t begin, + std::size_t end, + Compare &&comp) +{ + return insertion_sort_impl(std::forward(first), + std::move(begin), std::move(end), + std::forward(comp)); +} + +template +struct GetValueType +{ + using value_type = typename std::iterator_traits::value_type; +}; + +template +struct GetValueType> +{ + using value_type = ElementType; +}; + +template +struct GetValueType< + sycl::accessor> +{ + using value_type = ElementType; +}; + +template +struct GetValueType> +{ + using value_type = ElementType; +}; + +template +struct GetReadOnlyAccess +{ + Iter operator()(const Iter &it, sycl::handler &) + { + return it; + } +}; + +template +struct GetReadOnlyAccess> +{ + auto operator()(const sycl::buffer &buf, + sycl::handler &cgh) + { + sycl::accessor acc(buf, cgh, sycl::read_only); + return acc; + } +}; + +template +struct GetWriteDiscardAccess +{ + Iter operator()(Iter it, sycl::handler &) + { + return it; + } +}; + +template +struct GetWriteDiscardAccess> +{ + auto operator()(sycl::buffer &buf, + sycl::handler &cgh) + { + sycl::accessor acc(buf, cgh, sycl::write_only, sycl::no_init); + return acc; + } +}; + +template +struct GetReadWriteAccess +{ + Iter operator()(Iter &it, sycl::handler &) + { + return it; + } +}; + +template +struct GetReadWriteAccess> +{ + auto operator()(sycl::buffer &buf, + sycl::handler &cgh) + { + sycl::accessor acc(buf, cgh, sycl::read_write); + return acc; + } +}; + +template +class sort_base_step_contig_krn; + +template +sycl::event + sort_base_step_contig_impl(sycl::queue &q, + const std::size_t iter_nelems, + const std::size_t sort_nelems, + const InpAcc input, + OutAcc output, + const Comp &comp, + const std::size_t conseq_nelems_sorted, + const std::vector &depends = {}) +{ + + using inpT = typename GetValueType::value_type; + using outT = typename GetValueType::value_type; + using KernelName = sort_base_step_contig_krn; + + const std::size_t n_segments = + quotient_ceil(sort_nelems, conseq_nelems_sorted); + + sycl::event base_sort = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + const sycl::range<1> gRange{iter_nelems * n_segments}; + + auto input_acc = GetReadOnlyAccess{}(input, cgh); + auto output_acc = GetWriteDiscardAccess{}(output, cgh); + + cgh.parallel_for(gRange, [=](sycl::id<1> id) { + const std::size_t iter_id = id[0] / n_segments; + const std::size_t segment_id = id[0] - iter_id * n_segments; + + const std::size_t iter_offset = iter_id * sort_nelems; + const std::size_t beg_id = + iter_offset + segment_id * conseq_nelems_sorted; + const std::size_t end_id = + iter_offset + + std::min((segment_id + 1) * conseq_nelems_sorted, sort_nelems); + for (std::size_t i = beg_id; i < end_id; ++i) { + output_acc[i] = input_acc[i]; + } + + leaf_sort_impl(output_acc, beg_id, end_id, comp); + }); + }); + + return base_sort; +} + +template +class sort_over_work_group_contig_krn; + +template +sycl::event sort_over_work_group_contig_impl( + sycl::queue &q, + std::size_t iter_nelems, + std::size_t sort_nelems, + const InpAcc input, + OutAcc output, + const Comp &comp, + std::size_t &nelems_wg_sorts, + const std::vector &depends = {}) +{ + using inpT = typename GetValueType::value_type; + using T = typename GetValueType::value_type; + using KernelName = sort_over_work_group_contig_krn; + + const auto &kernel_id = sycl::get_kernel_id(); + + auto const &ctx = q.get_context(); + auto const &dev = q.get_device(); + auto kb = sycl::get_kernel_bundle( + ctx, {dev}, {kernel_id}); + + auto krn = kb.get_kernel(kernel_id); + + const std::uint32_t max_sg_size = krn.template get_info< + sycl::info::kernel_device_specific::max_sub_group_size>(dev); + const std::uint64_t device_local_memory_size = + dev.get_info(); + + // leave 512 bytes of local memory for RT + const std::uint64_t safety_margin = 512; + + const std::uint64_t nelems_per_slm = + (device_local_memory_size - safety_margin) / (2 * sizeof(T)); + + static constexpr std::uint32_t sub_groups_per_work_group = 4; + const std::uint32_t elems_per_wi = dev.has(sycl::aspect::cpu) ? 8 : 2; + + const std::size_t lws = sub_groups_per_work_group * max_sg_size; + + nelems_wg_sorts = elems_per_wi * lws; + + if (nelems_wg_sorts > nelems_per_slm) { + nelems_wg_sorts = (q.get_device().has(sycl::aspect::cpu) ? 16 : 4); + + return sort_base_step_contig_impl( + q, iter_nelems, sort_nelems, input, output, comp, nelems_wg_sorts, + depends); + } + + // This assumption permits doing away with using a loop + assert(nelems_wg_sorts % lws == 0); + + const std::size_t n_segments = quotient_ceil(sort_nelems, nelems_wg_sorts); + + sycl::event base_sort_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.use_kernel_bundle(kb); + + sycl::range<1> global_range{iter_nelems * n_segments * lws}; + sycl::range<1> local_range{lws}; + + sycl::range<1> slm_range{nelems_wg_sorts}; + sycl::local_accessor work_space(slm_range, cgh); + sycl::local_accessor scratch_space(slm_range, cgh); + + auto input_acc = GetReadOnlyAccess{}(input, cgh); + auto output_acc = GetWriteDiscardAccess{}(output, cgh); + + sycl::nd_range<1> ndRange(global_range, local_range); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> it) { + const std::size_t group_id = it.get_group_linear_id(); + const std::size_t iter_id = group_id / n_segments; + const std::size_t segment_id = group_id - iter_id * n_segments; + const std::size_t lid = it.get_local_linear_id(); + + const std::size_t segment_start_idx = segment_id * nelems_wg_sorts; + const std::size_t segment_end_idx = + std::min(segment_start_idx + nelems_wg_sorts, sort_nelems); + const std::size_t wg_chunk_size = + segment_end_idx - segment_start_idx; + + // load input into SLM + for (std::size_t array_id = segment_start_idx + lid; + array_id < segment_end_idx; array_id += lws) + { + T v = (array_id < sort_nelems) + ? input_acc[iter_id * sort_nelems + array_id] + : T{}; + work_space[array_id - segment_start_idx] = v; + } + sycl::group_barrier(it.get_group()); + + const std::size_t chunk = quotient_ceil(nelems_wg_sorts, lws); + + const std::size_t chunk_start_idx = lid * chunk; + const std::size_t chunk_end_idx = + sycl::min(chunk_start_idx + chunk, wg_chunk_size); + + leaf_sort_impl(work_space, chunk_start_idx, chunk_end_idx, comp); + + sycl::group_barrier(it.get_group()); + + bool data_in_temp = false; + std::size_t n_chunks_merged = 1; + + // merge chunk while n_chunks_merged * chunk < wg_chunk_size + const std::size_t max_chunks_merged = + 1 + ((wg_chunk_size - 1) / chunk); + for (; n_chunks_merged < max_chunks_merged; + data_in_temp = !data_in_temp, n_chunks_merged *= 2) + { + const std::size_t nelems_sorted_so_far = + n_chunks_merged * chunk; + const std::size_t q = (lid / n_chunks_merged); + const std::size_t start_1 = + sycl::min(2 * nelems_sorted_so_far * q, wg_chunk_size); + const std::size_t end_1 = + sycl::min(start_1 + nelems_sorted_so_far, wg_chunk_size); + const std::size_t end_2 = + sycl::min(end_1 + nelems_sorted_so_far, wg_chunk_size); + const std::size_t offset = chunk * (lid - q * n_chunks_merged); + + if (data_in_temp) { + merge_impl(offset, scratch_space, work_space, start_1, + end_1, end_2, start_1, comp, chunk); + } + else { + merge_impl(offset, work_space, scratch_space, start_1, + end_1, end_2, start_1, comp, chunk); + } + sycl::group_barrier(it.get_group()); + } + + const auto &out_src = (data_in_temp) ? scratch_space : work_space; + for (std::size_t array_id = segment_start_idx + lid; + array_id < segment_end_idx; array_id += lws) + { + if (array_id < sort_nelems) { + output_acc[iter_id * sort_nelems + array_id] = + out_src[array_id - segment_start_idx]; + } + } + }); + }); + + return base_sort_ev; +} + +class vacuous_krn; + +inline sycl::event tie_events(sycl::queue &q, + const std::vector depends) +{ + if (depends.empty()) + return sycl::event(); + if (depends.size() == 1) + return depends[0]; + + sycl::event e = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + using KernelName = vacuous_krn; + cgh.single_task([]() {}); + }); + + return e; +} + +template +class merge_adjacent_blocks_to_temp_krn; + +template +class merge_adjacent_blocks_from_temp_krn; + +template +sycl::event + merge_sorted_block_contig_impl(sycl::queue &q, + std::size_t iter_nelems, + std::size_t sort_nelems, + Acc output, + const Comp comp, + std::size_t sorted_block_size, + const std::vector &depends = {}) +{ + + if (sorted_block_size >= sort_nelems) + return tie_events(q, depends); + + // experimentally determined value + // size of segments worked upon by each work-item during merging + const sycl::device &dev = q.get_device(); + const std::size_t segment_size = (dev.has(sycl::aspect::cpu)) ? 32 : 4; + + const std::size_t chunk_size = + (sorted_block_size < segment_size) ? sorted_block_size : segment_size; + + assert(sorted_block_size % chunk_size == 0); + + using T = typename GetValueType::value_type; + + sycl::buffer temp_buf(sycl::range<1>{iter_nelems * sort_nelems}); + // T *allocated_mem = sycl::malloc_device(iter_nelems * sort_nelems, q); + + bool needs_copy = true; + bool used_depends = false; + + sycl::event dep_ev; + std::size_t chunks_merged = sorted_block_size / chunk_size; + + assert(!(chunks_merged & (chunks_merged - 1))); + + using ToTempKernelName = class merge_adjacent_blocks_to_temp_krn; + using FromTempKernelName = + class merge_adjacent_blocks_from_temp_krn; + + while (chunks_merged * chunk_size < sort_nelems) { + sycl::event local_dep = dep_ev; + + sycl::event merge_ev = q.submit([&](sycl::handler &cgh) { + if (used_depends) { + cgh.depends_on(local_dep); + } + else { + cgh.depends_on(depends); + used_depends = true; + } + + const std::size_t n_chunks = quotient_ceil(sort_nelems, chunk_size); + + if (needs_copy) { + sycl::accessor temp_acc{temp_buf, cgh, sycl::write_only, + sycl::no_init}; + auto output_acc = GetReadOnlyAccess{}(output, cgh); + cgh.parallel_for( + {iter_nelems * n_chunks}, [=](sycl::id<1> wid) { + auto flat_idx = wid[0]; + auto iter_idx = flat_idx / n_chunks; + auto idx = flat_idx - n_chunks * iter_idx; + + const std::size_t idx_mult = + (idx / chunks_merged) * chunks_merged; + const std::size_t idx_rem = (idx - idx_mult); + const std::size_t start_1 = + sycl::min(2 * idx_mult * chunk_size, sort_nelems); + const std::size_t end_1 = sycl::min( + start_1 + chunks_merged * chunk_size, sort_nelems); + const std::size_t end_2 = sycl::min( + end_1 + chunks_merged * chunk_size, sort_nelems); + const std::size_t offset = chunk_size * idx_rem; + + const std::size_t iter_offset = iter_idx * sort_nelems; + + merge_impl(offset, output_acc, temp_acc, + iter_offset + start_1, iter_offset + end_1, + iter_offset + end_2, iter_offset + start_1, + comp, chunk_size); + }); + } + else { + sycl::accessor temp_acc{temp_buf, cgh, sycl::read_only}; + auto output_acc = GetWriteDiscardAccess{}(output, cgh); + cgh.parallel_for( + {iter_nelems * n_chunks}, [=](sycl::id<1> wid) { + auto flat_idx = wid[0]; + auto iter_idx = flat_idx / n_chunks; + auto idx = flat_idx - n_chunks * iter_idx; + + const std::size_t idx_mult = + (idx / chunks_merged) * chunks_merged; + const std::size_t idx_rem = (idx - idx_mult); + const std::size_t start_1 = + sycl::min(2 * idx_mult * chunk_size, sort_nelems); + const std::size_t end_1 = sycl::min( + start_1 + chunks_merged * chunk_size, sort_nelems); + const std::size_t end_2 = sycl::min( + end_1 + chunks_merged * chunk_size, sort_nelems); + const std::size_t offset = chunk_size * idx_rem; + + const std::size_t iter_offset = iter_idx * sort_nelems; + + merge_impl(offset, temp_acc, output_acc, + iter_offset + start_1, iter_offset + end_1, + iter_offset + end_2, iter_offset + start_1, + comp, chunk_size); + }); + } + }); + + chunks_merged *= 2; + dep_ev = merge_ev; + + if (chunks_merged * chunk_size < sort_nelems) { + needs_copy = !needs_copy; + } + } + + if (needs_copy) { + sycl::event copy_ev = q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dep_ev); + + sycl::accessor temp_acc{temp_buf, cgh, sycl::read_only}; + auto output_acc = GetWriteDiscardAccess{}(output, cgh); + + cgh.copy(temp_acc, output_acc); + }); + dep_ev = copy_ev; + } + + return dep_ev; +} + +} // namespace merge_sort_detail + +template > +sycl::event stable_sort_axis1_contig_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows in a + // matrix when sorting over rows) + std::size_t sort_nelems, // size of each array to sort (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + sort_arg_offset; + argTy *res_tp = + reinterpret_cast(res_cp) + iter_res_offset + sort_res_offset; + + auto comp = Comp{}; + + // constant chosen experimentally to ensure monotonicity of + // sorting performance, as measured on GPU Max, and Iris Xe + constexpr std::size_t sequential_sorting_threshold = 16; + + if (sort_nelems < sequential_sorting_threshold) { + // equal work-item sorts entire row + sycl::event sequential_sorting_ev = + merge_sort_detail::sort_base_step_contig_impl( + exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp, + sort_nelems, depends); + + return sequential_sorting_ev; + } + else { + std::size_t sorted_block_size{}; + + // Sort segments of the array + sycl::event base_sort_ev = + merge_sort_detail::sort_over_work_group_contig_impl( + exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp, + sorted_block_size, // modified in place with size of sorted + // block size + depends); + + // Merge segments in parallel until all elements are sorted + sycl::event merges_ev = + merge_sort_detail::merge_sorted_block_contig_impl( + exec_q, iter_nelems, sort_nelems, res_tp, comp, + sorted_block_size, {base_sort_ev}); + + return merges_ev; + } +} + +template +class populate_index_data_krn; + +template +class index_map_to_rows_krn; + +template +struct IndexComp +{ + IndexComp(const ValueT *data, const ValueComp &comp_op) + : ptr(data), value_comp(comp_op) + { + } + + bool operator()(const IndexT &i1, const IndexT &i2) const + { + return value_comp(ptr[i1], ptr[i2]); + } + +private: + const ValueT *ptr; + ValueComp value_comp; +}; + +template > +sycl::event stable_argsort_axis1_contig_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows in a + // matrix when sorting over rows) + std::size_t sort_nelems, // size of each array to sort (length of rows, + // i.e. number of columns) + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + sort_arg_offset; + IndexTy *res_tp = + reinterpret_cast(res_cp) + iter_res_offset + sort_res_offset; + + const IndexComp index_comp{arg_tp, ValueComp{}}; + + static constexpr std::size_t determine_automatically = 0; + std::size_t sorted_block_size = determine_automatically; + + const std::size_t total_nelems = iter_nelems * sort_nelems; + + using dpctl::tensor::kernels::sort_utils_detail::iota_impl; + + using IotaKernelName = populate_index_data_krn; + + sycl::event populate_indexed_data_ev = iota_impl( + exec_q, res_tp, total_nelems, depends); + + // Sort segments of the array + sycl::event base_sort_ev = + merge_sort_detail::sort_over_work_group_contig_impl( + exec_q, iter_nelems, sort_nelems, res_tp, res_tp, index_comp, + sorted_block_size, // modified in place with size of sorted block + // size + {populate_indexed_data_ev}); + + // Merge segments in parallel until all elements are sorted + sycl::event merges_ev = merge_sort_detail::merge_sorted_block_contig_impl( + exec_q, iter_nelems, sort_nelems, res_tp, index_comp, sorted_block_size, + {base_sort_ev}); + + // no need to map back if iter_nelems == 1 + if (iter_nelems == 1u) { + return merges_ev; + } + + using MapBackKernelName = index_map_to_rows_krn; + using dpctl::tensor::kernels::sort_utils_detail::map_back_impl; + + sycl::event write_out_ev = map_back_impl( + exec_q, total_nelems, res_tp, res_tp, sort_nelems, {merges_ev}); + + return write_out_ev; +} + +} // namespace dpctl::tensor::kernels diff --git a/dpctl_ext/tensor/libtensor/include/kernels/sorting/radix_sort.hpp b/dpctl_ext/tensor/libtensor/include/kernels/sorting/radix_sort.hpp new file mode 100644 index 000000000000..545444101fc6 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/sorting/radix_sort.hpp @@ -0,0 +1,1920 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/sort_utils.hpp" +#include "utils/sycl_alloc_utils.hpp" + +namespace dpctl::tensor::kernels +{ + +namespace radix_sort_details +{ + +template +class radix_sort_count_kernel; + +template +class radix_sort_scan_kernel; + +template +class radix_sort_reorder_peer_kernel; + +template +class radix_sort_reorder_kernel; + +/*! @brief Computes smallest exponent such that `n <= (1 << exponent)` */ +template && + sizeof(SizeT) == sizeof(std::uint64_t), + int> = 0> +std::uint32_t ceil_log2(SizeT n) +{ + // if n > 2^b, n = q * 2^b + r for q > 0 and 0 <= r < 2^b + // floor_log2(q * 2^b + r) == floor_log2(q * 2^b) == q + floor_log2(n1) + // ceil_log2(n) == 1 + floor_log2(n-1) + if (n <= 1) + return std::uint32_t{1}; + + std::uint32_t exp{1}; + --n; + if (n >= (SizeT{1} << 32)) { + n >>= 32; + exp += 32; + } + if (n >= (SizeT{1} << 16)) { + n >>= 16; + exp += 16; + } + if (n >= (SizeT{1} << 8)) { + n >>= 8; + exp += 8; + } + if (n >= (SizeT{1} << 4)) { + n >>= 4; + exp += 4; + } + if (n >= (SizeT{1} << 2)) { + n >>= 2; + exp += 2; + } + if (n >= (SizeT{1} << 1)) { + n >>= 1; + ++exp; + } + return exp; +} + +//---------------------------------------------------------- +// bitwise order-preserving conversions to unsigned integers +//---------------------------------------------------------- + +template +bool order_preserving_cast(bool val) +{ + if constexpr (is_ascending) + return val; + else + return !val; +} + +template , int> = 0> +UIntT order_preserving_cast(UIntT val) +{ + if constexpr (is_ascending) { + return val; + } + else { + // bitwise invert + return (~val); + } +} + +template && std::is_signed_v, + int> = 0> +std::make_unsigned_t order_preserving_cast(IntT val) +{ + using UIntT = std::make_unsigned_t; + const UIntT uint_val = sycl::bit_cast(val); + + if constexpr (is_ascending) { + // ascending_mask: 100..0 + static constexpr UIntT ascending_mask = + (UIntT(1) << std::numeric_limits::digits); + return (uint_val ^ ascending_mask); + } + else { + // descending_mask: 011..1 + static constexpr UIntT descending_mask = + (std::numeric_limits::max() >> 1); + return (uint_val ^ descending_mask); + } +} + +template +std::uint16_t order_preserving_cast(sycl::half val) +{ + using UIntT = std::uint16_t; + + const UIntT uint_val = sycl::bit_cast( + (sycl::isnan(val)) ? std::numeric_limits::quiet_NaN() + : val); + UIntT mask; + + // test the sign bit of the original value + const bool zero_fp_sign_bit = (UIntT(0) == (uint_val >> 15)); + + static constexpr UIntT zero_mask = UIntT(0x8000u); + static constexpr UIntT nonzero_mask = UIntT(0xFFFFu); + + static constexpr UIntT inv_zero_mask = static_cast(~zero_mask); + static constexpr UIntT inv_nonzero_mask = static_cast(~nonzero_mask); + + if constexpr (is_ascending) { + mask = (zero_fp_sign_bit) ? zero_mask : nonzero_mask; + } + else { + mask = (zero_fp_sign_bit) ? (inv_zero_mask) : (inv_nonzero_mask); + } + + return (uint_val ^ mask); +} + +template && + sizeof(FloatT) == sizeof(std::uint32_t), + int> = 0> +std::uint32_t order_preserving_cast(FloatT val) +{ + using UIntT = std::uint32_t; + + UIntT uint_val = sycl::bit_cast( + (sycl::isnan(val)) ? std::numeric_limits::quiet_NaN() : val); + + UIntT mask; + + // test the sign bit of the original value + const bool zero_fp_sign_bit = (UIntT(0) == (uint_val >> 31)); + + static constexpr UIntT zero_mask = UIntT(0x80000000u); + static constexpr UIntT nonzero_mask = UIntT(0xFFFFFFFFu); + + if constexpr (is_ascending) + mask = (zero_fp_sign_bit) ? zero_mask : nonzero_mask; + else + mask = (zero_fp_sign_bit) ? (~zero_mask) : (~nonzero_mask); + + return (uint_val ^ mask); +} + +template && + sizeof(FloatT) == sizeof(std::uint64_t), + int> = 0> +std::uint64_t order_preserving_cast(FloatT val) +{ + using UIntT = std::uint64_t; + + UIntT uint_val = sycl::bit_cast( + (sycl::isnan(val)) ? std::numeric_limits::quiet_NaN() : val); + UIntT mask; + + // test the sign bit of the original value + const bool zero_fp_sign_bit = (UIntT(0) == (uint_val >> 63)); + + static constexpr UIntT zero_mask = UIntT(0x8000000000000000u); + static constexpr UIntT nonzero_mask = UIntT(0xFFFFFFFFFFFFFFFFu); + + if constexpr (is_ascending) + mask = (zero_fp_sign_bit) ? zero_mask : nonzero_mask; + else + mask = (zero_fp_sign_bit) ? (~zero_mask) : (~nonzero_mask); + + return (uint_val ^ mask); +} + +//----------------- +// bucket functions +//----------------- + +template +constexpr std::size_t number_of_bits_in_type() +{ + constexpr std::size_t type_bits = + (sizeof(T) * std::numeric_limits::digits); + return type_bits; +} + +// the number of buckets (size of radix bits) in T +template +constexpr std::uint32_t number_of_buckets_in_type(std::uint32_t radix_bits) +{ + constexpr std::size_t type_bits = number_of_bits_in_type(); + return (type_bits + radix_bits - 1) / radix_bits; +} + +// get bits value (bucket) in a certain radix position +template +std::uint32_t get_bucket_id(T val, std::uint32_t radix_offset) +{ + static_assert(std::is_unsigned_v); + + return (val >> radix_offset) & T(radix_mask); +} + +//-------------------------------- +// count kernel (single iteration) +//-------------------------------- + +template +sycl::event + radix_sort_count_submit(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_segments, + std::size_t wg_size, + std::uint32_t radix_offset, + std::size_t n_values, + ValueT *vals_ptr, + std::size_t n_counts, + CountT *counts_ptr, + const Proj &proj_op, + const bool is_ascending, + const std::vector &dependency_events) +{ + // bin_count = radix_states used for an array storing bucket state counters + static constexpr std::uint32_t radix_states = + (std::uint32_t(1) << radix_bits); + static constexpr std::uint32_t radix_mask = radix_states - 1; + + // iteration space info + const std::size_t n = n_values; + // each segment is processed by a work-group + const std::size_t elems_per_segment = (n + n_segments - 1) / n_segments; + const std::size_t no_op_flag_id = n_counts - 1; + + assert(n_counts == (n_segments + 1) * radix_states + 1); + + sycl::event local_count_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependency_events); + + sycl::local_accessor counts_lacc(wg_size * radix_states, + cgh); + + sycl::nd_range<1> ndRange(n_iters * n_segments * wg_size, wg_size); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> ndit) { + // 0 <= lid < wg_size + const std::size_t lid = ndit.get_local_id(0); + // 0 <= group_id < n_segments * n_iters + const std::size_t group_id = ndit.get_group(0); + const std::size_t iter_id = group_id / n_segments; + const std::size_t val_iter_offset = iter_id * n; + // 0 <= wgr_id < n_segments + const std::size_t wgr_id = group_id - iter_id * n_segments; + + const std::size_t seg_start = elems_per_segment * wgr_id; + + // count per work-item: create a private array for storing count + // values here bin_count = radix_states + std::array counts_arr = {CountT{0}}; + + // count per work-item: count values and write result to private + // count array + const std::size_t seg_end = + sycl::min(seg_start + elems_per_segment, n); + if (is_ascending) { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += wg_size) { + // get the bucket for the bit-ordered input value, + // applying the offset and mask for radix bits + const auto val = + order_preserving_cast( + proj_op(vals_ptr[val_iter_offset + val_id])); + const std::uint32_t bucket_id = + get_bucket_id(val, radix_offset); + + // increment counter for this bit bucket + ++counts_arr[bucket_id]; + } + } + else { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += wg_size) { + // get the bucket for the bit-ordered input value, + // applying the offset and mask for radix bits + const auto val = + order_preserving_cast( + proj_op(vals_ptr[val_iter_offset + val_id])); + const std::uint32_t bucket_id = + get_bucket_id(val, radix_offset); + + // increment counter for this bit bucket + ++counts_arr[bucket_id]; + } + } + + // count per work-item: write private count array to local count + // array counts_lacc is concatenation of private count arrays from + // each work-item in the order of their local ids + const std::uint32_t count_start_id = radix_states * lid; + for (std::uint32_t radix_state_id = 0; + radix_state_id < radix_states; ++radix_state_id) + { + counts_lacc[count_start_id + radix_state_id] = + counts_arr[radix_state_id]; + } + + sycl::group_barrier(ndit.get_group()); + + // count per work-group: reduce till count_lacc[] size > wg_size + // all work-items in the work-group do the work. + for (std::uint32_t i = 1; i < radix_states; ++i) { + // Since we interested in computing total count over work-group + // for each radix state, the correct result is only assured if + // wg_size >= radix_states + counts_lacc[lid] += counts_lacc[wg_size * i + lid]; + } + + sycl::group_barrier(ndit.get_group()); + + // count per work-group: reduce until count_lacc[] size > + // radix_states (n_witems /= 2 per iteration) + for (std::uint32_t n_witems = (wg_size >> 1); + n_witems >= radix_states; n_witems >>= 1) + { + if (lid < n_witems) + counts_lacc[lid] += counts_lacc[n_witems + lid]; + + sycl::group_barrier(ndit.get_group()); + } + + const std::size_t iter_counter_offset = iter_id * n_counts; + + // count per work-group: write local count array to global count + // array + if (lid < radix_states) { + // move buckets with the same id to adjacent positions, + // thus splitting count array into radix_states regions + counts_ptr[iter_counter_offset + (n_segments + 1) * lid + + wgr_id] = counts_lacc[lid]; + } + + // side work: reset 'no-operation-flag', signaling to skip re-order + // phase + if (wgr_id == 0 && lid == 0) { + CountT &no_op_flag = + counts_ptr[iter_counter_offset + no_op_flag_id]; + no_op_flag = 0; + } + }); + }); + + return local_count_ev; +} + +//----------------------------------------------------------------------- +// radix sort: scan kernel (single iteration) +//----------------------------------------------------------------------- + +template +sycl::event radix_sort_scan_submit(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_segments, + std::size_t wg_size, + std::size_t n_values, + std::size_t n_counts, + CountT *counts_ptr, + const std::vector depends) +{ + const std::size_t no_op_flag_id = n_counts - 1; + + // Scan produces local offsets using count values. + // There are no local offsets for the first segment, but the rest segments + // should be scanned with respect to the count value in the first segment + // what requires n + 1 positions + const std::size_t scan_size = n_segments + 1; + wg_size = std::min(scan_size, wg_size); + + static constexpr std::uint32_t radix_states = std::uint32_t(1) + << radix_bits; + + // compilation of the kernel prevents out of resources issue, which may + // occur due to usage of collective algorithms such as joint_exclusive_scan + // even if local memory is not explicitly requested + sycl::event scan_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + sycl::nd_range<1> ndRange(n_iters * radix_states * wg_size, wg_size); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> ndit) { + const std::size_t group_id = ndit.get_group(0); + const std::size_t iter_id = group_id / radix_states; + const std::size_t wgr_id = group_id - iter_id * radix_states; + // find borders of a region with a specific bucket id + auto begin_ptr = + counts_ptr + scan_size * wgr_id + iter_id * n_counts; + + sycl::joint_exclusive_scan(ndit.get_group(), begin_ptr, + begin_ptr + scan_size, begin_ptr, + CountT(0), sycl::plus{}); + + const auto lid = ndit.get_local_linear_id(); + + // NB: No race condition here, because the condition may ever be + // true for only on one WG, one WI. + if ((lid == wg_size - 1) && (begin_ptr[scan_size - 1] == n_values)) + { + // set flag, since all the values got into one + // this is optimization, may happy often for + // higher radix offsets (all zeros) + auto &no_op_flag = + counts_ptr[iter_id * n_counts + no_op_flag_id]; + no_op_flag = 1; + } + }); + }); + + return scan_ev; +} + +//----------------------------------------------------------------------- +// radix sort: group level reorder algorithms +//----------------------------------------------------------------------- + +struct empty_storage +{ + template + empty_storage(T &&...) + { + } +}; + +// Number with `n` least significant bits of uint32_t +inline std::uint32_t n_ls_bits_set(std::uint32_t n) noexcept +{ + static constexpr std::uint32_t zero{}; + static constexpr std::uint32_t all_bits_set = ~zero; + + return ~(all_bits_set << n); +} + +enum class peer_prefix_algo +{ + subgroup_ballot, + atomic_fetch_or, + scan_then_broadcast +}; + +template +struct peer_prefix_helper; + +template +auto get_accessor_pointer(const AccT &acc) +{ + return acc.template get_multi_ptr().get(); +} + +template +struct peer_prefix_helper +{ + using AtomicT = sycl::atomic_ref; + using TempStorageT = sycl::local_accessor; + +private: + sycl::sub_group sgroup; + std::uint32_t lid; + std::uint32_t item_mask; + AtomicT atomic_peer_mask; + +public: + peer_prefix_helper(sycl::nd_item<1> ndit, TempStorageT lacc) + : sgroup(ndit.get_sub_group()), lid(ndit.get_local_linear_id()), + item_mask(n_ls_bits_set(lid)), atomic_peer_mask(lacc[0]) + { + } + + std::uint32_t peer_contribution(OffsetT &new_offset_id, + OffsetT offset_prefix, + bool wi_bit_set) const + { + // reset mask for each radix state + if (lid == 0) + atomic_peer_mask.store(std::uint32_t{0}); + sycl::group_barrier(sgroup); + + const std::uint32_t uint_contrib{wi_bit_set ? std::uint32_t{1} + : std::uint32_t{0}}; + + // set local id's bit to 1 if the bucket value matches the radix state + atomic_peer_mask.fetch_or(uint_contrib << lid); + sycl::group_barrier(sgroup); + std::uint32_t peer_mask_bits = atomic_peer_mask.load(); + std::uint32_t sg_total_offset = sycl::popcount(peer_mask_bits); + + // get the local offset index from the bits set in the peer mask with + // index less than the work item ID + peer_mask_bits &= item_mask; + new_offset_id |= wi_bit_set + ? (offset_prefix + sycl::popcount(peer_mask_bits)) + : OffsetT{0}; + return sg_total_offset; + } +}; + +template +struct peer_prefix_helper +{ + using TempStorageT = empty_storage; + using ItemType = sycl::nd_item<1>; + using SubGroupType = sycl::sub_group; + +private: + SubGroupType sgroup; + std::uint32_t sg_size; + +public: + peer_prefix_helper(sycl::nd_item<1> ndit, TempStorageT) + : sgroup(ndit.get_sub_group()), sg_size(sgroup.get_local_range()[0]) + { + } + + std::uint32_t peer_contribution(OffsetT &new_offset_id, + OffsetT offset_prefix, + bool wi_bit_set) const + { + const std::uint32_t contrib{wi_bit_set ? std::uint32_t{1} + : std::uint32_t{0}}; + + std::uint32_t sg_item_offset = sycl::exclusive_scan_over_group( + sgroup, contrib, sycl::plus{}); + + new_offset_id |= + (wi_bit_set ? (offset_prefix + sg_item_offset) : OffsetT(0)); + + // the last scanned value does not contain number of all copies, thus + // adding contribution + std::uint32_t sg_total_offset = sycl::group_broadcast( + sgroup, sg_item_offset + contrib, sg_size - 1); + + return sg_total_offset; + } +}; + +template +struct peer_prefix_helper +{ +private: + sycl::sub_group sgroup; + std::uint32_t lid; + sycl::ext::oneapi::sub_group_mask item_sg_mask; + + sycl::ext::oneapi::sub_group_mask mask_builder(std::uint32_t mask, + std::uint32_t sg_size) + { + return sycl::detail::Builder::createSubGroupMask< + sycl::ext::oneapi::sub_group_mask>(mask, sg_size); + } + +public: + using TempStorageT = empty_storage; + + peer_prefix_helper(sycl::nd_item<1> ndit, TempStorageT) + : sgroup(ndit.get_sub_group()), lid(ndit.get_local_linear_id()), + item_sg_mask( + mask_builder(n_ls_bits_set(lid), sgroup.get_local_linear_range())) + { + } + + std::uint32_t peer_contribution(OffsetT &new_offset_id, + OffsetT offset_prefix, + bool wi_bit_set) const + { + // set local id's bit to 1 if the bucket value matches the radix state + auto peer_mask = sycl::ext::oneapi::group_ballot(sgroup, wi_bit_set); + std::uint32_t peer_mask_bits{}; + + peer_mask.extract_bits(peer_mask_bits); + std::uint32_t sg_total_offset = sycl::popcount(peer_mask_bits); + + // get the local offset index from the bits set in the peer mask with + // index less than the work item ID + peer_mask &= item_sg_mask; + peer_mask.extract_bits(peer_mask_bits); + + new_offset_id |= wi_bit_set + ? (offset_prefix + sycl::popcount(peer_mask_bits)) + : OffsetT(0); + + return sg_total_offset; + } +}; + +template +void copy_func_for_radix_sort(const std::size_t n_segments, + const std::size_t elems_per_segment, + const std::size_t sg_size, + const std::uint32_t lid, + const std::size_t wgr_id, + const InputT *input_ptr, + const std::size_t n_values, + OutputT *output_ptr) +{ + // item info + const std::size_t seg_start = elems_per_segment * wgr_id; + + std::size_t seg_end = sycl::min(seg_start + elems_per_segment, n_values); + + // ensure that each work item in a subgroup does the same number of loop + // iterations + const std::uint16_t tail_size = (seg_end - seg_start) % sg_size; + seg_end -= tail_size; + + // find offsets for the same values within a segment and fill the resulting + // buffer + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += sg_size) { + output_ptr[val_id] = std::move(input_ptr[val_id]); + } + + if (tail_size > 0 && lid < tail_size) { + const std::size_t val_id = seg_end + lid; + output_ptr[val_id] = std::move(input_ptr[val_id]); + } +} + +//----------------------------------------------------------------------- +// radix sort: reorder kernel (per iteration) +//----------------------------------------------------------------------- +template +sycl::event + radix_sort_reorder_submit(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_segments, + std::uint32_t radix_offset, + std::size_t n_values, + const InputT *input_ptr, + OutputT *output_ptr, + std::size_t n_offsets, + OffsetT *offset_ptr, + const ProjT &proj_op, + const bool is_ascending, + const std::vector dependency_events) +{ + using ValueT = InputT; + using PeerHelper = peer_prefix_helper; + + static constexpr std::uint32_t radix_states = std::uint32_t{1} + << radix_bits; + static constexpr std::uint32_t radix_mask = radix_states - 1; + const std::size_t elems_per_segment = + (n_values + n_segments - 1) / n_segments; + + const std::size_t no_op_flag_id = n_offsets - 1; + + const auto &kernel_id = sycl::get_kernel_id(); + + auto const &ctx = exec_q.get_context(); + auto const &dev = exec_q.get_device(); + auto kb = sycl::get_kernel_bundle( + ctx, {dev}, {kernel_id}); + + auto krn = kb.get_kernel(kernel_id); + + const std::uint32_t sg_size = krn.template get_info< + sycl::info::kernel_device_specific::max_sub_group_size>(dev); + + sycl::event reorder_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependency_events); + cgh.use_kernel_bundle(kb); + + using StorageT = typename PeerHelper::TempStorageT; + + StorageT peer_temp(1, cgh); + + sycl::range<1> lRange{sg_size}; + sycl::range<1> gRange{n_iters * n_segments * sg_size}; + + sycl::nd_range<1> ndRange{gRange, lRange}; + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> ndit) { + const std::size_t group_id = ndit.get_group(0); + const std::size_t iter_id = group_id / n_segments; + const std::size_t segment_id = group_id - iter_id * n_segments; + + auto b_offset_ptr = offset_ptr + iter_id * n_offsets; + auto b_input_ptr = input_ptr + iter_id * n_values; + auto b_output_ptr = output_ptr + iter_id * n_values; + + const std::uint32_t lid = ndit.get_local_id(0); + + auto &no_op_flag = b_offset_ptr[no_op_flag_id]; + if (no_op_flag) { + // no reordering necessary, simply copy + copy_func_for_radix_sort( + n_segments, elems_per_segment, sg_size, lid, segment_id, + b_input_ptr, n_values, b_output_ptr); + return; + } + + // create a private array for storing offset values + // and add total offset and offset for compute unit + // for a certain radix state + std::array offset_arr{}; + const std::size_t scan_size = n_segments + 1; + + OffsetT scanned_bin = 0; + + /* find cumulative offset */ + static constexpr std::uint32_t zero_radix_state_id = 0; + offset_arr[zero_radix_state_id] = b_offset_ptr[segment_id]; + + for (std::uint32_t radix_state_id = 1; + radix_state_id < radix_states; ++radix_state_id) + { + const std::uint32_t local_offset_id = + segment_id + scan_size * radix_state_id; + + // scan bins serially + const std::size_t last_segment_bucket_id = + radix_state_id * scan_size - 1; + scanned_bin += b_offset_ptr[last_segment_bucket_id]; + + offset_arr[radix_state_id] = + scanned_bin + b_offset_ptr[local_offset_id]; + } + + const std::size_t seg_start = elems_per_segment * segment_id; + std::size_t seg_end = + sycl::min(seg_start + elems_per_segment, n_values); + // ensure that each work item in a subgroup does the same number of + // loop iterations + const std::uint32_t tail_size = (seg_end - seg_start) % sg_size; + seg_end -= tail_size; + + const PeerHelper peer_prefix_hlp(ndit, peer_temp); + + // find offsets for the same values within a segment and fill the + // resulting buffer + if (is_ascending) { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += sg_size) { + ValueT in_val = std::move(b_input_ptr[val_id]); + + // get the bucket for the bit-ordered input value, applying + // the offset and mask for radix bits + const auto mapped_val = + order_preserving_cast( + proj_op(in_val)); + std::uint32_t bucket_id = + get_bucket_id(mapped_val, radix_offset); + + OffsetT new_offset_id = 0; + for (std::uint32_t radix_state_id = 0; + radix_state_id < radix_states; ++radix_state_id) + { + bool is_current_bucket = (bucket_id == radix_state_id); + std::uint32_t sg_total_offset = + peer_prefix_hlp.peer_contribution( + /* modified by reference */ new_offset_id, + offset_arr[radix_state_id], + /* bit contribution from this work-item */ + is_current_bucket); + offset_arr[radix_state_id] += sg_total_offset; + } + b_output_ptr[new_offset_id] = std::move(in_val); + } + } + else { + for (std::size_t val_id = seg_start + lid; val_id < seg_end; + val_id += sg_size) { + ValueT in_val = std::move(b_input_ptr[val_id]); + + // get the bucket for the bit-ordered input value, applying + // the offset and mask for radix bits + const auto mapped_val = + order_preserving_cast( + proj_op(in_val)); + std::uint32_t bucket_id = + get_bucket_id(mapped_val, radix_offset); + + OffsetT new_offset_id = 0; + for (std::uint32_t radix_state_id = 0; + radix_state_id < radix_states; ++radix_state_id) + { + bool is_current_bucket = (bucket_id == radix_state_id); + std::uint32_t sg_total_offset = + peer_prefix_hlp.peer_contribution( + /* modified by reference */ new_offset_id, + offset_arr[radix_state_id], + /* bit contribution from this work-item */ + is_current_bucket); + offset_arr[radix_state_id] += sg_total_offset; + } + b_output_ptr[new_offset_id] = std::move(in_val); + } + } + if (tail_size > 0) { + ValueT in_val; + + // default: is greater than any actual radix state + std::uint32_t bucket_id = radix_states; + if (lid < tail_size) { + in_val = std::move(b_input_ptr[seg_end + lid]); + + const auto proj_val = proj_op(in_val); + const auto mapped_val = + (is_ascending) + ? order_preserving_cast( + proj_val) + : order_preserving_cast( + proj_val); + bucket_id = + get_bucket_id(mapped_val, radix_offset); + } + + OffsetT new_offset_id = 0; + for (std::uint32_t radix_state_id = 0; + radix_state_id < radix_states; ++radix_state_id) + { + bool is_current_bucket = (bucket_id == radix_state_id); + std::uint32_t sg_total_offset = + peer_prefix_hlp.peer_contribution( + new_offset_id, offset_arr[radix_state_id], + is_current_bucket); + + offset_arr[radix_state_id] += sg_total_offset; + } + + if (lid < tail_size) { + b_output_ptr[new_offset_id] = std::move(in_val); + } + } + }); + }); + + return reorder_ev; +} + +template +sizeT _slm_adjusted_work_group_size(sycl::queue &exec_q, + sizeT required_slm_bytes_per_wg, + sizeT wg_size) +{ + const auto &dev = exec_q.get_device(); + + if (wg_size == 0) + wg_size = + dev.template get_info(); + + const auto local_mem_sz = + dev.template get_info(); + + return sycl::min(local_mem_sz / required_slm_bytes_per_wg, wg_size); +} + +//----------------------------------------------------------------------- +// radix sort: one iteration +//----------------------------------------------------------------------- + +template +struct parallel_radix_sort_iteration_step +{ + template + using count_phase = radix_sort_count_kernel; + template + using local_scan_phase = radix_sort_scan_kernel; + template + using reorder_peer_phase = + radix_sort_reorder_peer_kernel; + template + using reorder_phase = radix_sort_reorder_kernel; + + template + static sycl::event submit(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_segments, + std::uint32_t radix_iter, + std::size_t n_values, + const InputT *in_ptr, + OutputT *out_ptr, + std::size_t n_counts, + CountT *counts_ptr, + const ProjT &proj_op, + const bool is_ascending, + const std::vector &dependency_events) + { + using _RadixCountKernel = count_phase; + using _RadixLocalScanKernel = + local_scan_phase; + using _RadixReorderPeerKernel = + reorder_peer_phase; + using _RadixReorderKernel = + reorder_phase; + + const auto &supported_sub_group_sizes = + exec_q.get_device() + .template get_info(); + const std::size_t max_sg_size = + (supported_sub_group_sizes.empty() + ? 0 + : supported_sub_group_sizes.back()); + const std::size_t reorder_sg_size = max_sg_size; + const std::size_t scan_wg_size = + exec_q.get_device() + .template get_info(); + + static constexpr std::size_t two_mils = (std::size_t(1) << 21); + std::size_t count_wg_size = + ((max_sg_size > 0) && (n_values > two_mils) ? 128 : max_sg_size); + + static constexpr std::uint32_t radix_states = std::uint32_t(1) + << radix_bits; + + // correct count_wg_size according to local memory limit in count phase + const auto max_count_wg_size = _slm_adjusted_work_group_size( + exec_q, sizeof(CountT) * radix_states, count_wg_size); + count_wg_size = + static_cast<::std::size_t>((max_count_wg_size / radix_states)) * + radix_states; + + // work-group size must be a power of 2 and not less than the number of + // states, for scanning to work correctly + + const std::size_t rounded_down_count_wg_size = + std::size_t{1} << (number_of_bits_in_type() - + sycl::clz(count_wg_size) - 1); + count_wg_size = + sycl::max(rounded_down_count_wg_size, std::size_t(radix_states)); + + // Compute the radix position for the given iteration + std::uint32_t radix_offset = radix_iter * radix_bits; + + // 1. Count Phase + sycl::event count_ev = + radix_sort_count_submit<_RadixCountKernel, radix_bits>( + exec_q, n_iters, n_segments, count_wg_size, radix_offset, + n_values, in_ptr, n_counts, counts_ptr, proj_op, is_ascending, + dependency_events); + + // 2. Scan Phase + sycl::event scan_ev = + radix_sort_scan_submit<_RadixLocalScanKernel, radix_bits>( + exec_q, n_iters, n_segments, scan_wg_size, n_values, n_counts, + counts_ptr, {count_ev}); + + // 3. Reorder Phase + sycl::event reorder_ev{}; + // subgroup_ballot-based peer algo uses extract_bits to populate + // uint32_t mask and hence relies on sub-group to be 32 or narrower + static constexpr std::size_t sg32_v = 32u; + static constexpr std::size_t sg16_v = 16u; + static constexpr std::size_t sg08_v = 8u; + if (sg32_v == reorder_sg_size || sg16_v == reorder_sg_size || + sg08_v == reorder_sg_size) + { + static constexpr auto peer_algorithm = + peer_prefix_algo::subgroup_ballot; + + reorder_ev = radix_sort_reorder_submit<_RadixReorderPeerKernel, + radix_bits, peer_algorithm>( + exec_q, n_iters, n_segments, radix_offset, n_values, in_ptr, + out_ptr, n_counts, counts_ptr, proj_op, is_ascending, + {scan_ev}); + } + else { + static constexpr auto peer_algorithm = + peer_prefix_algo::scan_then_broadcast; + + reorder_ev = radix_sort_reorder_submit<_RadixReorderKernel, + radix_bits, peer_algorithm>( + exec_q, n_iters, n_segments, radix_offset, n_values, in_ptr, + out_ptr, n_counts, counts_ptr, proj_op, is_ascending, + {scan_ev}); + } + + return reorder_ev; + } +}; // struct parallel_radix_sort_iteration + +template +class radix_sort_one_wg_krn; + +template +struct subgroup_radix_sort +{ +private: + class use_slm_tag + { + }; + class use_global_mem_tag + { + }; + +public: + template + sycl::event operator()(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_to_sort, + ValueT *input_ptr, + OutputT *output_ptr, + ProjT proj_op, + const bool is_ascending, + const std::vector &depends) + { + static_assert(std::is_same_v, OutputT>); + + using _SortKernelLoc = + radix_sort_one_wg_krn; + using _SortKernelPartGlob = + radix_sort_one_wg_krn; + using _SortKernelGlob = + radix_sort_one_wg_krn; + + static constexpr std::size_t max_concurrent_work_groups = 128U; + + // Choose this to occupy the entire accelerator + const std::size_t n_work_groups = + std::min(n_iters, max_concurrent_work_groups); + + // determine which temporary allocation can be accommodated in SLM + const auto &SLM_availability = + check_slm_size(exec_q, n_to_sort); + + const std::size_t n_batch_size = n_work_groups; + + switch (SLM_availability) { + case temp_allocations::both_in_slm: + { + static constexpr auto storage_for_values = use_slm_tag{}; + static constexpr auto storage_for_counters = use_slm_tag{}; + + return one_group_submitter<_SortKernelLoc>()( + exec_q, n_iters, n_iters, n_to_sort, input_ptr, output_ptr, + proj_op, is_ascending, storage_for_values, storage_for_counters, + depends); + } + case temp_allocations::counters_in_slm: + { + static constexpr auto storage_for_values = use_global_mem_tag{}; + static constexpr auto storage_for_counters = use_slm_tag{}; + + return one_group_submitter<_SortKernelPartGlob>()( + exec_q, n_iters, n_batch_size, n_to_sort, input_ptr, output_ptr, + proj_op, is_ascending, storage_for_values, storage_for_counters, + depends); + } + default: + { + static constexpr auto storage_for_values = use_global_mem_tag{}; + static constexpr auto storage_for_counters = use_global_mem_tag{}; + + return one_group_submitter<_SortKernelGlob>()( + exec_q, n_iters, n_batch_size, n_to_sort, input_ptr, output_ptr, + proj_op, is_ascending, storage_for_values, storage_for_counters, + depends); + } + } + } + +private: + template + class TempBuf; + + template + class TempBuf + { + std::size_t buf_size; + + public: + TempBuf(std::size_t, std::size_t n) : buf_size(n) {} + auto get_acc(sycl::handler &cgh) + { + return sycl::local_accessor(buf_size, cgh); + } + + std::size_t get_iter_stride() const + { + return std::size_t{0}; + } + }; + + template + class TempBuf + { + sycl::buffer buf; + std::size_t iter_stride; + + public: + TempBuf(std::size_t n_iters, std::size_t n) + : buf(n_iters * n), iter_stride(n) + { + } + auto get_acc(sycl::handler &cgh) + { + return sycl::accessor(buf, cgh, sycl::read_write, sycl::no_init); + } + std::size_t get_iter_stride() const + { + return iter_stride; + } + }; + + static_assert(wg_size <= 1024); + static constexpr std::uint16_t bin_count = (1 << radix); + static constexpr std::uint16_t counter_buf_sz = wg_size * bin_count + 1; + + enum class temp_allocations + { + both_in_slm, + counters_in_slm, + both_in_global_mem + }; + + template + temp_allocations check_slm_size(const sycl::queue &exec_q, SizeT n) + { + // the kernel is designed for data size <= 64K + assert(n <= (SizeT(1) << 16)); + + static constexpr auto req_slm_size_counters = + counter_buf_sz * sizeof(std::uint16_t); + + const auto &dev = exec_q.get_device(); + + // Pessimistically only use half of the memory to take into account + // a SYCL group algorithm might use a portion of SLM + const std::size_t max_slm_size = + dev.template get_info() / 2; + + const auto n_uniform = 1 << ceil_log2(n); + const auto req_slm_size_val = sizeof(T) * n_uniform; + + return ((req_slm_size_val + req_slm_size_counters) <= max_slm_size) + ? + // the values and the counters are placed in SLM + temp_allocations::both_in_slm + : (req_slm_size_counters <= max_slm_size) + ? + // the counters are placed in SLM, the values - in the + // global memory + temp_allocations::counters_in_slm + : + // the values and the counters are placed in the global + // memory + temp_allocations::both_in_global_mem; + } + + template + struct one_group_submitter + { + template + sycl::event operator()(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_batch_size, + std::size_t n_values, + InputT *input_arr, + OutputT *output_arr, + const ProjT &proj_op, + const bool is_ascending, + SLM_value_tag, + SLM_counter_tag, + const std::vector &depends) + { + assert(!(n_values >> 16)); + + assert(n_values <= static_cast(block_size) * + static_cast(wg_size)); + + const std::uint16_t n = static_cast(n_values); + static_assert(std::is_same_v, OutputT>); + + using ValueT = OutputT; + + using KeyT = std::invoke_result_t; + + TempBuf buf_val( + n_batch_size, static_cast(block_size * wg_size)); + TempBuf buf_count( + n_batch_size, static_cast(counter_buf_sz)); + + sycl::range<1> lRange{wg_size}; + + sycl::event sort_ev; + std::vector deps{depends}; + + const std::size_t n_batches = + (n_iters + n_batch_size - 1) / n_batch_size; + + const auto &kernel_id = sycl::get_kernel_id(); + + auto const &ctx = exec_q.get_context(); + auto const &dev = exec_q.get_device(); + auto kb = sycl::get_kernel_bundle( + ctx, {dev}, {kernel_id}); + + const auto &krn = kb.get_kernel(kernel_id); + + const std::uint32_t krn_sg_size = krn.template get_info< + sycl::info::kernel_device_specific::max_sub_group_size>(dev); + + // due to a bug in CPU device implementation, an additional + // synchronization is necessary for short sub-group sizes + const bool work_around_needed = + exec_q.get_device().has(sycl::aspect::cpu) && + (krn_sg_size < 16); + + for (std::size_t batch_id = 0; batch_id < n_batches; ++batch_id) { + + const std::size_t block_start = batch_id * n_batch_size; + + // input_arr/output_arr each has shape (n_iters, n) + InputT *this_input_arr = input_arr + block_start * n_values; + OutputT *this_output_arr = output_arr + block_start * n_values; + + const std::size_t block_end = + std::min(block_start + n_batch_size, n_iters); + + sycl::range<1> gRange{(block_end - block_start) * wg_size}; + sycl::nd_range ndRange{gRange, lRange}; + + sort_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + cgh.use_kernel_bundle(kb); + + // allocation to use for value exchanges + auto exchange_acc = buf_val.get_acc(cgh); + const std::size_t exchange_acc_iter_stride = + buf_val.get_iter_stride(); + + // allocation for counters + auto counter_acc = buf_count.get_acc(cgh); + const std::size_t counter_acc_iter_stride = + buf_count.get_iter_stride(); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> + ndit) { + ValueT values[block_size]; + + const std::size_t iter_id = ndit.get_group(0); + const std::size_t iter_val_offset = + iter_id * static_cast(n); + const std::size_t iter_counter_offset = + iter_id * counter_acc_iter_stride; + const std::size_t iter_exchange_offset = + iter_id * exchange_acc_iter_stride; + + std::uint16_t wi = ndit.get_local_linear_id(); + std::uint16_t begin_bit = 0; + + static constexpr std::uint16_t end_bit = + number_of_bits_in_type(); + + // copy from input array into values +#pragma unroll + for (std::uint16_t i = 0; i < block_size; ++i) { + const std::uint16_t id = wi * block_size + i; + values[i] = + (id < n) ? this_input_arr[iter_val_offset + id] + : ValueT{}; + } + + while (true) { + // indices for indirect access in the "re-order" + // phase + std::uint16_t indices[block_size]; + { + // pointers to bucket's counters + std::uint16_t *counters[block_size]; + + // counting phase + auto pcounter = + get_accessor_pointer(counter_acc) + + (wi + iter_counter_offset); + + // initialize counters +#pragma unroll + for (std::uint16_t i = 0; i < bin_count; ++i) + pcounter[i * wg_size] = std::uint16_t{0}; + + sycl::group_barrier(ndit.get_group()); + + if (is_ascending) { +#pragma unroll + for (std::uint16_t i = 0; i < block_size; + ++i) { + const std::uint16_t id = + wi * block_size + i; + static constexpr std::uint16_t + bin_mask = bin_count - 1; + + // points to the padded element, i.e. id + // is in-range + static constexpr std::uint16_t + default_out_of_range_bin_id = + bin_mask; + + const std::uint16_t bin = + (id < n) + ? get_bucket_id( + order_preserving_cast< + /* is_ascending */ + true>( + proj_op(values[i])), + begin_bit) + : default_out_of_range_bin_id; + + // counting and local offset calculation + counters[i] = &pcounter[bin * wg_size]; + indices[i] = *counters[i]; + *counters[i] = indices[i] + 1; + + if (work_around_needed) { + sycl::group_barrier( + ndit.get_group()); + } + } + } + else { +#pragma unroll + for (std::uint16_t i = 0; i < block_size; + ++i) { + const std::uint16_t id = + wi * block_size + i; + static constexpr std::uint16_t + bin_mask = bin_count - 1; + + // points to the padded element, i.e. id + // is in-range + static constexpr std::uint16_t + default_out_of_range_bin_id = + bin_mask; + + const std::uint16_t bin = + (id < n) + ? get_bucket_id( + order_preserving_cast< + /* is_ascending */ + false>( + proj_op(values[i])), + begin_bit) + : default_out_of_range_bin_id; + + // counting and local offset calculation + counters[i] = &pcounter[bin * wg_size]; + indices[i] = *counters[i]; + *counters[i] = indices[i] + 1; + + if (work_around_needed) { + sycl::group_barrier( + ndit.get_group()); + } + } + } + + sycl::group_barrier(ndit.get_group()); + + // exclusive scan phase + { + + // scan contiguous numbers + std::uint16_t bin_sum[bin_count]; + const std::size_t counter_offset0 = + iter_counter_offset + wi * bin_count; + bin_sum[0] = counter_acc[counter_offset0]; + +#pragma unroll + for (std::uint16_t i = 1; i < bin_count; + ++i) + bin_sum[i] = + bin_sum[i - 1] + + counter_acc[counter_offset0 + i]; + + sycl::group_barrier(ndit.get_group()); + + // exclusive scan local sum + std::uint16_t sum_scan = + sycl::exclusive_scan_over_group( + ndit.get_group(), + bin_sum[bin_count - 1], + sycl::plus()); + +// add to local sum, generate exclusive scan result +#pragma unroll + for (std::uint16_t i = 0; i < bin_count; + ++i) + counter_acc[counter_offset0 + i + 1] = + sum_scan + bin_sum[i]; + + if (wi == 0) + counter_acc[iter_counter_offset + 0] = + std::uint32_t{0}; + + sycl::group_barrier(ndit.get_group()); + } + +#pragma unroll + for (std::uint16_t i = 0; i < block_size; ++i) { + // a global index is a local offset plus a + // global base index + indices[i] += *counters[i]; + } + + sycl::group_barrier(ndit.get_group()); + } + + begin_bit += radix; + + // "re-order" phase + sycl::group_barrier(ndit.get_group()); + if (begin_bit >= end_bit) { + // the last iteration - writing out the result +#pragma unroll + for (std::uint16_t i = 0; i < block_size; ++i) { + const std::uint16_t r = indices[i]; + if (r < n) { + this_output_arr[iter_val_offset + r] = + values[i]; + } + } + + return; + } + + // data exchange +#pragma unroll + for (std::uint16_t i = 0; i < block_size; ++i) { + const std::uint16_t r = indices[i]; + if (r < n) + exchange_acc[iter_exchange_offset + r] = + values[i]; + } + + sycl::group_barrier(ndit.get_group()); + +#pragma unroll + for (std::uint16_t i = 0; i < block_size; ++i) { + const std::uint16_t id = wi * block_size + i; + if (id < n) + values[i] = + exchange_acc[iter_exchange_offset + id]; + } + + sycl::group_barrier(ndit.get_group()); + } + }); + }); + + deps = {sort_ev}; + } + + return sort_ev; + } + }; +}; + +template +struct OneWorkGroupRadixSortKernel; + +//----------------------------------------------------------------------- +// radix sort: main function +//----------------------------------------------------------------------- +template +sycl::event parallel_radix_sort_impl(sycl::queue &exec_q, + std::size_t n_iters, + std::size_t n_to_sort, + const ValueT *input_arr, + ValueT *output_arr, + const ProjT &proj_op, + const bool is_ascending, + const std::vector &depends) +{ + assert(n_to_sort > 1); + + using KeyT = std::remove_cv_t< + std::remove_reference_t>>; + + // radix bits represent number of processed bits in each value during one + // iteration + static constexpr std::uint32_t radix_bits = 4; + + sycl::event sort_ev{}; + + const auto &dev = exec_q.get_device(); + const auto max_wg_size = + dev.template get_info(); + + static constexpr std::uint16_t ref_wg_size = 64; + if (n_to_sort <= 16384 && ref_wg_size * 8 <= max_wg_size) { + using _RadixSortKernel = OneWorkGroupRadixSortKernel; + + if (n_to_sort <= 64 && ref_wg_size <= max_wg_size) { + // wg_size * block_size == 64 * 1 * 1 == 64 + static constexpr std::uint16_t wg_size = ref_wg_size; + static constexpr std::uint16_t block_size = 1; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 128 && ref_wg_size * 2 <= max_wg_size) { + // wg_size * block_size == 64 * 2 * 1 == 128 + static constexpr std::uint16_t wg_size = ref_wg_size * 2; + static constexpr std::uint16_t block_size = 1; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 256 && ref_wg_size * 2 <= max_wg_size) { + // wg_size * block_size == 64 * 2 * 2 == 256 + static constexpr std::uint16_t wg_size = ref_wg_size * 2; + static constexpr std::uint16_t block_size = 2; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 512 && ref_wg_size * 2 <= max_wg_size) { + // wg_size * block_size == 64 * 2 * 4 == 512 + static constexpr std::uint16_t wg_size = ref_wg_size * 2; + static constexpr std::uint16_t block_size = 4; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 1024 && ref_wg_size * 2 <= max_wg_size) { + // wg_size * block_size == 64 * 2 * 8 == 1024 + static constexpr std::uint16_t wg_size = ref_wg_size * 2; + static constexpr std::uint16_t block_size = 8; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 2048 && ref_wg_size * 4 <= max_wg_size) { + // wg_size * block_size == 64 * 4 * 8 == 2048 + static constexpr std::uint16_t wg_size = ref_wg_size * 4; + static constexpr std::uint16_t block_size = 8; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 4096 && ref_wg_size * 4 <= max_wg_size) { + // wg_size * block_size == 64 * 4 * 16 == 4096 + static constexpr std::uint16_t wg_size = ref_wg_size * 4; + static constexpr std::uint16_t block_size = 16; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else if (n_to_sort <= 8192 && ref_wg_size * 8 <= max_wg_size) { + // wg_size * block_size == 64 * 8 * 16 == 8192 + static constexpr std::uint16_t wg_size = ref_wg_size * 8; + static constexpr std::uint16_t block_size = 16; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + else { + // wg_size * block_size == 64 * 8 * 32 == 16384 + static constexpr std::uint16_t wg_size = ref_wg_size * 8; + static constexpr std::uint16_t block_size = 32; + + sort_ev = subgroup_radix_sort<_RadixSortKernel, wg_size, block_size, + radix_bits>{}( + exec_q, n_iters, n_to_sort, input_arr, output_arr, proj_op, + is_ascending, depends); + } + } + else { + static constexpr std::uint32_t radix_iters = + number_of_buckets_in_type(radix_bits); + static constexpr std::uint32_t radix_states = std::uint32_t(1) + << radix_bits; + + static constexpr std::size_t bound_512k = (std::size_t(1) << 19); + static constexpr std::size_t bound_2m = (std::size_t(1) << 21); + + const auto wg_sz_k = (n_to_sort < bound_512k) ? 8 + : (n_to_sort <= bound_2m) ? 4 + : 1; + const std::size_t wg_size = max_wg_size / wg_sz_k; + + const std::size_t n_segments = (n_to_sort + wg_size - 1) / wg_size; + + // Additional radix_states elements are used for getting local offsets + // from count values + no_op flag; 'No operation' flag specifies whether + // to skip re-order phase if the all keys are the same (lie in one bin) + const std::size_t n_counts = + (n_segments + 1) * radix_states + 1 /*no_op flag*/; + + using CountT = std::uint32_t; + + // memory for storing count and offset values + auto count_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + n_iters * n_counts, exec_q); + + CountT *count_ptr = count_owner.get(); + + static constexpr std::uint32_t zero_radix_iter{0}; + + if constexpr (std::is_same_v) { + + sort_ev = parallel_radix_sort_iteration_step< + radix_bits, /*even=*/true>::submit(exec_q, n_iters, n_segments, + zero_radix_iter, n_to_sort, + input_arr, output_arr, + n_counts, count_ptr, proj_op, + is_ascending, depends); + + sort_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {sort_ev}, count_owner); + + return sort_ev; + } + + auto tmp_arr_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + n_iters * n_to_sort, exec_q); + + ValueT *tmp_arr = tmp_arr_owner.get(); + + // iterations per each bucket + assert("Number of iterations must be even" && radix_iters % 2 == 0); + assert(radix_iters > 0); + + sort_ev = parallel_radix_sort_iteration_step< + radix_bits, /*even=*/true>::submit(exec_q, n_iters, n_segments, + zero_radix_iter, n_to_sort, + input_arr, tmp_arr, n_counts, + count_ptr, proj_op, is_ascending, + depends); + + for (std::uint32_t radix_iter = 1; radix_iter < radix_iters; + ++radix_iter) { + if (radix_iter % 2 == 0) { + sort_ev = parallel_radix_sort_iteration_step< + radix_bits, + /*even=*/true>::submit(exec_q, n_iters, n_segments, + radix_iter, n_to_sort, output_arr, + tmp_arr, n_counts, count_ptr, + proj_op, is_ascending, {sort_ev}); + } + else { + sort_ev = parallel_radix_sort_iteration_step< + radix_bits, + /*even=*/false>::submit(exec_q, n_iters, n_segments, + radix_iter, n_to_sort, tmp_arr, + output_arr, n_counts, count_ptr, + proj_op, is_ascending, {sort_ev}); + } + } + + sort_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {sort_ev}, tmp_arr_owner, count_owner); + } + + return sort_ev; +} + +struct IdentityProj +{ + constexpr IdentityProj() {} + + template + constexpr T operator()(T val) const + { + return val; + } +}; + +template +struct ValueProj +{ + constexpr ValueProj() {} + + constexpr ValueT operator()(const std::pair &pair) const + { + return pair.first; + } +}; + +template +struct IndexedProj +{ + IndexedProj(const ValueT *arg_ptr) : ptr(arg_ptr), value_projector{} {} + + IndexedProj(const ValueT *arg_ptr, const ProjT &proj_op) + : ptr(arg_ptr), value_projector(proj_op) + { + } + + auto operator()(IndexT i) const + { + return value_projector(ptr[i]); + } + +private: + const ValueT *ptr; + ProjT value_projector; +}; + +} // namespace radix_sort_details + +using dpctl::tensor::ssize_t; + +template +sycl::event + radix_sort_axis1_contig_impl(sycl::queue &exec_q, + const bool sort_ascending, + // number of sub-arrays to sort (num. of rows + // in a matrix when sorting over rows) + std::size_t iter_nelems, + // size of each array to sort (length of rows, + // i.e. number of columns) + std::size_t sort_nelems, + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + sort_arg_offset; + argTy *res_tp = + reinterpret_cast(res_cp) + iter_res_offset + sort_res_offset; + + using Proj = radix_sort_details::IdentityProj; + static constexpr Proj proj_op{}; + + sycl::event radix_sort_ev = + radix_sort_details::parallel_radix_sort_impl( + exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, proj_op, + sort_ascending, depends); + + return radix_sort_ev; +} + +template +class radix_argsort_index_write_out_krn; + +template +class radix_argsort_iota_krn; + +template +sycl::event + radix_argsort_axis1_contig_impl(sycl::queue &exec_q, + const bool sort_ascending, + // number of sub-arrays to sort (num. of + // rows in a matrix when sorting over rows) + std::size_t iter_nelems, + // size of each array to sort (length of + // rows, i.e. number of columns) + std::size_t sort_nelems, + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + const argTy *arg_tp = reinterpret_cast(arg_cp) + + iter_arg_offset + sort_arg_offset; + IndexTy *res_tp = + reinterpret_cast(res_cp) + iter_res_offset + sort_res_offset; + + const std::size_t total_nelems = iter_nelems * sort_nelems; + auto workspace_owner = + dpctl::tensor::alloc_utils::smart_malloc_device(total_nelems, + exec_q); + + // get raw USM pointer + IndexTy *workspace = workspace_owner.get(); + + using IdentityProjT = radix_sort_details::IdentityProj; + using IndexedProjT = + radix_sort_details::IndexedProj; + const IndexedProjT proj_op{arg_tp}; + + using IotaKernelName = radix_argsort_iota_krn; + + using dpctl::tensor::kernels::sort_utils_detail::iota_impl; + + sycl::event iota_ev = iota_impl( + exec_q, workspace, total_nelems, depends); + + sycl::event radix_sort_ev = + radix_sort_details::parallel_radix_sort_impl( + exec_q, iter_nelems, sort_nelems, workspace, res_tp, proj_op, + sort_ascending, {iota_ev}); + + using MapBackKernelName = radix_argsort_index_write_out_krn; + using dpctl::tensor::kernels::sort_utils_detail::map_back_impl; + + sycl::event dep = radix_sort_ev; + + // no need to perform map_back ( id % sort_nelems) + // if total_nelems == sort_nelems + if (iter_nelems > 1u) { + dep = map_back_impl( + exec_q, total_nelems, res_tp, res_tp, sort_nelems, {dep}); + } + + sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {dep}, workspace_owner); + + return cleanup_ev; +} + +} // namespace dpctl::tensor::kernels diff --git a/dpctl_ext/tensor/libtensor/include/kernels/sorting/search_sorted_detail.hpp b/dpctl_ext/tensor/libtensor/include/kernels/sorting/search_sorted_detail.hpp new file mode 100644 index 000000000000..d184261aeabf --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/sorting/search_sorted_detail.hpp @@ -0,0 +1,124 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor sort/argsort operations. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace dpctl::tensor::kernels +{ + +namespace search_sorted_detail +{ + +template +T quotient_ceil(T n, T m) +{ + return (n + m - 1) / m; +} + +template +std::size_t lower_bound_impl(const Acc acc, + const std::size_t first, + const std::size_t last, + const Value &value, + const Compare &comp) +{ + std::size_t n = last - first; + std::size_t cur = n, start = first; + std::size_t it; + while (n > 0) { + it = start; + cur = n / 2; + it += cur; + if (comp(acc[it], value)) { + n -= cur + 1, start = ++it; + } + else + n = cur; + } + return start; +} + +template +std::size_t upper_bound_impl(const Acc acc, + const std::size_t first, + const std::size_t last, + const Value &value, + const Compare &comp) +{ + const auto &op_comp = [comp](auto x, auto y) { return !comp(y, x); }; + return lower_bound_impl(acc, first, last, value, op_comp); +} + +template +std::size_t lower_bound_indexed_impl(const Acc acc, + std::size_t first, + std::size_t last, + const Value &value, + const Compare &comp, + const IndexerT &acc_indexer) +{ + std::size_t n = last - first; + std::size_t cur = n, start = first; + std::size_t it; + while (n > 0) { + it = start; + cur = n / 2; + it += cur; + if (comp(acc[acc_indexer(it)], value)) { + n -= cur + 1, start = ++it; + } + else + n = cur; + } + return start; +} + +template +std::size_t upper_bound_indexed_impl(const Acc acc, + const std::size_t first, + const std::size_t last, + const Value &value, + const Compare &comp, + const IndexerT &acc_indexer) +{ + const auto &op_comp = [comp](auto x, auto y) { return !comp(y, x); }; + return lower_bound_indexed_impl(acc, first, last, value, op_comp, + acc_indexer); +} + +} // namespace search_sorted_detail + +} // namespace dpctl::tensor::kernels diff --git a/dpctl_ext/tensor/libtensor/include/kernels/sorting/searchsorted.hpp b/dpctl_ext/tensor/libtensor/include/kernels/sorting/searchsorted.hpp new file mode 100644 index 000000000000..b89ad922e102 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/sorting/searchsorted.hpp @@ -0,0 +1,259 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor sort/argsort operations. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include + +#include + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/search_sorted_detail.hpp" +#include "utils/offset_utils.hpp" + +namespace dpctl::tensor::kernels +{ + +using dpctl::tensor::ssize_t; + +template +struct SearchSortedFunctor +{ +private: + const argTy *hay_tp; + const argTy *needles_tp; + indTy *positions_tp; + std::size_t hay_nelems; + HayIndexerT hay_indexer; + NeedlesIndexerT needles_indexer; + PositionsIndexerT positions_indexer; + +public: + SearchSortedFunctor(const argTy *hay_, + const argTy *needles_, + indTy *positions_, + const std::size_t hay_nelems_, + const HayIndexerT &hay_indexer_, + const NeedlesIndexerT &needles_indexer_, + const PositionsIndexerT &positions_indexer_) + : hay_tp(hay_), needles_tp(needles_), positions_tp(positions_), + hay_nelems(hay_nelems_), hay_indexer(hay_indexer_), + needles_indexer(needles_indexer_), + positions_indexer(positions_indexer_) + { + } + + void operator()(sycl::id<1> id) const + { + const Compare comp{}; + + const std::size_t i = id[0]; + const argTy needle_v = needles_tp[needles_indexer(i)]; + + // position of the needle_v in the hay array + indTy pos{}; + + static constexpr std::size_t zero(0); + if constexpr (left_side) { + // search in hay in left-closed interval, give `pos` such that + // hay[pos - 1] < needle_v <= hay[pos] + + // lower_bound returns the first pos such that bool(hay[pos] < + // needle_v) is false, i.e. needle_v <= hay[pos] + pos = search_sorted_detail::lower_bound_indexed_impl( + hay_tp, zero, hay_nelems, needle_v, comp, hay_indexer); + } + else { + // search in hay in right-closed interval: hay[pos - 1] <= needle_v + // < hay[pos] + + // upper_bound returns the first pos such that bool(needle_v < + // hay[pos]) is true, i.e. needle_v < hay[pos] + pos = search_sorted_detail::upper_bound_indexed_impl( + hay_tp, zero, hay_nelems, needle_v, comp, hay_indexer); + } + + positions_tp[positions_indexer(i)] = pos; + } +}; + +typedef sycl::event (*searchsorted_contig_impl_fp_ptr_t)( + sycl::queue &, + const std::size_t, + const std::size_t, + const char *, + const ssize_t, + const char *, + const ssize_t, + char *, + const ssize_t, + const std::vector &); + +template +class searchsorted_contig_impl_krn; + +template +sycl::event searchsorted_contig_impl(sycl::queue &exec_q, + const std::size_t hay_nelems, + const std::size_t needles_nelems, + const char *hay_cp, + const ssize_t hay_offset, + const char *needles_cp, + const ssize_t needles_offset, + char *positions_cp, + const ssize_t positions_offset, + const std::vector &depends) +{ + const argTy *hay_tp = reinterpret_cast(hay_cp) + hay_offset; + const argTy *needles_tp = + reinterpret_cast(needles_cp) + needles_offset; + + indTy *positions_tp = + reinterpret_cast(positions_cp) + positions_offset; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using KernelName = + class searchsorted_contig_impl_krn; + + sycl::range<1> gRange(needles_nelems); + + using TrivialIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + static constexpr TrivialIndexerT hay_indexer{}; + static constexpr TrivialIndexerT needles_indexer{}; + static constexpr TrivialIndexerT positions_indexer{}; + + const auto fnctr = + SearchSortedFunctor( + hay_tp, needles_tp, positions_tp, hay_nelems, hay_indexer, + needles_indexer, positions_indexer); + + cgh.parallel_for(gRange, fnctr); + }); + + return comp_ev; +} + +typedef sycl::event (*searchsorted_strided_impl_fp_ptr_t)( + sycl::queue &, + const std::size_t, + const std::size_t, + const char *, + const ssize_t, + const ssize_t, + const char *, + const ssize_t, + char *, + const ssize_t, + int, + const ssize_t *, + const std::vector &); + +template +class searchsorted_strided_impl_krn; + +template +sycl::event searchsorted_strided_impl( + sycl::queue &exec_q, + const std::size_t hay_nelems, + const std::size_t needles_nelems, + const char *hay_cp, + const ssize_t hay_offset, + // hay is 1D, so hay_nelems, hay_offset, hay_stride describe strided array + const ssize_t hay_stride, + const char *needles_cp, + const ssize_t needles_offset, + char *positions_cp, + const ssize_t positions_offset, + const int needles_nd, + // packed_shape_strides is [needles_shape, needles_strides, + // positions_strides] has length of 3*needles_nd + const ssize_t *packed_shape_strides, + const std::vector &depends) +{ + const argTy *hay_tp = reinterpret_cast(hay_cp); + const argTy *needles_tp = reinterpret_cast(needles_cp); + + indTy *positions_tp = reinterpret_cast(positions_cp); + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + sycl::range<1> gRange(needles_nelems); + + using HayIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + const HayIndexerT hay_indexer( + /* offset */ hay_offset, + /* size */ hay_nelems, + /* step */ hay_stride); + + using NeedlesIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const ssize_t *needles_shape_strides = packed_shape_strides; + const NeedlesIndexerT needles_indexer(needles_nd, needles_offset, + needles_shape_strides); + using PositionsIndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + + const ssize_t *positions_shape = packed_shape_strides; + const ssize_t *positions_strides = + packed_shape_strides + 2 * needles_nd; + const PositionsIndexerT positions_indexer( + needles_nd, positions_offset, positions_shape, positions_strides); + + const auto fnctr = + SearchSortedFunctor( + hay_tp, needles_tp, positions_tp, hay_nelems, hay_indexer, + needles_indexer, positions_indexer); + using KernelName = + class searchsorted_strided_impl_krn; + + cgh.parallel_for(gRange, fnctr); + }); + + return comp_ev; +} + +} // namespace dpctl::tensor::kernels diff --git a/dpctl_ext/tensor/libtensor/include/kernels/sorting/sort_impl_fn_ptr_t.hpp b/dpctl_ext/tensor/libtensor/include/kernels/sorting/sort_impl_fn_ptr_t.hpp new file mode 100644 index 000000000000..7b48f310a445 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/sorting/sort_impl_fn_ptr_t.hpp @@ -0,0 +1,61 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include + +#include "kernels/dpctl_tensor_types.hpp" + +namespace dpctl::tensor::kernels +{ + +using dpctl::tensor::ssize_t; + +typedef sycl::event (*sort_contig_fn_ptr_t)(sycl::queue &, + std::size_t, + std::size_t, + const char *, + char *, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + const std::vector &); + +} // namespace dpctl::tensor::kernels diff --git a/dpctl_ext/tensor/libtensor/include/kernels/sorting/sort_utils.hpp b/dpctl_ext/tensor/libtensor/include/kernels/sorting/sort_utils.hpp new file mode 100644 index 000000000000..4bbcacd56fcc --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/sorting/sort_utils.hpp @@ -0,0 +1,147 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor sort/argsort operations. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace dpctl::tensor::kernels::sort_utils_detail +{ + +namespace syclexp = sycl::ext::oneapi::experimental; + +template +sycl::event iota_impl(sycl::queue &exec_q, + T *data, + std::size_t nelems, + const std::vector &dependent_events) +{ + static constexpr std::uint32_t lws = 256; + static constexpr std::uint32_t n_wi = 4; + const std::size_t n_groups = (nelems + n_wi * lws - 1) / (n_wi * lws); + + sycl::range<1> gRange{n_groups * lws}; + sycl::range<1> lRange{lws}; + sycl::nd_range<1> ndRange{gRange, lRange}; + + sycl::event e = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_events); + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> it) { + const std::size_t gid = it.get_global_linear_id(); + const auto &sg = it.get_sub_group(); + const std::uint32_t lane_id = sg.get_local_id()[0]; + + const std::size_t offset = (gid - lane_id) * n_wi; + const std::uint32_t max_sgSize = sg.get_max_local_range()[0]; + + std::array stripe{}; +#pragma unroll + for (std::uint32_t i = 0; i < n_wi; ++i) { + stripe[i] = T(offset + lane_id + i * max_sgSize); + } + + if (offset + n_wi * max_sgSize < nelems) { + static constexpr auto group_ls_props = syclexp::properties{ + syclexp::data_placement_striped + // , syclexp::full_group + }; + + auto out_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&data[offset]); + + syclexp::group_store(sg, sycl::span{&stripe[0], n_wi}, + out_multi_ptr, group_ls_props); + } + else { + for (std::size_t idx = offset + lane_id; idx < nelems; + idx += max_sgSize) { + data[idx] = T(idx); + } + } + }); + }); + + return e; +} + +template +sycl::event map_back_impl(sycl::queue &exec_q, + std::size_t nelems, + const IndexTy *flat_index_data, + IndexTy *reduced_index_data, + std::size_t row_size, + const std::vector &dependent_events) +{ + static constexpr std::uint32_t lws = 64; + static constexpr std::uint32_t n_wi = 4; + const std::size_t n_groups = (nelems + lws * n_wi - 1) / (n_wi * lws); + + sycl::range<1> lRange{lws}; + sycl::range<1> gRange{n_groups * lws}; + sycl::nd_range<1> ndRange{gRange, lRange}; + + sycl::event map_back_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependent_events); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> it) { + const std::size_t gid = it.get_global_linear_id(); + const auto &sg = it.get_sub_group(); + const std::uint32_t lane_id = sg.get_local_id()[0]; + const std::uint32_t sg_size = sg.get_max_local_range()[0]; + + const std::size_t start_id = (gid - lane_id) * n_wi + lane_id; + +#pragma unroll + for (std::uint32_t i = 0; i < n_wi; ++i) { + const std::size_t data_id = start_id + i * sg_size; + + if (data_id < nelems) { + const IndexTy linear_index = flat_index_data[data_id]; + reduced_index_data[data_id] = (linear_index % row_size); + } + } + }); + }); + + return map_back_ev; +} + +} // namespace dpctl::tensor::kernels::sort_utils_detail diff --git a/dpctl_ext/tensor/libtensor/include/kernels/sorting/topk.hpp b/dpctl_ext/tensor/libtensor/include/kernels/sorting/topk.hpp new file mode 100644 index 000000000000..90e06a1e9eb8 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/sorting/topk.hpp @@ -0,0 +1,510 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for tensor topk operation. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include "kernels/sorting/merge_sort.hpp" +#include "kernels/sorting/radix_sort.hpp" +#include "kernels/sorting/search_sorted_detail.hpp" +#include "kernels/sorting/sort_utils.hpp" +#include "utils/sycl_alloc_utils.hpp" + +namespace dpctl::tensor::kernels +{ + +namespace topk_detail +{ + +void scale_topk_params(const std::uint64_t nelems_per_slm, + const std::size_t sub_groups_per_work_group, + const std::uint32_t elems_per_wi, + const std::vector &sg_sizes, + std::size_t &lws, + std::size_t &nelems_wg_sorts) +{ + for (auto it = sg_sizes.rbegin(); it != sg_sizes.rend(); ++it) { + auto sg_size = *it; + lws = sub_groups_per_work_group * sg_size; + nelems_wg_sorts = elems_per_wi * lws; + if (nelems_wg_sorts < nelems_per_slm) { + return; + } + } + // should never reach + throw std::runtime_error("Could not construct top k kernel parameters"); +} + +template +sycl::event write_out_impl(sycl::queue &exec_q, + std::size_t iter_nelems, + std::size_t k, + const argTy *arg_tp, + const IndexTy *index_data, + std::size_t iter_index_stride, + std::size_t axis_nelems, + argTy *vals_tp, + IndexTy *inds_tp, + const std::vector &depends) +{ + static constexpr std::uint32_t lws = 64; + static constexpr std::uint32_t n_wi = 4; + const std::size_t nelems = iter_nelems * k; + const std::size_t n_groups = (nelems + lws * n_wi - 1) / (n_wi * lws); + + sycl::range<1> lRange{lws}; + sycl::range<1> gRange{n_groups * lws}; + sycl::nd_range<1> ndRange{gRange, lRange}; + + sycl::event write_out_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.parallel_for(ndRange, [=](sycl::nd_item<1> it) { + const std::size_t gid = it.get_global_linear_id(); + const auto &sg = it.get_sub_group(); + const std::uint32_t lane_id = sg.get_local_id()[0]; + const std::uint32_t sg_size = sg.get_max_local_range()[0]; + + const std::size_t start_id = (gid - lane_id) * n_wi + lane_id; + +#pragma unroll + for (std::uint32_t i = 0; i < n_wi; ++i) { + const std::size_t data_id = start_id + i * sg_size; + + if (data_id < nelems) { + const std::size_t iter_id = data_id / k; + + /* + const std::size_t axis_gid = data_id - (iter_gid * k); + const std::size_t src_idx = iter_gid * iter_index_stride + + axis_gid; + */ + const std::size_t src_idx = + data_id + iter_id * (iter_index_stride - k); + + const IndexTy res_ind = index_data[src_idx]; + const argTy v = arg_tp[res_ind]; + + const std::size_t dst_idx = data_id; + vals_tp[dst_idx] = v; + inds_tp[dst_idx] = (res_ind % axis_nelems); + } + } + }); + }); + + return write_out_ev; +} + +} // namespace topk_detail + +template +class topk_populate_index_data_krn; + +template +class topk_full_merge_map_back_krn; + +template +sycl::event + topk_full_merge_sort_impl(sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays + std::size_t axis_nelems, // size of each sub-array + std::size_t k, + const argTy *arg_tp, + argTy *vals_tp, + IndexTy *inds_tp, + const CompT &comp, + const std::vector &depends) +{ + auto index_data_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * axis_nelems, exec_q); + // extract USM pointer + IndexTy *index_data = index_data_owner.get(); + + using IotaKernelName = topk_populate_index_data_krn; + + using dpctl::tensor::kernels::sort_utils_detail::iota_impl; + + sycl::event populate_indexed_data_ev = iota_impl( + exec_q, index_data, iter_nelems * axis_nelems, depends); + + std::size_t sorted_block_size; + // Sort segments of the array + sycl::event base_sort_ev = + merge_sort_detail::sort_over_work_group_contig_impl( + exec_q, iter_nelems, axis_nelems, index_data, index_data, comp, + sorted_block_size, // modified in place with size of sorted block + // size + {populate_indexed_data_ev}); + + // Merge segments in parallel until all elements are sorted + sycl::event merges_ev = merge_sort_detail::merge_sorted_block_contig_impl( + exec_q, iter_nelems, axis_nelems, index_data, comp, sorted_block_size, + {base_sort_ev}); + + using WriteOutKernelName = topk_full_merge_map_back_krn; + + sycl::event write_out_ev = + topk_detail::write_out_impl( + exec_q, iter_nelems, k, arg_tp, index_data, axis_nelems, + axis_nelems, vals_tp, inds_tp, {merges_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {write_out_ev}, + index_data_owner); + + return cleanup_host_task_event; +}; + +template +class topk_partial_merge_map_back_krn; + +template +class topk_over_work_group_krn; + +template > +sycl::event topk_merge_impl( + sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays to sort (num. of rows + // in a matrix when sorting over rows) + std::size_t axis_nelems, // size of each array to sort (length of + // rows, i.e. number of columns) + std::size_t k, + const char *arg_cp, + char *vals_cp, + char *inds_cp, + const std::vector &depends) +{ + if (axis_nelems < k) { + throw std::runtime_error("Invalid sort axis size for value of k"); + } + + const argTy *arg_tp = reinterpret_cast(arg_cp); + argTy *vals_tp = reinterpret_cast(vals_cp); + IndexTy *inds_tp = reinterpret_cast(inds_cp); + + using dpctl::tensor::kernels::IndexComp; + const IndexComp index_comp{arg_tp, ValueComp{}}; + + if (axis_nelems <= 512 || k >= 1024 || k > axis_nelems / 2) { + return topk_full_merge_sort_impl(exec_q, iter_nelems, axis_nelems, k, + arg_tp, vals_tp, inds_tp, index_comp, + depends); + } + else { + using PartialKernelName = + topk_over_work_group_krn; + + const auto &kernel_id = sycl::get_kernel_id(); + + auto const &ctx = exec_q.get_context(); + auto const &dev = exec_q.get_device(); + + auto kb = sycl::get_kernel_bundle( + ctx, {dev}, {kernel_id}); + + auto krn = kb.get_kernel(kernel_id); + + const std::uint32_t max_sg_size = krn.template get_info< + sycl::info::kernel_device_specific::max_sub_group_size>(dev); + const std::uint64_t device_local_memory_size = + dev.get_info(); + + // leave 512 bytes of local memory for RT + const std::uint64_t safety_margin = 512; + + const std::uint64_t nelems_per_slm = + (device_local_memory_size - safety_margin) / (2 * sizeof(IndexTy)); + + static constexpr std::uint32_t sub_groups_per_work_group = 4; + const std::uint32_t elems_per_wi = dev.has(sycl::aspect::cpu) ? 8 : 2; + + std::size_t lws = sub_groups_per_work_group * max_sg_size; + + std::size_t sorted_block_size = elems_per_wi * lws; + if (sorted_block_size > nelems_per_slm) { + const std::vector sg_sizes = + dev.get_info(); + topk_detail::scale_topk_params( + nelems_per_slm, sub_groups_per_work_group, elems_per_wi, + sg_sizes, + lws, // modified by reference + sorted_block_size // modified by reference + ); + } + + // This assumption permits doing away with using a loop + assert(sorted_block_size % lws == 0); + + using search_sorted_detail::quotient_ceil; + const std::size_t n_segments = + quotient_ceil(axis_nelems, sorted_block_size); + + // round k up for the later merge kernel if necessary + const std::size_t round_k_to = dev.has(sycl::aspect::cpu) ? 32 : 4; + std::size_t k_rounded = + (k < round_k_to) + ? k + : quotient_ceil(k, round_k_to) * round_k_to; + + // get length of tail for alloc size + auto rem = axis_nelems % sorted_block_size; + auto alloc_len = (rem && rem < k_rounded) + ? rem + k_rounded * (n_segments - 1) + : k_rounded * n_segments; + + // if allocation would be sufficiently large or k is larger than + // elements processed, use full sort + if (k_rounded >= axis_nelems || k_rounded >= sorted_block_size || + alloc_len >= axis_nelems / 2) + { + return topk_full_merge_sort_impl(exec_q, iter_nelems, axis_nelems, + k, arg_tp, vals_tp, inds_tp, + index_comp, depends); + } + + auto index_data_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * alloc_len, exec_q); + // get raw USM pointer + IndexTy *index_data = index_data_owner.get(); + + // no need to populate index data: SLM will be populated with default + // values + + sycl::event base_sort_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.use_kernel_bundle(kb); + + sycl::range<1> global_range{iter_nelems * n_segments * lws}; + sycl::range<1> local_range{lws}; + + sycl::range<1> slm_range{sorted_block_size}; + sycl::local_accessor work_space(slm_range, cgh); + sycl::local_accessor scratch_space(slm_range, cgh); + + sycl::nd_range<1> ndRange(global_range, local_range); + + cgh.parallel_for( + ndRange, [=](sycl::nd_item<1> it) { + const std::size_t group_id = it.get_group_linear_id(); + const std::size_t iter_id = group_id / n_segments; + const std::size_t segment_id = + group_id - iter_id * n_segments; + const std::size_t lid = it.get_local_linear_id(); + + const std::size_t segment_start_idx = + segment_id * sorted_block_size; + const std::size_t segment_end_idx = std::min( + segment_start_idx + sorted_block_size, axis_nelems); + const std::size_t wg_chunk_size = + segment_end_idx - segment_start_idx; + + // load input into SLM + for (std::size_t array_id = segment_start_idx + lid; + array_id < segment_end_idx; array_id += lws) + { + IndexTy v = (array_id < axis_nelems) + ? iter_id * axis_nelems + array_id + : IndexTy{}; + work_space[array_id - segment_start_idx] = v; + } + sycl::group_barrier(it.get_group()); + + const std::size_t chunk = + quotient_ceil(sorted_block_size, lws); + + const std::size_t chunk_start_idx = lid * chunk; + const std::size_t chunk_end_idx = + sycl::min(chunk_start_idx + chunk, wg_chunk_size); + + merge_sort_detail::leaf_sort_impl( + work_space, chunk_start_idx, chunk_end_idx, index_comp); + + sycl::group_barrier(it.get_group()); + + bool data_in_temp = false; + std::size_t n_chunks_merged = 1; + + // merge chunk while n_chunks_merged * chunk < wg_chunk_size + const std::size_t max_chunks_merged = + 1 + ((wg_chunk_size - 1) / chunk); + for (; n_chunks_merged < max_chunks_merged; + data_in_temp = !data_in_temp, n_chunks_merged *= 2) + { + const std::size_t nelems_sorted_so_far = + n_chunks_merged * chunk; + const std::size_t q = (lid / n_chunks_merged); + const std::size_t start_1 = sycl::min( + 2 * nelems_sorted_so_far * q, wg_chunk_size); + const std::size_t end_1 = sycl::min( + start_1 + nelems_sorted_so_far, wg_chunk_size); + const std::size_t end_2 = sycl::min( + end_1 + nelems_sorted_so_far, wg_chunk_size); + const std::size_t offset = + chunk * (lid - q * n_chunks_merged); + + if (data_in_temp) { + merge_sort_detail::merge_impl( + offset, scratch_space, work_space, start_1, + end_1, end_2, start_1, index_comp, chunk); + } + else { + merge_sort_detail::merge_impl( + offset, work_space, scratch_space, start_1, + end_1, end_2, start_1, index_comp, chunk); + } + sycl::group_barrier(it.get_group()); + } + + // output assumed to be structured as (iter_nelems, + // alloc_len) + const std::size_t k_segment_start_idx = + segment_id * k_rounded; + const std::size_t k_segment_end_idx = std::min( + k_segment_start_idx + k_rounded, alloc_len); + const auto &out_src = + (data_in_temp) ? scratch_space : work_space; + for (std::size_t array_id = k_segment_start_idx + lid; + array_id < k_segment_end_idx; array_id += lws) + { + if (lid < k_rounded) { + index_data[iter_id * alloc_len + array_id] = + out_src[array_id - k_segment_start_idx]; + } + } + }); + }); + + // Merge segments in parallel until all elements are sorted + sycl::event merges_ev = + merge_sort_detail::merge_sorted_block_contig_impl( + exec_q, iter_nelems, alloc_len, index_data, index_comp, + k_rounded, {base_sort_ev}); + + // Write out top k of the merge-sorted memory + using WriteOutKernelName = + topk_partial_merge_map_back_krn; + + sycl::event write_topk_ev = + topk_detail::write_out_impl( + exec_q, iter_nelems, k, arg_tp, index_data, alloc_len, + axis_nelems, vals_tp, inds_tp, {merges_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {write_topk_ev}, index_data_owner); + + return cleanup_host_task_event; + } +} + +template +class topk_iota_krn; + +template +class topk_radix_map_back_krn; + +template +sycl::event topk_radix_impl(sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays + std::size_t axis_nelems, // size of each sub-array + std::size_t k, + bool ascending, + const char *arg_cp, + char *vals_cp, + char *inds_cp, + const std::vector &depends) +{ + if (axis_nelems < k) { + throw std::runtime_error("Invalid sort axis size for value of k"); + } + + const argTy *arg_tp = reinterpret_cast(arg_cp); + argTy *vals_tp = reinterpret_cast(vals_cp); + IndexTy *inds_tp = reinterpret_cast(inds_cp); + + const std::size_t total_nelems = iter_nelems * axis_nelems; + const std::size_t padded_total_nelems = ((total_nelems + 63) / 64) * 64; + auto workspace_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + padded_total_nelems + total_nelems, exec_q); + + // get raw USM pointer + IndexTy *workspace = workspace_owner.get(); + IndexTy *tmp_tp = workspace + padded_total_nelems; + + using IdentityProjT = radix_sort_details::IdentityProj; + using IndexedProjT = + radix_sort_details::IndexedProj; + const IndexedProjT proj_op{arg_tp}; + + using IotaKernelName = topk_iota_krn; + + using dpctl::tensor::kernels::sort_utils_detail::iota_impl; + + sycl::event iota_ev = iota_impl( + exec_q, workspace, total_nelems, depends); + + sycl::event radix_sort_ev = + radix_sort_details::parallel_radix_sort_impl( + exec_q, iter_nelems, axis_nelems, workspace, tmp_tp, proj_op, + ascending, {iota_ev}); + + // Write out top k of the temporary + using WriteOutKernelName = topk_radix_map_back_krn; + + sycl::event write_topk_ev = + topk_detail::write_out_impl( + exec_q, iter_nelems, k, arg_tp, tmp_tp, axis_nelems, axis_nelems, + vals_tp, inds_tp, {radix_sort_ev}); + + sycl::event cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {write_topk_ev}, workspace_owner); + + return cleanup_ev; +} + +} // namespace dpctl::tensor::kernels diff --git a/dpctl_ext/tensor/libtensor/include/utils/rich_comparisons.hpp b/dpctl_ext/tensor/libtensor/include/utils/rich_comparisons.hpp new file mode 100644 index 000000000000..87cdfbfbd54f --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/utils/rich_comparisons.hpp @@ -0,0 +1,149 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===---------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include + +#include "sycl/sycl.hpp" + +namespace dpctl::tensor::rich_comparisons +{ + +namespace detail +{ +template +struct ExtendedRealFPLess +{ + /* [R, nan] */ + bool operator()(const fpT v1, const fpT v2) const + { + return (!std::isnan(v1) && (std::isnan(v2) || (v1 < v2))); + } +}; + +template +struct ExtendedRealFPGreater +{ + bool operator()(const fpT v1, const fpT v2) const + { + return (!std::isnan(v2) && (std::isnan(v1) || (v2 < v1))); + } +}; + +template +struct ExtendedComplexFPLess +{ + /* [(R, R), (R, nan), (nan, R), (nan, nan)] */ + + bool operator()(const cT &v1, const cT &v2) const + { + using realT = typename cT::value_type; + + const realT real1 = std::real(v1); + const realT real2 = std::real(v2); + + const bool r1_nan = std::isnan(real1); + const bool r2_nan = std::isnan(real2); + + const realT imag1 = std::imag(v1); + const realT imag2 = std::imag(v2); + + const bool i1_nan = std::isnan(imag1); + const bool i2_nan = std::isnan(imag2); + + const int idx1 = ((r1_nan) ? 2 : 0) + ((i1_nan) ? 1 : 0); + const int idx2 = ((r2_nan) ? 2 : 0) + ((i2_nan) ? 1 : 0); + + const bool res = + !(r1_nan && i1_nan) && + ((idx1 < idx2) || + ((idx1 == idx2) && + ((r1_nan && !i1_nan && (imag1 < imag2)) || + (!r1_nan && i1_nan && (real1 < real2)) || + (!r1_nan && !i1_nan && + ((real1 < real2) || (!(real2 < real1) && (imag1 < imag2))))))); + + return res; + } +}; + +template +struct ExtendedComplexFPGreater +{ + bool operator()(const cT &v1, const cT &v2) const + { + auto less_ = ExtendedComplexFPLess{}; + return less_(v2, v1); + } +}; + +template +inline constexpr bool is_fp_v = (std::is_same_v || + std::is_same_v || + std::is_same_v); + +} // namespace detail + +template +struct AscendingSorter +{ + using type = std::conditional_t, + detail::ExtendedRealFPLess, + std::less>; +}; + +template +struct AscendingSorter> +{ + using type = detail::ExtendedComplexFPLess>; +}; + +template +struct DescendingSorter +{ + using type = std::conditional_t, + detail::ExtendedRealFPGreater, + std::greater>; +}; + +template +struct DescendingSorter> +{ + using type = detail::ExtendedComplexFPGreater>; +}; + +} // namespace dpctl::tensor::rich_comparisons diff --git a/dpctl_ext/tensor/libtensor/source/sorting/isin.cpp b/dpctl_ext/tensor/libtensor/source/sorting/isin.cpp new file mode 100644 index 000000000000..728a7ca3b733 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/isin.cpp @@ -0,0 +1,324 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/sorting/isin.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/type_dispatch.hpp" + +#include "simplify_iteration_space.hpp" + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace dpctl::tensor::py_internal +{ +namespace detail +{ + +using dpctl::tensor::kernels::isin_contig_impl_fp_ptr_t; + +static isin_contig_impl_fp_ptr_t + isin_contig_impl_dispatch_vector[td_ns::num_types]; + +template +struct IsinContigFactory +{ + constexpr IsinContigFactory() {} + + fnT get() const + { + using dpctl::tensor::kernels::isin_contig_impl; + return isin_contig_impl; + } +}; + +using dpctl::tensor::kernels::isin_strided_impl_fp_ptr_t; + +static isin_strided_impl_fp_ptr_t + isin_strided_impl_dispatch_vector[td_ns::num_types]; + +template +struct IsinStridedFactory +{ + constexpr IsinStridedFactory() {} + + fnT get() const + { + using dpctl::tensor::kernels::isin_strided_impl; + return isin_strided_impl; + } +}; + +void init_isin_dispatch_vector(void) +{ + + // Contiguous input function dispatch + td_ns::DispatchVectorBuilder + dvb1; + dvb1.populate_dispatch_vector(isin_contig_impl_dispatch_vector); + + // Strided input function dispatch + td_ns::DispatchVectorBuilder + dvb2; + dvb2.populate_dispatch_vector(isin_strided_impl_dispatch_vector); +} + +} // namespace detail + +/*! @brief search for needle from needles in sorted hay */ +std::pair + py_isin(const dpctl::tensor::usm_ndarray &needles, + const dpctl::tensor::usm_ndarray &hay, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const bool invert, + const std::vector &depends) +{ + const int hay_nd = hay.get_ndim(); + const int needles_nd = needles.get_ndim(); + const int dst_nd = dst.get_ndim(); + + if (hay_nd != 1 || needles_nd != dst_nd) { + throw py::value_error("Array dimensions mismatch"); + } + + // check that needle and dst have the same shape + std::size_t needles_nelems(1); + bool same_shape(true); + + const std::size_t hay_nelems = static_cast(hay.get_shape(0)); + + const py::ssize_t *needles_shape_ptr = needles.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + for (int i = 0; (i < needles_nd) && same_shape; ++i) { + const auto needles_sh_i = needles_shape_ptr[i]; + const auto dst_sh_i = dst_shape_ptr[i]; + + same_shape = same_shape && (needles_sh_i == dst_sh_i); + needles_nelems *= static_cast(needles_sh_i); + } + + if (!same_shape) { + throw py::value_error( + "Array of values to search for and array of their " + "dst do not have the same shape"); + } + + // check that dst is ample enough + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, + needles_nelems); + + // check that dst is writable + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + // check that queues are compatible + if (!dpctl::utils::queues_are_compatible(exec_q, {hay, needles, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + // if output array overlaps with input arrays, race condition results + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(dst, hay) || overlap(dst, needles)) { + throw py::value_error("Destination array overlaps with input."); + } + + const int hay_typenum = hay.get_typenum(); + const int needles_typenum = needles.get_typenum(); + const int dst_typenum = dst.get_typenum(); + + auto const &array_types = td_ns::usm_ndarray_types(); + const int hay_typeid = array_types.typenum_to_lookup_id(hay_typenum); + const int needles_typeid = + array_types.typenum_to_lookup_id(needles_typenum); + const int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + // check hay and needle have the same data-type + if (needles_typeid != hay_typeid) { + throw py::value_error( + "Hay array and needles array must have the same data types"); + } + // check that dst has boolean data type + const auto dst_typenum_t_v = static_cast(dst_typeid); + if (dst_typenum_t_v != td_ns::typenum_t::BOOL) { + throw py::value_error("dst array must have data-type bool"); + } + + if (needles_nelems == 0) { + // Nothing to do + return std::make_pair(sycl::event{}, sycl::event{}); + } + + // if all inputs are contiguous call contiguous implementations + // otherwise call strided implementation + const bool hay_is_c_contig = hay.is_c_contiguous(); + const bool hay_is_f_contig = hay.is_f_contiguous(); + + const bool needles_is_c_contig = needles.is_c_contiguous(); + const bool needles_is_f_contig = needles.is_f_contiguous(); + + const bool dst_is_c_contig = dst.is_c_contiguous(); + const bool dst_is_f_contig = dst.is_f_contiguous(); + + const bool all_c_contig = + (hay_is_c_contig && needles_is_c_contig && dst_is_c_contig); + const bool all_f_contig = + (hay_is_f_contig && needles_is_f_contig && dst_is_f_contig); + + const char *hay_data = hay.get_data(); + const char *needles_data = needles.get_data(); + + char *dst_data = dst.get_data(); + + if (all_c_contig || all_f_contig) { + auto fn = detail::isin_contig_impl_dispatch_vector[hay_typeid]; + + static constexpr py::ssize_t zero_offset(0); + + sycl::event comp_ev = fn(exec_q, invert, hay_nelems, needles_nelems, + hay_data, zero_offset, needles_data, + zero_offset, dst_data, zero_offset, depends); + + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {hay, needles, dst}, {comp_ev}), + comp_ev); + } + + // strided case + + const auto &needles_strides = needles.get_strides_vector(); + const auto &dst_strides = dst.get_strides_vector(); + + int simplified_nd = needles_nd; + + using shT = std::vector; + shT simplified_common_shape; + shT simplified_needles_strides; + shT simplified_dst_strides; + py::ssize_t needles_offset(0); + py::ssize_t dst_offset(0); + + if (simplified_nd == 0) { + // needles and dst have same nd + simplified_nd = 1; + simplified_common_shape.push_back(1); + simplified_needles_strides.push_back(0); + simplified_dst_strides.push_back(0); + } + else { + dpctl::tensor::py_internal::simplify_iteration_space( + // modified by reference + simplified_nd, + // read-only inputs + needles_shape_ptr, needles_strides, dst_strides, + // output, modified by reference + simplified_common_shape, simplified_needles_strides, + simplified_dst_strides, needles_offset, dst_offset); + } + std::vector host_task_events; + host_task_events.reserve(2); + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + + auto ptr_size_event_tuple = device_allocate_and_pack( + exec_q, host_task_events, + // vectors being packed + simplified_common_shape, simplified_needles_strides, + simplified_dst_strides); + auto packed_shape_strides_owner = + std::move(std::get<0>(ptr_size_event_tuple)); + const sycl::event ©_shape_strides_ev = + std::get<2>(ptr_size_event_tuple); + const py::ssize_t *packed_shape_strides = packed_shape_strides_owner.get(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shape_strides_ev); + + auto strided_fn = detail::isin_strided_impl_dispatch_vector[hay_typeid]; + + if (!strided_fn) { + throw std::runtime_error( + "No implementation for data types of input arrays"); + } + + static constexpr py::ssize_t zero_offset(0); + py::ssize_t hay_step = hay.get_strides_vector()[0]; + + const sycl::event &comp_ev = strided_fn( + exec_q, invert, hay_nelems, needles_nelems, hay_data, zero_offset, + hay_step, needles_data, needles_offset, dst_data, dst_offset, + simplified_nd, packed_shape_strides, all_deps); + + // free packed temporaries + sycl::event temporaries_cleanup_ev = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {comp_ev}, packed_shape_strides_owner); + + host_task_events.push_back(temporaries_cleanup_ev); + const sycl::event &ht_ev = dpctl::utils::keep_args_alive( + exec_q, {hay, needles, dst}, host_task_events); + + return std::make_pair(ht_ev, comp_ev); +} + +void init_isin_functions(py::module_ m) +{ + dpctl::tensor::py_internal::detail::init_isin_dispatch_vector(); + + using dpctl::tensor::py_internal::py_isin; + m.def("_isin", &py_isin, py::arg("needles"), py::arg("hay"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("invert"), + py::arg("depends") = py::list()); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/isin.hpp b/dpctl_ext/tensor/libtensor/source/sorting/isin.hpp new file mode 100644 index 000000000000..236e8b5898c6 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/isin.hpp @@ -0,0 +1,47 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_isin_functions(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/merge_argsort.cpp b/dpctl_ext/tensor/libtensor/source/sorting/merge_argsort.cpp new file mode 100644 index 000000000000..f76a74e1f8bf --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/merge_argsort.cpp @@ -0,0 +1,159 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "utils/rich_comparisons.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/sorting/merge_sort.hpp" +#include "kernels/sorting/sort_impl_fn_ptr_t.hpp" + +#include "merge_argsort.hpp" +#include "py_argsort_common.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +using dpctl::tensor::kernels::sort_contig_fn_ptr_t; +static sort_contig_fn_ptr_t + ascending_argsort_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static sort_contig_fn_ptr_t + descending_argsort_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +template +struct AscendingArgSortContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v || + std::is_same_v) + { + using dpctl::tensor::rich_comparisons::AscendingSorter; + using Comp = typename AscendingSorter::type; + + using dpctl::tensor::kernels::stable_argsort_axis1_contig_impl; + return stable_argsort_axis1_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct DescendingArgSortContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v || + std::is_same_v) + { + using dpctl::tensor::rich_comparisons::DescendingSorter; + using Comp = typename DescendingSorter::type; + + using dpctl::tensor::kernels::stable_argsort_axis1_contig_impl; + return stable_argsort_axis1_contig_impl; + } + else { + return nullptr; + } + } +}; + +void init_merge_argsort_dispatch_tables(void) +{ + using dpctl::tensor::kernels::sort_contig_fn_ptr_t; + + td_ns::DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(ascending_argsort_contig_dispatch_table); + + td_ns::DispatchTableBuilder< + sort_contig_fn_ptr_t, DescendingArgSortContigFactory, td_ns::num_types> + dtb2; + dtb2.populate_dispatch_table(descending_argsort_contig_dispatch_table); +} + +void init_merge_argsort_functions(py::module_ m) +{ + dpctl::tensor::py_internal::init_merge_argsort_dispatch_tables(); + + auto py_argsort_ascending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return dpctl::tensor::py_internal::py_argsort( + src, trailing_dims_to_sort, dst, exec_q, depends, + dpctl::tensor::py_internal:: + ascending_argsort_contig_dispatch_table); + }; + m.def("_argsort_ascending", py_argsort_ascending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto py_argsort_descending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return dpctl::tensor::py_internal::py_argsort( + src, trailing_dims_to_sort, dst, exec_q, depends, + dpctl::tensor::py_internal:: + descending_argsort_contig_dispatch_table); + }; + m.def("_argsort_descending", py_argsort_descending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + return; +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/merge_argsort.hpp b/dpctl_ext/tensor/libtensor/source/sorting/merge_argsort.hpp new file mode 100644 index 000000000000..10777b4bc2fd --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/merge_argsort.hpp @@ -0,0 +1,47 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_merge_argsort_functions(py::module_); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/merge_sort.cpp b/dpctl_ext/tensor/libtensor/source/sorting/merge_sort.cpp new file mode 100644 index 000000000000..7bb0e37c6aed --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/merge_sort.cpp @@ -0,0 +1,138 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "utils/rich_comparisons.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/sorting/merge_sort.hpp" +#include "kernels/sorting/sort_impl_fn_ptr_t.hpp" + +#include "merge_sort.hpp" +#include "py_sort_common.hpp" + +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace dpctl::tensor::py_internal +{ + +using dpctl::tensor::kernels::sort_contig_fn_ptr_t; +static sort_contig_fn_ptr_t + ascending_sort_contig_dispatch_vector[td_ns::num_types]; +static sort_contig_fn_ptr_t + descending_sort_contig_dispatch_vector[td_ns::num_types]; + +template +struct AscendingSortContigFactory +{ + fnT get() + { + using dpctl::tensor::rich_comparisons::AscendingSorter; + using Comp = typename AscendingSorter::type; + + using dpctl::tensor::kernels::stable_sort_axis1_contig_impl; + return stable_sort_axis1_contig_impl; + } +}; + +template +struct DescendingSortContigFactory +{ + fnT get() + { + using dpctl::tensor::rich_comparisons::DescendingSorter; + using Comp = typename DescendingSorter::type; + + using dpctl::tensor::kernels::stable_sort_axis1_contig_impl; + return stable_sort_axis1_contig_impl; + } +}; + +void init_merge_sort_dispatch_vectors(void) +{ + using dpctl::tensor::kernels::sort_contig_fn_ptr_t; + + td_ns::DispatchVectorBuilder + dtv1; + dtv1.populate_dispatch_vector(ascending_sort_contig_dispatch_vector); + + td_ns::DispatchVectorBuilder + dtv2; + dtv2.populate_dispatch_vector(descending_sort_contig_dispatch_vector); +} + +void init_merge_sort_functions(py::module_ m) +{ + dpctl::tensor::py_internal::init_merge_sort_dispatch_vectors(); + + auto py_sort_ascending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return dpctl::tensor::py_internal::py_sort( + src, trailing_dims_to_sort, dst, exec_q, depends, + dpctl::tensor::py_internal::ascending_sort_contig_dispatch_vector); + }; + m.def("_sort_ascending", py_sort_ascending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto py_sort_descending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return dpctl::tensor::py_internal::py_sort( + src, trailing_dims_to_sort, dst, exec_q, depends, + dpctl::tensor::py_internal::descending_sort_contig_dispatch_vector); + }; + m.def("_sort_descending", py_sort_descending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + return; +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/merge_sort.hpp b/dpctl_ext/tensor/libtensor/source/sorting/merge_sort.hpp new file mode 100644 index 000000000000..a6bdd0a4efe9 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/merge_sort.hpp @@ -0,0 +1,47 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_merge_sort_functions(py::module_); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/py_argsort_common.hpp b/dpctl_ext/tensor/libtensor/source/sorting/py_argsort_common.hpp new file mode 100644 index 000000000000..d204b492b2ff --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/py_argsort_common.hpp @@ -0,0 +1,185 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/type_dispatch.hpp" + +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace dpctl::tensor::py_internal +{ + +template +std::pair + py_argsort(const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const sorting_contig_impl_fnT &sort_contig_fns) +{ + int src_nd = src.get_ndim(); + int dst_nd = dst.get_ndim(); + if (src_nd != dst_nd) { + throw py::value_error("The input and output arrays must have " + "the same array ranks"); + } + int iteration_nd = src_nd - trailing_dims_to_sort; + if (trailing_dims_to_sort <= 0 || iteration_nd < 0) { + throw py::value_error("Trailing_dim_to_sort must be positive, but no " + "greater than rank of the array being sorted"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + std::size_t iter_nelems(1); + + for (int i = 0; same_shapes && (i < iteration_nd); ++i) { + auto src_shape_i = src_shape_ptr[i]; + same_shapes = same_shapes && (src_shape_i == dst_shape_ptr[i]); + iter_nelems *= static_cast(src_shape_i); + } + + std::size_t sort_nelems(1); + for (int i = iteration_nd; same_shapes && (i < src_nd); ++i) { + auto src_shape_i = src_shape_ptr[i]; + same_shapes = same_shapes && (src_shape_i == dst_shape_ptr[i]); + sort_nelems *= static_cast(src_shape_i); + } + + if (!same_shapes) { + throw py::value_error( + "Destination shape does not match the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + if ((iter_nelems == 0) || (sort_nelems == 0)) { + // Nothing to do + return std::make_pair(sycl::event(), sycl::event()); + } + + // check that dst and src do not overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample( + dst, sort_nelems * iter_nelems); + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + if ((dst_typeid != static_cast(td_ns::typenum_t::INT64)) && + (dst_typeid != static_cast(td_ns::typenum_t::INT32))) + { + throw py::value_error( + "Output index array must have data type int32 or int64"); + } + + bool is_src_c_contig = src.is_c_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + + if (is_src_c_contig && is_dst_c_contig) { + if (sort_nelems > 1) { + static constexpr py::ssize_t zero_offset = py::ssize_t(0); + + auto fn = sort_contig_fns[src_typeid][dst_typeid]; + + if (fn == nullptr) { + throw py::value_error( + "Not implemented for dtypes of input arrays"); + } + + sycl::event comp_ev = + fn(exec_q, iter_nelems, sort_nelems, src.get_data(), + dst.get_data(), zero_offset, zero_offset, zero_offset, + zero_offset, depends); + + sycl::event keep_args_alive_ev = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {comp_ev}); + + return std::make_pair(keep_args_alive_ev, comp_ev); + } + else { + assert(dst.get_size() == iter_nelems); + int dst_elemsize = dst.get_elemsize(); + static constexpr int memset_val(0); + + sycl::event fill_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.memset(reinterpret_cast(dst.get_data()), memset_val, + iter_nelems * dst_elemsize); + }); + + sycl::event keep_args_alive_ev = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {fill_ev}); + + return std::make_pair(keep_args_alive_ev, fill_ev); + } + } + + throw py::value_error( + "Both source and destination arrays must be C-contiguous"); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/py_sort_common.hpp b/dpctl_ext/tensor/libtensor/source/sorting/py_sort_common.hpp new file mode 100644 index 000000000000..0fdf7bec80fd --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/py_sort_common.hpp @@ -0,0 +1,179 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/type_dispatch.hpp" + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace dpctl::tensor::py_internal +{ + +template +std::pair + py_sort(const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends, + const sorting_contig_impl_fnT &sort_contig_fns) +{ + int src_nd = src.get_ndim(); + int dst_nd = dst.get_ndim(); + if (src_nd != dst_nd) { + throw py::value_error("The input and output arrays must have " + "the same array ranks"); + } + int iteration_nd = src_nd - trailing_dims_to_sort; + if (trailing_dims_to_sort <= 0 || iteration_nd < 0) { + throw py::value_error("Trailing_dim_to_sort must be positive, but no " + "greater than rank of the array being sorted"); + } + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + std::size_t iter_nelems(1); + + for (int i = 0; same_shapes && (i < iteration_nd); ++i) { + auto src_shape_i = src_shape_ptr[i]; + same_shapes = same_shapes && (src_shape_i == dst_shape_ptr[i]); + iter_nelems *= static_cast(src_shape_i); + } + + std::size_t sort_nelems(1); + for (int i = iteration_nd; same_shapes && (i < src_nd); ++i) { + auto src_shape_i = src_shape_ptr[i]; + same_shapes = same_shapes && (src_shape_i == dst_shape_ptr[i]); + sort_nelems *= static_cast(src_shape_i); + } + + if (!same_shapes) { + throw py::value_error( + "Destination shape does not match the input shape"); + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + if ((iter_nelems == 0) || (sort_nelems == 0)) { + // Nothing to do + return std::make_pair(sycl::event(), sycl::event()); + } + + // check that dst and src do not overlap + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, dst)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample( + dst, sort_nelems * iter_nelems); + + int src_typenum = src.get_typenum(); + int dst_typenum = dst.get_typenum(); + + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + if (src_typeid != dst_typeid) { + throw py::value_error("Both input arrays must have " + "the same value data type"); + } + + bool is_src_c_contig = src.is_c_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + + if (is_src_c_contig && is_dst_c_contig) { + if (sort_nelems > 1) { + static constexpr py::ssize_t zero_offset = py::ssize_t(0); + + auto fn = sort_contig_fns[src_typeid]; + + if (nullptr == fn) { + throw py::value_error( + "Not implemented for the dtype of input arrays"); + } + + sycl::event comp_ev = + fn(exec_q, iter_nelems, sort_nelems, src.get_data(), + dst.get_data(), zero_offset, zero_offset, zero_offset, + zero_offset, depends); + + sycl::event keep_args_alive_ev = + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {comp_ev}); + + return std::make_pair(keep_args_alive_ev, comp_ev); + } + else { + assert(dst.get_size() == iter_nelems); + int src_elemsize = src.get_elemsize(); + + sycl::event copy_ev = + exec_q.copy(src.get_data(), dst.get_data(), + src_elemsize * iter_nelems, depends); + + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {src, dst}, {copy_ev}), + copy_ev); + } + } + + throw py::value_error( + "Both source and destination arrays must be C-contiguous"); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/radix_argsort.cpp b/dpctl_ext/tensor/libtensor/source/sorting/radix_argsort.cpp new file mode 100644 index 000000000000..ba9bbfd61b7c --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/radix_argsort.cpp @@ -0,0 +1,190 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "utils/type_dispatch.hpp" + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/radix_sort.hpp" +#include "kernels/sorting/sort_impl_fn_ptr_t.hpp" + +#include "py_argsort_common.hpp" +#include "radix_argsort.hpp" +#include "radix_sort_support.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace impl_ns = dpctl::tensor::kernels::radix_sort_details; + +using dpctl::tensor::kernels::sort_contig_fn_ptr_t; + +static sort_contig_fn_ptr_t + ascending_radix_argsort_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static sort_contig_fn_ptr_t + descending_radix_argsort_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +namespace +{ + +template +sycl::event argsort_axis1_contig_caller(sycl::queue &q, + std::size_t iter_nelems, + std::size_t sort_nelems, + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + using dpctl::tensor::kernels::radix_argsort_axis1_contig_impl; + + return radix_argsort_axis1_contig_impl( + q, is_ascending, iter_nelems, sort_nelems, arg_cp, res_cp, + iter_arg_offset, iter_res_offset, sort_arg_offset, sort_res_offset, + depends); +} + +} // end of anonymous namespace + +template +struct AscendingRadixArgSortContigFactory +{ + fnT get() + { + if constexpr (RadixSortSupportVector::is_defined && + (std::is_same_v || + std::is_same_v)) + { + return argsort_axis1_contig_caller< + /*ascending*/ true, argTy, IndexTy>; + } + else { + return nullptr; + } + } +}; + +template +struct DescendingRadixArgSortContigFactory +{ + fnT get() + { + if constexpr (RadixSortSupportVector::is_defined && + (std::is_same_v || + std::is_same_v)) + { + return argsort_axis1_contig_caller< + /*ascending*/ false, argTy, IndexTy>; + } + else { + return nullptr; + } + } +}; + +void init_radix_argsort_dispatch_tables(void) +{ + using dpctl::tensor::kernels::sort_contig_fn_ptr_t; + + td_ns::DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(ascending_radix_argsort_contig_dispatch_table); + + td_ns::DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table( + descending_radix_argsort_contig_dispatch_table); +} + +void init_radix_argsort_functions(py::module_ m) +{ + dpctl::tensor::py_internal::init_radix_argsort_dispatch_tables(); + + auto py_radix_argsort_ascending = + [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return dpctl::tensor::py_internal::py_argsort( + src, trailing_dims_to_sort, dst, exec_q, depends, + dpctl::tensor::py_internal:: + ascending_radix_argsort_contig_dispatch_table); + }; + m.def("_radix_argsort_ascending", py_radix_argsort_ascending, + py::arg("src"), py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto py_radix_argsort_descending = + [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return dpctl::tensor::py_internal::py_argsort( + src, trailing_dims_to_sort, dst, exec_q, depends, + dpctl::tensor::py_internal:: + descending_radix_argsort_contig_dispatch_table); + }; + m.def("_radix_argsort_descending", py_radix_argsort_descending, + py::arg("src"), py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + return; +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/radix_argsort.hpp b/dpctl_ext/tensor/libtensor/source/sorting/radix_argsort.hpp new file mode 100644 index 000000000000..89013fbb1bdc --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/radix_argsort.hpp @@ -0,0 +1,47 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_radix_argsort_functions(py::module_); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/radix_sort.cpp b/dpctl_ext/tensor/libtensor/source/sorting/radix_sort.cpp new file mode 100644 index 000000000000..fb3cab487df8 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/radix_sort.cpp @@ -0,0 +1,192 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "utils/type_dispatch.hpp" + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/sorting/radix_sort.hpp" +#include "kernels/sorting/sort_impl_fn_ptr_t.hpp" + +#include "py_sort_common.hpp" +#include "radix_sort.hpp" +#include "radix_sort_support.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace impl_ns = dpctl::tensor::kernels::radix_sort_details; + +using dpctl::tensor::kernels::sort_contig_fn_ptr_t; +static sort_contig_fn_ptr_t + ascending_radix_sort_contig_dispatch_vector[td_ns::num_types]; +static sort_contig_fn_ptr_t + descending_radix_sort_contig_dispatch_vector[td_ns::num_types]; + +namespace +{ + +template +sycl::event sort_axis1_contig_caller(sycl::queue &q, + std::size_t iter_nelems, + std::size_t sort_nelems, + const char *arg_cp, + char *res_cp, + ssize_t iter_arg_offset, + ssize_t iter_res_offset, + ssize_t sort_arg_offset, + ssize_t sort_res_offset, + const std::vector &depends) +{ + using dpctl::tensor::kernels::radix_sort_axis1_contig_impl; + + return radix_sort_axis1_contig_impl( + q, is_ascending, iter_nelems, sort_nelems, arg_cp, res_cp, + iter_arg_offset, iter_res_offset, sort_arg_offset, sort_res_offset, + depends); +} + +} // end of anonymous namespace + +template +struct AscendingRadixSortContigFactory +{ + fnT get() + { + if constexpr (RadixSortSupportVector::is_defined) { + return sort_axis1_contig_caller; + } + else { + return nullptr; + } + } +}; + +template +struct DescendingRadixSortContigFactory +{ + fnT get() + { + if constexpr (RadixSortSupportVector::is_defined) { + return sort_axis1_contig_caller; + } + else { + return nullptr; + } + } +}; + +void init_radix_sort_dispatch_vectors(void) +{ + using dpctl::tensor::kernels::sort_contig_fn_ptr_t; + + td_ns::DispatchVectorBuilder< + sort_contig_fn_ptr_t, AscendingRadixSortContigFactory, td_ns::num_types> + dtv1; + dtv1.populate_dispatch_vector(ascending_radix_sort_contig_dispatch_vector); + + td_ns::DispatchVectorBuilder + dtv2; + dtv2.populate_dispatch_vector(descending_radix_sort_contig_dispatch_vector); +} + +bool py_radix_sort_defined(int typenum) +{ + const auto &array_types = td_ns::usm_ndarray_types(); + + try { + int type_id = array_types.typenum_to_lookup_id(typenum); + return (nullptr != + ascending_radix_sort_contig_dispatch_vector[type_id]); + } catch (const std::exception &e) { + return false; + } +} + +void init_radix_sort_functions(py::module_ m) +{ + dpctl::tensor::py_internal::init_radix_sort_dispatch_vectors(); + + auto py_radix_sort_ascending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return dpctl::tensor::py_internal::py_sort( + src, trailing_dims_to_sort, dst, exec_q, depends, + dpctl::tensor::py_internal:: + ascending_radix_sort_contig_dispatch_vector); + }; + m.def("_radix_sort_ascending", py_radix_sort_ascending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto py_radix_sort_descending = [](const dpctl::tensor::usm_ndarray &src, + const int trailing_dims_to_sort, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) + -> std::pair { + return dpctl::tensor::py_internal::py_sort( + src, trailing_dims_to_sort, dst, exec_q, depends, + dpctl::tensor::py_internal:: + descending_radix_sort_contig_dispatch_vector); + }; + m.def("_radix_sort_descending", py_radix_sort_descending, py::arg("src"), + py::arg("trailing_dims_to_sort"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + m.def("_radix_sort_dtype_supported", py_radix_sort_defined); + + return; +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/radix_sort.hpp b/dpctl_ext/tensor/libtensor/source/sorting/radix_sort.hpp new file mode 100644 index 000000000000..5f3c771b464b --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/radix_sort.hpp @@ -0,0 +1,47 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_radix_sort_functions(py::module_); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/radix_sort_support.hpp b/dpctl_ext/tensor/libtensor/source/sorting/radix_sort_support.hpp new file mode 100644 index 000000000000..8d7e55a5cd28 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/radix_sort_support.hpp @@ -0,0 +1,78 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include + +namespace dpctl::tensor::py_internal +{ + +template +struct TypeDefinedEntry : std::bool_constant> +{ + static constexpr bool is_defined = true; +}; + +struct NotDefinedEntry : std::true_type +{ + static constexpr bool is_defined = false; +}; + +template +struct RadixSortSupportVector +{ + using resolver_t = + typename std::disjunction, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + TypeDefinedEntry, + NotDefinedEntry>; + + static constexpr bool is_defined = resolver_t::is_defined; +}; + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/searchsorted.cpp b/dpctl_ext/tensor/libtensor/source/sorting/searchsorted.cpp new file mode 100644 index 000000000000..5cebb74be2d0 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/searchsorted.cpp @@ -0,0 +1,478 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/sorting/searchsorted.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/rich_comparisons.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/type_dispatch.hpp" + +#include "simplify_iteration_space.hpp" + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace dpctl::tensor::py_internal +{ + +namespace detail +{ + +using dpctl::tensor::kernels::searchsorted_contig_impl_fp_ptr_t; + +static searchsorted_contig_impl_fp_ptr_t + left_side_searchsorted_contig_impl[td_ns::num_types][td_ns::num_types]; + +static searchsorted_contig_impl_fp_ptr_t + right_side_searchsorted_contig_impl[td_ns::num_types][td_ns::num_types]; + +template +struct LeftSideSearchSortedContigFactory +{ + constexpr LeftSideSearchSortedContigFactory() {} + + fnT get() const + { + if constexpr (std::is_same_v || + std::is_same_v) + { + static constexpr bool left_side_search(true); + using dpctl::tensor::kernels::searchsorted_contig_impl; + using dpctl::tensor::rich_comparisons::AscendingSorter; + + using Compare = typename AscendingSorter::type; + + return searchsorted_contig_impl; + } + else { + return nullptr; + } + } +}; + +template +struct RightSideSearchSortedContigFactory +{ + constexpr RightSideSearchSortedContigFactory() {} + + fnT get() const + { + if constexpr (std::is_same_v || + std::is_same_v) + { + static constexpr bool right_side_search(false); + + using dpctl::tensor::kernels::searchsorted_contig_impl; + using dpctl::tensor::rich_comparisons::AscendingSorter; + + using Compare = typename AscendingSorter::type; + + return searchsorted_contig_impl; + } + else { + return nullptr; + } + } +}; + +using dpctl::tensor::kernels::searchsorted_strided_impl_fp_ptr_t; + +static searchsorted_strided_impl_fp_ptr_t + left_side_searchsorted_strided_impl[td_ns::num_types][td_ns::num_types]; + +static searchsorted_strided_impl_fp_ptr_t + right_side_searchsorted_strided_impl[td_ns::num_types][td_ns::num_types]; + +template +struct LeftSideSearchSortedStridedFactory +{ + constexpr LeftSideSearchSortedStridedFactory() {} + + fnT get() const + { + if constexpr (std::is_same_v || + std::is_same_v) + { + static constexpr bool left_side_search(true); + using dpctl::tensor::kernels::searchsorted_strided_impl; + using dpctl::tensor::rich_comparisons::AscendingSorter; + + using Compare = typename AscendingSorter::type; + + return searchsorted_strided_impl; + } + else { + return nullptr; + } + } +}; + +template +struct RightSideSearchSortedStridedFactory +{ + constexpr RightSideSearchSortedStridedFactory() {} + + fnT get() const + { + if constexpr (std::is_same_v || + std::is_same_v) + { + static constexpr bool right_side_search(false); + using dpctl::tensor::kernels::searchsorted_strided_impl; + using dpctl::tensor::rich_comparisons::AscendingSorter; + + using Compare = typename AscendingSorter::type; + + return searchsorted_strided_impl; + } + else { + return nullptr; + } + } +}; + +void init_searchsorted_dispatch_table(void) +{ + + // Contiguous input function dispatch + td_ns::DispatchTableBuilder + dtb1; + dtb1.populate_dispatch_table(left_side_searchsorted_contig_impl); + + td_ns::DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(right_side_searchsorted_contig_impl); + + // Strided input function dispatch + td_ns::DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(left_side_searchsorted_strided_impl); + + td_ns::DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(right_side_searchsorted_strided_impl); +} + +} // namespace detail + +/*! @brief search for needle from needles in sorted hay */ +std::pair + py_searchsorted(const dpctl::tensor::usm_ndarray &hay, + const dpctl::tensor::usm_ndarray &needles, + const dpctl::tensor::usm_ndarray &positions, + sycl::queue &exec_q, + const bool search_left_side, + const std::vector &depends) +{ + const int hay_nd = hay.get_ndim(); + const int needles_nd = needles.get_ndim(); + const int positions_nd = positions.get_ndim(); + + if (hay_nd != 1 || needles_nd != positions_nd) { + throw py::value_error("Array dimensions mismatch"); + } + + // check that needle and positions have the same shape + std::size_t needles_nelems(1); + bool same_shape(true); + + const std::size_t hay_nelems = static_cast(hay.get_shape(0)); + + const py::ssize_t *needles_shape_ptr = needles.get_shape_raw(); + const py::ssize_t *positions_shape_ptr = positions.get_shape_raw(); + + for (int i = 0; (i < needles_nd) && same_shape; ++i) { + const auto needles_sh_i = needles_shape_ptr[i]; + const auto positions_sh_i = positions_shape_ptr[i]; + + same_shape = same_shape && (needles_sh_i == positions_sh_i); + needles_nelems *= static_cast(needles_sh_i); + } + + if (!same_shape) { + throw py::value_error( + "Array of values to search for and array of their " + "positions do not have the same shape"); + } + + // check that positions is ample enough + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(positions, + needles_nelems); + + // check that positions is writable + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(positions); + + // check that queues are compatible + if (!dpctl::utils::queues_are_compatible(exec_q, {hay, needles, positions})) + { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + // if output array overlaps with input arrays, race condition results + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(positions, hay) || overlap(positions, needles)) { + throw py::value_error("Destination array overlaps with input."); + } + + const int hay_typenum = hay.get_typenum(); + const int needles_typenum = needles.get_typenum(); + const int positions_typenum = positions.get_typenum(); + + auto const &array_types = td_ns::usm_ndarray_types(); + const int hay_typeid = array_types.typenum_to_lookup_id(hay_typenum); + const int needles_typeid = + array_types.typenum_to_lookup_id(needles_typenum); + const int positions_typeid = + array_types.typenum_to_lookup_id(positions_typenum); + + // check hay and needle have the same data-type + if (needles_typeid != hay_typeid) { + throw py::value_error( + "Hay array and needles array must have the same data types"); + } + // check that positions has indexing data-type (int32, or int64) + const auto positions_typenum_t_v = + static_cast(positions_typeid); + if (positions_typenum_t_v != td_ns::typenum_t::INT32 && + positions_typenum_t_v != td_ns::typenum_t::INT64) + { + throw py::value_error( + "Positions array must have data-type int32, or int64"); + } + + if (needles_nelems == 0) { + // Nothing to do + return std::make_pair(sycl::event{}, sycl::event{}); + } + + // if all inputs are contiguous call contiguous implementations + // otherwise call strided implementation + const bool hay_is_c_contig = hay.is_c_contiguous(); + const bool hay_is_f_contig = hay.is_f_contiguous(); + + const bool needles_is_c_contig = needles.is_c_contiguous(); + const bool needles_is_f_contig = needles.is_f_contiguous(); + + const bool positions_is_c_contig = positions.is_c_contiguous(); + const bool positions_is_f_contig = positions.is_f_contiguous(); + + const bool all_c_contig = + (hay_is_c_contig && needles_is_c_contig && positions_is_c_contig); + const bool all_f_contig = + (hay_is_f_contig && needles_is_f_contig && positions_is_f_contig); + + const char *hay_data = hay.get_data(); + const char *needles_data = needles.get_data(); + + char *positions_data = positions.get_data(); + + if (all_c_contig || all_f_contig) { + auto fn = + (search_left_side) + ? detail::left_side_searchsorted_contig_impl[hay_typeid] + [positions_typeid] + : detail::right_side_searchsorted_contig_impl[hay_typeid] + [positions_typeid]; + + if (fn) { + static constexpr py::ssize_t zero_offset(0); + + sycl::event comp_ev = + fn(exec_q, hay_nelems, needles_nelems, hay_data, zero_offset, + needles_data, zero_offset, positions_data, zero_offset, + depends); + + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {hay, needles, positions}, + {comp_ev}), + comp_ev); + } + } + + // strided case + + const auto &needles_strides = needles.get_strides_vector(); + const auto &positions_strides = positions.get_strides_vector(); + + int simplified_nd = needles_nd; + + using shT = std::vector; + shT simplified_common_shape; + shT simplified_needles_strides; + shT simplified_positions_strides; + py::ssize_t needles_offset(0); + py::ssize_t positions_offset(0); + + if (simplified_nd == 0) { + // needles and positions have same nd + simplified_nd = 1; + simplified_common_shape.push_back(1); + simplified_needles_strides.push_back(0); + simplified_positions_strides.push_back(0); + } + else { + dpctl::tensor::py_internal::simplify_iteration_space( + // modified by reference + simplified_nd, + // read-only inputs + needles_shape_ptr, needles_strides, positions_strides, + // output, modified by reference + simplified_common_shape, simplified_needles_strides, + simplified_positions_strides, needles_offset, positions_offset); + } + std::vector host_task_events; + host_task_events.reserve(2); + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + + auto ptr_size_event_tuple = device_allocate_and_pack( + exec_q, host_task_events, + // vectors being packed + simplified_common_shape, simplified_needles_strides, + simplified_positions_strides); + auto packed_shape_strides_owner = + std::move(std::get<0>(ptr_size_event_tuple)); + const sycl::event ©_shape_strides_ev = + std::get<2>(ptr_size_event_tuple); + const py::ssize_t *packed_shape_strides = packed_shape_strides_owner.get(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shape_strides_ev); + + auto strided_fn = + (search_left_side) + ? detail::left_side_searchsorted_strided_impl[hay_typeid] + [positions_typeid] + : detail::right_side_searchsorted_strided_impl[hay_typeid] + [positions_typeid]; + + if (!strided_fn) { + throw std::runtime_error( + "No implementation for data types of input arrays"); + } + + static constexpr py::ssize_t zero_offset(0); + py::ssize_t hay_step = hay.get_strides_vector()[0]; + + const sycl::event &comp_ev = strided_fn( + exec_q, hay_nelems, needles_nelems, hay_data, zero_offset, hay_step, + needles_data, needles_offset, positions_data, positions_offset, + simplified_nd, packed_shape_strides, all_deps); + + // free packed temporaries + sycl::event temporaries_cleanup_ev = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {comp_ev}, packed_shape_strides_owner); + + host_task_events.push_back(temporaries_cleanup_ev); + const sycl::event &ht_ev = dpctl::utils::keep_args_alive( + exec_q, {hay, needles, positions}, host_task_events); + + return std::make_pair(ht_ev, comp_ev); +} + +/*! @brief search for needle from needles in sorted hay, + * hay[pos] <= needle < hay[pos + 1] + */ +std::pair + py_searchsorted_left(const dpctl::tensor::usm_ndarray &hay, + const dpctl::tensor::usm_ndarray &needles, + const dpctl::tensor::usm_ndarray &positions, + sycl::queue &exec_q, + const std::vector &depends) +{ + static constexpr bool side_left(true); + return py_searchsorted(hay, needles, positions, exec_q, side_left, depends); +} + +/*! @brief search for needle from needles in sorted hay, + * hay[pos] < needle <= hay[pos + 1] + */ +std::pair + py_searchsorted_right(const dpctl::tensor::usm_ndarray &hay, + const dpctl::tensor::usm_ndarray &needles, + const dpctl::tensor::usm_ndarray &positions, + sycl::queue &exec_q, + const std::vector &depends) +{ + static constexpr bool side_right(false); + return py_searchsorted(hay, needles, positions, exec_q, side_right, + depends); +} + +void init_searchsorted_functions(py::module_ m) +{ + dpctl::tensor::py_internal::detail::init_searchsorted_dispatch_table(); + + using dpctl::tensor::py_internal::py_searchsorted_left; + using dpctl::tensor::py_internal::py_searchsorted_right; + + m.def("_searchsorted_left", &py_searchsorted_left, py::arg("hay"), + py::arg("needles"), py::arg("positions"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_searchsorted_right", &py_searchsorted_right, py::arg("hay"), + py::arg("needles"), py::arg("positions"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/searchsorted.hpp b/dpctl_ext/tensor/libtensor/source/sorting/searchsorted.hpp new file mode 100644 index 000000000000..b60dae1e0ec9 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/searchsorted.hpp @@ -0,0 +1,47 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_searchsorted_functions(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/topk.cpp b/dpctl_ext/tensor/libtensor/source/sorting/topk.cpp new file mode 100644 index 000000000000..23a8d2a4328f --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/topk.cpp @@ -0,0 +1,303 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include + +#include "kernels/sorting/topk.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/output_validation.hpp" +#include "utils/rich_comparisons.hpp" +#include "utils/type_dispatch.hpp" + +#include "topk.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +typedef sycl::event (*topk_impl_fn_ptr_t)(sycl::queue &, + std::size_t, + std::size_t, + std::size_t, + bool, + const char *, + char *, + char *, + const std::vector &); + +static topk_impl_fn_ptr_t topk_dispatch_vector[td_ns::num_types]; + +namespace +{ + +template +struct use_radix_sort : public std::false_type +{ +}; + +template +struct use_radix_sort< + T, + std::enable_if_t, + std::is_same, + std::is_same, + std::is_same, + std::is_same>::value>> + : public std::true_type +{ +}; + +template +sycl::event topk_caller(sycl::queue &exec_q, + std::size_t iter_nelems, // number of sub-arrays + std::size_t axis_nelems, // size of each sub-array + std::size_t k, + bool largest, + const char *arg_cp, + char *vals_cp, + char *inds_cp, + const std::vector &depends) +{ + if constexpr (use_radix_sort::value) { + using dpctl::tensor::kernels::topk_radix_impl; + auto ascending = !largest; + return topk_radix_impl(exec_q, iter_nelems, axis_nelems, + k, ascending, arg_cp, vals_cp, + inds_cp, depends); + } + else { + using dpctl::tensor::kernels::topk_merge_impl; + if (largest) { + using CompTy = + typename dpctl::tensor::rich_comparisons::DescendingSorter< + argTy>::type; + return topk_merge_impl( + exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp, + depends); + } + else { + using CompTy = + typename dpctl::tensor::rich_comparisons::AscendingSorter< + argTy>::type; + return topk_merge_impl( + exec_q, iter_nelems, axis_nelems, k, arg_cp, vals_cp, inds_cp, + depends); + } + } +} + +} // namespace + +std::pair + py_topk(const dpctl::tensor::usm_ndarray &src, + std::optional trailing_dims_to_search, + const std::size_t k, + const bool largest, + const dpctl::tensor::usm_ndarray &vals, + const dpctl::tensor::usm_ndarray &inds, + sycl::queue &exec_q, + const std::vector &depends) +{ + int src_nd = src.get_ndim(); + int vals_nd = vals.get_ndim(); + int inds_nd = inds.get_ndim(); + + const py::ssize_t *src_shape_ptr = src.get_shape_raw(); + const py::ssize_t *vals_shape_ptr = vals.get_shape_raw(); + const py::ssize_t *inds_shape_ptr = inds.get_shape_raw(); + + std::size_t axis_nelems(1); + std::size_t iter_nelems(1); + if (trailing_dims_to_search.has_value()) { + if (src_nd != vals_nd || src_nd != inds_nd) { + throw py::value_error("The input and output arrays must have " + "the same array ranks"); + } + + auto trailing_dims = trailing_dims_to_search.value(); + int iter_nd = src_nd - trailing_dims; + if (trailing_dims <= 0 || iter_nd < 0) { + throw py::value_error( + "trailing_dims_to_search must be positive, but no " + "greater than rank of the array being searched"); + } + + bool same_shapes = true; + for (int i = 0; same_shapes && (i < iter_nd); ++i) { + auto src_shape_i = src_shape_ptr[i]; + same_shapes = same_shapes && (src_shape_i == vals_shape_ptr[i] && + src_shape_i == inds_shape_ptr[i]); + iter_nelems *= static_cast(src_shape_i); + } + + if (!same_shapes) { + throw py::value_error( + "Destination shape does not match the input shape"); + } + + std::size_t vals_k(1); + std::size_t inds_k(1); + for (int i = iter_nd; i < src_nd; ++i) { + axis_nelems *= static_cast(src_shape_ptr[i]); + vals_k *= static_cast(vals_shape_ptr[i]); + inds_k *= static_cast(inds_shape_ptr[i]); + } + + bool valid_k = (vals_k == k && inds_k == k && axis_nelems >= k); + if (!valid_k) { + throw py::value_error("The value of k is invalid for the input and " + "destination arrays"); + } + } + else { + if (vals_nd != 1 || inds_nd != 1) { + throw py::value_error("Output arrays must be one-dimensional"); + } + + for (int i = 0; i < src_nd; ++i) { + axis_nelems *= static_cast(src_shape_ptr[i]); + } + + bool valid_k = (axis_nelems >= k && + static_cast(vals_shape_ptr[0]) == k && + static_cast(inds_shape_ptr[0]) == k); + if (!valid_k) { + throw py::value_error("The value of k is invalid for the input and " + "destination arrays"); + } + } + + if (!dpctl::utils::queues_are_compatible(exec_q, {src, vals, inds})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(vals); + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(inds); + + if ((iter_nelems == 0) || (axis_nelems == 0)) { + // Nothing to do + return std::make_pair(sycl::event(), sycl::event()); + } + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(src, vals) || overlap(src, inds)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(vals, + k * iter_nelems); + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(inds, + k * iter_nelems); + + int src_typenum = src.get_typenum(); + int vals_typenum = vals.get_typenum(); + int inds_typenum = inds.get_typenum(); + + const auto &array_types = td_ns::usm_ndarray_types(); + int src_typeid = array_types.typenum_to_lookup_id(src_typenum); + int vals_typeid = array_types.typenum_to_lookup_id(vals_typenum); + int inds_typeid = array_types.typenum_to_lookup_id(inds_typenum); + + if (src_typeid != vals_typeid) { + throw py::value_error("Input array and vals array must have " + "the same data type"); + } + + if (inds_typeid != static_cast(td_ns::typenum_t::INT64)) { + throw py::value_error("Inds array must have data type int64"); + } + + bool is_src_c_contig = src.is_c_contiguous(); + bool is_vals_c_contig = vals.is_c_contiguous(); + bool is_inds_c_contig = inds.is_c_contiguous(); + + if (is_src_c_contig && is_vals_c_contig && is_inds_c_contig) { + auto fn = topk_dispatch_vector[src_typeid]; + + sycl::event comp_ev = + fn(exec_q, iter_nelems, axis_nelems, k, largest, src.get_data(), + vals.get_data(), inds.get_data(), depends); + + sycl::event keep_args_alive_ev = + dpctl::utils::keep_args_alive(exec_q, {src, vals, inds}, {comp_ev}); + + return std::make_pair(keep_args_alive_ev, comp_ev); + } + + return std::make_pair(sycl::event(), sycl::event()); +} + +template +struct TopKFactory +{ + fnT get() + { + using IdxT = std::int64_t; + return topk_caller; + } +}; + +void init_topk_dispatch_vectors(void) +{ + td_ns::DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(topk_dispatch_vector); +} + +void init_topk_functions(py::module_ m) +{ + dpctl::tensor::py_internal::init_topk_dispatch_vectors(); + + m.def("_topk", &py_topk, py::arg("src"), py::arg("trailing_dims_to_search"), + py::arg("k"), py::arg("largest"), py::arg("vals"), py::arg("inds"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/sorting/topk.hpp b/dpctl_ext/tensor/libtensor/source/sorting/topk.hpp new file mode 100644 index 000000000000..d39c0eefdb93 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/sorting/topk.hpp @@ -0,0 +1,47 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_topk_functions(py::module_); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/tensor_sorting.cpp b/dpctl_ext/tensor/libtensor/source/tensor_sorting.cpp new file mode 100644 index 000000000000..98a2493a7c48 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/tensor_sorting.cpp @@ -0,0 +1,57 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_sorting_impl +/// extension. +//===----------------------------------------------------------------------===// + +#include + +#include "sorting/isin.hpp" +#include "sorting/merge_argsort.hpp" +#include "sorting/merge_sort.hpp" +#include "sorting/radix_argsort.hpp" +#include "sorting/radix_sort.hpp" +#include "sorting/searchsorted.hpp" +#include "sorting/topk.hpp" + +namespace py = pybind11; + +PYBIND11_MODULE(_tensor_sorting_impl, m) +{ + dpctl::tensor::py_internal::init_isin_functions(m); + dpctl::tensor::py_internal::init_merge_sort_functions(m); + dpctl::tensor::py_internal::init_merge_argsort_functions(m); + dpctl::tensor::py_internal::init_searchsorted_functions(m); + dpctl::tensor::py_internal::init_radix_sort_functions(m); + dpctl::tensor::py_internal::init_radix_argsort_functions(m); + dpctl::tensor::py_internal::init_topk_functions(m); +} diff --git a/dpnp/dpnp_iface_logic.py b/dpnp/dpnp_iface_logic.py index 1834f25a0485..3e3501b14c7c 100644 --- a/dpnp/dpnp_iface_logic.py +++ b/dpnp/dpnp_iface_logic.py @@ -49,6 +49,9 @@ import dpctl.utils as dpu import numpy +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext import dpnp import dpnp.backend.extensions.ufunc._ufunc_impl as ufi from dpnp.dpnp_algo.dpnp_elementwise_common import DPNPBinaryFunc, DPNPUnaryFunc @@ -1273,7 +1276,7 @@ def isin( usm_element = dpnp.get_usm_ndarray(element) usm_test = dpnp.get_usm_ndarray(test_elements) return dpnp_array._create_from_usm_ndarray( - dpt.isin( + dpt_ext.isin( usm_element, usm_test, invert=invert, diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index d188ae098cd9..7eb44f79ae38 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -378,22 +378,24 @@ def _get_first_nan_index(usm_a): true_val = dpt_ext.asarray( True, sycl_queue=usm_a.sycl_queue, usm_type=usm_a.usm_type ) - return dpt.searchsorted(dpt.isnan(usm_a), true_val, side="left") - return dpt.searchsorted(usm_a, usm_a[-1], side="left") + return dpt_ext.searchsorted( + dpt.isnan(usm_a), true_val, side="left" + ) + return dpt_ext.searchsorted(usm_a, usm_a[-1], side="left") return None usm_ar = dpnp.get_usm_ndarray(ar) num_of_flags = (return_index, return_inverse, return_counts).count(True) if num_of_flags == 0: - usm_res = dpt.unique_values(usm_ar) + usm_res = dpt_ext.unique_values(usm_ar) usm_res = (usm_res,) # cast to a tuple to align with other cases elif num_of_flags == 1 and return_inverse: - usm_res = dpt.unique_inverse(usm_ar) + usm_res = dpt_ext.unique_inverse(usm_ar) elif num_of_flags == 1 and return_counts: - usm_res = dpt.unique_counts(usm_ar) + usm_res = dpt_ext.unique_counts(usm_ar) else: - usm_res = dpt.unique_all(usm_ar) + usm_res = dpt_ext.unique_all(usm_ar) first_nan = None if equal_nan: diff --git a/dpnp/dpnp_iface_searching.py b/dpnp/dpnp_iface_searching.py index 15f52338ec7e..055aaa999c3a 100644 --- a/dpnp/dpnp_iface_searching.py +++ b/dpnp/dpnp_iface_searching.py @@ -382,7 +382,7 @@ def searchsorted(a, v, side="left", sorter=None): usm_sorter = None if sorter is None else dpnp.get_usm_ndarray(sorter) return dpnp_array._create_from_usm_ndarray( - dpt.searchsorted(usm_a, usm_v, side=side, sorter=usm_sorter) + dpt_ext.searchsorted(usm_a, usm_v, side=side, sorter=usm_sorter) ) diff --git a/dpnp/dpnp_iface_sorting.py b/dpnp/dpnp_iface_sorting.py index 5f7a3829b3c9..e7abef1f4338 100644 --- a/dpnp/dpnp_iface_sorting.py +++ b/dpnp/dpnp_iface_sorting.py @@ -41,11 +41,9 @@ from collections.abc import Sequence -import dpctl.tensor as dpt - # TODO: revert to `import dpctl.tensor...` # when dpnp fully migrates dpctl/tensor -import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor as dpt import dpnp from dpctl_ext.tensor._numpy_helper import normalize_axis_index @@ -87,7 +85,7 @@ def _wrap_sort_argsort( usm_a = dpnp.get_usm_ndarray(a) if axis is None: - usm_a = dpt_ext.reshape(usm_a, -1) + usm_a = dpt.reshape(usm_a, -1) axis = -1 axis = normalize_axis_index(axis, ndim=usm_a.ndim) diff --git a/pyproject.toml b/pyproject.toml index 67fb75cb5f54..09253467b8dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,7 +108,7 @@ target-version = ['py310', 'py311', 'py312', 'py313', 'py314'] [tool.codespell] builtin = "clear,rare,informal,names" check-filenames = true -ignore-words-list = "amin,arange,elemt,fro,hist,ith,mone,nd,nin,sinc,vart,GroupT,AccessorT,IndexT" +ignore-words-list = "amin,arange,elemt,fro,hist,ith,mone,nd,nin,sinc,vart,GroupT,AccessorT,IndexT,fpT,OffsetT,inpT" quiet-level = 3 [tool.coverage.report]