From ddaf040ea924b1bdd4e093f583018c262da3cc7f Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 8 Mar 2024 10:06:24 +0900 Subject: [mlir][Transforms][NFC] Make signature conversion more efficient (#83922) During block signature conversion, a new block is inserted and ops are moved from the old block to the new block. This commit changes the implementation such that ops are moved in bulk (`splice`) instead of one-by-one; that's what `splitBlock` is doing. This also makes it possible to pass the new block argument types directly to `createBlock` instead of using `addArgument` (which bypasses the rewriter). This doesn't change anything from a technical point of view (there is no rewriter API for adding arguments at the moment), but the implementation reads a bit nicer. --- mlir/lib/Transforms/Utils/DialectConversion.cpp | 27 ++++++++++++++----------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index d7dc902..8b2d714 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1281,7 +1281,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion) { - MLIRContext *ctx = rewriter.getContext(); + OpBuilder::InsertionGuard g(rewriter); // If no arguments are being changed or added, there is nothing to do. unsigned origArgCount = block->getNumArguments(); @@ -1289,14 +1289,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( if (llvm::equal(block->getArgumentTypes(), convertedTypes)) return block; - // Split the block at the beginning to get a new block to use for the updated - // signature. - Block *newBlock = rewriter.splitBlock(block, block->begin()); - block->replaceAllUsesWith(newBlock); - - // Map all new arguments to the location of the argument they originate from. + // Compute the locations of all block arguments in the new block. SmallVector newLocs(convertedTypes.size(), - Builder(ctx).getUnknownLoc()); + rewriter.getUnknownLoc()); for (unsigned i = 0; i < origArgCount; ++i) { auto inputMap = signatureConversion.getInputMapping(i); if (!inputMap || inputMap->replacementValue) @@ -1306,9 +1301,16 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( newLocs[inputMap->inputNo + j] = origLoc; } - SmallVector newArgRange( - newBlock->addArguments(convertedTypes, newLocs)); - ArrayRef newArgs(newArgRange); + // Insert a new block with the converted block argument types and move all ops + // from the old block to the new block. + Block *newBlock = + rewriter.createBlock(block->getParent(), std::next(block->getIterator()), + convertedTypes, newLocs); + appendRewrite(newBlock, block, newBlock->end()); + newBlock->getOperations().splice(newBlock->end(), block->getOperations()); + + // Replace all uses of the old block with the new block. + block->replaceAllUsesWith(newBlock); // Remap each of the original arguments as determined by the signature // conversion. @@ -1333,7 +1335,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( } // Otherwise, this is a 1->1+ mapping. - auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); + auto replArgs = + newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); Value newArg; // If this is a 1->1 mapping and the types of new and replacement arguments -- cgit v1.1