diff options
-rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 36 |
1 files changed, 27 insertions, 9 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index ea169a1..207b0073 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -169,10 +169,15 @@ struct ConversionValueMapping { ValueVector lookupOrNull(const ValueVector &from, TypeRange desiredTypes = {}) const; + template <typename T> + struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {}; + /// Map a value to the one provided. - void map(const ValueVector &oldVal, const ValueVector &newVal) { + template <typename OldVal, typename NewVal> + std::enable_if_t<IsValueVector<OldVal>{} && IsValueVector<NewVal>{}> + map(OldVal &&oldVal, NewVal &&newVal) { LLVM_DEBUG({ - ValueVector next = newVal; + ValueVector next(newVal); while (true) { assert(next != oldVal && "inserting cyclic mapping"); auto it = mapping.find(next); @@ -181,9 +186,22 @@ struct ConversionValueMapping { next = it->second; } }); - mapping[oldVal] = newVal; for (Value v : newVal) mappedTo.insert(v); + + mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal); + } + + template <typename OldVal, typename NewVal> + std::enable_if_t<!IsValueVector<OldVal>{} || !IsValueVector<NewVal>{}> + map(OldVal &&oldVal, NewVal &&newVal) { + if constexpr (IsValueVector<OldVal>{}) { + map(std::forward<OldVal>(oldVal), ValueVector{newVal}); + } else if constexpr (IsValueVector<NewVal>{}) { + map(ValueVector{oldVal}, std::forward<NewVal>(newVal)); + } else { + map(ValueVector{oldVal}, ValueVector{newVal}); + } } /// Drop the last mapping for the given values. @@ -1405,7 +1423,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( assert(inputMap->size == 0 && "invalid to provide a replacement value when the argument isn't " "dropped"); - mapping.map({origArg}, {repl}); + mapping.map(origArg, repl); appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); continue; } @@ -1418,7 +1436,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); ValueVector replArgVals = llvm::map_to_vector<1>( replArgs, [](BlockArgument arg) -> Value { return arg; }); - mapping.map({origArg}, replArgVals); + mapping.map(origArg, std::move(replArgVals)); appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter); } @@ -1448,7 +1466,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( // Avoid materializing an unnecessary cast. if (TypeRange(inputs) == outputTypes) { if (!valuesToMap.empty()) - mapping.map(valuesToMap, inputs); + mapping.map(std::move(valuesToMap), inputs); return inputs; } @@ -1499,7 +1517,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(), /*valuesToMap=*/{value}, /*inputs=*/repl, /*outputType=*/value.getType(), /*originalType=*/Type(), converter)[0]; - mapping.map({value}, {castValue}); + mapping.map(value, castValue); return castValue; } @@ -1569,7 +1587,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( // Remap result to replacement value. if (repl.empty()) continue; - mapping.map({result}, repl); + mapping.map(result, repl); } appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter); @@ -1722,7 +1740,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, }); impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, impl->currentTypeConverter); - impl->mapping.map(impl->mapping.lookupOrDefault({from}), {to}); + impl->mapping.map(impl->mapping.lookupOrDefault({from}), to); } Value ConversionPatternRewriter::getRemappedValue(Value key) { |