aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <mspringer@nvidia.com>2024-05-19 11:49:47 +0200
committerMatthias Springer <mspringer@nvidia.com>2024-05-19 11:49:47 +0200
commit8d8b1ad8310223b584fc9d0af57019a45dfbec77 (patch)
tree2a8ec058a2c665e8e6761344098bea6560d2abb7
parent8b8a38a7b426fc724804602d7635134a0c63f08c (diff)
downloadllvm-users/matthias-springer/dialect_conv_no_fold.zip
llvm-users/matthias-springer/dialect_conv_no_fold.tar.gz
llvm-users/matthias-springer/dialect_conv_no_fold.tar.bz2
[mlir][Transforms] Dialect conversion: Add option to disable foldingusers/matthias-springer/dialect_conv_no_fold
This commit adds a new flag to `ConversionConfig` that deactivates op folding during a dialect conversion. Op folding is problematic beause op folders may assume that the IR is in a valid state. (See #89770 for an example.) However, the dialect conversion driver does not guarantee that the IR is valid during a dialect conversion; it only guarantees that the IR is valid at the end of a dialect conversion. E.g., IR may be invalid after a conversion pattern application because some IR modifications (e.g., op/block replacements) are applied in a delayed fashion at the end of a dialect conversion. This makes op folders generally unsafe to use with a dialect conversion. Note: For the same reason, it is also not safe to use non-conversion patterns with a dialect conversion. Conversion patterns can be used safely because they have an "adapter". (And conversion patterns cannot assume that the entire IR is valid in general.)
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h11
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp2
-rw-r--r--mlir/test/Transforms/test-legalizer-analysis.mlir2
-rw-r--r--mlir/test/Transforms/test-legalizer-full.mlir2
-rw-r--r--mlir/test/Transforms/test-legalizer-no-fold.mlir11
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp42
6 files changed, 47 insertions, 23 deletions
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 83198c9..ea41e7c 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -247,7 +247,8 @@ public:
/// Attempts a 1-1 type conversion, expecting the result type to be
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
/// and a null type on conversion or cast failure.
- template <typename TargetType> TargetType convertType(Type t) const {
+ template <typename TargetType>
+ TargetType convertType(Type t) const {
return dyn_cast_or_null<TargetType>(convertType(t));
}
@@ -1118,6 +1119,14 @@ struct ConversionConfig {
// already been modified) and iterators into past IR state cannot be
// represented at the moment.
RewriterBase::Listener *listener = nullptr;
+
+ /// If set to "true", the dialect conversion driver attempts to fold
+ /// operations throughout the conversion. This is problematic because op
+ /// folders may assume that the IR is in a valid state at the beginning of
+ /// the folding process. However, the dialect conversion does not guarantee
+ /// that because some IR modifications are delayed until the end of the
+ /// conversion.
+ bool foldOps = true;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d407d60..3c684e9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2030,7 +2030,7 @@ OperationLegalizer::legalize(Operation *op,
// If the operation isn't legal, try to fold it in-place.
// TODO: Should we always try to do this, even if the op is
// already legal?
- if (succeeded(legalizeWithFold(op, rewriter))) {
+ if (config.foldOps && succeeded(legalizeWithFold(op, rewriter))) {
LLVM_DEBUG({
logSuccess(logger, "operation was folded");
logger.startLine() << logLineComment;
diff --git a/mlir/test/Transforms/test-legalizer-analysis.mlir b/mlir/test/Transforms/test-legalizer-analysis.mlir
index 19a1310..829415b 100644
--- a/mlir/test/Transforms/test-legalizer-analysis.mlir
+++ b/mlir/test/Transforms/test-legalizer-analysis.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -verify-diagnostics -test-legalize-mode=analysis %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="legalize-mode=analysis" -verify-diagnostics %s | FileCheck %s
// expected-remark@-2 {{op 'builtin.module' is legalizable}}
// expected-remark@+1 {{op 'func.func' is legalizable}}
diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir
index 5f1148c..ea163a5 100644
--- a/mlir/test/Transforms/test-legalizer-full.mlir
+++ b/mlir/test/Transforms/test-legalizer-full.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -test-legalize-mode=full -split-input-file -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="legalize-mode=full" -split-input-file -verify-diagnostics %s | FileCheck %s
// CHECK-LABEL: func @multi_level_mapping
func.func @multi_level_mapping() {
diff --git a/mlir/test/Transforms/test-legalizer-no-fold.mlir b/mlir/test/Transforms/test-legalizer-no-fold.mlir
new file mode 100644
index 0000000..61afd72
--- /dev/null
+++ b/mlir/test/Transforms/test-legalizer-no-fold.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="fold-ops=0" %s | FileCheck %s
+
+// CHECK-LABEL: @remove_foldable_op(
+func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
+ // Check that op was not folded.
+ // CHECK: "test.op_with_region_fold"
+ %0 = "test.op_with_region_fold"(%arg0) ({
+ "foo.op_with_region_terminator"() : () -> ()
+ }) : (i32) -> (i32)
+ "test.return"(%0) : (i32) -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index f9f7d4e..97ef7a5 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1102,7 +1102,9 @@ struct TestLegalizePatternDriver
/// The mode of conversion to use with the driver.
enum class ConversionMode { Analysis, Full, Partial };
- TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {}
+ TestLegalizePatternDriver() = default;
+ TestLegalizePatternDriver(const TestLegalizePatternDriver &other)
+ : PassWrapper(other) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect, test::TestDialect>();
@@ -1179,6 +1181,7 @@ struct TestLegalizePatternDriver
DumpNotifications dumpNotifications;
config.listener = &dumpNotifications;
config.unlegalizedOps = &unlegalizedOps;
+ config.foldOps = foldOps;
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns), config))) {
getOperation()->emitRemark() << "applyPartialConversion failed";
@@ -1197,6 +1200,7 @@ struct TestLegalizePatternDriver
});
ConversionConfig config;
+ config.foldOps = foldOps;
DumpNotifications dumpNotifications;
config.listener = &dumpNotifications;
if (failed(applyFullConversion(getOperation(), target,
@@ -1212,6 +1216,7 @@ struct TestLegalizePatternDriver
// Analyze the convertible operations.
DenseSet<Operation *> legalizedOps;
ConversionConfig config;
+ config.foldOps = foldOps;
config.legalizableOps = &legalizedOps;
if (failed(applyAnalysisConversion(getOperation(), target,
std::move(patterns), config)))
@@ -1222,24 +1227,25 @@ struct TestLegalizePatternDriver
op->emitRemark() << "op '" << op->getName() << "' is legalizable";
}
- /// The mode of conversion to use.
- ConversionMode mode;
+ Option<bool> foldOps{
+ *this, "fold-ops",
+ llvm::cl::desc("Fold ops throughout the conversion process"),
+ llvm::cl::init(true)};
+
+ Option<TestLegalizePatternDriver::ConversionMode> mode{
+ *this, "legalize-mode",
+ llvm::cl::desc("The legalization mode to use with the test driver"),
+ llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
+ llvm::cl::values(
+ clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
+ "analysis", "Perform an analysis conversion"),
+ clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
+ "Perform a full conversion"),
+ clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
+ "partial", "Perform a partial conversion"))};
};
} // namespace
-static llvm::cl::opt<TestLegalizePatternDriver::ConversionMode>
- legalizerConversionMode(
- "test-legalize-mode",
- llvm::cl::desc("The legalization mode to use with the test driver"),
- llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial),
- llvm::cl::values(
- clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis,
- "analysis", "Perform an analysis conversion"),
- clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full",
- "Perform a full conversion"),
- clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial,
- "partial", "Perform a partial conversion")));
-
//===----------------------------------------------------------------------===//
// ConversionPatternRewriter::getRemappedValue testing. This method is used
// to get the remapped value of an original value that was replaced using
@@ -1909,9 +1915,7 @@ void registerPatternsTestPass() {
PassRegistration<TestPatternDriver>();
PassRegistration<TestStrictPatternDriver>();
- PassRegistration<TestLegalizePatternDriver>([] {
- return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
- });
+ PassRegistration<TestLegalizePatternDriver>();
PassRegistration<TestRemappedValue>();