aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Scalar/LoopTermFold.cpp
blob: d11af1e10e38fa18f990dbea1490386852e09408 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
//===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Scalar/LoopTermFold.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Config/llvm-config.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
#include <cassert>
#include <optional>

using namespace llvm;

#define DEBUG_TYPE "loop-term-fold"

STATISTIC(NumTermFold,
          "Number of terminating condition fold recognized and performed");

static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>
canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
                      const LoopInfo &LI, const TargetTransformInfo &TTI) {
  if (!L->isInnermost()) {
    LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n");
    return std::nullopt;
  }
  // Only inspect on simple loop structure
  if (!L->isLoopSimplifyForm()) {
    LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n");
    return std::nullopt;
  }

  if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
    LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
    return std::nullopt;
  }

  BasicBlock *LoopLatch = L->getLoopLatch();
  BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
  if (!BI || BI->isUnconditional())
    return std::nullopt;
  auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
  if (!TermCond) {
    LLVM_DEBUG(
        dbgs() << "Cannot fold on branching condition that is not an ICmpInst");
    return std::nullopt;
  }
  if (!TermCond->hasOneUse()) {
    LLVM_DEBUG(
        dbgs()
        << "Cannot replace terminating condition with more than one use\n");
    return std::nullopt;
  }

  BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
  Value *RHS = TermCond->getOperand(1);
  if (!LHS || !L->isLoopInvariant(RHS))
    // We could pattern match the inverse form of the icmp, but that is
    // non-canonical, and this pass is running *very* late in the pipeline.
    return std::nullopt;

  // Find the IV used by the current exit condition.
  PHINode *ToFold;
  Value *ToFoldStart, *ToFoldStep;
  if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
    return std::nullopt;

  // Ensure the simple recurrence is a part of the current loop.
  if (ToFold->getParent() != L->getHeader())
    return std::nullopt;

  // If that IV isn't dead after we rewrite the exit condition in terms of
  // another IV, there's no point in doing the transform.
  if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond))
    return std::nullopt;

  // Inserting instructions in the preheader has a runtime cost, scale
  // the allowed cost with the loops trip count as best we can.
  const unsigned ExpansionBudget = [&]() {
    unsigned Budget = 2 * SCEVCheapExpansionBudget;
    if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L))
      return std::min(Budget, SmallTC);
    if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L))
      return std::min(Budget, *SmallTC);
    // Unknown trip count, assume long running by default.
    return Budget;
  }();

  const SCEV *BECount = SE.getBackedgeTakenCount(L);
  const DataLayout &DL = L->getHeader()->getDataLayout();
  SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");

  PHINode *ToHelpFold = nullptr;
  const SCEV *TermValueS = nullptr;
  bool MustDropPoison = false;
  auto InsertPt = L->getLoopPreheader()->getTerminator();
  for (PHINode &PN : L->getHeader()->phis()) {
    if (ToFold == &PN)
      continue;

    if (!SE.isSCEVable(PN.getType())) {
      LLVM_DEBUG(dbgs() << "IV of phi '" << PN
                        << "' is not SCEV-able, not qualified for the "
                           "terminating condition folding.\n");
      continue;
    }
    const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
    // Only speculate on affine AddRec
    if (!AddRec || !AddRec->isAffine()) {
      LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN
                        << "' is not an affine add recursion, not qualified "
                           "for the terminating condition folding.\n");
      continue;
    }

    // Check that we can compute the value of AddRec on the exiting iteration
    // without soundness problems.  evaluateAtIteration internally needs
    // to multiply the stride of the iteration number - which may wrap around.
    // The issue here is subtle because computing the result accounting for
    // wrap is insufficient. In order to use the result in an exit test, we
    // must also know that AddRec doesn't take the same value on any previous
    // iteration. The simplest case to consider is a candidate IV which is
    // narrower than the trip count (and thus original IV), but this can
    // also happen due to non-unit strides on the candidate IVs.
    if (!AddRec->hasNoSelfWrap() ||
        !SE.isKnownNonZero(AddRec->getStepRecurrence(SE)))
      continue;

    const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);
    const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE);
    if (!Expander.isSafeToExpand(TermValueSLocal)) {
      LLVM_DEBUG(
          dbgs() << "Is not safe to expand terminating value for phi node" << PN
                 << "\n");
      continue;
    }

    if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, &TTI,
                                     InsertPt)) {
      LLVM_DEBUG(
          dbgs() << "Is too expensive to expand terminating value for phi node"
                 << PN << "\n");
      continue;
    }

    // The candidate IV may have been otherwise dead and poison from the
    // very first iteration.  If we can't disprove that, we can't use the IV.
    if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) {
      LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n");
      continue;
    }

    // The candidate IV may become poison on the last iteration.  If this
    // value is not branched on, this is a well defined program.  We're
    // about to add a new use to this IV, and we have to ensure we don't
    // insert UB which didn't previously exist.
    bool MustDropPoisonLocal = false;
    Instruction *PostIncV =
        cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch));
    if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(),
                                       &DT)) {
      LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN
                        << "\n");

      // If this is a complex recurrance with multiple instructions computing
      // the backedge value, we might need to strip poison flags from all of
      // them.
      if (PostIncV->getOperand(0) != &PN)
        continue;

      // In order to perform the transform, we need to drop the poison
      // generating flags on this instruction (if any).
      MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();
    }

    // We pick the last legal alternate IV.  We could expore choosing an optimal
    // alternate IV if we had a decent heuristic to do so.
    ToHelpFold = &PN;
    TermValueS = TermValueSLocal;
    MustDropPoison = MustDropPoisonLocal;
  }

  LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
                 << "Cannot find other AddRec IV to help folding\n";);

  LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs()
             << "\nFound loop that can fold terminating condition\n"
             << "  BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n"
             << "  TermCond: " << *TermCond << "\n"
             << "  BrandInst: " << *BI << "\n"
             << "  ToFold: " << *ToFold << "\n"
             << "  ToHelpFold: " << *ToHelpFold << "\n");

  if (!ToFold || !ToHelpFold)
    return std::nullopt;
  return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison);
}

static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
                        LoopInfo &LI, const TargetTransformInfo &TTI,
                        TargetLibraryInfo &TLI, MemorySSA *MSSA) {
  std::unique_ptr<MemorySSAUpdater> MSSAU;
  if (MSSA)
    MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);

  auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI);
  if (!Opt)
    return false;

  auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;

  NumTermFold++;

  BasicBlock *LoopPreheader = L->getLoopPreheader();
  BasicBlock *LoopLatch = L->getLoopLatch();

  (void)ToFold;
  LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
                    << *ToFold << "\n"
                    << "New term-cond phi-node:\n"
                    << *ToHelpFold << "\n");

  Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader);
  (void)StartValue;
  Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch);

  // See comment in canFoldTermCondOfLoop on why this is sufficient.
  if (MustDrop)
    cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags();

  // SCEVExpander for both use in preheader and latch
  const DataLayout &DL = L->getHeader()->getDataLayout();
  SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");

  assert(Expander.isSafeToExpand(TermValueS) &&
         "Terminating value was checked safe in canFoldTerminatingCondition");

  // Create new terminating value at loop preheader
  Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(),
                                            LoopPreheader->getTerminator());

  LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
                    << *StartValue << "\n"
                    << "Terminating value of new term-cond phi-node:\n"
                    << *TermValue << "\n");

  // Create new terminating condition at loop latch
  BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
  ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
  IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
  Value *NewTermCond =
      LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue,
                              "lsr_fold_term_cond.replaced_term_cond");
  // Swap successors to exit loop body if IV equals to new TermValue
  if (BI->getSuccessor(0) == L->getHeader())
    BI->swapSuccessors();

  LLVM_DEBUG(dbgs() << "Old term-cond:\n"
                    << *OldTermCond << "\n"
                    << "New term-cond:\n"
                    << *NewTermCond << "\n");

  BI->setCondition(NewTermCond);

  Expander.clear();
  OldTermCond->eraseFromParent();
  DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
  return true;
}

namespace {

class LoopTermFold : public LoopPass {
public:
  static char ID; // Pass ID, replacement for typeid

  LoopTermFold();

private:
  bool runOnLoop(Loop *L, LPPassManager &LPM) override;
  void getAnalysisUsage(AnalysisUsage &AU) const override;
};

} // end anonymous namespace

LoopTermFold::LoopTermFold() : LoopPass(ID) {
  initializeLoopTermFoldPass(*PassRegistry::getPassRegistry());
}

void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const {
  AU.addRequired<LoopInfoWrapperPass>();
  AU.addPreserved<LoopInfoWrapperPass>();
  AU.addPreservedID(LoopSimplifyID);
  AU.addRequiredID(LoopSimplifyID);
  AU.addRequired<DominatorTreeWrapperPass>();
  AU.addPreserved<DominatorTreeWrapperPass>();
  AU.addRequired<ScalarEvolutionWrapperPass>();
  AU.addPreserved<ScalarEvolutionWrapperPass>();
  AU.addRequired<TargetLibraryInfoWrapperPass>();
  AU.addRequired<TargetTransformInfoWrapperPass>();
  AU.addPreserved<MemorySSAWrapperPass>();
}

bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {
  if (skipLoop(L))
    return false;

  auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
  auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
  auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
      *L->getHeader()->getParent());
  auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
      *L->getHeader()->getParent());
  auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
  MemorySSA *MSSA = nullptr;
  if (MSSAAnalysis)
    MSSA = &MSSAAnalysis->getMSSA();
  return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA);
}

PreservedAnalyses LoopTermFoldPass::run(Loop &L, LoopAnalysisManager &AM,
                                        LoopStandardAnalysisResults &AR,
                                        LPMUpdater &) {
  if (!RunTermFold(&L, AR.SE, AR.DT, AR.LI, AR.TTI, AR.TLI, AR.MSSA))
    return PreservedAnalyses::all();

  auto PA = getLoopPassPreservedAnalyses();
  if (AR.MSSA)
    PA.preserve<MemorySSAAnalysis>();
  return PA;
}

char LoopTermFold::ID = 0;

INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
                      false, false)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
                    false, false)

Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); }