aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms/Utils/DialectConversion.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/Utils/DialectConversion.cpp')
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp112
1 files changed, 20 insertions, 92 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3a23bbf..2fe0697 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1105,10 +1105,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// A set of operations that were modified by the current pattern.
SetVector<Operation *> patternModifiedOps;
- /// A set of blocks that were inserted (newly-created blocks or moved blocks)
- /// by the current pattern.
- SetVector<Block *> patternInsertedBlocks;
-
/// A list of unresolved materializations that were created by the current
/// pattern.
DenseSet<UnrealizedConversionCastOp> patternMaterializations;
@@ -2046,8 +2042,6 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
if (!config.allowPatternRollback && config.listener)
config.listener->notifyBlockInserted(block, previous, previousIt);
- patternInsertedBlocks.insert(block);
-
if (wasDetached) {
// If the block was detached, it is most likely a newly created block.
if (config.allowPatternRollback) {
@@ -2399,17 +2393,12 @@ private:
bool canApplyPattern(Operation *op, const Pattern &pattern);
/// Legalize the resultant IR after successfully applying the given pattern.
- LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
- const RewriterState &curState,
- const SetVector<Operation *> &newOps,
- const SetVector<Operation *> &modifiedOps,
- const SetVector<Block *> &insertedBlocks);
-
- /// Legalizes the actions registered during the execution of a pattern.
LogicalResult
- legalizePatternBlockRewrites(Operation *op,
- const SetVector<Block *> &insertedBlocks,
- const SetVector<Operation *> &newOps);
+ legalizePatternResult(Operation *op, const Pattern &pattern,
+ const RewriterState &curState,
+ const SetVector<Operation *> &newOps,
+ const SetVector<Operation *> &modifiedOps);
+
LogicalResult
legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
LogicalResult
@@ -2608,7 +2597,6 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
auto cleanup = llvm::make_scope_exit([&]() {
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
- rewriterImpl.patternInsertedBlocks.clear();
});
// Upon failure, undo all changes made by the folder.
@@ -2662,24 +2650,16 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
static void
reportNewIrLegalizationFatalError(const Pattern &pattern,
const SetVector<Operation *> &newOps,
- const SetVector<Operation *> &modifiedOps,
- const SetVector<Block *> &insertedBlocks) {
+ const SetVector<Operation *> &modifiedOps) {
auto newOpNames = llvm::map_range(
newOps, [](Operation *op) { return op->getName().getStringRef(); });
auto modifiedOpNames = llvm::map_range(
modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
- StringRef detachedBlockStr = "(detached block)";
- auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) {
- if (block->getParentOp())
- return block->getParentOp()->getName().getStringRef();
- return detachedBlockStr;
- });
- llvm::report_fatal_error(
- "pattern '" + pattern.getDebugName() +
- "' produced IR that could not be legalized. " + "new ops: {" +
- llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" +
- llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" +
- llvm::join(insertedBlockNames, ", ") + "}");
+ llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
+ "' produced IR that could not be legalized. " +
+ "new ops: {" + llvm::join(newOpNames, ", ") + "}, " +
+ "modified ops: {" +
+ llvm::join(modifiedOpNames, ", ") + "}");
}
LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
@@ -2743,7 +2723,6 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
}
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
- rewriterImpl.patternInsertedBlocks.clear();
LLVM_DEBUG({
logFailure(rewriterImpl.logger, "pattern failed to match");
if (rewriterImpl.config.notifyCallback) {
@@ -2777,15 +2756,12 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
SetVector<Operation *> modifiedOps =
moveAndReset(rewriterImpl.patternModifiedOps);
- SetVector<Block *> insertedBlocks =
- moveAndReset(rewriterImpl.patternInsertedBlocks);
- auto result = legalizePatternResult(op, pattern, curState, newOps,
- modifiedOps, insertedBlocks);
+ auto result =
+ legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
appliedPatterns.erase(&pattern);
if (failed(result)) {
if (!rewriterImpl.config.allowPatternRollback)
- reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps,
- insertedBlocks);
+ reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
rewriterImpl.resetState(curState, pattern.getDebugName());
}
if (config.listener)
@@ -2823,8 +2799,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op,
LogicalResult OperationLegalizer::legalizePatternResult(
Operation *op, const Pattern &pattern, const RewriterState &curState,
const SetVector<Operation *> &newOps,
- const SetVector<Operation *> &modifiedOps,
- const SetVector<Block *> &insertedBlocks) {
+ const SetVector<Operation *> &modifiedOps) {
[[maybe_unused]] auto &impl = rewriter.getImpl();
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
@@ -2843,8 +2818,7 @@ LogicalResult OperationLegalizer::legalizePatternResult(
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Legalize each of the actions registered during application.
- if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
- failed(legalizePatternRootUpdates(modifiedOps)) ||
+ if (failed(legalizePatternRootUpdates(modifiedOps)) ||
failed(legalizePatternCreatedOperations(newOps))) {
return failure();
}
@@ -2853,53 +2827,6 @@ LogicalResult OperationLegalizer::legalizePatternResult(
return success();
}
-LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
- Operation *op, const SetVector<Block *> &insertedBlocks,
- const SetVector<Operation *> &newOps) {
- ConversionPatternRewriterImpl &impl = rewriter.getImpl();
- SmallPtrSet<Operation *, 16> alreadyLegalized;
-
- // If the pattern moved or created any blocks, make sure the types of block
- // arguments get legalized.
- for (Block *block : insertedBlocks) {
- if (impl.erasedBlocks.contains(block))
- continue;
-
- // Only check blocks outside of the current operation.
- Operation *parentOp = block->getParentOp();
- if (!parentOp || parentOp == op || block->getNumArguments() == 0)
- continue;
-
- // If the region of the block has a type converter, try to convert the block
- // directly.
- if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
- std::optional<TypeConverter::SignatureConversion> conversion =
- converter->convertBlockSignature(block);
- if (!conversion) {
- LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
- "block"));
- return failure();
- }
- impl.applySignatureConversion(block, converter, *conversion);
- continue;
- }
-
- // Otherwise, try to legalize the parent operation if it was not generated
- // by this pattern. This is because we will attempt to legalize the parent
- // operation, and blocks in regions created by this pattern will already be
- // legalized later on.
- if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
- if (failed(legalize(parentOp))) {
- LLVM_DEBUG(logFailure(
- impl.logger, "operation '{0}'({1}) became illegal after rewrite",
- parentOp->getName(), parentOp));
- return failure();
- }
- }
- }
- return success();
-}
-
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
const SetVector<Operation *> &newOps) {
for (Operation *op : newOps) {
@@ -3800,10 +3727,11 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
TypeConverter::SignatureConversion result(type.getNumInputs());
SmallVector<Type, 1> newResults;
if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
- failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
- failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
- typeConverter, &result)))
+ failed(typeConverter.convertTypes(type.getResults(), newResults)))
return failure();
+ if (!funcOp.getFunctionBody().empty())
+ rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result,
+ &typeConverter);
// Update the function signature in-place.
auto newType = FunctionType::get(rewriter.getContext(),