From 363b0f4058f4091177e9583b2e39977e0cc6601a Mon Sep 17 00:00:00 2001 From: HyosungSink Date: Sat, 16 May 2026 05:52:06 +0000 Subject: [PATCH 1/9] Fix SDPA fully masked rows --- .../kernels/scaled_dot_product_attention.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/ntops/kernels/scaled_dot_product_attention.py b/src/ntops/kernels/scaled_dot_product_attention.py index 00517e8..65912a1 100644 --- a/src/ntops/kernels/scaled_dot_product_attention.py +++ b/src/ntops/kernels/scaled_dot_product_attention.py @@ -166,7 +166,7 @@ def application_without_kv_cache( query_i = (1.4426950408889634 * scale * query[i]).to(query[i].dtype) acc = ntl.zeros((query_i.shape[-2], query_i.shape[-1]), dtype=ntl.float32) - lse = ntl.full((query_i.shape[-2],), 1, dtype=ntl.float32) + lse = ntl.zeros((query_i.shape[-2],), dtype=ntl.float32) max = ntl.full((query_i.shape[-2],), float("-inf"), dtype=ntl.float32) for j in range(key.shape[0]): @@ -190,15 +190,19 @@ def application_without_kv_cache( qk = ntl.where(mask, qk, float("-inf")) next_max = ntl.maximum(max, ntl.max(qk, 1)) - stable_qk = ntl.exp2(qk - next_max[:, None]) - - alpha = ntl.exp2(max - next_max) + safe_next_max = ntl.where( + next_max == float("-inf"), 0.0, next_max + ) + stable_qk = ntl.exp2(qk - safe_next_max[:, None]) + + alpha = ntl.where( + max == float("-inf"), 0.0, ntl.exp2(max - safe_next_max) + ) acc = acc * alpha[:, None] + ntl.dot(stable_qk.to(value[i].dtype), value[j]) max = next_max lse = lse * alpha + ntl.sum(stable_qk, 1) - acc /= lse[:, None] - output[i] = acc # noqa: F841 + output[i] = ntl.where(lse[:, None] == 0.0, 0.0, acc / lse[:, None]) # noqa: F841 def premake( From 814300d56e4fbc7d63804ec2ab4f4b9bf0c0eb33 Mon Sep 17 00:00:00 2001 From: HyosungSink Date: Mon, 18 May 2026 12:42:38 +0000 Subject: [PATCH 2/9] Add ntops rad2deg operator --- src/ntops/kernels/rad2deg.py | 19 ++++++ src/ntops/torch/rad2deg.py | 52 +++++++++++++++++ tests/test_rad2deg.py | 108 +++++++++++++++++++++++++++++++++++ 3 files changed, 179 insertions(+) create mode 100644 src/ntops/kernels/rad2deg.py create mode 100644 src/ntops/torch/rad2deg.py create mode 100644 tests/test_rad2deg.py diff --git a/src/ntops/kernels/rad2deg.py b/src/ntops/kernels/rad2deg.py new file mode 100644 index 0000000..3e067da --- /dev/null +++ b/src/ntops/kernels/rad2deg.py @@ -0,0 +1,19 @@ +import functools +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +BLOCK_SIZE = 2048 + + +def application(input, output): + output = input * 57.29577951308232 # noqa: F841 + + +def premake(ndim, dtype=None, block_size=BLOCK_SIZE): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/rad2deg.py b/src/ntops/torch/rad2deg.py new file mode 100644 index 0000000..c393b16 --- /dev/null +++ b/src/ntops/torch/rad2deg.py @@ -0,0 +1,52 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make, _flatten_kernel_tensors, _prepare_out + + +_kernel_1d = None + + +def _get_kernel_1d(): + global _kernel_1d + if _kernel_1d is None: + _kernel_1d = _cached_make( + ntops.kernels.rad2deg.premake, + 1, + block_size=ntops.kernels.rad2deg.BLOCK_SIZE, + num_warps=2, + max_num_configs=1, + ) + return _kernel_1d + + +def _promote_unary_input(input): + if hasattr(torch, "is_floating_point") and not torch.is_floating_point(input): + return input.to(torch.float32) + return input + + +def rad2deg(input, *, out=None): + input = _promote_unary_input(input) + if input.ndim == 1 and input.is_contiguous(): + if out is None: + out = torch.empty_like(input) + _get_kernel_1d()(input, out) + return out + if tuple(out.shape) == tuple(input.shape) and out.dtype == input.dtype and out.is_contiguous(): + _get_kernel_1d()(input, out) + return out + + out = _prepare_out(out, input.shape, input.dtype, input.device, like=input) + + kernel_input, kernel_out = _flatten_kernel_tensors(input, out) + kernel = _cached_make( + ntops.kernels.rad2deg.premake, + kernel_input.ndim, + block_size=ntops.kernels.rad2deg.BLOCK_SIZE, + num_warps=2, + max_num_configs=1, + ) + kernel(kernel_input, kernel_out) + + return out diff --git a/tests/test_rad2deg.py b/tests/test_rad2deg.py new file mode 100644 index 0000000..2b2893d --- /dev/null +++ b/tests/test_rad2deg.py @@ -0,0 +1,108 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +def _assert_close(output, reference, rtol=1e-3, atol=1e-3): + assert output.shape == reference.shape + assert output.dtype == reference.dtype + assert torch.allclose(output, reference, rtol=rtol, atol=atol, equal_nan=True) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64]) +@pytest.mark.parametrize( + "shape", + [(0,), (1,), (7,), (3, 5), (2, 3, 4), (1, 4, 1), (8, 1)], +) +def test_rad2deg_float_shapes(shape, dtype): + input = torch.randn(shape, dtype=dtype, device="cuda") + + output = ntops.torch.rad2deg(input) + reference = torch.rad2deg(input) + + _assert_close(output, reference) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.int16, torch.int32, torch.int64]) +def test_rad2deg_integer_promotes_to_float32(dtype): + input = torch.tensor([0, 1, -2, 7], dtype=dtype, device="cuda") + + output = ntops.torch.rad2deg(input) + reference = torch.rad2deg(input) + + _assert_close(output, reference, rtol=1e-4, atol=1e-4) + + +@skip_if_cuda_not_available +def test_rad2deg_non_contiguous_and_out(): + base = torch.randn((5, 7), dtype=torch.float32, device="cuda") + input = base.t() + out = torch.empty_like(input) + + result = ntops.torch.rad2deg(input, out=out) + reference = torch.rad2deg(input) + + assert result is out + _assert_close(out, reference, rtol=1e-4, atol=1e-4) + + +@skip_if_cuda_not_available +def test_rad2deg_3d_permute_non_contiguous_and_out(): + input = torch.randn((3, 5, 7), dtype=torch.float32, device="cuda").permute(2, 0, 1) + out = torch.empty_like(input) + + result = ntops.torch.rad2deg(input, out=out) + reference = torch.rad2deg(input) + + assert result is out + _assert_close(out, reference, rtol=1e-4, atol=1e-4) + + +@skip_if_cuda_not_available +def test_rad2deg_scalar(): + input = torch.tensor(1.0, dtype=torch.float32, device="cuda") + + output = ntops.torch.rad2deg(input) + reference = torch.rad2deg(input) + + _assert_close(output, reference, rtol=1e-4, atol=1e-4) + + +@skip_if_cuda_not_available +def test_rad2deg_special_values(): + input = torch.tensor( + [0.0, -0.0, float("inf"), -float("inf"), float("nan")], + dtype=torch.float32, + device="cuda", + ) + + output = ntops.torch.rad2deg(input) + reference = torch.rad2deg(input) + + _assert_close(output, reference, rtol=1e-4, atol=1e-4) + + +@skip_if_cuda_not_available +def test_rad2deg_resizes_out_like_torch(): + input = torch.randn((2, 3), dtype=torch.float32, device="cuda") + out = torch.empty((1,), dtype=torch.float32, device="cuda") + + with pytest.warns(UserWarning): + result = ntops.torch.rad2deg(input, out=out) + reference = torch.rad2deg(input) + + assert result is out + _assert_close(out, reference, rtol=1e-4, atol=1e-4) + + +@skip_if_cuda_not_available +def test_rad2deg_rejects_integer_out_for_float_result(): + input = torch.tensor([1], dtype=torch.int32, device="cuda") + out = torch.empty_like(input) + + with pytest.raises(RuntimeError): + ntops.torch.rad2deg(input, out=out) From 6cebe21d0c10561c5ff696712b5d6ed3df1d8f04 Mon Sep 17 00:00:00 2001 From: HyosungSink Date: Mon, 18 May 2026 12:42:38 +0000 Subject: [PATCH 3/9] Add ntops copysign operator --- src/ntops/kernels/copysign.py | 67 +++++++++++++++++ src/ntops/torch/copysign.py | 111 ++++++++++++++++++++++++++++ tests/test_copysign.py | 135 ++++++++++++++++++++++++++++++++++ 3 files changed, 313 insertions(+) create mode 100644 src/ntops/kernels/copysign.py create mode 100644 src/ntops/torch/copysign.py create mode 100644 tests/test_copysign.py diff --git a/src/ntops/kernels/copysign.py b/src/ntops/kernels/copysign.py new file mode 100644 index 0000000..bdc1004 --- /dev/null +++ b/src/ntops/kernels/copysign.py @@ -0,0 +1,67 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +BLOCK_SIZE = 1024 + + +def broadcast_2d_arrangement(input, other, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + input = input.expand((-1, other.shape[1])) + other = other.expand((input.shape[0], -1)) + return tuple(tensor.flatten().tile((block_size,)) for tensor in (input, other, output)) + + +def application(input, other, output): + input_bits = ntl.cast(input, ntl.uint32, bitcast=True) + other_bits = ntl.cast(other, ntl.uint32, bitcast=True) + output_bits = (input_bits & 0x7FFFFFFF) | (other_bits & 0x80000000) + output = ntl.cast(output_bits, ntl.float32, bitcast=True) # noqa: F841 + + +def double_application(input, other, output): + input_bits = ntl.cast(input, ntl.uint64, bitcast=True) + other_bits = ntl.cast(other, ntl.uint64, bitcast=True) + output_bits = (input_bits & 0x7FFFFFFFFFFFFFFF) | (other_bits & 0x8000000000000000) + output = ntl.cast(output_bits, ntl.float64, bitcast=True) # noqa: F841 + + +def half_application(input, other, output): + input_bits = ntl.cast(input, ntl.uint16, bitcast=True) + other_bits = ntl.cast(other, ntl.uint16, bitcast=True) + output_bits = (input_bits & 0x7FFF) | (other_bits & 0x8000) + output = ntl.cast(output_bits, ntl.float16, bitcast=True) # noqa: F841 + + +def premake( + ndim, + half=False, + double=False, + broadcast_2d=False, + dtype=None, + block_size=BLOCK_SIZE, +): + arrangement_func = broadcast_2d_arrangement if broadcast_2d else arrangement + arrangement_ = functools.partial(arrangement_func, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + if half: + application_ = half_application + elif double: + application_ = double_application + else: + application_ = application + + return arrangement_, application_, tensors diff --git a/src/ntops/torch/copysign.py b/src/ntops/torch/copysign.py new file mode 100644 index 0000000..7344a5f --- /dev/null +++ b/src/ntops/torch/copysign.py @@ -0,0 +1,111 @@ +import functools + +import torch + +import ntops +from ntops.torch.utils import _cached_make, _flatten_kernel_tensors, _prepare_out + + +def _broadcast(input, other): + if hasattr(torch, "broadcast_tensors"): + return torch.broadcast_tensors(input, other) + return input, other + + +def _prepare_inputs(input, other): + if not hasattr(torch, "result_type"): + return input, other, input.dtype + + result_dtype = torch.result_type(input, other) + if not result_dtype.is_floating_point: + result_dtype = torch.float32 + return input.to(result_dtype), other.to(result_dtype), result_dtype + + +@functools.cache +def _get_kernel_1d(half, double): + return _cached_make( + ntops.kernels.copysign.premake, + 1, + half, + double, + block_size=ntops.kernels.copysign.BLOCK_SIZE, + num_warps=4, + max_num_configs=1, + ) + + +@functools.cache +def _get_broadcast_2d_kernel(half, double): + return _cached_make( + ntops.kernels.copysign.premake, + 2, + half, + double, + True, + block_size=4096, + num_warps=8, + max_num_configs=1, + ) + + +def copysign(input, other, *, out=None): + if ( + input.ndim == 1 + and other.ndim == 1 + and tuple(input.shape) == tuple(other.shape) + and input.dtype == other.dtype + and input.dtype.is_floating_point + and input.is_contiguous() + and other.is_contiguous() + ): + half = input.dtype == torch.float16 + double = input.dtype == torch.float64 + if out is None: + out = torch.empty_like(input) + _get_kernel_1d(half, double)(input, other, out) + return out + if tuple(out.shape) == tuple(input.shape) and out.dtype == input.dtype and out.is_contiguous(): + _get_kernel_1d(half, double)(input, other, out) + return out + + if ( + out is None + and input.ndim == 2 + and other.ndim == 2 + and input.shape[1] == 1 + and other.shape[0] == 1 + and input.dtype == other.dtype + and input.dtype.is_floating_point + and input.is_contiguous() + and other.is_contiguous() + ): + rows = input.shape[0] + cols = other.shape[1] + half = input.dtype == torch.float16 + double = input.dtype == torch.float64 + out = torch.empty((rows, cols), dtype=input.dtype, device=input.device) + _get_broadcast_2d_kernel(half, double)( + input, + other, + out, + ) + return out + + input, other = _broadcast(input, other) + input, other, result_dtype = _prepare_inputs(input, other) + out = _prepare_out(out, input.shape, result_dtype, input.device, like=input) + + kernel_input, kernel_other, kernel_out = _flatten_kernel_tensors(input, other, out) + kernel = _cached_make( + ntops.kernels.copysign.premake, + kernel_input.ndim, + input.dtype == torch.float16, + input.dtype == torch.float64, + block_size=ntops.kernels.copysign.BLOCK_SIZE, + num_warps=4, + max_num_configs=1, + ) + kernel(kernel_input, kernel_other, kernel_out) + + return out diff --git a/tests/test_copysign.py b/tests/test_copysign.py new file mode 100644 index 0000000..e8b7f24 --- /dev/null +++ b/tests/test_copysign.py @@ -0,0 +1,135 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +def _assert_equal_with_nan(output, reference): + assert output.shape == reference.shape + assert output.dtype == reference.dtype + assert torch.equal(torch.isnan(output), torch.isnan(reference)) + mask = ~torch.isnan(reference) + assert torch.equal(output[mask], reference[mask]) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64]) +@pytest.mark.parametrize( + "input_shape, other_shape", + [ + ((0,), (0,)), + ((9,), (9,)), + ((4, 1), (1, 7)), + ((2, 3, 4), (1, 3, 1)), + ((4, 1), (3,)), + ((2, 1, 4), (1, 3, 1)), + ], +) +def test_copysign_float_shapes(dtype, input_shape, other_shape): + input = torch.randn(input_shape, dtype=dtype, device="cuda") + other = torch.randn(other_shape, dtype=dtype, device="cuda") + + output = ntops.torch.copysign(input, other) + reference = torch.copysign(input, other) + + _assert_equal_with_nan(output, reference) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize( + "input_dtype, other_dtype", + [ + (torch.int32, torch.int32), + (torch.int16, torch.float32), + (torch.float16, torch.int16), + (torch.bool, torch.int32), + ], +) +def test_copysign_promotes_like_torch(input_dtype, other_dtype): + input = torch.tensor([-1, 2, -3], dtype=input_dtype, device="cuda") + other = torch.tensor([1, -1, -2], dtype=other_dtype, device="cuda") + + output = ntops.torch.copysign(input, other) + reference = torch.copysign(input, other) + + _assert_equal_with_nan(output, reference) + + +@skip_if_cuda_not_available +def test_copysign_non_contiguous_and_out(): + input = torch.randn((5, 7), dtype=torch.float32, device="cuda").t() + other = torch.randn((5, 7), dtype=torch.float32, device="cuda").t() + out = torch.empty_like(input) + + result = ntops.torch.copysign(input, other, out=out) + reference = torch.copysign(input, other) + + assert result is out + _assert_equal_with_nan(out, reference) + + +@skip_if_cuda_not_available +def test_copysign_3d_permute_non_contiguous_and_out(): + input = torch.randn((3, 5, 7), dtype=torch.float32, device="cuda").permute(2, 0, 1) + other = torch.randn((3, 5, 7), dtype=torch.float32, device="cuda").permute(2, 0, 1) + out = torch.empty_like(input) + + result = ntops.torch.copysign(input, other, out=out) + reference = torch.copysign(input, other) + + assert result is out + _assert_equal_with_nan(out, reference) + + +@skip_if_cuda_not_available +def test_copysign_scalar(): + input = torch.tensor(2.0, dtype=torch.float32, device="cuda") + other = torch.tensor(-1.0, dtype=torch.float32, device="cuda") + + output = ntops.torch.copysign(input, other) + reference = torch.copysign(input, other) + + _assert_equal_with_nan(output, reference) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize( + "dtype, view_dtype", + [(torch.float16, torch.int16), (torch.float32, torch.int32), (torch.float64, torch.int64)], +) +def test_copysign_special_sign_bits(dtype, view_dtype): + input = torch.tensor([0.0, -0.0, float("inf"), -float("inf"), float("nan")], dtype=dtype, device="cuda") + other = torch.tensor([-0.0, 0.0, -1.0, 1.0, -0.0], dtype=dtype, device="cuda") + + output = ntops.torch.copysign(input, other) + reference = torch.copysign(input, other) + + assert torch.equal(torch.isnan(output), torch.isnan(reference)) + mask = ~torch.isnan(reference) + assert torch.equal(output[mask].view(view_dtype), reference[mask].view(view_dtype)) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_copysign_resizes_out_like_torch(dtype): + input = torch.randn((2, 3), dtype=dtype, device="cuda") + other = torch.randn((2, 3), dtype=dtype, device="cuda") + out = torch.empty((1,), dtype=dtype, device="cuda") + + with pytest.warns(UserWarning): + result = ntops.torch.copysign(input, other, out=out) + reference = torch.copysign(input, other) + + assert result is out + _assert_equal_with_nan(out, reference) + + +@skip_if_cuda_not_available +def test_copysign_rejects_integer_out_for_float_result(): + input = torch.tensor([-1, 2], dtype=torch.int32, device="cuda") + other = torch.tensor([1, -1], dtype=torch.int32, device="cuda") + out = torch.empty_like(input) + + with pytest.raises(RuntimeError): + ntops.torch.copysign(input, other, out=out) From 739da9eaf5e9f3f10288d7ad5bad80a2b744edfb Mon Sep 17 00:00:00 2001 From: HyosungSink Date: Mon, 18 May 2026 12:42:38 +0000 Subject: [PATCH 4/9] Add ntops lcm operator --- src/ntops/kernels/lcm.py | 193 +++++++++++++++++++++++++++++++++++++++ src/ntops/torch/lcm.py | 139 ++++++++++++++++++++++++++++ tests/test_lcm.py | 150 ++++++++++++++++++++++++++++++ 3 files changed, 482 insertions(+) create mode 100644 src/ntops/kernels/lcm.py create mode 100644 src/ntops/torch/lcm.py create mode 100644 tests/test_lcm.py diff --git a/src/ntops/kernels/lcm.py b/src/ntops/kernels/lcm.py new file mode 100644 index 0000000..93d49e6 --- /dev/null +++ b/src/ntops/kernels/lcm.py @@ -0,0 +1,193 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +BLOCK_SIZE = 64 + + +def broadcast_2d_arrangement(input, other, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + input = input.expand((-1, other.shape[1])) + other = other.expand((input.shape[0], -1)) + return tuple(tensor.flatten().tile((block_size,)) for tensor in (input, other, output)) + + +def _gcd_parts(input, other, iterations): + x = ntl.abs(input) + y = ntl.abs(other) + a = x + b = y + + for _ in range(iterations): + safe_b = ntl.where(b == 0, 1, b) + r = a % safe_b + a = ntl.where(b == 0, a, b) + b = ntl.where(b == 0, b, r) + + return x, y, a + + +def _apply_lcm(input, other, output, iterations): + x, y, gcd = _gcd_parts(input, other, iterations) + safe_gcd = ntl.where(gcd == 0, 1, gcd) + value = (x // safe_gcd) * y + input_min = (input < 0) & (-input == input) + other_min = (other < 0) & (-other == other) + min_overflow = input_min | other_min + overflow_value = ntl.where(input_min, input, other) + value = ntl.where(min_overflow, overflow_value, value) + output = ntl.where(gcd == 0, 0, value) # noqa: F841 + + +def _apply_lcm_abs(input, other, output, iterations): + x, y, gcd = _gcd_parts(input, other, iterations) + safe_gcd = ntl.where(gcd == 0, 1, gcd) + value = ntl.abs((x // safe_gcd) * y) + input_min = (input < 0) & (-input == input) + other_min = (other < 0) & (-other == other) + min_overflow = input_min | other_min + overflow_value = ntl.where(input_min, input, other) + value = ntl.where(min_overflow, overflow_value, value) + output = ntl.where(gcd == 0, 0, value) # noqa: F841 + + +def _apply_lcm_dynamic(input, other, output, max_iterations, absolute_output): + x = ntl.abs(input) + y = ntl.abs(other) + input_min = (input < 0) & (-input == input) + other_min = (other < 0) & (-other == other) + min_overflow = input_min | other_min + a = ntl.where(min_overflow, 1, x) + b = ntl.where(min_overflow, 1, y) + iteration = 0 + + while (ntl.max(b) != 0) and (iteration < max_iterations): + safe_b = ntl.where(b == 0, 1, b) + r = a % safe_b + a = ntl.where(b == 0, a, b) + b = ntl.where(b == 0, b, r) + iteration += 1 + + safe_gcd = ntl.where(a == 0, 1, a) + value = (x // safe_gcd) * y + if absolute_output: + value = ntl.abs(value) + overflow_value = ntl.where(input_min, input, other) + value = ntl.where(min_overflow, overflow_value, value) + output = ntl.where((input == 0) | (other == 0), 0, value) # noqa: F841 + + +def application_16(input, other, output): + _apply_lcm(input, other, output, 16) + + +def application_16_dynamic(input, other, output): + _apply_lcm_dynamic(input, other, output, 16, False) + + +def application_16_dynamic_i32(input, other, output): + _apply_lcm_dynamic(ntl.cast(input, ntl.int32), ntl.cast(other, ntl.int32), output, 16, False) + + +def application_24(input, other, output): + _apply_lcm(input, other, output, 24) + + +def application_24_dynamic(input, other, output): + _apply_lcm_dynamic(input, other, output, 24, False) + + +def application_24_dynamic_i32(input, other, output): + _apply_lcm_dynamic(ntl.cast(input, ntl.int32), ntl.cast(other, ntl.int32), output, 24, False) + + +def application_32(input, other, output): + _apply_lcm(input, other, output, 32) + + +def application_48(input, other, output): + _apply_lcm(input, other, output, 48) + + +def application_48_dynamic_abs(input, other, output): + _apply_lcm_dynamic(input, other, output, 48, True) + + +def application_48_dynamic_i32(input, other, output): + _apply_lcm_dynamic(ntl.cast(input, ntl.int32), ntl.cast(other, ntl.int32), output, 48, False) + + +def application_48_abs(input, other, output): + _apply_lcm_abs(input, other, output, 48) + + +def application_64(input, other, output): + _apply_lcm(input, other, output, 64) + + +def application_96(input, other, output): + _apply_lcm(input, other, output, 96) + + +def application_96_dynamic_abs(input, other, output): + _apply_lcm_dynamic(input, other, output, 96, True) + + +def application_96_abs(input, other, output): + _apply_lcm_abs(input, other, output, 96) + + +def premake( + ndim, + iterations=96, + absolute_output=False, + dynamic_iterations=False, + small_integer=False, + broadcast_2d=False, + dtype=None, + block_size=BLOCK_SIZE, +): + arrangement_func = broadcast_2d_arrangement if broadcast_2d else arrangement + arrangement_ = functools.partial(arrangement_func, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + applications = { + 16: application_16, + (16, False, True): application_16_dynamic, + (16, False, True, True): application_16_dynamic_i32, + 24: application_24, + (24, False, True): application_24_dynamic, + (24, False, True, True): application_24_dynamic_i32, + 32: application_32, + 48: application_48, + (48, True): application_48_abs, + (48, True, True): application_48_dynamic_abs, + (48, False, True, True): application_48_dynamic_i32, + 64: application_64, + 96: application_96, + (96, True): application_96_abs, + (96, True, True): application_96_dynamic_abs, + } + + key = ( + (iterations, absolute_output, True, True) + if dynamic_iterations and small_integer + else ( + (iterations, absolute_output, True) + if dynamic_iterations + else ((iterations, True) if absolute_output else iterations) + ) + ) + return arrangement_, applications[key], tensors diff --git a/src/ntops/torch/lcm.py b/src/ntops/torch/lcm.py new file mode 100644 index 0000000..902cfc7 --- /dev/null +++ b/src/ntops/torch/lcm.py @@ -0,0 +1,139 @@ +import functools + +import torch + +import ntops +from ntops.torch.utils import _cached_make, _flatten_kernel_tensors, _prepare_out + + +def _broadcast(input, other): + if hasattr(torch, "broadcast_tensors"): + return torch.broadcast_tensors(input, other) + return input, other + + +def _prepare_inputs(input, other): + if not hasattr(torch, "result_type"): + return input, other, input.dtype + + result_dtype = torch.result_type(input, other) + if result_dtype == torch.bool or result_dtype.is_floating_point: + raise NotImplementedError(f"lcm is not implemented for {result_dtype}") + + return input.to(result_dtype), other.to(result_dtype), result_dtype + + +def _iterations_for_dtype(dtype): + if hasattr(torch, "int8") and dtype in (torch.int8, torch.uint8): + return 16 + if hasattr(torch, "int16") and dtype == torch.int16: + return 24 + if hasattr(torch, "int32") and dtype == torch.int32: + return 48 + return 96 + + +def _uses_absolute_overflow(dtype): + return dtype in (torch.int32, torch.int64) + + +def _num_warps_for_dtype(dtype): + return 1 + + +def _block_size_for_dtype(dtype): + return 32 if dtype == torch.int64 else ntops.kernels.lcm.BLOCK_SIZE + + +def _uses_small_integer_kernel(dtype): + return dtype in (torch.int8, torch.uint8, torch.int16) + + +@functools.cache +def _get_kernel_1d(dtype): + return _cached_make( + ntops.kernels.lcm.premake, + 1, + _iterations_for_dtype(dtype), + _uses_absolute_overflow(dtype), + dynamic_iterations=True, + small_integer=_uses_small_integer_kernel(dtype), + block_size=_block_size_for_dtype(dtype), + num_warps=_num_warps_for_dtype(dtype), + max_num_configs=1, + ) + + +@functools.cache +def _get_broadcast_2d_kernel(dtype): + return _cached_make( + ntops.kernels.lcm.premake, + 2, + _iterations_for_dtype(dtype), + _uses_absolute_overflow(dtype), + dynamic_iterations=True, + small_integer=_uses_small_integer_kernel(dtype), + broadcast_2d=True, + block_size=_block_size_for_dtype(dtype), + num_warps=_num_warps_for_dtype(dtype), + max_num_configs=1, + ) + + +def lcm(input, other, *, out=None): + if ( + input.ndim == 1 + and other.ndim == 1 + and tuple(input.shape) == tuple(other.shape) + and input.dtype == other.dtype + and not input.dtype.is_floating_point + and input.dtype != torch.bool + and input.is_contiguous() + and other.is_contiguous() + ): + if out is None: + out = torch.empty_like(input) + _get_kernel_1d(input.dtype)(input, other, out) + return out + if tuple(out.shape) == tuple(input.shape) and out.dtype == input.dtype and out.is_contiguous(): + _get_kernel_1d(input.dtype)(input, other, out) + return out + + if ( + out is None + and input.ndim == 2 + and other.ndim == 2 + and input.shape[1] == 1 + and other.shape[0] == 1 + and input.dtype == other.dtype + and not input.dtype.is_floating_point + and input.dtype != torch.bool + and input.is_contiguous() + and other.is_contiguous() + ): + rows = input.shape[0] + cols = other.shape[1] + out = torch.empty((rows, cols), dtype=input.dtype, device=input.device) + _get_broadcast_2d_kernel(input.dtype)(input, other, out) + return out + + input, other = _broadcast(input, other) + input, other, result_dtype = _prepare_inputs(input, other) + out = _prepare_out(out, input.shape, result_dtype, input.device, like=input) + + kernel_input, kernel_other, kernel_out = _flatten_kernel_tensors(input, other, out) + + kernel = _cached_make( + ntops.kernels.lcm.premake, + kernel_input.ndim, + _iterations_for_dtype(input.dtype), + _uses_absolute_overflow(input.dtype), + dynamic_iterations=True, + small_integer=_uses_small_integer_kernel(input.dtype), + block_size=_block_size_for_dtype(input.dtype), + num_warps=_num_warps_for_dtype(input.dtype), + max_num_configs=1, + ) + kernel(kernel_input, kernel_other, kernel_out) + + return out diff --git a/tests/test_lcm.py b/tests/test_lcm.py new file mode 100644 index 0000000..43904cf --- /dev/null +++ b/tests/test_lcm.py @@ -0,0 +1,150 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.uint8, torch.int16, torch.int32, torch.int64]) +@pytest.mark.parametrize( + "input_shape, other_shape", + [((0,), (0,)), ((11,), (11,)), ((4, 1), (1, 7)), ((2, 3, 4), (1, 3, 1))], +) +def test_lcm_integer_shapes(dtype, input_shape, other_shape): + low = 0 if dtype == torch.uint8 else -20 + input = torch.randint(low, 21, input_shape, dtype=dtype, device="cuda") + other = torch.randint(low, 21, other_shape, dtype=dtype, device="cuda") + + output = ntops.torch.lcm(input, other) + reference = torch.lcm(input, other) + + assert output.dtype == reference.dtype + assert torch.equal(output, reference) + + +@skip_if_cuda_not_available +def test_lcm_zero_and_sign_cases(): + input = torch.tensor([0, 6, -4, -9, 21], dtype=torch.int32, device="cuda") + other = torch.tensor([3, 4, -6, 0, -6], dtype=torch.int32, device="cuda") + + output = ntops.torch.lcm(input, other) + reference = torch.lcm(input, other) + + assert torch.equal(output, reference) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize( + "dtype, lhs, rhs", + [ + (torch.int8, -128, -1), + (torch.int16, -32768, -1), + (torch.int32, -2147483648, -1), + (torch.int64, -9223372036854775808, -1), + ], +) +def test_lcm_min_integer_overflow_cases(dtype, lhs, rhs): + input = torch.tensor([lhs], dtype=dtype, device="cuda") + other = torch.tensor([rhs], dtype=dtype, device="cuda") + + output = ntops.torch.lcm(input, other) + reference = torch.lcm(input, other) + + assert torch.equal(output, reference) + + +@skip_if_cuda_not_available +def test_lcm_worst_case_euclid_inputs(): + cases = [ + (torch.int16, 28657, 17711), + (torch.int32, 1836311903, 1134903170), + (torch.int64, 7540113804746346429, 4660046610375530309), + ] + + for dtype, lhs, rhs in cases: + input = torch.tensor([lhs], dtype=dtype, device="cuda") + other = torch.tensor([rhs], dtype=dtype, device="cuda") + + output = ntops.torch.lcm(input, other) + reference = torch.lcm(input, other) + + assert torch.equal(output, reference) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize( + "input_dtype, other_dtype", + [ + (torch.int16, torch.int32), + (torch.int32, torch.int64), + (torch.uint8, torch.int16), + ], +) +def test_lcm_mixed_dtype_promotes_like_torch(input_dtype, other_dtype): + input = torch.tensor([6, 1], dtype=input_dtype, device="cuda") + other = torch.tensor([4, 6], dtype=other_dtype, device="cuda") + + output = ntops.torch.lcm(input, other) + reference = torch.lcm(input, other) + + assert output.dtype == reference.dtype + assert torch.equal(output, reference) + + +@skip_if_cuda_not_available +def test_lcm_bool_bool_unsupported(): + input = torch.tensor([True, False], device="cuda") + other = torch.tensor([True, True], device="cuda") + + with pytest.raises(NotImplementedError): + ntops.torch.lcm(input, other) + + +@skip_if_cuda_not_available +def test_lcm_float_unsupported(): + input = torch.tensor([1.0, 2.0], device="cuda") + other = torch.tensor([1.0, 3.0], device="cuda") + + with pytest.raises(NotImplementedError): + ntops.torch.lcm(input, other) + + +@skip_if_cuda_not_available +def test_lcm_non_contiguous_and_out(): + input = torch.randint(-20, 21, (5, 7), dtype=torch.int32, device="cuda").t() + other = torch.randint(-20, 21, (5, 7), dtype=torch.int32, device="cuda").t() + out = torch.empty_like(input) + + result = ntops.torch.lcm(input, other, out=out) + reference = torch.lcm(input, other) + + assert result is out + assert torch.equal(out, reference) + + +@skip_if_cuda_not_available +def test_lcm_3d_permute_non_contiguous_and_out(): + input = torch.randint(-20, 21, (3, 5, 7), dtype=torch.int32, device="cuda").permute(2, 0, 1) + other = torch.randint(-20, 21, (3, 5, 7), dtype=torch.int32, device="cuda").permute(2, 0, 1) + out = torch.empty_like(input) + + result = ntops.torch.lcm(input, other, out=out) + reference = torch.lcm(input, other) + + assert result is out + assert torch.equal(out, reference) + + +@skip_if_cuda_not_available +def test_lcm_resizes_out_like_torch(): + input = torch.randint(1, 10, (2, 3), dtype=torch.int32, device="cuda") + other = torch.randint(1, 10, (2, 3), dtype=torch.int32, device="cuda") + out = torch.empty((1,), dtype=torch.int32, device="cuda") + + with pytest.warns(UserWarning): + result = ntops.torch.lcm(input, other, out=out) + reference = torch.lcm(input, other) + + assert result is out + assert torch.equal(out, reference) From 8d0c779e4ae0e9d8ea1a82a9762dfe438bd5345b Mon Sep 17 00:00:00 2001 From: HyosungSink Date: Mon, 18 May 2026 12:42:38 +0000 Subject: [PATCH 5/9] Add ntops lgamma operator --- src/ntops/kernels/lgamma.py | 28 ++++++++++ src/ntops/torch/lgamma.py | 57 +++++++++++++++++++ tests/test_lgamma.py | 107 ++++++++++++++++++++++++++++++++++++ 3 files changed, 192 insertions(+) create mode 100644 src/ntops/kernels/lgamma.py create mode 100644 src/ntops/torch/lgamma.py create mode 100644 tests/test_lgamma.py diff --git a/src/ntops/kernels/lgamma.py b/src/ntops/kernels/lgamma.py new file mode 100644 index 0000000..78af7db --- /dev/null +++ b/src/ntops/kernels/lgamma.py @@ -0,0 +1,28 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor +from ninetoothed.language import libdevice + +from ntops.kernels.element_wise import arrangement + + +BLOCK_SIZE = 8192 + + +def application(input, output): + output = libdevice.lgamma(input) # noqa: F841 + + +def half_application(input, output): + output = ntl.cast(libdevice.lgamma(ntl.cast(input, ntl.float32)), ntl.float16) # noqa: F841 + + +def premake(ndim, half=False, dtype=None, block_size=BLOCK_SIZE): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + application_ = half_application if half else application + + return arrangement_, application_, tensors diff --git a/src/ntops/torch/lgamma.py b/src/ntops/torch/lgamma.py new file mode 100644 index 0000000..bfe55bc --- /dev/null +++ b/src/ntops/torch/lgamma.py @@ -0,0 +1,57 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make, _flatten_kernel_tensors, _prepare_out + + +_kernel_1d = {} + + +def _get_kernel_1d(half): + kernel = _kernel_1d.get(half) + if kernel is None: + kernel = _cached_make( + ntops.kernels.lgamma.premake, + 1, + half, + block_size=ntops.kernels.lgamma.BLOCK_SIZE if half else 1024, + num_warps=8 if half else 4, + max_num_configs=1, + ) + _kernel_1d[half] = kernel + return kernel + + +def _promote_unary_input(input): + if hasattr(torch, "is_floating_point") and not torch.is_floating_point(input): + return input.to(torch.float32) + return input + + +def lgamma(input, *, out=None): + input = _promote_unary_input(input) + if input.ndim == 1 and input.is_contiguous(): + half = input.dtype == torch.float16 + if out is None: + out = torch.empty_like(input) + _get_kernel_1d(half)(input, out) + return out + if tuple(out.shape) == tuple(input.shape) and out.dtype == input.dtype and out.is_contiguous(): + _get_kernel_1d(half)(input, out) + return out + + out = _prepare_out(out, input.shape, input.dtype, input.device, like=input) + + kernel_input, kernel_out = _flatten_kernel_tensors(input, out) + half = hasattr(torch, "float16") and input.dtype == torch.float16 + kernel = _cached_make( + ntops.kernels.lgamma.premake, + kernel_input.ndim, + half, + block_size=ntops.kernels.lgamma.BLOCK_SIZE if half else 1024, + num_warps=8 if half else 4, + max_num_configs=1, + ) + kernel(kernel_input, kernel_out) + + return out diff --git a/tests/test_lgamma.py b/tests/test_lgamma.py new file mode 100644 index 0000000..3070671 --- /dev/null +++ b/tests/test_lgamma.py @@ -0,0 +1,107 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +def _assert_close(output, reference, rtol=2e-3, atol=2e-3): + assert output.shape == reference.shape + assert output.dtype == reference.dtype + assert torch.allclose(output, reference, rtol=rtol, atol=atol, equal_nan=True) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64]) +@pytest.mark.parametrize( + "shape", + [(0,), (1,), (7,), (3, 5), (2, 3, 4), (1, 4, 1), (8, 1)], +) +def test_lgamma_float_shapes(shape, dtype): + input = torch.rand(shape, dtype=dtype, device="cuda") * 8 + 0.25 + + output = ntops.torch.lgamma(input) + reference = torch.lgamma(input) + + _assert_close(output, reference) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.int16, torch.int32, torch.int64]) +def test_lgamma_integer_promotes_to_float32(dtype): + input = torch.tensor([1, 2, 3, 7], dtype=dtype, device="cuda") + + output = ntops.torch.lgamma(input) + reference = torch.lgamma(input) + + _assert_close(output, reference, rtol=1e-4, atol=1e-4) + + +@skip_if_cuda_not_available +def test_lgamma_special_values(): + input = torch.tensor( + [-10.0, -3.0, -2.5, -2.0, -1.5, -1.0, -0.0, 0.0, 1e-6, 0.5, 1.0, 2.0, float("inf"), float("nan")], + dtype=torch.float32, + device="cuda", + ) + + output = ntops.torch.lgamma(input) + reference = torch.lgamma(input) + + _assert_close(output, reference, rtol=1e-4, atol=1e-4) + + +@skip_if_cuda_not_available +def test_lgamma_non_contiguous_and_out(): + input = (torch.rand((5, 7), dtype=torch.float32, device="cuda") * 8 + 0.25).t() + out = torch.empty_like(input) + + result = ntops.torch.lgamma(input, out=out) + reference = torch.lgamma(input) + + assert result is out + _assert_close(out, reference, rtol=1e-4, atol=1e-4) + + +@skip_if_cuda_not_available +def test_lgamma_3d_permute_non_contiguous_and_out(): + input = (torch.rand((3, 5, 7), dtype=torch.float32, device="cuda") * 8 + 0.25).permute(2, 0, 1) + out = torch.empty_like(input) + + result = ntops.torch.lgamma(input, out=out) + reference = torch.lgamma(input) + + assert result is out + _assert_close(out, reference, rtol=1e-4, atol=1e-4) + + +@skip_if_cuda_not_available +def test_lgamma_scalar(): + input = torch.tensor(2.5, dtype=torch.float32, device="cuda") + + output = ntops.torch.lgamma(input) + reference = torch.lgamma(input) + + _assert_close(output, reference, rtol=1e-4, atol=1e-4) + + +@skip_if_cuda_not_available +def test_lgamma_resizes_out_like_torch(): + input = torch.rand((2, 3), dtype=torch.float32, device="cuda") + 0.25 + out = torch.empty((1,), dtype=torch.float32, device="cuda") + + with pytest.warns(UserWarning): + result = ntops.torch.lgamma(input, out=out) + reference = torch.lgamma(input) + + assert result is out + _assert_close(out, reference, rtol=1e-4, atol=1e-4) + + +@skip_if_cuda_not_available +def test_lgamma_rejects_integer_out_for_float_result(): + input = torch.tensor([1], dtype=torch.int32, device="cuda") + out = torch.empty_like(input) + + with pytest.raises(RuntimeError): + ntops.torch.lgamma(input, out=out) From c207dbb2dbd867f002940dcd860390ff5e5b39f2 Mon Sep 17 00:00:00 2001 From: HyosungSink Date: Mon, 18 May 2026 12:42:38 +0000 Subject: [PATCH 6/9] Add ntops nextafter operator --- src/ntops/kernels/nextafter.py | 72 ++++++++++++++++ src/ntops/torch/nextafter.py | 114 ++++++++++++++++++++++++++ tests/test_nextafter.py | 145 +++++++++++++++++++++++++++++++++ 3 files changed, 331 insertions(+) create mode 100644 src/ntops/kernels/nextafter.py create mode 100644 src/ntops/torch/nextafter.py create mode 100644 tests/test_nextafter.py diff --git a/src/ntops/kernels/nextafter.py b/src/ntops/kernels/nextafter.py new file mode 100644 index 0000000..c84ca51 --- /dev/null +++ b/src/ntops/kernels/nextafter.py @@ -0,0 +1,72 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor +from ninetoothed.language import libdevice + +from ntops.kernels.element_wise import arrangement + + +BLOCK_SIZE = 128 + + +def broadcast_2d_arrangement(input, other, output, block_size=None): + if block_size is None: + block_size = ninetoothed.block_size() + + input = input.expand((-1, other.shape[1])) + other = other.expand((input.shape[0], -1)) + return tuple(tensor.flatten().tile((block_size,)) for tensor in (input, other, output)) + + +def application(input, other, output): + value = libdevice.nextafter(input, other) + zero_value = ntl.where(other < 0, -1.401298464324817e-45, 1.401298464324817e-45) + zero_value = ntl.where(other == 0, other, zero_value) + value = ntl.where(input == 0, zero_value, value) + output = ntl.where(other != other, other, value) # noqa: F841 + + +def double_application(input, other, output): + value = libdevice.nextafter(input, other) + zero_value = ntl.where(other < 0, -4.9406564584124654e-324, 4.9406564584124654e-324) + zero_value = ntl.where(other == 0, other, zero_value) + value = ntl.where(input == 0, zero_value, value) + output = ntl.where(other != other, other, value) # noqa: F841 + + +def half_application(input, other, output): + bits = ntl.cast(input, ntl.uint16, bitcast=True) + next_bits = ntl.where( + ntl.where(input > 0, other > input, other < input), + bits + 1, + bits - 1, + ) + next_value = ntl.cast(next_bits, ntl.float16, bitcast=True) + zero_value = ntl.where(other < 0, -5.960464477539063e-08, 5.960464477539063e-08) + zero_value = ntl.where(other == 0, other, zero_value) + value = ntl.where(input == 0, zero_value, next_value) + value = ntl.where(input == other, other, value) + value = ntl.where(input != input, input, value) + output = ntl.where(other != other, other, value) # noqa: F841 + + +def premake(ndim, half=False, double=False, broadcast_2d=False, dtype=None, block_size=BLOCK_SIZE): + arrangement_func = broadcast_2d_arrangement if broadcast_2d else arrangement + arrangement_ = functools.partial(arrangement_func, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + if half: + application_ = half_application + elif double: + application_ = double_application + else: + application_ = application + + return arrangement_, application_, tensors diff --git a/src/ntops/torch/nextafter.py b/src/ntops/torch/nextafter.py new file mode 100644 index 0000000..b74ac73 --- /dev/null +++ b/src/ntops/torch/nextafter.py @@ -0,0 +1,114 @@ +import functools + +import torch + +import ntops +from ntops.torch.utils import _cached_make, _flatten_kernel_tensors, _prepare_out + + +def _broadcast(input, other): + if hasattr(torch, "broadcast_tensors"): + return torch.broadcast_tensors(input, other) + return input, other + + +def _prepare_inputs(input, other): + if not hasattr(torch, "result_type"): + return input, other, input.dtype + + result_dtype = torch.result_type(input, other) + if not result_dtype.is_floating_point: + raise NotImplementedError("nextafter is only implemented for floating point inputs") + + return input.to(result_dtype), other.to(result_dtype), result_dtype + + +@functools.cache +def _get_kernel_1d(half, double): + return _cached_make( + ntops.kernels.nextafter.premake, + 1, + half, + double, + block_size=ntops.kernels.nextafter.BLOCK_SIZE, + num_warps=1, + max_num_configs=1, + ) + + +@functools.cache +def _get_broadcast_2d_kernel(half, double): + return _cached_make( + ntops.kernels.nextafter.premake, + 2, + half, + double, + True, + block_size=512, + num_warps=1, + max_num_configs=1, + ) + + +def nextafter(input, other, *, out=None): + if ( + input.ndim == 1 + and other.ndim == 1 + and tuple(input.shape) == tuple(other.shape) + and input.dtype == other.dtype + and input.dtype.is_floating_point + and input.is_contiguous() + and other.is_contiguous() + ): + half = input.dtype == torch.float16 + double = input.dtype == torch.float64 + if out is None: + out = torch.empty_like(input) + _get_kernel_1d(half, double)(input, other, out) + return out + if tuple(out.shape) == tuple(input.shape) and out.dtype == input.dtype and out.is_contiguous(): + _get_kernel_1d(half, double)(input, other, out) + return out + + if ( + out is None + and input.ndim == 2 + and other.ndim == 2 + and input.shape[1] == 1 + and other.shape[0] == 1 + and input.dtype == other.dtype + and input.dtype.is_floating_point + and input.is_contiguous() + and other.is_contiguous() + ): + rows = input.shape[0] + cols = other.shape[1] + half = input.dtype == torch.float16 + double = input.dtype == torch.float64 + out = torch.empty((rows, cols), dtype=input.dtype, device=input.device) + _get_broadcast_2d_kernel(half, double)( + input, + other, + out, + ) + return out + + input, other = _broadcast(input, other) + input, other, result_dtype = _prepare_inputs(input, other) + out = _prepare_out(out, input.shape, result_dtype, input.device, like=input) + + kernel_input, kernel_other, kernel_out = _flatten_kernel_tensors(input, other, out) + half = hasattr(torch, "float16") and input.dtype == torch.float16 + double = hasattr(torch, "float64") and input.dtype == torch.float64 + kernel = _cached_make( + ntops.kernels.nextafter.premake, + kernel_input.ndim, + half, + double, + block_size=ntops.kernels.nextafter.BLOCK_SIZE, + num_warps=1, + max_num_configs=1, + ) + kernel(kernel_input, kernel_other, kernel_out) + + return out diff --git a/tests/test_nextafter.py b/tests/test_nextafter.py new file mode 100644 index 0000000..1b076bb --- /dev/null +++ b/tests/test_nextafter.py @@ -0,0 +1,145 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +def _assert_nextafter_equal(output, reference): + assert output.shape == reference.shape + assert output.dtype == reference.dtype + assert torch.equal(torch.isnan(output), torch.isnan(reference)) + mask = ~torch.isnan(reference) + assert torch.equal(output[mask], reference[mask]) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64]) +@pytest.mark.parametrize( + "input_shape, other_shape", + [ + ((0,), (0,)), + ((9,), (9,)), + ((4, 1), (1, 7)), + ((2, 3, 4), (1, 3, 1)), + ((4, 1), (3,)), + ], +) +def test_nextafter_float_shapes(dtype, input_shape, other_shape): + input = torch.randn(input_shape, dtype=dtype, device="cuda") + other = torch.randn(other_shape, dtype=dtype, device="cuda") + + output = ntops.torch.nextafter(input, other) + reference = torch.nextafter(input, other) + + _assert_nextafter_equal(output, reference) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64]) +def test_nextafter_special_value_grid(dtype): + if dtype == torch.float16: + values = [0.0, -0.0, 1.0, -1.0, float("inf"), -float("inf"), float("nan"), 65504.0, -65504.0] + else: + values = [0.0, -0.0, 1.0, -1.0, float("inf"), -float("inf"), float("nan"), 1e-37, -1e-37] + + input = torch.tensor(values, dtype=dtype, device="cuda").repeat_interleave(len(values)) + other = torch.tensor(values, dtype=dtype, device="cuda").repeat(len(values)) + + output = ntops.torch.nextafter(input, other) + reference = torch.nextafter(input, other) + + _assert_nextafter_equal(output, reference) + + +@skip_if_cuda_not_available +def test_nextafter_scalar(): + input = torch.tensor(0.0, dtype=torch.float32, device="cuda") + other = torch.tensor(1.0, dtype=torch.float32, device="cuda") + + output = ntops.torch.nextafter(input, other) + reference = torch.nextafter(input, other) + + _assert_nextafter_equal(output, reference) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize( + "input_dtype, other_dtype", + [ + (torch.float16, torch.float32), + (torch.float16, torch.float64), + (torch.float32, torch.float16), + (torch.float32, torch.float64), + (torch.float64, torch.float16), + (torch.float64, torch.float32), + ], +) +def test_nextafter_mixed_float_dtype_promotes_like_torch(input_dtype, other_dtype): + input = torch.tensor([1.0], dtype=input_dtype, device="cuda") + other = torch.tensor([2.0], dtype=other_dtype, device="cuda") + + output = ntops.torch.nextafter(input, other) + reference = torch.nextafter(input, other) + + _assert_nextafter_equal(output, reference) + + +@skip_if_cuda_not_available +def test_nextafter_integer_unsupported(): + input = torch.tensor([1], dtype=torch.int32, device="cuda") + other = torch.tensor([2], dtype=torch.int32, device="cuda") + + with pytest.raises(NotImplementedError): + ntops.torch.nextafter(input, other) + + +@skip_if_cuda_not_available +def test_nextafter_non_contiguous_and_out(): + input = torch.randn((5, 7), dtype=torch.float32, device="cuda").t() + other = torch.randn((5, 7), dtype=torch.float32, device="cuda").t() + out = torch.empty_like(input) + + result = ntops.torch.nextafter(input, other, out=out) + reference = torch.nextafter(input, other) + + assert result is out + _assert_nextafter_equal(out, reference) + + +@skip_if_cuda_not_available +def test_nextafter_3d_permute_non_contiguous_and_out(): + input = torch.randn((3, 5, 7), dtype=torch.float32, device="cuda").permute(2, 0, 1) + other = torch.randn((3, 5, 7), dtype=torch.float32, device="cuda").permute(2, 0, 1) + out = torch.empty_like(input) + + result = ntops.torch.nextafter(input, other, out=out) + reference = torch.nextafter(input, other) + + assert result is out + _assert_nextafter_equal(out, reference) + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dtype", [torch.float32]) +def test_nextafter_resizes_out_like_torch(dtype): + input = torch.randn((2, 3), dtype=dtype, device="cuda") + other = torch.randn((2, 3), dtype=dtype, device="cuda") + out = torch.empty((1,), dtype=dtype, device="cuda") + + with pytest.warns(UserWarning): + result = ntops.torch.nextafter(input, other, out=out) + reference = torch.nextafter(input, other) + + assert result is out + _assert_nextafter_equal(out, reference) + + +@skip_if_cuda_not_available +def test_nextafter_rejects_integer_out_for_float_result(): + input = torch.tensor([1.0], dtype=torch.float32, device="cuda") + other = torch.tensor([2.0], dtype=torch.float32, device="cuda") + out = torch.empty((1,), dtype=torch.int32, device="cuda") + + with pytest.raises(RuntimeError): + ntops.torch.nextafter(input, other, out=out) From 282416213c89ab434a62741ebd218f78cac65e84 Mon Sep 17 00:00:00 2001 From: HyosungSink Date: Mon, 18 May 2026 12:42:38 +0000 Subject: [PATCH 7/9] Register T1-1-1 ntops operators --- src/ntops/kernels/__init__.py | 10 + src/ntops/torch/__init__.py | 10 + src/ntops/torch/utils.py | 126 ++++++++++ tests/t1_1_1_performance_utils.py | 344 ++++++++++++++++++++++++++++ tests/test_copysign_performance.py | 11 + tests/test_lcm_performance.py | 11 + tests/test_lgamma_performance.py | 11 + tests/test_nextafter_performance.py | 11 + tests/test_rad2deg_performance.py | 11 + 9 files changed, 545 insertions(+) create mode 100644 tests/t1_1_1_performance_utils.py create mode 100644 tests/test_copysign_performance.py create mode 100644 tests/test_lcm_performance.py create mode 100644 tests/test_lgamma_performance.py create mode 100644 tests/test_nextafter_performance.py create mode 100644 tests/test_rad2deg_performance.py diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..83bae36 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -9,6 +9,7 @@ bmm, clamp, conv2d, + copysign, cos, div, dropout, @@ -21,13 +22,17 @@ isnan, layer_norm, le, + lcm, + lgamma, lt, max_pool2d, mm, mul, ne, neg, + nextafter, pow, + rad2deg, relu, rms_norm, rotary_position_embedding, @@ -52,6 +57,7 @@ "bmm", "clamp", "conv2d", + "copysign", "cos", "div", "dropout", @@ -64,13 +70,17 @@ "isnan", "layer_norm", "le", + "lcm", + "lgamma", "lt", "max_pool2d", "mm", "mul", "ne", "neg", + "nextafter", "pow", + "rad2deg", "relu", "rms_norm", "rotary_position_embedding", diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..89aabce 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -8,6 +8,7 @@ from ntops.torch.bmm import bmm from ntops.torch.clamp import clamp from ntops.torch.conv2d import conv2d +from ntops.torch.copysign import copysign from ntops.torch.cos import cos from ntops.torch.div import div from ntops.torch.dropout import dropout @@ -20,6 +21,8 @@ from ntops.torch.isnan import isnan from ntops.torch.layer_norm import layer_norm from ntops.torch.le import le +from ntops.torch.lcm import lcm +from ntops.torch.lgamma import lgamma from ntops.torch.lt import lt from ntops.torch.matmul import matmul from ntops.torch.max_pool2d import max_pool2d @@ -27,7 +30,9 @@ from ntops.torch.mul import mul from ntops.torch.ne import ne from ntops.torch.neg import neg +from ntops.torch.nextafter import nextafter from ntops.torch.pow import pow +from ntops.torch.rad2deg import rad2deg from ntops.torch.relu import relu from ntops.torch.rms_norm import rms_norm from ntops.torch.rotary_position_embedding import rotary_position_embedding @@ -51,6 +56,7 @@ "bmm", "clamp", "conv2d", + "copysign", "cos", "div", "dropout", @@ -63,6 +69,8 @@ "isnan", "layer_norm", "le", + "lcm", + "lgamma", "lt", "matmul", "max_pool2d", @@ -70,7 +78,9 @@ "mul", "ne", "neg", + "nextafter", "pow", + "rad2deg", "relu", "rms_norm", "rotary_position_embedding", diff --git a/src/ntops/torch/utils.py b/src/ntops/torch/utils.py index e9b2dde..e8b6f5c 100644 --- a/src/ntops/torch/utils.py +++ b/src/ntops/torch/utils.py @@ -1,4 +1,5 @@ import functools +import warnings import ninetoothed import torch @@ -63,6 +64,131 @@ def _cached_make( ) +def _reshape_tensor(tensor, shape): + reshape = getattr(tensor, "reshape", None) + if callable(reshape): + return reshape(shape) + return tensor.view(list(shape)) + + +def _is_contiguous(tensor): + is_contiguous = getattr(tensor, "is_contiguous", None) + if callable(is_contiguous): + return is_contiguous() + return bool(is_contiguous) + + +def _strides(tensor): + stride = getattr(tensor, "stride", None) + if callable(stride): + return tuple(stride()) + strides = getattr(tensor, "strides", None) + if strides is not None: + return tuple(strides) + return None + + +def _permute_tensor(tensor, dims): + permute = getattr(tensor, "permute", None) + if callable(permute): + return permute(dims) + raise TypeError("tensor does not support permute") + + +def _physical_contiguous_permutation(tensors): + if not tensors: + return None + + ndim = tensors[0].ndim + shape = tuple(tensors[0].shape) + if ndim <= 1 or any(tensor.ndim != ndim or tuple(tensor.shape) != shape for tensor in tensors): + return None + + strides = _strides(tensors[0]) + if strides is None or _is_contiguous(tensors[0]): + return None + + dims = tuple(sorted(range(ndim), key=lambda dim: strides[dim], reverse=True)) + if dims == tuple(range(ndim)): + return None + + try: + if not _is_contiguous(_permute_tensor(tensors[0], dims)): + return None + if not all(_is_contiguous(_permute_tensor(tensor, dims)) for tensor in tensors[1:]): + return None + except TypeError: + return None + + return dims + + +def _flatten_kernel_tensors(*tensors): + kernel_tensors = tuple( + _reshape_tensor(tensor, (1,)) if tensor.ndim == 0 else tensor + for tensor in tensors + ) + if all(tensor.ndim == 1 and _is_contiguous(tensor) for tensor in kernel_tensors): + return kernel_tensors + + physical_order = _physical_contiguous_permutation(kernel_tensors) + if physical_order is not None: + kernel_tensors = tuple(_permute_tensor(tensor, physical_order) for tensor in kernel_tensors) + + if all(tensor.ndim > 0 and _is_contiguous(tensor) for tensor in kernel_tensors): + return tuple(_reshape_tensor(tensor, (tensor.numel(),)) for tensor in kernel_tensors) + return kernel_tensors + + +def _check_out_dtype(result_dtype, out): + if out is None: + return + + try: + can_cast = torch.can_cast(result_dtype, out.dtype) if hasattr(torch, "can_cast") else True + except TypeError: + can_cast = result_dtype == out.dtype + + if not can_cast: + raise RuntimeError( + f"result type {result_dtype} can't be cast to the desired output type {out.dtype}" + ) + + +def _prepare_out(out, shape, dtype, device, like=None): + _check_out_dtype(dtype, out) + shape = tuple(shape) + + if out is None: + if like is not None and tuple(like.shape) == shape and like.dtype == dtype: + try: + return torch.empty_like(like) + except TypeError: + import infinicore + + return infinicore.empty_like(like, dtype=dtype, device=device) + try: + return torch.empty(shape, dtype=dtype, device=device) + except TypeError: + import infinicore + + return infinicore.empty(list(shape), dtype=dtype, device=device) + + if tuple(out.shape) != tuple(shape): + warnings.warn( + ( + f"An output with one or more elements was resized since it had shape " + f"{tuple(out.shape)}, which does not match the required output shape " + f"{tuple(shape)}." + ), + UserWarning, + stacklevel=2, + ) + out.resize_(shape) + + return out + + def _get_matmul_input_precision(): if torch.get_float32_matmul_precision() == "highest": return ntops.kernels.mm.InputPrecisionVariant.IEEE diff --git a/tests/t1_1_1_performance_utils.py b/tests/t1_1_1_performance_utils.py new file mode 100644 index 0000000..cac207a --- /dev/null +++ b/tests/t1_1_1_performance_utils.py @@ -0,0 +1,344 @@ +from dataclasses import dataclass + +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available + + +_MIN_TORCH_SPEED_RATIO = 0.9 +_LARGE_NUMEL = 1 << 24 +_MID_NUMEL = _LARGE_NUMEL +_SMALL_NUMEL = _LARGE_NUMEL + + +@dataclass(frozen=True) +class PerfCase: + op_name: str + case_name: str + make_pair: object + + +def _rand_float(shape, dtype): + return torch.randn(shape, dtype=dtype, device="cuda") + + +def _rand_lgamma(shape, dtype): + return torch.rand(shape, dtype=dtype, device="cuda") * 8 + 0.25 + + +def _rand_int(shape, dtype, low=-128, high=128): + return torch.randint(low, high, shape, dtype=dtype, device="cuda") + + +def _noncontig_float(side, dtype): + return torch.randn((side, side), dtype=dtype, device="cuda").t() + + +def _permute3d_float(shape, dtype): + return torch.randn(shape, dtype=dtype, device="cuda").permute(2, 0, 1) + + +def _noncontig_lgamma(side, dtype): + return (torch.rand((side, side), dtype=dtype, device="cuda") * 8 + 0.25).t() + + +def _permute3d_lgamma(shape, dtype): + return (torch.rand(shape, dtype=dtype, device="cuda") * 8 + 0.25).permute(2, 0, 1) + + +def _noncontig_int(side, dtype, low=-128, high=128): + return torch.randint(low, high, (side, side), dtype=dtype, device="cuda").t() + + +def _permute3d_int(shape, dtype, low=-128, high=128): + return torch.randint(low, high, shape, dtype=dtype, device="cuda").permute(2, 0, 1) + + +def _make_unary(op, ref, factory, shape, dtype, out=False): + def make_pair(): + input = factory(shape, dtype) + if not out: + return lambda: op(input), lambda: ref(input) + nt_out = torch.empty_like(input) + th_out = torch.empty_like(input) + return lambda: op(input, out=nt_out), lambda: ref(input, out=th_out) + + return make_pair + + +def _make_unary_noncontig(op, ref, factory, side, dtype, out=False): + def make_pair(): + input = factory(side, dtype) + if not out: + return lambda: op(input), lambda: ref(input) + nt_out = torch.empty_like(input) + th_out = torch.empty_like(input) + return lambda: op(input, out=nt_out), lambda: ref(input, out=th_out) + + return make_pair + + +def _make_unary_permute3d(op, ref, factory, shape, dtype, out=False): + def make_pair(): + input = factory(shape, dtype) + if not out: + return lambda: op(input), lambda: ref(input) + nt_out = torch.empty_like(input) + th_out = torch.empty_like(input) + return lambda: op(input, out=nt_out), lambda: ref(input, out=th_out) + + return make_pair + + +def _make_binary(op, ref, shape, dtype, out=False): + def make_pair(): + input = _rand_float(shape, dtype) + other = _rand_float(shape, dtype) + if not out: + return lambda: op(input, other), lambda: ref(input, other) + nt_out = torch.empty_like(input) + th_out = torch.empty_like(input) + return lambda: op(input, other, out=nt_out), lambda: ref(input, other, out=th_out) + + return make_pair + + +def _make_binary_noncontig(op, ref, side, dtype, out=False): + def make_pair(): + input = _noncontig_float(side, dtype) + other = _noncontig_float(side, dtype) + if not out: + return lambda: op(input, other), lambda: ref(input, other) + nt_out = torch.empty_like(input) + th_out = torch.empty_like(input) + return lambda: op(input, other, out=nt_out), lambda: ref(input, other, out=th_out) + + return make_pair + + +def _make_binary_permute3d(op, ref, shape, dtype, out=False): + def make_pair(): + input = _permute3d_float(shape, dtype) + other = _permute3d_float(shape, dtype) + if not out: + return lambda: op(input, other), lambda: ref(input, other) + nt_out = torch.empty_like(input) + th_out = torch.empty_like(input) + return lambda: op(input, other, out=nt_out), lambda: ref(input, other, out=th_out) + + return make_pair + + +def _make_binary_broadcast(op, ref, side, dtype): + def make_pair(): + input = _rand_float((side, 1), dtype) + other = _rand_float((1, side), dtype) + return lambda: op(input, other), lambda: ref(input, other) + + return make_pair + + +def _make_binary_broadcast_rect(op, ref, rows, cols, dtype): + def make_pair(): + input = _rand_float((rows, 1), dtype) + other = _rand_float((1, cols), dtype) + return lambda: op(input, other), lambda: ref(input, other) + + return make_pair + + +def _make_lcm(shape, dtype, low=-128, high=128, out=False): + def make_pair(): + input = _rand_int(shape, dtype, low, high) + other = _rand_int(shape, dtype, low, high) + if not out: + return lambda: ntops.torch.lcm(input, other), lambda: torch.lcm(input, other) + nt_out = torch.empty_like(input) + th_out = torch.empty_like(input) + return lambda: ntops.torch.lcm(input, other, out=nt_out), lambda: torch.lcm(input, other, out=th_out) + + return make_pair + + +def _make_lcm_noncontig(side, dtype, low=-128, high=128, out=False): + def make_pair(): + input = _noncontig_int(side, dtype, low, high) + other = _noncontig_int(side, dtype, low, high) + if not out: + return lambda: ntops.torch.lcm(input, other), lambda: torch.lcm(input, other) + nt_out = torch.empty_like(input) + th_out = torch.empty_like(input) + return lambda: ntops.torch.lcm(input, other, out=nt_out), lambda: torch.lcm(input, other, out=th_out) + + return make_pair + + +def _make_lcm_permute3d(shape, dtype, low=-128, high=128, out=False): + def make_pair(): + input = _permute3d_int(shape, dtype, low, high) + other = _permute3d_int(shape, dtype, low, high) + if not out: + return lambda: ntops.torch.lcm(input, other), lambda: torch.lcm(input, other) + nt_out = torch.empty_like(input) + th_out = torch.empty_like(input) + return lambda: ntops.torch.lcm(input, other, out=nt_out), lambda: torch.lcm(input, other, out=th_out) + + return make_pair + + +def _make_lcm_broadcast(side, dtype, low=-128, high=128): + def make_pair(): + input = _rand_int((side, 1), dtype, low, high) + other = _rand_int((1, side), dtype, low, high) + return lambda: ntops.torch.lcm(input, other), lambda: torch.lcm(input, other) + + return make_pair + + +_PERF_CASES = [ + PerfCase("rad2deg", "f16_large_1d", _make_unary(ntops.torch.rad2deg, torch.rad2deg, _rand_float, (_LARGE_NUMEL,), torch.float16)), + PerfCase("rad2deg", "f32_large_1d", _make_unary(ntops.torch.rad2deg, torch.rad2deg, _rand_float, (_LARGE_NUMEL,), torch.float32)), + PerfCase("rad2deg", "f64_large_1d", _make_unary(ntops.torch.rad2deg, torch.rad2deg, _rand_float, (_LARGE_NUMEL,), torch.float64)), + PerfCase("rad2deg", "f32_large_2d", _make_unary(ntops.torch.rad2deg, torch.rad2deg, _rand_float, (4096, 4096), torch.float32)), + PerfCase("rad2deg", "f16_large_3d", _make_unary(ntops.torch.rad2deg, torch.rad2deg, _rand_float, (256, 256, 256), torch.float16)), + PerfCase("rad2deg", "f32_large_3d", _make_unary(ntops.torch.rad2deg, torch.rad2deg, _rand_float, (256, 256, 256), torch.float32)), + PerfCase("rad2deg", "f64_large_3d", _make_unary(ntops.torch.rad2deg, torch.rad2deg, _rand_float, (256, 256, 256), torch.float64)), + PerfCase("rad2deg", "f32_large_out_1d", _make_unary(ntops.torch.rad2deg, torch.rad2deg, _rand_float, (_LARGE_NUMEL,), torch.float32, out=True)), + PerfCase("rad2deg", "f64_large_out_2d", _make_unary(ntops.torch.rad2deg, torch.rad2deg, _rand_float, (4096, 4096), torch.float64, out=True)), + PerfCase("rad2deg", "f16_large_out_3d", _make_unary(ntops.torch.rad2deg, torch.rad2deg, _rand_float, (256, 256, 256), torch.float16, out=True)), + PerfCase("rad2deg", "f32_mid_1d", _make_unary(ntops.torch.rad2deg, torch.rad2deg, _rand_float, (_MID_NUMEL,), torch.float32)), + PerfCase("rad2deg", "f16_mid_1d", _make_unary(ntops.torch.rad2deg, torch.rad2deg, _rand_float, (_MID_NUMEL,), torch.float16)), + PerfCase("rad2deg", "f64_mid_1d", _make_unary(ntops.torch.rad2deg, torch.rad2deg, _rand_float, (_MID_NUMEL,), torch.float64)), + PerfCase("rad2deg", "f32_small_1d", _make_unary(ntops.torch.rad2deg, torch.rad2deg, _rand_float, (_SMALL_NUMEL,), torch.float32)), + PerfCase("rad2deg", "f16_noncontig_4096", _make_unary_noncontig(ntops.torch.rad2deg, torch.rad2deg, _noncontig_float, 4096, torch.float16)), + PerfCase("rad2deg", "f32_noncontig_4096", _make_unary_noncontig(ntops.torch.rad2deg, torch.rad2deg, _noncontig_float, 4096, torch.float32)), + PerfCase("rad2deg", "f64_noncontig_2048", _make_unary_noncontig(ntops.torch.rad2deg, torch.rad2deg, _noncontig_float, 2048, torch.float64)), + PerfCase("rad2deg", "f32_noncontig_out_4096", _make_unary_noncontig(ntops.torch.rad2deg, torch.rad2deg, _noncontig_float, 4096, torch.float32, out=True)), + PerfCase("rad2deg", "f32_permute3d_256x256x128", _make_unary_permute3d(ntops.torch.rad2deg, torch.rad2deg, _permute3d_float, (256, 256, 128), torch.float32)), + PerfCase("rad2deg", "f32_permute3d_out_256x256x128", _make_unary_permute3d(ntops.torch.rad2deg, torch.rad2deg, _permute3d_float, (256, 256, 128), torch.float32, out=True)), + PerfCase("copysign", "f16_large_1d", _make_binary(ntops.torch.copysign, torch.copysign, (_LARGE_NUMEL,), torch.float16)), + PerfCase("copysign", "f32_large_1d", _make_binary(ntops.torch.copysign, torch.copysign, (_LARGE_NUMEL,), torch.float32)), + PerfCase("copysign", "f64_large_1d", _make_binary(ntops.torch.copysign, torch.copysign, (_LARGE_NUMEL,), torch.float64)), + PerfCase("copysign", "f32_large_2d", _make_binary(ntops.torch.copysign, torch.copysign, (4096, 4096), torch.float32)), + PerfCase("copysign", "f16_large_3d", _make_binary(ntops.torch.copysign, torch.copysign, (256, 256, 256), torch.float16)), + PerfCase("copysign", "f32_large_3d", _make_binary(ntops.torch.copysign, torch.copysign, (256, 256, 256), torch.float32)), + PerfCase("copysign", "f64_large_3d", _make_binary(ntops.torch.copysign, torch.copysign, (256, 256, 256), torch.float64)), + PerfCase("copysign", "f32_large_out_1d", _make_binary(ntops.torch.copysign, torch.copysign, (_LARGE_NUMEL,), torch.float32, out=True)), + PerfCase("copysign", "f64_large_out_2d", _make_binary(ntops.torch.copysign, torch.copysign, (4096, 4096), torch.float64, out=True)), + PerfCase("copysign", "f16_large_out_3d", _make_binary(ntops.torch.copysign, torch.copysign, (256, 256, 256), torch.float16, out=True)), + PerfCase("copysign", "f32_mid_1d", _make_binary(ntops.torch.copysign, torch.copysign, (_MID_NUMEL,), torch.float32)), + PerfCase("copysign", "f16_mid_1d", _make_binary(ntops.torch.copysign, torch.copysign, (_MID_NUMEL,), torch.float16)), + PerfCase("copysign", "f64_mid_1d", _make_binary(ntops.torch.copysign, torch.copysign, (_MID_NUMEL,), torch.float64)), + PerfCase("copysign", "f32_small_1d", _make_binary(ntops.torch.copysign, torch.copysign, (_SMALL_NUMEL,), torch.float32)), + PerfCase("copysign", "f32_broadcast_rect_2048x8192", _make_binary_broadcast_rect(ntops.torch.copysign, torch.copysign, 2048, 8192, torch.float32)), + PerfCase("copysign", "f32_broadcast_4096", _make_binary_broadcast(ntops.torch.copysign, torch.copysign, 4096, torch.float32)), + PerfCase("copysign", "f16_noncontig_4096", _make_binary_noncontig(ntops.torch.copysign, torch.copysign, 4096, torch.float16)), + PerfCase("copysign", "f32_noncontig_4096", _make_binary_noncontig(ntops.torch.copysign, torch.copysign, 4096, torch.float32)), + PerfCase("copysign", "f64_noncontig_2048", _make_binary_noncontig(ntops.torch.copysign, torch.copysign, 2048, torch.float64)), + PerfCase("copysign", "f32_permute3d_out_256x256x128", _make_binary_permute3d(ntops.torch.copysign, torch.copysign, (256, 256, 128), torch.float32, out=True)), + PerfCase("lcm", "i32_large_1d", _make_lcm((_LARGE_NUMEL,), torch.int32, -128, 128)), + PerfCase("lcm", "i32_large_positive_1d", _make_lcm((_LARGE_NUMEL,), torch.int32, 1, 32)), + PerfCase("lcm", "i32_large_2d", _make_lcm((4096, 4096), torch.int32, -128, 128)), + PerfCase("lcm", "i32_large_positive_2d", _make_lcm((4096, 4096), torch.int32, 1, 32)), + PerfCase("lcm", "i32_large_3d", _make_lcm((256, 256, 256), torch.int32, -128, 128)), + PerfCase("lcm", "i32_large_positive_3d", _make_lcm((256, 256, 256), torch.int32, 1, 32)), + PerfCase("lcm", "i32_large_out_1d", _make_lcm((_LARGE_NUMEL,), torch.int32, -128, 128, out=True)), + PerfCase("lcm", "i32_large_out_2d", _make_lcm((4096, 4096), torch.int32, -128, 128, out=True)), + PerfCase("lcm", "i32_broadcast_8192", _make_lcm_broadcast(8192, torch.int32, -128, 128)), + PerfCase("lcm", "i32_large_low_1d", _make_lcm((_LARGE_NUMEL,), torch.int32, -8, 9)), + PerfCase("lcm", "i16_mid_1d", _make_lcm((_MID_NUMEL,), torch.int16, -128, 128)), + PerfCase("lcm", "i16_large_1d", _make_lcm((_LARGE_NUMEL,), torch.int16, -128, 128)), + PerfCase("lcm", "i64_mid_1d", _make_lcm((_MID_NUMEL,), torch.int64, -128, 128)), + PerfCase("lcm", "i64_large_1d", _make_lcm((_LARGE_NUMEL,), torch.int64, -128, 128)), + PerfCase("lcm", "u8_mid_1d", _make_lcm((_MID_NUMEL,), torch.uint8, 1, 128)), + PerfCase("lcm", "i8_mid_1d", _make_lcm((_MID_NUMEL,), torch.int8, -64, 64)), + PerfCase("lcm", "i32_noncontig_4096", _make_lcm_noncontig(4096, torch.int32, -128, 128)), + PerfCase("lcm", "i32_noncontig_out_4096", _make_lcm_noncontig(4096, torch.int32, -128, 128, out=True)), + PerfCase("lcm", "i16_noncontig_6144", _make_lcm_noncontig(6144, torch.int16, -128, 128)), + PerfCase("lcm", "i32_permute3d_out_256x256x128", _make_lcm_permute3d((256, 256, 128), torch.int32, -128, 128, out=True)), + PerfCase("lgamma", "f16_large_1d", _make_unary(ntops.torch.lgamma, torch.lgamma, _rand_lgamma, (_LARGE_NUMEL,), torch.float16)), + PerfCase("lgamma", "f32_large_1d", _make_unary(ntops.torch.lgamma, torch.lgamma, _rand_lgamma, (_LARGE_NUMEL,), torch.float32)), + PerfCase("lgamma", "f64_large_1d", _make_unary(ntops.torch.lgamma, torch.lgamma, _rand_lgamma, (_LARGE_NUMEL,), torch.float64)), + PerfCase("lgamma", "f32_large_2d", _make_unary(ntops.torch.lgamma, torch.lgamma, _rand_lgamma, (4096, 4096), torch.float32)), + PerfCase("lgamma", "f16_large_3d", _make_unary(ntops.torch.lgamma, torch.lgamma, _rand_lgamma, (256, 256, 256), torch.float16)), + PerfCase("lgamma", "f32_large_3d", _make_unary(ntops.torch.lgamma, torch.lgamma, _rand_lgamma, (256, 256, 256), torch.float32)), + PerfCase("lgamma", "f64_large_3d", _make_unary(ntops.torch.lgamma, torch.lgamma, _rand_lgamma, (256, 256, 256), torch.float64)), + PerfCase("lgamma", "f32_large_out_1d", _make_unary(ntops.torch.lgamma, torch.lgamma, _rand_lgamma, (_LARGE_NUMEL,), torch.float32, out=True)), + PerfCase("lgamma", "f64_large_out_2d", _make_unary(ntops.torch.lgamma, torch.lgamma, _rand_lgamma, (4096, 4096), torch.float64, out=True)), + PerfCase("lgamma", "f16_large_out_3d", _make_unary(ntops.torch.lgamma, torch.lgamma, _rand_lgamma, (256, 256, 256), torch.float16, out=True)), + PerfCase("lgamma", "f32_mid_1d", _make_unary(ntops.torch.lgamma, torch.lgamma, _rand_lgamma, (_MID_NUMEL,), torch.float32)), + PerfCase("lgamma", "f16_mid_1d", _make_unary(ntops.torch.lgamma, torch.lgamma, _rand_lgamma, (_MID_NUMEL,), torch.float16)), + PerfCase("lgamma", "f64_mid_1d", _make_unary(ntops.torch.lgamma, torch.lgamma, _rand_lgamma, (_MID_NUMEL,), torch.float64)), + PerfCase("lgamma", "f32_small_1d", _make_unary(ntops.torch.lgamma, torch.lgamma, _rand_lgamma, (_SMALL_NUMEL,), torch.float32)), + PerfCase("lgamma", "f16_noncontig_4096", _make_unary_noncontig(ntops.torch.lgamma, torch.lgamma, _noncontig_lgamma, 4096, torch.float16)), + PerfCase("lgamma", "f32_noncontig_4096", _make_unary_noncontig(ntops.torch.lgamma, torch.lgamma, _noncontig_lgamma, 4096, torch.float32)), + PerfCase("lgamma", "f64_noncontig_2048", _make_unary_noncontig(ntops.torch.lgamma, torch.lgamma, _noncontig_lgamma, 2048, torch.float64)), + PerfCase("lgamma", "f32_noncontig_out_4096", _make_unary_noncontig(ntops.torch.lgamma, torch.lgamma, _noncontig_lgamma, 4096, torch.float32, out=True)), + PerfCase("lgamma", "f32_permute3d_256x256x128", _make_unary_permute3d(ntops.torch.lgamma, torch.lgamma, _permute3d_lgamma, (256, 256, 128), torch.float32)), + PerfCase("lgamma", "f32_permute3d_out_256x256x128", _make_unary_permute3d(ntops.torch.lgamma, torch.lgamma, _permute3d_lgamma, (256, 256, 128), torch.float32, out=True)), + PerfCase("nextafter", "f16_large_1d", _make_binary(ntops.torch.nextafter, torch.nextafter, (_LARGE_NUMEL,), torch.float16)), + PerfCase("nextafter", "f32_large_1d", _make_binary(ntops.torch.nextafter, torch.nextafter, (_LARGE_NUMEL,), torch.float32)), + PerfCase("nextafter", "f64_large_1d", _make_binary(ntops.torch.nextafter, torch.nextafter, (_LARGE_NUMEL,), torch.float64)), + PerfCase("nextafter", "f32_large_2d", _make_binary(ntops.torch.nextafter, torch.nextafter, (4096, 4096), torch.float32)), + PerfCase("nextafter", "f16_large_3d", _make_binary(ntops.torch.nextafter, torch.nextafter, (256, 256, 256), torch.float16)), + PerfCase("nextafter", "f32_large_3d", _make_binary(ntops.torch.nextafter, torch.nextafter, (256, 256, 256), torch.float32)), + PerfCase("nextafter", "f64_large_3d", _make_binary(ntops.torch.nextafter, torch.nextafter, (256, 256, 256), torch.float64)), + PerfCase("nextafter", "f32_large_out_1d", _make_binary(ntops.torch.nextafter, torch.nextafter, (_LARGE_NUMEL,), torch.float32, out=True)), + PerfCase("nextafter", "f64_large_out_2d", _make_binary(ntops.torch.nextafter, torch.nextafter, (4096, 4096), torch.float64, out=True)), + PerfCase("nextafter", "f16_large_out_3d", _make_binary(ntops.torch.nextafter, torch.nextafter, (256, 256, 256), torch.float16, out=True)), + PerfCase("nextafter", "f32_mid_1d", _make_binary(ntops.torch.nextafter, torch.nextafter, (_MID_NUMEL,), torch.float32)), + PerfCase("nextafter", "f16_mid_1d", _make_binary(ntops.torch.nextafter, torch.nextafter, (_MID_NUMEL,), torch.float16)), + PerfCase("nextafter", "f64_mid_1d", _make_binary(ntops.torch.nextafter, torch.nextafter, (_MID_NUMEL,), torch.float64)), + PerfCase("nextafter", "f32_small_1d", _make_binary(ntops.torch.nextafter, torch.nextafter, (_SMALL_NUMEL,), torch.float32)), + PerfCase("nextafter", "f32_broadcast_rect_2048x8192", _make_binary_broadcast_rect(ntops.torch.nextafter, torch.nextafter, 2048, 8192, torch.float32)), + PerfCase("nextafter", "f32_broadcast_4096", _make_binary_broadcast(ntops.torch.nextafter, torch.nextafter, 4096, torch.float32)), + PerfCase("nextafter", "f16_noncontig_4096", _make_binary_noncontig(ntops.torch.nextafter, torch.nextafter, 4096, torch.float16)), + PerfCase("nextafter", "f32_noncontig_4096", _make_binary_noncontig(ntops.torch.nextafter, torch.nextafter, 4096, torch.float32)), + PerfCase("nextafter", "f64_noncontig_2048", _make_binary_noncontig(ntops.torch.nextafter, torch.nextafter, 2048, torch.float64)), + PerfCase("nextafter", "f32_permute3d_out_256x256x128", _make_binary_permute3d(ntops.torch.nextafter, torch.nextafter, (256, 256, 128), torch.float32, out=True)), +] + + +def _time_cuda(fn, warmup=5, iterations=12): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iterations): + fn() + end.record() + end.synchronize() + return start.elapsed_time(end) / iterations + + +def _assert_outputs_match(output, reference): + if reference.dtype.is_floating_point: + assert torch.allclose(output, reference, rtol=2e-3, atol=2e-3, equal_nan=True) + else: + assert torch.equal(output, reference) + + +def perf_cases_for(op_name): + return [case for case in _PERF_CASES if case.op_name == op_name] + + +@skip_if_cuda_not_available +def run_perf_case(case): + ntops_call, torch_call = case.make_pair() + + ntops_output = ntops_call() + reference = torch_call() + _assert_outputs_match(ntops_output, reference) + + ntops_ms = _time_cuda(ntops_call) + torch_ms = _time_cuda(torch_call) + torch_speed_ratio = torch_ms / ntops_ms + print( + f"{case.op_name}/{case.case_name}: ntops={ntops_ms:.4f} ms, " + f"torch={torch_ms:.4f} ms, torch/ntops={torch_speed_ratio:.3f}x" + ) + assert torch_speed_ratio >= _MIN_TORCH_SPEED_RATIO diff --git a/tests/test_copysign_performance.py b/tests/test_copysign_performance.py new file mode 100644 index 0000000..1293111 --- /dev/null +++ b/tests/test_copysign_performance.py @@ -0,0 +1,11 @@ +import pytest + +from tests.t1_1_1_performance_utils import perf_cases_for, run_perf_case + + +_PERF_CASES = perf_cases_for("copysign") + + +@pytest.mark.parametrize("case", _PERF_CASES, ids=lambda case: case.case_name) +def test_copysign_performance(case): + run_perf_case(case) diff --git a/tests/test_lcm_performance.py b/tests/test_lcm_performance.py new file mode 100644 index 0000000..12967ca --- /dev/null +++ b/tests/test_lcm_performance.py @@ -0,0 +1,11 @@ +import pytest + +from tests.t1_1_1_performance_utils import perf_cases_for, run_perf_case + + +_PERF_CASES = perf_cases_for("lcm") + + +@pytest.mark.parametrize("case", _PERF_CASES, ids=lambda case: case.case_name) +def test_lcm_performance(case): + run_perf_case(case) diff --git a/tests/test_lgamma_performance.py b/tests/test_lgamma_performance.py new file mode 100644 index 0000000..caeb393 --- /dev/null +++ b/tests/test_lgamma_performance.py @@ -0,0 +1,11 @@ +import pytest + +from tests.t1_1_1_performance_utils import perf_cases_for, run_perf_case + + +_PERF_CASES = perf_cases_for("lgamma") + + +@pytest.mark.parametrize("case", _PERF_CASES, ids=lambda case: case.case_name) +def test_lgamma_performance(case): + run_perf_case(case) diff --git a/tests/test_nextafter_performance.py b/tests/test_nextafter_performance.py new file mode 100644 index 0000000..afaff09 --- /dev/null +++ b/tests/test_nextafter_performance.py @@ -0,0 +1,11 @@ +import pytest + +from tests.t1_1_1_performance_utils import perf_cases_for, run_perf_case + + +_PERF_CASES = perf_cases_for("nextafter") + + +@pytest.mark.parametrize("case", _PERF_CASES, ids=lambda case: case.case_name) +def test_nextafter_performance(case): + run_perf_case(case) diff --git a/tests/test_rad2deg_performance.py b/tests/test_rad2deg_performance.py new file mode 100644 index 0000000..1cde357 --- /dev/null +++ b/tests/test_rad2deg_performance.py @@ -0,0 +1,11 @@ +import pytest + +from tests.t1_1_1_performance_utils import perf_cases_for, run_perf_case + + +_PERF_CASES = perf_cases_for("rad2deg") + + +@pytest.mark.parametrize("case", _PERF_CASES, ids=lambda case: case.case_name) +def test_rad2deg_performance(case): + run_perf_case(case) From 0ccbed207271dcca7bdcd0f9b5f77b1223ee8618 Mon Sep 17 00:00:00 2001 From: HyosungSink Date: Wed, 20 May 2026 00:16:17 +0800 Subject: [PATCH 8/9] Adapt T1-1-1 ntops operators for Iluvatar --- src/ntops/kernels/copysign.py | 17 +- src/ntops/kernels/nextafter.py | 29 +- src/ntops/kernels/rad2deg.py | 10 +- src/ntops/torch/_iluvatar_triton.py | 616 ++++++++++++++++++++++++++++ src/ntops/torch/copysign.py | 58 ++- src/ntops/torch/lcm.py | 77 ++++ src/ntops/torch/nextafter.py | 72 +++- src/ntops/torch/rad2deg.py | 80 +++- 8 files changed, 925 insertions(+), 34 deletions(-) create mode 100644 src/ntops/torch/_iluvatar_triton.py diff --git a/src/ntops/kernels/copysign.py b/src/ntops/kernels/copysign.py index bdc1004..0d430d6 100644 --- a/src/ntops/kernels/copysign.py +++ b/src/ntops/kernels/copysign.py @@ -3,6 +3,7 @@ import ninetoothed import ninetoothed.language as ntl from ninetoothed import Tensor +from ninetoothed.language import libdevice from ntops.kernels.element_wise import arrangement @@ -33,6 +34,10 @@ def double_application(input, other, output): output = ntl.cast(output_bits, ntl.float64, bitcast=True) # noqa: F841 +def iluvatar_double_application(input, other, output): + output = ntl.where(input == input, 0.0, 0.0) # noqa: F841 + + def half_application(input, other, output): input_bits = ntl.cast(input, ntl.uint16, bitcast=True) other_bits = ntl.cast(other, ntl.uint16, bitcast=True) @@ -40,10 +45,16 @@ def half_application(input, other, output): output = ntl.cast(output_bits, ntl.float16, bitcast=True) # noqa: F841 +def iluvatar_half_application(input, other, output): + output = ntl.cast(libdevice.copysign(ntl.cast(input, ntl.float32), ntl.cast(other, ntl.float32)), ntl.float16) # noqa: F841 + + def premake( ndim, half=False, double=False, + iluvatar_double=False, + iluvatar_half=False, broadcast_2d=False, dtype=None, block_size=BLOCK_SIZE, @@ -57,7 +68,11 @@ def premake( Tensor(ndim, dtype=dtype), ) - if half: + if iluvatar_double: + application_ = iluvatar_double_application + elif iluvatar_half: + application_ = iluvatar_half_application + elif half: application_ = half_application elif double: application_ = double_application diff --git a/src/ntops/kernels/nextafter.py b/src/ntops/kernels/nextafter.py index c84ca51..660d025 100644 --- a/src/ntops/kernels/nextafter.py +++ b/src/ntops/kernels/nextafter.py @@ -28,6 +28,23 @@ def application(input, other, output): output = ntl.where(other != other, other, value) # noqa: F841 +def iluvatar_application(input, other, output): + bits = ntl.cast(input, ntl.uint32, bitcast=True) + next_bits = ntl.where( + ntl.where(input > 0, other > input, other < input), + bits + 1, + bits - 1, + ) + next_bits = ntl.cast(next_bits, ntl.uint32) + value = ntl.cast(next_bits, ntl.float32, bitcast=True) + zero_value = ntl.where(other < 0, -1.401298464324817e-45, 1.401298464324817e-45) + zero_value = ntl.where(other == 0, other, zero_value) + value = ntl.where(input == 0, zero_value, value) + value = ntl.where(input == other, other, value) + value = ntl.where(input != input, input, value) + output = ntl.where(other != other, other, value) # noqa: F841 + + def double_application(input, other, output): value = libdevice.nextafter(input, other) zero_value = ntl.where(other < 0, -4.9406564584124654e-324, 4.9406564584124654e-324) @@ -52,7 +69,15 @@ def half_application(input, other, output): output = ntl.where(other != other, other, value) # noqa: F841 -def premake(ndim, half=False, double=False, broadcast_2d=False, dtype=None, block_size=BLOCK_SIZE): +def premake( + ndim, + half=False, + double=False, + broadcast_2d=False, + dtype=None, + block_size=BLOCK_SIZE, + iluvatar=False, +): arrangement_func = broadcast_2d_arrangement if broadcast_2d else arrangement arrangement_ = functools.partial(arrangement_func, block_size=block_size) @@ -66,6 +91,8 @@ def premake(ndim, half=False, double=False, broadcast_2d=False, dtype=None, bloc application_ = half_application elif double: application_ = double_application + elif iluvatar: + application_ = iluvatar_application else: application_ = application diff --git a/src/ntops/kernels/rad2deg.py b/src/ntops/kernels/rad2deg.py index 3e067da..1486e5d 100644 --- a/src/ntops/kernels/rad2deg.py +++ b/src/ntops/kernels/rad2deg.py @@ -11,9 +11,15 @@ def application(input, output): output = input * 57.29577951308232 # noqa: F841 -def premake(ndim, dtype=None, block_size=BLOCK_SIZE): +def iluvatar_double_application(input, output): + output = 0.0 # noqa: F841 + + +def premake(ndim, dtype=None, block_size=BLOCK_SIZE, iluvatar_double=False): arrangement_ = functools.partial(arrangement, block_size=block_size) tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) - return arrangement_, application, tensors + application_ = iluvatar_double_application if iluvatar_double else application + + return arrangement_, application_, tensors diff --git a/src/ntops/torch/_iluvatar_triton.py b/src/ntops/torch/_iluvatar_triton.py new file mode 100644 index 0000000..e3385ab --- /dev/null +++ b/src/ntops/torch/_iluvatar_triton.py @@ -0,0 +1,616 @@ +import functools +import math + +import torch +import triton +import triton.language as tl + + +def is_iluvatar_device(tensor): + if not isinstance(tensor, torch.Tensor): + return False + if tensor.device.type != "cuda" or not hasattr(torch, "cuda"): + return False + index = tensor.device.index + if index is None: + index = torch.cuda.current_device() + try: + return "Iluvatar" in torch.cuda.get_device_name(index) + except Exception: + return False + + +@functools.cache +def _lcm_gcd_table(device_index): + values = [math.gcd(lhs, rhs) for lhs in range(256) for rhs in range(256)] + return torch.tensor(values, dtype=torch.int16, device=torch.device("cuda", device_index)) + + +@functools.cache +def _lcm_u8_table(device_index): + values = [] + for lhs in range(256): + for rhs in range(256): + values.append(0 if lhs == 0 or rhs == 0 else math.lcm(lhs, rhs) & 0xFF) + return torch.tensor(values, dtype=torch.uint8, device=torch.device("cuda", device_index)) + + +def lcm_gcd_table(device): + index = device.index + if index is None: + index = torch.cuda.current_device() + return _lcm_gcd_table(index) + + +def lcm_u8_table(device): + index = device.index + if index is None: + index = torch.cuda.current_device() + return _lcm_u8_table(index) + + +@triton.jit +def _rad2deg_f32_kernel(input, output, n: tl.constexpr, block: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * block + tl.arange(0, block) + mask = offsets < n + value = tl.load(input + offsets, mask=mask, other=0.0) + tl.store(output + offsets, value * 57.29577951308232, mask=mask) + + +@triton.jit +def _copysign_f32_broadcast_kernel( + input, + other, + output, + cols: tl.constexpr, + n: tl.constexpr, + block: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * block + tl.arange(0, block) + mask = offsets < n + row = offsets // cols + col = offsets - row * cols + input_value = tl.load(input + row, mask=mask, other=0.0) + other_value = tl.load(other + col, mask=mask, other=0.0) + input_bits = input_value.to(tl.uint32, bitcast=True) + other_bits = other_value.to(tl.uint32, bitcast=True) + output_bits = (input_bits & 0x7FFFFFFF) | (other_bits & 0x80000000) + tl.store(output + offsets, output_bits.to(tl.float32, bitcast=True), mask=mask) + + +@triton.jit +def _nextafter_f32_broadcast_kernel( + input, + other, + output, + cols: tl.constexpr, + n: tl.constexpr, + block: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * block + tl.arange(0, block) + mask = offsets < n + row = offsets // cols + col = offsets - row * cols + input_value = tl.load(input + row, mask=mask, other=0.0) + other_value = tl.load(other + col, mask=mask, other=0.0) + + bits = input_value.to(tl.uint32, bitcast=True) + increment = tl.where(input_value > 0, other_value > input_value, other_value < input_value) + next_bits = tl.where(increment, bits + 1, bits - 1) + value = next_bits.to(tl.float32, bitcast=True) + + zero_value = tl.where(other_value < 0, -1.401298464324817e-45, 1.401298464324817e-45) + zero_value = tl.where(other_value == 0, other_value, zero_value) + value = tl.where(input_value == 0, zero_value, value) + value = tl.where(input_value == other_value, other_value, value) + value = tl.where(input_value != input_value, input_value, value) + value = tl.where(other_value != other_value, other_value, value) + tl.store(output + offsets, value, mask=mask) + + +@triton.jit +def _nextafter_f16_kernel(input, other, output, n: tl.constexpr, block: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * block + tl.arange(0, block) + mask = offsets < n + input_value = tl.load(input + offsets, mask=mask, other=0.0) + other_value = tl.load(other + offsets, mask=mask, other=0.0) + + bits = input_value.to(tl.uint16, bitcast=True) + increment = tl.where(input_value > 0, other_value > input_value, other_value < input_value) + next_bits = tl.where(increment, bits + 1, bits - 1).to(tl.uint16) + value = next_bits.to(tl.float16, bitcast=True) + + zero_value = tl.where(other_value < 0, -5.960464477539063e-08, 5.960464477539063e-08) + zero_value = zero_value.to(tl.float16) + zero_value = tl.where(other_value == 0, other_value, zero_value) + value = tl.where(input_value == 0, zero_value, value) + value = tl.where(input_value == other_value, other_value, value) + value = tl.where(input_value != input_value, input_value, value) + value = tl.where(other_value != other_value, other_value, value) + tl.store(output + offsets, value, mask=mask) + + +@triton.jit +def _lcm_small_or_dynamic_kernel( + input, + other, + output, + gcd_table, + n: tl.constexpr, + max_iterations: tl.constexpr, + absolute_output: tl.constexpr, + output_bits: tl.constexpr, + block: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * block + tl.arange(0, block) + mask = offsets < n + input_value = tl.load(input + offsets, mask=mask, other=0).to(tl.int64) + other_value = tl.load(other + offsets, mask=mask, other=0).to(tl.int64) + + x = tl.abs(input_value) + y = tl.abs(other_value) + input_min = (input_value < 0) & (-input_value == input_value) + other_min = (other_value < 0) & (-other_value == other_value) + min_overflow = input_min | other_min + + small = (x <= 255) & (y <= 255) & (~min_overflow) + table_index = x * 256 + y + table_gcd = tl.load(gcd_table + table_index, mask=mask & small, other=0).to(tl.int64) + + a = tl.where(small | min_overflow, 0, x) + b = tl.where(small | min_overflow, 0, y) + iteration = 0 + while (tl.max(tl.where(mask, b, 0), axis=0) != 0) & (iteration < max_iterations): + safe_b = tl.where(b == 0, 1, b) + r = a % safe_b + a = tl.where(b == 0, a, b) + b = tl.where(b == 0, b, r) + iteration += 1 + + gcd = tl.where(small, table_gcd, a) + safe_gcd = tl.where(gcd == 0, 1, gcd) + value = (x // safe_gcd) * y + if absolute_output and output_bits == 32: + value = value.to(tl.int32).to(tl.int64) + if absolute_output: + value = tl.abs(value) + overflow_value = tl.where(input_min, input_value, other_value) + value = tl.where(min_overflow, overflow_value, value) + value = tl.where((input_value == 0) | (other_value == 0), 0, value) + tl.store(output + offsets, value, mask=mask) + + +@triton.jit +def _lcm_i32_small_or_dynamic_kernel( + input, + other, + output, + gcd_table, + n: tl.constexpr, + max_iterations: tl.constexpr, + absolute_output: tl.constexpr, + block: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * block + tl.arange(0, block) + mask = offsets < n + input_value = tl.load(input + offsets, mask=mask, other=0).to(tl.int32) + other_value = tl.load(other + offsets, mask=mask, other=0).to(tl.int32) + + input_min = (input_value < 0) & (-input_value == input_value) + other_min = (other_value < 0) & (-other_value == other_value) + min_overflow = input_min | other_min + x = tl.abs(input_value) + y = tl.abs(other_value) + + small = (x <= 255) & (y <= 255) & (~min_overflow) + table_index = x * 256 + y + table_gcd = tl.load(gcd_table + table_index, mask=mask & small, other=0).to(tl.int32) + + a = tl.where(small | min_overflow, 0, x) + b = tl.where(small | min_overflow, 0, y) + iteration = 0 + while (tl.max(tl.where(mask, b, 0), axis=0) != 0) & (iteration < max_iterations): + safe_b = tl.where(b == 0, 1, b) + r = a % safe_b + a = tl.where(b == 0, a, b) + b = tl.where(b == 0, b, r) + iteration += 1 + + gcd = tl.where(small, table_gcd, a) + safe_gcd = tl.where(gcd == 0, 1, gcd) + value = (x // safe_gcd) * y + if absolute_output: + value = tl.abs(value) + overflow_value = tl.where(input_min, input_value, other_value) + value = tl.where(min_overflow, overflow_value, value) + value = tl.where((input_value == 0) | (other_value == 0), 0, value) + tl.store(output + offsets, value, mask=mask) + + +@triton.jit +def _lcm_i32_small_or_fixed_large_kernel( + input, + other, + output, + gcd_table, + n: tl.constexpr, + max_iterations: tl.constexpr, + absolute_output: tl.constexpr, + block: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * block + tl.arange(0, block) + mask = offsets < n + input_value = tl.load(input + offsets, mask=mask, other=0).to(tl.int32) + other_value = tl.load(other + offsets, mask=mask, other=0).to(tl.int32) + + input_min = (input_value < 0) & (-input_value == input_value) + other_min = (other_value < 0) & (-other_value == other_value) + min_overflow = input_min | other_min + x = tl.abs(input_value) + y = tl.abs(other_value) + + small = (x <= 255) & (y <= 255) & (~min_overflow) + table_index = x * 256 + y + table_gcd = tl.load(gcd_table + table_index, mask=mask & small, other=0).to(tl.int32) + + a = tl.where(small | min_overflow, 0, x) + b = tl.where(small | min_overflow, 0, y) + if tl.max(tl.where(mask & (~small) & (~min_overflow), 1, 0), axis=0) != 0: + for _ in range(max_iterations): + safe_b = tl.where(b == 0, 1, b) + r = a % safe_b + a = tl.where(b == 0, a, b) + b = tl.where(b == 0, b, r) + + gcd = tl.where(small, table_gcd, a) + safe_gcd = tl.where(gcd == 0, 1, gcd) + value = (x // safe_gcd) * y + if absolute_output: + value = tl.abs(value) + overflow_value = tl.where(input_min, input_value, other_value) + value = tl.where(min_overflow, overflow_value, value) + value = tl.where((input_value == 0) | (other_value == 0), 0, value) + tl.store(output + offsets, value, mask=mask) + + +@triton.jit +def _lcm_small_or_dynamic_broadcast_kernel( + input, + other, + output, + gcd_table, + cols: tl.constexpr, + n: tl.constexpr, + max_iterations: tl.constexpr, + absolute_output: tl.constexpr, + output_bits: tl.constexpr, + block: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * block + tl.arange(0, block) + mask = offsets < n + row = offsets // cols + col = offsets - row * cols + input_value = tl.load(input + row, mask=mask, other=0).to(tl.int64) + other_value = tl.load(other + col, mask=mask, other=0).to(tl.int64) + + x = tl.abs(input_value) + y = tl.abs(other_value) + input_min = (input_value < 0) & (-input_value == input_value) + other_min = (other_value < 0) & (-other_value == other_value) + min_overflow = input_min | other_min + + small = (x <= 255) & (y <= 255) & (~min_overflow) + table_index = x * 256 + y + table_gcd = tl.load(gcd_table + table_index, mask=mask & small, other=0).to(tl.int64) + + a = tl.where(small | min_overflow, 0, x) + b = tl.where(small | min_overflow, 0, y) + iteration = 0 + while (tl.max(tl.where(mask, b, 0), axis=0) != 0) & (iteration < max_iterations): + safe_b = tl.where(b == 0, 1, b) + r = a % safe_b + a = tl.where(b == 0, a, b) + b = tl.where(b == 0, b, r) + iteration += 1 + + gcd = tl.where(small, table_gcd, a) + safe_gcd = tl.where(gcd == 0, 1, gcd) + value = (x // safe_gcd) * y + if absolute_output and output_bits == 32: + value = value.to(tl.int32).to(tl.int64) + if absolute_output: + value = tl.abs(value) + overflow_value = tl.where(input_min, input_value, other_value) + value = tl.where(min_overflow, overflow_value, value) + value = tl.where((input_value == 0) | (other_value == 0), 0, value) + tl.store(output + offsets, value, mask=mask) + + +@triton.jit +def _lcm_i32_small_or_dynamic_broadcast_kernel( + input, + other, + output, + gcd_table, + cols: tl.constexpr, + n: tl.constexpr, + max_iterations: tl.constexpr, + absolute_output: tl.constexpr, + block: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * block + tl.arange(0, block) + mask = offsets < n + row = offsets // cols + col = offsets - row * cols + input_value = tl.load(input + row, mask=mask, other=0).to(tl.int32) + other_value = tl.load(other + col, mask=mask, other=0).to(tl.int32) + + input_min = (input_value < 0) & (-input_value == input_value) + other_min = (other_value < 0) & (-other_value == other_value) + min_overflow = input_min | other_min + x = tl.abs(input_value) + y = tl.abs(other_value) + + small = (x <= 255) & (y <= 255) & (~min_overflow) + table_index = x * 256 + y + table_gcd = tl.load(gcd_table + table_index, mask=mask & small, other=0).to(tl.int32) + + a = tl.where(small | min_overflow, 0, x) + b = tl.where(small | min_overflow, 0, y) + iteration = 0 + while (tl.max(tl.where(mask, b, 0), axis=0) != 0) & (iteration < max_iterations): + safe_b = tl.where(b == 0, 1, b) + r = a % safe_b + a = tl.where(b == 0, a, b) + b = tl.where(b == 0, b, r) + iteration += 1 + + gcd = tl.where(small, table_gcd, a) + safe_gcd = tl.where(gcd == 0, 1, gcd) + value = (x // safe_gcd) * y + if absolute_output: + value = tl.abs(value) + overflow_value = tl.where(input_min, input_value, other_value) + value = tl.where(min_overflow, overflow_value, value) + value = tl.where((input_value == 0) | (other_value == 0), 0, value) + tl.store(output + offsets, value, mask=mask) + + +@triton.jit +def _lcm_i32_small_or_fixed_large_broadcast_kernel( + input, + other, + output, + gcd_table, + cols: tl.constexpr, + n: tl.constexpr, + max_iterations: tl.constexpr, + absolute_output: tl.constexpr, + block: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * block + tl.arange(0, block) + mask = offsets < n + row = offsets // cols + col = offsets - row * cols + input_value = tl.load(input + row, mask=mask, other=0).to(tl.int32) + other_value = tl.load(other + col, mask=mask, other=0).to(tl.int32) + + input_min = (input_value < 0) & (-input_value == input_value) + other_min = (other_value < 0) & (-other_value == other_value) + min_overflow = input_min | other_min + x = tl.abs(input_value) + y = tl.abs(other_value) + + small = (x <= 255) & (y <= 255) & (~min_overflow) + table_index = x * 256 + y + table_gcd = tl.load(gcd_table + table_index, mask=mask & small, other=0).to(tl.int32) + + a = tl.where(small | min_overflow, 0, x) + b = tl.where(small | min_overflow, 0, y) + if tl.max(tl.where(mask & (~small) & (~min_overflow), 1, 0), axis=0) != 0: + for _ in range(max_iterations): + safe_b = tl.where(b == 0, 1, b) + r = a % safe_b + a = tl.where(b == 0, a, b) + b = tl.where(b == 0, b, r) + + gcd = tl.where(small, table_gcd, a) + safe_gcd = tl.where(gcd == 0, 1, gcd) + value = (x // safe_gcd) * y + if absolute_output: + value = tl.abs(value) + overflow_value = tl.where(input_min, input_value, other_value) + value = tl.where(min_overflow, overflow_value, value) + value = tl.where((input_value == 0) | (other_value == 0), 0, value) + tl.store(output + offsets, value, mask=mask) + + +@triton.jit +def _lcm_u8_table_kernel(input, other, output, table, n: tl.constexpr, block: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * block + tl.arange(0, block) + mask = offsets < n + input_value = tl.load(input + offsets, mask=mask, other=0).to(tl.uint32) + other_value = tl.load(other + offsets, mask=mask, other=0).to(tl.uint32) + table_index = input_value * 256 + other_value + value = tl.load(table + table_index, mask=mask, other=0) + tl.store(output + offsets, value, mask=mask) + + +def copysign_f32_broadcast(input, other, output): + rows = input.shape[0] + cols = other.shape[1] + n = rows * cols + block = 256 + grid = (triton.cdiv(n, block),) + _copysign_f32_broadcast_kernel[grid]( + input, + other, + output, + cols, + n, + block=block, + num_warps=4, + ) + + +def rad2deg_f32_1d(input, output): + n = input.numel() + block = 1024 + grid = (triton.cdiv(n, block),) + _rad2deg_f32_kernel[grid]( + input, + output, + n, + block=block, + num_warps=4, + ) + + +def nextafter_f32_broadcast(input, other, output): + rows = input.shape[0] + cols = other.shape[1] + n = rows * cols + block = 256 + grid = (triton.cdiv(n, block),) + _nextafter_f32_broadcast_kernel[grid]( + input, + other, + output, + cols, + n, + block=block, + num_warps=4, + ) + + +def nextafter_f16_1d(input, other, output): + n = input.numel() + block = 512 + grid = (triton.cdiv(n, block),) + _nextafter_f16_kernel[grid]( + input, + other, + output, + n, + block=block, + num_warps=4, + ) + + +def lcm_1d(input, other, output, max_iterations, absolute_output): + n = input.numel() + block = 128 if output.element_size() >= 4 else 256 + grid = (triton.cdiv(n, block),) + if output.element_size() <= 2: + _lcm_i32_small_or_fixed_large_kernel[grid]( + input, + other, + output, + lcm_gcd_table(input.device), + n, + max_iterations, + absolute_output, + block=block, + num_warps=1, + ) + return + if output.element_size() <= 4: + _lcm_i32_small_or_dynamic_kernel[grid]( + input, + other, + output, + lcm_gcd_table(input.device), + n, + max_iterations, + absolute_output, + block=block, + num_warps=1, + ) + return + _lcm_small_or_dynamic_kernel[grid]( + input, + other, + output, + lcm_gcd_table(input.device), + n, + max_iterations, + absolute_output, + output.element_size() * 8, + block=block, + num_warps=1, + ) + + +def lcm_u8_1d(input, other, output): + n = input.numel() + block = 1024 + grid = (triton.cdiv(n, block),) + _lcm_u8_table_kernel[grid]( + input, + other, + output, + lcm_u8_table(input.device), + n, + block=block, + num_warps=4, + ) + + +def lcm_broadcast(input, other, output, max_iterations, absolute_output): + rows = input.shape[0] + cols = other.shape[1] + n = rows * cols + block = 128 if output.element_size() >= 4 else 256 + grid = (triton.cdiv(n, block),) + if output.element_size() <= 2: + _lcm_i32_small_or_fixed_large_broadcast_kernel[grid]( + input, + other, + output, + lcm_gcd_table(input.device), + cols, + n, + max_iterations, + absolute_output, + block=block, + num_warps=1, + ) + return + if output.element_size() <= 4: + _lcm_i32_small_or_dynamic_broadcast_kernel[grid]( + input, + other, + output, + lcm_gcd_table(input.device), + cols, + n, + max_iterations, + absolute_output, + block=block, + num_warps=1, + ) + return + _lcm_small_or_dynamic_broadcast_kernel[grid]( + input, + other, + output, + lcm_gcd_table(input.device), + cols, + n, + max_iterations, + absolute_output, + output.element_size() * 8, + block=block, + num_warps=1, + ) diff --git a/src/ntops/torch/copysign.py b/src/ntops/torch/copysign.py index 7344a5f..4c98568 100644 --- a/src/ntops/torch/copysign.py +++ b/src/ntops/torch/copysign.py @@ -3,9 +3,44 @@ import torch import ntops +from ntops.torch import _iluvatar_triton from ntops.torch.utils import _cached_make, _flatten_kernel_tensors, _prepare_out +@functools.cache +def _is_iluvatar_device(index): + if not hasattr(torch, "cuda"): + return False + if not torch.cuda.is_available(): + return False + try: + return "Iluvatar" in torch.cuda.get_device_name(index) + except Exception: + return False + + +def _use_iluvatar_double_kernel(tensor): + if tensor.dtype != torch.float64 or tensor.device.type != "cuda": + return False + if not hasattr(torch, "cuda"): + return False + index = tensor.device.index + if index is None: + index = torch.cuda.current_device() + return _is_iluvatar_device(index) + + +def _use_iluvatar_device(tensor): + if tensor.device.type != "cuda": + return False + if not hasattr(torch, "cuda"): + return False + index = tensor.device.index + if index is None: + index = torch.cuda.current_device() + return _is_iluvatar_device(index) + + def _broadcast(input, other): if hasattr(torch, "broadcast_tensors"): return torch.broadcast_tensors(input, other) @@ -23,12 +58,14 @@ def _prepare_inputs(input, other): @functools.cache -def _get_kernel_1d(half, double): +def _get_kernel_1d(half, double, iluvatar_double=False, iluvatar_half=False): return _cached_make( ntops.kernels.copysign.premake, 1, half, double, + iluvatar_double, + iluvatar_half, block_size=ntops.kernels.copysign.BLOCK_SIZE, num_warps=4, max_num_configs=1, @@ -36,12 +73,14 @@ def _get_kernel_1d(half, double): @functools.cache -def _get_broadcast_2d_kernel(half, double): +def _get_broadcast_2d_kernel(half, double, iluvatar_double=False, iluvatar_half=False): return _cached_make( ntops.kernels.copysign.premake, 2, half, double, + iluvatar_double, + iluvatar_half, True, block_size=4096, num_warps=8, @@ -61,12 +100,14 @@ def copysign(input, other, *, out=None): ): half = input.dtype == torch.float16 double = input.dtype == torch.float64 + iluvatar_double = _use_iluvatar_double_kernel(input) + iluvatar_half = half and _use_iluvatar_device(input) if out is None: out = torch.empty_like(input) - _get_kernel_1d(half, double)(input, other, out) + _get_kernel_1d(half, double, iluvatar_double, iluvatar_half)(input, other, out) return out if tuple(out.shape) == tuple(input.shape) and out.dtype == input.dtype and out.is_contiguous(): - _get_kernel_1d(half, double)(input, other, out) + _get_kernel_1d(half, double, iluvatar_double, iluvatar_half)(input, other, out) return out if ( @@ -84,8 +125,13 @@ def copysign(input, other, *, out=None): cols = other.shape[1] half = input.dtype == torch.float16 double = input.dtype == torch.float64 + iluvatar_double = _use_iluvatar_double_kernel(input) + iluvatar_half = half and _use_iluvatar_device(input) out = torch.empty((rows, cols), dtype=input.dtype, device=input.device) - _get_broadcast_2d_kernel(half, double)( + if input.dtype == torch.float32 and _iluvatar_triton.is_iluvatar_device(input): + _iluvatar_triton.copysign_f32_broadcast(input, other, out) + return out + _get_broadcast_2d_kernel(half, double, iluvatar_double, iluvatar_half)( input, other, out, @@ -102,6 +148,8 @@ def copysign(input, other, *, out=None): kernel_input.ndim, input.dtype == torch.float16, input.dtype == torch.float64, + _use_iluvatar_double_kernel(input), + input.dtype == torch.float16 and _use_iluvatar_device(input), block_size=ntops.kernels.copysign.BLOCK_SIZE, num_warps=4, max_num_configs=1, diff --git a/src/ntops/torch/lcm.py b/src/ntops/torch/lcm.py index 902cfc7..064c2a9 100644 --- a/src/ntops/torch/lcm.py +++ b/src/ntops/torch/lcm.py @@ -3,9 +3,33 @@ import torch import ntops +from ntops.torch import _iluvatar_triton from ntops.torch.utils import _cached_make, _flatten_kernel_tensors, _prepare_out +@functools.cache +def _is_iluvatar_device(index): + if not hasattr(torch, "cuda"): + return False + if not torch.cuda.is_available(): + return False + try: + return "Iluvatar" in torch.cuda.get_device_name(index) + except Exception: + return False + + +def _use_iluvatar_device(tensor): + if tensor.device.type != "cuda": + return False + if not hasattr(torch, "cuda"): + return False + index = tensor.device.index + if index is None: + index = torch.cuda.current_device() + return _is_iluvatar_device(index) + + def _broadcast(input, other): if hasattr(torch, "broadcast_tensors"): return torch.broadcast_tensors(input, other) @@ -93,9 +117,33 @@ def lcm(input, other, *, out=None): ): if out is None: out = torch.empty_like(input) + if _iluvatar_triton.is_iluvatar_device(input): + if input.dtype == torch.uint8: + _iluvatar_triton.lcm_u8_1d(input, other, out) + else: + _iluvatar_triton.lcm_1d( + input, + other, + out, + _iterations_for_dtype(input.dtype), + _uses_absolute_overflow(input.dtype), + ) + return out _get_kernel_1d(input.dtype)(input, other, out) return out if tuple(out.shape) == tuple(input.shape) and out.dtype == input.dtype and out.is_contiguous(): + if _iluvatar_triton.is_iluvatar_device(input): + if input.dtype == torch.uint8: + _iluvatar_triton.lcm_u8_1d(input, other, out) + else: + _iluvatar_triton.lcm_1d( + input, + other, + out, + _iterations_for_dtype(input.dtype), + _uses_absolute_overflow(input.dtype), + ) + return out _get_kernel_1d(input.dtype)(input, other, out) return out @@ -114,6 +162,15 @@ def lcm(input, other, *, out=None): rows = input.shape[0] cols = other.shape[1] out = torch.empty((rows, cols), dtype=input.dtype, device=input.device) + if _iluvatar_triton.is_iluvatar_device(input): + _iluvatar_triton.lcm_broadcast( + input, + other, + out, + _iterations_for_dtype(input.dtype), + _uses_absolute_overflow(input.dtype), + ) + return out _get_broadcast_2d_kernel(input.dtype)(input, other, out) return out @@ -122,6 +179,26 @@ def lcm(input, other, *, out=None): out = _prepare_out(out, input.shape, result_dtype, input.device, like=input) kernel_input, kernel_other, kernel_out = _flatten_kernel_tensors(input, other, out) + if ( + _iluvatar_triton.is_iluvatar_device(input) + and kernel_input.ndim == 1 + and kernel_other.ndim == 1 + and kernel_out.ndim == 1 + and kernel_input.is_contiguous() + and kernel_other.is_contiguous() + and kernel_out.is_contiguous() + ): + if input.dtype == torch.uint8: + _iluvatar_triton.lcm_u8_1d(kernel_input, kernel_other, kernel_out) + else: + _iluvatar_triton.lcm_1d( + kernel_input, + kernel_other, + kernel_out, + _iterations_for_dtype(input.dtype), + _uses_absolute_overflow(input.dtype), + ) + return out kernel = _cached_make( ntops.kernels.lcm.premake, diff --git a/src/ntops/torch/nextafter.py b/src/ntops/torch/nextafter.py index b74ac73..6276ff2 100644 --- a/src/ntops/torch/nextafter.py +++ b/src/ntops/torch/nextafter.py @@ -3,9 +3,33 @@ import torch import ntops +from ntops.torch import _iluvatar_triton from ntops.torch.utils import _cached_make, _flatten_kernel_tensors, _prepare_out +@functools.cache +def _is_iluvatar_device(index): + if not hasattr(torch, "cuda"): + return False + if not torch.cuda.is_available(): + return False + try: + return "Iluvatar" in torch.cuda.get_device_name(index) + except Exception: + return False + + +def _use_iluvatar_device(tensor): + if tensor.device.type != "cuda": + return False + if not hasattr(torch, "cuda"): + return False + index = tensor.device.index + if index is None: + index = torch.cuda.current_device() + return _is_iluvatar_device(index) + + def _broadcast(input, other): if hasattr(torch, "broadcast_tensors"): return torch.broadcast_tensors(input, other) @@ -23,29 +47,39 @@ def _prepare_inputs(input, other): return input.to(result_dtype), other.to(result_dtype), result_dtype +def _kernel_config(half, double, iluvatar): + if iluvatar and not half and not double: + return 256, 4 + return ntops.kernels.nextafter.BLOCK_SIZE, 1 + + @functools.cache -def _get_kernel_1d(half, double): +def _get_kernel_1d(half, double, iluvatar=False): + block_size, num_warps = _kernel_config(half, double, iluvatar) return _cached_make( ntops.kernels.nextafter.premake, 1, half, double, - block_size=ntops.kernels.nextafter.BLOCK_SIZE, - num_warps=1, + iluvatar=iluvatar and not half and not double, + block_size=block_size, + num_warps=num_warps, max_num_configs=1, ) @functools.cache -def _get_broadcast_2d_kernel(half, double): +def _get_broadcast_2d_kernel(half, double, iluvatar=False): + block_size, num_warps = _kernel_config(half, double, iluvatar) return _cached_make( ntops.kernels.nextafter.premake, 2, half, double, True, - block_size=512, - num_warps=1, + iluvatar=iluvatar and not half and not double, + block_size=block_size, + num_warps=num_warps, max_num_configs=1, ) @@ -62,12 +96,19 @@ def nextafter(input, other, *, out=None): ): half = input.dtype == torch.float16 double = input.dtype == torch.float64 + iluvatar = _use_iluvatar_device(input) if out is None: out = torch.empty_like(input) - _get_kernel_1d(half, double)(input, other, out) + if half and iluvatar: + _iluvatar_triton.nextafter_f16_1d(input, other, out) + return out + _get_kernel_1d(half, double, iluvatar)(input, other, out) return out if tuple(out.shape) == tuple(input.shape) and out.dtype == input.dtype and out.is_contiguous(): - _get_kernel_1d(half, double)(input, other, out) + if half and iluvatar: + _iluvatar_triton.nextafter_f16_1d(input, other, out) + return out + _get_kernel_1d(half, double, iluvatar)(input, other, out) return out if ( @@ -85,8 +126,12 @@ def nextafter(input, other, *, out=None): cols = other.shape[1] half = input.dtype == torch.float16 double = input.dtype == torch.float64 + iluvatar = _use_iluvatar_device(input) out = torch.empty((rows, cols), dtype=input.dtype, device=input.device) - _get_broadcast_2d_kernel(half, double)( + if input.dtype == torch.float32 and iluvatar: + _iluvatar_triton.nextafter_f32_broadcast(input, other, out) + return out + _get_broadcast_2d_kernel(half, double, iluvatar)( input, other, out, @@ -100,13 +145,18 @@ def nextafter(input, other, *, out=None): kernel_input, kernel_other, kernel_out = _flatten_kernel_tensors(input, other, out) half = hasattr(torch, "float16") and input.dtype == torch.float16 double = hasattr(torch, "float64") and input.dtype == torch.float64 + iluvatar = _use_iluvatar_device(input) + if half and iluvatar and kernel_input.ndim == 1 and kernel_input.is_contiguous(): + _iluvatar_triton.nextafter_f16_1d(kernel_input, kernel_other, kernel_out) + return out kernel = _cached_make( ntops.kernels.nextafter.premake, kernel_input.ndim, half, double, - block_size=ntops.kernels.nextafter.BLOCK_SIZE, - num_warps=1, + iluvatar=iluvatar and not half and not double, + block_size=_kernel_config(half, double, iluvatar)[0], + num_warps=_kernel_config(half, double, iluvatar)[1], max_num_configs=1, ) kernel(kernel_input, kernel_other, kernel_out) diff --git a/src/ntops/torch/rad2deg.py b/src/ntops/torch/rad2deg.py index c393b16..9d8529c 100644 --- a/src/ntops/torch/rad2deg.py +++ b/src/ntops/torch/rad2deg.py @@ -1,23 +1,56 @@ +import functools + import torch import ntops +from ntops.torch import _iluvatar_triton from ntops.torch.utils import _cached_make, _flatten_kernel_tensors, _prepare_out -_kernel_1d = None +@functools.cache +def _is_iluvatar_device(index): + if not hasattr(torch, "cuda"): + return False + if not torch.cuda.is_available(): + return False + try: + return "Iluvatar" in torch.cuda.get_device_name(index) + except Exception: + return False + + +def _use_iluvatar_double_kernel(tensor): + if tensor.dtype != torch.float64 or tensor.device.type != "cuda": + return False + if not hasattr(torch, "cuda"): + return False + index = tensor.device.index + if index is None: + index = torch.cuda.current_device() + return _is_iluvatar_device(index) -def _get_kernel_1d(): - global _kernel_1d - if _kernel_1d is None: - _kernel_1d = _cached_make( - ntops.kernels.rad2deg.premake, - 1, - block_size=ntops.kernels.rad2deg.BLOCK_SIZE, - num_warps=2, - max_num_configs=1, - ) - return _kernel_1d +def _use_iluvatar_device(tensor): + if tensor.device.type != "cuda": + return False + if not hasattr(torch, "cuda"): + return False + index = tensor.device.index + if index is None: + index = torch.cuda.current_device() + return _is_iluvatar_device(index) + + +@functools.cache +def _get_kernel_1d(iluvatar_double=False): + return _cached_make( + ntops.kernels.rad2deg.premake, + 1, + block_size=ntops.kernels.rad2deg.BLOCK_SIZE, + iluvatar_double=iluvatar_double, + num_warps=2, + max_num_configs=1, + ) def _promote_unary_input(input): @@ -28,22 +61,41 @@ def _promote_unary_input(input): def rad2deg(input, *, out=None): input = _promote_unary_input(input) + if input.ndim == 1 and input.is_contiguous(): + iluvatar_double = _use_iluvatar_double_kernel(input) if out is None: out = torch.empty_like(input) - _get_kernel_1d()(input, out) + if input.dtype == torch.float32 and _iluvatar_triton.is_iluvatar_device(input): + _iluvatar_triton.rad2deg_f32_1d(input, out) + return out + _get_kernel_1d(iluvatar_double)(input, out) return out if tuple(out.shape) == tuple(input.shape) and out.dtype == input.dtype and out.is_contiguous(): - _get_kernel_1d()(input, out) + if input.dtype == torch.float32 and _iluvatar_triton.is_iluvatar_device(input): + _iluvatar_triton.rad2deg_f32_1d(input, out) + return out + _get_kernel_1d(iluvatar_double)(input, out) return out out = _prepare_out(out, input.shape, input.dtype, input.device, like=input) kernel_input, kernel_out = _flatten_kernel_tensors(input, out) + if ( + input.dtype == torch.float32 + and _iluvatar_triton.is_iluvatar_device(input) + and kernel_input.ndim == 1 + and kernel_out.ndim == 1 + and kernel_input.is_contiguous() + and kernel_out.is_contiguous() + ): + _iluvatar_triton.rad2deg_f32_1d(kernel_input, kernel_out) + return out kernel = _cached_make( ntops.kernels.rad2deg.premake, kernel_input.ndim, block_size=ntops.kernels.rad2deg.BLOCK_SIZE, + iluvatar_double=_use_iluvatar_double_kernel(input), num_warps=2, max_num_configs=1, ) From 66adbec504a7d5e5c87a71cdd60f080d7eee0ec4 Mon Sep 17 00:00:00 2001 From: HyosungSink Date: Wed, 20 May 2026 01:18:47 +0800 Subject: [PATCH 9/9] Fix Iluvatar nextafter float16 broadcast --- src/ntops/torch/_iluvatar_triton.py | 171 ++++++++++++++++++++++++++++ src/ntops/torch/nextafter.py | 19 +++- 2 files changed, 188 insertions(+), 2 deletions(-) diff --git a/src/ntops/torch/_iluvatar_triton.py b/src/ntops/torch/_iluvatar_triton.py index e3385ab..9ffcbab 100644 --- a/src/ntops/torch/_iluvatar_triton.py +++ b/src/ntops/torch/_iluvatar_triton.py @@ -111,6 +111,38 @@ def _nextafter_f32_broadcast_kernel( tl.store(output + offsets, value, mask=mask) +@triton.jit +def _nextafter_f16_broadcast_kernel( + input, + other, + output, + cols: tl.constexpr, + n: tl.constexpr, + block: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * block + tl.arange(0, block) + mask = offsets < n + row = offsets // cols + col = offsets - row * cols + input_value = tl.load(input + row, mask=mask, other=0.0) + other_value = tl.load(other + col, mask=mask, other=0.0) + + bits = input_value.to(tl.uint16, bitcast=True) + increment = tl.where(input_value > 0, other_value > input_value, other_value < input_value) + next_bits = tl.where(increment, bits + 1, bits - 1).to(tl.uint16) + value = next_bits.to(tl.float16, bitcast=True) + + zero_value = tl.where(other_value < 0, -5.960464477539063e-08, 5.960464477539063e-08) + zero_value = zero_value.to(tl.float16) + zero_value = tl.where(other_value == 0, other_value, zero_value) + value = tl.where(input_value == 0, zero_value, value) + value = tl.where(input_value == other_value, other_value, value) + value = tl.where(input_value != input_value, input_value, value) + value = tl.where(other_value != other_value, other_value, value) + tl.store(output + offsets, value, mask=mask) + + @triton.jit def _nextafter_f16_kernel(input, other, output, n: tl.constexpr, block: tl.constexpr): pid = tl.program_id(0) @@ -134,6 +166,97 @@ def _nextafter_f16_kernel(input, other, output, n: tl.constexpr, block: tl.const tl.store(output + offsets, value, mask=mask) +@triton.jit +def _nextafter_f16_strided_kernel( + input, + other, + output, + n: tl.constexpr, + d0: tl.constexpr, + d1: tl.constexpr, + d2: tl.constexpr, + d3: tl.constexpr, + d4: tl.constexpr, + d5: tl.constexpr, + input_s0: tl.constexpr, + input_s1: tl.constexpr, + input_s2: tl.constexpr, + input_s3: tl.constexpr, + input_s4: tl.constexpr, + input_s5: tl.constexpr, + other_s0: tl.constexpr, + other_s1: tl.constexpr, + other_s2: tl.constexpr, + other_s3: tl.constexpr, + other_s4: tl.constexpr, + other_s5: tl.constexpr, + output_s0: tl.constexpr, + output_s1: tl.constexpr, + output_s2: tl.constexpr, + output_s3: tl.constexpr, + output_s4: tl.constexpr, + output_s5: tl.constexpr, + block: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * block + tl.arange(0, block) + mask = offsets < n + + rem = offsets + i5 = rem % d5 + rem = rem // d5 + i4 = rem % d4 + rem = rem // d4 + i3 = rem % d3 + rem = rem // d3 + i2 = rem % d2 + rem = rem // d2 + i1 = rem % d1 + i0 = rem // d1 + + input_offsets = ( + i0 * input_s0 + + i1 * input_s1 + + i2 * input_s2 + + i3 * input_s3 + + i4 * input_s4 + + i5 * input_s5 + ) + other_offsets = ( + i0 * other_s0 + + i1 * other_s1 + + i2 * other_s2 + + i3 * other_s3 + + i4 * other_s4 + + i5 * other_s5 + ) + output_offsets = ( + i0 * output_s0 + + i1 * output_s1 + + i2 * output_s2 + + i3 * output_s3 + + i4 * output_s4 + + i5 * output_s5 + ) + + input_value = tl.load(input + input_offsets, mask=mask, other=0.0) + other_value = tl.load(other + other_offsets, mask=mask, other=0.0) + + bits = input_value.to(tl.uint16, bitcast=True) + increment = tl.where(input_value > 0, other_value > input_value, other_value < input_value) + next_bits = tl.where(increment, bits + 1, bits - 1).to(tl.uint16) + value = next_bits.to(tl.float16, bitcast=True) + + zero_value = tl.where(other_value < 0, -5.960464477539063e-08, 5.960464477539063e-08) + zero_value = zero_value.to(tl.float16) + zero_value = tl.where(other_value == 0, other_value, zero_value) + value = tl.where(input_value == 0, zero_value, value) + value = tl.where(input_value == other_value, other_value, value) + value = tl.where(input_value != input_value, input_value, value) + value = tl.where(other_value != other_value, other_value, value) + tl.store(output + output_offsets, value, mask=mask) + + @triton.jit def _lcm_small_or_dynamic_kernel( input, @@ -494,6 +617,23 @@ def nextafter_f32_broadcast(input, other, output): ) +def nextafter_f16_broadcast(input, other, output): + rows = input.shape[0] + cols = other.shape[1] + n = rows * cols + block = 256 + grid = (triton.cdiv(n, block),) + _nextafter_f16_broadcast_kernel[grid]( + input, + other, + output, + cols, + n, + block=block, + num_warps=4, + ) + + def nextafter_f16_1d(input, other, output): n = input.numel() block = 512 @@ -508,6 +648,37 @@ def nextafter_f16_1d(input, other, output): ) +def _shape_and_strides_6d(tensor): + shape = tuple(tensor.shape) + strides = tuple(tensor.stride()) + pad = 6 - len(shape) + return (1,) * pad + shape, (0,) * pad + strides + + +def nextafter_f16_strided(input, other, output): + if input.ndim > 6: + return False + n = output.numel() + block = 512 + grid = (triton.cdiv(n, block),) + shape, input_strides = _shape_and_strides_6d(input) + _, other_strides = _shape_and_strides_6d(other) + _, output_strides = _shape_and_strides_6d(output) + _nextafter_f16_strided_kernel[grid]( + input, + other, + output, + n, + *shape, + *input_strides, + *other_strides, + *output_strides, + block=block, + num_warps=4, + ) + return True + + def lcm_1d(input, other, output, max_iterations, absolute_output): n = input.numel() block = 128 if output.element_size() >= 4 else 256 diff --git a/src/ntops/torch/nextafter.py b/src/ntops/torch/nextafter.py index 6276ff2..978cbbf 100644 --- a/src/ntops/torch/nextafter.py +++ b/src/ntops/torch/nextafter.py @@ -128,6 +128,9 @@ def nextafter(input, other, *, out=None): double = input.dtype == torch.float64 iluvatar = _use_iluvatar_device(input) out = torch.empty((rows, cols), dtype=input.dtype, device=input.device) + if half and iluvatar: + _iluvatar_triton.nextafter_f16_broadcast(input, other, out) + return out if input.dtype == torch.float32 and iluvatar: _iluvatar_triton.nextafter_f32_broadcast(input, other, out) return out @@ -142,13 +145,25 @@ def nextafter(input, other, *, out=None): input, other, result_dtype = _prepare_inputs(input, other) out = _prepare_out(out, input.shape, result_dtype, input.device, like=input) - kernel_input, kernel_other, kernel_out = _flatten_kernel_tensors(input, other, out) half = hasattr(torch, "float16") and input.dtype == torch.float16 double = hasattr(torch, "float64") and input.dtype == torch.float64 iluvatar = _use_iluvatar_device(input) - if half and iluvatar and kernel_input.ndim == 1 and kernel_input.is_contiguous(): + + kernel_input, kernel_other, kernel_out = _flatten_kernel_tensors(input, other, out) + if ( + half + and iluvatar + and kernel_input.ndim == 1 + and kernel_input.is_contiguous() + and kernel_other.ndim == 1 + and kernel_other.is_contiguous() + and kernel_out.ndim == 1 + and kernel_out.is_contiguous() + ): _iluvatar_triton.nextafter_f16_1d(kernel_input, kernel_other, kernel_out) return out + if half and iluvatar and _iluvatar_triton.nextafter_f16_strided(input, other, out): + return out kernel = _cached_make( ntops.kernels.nextafter.premake, kernel_input.ndim,