Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
"HASHTABLE": self.convert_hashtable,
"HASHTABLE_FIND": self.convert_hashtable_find,
"HASHTABLE_IMPORT": self.convert_hashtable_import,
"HASHTABLE_LOOKUP": self.convert_hashtable_lookup,
"HASHTABLE_SIZE": self.convert_hashtable_size,
"IF": self.convert_if,
"L2_NORMALIZATION": self.convert_l2_normalization,
Expand Down Expand Up @@ -755,6 +756,88 @@ def convert_hashtable_find(self, op):
"HASHTABLE_FIND requires TensorType.STRING support in Relax TFLite frontend"
)

def convert_hashtable_lookup(self, op):
"""Convert TFLite HASHTABLE_LOOKUP for non-string value tensors."""
from tflite.TensorType import TensorType

input_tensors = self.get_input_tensors(op)
output_tensors = self.get_output_tensors(op)
if len(input_tensors) != 3 or len(output_tensors) != 2:
raise tvm.error.OpNotImplemented(
"HASHTABLE_LOOKUP expects lookup, key, and value inputs with two outputs"
)

lookup_tensor, key_tensor, value_tensor = input_tensors
output_tensor, hits_tensor = output_tensors

if (
lookup_tensor.tensor.Type() != TensorType.INT32
or key_tensor.tensor.Type() != TensorType.INT32
):
raise tvm.error.OpNotImplemented(
"HASHTABLE_LOOKUP requires int32 lookup and key tensors"
)
if self._is_tflite_string_type(value_tensor.tensor.Type()):
raise tvm.error.OpNotImplemented(
"HASHTABLE_LOOKUP with TensorType.STRING values is not supported"
)
if value_tensor.tensor.Type() != output_tensor.tensor.Type():
raise tvm.error.OpNotImplemented(
"HASHTABLE_LOOKUP output dtype must match the value tensor dtype"
)
if hits_tensor.tensor.Type() != TensorType.UINT8:
raise tvm.error.OpNotImplemented("HASHTABLE_LOOKUP hits output must be uint8")

lookup_shape = to_int_list(self.get_tensor_shape(lookup_tensor))
key_shape = to_int_list(self.get_tensor_shape(key_tensor))
value_shape = to_int_list(self.get_tensor_shape(value_tensor))
output_shape = to_int_list(self.get_tensor_shape(output_tensor))
hits_shape = to_int_list(self.get_tensor_shape(hits_tensor))

if len(lookup_shape) != 1 or len(key_shape) != 1 or len(value_shape) < 1:
raise tvm.error.OpNotImplemented(
"HASHTABLE_LOOKUP requires rank-1 lookup/key and rank>=1 value tensors"
)
if key_shape[0] != value_shape[0]:
raise tvm.error.OpNotImplemented(
"HASHTABLE_LOOKUP requires key and value tensors to agree on row count"
)
if key_shape[0] == 0:
raise tvm.error.OpNotImplemented(
"HASHTABLE_LOOKUP requires a non-empty key/value table"
)
if output_shape != [lookup_shape[0]] + value_shape[1:]:
raise tvm.error.OpNotImplemented(
"HASHTABLE_LOOKUP output shape must match lookup count and value tail shape"
)
if hits_shape != [lookup_shape[0]]:
raise tvm.error.OpNotImplemented(
"HASHTABLE_LOOKUP hits output shape must match lookup count"
)

lookup = self.get_tensor_expr(lookup_tensor)
key = self.get_tensor_expr(key_tensor)
value = self.get_tensor_expr(value_tensor)

positions = relax.op.bucketize(lookup, key, out_int32=True, right=False)
candidate_keys = relax.op.take(key, positions, axis=0, mode="clip")
in_range = relax.op.less(positions, relax.const(key_shape[0], "int32"))
found = relax.op.logical_and(in_range, relax.op.equal(candidate_keys, lookup))
Comment on lines +822 to +825
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation uses relax.op.bucketize to find the lookup positions. However, bucketize assumes that the boundaries (the key tensor) are sorted. In TFLite, HASHTABLE_LOOKUP keys are not guaranteed to be sorted. If the keys are unsorted, bucketize will perform an incorrect binary search and return wrong indices.

To support unsorted keys correctly, we can use a broadcasted comparison to find the matching indices using relax.op.equal, relax.op.argmax, and relax.op.max.

Suggested change
positions = relax.op.bucketize(lookup, key, out_int32=True, right=False)
candidate_keys = relax.op.take(key, positions, axis=0, mode="clip")
in_range = relax.op.less(positions, relax.const(key_shape[0], "int32"))
found = relax.op.logical_and(in_range, relax.op.equal(candidate_keys, lookup))
lookup_expanded = relax.op.expand_dims(lookup, axis=1)
key_expanded = relax.op.expand_dims(key, axis=0)
match_matrix = relax.op.equal(lookup_expanded, key_expanded)
match_matrix_int = relax.op.astype(match_matrix, "int32")
positions = relax.op.astype(relax.op.argmax(match_matrix_int, axis=1), "int32")
found = relax.op.astype(relax.op.max(match_matrix_int, axis=1), "bool")


gathered_values = relax.op.take(value, positions, axis=0, mode="clip")
output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type())
zero_values = relax.op.zeros(output_shape, output_dtype)

if len(value_shape) > 1:
found_values = relax.op.expand_dims(found, axis=list(range(1, len(value_shape))))
found_values = relax.op.broadcast_to(found_values, output_shape)
else:
found_values = found

output = relax.op.where(found_values, gathered_values, zero_values)
hits = relax.op.astype(found, "uint8")
return relax.Tuple([output, hits])

def convert_hashtable_size(self, op):
"""Convert HASHTABLE_SIZE for a statically imported TFLite hashtable."""
input_tensors = self.get_input_tensors(op)
Expand Down
89 changes: 89 additions & 0 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4053,6 +4053,18 @@ def _get_builtin_operator(builtin_name):
return getattr(_tfl_builtin_operator, builtin_name)


def _run_module(mod, *inputs):
tgt = tvm.target.Target("c")
ex = tvm.compile(mod, tgt)
vm = relax.VirtualMachine(ex, tvm.cpu())
vm.set_input("main", *inputs)
vm.invoke_stateful("main")
outputs = vm.get_outputs("main")
if hasattr(outputs, "numpy"):
return outputs.numpy()
return tuple(output.numpy() for output in outputs)


def _build_tflite_call_model(
call_subgraph_index=1,
callee_inputs=None,
Expand Down Expand Up @@ -5844,6 +5856,36 @@ def _build_tflite_hashtable_size_uninitialized_model():
)


def _build_tflite_hashtable_lookup_model(*, value_shape, value_type=None):
"""Build a model containing one HASHTABLE_LOOKUP operator."""
builder = flatbuffers.Builder(1024)

value_type = _tfl_tensor_type.FLOAT32 if value_type is None else value_type

lookup_tensor = _build_tensor(builder, 0, [4], tensor_type=_tfl_tensor_type.INT32)
key_tensor = _build_tensor(builder, 1, [3], tensor_type=_tfl_tensor_type.INT32)
value_tensor = _build_tensor(builder, 2, value_shape, tensor_type=value_type)
output_tensor = _build_tensor(builder, 3, [4, *value_shape[1:]], tensor_type=value_type)
hits_tensor = _build_tensor(builder, 4, [4], tensor_type=_tfl_tensor_type.UINT8)

hashtable_lookup = _build_operator(builder, 0, [0, 1, 2], [3, 4])
main_subgraph = _build_subgraph(
builder,
tensors=[lookup_tensor, key_tensor, value_tensor, output_tensor, hits_tensor],
operators=[hashtable_lookup],
inputs=[0, 1, 2],
outputs=[3, 4],
)
operator_codes = [_build_operator_code(builder, _get_builtin_operator("HASHTABLE_LOOKUP"))]
buffers = [_build_buffer(builder) for _ in range(5)]
return _finish_tflite_model(
builder,
subgraph=main_subgraph,
operator_codes=operator_codes,
buffers=buffers,
)


def test_resource_variable_call_once_init_read():
"""Test reading a resource variable initialized by a supported CALL_ONCE subgraph."""
mod = _load_model_from_buffer(_build_tflite_resource_variable_model())
Expand Down Expand Up @@ -5908,6 +5950,53 @@ def test_hashtable_size_uninitialized_unsupported():
_load_model_from_buffer(_build_tflite_hashtable_size_uninitialized_model())


def test_hashtable_lookup_1d_value():
mod = _load_model_from_buffer(_build_tflite_hashtable_lookup_model(value_shape=[3]))

output, hits = _run_module(
mod,
np.array([1234, -292, -11, 0], dtype=np.int32),
np.array([-11, 0, 1234], dtype=np.int32),
np.array([0.0, 0.1, 0.4], dtype=np.float32),
)

np.testing.assert_allclose(output, np.array([0.4, 0.0, 0.0, 0.1], dtype=np.float32))
np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8))
Comment on lines +5956 to +5964
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since TFLite hashtable keys are not guaranteed to be sorted, we should update the test case to use unsorted keys to ensure correctness and prevent regressions.

Suggested change
output, hits = _run_module(
mod,
np.array([1234, -292, -11, 0], dtype=np.int32),
np.array([-11, 0, 1234], dtype=np.int32),
np.array([0.0, 0.1, 0.4], dtype=np.float32),
)
np.testing.assert_allclose(output, np.array([0.4, 0.0, 0.0, 0.1], dtype=np.float32))
np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8))
output, hits = _run_module(
mod,
np.array([1234, -292, -11, 0], dtype=np.int32),
np.array([0, 1234, -11], dtype=np.int32),
np.array([0.1, 0.4, 0.0], dtype=np.float32),
)
np.testing.assert_allclose(output, np.array([0.4, 0.0, 0.0, 0.1], dtype=np.float32))
np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8))



def test_hashtable_lookup_2d_value():
mod = _load_model_from_buffer(_build_tflite_hashtable_lookup_model(value_shape=[3, 2]))

output, hits = _run_module(
mod,
np.array([1234, -292, -11, 0], dtype=np.int32),
np.array([-11, 0, 1234], dtype=np.int32),
np.array([[0.0, 0.1], [1.0, 1.1], [2.0, 2.1]], dtype=np.float32),
)

np.testing.assert_allclose(
output,
np.array(
[
[2.0, 2.1],
[0.0, 0.0],
[0.0, 0.1],
[1.0, 1.1],
],
dtype=np.float32,
),
)
np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8))


def test_hashtable_lookup_string_value_unsupported():
string_type = _get_string_tensor_type()
with pytest.raises(ValueError, match="unknown dtype `string`"):
_load_model_from_buffer(
_build_tflite_hashtable_lookup_model(value_shape=[3], value_type=string_type)
)


def _get_stablehlo_builtin_operator(builtin_name):
if not hasattr(_tfl_builtin_operator, builtin_name):
pytest.skip(f"TFLite schema does not provide BuiltinOperator.{builtin_name}")
Expand Down
Loading