From 234c650f4e8166d0706f69608bd71f99adefe6ad Mon Sep 17 00:00:00 2001 From: cchung100m Date: Wed, 27 May 2026 23:48:40 +0800 Subject: [PATCH 01/11] [Relax][ONNX] Fix Cast operator float->int NaN/Inf handling --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 1a224e431ba4..28bb2e6e88f6 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1104,6 +1104,20 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.const(output, to_type) if isinstance(inputs[0], relax.PrimValue): return relax.PrimValue(inputs[0].value.astype(to_type)) + + try: + np_dst = _np.dtype(str(to_type)) + except Exception: + return relax.op.astype(inputs[0], to_type) + + if np_dst.kind in ("i", "u"): + src = inputs[0] + src_dtype = getattr(getattr(src, "struct_info", None), "dtype", None) or getattr(src, "dtype", None) + if src_dtype is not None and _relax_dtype_is_floating_point(src_dtype): + bad = relax.op.logical_or(relax.op.isnan(src), relax.op.isinf(src)) + x = bb.emit(relax.op.where(bad, relax.const(0.0, src_dtype), src)) + return relax.op.astype(x, to_type) + return relax.op.astype(inputs[0], to_type) From 1329d38c1176f41a93eda3f15e651b2f657f542f Mon Sep 17 00:00:00 2001 From: cchung100m Date: Thu, 28 May 2026 08:16:51 +0800 Subject: [PATCH 02/11] Fix lint error --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 28bb2e6e88f6..d01a87f5f5fd 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1112,9 +1112,11 @@ def _impl_v13(cls, bb, inputs, attr, params): if np_dst.kind in ("i", "u"): src = inputs[0] - src_dtype = getattr(getattr(src, "struct_info", None), "dtype", None) or getattr(src, "dtype", None) + src_dtype = getattr(getattr(src, "struct_info", None), "dtype", None) or getattr( + src, "dtype", None + ) if src_dtype is not None and _relax_dtype_is_floating_point(src_dtype): - bad = relax.op.logical_or(relax.op.isnan(src), relax.op.isinf(src)) + bad = relax.op.logical_not(relax.op.isfinite(src)) x = bb.emit(relax.op.where(bad, relax.const(0.0, src_dtype), src)) return relax.op.astype(x, to_type) From ec3848462a4ee2a378175e4d6becff85d2ce55fa Mon Sep 17 00:00:00 2001 From: cchung100m Date: Thu, 28 May 2026 20:00:47 +0800 Subject: [PATCH 03/11] Add test case: test_cast_nan_inf_to_int8 --- tests/python/relax/test_frontend_onnx.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 427881243663..00c336ebeac9 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -863,6 +863,23 @@ def test_cast(from_type, to_type): check_correctness(model, opset=13) +def test_cast_nan_inf_to_int8(): + vals = np.array([300.0, np.nan, np.PINF, np.NINF, 50.0, -50.0], dtype=np.float32) + node = helper.make_node("Cast", inputs=["a"], outputs=["b"], to=TensorProto.INT8) + graph = helper.make_graph( + [node], + "cast_nan_inf_test", + inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, list(vals.shape))], + outputs=[helper.make_tensor_value_info("b", TensorProto.INT8, list(vals.shape))], + ) + model = helper.make_model(graph, producer_name="cast_nan_inf_test") + tvm_output = run_in_tvm(model, inputs={"a": vals}, opset=13) + out_np = tvm_output.numpy() + expected = np.array([44, 0, 0, 0, 50, -50], dtype=np.int8) + assert out_np.dtype == np.int8 + np.testing.assert_array_equal(out_np, expected) + + def test_gather(): def _verify_gather(data_shape, indices, out_shape, axis=0): gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], axis=axis) From 3db42ed45114f4514882674aae6cb3c27184484c Mon Sep 17 00:00:00 2001 From: cchung100m Date: Thu, 28 May 2026 20:58:20 +0800 Subject: [PATCH 04/11] replace np.PINF, np.NINF with np.inf, -np.inf --- tests/python/relax/test_frontend_onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 00c336ebeac9..8d2eda1e002a 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -864,7 +864,7 @@ def test_cast(from_type, to_type): def test_cast_nan_inf_to_int8(): - vals = np.array([300.0, np.nan, np.PINF, np.NINF, 50.0, -50.0], dtype=np.float32) + vals = np.array([300.0, np.nan, np.inf, -np.inf, 50.0, -50.0], dtype=np.float32) node = helper.make_node("Cast", inputs=["a"], outputs=["b"], to=TensorProto.INT8) graph = helper.make_graph( [node], From 920449136bdbcb125902d5a7015de46d5662e02f Mon Sep 17 00:00:00 2001 From: cchung100m Date: Fri, 29 May 2026 22:14:52 +0800 Subject: [PATCH 05/11] Add saturate 300->44(int8) not 127 --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 37 +++++++++++++++++-- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index d01a87f5f5fd..f02ff1e42a7a 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1116,9 +1116,40 @@ def _impl_v13(cls, bb, inputs, attr, params): src, "dtype", None ) if src_dtype is not None and _relax_dtype_is_floating_point(src_dtype): - bad = relax.op.logical_not(relax.op.isfinite(src)) - x = bb.emit(relax.op.where(bad, relax.const(0.0, src_dtype), src)) - return relax.op.astype(x, to_type) + x_sanitized = bb.emit( + relax.op.where( + relax.op.logical_not(relax.op.isfinite(src)), + relax.const(0.0, src_dtype), + src, + ) + ) + dst_str = str(to_type) + if dst_str.startswith("uint"): + signed = False + bits = int(dst_str[4:]) + elif dst_str.startswith("int"): + signed = True + bits = int(dst_str[3:]) + else: + return relax.op.astype(x_sanitized, to_type) + + temp_dtype = "int64" if bits > 32 else "int32" + t = relax.op.astype(x_sanitized, temp_dtype) + mask_val = (1 << bits) - 1 + mask = relax.const(mask_val, temp_dtype) + uw = relax.op.bitwise_and(t, mask) + if signed: + half = 1 << (bits - 1) + half_c = relax.const(half, temp_dtype) + two_pow = relax.const(1 << bits, temp_dtype) + wrapped = relax.op.where( + relax.op.greater_equal(uw, half_c), + relax.op.subtract(uw, two_pow), + uw, + ) + else: + wrapped = uw + return relax.op.astype(wrapped, to_type) return relax.op.astype(inputs[0], to_type) From a276ae13e5272a595ed893247a9c4a1da0428a97 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sun, 31 May 2026 14:56:03 +0800 Subject: [PATCH 06/11] Fix overflow error --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index f02ff1e42a7a..a7e71e9f7326 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1118,8 +1118,8 @@ def _impl_v13(cls, bb, inputs, attr, params): if src_dtype is not None and _relax_dtype_is_floating_point(src_dtype): x_sanitized = bb.emit( relax.op.where( - relax.op.logical_not(relax.op.isfinite(src)), - relax.const(0.0, src_dtype), + relax.op.logical_not(relax.op.isfinite(src)), + relax.const(0.0, src_dtype), src, ) ) @@ -1136,6 +1136,8 @@ def _impl_v13(cls, bb, inputs, attr, params): temp_dtype = "int64" if bits > 32 else "int32" t = relax.op.astype(x_sanitized, temp_dtype) mask_val = (1 << bits) - 1 + if temp_dtype == "int32" and mask_val > 0x7FFFFFFF: + temp_dtype = "int64" mask = relax.const(mask_val, temp_dtype) uw = relax.op.bitwise_and(t, mask) if signed: @@ -1143,7 +1145,7 @@ def _impl_v13(cls, bb, inputs, attr, params): half_c = relax.const(half, temp_dtype) two_pow = relax.const(1 << bits, temp_dtype) wrapped = relax.op.where( - relax.op.greater_equal(uw, half_c), + relax.op.greater_equal(uw, half_c), relax.op.subtract(uw, two_pow), uw, ) From f17efe0c574dfffe99214e8f67745fea4dd06b0a Mon Sep 17 00:00:00 2001 From: cchung100m Date: Mon, 1 Jun 2026 20:18:55 +0800 Subject: [PATCH 07/11] Fix Type Error: Binary operators must have the datatype for both operands --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index a7e71e9f7326..1bc8cbace965 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1134,10 +1134,10 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.astype(x_sanitized, to_type) temp_dtype = "int64" if bits > 32 else "int32" - t = relax.op.astype(x_sanitized, temp_dtype) mask_val = (1 << bits) - 1 if temp_dtype == "int32" and mask_val > 0x7FFFFFFF: temp_dtype = "int64" + t = relax.op.astype(x_sanitxized, temp_dtype) mask = relax.const(mask_val, temp_dtype) uw = relax.op.bitwise_and(t, mask) if signed: From a5e79d4a4e47cc74f9a9b032d5fbbb703650be40 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Mon, 1 Jun 2026 20:24:53 +0800 Subject: [PATCH 08/11] Fix typo --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 1bc8cbace965..5a14558bc7b4 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1137,7 +1137,7 @@ def _impl_v13(cls, bb, inputs, attr, params): mask_val = (1 << bits) - 1 if temp_dtype == "int32" and mask_val > 0x7FFFFFFF: temp_dtype = "int64" - t = relax.op.astype(x_sanitxized, temp_dtype) + t = relax.op.astype(x_sanitized, temp_dtype) mask = relax.const(mask_val, temp_dtype) uw = relax.op.bitwise_and(t, mask) if signed: From db182e946b7ebdf2bffa7adc9216fb4649723c60 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Mon, 1 Jun 2026 21:48:54 +0800 Subject: [PATCH 09/11] Fix literal value 4294967296 exceeds maximum of int32 --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 5a14558bc7b4..f33cc747049e 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1133,10 +1133,8 @@ def _impl_v13(cls, bb, inputs, attr, params): else: return relax.op.astype(x_sanitized, to_type) - temp_dtype = "int64" if bits > 32 else "int32" + temp_dtype = "int64" if bits >= 32 else "int32" mask_val = (1 << bits) - 1 - if temp_dtype == "int32" and mask_val > 0x7FFFFFFF: - temp_dtype = "int64" t = relax.op.astype(x_sanitized, temp_dtype) mask = relax.const(mask_val, temp_dtype) uw = relax.op.bitwise_and(t, mask) From 2b73e78bc11c96476358514cdf820e794f69e0f7 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Mon, 1 Jun 2026 23:02:22 +0800 Subject: [PATCH 10/11] Fix literal value 4294967296 exceeds maximum of int32 --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index f33cc747049e..4cce15b2939a 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1141,7 +1141,7 @@ def _impl_v13(cls, bb, inputs, attr, params): if signed: half = 1 << (bits - 1) half_c = relax.const(half, temp_dtype) - two_pow = relax.const(1 << bits, temp_dtype) + two_pow = relax.op.add(mask, relax.const(1, temp_dtype)) wrapped = relax.op.where( relax.op.greater_equal(uw, half_c), relax.op.subtract(uw, two_pow), From 4e11f72aba53653ca44d15e9deb0a4d0a5918a4c Mon Sep 17 00:00:00 2001 From: cchung100m Date: Tue, 2 Jun 2026 00:18:52 +0800 Subject: [PATCH 11/11] Refactor two_pow --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 4cce15b2939a..f15c68bc822c 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1134,14 +1134,21 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.astype(x_sanitized, to_type) temp_dtype = "int64" if bits >= 32 else "int32" - mask_val = (1 << bits) - 1 t = relax.op.astype(x_sanitized, temp_dtype) - mask = relax.const(mask_val, temp_dtype) - uw = relax.op.bitwise_and(t, mask) + if bits == 32: + two_pow = relax.const(1 << bits, temp_dtype) + uw = relax.op.floor_mod(t, two_pow) + else: + mask_val = (1 << bits) - 1 + mask = relax.const(mask_val, temp_dtype) + uw = relax.op.bitwise_and(t, mask) if signed: half = 1 << (bits - 1) half_c = relax.const(half, temp_dtype) - two_pow = relax.op.add(mask, relax.const(1, temp_dtype)) + if bits == 32: + two_pow = relax.const(1 << bits, temp_dtype) + else: + two_pow = relax.op.add(mask, relax.const(1, temp_dtype)) wrapped = relax.op.where( relax.op.greater_equal(uw, half_c), relax.op.subtract(uw, two_pow),