From 8bd9ade6284a793c898da133723121c3bcc49ef7 Mon Sep 17 00:00:00 2001 From: Yingwei Zheng Date: Sat, 3 Aug 2024 13:35:22 +0800 Subject: [InstCombine] Fold `fcmp pred sqrt(X), 0.0 -> fcmp pred2 X, 0.0` (#101626) Proof (Please run alive-tv with larger smt-to): https://alive2.llvm.org/ce/z/-aqixk FMF propagation: https://alive2.llvm.org/ce/z/zyKK_p ``` sqrt(X) < 0.0 --> false sqrt(X) u>= 0.0 --> true sqrt(X) u< 0.0 --> X u< 0.0 sqrt(X) u<= 0.0 --> X u<= 0.0 sqrt(X) > 0.0 --> X > 0.0 sqrt(X) >= 0.0 --> X >= 0.0 sqrt(X) == 0.0 --> X == 0.0 sqrt(X) u!= 0.0 --> X u!= 0.0 sqrt(X) <= 0.0 --> X == 0.0 sqrt(X) u> 0.0 --> X u!= 0.0 sqrt(X) u== 0.0 --> X u<= 0.0 sqrt(X) != 0.0 --> X > 0.0 !isnan(sqrt(X)) --> X >= 0.0 isnan(sqrt(X)) --> X u< 0.0 ``` In most cases, `sqrt` cannot be eliminated since it has multiple uses. But this patch will break data dependencies and allow optimizer to sink expensive `sqrt` calls into successor blocks. --- .../Transforms/InstCombine/InstCombineCompares.cpp | 64 ++++++ llvm/test/Transforms/InstCombine/fcmp.ll | 233 +++++++++++++++++++++ .../test/Transforms/InstCombine/known-never-nan.ll | 4 +- 3 files changed, 298 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 3b6df27..94786f0 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -7980,6 +7980,67 @@ static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) { } } +/// Optimize sqrt(X) compared with zero. +static Instruction *foldSqrtWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) { + Value *X; + if (!match(I.getOperand(0), m_Sqrt(m_Value(X)))) + return nullptr; + + if (!match(I.getOperand(1), m_PosZeroFP())) + return nullptr; + + auto ReplacePredAndOp0 = [&](FCmpInst::Predicate P) { + I.setPredicate(P); + return IC.replaceOperand(I, 0, X); + }; + + // Clear ninf flag if sqrt doesn't have it. + if (!cast(I.getOperand(0))->hasNoInfs()) + I.setHasNoInfs(false); + + switch (I.getPredicate()) { + case FCmpInst::FCMP_OLT: + case FCmpInst::FCMP_UGE: + // sqrt(X) < 0.0 --> false + // sqrt(X) u>= 0.0 --> true + llvm_unreachable("fcmp should have simplified"); + case FCmpInst::FCMP_ULT: + case FCmpInst::FCMP_ULE: + case FCmpInst::FCMP_OGT: + case FCmpInst::FCMP_OGE: + case FCmpInst::FCMP_OEQ: + case FCmpInst::FCMP_UNE: + // sqrt(X) u< 0.0 --> X u< 0.0 + // sqrt(X) u<= 0.0 --> X u<= 0.0 + // sqrt(X) > 0.0 --> X > 0.0 + // sqrt(X) >= 0.0 --> X >= 0.0 + // sqrt(X) == 0.0 --> X == 0.0 + // sqrt(X) u!= 0.0 --> X u!= 0.0 + return IC.replaceOperand(I, 0, X); + + case FCmpInst::FCMP_OLE: + // sqrt(X) <= 0.0 --> X == 0.0 + return ReplacePredAndOp0(FCmpInst::FCMP_OEQ); + case FCmpInst::FCMP_UGT: + // sqrt(X) u> 0.0 --> X u!= 0.0 + return ReplacePredAndOp0(FCmpInst::FCMP_UNE); + case FCmpInst::FCMP_UEQ: + // sqrt(X) u== 0.0 --> X u<= 0.0 + return ReplacePredAndOp0(FCmpInst::FCMP_ULE); + case FCmpInst::FCMP_ONE: + // sqrt(X) != 0.0 --> X > 0.0 + return ReplacePredAndOp0(FCmpInst::FCMP_OGT); + case FCmpInst::FCMP_ORD: + // !isnan(sqrt(X)) --> X >= 0.0 + return ReplacePredAndOp0(FCmpInst::FCMP_OGE); + case FCmpInst::FCMP_UNO: + // isnan(sqrt(X)) --> X u< 0.0 + return ReplacePredAndOp0(FCmpInst::FCMP_ULT); + default: + llvm_unreachable("Unexpected predicate!"); + } +} + static Instruction *foldFCmpFNegCommonOp(FCmpInst &I) { CmpInst::Predicate Pred = I.getPredicate(); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -8247,6 +8308,9 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { if (Instruction *R = foldFabsWithFcmpZero(I, *this)) return R; + if (Instruction *R = foldSqrtWithFcmpZero(I, *this)) + return R; + if (match(Op0, m_FNeg(m_Value(X)))) { // fcmp pred (fneg X), C --> fcmp swap(pred) X, -C Constant *C; diff --git a/llvm/test/Transforms/InstCombine/fcmp.ll b/llvm/test/Transforms/InstCombine/fcmp.ll index 656b3d2..8afb646 100644 --- a/llvm/test/Transforms/InstCombine/fcmp.ll +++ b/llvm/test/Transforms/InstCombine/fcmp.ll @@ -2117,3 +2117,236 @@ define <8 x i1> @fcmp_ogt_fsub_const_vec_denormal_preserve-sign(<8 x float> %x, %cmp = fcmp ogt <8 x float> %fs, zeroinitializer ret <8 x i1> %cmp } + +define i1 @fcmp_sqrt_zero_olt(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_olt( +; CHECK-NEXT: ret i1 false +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp olt half %sqrt, 0.0 + ret i1 %cmp +} + +define i1 @fcmp_sqrt_zero_ult(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_ult( +; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp ult half %sqrt, 0.0 + ret i1 %cmp +} + +define i1 @fcmp_sqrt_zero_ult_fmf(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_ult_fmf( +; CHECK-NEXT: [[CMP:%.*]] = fcmp nsz ult half [[X:%.*]], 0xH0000 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp ninf nsz ult half %sqrt, 0.0 + ret i1 %cmp +} + +define i1 @fcmp_sqrt_zero_ult_fmf_sqrt_ninf(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_ult_fmf_sqrt_ninf( +; CHECK-NEXT: [[CMP:%.*]] = fcmp ninf nsz ult half [[X:%.*]], 0xH0000 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call ninf half @llvm.sqrt.f16(half %x) + %cmp = fcmp ninf nsz ult half %sqrt, 0.0 + ret i1 %cmp +} + +define i1 @fcmp_sqrt_zero_ult_nzero(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_ult_nzero( +; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp ult half %sqrt, -0.0 + ret i1 %cmp +} + +define <2 x i1> @fcmp_sqrt_zero_ult_vec(<2 x half> %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_ult_vec( +; CHECK-NEXT: [[CMP:%.*]] = fcmp ult <2 x half> [[X:%.*]], zeroinitializer +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %sqrt = call <2 x half> @llvm.sqrt.v2f16(<2 x half> %x) + %cmp = fcmp ult <2 x half> %sqrt, zeroinitializer + ret <2 x i1> %cmp +} + +define <2 x i1> @fcmp_sqrt_zero_ult_vec_mixed_zero(<2 x half> %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_ult_vec_mixed_zero( +; CHECK-NEXT: [[CMP:%.*]] = fcmp ult <2 x half> [[X:%.*]], zeroinitializer +; CHECK-NEXT: ret <2 x i1> [[CMP]] +; + %sqrt = call <2 x half> @llvm.sqrt.v2f16(<2 x half> %x) + %cmp = fcmp ult <2 x half> %sqrt, + ret <2 x i1> %cmp +} + +define i1 @fcmp_sqrt_zero_ole(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_ole( +; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq half [[X:%.*]], 0xH0000 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp ole half %sqrt, 0.0 + ret i1 %cmp +} + +define i1 @fcmp_sqrt_zero_ule(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_ule( +; CHECK-NEXT: [[CMP:%.*]] = fcmp ule half [[X:%.*]], 0xH0000 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp ule half %sqrt, 0.0 + ret i1 %cmp +} + +define i1 @fcmp_sqrt_zero_ogt(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_ogt( +; CHECK-NEXT: [[CMP:%.*]] = fcmp ogt half [[X:%.*]], 0xH0000 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp ogt half %sqrt, 0.0 + ret i1 %cmp +} + +define i1 @fcmp_sqrt_zero_ugt(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_ugt( +; CHECK-NEXT: [[CMP:%.*]] = fcmp une half [[X:%.*]], 0xH0000 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp ugt half %sqrt, 0.0 + ret i1 %cmp +} + +define i1 @fcmp_sqrt_zero_oge(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_oge( +; CHECK-NEXT: [[CMP:%.*]] = fcmp oge half [[X:%.*]], 0xH0000 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp oge half %sqrt, 0.0 + ret i1 %cmp +} + +define i1 @fcmp_sqrt_zero_uge(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_uge( +; CHECK-NEXT: ret i1 true +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp uge half %sqrt, 0.0 + ret i1 %cmp +} + +define i1 @fcmp_sqrt_zero_oeq(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_oeq( +; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq half [[X:%.*]], 0xH0000 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp oeq half %sqrt, 0.0 + ret i1 %cmp +} + +define i1 @fcmp_sqrt_zero_ueq(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_ueq( +; CHECK-NEXT: [[CMP:%.*]] = fcmp ule half [[X:%.*]], 0xH0000 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp ueq half %sqrt, 0.0 + ret i1 %cmp +} + +define i1 @fcmp_sqrt_zero_one(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_one( +; CHECK-NEXT: [[CMP:%.*]] = fcmp ogt half [[X:%.*]], 0xH0000 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp one half %sqrt, 0.0 + ret i1 %cmp +} + +define i1 @fcmp_sqrt_zero_une(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_une( +; CHECK-NEXT: [[CMP:%.*]] = fcmp une half [[X:%.*]], 0xH0000 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp une half %sqrt, 0.0 + ret i1 %cmp +} + +define i1 @fcmp_sqrt_zero_ord(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_ord( +; CHECK-NEXT: [[CMP:%.*]] = fcmp oge half [[X:%.*]], 0xH0000 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp ord half %sqrt, 0.0 + ret i1 %cmp +} + +define i1 @fcmp_sqrt_zero_uno(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_uno( +; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp uno half %sqrt, 0.0 + ret i1 %cmp +} + +; Make sure that ninf is cleared. +define i1 @fcmp_sqrt_zero_uno_fmf(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_uno_fmf( +; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[X:%.*]], 0xH0000 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp ninf uno half %sqrt, 0.0 + ret i1 %cmp +} + +define i1 @fcmp_sqrt_zero_uno_fmf_sqrt_ninf(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_uno_fmf_sqrt_ninf( +; CHECK-NEXT: [[CMP:%.*]] = fcmp ninf ult half [[X:%.*]], 0xH0000 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call ninf half @llvm.sqrt.f16(half %x) + %cmp = fcmp ninf uno half %sqrt, 0.0 + ret i1 %cmp +} + +; negative tests + +define i1 @fcmp_sqrt_zero_ult_var(half %x, half %y) { +; CHECK-LABEL: @fcmp_sqrt_zero_ult_var( +; CHECK-NEXT: [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]]) +; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[SQRT]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp ult half %sqrt, %y + ret i1 %cmp +} + +define i1 @fcmp_sqrt_zero_ult_nonzero(half %x) { +; CHECK-LABEL: @fcmp_sqrt_zero_ult_nonzero( +; CHECK-NEXT: [[SQRT:%.*]] = call half @llvm.sqrt.f16(half [[X:%.*]]) +; CHECK-NEXT: [[CMP:%.*]] = fcmp ult half [[SQRT]], 0xH3C00 +; CHECK-NEXT: ret i1 [[CMP]] +; + %sqrt = call half @llvm.sqrt.f16(half %x) + %cmp = fcmp ult half %sqrt, 1.000000e+00 + ret i1 %cmp +} diff --git a/llvm/test/Transforms/InstCombine/known-never-nan.ll b/llvm/test/Transforms/InstCombine/known-never-nan.ll index a1cabc2..82075b3 100644 --- a/llvm/test/Transforms/InstCombine/known-never-nan.ll +++ b/llvm/test/Transforms/InstCombine/known-never-nan.ll @@ -9,9 +9,7 @@ define i1 @fabs_sqrt_src_maybe_nan(double %arg0, double %arg1) { ; CHECK-LABEL: @fabs_sqrt_src_maybe_nan( -; CHECK-NEXT: [[FABS:%.*]] = call double @llvm.fabs.f64(double [[ARG0:%.*]]) -; CHECK-NEXT: [[OP:%.*]] = call double @llvm.sqrt.f64(double [[FABS]]) -; CHECK-NEXT: [[TMP:%.*]] = fcmp ord double [[OP]], 0.000000e+00 +; CHECK-NEXT: [[TMP:%.*]] = fcmp ord double [[ARG0:%.*]], 0.000000e+00 ; CHECK-NEXT: ret i1 [[TMP]] ; %fabs = call double @llvm.fabs.f64(double %arg0) -- cgit v1.1