diff options
author | Noah Goldstein <goldstein.w.n@gmail.com> | 2024-06-25 19:51:39 +0800 |
---|---|---|
committer | Noah Goldstein <goldstein.w.n@gmail.com> | 2024-08-20 09:17:49 -0700 |
commit | e4c67ba67ee25e74bbcb719f90dba6e2e9ce41a0 (patch) | |
tree | ad902cbc3b838a26012d97ae5fd9ca6c613daae8 /llvm/lib/CodeGen/CodeGenPrepare.cpp | |
parent | 9b25ad818c0b82fe4db8b43e9c9700805a2c7322 (diff) | |
download | llvm-e4c67ba67ee25e74bbcb719f90dba6e2e9ce41a0.zip llvm-e4c67ba67ee25e74bbcb719f90dba6e2e9ce41a0.tar.gz llvm-e4c67ba67ee25e74bbcb719f90dba6e2e9ce41a0.tar.bz2 |
Recommit "[CodeGenPrepare] Folding `urem` with loop invariant value"
Was missing remainder on `Start` value.
Also changed logic as as nikic suggested (getting loop from `PN`
instead of `Rem`). The prior impl increased the complexity of the code
and made debugging it more difficult.
Closes #104877
Diffstat (limited to 'llvm/lib/CodeGen/CodeGenPrepare.cpp')
-rw-r--r-- | llvm/lib/CodeGen/CodeGenPrepare.cpp | 134 |
1 files changed, 134 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp index 48253a6..bf48c1f 100644 --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -472,6 +472,7 @@ private: bool replaceMathCmpWithIntrinsic(BinaryOperator *BO, Value *Arg0, Value *Arg1, CmpInst *Cmp, Intrinsic::ID IID); bool optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT); + bool optimizeURem(Instruction *Rem); bool combineToUSubWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT); bool combineToUAddWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT); void verifyBFIUpdates(Function &F); @@ -1975,6 +1976,135 @@ static bool foldFCmpToFPClassTest(CmpInst *Cmp, const TargetLowering &TLI, return true; } +static bool isRemOfLoopIncrementWithLoopInvariant(Instruction *Rem, + const LoopInfo *LI, + Value *&RemAmtOut, + PHINode *&LoopIncrPNOut) { + Value *Incr, *RemAmt; + // NB: If RemAmt is a power of 2 it *should* have been transformed by now. + if (!match(Rem, m_URem(m_Value(Incr), m_Value(RemAmt)))) + return false; + + // Find out loop increment PHI. + auto *PN = dyn_cast<PHINode>(Incr); + if (!PN) + return false; + + // This isn't strictly necessary, what we really need is one increment and any + // amount of initial values all being the same. + if (PN->getNumIncomingValues() != 2) + return false; + + // Only trivially analyzable loops. + Loop *L = LI->getLoopFor(PN->getParent()); + if (!L || !L->getLoopPreheader() || !L->getLoopLatch()) + return false; + + // Req that the remainder is in the loop + if (!L->contains(Rem)) + return false; + + // Only works if the remainder amount is a loop invaraint + if (!L->isLoopInvariant(RemAmt)) + return false; + + // Is the PHI a loop increment? + auto LoopIncrInfo = getIVIncrement(PN, LI); + if (!LoopIncrInfo) + return false; + + // We need remainder_amount % increment_amount to be zero. Increment of one + // satisfies that without any special logic and is overwhelmingly the common + // case. + if (!match(LoopIncrInfo->second, m_One())) + return false; + + // Need the increment to not overflow. + if (!match(LoopIncrInfo->first, m_c_NUWAdd(m_Specific(PN), m_Value()))) + return false; + + // Set output variables. + RemAmtOut = RemAmt; + LoopIncrPNOut = PN; + + return true; +} + +// Try to transform: +// +// for(i = Start; i < End; ++i) +// Rem = (i nuw+ IncrLoopInvariant) u% RemAmtLoopInvariant; +// +// -> +// +// Rem = (Start nuw+ IncrLoopInvariant) % RemAmtLoopInvariant; +// for(i = Start; i < End; ++i, ++rem) +// Rem = rem == RemAmtLoopInvariant ? 0 : Rem; +// +// Currently only implemented for `IncrLoopInvariant` being zero. +static bool foldURemOfLoopIncrement(Instruction *Rem, const DataLayout *DL, + const LoopInfo *LI, + SmallSet<BasicBlock *, 32> &FreshBBs, + bool IsHuge) { + Value *RemAmt; + PHINode *LoopIncrPN; + if (!isRemOfLoopIncrementWithLoopInvariant(Rem, LI, RemAmt, LoopIncrPN)) + return false; + + // Only non-constant remainder as the extra IV is probably not profitable + // in that case. + // + // Potential TODO(1): `urem` of a const ends up as `mul` + `shift` + `add`. If + // we can rule out register pressure and ensure this `urem` is executed each + // iteration, its probably profitable to handle the const case as well. + // + // Potential TODO(2): Should we have a check for how "nested" this remainder + // operation is? The new code runs every iteration so if the remainder is + // guarded behind unlikely conditions this might not be worth it. + if (match(RemAmt, m_ImmConstant())) + return false; + + Loop *L = LI->getLoopFor(LoopIncrPN->getParent()); + Value *Start = LoopIncrPN->getIncomingValueForBlock(L->getLoopPreheader()); + // If we can't fully optimize out the `rem`, skip this transform. + Start = simplifyURemInst(Start, RemAmt, *DL); + if (!Start) + return false; + + // Create new remainder with induction variable. + Type *Ty = Rem->getType(); + IRBuilder<> Builder(Rem->getContext()); + + Builder.SetInsertPoint(LoopIncrPN); + PHINode *NewRem = Builder.CreatePHI(Ty, 2); + + Builder.SetInsertPoint(cast<Instruction>( + LoopIncrPN->getIncomingValueForBlock(L->getLoopLatch()))); + // `(add (urem x, y), 1)` is always nuw. + Value *RemAdd = Builder.CreateNUWAdd(NewRem, ConstantInt::get(Ty, 1)); + Value *RemCmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, RemAdd, RemAmt); + Value *RemSel = + Builder.CreateSelect(RemCmp, Constant::getNullValue(Ty), RemAdd); + + NewRem->addIncoming(Start, L->getLoopPreheader()); + NewRem->addIncoming(RemSel, L->getLoopLatch()); + + // Insert all touched BBs. + FreshBBs.insert(LoopIncrPN->getParent()); + FreshBBs.insert(L->getLoopLatch()); + FreshBBs.insert(Rem->getParent()); + + replaceAllUsesWith(Rem, NewRem, FreshBBs, IsHuge); + Rem->eraseFromParent(); + return true; +} + +bool CodeGenPrepare::optimizeURem(Instruction *Rem) { + if (foldURemOfLoopIncrement(Rem, DL, LI, FreshBBs, IsHugeFunc)) + return true; + return false; +} + bool CodeGenPrepare::optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT) { if (sinkCmpExpression(Cmp, *TLI)) return true; @@ -8358,6 +8488,10 @@ bool CodeGenPrepare::optimizeInst(Instruction *I, ModifyDT &ModifiedDT) { if (optimizeCmp(Cmp, ModifiedDT)) return true; + if (match(I, m_URem(m_Value(), m_Value()))) + if (optimizeURem(I)) + return true; + if (LoadInst *LI = dyn_cast<LoadInst>(I)) { LI->setMetadata(LLVMContext::MD_invariant_group, nullptr); bool Modified = optimizeLoadExt(LI); |