diff options
Diffstat (limited to 'mlir/test/lib/Dialect/Test')
| -rw-r--r-- | mlir/test/lib/Dialect/Test/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 26 | ||||
| -rw-r--r-- | mlir/test/lib/Dialect/Test/TestOps.td | 2 | ||||
| -rw-r--r-- | mlir/test/lib/Dialect/Test/TestPatterns.cpp | 7 |
4 files changed, 23 insertions, 13 deletions
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt index f099d01..9354a85 100644 --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -71,6 +71,7 @@ add_mlir_library(MLIRTestDialect ) mlir_target_link_libraries(MLIRTestDialect PUBLIC MLIRControlFlowInterfaces + MLIRControlFlowTransforms MLIRDataLayoutInterfaces MLIRDerivedAttributeOpInterface MLIRDestinationStyleOpInterface diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index b211e24..4d4ec02 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -633,8 +633,9 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { parser.getCurrentLocation(), result.operands); } -OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) && +OperandRange RegionIfOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, + successor.getSuccessor()) && "invalid region index"); return getOperands(); } @@ -643,10 +644,11 @@ void RegionIfOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { // We always branch to the join region. if (!point.isParent()) { - if (point != getJoinRegion()) + if (point.getTerminatorPredecessorOrNull()->getParentRegion() != + &getJoinRegion()) regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); else - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } @@ -673,7 +675,7 @@ void AnyCondOp::getSuccessorRegions(RegionBranchPoint point, if (point.isParent()) regions.emplace_back(&getRegion()); else - regions.emplace_back(getResults()); + regions.emplace_back(getOperation(), getResults()); } void AnyCondOp::getRegionInvocationBounds( @@ -1107,11 +1109,11 @@ void LoopBlockOp::getSuccessorRegions( if (point.isParent()) return; - regions.emplace_back((*this)->getResults()); + regions.emplace_back(getOperation(), getOperation()->getResults()); } -OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBody()); +OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getBody()); return MutableOperandRange(getInitMutable()); } @@ -1120,8 +1122,8 @@ OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { //===----------------------------------------------------------------------===// MutableOperandRange -LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) { - if (point.isParent()) +LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionSuccessor successor) { + if (successor.isParent()) return getExitArgMutable(); return getNextIterArgMutable(); } @@ -1213,7 +1215,7 @@ void TestStoreWithARegion::getSuccessorRegions( if (point.isParent()) regions.emplace_back(&getBody(), getBody().front().getArguments()); else - regions.emplace_back(); + regions.emplace_back(getOperation(), getOperation()->getResults()); } //===----------------------------------------------------------------------===// @@ -1227,7 +1229,7 @@ void TestStoreWithALoopRegion::getSuccessorRegions( // enter the body. regions.emplace_back( RegionSuccessor(&getBody(), getBody().front().getArguments())); - regions.emplace_back(); + regions.emplace_back(getOperation(), getOperation()->getResults()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 05a33cf..a3430ba 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2581,7 +2581,7 @@ def LoopBlockTerminatorOp : TEST_Op<"loop_block_term", def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [ NoTerminator, - DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorRegions"]> + DeclareOpInterfaceMethods<RegionBranchOpInterface> ]> { let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases); let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index efbdbfb..fd2b943 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -11,6 +11,7 @@ #include "TestTypes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/CommonFolders.h" +#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" @@ -2042,6 +2043,10 @@ struct TestTypeConversionDriver }); converter.addConversion([](IndexType type) { return type; }); converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &types) { + if (type.isInteger(1)) { + // i1 is legal. + types.push_back(type); + } if (type.isInteger(38)) { // i38 is legal. types.push_back(type); @@ -2175,6 +2180,8 @@ struct TestTypeConversionDriver converter); mlir::scf::populateSCFStructuralTypeConversionsAndLegality( converter, patterns, target); + mlir::cf::populateCFStructuralTypeConversionsAndLegality(converter, + patterns, target); ConversionConfig config; config.allowPatternRollback = allowPatternRollback; |
