aboutsummaryrefslogtreecommitdiff
path: root/llvm
diff options
context:
space:
mode:
authorSanjay Patel <spatel@rotateright.com>2020-07-06 18:03:55 -0400
committerSanjay Patel <spatel@rotateright.com>2020-07-06 19:12:21 -0400
commitea71ba11ab1187af03a790dc20967ddd62f68bfe (patch)
tree1da086fd66de7b65f0a4c902838bef883a821695 /llvm
parent4029f8ede42f69f5fb5affb3eb008e03d448f407 (diff)
downloadllvm-ea71ba11ab1187af03a790dc20967ddd62f68bfe.zip
llvm-ea71ba11ab1187af03a790dc20967ddd62f68bfe.tar.gz
llvm-ea71ba11ab1187af03a790dc20967ddd62f68bfe.tar.bz2
[DAGCombiner] reassociate reciprocal sqrt expression to eliminate FP division
X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z) In the motivating case from PR46406: https://bugs.llvm.org/show_bug.cgi?id=46406 ...this is restoring the sequence that was originally in the source code. We extracted a term from within the sqrt because we do not know in instcombine whether a target will expand a sqrt call. Note: we could say that the transform in IR should be restricted, but that would not solve the problem if the source was originally in the pattern shown here. This is a gray area for fast-math-flag requirements. I think we should at least check fast-math-flags on the fdiv and fmul because I view this transform as 2 pieces: reassociate the fmul operands and form reciprocal from the fdiv (as with the existing transform). We could argue that the sqrt also needs FMF, but that was not required before, so we should change that in a follow-up patch if that seems better. We don't currently have a way to check that the target will produce a sqrt or recip estimate without actually creating nodes (the APIs are SDValue getSqrtEstimate() and SDValue getRecipEstimate()), so we clean up speculatively created nodes if we are not able to create an estimate. The x86 test with doubles verifies that we are not changing a test with no estimate sequence. Differential Revision: https://reviews.llvm.org/D82716
Diffstat (limited to 'llvm')
-rw-r--r--llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp18
-rw-r--r--llvm/test/CodeGen/X86/sqrt-fastmath.ll119
2 files changed, 81 insertions, 56 deletions
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 015d78a3..c94bbeb 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -13232,6 +13232,24 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) {
Y = N1.getOperand(0);
}
if (Sqrt.getNode()) {
+ // If the other multiply operand is known positive, pull it into the
+ // sqrt. That will eliminate the division if we convert to an estimate:
+ // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z)
+ // TODO: Also fold the case where A == Z (fabs is missing).
+ if (Flags.hasAllowReassociation() && N1.hasOneUse() &&
+ N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse() &&
+ Y.getOpcode() == ISD::FABS && Y.hasOneUse()) {
+ SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, Y.getOperand(0),
+ Y.getOperand(0), Flags);
+ SDValue AAZ =
+ DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0), Flags);
+ if (SDValue Rsqrt = buildRsqrtEstimate(AAZ, Flags))
+ return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt, Flags);
+
+ // Estimate creation failed. Clean up speculatively created nodes.
+ recursivelyDeleteUnusedNodes(AAZ.getNode());
+ }
+
// We found a FSQRT, so try to make this fold:
// X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y)
if (SDValue Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0), Flags)) {
diff --git a/llvm/test/CodeGen/X86/sqrt-fastmath.ll b/llvm/test/CodeGen/X86/sqrt-fastmath.ll
index b1582d7..29d21f6 100644
--- a/llvm/test/CodeGen/X86/sqrt-fastmath.ll
+++ b/llvm/test/CodeGen/X86/sqrt-fastmath.ll
@@ -618,46 +618,47 @@ define <16 x float> @v16f32_estimate(<16 x float> %x) #1 {
ret <16 x float> %div
}
-; x / (fabs(y) * sqrt(z))
+; x / (fabs(y) * sqrt(z)) --> x * rsqrt(y*y*z)
define float @div_sqrt_fabs_f32(float %x, float %y, float %z) {
; SSE-LABEL: div_sqrt_fabs_f32:
; SSE: # %bb.0:
-; SSE-NEXT: rsqrtss %xmm2, %xmm3
-; SSE-NEXT: mulss %xmm3, %xmm2
-; SSE-NEXT: mulss %xmm3, %xmm2
-; SSE-NEXT: addss {{.*}}(%rip), %xmm2
-; SSE-NEXT: mulss {{.*}}(%rip), %xmm3
-; SSE-NEXT: mulss %xmm2, %xmm3
-; SSE-NEXT: andps {{.*}}(%rip), %xmm1
-; SSE-NEXT: divss %xmm1, %xmm3
-; SSE-NEXT: mulss %xmm3, %xmm0
+; SSE-NEXT: mulss %xmm1, %xmm1
+; SSE-NEXT: mulss %xmm2, %xmm1
+; SSE-NEXT: xorps %xmm2, %xmm2
+; SSE-NEXT: rsqrtss %xmm1, %xmm2
+; SSE-NEXT: mulss %xmm2, %xmm1
+; SSE-NEXT: mulss %xmm2, %xmm1
+; SSE-NEXT: addss {{.*}}(%rip), %xmm1
+; SSE-NEXT: mulss {{.*}}(%rip), %xmm2
+; SSE-NEXT: mulss %xmm0, %xmm2
+; SSE-NEXT: mulss %xmm1, %xmm2
+; SSE-NEXT: movaps %xmm2, %xmm0
; SSE-NEXT: retq
;
; AVX1-LABEL: div_sqrt_fabs_f32:
; AVX1: # %bb.0:
-; AVX1-NEXT: vrsqrtss %xmm2, %xmm2, %xmm3
-; AVX1-NEXT: vmulss %xmm3, %xmm2, %xmm2
-; AVX1-NEXT: vmulss %xmm3, %xmm2, %xmm2
-; AVX1-NEXT: vaddss {{.*}}(%rip), %xmm2, %xmm2
-; AVX1-NEXT: vmulss {{.*}}(%rip), %xmm3, %xmm3
-; AVX1-NEXT: vmulss %xmm2, %xmm3, %xmm2
-; AVX1-NEXT: vandps {{.*}}(%rip), %xmm1, %xmm1
-; AVX1-NEXT: vdivss %xmm1, %xmm2, %xmm1
-; AVX1-NEXT: vmulss %xmm1, %xmm0, %xmm0
+; AVX1-NEXT: vmulss %xmm1, %xmm1, %xmm1
+; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm1
+; AVX1-NEXT: vrsqrtss %xmm1, %xmm1, %xmm2
+; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm1
+; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm1
+; AVX1-NEXT: vaddss {{.*}}(%rip), %xmm1, %xmm1
+; AVX1-NEXT: vmulss {{.*}}(%rip), %xmm2, %xmm2
+; AVX1-NEXT: vmulss %xmm0, %xmm2, %xmm0
+; AVX1-NEXT: vmulss %xmm0, %xmm1, %xmm0
; AVX1-NEXT: retq
;
; AVX512-LABEL: div_sqrt_fabs_f32:
; AVX512: # %bb.0:
-; AVX512-NEXT: vrsqrtss %xmm2, %xmm2, %xmm3
-; AVX512-NEXT: vmulss %xmm3, %xmm2, %xmm2
-; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + mem
-; AVX512-NEXT: vmulss {{.*}}(%rip), %xmm3, %xmm3
-; AVX512-NEXT: vbroadcastss {{.*#+}} xmm4 = [NaN,NaN,NaN,NaN]
-; AVX512-NEXT: vmulss %xmm2, %xmm3, %xmm2
-; AVX512-NEXT: vandps %xmm4, %xmm1, %xmm1
-; AVX512-NEXT: vdivss %xmm1, %xmm2, %xmm1
-; AVX512-NEXT: vmulss %xmm1, %xmm0, %xmm0
+; AVX512-NEXT: vmulss %xmm1, %xmm1, %xmm1
+; AVX512-NEXT: vmulss %xmm2, %xmm1, %xmm1
+; AVX512-NEXT: vrsqrtss %xmm1, %xmm1, %xmm2
+; AVX512-NEXT: vmulss %xmm2, %xmm1, %xmm1
+; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm1 = (xmm2 * xmm1) + mem
+; AVX512-NEXT: vmulss {{.*}}(%rip), %xmm2, %xmm2
+; AVX512-NEXT: vmulss %xmm0, %xmm2, %xmm0
+; AVX512-NEXT: vmulss %xmm0, %xmm1, %xmm0
; AVX512-NEXT: retq
%s = call fast float @llvm.sqrt.f32(float %z)
%a = call fast float @llvm.fabs.f32(float %y)
@@ -666,47 +667,46 @@ define float @div_sqrt_fabs_f32(float %x, float %y, float %z) {
ret float %d
}
-; x / (fabs(y) * sqrt(z))
+; x / (fabs(y) * sqrt(z)) --> x * rsqrt(y*y*z)
define <4 x float> @div_sqrt_fabs_v4f32(<4 x float> %x, <4 x float> %y, <4 x float> %z) {
; SSE-LABEL: div_sqrt_fabs_v4f32:
; SSE: # %bb.0:
-; SSE-NEXT: rsqrtps %xmm2, %xmm3
-; SSE-NEXT: mulps %xmm3, %xmm2
-; SSE-NEXT: mulps %xmm3, %xmm2
-; SSE-NEXT: addps {{.*}}(%rip), %xmm2
-; SSE-NEXT: mulps {{.*}}(%rip), %xmm3
-; SSE-NEXT: mulps %xmm2, %xmm3
-; SSE-NEXT: andps {{.*}}(%rip), %xmm1
-; SSE-NEXT: divps %xmm1, %xmm3
-; SSE-NEXT: mulps %xmm3, %xmm0
+; SSE-NEXT: mulps %xmm1, %xmm1
+; SSE-NEXT: mulps %xmm2, %xmm1
+; SSE-NEXT: rsqrtps %xmm1, %xmm2
+; SSE-NEXT: mulps %xmm2, %xmm1
+; SSE-NEXT: mulps %xmm2, %xmm1
+; SSE-NEXT: addps {{.*}}(%rip), %xmm1
+; SSE-NEXT: mulps {{.*}}(%rip), %xmm2
+; SSE-NEXT: mulps %xmm1, %xmm2
+; SSE-NEXT: mulps %xmm2, %xmm0
; SSE-NEXT: retq
;
; AVX1-LABEL: div_sqrt_fabs_v4f32:
; AVX1: # %bb.0:
-; AVX1-NEXT: vrsqrtps %xmm2, %xmm3
-; AVX1-NEXT: vmulps %xmm3, %xmm2, %xmm2
-; AVX1-NEXT: vmulps %xmm3, %xmm2, %xmm2
-; AVX1-NEXT: vaddps {{.*}}(%rip), %xmm2, %xmm2
-; AVX1-NEXT: vmulps {{.*}}(%rip), %xmm3, %xmm3
-; AVX1-NEXT: vmulps %xmm2, %xmm3, %xmm2
-; AVX1-NEXT: vandps {{.*}}(%rip), %xmm1, %xmm1
-; AVX1-NEXT: vdivps %xmm1, %xmm2, %xmm1
+; AVX1-NEXT: vmulps %xmm1, %xmm1, %xmm1
+; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm1
+; AVX1-NEXT: vrsqrtps %xmm1, %xmm2
+; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm1
+; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm1
+; AVX1-NEXT: vaddps {{.*}}(%rip), %xmm1, %xmm1
+; AVX1-NEXT: vmulps {{.*}}(%rip), %xmm2, %xmm2
+; AVX1-NEXT: vmulps %xmm1, %xmm2, %xmm1
; AVX1-NEXT: vmulps %xmm1, %xmm0, %xmm0
; AVX1-NEXT: retq
;
; AVX512-LABEL: div_sqrt_fabs_v4f32:
; AVX512: # %bb.0:
-; AVX512-NEXT: vrsqrtps %xmm2, %xmm3
-; AVX512-NEXT: vmulps %xmm3, %xmm2, %xmm2
-; AVX512-NEXT: vbroadcastss {{.*#+}} xmm4 = [-3.0E+0,-3.0E+0,-3.0E+0,-3.0E+0]
-; AVX512-NEXT: vfmadd231ps {{.*#+}} xmm4 = (xmm3 * xmm2) + xmm4
-; AVX512-NEXT: vbroadcastss {{.*#+}} xmm2 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
-; AVX512-NEXT: vmulps %xmm2, %xmm3, %xmm2
-; AVX512-NEXT: vbroadcastss {{.*#+}} xmm3 = [NaN,NaN,NaN,NaN]
-; AVX512-NEXT: vmulps %xmm4, %xmm2, %xmm2
-; AVX512-NEXT: vandps %xmm3, %xmm1, %xmm1
-; AVX512-NEXT: vdivps %xmm1, %xmm2, %xmm1
+; AVX512-NEXT: vmulps %xmm1, %xmm1, %xmm1
+; AVX512-NEXT: vmulps %xmm2, %xmm1, %xmm1
+; AVX512-NEXT: vrsqrtps %xmm1, %xmm2
+; AVX512-NEXT: vmulps %xmm2, %xmm1, %xmm1
+; AVX512-NEXT: vbroadcastss {{.*#+}} xmm3 = [-3.0E+0,-3.0E+0,-3.0E+0,-3.0E+0]
+; AVX512-NEXT: vfmadd231ps {{.*#+}} xmm3 = (xmm2 * xmm1) + xmm3
+; AVX512-NEXT: vbroadcastss {{.*#+}} xmm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1]
+; AVX512-NEXT: vmulps %xmm1, %xmm2, %xmm1
+; AVX512-NEXT: vmulps %xmm3, %xmm1, %xmm1
; AVX512-NEXT: vmulps %xmm1, %xmm0, %xmm0
; AVX512-NEXT: retq
%s = call <4 x float> @llvm.sqrt.v4f32(<4 x float> %z)
@@ -716,6 +716,11 @@ define <4 x float> @div_sqrt_fabs_v4f32(<4 x float> %x, <4 x float> %y, <4 x flo
ret <4 x float> %d
}
+; This has 'arcp' but does not have 'reassoc' FMF.
+; We allow converting the sqrt to an estimate, but
+; do not pull the divisor into the estimate.
+; x / (fabs(y) * sqrt(z)) --> x * rsqrt(z) / fabs(y)
+
define <4 x float> @div_sqrt_fabs_v4f32_fmf(<4 x float> %x, <4 x float> %y, <4 x float> %z) {
; SSE-LABEL: div_sqrt_fabs_v4f32_fmf:
; SSE: # %bb.0:
@@ -765,6 +770,8 @@ define <4 x float> @div_sqrt_fabs_v4f32_fmf(<4 x float> %x, <4 x float> %y, <4 x
ret <4 x float> %d
}
+; No estimates for f64, so do not convert fabs into an fmul.
+
define double @div_sqrt_fabs_f64(double %x, double %y, double %z) {
; SSE-LABEL: div_sqrt_fabs_f64:
; SSE: # %bb.0: