aboutsummaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Test
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/lib/Dialect/Test')
-rw-r--r--mlir/test/lib/Dialect/Test/CMakeLists.txt1
-rw-r--r--mlir/test/lib/Dialect/Test/TestOpDefs.cpp26
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td2
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp7
4 files changed, 23 insertions, 13 deletions
diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt
index f099d01..9354a85 100644
--- a/mlir/test/lib/Dialect/Test/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt
@@ -71,6 +71,7 @@ add_mlir_library(MLIRTestDialect
)
mlir_target_link_libraries(MLIRTestDialect PUBLIC
MLIRControlFlowInterfaces
+ MLIRControlFlowTransforms
MLIRDataLayoutInterfaces
MLIRDerivedAttributeOpInterface
MLIRDestinationStyleOpInterface
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index b211e24..4d4ec02 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -633,8 +633,9 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
parser.getCurrentLocation(), result.operands);
}
-OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) &&
+OperandRange RegionIfOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert(llvm::is_contained({&getThenRegion(), &getElseRegion()},
+ successor.getSuccessor()) &&
"invalid region index");
return getOperands();
}
@@ -643,10 +644,11 @@ void RegionIfOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
// We always branch to the join region.
if (!point.isParent()) {
- if (point != getJoinRegion())
+ if (point.getTerminatorPredecessorOrNull()->getParentRegion() !=
+ &getJoinRegion())
regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
else
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
@@ -673,7 +675,7 @@ void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,
if (point.isParent())
regions.emplace_back(&getRegion());
else
- regions.emplace_back(getResults());
+ regions.emplace_back(getOperation(), getResults());
}
void AnyCondOp::getRegionInvocationBounds(
@@ -1107,11 +1109,11 @@ void LoopBlockOp::getSuccessorRegions(
if (point.isParent())
return;
- regions.emplace_back((*this)->getResults());
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
-OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBody());
+OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert(successor.getSuccessor() == &getBody());
return MutableOperandRange(getInitMutable());
}
@@ -1120,8 +1122,8 @@ OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
//===----------------------------------------------------------------------===//
MutableOperandRange
-LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) {
- if (point.isParent())
+LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionSuccessor successor) {
+ if (successor.isParent())
return getExitArgMutable();
return getNextIterArgMutable();
}
@@ -1213,7 +1215,7 @@ void TestStoreWithARegion::getSuccessorRegions(
if (point.isParent())
regions.emplace_back(&getBody(), getBody().front().getArguments());
else
- regions.emplace_back();
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
//===----------------------------------------------------------------------===//
@@ -1227,7 +1229,7 @@ void TestStoreWithALoopRegion::getSuccessorRegions(
// enter the body.
regions.emplace_back(
RegionSuccessor(&getBody(), getBody().front().getArguments()));
- regions.emplace_back();
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 05a33cf..a3430ba 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2581,7 +2581,7 @@ def LoopBlockTerminatorOp : TEST_Op<"loop_block_term",
def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [
NoTerminator,
- DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorRegions"]>
+ DeclareOpInterfaceMethods<RegionBranchOpInterface>
]> {
let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases);
let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index efbdbfb..fd2b943 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -11,6 +11,7 @@
#include "TestTypes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/CommonFolders.h"
+#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -2042,6 +2043,10 @@ struct TestTypeConversionDriver
});
converter.addConversion([](IndexType type) { return type; });
converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &types) {
+ if (type.isInteger(1)) {
+ // i1 is legal.
+ types.push_back(type);
+ }
if (type.isInteger(38)) {
// i38 is legal.
types.push_back(type);
@@ -2175,6 +2180,8 @@ struct TestTypeConversionDriver
converter);
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
converter, patterns, target);
+ mlir::cf::populateCFStructuralTypeConversionsAndLegality(converter,
+ patterns, target);
ConversionConfig config;
config.allowPatternRollback = allowPatternRollback;