diff options
author | Igor Kirillov <igor.kirillov@arm.com> | 2023-06-23 09:55:34 +0000 |
---|---|---|
committer | Igor Kirillov <igor.kirillov@arm.com> | 2023-06-23 10:13:22 +0000 |
commit | 04a8070b46da2bcd47d0a134922409dc16bb9d57 (patch) | |
tree | 33ce3a20bc949144b0d2d419db4d350a1e6c51ab /llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp | |
parent | 2273741ea2547ecda28cb01d7679d0563b35ac16 (diff) | |
download | llvm-04a8070b46da2bcd47d0a134922409dc16bb9d57.zip llvm-04a8070b46da2bcd47d0a134922409dc16bb9d57.tar.gz llvm-04a8070b46da2bcd47d0a134922409dc16bb9d57.tar.bz2 |
Revert "Revert "[CodeGen] Extend reduction support in ComplexDeinterleaving pass to support predication""
Adds the capability to recognize SelectInst that appear in the IR.
These instructions are generated during scalable vectorization for reduction
and when the code contains conditions inside the loop body or when
"-prefer-predicate-over-epilogue=predicate-dont-vectorize" is set.
Differential Revision: https://reviews.llvm.org/D152558
This reverts commit ab09654832dba5cef8baa6400fdfd3e4d1495624.
Reason: Reapplying after removing unnecessary default case in switch expression.
Diffstat (limited to 'llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp')
-rw-r--r-- | llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp index a30edf5..9339ba3 100644 --- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp +++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp @@ -371,6 +371,10 @@ private: NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); + /// Identifies SelectInsts in a loop that has reduction with predication masks + /// and/or predicated tail folding + NodePtr identifySelectNode(Instruction *Real, Instruction *Imag); + Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); /// Complete IR modifications after producing new reduction operation: @@ -889,6 +893,9 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) { if (NodePtr CN = identifyPHINode(Real, Imag)) return CN; + if (NodePtr CN = identifySelectNode(Real, Imag)) + return CN; + auto *VTy = cast<VectorType>(Real->getType()); auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); @@ -1713,6 +1720,45 @@ ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, return submitCompositeNode(PlaceholderNode); } +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real, + Instruction *Imag) { + auto *SelectReal = dyn_cast<SelectInst>(Real); + auto *SelectImag = dyn_cast<SelectInst>(Imag); + if (!SelectReal || !SelectImag) + return nullptr; + + Instruction *MaskA, *MaskB; + Instruction *AR, *AI, *RA, *BI; + if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR), + m_Instruction(RA))) || + !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI), + m_Instruction(BI)))) + return nullptr; + + if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB)) + return nullptr; + + if (!MaskA->getType()->isVectorTy()) + return nullptr; + + auto NodeA = identifyNode(AR, AI); + if (!NodeA) + return nullptr; + + auto NodeB = identifyNode(RA, BI); + if (!NodeB) + return nullptr; + + NodePtr PlaceholderNode = prepareCompositeNode( + ComplexDeinterleavingOperation::ReductionSelect, Real, Imag); + PlaceholderNode->addOperand(NodeA); + PlaceholderNode->addOperand(NodeB); + FinalInstructions.insert(MaskA); + FinalInstructions.insert(MaskB); + return submitCompositeNode(PlaceholderNode); +} + static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, FastMathFlags Flags, Value *InputA, Value *InputB) { @@ -1787,6 +1833,19 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, ReplacementNode = replaceNode(Builder, Node->Operands[0]); processReductionOperation(ReplacementNode, Node); break; + case ComplexDeinterleavingOperation::ReductionSelect: { + auto *MaskReal = Node->Real->getOperand(0); + auto *MaskImag = Node->Imag->getOperand(0); + auto *A = replaceNode(Builder, Node->Operands[0]); + auto *B = replaceNode(Builder, Node->Operands[1]); + auto *NewMaskTy = VectorType::getDoubleElementsVectorType( + cast<VectorType>(MaskReal->getType())); + auto *NewMask = + Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, + NewMaskTy, {MaskReal, MaskImag}); + ReplacementNode = Builder.CreateSelect(NewMask, A, B); + break; + } } assert(ReplacementNode && "Target failed to create Intrinsic call."); |