aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms')
-rw-r--r--mlir/lib/Transforms/RemoveDeadValues.cpp25
1 files changed, 18 insertions, 7 deletions
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index e0c65b0..41f3f9d 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -432,8 +432,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// Return the successors of `region` if the latter is not null. Else return
// the successors of `regionBranchOp`.
- auto getSuccessors = [&](Region *region = nullptr) {
- auto point = region ? region : RegionBranchPoint::parent();
+ auto getSuccessors = [&](RegionBranchPoint point) {
SmallVector<RegionSuccessor> successors;
regionBranchOp.getSuccessorRegions(point, successors);
return successors;
@@ -456,7 +455,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// `nonForwardedOperands`.
auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
- for (const RegionSuccessor &successor : getSuccessors()) {
+ for (const RegionSuccessor &successor :
+ getSuccessors(RegionBranchPoint::parent())) {
for (OpOperand *opOperand : getForwardedOpOperands(successor))
nonForwardedOperands.reset(opOperand->getOperandNumber());
}
@@ -469,10 +469,13 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
for (Region &region : regionBranchOp->getRegions()) {
if (region.empty())
continue;
+ // TODO: this isn't correct in face of multiple terminators.
Operation *terminator = region.front().getTerminator();
nonForwardedRets[terminator] =
BitVector(terminator->getNumOperands(), true);
- for (const RegionSuccessor &successor : getSuccessors(&region)) {
+ for (const RegionSuccessor &successor :
+ getSuccessors(RegionBranchPoint(
+ cast<RegionBranchTerminatorOpInterface>(terminator)))) {
for (OpOperand *opOperand :
getForwardedOpOperands(successor, terminator))
nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
@@ -489,8 +492,13 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
Operation *terminator =
region ? region->front().getTerminator() : nullptr;
+ RegionBranchPoint point =
+ terminator
+ ? RegionBranchPoint(
+ cast<RegionBranchTerminatorOpInterface>(terminator))
+ : RegionBranchPoint::parent();
- for (const RegionSuccessor &successor : getSuccessors(region)) {
+ for (const RegionSuccessor &successor : getSuccessors(point)) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor, terminator),
@@ -517,7 +525,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
resultsOrArgsToKeepChanged = false;
// Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
- for (const RegionSuccessor &successor : getSuccessors()) {
+ for (const RegionSuccessor &successor :
+ getSuccessors(RegionBranchPoint::parent())) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor),
@@ -551,7 +560,9 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
Operation *terminator = region.front().getTerminator();
- for (const RegionSuccessor &successor : getSuccessors(&region)) {
+ for (const RegionSuccessor &successor :
+ getSuccessors(RegionBranchPoint(
+ cast<RegionBranchTerminatorOpInterface>(terminator)))) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor, terminator),