diff --git a/dpctl_ext/tensor/CMakeLists.txt b/dpctl_ext/tensor/CMakeLists.txt index afc7dca4db3..b032dc34bdb 100644 --- a/dpctl_ext/tensor/CMakeLists.txt +++ b/dpctl_ext/tensor/CMakeLists.txt @@ -75,19 +75,19 @@ set(_elementwise_sources ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/abs.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acos.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acosh.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/add.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/add.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/angle.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asin.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asinh.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan2.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan2.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atanh.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_and.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_and.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_invert.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_left_shift.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_or.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_right_shift.cpp - #${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_xor.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_left_shift.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_or.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_right_shift.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_xor.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cbrt.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/ceil.cpp ${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/conj.cpp diff --git a/dpctl_ext/tensor/__init__.py b/dpctl_ext/tensor/__init__.py index a6127f1fc27..5172d426334 100644 --- a/dpctl_ext/tensor/__init__.py +++ b/dpctl_ext/tensor/__init__.py @@ -57,12 +57,19 @@ abs, acos, acosh, + add, angle, asin, asinh, atan, + atan2, atanh, + bitwise_and, bitwise_invert, + bitwise_left_shift, + bitwise_or, + bitwise_right_shift, + bitwise_xor, cbrt, ceil, conj, @@ -158,6 +165,7 @@ "abs", "acos", "acosh", + "add", "all", "angle", "any", @@ -172,7 +180,13 @@ "astype", "atan", "atanh", + "atan2", + "bitwise_and", "bitwise_invert", + "bitwise_left_shift", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", "broadcast_arrays", "broadcast_to", "can_cast", diff --git a/dpctl_ext/tensor/_elementwise_common.py b/dpctl_ext/tensor/_elementwise_common.py index 7811c01d9ce..7fd9dabf961 100644 --- a/dpctl_ext/tensor/_elementwise_common.py +++ b/dpctl_ext/tensor/_elementwise_common.py @@ -35,11 +35,22 @@ import dpctl_ext.tensor as dpt_ext import dpctl_ext.tensor._tensor_impl as ti -from ._copy_utils import _empty_like_orderK +from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK +from ._manipulation_functions import _broadcast_shape_impl +from ._scalar_utils import ( + _get_dtype, + _get_queue_usm_type, + _get_shape, + _validate_dtype, +) from ._type_utils import ( + _acceptance_fn_default_binary, _acceptance_fn_default_unary, _all_data_types, _find_buf_dtype, + _find_buf_dtype2, + _find_buf_dtype_in_place_op, + _resolve_weak_types, ) @@ -283,3 +294,705 @@ def __call__(self, x, /, *, out=None, order="K"): _manager.add_event_pair(ht, uf_ev) return out + + +class BinaryElementwiseFunc: + """ + Class that implements binary element-wise functions. + + Args: + name (str): + Name of the unary function + result_type_resovle_fn (callable): + Function that takes dtypes of the input and + returns the dtype of the result if the + implementation functions supports it, or + returns `None` otherwise. + binary_dp_impl_fn (callable): + Data-parallel implementation function with signature + `impl_fn(src1: usm_ndarray, src2: usm_ndarray, dst: usm_ndarray, + sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])` + where the `src1` and `src2` are the argument arrays, `dst` is the + array to be populated with function values, + i.e. `dst=func(src1, src2)`. + The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s. + The first event corresponds to data-management host tasks, + including lifetime management of argument Python objects to ensure + that their associated USM allocation is not freed before offloaded + computational tasks complete execution, while the second event + corresponds to computational tasks associated with function + evaluation. + docs (str): + Documentation string for the unary function. + binary_inplace_fn (callable, optional): + Data-parallel implementation function with signature + `impl_fn(src: usm_ndarray, dst: usm_ndarray, + sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])` + where the `src` is the argument array, `dst` is the + array to be populated with function values, + i.e. `dst=func(dst, src)`. + The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s. + The first event corresponds to data-management host tasks, + including async lifetime management of Python arguments, + while the second event corresponds to computational tasks + associated with function evaluation. + acceptance_fn (callable, optional): + Function to influence type promotion behavior of this binary + function. The function takes 6 arguments: + arg1_dtype - Data type of the first argument + arg2_dtype - Data type of the second argument + ret_buf1_dtype - Data type the first argument would be cast to + ret_buf2_dtype - Data type the second argument would be cast to + res_dtype - Data type of the output array with function values + sycl_dev - The :class:`dpctl.SyclDevice` where the function + evaluation is carried out. + The function is only called when both arguments of the binary + function require casting, e.g. both arguments of + `dpctl.tensor.logaddexp` are arrays with integral data type. + """ + + def __init__( + self, + name, + result_type_resolver_fn, + binary_dp_impl_fn, + docs, + binary_inplace_fn=None, + acceptance_fn=None, + weak_type_resolver=None, + ): + self.__name__ = "BinaryElementwiseFunc" + self.name_ = name + self.result_type_resolver_fn_ = result_type_resolver_fn + self.types_ = None + self.binary_fn_ = binary_dp_impl_fn + self.binary_inplace_fn_ = binary_inplace_fn + self.__doc__ = docs + if callable(acceptance_fn): + self.acceptance_fn_ = acceptance_fn + else: + self.acceptance_fn_ = _acceptance_fn_default_binary + if callable(weak_type_resolver): + self.weak_type_resolver_ = weak_type_resolver + else: + self.weak_type_resolver_ = _resolve_weak_types + + def __str__(self): + return f"<{self.__name__} '{self.name_}'>" + + def __repr__(self): + return f"<{self.__name__} '{self.name_}'>" + + def get_implementation_function(self): + """Returns the out-of-place implementation + function for this elementwise binary function. + + """ + return self.binary_fn_ + + def get_implementation_inplace_function(self): + """Returns the in-place implementation + function for this elementwise binary function. + + """ + return self.binary_inplace_fn_ + + def get_type_result_resolver_function(self): + """Returns the type resolver function for this + elementwise binary function. + """ + return self.result_type_resolver_fn_ + + def get_type_promotion_path_acceptance_function(self): + """Returns the acceptance function for this + elementwise binary function. + + Acceptance function influences the type promotion + behavior of this binary function. + The function takes 6 arguments: + arg1_dtype - Data type of the first argument + arg2_dtype - Data type of the second argument + ret_buf1_dtype - Data type the first argument would be cast to + ret_buf2_dtype - Data type the second argument would be cast to + res_dtype - Data type of the output array with function values + sycl_dev - :class:`dpctl.SyclDevice` on which function evaluation + is carried out. + + The acceptance function is only invoked if both input arrays must be + cast to intermediary data types, as would happen during call of + `dpctl.tensor.hypot` with both arrays being of integral data type. + """ + return self.acceptance_fn_ + + def get_array_dtype_scalar_type_resolver_function(self): + """Returns the function which determines how to treat + Python scalar types for this elementwise binary function. + + Resolver influences what type the scalar will be + treated as prior to type promotion behavior. + The function takes 3 arguments: + + Args: + o1_dtype (object, dtype): + A class representing a Python scalar type or a ``dtype`` + o2_dtype (object, dtype): + A class representing a Python scalar type or a ``dtype`` + sycl_dev (:class:`dpctl.SyclDevice`): + Device on which function evaluation is carried out. + + One of ``o1_dtype`` and ``o2_dtype`` must be a ``dtype`` instance. + """ + return self.weak_type_resolver_ + + @property + def nin(self): + """Returns the number of arguments treated as inputs.""" + return 2 + + @property + def nout(self): + """Returns the number of arguments treated as outputs.""" + return 1 + + @property + def types(self): + """Returns information about types supported by + implementation function, using NumPy's character + encoding for data types, e.g. + + :Example: + .. code-block:: python + + dpctl.tensor.divide.types + # Outputs: ['ee->e', 'ff->f', 'fF->F', 'dd->d', 'dD->D', + # 'Ff->F', 'FF->F', 'Dd->D', 'DD->D'] + """ + types = self.types_ + if not types: + types = [] + _all_dtypes = _all_data_types(True, True) + for dt1 in _all_dtypes: + for dt2 in _all_dtypes: + dt3 = self.result_type_resolver_fn_(dt1, dt2) + if dt3: + types.append(f"{dt1.char}{dt2.char}->{dt3.char}") + self.types_ = types + return types + + def __call__(self, o1, o2, /, *, out=None, order="K"): + if order not in ["K", "C", "F", "A"]: + order = "K" + q1, o1_usm_type = _get_queue_usm_type(o1) + q2, o2_usm_type = _get_queue_usm_type(o2) + if q1 is None and q2 is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments. " + "One of the arguments must represent USM allocation and " + "expose `__sycl_usm_array_interface__` property" + ) + if q1 is None: + exec_q = q2 + res_usm_type = o2_usm_type + elif q2 is None: + exec_q = q1 + res_usm_type = o1_usm_type + else: + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + o1_usm_type, + o2_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + o1_shape = _get_shape(o1) + o2_shape = _get_shape(o2) + if not all( + isinstance(s, (tuple, list)) + for s in ( + o1_shape, + o2_shape, + ) + ): + raise TypeError( + "Shape of arguments can not be inferred. " + "Arguments are expected to be " + "lists, tuples, or both" + ) + try: + res_shape = _broadcast_shape_impl( + [ + o1_shape, + o2_shape, + ] + ) + except ValueError: + raise ValueError( + "operands could not be broadcast together with shapes " + f"{o1_shape} and {o2_shape}" + ) + sycl_dev = exec_q.sycl_device + o1_dtype = _get_dtype(o1, sycl_dev) + o2_dtype = _get_dtype(o2, sycl_dev) + if not all(_validate_dtype(o) for o in (o1_dtype, o2_dtype)): + raise ValueError("Operands have unsupported data types") + + o1_dtype, o2_dtype = self.weak_type_resolver_( + o1_dtype, o2_dtype, sycl_dev + ) + + buf1_dt, buf2_dt, res_dt = _find_buf_dtype2( + o1_dtype, + o2_dtype, + self.result_type_resolver_fn_, + sycl_dev, + acceptance_fn=self.acceptance_fn_, + ) + + if res_dt is None: + raise ValueError( + f"function '{self.name_}' does not support input types " + f"({o1_dtype}, {o2_dtype}), " + "and the inputs could not be safely coerced to any " + "supported types according to the casting rule ''safe''." + ) + + orig_out = out + _manager = SequentialOrderManager[exec_q] + if out is not None: + if not isinstance(out, dpt.usm_ndarray): + raise TypeError( + f"output array must be of usm_ndarray type, got {type(out)}" + ) + + if not out.flags.writable: + raise ValueError("provided `out` array is read-only") + + if out.shape != res_shape: + raise ValueError( + "The shape of input and output arrays are inconsistent. " + f"Expected output shape is {res_shape}, got {out.shape}" + ) + + if res_dt != out.dtype: + raise ValueError( + f"Output array of type {res_dt} is needed, " + f"got {out.dtype}" + ) + + if ( + dpctl.utils.get_execution_queue((exec_q, out.sycl_queue)) + is None + ): + raise ExecutionPlacementError( + "Input and output allocation queues are not compatible" + ) + + if isinstance(o1, dpt.usm_ndarray): + if ti._array_overlap(o1, out) and buf1_dt is None: + if not ti._same_logical_tensors(o1, out): + out = dpt_ext.empty_like(out) + elif self.binary_inplace_fn_ is not None: + # if there is a dedicated in-place kernel + # it can be called here, otherwise continues + if isinstance(o2, dpt.usm_ndarray): + src2 = o2 + if ( + ti._array_overlap(o2, out) + and not ti._same_logical_tensors(o2, out) + and buf2_dt is None + ): + buf2_dt = o2_dtype + else: + src2 = dpt_ext.asarray( + o2, dtype=o2_dtype, sycl_queue=exec_q + ) + if buf2_dt is None: + if src2.shape != res_shape: + src2 = dpt_ext.broadcast_to(src2, res_shape) + dep_evs = _manager.submitted_events + ht_, comp_ev = self.binary_inplace_fn_( + lhs=o1, + rhs=src2, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_, comp_ev) + else: + buf2 = dpt_ext.empty_like(src2, dtype=buf2_dt) + dep_evs = _manager.submitted_events + ( + ht_copy_ev, + copy_ev, + ) = ti._copy_usm_ndarray_into_usm_ndarray( + src=src2, + dst=buf2, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_copy_ev, copy_ev) + + buf2 = dpt_ext.broadcast_to(buf2, res_shape) + ht_, bf_ev = self.binary_inplace_fn_( + lhs=o1, + rhs=buf2, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_, bf_ev) + + return out + + if isinstance(o2, dpt.usm_ndarray): + if ( + ti._array_overlap(o2, out) + and not ti._same_logical_tensors(o2, out) + and buf2_dt is None + ): + # should not reach if out is reallocated + # after being checked against o1 + out = dpt_ext.empty_like(out) + + if isinstance(o1, dpt.usm_ndarray): + src1 = o1 + else: + src1 = dpt_ext.asarray(o1, dtype=o1_dtype, sycl_queue=exec_q) + if isinstance(o2, dpt.usm_ndarray): + src2 = o2 + else: + src2 = dpt_ext.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q) + + if order == "A": + order = ( + "F" + if all( + arr.flags.f_contiguous + for arr in ( + src1, + src2, + ) + ) + else "C" + ) + + if buf1_dt is None and buf2_dt is None: + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + src1, src2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt_ext.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + if src1.shape != res_shape: + src1 = dpt_ext.broadcast_to(src1, res_shape) + if src2.shape != res_shape: + src2 = dpt_ext.broadcast_to(src2, res_shape) + deps_ev = _manager.submitted_events + ht_binary_ev, binary_ev = self.binary_fn_( + src1=src1, + src2=src2, + dst=out, + sycl_queue=exec_q, + depends=deps_ev, + ) + _manager.add_event_pair(ht_binary_ev, binary_ev) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[binary_ev], + ) + _manager.add_event_pair(ht_copy_out_ev, cpy_ev) + out = orig_out + return out + elif buf1_dt is None: + if order == "K": + buf2 = _empty_like_orderK(src2, buf2_dt) + else: + buf2 = dpt_ext.empty_like(src2, dtype=buf2_dt, order=order) + dep_evs = _manager.submitted_events + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=src2, dst=buf2, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_copy_ev, copy_ev) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + src1, buf2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt_ext.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + if src1.shape != res_shape: + src1 = dpt_ext.broadcast_to(src1, res_shape) + buf2 = dpt_ext.broadcast_to(buf2, res_shape) + ht_binary_ev, binary_ev = self.binary_fn_( + src1=src1, + src2=buf2, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_binary_ev, binary_ev) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[binary_ev], + ) + _manager.add_event_pair(ht_copy_out_ev, cpy_ev) + out = orig_out + return out + elif buf2_dt is None: + if order == "K": + buf1 = _empty_like_orderK(src1, buf1_dt) + else: + buf1 = dpt_ext.empty_like(src1, dtype=buf1_dt, order=order) + dep_evs = _manager.submitted_events + ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=src1, dst=buf1, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_copy_ev, copy_ev) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + buf1, src2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt_ext.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + buf1 = dpt_ext.broadcast_to(buf1, res_shape) + if src2.shape != res_shape: + src2 = dpt_ext.broadcast_to(src2, res_shape) + ht_binary_ev, binary_ev = self.binary_fn_( + src1=buf1, + src2=src2, + dst=out, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_binary_ev, binary_ev) + if not (orig_out is None or orig_out is out): + # Copy the out data from temporary buffer to original memory + ht_copy_out_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=out, + dst=orig_out, + sycl_queue=exec_q, + depends=[binary_ev], + ) + _manager.add_event_pair(ht_copy_out_ev, cpy_ev) + out = orig_out + return out + + if order == "K": + if src1.flags.c_contiguous and src2.flags.c_contiguous: + order = "C" + elif src1.flags.f_contiguous and src2.flags.f_contiguous: + order = "F" + if order == "K": + buf1 = _empty_like_orderK(src1, buf1_dt) + else: + buf1 = dpt_ext.empty_like(src1, dtype=buf1_dt, order=order) + dep_evs = _manager.submitted_events + ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=src1, dst=buf1, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_copy1_ev, copy1_ev) + if order == "K": + buf2 = _empty_like_orderK(src2, buf2_dt) + else: + buf2 = dpt_ext.empty_like(src2, dtype=buf2_dt, order=order) + ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=src2, dst=buf2, sycl_queue=exec_q, depends=dep_evs + ) + _manager.add_event_pair(ht_copy2_ev, copy2_ev) + if out is None: + if order == "K": + out = _empty_like_pair_orderK( + buf1, buf2, res_dt, res_shape, res_usm_type, exec_q + ) + else: + out = dpt_ext.empty( + res_shape, + dtype=res_dt, + usm_type=res_usm_type, + sycl_queue=exec_q, + order=order, + ) + + buf1 = dpt_ext.broadcast_to(buf1, res_shape) + buf2 = dpt_ext.broadcast_to(buf2, res_shape) + ht_, bf_ev = self.binary_fn_( + src1=buf1, + src2=buf2, + dst=out, + sycl_queue=exec_q, + depends=[copy1_ev, copy2_ev], + ) + _manager.add_event_pair(ht_, bf_ev) + return out + + def _inplace_op(self, o1, o2): + if self.binary_inplace_fn_ is None: + raise ValueError( + "binary function does not have a dedicated in-place " + "implementation" + ) + if not isinstance(o1, dpt.usm_ndarray): + raise TypeError( + "Expected first argument to be " + f"dpctl.tensor.usm_ndarray, got {type(o1)}" + ) + if not o1.flags.writable: + raise ValueError("provided left-hand side array is read-only") + q1, o1_usm_type = o1.sycl_queue, o1.usm_type + q2, o2_usm_type = _get_queue_usm_type(o2) + if q2 is None: + exec_q = q1 + res_usm_type = o1_usm_type + else: + exec_q = dpctl.utils.get_execution_queue((q1, q2)) + if exec_q is None: + raise ExecutionPlacementError( + "Execution placement can not be unambiguously inferred " + "from input arguments." + ) + res_usm_type = dpctl.utils.get_coerced_usm_type( + ( + o1_usm_type, + o2_usm_type, + ) + ) + dpctl.utils.validate_usm_type(res_usm_type, allow_none=False) + o1_shape = o1.shape + o2_shape = _get_shape(o2) + if not isinstance(o2_shape, (tuple, list)): + raise TypeError( + "Shape of second argument can not be inferred. " + "Expected list or tuple." + ) + try: + res_shape = _broadcast_shape_impl( + [ + o1_shape, + o2_shape, + ] + ) + except ValueError: + raise ValueError( + "operands could not be broadcast together with shapes " + f"{o1_shape} and {o2_shape}" + ) + + if res_shape != o1_shape: + raise ValueError( + "The shape of the non-broadcastable left-hand " + f"side {o1_shape} is inconsistent with the " + f"broadcast shape {res_shape}." + ) + + sycl_dev = exec_q.sycl_device + o1_dtype = o1.dtype + o2_dtype = _get_dtype(o2, sycl_dev) + if not _validate_dtype(o2_dtype): + raise ValueError("Operand has an unsupported data type") + + o1_dtype, o2_dtype = self.weak_type_resolver_( + o1_dtype, o2_dtype, sycl_dev + ) + + buf_dt, res_dt = _find_buf_dtype_in_place_op( + o1_dtype, + o2_dtype, + self.result_type_resolver_fn_, + sycl_dev, + ) + + if res_dt is None: + raise ValueError( + f"function '{self.name_}' does not support input types " + f"({o1_dtype}, {o2_dtype}), " + "and the inputs could not be safely coerced to any " + "supported types according to the casting rule " + "''same_kind''." + ) + + if res_dt != o1_dtype: + raise ValueError( + f"Output array of type {res_dt} is needed, " f"got {o1_dtype}" + ) + + _manager = SequentialOrderManager[exec_q] + if isinstance(o2, dpt.usm_ndarray): + src2 = o2 + if ( + ti._array_overlap(o2, o1) + and not ti._same_logical_tensors(o2, o1) + and buf_dt is None + ): + buf_dt = o2_dtype + else: + src2 = dpt_ext.asarray(o2, dtype=o2_dtype, sycl_queue=exec_q) + if buf_dt is None: + if src2.shape != res_shape: + src2 = dpt_ext.broadcast_to(src2, res_shape) + dep_evs = _manager.submitted_events + ht_, comp_ev = self.binary_inplace_fn_( + lhs=o1, + rhs=src2, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_, comp_ev) + else: + buf = dpt_ext.empty_like(src2, dtype=buf_dt) + dep_evs = _manager.submitted_events + ( + ht_copy_ev, + copy_ev, + ) = ti._copy_usm_ndarray_into_usm_ndarray( + src=src2, + dst=buf, + sycl_queue=exec_q, + depends=dep_evs, + ) + _manager.add_event_pair(ht_copy_ev, copy_ev) + + buf = dpt_ext.broadcast_to(buf, res_shape) + ht_, bf_ev = self.binary_inplace_fn_( + lhs=o1, + rhs=buf, + sycl_queue=exec_q, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_, bf_ev) + + return o1 diff --git a/dpctl_ext/tensor/_elementwise_funcs.py b/dpctl_ext/tensor/_elementwise_funcs.py index ae0ef8aa349..08d59d8289a 100644 --- a/dpctl_ext/tensor/_elementwise_funcs.py +++ b/dpctl_ext/tensor/_elementwise_funcs.py @@ -30,7 +30,7 @@ # when dpnp fully migrates dpctl/tensor import dpctl_ext.tensor._tensor_elementwise_impl as ti -from ._elementwise_common import UnaryElementwiseFunc +from ._elementwise_common import BinaryElementwiseFunc, UnaryElementwiseFunc from ._type_utils import ( _acceptance_fn_negative, _acceptance_fn_reciprocal, @@ -124,6 +124,41 @@ ) del _acosh_docstring +# B01: ===== ADD (x1, x2) + +_add_docstring_ = r""" +add(x1, x2, /, \*, out=None, order='K') + +Calculates the sum for each element `x1_i` of the input array `x1` with +the respective element `x2_i` of the input array `x2`. + +Args: + x1 (usm_ndarray): + First input array. May have any data type. + x2 (usm_ndarray): + Second input array. May have any data type. + out (Union[usm_ndarray, None], optional): + Output array to populate. + Array must have the correct shape and the expected data type. + order ("C","F","A","K", optional): + Memory layout of the new output array, if parameter + `out` is ``None``. + Default: "K". + +Returns: + usm_ndarray: + An array containing the element-wise sums. The data type of the + returned array is determined by the Type Promotion Rules. +""" +add = BinaryElementwiseFunc( + "add", + ti._add_result_type, + ti._add, + _add_docstring_, + binary_inplace_fn=ti._add_inplace, +) +del _add_docstring_ + # U04: ===== ASIN (x) _asin_docstring = r""" asin(x, /, \*, out=None, order='K') @@ -211,6 +246,41 @@ ) del _atan_docstring +# B02: ===== ATAN2 (x1, x2) +_atan2_docstring_ = r""" +atan2(x1, x2, /, \*, out=None, order='K') + +Calculates the inverse tangent of the quotient `x1_i/x2_i` for each element +`x1_i` of the input array `x1` with the respective element `x2_i` of the +input array `x2`. Each element-wise result is expressed in radians. + +Args: + x1 (usm_ndarray): + First input array, expected to have a real-valued floating-point + data type. + x2 (usm_ndarray): + Second input array, also expected to have a real-valued + floating-point data type. + out (Union[usm_ndarray, None], optional): + Output array to populate. + Array must have the correct shape and the expected data type. + order ("C","F","A","K", optional): + Memory layout of the new output array, if parameter + `out` is ``None``. + Default: "K". + +Returns: + usm_ndarray: + An array containing the inverse tangent of the quotient `x1`/`x2`. + The returned array must have a real-valued floating-point data type + determined by Type Promotion Rules. +""" + +atan2 = BinaryElementwiseFunc( + "atan2", ti._atan2_result_type, ti._atan2, _atan2_docstring_ +) +del _atan2_docstring_ + # U07: ===== ATANH (x) _atanh_docstring = r""" atanh(x, /, \*, out=None, order='K') @@ -240,6 +310,80 @@ ) del _atanh_docstring +# B03: ===== BITWISE_AND (x1, x2) +_bitwise_and_docstring_ = r""" +bitwise_and(x1, x2, /, \*, out=None, order='K') + +Computes the bitwise AND of the underlying binary representation of each +element `x1_i` of the input array `x1` with the respective element `x2_i` +of the input array `x2`. + +Args: + x1 (usm_ndarray): + First input array, expected to have integer or boolean data type. + x2 (usm_ndarray): + Second input array, also expected to have integer or boolean data + type. + out (Union[usm_ndarray, None], optional): + Output array to populate. + Array must have the correct shape and the expected data type. + order ("C","F","A","K", optional): + Memory layout of the new output array, if parameter + `out` is ``None``. + Default: "K". + +Returns: + usm_ndarray: + An array containing the element-wise results. The data type + of the returned array is determined by the Type Promotion Rules. +""" + +bitwise_and = BinaryElementwiseFunc( + "bitwise_and", + ti._bitwise_and_result_type, + ti._bitwise_and, + _bitwise_and_docstring_, + binary_inplace_fn=ti._bitwise_and_inplace, +) +del _bitwise_and_docstring_ + +# B04: ===== BITWISE_LEFT_SHIFT (x1, x2) +_bitwise_left_shift_docstring_ = r""" +bitwise_left_shift(x1, x2, /, \*, out=None, order='K') + +Shifts the bits of each element `x1_i` of the input array x1 to the left by +appending `x2_i` (i.e., the respective element in the input array `x2`) zeros to +the right of `x1_i`. + +Args: + x1 (usm_ndarray): + First input array, expected to have integer data type. + x2 (usm_ndarray): + Second input array, also expected to have integer data type. + Each element must be greater than or equal to 0. + out (Union[usm_ndarray, None], optional): + Output array to populate. + Array must have the correct shape and the expected data type. + order ("C","F","A","K", optional): + Memory layout of the new output array, if parameter + `out` is ``None``. + Default: "K". + +Returns: + usm_ndarray: + An array containing the element-wise results. The data type + of the returned array is determined by the Type Promotion Rules. +""" + +bitwise_left_shift = BinaryElementwiseFunc( + "bitwise_left_shift", + ti._bitwise_left_shift_result_type, + ti._bitwise_left_shift, + _bitwise_left_shift_docstring_, + binary_inplace_fn=ti._bitwise_left_shift_inplace, +) +del _bitwise_left_shift_docstring_ + # U08: ===== BITWISE_INVERT (x) _bitwise_invert_docstring = r""" bitwise_invert(x, /, \*, out=None, order='K') @@ -272,6 +416,117 @@ ) del _bitwise_invert_docstring +# B05: ===== BITWISE_OR (x1, x2) +_bitwise_or_docstring_ = r""" +bitwise_or(x1, x2, /, \*, out=None, order='K') + +Computes the bitwise OR of the underlying binary representation of each +element `x1_i` of the input array `x1` with the respective element `x2_i` +of the input array `x2`. + +Args: + x1 (usm_ndarray): + First input array, expected to have integer or boolean data type. + x2 (usm_ndarray): + Second input array, also expected to have integer or boolean data + type. + out (Union[usm_ndarray, None], optional): + Output array to populate. + Array must have the correct shape and the expected data type. + order ("C","F","A","K", optional): + Memory layout of the new output array, if parameter + `out` is ``None``. + Default: "K". + +Returns: + usm_ndarray: + An array containing the element-wise results. The data type + of the returned array is determined by the Type Promotion Rules. +""" + +bitwise_or = BinaryElementwiseFunc( + "bitwise_or", + ti._bitwise_or_result_type, + ti._bitwise_or, + _bitwise_or_docstring_, + binary_inplace_fn=ti._bitwise_or_inplace, +) +del _bitwise_or_docstring_ + +# B06: ===== BITWISE_RIGHT_SHIFT (x1, x2) +_bitwise_right_shift_docstring_ = r""" +bitwise_right_shift(x1, x2, /, \*, out=None, order='K') + +Shifts the bits of each element `x1_i` of the input array `x1` to the right +according to the respective element `x2_i` of the input array `x2`. + +Args: + x1 (usm_ndarray): + First input array, expected to have integer data type. + x2 (usm_ndarray): + Second input array, also expected to have integer data type. + Each element must be greater than or equal to 0. + out (Union[usm_ndarray, None], optional): + Output array to populate. + Array must have the correct shape and the expected data type. + order ("C","F","A","K", optional): + Memory layout of the new output array, if parameter + `out` is ``None``. + Default: "K". + +Returns: + usm_ndarray: + An array containing the element-wise results. The data type + of the returned array is determined by the Type Promotion Rules. +""" + +bitwise_right_shift = BinaryElementwiseFunc( + "bitwise_right_shift", + ti._bitwise_right_shift_result_type, + ti._bitwise_right_shift, + _bitwise_right_shift_docstring_, + binary_inplace_fn=ti._bitwise_right_shift_inplace, +) +del _bitwise_right_shift_docstring_ + + +# B07: ===== BITWISE_XOR (x1, x2) +_bitwise_xor_docstring_ = r""" +bitwise_xor(x1, x2, /, \*, out=None, order='K') + +Computes the bitwise XOR of the underlying binary representation of each +element `x1_i` of the input array `x1` with the respective element `x2_i` +of the input array `x2`. + +Args: + x1 (usm_ndarray): + First input array, expected to have integer or boolean data type. + x2 (usm_ndarray): + Second input array, also expected to have integer or boolean data + type. + out (Union[usm_ndarray, None], optional): + Output array to populate. + Array must have the correct shape and the expected data type. + order ("C","F","A","K", optional): + Memory layout of the new output array, if parameter + `out` is ``None``. + Default: "K". + +Returns: + usm_ndarray: + An array containing the element-wise results. The data type + of the returned array is determined by the Type Promotion Rules. +""" + +bitwise_xor = BinaryElementwiseFunc( + "bitwise_xor", + ti._bitwise_xor_result_type, + ti._bitwise_xor, + _bitwise_xor_docstring_, + binary_inplace_fn=ti._bitwise_xor_inplace, +) +del _bitwise_xor_docstring_ + # U09: ==== CEIL (x) _ceil_docstring = r""" ceil(x, /, \*, out=None, order='K') diff --git a/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/add.hpp b/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/add.hpp new file mode 100644 index 00000000000..1b7440304f0 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/add.hpp @@ -0,0 +1,688 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise evaluation of ADD(x1, x2) +/// function. +//===---------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include +#include + +#include + +#include "sycl_complex.hpp" +#include "vec_size_util.hpp" + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" +#include "kernels/elementwise_functions/common_inplace.hpp" + +#include "utils/type_dispatch_building.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl::tensor::kernels::add +{ + +using dpctl::tensor::ssize_t; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace tu_ns = dpctl::tensor::type_utils; + +template +struct AddFunctor +{ + + using supports_sg_loadstore = std::negation< + std::disjunction, tu_ns::is_complex>>; + using supports_vec = std::negation< + std::disjunction, tu_ns::is_complex>>; + + resT operator()(const argT1 &in1, const argT2 &in2) const + { + if constexpr (tu_ns::is_complex::value && + tu_ns::is_complex::value) + { + using rT1 = typename argT1::value_type; + using rT2 = typename argT2::value_type; + + return exprm_ns::complex(in1) + exprm_ns::complex(in2); + } + else if constexpr (tu_ns::is_complex::value && + !tu_ns::is_complex::value) + { + using rT1 = typename argT1::value_type; + + return exprm_ns::complex(in1) + in2; + } + else if constexpr (!tu_ns::is_complex::value && + tu_ns::is_complex::value) + { + using rT2 = typename argT2::value_type; + + return in1 + exprm_ns::complex(in2); + } + else { + return in1 + in2; + } + } + + template + sycl::vec + operator()(const sycl::vec &in1, + const sycl::vec &in2) const + { + auto tmp = in1 + in2; + if constexpr (std::is_same_v) { + return tmp; + } + else { + using dpctl::tensor::type_utils::vec_cast; + + return vec_cast( + tmp); + } + } +}; + +template +using AddContigFunctor = + elementwise_common::BinaryContigFunctor, + vec_sz, + n_vecs, + enable_sg_loadstore>; + +template +using AddStridedFunctor = + elementwise_common::BinaryStridedFunctor>; + +template +struct AddOutputType +{ + using value_type = typename std::disjunction< + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, + td_ns::BinaryTypeMapResultEntry, + T2, + std::complex, + std::complex>, + td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; +}; + +namespace hyperparam_detail +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct AddContigHyperparameterSet +{ + using value_type = typename std::disjunction< + BinaryContigHyperparameterSetEntry, + BinaryContigHyperparameterSetEntry, + BinaryContigHyperparameterSetEntry, + BinaryContigHyperparameterSetEntry, + BinaryContigHyperparameterSetEntry, + BinaryContigHyperparameterSetEntry, + ContigHyperparameterSetDefault<4u, 2u>>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of namespace hyperparam_detail + +template +class add_contig_kernel; + +template +sycl::event add_contig_impl(sycl::queue &exec_q, + std::size_t nelems, + const char *arg1_p, + ssize_t arg1_offset, + const char *arg2_p, + ssize_t arg2_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends = {}) +{ + using AddHS = hyperparam_detail::AddContigHyperparameterSet; + static constexpr auto vec_sz = AddHS::vec_sz; + static constexpr auto n_vecs = AddHS::n_vecs; + + return elementwise_common::binary_contig_impl< + argTy1, argTy2, AddOutputType, AddContigFunctor, add_contig_kernel, + vec_sz, n_vecs>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, + arg2_offset, res_p, res_offset, depends); +} + +template +struct AddContigFactory +{ + fnT get() + { + if constexpr (!AddOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = add_contig_impl; + return fn; + } + } +}; + +template +struct AddTypeMapFactory +{ + /*! @brief get typeid for output type of std::add(T1 x, T2 y) */ + std::enable_if_t::value, int> get() + { + using rT = typename AddOutputType::value_type; + return td_ns::GetTypeid{}.get(); + } +}; + +template +class add_strided_kernel; + +template +sycl::event add_strided_impl(sycl::queue &exec_q, + std::size_t nelems, + int nd, + const ssize_t *shape_and_strides, + const char *arg1_p, + ssize_t arg1_offset, + const char *arg2_p, + ssize_t arg2_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + return elementwise_common::binary_strided_impl< + argTy1, argTy2, AddOutputType, AddStridedFunctor, add_strided_kernel>( + exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p, + arg2_offset, res_p, res_offset, depends, additional_depends); +} + +template +struct AddStridedFactory +{ + fnT get() + { + if constexpr (!AddOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = add_strided_impl; + return fn; + } + } +}; + +template +class add_matrix_row_broadcast_sg_krn; + +template +using AddContigMatrixContigRowBroadcastingFunctor = + elementwise_common::BinaryContigMatrixContigRowBroadcastingFunctor< + argT1, + argT2, + resT, + AddFunctor>; + +template +sycl::event add_contig_matrix_contig_row_broadcast_impl( + sycl::queue &exec_q, + std::vector &host_tasks, + std::size_t n0, + std::size_t n1, + const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix + ssize_t mat_offset, + const char *vec_p, // typeless pointer to (n1,) contiguous row + ssize_t vec_offset, + char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, + // res[i,j] = mat[i,j] + vec[j] + ssize_t res_offset, + const std::vector &depends = {}) +{ + return elementwise_common::binary_contig_matrix_contig_row_broadcast_impl< + argT1, argT2, resT, AddContigMatrixContigRowBroadcastingFunctor, + add_matrix_row_broadcast_sg_krn>(exec_q, host_tasks, n0, n1, mat_p, + mat_offset, vec_p, vec_offset, res_p, + res_offset, depends); +} + +template +struct AddContigMatrixContigRowBroadcastFactory +{ + fnT get() + { + if constexpr (!AddOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + using resT = typename AddOutputType::value_type; + if constexpr (dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = + add_contig_matrix_contig_row_broadcast_impl; + return fn; + } + } + } +}; + +template +sycl::event add_contig_row_contig_matrix_broadcast_impl( + sycl::queue &exec_q, + std::vector &host_tasks, + std::size_t n0, + std::size_t n1, + const char *vec_p, // typeless pointer to (n1,) contiguous row + ssize_t vec_offset, + const char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix + ssize_t mat_offset, + char *res_p, // typeless pointer to (n0, n1) result C-contig. matrix, + // res[i,j] = mat[i,j] + vec[j] + ssize_t res_offset, + const std::vector &depends = {}) +{ + return add_contig_matrix_contig_row_broadcast_impl( + exec_q, host_tasks, n0, n1, mat_p, mat_offset, vec_p, vec_offset, res_p, + res_offset, depends); +}; + +template +struct AddContigRowContigMatrixBroadcastFactory +{ + fnT get() + { + if constexpr (!AddOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + using resT = typename AddOutputType::value_type; + if constexpr (dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = + add_contig_row_contig_matrix_broadcast_impl; + return fn; + } + } + } +}; + +template +struct AddInplaceFunctor +{ + + using supports_sg_loadstore = std::negation< + std::disjunction, tu_ns::is_complex>>; + using supports_vec = std::negation< + std::disjunction, tu_ns::is_complex>>; + + void operator()(resT &res, const argT &in) + { + res += in; + } + + template + void operator()(sycl::vec &res, + const sycl::vec &in) + { + res += in; + } +}; + +template +using AddInplaceContigFunctor = elementwise_common::BinaryInplaceContigFunctor< + argT, + resT, + AddInplaceFunctor, + vec_sz, + n_vecs, + enable_sg_loadstore>; + +template +using AddInplaceStridedFunctor = + elementwise_common::BinaryInplaceStridedFunctor< + argT, + resT, + IndexerT, + AddInplaceFunctor>; + +template +class add_inplace_contig_kernel; + +/* @brief Types supported by in-place add */ +template +struct AddInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + td_ns::TypePairDefinedEntry, + resTy, + std::complex>, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct AddInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x += y */ + std::enable_if_t::value, int> get() + { + if constexpr (AddInplaceTypePairSupport::is_defined) { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + +template +sycl::event + add_inplace_contig_impl(sycl::queue &exec_q, + std::size_t nelems, + const char *arg_p, + ssize_t arg_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends = {}) +{ + static constexpr auto vec_sz = + hyperparam_detail::AddContigHyperparameterSet::vec_sz; + static constexpr auto n_vecs = + hyperparam_detail::AddContigHyperparameterSet::n_vecs; + + return elementwise_common::binary_inplace_contig_impl< + argTy, resTy, AddInplaceContigFunctor, add_inplace_contig_kernel, + vec_sz, n_vecs>(exec_q, nelems, arg_p, arg_offset, res_p, res_offset, + depends); +} + +template +struct AddInplaceContigFactory +{ + fnT get() + { + if constexpr (!AddInplaceTypePairSupport::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = add_inplace_contig_impl; + return fn; + } + } +}; + +template +class add_inplace_strided_kernel; + +template +sycl::event + add_inplace_strided_impl(sycl::queue &exec_q, + std::size_t nelems, + int nd, + const ssize_t *shape_and_strides, + const char *arg_p, + ssize_t arg_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + return elementwise_common::binary_inplace_strided_impl< + argTy, resTy, AddInplaceStridedFunctor, add_inplace_strided_kernel>( + exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, + res_offset, depends, additional_depends); +} + +template +struct AddInplaceStridedFactory +{ + fnT get() + { + if constexpr (!AddInplaceTypePairSupport::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = add_inplace_strided_impl; + return fn; + } + } +}; + +template +class add_inplace_row_matrix_broadcast_sg_krn; + +template +using AddInplaceRowMatrixBroadcastingFunctor = + elementwise_common::BinaryInplaceRowMatrixBroadcastingFunctor< + argT, + resT, + AddInplaceFunctor>; + +template +sycl::event add_inplace_row_matrix_broadcast_impl( + sycl::queue &exec_q, + std::vector &host_tasks, + std::size_t n0, + std::size_t n1, + const char *vec_p, // typeless pointer to (n1,) contiguous row + ssize_t vec_offset, + char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix + ssize_t mat_offset, + const std::vector &depends = {}) +{ + return elementwise_common::binary_inplace_row_matrix_broadcast_impl< + argT, resT, AddInplaceRowMatrixBroadcastingFunctor, + add_inplace_row_matrix_broadcast_sg_krn>(exec_q, host_tasks, n0, n1, + vec_p, vec_offset, mat_p, + mat_offset, depends); +} + +template +struct AddInplaceRowMatrixBroadcastFactory +{ + fnT get() + { + if constexpr (!AddInplaceTypePairSupport::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + if constexpr (dpctl::tensor::type_utils::is_complex::value || + dpctl::tensor::type_utils::is_complex::value) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = add_inplace_row_matrix_broadcast_impl; + return fn; + } + } + } +}; + +} // namespace dpctl::tensor::kernels::add diff --git a/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp b/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp new file mode 100644 index 00000000000..3da1b828d0e --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/atan2.hpp @@ -0,0 +1,232 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise evaluation of ATAN2(x1, x2) +/// function. +//===---------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include + +#include + +#include "vec_size_util.hpp" + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" + +#include "utils/type_dispatch_building.hpp" + +namespace dpctl::tensor::kernels::atan2 +{ + +using dpctl::tensor::ssize_t; +namespace td_ns = dpctl::tensor::type_dispatch; + +template +struct Atan2Functor +{ + + using supports_sg_loadstore = std::true_type; + using supports_vec = std::false_type; + + resT operator()(const argT1 &in1, const argT2 &in2) const + { + if (std::isinf(in2) && !sycl::signbit(in2)) { + if (std::isfinite(in1)) { + return sycl::copysign(resT(0), in1); + } + } + return sycl::atan2(in1, in2); + } +}; + +template +using Atan2ContigFunctor = + elementwise_common::BinaryContigFunctor, + vec_sz, + n_vecs, + enable_sg_loadstore>; + +template +using Atan2StridedFunctor = + elementwise_common::BinaryStridedFunctor>; + +template +struct Atan2OutputType +{ + using value_type = typename std::disjunction< + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; +}; + +namespace hyperparam_detail +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct Atan2ContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of namespace hyperparam_detail + +template +class atan2_contig_kernel; + +template +sycl::event atan2_contig_impl(sycl::queue &exec_q, + std::size_t nelems, + const char *arg1_p, + ssize_t arg1_offset, + const char *arg2_p, + ssize_t arg2_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends = {}) +{ + using Atan2HS = + hyperparam_detail::Atan2ContigHyperparameterSet; + static constexpr std::uint8_t vec_sz = Atan2HS::vec_sz; + static constexpr std::uint8_t n_vecs = Atan2HS::n_vecs; + + return elementwise_common::binary_contig_impl< + argTy1, argTy2, Atan2OutputType, Atan2ContigFunctor, + atan2_contig_kernel, vec_sz, n_vecs>(exec_q, nelems, arg1_p, + arg1_offset, arg2_p, arg2_offset, + res_p, res_offset, depends); +} + +template +struct Atan2ContigFactory +{ + fnT get() + { + if constexpr (!Atan2OutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = atan2_contig_impl; + return fn; + } + } +}; + +template +struct Atan2TypeMapFactory +{ + /*! @brief get typeid for output type of sycl::atan2(T1 x, T2 y) */ + std::enable_if_t::value, int> get() + { + using rT = typename Atan2OutputType::value_type; + return td_ns::GetTypeid{}.get(); + } +}; + +template +class atan2_strided_kernel; + +template +sycl::event + atan2_strided_impl(sycl::queue &exec_q, + std::size_t nelems, + int nd, + const ssize_t *shape_and_strides, + const char *arg1_p, + ssize_t arg1_offset, + const char *arg2_p, + ssize_t arg2_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + return elementwise_common::binary_strided_impl< + argTy1, argTy2, Atan2OutputType, Atan2StridedFunctor, + atan2_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p, + arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends, additional_depends); +} + +template +struct Atan2StridedFactory +{ + fnT get() + { + if constexpr (!Atan2OutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = atan2_strided_impl; + return fn; + } + } +}; + +} // namespace dpctl::tensor::kernels::atan2 diff --git a/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp b/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp new file mode 100644 index 00000000000..f3f25c45e36 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/bitwise_and.hpp @@ -0,0 +1,466 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise bitwise_and(ar1, ar2) operation. +//===---------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include + +#include + +#include "vec_size_util.hpp" + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" +#include "kernels/elementwise_functions/common_inplace.hpp" + +#include "utils/type_dispatch_building.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl::tensor::kernels::bitwise_and +{ + +using dpctl::tensor::ssize_t; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace tu_ns = dpctl::tensor::type_utils; + +template +struct BitwiseAndFunctor +{ + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + using supports_sg_loadstore = typename std::true_type; + using supports_vec = typename std::true_type; + + resT operator()(const argT1 &in1, const argT2 &in2) const + { + using tu_ns::convert_impl; + + if constexpr (std::is_same_v) { + return in1 && in2; + } + else { + return (in1 & in2); + } + } + + template + sycl::vec + operator()(const sycl::vec &in1, + const sycl::vec &in2) const + { + + if constexpr (std::is_same_v) { + using dpctl::tensor::type_utils::vec_cast; + + auto tmp = (in1 && in2); + return vec_cast( + tmp); + } + else { + return (in1 & in2); + } + } +}; + +template +using BitwiseAndContigFunctor = elementwise_common::BinaryContigFunctor< + argT1, + argT2, + resT, + BitwiseAndFunctor, + vec_sz, + n_vecs, + enable_sg_loadstore>; + +template +using BitwiseAndStridedFunctor = elementwise_common::BinaryStridedFunctor< + argT1, + argT2, + resT, + IndexerT, + BitwiseAndFunctor>; + +template +struct BitwiseAndOutputType +{ + using value_type = typename std::disjunction< + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; +}; + +namespace hyperparam_detail +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct BitwiseAndContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; +} // end of namespace hyperparam_detail + +template +class bitwise_and_contig_kernel; + +template +sycl::event + bitwise_and_contig_impl(sycl::queue &exec_q, + std::size_t nelems, + const char *arg1_p, + ssize_t arg1_offset, + const char *arg2_p, + ssize_t arg2_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends = {}) +{ + using BitwiseAndHS = + hyperparam_detail::BitwiseAndContigHyperparameterSet; + static constexpr std::uint8_t vec_sz = BitwiseAndHS::vec_sz; + static constexpr std::uint8_t n_vec = BitwiseAndHS::n_vecs; + + return elementwise_common::binary_contig_impl< + argTy1, argTy2, BitwiseAndOutputType, BitwiseAndContigFunctor, + bitwise_and_contig_kernel, vec_sz, n_vec>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); +} + +template +struct BitwiseAndContigFactory +{ + fnT get() + { + if constexpr (!BitwiseAndOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_and_contig_impl; + return fn; + } + } +}; + +template +struct BitwiseAndTypeMapFactory +{ + /*! @brief get typeid for output type of operator()>(x, y), always bool + */ + std::enable_if_t::value, int> get() + { + using rT = typename BitwiseAndOutputType::value_type; + return td_ns::GetTypeid{}.get(); + } +}; + +template +class bitwise_and_strided_kernel; + +template +sycl::event + bitwise_and_strided_impl(sycl::queue &exec_q, + std::size_t nelems, + int nd, + const ssize_t *shape_and_strides, + const char *arg1_p, + ssize_t arg1_offset, + const char *arg2_p, + ssize_t arg2_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + return elementwise_common::binary_strided_impl< + argTy1, argTy2, BitwiseAndOutputType, BitwiseAndStridedFunctor, + bitwise_and_strided_kernel>( + exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p, + arg2_offset, res_p, res_offset, depends, additional_depends); +} + +template +struct BitwiseAndStridedFactory +{ + fnT get() + { + if constexpr (!BitwiseAndOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_and_strided_impl; + return fn; + } + } +}; + +template +struct BitwiseAndInplaceFunctor +{ + using supports_sg_loadstore = typename std::true_type; + using supports_vec = typename std::true_type; + + void operator()(resT &res, const argT &in) const + { + using tu_ns::convert_impl; + + if constexpr (std::is_same_v) { + res = res && in; + } + else { + res &= in; + } + } + + template + void operator()(sycl::vec &res, + const sycl::vec &in) const + { + + if constexpr (std::is_same_v) { + using dpctl::tensor::type_utils::vec_cast; + + auto tmp = (res && in); + res = vec_cast( + tmp); + } + else { + res &= in; + } + } +}; + +template +using BitwiseAndInplaceContigFunctor = + elementwise_common::BinaryInplaceContigFunctor< + argT, + resT, + BitwiseAndInplaceFunctor, + vec_sz, + n_vecs, + enable_sg_loadstore>; + +template +using BitwiseAndInplaceStridedFunctor = + elementwise_common::BinaryInplaceStridedFunctor< + argT, + resT, + IndexerT, + BitwiseAndInplaceFunctor>; + +template +class bitwise_and_inplace_contig_kernel; + +/* @brief Types supported by in-place bitwise AND */ +template +struct BitwiseAndInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct BitwiseAndInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x &= y */ + std::enable_if_t::value, int> get() + { + if constexpr (BitwiseAndInplaceTypePairSupport::is_defined) + { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + +template +sycl::event bitwise_and_inplace_contig_impl( + sycl::queue &exec_q, + std::size_t nelems, + const char *arg_p, + ssize_t arg_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends = {}) +{ + using BitwiseAndHS = + hyperparam_detail::BitwiseAndContigHyperparameterSet; + static constexpr std::uint8_t vec_sz = BitwiseAndHS::vec_sz; + static constexpr std::uint8_t n_vecs = BitwiseAndHS::n_vecs; + + return elementwise_common::binary_inplace_contig_impl< + argTy, resTy, BitwiseAndInplaceContigFunctor, + bitwise_and_inplace_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends); +} + +template +struct BitwiseAndInplaceContigFactory +{ + fnT get() + { + if constexpr (!BitwiseAndInplaceTypePairSupport::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_and_inplace_contig_impl; + return fn; + } + } +}; + +template +class bitwise_and_inplace_strided_kernel; + +template +sycl::event bitwise_and_inplace_strided_impl( + sycl::queue &exec_q, + std::size_t nelems, + int nd, + const ssize_t *shape_and_strides, + const char *arg_p, + ssize_t arg_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + return elementwise_common::binary_inplace_strided_impl< + argTy, resTy, BitwiseAndInplaceStridedFunctor, + bitwise_and_inplace_strided_kernel>( + exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, + res_offset, depends, additional_depends); +} + +template +struct BitwiseAndInplaceStridedFactory +{ + fnT get() + { + if constexpr (!BitwiseAndInplaceTypePairSupport::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_and_inplace_strided_impl; + return fn; + } + } +}; + +} // namespace dpctl::tensor::kernels::bitwise_and diff --git a/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp b/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp new file mode 100644 index 00000000000..549a220fbab --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/bitwise_left_shift.hpp @@ -0,0 +1,485 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise bitwise_left_shift(ar1, ar2) +/// operation. +//===---------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include + +#include + +#include "vec_size_util.hpp" + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" +#include "kernels/elementwise_functions/common_inplace.hpp" + +#include "utils/type_dispatch_building.hpp" + +namespace dpctl::tensor::kernels::bitwise_left_shift +{ + +using dpctl::tensor::ssize_t; +namespace td_ns = dpctl::tensor::type_dispatch; + +template +struct BitwiseLeftShiftFunctor +{ + static_assert(std::is_integral_v); + static_assert(std::is_integral_v); + static_assert(!std::is_same_v); + static_assert(!std::is_same_v); + + using supports_sg_loadstore = typename std::true_type; + using supports_vec = typename std::true_type; + + resT operator()(const argT1 &in1, const argT2 &in2) const + { + return impl(in1, in2); + } + + template + sycl::vec + operator()(const sycl::vec &in1, + const sycl::vec &in2) const + { + sycl::vec res; +#pragma unroll + for (int i = 0; i < vec_sz; ++i) { + res[i] = impl(in1[i], in2[i]); + } + return res; + } + +private: + resT impl(const argT1 &in1, const argT2 &in2) const + { + static constexpr argT2 in1_bitsize = + static_cast(sizeof(argT1) * 8); + static constexpr resT zero = resT(0); + + // bitshift op with second operand negative, or >= bitwidth(argT1) is UB + // array API spec mandates 0 + if constexpr (std::is_unsigned_v) { + return (in2 < in1_bitsize) ? (in1 << in2) : zero; + } + else { + return (in2 < argT2(0)) + ? zero + : ((in2 < in1_bitsize) ? (in1 << in2) : zero); + } + } +}; + +template +using BitwiseLeftShiftContigFunctor = elementwise_common::BinaryContigFunctor< + argT1, + argT2, + resT, + BitwiseLeftShiftFunctor, + vec_sz, + n_vecs, + enable_sg_loadstore>; + +template +using BitwiseLeftShiftStridedFunctor = elementwise_common::BinaryStridedFunctor< + argT1, + argT2, + resT, + IndexerT, + BitwiseLeftShiftFunctor>; + +template +struct BitwiseLeftShiftOutputType +{ + using ResT = T1; + using value_type = typename std::disjunction< + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; +}; + +namespace hyperparam_detail +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct BitwiseLeftShiftContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of namespace hyperparam_detail + +template +class bitwise_left_shift_contig_kernel; + +template +sycl::event + bitwise_left_shift_contig_impl(sycl::queue &exec_q, + std::size_t nelems, + const char *arg1_p, + ssize_t arg1_offset, + const char *arg2_p, + ssize_t arg2_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends = {}) +{ + using BitwiseLSHS = + hyperparam_detail::BitwiseLeftShiftContigHyperparameterSet; + static constexpr std::uint8_t vec_sz = BitwiseLSHS::vec_sz; + static constexpr std::uint8_t n_vecs = BitwiseLSHS::n_vecs; + + return elementwise_common::binary_contig_impl< + argTy1, argTy2, BitwiseLeftShiftOutputType, + BitwiseLeftShiftContigFunctor, bitwise_left_shift_contig_kernel, vec_sz, + n_vecs>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); +} + +template +struct BitwiseLeftShiftContigFactory +{ + fnT get() + { + if constexpr (!BitwiseLeftShiftOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_left_shift_contig_impl; + return fn; + } + } +}; + +template +struct BitwiseLeftShiftTypeMapFactory +{ + /*! @brief get typeid for output type of operator()>(x, y), always bool + */ + std::enable_if_t::value, int> get() + { + using rT = typename BitwiseLeftShiftOutputType::value_type; + return td_ns::GetTypeid{}.get(); + } +}; + +template +class bitwise_left_shift_strided_kernel; + +template +sycl::event bitwise_left_shift_strided_impl( + sycl::queue &exec_q, + std::size_t nelems, + int nd, + const ssize_t *shape_and_strides, + const char *arg1_p, + ssize_t arg1_offset, + const char *arg2_p, + ssize_t arg2_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + return elementwise_common::binary_strided_impl< + argTy1, argTy2, BitwiseLeftShiftOutputType, + BitwiseLeftShiftStridedFunctor, bitwise_left_shift_strided_kernel>( + exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p, + arg2_offset, res_p, res_offset, depends, additional_depends); +} + +template +struct BitwiseLeftShiftStridedFactory +{ + fnT get() + { + if constexpr (!BitwiseLeftShiftOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_left_shift_strided_impl; + return fn; + } + } +}; + +template +struct BitwiseLeftShiftInplaceFunctor +{ + static_assert(std::is_integral_v); + static_assert(!std::is_same_v); + + using supports_sg_loadstore = typename std::true_type; + using supports_vec = typename std::true_type; + + void operator()(resT &res, const argT &in) const + { + impl(res, in); + } + + template + void operator()(sycl::vec &res, + const sycl::vec &in) const + { +#pragma unroll + for (int i = 0; i < vec_sz; ++i) { + impl(res[i], in[i]); + } + } + +private: + void impl(resT &res, const argT &in) const + { + static constexpr argT res_bitsize = static_cast(sizeof(resT) * 8); + static constexpr resT zero = resT(0); + + // bitshift op with second operand negative, or >= bitwidth(argT1) is UB + // array API spec mandates 0 + if constexpr (std::is_unsigned_v) { + (in < res_bitsize) ? (res <<= in) : res = zero; + } + else { + (in < argT(0)) ? res = zero + : ((in < res_bitsize) ? (res <<= in) : res = zero); + } + } +}; + +template +using BitwiseLeftShiftInplaceContigFunctor = + elementwise_common::BinaryInplaceContigFunctor< + argT, + resT, + BitwiseLeftShiftInplaceFunctor, + vec_sz, + n_vecs, + enable_sg_loadstore>; + +template +using BitwiseLeftShiftInplaceStridedFunctor = + elementwise_common::BinaryInplaceStridedFunctor< + argT, + resT, + IndexerT, + BitwiseLeftShiftInplaceFunctor>; + +template +class bitwise_left_shift_inplace_contig_kernel; + +/* @brief Types supported by in-place bitwise left shift */ +template +struct BitwiseLeftShiftInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct BitwiseLeftShiftInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x <<= y */ + std::enable_if_t::value, int> get() + { + if constexpr (BitwiseLeftShiftInplaceTypePairSupport::is_defined) + { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + +template +sycl::event bitwise_left_shift_inplace_contig_impl( + sycl::queue &exec_q, + std::size_t nelems, + const char *arg_p, + ssize_t arg_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends = {}) +{ + using BitwiseLSHS = + hyperparam_detail::BitwiseLeftShiftContigHyperparameterSet; + static constexpr std::uint8_t vec_sz = BitwiseLSHS::vec_sz; + static constexpr std::uint8_t n_vecs = BitwiseLSHS::n_vecs; + + return elementwise_common::binary_inplace_contig_impl< + argTy, resTy, BitwiseLeftShiftInplaceContigFunctor, + bitwise_left_shift_inplace_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends); +} + +template +struct BitwiseLeftShiftInplaceContigFactory +{ + fnT get() + { + if constexpr (!BitwiseLeftShiftInplaceTypePairSupport::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_left_shift_inplace_contig_impl; + return fn; + } + } +}; + +template +class bitwise_left_shift_inplace_strided_kernel; + +template +sycl::event bitwise_left_shift_inplace_strided_impl( + sycl::queue &exec_q, + std::size_t nelems, + int nd, + const ssize_t *shape_and_strides, + const char *arg_p, + ssize_t arg_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + return elementwise_common::binary_inplace_strided_impl< + argTy, resTy, BitwiseLeftShiftInplaceStridedFunctor, + bitwise_left_shift_inplace_strided_kernel>( + exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, + res_offset, depends, additional_depends); +} + +template +struct BitwiseLeftShiftInplaceStridedFactory +{ + fnT get() + { + if constexpr (!BitwiseLeftShiftInplaceTypePairSupport::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_left_shift_inplace_strided_impl; + return fn; + } + } +}; + +} // namespace dpctl::tensor::kernels::bitwise_left_shift diff --git a/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp b/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp new file mode 100644 index 00000000000..82532e82531 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/bitwise_or.hpp @@ -0,0 +1,466 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise bitwise_or(ar1, ar2) operation. +//===---------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include + +#include + +#include "vec_size_util.hpp" + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" +#include "kernels/elementwise_functions/common_inplace.hpp" + +#include "utils/type_dispatch_building.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl::tensor::kernels::bitwise_or +{ + +using dpctl::tensor::ssize_t; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace tu_ns = dpctl::tensor::type_utils; + +template +struct BitwiseOrFunctor +{ + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + using supports_sg_loadstore = typename std::true_type; + using supports_vec = typename std::true_type; + + resT operator()(const argT1 &in1, const argT2 &in2) const + { + using tu_ns::convert_impl; + + if constexpr (std::is_same_v) { + return in1 || in2; + } + else { + return (in1 | in2); + } + } + + template + sycl::vec + operator()(const sycl::vec &in1, + const sycl::vec &in2) const + { + + if constexpr (std::is_same_v) { + using dpctl::tensor::type_utils::vec_cast; + + auto tmp = (in1 || in2); + return vec_cast( + tmp); + } + else { + return (in1 | in2); + } + } +}; + +template +using BitwiseOrContigFunctor = elementwise_common::BinaryContigFunctor< + argT1, + argT2, + resT, + BitwiseOrFunctor, + vec_sz, + n_vecs, + enable_sg_loadstore>; + +template +using BitwiseOrStridedFunctor = elementwise_common::BinaryStridedFunctor< + argT1, + argT2, + resT, + IndexerT, + BitwiseOrFunctor>; + +template +struct BitwiseOrOutputType +{ + using value_type = typename std::disjunction< + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; +}; + +namespace hyperparam_detail +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct BitwiseOrContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of namespace hyperparam_detail + +template +class bitwise_or_contig_kernel; + +template +sycl::event bitwise_or_contig_impl(sycl::queue &exec_q, + std::size_t nelems, + const char *arg1_p, + ssize_t arg1_offset, + const char *arg2_p, + ssize_t arg2_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends = {}) +{ + using BitwiseOrHS = + hyperparam_detail::BitwiseOrContigHyperparameterSet; + static constexpr std::uint8_t vec_sz = BitwiseOrHS::vec_sz; + static constexpr std::uint8_t n_vecs = BitwiseOrHS::n_vecs; + + return elementwise_common::binary_contig_impl< + argTy1, argTy2, BitwiseOrOutputType, BitwiseOrContigFunctor, + bitwise_or_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); +} + +template +struct BitwiseOrContigFactory +{ + fnT get() + { + if constexpr (!BitwiseOrOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_or_contig_impl; + return fn; + } + } +}; + +template +struct BitwiseOrTypeMapFactory +{ + /*! @brief get typeid for output type of operator()>(x, y), always bool + */ + std::enable_if_t::value, int> get() + { + using rT = typename BitwiseOrOutputType::value_type; + return td_ns::GetTypeid{}.get(); + } +}; + +template +class bitwise_or_strided_kernel; + +template +sycl::event + bitwise_or_strided_impl(sycl::queue &exec_q, + std::size_t nelems, + int nd, + const ssize_t *shape_and_strides, + const char *arg1_p, + ssize_t arg1_offset, + const char *arg2_p, + ssize_t arg2_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + return elementwise_common::binary_strided_impl< + argTy1, argTy2, BitwiseOrOutputType, BitwiseOrStridedFunctor, + bitwise_or_strided_kernel>( + exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p, + arg2_offset, res_p, res_offset, depends, additional_depends); +} + +template +struct BitwiseOrStridedFactory +{ + fnT get() + { + if constexpr (!BitwiseOrOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_or_strided_impl; + return fn; + } + } +}; + +template +struct BitwiseOrInplaceFunctor +{ + using supports_sg_loadstore = typename std::true_type; + using supports_vec = typename std::true_type; + + void operator()(resT &res, const argT &in) const + { + using tu_ns::convert_impl; + + if constexpr (std::is_same_v) { + res = res || in; + } + else { + res |= in; + } + } + + template + void operator()(sycl::vec &res, + const sycl::vec &in) const + { + + if constexpr (std::is_same_v) { + using dpctl::tensor::type_utils::vec_cast; + + auto tmp = (res || in); + res = vec_cast( + tmp); + } + else { + res |= in; + } + } +}; + +template +using BitwiseOrInplaceContigFunctor = + elementwise_common::BinaryInplaceContigFunctor< + argT, + resT, + BitwiseOrInplaceFunctor, + vec_sz, + n_vecs, + enable_sg_loadstore>; + +template +using BitwiseOrInplaceStridedFunctor = + elementwise_common::BinaryInplaceStridedFunctor< + argT, + resT, + IndexerT, + BitwiseOrInplaceFunctor>; + +template +class bitwise_or_inplace_contig_kernel; + +/* @brief Types supported by in-place bitwise OR */ +template +struct BitwiseOrInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct BitwiseOrInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x |= y */ + std::enable_if_t::value, int> get() + { + if constexpr (BitwiseOrInplaceTypePairSupport::is_defined) { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + +template +sycl::event + bitwise_or_inplace_contig_impl(sycl::queue &exec_q, + std::size_t nelems, + const char *arg_p, + ssize_t arg_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends = {}) +{ + using BitwiseOrHS = + hyperparam_detail::BitwiseOrContigHyperparameterSet; + + static constexpr std::uint8_t vec_sz = BitwiseOrHS::vec_sz; + static constexpr std::uint8_t n_vecs = BitwiseOrHS::n_vecs; + + return elementwise_common::binary_inplace_contig_impl< + argTy, resTy, BitwiseOrInplaceContigFunctor, + bitwise_or_inplace_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends); +} + +template +struct BitwiseOrInplaceContigFactory +{ + fnT get() + { + if constexpr (!BitwiseOrInplaceTypePairSupport::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_or_inplace_contig_impl; + return fn; + } + } +}; + +template +class bitwise_or_inplace_strided_kernel; + +template +sycl::event bitwise_or_inplace_strided_impl( + sycl::queue &exec_q, + std::size_t nelems, + int nd, + const ssize_t *shape_and_strides, + const char *arg_p, + ssize_t arg_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + return elementwise_common::binary_inplace_strided_impl< + argTy, resTy, BitwiseOrInplaceStridedFunctor, + bitwise_or_inplace_strided_kernel>( + exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, + res_offset, depends, additional_depends); +} + +template +struct BitwiseOrInplaceStridedFactory +{ + fnT get() + { + if constexpr (!BitwiseOrInplaceTypePairSupport::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_or_inplace_strided_impl; + return fn; + } + } +}; + +} // namespace dpctl::tensor::kernels::bitwise_or diff --git a/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp b/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp new file mode 100644 index 00000000000..49e05ac43f9 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/bitwise_right_shift.hpp @@ -0,0 +1,493 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise bitwise_right_shift(ar1, ar2) +/// operation. +//===---------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include + +#include + +#include "vec_size_util.hpp" + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" +#include "kernels/elementwise_functions/common_inplace.hpp" + +#include "utils/type_dispatch_building.hpp" + +namespace dpctl::tensor::kernels::bitwise_right_shift +{ + +using dpctl::tensor::ssize_t; +namespace td_ns = dpctl::tensor::type_dispatch; + +template +struct BitwiseRightShiftFunctor +{ + static_assert(std::is_same_v); + static_assert(std::is_integral_v); + static_assert(std::is_integral_v); + + using supports_sg_loadstore = typename std::true_type; + using supports_vec = typename std::true_type; + + resT operator()(const argT1 &in1, const argT2 &in2) const + { + return impl(in1, in2); + } + + template + sycl::vec + operator()(const sycl::vec &in1, + const sycl::vec &in2) const + { + sycl::vec res; +#pragma unroll + for (int i = 0; i < vec_sz; ++i) { + res[i] = impl(in1[i], in2[i]); + } + return res; + } + +private: + resT impl(const argT1 &in1, const argT2 &in2) const + { + static constexpr argT2 in1_bitsize = + static_cast(sizeof(argT1) * 8); + static constexpr resT zero = resT(0); + + // bitshift op with second operand negative, or >= bitwidth(argT1) is UB + // array API spec mandates 0 + if constexpr (std::is_unsigned_v) { + return (in2 < in1_bitsize) ? (in1 >> in2) : zero; + } + else { + return (in2 < argT2(0)) + ? zero + : ((in2 < in1_bitsize) + ? (in1 >> in2) + : (in1 < argT1(0) ? resT(-1) : zero)); + } + } +}; + +template +using BitwiseRightShiftContigFunctor = elementwise_common::BinaryContigFunctor< + argT1, + argT2, + resT, + BitwiseRightShiftFunctor, + vec_sz, + n_vecs, + enable_sg_loadstore>; + +template +using BitwiseRightShiftStridedFunctor = + elementwise_common::BinaryStridedFunctor< + argT1, + argT2, + resT, + IndexerT, + BitwiseRightShiftFunctor>; + +template +struct BitwiseRightShiftOutputType +{ + using ResT = T1; + using value_type = typename std::disjunction< + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; +}; + +namespace hyperparam_detail +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct BitwiseRightShiftContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // namespace hyperparam_detail + +template +class bitwise_right_shift_contig_kernel; + +template +sycl::event bitwise_right_shift_contig_impl( + sycl::queue &exec_q, + std::size_t nelems, + const char *arg1_p, + ssize_t arg1_offset, + const char *arg2_p, + ssize_t arg2_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends = {}) +{ + using BitwiseRSHS = + hyperparam_detail::BitwiseRightShiftContigHyperparameterSet; + constexpr std::uint8_t vec_sz = BitwiseRSHS::vec_sz; + constexpr std::uint8_t n_vecs = BitwiseRSHS::n_vecs; + + return elementwise_common::binary_contig_impl< + argTy1, argTy2, BitwiseRightShiftOutputType, + BitwiseRightShiftContigFunctor, bitwise_right_shift_contig_kernel, + vec_sz, n_vecs>(exec_q, nelems, arg1_p, arg1_offset, arg2_p, + arg2_offset, res_p, res_offset, depends); +} + +template +struct BitwiseRightShiftContigFactory +{ + fnT get() + { + if constexpr (!BitwiseRightShiftOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_right_shift_contig_impl; + return fn; + } + } +}; + +template +struct BitwiseRightShiftTypeMapFactory +{ + /*! @brief get typeid for output type of operator()>(x, y), always bool + */ + std::enable_if_t::value, int> get() + { + using rT = typename BitwiseRightShiftOutputType::value_type; + return td_ns::GetTypeid{}.get(); + } +}; + +template +class bitwise_right_shift_strided_kernel; + +template +sycl::event bitwise_right_shift_strided_impl( + sycl::queue &exec_q, + std::size_t nelems, + int nd, + const ssize_t *shape_and_strides, + const char *arg1_p, + ssize_t arg1_offset, + const char *arg2_p, + ssize_t arg2_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + return elementwise_common::binary_strided_impl< + argTy1, argTy2, BitwiseRightShiftOutputType, + BitwiseRightShiftStridedFunctor, bitwise_right_shift_strided_kernel>( + exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p, + arg2_offset, res_p, res_offset, depends, additional_depends); +} + +template +struct BitwiseRightShiftStridedFactory +{ + fnT get() + { + if constexpr (!BitwiseRightShiftOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_right_shift_strided_impl; + return fn; + } + } +}; + +template +struct BitwiseRightShiftInplaceFunctor +{ + static_assert(std::is_integral_v); + static_assert(!std::is_same_v); + + using supports_sg_loadstore = typename std::true_type; + using supports_vec = typename std::true_type; + + void operator()(resT &res, const argT &in) const + { + impl(res, in); + } + + template + void operator()(sycl::vec &res, + const sycl::vec &in) const + { +#pragma unroll + for (int i = 0; i < vec_sz; ++i) { + impl(res[i], in[i]); + } + } + +private: + void impl(resT &res, const argT &in) const + { + static constexpr argT res_bitsize = static_cast(sizeof(resT) * 8); + static constexpr resT zero = resT(0); + + // bitshift op with second operand negative, or >= bitwidth(argT1) is UB + // array API spec mandates 0 + if constexpr (std::is_unsigned_v) { + (in < res_bitsize) ? (res >>= in) : res = zero; + } + else { + (in < argT(0)) ? res = zero + : ((in < res_bitsize) ? (res >>= in) + : (res < resT(0)) ? res = resT(-1) + : res = zero); + } + } +}; + +template +using BitwiseRightShiftInplaceContigFunctor = + elementwise_common::BinaryInplaceContigFunctor< + argT, + resT, + BitwiseRightShiftInplaceFunctor, + vec_sz, + n_vecs, + enable_sg_loadstore>; + +template +using BitwiseRightShiftInplaceStridedFunctor = + elementwise_common::BinaryInplaceStridedFunctor< + argT, + resT, + IndexerT, + BitwiseRightShiftInplaceFunctor>; + +template +class bitwise_right_shift_inplace_contig_kernel; + +/* @brief Types supported by in-place bitwise right shift */ +template +struct BitwiseRightShiftInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct BitwiseRightShiftInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x >>= y */ + std::enable_if_t::value, int> get() + { + if constexpr (BitwiseRightShiftInplaceTypePairSupport::is_defined) + { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + +template +sycl::event bitwise_right_shift_inplace_contig_impl( + sycl::queue &exec_q, + std::size_t nelems, + const char *arg_p, + ssize_t arg_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends = {}) +{ + using BitwiseRSHS = + hyperparam_detail::BitwiseRightShiftContigHyperparameterSet; + + // res = OP(res, arg) + static constexpr std::uint8_t vec_sz = BitwiseRSHS::vec_sz; + static constexpr std::uint8_t n_vecs = BitwiseRSHS::n_vecs; + + return elementwise_common::binary_inplace_contig_impl< + argTy, resTy, BitwiseRightShiftInplaceContigFunctor, + bitwise_right_shift_inplace_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends); +} + +template +struct BitwiseRightShiftInplaceContigFactory +{ + fnT get() + { + if constexpr (!BitwiseRightShiftInplaceTypePairSupport::is_defined) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_right_shift_inplace_contig_impl; + return fn; + } + } +}; + +template +class bitwise_right_shift_inplace_strided_kernel; + +template +sycl::event bitwise_right_shift_inplace_strided_impl( + sycl::queue &exec_q, + std::size_t nelems, + int nd, + const ssize_t *shape_and_strides, + const char *arg_p, + ssize_t arg_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + return elementwise_common::binary_inplace_strided_impl< + argTy, resTy, BitwiseRightShiftInplaceStridedFunctor, + bitwise_right_shift_inplace_strided_kernel>( + exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, + res_offset, depends, additional_depends); +} + +template +struct BitwiseRightShiftInplaceStridedFactory +{ + fnT get() + { + if constexpr (!BitwiseRightShiftInplaceTypePairSupport::is_defined) + { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_right_shift_inplace_strided_impl; + return fn; + } + } +}; + +} // namespace dpctl::tensor::kernels::bitwise_right_shift diff --git a/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp b/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp new file mode 100644 index 00000000000..5ff8c678c68 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/bitwise_xor.hpp @@ -0,0 +1,468 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines kernels for elementwise bitwise_xor(ar1, ar2) operation. +//===---------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include + +#include + +#include "vec_size_util.hpp" + +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common.hpp" +#include "kernels/elementwise_functions/common_inplace.hpp" + +#include "utils/type_dispatch_building.hpp" +#include "utils/type_utils.hpp" + +namespace dpctl::tensor::kernels::bitwise_xor +{ + +using dpctl::tensor::ssize_t; +namespace td_ns = dpctl::tensor::type_dispatch; +namespace tu_ns = dpctl::tensor::type_utils; + +template +struct BitwiseXorFunctor +{ + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + using supports_sg_loadstore = typename std::true_type; + using supports_vec = typename std::true_type; + + resT operator()(const argT1 &in1, const argT2 &in2) const + { + if constexpr (std::is_same_v) { + // (false != false) -> false, (false != true) -> true + // (true != false) -> true, (true != true) -> false + return (in1 != in2); + } + else { + return (in1 ^ in2); + } + } + + template + sycl::vec + operator()(const sycl::vec &in1, + const sycl::vec &in2) const + { + + if constexpr (std::is_same_v) { + using dpctl::tensor::type_utils::vec_cast; + + auto tmp = (in1 != in2); + return vec_cast( + tmp); + } + else { + return (in1 ^ in2); + } + } +}; + +template +using BitwiseXorContigFunctor = elementwise_common::BinaryContigFunctor< + argT1, + argT2, + resT, + BitwiseXorFunctor, + vec_sz, + n_vecs, + enable_sg_loadstore>; + +template +using BitwiseXorStridedFunctor = elementwise_common::BinaryStridedFunctor< + argT1, + argT2, + resT, + IndexerT, + BitwiseXorFunctor>; + +template +struct BitwiseXorOutputType +{ + using value_type = typename std::disjunction< + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::BinaryTypeMapResultEntry, + td_ns::DefaultResultEntry>::result_type; + + static constexpr bool is_defined = !std::is_same_v; +}; + +namespace hyperparam_detail +{ + +namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils; + +using vsu_ns::BinaryContigHyperparameterSetEntry; +using vsu_ns::ContigHyperparameterSetDefault; + +template +struct BitwiseXorContigHyperparameterSet +{ + using value_type = + typename std::disjunction>; + + constexpr static auto vec_sz = value_type::vec_sz; + constexpr static auto n_vecs = value_type::n_vecs; +}; + +} // end of namespace hyperparam_detail + +template +class bitwise_xor_contig_kernel; + +template +sycl::event + bitwise_xor_contig_impl(sycl::queue &exec_q, + std::size_t nelems, + const char *arg1_p, + ssize_t arg1_offset, + const char *arg2_p, + ssize_t arg2_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends = {}) +{ + using BitwiseXorHS = + hyperparam_detail::BitwiseXorContigHyperparameterSet; + static constexpr std::uint8_t vec_sz = BitwiseXorHS::vec_sz; + static constexpr std::uint8_t n_vecs = BitwiseXorHS::n_vecs; + + return elementwise_common::binary_contig_impl< + argTy1, argTy2, BitwiseXorOutputType, BitwiseXorContigFunctor, + bitwise_xor_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg1_p, arg1_offset, arg2_p, arg2_offset, res_p, + res_offset, depends); +} + +template +struct BitwiseXorContigFactory +{ + fnT get() + { + if constexpr (!BitwiseXorOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_xor_contig_impl; + return fn; + } + } +}; + +template +struct BitwiseXorTypeMapFactory +{ + /*! @brief get typeid for output type of operator()>(x, y), always bool + */ + std::enable_if_t::value, int> get() + { + using rT = typename BitwiseXorOutputType::value_type; + return td_ns::GetTypeid{}.get(); + } +}; + +template +class bitwise_xor_strided_kernel; + +template +sycl::event + bitwise_xor_strided_impl(sycl::queue &exec_q, + std::size_t nelems, + int nd, + const ssize_t *shape_and_strides, + const char *arg1_p, + ssize_t arg1_offset, + const char *arg2_p, + ssize_t arg2_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + return elementwise_common::binary_strided_impl< + argTy1, argTy2, BitwiseXorOutputType, BitwiseXorStridedFunctor, + bitwise_xor_strided_kernel>( + exec_q, nelems, nd, shape_and_strides, arg1_p, arg1_offset, arg2_p, + arg2_offset, res_p, res_offset, depends, additional_depends); +} + +template +struct BitwiseXorStridedFactory +{ + fnT get() + { + if constexpr (!BitwiseXorOutputType::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_xor_strided_impl; + return fn; + } + } +}; + +template +struct BitwiseXorInplaceFunctor +{ + using supports_sg_loadstore = typename std::true_type; + using supports_vec = typename std::true_type; + + void operator()(resT &res, const argT &in) const + { + using tu_ns::convert_impl; + + if constexpr (std::is_same_v) { + res = (res != in); + } + else { + res ^= in; + } + } + + template + void operator()(sycl::vec &res, + const sycl::vec &in) const + { + + if constexpr (std::is_same_v) { + using dpctl::tensor::type_utils::vec_cast; + + auto tmp = (res != in); + res = vec_cast( + tmp); + } + else { + res ^= in; + } + } +}; + +template +using BitwiseXorInplaceContigFunctor = + elementwise_common::BinaryInplaceContigFunctor< + argT, + resT, + BitwiseXorInplaceFunctor, + vec_sz, + n_vecs, + enable_sg_loadstore>; + +template +using BitwiseXorInplaceStridedFunctor = + elementwise_common::BinaryInplaceStridedFunctor< + argT, + resT, + IndexerT, + BitwiseXorInplaceFunctor>; + +template +class bitwise_xor_inplace_contig_kernel; + +/* @brief Types supported by in-place bitwise XOR */ +template +struct BitwiseXorInplaceTypePairSupport +{ + /* value if true a kernel for must be instantiated */ + static constexpr bool is_defined = std::disjunction< + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + td_ns::TypePairDefinedEntry, + // fall-through + td_ns::NotDefinedEntry>::is_defined; +}; + +template +struct BitwiseXorInplaceTypeMapFactory +{ + /*! @brief get typeid for output type of x ^= y */ + std::enable_if_t::value, int> get() + { + if constexpr (BitwiseXorInplaceTypePairSupport::is_defined) + { + return td_ns::GetTypeid{}.get(); + } + else { + return td_ns::GetTypeid{}.get(); + } + } +}; + +template +sycl::event bitwise_xor_inplace_contig_impl( + sycl::queue &exec_q, + std::size_t nelems, + const char *arg_p, + ssize_t arg_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends = {}) +{ + using BitwiseXorHS = + hyperparam_detail::BitwiseXorContigHyperparameterSet; + + static constexpr std::uint8_t vec_sz = BitwiseXorHS::vec_sz; + static constexpr std::uint8_t n_vecs = BitwiseXorHS::n_vecs; + + return elementwise_common::binary_inplace_contig_impl< + argTy, resTy, BitwiseXorInplaceContigFunctor, + bitwise_xor_inplace_contig_kernel, vec_sz, n_vecs>( + exec_q, nelems, arg_p, arg_offset, res_p, res_offset, depends); +} + +template +struct BitwiseXorInplaceContigFactory +{ + fnT get() + { + if constexpr (!BitwiseXorInplaceTypePairSupport::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_xor_inplace_contig_impl; + return fn; + } + } +}; + +template +class bitwise_xor_inplace_strided_kernel; + +template +sycl::event bitwise_xor_inplace_strided_impl( + sycl::queue &exec_q, + std::size_t nelems, + int nd, + const ssize_t *shape_and_strides, + const char *arg_p, + ssize_t arg_offset, + char *res_p, + ssize_t res_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + return elementwise_common::binary_inplace_strided_impl< + argTy, resTy, BitwiseXorInplaceStridedFunctor, + bitwise_xor_inplace_strided_kernel>( + exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p, + res_offset, depends, additional_depends); +} + +template +struct BitwiseXorInplaceStridedFactory +{ + fnT get() + { + if constexpr (!BitwiseXorInplaceTypePairSupport::is_defined) { + fnT fn = nullptr; + return fn; + } + else { + fnT fn = bitwise_xor_inplace_strided_impl; + return fn; + } + } +}; + +} // namespace dpctl::tensor::kernels::bitwise_xor diff --git a/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/common_inplace.hpp b/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/common_inplace.hpp new file mode 100644 index 00000000000..2c028bc3015 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/common_inplace.hpp @@ -0,0 +1,478 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines common code for in-place elementwise tensor operations. +//===---------------------------------------------------------------------===// + +#pragma once +#include +#include +#include +#include +#include + +#include + +#include "utils/offset_utils.hpp" +#include "utils/sycl_alloc_utils.hpp" +#include "utils/sycl_utils.hpp" + +#include "kernels/alignment.hpp" +#include "kernels/dpctl_tensor_types.hpp" +#include "kernels/elementwise_functions/common_detail.hpp" + +namespace dpctl::tensor::kernels::elementwise_common +{ + +using dpctl::tensor::ssize_t; +using dpctl::tensor::kernels::alignment_utils:: + disabled_sg_loadstore_wrapper_krn; +using dpctl::tensor::kernels::alignment_utils::is_aligned; +using dpctl::tensor::kernels::alignment_utils::required_alignment; + +using dpctl::tensor::sycl_utils::sub_group_load; +using dpctl::tensor::sycl_utils::sub_group_store; + +template +struct BinaryInplaceContigFunctor +{ +private: + const argT *rhs = nullptr; + resT *lhs = nullptr; + std::size_t nelems_; + +public: + BinaryInplaceContigFunctor(const argT *rhs_tp, + resT *lhs_tp, + const std::size_t n_elems) + : rhs(rhs_tp), lhs(lhs_tp), nelems_(n_elems) + { + } + + void operator()(sycl::nd_item<1> ndit) const + { + BinaryInplaceOperatorT op{}; + static constexpr std::uint8_t elems_per_wi = vec_sz * n_vecs; + /* Each work-item processes vec_sz elements, contiguous in memory */ + /* NB: Workgroup size must be divisible by sub-group size */ + + if constexpr (enable_sg_loadstore && + BinaryInplaceOperatorT::supports_sg_loadstore::value && + BinaryInplaceOperatorT::supports_vec::value && + (vec_sz > 1)) + { + auto sg = ndit.get_sub_group(); + std::uint16_t sgSize = sg.get_max_local_range()[0]; + + std::size_t base = + elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + + if (base + elems_per_wi * sgSize < nelems_) { + +#pragma unroll + for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) { + const std::size_t offset = base + it * sgSize; + auto rhs_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&rhs[offset]); + auto lhs_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&lhs[offset]); + + const sycl::vec &arg_vec = + sub_group_load(sg, rhs_multi_ptr); + sycl::vec res_vec = + sub_group_load(sg, lhs_multi_ptr); + op(res_vec, arg_vec); + + sub_group_store(sg, res_vec, lhs_multi_ptr); + } + } + else { + const std::size_t lane_id = sg.get_local_id()[0]; + for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) { + op(lhs[k], rhs[k]); + } + } + } + else if constexpr (enable_sg_loadstore && + BinaryInplaceOperatorT::supports_sg_loadstore::value) + { + auto sg = ndit.get_sub_group(); + std::uint16_t sgSize = sg.get_max_local_range()[0]; + + std::size_t base = + elems_per_wi * (ndit.get_group(0) * ndit.get_local_range(0) + + sg.get_group_id()[0] * sgSize); + + if (base + elems_per_wi * sgSize < nelems_) { +#pragma unroll + for (std::uint8_t it = 0; it < elems_per_wi; it += vec_sz) { + const std::size_t offset = base + it * sgSize; + auto rhs_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&rhs[offset]); + auto lhs_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&lhs[offset]); + + const sycl::vec arg_vec = + sub_group_load(sg, rhs_multi_ptr); + sycl::vec res_vec = + sub_group_load(sg, lhs_multi_ptr); +#pragma unroll + for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) { + op(res_vec[vec_id], arg_vec[vec_id]); + } + sub_group_store(sg, res_vec, lhs_multi_ptr); + } + } + else { + const std::size_t lane_id = sg.get_local_id()[0]; + for (std::size_t k = base + lane_id; k < nelems_; k += sgSize) { + op(lhs[k], rhs[k]); + } + } + } + else { + const std::size_t sgSize = + ndit.get_sub_group().get_local_range()[0]; + const std::size_t gid = ndit.get_global_linear_id(); + const std::size_t elems_per_sg = elems_per_wi * sgSize; + + const std::size_t start = + (gid / sgSize) * (elems_per_sg - sgSize) + gid; + const std::size_t end = std::min(nelems_, start + elems_per_sg); + for (std::size_t offset = start; offset < end; offset += sgSize) { + op(lhs[offset], rhs[offset]); + } + } + } +}; + +template +struct BinaryInplaceStridedFunctor +{ +private: + const argT *rhs = nullptr; + resT *lhs = nullptr; + TwoOffsets_IndexerT two_offsets_indexer_; + +public: + BinaryInplaceStridedFunctor(const argT *rhs_tp, + resT *lhs_tp, + const TwoOffsets_IndexerT &inp_res_indexer) + : rhs(rhs_tp), lhs(lhs_tp), two_offsets_indexer_(inp_res_indexer) + { + } + + void operator()(sycl::id<1> wid) const + { + const auto &two_offsets_ = + two_offsets_indexer_(static_cast(wid.get(0))); + + const auto &inp_offset = two_offsets_.get_first_offset(); + const auto &lhs_offset = two_offsets_.get_second_offset(); + + BinaryInplaceOperatorT op{}; + op(lhs[lhs_offset], rhs[inp_offset]); + } +}; + +template +struct BinaryInplaceRowMatrixBroadcastingFunctor +{ +private: + const argT *padded_vec; + resT *mat; + std::size_t n_elems; + std::size_t n1; + +public: + BinaryInplaceRowMatrixBroadcastingFunctor(const argT *row_tp, + resT *mat_tp, + std::size_t n_elems_in_mat, + std::size_t n_elems_in_row) + : padded_vec(row_tp), mat(mat_tp), n_elems(n_elems_in_mat), + n1(n_elems_in_row) + { + } + + void operator()(sycl::nd_item<1> ndit) const + { + /* Workgroup size is expected to be a multiple of sub-group size */ + BinaryOperatorT op{}; + static_assert(BinaryOperatorT::supports_sg_loadstore::value); + + auto sg = ndit.get_sub_group(); + const std::size_t gid = ndit.get_global_linear_id(); + + std::uint8_t sgSize = sg.get_max_local_range()[0]; + std::size_t base = gid - sg.get_local_id()[0]; + + if (base + sgSize < n_elems) { + auto in_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&padded_vec[base % n1]); + + auto out_multi_ptr = sycl::address_space_cast< + sycl::access::address_space::global_space, + sycl::access::decorated::yes>(&mat[base]); + + const argT vec_el = sub_group_load(sg, in_multi_ptr); + resT mat_el = sub_group_load(sg, out_multi_ptr); + + op(mat_el, vec_el); + + sub_group_store(sg, mat_el, out_multi_ptr); + } + else { + const std::size_t start = base + sg.get_local_id()[0]; + for (std::size_t k = start; k < n_elems; k += sgSize) { + op(mat[k], padded_vec[k % n1]); + } + } + } +}; + +// Typedefs for function pointers + +typedef sycl::event (*binary_inplace_contig_impl_fn_ptr_t)( + sycl::queue &, + std::size_t, + const char *, + ssize_t, + char *, + ssize_t, + const std::vector &); + +typedef sycl::event (*binary_inplace_strided_impl_fn_ptr_t)( + sycl::queue &, + std::size_t, + int, + const ssize_t *, + const char *, + ssize_t, + char *, + ssize_t, + const std::vector &, + const std::vector &); + +typedef sycl::event (*binary_inplace_row_matrix_broadcast_impl_fn_ptr_t)( + sycl::queue &, + std::vector &, + std::size_t, + std::size_t, + const char *, + ssize_t, + char *, + ssize_t, + const std::vector &); + +template + class BinaryInplaceContigFunctorT, + template + class kernel_name, + std::uint8_t vec_sz = 4u, + std::uint8_t n_vecs = 2u> +sycl::event + binary_inplace_contig_impl(sycl::queue &exec_q, + std::size_t nelems, + const char *rhs_p, + ssize_t rhs_offset, + char *lhs_p, + ssize_t lhs_offset, + const std::vector &depends = {}) +{ + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + + const std::size_t lws = 128; + const std::size_t n_groups = + ((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz)); + const auto gws_range = sycl::range<1>(n_groups * lws); + const auto lws_range = sycl::range<1>(lws); + + const argTy *arg_tp = + reinterpret_cast(rhs_p) + rhs_offset; + resTy *res_tp = reinterpret_cast(lhs_p) + lhs_offset; + + if (is_aligned(arg_tp) && + is_aligned(res_tp)) + { + static constexpr bool enable_sg_loadstore = true; + using KernelName = kernel_name; + using Impl = + BinaryInplaceContigFunctorT; + + cgh.parallel_for( + sycl::nd_range<1>(gws_range, lws_range), + Impl(arg_tp, res_tp, nelems)); + } + else { + static constexpr bool disable_sg_loadstore = true; + using InnerKernelName = kernel_name; + using KernelName = + disabled_sg_loadstore_wrapper_krn; + using Impl = + BinaryInplaceContigFunctorT; + + cgh.parallel_for( + sycl::nd_range<1>(gws_range, lws_range), + Impl(arg_tp, res_tp, nelems)); + } + }); + return comp_ev; +} + +template + class BinaryInplaceStridedFunctorT, + template + class kernel_name> +sycl::event binary_inplace_strided_impl( + sycl::queue &exec_q, + std::size_t nelems, + int nd, + const ssize_t *shape_and_strides, + const char *rhs_p, + ssize_t rhs_offset, + char *lhs_p, + ssize_t lhs_offset, + const std::vector &depends, + const std::vector &additional_depends) +{ + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(depends); + cgh.depends_on(additional_depends); + + using IndexerT = + typename dpctl::tensor::offset_utils::TwoOffsets_StridedIndexer; + + const IndexerT indexer{nd, rhs_offset, lhs_offset, shape_and_strides}; + + const argTy *arg_tp = reinterpret_cast(rhs_p); + resTy *res_tp = reinterpret_cast(lhs_p); + + using Impl = BinaryInplaceStridedFunctorT; + + cgh.parallel_for>( + {nelems}, Impl(arg_tp, res_tp, indexer)); + }); + return comp_ev; +} + +template + class BinaryInplaceRowMatrixBroadcastFunctorT, + template + class kernel_name> +sycl::event binary_inplace_row_matrix_broadcast_impl( + sycl::queue &exec_q, + std::vector &host_tasks, + std::size_t n0, + std::size_t n1, + const char *vec_p, // typeless pointer to (n1,) contiguous row + ssize_t vec_offset, + char *mat_p, // typeless pointer to (n0, n1) C-contiguous matrix + ssize_t mat_offset, + const std::vector &depends = {}) +{ + const argT *vec = reinterpret_cast(vec_p) + vec_offset; + resT *mat = reinterpret_cast(mat_p) + mat_offset; + + const auto &dev = exec_q.get_device(); + const auto &sg_sizes = dev.get_info(); + // Get device-specific kernel info max_sub_group_size + std::size_t max_sgSize = + *(std::max_element(std::begin(sg_sizes), std::end(sg_sizes))); + + std::size_t n1_padded = n1 + max_sgSize; + auto padded_vec_owner = + dpctl::tensor::alloc_utils::smart_malloc_device(n1_padded, + exec_q); + argT *padded_vec = padded_vec_owner.get(); + + sycl::event make_padded_vec_ev = + dpctl::tensor::kernels::elementwise_detail::populate_padded_vector< + argT>(exec_q, vec, n1, padded_vec, n1_padded, depends); + + // sub-group spans work-items [I, I + sgSize) + // base = ndit.get_global_linear_id() - sg.get_local_id()[0] + // Generically, sub_group_load( &mat[base]) may load arrays from + // different rows of mat. The start corresponds to row (base / n0) + // We read sub_group_load(&padded_vec[(base / n0)]). The vector is + // padded to ensure that reads are accessible + + const std::size_t lws = 128; + + sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) { + cgh.depends_on(make_padded_vec_ev); + + auto lwsRange = sycl::range<1>(lws); + std::size_t n_elems = n0 * n1; + std::size_t n_groups = (n_elems + lws - 1) / lws; + auto gwsRange = sycl::range<1>(n_groups * lws); + + using Impl = BinaryInplaceRowMatrixBroadcastFunctorT; + + cgh.parallel_for>( + sycl::nd_range<1>(gwsRange, lwsRange), + Impl(padded_vec, mat, n_elems, n1)); + }); + + sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {comp_ev}, padded_vec_owner); + host_tasks.push_back(tmp_cleanup_ev); + + return comp_ev; +} + +} // namespace dpctl::tensor::kernels::elementwise_common diff --git a/dpctl_ext/tensor/libtensor/source/elementwise_functions/add.cpp b/dpctl_ext/tensor/libtensor/source/elementwise_functions/add.cpp new file mode 100644 index 00000000000..fb8cf8bbfbf --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/elementwise_functions/add.cpp @@ -0,0 +1,242 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_elementwise_impl +/// extension, specifically functions for elementwise operations. +//===---------------------------------------------------------------------===// + +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include +#include + +#include "add.hpp" +#include "elementwise_functions.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/elementwise_functions/add.hpp" +#include "kernels/elementwise_functions/common.hpp" +#include "kernels/elementwise_functions/common_inplace.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common; +using ew_cmn_ns::binary_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_strided_impl_fn_ptr_t; + +using ew_cmn_ns::binary_inplace_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_inplace_row_matrix_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_inplace_strided_impl_fn_ptr_t; + +// B01: ===== ADD (x1, x2) +namespace impl +{ + +namespace add_fn_ns = dpctl::tensor::kernels::add; + +static binary_contig_impl_fn_ptr_t add_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static int add_output_id_table[td_ns::num_types][td_ns::num_types]; +static int add_inplace_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + add_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +// add(matrix, row) +static binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t + add_contig_matrix_contig_row_broadcast_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +// add(row, matrix) +static binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t + add_contig_row_contig_matrix_broadcast_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static binary_inplace_contig_impl_fn_ptr_t + add_inplace_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static binary_inplace_strided_impl_fn_ptr_t + add_inplace_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; +static binary_inplace_row_matrix_broadcast_impl_fn_ptr_t + add_inplace_row_matrix_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_add_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = add_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::AddTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(add_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::AddStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(add_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::AddContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(add_contig_dispatch_table); + + // function pointers for operation on contiguous matrix, contiguous row + // with contiguous matrix output + using fn_ns::AddContigMatrixContigRowBroadcastFactory; + DispatchTableBuilder< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t, + AddContigMatrixContigRowBroadcastFactory, num_types> + dtb4; + dtb4.populate_dispatch_table( + add_contig_matrix_contig_row_broadcast_dispatch_table); + + // function pointers for operation on contiguous row, contiguous matrix + // with contiguous matrix output + using fn_ns::AddContigRowContigMatrixBroadcastFactory; + DispatchTableBuilder< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t, + AddContigRowContigMatrixBroadcastFactory, num_types> + dtb5; + dtb5.populate_dispatch_table( + add_contig_row_contig_matrix_broadcast_dispatch_table); + + // function pointers for inplace operation on general strided arrays + using fn_ns::AddInplaceStridedFactory; + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(add_inplace_strided_dispatch_table); + + // function pointers for inplace operation on contiguous inputs and output + using fn_ns::AddInplaceContigFactory; + DispatchTableBuilder + dtb7; + dtb7.populate_dispatch_table(add_inplace_contig_dispatch_table); + + // function pointers for inplace operation on contiguous matrix + // and contiguous row + using fn_ns::AddInplaceRowMatrixBroadcastFactory; + DispatchTableBuilder + dtb8; + dtb8.populate_dispatch_table(add_inplace_row_matrix_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::AddInplaceTypeMapFactory; + DispatchTableBuilder dtb9; + dtb9.populate_dispatch_table(add_inplace_output_id_table); +}; + +} // namespace impl + +void init_add(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + impl::populate_add_dispatch_tables(); + using impl::add_contig_dispatch_table; + using impl::add_contig_matrix_contig_row_broadcast_dispatch_table; + using impl::add_contig_row_contig_matrix_broadcast_dispatch_table; + using impl::add_output_id_table; + using impl::add_strided_dispatch_table; + + auto add_pyapi = [&](const arrayT &src1, const arrayT &src2, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, add_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + add_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + add_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + add_contig_matrix_contig_row_broadcast_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + add_contig_row_contig_matrix_broadcast_dispatch_table); + }; + auto add_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + add_output_id_table); + }; + m.def("_add", add_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_add_result_type", add_result_type_pyapi, ""); + + using impl::add_inplace_contig_dispatch_table; + using impl::add_inplace_output_id_table; + using impl::add_inplace_row_matrix_dispatch_table; + using impl::add_inplace_strided_dispatch_table; + + auto add_inplace_pyapi = [&](const arrayT &src, const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, add_inplace_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + add_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + add_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + add_inplace_row_matrix_dispatch_table); + }; + m.def("_add_inplace", add_inplace_pyapi, "", py::arg("lhs"), + py::arg("rhs"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/elementwise_functions/add.hpp b/dpctl_ext/tensor/libtensor/source/elementwise_functions/add.hpp new file mode 100644 index 00000000000..0797adb79dd --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/elementwise_functions/add.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_elementwise_impl +/// extension, specifically functions for elementwise operations. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_add(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/elementwise_functions/atan2.cpp b/dpctl_ext/tensor/libtensor/source/elementwise_functions/atan2.cpp new file mode 100644 index 00000000000..de1999e3819 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/elementwise_functions/atan2.cpp @@ -0,0 +1,145 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_elementwise_impl +/// extension, specifically functions for elementwise operations. +//===---------------------------------------------------------------------===// + +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include +#include + +#include "atan2.hpp" +#include "elementwise_functions.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/elementwise_functions/atan2.hpp" +#include "kernels/elementwise_functions/common.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common; +using ew_cmn_ns::binary_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_strided_impl_fn_ptr_t; + +// B02: ===== ATAN2 (x1, x2) +namespace impl +{ +namespace atan2_fn_ns = dpctl::tensor::kernels::atan2; + +static binary_contig_impl_fn_ptr_t + atan2_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; +static int atan2_output_id_table[td_ns::num_types][td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + atan2_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +void populate_atan2_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = atan2_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::Atan2TypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(atan2_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::Atan2StridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(atan2_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::Atan2ContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(atan2_contig_dispatch_table); +}; + +} // namespace impl + +void init_atan2(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + impl::populate_atan2_dispatch_tables(); + using impl::atan2_contig_dispatch_table; + using impl::atan2_output_id_table; + using impl::atan2_strided_dispatch_table; + + auto atan2_pyapi = [&](const arrayT &src1, const arrayT &src2, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, atan2_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + atan2_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + atan2_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto atan2_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + atan2_output_id_table); + }; + m.def("_atan2", atan2_pyapi, "", py::arg("src1"), py::arg("src2"), + py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_atan2_result_type", atan2_result_type_pyapi, ""); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/elementwise_functions/atan2.hpp b/dpctl_ext/tensor/libtensor/source/elementwise_functions/atan2.hpp new file mode 100644 index 00000000000..5bdf9b74db2 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/elementwise_functions/atan2.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_elementwise_impl +/// extension, specifically functions for elementwise operations. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_atan2(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_and.cpp b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_and.cpp new file mode 100644 index 00000000000..ec347355938 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_and.cpp @@ -0,0 +1,205 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_elementwise_impl +/// extension, specifically functions for elementwise operations. +//===---------------------------------------------------------------------===// + +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include +#include + +#include "bitwise_and.hpp" +#include "elementwise_functions.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/elementwise_functions/bitwise_and.hpp" +#include "kernels/elementwise_functions/common.hpp" +#include "kernels/elementwise_functions/common_inplace.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common; +using ew_cmn_ns::binary_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_strided_impl_fn_ptr_t; + +using ew_cmn_ns::binary_inplace_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_inplace_row_matrix_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_inplace_strided_impl_fn_ptr_t; + +// B03: ===== BITWISE_AND (x1, x2) +namespace impl +{ +namespace bitwise_and_fn_ns = dpctl::tensor::kernels::bitwise_and; + +static binary_contig_impl_fn_ptr_t + bitwise_and_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static int bitwise_and_output_id_table[td_ns::num_types][td_ns::num_types]; +static int bitwise_and_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + bitwise_and_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static binary_inplace_contig_impl_fn_ptr_t + bitwise_and_inplace_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static binary_inplace_strided_impl_fn_ptr_t + bitwise_and_inplace_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_bitwise_and_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = bitwise_and_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::BitwiseAndTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(bitwise_and_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::BitwiseAndStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(bitwise_and_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::BitwiseAndContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(bitwise_and_contig_dispatch_table); + + // function pointers for inplace operation on general strided arrays + using fn_ns::BitwiseAndInplaceStridedFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(bitwise_and_inplace_strided_dispatch_table); + + // function pointers for inplace operation on contiguous inputs and output + using fn_ns::BitwiseAndInplaceContigFactory; + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(bitwise_and_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::BitwiseAndInplaceTypeMapFactory; + DispatchTableBuilder dtb6; + dtb6.populate_dispatch_table(bitwise_and_inplace_output_id_table); +}; + +} // namespace impl + +void init_bitwise_and(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + impl::populate_bitwise_and_dispatch_tables(); + using impl::bitwise_and_contig_dispatch_table; + using impl::bitwise_and_output_id_table; + using impl::bitwise_and_strided_dispatch_table; + + auto bitwise_and_pyapi = [&](const arrayT &src1, const arrayT &src2, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, bitwise_and_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + bitwise_and_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + bitwise_and_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto bitwise_and_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + bitwise_and_output_id_table); + }; + m.def("_bitwise_and", bitwise_and_pyapi, "", py::arg("src1"), + py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_bitwise_and_result_type", bitwise_and_result_type_pyapi, ""); + + using impl::bitwise_and_inplace_contig_dispatch_table; + using impl::bitwise_and_inplace_output_id_table; + using impl::bitwise_and_inplace_strided_dispatch_table; + + auto bitwise_and_inplace_pyapi = [&](const arrayT &src, + const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, bitwise_and_inplace_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + bitwise_and_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + bitwise_and_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + td_ns::NullPtrTable< + binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); + }; + m.def("_bitwise_and_inplace", bitwise_and_inplace_pyapi, "", + py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_and.hpp b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_and.hpp new file mode 100644 index 00000000000..19f29ae8822 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_and.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_elementwise_impl +/// extension, specifically functions for elementwise operations. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_bitwise_and(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.cpp b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.cpp new file mode 100644 index 00000000000..eb0f98d2bb4 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.cpp @@ -0,0 +1,215 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_elementwise_impl +/// extension, specifically functions for elementwise operations. +//===---------------------------------------------------------------------===// + +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include +#include + +#include "bitwise_left_shift.hpp" +#include "elementwise_functions.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/elementwise_functions/bitwise_left_shift.hpp" +#include "kernels/elementwise_functions/common.hpp" +#include "kernels/elementwise_functions/common_inplace.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common; +using ew_cmn_ns::binary_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_strided_impl_fn_ptr_t; + +using ew_cmn_ns::binary_inplace_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_inplace_row_matrix_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_inplace_strided_impl_fn_ptr_t; + +// B04: ===== BITWISE_LEFT_SHIFT (x1, x2) +namespace impl +{ +namespace bitwise_left_shift_fn_ns = dpctl::tensor::kernels::bitwise_left_shift; + +static binary_contig_impl_fn_ptr_t + bitwise_left_shift_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static int bitwise_left_shift_output_id_table[td_ns::num_types] + [td_ns::num_types]; +static int bitwise_left_shift_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + bitwise_left_shift_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static binary_inplace_contig_impl_fn_ptr_t + bitwise_left_shift_inplace_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static binary_inplace_strided_impl_fn_ptr_t + bitwise_left_shift_inplace_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_bitwise_left_shift_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = bitwise_left_shift_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::BitwiseLeftShiftTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(bitwise_left_shift_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::BitwiseLeftShiftStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(bitwise_left_shift_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::BitwiseLeftShiftContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(bitwise_left_shift_contig_dispatch_table); + + // function pointers for inplace operation on general strided arrays + using fn_ns::BitwiseLeftShiftInplaceStridedFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table( + bitwise_left_shift_inplace_strided_dispatch_table); + + // function pointers for inplace operation on contiguous inputs and output + using fn_ns::BitwiseLeftShiftInplaceContigFactory; + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table( + bitwise_left_shift_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::BitwiseLeftShiftInplaceTypeMapFactory; + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(bitwise_left_shift_inplace_output_id_table); +}; + +} // namespace impl + +void init_bitwise_left_shift(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + impl::populate_bitwise_left_shift_dispatch_tables(); + using impl::bitwise_left_shift_contig_dispatch_table; + using impl::bitwise_left_shift_output_id_table; + using impl::bitwise_left_shift_strided_dispatch_table; + + auto bitwise_left_shift_pyapi = [&](const arrayT &src1, + const arrayT &src2, + const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, + bitwise_left_shift_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + bitwise_left_shift_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + bitwise_left_shift_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto bitwise_left_shift_result_type_pyapi = + [&](const py::dtype &dtype1, const py::dtype &dtype2) { + return py_binary_ufunc_result_type( + dtype1, dtype2, bitwise_left_shift_output_id_table); + }; + m.def("_bitwise_left_shift", bitwise_left_shift_pyapi, "", + py::arg("src1"), py::arg("src2"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + m.def("_bitwise_left_shift_result_type", + bitwise_left_shift_result_type_pyapi, ""); + + using impl::bitwise_left_shift_inplace_contig_dispatch_table; + using impl::bitwise_left_shift_inplace_output_id_table; + using impl::bitwise_left_shift_inplace_strided_dispatch_table; + + auto bitwise_left_shift_inplace_pyapi = + [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, + bitwise_left_shift_inplace_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + bitwise_left_shift_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + bitwise_left_shift_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + td_ns::NullPtrTable< + binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); + }; + m.def("_bitwise_left_shift_inplace", bitwise_left_shift_inplace_pyapi, + "", py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.hpp b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.hpp new file mode 100644 index 00000000000..49a7947d98c --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_left_shift.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_elementwise_impl +/// extension, specifically functions for elementwise operations. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_bitwise_left_shift(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_or.cpp b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_or.cpp new file mode 100644 index 00000000000..a9bd8820c15 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_or.cpp @@ -0,0 +1,205 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_elementwise_impl +/// extension, specifically functions for elementwise operations. +//===---------------------------------------------------------------------===// + +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include +#include + +#include "bitwise_or.hpp" +#include "elementwise_functions.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/elementwise_functions/bitwise_or.hpp" +#include "kernels/elementwise_functions/common.hpp" +#include "kernels/elementwise_functions/common_inplace.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common; +using ew_cmn_ns::binary_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_strided_impl_fn_ptr_t; + +using ew_cmn_ns::binary_inplace_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_inplace_row_matrix_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_inplace_strided_impl_fn_ptr_t; + +// B05: ===== BITWISE_OR (x1, x2) +namespace impl +{ +namespace bitwise_or_fn_ns = dpctl::tensor::kernels::bitwise_or; + +static binary_contig_impl_fn_ptr_t + bitwise_or_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static int bitwise_or_output_id_table[td_ns::num_types][td_ns::num_types]; +static int bitwise_or_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + bitwise_or_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static binary_inplace_contig_impl_fn_ptr_t + bitwise_or_inplace_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static binary_inplace_strided_impl_fn_ptr_t + bitwise_or_inplace_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_bitwise_or_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = bitwise_or_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::BitwiseOrTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(bitwise_or_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::BitwiseOrStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(bitwise_or_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::BitwiseOrContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(bitwise_or_contig_dispatch_table); + + // function pointers for inplace operation on general strided arrays + using fn_ns::BitwiseOrInplaceStridedFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(bitwise_or_inplace_strided_dispatch_table); + + // function pointers for inplace operation on contiguous inputs and output + using fn_ns::BitwiseOrInplaceContigFactory; + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(bitwise_or_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::BitwiseOrInplaceTypeMapFactory; + DispatchTableBuilder dtb6; + dtb6.populate_dispatch_table(bitwise_or_inplace_output_id_table); +}; + +} // namespace impl + +void init_bitwise_or(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + impl::populate_bitwise_or_dispatch_tables(); + using impl::bitwise_or_contig_dispatch_table; + using impl::bitwise_or_output_id_table; + using impl::bitwise_or_strided_dispatch_table; + + auto bitwise_or_pyapi = [&](const arrayT &src1, const arrayT &src2, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, bitwise_or_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + bitwise_or_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + bitwise_or_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto bitwise_or_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + bitwise_or_output_id_table); + }; + m.def("_bitwise_or", bitwise_or_pyapi, "", py::arg("src1"), + py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_bitwise_or_result_type", bitwise_or_result_type_pyapi, ""); + + using impl::bitwise_or_inplace_contig_dispatch_table; + using impl::bitwise_or_inplace_output_id_table; + using impl::bitwise_or_inplace_strided_dispatch_table; + + auto bitwise_or_inplace_pyapi = [&](const arrayT &src, + const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, bitwise_or_inplace_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + bitwise_or_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + bitwise_or_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + td_ns::NullPtrTable< + binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); + }; + m.def("_bitwise_or_inplace", bitwise_or_inplace_pyapi, "", + py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_or.hpp b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_or.hpp new file mode 100644 index 00000000000..1e24caa5442 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_or.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_elementwise_impl +/// extension, specifically functions for elementwise operations. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_bitwise_or(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.cpp b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.cpp new file mode 100644 index 00000000000..09c66d9f8b5 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.cpp @@ -0,0 +1,216 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_elementwise_impl +/// extension, specifically functions for elementwise operations. +//===---------------------------------------------------------------------===// + +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include +#include + +#include "bitwise_right_shift.hpp" +#include "elementwise_functions.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/elementwise_functions/bitwise_right_shift.hpp" +#include "kernels/elementwise_functions/common.hpp" +#include "kernels/elementwise_functions/common_inplace.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common; +using ew_cmn_ns::binary_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_strided_impl_fn_ptr_t; + +using ew_cmn_ns::binary_inplace_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_inplace_row_matrix_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_inplace_strided_impl_fn_ptr_t; + +// B06: ===== BITWISE_RIGHT_SHIFT (x1, x2) +namespace impl +{ +namespace bitwise_right_shift_fn_ns = + dpctl::tensor::kernels::bitwise_right_shift; + +static binary_contig_impl_fn_ptr_t + bitwise_right_shift_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static int bitwise_right_shift_output_id_table[td_ns::num_types] + [td_ns::num_types]; +static int bitwise_right_shift_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + bitwise_right_shift_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static binary_inplace_contig_impl_fn_ptr_t + bitwise_right_shift_inplace_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static binary_inplace_strided_impl_fn_ptr_t + bitwise_right_shift_inplace_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_bitwise_right_shift_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = bitwise_right_shift_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::BitwiseRightShiftTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(bitwise_right_shift_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::BitwiseRightShiftStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(bitwise_right_shift_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::BitwiseRightShiftContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(bitwise_right_shift_contig_dispatch_table); + + // function pointers for inplace operation on general strided arrays + using fn_ns::BitwiseRightShiftInplaceStridedFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table( + bitwise_right_shift_inplace_strided_dispatch_table); + + // function pointers for inplace operation on contiguous inputs and output + using fn_ns::BitwiseRightShiftInplaceContigFactory; + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table( + bitwise_right_shift_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::BitwiseRightShiftInplaceTypeMapFactory; + DispatchTableBuilder + dtb6; + dtb6.populate_dispatch_table(bitwise_right_shift_inplace_output_id_table); +}; + +} // namespace impl + +void init_bitwise_right_shift(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + impl::populate_bitwise_right_shift_dispatch_tables(); + using impl::bitwise_right_shift_contig_dispatch_table; + using impl::bitwise_right_shift_output_id_table; + using impl::bitwise_right_shift_strided_dispatch_table; + + auto bitwise_right_shift_pyapi = [&](const arrayT &src1, + const arrayT &src2, + const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, + bitwise_right_shift_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + bitwise_right_shift_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + bitwise_right_shift_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto bitwise_right_shift_result_type_pyapi = + [&](const py::dtype &dtype1, const py::dtype &dtype2) { + return py_binary_ufunc_result_type( + dtype1, dtype2, bitwise_right_shift_output_id_table); + }; + m.def("_bitwise_right_shift", bitwise_right_shift_pyapi, "", + py::arg("src1"), py::arg("src2"), py::arg("dst"), + py::arg("sycl_queue"), py::arg("depends") = py::list()); + m.def("_bitwise_right_shift_result_type", + bitwise_right_shift_result_type_pyapi, ""); + + using impl::bitwise_right_shift_inplace_contig_dispatch_table; + using impl::bitwise_right_shift_inplace_output_id_table; + using impl::bitwise_right_shift_inplace_strided_dispatch_table; + + auto bitwise_right_shift_inplace_pyapi = + [&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, + bitwise_right_shift_inplace_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + bitwise_right_shift_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + bitwise_right_shift_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + td_ns::NullPtrTable< + binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); + }; + m.def("_bitwise_right_shift_inplace", bitwise_right_shift_inplace_pyapi, + "", py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.hpp b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.hpp new file mode 100644 index 00000000000..aeb24d73b2f --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_right_shift.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_elementwise_impl +/// extension, specifically functions for elementwise operations. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_bitwise_right_shift(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_xor.cpp b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_xor.cpp new file mode 100644 index 00000000000..0f9447a82b5 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_xor.cpp @@ -0,0 +1,205 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_elementwise_impl +/// extension, specifically functions for elementwise operations. +//===---------------------------------------------------------------------===// + +#include + +#include + +#include "dpnp4pybind11.hpp" +#include +#include +#include + +#include "bitwise_xor.hpp" +#include "elementwise_functions.hpp" +#include "utils/type_dispatch.hpp" + +#include "kernels/elementwise_functions/bitwise_xor.hpp" +#include "kernels/elementwise_functions/common.hpp" +#include "kernels/elementwise_functions/common_inplace.hpp" + +namespace dpctl::tensor::py_internal +{ + +namespace py = pybind11; +namespace td_ns = dpctl::tensor::type_dispatch; + +namespace ew_cmn_ns = dpctl::tensor::kernels::elementwise_common; +using ew_cmn_ns::binary_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_strided_impl_fn_ptr_t; + +using ew_cmn_ns::binary_inplace_contig_impl_fn_ptr_t; +using ew_cmn_ns::binary_inplace_row_matrix_broadcast_impl_fn_ptr_t; +using ew_cmn_ns::binary_inplace_strided_impl_fn_ptr_t; + +// B07: ===== BITWISE_XOR (x1, x2) +namespace impl +{ +namespace bitwise_xor_fn_ns = dpctl::tensor::kernels::bitwise_xor; + +static binary_contig_impl_fn_ptr_t + bitwise_xor_contig_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static int bitwise_xor_output_id_table[td_ns::num_types][td_ns::num_types]; +static int bitwise_xor_inplace_output_id_table[td_ns::num_types] + [td_ns::num_types]; + +static binary_strided_impl_fn_ptr_t + bitwise_xor_strided_dispatch_table[td_ns::num_types][td_ns::num_types]; + +static binary_inplace_contig_impl_fn_ptr_t + bitwise_xor_inplace_contig_dispatch_table[td_ns::num_types] + [td_ns::num_types]; +static binary_inplace_strided_impl_fn_ptr_t + bitwise_xor_inplace_strided_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +void populate_bitwise_xor_dispatch_tables(void) +{ + using namespace td_ns; + namespace fn_ns = bitwise_xor_fn_ns; + + // which input types are supported, and what is the type of the result + using fn_ns::BitwiseXorTypeMapFactory; + DispatchTableBuilder dtb1; + dtb1.populate_dispatch_table(bitwise_xor_output_id_table); + + // function pointers for operation on general strided arrays + using fn_ns::BitwiseXorStridedFactory; + DispatchTableBuilder + dtb2; + dtb2.populate_dispatch_table(bitwise_xor_strided_dispatch_table); + + // function pointers for operation on contiguous inputs and output + using fn_ns::BitwiseXorContigFactory; + DispatchTableBuilder + dtb3; + dtb3.populate_dispatch_table(bitwise_xor_contig_dispatch_table); + + // function pointers for inplace operation on general strided arrays + using fn_ns::BitwiseXorInplaceStridedFactory; + DispatchTableBuilder + dtb4; + dtb4.populate_dispatch_table(bitwise_xor_inplace_strided_dispatch_table); + + // function pointers for inplace operation on contiguous inputs and output + using fn_ns::BitwiseXorInplaceContigFactory; + DispatchTableBuilder + dtb5; + dtb5.populate_dispatch_table(bitwise_xor_inplace_contig_dispatch_table); + + // which types are supported by the in-place kernels + using fn_ns::BitwiseXorInplaceTypeMapFactory; + DispatchTableBuilder dtb6; + dtb6.populate_dispatch_table(bitwise_xor_inplace_output_id_table); +}; + +} // namespace impl + +void init_bitwise_xor(py::module_ m) +{ + using arrayT = dpctl::tensor::usm_ndarray; + using event_vecT = std::vector; + { + impl::populate_bitwise_xor_dispatch_tables(); + using impl::bitwise_xor_contig_dispatch_table; + using impl::bitwise_xor_output_id_table; + using impl::bitwise_xor_strided_dispatch_table; + + auto bitwise_xor_pyapi = [&](const arrayT &src1, const arrayT &src2, + const arrayT &dst, sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_ufunc( + src1, src2, dst, exec_q, depends, bitwise_xor_output_id_table, + // function pointers to handle operation on contiguous arrays + // (pointers may be nullptr) + bitwise_xor_contig_dispatch_table, + // function pointers to handle operation on strided arrays (most + // general case) + bitwise_xor_strided_dispatch_table, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_matrix_contig_row_broadcast_impl_fn_ptr_t>{}, + // function pointers to handle operation of c-contig matrix and + // c-contig row with broadcasting (may be nullptr) + td_ns::NullPtrTable< + binary_contig_row_contig_matrix_broadcast_impl_fn_ptr_t>{}); + }; + auto bitwise_xor_result_type_pyapi = [&](const py::dtype &dtype1, + const py::dtype &dtype2) { + return py_binary_ufunc_result_type(dtype1, dtype2, + bitwise_xor_output_id_table); + }; + m.def("_bitwise_xor", bitwise_xor_pyapi, "", py::arg("src1"), + py::arg("src2"), py::arg("dst"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + m.def("_bitwise_xor_result_type", bitwise_xor_result_type_pyapi, ""); + + using impl::bitwise_xor_inplace_contig_dispatch_table; + using impl::bitwise_xor_inplace_output_id_table; + using impl::bitwise_xor_inplace_strided_dispatch_table; + + auto bitwise_xor_inplace_pyapi = [&](const arrayT &src, + const arrayT &dst, + sycl::queue &exec_q, + const event_vecT &depends = {}) { + return py_binary_inplace_ufunc( + src, dst, exec_q, depends, bitwise_xor_inplace_output_id_table, + // function pointers to handle inplace operation on + // contiguous arrays (pointers may be nullptr) + bitwise_xor_inplace_contig_dispatch_table, + // function pointers to handle inplace operation on strided + // arrays (most general case) + bitwise_xor_inplace_strided_dispatch_table, + // function pointers to handle inplace operation on + // c-contig matrix with c-contig row with broadcasting + // (may be nullptr) + td_ns::NullPtrTable< + binary_inplace_row_matrix_broadcast_impl_fn_ptr_t>{}); + }; + m.def("_bitwise_xor_inplace", bitwise_xor_inplace_pyapi, "", + py::arg("lhs"), py::arg("rhs"), py::arg("sycl_queue"), + py::arg("depends") = py::list()); + } +} + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_xor.hpp b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_xor.hpp new file mode 100644 index 00000000000..4029574cdd7 --- /dev/null +++ b/dpctl_ext/tensor/libtensor/source/elementwise_functions/bitwise_xor.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright (c) 2026, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// - Neither the name of the copyright holder nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** +// +//===---------------------------------------------------------------------===// +/// +/// \file +/// This file defines functions of dpctl.tensor._tensor_elementwise_impl +/// extension, specifically functions for elementwise operations. +//===---------------------------------------------------------------------===// + +#pragma once +#include + +namespace py = pybind11; + +namespace dpctl::tensor::py_internal +{ + +extern void init_bitwise_xor(py::module_ m); + +} // namespace dpctl::tensor::py_internal diff --git a/dpctl_ext/tensor/libtensor/source/elementwise_functions/elementwise_common.cpp b/dpctl_ext/tensor/libtensor/source/elementwise_functions/elementwise_common.cpp index 144e39be252..e4e730a1da6 100644 --- a/dpctl_ext/tensor/libtensor/source/elementwise_functions/elementwise_common.cpp +++ b/dpctl_ext/tensor/libtensor/source/elementwise_functions/elementwise_common.cpp @@ -38,19 +38,19 @@ #include "abs.hpp" #include "acos.hpp" #include "acosh.hpp" -// #include "add.hpp" +#include "add.hpp" #include "angle.hpp" #include "asin.hpp" #include "asinh.hpp" #include "atan.hpp" -// #include "atan2.hpp" +#include "atan2.hpp" #include "atanh.hpp" -// #include "bitwise_and.hpp" +#include "bitwise_and.hpp" #include "bitwise_invert.hpp" -// #include "bitwise_left_shift.hpp" -// #include "bitwise_or.hpp" -// #include "bitwise_right_shift.hpp" -// #include "bitwise_xor.hpp" +#include "bitwise_left_shift.hpp" +#include "bitwise_or.hpp" +#include "bitwise_right_shift.hpp" +#include "bitwise_xor.hpp" #include "cbrt.hpp" #include "ceil.hpp" #include "conj.hpp" @@ -118,19 +118,19 @@ void init_elementwise_functions(py::module_ m) init_abs(m); init_acos(m); init_acosh(m); - // init_add(m); + init_add(m); init_angle(m); init_asin(m); init_asinh(m); init_atan(m); - // init_atan2(m); + init_atan2(m); init_atanh(m); - // init_bitwise_and(m); + init_bitwise_and(m); init_bitwise_invert(m); - // init_bitwise_left_shift(m); - // init_bitwise_or(m); - // init_bitwise_right_shift(m); - // init_bitwise_xor(m); + init_bitwise_left_shift(m); + init_bitwise_or(m); + init_bitwise_right_shift(m); + init_bitwise_xor(m); init_cbrt(m); init_ceil(m); init_conj(m); diff --git a/dpctl_ext/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp b/dpctl_ext/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp index 0b5d1e65c72..4f09a647c3f 100644 --- a/dpctl_ext/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp +++ b/dpctl_ext/tensor/libtensor/source/elementwise_functions/elementwise_functions.hpp @@ -287,4 +287,533 @@ py::object py_unary_ufunc_result_type(const py::dtype &input_dtype, } } +// ======================== Binary functions =========================== + +namespace +{ +template +bool isEqual(Container const &c, std::initializer_list const &l) +{ + return std::equal(std::begin(c), std::end(c), std::begin(l), std::end(l)); +} +} // namespace + +/*! @brief Template implementing Python API for binary elementwise + * functions */ +template +std::pair py_binary_ufunc( + const dpctl::tensor::usm_ndarray &src1, + const dpctl::tensor::usm_ndarray &src2, + const dpctl::tensor::usm_ndarray &dst, // dst = op(src1, src2), elementwise + sycl::queue &exec_q, + const std::vector depends, + // + const output_typesT &output_type_table, + const contig_dispatchT &contig_dispatch_table, + const strided_dispatchT &strided_dispatch_table, + const contig_matrix_row_dispatchT + &contig_matrix_row_broadcast_dispatch_table, + const contig_row_matrix_dispatchT + &contig_row_matrix_broadcast_dispatch_table) +{ + // check type_nums + int src1_typenum = src1.get_typenum(); + int src2_typenum = src2.get_typenum(); + int dst_typenum = dst.get_typenum(); + + auto array_types = td_ns::usm_ndarray_types(); + int src1_typeid = array_types.typenum_to_lookup_id(src1_typenum); + int src2_typeid = array_types.typenum_to_lookup_id(src2_typenum); + int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); + + int output_typeid = output_type_table[src1_typeid][src2_typeid]; + + if (output_typeid != dst_typeid) { + throw py::value_error( + "Destination array has unexpected elemental data type."); + } + + // check that queues are compatible + if (!dpctl::utils::queues_are_compatible(exec_q, {src1, src2, dst})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst); + + // check shapes, broadcasting is assumed done by caller + // check that dimensions are the same + int dst_nd = dst.get_ndim(); + if (dst_nd != src1.get_ndim() || dst_nd != src2.get_ndim()) { + throw py::value_error("Array dimensions are not the same."); + } + + // check that shapes are the same + const py::ssize_t *src1_shape = src1.get_shape_raw(); + const py::ssize_t *src2_shape = src2.get_shape_raw(); + const py::ssize_t *dst_shape = dst.get_shape_raw(); + bool shapes_equal(true); + std::size_t src_nelems(1); + + for (int i = 0; i < dst_nd; ++i) { + src_nelems *= static_cast(src1_shape[i]); + shapes_equal = shapes_equal && (src1_shape[i] == dst_shape[i] && + src2_shape[i] == dst_shape[i]); + } + if (!shapes_equal) { + throw py::value_error("Array shapes are not the same."); + } + + // if nelems is zero, return + if (src_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(dst, src_nelems); + + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + auto const &same_logical_tensors = + dpctl::tensor::overlap::SameLogicalTensors(); + if ((overlap(src1, dst) && !same_logical_tensors(src1, dst)) || + (overlap(src2, dst) && !same_logical_tensors(src2, dst))) + { + throw py::value_error("Arrays index overlapping segments of memory"); + } + // check memory overlap + const char *src1_data = src1.get_data(); + const char *src2_data = src2.get_data(); + char *dst_data = dst.get_data(); + + // handle contiguous inputs + bool is_src1_c_contig = src1.is_c_contiguous(); + bool is_src1_f_contig = src1.is_f_contiguous(); + + bool is_src2_c_contig = src2.is_c_contiguous(); + bool is_src2_f_contig = src2.is_f_contiguous(); + + bool is_dst_c_contig = dst.is_c_contiguous(); + bool is_dst_f_contig = dst.is_f_contiguous(); + + bool all_c_contig = + (is_src1_c_contig && is_src2_c_contig && is_dst_c_contig); + bool all_f_contig = + (is_src1_f_contig && is_src2_f_contig && is_dst_f_contig); + + // dispatch for contiguous inputs + if (all_c_contig || all_f_contig) { + auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid]; + + if (contig_fn != nullptr) { + auto comp_ev = contig_fn(exec_q, src_nelems, src1_data, 0, + src2_data, 0, dst_data, 0, depends); + sycl::event ht_ev = dpctl::utils::keep_args_alive( + exec_q, {src1, src2, dst}, {comp_ev}); + + return std::make_pair(ht_ev, comp_ev); + } + } + + // simplify strides + auto const &src1_strides = src1.get_strides_vector(); + auto const &src2_strides = src2.get_strides_vector(); + auto const &dst_strides = dst.get_strides_vector(); + + using shT = std::vector; + shT simplified_shape; + shT simplified_src1_strides; + shT simplified_src2_strides; + shT simplified_dst_strides; + py::ssize_t src1_offset(0); + py::ssize_t src2_offset(0); + py::ssize_t dst_offset(0); + + int nd = dst_nd; + const py::ssize_t *shape = src1_shape; + + dpctl::tensor::py_internal::simplify_iteration_space_3( + nd, shape, src1_strides, src2_strides, dst_strides, + // outputs + simplified_shape, simplified_src1_strides, simplified_src2_strides, + simplified_dst_strides, src1_offset, src2_offset, dst_offset); + + std::vector host_tasks{}; + if (nd < 3) { + static constexpr auto unit_stride = + std::initializer_list{1}; + + if ((nd == 1) && isEqual(simplified_src1_strides, unit_stride) && + isEqual(simplified_src2_strides, unit_stride) && + isEqual(simplified_dst_strides, unit_stride)) + { + auto contig_fn = contig_dispatch_table[src1_typeid][src2_typeid]; + + if (contig_fn != nullptr) { + auto comp_ev = contig_fn(exec_q, src_nelems, src1_data, + src1_offset, src2_data, src2_offset, + dst_data, dst_offset, depends); + sycl::event ht_ev = dpctl::utils::keep_args_alive( + exec_q, {src1, src2, dst}, {comp_ev}); + + return std::make_pair(ht_ev, comp_ev); + } + } + if (nd == 2) { + static constexpr auto zero_one_strides = + std::initializer_list{0, 1}; + static constexpr auto one_zero_strides = + std::initializer_list{1, 0}; + static constexpr py::ssize_t one{1}; + // special case of C-contiguous matrix and a row + if (isEqual(simplified_src2_strides, zero_one_strides) && + isEqual(simplified_src1_strides, {simplified_shape[1], one}) && + isEqual(simplified_dst_strides, {simplified_shape[1], one})) + { + auto matrix_row_broadcast_fn = + contig_matrix_row_broadcast_dispatch_table[src1_typeid] + [src2_typeid]; + if (matrix_row_broadcast_fn != nullptr) { + int src1_itemsize = src1.get_elemsize(); + int src2_itemsize = src2.get_elemsize(); + int dst_itemsize = dst.get_elemsize(); + + if (is_aligned( + src1_data + src1_offset * src1_itemsize) && + is_aligned( + src2_data + src2_offset * src2_itemsize) && + is_aligned( + dst_data + dst_offset * dst_itemsize)) + { + std::size_t n0 = simplified_shape[0]; + std::size_t n1 = simplified_shape[1]; + sycl::event comp_ev = matrix_row_broadcast_fn( + exec_q, host_tasks, n0, n1, src1_data, src1_offset, + src2_data, src2_offset, dst_data, dst_offset, + depends); + + return std::make_pair( + dpctl::utils::keep_args_alive( + exec_q, {src1, src2, dst}, host_tasks), + comp_ev); + } + } + } + if (isEqual(simplified_src1_strides, one_zero_strides) && + isEqual(simplified_src2_strides, {one, simplified_shape[0]}) && + isEqual(simplified_dst_strides, {one, simplified_shape[0]})) + { + auto row_matrix_broadcast_fn = + contig_row_matrix_broadcast_dispatch_table[src1_typeid] + [src2_typeid]; + if (row_matrix_broadcast_fn != nullptr) { + + int src1_itemsize = src1.get_elemsize(); + int src2_itemsize = src2.get_elemsize(); + int dst_itemsize = dst.get_elemsize(); + + if (is_aligned( + src1_data + src1_offset * src1_itemsize) && + is_aligned( + src2_data + src2_offset * src2_itemsize) && + is_aligned( + dst_data + dst_offset * dst_itemsize)) + { + std::size_t n0 = simplified_shape[1]; + std::size_t n1 = simplified_shape[0]; + sycl::event comp_ev = row_matrix_broadcast_fn( + exec_q, host_tasks, n0, n1, src1_data, src1_offset, + src2_data, src2_offset, dst_data, dst_offset, + depends); + + return std::make_pair( + dpctl::utils::keep_args_alive( + exec_q, {src1, src2, dst}, host_tasks), + comp_ev); + } + } + } + } + } + + // dispatch to strided code + auto strided_fn = strided_dispatch_table[src1_typeid][src2_typeid]; + + if (strided_fn == nullptr) { + throw std::runtime_error( + "Strided implementation is missing for src1_typeid=" + + std::to_string(src1_typeid) + + " and src2_typeid=" + std::to_string(src2_typeid)); + } + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto ptr_sz_event_triple_ = device_allocate_and_pack( + exec_q, host_tasks, simplified_shape, simplified_src1_strides, + simplified_src2_strides, simplified_dst_strides); + auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_)); + auto ©_shape_ev = std::get<2>(ptr_sz_event_triple_); + + const py::ssize_t *shape_strides = shape_strides_owner.get(); + + sycl::event strided_fn_ev = strided_fn( + exec_q, src_nelems, nd, shape_strides, src1_data, src1_offset, + src2_data, src2_offset, dst_data, dst_offset, depends, {copy_shape_ev}); + + // async free of shape_strides temporary + sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {strided_fn_ev}, shape_strides_owner); + host_tasks.push_back(tmp_cleanup_ev); + + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst}, host_tasks), + strided_fn_ev); +} + +/*! @brief Type querying for binary elementwise functions */ +template +py::object py_binary_ufunc_result_type(const py::dtype &input1_dtype, + const py::dtype &input2_dtype, + const output_typesT &output_types_table) +{ + int tn1 = input1_dtype.num(); // NumPy type numbers are the same as in dpctl + int tn2 = input2_dtype.num(); // NumPy type numbers are the same as in dpctl + int src1_typeid = -1; + int src2_typeid = -1; + + auto array_types = td_ns::usm_ndarray_types(); + + try { + src1_typeid = array_types.typenum_to_lookup_id(tn1); + src2_typeid = array_types.typenum_to_lookup_id(tn2); + } catch (const std::exception &e) { + throw py::value_error(e.what()); + } + + if (src1_typeid < 0 || src1_typeid >= td_ns::num_types || src2_typeid < 0 || + src2_typeid >= td_ns::num_types) + { + throw std::runtime_error("binary output type lookup failed"); + } + int dst_typeid = output_types_table[src1_typeid][src2_typeid]; + + if (dst_typeid < 0) { + auto res = py::none(); + return py::cast(res); + } + else { + using dpctl::tensor::py_internal::type_utils::_dtype_from_typenum; + + auto dst_typenum_t = static_cast(dst_typeid); + auto dt = _dtype_from_typenum(dst_typenum_t); + + return py::cast(dt); + } +} + +// ==================== Inplace binary functions ======================= + +template +std::pair + py_binary_inplace_ufunc(const dpctl::tensor::usm_ndarray &lhs, + const dpctl::tensor::usm_ndarray &rhs, + sycl::queue &exec_q, + const std::vector depends, + // + const output_typesT &output_type_table, + const contig_dispatchT &contig_dispatch_table, + const strided_dispatchT &strided_dispatch_table, + const contig_row_matrix_dispatchT + &contig_row_matrix_broadcast_dispatch_table) +{ + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(lhs); + + // check type_nums + int rhs_typenum = rhs.get_typenum(); + int lhs_typenum = lhs.get_typenum(); + + auto array_types = td_ns::usm_ndarray_types(); + int rhs_typeid = array_types.typenum_to_lookup_id(rhs_typenum); + int lhs_typeid = array_types.typenum_to_lookup_id(lhs_typenum); + + int output_typeid = output_type_table[rhs_typeid][lhs_typeid]; + + if (output_typeid != lhs_typeid) { + throw py::value_error( + "Left-hand side array has unexpected elemental data type."); + } + + // check that queues are compatible + if (!dpctl::utils::queues_are_compatible(exec_q, {rhs, lhs})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + // check shapes, broadcasting is assumed done by caller + // check that dimensions are the same + int lhs_nd = lhs.get_ndim(); + if (lhs_nd != rhs.get_ndim()) { + throw py::value_error("Array dimensions are not the same."); + } + + // check that shapes are the same + const py::ssize_t *rhs_shape = rhs.get_shape_raw(); + const py::ssize_t *lhs_shape = lhs.get_shape_raw(); + bool shapes_equal(true); + std::size_t rhs_nelems(1); + + for (int i = 0; i < lhs_nd; ++i) { + rhs_nelems *= static_cast(rhs_shape[i]); + shapes_equal = shapes_equal && (rhs_shape[i] == lhs_shape[i]); + } + if (!shapes_equal) { + throw py::value_error("Array shapes are not the same."); + } + + // if nelems is zero, return + if (rhs_nelems == 0) { + return std::make_pair(sycl::event(), sycl::event()); + } + + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(lhs, rhs_nelems); + + // check memory overlap + auto const &same_logical_tensors = + dpctl::tensor::overlap::SameLogicalTensors(); + auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); + if (overlap(rhs, lhs) && !same_logical_tensors(rhs, lhs)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } + // check memory overlap + const char *rhs_data = rhs.get_data(); + char *lhs_data = lhs.get_data(); + + // handle contiguous inputs + bool is_rhs_c_contig = rhs.is_c_contiguous(); + bool is_rhs_f_contig = rhs.is_f_contiguous(); + + bool is_lhs_c_contig = lhs.is_c_contiguous(); + bool is_lhs_f_contig = lhs.is_f_contiguous(); + + bool both_c_contig = (is_rhs_c_contig && is_lhs_c_contig); + bool both_f_contig = (is_rhs_f_contig && is_lhs_f_contig); + + // dispatch for contiguous inputs + if (both_c_contig || both_f_contig) { + auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid]; + + if (contig_fn != nullptr) { + auto comp_ev = contig_fn(exec_q, rhs_nelems, rhs_data, 0, lhs_data, + 0, depends); + sycl::event ht_ev = + dpctl::utils::keep_args_alive(exec_q, {rhs, lhs}, {comp_ev}); + + return std::make_pair(ht_ev, comp_ev); + } + } + + // simplify strides + auto const &rhs_strides = rhs.get_strides_vector(); + auto const &lhs_strides = lhs.get_strides_vector(); + + using shT = std::vector; + shT simplified_shape; + shT simplified_rhs_strides; + shT simplified_lhs_strides; + py::ssize_t rhs_offset(0); + py::ssize_t lhs_offset(0); + + int nd = lhs_nd; + const py::ssize_t *shape = rhs_shape; + + dpctl::tensor::py_internal::simplify_iteration_space( + nd, shape, rhs_strides, lhs_strides, + // outputs + simplified_shape, simplified_rhs_strides, simplified_lhs_strides, + rhs_offset, lhs_offset); + + std::vector host_tasks{}; + if (nd < 3) { + static constexpr auto unit_stride = + std::initializer_list{1}; + + if ((nd == 1) && isEqual(simplified_rhs_strides, unit_stride) && + isEqual(simplified_lhs_strides, unit_stride)) + { + auto contig_fn = contig_dispatch_table[rhs_typeid][lhs_typeid]; + + if (contig_fn != nullptr) { + auto comp_ev = + contig_fn(exec_q, rhs_nelems, rhs_data, rhs_offset, + lhs_data, lhs_offset, depends); + sycl::event ht_ev = dpctl::utils::keep_args_alive( + exec_q, {rhs, lhs}, {comp_ev}); + + return std::make_pair(ht_ev, comp_ev); + } + } + if (nd == 2) { + static constexpr auto one_zero_strides = + std::initializer_list{1, 0}; + static constexpr py::ssize_t one{1}; + // special case of C-contiguous matrix and a row + if (isEqual(simplified_rhs_strides, one_zero_strides) && + isEqual(simplified_lhs_strides, {one, simplified_shape[0]})) + { + auto row_matrix_broadcast_fn = + contig_row_matrix_broadcast_dispatch_table[rhs_typeid] + [lhs_typeid]; + if (row_matrix_broadcast_fn != nullptr) { + std::size_t n0 = simplified_shape[1]; + std::size_t n1 = simplified_shape[0]; + sycl::event comp_ev = row_matrix_broadcast_fn( + exec_q, host_tasks, n0, n1, rhs_data, rhs_offset, + lhs_data, lhs_offset, depends); + + return std::make_pair(dpctl::utils::keep_args_alive( + exec_q, {lhs, rhs}, host_tasks), + comp_ev); + } + } + } + } + + // dispatch to strided code + auto strided_fn = strided_dispatch_table[rhs_typeid][lhs_typeid]; + + if (strided_fn == nullptr) { + throw std::runtime_error( + "Strided implementation is missing for rhs_typeid=" + + std::to_string(rhs_typeid) + + " and lhs_typeid=" + std::to_string(lhs_typeid)); + } + + using dpctl::tensor::offset_utils::device_allocate_and_pack; + auto ptr_sz_event_triple_ = device_allocate_and_pack( + exec_q, host_tasks, simplified_shape, simplified_rhs_strides, + simplified_lhs_strides); + auto shape_strides_owner = std::move(std::get<0>(ptr_sz_event_triple_)); + auto copy_shape_ev = std::get<2>(ptr_sz_event_triple_); + + const py::ssize_t *shape_strides = shape_strides_owner.get(); + + sycl::event strided_fn_ev = + strided_fn(exec_q, rhs_nelems, nd, shape_strides, rhs_data, rhs_offset, + lhs_data, lhs_offset, depends, {copy_shape_ev}); + + // async free of shape_strides temporary + sycl::event tmp_cleanup_ev = dpctl::tensor::alloc_utils::async_smart_free( + exec_q, {strided_fn_ev}, shape_strides_owner); + + host_tasks.push_back(tmp_cleanup_ev); + + return std::make_pair( + dpctl::utils::keep_args_alive(exec_q, {rhs, lhs}, host_tasks), + strided_fn_ev); +} + } // namespace dpctl::tensor::py_internal diff --git a/dpnp/dpnp_iface_bitwise.py b/dpnp/dpnp_iface_bitwise.py index edf68b2f658..bff5c4e3aed 100644 --- a/dpnp/dpnp_iface_bitwise.py +++ b/dpnp/dpnp_iface_bitwise.py @@ -43,12 +43,11 @@ # pylint: disable=no-name-in-module # pylint: disable=protected-access -import dpctl.tensor._tensor_elementwise_impl as ti import numpy # TODO: revert to `import dpctl.tensor...` # when dpnp fully migrates dpctl/tensor -import dpctl_ext.tensor._tensor_elementwise_impl as ti_ext +import dpctl_ext.tensor._tensor_elementwise_impl as ti import dpnp.backend.extensions.ufunc._ufunc_impl as ufi from dpnp.dpnp_algo.dpnp_elementwise_common import DPNPBinaryFunc, DPNPUnaryFunc @@ -517,8 +516,8 @@ def binary_repr(num, width=None): invert = DPNPUnaryFunc( "invert", - ti_ext._bitwise_invert_result_type, - ti_ext._bitwise_invert, + ti._bitwise_invert_result_type, + ti._bitwise_invert, _INVERT_DOCSTRING, ) diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index c84b61dad4b..d1bdbdcfc96 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -469,8 +469,8 @@ def _validate_interp_param(param, name, exec_q, usm_type, dtype=None): add = DPNPBinaryFunc( "add", - ti._add_result_type, - ti._add, + ti_ext._add_result_type, + ti_ext._add, _ADD_DOCSTRING, mkl_fn_to_call="_mkl_add_to_call", mkl_impl_fn="_add", diff --git a/dpnp/dpnp_iface_trigonometric.py b/dpnp/dpnp_iface_trigonometric.py index 6deab3a8876..186ae47b095 100644 --- a/dpnp/dpnp_iface_trigonometric.py +++ b/dpnp/dpnp_iface_trigonometric.py @@ -572,8 +572,8 @@ def _get_accumulation_res_dt(a, dtype): atan2 = DPNPBinaryFunc( "atan2", - ti._atan2_result_type, - ti._atan2, + ti_ext._atan2_result_type, + ti_ext._atan2, _ATAN2_DOCSTRING, mkl_fn_to_call="_mkl_atan2_to_call", mkl_impl_fn="_atan2",