diff options
Diffstat (limited to 'mlir/lib/Transforms/Utils/DialectConversion.cpp')
-rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 455 |
1 files changed, 214 insertions, 241 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 255b0ba..403321d 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Iterators.h" #include "mlir/Interfaces/FunctionInterfaces.h" @@ -63,11 +64,55 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) { return OpBuilder::InsertPoint(insertBlock, insertPt); } +/// Helper function that computes an insertion point where the given values are +/// defined and can be used without a dominance violation. +static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) { + assert(!vals.empty() && "expected at least one value"); + DominanceInfo domInfo; + OpBuilder::InsertPoint pt = computeInsertPoint(vals.front()); + for (Value v : vals.drop_front()) { + // Choose the "later" insertion point. + OpBuilder::InsertPoint nextPt = computeInsertPoint(v); + if (domInfo.dominates(pt.getBlock(), pt.getPoint(), nextPt.getBlock(), + nextPt.getPoint())) { + // pt is before nextPt => choose nextPt. + pt = nextPt; + } else { +#ifndef NDEBUG + // nextPt should be before pt => choose pt. + // If pt, nextPt are no dominance relationship, then there is no valid + // insertion point at which all given values are defined. + bool dom = domInfo.dominates(nextPt.getBlock(), nextPt.getPoint(), + pt.getBlock(), pt.getPoint()); + assert(dom && "unable to find valid insertion point"); +#endif // NDEBUG + } + } + return pt; +} + //===----------------------------------------------------------------------===// // ConversionValueMapping //===----------------------------------------------------------------------===// +/// A vector of SSA values, optimized for the most common case of a single +/// value. +using ValueVector = SmallVector<Value, 1>; + namespace { + +/// Helper class to make it possible to use `ValueVector` as a key in DenseMap. +struct ValueVectorMapInfo { + static ValueVector getEmptyKey() { return ValueVector{Value()}; } + static ValueVector getTombstoneKey() { return ValueVector{Value(), Value()}; } + static ::llvm::hash_code getHashValue(const ValueVector &val) { + return ::llvm::hash_combine_range(val.begin(), val.end()); + } + static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) { + return LHS == RHS; + } +}; + /// This class wraps a IRMapping to provide recursive lookup /// functionality, i.e. we will traverse if the mapped value also has a mapping. struct ConversionValueMapping { @@ -75,68 +120,128 @@ struct ConversionValueMapping { /// false positives. bool isMappedTo(Value value) const { return mappedTo.contains(value); } - /// Lookup the most recently mapped value with the desired type in the + /// Lookup the most recently mapped values with the desired types in the /// mapping. /// /// Special cases: - /// - If the desired type is "null", simply return the most recently mapped - /// value. - /// - If there is no mapping to the desired type, also return the most - /// recently mapped value. - /// - If there is no mapping for the given value at all, return the given + /// - If the desired type range is empty, simply return the most recently + /// mapped values. + /// - If there is no mapping to the desired types, also return the most + /// recently mapped values. + /// - If there is no mapping for the given values at all, return the given /// value. - Value lookupOrDefault(Value from, Type desiredType = nullptr) const; + ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; - /// Lookup a mapped value within the map, or return null if a mapping does not - /// exist. If a mapping exists, this follows the same behavior of - /// `lookupOrDefault`. - Value lookupOrNull(Value from, Type desiredType = nullptr) const; + /// Lookup the given value within the map, or return an empty vector if the + /// value is not mapped. If it is mapped, this follows the same behavior + /// as `lookupOrDefault`. + ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const; - /// Map a value to the one provided. - void map(Value oldVal, Value newVal) { + template <typename T> + struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {}; + + /// Map a value vector to the one provided. + template <typename OldVal, typename NewVal> + std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value> + map(OldVal &&oldVal, NewVal &&newVal) { LLVM_DEBUG({ - for (Value it = newVal; it; it = mapping.lookupOrNull(it)) - assert(it != oldVal && "inserting cyclic mapping"); + ValueVector next(newVal); + while (true) { + assert(next != oldVal && "inserting cyclic mapping"); + auto it = mapping.find(next); + if (it == mapping.end()) + break; + next = it->second; + } }); - mapping.map(oldVal, newVal); - mappedTo.insert(newVal); + for (Value v : newVal) + mappedTo.insert(v); + + mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal); } - /// Drop the last mapping for the given value. - void erase(Value value) { mapping.erase(value); } + /// Map a value vector or single value to the one provided. + template <typename OldVal, typename NewVal> + std::enable_if_t<!IsValueVector<OldVal>::value || + !IsValueVector<NewVal>::value> + map(OldVal &&oldVal, NewVal &&newVal) { + if constexpr (IsValueVector<OldVal>{}) { + map(std::forward<OldVal>(oldVal), ValueVector{newVal}); + } else if constexpr (IsValueVector<NewVal>{}) { + map(ValueVector{oldVal}, std::forward<NewVal>(newVal)); + } else { + map(ValueVector{oldVal}, ValueVector{newVal}); + } + } + + /// Drop the last mapping for the given values. + void erase(const ValueVector &value) { mapping.erase(value); } private: /// Current value mappings. - IRMapping mapping; + DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> mapping; /// All SSA values that are mapped to. May contain false positives. DenseSet<Value> mappedTo; }; } // namespace -Value ConversionValueMapping::lookupOrDefault(Value from, - Type desiredType) const { - // Try to find the deepest value that has the desired type. If there is no - // such value, simply return the deepest value. - Value desiredValue; +ValueVector +ConversionValueMapping::lookupOrDefault(Value from, + TypeRange desiredTypes) const { + // Try to find the deepest values that have the desired types. If there is no + // such mapping, simply return the deepest values. + ValueVector desiredValue; + ValueVector current{from}; do { - if (!desiredType || from.getType() == desiredType) - desiredValue = from; + // Store the current value if the types match. + if (TypeRange(ValueRange(current)) == desiredTypes) + desiredValue = current; + + // If possible, Replace each value with (one or multiple) mapped values. + ValueVector next; + for (Value v : current) { + auto it = mapping.find({v}); + if (it != mapping.end()) { + llvm::append_range(next, it->second); + } else { + next.push_back(v); + } + } + if (next != current) { + // If at least one value was replaced, continue the lookup from there. + current = std::move(next); + continue; + } - Value mappedValue = mapping.lookupOrNull(from); - if (!mappedValue) + // Otherwise: Check if there is a mapping for the entire vector. Such + // mappings are materializations. (N:M mapping are not supported for value + // replacements.) + // + // Note: From a correctness point of view, materializations do not have to + // be stored (and looked up) in the mapping. But for performance reasons, + // we choose to reuse existing IR (when possible) instead of creating it + // multiple times. + auto it = mapping.find(current); + if (it == mapping.end()) { + // No mapping found: The lookup stops here. break; - from = mappedValue; + } + current = it->second; } while (true); - // If the desired value was found use it, otherwise default to the leaf value. - return desiredValue ? desiredValue : from; + // If the desired values were found use them, otherwise default to the leaf + // values. + // Note: If `desiredTypes` is empty, this function always returns `current`. + return !desiredValue.empty() ? std::move(desiredValue) : std::move(current); } -Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const { - Value result = lookupOrDefault(from, desiredType); - if (result == from || (desiredType && result.getType() != desiredType)) - return nullptr; +ValueVector ConversionValueMapping::lookupOrNull(Value from, + TypeRange desiredTypes) const { + ValueVector result = lookupOrDefault(from, desiredTypes); + if (result == ValueVector{from} || + (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes)) + return {}; return result; } @@ -651,10 +756,6 @@ public: /// The type of materialization. enum MaterializationKind { - /// This materialization materializes a conversion for an illegal block - /// argument type, to the original one. - Argument, - /// This materialization materializes a conversion from an illegal type to a /// legal one. Target, @@ -673,7 +774,7 @@ public: UnrealizedConversionCastOp op, const TypeConverter *converter, MaterializationKind kind, Type originalType, - Value mappedValue); + ValueVector mappedValues); static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::UnresolvedMaterialization; @@ -708,9 +809,9 @@ private: /// materializations. Type originalType; - /// The value in the conversion value mapping that is being replaced by the + /// The values in the conversion value mapping that are being replaced by the /// results of this unresolved materialization. - Value mappedValue; + ValueVector mappedValues; }; } // namespace @@ -779,7 +880,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { LogicalResult remapValues(StringRef valueDiagTag, std::optional<Location> inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVector<SmallVector<Value>> &remapped); + SmallVector<ValueVector> &remapped); /// Return "true" if the given operation is ignored, and does not need to be /// converted. @@ -820,39 +921,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// If a cast op was built, it can optionally be returned with the `castOp` /// output argument. /// - /// If `valueToMap` is set to a non-null Value, then that value is mapped to + /// If `valuesToMap` is set to a non-null Value, then that value is mapped to /// the results of the unresolved materialization in the conversion value /// mapping. ValueRange buildUnresolvedMaterialization( MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, - Value valueToMap, ValueRange inputs, TypeRange outputTypes, + ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, UnrealizedConversionCastOp *castOp = nullptr); - Value buildUnresolvedMaterialization( - MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, - Value valueToMap, ValueRange inputs, Type outputType, Type originalType, - const TypeConverter *converter, - UnrealizedConversionCastOp *castOp = nullptr) { - return buildUnresolvedMaterialization(kind, ip, loc, valueToMap, inputs, - TypeRange(outputType), originalType, - converter, castOp) - .front(); - } - - /// Build an N:1 materialization for the given original value that was - /// replaced with the given replacement values. - /// - /// This is a workaround around incomplete 1:N support in the dialect - /// conversion driver. The conversion mapping can store only 1:1 replacements - /// and the conversion patterns only support single Value replacements in the - /// adaptor, so N values must be converted back to a single value. This - /// function will be deleted when full 1:N support has been added. - /// - /// This function inserts an argument materialization back to the original - /// type. - void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc, - ValueRange replacements, Value originalValue, - const TypeConverter *converter); /// Find a replacement value for the given SSA value in the conversion value /// mapping. The replacement value must have the same type as the given SSA @@ -862,16 +938,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { Value findOrBuildReplacementValue(Value value, const TypeConverter *converter); - /// Unpack an N:1 materialization and return the inputs of the - /// materialization. This function unpacks only those materializations that - /// were built with `insertNTo1Materialization`. - /// - /// This is a workaround around incomplete 1:N support in the dialect - /// conversion driver. It allows us to write 1:N conversion patterns while - /// 1:N support is still missing in the conversion value mapping. This - /// function will be deleted when full 1:N support has been added. - SmallVector<Value> unpackNTo1Materialization(Value value); - //===--------------------------------------------------------------------===// // Rewriter Notification Hooks //===--------------------------------------------------------------------===// @@ -974,10 +1040,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *> unresolvedMaterializations; - /// A set of all N:1 materializations that were added to work around - /// incomplete 1:N support in the dialect conversion driver. - DenseSet<UnrealizedConversionCastOp> nTo1TempMaterializations; - /// The current type converter, or nullptr if no type converter is currently /// active. const TypeConverter *currentTypeConverter = nullptr; @@ -1041,7 +1103,7 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { }); } -void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); } +void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); } void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { auto *listener = @@ -1082,7 +1144,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { void ReplaceOperationRewrite::rollback() { for (auto result : op->getResults()) - rewriterImpl.mapping.erase(result); + rewriterImpl.mapping.erase({result}); } void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) { @@ -1101,20 +1163,19 @@ void CreateOperationRewrite::rollback() { UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite( ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op, const TypeConverter *converter, MaterializationKind kind, Type originalType, - Value mappedValue) + ValueVector mappedValues) : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), converterAndKind(converter, kind), originalType(originalType), - mappedValue(mappedValue) { + mappedValues(std::move(mappedValues)) { assert((!originalType || kind == MaterializationKind::Target) && "original type is valid only for target materializations"); rewriterImpl.unresolvedMaterializations[op] = this; } void UnresolvedMaterializationRewrite::rollback() { - if (mappedValue) - rewriterImpl.mapping.erase(mappedValue); + if (!mappedValues.empty()) + rewriterImpl.mapping.erase(mappedValues); rewriterImpl.unresolvedMaterializations.erase(getOperation()); - rewriterImpl.nTo1TempMaterializations.erase(getOperation()); op->erase(); } @@ -1160,7 +1221,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { LogicalResult ConversionPatternRewriterImpl::remapValues( StringRef valueDiagTag, std::optional<Location> inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVector<SmallVector<Value>> &remapped) { + SmallVector<ValueVector> &remapped) { remapped.reserve(llvm::size(values)); for (const auto &it : llvm::enumerate(values)) { @@ -1168,18 +1229,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( Type origType = operand.getType(); Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); - // Find the most recently mapped value. Unpack all temporary N:1 - // materializations. Such conversions are a workaround around missing - // 1:N support in the ConversionValueMapping. (The conversion patterns - // already support 1:N replacements.) - Value repl = mapping.lookupOrDefault(operand); - SmallVector<Value> unpacked = unpackNTo1Materialization(repl); - if (!currentTypeConverter) { // The current pattern does not have a type converter. I.e., it does not // distinguish between legal and illegal types. For each operand, simply - // pass through the most recently mapped value. - remapped.push_back(std::move(unpacked)); + // pass through the most recently mapped values. + remapped.push_back(mapping.lookupOrDefault(operand)); continue; } @@ -1192,51 +1246,28 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( }); return failure(); } - // If a type is converted to 0 types, there is nothing to do. if (legalTypes.empty()) { remapped.push_back({}); continue; } - if (legalTypes.size() != 1) { - // TODO: This is a 1:N conversion. The conversion value mapping does not - // store such materializations yet. If the types of the most recently - // mapped values do not match, build a target materialization. - ValueRange unpackedRange(unpacked); - if (TypeRange(unpackedRange) == legalTypes) { - remapped.push_back(std::move(unpacked)); - continue; - } - - // Insert a target materialization if the current pattern expects - // different legalized types. - ValueRange targetMat = buildUnresolvedMaterialization( - MaterializationKind::Target, computeInsertPoint(repl), operandLoc, - /*valueToMap=*/Value(), /*inputs=*/unpacked, - /*outputType=*/legalTypes, /*originalType=*/origType, - currentTypeConverter); - remapped.push_back(targetMat); + ValueVector repl = mapping.lookupOrDefault(operand, legalTypes); + if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) { + // Mapped values have the correct type or there is an existing + // materialization. Or the operand is not mapped at all and has the + // correct type. + remapped.push_back(std::move(repl)); continue; } - // Handle 1->1 type conversions. - Type desiredType = legalTypes.front(); - // Try to find a mapped value with the desired type. (Or the operand itself - // if the value is not mapped at all.) - Value newOperand = mapping.lookupOrDefault(operand, desiredType); - if (newOperand.getType() != desiredType) { - // If the looked up value's type does not have the desired type, it means - // that the value was replaced with a value of different type and no - // target materialization was created yet. - Value castValue = buildUnresolvedMaterialization( - MaterializationKind::Target, computeInsertPoint(newOperand), - operandLoc, /*valueToMap=*/newOperand, /*inputs=*/unpacked, - /*outputType=*/desiredType, /*originalType=*/origType, - currentTypeConverter); - newOperand = castValue; - } - remapped.push_back({newOperand}); + // Create a materialization for the most recently mapped values. + repl = mapping.lookupOrDefault(operand); + ValueRange castValues = buildUnresolvedMaterialization( + MaterializationKind::Target, computeInsertPoint(repl), operandLoc, + /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes, + /*originalType=*/origType, currentTypeConverter); + remapped.push_back(castValues); } return success(); } @@ -1353,7 +1384,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( buildUnresolvedMaterialization( MaterializationKind::Source, OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), - /*valueToMap=*/origArg, /*inputs=*/ValueRange(), + /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(), /*outputType=*/origArgType, /*originalType=*/Type(), converter); appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); continue; @@ -1369,19 +1400,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( continue; } - // This is a 1->1+ mapping. 1->N mappings are not fully supported in the - // dialect conversion. Therefore, we need an argument materialization to - // turn the replacement block arguments into a single SSA value that can be - // used as a replacement. + // This is a 1->1+ mapping. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - if (replArgs.size() == 1) { - mapping.map(origArg, replArgs.front()); - } else { - insertNTo1Materialization( - OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), - /*replacements=*/replArgs, /*outputValue=*/origArg, converter); - } + ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs); + mapping.map(origArg, std::move(replArgVals)); appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); } @@ -1402,20 +1425,13 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /// of input operands. ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, - Value valueToMap, ValueRange inputs, TypeRange outputTypes, + ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, UnrealizedConversionCastOp *castOp) { assert((!originalType || kind == MaterializationKind::Target) && "original type is valid only for target materializations"); - - // Avoid materializing an unnecessary cast. - if (TypeRange(inputs) == outputTypes) { - if (valueToMap) { - assert(inputs.size() == 1 && "1:N mapping is not supported"); - mapping.map(valueToMap, inputs.front()); - } - return inputs; - } + assert(TypeRange(inputs) != outputTypes && + "materialization is not necessary"); // Create an unresolved materialization. We use a new OpBuilder to avoid // tracking the materialization like we do for other operations. @@ -1423,37 +1439,23 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); auto convertOp = builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs); - if (valueToMap) { - assert(outputTypes.size() == 1 && "1:N mapping is not supported"); - mapping.map(valueToMap, convertOp.getResult(0)); - } + if (!valuesToMap.empty()) + mapping.map(valuesToMap, convertOp.getResults()); if (castOp) *castOp = convertOp; - appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind, - originalType, valueToMap); + appendRewrite<UnresolvedMaterializationRewrite>( + convertOp, converter, kind, originalType, std::move(valuesToMap)); return convertOp.getResults(); } -void ConversionPatternRewriterImpl::insertNTo1Materialization( - OpBuilder::InsertPoint ip, Location loc, ValueRange replacements, - Value originalValue, const TypeConverter *converter) { - // Insert argument materialization back to the original type. - Type originalType = originalValue.getType(); - UnrealizedConversionCastOp argCastOp; - buildUnresolvedMaterialization( - MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue, - /*inputs=*/replacements, originalType, - /*originalType=*/Type(), converter, &argCastOp); - if (argCastOp) - nTo1TempMaterializations.insert(argCastOp); -} - Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( Value value, const TypeConverter *converter) { - // Find a replacement value with the same type. - Value repl = mapping.lookupOrNull(value, value.getType()); - if (repl) - return repl; + // Try to find a replacement value with the same type in the conversion value + // mapping. This includes cached materializations. We try to reuse those + // instead of generating duplicate IR. + ValueVector repl = mapping.lookupOrNull(value, value.getType()); + if (!repl.empty()) + return repl.front(); // Check if the value is dead. No replacement value is needed in that case. // This is an approximate check that may have false negatives but does not @@ -1468,7 +1470,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // (regardless of the type) and build a source materialization to the // original type. repl = mapping.lookupOrNull(value); - if (!repl) { + if (repl.empty()) { // No replacement value is registered in the mapping. This means that the // value is dropped and no longer needed. (If the value were still needed, // a source materialization producing a replacement value "out of thin air" @@ -1476,34 +1478,22 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // `applySignatureConversion`.) return Value(); } - Value castValue = buildUnresolvedMaterialization( - MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(), - /*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(), - /*originalType=*/Type(), converter); - mapping.map(value, castValue); - return castValue; -} -SmallVector<Value> -ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) { - // Unpack unrealized_conversion_cast ops that were inserted as a N:1 - // workaround. - auto castOp = value.getDefiningOp<UnrealizedConversionCastOp>(); - if (!castOp) - return {value}; - if (!nTo1TempMaterializations.contains(castOp)) - return {value}; - assert(castOp->getNumResults() == 1 && "expected single result"); - - SmallVector<Value> result; - for (Value v : castOp.getOperands()) { - // Keep unpacking if possible. This is needed because during block - // signature conversions and 1:N op replacements, the driver may have - // inserted two materializations back-to-back: first an argument - // materialization, then a target materialization. - llvm::append_range(result, unpackNTo1Materialization(v)); - } - return result; + // Note: `computeInsertPoint` computes the "earliest" insertion point at + // which all values in `repl` are defined. It is important to emit the + // materialization at that location because the same materialization may be + // reused in a different context. (That's because materializations are cached + // in the conversion value mapping.) The insertion point of the + // materialization must be valid for all future users that may be created + // later in the conversion process. + Value castValue = + buildUnresolvedMaterialization(MaterializationKind::Source, + computeInsertPoint(repl), value.getLoc(), + /*valuesToMap=*/repl, /*inputs=*/repl, + /*outputType=*/value.getType(), + /*originalType=*/Type(), converter) + .front(); + return castValue; } //===----------------------------------------------------------------------===// @@ -1554,7 +1544,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( // Materialize a replacement value "out of thin air". buildUnresolvedMaterialization( MaterializationKind::Source, computeInsertPoint(result), - result.getLoc(), /*valueToMap=*/result, /*inputs=*/ValueRange(), + result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(), /*outputType=*/result.getType(), /*originalType=*/Type(), currentTypeConverter); continue; @@ -1572,16 +1562,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( // Remap result to replacement value. if (repl.empty()) continue; - - if (repl.size() == 1) { - // Single replacement value: replace directly. - mapping.map(result, repl.front()); - } else { - // Multiple replacement values: insert N:1 materialization. - insertNTo1Materialization(computeInsertPoint(result), result.getLoc(), - /*replacements=*/repl, /*outputValue=*/result, - currentTypeConverter); - } + mapping.map(result, repl); } appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter); @@ -1660,8 +1641,13 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); SmallVector<ValueRange> newVals; - for (size_t i = 0; i < newValues.size(); ++i) - newVals.push_back(newValues.slice(i, 1)); + for (size_t i = 0; i < newValues.size(); ++i) { + if (newValues[i]) { + newVals.push_back(newValues.slice(i, 1)); + } else { + newVals.push_back(ValueRange()); + } + } impl->notifyOpReplaced(op, newVals); } @@ -1733,7 +1719,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, } Value ConversionPatternRewriter::getRemappedValue(Value key) { - SmallVector<SmallVector<Value>> remappedValues; + SmallVector<ValueVector> remappedValues; if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key, remappedValues))) return nullptr; @@ -1746,7 +1732,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys, SmallVectorImpl<Value> &results) { if (keys.empty()) return success(); - SmallVector<SmallVector<Value>> remapped; + SmallVector<ValueVector> remapped; if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, remapped))) return failure(); @@ -1872,7 +1858,7 @@ ConversionPattern::matchAndRewrite(Operation *op, getTypeConverter()); // Remap the operands of the operation. - SmallVector<SmallVector<Value>> remapped; + SmallVector<ValueVector> remapped; if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, op->getOperands(), remapped))) { return failure(); @@ -2625,19 +2611,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter, rewriter.setInsertionPoint(op); SmallVector<Value> newMaterialization; switch (rewrite->getMaterializationKind()) { - case MaterializationKind::Argument: { - // Try to materialize an argument conversion. - assert(op->getNumResults() == 1 && "expected single result"); - Value argMat = converter->materializeArgumentConversion( - rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands); - if (argMat) { - newMaterialization.push_back(argMat); - break; - } - } - // If an argument materialization failed, fallback to trying a target - // materialization. - [[fallthrough]]; case MaterializationKind::Target: newMaterialization = converter->materializeTargetConversion( rewriter, op->getLoc(), op.getResultTypes(), inputOperands, |