aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Test/TestPatterns.cpp
diff options
context:
space:
mode:
authorNAKAMURA Takumi <geek4civic@gmail.com>2025-01-09 18:43:11 +0900
committerNAKAMURA Takumi <geek4civic@gmail.com>2025-01-09 18:43:11 +0900
commit0e1a753549b29ff1f5a190aca83b803a33b51628 (patch)
treee5578f8810c65711304128d0c8add7fa1f77b9d8 /mlir/test/lib/Dialect/Test/TestPatterns.cpp
parent3c6252260ee11e3a453076b4d96ffffe20d49998 (diff)
parentbdcf47e4bcb92889665825654bb80a8bbe30379e (diff)
downloadllvm-users/chapuni/cov/single/if.zip
llvm-users/chapuni/cov/single/if.tar.gz
llvm-users/chapuni/cov/single/if.tar.bz2
Merge branch 'users/chapuni/cov/single/base' into users/chapuni/cov/single/ifusers/chapuni/cov/single/if
Conflicts: clang/lib/CodeGen/CoverageMappingGen.cpp
Diffstat (limited to 'mlir/test/lib/Dialect/Test/TestPatterns.cpp')
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp44
1 files changed, 39 insertions, 5 deletions
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index a470497..5b7c36c 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -785,7 +785,7 @@ struct TestDetachedSignatureConversion : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
if (op->getNumRegions() != 1)
return failure();
- OperationState state(op->getLoc(), "test.legal_op_with_region", operands,
+ OperationState state(op->getLoc(), "test.legal_op", operands,
op->getResultTypes(), {}, BlockRange());
Region *newRegion = state.addRegion();
rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
@@ -1234,6 +1234,41 @@ public:
}
};
+/// A pattern that tests two back-to-back 1 -> 2 op replacements.
+class TestMultiple1ToNReplacement : public ConversionPattern {
+public:
+ TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1,
+ ctx) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // Helper function that replaces the given op with a new op of the given
+ // name and doubles each result (1 -> 2 replacement of each result).
+ auto replaceWithDoubleResults = [&](Operation *op, StringRef name) {
+ SmallVector<Type> types;
+ for (Type t : op->getResultTypes()) {
+ types.push_back(t);
+ types.push_back(t);
+ }
+ OperationState state(op->getLoc(), name,
+ /*operands=*/{}, types, op->getAttrs());
+ auto *newOp = rewriter.create(state);
+ SmallVector<ValueRange> repls;
+ for (size_t i = 0, e = op->getNumResults(); i < e; ++i)
+ repls.push_back(newOp->getResults().slice(2 * i, 2));
+ rewriter.replaceOpWithMultiple(op, repls);
+ return newOp;
+ };
+
+ // Replace test.multiple_1_to_n_replacement with test.step_1.
+ Operation *repl1 = replaceWithDoubleResults(op, "test.step_1");
+ // Now replace test.step_1 with test.legal_op.
+ replaceWithDoubleResults(repl1, "test.legal_op");
+ return success();
+ }
+};
+
} // namespace
namespace {
@@ -1241,7 +1276,6 @@ struct TestTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter;
TestTypeConverter() {
addConversion(convertType);
- addArgumentMaterialization(materializeCast);
addSourceMaterialization(materializeCast);
}
@@ -1319,7 +1353,8 @@ struct TestLegalizePatternDriver
TestUndoPropertiesModification, TestEraseOp,
TestRepetitive1ToNConsumer>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
- TestPassthroughInvalidOp>(&getContext(), converter);
+ TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
+ &getContext(), converter);
patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1330,8 +1365,7 @@ struct TestLegalizePatternDriver
target.addLegalOp<ModuleOp>();
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
TerminatorOp, OneRegionOp>();
- target.addLegalOp(
- OperationName("test.legal_op_with_region", &getContext()));
+ target.addLegalOp(OperationName("test.legal_op", &getContext()));
target
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {