diff options
author | Butygin <ivan.butygin@intel.com> | 2021-10-28 19:04:35 +0300 |
---|---|---|
committer | Butygin <ivan.butygin@intel.com> | 2021-12-20 21:53:03 +0300 |
commit | c7f96d5ab188bf371f8096ed0a98f91f18a5435a (patch) | |
tree | 7c2d3e29049f7e03ac8c04185ab8131bc9833cad | |
parent | de904900600f11a65cdf44023061600b668e9df5 (diff) | |
download | llvm-c7f96d5ab188bf371f8096ed0a98f91f18a5435a.zip llvm-c7f96d5ab188bf371f8096ed0a98f91f18a5435a.tar.gz llvm-c7f96d5ab188bf371f8096ed0a98f91f18a5435a.tar.bz2 |
[mlir][scf] Canonicalize nested scf.if's to scf.if + arith.and
Differential Revision: https://reviews.llvm.org/D115930
-rw-r--r-- | mlir/lib/Dialect/SCF/SCF.cpp | 54 | ||||
-rw-r--r-- | mlir/test/Dialect/SCF/canonicalize.mlir | 18 |
2 files changed, 68 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp index 95a4c9d..534c0f0 100644 --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1596,14 +1596,60 @@ struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> { } }; +/// Convert nested `if`s into `arith.andi` + single `if`. +/// +/// scf.if %arg0 { +/// scf.if %arg1 { +/// ... +/// scf.yield +/// } +/// scf.yield +/// } +/// becomes +/// +/// %0 = arith.andi %arg0, %arg1 +/// scf.if %0 { +/// ... +/// scf.yield +/// } +struct CombineNestedIfs : public OpRewritePattern<IfOp> { + using OpRewritePattern<IfOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(IfOp op, + PatternRewriter &rewriter) const override { + // Both `if` ops must not yield results and have only `then` block. + if (op->getNumResults() != 0 || op.elseBlock()) + return failure(); + + auto nestedOps = op.thenBlock()->without_terminator(); + // Nested `if` must be the only op in block. + if (!llvm::hasSingleElement(nestedOps)) + return failure(); + + auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin()); + if (!nestedIf || nestedIf->getNumResults() != 0 || nestedIf.elseBlock()) + return failure(); + + Location loc = op.getLoc(); + Value newCondition = rewriter.create<arith::AndIOp>(loc, op.condition(), + nestedIf.condition()); + auto newIf = rewriter.create<IfOp>(loc, newCondition); + Block *newIfBlock = newIf.thenBlock(); + rewriter.eraseOp(newIfBlock->getTerminator()); + rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock); + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add<RemoveUnusedResults, RemoveStaticCondition, ConvertTrivialIfToSelect, - ConditionPropagation, ReplaceIfYieldWithConditionOrValue, CombineIfs, - RemoveEmptyElseBranch>(context); + results.add<CombineIfs, CombineNestedIfs, ConditionPropagation, + ConvertTrivialIfToSelect, RemoveEmptyElseBranch, + RemoveStaticCondition, RemoveUnusedResults, + ReplaceIfYieldWithConditionOrValue>(context); } Block *IfOp::thenBlock() { return &getThenRegion().back(); } diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 8d90303..d946c55 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -429,6 +429,24 @@ func @replace_false_if_with_values() { // ----- +// CHECK-LABEL: @merge_nested_if +// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1) +func @merge_nested_if(%arg0: i1, %arg1: i1) { +// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]] +// CHECK: scf.if %[[COND]] { +// CHECK-NEXT: "test.op"() + scf.if %arg0 { + scf.if %arg1 { + "test.op"() : () -> () + scf.yield + } + scf.yield + } + return +} + +// ----- + // CHECK-LABEL: @remove_zero_iteration_loop func @remove_zero_iteration_loop() { %c42 = arith.constant 42 : index |