diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 44308be5ba2f..dc83a40f6876 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1598,13 +1598,31 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); - return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b); + if (op->a.dtype().is_float()) { + llvm::Value* nan_a = builder_->CreateFCmpUNO(a, a); + llvm::Value* nan_b = builder_->CreateFCmpUNO(b, b); + return builder_->CreateSelect( + nan_a, a, + builder_->CreateSelect(nan_b, b, + builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b))); + } else { + return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b); + } } llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) { llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); - return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b); + if (op->a.dtype().is_float()) { + llvm::Value* nan_a = builder_->CreateFCmpUNO(a, a); + llvm::Value* nan_b = builder_->CreateFCmpUNO(b, b); + return builder_->CreateSelect( + nan_a, a, + builder_->CreateSelect(nan_b, b, + builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b))); + } else { + return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b); + } } llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) { diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index 9191bea54934..9df0ab98baeb 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -226,7 +226,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): ) assert len(loads) > 1 - assert (len(compare) > 1 and len(select) == len(compare)) or len(max_instr) > 1 + assert (len(compare) > 1 and len(select) == 3 * len(compare)) or len(max_instr) > 1 @pytest.mark.skipif( @@ -269,7 +269,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): ) assert len(loads) > 1 - assert (len(compare) > 1 and len(select) == len(compare)) or len(min_instr) > 1 + assert (len(compare) > 1 and len(select) == 3 * len(compare)) or len(min_instr) > 1 @pytest.mark.skipif(