aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Interfaces/ControlFlowInterfaces.cpp')
-rw-r--r--mlir/lib/Interfaces/ControlFlowInterfaces.cpp305
1 files changed, 217 insertions, 88 deletions
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index ca3f766..1e56810 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -9,7 +9,9 @@
#include <utility>
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "llvm/Support/DebugLog.h"
using namespace mlir;
@@ -38,20 +40,31 @@ SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
std::optional<BlockArgument>
detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
unsigned operandIndex, Block *successor) {
+ LDBG() << "Getting branch successor argument for operand index "
+ << operandIndex << " in successor block";
+
OperandRange forwardedOperands = operands.getForwardedOperands();
// Check that the operands are valid.
- if (forwardedOperands.empty())
+ if (forwardedOperands.empty()) {
+ LDBG() << "No forwarded operands, returning nullopt";
return std::nullopt;
+ }
// Check to ensure that this operand is within the range.
unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
if (operandIndex < operandsStart ||
- operandIndex >= (operandsStart + forwardedOperands.size()))
+ operandIndex >= (operandsStart + forwardedOperands.size())) {
+ LDBG() << "Operand index " << operandIndex << " out of range ["
+ << operandsStart << ", "
+ << (operandsStart + forwardedOperands.size())
+ << "), returning nullopt";
return std::nullopt;
+ }
// Index the successor.
unsigned argIndex =
operands.getProducedOperandCount() + operandIndex - operandsStart;
+ LDBG() << "Computed argument index " << argIndex << " for successor block";
return successor->getArgument(argIndex);
}
@@ -59,9 +72,15 @@ detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
LogicalResult
detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
const SuccessorOperands &operands) {
+ LDBG() << "Verifying branch successor operands for successor #" << succNo
+ << " in operation " << op->getName();
+
// Check the count.
unsigned operandCount = operands.size();
Block *destBB = op->getSuccessor(succNo);
+ LDBG() << "Branch has " << operandCount << " operands, target block has "
+ << destBB->getNumArguments() << " arguments";
+
if (operandCount != destBB->getNumArguments())
return op->emitError() << "branch has " << operandCount
<< " operands for successor #" << succNo
@@ -69,13 +88,22 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
<< destBB->getNumArguments();
// Check the types.
+ LDBG() << "Checking type compatibility for "
+ << (operandCount - operands.getProducedOperandCount())
+ << " forwarded operands";
for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
++i) {
- if (!cast<BranchOpInterface>(op).areTypesCompatible(
- operands[i].getType(), destBB->getArgument(i).getType()))
+ Type operandType = operands[i].getType();
+ Type argType = destBB->getArgument(i).getType();
+ LDBG() << "Checking type compatibility: operand type " << operandType
+ << " vs argument type " << argType;
+
+ if (!cast<BranchOpInterface>(op).areTypesCompatible(operandType, argType))
return op->emitError() << "type mismatch for bb argument #" << i
<< " of successor #" << succNo;
}
+
+ LDBG() << "Branch successor operand verification successful";
return success();
}
@@ -126,15 +154,15 @@ LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
RegionBranchPoint sourceNo,
- RegionBranchPoint succRegionNo) {
+ RegionSuccessor succRegionNo) {
diag << "from ";
- if (Region *region = sourceNo.getRegionOrNull())
- diag << "Region #" << region->getRegionNumber();
+ if (Operation *op = sourceNo.getTerminatorPredecessorOrNull())
+ diag << "Operation " << op->getName();
else
diag << "parent operands";
diag << " to ";
- if (Region *region = succRegionNo.getRegionOrNull())
+ if (Region *region = succRegionNo.getSuccessor())
diag << "Region #" << region->getRegionNumber();
else
diag << "parent results";
@@ -145,13 +173,12 @@ static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
/// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
/// types of the inputs that flow to a successor region.
static LogicalResult
-verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
- function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
+verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp,
+ RegionBranchPoint sourcePoint,
+ function_ref<FailureOr<TypeRange>(RegionSuccessor)>
getInputsTypesForRegion) {
- auto regionInterface = cast<RegionBranchOpInterface>(op);
-
SmallVector<RegionSuccessor, 2> successors;
- regionInterface.getSuccessorRegions(sourcePoint, successors);
+ branchOp.getSuccessorRegions(sourcePoint, successors);
for (RegionSuccessor &succ : successors) {
FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
@@ -160,10 +187,14 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
if (sourceTypes->size() != succInputsTypes.size()) {
- InFlightDiagnostic diag = op->emitOpError("region control flow edge ");
+ InFlightDiagnostic diag =
+ branchOp->emitOpError("region control flow edge ");
+ std::string succStr;
+ llvm::raw_string_ostream os(succStr);
+ os << succ;
return printRegionEdgeName(diag, sourcePoint, succ)
<< ": source has " << sourceTypes->size()
- << " operands, but target successor needs "
+ << " operands, but target successor " << os.str() << " needs "
<< succInputsTypes.size();
}
@@ -171,8 +202,10 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
Type sourceType = std::get<0>(typesIdx.value());
Type inputType = std::get<1>(typesIdx.value());
- if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
- InFlightDiagnostic diag = op->emitOpError("along control flow edge ");
+
+ if (!branchOp.areTypesCompatible(sourceType, inputType)) {
+ InFlightDiagnostic diag =
+ branchOp->emitOpError("along control flow edge ");
return printRegionEdgeName(diag, sourcePoint, succ)
<< ": source type #" << typesIdx.index() << " " << sourceType
<< " should match input type #" << typesIdx.index() << " "
@@ -180,6 +213,7 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
}
}
}
+
return success();
}
@@ -187,34 +221,18 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
- auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
- return regionInterface.getEntrySuccessorOperands(point).getTypes();
+ auto inputTypesFromParent = [&](RegionSuccessor successor) -> TypeRange {
+ return regionInterface.getEntrySuccessorOperands(successor).getTypes();
};
// Verify types along control flow edges originating from the parent.
- if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(),
- inputTypesFromParent)))
+ if (failed(verifyTypesAlongAllEdges(
+ regionInterface, RegionBranchPoint::parent(), inputTypesFromParent)))
return failure();
- auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
- if (lhs.size() != rhs.size())
- return false;
- for (auto types : llvm::zip(lhs, rhs)) {
- if (!regionInterface.areTypesCompatible(std::get<0>(types),
- std::get<1>(types))) {
- return false;
- }
- }
- return true;
- };
-
// Verify types along control flow edges originating from each region.
for (Region &region : op->getRegions()) {
-
- // Since there can be multiple terminators implementing the
- // `RegionBranchTerminatorOpInterface`, all should have the same operand
- // types when passing them to the same region.
-
+ // Collect all return-like terminators in the region.
SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
for (Block &block : region)
if (!block.empty())
@@ -227,33 +245,20 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
if (regionReturnOps.empty())
continue;
- auto inputTypesForRegion =
- [&](RegionBranchPoint point) -> FailureOr<TypeRange> {
- std::optional<OperandRange> regionReturnOperands;
- for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
- auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);
-
- if (!regionReturnOperands) {
- regionReturnOperands = terminatorOperands;
- continue;
- }
-
- // Found more than one ReturnLike terminator. Make sure the operand
- // types match with the first one.
- if (!areTypesCompatible(regionReturnOperands->getTypes(),
- terminatorOperands.getTypes())) {
- InFlightDiagnostic diag = op->emitOpError("along control flow edge");
- return printRegionEdgeName(diag, region, point)
- << " operands mismatch between return-like terminators";
- }
- }
-
- // All successors get the same set of operand types.
- return TypeRange(regionReturnOperands->getTypes());
- };
-
- if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
- return failure();
+ // Verify types along control flow edges originating from each return-like
+ // terminator.
+ for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
+
+ auto inputTypesForRegion =
+ [&](RegionSuccessor successor) -> FailureOr<TypeRange> {
+ OperandRange terminatorOperands =
+ regionReturnOp.getSuccessorOperands(successor);
+ return TypeRange(terminatorOperands.getTypes());
+ };
+ if (failed(verifyTypesAlongAllEdges(regionInterface, regionReturnOp,
+ inputTypesForRegion)))
+ return failure();
+ }
}
return success();
@@ -272,31 +277,74 @@ using StopConditionFn = function_ref<bool(Region *, ArrayRef<bool> visited)>;
static bool traverseRegionGraph(Region *begin,
StopConditionFn stopConditionFn) {
auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
+ LDBG() << "Starting region graph traversal from region #"
+ << begin->getRegionNumber() << " in operation " << op->getName();
+
SmallVector<bool> visited(op->getNumRegions(), false);
visited[begin->getRegionNumber()] = true;
+ LDBG() << "Initialized visited array with " << op->getNumRegions()
+ << " regions";
// Retrieve all successors of the region and enqueue them in the worklist.
SmallVector<Region *> worklist;
auto enqueueAllSuccessors = [&](Region *region) {
- SmallVector<RegionSuccessor> successors;
- op.getSuccessorRegions(region, successors);
- for (RegionSuccessor successor : successors)
- if (!successor.isParent())
- worklist.push_back(successor.getSuccessor());
+ LDBG() << "Enqueuing successors for region #" << region->getRegionNumber();
+ SmallVector<Attribute> operandAttributes(op->getNumOperands());
+ for (Block &block : *region) {
+ if (block.empty())
+ continue;
+ auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(block.back());
+ if (!terminator)
+ continue;
+ SmallVector<RegionSuccessor> successors;
+ operandAttributes.resize(terminator->getNumOperands());
+ terminator.getSuccessorRegions(operandAttributes, successors);
+ LDBG() << "Found " << successors.size()
+ << " successors from terminator in block";
+ for (RegionSuccessor successor : successors) {
+ if (!successor.isParent()) {
+ worklist.push_back(successor.getSuccessor());
+ LDBG() << "Added region #"
+ << successor.getSuccessor()->getRegionNumber()
+ << " to worklist";
+ } else {
+ LDBG() << "Skipping parent successor";
+ }
+ }
+ }
};
enqueueAllSuccessors(begin);
+ LDBG() << "Initial worklist size: " << worklist.size();
// Process all regions in the worklist via DFS.
while (!worklist.empty()) {
Region *nextRegion = worklist.pop_back_val();
- if (stopConditionFn(nextRegion, visited))
+ LDBG() << "Processing region #" << nextRegion->getRegionNumber()
+ << " from worklist (remaining: " << worklist.size() << ")";
+
+ if (stopConditionFn(nextRegion, visited)) {
+ LDBG() << "Stop condition met for region #"
+ << nextRegion->getRegionNumber() << ", returning true";
return true;
- if (visited[nextRegion->getRegionNumber()])
+ }
+ llvm::dbgs() << "Region: " << nextRegion << "\n";
+ if (!nextRegion->getParentOp()) {
+ llvm::errs() << "Region " << *nextRegion << " has no parent op\n";
+ return false;
+ }
+ if (visited[nextRegion->getRegionNumber()]) {
+ LDBG() << "Region #" << nextRegion->getRegionNumber()
+ << " already visited, skipping";
continue;
+ }
visited[nextRegion->getRegionNumber()] = true;
+ LDBG() << "Marking region #" << nextRegion->getRegionNumber()
+ << " as visited";
enqueueAllSuccessors(nextRegion);
}
+ LDBG() << "Traversal completed, returning false";
return false;
}
@@ -322,18 +370,26 @@ static bool isRegionReachable(Region *begin, Region *r) {
/// mutually exclusive if they are not reachable from each other as per
/// RegionBranchOpInterface::getSuccessorRegions.
bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
+ LDBG() << "Checking if operations are in mutually exclusive regions: "
+ << a->getName() << " and " << b->getName();
+
assert(a && "expected non-empty operation");
assert(b && "expected non-empty operation");
auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
while (branchOp) {
+ LDBG() << "Checking branch operation " << branchOp->getName();
+
// Check if b is inside branchOp. (We already know that a is.)
if (!branchOp->isProperAncestor(b)) {
+ LDBG() << "Operation b is not inside branchOp, checking next ancestor";
// Check next enclosing RegionBranchOpInterface.
branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
continue;
}
+ LDBG() << "Both operations are inside branchOp, finding their regions";
+
// b is contained in branchOp. Retrieve the regions in which `a` and `b`
// are contained.
Region *regionA = nullptr, *regionB = nullptr;
@@ -341,63 +397,136 @@ bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
if (r.findAncestorOpInRegion(*a)) {
assert(!regionA && "already found a region for a");
regionA = &r;
+ LDBG() << "Found region #" << r.getRegionNumber() << " for operation a";
}
if (r.findAncestorOpInRegion(*b)) {
assert(!regionB && "already found a region for b");
regionB = &r;
+ LDBG() << "Found region #" << r.getRegionNumber() << " for operation b";
}
}
assert(regionA && regionB && "could not find region of op");
+ LDBG() << "Region A: #" << regionA->getRegionNumber() << ", Region B: #"
+ << regionB->getRegionNumber();
+
// `a` and `b` are in mutually exclusive regions if both regions are
// distinct and neither region is reachable from the other region.
- return regionA != regionB && !isRegionReachable(regionA, regionB) &&
- !isRegionReachable(regionB, regionA);
+ bool regionsAreDistinct = (regionA != regionB);
+ bool aNotReachableFromB = !isRegionReachable(regionA, regionB);
+ bool bNotReachableFromA = !isRegionReachable(regionB, regionA);
+
+ LDBG() << "Regions distinct: " << regionsAreDistinct
+ << ", A not reachable from B: " << aNotReachableFromB
+ << ", B not reachable from A: " << bNotReachableFromA;
+
+ bool mutuallyExclusive =
+ regionsAreDistinct && aNotReachableFromB && bNotReachableFromA;
+ LDBG() << "Operations are mutually exclusive: " << mutuallyExclusive;
+
+ return mutuallyExclusive;
}
// Could not find a common RegionBranchOpInterface among a's and b's
// ancestors.
+ LDBG() << "No common RegionBranchOpInterface found, operations are not "
+ "mutually exclusive";
return false;
}
bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
+ LDBG() << "Checking if region #" << index << " is repetitive in operation "
+ << getOperation()->getName();
+
Region *region = &getOperation()->getRegion(index);
- return isRegionReachable(region, region);
+ bool isRepetitive = isRegionReachable(region, region);
+
+ LDBG() << "Region #" << index << " is repetitive: " << isRepetitive;
+ return isRepetitive;
}
bool RegionBranchOpInterface::hasLoop() {
+ LDBG() << "Checking if operation " << getOperation()->getName()
+ << " has loops";
+
SmallVector<RegionSuccessor> entryRegions;
getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
- for (RegionSuccessor successor : entryRegions)
- if (!successor.isParent() &&
- traverseRegionGraph(successor.getSuccessor(),
- [](Region *nextRegion, ArrayRef<bool> visited) {
- // Interrupt traversal if the region was already
- // visited.
- return visited[nextRegion->getRegionNumber()];
- }))
- return true;
+ LDBG() << "Found " << entryRegions.size() << " entry regions";
+
+ for (RegionSuccessor successor : entryRegions) {
+ if (!successor.isParent()) {
+ LDBG() << "Checking entry region #"
+ << successor.getSuccessor()->getRegionNumber() << " for loops";
+
+ bool hasLoop =
+ traverseRegionGraph(successor.getSuccessor(),
+ [](Region *nextRegion, ArrayRef<bool> visited) {
+ // Interrupt traversal if the region was already
+ // visited.
+ return visited[nextRegion->getRegionNumber()];
+ });
+
+ if (hasLoop) {
+ LDBG() << "Found loop in entry region #"
+ << successor.getSuccessor()->getRegionNumber();
+ return true;
+ }
+ } else {
+ LDBG() << "Skipping parent successor";
+ }
+ }
+
+ LDBG() << "No loops found in operation";
return false;
}
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
+ LDBG() << "Finding enclosing repetitive region for operation "
+ << op->getName();
+
while (Region *region = op->getParentRegion()) {
+ LDBG() << "Checking region #" << region->getRegionNumber()
+ << " in operation " << region->getParentOp()->getName();
+
op = region->getParentOp();
- if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
- if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
+ if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
+ LDBG()
+ << "Found RegionBranchOpInterface, checking if region is repetitive";
+ if (branchOp.isRepetitiveRegion(region->getRegionNumber())) {
+ LDBG() << "Found repetitive region #" << region->getRegionNumber();
return region;
+ }
+ } else {
+ LDBG() << "Parent operation does not implement RegionBranchOpInterface";
+ }
}
+
+ LDBG() << "No enclosing repetitive region found";
return nullptr;
}
Region *mlir::getEnclosingRepetitiveRegion(Value value) {
+ LDBG() << "Finding enclosing repetitive region for value";
+
Region *region = value.getParentRegion();
while (region) {
+ LDBG() << "Checking region #" << region->getRegionNumber()
+ << " in operation " << region->getParentOp()->getName();
+
Operation *op = region->getParentOp();
- if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
- if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
+ if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
+ LDBG()
+ << "Found RegionBranchOpInterface, checking if region is repetitive";
+ if (branchOp.isRepetitiveRegion(region->getRegionNumber())) {
+ LDBG() << "Found repetitive region #" << region->getRegionNumber();
return region;
+ }
+ } else {
+ LDBG() << "Parent operation does not implement RegionBranchOpInterface";
+ }
region = op->getParentRegion();
}
+
+ LDBG() << "No enclosing repetitive region found for value";
return nullptr;
}