diff options
Diffstat (limited to 'llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp')
-rw-r--r-- | llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp index a30edf5..9339ba3 100644 --- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp +++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp @@ -371,6 +371,10 @@ private: NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); + /// Identifies SelectInsts in a loop that has reduction with predication masks + /// and/or predicated tail folding + NodePtr identifySelectNode(Instruction *Real, Instruction *Imag); + Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); /// Complete IR modifications after producing new reduction operation: @@ -889,6 +893,9 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) { if (NodePtr CN = identifyPHINode(Real, Imag)) return CN; + if (NodePtr CN = identifySelectNode(Real, Imag)) + return CN; + auto *VTy = cast<VectorType>(Real->getType()); auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); @@ -1713,6 +1720,45 @@ ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, return submitCompositeNode(PlaceholderNode); } +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real, + Instruction *Imag) { + auto *SelectReal = dyn_cast<SelectInst>(Real); + auto *SelectImag = dyn_cast<SelectInst>(Imag); + if (!SelectReal || !SelectImag) + return nullptr; + + Instruction *MaskA, *MaskB; + Instruction *AR, *AI, *RA, *BI; + if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR), + m_Instruction(RA))) || + !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI), + m_Instruction(BI)))) + return nullptr; + + if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB)) + return nullptr; + + if (!MaskA->getType()->isVectorTy()) + return nullptr; + + auto NodeA = identifyNode(AR, AI); + if (!NodeA) + return nullptr; + + auto NodeB = identifyNode(RA, BI); + if (!NodeB) + return nullptr; + + NodePtr PlaceholderNode = prepareCompositeNode( + ComplexDeinterleavingOperation::ReductionSelect, Real, Imag); + PlaceholderNode->addOperand(NodeA); + PlaceholderNode->addOperand(NodeB); + FinalInstructions.insert(MaskA); + FinalInstructions.insert(MaskB); + return submitCompositeNode(PlaceholderNode); +} + static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, FastMathFlags Flags, Value *InputA, Value *InputB) { @@ -1787,6 +1833,19 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, ReplacementNode = replaceNode(Builder, Node->Operands[0]); processReductionOperation(ReplacementNode, Node); break; + case ComplexDeinterleavingOperation::ReductionSelect: { + auto *MaskReal = Node->Real->getOperand(0); + auto *MaskImag = Node->Imag->getOperand(0); + auto *A = replaceNode(Builder, Node->Operands[0]); + auto *B = replaceNode(Builder, Node->Operands[1]); + auto *NewMaskTy = VectorType::getDoubleElementsVectorType( + cast<VectorType>(MaskReal->getType())); + auto *NewMask = + Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, + NewMaskTy, {MaskReal, MaskImag}); + ReplacementNode = Builder.CreateSelect(NewMask, A, B); + break; + } } assert(ReplacementNode && "Target failed to create Intrinsic call."); |