aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms/Utils/DialectConversion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/Utils/DialectConversion.cpp')
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp120
1 files changed, 85 insertions, 35 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 2fe0697..f8c38fa 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -92,6 +92,22 @@ static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
return pt;
}
+namespace {
+enum OpConversionMode {
+ /// In this mode, the conversion will ignore failed conversions to allow
+ /// illegal operations to co-exist in the IR.
+ Partial,
+
+ /// In this mode, all operations must be legal for the given target for the
+ /// conversion to succeed.
+ Full,
+
+ /// In this mode, operations are analyzed for legality. No actual rewrites are
+ /// applied to the operations on success.
+ Analysis,
+};
+} // namespace
+
//===----------------------------------------------------------------------===//
// ConversionValueMapping
//===----------------------------------------------------------------------===//
@@ -866,8 +882,9 @@ namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
- const ConversionConfig &config)
- : rewriter(rewriter), config(config),
+ const ConversionConfig &config,
+ OperationConverter &opConverter)
+ : rewriter(rewriter), config(config), opConverter(opConverter),
notifyingRewriter(rewriter.getContext(), config.listener) {}
//===--------------------------------------------------------------------===//
@@ -1124,6 +1141,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Dialect conversion configuration.
const ConversionConfig &config;
+ /// The operation converter to use for recursive legalization.
+ OperationConverter &opConverter;
+
/// A set of erased operations. This set is utilized only if
/// `allowPatternRollback` is set to "false". Conceptually, this set is
/// similar to `replacedOps` (which is maintained when the flag is set to
@@ -2084,9 +2104,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
//===----------------------------------------------------------------------===//
ConversionPatternRewriter::ConversionPatternRewriter(
- MLIRContext *ctx, const ConversionConfig &config)
- : PatternRewriter(ctx),
- impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
+ MLIRContext *ctx, const ConversionConfig &config,
+ OperationConverter &opConverter)
+ : PatternRewriter(ctx), impl(new detail::ConversionPatternRewriterImpl(
+ *this, config, opConverter)) {
setListener(impl.get());
}
@@ -2207,6 +2228,37 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
return success();
}
+LogicalResult ConversionPatternRewriter::legalize(Region *r) {
+ // Fast path: If the region is empty, there is nothing to legalize.
+ if (r->empty())
+ return success();
+
+ // Gather a list of all operations to legalize. This is done before
+ // converting the entry block signature because unrealized_conversion_cast
+ // ops should not be included.
+ SmallVector<Operation *> ops;
+ for (Block &b : *r)
+ for (Operation &op : b)
+ ops.push_back(&op);
+
+ // If the current pattern runs with a type converter, convert the entry block
+ // signature.
+ if (const TypeConverter *converter = impl->currentTypeConverter) {
+ std::optional<TypeConverter::SignatureConversion> conversion =
+ converter->convertBlockSignature(&r->front());
+ if (!conversion)
+ return failure();
+ applySignatureConversion(&r->front(), *conversion, converter);
+ }
+
+ // Legalize all operations in the region.
+ for (Operation *op : ops)
+ if (failed(legalize(op)))
+ return failure();
+
+ return success();
+}
+
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
Block::iterator before,
ValueRange argValues) {
@@ -3192,22 +3244,6 @@ static void reconcileUnrealizedCasts(
// OperationConverter
//===----------------------------------------------------------------------===//
-namespace {
-enum OpConversionMode {
- /// In this mode, the conversion will ignore failed conversions to allow
- /// illegal operations to co-exist in the IR.
- Partial,
-
- /// In this mode, all operations must be legal for the given target for the
- /// conversion to succeed.
- Full,
-
- /// In this mode, operations are analyzed for legality. No actual rewrites are
- /// applied to the operations on success.
- Analysis,
-};
-} // namespace
-
namespace mlir {
// This class converts operations to a given conversion target via a set of
// rewrite patterns. The conversion behaves differently depending on the
@@ -3217,16 +3253,20 @@ struct OperationConverter {
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config,
OpConversionMode mode)
- : rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
+ : rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
mode(mode) {}
/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops);
-private:
- /// Converts an operation with the given rewriter.
- LogicalResult convert(Operation *op);
+ /// Converts a single operation. If `isRecursiveLegalization` is "true", the
+ /// conversion is a recursive legalization request, triggered from within a
+ /// pattern. In that case, do not emit errors because there will be another
+ /// attempt at legalizing the operation later (via the regular pre-order
+ /// legalization mechanism).
+ LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
+private:
/// The rewriter to use when converting operations.
ConversionPatternRewriter rewriter;
@@ -3238,32 +3278,42 @@ private:
};
} // namespace mlir
-LogicalResult OperationConverter::convert(Operation *op) {
+LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
+ return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true);
+}
+
+LogicalResult OperationConverter::convert(Operation *op,
+ bool isRecursiveLegalization) {
const ConversionConfig &config = rewriter.getConfig();
// Legalize the given operation.
if (failed(opLegalizer.legalize(op))) {
// Handle the case of a failed conversion for each of the different modes.
// Full conversions expect all operations to be converted.
- if (mode == OpConversionMode::Full)
- return op->emitError()
- << "failed to legalize operation '" << op->getName() << "'";
+ if (mode == OpConversionMode::Full) {
+ if (!isRecursiveLegalization)
+ op->emitError() << "failed to legalize operation '" << op->getName()
+ << "'";
+ return failure();
+ }
// Partial conversions allow conversions to fail iff the operation was not
// explicitly marked as illegal. If the user provided a `unlegalizedOps`
// set, non-legalizable ops are added to that set.
if (mode == OpConversionMode::Partial) {
- if (opLegalizer.isIllegal(op))
- return op->emitError()
- << "failed to legalize operation '" << op->getName()
- << "' that was explicitly marked illegal";
- if (config.unlegalizedOps)
+ if (opLegalizer.isIllegal(op)) {
+ if (!isRecursiveLegalization)
+ op->emitError() << "failed to legalize operation '" << op->getName()
+ << "' that was explicitly marked illegal";
+ return failure();
+ }
+ if (config.unlegalizedOps && !isRecursiveLegalization)
config.unlegalizedOps->insert(op);
}
} else if (mode == OpConversionMode::Analysis) {
// Analysis conversions don't fail if any operations fail to legalize,
// they are only interested in the operations that were successfully
// legalized.
- if (config.legalizableOps)
+ if (config.legalizableOps && !isRecursiveLegalization)
config.legalizableOps->insert(op);
}
return success();