From 3088dc6a22af5315d9574730ed2a8f0fb5d81c06 Mon Sep 17 00:00:00 2001 From: LudovicoYIN Date: Mon, 1 Jun 2026 09:21:57 +0000 Subject: [PATCH] [Relax][Frontend][TFLite] Add HASHTABLE_LOOKUP converter Add Relax TFLite frontend support for HASHTABLE_LOOKUP. This change implements HASHTABLE_LOOKUP for non-string value tensors. The converter lowers the lookup through bucketize, take, and where so that missing keys return zero-filled values and a uint8 hits mask that matches TFLite semantics for the supported cases. The patch also adds handcrafted TFLite frontend tests covering 1D and 2D float value tensors, along with the current unsupported string-value case. --- .../relax/frontend/tflite/tflite_frontend.py | 83 +++++++++++++++++ tests/python/relax/test_frontend_tflite.py | 89 +++++++++++++++++++ 2 files changed, 172 insertions(+) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 2a4455eb30bb..fc3d61713dc6 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -248,6 +248,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "HASHTABLE": self.convert_hashtable, "HASHTABLE_FIND": self.convert_hashtable_find, "HASHTABLE_IMPORT": self.convert_hashtable_import, + "HASHTABLE_LOOKUP": self.convert_hashtable_lookup, "HASHTABLE_SIZE": self.convert_hashtable_size, "IF": self.convert_if, "L2_NORMALIZATION": self.convert_l2_normalization, @@ -755,6 +756,88 @@ def convert_hashtable_find(self, op): "HASHTABLE_FIND requires TensorType.STRING support in Relax TFLite frontend" ) + def convert_hashtable_lookup(self, op): + """Convert TFLite HASHTABLE_LOOKUP for non-string value tensors.""" + from tflite.TensorType import TensorType + + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 3 or len(output_tensors) != 2: + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP expects lookup, key, and value inputs with two outputs" + ) + + lookup_tensor, key_tensor, value_tensor = input_tensors + output_tensor, hits_tensor = output_tensors + + if ( + lookup_tensor.tensor.Type() != TensorType.INT32 + or key_tensor.tensor.Type() != TensorType.INT32 + ): + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP requires int32 lookup and key tensors" + ) + if self._is_tflite_string_type(value_tensor.tensor.Type()): + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP with TensorType.STRING values is not supported" + ) + if value_tensor.tensor.Type() != output_tensor.tensor.Type(): + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP output dtype must match the value tensor dtype" + ) + if hits_tensor.tensor.Type() != TensorType.UINT8: + raise tvm.error.OpNotImplemented("HASHTABLE_LOOKUP hits output must be uint8") + + lookup_shape = to_int_list(self.get_tensor_shape(lookup_tensor)) + key_shape = to_int_list(self.get_tensor_shape(key_tensor)) + value_shape = to_int_list(self.get_tensor_shape(value_tensor)) + output_shape = to_int_list(self.get_tensor_shape(output_tensor)) + hits_shape = to_int_list(self.get_tensor_shape(hits_tensor)) + + if len(lookup_shape) != 1 or len(key_shape) != 1 or len(value_shape) < 1: + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP requires rank-1 lookup/key and rank>=1 value tensors" + ) + if key_shape[0] != value_shape[0]: + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP requires key and value tensors to agree on row count" + ) + if key_shape[0] == 0: + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP requires a non-empty key/value table" + ) + if output_shape != [lookup_shape[0]] + value_shape[1:]: + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP output shape must match lookup count and value tail shape" + ) + if hits_shape != [lookup_shape[0]]: + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP hits output shape must match lookup count" + ) + + lookup = self.get_tensor_expr(lookup_tensor) + key = self.get_tensor_expr(key_tensor) + value = self.get_tensor_expr(value_tensor) + + positions = relax.op.bucketize(lookup, key, out_int32=True, right=False) + candidate_keys = relax.op.take(key, positions, axis=0, mode="clip") + in_range = relax.op.less(positions, relax.const(key_shape[0], "int32")) + found = relax.op.logical_and(in_range, relax.op.equal(candidate_keys, lookup)) + + gathered_values = relax.op.take(value, positions, axis=0, mode="clip") + output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type()) + zero_values = relax.op.zeros(output_shape, output_dtype) + + if len(value_shape) > 1: + found_values = relax.op.expand_dims(found, axis=list(range(1, len(value_shape)))) + found_values = relax.op.broadcast_to(found_values, output_shape) + else: + found_values = found + + output = relax.op.where(found_values, gathered_values, zero_values) + hits = relax.op.astype(found, "uint8") + return relax.Tuple([output, hits]) + def convert_hashtable_size(self, op): """Convert HASHTABLE_SIZE for a statically imported TFLite hashtable.""" input_tensors = self.get_input_tensors(op) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 7c3e526d99a2..c34da605de18 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -4053,6 +4053,18 @@ def _get_builtin_operator(builtin_name): return getattr(_tfl_builtin_operator, builtin_name) +def _run_module(mod, *inputs): + tgt = tvm.target.Target("c") + ex = tvm.compile(mod, tgt) + vm = relax.VirtualMachine(ex, tvm.cpu()) + vm.set_input("main", *inputs) + vm.invoke_stateful("main") + outputs = vm.get_outputs("main") + if hasattr(outputs, "numpy"): + return outputs.numpy() + return tuple(output.numpy() for output in outputs) + + def _build_tflite_call_model( call_subgraph_index=1, callee_inputs=None, @@ -5844,6 +5856,36 @@ def _build_tflite_hashtable_size_uninitialized_model(): ) +def _build_tflite_hashtable_lookup_model(*, value_shape, value_type=None): + """Build a model containing one HASHTABLE_LOOKUP operator.""" + builder = flatbuffers.Builder(1024) + + value_type = _tfl_tensor_type.FLOAT32 if value_type is None else value_type + + lookup_tensor = _build_tensor(builder, 0, [4], tensor_type=_tfl_tensor_type.INT32) + key_tensor = _build_tensor(builder, 1, [3], tensor_type=_tfl_tensor_type.INT32) + value_tensor = _build_tensor(builder, 2, value_shape, tensor_type=value_type) + output_tensor = _build_tensor(builder, 3, [4, *value_shape[1:]], tensor_type=value_type) + hits_tensor = _build_tensor(builder, 4, [4], tensor_type=_tfl_tensor_type.UINT8) + + hashtable_lookup = _build_operator(builder, 0, [0, 1, 2], [3, 4]) + main_subgraph = _build_subgraph( + builder, + tensors=[lookup_tensor, key_tensor, value_tensor, output_tensor, hits_tensor], + operators=[hashtable_lookup], + inputs=[0, 1, 2], + outputs=[3, 4], + ) + operator_codes = [_build_operator_code(builder, _get_builtin_operator("HASHTABLE_LOOKUP"))] + buffers = [_build_buffer(builder) for _ in range(5)] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + operator_codes=operator_codes, + buffers=buffers, + ) + + def test_resource_variable_call_once_init_read(): """Test reading a resource variable initialized by a supported CALL_ONCE subgraph.""" mod = _load_model_from_buffer(_build_tflite_resource_variable_model()) @@ -5908,6 +5950,53 @@ def test_hashtable_size_uninitialized_unsupported(): _load_model_from_buffer(_build_tflite_hashtable_size_uninitialized_model()) +def test_hashtable_lookup_1d_value(): + mod = _load_model_from_buffer(_build_tflite_hashtable_lookup_model(value_shape=[3])) + + output, hits = _run_module( + mod, + np.array([1234, -292, -11, 0], dtype=np.int32), + np.array([-11, 0, 1234], dtype=np.int32), + np.array([0.0, 0.1, 0.4], dtype=np.float32), + ) + + np.testing.assert_allclose(output, np.array([0.4, 0.0, 0.0, 0.1], dtype=np.float32)) + np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8)) + + +def test_hashtable_lookup_2d_value(): + mod = _load_model_from_buffer(_build_tflite_hashtable_lookup_model(value_shape=[3, 2])) + + output, hits = _run_module( + mod, + np.array([1234, -292, -11, 0], dtype=np.int32), + np.array([-11, 0, 1234], dtype=np.int32), + np.array([[0.0, 0.1], [1.0, 1.1], [2.0, 2.1]], dtype=np.float32), + ) + + np.testing.assert_allclose( + output, + np.array( + [ + [2.0, 2.1], + [0.0, 0.0], + [0.0, 0.1], + [1.0, 1.1], + ], + dtype=np.float32, + ), + ) + np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8)) + + +def test_hashtable_lookup_string_value_unsupported(): + string_type = _get_string_tensor_type() + with pytest.raises(ValueError, match="unknown dtype `string`"): + _load_model_from_buffer( + _build_tflite_hashtable_lookup_model(value_shape=[3], value_type=string_type) + ) + + def _get_stablehlo_builtin_operator(builtin_name): if not hasattr(_tfl_builtin_operator, builtin_name): pytest.skip(f"TFLite schema does not provide BuiltinOperator.{builtin_name}")