[2026春季][T1-1-1] mygitljf#80
Open
mygitljf wants to merge 7 commits into
Open
Conversation
Five element-wise operators implemented in NineToothed DSL on top of
ntops.kernels.element_wise.arrangement.
- rad2deg: input * (180 / pi).
- copysign: IEEE bit manipulation; bitcast to int, splice magnitude bits
of input with sign bit of other. Avoids the fp16/bf16 -> fp32
cast required by libdevice.copysign and is bit-exact across
fp16/bf16/fp32.
- nextafter: IEEE bit manipulation; handle NaN/equal/zero specially, then
walk one ULP toward the target in integer space. Bit-exact
across fp16/bf16/fp32.
- lcm: Euclidean GCD with dtype-bucketed unroll counts (28 for
int8/int16, 56 for int32, 104 for int64) -- the worst-case
Fibonacci bound. The wrapper passes dtype to _cached_make so
the per-dtype application is selected correctly.
- lgamma: libdevice.lgamma; cast fp16/bf16 up to fp32 since libdevice
does not support narrow floats here.
Tests cover float16/float32 (lcm: int8/int16/int32) for ndim 1-4 against
PyTorch references using torch.allclose / torch.equal. Verified on NVIDIA
A100 80GB PCIe (44 passed, 4 skipped for bool lcm).
…e unroll
Iter02+03 optimization for the lcm operator. Replaces the Euclidean GCD
inner loop (which used 64-bit IDIV/IREM, ~100+ cycles per op on A100)
with Stein's binary GCD using libdevice.ffs (ctz) + shift + subtract.
Per-dtype unroll buckets sized to the worst-case Stein iteration count:
application_32 : int8 / int16 (promoted to int32, value range <= 32767)
application_64 : int32
application_128 : int64
ncu profile (sudo) on application_104 (iter01) showed Compute 66.4%,
DRAM 1.1%, 39 regs/thread - confirming pure compute bound on integer
division. After Stein:
device time at (1024, 5632), ratio = torch / ntops:
iter01 iter02 iter03 iter01 -> iter03
lcm int8 0.43 0.33 0.59 +37%
lcm int16 0.66 0.52 0.92 +39%
lcm int32 0.46 0.83 0.83 +80%
lcm int64 0.36 0.76 0.76 +111%
Correctness: ntops pytest 12 passed + 4 skipped (bool); InfiniCore
--nvidia 40/40 passed. Bit-exact against torch.lcm including int8/int16
two's-complement wrap-around behaviour.
Constraints learned (also in docs/算子开发经验.md):
- NineToothed AST does not cross Python function boundaries cleanly;
helper functions returning tuples raise compilation errors. All
prelude / step / finish blocks must be inlined per application_<N>.
- libdevice.ffs(0) returns 0; ffs(x>0) returns 1 + ctz(x). Need where()
masks at every ctz site.
Detailed iteration logs in docs/logs/T1-1-1/{iter01,iter02,iter03}/.
Iter04 follow-up to commit 8b74924. Empirical measurement shows the Stein outer iteration count is bounded by bit_width (not 2 * bit_width) because each outer step absorbs all consecutive trailing-zero shifts via 'b >> ctz(b)'. Fibonacci adversarial + (2^k-1, 2^(k-1)) pathological inputs confirm: int8 worst case 5 -> use 8 int16 worst case 13 -> use 16 int32 worst case 31 -> use 36 int64 worst case 63 -> use 72 Renamed application_32/64/128 to application_8/16/36/72 with separate buckets for int8 and int16. Premake dispatch is now 4-way. ncu shows iter03 application_128 was 96.84% compute-saturated, so cutting unroll directly translates to wall-clock speedup. application_72 duration drops from 2.82 ms to 1.65 ms (-42%, matches 128->72 ratio). device time at (1024, 5632), ratio = torch / ntops: iter03 iter04 lcm int8 0.59 -> 1.33 +126% (above torch) lcm int16 0.92 -> 1.47 +60% lcm int32 0.83 -> 1.35 +63% lcm int64 0.76 -> 1.30 +71% iter01 -> iter04 cumulative on int64: 0.36 -> 1.30 (+261%). Correctness: ntops pytest 12 passed + 4 skipped (bool); InfiniCore --nvidia 40/40 passed. Detailed iteration log in docs/logs/T1-1-1/lcm/iter04/.
… int64 dynamic Euclidean - lcm: Stein inner-loop sentinel-merge (3 wheres -> 1 where) for int8/int16/int32 with per-dtype unroll (8/16/36); int64 switched to dynamic Euclidean with grouped block-level early stop. Flatten contiguous inputs to 1D; explicit (block_size, num_warps, num_stages) per dtype. - rad2deg: block_size=2048, num_warps=4, num_stages=1; flatten 1D. - copysign: block_size=1024, num_warps=4. - nextafter: block_size=1024, num_warps=4, num_stages=2. - lgamma: large-numel path with block_size=1024, num_warps=4, num_stages=5. Verification: - ntops pytest (rad2deg/copysign/nextafter/lgamma/lcm): 44 passed, 4 skipped. - InfiniCore --nvidia: 5/5 ops 100% pass. - InfiniCore --bench device aggregate (lcm): 0.52x -> 1.67x.
…llback
Production-side changes used during MetaX C500 and Iluvatar MR-V100
evaluation runs.
ntops.kernels.copysign / nextafter:
- Replace runtime `int_dtype.primitive_bitwidth` with three explicit
applications (int16 / int32 / int64) selected in `premake` from the
output dtype. The previous expression relied on querying the
primitive bit width through the result of a `cast(..., bitcast=True)`
expression, which forced fp32->fp16 bitcast paths on backends that
cannot lower 32-bit-to-16-bit bitcast stores.
ntops.kernels.lcm:
- Pin `compute_dtype` to `int32` for application_8/16/36 (int8/int16
always promote, int32 stays). int64 keeps its dynamic-Euclidean
application unchanged.
ntops.kernels.lgamma:
- Split into `application_float32_compute` (fp16/bf16 -> compute in
fp32 -> cast back) and `application_native` (fp32/fp64). `premake`
selects per dtype.
ntops.torch.{copysign, nextafter, lgamma}:
- Pass `dtype=ninetoothed.<dtype>` into `_cached_make` so the
per-dtype application is selected at compile time.
- copysign / nextafter: when running on Iluvatar MR-V100 / CoreX
with fp16/bf16, fall back to torch.<op> via a wrapper. Iluvatar
CoreX Triton fails to compile the 16-bit bitcast store path; on
every other backend (NVIDIA, MetaX, Moore) the native ntops kernel
path is taken. Detection lives in
`ntops.torch.utils._is_corex_compat_device` (controllable via
`NTOPS_BACKEND` env).
ntops.torch.utils:
- Add `_is_corex_compat_device`, `_torch_binary_fallback`, and
infinicore <-> torch tensor bridging helpers used by the fallback.
Verification
------------
NVIDIA A100 80GB (--nvidia):
ntops pytest: 44 passed, 4 skipped
InfiniCore run.py: 172/172 (100.0%)
--bench device aggregate: ~1.05x (lcm 1.49x, rad2deg 1.01x,
nextafter 0.99x, copysign 0.95x, lgamma 0.79x)
MetaX C500 (--metax):
ntops pytest: 44 passed, 4 skipped
InfiniCore run.py: 172/172 (100.0%)
--bench device aggregate: 1.022x (lcm 1.272x, lgamma 0.987x,
nextafter 0.997x, copysign 0.933x, rad2deg 0.901x)
Iluvatar MR-V100 (--iluvatar):
ntops pytest: 44 passed, 4 skipped
InfiniCore run.py: 172/172 (100.0%)
--bench device aggregate: 0.477x (the wrapper-fallback path on
fp16/bf16 dominates the elementwise budget; fp32 paths still hit
the native kernel)
Adds an opt-in performance harness that benchmarks ntops.torch.<op>
against torch.<op> with `triton.testing.do_bench` (warmup=50, rep=200,
median of 3). Tests are skipped by default and only execute when
`NTOPS_RUN_PERF=1` is set, so they cost nothing in the standard
correctness pipeline.
Files
-----
- tests/perf_utils.py: shared shapes/dtype lists, `bench_us`,
`report_and_assert` (asserts ratio = torch_us / ntops_us >= 0.5),
warmup_pair, input factories.
- tests/test_<op>_perf.py for rad2deg, copysign, lcm, nextafter,
lgamma: parametrized over a representative set of shapes and
dtypes (fp16/bf16/fp32; int8/int16/int32/int64 for lcm).
Usage
-----
NTOPS_RUN_PERF=1 python -m pytest tests/test_rad2deg_perf.py -s
Verification (NVIDIA A100 80GB):
rad2deg 18 cases ntops 190.10us torch 190.59us 1.003x
copysign 18 cases ntops 248.22us torch 235.79us 0.950x
lcm 24 cases ntops 1018.18us torch 1177.45us 1.156x
nextafter 18 cases ntops 254.00us torch 255.25us 1.005x
lgamma 18 cases ntops 291.97us torch 289.02us 0.990x
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Platform Compatibility
NVIDIA A100, MetaX C500, Iluvatar MR-V100:
On Iluvatar / CoreX the
bitcast storelowering path for fp16 / bf16 is not available, socopysign/nextafterfall back totorch.<op>for fp16 / bf16 on that platform viantops.torch.utils._is_corex_compat_device. NVIDIA and MetaX still take the original ntops kernel path.Test Commands
ntops correctness
ntops performance
Set
NTOPS_RUN_PERF=1to enable.warmup=50,rep=200, median of 3.InfiniCore correctness + performance (switch the device flag per platform)
Test Results (100% pass on all three platforms)
Correctness
| Platform | Device flag | ntops pytest | InfiniCore run.py |
|---|---|
| NVIDIA A100 80GB PCIe |
--nvidia| 44 passed, 4 skipped | 172 / 172 (100.0%) || MetaX C500 |
--metax| 44 passed, 4 skipped | 172 / 172 (100.0%) || Iluvatar MR-V100 |
--iluvatar| 44 passed, 4 skipped | 172 / 172 (100.0%) |The
4 skippedcases arelcmonbooldtype (torch.lcmitself does not supportbool).Performance
ntops
InfiniCore
NVIDIA A100:

MetaX C500:
Iluvatar MR-V100: