diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 2a4455eb30bb..23a34c728921 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -375,6 +375,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "STABLEHLO_REDUCE": self._convert_stablehlo_reduce, "STABLEHLO_REDUCE_WINDOW": self._convert_stablehlo_reduce_window, "STABLEHLO_REMAINDER": self._convert_stablehlo_remainder, + "STABLEHLO_RNG_BIT_GENERATOR": self._convert_stablehlo_rng_bit_generator, "STABLEHLO_RSQRT": functools.partial(self._convert_stablehlo_unary, relax_op=_op.rsqrt), "STABLEHLO_SCATTER": self._convert_stablehlo_scatter, "STABLEHLO_SELECT": functools.partial( @@ -918,6 +919,8 @@ def get_tensor_type_as_numpy(self, tensor_wrapper): TensorType.FLOAT32: np.float32, TensorType.INT32: np.int32, TensorType.INT64: np.int64, + TensorType.UINT32: np.uint32, + TensorType.UINT64: np.uint64, TensorType.BOOL: np.bool_, }[tensor_wrapper.tensor.Type()] @@ -958,6 +961,10 @@ def get_tensor_type_str(self, tensor_type): return "int32" if tensor_type == TensorType.INT64: return "int64" + if tensor_type == TensorType.UINT32: + return "uint32" + if tensor_type == TensorType.UINT64: + return "uint64" if tensor_type == TensorType.BOOL: return "bool" raise NotImplementedError(f"Tensor type {tensor_type!s} is currently not supported") @@ -2206,6 +2213,72 @@ def _convert_stablehlo_custom_call(self, op): target = call_target_name or "" raise tvm.error.OpNotImplemented(f"STABLEHLO_CUSTOM_CALL target {target} is not supported") + def _convert_stablehlo_rng_bit_generator(self, op): + """Convert STABLEHLO_RNG_BIT_GENERATOR to a bit-exact call_tir kernel.""" + from tflite.RngAlgorithm import RngAlgorithm + from tflite.StablehloRngBitGeneratorOptions import StablehloRngBitGeneratorOptions + + op_name = "STABLEHLO_RNG_BIT_GENERATOR" + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 1 or len(output_tensors) != 2: + raise tvm.error.OpNotImplemented(f"{op_name} expects one input and two outputs") + + opts = self._get_stablehlo_options(op, StablehloRngBitGeneratorOptions) + algorithm_enum = opts.Algorithm() + # DEFAULT resolves to PHILOX in the TFLite runtime kernel. + if algorithm_enum == RngAlgorithm.THREEFRY: + algorithm = "threefry" + elif algorithm_enum in (RngAlgorithm.PHILOX, RngAlgorithm.DEFAULT): + algorithm = "philox" + else: + raise tvm.error.OpNotImplemented( + f"{op_name} algorithm {algorithm_enum} is not supported" + ) + + state_tensor = input_tensors[0] + if self.get_tensor_type_str(state_tensor.tensor.Type()) != "uint64": + raise tvm.error.OpNotImplemented(f"{op_name} requires a uint64 initial state") + state_shape = self._get_static_tensor_shape(state_tensor, op_name) + if len(state_shape) != 1: + raise tvm.error.OpNotImplemented(f"{op_name} requires a 1-D initial state") + state_len = int(state_shape[0]) + # State-length constraints mirror the TFLite runtime kernel. + if algorithm == "threefry" and state_len != 2: + raise tvm.error.OpNotImplemented(f"{op_name} THREEFRY requires a u64[2] state") + if algorithm == "philox" and state_len not in (2, 3): + raise tvm.error.OpNotImplemented(f"{op_name} PHILOX requires a u64[2] or u64[3] state") + + out_state_tensor, out_tensor = output_tensors + if self.get_tensor_type_str(out_state_tensor.tensor.Type()) != "uint64": + raise tvm.error.OpNotImplemented(f"{op_name} output state must be uint64") + out_state_shape = self._get_static_tensor_shape(out_state_tensor, op_name) + if list(out_state_shape) != list(state_shape): + raise tvm.error.OpNotImplemented( + f"{op_name} output state shape must match the initial state" + ) + out_dtype = self.get_tensor_type_str(out_tensor.tensor.Type()) + if out_dtype not in ("int32", "int64", "uint32", "uint64"): + raise tvm.error.OpNotImplemented(f"{op_name} output dtype {out_dtype} is not supported") + out_shape = tuple(self._get_static_tensor_shape(out_tensor, op_name)) + + prim_func = _build_stablehlo_rng_bit_generator_primfunc( + algorithm, state_len, out_dtype, out_shape + ) + module_builder = self.conversion_state["module_builder"] + func_name = f"tflite_stablehlo_rng_{algorithm}_{out_state_tensor.tensor_idx}" + gv = module_builder.add_func(prim_func, func_name) + state_expr = self.get_tensor_expr(state_tensor) + call = relax.call_tir( + gv, + [state_expr], + [ + relax.TensorStructInfo(tuple(state_shape), "uint64"), + relax.TensorStructInfo(out_shape, out_dtype), + ], + ) + return self.bb.normalize(call) + def _convert_stablehlo_while(self, op): """Convert STABLEHLO_WHILE to a recursive Relax private function.""" from tflite.StablehloWhileOptions import StablehloWhileOptions @@ -7347,6 +7420,162 @@ def get_tensor_shape(self, tensor_wrapper): ) +# Constants for the Random123 counter-based PRNGs used by STABLEHLO_RNG_BIT_GENERATOR, +# matching tensorflow/lite/kernels/rng_util.cc. +_STABLEHLO_RNG_THREEFRY_PARITY = 0x1BD11BDA +_STABLEHLO_RNG_PHILOX_MUL_A = 0xD2511F53 +_STABLEHLO_RNG_PHILOX_MUL_B = 0xCD9E8D57 +_STABLEHLO_RNG_PHILOX_WEYL_A = 0x9E3779B9 +_STABLEHLO_RNG_PHILOX_WEYL_B = 0xBB67AE85 + + +def _build_stablehlo_rng_bit_generator_primfunc(algorithm, state_len, out_dtype, out_shape): + """Build a bit-exact TIR kernel for STABLEHLO_RNG_BIT_GENERATOR. + + Mirrors the TFLite runtime kernel (tensorflow/lite/kernels/rng_bit_generator.cc), + implementing the Random123 Threefry2x32 (20 rounds) and Philox4x32 (10 rounds) + counter-based PRNGs. The kernel reinterprets the uint64 state as uint32 words, + advances a 64-bit block counter, and packs the generated words into the output + tensor. The updated state keeps the key unchanged and only advances the counter, + which is the only behaviour the runtime relies on. + """ + from tvm.script.parser import tirx as T + + total = 1 + for dim in out_shape: + total *= int(dim) + is_64bit = out_dtype in ("int64", "uint64") + block_words = 2 if algorithm == "threefry" else 4 + out_word_count = total * (2 if is_64bit else 1) + num_blocks = (out_word_count + block_words - 1) // block_words + writes_per_block = block_words // (2 if is_64bit else 1) + parity = _STABLEHLO_RNG_THREEFRY_PARITY + mul_a, mul_b = _STABLEHLO_RNG_PHILOX_MUL_A, _STABLEHLO_RNG_PHILOX_MUL_B + weyl_a, weyl_b = _STABLEHLO_RNG_PHILOX_WEYL_A, _STABLEHLO_RNG_PHILOX_WEYL_B + + def _u32(value): + return T.Cast("uint32", value) + + def _u64(value): + return T.Cast("uint64", value) + + def _store_value(words, write_index): + # Pack the generated uint32 words into one output element, reinterpreting + # the bit pattern into the (possibly signed) output dtype. + if is_64bit: + low = _u64(words[2 * write_index]) + high = _u64(words[2 * write_index + 1]) + return T.reinterpret(out_dtype, low | (high << T.uint64(32))) + return T.reinterpret(out_dtype, words[write_index]) + + if algorithm == "threefry": + + @T.prim_func(private=True, s_tir=True) + def kernel( + initial_state: T.Buffer((state_len,), "uint64"), + output_state: T.Buffer((state_len,), "uint64"), + output: T.Buffer(out_shape, out_dtype), + ): + # A single opaque structured block keeps the imperative kernel as a + # well-formed block-structured PrimFunc, as required by the Relax + # pipeline (e.g. HasReshapePattern). + with T.sblock("rng_bit_generator"): + state_key = initial_state[0] + state_counter = initial_state[1] + key_0 = _u32(state_key & T.uint64(0xFFFFFFFF)) + key_1 = _u32(state_key >> T.uint64(32)) + output_state[0] = state_key + output_state[1] = state_counter + T.uint64(num_blocks) + out_flat = T.decl_buffer((total,), out_dtype, data=output.data) + keys = T.decl_buffer((3,), "uint32", scope="local") + rotations = T.decl_buffer((8,), "uint32", scope="local") + ctr = T.decl_buffer((2,), "uint32", scope="local") + keys[0] = key_0 + keys[1] = key_1 + keys[2] = key_0 ^ key_1 ^ T.uint32(parity) + rotations[0] = T.uint32(13) + rotations[1] = T.uint32(15) + rotations[2] = T.uint32(26) + rotations[3] = T.uint32(6) + rotations[4] = T.uint32(17) + rotations[5] = T.uint32(29) + rotations[6] = T.uint32(16) + rotations[7] = T.uint32(24) + for block in T.serial(num_blocks): + counter = state_counter + _u64(block) + ctr[0] = _u32(counter & T.uint64(0xFFFFFFFF)) + key_0 + ctr[1] = _u32(counter >> T.uint64(32)) + key_1 + for group in T.serial(5): + for step in T.serial(4): + rot = rotations[(group * 4 + step) % 8] + ctr[0] = ctr[0] + ctr[1] + ctr[1] = (ctr[1] << rot) | (ctr[1] >> (T.uint32(32) - rot)) + ctr[1] = ctr[1] ^ ctr[0] + ctr[0] = ctr[0] + keys[(group + 1) % 3] + ctr[1] = ctr[1] + keys[(group + 2) % 3] + _u32(group + 1) + for write_index in T.serial(writes_per_block): + element = block * writes_per_block + write_index + if element < total: + out_flat[element] = _store_value(ctr, write_index) + + return kernel + + @T.prim_func(private=True, s_tir=True) + def kernel( + initial_state: T.Buffer((state_len,), "uint64"), + output_state: T.Buffer((state_len,), "uint64"), + output: T.Buffer(out_shape, out_dtype), + ): + with T.sblock("rng_bit_generator"): + state_key = initial_state[0] + state_counter = initial_state[1] + key_0 = _u32(state_key & T.uint64(0xFFFFFFFF)) + key_1 = _u32(state_key >> T.uint64(32)) + output_state[0] = state_key + output_state[1] = state_counter + T.uint64(num_blocks) + out_flat = T.decl_buffer((total,), out_dtype, data=output.data) + ctr = T.decl_buffer((4,), "uint32", scope="local") + keys = T.decl_buffer((2,), "uint32", scope="local") + high_ctr = T.decl_buffer((2,), "uint32", scope="local") + if state_len == 3: + # PHILOX u64[3]: the third state word feeds the high counter and + # is passed through to the output state unchanged. + high_state = initial_state[2] + output_state[2] = high_state + high_ctr[0] = _u32(high_state & T.uint64(0xFFFFFFFF)) + high_ctr[1] = _u32(high_state >> T.uint64(32)) + else: + high_ctr[0] = key_0 + high_ctr[1] = key_1 + for block in T.serial(num_blocks): + counter = state_counter + _u64(block) + ctr[0] = _u32(counter & T.uint64(0xFFFFFFFF)) + ctr[1] = _u32(counter >> T.uint64(32)) + ctr[2] = high_ctr[0] + ctr[3] = high_ctr[1] + keys[0] = key_0 + keys[1] = key_1 + for _round in T.serial(10): + prod_0 = T.uint64(mul_a) * _u64(ctr[0]) + prod_1 = T.uint64(mul_b) * _u64(ctr[2]) + new_0 = _u32(prod_1 >> T.uint64(32)) ^ ctr[1] ^ keys[0] + new_1 = _u32(prod_1 & T.uint64(0xFFFFFFFF)) + new_2 = _u32(prod_0 >> T.uint64(32)) ^ ctr[3] ^ keys[1] + new_3 = _u32(prod_0 & T.uint64(0xFFFFFFFF)) + ctr[0] = new_0 + ctr[1] = new_1 + ctr[2] = new_2 + ctr[3] = new_3 + keys[0] = keys[0] + T.uint32(weyl_a) + keys[1] = keys[1] + T.uint32(weyl_b) + for write_index in T.serial(writes_per_block): + element = block * writes_per_block + write_index + if element < total: + out_flat[element] = _store_value(ctr, write_index) + + return kernel + + # pylint: disable=no-else-return def prepare_dense_matrix_from_sparse(sparse_tensor, sparse_tensor_value, sparse_tensor_type): """Prepare sparse indices and dense matrix from TFLite sparse parameters.""" @@ -7593,6 +7822,8 @@ def _decode_type(n): 7: "int16", 8: "complex64", 9: "int8", + 12: "uint64", + 15: "uint32", } return _tflite_m[n] diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 7c3e526d99a2..7d1ac4193181 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3697,6 +3697,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_stablehlo_scatter_opts = _get_tflite_schema_module("StablehloScatterOptions") _tfl_stablehlo_sort_opts = _get_tflite_schema_module("StablehloSortOptions") _tfl_stablehlo_while_opts = _get_tflite_schema_module("StablehloWhileOptions") +_tfl_stablehlo_rng_opts = _get_tflite_schema_module("StablehloRngBitGeneratorOptions") _tfl_call_options = _get_tflite_schema_module("CallOptions") _tfl_call_once_options = _get_tflite_schema_module("CallOnceOptions") _tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata") @@ -3721,6 +3722,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_padding = _get_tflite_schema_enum("Padding") _tfl_sparse_index_vector = _get_tflite_schema_enum("SparseIndexVector") _tfl_tensor_type = _get_tflite_schema_enum("TensorType") +_tfl_rng_algorithm = _get_tflite_schema_enum("RngAlgorithm") _tfl_lstm_options = _get_tflite_schema_module("LSTMOptions") _tfl_sequence_rnn_options = _get_tflite_schema_module("SequenceRNNOptions") @@ -6926,6 +6928,240 @@ def test_stablehlo_options_missing_payload_unsupported(): _load_model_from_buffer(buf) +def _build_stablehlo_rng_model(algorithm, state_len, out_shape, out_tensor_type, const_state=None): + """Build a STABLEHLO_RNG_BIT_GENERATOR model. + + When ``const_state`` is provided, the uint64 initial state is embedded as a + constant tensor (no graph input); otherwise it is a graph input. + """ + builder = flatbuffers.Builder(1024) + + _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsStart(builder) + _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsAddAlgorithm(builder, algorithm) + rng_opts = _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsEnd(builder) + + rng_builtin = _get_stablehlo_builtin_operator("STABLEHLO_RNG_BIT_GENERATOR") + rng_code = _build_operator_code(builder, rng_builtin) + + main_tensors = [ + _build_tensor(builder, 0, [state_len], tensor_type=_tfl_tensor_type.UINT64), + _build_tensor(builder, 1, [state_len], tensor_type=_tfl_tensor_type.UINT64), + _build_tensor(builder, 2, list(out_shape), tensor_type=out_tensor_type), + ] + rng_op = _build_operator( + builder, + 0, + [0], + [1, 2], + builtin_options2_type=_tfl_builtin_options2.StablehloRngBitGeneratorOptions, + builtin_options2=rng_opts, + ) + main_subgraph = _build_subgraph( + builder, + tensors=main_tensors, + operators=[rng_op], + inputs=[] if const_state is not None else [0], + outputs=[1, 2], + ) + + state_data = None + if const_state is not None: + state_data = np.array(const_state, dtype="uint64").tobytes() + buffers = [ + _build_buffer(builder, data=state_data), + _build_buffer(builder), + _build_buffer(builder), + ] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + operator_codes=[rng_code], + buffers=buffers, + ) + + +def _run_stablehlo_rng_model(algorithm, state_len, out_shape, out_tensor_type, init_state): + """Import, compile, and execute an RNG model, returning (output_state, output).""" + buf = _build_stablehlo_rng_model(algorithm, state_len, out_shape, out_tensor_type) + mod = _load_model_from_buffer(buf) + ex = tvm.compile(mod, tvm.target.Target("llvm")) + vm = relax.VirtualMachine(ex, tvm.cpu()) + result = vm["main"](tvm.runtime.tensor(np.array(init_state, dtype="uint64"))) + return result[0].numpy(), result[1].numpy() + + +# Expected vectors are taken verbatim from the TFLite runtime kernel test +# (tensorflow/lite/kernels/rng_bit_generator_test.cc), guaranteeing bit-exact parity. +_RNG_THREEFRY_EXPECTED = { + "int32": [43444564, -2144348869, -315321645, -549236733, 1672743891, -54463903], + "uint32": [43444564, 2150618427, 3979645651, 3745730563, 1672743891, 4240503393], + "int64": [ + -9209908263526143660, + -2358953802017238317, + -233920680524772397, + 2658481902456610144, + -2022031683723149139, + -2324041912354448873, + ], + "uint64": [ + 9236835810183407956, + 16087790271692313299, + 18212823393184779219, + 2658481902456610144, + 16424712389986402477, + 16122702161355102743, + ], +} +_RNG_THREEFRY_STATE = {"int32": [1, 5], "uint32": [1, 5], "int64": [1, 8], "uint64": [1, 8]} +_RNG_PHILOX_EXPECTED = { + "int32": [-263854262, 1366700262, 495645701, -1243243882, 89414891, 1917262711], + "uint32": [4031113034, 1366700262, 495645701, 3051723414, 89414891, 1917262711], + "int64": [ + 5869932932755744586, + -5339691813646437371, + 8234580641674714347, + 2641225993340350124, + 1962472297844690804, + -3580856229565614135, + ], + "uint64": [ + 5869932932755744586, + 13107052260063114245, + 8234580641674714347, + 2641225993340350124, + 1962472297844690804, + 14865887844143937481, + ], +} +_RNG_PHILOX_STATE = { + "int32": [1, 4, 3], + "uint32": [1, 4, 3], + "int64": [1, 5, 3], + "uint64": [1, 5, 3], +} + + +@pytest.mark.parametrize( + "out_dtype,out_tensor_type", + [ + ("int32", _tfl_tensor_type.INT32), + ("uint32", _tfl_tensor_type.UINT32), + ("int64", _tfl_tensor_type.INT64), + ("uint64", _tfl_tensor_type.UINT64), + ], +) +def test_stablehlo_rng_bit_generator_threefry(out_dtype, out_tensor_type): + """TFLite STABLEHLO_RNG_BIT_GENERATOR THREEFRY matches the runtime kernel bit-exactly.""" + state, output = _run_stablehlo_rng_model( + _tfl_rng_algorithm.THREEFRY, 2, [2, 3], out_tensor_type, [1, 2] + ) + assert output.flatten().tolist() == _RNG_THREEFRY_EXPECTED[out_dtype] + assert state.tolist() == _RNG_THREEFRY_STATE[out_dtype] + + +@pytest.mark.parametrize( + "out_dtype,out_tensor_type", + [ + ("int32", _tfl_tensor_type.INT32), + ("uint32", _tfl_tensor_type.UINT32), + ("int64", _tfl_tensor_type.INT64), + ("uint64", _tfl_tensor_type.UINT64), + ], +) +def test_stablehlo_rng_bit_generator_philox(out_dtype, out_tensor_type): + """TFLite STABLEHLO_RNG_BIT_GENERATOR PHILOX matches the runtime kernel bit-exactly.""" + state, output = _run_stablehlo_rng_model( + _tfl_rng_algorithm.PHILOX, 3, [2, 3], out_tensor_type, [1, 2, 3] + ) + assert output.flatten().tolist() == _RNG_PHILOX_EXPECTED[out_dtype] + assert state.tolist() == _RNG_PHILOX_STATE[out_dtype] + + +def test_stablehlo_rng_bit_generator_default_matches_philox(): + """TFLite STABLEHLO_RNG_BIT_GENERATOR DEFAULT resolves to the PHILOX algorithm.""" + state, output = _run_stablehlo_rng_model( + _tfl_rng_algorithm.DEFAULT, 3, [2, 3], _tfl_tensor_type.INT32, [1, 2, 3] + ) + assert output.flatten().tolist() == _RNG_PHILOX_EXPECTED["int32"] + assert state.tolist() == _RNG_PHILOX_STATE["int32"] + + +def test_stablehlo_rng_bit_generator_deterministic(): + """Re-running the imported RNG kernel yields identical bit-exact output.""" + buf = _build_stablehlo_rng_model(_tfl_rng_algorithm.PHILOX, 3, [3, 3], _tfl_tensor_type.INT32) + mod = _load_model_from_buffer(buf) + ex = tvm.compile(mod, tvm.target.Target("llvm")) + vm = relax.VirtualMachine(ex, tvm.cpu()) + init = tvm.runtime.tensor(np.array([7, 8, 9], dtype="uint64")) + first = vm["main"](init) + second = vm["main"](init) + np.testing.assert_equal(first[1].numpy(), second[1].numpy()) + np.testing.assert_equal(first[0].numpy(), second[0].numpy()) + + +def test_stablehlo_rng_bit_generator_constant_state(): + """A constant uint64 initial state imports and stays bit-exact (no graph input).""" + buf = _build_stablehlo_rng_model( + _tfl_rng_algorithm.THREEFRY, 2, [2, 3], _tfl_tensor_type.INT32, const_state=[1, 2] + ) + mod = _load_model_from_buffer(buf) + assert len(mod["main"].params) == 0 + ex = tvm.compile(mod, tvm.target.Target("llvm")) + vm = relax.VirtualMachine(ex, tvm.cpu()) + result = vm["main"]() + assert result[1].numpy().flatten().tolist() == _RNG_THREEFRY_EXPECTED["int32"] + assert result[0].numpy().tolist() == _RNG_THREEFRY_STATE["int32"] + + +def test_stablehlo_rng_bit_generator_unsupported_output_dtype(): + """TFLite STABLEHLO_RNG_BIT_GENERATOR rejects non-integer output dtypes.""" + buf = _build_stablehlo_rng_model(_tfl_rng_algorithm.PHILOX, 3, [2, 3], _tfl_tensor_type.FLOAT32) + with pytest.raises(tvm.error.OpNotImplemented, match="output dtype float32 is not supported"): + _load_model_from_buffer(buf) + + +def test_stablehlo_rng_bit_generator_threefry_invalid_state_unsupported(): + """TFLite STABLEHLO_RNG_BIT_GENERATOR rejects a u64[3] state for THREEFRY.""" + buf = _build_stablehlo_rng_model(_tfl_rng_algorithm.THREEFRY, 3, [2, 3], _tfl_tensor_type.INT32) + with pytest.raises(tvm.error.OpNotImplemented, match="THREEFRY requires a u64.2. state"): + _load_model_from_buffer(buf) + + +def test_stablehlo_rng_bit_generator_non_uint64_state_unsupported(): + """TFLite STABLEHLO_RNG_BIT_GENERATOR rejects a non-uint64 initial state.""" + builder = flatbuffers.Builder(1024) + _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsStart(builder) + _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsAddAlgorithm( + builder, _tfl_rng_algorithm.PHILOX + ) + rng_opts = _tfl_stablehlo_rng_opts.StablehloRngBitGeneratorOptionsEnd(builder) + rng_code = _build_operator_code( + builder, _get_stablehlo_builtin_operator("STABLEHLO_RNG_BIT_GENERATOR") + ) + tensors = [ + _build_tensor(builder, 0, [2], tensor_type=_tfl_tensor_type.INT64), + _build_tensor(builder, 1, [2], tensor_type=_tfl_tensor_type.INT64), + _build_tensor(builder, 2, [2, 3], tensor_type=_tfl_tensor_type.INT32), + ] + rng_op = _build_operator( + builder, + 0, + [0], + [1, 2], + builtin_options2_type=_tfl_builtin_options2.StablehloRngBitGeneratorOptions, + builtin_options2=rng_opts, + ) + subgraph = _build_subgraph( + builder, tensors=tensors, operators=[rng_op], inputs=[0], outputs=[1, 2] + ) + buffers = [_build_buffer(builder) for _ in range(3)] + buf = _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[rng_code], buffers=buffers + ) + with pytest.raises(tvm.error.OpNotImplemented, match="requires a uint64 initial state"): + _load_model_from_buffer(buf) + + def test_stablehlo_while(): """TFLite STABLEHLO_WHILE lowers to a recursive Relax private function.""" mod = _load_model_from_buffer(_build_stablehlo_while_model())