Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 231 additions & 0 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
"STABLEHLO_REDUCE": self._convert_stablehlo_reduce,
"STABLEHLO_REDUCE_WINDOW": self._convert_stablehlo_reduce_window,
"STABLEHLO_REMAINDER": self._convert_stablehlo_remainder,
"STABLEHLO_RNG_BIT_GENERATOR": self._convert_stablehlo_rng_bit_generator,
"STABLEHLO_RSQRT": functools.partial(self._convert_stablehlo_unary, relax_op=_op.rsqrt),
"STABLEHLO_SCATTER": self._convert_stablehlo_scatter,
"STABLEHLO_SELECT": functools.partial(
Expand Down Expand Up @@ -918,6 +919,8 @@ def get_tensor_type_as_numpy(self, tensor_wrapper):
TensorType.FLOAT32: np.float32,
TensorType.INT32: np.int32,
TensorType.INT64: np.int64,
TensorType.UINT32: np.uint32,
TensorType.UINT64: np.uint64,
TensorType.BOOL: np.bool_,
}[tensor_wrapper.tensor.Type()]

Expand Down Expand Up @@ -958,6 +961,10 @@ def get_tensor_type_str(self, tensor_type):
return "int32"
if tensor_type == TensorType.INT64:
return "int64"
if tensor_type == TensorType.UINT32:
return "uint32"
if tensor_type == TensorType.UINT64:
return "uint64"
if tensor_type == TensorType.BOOL:
return "bool"
raise NotImplementedError(f"Tensor type {tensor_type!s} is currently not supported")
Expand Down Expand Up @@ -2206,6 +2213,72 @@ def _convert_stablehlo_custom_call(self, op):
target = call_target_name or "<empty>"
raise tvm.error.OpNotImplemented(f"STABLEHLO_CUSTOM_CALL target {target} is not supported")

def _convert_stablehlo_rng_bit_generator(self, op):
"""Convert STABLEHLO_RNG_BIT_GENERATOR to a bit-exact call_tir kernel."""
from tflite.RngAlgorithm import RngAlgorithm
from tflite.StablehloRngBitGeneratorOptions import StablehloRngBitGeneratorOptions

op_name = "STABLEHLO_RNG_BIT_GENERATOR"
input_tensors = self.get_input_tensors(op)
output_tensors = self.get_output_tensors(op)
if len(input_tensors) != 1 or len(output_tensors) != 2:
raise tvm.error.OpNotImplemented(f"{op_name} expects one input and two outputs")

opts = self._get_stablehlo_options(op, StablehloRngBitGeneratorOptions)
algorithm_enum = opts.Algorithm()
# DEFAULT resolves to PHILOX in the TFLite runtime kernel.
if algorithm_enum == RngAlgorithm.THREEFRY:
algorithm = "threefry"
elif algorithm_enum in (RngAlgorithm.PHILOX, RngAlgorithm.DEFAULT):
algorithm = "philox"
else:
raise tvm.error.OpNotImplemented(
f"{op_name} algorithm {algorithm_enum} is not supported"
)

state_tensor = input_tensors[0]
if self.get_tensor_type_str(state_tensor.tensor.Type()) != "uint64":
raise tvm.error.OpNotImplemented(f"{op_name} requires a uint64 initial state")
state_shape = self._get_static_tensor_shape(state_tensor, op_name)
if len(state_shape) != 1:
raise tvm.error.OpNotImplemented(f"{op_name} requires a 1-D initial state")
state_len = int(state_shape[0])
# State-length constraints mirror the TFLite runtime kernel.
if algorithm == "threefry" and state_len != 2:
raise tvm.error.OpNotImplemented(f"{op_name} THREEFRY requires a u64[2] state")
if algorithm == "philox" and state_len not in (2, 3):
raise tvm.error.OpNotImplemented(f"{op_name} PHILOX requires a u64[2] or u64[3] state")

out_state_tensor, out_tensor = output_tensors
if self.get_tensor_type_str(out_state_tensor.tensor.Type()) != "uint64":
raise tvm.error.OpNotImplemented(f"{op_name} output state must be uint64")
out_state_shape = self._get_static_tensor_shape(out_state_tensor, op_name)
if list(out_state_shape) != list(state_shape):
raise tvm.error.OpNotImplemented(
f"{op_name} output state shape must match the initial state"
)
out_dtype = self.get_tensor_type_str(out_tensor.tensor.Type())
if out_dtype not in ("int32", "int64", "uint32", "uint64"):
raise tvm.error.OpNotImplemented(f"{op_name} output dtype {out_dtype} is not supported")
out_shape = tuple(self._get_static_tensor_shape(out_tensor, op_name))

prim_func = _build_stablehlo_rng_bit_generator_primfunc(
algorithm, state_len, out_dtype, out_shape
)
module_builder = self.conversion_state["module_builder"]
func_name = f"tflite_stablehlo_rng_{algorithm}_{out_state_tensor.tensor_idx}"
gv = module_builder.add_func(prim_func, func_name)
state_expr = self.get_tensor_expr(state_tensor)
call = relax.call_tir(
gv,
[state_expr],
[
relax.TensorStructInfo(tuple(state_shape), "uint64"),
relax.TensorStructInfo(out_shape, out_dtype),
],
)
return self.bb.normalize(call)

def _convert_stablehlo_while(self, op):
"""Convert STABLEHLO_WHILE to a recursive Relax private function."""
from tflite.StablehloWhileOptions import StablehloWhileOptions
Expand Down Expand Up @@ -7347,6 +7420,162 @@ def get_tensor_shape(self, tensor_wrapper):
)


# Constants for the Random123 counter-based PRNGs used by STABLEHLO_RNG_BIT_GENERATOR,
# matching tensorflow/lite/kernels/rng_util.cc.
_STABLEHLO_RNG_THREEFRY_PARITY = 0x1BD11BDA
_STABLEHLO_RNG_PHILOX_MUL_A = 0xD2511F53
_STABLEHLO_RNG_PHILOX_MUL_B = 0xCD9E8D57
_STABLEHLO_RNG_PHILOX_WEYL_A = 0x9E3779B9
_STABLEHLO_RNG_PHILOX_WEYL_B = 0xBB67AE85


def _build_stablehlo_rng_bit_generator_primfunc(algorithm, state_len, out_dtype, out_shape):
"""Build a bit-exact TIR kernel for STABLEHLO_RNG_BIT_GENERATOR.

Mirrors the TFLite runtime kernel (tensorflow/lite/kernels/rng_bit_generator.cc),
implementing the Random123 Threefry2x32 (20 rounds) and Philox4x32 (10 rounds)
counter-based PRNGs. The kernel reinterprets the uint64 state as uint32 words,
advances a 64-bit block counter, and packs the generated words into the output
tensor. The updated state keeps the key unchanged and only advances the counter,
which is the only behaviour the runtime relies on.
"""
from tvm.script.parser import tirx as T
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a typo in the import statement. tirx does not exist in tvm.script.parser. It should be imported from tvm.script as tir.

Suggested change
from tvm.script.parser import tirx as T
from tvm.script import tir as T


total = 1
for dim in out_shape:
total *= int(dim)
is_64bit = out_dtype in ("int64", "uint64")
block_words = 2 if algorithm == "threefry" else 4
out_word_count = total * (2 if is_64bit else 1)
num_blocks = (out_word_count + block_words - 1) // block_words
writes_per_block = block_words // (2 if is_64bit else 1)
parity = _STABLEHLO_RNG_THREEFRY_PARITY
mul_a, mul_b = _STABLEHLO_RNG_PHILOX_MUL_A, _STABLEHLO_RNG_PHILOX_MUL_B
weyl_a, weyl_b = _STABLEHLO_RNG_PHILOX_WEYL_A, _STABLEHLO_RNG_PHILOX_WEYL_B

def _u32(value):
return T.Cast("uint32", value)

def _u64(value):
return T.Cast("uint64", value)

def _store_value(words, write_index):
# Pack the generated uint32 words into one output element, reinterpreting
# the bit pattern into the (possibly signed) output dtype.
if is_64bit:
low = _u64(words[2 * write_index])
high = _u64(words[2 * write_index + 1])
return T.reinterpret(out_dtype, low | (high << T.uint64(32)))
return T.reinterpret(out_dtype, words[write_index])

if algorithm == "threefry":

@T.prim_func(private=True, s_tir=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The @T.prim_func decorator does not accept an s_tir argument. This will raise a TypeError at runtime. Please remove s_tir=True.

Suggested change
@T.prim_func(private=True, s_tir=True)
@T.prim_func(private=True)

def kernel(
initial_state: T.Buffer((state_len,), "uint64"),
output_state: T.Buffer((state_len,), "uint64"),
output: T.Buffer(out_shape, out_dtype),
):
# A single opaque structured block keeps the imperative kernel as a
# well-formed block-structured PrimFunc, as required by the Relax
# pipeline (e.g. HasReshapePattern).
with T.sblock("rng_bit_generator"):
state_key = initial_state[0]
state_counter = initial_state[1]
key_0 = _u32(state_key & T.uint64(0xFFFFFFFF))
key_1 = _u32(state_key >> T.uint64(32))
output_state[0] = state_key
output_state[1] = state_counter + T.uint64(num_blocks)
out_flat = T.decl_buffer((total,), out_dtype, data=output.data)
keys = T.decl_buffer((3,), "uint32", scope="local")
rotations = T.decl_buffer((8,), "uint32", scope="local")
ctr = T.decl_buffer((2,), "uint32", scope="local")
keys[0] = key_0
keys[1] = key_1
keys[2] = key_0 ^ key_1 ^ T.uint32(parity)
rotations[0] = T.uint32(13)
rotations[1] = T.uint32(15)
rotations[2] = T.uint32(26)
rotations[3] = T.uint32(6)
rotations[4] = T.uint32(17)
rotations[5] = T.uint32(29)
rotations[6] = T.uint32(16)
rotations[7] = T.uint32(24)
for block in T.serial(num_blocks):
counter = state_counter + _u64(block)
ctr[0] = _u32(counter & T.uint64(0xFFFFFFFF)) + key_0
ctr[1] = _u32(counter >> T.uint64(32)) + key_1
for group in T.serial(5):
for step in T.serial(4):
rot = rotations[(group * 4 + step) % 8]
ctr[0] = ctr[0] + ctr[1]
ctr[1] = (ctr[1] << rot) | (ctr[1] >> (T.uint32(32) - rot))
ctr[1] = ctr[1] ^ ctr[0]
ctr[0] = ctr[0] + keys[(group + 1) % 3]
ctr[1] = ctr[1] + keys[(group + 2) % 3] + _u32(group + 1)
for write_index in T.serial(writes_per_block):
element = block * writes_per_block + write_index
if element < total:
out_flat[element] = _store_value(ctr, write_index)

return kernel

@T.prim_func(private=True, s_tir=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The @T.prim_func decorator does not accept an s_tir argument. This will raise a TypeError at runtime. Please remove s_tir=True.

Suggested change
@T.prim_func(private=True, s_tir=True)
@T.prim_func(private=True)

def kernel(
initial_state: T.Buffer((state_len,), "uint64"),
output_state: T.Buffer((state_len,), "uint64"),
output: T.Buffer(out_shape, out_dtype),
):
with T.sblock("rng_bit_generator"):
state_key = initial_state[0]
state_counter = initial_state[1]
key_0 = _u32(state_key & T.uint64(0xFFFFFFFF))
key_1 = _u32(state_key >> T.uint64(32))
output_state[0] = state_key
output_state[1] = state_counter + T.uint64(num_blocks)
out_flat = T.decl_buffer((total,), out_dtype, data=output.data)
ctr = T.decl_buffer((4,), "uint32", scope="local")
keys = T.decl_buffer((2,), "uint32", scope="local")
high_ctr = T.decl_buffer((2,), "uint32", scope="local")
if state_len == 3:
# PHILOX u64[3]: the third state word feeds the high counter and
# is passed through to the output state unchanged.
high_state = initial_state[2]
output_state[2] = high_state
high_ctr[0] = _u32(high_state & T.uint64(0xFFFFFFFF))
high_ctr[1] = _u32(high_state >> T.uint64(32))
else:
high_ctr[0] = key_0
high_ctr[1] = key_1
for block in T.serial(num_blocks):
counter = state_counter + _u64(block)
ctr[0] = _u32(counter & T.uint64(0xFFFFFFFF))
ctr[1] = _u32(counter >> T.uint64(32))
ctr[2] = high_ctr[0]
ctr[3] = high_ctr[1]
keys[0] = key_0
keys[1] = key_1
for _round in T.serial(10):
prod_0 = T.uint64(mul_a) * _u64(ctr[0])
prod_1 = T.uint64(mul_b) * _u64(ctr[2])
new_0 = _u32(prod_1 >> T.uint64(32)) ^ ctr[1] ^ keys[0]
new_1 = _u32(prod_1 & T.uint64(0xFFFFFFFF))
new_2 = _u32(prod_0 >> T.uint64(32)) ^ ctr[3] ^ keys[1]
new_3 = _u32(prod_0 & T.uint64(0xFFFFFFFF))
ctr[0] = new_0
ctr[1] = new_1
ctr[2] = new_2
ctr[3] = new_3
keys[0] = keys[0] + T.uint32(weyl_a)
keys[1] = keys[1] + T.uint32(weyl_b)
for write_index in T.serial(writes_per_block):
element = block * writes_per_block + write_index
if element < total:
out_flat[element] = _store_value(ctr, write_index)

return kernel


# pylint: disable=no-else-return
def prepare_dense_matrix_from_sparse(sparse_tensor, sparse_tensor_value, sparse_tensor_type):
"""Prepare sparse indices and dense matrix from TFLite sparse parameters."""
Expand Down Expand Up @@ -7593,6 +7822,8 @@ def _decode_type(n):
7: "int16",
8: "complex64",
9: "int8",
12: "uint64",
15: "uint32",
}
return _tflite_m[n]

Expand Down
Loading
Loading