[Relax][Frontend][TFLite] Add HASHTABLE_LOOKUP converter#19654
[Relax][Frontend][TFLite] Add HASHTABLE_LOOKUP converter#19654LudovicoYIN wants to merge 1 commit into
Conversation
Add Relax TFLite frontend support for HASHTABLE_LOOKUP. This change implements HASHTABLE_LOOKUP for non-string value tensors. The converter lowers the lookup through bucketize, take, and where so that missing keys return zero-filled values and a uint8 hits mask that matches TFLite semantics for the supported cases. The patch also adds handcrafted TFLite frontend tests covering 1D and 2D float value tensors, along with the current unsupported string-value case.
There was a problem hiding this comment.
Code Review
This pull request adds support for the HASHTABLE_LOOKUP operator in the Relax TFLite frontend, along with corresponding unit tests. The review feedback highlights a critical issue where the implementation uses relax.op.bucketize, which incorrectly assumes that the lookup keys are sorted. To correctly support unsorted keys, it is recommended to use broadcasted comparison operators instead and update the test cases to use unsorted keys.
| positions = relax.op.bucketize(lookup, key, out_int32=True, right=False) | ||
| candidate_keys = relax.op.take(key, positions, axis=0, mode="clip") | ||
| in_range = relax.op.less(positions, relax.const(key_shape[0], "int32")) | ||
| found = relax.op.logical_and(in_range, relax.op.equal(candidate_keys, lookup)) |
There was a problem hiding this comment.
The current implementation uses relax.op.bucketize to find the lookup positions. However, bucketize assumes that the boundaries (the key tensor) are sorted. In TFLite, HASHTABLE_LOOKUP keys are not guaranteed to be sorted. If the keys are unsorted, bucketize will perform an incorrect binary search and return wrong indices.
To support unsorted keys correctly, we can use a broadcasted comparison to find the matching indices using relax.op.equal, relax.op.argmax, and relax.op.max.
| positions = relax.op.bucketize(lookup, key, out_int32=True, right=False) | |
| candidate_keys = relax.op.take(key, positions, axis=0, mode="clip") | |
| in_range = relax.op.less(positions, relax.const(key_shape[0], "int32")) | |
| found = relax.op.logical_and(in_range, relax.op.equal(candidate_keys, lookup)) | |
| lookup_expanded = relax.op.expand_dims(lookup, axis=1) | |
| key_expanded = relax.op.expand_dims(key, axis=0) | |
| match_matrix = relax.op.equal(lookup_expanded, key_expanded) | |
| match_matrix_int = relax.op.astype(match_matrix, "int32") | |
| positions = relax.op.astype(relax.op.argmax(match_matrix_int, axis=1), "int32") | |
| found = relax.op.astype(relax.op.max(match_matrix_int, axis=1), "bool") |
| output, hits = _run_module( | ||
| mod, | ||
| np.array([1234, -292, -11, 0], dtype=np.int32), | ||
| np.array([-11, 0, 1234], dtype=np.int32), | ||
| np.array([0.0, 0.1, 0.4], dtype=np.float32), | ||
| ) | ||
|
|
||
| np.testing.assert_allclose(output, np.array([0.4, 0.0, 0.0, 0.1], dtype=np.float32)) | ||
| np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8)) |
There was a problem hiding this comment.
Since TFLite hashtable keys are not guaranteed to be sorted, we should update the test case to use unsorted keys to ensure correctness and prevent regressions.
| output, hits = _run_module( | |
| mod, | |
| np.array([1234, -292, -11, 0], dtype=np.int32), | |
| np.array([-11, 0, 1234], dtype=np.int32), | |
| np.array([0.0, 0.1, 0.4], dtype=np.float32), | |
| ) | |
| np.testing.assert_allclose(output, np.array([0.4, 0.0, 0.0, 0.1], dtype=np.float32)) | |
| np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8)) | |
| output, hits = _run_module( | |
| mod, | |
| np.array([1234, -292, -11, 0], dtype=np.int32), | |
| np.array([0, 1234, -11], dtype=np.int32), | |
| np.array([0.1, 0.4, 0.0], dtype=np.float32), | |
| ) | |
| np.testing.assert_allclose(output, np.array([0.4, 0.0, 0.0, 0.1], dtype=np.float32)) | |
| np.testing.assert_array_equal(hits, np.array([1, 0, 1, 1], dtype=np.uint8)) |
Summary
Add Relax TFLite frontend support for
HASHTABLE_LOOKUP.This PR adds a converter for
HASHTABLE_LOOKUPin the Relax TFLite frontend. The implementation supports non-string value tensors and lowers the lookup throughbucketize,take, andwhereso that missing keys return zero-filled values together with auint8hits mask matching TFLite semantics for the supported cases.The PR also adds handcrafted TFLite frontend tests covering:
Testing
Ran
tests/python/relax/test_frontend_tflite.py -k 'hashtable_lookup'.Part of #19519