aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
diff options
context:
space:
mode:
authorIgor Kirillov <igor.kirillov@arm.com>2023-03-27 16:32:40 +0000
committerIgor Kirillov <igor.kirillov@arm.com>2023-04-18 13:05:49 +0000
commitc692e87ab8e7d3c7d8e2365572ffb41f6ec9ac1d (patch)
tree364a2c30e11ca8ce1dc0ae7ec3badc0ca70b6e0a /llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
parentdc86900ff31e35e61ec9c5adca0488bf33d11833 (diff)
downloadllvm-c692e87ab8e7d3c7d8e2365572ffb41f6ec9ac1d.zip
llvm-c692e87ab8e7d3c7d8e2365572ffb41f6ec9ac1d.tar.gz
llvm-c692e87ab8e7d3c7d8e2365572ffb41f6ec9ac1d.tar.bz2
[CodeGen] Enable processing of interconnected complex number operations
With this patch, ComplexDeinterleavingPass now has the ability to handle any number of interconnected operations involving complex numbers. For example, the patch enables the processing of code like the following: for (int i = 0; i < 1000; ++i) { a[i] = w[i] * v[i]; b[i] = w[i] * u[i]; } This code has multiple arrays containing complex numbers and a common subexpression `w` that appears in two expressions. Differential Revision: https://reviews.llvm.org/D146988
Diffstat (limited to 'llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp')
-rw-r--r--llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp189
1 files changed, 118 insertions, 71 deletions
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index ff0c5d5..3cfe935 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -137,19 +137,12 @@ public:
Instruction *Real;
Instruction *Imag;
- // Instructions that should only exist within this node, there should be no
- // users of these instructions outside the node. An example of these would be
- // the multiply instructions of a partial multiply operation.
- SmallVector<Instruction *> InternalInstructions;
ComplexDeinterleavingRotation Rotation;
SmallVector<RawNodePtr> Operands;
Value *ReplacementNode = nullptr;
- void addInstruction(Instruction *I) { InternalInstructions.push_back(I); }
void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
- bool hasAllInternalUses(SmallPtrSet<Instruction *, 16> &AllInstructions);
-
void dump() { dump(dbgs()); }
void dump(raw_ostream &OS) {
auto PrintValue = [&](Value *V) {
@@ -181,12 +174,6 @@ public:
OS << " - ";
PrintNodeRef(Op);
}
- OS << " InternalInstructions:\n";
- for (const auto &I : InternalInstructions) {
- OS << " - \"";
- I->print(OS, true);
- OS << "\"\n";
- }
}
};
@@ -194,14 +181,22 @@ class ComplexDeinterleavingGraph {
public:
using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
- explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {}
+ explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
+ const TargetLibraryInfo *TLI)
+ : TL(TL), TLI(TLI) {}
private:
const TargetLowering *TL = nullptr;
- Instruction *RootValue = nullptr;
- NodePtr RootNode;
+ const TargetLibraryInfo *TLI = nullptr;
SmallVector<NodePtr> CompositeNodes;
- SmallPtrSet<Instruction *, 16> AllInstructions;
+
+ SmallPtrSet<Instruction *, 16> FinalInstructions;
+
+ /// Root instructions are instructions from which complex computation starts
+ std::map<Instruction *, NodePtr> RootToNode;
+
+ /// Topologically sorted root instructions
+ SmallVector<Instruction *, 1> OrderedRoots;
NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
Instruction *R, Instruction *I) {
@@ -211,10 +206,6 @@ private:
NodePtr submitCompositeNode(NodePtr Node) {
CompositeNodes.push_back(Node);
- AllInstructions.insert(Node->Real);
- AllInstructions.insert(Node->Imag);
- for (auto *I : Node->InternalInstructions)
- AllInstructions.insert(I);
return Node;
}
@@ -271,6 +262,10 @@ public:
/// current graph.
bool identifyNodes(Instruction *RootI);
+ /// Check that every instruction, from the roots to the leaves, has internal
+ /// uses.
+ bool checkNodes();
+
/// Perform the actual replacement of the underlying instruction graph.
void replaceNodes();
};
@@ -368,9 +363,7 @@ static bool isDeinterleavingMask(ArrayRef<int> Mask) {
}
bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
- bool Changed = false;
-
- SmallVector<Instruction *> DeadInstrRoots;
+ ComplexDeinterleavingGraph Graph(TL, TLI);
for (auto &I : *B) {
auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
@@ -382,22 +375,15 @@ bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
if (!isInterleavingMask(SVI->getShuffleMask()))
continue;
- ComplexDeinterleavingGraph Graph(TL);
- if (!Graph.identifyNodes(SVI))
- continue;
-
- Graph.replaceNodes();
- DeadInstrRoots.push_back(SVI);
- Changed = true;
+ Graph.identifyNodes(SVI);
}
- for (const auto &I : DeadInstrRoots) {
- if (!I || I->getParent() == nullptr)
- continue;
- llvm::RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
+ if (Graph.checkNodes()) {
+ Graph.replaceNodes();
+ return true;
}
- return Changed;
+ return false;
}
ComplexDeinterleavingGraph::NodePtr
@@ -511,7 +497,6 @@ ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
Node->Rotation = Rotation;
Node->addOperand(CommonNode);
Node->addOperand(UncommonNode);
- Node->InternalInstructions.append(FNegs);
return submitCompositeNode(Node);
}
@@ -627,8 +612,6 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
NodePtr Node = prepareCompositeNode(
ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
- Node->addInstruction(RealMulI);
- Node->addInstruction(ImagMulI);
Node->Rotation = Rotation;
Node->addOperand(CommonRes);
Node->addOperand(UncommonRes);
@@ -846,6 +829,8 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Shuffle,
RealShuffle, ImagShuffle);
PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
+ FinalInstructions.insert(RealShuffle);
+ FinalInstructions.insert(ImagShuffle);
return submitCompositeNode(PlaceholderNode);
}
if (RealShuffle || ImagShuffle) {
@@ -881,9 +866,7 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
return false;
- RootValue = RootI;
- AllInstructions.insert(RootI);
- RootNode = identifyNode(Real, Imag);
+ auto RootNode = identifyNode(Real, Imag);
LLVM_DEBUG({
Function *F = RootI->getFunction();
@@ -894,14 +877,86 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
dbgs() << "\n";
});
- // Check all instructions have internal uses
- for (const auto &Node : CompositeNodes) {
- if (!Node->hasAllInternalUses(AllInstructions)) {
- LLVM_DEBUG(dbgs() << " - Invalid internal uses\n");
- return false;
+ if (RootNode) {
+ RootToNode[RootI] = RootNode;
+ OrderedRoots.push_back(RootI);
+ return true;
+ }
+
+ return false;
+}
+
+bool ComplexDeinterleavingGraph::checkNodes() {
+ // Collect all instructions from roots to leaves
+ SmallPtrSet<Instruction *, 16> AllInstructions;
+ SmallVector<Instruction *, 8> Worklist;
+ for (auto *I : OrderedRoots)
+ Worklist.push_back(I);
+
+ // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
+ // chains
+ while (!Worklist.empty()) {
+ auto *I = Worklist.back();
+ Worklist.pop_back();
+
+ if (!AllInstructions.insert(I).second)
+ continue;
+
+ for (Value *Op : I->operands()) {
+ if (auto *OpI = dyn_cast<Instruction>(Op)) {
+ if (!FinalInstructions.count(I))
+ Worklist.emplace_back(OpI);
+ }
}
}
- return RootNode != nullptr;
+
+ // Find instructions that have users outside of chain
+ SmallVector<Instruction *, 2> OuterInstructions;
+ for (auto *I : AllInstructions) {
+ // Skip root nodes
+ if (RootToNode.count(I))
+ continue;
+
+ for (User *U : I->users()) {
+ if (AllInstructions.count(cast<Instruction>(U)))
+ continue;
+
+ // Found an instruction that is not used by XCMLA/XCADD chain
+ Worklist.emplace_back(I);
+ break;
+ }
+ }
+
+ // If any instructions are found to be used outside, find and remove roots
+ // that somehow connect to those instructions.
+ SmallPtrSet<Instruction *, 16> Visited;
+ while (!Worklist.empty()) {
+ auto *I = Worklist.back();
+ Worklist.pop_back();
+ if (!Visited.insert(I).second)
+ continue;
+
+ // Found an impacted root node. Removing it from the nodes to be
+ // deinterleaved
+ if (RootToNode.count(I)) {
+ LLVM_DEBUG(dbgs() << "Instruction " << *I
+ << " could be deinterleaved but its chain of complex "
+ "operations have an outside user\n");
+ RootToNode.erase(I);
+ }
+
+ if (!AllInstructions.count(I) || FinalInstructions.count(I))
+ continue;
+
+ for (User *U : I->users())
+ Worklist.emplace_back(cast<Instruction>(U));
+
+ for (Value *Op : I->operands()) {
+ if (auto *OpI = dyn_cast<Instruction>(Op))
+ Worklist.emplace_back(OpI);
+ }
+ }
+ return !RootToNode.empty();
}
static Value *replaceSymmetricNode(ComplexDeinterleavingGraph::RawNodePtr Node,
@@ -958,29 +1013,21 @@ Value *ComplexDeinterleavingGraph::replaceNode(
}
void ComplexDeinterleavingGraph::replaceNodes() {
- Value *R = replaceNode(RootNode.get());
- assert(R && "Unable to find replacement for RootValue");
- RootValue->replaceAllUsesWith(R);
-}
-
-bool ComplexDeinterleavingCompositeNode::hasAllInternalUses(
- SmallPtrSet<Instruction *, 16> &AllInstructions) {
- if (Operation == ComplexDeinterleavingOperation::Shuffle)
- return true;
+ SmallVector<Instruction *, 16> DeadInstrRoots;
+ for (auto *RootInstruction : OrderedRoots) {
+ // Check if this potential root went through check process and we can
+ // deinterleave it
+ if (!RootToNode.count(RootInstruction))
+ continue;
- for (auto *User : Real->users()) {
- if (!AllInstructions.contains(cast<Instruction>(User)))
- return false;
+ IRBuilder<> Builder(RootInstruction);
+ auto RootNode = RootToNode[RootInstruction];
+ Value *R = replaceNode(RootNode.get());
+ assert(R && "Unable to find replacement for RootInstruction");
+ DeadInstrRoots.push_back(RootInstruction);
+ RootInstruction->replaceAllUsesWith(R);
}
- for (auto *User : Imag->users()) {
- if (!AllInstructions.contains(cast<Instruction>(User)))
- return false;
- }
- for (auto *I : InternalInstructions) {
- for (auto *User : I->users()) {
- if (!AllInstructions.contains(cast<Instruction>(User)))
- return false;
- }
- }
- return true;
+
+ for (auto *I : DeadInstrRoots)
+ RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
}