diff options
Diffstat (limited to 'mlir/test/lib/Dialect/Test/TestPatterns.cpp')
-rw-r--r-- | mlir/test/lib/Dialect/Test/TestPatterns.cpp | 44 |
1 files changed, 39 insertions, 5 deletions
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index a470497..5b7c36c 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -785,7 +785,7 @@ struct TestDetachedSignatureConversion : public ConversionPattern { ConversionPatternRewriter &rewriter) const final { if (op->getNumRegions() != 1) return failure(); - OperationState state(op->getLoc(), "test.legal_op_with_region", operands, + OperationState state(op->getLoc(), "test.legal_op", operands, op->getResultTypes(), {}, BlockRange()); Region *newRegion = state.addRegion(); rewriter.inlineRegionBefore(op->getRegion(0), *newRegion, @@ -1234,6 +1234,41 @@ public: } }; +/// A pattern that tests two back-to-back 1 -> 2 op replacements. +class TestMultiple1ToNReplacement : public ConversionPattern { +public: + TestMultiple1ToNReplacement(MLIRContext *ctx, const TypeConverter &converter) + : ConversionPattern(converter, "test.multiple_1_to_n_replacement", 1, + ctx) {} + LogicalResult + matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands, + ConversionPatternRewriter &rewriter) const final { + // Helper function that replaces the given op with a new op of the given + // name and doubles each result (1 -> 2 replacement of each result). + auto replaceWithDoubleResults = [&](Operation *op, StringRef name) { + SmallVector<Type> types; + for (Type t : op->getResultTypes()) { + types.push_back(t); + types.push_back(t); + } + OperationState state(op->getLoc(), name, + /*operands=*/{}, types, op->getAttrs()); + auto *newOp = rewriter.create(state); + SmallVector<ValueRange> repls; + for (size_t i = 0, e = op->getNumResults(); i < e; ++i) + repls.push_back(newOp->getResults().slice(2 * i, 2)); + rewriter.replaceOpWithMultiple(op, repls); + return newOp; + }; + + // Replace test.multiple_1_to_n_replacement with test.step_1. + Operation *repl1 = replaceWithDoubleResults(op, "test.step_1"); + // Now replace test.step_1 with test.legal_op. + replaceWithDoubleResults(repl1, "test.legal_op"); + return success(); + } +}; + } // namespace namespace { @@ -1241,7 +1276,6 @@ struct TestTypeConverter : public TypeConverter { using TypeConverter::TypeConverter; TestTypeConverter() { addConversion(convertType); - addArgumentMaterialization(materializeCast); addSourceMaterialization(materializeCast); } @@ -1319,7 +1353,8 @@ struct TestLegalizePatternDriver TestUndoPropertiesModification, TestEraseOp, TestRepetitive1ToNConsumer>(&getContext()); patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp, - TestPassthroughInvalidOp>(&getContext(), converter); + TestPassthroughInvalidOp, TestMultiple1ToNReplacement>( + &getContext(), converter); patterns.add<TestDuplicateBlockArgs>(converter, &getContext()); mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns, converter); @@ -1330,8 +1365,7 @@ struct TestLegalizePatternDriver target.addLegalOp<ModuleOp>(); target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp, TerminatorOp, OneRegionOp>(); - target.addLegalOp( - OperationName("test.legal_op_with_region", &getContext())); + target.addLegalOp(OperationName("test.legal_op", &getContext())); target .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>(); target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) { |