aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp
blob: 8707ec91328dc9a86a2228bfa66ea1a241aa5220 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
//===- RotateWhileLoop.cpp - scf.while loop rotation ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Rotates `scf.while` loops.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SCF/Transforms/Patterns.h"

#include "mlir/Dialect/SCF/IR/SCF.h"

using namespace mlir;

namespace {
struct RotateWhileLoopPattern : OpRewritePattern<scf::WhileOp> {
  using OpRewritePattern<scf::WhileOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(scf::WhileOp whileOp,
                                PatternRewriter &rewriter) const final {
    // Setting this option would lead to infinite recursion on a greedy driver
    // as 'do-while' loops wouldn't be skipped.
    constexpr bool forceCreateCheck = false;
    FailureOr<scf::WhileOp> result =
        scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter, forceCreateCheck);
    // scf::wrapWhileLoopInZeroTripCheck hasn't yet implemented a failure
    // mechanism. 'do-while' loops are simply returned unmodified. In order to
    // stop recursion, we check input and output operations differ.
    return success(succeeded(result) && *result != whileOp);
  }
};
} // namespace

namespace mlir {
namespace scf {
void populateSCFRotateWhileLoopPatterns(RewritePatternSet &patterns) {
  patterns.add<RotateWhileLoopPattern>(patterns.getContext());
}
} // namespace scf
} // namespace mlir