aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <mspringer@nvidia.com>2024-10-05 17:29:12 +0200
committerMatthias Springer <mspringer@nvidia.com>2024-10-12 10:45:42 +0200
commit16388fdda61e751c85a2dcb8beff8e2fa337b698 (patch)
tree83d4ff431204cc44e312006aba02322203b34de5
parent9f24c145494ee238e65e25205a4dcb4451f009ae (diff)
downloadllvm-users/matthias-springer/one_to_n_pattern.zip
llvm-users/matthias-springer/one_to_n_pattern.tar.gz
llvm-users/matthias-springer/one_to_n_pattern.tar.bz2
[WIP] 1:N conversion patternusers/matthias-springer/one_to_n_pattern
-rwxr-xr-xmlir/artifacts/jq-linux64bin0 -> 3953824 bytes
-rw-r--r--mlir/include/mlir/Conversion/LLVMCommon/Pattern.h35
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h57
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp45
-rw-r--r--mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp11
5 files changed, 133 insertions, 15 deletions
diff --git a/mlir/artifacts/jq-linux64 b/mlir/artifacts/jq-linux64
new file mode 100755
index 0000000..f48b0ca
--- /dev/null
+++ b/mlir/artifacts/jq-linux64
Binary files differ
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index f3bf5b6..6751c3e 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -143,6 +143,8 @@ template <typename SourceOp>
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
+ using OneToNOpAdaptor =
+ typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>;
explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
@@ -153,8 +155,13 @@ public:
/// Wrappers around the RewritePattern methods that pass the derived op type.
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
- rewriter);
+ auto sourceOp = cast<SourceOp>(op);
+ rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
+ }
+ void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto sourceOp = cast<SourceOp>(op);
+ rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
}
LogicalResult match(Operation *op) const final {
return match(cast<SourceOp>(op));
@@ -162,8 +169,15 @@ public:
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
- return matchAndRewrite(cast<SourceOp>(op),
- OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
+ auto sourceOp = cast<SourceOp>(op);
+ return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
+ }
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto sourceOp = cast<SourceOp>(op);
+ return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
+ rewriter);
}
/// Rewrite and Match methods that operate on the SourceOp type. These must be
@@ -175,6 +189,12 @@ public:
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override rewrite or matchAndRewrite");
}
+ virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ SmallVector<Value> oneToOneOperands =
+ getOneToOneAdaptorOperands(adaptor.getOperands());
+ rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+ }
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -183,6 +203,13 @@ public:
rewrite(op, adaptor, rewriter);
return success();
}
+ virtual LogicalResult
+ matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ SmallVector<Value> oneToOneOperands =
+ getOneToOneAdaptorOperands(adaptor.getOperands());
+ return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+ }
private:
using ConvertToLLVMPattern::match;
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 65e279e..080129b 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -467,6 +467,10 @@ public:
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite");
}
+ virtual void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
+ ConversionPatternRewriter &rewriter) const {
+ rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ }
/// Hook for derived classes to implement combined matching and rewriting.
virtual LogicalResult
@@ -477,6 +481,11 @@ public:
rewrite(op, operands, rewriter);
return success();
}
+ virtual LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
+ ConversionPatternRewriter &rewriter) const {
+ return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ }
/// Attempt to match and rewrite the IR root at the specified operation.
LogicalResult matchAndRewrite(Operation *op,
@@ -504,6 +513,9 @@ protected:
: RewritePattern(std::forward<Args>(args)...),
typeConverter(&typeConverter) {}
+ static SmallVector<Value>
+ getOneToOneAdaptorOperands(ArrayRef<ArrayRef<Value>> operands);
+
protected:
/// An optional type converter for use by this pattern.
const TypeConverter *typeConverter = nullptr;
@@ -519,6 +531,8 @@ template <typename SourceOp>
class OpConversionPattern : public ConversionPattern {
public:
using OpAdaptor = typename SourceOp::Adaptor;
+ using OneToNOpAdaptor =
+ typename SourceOp::template GenericAdaptor<ArrayRef<ArrayRef<Value>>>;
OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
: ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -537,12 +551,24 @@ public:
auto sourceOp = cast<SourceOp>(op);
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
+ void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto sourceOp = cast<SourceOp>(op);
+ rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
+ }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto sourceOp = cast<SourceOp>(op);
return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto sourceOp = cast<SourceOp>(op);
+ return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
+ rewriter);
+ }
/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
@@ -553,6 +579,12 @@ public:
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
+ virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ SmallVector<Value> oneToOneOperands =
+ getOneToOneAdaptorOperands(adaptor.getOperands());
+ rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+ }
virtual LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -561,6 +593,13 @@ public:
rewrite(op, adaptor, rewriter);
return success();
}
+ virtual LogicalResult
+ matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ SmallVector<Value> oneToOneOperands =
+ getOneToOneAdaptorOperands(adaptor.getOperands());
+ return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+ }
private:
using ConversionPattern::matchAndRewrite;
@@ -586,11 +625,20 @@ public:
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
+ void rewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ rewrite(cast<SourceOp>(op), operands, rewriter);
+ }
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ArrayRef<Value>> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
+ }
/// Rewrite and Match methods that operate on the SourceOp type. These must be
/// overridden by the derived pattern class.
@@ -598,6 +646,10 @@ public:
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}
+ virtual void rewrite(SourceOp op, ArrayRef<ArrayRef<Value>> operands,
+ ConversionPatternRewriter &rewriter) const {
+ rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ }
virtual LogicalResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
@@ -606,6 +658,11 @@ public:
rewrite(op, operands, rewriter);
return success();
}
+ virtual LogicalResult
+ matchAndRewrite(SourceOp op, ArrayRef<ArrayRef<Value>> operands,
+ ConversionPatternRewriter &rewriter) const {
+ return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ }
private:
using ConversionPattern::matchAndRewrite;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 97dd3ab..0d13eb5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -769,7 +769,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
LogicalResult remapValues(StringRef valueDiagTag,
std::optional<Location> inputLoc,
PatternRewriter &rewriter, ValueRange values,
- SmallVectorImpl<Value> &remapped);
+ SmallVector<SmallVector<Value, 1>> &remapped);
/// Return "true" if the given operation is ignored, and does not need to be
/// converted.
@@ -1089,7 +1089,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
LogicalResult ConversionPatternRewriterImpl::remapValues(
StringRef valueDiagTag, std::optional<Location> inputLoc,
PatternRewriter &rewriter, ValueRange values,
- SmallVectorImpl<Value> &remapped) {
+ SmallVector<SmallVector<Value, 1>> &remapped) {
remapped.reserve(llvm::size(values));
for (const auto &it : llvm::enumerate(values)) {
@@ -1101,7 +1101,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// The current pattern does not have a type converter. I.e., it does not
// distinguish between legal and illegal types. For each operand, simply
// pass through the most recently mapped value.
- remapped.push_back(mapping.lookupOrDefault(operand));
+ remapped.push_back({mapping.lookupOrDefault(operand)});
continue;
}
@@ -1123,7 +1123,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// improvements to the `ConversionValueMapping` (to be able to store 1:N
// mappings) and to the `ConversionPattern` adaptor handling (to be able
// to pass multiple remapped values for a single operand to the adaptor).
- remapped.push_back(mapping.lookupOrDefault(operand));
+ remapped.push_back({mapping.lookupOrDefault(operand)});
continue;
}
@@ -1143,7 +1143,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
mapping.map(newOperand, castValue);
newOperand = castValue;
}
- remapped.push_back(newOperand);
+ remapped.push_back({newOperand});
}
return success();
}
@@ -1523,11 +1523,12 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
}
Value ConversionPatternRewriter::getRemappedValue(Value key) {
- SmallVector<Value> remappedValues;
+ SmallVector<SmallVector<Value, 1>> remappedValues;
if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
remappedValues)))
return nullptr;
- return remappedValues.front();
+ assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
+ return remappedValues.front().front();
}
LogicalResult
@@ -1535,8 +1536,15 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
SmallVectorImpl<Value> &results) {
if (keys.empty())
return success();
- return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
- results);
+ SmallVector<SmallVector<Value, 1>> remapped;
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
+ remapped)))
+ return failure();
+ for (const auto &values : remapped) {
+ assert(values.size() == 1 && "1:N conversion not supported");
+ results.push_back(values.front());
+ }
+ return success();
}
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
@@ -1630,6 +1638,16 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
// ConversionPattern
//===----------------------------------------------------------------------===//
+SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
+ ArrayRef<ArrayRef<Value>> operands) {
+ SmallVector<Value> oneToOneOperands;
+ oneToOneOperands.reserve(operands.size());
+ for (ArrayRef<Value> operand : operands) {
+ assert(operand.size() == 1 && "pattern does not support 1:N conversion");
+ oneToOneOperands.push_back(operand.front());
+ }
+}
+
LogicalResult
ConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
@@ -1641,11 +1659,16 @@ ConversionPattern::matchAndRewrite(Operation *op,
getTypeConverter());
// Remap the operands of the operation.
- SmallVector<Value, 4> operands;
+ SmallVector<SmallVector<Value, 1>> remapped;
if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
- op->getOperands(), operands))) {
+ op->getOperands(), remapped))) {
return failure();
}
+ SmallVector<Value, 4> operands;
+ for (const auto &values : remapped) {
+ assert(values.size() == 1 && "1:N conversion not supported");
+ operands.push_back(values.front());
+ }
return matchAndRewrite(op, operands, dialectRewriter);
}
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 67df002..299eacc 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -4282,6 +4282,17 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
}
}
+ {
+ SmallVector<MethodParameter> paramList;
+ paramList.emplace_back("RangeT", "values");
+ paramList.emplace_back("const " + op.getGenericAdaptorName() + "Base &",
+ "base");
+ auto *constructor =
+ genericAdaptor.addConstructor<Method::Inline>(paramList);
+ constructor->addMemberInitializer("Base", "base");
+ constructor->addMemberInitializer("odsOperands", "values");
+ }
+
// Create constructors constructing the adaptor from an instance of the op.
// This takes the attributes, properties and regions from the op instance
// and the value range from the parameter.