aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Async/IR/Async.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Async/IR/Async.cpp')
-rw-r--r--mlir/lib/Dialect/Async/IR/Async.cpp11
1 files changed, 7 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index dc7b07d..8e4a49d 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -36,8 +36,9 @@ void AsyncDialect::initialize() {
constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
-OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBodyRegion() && "invalid region index");
+OperandRange ExecuteOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert(successor.getSuccessor() == &getBodyRegion() &&
+ "invalid region index");
return getBodyOperands();
}
@@ -53,8 +54,10 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `body` region branch back to the parent operation.
- if (point == getBodyRegion()) {
- regions.push_back(RegionSuccessor(getBodyResults()));
+ if (!point.isParent() &&
+ point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &getBodyRegion()) {
+ regions.push_back(RegionSuccessor(getOperation(), getBodyResults()));
return;
}