diff options
author | Markus Böck <markus.boeck02@gmail.com> | 2025-03-26 09:12:16 +0100 |
---|---|---|
committer | Markus Böck <markus.boeck02@gmail.com> | 2025-03-26 09:12:16 +0100 |
commit | b30cdb12781fa603eb8e3803449f7f1530b7439d (patch) | |
tree | 76e0673d7a95e5f5027566698eeaeaa6bd20bedf | |
parent | 1b6eddfcd96bb749c05a6ab927c3ccd666b6d984 (diff) | |
download | llvm-users/zero9178/replace_op_with_multiple_overloads_suggestion.zip llvm-users/zero9178/replace_op_with_multiple_overloads_suggestion.tar.gz llvm-users/zero9178/replace_op_with_multiple_overloads_suggestion.tar.bz2 |
-rw-r--r-- | llvm/include/llvm/ADT/SmallVector.h | 5 | ||||
-rw-r--r-- | mlir/include/mlir/Transforms/DialectConversion.h | 22 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 33 |
3 files changed, 38 insertions, 22 deletions
diff --git a/llvm/include/llvm/ADT/SmallVector.h b/llvm/include/llvm/ADT/SmallVector.h index bd3e887..917fbce 100644 --- a/llvm/include/llvm/ADT/SmallVector.h +++ b/llvm/include/llvm/ADT/SmallVector.h @@ -1238,6 +1238,11 @@ public: SmallVectorImpl<T>::operator=(RHS); } + SmallVector(const SmallVectorImpl<T> &RHS) : SmallVectorImpl<T>(N) { + if (!RHS.empty()) + SmallVectorImpl<T>::operator=(RHS); + } + SmallVector &operator=(const SmallVector &RHS) { SmallVectorImpl<T>::operator=(RHS); return *this; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index e4a785ea..7ec0ab2 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -898,28 +898,24 @@ public: /// Replace the given operation with the new value ranges. The number of op /// results and value ranges must match. The given operation is erased. void replaceOpWithMultiple(Operation *op, - ArrayRef<SmallVector<Value, 1>> newValues); + SmallVector<SmallVector<Value>> &&newValues); + // Note: This overload matches SmallVector<ValueRange>, // SmallVector<SmallVector<Value>>, etc. template <typename RangeRangeT> void replaceOpWithMultiple(Operation *op, RangeRangeT &&newValues) { - // Note: Prefer the ArrayRef<SmallVector<Value, 1>> overload because it - // does not copy the replacements vector. - auto vals = llvm::map_to_vector(newValues, [](const auto &r) { - return SmallVector<Value, 1>(std::begin(r), std::end(r)); - }); - replaceOpWithMultiple(op, ArrayRef(vals)); + replaceOpWithMultiple(op, llvm::map_to_vector(newValues, [](const auto &r) { + return llvm::to_vector(r); + })); } + // Note: This overload matches initializer list of ValueRange, // SmallVector<Value>, etc. template <typename RangeT = ValueRange> void replaceOpWithMultiple(Operation *op, ArrayRef<RangeT> newValues) { - // Note: Prefer the ArrayRef<SmallVector<Value, 1>> overload because it - // does not copy the replacements vector. - auto vals = llvm::map_to_vector(newValues, [](const RangeT &r) { - return SmallVector<Value, 1>(std::begin(r), std::end(r)); - }); - replaceOpWithMultiple(op, ArrayRef(vals)); + replaceOpWithMultiple(op, llvm::map_to_vector(newValues, [](const RangeT &r) { + return SmallVector<Value>(std::begin(r), std::end(r)); + })); } /// PatternRewriter hook for erasing a dead operation. The uses of this diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index c692aaf..a77a99b 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -137,8 +137,20 @@ struct ConversionValueMapping { /// as `lookupOrDefault`. ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const; - template <typename T> - struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {}; + template <typename> + struct IsValueVector : std::false_type {}; + + template <typename T, size_t n> + struct IsValueVector<SmallVector<T, n>> : std::true_type {}; + + template <typename T, size_t n> + struct IsValueVector<SmallVector<T, n> &> : std::true_type {}; + + template <typename T, size_t n> + struct IsValueVector<SmallVector<T, n> &&> : std::true_type {}; + + template <typename T, size_t n> + struct IsValueVector<const SmallVector<T, n> &> : std::true_type {}; /// Map a value vector to the one provided. template <typename OldVal, typename NewVal> @@ -947,7 +959,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { OpBuilder::InsertPoint previous) override; /// Notifies that an op is about to be replaced with the given values. - void notifyOpReplaced(Operation *op, ArrayRef<ValueVector> newValues); + template <unsigned N> + void notifyOpReplaced(Operation *op, + SmallVector<SmallVector<Value, N>> &&newValues); /// Notifies that a block is about to be erased. void notifyBlockIsBeingErased(Block *block); @@ -1519,8 +1533,9 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp); } +template <unsigned N> void ConversionPatternRewriterImpl::notifyOpReplaced( - Operation *op, ArrayRef<ValueVector> newValues) { + Operation *op, SmallVector<SmallVector<Value, N>> &&newValues) { assert(newValues.size() == op->getNumResults()); assert(!ignoredOps.contains(op) && "operation was already replaced"); @@ -1562,7 +1577,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( // Remap result to replacement value. if (repl.empty()) continue; - mapping.map(result, repl); + mapping.map(result, std::move(repl)); } appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter); @@ -1644,18 +1659,18 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { llvm::map_to_vector(newValues, [](Value v) -> ValueVector { return v ? ValueVector{v} : ValueVector(); }); - impl->notifyOpReplaced(op, newVals); + impl->notifyOpReplaced(op, std::move(newVals)); } void ConversionPatternRewriter::replaceOpWithMultiple( - Operation *op, ArrayRef<SmallVector<Value, 1>> newValues) { + Operation *op, SmallVector<SmallVector<Value>> &&newValues) { assert(op->getNumResults() == newValues.size() && "incorrect # of replacement values"); LLVM_DEBUG({ impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); - impl->notifyOpReplaced(op, newValues); + impl->notifyOpReplaced(op, std::move(newValues)); } void ConversionPatternRewriter::eraseOp(Operation *op) { @@ -1664,7 +1679,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); SmallVector<ValueVector> nullRepls(op->getNumResults(), ValueVector()); - impl->notifyOpReplaced(op, nullRepls); + impl->notifyOpReplaced(op, std::move(nullRepls)); } void ConversionPatternRewriter::eraseBlock(Block *block) { |