diff options
author | Matthias Springer <mspringer@nvidia.com> | 2024-10-05 17:29:12 +0200 |
---|---|---|
committer | Matthias Springer <mspringer@nvidia.com> | 2024-10-12 10:45:42 +0200 |
commit | 16388fdda61e751c85a2dcb8beff8e2fa337b698 (patch) | |
tree | 83d4ff431204cc44e312006aba02322203b34de5 | |
parent | 9f24c145494ee238e65e25205a4dcb4451f009ae (diff) | |
download | llvm-users/matthias-springer/one_to_n_pattern.zip llvm-users/matthias-springer/one_to_n_pattern.tar.gz llvm-users/matthias-springer/one_to_n_pattern.tar.bz2 |
[WIP] 1:N conversion patternusers/matthias-springer/one_to_n_pattern
-rwxr-xr-x | mlir/artifacts/jq-linux64 | bin | 0 -> 3953824 bytes | |||
-rw-r--r-- | mlir/include/mlir/Conversion/LLVMCommon/Pattern.h | 35 | ||||
-rw-r--r-- | mlir/include/mlir/Transforms/DialectConversion.h | 57 | ||||
-rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 45 | ||||
-rw-r--r-- | mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 11 |
5 files changed, 133 insertions, 15 deletions
diff --git a/mlir/artifacts/jq-linux64 b/mlir/artifacts/jq-linux64 Binary files differnew file mode 100755 index 0000000..f48b0ca --- /dev/null +++ b/mlir/artifacts/jq-linux64 diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h index f3bf5b6..6751c3e 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -143,6 +143,8 @@ template <typename SourceOp> class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { public: using OpAdaptor = typename SourceOp::Adaptor; + using OneToNOpAdaptor = + typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>; explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) @@ -153,8 +155,13 @@ public: /// Wrappers around the RewritePattern methods that pass the derived op type. void rewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final { - rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)), - rewriter); + auto sourceOp = cast<SourceOp>(op); + rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); + } + void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast<SourceOp>(op); + rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter); } LogicalResult match(Operation *op) const final { return match(cast<SourceOp>(op)); @@ -162,8 +169,15 @@ public: LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final { - return matchAndRewrite(cast<SourceOp>(op), - OpAdaptor(operands, cast<SourceOp>(op)), rewriter); + auto sourceOp = cast<SourceOp>(op); + return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); + } + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast<SourceOp>(op); + return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), + rewriter); } /// Rewrite and Match methods that operate on the SourceOp type. These must be @@ -175,6 +189,12 @@ public: ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override rewrite or matchAndRewrite"); } + virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + SmallVector<Value> oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -183,6 +203,13 @@ public: rewrite(op, adaptor, rewriter); return success(); } + virtual LogicalResult + matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + SmallVector<Value> oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } private: using ConvertToLLVMPattern::match; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 65e279e..080129b 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -467,6 +467,10 @@ public: ConversionPatternRewriter &rewriter) const { llvm_unreachable("unimplemented rewrite"); } + virtual void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const { + rewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + } /// Hook for derived classes to implement combined matching and rewriting. virtual LogicalResult @@ -477,6 +481,11 @@ public: rewrite(op, operands, rewriter); return success(); } + virtual LogicalResult + matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const { + return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + } /// Attempt to match and rewrite the IR root at the specified operation. LogicalResult matchAndRewrite(Operation *op, @@ -504,6 +513,9 @@ protected: : RewritePattern(std::forward<Args>(args)...), typeConverter(&typeConverter) {} + static SmallVector<Value> + getOneToOneAdaptorOperands(ArrayRef<ArrayRef<Value>> operands); + protected: /// An optional type converter for use by this pattern. const TypeConverter *typeConverter = nullptr; @@ -519,6 +531,8 @@ template <typename SourceOp> class OpConversionPattern : public ConversionPattern { public: using OpAdaptor = typename SourceOp::Adaptor; + using OneToNOpAdaptor = + typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>; OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} @@ -537,12 +551,24 @@ public: auto sourceOp = cast<SourceOp>(op); rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); } + void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast<SourceOp>(op); + rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter); + } LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final { auto sourceOp = cast<SourceOp>(op); return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); } + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const final { + auto sourceOp = cast<SourceOp>(op); + return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), + rewriter); + } /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. @@ -553,6 +579,12 @@ public: ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override matchAndRewrite or a rewrite method"); } + virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + SmallVector<Value> oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { @@ -561,6 +593,13 @@ public: rewrite(op, adaptor, rewriter); return success(); } + virtual LogicalResult + matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + SmallVector<Value> oneToOneOperands = + getOneToOneAdaptorOperands(adaptor.getOperands()); + return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter); + } private: using ConversionPattern::matchAndRewrite; @@ -586,11 +625,20 @@ public: ConversionPatternRewriter &rewriter) const final { rewrite(cast<SourceOp>(op), operands, rewriter); } + void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const final { + rewrite(cast<SourceOp>(op), operands, rewriter); + } LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final { return matchAndRewrite(cast<SourceOp>(op), operands, rewriter); } + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const final { + return matchAndRewrite(cast<SourceOp>(op), operands, rewriter); + } /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. @@ -598,6 +646,10 @@ public: ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override matchAndRewrite or a rewrite method"); } + virtual void rewrite(SourceOp op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const { + rewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + } virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const { @@ -606,6 +658,11 @@ public: rewrite(op, operands, rewriter); return success(); } + virtual LogicalResult + matchAndRewrite(SourceOp op, ArrayRef<ArrayRef<Value>> operands, + ConversionPatternRewriter &rewriter) const { + return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter); + } private: using ConversionPattern::matchAndRewrite; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 97dd3ab..0d13eb5 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -769,7 +769,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { LogicalResult remapValues(StringRef valueDiagTag, std::optional<Location> inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVectorImpl<Value> &remapped); + SmallVector<SmallVector<Value, 1>> &remapped); /// Return "true" if the given operation is ignored, and does not need to be /// converted. @@ -1089,7 +1089,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { LogicalResult ConversionPatternRewriterImpl::remapValues( StringRef valueDiagTag, std::optional<Location> inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVectorImpl<Value> &remapped) { + SmallVector<SmallVector<Value, 1>> &remapped) { remapped.reserve(llvm::size(values)); for (const auto &it : llvm::enumerate(values)) { @@ -1101,7 +1101,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( // The current pattern does not have a type converter. I.e., it does not // distinguish between legal and illegal types. For each operand, simply // pass through the most recently mapped value. - remapped.push_back(mapping.lookupOrDefault(operand)); + remapped.push_back({mapping.lookupOrDefault(operand)}); continue; } @@ -1123,7 +1123,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( // improvements to the `ConversionValueMapping` (to be able to store 1:N // mappings) and to the `ConversionPattern` adaptor handling (to be able // to pass multiple remapped values for a single operand to the adaptor). - remapped.push_back(mapping.lookupOrDefault(operand)); + remapped.push_back({mapping.lookupOrDefault(operand)}); continue; } @@ -1143,7 +1143,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( mapping.map(newOperand, castValue); newOperand = castValue; } - remapped.push_back(newOperand); + remapped.push_back({newOperand}); } return success(); } @@ -1523,11 +1523,12 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, } Value ConversionPatternRewriter::getRemappedValue(Value key) { - SmallVector<Value> remappedValues; + SmallVector<SmallVector<Value, 1>> remappedValues; if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key, remappedValues))) return nullptr; - return remappedValues.front(); + assert(remappedValues.front().size() == 1 && "1:N conversion not supported"); + return remappedValues.front().front(); } LogicalResult @@ -1535,8 +1536,15 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys, SmallVectorImpl<Value> &results) { if (keys.empty()) return success(); - return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, - results); + SmallVector<SmallVector<Value, 1>> remapped; + if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, + remapped))) + return failure(); + for (const auto &values : remapped) { + assert(values.size() == 1 && "1:N conversion not supported"); + results.push_back(values.front()); + } + return success(); } void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, @@ -1630,6 +1638,16 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { // ConversionPattern //===----------------------------------------------------------------------===// +SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands( + ArrayRef<ArrayRef<Value>> operands) { + SmallVector<Value> oneToOneOperands; + oneToOneOperands.reserve(operands.size()); + for (ArrayRef<Value> operand : operands) { + assert(operand.size() == 1 && "pattern does not support 1:N conversion"); + oneToOneOperands.push_back(operand.front()); + } +} + LogicalResult ConversionPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { @@ -1641,11 +1659,16 @@ ConversionPattern::matchAndRewrite(Operation *op, getTypeConverter()); // Remap the operands of the operation. - SmallVector<Value, 4> operands; + SmallVector<SmallVector<Value, 1>> remapped; if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, - op->getOperands(), operands))) { + op->getOperands(), remapped))) { return failure(); } + SmallVector<Value, 4> operands; + for (const auto &values : remapped) { + assert(values.size() == 1 && "1:N conversion not supported"); + operands.push_back(values.front()); + } return matchAndRewrite(op, operands, dialectRewriter); } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 67df002..299eacc 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -4282,6 +4282,17 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( } } + { + SmallVector<MethodParameter> paramList; + paramList.emplace_back("RangeT", "values"); + paramList.emplace_back("const " + op.getGenericAdaptorName() + "Base &", + "base"); + auto *constructor = + genericAdaptor.addConstructor<Method::Inline>(paramList); + constructor->addMemberInitializer("Base", "base"); + constructor->addMemberInitializer("odsOperands", "values"); + } + // Create constructors constructing the adaptor from an instance of the op. // This takes the attributes, properties and regions from the op instance // and the value range from the parameter. |