diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 2a4455eb30bb..fc3d61713dc6 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -248,6 +248,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "HASHTABLE": self.convert_hashtable, "HASHTABLE_FIND": self.convert_hashtable_find, "HASHTABLE_IMPORT": self.convert_hashtable_import, + "HASHTABLE_LOOKUP": self.convert_hashtable_lookup, "HASHTABLE_SIZE": self.convert_hashtable_size, "IF": self.convert_if, "L2_NORMALIZATION": self.convert_l2_normalization, @@ -755,6 +756,88 @@ def convert_hashtable_find(self, op): "HASHTABLE_FIND requires TensorType.STRING support in Relax TFLite frontend" ) + def convert_hashtable_lookup(self, op): + """Convert TFLite HASHTABLE_LOOKUP for non-string value tensors.""" + from tflite.TensorType import TensorType + + input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) + if len(input_tensors) != 3 or len(output_tensors) != 2: + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP expects lookup, key, and value inputs with two outputs" + ) + + lookup_tensor, key_tensor, value_tensor = input_tensors + output_tensor, hits_tensor = output_tensors + + if ( + lookup_tensor.tensor.Type() != TensorType.INT32 + or key_tensor.tensor.Type() != TensorType.INT32 + ): + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP requires int32 lookup and key tensors" + ) + if self._is_tflite_string_type(value_tensor.tensor.Type()): + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP with TensorType.STRING values is not supported" + ) + if value_tensor.tensor.Type() != output_tensor.tensor.Type(): + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP output dtype must match the value tensor dtype" + ) + if hits_tensor.tensor.Type() != TensorType.UINT8: + raise tvm.error.OpNotImplemented("HASHTABLE_LOOKUP hits output must be uint8") + + lookup_shape = to_int_list(self.get_tensor_shape(lookup_tensor)) + key_shape = to_int_list(self.get_tensor_shape(key_tensor)) + value_shape = to_int_list(self.get_tensor_shape(value_tensor)) + output_shape = to_int_list(self.get_tensor_shape(output_tensor)) + hits_shape = to_int_list(self.get_tensor_shape(hits_tensor)) + + if len(lookup_shape) != 1 or len(key_shape) != 1 or len(value_shape) < 1: + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP requires rank-1 lookup/key and rank>=1 value tensors" + ) + if key_shape[0] != value_shape[0]: + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP requires key and value tensors to agree on row count" + ) + if key_shape[0] == 0: + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP requires a non-empty key/value table" + ) + if output_shape != [lookup_shape[0]] + value_shape[1:]: + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP output shape must match lookup count and value tail shape" + ) + if hits_shape != [lookup_shape[0]]: + raise tvm.error.OpNotImplemented( + "HASHTABLE_LOOKUP hits output shape must match lookup count" + ) + + lookup = self.get_tensor_expr(lookup_tensor) + key = self.get_tensor_expr(key_tensor) + value = self.get_tensor_expr(value_tensor) + + positions = relax.op.bucketize(lookup, key, out_int32=True, right=False) + candidate_keys = relax.op.take(key, positions, axis=0, mode="clip") + in_range = relax.op.less(positions, relax.const(key_shape[0], "int32")) + found = relax.op.logical_and(in_range, relax.op.equal(candidate_keys, lookup)) + + gathered_values = relax.op.take(value, positions, axis=0, mode="clip") + output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type()) + zero_values = relax.op.zeros(output_shape, output_dtype) + + if len(value_shape) > 1: + found_values = relax.op.expand_dims(found, axis=list(range(1, len(value_shape)))) + found_values = relax.op.broadcast_to(found_values, output_shape) + else: + found_values = found + + output = relax.op.where(found_values, gathered_values, zero_values) + hits = relax.op.astype(found, "uint8") + return relax.Tuple([output, hits]) + def convert_hashtable_size(self, op): """Convert HASHTABLE_SIZE for a statically imported TFLite hashtable.""" input_tensors = self.get_input_tensors(op) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 7c3e526d99a2..c34da605de18 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -4053,6 +4053,18 @@ def _get_builtin_operator(builtin_name): return getattr(_tfl_builtin_operator, builtin_name) +def _run_module(mod, *inputs): + tgt = tvm.target.Target("c") + ex = tvm.compile(mod, tgt) + vm = relax.VirtualMachine(ex, tvm.cpu()) + vm.set_input("main", *inputs) + vm.invoke_stateful("main") + outputs = vm.get_outputs("main") + if hasattr(outputs, "numpy"): + return outputs.numpy() + return tuple(output.numpy() for output in outputs) + + def _build_tflite_call_model( call_subgraph_index=1, callee_inputs=None, @@ -5844,6 +5856,36 @@ def _build_tflite_hashtable_size_uninitialized_model(): ) +def _build_tflite_hashtable_lookup_model(*, value_shape, value_type=None): + """Build a model containing one HASHTABLE_LOOKUP operator.""" + builder = flatbuffers.Builder(1024) + + value_type = _tfl_tensor_type.FLOAT32 if value_type is None else value_type + + lookup_tensor = _build_tensor(builder, 0, [4], tensor_type=_tfl_tensor_type.INT32) + key_tensor = _build_tensor(builder, 1, [3], tensor_type=_tfl_tensor_type.INT32) + value_tensor = _build_tensor(builder, 2, value_shape, tensor_type=value_type) + output_tensor = _build_tensor(builder, 3, [4, *value_shape[1:]], tensor_type=value_type) + hits_tensor = _build_tensor(builder, 4, [4], tensor_type=_tfl_tensor_type.UINT8) + + hashtable_lookup = _build_operator(builder, 0, [0, 1, 2], [3, 4]) + main_subgraph = _build_subgraph( + builder, + tensors=[lookup_tensor, key_tensor, value_tensor, output_tensor, hits_tensor], + operators=[hashtable_lookup], + inputs=[0, 1, 2], + outputs=[3, 4], + ) + operator_codes = [_build_operator_code(builder, _get_builtin_operator("HASHTABLE_LOOKUP"))] + buffers = [_build_buffer(builder) for _ in range(5)] + return _finish_tflite_model( + builder, + subgraph=main_subgraph, + operator_codes=operator_codes, + buffers=buffers, + ) + + def test_resource_variable_call_once_init_read(): """Test reading a resource variable initialized by a supported CALL_ONCE subgraph.""" mod = _load_model_from_buffer(_build_tflite_resource_variable_model()) @@ -5908,6 +5950,53 @@ def test_hashtable_size_uninitialized_unsupported(): _load_model_from_buffer(_build_tflite_hashtable_size_uninitialized_model()) +def test_hashtable_lookup_1d_value(): + mod = _load_model_from_buffer(_build_tflite_hashtable_lookup_model(value_shape=[3])) + + output, hits = _run_module( + mod, + np.array([1234, -292, -11, 0], dtype=np.int32), + np.array([-11, 0, 1234], dtype=np.int32), + np.array([0.0, 0.1, 0.4], dtype=np.float32), + ) + + np.testing.assert_allclose(output, np.array([0.4, 0.0, 0.0, 0.1], dtype=np.float32)) + np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8)) + + +def test_hashtable_lookup_2d_value(): + mod = _load_model_from_buffer(_build_tflite_hashtable_lookup_model(value_shape=[3, 2])) + + output, hits = _run_module( + mod, + np.array([1234, -292, -11, 0], dtype=np.int32), + np.array([-11, 0, 1234], dtype=np.int32), + np.array([[0.0, 0.1], [1.0, 1.1], [2.0, 2.1]], dtype=np.float32), + ) + + np.testing.assert_allclose( + output, + np.array( + [ + [2.0, 2.1], + [0.0, 0.0], + [0.0, 0.1], + [1.0, 1.1], + ], + dtype=np.float32, + ), + ) + np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8)) + + +def test_hashtable_lookup_string_value_unsupported(): + string_type = _get_string_tensor_type() + with pytest.raises(ValueError, match="unknown dtype `string`"): + _load_model_from_buffer( + _build_tflite_hashtable_lookup_model(value_shape=[3], value_type=string_type) + ) + + def _get_stablehlo_builtin_operator(builtin_name): if not hasattr(_tfl_builtin_operator, builtin_name): pytest.skip(f"TFLite schema does not provide BuiltinOperator.{builtin_name}")