aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <springerm@google.com>2024-04-23 22:41:44 +0000
committerMatthias Springer <springerm@google.com>2024-04-23 22:43:05 +0000
commit9bb74b5e1d3403ef83058a181a89763744966597 (patch)
treebdc91803392e5d03e8221047a499172015b02e59
parent0c0c5c475857e9cd6a2fe82fd1e46abdb174a1c1 (diff)
downloadllvm-users/matthias-springer/greedy_rewrite_cse_constants.zip
llvm-users/matthias-springer/greedy_rewrite_cse_constants.tar.gz
llvm-users/matthias-springer/greedy_rewrite_cse_constants.tar.bz2
[mlir][Transforms] GreedyPatternRewriteDriver: Add flag to control constant CSE'ingusers/matthias-springer/greedy_rewrite_cse_constants
By default, the greedy pattern rewrite driver CSE's constant ops. If an op is CSE'd with an op in a parent region, the op is effectively "hoisted". Over the last years, users have described situations where this is not desirable/necessary. This commit adds a new flag to `GreedyRewriteConfig` that controls CSE'ing of constants. For testing purposes, it is also exposed as a canonicalizer pass flag.
-rw-r--r--mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h4
-rw-r--r--mlir/include/mlir/Transforms/Passes.td2
-rw-r--r--mlir/lib/Transforms/Canonicalizer.cpp2
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp4
-rw-r--r--mlir/test/Pass/run-reproducer.mlir2
-rw-r--r--mlir/test/Transforms/test-canonicalize.mlir14
6 files changed, 25 insertions, 3 deletions
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 763146a..880426c 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -47,6 +47,10 @@ public:
/// Note: Only applicable when simplifying entire regions.
bool enableRegionSimplification = true;
+ /// If set to "true", constants are CSE'd (even across multiple regions that
+ /// are in a parent-ancestor relationship).
+ bool cseConstants = true;
+
/// This specifies the maximum number of times the rewriter will iterate
/// between applying patterns and simplifying regions. Use `kNoLimit` to
/// disable this iteration limit.
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 1b40a87..549161c 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -35,6 +35,8 @@ def Canonicalizer : Pass<"canonicalize"> {
Option<"enableRegionSimplification", "region-simplify", "bool",
/*default=*/"true",
"Perform control flow optimizations to the region tree">,
+ Option<"cseConstants", "cse-constants", "bool", /*default=*/"true",
+ "CSE constant operations">,
Option<"maxIterations", "max-iterations", "int64_t",
/*default=*/"10",
"Max. iterations between applying patterns / simplifying regions">,
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index d50019b..2600df3 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -33,6 +33,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
: config(config) {
this->topDownProcessingEnabled = config.useTopDownTraversal;
this->enableRegionSimplification = config.enableRegionSimplification;
+ this->cseConstants = config.cseConstants;
this->maxIterations = config.maxIterations;
this->maxNumRewrites = config.maxNumRewrites;
this->disabledPatterns = disabledPatterns;
@@ -45,6 +46,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
// Set the config from possible pass options set in the meantime.
config.useTopDownTraversal = topDownProcessingEnabled;
config.enableRegionSimplification = enableRegionSimplification;
+ config.cseConstants = cseConstants;
config.maxIterations = maxIterations;
config.maxNumRewrites = maxNumRewrites;
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index cfd4f9c0..cf4a192a 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -848,13 +848,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
if (!config.useTopDownTraversal) {
// Add operations to the worklist in postorder.
region.walk([&](Operation *op) {
- if (!insertKnownConstant(op))
+ if (!config.cseConstants || !insertKnownConstant(op))
addToWorklist(op);
});
} else {
// Add all nested operations to the worklist in preorder.
region.walk<WalkOrder::PreOrder>([&](Operation *op) {
- if (!insertKnownConstant(op)) {
+ if (!config.cseConstants || !insertKnownConstant(op)) {
addToWorklist(op);
return WalkResult::advance();
}
diff --git a/mlir/test/Pass/run-reproducer.mlir b/mlir/test/Pass/run-reproducer.mlir
index 57a58db..220ea24 100644
--- a/mlir/test/Pass/run-reproducer.mlir
+++ b/mlir/test/Pass/run-reproducer.mlir
@@ -14,7 +14,7 @@ func.func @bar() {
external_resources: {
mlir_reproducer: {
verify_each: true,
- // CHECK: builtin.module(func.func(cse,canonicalize{ max-iterations=1 max-num-rewrites=-1 region-simplify=false test-convergence=false top-down=false}))
+ // CHECK: builtin.module(func.func(cse,canonicalize{cse-constants=true max-iterations=1 max-num-rewrites=-1 region-simplify=false test-convergence=false top-down=false}))
pipeline: "builtin.module(func.func(cse,canonicalize{max-iterations=1 max-num-rewrites=-1 region-simplify=false top-down=false}))",
disable_threading: true
}
diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir
index 4f0095e..98eae14 100644
--- a/mlir/test/Transforms/test-canonicalize.mlir
+++ b/mlir/test/Transforms/test-canonicalize.mlir
@@ -1,5 +1,6 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{region-simplify=false}))' | FileCheck %s --check-prefixes=CHECK,NO-RS
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{cse-constants=false}))' | FileCheck %s --check-prefixes=NO-CSE
// CHECK-LABEL: func @remove_op_with_inner_ops_pattern
func.func @remove_op_with_inner_ops_pattern() {
@@ -89,3 +90,16 @@ func.func @test_region_simplify() {
^bb1:
return
}
+
+// CHECK-LABEL: do_not_cse_constant
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: return %[[c0]], %[[c0]]
+// NO-CSE-LABEL: do_not_cse_constant
+// NO-CSE: %[[c0:.*]] = arith.constant 0 : index
+// NO-CSE: %[[c1:.*]] = arith.constant 0 : index
+// NO-CSE: return %[[c0]], %[[c1]]
+func.func @do_not_cse_constant() -> (index, index) {
+ %0 = arith.constant 0 : index
+ %1 = arith.constant 0 : index
+ return %0, %1 : index, index
+} \ No newline at end of file