Skip to content

[Relax][Frontend][TFLite] Support STABLEHLO_RNG_BIT_GENERATOR#19651

Draft
Aharrypotter wants to merge 2 commits into
apache:mainfrom
Aharrypotter:relax-tflite-stablehlo-rng-bit-generator
Draft

[Relax][Frontend][TFLite] Support STABLEHLO_RNG_BIT_GENERATOR#19651
Aharrypotter wants to merge 2 commits into
apache:mainfrom
Aharrypotter:relax-tflite-stablehlo-rng-bit-generator

Conversation

@Aharrypotter
Copy link
Copy Markdown
Contributor

Summary

This PR adds Relax TFLite frontend support for the TFLite builtin
STABLEHLO_RNG_BIT_GENERATOR operator.

Unlike most StableHLO builtins, the TFLite runtime
(tensorflow/lite/kernels/rng_bit_generator.cc) implements this op as a real,
deterministic counter-based PRNG, so the importer must reproduce it bit-exactly
rather than map it to an existing op:

  • one uint64 1-D initial_state input, two outputs — uint64 output_state and
    the random-bit output (int32 / int64 / uint32 / uint64);
  • algorithm in {DEFAULT, PHILOX, THREEFRY}, where DEFAULT resolves to
    PHILOX;
  • Random123 Threefry2x32 (20 rounds) and Philox4x32 (10 rounds) with the fixed
    constants from rng_util.cc;
  • state-length constraints: THREEFRY requires u64[2], PHILOX/DEFAULT
    require u64[2] or u64[3].

Design

TVM/Relax has no matching RNG primitive, so the converter generates a TIR kernel
that mirrors the runtime and emits it through relax.call_tir with two outputs.
The kernel:

  • reinterprets the uint64 state as uint32 words and advances a 64-bit block
    counter (final counter = initial_state[1] + num_blocks);
  • runs the selected algorithm per block with all round state materialized into
    local buffers, which keeps the generated IR linear instead of an exponentially
    nested expression tree;
  • packs the produced uint32 words back into the output dtype, and writes the
    updated state (key unchanged, counter advanced, Philox u64[3] tail passed
    through) — the only state behaviour the runtime relies on.

The kernel is an s_tir PrimFunc wrapped in a single opaque structured block so
it remains a well-formed block-structured function for the Relax pipeline
(e.g. HasReshapePattern). get_tensor_type_str and the input _decode_type
map are extended with uint32/uint64 so the uint64 state imports correctly.

Unsupported inputs raise a precise OpNotImplemented (non-uint64 / non-1-D
state, mismatched output-state shape, unsupported output dtype, unknown
algorithm, per-algorithm state-length violations).

Operator Support

Operator TFLite options Relax lowering Supported subset
STABLEHLO_RNG_BIT_GENERATOR StablehloRngBitGeneratorOptions.Algorithm() from BuiltinOptions2 call_tir to a generated bit-exact TIR kernel THREEFRY (u64[2]) and PHILOX/DEFAULT (u64[2]/u64[3]); int32/int64/uint32/uint64 output

Tests

Tests build minimal RNG flatbuffers, compile, and execute them, comparing the
output and updated state against the verbatim expected vectors from the TFLite
runtime kernel test (rng_bit_generator_test.cc).

Test Coverage
test_stablehlo_rng_bit_generator_threefry THREEFRY bit-exact, all 4 output dtypes
test_stablehlo_rng_bit_generator_philox PHILOX bit-exact, all 4 output dtypes
test_stablehlo_rng_bit_generator_default_matches_philox DEFAULT resolves to PHILOX
test_stablehlo_rng_bit_generator_deterministic run-to-run bit-identical output
test_stablehlo_rng_bit_generator_unsupported_output_dtype output dtype guard
test_stablehlo_rng_bit_generator_threefry_invalid_state_unsupported THREEFRY u64[2] state guard
test_stablehlo_rng_bit_generator_non_uint64_state_unsupported uint64 state guard

Local validation:

python -m ruff check \
  python/tvm/relax/frontend/tflite/tflite_frontend.py \
  tests/python/relax/test_frontend_tflite.py

python -m pytest \
  tests/python/relax/test_frontend_tflite.py \
  -k rng_bit_generator -q

python -m pytest \
  tests/python/relax/test_frontend_tflite.py \
  -k stablehlo -q

Result:

ruff check: All checks passed
rng_bit_generator tests: 13 passed
stablehlo tests: 96 passed

References

Add Relax TFLite frontend support for the `STABLEHLO_RNG_BIT_GENERATOR`
builtin operator, lowering it to a bit-exact `call_tir` PRNG kernel.

The TFLite runtime kernel (tensorflow/lite/kernels/rng_bit_generator.cc)
is a real, deterministic counter-based PRNG, not a thin StableHLO shim:

- one uint64 1-D `initial_state` input, two outputs (uint64 `output_state`
  plus the random-bit `output` in int32/int64/uint32/uint64);
- `algorithm` in {DEFAULT, PHILOX, THREEFRY}, where DEFAULT resolves to
  PHILOX in the runtime;
- bit-exact Random123 Threefry2x32 (20 rounds) and Philox4x32 (10 rounds)
  with fixed constants from rng_util.cc;
- state-length constraints: THREEFRY requires u64[2], PHILOX/DEFAULT
  require u64[2] or u64[3].

Since TVM/Relax has no matching RNG primitive, the converter generates a
TIR kernel that mirrors the runtime exactly and emits it via `call_tir`.
The kernel reinterprets the uint64 state as uint32 words, advances a
64-bit block counter, runs the selected algorithm per block with all
round state materialized into local buffers (avoiding exponential
expression blow-up), and packs the generated words back into the output
dtype. The updated state keeps the key unchanged and only advances the
counter, which is the only state behaviour the runtime relies on. The
kernel is a `s_tir` PrimFunc wrapped in a single opaque structured block
so it stays a well-formed block-structured function for the Relax
pipeline (e.g. HasReshapePattern).

Unsupported cases raise a precise OpNotImplemented: non-uint64 state,
non-1-D state, mismatched output-state shape, unsupported output dtype,
unknown algorithm, and per-algorithm state-length violations.

Also extend `get_tensor_type_str` and the input `_decode_type` map with
uint32/uint64 so uint64 state tensors import correctly.

Tests build minimal RNG flatbuffers, compile, and execute them, checking
output and updated state against the verbatim expected vectors from the
TFLite runtime kernel test (rng_bit_generator_test.cc) for both
algorithms and all four output dtypes, plus DEFAULT==PHILOX, run-to-run
determinism, and the rejection paths.

Refs apache#19519 (item I: remaining StableHLO operators in TFLite).
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the STABLEHLO_RNG_BIT_GENERATOR operator in the TVM Relax TFLite frontend, including support for UINT32 and UINT64 tensor types. It implements bit-exact TIR kernels for both the Threefry and Philox counter-based PRNG algorithms and adds corresponding unit tests. Feedback on the changes highlights a few issues in the generated TIR kernels, including an invalid import of tirx instead of tir, unsupported s_tir arguments in the @T.prim_func decorators, and an opportunity to avoid generating empty loops when the state length is 2.

tensor. The updated state keeps the key unchanged and only advances the counter,
which is the only behaviour the runtime relies on.
"""
from tvm.script.parser import tirx as T
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a typo in the import statement. tirx does not exist in tvm.script.parser. It should be imported from tvm.script as tir.

Suggested change
from tvm.script.parser import tirx as T
from tvm.script import tir as T


if algorithm == "threefry":

@T.prim_func(private=True, s_tir=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The @T.prim_func decorator does not accept an s_tir argument. This will raise a TypeError at runtime. Please remove s_tir=True.

Suggested change
@T.prim_func(private=True, s_tir=True)
@T.prim_func(private=True)


return kernel

@T.prim_func(private=True, s_tir=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The @T.prim_func decorator does not accept an s_tir argument. This will raise a TypeError at runtime. Please remove s_tir=True.

Suggested change
@T.prim_func(private=True, s_tir=True)
@T.prim_func(private=True)

Comment on lines +7534 to +7535
for tail in T.serial(state_len - 2):
output_state[tail + 2] = initial_state[tail + 2]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When state_len is 2, state_len - 2 is 0, which generates a loop with 0 iterations (T.serial(0)). To avoid generating redundant/empty loops in the TIR kernel, we can conditionally generate this loop using a Python if statement.

Suggested change
for tail in T.serial(state_len - 2):
output_state[tail + 2] = initial_state[tail + 2]
if state_len > 2:
for tail in T.serial(state_len - 2):
output_state[tail + 2] = initial_state[tail + 2]

… empty loop

Follow-up to the STABLEHLO_RNG_BIT_GENERATOR support, addressing review
feedback.

Constant uint64/uint32 state import
-----------------------------------
The initial commit taught `get_tensor_type_str` and the graph-input
`_decode_type` map about uint32/uint64, but missed
`get_tensor_type_as_numpy`, which decodes constant tensor buffers. As a
result, a valid TFLite model whose `initial_state` is a constant tensor
(rather than a graph input) reached `get_tensor_expr` ->
`get_tensor_value` -> `get_tensor_type_as_numpy` and raised a `KeyError`
on the UINT64/UINT32 tensor type, so the model could not import. The
original tests only exercised graph-input state and therefore missed
this path.

Add `UINT32 -> np.uint32` and `UINT64 -> np.uint64` to
`get_tensor_type_as_numpy`, and add `test_stablehlo_rng_bit_generator_constant_state`,
which embeds a constant `u64[2]` THREEFRY state (no graph input),
imports/compiles/executes the model, and checks the output and updated
state stay bit-exact against the TFLite runtime vectors. The test also
asserts `mod["main"]` has zero params to confirm the state is folded in
as a constant.

Philox kernel: remove zero-trip loop
------------------------------------
When `state_len == 2`, `for tail in T.serial(state_len - 2)` generated an
empty (zero iteration) loop. Fold the `output_state[2]` pass-through into
the existing `if state_len == 3:` branch that already derives the high
counter words from the same third state word, removing the redundant
loop and keeping the pass-through next to the logic that consumes it.

Tests: 14 rng_bit_generator and 102 stablehlo cases pass; ruff clean.
@Aharrypotter Aharrypotter marked this pull request as draft June 1, 2026 06:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant