aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarkus Böck <markus.boeck02@gmail.com>2025-03-26 09:12:16 +0100
committerMarkus Böck <markus.boeck02@gmail.com>2025-03-26 09:12:16 +0100
commitb30cdb12781fa603eb8e3803449f7f1530b7439d (patch)
tree76e0673d7a95e5f5027566698eeaeaa6bd20bedf
parent1b6eddfcd96bb749c05a6ab927c3ccd666b6d984 (diff)
downloadllvm-users/zero9178/replace_op_with_multiple_overloads_suggestion.zip
llvm-users/zero9178/replace_op_with_multiple_overloads_suggestion.tar.gz
llvm-users/zero9178/replace_op_with_multiple_overloads_suggestion.tar.bz2
-rw-r--r--llvm/include/llvm/ADT/SmallVector.h5
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h22
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp33
3 files changed, 38 insertions, 22 deletions
diff --git a/llvm/include/llvm/ADT/SmallVector.h b/llvm/include/llvm/ADT/SmallVector.h
index bd3e887..917fbce 100644
--- a/llvm/include/llvm/ADT/SmallVector.h
+++ b/llvm/include/llvm/ADT/SmallVector.h
@@ -1238,6 +1238,11 @@ public:
SmallVectorImpl<T>::operator=(RHS);
}
+ SmallVector(const SmallVectorImpl<T> &RHS) : SmallVectorImpl<T>(N) {
+ if (!RHS.empty())
+ SmallVectorImpl<T>::operator=(RHS);
+ }
+
SmallVector &operator=(const SmallVector &RHS) {
SmallVectorImpl<T>::operator=(RHS);
return *this;
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e4a785ea..7ec0ab2 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -898,28 +898,24 @@ public:
/// Replace the given operation with the new value ranges. The number of op
/// results and value ranges must match. The given operation is erased.
void replaceOpWithMultiple(Operation *op,
- ArrayRef<SmallVector<Value, 1>> newValues);
+ SmallVector<SmallVector<Value>> &&newValues);
+
// Note: This overload matches SmallVector<ValueRange>,
// SmallVector<SmallVector<Value>>, etc.
template <typename RangeRangeT>
void replaceOpWithMultiple(Operation *op, RangeRangeT &&newValues) {
- // Note: Prefer the ArrayRef<SmallVector<Value, 1>> overload because it
- // does not copy the replacements vector.
- auto vals = llvm::map_to_vector(newValues, [](const auto &r) {
- return SmallVector<Value, 1>(std::begin(r), std::end(r));
- });
- replaceOpWithMultiple(op, ArrayRef(vals));
+ replaceOpWithMultiple(op, llvm::map_to_vector(newValues, [](const auto &r) {
+ return llvm::to_vector(r);
+ }));
}
+
// Note: This overload matches initializer list of ValueRange,
// SmallVector<Value>, etc.
template <typename RangeT = ValueRange>
void replaceOpWithMultiple(Operation *op, ArrayRef<RangeT> newValues) {
- // Note: Prefer the ArrayRef<SmallVector<Value, 1>> overload because it
- // does not copy the replacements vector.
- auto vals = llvm::map_to_vector(newValues, [](const RangeT &r) {
- return SmallVector<Value, 1>(std::begin(r), std::end(r));
- });
- replaceOpWithMultiple(op, ArrayRef(vals));
+ replaceOpWithMultiple(op, llvm::map_to_vector(newValues, [](const RangeT &r) {
+ return SmallVector<Value>(std::begin(r), std::end(r));
+ }));
}
/// PatternRewriter hook for erasing a dead operation. The uses of this
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c692aaf..a77a99b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -137,8 +137,20 @@ struct ConversionValueMapping {
/// as `lookupOrDefault`.
ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
- template <typename T>
- struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
+ template <typename>
+ struct IsValueVector : std::false_type {};
+
+ template <typename T, size_t n>
+ struct IsValueVector<SmallVector<T, n>> : std::true_type {};
+
+ template <typename T, size_t n>
+ struct IsValueVector<SmallVector<T, n> &> : std::true_type {};
+
+ template <typename T, size_t n>
+ struct IsValueVector<SmallVector<T, n> &&> : std::true_type {};
+
+ template <typename T, size_t n>
+ struct IsValueVector<const SmallVector<T, n> &> : std::true_type {};
/// Map a value vector to the one provided.
template <typename OldVal, typename NewVal>
@@ -947,7 +959,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
OpBuilder::InsertPoint previous) override;
/// Notifies that an op is about to be replaced with the given values.
- void notifyOpReplaced(Operation *op, ArrayRef<ValueVector> newValues);
+ template <unsigned N>
+ void notifyOpReplaced(Operation *op,
+ SmallVector<SmallVector<Value, N>> &&newValues);
/// Notifies that a block is about to be erased.
void notifyBlockIsBeingErased(Block *block);
@@ -1519,8 +1533,9 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
}
+template <unsigned N>
void ConversionPatternRewriterImpl::notifyOpReplaced(
- Operation *op, ArrayRef<ValueVector> newValues) {
+ Operation *op, SmallVector<SmallVector<Value, N>> &&newValues) {
assert(newValues.size() == op->getNumResults());
assert(!ignoredOps.contains(op) && "operation was already replaced");
@@ -1562,7 +1577,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
// Remap result to replacement value.
if (repl.empty())
continue;
- mapping.map(result, repl);
+ mapping.map(result, std::move(repl));
}
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
@@ -1644,18 +1659,18 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
llvm::map_to_vector(newValues, [](Value v) -> ValueVector {
return v ? ValueVector{v} : ValueVector();
});
- impl->notifyOpReplaced(op, newVals);
+ impl->notifyOpReplaced(op, std::move(newVals));
}
void ConversionPatternRewriter::replaceOpWithMultiple(
- Operation *op, ArrayRef<SmallVector<Value, 1>> newValues) {
+ Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
assert(op->getNumResults() == newValues.size() &&
"incorrect # of replacement values");
LLVM_DEBUG({
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
- impl->notifyOpReplaced(op, newValues);
+ impl->notifyOpReplaced(op, std::move(newValues));
}
void ConversionPatternRewriter::eraseOp(Operation *op) {
@@ -1664,7 +1679,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
<< "** Erase : '" << op->getName() << "'(" << op << ")\n";
});
SmallVector<ValueVector> nullRepls(op->getNumResults(), ValueVector());
- impl->notifyOpReplaced(op, nullRepls);
+ impl->notifyOpReplaced(op, std::move(nullRepls));
}
void ConversionPatternRewriter::eraseBlock(Block *block) {