diff options
Diffstat (limited to 'mlir/lib/Transforms/Utils/DialectConversion.cpp')
| -rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 112 |
1 files changed, 20 insertions, 92 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 3a23bbf..2fe0697 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1105,10 +1105,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// A set of operations that were modified by the current pattern. SetVector<Operation *> patternModifiedOps; - /// A set of blocks that were inserted (newly-created blocks or moved blocks) - /// by the current pattern. - SetVector<Block *> patternInsertedBlocks; - /// A list of unresolved materializations that were created by the current /// pattern. DenseSet<UnrealizedConversionCastOp> patternMaterializations; @@ -2046,8 +2042,6 @@ void ConversionPatternRewriterImpl::notifyBlockInserted( if (!config.allowPatternRollback && config.listener) config.listener->notifyBlockInserted(block, previous, previousIt); - patternInsertedBlocks.insert(block); - if (wasDetached) { // If the block was detached, it is most likely a newly created block. if (config.allowPatternRollback) { @@ -2399,17 +2393,12 @@ private: bool canApplyPattern(Operation *op, const Pattern &pattern); /// Legalize the resultant IR after successfully applying the given pattern. - LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, - const RewriterState &curState, - const SetVector<Operation *> &newOps, - const SetVector<Operation *> &modifiedOps, - const SetVector<Block *> &insertedBlocks); - - /// Legalizes the actions registered during the execution of a pattern. LogicalResult - legalizePatternBlockRewrites(Operation *op, - const SetVector<Block *> &insertedBlocks, - const SetVector<Operation *> &newOps); + legalizePatternResult(Operation *op, const Pattern &pattern, + const RewriterState &curState, + const SetVector<Operation *> &newOps, + const SetVector<Operation *> &modifiedOps); + LogicalResult legalizePatternCreatedOperations(const SetVector<Operation *> &newOps); LogicalResult @@ -2608,7 +2597,6 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) { auto cleanup = llvm::make_scope_exit([&]() { rewriterImpl.patternNewOps.clear(); rewriterImpl.patternModifiedOps.clear(); - rewriterImpl.patternInsertedBlocks.clear(); }); // Upon failure, undo all changes made by the folder. @@ -2662,24 +2650,16 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) { static void reportNewIrLegalizationFatalError(const Pattern &pattern, const SetVector<Operation *> &newOps, - const SetVector<Operation *> &modifiedOps, - const SetVector<Block *> &insertedBlocks) { + const SetVector<Operation *> &modifiedOps) { auto newOpNames = llvm::map_range( newOps, [](Operation *op) { return op->getName().getStringRef(); }); auto modifiedOpNames = llvm::map_range( modifiedOps, [](Operation *op) { return op->getName().getStringRef(); }); - StringRef detachedBlockStr = "(detached block)"; - auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) { - if (block->getParentOp()) - return block->getParentOp()->getName().getStringRef(); - return detachedBlockStr; - }); - llvm::report_fatal_error( - "pattern '" + pattern.getDebugName() + - "' produced IR that could not be legalized. " + "new ops: {" + - llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" + - llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" + - llvm::join(insertedBlockNames, ", ") + "}"); + llvm::report_fatal_error("pattern '" + pattern.getDebugName() + + "' produced IR that could not be legalized. " + + "new ops: {" + llvm::join(newOpNames, ", ") + "}, " + + "modified ops: {" + + llvm::join(modifiedOpNames, ", ") + "}"); } LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) { @@ -2743,7 +2723,6 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) { } rewriterImpl.patternNewOps.clear(); rewriterImpl.patternModifiedOps.clear(); - rewriterImpl.patternInsertedBlocks.clear(); LLVM_DEBUG({ logFailure(rewriterImpl.logger, "pattern failed to match"); if (rewriterImpl.config.notifyCallback) { @@ -2777,15 +2756,12 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) { SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps); SetVector<Operation *> modifiedOps = moveAndReset(rewriterImpl.patternModifiedOps); - SetVector<Block *> insertedBlocks = - moveAndReset(rewriterImpl.patternInsertedBlocks); - auto result = legalizePatternResult(op, pattern, curState, newOps, - modifiedOps, insertedBlocks); + auto result = + legalizePatternResult(op, pattern, curState, newOps, modifiedOps); appliedPatterns.erase(&pattern); if (failed(result)) { if (!rewriterImpl.config.allowPatternRollback) - reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps, - insertedBlocks); + reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps); rewriterImpl.resetState(curState, pattern.getDebugName()); } if (config.listener) @@ -2823,8 +2799,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op, LogicalResult OperationLegalizer::legalizePatternResult( Operation *op, const Pattern &pattern, const RewriterState &curState, const SetVector<Operation *> &newOps, - const SetVector<Operation *> &modifiedOps, - const SetVector<Block *> &insertedBlocks) { + const SetVector<Operation *> &modifiedOps) { [[maybe_unused]] auto &impl = rewriter.getImpl(); assert(impl.pendingRootUpdates.empty() && "dangling root updates"); @@ -2843,8 +2818,7 @@ LogicalResult OperationLegalizer::legalizePatternResult( #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Legalize each of the actions registered during application. - if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) || - failed(legalizePatternRootUpdates(modifiedOps)) || + if (failed(legalizePatternRootUpdates(modifiedOps)) || failed(legalizePatternCreatedOperations(newOps))) { return failure(); } @@ -2853,53 +2827,6 @@ LogicalResult OperationLegalizer::legalizePatternResult( return success(); } -LogicalResult OperationLegalizer::legalizePatternBlockRewrites( - Operation *op, const SetVector<Block *> &insertedBlocks, - const SetVector<Operation *> &newOps) { - ConversionPatternRewriterImpl &impl = rewriter.getImpl(); - SmallPtrSet<Operation *, 16> alreadyLegalized; - - // If the pattern moved or created any blocks, make sure the types of block - // arguments get legalized. - for (Block *block : insertedBlocks) { - if (impl.erasedBlocks.contains(block)) - continue; - - // Only check blocks outside of the current operation. - Operation *parentOp = block->getParentOp(); - if (!parentOp || parentOp == op || block->getNumArguments() == 0) - continue; - - // If the region of the block has a type converter, try to convert the block - // directly. - if (auto *converter = impl.regionToConverter.lookup(block->getParent())) { - std::optional<TypeConverter::SignatureConversion> conversion = - converter->convertBlockSignature(block); - if (!conversion) { - LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " - "block")); - return failure(); - } - impl.applySignatureConversion(block, converter, *conversion); - continue; - } - - // Otherwise, try to legalize the parent operation if it was not generated - // by this pattern. This is because we will attempt to legalize the parent - // operation, and blocks in regions created by this pattern will already be - // legalized later on. - if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) { - if (failed(legalize(parentOp))) { - LLVM_DEBUG(logFailure( - impl.logger, "operation '{0}'({1}) became illegal after rewrite", - parentOp->getName(), parentOp)); - return failure(); - } - } - } - return success(); -} - LogicalResult OperationLegalizer::legalizePatternCreatedOperations( const SetVector<Operation *> &newOps) { for (Operation *op : newOps) { @@ -3800,10 +3727,11 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp, TypeConverter::SignatureConversion result(type.getNumInputs()); SmallVector<Type, 1> newResults; if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) || - failed(typeConverter.convertTypes(type.getResults(), newResults)) || - failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), - typeConverter, &result))) + failed(typeConverter.convertTypes(type.getResults(), newResults))) return failure(); + if (!funcOp.getFunctionBody().empty()) + rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result, + &typeConverter); // Update the function signature in-place. auto newType = FunctionType::get(rewriter.getContext(), |
