We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent ac4acc6 commit 01bd2dcCopy full SHA for 01bd2dc
1 file changed
onnx_array_api/reference/ops/op_cast_like.py
@@ -17,7 +17,8 @@
17
18
def _cast_like(x, y, saturate):
19
if bfloat16 is None:
20
- return (cast_to(x, y.dtype, saturate),)
+ to = np_dtype_to_tensor_dtype(y.dtype)
21
+ return (cast_to(x, to, saturate),)
22
if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16":
23
# np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
24
to = TensorProto.BFLOAT16
0 commit comments