From 169eafa190a0a82156ec76e1d583a7ce2967e059 Mon Sep 17 00:00:00 2001 From: actink Date: Tue, 19 May 2026 21:30:50 +0800 Subject: [PATCH 1/5] [TIR] Fixup max/min with input containing NAN fixed #19579 --- src/target/llvm/codegen_llvm.cc | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 44308be5ba2f..e920b108015d 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1598,13 +1598,21 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); - return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b); + if (op->a.dtype().is_int() || op->a.dtype().is_uint()) { + return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b); + } else { + return builder_->CreateMinimum(a, b); + } } llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); - return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b); + if (op->a.dtype().is_int() || op->a.dtype().is_uint()) { + return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b); + } else { + return builder_->CreateMaximum(a, b); + } } llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) { From 5559523e3b800f55678a7d50168f3ff1a371febe Mon Sep 17 00:00:00 2001 From: actink Date: Tue, 19 May 2026 22:18:26 +0800 Subject: [PATCH 2/5] add llvm version macro --- src/target/llvm/codegen_llvm.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index e920b108015d..38a1ab21424b 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1598,21 +1598,23 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); - if (op->a.dtype().is_int() || op->a.dtype().is_uint()) { - return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b); - } else { + if (op->a.dtype().is_float()) { +#if TVM_LLVM_VERSION >= 120 return builder_->CreateMinimum(a, b); +#endif } + return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b); } llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); - if (op->a.dtype().is_int() || op->a.dtype().is_uint()) { - return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b); - } else { + if (op->a.dtype().is_float()) { +#if TVM_LLVM_VERSION >= 120 return builder_->CreateMaximum(a, b); +#endif } + return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b); } llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) { From 2340aeb310bfd99cd72169800e0648e80f276b5c Mon Sep 17 00:00:00 2001 From: actink Date: Wed, 20 May 2026 08:41:32 +0800 Subject: [PATCH 3/5] del TVM_LLVM_VERSION >= 120, since bumped minimum llvm 150 --- src/target/llvm/codegen_llvm.cc | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 38a1ab21424b..391110847d4c 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1599,22 +1599,20 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); if (op->a.dtype().is_float()) { -#if TVM_LLVM_VERSION >= 120 return builder_->CreateMinimum(a, b); -#endif + } else { + return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b); } - return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b); } llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); if (op->a.dtype().is_float()) { -#if TVM_LLVM_VERSION >= 120 return builder_->CreateMaximum(a, b); -#endif + } else { + return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b); } - return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b); } llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) { From ed160664d181c7b1562eec4cb104997a295930e4 Mon Sep 17 00:00:00 2001 From: actink Date: Thu, 28 May 2026 20:43:50 +0800 Subject: [PATCH 4/5] use workaround Minimum/Maximum --- src/target/llvm/codegen_llvm.cc | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 391110847d4c..dc83a40f6876 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1599,7 +1599,12 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); if (op->a.dtype().is_float()) { - return builder_->CreateMinimum(a, b); + llvm::Value* nan_a = builder_->CreateFCmpUNO(a, a); + llvm::Value* nan_b = builder_->CreateFCmpUNO(b, b); + return builder_->CreateSelect( + nan_a, a, + builder_->CreateSelect(nan_b, b, + builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b))); } else { return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b); } @@ -1609,7 +1614,12 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); if (op->a.dtype().is_float()) { - return builder_->CreateMaximum(a, b); + llvm::Value* nan_a = builder_->CreateFCmpUNO(a, a); + llvm::Value* nan_b = builder_->CreateFCmpUNO(b, b); + return builder_->CreateSelect( + nan_a, a, + builder_->CreateSelect(nan_b, b, + builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b))); } else { return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b); } From 70822b7053efe70646bf2ee92b9122cab960e97c Mon Sep 17 00:00:00 2001 From: actink Date: Thu, 28 May 2026 22:21:12 +0800 Subject: [PATCH 5/5] fixup testcase asm check --- tests/python/codegen/test_target_codegen_aarch64.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 9191bea54934..9df0ab98baeb 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -226,7 +226,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): ) assert len(loads) > 1 - assert (len(compare) > 1 and len(select) == len(compare)) or len(max_instr) > 1 + assert (len(compare) > 1 and len(select) == 3 * len(compare)) or len(max_instr) > 1 @pytest.mark.skipif( @@ -269,7 +269,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): ) assert len(loads) > 1 - assert (len(compare) > 1 and len(select) == len(compare)) or len(min_instr) > 1 + assert (len(compare) > 1 and len(select) == 3 * len(compare)) or len(min_instr) > 1 @pytest.mark.skipif(