aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SCF/IR/SCF.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SCF/IR/SCF.cpp')
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp52
1 files changed, 30 insertions, 22 deletions
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 1ab01d8..2946b53 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -397,7 +397,7 @@ void ExecuteRegionOp::getSuccessorRegions(
}
// Otherwise, the region branches back to the parent operation.
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
}
//===----------------------------------------------------------------------===//
@@ -405,10 +405,11 @@ void ExecuteRegionOp::getSuccessorRegions(
//===----------------------------------------------------------------------===//
MutableOperandRange
-ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
- assert((point.isParent() || point == getParentOp().getAfter()) &&
- "condition op can only exit the loop or branch to the after"
- "region");
+ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
+ assert(
+ (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) &&
+ "condition op can only exit the loop or branch to the after"
+ "region");
// Pass all operands except the condition to the successor region.
return getArgsMutable();
}
@@ -426,7 +427,7 @@ void ConditionOp::getSuccessorRegions(
regions.emplace_back(&whileOp.getAfter(),
whileOp.getAfter().getArguments());
if (!boolAttr || !boolAttr.getValue())
- regions.emplace_back(whileOp.getResults());
+ regions.emplace_back(whileOp.getOperation(), whileOp.getResults());
}
//===----------------------------------------------------------------------===//
@@ -749,7 +750,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
return dyn_cast_or_null<ForOp>(containingOp);
}
-OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
return getInitArgs();
}
@@ -759,7 +760,7 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
// back into the operation itself. It is possible for loop not to enter the
// body.
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
}
SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
@@ -2053,9 +2054,10 @@ void ForallOp::getSuccessorRegions(RegionBranchPoint point,
// parallel by multiple threads. We should not expect to branch back into
// the forall body after the region's execution is complete.
if (point.isParent())
- regions.push_back(RegionSuccessor(&getRegion()));
+ regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
else
- regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(getOperation(), getOperation()->getResults()));
}
//===----------------------------------------------------------------------===//
@@ -2333,9 +2335,10 @@ void IfOp::print(OpAsmPrinter &p) {
void IfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
- // The `then` and the `else` region branch back to the parent operation.
+ // The `then` and the `else` region branch back to the parent operation or one
+ // of the recursive parent operations (early exit case).
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
@@ -2344,7 +2347,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point,
// Don't consider the else region if it is empty.
Region *elseRegion = &this->getElseRegion();
if (elseRegion->empty())
- regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(getOperation(), getOperation()->getResults()));
else
regions.push_back(RegionSuccessor(elseRegion));
}
@@ -2361,7 +2365,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
if (!getElseRegion().empty())
regions.emplace_back(&getElseRegion());
else
- regions.emplace_back(getResults());
+ regions.emplace_back(getOperation(), getResults());
}
}
@@ -3385,7 +3389,8 @@ void ParallelOp::getSuccessorRegions(
// back into the operation itself. It is possible for loop not to enter the
// body.
regions.push_back(RegionSuccessor(&getRegion()));
- regions.push_back(RegionSuccessor());
+ regions.push_back(RegionSuccessor(
+ getOperation(), ResultRange{getResults().end(), getResults().end()}));
}
//===----------------------------------------------------------------------===//
@@ -3431,7 +3436,7 @@ LogicalResult ReduceOp::verifyRegions() {
}
MutableOperandRange
-ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) {
+ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
// No operands are forwarded to the next iteration.
return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
}
@@ -3514,8 +3519,8 @@ Block::BlockArgListType WhileOp::getRegionIterArgs() {
return getBeforeArguments();
}
-OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBefore() &&
+OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert(successor.getSuccessor() == &getBefore() &&
"WhileOp is expected to branch only to the first region");
return getInits();
}
@@ -3528,15 +3533,18 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point,
return;
}
- assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
+ assert(llvm::is_contained(
+ {&getAfter(), &getBefore()},
+ point.getTerminatorPredecessorOrNull()->getParentRegion()) &&
"there are only two regions in a WhileOp");
// The body region always branches back to the condition region.
- if (point == getAfter()) {
+ if (point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &getAfter()) {
regions.emplace_back(&getBefore(), getBefore().getArguments());
return;
}
- regions.emplace_back(getResults());
+ regions.emplace_back(getOperation(), getResults());
regions.emplace_back(&getAfter(), getAfter().getArguments());
}
@@ -4445,7 +4453,7 @@ void IndexSwitchOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
// All regions branch back to the parent op.
if (!point.isParent()) {
- successors.emplace_back(getResults());
+ successors.emplace_back(getOperation(), getResults());
return;
}