From 7ad667c4aea8cac9a01634d2ca031d0db140120c Mon Sep 17 00:00:00 2001 From: Aharrypotter Date: Sun, 31 May 2026 19:33:21 +0800 Subject: [PATCH 1/2] [Relax][Frontend][TFLite] Support STABLEHLO_RNG_BIT_GENERATOR Add Relax TFLite frontend support for the `STABLEHLO_RNG_BIT_GENERATOR` builtin operator, lowering it to a bit-exact `call_tir` PRNG kernel. The TFLite runtime kernel (tensorflow/lite/kernels/rng_bit_generator.cc) is a real, deterministic counter-based PRNG, not a thin StableHLO shim: - one uint64 1-D `initial_state` input, two outputs (uint64 `output_state` plus the random-bit `output` in int32/int64/uint32/uint64); - `algorithm` in {DEFAULT, PHILOX, THREEFRY}, where DEFAULT resolves to PHILOX in the runtime; - bit-exact Random123 Threefry2x32 (20 rounds) and Philox4x32 (10 rounds) with fixed constants from rng_util.cc; - state-length constraints: THREEFRY requires u64[2], PHILOX/DEFAULT require u64[2] or u64[3]. Since TVM/Relax has no matching RNG primitive, the converter generates a TIR kernel that mirrors the runtime exactly and emits it via `call_tir`. The kernel reinterprets the uint64 state as uint32 words, advances a 64-bit block counter, runs the selected algorithm per block with all round state materialized into local buffers (avoiding exponential expression blow-up), and packs the generated words back into the output dtype. The updated state keeps the key unchanged and only advances the counter, which is the only state behaviour the runtime relies on. The kernel is a `s_tir` PrimFunc wrapped in a single opaque structured block so it stays a well-formed block-structured function for the Relax pipeline (e.g. HasReshapePattern). Unsupported cases raise a precise OpNotImplemented: non-uint64 state, non-1-D state, mismatched output-state shape, unsupported output dtype, unknown algorithm, and per-algorithm state-length violations. Also extend `get_tensor_type_str` and the input `_decode_type` map with uint32/uint64 so uint64 state tensors import correctly. Tests build minimal RNG flatbuffers, compile, and execute them, checking output and updated state against the verbatim expected vectors from the TFLite runtime kernel test (rng_bit_generator_test.cc) for both algorithms and all four output dtypes, plus DEFAULT==PHILOX, run-to-run determinism, and the rejection paths. Refs #19519 (item I: remaining StableHLO operators in TFLite). --- .../relax/frontend/tflite/tflite_frontend.py | 228 ++++++++++++++++++ tests/python/relax/test_frontend_tflite.py | 211 ++++++++++++++++ 2 files changed, 439 insertions(+) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 2a4455eb30bb..ca1ff8d6c19f 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -375,6 +375,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "STABLEHLO_REDUCE": self._convert_stablehlo_reduce, "STABLEHLO_REDUCE_WINDOW": self._convert_stablehlo_reduce_window, "STABLEHLO_REMAINDER": self._convert_stablehlo_remainder, + "STABLEHLO_RNG_BIT_GENERATOR": self._convert_stablehlo_rng_bit_generator, "STABLEHLO_RSQRT": functools.partial(self._convert_stablehlo_unary, relax_op=_op.rsqrt), "STABLEHLO_SCATTER": self._convert_stablehlo_scatter, "STABLEHLO_SELECT": functools.partial( @@ -958,6 +959,10 @@ def get_tensor_type_str(self, tensor_type): return "int32" if tensor_type == TensorType.INT64: return "int64" + if tensor_type == TensorType.UINT32: + return "uint32" + if tensor_type == TensorType.UINT64: + return "uint64" if tensor_type == TensorType.BOOL: return "bool" raise NotImplementedError(f"Tensor type {tensor_type!s} is currently not supported") @@ -2206,6 +2211,72 @@ def _convert_stablehlo_custom_call(self, op): target = call_target_name or "" raise tvm.error.OpNotImplemented(f"STABLEHLO_CUSTOM_CALL target {target} is not supported") + def _convert_stablehlo_rng_bit_generator(self, op): + """Convert STABLEHLO_RNG_BIT_GENERATOR to a bit-exact call_tir kernel.""" + from tflite.RngAlgorithm import RngAlgorithm + from tflite.StablehloRngBitGeneratorOptions import StablehloRngBitGeneratorOptions + + op_name = "STABLEHLO_RNG_BIT_GENERATOR" + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 1 or len(output_tensors) != 2: + raise tvm.error.OpNotImplemented(f"{op_name} expects one input and two outputs") + + opts = self._get_stablehlo_options(op, StablehloRngBitGeneratorOptions) + algorithm_enum = opts.Algorithm() + # DEFAULT resolves to PHILOX in the TFLite runtime kernel. + if algorithm_enum == RngAlgorithm.THREEFRY: + algorithm = "threefry" + elif algorithm_enum in (RngAlgorithm.PHILOX, RngAlgorithm.DEFAULT): + algorithm = "philox" + else: + raise tvm.error.OpNotImplemented( + f"{op_name} algorithm {algorithm_enum} is not supported" + ) + + state_tensor = input_tensors[0] + if self.get_tensor_type_str(state_tensor.tensor.Type()) != "uint64": + raise tvm.error.OpNotImplemented(f"{op_name} requires a uint64 initial state") + state_shape = self._get_static_tensor_shape(state_tensor, op_name) + if len(state_shape) != 1: + raise tvm.error.OpNotImplemented(f"{op_name} requires a 1-D initial state") + state_len = int(state_shape[0]) + # State-length constraints mirror the TFLite runtime kernel. + if algorithm == "threefry" and state_len != 2: + raise tvm.error.OpNotImplemented(f"{op_name} THREEFRY requires a u64[2] state") + if algorithm == "philox" and state_len not in (2, 3): + raise tvm.error.OpNotImplemented(f"{op_name} PHILOX requires a u64[2] or u64[3] state") + + out_state_tensor, out_tensor = output_tensors + if self.get_tensor_type_str(out_state_tensor.tensor.Type()) != "uint64": + raise tvm.error.OpNotImplemented(f"{op_name} output state must be uint64") + out_state_shape = self._get_static_tensor_shape(out_state_tensor, op_name) + if list(out_state_shape) != list(state_shape): + raise tvm.error.OpNotImplemented( + f"{op_name} output state shape must match the initial state" + ) + out_dtype = self.get_tensor_type_str(out_tensor.tensor.Type()) + if out_dtype not in ("int32", "int64", "uint32", "uint64"): + raise tvm.error.OpNotImplemented(f"{op_name} output dtype {out_dtype} is not supported") + out_shape = tuple(self._get_static_tensor_shape(out_tensor, op_name)) + + prim_func = _build_stablehlo_rng_bit_generator_primfunc( + algorithm, state_len, out_dtype, out_shape + ) + module_builder = self.conversion_state["module_builder"] + func_name = f"tflite_stablehlo_rng_{algorithm}_{out_state_tensor.tensor_idx}" + gv = module_builder.add_func(prim_func, func_name) + state_expr = self.get_tensor_expr(state_tensor) + call = relax.call_tir( + gv, + [state_expr], + [ + relax.TensorStructInfo(tuple(state_shape), "uint64"), + relax.TensorStructInfo(out_shape, out_dtype), + ], + ) + return self.bb.normalize(call) + def _convert_stablehlo_while(self, op): """Convert STABLEHLO_WHILE to a recursive Relax private function.""" from tflite.StablehloWhileOptions import StablehloWhileOptions @@ -7347,6 +7418,161 @@ def get_tensor_shape(self, tensor_wrapper): ) +# Constants for the Random123 counter-based PRNGs used by STABLEHLO_RNG_BIT_GENERATOR, +# matching tensorflow/lite/kernels/rng_util.cc. +_STABLEHLO_RNG_THREEFRY_PARITY = 0x1BD11BDA +_STABLEHLO_RNG_PHILOX_MUL_A = 0xD2511F53 +_STABLEHLO_RNG_PHILOX_MUL_B = 0xCD9E8D57 +_STABLEHLO_RNG_PHILOX_WEYL_A = 0x9E3779B9 +_STABLEHLO_RNG_PHILOX_WEYL_B = 0xBB67AE85 + + +def _build_stablehlo_rng_bit_generator_primfunc(algorithm, state_len, out_dtype, out_shape): + """Build a bit-exact TIR kernel for STABLEHLO_RNG_BIT_GENERATOR. + + Mirrors the TFLite runtime kernel (tensorflow/lite/kernels/rng_bit_generator.cc), + implementing the Random123 Threefry2x32 (20 rounds) and Philox4x32 (10 rounds) + counter-based PRNGs. The kernel reinterprets the uint64 state as uint32 words, + advances a 64-bit block counter, and packs the generated words into the output + tensor. The updated state keeps the key unchanged and only advances the counter, + which is the only behaviour the runtime relies on. + """ + from tvm.script.parser import tirx as T + + total = 1 + for dim in out_shape: + total *= int(dim) + is_64bit = out_dtype in ("int64", "uint64") + block_words = 2 if algorithm == "threefry" else 4 + out_word_count = total * (2 if is_64bit else 1) + num_blocks = (out_word_count + block_words - 1) // block_words + writes_per_block = block_words // (2 if is_64bit else 1) + parity = _STABLEHLO_RNG_THREEFRY_PARITY + mul_a, mul_b = _STABLEHLO_RNG_PHILOX_MUL_A, _STABLEHLO_RNG_PHILOX_MUL_B + weyl_a, weyl_b = _STABLEHLO_RNG_PHILOX_WEYL_A, _STABLEHLO_RNG_PHILOX_WEYL_B + + def _u32(value): + return T.Cast("uint32", value) + + def _u64(value): + return T.Cast("uint64", value) + + def _store_value(words, write_index): + # Pack the generated uint32 words into one output element, reinterpreting + # the bit pattern into the (possibly signed) output dtype. + if is_64bit: + low = _u64(words[2 * write_index]) + high = _u64(words[2 * write_index + 1]) + return T.reinterpret(out_dtype, low | (high << T.uint64(32))) + return T.reinterpret(out_dtype, words[write_index]) + + if algorithm == "threefry": + + @T.prim_func(private=True, s_tir=True) + def kernel( + initial_state: T.Buffer((state_len,), "uint64"), + output_state: T.Buffer((state_len,), "uint64"), + output: T.Buffer(out_shape, out_dtype), + ): + # A single opaque structured block keeps the imperative kernel as a + # well-formed block-structured PrimFunc, as required by the Relax + # pipeline (e.g. HasReshapePattern). + with T.sblock("rng_bit_generator"): + state_key = initial_state[0] + state_counter = initial_state[1] + key_0 = _u32(state_key & T.uint64(0xFFFFFFFF)) + key_1 = _u32(state_key >> T.uint64(32)) + output_state[0] = state_key + output_state[1] = state_counter + T.uint64(num_blocks) + out_flat = T.decl_buffer((total,), out_dtype, data=output.data) + keys = T.decl_buffer((3,), "uint32", scope="local") + rotations = T.decl_buffer((8,), "uint32", scope="local") + ctr = T.decl_buffer((2,), "uint32", scope="local") + keys[0] = key_0 + keys[1] = key_1 + keys[2] = key_0 ^ key_1 ^ T.uint32(parity) + rotations[0] = T.uint32(13) + rotations[1] = T.uint32(15) + rotations[2] = T.uint32(26) + rotations[3] = T.uint32(6) + rotations[4] = T.uint32(17) + rotations[5] = T.uint32(29) + rotations[6] = T.uint32(16) + rotations[7] = T.uint32(24) + for block in T.serial(num_blocks): + counter = state_counter + _u64(block) + ctr[0] = _u32(counter & T.uint64(0xFFFFFFFF)) + key_0 + ctr[1] = _u32(counter >> T.uint64(32)) + key_1 + for group in T.serial(5): + for step in T.serial(4): + rot = rotations[(group * 4 + step) % 8] + ctr[0] = ctr[0] + ctr[1] + ctr[1] = (ctr[1] << rot) | (ctr[1] >> (T.uint32(32) - rot)) + ctr[1] = ctr[1] ^ ctr[0] + ctr[0] = ctr[0] + keys[(group + 1) % 3] + ctr[1] = ctr[1] + keys[(group + 2) % 3] + _u32(group + 1) + for write_index in T.serial(writes_per_block): + element = block * writes_per_block + write_index + if element < total: + out_flat[element] = _store_value(ctr, write_index) + + return kernel + + @T.prim_func(private=True, s_tir=True) + def kernel( + initial_state: T.Buffer((state_len,), "uint64"), + output_state: T.Buffer((state_len,), "uint64"), + output: T.Buffer(out_shape, out_dtype), + ): + with T.sblock("rng_bit_generator"): + state_key = initial_state[0] + state_counter = initial_state[1] + key_0 = _u32(state_key & T.uint64(0xFFFFFFFF)) + key_1 = _u32(state_key >> T.uint64(32)) + output_state[0] = state_key + output_state[1] = state_counter + T.uint64(num_blocks) + for tail in T.serial(state_len - 2): + output_state[tail + 2] = initial_state[tail + 2] + out_flat = T.decl_buffer((total,), out_dtype, data=output.data) + ctr = T.decl_buffer((4,), "uint32", scope="local") + keys = T.decl_buffer((2,), "uint32", scope="local") + high_ctr = T.decl_buffer((2,), "uint32", scope="local") + if state_len == 3: + high_state = initial_state[2] + high_ctr[0] = _u32(high_state & T.uint64(0xFFFFFFFF)) + high_ctr[1] = _u32(high_state >> T.uint64(32)) + else: + high_ctr[0] = key_0 + high_ctr[1] = key_1 + for block in T.serial(num_blocks): + counter = state_counter + _u64(block) + ctr[0] = _u32(counter & T.uint64(0xFFFFFFFF)) + ctr[1] = _u32(counter >> T.uint64(32)) + ctr[2] = high_ctr[0] + ctr[3] = high_ctr[1] + keys[0] = key_0 + keys[1] = key_1 + for _round in T.serial(10): + prod_0 = T.uint64(mul_a) * _u64(ctr[0]) + prod_1 = T.uint64(mul_b) * _u64(ctr[2]) + new_0 = _u32(prod_1 >> T.uint64(32)) ^ ctr[1] ^ keys[0] + new_1 = _u32(prod_1 & T.uint64(0xFFFFFFFF)) + new_2 = _u32(prod_0 >> T.uint64(32)) ^ ctr[3] ^ keys[1] + new_3 = _u32(prod_0 & T.uint64(0xFFFFFFFF)) + ctr[0] = new_0 + ctr[1] = new_1 + ctr[2] = new_2 + ctr[3] = new_3 + keys[0] = keys[0] + T.uint32(weyl_a) + keys[1] = keys[1] + T.uint32(weyl_b) + for write_index in T.serial(writes_per_block): + element = block * writes_per_block + write_index + if element < total: + out_flat[element] = _store_value(ctr, write_index) + + return kernel + + # pylint: disable=no-else-return def prepare_dense_matrix_from_sparse(sparse_tensor, sparse_tensor_value, sparse_tensor_type): """Prepare sparse indices and dense matrix from TFLite sparse parameters.""" @@ -7593,6 +7819,8 @@ def _decode_type(n): 7: "int16", 8: "complex64", 9: "int8", + 12: "uint64", + 15: "uint32", } return _tflite_m[n] diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 7c3e526d99a2..4a4f363dd807 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3697,6 +3697,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_stablehlo_scatter_opts = _get_tflite_schema_module("StablehloScatterOptions") _tfl_stablehlo_sort_opts = _get_tflite_schema_module("StablehloSortOptions") _tfl_stablehlo_while_opts = _get_tflite_schema_module("StablehloWhileOptions") +_tfl_stablehlo_rng_opts = _get_tflite_schema_module("StablehloRngBitGeneratorOptions") _tfl_call_options = _get_tflite_schema_module("CallOptions") _tfl_call_once_options = _get_tflite_schema_module("CallOnceOptions") _tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata") @@ -3721,6 +3722,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_padding = _get_tflite_schema_enum("Padding") _tfl_sparse_index_vector = _get_tflite_schema_enum("SparseIndexVector") _tfl_tensor_type = _get_tflite_schema_enum("TensorType") +_tfl_rng_algorithm = _get_tflite_schema_enum("RngAlgorithm") _tfl_lstm_options = _get_tflite_schema_module("LSTMOptions") _tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions") @@ -6926,6 +6928,215 @@ def test_stablehlo_options_missing_payload_unsupported(): _load_model_from_buffer(buf) +def _build_stablehlo_rng_model(algorithm, state_len, out_shape, out_tensor_type): + """Build a STABLEHLO_RNG_BIT_GENERATOR model with a uint64 state input.""" + builder = flatbuffers.Builder(1024) + + _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsStart(builder) + _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsAddAlgorithm(builder, algorithm) + rng_opts = _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsEnd(builder) + + rng_builtin = _get_stablehlo_builtin_operator("STABLEHLO_RNG_BIT_GENERATOR") + rng_code = _build_operator_code(builder, rng_builtin) + + main_tensors = [ + _build_tensor(builder, 0, [state_len], tensor_type=_tfl_tensor_type.UINT64), + _build_tensor(builder, 1, [state_len], tensor_type=_tfl_tensor_type.UINT64), + _build_tensor(builder, 2, list(out_shape), tensor_type=out_tensor_type), + ] + rng_op = _build_operator( + builder, + 0, + [0], + [1, 2], + builtin_options2_type=_tfl_builtin_options2.StablehloRngBitGeneratorOptions, + builtin_options2=rng_opts, + ) + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=[rng_op], + inputs=[0], + outputs=[1, 2], + ) + + buffers = [_build_buffer(builder) for _ in range(3)] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + operator_codes=[rng_code], + buffers=buffers, + ) + + +def _run_stablehlo_rng_model(algorithm, state_len, out_shape, out_tensor_type, init_state): + """Import, compile, and execute an RNG model, returning (output_state, output).""" + buf = _build_stablehlo_rng_model(algorithm, state_len, out_shape, out_tensor_type) + mod = _load_model_from_buffer(buf) + ex = tvm.compile(mod, tvm.target.Target("llvm")) + vm = relax.VirtualMachine(ex, tvm.cpu()) + result = vm["main"](tvm.runtime.tensor(np.array(init_state, dtype="uint64"))) + return result[0].numpy(), result[1].numpy() + + +# Expected vectors are taken verbatim from the TFLite runtime kernel test +# (tensorflow/lite/kernels/rng_bit_generator_test.cc), guaranteeing bit-exact parity. +_RNG_THREEFRY_EXPECTED = { + "int32": [43444564, -2144348869, -315321645, -549236733, 1672743891, -54463903], + "uint32": [43444564, 2150618427, 3979645651, 3745730563, 1672743891, 4240503393], + "int64": [ + -9209908263526143660, + -2358953802017238317, + -233920680524772397, + 2658481902456610144, + -2022031683723149139, + -2324041912354448873, + ], + "uint64": [ + 9236835810183407956, + 16087790271692313299, + 18212823393184779219, + 2658481902456610144, + 16424712389986402477, + 16122702161355102743, + ], +} +_RNG_THREEFRY_STATE = {"int32": [1, 5], "uint32": [1, 5], "int64": [1, 8], "uint64": [1, 8]} +_RNG_PHILOX_EXPECTED = { + "int32": [-263854262, 1366700262, 495645701, -1243243882, 89414891, 1917262711], + "uint32": [4031113034, 1366700262, 495645701, 3051723414, 89414891, 1917262711], + "int64": [ + 5869932932755744586, + -5339691813646437371, + 8234580641674714347, + 2641225993340350124, + 1962472297844690804, + -3580856229565614135, + ], + "uint64": [ + 5869932932755744586, + 13107052260063114245, + 8234580641674714347, + 2641225993340350124, + 1962472297844690804, + 14865887844143937481, + ], +} +_RNG_PHILOX_STATE = { + "int32": [1, 4, 3], + "uint32": [1, 4, 3], + "int64": [1, 5, 3], + "uint64": [1, 5, 3], +} + + +@pytest.mark.parametrize( + "out_dtype,out_tensor_type", + [ + ("int32", _tfl_tensor_type.INT32), + ("uint32", _tfl_tensor_type.UINT32), + ("int64", _tfl_tensor_type.INT64), + ("uint64", _tfl_tensor_type.UINT64), + ], +) +def test_stablehlo_rng_bit_generator_threefry(out_dtype, out_tensor_type): + """TFLite STABLEHLO_RNG_BIT_GENERATOR THREEFRY matches the runtime kernel bit-exactly.""" + state, output = _run_stablehlo_rng_model( + _tfl_rng_algorithm.THREEFRY, 2, [2, 3], out_tensor_type, [1, 2] + ) + assert output.flatten().tolist() == _RNG_THREEFRY_EXPECTED[out_dtype] + assert state.tolist() == _RNG_THREEFRY_STATE[out_dtype] + + +@pytest.mark.parametrize( + "out_dtype,out_tensor_type", + [ + ("int32", _tfl_tensor_type.INT32), + ("uint32", _tfl_tensor_type.UINT32), + ("int64", _tfl_tensor_type.INT64), + ("uint64", _tfl_tensor_type.UINT64), + ], +) +def test_stablehlo_rng_bit_generator_philox(out_dtype, out_tensor_type): + """TFLite STABLEHLO_RNG_BIT_GENERATOR PHILOX matches the runtime kernel bit-exactly.""" + state, output = _run_stablehlo_rng_model( + _tfl_rng_algorithm.PHILOX, 3, [2, 3], out_tensor_type, [1, 2, 3] + ) + assert output.flatten().tolist() == _RNG_PHILOX_EXPECTED[out_dtype] + assert state.tolist() == _RNG_PHILOX_STATE[out_dtype] + + +def test_stablehlo_rng_bit_generator_default_matches_philox(): + """TFLite STABLEHLO_RNG_BIT_GENERATOR DEFAULT resolves to the PHILOX algorithm.""" + state, output = _run_stablehlo_rng_model( + _tfl_rng_algorithm.DEFAULT, 3, [2, 3], _tfl_tensor_type.INT32, [1, 2, 3] + ) + assert output.flatten().tolist() == _RNG_PHILOX_EXPECTED["int32"] + assert state.tolist() == _RNG_PHILOX_STATE["int32"] + + +def test_stablehlo_rng_bit_generator_deterministic(): + """Re-running the imported RNG kernel yields identical bit-exact output.""" + buf = _build_stablehlo_rng_model(_tfl_rng_algorithm.PHILOX, 3, [3, 3], _tfl_tensor_type.INT32) + mod = _load_model_from_buffer(buf) + ex = tvm.compile(mod, tvm.target.Target("llvm")) + vm = relax.VirtualMachine(ex, tvm.cpu()) + init = tvm.runtime.tensor(np.array([7, 8, 9], dtype="uint64")) + first = vm["main"](init) + second = vm["main"](init) + np.testing.assert_equal(first[1].numpy(), second[1].numpy()) + np.testing.assert_equal(first[0].numpy(), second[0].numpy()) + + +def test_stablehlo_rng_bit_generator_unsupported_output_dtype(): + """TFLite STABLEHLO_RNG_BIT_GENERATOR rejects non-integer output dtypes.""" + buf = _build_stablehlo_rng_model(_tfl_rng_algorithm.PHILOX, 3, [2, 3], _tfl_tensor_type.FLOAT32) + with pytest.raises(tvm.error.OpNotImplemented, match="output dtype float32 is not supported"): + _load_model_from_buffer(buf) + + +def test_stablehlo_rng_bit_generator_threefry_invalid_state_unsupported(): + """TFLite STABLEHLO_RNG_BIT_GENERATOR rejects a u64[3] state for THREEFRY.""" + buf = _build_stablehlo_rng_model(_tfl_rng_algorithm.THREEFRY, 3, [2, 3], _tfl_tensor_type.INT32) + with pytest.raises(tvm.error.OpNotImplemented, match="THREEFRY requires a u64.2. state"): + _load_model_from_buffer(buf) + + +def test_stablehlo_rng_bit_generator_non_uint64_state_unsupported(): + """TFLite STABLEHLO_RNG_BIT_GENERATOR rejects a non-uint64 initial state.""" + builder = flatbuffers.Builder(1024) + _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsStart(builder) + _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsAddAlgorithm( + builder, _tfl_rng_algorithm.PHILOX + ) + rng_opts = _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsEnd(builder) + rng_code = _build_operator_code( + builder, _get_stablehlo_builtin_operator("STABLEHLO_RNG_BIT_GENERATOR") + ) + tensors = [ + _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT64), + _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT64), + _build_tensor(builder, 2, [2, 3], tensor_type=_tfl_tensor_type.INT32), + ] + rng_op = _build_operator( + builder, + 0, + [0], + [1, 2], + builtin_options2_type=_tfl_builtin_options2.StablehloRngBitGeneratorOptions, + builtin_options2=rng_opts, + ) + subgraph = _build_subgraph( + builder, tensors=tensors, operators=[rng_op], inputs=[0], outputs=[1, 2] + ) + buffers = [_build_buffer(builder) for _ in range(3)] + buf = _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[rng_code], buffers=buffers + ) + with pytest.raises(tvm.error.OpNotImplemented, match="requires a uint64 initial state"): + _load_model_from_buffer(buf) + + def test_stablehlo_while(): """TFLite STABLEHLO_WHILE lowers to a recursive Relax private function.""" mod = _load_model_from_buffer(_build_stablehlo_while_model()) From d2ac81ac567268161f99453a072fb84424e17c6b Mon Sep 17 00:00:00 2001 From: Aharrypotter Date: Mon, 1 Jun 2026 13:08:07 +0800 Subject: [PATCH 2/2] [Relax][Frontend][TFLite] Fix constant uint RNG state import and drop empty loop Follow-up to the STABLEHLO_RNG_BIT_GENERATOR support, addressing review feedback. Constant uint64/uint32 state import ----------------------------------- The initial commit taught `get_tensor_type_str` and the graph-input `_decode_type` map about uint32/uint64, but missed `get_tensor_type_as_numpy`, which decodes constant tensor buffers. As a result, a valid TFLite model whose `initial_state` is a constant tensor (rather than a graph input) reached `get_tensor_expr` -> `get_tensor_value` -> `get_tensor_type_as_numpy` and raised a `KeyError` on the UINT64/UINT32 tensor type, so the model could not import. The original tests only exercised graph-input state and therefore missed this path. Add `UINT32 -> np.uint32` and `UINT64 -> np.uint64` to `get_tensor_type_as_numpy`, and add `test_stablehlo_rng_bit_generator_constant_state`, which embeds a constant `u64[2]` THREEFRY state (no graph input), imports/compiles/executes the model, and checks the output and updated state stay bit-exact against the TFLite runtime vectors. The test also asserts `mod["main"]` has zero params to confirm the state is folded in as a constant. Philox kernel: remove zero-trip loop ------------------------------------ When `state_len == 2`, `for tail in T.serial(state_len - 2)` generated an empty (zero iteration) loop. Fold the `output_state[2]` pass-through into the existing `if state_len == 3:` branch that already derives the high counter words from the same third state word, removing the redundant loop and keeping the pass-through next to the logic that consumes it. Tests: 14 rng_bit_generator and 102 stablehlo cases pass; ruff clean. --- .../relax/frontend/tflite/tflite_frontend.py | 7 ++-- tests/python/relax/test_frontend_tflite.py | 33 ++++++++++++++++--- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index ca1ff8d6c19f..23a34c728921 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -919,6 +919,8 @@ def get_tensor_type_as_numpy(self, tensor_wrapper): TensorType.FLOAT32: np.float32, TensorType.INT32: np.int32, TensorType.INT64: np.int64, + TensorType.UINT32: np.uint32, + TensorType.UINT64: np.uint64, TensorType.BOOL: np.bool_, }[tensor_wrapper.tensor.Type()] @@ -7531,14 +7533,15 @@ def kernel( key_1 = _u32(state_key >> T.uint64(32)) output_state[0] = state_key output_state[1] = state_counter + T.uint64(num_blocks) - for tail in T.serial(state_len - 2): - output_state[tail + 2] = initial_state[tail + 2] out_flat = T.decl_buffer((total,), out_dtype, data=output.data) ctr = T.decl_buffer((4,), "uint32", scope="local") keys = T.decl_buffer((2,), "uint32", scope="local") high_ctr = T.decl_buffer((2,), "uint32", scope="local") if state_len == 3: + # PHILOX u64[3]: the third state word feeds the high counter and + # is passed through to the output state unchanged. high_state = initial_state[2] + output_state[2] = high_state high_ctr[0] = _u32(high_state & T.uint64(0xFFFFFFFF)) high_ctr[1] = _u32(high_state >> T.uint64(32)) else: diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 4a4f363dd807..7d1ac4193181 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -6928,8 +6928,12 @@ def test_stablehlo_options_missing_payload_unsupported(): _load_model_from_buffer(buf) -def _build_stablehlo_rng_model(algorithm, state_len, out_shape, out_tensor_type): - """Build a STABLEHLO_RNG_BIT_GENERATOR model with a uint64 state input.""" +def _build_stablehlo_rng_model(algorithm, state_len, out_shape, out_tensor_type, const_state=None): + """Build a STABLEHLO_RNG_BIT_GENERATOR model. + + When ``const_state`` is provided, the uint64 initial state is embedded as a + constant tensor (no graph input); otherwise it is a graph input. + """ builder = flatbuffers.Builder(1024) _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsStart(builder) @@ -6956,11 +6960,18 @@ def _build_stablehlo_rng_model(algorithm, state_len, out_shape, out_tensor_type) builder, tensors=main_tensors, operators=[rng_op], - inputs=[0], + inputs=[] if const_state is not None else [0], outputs=[1, 2], ) - buffers = [_build_buffer(builder) for _ in range(3)] + state_data = None + if const_state is not None: + state_data = np.array(const_state, dtype="uint64").tobytes() + buffers = [ + _build_buffer(builder, data=state_data), + _build_buffer(builder), + _build_buffer(builder), + ] return _finish_tflite_model( builder, subgraph=main_subgraph, @@ -7088,6 +7099,20 @@ def test_stablehlo_rng_bit_generator_deterministic(): np.testing.assert_equal(first[0].numpy(), second[0].numpy()) +def test_stablehlo_rng_bit_generator_constant_state(): + """A constant uint64 initial state imports and stays bit-exact (no graph input).""" + buf = _build_stablehlo_rng_model( + _tfl_rng_algorithm.THREEFRY, 2, [2, 3], _tfl_tensor_type.INT32, const_state=[1, 2] + ) + mod = _load_model_from_buffer(buf) + assert len(mod["main"].params) == 0 + ex = tvm.compile(mod, tvm.target.Target("llvm")) + vm = relax.VirtualMachine(ex, tvm.cpu()) + result = vm["main"]() + assert result[1].numpy().flatten().tolist() == _RNG_THREEFRY_EXPECTED["int32"] + assert result[0].numpy().tolist() == _RNG_THREEFRY_STATE["int32"] + + def test_stablehlo_rng_bit_generator_unsupported_output_dtype(): """TFLite STABLEHLO_RNG_BIT_GENERATOR rejects non-integer output dtypes.""" buf = _build_stablehlo_rng_model(_tfl_rng_algorithm.PHILOX, 3, [2, 3], _tfl_tensor_type.FLOAT32)