aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
diff options
context:
space:
mode:
authorIgor Kirillov <igor.kirillov@arm.com>2023-06-23 09:55:34 +0000
committerIgor Kirillov <igor.kirillov@arm.com>2023-06-23 10:13:22 +0000
commit04a8070b46da2bcd47d0a134922409dc16bb9d57 (patch)
tree33ce3a20bc949144b0d2d419db4d350a1e6c51ab /llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
parent2273741ea2547ecda28cb01d7679d0563b35ac16 (diff)
downloadllvm-04a8070b46da2bcd47d0a134922409dc16bb9d57.zip
llvm-04a8070b46da2bcd47d0a134922409dc16bb9d57.tar.gz
llvm-04a8070b46da2bcd47d0a134922409dc16bb9d57.tar.bz2
Revert "Revert "[CodeGen] Extend reduction support in ComplexDeinterleaving pass to support predication""
Adds the capability to recognize SelectInst that appear in the IR. These instructions are generated during scalable vectorization for reduction and when the code contains conditions inside the loop body or when "-prefer-predicate-over-epilogue=predicate-dont-vectorize" is set. Differential Revision: https://reviews.llvm.org/D152558 This reverts commit ab09654832dba5cef8baa6400fdfd3e4d1495624. Reason: Reapplying after removing unnecessary default case in switch expression.
Diffstat (limited to 'llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp')
-rw-r--r--llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp59
1 files changed, 59 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index a30edf5..9339ba3 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -371,6 +371,10 @@ private:
NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
+ /// Identifies SelectInsts in a loop that has reduction with predication masks
+ /// and/or predicated tail folding
+ NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
+
Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
/// Complete IR modifications after producing new reduction operation:
@@ -889,6 +893,9 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
if (NodePtr CN = identifyPHINode(Real, Imag))
return CN;
+ if (NodePtr CN = identifySelectNode(Real, Imag))
+ return CN;
+
auto *VTy = cast<VectorType>(Real->getType());
auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
@@ -1713,6 +1720,45 @@ ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
return submitCompositeNode(PlaceholderNode);
}
+ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
+ Instruction *Imag) {
+ auto *SelectReal = dyn_cast<SelectInst>(Real);
+ auto *SelectImag = dyn_cast<SelectInst>(Imag);
+ if (!SelectReal || !SelectImag)
+ return nullptr;
+
+ Instruction *MaskA, *MaskB;
+ Instruction *AR, *AI, *RA, *BI;
+ if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
+ m_Instruction(RA))) ||
+ !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
+ m_Instruction(BI))))
+ return nullptr;
+
+ if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
+ return nullptr;
+
+ if (!MaskA->getType()->isVectorTy())
+ return nullptr;
+
+ auto NodeA = identifyNode(AR, AI);
+ if (!NodeA)
+ return nullptr;
+
+ auto NodeB = identifyNode(RA, BI);
+ if (!NodeB)
+ return nullptr;
+
+ NodePtr PlaceholderNode = prepareCompositeNode(
+ ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
+ PlaceholderNode->addOperand(NodeA);
+ PlaceholderNode->addOperand(NodeB);
+ FinalInstructions.insert(MaskA);
+ FinalInstructions.insert(MaskB);
+ return submitCompositeNode(PlaceholderNode);
+}
+
static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
FastMathFlags Flags, Value *InputA,
Value *InputB) {
@@ -1787,6 +1833,19 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
ReplacementNode = replaceNode(Builder, Node->Operands[0]);
processReductionOperation(ReplacementNode, Node);
break;
+ case ComplexDeinterleavingOperation::ReductionSelect: {
+ auto *MaskReal = Node->Real->getOperand(0);
+ auto *MaskImag = Node->Imag->getOperand(0);
+ auto *A = replaceNode(Builder, Node->Operands[0]);
+ auto *B = replaceNode(Builder, Node->Operands[1]);
+ auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
+ cast<VectorType>(MaskReal->getType()));
+ auto *NewMask =
+ Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
+ NewMaskTy, {MaskReal, MaskImag});
+ ReplacementNode = Builder.CreateSelect(NewMask, A, B);
+ break;
+ }
}
assert(ReplacementNode && "Target failed to create Intrinsic call.");