diff options
Diffstat (limited to 'mlir/lib/Dialect/SCF/IR/SCF.cpp')
| -rw-r--r-- | mlir/lib/Dialect/SCF/IR/SCF.cpp | 52 | 
1 files changed, 30 insertions, 22 deletions
| diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 1ab01d8..2946b53 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -397,7 +397,7 @@ void ExecuteRegionOp::getSuccessorRegions(    }    // Otherwise, the region branches back to the parent operation. -  regions.push_back(RegionSuccessor(getResults())); +  regions.push_back(RegionSuccessor(getOperation(), getResults()));  }  //===----------------------------------------------------------------------===// @@ -405,10 +405,11 @@ void ExecuteRegionOp::getSuccessorRegions(  //===----------------------------------------------------------------------===//  MutableOperandRange -ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { -  assert((point.isParent() || point == getParentOp().getAfter()) && -         "condition op can only exit the loop or branch to the after" -         "region"); +ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) { +  assert( +      (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) && +      "condition op can only exit the loop or branch to the after" +      "region");    // Pass all operands except the condition to the successor region.    return getArgsMutable();  } @@ -426,7 +427,7 @@ void ConditionOp::getSuccessorRegions(      regions.emplace_back(&whileOp.getAfter(),                           whileOp.getAfter().getArguments());    if (!boolAttr || !boolAttr.getValue()) -    regions.emplace_back(whileOp.getResults()); +    regions.emplace_back(whileOp.getOperation(), whileOp.getResults());  }  //===----------------------------------------------------------------------===// @@ -749,7 +750,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {    return dyn_cast_or_null<ForOp>(containingOp);  } -OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) { +OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {    return getInitArgs();  } @@ -759,7 +760,7 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,    // back into the operation itself. It is possible for loop not to enter the    // body.    regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); -  regions.push_back(RegionSuccessor(getResults())); +  regions.push_back(RegionSuccessor(getOperation(), getResults()));  }  SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; } @@ -2053,9 +2054,10 @@ void ForallOp::getSuccessorRegions(RegionBranchPoint point,    // parallel by multiple threads. We should not expect to branch back into    // the forall body after the region's execution is complete.    if (point.isParent()) -    regions.push_back(RegionSuccessor(&getRegion())); +    regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));    else -    regions.push_back(RegionSuccessor()); +    regions.push_back( +        RegionSuccessor(getOperation(), getOperation()->getResults()));  }  //===----------------------------------------------------------------------===// @@ -2333,9 +2335,10 @@ void IfOp::print(OpAsmPrinter &p) {  void IfOp::getSuccessorRegions(RegionBranchPoint point,                                 SmallVectorImpl<RegionSuccessor> ®ions) { -  // The `then` and the `else` region branch back to the parent operation. +  // The `then` and the `else` region branch back to the parent operation or one +  // of the recursive parent operations (early exit case).    if (!point.isParent()) { -    regions.push_back(RegionSuccessor(getResults())); +    regions.push_back(RegionSuccessor(getOperation(), getResults()));      return;    } @@ -2344,7 +2347,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point,    // Don't consider the else region if it is empty.    Region *elseRegion = &this->getElseRegion();    if (elseRegion->empty()) -    regions.push_back(RegionSuccessor()); +    regions.push_back( +        RegionSuccessor(getOperation(), getOperation()->getResults()));    else      regions.push_back(RegionSuccessor(elseRegion));  } @@ -2361,7 +2365,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,      if (!getElseRegion().empty())        regions.emplace_back(&getElseRegion());      else -      regions.emplace_back(getResults()); +      regions.emplace_back(getOperation(), getResults());    }  } @@ -3385,7 +3389,8 @@ void ParallelOp::getSuccessorRegions(    // back into the operation itself. It is possible for loop not to enter the    // body.    regions.push_back(RegionSuccessor(&getRegion())); -  regions.push_back(RegionSuccessor()); +  regions.push_back(RegionSuccessor( +      getOperation(), ResultRange{getResults().end(), getResults().end()}));  }  //===----------------------------------------------------------------------===// @@ -3431,7 +3436,7 @@ LogicalResult ReduceOp::verifyRegions() {  }  MutableOperandRange -ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) { +ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {    // No operands are forwarded to the next iteration.    return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);  } @@ -3514,8 +3519,8 @@ Block::BlockArgListType WhileOp::getRegionIterArgs() {    return getBeforeArguments();  } -OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) { -  assert(point == getBefore() && +OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) { +  assert(successor.getSuccessor() == &getBefore() &&           "WhileOp is expected to branch only to the first region");    return getInits();  } @@ -3528,15 +3533,18 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point,      return;    } -  assert(llvm::is_contained({&getAfter(), &getBefore()}, point) && +  assert(llvm::is_contained( +             {&getAfter(), &getBefore()}, +             point.getTerminatorPredecessorOrNull()->getParentRegion()) &&           "there are only two regions in a WhileOp");    // The body region always branches back to the condition region. -  if (point == getAfter()) { +  if (point.getTerminatorPredecessorOrNull()->getParentRegion() == +      &getAfter()) {      regions.emplace_back(&getBefore(), getBefore().getArguments());      return;    } -  regions.emplace_back(getResults()); +  regions.emplace_back(getOperation(), getResults());    regions.emplace_back(&getAfter(), getAfter().getArguments());  } @@ -4445,7 +4453,7 @@ void IndexSwitchOp::getSuccessorRegions(      RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {    // All regions branch back to the parent op.    if (!point.isParent()) { -    successors.emplace_back(getResults()); +    successors.emplace_back(getOperation(), getResults());      return;    } | 
