Skip to content

[Relax][Frontend][TFLite] Add HASHTABLE_LOOKUP converter#19654

Open
LudovicoYIN wants to merge 1 commit into
apache:mainfrom
LudovicoYIN:lukeyin/tflite-hashtable-lookup
Open

[Relax][Frontend][TFLite] Add HASHTABLE_LOOKUP converter#19654
LudovicoYIN wants to merge 1 commit into
apache:mainfrom
LudovicoYIN:lukeyin/tflite-hashtable-lookup

Conversation

@LudovicoYIN
Copy link
Copy Markdown
Contributor

Summary

Add Relax TFLite frontend support for HASHTABLE_LOOKUP.

This PR adds a converter for HASHTABLE_LOOKUP in the Relax TFLite frontend. The implementation supports non-string value tensors and lowers the lookup through bucketize, take, and where so that missing keys return zero-filled values together with a uint8 hits mask matching TFLite semantics for the supported cases.

The PR also adds handcrafted TFLite frontend tests covering:

  • 1D float value tensors
  • 2D float value tensors
  • the current unsupported string-value case

Testing

Ran tests/python/relax/test_frontend_tflite.py -k 'hashtable_lookup'.

Part of #19519

Add Relax TFLite frontend support for HASHTABLE_LOOKUP.

This change implements HASHTABLE_LOOKUP for non-string value tensors.
The converter lowers the lookup through bucketize, take, and where so
that missing keys return zero-filled values and a uint8 hits mask that
matches TFLite semantics for the supported cases.

The patch also adds handcrafted TFLite frontend tests covering 1D and
2D float value tensors, along with the current unsupported string-value
case.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the HASHTABLE_LOOKUP operator in the Relax TFLite frontend, along with corresponding unit tests. The review feedback highlights a critical issue where the implementation uses relax.op.bucketize, which incorrectly assumes that the lookup keys are sorted. To correctly support unsorted keys, it is recommended to use broadcasted comparison operators instead and update the test cases to use unsorted keys.

Comment on lines +822 to +825
positions = relax.op.bucketize(lookup, key, out_int32=True, right=False)
candidate_keys = relax.op.take(key, positions, axis=0, mode="clip")
in_range = relax.op.less(positions, relax.const(key_shape[0], "int32"))
found = relax.op.logical_and(in_range, relax.op.equal(candidate_keys, lookup))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation uses relax.op.bucketize to find the lookup positions. However, bucketize assumes that the boundaries (the key tensor) are sorted. In TFLite, HASHTABLE_LOOKUP keys are not guaranteed to be sorted. If the keys are unsorted, bucketize will perform an incorrect binary search and return wrong indices.

To support unsorted keys correctly, we can use a broadcasted comparison to find the matching indices using relax.op.equal, relax.op.argmax, and relax.op.max.

Suggested change
positions = relax.op.bucketize(lookup, key, out_int32=True, right=False)
candidate_keys = relax.op.take(key, positions, axis=0, mode="clip")
in_range = relax.op.less(positions, relax.const(key_shape[0], "int32"))
found = relax.op.logical_and(in_range, relax.op.equal(candidate_keys, lookup))
lookup_expanded = relax.op.expand_dims(lookup, axis=1)
key_expanded = relax.op.expand_dims(key, axis=0)
match_matrix = relax.op.equal(lookup_expanded, key_expanded)
match_matrix_int = relax.op.astype(match_matrix, "int32")
positions = relax.op.astype(relax.op.argmax(match_matrix_int, axis=1), "int32")
found = relax.op.astype(relax.op.max(match_matrix_int, axis=1), "bool")

Comment on lines +5956 to +5964
output, hits = _run_module(
mod,
np.array([1234, -292, -11, 0], dtype=np.int32),
np.array([-11, 0, 1234], dtype=np.int32),
np.array([0.0, 0.1, 0.4], dtype=np.float32),
)

np.testing.assert_allclose(output, np.array([0.4, 0.0, 0.0, 0.1], dtype=np.float32))
np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since TFLite hashtable keys are not guaranteed to be sorted, we should update the test case to use unsorted keys to ensure correctness and prevent regressions.

Suggested change
output, hits = _run_module(
mod,
np.array([1234, -292, -11, 0], dtype=np.int32),
np.array([-11, 0, 1234], dtype=np.int32),
np.array([0.0, 0.1, 0.4], dtype=np.float32),
)
np.testing.assert_allclose(output, np.array([0.4, 0.0, 0.0, 0.1], dtype=np.float32))
np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8))
output, hits = _run_module(
mod,
np.array([1234, -292, -11, 0], dtype=np.int32),
np.array([0, 1234, -11], dtype=np.int32),
np.array([0.1, 0.4, 0.0], dtype=np.float32),
)
np.testing.assert_allclose(output, np.array([0.4, 0.0, 0.0, 0.1], dtype=np.float32))
np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant