diff options
Diffstat (limited to 'mlir/test/lib/Dialect/Test')
| -rw-r--r-- | mlir/test/lib/Dialect/Test/CMakeLists.txt | 1 | ||||
| -rw-r--r-- | mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 26 | ||||
| -rw-r--r-- | mlir/test/lib/Dialect/Test/TestOps.td | 2 | ||||
| -rw-r--r-- | mlir/test/lib/Dialect/Test/TestPatterns.cpp | 7 | 
4 files changed, 23 insertions, 13 deletions
| diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt index f099d01..9354a85 100644 --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -71,6 +71,7 @@ add_mlir_library(MLIRTestDialect    )  mlir_target_link_libraries(MLIRTestDialect PUBLIC    MLIRControlFlowInterfaces +  MLIRControlFlowTransforms    MLIRDataLayoutInterfaces    MLIRDerivedAttributeOpInterface    MLIRDestinationStyleOpInterface diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index b211e24..4d4ec02 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -633,8 +633,9 @@ ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {                                  parser.getCurrentLocation(), result.operands);  } -OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) { -  assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) && +OperandRange RegionIfOp::getEntrySuccessorOperands(RegionSuccessor successor) { +  assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, +                            successor.getSuccessor()) &&           "invalid region index");    return getOperands();  } @@ -643,10 +644,11 @@ void RegionIfOp::getSuccessorRegions(      RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {    // We always branch to the join region.    if (!point.isParent()) { -    if (point != getJoinRegion()) +    if (point.getTerminatorPredecessorOrNull()->getParentRegion() != +        &getJoinRegion())        regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));      else -      regions.push_back(RegionSuccessor(getResults())); +      regions.push_back(RegionSuccessor(getOperation(), getResults()));      return;    } @@ -673,7 +675,7 @@ void AnyCondOp::getSuccessorRegions(RegionBranchPoint point,    if (point.isParent())      regions.emplace_back(&getRegion());    else -    regions.emplace_back(getResults()); +    regions.emplace_back(getOperation(), getResults());  }  void AnyCondOp::getRegionInvocationBounds( @@ -1107,11 +1109,11 @@ void LoopBlockOp::getSuccessorRegions(    if (point.isParent())      return; -  regions.emplace_back((*this)->getResults()); +  regions.emplace_back(getOperation(), getOperation()->getResults());  } -OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { -  assert(point == getBody()); +OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionSuccessor successor) { +  assert(successor.getSuccessor() == &getBody());    return MutableOperandRange(getInitMutable());  } @@ -1120,8 +1122,8 @@ OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {  //===----------------------------------------------------------------------===//  MutableOperandRange -LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) { -  if (point.isParent()) +LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionSuccessor successor) { +  if (successor.isParent())      return getExitArgMutable();    return getNextIterArgMutable();  } @@ -1213,7 +1215,7 @@ void TestStoreWithARegion::getSuccessorRegions(    if (point.isParent())      regions.emplace_back(&getBody(), getBody().front().getArguments());    else -    regions.emplace_back(); +    regions.emplace_back(getOperation(), getOperation()->getResults());  }  //===----------------------------------------------------------------------===// @@ -1227,7 +1229,7 @@ void TestStoreWithALoopRegion::getSuccessorRegions(    // enter the body.    regions.emplace_back(        RegionSuccessor(&getBody(), getBody().front().getArguments())); -  regions.emplace_back(); +  regions.emplace_back(getOperation(), getOperation()->getResults());  }  //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 05a33cf..a3430ba 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2581,7 +2581,7 @@ def LoopBlockTerminatorOp : TEST_Op<"loop_block_term",  def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [      NoTerminator, -    DeclareOpInterfaceMethods<RegionBranchOpInterface, ["getSuccessorRegions"]> +    DeclareOpInterfaceMethods<RegionBranchOpInterface>    ]> {    let arguments = (ins Index:$arg, DenseI64ArrayAttr:$cases);    let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index efbdbfb..fd2b943 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -11,6 +11,7 @@  #include "TestTypes.h"  #include "mlir/Dialect/Arith/IR/Arith.h"  #include "mlir/Dialect/CommonFolders.h" +#include "mlir/Dialect/ControlFlow/Transforms/StructuralTypeConversions.h"  #include "mlir/Dialect/Func/IR/FuncOps.h"  #include "mlir/Dialect/Func/Transforms/FuncConversions.h"  #include "mlir/Dialect/SCF/Transforms/Patterns.h" @@ -2042,6 +2043,10 @@ struct TestTypeConversionDriver      });      converter.addConversion([](IndexType type) { return type; });      converter.addConversion([](IntegerType type, SmallVectorImpl<Type> &types) { +      if (type.isInteger(1)) { +        // i1 is legal. +        types.push_back(type); +      }        if (type.isInteger(38)) {          // i38 is legal.          types.push_back(type); @@ -2175,6 +2180,8 @@ struct TestTypeConversionDriver                                                                converter);      mlir::scf::populateSCFStructuralTypeConversionsAndLegality(          converter, patterns, target); +    mlir::cf::populateCFStructuralTypeConversionsAndLegality(converter, +                                                             patterns, target);      ConversionConfig config;      config.allowPatternRollback = allowPatternRollback; | 
