diff options
Diffstat (limited to 'mlir/lib/Dialect/Affine/IR/AffineOps.cpp')
| -rw-r--r-- | mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 50 |
1 files changed, 29 insertions, 21 deletions
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index e0a53cd..0c35921 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2716,8 +2716,9 @@ LogicalResult AffineForOp::fold(FoldAdaptor adaptor, return success(folded); } -OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert((point.isParent() || point == getRegion()) && "invalid region point"); +OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert((successor.isParent() || successor.getSuccessor() == &getRegion()) && + "invalid region point"); // The initial operands map to the loop arguments after the induction // variable or are forwarded to the results when the trip count is zero. @@ -2726,34 +2727,41 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { void AffineForOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { - assert((point.isParent() || point == getRegion()) && "expected loop region"); + assert((point.isParent() || + point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getRegion()) && + "expected loop region"); // The loop may typically branch back to its body or to the parent operation. // If the predecessor is the parent op and the trip count is known to be at // least one, branch into the body using the iterator arguments. And in cases // we know the trip count is zero, it can only branch back to its parent. std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this); - if (point.isParent() && tripCount.has_value()) { - if (tripCount.value() > 0) { - regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - return; - } - if (tripCount.value() == 0) { - regions.push_back(RegionSuccessor(getResults())); - return; + if (tripCount.has_value()) { + if (!point.isParent()) { + // From the loop body, if the trip count is one, we can only branch back + // to the parent. + if (tripCount == 1) { + regions.push_back(RegionSuccessor(getOperation(), getResults())); + return; + } + if (tripCount == 0) + return; + } else { + if (tripCount.value() > 0) { + regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); + return; + } + if (tripCount.value() == 0) { + regions.push_back(RegionSuccessor(getOperation(), getResults())); + return; + } } } - // From the loop body, if the trip count is one, we can only branch back to - // the parent. - if (!point.isParent() && tripCount == 1) { - regions.push_back(RegionSuccessor(getResults())); - return; - } - // In all other cases, the loop may branch back to itself or the parent // operation. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } AffineBound AffineForOp::getLowerBound() { @@ -3142,7 +3150,7 @@ void AffineIfOp::getSuccessorRegions( RegionSuccessor(&getThenRegion(), getThenRegion().getArguments())); // If the "else" region is empty, branch bach into parent. if (getElseRegion().empty()) { - regions.push_back(getResults()); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } else { regions.push_back( RegionSuccessor(&getElseRegion(), getElseRegion().getArguments())); @@ -3152,7 +3160,7 @@ void AffineIfOp::getSuccessorRegions( // If the predecessor is the `else`/`then` region, then branching into parent // op is valid. - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } LogicalResult AffineIfOp::verify() { |
