aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/CodeGen/CodeGenPrepare.cpp
diff options
context:
space:
mode:
authorNoah Goldstein <goldstein.w.n@gmail.com>2024-06-25 19:51:39 +0800
committerNoah Goldstein <goldstein.w.n@gmail.com>2024-08-18 15:58:24 -0700
commitc64ce8bf283120fd145a57d0e61f9697f719139d (patch)
tree40168c956007111e0dea5d78f045323da3f95881 /llvm/lib/CodeGen/CodeGenPrepare.cpp
parentf16125a13ce725b1e936468e08257c0fbb80c0fa (diff)
downloadllvm-c64ce8bf283120fd145a57d0e61f9697f719139d.zip
llvm-c64ce8bf283120fd145a57d0e61f9697f719139d.tar.gz
llvm-c64ce8bf283120fd145a57d0e61f9697f719139d.tar.bz2
[CodeGenPrepare] Folding `urem` with loop invariant value
``` for(i = Start; i < End; ++i) Rem = (i nuw+ IncrLoopInvariant) u% RemAmtLoopInvariant; ``` -> ``` Rem = (Start nuw+ IncrLoopInvariant) % RemAmtLoopInvariant; for(i = Start; i < End; ++i, ++rem) Rem = rem == RemAmtLoopInvariant ? 0 : Rem; ``` In its current state, only if `IncrLoopInvariant` and `Start` both being zero. Alive2 seemed unable to prove this (see: https://alive2.llvm.org/ce/z/ATGDp3 which is clearly wrong but still checks out...) so wrote an exhaustive test here: https://godbolt.org/z/WYa561388 Closes #96625
Diffstat (limited to 'llvm/lib/CodeGen/CodeGenPrepare.cpp')
-rw-r--r--llvm/lib/CodeGen/CodeGenPrepare.cpp131
1 files changed, 131 insertions, 0 deletions
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 48253a6..72db1f4 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -472,6 +472,7 @@ private:
bool replaceMathCmpWithIntrinsic(BinaryOperator *BO, Value *Arg0, Value *Arg1,
CmpInst *Cmp, Intrinsic::ID IID);
bool optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT);
+ bool optimizeURem(Instruction *Rem);
bool combineToUSubWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT);
bool combineToUAddWithOverflow(CmpInst *Cmp, ModifyDT &ModifiedDT);
void verifyBFIUpdates(Function &F);
@@ -1975,6 +1976,132 @@ static bool foldFCmpToFPClassTest(CmpInst *Cmp, const TargetLowering &TLI,
return true;
}
+static bool isRemOfLoopIncrementWithLoopInvariant(Instruction *Rem,
+ const LoopInfo *LI,
+ Value *&RemAmtOut,
+ PHINode *&LoopIncrPNOut) {
+ Value *Incr, *RemAmt;
+ // NB: If RemAmt is a power of 2 it *should* have been transformed by now.
+ if (!match(Rem, m_URem(m_Value(Incr), m_Value(RemAmt))))
+ return false;
+
+ // Find out loop increment PHI.
+ auto *PN = dyn_cast<PHINode>(Incr);
+ if (!PN)
+ return false;
+
+ // This isn't strictly necessary, what we really need is one increment and any
+ // amount of initial values all being the same.
+ if (PN->getNumIncomingValues() != 2)
+ return false;
+
+ // Only trivially analyzable loops.
+ Loop *L = LI->getLoopFor(Rem->getParent());
+ if (!L || !L->getLoopPreheader() || !L->getLoopLatch())
+ return false;
+
+ // Only works if the remainder amount is a loop invaraint
+ if (!L->isLoopInvariant(RemAmt))
+ return false;
+
+ // Is the PHI a loop increment?
+ auto LoopIncrInfo = getIVIncrement(PN, LI);
+ if (!LoopIncrInfo)
+ return false;
+
+ // getIVIncrement finds the loop at PN->getParent(). This might be a different
+ // loop from the loop with Rem->getParent().
+ if (L->getHeader() != PN->getParent())
+ return false;
+
+ // We need remainder_amount % increment_amount to be zero. Increment of one
+ // satisfies that without any special logic and is overwhelmingly the common
+ // case.
+ if (!match(LoopIncrInfo->second, m_One()))
+ return false;
+
+ // Need the increment to not overflow.
+ if (!match(LoopIncrInfo->first, m_NUWAdd(m_Value(), m_Value())))
+ return false;
+
+ // Set output variables.
+ RemAmtOut = RemAmt;
+ LoopIncrPNOut = PN;
+
+ return true;
+}
+
+// Try to transform:
+//
+// for(i = Start; i < End; ++i)
+// Rem = (i nuw+ IncrLoopInvariant) u% RemAmtLoopInvariant;
+//
+// ->
+//
+// Rem = (Start nuw+ IncrLoopInvariant) % RemAmtLoopInvariant;
+// for(i = Start; i < End; ++i, ++rem)
+// Rem = rem == RemAmtLoopInvariant ? 0 : Rem;
+//
+// Currently only implemented for `IncrLoopInvariant` being zero.
+static bool foldURemOfLoopIncrement(Instruction *Rem, const DataLayout *DL,
+ const LoopInfo *LI,
+ SmallSet<BasicBlock *, 32> &FreshBBs,
+ bool IsHuge) {
+ Value *RemAmt;
+ PHINode *LoopIncrPN;
+ if (!isRemOfLoopIncrementWithLoopInvariant(Rem, LI, RemAmt, LoopIncrPN))
+ return false;
+
+ // Only non-constant remainder as the extra IV is probably not profitable
+ // in that case.
+ //
+ // Potential TODO(1): `urem` of a const ends up as `mul` + `shift` + `add`. If
+ // we can rule out register pressure and ensure this `urem` is executed each
+ // iteration, its probably profitable to handle the const case as well.
+ //
+ // Potential TODO(2): Should we have a check for how "nested" this remainder
+ // operation is? The new code runs every iteration so if the remainder is
+ // guarded behind unlikely conditions this might not be worth it.
+ if (match(RemAmt, m_ImmConstant()))
+ return false;
+ Loop *L = LI->getLoopFor(Rem->getParent());
+
+ Value *Start = LoopIncrPN->getIncomingValueForBlock(L->getLoopPreheader());
+
+ // Create new remainder with induction variable.
+ Type *Ty = Rem->getType();
+ IRBuilder<> Builder(Rem->getContext());
+
+ Builder.SetInsertPoint(LoopIncrPN);
+ PHINode *NewRem = Builder.CreatePHI(Ty, 2);
+
+ Builder.SetInsertPoint(cast<Instruction>(
+ LoopIncrPN->getIncomingValueForBlock(L->getLoopLatch())));
+ // `(add (urem x, y), 1)` is always nuw.
+ Value *RemAdd = Builder.CreateNUWAdd(NewRem, ConstantInt::get(Ty, 1));
+ Value *RemCmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, RemAdd, RemAmt);
+ Value *RemSel =
+ Builder.CreateSelect(RemCmp, Constant::getNullValue(Ty), RemAdd);
+
+ NewRem->addIncoming(Start, L->getLoopPreheader());
+ NewRem->addIncoming(RemSel, L->getLoopLatch());
+
+ // Insert all touched BBs.
+ FreshBBs.insert(LoopIncrPN->getParent());
+ FreshBBs.insert(L->getLoopLatch());
+ FreshBBs.insert(Rem->getParent());
+
+ replaceAllUsesWith(Rem, NewRem, FreshBBs, IsHuge);
+ Rem->eraseFromParent();
+ return true;
+}
+
+bool CodeGenPrepare::optimizeURem(Instruction *Rem) {
+ if (foldURemOfLoopIncrement(Rem, DL, LI, FreshBBs, IsHugeFunc))
+ return true;
+ return false;
+}
+
bool CodeGenPrepare::optimizeCmp(CmpInst *Cmp, ModifyDT &ModifiedDT) {
if (sinkCmpExpression(Cmp, *TLI))
return true;
@@ -8358,6 +8485,10 @@ bool CodeGenPrepare::optimizeInst(Instruction *I, ModifyDT &ModifiedDT) {
if (optimizeCmp(Cmp, ModifiedDT))
return true;
+ if (match(I, m_URem(m_Value(), m_Value())))
+ if (optimizeURem(I))
+ return true;
+
if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
LI->setMetadata(LLVMContext::MD_invariant_group, nullptr);
bool Modified = optimizeLoadExt(LI);