diff options
author | Oleksandr "Alex" Zinenko <git@ozinenko.com> | 2025-07-23 21:02:40 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-07-23 21:02:40 +0200 |
commit | 9d11accf95db0ed08bd3181c25dd75fc793d089d (patch) | |
tree | 9c831c87b6f18fa8802243f9df49c06cb9651d03 | |
parent | bc1f85d234a8e8e4d1bcfb2126e7c8ec8c8f5b3d (diff) | |
download | llvm-9d11accf95db0ed08bd3181c25dd75fc793d089d.zip llvm-9d11accf95db0ed08bd3181c25dd75fc793d089d.tar.gz llvm-9d11accf95db0ed08bd3181c25dd75fc793d089d.tar.bz2 |
[mlir] move if-condition propagation to a standalone pass (#150278)
This offers a significant speedup over running this as a
canonicalizaiton pattern, up to 10x improvement when running on large
(>100k operations) inputs coming from Polygeist.
It is also not clear whether this transformation is a reasonable
canonicalization as it performs non-local rewrites.
-rw-r--r-- | mlir/include/mlir/Dialect/SCF/Transforms/Passes.td | 6 | ||||
-rw-r--r-- | mlir/lib/Dialect/SCF/IR/SCF.cpp | 64 | ||||
-rw-r--r-- | mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp | 98 | ||||
-rw-r--r-- | mlir/test/Dialect/SCF/canonicalize.mlir | 35 | ||||
-rw-r--r-- | mlir/test/Dialect/SCF/if-cond-prop.mlir | 34 |
6 files changed, 141 insertions, 97 deletions
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td index 6e5ef96..ca2510b 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td @@ -41,6 +41,12 @@ def SCFForLoopSpecialization : Pass<"scf-for-loop-specialization"> { let constructor = "mlir::createForLoopSpecializationPass()"; } +def SCFIfConditionPropagation : Pass<"scf-if-condition-propagation"> { + let summary = "Replace usages of if condition with true/false constants in " + "the conditional regions"; + let dependentDialects = ["arith::ArithDialect"]; +} + def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> { let summary = "Fuse adjacent parallel loops"; let constructor = "mlir::createParallelLoopFusionPass()"; diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index df41eba..72ab4b1 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -2414,65 +2414,6 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> { } }; -/// Allow the true region of an if to assume the condition is true -/// and vice versa. For example: -/// -/// scf.if %cmp { -/// print(%cmp) -/// } -/// -/// becomes -/// -/// scf.if %cmp { -/// print(true) -/// } -/// -struct ConditionPropagation : public OpRewritePattern<IfOp> { - using OpRewritePattern<IfOp>::OpRewritePattern; - - LogicalResult matchAndRewrite(IfOp op, - PatternRewriter &rewriter) const override { - // Early exit if the condition is constant since replacing a constant - // in the body with another constant isn't a simplification. - if (matchPattern(op.getCondition(), m_Constant())) - return failure(); - - bool changed = false; - mlir::Type i1Ty = rewriter.getI1Type(); - - // These variables serve to prevent creating duplicate constants - // and hold constant true or false values. - Value constantTrue = nullptr; - Value constantFalse = nullptr; - - for (OpOperand &use : - llvm::make_early_inc_range(op.getCondition().getUses())) { - if (op.getThenRegion().isAncestor(use.getOwner()->getParentRegion())) { - changed = true; - - if (!constantTrue) - constantTrue = rewriter.create<arith::ConstantOp>( - op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)); - - rewriter.modifyOpInPlace(use.getOwner(), - [&]() { use.set(constantTrue); }); - } else if (op.getElseRegion().isAncestor( - use.getOwner()->getParentRegion())) { - changed = true; - - if (!constantFalse) - constantFalse = rewriter.create<arith::ConstantOp>( - op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)); - - rewriter.modifyOpInPlace(use.getOwner(), - [&]() { use.set(constantFalse); }); - } - } - - return success(changed); - } -}; - /// Remove any statements from an if that are equivalent to the condition /// or its negation. For example: /// @@ -2854,9 +2795,8 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> { void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<CombineIfs, CombineNestedIfs, ConditionPropagation, - ConvertTrivialIfToSelect, RemoveEmptyElseBranch, - RemoveStaticCondition, RemoveUnusedResults, + results.add<CombineIfs, CombineNestedIfs, ConvertTrivialIfToSelect, + RemoveEmptyElseBranch, RemoveStaticCondition, RemoveUnusedResults, ReplaceIfYieldWithConditionOrValue>(context); } diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt index 84dd992..6d3bafb 100644 --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSCFTransforms ForallToFor.cpp ForallToParallel.cpp ForToWhile.cpp + IfConditionPropagation.cpp LoopCanonicalization.cpp LoopPipelining.cpp LoopRangeFolding.cpp diff --git a/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp new file mode 100644 index 0000000..bdc5129 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/IfConditionPropagation.cpp @@ -0,0 +1,98 @@ +//===- IfConditionPropagation.cpp -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a pass for constant propagation of the condition of an +// `scf.if` into its then and else regions as true and false respectively. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" + +using namespace mlir; + +namespace mlir { +#define GEN_PASS_DEF_SCFIFCONDITIONPROPAGATION +#include "mlir/Dialect/SCF/Transforms/Passes.h.inc" +} // namespace mlir + +/// Traverses the IR recursively (on region tree) and updates the uses of a +/// value also as the condition of an `scf.if` to either `true` or `false` +/// constants in the `then` and `else regions. This is done as a single +/// post-order sweep over the IR (without `walk`) for efficiency reasons. While +/// traversing, the function maintains the set of visited regions to quickly +/// identify whether the value belong to a region that is known to be nested in +/// the `then` or `else` branch of a specific loop. +static void propagateIfConditionsImpl(Operation *root, + llvm::SmallPtrSet<Region *, 8> &visited) { + if (auto scfIf = dyn_cast<scf::IfOp>(root)) { + llvm::SmallPtrSet<Region *, 8> thenChildren, elseChildren; + // Visit the "then" region, collect children. + for (Block &block : scfIf.getThenRegion()) { + for (Operation &op : block) { + propagateIfConditionsImpl(&op, thenChildren); + } + } + + // Visit the "else" region, collect children. + for (Block &block : scfIf.getElseRegion()) { + for (Operation &op : block) { + propagateIfConditionsImpl(&op, elseChildren); + } + } + + // Update uses to point to constants instead. + OpBuilder builder(scfIf); + Value trueValue = arith::ConstantIntOp::create(builder, scfIf.getLoc(), + /*value=*/true, /*width=*/1); + Value falseValue = + arith::ConstantIntOp::create(builder, scfIf.getLoc(), + /*value=*/false, /*width=*/1); + + for (OpOperand &use : scfIf.getCondition().getUses()) { + if (thenChildren.contains(use.getOwner()->getParentRegion())) + use.set(trueValue); + else if (elseChildren.contains(use.getOwner()->getParentRegion())) + use.set(falseValue); + } + if (trueValue.getUses().empty()) + trueValue.getDefiningOp()->erase(); + if (falseValue.getUses().empty()) + falseValue.getDefiningOp()->erase(); + + // Append the two lists of children and return them. + visited.insert_range(thenChildren); + visited.insert_range(elseChildren); + return; + } + + for (Region ®ion : root->getRegions()) { + for (Block &block : region) { + for (Operation &op : block) { + propagateIfConditionsImpl(&op, visited); + } + } + } +} + +/// Traverses the IR recursively (on region tree) and updates the uses of a +/// value also as the condition of an `scf.if` to either `true` or `false` +/// constants in the `then` and `else regions +static void propagateIfConditions(Operation *root) { + llvm::SmallPtrSet<Region *, 8> visited; + propagateIfConditionsImpl(root, visited); +} + +namespace { +/// Pass entrypoint. +struct SCFIfConditionPropagationPass + : impl::SCFIfConditionPropagationBase<SCFIfConditionPropagationPass> { + void runOnOperation() override { propagateIfConditions(getOperation()); } +}; +} // namespace diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 8ba8013..12d30e17 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -867,41 +867,6 @@ func.func @matmul_on_tensors(%t0: tensor<32x1024xf32>) -> tensor<?x?xf32> { // ----- -// CHECK-LABEL: @cond_prop -func.func @cond_prop(%arg0 : i1) -> index { - %res = scf.if %arg0 -> index { - %res1 = scf.if %arg0 -> index { - %v1 = "test.get_some_value1"() : () -> index - scf.yield %v1 : index - } else { - %v2 = "test.get_some_value2"() : () -> index - scf.yield %v2 : index - } - scf.yield %res1 : index - } else { - %res2 = scf.if %arg0 -> index { - %v3 = "test.get_some_value3"() : () -> index - scf.yield %v3 : index - } else { - %v4 = "test.get_some_value4"() : () -> index - scf.yield %v4 : index - } - scf.yield %res2 : index - } - return %res : index -} -// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (index) { -// CHECK-NEXT: %[[c1:.+]] = "test.get_some_value1"() : () -> index -// CHECK-NEXT: scf.yield %[[c1]] : index -// CHECK-NEXT: } else { -// CHECK-NEXT: %[[c4:.+]] = "test.get_some_value4"() : () -> index -// CHECK-NEXT: scf.yield %[[c4]] : index -// CHECK-NEXT: } -// CHECK-NEXT: return %[[if]] : index -// CHECK-NEXT:} - -// ----- - // CHECK-LABEL: @replace_if_with_cond1 func.func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) { %true = arith.constant true diff --git a/mlir/test/Dialect/SCF/if-cond-prop.mlir b/mlir/test/Dialect/SCF/if-cond-prop.mlir new file mode 100644 index 0000000..99d113f --- /dev/null +++ b/mlir/test/Dialect/SCF/if-cond-prop.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-opt %s --scf-if-condition-propagation --allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: @cond_prop +func.func @cond_prop(%arg0 : i1) -> index { + %res = scf.if %arg0 -> index { + %res1 = scf.if %arg0 -> index { + %v1 = "test.get_some_value1"() : () -> index + scf.yield %v1 : index + } else { + %v2 = "test.get_some_value2"() : () -> index + scf.yield %v2 : index + } + scf.yield %res1 : index + } else { + %res2 = scf.if %arg0 -> index { + %v3 = "test.get_some_value3"() : () -> index + scf.yield %v3 : index + } else { + %v4 = "test.get_some_value4"() : () -> index + scf.yield %v4 : index + } + scf.yield %res2 : index + } + return %res : index +} +// CHECK: %[[if:.+]] = scf.if %arg0 -> (index) { +// CHECK: %[[c1:.+]] = "test.get_some_value1"() : () -> index +// CHECK: scf.yield %[[c1]] : index +// CHECK: } else { +// CHECK: %[[c4:.+]] = "test.get_some_value4"() : () -> index +// CHECK: scf.yield %[[c4]] : index +// CHECK: } +// CHECK: return %[[if]] : index +// CHECK:} |