diff --git a/dpctl_ext/tensor/CMakeLists.txt b/dpctl_ext/tensor/CMakeLists.txt index ef3565f9827..afc7dca4db3 100644 --- a/dpctl_ext/tensor/CMakeLists.txt +++ b/dpctl_ext/tensor/CMakeLists.txt @@ -166,6 +166,10 @@ set(_sorting_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/topk.cpp ) +set(_linalg_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linalg_functions/dot.cpp +) set(_tensor_accumulation_impl_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp ${_accumulator_sources} @@ -182,6 +186,10 @@ set(_tensor_sorting_impl_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_sorting.cpp ${_sorting_sources} ) +set(_tensor_linalg_impl_sources + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_linalg.cpp + ${_linalg_sources} +) set(_static_lib_trgt simplify_iteration_space) @@ -228,6 +236,12 @@ add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_sorting_impl_s target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt}) list(APPEND _py_trgts ${python_module_name}) +set(python_module_name _tensor_linalg_impl) +pybind11_add_module(${python_module_name} MODULE ${_tensor_linalg_impl_sources}) +add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_linalg_impl_sources}) +target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt}) +list(APPEND _py_trgts ${python_module_name}) + set(_clang_prefix "") if(WIN32) set(_clang_prefix "/clang:") @@ -245,7 +259,7 @@ list( ${_elementwise_sources} ${_reduction_sources} ${_sorting_sources} - # ${_linalg_sources} + ${_linalg_sources} ${_accumulator_sources} ) diff --git a/dpctl_ext/tensor/__init__.py b/dpctl_ext/tensor/__init__.py index 70352687c5d..a6127f1fc27 100644 --- a/dpctl_ext/tensor/__init__.py +++ b/dpctl_ext/tensor/__init__.py @@ -107,6 +107,12 @@ take, take_along_axis, ) +from ._linear_algebra_functions import ( + matmul, + matrix_transpose, + tensordot, + vecdot, +) from ._manipulation_functions import ( broadcast_arrays, broadcast_to, @@ -216,6 +222,8 @@ "min", "moveaxis", "permute_dims", + "matmul", + "matrix_transpose", "negative", "nonzero", "ones", @@ -251,6 +259,7 @@ "take_along_axis", "tan", "tanh", + "tensordot", "tile", "top_k", "to_numpy", @@ -262,6 +271,7 @@ "unique_inverse", "unique_values", "unstack", + "vecdot", "where", "zeros", "zeros_like", diff --git a/dpctl_ext/tensor/_linear_algebra_functions.py b/dpctl_ext/tensor/_linear_algebra_functions.py new file mode 100644 index 00000000000..5f6edecf5e5 --- /dev/null +++ b/dpctl_ext/tensor/_linear_algebra_functions.py @@ -0,0 +1,1019 @@ +# ***************************************************************************** +# Copyright (c) 2026, Intel Corporation +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# - Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# - Neither the name of the copyright holder nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +import operator + +import dpctl +import dpctl.tensor as dpt +from dpctl.utils import ExecutionPlacementError, SequentialOrderManager + +# TODO: revert to `import dpctl.tensor...` +# when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt_ext +import dpctl_ext.tensor._tensor_elementwise_impl as tei +import dpctl_ext.tensor._tensor_impl as ti +import dpctl_ext.tensor._tensor_linalg_impl as tli + +from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK +from ._manipulation_functions import _broadcast_shape_impl +from ._numpy_helper import normalize_axis_index, normalize_axis_tuple +from ._type_utils import ( + _acceptance_fn_default_binary, + _find_buf_dtype2, + _to_device_supported_dtype, +) + + +def matrix_transpose(x): + r"""matrix_transpose(x) + + Transposes the innermost two dimensions of `x`, where `x` is a + 2-dimensional matrix or a stack of 2-dimensional matrices. + + To convert from a 1-dimensional array to a 2-dimensional column + vector, use x[:, dpt.newaxis]. + + Args: + x (usm_ndarray): + Input array with shape (..., m, n). + + Returns: + usm_ndarray: + Array with shape (..., n, m). + """ + + if not isinstance(x, dpt.usm_ndarray): + raise TypeError( + "Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x)) + ) + if x.ndim < 2: + raise ValueError( + "dpctl.tensor.matrix_transpose requires array to have" + "at least 2 dimensions" + ) + + return x.mT + + +def tensordot(x1, x2, axes=2): + r"""tensordot(x1, x2, axes=2) + + Returns a tensor contraction of `x1` and `x2` over specific axes. + + Args: + x1 (usm_ndarray): + first input array, expected to have numeric data type. + x2 (usm_ndarray): + second input array, expected to have numeric data type. + Corresponding contracted axes of `x1` and `x2` must be equal. + axes (Union[int, Tuple[Sequence[int], Sequence[int]]): + number of axes to contract or explicit sequences of axes for + `x1` and `x2`, respectively. If `axes` is an integer equal to `N`, + then the contraction is performed over last `N` axes of `x1` and + the first `N` axis of `x2` in order. The size of each corresponding + axis must match and must be non-negative. + + * if `N` equals `0`, the result is the tensor outer product + * if `N` equals `1`, the result is the tensor dot product + * if `N` equals `2`, the result is the tensor double + contraction (default). + + If `axes` is a tuple of two sequences `(x1_axes, x2_axes)`, the + first sequence applies to `x1` and the second sequence applies + to `x2`. Both sequences must have equal length, and each axis + `x1_axes[i]` for `x1` must have the same size as the respective + axis `x2_axes[i]` for `x2`. Each sequence must consist of unique + integers that specify valid axes for each respective array. + For example, if `x1` has rank `N`, a valid axis must reside on the + half-open interval `[-N, N)`. + Returns: + usm_ndarray: + an array containing the tensor contraction whose shape consists of + the non-contracted axes of the first array `x1`, followed by the + non-contracted axes of the second array `x2`. The returned array + must have a data type determined by Type Promotion Rules. + """ + if not isinstance(x1, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}") + if not isinstance(x2, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") + q1, x1_usm_type = x1.sycl_queue, x1.usm_type + q2, x2_usm_type = x2.sycl_queue, x2.usm_type + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x1_usm_type, + x2_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + # handle axes and shapes validation + x1_nd = x1.ndim + x2_nd = x2.ndim + x1_shape = x1.shape + x2_shape = x2.shape + if isinstance(axes, int): + if axes < 0: + raise ValueError("`axes` integer is expected to be non-negative") + n_axes1 = axes + n_axes2 = axes + axes1 = normalize_axis_tuple(tuple(range(-axes, 0)), x1_nd) + axes2 = tuple(range(0, axes)) + elif isinstance(axes, tuple): + if len(axes) != 2: + raise ValueError( + "`axes` tuple is expected to contain two sequences" + ) + axes1 = tuple(axes[0]) + axes2 = tuple(axes[1]) + n_axes1 = len(axes1) + n_axes2 = len(axes2) + else: + raise TypeError("`axes` must be an integer or a tuple of sequences") + if n_axes1 != n_axes2: + raise ValueError( + "number of axes contracted must be the same for each array" + ) + if n_axes1 == 0: + arr1 = x1[..., dpt.newaxis] + arr2 = x2[dpt.newaxis, ...] + n_axes1 = 1 + n_axes2 = 1 + else: + same_shapes = True + for i in range(n_axes1): + axis1 = axes1[i] + axis2 = axes2[i] + same_shapes = same_shapes and (x1_shape[axis1] == x2_shape[axis2]) + if not same_shapes: + raise ValueError("shape mismatch in contracted `tensordot` axes") + axes1 = normalize_axis_tuple(axes1, x1_nd) + axes2 = normalize_axis_tuple(axes2, x2_nd) + perm1 = [i for i in range(x1_nd) if i not in axes1] + list(axes1) + perm2 = list(axes2) + [i for i in range(x2_nd) if i not in axes2] + arr1 = dpt_ext.permute_dims(x1, perm1) + arr2 = dpt_ext.permute_dims(x2, perm2) + arr1_outer_nd = arr1.ndim - n_axes1 + arr2_outer_nd = arr2.ndim - n_axes2 + res_shape = arr1.shape[:arr1_outer_nd] + arr2.shape[n_axes2:] + # dtype validation + sycl_dev = exec_q.sycl_device + x1_dtype = x1.dtype + x2_dtype = x2.dtype + buf1_dt, buf2_dt, res_dt = _find_buf_dtype2( + x1_dtype, + x2_dtype, + tli._dot_result_type, + sycl_dev, + acceptance_fn=_acceptance_fn_default_binary, + ) + if res_dt is None: + raise TypeError( + "function 'tensordot' does not support input types " + f"({x1_dtype}, {x2_dtype}), " + "and the inputs could not be safely coerced to any " + "supported types according to the casting rule ''safe''." + ) + + _manager = SequentialOrderManager[exec_q] + if buf1_dt is None and buf2_dt is None: + out = dpt_ext.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + dep_evs = _manager.submitted_events + ht_dot_ev, dot_ev = tli._dot( + x1=arr1, + x2=arr2, + batch_dims=0, + x1_outer_dims=arr1_outer_nd, + x2_outer_dims=arr2_outer_nd, + inner_dims=n_axes1, + dst=out, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_dot_ev, dot_ev) + + return out + + elif buf1_dt is None: + buf2 = _empty_like_orderK(arr2, buf2_dt) + + dep_evs = _manager.submitted_events + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr2, dst=buf2, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_copy_ev, copy_ev) + out = dpt_ext.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, dot_ev = tli._dot( + x1=arr1, + x2=buf2, + batch_dims=0, + x1_outer_dims=arr1_outer_nd, + x2_outer_dims=arr2_outer_nd, + inner_dims=n_axes1, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_dot_ev, dot_ev) + + return out + + elif buf2_dt is None: + buf1 = _empty_like_orderK(arr1, buf1_dt) + dep_evs = _manager.submitted_events + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr1, dst=buf1, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_copy_ev, copy_ev) + out = dpt_ext.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, dot_ev = tli._dot( + x1=buf1, + x2=arr2, + batch_dims=0, + x1_outer_dims=arr1_outer_nd, + x2_outer_dims=arr2_outer_nd, + inner_dims=n_axes1, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_dot_ev, dot_ev) + + return out + + buf1 = _empty_like_orderK(arr1, buf1_dt) + deps_ev = _manager.submitted_events + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr1, dst=buf1, sycl_queue=exec_q, depends=deps_ev + ) + _manager.add_event_pair(ht_copy1_ev, copy1_ev) + buf2 = _empty_like_orderK(arr2, buf2_dt) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=arr2, dst=buf2, sycl_queue=exec_q, depends=deps_ev + ) + _manager.add_event_pair(ht_copy2_ev, copy2_ev) + out = dpt_ext.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_, dot_ev = tli._dot( + x1=buf1, + x2=buf2, + batch_dims=0, + x1_outer_dims=arr1_outer_nd, + x2_outer_dims=arr2_outer_nd, + inner_dims=n_axes1, + dst=out, + sycl_queue=exec_q, + depends=[copy1_ev, copy2_ev], + ) + _manager.add_event_pair(ht_, dot_ev) + + return out + + +def vecdot(x1, x2, axis=-1): + r"""vecdot(x1, x2, axis=-1) + + Computes the (vector) dot product of two arrays. + + Args: + x1 (usm_ndarray): + first input array. + x2 (usm_ndarray): + second input array. Input arrays must have compatible + shapes along non-contract axes according to broadcasting + rules, and must have the same size along the contracted + axis. Input arrays should be of numeric type. + axis (Optional[int]): + axis over which to compute the dot product. The axis must + be an integer on the interval `[-N, -1]`, where `N` is + ``min(x1.ndim, x2.ndim)``. The axis along which dot product + is performed is counted backward from the last axes + (that is, `-1` refers to the last axis). By default, + dot product is computed over the last axis. + Default: `-1`. + + Returns: + usm_ndarray: + if `x1` and `x2` are both one-dimensional arrays, a + zero-dimensional array containing the dot product value + is returned; otherwise, a non-zero-dimensional array containing + the dot products and having rank `N-1`, where `N` is the rank + of the shape of input arrays after broadcasting rules are applied + to non-contracted axes. + """ + if not isinstance(x1, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}") + if not isinstance(x2, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") + q1, x1_usm_type = x1.sycl_queue, x1.usm_type + q2, x2_usm_type = x2.sycl_queue, x2.usm_type + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x1_usm_type, + x2_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + # axis and shape validation + x1_nd = x1.ndim + x2_nd = x2.ndim + x1_shape = x1.shape + x2_shape = x2.shape + if axis >= 0: + raise ValueError("`axis` must be negative") + axis = operator.index(axis) + x1_axis = normalize_axis_index(axis, x1_nd) + x2_axis = normalize_axis_index(axis, x2_nd) + if x1_shape[x1_axis] != x2_shape[x2_axis]: + raise ValueError( + "given axis must have the same shape for `x1` and `x2`" + ) + if x1_nd > x2_nd: + x2_shape = (1,) * (x1_nd - x2_nd) + x2_shape + elif x2_nd > x1_nd: + x1_shape = (1,) * (x2_nd - x1_nd) + x1_shape + try: + broadcast_sh = _broadcast_shape_impl( + [ + x1_shape, + x2_shape, + ] + ) + except ValueError: + raise ValueError("mismatch in `vecdot` dimensions") + broadcast_nd = len(broadcast_sh) + contracted_axis = normalize_axis_index(axis, broadcast_nd) + res_sh = tuple( + [broadcast_sh[i] for i in range(broadcast_nd) if i != contracted_axis] + ) + # dtype validation + sycl_dev = exec_q.sycl_device + x1_dtype = x1.dtype + x2_dtype = x2.dtype + buf1_dt, buf2_dt, res_dt = _find_buf_dtype2( + x1_dtype, + x2_dtype, + tli._dot_result_type, + sycl_dev, + acceptance_fn=_acceptance_fn_default_binary, + ) + if res_dt is None: + raise TypeError( + "function 'vecdot' does not support input types " + f"({x1_dtype}, {x2_dtype}), " + "and the inputs could not be safely coerced to any " + "supported types according to the casting rule ''safe''." + ) + + _manager = SequentialOrderManager[exec_q] + if buf1_dt is None and buf2_dt is None: + if x1.dtype.kind == "c": + x1_tmp = _empty_like_orderK(x1, x1.dtype) + dep_evs = _manager.submitted_events + ht_conj_ev, conj_ev = tei._conj( + src=x1, dst=x1_tmp, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_conj_ev, conj_ev) + x1 = x1_tmp + if x1.shape != broadcast_sh: + x1 = dpt_ext.broadcast_to(x1, broadcast_sh) + if x2.shape != broadcast_sh: + x2 = dpt_ext.broadcast_to(x2, broadcast_sh) + x1 = dpt_ext.moveaxis(x1, contracted_axis, -1) + x2 = dpt_ext.moveaxis(x2, contracted_axis, -1) + out = dpt_ext.empty( + res_sh, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + dep_evs = _manager.submitted_events + ht_dot_ev, dot_ev = tli._dot( + x1=x1, + x2=x2, + batch_dims=len(res_sh), + x1_outer_dims=0, + x2_outer_dims=0, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_dot_ev, dot_ev) + return dpt_ext.reshape(out, res_sh) + + elif buf1_dt is None: + if x1.dtype.kind == "c": + x1_tmp = _empty_like_orderK(x1, x1.dtype) + deps_ev = _manager.submitted_events + ht_conj_ev, conj_e = tei._conj( + src=x1, dst=x1_tmp, sycl_queue=exec_q, depends=deps_ev + ) + _manager.add_event_pair(ht_conj_ev, conj_e) + x1 = x1_tmp + buf2 = _empty_like_orderK(x2, buf2_dt) + deps_ev = _manager.submitted_events + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=buf2, sycl_queue=exec_q, depends=deps_ev + ) + _manager.add_event_pair(ht_copy_ev, copy_ev) + if x1.shape != broadcast_sh: + x1 = dpt_ext.broadcast_to(x1, broadcast_sh) + if buf2.shape != broadcast_sh: + buf2 = dpt_ext.broadcast_to(buf2, broadcast_sh) + x1 = dpt_ext.moveaxis(x1, contracted_axis, -1) + buf2 = dpt_ext.moveaxis(buf2, contracted_axis, -1) + out = dpt_ext.empty( + res_sh, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + ht_dot_ev, dot_ev = tli._dot( + x1=x1, + x2=buf2, + batch_dims=len(res_sh), + x1_outer_dims=0, + x2_outer_dims=0, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_dot_ev, dot_ev) + return dpt_ext.reshape(out, res_sh) + + elif buf2_dt is None: + buf1 = _empty_like_orderK(x1, buf1_dt) + deps_ev = _manager.submitted_events + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=buf1, sycl_queue=exec_q, depends=deps_ev + ) + _manager.add_event_pair(ht_copy_ev, copy_ev) + if buf1.dtype.kind == "c": + ht_conj_ev, conj_ev = tei._conj( + src=buf1, dst=buf1, sycl_queue=exec_q, depends=[copy_ev] + ) + _manager.add_event_pair(ht_conj_ev, conj_ev) + if buf1.shape != broadcast_sh: + buf1 = dpt_ext.broadcast_to(buf1, broadcast_sh) + if x2.shape != broadcast_sh: + x2 = dpt_ext.broadcast_to(x2, broadcast_sh) + buf1 = dpt_ext.moveaxis(buf1, contracted_axis, -1) + x2 = dpt_ext.moveaxis(x2, contracted_axis, -1) + out = dpt_ext.empty( + res_sh, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + deps_ev = _manager.submitted_events + ht_dot_ev, dot_ev = tli._dot( + x1=buf1, + x2=x2, + batch_dims=len(res_sh), + x1_outer_dims=0, + x2_outer_dims=0, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=deps_ev, + ) + _manager.add_event_pair(ht_dot_ev, dot_ev) + return dpt_ext.reshape(out, res_sh) + + buf1 = _empty_like_orderK(x1, buf1_dt) + deps_ev = _manager.submitted_events + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=buf1, sycl_queue=exec_q, depends=deps_ev + ) + _manager.add_event_pair(ht_copy1_ev, copy1_ev) + if buf1.dtype.kind == "c": + ht_conj_ev, conj_ev = tei._conj( + src=buf1, dst=buf1, sycl_queue=exec_q, depends=[copy1_ev] + ) + _manager.add_event_pair(ht_conj_ev, conj_ev) + buf2 = _empty_like_orderK(x2, buf2_dt) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=buf2, sycl_queue=exec_q, depends=deps_ev + ) + _manager.add_event_pair(ht_copy2_ev, copy2_ev) + if buf1.shape != broadcast_sh: + buf1 = dpt_ext.broadcast_to(buf1, broadcast_sh) + if buf2.shape != broadcast_sh: + buf2 = dpt_ext.broadcast_to(buf2, broadcast_sh) + buf1 = dpt_ext.moveaxis(buf1, contracted_axis, -1) + buf2 = dpt_ext.moveaxis(buf2, contracted_axis, -1) + out = dpt_ext.empty( + res_sh, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order="C", + ) + deps_ev = _manager.submitted_events + ht_dot_ev, dot_ev = tli._dot( + x1=buf1, + x2=buf2, + batch_dims=len(res_sh), + x1_outer_dims=0, + x2_outer_dims=0, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=deps_ev, + ) + _manager.add_event_pair(ht_dot_ev, dot_ev) + return out + + +def matmul(x1, x2, out=None, dtype=None, order="K"): + r"""matmul(x1, x2, out=None, order="K") + + Computes the matrix product. Implements the same semantics + as the built-in operator `@`. + + Args: + x1 (usm_ndarray): + first input array. Expected to have numeric data type, and + at least one dimension. If `x1` is one-dimensional having + shape `(M,)`, and `x2` has more than one dimension, `x1` is + effectively treated as a two-dimensional array with shape `(1, M)`, + although the prepended dimension is removed from the output array. + If `x1` has shape `(..., M, K)`, the innermost two dimensions form + matrices on which to perform matrix multiplication. + x2 (usm_ndarray): + second input array. Expected to have numeric data type, and + at least one dimension. If `x2` is one-dimensional having + shape `(N,)`, and `x1` has more than one dimension, `x2` is + effectively treated as a two-dimensional array with shape `(N, 1)`, + although the appended dimension is removed from the output array. + If `x2` has shape `(..., K, N)`, the innermost two dimensions form + matrices on which to perform matrix multiplication. + out (Optional[usm_ndarray]): + the array into which the result of the matrix product is written. + The data type of `out` must match the expected data type of the + result or (if provided) `dtype`. + If `None` then a new array is returned. Default: `None`. + dtype (Optional[dtype]): + data type of the returned array. If `None`, the data type of the + returned array is determined by the Type Promotion Rules. + Default: `None`. + order (["K", "C", "F", "A"]): + memory layout of the output array, if `out` is `None`, otherwise + the `order` parameter value is not used. Default: `K`. + Returns: + usm_ndarray: + * if both `x1` and `x2` are one-dimensional arrays with shape + `(N,)`, returned array is a zero-dimensional array containing + inner product as its only element. + * if `x1` is two-dimensional array with shape `(M, K)` and `x2` is + a two-dimensional array with shape `(K, N)`, returned array is a + two-dimensional array with shape `(M, N)` and contains the + conventional matrix product. + * if `x1` is a one-dimensional array with shape `(K,)` and `x2` is + an array with shape `(..., K, N)`, returned array contains the + conventional matrix product and has shape `(..., N)`. + * if `x1` is an array with shape `(..., M, K)` and `x2` is a + one-dimensional array with shape `(K,)`, returned array has shape + `(..., M)` and contains the conventional matrix product. + * if `x1` is a two-dimensional array with shape `(M, K)` and `x2` + is an array with shape `(..., K, N)`, returned array contains + conventional matrix product for each stacked matrix and has shape + `(..., M, N)`. + * if `x1` has shape `(..., M, K)` and `x2` is a two-dimensional + array with shape `(K, N)`, returned array contains conventional + matrix product for each stacked matrix and has shape + `(..., M, N)`. + * if both `x1` and `x2` have more than two dimensions, returned + array contains conventional matrix product for each stacked + matrix and has shape determined by broadcasting rules for + `x1.shape[:-2]` and `x2.shape[:-2]`. + + The data type of the returned array is determined by the Type + Promotion Rules. If either `x1` or `x2` has a complex floating + point type, neither argument is complex conjugated or transposed. + """ + if not isinstance(x1, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}") + if not isinstance(x2, dpt.usm_ndarray): + raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}") + if order not in ["K", "C", "F", "A"]: + order = "K" + q1, x1_usm_type = x1.sycl_queue, x1.usm_type + q2, x2_usm_type = x2.sycl_queue, x2.usm_type + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + x1_usm_type, + x2_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + + x1_nd = x1.ndim + x2_nd = x2.ndim + if x1_nd == 0 or x2_nd == 0: + raise ValueError("one or more operands to `matmul` is 0 dimensional") + x1_shape = x1.shape + x2_shape = x2.shape + appended_axes = [] + if x1_nd == 1: + x1 = x1[dpt.newaxis, :] + x1_shape = x1.shape + appended_axes.append(-2) + if x2_nd == 1: + x2 = x2[:, dpt.newaxis] + x2_shape = x2.shape + appended_axes.append(-1) + if x1_shape[-1] != x2_shape[-2]: + raise ValueError("mismatch in `matmul` inner dimension") + x1_outer_sh = x1_shape[:-2] + x2_outer_sh = x2_shape[:-2] + try: + res_outer_sh = _broadcast_shape_impl( + [ + x1_outer_sh, + x2_outer_sh, + ] + ) + except ValueError: + raise ValueError("mismatch in `matmul` batching dimensions") + x1_broadcast_shape = res_outer_sh + x1_shape[-2:] + x2_broadcast_shape = res_outer_sh + x2_shape[-2:] + res_shape = res_outer_sh + x1_shape[-2:-1] + x2_shape[-1:] + + sycl_dev = exec_q.sycl_device + x1_dtype = x1.dtype + x2_dtype = x2.dtype + if dtype is None: + buf1_dt, buf2_dt, res_dt = _find_buf_dtype2( + x1_dtype, + x2_dtype, + tli._dot_result_type, + sycl_dev, + acceptance_fn=_acceptance_fn_default_binary, + ) + if res_dt is None: + raise ValueError( + "function 'matmul' does not support input types " + f"({x1_dtype}, {x2_dtype}), " + "and the inputs could not be safely coerced to any " + "supported types according to the casting rule ''safe''." + ) + else: + res_dt = dpt.dtype(dtype) + res_dt = _to_device_supported_dtype(res_dt, sycl_dev) + buf1_dt, buf2_dt = None, None + if x1_dtype != res_dt: + if dpt_ext.can_cast(x1_dtype, res_dt, casting="same_kind"): + buf1_dt = res_dt + else: + raise ValueError( + r"`matmul` input `x1` cannot be cast from " + f"{x1_dtype} to " + f"requested type {res_dt} according to the casting rule " + "''same_kind''." + ) + if x2_dtype != res_dt: + if dpt_ext.can_cast(x2_dtype, res_dt, casting="same_kind"): + buf2_dt = res_dt + else: + raise ValueError( + r"`matmul` input `x2` cannot be cast from " + f"{x2_dtype} to " + f"requested type {res_dt} according to the casting rule " + "''same_kind''." + ) + + orig_out = out + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" + ) + + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + + final_res_shape = tuple( + res_shape[i] + for i in range(-len(res_shape), 0) + if i not in appended_axes + ) + if out.shape != final_res_shape: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {final_res_shape}, got {out.shape}" + ) + + if appended_axes: + out = dpt_ext.expand_dims(out, axis=appended_axes) + orig_out = out + + if res_dt != out.dtype: + raise ValueError( + f"Output array of type {res_dt} is needed, got {out.dtype}" + ) + + if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None: + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + + if ti._array_overlap(x1, out) and buf1_dt is None: + out = dpt_ext.empty_like(out) + + if ti._array_overlap(x2, out) and buf2_dt is None: + # should not reach if out is reallocated + # after being checked against x1 + out = dpt_ext.empty_like(out) + + if order == "A": + order = ( + "F" + if all( + arr.flags.f_contiguous + for arr in ( + x1, + x2, + ) + ) + else "C" + ) + + _manager = SequentialOrderManager[exec_q] + if buf1_dt is None and buf2_dt is None: + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + x1, x2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt_ext.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + if x1.shape != x1_broadcast_shape: + x1 = dpt_ext.broadcast_to(x1, x1_broadcast_shape) + if x2.shape != x2_broadcast_shape: + x2 = dpt_ext.broadcast_to(x2, x2_broadcast_shape) + deps_evs = _manager.submitted_events + ht_dot_ev, dot_ev = tli._dot( + x1=x1, + x2=x2, + batch_dims=len(res_shape[:-2]), + x1_outer_dims=1, + x2_outer_dims=1, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=deps_evs, + ) + _manager.add_event_pair(ht_dot_ev, dot_ev) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[dot_ev], + ) + _manager.add_event_pair(ht_copy_out_ev, cpy_ev) + out = orig_out + if appended_axes: + out = dpt_ext.squeeze(out, tuple(appended_axes)) + return out + elif buf1_dt is None: + if order == "K": + buf2 = _empty_like_orderK(x2, buf2_dt) + else: + buf2 = dpt_ext.empty_like(x2, dtype=buf2_dt, order=order) + deps_evs = _manager.submitted_events + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=buf2, sycl_queue=exec_q, depends=deps_evs + ) + _manager.add_event_pair(ht_copy_ev, copy_ev) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + x1, buf2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt_ext.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + if x1.shape != x1_broadcast_shape: + x1 = dpt_ext.broadcast_to(x1, x1_broadcast_shape) + if buf2.shape != x2_broadcast_shape: + buf2 = dpt_ext.broadcast_to(buf2, x2_broadcast_shape) + ht_dot_ev, dot_ev = tli._dot( + x1=x1, + x2=buf2, + batch_dims=len(res_shape[:-2]), + x1_outer_dims=1, + x2_outer_dims=1, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_dot_ev, dot_ev) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[dot_ev], + ) + _manager.add_event_pair(ht_copy_out_ev, cpy_ev) + out = orig_out + if appended_axes: + out = dpt_ext.squeeze(out, tuple(appended_axes)) + return out + + elif buf2_dt is None: + if order == "K": + buf1 = _empty_like_orderK(x1, buf1_dt) + else: + buf1 = dpt_ext.empty_like(x1, dtype=buf1_dt, order=order) + deps_ev = _manager.submitted_events + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=buf1, sycl_queue=exec_q, depends=deps_ev + ) + _manager.add_event_pair(ht_copy_ev, copy_ev) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + buf1, x2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt_ext.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + if buf1.shape != x1_broadcast_shape: + buf1 = dpt_ext.broadcast_to(buf1, x1_broadcast_shape) + if x2.shape != x2_broadcast_shape: + x2 = dpt_ext.broadcast_to(x2, x2_broadcast_shape) + ht_dot_ev, dot_ev = tli._dot( + x1=buf1, + x2=x2, + batch_dims=len(res_shape[:-2]), + x1_outer_dims=1, + x2_outer_dims=1, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_dot_ev, dot_ev) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[dot_ev], + ) + _manager.add_event_pair(ht_copy_out_ev, cpy_ev) + out = orig_out + if appended_axes: + out = dpt_ext.squeeze(out, tuple(appended_axes)) + return out + + if order == "K": + if x1.flags.c_contiguous and x2.flags.c_contiguous: + order = "C" + elif x1.flags.f_contiguous and x2.flags.f_contiguous: + order = "F" + if order == "K": + buf1 = _empty_like_orderK(x1, buf1_dt) + else: + buf1 = dpt_ext.empty_like(x1, dtype=buf1_dt, order=order) + deps_ev = _manager.submitted_events + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x1, dst=buf1, sycl_queue=exec_q, depends=deps_ev + ) + _manager.add_event_pair(ht_copy1_ev, copy1_ev) + if order == "K": + buf2 = _empty_like_orderK(x2, buf2_dt) + else: + buf2 = dpt_ext.empty_like(x2, dtype=buf2_dt, order=order) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=x2, dst=buf2, sycl_queue=exec_q, depends=deps_ev + ) + _manager.add_event_pair(ht_copy2_ev, copy2_ev) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + buf1, buf2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt_ext.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + if buf1.shape != x1_broadcast_shape: + buf1 = dpt_ext.broadcast_to(buf1, x1_broadcast_shape) + if buf2.shape != x2_broadcast_shape: + buf2 = dpt_ext.broadcast_to(buf2, x2_broadcast_shape) + ht_, dot_ev = tli._dot( + x1=buf1, + x2=buf2, + batch_dims=len(res_shape[:-2]), + x1_outer_dims=1, + x2_outer_dims=1, + inner_dims=1, + dst=out, + sycl_queue=exec_q, + depends=[copy1_ev, copy2_ev], + ) + _manager.add_event_pair(ht_, dot_ev) + if appended_axes: + out = dpt_ext.squeeze(out, tuple(appended_axes)) + return out diff --git a/dpctl_ext/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp b/dpctl_ext/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp new file mode 100644 index 00000000000..bb7b701e880 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp @@ -0,0 +1,1401 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for the vector dot product. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/reductions.hpp" +#include "utils/offset_utils.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl::tensor::kernels +{ + +using dpctl::tensor::ssize_t; +namespace su_ns = dpctl::tensor::sycl_utils; + +template +struct SequentialDotProduct +{ +private: + const lhsT *lhs_ = nullptr; + const rhsT *rhs_ = nullptr; + outT *out_ = nullptr; + BatchIndexerT batch_indexer_; + RedIndexerT reduced_dims_indexer_; + std::size_t reduction_max_gid_ = 0; + +public: + SequentialDotProduct(const lhsT *lhs, + const rhsT *rhs, + outT *out, + BatchIndexerT batch_indexer, + RedIndexerT reduced_dims_indexer, + std::size_t reduction_size) + : lhs_(lhs), rhs_(rhs), out_(out), batch_indexer_(batch_indexer), + reduced_dims_indexer_(reduced_dims_indexer), + reduction_max_gid_(reduction_size) + { + } + + void operator()(sycl::id<1> id) const + { + + auto const &batch_offsets = batch_indexer_(id[0]); + const ssize_t &lhs_batch_offset = batch_offsets.get_first_offset(); + const ssize_t &rhs_batch_offset = batch_offsets.get_second_offset(); + const ssize_t &out_batch_offset = batch_offsets.get_third_offset(); + + outT red_val(0); + for (std::size_t m = 0; m < reduction_max_gid_; ++m) { + auto reduction_offsets = reduced_dims_indexer_(m); + auto lhs_reduction_offset = reduction_offsets.get_first_offset(); + auto rhs_reduction_offset = reduction_offsets.get_second_offset(); + + using dpctl::tensor::type_utils::convert_impl; + red_val += convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + } + + out_[out_batch_offset] = red_val; + } +}; + +template +struct DotProductFunctor +{ +private: + const lhsT *lhs_ = nullptr; + const rhsT *rhs_ = nullptr; + outT *out_ = nullptr; + ReductionOpT reduction_op_; + BatchIndexerT batch_indexer_; + RedIndexerT reduced_dims_indexer_; + std::size_t reduction_max_gid_ = 0; + std::size_t batches_ = 1; + std::size_t reductions_per_wi = 16; + +public: + DotProductFunctor(const lhsT *lhs, + const rhsT *rhs, + outT *res, + const ReductionOpT &reduction_op, + const BatchIndexerT &batch_indexer, + const RedIndexerT &arg_reduced_dims_indexer, + std::size_t reduction_size, + std::size_t iteration_size, + std::size_t reduction_size_per_wi) + : lhs_(lhs), rhs_(rhs), out_(res), reduction_op_(reduction_op), + batch_indexer_(batch_indexer), + reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), batches_(iteration_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t batch_id = it.get_group(0) % batches_; + const std::size_t reduction_batch_id = it.get_group(0) / batches_; + + const std::size_t reduction_lid = it.get_local_id(0); + const std::size_t wg = + it.get_local_range(0); // 0 <= reduction_lid < wg + + // work-items operate over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + // for each input + + const auto &batch_offsets_ = batch_indexer_(batch_id); + const auto &lhs_batch_offset = batch_offsets_.get_first_offset(); + const auto &rhs_batch_offset = batch_offsets_.get_second_offset(); + const auto &out_batch_offset = batch_offsets_.get_third_offset(); + + outT local_red_val(0); + std::size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + std::size_t arg_reduce_gid_max = std::min( + reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg); + + for (std::size_t arg_reduce_gid = arg_reduce_gid0; + arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg) + { + auto reduction_offsets_ = reduced_dims_indexer_(arg_reduce_gid); + const auto &lhs_reduction_offset = + reduction_offsets_.get_first_offset(); + const auto &rhs_reduction_offset = + reduction_offsets_.get_second_offset(); + + using dpctl::tensor::type_utils::convert_impl; + outT val = convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + + local_red_val += val; + } + + auto work_group = it.get_group(); + outT red_val_over_wg = sycl::reduce_over_group( + work_group, local_red_val, outT(0), reduction_op_); + + if (work_group.leader()) { + sycl::atomic_ref + res_ref(out_[out_batch_offset]); + res_ref += red_val_over_wg; + } + } +}; + +template +struct DotProductCustomFunctor +{ +private: + const lhsT *lhs_ = nullptr; + const rhsT *rhs_ = nullptr; + outT *out_ = nullptr; + ReductionOpT reduction_op_; + BatchIndexerT batch_indexer_; + RedIndexerT reduced_dims_indexer_; + SlmT local_mem_; + std::size_t reduction_max_gid_ = 0; + std::size_t batches_ = 1; + std::size_t reductions_per_wi = 16; + +public: + DotProductCustomFunctor(const lhsT *lhs, + const rhsT *rhs, + outT *res, + const ReductionOpT &reduction_op, + const BatchIndexerT &batch_indexer, + const RedIndexerT &arg_reduced_dims_indexer, + SlmT local_mem, + std::size_t reduction_size, + std::size_t iteration_size, + std::size_t reduction_size_per_wi) + : lhs_(lhs), rhs_(rhs), out_(res), reduction_op_(reduction_op), + batch_indexer_(batch_indexer), + reduced_dims_indexer_(arg_reduced_dims_indexer), + local_mem_(local_mem), reduction_max_gid_(reduction_size), + batches_(iteration_size), reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t batch_id = it.get_group(0) % batches_; + const std::size_t reduction_batch_id = it.get_group(0) / batches_; + + const std::size_t reduction_lid = it.get_local_id(0); + const std::size_t wg = + it.get_local_range(0); // 0 <= reduction_lid < wg + + // work-items operate over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + // for each input + + const auto &batch_offsets_ = batch_indexer_(batch_id); + const auto &lhs_batch_offset = batch_offsets_.get_first_offset(); + const auto &rhs_batch_offset = batch_offsets_.get_second_offset(); + const auto &out_batch_offset = batch_offsets_.get_third_offset(); + + outT local_red_val(0); + std::size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + std::size_t arg_reduce_gid_max = std::min( + reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg); + + for (std::size_t arg_reduce_gid = arg_reduce_gid0; + arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg) + { + auto reduction_offsets_ = reduced_dims_indexer_(arg_reduce_gid); + const auto &lhs_reduction_offset = + reduction_offsets_.get_first_offset(); + const auto &rhs_reduction_offset = + reduction_offsets_.get_second_offset(); + + using dpctl::tensor::type_utils::convert_impl; + outT val = convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + + local_red_val += val; + } + + auto work_group = it.get_group(); + outT red_val_over_wg = su_ns::custom_reduce_over_group( + work_group, local_mem_, local_red_val, reduction_op_); + + if (work_group.leader()) { + sycl::atomic_ref + res_ref(out_[out_batch_offset]); + res_ref += red_val_over_wg; + } + } +}; + +template < + typename lhsTy, + typename rhsTy, + typename resTy, + typename BatchIndexerT, + typename RedIndexerT, + template + class kernel_name_token> +sycl::event sequential_dot_product(sycl::queue &exec_q, + const lhsTy *lhs, + const rhsTy *rhs, + resTy *res, + std::size_t batches, + std::size_t reduction_nelems, + const BatchIndexerT &batch_indexer, + const RedIndexerT &reduction_indexer, + const std::vector &depends) +{ + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.parallel_for< + kernel_name_token>( + sycl::range<1>(batches), + SequentialDotProduct(lhs, rhs, res, batch_indexer, + reduction_indexer, + reduction_nelems)); + }); + + return dot_ev; +} + +template + class kernel_name_token> +sycl::event submit_atomic_dot_product(sycl::queue &exec_q, + const lhsTy *lhs, + const rhsTy *rhs, + resTy *res, + std::size_t wg, + std::size_t batches, + std::size_t reduction_nelems, + std::size_t reductions_per_wi, + std::size_t reduction_groups, + const BatchIndexerT &batch_indexer, + const RedIndexerT &reduction_indexer, + const std::vector &depends) +{ + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + auto ndRange = sycl::nd_range<1>(globalRange, localRange); + + if constexpr (can_use_reduce_over_group::value) { + using KernelName = + class kernel_name_token; + + cgh.parallel_for( + ndRange, DotProductFunctor( + lhs, rhs, res, ReductionOpT(), batch_indexer, + reduction_indexer, reduction_nelems, batches, + reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + + using KernelName = class custom_reduction_wrapper>; + + cgh.parallel_for( + ndRange, + DotProductCustomFunctor( + lhs, rhs, res, ReductionOpT(), batch_indexer, + reduction_indexer, local_memory, reduction_nelems, batches, + reductions_per_wi)); + } + }); + return dot_ev; +} + +template +class dot_product_seq_krn; + +template +class dot_product_init_krn; + +template +class dot_product_krn; + +typedef sycl::event (*dot_product_impl_fn_ptr_t)( + sycl::queue &, + std::size_t, + std::size_t, + const char *, + const char *, + char *, + int, + const ssize_t *, + ssize_t, + ssize_t, + ssize_t, + int, + const ssize_t *, + ssize_t, + ssize_t, + const std::vector &); + +template +sycl::event dot_product_impl(sycl::queue &exec_q, + std::size_t batches, + std::size_t reduction_nelems, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + int batch_nd, + const ssize_t *batch_shape_and_strides, + ssize_t batch_lhs_offset, + ssize_t batch_rhs_offset, + ssize_t batch_res_offset, + int red_nd, + const ssize_t *reduction_shape_stride, + ssize_t reduction_lhs_offset, + ssize_t reduction_rhs_offset, + const std::vector &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + const InputOutputBatchIndexerT inp_out_batch_indexer{ + batch_nd, batch_lhs_offset, batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; + + sycl::event dot_ev = + sequential_dot_product( + exec_q, lhs_tp, rhs_tp, res_tp, batches, reduction_nelems, + inp_out_batch_indexer, reduction_indexer, depends); + + return dot_ev; + } + else { + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + using IndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + + const ssize_t *const &res_shape = batch_shape_and_strides; + const ssize_t *const &res_strides = + batch_shape_and_strides + 3 * batch_nd; + const IndexerT res_indexer(batch_nd, batch_res_offset, res_shape, + res_strides); + using InitKernelName = + class dot_product_init_krn; + cgh.depends_on(depends); + + cgh.parallel_for( + sycl::range<1>(batches), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = 0; + }); + }); + + using ReductionOpT = sycl::plus; + + using BatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + const BatchIndexerT batch_indexer{batch_nd, batch_lhs_offset, + batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; + + static constexpr std::size_t preferred_reductions_per_wi = + 4; // determined experimentally + std::size_t reductions_per_wi = + (reduction_nelems < preferred_reductions_per_wi * wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) + : preferred_reductions_per_wi; + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + sycl::event dot_ev = + submit_atomic_dot_product( + exec_q, lhs_tp, rhs_tp, res_tp, wg, batches, reduction_nelems, + reductions_per_wi, reduction_groups, batch_indexer, + reduction_indexer, {res_init_ev}); + + return dot_ev; + } +} + +typedef sycl::event (*dot_product_contig_impl_fn_ptr_t)( + sycl::queue &, + std::size_t, + std::size_t, + const char *, + const char *, + char *, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + ssize_t, + const std::vector &); + +template +sycl::event + dot_product_contig_impl(sycl::queue &exec_q, + std::size_t batches, + std::size_t reduction_nelems, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + ssize_t batch_lhs_offset, + ssize_t batch_rhs_offset, + ssize_t batch_res_offset, + ssize_t reduction_lhs_offset, + ssize_t reduction_rhs_offset, + const std::vector &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp) + + batch_lhs_offset + reduction_lhs_offset; + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp) + + batch_rhs_offset + reduction_rhs_offset; + resTy *res_tp = reinterpret_cast(res_cp) + batch_res_offset; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + const InputBatchIndexerT inp_batch_indexer{/* size */ batches, + /* step */ reduction_nelems}; + const InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + static constexpr ReductionIndexerT reduction_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + + sycl::event dot_ev = + sequential_dot_product( + exec_q, lhs_tp, rhs_tp, res_tp, batches, reduction_nelems, + inp_out_batch_indexer, reduction_indexer, depends); + + return dot_ev; + } + else { + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.fill(res_tp, resTy(0), batches); + }); + + using ReductionOpT = sycl::plus; + + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + const InputBatchIndexerT inp_batch_indexer{/* size */ batches, + /* step */ reduction_nelems}; + const InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + static constexpr ReductionIndexerT reduction_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + + static constexpr std::size_t preferred_reductions_per_wi = + 4; // determined experimentally + std::size_t reductions_per_wi = + (reduction_nelems < preferred_reductions_per_wi * wg) + ? std::max(1, (reduction_nelems + wg - 1) / wg) + : preferred_reductions_per_wi; + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + + sycl::event dot_ev = + submit_atomic_dot_product( + exec_q, lhs_tp, rhs_tp, res_tp, wg, batches, reduction_nelems, + reductions_per_wi, reduction_groups, inp_out_batch_indexer, + reduction_indexer, {res_init_ev}); + + return dot_ev; + } +} + +template +struct DotProductNoAtomicFunctor +{ +private: + const lhsT *lhs_ = nullptr; + const rhsT *rhs_ = nullptr; + outT *out_ = nullptr; + ReductionOpT reduction_op_; + BatchIndexerT batch_indexer_; + RedIndexerT reduced_dims_indexer_; + std::size_t reduction_max_gid_ = 0; + std::size_t batches_ = 1; + std::size_t reductions_per_wi = 16; + +public: + DotProductNoAtomicFunctor(const lhsT *lhs, + const rhsT *rhs, + outT *res, + const ReductionOpT &reduction_op, + const BatchIndexerT &batch_indexer, + const RedIndexerT &arg_reduced_dims_indexer, + std::size_t reduction_size, + std::size_t iteration_size, + std::size_t reduction_size_per_wi) + : lhs_(lhs), rhs_(rhs), out_(res), reduction_op_(reduction_op), + batch_indexer_(batch_indexer), + reduced_dims_indexer_(arg_reduced_dims_indexer), + reduction_max_gid_(reduction_size), batches_(iteration_size), + reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t reduction_lid = it.get_local_id(0); + const std::size_t wg = + it.get_local_range(0); // 0 <= reduction_lid < wg + + const std::size_t batch_id = it.get_group(0) % batches_; + const std::size_t reduction_batch_id = it.get_group(0) / batches_; + const std::size_t n_reduction_groups = it.get_group_range(0) / batches_; + + // work-items operate over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + // for each input + + const auto &batch_offsets_ = batch_indexer_(batch_id); + const auto &lhs_batch_offset = batch_offsets_.get_first_offset(); + const auto &rhs_batch_offset = batch_offsets_.get_second_offset(); + const auto &out_batch_offset = batch_offsets_.get_third_offset(); + + outT local_red_val(0); + std::size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + std::size_t arg_reduce_gid_max = std::min( + reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg); + + for (std::size_t arg_reduce_gid = arg_reduce_gid0; + arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg) + { + auto reduction_offsets_ = reduced_dims_indexer_(arg_reduce_gid); + const auto &lhs_reduction_offset = + reduction_offsets_.get_first_offset(); + const auto &rhs_reduction_offset = + reduction_offsets_.get_second_offset(); + + using dpctl::tensor::type_utils::convert_impl; + outT val = convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + + local_red_val += val; + } + + auto work_group = it.get_group(); + + using RedOpT = typename std::conditional, + sycl::logical_or, + sycl::plus>::type; + outT red_val_over_wg = sycl::reduce_over_group( + work_group, local_red_val, outT(0), RedOpT()); + + if (work_group.leader()) { + // each group writes to a different memory location + out_[out_batch_offset * n_reduction_groups + reduction_batch_id] = + red_val_over_wg; + } + } +}; + +template +struct DotProductNoAtomicCustomFunctor +{ +private: + const lhsT *lhs_ = nullptr; + const rhsT *rhs_ = nullptr; + outT *out_ = nullptr; + ReductionOpT reduction_op_; + BatchIndexerT batch_indexer_; + RedIndexerT reduced_dims_indexer_; + SlmT local_mem_; + std::size_t reduction_max_gid_ = 0; + std::size_t batches_ = 1; + std::size_t reductions_per_wi = 16; + +public: + DotProductNoAtomicCustomFunctor(const lhsT *lhs, + const rhsT *rhs, + outT *res, + const ReductionOpT &reduction_op, + const BatchIndexerT &batch_indexer, + const RedIndexerT &arg_reduced_dims_indexer, + SlmT local_mem, + std::size_t reduction_size, + std::size_t iteration_size, + std::size_t reduction_size_per_wi) + : lhs_(lhs), rhs_(rhs), out_(res), reduction_op_(reduction_op), + batch_indexer_(batch_indexer), + reduced_dims_indexer_(arg_reduced_dims_indexer), + local_mem_(local_mem), reduction_max_gid_(reduction_size), + batches_(iteration_size), reductions_per_wi(reduction_size_per_wi) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t reduction_lid = it.get_local_id(0); + const std::size_t wg = + it.get_local_range(0); // 0 <= reduction_lid < wg + + const std::size_t batch_id = it.get_group(0) % batches_; + const std::size_t reduction_batch_id = it.get_group(0) / batches_; + const std::size_t n_reduction_groups = it.get_group_range(0) / batches_; + + // work-items operate over input with indices + // inp_data_id = reduction_batch_id * wg * reductions_per_wi + m * wg + // + reduction_lid + // for 0 <= m < reductions_per_wi + // for each input + + const auto &batch_offsets_ = batch_indexer_(batch_id); + const auto &lhs_batch_offset = batch_offsets_.get_first_offset(); + const auto &rhs_batch_offset = batch_offsets_.get_second_offset(); + const auto &out_batch_offset = batch_offsets_.get_third_offset(); + + outT local_red_val(0); + std::size_t arg_reduce_gid0 = + reduction_lid + reduction_batch_id * wg * reductions_per_wi; + std::size_t arg_reduce_gid_max = std::min( + reduction_max_gid_, arg_reduce_gid0 + reductions_per_wi * wg); + + for (std::size_t arg_reduce_gid = arg_reduce_gid0; + arg_reduce_gid < arg_reduce_gid_max; arg_reduce_gid += wg) + { + auto reduction_offsets_ = reduced_dims_indexer_(arg_reduce_gid); + const auto &lhs_reduction_offset = + reduction_offsets_.get_first_offset(); + const auto &rhs_reduction_offset = + reduction_offsets_.get_second_offset(); + + using dpctl::tensor::type_utils::convert_impl; + outT val = convert_impl( + lhs_[lhs_batch_offset + lhs_reduction_offset]) * + convert_impl( + rhs_[rhs_batch_offset + rhs_reduction_offset]); + + local_red_val += val; + } + + auto work_group = it.get_group(); + + outT red_val_over_wg = su_ns::custom_reduce_over_group( + work_group, local_mem_, local_red_val, reduction_op_); + + if (work_group.leader()) { + // each group writes to a different memory location + out_[out_batch_offset * n_reduction_groups + reduction_batch_id] = + red_val_over_wg; + } + } +}; + +template + class kernel_name_token> +sycl::event + submit_no_atomic_dot_product(sycl::queue &exec_q, + const lhsTy *lhs, + const rhsTy *rhs, + resTy *res, + std::size_t wg, + std::size_t batches, + std::size_t reduction_nelems, + std::size_t reductions_per_wi, + std::size_t reduction_groups, + const BatchIndexerT &batch_indexer, + const RedIndexerT &reduction_indexer, + const std::vector &depends) +{ + sycl::event dot_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + auto globalRange = sycl::range<1>{batches * reduction_groups * wg}; + auto localRange = sycl::range<1>{wg}; + auto ndRange = sycl::nd_range<1>(globalRange, localRange); + + if constexpr (can_use_reduce_over_group::value) { + using KernelName = + class kernel_name_token; + + cgh.parallel_for( + ndRange, + DotProductNoAtomicFunctor( + lhs, rhs, res, ReductionOpT(), batch_indexer, + reduction_indexer, reduction_nelems, batches, + reductions_per_wi)); + } + else { + using SlmT = sycl::local_accessor; + SlmT local_memory = SlmT(localRange, cgh); + + using KernelName = class custom_reduction_wrapper>; + + cgh.parallel_for( + ndRange, + DotProductNoAtomicCustomFunctor( + lhs, rhs, res, ReductionOpT(), batch_indexer, + reduction_indexer, local_memory, reduction_nelems, batches, + reductions_per_wi)); + } + }); + return dot_ev; +} + +template +class dot_product_tree_krn; + +template +class dot_product_tree_reduction_krn; + +template +sycl::event dot_product_tree_impl(sycl::queue &exec_q, + std::size_t batches, + std::size_t reduction_nelems, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + int batch_nd, + const ssize_t *batch_shape_and_strides, + ssize_t batch_lhs_offset, + ssize_t batch_rhs_offset, + ssize_t batch_res_offset, + int red_nd, + const ssize_t *reduction_shape_stride, + ssize_t reduction_lhs_offset, + ssize_t reduction_rhs_offset, + const std::vector &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + const InputOutputBatchIndexerT inp_out_batch_indexer{ + batch_nd, batch_lhs_offset, batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; + + sycl::event dot_ev = + sequential_dot_product( + exec_q, lhs_tp, rhs_tp, res_tp, batches, reduction_nelems, + inp_out_batch_indexer, reduction_indexer, depends); + + return dot_ev; + } + + static constexpr std::size_t preferred_reductions_per_wi = 8; + // prevents running out of resources on CPU + std::size_t max_wg = reduction_detail::get_work_group_size(d); + + using ReductionOpT = typename std::conditional, + sycl::logical_or, + sycl::plus>::type; + + std::size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + using BatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + const BatchIndexerT batch_indexer{batch_nd, batch_lhs_offset, + batch_rhs_offset, batch_res_offset, + batch_shape_and_strides}; + const ReductionIndexerT reduction_indexer{red_nd, reduction_lhs_offset, + reduction_rhs_offset, + reduction_shape_stride}; + + if (batches == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event dot_ev = + submit_no_atomic_dot_product( + exec_q, lhs_tp, rhs_tp, res_tp, wg, batches, reduction_nelems, + reductions_per_wi, reduction_groups, batch_indexer, + reduction_indexer, depends); + + return dot_ev; + } + else { + static constexpr resTy identity_val = + sycl::known_identity::value; + + // more than one work-groups is needed, requires a temporary + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + std::size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // returns unique_ptr + auto partially_reduced_tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + batches * (reduction_groups + second_iter_reduction_groups_), + exec_q); + + resTy *partially_reduced_tmp = partially_reduced_tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * batches; + + sycl::event first_reduction_ev; + { + using LhsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using RhsIndexerT = + dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + LhsIndexerT, RhsIndexerT, ResIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + const LhsIndexerT lhs_indexer(batch_nd, batch_lhs_offset, + batch_shape_and_strides); + const RhsIndexerT rhs_indexer( + batch_nd, batch_rhs_offset, batch_shape_and_strides, + batch_shape_and_strides + 2 * batch_nd); + static constexpr ResIndexerT noop_tmp_indexer{}; + + const InputOutputBatchIndexerT in_out_iter_indexer{ + lhs_indexer, rhs_indexer, noop_tmp_indexer}; + const ReductionIndexerT reduction_indexer{ + red_nd, reduction_lhs_offset, reduction_rhs_offset, + reduction_shape_stride}; + + first_reduction_ev = submit_no_atomic_dot_product< + lhsTy, rhsTy, resTy, ReductionOpT, InputOutputBatchIndexerT, + ReductionIndexerT, dot_product_tree_krn>( + exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, wg, batches, + reduction_nelems, preferred_reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, depends); + } + + std::size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + std::size_t reduction_groups_ = + (remaining_reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ batches, + /* step */ reduction_groups_}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event partial_reduction_ev = + dpctl::tensor::kernels::submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, dot_product_tree_reduction_krn>( + exec_q, temp_arg, temp2_arg, identity_val, wg, batches, + remaining_reduction_nelems, preferred_reductions_per_wi, + reduction_groups_, in_out_iter_indexer, reduction_indexer, + {dependent_ev}); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + + // final reduction to res + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ batches, + /* step */ remaining_reduction_nelems}; + const ResIndexerT res_iter_indexer{ + batch_nd, batch_res_offset, + /* shape */ batch_shape_and_strides, + /* strides */ batch_shape_and_strides + 2 * batch_nd}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = std::max( + 1, (remaining_reduction_nelems + wg - 1) / wg); + + reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event final_reduction_ev = + dpctl::tensor::kernels::submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, dot_product_tree_reduction_krn>( + exec_q, temp_arg, res_tp, identity_val, wg, batches, + remaining_reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, {dependent_ev}); + + // transfer ownership of USM allocation to host_task + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {final_reduction_ev}, partially_reduced_tmp_owner); + + return cleanup_host_task_event; + } +} + +template +sycl::event + dot_product_contig_tree_impl(sycl::queue &exec_q, + std::size_t batches, + std::size_t reduction_nelems, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + ssize_t batch_lhs_offset, + ssize_t batch_rhs_offset, + ssize_t batch_res_offset, + ssize_t reduction_lhs_offset, + ssize_t reduction_rhs_offset, + const std::vector &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp) + + batch_lhs_offset + reduction_lhs_offset; + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp) + + batch_rhs_offset + reduction_rhs_offset; + resTy *res_tp = reinterpret_cast(res_cp) + batch_res_offset; + + const sycl::device &d = exec_q.get_device(); + const auto &sg_sizes = d.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + if (reduction_nelems < wg) { + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + const InputBatchIndexerT inp_batch_indexer{/* size */ batches, + /* step */ reduction_nelems}; + const InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + static constexpr ReductionIndexerT reduction_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + + sycl::event dot_ev = + sequential_dot_product( + exec_q, lhs_tp, rhs_tp, res_tp, batches, reduction_nelems, + inp_out_batch_indexer, reduction_indexer, depends); + + return dot_ev; + } + + static constexpr std::size_t preferred_reductions_per_wi = 8; + // prevents running out of resources on CPU + std::size_t max_wg = reduction_detail::get_work_group_size(d); + + using ReductionOpT = typename std::conditional, + sycl::logical_or, + sycl::plus>::type; + + std::size_t reductions_per_wi(preferred_reductions_per_wi); + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + const InputBatchIndexerT inp_batch_indexer{/* size */ batches, + /* step */ reduction_nelems}; + const InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + static constexpr ReductionIndexerT reduction_indexer{NoOpIndexerT{}, + NoOpIndexerT{}}; + + if (batches == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event dot_ev = submit_no_atomic_dot_product< + lhsTy, rhsTy, resTy, ReductionOpT, InputOutputBatchIndexerT, + ReductionIndexerT, dot_product_tree_krn>( + exec_q, lhs_tp, rhs_tp, res_tp, wg, batches, reduction_nelems, + reductions_per_wi, reduction_groups, inp_out_batch_indexer, + reduction_indexer, depends); + + return dot_ev; + } + else { + static constexpr resTy identity_val = + sycl::known_identity::value; + + // more than one work-groups is needed, requires a temporary + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups > 1); + + std::size_t second_iter_reduction_groups_ = + (reduction_groups + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // unique_ptr that owns temporary allocation for partial reductions + auto partially_reduced_tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + batches * (reduction_groups + second_iter_reduction_groups_), + exec_q); + // get raw pointers + resTy *partially_reduced_tmp = partially_reduced_tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_groups * batches; + + sycl::event first_reduction_ev; + { + using InputBatchIndexerT = + dpctl::tensor::offset_utils::Strided1DIndexer; + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputBatchIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer< + InputBatchIndexerT, InputBatchIndexerT, NoOpIndexerT>; + using ReductionIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + + const InputBatchIndexerT inp_batch_indexer{ + /* size */ batches, + /* step */ reduction_nelems}; + const InputOutputBatchIndexerT inp_out_batch_indexer{ + inp_batch_indexer, inp_batch_indexer, NoOpIndexerT{}}; + static constexpr ReductionIndexerT reduction_indexer{ + NoOpIndexerT{}, NoOpIndexerT{}}; + + first_reduction_ev = submit_no_atomic_dot_product< + lhsTy, rhsTy, resTy, ReductionOpT, InputOutputBatchIndexerT, + ReductionIndexerT, dot_product_tree_krn>( + exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, wg, batches, + reduction_nelems, preferred_reductions_per_wi, reduction_groups, + inp_out_batch_indexer, reduction_indexer, depends); + } + + std::size_t remaining_reduction_nelems = reduction_groups; + + resTy *temp_arg = partially_reduced_tmp; + resTy *temp2_arg = partially_reduced_tmp2; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > + preferred_reductions_per_wi * max_wg) { + std::size_t reduction_groups_ = + (remaining_reduction_nelems + preferred_reductions_per_wi * wg - + 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ batches, + /* step */ reduction_groups_}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event partial_reduction_ev = + dpctl::tensor::kernels::submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, dot_product_tree_reduction_krn>( + exec_q, temp_arg, temp2_arg, identity_val, wg, batches, + remaining_reduction_nelems, preferred_reductions_per_wi, + reduction_groups_, in_out_iter_indexer, reduction_indexer, + {dependent_ev}); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + + // final reduction to res + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ batches, + /* step */ remaining_reduction_nelems}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = std::max( + 1, (remaining_reduction_nelems + wg - 1) / wg); + + reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event final_reduction_ev = + dpctl::tensor::kernels::submit_no_atomic_reduction< + resTy, resTy, ReductionOpT, InputOutputIterIndexerT, + ReductionIndexerT, dot_product_tree_reduction_krn>( + exec_q, temp_arg, res_tp, identity_val, wg, batches, + remaining_reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, {dependent_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {final_reduction_ev}, partially_reduced_tmp_owner); + + return cleanup_host_task_event; + } +} + +} // namespace dpctl::tensor::kernels diff --git a/dpctl_ext/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp b/dpctl_ext/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp new file mode 100644 index 00000000000..5d42ada4e84 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp @@ -0,0 +1,4239 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for general matrix multiplication (GEMM). +//===---------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/reductions.hpp" +#include "utils/offset_utils.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/sycl_utils.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl::tensor::kernels +{ + +using dpctl::tensor::ssize_t; + +namespace gemm_detail +{ + +template +void scale_gemm_k_parameters(const std::size_t &local_mem_size, + const std::size_t &reserved_slm_size, + const std::size_t delta_k, + std::size_t &n_wi, + std::size_t &delta_n) +{ + static constexpr std::size_t slm_elem_size = sizeof(T) * m_groups; + + while (slm_elem_size * (n_wi + delta_n) * delta_k + reserved_slm_size >= + local_mem_size) + { + n_wi = n_wi / 2; + delta_n = delta_n / 2; + if (delta_n == 0) + throw std::runtime_error("Insufficient resources"); + } +} + +template +void scale_gemm_nm_parameters(const std::size_t &local_mem_size, + const std::size_t &reserved_slm_size, + const std::size_t &wi_delta_n, + std::size_t &wi_delta_k, + std::size_t &wg_delta_n, + std::size_t &wg_delta_m) +{ + static constexpr std::size_t slm_A_elem_size = sizeof(T); + static constexpr std::size_t slm_B_elem_size = sizeof(T) * wi_delta_m; + + while ((wi_delta_n * wg_delta_n * wi_delta_k * slm_A_elem_size) + + (wi_delta_k * wg_delta_m * slm_B_elem_size) + + reserved_slm_size >= + local_mem_size) + { + wg_delta_n /= 2; + wg_delta_m /= 2; + wi_delta_k /= 2; + if (wg_delta_n == 0) + throw std::runtime_error("Insufficient resources"); + } +} +} // namespace gemm_detail + +using dpctl::tensor::sycl_utils::choose_workgroup_size; + +template +class gemm_seq_reduction_krn; + +template +class gemm_tree_reduction_krn; + +template +sycl::event single_reduction_for_gemm(sycl::queue &exec_q, + T *tmp_tp, + T *res_tp, + T identity_val, + std::size_t iter_nelems, + std::size_t reduction_nelems, + std::size_t reduction_groups, + std::size_t wg, + std::size_t max_wg, + std::size_t preferred_reductions_per_wi, + std::size_t reductions_per_wi, + int res_nd, + ssize_t res_offset, + const ssize_t *res_shapes_strides, + const std::vector &depends) +{ + sycl::event red_ev; + if (reduction_nelems < wg) { + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + const ResIndexerT res_iter_indexer{res_nd, 0, res_shapes_strides}; + const InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + res_iter_indexer}; + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems)); + }); + } + else { + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + const ResIndexerT res_iter_indexer{res_nd, 0, res_shapes_strides}; + const InputOutputIterIndexerT in_out_iter_indexer{NoOpIndexerT{}, + res_iter_indexer}; + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + red_ev = dpctl::tensor::kernels::submit_no_atomic_reduction< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT, + gemm_tree_reduction_krn>( + exec_q, tmp_tp, res_tp, identity_val, wg, iter_nelems, + reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, depends); + } + return red_ev; +} + +template +sycl::event + single_reduction_for_gemm_contig(sycl::queue &exec_q, + T *tmp_tp, + T *res_tp, + T identity_val, + std::size_t iter_nelems, + std::size_t reduction_nelems, + std::size_t reduction_groups, + std::size_t wg, + std::size_t max_wg, + std::size_t preferred_reductions_per_wi, + std::size_t reductions_per_wi, + const std::vector &depends) +{ + sycl::event red_ev; + if (reduction_nelems < wg) { + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + static constexpr InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, NoOpIndexerT{}}; + // tmp allocation is a C-contiguous matrix (reduction_nelems, + // iter_nelems) and we are reducing by axis 0 + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + red_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + sycl::range<1> iter_range{iter_nelems}; + + cgh.parallel_for>( + iter_range, + SequentialReduction( + tmp_tp, res_tp, ReductionOpT(), identity_val, + in_out_iter_indexer, reduction_indexer, reduction_nelems)); + }); + } + else { + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + static constexpr InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, NoOpIndexerT{}}; + // tmp allocation is a C-contiguous matrix + // (reduction_nelems, iter_nelems). Reducing along axis 0 + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + if (iter_nelems == 1) { + // increase GPU occupancy + wg = max_wg; + } + reductions_per_wi = + std::max(1, (reduction_nelems + wg - 1) / wg); + + std::size_t reduction_groups = + (reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + red_ev = dpctl::tensor::kernels::submit_no_atomic_reduction< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT, + gemm_tree_reduction_krn>( + exec_q, tmp_tp, res_tp, identity_val, wg, iter_nelems, + reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, depends); + } + return red_ev; +} + +template +sycl::event tree_reduction_for_gemm(sycl::queue &exec_q, + T *partially_reduced_tmp, + T *partially_reduced_tmp2, + T *res_tp, + T identity_val, + std::size_t iter_nelems, + std::size_t reduction_nelems, + std::size_t reduction_groups, + std::size_t wg, + std::size_t max_wg, + std::size_t preferred_reductions_per_wi, + std::size_t reductions_per_wi, + int res_nd, + ssize_t res_offset, + const ssize_t *res_shape_strides, + const std::vector &depends) +{ + sycl::event first_reduction_ev; + { + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + NoOpIndexerT, NoOpIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + static constexpr InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, NoOpIndexerT{}}; + // partially_reduced_tmp is C-contig matrix with shape + // (reduction_nelems, iter_nelems). Reducing along axis 0. + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + first_reduction_ev = dpctl::tensor::kernels::submit_no_atomic_reduction< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT, + gemm_tree_reduction_krn>( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, identity_val, + wg, iter_nelems, reduction_nelems, reductions_per_wi, + reduction_groups, in_out_iter_indexer, reduction_indexer, depends); + } + + std::size_t remaining_reduction_nelems = reduction_groups; + + T *temp_arg = partially_reduced_tmp2; + T *temp2_arg = partially_reduced_tmp; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > preferred_reductions_per_wi * max_wg) { + std::size_t reduction_groups_ = (remaining_reduction_nelems + + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ reduction_groups_}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event partial_reduction_ev = + dpctl::tensor::kernels::submit_no_atomic_reduction< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT, + gemm_tree_reduction_krn>( + exec_q, temp_arg, temp2_arg, identity_val, wg, iter_nelems, + remaining_reduction_nelems, reductions_per_wi, + reduction_groups_, in_out_iter_indexer, reduction_indexer, + {dependent_ev}); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + + // final reduction to res + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ remaining_reduction_nelems}; + const ResIndexerT res_iter_indexer{ + /* ndim */ res_nd, + /* offset */ static_cast(res_offset), + /* packed shape_strides*/ res_shape_strides}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = + std::max(1, (remaining_reduction_nelems + wg - 1) / wg); + + reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event final_reduction_ev = + dpctl::tensor::kernels::submit_no_atomic_reduction< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT, + gemm_tree_reduction_krn>( + exec_q, temp_arg, res_tp, identity_val, wg, iter_nelems, + remaining_reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, {dependent_ev}); + + return final_reduction_ev; +} + +template +class gemm_reduction_over_group_temps_contig_krn; + +template +sycl::event + tree_reduction_for_gemm_contig(sycl::queue &exec_q, + T *partially_reduced_tmp, + T *partially_reduced_tmp2, + T *res_tp, + T identity_val, + std::size_t iter_nelems, + std::size_t reduction_nelems, + std::size_t reduction_groups, + std::size_t wg, + std::size_t max_wg, + std::size_t preferred_reductions_per_wi, + std::size_t reductions_per_wi, + const std::vector &depends) +{ + using NoOpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer; + using ReductionIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + + static constexpr InputOutputIterIndexerT in_out_iter_indexer{ + NoOpIndexerT{}, NoOpIndexerT{}}; + const ReductionIndexerT reduction_indexer{/* size */ reduction_nelems, + /* step */ iter_nelems}; + + const sycl::event &first_reduction_ev = + dpctl::tensor::kernels::submit_no_atomic_reduction< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT, + gemm_reduction_over_group_temps_contig_krn>( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, identity_val, + wg, iter_nelems, reduction_nelems, reductions_per_wi, + reduction_groups, in_out_iter_indexer, reduction_indexer, depends); + + std::size_t remaining_reduction_nelems = reduction_groups; + + T *temp_arg = partially_reduced_tmp2; + T *temp2_arg = partially_reduced_tmp; + sycl::event dependent_ev = first_reduction_ev; + + while (remaining_reduction_nelems > preferred_reductions_per_wi * max_wg) { + std::size_t reduction_groups_ = (remaining_reduction_nelems + + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + assert(reduction_groups_ > 1); + + // keep reducing + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + // n * m = iter_nelems because essentially, this process + // creates a stack of reduction_nelems 2D matrices and we reduce + // along the stack axis + const InputIndexerT inp_indexer{/* size */ iter_nelems, + /* step */ reduction_groups_}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + + static constexpr ReductionIndexerT reduction_indexer{}; + + sycl::event partial_reduction_ev = + dpctl::tensor::kernels::submit_no_atomic_reduction< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT, + gemm_reduction_over_group_temps_contig_krn>( + exec_q, temp_arg, temp2_arg, identity_val, wg, iter_nelems, + remaining_reduction_nelems, reductions_per_wi, + reduction_groups_, in_out_iter_indexer, reduction_indexer, + {dependent_ev}); + + remaining_reduction_nelems = reduction_groups_; + std::swap(temp_arg, temp2_arg); + dependent_ev = std::move(partial_reduction_ev); + } + + // final reduction to res + { + using InputIndexerT = dpctl::tensor::offset_utils::Strided1DIndexer; + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + using InputOutputIterIndexerT = + dpctl::tensor::offset_utils::TwoOffsets_CombinedIndexer< + InputIndexerT, ResIndexerT>; + using ReductionIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const InputIndexerT inp_indexer{ + /* size */ iter_nelems, + /* step */ remaining_reduction_nelems}; + static constexpr ResIndexerT res_iter_indexer{}; + + const InputOutputIterIndexerT in_out_iter_indexer{inp_indexer, + res_iter_indexer}; + static constexpr ReductionIndexerT reduction_indexer{}; + + wg = max_wg; + reductions_per_wi = std::max( + 1, (remaining_reduction_nelems + wg - 1) / wg); + + std::size_t reduction_groups = + (remaining_reduction_nelems + reductions_per_wi * wg - 1) / + (reductions_per_wi * wg); + assert(reduction_groups == 1); + + sycl::event final_reduction_ev = + dpctl::tensor::kernels::submit_no_atomic_reduction< + T, T, ReductionOpT, InputOutputIterIndexerT, ReductionIndexerT, + gemm_reduction_over_group_temps_contig_krn>( + exec_q, temp_arg, res_tp, identity_val, wg, iter_nelems, + remaining_reduction_nelems, reductions_per_wi, reduction_groups, + in_out_iter_indexer, reduction_indexer, {dependent_ev}); + + return final_reduction_ev; + } +} + +template +class GemmBatchFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + std::size_t n = 0; + std::size_t n_blocks = 0; + std::size_t delta_n = 0; + std::size_t k = 0; + std::size_t k_blocks = 0; + std::size_t delta_k = 0; + std::size_t n_wi = 0; + std::size_t m = 0; + std::size_t batch_nelems = 0; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + OuterInnerDimsIndexerT res_indexer; + +public: + GemmBatchFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + std::size_t n_, + std::size_t n_blocks_, + std::size_t delta_n_, + std::size_t k_, + std::size_t k_blocks_, + std::size_t delta_k_, + std::size_t n_wi_, + std::size_t m_, + std::size_t batch_nelems_, + const BatchDimsIndexerT &batch_indexer_, + const OuterInnerDimsIndexerT &lhs_indexer_, + const OuterInnerDimsIndexerT &rhs_indexer_, + const OuterInnerDimsIndexerT &res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), + batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + // for batching: + // (current matrix in batch) m_id = global_id / (global_range / + // batch_nelems) for lhs, offset = m_id * (n * k) for rhs, offset = + // m_id + // * (k * m) for res, offset = m_id * (n * m) + const std::size_t n_groups_per_batch = + it.get_group_range(0) / batch_nelems; + const std::size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const std::size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + const std::size_t lid = it.get_local_linear_id(); + + const auto &three_offsets_ = batch_indexer(static_cast(m_id)); + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + const std::size_t r_size = (n_blocks * k_blocks); + // 0 <= block_j < m_blocks, + const std::size_t block_j = gr_id / r_size; + // 0 <= block_r < n_blocks * k_blocks + const std::size_t block_r = gr_id - block_j * r_size; + // 0 <= block_s < k_blocks + const std::size_t block_s = block_r / n_blocks; + // 0 <= block_i < n_blocks + const std::size_t block_i = block_r - block_s * n_blocks; + + // 0 <= local_i < delta_n + const std::size_t local_i = lid / (delta_k); + // 0 <= local_s < delta_k + const std::size_t local_s = lid - local_i * (delta_k); + + std::size_t i = block_i * delta_n + local_i; + std::size_t j = m_groups * block_j; + std::size_t s = block_s * delta_k * n_wi + local_s; + + using accV_t = typename LocAccT::value_type; + + static constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (std::size_t q = 0; q < n_wi * delta_k; q += delta_k) { + const std::size_t sq = s + q; + const std::size_t sqmj = sq * m + j; + + if constexpr (m_groups == 1 && std::is_same_v) { + local_B_block[local_s + q] = + (sq < k && j < m) + ? static_cast( + rhs[rhs_offset + rhs_indexer(sqmj)]) + : identity_; + } + else { + accV_t local_B_vec; +#pragma unroll + for (std::size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) + { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; + } + } + } + + it.barrier(sycl::access::fence_space::local_space); + + std::size_t t_shift = block_s * delta_k * n_wi; + std::size_t global_s_offset = i * k + t_shift; + + accV_t private_sum(identity_); + static constexpr accV_t vec_identity_(identity_); + for (std::size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += + ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : vec_identity_; + } + + std::size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + accV_t local_sum(workspace[workspace_i_shift]); + for (std::size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + sycl::atomic_ref + aout0(res[res_offset + res_indexer(i * m + j)]); + + if constexpr (m_groups == 1 && std::is_same_v) { + aout0 += local_sum; + } + else { + aout0 += local_sum[0]; + +#pragma unroll + for (std::size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + sycl::atomic_ref< + resT, sycl::memory_order::relaxed, + sycl::memory_scope::device, + sycl::access::address_space::global_space> + aout1(res[res_offset + + res_indexer(i * m + j + vec_id)]); + + aout1 += local_sum[vec_id]; + } + } + } + } + } +}; + +template +class gemm_init_krn; + +template +class gemm_k_krn; + +template +class gemm_nm_krn; + +template +class gemm_batch_k_krn; + +template +class gemm_batch_nm_krn; + +namespace gemm_detail +{ + +template +sycl::event _gemm_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + const std::size_t batch_nelems, + const std::size_t n, + const std::size_t k, + const std::size_t m, + const BatchIndexerT &batch_indexer, + const LhsIndexerT &lhs_indexer, + const RhsIndexerT &rhs_indexer, + const ResIndexerT &res_indexer, + const std::vector &depends) +{ + static constexpr std::size_t m_groups = 4; + const std::size_t delta_k(4); + std::size_t n_wi(64); + std::size_t delta_n(32); + + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + const sycl::device &dev = exec_q.get_device(); + const std::size_t local_mem_size = + dev.get_info(); + const std::size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + std::size_t n_blocks = (n + delta_n - 1) / delta_n; + std::size_t m_blocks = (m + m_groups - 1) / m_groups; + std::size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + std::size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using LocAccT = sycl::local_accessor, 1>; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = + class gemm_batch_k_krn; + cgh.parallel_for( + ndRange, + GemmBatchFunctorThreadK( + lhs_tp, rhs_tp, res_tp, std::move(workspace), + std::move(local_B_block), n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + }); + return gemm_ev; +} + +template +sycl::event _gemm_small_m_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + const std::size_t batch_nelems, + const std::size_t n, + const std::size_t k, + const std::size_t m, + const BatchIndexerT &batch_indexer, + const LhsIndexerT &lhs_indexer, + const RhsIndexerT &rhs_indexer, + const ResIndexerT &res_indexer, + const std::vector &depends) +{ + static constexpr std::size_t m_groups = 1; + const std::size_t delta_k(4); + std::size_t n_wi(64); + std::size_t delta_n(32); + + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + const sycl::device &dev = exec_q.get_device(); + const std::size_t local_mem_size = + dev.get_info(); + const std::size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + std::size_t n_blocks = (n + delta_n - 1) / delta_n; + std::size_t m_blocks = (m + m_groups - 1) / m_groups; + std::size_t k_blocks = (k + n_wi * delta_k - 1) / (n_wi * delta_k); + + std::size_t lws = delta_n * delta_k; + + auto gRange = + sycl::range<1>(batch_nelems * n_blocks * m_blocks * k_blocks * lws); + auto lRange = sycl::range<1>(lws); + + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = + class gemm_batch_k_krn; + cgh.parallel_for( + ndRange, + GemmBatchFunctorThreadK( + lhs_tp, rhs_tp, res_tp, std::move(workspace), + std::move(local_B_block), n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + }); + + return gemm_ev; +} + +} // end of namespace gemm_detail + +template +class GemmBatchFunctorThreadNM_vecm +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_lhs_block; + LocAccT2 local_rhs_block; + std::size_t batch_nelems; + std::size_t n = 0; + std::size_t k = 0; + std::size_t m = 0; + std::size_t n_groups = 0; + std::uint32_t wg_delta_n = 0; + std::uint32_t wg_delta_m = 0; + std::uint32_t wi_delta_k = 0; + BatchDimsIndexerT batch_indexer; + LhsIndexerT lhs_indexer; + RhsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + /*! @brief */ + GemmBatchFunctorThreadNM_vecm(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_lhs_block_, + LocAccT2 local_rhs_block_, + std::size_t batch_nelems_, + std::size_t n_, + std::size_t k_, + std::size_t m_, + std::size_t n_groups_, + std::size_t wg_delta_n_, + std::size_t wg_delta_m_, + std::size_t wi_delta_k_, + const BatchDimsIndexerT &batch_indexer_, + const LhsIndexerT &lhs_indexer_, + const RhsIndexerT &rhs_indexer_, + const ResIndexerT &res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_lhs_block(local_lhs_block_), + local_rhs_block(local_rhs_block_), batch_nelems(batch_nelems_), n(n_), + k(k_), m(m_), n_groups(n_groups_), wg_delta_n(wg_delta_n_), + wg_delta_m(wg_delta_m_), wi_delta_k(wi_delta_k_), + batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + static constexpr resT zero_(0); + static constexpr std::uint32_t wi_total_delta_m = + wi_delta_m_vecs * m_vec_size; + + const std::size_t gws_per_batch = it.get_group_range(0) / batch_nelems; + const std::size_t batch_id = it.get_group_linear_id() / gws_per_batch; + const std::size_t gr_id = + it.get_group_linear_id() - batch_id * gws_per_batch; + + const auto &three_offsets_ = + batch_indexer(static_cast(batch_id)); + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + // 0 <= block_j < m_groups + const std::size_t block_j = gr_id / n_groups; + // 0 <= block_i < n_groups + const std::size_t block_i = gr_id - block_j * n_groups; + + // Assumption: lws == wg_delta_n * wg_delta_m + const std::uint32_t lid = it.get_local_linear_id(); + // 0 <= local_j < (lws / wg_delta_n == wg_delta_m) + const std::uint32_t local_j = lid / wg_delta_n; + // sub-group lanes map to adjacent local_i + const std::uint32_t local_i = lid - local_j * wg_delta_n; + + // Coordinates of the block of C the work-group works on + std::size_t i = block_i * wg_delta_n * wi_delta_n; + std::size_t j = block_j * wg_delta_m * wi_total_delta_m; + + using slmA_t = typename LocAccT1::value_type; + using slmB_t = typename LocAccT2::value_type; + + const std::size_t a_st0 = k; + const std::size_t a_st1 = 1; + + const std::size_t b_st0 = m; + const std::size_t b_st1 = 1; + + const std::size_t c_st0 = m; + const std::size_t c_st1 = 1; + + // allocate/initialize private matrix C + // size ( wi_total_delta_n, wi_total_delta_m ) + static constexpr std::uint32_t C_size = wi_delta_n * wi_delta_m_vecs; + std::array private_C{slmB_t{zero_}}; + + for (std::size_t s = 0; s < k; s += wi_delta_k) { + // populate local_lhs_block ( wg_delta_n * wi_delta_n, + // wi_delta_k) + for (std::uint32_t vid = lid; vid < local_lhs_block.size(); + vid += it.get_local_range()[0]) + { + // 0 <= v_i < wg_delta_n * wi_delta_n + const std::uint32_t v_i = vid / wi_delta_k; + // 0 <= v_s < wi_delta_k + const std::uint32_t v_s = vid - v_i * wi_delta_k; + + const std::size_t g_i = i + v_i; + const std::size_t g_s = s + v_s; + + const std::uint32_t mapped_vid = + wg_delta_n * wi_delta_n * v_s + v_i; + local_lhs_block[mapped_vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_offset + + lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : zero_; + } + + // populate local_rhs_block> ( wg_delta_m * + // wi_delta_m_vecs, wi_delta_k ) + for (std::uint32_t vid = lid; vid < local_rhs_block.size(); + vid += it.get_local_range()[0]) + { + // 0 <= v_j < wg_delta_m * wi_delta_m_vecs + const std::uint32_t v_j = vid / wi_delta_k; + // 0 <= v_s < wi_delta_k + const std::uint32_t v_s = vid - v_j * wi_delta_k; + + const std::size_t g_j = j + v_j * m_vec_size; + const std::size_t g_s = s + v_s; + const std::uint32_t mapped_vid = + wg_delta_m * wi_delta_m_vecs * v_s + v_j; + + if constexpr (m_vec_size == 1) { + local_rhs_block[mapped_vid] = + (g_j < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j * b_st1)]) + : zero_; + } + else { + slmB_t vec{}; +#pragma unroll + for (std::uint32_t lane_id = 0; lane_id < m_vec_size; + ++lane_id) { + const std::size_t g_j1 = g_j + lane_id; + vec[lane_id] = (g_j1 < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + + g_j1 * b_st1)]) + : zero_; + }; + + local_rhs_block[mapped_vid] = vec; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + const std::uint32_t lo_lhs_st_k = (wg_delta_n * wi_delta_n); + const std::uint32_t lo_rhs_rk_k = (wg_delta_m * wi_delta_m_vecs); + for (std::uint32_t pr_k = 0; pr_k < wi_delta_k; ++pr_k) { + std::array pr_lhs{}; +#pragma unroll + for (std::uint32_t pr_i = 0; pr_i < wi_delta_n; ++pr_i) { + pr_lhs[pr_i] = + local_lhs_block[pr_k * lo_lhs_st_k + + (local_i + pr_i * wg_delta_n)]; + } + + std::array pr_rhs{}; +#pragma unroll + for (std::uint32_t pr_j = 0; pr_j < wi_delta_m_vecs; ++pr_j) { + pr_rhs[pr_j] = + local_rhs_block[pr_k * lo_rhs_rk_k + + (local_j + pr_j * wg_delta_m)]; + } + +#pragma unroll + for (std::uint32_t pr_i = 0; pr_i < wi_delta_n; ++pr_i) { +#pragma unroll + for (std::uint32_t pr_j = 0; pr_j < wi_delta_m_vecs; ++pr_j) + { + private_C[pr_i * wi_delta_m_vecs + pr_j] += + pr_lhs[pr_i] * pr_rhs[pr_j]; + } + } + } + + it.barrier(sycl::access::fence_space::local_space); + } + + if constexpr (m_vec_size == 1) { +#pragma unroll + for (std::uint32_t pr_i = 0; pr_i < wi_delta_n; ++pr_i) { + std::size_t out_i = i + local_i + pr_i * wg_delta_n; + if (out_i < n) { +#pragma unroll + for (std::uint32_t pr_j = 0; pr_j < wi_delta_m_vecs; ++pr_j) + { + const std::size_t out_j = + j + (local_j + pr_j * wg_delta_m) * m_vec_size; + const std::size_t out_flat_id = + out_i * c_st0 + out_j * c_st1; + if (out_j < m) { + res[res_offset + res_indexer(out_flat_id)] = + private_C[pr_i * wi_delta_m_vecs + pr_j]; + } + } + } + } + } + else { +#pragma unroll + for (std::uint32_t pr_i = 0; pr_i < wi_delta_n; ++pr_i) { + std::size_t out_i = i + local_i + pr_i * wg_delta_n; + if (out_i < n) { + // could be unrolled + for (std::uint32_t pr_j = 0; pr_j < wi_delta_m_vecs; ++pr_j) + { + std::size_t out_j = + j + (local_j + pr_j * wg_delta_m) * m_vec_size; +#pragma unroll + for (std::uint32_t lane_id = 0; lane_id < m_vec_size; + ++lane_id) { + const std::size_t out_flat_id = + out_i * c_st0 + (out_j + lane_id) * c_st1; + if (out_j + lane_id < m) { + res[res_offset + res_indexer(out_flat_id)] = + private_C[pr_i * wi_delta_m_vecs + pr_j] + [lane_id]; + } + } + } + } + } + } + } +}; + +struct GemmBatchFunctorThreadNM_vecm_HyperParameters +{ +private: + std::uint32_t wi_delta_n = 2; + std::uint32_t wi_delta_m_vecs = 4; + std::uint32_t m_vec_size = 1; + +public: + constexpr GemmBatchFunctorThreadNM_vecm_HyperParameters(); + constexpr GemmBatchFunctorThreadNM_vecm_HyperParameters( + std::uint32_t wi_delta_n_, + std::uint32_t wi_delta_m_vecs_, + std::uint32_t m_vec_size_) + : wi_delta_n(wi_delta_n_), wi_delta_m_vecs(wi_delta_m_vecs_), + m_vec_size(m_vec_size_) + { + } + + constexpr std::uint32_t get_wi_delta_n() const + { + return wi_delta_n; + } + constexpr std::uint32_t get_wi_delta_m_vecs() const + { + return wi_delta_m_vecs; + } + constexpr std::uint32_t get_m_vec_size() const + { + return m_vec_size; + } +}; + +template +struct GemmBatchFunctorThreadNM_vecm_HyperParametersSelector +{ + constexpr GemmBatchFunctorThreadNM_vecm_HyperParametersSelector() {} + + constexpr GemmBatchFunctorThreadNM_vecm_HyperParameters get() const + { + if constexpr (sizeof(resT) == 1) { + // 1 * 8 * 2 * 4 == 64 + return GemmBatchFunctorThreadNM_vecm_HyperParameters(8, 2, 4); + } + else if constexpr (sizeof(resT) == 2) { + // 2 * 4 * 2 * 4 == 64 + return GemmBatchFunctorThreadNM_vecm_HyperParameters(4, 2, 4); + } + else if constexpr (sizeof(resT) == 4) { + // 4 * 4 * 1 * 4 == 64 + return GemmBatchFunctorThreadNM_vecm_HyperParameters(4, 1, 4); + } + else if constexpr (sizeof(resT) == 8) { + // 8 * 2 * 1 * 4 == 64 + if constexpr (std::is_same_v>) { + return GemmBatchFunctorThreadNM_vecm_HyperParameters(2, 4, 1); + } + else { + return GemmBatchFunctorThreadNM_vecm_HyperParameters(2, 1, 4); + } + } + else if constexpr (std::is_same_v>) { + // 16 * 2 * 2 * 1 == 64 + return GemmBatchFunctorThreadNM_vecm_HyperParameters(2, 2, 1); + } + else { + return GemmBatchFunctorThreadNM_vecm_HyperParameters(2, 2, 1); + } + } +}; + +template +class gemm_batch_nm_vecm_krn; + +namespace gemm_detail +{ + +template +std::tuple + get_wg_delta_m_and_wi_delta_k(const std::size_t slm_byte_size, + const std::uint32_t wg_delta_n, + const std::uint32_t suggested_wg_delta_m) +{ + std::uint32_t wg_delta_m = suggested_wg_delta_m; + + const std::size_t slm_max_rows = + slm_byte_size / + ((wg_delta_n * wi_delta_n + wg_delta_m * wi_delta_m) * sizeof(T)); + + std::uint32_t wi_delta_k = + (slm_max_rows >= 64) + ? 64 + : 32 * static_cast(slm_max_rows / 32); + + for (std::uint32_t it = 0; !wi_delta_k && (it < 4); ++it) { + wg_delta_m /= 2; + + const std::size_t slm_max_rows = + slm_byte_size / + ((wg_delta_n * wi_delta_n + wg_delta_m * wi_delta_m) * sizeof(T)); + + wi_delta_k = + (slm_max_rows >= 64) + ? 64 + : ((slm_max_rows >= 32) + ? 32 + : (slm_max_rows >= 16 ? 16 + : 8 * static_cast( + slm_max_rows / 8))); + } + + if (!wi_delta_k) { + throw std::runtime_error("Insufficient resources"); + } + + return std::make_tuple(wg_delta_m, wi_delta_k); +} + +template +sycl::event _gemm_batch_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + const std::size_t batch_nelems, + const std::size_t n, + const std::size_t k, + const std::size_t m, + const BatchIndexerT &batch_indexer, + const LhsIndexerT &lhs_indexer, + const RhsIndexerT &rhs_indexer, + const ResIndexerT &res_indexer, + std::vector const &depends) +{ + static constexpr GemmBatchFunctorThreadNM_vecm_HyperParametersSelector< + resTy> + selector{}; + static constexpr auto hyper_params = selector.get(); + + static constexpr std::uint32_t wi_delta_n = hyper_params.get_wi_delta_n(); + static constexpr std::uint32_t wi_delta_m_vecs = + hyper_params.get_wi_delta_m_vecs(); + static constexpr std::uint32_t m_vec_size = hyper_params.get_m_vec_size(); + + static constexpr std::uint32_t wi_total_delta_m = + wi_delta_m_vecs * m_vec_size; + + using KernelName = + class gemm_batch_nm_vecm_krn; + + const auto &kernel_id = sycl::get_kernel_id(); + + auto const &ctx = exec_q.get_context(); + auto const &dev = exec_q.get_device(); + auto kb = sycl::get_kernel_bundle( + ctx, {dev}, {kernel_id}); + + auto krn = kb.get_kernel(kernel_id); + + const std::uint32_t max_sg_size = krn.template get_info< + sycl::info::kernel_device_specific::max_sub_group_size>(dev); + + const std::size_t k_wg_sz = krn.template get_info< + sycl::info::kernel_device_specific::work_group_size>(dev); + + // Limit work-group size + static constexpr std::size_t wg_sz_limit(2048); + const std::size_t max_wg_sz = std::min(wg_sz_limit, k_wg_sz); + + const std::uint32_t max_subgroups_per_wg = + static_cast(max_wg_sz / max_sg_size); + + const std::size_t reserved_slm_byte_size = 512; + const std::size_t slm_byte_size = + dev.get_info(); + + const std::uint32_t wg_delta_n = max_sg_size; + std::uint32_t wg_delta_m = 0; + std::uint32_t wi_delta_k = 0; + + std::tie(wg_delta_m, wi_delta_k) = + get_wg_delta_m_and_wi_delta_k( + slm_byte_size - reserved_slm_byte_size, wg_delta_n, + max_subgroups_per_wg); + + const std::uint32_t lws = wg_delta_n * wg_delta_m; + + const std::size_t n_groups = + (n + wg_delta_n * wi_delta_n - 1) / (wg_delta_n * wi_delta_n); + const std::size_t m_groups = (m + wg_delta_m * wi_total_delta_m - 1) / + (wg_delta_m * wi_total_delta_m); + + const std::size_t gws = lws * batch_nelems * n_groups * m_groups; + + sycl::range<1> lRange(lws); + sycl::range<1> gRange(gws); + sycl::nd_range<1> ndRange(gRange, lRange); + + using slmB_t = + typename std::conditional>::type; + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + cgh.use_kernel_bundle(kb); + + using LocAccT1 = sycl::local_accessor; + LocAccT1 local_A_block(wg_delta_n * wi_delta_n * wi_delta_k, cgh); + + using LocAccT2 = sycl::local_accessor; + LocAccT2 local_B_block(wg_delta_m * wi_delta_m_vecs * wi_delta_k, cgh); + + using Impl_FunctorT = GemmBatchFunctorThreadNM_vecm< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, BatchIndexerT, LhsIndexerT, + RhsIndexerT, ResIndexerT, wi_delta_n, wi_delta_m_vecs, m_vec_size>; + + cgh.parallel_for( + ndRange, Impl_FunctorT( + lhs_tp, rhs_tp, res_tp, std::move(local_A_block), + std::move(local_B_block), batch_nelems, n, k, m, + n_groups, wg_delta_n, wg_delta_m, wi_delta_k, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); + }); + return gemm_ev; +} + +} // namespace gemm_detail + +typedef sycl::event (*gemm_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + std::size_t, // lhs_outer_nelems (n) + std::size_t, // inner_nelems (k) + std::size_t, // rhs_outer_nelems (m) + int, // inner nd + int, // lhs outer nd + const ssize_t *, // lhs shape and strides + int, // rhs outer nd + const ssize_t *, // rhs shape and strides + int, // res outer nd + const ssize_t *, // res shape and strides + std::vector const &); + +template +sycl::event gemm_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + std::size_t n, + std::size_t k, + std::size_t m, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_shape_strides, + int rhs_outer_nd, + const ssize_t *rhs_shape_strides, + int res_outer_nd, + const ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + using OuterInnerIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const OuterInnerIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_shape_strides); + const OuterInnerIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_shape_strides); + const OuterInnerIndexerT res_indexer(res_outer_nd, 0, res_shape_strides); + + using BatchIndexerT = dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer; + static constexpr BatchIndexerT batch_indexer{}; + + static constexpr std::size_t single_batch_nelems = 1; + + const std::size_t min_nm = std::min(n, m); + const std::size_t max_nm = std::max(n, m); + + if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) { + return gemm_detail::_gemm_batch_nm_impl< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerIndexerT, + OuterInnerIndexerT, OuterInnerIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); + } + + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const IndexerT res_indexer(res_outer_nd, 0, res_shape_strides); + using InitKernelName = class gemm_init_krn; + cgh.parallel_for( + sycl::range<1>(n * m), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); + + if (k == 0) { + return res_init_ev; + } + + if ((max_nm < 64)) { + if (m < 4) { + return gemm_detail::_gemm_small_m_impl< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerIndexerT, + OuterInnerIndexerT, OuterInnerIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); + } + return gemm_detail::_gemm_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); + } + + return gemm_detail::_gemm_batch_nm_impl< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerIndexerT, + OuterInnerIndexerT, OuterInnerIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, {res_init_ev}); +} + +typedef sycl::event (*gemm_contig_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + std::size_t, // n + std::size_t, // k + std::size_t, // m + std::vector const &); + +template +sycl::event gemm_contig_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + std::size_t n, + std::size_t k, + std::size_t m, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + using OuterInnerIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + static constexpr OuterInnerIndexerT lhs_indexer{}; + static constexpr OuterInnerIndexerT rhs_indexer{}; + static constexpr OuterInnerIndexerT res_indexer{}; + + using BatchIndexerT = dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer; + static constexpr BatchIndexerT batch_indexer{}; + + static constexpr std::size_t single_batch_nelems = 1; + + const std::size_t min_nm = std::min(n, m); + const std::size_t max_nm = std::max(n, m); + if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) { + return gemm_detail::_gemm_batch_nm_impl< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerIndexerT, + OuterInnerIndexerT, OuterInnerIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); + } + + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m); + }); + + if (k == 0) { + return res_init_ev; + } + + if (max_nm < 64) { + if (m < 4) { + return gemm_detail::_gemm_small_m_impl< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerIndexerT, + OuterInnerIndexerT, OuterInnerIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); + } + return gemm_detail::_gemm_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); + } + + return gemm_detail::_gemm_batch_nm_impl< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerIndexerT, + OuterInnerIndexerT, OuterInnerIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, {res_init_ev}); +} + +template +class gemm_batch_init_krn; + +typedef sycl::event (*gemm_batch_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + std::size_t, // batch nelems + std::size_t, // lhs outer nelems (n) + std::size_t, // inner nelems (k) + std::size_t, // rhs outer nelems (m) + int, // batching nd + const ssize_t *, // batch shape strides + ssize_t, // lhs batch offset + ssize_t, // rhs batch offset + ssize_t, // res batch offset + int, // inner dims + int, // lhs outer dims + const ssize_t *, // lhs outer and inner shape and strides + int, // rhs outer dims + const ssize_t *, // rhs outer and inner shape and strides + int, // res outer dims + const ssize_t *, // res outer and inner shape and strides + const ssize_t *, // res full shape and strides + std::vector const &); + +template +sycl::event gemm_batch_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + std::size_t batch_nelems, + std::size_t n, + std::size_t k, + std::size_t m, + int batch_nd, + const ssize_t *batch_shape_strides, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const ssize_t *res_outer_shapes_strides, + const ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + const BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, + batch_shape_strides); + + const std::size_t min_nm = std::min(n, m); + const std::size_t max_nm = std::max(n, m); + + if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) { + return gemm_detail::_gemm_batch_nm_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); + } + + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const IndexerT res_indexer(batch_nd + res_outer_nd, res_batch_offset, + res_shape_strides); + using InitKernelName = class gemm_batch_init_krn; + cgh.parallel_for( + sycl::range<1>(n * m * batch_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); + + if (k == 0) { + return res_init_ev; + } + + if (m < 4) { + return gemm_detail::_gemm_small_m_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); + } + else if (k > n && k > m) { + return gemm_detail::_gemm_k_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); + } + else { + return gemm_detail::_gemm_batch_nm_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); + } +} + +typedef sycl::event (*gemm_batch_contig_impl_fn_ptr_t)( + sycl::queue &, + const char *, // lhs + const char *, // rhs + char *, // res + std::size_t, // batch nelems + std::size_t, // n + std::size_t, // k + std::size_t, // m + ssize_t, // lhs batch offset + ssize_t, // rhs batch offset + ssize_t, // res batch offset + std::vector const &); + +template +sycl::event gemm_batch_contig_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + std::size_t batch_nelems, + std::size_t n, + std::size_t k, + std::size_t m, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = + reinterpret_cast(lhs_cp) + lhs_batch_offset; + const rhsTy *rhs_tp = + reinterpret_cast(rhs_cp) + rhs_batch_offset; + resTy *res_tp = reinterpret_cast(res_cp) + res_batch_offset; + + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + static constexpr OuterInnerDimsIndexerT lhs_indexer{}; + static constexpr OuterInnerDimsIndexerT rhs_indexer{}; + static constexpr OuterInnerDimsIndexerT res_indexer{}; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + const BatchDimsIndexerT batch_indexer( + Strided1DIndexer{/* size */ batch_nelems, + /* step */ n * k}, + Strided1DIndexer{/* size */ batch_nelems, + /* step */ k * m}, + Strided1DIndexer{/* size */ batch_nelems, + /* step */ n * m}); + + const std::size_t min_nm = std::min(n, m); + const std::size_t max_nm = std::max(n, m); + + if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) { + return gemm_detail::_gemm_batch_nm_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); + } + + sycl::event res_init_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m * batch_nelems); + }); + + if (k == 0) { + return res_init_ev; + } + + if (max_nm < 64) { + if (m < 4) { + return gemm_detail::_gemm_small_m_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); + } + return gemm_detail::_gemm_k_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + {res_init_ev}); + } + + return gemm_detail::_gemm_batch_nm_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer, {res_init_ev}); +} + +// ========== Gemm Tree + +template +class GemmBatchNoAtomicFunctorThreadNM +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT1 local_A_block; + LocAccT2 local_B_block; + std::size_t n = 0; + std::size_t wg_delta_n = 0; + std::size_t k = 0; + std::size_t k_blocks = 0; + std::size_t wi_delta_k = 0; + std::size_t m = 0; + std::size_t m_blocks = 0; + std::size_t wg_delta_m = 0; + std::size_t batch_nelems; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmBatchNoAtomicFunctorThreadNM(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT1 local_A_block_, + LocAccT2 local_B_block_, + std::size_t n_, + std::size_t wg_delta_n_, + std::size_t k_, + std::size_t k_blocks_, + std::size_t wi_delta_k_, + std::size_t m_, + std::size_t m_blocks_, + std::size_t wg_delta_m_, + std::size_t batch_nelems_, + const BatchDimsIndexerT batch_indexer_, + const OuterInnerDimsIndexerT lhs_indexer_, + const OuterInnerDimsIndexerT rhs_indexer_, + const ResIndexerT res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), local_A_block(local_A_block_), + local_B_block(local_B_block_), n(n_), wg_delta_n(wg_delta_n_), k(k_), + k_blocks(k_blocks_), wi_delta_k(wi_delta_k_), m(m_), + m_blocks(m_blocks_), wg_delta_m(wg_delta_m_), + batch_nelems(batch_nelems_), batch_indexer(batch_indexer_), + lhs_indexer(lhs_indexer_), rhs_indexer(rhs_indexer_), + res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t n_groups_per_batch = + it.get_group_range(0) / batch_nelems; + const std::size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const std::size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + + const auto &three_offsets_ = batch_indexer(static_cast(m_id)); + + // lift group_id to (block_i, block_j, block_s), + // 0 <= block_i < n_blocks, 0 <= block_j < m_blocks, 0 <= block_s + // < k_blocks + + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + std::size_t block_i = gr_id / (m_blocks * k_blocks); + std::size_t block_r = gr_id - block_i * (m_blocks * k_blocks); + std::size_t block_j = block_r / k_blocks; + std::size_t block_s = block_r - block_j * k_blocks; + + std::size_t lid = it.get_local_linear_id(); + std::size_t local_i = lid / wg_delta_m; // 0<= local_i < wg_delta_n + std::size_t local_j = + lid - local_i * wg_delta_m; // 0<= local_j < wg_delta_m + + // load A block and B blocks into SLM + + std::size_t i = block_i * wi_delta_n * wg_delta_n; + std::size_t j = block_j * wi_delta_m * wg_delta_m; + std::size_t s = block_s * wi_delta_k; + + const std::int64_t a_st0 = k; + const std::int64_t a_st1 = 1; + + const std::int64_t b_st0 = m; + const std::int64_t b_st1 = 1; + + const std::int64_t c_st0 = m; + const std::int64_t c_st1 = 1; + + std::size_t lws = it.get_local_range(0); + + for (std::size_t vid = lid; vid < local_A_block.size(); vid += lws) { + std::size_t v_i = + vid / wi_delta_k; // 0<= v_i < wg_delta_n * wi_delta_n + std::size_t v_s = vid - v_i * wi_delta_k; // 0<= v_s < wi_delta_k + + std::size_t g_i = i + v_i; + std::size_t g_s = s + v_s; + + local_A_block[vid] = + (g_i < n && g_s < k) + ? static_cast( + lhs[lhs_offset + + lhs_indexer(g_i * a_st0 + g_s * a_st1)]) + : resT(0); + } + + using slmB_t = typename LocAccT2::value_type; + + for (std::size_t vid = lid; vid < local_B_block.size(); vid += lws) { + std::size_t v_j = vid / wi_delta_k; // 0<= v_i < wg_delta_m + std::size_t v_s = vid - v_j * wi_delta_k; // 0<= v_s < wi_delta_k + + std::size_t g_j = j + v_j * wi_delta_m; + std::size_t g_s = s + v_s; + + if constexpr (wi_delta_m == 1 && std::is_same_v) { + local_B_block[vid] = + (g_j < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j * b_st1)]) + : resT(0); + } + else { + slmB_t vec{}; +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) + { + std::size_t g_j1 = g_j + lane_id; + vec[lane_id] = + (g_j1 < m && g_s < k) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(g_s * b_st0 + g_j1 * b_st1)]) + : resT(0); + } + + local_B_block[vid] = vec; + } + } + + it.barrier(sycl::access::fence_space::local_space); + + i += local_i * wi_delta_n; + j += local_j * wi_delta_m; + + const std::size_t a_offset = local_i * wi_delta_k * wi_delta_n; + const std::size_t b_offset = local_j * wi_delta_k; + + static constexpr resT identity_(0); + + for (std::uint8_t private_i = 0; private_i < wi_delta_n; ++private_i) { + const std::size_t a_pr_offset = private_i * wi_delta_k; + + slmB_t local_sum(identity_); + for (std::size_t private_s = 0; private_s < wi_delta_k; ++private_s) + { + local_sum = local_sum + + (local_A_block[a_offset + a_pr_offset + private_s] * + local_B_block[b_offset + private_s]); + } + + const std::size_t gl_i = i + private_i; + + if constexpr (wi_delta_m == 1 && std::is_same_v) { + const std::size_t gl_j = j; + if (gl_i < n && gl_j < m) { + res[res_offset + res_indexer(gl_i * c_st0 + gl_j * c_st1) + + (block_s * n * m * batch_nelems)] = local_sum; + } + } + else { +#pragma unroll + for (std::uint8_t lane_id = 0; lane_id < wi_delta_m; ++lane_id) + { + const std::size_t gl_j = j + lane_id; + + if (gl_i < n && gl_j < m) { + res[res_offset + + res_indexer(gl_i * c_st0 + gl_j * c_st1) + + (block_s * n * m * batch_nelems)] = + local_sum[lane_id]; + } + } + } + } + } +}; + +template +class GemmBatchNoAtomicFunctorThreadK +{ +private: + const lhsT *lhs = nullptr; + const rhsT *rhs = nullptr; + resT *res = nullptr; + LocAccT workspace; + LocAccT local_B_block; + std::size_t n = 0; + std::size_t n_blocks = 0; + std::size_t delta_n = 0; + std::size_t k = 0; + std::size_t k_blocks = 0; + std::size_t delta_k = 0; + std::size_t n_wi = 0; + std::size_t m = 0; + std::size_t batch_nelems = 0; + BatchDimsIndexerT batch_indexer; + OuterInnerDimsIndexerT lhs_indexer; + OuterInnerDimsIndexerT rhs_indexer; + ResIndexerT res_indexer; + +public: + GemmBatchNoAtomicFunctorThreadK(const lhsT *lhs_, + const rhsT *rhs_, + resT *res_, + LocAccT workspace_, + LocAccT local_B_block_, + std::size_t n_, + std::size_t n_blocks_, + std::size_t delta_n_, + std::size_t k_, + std::size_t k_blocks_, + std::size_t delta_k_, + std::size_t n_wi_, + std::size_t m_, + std::size_t batch_nelems_, + const BatchDimsIndexerT &batch_indexer_, + const OuterInnerDimsIndexerT &lhs_indexer_, + const OuterInnerDimsIndexerT &rhs_indexer_, + const ResIndexerT &res_indexer_) + : lhs(lhs_), rhs(rhs_), res(res_), workspace(workspace_), + local_B_block(local_B_block_), n(n_), n_blocks(n_blocks_), + delta_n(delta_n_), k(k_), k_blocks(k_blocks_), delta_k(delta_k_), + n_wi(n_wi_), m(m_), batch_nelems(batch_nelems_), + batch_indexer(batch_indexer_), lhs_indexer(lhs_indexer_), + rhs_indexer(rhs_indexer_), res_indexer(res_indexer_) + { + } + + void operator()(sycl::nd_item<1> it) const + { + const std::size_t n_groups_per_batch = + it.get_group_range(0) / batch_nelems; + const std::size_t m_id = it.get_group_linear_id() / n_groups_per_batch; + const std::size_t gr_id = + it.get_group_linear_id() - m_id * n_groups_per_batch; + std::size_t lid = it.get_local_linear_id(); + + const auto &three_offsets_ = batch_indexer(static_cast(m_id)); + const auto &lhs_offset = three_offsets_.get_first_offset(); + const auto &rhs_offset = three_offsets_.get_second_offset(); + const auto &res_offset = three_offsets_.get_third_offset(); + + // lift gr_id -> (block_i, block_j, block_s) + // block_i moves fastest, then block_s, then block_j + + const std::size_t r_size = (n_blocks * k_blocks); + // 0 <= block_j < m_blocks + std::size_t block_j = gr_id / r_size; + // 0 <= block_r < n_blocks * k_blocks + std::size_t block_r = gr_id - block_j * r_size; + // 0 <= block_s < k_blocks + std::size_t block_s = block_r / n_blocks; + // 0 <= block_i < n_blocks + std::size_t block_i = block_r - block_s * n_blocks; + + std::size_t local_i = lid / (delta_k); // 0 <= local_i < delta_n + std::size_t local_s = + lid - local_i * (delta_k); // 0 <= local_s < delta_k + + std::size_t i = block_i * delta_n + local_i; + std::size_t j = m_groups * block_j; + std::size_t s = block_s * delta_k * n_wi + local_s; + + using accV_t = typename LocAccT::value_type; + + static constexpr resT identity_ = resT(0); + if (local_i == 0) { + for (std::size_t q = 0; q < n_wi * delta_k; q += delta_k) { + std::size_t sq = s + q; + std::size_t sqmj = sq * m + j; + + if constexpr (m_groups == 1 && std::is_same_v) { + local_B_block[local_s + q] = + (sq < k && j < m) + ? static_cast( + rhs[rhs_offset + rhs_indexer(sqmj)]) + : identity_; + } + else { + accV_t local_B_vec; +#pragma unroll + for (std::size_t vec_idx = 0; vec_idx < m_groups; ++vec_idx) + { + local_B_vec[vec_idx] = + (sq < k && j + vec_idx < m) + ? static_cast( + rhs[rhs_offset + + rhs_indexer(sqmj + vec_idx)]) + : identity_; + } + local_B_block[local_s + q] = local_B_vec; + } + } + } + + it.barrier(sycl::access::fence_space::local_space); + + std::size_t t_shift = block_s * delta_k * n_wi; + std::size_t global_s_offset = i * k + t_shift; + + accV_t private_sum(identity_); + static constexpr accV_t vec_identity_(identity_); + for (std::size_t t = local_s; t < local_B_block.size(); t += delta_k) { + private_sum += + ((i < n) && (t + t_shift < k)) + ? (static_cast( + lhs[lhs_offset + lhs_indexer(global_s_offset + t)]) * + local_B_block[t]) + : vec_identity_; + } + + std::size_t workspace_i_shift = local_i * delta_k; + workspace[workspace_i_shift + local_s] = private_sum; + + it.barrier(sycl::access::fence_space::local_space); + + if (local_s == 0 && i < n) { + accV_t local_sum(workspace[workspace_i_shift]); + for (std::size_t t = 1; t < delta_k; ++t) { + local_sum += workspace[workspace_i_shift + t]; + } + + const std::size_t total_offset = + res_offset + (block_s * n * m * batch_nelems); + + if constexpr (m_groups == 1 && std::is_same_v) { + res[total_offset + res_indexer(i * m + j)] = local_sum; + } + else { + res[total_offset + res_indexer(i * m + j)] = local_sum[0]; + +#pragma unroll + for (std::size_t vec_id = 1; vec_id < m_groups; ++vec_id) { + if (j + vec_id < m) { + res[total_offset + res_indexer(i * m + j + vec_id)] = + local_sum[vec_id]; + } + } + } + } + } +}; + +template +class gemm_batch_tree_k_krn; + +template +class gemm_batch_tree_nm_krn; + +namespace gemm_detail +{ + +template +sycl::event _gemm_tree_k_step(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + const std::size_t batch_nelems, + const std::size_t n, + const std::size_t k, + const std::size_t m, + const std::size_t delta_n, + const std::size_t n_wi, + const std::size_t delta_k, + const BatchIndexerT &batch_indexer, + const LhsIndexerT &lhs_indexer, + const RhsIndexerT &rhs_indexer, + const ResIndexerT &res_indexer, + const std::vector &depends) +{ + static_assert(std::is_same_v); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + const std::size_t n_blocks = (n + delta_n - 1) / delta_n; + const std::size_t k_blocks = + (k + n_wi * delta_k - 1) / (n_wi * delta_k); + const std::size_t m_blocks = (m + m_groups - 1) / m_groups; + + const std::size_t lws = delta_n * delta_k; + const std::size_t gws = + batch_nelems * n_blocks * m_blocks * k_blocks * lws; + + auto gRange = sycl::range<1>(gws); + auto lRange = sycl::range<1>(lws); + auto ndRange = sycl::nd_range<1>(gRange, lRange); + + using slmB_t = + typename std::conditional>::type; + + using LocAccT = sycl::local_accessor; + LocAccT local_B_block(n_wi * delta_k, cgh); + LocAccT workspace(delta_n * delta_k, cgh); + + using KernelName = + class gemm_batch_tree_k_krn; + + cgh.parallel_for( + ndRange, + GemmBatchNoAtomicFunctorThreadK( + lhs_tp, rhs_tp, res_tp, std::move(workspace), + std::move(local_B_block), n, n_blocks, delta_n, k, k_blocks, + delta_k, n_wi, m, batch_nelems, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer)); + }); + return gemm_ev; +} + +} // end of namespace gemm_detail + +template +sycl::event + gemm_batch_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + std::size_t batch_nelems, + std::size_t n, + std::size_t k, + std::size_t m, + int batch_nd, + const ssize_t *batch_shape_strides, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const ssize_t *res_outer_shapes_strides, + const ssize_t *res_shape_strides, + std::vector const &depends) +{ + std::size_t delta_k(4); + std::size_t n_wi(64); + std::size_t delta_n(32); + + const sycl::device &dev = exec_q.get_device(); + const std::size_t local_mem_size = + dev.get_info(); + const std::size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + if (k <= (delta_k * n_wi)) { + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + const OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + const BatchDimsIndexerT batch_indexer( + batch_nd, lhs_batch_offset, rhs_batch_offset, res_batch_offset, + batch_shape_strides); + + return gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, delta_n, + n_wi, delta_k, batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + depends); + } + else { + using ReductionOpT = + typename std::conditional, + sycl::logical_or, + sycl::plus>::type; + static constexpr resTy identity_val = + sycl::known_identity::value; + + std::size_t iter_nelems = batch_nelems * n * m; + std::size_t reduction_nelems = + (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-group is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple + const auto &sg_sizes = + dev.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + static constexpr std::size_t preferred_reductions_per_wi = 4; + std::size_t reductions_per_wi(preferred_reductions_per_wi); + + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + // max_max_wg prevents running out of resources on CPU + static constexpr std::size_t max_max_wg = 2048; + std::size_t max_wg = std::min( + max_max_wg, + dev.get_info() / 2); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + auto tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * reduction_nelems, exec_q); + resTy *tmp = tmp_owner.get(); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + const OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + static constexpr TmpIndexerT res_indexer{}; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; + const StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + const UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + const Strided1DIndexer tmp_batch_indexer( + /* size */ batch_nelems, + /* step */ n * m); + const BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + sycl::event gemm_ev = gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, TmpIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, tmp, batch_nelems, n, k, m, delta_n, + n_wi, delta_k, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer, depends); + + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {red_ev}, + tmp_owner); + + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + const std::size_t tmp_alloc_size = + iter_nelems * ( + /* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups); + + // get unique_ptr owning the temporary allocation + auto tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + // get raw USM pointer + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + ; + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + const OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + static constexpr TmpIndexerT res_indexer{}; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + const StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + const StridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides + 2 * batch_nd); + const Strided1DIndexer tmp_batch_indexer( + /* size */ batch_nelems, + /* step */ n * m); + const BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + sycl::event gemm_ev = gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, TmpIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, batch_nelems, n, + k, m, delta_n, n_wi, delta_k, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer, depends); + + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {red_ev}, + tmp_owner); + + return cleanup_host_task_event; + } + } +} + +namespace gemm_detail +{ + +template +sycl::event _gemm_tree_nm_step(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + const std::size_t batch_nelems, + const std::size_t n, + const std::size_t k, + const std::size_t m, + const std::uint32_t wg_delta_n, + const std::uint32_t wg_delta_m, + const std::uint32_t wi_delta_k, + const BatchIndexerT &batch_indexer, + const LhsIndexerT &lhs_indexer, + const RhsIndexerT &rhs_indexer, + const ResIndexerT &res_indexer, + const std::vector &depends) +{ + static_assert(std::is_same_v); + + sycl::event gemm_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + const std::size_t lws = wg_delta_n * wg_delta_m; + + const std::size_t n_blocks = + ((n + wi_delta_n * wg_delta_n - 1) / (wi_delta_n * wg_delta_n)); + const std::size_t k_blocks = ((k + wi_delta_k - 1) / wi_delta_k); + const std::size_t m_blocks = + ((m + wi_delta_m * wg_delta_m - 1) / (wi_delta_m * wg_delta_m)); + + const std::size_t gws = + batch_nelems * n_blocks * m_blocks * k_blocks * lws; + + auto gwsRange = sycl::range<1>(gws); + auto lwsRange = sycl::range<1>(lws); + auto ndRange = sycl::nd_range<1>(gwsRange, lwsRange); + + using slmB_t = + typename std::conditional>::type; + using LocAccT1 = sycl::local_accessor; + using LocAccT2 = sycl::local_accessor; + + const sycl::range<1> local_A_size((wi_delta_n * wg_delta_n) * + wi_delta_k); + const sycl::range<1> local_B_size(wi_delta_k * wg_delta_m); + + LocAccT1 local_A_block(local_A_size, cgh); + LocAccT2 local_B_block(local_B_size, cgh); + + using KernelName = + class gemm_batch_tree_nm_krn; + cgh.parallel_for( + ndRange, GemmBatchNoAtomicFunctorThreadNM< + lhsTy, rhsTy, resTy, LocAccT1, LocAccT2, LhsIndexerT, + ResIndexerT, BatchIndexerT, wi_delta_n, wi_delta_m>( + lhs_tp, rhs_tp, res_tp, std::move(local_A_block), + std::move(local_B_block), n, wg_delta_n, k, k_blocks, + wi_delta_k, m, m_blocks, wg_delta_m, batch_nelems, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer)); + }); + return gemm_ev; +} + +} // end namespace gemm_detail + +template +sycl::event + gemm_batch_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + std::size_t batch_nelems, + std::size_t n, + std::size_t k, + std::size_t m, + int batch_nd, + const ssize_t *batch_shape_strides, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const ssize_t *res_outer_shapes_strides, + const ssize_t *res_shape_strides, + std::vector const &depends) +{ + static constexpr int wi_delta_n = 2; + std::size_t wg_delta_n(16); // rows of A processed in WG + std::size_t wg_delta_m(16); // rows of B processed in WG + std::size_t wi_delta_k(64); // Elements in K dimension processed by WI + + const sycl::device &dev = exec_q.get_device(); + const std::size_t local_mem_size = + dev.get_info(); + const std::size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + if (k <= wi_delta_k) { + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + const OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + const BatchDimsIndexerT batch_indexer( + batch_nd, lhs_batch_offset, rhs_batch_offset, res_batch_offset, + batch_shape_strides); + + return gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, wi_delta_n, + wi_delta_m>(exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + wg_delta_n, wg_delta_m, wi_delta_k, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer, depends); + } + else { + using ReductionOpT = + typename std::conditional, + sycl::logical_or, + sycl::plus>::type; + static constexpr resTy identity_val = + sycl::known_identity::value; + std::size_t iter_nelems = batch_nelems * n * m; + std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + static constexpr std::size_t preferred_reductions_per_wi = 4; + std::size_t reductions_per_wi(preferred_reductions_per_wi); + + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + std::size_t max_wg = reduction_detail::get_work_group_size(dev); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + auto tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * reduction_nelems, exec_q); + resTy *tmp = tmp_owner.get(); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + const OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + static constexpr TmpIndexerT res_indexer{}; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; + const StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + const UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + const Strided1DIndexer tmp_batch_indexer( + /* size */ batch_nelems, + /* step */ n * m); + const BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + sycl::event gemm_ev = gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, TmpIndexerT, wi_delta_n, wi_delta_m>( + exec_q, lhs_tp, rhs_tp, tmp, batch_nelems, n, k, m, wg_delta_n, + wg_delta_m, wi_delta_k, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer, depends); + + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {red_ev}, + tmp_owner); + + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + const std::size_t tmp_alloc_size = + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups); + + auto tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + ; + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::StridedIndexer; + using TmpIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + + const OuterInnerDimsIndexerT lhs_indexer( + inner_nd + lhs_outer_nd, 0, lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer( + inner_nd + rhs_outer_nd, 0, rhs_outer_inner_shapes_strides); + static constexpr TmpIndexerT res_indexer{}; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::StridedIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using dpctl::tensor::offset_utils::UnpackedStridedIndexer; + using BatchDimsIndexerT = ThreeOffsets_CombinedIndexer< + StridedIndexer, UnpackedStridedIndexer, Strided1DIndexer>; + + const StridedIndexer lhs_batch_indexer(batch_nd, lhs_batch_offset, + batch_shape_strides); + const UnpackedStridedIndexer rhs_batch_indexer( + batch_nd, rhs_batch_offset, batch_shape_strides, + batch_shape_strides + 2 * batch_nd); + const Strided1DIndexer tmp_batch_indexer( + /* size */ batch_nelems, + /* step */ n * m); + const BatchDimsIndexerT batch_indexer( + lhs_batch_indexer, rhs_batch_indexer, tmp_batch_indexer); + + sycl::event gemm_ev = gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, TmpIndexerT, wi_delta_n, wi_delta_m>( + exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, batch_nelems, n, + k, m, wg_delta_n, wg_delta_m, wi_delta_k, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer, depends); + + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + batch_nd + res_outer_nd, res_batch_offset, res_shape_strides, + {gemm_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {red_ev}, + tmp_owner); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event gemm_batch_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + std::size_t batch_nelems, + std::size_t n, + std::size_t k, + std::size_t m, + int batch_nd, + const ssize_t *batch_shape_strides, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const ssize_t *res_outer_shapes_strides, + const ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_outer_shapes_strides); + + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer; + const BatchDimsIndexerT batch_indexer(batch_nd, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, + batch_shape_strides); + + sycl::event gemm_ev = gemm_detail::_gemm_batch_nm_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer, depends); + + return gemm_ev; +} + +template +class gemm_batch_tree_empty_krn; + +template +sycl::event gemm_batch_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + std::size_t batch_nelems, + std::size_t n, + std::size_t k, + std::size_t m, + int batch_nd, + const ssize_t *batch_shape_strides, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_outer_nd, + const ssize_t *res_outer_shapes_strides, + const ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const std::size_t min_nm = std::min(n, m); + const std::size_t max_nm = std::max(n, m); + + if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) { + return gemm_batch_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + + if (k == 0) { + sycl::event gemm_batch_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const IndexerT res_indexer(batch_nd + res_outer_nd, + res_batch_offset, res_shape_strides); + using InitKernelName = + class gemm_batch_tree_empty_krn; + cgh.parallel_for( + sycl::range<1>(n * m * batch_nelems), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); + return gemm_batch_no_reduction_ev; + } + + if (max_nm < 64) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m < 4) { + static constexpr std::uint32_t m_groups_one = 1; + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_nd, batch_shape_strides, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + else { + static constexpr std::uint32_t m_groups_four = 4; + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_nd, batch_shape_strides, lhs_batch_offset, + rhs_batch_offset, res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + } + else { + static constexpr std::uint32_t m_groups_one = 1; + return gemm_batch_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + static constexpr std::uint32_t m_groups_four = 4; + return gemm_batch_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + else { // m > 1, n > k or m > k, resTy complex + static constexpr std::uint32_t m_groups_one = 1; + return gemm_batch_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, batch_nd, + batch_shape_strides, lhs_batch_offset, rhs_batch_offset, + res_batch_offset, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_outer_nd, + res_outer_shapes_strides, res_shape_strides, depends); + } + } +} + +template +sycl::event + gemm_batch_contig_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + std::size_t batch_nelems, + std::size_t n, + std::size_t k, + std::size_t m, + std::vector const &depends) +{ + std::size_t delta_k(4); + std::size_t n_wi(64); + std::size_t delta_n(32); + + const sycl::device &dev = exec_q.get_device(); + const std::size_t local_mem_size = + dev.get_info(); + const std::size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + if (k <= (delta_k * n_wi)) { + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + static constexpr OuterInnerDimsIndexerT lhs_indexer{}; + static constexpr OuterInnerDimsIndexerT rhs_indexer{}; + static constexpr OuterInnerDimsIndexerT res_indexer{}; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + const BatchDimsIndexerT batch_indexer( + Strided1DIndexer{/* size */ batch_nelems, + /* step */ n * k}, + Strided1DIndexer{/* size */ batch_nelems, + /* step */ k * m}, + Strided1DIndexer{/* size */ batch_nelems, + /* step */ n * m}); + + return gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, delta_n, + n_wi, delta_k, batch_indexer, lhs_indexer, rhs_indexer, res_indexer, + depends); + } + else { + using ReductionOpT = + typename std::conditional, + sycl::logical_or, + sycl::plus>::type; + static constexpr resTy identity_val = + sycl::known_identity::value; + + std::size_t iter_nelems = batch_nelems * n * m; + std::size_t reduction_nelems = + (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-group is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple + const auto &sg_sizes = + dev.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + static constexpr std::size_t preferred_reductions_per_wi = 4; + std::size_t reductions_per_wi(preferred_reductions_per_wi); + + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + std::size_t max_wg = reduction_detail::get_work_group_size(dev); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + auto tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * reduction_nelems, exec_q); + resTy *tmp = tmp_owner.get(); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + static constexpr OuterInnerDimsIndexerT lhs_indexer{}; + static constexpr OuterInnerDimsIndexerT rhs_indexer{}; + static constexpr OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + const BatchDimsIndexerT batch_indexer( + Strided1DIndexer{/* size */ batch_nelems, + /* step */ n * k}, + Strided1DIndexer{/* size */ batch_nelems, + /* step */ k * m}, + Strided1DIndexer{/* size */ batch_nelems, + /* step */ n * m}); + + sycl::event gemm_ev = gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, tmp, batch_nelems, n, k, m, delta_n, + n_wi, delta_k, batch_indexer, lhs_indexer, rhs_indexer, + tmp_indexer, depends); + + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {red_ev}, + tmp_owner); + + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + const std::size_t tmp_alloc_size = + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups); + + auto tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + static constexpr OuterInnerDimsIndexerT lhs_indexer{}; + static constexpr OuterInnerDimsIndexerT rhs_indexer{}; + static constexpr OuterInnerDimsIndexerT tmp_indexer{}; + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + const BatchDimsIndexerT batch_indexer( + Strided1DIndexer{/* size */ batch_nelems, + /* step */ n * k}, + Strided1DIndexer{/* size */ batch_nelems, + /* step */ k * m}, + Strided1DIndexer{/* size */ batch_nelems, + /* step */ n * m}); + + sycl::event gemm_ev = gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, batch_nelems, n, + k, m, delta_n, n_wi, delta_k, batch_indexer, lhs_indexer, + rhs_indexer, tmp_indexer, depends); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {red_ev}, + tmp_owner); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event + gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + std::size_t batch_nelems, + std::size_t n, + std::size_t k, + std::size_t m, + std::vector const &depends) +{ + static constexpr int wi_delta_n = 2; + std::size_t wg_delta_n(16); // rows of A processed in WG + std::size_t wg_delta_m(16); // rows of B processed in WG + std::size_t wi_delta_k(64); // Elements in K dimension processed by WI + + const sycl::device &dev = exec_q.get_device(); + const std::size_t local_mem_size = + dev.get_info(); + const std::size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + // each group processes delta_k * n_wi + // items in a column, so no need for allocating + // temp memory if only one group is needed + if (k <= wi_delta_k) { + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + static constexpr OuterInnerDimsIndexerT lhs_indexer{}; + static constexpr OuterInnerDimsIndexerT rhs_indexer{}; + static constexpr OuterInnerDimsIndexerT res_indexer{}; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + const BatchDimsIndexerT batch_indexer( + Strided1DIndexer{/* size */ batch_nelems, + /* step */ n * k}, + Strided1DIndexer{/* size */ batch_nelems, + /* step */ k * m}, + Strided1DIndexer{/* size */ batch_nelems, + /* step */ n * m}); + + return gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, wi_delta_n, + wi_delta_m>(exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + wg_delta_n, wg_delta_m, wi_delta_k, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer, depends); + } + else { + using ReductionOpT = + typename std::conditional, + sycl::logical_or, + sycl::plus>::type; + static constexpr resTy identity_val = + sycl::known_identity::value; + std::size_t iter_nelems = batch_nelems * n * m; + std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-group is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + static constexpr std::size_t preferred_reductions_per_wi = 4; + std::size_t reductions_per_wi(preferred_reductions_per_wi); + + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + std::size_t max_wg = reduction_detail::get_work_group_size(dev); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + auto tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * reduction_nelems, exec_q); + + resTy *tmp = tmp_owner.get(); + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + static constexpr OuterInnerDimsIndexerT lhs_indexer{}; + static constexpr OuterInnerDimsIndexerT rhs_indexer{}; + static constexpr OuterInnerDimsIndexerT tmp_indexer{}; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + const BatchDimsIndexerT batch_indexer( + Strided1DIndexer{/* size */ batch_nelems, + /* step */ n * k}, + Strided1DIndexer{/* size */ batch_nelems, + /* step */ k * m}, + Strided1DIndexer{/* size */ batch_nelems, + /* step */ n * m}); + + sycl::event gemm_ev = gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, wi_delta_n, + wi_delta_m>(exec_q, lhs_tp, rhs_tp, tmp, batch_nelems, n, k, m, + wg_delta_n, wg_delta_m, wi_delta_k, batch_indexer, + lhs_indexer, rhs_indexer, tmp_indexer, depends); + + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {red_ev}, + tmp_owner); + + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + const std::size_t tmp_alloc_size = + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups); + + auto tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + + using OuterInnerDimsIndexerT = + dpctl::tensor::offset_utils::NoOpIndexer; + static constexpr OuterInnerDimsIndexerT lhs_indexer{}; + static constexpr OuterInnerDimsIndexerT rhs_indexer{}; + static constexpr OuterInnerDimsIndexerT tmp_indexer{}; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + const BatchDimsIndexerT batch_indexer( + Strided1DIndexer{/* size */ batch_nelems, + /* step */ n * k}, + Strided1DIndexer{/* size */ batch_nelems, + /* step */ k * m}, + Strided1DIndexer{/* size */ batch_nelems, + /* step */ n * m}); + + sycl::event gemm_ev = gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, wi_delta_n, + wi_delta_m>(exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, + batch_nelems, n, k, m, wg_delta_n, wg_delta_m, + wi_delta_k, batch_indexer, lhs_indexer, rhs_indexer, + tmp_indexer, depends); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {red_ev}, + tmp_owner); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event gemm_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + std::size_t n, + std::size_t k, + std::size_t m, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_shape_strides, + int rhs_outer_nd, + const ssize_t *rhs_shape_strides, + int res_outer_nd, + const ssize_t *res_shape_strides, + std::vector const &depends = {}) +{ + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_shape_strides); + const OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_shape_strides); + const OuterInnerDimsIndexerT res_indexer(res_outer_nd, 0, + res_shape_strides); + + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer; + static constexpr BatchDimsIndexerT batch_indexer{}; + + static constexpr std::size_t single_batch_nelems = 1; + + sycl::event gemm_ev = gemm_detail::_gemm_batch_nm_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); + + return gemm_ev; +} + +template +sycl::event + gemm_batch_nm_contig_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + std::size_t batch_nelems, + std::size_t n, + std::size_t k, + std::size_t m, + std::vector const &depends = {}) +{ + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + static constexpr OuterInnerDimsIndexerT lhs_indexer{}; + static constexpr OuterInnerDimsIndexerT rhs_indexer{}; + static constexpr OuterInnerDimsIndexerT res_indexer{}; + + static constexpr std::size_t single_batch_nelems = 1; + if (batch_nelems == single_batch_nelems) { + using BatchDimsIndexerT = + dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer; + static constexpr BatchDimsIndexerT batch_indexer{}; + + sycl::event gemm_ev = gemm_detail::_gemm_batch_nm_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); + + return gemm_ev; + } + else { + using dpctl::tensor::offset_utils::Strided1DIndexer; + using dpctl::tensor::offset_utils::ThreeOffsets_CombinedIndexer; + using BatchDimsIndexerT = + ThreeOffsets_CombinedIndexer; + + using dpctl::tensor::offset_utils::Strided1DIndexer; + + const BatchDimsIndexerT batch_indexer( + Strided1DIndexer{/* size */ batch_nelems, + /* step */ n * k}, + Strided1DIndexer{/* size */ batch_nelems, + /* step */ k * m}, + Strided1DIndexer{/* size */ batch_nelems, + /* step */ n * m}); + + sycl::event gemm_ev = gemm_detail::_gemm_batch_nm_impl< + lhsTy, rhsTy, resTy, BatchDimsIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT>( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); + + return gemm_ev; + } +} + +template +sycl::event + gemm_batch_contig_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + std::size_t batch_nelems, + std::size_t n, + std::size_t k, + std::size_t m, + ssize_t lhs_batch_offset, + ssize_t rhs_batch_offset, + ssize_t res_batch_offset, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = + reinterpret_cast(lhs_cp) + lhs_batch_offset; + const rhsTy *rhs_tp = + reinterpret_cast(rhs_cp) + rhs_batch_offset; + resTy *res_tp = reinterpret_cast(res_cp) + res_batch_offset; + + const std::size_t min_nm = std::min(n, m); + const std::size_t max_nm = std::max(n, m); + + if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) { + return gemm_batch_nm_contig_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + } + + if (k == 0) { + sycl::event gemm_batch_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m * batch_nelems); + }); + return gemm_batch_no_reduction_ev; + } + + if (max_nm < 64) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m < 4) { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + depends); + } + else { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, + depends); + } + } + else { + return gemm_batch_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + } + } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_batch_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + } + else { // m > 1, n > k or m > k, resTy complex + return gemm_batch_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, batch_nelems, n, k, m, depends); + } + } +} + +// Gemm tree non-batched + +template +class gemm_tree_nm_krn; + +template +class gemm_tree_k_krn; + +template +sycl::event gemm_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + std::size_t n, + std::size_t k, + std::size_t m, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_nd, + const ssize_t *res_shapes_strides, + const std::vector &depends) +{ + std::size_t delta_k(4); + std::size_t n_wi(64); + std::size_t delta_n(32); + + const sycl::device &dev = exec_q.get_device(); + const std::size_t local_mem_size = + dev.get_info(); + const std::size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + using BatchIndexerT = dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer; + static constexpr BatchIndexerT batch_indexer{}; + + static constexpr std::size_t single_batch_nelems = 1; + + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + + sycl::event gemm_ev; + if (k <= (delta_k * n_wi)) { + const OuterInnerDimsIndexerT res_indexer(res_nd, 0, res_shapes_strides); + + return gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + delta_n, n_wi, delta_k, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer, depends); + } + else { + using ReductionOpT = + typename std::conditional, + sycl::logical_or, + sycl::plus>::type; + static constexpr resTy identity_val = + sycl::known_identity::value; + + std::size_t iter_nelems = n * m; + std::size_t reduction_nelems = + (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-groups is needed, requires a temporary + // delta_k * n_wi elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + static constexpr std::size_t preferred_reductions_per_wi = 8; + std::size_t reductions_per_wi(preferred_reductions_per_wi); + + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + std::size_t max_wg = reduction_detail::get_work_group_size(dev); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + + auto tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * reduction_nelems, exec_q); + resTy *tmp = tmp_owner.get(); + + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + static constexpr ResIndexerT res_indexer{}; + + sycl::event gemm_ev = gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, ResIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, tmp, single_batch_nelems, n, k, m, + delta_n, n_wi, delta_k, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer, depends); + + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, res_nd, 0, + res_shapes_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {red_ev}, + tmp_owner); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + const std::size_t tmp_alloc_size = + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups); + + auto tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + static constexpr ResIndexerT res_indexer{}; + + sycl::event gemm_ev = gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, ResIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, + single_batch_nelems, n, k, m, delta_n, n_wi, delta_k, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); + + // tree_reduction_for_gemm returns sycl::event for reduction + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + res_nd, 0, res_shapes_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {red_ev}, + tmp_owner); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event gemm_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + std::size_t n, + std::size_t k, + std::size_t m, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_nd, + const ssize_t *res_shapes_strides, + const std::vector &depends) +{ + static constexpr int wi_delta_n = 2; + std::size_t wg_delta_n(16); // rows of A processed in WG + std::size_t wg_delta_m(16); // rows of B processed in WG + std::size_t wi_delta_k(64); // Elements in K dimension processed by WI + + const sycl::device &dev = exec_q.get_device(); + const std::size_t local_mem_size = + dev.get_info(); + const std::size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + using BatchIndexerT = dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer; + static constexpr BatchIndexerT batch_indexer{}; + + static constexpr std::size_t single_batch_nelems = 1; + + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const OuterInnerDimsIndexerT lhs_indexer(inner_nd + lhs_outer_nd, 0, + lhs_outer_inner_shapes_strides); + const OuterInnerDimsIndexerT rhs_indexer(inner_nd + rhs_outer_nd, 0, + rhs_outer_inner_shapes_strides); + + // each group processes delta_k items in a column, + // so no need to allocate temp memory if one group needed + if (k <= wi_delta_k) { + const OuterInnerDimsIndexerT res_indexer(res_nd, 0, res_shapes_strides); + + return gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, wi_delta_n, + wi_delta_m>(exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, + k, m, wg_delta_n, wg_delta_m, wi_delta_k, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer, depends); + } + else { + using ReductionOpT = + typename std::conditional, + sycl::logical_or, + sycl::plus>::type; + static constexpr resTy identity_val = + sycl::known_identity::value; + + std::size_t iter_nelems = n * m; + std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-groups is needed, requires a temporary + // wi_delta_k elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + static constexpr std::size_t preferred_reductions_per_wi = 8; + std::size_t reductions_per_wi(preferred_reductions_per_wi); + + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + std::size_t max_wg = reduction_detail::get_work_group_size(dev); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + auto tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * reduction_nelems, exec_q); + resTy *tmp = tmp_owner.get(); + + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + static constexpr ResIndexerT res_indexer{}; + + sycl::event gemm_ev = gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, wi_delta_m>( + exec_q, lhs_tp, rhs_tp, tmp, single_batch_nelems, n, k, m, + wg_delta_n, wg_delta_m, wi_delta_k, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer, depends); + + sycl::event red_ev = single_reduction_for_gemm( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, res_nd, 0, + res_shapes_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {red_ev}, + tmp_owner); + + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + const std::size_t tmp_alloc_size = + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups); + auto tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + + using ResIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + static constexpr ResIndexerT res_indexer{}; + + sycl::event gemm_ev = gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, ResIndexerT, wi_delta_n, wi_delta_m>( + exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, + single_batch_nelems, n, k, m, wg_delta_n, wg_delta_m, + wi_delta_k, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer, depends); + + sycl::event red_ev = tree_reduction_for_gemm( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, res_tp, + identity_val, iter_nelems, reduction_nelems, reduction_groups, + wg, max_wg, preferred_reductions_per_wi, reductions_per_wi, + res_nd, 0, res_shapes_strides, {gemm_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {red_ev}, + tmp_owner); + + return cleanup_host_task_event; + } + } +} + +template +class gemm_tree_empty_krn; + +template +sycl::event gemm_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + std::size_t n, + std::size_t k, + std::size_t m, + int inner_nd, + int lhs_outer_nd, + const ssize_t *lhs_outer_inner_shapes_strides, + int rhs_outer_nd, + const ssize_t *rhs_outer_inner_shapes_strides, + int res_nd, + const ssize_t *res_shapes_strides, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const std::size_t min_nm = std::min(n, m); + const std::size_t max_nm = std::max(n, m); + + if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) { + return gemm_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + + if (k == 0) { + sycl::event gemm_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + using IndexerT = dpctl::tensor::offset_utils::StridedIndexer; + const IndexerT res_indexer(res_nd, 0, res_shapes_strides); + using InitKernelName = + class gemm_tree_empty_krn; + cgh.parallel_for( + sycl::range<1>(n * m), [=](sycl::id<1> id) { + auto res_offset = res_indexer(id[0]); + res_tp[res_offset] = resTy(0); + }); + }); + return gemm_no_reduction_ev; + } + + if (max_nm < 64) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m < 4) { + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, + lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + else { + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, + lhs_outer_nd, lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + } + else { + return gemm_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + else { + return gemm_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, inner_nd, lhs_outer_nd, + lhs_outer_inner_shapes_strides, rhs_outer_nd, + rhs_outer_inner_shapes_strides, res_nd, res_shapes_strides, + depends); + } + } +} + +template +sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + std::size_t n, + std::size_t k, + std::size_t m, + std::vector const &depends) +{ + std::size_t delta_k(4); + std::size_t n_wi(64); + std::size_t delta_n(32); + + const sycl::device &dev = exec_q.get_device(); + const std::size_t local_mem_size = + dev.get_info(); + const std::size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_k_parameters( + local_mem_size, reserved_slm_size, delta_k, + n_wi, // modified by reference + delta_n // modified by reference + ); + + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + static constexpr OuterInnerDimsIndexerT lhs_indexer{}; + static constexpr OuterInnerDimsIndexerT rhs_indexer{}; + static constexpr OuterInnerDimsIndexerT res_indexer{}; + + using BatchIndexerT = dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer; + static constexpr BatchIndexerT batch_indexer{}; + + static constexpr std::size_t single_batch_nelems = 1; + + sycl::event gemm_ev; + if (k <= (delta_k * n_wi)) { + return gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + delta_n, n_wi, delta_k, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer, depends); + } + else { + using ReductionOpT = + typename std::conditional, + sycl::logical_or, + sycl::plus>::type; + static constexpr resTy identity_val = + sycl::known_identity::value; + + std::size_t iter_nelems = n * m; + std::size_t reduction_nelems = + (k + delta_k * n_wi - 1) / (delta_k * n_wi); + + // more than one work-groups is needed, requires a + // temporary delta_k * n_wi elements processed along k, + // so if more to process use multiple + const auto &sg_sizes = + dev.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + static constexpr std::size_t preferred_reductions_per_wi = 8; + std::size_t reductions_per_wi(preferred_reductions_per_wi); + + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + std::size_t max_wg = reduction_detail::get_work_group_size(dev); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + auto tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * reduction_nelems, exec_q); + resTy *tmp = tmp_owner.get(); + + sycl::event gemm_ev = gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, tmp, single_batch_nelems, n, k, m, + delta_n, n_wi, delta_k, batch_indexer, lhs_indexer, rhs_indexer, + res_indexer, depends); + + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {red_ev}, + tmp_owner); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + const std::size_t tmp_alloc_size = + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups); + auto tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + + sycl::event gemm_ev = gemm_detail::_gemm_tree_k_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, m_groups>( + exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, + single_batch_nelems, n, k, m, delta_n, n_wi, delta_k, + batch_indexer, lhs_indexer, rhs_indexer, res_indexer, depends); + + // tree_reduction_for_gemm_contig returns sycl::event + // for reduction + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {red_ev}, + tmp_owner); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q, + const lhsTy *lhs_tp, + const rhsTy *rhs_tp, + resTy *res_tp, + std::size_t n, + std::size_t k, + std::size_t m, + std::vector const &depends) +{ + static constexpr int wi_delta_n = 2; + std::size_t wg_delta_n(16); // rows of A processed in WG + std::size_t wg_delta_m(16); // rows of B processed in WG + std::size_t wi_delta_k(64); // Elements in K dimension processed by WI + + const sycl::device &dev = exec_q.get_device(); + const std::size_t local_mem_size = + dev.get_info(); + const std::size_t reserved_slm_size = 512; + + gemm_detail::scale_gemm_nm_parameters( + local_mem_size, reserved_slm_size, wi_delta_n, + wi_delta_k, // modified by reference + wg_delta_n, // modified by reference + wg_delta_m // modified by reference + ); + + using OuterInnerDimsIndexerT = dpctl::tensor::offset_utils::NoOpIndexer; + static constexpr OuterInnerDimsIndexerT lhs_indexer{}; + static constexpr OuterInnerDimsIndexerT rhs_indexer{}; + static constexpr OuterInnerDimsIndexerT res_indexer{}; + + using BatchIndexerT = dpctl::tensor::offset_utils::ThreeZeroOffsets_Indexer; + static constexpr BatchIndexerT batch_indexer{}; + + static constexpr std::size_t single_batch_nelems = 1; + + // each group processes delta_k items in a column, + // so no need to allocate temp memory if one group needed + if (k <= wi_delta_k) { + + return gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, wi_delta_n, + wi_delta_m>(exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, + k, m, wg_delta_n, wg_delta_m, wi_delta_k, batch_indexer, + lhs_indexer, rhs_indexer, res_indexer, depends); + } + else { + using ReductionOpT = + typename std::conditional, + sycl::logical_or, + sycl::plus>::type; + static constexpr resTy identity_val = + sycl::known_identity::value; + + std::size_t iter_nelems = n * m; + std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k; + + // more than one work-groups is needed, requires a temporary + // wi_delta_k elements processed along k, so if more to + // process use multiple + const auto &sg_sizes = + dev.get_info(); + std::size_t wg = choose_workgroup_size<4>(reduction_nelems, sg_sizes); + + static constexpr std::size_t preferred_reductions_per_wi = 8; + std::size_t reductions_per_wi(preferred_reductions_per_wi); + + std::size_t reduction_groups = + (reduction_nelems + preferred_reductions_per_wi * wg - 1) / + (preferred_reductions_per_wi * wg); + + std::size_t max_wg = reduction_detail::get_work_group_size(dev); + + if (reduction_nelems <= preferred_reductions_per_wi * max_wg) { + auto tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + iter_nelems * reduction_nelems, exec_q); + resTy *tmp = tmp_owner.get(); + + sycl::event gemm_ev = gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, wi_delta_n, + wi_delta_m>(exec_q, lhs_tp, rhs_tp, tmp, single_batch_nelems, n, + k, m, wg_delta_n, wg_delta_m, wi_delta_k, + batch_indexer, lhs_indexer, rhs_indexer, + res_indexer, depends); + + sycl::event red_ev = + single_reduction_for_gemm_contig( + exec_q, tmp, res_tp, identity_val, iter_nelems, + reduction_nelems, reduction_groups, wg, max_wg, + preferred_reductions_per_wi, reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {red_ev}, + tmp_owner); + return cleanup_host_task_event; + } + else { + assert(reduction_groups > 1); + + const std::size_t tmp_alloc_size = + iter_nelems * (/* temp */ reduction_nelems + + /* first reduction temp */ reduction_groups); + + auto tmp_owner = + dpctl::tensor::alloc_utils::smart_malloc_device( + tmp_alloc_size, exec_q); + resTy *partially_reduced_tmp = tmp_owner.get(); + resTy *partially_reduced_tmp2 = + partially_reduced_tmp + reduction_nelems * iter_nelems; + + sycl::event gemm_ev = gemm_detail::_gemm_tree_nm_step< + lhsTy, rhsTy, resTy, BatchIndexerT, OuterInnerDimsIndexerT, + OuterInnerDimsIndexerT, OuterInnerDimsIndexerT, wi_delta_n, + wi_delta_m>(exec_q, lhs_tp, rhs_tp, partially_reduced_tmp, + single_batch_nelems, n, k, m, wg_delta_n, + wg_delta_m, wi_delta_k, batch_indexer, lhs_indexer, + rhs_indexer, res_indexer, depends); + + sycl::event red_ev = + tree_reduction_for_gemm_contig( + exec_q, partially_reduced_tmp, partially_reduced_tmp2, + res_tp, identity_val, iter_nelems, reduction_nelems, + reduction_groups, wg, max_wg, preferred_reductions_per_wi, + reductions_per_wi, {gemm_ev}); + + sycl::event cleanup_host_task_event = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {red_ev}, + tmp_owner); + + return cleanup_host_task_event; + } + } +} + +template +sycl::event gemm_contig_tree_impl(sycl::queue &exec_q, + const char *lhs_cp, + const char *rhs_cp, + char *res_cp, + std::size_t n, + std::size_t k, + std::size_t m, + std::vector const &depends = {}) +{ + const lhsTy *lhs_tp = reinterpret_cast(lhs_cp); + const rhsTy *rhs_tp = reinterpret_cast(rhs_cp); + resTy *res_tp = reinterpret_cast(res_cp); + + const std::size_t min_nm = std::min(n, m); + const std::size_t max_nm = std::max(n, m); + + if (min_nm > 0 && (max_nm >= ((64 * 1024) / min_nm))) { + static constexpr std::size_t single_batch_nelems = 1; + return gemm_batch_nm_contig_impl( + exec_q, lhs_tp, rhs_tp, res_tp, single_batch_nelems, n, k, m, + depends); + } + + if (k == 0) { + sycl::event gemm_no_reduction_ev = + exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.fill(res_tp, resTy(0), n * m); + }); + return gemm_no_reduction_ev; + } + + if (max_nm < 64) { + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + if (m < 4) { + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + else { + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + } + else { + return gemm_contig_tree_k_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + } + else { // m > 1, n > k or m > k + using dpctl::tensor::type_utils::is_complex; + if constexpr (!is_complex::value) { + return gemm_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + else { + return gemm_contig_tree_nm_impl( + exec_q, lhs_tp, rhs_tp, res_tp, n, k, m, depends); + } + } +} + +} // namespace dpctl::tensor::kernels diff --git a/dpctl_ext/tensor/libtensor/source/linalg_functions/dot.cpp b/dpctl_ext/tensor/libtensor/source/linalg_functions/dot.cpp new file mode 100644 index 00000000000..5851382f846 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/linalg_functions/dot.cpp @@ -0,0 +1,834 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include +#include + +#include "dpnp4pybind11.hpp" +#include +#include +#include + +#include "dot.hpp" +#include "dot_atomic_support.hpp" +#include "dot_dispatch.hpp" +#include "elementwise_functions/elementwise_functions_type_utils.hpp" +#include "kernels/linalg_functions/dot_product.hpp" +#include "kernels/linalg_functions/gemm.hpp" +#include "reductions/reduction_atomic_support.hpp" +#include "simplify_iteration_space.hpp" +#include "utils/memory_overlap.hpp" +#include "utils/offset_utils.hpp" +#include "utils/output_validation.hpp" +#include "utils/sycl_alloc_utils.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +static int dot_output_id_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::dot_product_impl_fn_ptr_t; +static dot_product_impl_fn_ptr_t dot_product_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static dot_product_impl_fn_ptr_t + dot_product_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::dot_product_contig_impl_fn_ptr_t; +static dot_product_contig_impl_fn_ptr_t + dot_product_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static dot_product_contig_impl_fn_ptr_t + dot_product_contig_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::gemm_impl_fn_ptr_t; +static gemm_impl_fn_ptr_t gemm_atomic_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static gemm_impl_fn_ptr_t gemm_temps_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +using dpctl::tensor::kernels::gemm_contig_impl_fn_ptr_t; +static gemm_contig_impl_fn_ptr_t + gemm_contig_atomic_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static gemm_contig_impl_fn_ptr_t + gemm_contig_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::gemm_batch_impl_fn_ptr_t; +static gemm_batch_impl_fn_ptr_t + gemm_batch_atomic_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static gemm_batch_impl_fn_ptr_t + gemm_batch_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +using dpctl::tensor::kernels::gemm_batch_contig_impl_fn_ptr_t; +static gemm_batch_contig_impl_fn_ptr_t + gemm_batch_contig_atomic_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static gemm_batch_contig_impl_fn_ptr_t + gemm_batch_contig_temps_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void init_dot_dispatch_tables(void) +{ + td_ns::DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(dot_output_id_table); + + td_ns::DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(gemm_batch_atomic_dispatch_table); + + td_ns::DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(gemm_batch_contig_atomic_dispatch_table); + + td_ns::DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(gemm_atomic_dispatch_table); + + td_ns::DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(gemm_contig_atomic_dispatch_table); + + td_ns::DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(gemm_batch_temps_dispatch_table); + + td_ns::DispatchTableBuilder + dtb7; + dtb7.populate_dispatch_table(gemm_batch_contig_temps_dispatch_table); + + td_ns::DispatchTableBuilder + dtb8; + dtb8.populate_dispatch_table(gemm_temps_dispatch_table); + + td_ns::DispatchTableBuilder + dtb9; + dtb9.populate_dispatch_table(gemm_contig_temps_dispatch_table); + + td_ns::DispatchTableBuilder + dtb10; + dtb10.populate_dispatch_table(dot_product_dispatch_table); + + td_ns::DispatchTableBuilder + dtb11; + dtb11.populate_dispatch_table(dot_product_temps_dispatch_table); + + td_ns::DispatchTableBuilder + dtb12; + dtb12.populate_dispatch_table(dot_product_contig_dispatch_table); + + td_ns::DispatchTableBuilder + dtb13; + dtb13.populate_dispatch_table(dot_product_contig_temps_dispatch_table); +} + +using atomic_support::atomic_support_fn_ptr_t; +static atomic_support_fn_ptr_t dot_atomic_support_vector[td_ns::num_types]; + +void init_dot_atomic_support_vector(void) +{ + + using atomic_support::DotAtomicSupportFactory; + td_ns::DispatchVectorBuilder + dvb; + dvb.populate_dispatch_vector(dot_atomic_support_vector); +} + +std::pair + py_dot(const dpctl::tensor::usm_ndarray &x1, + const dpctl::tensor::usm_ndarray &x2, + int batch_dims, + int x1_outer_dims, + int x2_outer_dims, + int inner_dims, + const dpctl::tensor::usm_ndarray &dst, + sycl::queue &exec_q, + const std::vector &depends) +{ + if (!dpctl::utils::queues_are_compatible(exec_q, {x1, x2, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + if (inner_dims == 0) { + throw py::value_error("No inner dimension for dot"); + } + + int x1_nd = x1.get_ndim(); + int x2_nd = x2.get_ndim(); + if (x1_nd != (batch_dims + x1_outer_dims + inner_dims) || + x2_nd != (batch_dims + x2_outer_dims + inner_dims)) + { + throw py::value_error("Input arrays do not have dimensions consistent " + "with input dimensions"); + } + + int dst_nd = dst.get_ndim(); + if (dst_nd != (batch_dims + x1_outer_dims + x2_outer_dims)) { + throw py::value_error("Destination array rank does not match input " + "array rank and number of input dimensions"); + } + + const py::ssize_t *x1_shape_ptr = x1.get_shape_raw(); + const py::ssize_t *x2_shape_ptr = x2.get_shape_raw(); + const py::ssize_t *dst_shape_ptr = dst.get_shape_raw(); + + bool same_shapes = true; + std::size_t batches(1); + for (int i = 0; same_shapes && (i < batch_dims); ++i) { + same_shapes = same_shapes && (x1_shape_ptr[i] == dst_shape_ptr[i]) && + (x2_shape_ptr[i] == dst_shape_ptr[i]); + batches *= x1_shape_ptr[i]; + } + std::size_t x1_outer_nelems(1); + for (int i = batch_dims; same_shapes && (i < (batch_dims + x1_outer_dims)); + ++i) { + same_shapes = same_shapes && (x1_shape_ptr[i] == dst_shape_ptr[i]); + x1_outer_nelems *= x1_shape_ptr[i]; + } + std::size_t inner_nelems(1); + for (int i = batch_dims; i < (batch_dims + inner_dims); ++i) { + auto x1_shape_idx = x1_outer_dims + i; + same_shapes = + same_shapes && (x1_shape_ptr[x1_shape_idx] == x2_shape_ptr[i]); + inner_nelems *= x1_shape_ptr[x1_shape_idx]; + } + std::size_t x2_outer_nelems(1); + for (int i = 0; same_shapes && (i < x2_outer_dims); ++i) { + auto x2_shape_idx = batch_dims + inner_dims + i; + same_shapes = + same_shapes && (x2_shape_ptr[x2_shape_idx] == + dst_shape_ptr[batch_dims + x1_outer_dims + i]); + x2_outer_nelems *= x2_shape_ptr[x2_shape_idx]; + } + if (!same_shapes) { + throw py::value_error("Input arrays to tensor dot product do not have " + "appropriate shapes"); + } + + std::size_t dst_nelems = batches * x1_outer_nelems * x2_outer_nelems; + if (dst_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + if (static_cast(dst.get_size()) != dst_nelems) { + throw py::value_error("dst shape and size mismatch"); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, dst_nelems); + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + // check that dst does not intersect with x1 or x2 + if (overlap(dst, x1) || overlap(dst, x2)) { + throw py::value_error("Result array overlaps with inputs"); + } + + int x1_typenum = x1.get_typenum(); + int x2_typenum = x2.get_typenum(); + int dst_typenum = dst.get_typenum(); + + auto const &array_types = td_ns::usm_ndarray_types(); + int x1_typeid = array_types.typenum_to_lookup_id(x1_typenum); + int x2_typeid = array_types.typenum_to_lookup_id(x2_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + int output_typeid = dot_output_id_table[x1_typeid][x2_typeid]; + + if (output_typeid != dst_typeid) { + throw py::value_error( + "Result array has unexpected elemental data type."); + } + + void *data_ptr = dst.get_data(); + const auto &ctx = exec_q.get_context(); + auto usm_type = sycl::get_pointer_type(data_ptr, ctx); + bool supports_atomics = + dot_atomic_support_vector[output_typeid](exec_q, usm_type); + + const char *x1_data = x1.get_data(); + const char *x2_data = x2.get_data(); + char *dst_data = dst.get_data(); + + const auto &x1_shape_vec = x1.get_shape_vector(); + const auto &x1_strides_vec = x1.get_strides_vector(); + + const auto &x2_shape_vec = x2.get_shape_vector(); + const auto &x2_strides_vec = x2.get_strides_vector(); + + const auto &dst_shape_vec = dst.get_shape_vector(); + const auto &dst_strides_vec = dst.get_strides_vector(); + + bool is_x1_c_contig = x1.is_c_contiguous(); + bool is_x1_f_contig = x1.is_f_contiguous(); + bool is_x2_c_contig = x2.is_c_contiguous(); + bool is_x2_f_contig = x2.is_f_contiguous(); + bool is_dst_c_contig = dst.is_c_contiguous(); + + bool call_vecdot = ((x1_outer_dims == 0 && x1_outer_nelems == 1) && + (x2_outer_dims == 0 && x2_outer_nelems == 1)); + + bool call_batched = (batch_dims != 0 || batches > 1); + std::vector host_task_events{}; + sycl::event dot_ev; + if (call_vecdot) { + if ((is_x1_c_contig && is_x2_c_contig && is_dst_c_contig) || + ((is_x1_f_contig && is_x2_f_contig) && !call_batched)) + { + dot_product_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = dot_product_contig_dispatch_table[x1_typeid][x2_typeid]; + } + else { + fn = dot_product_contig_temps_dispatch_table[x1_typeid] + [x2_typeid]; + } + if (fn != nullptr) { + static constexpr py::ssize_t zero_offset = 0; + dot_ev = fn(exec_q, dst_nelems, inner_nelems, x1.get_data(), + x2.get_data(), dst.get_data(), + zero_offset, // lhs batch offset + zero_offset, // rhs batch offset + zero_offset, // res batch offset + zero_offset, // lhs reduction offset + zero_offset, // rhs reduction offset + depends); + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {x1, x2, dst}, {dot_ev}), + dot_ev); + } + } + int inner_nd = inner_dims; + const py::ssize_t *inner_shape_ptr = x1_shape_ptr + batch_dims; + using shT = std::vector; + const shT inner_x1_strides(std::begin(x1_strides_vec) + batch_dims, + std::end(x1_strides_vec)); + const shT inner_x2_strides(std::begin(x2_strides_vec) + batch_dims, + std::end(x2_strides_vec)); + + shT simplified_inner_shape; + shT simplified_inner_x1_strides; + shT simplified_inner_x2_strides; + py::ssize_t inner_x1_offset(0); + py::ssize_t inner_x2_offset(0); + + simplify_iteration_space( + inner_nd, inner_shape_ptr, inner_x1_strides, inner_x2_strides, + // output + simplified_inner_shape, simplified_inner_x1_strides, + simplified_inner_x2_strides, inner_x1_offset, inner_x2_offset); + + const py::ssize_t *batch_shape_ptr = x1_shape_ptr; + + const shT batch_x1_strides(std::begin(x1_strides_vec), + std::begin(x1_strides_vec) + batch_dims); + const shT batch_x2_strides(std::begin(x2_strides_vec), + std::begin(x2_strides_vec) + batch_dims); + shT const &batch_dst_strides = dst_strides_vec; + + shT simplified_batch_shape; + shT simplified_batch_x1_strides; + shT simplified_batch_x2_strides; + shT simplified_batch_dst_strides; + py::ssize_t batch_x1_offset(0); + py::ssize_t batch_x2_offset(0); + py::ssize_t batch_dst_offset(0); + + if (batch_dims == 0) { + if (dst_nelems != 1) { + throw std::runtime_error( + "batch_dims == 0, but dst_nelems != 1"); + } + batch_dims = 1; + simplified_batch_shape.push_back(1); + simplified_batch_x1_strides.push_back(0); + simplified_batch_x2_strides.push_back(0); + simplified_batch_dst_strides.push_back(0); + } + else { + simplify_iteration_space_3( + batch_dims, batch_shape_ptr, batch_x1_strides, batch_x2_strides, + batch_dst_strides, + // output + simplified_batch_shape, simplified_batch_x1_strides, + simplified_batch_x2_strides, simplified_batch_dst_strides, + batch_x1_offset, batch_x2_offset, batch_dst_offset); + } + + if (inner_nd == 1 && batch_dims == 1) { + bool dot_product_c_contig = false; + bool reduce_all_elems = false; + + if (simplified_inner_x1_strides[0] == 1 && + simplified_inner_x2_strides[0] == 1) { + reduce_all_elems = (simplified_batch_shape[0] == 1); + dot_product_c_contig = + (simplified_batch_dst_strides[0] == 1) && + (static_cast(simplified_batch_x1_strides[0]) == + inner_nelems) && + (static_cast(simplified_batch_x2_strides[0]) == + inner_nelems); + } + + if (dot_product_c_contig || reduce_all_elems) { + dot_product_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = + dot_product_contig_dispatch_table[x1_typeid][x2_typeid]; + } + else { + fn = dot_product_contig_temps_dispatch_table[x1_typeid] + [x2_typeid]; + } + if (fn != nullptr) { + dot_ev = fn(exec_q, dst_nelems, inner_nelems, x1.get_data(), + x2.get_data(), dst.get_data(), + batch_x1_offset, // lhs batch offset + batch_x2_offset, // rhs batch offset + batch_dst_offset, // res batch offset + inner_x1_offset, // lhs reduction offset + inner_x2_offset, // rhs reduction offset + depends); + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {x1, x2, dst}, {dot_ev}), + dot_ev); + } + } + } + + dot_product_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = dot_product_dispatch_table[x1_typeid][x2_typeid]; + } + if (fn == nullptr) { + fn = dot_product_temps_dispatch_table[x1_typeid][x2_typeid]; + if (fn == nullptr) { + throw std::runtime_error( + "Implementation is missing for x1_typeid=" + + std::to_string(x1_typeid) + + " and x2_typeid=" + std::to_string(x2_typeid)); + } + } + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto arrays_metainfo_packing_triple_ = + device_allocate_and_pack( + exec_q, host_task_events, + // iteration metadata + simplified_batch_shape, simplified_batch_x1_strides, + simplified_batch_x2_strides, simplified_batch_dst_strides, + // reduction metadata + simplified_inner_shape, simplified_inner_x1_strides, + simplified_inner_x2_strides); + auto tmp_alloc_owner = + std::move(std::get<0>(arrays_metainfo_packing_triple_)); + const auto ©_metadata_ev = + std::get<2>(arrays_metainfo_packing_triple_); + const py::ssize_t *temp_allocation_ptr = tmp_alloc_owner.get(); + + const py::ssize_t *iter_shape_and_strides = temp_allocation_ptr; + const py::ssize_t *inner_shape_stride = + temp_allocation_ptr + 4 * simplified_batch_shape.size(); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.resize(depends.size()); + std::copy(depends.begin(), depends.end(), all_deps.begin()); + all_deps.push_back(copy_metadata_ev); + + dot_ev = + fn(exec_q, dst_nelems, inner_nelems, x1.get_data(), x2.get_data(), + dst.get_data(), batch_dims, iter_shape_and_strides, + batch_x1_offset, batch_x2_offset, batch_dst_offset, + inner_nd, // number dimensions being reduced + inner_shape_stride, inner_x1_offset, inner_x2_offset, all_deps); + + sycl::event temp_cleanup_ev = + dpctl::tensor::alloc_utils::async_smart_free(exec_q, {dot_ev}, + tmp_alloc_owner); + host_task_events.push_back(temp_cleanup_ev); + } + else { // if (!call_vecdot) + if (!call_batched) { + if ((is_x1_c_contig && is_x2_c_contig && is_dst_c_contig)) { + gemm_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = + gemm_contig_atomic_dispatch_table[x1_typeid][x2_typeid]; + } + else { + fn = gemm_contig_temps_dispatch_table[x1_typeid][x2_typeid]; + } + if (fn != nullptr) { + dot_ev = fn(exec_q, x1_data, x2_data, dst_data, + x1_outer_nelems, // n + inner_nelems, // k + x2_outer_nelems, // m + depends); + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {x1, x2, dst}, {dot_ev}), + dot_ev); + } + } + gemm_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = gemm_atomic_dispatch_table[x1_typeid][x2_typeid]; + } + if (fn == nullptr) { + fn = gemm_temps_dispatch_table[x1_typeid][x2_typeid]; + if (fn == nullptr) { + throw std::runtime_error( + "Implementation is missing for x1_typeid=" + + std::to_string(x1_typeid) + + " and x2_typeid=" + std::to_string(x2_typeid)); + } + } + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto ptr_size_event_tuple1 = device_allocate_and_pack( + exec_q, host_task_events, x1_shape_vec, x1_strides_vec, + x2_shape_vec, x2_strides_vec, dst_shape_vec, dst_strides_vec); + auto packed_shapes_strides_owner = + std::move(std::get<0>(ptr_size_event_tuple1)); + sycl::event copy_shapes_strides_ev = + std::get<2>(ptr_size_event_tuple1); + const py::ssize_t *packed_shapes_strides = + packed_shapes_strides_owner.get(); + + const py::ssize_t *x1_shape_strides = packed_shapes_strides; + const py::ssize_t *x2_shape_strides = + packed_shapes_strides + 2 * (x1_nd); + const py::ssize_t *dst_shape_strides = + packed_shapes_strides + 2 * (x1_nd + x2_nd); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shapes_strides_ev); + + // change gemm calls to pass inner dims and outer dims separately + dot_ev = + fn(exec_q, x1_data, x2_data, dst_data, x1_outer_nelems, + inner_nelems, x2_outer_nelems, inner_dims, x1_outer_dims, + x1_shape_strides, x2_outer_dims, x2_shape_strides, + x1_outer_dims + x2_outer_dims, dst_shape_strides, all_deps); + + sycl::event cleanup_tmp_allocations_ev = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {dot_ev}, packed_shapes_strides_owner); + host_task_events.push_back(cleanup_tmp_allocations_ev); + } + else { // if (call_batched) + using shT = std::vector; + // temporary asserts for matmul + assert(x1_outer_dims == 1); + assert(x2_outer_dims == 1); + assert(inner_dims == 1); + + if ((is_x1_c_contig && is_x2_c_contig && is_dst_c_contig)) { + gemm_batch_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = gemm_batch_contig_atomic_dispatch_table[x1_typeid] + [x2_typeid]; + } + else { + fn = gemm_batch_contig_temps_dispatch_table[x1_typeid] + [x2_typeid]; + } + if (fn != nullptr) { + static constexpr py::ssize_t zero_offset = 0; + dot_ev = fn(exec_q, x1_data, x2_data, dst_data, batches, + x1_outer_nelems, // n + inner_nelems, // k + x2_outer_nelems, // m + zero_offset, zero_offset, zero_offset, depends); + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {x1, x2, dst}, {dot_ev}), + dot_ev); + } + } + + auto x1_outer_inner_dims = x1_nd - batch_dims; + auto x2_outer_inner_dims = x2_nd - batch_dims; + auto dst_outer_inner_dims = dst_nd - batch_dims; + + shT batch_x1_shape; + shT outer_inner_x1_shape; + shT batch_x1_strides; + shT outer_inner_x1_strides; + split_iteration_space(x1_shape_vec, x1_strides_vec, batch_dims, + batch_dims + x1_outer_inner_dims, + // 4 vectors modified + batch_x1_shape, outer_inner_x1_shape, + batch_x1_strides, outer_inner_x1_strides); + + shT batch_x2_shape; + shT outer_inner_x2_shape; + shT batch_x2_strides; + shT outer_inner_x2_strides; + split_iteration_space(x2_shape_vec, x2_strides_vec, batch_dims, + batch_dims + x2_outer_inner_dims, + // 4 vectors modified + batch_x2_shape, outer_inner_x2_shape, + batch_x2_strides, outer_inner_x2_strides); + + shT batch_dst_shape; + shT outer_inner_dst_shape; + shT batch_dst_strides; + shT outer_inner_dst_strides; + split_iteration_space(dst_shape_vec, dst_strides_vec, batch_dims, + batch_dims + dst_outer_inner_dims, + // 4 vectors modified + batch_dst_shape, outer_inner_dst_shape, + batch_dst_strides, outer_inner_dst_strides); + + using shT = std::vector; + shT simplified_batch_shape; + shT simplified_batch_x1_strides; + shT simplified_batch_x2_strides; + shT simplified_batch_dst_strides; + py::ssize_t x1_batch_offset(0); + py::ssize_t x2_batch_offset(0); + py::ssize_t dst_batch_offset(0); + + const py::ssize_t *shape = x1_shape_ptr; + + simplify_iteration_space_3( + batch_dims, shape, batch_x1_strides, batch_x2_strides, + batch_dst_strides, + // outputs + simplified_batch_shape, simplified_batch_x1_strides, + simplified_batch_x2_strides, simplified_batch_dst_strides, + x1_batch_offset, x2_batch_offset, dst_batch_offset); + + if (batch_dims == 1 && x1_outer_dims == 1 && x2_outer_dims == 1 && + inner_dims == 1) + { + bool gemm_batch_c_contig = false; + + if ((static_cast(outer_inner_x1_strides[0]) == + inner_nelems && + outer_inner_x1_strides[1] == 1) && + (static_cast(outer_inner_x2_strides[0]) == + inner_nelems && + outer_inner_x2_strides[1] == 1) && + (static_cast(outer_inner_dst_strides[0]) == + x2_outer_nelems && + outer_inner_dst_strides[1] == 1)) + { + gemm_batch_c_contig = + (static_cast( + simplified_batch_x1_strides[0]) == + x1_outer_nelems * inner_nelems) && + (static_cast( + simplified_batch_x2_strides[0]) == + x2_outer_nelems * inner_nelems) && + (static_cast( + simplified_batch_dst_strides[0]) == + x1_outer_nelems * x2_outer_nelems); + } + + if (gemm_batch_c_contig) { + gemm_batch_contig_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = gemm_batch_contig_atomic_dispatch_table[x1_typeid] + [x2_typeid]; + } + else { + fn = gemm_batch_contig_temps_dispatch_table[x1_typeid] + [x2_typeid]; + } + if (fn != nullptr) { + dot_ev = fn(exec_q, x1_data, x2_data, dst_data, batches, + x1_outer_nelems, // n + inner_nelems, // k + x2_outer_nelems, // m + x1_batch_offset, x2_batch_offset, + dst_batch_offset, depends); + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {x1, x2, dst}, + {dot_ev}), + dot_ev); + } + } + } + + gemm_batch_impl_fn_ptr_t fn = nullptr; + if (supports_atomics) { + fn = gemm_batch_atomic_dispatch_table[x1_typeid][x2_typeid]; + } + if (fn == nullptr) { + fn = gemm_batch_temps_dispatch_table[x1_typeid][x2_typeid]; + if (fn == nullptr) { + throw std::runtime_error( + "Implementation is missing for x1_typeid=" + + std::to_string(x1_typeid) + + " and x2_typeid=" + std::to_string(x2_typeid)); + } + } + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto ptr_size_event_tuple1 = device_allocate_and_pack( + exec_q, host_task_events, simplified_batch_shape, + simplified_batch_x1_strides, simplified_batch_x2_strides, + simplified_batch_dst_strides, outer_inner_x1_shape, + outer_inner_x1_strides, outer_inner_x2_shape, + outer_inner_x2_strides, outer_inner_dst_shape, + outer_inner_dst_strides, + // full shape and strides of the result array + // necessary for reduction and initialization + simplified_batch_shape, outer_inner_dst_shape, + simplified_batch_dst_strides, outer_inner_dst_strides); + auto packed_shapes_strides_owner = + std::move(std::get<0>(ptr_size_event_tuple1)); + sycl::event copy_shapes_strides_ev = + std::get<2>(ptr_size_event_tuple1); + const py::ssize_t *packed_shapes_strides = + packed_shapes_strides_owner.get(); + + const auto batch_shape_strides = packed_shapes_strides; + const auto x1_outer_inner_shapes_strides = + packed_shapes_strides + 4 * batch_dims; + const auto x2_outer_inner_shapes_strides = + packed_shapes_strides + 4 * batch_dims + + 2 * (x1_outer_inner_dims); + const auto dst_outer_shapes_strides = + packed_shapes_strides + 4 * batch_dims + + 2 * (x1_outer_inner_dims) + 2 * (x2_outer_inner_dims); + const auto dst_full_shape_strides = + packed_shapes_strides + 4 * batch_dims + + 2 * (x1_outer_inner_dims) + 2 * (x2_outer_inner_dims) + + 2 * (dst_outer_inner_dims); + + std::vector all_deps; + all_deps.reserve(depends.size() + 1); + all_deps.insert(all_deps.end(), depends.begin(), depends.end()); + all_deps.push_back(copy_shapes_strides_ev); + + dot_ev = fn( + exec_q, x1_data, x2_data, dst_data, batches, x1_outer_nelems, + inner_nelems, x2_outer_nelems, batch_dims, batch_shape_strides, + x1_batch_offset, x2_batch_offset, dst_batch_offset, inner_dims, + x1_outer_dims, x1_outer_inner_shapes_strides, x2_outer_dims, + x2_outer_inner_shapes_strides, x1_outer_dims + x2_outer_dims, + dst_outer_shapes_strides, dst_full_shape_strides, all_deps); + + sycl::event cleanup_tmp_allocations_ev = + dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {dot_ev}, packed_shapes_strides_owner); + host_task_events.push_back(cleanup_tmp_allocations_ev); + } + } + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {x1, x2, dst}, host_task_events), + dot_ev); +} + +template +py::object py_dot_result_type(const py::dtype &input1_dtype, + const py::dtype &input2_dtype, + const output_typesT &output_types_table) +{ + int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpctl + int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpctl + int src1_typeid = -1; + int src2_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + src1_typeid = array_types.typenum_to_lookup_id(tn1); + src2_typeid = array_types.typenum_to_lookup_id(tn2); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 || + src2_typeid >= td_ns::num_types) + { + throw std::runtime_error("binary output type lookup failed"); + } + int dst_typeid = output_types_table[src1_typeid][src2_typeid]; + + if (dst_typeid < 0) { + auto res = py::none(); + return py::cast(res); + } + else { + auto dst_typenum_t = static_cast(dst_typeid); + auto dt = type_utils::_dtype_from_typenum(dst_typenum_t); + + return py::cast(dt); + } +} + +void init_dot(py::module_ m) +{ + init_dot_atomic_support_vector(); + init_dot_dispatch_tables(); + + m.def("_dot", &py_dot, "", py::arg("x1"), py::arg("x2"), + py::arg("batch_dims"), py::arg("x1_outer_dims"), + py::arg("x2_outer_dims"), py::arg("inner_dims"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + + auto dot_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_dot_result_type(dtype1, dtype2, dot_output_id_table); + }; + m.def("_dot_result_type", dot_result_type_pyapi, ""); +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/linalg_functions/dot.hpp b/dpctl_ext/tensor/libtensor/source/linalg_functions/dot.hpp new file mode 100644 index 00000000000..f6a23ace5cd --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/linalg_functions/dot.hpp @@ -0,0 +1,45 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_dot(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/linalg_functions/dot_atomic_support.hpp b/dpctl_ext/tensor/libtensor/source/linalg_functions/dot_atomic_support.hpp new file mode 100644 index 00000000000..49c26459695 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/linalg_functions/dot_atomic_support.hpp @@ -0,0 +1,60 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#pragma once + +#include + +#include "reductions/reduction_atomic_support.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl::tensor::py_internal::atomic_support +{ + +template +struct DotAtomicSupportFactory +{ + fnT get() + { + using dpctl::tensor::type_utils::is_complex; + if constexpr (is_complex::value) { + return atomic_support::fixed_decision; + } + else { + return atomic_support::check_atomic_support; + } + } +}; + +} // namespace dpctl::tensor::py_internal::atomic_support diff --git a/dpctl_ext/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp b/dpctl_ext/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp new file mode 100644 index 00000000000..790439abd0f --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/linalg_functions/dot_dispatch.hpp @@ -0,0 +1,404 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===--------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include "kernels/linalg_functions/dot_product.hpp" +#include "kernels/linalg_functions/gemm.hpp" +#include "utils/type_dispatch_building.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace td_ns = dpctl::tensor::type_dispatch; + +template +struct DotAtomicOutputType +{ + using value_type = typename std::disjunction< + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; +}; + +// add separate type support lists for atomic vs. temps +// gemm, gevm, and dot product share output type struct +template +struct DotNoAtomicOutputType +{ + using value_type = typename std::disjunction< + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, + td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; +}; + +template +struct DotTypeMapFactory +{ + /*! @brief get typeid for output type of kernels called by py_dot */ + std::enable_if_t::value, int> get() + { + using rT1 = typename DotNoAtomicOutputType::value_type; + using rT2 = typename DotAtomicOutputType::value_type; + static_assert(std::is_same_v || std::is_same_v); + return td_ns::GetTypeid{}.get(); + } +}; + +template +struct GemmBatchAtomicFactory +{ + fnT get() + { + if constexpr (!DotAtomicOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_batch_impl; + using T3 = typename DotAtomicOutputType::value_type; + fnT fn = gemm_batch_impl; + return fn; + } + } +}; + +template +struct GemmBatchContigAtomicFactory +{ + fnT get() + { + if constexpr (!DotAtomicOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_batch_contig_impl; + using T3 = typename DotAtomicOutputType::value_type; + fnT fn = gemm_batch_contig_impl; + return fn; + } + } +}; + +template +struct GemmAtomicFactory +{ + fnT get() + { + if constexpr (!DotAtomicOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_impl; + using T3 = typename DotAtomicOutputType::value_type; + fnT fn = gemm_impl; + return fn; + } + } +}; + +template +struct GemmContigAtomicFactory +{ + fnT get() + { + if constexpr (!DotAtomicOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_contig_impl; + using T3 = typename DotAtomicOutputType::value_type; + fnT fn = gemm_contig_impl; + return fn; + } + } +}; + +template +struct GemmTempsFactory +{ + fnT get() + { + if constexpr (!DotNoAtomicOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_tree_impl; + using T3 = typename DotNoAtomicOutputType::value_type; + fnT fn = gemm_tree_impl; + return fn; + } + } +}; + +template +struct GemmContigTempsFactory +{ + fnT get() + { + if constexpr (!DotNoAtomicOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_contig_tree_impl; + using T3 = typename DotNoAtomicOutputType::value_type; + fnT fn = gemm_contig_tree_impl; + return fn; + } + } +}; + +template +struct GemmBatchTempsFactory +{ + fnT get() + { + if constexpr (!DotNoAtomicOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_batch_tree_impl; + using T3 = typename DotNoAtomicOutputType::value_type; + fnT fn = gemm_batch_tree_impl; + return fn; + } + } +}; + +template +struct GemmBatchContigTempsFactory +{ + fnT get() + { + if constexpr (!DotNoAtomicOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::gemm_batch_contig_tree_impl; + using T3 = typename DotNoAtomicOutputType::value_type; + fnT fn = gemm_batch_contig_tree_impl; + return fn; + } + } +}; + +template +struct DotProductAtomicFactory +{ + fnT get() + { + if constexpr (!DotAtomicOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::dot_product_impl; + using T3 = typename DotAtomicOutputType::value_type; + fnT fn = dot_product_impl; + return fn; + } + } +}; + +template +struct DotProductNoAtomicFactory +{ + fnT get() + { + if constexpr (!DotNoAtomicOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::dot_product_tree_impl; + using T3 = typename DotNoAtomicOutputType::value_type; + fnT fn = dot_product_tree_impl; + return fn; + } + } +}; + +template +struct DotProductContigAtomicFactory +{ + fnT get() + { + if constexpr (!DotAtomicOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::dot_product_contig_impl; + using T3 = typename DotAtomicOutputType::value_type; + fnT fn = dot_product_contig_impl; + return fn; + } + } +}; + +template +struct DotProductContigNoAtomicFactory +{ + fnT get() + { + if constexpr (!DotNoAtomicOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + using dpctl::tensor::kernels::dot_product_contig_tree_impl; + using T3 = typename DotNoAtomicOutputType::value_type; + fnT fn = dot_product_contig_tree_impl; + return fn; + } + } +}; + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/tensor_linalg.cpp b/dpctl_ext/tensor/libtensor/source/tensor_linalg.cpp new file mode 100644 index 00000000000..4a1b5fb79b9 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/tensor_linalg.cpp @@ -0,0 +1,41 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_impl extensions +//===----------------------------------------------------------------------===// + +#include "linalg_functions/dot.hpp" +#include + +PYBIND11_MODULE(_tensor_linalg_impl, m) +{ + dpctl::tensor::py_internal::init_dot(m); +} diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index b5afd9523d6..2ff08cc6ec8 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -2370,7 +2370,7 @@ def matrix_transpose(x, /): f"but it is {usm_x.ndim}" ) - usm_res = dpt.matrix_transpose(usm_x) + usm_res = dpt_ext.matrix_transpose(usm_x) return dpnp_array._create_from_usm_ndarray(usm_res) diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index b01f57eaecd..28ed40ab5f6 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -26,8 +26,6 @@ # THE POSSIBILITY OF SUCH DAMAGE. # ***************************************************************************** -import dpctl -import dpctl.tensor as dpt import dpctl.utils as dpu import numpy from dpctl.utils import ExecutionPlacementError @@ -35,6 +33,7 @@ # pylint: disable=no-name-in-module # TODO: revert to `import dpctl.tensor...` # when dpnp fully migrates dpctl/tensor +import dpctl_ext.tensor as dpt import dpctl_ext.tensor._tensor_impl as ti import dpnp import dpnp.backend.extensions.blas._blas_impl as bi @@ -696,7 +695,7 @@ def _validate_out_array(out, exec_q): """Validate out is supported array and has correct queue.""" if out is not None: dpnp.check_supported_arrays_type(out) - if dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) is None: + if dpu.get_execution_queue((exec_q, out.sycl_queue)) is None: raise ExecutionPlacementError( "Input and output allocation queues are not compatible" )