aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp')
-rw-r--r--llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp59
1 files changed, 59 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index a30edf5..9339ba3 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -371,6 +371,10 @@ private:
NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
+ /// Identifies SelectInsts in a loop that has reduction with predication masks
+ /// and/or predicated tail folding
+ NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
+
Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
/// Complete IR modifications after producing new reduction operation:
@@ -889,6 +893,9 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
if (NodePtr CN = identifyPHINode(Real, Imag))
return CN;
+ if (NodePtr CN = identifySelectNode(Real, Imag))
+ return CN;
+
auto *VTy = cast<VectorType>(Real->getType());
auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
@@ -1713,6 +1720,45 @@ ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
return submitCompositeNode(PlaceholderNode);
}
+ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
+ Instruction *Imag) {
+ auto *SelectReal = dyn_cast<SelectInst>(Real);
+ auto *SelectImag = dyn_cast<SelectInst>(Imag);
+ if (!SelectReal || !SelectImag)
+ return nullptr;
+
+ Instruction *MaskA, *MaskB;
+ Instruction *AR, *AI, *RA, *BI;
+ if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
+ m_Instruction(RA))) ||
+ !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
+ m_Instruction(BI))))
+ return nullptr;
+
+ if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
+ return nullptr;
+
+ if (!MaskA->getType()->isVectorTy())
+ return nullptr;
+
+ auto NodeA = identifyNode(AR, AI);
+ if (!NodeA)
+ return nullptr;
+
+ auto NodeB = identifyNode(RA, BI);
+ if (!NodeB)
+ return nullptr;
+
+ NodePtr PlaceholderNode = prepareCompositeNode(
+ ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
+ PlaceholderNode->addOperand(NodeA);
+ PlaceholderNode->addOperand(NodeB);
+ FinalInstructions.insert(MaskA);
+ FinalInstructions.insert(MaskB);
+ return submitCompositeNode(PlaceholderNode);
+}
+
static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
FastMathFlags Flags, Value *InputA,
Value *InputB) {
@@ -1787,6 +1833,19 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
ReplacementNode = replaceNode(Builder, Node->Operands[0]);
processReductionOperation(ReplacementNode, Node);
break;
+ case ComplexDeinterleavingOperation::ReductionSelect: {
+ auto *MaskReal = Node->Real->getOperand(0);
+ auto *MaskImag = Node->Imag->getOperand(0);
+ auto *A = replaceNode(Builder, Node->Operands[0]);
+ auto *B = replaceNode(Builder, Node->Operands[1]);
+ auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
+ cast<VectorType>(MaskReal->getType()));
+ auto *NewMask =
+ Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
+ NewMaskTy, {MaskReal, MaskImag});
+ ReplacementNode = Builder.CreateSelect(NewMask, A, B);
+ break;
+ }
}
assert(ReplacementNode && "Target failed to create Intrinsic call.");