aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2024-06-10 21:49:52 +0200
committerGitHub <noreply@github.com>2024-06-10 21:49:52 +0200
commit52050f3ff388773b9345d421d968a7d1ee880531 (patch)
treecc81dd362789591e2ba149bc954425db1a6ad8f0
parent65310f34d7edf7924ca4cbe7df836770669f70dc (diff)
downloadllvm-52050f3ff388773b9345d421d968a7d1ee880531.zip
llvm-52050f3ff388773b9345d421d968a7d1ee880531.tar.gz
llvm-52050f3ff388773b9345d421d968a7d1ee880531.tar.bz2
[mlir][Transforms] Dialect Conversion: Simplify block conversion API (#94866)
This commit simplifies and improves documentation for the part of the `ConversionPatternRewriter` API that deals with signature conversions. There are now two public functions for signature conversion: * `applySignatureConversion` converts a single block signature. This function used to take a `Region *` (but converted only the entry block). It now takes a `Block *`. * `convertRegionTypes` converts all block signatures of a region. `convertNonEntryRegionTypes` is removed because it is not widely used and can easily be expressed with a call to `applySignatureConversion` inside a loop. (See `Detensorize.cpp` for an example.) Note: For consistency, `convertRegionTypes` could be renamed to `applySignatureConversion` (overload) in the future. (Or `applySignatureConversion` renamed to `convertBlockTypes`.) Also clarify when a type converter and/or signature conversion object is needed and for what purpose. Internal code refactoring (NFC) of `ConversionPatternRewriterImpl` (the part that deals with signature conversions). This part of the codebase was quite convoluted and unintuitive. From a functional perspective, this change is NFC. However, the public API changes, thus not marking as NFC. Note for LLVM integration: When you see `applySignatureConversion(region, ...)`, replace with `applySignatureConversion(region->front(), ...)`. In the unlikely case that you see `convertNonEntryRegionTypes`, apply the same changes as this commit did to `Detensorize.cpp`. --------- Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
-rw-r--r--mlir/docs/DialectConversion.md30
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h49
-rw-r--r--mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp20
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp123
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp5
6 files changed, 81 insertions, 148 deletions
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index a355d5a..69781bb 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -372,19 +372,23 @@ class TypeConverter {
From the perspective of type conversion, the types of block arguments are a bit
special. Throughout the conversion process, blocks may move between regions of
different operations. Given this, the conversion of the types for blocks must be
-done explicitly via a conversion pattern. To convert the types of block
-arguments within a Region, a custom hook on the `ConversionPatternRewriter` must
-be invoked; `convertRegionTypes`. This hook uses a provided type converter to
-apply type conversions to all blocks within a given region, and all blocks that
-move into that region. As noted above, the conversions performed by this method
-use the argument materialization hook on the `TypeConverter`. This hook also
-takes an optional `TypeConverter::SignatureConversion` parameter that applies a
-custom conversion to the entry block of the region. The types of the entry block
-arguments are often tied semantically to details on the operation, e.g. func::FuncOp,
-AffineForOp, etc. To convert the signature of just the region entry block, and
-not any other blocks within the region, the `applySignatureConversion` hook may
-be used instead. A signature conversion, `TypeConverter::SignatureConversion`,
-can be built programmatically:
+done explicitly via a conversion pattern.
+
+To convert the types of block arguments within a Region, a custom hook on the
+`ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook
+uses a provided type converter to apply type conversions to all blocks of a
+given region. As noted above, the conversions performed by this method use the
+argument materialization hook on the `TypeConverter`. This hook also takes an
+optional `TypeConverter::SignatureConversion` parameter that applies a custom
+conversion to the entry block of the region. The types of the entry block
+arguments are often tied semantically to the operation, e.g.,
+`func::FuncOp`, `AffineForOp`, etc.
+
+To convert the signature of just one given block, the
+`applySignatureConversion` hook can be used.
+
+A signature conversion, `TypeConverter::SignatureConversion`, can be built
+programmatically:
```c++
class SignatureConversion {
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f6c5149..f83f3a3 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -247,7 +247,8 @@ public:
/// Attempts a 1-1 type conversion, expecting the result type to be
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
/// and a null type on conversion or cast failure.
- template <typename TargetType> TargetType convertType(Type t) const {
+ template <typename TargetType>
+ TargetType convertType(Type t) const {
return dyn_cast_or_null<TargetType>(convertType(t));
}
@@ -661,42 +662,42 @@ class ConversionPatternRewriter final : public PatternRewriter {
public:
~ConversionPatternRewriter() override;
- /// Apply a signature conversion to the entry block of the given region. This
- /// replaces the entry block with a new block containing the updated
- /// signature. The new entry block to the region is returned for convenience.
- /// If no block argument types are changing, the entry original block will be
+ /// Apply a signature conversion to given block. This replaces the block with
+ /// a new block containing the updated signature. The operations of the given
+ /// block are inlined into the newly-created block, which is returned.
+ ///
+ /// If no block argument types are changing, the original block will be
/// left in place and returned.
///
- /// If provided, `converter` will be used for any materializations.
+ /// A signature converison must be provided. (Type converters can construct
+ /// a signature conversion with `convertBlockSignature`.)
+ ///
+ /// Optionally, a type converter can be provided to build materializations.
+ /// Note: If no type converter was provided or the type converter does not
+ /// specify any suitable argument/target materialization rules, the dialect
+ /// conversion may fail to legalize unresolved materializations.
Block *
- applySignatureConversion(Region *region,
+ applySignatureConversion(Block *block,
TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter = nullptr);
- /// Convert the types of block arguments within the given region. This
+ /// Apply a signature conversion to each block in the given region. This
/// replaces each block with a new block containing the updated signature. If
/// an updated signature would match the current signature, the respective
- /// block is left in place as is.
+ /// block is left in place as is. (See `applySignatureConversion` for
+ /// details.) The new entry block of the region is returned.
+ ///
+ /// SignatureConversions are computed with the specified type converter.
+ /// This function returns "failure" if the type converter failed to compute
+ /// a SignatureConversion for at least one block.
///
- /// The entry block may have a special conversion if `entryConversion` is
- /// provided. On success, the new entry block to the region is returned for
- /// convenience. Otherwise, failure is returned.
+ /// Optionally, a special SignatureConversion can be specified for the entry
+ /// block. This is because the types of the entry block arguments are often
+ /// tied semantically to the operation.
FailureOr<Block *> convertRegionTypes(
Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion = nullptr);
- /// Convert the types of block arguments within the given region except for
- /// the entry region. This replaces each non-entry block with a new block
- /// containing the updated signature. If an updated signature would match the
- /// current signature, the respective block is left in place as is.
- ///
- /// If special conversion behavior is needed for the non-entry blocks (for
- /// example, we need to convert only a subset of a BB arguments), such
- /// behavior can be specified in blockConversions.
- LogicalResult convertNonEntryRegionTypes(
- Region *region, const TypeConverter &converter,
- ArrayRef<TypeConverter::SignatureConversion> blockConversions);
-
/// Replace all the uses of the block argument `from` with value `to`.
void replaceUsesOfBlockArgument(BlockArgument from, Value to);
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index d90cf93..f62de1f 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -162,7 +162,7 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
signatureConverter.remapInput(0, newIndVar);
for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
signatureConverter.remapInput(i, header->getArgument(i));
- body = rewriter.applySignatureConversion(&forOp.getRegion(),
+ body = rewriter.applySignatureConversion(&forOp.getRegion().front(),
signatureConverter);
// Move the blocks from the forOp into the loopOp. This is the body of the
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 2296809..af38485 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -106,27 +106,23 @@ struct FunctionNonEntryBlockConversion
ConversionPatternRewriter &rewriter) const override {
rewriter.startOpModification(op);
Region &region = op.getFunctionBody();
- SmallVector<TypeConverter::SignatureConversion, 2> conversions;
- for (Block &block : llvm::drop_begin(region, 1)) {
- conversions.emplace_back(block.getNumArguments());
- TypeConverter::SignatureConversion &back = conversions.back();
+ for (Block &block :
+ llvm::make_early_inc_range(llvm::drop_begin(region, 1))) {
+ TypeConverter::SignatureConversion conversion(
+ /*numOrigInputs=*/block.getNumArguments());
for (BlockArgument blockArgument : block.getArguments()) {
int idx = blockArgument.getArgNumber();
if (blockArgsToDetensor.count(blockArgument))
- back.addInputs(idx, {getTypeConverter()->convertType(
- block.getArgumentTypes()[idx])});
+ conversion.addInputs(idx, {getTypeConverter()->convertType(
+ block.getArgumentTypes()[idx])});
else
- back.addInputs(idx, {block.getArgumentTypes()[idx]});
+ conversion.addInputs(idx, {block.getArgumentTypes()[idx]});
}
- }
- if (failed(rewriter.convertNonEntryRegionTypes(&region, *typeConverter,
- conversions))) {
- rewriter.cancelOpModification(op);
- return failure();
+ rewriter.applySignatureConversion(&block, conversion, getTypeConverter());
}
rewriter.finalizeOpModification(op);
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d407d60..2f0efe1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -839,27 +839,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// Type Conversion
//===--------------------------------------------------------------------===//
- /// Attempt to convert the signature of the given block, if successful a new
- /// block is returned containing the new arguments. Returns `block` if it did
- /// not require conversion.
- FailureOr<Block *> convertBlockSignature(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
- TypeConverter::SignatureConversion *conversion = nullptr);
-
- /// Convert the types of non-entry block arguments within the given region.
- LogicalResult convertNonEntryRegionTypes(
- ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
- ArrayRef<TypeConverter::SignatureConversion> blockConversions = {});
-
- /// Apply a signature conversion on the given region, using `converter` for
- /// materializations if not null.
- Block *
- applySignatureConversion(ConversionPatternRewriter &rewriter, Region *region,
- TypeConverter::SignatureConversion &conversion,
- const TypeConverter *converter);
-
/// Convert the types of block arguments within the given region.
FailureOr<Block *>
convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
@@ -1294,34 +1273,6 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
//===----------------------------------------------------------------------===//
// Type Conversion
-FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
- TypeConverter::SignatureConversion *conversion) {
- if (conversion)
- return applySignatureConversion(rewriter, block, converter, *conversion);
-
- // If a converter wasn't provided, and the block wasn't already converted,
- // there is nothing we can do.
- if (!converter)
- return failure();
-
- // Try to convert the signature for the block with the provided converter.
- if (auto conversion = converter->convertBlockSignature(block))
- return applySignatureConversion(rewriter, block, converter, *conversion);
- return failure();
-}
-
-Block *ConversionPatternRewriterImpl::applySignatureConversion(
- ConversionPatternRewriter &rewriter, Region *region,
- TypeConverter::SignatureConversion &conversion,
- const TypeConverter *converter) {
- if (!region->empty())
- return *convertBlockSignature(rewriter, &region->front(), converter,
- &conversion);
- return nullptr;
-}
-
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
ConversionPatternRewriter &rewriter, Region *region,
const TypeConverter &converter,
@@ -1330,42 +1281,29 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
if (region->empty())
return nullptr;
- if (failed(convertNonEntryRegionTypes(rewriter, region, converter)))
- return failure();
-
- FailureOr<Block *> newEntry = convertBlockSignature(
- rewriter, &region->front(), &converter, entryConversion);
- return newEntry;
-}
-
-LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
- ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
- ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
- regionToConverter[region] = &converter;
- if (region->empty())
- return success();
-
- // Convert the arguments of each block within the region.
- int blockIdx = 0;
- assert((blockConversions.empty() ||
- blockConversions.size() == region->getBlocks().size() - 1) &&
- "expected either to provide no SignatureConversions at all or to "
- "provide a SignatureConversion for each non-entry block");
-
+ // Convert the arguments of each non-entry block within the region.
for (Block &block :
llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) {
- TypeConverter::SignatureConversion *blockConversion =
- blockConversions.empty()
- ? nullptr
- : const_cast<TypeConverter::SignatureConversion *>(
- &blockConversions[blockIdx++]);
-
- if (failed(convertBlockSignature(rewriter, &block, &converter,
- blockConversion)))
+ // Compute the signature for the block with the provided converter.
+ std::optional<TypeConverter::SignatureConversion> conversion =
+ converter.convertBlockSignature(&block);
+ if (!conversion)
return failure();
- }
- return success();
+ // Convert the block with the computed signature.
+ applySignatureConversion(rewriter, &block, &converter, *conversion);
+ }
+
+ // Convert the entry block. If an entry signature conversion was provided,
+ // use that one. Otherwise, compute the signature with the type converter.
+ if (entryConversion)
+ return applySignatureConversion(rewriter, &region->front(), &converter,
+ *entryConversion);
+ std::optional<TypeConverter::SignatureConversion> conversion =
+ converter.convertBlockSignature(&region->front());
+ if (!conversion)
+ return failure();
+ return applySignatureConversion(rewriter, &region->front(), &converter,
+ *conversion);
}
Block *ConversionPatternRewriterImpl::applySignatureConversion(
@@ -1676,12 +1614,12 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
}
Block *ConversionPatternRewriter::applySignatureConversion(
- Region *region, TypeConverter::SignatureConversion &conversion,
+ Block *block, TypeConverter::SignatureConversion &conversion,
const TypeConverter *converter) {
- assert(!impl->wasOpReplaced(region->getParentOp()) &&
+ assert(!impl->wasOpReplaced(block->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->applySignatureConversion(*this, region, conversion, converter);
+ return impl->applySignatureConversion(*this, block, converter, conversion);
}
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -1693,16 +1631,6 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
return impl->convertRegionTypes(*this, region, converter, entryConversion);
}
-LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
- Region *region, const TypeConverter &converter,
- ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
- assert(!impl->wasOpReplaced(region->getParentOp()) &&
- "attempting to apply a signature conversion to a block within a "
- "replaced/erased op");
- return impl->convertNonEntryRegionTypes(*this, region, converter,
- blockConversions);
-}
-
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value to) {
LLVM_DEBUG({
@@ -2231,11 +2159,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// If the region of the block has a type converter, try to convert the block
// directly.
if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
- if (failed(impl.convertBlockSignature(rewriter, block, converter))) {
+ std::optional<TypeConverter::SignatureConversion> conversion =
+ converter->convertBlockSignature(block);
+ if (!conversion) {
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
"block"));
return failure();
}
+ impl.applySignatureConversion(rewriter, block, converter, *conversion);
continue;
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index f9f7d4e..a14a5da 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1516,8 +1516,9 @@ struct TestTestSignatureConversionNoConverter
if (failed(
converter.convertSignatureArgs(entry->getArgumentTypes(), result)))
return failure();
- rewriter.modifyOpInPlace(
- op, [&] { rewriter.applySignatureConversion(&region, result); });
+ rewriter.modifyOpInPlace(op, [&] {
+ rewriter.applySignatureConversion(&region.front(), result);
+ });
return success();
}