diff options
Diffstat (limited to 'mlir/lib/Transforms/Utils')
| -rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 120 |
1 files changed, 85 insertions, 35 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 2fe0697..f8c38fa 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -92,6 +92,22 @@ static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) { return pt; } +namespace { +enum OpConversionMode { + /// In this mode, the conversion will ignore failed conversions to allow + /// illegal operations to co-exist in the IR. + Partial, + + /// In this mode, all operations must be legal for the given target for the + /// conversion to succeed. + Full, + + /// In this mode, operations are analyzed for legality. No actual rewrites are + /// applied to the operations on success. + Analysis, +}; +} // namespace + //===----------------------------------------------------------------------===// // ConversionValueMapping //===----------------------------------------------------------------------===// @@ -866,8 +882,9 @@ namespace mlir { namespace detail { struct ConversionPatternRewriterImpl : public RewriterBase::Listener { explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, - const ConversionConfig &config) - : rewriter(rewriter), config(config), + const ConversionConfig &config, + OperationConverter &opConverter) + : rewriter(rewriter), config(config), opConverter(opConverter), notifyingRewriter(rewriter.getContext(), config.listener) {} //===--------------------------------------------------------------------===// @@ -1124,6 +1141,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Dialect conversion configuration. const ConversionConfig &config; + /// The operation converter to use for recursive legalization. + OperationConverter &opConverter; + /// A set of erased operations. This set is utilized only if /// `allowPatternRollback` is set to "false". Conceptually, this set is /// similar to `replacedOps` (which is maintained when the flag is set to @@ -2084,9 +2104,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure( //===----------------------------------------------------------------------===// ConversionPatternRewriter::ConversionPatternRewriter( - MLIRContext *ctx, const ConversionConfig &config) - : PatternRewriter(ctx), - impl(new detail::ConversionPatternRewriterImpl(*this, config)) { + MLIRContext *ctx, const ConversionConfig &config, + OperationConverter &opConverter) + : PatternRewriter(ctx), impl(new detail::ConversionPatternRewriterImpl( + *this, config, opConverter)) { setListener(impl.get()); } @@ -2207,6 +2228,37 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys, return success(); } +LogicalResult ConversionPatternRewriter::legalize(Region *r) { + // Fast path: If the region is empty, there is nothing to legalize. + if (r->empty()) + return success(); + + // Gather a list of all operations to legalize. This is done before + // converting the entry block signature because unrealized_conversion_cast + // ops should not be included. + SmallVector<Operation *> ops; + for (Block &b : *r) + for (Operation &op : b) + ops.push_back(&op); + + // If the current pattern runs with a type converter, convert the entry block + // signature. + if (const TypeConverter *converter = impl->currentTypeConverter) { + std::optional<TypeConverter::SignatureConversion> conversion = + converter->convertBlockSignature(&r->front()); + if (!conversion) + return failure(); + applySignatureConversion(&r->front(), *conversion, converter); + } + + // Legalize all operations in the region. + for (Operation *op : ops) + if (failed(legalize(op))) + return failure(); + + return success(); +} + void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues) { @@ -3192,22 +3244,6 @@ static void reconcileUnrealizedCasts( // OperationConverter //===----------------------------------------------------------------------===// -namespace { -enum OpConversionMode { - /// In this mode, the conversion will ignore failed conversions to allow - /// illegal operations to co-exist in the IR. - Partial, - - /// In this mode, all operations must be legal for the given target for the - /// conversion to succeed. - Full, - - /// In this mode, operations are analyzed for legality. No actual rewrites are - /// applied to the operations on success. - Analysis, -}; -} // namespace - namespace mlir { // This class converts operations to a given conversion target via a set of // rewrite patterns. The conversion behaves differently depending on the @@ -3217,16 +3253,20 @@ struct OperationConverter { const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode) - : rewriter(ctx, config), opLegalizer(rewriter, target, patterns), + : rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns), mode(mode) {} /// Converts the given operations to the conversion target. LogicalResult convertOperations(ArrayRef<Operation *> ops); -private: - /// Converts an operation with the given rewriter. - LogicalResult convert(Operation *op); + /// Converts a single operation. If `isRecursiveLegalization` is "true", the + /// conversion is a recursive legalization request, triggered from within a + /// pattern. In that case, do not emit errors because there will be another + /// attempt at legalizing the operation later (via the regular pre-order + /// legalization mechanism). + LogicalResult convert(Operation *op, bool isRecursiveLegalization = false); +private: /// The rewriter to use when converting operations. ConversionPatternRewriter rewriter; @@ -3238,32 +3278,42 @@ private: }; } // namespace mlir -LogicalResult OperationConverter::convert(Operation *op) { +LogicalResult ConversionPatternRewriter::legalize(Operation *op) { + return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true); +} + +LogicalResult OperationConverter::convert(Operation *op, + bool isRecursiveLegalization) { const ConversionConfig &config = rewriter.getConfig(); // Legalize the given operation. if (failed(opLegalizer.legalize(op))) { // Handle the case of a failed conversion for each of the different modes. // Full conversions expect all operations to be converted. - if (mode == OpConversionMode::Full) - return op->emitError() - << "failed to legalize operation '" << op->getName() << "'"; + if (mode == OpConversionMode::Full) { + if (!isRecursiveLegalization) + op->emitError() << "failed to legalize operation '" << op->getName() + << "'"; + return failure(); + } // Partial conversions allow conversions to fail iff the operation was not // explicitly marked as illegal. If the user provided a `unlegalizedOps` // set, non-legalizable ops are added to that set. if (mode == OpConversionMode::Partial) { - if (opLegalizer.isIllegal(op)) - return op->emitError() - << "failed to legalize operation '" << op->getName() - << "' that was explicitly marked illegal"; - if (config.unlegalizedOps) + if (opLegalizer.isIllegal(op)) { + if (!isRecursiveLegalization) + op->emitError() << "failed to legalize operation '" << op->getName() + << "' that was explicitly marked illegal"; + return failure(); + } + if (config.unlegalizedOps && !isRecursiveLegalization) config.unlegalizedOps->insert(op); } } else if (mode == OpConversionMode::Analysis) { // Analysis conversions don't fail if any operations fail to legalize, // they are only interested in the operations that were successfully // legalized. - if (config.legalizableOps) + if (config.legalizableOps && !isRecursiveLegalization) config.legalizableOps->insert(op); } return success(); |
