diff options
Diffstat (limited to 'mlir/lib/Dialect/Affine/IR')
| -rw-r--r-- | mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 50 | 
1 files changed, 29 insertions, 21 deletions
| diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index e0a53cd..0c35921 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2716,8 +2716,9 @@ LogicalResult AffineForOp::fold(FoldAdaptor adaptor,    return success(folded);  } -OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { -  assert((point.isParent() || point == getRegion()) && "invalid region point"); +OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) { +  assert((successor.isParent() || successor.getSuccessor() == &getRegion()) && +         "invalid region point");    // The initial operands map to the loop arguments after the induction    // variable or are forwarded to the results when the trip count is zero. @@ -2726,34 +2727,41 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {  void AffineForOp::getSuccessorRegions(      RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { -  assert((point.isParent() || point == getRegion()) && "expected loop region"); +  assert((point.isParent() || +          point.getTerminatorPredecessorOrNull()->getParentRegion() == +              &getRegion()) && +         "expected loop region");    // The loop may typically branch back to its body or to the parent operation.    // If the predecessor is the parent op and the trip count is known to be at    // least one, branch into the body using the iterator arguments. And in cases    // we know the trip count is zero, it can only branch back to its parent.    std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this); -  if (point.isParent() && tripCount.has_value()) { -    if (tripCount.value() > 0) { -      regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); -      return; -    } -    if (tripCount.value() == 0) { -      regions.push_back(RegionSuccessor(getResults())); -      return; +  if (tripCount.has_value()) { +    if (!point.isParent()) { +      // From the loop body, if the trip count is one, we can only branch back +      // to the parent. +      if (tripCount == 1) { +        regions.push_back(RegionSuccessor(getOperation(), getResults())); +        return; +      } +      if (tripCount == 0) +        return; +    } else { +      if (tripCount.value() > 0) { +        regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); +        return; +      } +      if (tripCount.value() == 0) { +        regions.push_back(RegionSuccessor(getOperation(), getResults())); +        return; +      }      }    } -  // From the loop body, if the trip count is one, we can only branch back to -  // the parent. -  if (!point.isParent() && tripCount == 1) { -    regions.push_back(RegionSuccessor(getResults())); -    return; -  } -    // In all other cases, the loop may branch back to itself or the parent    // operation.    regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); -  regions.push_back(RegionSuccessor(getResults())); +  regions.push_back(RegionSuccessor(getOperation(), getResults()));  }  AffineBound AffineForOp::getLowerBound() { @@ -3142,7 +3150,7 @@ void AffineIfOp::getSuccessorRegions(          RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));      // If the "else" region is empty, branch bach into parent.      if (getElseRegion().empty()) { -      regions.push_back(getResults()); +      regions.push_back(RegionSuccessor(getOperation(), getResults()));      } else {        regions.push_back(            RegionSuccessor(&getElseRegion(), getElseRegion().getArguments())); @@ -3152,7 +3160,7 @@ void AffineIfOp::getSuccessorRegions(    // If the predecessor is the `else`/`then` region, then branching into parent    // op is valid. -  regions.push_back(RegionSuccessor(getResults())); +  regions.push_back(RegionSuccessor(getOperation(), getResults()));  }  LogicalResult AffineIfOp::verify() { | 
