aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp')
-rw-r--r--llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp189
1 files changed, 118 insertions, 71 deletions
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index ff0c5d5..3cfe935 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -137,19 +137,12 @@ public:
Instruction *Real;
Instruction *Imag;
- // Instructions that should only exist within this node, there should be no
- // users of these instructions outside the node. An example of these would be
- // the multiply instructions of a partial multiply operation.
- SmallVector<Instruction *> InternalInstructions;
ComplexDeinterleavingRotation Rotation;
SmallVector<RawNodePtr> Operands;
Value *ReplacementNode = nullptr;
- void addInstruction(Instruction *I) { InternalInstructions.push_back(I); }
void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
- bool hasAllInternalUses(SmallPtrSet<Instruction *, 16> &AllInstructions);
-
void dump() { dump(dbgs()); }
void dump(raw_ostream &OS) {
auto PrintValue = [&](Value *V) {
@@ -181,12 +174,6 @@ public:
OS << " - ";
PrintNodeRef(Op);
}
- OS << " InternalInstructions:\n";
- for (const auto &I : InternalInstructions) {
- OS << " - \"";
- I->print(OS, true);
- OS << "\"\n";
- }
}
};
@@ -194,14 +181,22 @@ class ComplexDeinterleavingGraph {
public:
using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
- explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {}
+ explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
+ const TargetLibraryInfo *TLI)
+ : TL(TL), TLI(TLI) {}
private:
const TargetLowering *TL = nullptr;
- Instruction *RootValue = nullptr;
- NodePtr RootNode;
+ const TargetLibraryInfo *TLI = nullptr;
SmallVector<NodePtr> CompositeNodes;
- SmallPtrSet<Instruction *, 16> AllInstructions;
+
+ SmallPtrSet<Instruction *, 16> FinalInstructions;
+
+ /// Root instructions are instructions from which complex computation starts
+ std::map<Instruction *, NodePtr> RootToNode;
+
+ /// Topologically sorted root instructions
+ SmallVector<Instruction *, 1> OrderedRoots;
NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
Instruction *R, Instruction *I) {
@@ -211,10 +206,6 @@ private:
NodePtr submitCompositeNode(NodePtr Node) {
CompositeNodes.push_back(Node);
- AllInstructions.insert(Node->Real);
- AllInstructions.insert(Node->Imag);
- for (auto *I : Node->InternalInstructions)
- AllInstructions.insert(I);
return Node;
}
@@ -271,6 +262,10 @@ public:
/// current graph.
bool identifyNodes(Instruction *RootI);
+ /// Check that every instruction, from the roots to the leaves, has internal
+ /// uses.
+ bool checkNodes();
+
/// Perform the actual replacement of the underlying instruction graph.
void replaceNodes();
};
@@ -368,9 +363,7 @@ static bool isDeinterleavingMask(ArrayRef<int> Mask) {
}
bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
- bool Changed = false;
-
- SmallVector<Instruction *> DeadInstrRoots;
+ ComplexDeinterleavingGraph Graph(TL, TLI);
for (auto &I : *B) {
auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
@@ -382,22 +375,15 @@ bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
if (!isInterleavingMask(SVI->getShuffleMask()))
continue;
- ComplexDeinterleavingGraph Graph(TL);
- if (!Graph.identifyNodes(SVI))
- continue;
-
- Graph.replaceNodes();
- DeadInstrRoots.push_back(SVI);
- Changed = true;
+ Graph.identifyNodes(SVI);
}
- for (const auto &I : DeadInstrRoots) {
- if (!I || I->getParent() == nullptr)
- continue;
- llvm::RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
+ if (Graph.checkNodes()) {
+ Graph.replaceNodes();
+ return true;
}
- return Changed;
+ return false;
}
ComplexDeinterleavingGraph::NodePtr
@@ -511,7 +497,6 @@ ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
Node->Rotation = Rotation;
Node->addOperand(CommonNode);
Node->addOperand(UncommonNode);
- Node->InternalInstructions.append(FNegs);
return submitCompositeNode(Node);
}
@@ -627,8 +612,6 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
NodePtr Node = prepareCompositeNode(
ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
- Node->addInstruction(RealMulI);
- Node->addInstruction(ImagMulI);
Node->Rotation = Rotation;
Node->addOperand(CommonRes);
Node->addOperand(UncommonRes);
@@ -846,6 +829,8 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Shuffle,
RealShuffle, ImagShuffle);
PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
+ FinalInstructions.insert(RealShuffle);
+ FinalInstructions.insert(ImagShuffle);
return submitCompositeNode(PlaceholderNode);
}
if (RealShuffle || ImagShuffle) {
@@ -881,9 +866,7 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
return false;
- RootValue = RootI;
- AllInstructions.insert(RootI);
- RootNode = identifyNode(Real, Imag);
+ auto RootNode = identifyNode(Real, Imag);
LLVM_DEBUG({
Function *F = RootI->getFunction();
@@ -894,14 +877,86 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
dbgs() << "\n";
});
- // Check all instructions have internal uses
- for (const auto &Node : CompositeNodes) {
- if (!Node->hasAllInternalUses(AllInstructions)) {
- LLVM_DEBUG(dbgs() << " - Invalid internal uses\n");
- return false;
+ if (RootNode) {
+ RootToNode[RootI] = RootNode;
+ OrderedRoots.push_back(RootI);
+ return true;
+ }
+
+ return false;
+}
+
+bool ComplexDeinterleavingGraph::checkNodes() {
+ // Collect all instructions from roots to leaves
+ SmallPtrSet<Instruction *, 16> AllInstructions;
+ SmallVector<Instruction *, 8> Worklist;
+ for (auto *I : OrderedRoots)
+ Worklist.push_back(I);
+
+ // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
+ // chains
+ while (!Worklist.empty()) {
+ auto *I = Worklist.back();
+ Worklist.pop_back();
+
+ if (!AllInstructions.insert(I).second)
+ continue;
+
+ for (Value *Op : I->operands()) {
+ if (auto *OpI = dyn_cast<Instruction>(Op)) {
+ if (!FinalInstructions.count(I))
+ Worklist.emplace_back(OpI);
+ }
}
}
- return RootNode != nullptr;
+
+ // Find instructions that have users outside of chain
+ SmallVector<Instruction *, 2> OuterInstructions;
+ for (auto *I : AllInstructions) {
+ // Skip root nodes
+ if (RootToNode.count(I))
+ continue;
+
+ for (User *U : I->users()) {
+ if (AllInstructions.count(cast<Instruction>(U)))
+ continue;
+
+ // Found an instruction that is not used by XCMLA/XCADD chain
+ Worklist.emplace_back(I);
+ break;
+ }
+ }
+
+ // If any instructions are found to be used outside, find and remove roots
+ // that somehow connect to those instructions.
+ SmallPtrSet<Instruction *, 16> Visited;
+ while (!Worklist.empty()) {
+ auto *I = Worklist.back();
+ Worklist.pop_back();
+ if (!Visited.insert(I).second)
+ continue;
+
+ // Found an impacted root node. Removing it from the nodes to be
+ // deinterleaved
+ if (RootToNode.count(I)) {
+ LLVM_DEBUG(dbgs() << "Instruction " << *I
+ << " could be deinterleaved but its chain of complex "
+ "operations have an outside user\n");
+ RootToNode.erase(I);
+ }
+
+ if (!AllInstructions.count(I) || FinalInstructions.count(I))
+ continue;
+
+ for (User *U : I->users())
+ Worklist.emplace_back(cast<Instruction>(U));
+
+ for (Value *Op : I->operands()) {
+ if (auto *OpI = dyn_cast<Instruction>(Op))
+ Worklist.emplace_back(OpI);
+ }
+ }
+ return !RootToNode.empty();
}
static Value *replaceSymmetricNode(ComplexDeinterleavingGraph::RawNodePtr Node,
@@ -958,29 +1013,21 @@ Value *ComplexDeinterleavingGraph::replaceNode(
}
void ComplexDeinterleavingGraph::replaceNodes() {
- Value *R = replaceNode(RootNode.get());
- assert(R && "Unable to find replacement for RootValue");
- RootValue->replaceAllUsesWith(R);
-}
-
-bool ComplexDeinterleavingCompositeNode::hasAllInternalUses(
- SmallPtrSet<Instruction *, 16> &AllInstructions) {
- if (Operation == ComplexDeinterleavingOperation::Shuffle)
- return true;
+ SmallVector<Instruction *, 16> DeadInstrRoots;
+ for (auto *RootInstruction : OrderedRoots) {
+ // Check if this potential root went through check process and we can
+ // deinterleave it
+ if (!RootToNode.count(RootInstruction))
+ continue;
- for (auto *User : Real->users()) {
- if (!AllInstructions.contains(cast<Instruction>(User)))
- return false;
+ IRBuilder<> Builder(RootInstruction);
+ auto RootNode = RootToNode[RootInstruction];
+ Value *R = replaceNode(RootNode.get());
+ assert(R && "Unable to find replacement for RootInstruction");
+ DeadInstrRoots.push_back(RootInstruction);
+ RootInstruction->replaceAllUsesWith(R);
}
- for (auto *User : Imag->users()) {
- if (!AllInstructions.contains(cast<Instruction>(User)))
- return false;
- }
- for (auto *I : InternalInstructions) {
- for (auto *User : I->users()) {
- if (!AllInstructions.contains(cast<Instruction>(User)))
- return false;
- }
- }
- return true;
+
+ for (auto *I : DeadInstrRoots)
+ RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
}