aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/CodeGenPrepare.cpp
diff options
context:
space:
mode:
authorNoah Goldstein <goldstein.w.n@gmail.com>2024-06-25 19:51:39 +0800
committerNoah Goldstein <goldstein.w.n@gmail.com>2024-08-20 09:17:49 -0700
commite4c67ba67ee25e74bbcb719f90dba6e2e9ce41a0 (patch)
treead902cbc3b838a26012d97ae5fd9ca6c613daae8 /llvm/lib/CodeGen/CodeGenPrepare.cpp
parent9b25ad818c0b82fe4db8b43e9c9700805a2c7322 (diff)
downloadllvm-e4c67ba67ee25e74bbcb719f90dba6e2e9ce41a0.zip
llvm-e4c67ba67ee25e74bbcb719f90dba6e2e9ce41a0.tar.gz
llvm-e4c67ba67ee25e74bbcb719f90dba6e2e9ce41a0.tar.bz2
Recommit "[CodeGenPrepare] Folding `urem` with loop invariant value"
Was missing remainder on `Start` value. Also changed logic as as nikic suggested (getting loop from `PN` instead of `Rem`). The prior impl increased the complexity of the code and made debugging it more difficult. Closes #104877
Diffstat (limited to 'llvm/lib/CodeGen/CodeGenPrepare.cpp')
-rw-r--r--llvm/lib/CodeGen/CodeGenPrepare.cpp134
1 files changed, 134 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 48253a6..bf48c1f 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -472,6 +472,7 @@ private:
bool replaceMathCmpWithIntrinsic(BinaryOperator *BO, Value *Arg0, Value *Arg1,
CmpInst *Cmp, Intrinsic::ID IID);
bool optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT);
+ bool optimizeURem(Instruction *Rem);
bool combineToUSubWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT);
bool combineToUAddWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT);
void verifyBFIUpdates(Function &F);
@@ -1975,6 +1976,135 @@ static bool foldFCmpToFPClassTest(CmpInst *Cmp, const TargetLowering &TLI,
return true;
}
+static bool isRemOfLoopIncrementWithLoopInvariant(Instruction *Rem,
+ const LoopInfo *LI,
+ Value *&RemAmtOut,
+ PHINode *&LoopIncrPNOut) {
+ Value *Incr, *RemAmt;
+ // NB: If RemAmt is a power of 2 it *should* have been transformed by now.
+ if (!match(Rem, m_URem(m_Value(Incr), m_Value(RemAmt))))
+ return false;
+
+ // Find out loop increment PHI.
+ auto *PN = dyn_cast<PHINode>(Incr);
+ if (!PN)
+ return false;
+
+ // This isn't strictly necessary, what we really need is one increment and any
+ // amount of initial values all being the same.
+ if (PN->getNumIncomingValues() != 2)
+ return false;
+
+ // Only trivially analyzable loops.
+ Loop *L = LI->getLoopFor(PN->getParent());
+ if (!L || !L->getLoopPreheader() || !L->getLoopLatch())
+ return false;
+
+ // Req that the remainder is in the loop
+ if (!L->contains(Rem))
+ return false;
+
+ // Only works if the remainder amount is a loop invaraint
+ if (!L->isLoopInvariant(RemAmt))
+ return false;
+
+ // Is the PHI a loop increment?
+ auto LoopIncrInfo = getIVIncrement(PN, LI);
+ if (!LoopIncrInfo)
+ return false;
+
+ // We need remainder_amount % increment_amount to be zero. Increment of one
+ // satisfies that without any special logic and is overwhelmingly the common
+ // case.
+ if (!match(LoopIncrInfo->second, m_One()))
+ return false;
+
+ // Need the increment to not overflow.
+ if (!match(LoopIncrInfo->first, m_c_NUWAdd(m_Specific(PN), m_Value())))
+ return false;
+
+ // Set output variables.
+ RemAmtOut = RemAmt;
+ LoopIncrPNOut = PN;
+
+ return true;
+}
+
+// Try to transform:
+//
+// for(i = Start; i < End; ++i)
+// Rem = (i nuw+ IncrLoopInvariant) u% RemAmtLoopInvariant;
+//
+// ->
+//
+// Rem = (Start nuw+ IncrLoopInvariant) % RemAmtLoopInvariant;
+// for(i = Start; i < End; ++i, ++rem)
+// Rem = rem == RemAmtLoopInvariant ? 0 : Rem;
+//
+// Currently only implemented for `IncrLoopInvariant` being zero.
+static bool foldURemOfLoopIncrement(Instruction *Rem, const DataLayout *DL,
+ const LoopInfo *LI,
+ SmallSet<BasicBlock *, 32> &FreshBBs,
+ bool IsHuge) {
+ Value *RemAmt;
+ PHINode *LoopIncrPN;
+ if (!isRemOfLoopIncrementWithLoopInvariant(Rem, LI, RemAmt, LoopIncrPN))
+ return false;
+
+ // Only non-constant remainder as the extra IV is probably not profitable
+ // in that case.
+ //
+ // Potential TODO(1): `urem` of a const ends up as `mul` + `shift` + `add`. If
+ // we can rule out register pressure and ensure this `urem` is executed each
+ // iteration, its probably profitable to handle the const case as well.
+ //
+ // Potential TODO(2): Should we have a check for how "nested" this remainder
+ // operation is? The new code runs every iteration so if the remainder is
+ // guarded behind unlikely conditions this might not be worth it.
+ if (match(RemAmt, m_ImmConstant()))
+ return false;
+
+ Loop *L = LI->getLoopFor(LoopIncrPN->getParent());
+ Value *Start = LoopIncrPN->getIncomingValueForBlock(L->getLoopPreheader());
+ // If we can't fully optimize out the `rem`, skip this transform.
+ Start = simplifyURemInst(Start, RemAmt, *DL);
+ if (!Start)
+ return false;
+
+ // Create new remainder with induction variable.
+ Type *Ty = Rem->getType();
+ IRBuilder<> Builder(Rem->getContext());
+
+ Builder.SetInsertPoint(LoopIncrPN);
+ PHINode *NewRem = Builder.CreatePHI(Ty, 2);
+
+ Builder.SetInsertPoint(cast<Instruction>(
+ LoopIncrPN->getIncomingValueForBlock(L->getLoopLatch())));
+ // `(add (urem x, y), 1)` is always nuw.
+ Value *RemAdd = Builder.CreateNUWAdd(NewRem, ConstantInt::get(Ty, 1));
+ Value *RemCmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, RemAdd, RemAmt);
+ Value *RemSel =
+ Builder.CreateSelect(RemCmp, Constant::getNullValue(Ty), RemAdd);
+
+ NewRem->addIncoming(Start, L->getLoopPreheader());
+ NewRem->addIncoming(RemSel, L->getLoopLatch());
+
+ // Insert all touched BBs.
+ FreshBBs.insert(LoopIncrPN->getParent());
+ FreshBBs.insert(L->getLoopLatch());
+ FreshBBs.insert(Rem->getParent());
+
+ replaceAllUsesWith(Rem, NewRem, FreshBBs, IsHuge);
+ Rem->eraseFromParent();
+ return true;
+}
+
+bool CodeGenPrepare::optimizeURem(Instruction *Rem) {
+ if (foldURemOfLoopIncrement(Rem, DL, LI, FreshBBs, IsHugeFunc))
+ return true;
+ return false;
+}
+
bool CodeGenPrepare::optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT) {
if (sinkCmpExpression(Cmp, *TLI))
return true;
@@ -8358,6 +8488,10 @@ bool CodeGenPrepare::optimizeInst(Instruction *I, ModifyDT &ModifiedDT) {
if (optimizeCmp(Cmp, ModifiedDT))
return true;
+ if (match(I, m_URem(m_Value(), m_Value())))
+ if (optimizeURem(I))
+ return true;
+
if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
LI->setMetadata(LLVMContext::MD_invariant_group, nullptr);
bool Modified = optimizeLoadExt(LI);