aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarkus Böck <markus.boeck02@gmail.com>2024-12-21 19:01:19 +0100
committerMarkus Böck <markus.boeck02@gmail.com>2024-12-21 19:01:19 +0100
commit5e857ac6348e7455f1c72fd49eb887fc1406949c (patch)
tree9e493022ffd529803c61fe43b8793e49e167120d
parent53f97f5a68033bd46ffd5a982435d64afe9048dd (diff)
downloadllvm-users/zero9178/1n_conversion_value_mapping_review.zip
llvm-users/zero9178/1n_conversion_value_mapping_review.tar.gz
llvm-users/zero9178/1n_conversion_value_mapping_review.tar.bz2
use universal references for `map`users/zero9178/1n_conversion_value_mapping_review
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp36
1 files changed, 27 insertions, 9 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ea169a1..207b0073 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -169,10 +169,15 @@ struct ConversionValueMapping {
ValueVector lookupOrNull(const ValueVector &from,
TypeRange desiredTypes = {}) const;
+ template <typename T>
+ struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
+
/// Map a value to the one provided.
- void map(const ValueVector &oldVal, const ValueVector &newVal) {
+ template <typename OldVal, typename NewVal>
+ std::enable_if_t<IsValueVector<OldVal>{} && IsValueVector<NewVal>{}>
+ map(OldVal &&oldVal, NewVal &&newVal) {
LLVM_DEBUG({
- ValueVector next = newVal;
+ ValueVector next(newVal);
while (true) {
assert(next != oldVal && "inserting cyclic mapping");
auto it = mapping.find(next);
@@ -181,9 +186,22 @@ struct ConversionValueMapping {
next = it->second;
}
});
- mapping[oldVal] = newVal;
for (Value v : newVal)
mappedTo.insert(v);
+
+ mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
+ }
+
+ template <typename OldVal, typename NewVal>
+ std::enable_if_t<!IsValueVector<OldVal>{} || !IsValueVector<NewVal>{}>
+ map(OldVal &&oldVal, NewVal &&newVal) {
+ if constexpr (IsValueVector<OldVal>{}) {
+ map(std::forward<OldVal>(oldVal), ValueVector{newVal});
+ } else if constexpr (IsValueVector<NewVal>{}) {
+ map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
+ } else {
+ map(ValueVector{oldVal}, ValueVector{newVal});
+ }
}
/// Drop the last mapping for the given values.
@@ -1405,7 +1423,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
- mapping.map({origArg}, {repl});
+ mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
continue;
}
@@ -1418,7 +1436,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
ValueVector replArgVals = llvm::map_to_vector<1>(
replArgs, [](BlockArgument arg) -> Value { return arg; });
- mapping.map({origArg}, replArgVals);
+ mapping.map(origArg, std::move(replArgVals));
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
}
@@ -1448,7 +1466,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
// Avoid materializing an unnecessary cast.
if (TypeRange(inputs) == outputTypes) {
if (!valuesToMap.empty())
- mapping.map(valuesToMap, inputs);
+ mapping.map(std::move(valuesToMap), inputs);
return inputs;
}
@@ -1499,7 +1517,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
/*valuesToMap=*/{value}, /*inputs=*/repl, /*outputType=*/value.getType(),
/*originalType=*/Type(), converter)[0];
- mapping.map({value}, {castValue});
+ mapping.map(value, castValue);
return castValue;
}
@@ -1569,7 +1587,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
// Remap result to replacement value.
if (repl.empty())
continue;
- mapping.map({result}, repl);
+ mapping.map(result, repl);
}
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
@@ -1722,7 +1740,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
});
impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
impl->currentTypeConverter);
- impl->mapping.map(impl->mapping.lookupOrDefault({from}), {to});
+ impl->mapping.map(impl->mapping.lookupOrDefault({from}), to);
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {