aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Test/TestPatterns.cpp
diff options
context:
space:
mode:
authorNAKAMURA Takumi <geek4civic@gmail.com>2025-01-09 18:49:54 +0900
committerNAKAMURA Takumi <geek4civic@gmail.com>2025-01-09 18:49:54 +0900
commite2810c9a248f4c7fbfae84bb32b6f7e01027458b (patch)
treeae0b02a8491b969a1cee94ea16ffe42c559143c5 /mlir/test/lib/Dialect/Test/TestPatterns.cpp
parentfa04eb4af95c1ca7377279728cb004bcd2324d01 (diff)
parentbdcf47e4bcb92889665825654bb80a8bbe30379e (diff)
downloadllvm-users/chapuni/cov/single/switch.zip
llvm-users/chapuni/cov/single/switch.tar.gz
llvm-users/chapuni/cov/single/switch.tar.bz2
Merge branch 'users/chapuni/cov/single/base' into users/chapuni/cov/single/switchusers/chapuni/cov/single/switch
Diffstat (limited to 'mlir/test/lib/Dialect/Test/TestPatterns.cpp')
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp44
1 files changed, 39 insertions, 5 deletions
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index a470497..5b7c36c 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -785,7 +785,7 @@ struct TestDetachedSignatureConversion : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
if (op->getNumRegions() != 1)
return failure();
- OperationState state(op->getLoc(), "test.legal_op_with_region", operands,
+ OperationState state(op->getLoc(), "test.legal_op", operands,
op->getResultTypes(), {}, BlockRange());
Region *newRegion = state.addRegion();
rewriter.inlineRegionBefore(op->getRegion(0), *newRegion,
@@ -1234,6 +1234,41 @@ public:
}
};
+/// A pattern that tests two back-to-back 1 -> 2 op replacements.
+class TestMultiple1ToNReplacement : public ConversionPattern {
+public:
+ TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter)
+ : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1,
+ ctx) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ // Helper function that replaces the given op with a new op of the given
+ // name and doubles each result (1 -> 2 replacement of each result).
+ auto replaceWithDoubleResults = [&](Operation *op, StringRef name) {
+ SmallVector<Type> types;
+ for (Type t : op->getResultTypes()) {
+ types.push_back(t);
+ types.push_back(t);
+ }
+ OperationState state(op->getLoc(), name,
+ /*operands=*/{}, types, op->getAttrs());
+ auto *newOp = rewriter.create(state);
+ SmallVector<ValueRange> repls;
+ for (size_t i = 0, e = op->getNumResults(); i < e; ++i)
+ repls.push_back(newOp->getResults().slice(2 * i, 2));
+ rewriter.replaceOpWithMultiple(op, repls);
+ return newOp;
+ };
+
+ // Replace test.multiple_1_to_n_replacement with test.step_1.
+ Operation *repl1 = replaceWithDoubleResults(op, "test.step_1");
+ // Now replace test.step_1 with test.legal_op.
+ replaceWithDoubleResults(repl1, "test.legal_op");
+ return success();
+ }
+};
+
} // namespace
namespace {
@@ -1241,7 +1276,6 @@ struct TestTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter;
TestTypeConverter() {
addConversion(convertType);
- addArgumentMaterialization(materializeCast);
addSourceMaterialization(materializeCast);
}
@@ -1319,7 +1353,8 @@ struct TestLegalizePatternDriver
TestUndoPropertiesModification, TestEraseOp,
TestRepetitive1ToNConsumer>(&getContext());
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
- TestPassthroughInvalidOp>(&getContext(), converter);
+ TestPassthroughInvalidOp, TestMultiple1ToNReplacement>(
+ &getContext(), converter);
patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@@ -1330,8 +1365,7 @@ struct TestLegalizePatternDriver
target.addLegalOp<ModuleOp>();
target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
TerminatorOp, OneRegionOp>();
- target.addLegalOp(
- OperationName("test.legal_op_with_region", &getContext()));
+ target.addLegalOp(OperationName("test.legal_op", &getContext()));
target
.addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {