From 8d8b1ad8310223b584fc9d0af57019a45dfbec77 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 19 May 2024 11:49:47 +0200 Subject: [mlir][Transforms] Dialect conversion: Add option to disable folding This commit adds a new flag to `ConversionConfig` that deactivates op folding during a dialect conversion. Op folding is problematic beause op folders may assume that the IR is in a valid state. (See #89770 for an example.) However, the dialect conversion driver does not guarantee that the IR is valid during a dialect conversion; it only guarantees that the IR is valid at the end of a dialect conversion. E.g., IR may be invalid after a conversion pattern application because some IR modifications (e.g., op/block replacements) are applied in a delayed fashion at the end of a dialect conversion. This makes op folders generally unsafe to use with a dialect conversion. Note: For the same reason, it is also not safe to use non-conversion patterns with a dialect conversion. Conversion patterns can be used safely because they have an "adapter". (And conversion patterns cannot assume that the entire IR is valid in general.) --- mlir/include/mlir/Transforms/DialectConversion.h | 11 +++++- mlir/lib/Transforms/Utils/DialectConversion.cpp | 2 +- mlir/test/Transforms/test-legalizer-analysis.mlir | 2 +- mlir/test/Transforms/test-legalizer-full.mlir | 2 +- mlir/test/Transforms/test-legalizer-no-fold.mlir | 11 ++++++ mlir/test/lib/Dialect/Test/TestPatterns.cpp | 42 +++++++++++++---------- 6 files changed, 47 insertions(+), 23 deletions(-) create mode 100644 mlir/test/Transforms/test-legalizer-no-fold.mlir diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 83198c9..ea41e7c 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -247,7 +247,8 @@ public: /// Attempts a 1-1 type conversion, expecting the result type to be /// `TargetType`. Returns the converted type cast to `TargetType` on success, /// and a null type on conversion or cast failure. - template TargetType convertType(Type t) const { + template + TargetType convertType(Type t) const { return dyn_cast_or_null(convertType(t)); } @@ -1118,6 +1119,14 @@ struct ConversionConfig { // already been modified) and iterators into past IR state cannot be // represented at the moment. RewriterBase::Listener *listener = nullptr; + + /// If set to "true", the dialect conversion driver attempts to fold + /// operations throughout the conversion. This is problematic because op + /// folders may assume that the IR is in a valid state at the beginning of + /// the folding process. However, the dialect conversion does not guarantee + /// that because some IR modifications are delayed until the end of the + /// conversion. + bool foldOps = true; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index d407d60..3c684e9 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2030,7 +2030,7 @@ OperationLegalizer::legalize(Operation *op, // If the operation isn't legal, try to fold it in-place. // TODO: Should we always try to do this, even if the op is // already legal? - if (succeeded(legalizeWithFold(op, rewriter))) { + if (config.foldOps && succeeded(legalizeWithFold(op, rewriter))) { LLVM_DEBUG({ logSuccess(logger, "operation was folded"); logger.startLine() << logLineComment; diff --git a/mlir/test/Transforms/test-legalizer-analysis.mlir b/mlir/test/Transforms/test-legalizer-analysis.mlir index 19a1310..829415b 100644 --- a/mlir/test/Transforms/test-legalizer-analysis.mlir +++ b/mlir/test/Transforms/test-legalizer-analysis.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -verify-diagnostics -test-legalize-mode=analysis %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="legalize-mode=analysis" -verify-diagnostics %s | FileCheck %s // expected-remark@-2 {{op 'builtin.module' is legalizable}} // expected-remark@+1 {{op 'func.func' is legalizable}} diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir index 5f1148c..ea163a5 100644 --- a/mlir/test/Transforms/test-legalizer-full.mlir +++ b/mlir/test/Transforms/test-legalizer-full.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -test-legalize-mode=full -split-input-file -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="legalize-mode=full" -split-input-file -verify-diagnostics %s | FileCheck %s // CHECK-LABEL: func @multi_level_mapping func.func @multi_level_mapping() { diff --git a/mlir/test/Transforms/test-legalizer-no-fold.mlir b/mlir/test/Transforms/test-legalizer-no-fold.mlir new file mode 100644 index 0000000..61afd72 --- /dev/null +++ b/mlir/test/Transforms/test-legalizer-no-fold.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns="fold-ops=0" %s | FileCheck %s + +// CHECK-LABEL: @remove_foldable_op( +func.func @remove_foldable_op(%arg0 : i32) -> (i32) { + // Check that op was not folded. + // CHECK: "test.op_with_region_fold" + %0 = "test.op_with_region_fold"(%arg0) ({ + "foo.op_with_region_terminator"() : () -> () + }) : (i32) -> (i32) + "test.return"(%0) : (i32) -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index f9f7d4e..97ef7a5 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1102,7 +1102,9 @@ struct TestLegalizePatternDriver /// The mode of conversion to use with the driver. enum class ConversionMode { Analysis, Full, Partial }; - TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} + TestLegalizePatternDriver() = default; + TestLegalizePatternDriver(const TestLegalizePatternDriver &other) + : PassWrapper(other) {} void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -1179,6 +1181,7 @@ struct TestLegalizePatternDriver DumpNotifications dumpNotifications; config.listener = &dumpNotifications; config.unlegalizedOps = &unlegalizedOps; + config.foldOps = foldOps; if (failed(applyPartialConversion(getOperation(), target, std::move(patterns), config))) { getOperation()->emitRemark() << "applyPartialConversion failed"; @@ -1197,6 +1200,7 @@ struct TestLegalizePatternDriver }); ConversionConfig config; + config.foldOps = foldOps; DumpNotifications dumpNotifications; config.listener = &dumpNotifications; if (failed(applyFullConversion(getOperation(), target, @@ -1212,6 +1216,7 @@ struct TestLegalizePatternDriver // Analyze the convertible operations. DenseSet legalizedOps; ConversionConfig config; + config.foldOps = foldOps; config.legalizableOps = &legalizedOps; if (failed(applyAnalysisConversion(getOperation(), target, std::move(patterns), config))) @@ -1222,24 +1227,25 @@ struct TestLegalizePatternDriver op->emitRemark() << "op '" << op->getName() << "' is legalizable"; } - /// The mode of conversion to use. - ConversionMode mode; + Option foldOps{ + *this, "fold-ops", + llvm::cl::desc("Fold ops throughout the conversion process"), + llvm::cl::init(true)}; + + Option mode{ + *this, "legalize-mode", + llvm::cl::desc("The legalization mode to use with the test driver"), + llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), + llvm::cl::values( + clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, + "analysis", "Perform an analysis conversion"), + clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", + "Perform a full conversion"), + clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, + "partial", "Perform a partial conversion"))}; }; } // namespace -static llvm::cl::opt - legalizerConversionMode( - "test-legalize-mode", - llvm::cl::desc("The legalization mode to use with the test driver"), - llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), - llvm::cl::values( - clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, - "analysis", "Perform an analysis conversion"), - clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", - "Perform a full conversion"), - clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, - "partial", "Perform a partial conversion"))); - //===----------------------------------------------------------------------===// // ConversionPatternRewriter::getRemappedValue testing. This method is used // to get the remapped value of an original value that was replaced using @@ -1909,9 +1915,7 @@ void registerPatternsTestPass() { PassRegistration(); PassRegistration(); - PassRegistration([] { - return std::make_unique(legalizerConversionMode); - }); + PassRegistration(); PassRegistration(); -- cgit v1.1