[Relax][Frontend][TFLite] Support STABLEHLO_RNG_BIT_GENERATOR#19651
[Relax][Frontend][TFLite] Support STABLEHLO_RNG_BIT_GENERATOR#19651Aharrypotter wants to merge 2 commits into
Conversation
Add Relax TFLite frontend support for the `STABLEHLO_RNG_BIT_GENERATOR`
builtin operator, lowering it to a bit-exact `call_tir` PRNG kernel.
The TFLite runtime kernel (tensorflow/lite/kernels/rng_bit_generator.cc)
is a real, deterministic counter-based PRNG, not a thin StableHLO shim:
- one uint64 1-D `initial_state` input, two outputs (uint64 `output_state`
plus the random-bit `output` in int32/int64/uint32/uint64);
- `algorithm` in {DEFAULT, PHILOX, THREEFRY}, where DEFAULT resolves to
PHILOX in the runtime;
- bit-exact Random123 Threefry2x32 (20 rounds) and Philox4x32 (10 rounds)
with fixed constants from rng_util.cc;
- state-length constraints: THREEFRY requires u64[2], PHILOX/DEFAULT
require u64[2] or u64[3].
Since TVM/Relax has no matching RNG primitive, the converter generates a
TIR kernel that mirrors the runtime exactly and emits it via `call_tir`.
The kernel reinterprets the uint64 state as uint32 words, advances a
64-bit block counter, runs the selected algorithm per block with all
round state materialized into local buffers (avoiding exponential
expression blow-up), and packs the generated words back into the output
dtype. The updated state keeps the key unchanged and only advances the
counter, which is the only state behaviour the runtime relies on. The
kernel is a `s_tir` PrimFunc wrapped in a single opaque structured block
so it stays a well-formed block-structured function for the Relax
pipeline (e.g. HasReshapePattern).
Unsupported cases raise a precise OpNotImplemented: non-uint64 state,
non-1-D state, mismatched output-state shape, unsupported output dtype,
unknown algorithm, and per-algorithm state-length violations.
Also extend `get_tensor_type_str` and the input `_decode_type` map with
uint32/uint64 so uint64 state tensors import correctly.
Tests build minimal RNG flatbuffers, compile, and execute them, checking
output and updated state against the verbatim expected vectors from the
TFLite runtime kernel test (rng_bit_generator_test.cc) for both
algorithms and all four output dtypes, plus DEFAULT==PHILOX, run-to-run
determinism, and the rejection paths.
Refs apache#19519 (item I: remaining StableHLO operators in TFLite).
There was a problem hiding this comment.
Code Review
This pull request adds support for the STABLEHLO_RNG_BIT_GENERATOR operator in the TVM Relax TFLite frontend, including support for UINT32 and UINT64 tensor types. It implements bit-exact TIR kernels for both the Threefry and Philox counter-based PRNG algorithms and adds corresponding unit tests. Feedback on the changes highlights a few issues in the generated TIR kernels, including an invalid import of tirx instead of tir, unsupported s_tir arguments in the @T.prim_func decorators, and an opportunity to avoid generating empty loops when the state length is 2.
| tensor. The updated state keeps the key unchanged and only advances the counter, | ||
| which is the only behaviour the runtime relies on. | ||
| """ | ||
| from tvm.script.parser import tirx as T |
|
|
||
| if algorithm == "threefry": | ||
|
|
||
| @T.prim_func(private=True, s_tir=True) |
|
|
||
| return kernel | ||
|
|
||
| @T.prim_func(private=True, s_tir=True) |
| for tail in T.serial(state_len - 2): | ||
| output_state[tail + 2] = initial_state[tail + 2] |
There was a problem hiding this comment.
When state_len is 2, state_len - 2 is 0, which generates a loop with 0 iterations (T.serial(0)). To avoid generating redundant/empty loops in the TIR kernel, we can conditionally generate this loop using a Python if statement.
| for tail in T.serial(state_len - 2): | |
| output_state[tail + 2] = initial_state[tail + 2] | |
| if state_len > 2: | |
| for tail in T.serial(state_len - 2): | |
| output_state[tail + 2] = initial_state[tail + 2] |
… empty loop Follow-up to the STABLEHLO_RNG_BIT_GENERATOR support, addressing review feedback. Constant uint64/uint32 state import ----------------------------------- The initial commit taught `get_tensor_type_str` and the graph-input `_decode_type` map about uint32/uint64, but missed `get_tensor_type_as_numpy`, which decodes constant tensor buffers. As a result, a valid TFLite model whose `initial_state` is a constant tensor (rather than a graph input) reached `get_tensor_expr` -> `get_tensor_value` -> `get_tensor_type_as_numpy` and raised a `KeyError` on the UINT64/UINT32 tensor type, so the model could not import. The original tests only exercised graph-input state and therefore missed this path. Add `UINT32 -> np.uint32` and `UINT64 -> np.uint64` to `get_tensor_type_as_numpy`, and add `test_stablehlo_rng_bit_generator_constant_state`, which embeds a constant `u64[2]` THREEFRY state (no graph input), imports/compiles/executes the model, and checks the output and updated state stay bit-exact against the TFLite runtime vectors. The test also asserts `mod["main"]` has zero params to confirm the state is folded in as a constant. Philox kernel: remove zero-trip loop ------------------------------------ When `state_len == 2`, `for tail in T.serial(state_len - 2)` generated an empty (zero iteration) loop. Fold the `output_state[2]` pass-through into the existing `if state_len == 3:` branch that already derives the high counter words from the same third state word, removing the redundant loop and keeping the pass-through next to the logic that consumes it. Tests: 14 rng_bit_generator and 102 stablehlo cases pass; ruff clean.
Summary
This PR adds Relax TFLite frontend support for the TFLite builtin
STABLEHLO_RNG_BIT_GENERATORoperator.Unlike most StableHLO builtins, the TFLite runtime
(
tensorflow/lite/kernels/rng_bit_generator.cc) implements this op as a real,deterministic counter-based PRNG, so the importer must reproduce it bit-exactly
rather than map it to an existing op:
initial_stateinput, two outputs — uint64output_stateandthe random-bit
output(int32 / int64 / uint32 / uint64);algorithmin{DEFAULT, PHILOX, THREEFRY}, whereDEFAULTresolves toPHILOX;constants from
rng_util.cc;THREEFRYrequiresu64[2],PHILOX/DEFAULTrequire
u64[2]oru64[3].Design
TVM/Relax has no matching RNG primitive, so the converter generates a TIR kernel
that mirrors the runtime and emits it through
relax.call_tirwith two outputs.The kernel:
counter (
final counter = initial_state[1] + num_blocks);local buffers, which keeps the generated IR linear instead of an exponentially
nested expression tree;
updated state (key unchanged, counter advanced, Philox
u64[3]tail passedthrough) — the only state behaviour the runtime relies on.
The kernel is an
s_tirPrimFunc wrapped in a single opaque structured block soit remains a well-formed block-structured function for the Relax pipeline
(e.g.
HasReshapePattern).get_tensor_type_strand the input_decode_typemap are extended with uint32/uint64 so the uint64 state imports correctly.
Unsupported inputs raise a precise
OpNotImplemented(non-uint64 / non-1-Dstate, mismatched output-state shape, unsupported output dtype, unknown
algorithm, per-algorithm state-length violations).
Operator Support
STABLEHLO_RNG_BIT_GENERATORStablehloRngBitGeneratorOptions.Algorithm()fromBuiltinOptions2call_tirto a generated bit-exact TIR kernelu64[2]) and PHILOX/DEFAULT (u64[2]/u64[3]); int32/int64/uint32/uint64 outputTests
Tests build minimal RNG flatbuffers, compile, and execute them, comparing the
output and updated state against the verbatim expected vectors from the TFLite
runtime kernel test (
rng_bit_generator_test.cc).test_stablehlo_rng_bit_generator_threefrytest_stablehlo_rng_bit_generator_philoxtest_stablehlo_rng_bit_generator_default_matches_philoxtest_stablehlo_rng_bit_generator_deterministictest_stablehlo_rng_bit_generator_unsupported_output_dtypetest_stablehlo_rng_bit_generator_threefry_invalid_state_unsupportedu64[2]state guardtest_stablehlo_rng_bit_generator_non_uint64_state_unsupportedLocal validation:
Result:
References
tensorflow/lite/kernels/rng_bit_generator.cc,rng_util.cc,rng_bit_generator_test.cc