Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None):
"DIV": functools.partial(self._convert_elemwise, relax_op=_op.divide),
"ELU": self.convert_elu,
"EMBEDDING_LOOKUP": self.convert_embedding_lookup,
"EMBEDDING_LOOKUP_SPARSE": self.convert_embedding_lookup_sparse,
"EQUAL": functools.partial(
self._convert_elemwise, relax_op=_op.equal, comparison_op=True
),
Expand Down Expand Up @@ -6183,6 +6184,120 @@ def convert_embedding_lookup(self, op):
indices = self.get_tensor_expr(indices_tensor)
return relax.op.take(params, indices, axis=0)

def convert_embedding_lookup_sparse(self, op):
"""Convert TFLite EMBEDDING_LOOKUP_SPARSE."""
from tflite.CombinerType import CombinerType
from tflite.EmbeddingLookupSparseOptions import EmbeddingLookupSparseOptions
from tflite.TensorType import TensorType

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 5, "EMBEDDING_LOOKUP_SPARSE should have 5 input tensors"
output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "EMBEDDING_LOOKUP_SPARSE should have 1 output tensor"

ids_tensor, indices_tensor, dense_shape_tensor, weights_tensor, params_tensor = (
input_tensors
)
output_tensor = output_tensors[0]

for tensor in input_tensors:
assert not tensor.qnn_params, "Quantized input is not expected."

assert ids_tensor.tensor.Type() == TensorType.INT32
assert indices_tensor.tensor.Type() == TensorType.INT32
assert dense_shape_tensor.tensor.Type() == TensorType.INT32
assert weights_tensor.tensor.Type() == TensorType.FLOAT32
assert params_tensor.tensor.Type() == TensorType.FLOAT32
assert output_tensor.tensor.Type() == TensorType.FLOAT32

ids_shape = to_int_list(self.get_tensor_shape(ids_tensor))
indices_shape = to_int_list(self.get_tensor_shape(indices_tensor))
dense_shape_shape = to_int_list(self.get_tensor_shape(dense_shape_tensor))
weights_shape = to_int_list(self.get_tensor_shape(weights_tensor))
params_shape = to_int_list(self.get_tensor_shape(params_tensor))

assert len(ids_shape) == 1, "EMBEDDING_LOOKUP_SPARSE ids must be rank 1"
assert len(indices_shape) == 2, "EMBEDDING_LOOKUP_SPARSE indices must be rank 2"
assert len(dense_shape_shape) == 1, "EMBEDDING_LOOKUP_SPARSE dense_shape must be rank 1"
assert len(weights_shape) == 1, "EMBEDDING_LOOKUP_SPARSE weights must be rank 1"
assert len(params_shape) >= 2, "EMBEDDING_LOOKUP_SPARSE params must be rank >= 2"
assert indices_shape[0] == ids_shape[0], (
"EMBEDDING_LOOKUP_SPARSE ids and indices must agree on lookup count"
)
assert weights_shape[0] == ids_shape[0], (
"EMBEDDING_LOOKUP_SPARSE ids and weights must agree on lookup count"
)

if self.has_expr(dense_shape_tensor.tensor_idx):
raise tvm.error.OpNotImplemented(
"TFLite EMBEDDING_LOOKUP_SPARSE with runtime dense_shape is not supported."
)

dense_shape = to_int_list(self.get_tensor_value(dense_shape_tensor))
lookup_rank = indices_shape[1]
assert len(dense_shape) == lookup_rank, (
"EMBEDDING_LOOKUP_SPARSE dense_shape length must match indices width"
)
assert lookup_rank >= 1, "EMBEDDING_LOOKUP_SPARSE indices width must be positive"
if not self.has_expr(ids_tensor.tensor_idx):
ids_value = self.get_tensor_value(ids_tensor)
if np.any(ids_value < 0):
raise tvm.error.OpNotImplemented(
"TFLite EMBEDDING_LOOKUP_SPARSE with negative ids is not supported."
)

params = self.get_tensor_expr(params_tensor)
ids = self.get_tensor_expr(ids_tensor)
weights = self.get_tensor_expr(weights_tensor)
indices = self.get_tensor_expr(indices_tensor)

ids = relax.op.astype(ids, "int32")
lookup = relax.op.take(params, ids, axis=0)

embedding_tail_shape = params_shape[1:]
output_prefix_shape = dense_shape[:-1]
output_shape = output_prefix_shape + embedding_tail_shape

# Aggregation buckets are defined by every sparse index dimension except the last one.
bucket_indices = relax.op.strided_slice(indices, axes=[1], begin=[0], end=[lookup_rank - 1])

weight_expand_shape = [ids_shape[0]] + [1] * len(embedding_tail_shape)
weighted_lookup = relax.op.multiply(lookup, relax.op.reshape(weights, weight_expand_shape))

value_base = relax.const(np.zeros(output_shape, dtype=np.float32), "float32")
summed_lookup = relax.op.scatter_nd(value_base, bucket_indices, weighted_lookup, "add")

op_options = op.BuiltinOptions()
sparse_options = EmbeddingLookupSparseOptions()
sparse_options.Init(op_options.Bytes, op_options.Pos)
combiner = sparse_options.Combiner()
if combiner == CombinerType.SUM:
return summed_lookup

count_shape = output_prefix_shape
count_base = relax.const(np.zeros(count_shape, dtype=np.float32), "float32")
if combiner == CombinerType.MEAN:
denominator_updates = weights
elif combiner == CombinerType.SQRTN:
denominator_updates = relax.op.multiply(weights, weights)
else:
raise tvm.error.OpNotImplemented(
f"Unsupported TFLite EMBEDDING_LOOKUP_SPARSE combiner value {combiner}"
)

denominator = relax.op.scatter_nd(count_base, bucket_indices, denominator_updates, "add")
if combiner == CombinerType.SQRTN:
denominator = relax.op.sqrt(denominator)

broadcast_shape = count_shape + [1] * len(embedding_tail_shape)
denominator = relax.op.reshape(denominator, broadcast_shape)
denominator = relax.op.broadcast_to(denominator, output_shape)
safe_denominator = relax.op.maximum(denominator, relax.const(1e-12, "float32"))
normalized = relax.op.divide(summed_lookup, safe_denominator)
return relax.op.where(
relax.op.greater(denominator, relax.const(0.0, "float32")), normalized, value_base
)

def convert_batch_matmul(self, op):
"""batch_matmul implementation."""

Expand Down
188 changes: 188 additions & 0 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4037,6 +4037,17 @@ def _build_hashtable_options(
return hashtable_options.HashtableOptionsEnd(builder)


def _build_embedding_lookup_sparse_options(builder, combiner):
try:
sparse_options = _get_tflite_schema_module("EmbeddingLookupSparseOptions")
except ModuleNotFoundError:
pytest.skip("TFLite schema does not provide EmbeddingLookupSparseOptions")

sparse_options.EmbeddingLookupSparseOptionsStart(builder)
sparse_options.EmbeddingLookupSparseOptionsAddCombiner(builder, combiner)
return sparse_options.EmbeddingLookupSparseOptionsEnd(builder)


def _load_model_from_buffer(model_bytes):
if hasattr(tflite.Model, "Model"):
tflite_model = tflite.Model.Model.GetRootAsModel(model_bytes, 0)
Expand All @@ -4053,6 +4064,15 @@ def _get_builtin_operator(builtin_name):
return getattr(_tfl_builtin_operator, builtin_name)


def _run_no_input_module(mod):
tgt = tvm.target.Target("c")
ex = tvm.compile(mod, tgt)
vm = relax.VirtualMachine(ex, tvm.cpu())
vm.set_input("main")
vm.invoke_stateful("main")
return vm.get_outputs("main").numpy()


def _build_tflite_call_model(
call_subgraph_index=1,
callee_inputs=None,
Expand Down Expand Up @@ -5844,6 +5864,82 @@ def _build_tflite_hashtable_size_uninitialized_model():
)


def _build_tflite_embedding_lookup_sparse_model(combiner, indices_data, dense_shape_data):
builder = flatbuffers.Builder(4096)

ids_data = np.array([1, 3, 0], dtype=np.int32)
indices_data = np.array(indices_data, dtype=np.int32)
dense_shape_data = np.array(dense_shape_data, dtype=np.int32)
weights_data = np.array([1.0, 2.0, 4.0], dtype=np.float32)
params_data = np.array(
[
[[0.00, 0.01], [0.10, 0.11], [0.20, 0.21]],
[[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]],
[[2.00, 2.01], [2.10, 2.11], [2.20, 2.21]],
[[3.00, 3.01], [3.10, 3.11], [3.20, 3.21]],
],
dtype=np.float32,
)

output_shape = dense_shape_data[:-1].tolist() + list(params_data.shape[1:])
sparse_options = _build_embedding_lookup_sparse_options(builder, combiner)

ids_tensor = _build_tensor(builder, 0, list(ids_data.shape), tensor_type=_tfl_tensor_type.INT32)
indices_tensor = _build_tensor(
builder, 1, list(indices_data.shape), tensor_type=_tfl_tensor_type.INT32
)
dense_shape_tensor = _build_tensor(
builder, 2, list(dense_shape_data.shape), tensor_type=_tfl_tensor_type.INT32
)
weights_tensor = _build_tensor(
builder, 3, list(weights_data.shape), tensor_type=_tfl_tensor_type.FLOAT32
)
params_tensor = _build_tensor(
builder, 4, list(params_data.shape), tensor_type=_tfl_tensor_type.FLOAT32
)
output_tensor = _build_tensor(builder, 5, output_shape, tensor_type=_tfl_tensor_type.FLOAT32)

sparse_op = _build_operator(
builder,
0,
[0, 1, 2, 3, 4],
[5],
builtin_options_type=_get_builtin_options_type("EmbeddingLookupSparseOptions"),
builtin_options=sparse_options,
)
subgraph = _build_subgraph(
builder,
tensors=[
ids_tensor,
indices_tensor,
dense_shape_tensor,
weights_tensor,
params_tensor,
output_tensor,
],
operators=[sparse_op],
inputs=[],
outputs=[5],
)
operator_codes = [
_build_operator_code(builder, _get_builtin_operator("EMBEDDING_LOOKUP_SPARSE"))
]
buffers = [
_build_buffer(builder, ids_data.tobytes()),
_build_buffer(builder, indices_data.tobytes()),
_build_buffer(builder, dense_shape_data.tobytes()),
_build_buffer(builder, weights_data.tobytes()),
_build_buffer(builder, params_data.tobytes()),
_build_buffer(builder),
]
return _finish_tflite_model(
builder,
subgraph=subgraph,
operator_codes=operator_codes,
buffers=buffers,
)


def test_resource_variable_call_once_init_read():
"""Test reading a resource variable initialized by a supported CALL_ONCE subgraph."""
mod = _load_model_from_buffer(_build_tflite_resource_variable_model())
Expand Down Expand Up @@ -5908,6 +6004,98 @@ def test_hashtable_size_uninitialized_unsupported():
_load_model_from_buffer(_build_tflite_hashtable_size_uninitialized_model())


def test_embedding_lookup_sparse_sum():
from tflite.CombinerType import CombinerType

mod = _load_model_from_buffer(
_build_tflite_embedding_lookup_sparse_model(
CombinerType.SUM,
indices_data=[[0, 0], [2, 0], [2, 1]],
dense_shape_data=[3, 2],
)
)

out = _run_no_input_module(mod)
expected = np.array(
[
[[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]],
[[0.00, 0.00], [0.00, 0.00], [0.00, 0.00]],
[[6.00, 6.06], [6.60, 6.66], [7.20, 7.26]],
],
dtype=np.float32,
)
np.testing.assert_allclose(out, expected, rtol=1e-5, atol=1e-5)


def test_embedding_lookup_sparse_mean():
from tflite.CombinerType import CombinerType

mod = _load_model_from_buffer(
_build_tflite_embedding_lookup_sparse_model(
CombinerType.MEAN,
indices_data=[[0, 0], [2, 0], [2, 1]],
dense_shape_data=[3, 2],
)
)

out = _run_no_input_module(mod)
expected = np.array(
[
[[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]],
[[0.00, 0.00], [0.00, 0.00], [0.00, 0.00]],
[[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]],
],
dtype=np.float32,
)
np.testing.assert_allclose(out, expected, rtol=1e-5, atol=1e-5)


def test_embedding_lookup_sparse_sqrtn():
from tflite.CombinerType import CombinerType

mod = _load_model_from_buffer(
_build_tflite_embedding_lookup_sparse_model(
CombinerType.SQRTN,
indices_data=[[0, 0], [2, 0], [2, 1]],
dense_shape_data=[3, 2],
)
)

out = _run_no_input_module(mod)
scale = np.sqrt(20.0).astype("float32")
expected = np.array(
[
[[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]],
[[0.00, 0.00], [0.00, 0.00], [0.00, 0.00]],
[
[6.00 / scale, 6.06 / scale],
[6.60 / scale, 6.66 / scale],
[7.20 / scale, 7.26 / scale],
],
],
dtype=np.float32,
)
np.testing.assert_allclose(out, expected, rtol=1e-5, atol=1e-5)


def test_embedding_lookup_sparse_indices_3d():
from tflite.CombinerType import CombinerType

mod = _load_model_from_buffer(
_build_tflite_embedding_lookup_sparse_model(
CombinerType.SUM,
indices_data=[[0, 0, 0], [2, 0, 0], [2, 0, 1]],
dense_shape_data=[3, 2, 2],
)
)

out = _run_no_input_module(mod)
expected = np.zeros((3, 2, 3, 2), dtype=np.float32)
expected[0, 0] = np.array([[1.00, 1.01], [1.10, 1.11], [1.20, 1.21]], dtype=np.float32)
expected[2, 0] = np.array([[6.00, 6.06], [6.60, 6.66], [7.20, 7.26]], dtype=np.float32)
np.testing.assert_allclose(out, expected, rtol=1e-5, atol=1e-5)


def _get_stablehlo_builtin_operator(builtin_name):
if not hasattr(_tfl_builtin_operator, builtin_name):
pytest.skip(f"TFLite schema does not provide BuiltinOperator.{builtin_name}")
Expand Down
Loading