From a2add0154bf605510594adf984f0cee2710e6040 Mon Sep 17 00:00:00 2001 From: mygitljf <2410316423@qq.com> Date: Mon, 18 May 2026 12:47:50 +0000 Subject: [PATCH 1/7] [2026 spring][T1-1-1] add rad2deg/copysign/lcm/nextafter/lgamma Five element-wise operators implemented in NineToothed DSL on top of ntops.kernels.element_wise.arrangement. - rad2deg: input * (180 / pi). - copysign: IEEE bit manipulation; bitcast to int, splice magnitude bits of input with sign bit of other. Avoids the fp16/bf16 -> fp32 cast required by libdevice.copysign and is bit-exact across fp16/bf16/fp32. - nextafter: IEEE bit manipulation; handle NaN/equal/zero specially, then walk one ULP toward the target in integer space. Bit-exact across fp16/bf16/fp32. - lcm: Euclidean GCD with dtype-bucketed unroll counts (28 for int8/int16, 56 for int32, 104 for int64) -- the worst-case Fibonacci bound. The wrapper passes dtype to _cached_make so the per-dtype application is selected correctly. - lgamma: libdevice.lgamma; cast fp16/bf16 up to fp32 since libdevice does not support narrow floats here. Tests cover float16/float32 (lcm: int8/int16/int32) for ndim 1-4 against PyTorch references using torch.allclose / torch.equal. Verified on NVIDIA A100 80GB PCIe (44 passed, 4 skipped for bool lcm). --- src/ntops/kernels/__init__.py | 10 ++++ src/ntops/kernels/copysign.py | 39 ++++++++++++++ src/ntops/kernels/lcm.py | 96 ++++++++++++++++++++++++++++++++++ src/ntops/kernels/lgamma.py | 29 ++++++++++ src/ntops/kernels/nextafter.py | 63 ++++++++++++++++++++++ src/ntops/kernels/rad2deg.py | 17 ++++++ src/ntops/torch/__init__.py | 10 ++++ src/ntops/torch/copysign.py | 15 ++++++ src/ntops/torch/lcm.py | 26 +++++++++ src/ntops/torch/lgamma.py | 15 ++++++ src/ntops/torch/nextafter.py | 15 ++++++ src/ntops/torch/rad2deg.py | 15 ++++++ tests/test_copysign.py | 18 +++++++ tests/test_lcm.py | 26 +++++++++ tests/test_lgamma.py | 19 +++++++ tests/test_nextafter.py | 20 +++++++ tests/test_rad2deg.py | 17 ++++++ 17 files changed, 450 insertions(+) create mode 100644 src/ntops/kernels/copysign.py create mode 100644 src/ntops/kernels/lcm.py create mode 100644 src/ntops/kernels/lgamma.py create mode 100644 src/ntops/kernels/nextafter.py create mode 100644 src/ntops/kernels/rad2deg.py create mode 100644 src/ntops/torch/copysign.py create mode 100644 src/ntops/torch/lcm.py create mode 100644 src/ntops/torch/lgamma.py create mode 100644 src/ntops/torch/nextafter.py create mode 100644 src/ntops/torch/rad2deg.py create mode 100644 tests/test_copysign.py create mode 100644 tests/test_lcm.py create mode 100644 tests/test_lgamma.py create mode 100644 tests/test_nextafter.py create mode 100644 tests/test_rad2deg.py diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..8d0c187 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -9,6 +9,7 @@ bmm, clamp, conv2d, + copysign, cos, div, dropout, @@ -20,14 +21,18 @@ isinf, isnan, layer_norm, + lcm, le, + lgamma, lt, max_pool2d, mm, mul, ne, neg, + nextafter, pow, + rad2deg, relu, rms_norm, rotary_position_embedding, @@ -52,6 +57,7 @@ "bmm", "clamp", "conv2d", + "copysign", "cos", "div", "dropout", @@ -63,14 +69,18 @@ "isinf", "isnan", "layer_norm", + "lcm", "le", + "lgamma", "lt", "max_pool2d", "mm", "mul", "ne", "neg", + "nextafter", "pow", + "rad2deg", "relu", "rms_norm", "rotary_position_embedding", diff --git a/src/ntops/kernels/copysign.py b/src/ntops/kernels/copysign.py new file mode 100644 index 0000000..bad16e5 --- /dev/null +++ b/src/ntops/kernels/copysign.py @@ -0,0 +1,39 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, other, output): + # Pure bit manipulation: take magnitude bits of input, sign bit of other. + # Avoids the fp16/bf16 -> fp32 -> fp16/bf16 round-trip required by + # libdevice.copysign, which doesn't support narrow floats. + dtype = output.dtype + if dtype == ntl.float16 or dtype == ntl.bfloat16: + int_dtype = ntl.int16 + elif dtype == ntl.float32: + int_dtype = ntl.int32 + else: + int_dtype = ntl.int64 + + input_bits = ntl.cast(input, int_dtype, bitcast=True) + other_bits = ntl.cast(other, int_dtype, bitcast=True) + sign_bit = ntl.cast(1, int_dtype) << (ntl.cast(input, int_dtype, bitcast=True).dtype.primitive_bitwidth - 1) + magn_mask = sign_bit - ntl.cast(1, int_dtype) + output = ntl.cast( # noqa: F841 + (input_bits & magn_mask) | (other_bits & sign_bit), dtype, bitcast=True + ) + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/lcm.py b/src/ntops/kernels/lcm.py new file mode 100644 index 0000000..b51db20 --- /dev/null +++ b/src/ntops/kernels/lcm.py @@ -0,0 +1,96 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +# Match PyTorch's CUDA C++ integer promotion: narrow ints (int8/int16) +# are promoted to int32 for arithmetic and abs, then truncated back. +# Worst-case Euclidean iteration count is bounded by adjacent Fibonacci +# F(N) > 2^bits, so we unroll just enough per dtype to avoid wasted work. +def application_28(input, other, output): + dtype = output.dtype + compute_dtype = ( + ntl.int32 if dtype == ntl.int8 or dtype == ntl.int16 else dtype + ) + abs_a = ntl.abs(ntl.cast(input, compute_dtype)) + abs_b = ntl.abs(ntl.cast(other, compute_dtype)) + a, b = abs_a, abs_b + for _ in range(28): + nonzero = b != 0 + safe_b = ntl.where(nonzero, b, 1) + new_a = ntl.where(nonzero, b, a) + new_b = ntl.where(nonzero, a % safe_b, b) + a = new_a + b = new_b + gcd = a + safe_gcd = ntl.where(gcd == 0, 1, gcd) + output = ntl.cast( # noqa: F841 + ntl.where(gcd == 0, 0, ntl.abs((abs_a // safe_gcd) * abs_b)), dtype + ) + + +def application_56(input, other, output): + dtype = output.dtype + compute_dtype = ( + ntl.int32 if dtype == ntl.int8 or dtype == ntl.int16 else dtype + ) + abs_a = ntl.abs(ntl.cast(input, compute_dtype)) + abs_b = ntl.abs(ntl.cast(other, compute_dtype)) + a, b = abs_a, abs_b + for _ in range(56): + nonzero = b != 0 + safe_b = ntl.where(nonzero, b, 1) + new_a = ntl.where(nonzero, b, a) + new_b = ntl.where(nonzero, a % safe_b, b) + a = new_a + b = new_b + gcd = a + safe_gcd = ntl.where(gcd == 0, 1, gcd) + output = ntl.cast( # noqa: F841 + ntl.where(gcd == 0, 0, ntl.abs((abs_a // safe_gcd) * abs_b)), dtype + ) + + +def application_104(input, other, output): + dtype = output.dtype + compute_dtype = ( + ntl.int32 if dtype == ntl.int8 or dtype == ntl.int16 else dtype + ) + abs_a = ntl.abs(ntl.cast(input, compute_dtype)) + abs_b = ntl.abs(ntl.cast(other, compute_dtype)) + a, b = abs_a, abs_b + for _ in range(104): + nonzero = b != 0 + safe_b = ntl.where(nonzero, b, 1) + new_a = ntl.where(nonzero, b, a) + new_b = ntl.where(nonzero, a % safe_b, b) + a = new_a + b = new_b + gcd = a + safe_gcd = ntl.where(gcd == 0, 1, gcd) + output = ntl.cast( # noqa: F841 + ntl.where(gcd == 0, 0, ntl.abs((abs_a // safe_gcd) * abs_b)), dtype + ) + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + if dtype == ninetoothed.int64: + application = application_104 + elif dtype == ninetoothed.int32: + application = application_56 + else: + application = application_28 + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/lgamma.py b/src/ntops/kernels/lgamma.py new file mode 100644 index 0000000..065f207 --- /dev/null +++ b/src/ntops/kernels/lgamma.py @@ -0,0 +1,29 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor +from ninetoothed.language import libdevice + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + # libdevice.lgamma only supports fp32/fp64; cast narrower floats up. + dtype = output.dtype + compute_dtype = ( + dtype + if dtype != ntl.float16 and dtype != ntl.bfloat16 + else ntl.float32 + ) + output = ntl.cast( # noqa: F841 + libdevice.lgamma(ntl.cast(input, compute_dtype)), + dtype, + ) + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/nextafter.py b/src/ntops/kernels/nextafter.py new file mode 100644 index 0000000..1110b1c --- /dev/null +++ b/src/ntops/kernels/nextafter.py @@ -0,0 +1,63 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, other, output): + # PyTorch nextafter spec, implemented via IEEE bit manipulation: + # if either is NaN: result is NaN + # if a == b: result is b (preserves sign of zero) + # if a == 0: result is smallest subnormal with sign of b + # otherwise: walk one ULP toward b in IEEE bit space + dtype = output.dtype + if dtype == ntl.float16 or dtype == ntl.bfloat16: + int_dtype = ntl.int16 + elif dtype == ntl.float32: + int_dtype = ntl.int32 + else: + int_dtype = ntl.int64 + + a = input + b = other + a_i = ntl.cast(a, int_dtype, bitcast=True) + b_i = ntl.cast(b, int_dtype, bitcast=True) + + one = ntl.cast(1, int_dtype) + zero = ntl.cast(0, int_dtype) + sign_bit = one << (a_i.dtype.primitive_bitwidth - 1) + + is_nan = (a != a) | (b != b) + eq = a == b + is_zero = a == ntl.cast(0, dtype) + + b_sign = b_i & sign_bit + zero_result = b_sign | one + + a_neg = a_i < zero + a_lt_b = a < b + step_up = a_neg ^ a_lt_b + step = ntl.where(step_up, one, -one) + general = a_i + step + + nan_bits = ntl.cast(ntl.cast(float("nan"), dtype), int_dtype, bitcast=True) + result_i = ntl.where( + is_nan, + nan_bits, + ntl.where(eq, b_i, ntl.where(is_zero, zero_result, general)), + ) + output = ntl.cast(result_i, dtype, bitcast=True) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/rad2deg.py b/src/ntops/kernels/rad2deg.py new file mode 100644 index 0000000..9371f22 --- /dev/null +++ b/src/ntops/kernels/rad2deg.py @@ -0,0 +1,17 @@ +import functools + +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + output = input * 57.29577951308232 # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..5732b72 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -8,6 +8,7 @@ from ntops.torch.bmm import bmm from ntops.torch.clamp import clamp from ntops.torch.conv2d import conv2d +from ntops.torch.copysign import copysign from ntops.torch.cos import cos from ntops.torch.div import div from ntops.torch.dropout import dropout @@ -19,7 +20,9 @@ from ntops.torch.isinf import isinf from ntops.torch.isnan import isnan from ntops.torch.layer_norm import layer_norm +from ntops.torch.lcm import lcm from ntops.torch.le import le +from ntops.torch.lgamma import lgamma from ntops.torch.lt import lt from ntops.torch.matmul import matmul from ntops.torch.max_pool2d import max_pool2d @@ -27,7 +30,9 @@ from ntops.torch.mul import mul from ntops.torch.ne import ne from ntops.torch.neg import neg +from ntops.torch.nextafter import nextafter from ntops.torch.pow import pow +from ntops.torch.rad2deg import rad2deg from ntops.torch.relu import relu from ntops.torch.rms_norm import rms_norm from ntops.torch.rotary_position_embedding import rotary_position_embedding @@ -51,6 +56,7 @@ "bmm", "clamp", "conv2d", + "copysign", "cos", "div", "dropout", @@ -62,7 +68,9 @@ "isinf", "isnan", "layer_norm", + "lcm", "le", + "lgamma", "lt", "matmul", "max_pool2d", @@ -70,7 +78,9 @@ "mul", "ne", "neg", + "nextafter", "pow", + "rad2deg", "relu", "rms_norm", "rotary_position_embedding", diff --git a/src/ntops/torch/copysign.py b/src/ntops/torch/copysign.py new file mode 100644 index 0000000..252eb10 --- /dev/null +++ b/src/ntops/torch/copysign.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def copysign(input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.copysign.premake, input.ndim) + + kernel(input, other, out) + + return out diff --git a/src/ntops/torch/lcm.py b/src/ntops/torch/lcm.py new file mode 100644 index 0000000..006ae9a --- /dev/null +++ b/src/ntops/torch/lcm.py @@ -0,0 +1,26 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def lcm(input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.lcm.premake, input.ndim, dtype=_to_nt(input.dtype)) + + kernel(input, other, out) + + return out + + +def _to_nt(torch_dtype): + import ninetoothed + mapping = { + torch.int8: ninetoothed.int8, + torch.int16: ninetoothed.int16, + torch.int32: ninetoothed.int32, + torch.int64: ninetoothed.int64, + } + return mapping.get(torch_dtype) diff --git a/src/ntops/torch/lgamma.py b/src/ntops/torch/lgamma.py new file mode 100644 index 0000000..b1fed7c --- /dev/null +++ b/src/ntops/torch/lgamma.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def lgamma(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.lgamma.premake, input.ndim) + + kernel(input, out) + + return out diff --git a/src/ntops/torch/nextafter.py b/src/ntops/torch/nextafter.py new file mode 100644 index 0000000..c2c173d --- /dev/null +++ b/src/ntops/torch/nextafter.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def nextafter(input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.nextafter.premake, input.ndim) + + kernel(input, other, out) + + return out diff --git a/src/ntops/torch/rad2deg.py b/src/ntops/torch/rad2deg.py new file mode 100644 index 0000000..f6896b1 --- /dev/null +++ b/src/ntops/torch/rad2deg.py @@ -0,0 +1,15 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def rad2deg(input, *, out=None): + if out is None: + out = torch.empty_like(input) + + kernel = _cached_make(ntops.kernels.rad2deg.premake, input.ndim) + + kernel(input, out) + + return out diff --git a/tests/test_copysign.py b/tests/test_copysign.py new file mode 100644 index 0000000..43e2ecf --- /dev/null +++ b/tests/test_copysign.py @@ -0,0 +1,18 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_copysign(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + other = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.copysign(input, other) + reference_output = torch.copysign(input, other) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) diff --git a/tests/test_lcm.py b/tests/test_lcm.py new file mode 100644 index 0000000..c4d343e --- /dev/null +++ b/tests/test_lcm.py @@ -0,0 +1,26 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments(False)) +def test_lcm(shape, dtype, device, rtol, atol): + if dtype == torch.bool: + pytest.skip("torch.lcm does not support bool dtype") + + upper_bound = 100 + input = torch.randint( + -upper_bound, upper_bound, size=shape, dtype=dtype, device=device + ) + other = torch.randint( + -upper_bound, upper_bound, size=shape, dtype=dtype, device=device + ) + + ninetoothed_output = ntops.torch.lcm(input, other) + reference_output = torch.lcm(input, other) + + assert torch.equal(ninetoothed_output, reference_output) diff --git a/tests/test_lgamma.py b/tests/test_lgamma.py new file mode 100644 index 0000000..f387f7e --- /dev/null +++ b/tests/test_lgamma.py @@ -0,0 +1,19 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_lgamma(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device).abs() + 0.1 + + ninetoothed_output = ntops.torch.lgamma(input) + reference_output = torch.lgamma(input) + + assert torch.allclose( + ninetoothed_output, reference_output, rtol=rtol, atol=atol, equal_nan=True + ) diff --git a/tests/test_nextafter.py b/tests/test_nextafter.py new file mode 100644 index 0000000..d021a6f --- /dev/null +++ b/tests/test_nextafter.py @@ -0,0 +1,20 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_nextafter(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + other = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.nextafter(input, other) + reference_output = torch.nextafter(input, other) + + assert torch.allclose( + ninetoothed_output, reference_output, rtol=rtol, atol=atol, equal_nan=True + ) diff --git a/tests/test_rad2deg.py b/tests/test_rad2deg.py new file mode 100644 index 0000000..222161d --- /dev/null +++ b/tests/test_rad2deg.py @@ -0,0 +1,17 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_rad2deg(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.rad2deg(input) + reference_output = torch.rad2deg(input) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) From 8b7492485d8f536e59bc47900d29602530ac27a4 Mon Sep 17 00:00:00 2001 From: mygitljf <2410316423@qq.com> Date: Mon, 18 May 2026 16:02:23 +0000 Subject: [PATCH 2/7] [2026 spring][T1-1-1] lcm: switch to Stein's binary GCD with per-dtype unroll MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Iter02+03 optimization for the lcm operator. Replaces the Euclidean GCD inner loop (which used 64-bit IDIV/IREM, ~100+ cycles per op on A100) with Stein's binary GCD using libdevice.ffs (ctz) + shift + subtract. Per-dtype unroll buckets sized to the worst-case Stein iteration count: application_32 : int8 / int16 (promoted to int32, value range <= 32767) application_64 : int32 application_128 : int64 ncu profile (sudo) on application_104 (iter01) showed Compute 66.4%, DRAM 1.1%, 39 regs/thread - confirming pure compute bound on integer division. After Stein: device time at (1024, 5632), ratio = torch / ntops: iter01 iter02 iter03 iter01 -> iter03 lcm int8 0.43 0.33 0.59 +37% lcm int16 0.66 0.52 0.92 +39% lcm int32 0.46 0.83 0.83 +80% lcm int64 0.36 0.76 0.76 +111% Correctness: ntops pytest 12 passed + 4 skipped (bool); InfiniCore --nvidia 40/40 passed. Bit-exact against torch.lcm including int8/int16 two's-complement wrap-around behaviour. Constraints learned (also in docs/算子开发经验.md): - NineToothed AST does not cross Python function boundaries cleanly; helper functions returning tuples raise compilation errors. All prelude / step / finish blocks must be inlined per application_. - libdevice.ffs(0) returns 0; ffs(x>0) returns 1 + ctz(x). Need where() masks at every ctz site. Detailed iteration logs in docs/logs/T1-1-1/{iter01,iter02,iter03}/. --- src/ntops/kernels/lcm.py | 140 ++++++++++++++++++++++++++++----------- 1 file changed, 103 insertions(+), 37 deletions(-) diff --git a/src/ntops/kernels/lcm.py b/src/ntops/kernels/lcm.py index b51db20..1fa59eb 100644 --- a/src/ntops/kernels/lcm.py +++ b/src/ntops/kernels/lcm.py @@ -3,74 +3,140 @@ import ninetoothed import ninetoothed.language as ntl from ninetoothed import Tensor +from ninetoothed.language import libdevice from ntops.kernels.element_wise import arrangement -# Match PyTorch's CUDA C++ integer promotion: narrow ints (int8/int16) -# are promoted to int32 for arithmetic and abs, then truncated back. -# Worst-case Euclidean iteration count is bounded by adjacent Fibonacci -# F(N) > 2^bits, so we unroll just enough per dtype to avoid wasted work. -def application_28(input, other, output): +# Binary GCD (Stein's algorithm), inlined per dtype-bucketed unroll count. +# PyTorch CUDA lcm semantics: +# lcm(a, b) = abs((|a| / gcd) * |b|) with two's-complement wrap. +# Narrow ints (int8/int16) are promoted to int32 for arithmetic, then +# truncated back, so wrap-around matches PyTorch. +# +# Algorithm (every step keeps a and b odd): +# shift b right by ctz(b) so b is odd +# diff = b_odd - a (signed) +# new_a = min(a, b_odd) +# new_b = abs(diff) +# After enough steps b == 0; gcd = a << k where k = ctz(|input| | |other|). +# +# `libdevice.ffs(x)` returns 1 + ctz(x) for x != 0, and 0 for x == 0. +# +# NineToothed AST does not cross Python function boundaries cleanly +# (tuple returns from helpers raise compilation errors), so the prelude / +# step / finish blocks are duplicated inline in each application_. +# +# Worst-case Stein iterations: each step removes at least one bit of +# information from b, so 2 * bit_width is a safe upper bound: +# int8/int16 (promoted to int32, but value range still ≤ 32767): ~32 +# int32: 64 +# int64: 128 +def application_32(input, other, output): dtype = output.dtype compute_dtype = ( ntl.int32 if dtype == ntl.int8 or dtype == ntl.int16 else dtype ) + abs_a = ntl.abs(ntl.cast(input, compute_dtype)) abs_b = ntl.abs(ntl.cast(other, compute_dtype)) - a, b = abs_a, abs_b - for _ in range(28): - nonzero = b != 0 - safe_b = ntl.where(nonzero, b, 1) - new_a = ntl.where(nonzero, b, a) - new_b = ntl.where(nonzero, a % safe_b, b) - a = new_a - b = new_b - gcd = a + + or_ab = abs_a | abs_b + safe_or = ntl.where(or_ab != 0, or_ab, 1) + k = ntl.cast(libdevice.ffs(safe_or) - 1, compute_dtype) + a0 = abs_a >> k + b0 = abs_b >> k + + nonzero_a0 = a0 != 0 + safe_a0 = ntl.where(nonzero_a0, a0, 1) + ctz_a0 = ntl.cast(libdevice.ffs(safe_a0) - 1, compute_dtype) + a = ntl.where(nonzero_a0, a0 >> ctz_a0, b0) + b = ntl.where(nonzero_a0, b0, ntl.cast(0, compute_dtype)) + + for _ in range(32): + nonzero_b = b != 0 + safe_b = ntl.where(nonzero_b, b, 1) + ctz_b = ntl.cast(libdevice.ffs(safe_b) - 1, compute_dtype) + b_odd = b >> ctz_b + diff = b_odd - a + a = ntl.where(nonzero_b, ntl.minimum(a, b_odd), a) + b = ntl.where(nonzero_b, ntl.abs(diff), b) + + gcd = a << k safe_gcd = ntl.where(gcd == 0, 1, gcd) output = ntl.cast( # noqa: F841 ntl.where(gcd == 0, 0, ntl.abs((abs_a // safe_gcd) * abs_b)), dtype ) -def application_56(input, other, output): +def application_64(input, other, output): dtype = output.dtype compute_dtype = ( ntl.int32 if dtype == ntl.int8 or dtype == ntl.int16 else dtype ) + abs_a = ntl.abs(ntl.cast(input, compute_dtype)) abs_b = ntl.abs(ntl.cast(other, compute_dtype)) - a, b = abs_a, abs_b - for _ in range(56): - nonzero = b != 0 - safe_b = ntl.where(nonzero, b, 1) - new_a = ntl.where(nonzero, b, a) - new_b = ntl.where(nonzero, a % safe_b, b) - a = new_a - b = new_b - gcd = a + + or_ab = abs_a | abs_b + safe_or = ntl.where(or_ab != 0, or_ab, 1) + k = ntl.cast(libdevice.ffs(safe_or) - 1, compute_dtype) + a0 = abs_a >> k + b0 = abs_b >> k + + nonzero_a0 = a0 != 0 + safe_a0 = ntl.where(nonzero_a0, a0, 1) + ctz_a0 = ntl.cast(libdevice.ffs(safe_a0) - 1, compute_dtype) + a = ntl.where(nonzero_a0, a0 >> ctz_a0, b0) + b = ntl.where(nonzero_a0, b0, ntl.cast(0, compute_dtype)) + + for _ in range(64): + nonzero_b = b != 0 + safe_b = ntl.where(nonzero_b, b, 1) + ctz_b = ntl.cast(libdevice.ffs(safe_b) - 1, compute_dtype) + b_odd = b >> ctz_b + diff = b_odd - a + a = ntl.where(nonzero_b, ntl.minimum(a, b_odd), a) + b = ntl.where(nonzero_b, ntl.abs(diff), b) + + gcd = a << k safe_gcd = ntl.where(gcd == 0, 1, gcd) output = ntl.cast( # noqa: F841 ntl.where(gcd == 0, 0, ntl.abs((abs_a // safe_gcd) * abs_b)), dtype ) -def application_104(input, other, output): +def application_128(input, other, output): dtype = output.dtype compute_dtype = ( ntl.int32 if dtype == ntl.int8 or dtype == ntl.int16 else dtype ) + abs_a = ntl.abs(ntl.cast(input, compute_dtype)) abs_b = ntl.abs(ntl.cast(other, compute_dtype)) - a, b = abs_a, abs_b - for _ in range(104): - nonzero = b != 0 - safe_b = ntl.where(nonzero, b, 1) - new_a = ntl.where(nonzero, b, a) - new_b = ntl.where(nonzero, a % safe_b, b) - a = new_a - b = new_b - gcd = a + + or_ab = abs_a | abs_b + safe_or = ntl.where(or_ab != 0, or_ab, 1) + k = ntl.cast(libdevice.ffs(safe_or) - 1, compute_dtype) + a0 = abs_a >> k + b0 = abs_b >> k + + nonzero_a0 = a0 != 0 + safe_a0 = ntl.where(nonzero_a0, a0, 1) + ctz_a0 = ntl.cast(libdevice.ffs(safe_a0) - 1, compute_dtype) + a = ntl.where(nonzero_a0, a0 >> ctz_a0, b0) + b = ntl.where(nonzero_a0, b0, ntl.cast(0, compute_dtype)) + + for _ in range(128): + nonzero_b = b != 0 + safe_b = ntl.where(nonzero_b, b, 1) + ctz_b = ntl.cast(libdevice.ffs(safe_b) - 1, compute_dtype) + b_odd = b >> ctz_b + diff = b_odd - a + a = ntl.where(nonzero_b, ntl.minimum(a, b_odd), a) + b = ntl.where(nonzero_b, ntl.abs(diff), b) + + gcd = a << k safe_gcd = ntl.where(gcd == 0, 1, gcd) output = ntl.cast( # noqa: F841 ntl.where(gcd == 0, 0, ntl.abs((abs_a // safe_gcd) * abs_b)), dtype @@ -81,11 +147,11 @@ def premake(ndim, dtype=None, block_size=None): arrangement_ = functools.partial(arrangement, block_size=block_size) if dtype == ninetoothed.int64: - application = application_104 + application = application_128 elif dtype == ninetoothed.int32: - application = application_56 + application = application_64 else: - application = application_28 + application = application_32 tensors = ( Tensor(ndim, dtype=dtype), From 0f30bd92b183aa52669c44576649ab2ed77dab3d Mon Sep 17 00:00:00 2001 From: mygitljf <2410316423@qq.com> Date: Mon, 18 May 2026 16:36:41 +0000 Subject: [PATCH 3/7] [2026 spring][T1-1-1] lcm: tighten Stein unroll to per-dtype worst-case Iter04 follow-up to commit 8b74924. Empirical measurement shows the Stein outer iteration count is bounded by bit_width (not 2 * bit_width) because each outer step absorbs all consecutive trailing-zero shifts via 'b >> ctz(b)'. Fibonacci adversarial + (2^k-1, 2^(k-1)) pathological inputs confirm: int8 worst case 5 -> use 8 int16 worst case 13 -> use 16 int32 worst case 31 -> use 36 int64 worst case 63 -> use 72 Renamed application_32/64/128 to application_8/16/36/72 with separate buckets for int8 and int16. Premake dispatch is now 4-way. ncu shows iter03 application_128 was 96.84% compute-saturated, so cutting unroll directly translates to wall-clock speedup. application_72 duration drops from 2.82 ms to 1.65 ms (-42%, matches 128->72 ratio). device time at (1024, 5632), ratio = torch / ntops: iter03 iter04 lcm int8 0.59 -> 1.33 +126% (above torch) lcm int16 0.92 -> 1.47 +60% lcm int32 0.83 -> 1.35 +63% lcm int64 0.76 -> 1.30 +71% iter01 -> iter04 cumulative on int64: 0.36 -> 1.30 (+261%). Correctness: ntops pytest 12 passed + 4 skipped (bool); InfiniCore --nvidia 40/40 passed. Detailed iteration log in docs/logs/T1-1-1/lcm/iter04/. --- src/ntops/kernels/lcm.py | 101 ++++++++++++++++++++------------------- 1 file changed, 53 insertions(+), 48 deletions(-) diff --git a/src/ntops/kernels/lcm.py b/src/ntops/kernels/lcm.py index 1fa59eb..0d81407 100644 --- a/src/ntops/kernels/lcm.py +++ b/src/ntops/kernels/lcm.py @@ -8,52 +8,37 @@ from ntops.kernels.element_wise import arrangement -# Binary GCD (Stein's algorithm), inlined per dtype-bucketed unroll count. -# PyTorch CUDA lcm semantics: -# lcm(a, b) = abs((|a| / gcd) * |b|) with two's-complement wrap. -# Narrow ints (int8/int16) are promoted to int32 for arithmetic, then -# truncated back, so wrap-around matches PyTorch. +# Stein binary GCD with unroll counts derived from empirical worst-case +# outer-iteration counts (iter04): # -# Algorithm (every step keeps a and b odd): -# shift b right by ctz(b) so b is odd -# diff = b_odd - a (signed) -# new_a = min(a, b_odd) -# new_b = abs(diff) -# After enough steps b == 0; gcd = a << k where k = ctz(|input| | |other|). +# int8 (value range <= 127): max 5 -> use 8 +# int16 (value range <= 32767): max 13 -> use 16 +# int32 (value range <= 2^31): max 31 -> use 36 +# int64 (value range <= 2^63): max 63 -> use 72 # -# `libdevice.ffs(x)` returns 1 + ctz(x) for x != 0, and 0 for x == 0. -# -# NineToothed AST does not cross Python function boundaries cleanly -# (tuple returns from helpers raise compilation errors), so the prelude / -# step / finish blocks are duplicated inline in each application_. -# -# Worst-case Stein iterations: each step removes at least one bit of -# information from b, so 2 * bit_width is a safe upper bound: -# int8/int16 (promoted to int32, but value range still ≤ 32767): ~32 -# int32: 64 -# int64: 128 -def application_32(input, other, output): +# Earlier iterations used 32/64/128 which were safe but ~2x over-engineered: +# my outer loop already absorbs all consecutive trailing-zero shifts in one +# iteration via "b >> ctz(b)", so the theoretical max outer iter count is +# bit_width (not 2 * bit_width). Empirical Fibonacci adversarial inputs +# confirm < bit_width. +def application_8(input, other, output): dtype = output.dtype compute_dtype = ( ntl.int32 if dtype == ntl.int8 or dtype == ntl.int16 else dtype ) - abs_a = ntl.abs(ntl.cast(input, compute_dtype)) abs_b = ntl.abs(ntl.cast(other, compute_dtype)) - or_ab = abs_a | abs_b safe_or = ntl.where(or_ab != 0, or_ab, 1) k = ntl.cast(libdevice.ffs(safe_or) - 1, compute_dtype) a0 = abs_a >> k b0 = abs_b >> k - nonzero_a0 = a0 != 0 safe_a0 = ntl.where(nonzero_a0, a0, 1) ctz_a0 = ntl.cast(libdevice.ffs(safe_a0) - 1, compute_dtype) a = ntl.where(nonzero_a0, a0 >> ctz_a0, b0) b = ntl.where(nonzero_a0, b0, ntl.cast(0, compute_dtype)) - - for _ in range(32): + for _ in range(8): nonzero_b = b != 0 safe_b = ntl.where(nonzero_b, b, 1) ctz_b = ntl.cast(libdevice.ffs(safe_b) - 1, compute_dtype) @@ -61,7 +46,6 @@ def application_32(input, other, output): diff = b_odd - a a = ntl.where(nonzero_b, ntl.minimum(a, b_odd), a) b = ntl.where(nonzero_b, ntl.abs(diff), b) - gcd = a << k safe_gcd = ntl.where(gcd == 0, 1, gcd) output = ntl.cast( # noqa: F841 @@ -69,28 +53,24 @@ def application_32(input, other, output): ) -def application_64(input, other, output): +def application_16(input, other, output): dtype = output.dtype compute_dtype = ( ntl.int32 if dtype == ntl.int8 or dtype == ntl.int16 else dtype ) - abs_a = ntl.abs(ntl.cast(input, compute_dtype)) abs_b = ntl.abs(ntl.cast(other, compute_dtype)) - or_ab = abs_a | abs_b safe_or = ntl.where(or_ab != 0, or_ab, 1) k = ntl.cast(libdevice.ffs(safe_or) - 1, compute_dtype) a0 = abs_a >> k b0 = abs_b >> k - nonzero_a0 = a0 != 0 safe_a0 = ntl.where(nonzero_a0, a0, 1) ctz_a0 = ntl.cast(libdevice.ffs(safe_a0) - 1, compute_dtype) a = ntl.where(nonzero_a0, a0 >> ctz_a0, b0) b = ntl.where(nonzero_a0, b0, ntl.cast(0, compute_dtype)) - - for _ in range(64): + for _ in range(16): nonzero_b = b != 0 safe_b = ntl.where(nonzero_b, b, 1) ctz_b = ntl.cast(libdevice.ffs(safe_b) - 1, compute_dtype) @@ -98,7 +78,6 @@ def application_64(input, other, output): diff = b_odd - a a = ntl.where(nonzero_b, ntl.minimum(a, b_odd), a) b = ntl.where(nonzero_b, ntl.abs(diff), b) - gcd = a << k safe_gcd = ntl.where(gcd == 0, 1, gcd) output = ntl.cast( # noqa: F841 @@ -106,28 +85,24 @@ def application_64(input, other, output): ) -def application_128(input, other, output): +def application_36(input, other, output): dtype = output.dtype compute_dtype = ( ntl.int32 if dtype == ntl.int8 or dtype == ntl.int16 else dtype ) - abs_a = ntl.abs(ntl.cast(input, compute_dtype)) abs_b = ntl.abs(ntl.cast(other, compute_dtype)) - or_ab = abs_a | abs_b safe_or = ntl.where(or_ab != 0, or_ab, 1) k = ntl.cast(libdevice.ffs(safe_or) - 1, compute_dtype) a0 = abs_a >> k b0 = abs_b >> k - nonzero_a0 = a0 != 0 safe_a0 = ntl.where(nonzero_a0, a0, 1) ctz_a0 = ntl.cast(libdevice.ffs(safe_a0) - 1, compute_dtype) a = ntl.where(nonzero_a0, a0 >> ctz_a0, b0) b = ntl.where(nonzero_a0, b0, ntl.cast(0, compute_dtype)) - - for _ in range(128): + for _ in range(36): nonzero_b = b != 0 safe_b = ntl.where(nonzero_b, b, 1) ctz_b = ntl.cast(libdevice.ffs(safe_b) - 1, compute_dtype) @@ -135,7 +110,38 @@ def application_128(input, other, output): diff = b_odd - a a = ntl.where(nonzero_b, ntl.minimum(a, b_odd), a) b = ntl.where(nonzero_b, ntl.abs(diff), b) + gcd = a << k + safe_gcd = ntl.where(gcd == 0, 1, gcd) + output = ntl.cast( # noqa: F841 + ntl.where(gcd == 0, 0, ntl.abs((abs_a // safe_gcd) * abs_b)), dtype + ) + +def application_72(input, other, output): + dtype = output.dtype + compute_dtype = ( + ntl.int32 if dtype == ntl.int8 or dtype == ntl.int16 else dtype + ) + abs_a = ntl.abs(ntl.cast(input, compute_dtype)) + abs_b = ntl.abs(ntl.cast(other, compute_dtype)) + or_ab = abs_a | abs_b + safe_or = ntl.where(or_ab != 0, or_ab, 1) + k = ntl.cast(libdevice.ffs(safe_or) - 1, compute_dtype) + a0 = abs_a >> k + b0 = abs_b >> k + nonzero_a0 = a0 != 0 + safe_a0 = ntl.where(nonzero_a0, a0, 1) + ctz_a0 = ntl.cast(libdevice.ffs(safe_a0) - 1, compute_dtype) + a = ntl.where(nonzero_a0, a0 >> ctz_a0, b0) + b = ntl.where(nonzero_a0, b0, ntl.cast(0, compute_dtype)) + for _ in range(72): + nonzero_b = b != 0 + safe_b = ntl.where(nonzero_b, b, 1) + ctz_b = ntl.cast(libdevice.ffs(safe_b) - 1, compute_dtype) + b_odd = b >> ctz_b + diff = b_odd - a + a = ntl.where(nonzero_b, ntl.minimum(a, b_odd), a) + b = ntl.where(nonzero_b, ntl.abs(diff), b) gcd = a << k safe_gcd = ntl.where(gcd == 0, 1, gcd) output = ntl.cast( # noqa: F841 @@ -145,18 +151,17 @@ def application_128(input, other, output): def premake(ndim, dtype=None, block_size=None): arrangement_ = functools.partial(arrangement, block_size=block_size) - if dtype == ninetoothed.int64: - application = application_128 + application = application_72 elif dtype == ninetoothed.int32: - application = application_64 + application = application_36 + elif dtype == ninetoothed.int16: + application = application_16 else: - application = application_32 - + application = application_8 tensors = ( Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype), ) - return arrangement_, application, tensors From 603a1ce6b5f3f1e2c821c0c0d3e6fbcc6e78c7e6 Mon Sep 17 00:00:00 2001 From: lianjf Date: Tue, 19 May 2026 09:14:13 +0000 Subject: [PATCH 4/7] [2026 spring][T1-1-1] tune block_size/num_warps; lcm sentinel-merge + int64 dynamic Euclidean - lcm: Stein inner-loop sentinel-merge (3 wheres -> 1 where) for int8/int16/int32 with per-dtype unroll (8/16/36); int64 switched to dynamic Euclidean with grouped block-level early stop. Flatten contiguous inputs to 1D; explicit (block_size, num_warps, num_stages) per dtype. - rad2deg: block_size=2048, num_warps=4, num_stages=1; flatten 1D. - copysign: block_size=1024, num_warps=4. - nextafter: block_size=1024, num_warps=4, num_stages=2. - lgamma: large-numel path with block_size=1024, num_warps=4, num_stages=5. Verification: - ntops pytest (rad2deg/copysign/nextafter/lgamma/lcm): 44 passed, 4 skipped. - InfiniCore --nvidia: 5/5 ops 100% pass. - InfiniCore --bench device aggregate (lcm): 0.52x -> 1.67x. --- src/ntops/kernels/lcm.py | 133 ++++++++++++++++++++--------------- src/ntops/torch/copysign.py | 7 +- src/ntops/torch/lcm.py | 47 ++++++++++++- src/ntops/torch/lgamma.py | 14 +++- src/ntops/torch/nextafter.py | 13 +++- src/ntops/torch/rad2deg.py | 23 +++++- 6 files changed, 175 insertions(+), 62 deletions(-) diff --git a/src/ntops/kernels/lcm.py b/src/ntops/kernels/lcm.py index 0d81407..8a3ab4a 100644 --- a/src/ntops/kernels/lcm.py +++ b/src/ntops/kernels/lcm.py @@ -8,19 +8,48 @@ from ntops.kernels.element_wise import arrangement -# Stein binary GCD with unroll counts derived from empirical worst-case -# outer-iteration counts (iter04): +# T1-1-1 lcm: dtype-dispatched algorithm. # +# int8/int16/int32 -> Stein binary GCD (no IDIV, cheap per-iter). +# int64 -> Dynamic Euclidean with grouped block-level early stop. +# +# Why two algorithms: +# Stein per-iter on A100: ~14 us at BLOCK=512/warps=8 (no IDIV; just +# ffs+shift+sub+min+abs + 1 where). +# Euclidean per-iter on A100: ~14 us at BLOCK=32/warps=1 (1 IDIV + +# 2 wheres; the int64 IDIV ~30 cycles, but BLOCK=32/warps=1 means +# 1 element per thread which maximizes the number of concurrent +# in-flight IDIVs across SMs). +# At BLOCK=512/warps=8 (= 2 elements per thread), Euclidean is ~43 +# us per iter because each thread's two dependent IDIV chains block +# each other -> 3x slower. The (32, 1) config is critical. +# For int8/16/32, Stein static unroll is unbeatable (no IDIV, no dynamic +# check overhead). For int64 with v2-style small inputs (values <= +# ~2^20), Euclidean dynamic averages ~14 outer iters vs Stein's +# fixed 60, giving ~4x speedup. +# +# Stein unroll counts (worst-case empirically validated): # int8 (value range <= 127): max 5 -> use 8 # int16 (value range <= 32767): max 13 -> use 16 # int32 (value range <= 2^31): max 31 -> use 36 -# int64 (value range <= 2^63): max 63 -> use 72 # -# Earlier iterations used 32/64/128 which were safe but ~2x over-engineered: -# my outer loop already absorbs all consecutive trailing-zero shifts in one -# iteration via "b >> ctz(b)", so the theoretical max outer iter count is -# bit_width (not 2 * bit_width). Empirical Fibonacci adversarial inputs -# confirm < bit_width. +# Euclidean (int64) uses grouped dynamic stop: +# outer cap = 12, inner unroll = 8 -> max 96 Euclidean iters. +# Block-level `ntl.max(b) != 0` check every 8 inner iters. +# N=96 covers Fibonacci adversarial worst case (~91 iters) for full +# int64 range. +# +# Sentinel-merge (iter05): one `where` per Stein iter using `a` (always +# odd, always >= 1) as the `b == 0` sentinel. +# +# History: +# iter05: sentinel-merge + flat 1D + explicit (512, warps_per_dt, 1) +# iter06: int64 Stein 72 -> 64 +# iter07: int64 Stein 64 -> 60 (1M+ sample empirical worst case = 57) +# iter08: int64 switch Stein -> dynamic Euclidean at (BLOCK=32, warps=1). +# Trade-off: int64 v2-style small-input launches ~2-4x faster +# than Stein; full-range int64 launches ~1.9x slower. Other +# dtypes unchanged. def application_8(input, other, output): dtype = output.dtype compute_dtype = ( @@ -38,18 +67,18 @@ def application_8(input, other, output): ctz_a0 = ntl.cast(libdevice.ffs(safe_a0) - 1, compute_dtype) a = ntl.where(nonzero_a0, a0 >> ctz_a0, b0) b = ntl.where(nonzero_a0, b0, ntl.cast(0, compute_dtype)) + a = ntl.where(a == 0, ntl.cast(1, compute_dtype), a) for _ in range(8): - nonzero_b = b != 0 - safe_b = ntl.where(nonzero_b, b, 1) - ctz_b = ntl.cast(libdevice.ffs(safe_b) - 1, compute_dtype) - b_odd = b >> ctz_b + b_for_calc = ntl.where(b != 0, b, a) + ctz_b = ntl.cast(libdevice.ffs(b_for_calc) - 1, compute_dtype) + b_odd = b_for_calc >> ctz_b diff = b_odd - a - a = ntl.where(nonzero_b, ntl.minimum(a, b_odd), a) - b = ntl.where(nonzero_b, ntl.abs(diff), b) + a = ntl.minimum(a, b_odd) + b = ntl.abs(diff) gcd = a << k safe_gcd = ntl.where(gcd == 0, 1, gcd) output = ntl.cast( # noqa: F841 - ntl.where(gcd == 0, 0, ntl.abs((abs_a // safe_gcd) * abs_b)), dtype + ntl.where(or_ab == 0, 0, ntl.abs((abs_a // safe_gcd) * abs_b)), dtype ) @@ -70,18 +99,18 @@ def application_16(input, other, output): ctz_a0 = ntl.cast(libdevice.ffs(safe_a0) - 1, compute_dtype) a = ntl.where(nonzero_a0, a0 >> ctz_a0, b0) b = ntl.where(nonzero_a0, b0, ntl.cast(0, compute_dtype)) + a = ntl.where(a == 0, ntl.cast(1, compute_dtype), a) for _ in range(16): - nonzero_b = b != 0 - safe_b = ntl.where(nonzero_b, b, 1) - ctz_b = ntl.cast(libdevice.ffs(safe_b) - 1, compute_dtype) - b_odd = b >> ctz_b + b_for_calc = ntl.where(b != 0, b, a) + ctz_b = ntl.cast(libdevice.ffs(b_for_calc) - 1, compute_dtype) + b_odd = b_for_calc >> ctz_b diff = b_odd - a - a = ntl.where(nonzero_b, ntl.minimum(a, b_odd), a) - b = ntl.where(nonzero_b, ntl.abs(diff), b) + a = ntl.minimum(a, b_odd) + b = ntl.abs(diff) gcd = a << k safe_gcd = ntl.where(gcd == 0, 1, gcd) output = ntl.cast( # noqa: F841 - ntl.where(gcd == 0, 0, ntl.abs((abs_a // safe_gcd) * abs_b)), dtype + ntl.where(or_ab == 0, 0, ntl.abs((abs_a // safe_gcd) * abs_b)), dtype ) @@ -102,57 +131,51 @@ def application_36(input, other, output): ctz_a0 = ntl.cast(libdevice.ffs(safe_a0) - 1, compute_dtype) a = ntl.where(nonzero_a0, a0 >> ctz_a0, b0) b = ntl.where(nonzero_a0, b0, ntl.cast(0, compute_dtype)) + a = ntl.where(a == 0, ntl.cast(1, compute_dtype), a) for _ in range(36): - nonzero_b = b != 0 - safe_b = ntl.where(nonzero_b, b, 1) - ctz_b = ntl.cast(libdevice.ffs(safe_b) - 1, compute_dtype) - b_odd = b >> ctz_b + b_for_calc = ntl.where(b != 0, b, a) + ctz_b = ntl.cast(libdevice.ffs(b_for_calc) - 1, compute_dtype) + b_odd = b_for_calc >> ctz_b diff = b_odd - a - a = ntl.where(nonzero_b, ntl.minimum(a, b_odd), a) - b = ntl.where(nonzero_b, ntl.abs(diff), b) + a = ntl.minimum(a, b_odd) + b = ntl.abs(diff) gcd = a << k safe_gcd = ntl.where(gcd == 0, 1, gcd) output = ntl.cast( # noqa: F841 - ntl.where(gcd == 0, 0, ntl.abs((abs_a // safe_gcd) * abs_b)), dtype + ntl.where(or_ab == 0, 0, ntl.abs((abs_a // safe_gcd) * abs_b)), dtype ) -def application_72(input, other, output): +def application_euclidean_dyn(input, other, output): + # Dynamic Euclidean for int64. + # Block-level early stop every 8 inner iters; outer cap 12 -> N=96. + # Convergence for random uniform full-range int64 averages ~36 iters; + # for v2-style range (<=2^20) averages ~14 iters. dtype = output.dtype - compute_dtype = ( - ntl.int32 if dtype == ntl.int8 or dtype == ntl.int16 else dtype - ) - abs_a = ntl.abs(ntl.cast(input, compute_dtype)) - abs_b = ntl.abs(ntl.cast(other, compute_dtype)) + abs_a = ntl.abs(input) + abs_b = ntl.abs(other) or_ab = abs_a | abs_b - safe_or = ntl.where(or_ab != 0, or_ab, 1) - k = ntl.cast(libdevice.ffs(safe_or) - 1, compute_dtype) - a0 = abs_a >> k - b0 = abs_b >> k - nonzero_a0 = a0 != 0 - safe_a0 = ntl.where(nonzero_a0, a0, 1) - ctz_a0 = ntl.cast(libdevice.ffs(safe_a0) - 1, compute_dtype) - a = ntl.where(nonzero_a0, a0 >> ctz_a0, b0) - b = ntl.where(nonzero_a0, b0, ntl.cast(0, compute_dtype)) - for _ in range(72): - nonzero_b = b != 0 - safe_b = ntl.where(nonzero_b, b, 1) - ctz_b = ntl.cast(libdevice.ffs(safe_b) - 1, compute_dtype) - b_odd = b >> ctz_b - diff = b_odd - a - a = ntl.where(nonzero_b, ntl.minimum(a, b_odd), a) - b = ntl.where(nonzero_b, ntl.abs(diff), b) - gcd = a << k + a = ntl.where(abs_a >= abs_b, abs_a, abs_b) + b = ntl.where(abs_a >= abs_b, abs_b, abs_a) + outer = 0 + while ntl.max(b) != 0 and outer < 12: + for _ in range(8): + b_safe = ntl.where(b != 0, b, ntl.cast(1, dtype)) + r = a % b_safe + a = ntl.where(b != 0, b, a) + b = r + outer += 1 + gcd = a safe_gcd = ntl.where(gcd == 0, 1, gcd) output = ntl.cast( # noqa: F841 - ntl.where(gcd == 0, 0, ntl.abs((abs_a // safe_gcd) * abs_b)), dtype + ntl.where(or_ab == 0, 0, ntl.abs((abs_a // safe_gcd) * abs_b)), dtype ) def premake(ndim, dtype=None, block_size=None): arrangement_ = functools.partial(arrangement, block_size=block_size) if dtype == ninetoothed.int64: - application = application_72 + application = application_euclidean_dyn elif dtype == ninetoothed.int32: application = application_36 elif dtype == ninetoothed.int16: diff --git a/src/ntops/torch/copysign.py b/src/ntops/torch/copysign.py index 252eb10..291ff04 100644 --- a/src/ntops/torch/copysign.py +++ b/src/ntops/torch/copysign.py @@ -8,7 +8,12 @@ def copysign(input, other, *, out=None): if out is None: out = torch.empty_like(input) - kernel = _cached_make(ntops.kernels.copysign.premake, input.ndim) + kernel = _cached_make( + ntops.kernels.copysign.premake, + input.ndim, + block_size=1024, + num_warps=4, + ) kernel(input, other, out) diff --git a/src/ntops/torch/lcm.py b/src/ntops/torch/lcm.py index 006ae9a..d8d9b19 100644 --- a/src/ntops/torch/lcm.py +++ b/src/ntops/torch/lcm.py @@ -4,13 +4,56 @@ from ntops.torch.utils import _cached_make +# iter08: int64 dispatched to dynamic Euclidean kernel, which is fastest +# at (BLOCK=32, num_warps=1) -- 1 element per thread maximizes concurrent +# in-flight IDIVs across SMs. See kernels/lcm.py header comment. +# +# int8/int16/int32 still use Stein static unroll (no IDIV; (512, warps, 1) +# is best per iter05's explore_config.py scan). +_NUM_STAGES = 1 + + +def _block_size_for(torch_dtype): + if torch_dtype == torch.int64: + return 32 + return 512 + + +def _num_warps_for(torch_dtype): + if torch_dtype == torch.int64: + return 1 + return 4 + + def lcm(input, other, *, out=None): if out is None: out = torch.empty_like(input) - kernel = _cached_make(ntops.kernels.lcm.premake, input.ndim, dtype=_to_nt(input.dtype)) + if ( + input.ndim != 1 + and input.is_contiguous() + and other.is_contiguous() + and out.is_contiguous() + ): + n = input.numel() + in_view = input.view([n]) + other_view = other.view([n]) + out_view = out.view([n]) + else: + in_view = input + other_view = other + out_view = out + + kernel = _cached_make( + ntops.kernels.lcm.premake, + in_view.ndim, + dtype=_to_nt(input.dtype), + block_size=_block_size_for(input.dtype), + num_warps=_num_warps_for(input.dtype), + num_stages=_NUM_STAGES, + ) - kernel(input, other, out) + kernel(in_view, other_view, out_view) return out diff --git a/src/ntops/torch/lgamma.py b/src/ntops/torch/lgamma.py index b1fed7c..13dc432 100644 --- a/src/ntops/torch/lgamma.py +++ b/src/ntops/torch/lgamma.py @@ -4,11 +4,23 @@ from ntops.torch.utils import _cached_make +_LARGE_NUMEL_THRESHOLD = 2_000_000 + + def lgamma(input, *, out=None): if out is None: out = torch.empty_like(input) - kernel = _cached_make(ntops.kernels.lgamma.premake, input.ndim) + if input.numel() >= _LARGE_NUMEL_THRESHOLD: + kernel = _cached_make( + ntops.kernels.lgamma.premake, + input.ndim, + block_size=1024, + num_warps=4, + num_stages=5, + ) + else: + kernel = _cached_make(ntops.kernels.lgamma.premake, input.ndim) kernel(input, out) diff --git a/src/ntops/torch/nextafter.py b/src/ntops/torch/nextafter.py index c2c173d..eebd82b 100644 --- a/src/ntops/torch/nextafter.py +++ b/src/ntops/torch/nextafter.py @@ -4,11 +4,22 @@ from ntops.torch.utils import _cached_make +_BLOCK_SIZE = 1024 +_NUM_WARPS = 4 +_NUM_STAGES = 2 + + def nextafter(input, other, *, out=None): if out is None: out = torch.empty_like(input) - kernel = _cached_make(ntops.kernels.nextafter.premake, input.ndim) + kernel = _cached_make( + ntops.kernels.nextafter.premake, + input.ndim, + block_size=_BLOCK_SIZE, + num_warps=_NUM_WARPS, + num_stages=_NUM_STAGES, + ) kernel(input, other, out) diff --git a/src/ntops/torch/rad2deg.py b/src/ntops/torch/rad2deg.py index f6896b1..470c26e 100644 --- a/src/ntops/torch/rad2deg.py +++ b/src/ntops/torch/rad2deg.py @@ -4,12 +4,31 @@ from ntops.torch.utils import _cached_make +_BLOCK_SIZE = 2048 +_NUM_WARPS = 4 +_NUM_STAGES = 1 + + def rad2deg(input, *, out=None): if out is None: out = torch.empty_like(input) - kernel = _cached_make(ntops.kernels.rad2deg.premake, input.ndim) + if input.ndim != 1 and input.is_contiguous() and out.is_contiguous(): + n = input.numel() + in_view = input.view([n]) + out_view = out.view([n]) + else: + in_view = input + out_view = out + + kernel = _cached_make( + ntops.kernels.rad2deg.premake, + in_view.ndim, + block_size=_BLOCK_SIZE, + num_warps=_NUM_WARPS, + num_stages=_NUM_STAGES, + ) - kernel(input, out) + kernel(in_view, out_view) return out From 9b840d85622cf49c1a8a98ba1eb8f70bbf875b4a Mon Sep 17 00:00:00 2001 From: mygitljf <2410316423@qq.com> Date: Tue, 19 May 2026 17:05:45 +0000 Subject: [PATCH 5/7] [2026 spring][T1-1-1] dtype-dispatch kernels + Iluvatar/CoreX fp16 fallback Production-side changes used during MetaX C500 and Iluvatar MR-V100 evaluation runs. ntops.kernels.copysign / nextafter: - Replace runtime `int_dtype.primitive_bitwidth` with three explicit applications (int16 / int32 / int64) selected in `premake` from the output dtype. The previous expression relied on querying the primitive bit width through the result of a `cast(..., bitcast=True)` expression, which forced fp32->fp16 bitcast paths on backends that cannot lower 32-bit-to-16-bit bitcast stores. ntops.kernels.lcm: - Pin `compute_dtype` to `int32` for application_8/16/36 (int8/int16 always promote, int32 stays). int64 keeps its dynamic-Euclidean application unchanged. ntops.kernels.lgamma: - Split into `application_float32_compute` (fp16/bf16 -> compute in fp32 -> cast back) and `application_native` (fp32/fp64). `premake` selects per dtype. ntops.torch.{copysign, nextafter, lgamma}: - Pass `dtype=ninetoothed.` into `_cached_make` so the per-dtype application is selected at compile time. - copysign / nextafter: when running on Iluvatar MR-V100 / CoreX with fp16/bf16, fall back to torch. via a wrapper. Iluvatar CoreX Triton fails to compile the 16-bit bitcast store path; on every other backend (NVIDIA, MetaX, Moore) the native ntops kernel path is taken. Detection lives in `ntops.torch.utils._is_corex_compat_device` (controllable via `NTOPS_BACKEND` env). ntops.torch.utils: - Add `_is_corex_compat_device`, `_torch_binary_fallback`, and infinicore <-> torch tensor bridging helpers used by the fallback. Verification ------------ NVIDIA A100 80GB (--nvidia): ntops pytest: 44 passed, 4 skipped InfiniCore run.py: 172/172 (100.0%) --bench device aggregate: ~1.05x (lcm 1.49x, rad2deg 1.01x, nextafter 0.99x, copysign 0.95x, lgamma 0.79x) MetaX C500 (--metax): ntops pytest: 44 passed, 4 skipped InfiniCore run.py: 172/172 (100.0%) --bench device aggregate: 1.022x (lcm 1.272x, lgamma 0.987x, nextafter 0.997x, copysign 0.933x, rad2deg 0.901x) Iluvatar MR-V100 (--iluvatar): ntops pytest: 44 passed, 4 skipped InfiniCore run.py: 172/172 (100.0%) --bench device aggregate: 0.477x (the wrapper-fallback path on fp16/bf16 dominates the elementwise budget; fp32 paths still hit the native kernel) --- src/ntops/kernels/copysign.py | 45 +++++++++++++--- src/ntops/kernels/lcm.py | 12 ++--- src/ntops/kernels/lgamma.py | 21 +++++--- src/ntops/kernels/nextafter.py | 91 +++++++++++++++++++++++++++++--- src/ntops/torch/copysign.py | 25 ++++++++- src/ntops/torch/lgamma.py | 19 ++++++- src/ntops/torch/nextafter.py | 25 ++++++++- src/ntops/torch/utils.py | 95 ++++++++++++++++++++++++++++++++++ 8 files changed, 299 insertions(+), 34 deletions(-) diff --git a/src/ntops/kernels/copysign.py b/src/ntops/kernels/copysign.py index bad16e5..ef5d1bf 100644 --- a/src/ntops/kernels/copysign.py +++ b/src/ntops/kernels/copysign.py @@ -1,26 +1,48 @@ import functools +import ninetoothed import ninetoothed.language as ntl from ninetoothed import Tensor from ntops.kernels.element_wise import arrangement -def application(input, other, output): +def application_int16(input, other, output): # Pure bit manipulation: take magnitude bits of input, sign bit of other. # Avoids the fp16/bf16 -> fp32 -> fp16/bf16 round-trip required by # libdevice.copysign, which doesn't support narrow floats. dtype = output.dtype - if dtype == ntl.float16 or dtype == ntl.bfloat16: - int_dtype = ntl.int16 - elif dtype == ntl.float32: - int_dtype = ntl.int32 - else: - int_dtype = ntl.int64 + int_dtype = ntl.int16 + + input_bits = ntl.cast(input, int_dtype, bitcast=True) + other_bits = ntl.cast(other, int_dtype, bitcast=True) + sign_bit = ntl.cast(1, int_dtype) << 15 + magn_mask = sign_bit - ntl.cast(1, int_dtype) + output = ntl.cast( # noqa: F841 + (input_bits & magn_mask) | (other_bits & sign_bit), dtype, bitcast=True + ) + + +def application_int32(input, other, output): + dtype = output.dtype + int_dtype = ntl.int32 input_bits = ntl.cast(input, int_dtype, bitcast=True) other_bits = ntl.cast(other, int_dtype, bitcast=True) - sign_bit = ntl.cast(1, int_dtype) << (ntl.cast(input, int_dtype, bitcast=True).dtype.primitive_bitwidth - 1) + sign_bit = ntl.cast(1, int_dtype) << 31 + magn_mask = sign_bit - ntl.cast(1, int_dtype) + output = ntl.cast( # noqa: F841 + (input_bits & magn_mask) | (other_bits & sign_bit), dtype, bitcast=True + ) + + +def application_int64(input, other, output): + dtype = output.dtype + int_dtype = ntl.int64 + + input_bits = ntl.cast(input, int_dtype, bitcast=True) + other_bits = ntl.cast(other, int_dtype, bitcast=True) + sign_bit = ntl.cast(1, int_dtype) << 63 magn_mask = sign_bit - ntl.cast(1, int_dtype) output = ntl.cast( # noqa: F841 (input_bits & magn_mask) | (other_bits & sign_bit), dtype, bitcast=True @@ -30,6 +52,13 @@ def application(input, other, output): def premake(ndim, dtype=None, block_size=None): arrangement_ = functools.partial(arrangement, block_size=block_size) + if dtype in (ninetoothed.float16, ninetoothed.bfloat16): + application = application_int16 + elif dtype == ninetoothed.float32: + application = application_int32 + else: + application = application_int64 + tensors = ( Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype), diff --git a/src/ntops/kernels/lcm.py b/src/ntops/kernels/lcm.py index 8a3ab4a..3d05521 100644 --- a/src/ntops/kernels/lcm.py +++ b/src/ntops/kernels/lcm.py @@ -52,9 +52,7 @@ # dtypes unchanged. def application_8(input, other, output): dtype = output.dtype - compute_dtype = ( - ntl.int32 if dtype == ntl.int8 or dtype == ntl.int16 else dtype - ) + compute_dtype = ntl.int32 abs_a = ntl.abs(ntl.cast(input, compute_dtype)) abs_b = ntl.abs(ntl.cast(other, compute_dtype)) or_ab = abs_a | abs_b @@ -84,9 +82,7 @@ def application_8(input, other, output): def application_16(input, other, output): dtype = output.dtype - compute_dtype = ( - ntl.int32 if dtype == ntl.int8 or dtype == ntl.int16 else dtype - ) + compute_dtype = ntl.int32 abs_a = ntl.abs(ntl.cast(input, compute_dtype)) abs_b = ntl.abs(ntl.cast(other, compute_dtype)) or_ab = abs_a | abs_b @@ -116,9 +112,7 @@ def application_16(input, other, output): def application_36(input, other, output): dtype = output.dtype - compute_dtype = ( - ntl.int32 if dtype == ntl.int8 or dtype == ntl.int16 else dtype - ) + compute_dtype = ntl.int32 abs_a = ntl.abs(ntl.cast(input, compute_dtype)) abs_b = ntl.abs(ntl.cast(other, compute_dtype)) or_ab = abs_a | abs_b diff --git a/src/ntops/kernels/lgamma.py b/src/ntops/kernels/lgamma.py index 065f207..920635d 100644 --- a/src/ntops/kernels/lgamma.py +++ b/src/ntops/kernels/lgamma.py @@ -1,5 +1,6 @@ import functools +import ninetoothed import ninetoothed.language as ntl from ninetoothed import Tensor from ninetoothed.language import libdevice @@ -7,16 +8,19 @@ from ntops.kernels.element_wise import arrangement -def application(input, output): +def application_float32_compute(input, output): # libdevice.lgamma only supports fp32/fp64; cast narrower floats up. dtype = output.dtype - compute_dtype = ( - dtype - if dtype != ntl.float16 and dtype != ntl.bfloat16 - else ntl.float32 + output = ntl.cast( # noqa: F841 + libdevice.lgamma(ntl.cast(input, ntl.float32)), + dtype, ) + + +def application_native(input, output): + dtype = output.dtype output = ntl.cast( # noqa: F841 - libdevice.lgamma(ntl.cast(input, compute_dtype)), + libdevice.lgamma(input), dtype, ) @@ -24,6 +28,11 @@ def application(input, output): def premake(ndim, dtype=None, block_size=None): arrangement_ = functools.partial(arrangement, block_size=block_size) + if dtype in (ninetoothed.float16, ninetoothed.bfloat16): + application = application_float32_compute + else: + application = application_native + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) return arrangement_, application, tensors diff --git a/src/ntops/kernels/nextafter.py b/src/ntops/kernels/nextafter.py index 1110b1c..4194eb0 100644 --- a/src/ntops/kernels/nextafter.py +++ b/src/ntops/kernels/nextafter.py @@ -1,24 +1,57 @@ import functools +import ninetoothed import ninetoothed.language as ntl from ninetoothed import Tensor from ntops.kernels.element_wise import arrangement -def application(input, other, output): +def application_int16(input, other, output): # PyTorch nextafter spec, implemented via IEEE bit manipulation: # if either is NaN: result is NaN # if a == b: result is b (preserves sign of zero) # if a == 0: result is smallest subnormal with sign of b # otherwise: walk one ULP toward b in IEEE bit space dtype = output.dtype - if dtype == ntl.float16 or dtype == ntl.bfloat16: - int_dtype = ntl.int16 - elif dtype == ntl.float32: - int_dtype = ntl.int32 - else: - int_dtype = ntl.int64 + int_dtype = ntl.int16 + + a = input + b = other + a_cmp = ntl.cast(a, ntl.float32) + b_cmp = ntl.cast(b, ntl.float32) + a_i = ntl.cast(a, int_dtype, bitcast=True) + b_i = ntl.cast(b, int_dtype, bitcast=True) + + one = ntl.cast(1, int_dtype) + zero = ntl.cast(0, int_dtype) + sign_bit = one << 15 + + is_nan = (a_cmp != a_cmp) | (b_cmp != b_cmp) + eq = a_cmp == b_cmp + is_zero = a_cmp == ntl.cast(0, ntl.float32) + + b_sign = b_i & sign_bit + zero_result = b_sign | one + + a_neg = a_i < zero + a_lt_b = a_cmp < b_cmp + step_up = a_neg ^ a_lt_b + step = ntl.where(step_up, one, -one) + general = a_i + step + + nan_bits = ntl.cast(ntl.cast(float("nan"), dtype), int_dtype, bitcast=True) + result_i = ntl.where( + is_nan, + nan_bits, + ntl.where(eq, b_i, ntl.where(is_zero, zero_result, general)), + ) + output = ntl.cast(result_i, dtype, bitcast=True) # noqa: F841 + + +def application_int32(input, other, output): + dtype = output.dtype + int_dtype = ntl.int32 a = input b = other @@ -27,7 +60,42 @@ def application(input, other, output): one = ntl.cast(1, int_dtype) zero = ntl.cast(0, int_dtype) - sign_bit = one << (a_i.dtype.primitive_bitwidth - 1) + sign_bit = one << 31 + + is_nan = (a != a) | (b != b) + eq = a == b + is_zero = a == ntl.cast(0, dtype) + + b_sign = b_i & sign_bit + zero_result = b_sign | one + + a_neg = a_i < zero + a_lt_b = a < b + step_up = a_neg ^ a_lt_b + step = ntl.where(step_up, one, -one) + general = a_i + step + + nan_bits = ntl.cast(ntl.cast(float("nan"), dtype), int_dtype, bitcast=True) + result_i = ntl.where( + is_nan, + nan_bits, + ntl.where(eq, b_i, ntl.where(is_zero, zero_result, general)), + ) + output = ntl.cast(result_i, dtype, bitcast=True) # noqa: F841 + + +def application_int64(input, other, output): + dtype = output.dtype + int_dtype = ntl.int64 + + a = input + b = other + a_i = ntl.cast(a, int_dtype, bitcast=True) + b_i = ntl.cast(b, int_dtype, bitcast=True) + + one = ntl.cast(1, int_dtype) + zero = ntl.cast(0, int_dtype) + sign_bit = one << 63 is_nan = (a != a) | (b != b) eq = a == b @@ -54,6 +122,13 @@ def application(input, other, output): def premake(ndim, dtype=None, block_size=None): arrangement_ = functools.partial(arrangement, block_size=block_size) + if dtype in (ninetoothed.float16, ninetoothed.bfloat16): + application = application_int16 + elif dtype == ninetoothed.float32: + application = application_int32 + else: + application = application_int64 + tensors = ( Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype), diff --git a/src/ntops/torch/copysign.py b/src/ntops/torch/copysign.py index 291ff04..ae0bc40 100644 --- a/src/ntops/torch/copysign.py +++ b/src/ntops/torch/copysign.py @@ -1,16 +1,27 @@ import torch import ntops -from ntops.torch.utils import _cached_make +from ntops.torch.utils import ( + _cached_make, + _is_corex_compat_device, + _torch_binary_fallback, +) def copysign(input, other, *, out=None): if out is None: out = torch.empty_like(input) + if input.dtype in (torch.float16, torch.bfloat16) and _is_corex_compat_device( + input.device + ): + _torch_binary_fallback("copysign", input, other, out) + return out + kernel = _cached_make( ntops.kernels.copysign.premake, input.ndim, + dtype=_to_nt(input.dtype), block_size=1024, num_warps=4, ) @@ -18,3 +29,15 @@ def copysign(input, other, *, out=None): kernel(input, other, out) return out + + +def _to_nt(torch_dtype): + import ninetoothed + + mapping = { + torch.float16: ninetoothed.float16, + torch.bfloat16: ninetoothed.bfloat16, + torch.float32: ninetoothed.float32, + torch.float64: ninetoothed.float64, + } + return mapping.get(torch_dtype) diff --git a/src/ntops/torch/lgamma.py b/src/ntops/torch/lgamma.py index 13dc432..bebe35d 100644 --- a/src/ntops/torch/lgamma.py +++ b/src/ntops/torch/lgamma.py @@ -15,13 +15,30 @@ def lgamma(input, *, out=None): kernel = _cached_make( ntops.kernels.lgamma.premake, input.ndim, + dtype=_to_nt(input.dtype), block_size=1024, num_warps=4, num_stages=5, ) else: - kernel = _cached_make(ntops.kernels.lgamma.premake, input.ndim) + kernel = _cached_make( + ntops.kernels.lgamma.premake, + input.ndim, + dtype=_to_nt(input.dtype), + ) kernel(input, out) return out + + +def _to_nt(torch_dtype): + import ninetoothed + + mapping = { + torch.float16: ninetoothed.float16, + torch.bfloat16: ninetoothed.bfloat16, + torch.float32: ninetoothed.float32, + torch.float64: ninetoothed.float64, + } + return mapping.get(torch_dtype) diff --git a/src/ntops/torch/nextafter.py b/src/ntops/torch/nextafter.py index eebd82b..45151be 100644 --- a/src/ntops/torch/nextafter.py +++ b/src/ntops/torch/nextafter.py @@ -1,7 +1,11 @@ import torch import ntops -from ntops.torch.utils import _cached_make +from ntops.torch.utils import ( + _cached_make, + _is_corex_compat_device, + _torch_binary_fallback, +) _BLOCK_SIZE = 1024 @@ -13,9 +17,16 @@ def nextafter(input, other, *, out=None): if out is None: out = torch.empty_like(input) + if input.dtype in (torch.float16, torch.bfloat16) and _is_corex_compat_device( + input.device + ): + _torch_binary_fallback("nextafter", input, other, out) + return out + kernel = _cached_make( ntops.kernels.nextafter.premake, input.ndim, + dtype=_to_nt(input.dtype), block_size=_BLOCK_SIZE, num_warps=_NUM_WARPS, num_stages=_NUM_STAGES, @@ -24,3 +35,15 @@ def nextafter(input, other, *, out=None): kernel(input, other, out) return out + + +def _to_nt(torch_dtype): + import ninetoothed + + mapping = { + torch.float16: ninetoothed.float16, + torch.bfloat16: ninetoothed.bfloat16, + torch.float32: ninetoothed.float32, + torch.float64: ninetoothed.float64, + } + return mapping.get(torch_dtype) diff --git a/src/ntops/torch/utils.py b/src/ntops/torch/utils.py index e9b2dde..289d4a8 100644 --- a/src/ntops/torch/utils.py +++ b/src/ntops/torch/utils.py @@ -1,4 +1,5 @@ import functools +import os import ninetoothed import torch @@ -42,6 +43,100 @@ def set_default_max_num_configs(max_num_configs): _cached_make_default_config.max_num_configs = max_num_configs +def _is_corex_compat_device(device=None): + backend = os.getenv("NTOPS_BACKEND", "").strip().lower() + if backend in {"corex", "iluvatar", "tian", "mr-v100"}: + return True + if backend in {"cuda", "nvidia"}: + return False + + if not torch.cuda.is_available(): + return False + + if device is not None and getattr(device, "type", None) != "cuda": + return False + + index = getattr(device, "index", None) + if index is None: + index = torch.cuda.current_device() + + try: + name = torch.cuda.get_device_name(index).lower() + except Exception: + return False + + return "iluvatar" in name or "mr-v100" in name or "corex" in name + + +def _torch_binary_fallback(op_name, input, other, out): + if not _is_infinicore_tensor(input): + return getattr(torch, op_name)(input, other, out=out) + + input_torch = _infinicore_to_torch(input) + other_torch = _infinicore_to_torch(other) + result = getattr(torch, op_name)(input_torch, other_torch) + _copy_torch_to_infinicore(result, out) + return out + + +def _is_infinicore_tensor(value): + return hasattr(value, "_underlying") and hasattr(value, "copy_") + + +def _infinicore_to_torch(value): + if not _is_infinicore_tensor(value): + return value + + result = torch.empty_strided( + tuple(value.shape), + tuple(value.stride()), + dtype=_to_torch_dtype(value.dtype), + device=str(value.device), + ) + _infinicore_from_torch(result).copy_(value) + return result + + +def _copy_torch_to_infinicore(value, out): + if _is_infinicore_tensor(out): + out.copy_(_infinicore_from_torch(value)) + else: + out.copy_(value) + + +def _infinicore_from_torch(value): + infinicore = __import__("infinicore") + infini_device = infinicore.device(value.device.type, value.device.index or 0) + kwargs = {"dtype": _to_infinicore_dtype(value.dtype), "device": infini_device} + if value.is_contiguous(): + return infinicore.from_blob(value.data_ptr(), list(value.shape), **kwargs) + return infinicore.strided_from_blob( + value.data_ptr(), list(value.shape), list(value.stride()), **kwargs + ) + + +def _to_torch_dtype(dtype): + infinicore = __import__("infinicore") + mapping = { + infinicore.float16: torch.float16, + infinicore.bfloat16: torch.bfloat16, + infinicore.float32: torch.float32, + infinicore.float64: torch.float64, + } + return mapping.get(dtype, dtype) + + +def _to_infinicore_dtype(dtype): + infinicore = __import__("infinicore") + mapping = { + torch.float16: infinicore.float16, + torch.bfloat16: infinicore.bfloat16, + torch.float32: infinicore.float32, + torch.float64: infinicore.float64, + } + return mapping[dtype] + + @functools.cache def _cached_make( premake, *args, num_warps=None, num_stages=None, max_num_configs=None, **keywords From 721a63a7237353baaac953fd65f216282ec3acce Mon Sep 17 00:00:00 2001 From: mygitljf <2410316423@qq.com> Date: Tue, 19 May 2026 17:06:20 +0000 Subject: [PATCH 6/7] [2026 spring][T1-1-1] add ntops do_bench perf harness for the 5 ops Adds an opt-in performance harness that benchmarks ntops.torch. against torch. with `triton.testing.do_bench` (warmup=50, rep=200, median of 3). Tests are skipped by default and only execute when `NTOPS_RUN_PERF=1` is set, so they cost nothing in the standard correctness pipeline. Files ----- - tests/perf_utils.py: shared shapes/dtype lists, `bench_us`, `report_and_assert` (asserts ratio = torch_us / ntops_us >= 0.5), warmup_pair, input factories. - tests/test__perf.py for rad2deg, copysign, lcm, nextafter, lgamma: parametrized over a representative set of shapes and dtypes (fp16/bf16/fp32; int8/int16/int32/int64 for lcm). Usage ----- NTOPS_RUN_PERF=1 python -m pytest tests/test_rad2deg_perf.py -s Verification (NVIDIA A100 80GB): rad2deg 18 cases ntops 190.10us torch 190.59us 1.003x copysign 18 cases ntops 248.22us torch 235.79us 0.950x lcm 24 cases ntops 1018.18us torch 1177.45us 1.156x nextafter 18 cases ntops 254.00us torch 255.25us 1.005x lgamma 18 cases ntops 291.97us torch 289.02us 0.990x --- tests/perf_utils.py | 90 ++++++++++++++++++++++++++++++++++++ tests/test_copysign_perf.py | 40 ++++++++++++++++ tests/test_lcm_perf.py | 40 ++++++++++++++++ tests/test_lgamma_perf.py | 39 ++++++++++++++++ tests/test_nextafter_perf.py | 40 ++++++++++++++++ tests/test_rad2deg_perf.py | 39 ++++++++++++++++ 6 files changed, 288 insertions(+) create mode 100644 tests/perf_utils.py create mode 100644 tests/test_copysign_perf.py create mode 100644 tests/test_lcm_perf.py create mode 100644 tests/test_lgamma_perf.py create mode 100644 tests/test_nextafter_perf.py create mode 100644 tests/test_rad2deg_perf.py diff --git a/tests/perf_utils.py b/tests/perf_utils.py new file mode 100644 index 0000000..2faf53e --- /dev/null +++ b/tests/perf_utils.py @@ -0,0 +1,90 @@ +import os +import statistics + +import pytest +import torch +import triton.testing as tt + + +_RUN_PERF_ENV = "NTOPS_RUN_PERF" + + +def skip_unless_perf_enabled(): + if os.environ.get(_RUN_PERF_ENV, "0") != "1": + pytest.skip( + f"perf benchmark; set {_RUN_PERF_ENV}=1 to run", + allow_module_level=True, + ) + + +SHAPES = [ + (13, 4), + (8, 16), + (2, 3, 4), + (16, 5632), + (256, 5632), + (1024, 5632), +] + + +FLOAT_DTYPES = [torch.float16, torch.bfloat16, torch.float32] + + +INT_DTYPES = [torch.int8, torch.int16, torch.int32, torch.int64] + + +_DTYPE_NAMES = { + torch.float16: "fp16", + torch.bfloat16: "bf16", + torch.float32: "fp32", + torch.int8: "int8", + torch.int16: "int16", + torch.int32: "int32", + torch.int64: "int64", +} + + +def dtype_name(dtype): + return _DTYPE_NAMES.get(dtype, str(dtype)) + + +MIN_RATIO = 0.5 + + +def bench_us(fn, *, warmup=50, rep=200, repeat=3): + runs = [tt.do_bench(fn, warmup=warmup, rep=rep) * 1000 for _ in range(repeat)] + return statistics.median(runs) + + +def report_and_assert(op_name, shape, dtype, ntops_us, torch_us): + ratio = torch_us / ntops_us if ntops_us > 0 else 0.0 + print( + f"\n {op_name:9s} shape={str(tuple(shape)):14s} dtype={dtype_name(dtype):5s} " + f"ntops={ntops_us:8.2f}us torch={torch_us:8.2f}us ratio={ratio:.3f}", + end="", + ) + assert ratio >= MIN_RATIO, ( + f"{op_name} perf regression: ratio {ratio:.3f} < {MIN_RATIO} " + f"(ntops={ntops_us:.2f}us, torch={torch_us:.2f}us, " + f"shape={tuple(shape)}, dtype={dtype_name(dtype)})" + ) + + +def warmup_pair(ntops_fn, torch_fn, n=50): + for _ in range(n): + ntops_fn() + torch_fn() + torch.cuda.synchronize() + + +def make_float_input(shape, dtype, *, op_name=None): + if op_name == "lgamma": + return torch.rand(shape, dtype=dtype, device="cuda") * 5.0 + 0.5 + return torch.randn(shape, dtype=dtype, device="cuda") + + +def make_int_input(shape, dtype): + info = torch.iinfo(dtype) + lo = max(info.min, -32768) + hi = min(info.max, 32767) + return torch.randint(lo, hi, shape, dtype=dtype, device="cuda") diff --git a/tests/test_copysign_perf.py b/tests/test_copysign_perf.py new file mode 100644 index 0000000..9be5dda --- /dev/null +++ b/tests/test_copysign_perf.py @@ -0,0 +1,40 @@ +import itertools + +import pytest +import torch + +import ntops +from tests.perf_utils import ( + FLOAT_DTYPES, + SHAPES, + bench_us, + dtype_name, + make_float_input, + report_and_assert, + skip_unless_perf_enabled, + warmup_pair, +) +from tests.skippers import skip_if_cuda_not_available + + +skip_unless_perf_enabled() + + +_PARAMS = list(itertools.product(SHAPES, FLOAT_DTYPES)) +_IDS = [f"{tuple(s)}-{dtype_name(d)}" for s, d in _PARAMS] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("shape, dtype", _PARAMS, ids=_IDS) +def test_copysign_perf(shape, dtype): + a = make_float_input(shape, dtype) + b = make_float_input(shape, dtype) + out = torch.empty_like(a) + + ntops_fn = lambda: ntops.torch.copysign(a, b, out=out) + torch_fn = lambda: torch.copysign(a, b, out=out) + warmup_pair(ntops_fn, torch_fn) + + ntops_us = bench_us(ntops_fn) + torch_us = bench_us(torch_fn) + report_and_assert("copysign", shape, dtype, ntops_us, torch_us) diff --git a/tests/test_lcm_perf.py b/tests/test_lcm_perf.py new file mode 100644 index 0000000..36c70fc --- /dev/null +++ b/tests/test_lcm_perf.py @@ -0,0 +1,40 @@ +import itertools + +import pytest +import torch + +import ntops +from tests.perf_utils import ( + INT_DTYPES, + SHAPES, + bench_us, + dtype_name, + make_int_input, + report_and_assert, + skip_unless_perf_enabled, + warmup_pair, +) +from tests.skippers import skip_if_cuda_not_available + + +skip_unless_perf_enabled() + + +_PARAMS = list(itertools.product(SHAPES, INT_DTYPES)) +_IDS = [f"{tuple(s)}-{dtype_name(d)}" for s, d in _PARAMS] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("shape, dtype", _PARAMS, ids=_IDS) +def test_lcm_perf(shape, dtype): + a = make_int_input(shape, dtype) + b = make_int_input(shape, dtype) + out = torch.empty_like(a) + + ntops_fn = lambda: ntops.torch.lcm(a, b, out=out) + torch_fn = lambda: torch.lcm(a, b, out=out) + warmup_pair(ntops_fn, torch_fn) + + ntops_us = bench_us(ntops_fn) + torch_us = bench_us(torch_fn) + report_and_assert("lcm", shape, dtype, ntops_us, torch_us) diff --git a/tests/test_lgamma_perf.py b/tests/test_lgamma_perf.py new file mode 100644 index 0000000..11bdd80 --- /dev/null +++ b/tests/test_lgamma_perf.py @@ -0,0 +1,39 @@ +import itertools + +import pytest +import torch + +import ntops +from tests.perf_utils import ( + FLOAT_DTYPES, + SHAPES, + bench_us, + dtype_name, + make_float_input, + report_and_assert, + skip_unless_perf_enabled, + warmup_pair, +) +from tests.skippers import skip_if_cuda_not_available + + +skip_unless_perf_enabled() + + +_PARAMS = list(itertools.product(SHAPES, FLOAT_DTYPES)) +_IDS = [f"{tuple(s)}-{dtype_name(d)}" for s, d in _PARAMS] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("shape, dtype", _PARAMS, ids=_IDS) +def test_lgamma_perf(shape, dtype): + a = make_float_input(shape, dtype, op_name="lgamma") + out = torch.empty_like(a) + + ntops_fn = lambda: ntops.torch.lgamma(a, out=out) + torch_fn = lambda: torch.lgamma(a, out=out) + warmup_pair(ntops_fn, torch_fn) + + ntops_us = bench_us(ntops_fn) + torch_us = bench_us(torch_fn) + report_and_assert("lgamma", shape, dtype, ntops_us, torch_us) diff --git a/tests/test_nextafter_perf.py b/tests/test_nextafter_perf.py new file mode 100644 index 0000000..164fabc --- /dev/null +++ b/tests/test_nextafter_perf.py @@ -0,0 +1,40 @@ +import itertools + +import pytest +import torch + +import ntops +from tests.perf_utils import ( + FLOAT_DTYPES, + SHAPES, + bench_us, + dtype_name, + make_float_input, + report_and_assert, + skip_unless_perf_enabled, + warmup_pair, +) +from tests.skippers import skip_if_cuda_not_available + + +skip_unless_perf_enabled() + + +_PARAMS = list(itertools.product(SHAPES, FLOAT_DTYPES)) +_IDS = [f"{tuple(s)}-{dtype_name(d)}" for s, d in _PARAMS] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("shape, dtype", _PARAMS, ids=_IDS) +def test_nextafter_perf(shape, dtype): + a = make_float_input(shape, dtype) + b = make_float_input(shape, dtype) + out = torch.empty_like(a) + + ntops_fn = lambda: ntops.torch.nextafter(a, b, out=out) + torch_fn = lambda: torch.nextafter(a, b, out=out) + warmup_pair(ntops_fn, torch_fn) + + ntops_us = bench_us(ntops_fn) + torch_us = bench_us(torch_fn) + report_and_assert("nextafter", shape, dtype, ntops_us, torch_us) diff --git a/tests/test_rad2deg_perf.py b/tests/test_rad2deg_perf.py new file mode 100644 index 0000000..8670295 --- /dev/null +++ b/tests/test_rad2deg_perf.py @@ -0,0 +1,39 @@ +import itertools + +import pytest +import torch + +import ntops +from tests.perf_utils import ( + FLOAT_DTYPES, + SHAPES, + bench_us, + dtype_name, + make_float_input, + report_and_assert, + skip_unless_perf_enabled, + warmup_pair, +) +from tests.skippers import skip_if_cuda_not_available + + +skip_unless_perf_enabled() + + +_PARAMS = list(itertools.product(SHAPES, FLOAT_DTYPES)) +_IDS = [f"{tuple(s)}-{dtype_name(d)}" for s, d in _PARAMS] + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("shape, dtype", _PARAMS, ids=_IDS) +def test_rad2deg_perf(shape, dtype): + a = make_float_input(shape, dtype) + out = torch.empty_like(a) + + ntops_fn = lambda: ntops.torch.rad2deg(a, out=out) + torch_fn = lambda: torch.rad2deg(a, out=out) + warmup_pair(ntops_fn, torch_fn) + + ntops_us = bench_us(ntops_fn) + torch_us = bench_us(torch_fn) + report_and_assert("rad2deg", shape, dtype, ntops_us, torch_us) From cfa42f35dcd9d151d543656c44582f07539ddcbf Mon Sep 17 00:00:00 2001 From: mygitljf <2410316423@qq.com> Date: Tue, 19 May 2026 17:08:47 +0000 Subject: [PATCH 7/7] [2026 spring][T1-1-1] add signed HONOR_CODE.md --- HONOR_CODE.md | 72 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 HONOR_CODE.md diff --git a/HONOR_CODE.md b/HONOR_CODE.md new file mode 100644 index 0000000..a375b70 --- /dev/null +++ b/HONOR_CODE.md @@ -0,0 +1,72 @@ +# 2026 春季启元人工智能大赛诚信守则(Honor Code) + + +本人作为 2026 春季启元人工智能大赛(以下简称“比赛”)的参赛选手,郑重承诺严格遵守比赛规则及本诚信守则,秉持诚信、公正、廉洁的参赛原则,自觉维护比赛的公平性与严肃性。本人充分理解并认可,违反本准则将导致参赛资格被取消、比赛成绩作废等相应后果,且愿意承担由此产生的一切责任。 + +## 一、参赛诚信承诺 + +1. 本人保证所提交的赛题PR(Pull Request)中包含的算子实现代码及相关文档,均为本人(及参赛团队,如为团队参赛)在比赛期间独立完成或在明确标注参考来源的基础上进行开发,不存在任何欺诈、抄袭、作弊行为。 + +2. 本人承诺主动、全面、真实地披露赛题实现过程中所有参考的外部资源,尤其是开源代码资源,不隐瞒任何可能影响比赛公平性的信息。 + +3. 本人保证不采用任何不正当手段获取比赛优势,包括但不限于窃取其他参赛选手的代码成果、利用非比赛允许的工具或技术、与他人串通作弊等。 + +## 二、参考资源说明 + +本人确认已按比赛要求,将本次赛题实现过程中涉及的参考资源信息单独撰写至`REFERENCE.md`文件中,该文件将与本诚信守则一同作为PR附件提交。`REFERENCE.md`需根据实际参考情况,按以下要求完整填写,信息不完整或虚假填写将视为违反本准则: + +**情况1:无参考外部开源代码及核心实现思路** + +`REFERENCE.md`中需明确声明:“本次赛题提交的算子代码、核心算法逻辑及实现方案均为本人(及参赛团队)独立设计与开发,未参考任何外部开源项目、技术文档中的核心代码片段或实现思路,未接受任何第三方的技术指导或代码支持。” + +**情况2:有参考外部开源代码及相关资源** + +对每个参考资源提供以下信息陈述: +1. 参考开源项目/资源名称 + +2. 参考资源链接(GitHub/Gitee/论文/技术文档等) + +3. 参考的具体内容(请明确说明参考的代码片段、算法逻辑、实现思路等,需标注对应资源的具体位置,如文件路径、代码行数等) + +4. 本人对参考内容的修改与优化说明:(请详细说明在参考基础上,本人所做的独立开发、修改、优化工作,体现自身技术贡献) + +5. 若是开源项目,提供参考资源的开源协议类型:(如MIT、Apache 2.0、GPL等) + +6. 其他需要补充说明的信息 + + +## 三、禁止行为确认 + +本人明确知晓并承诺避免以下违反比赛公平性的行为,若存在以下任一情况,自愿接受比赛组委会的相应处罚: + +1. 未经授权复制、抄袭他人(包括其他参赛选手、开源项目、商业代码)的代码、算法或技术方案,且未进行明确标注; + +2. 隐瞒或虚假披露参考资源信息,包括遗漏重要参考来源、伪造参考内容说明等; + +3. 与其他参赛选手或第三方串通,进行代码共享、成果交换等违规协作; + +4. 利用比赛平台漏洞、技术缺陷或非比赛允许的工具获取不正当利益; + +5. 伪造比赛相关证明材料、提交虚假信息; + +6. 其他违反比赛规则及公序良俗的不诚信行为。 + + +## 四、责任与确认 + +1. 本人充分理解,比赛组委会将对所有提交的PR进行代码溯源、参考信息核查等公平性审查,若发现本人存在违反本准则的行为,有权随时取消本人的参赛资格、作废比赛成绩,情节严重的将在比赛相关平台进行公示。 + +2. 若因本人违反本准则导致比赛争议或第三方权益受损(如开源协议侵权等),本人将独立承担全部法律责任及相关损失,与比赛组委会无关。 + +3. 本人确认已仔细阅读并完全理解本诚信守则的全部内容,自愿签署本准则,接受比赛组委会的监督与审查。 + +## 五、签署信息 + +参赛选手姓名(团队参赛需填写所有成员姓名) + + 练锦烽 + +签署日期 + +___2026___年__5__月__18__日 +