Skip to content
54 changes: 54 additions & 0 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,60 @@ def _impl_v13(cls, bb, inputs, attr, params):
return relax.const(output, to_type)
if isinstance(inputs[0], relax.PrimValue):
return relax.PrimValue(inputs[0].value.astype(to_type))

try:
np_dst = _np.dtype(str(to_type))
except Exception:
return relax.op.astype(inputs[0], to_type)

if np_dst.kind in ("i", "u"):
src = inputs[0]
src_dtype = getattr(getattr(src, "struct_info", None), "dtype", None) or getattr(
src, "dtype", None
)
if src_dtype is not None and _relax_dtype_is_floating_point(src_dtype):
x_sanitized = bb.emit(
relax.op.where(
relax.op.logical_not(relax.op.isfinite(src)),
relax.const(0.0, src_dtype),
src,
)
)
dst_str = str(to_type)
if dst_str.startswith("uint"):
signed = False
bits = int(dst_str[4:])
elif dst_str.startswith("int"):
signed = True
bits = int(dst_str[3:])
else:
return relax.op.astype(x_sanitized, to_type)

temp_dtype = "int64" if bits >= 32 else "int32"
t = relax.op.astype(x_sanitized, temp_dtype)
if bits == 32:
two_pow = relax.const(1 << bits, temp_dtype)
uw = relax.op.floor_mod(t, two_pow)
else:
mask_val = (1 << bits) - 1
mask = relax.const(mask_val, temp_dtype)
uw = relax.op.bitwise_and(t, mask)
if signed:
half = 1 << (bits - 1)
half_c = relax.const(half, temp_dtype)
if bits == 32:
two_pow = relax.const(1 << bits, temp_dtype)
else:
two_pow = relax.op.add(mask, relax.const(1, temp_dtype))
wrapped = relax.op.where(
relax.op.greater_equal(uw, half_c),
relax.op.subtract(uw, two_pow),
uw,
)
else:
wrapped = uw
return relax.op.astype(wrapped, to_type)

return relax.op.astype(inputs[0], to_type)


Expand Down
17 changes: 17 additions & 0 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,23 @@ def test_cast(from_type, to_type):
check_correctness(model, opset=13)


def test_cast_nan_inf_to_int8():
vals = np.array([300.0, np.nan, np.inf, -np.inf, 50.0, -50.0], dtype=np.float32)
node = helper.make_node("Cast", inputs=["a"], outputs=["b"], to=TensorProto.INT8)
graph = helper.make_graph(
[node],
"cast_nan_inf_test",
inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, list(vals.shape))],
outputs=[helper.make_tensor_value_info("b", TensorProto.INT8, list(vals.shape))],
)
model = helper.make_model(graph, producer_name="cast_nan_inf_test")
tvm_output = run_in_tvm(model, inputs={"a": vals}, opset=13)
out_np = tvm_output.numpy()
expected = np.array([44, 0, 0, 0, 50, -50], dtype=np.int8)
assert out_np.dtype == np.int8
np.testing.assert_array_equal(out_np, expected)


def test_gather():
def _verify_gather(data_shape, indices, out_shape, axis=0):
gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], axis=axis)
Expand Down
Loading