diff options
Diffstat (limited to 'mlir/lib/Interfaces/ControlFlowInterfaces.cpp')
| -rw-r--r-- | mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 305 |
1 files changed, 217 insertions, 88 deletions
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp index ca3f766..1e56810 100644 --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -9,7 +9,9 @@ #include <utility> #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "llvm/Support/DebugLog.h" using namespace mlir; @@ -38,20 +40,31 @@ SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount, std::optional<BlockArgument> detail::getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor) { + LDBG() << "Getting branch successor argument for operand index " + << operandIndex << " in successor block"; + OperandRange forwardedOperands = operands.getForwardedOperands(); // Check that the operands are valid. - if (forwardedOperands.empty()) + if (forwardedOperands.empty()) { + LDBG() << "No forwarded operands, returning nullopt"; return std::nullopt; + } // Check to ensure that this operand is within the range. unsigned operandsStart = forwardedOperands.getBeginOperandIndex(); if (operandIndex < operandsStart || - operandIndex >= (operandsStart + forwardedOperands.size())) + operandIndex >= (operandsStart + forwardedOperands.size())) { + LDBG() << "Operand index " << operandIndex << " out of range [" + << operandsStart << ", " + << (operandsStart + forwardedOperands.size()) + << "), returning nullopt"; return std::nullopt; + } // Index the successor. unsigned argIndex = operands.getProducedOperandCount() + operandIndex - operandsStart; + LDBG() << "Computed argument index " << argIndex << " for successor block"; return successor->getArgument(argIndex); } @@ -59,9 +72,15 @@ detail::getBranchSuccessorArgument(const SuccessorOperands &operands, LogicalResult detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands) { + LDBG() << "Verifying branch successor operands for successor #" << succNo + << " in operation " << op->getName(); + // Check the count. unsigned operandCount = operands.size(); Block *destBB = op->getSuccessor(succNo); + LDBG() << "Branch has " << operandCount << " operands, target block has " + << destBB->getNumArguments() << " arguments"; + if (operandCount != destBB->getNumArguments()) return op->emitError() << "branch has " << operandCount << " operands for successor #" << succNo @@ -69,13 +88,22 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, << destBB->getNumArguments(); // Check the types. + LDBG() << "Checking type compatibility for " + << (operandCount - operands.getProducedOperandCount()) + << " forwarded operands"; for (unsigned i = operands.getProducedOperandCount(); i != operandCount; ++i) { - if (!cast<BranchOpInterface>(op).areTypesCompatible( - operands[i].getType(), destBB->getArgument(i).getType())) + Type operandType = operands[i].getType(); + Type argType = destBB->getArgument(i).getType(); + LDBG() << "Checking type compatibility: operand type " << operandType + << " vs argument type " << argType; + + if (!cast<BranchOpInterface>(op).areTypesCompatible(operandType, argType)) return op->emitError() << "type mismatch for bb argument #" << i << " of successor #" << succNo; } + + LDBG() << "Branch successor operand verification successful"; return success(); } @@ -126,15 +154,15 @@ LogicalResult detail::verifyRegionBranchWeights(Operation *op) { static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag, RegionBranchPoint sourceNo, - RegionBranchPoint succRegionNo) { + RegionSuccessor succRegionNo) { diag << "from "; - if (Region *region = sourceNo.getRegionOrNull()) - diag << "Region #" << region->getRegionNumber(); + if (Operation *op = sourceNo.getTerminatorPredecessorOrNull()) + diag << "Operation " << op->getName(); else diag << "parent operands"; diag << " to "; - if (Region *region = succRegionNo.getRegionOrNull()) + if (Region *region = succRegionNo.getSuccessor()) diag << "Region #" << region->getRegionNumber(); else diag << "parent results"; @@ -145,13 +173,12 @@ static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag, /// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the /// types of the inputs that flow to a successor region. static LogicalResult -verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, - function_ref<FailureOr<TypeRange>(RegionBranchPoint)> +verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp, + RegionBranchPoint sourcePoint, + function_ref<FailureOr<TypeRange>(RegionSuccessor)> getInputsTypesForRegion) { - auto regionInterface = cast<RegionBranchOpInterface>(op); - SmallVector<RegionSuccessor, 2> successors; - regionInterface.getSuccessorRegions(sourcePoint, successors); + branchOp.getSuccessorRegions(sourcePoint, successors); for (RegionSuccessor &succ : successors) { FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ); @@ -160,10 +187,14 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); if (sourceTypes->size() != succInputsTypes.size()) { - InFlightDiagnostic diag = op->emitOpError("region control flow edge "); + InFlightDiagnostic diag = + branchOp->emitOpError("region control flow edge "); + std::string succStr; + llvm::raw_string_ostream os(succStr); + os << succ; return printRegionEdgeName(diag, sourcePoint, succ) << ": source has " << sourceTypes->size() - << " operands, but target successor needs " + << " operands, but target successor " << os.str() << " needs " << succInputsTypes.size(); } @@ -171,8 +202,10 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { Type sourceType = std::get<0>(typesIdx.value()); Type inputType = std::get<1>(typesIdx.value()); - if (!regionInterface.areTypesCompatible(sourceType, inputType)) { - InFlightDiagnostic diag = op->emitOpError("along control flow edge "); + + if (!branchOp.areTypesCompatible(sourceType, inputType)) { + InFlightDiagnostic diag = + branchOp->emitOpError("along control flow edge "); return printRegionEdgeName(diag, sourcePoint, succ) << ": source type #" << typesIdx.index() << " " << sourceType << " should match input type #" << typesIdx.index() << " " @@ -180,6 +213,7 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, } } } + return success(); } @@ -187,34 +221,18 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { auto regionInterface = cast<RegionBranchOpInterface>(op); - auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange { - return regionInterface.getEntrySuccessorOperands(point).getTypes(); + auto inputTypesFromParent = [&](RegionSuccessor successor) -> TypeRange { + return regionInterface.getEntrySuccessorOperands(successor).getTypes(); }; // Verify types along control flow edges originating from the parent. - if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(), - inputTypesFromParent))) + if (failed(verifyTypesAlongAllEdges( + regionInterface, RegionBranchPoint::parent(), inputTypesFromParent))) return failure(); - auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) { - if (lhs.size() != rhs.size()) - return false; - for (auto types : llvm::zip(lhs, rhs)) { - if (!regionInterface.areTypesCompatible(std::get<0>(types), - std::get<1>(types))) { - return false; - } - } - return true; - }; - // Verify types along control flow edges originating from each region. for (Region ®ion : op->getRegions()) { - - // Since there can be multiple terminators implementing the - // `RegionBranchTerminatorOpInterface`, all should have the same operand - // types when passing them to the same region. - + // Collect all return-like terminators in the region. SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps; for (Block &block : region) if (!block.empty()) @@ -227,33 +245,20 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { if (regionReturnOps.empty()) continue; - auto inputTypesForRegion = - [&](RegionBranchPoint point) -> FailureOr<TypeRange> { - std::optional<OperandRange> regionReturnOperands; - for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) { - auto terminatorOperands = regionReturnOp.getSuccessorOperands(point); - - if (!regionReturnOperands) { - regionReturnOperands = terminatorOperands; - continue; - } - - // Found more than one ReturnLike terminator. Make sure the operand - // types match with the first one. - if (!areTypesCompatible(regionReturnOperands->getTypes(), - terminatorOperands.getTypes())) { - InFlightDiagnostic diag = op->emitOpError("along control flow edge"); - return printRegionEdgeName(diag, region, point) - << " operands mismatch between return-like terminators"; - } - } - - // All successors get the same set of operand types. - return TypeRange(regionReturnOperands->getTypes()); - }; - - if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion))) - return failure(); + // Verify types along control flow edges originating from each return-like + // terminator. + for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) { + + auto inputTypesForRegion = + [&](RegionSuccessor successor) -> FailureOr<TypeRange> { + OperandRange terminatorOperands = + regionReturnOp.getSuccessorOperands(successor); + return TypeRange(terminatorOperands.getTypes()); + }; + if (failed(verifyTypesAlongAllEdges(regionInterface, regionReturnOp, + inputTypesForRegion))) + return failure(); + } } return success(); @@ -272,31 +277,74 @@ using StopConditionFn = function_ref<bool(Region *, ArrayRef<bool> visited)>; static bool traverseRegionGraph(Region *begin, StopConditionFn stopConditionFn) { auto op = cast<RegionBranchOpInterface>(begin->getParentOp()); + LDBG() << "Starting region graph traversal from region #" + << begin->getRegionNumber() << " in operation " << op->getName(); + SmallVector<bool> visited(op->getNumRegions(), false); visited[begin->getRegionNumber()] = true; + LDBG() << "Initialized visited array with " << op->getNumRegions() + << " regions"; // Retrieve all successors of the region and enqueue them in the worklist. SmallVector<Region *> worklist; auto enqueueAllSuccessors = [&](Region *region) { - SmallVector<RegionSuccessor> successors; - op.getSuccessorRegions(region, successors); - for (RegionSuccessor successor : successors) - if (!successor.isParent()) - worklist.push_back(successor.getSuccessor()); + LDBG() << "Enqueuing successors for region #" << region->getRegionNumber(); + SmallVector<Attribute> operandAttributes(op->getNumOperands()); + for (Block &block : *region) { + if (block.empty()) + continue; + auto terminator = + dyn_cast<RegionBranchTerminatorOpInterface>(block.back()); + if (!terminator) + continue; + SmallVector<RegionSuccessor> successors; + operandAttributes.resize(terminator->getNumOperands()); + terminator.getSuccessorRegions(operandAttributes, successors); + LDBG() << "Found " << successors.size() + << " successors from terminator in block"; + for (RegionSuccessor successor : successors) { + if (!successor.isParent()) { + worklist.push_back(successor.getSuccessor()); + LDBG() << "Added region #" + << successor.getSuccessor()->getRegionNumber() + << " to worklist"; + } else { + LDBG() << "Skipping parent successor"; + } + } + } }; enqueueAllSuccessors(begin); + LDBG() << "Initial worklist size: " << worklist.size(); // Process all regions in the worklist via DFS. while (!worklist.empty()) { Region *nextRegion = worklist.pop_back_val(); - if (stopConditionFn(nextRegion, visited)) + LDBG() << "Processing region #" << nextRegion->getRegionNumber() + << " from worklist (remaining: " << worklist.size() << ")"; + + if (stopConditionFn(nextRegion, visited)) { + LDBG() << "Stop condition met for region #" + << nextRegion->getRegionNumber() << ", returning true"; return true; - if (visited[nextRegion->getRegionNumber()]) + } + llvm::dbgs() << "Region: " << nextRegion << "\n"; + if (!nextRegion->getParentOp()) { + llvm::errs() << "Region " << *nextRegion << " has no parent op\n"; + return false; + } + if (visited[nextRegion->getRegionNumber()]) { + LDBG() << "Region #" << nextRegion->getRegionNumber() + << " already visited, skipping"; continue; + } visited[nextRegion->getRegionNumber()] = true; + LDBG() << "Marking region #" << nextRegion->getRegionNumber() + << " as visited"; enqueueAllSuccessors(nextRegion); } + LDBG() << "Traversal completed, returning false"; return false; } @@ -322,18 +370,26 @@ static bool isRegionReachable(Region *begin, Region *r) { /// mutually exclusive if they are not reachable from each other as per /// RegionBranchOpInterface::getSuccessorRegions. bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { + LDBG() << "Checking if operations are in mutually exclusive regions: " + << a->getName() << " and " << b->getName(); + assert(a && "expected non-empty operation"); assert(b && "expected non-empty operation"); auto branchOp = a->getParentOfType<RegionBranchOpInterface>(); while (branchOp) { + LDBG() << "Checking branch operation " << branchOp->getName(); + // Check if b is inside branchOp. (We already know that a is.) if (!branchOp->isProperAncestor(b)) { + LDBG() << "Operation b is not inside branchOp, checking next ancestor"; // Check next enclosing RegionBranchOpInterface. branchOp = branchOp->getParentOfType<RegionBranchOpInterface>(); continue; } + LDBG() << "Both operations are inside branchOp, finding their regions"; + // b is contained in branchOp. Retrieve the regions in which `a` and `b` // are contained. Region *regionA = nullptr, *regionB = nullptr; @@ -341,63 +397,136 @@ bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { if (r.findAncestorOpInRegion(*a)) { assert(!regionA && "already found a region for a"); regionA = &r; + LDBG() << "Found region #" << r.getRegionNumber() << " for operation a"; } if (r.findAncestorOpInRegion(*b)) { assert(!regionB && "already found a region for b"); regionB = &r; + LDBG() << "Found region #" << r.getRegionNumber() << " for operation b"; } } assert(regionA && regionB && "could not find region of op"); + LDBG() << "Region A: #" << regionA->getRegionNumber() << ", Region B: #" + << regionB->getRegionNumber(); + // `a` and `b` are in mutually exclusive regions if both regions are // distinct and neither region is reachable from the other region. - return regionA != regionB && !isRegionReachable(regionA, regionB) && - !isRegionReachable(regionB, regionA); + bool regionsAreDistinct = (regionA != regionB); + bool aNotReachableFromB = !isRegionReachable(regionA, regionB); + bool bNotReachableFromA = !isRegionReachable(regionB, regionA); + + LDBG() << "Regions distinct: " << regionsAreDistinct + << ", A not reachable from B: " << aNotReachableFromB + << ", B not reachable from A: " << bNotReachableFromA; + + bool mutuallyExclusive = + regionsAreDistinct && aNotReachableFromB && bNotReachableFromA; + LDBG() << "Operations are mutually exclusive: " << mutuallyExclusive; + + return mutuallyExclusive; } // Could not find a common RegionBranchOpInterface among a's and b's // ancestors. + LDBG() << "No common RegionBranchOpInterface found, operations are not " + "mutually exclusive"; return false; } bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) { + LDBG() << "Checking if region #" << index << " is repetitive in operation " + << getOperation()->getName(); + Region *region = &getOperation()->getRegion(index); - return isRegionReachable(region, region); + bool isRepetitive = isRegionReachable(region, region); + + LDBG() << "Region #" << index << " is repetitive: " << isRepetitive; + return isRepetitive; } bool RegionBranchOpInterface::hasLoop() { + LDBG() << "Checking if operation " << getOperation()->getName() + << " has loops"; + SmallVector<RegionSuccessor> entryRegions; getSuccessorRegions(RegionBranchPoint::parent(), entryRegions); - for (RegionSuccessor successor : entryRegions) - if (!successor.isParent() && - traverseRegionGraph(successor.getSuccessor(), - [](Region *nextRegion, ArrayRef<bool> visited) { - // Interrupt traversal if the region was already - // visited. - return visited[nextRegion->getRegionNumber()]; - })) - return true; + LDBG() << "Found " << entryRegions.size() << " entry regions"; + + for (RegionSuccessor successor : entryRegions) { + if (!successor.isParent()) { + LDBG() << "Checking entry region #" + << successor.getSuccessor()->getRegionNumber() << " for loops"; + + bool hasLoop = + traverseRegionGraph(successor.getSuccessor(), + [](Region *nextRegion, ArrayRef<bool> visited) { + // Interrupt traversal if the region was already + // visited. + return visited[nextRegion->getRegionNumber()]; + }); + + if (hasLoop) { + LDBG() << "Found loop in entry region #" + << successor.getSuccessor()->getRegionNumber(); + return true; + } + } else { + LDBG() << "Skipping parent successor"; + } + } + + LDBG() << "No loops found in operation"; return false; } Region *mlir::getEnclosingRepetitiveRegion(Operation *op) { + LDBG() << "Finding enclosing repetitive region for operation " + << op->getName(); + while (Region *region = op->getParentRegion()) { + LDBG() << "Checking region #" << region->getRegionNumber() + << " in operation " << region->getParentOp()->getName(); + op = region->getParentOp(); - if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) - if (branchOp.isRepetitiveRegion(region->getRegionNumber())) + if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) { + LDBG() + << "Found RegionBranchOpInterface, checking if region is repetitive"; + if (branchOp.isRepetitiveRegion(region->getRegionNumber())) { + LDBG() << "Found repetitive region #" << region->getRegionNumber(); return region; + } + } else { + LDBG() << "Parent operation does not implement RegionBranchOpInterface"; + } } + + LDBG() << "No enclosing repetitive region found"; return nullptr; } Region *mlir::getEnclosingRepetitiveRegion(Value value) { + LDBG() << "Finding enclosing repetitive region for value"; + Region *region = value.getParentRegion(); while (region) { + LDBG() << "Checking region #" << region->getRegionNumber() + << " in operation " << region->getParentOp()->getName(); + Operation *op = region->getParentOp(); - if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) - if (branchOp.isRepetitiveRegion(region->getRegionNumber())) + if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) { + LDBG() + << "Found RegionBranchOpInterface, checking if region is repetitive"; + if (branchOp.isRepetitiveRegion(region->getRegionNumber())) { + LDBG() << "Found repetitive region #" << region->getRegionNumber(); return region; + } + } else { + LDBG() << "Parent operation does not implement RegionBranchOpInterface"; + } region = op->getParentRegion(); } + + LDBG() << "No enclosing repetitive region found for value"; return nullptr; } |
