diff options
Diffstat (limited to 'llvm/lib/Transforms')
-rw-r--r-- | llvm/lib/Transforms/Scalar/LoopUnswitch.cpp | 160 |
1 files changed, 128 insertions, 32 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp index b1083a5..add63eb 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -374,9 +374,27 @@ Pass *llvm::createLoopUnswitchPass(bool Os) { return new LoopUnswitch(Os); } +/// Operator chain lattice. +enum OperatorChain { + OC_OpChainNone, ///< There is no operator. + OC_OpChainOr, ///< There are only ORs. + OC_OpChainAnd, ///< There are only ANDs. + OC_OpChainMixed ///< There are ANDs and ORs. +}; + /// Cond is a condition that occurs in L. If it is invariant in the loop, or has /// an invariant piece, return the invariant. Otherwise, return null. +// +/// NOTE: FindLIVLoopCondition will not return a partial LIV by walking up a +/// mixed operator chain, as we can not reliably find a value which will simplify +/// the operator chain. If the chain is AND-only or OR-only, we can use 0 or ~0 +/// to simplify the chain. +/// +/// NOTE: In case a partial LIV and a mixed operator chain, we may be able to +/// simplify the condition itself to a loop variant condition, but at the +/// cost of creating an entirely new loop. static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, + OperatorChain &ParentChain, DenseMap<Value *, Value *> &Cache) { auto CacheIt = Cache.find(Cond); if (CacheIt != Cache.end()) @@ -400,21 +418,53 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, return Cond; } + // Walk up the operator chain to find partial invariant conditions. if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Cond)) if (BO->getOpcode() == Instruction::And || BO->getOpcode() == Instruction::Or) { - // If either the left or right side is invariant, we can unswitch on this, - // which will cause the branch to go away in one loop and the condition to - // simplify in the other one. - if (Value *LHS = - FindLIVLoopCondition(BO->getOperand(0), L, Changed, Cache)) { - Cache[Cond] = LHS; - return LHS; + // Given the previous operator, compute the current operator chain status. + OperatorChain NewChain; + switch (ParentChain) { + case OC_OpChainNone: + NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd : + OC_OpChainOr; + break; + case OC_OpChainOr: + NewChain = BO->getOpcode() == Instruction::Or ? OC_OpChainOr : + OC_OpChainMixed; + break; + case OC_OpChainAnd: + NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd : + OC_OpChainMixed; + break; + case OC_OpChainMixed: + NewChain = OC_OpChainMixed; + break; } - if (Value *RHS = - FindLIVLoopCondition(BO->getOperand(1), L, Changed, Cache)) { - Cache[Cond] = RHS; - return RHS; + + // If we reach a Mixed state, we do not want to keep walking up as we can not + // reliably find a value that will simplify the chain. With this check, we + // will return null on the first sight of mixed chain and the caller will + // either backtrack to find partial LIV in other operand or return null. + if (NewChain != OC_OpChainMixed) { + // Update the current operator chain type before we search up the chain. + ParentChain = NewChain; + // If either the left or right side is invariant, we can unswitch on this, + // which will cause the branch to go away in one loop and the condition to + // simplify in the other one. + if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed, + ParentChain, Cache)) { + Cache[Cond] = LHS; + return LHS; + } + // We did not manage to find a partial LIV in operand(0). Backtrack and try + // operand(1). + ParentChain = NewChain; + if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed, + ParentChain, Cache)) { + Cache[Cond] = RHS; + return RHS; + } } } @@ -422,9 +472,21 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, return nullptr; } -static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) { +/// Cond is a condition that occurs in L. If it is invariant in the loop, or has +/// an invariant piece, return the invariant along with the operator chain type. +/// Otherwise, return null. +static std::pair<Value *, OperatorChain> FindLIVLoopCondition(Value *Cond, + Loop *L, + bool &Changed) { DenseMap<Value *, Value *> Cache; - return FindLIVLoopCondition(Cond, L, Changed, Cache); + OperatorChain OpChain = OC_OpChainNone; + Value *FCond = FindLIVLoopCondition(Cond, L, Changed, OpChain, Cache); + + // In case we do find a LIV, it can not be obtained by walking up a mixed + // operator chain. + assert((!FCond || OpChain != OC_OpChainMixed) && + "Do not expect a partial LIV with mixed operator chain"); + return {FCond, OpChain}; } bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { @@ -556,7 +618,7 @@ bool LoopUnswitch::processCurrentLoop() { for (IntrinsicInst *Guard : Guards) { Value *LoopCond = - FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed); + FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed).first; if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { // NB! Unswitching (if successful) could have erased some of the @@ -597,7 +659,7 @@ bool LoopUnswitch::processCurrentLoop() { // See if this, or some part of it, is loop invariant. If so, we can // unswitch on it if we desire. Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), - currentLoop, Changed); + currentLoop, Changed).first; if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) { ++NumBranches; @@ -605,24 +667,49 @@ bool LoopUnswitch::processCurrentLoop() { } } } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { - Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), - currentLoop, Changed); + Value *SC = SI->getCondition(); + Value *LoopCond; + OperatorChain OpChain; + std::tie(LoopCond, OpChain) = + FindLIVLoopCondition(SC, currentLoop, Changed); + unsigned NumCases = SI->getNumCases(); if (LoopCond && NumCases) { // Find a value to unswitch on: // FIXME: this should chose the most expensive case! // FIXME: scan for a case with a non-critical edge? Constant *UnswitchVal = nullptr; - - // Do not process same value again and again. - // At this point we have some cases already unswitched and - // some not yet unswitched. Let's find the first not yet unswitched one. - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); - i != e; ++i) { - Constant *UnswitchValCandidate = i.getCaseValue(); - if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) { - UnswitchVal = UnswitchValCandidate; - break; + // Find a case value such that at least one case value is unswitched + // out. + if (OpChain == OC_OpChainAnd) { + // If the chain only has ANDs and the switch has a case value of 0. + // Dropping in a 0 to the chain will unswitch out the 0-casevalue. + auto *AllZero = cast<ConstantInt>(Constant::getNullValue(SC->getType())); + if (BranchesInfo.isUnswitched(SI, AllZero)) + continue; + // We are unswitching 0 out. + UnswitchVal = AllZero; + } else if (OpChain == OC_OpChainOr) { + // If the chain only has ORs and the switch has a case value of ~0. + // Dropping in a ~0 to the chain will unswitch out the ~0-casevalue. + auto *AllOne = cast<ConstantInt>(Constant::getAllOnesValue(SC->getType())); + if (BranchesInfo.isUnswitched(SI, AllOne)) + continue; + // We are unswitching ~0 out. + UnswitchVal = AllOne; + } else { + assert(OpChain == OC_OpChainNone && + "Expect to unswitch on trivial chain"); + // Do not process same value again and again. + // At this point we have some cases already unswitched and + // some not yet unswitched. Let's find the first not yet unswitched one. + for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); + i != e; ++i) { + Constant *UnswitchValCandidate = i.getCaseValue(); + if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) { + UnswitchVal = UnswitchValCandidate; + break; + } } } @@ -631,6 +718,11 @@ bool LoopUnswitch::processCurrentLoop() { if (UnswitchIfProfitable(LoopCond, UnswitchVal)) { ++NumSwitches; + // In case of a full LIV, UnswitchVal is the value we unswitched out. + // In case of a partial LIV, we only unswitch when its an AND-chain + // or OR-chain. In both cases switch input value simplifies to + // UnswitchVal. + BranchesInfo.setUnswitched(SI, UnswitchVal); return true; } } @@ -641,7 +733,7 @@ bool LoopUnswitch::processCurrentLoop() { BBI != E; ++BBI) if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) { Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), - currentLoop, Changed); + currentLoop, Changed).first; if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { ++NumSelects; @@ -900,7 +992,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { return false; Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), - currentLoop, Changed); + currentLoop, Changed).first; // Unswitch only if the trivial condition itself is an LIV (not // partial LIV which could occur in and/or) @@ -931,7 +1023,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) { // If this isn't switching on an invariant condition, we can't unswitch it. Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), - currentLoop, Changed); + currentLoop, Changed).first; // Unswitch only if the trivial condition itself is an LIV (not // partial LIV which could occur in and/or) @@ -969,6 +1061,9 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB, nullptr); + + // We are only unswitching full LIV. + BranchesInfo.setUnswitched(SI, CondVal); ++NumSwitches; return true; } @@ -1250,6 +1345,9 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, SwitchInst *SI = dyn_cast<SwitchInst>(UI); if (!SI || !isa<ConstantInt>(Val)) continue; + // NOTE: if a case value for the switch is unswitched out, we record it + // after the unswitch finishes. We can not record it here as the switch + // is not a direct user of the partial LIV. SwitchInst::CaseIt DeadCase = SI->findCaseValue(cast<ConstantInt>(Val)); // Default case is live for multiple values. if (DeadCase == SI->case_default()) continue; @@ -1262,8 +1360,6 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, BasicBlock *SISucc = DeadCase.getCaseSuccessor(); BasicBlock *Latch = L->getLoopLatch(); - BranchesInfo.setUnswitched(SI, Val); - if (!SI->findCaseDest(SISucc)) continue; // Edge is critical. // If the DeadCase successor dominates the loop latch, then the // transformation isn't safe since it will delete the sole predecessor edge |