From f1e0657d144f5a3cfef4b625d0f875f4dacd21d1 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 25 Jun 2024 08:43:28 +0200 Subject: [mlir][Transforms] Dialect conversion: Simplify handling of dropped arguments (#96207) This commit simplifies the handling of dropped arguments and updates some dialect conversion documentation that is outdated. When converting a block signature, a `BlockTypeConversionRewrite` object and potentially multiple `ReplaceBlockArgRewrite` are created. During the "commit" phase, uses of the old block arguments are replaced with the new block arguments, but the old implementation was written in an inconsistent way: some block arguments were replaced in `BlockTypeConversionRewrite::commit` and some were replaced in `ReplaceBlockArgRewrite::commit`. The new `BlockTypeConversionRewrite::commit` implementation is much simpler and no longer modifies any IR; that is done only in `ReplaceBlockArgRewrite` now. The `ConvertedArgInfo` data structure is no longer needed. To that end, materializations of dropped arguments are now built in `applySignatureConversion` instead of `materializeLiveConversions`; the latter function no longer has to deal with dropped arguments. Other minor improvements: - Improve variable name: `origOutputType` -> `origArgType`. Add an assertion to check that this field is only used for argument materializations. - Add more comments to `applySignatureConversion`. Note: Error messages around failed materializations for dropped basic block arguments changed slightly. That is because those materializations are now built in `legalizeUnresolvedMaterialization` instead of `legalizeConvertedArgumentTypes`. This commit is in preparation of decoupling argument/source/target materializations from the dialect conversion. --- mlir/docs/DialectConversion.md | 44 +++-- mlir/include/mlir/Transforms/DialectConversion.h | 17 +- mlir/lib/Transforms/Utils/DialectConversion.cpp | 208 ++++++++------------- .../Transforms/test-legalize-type-conversion.mlir | 6 +- 4 files changed, 119 insertions(+), 156 deletions(-) diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index 69781bb..3002ac00 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -246,6 +246,13 @@ depending on the situation. - An argument materialization is used when converting the type of a block argument during a [signature conversion](#region-signature-conversion). + The new block argument types are specified in a `SignatureConversion` + object. An original block argument can be converted into multiple + block arguments, which is not supported everywhere in the dialect + conversion. (E.g., adaptors support only a single replacement value for + each original value.) Therefore, an argument materialization is used to + convert potentially multiple new block arguments back into a single SSA + value. * Source Materialization @@ -259,6 +266,9 @@ depending on the situation. * When a block argument has been converted to a different type, but the original argument still has users that will remain live after the conversion process has finished. + * When a block argument has been dropped, but the argument still has + users that will remain live after the conversion process has + finished. * When the result type of an operation has been converted to a different type, but the original result still has users that will remain live after the conversion process is finished. @@ -328,19 +338,22 @@ class TypeConverter { registerConversion(wrapCallback(std::forward(callback))); } - /// Register a materialization function, which must be convertible to the - /// following form: - /// `Optional (OpBuilder &, T, ValueRange, Location)`, - /// where `T` is any subclass of `Type`. - /// This function is responsible for creating an operation, using the - /// OpBuilder and Location provided, that "converts" a range of values into a - /// single value of the given type `T`. It must return a Value of the - /// converted type on success, an `std::nullopt` if it failed but other - /// materialization can be attempted, and `nullptr` on unrecoverable failure. - /// It will only be called for (sub)types of `T`. - /// + /// All of the following materializations require function objects that are + /// convertible to the following form: + /// `std::optional(OpBuilder &, T, ValueRange, Location)`, + /// where `T` is any subclass of `Type`. This function is responsible for + /// creating an operation, using the OpBuilder and Location provided, that + /// "casts" a range of values into a single value of the given type `T`. It + /// must return a Value of the converted type on success, an `std::nullopt` if + /// it failed but other materialization can be attempted, and `nullptr` on + /// unrecoverable failure. It will only be called for (sub)types of `T`. + /// Materialization functions must be provided when a type conversion may + /// persist after the conversion has finished. + /// This method registers a materialization that will be called when - /// converting an illegal block argument type, to a legal type. + /// converting (potentially multiple) block arguments that were the result of + /// a signature conversion of a single block argument, to a single SSA value + /// of a legal type. template ::template arg_t<1>> void addArgumentMaterialization(FnT &&callback) { @@ -348,8 +361,9 @@ class TypeConverter { wrapMaterialization(std::forward(callback))); } /// This method registers a materialization that will be called when - /// converting a legal type to an illegal source type. This is used when - /// conversions to an illegal type must persist beyond the main conversion. + /// converting a legal replacement value back to an illegal source type. + /// This is used when some uses of the original, illegal value must persist + /// beyond the main conversion. template ::template arg_t<1>> void addSourceMaterialization(FnT &&callback) { @@ -357,7 +371,7 @@ class TypeConverter { wrapMaterialization(std::forward(callback))); } /// This method registers a materialization that will be called when - /// converting type from an illegal, or source, type to a legal type. + /// converting an illegal (source) value to a legal (target) type. template ::template arg_t<1>> void addTargetMaterialization(FnT &&callback) { diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index f83f3a3..e17f6f6 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -168,8 +168,8 @@ public: registerConversion(wrapCallback(std::forward(callback))); } - /// Register a materialization function, which must be convertible to the - /// following form: + /// All of the following materializations require function objects that are + /// convertible to the following form: /// `std::optional(OpBuilder &, T, ValueRange, Location)`, /// where `T` is any subclass of `Type`. This function is responsible for /// creating an operation, using the OpBuilder and Location provided, that @@ -179,9 +179,11 @@ public: /// unrecoverable failure. It will only be called for (sub)types of `T`. /// Materialization functions must be provided when a type conversion may /// persist after the conversion has finished. - /// + /// This method registers a materialization that will be called when - /// converting an illegal block argument type, to a legal type. + /// converting (potentially multiple) block arguments that were the result of + /// a signature conversion of a single block argument, to a single SSA value + /// of a legal type. template >::template arg_t<1>> void addArgumentMaterialization(FnT &&callback) { @@ -189,8 +191,9 @@ public: wrapMaterialization(std::forward(callback))); } /// This method registers a materialization that will be called when - /// converting a legal type to an illegal source type. This is used when - /// conversions to an illegal type must persist beyond the main conversion. + /// converting a legal replacement value back to an illegal source type. + /// This is used when some uses of the original, illegal value must persist + /// beyond the main conversion. template >::template arg_t<1>> void addSourceMaterialization(FnT &&callback) { @@ -198,7 +201,7 @@ public: wrapMaterialization(std::forward(callback))); } /// This method registers a materialization that will be called when - /// converting type from an illegal, or source, type to a legal type. + /// converting an illegal (source) value to a legal (target) type. template >::template arg_t<1>> void addTargetMaterialization(FnT &&callback) { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index e6c0ee2..07ebd68 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -432,34 +432,14 @@ private: Block *insertBeforeBlock; }; -/// This structure contains the information pertaining to an argument that has -/// been converted. -struct ConvertedArgInfo { - ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize, - Value castValue = nullptr) - : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {} - - /// The start index of in the new argument list that contains arguments that - /// replace the original. - unsigned newArgIdx; - - /// The number of arguments that replaced the original argument. - unsigned newArgSize; - - /// The cast value that was created to cast from the new arguments to the - /// old. This only used if 'newArgSize' > 1. - Value castValue; -}; - /// Block type conversion. This rewrite is partially reflected in the IR. class BlockTypeConversionRewrite : public BlockRewrite { public: - BlockTypeConversionRewrite( - ConversionPatternRewriterImpl &rewriterImpl, Block *block, - Block *origBlock, SmallVector, 1> argInfo, - const TypeConverter *converter) + BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl, + Block *block, Block *origBlock, + const TypeConverter *converter) : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block), - origBlock(origBlock), argInfo(argInfo), converter(converter) {} + origBlock(origBlock), converter(converter) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::BlockTypeConversion; @@ -479,10 +459,6 @@ private: /// The original block that was requested to have its signature converted. Block *origBlock; - /// The conversion information for each of the arguments. The information is - /// std::nullopt if the argument was dropped during conversion. - SmallVector, 1> argInfo; - /// The type converter used to convert the arguments. const TypeConverter *converter; }; @@ -696,7 +672,11 @@ enum MaterializationKind { /// This materialization materializes a conversion from an illegal type to a /// legal one. - Target + Target, + + /// This materialization materializes a conversion from a legal type back to + /// an illegal one. + Source }; /// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast" @@ -708,9 +688,13 @@ public: ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr, MaterializationKind kind = MaterializationKind::Target, - Type origOutputType = nullptr) + Type origArgType = nullptr) : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), - converterAndKind(converter, kind), origOutputType(origOutputType) {} + converterAndKind(converter, kind), origArgType(origArgType) { + assert(kind == MaterializationKind::Argument || + !origArgType && "orginal argument type make sense only for argument " + "materializations"); + } static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::UnresolvedMaterialization; @@ -734,17 +718,17 @@ public: return converterAndKind.getInt(); } - /// Return the original illegal output type of the input values. - Type getOrigOutputType() const { return origOutputType; } + /// Return the original type of the block argument. + Type getOrigArgType() const { return origArgType; } private: /// The corresponding type converter to use when resolving this /// materialization, and the kind of this materialization. - llvm::PointerIntPair + llvm::PointerIntPair converterAndKind; /// The original output type. This is only used for argument conversions. - Type origOutputType; + Type origArgType; }; } // namespace @@ -862,13 +846,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { ValueRange inputs, Type outputType, Type origOutputType, const TypeConverter *converter); - - Value buildUnresolvedArgumentMaterialization(Block *block, Location loc, - ValueRange inputs, - Type origOutputType, - Type outputType, - const TypeConverter *converter); - Value buildUnresolvedTargetMaterialization(Location loc, Value input, Type outputType, const TypeConverter *converter); @@ -998,28 +975,6 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) { dyn_cast_or_null(rewriter.getListener())) for (Operation *op : block->getUsers()) listener->notifyOperationModified(op); - - // Process the remapping for each of the original arguments. - for (auto [origArg, info] : - llvm::zip_equal(origBlock->getArguments(), argInfo)) { - // Handle the case of a 1->0 value mapping. - if (!info) { - if (Value newArg = - rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType())) - rewriter.replaceAllUsesWith(origArg, newArg); - continue; - } - - // Otherwise this is a 1->1+ value mapping. - Value castValue = info->castValue; - assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping"); - - // If the argument is still used, replace it with the generated cast. - if (!origArg.use_empty()) { - rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault( - castValue, origArg.getType())); - } - } } void BlockTypeConversionRewrite::rollback() { @@ -1043,15 +998,13 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions( if (!liveUser) continue; - Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg); - bool isDroppedArg = replacementValue == origArg; - if (!isDroppedArg) - builder.setInsertionPointAfterValue(replacementValue); + Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg); + assert(replacementValue && "replacement value not found"); Value newArg; if (converter) { + builder.setInsertionPointAfterValue(replacementValue); newArg = converter->materializeSourceConversion( - builder, origArg.getLoc(), origArg.getType(), - isDroppedArg ? ValueRange() : ValueRange(replacementValue)); + builder, origArg.getLoc(), origArg.getType(), replacementValue); assert((!newArg || newArg.getType() == origArg.getType()) && "materialization hook did not provide a value of the expected " "type"); @@ -1062,8 +1015,6 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions( << "failed to materialize conversion for block argument #" << it.index() << " that remained live after conversion, type was " << origArg.getType(); - if (!isDroppedArg) - diag << ", with target type " << replacementValue.getType(); diag.attachNote(liveUser->getLoc()) << "see existing live user here: " << *liveUser; return failure(); @@ -1349,65 +1300,65 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( // Replace all uses of the old block with the new block. block->replaceAllUsesWith(newBlock); - // Remap each of the original arguments as determined by the signature - // conversion. - SmallVector, 1> argInfo; - argInfo.resize(origArgCount); - for (unsigned i = 0; i != origArgCount; ++i) { - auto inputMap = signatureConversion.getInputMapping(i); - if (!inputMap) - continue; BlockArgument origArg = block->getArgument(i); + Type origArgType = origArg.getType(); - // If inputMap->replacementValue is not nullptr, then the argument is - // dropped and a replacement value is provided to be the remappedValue. - if (inputMap->replacementValue) { - assert(inputMap->size == 0 && - "invalid to provide a replacement value when the argument isn't " - "dropped"); - mapping.map(origArg, inputMap->replacementValue); - appendRewrite(block, origArg); - continue; - } - - // Otherwise, this is a 1->1+ mapping. - auto replArgs = - newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - Value newArg; - - // If this is a 1->1 mapping and the types of new and replacement arguments - // match (i.e. it's an identity map), then the argument is mapped to its - // original type. + // Helper function that tries to legalize the given type. Returns the given + // type if it could not be legalized. // FIXME: We simply pass through the replacement argument if there wasn't a // converter, which isn't great as it allows implicit type conversions to // appear. We should properly restructure this code to handle cases where a // converter isn't provided and also to properly handle the case where an // argument materialization is actually a temporary source materialization // (e.g. in the case of 1->N). - if (replArgs.size() == 1 && - (!converter || replArgs[0].getType() == origArg.getType())) { - newArg = replArgs.front(); - } else { - Type origOutputType = origArg.getType(); + auto tryLegalizeType = [&](Type type) { + if (converter) + if (Type t = converter->convertType(type)) + return t; + return type; + }; - // Legalize the argument output type. - Type outputType = origOutputType; - if (Type legalOutputType = converter->convertType(outputType)) - outputType = legalOutputType; + std::optional inputMap = + signatureConversion.getInputMapping(i); + if (!inputMap) { + // This block argument was dropped and no replacement value was provided. + // Materialize a replacement value "out of thin air". + Value repl = buildUnresolvedMaterialization( + MaterializationKind::Source, newBlock, newBlock->begin(), + origArg.getLoc(), /*inputs=*/ValueRange(), + /*outputType=*/origArgType, /*origArgType=*/{}, converter); + mapping.map(origArg, repl); + appendRewrite(block, origArg); + continue; + } - newArg = buildUnresolvedArgumentMaterialization( - newBlock, origArg.getLoc(), replArgs, origOutputType, outputType, - converter); + if (Value repl = inputMap->replacementValue) { + // This block argument was dropped and a replacement value was provided. + assert(inputMap->size == 0 && + "invalid to provide a replacement value when the argument isn't " + "dropped"); + mapping.map(origArg, repl); + appendRewrite(block, origArg); + continue; } - mapping.map(origArg, newArg); + // This is a 1->1+ mapping. 1->N mappings are not fully supported in the + // dialect conversion. Therefore, we need an argument materialization to + // turn the replacement block arguments into a single SSA value that can be + // used as a replacement. The type of this SSA value is the legalized + // version of the original block argument type. + auto replArgs = + newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); + Value repl = buildUnresolvedMaterialization( + MaterializationKind::Argument, newBlock, newBlock->begin(), + origArg.getLoc(), /*inputs=*/replArgs, + /*outputType=*/tryLegalizeType(origArgType), origArgType, converter); + mapping.map(origArg, repl); appendRewrite(block, origArg); - argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); } - appendRewrite(newBlock, block, argInfo, - converter); + appendRewrite(newBlock, block, converter); // Erase the old block. (It is just unlinked for now and will be erased during // cleanup.) @@ -1424,7 +1375,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /// of input operands. Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( MaterializationKind kind, Block *insertBlock, Block::iterator insertPt, - Location loc, ValueRange inputs, Type outputType, Type origOutputType, + Location loc, ValueRange inputs, Type outputType, Type origArgType, const TypeConverter *converter) { // Avoid materializing an unnecessary cast. if (inputs.size() == 1 && inputs.front().getType() == outputType) @@ -1436,16 +1387,9 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization( auto convertOp = builder.create(loc, outputType, inputs); appendRewrite(convertOp, converter, kind, - origOutputType); + origArgType); return convertOp.getResult(0); } -Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization( - Block *block, Location loc, ValueRange inputs, Type origOutputType, - Type outputType, const TypeConverter *converter) { - return buildUnresolvedMaterialization(MaterializationKind::Argument, block, - block->begin(), loc, inputs, outputType, - origOutputType, converter); -} Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization( Location loc, Value input, Type outputType, const TypeConverter *converter) { @@ -1454,9 +1398,9 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization( if (OpResult inputRes = dyn_cast(input)) insertPt = ++inputRes.getOwner()->getIterator(); - return buildUnresolvedMaterialization(MaterializationKind::Target, - insertBlock, insertPt, loc, input, - outputType, outputType, converter); + return buildUnresolvedMaterialization( + MaterializationKind::Target, insertBlock, insertPt, loc, input, + outputType, /*origArgType=*/{}, converter); } //===----------------------------------------------------------------------===// @@ -2852,7 +2796,7 @@ static LogicalResult legalizeUnresolvedMaterialization( // easily misunderstood. We should clean up the argument hooks to better // represent the desired invariants we actually care about. newMaterialization = converter->materializeArgumentConversion( - rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands); + rewriter, op->getLoc(), mat.getOrigArgType(), inputOperands); if (newMaterialization) break; @@ -2863,6 +2807,10 @@ static LogicalResult legalizeUnresolvedMaterialization( newMaterialization = converter->materializeTargetConversion( rewriter, op->getLoc(), outputType, inputOperands); break; + case MaterializationKind::Source: + newMaterialization = converter->materializeSourceConversion( + rewriter, op->getLoc(), outputType, inputOperands); + break; } if (newMaterialization) { replaceMaterialization(rewriterImpl, opResult, newMaterialization, @@ -2873,8 +2821,8 @@ static LogicalResult legalizeUnresolvedMaterialization( InFlightDiagnostic diag = op->emitError() << "failed to legalize unresolved materialization " - "from " - << inputOperands.getTypes() << " to " << outputType + "from (" + << inputOperands.getTypes() << ") to " << outputType << " that remained live after conversion"; if (Operation *liveUser = findLiveUser(op->getUsers())) { diag.attachNote(liveUser->getLoc()) diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir index b35cda8..8254be6 100644 --- a/mlir/test/Transforms/test-legalize-type-conversion.mlir +++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir @@ -2,9 +2,8 @@ func.func @test_invalid_arg_materialization( - // expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}} + // expected-error@below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}} %arg0: i16) { - // expected-note@below {{see existing live user here}} "foo.return"(%arg0) : (i16) -> () } @@ -104,9 +103,8 @@ func.func @test_block_argument_not_converted() { // Make sure argument type changes aren't implicitly forwarded. func.func @test_signature_conversion_no_converter() { "test.signature_conversion_no_converter"() ({ - // expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion}} + // expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}} ^bb0(%arg0: f32): - // expected-note@below {{see existing live user here}} "test.type_consumer"(%arg0) : (f32) -> () "test.return"(%arg0) : (f32) -> () }) : () -> () -- cgit v1.1