diff options
author | Matthias Springer <springerm@google.com> | 2024-04-23 22:41:44 +0000 |
---|---|---|
committer | Matthias Springer <springerm@google.com> | 2024-04-23 22:43:05 +0000 |
commit | 9bb74b5e1d3403ef83058a181a89763744966597 (patch) | |
tree | bdc91803392e5d03e8221047a499172015b02e59 | |
parent | 0c0c5c475857e9cd6a2fe82fd1e46abdb174a1c1 (diff) | |
download | llvm-users/matthias-springer/greedy_rewrite_cse_constants.zip llvm-users/matthias-springer/greedy_rewrite_cse_constants.tar.gz llvm-users/matthias-springer/greedy_rewrite_cse_constants.tar.bz2 |
[mlir][Transforms] GreedyPatternRewriteDriver: Add flag to control constant CSE'ingusers/matthias-springer/greedy_rewrite_cse_constants
By default, the greedy pattern rewrite driver CSE's constant ops. If an op is CSE'd with an op in a parent region, the op is effectively "hoisted". Over the last years, users have described situations where this is not desirable/necessary. This commit adds a new flag to `GreedyRewriteConfig` that controls CSE'ing of constants. For testing purposes, it is also exposed as a canonicalizer pass flag.
-rw-r--r-- | mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h | 4 | ||||
-rw-r--r-- | mlir/include/mlir/Transforms/Passes.td | 2 | ||||
-rw-r--r-- | mlir/lib/Transforms/Canonicalizer.cpp | 2 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 4 | ||||
-rw-r--r-- | mlir/test/Pass/run-reproducer.mlir | 2 | ||||
-rw-r--r-- | mlir/test/Transforms/test-canonicalize.mlir | 14 |
6 files changed, 25 insertions, 3 deletions
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h index 763146a..880426c 100644 --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -47,6 +47,10 @@ public: /// Note: Only applicable when simplifying entire regions. bool enableRegionSimplification = true; + /// If set to "true", constants are CSE'd (even across multiple regions that + /// are in a parent-ancestor relationship). + bool cseConstants = true; + /// This specifies the maximum number of times the rewriter will iterate /// between applying patterns and simplifying regions. Use `kNoLimit` to /// disable this iteration limit. diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 1b40a87..549161c 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -35,6 +35,8 @@ def Canonicalizer : Pass<"canonicalize"> { Option<"enableRegionSimplification", "region-simplify", "bool", /*default=*/"true", "Perform control flow optimizations to the region tree">, + Option<"cseConstants", "cse-constants", "bool", /*default=*/"true", + "CSE constant operations">, Option<"maxIterations", "max-iterations", "int64_t", /*default=*/"10", "Max. iterations between applying patterns / simplifying regions">, diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index d50019b..2600df3 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -33,6 +33,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> { : config(config) { this->topDownProcessingEnabled = config.useTopDownTraversal; this->enableRegionSimplification = config.enableRegionSimplification; + this->cseConstants = config.cseConstants; this->maxIterations = config.maxIterations; this->maxNumRewrites = config.maxNumRewrites; this->disabledPatterns = disabledPatterns; @@ -45,6 +46,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> { // Set the config from possible pass options set in the meantime. config.useTopDownTraversal = topDownProcessingEnabled; config.enableRegionSimplification = enableRegionSimplification; + config.cseConstants = cseConstants; config.maxIterations = maxIterations; config.maxNumRewrites = maxNumRewrites; diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index cfd4f9c0..cf4a192a 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -848,13 +848,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && { if (!config.useTopDownTraversal) { // Add operations to the worklist in postorder. region.walk([&](Operation *op) { - if (!insertKnownConstant(op)) + if (!config.cseConstants || !insertKnownConstant(op)) addToWorklist(op); }); } else { // Add all nested operations to the worklist in preorder. region.walk<WalkOrder::PreOrder>([&](Operation *op) { - if (!insertKnownConstant(op)) { + if (!config.cseConstants || !insertKnownConstant(op)) { addToWorklist(op); return WalkResult::advance(); } diff --git a/mlir/test/Pass/run-reproducer.mlir b/mlir/test/Pass/run-reproducer.mlir index 57a58db..220ea24 100644 --- a/mlir/test/Pass/run-reproducer.mlir +++ b/mlir/test/Pass/run-reproducer.mlir @@ -14,7 +14,7 @@ func.func @bar() { external_resources: { mlir_reproducer: { verify_each: true, - // CHECK: builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=false test-convergence=false top-down=false})) + // CHECK: builtin.module(func.func(cse,canonicalize{cse-constants=true max-iterations=1 max-num-rewrites=-1 region-simplify=false test-convergence=false top-down=false})) pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))", disable_threading: true } diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir index 4f0095e..98eae14 100644 --- a/mlir/test/Transforms/test-canonicalize.mlir +++ b/mlir/test/Transforms/test-canonicalize.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s // RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=false}))' | FileCheck %s --check-prefixes=CHECK,NO-RS +// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{cse-constants=false}))' | FileCheck %s --check-prefixes=NO-CSE // CHECK-LABEL: func @remove_op_with_inner_ops_pattern func.func @remove_op_with_inner_ops_pattern() { @@ -89,3 +90,16 @@ func.func @test_region_simplify() { ^bb1: return } + +// CHECK-LABEL: do_not_cse_constant +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: return %[[c0]], %[[c0]] +// NO-CSE-LABEL: do_not_cse_constant +// NO-CSE: %[[c0:.*]] = arith.constant 0 : index +// NO-CSE: %[[c1:.*]] = arith.constant 0 : index +// NO-CSE: return %[[c0]], %[[c1]] +func.func @do_not_cse_constant() -> (index, index) { + %0 = arith.constant 0 : index + %1 = arith.constant 0 : index + return %0, %1 : index, index +}
\ No newline at end of file |