diff options
author | goldsteinn <35538541+goldsteinn@users.noreply.github.com> | 2024-10-31 07:14:33 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-10-31 09:14:33 -0500 |
commit | 1e072ae289d77c3c9704a9ae832c076a303c435b (patch) | |
tree | 96d774b6e3dec3845a9f5e5b2fcdc94a100b1a1e /llvm/lib/CodeGen/CodeGenPrepare.cpp | |
parent | 0ab44fd2464354dfdca0e7afacbb21a84bca46d9 (diff) | |
download | llvm-1e072ae289d77c3c9704a9ae832c076a303c435b.zip llvm-1e072ae289d77c3c9704a9ae832c076a303c435b.tar.gz llvm-1e072ae289d77c3c9704a9ae832c076a303c435b.tar.bz2 |
[CGP] [CodeGenPrepare] Folding `urem` with loop invariant value plus offset (#104724)
This extends the existing fold:
```
for(i = Start; i < End; ++i)
Rem = (i nuw+- IncrLoopInvariant) u% RemAmtLoopInvariant;
```
->
```
Rem = (Start nuw+- IncrLoopInvariant) % RemAmtLoopInvariant;
for(i = Start; i < End; ++i, ++rem)
Rem = rem == RemAmtLoopInvariant ? 0 : Rem;
```
To work with a non-zero `IncrLoopInvariant`.
This is a common usage in cases such as:
```
for(i = 0; i < N; ++i)
if ((i + 1) % X) == 0)
do_something_occasionally_but_not_first_iter();
```
Alive2 w/ i4/unrolled 6x (needs to be ran locally due to timeout):
https://alive2.llvm.org/ce/z/6tgyN3
Exhaust proof over all uint8_t combinations in C++:
https://godbolt.org/z/WYa561388
Diffstat (limited to 'llvm/lib/CodeGen/CodeGenPrepare.cpp')
-rw-r--r-- | llvm/lib/CodeGen/CodeGenPrepare.cpp | 58 |
1 files changed, 49 insertions, 9 deletions
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp index 5224a6c..f1ac3d9 100644 --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -1981,17 +1981,36 @@ static bool foldFCmpToFPClassTest(CmpInst *Cmp, const TargetLowering &TLI, return true; } -static bool isRemOfLoopIncrementWithLoopInvariant(Instruction *Rem, - const LoopInfo *LI, - Value *&RemAmtOut, - PHINode *&LoopIncrPNOut) { +static bool isRemOfLoopIncrementWithLoopInvariant( + Instruction *Rem, const LoopInfo *LI, Value *&RemAmtOut, Value *&AddInstOut, + Value *&AddOffsetOut, PHINode *&LoopIncrPNOut) { Value *Incr, *RemAmt; // NB: If RemAmt is a power of 2 it *should* have been transformed by now. if (!match(Rem, m_URem(m_Value(Incr), m_Value(RemAmt)))) return false; + Value *AddInst, *AddOffset; // Find out loop increment PHI. auto *PN = dyn_cast<PHINode>(Incr); + if (PN != nullptr) { + AddInst = nullptr; + AddOffset = nullptr; + } else { + // Search through a NUW add on top of the loop increment. + Value *V0, *V1; + if (!match(Incr, m_NUWAdd(m_Value(V0), m_Value(V1)))) + return false; + + AddInst = Incr; + PN = dyn_cast<PHINode>(V0); + if (PN != nullptr) { + AddOffset = V1; + } else { + PN = dyn_cast<PHINode>(V1); + AddOffset = V0; + } + } + if (!PN) return false; @@ -2031,6 +2050,8 @@ static bool isRemOfLoopIncrementWithLoopInvariant(Instruction *Rem, // Set output variables. RemAmtOut = RemAmt; LoopIncrPNOut = PN; + AddInstOut = AddInst; + AddOffsetOut = AddOffset; return true; } @@ -2045,15 +2066,14 @@ static bool isRemOfLoopIncrementWithLoopInvariant(Instruction *Rem, // Rem = (Start nuw+ IncrLoopInvariant) % RemAmtLoopInvariant; // for(i = Start; i < End; ++i, ++rem) // Rem = rem == RemAmtLoopInvariant ? 0 : Rem; -// -// Currently only implemented for `IncrLoopInvariant` being zero. static bool foldURemOfLoopIncrement(Instruction *Rem, const DataLayout *DL, const LoopInfo *LI, SmallSet<BasicBlock *, 32> &FreshBBs, bool IsHuge) { - Value *RemAmt; + Value *AddOffset, *RemAmt, *AddInst; PHINode *LoopIncrPN; - if (!isRemOfLoopIncrementWithLoopInvariant(Rem, LI, RemAmt, LoopIncrPN)) + if (!isRemOfLoopIncrementWithLoopInvariant(Rem, LI, RemAmt, AddInst, + AddOffset, LoopIncrPN)) return false; // Only non-constant remainder as the extra IV is probably not profitable @@ -2071,6 +2091,23 @@ static bool foldURemOfLoopIncrement(Instruction *Rem, const DataLayout *DL, Loop *L = LI->getLoopFor(LoopIncrPN->getParent()); Value *Start = LoopIncrPN->getIncomingValueForBlock(L->getLoopPreheader()); + // If we have add create initial value for remainder. + // The logic here is: + // (urem (add nuw Start, IncrLoopInvariant), RemAmtLoopInvariant + // + // Only proceed if the expression simplifies (otherwise we can't fully + // optimize out the urem). + if (AddInst) { + assert(AddOffset && "We found an add but missing values"); + // Without dom-condition/assumption cache we aren't likely to get much out + // of a context instruction. + Start = simplifyAddInst(Start, AddOffset, + match(AddInst, m_NSWAdd(m_Value(), m_Value())), + /*IsNUW=*/true, *DL); + if (!Start) + return false; + } + // If we can't fully optimize out the `rem`, skip this transform. Start = simplifyURemInst(Start, RemAmt, *DL); if (!Start) @@ -2098,9 +2135,12 @@ static bool foldURemOfLoopIncrement(Instruction *Rem, const DataLayout *DL, FreshBBs.insert(LoopIncrPN->getParent()); FreshBBs.insert(L->getLoopLatch()); FreshBBs.insert(Rem->getParent()); - + if (AddInst) + FreshBBs.insert(cast<Instruction>(AddInst)->getParent()); replaceAllUsesWith(Rem, NewRem, FreshBBs, IsHuge); Rem->eraseFromParent(); + if (AddInst && AddInst->use_empty()) + cast<Instruction>(AddInst)->eraseFromParent(); return true; } |