diff options
Diffstat (limited to 'mlir/lib/Transforms')
| -rw-r--r-- | mlir/lib/Transforms/RemoveDeadValues.cpp | 25 | ||||
| -rw-r--r-- | mlir/lib/Transforms/ViewOpGraph.cpp | 5 |
2 files changed, 21 insertions, 9 deletions
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index e0c65b0..41f3f9d 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -432,8 +432,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, // Return the successors of `region` if the latter is not null. Else return // the successors of `regionBranchOp`. - auto getSuccessors = [&](Region *region = nullptr) { - auto point = region ? region : RegionBranchPoint::parent(); + auto getSuccessors = [&](RegionBranchPoint point) { SmallVector<RegionSuccessor> successors; regionBranchOp.getSuccessorRegions(point, successors); return successors; @@ -456,7 +455,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, // `nonForwardedOperands`. auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) { nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true); - for (const RegionSuccessor &successor : getSuccessors()) { + for (const RegionSuccessor &successor : + getSuccessors(RegionBranchPoint::parent())) { for (OpOperand *opOperand : getForwardedOpOperands(successor)) nonForwardedOperands.reset(opOperand->getOperandNumber()); } @@ -469,10 +469,13 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, for (Region ®ion : regionBranchOp->getRegions()) { if (region.empty()) continue; + // TODO: this isn't correct in face of multiple terminators. Operation *terminator = region.front().getTerminator(); nonForwardedRets[terminator] = BitVector(terminator->getNumOperands(), true); - for (const RegionSuccessor &successor : getSuccessors(®ion)) { + for (const RegionSuccessor &successor : + getSuccessors(RegionBranchPoint( + cast<RegionBranchTerminatorOpInterface>(terminator)))) { for (OpOperand *opOperand : getForwardedOpOperands(successor, terminator)) nonForwardedRets[terminator].reset(opOperand->getOperandNumber()); @@ -489,8 +492,13 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) { Operation *terminator = region ? region->front().getTerminator() : nullptr; + RegionBranchPoint point = + terminator + ? RegionBranchPoint( + cast<RegionBranchTerminatorOpInterface>(terminator)) + : RegionBranchPoint::parent(); - for (const RegionSuccessor &successor : getSuccessors(region)) { + for (const RegionSuccessor &successor : getSuccessors(point)) { Region *successorRegion = successor.getSuccessor(); for (auto [opOperand, input] : llvm::zip(getForwardedOpOperands(successor, terminator), @@ -517,7 +525,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, resultsOrArgsToKeepChanged = false; // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`. - for (const RegionSuccessor &successor : getSuccessors()) { + for (const RegionSuccessor &successor : + getSuccessors(RegionBranchPoint::parent())) { Region *successorRegion = successor.getSuccessor(); for (auto [opOperand, input] : llvm::zip(getForwardedOpOperands(successor), @@ -551,7 +560,9 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, if (region.empty()) continue; Operation *terminator = region.front().getTerminator(); - for (const RegionSuccessor &successor : getSuccessors(®ion)) { + for (const RegionSuccessor &successor : + getSuccessors(RegionBranchPoint( + cast<RegionBranchTerminatorOpInterface>(terminator)))) { Region *successorRegion = successor.getSuccessor(); for (auto [opOperand, input] : llvm::zip(getForwardedOpOperands(successor, terminator), diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp index 08cac1f..5790a77 100644 --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -158,7 +158,8 @@ private: /// Emit a cluster (subgraph). The specified builder generates the body of the /// cluster. Return the anchor node of the cluster. - Node emitClusterStmt(function_ref<void()> builder, std::string label = "") { + Node emitClusterStmt(function_ref<void()> builder, + const std::string &label = "") { int clusterId = ++counter; os << "subgraph cluster_" << clusterId << " {\n"; os.indent(); @@ -269,7 +270,7 @@ private: } /// Emit a node statement. - Node emitNodeStmt(std::string label, StringRef shape = kShapeNode, + Node emitNodeStmt(const std::string &label, StringRef shape = kShapeNode, StringRef background = "") { int nodeId = ++counter; AttributeMap attrs; |
