aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorButygin <ivan.butygin@intel.com>2021-10-28 19:04:35 +0300
committerButygin <ivan.butygin@intel.com>2021-12-20 21:53:03 +0300
commitc7f96d5ab188bf371f8096ed0a98f91f18a5435a (patch)
tree7c2d3e29049f7e03ac8c04185ab8131bc9833cad
parentde904900600f11a65cdf44023061600b668e9df5 (diff)
downloadllvm-c7f96d5ab188bf371f8096ed0a98f91f18a5435a.zip
llvm-c7f96d5ab188bf371f8096ed0a98f91f18a5435a.tar.gz
llvm-c7f96d5ab188bf371f8096ed0a98f91f18a5435a.tar.bz2
[mlir][scf] Canonicalize nested scf.if's to scf.if + arith.and
Differential Revision: https://reviews.llvm.org/D115930
-rw-r--r--mlir/lib/Dialect/SCF/SCF.cpp54
-rw-r--r--mlir/test/Dialect/SCF/canonicalize.mlir18
2 files changed, 68 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
index 95a4c9d..534c0f0 100644
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -1596,14 +1596,60 @@ struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
}
};
+/// Convert nested `if`s into `arith.andi` + single `if`.
+///
+/// scf.if %arg0 {
+/// scf.if %arg1 {
+/// ...
+/// scf.yield
+/// }
+/// scf.yield
+/// }
+/// becomes
+///
+/// %0 = arith.andi %arg0, %arg1
+/// scf.if %0 {
+/// ...
+/// scf.yield
+/// }
+struct CombineNestedIfs : public OpRewritePattern<IfOp> {
+ using OpRewritePattern<IfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IfOp op,
+ PatternRewriter &rewriter) const override {
+ // Both `if` ops must not yield results and have only `then` block.
+ if (op->getNumResults() != 0 || op.elseBlock())
+ return failure();
+
+ auto nestedOps = op.thenBlock()->without_terminator();
+ // Nested `if` must be the only op in block.
+ if (!llvm::hasSingleElement(nestedOps))
+ return failure();
+
+ auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
+ if (!nestedIf || nestedIf->getNumResults() != 0 || nestedIf.elseBlock())
+ return failure();
+
+ Location loc = op.getLoc();
+ Value newCondition = rewriter.create<arith::AndIOp>(loc, op.condition(),
+ nestedIf.condition());
+ auto newIf = rewriter.create<IfOp>(loc, newCondition);
+ Block *newIfBlock = newIf.thenBlock();
+ rewriter.eraseOp(newIfBlock->getTerminator());
+ rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
} // namespace
void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results
- .add<RemoveUnusedResults, RemoveStaticCondition, ConvertTrivialIfToSelect,
- ConditionPropagation, ReplaceIfYieldWithConditionOrValue, CombineIfs,
- RemoveEmptyElseBranch>(context);
+ results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
+ ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
+ RemoveStaticCondition, RemoveUnusedResults,
+ ReplaceIfYieldWithConditionOrValue>(context);
}
Block *IfOp::thenBlock() { return &getThenRegion().back(); }
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 8d90303..d946c55 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -429,6 +429,24 @@ func @replace_false_if_with_values() {
// -----
+// CHECK-LABEL: @merge_nested_if
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
+func @merge_nested_if(%arg0: i1, %arg1: i1) {
+// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
+// CHECK: scf.if %[[COND]] {
+// CHECK-NEXT: "test.op"()
+ scf.if %arg0 {
+ scf.if %arg1 {
+ "test.op"() : () -> ()
+ scf.yield
+ }
+ scf.yield
+ }
+ return
+}
+
+// -----
+
// CHECK-LABEL: @remove_zero_iteration_loop
func @remove_zero_iteration_loop() {
%c42 = arith.constant 42 : index