aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZain Jaffal <zain@jjaffal.com>2024-03-09 17:15:14 +0000
committerGitHub <noreply@github.com>2024-03-09 17:15:14 +0000
commitf5811494b0cde306e98caa339e4dc1c06cb5e8e9 (patch)
treee230c9b62fec528ded123e4b9bcb78a799a2a63d
parent57a2229a2f62746d5616f0bd82a03b23eb459cf3 (diff)
downloadllvm-f5811494b0cde306e98caa339e4dc1c06cb5e8e9.zip
llvm-f5811494b0cde306e98caa339e4dc1c06cb5e8e9.tar.gz
llvm-f5811494b0cde306e98caa339e4dc1c06cb5e8e9.tar.bz2
check if operand is div in fold FDivSqrtDivisor (#81970)
This patch fixes the issues introduced in https://github.com/llvm/llvm-project/commit/bb5c3899d1936ebdf7ebf5ca4347ee2e057bee7f. I moved the check for the instruction to be div before I check for the fast math flags which resolves the crash in ``` float a, b; double sqrt(); void c() { b = a / sqrt(a); } ``` --------- Co-authored-by: Matt Arsenault <arsenm2@gmail.com>
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp31
-rw-r--r--llvm/test/Transforms/InstCombine/fdiv-sqrt.ll34
2 files changed, 56 insertions, 9 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 3ebf6b3..278be62 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1709,6 +1709,34 @@ static Instruction *foldFDivPowDivisor(BinaryOperator &I,
return BinaryOperator::CreateFMulFMF(Op0, Pow, &I);
}
+/// Convert div to mul if we have an sqrt divisor iff sqrt's operand is a fdiv
+/// instruction.
+static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
+ InstCombiner::BuilderTy &Builder) {
+ // X / sqrt(Y / Z) --> X * sqrt(Z / Y)
+ if (!I.hasAllowReassoc() || !I.hasAllowReciprocal())
+ return nullptr;
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ auto *II = dyn_cast<IntrinsicInst>(Op1);
+ if (!II || II->getIntrinsicID() != Intrinsic::sqrt || !II->hasOneUse() ||
+ !II->hasAllowReassoc() || !II->hasAllowReciprocal())
+ return nullptr;
+
+ Value *Y, *Z;
+ auto *DivOp = dyn_cast<Instruction>(II->getOperand(0));
+ if (!DivOp)
+ return nullptr;
+ if (!match(DivOp, m_FDiv(m_Value(Y), m_Value(Z))))
+ return nullptr;
+ if (!DivOp->hasAllowReassoc() || !I.hasAllowReciprocal() ||
+ !DivOp->hasOneUse())
+ return nullptr;
+ Value *SwapDiv = Builder.CreateFDivFMF(Z, Y, DivOp);
+ Value *NewSqrt =
+ Builder.CreateUnaryIntrinsic(II->getIntrinsicID(), SwapDiv, II);
+ return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
+}
+
Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
Module *M = I.getModule();
@@ -1816,6 +1844,9 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
if (Instruction *Mul = foldFDivPowDivisor(I, Builder))
return Mul;
+ if (Instruction *Mul = foldFDivSqrtDivisor(I, Builder))
+ return Mul;
+
// pow(X, Y) / X --> pow(X, Y-1)
if (I.hasAllowReassoc() &&
match(Op0, m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Specific(Op1),
diff --git a/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
index 346271b..9f030c5 100644
--- a/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
+++ b/llvm/test/Transforms/InstCombine/fdiv-sqrt.ll
@@ -6,9 +6,9 @@ declare double @llvm.sqrt.f64(double)
define double @sqrt_div_fast(double %x, double %y, double %z) {
; CHECK-LABEL: @sqrt_div_fast(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[DIV:%.*]] = fdiv fast double [[Y:%.*]], [[Z:%.*]]
-; CHECK-NEXT: [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[DIV]])
-; CHECK-NEXT: [[DIV1:%.*]] = fdiv fast double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT: [[TMP0:%.*]] = fdiv fast double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = call fast double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT: [[DIV1:%.*]] = fmul fast double [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret double [[DIV1]]
;
entry:
@@ -36,9 +36,9 @@ entry:
define double @sqrt_div_reassoc_arcp(double %x, double %y, double %z) {
; CHECK-LABEL: @sqrt_div_reassoc_arcp(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[DIV:%.*]] = fdiv reassoc arcp double [[Y:%.*]], [[Z:%.*]]
-; CHECK-NEXT: [[SQRT:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[DIV]])
-; CHECK-NEXT: [[DIV1:%.*]] = fdiv reassoc arcp double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT: [[TMP0:%.*]] = fdiv reassoc arcp double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT: [[DIV1:%.*]] = fmul reassoc arcp double [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret double [[DIV1]]
;
entry:
@@ -96,9 +96,9 @@ entry:
define double @sqrt_div_arcp_missing(double %x, double %y, double %z) {
; CHECK-LABEL: @sqrt_div_arcp_missing(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[DIV:%.*]] = fdiv reassoc double [[Y:%.*]], [[Z:%.*]]
-; CHECK-NEXT: [[SQRT:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[DIV]])
-; CHECK-NEXT: [[DIV1:%.*]] = fdiv reassoc arcp double [[X:%.*]], [[SQRT]]
+; CHECK-NEXT: [[TMP0:%.*]] = fdiv reassoc double [[Z:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = call reassoc arcp double @llvm.sqrt.f64(double [[TMP0]])
+; CHECK-NEXT: [[DIV1:%.*]] = fmul reassoc arcp double [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret double [[DIV1]]
;
entry:
@@ -173,3 +173,19 @@ entry:
ret double %div1
}
+define float @sqrt_non_div_operator(float %a) {
+; CHECK-LABEL: @sqrt_non_div_operator(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[CONV:%.*]] = fpext float [[A:%.*]] to double
+; CHECK-NEXT: [[SQRT:%.*]] = call fast double @llvm.sqrt.f64(double [[CONV]])
+; CHECK-NEXT: [[DIV:%.*]] = fdiv fast double [[CONV]], [[SQRT]]
+; CHECK-NEXT: [[CONV2:%.*]] = fptrunc double [[DIV]] to float
+; CHECK-NEXT: ret float [[CONV2]]
+;
+entry:
+ %conv = fpext float %a to double
+ %sqrt = call fast double @llvm.sqrt.f64(double %conv)
+ %div = fdiv fast double %conv, %sqrt
+ %conv2 = fptrunc double %div to float
+ ret float %conv2
+}