diff options
Diffstat (limited to 'llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp')
-rw-r--r-- | llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp | 319 |
1 files changed, 295 insertions, 24 deletions
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp index ec7abb2..1c61179 100644 --- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp +++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp @@ -228,8 +228,53 @@ private: /// Topologically sorted root instructions SmallVector<Instruction *, 1> OrderedRoots; + /// When examining a basic block for complex deinterleaving, if it is a simple + /// one-block loop, then the only incoming block is 'Incoming' and the + /// 'BackEdge' block is the block itself." + BasicBlock *BackEdge = nullptr; + BasicBlock *Incoming = nullptr; + + /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction + /// %OutsideUser as it is shown in the IR: + /// + /// vector.body: + /// %PHInode = phi <vector type> [ zeroinitializer, %entry ], + /// [ %ReductionOp, %vector.body ] + /// ... + /// %ReductionOp = fadd i64 ... + /// ... + /// br i1 %condition, label %vector.body, %middle.block + /// + /// middle.block: + /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp) + /// + /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding + /// `llvm.vector.reduce.fadd` when unroll factor isn't one. + std::map<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo; + + /// In the process of detecting a reduction, we consider a pair of + /// %ReductionOP, which we refer to as real and imag (or vice versa), and + /// traverse the use-tree to detect complex operations. As this is a reduction + /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds + /// to the %ReductionOPs that we suspect to be complex. + /// RealPHI and ImagPHI are used by the identifyPHINode method. + PHINode *RealPHI = nullptr; + PHINode *ImagPHI = nullptr; + + /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode. + /// The new PHINode corresponds to a vector of deinterleaved complex numbers. + /// This mapping is populated during + /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then + /// used in the ComplexDeinterleavingOperation::ReductionOperation node + /// replacement process. + std::map<PHINode *, PHINode *> OldToNewPHI; + NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation, Instruction *R, Instruction *I) { + assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI && + Operation != ComplexDeinterleavingOperation::ReductionOperation) || + (R && I)) && + "Reduction related nodes must have Real and Imaginary parts"); return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R, I); } @@ -324,8 +369,17 @@ private: /// intrinsic (for both fixed and scalable vectors) NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag); + NodePtr identifyPHINode(Instruction *Real, Instruction *Imag); + Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node); + /// Complete IR modifications after producing new reduction operation: + /// * Populate the PHINode generated for + /// ComplexDeinterleavingOperation::ReductionPHI + /// * Deinterleave the final value outside of the loop and repurpose original + /// reduction users + void processReductionOperation(Value *OperationReplacement, RawNodePtr Node); + public: void dump() { dump(dbgs()); } void dump(raw_ostream &OS) { @@ -337,6 +391,13 @@ public: /// current graph. bool identifyNodes(Instruction *RootI); + /// In case \pB is one-block loop, this function seeks potential reductions + /// and populates ReductionInfo. Returns true if any reductions were + /// identified. + bool collectPotentialReductions(BasicBlock *B); + + void identifyReductionNodes(); + /// Check that every instruction, from the roots to the leaves, has internal /// uses. bool checkNodes(); @@ -439,6 +500,9 @@ static bool isDeinterleavingMask(ArrayRef<int> Mask) { bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) { ComplexDeinterleavingGraph Graph(TL, TLI); + if (Graph.collectPotentialReductions(B)) + Graph.identifyReductionNodes(); + for (auto &I : *B) Graph.identifyNodes(&I); @@ -822,6 +886,9 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) { if (NodePtr CN = identifyDeinterleave(Real, Imag)) return CN; + if (NodePtr CN = identifyPHINode(Real, Imag)) + return CN; + auto *VTy = cast<VectorType>(Real->getType()); auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); @@ -1293,6 +1360,16 @@ ComplexDeinterleavingGraph::extractPositiveAddend( } bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { + // This potential root instruction might already have been recognized as + // reduction. Because RootToNode maps both Real and Imaginary parts to + // CompositeNode we should choose only one either Real or Imag instruction to + // use as an anchor for generating complex instruction. + auto It = RootToNode.find(RootI); + if (It != RootToNode.end() && It->second->Real == RootI) { + OrderedRoots.push_back(RootI); + return true; + } + auto RootNode = identifyRoot(RootI); if (!RootNode) return false; @@ -1310,12 +1387,113 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) { return true; } +bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) { + bool FoundPotentialReduction = false; + + auto *Br = dyn_cast<BranchInst>(B->getTerminator()); + if (!Br || Br->getNumSuccessors() != 2) + return false; + + // Identify simple one-block loop + if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B) + return false; + + SmallVector<PHINode *> PHIs; + for (auto &PHI : B->phis()) { + if (PHI.getNumIncomingValues() != 2) + continue; + + if (!PHI.getType()->isVectorTy()) + continue; + + auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B)); + if (!ReductionOp) + continue; + + // Check if final instruction is reduced outside of current block + Instruction *FinalReduction = nullptr; + auto NumUsers = 0u; + for (auto *U : ReductionOp->users()) { + ++NumUsers; + if (U == &PHI) + continue; + FinalReduction = dyn_cast<Instruction>(U); + } + + if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B) + continue; + + ReductionInfo[ReductionOp] = {&PHI, FinalReduction}; + BackEdge = B; + auto BackEdgeIdx = PHI.getBasicBlockIndex(B); + auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0; + Incoming = PHI.getIncomingBlock(IncomingIdx); + FoundPotentialReduction = true; + + // If the initial value of PHINode is an Instruction, consider it a leaf + // value of a complex deinterleaving graph. + if (auto *InitPHI = + dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming))) + FinalInstructions.insert(InitPHI); + } + return FoundPotentialReduction; +} + +void ComplexDeinterleavingGraph::identifyReductionNodes() { + SmallVector<bool> Processed(ReductionInfo.size(), false); + SmallVector<Instruction *> OperationInstruction; + for (auto &P : ReductionInfo) + OperationInstruction.push_back(P.first); + + // Identify a complex computation by evaluating two reduction operations that + // potentially could be involved + for (size_t i = 0; i < OperationInstruction.size(); ++i) { + if (Processed[i]) + continue; + for (size_t j = i + 1; j < OperationInstruction.size(); ++j) { + if (Processed[j]) + continue; + + auto *Real = OperationInstruction[i]; + auto *Imag = OperationInstruction[j]; + + RealPHI = ReductionInfo[Real].first; + ImagPHI = ReductionInfo[Imag].first; + auto Node = identifyNode(Real, Imag); + if (!Node) { + std::swap(Real, Imag); + std::swap(RealPHI, ImagPHI); + Node = identifyNode(Real, Imag); + } + + // If a node is identified, mark its operation instructions as used to + // prevent re-identification and attach the node to the real part + if (Node) { + LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: " + << *Real << " / " << *Imag << "\n"); + Processed[i] = true; + Processed[j] = true; + auto RootNode = prepareCompositeNode( + ComplexDeinterleavingOperation::ReductionOperation, Real, Imag); + RootNode->addOperand(Node); + RootToNode[Real] = RootNode; + RootToNode[Imag] = RootNode; + submitCompositeNode(RootNode); + break; + } + } + } + + RealPHI = nullptr; + ImagPHI = nullptr; +} + bool ComplexDeinterleavingGraph::checkNodes() { // Collect all instructions from roots to leaves SmallPtrSet<Instruction *, 16> AllInstructions; SmallVector<Instruction *, 8> Worklist; - for (auto *I : OrderedRoots) - Worklist.push_back(I); + for (auto &Pair : RootToNode) + Worklist.push_back(Pair.first); // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG // chains @@ -1524,6 +1702,17 @@ ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real, return submitCompositeNode(PlaceholderNode); } +ComplexDeinterleavingGraph::NodePtr +ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real, + Instruction *Imag) { + if (Real != RealPHI || Imag != ImagPHI) + return nullptr; + + NodePtr PlaceholderNode = prepareCompositeNode( + ComplexDeinterleavingOperation::ReductionPHI, Real, Imag); + return submitCompositeNode(PlaceholderNode); +} + static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode, FastMathFlags Flags, Value *InputA, Value *InputB) { @@ -1553,27 +1742,100 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder, if (Node->ReplacementNode) return Node->ReplacementNode; - Value *Input0 = replaceNode(Builder, Node->Operands[0]); - Value *Input1 = Node->Operands.size() > 1 - ? replaceNode(Builder, Node->Operands[1]) - : nullptr; - Value *Accumulator = Node->Operands.size() > 2 - ? replaceNode(Builder, Node->Operands[2]) - : nullptr; - if (Input1) - assert(Input0->getType() == Input1->getType() && - "Node inputs need to be of the same type"); - - if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) - Node->ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, - Node->Flags, Input0, Input1); - else - Node->ReplacementNode = TL->createComplexDeinterleavingIR( - Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator); + auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * { + return Node->Operands.size() > Idx + ? replaceNode(Builder, Node->Operands[Idx]) + : nullptr; + }; + + Value *ReplacementNode; + switch (Node->Operation) { + case ComplexDeinterleavingOperation::CAdd: + case ComplexDeinterleavingOperation::CMulPartial: + case ComplexDeinterleavingOperation::Symmetric: { + Value *Input0 = ReplaceOperandIfExist(Node, 0); + Value *Input1 = ReplaceOperandIfExist(Node, 1); + Value *Accumulator = ReplaceOperandIfExist(Node, 2); + assert(!Input1 || (Input0->getType() == Input1->getType() && + "Node inputs need to be of the same type")); + assert(!Accumulator || + (Input0->getType() == Accumulator->getType() && + "Accumulator and input need to be of the same type")); + if (Node->Operation == ComplexDeinterleavingOperation::Symmetric) + ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags, + Input0, Input1); + else + ReplacementNode = TL->createComplexDeinterleavingIR( + Builder, Node->Operation, Node->Rotation, Input0, Input1, + Accumulator); + break; + } + case ComplexDeinterleavingOperation::Deinterleave: + llvm_unreachable("Deinterleave node should already have ReplacementNode"); + break; + case ComplexDeinterleavingOperation::ReductionPHI: { + // If Operation is ReductionPHI, a new empty PHINode is created. + // It is filled later when the ReductionOperation is processed. + auto *VTy = cast<VectorType>(Node->Real->getType()); + auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); + auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI()); + OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI; + ReplacementNode = NewPHI; + break; + } + case ComplexDeinterleavingOperation::ReductionOperation: + ReplacementNode = replaceNode(Builder, Node->Operands[0]); + processReductionOperation(ReplacementNode, Node); + break; + default: + llvm_unreachable( + "Unhandled case in ComplexDeinterleavingGraph::replaceNode"); + break; + } - assert(Node->ReplacementNode && "Target failed to create Intrinsic call."); + assert(ReplacementNode && "Target failed to create Intrinsic call."); NumComplexTransformations += 1; - return Node->ReplacementNode; + Node->ReplacementNode = ReplacementNode; + return ReplacementNode; +} + +void ComplexDeinterleavingGraph::processReductionOperation( + Value *OperationReplacement, RawNodePtr Node) { + auto *OldPHIReal = ReductionInfo[Node->Real].first; + auto *OldPHIImag = ReductionInfo[Node->Imag].first; + auto *NewPHI = OldToNewPHI[OldPHIReal]; + + auto *VTy = cast<VectorType>(Node->Real->getType()); + auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy); + + // We have to interleave initial origin values coming from IncomingBlock + Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming); + Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming); + + IRBuilder<> Builder(Incoming->getTerminator()); + auto *NewInit = Builder.CreateIntrinsic( + Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag}); + + NewPHI->addIncoming(NewInit, Incoming); + NewPHI->addIncoming(OperationReplacement, BackEdge); + + // Deinterleave complex vector outside of loop so that it can be finally + // reduced + auto *FinalReductionReal = ReductionInfo[Node->Real].second; + auto *FinalReductionImag = ReductionInfo[Node->Imag].second; + + Builder.SetInsertPoint( + &*FinalReductionReal->getParent()->getFirstInsertionPt()); + auto *Deinterleave = Builder.CreateIntrinsic( + Intrinsic::experimental_vector_deinterleave2, + OperationReplacement->getType(), OperationReplacement); + + auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0); + FinalReductionReal->replaceUsesOfWith(Node->Real, NewReal); + + Builder.SetInsertPoint(FinalReductionImag); + auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1); + FinalReductionImag->replaceUsesOfWith(Node->Imag, NewImag); } void ComplexDeinterleavingGraph::replaceNodes() { @@ -1587,9 +1849,18 @@ void ComplexDeinterleavingGraph::replaceNodes() { IRBuilder<> Builder(RootInstruction); auto RootNode = RootToNode[RootInstruction]; Value *R = replaceNode(Builder, RootNode.get()); - assert(R && "Unable to find replacement for RootInstruction"); - DeadInstrRoots.push_back(RootInstruction); - RootInstruction->replaceAllUsesWith(R); + + if (RootNode->Operation == + ComplexDeinterleavingOperation::ReductionOperation) { + ReductionInfo[RootNode->Real].first->removeIncomingValue(BackEdge); + ReductionInfo[RootNode->Imag].first->removeIncomingValue(BackEdge); + DeadInstrRoots.push_back(RootNode->Real); + DeadInstrRoots.push_back(RootNode->Imag); + } else { + assert(R && "Unable to find replacement for RootInstruction"); + DeadInstrRoots.push_back(RootInstruction); + RootInstruction->replaceAllUsesWith(R); + } } for (auto *I : DeadInstrRoots) |