diff options
author | Noah Goldstein <goldstein.w.n@gmail.com> | 2024-06-25 19:51:39 +0800 |
---|---|---|
committer | Noah Goldstein <goldstein.w.n@gmail.com> | 2024-08-18 15:58:24 -0700 |
commit | c64ce8bf283120fd145a57d0e61f9697f719139d (patch) | |
tree | 40168c956007111e0dea5d78f045323da3f95881 /llvm/lib/CodeGen/CodeGenPrepare.cpp | |
parent | f16125a13ce725b1e936468e08257c0fbb80c0fa (diff) | |
download | llvm-c64ce8bf283120fd145a57d0e61f9697f719139d.zip llvm-c64ce8bf283120fd145a57d0e61f9697f719139d.tar.gz llvm-c64ce8bf283120fd145a57d0e61f9697f719139d.tar.bz2 |
[CodeGenPrepare] Folding `urem` with loop invariant value
```
for(i = Start; i < End; ++i)
Rem = (i nuw+ IncrLoopInvariant) u% RemAmtLoopInvariant;
```
->
```
Rem = (Start nuw+ IncrLoopInvariant) % RemAmtLoopInvariant;
for(i = Start; i < End; ++i, ++rem)
Rem = rem == RemAmtLoopInvariant ? 0 : Rem;
```
In its current state, only if `IncrLoopInvariant` and `Start` both
being zero.
Alive2 seemed unable to prove this (see:
https://alive2.llvm.org/ce/z/ATGDp3 which is clearly wrong but still
checks out...) so wrote an exhaustive test here:
https://godbolt.org/z/WYa561388
Closes #96625
Diffstat (limited to 'llvm/lib/CodeGen/CodeGenPrepare.cpp')
-rw-r--r-- | llvm/lib/CodeGen/CodeGenPrepare.cpp | 131 |
1 files changed, 131 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp index 48253a6..72db1f4 100644 --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -472,6 +472,7 @@ private: bool replaceMathCmpWithIntrinsic(BinaryOperator *BO, Value *Arg0, Value *Arg1, CmpInst *Cmp, Intrinsic::ID IID); bool optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT); + bool optimizeURem(Instruction *Rem); bool combineToUSubWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT); bool combineToUAddWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT); void verifyBFIUpdates(Function &F); @@ -1975,6 +1976,132 @@ static bool foldFCmpToFPClassTest(CmpInst *Cmp, const TargetLowering &TLI, return true; } +static bool isRemOfLoopIncrementWithLoopInvariant(Instruction *Rem, + const LoopInfo *LI, + Value *&RemAmtOut, + PHINode *&LoopIncrPNOut) { + Value *Incr, *RemAmt; + // NB: If RemAmt is a power of 2 it *should* have been transformed by now. + if (!match(Rem, m_URem(m_Value(Incr), m_Value(RemAmt)))) + return false; + + // Find out loop increment PHI. + auto *PN = dyn_cast<PHINode>(Incr); + if (!PN) + return false; + + // This isn't strictly necessary, what we really need is one increment and any + // amount of initial values all being the same. + if (PN->getNumIncomingValues() != 2) + return false; + + // Only trivially analyzable loops. + Loop *L = LI->getLoopFor(Rem->getParent()); + if (!L || !L->getLoopPreheader() || !L->getLoopLatch()) + return false; + + // Only works if the remainder amount is a loop invaraint + if (!L->isLoopInvariant(RemAmt)) + return false; + + // Is the PHI a loop increment? + auto LoopIncrInfo = getIVIncrement(PN, LI); + if (!LoopIncrInfo) + return false; + + // getIVIncrement finds the loop at PN->getParent(). This might be a different + // loop from the loop with Rem->getParent(). + if (L->getHeader() != PN->getParent()) + return false; + + // We need remainder_amount % increment_amount to be zero. Increment of one + // satisfies that without any special logic and is overwhelmingly the common + // case. + if (!match(LoopIncrInfo->second, m_One())) + return false; + + // Need the increment to not overflow. + if (!match(LoopIncrInfo->first, m_NUWAdd(m_Value(), m_Value()))) + return false; + + // Set output variables. + RemAmtOut = RemAmt; + LoopIncrPNOut = PN; + + return true; +} + +// Try to transform: +// +// for(i = Start; i < End; ++i) +// Rem = (i nuw+ IncrLoopInvariant) u% RemAmtLoopInvariant; +// +// -> +// +// Rem = (Start nuw+ IncrLoopInvariant) % RemAmtLoopInvariant; +// for(i = Start; i < End; ++i, ++rem) +// Rem = rem == RemAmtLoopInvariant ? 0 : Rem; +// +// Currently only implemented for `IncrLoopInvariant` being zero. +static bool foldURemOfLoopIncrement(Instruction *Rem, const DataLayout *DL, + const LoopInfo *LI, + SmallSet<BasicBlock *, 32> &FreshBBs, + bool IsHuge) { + Value *RemAmt; + PHINode *LoopIncrPN; + if (!isRemOfLoopIncrementWithLoopInvariant(Rem, LI, RemAmt, LoopIncrPN)) + return false; + + // Only non-constant remainder as the extra IV is probably not profitable + // in that case. + // + // Potential TODO(1): `urem` of a const ends up as `mul` + `shift` + `add`. If + // we can rule out register pressure and ensure this `urem` is executed each + // iteration, its probably profitable to handle the const case as well. + // + // Potential TODO(2): Should we have a check for how "nested" this remainder + // operation is? The new code runs every iteration so if the remainder is + // guarded behind unlikely conditions this might not be worth it. + if (match(RemAmt, m_ImmConstant())) + return false; + Loop *L = LI->getLoopFor(Rem->getParent()); + + Value *Start = LoopIncrPN->getIncomingValueForBlock(L->getLoopPreheader()); + + // Create new remainder with induction variable. + Type *Ty = Rem->getType(); + IRBuilder<> Builder(Rem->getContext()); + + Builder.SetInsertPoint(LoopIncrPN); + PHINode *NewRem = Builder.CreatePHI(Ty, 2); + + Builder.SetInsertPoint(cast<Instruction>( + LoopIncrPN->getIncomingValueForBlock(L->getLoopLatch()))); + // `(add (urem x, y), 1)` is always nuw. + Value *RemAdd = Builder.CreateNUWAdd(NewRem, ConstantInt::get(Ty, 1)); + Value *RemCmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, RemAdd, RemAmt); + Value *RemSel = + Builder.CreateSelect(RemCmp, Constant::getNullValue(Ty), RemAdd); + + NewRem->addIncoming(Start, L->getLoopPreheader()); + NewRem->addIncoming(RemSel, L->getLoopLatch()); + + // Insert all touched BBs. + FreshBBs.insert(LoopIncrPN->getParent()); + FreshBBs.insert(L->getLoopLatch()); + FreshBBs.insert(Rem->getParent()); + + replaceAllUsesWith(Rem, NewRem, FreshBBs, IsHuge); + Rem->eraseFromParent(); + return true; +} + +bool CodeGenPrepare::optimizeURem(Instruction *Rem) { + if (foldURemOfLoopIncrement(Rem, DL, LI, FreshBBs, IsHugeFunc)) + return true; + return false; +} + bool CodeGenPrepare::optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT) { if (sinkCmpExpression(Cmp, *TLI)) return true; @@ -8358,6 +8485,10 @@ bool CodeGenPrepare::optimizeInst(Instruction *I, ModifyDT &ModifiedDT) { if (optimizeCmp(Cmp, ModifiedDT)) return true; + if (match(I, m_URem(m_Value(), m_Value()))) + if (optimizeURem(I)) + return true; + if (LoadInst *LI = dyn_cast<LoadInst>(I)) { LI->setMetadata(LLVMContext::MD_invariant_group, nullptr); bool Modified = optimizeLoadExt(LI); |