aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
diff options
context:
space:
mode:
authorIgor Kirillov <igor.kirillov@arm.com>2023-06-02 19:14:07 +0000
committerIgor Kirillov <igor.kirillov@arm.com>2023-06-14 17:27:26 +0000
commit2cbc265cc947c40372b841f80649276fbf9d183f (patch)
tree1a6ba6fad125201442975498834657623269b130 /llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
parente0f7b0e0f704dc3759925602e474b9e669270fcb (diff)
downloadllvm-2cbc265cc947c40372b841f80649276fbf9d183f.zip
llvm-2cbc265cc947c40372b841f80649276fbf9d183f.tar.gz
llvm-2cbc265cc947c40372b841f80649276fbf9d183f.tar.bz2
[CodeGen] Add support for reductions in ComplexDeinterleaving pass
This commit enhances the ComplexDeinterleaving pass to handle unordered reductions in simple one-block vectorized loops, supporting both SVE and Neon architectures. Differential Revision: https://reviews.llvm.org/D152022
Diffstat (limited to 'llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp')
-rw-r--r--llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp319
1 files changed, 295 insertions, 24 deletions
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index ec7abb2..1c61179 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -228,8 +228,53 @@ private:
/// Topologically sorted root instructions
SmallVector<Instruction *, 1> OrderedRoots;
+ /// When examining a basic block for complex deinterleaving, if it is a simple
+ /// one-block loop, then the only incoming block is 'Incoming' and the
+ /// 'BackEdge' block is the block itself."
+ BasicBlock *BackEdge = nullptr;
+ BasicBlock *Incoming = nullptr;
+
+ /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
+ /// %OutsideUser as it is shown in the IR:
+ ///
+ /// vector.body:
+ /// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
+ /// [ %ReductionOp, %vector.body ]
+ /// ...
+ /// %ReductionOp = fadd i64 ...
+ /// ...
+ /// br i1 %condition, label %vector.body, %middle.block
+ ///
+ /// middle.block:
+ /// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
+ ///
+ /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
+ /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
+ std::map<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
+
+ /// In the process of detecting a reduction, we consider a pair of
+ /// %ReductionOP, which we refer to as real and imag (or vice versa), and
+ /// traverse the use-tree to detect complex operations. As this is a reduction
+ /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
+ /// to the %ReductionOPs that we suspect to be complex.
+ /// RealPHI and ImagPHI are used by the identifyPHINode method.
+ PHINode *RealPHI = nullptr;
+ PHINode *ImagPHI = nullptr;
+
+ /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
+ /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
+ /// This mapping is populated during
+ /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
+ /// used in the ComplexDeinterleavingOperation::ReductionOperation node
+ /// replacement process.
+ std::map<PHINode *, PHINode *> OldToNewPHI;
+
NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
Instruction *R, Instruction *I) {
+ assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
+ Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
+ (R && I)) &&
+ "Reduction related nodes must have Real and Imaginary parts");
return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
I);
}
@@ -324,8 +369,17 @@ private:
/// intrinsic (for both fixed and scalable vectors)
NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
+ NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
+
Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
+ /// Complete IR modifications after producing new reduction operation:
+ /// * Populate the PHINode generated for
+ /// ComplexDeinterleavingOperation::ReductionPHI
+ /// * Deinterleave the final value outside of the loop and repurpose original
+ /// reduction users
+ void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
+
public:
void dump() { dump(dbgs()); }
void dump(raw_ostream &OS) {
@@ -337,6 +391,13 @@ public:
/// current graph.
bool identifyNodes(Instruction *RootI);
+ /// In case \pB is one-block loop, this function seeks potential reductions
+ /// and populates ReductionInfo. Returns true if any reductions were
+ /// identified.
+ bool collectPotentialReductions(BasicBlock *B);
+
+ void identifyReductionNodes();
+
/// Check that every instruction, from the roots to the leaves, has internal
/// uses.
bool checkNodes();
@@ -439,6 +500,9 @@ static bool isDeinterleavingMask(ArrayRef<int> Mask) {
bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
ComplexDeinterleavingGraph Graph(TL, TLI);
+ if (Graph.collectPotentialReductions(B))
+ Graph.identifyReductionNodes();
+
for (auto &I : *B)
Graph.identifyNodes(&I);
@@ -822,6 +886,9 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
if (NodePtr CN = identifyDeinterleave(Real, Imag))
return CN;
+ if (NodePtr CN = identifyPHINode(Real, Imag))
+ return CN;
+
auto *VTy = cast<VectorType>(Real->getType());
auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
@@ -1293,6 +1360,16 @@ ComplexDeinterleavingGraph::extractPositiveAddend(
}
bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
+ // This potential root instruction might already have been recognized as
+ // reduction. Because RootToNode maps both Real and Imaginary parts to
+ // CompositeNode we should choose only one either Real or Imag instruction to
+ // use as an anchor for generating complex instruction.
+ auto It = RootToNode.find(RootI);
+ if (It != RootToNode.end() && It->second->Real == RootI) {
+ OrderedRoots.push_back(RootI);
+ return true;
+ }
+
auto RootNode = identifyRoot(RootI);
if (!RootNode)
return false;
@@ -1310,12 +1387,113 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
return true;
}
+bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
+ bool FoundPotentialReduction = false;
+
+ auto *Br = dyn_cast<BranchInst>(B->getTerminator());
+ if (!Br || Br->getNumSuccessors() != 2)
+ return false;
+
+ // Identify simple one-block loop
+ if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
+ return false;
+
+ SmallVector<PHINode *> PHIs;
+ for (auto &PHI : B->phis()) {
+ if (PHI.getNumIncomingValues() != 2)
+ continue;
+
+ if (!PHI.getType()->isVectorTy())
+ continue;
+
+ auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
+ if (!ReductionOp)
+ continue;
+
+ // Check if final instruction is reduced outside of current block
+ Instruction *FinalReduction = nullptr;
+ auto NumUsers = 0u;
+ for (auto *U : ReductionOp->users()) {
+ ++NumUsers;
+ if (U == &PHI)
+ continue;
+ FinalReduction = dyn_cast<Instruction>(U);
+ }
+
+ if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B)
+ continue;
+
+ ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
+ BackEdge = B;
+ auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
+ auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
+ Incoming = PHI.getIncomingBlock(IncomingIdx);
+ FoundPotentialReduction = true;
+
+ // If the initial value of PHINode is an Instruction, consider it a leaf
+ // value of a complex deinterleaving graph.
+ if (auto *InitPHI =
+ dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
+ FinalInstructions.insert(InitPHI);
+ }
+ return FoundPotentialReduction;
+}
+
+void ComplexDeinterleavingGraph::identifyReductionNodes() {
+ SmallVector<bool> Processed(ReductionInfo.size(), false);
+ SmallVector<Instruction *> OperationInstruction;
+ for (auto &P : ReductionInfo)
+ OperationInstruction.push_back(P.first);
+
+ // Identify a complex computation by evaluating two reduction operations that
+ // potentially could be involved
+ for (size_t i = 0; i < OperationInstruction.size(); ++i) {
+ if (Processed[i])
+ continue;
+ for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
+ if (Processed[j])
+ continue;
+
+ auto *Real = OperationInstruction[i];
+ auto *Imag = OperationInstruction[j];
+
+ RealPHI = ReductionInfo[Real].first;
+ ImagPHI = ReductionInfo[Imag].first;
+ auto Node = identifyNode(Real, Imag);
+ if (!Node) {
+ std::swap(Real, Imag);
+ std::swap(RealPHI, ImagPHI);
+ Node = identifyNode(Real, Imag);
+ }
+
+ // If a node is identified, mark its operation instructions as used to
+ // prevent re-identification and attach the node to the real part
+ if (Node) {
+ LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
+ << *Real << " / " << *Imag << "\n");
+ Processed[i] = true;
+ Processed[j] = true;
+ auto RootNode = prepareCompositeNode(
+ ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
+ RootNode->addOperand(Node);
+ RootToNode[Real] = RootNode;
+ RootToNode[Imag] = RootNode;
+ submitCompositeNode(RootNode);
+ break;
+ }
+ }
+ }
+
+ RealPHI = nullptr;
+ ImagPHI = nullptr;
+}
+
bool ComplexDeinterleavingGraph::checkNodes() {
// Collect all instructions from roots to leaves
SmallPtrSet<Instruction *, 16> AllInstructions;
SmallVector<Instruction *, 8> Worklist;
- for (auto *I : OrderedRoots)
- Worklist.push_back(I);
+ for (auto &Pair : RootToNode)
+ Worklist.push_back(Pair.first);
// Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
// chains
@@ -1524,6 +1702,17 @@ ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
return submitCompositeNode(PlaceholderNode);
}
+ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
+ Instruction *Imag) {
+ if (Real != RealPHI || Imag != ImagPHI)
+ return nullptr;
+
+ NodePtr PlaceholderNode = prepareCompositeNode(
+ ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
+ return submitCompositeNode(PlaceholderNode);
+}
+
static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
FastMathFlags Flags, Value *InputA,
Value *InputB) {
@@ -1553,27 +1742,100 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
if (Node->ReplacementNode)
return Node->ReplacementNode;
- Value *Input0 = replaceNode(Builder, Node->Operands[0]);
- Value *Input1 = Node->Operands.size() > 1
- ? replaceNode(Builder, Node->Operands[1])
- : nullptr;
- Value *Accumulator = Node->Operands.size() > 2
- ? replaceNode(Builder, Node->Operands[2])
- : nullptr;
- if (Input1)
- assert(Input0->getType() == Input1->getType() &&
- "Node inputs need to be of the same type");
-
- if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
- Node->ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode,
- Node->Flags, Input0, Input1);
- else
- Node->ReplacementNode = TL->createComplexDeinterleavingIR(
- Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
+ auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
+ return Node->Operands.size() > Idx
+ ? replaceNode(Builder, Node->Operands[Idx])
+ : nullptr;
+ };
+
+ Value *ReplacementNode;
+ switch (Node->Operation) {
+ case ComplexDeinterleavingOperation::CAdd:
+ case ComplexDeinterleavingOperation::CMulPartial:
+ case ComplexDeinterleavingOperation::Symmetric: {
+ Value *Input0 = ReplaceOperandIfExist(Node, 0);
+ Value *Input1 = ReplaceOperandIfExist(Node, 1);
+ Value *Accumulator = ReplaceOperandIfExist(Node, 2);
+ assert(!Input1 || (Input0->getType() == Input1->getType() &&
+ "Node inputs need to be of the same type"));
+ assert(!Accumulator ||
+ (Input0->getType() == Accumulator->getType() &&
+ "Accumulator and input need to be of the same type"));
+ if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
+ ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
+ Input0, Input1);
+ else
+ ReplacementNode = TL->createComplexDeinterleavingIR(
+ Builder, Node->Operation, Node->Rotation, Input0, Input1,
+ Accumulator);
+ break;
+ }
+ case ComplexDeinterleavingOperation::Deinterleave:
+ llvm_unreachable("Deinterleave node should already have ReplacementNode");
+ break;
+ case ComplexDeinterleavingOperation::ReductionPHI: {
+ // If Operation is ReductionPHI, a new empty PHINode is created.
+ // It is filled later when the ReductionOperation is processed.
+ auto *VTy = cast<VectorType>(Node->Real->getType());
+ auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
+ auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI());
+ OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
+ ReplacementNode = NewPHI;
+ break;
+ }
+ case ComplexDeinterleavingOperation::ReductionOperation:
+ ReplacementNode = replaceNode(Builder, Node->Operands[0]);
+ processReductionOperation(ReplacementNode, Node);
+ break;
+ default:
+ llvm_unreachable(
+ "Unhandled case in ComplexDeinterleavingGraph::replaceNode");
+ break;
+ }
- assert(Node->ReplacementNode && "Target failed to create Intrinsic call.");
+ assert(ReplacementNode && "Target failed to create Intrinsic call.");
NumComplexTransformations += 1;
- return Node->ReplacementNode;
+ Node->ReplacementNode = ReplacementNode;
+ return ReplacementNode;
+}
+
+void ComplexDeinterleavingGraph::processReductionOperation(
+ Value *OperationReplacement, RawNodePtr Node) {
+ auto *OldPHIReal = ReductionInfo[Node->Real].first;
+ auto *OldPHIImag = ReductionInfo[Node->Imag].first;
+ auto *NewPHI = OldToNewPHI[OldPHIReal];
+
+ auto *VTy = cast<VectorType>(Node->Real->getType());
+ auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
+
+ // We have to interleave initial origin values coming from IncomingBlock
+ Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
+ Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
+
+ IRBuilder<> Builder(Incoming->getTerminator());
+ auto *NewInit = Builder.CreateIntrinsic(
+ Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag});
+
+ NewPHI->addIncoming(NewInit, Incoming);
+ NewPHI->addIncoming(OperationReplacement, BackEdge);
+
+ // Deinterleave complex vector outside of loop so that it can be finally
+ // reduced
+ auto *FinalReductionReal = ReductionInfo[Node->Real].second;
+ auto *FinalReductionImag = ReductionInfo[Node->Imag].second;
+
+ Builder.SetInsertPoint(
+ &*FinalReductionReal->getParent()->getFirstInsertionPt());
+ auto *Deinterleave = Builder.CreateIntrinsic(
+ Intrinsic::experimental_vector_deinterleave2,
+ OperationReplacement->getType(), OperationReplacement);
+
+ auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
+ FinalReductionReal->replaceUsesOfWith(Node->Real, NewReal);
+
+ Builder.SetInsertPoint(FinalReductionImag);
+ auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
+ FinalReductionImag->replaceUsesOfWith(Node->Imag, NewImag);
}
void ComplexDeinterleavingGraph::replaceNodes() {
@@ -1587,9 +1849,18 @@ void ComplexDeinterleavingGraph::replaceNodes() {
IRBuilder<> Builder(RootInstruction);
auto RootNode = RootToNode[RootInstruction];
Value *R = replaceNode(Builder, RootNode.get());
- assert(R && "Unable to find replacement for RootInstruction");
- DeadInstrRoots.push_back(RootInstruction);
- RootInstruction->replaceAllUsesWith(R);
+
+ if (RootNode->Operation ==
+ ComplexDeinterleavingOperation::ReductionOperation) {
+ ReductionInfo[RootNode->Real].first->removeIncomingValue(BackEdge);
+ ReductionInfo[RootNode->Imag].first->removeIncomingValue(BackEdge);
+ DeadInstrRoots.push_back(RootNode->Real);
+ DeadInstrRoots.push_back(RootNode->Imag);
+ } else {
+ assert(R && "Unable to find replacement for RootInstruction");
+ DeadInstrRoots.push_back(RootInstruction);
+ RootInstruction->replaceAllUsesWith(R);
+ }
}
for (auto *I : DeadInstrRoots)