From 810bc8011814fb99f4b8393271b69430fb52e044 Mon Sep 17 00:00:00 2001 From: RoomWithOutRoof Date: Mon, 20 Apr 2026 01:43:17 +0800 Subject: [PATCH 1/2] Add MLX Op Handler for aten.isinf Implement isinf op handler for the MLX delegate backend. isinf(x) is decomposed as abs(x) == inf, using existing AbsNode and EqualNode which are already supported in the MLX graph schema. The handler also includes a corresponding test case with _inf_input_fn that generates inputs containing both positive and negative infinity. Fixes: #18922 --- backends/mlx/ops.py | 35 +++++++++++++++++++++++++++++++++++ backends/mlx/test/test_ops.py | 17 ++++++++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index 3f7da88a793..e210eedade7 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -444,6 +444,41 @@ def _isnan_handler(P: MLXProgramBuilder, n: Node) -> Slot: return out +@REGISTRY.register(target=[torch.ops.aten.isinf.default]) +def _isinf_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Handle aten.isinf - check for infinite values element-wise. + + isinf(x) is equivalent to abs(x) == inf. + """ + args = P.args(n) + require_args(args, 1, 1, "aten.isinf") + require_kwargs(P.kwargs(n), set(), "aten.isinf") + x = args[0] + + # Create abs(x) + _, abs_tmp = P.make_tmp_slot() + P.emit( + AbsNode( + x=P.slot_to_tid(x), + out=P.slot_to_tid(abs_tmp), + ) + ) + + # Create inf constant + inf_slot = emit_lifted_constant(P, float('inf'), torch.float32) + + # Compare abs(x) == inf + out = P.make_or_get_slot(n) + P.emit( + EqualNode( + a=P.slot_to_tid(abs_tmp), + b=P.slot_to_tid(inf_slot), + out=P.slot_to_tid(out), + ) + ) + return out + + _BINARY_OPS: List[Tuple[List[Any], Any, str, bool]] = [ ( [torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar], diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 7ba3902e436..618c7305dcc 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -4020,6 +4020,21 @@ def fn(shape, dtype): return fn +def _inf_input_fn(): + """Return a callable(shape, dtype) that generates inputs with some inf values.""" + + def fn(shape, dtype): + x = torch.randn(shape, dtype=dtype) + # Insert some inf values + mask_pos = torch.rand(shape) > 0.8 + mask_neg = torch.rand(shape) > 0.9 + x[mask_pos] = float('inf') + x[mask_neg] = float('-inf') + return (x,) + + return fn + + # Standard shape and dtype configs used by unary tests. _SHAPES_3 = [(16,), (4, 4), (2, 3, 4)] _SHAPES_2 = [(16,), (4, 4)] @@ -4112,7 +4127,7 @@ def create_model(self) -> nn.Module: {"op_name": "neg", "op_fn": torch.neg}, {"op_name": "logical_not","op_fn": torch.logical_not, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn": _bool_input_fn()}, {"op_name": "isnan", "op_fn": torch.isnan, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _nan_input_fn()}, - # activations + {"op_name": "isinf", "op_fn": torch.isinf, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _inf_input_fn()}, # activations {"op_name": "relu", "op_fn": torch.relu, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 128, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2, offset=-1)}, {"op_name": "sigmoid", "op_fn": torch.sigmoid, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 1, 128)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2)}, {"op_name": "tanh", "op_fn": torch.tanh, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 1, 128)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=3)}, From 5322330b08de842456fb01855c0f43c49fabf17a Mon Sep 17 00:00:00 2001 From: Jah-yee <166608075+Jah-yee@users.noreply.github.com> Date: Thu, 25 Jun 2026 14:06:30 +0800 Subject: [PATCH 2/2] Fix review feedback: formatting, mask overlap, dtype comment, lint double quotes --- backends/mlx/ops.py | 4 ++-- backends/mlx/test/test_ops.py | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/backends/mlx/ops.py b/backends/mlx/ops.py index e210eedade7..3961e5acea5 100644 --- a/backends/mlx/ops.py +++ b/backends/mlx/ops.py @@ -464,8 +464,8 @@ def _isinf_handler(P: MLXProgramBuilder, n: Node) -> Slot: ) ) - # Create inf constant - inf_slot = emit_lifted_constant(P, float('inf'), torch.float32) + # Create inf constant (float32; EqualNode handles type promotion to match input dtype) + inf_slot = emit_lifted_constant(P, float("inf"), torch.float32) # Compare abs(x) == inf out = P.make_or_get_slot(n) diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 618c7305dcc..980ca9c3831 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -4025,11 +4025,11 @@ def _inf_input_fn(): def fn(shape, dtype): x = torch.randn(shape, dtype=dtype) - # Insert some inf values - mask_pos = torch.rand(shape) > 0.8 - mask_neg = torch.rand(shape) > 0.9 - x[mask_pos] = float('inf') - x[mask_neg] = float('-inf') + # Insert ~20% +inf and ~10% -inf using non-overlapping masks + mask_pos = torch.rand(shape) > 0.8 # ~20% -> +inf + mask_neg = (~mask_pos) & (torch.rand(shape) > 0.9) # ~10% of remaining -> -inf + x[mask_pos] = float("inf") + x[mask_neg] = float("-inf") return (x,) return fn @@ -4127,7 +4127,8 @@ def create_model(self) -> nn.Module: {"op_name": "neg", "op_fn": torch.neg}, {"op_name": "logical_not","op_fn": torch.logical_not, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn": _bool_input_fn()}, {"op_name": "isnan", "op_fn": torch.isnan, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _nan_input_fn()}, - {"op_name": "isinf", "op_fn": torch.isinf, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _inf_input_fn()}, # activations + {"op_name": "isinf", "op_fn": torch.isinf, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _inf_input_fn()}, + # activations {"op_name": "relu", "op_fn": torch.relu, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 128, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2, offset=-1)}, {"op_name": "sigmoid", "op_fn": torch.sigmoid, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 1, 128)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2)}, {"op_name": "tanh", "op_fn": torch.tanh, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 1, 128)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=3)},