diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 1a224e431ba4..f15c68bc822c 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1104,6 +1104,60 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.const(output, to_type) if isinstance(inputs[0], relax.PrimValue): return relax.PrimValue(inputs[0].value.astype(to_type)) + + try: + np_dst = _np.dtype(str(to_type)) + except Exception: + return relax.op.astype(inputs[0], to_type) + + if np_dst.kind in ("i", "u"): + src = inputs[0] + src_dtype = getattr(getattr(src, "struct_info", None), "dtype", None) or getattr( + src, "dtype", None + ) + if src_dtype is not None and _relax_dtype_is_floating_point(src_dtype): + x_sanitized = bb.emit( + relax.op.where( + relax.op.logical_not(relax.op.isfinite(src)), + relax.const(0.0, src_dtype), + src, + ) + ) + dst_str = str(to_type) + if dst_str.startswith("uint"): + signed = False + bits = int(dst_str[4:]) + elif dst_str.startswith("int"): + signed = True + bits = int(dst_str[3:]) + else: + return relax.op.astype(x_sanitized, to_type) + + temp_dtype = "int64" if bits >= 32 else "int32" + t = relax.op.astype(x_sanitized, temp_dtype) + if bits == 32: + two_pow = relax.const(1 << bits, temp_dtype) + uw = relax.op.floor_mod(t, two_pow) + else: + mask_val = (1 << bits) - 1 + mask = relax.const(mask_val, temp_dtype) + uw = relax.op.bitwise_and(t, mask) + if signed: + half = 1 << (bits - 1) + half_c = relax.const(half, temp_dtype) + if bits == 32: + two_pow = relax.const(1 << bits, temp_dtype) + else: + two_pow = relax.op.add(mask, relax.const(1, temp_dtype)) + wrapped = relax.op.where( + relax.op.greater_equal(uw, half_c), + relax.op.subtract(uw, two_pow), + uw, + ) + else: + wrapped = uw + return relax.op.astype(wrapped, to_type) + return relax.op.astype(inputs[0], to_type) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 427881243663..8d2eda1e002a 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -863,6 +863,23 @@ def test_cast(from_type, to_type): check_correctness(model, opset=13) +def test_cast_nan_inf_to_int8(): + vals = np.array([300.0, np.nan, np.inf, -np.inf, 50.0, -50.0], dtype=np.float32) + node = helper.make_node("Cast", inputs=["a"], outputs=["b"], to=TensorProto.INT8) + graph = helper.make_graph( + [node], + "cast_nan_inf_test", + inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, list(vals.shape))], + outputs=[helper.make_tensor_value_info("b", TensorProto.INT8, list(vals.shape))], + ) + model = helper.make_model(graph, producer_name="cast_nan_inf_test") + tvm_output = run_in_tvm(model, inputs={"a": vals}, opset=13) + out_np = tvm_output.numpy() + expected = np.array([44, 0, 0, 0, 50, -50], dtype=np.int8) + assert out_np.dtype == np.int8 + np.testing.assert_array_equal(out_np, expected) + + def test_gather(): def _verify_gather(data_shape, indices, out_shape, axis=0): gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], axis=axis)