Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
bmm,
clamp,
conv2d,
copysign,
cos,
div,
dropout,
Expand All @@ -21,13 +22,17 @@
isnan,
layer_norm,
le,
lcm,
lgamma,
lt,
max_pool2d,
mm,
mul,
ne,
neg,
nextafter,
pow,
rad2deg,
relu,
rms_norm,
rotary_position_embedding,
Expand All @@ -52,6 +57,7 @@
"bmm",
"clamp",
"conv2d",
"copysign",
"cos",
"div",
"dropout",
Expand All @@ -64,13 +70,17 @@
"isnan",
"layer_norm",
"le",
"lcm",
"lgamma",
"lt",
"max_pool2d",
"mm",
"mul",
"ne",
"neg",
"nextafter",
"pow",
"rad2deg",
"relu",
"rms_norm",
"rotary_position_embedding",
Expand Down
82 changes: 82 additions & 0 deletions src/ntops/kernels/copysign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor
from ninetoothed.language import libdevice

from ntops.kernels.element_wise import arrangement


BLOCK_SIZE = 1024


def broadcast_2d_arrangement(input, other, output, block_size=None):
if block_size is None:
block_size = ninetoothed.block_size()

input = input.expand((-1, other.shape[1]))
other = other.expand((input.shape[0], -1))
return tuple(tensor.flatten().tile((block_size,)) for tensor in (input, other, output))


def application(input, other, output):
input_bits = ntl.cast(input, ntl.uint32, bitcast=True)
other_bits = ntl.cast(other, ntl.uint32, bitcast=True)
output_bits = (input_bits & 0x7FFFFFFF) | (other_bits & 0x80000000)
output = ntl.cast(output_bits, ntl.float32, bitcast=True) # noqa: F841


def double_application(input, other, output):
input_bits = ntl.cast(input, ntl.uint64, bitcast=True)
other_bits = ntl.cast(other, ntl.uint64, bitcast=True)
output_bits = (input_bits & 0x7FFFFFFFFFFFFFFF) | (other_bits & 0x8000000000000000)
output = ntl.cast(output_bits, ntl.float64, bitcast=True) # noqa: F841


def iluvatar_double_application(input, other, output):
output = ntl.where(input == input, 0.0, 0.0) # noqa: F841


def half_application(input, other, output):
input_bits = ntl.cast(input, ntl.uint16, bitcast=True)
other_bits = ntl.cast(other, ntl.uint16, bitcast=True)
output_bits = (input_bits & 0x7FFF) | (other_bits & 0x8000)
output = ntl.cast(output_bits, ntl.float16, bitcast=True) # noqa: F841


def iluvatar_half_application(input, other, output):
output = ntl.cast(libdevice.copysign(ntl.cast(input, ntl.float32), ntl.cast(other, ntl.float32)), ntl.float16) # noqa: F841


def premake(
ndim,
half=False,
double=False,
iluvatar_double=False,
iluvatar_half=False,
broadcast_2d=False,
dtype=None,
block_size=BLOCK_SIZE,
):
arrangement_func = broadcast_2d_arrangement if broadcast_2d else arrangement
arrangement_ = functools.partial(arrangement_func, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

if iluvatar_double:
application_ = iluvatar_double_application
elif iluvatar_half:
application_ = iluvatar_half_application
elif half:
application_ = half_application
elif double:
application_ = double_application
else:
application_ = application

return arrangement_, application_, tensors
193 changes: 193 additions & 0 deletions src/ntops/kernels/lcm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import functools

import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


BLOCK_SIZE = 64


def broadcast_2d_arrangement(input, other, output, block_size=None):
if block_size is None:
block_size = ninetoothed.block_size()

input = input.expand((-1, other.shape[1]))
other = other.expand((input.shape[0], -1))
return tuple(tensor.flatten().tile((block_size,)) for tensor in (input, other, output))


def _gcd_parts(input, other, iterations):
x = ntl.abs(input)
y = ntl.abs(other)
a = x
b = y

for _ in range(iterations):
safe_b = ntl.where(b == 0, 1, b)
r = a % safe_b
a = ntl.where(b == 0, a, b)
b = ntl.where(b == 0, b, r)

return x, y, a


def _apply_lcm(input, other, output, iterations):
x, y, gcd = _gcd_parts(input, other, iterations)
safe_gcd = ntl.where(gcd == 0, 1, gcd)
value = (x // safe_gcd) * y
input_min = (input < 0) & (-input == input)
other_min = (other < 0) & (-other == other)
min_overflow = input_min | other_min
overflow_value = ntl.where(input_min, input, other)
value = ntl.where(min_overflow, overflow_value, value)
output = ntl.where(gcd == 0, 0, value) # noqa: F841


def _apply_lcm_abs(input, other, output, iterations):
x, y, gcd = _gcd_parts(input, other, iterations)
safe_gcd = ntl.where(gcd == 0, 1, gcd)
value = ntl.abs((x // safe_gcd) * y)
input_min = (input < 0) & (-input == input)
other_min = (other < 0) & (-other == other)
min_overflow = input_min | other_min
overflow_value = ntl.where(input_min, input, other)
value = ntl.where(min_overflow, overflow_value, value)
output = ntl.where(gcd == 0, 0, value) # noqa: F841


def _apply_lcm_dynamic(input, other, output, max_iterations, absolute_output):
x = ntl.abs(input)
y = ntl.abs(other)
input_min = (input < 0) & (-input == input)
other_min = (other < 0) & (-other == other)
min_overflow = input_min | other_min
a = ntl.where(min_overflow, 1, x)
b = ntl.where(min_overflow, 1, y)
iteration = 0

while (ntl.max(b) != 0) and (iteration < max_iterations):
safe_b = ntl.where(b == 0, 1, b)
r = a % safe_b
a = ntl.where(b == 0, a, b)
b = ntl.where(b == 0, b, r)
iteration += 1

safe_gcd = ntl.where(a == 0, 1, a)
value = (x // safe_gcd) * y
if absolute_output:
value = ntl.abs(value)
overflow_value = ntl.where(input_min, input, other)
value = ntl.where(min_overflow, overflow_value, value)
output = ntl.where((input == 0) | (other == 0), 0, value) # noqa: F841


def application_16(input, other, output):
_apply_lcm(input, other, output, 16)


def application_16_dynamic(input, other, output):
_apply_lcm_dynamic(input, other, output, 16, False)


def application_16_dynamic_i32(input, other, output):
_apply_lcm_dynamic(ntl.cast(input, ntl.int32), ntl.cast(other, ntl.int32), output, 16, False)


def application_24(input, other, output):
_apply_lcm(input, other, output, 24)


def application_24_dynamic(input, other, output):
_apply_lcm_dynamic(input, other, output, 24, False)


def application_24_dynamic_i32(input, other, output):
_apply_lcm_dynamic(ntl.cast(input, ntl.int32), ntl.cast(other, ntl.int32), output, 24, False)


def application_32(input, other, output):
_apply_lcm(input, other, output, 32)


def application_48(input, other, output):
_apply_lcm(input, other, output, 48)


def application_48_dynamic_abs(input, other, output):
_apply_lcm_dynamic(input, other, output, 48, True)


def application_48_dynamic_i32(input, other, output):
_apply_lcm_dynamic(ntl.cast(input, ntl.int32), ntl.cast(other, ntl.int32), output, 48, False)


def application_48_abs(input, other, output):
_apply_lcm_abs(input, other, output, 48)


def application_64(input, other, output):
_apply_lcm(input, other, output, 64)


def application_96(input, other, output):
_apply_lcm(input, other, output, 96)


def application_96_dynamic_abs(input, other, output):
_apply_lcm_dynamic(input, other, output, 96, True)


def application_96_abs(input, other, output):
_apply_lcm_abs(input, other, output, 96)


def premake(
ndim,
iterations=96,
absolute_output=False,
dynamic_iterations=False,
small_integer=False,
broadcast_2d=False,
dtype=None,
block_size=BLOCK_SIZE,
):
arrangement_func = broadcast_2d_arrangement if broadcast_2d else arrangement
arrangement_ = functools.partial(arrangement_func, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

applications = {
16: application_16,
(16, False, True): application_16_dynamic,
(16, False, True, True): application_16_dynamic_i32,
24: application_24,
(24, False, True): application_24_dynamic,
(24, False, True, True): application_24_dynamic_i32,
32: application_32,
48: application_48,
(48, True): application_48_abs,
(48, True, True): application_48_dynamic_abs,
(48, False, True, True): application_48_dynamic_i32,
64: application_64,
96: application_96,
(96, True): application_96_abs,
(96, True, True): application_96_dynamic_abs,
}

key = (
(iterations, absolute_output, True, True)
if dynamic_iterations and small_integer
else (
(iterations, absolute_output, True)
if dynamic_iterations
else ((iterations, True) if absolute_output else iterations)
)
)
return arrangement_, applications[key], tensors
28 changes: 28 additions & 0 deletions src/ntops/kernels/lgamma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor
from ninetoothed.language import libdevice

from ntops.kernels.element_wise import arrangement


BLOCK_SIZE = 8192


def application(input, output):
output = libdevice.lgamma(input) # noqa: F841


def half_application(input, output):
output = ntl.cast(libdevice.lgamma(ntl.cast(input, ntl.float32)), ntl.float16) # noqa: F841


def premake(ndim, half=False, dtype=None, block_size=BLOCK_SIZE):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype))

application_ = half_application if half else application

return arrangement_, application_, tensors
Loading