diff options
author | Arthur Eubanks <aeubanks@google.com> | 2020-09-15 15:02:23 -0700 |
---|---|---|
committer | Arthur Eubanks <aeubanks@google.com> | 2020-09-15 18:18:31 -0700 |
commit | f7aa1563eb5ff00416fba373073ba19832b6fc34 (patch) | |
tree | 069af1b9456780e461c996e2996a541196b0c94f /llvm/lib/Transforms/Utils/LowerSwitch.cpp | |
parent | 7bc77c8526b6b2f0a2b2b780151bafc5e4094130 (diff) | |
download | llvm-f7aa1563eb5ff00416fba373073ba19832b6fc34.zip llvm-f7aa1563eb5ff00416fba373073ba19832b6fc34.tar.gz llvm-f7aa1563eb5ff00416fba373073ba19832b6fc34.tar.bz2 |
[LowerSwitch][NewPM] Port lowerswitch to NPM
Reviewed By: ychen
Differential Revision: https://reviews.llvm.org/D87726
Diffstat (limited to 'llvm/lib/Transforms/Utils/LowerSwitch.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/LowerSwitch.cpp | 393 |
1 files changed, 191 insertions, 202 deletions
diff --git a/llvm/lib/Transforms/Utils/LowerSwitch.cpp b/llvm/lib/Transforms/Utils/LowerSwitch.cpp index 34e836d..10a4420 100644 --- a/llvm/lib/Transforms/Utils/LowerSwitch.cpp +++ b/llvm/lib/Transforms/Utils/LowerSwitch.cpp @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Utils/LowerSwitch.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -26,6 +27,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" @@ -55,9 +57,9 @@ namespace { } // end anonymous namespace +namespace { // Return true iff R is covered by Ranges. -static bool IsInRanges(const IntRange &R, - const std::vector<IntRange> &Ranges) { +bool IsInRanges(const IntRange &R, const std::vector<IntRange> &Ranges) { // Note: Ranges must be sorted, non-overlapping and non-adjacent. // Find the first range whose High field is >= R.High, @@ -68,120 +70,34 @@ static bool IsInRanges(const IntRange &R, return I != Ranges.end() && I->Low <= R.Low; } -namespace { - - /// Replace all SwitchInst instructions with chained branch instructions. - class LowerSwitch : public FunctionPass { - public: - // Pass identification, replacement for typeid - static char ID; - - LowerSwitch() : FunctionPass(ID) { - initializeLowerSwitchPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<LazyValueInfoWrapperPass>(); - } - - struct CaseRange { - ConstantInt* Low; - ConstantInt* High; - BasicBlock* BB; - - CaseRange(ConstantInt *low, ConstantInt *high, BasicBlock *bb) - : Low(low), High(high), BB(bb) {} - }; - - using CaseVector = std::vector<CaseRange>; - using CaseItr = std::vector<CaseRange>::iterator; - - private: - void processSwitchInst(SwitchInst *SI, - SmallPtrSetImpl<BasicBlock *> &DeleteList, - AssumptionCache *AC, LazyValueInfo *LVI); - - BasicBlock *switchConvert(CaseItr Begin, CaseItr End, - ConstantInt *LowerBound, ConstantInt *UpperBound, - Value *Val, BasicBlock *Predecessor, - BasicBlock *OrigBlock, BasicBlock *Default, - const std::vector<IntRange> &UnreachableRanges); - BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, - ConstantInt *LowerBound, ConstantInt *UpperBound, - BasicBlock *OrigBlock, BasicBlock *Default); - unsigned Clusterify(CaseVector &Cases, SwitchInst *SI); - }; - - /// The comparison function for sorting the switch case values in the vector. - /// WARNING: Case ranges should be disjoint! - struct CaseCmp { - bool operator()(const LowerSwitch::CaseRange& C1, - const LowerSwitch::CaseRange& C2) { - const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low); - const ConstantInt* CI2 = cast<const ConstantInt>(C2.High); - return CI1->getValue().slt(CI2->getValue()); - } - }; - -} // end anonymous namespace - -char LowerSwitch::ID = 0; - -// Publicly exposed interface to pass... -char &llvm::LowerSwitchID = LowerSwitch::ID; - -INITIALIZE_PASS_BEGIN(LowerSwitch, "lowerswitch", - "Lower SwitchInst's to branches", false, false) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) -INITIALIZE_PASS_END(LowerSwitch, "lowerswitch", - "Lower SwitchInst's to branches", false, false) - -// createLowerSwitchPass - Interface to this file... -FunctionPass *llvm::createLowerSwitchPass() { - return new LowerSwitch(); -} - -bool LowerSwitch::runOnFunction(Function &F) { - LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); - auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>(); - AssumptionCache *AC = ACT ? &ACT->getAssumptionCache(F) : nullptr; - - bool Changed = false; - SmallPtrSet<BasicBlock*, 8> DeleteList; - - for (Function::iterator I = F.begin(), E = F.end(); I != E; ) { - BasicBlock *Cur = &*I++; // Advance over block so we don't traverse new blocks - - // If the block is a dead Default block that will be deleted later, don't - // waste time processing it. - if (DeleteList.count(Cur)) - continue; - - if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) { - Changed = true; - processSwitchInst(SI, DeleteList, AC, LVI); - } - } - - for (BasicBlock* BB: DeleteList) { - LVI->eraseBlock(BB); - DeleteDeadBlock(BB); +struct CaseRange { + ConstantInt *Low; + ConstantInt *High; + BasicBlock *BB; + + CaseRange(ConstantInt *low, ConstantInt *high, BasicBlock *bb) + : Low(low), High(high), BB(bb) {} +}; + +using CaseVector = std::vector<CaseRange>; +using CaseItr = std::vector<CaseRange>::iterator; + +/// The comparison function for sorting the switch case values in the vector. +/// WARNING: Case ranges should be disjoint! +struct CaseCmp { + bool operator()(const CaseRange &C1, const CaseRange &C2) { + const ConstantInt *CI1 = cast<const ConstantInt>(C1.Low); + const ConstantInt *CI2 = cast<const ConstantInt>(C2.High); + return CI1->getValue().slt(CI2->getValue()); } - - return Changed; -} +}; /// Used for debugging purposes. LLVM_ATTRIBUTE_USED -static raw_ostream &operator<<(raw_ostream &O, - const LowerSwitch::CaseVector &C) { +raw_ostream &operator<<(raw_ostream &O, const CaseVector &C) { O << "["; - for (LowerSwitch::CaseVector::const_iterator B = C.begin(), E = C.end(); - B != E;) { + for (CaseVector::const_iterator B = C.begin(), E = C.end(); B != E;) { O << "[" << B->Low->getValue() << ", " << B->High->getValue() << "]"; if (++B != E) O << ", "; @@ -200,9 +116,9 @@ static raw_ostream &operator<<(raw_ostream &O, /// 2) Removed if subsequent incoming values now share the same case, i.e., /// multiple outcome edges are condensed into one. This is necessary to keep the /// number of phi values equal to the number of branches to SuccBB. -static void -fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, - const unsigned NumMergedCases = std::numeric_limits<unsigned>::max()) { +void FixPhis( + BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, + const unsigned NumMergedCases = std::numeric_limits<unsigned>::max()) { for (BasicBlock::iterator I = SuccBB->begin(), IE = SuccBB->getFirstNonPHI()->getIterator(); I != IE; ++I) { @@ -233,17 +149,80 @@ fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, } } +/// Create a new leaf block for the binary lookup tree. It checks if the +/// switch's value == the case's value. If not, then it jumps to the default +/// branch. At this point in the tree, the value can't be another valid case +/// value, so the jump to the "default" branch is warranted. +BasicBlock *NewLeafBlock(CaseRange &Leaf, Value *Val, ConstantInt *LowerBound, + ConstantInt *UpperBound, BasicBlock *OrigBlock, + BasicBlock *Default) { + Function *F = OrigBlock->getParent(); + BasicBlock *NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock"); + F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewLeaf); + + // Emit comparison + ICmpInst *Comp = nullptr; + if (Leaf.Low == Leaf.High) { + // Make the seteq instruction... + Comp = + new ICmpInst(*NewLeaf, ICmpInst::ICMP_EQ, Val, Leaf.Low, "SwitchLeaf"); + } else { + // Make range comparison + if (Leaf.Low == LowerBound) { + // Val >= Min && Val <= Hi --> Val <= Hi + Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, + "SwitchLeaf"); + } else if (Leaf.High == UpperBound) { + // Val <= Max && Val >= Lo --> Val >= Lo + Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SGE, Val, Leaf.Low, + "SwitchLeaf"); + } else if (Leaf.Low->isZero()) { + // Val >= 0 && Val <= Hi --> Val <=u Hi + Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, + "SwitchLeaf"); + } else { + // Emit V-Lo <=u Hi-Lo + Constant *NegLo = ConstantExpr::getNeg(Leaf.Low); + Instruction *Add = BinaryOperator::CreateAdd( + Val, NegLo, Val->getName() + ".off", NewLeaf); + Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High); + Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Add, UpperBound, + "SwitchLeaf"); + } + } + + // Make the conditional branch... + BasicBlock *Succ = Leaf.BB; + BranchInst::Create(Succ, Default, Comp, NewLeaf); + + // If there were any PHI nodes in this successor, rewrite one entry + // from OrigBlock to come from NewLeaf. + for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { + PHINode *PN = cast<PHINode>(I); + // Remove all but one incoming entries from the cluster + uint64_t Range = Leaf.High->getSExtValue() - Leaf.Low->getSExtValue(); + for (uint64_t j = 0; j < Range; ++j) { + PN->removeIncomingValue(OrigBlock); + } + + int BlockIdx = PN->getBasicBlockIndex(OrigBlock); + assert(BlockIdx != -1 && "Switch didn't go to this successor??"); + PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf); + } + + return NewLeaf; +} + /// Convert the switch statement into a binary lookup of the case values. /// The function recursively builds this tree. LowerBound and UpperBound are /// used to keep track of the bounds for Val that have already been checked by /// a block emitted by one of the previous calls to switchConvert in the call /// stack. -BasicBlock * -LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, - ConstantInt *UpperBound, Value *Val, - BasicBlock *Predecessor, BasicBlock *OrigBlock, - BasicBlock *Default, - const std::vector<IntRange> &UnreachableRanges) { +BasicBlock *SwitchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, + ConstantInt *UpperBound, Value *Val, + BasicBlock *Predecessor, BasicBlock *OrigBlock, + BasicBlock *Default, + const std::vector<IntRange> &UnreachableRanges) { assert(LowerBound && UpperBound && "Bounds must be initialized"); unsigned Size = End - Begin; @@ -255,10 +234,10 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, if (Begin->Low == LowerBound && Begin->High == UpperBound) { unsigned NumMergedCases = 0; NumMergedCases = UpperBound->getSExtValue() - LowerBound->getSExtValue(); - fixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases); + FixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases); return Begin->BB; } - return newLeafBlock(*Begin, Val, LowerBound, UpperBound, OrigBlock, + return NewLeafBlock(*Begin, Val, LowerBound, UpperBound, OrigBlock, Default); } @@ -305,12 +284,12 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, Val, Pivot.Low, "Pivot"); - BasicBlock *LBranch = switchConvert(LHS.begin(), LHS.end(), LowerBound, - NewUpperBound, Val, NewNode, OrigBlock, - Default, UnreachableRanges); - BasicBlock *RBranch = switchConvert(RHS.begin(), RHS.end(), NewLowerBound, - UpperBound, Val, NewNode, OrigBlock, - Default, UnreachableRanges); + BasicBlock *LBranch = + SwitchConvert(LHS.begin(), LHS.end(), LowerBound, NewUpperBound, Val, + NewNode, OrigBlock, Default, UnreachableRanges); + BasicBlock *RBranch = + SwitchConvert(RHS.begin(), RHS.end(), NewLowerBound, UpperBound, Val, + NewNode, OrigBlock, Default, UnreachableRanges); F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewNode); NewNode->getInstList().push_back(Comp); @@ -319,78 +298,10 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, return NewNode; } -/// Create a new leaf block for the binary lookup tree. It checks if the -/// switch's value == the case's value. If not, then it jumps to the default -/// branch. At this point in the tree, the value can't be another valid case -/// value, so the jump to the "default" branch is warranted. -BasicBlock *LowerSwitch::newLeafBlock(CaseRange &Leaf, Value *Val, - ConstantInt *LowerBound, - ConstantInt *UpperBound, - BasicBlock *OrigBlock, - BasicBlock *Default) { - Function* F = OrigBlock->getParent(); - BasicBlock* NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock"); - F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewLeaf); - - // Emit comparison - ICmpInst* Comp = nullptr; - if (Leaf.Low == Leaf.High) { - // Make the seteq instruction... - Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_EQ, Val, - Leaf.Low, "SwitchLeaf"); - } else { - // Make range comparison - if (Leaf.Low == LowerBound) { - // Val >= Min && Val <= Hi --> Val <= Hi - Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, - "SwitchLeaf"); - } else if (Leaf.High == UpperBound) { - // Val <= Max && Val >= Lo --> Val >= Lo - Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SGE, Val, Leaf.Low, - "SwitchLeaf"); - } else if (Leaf.Low->isZero()) { - // Val >= 0 && Val <= Hi --> Val <=u Hi - Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, - "SwitchLeaf"); - } else { - // Emit V-Lo <=u Hi-Lo - Constant* NegLo = ConstantExpr::getNeg(Leaf.Low); - Instruction* Add = BinaryOperator::CreateAdd(Val, NegLo, - Val->getName()+".off", - NewLeaf); - Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High); - Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Add, UpperBound, - "SwitchLeaf"); - } - } - - // Make the conditional branch... - BasicBlock* Succ = Leaf.BB; - BranchInst::Create(Succ, Default, Comp, NewLeaf); - - // If there were any PHI nodes in this successor, rewrite one entry - // from OrigBlock to come from NewLeaf. - for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { - PHINode* PN = cast<PHINode>(I); - // Remove all but one incoming entries from the cluster - uint64_t Range = Leaf.High->getSExtValue() - - Leaf.Low->getSExtValue(); - for (uint64_t j = 0; j < Range; ++j) { - PN->removeIncomingValue(OrigBlock); - } - - int BlockIdx = PN->getBasicBlockIndex(OrigBlock); - assert(BlockIdx != -1 && "Switch didn't go to this successor??"); - PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf); - } - - return NewLeaf; -} - /// Transform simple list of \p SI's cases into list of CaseRange's \p Cases. /// \post \p Cases wouldn't contain references to \p SI's default BB. /// \returns Number of \p SI's cases that do not reference \p SI's default BB. -unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { +unsigned Clusterify(CaseVector &Cases, SwitchInst *SI) { unsigned NumSimpleCases = 0; // Start with "simple" cases @@ -431,9 +342,9 @@ unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { /// Replace the specified switch instruction with a sequence of chained if-then /// insts in a balanced binary search. -void LowerSwitch::processSwitchInst(SwitchInst *SI, - SmallPtrSetImpl<BasicBlock *> &DeleteList, - AssumptionCache *AC, LazyValueInfo *LVI) { +void ProcessSwitchInst(SwitchInst *SI, + SmallPtrSetImpl<BasicBlock *> &DeleteList, + AssumptionCache *AC, LazyValueInfo *LVI) { BasicBlock *OrigBlock = SI->getParent(); Function *F = OrigBlock->getParent(); Value *Val = SI->getCondition(); // The value we are switching on... @@ -458,7 +369,7 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, if (Cases.empty()) { BranchInst::Create(Default, OrigBlock); // Remove all the references from Default's PHIs to OrigBlock, but one. - fixPhis(Default, OrigBlock, OrigBlock); + FixPhis(Default, OrigBlock, OrigBlock); SI->eraseFromParent(); return; } @@ -592,12 +503,12 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, BranchInst::Create(Default, NewDefault); BasicBlock *SwitchBlock = - switchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val, + SwitchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val, OrigBlock, OrigBlock, NewDefault, UnreachableRanges); // If there are entries in any PHI nodes for the default edge, make sure // to update them as well. - fixPhis(Default, OrigBlock, NewDefault); + FixPhis(Default, OrigBlock, NewDefault); // Branch to our shiny new if-then stuff... BranchInst::Create(SwitchBlock, OrigBlock); @@ -610,3 +521,81 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, if (pred_begin(OldDefault) == pred_end(OldDefault)) DeleteList.insert(OldDefault); } + +bool LowerSwitch(Function &F, LazyValueInfo *LVI, AssumptionCache *AC) { + bool Changed = false; + SmallPtrSet<BasicBlock *, 8> DeleteList; + + for (Function::iterator I = F.begin(), E = F.end(); I != E;) { + BasicBlock *Cur = + &*I++; // Advance over block so we don't traverse new blocks + + // If the block is a dead Default block that will be deleted later, don't + // waste time processing it. + if (DeleteList.count(Cur)) + continue; + + if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) { + Changed = true; + ProcessSwitchInst(SI, DeleteList, AC, LVI); + } + } + + for (BasicBlock *BB : DeleteList) { + LVI->eraseBlock(BB); + DeleteDeadBlock(BB); + } + + return Changed; +} + +/// Replace all SwitchInst instructions with chained branch instructions. +class LowerSwitchLegacyPass : public FunctionPass { +public: + // Pass identification, replacement for typeid + static char ID; + + LowerSwitchLegacyPass() : FunctionPass(ID) { + initializeLowerSwitchLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<LazyValueInfoWrapperPass>(); + } +}; + +} // end anonymous namespace + +char LowerSwitchLegacyPass::ID = 0; + +// Publicly exposed interface to pass... +char &llvm::LowerSwitchID = LowerSwitchLegacyPass::ID; + +INITIALIZE_PASS_BEGIN(LowerSwitchLegacyPass, "lowerswitch", + "Lower SwitchInst's to branches", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) +INITIALIZE_PASS_END(LowerSwitchLegacyPass, "lowerswitch", + "Lower SwitchInst's to branches", false, false) + +// createLowerSwitchPass - Interface to this file... +FunctionPass *llvm::createLowerSwitchPass() { + return new LowerSwitchLegacyPass(); +} + +bool LowerSwitchLegacyPass::runOnFunction(Function &F) { + LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); + auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>(); + AssumptionCache *AC = ACT ? &ACT->getAssumptionCache(F) : nullptr; + return LowerSwitch(F, LVI, AC); +} + +PreservedAnalyses LowerSwitchPass::run(Function &F, + FunctionAnalysisManager &AM) { + LazyValueInfo *LVI = &AM.getResult<LazyValueAnalysis>(F); + AssumptionCache *AC = AM.getCachedResult<AssumptionAnalysis>(F); + return LowerSwitch(F, LVI, AC) ? PreservedAnalyses::none() + : PreservedAnalyses::all(); +} |