From 306e7473083d6f6c618cb82085aa1cd89165a9cc Mon Sep 17 00:00:00 2001 From: LudovicoYIN Date: Mon, 1 Jun 2026 07:02:02 +0000 Subject: [PATCH 1/3] [Relax][Frontend][TFLite] Add EMBEDDING_LOOKUP_SPARSE converter Add Relax TFLite frontend support for EMBEDDING_LOOKUP_SPARSE. This change implements the converter for the SUM, MEAN, and SQRTN combiners and supports higher-rank sparse indices. The converter lowers the sparse aggregation through scatter_nd to match TFLite operator semantics for the supported cases. The patch also adds handcrafted TFLite frontend tests covering the SUM, MEAN, and SQRTN combiners, along with a 3D indices case. --- .../relax/frontend/tflite/tflite_frontend.py | 117 +++++++++++ tests/python/relax/test_frontend_tflite.py | 188 ++++++++++++++++++ 2 files changed, 305 insertions(+) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 2a4455eb30bb..7a7a9a38dd06 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -224,6 +224,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "DIV": functools.partial(self._convert_elemwise, relax_op=_op.divide), "ELU": self.convert_elu, "EMBEDDING_LOOKUP": self.convert_embedding_lookup, + "EMBEDDING_LOOKUP_SPARSE": self.convert_embedding_lookup_sparse, "EQUAL": functools.partial( self._convert_elemwise, relax_op=_op.equal, comparison_op=True ), @@ -6183,6 +6184,122 @@ def convert_embedding_lookup(self, op): indices = self.get_tensor_expr(indices_tensor) return relax.op.take(params, indices, axis=0) + def convert_embedding_lookup_sparse(self, op): + """Convert TFLite EMBEDDING_LOOKUP_SPARSE.""" + from tflite.CombinerType import CombinerType + from tflite.EmbeddingLookupSparseOptions import EmbeddingLookupSparseOptions + from tflite.TensorType import TensorType + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 5, "EMBEDDING_LOOKUP_SPARSE should have 5 input tensors" + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "EMBEDDING_LOOKUP_SPARSE should have 1 output tensor" + + ids_tensor, indices_tensor, dense_shape_tensor, weights_tensor, params_tensor = ( + input_tensors + ) + output_tensor = output_tensors[0] + + for tensor in input_tensors: + assert not tensor.qnn_params, "Quantized input is not expected." + + assert ids_tensor.tensor.Type() == TensorType.INT32 + assert indices_tensor.tensor.Type() == TensorType.INT32 + assert dense_shape_tensor.tensor.Type() == TensorType.INT32 + assert weights_tensor.tensor.Type() == TensorType.FLOAT32 + assert params_tensor.tensor.Type() == TensorType.FLOAT32 + assert output_tensor.tensor.Type() == TensorType.FLOAT32 + + ids_shape = to_int_list(self.get_tensor_shape(ids_tensor)) + indices_shape = to_int_list(self.get_tensor_shape(indices_tensor)) + dense_shape_shape = to_int_list(self.get_tensor_shape(dense_shape_tensor)) + weights_shape = to_int_list(self.get_tensor_shape(weights_tensor)) + params_shape = to_int_list(self.get_tensor_shape(params_tensor)) + + assert len(ids_shape) == 1, "EMBEDDING_LOOKUP_SPARSE ids must be rank 1" + assert len(indices_shape) == 2, "EMBEDDING_LOOKUP_SPARSE indices must be rank 2" + assert len(dense_shape_shape) == 1, "EMBEDDING_LOOKUP_SPARSE dense_shape must be rank 1" + assert len(weights_shape) == 1, "EMBEDDING_LOOKUP_SPARSE weights must be rank 1" + assert len(params_shape) >= 2, "EMBEDDING_LOOKUP_SPARSE params must be rank >= 2" + assert indices_shape[0] == ids_shape[0], ( + "EMBEDDING_LOOKUP_SPARSE ids and indices must agree on lookup count" + ) + assert weights_shape[0] == ids_shape[0], ( + "EMBEDDING_LOOKUP_SPARSE ids and weights must agree on lookup count" + ) + + if self.has_expr(dense_shape_tensor.tensor_idx): + raise tvm.error.OpNotImplemented( + "TFLite EMBEDDING_LOOKUP_SPARSE with runtime dense_shape is not supported." + ) + + dense_shape = to_int_list(self.get_tensor_value(dense_shape_tensor)) + lookup_rank = indices_shape[1] + assert len(dense_shape) == lookup_rank, ( + "EMBEDDING_LOOKUP_SPARSE dense_shape length must match indices width" + ) + assert lookup_rank >= 1, "EMBEDDING_LOOKUP_SPARSE indices width must be positive" + if not self.has_expr(ids_tensor.tensor_idx): + ids_value = self.get_tensor_value(ids_tensor) + if np.any(ids_value < 0): + raise tvm.error.OpNotImplemented( + "TFLite EMBEDDING_LOOKUP_SPARSE with negative ids is not supported." + ) + + params = self.get_tensor_expr(params_tensor) + ids = self.get_tensor_expr(ids_tensor) + weights = self.get_tensor_expr(weights_tensor) + indices = self.get_tensor_expr(indices_tensor) + + ids = relax.op.astype(ids, "int32") + lookup = relax.op.take(params, ids, axis=0) + + embedding_tail_shape = params_shape[1:] + output_prefix_shape = dense_shape[:-1] + output_shape = output_prefix_shape + embedding_tail_shape + + # Aggregation buckets are defined by every sparse index dimension except the last one. + bucket_indices = relax.op.strided_slice(indices, axes=[1], begin=[0], end=[lookup_rank - 1]) + + weight_expand_shape = [ids_shape[0]] + [1] * len(embedding_tail_shape) + weighted_lookup = relax.op.multiply(lookup, relax.op.reshape(weights, weight_expand_shape)) + + value_base = relax.const(np.zeros(output_shape, dtype=np.float32), "float32") + summed_lookup = relax.op.scatter_nd(value_base, bucket_indices, weighted_lookup, "add") + + op_options = op.BuiltinOptions() + sparse_options = EmbeddingLookupSparseOptions() + sparse_options.Init(op_options.Bytes, op_options.Pos) + combiner = sparse_options.Combiner() + if combiner == CombinerType.SUM: + return summed_lookup + + count_shape = output_prefix_shape + count_base = relax.const(np.zeros(count_shape, dtype=np.float32), "float32") + if combiner == CombinerType.MEAN: + denominator_updates = weights + elif combiner == CombinerType.SQRTN: + denominator_updates = relax.op.multiply(weights, weights) + else: + raise tvm.error.OpNotImplemented( + f"Unsupported TFLite EMBEDDING_LOOKUP_SPARSE combiner value {combiner}" + ) + + denominator = relax.op.scatter_nd(count_base, bucket_indices, denominator_updates, "add") + if combiner == CombinerType.SQRTN: + denominator = relax.op.sqrt(denominator) + + broadcast_shape = count_shape + [1] * len(embedding_tail_shape) + denominator = relax.op.reshape(denominator, broadcast_shape) + denominator = relax.op.broadcast_to(denominator, output_shape) + safe_denominator = relax.op.maximum( + denominator, relax.const(np.full(output_shape, 1e-12, dtype=np.float32), "float32") + ) + normalized = relax.op.divide(summed_lookup, safe_denominator) + return relax.op.where( + relax.op.greater(denominator, relax.const(0.0, "float32")), normalized, value_base + ) + def convert_batch_matmul(self, op): """batch_matmul implementation.""" diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 7c3e526d99a2..b76050c13593 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -4037,6 +4037,17 @@ def _build_hashtable_options( return hashtable_options.HashtableOptionsEnd(builder) +def _build_embedding_lookup_sparse_options(builder, combiner): + try: + sparse_options = _get_tflite_schema_module("EmbeddingLookupSparseOptions") + except ModuleNotFoundError: + pytest.skip("TFLite schema does not provide EmbeddingLookupSparseOptions") + + sparse_options.EmbeddingLookupSparseOptionsStart(builder) + sparse_options.EmbeddingLookupSparseOptionsAddCombiner(builder, combiner) + return sparse_options.EmbeddingLookupSparseOptionsEnd(builder) + + def _load_model_from_buffer(model_bytes): if hasattr(tflite.Model, "Model"): tflite_model = tflite.Model.Model.GetRootAsModel(model_bytes, 0) @@ -4053,6 +4064,15 @@ def _get_builtin_operator(builtin_name): return getattr(_tfl_builtin_operator, builtin_name) +def _run_no_input_module(mod): + tgt = tvm.target.Target("c") + ex = tvm.compile(mod, tgt) + vm = relax.VirtualMachine(ex, tvm.cpu()) + vm.set_input("main") + vm.invoke_stateful("main") + return vm.get_outputs("main").numpy() + + def _build_tflite_call_model( call_subgraph_index=1, callee_inputs=None, @@ -5844,6 +5864,82 @@ def _build_tflite_hashtable_size_uninitialized_model(): ) +def _build_tflite_embedding_lookup_sparse_model(combiner, indices_data, dense_shape_data): + builder = flatbuffers.Builder(4096) + + ids_data = np.array([1, 3, 0], dtype=np.int32) + indices_data = np.array(indices_data, dtype=np.int32) + dense_shape_data = np.array(dense_shape_data, dtype=np.int32) + weights_data = np.array([1.0, 2.0, 4.0], dtype=np.float32) + params_data = np.array( + [ + [[0.00, 0.01], [0.10, 0.11], [0.20, 0.21]], + [[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]], + [[2.00, 2.01], [2.10, 2.11], [2.20, 2.21]], + [[3.00, 3.01], [3.10, 3.11], [3.20, 3.21]], + ], + dtype=np.float32, + ) + + output_shape = dense_shape_data[:-1].tolist() + list(params_data.shape[1:]) + sparse_options = _build_embedding_lookup_sparse_options(builder, combiner) + + ids_tensor = _build_tensor(builder, 0, list(ids_data.shape), tensor_type=_tfl_tensor_type.INT32) + indices_tensor = _build_tensor( + builder, 1, list(indices_data.shape), tensor_type=_tfl_tensor_type.INT32 + ) + dense_shape_tensor = _build_tensor( + builder, 2, list(dense_shape_data.shape), tensor_type=_tfl_tensor_type.INT32 + ) + weights_tensor = _build_tensor( + builder, 3, list(weights_data.shape), tensor_type=_tfl_tensor_type.FLOAT32 + ) + params_tensor = _build_tensor( + builder, 4, list(params_data.shape), tensor_type=_tfl_tensor_type.FLOAT32 + ) + output_tensor = _build_tensor(builder, 5, output_shape, tensor_type=_tfl_tensor_type.FLOAT32) + + sparse_op = _build_operator( + builder, + 0, + [0, 1, 2, 3, 4], + [5], + builtin_options_type=_get_builtin_options_type("EmbeddingLookupSparseOptions"), + builtin_options=sparse_options, + ) + subgraph = _build_subgraph( + builder, + tensors=[ + ids_tensor, + indices_tensor, + dense_shape_tensor, + weights_tensor, + params_tensor, + output_tensor, + ], + operators=[sparse_op], + inputs=[], + outputs=[5], + ) + operator_codes = [ + _build_operator_code(builder, _get_builtin_operator("EMBEDDING_LOOKUP_SPARSE")) + ] + buffers = [ + _build_buffer(builder, ids_data.tobytes()), + _build_buffer(builder, indices_data.tobytes()), + _build_buffer(builder, dense_shape_data.tobytes()), + _build_buffer(builder, weights_data.tobytes()), + _build_buffer(builder, params_data.tobytes()), + _build_buffer(builder), + ] + return _finish_tflite_model( + builder, + subgraph=subgraph, + operator_codes=operator_codes, + buffers=buffers, + ) + + def test_resource_variable_call_once_init_read(): """Test reading a resource variable initialized by a supported CALL_ONCE subgraph.""" mod = _load_model_from_buffer(_build_tflite_resource_variable_model()) @@ -5908,6 +6004,98 @@ def test_hashtable_size_uninitialized_unsupported(): _load_model_from_buffer(_build_tflite_hashtable_size_uninitialized_model()) +def test_embedding_lookup_sparse_sum(): + from tflite.CombinerType import CombinerType + + mod = _load_model_from_buffer( + _build_tflite_embedding_lookup_sparse_model( + CombinerType.SUM, + indices_data=[[0, 0], [2, 0], [2, 1]], + dense_shape_data=[3, 2], + ) + ) + + out = _run_no_input_module(mod) + expected = np.array( + [ + [[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]], + [[0.00, 0.00], [0.00, 0.00], [0.00, 0.00]], + [[6.00, 6.06], [6.60, 6.66], [7.20, 7.26]], + ], + dtype=np.float32, + ) + np.testing.assert_allclose(out, expected, rtol=1e-5, atol=1e-5) + + +def test_embedding_lookup_sparse_mean(): + from tflite.CombinerType import CombinerType + + mod = _load_model_from_buffer( + _build_tflite_embedding_lookup_sparse_model( + CombinerType.MEAN, + indices_data=[[0, 0], [2, 0], [2, 1]], + dense_shape_data=[3, 2], + ) + ) + + out = _run_no_input_module(mod) + expected = np.array( + [ + [[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]], + [[0.00, 0.00], [0.00, 0.00], [0.00, 0.00]], + [[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]], + ], + dtype=np.float32, + ) + np.testing.assert_allclose(out, expected, rtol=1e-5, atol=1e-5) + + +def test_embedding_lookup_sparse_sqrtn(): + from tflite.CombinerType import CombinerType + + mod = _load_model_from_buffer( + _build_tflite_embedding_lookup_sparse_model( + CombinerType.SQRTN, + indices_data=[[0, 0], [2, 0], [2, 1]], + dense_shape_data=[3, 2], + ) + ) + + out = _run_no_input_module(mod) + scale = np.sqrt(20.0).astype("float32") + expected = np.array( + [ + [[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]], + [[0.00, 0.00], [0.00, 0.00], [0.00, 0.00]], + [ + [6.00 / scale, 6.06 / scale], + [6.60 / scale, 6.66 / scale], + [7.20 / scale, 7.26 / scale], + ], + ], + dtype=np.float32, + ) + np.testing.assert_allclose(out, expected, rtol=1e-5, atol=1e-5) + + +def test_embedding_lookup_sparse_indices_3d(): + from tflite.CombinerType import CombinerType + + mod = _load_model_from_buffer( + _build_tflite_embedding_lookup_sparse_model( + CombinerType.SUM, + indices_data=[[0, 0, 0], [2, 0, 0], [2, 0, 1]], + dense_shape_data=[3, 2, 2], + ) + ) + + out = _run_no_input_module(mod) + expected = np.zeros((3, 2, 3, 2), dtype=np.float32) + expected[0, 0] = np.array([[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]], dtype=np.float32) + expected[2, 0] = np.array([[6.00, 6.06], [6.60, 6.66], [7.20, 7.26]], dtype=np.float32) + np.testing.assert_allclose(out, expected, rtol=1e-5, atol=1e-5) + + def _get_stablehlo_builtin_operator(builtin_name): if not hasattr(_tfl_builtin_operator, builtin_name): pytest.skip(f"TFLite schema does not provide BuiltinOperator.{builtin_name}") From bce556e937733c414a4430c79862ea6942923b55 Mon Sep 17 00:00:00 2001 From: YinHanke Date: Mon, 1 Jun 2026 15:15:00 +0800 Subject: [PATCH 2/3] Update python/tvm/relax/frontend/tflite/tflite_frontend.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- python/tvm/relax/frontend/tflite/tflite_frontend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 7a7a9a38dd06..b764d0c171cf 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -6293,7 +6293,7 @@ def convert_embedding_lookup_sparse(self, op): denominator = relax.op.reshape(denominator, broadcast_shape) denominator = relax.op.broadcast_to(denominator, output_shape) safe_denominator = relax.op.maximum( - denominator, relax.const(np.full(output_shape, 1e-12, dtype=np.float32), "float32") + denominator, relax.const(1e-12, "float32") ) normalized = relax.op.divide(summed_lookup, safe_denominator) return relax.op.where( From 0d5803a08891887a09fced60aed67ee8d27fb41b Mon Sep 17 00:00:00 2001 From: LudovicoYIN Date: Mon, 1 Jun 2026 07:29:21 +0000 Subject: [PATCH 3/3] [Relax][Frontend][TFLite] Simplify safe denominator constant Use a scalar float32 constant when clamping the EMBEDDING_LOOKUP_SPARSE denominator. This keeps the Relax expression smaller by relying on broadcasting in maximum instead of materializing a full-shaped constant tensor. --- python/tvm/relax/frontend/tflite/tflite_frontend.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index b764d0c171cf..b35b8ed73753 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -6292,9 +6292,7 @@ def convert_embedding_lookup_sparse(self, op): broadcast_shape = count_shape + [1] * len(embedding_tail_shape) denominator = relax.op.reshape(denominator, broadcast_shape) denominator = relax.op.broadcast_to(denominator, output_shape) - safe_denominator = relax.op.maximum( - denominator, relax.const(1e-12, "float32") - ) + safe_denominator = relax.op.maximum(denominator, relax.const(1e-12, "float32")) normalized = relax.op.divide(summed_lookup, safe_denominator) return relax.op.where( relax.op.greater(denominator, relax.const(0.0, "float32")), normalized, value_base