diff options
Diffstat (limited to 'mlir/lib')
33 files changed, 854 insertions, 375 deletions
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp index a84d10d..24cb123 100644 --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -16,19 +16,21 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/Region.h" #include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Support/LLVM.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/DebugLog.h" #include <cassert> #include <optional> #include <utility> using namespace mlir; +#define DEBUG_TYPE "local-alias-analysis" + //===----------------------------------------------------------------------===// // Underlying Address Computation //===----------------------------------------------------------------------===// @@ -42,81 +44,47 @@ static void collectUnderlyingAddressValues(Value value, unsigned maxDepth, DenseSet<Value> &visited, SmallVectorImpl<Value> &output); -/// Given a successor (`region`) of a RegionBranchOpInterface, collect all of -/// the underlying values being addressed by one of the successor inputs. If the -/// provided `region` is null, as per `RegionBranchOpInterface` this represents -/// the parent operation. -static void collectUnderlyingAddressValues(RegionBranchOpInterface branch, - Region *region, Value inputValue, - unsigned inputIndex, - unsigned maxDepth, - DenseSet<Value> &visited, - SmallVectorImpl<Value> &output) { - // Given the index of a region of the branch (`predIndex`), or std::nullopt to - // represent the parent operation, try to return the index into the outputs of - // this region predecessor that correspond to the input values of `region`. If - // an index could not be found, std::nullopt is returned instead. - auto getOperandIndexIfPred = - [&](RegionBranchPoint pred) -> std::optional<unsigned> { - SmallVector<RegionSuccessor, 2> successors; - branch.getSuccessorRegions(pred, successors); - for (RegionSuccessor &successor : successors) { - if (successor.getSuccessor() != region) - continue; - // Check that the successor inputs map to the given input value. - ValueRange inputs = successor.getSuccessorInputs(); - if (inputs.empty()) { - output.push_back(inputValue); - break; - } - unsigned firstInputIndex, lastInputIndex; - if (region) { - firstInputIndex = cast<BlockArgument>(inputs[0]).getArgNumber(); - lastInputIndex = cast<BlockArgument>(inputs.back()).getArgNumber(); - } else { - firstInputIndex = cast<OpResult>(inputs[0]).getResultNumber(); - lastInputIndex = cast<OpResult>(inputs.back()).getResultNumber(); - } - if (firstInputIndex > inputIndex || lastInputIndex < inputIndex) { - output.push_back(inputValue); - break; - } - return inputIndex - firstInputIndex; - } - return std::nullopt; - }; - - // Check branches from the parent operation. - auto branchPoint = RegionBranchPoint::parent(); - if (region) - branchPoint = region; - - if (std::optional<unsigned> operandIndex = - getOperandIndexIfPred(/*predIndex=*/RegionBranchPoint::parent())) { - collectUnderlyingAddressValues( - branch.getEntrySuccessorOperands(branchPoint)[*operandIndex], maxDepth, - visited, output); +/// Given a RegionBranchOpInterface operation (`branch`), a Value`inputValue` +/// which is an input for the provided successor (`initialSuccessor`), try to +/// find the possible sources for the value along the control flow edges. +static void collectUnderlyingAddressValues2( + RegionBranchOpInterface branch, RegionSuccessor initialSuccessor, + Value inputValue, unsigned inputIndex, unsigned maxDepth, + DenseSet<Value> &visited, SmallVectorImpl<Value> &output) { + LDBG() << "collectUnderlyingAddressValues2: " + << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions()); + LDBG() << " with initialSuccessor " << initialSuccessor; + LDBG() << " inputValue: " << inputValue; + LDBG() << " inputIndex: " << inputIndex; + LDBG() << " maxDepth: " << maxDepth; + ValueRange inputs = initialSuccessor.getSuccessorInputs(); + if (inputs.empty()) { + LDBG() << " input is empty, enqueue value"; + output.push_back(inputValue); + return; } - // Check branches from each child region. - Operation *op = branch.getOperation(); - for (Region ®ion : op->getRegions()) { - if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(region)) { - for (Block &block : region) { - // Try to determine possible region-branch successor operands for the - // current region. - if (auto term = dyn_cast<RegionBranchTerminatorOpInterface>( - block.getTerminator())) { - collectUnderlyingAddressValues( - term.getSuccessorOperands(branchPoint)[*operandIndex], maxDepth, - visited, output); - } else if (block.getNumSuccessors()) { - // Otherwise, if this terminator may exit the region we can't make - // any assumptions about which values get passed. - output.push_back(inputValue); - return; - } - } - } + unsigned firstInputIndex, lastInputIndex; + if (isa<BlockArgument>(inputs[0])) { + firstInputIndex = cast<BlockArgument>(inputs[0]).getArgNumber(); + lastInputIndex = cast<BlockArgument>(inputs.back()).getArgNumber(); + } else { + firstInputIndex = cast<OpResult>(inputs[0]).getResultNumber(); + lastInputIndex = cast<OpResult>(inputs.back()).getResultNumber(); + } + if (firstInputIndex > inputIndex || lastInputIndex < inputIndex) { + LDBG() << " !! Input index " << inputIndex << " out of range " + << firstInputIndex << " to " << lastInputIndex + << ", adding input value to output"; + output.push_back(inputValue); + return; + } + SmallVector<Value> predecessorValues; + branch.getPredecessorValues(initialSuccessor, inputIndex - firstInputIndex, + predecessorValues); + LDBG() << " Found " << predecessorValues.size() << " predecessor values"; + for (Value predecessorValue : predecessorValues) { + LDBG() << " Processing predecessor value: " << predecessorValue; + collectUnderlyingAddressValues(predecessorValue, maxDepth, visited, output); } } @@ -124,22 +92,28 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch, static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth, DenseSet<Value> &visited, SmallVectorImpl<Value> &output) { + LDBG() << "collectUnderlyingAddressValues (OpResult): " << result; + LDBG() << " maxDepth: " << maxDepth; + Operation *op = result.getOwner(); // If this is a view, unwrap to the source. if (ViewLikeOpInterface view = dyn_cast<ViewLikeOpInterface>(op)) { if (result == view.getViewDest()) { + LDBG() << " Unwrapping view to source: " << view.getViewSource(); return collectUnderlyingAddressValues(view.getViewSource(), maxDepth, visited, output); } } // Check to see if we can reason about the control flow of this op. if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { - return collectUnderlyingAddressValues(branch, /*region=*/nullptr, result, - result.getResultNumber(), maxDepth, - visited, output); + LDBG() << " Processing region branch operation"; + return collectUnderlyingAddressValues2( + branch, RegionSuccessor(op, op->getResults()), result, + result.getResultNumber(), maxDepth, visited, output); } + LDBG() << " Adding result to output: " << result; output.push_back(result); } @@ -148,14 +122,23 @@ static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth, static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth, DenseSet<Value> &visited, SmallVectorImpl<Value> &output) { + LDBG() << "collectUnderlyingAddressValues (BlockArgument): " << arg; + LDBG() << " maxDepth: " << maxDepth; + LDBG() << " argNumber: " << arg.getArgNumber(); + LDBG() << " isEntryBlock: " << arg.getOwner()->isEntryBlock(); + Block *block = arg.getOwner(); unsigned argNumber = arg.getArgNumber(); // Handle the case of a non-entry block. if (!block->isEntryBlock()) { + LDBG() << " Processing non-entry block with " + << std::distance(block->pred_begin(), block->pred_end()) + << " predecessors"; for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { auto branch = dyn_cast<BranchOpInterface>((*it)->getTerminator()); if (!branch) { + LDBG() << " Cannot analyze control flow, adding argument to output"; // We can't analyze the control flow, so bail out early. output.push_back(arg); return; @@ -165,10 +148,12 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth, unsigned index = it.getSuccessorIndex(); Value operand = branch.getSuccessorOperands(index)[argNumber]; if (!operand) { + LDBG() << " No operand found for argument, adding to output"; // We can't analyze the control flow, so bail out early. output.push_back(arg); return; } + LDBG() << " Processing operand from predecessor: " << operand; collectUnderlyingAddressValues(operand, maxDepth, visited, output); } return; @@ -178,10 +163,35 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth, Region *region = block->getParent(); Operation *op = region->getParentOp(); if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { - return collectUnderlyingAddressValues(branch, region, arg, argNumber, - maxDepth, visited, output); + LDBG() << " Processing region branch operation for entry block"; + // We have to find the successor matching the region, so that the input + // arguments are correctly set. + // TODO: this isn't comprehensive: the successor may not be reachable from + // the entry block. + SmallVector<RegionSuccessor> successors; + branch.getSuccessorRegions(RegionBranchPoint::parent(), successors); + RegionSuccessor regionSuccessor(region); + bool found = false; + for (RegionSuccessor &successor : successors) { + if (successor.getSuccessor() == region) { + LDBG() << " Found matching region successor: " << successor; + found = true; + regionSuccessor = successor; + break; + } + } + if (!found) { + LDBG() + << " No matching region successor found, adding argument to output"; + output.push_back(arg); + return; + } + return collectUnderlyingAddressValues2( + branch, regionSuccessor, arg, argNumber, maxDepth, visited, output); } + LDBG() + << " Cannot reason about underlying address, adding argument to output"; // We can't reason about the underlying address of this argument. output.push_back(arg); } @@ -190,17 +200,26 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth, static void collectUnderlyingAddressValues(Value value, unsigned maxDepth, DenseSet<Value> &visited, SmallVectorImpl<Value> &output) { + LDBG() << "collectUnderlyingAddressValues: " << value; + LDBG() << " maxDepth: " << maxDepth; + // Check that we don't infinitely recurse. - if (!visited.insert(value).second) + if (!visited.insert(value).second) { + LDBG() << " Value already visited, skipping"; return; + } if (maxDepth == 0) { + LDBG() << " Max depth reached, adding value to output"; output.push_back(value); return; } --maxDepth; - if (BlockArgument arg = dyn_cast<BlockArgument>(value)) + if (BlockArgument arg = dyn_cast<BlockArgument>(value)) { + LDBG() << " Processing as BlockArgument"; return collectUnderlyingAddressValues(arg, maxDepth, visited, output); + } + LDBG() << " Processing as OpResult"; collectUnderlyingAddressValues(cast<OpResult>(value), maxDepth, visited, output); } @@ -208,9 +227,11 @@ static void collectUnderlyingAddressValues(Value value, unsigned maxDepth, /// Given a value, collect all of the underlying values being addressed. static void collectUnderlyingAddressValues(Value value, SmallVectorImpl<Value> &output) { + LDBG() << "collectUnderlyingAddressValues: " << value; DenseSet<Value> visited; collectUnderlyingAddressValues(value, maxUnderlyingValueSearchDepth, visited, output); + LDBG() << " Collected " << output.size() << " underlying values"; } //===----------------------------------------------------------------------===// @@ -227,19 +248,33 @@ static LogicalResult getAllocEffectFor(Value value, std::optional<MemoryEffects::EffectInstance> &effect, Operation *&allocScopeOp) { + LDBG() << "getAllocEffectFor: " << value; + // Try to get a memory effect interface for the parent operation. Operation *op; - if (BlockArgument arg = dyn_cast<BlockArgument>(value)) + if (BlockArgument arg = dyn_cast<BlockArgument>(value)) { op = arg.getOwner()->getParentOp(); - else + LDBG() << " BlockArgument, parent op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); + } else { op = cast<OpResult>(value).getOwner(); + LDBG() << " OpResult, owner op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); + } + MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); - if (!interface) + if (!interface) { + LDBG() << " No memory effect interface found"; return failure(); + } // Try to find an allocation effect on the resource. - if (!(effect = interface.getEffectOnValue<MemoryEffects::Allocate>(value))) + if (!(effect = interface.getEffectOnValue<MemoryEffects::Allocate>(value))) { + LDBG() << " No allocation effect found on value"; return failure(); + } + + LDBG() << " Found allocation effect"; // If we found an allocation effect, try to find a scope for the allocation. // If the resource of this allocation is automatically scoped, find the parent @@ -247,6 +282,12 @@ getAllocEffectFor(Value value, if (llvm::isa<SideEffects::AutomaticAllocationScopeResource>( effect->getResource())) { allocScopeOp = op->getParentWithTrait<OpTrait::AutomaticAllocationScope>(); + if (allocScopeOp) { + LDBG() << " Automatic allocation scope found: " + << OpWithFlags(allocScopeOp, OpPrintingFlags().skipRegions()); + } else { + LDBG() << " Automatic allocation scope found: null"; + } return success(); } @@ -255,6 +296,12 @@ getAllocEffectFor(Value value, // For now assume allocation scope to the function scope (we don't care if // pointer escape outside function). allocScopeOp = op->getParentOfType<FunctionOpInterface>(); + if (allocScopeOp) { + LDBG() << " Function scope found: " + << OpWithFlags(allocScopeOp, OpPrintingFlags().skipRegions()); + } else { + LDBG() << " Function scope found: null"; + } return success(); } @@ -293,33 +340,44 @@ static std::optional<AliasResult> checkDistinctObjects(Value lhs, Value rhs) { /// Given the two values, return their aliasing behavior. AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { - if (lhs == rhs) + LDBG() << "aliasImpl: " << lhs << " vs " << rhs; + + if (lhs == rhs) { + LDBG() << " Same value, must alias"; return AliasResult::MustAlias; + } + Operation *lhsAllocScope = nullptr, *rhsAllocScope = nullptr; std::optional<MemoryEffects::EffectInstance> lhsAlloc, rhsAlloc; // Handle the case where lhs is a constant. Attribute lhsAttr, rhsAttr; if (matchPattern(lhs, m_Constant(&lhsAttr))) { + LDBG() << " lhs is constant"; // TODO: This is overly conservative. Two matching constants don't // necessarily map to the same address. For example, if the two values // correspond to different symbols that both represent a definition. - if (matchPattern(rhs, m_Constant(&rhsAttr))) + if (matchPattern(rhs, m_Constant(&rhsAttr))) { + LDBG() << " rhs is also constant, may alias"; return AliasResult::MayAlias; + } // Try to find an alloc effect on rhs. If an effect was found we can't // alias, otherwise we might. - return succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope)) - ? AliasResult::NoAlias - : AliasResult::MayAlias; + bool rhsHasAlloc = + succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope)); + LDBG() << " rhs has alloc effect: " << rhsHasAlloc; + return rhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias; } // Handle the case where rhs is a constant. if (matchPattern(rhs, m_Constant(&rhsAttr))) { + LDBG() << " rhs is constant"; // Try to find an alloc effect on lhs. If an effect was found we can't // alias, otherwise we might. - return succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope)) - ? AliasResult::NoAlias - : AliasResult::MayAlias; + bool lhsHasAlloc = + succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope)); + LDBG() << " lhs has alloc effect: " << lhsHasAlloc; + return lhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias; } if (std::optional<AliasResult> result = checkDistinctObjects(lhs, rhs)) @@ -329,9 +387,14 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { // an allocation effect. bool lhsHasAlloc = succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope)); bool rhsHasAlloc = succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope)); + LDBG() << " lhs has alloc effect: " << lhsHasAlloc; + LDBG() << " rhs has alloc effect: " << rhsHasAlloc; + if (lhsHasAlloc == rhsHasAlloc) { // If both values have an allocation effect we know they don't alias, and if // neither have an effect we can't make an assumptions. + LDBG() << " Both have same alloc status: " + << (lhsHasAlloc ? "NoAlias" : "MayAlias"); return lhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias; } @@ -339,6 +402,7 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { // and one without. Move the one with the effect to the lhs to make the next // checks simpler. if (rhsHasAlloc) { + LDBG() << " Swapping lhs and rhs to put alloc effect on lhs"; std::swap(lhs, rhs); lhsAlloc = rhsAlloc; lhsAllocScope = rhsAllocScope; @@ -347,49 +411,74 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) { // If the effect has a scoped allocation region, check to see if the // non-effect value is defined above that scope. if (lhsAllocScope) { + LDBG() << " Checking allocation scope: " + << OpWithFlags(lhsAllocScope, OpPrintingFlags().skipRegions()); // If the parent operation of rhs is an ancestor of the allocation scope, or // if rhs is an entry block argument of the allocation scope we know the two // values can't alias. Operation *rhsParentOp = rhs.getParentRegion()->getParentOp(); - if (rhsParentOp->isProperAncestor(lhsAllocScope)) + if (rhsParentOp->isProperAncestor(lhsAllocScope)) { + LDBG() << " rhs parent is ancestor of alloc scope, no alias"; return AliasResult::NoAlias; + } if (rhsParentOp == lhsAllocScope) { BlockArgument rhsArg = dyn_cast<BlockArgument>(rhs); - if (rhsArg && rhs.getParentBlock()->isEntryBlock()) + if (rhsArg && rhs.getParentBlock()->isEntryBlock()) { + LDBG() << " rhs is entry block arg of alloc scope, no alias"; return AliasResult::NoAlias; + } } } // If we couldn't reason about the relationship between the two values, // conservatively assume they might alias. + LDBG() << " Cannot reason about relationship, may alias"; return AliasResult::MayAlias; } /// Given the two values, return their aliasing behavior. AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) { - if (lhs == rhs) + LDBG() << "alias: " << lhs << " vs " << rhs; + + if (lhs == rhs) { + LDBG() << " Same value, must alias"; return AliasResult::MustAlias; + } // Get the underlying values being addressed. SmallVector<Value, 8> lhsValues, rhsValues; collectUnderlyingAddressValues(lhs, lhsValues); collectUnderlyingAddressValues(rhs, rhsValues); + LDBG() << " lhs underlying values: " << lhsValues.size(); + LDBG() << " rhs underlying values: " << rhsValues.size(); + // If we failed to collect for either of the values somehow, conservatively // assume they may alias. - if (lhsValues.empty() || rhsValues.empty()) + if (lhsValues.empty() || rhsValues.empty()) { + LDBG() << " Failed to collect underlying values, may alias"; return AliasResult::MayAlias; + } // Check the alias results against each of the underlying values. std::optional<AliasResult> result; for (Value lhsVal : lhsValues) { for (Value rhsVal : rhsValues) { + LDBG() << " Checking underlying values: " << lhsVal << " vs " << rhsVal; AliasResult nextResult = aliasImpl(lhsVal, rhsVal); + LDBG() << " Result: " + << (nextResult == AliasResult::MustAlias ? "MustAlias" + : nextResult == AliasResult::NoAlias ? "NoAlias" + : "MayAlias"); result = result ? result->merge(nextResult) : nextResult; } } // We should always have a valid result here. + LDBG() << " Final result: " + << (result->isMust() ? "MustAlias" + : result->isNo() ? "NoAlias" + : "MayAlias"); return *result; } @@ -398,8 +487,12 @@ AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) { //===----------------------------------------------------------------------===// ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) { + LDBG() << "getModRef: " << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " on location " << location; + // Check to see if this operation relies on nested side effects. if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) { + LDBG() << " Operation has recursive memory effects, returning ModAndRef"; // TODO: To check recursive operations we need to check all of the nested // operations, which can result in a quadratic number of queries. We should // introduce some caching of some kind to help alleviate this, especially as @@ -410,38 +503,64 @@ ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) { // Otherwise, check to see if this operation has a memory effect interface. MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op); - if (!interface) + if (!interface) { + LDBG() << " No memory effect interface, returning ModAndRef"; return ModRefResult::getModAndRef(); + } // Build a ModRefResult by merging the behavior of the effects of this // operation. SmallVector<MemoryEffects::EffectInstance> effects; interface.getEffects(effects); + LDBG() << " Found " << effects.size() << " memory effects"; ModRefResult result = ModRefResult::getNoModRef(); for (const MemoryEffects::EffectInstance &effect : effects) { - if (isa<MemoryEffects::Allocate, MemoryEffects::Free>(effect.getEffect())) + if (isa<MemoryEffects::Allocate, MemoryEffects::Free>(effect.getEffect())) { + LDBG() << " Skipping alloc/free effect"; continue; + } // Check for an alias between the effect and our memory location. // TODO: Add support for checking an alias with a symbol reference. AliasResult aliasResult = AliasResult::MayAlias; - if (Value effectValue = effect.getValue()) + if (Value effectValue = effect.getValue()) { + LDBG() << " Checking alias between effect value " << effectValue + << " and location " << location; aliasResult = alias(effectValue, location); + LDBG() << " Alias result: " + << (aliasResult.isMust() ? "MustAlias" + : aliasResult.isNo() ? "NoAlias" + : "MayAlias"); + } else { + LDBG() << " No effect value, assuming MayAlias"; + } // If we don't alias, ignore this effect. - if (aliasResult.isNo()) + if (aliasResult.isNo()) { + LDBG() << " No alias, ignoring effect"; continue; + } // Merge in the corresponding mod or ref for this effect. if (isa<MemoryEffects::Read>(effect.getEffect())) { + LDBG() << " Adding Ref to result"; result = result.merge(ModRefResult::getRef()); } else { assert(isa<MemoryEffects::Write>(effect.getEffect())); + LDBG() << " Adding Mod to result"; result = result.merge(ModRefResult::getMod()); } - if (result.isModAndRef()) + if (result.isModAndRef()) { + LDBG() << " Result is now ModAndRef, breaking"; break; + } } + + LDBG() << " Final ModRef result: " + << (result.isModAndRef() ? "ModAndRef" + : result.isMod() ? "Mod" + : result.isRef() ? "Ref" + : "NoModRef"); return result; } diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp index 377f7eb..0fc5b44 100644 --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -501,11 +501,10 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op, return; SmallVector<RegionSuccessor> successors; - if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) - terminator.getSuccessorRegions(*operands, successors); - else - branch.getSuccessorRegions(op->getParentRegion(), successors); - + auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op); + if (!terminator) + return; + terminator.getSuccessorRegions(*operands, successors); visitRegionBranchEdges(branch, op, successors); } diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp index daa3db5..0682e5f 100644 --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -588,7 +588,9 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) { // flow, propagate the lattice back along the control flow edge. if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) { LDBG() << " Exit block of region branch operation"; - visitRegionBranchOperation(point, branch, block->getParent(), before); + auto terminator = + cast<RegionBranchTerminatorOpInterface>(block->getTerminator()); + visitRegionBranchOperation(point, branch, terminator, before); return; } diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index 0d2e2ed..8e63ae8 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -130,7 +130,7 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) { // The results of a region branch operation are determined by control-flow. if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { visitRegionSuccessors(getProgramPointAfter(branch), branch, - /*successor=*/RegionBranchPoint::parent(), + /*successor=*/{branch, branch->getResults()}, resultLattices); return success(); } @@ -279,7 +279,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation( void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( ProgramPoint *point, RegionBranchOpInterface branch, - RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) { + RegionSuccessor successor, ArrayRef<AbstractSparseLattice *> lattices) { const auto *predecessors = getOrCreateFor<PredecessorState>(point, point); assert(predecessors->allPredecessorsKnown() && "unexpected unresolved region successors"); @@ -314,7 +314,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors( visitNonControlFlowArgumentsImpl( branch, RegionSuccessor( - branch->getResults().slice(firstIndex, inputs.size())), + branch, branch->getResults().slice(firstIndex, inputs.size())), lattices, firstIndex); } else { if (!inputs.empty()) diff --git a/mlir/lib/Analysis/SliceWalk.cpp b/mlir/lib/Analysis/SliceWalk.cpp index 817d71a..863f260 100644 --- a/mlir/lib/Analysis/SliceWalk.cpp +++ b/mlir/lib/Analysis/SliceWalk.cpp @@ -114,7 +114,7 @@ mlir::getControlFlowPredecessors(Value value) { if (!regionOp) return std::nullopt; // Add the control flow predecessor operands to the work list. - RegionSuccessor region(regionOp->getResults()); + RegionSuccessor region(regionOp, regionOp->getResults()); SmallVector<Value> predecessorOperands = getRegionPredecessorOperands( regionOp, region, opResult.getResultNumber()); return predecessorOperands; diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 478b6aa..1eca43d 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -989,21 +989,17 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) { smfma.getN(), smfma.getK(), 1u, chipset); } -/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` -/// if one exists. This includes checking to ensure the intrinsic is supported -/// on the architecture you are compiling for. -static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, - Chipset chipset) { - auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType()); - auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType()); - auto destVectorType = cast<VectorType>(wmma.getDestC().getType()); - Type elemSourceType = sourceVectorType.getElementType(); - Type elemBSourceType = sourceBVectorType.getElementType(); - Type elemDestType = destVectorType.getElementType(); - - const uint32_t k = wmma.getK(); - +/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma` +/// for RDNA3/4 architectures. +static std::optional<StringRef> +wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, + Type elemDestType, uint32_t k, bool isRDNA3) { + using fp8 = Float8E4M3FNType; + using bf8 = Float8E5M2Type; + + // Handle k == 16 for RDNA3/4. if (k == 16) { + // Common patterns for RDNA3 and RDNA4. if (elemSourceType.isF16() && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); if (elemSourceType.isBF16() && elemDestType.isF32()) @@ -1014,39 +1010,160 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); - if (chipset.majorVersion == 11) { + + // RDNA3 specific patterns. + if (isRDNA3) { if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); + return std::nullopt; } - } - if (chipset.majorVersion < 12) - return std::nullopt; - // gfx12+ - if (k == 16) { - if (isa<Float8E4M3FNType>(elemSourceType) && - isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) + // RDNA4 specific patterns (fp8/bf8). + if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) && + elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName(); - if (isa<Float8E4M3FNType>(elemSourceType) && - isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32()) + if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) && + elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName(); - if (isa<Float8E5M2Type>(elemSourceType) && - isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32()) + if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) && + elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName(); - if (isa<Float8E5M2Type>(elemSourceType) && - isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32()) + if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) && + elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName(); if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); return std::nullopt; } - if (k == 32) { + + // Handle k == 32 for RDNA4. + if (k == 32 && !isRDNA3) { if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); + } + + llvm_unreachable("Unsupported k value"); +} + +/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` +/// for the gfx1250 architecture. +static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType, + Type elemBSourceType, + Type elemDestType, + uint32_t k) { + using fp8 = Float8E4M3FNType; + using bf8 = Float8E5M2Type; + + if (k == 4) { + if (elemSourceType.isF32() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x4_f32::getOperationName(); + return std::nullopt; } + if (k == 32) { + if (elemSourceType.isF16() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x32_f16::getOperationName(); + if (elemSourceType.isBF16() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x32_bf16::getOperationName(); + if (elemSourceType.isF16() && elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x32_f16::getOperationName(); + if (elemSourceType.isBF16() && elemDestType.isBF16()) + return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName(); + + return std::nullopt; + } + + if (k == 64) { + if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName(); + } + if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName(); + } + if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName(); + } + if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName(); + } + if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) + return ROCDL::wmma_i32_16x16x64_iu8::getOperationName(); + + return std::nullopt; + } + + if (k == 128) { + if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName(); + } + if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName(); + } + if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName(); + } + if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName(); + } + + return std::nullopt; + } + + llvm_unreachable("Unsupported k value"); +} + +/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma` +/// if one exists. This includes checking to ensure the intrinsic is supported +/// on the architecture you are compiling for. +static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma, + Chipset chipset) { + auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType()); + auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType()); + auto destVectorType = cast<VectorType>(wmma.getDestC().getType()); + Type elemSourceType = sourceVectorType.getElementType(); + Type elemBSourceType = sourceBVectorType.getElementType(); + Type elemDestType = destVectorType.getElementType(); + + const uint32_t k = wmma.getK(); + const bool isRDNA3 = chipset.majorVersion == 11; + const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0; + + // Handle RDNA3 and RDNA4. + if (isRDNA3 || isRDNA4) + return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType, + k, isRDNA3); + + // Handle gfx1250. + if (chipset == Chipset{12, 5, 0}) + return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType, + elemDestType, k); + llvm_unreachable("unhandled WMMA case"); } diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 0fe7239..9e46b7d 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -313,25 +313,53 @@ private: struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> { using OpConversionPattern<complex::ExpOp>::OpConversionPattern; + // exp(x+I*y) = exp(x)*(cos(y)+I*sin(y)) + // Handle special cases as StableHLO implementation does: + // 1. When b == 0, set imag(exp(z)) = 0 + // 2. When exp(x) == inf, use exp(x/2)*(cos(y)+I*sin(y))*exp(x/2) LogicalResult matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto type = cast<ComplexType>(adaptor.getComplex().getType()); - auto elementType = cast<FloatType>(type.getElementType()); - arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); - - Value real = - complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex()); - Value imag = - complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex()); - Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue()); - Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue()); + auto ET = cast<FloatType>(type.getElementType()); + arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue(); + const auto &floatSemantics = ET.getFloatSemantics(); + ImplicitLocOpBuilder b(loc, rewriter); + + Value x = complex::ReOp::create(b, ET, adaptor.getComplex()); + Value y = complex::ImOp::create(b, ET, adaptor.getComplex()); + Value zero = arith::ConstantOp::create(b, ET, b.getZeroAttr(ET)); + Value half = arith::ConstantOp::create(b, ET, b.getFloatAttr(ET, 0.5)); + Value inf = arith::ConstantOp::create( + b, ET, b.getFloatAttr(ET, APFloat::getInf(floatSemantics))); + + Value exp = math::ExpOp::create(b, x, fmf); + Value xHalf = arith::MulFOp::create(b, x, half, fmf); + Value expHalf = math::ExpOp::create(b, xHalf, fmf); + Value cos = math::CosOp::create(b, y, fmf); + Value sin = math::SinOp::create(b, y, fmf); + + Value expIsInf = + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, exp, inf, fmf); + Value yIsZero = + arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, y, zero); + + // Real path: select between exp(x)*cos(y) and exp(x/2)*cos(y)*exp(x/2) + Value realNormal = arith::MulFOp::create(b, exp, cos, fmf); + Value expHalfCos = arith::MulFOp::create(b, expHalf, cos, fmf); + Value realOverflow = arith::MulFOp::create(b, expHalfCos, expHalf, fmf); Value resultReal = - arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue()); - Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue()); - Value resultImag = - arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue()); + arith::SelectOp::create(b, expIsInf, realOverflow, realNormal); + + // Imaginary part: if y == 0 return 0 else select between exp(x)*sin(y) and + // exp(x/2)*sin(y)*exp(x/2) + Value imagNormal = arith::MulFOp::create(b, exp, sin, fmf); + Value expHalfSin = arith::MulFOp::create(b, expHalf, sin, fmf); + Value imagOverflow = arith::MulFOp::create(b, expHalfSin, expHalf, fmf); + Value imagNonZero = + arith::SelectOp::create(b, expIsInf, imagOverflow, imagNormal); + Value resultImag = arith::SelectOp::create(b, yIsZero, zero, imagNonZero); rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal, resultImag); diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp index b711e33..a4c66e1 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -692,7 +692,7 @@ SymbolRefAttr PatternLowering::generateRewriter( llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue); args.append(mappedArgs.begin(), mappedArgs.end()); pdl_interp::ApplyRewriteOp::create(builder, rewriter.getLoc(), - /*resultTypes=*/TypeRange(), rewriteName, + /*results=*/TypeRange(), rewriteName, args); } else { // Otherwise this is a dag rewriter defined using PDL operations. diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 585b6da..df955fc 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -399,13 +399,15 @@ LogicalResult WMMAOp::verify() { if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) { return emitOpError( - "source element types much match (except for fp8) but have ") + "source element types must match (except for fp8/bf8) but have ") << sourceAType << " and " << sourceBType; } - if (!sourceAElemType.isInteger(4) && getK() != 16) { - return emitOpError("K dimension must be 16 for source element type ") - << sourceAElemType; + if (isSrcFloat) { + if (getClamp()) + return emitOpError("clamp flag is not supported for float types"); + if (getUnsignedA() || getUnsignedB()) + return emitOpError("unsigned flags are not supported for float types"); } return success(); } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index e0a53cd..0c35921 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2716,8 +2716,9 @@ LogicalResult AffineForOp::fold(FoldAdaptor adaptor, return success(folded); } -OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert((point.isParent() || point == getRegion()) && "invalid region point"); +OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert((successor.isParent() || successor.getSuccessor() == &getRegion()) && + "invalid region point"); // The initial operands map to the loop arguments after the induction // variable or are forwarded to the results when the trip count is zero. @@ -2726,34 +2727,41 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) { void AffineForOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { - assert((point.isParent() || point == getRegion()) && "expected loop region"); + assert((point.isParent() || + point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getRegion()) && + "expected loop region"); // The loop may typically branch back to its body or to the parent operation. // If the predecessor is the parent op and the trip count is known to be at // least one, branch into the body using the iterator arguments. And in cases // we know the trip count is zero, it can only branch back to its parent. std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this); - if (point.isParent() && tripCount.has_value()) { - if (tripCount.value() > 0) { - regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - return; - } - if (tripCount.value() == 0) { - regions.push_back(RegionSuccessor(getResults())); - return; + if (tripCount.has_value()) { + if (!point.isParent()) { + // From the loop body, if the trip count is one, we can only branch back + // to the parent. + if (tripCount == 1) { + regions.push_back(RegionSuccessor(getOperation(), getResults())); + return; + } + if (tripCount == 0) + return; + } else { + if (tripCount.value() > 0) { + regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); + return; + } + if (tripCount.value() == 0) { + regions.push_back(RegionSuccessor(getOperation(), getResults())); + return; + } } } - // From the loop body, if the trip count is one, we can only branch back to - // the parent. - if (!point.isParent() && tripCount == 1) { - regions.push_back(RegionSuccessor(getResults())); - return; - } - // In all other cases, the loop may branch back to itself or the parent // operation. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } AffineBound AffineForOp::getLowerBound() { @@ -3142,7 +3150,7 @@ void AffineIfOp::getSuccessorRegions( RegionSuccessor(&getThenRegion(), getThenRegion().getArguments())); // If the "else" region is empty, branch bach into parent. if (getElseRegion().empty()) { - regions.push_back(getResults()); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } else { regions.push_back( RegionSuccessor(&getElseRegion(), getElseRegion().getArguments())); @@ -3152,7 +3160,7 @@ void AffineIfOp::getSuccessorRegions( // If the predecessor is the `else`/`then` region, then branching into parent // op is valid. - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } LogicalResult AffineIfOp::verify() { diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp index d925c19..a651710 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp @@ -216,8 +216,8 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) { for (auto condBranch : worklist) { auto loc = condBranch.getLoc(); Block *block = condBranch->getBlock(); - auto newTrueBranch = rewriter.splitBlock(block, block->end()); - auto newFalseBranch = rewriter.splitBlock(block, block->end()); + auto *newTrueBranch = rewriter.splitBlock(block, block->end()); + auto *newFalseBranch = rewriter.splitBlock(block, block->end()); insertJump(loc, newTrueBranch, condBranch.getTrueDest(), condBranch.getTrueDestOperands()); insertJump(loc, newFalseBranch, condBranch.getFalseDest(), @@ -382,7 +382,7 @@ gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap, // Find or create a live range for `value`. auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator); LiveRange &valueLiveRange = it->second; - auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef); + auto *lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef); // Add the interval [firstUseOrDef, lastUseInBlock) to the live range. unsigned startOpIdx = operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0); diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp index dc7b07d..8e4a49d 100644 --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -36,8 +36,9 @@ void AsyncDialect::initialize() { constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes"; -OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBodyRegion() && "invalid region index"); +OperandRange ExecuteOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getBodyRegion() && + "invalid region index"); return getBodyOperands(); } @@ -53,8 +54,10 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) { void ExecuteOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { // The `body` region branch back to the parent operation. - if (point == getBodyRegion()) { - regions.push_back(RegionSuccessor(getBodyResults())); + if (!point.isParent() && + point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getBodyRegion()) { + regions.push_back(RegionSuccessor(getOperation(), getBodyResults())); return; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp index b593cca..36a759c 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp @@ -562,8 +562,11 @@ LogicalResult BufferDeallocation::updateFunctionSignature(FunctionOpInterface op) { SmallVector<TypeRange> returnOperandTypes(llvm::map_range( op.getFunctionBody().getOps<RegionBranchTerminatorOpInterface>(), - [](RegionBranchTerminatorOpInterface op) { - return op.getSuccessorOperands(RegionBranchPoint::parent()).getTypes(); + [&](RegionBranchTerminatorOpInterface branchOp) { + return branchOp + .getSuccessorOperands(RegionSuccessor( + op.getOperation(), op.getOperation()->getResults())) + .getTypes(); })); if (!llvm::all_equal(returnOperandTypes)) return op->emitError( @@ -942,8 +945,8 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) { // about, but we would need to check how many successors there are and under // which condition they are taken, etc. - MutableOperandRange operands = - op.getMutableSuccessorOperands(RegionBranchPoint::parent()); + MutableOperandRange operands = op.getMutableSuccessorOperands( + RegionSuccessor(op.getOperation(), op.getOperation()->getResults())); SmallVector<Value> updatedOwnerships; auto result = deallocation_impl::insertDeallocOpForReturnLike( diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 4754f0b..0992ce14 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -845,7 +845,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { // The `then` and the `else` region branch back to the parent operation. if (!point.isParent()) { - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); return; } @@ -854,7 +855,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, // Don't consider the else region if it is empty. Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); else regions.push_back(RegionSuccessor(elseRegion)); } @@ -871,7 +873,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, if (!getElseRegion().empty()) regions.emplace_back(&getElseRegion()); else - regions.emplace_back(); + regions.emplace_back(getOperation(), getOperation()->getResults()); } } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index b5f8dda..6c6d8d2 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2399,7 +2399,7 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser, void WarpExecuteOnLane0Op::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index eb2d825..bd25e94 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -495,13 +495,14 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter, if (failed(maybePackedDimForEachOperand)) return failure(); packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand; - listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims)); LDBG() << "++++ After pack size #" << i << ": " << packedSizes[i]; LDBG() << "maps: " << llvm::interleaved(indexingMaps); LDBG() << "iterators: " << llvm::interleaved(iteratorTypes); LDBG() << "packedDimForEachOperand: " << llvm::interleaved(packedOperandsDims.packedDimForEachOperand); + + listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims)); } // Step 2. Propagate packing to all LinalgOp operands. diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index c551fba..1c21a2f 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -405,7 +405,7 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) { void AllocaScopeOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp index 6fa8ce4..69afbca 100644 --- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp @@ -98,6 +98,27 @@ struct RankOpInterface } }; +struct CollapseShapeOpInterface + : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface, + memref::CollapseShapeOp> { + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + auto collapseOp = cast<memref::CollapseShapeOp>(op); + assert(value == collapseOp.getResult() && "invalid value"); + + // Multiply the expressions for the dimensions in the reassociation group. + const ReassociationIndices reassocIndices = + collapseOp.getReassociationIndices()[dim]; + AffineExpr productExpr = + cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]); + for (size_t i = 1; i < reassocIndices.size(); ++i) { + productExpr = + productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]); + } + cstr.bound(value)[dim] == productExpr; + } +}; + struct SubViewOpInterface : public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface, SubViewOp> { @@ -134,6 +155,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels( memref::AllocOpInterface<memref::AllocaOp>>(*ctx); memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx); memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx); + memref::CollapseShapeOp::attachInterface<memref::CollapseShapeOpInterface>( + *ctx); memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>( *ctx); memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 1ab01d8..2946b53 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -397,7 +397,7 @@ void ExecuteRegionOp::getSuccessorRegions( } // Otherwise, the region branches back to the parent operation. - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } //===----------------------------------------------------------------------===// @@ -405,10 +405,11 @@ void ExecuteRegionOp::getSuccessorRegions( //===----------------------------------------------------------------------===// MutableOperandRange -ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) { - assert((point.isParent() || point == getParentOp().getAfter()) && - "condition op can only exit the loop or branch to the after" - "region"); +ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) { + assert( + (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) && + "condition op can only exit the loop or branch to the after" + "region"); // Pass all operands except the condition to the successor region. return getArgsMutable(); } @@ -426,7 +427,7 @@ void ConditionOp::getSuccessorRegions( regions.emplace_back(&whileOp.getAfter(), whileOp.getAfter().getArguments()); if (!boolAttr || !boolAttr.getValue()) - regions.emplace_back(whileOp.getResults()); + regions.emplace_back(whileOp.getOperation(), whileOp.getResults()); } //===----------------------------------------------------------------------===// @@ -749,7 +750,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) { return dyn_cast_or_null<ForOp>(containingOp); } -OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) { +OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) { return getInitArgs(); } @@ -759,7 +760,7 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point, // back into the operation itself. It is possible for loop not to enter the // body. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; } @@ -2053,9 +2054,10 @@ void ForallOp::getSuccessorRegions(RegionBranchPoint point, // parallel by multiple threads. We should not expect to branch back into // the forall body after the region's execution is complete. if (point.isParent()) - regions.push_back(RegionSuccessor(&getRegion())); + regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); else - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); } //===----------------------------------------------------------------------===// @@ -2333,9 +2335,10 @@ void IfOp::print(OpAsmPrinter &p) { void IfOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { - // The `then` and the `else` region branch back to the parent operation. + // The `then` and the `else` region branch back to the parent operation or one + // of the recursive parent operations (early exit case). if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } @@ -2344,7 +2347,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point, // Don't consider the else region if it is empty. Region *elseRegion = &this->getElseRegion(); if (elseRegion->empty()) - regions.push_back(RegionSuccessor()); + regions.push_back( + RegionSuccessor(getOperation(), getOperation()->getResults())); else regions.push_back(RegionSuccessor(elseRegion)); } @@ -2361,7 +2365,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands, if (!getElseRegion().empty()) regions.emplace_back(&getElseRegion()); else - regions.emplace_back(getResults()); + regions.emplace_back(getOperation(), getResults()); } } @@ -3385,7 +3389,8 @@ void ParallelOp::getSuccessorRegions( // back into the operation itself. It is possible for loop not to enter the // body. regions.push_back(RegionSuccessor(&getRegion())); - regions.push_back(RegionSuccessor()); + regions.push_back(RegionSuccessor( + getOperation(), ResultRange{getResults().end(), getResults().end()})); } //===----------------------------------------------------------------------===// @@ -3431,7 +3436,7 @@ LogicalResult ReduceOp::verifyRegions() { } MutableOperandRange -ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) { +ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) { // No operands are forwarded to the next iteration. return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0); } @@ -3514,8 +3519,8 @@ Block::BlockArgListType WhileOp::getRegionIterArgs() { return getBeforeArguments(); } -OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBefore() && +OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getBefore() && "WhileOp is expected to branch only to the first region"); return getInits(); } @@ -3528,15 +3533,18 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point, return; } - assert(llvm::is_contained({&getAfter(), &getBefore()}, point) && + assert(llvm::is_contained( + {&getAfter(), &getBefore()}, + point.getTerminatorPredecessorOrNull()->getParentRegion()) && "there are only two regions in a WhileOp"); // The body region always branches back to the condition region. - if (point == getAfter()) { + if (point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getAfter()) { regions.emplace_back(&getBefore(), getBefore().getArguments()); return; } - regions.emplace_back(getResults()); + regions.emplace_back(getOperation(), getResults()); regions.emplace_back(&getAfter(), getAfter().getArguments()); } @@ -4445,7 +4453,7 @@ void IndexSwitchOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) { // All regions branch back to the parent op. if (!point.isParent()) { - successors.emplace_back(getResults()); + successors.emplace_back(getOperation(), getResults()); return; } diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp index ae52af5..ddcbda8 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -23,7 +23,6 @@ namespace mlir { #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir -using namespace llvm; using namespace mlir; using scf::ForOp; using scf::WhileOp; diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp index a2f03f1..00bef70 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp @@ -21,7 +21,6 @@ namespace mlir { #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir -using namespace llvm; using namespace mlir; using scf::LoopNest; diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 5ba8289..f0f22e5 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -346,7 +346,7 @@ void AssumingOp::getSuccessorRegions( // parent, so return the correct RegionSuccessor purely based on the index // being None or 0. if (!point.isParent()) { - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); return; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 1a9d9e1..3962e3e 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -2597,7 +2597,7 @@ std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() { std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); } -OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) { +OperandRange IterateOp::getEntrySuccessorOperands(RegionSuccessor successor) { return getInitArgs(); } @@ -2607,7 +2607,7 @@ void IterateOp::getSuccessorRegions(RegionBranchPoint point, // or back into the operation itself. regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs())); // It is possible for loop not to enter the body. - regions.push_back(RegionSuccessor(getResults())); + regions.push_back(RegionSuccessor(getOperation(), getResults())); } void CoIterateOp::build(OpBuilder &builder, OperationState &odsState, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp index f53d272..ffa8b40 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp @@ -152,19 +152,20 @@ IterationGraphSorter IterationGraphSorter::fromGenericOp( } IterationGraphSorter::IterationGraphSorter( - SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl, Value out, - AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes, + SmallVector<Value> &&insArg, SmallVector<AffineMap> &&loop2InsLvlArg, + Value out, AffineMap loop2OutLvl, + SmallVector<utils::IteratorType> &&iterTypesArg, sparse_tensor::LoopOrderingStrategy strategy) - : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out), - loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)), + : ins(std::move(insArg)), loop2InsLvl(std::move(loop2InsLvlArg)), out(out), + loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypesArg)), strategy(strategy) { // One map per tensor. - assert(this->loop2InsLvl.size() == this->ins.size()); + assert(loop2InsLvl.size() == ins.size()); // All the affine maps have the same number of dimensions (loops). assert(llvm::all_equal(llvm::map_range( - this->loop2InsLvl, [](AffineMap m) { return m.getNumDims(); }))); + loop2InsLvl, [](AffineMap m) { return m.getNumDims(); }))); // The number of results of the map should match the rank of the tensor. - assert(llvm::all_of(llvm::zip(this->loop2InsLvl, this->ins), [](auto mvPair) { + assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) { auto [m, v] = mvPair; // For ranked types the rank must match. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h index b2a16e9..35e58ed 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h @@ -59,10 +59,10 @@ public: private: // Private constructor. - IterationGraphSorter(SmallVector<Value> &&ins, - SmallVector<AffineMap> &&loop2InsLvl, Value out, + IterationGraphSorter(SmallVector<Value> &&insArg, + SmallVector<AffineMap> &&loop2InsLvlArg, Value out, AffineMap loop2OutLvl, - SmallVector<utils::IteratorType> &&iterTypes, + SmallVector<utils::IteratorType> &&iterTypesArg, sparse_tensor::LoopOrderingStrategy strategy = sparse_tensor::LoopOrderingStrategy::kDefault); diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp index 1e3b377..549ac7a 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp @@ -77,7 +77,7 @@ FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer( dyn_cast<TilingInterface>(consumerOperands.front()->getOwner()); if (!consumerOp) return failure(); - for (auto opOperand : consumerOperands.drop_front()) { + for (auto *opOperand : consumerOperands.drop_front()) { if (opOperand->getOwner() != consumerOp) { LLVM_DEBUG({ llvm::dbgs() diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 365afab..062606e 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -96,9 +96,9 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform, // AlternativesOp //===----------------------------------------------------------------------===// -OperandRange -transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { - if (!point.isParent() && getOperation()->getNumOperands() == 1) +OperandRange transform::AlternativesOp::getEntrySuccessorOperands( + RegionSuccessor successor) { + if (!successor.isParent() && getOperation()->getNumOperands() == 1) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), getOperation()->operand_end()); @@ -107,15 +107,18 @@ transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) { void transform::AlternativesOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { for (Region &alternative : llvm::drop_begin( - getAlternatives(), - point.isParent() ? 0 - : point.getRegionOrNull()->getRegionNumber() + 1)) { + getAlternatives(), point.isParent() + ? 0 + : point.getTerminatorPredecessorOrNull() + ->getParentRegion() + ->getRegionNumber() + + 1)) { regions.emplace_back(&alternative, !getOperands().empty() ? alternative.getArguments() : Block::BlockArgListType()); } if (!point.isParent()) - regions.emplace_back(getOperation()->getResults()); + regions.emplace_back(getOperation(), getOperation()->getResults()); } void transform::AlternativesOp::getRegionInvocationBounds( @@ -1740,16 +1743,18 @@ void transform::ForeachOp::getSuccessorRegions( } // Branch back to the region or the parent. - assert(point == getBody() && "unexpected region index"); + assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getBody() && + "unexpected region index"); regions.emplace_back(bodyRegion, bodyRegion->getArguments()); - regions.emplace_back(); + regions.emplace_back(getOperation(), getOperation()->getResults()); } OperandRange -transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) { +transform::ForeachOp::getEntrySuccessorOperands(RegionSuccessor successor) { // Each block argument handle is mapped to a subset (one op to be precise) // of the payload of the corresponding `targets` operand of ForeachOp. - assert(point == getBody() && "unexpected region index"); + assert(successor.getSuccessor() == &getBody() && "unexpected region index"); return getOperation()->getOperands(); } @@ -2948,8 +2953,8 @@ void transform::SequenceOp::getEffects( } OperandRange -transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBody() && "unexpected region index"); +transform::SequenceOp::getEntrySuccessorOperands(RegionSuccessor successor) { + assert(successor.getSuccessor() == &getBody() && "unexpected region index"); if (getOperation()->getNumOperands() > 0) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), @@ -2966,8 +2971,10 @@ void transform::SequenceOp::getSuccessorRegions( return; } - assert(point == getBody() && "unexpected region index"); - regions.emplace_back(getOperation()->getResults()); + assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == + &getBody() && + "unexpected region index"); + regions.emplace_back(getOperation(), getOperation()->getResults()); } void transform::SequenceOp::getRegionInvocationBounds( diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp index c627158..f727118 100644 --- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h" @@ -112,7 +113,7 @@ static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer, } OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands( - RegionBranchPoint point) { + RegionSuccessor successor) { // No operands will be forwarded to the region(s). return getOperands().slice(0, 0); } @@ -128,7 +129,7 @@ void transform::tune::AlternativesOp::getSuccessorRegions( for (Region &alternative : getAlternatives()) regions.emplace_back(&alternative, Block::BlockArgListType()); else - regions.emplace_back(getOperation()->getResults()); + regions.emplace_back(getOperation(), getOperation()->getResults()); } void transform::tune::AlternativesOp::getRegionInvocationBounds( diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index 776b5c6..f4c9242 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -138,6 +138,10 @@ Diagnostic &Diagnostic::operator<<(Operation &op) { return appendOp(op, OpPrintingFlags()); } +Diagnostic &Diagnostic::operator<<(OpWithFlags op) { + return appendOp(*op.getOperation(), op.flags()); +} + Diagnostic &Diagnostic::appendOp(Operation &op, const OpPrintingFlags &flags) { std::string str; llvm::raw_string_ostream os(str); diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp index 46b6298..15a941f 100644 --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -253,6 +253,21 @@ void Region::OpIterator::skipOverBlocksWithNoOps() { operation = block->begin(); } +llvm::raw_ostream &mlir::operator<<(llvm::raw_ostream &os, Region ®ion) { + if (!region.getParentOp()) { + os << "Region has no parent op"; + } else { + os << "Region #" << region.getRegionNumber() << " in operation " + << region.getParentOp()->getName(); + } + for (auto it : llvm::enumerate(region.getBlocks())) { + os << "\n Block #" << it.index() << ":"; + for (Operation &op : it.value().getOperations()) + os << "\n " << OpWithFlags(&op, OpPrintingFlags().skipRegions()); + } + return os; +} + //===----------------------------------------------------------------------===// // RegionRange //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp index ca3f766..1e56810 100644 --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -9,7 +9,9 @@ #include <utility> #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "llvm/Support/DebugLog.h" using namespace mlir; @@ -38,20 +40,31 @@ SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount, std::optional<BlockArgument> detail::getBranchSuccessorArgument(const SuccessorOperands &operands, unsigned operandIndex, Block *successor) { + LDBG() << "Getting branch successor argument for operand index " + << operandIndex << " in successor block"; + OperandRange forwardedOperands = operands.getForwardedOperands(); // Check that the operands are valid. - if (forwardedOperands.empty()) + if (forwardedOperands.empty()) { + LDBG() << "No forwarded operands, returning nullopt"; return std::nullopt; + } // Check to ensure that this operand is within the range. unsigned operandsStart = forwardedOperands.getBeginOperandIndex(); if (operandIndex < operandsStart || - operandIndex >= (operandsStart + forwardedOperands.size())) + operandIndex >= (operandsStart + forwardedOperands.size())) { + LDBG() << "Operand index " << operandIndex << " out of range [" + << operandsStart << ", " + << (operandsStart + forwardedOperands.size()) + << "), returning nullopt"; return std::nullopt; + } // Index the successor. unsigned argIndex = operands.getProducedOperandCount() + operandIndex - operandsStart; + LDBG() << "Computed argument index " << argIndex << " for successor block"; return successor->getArgument(argIndex); } @@ -59,9 +72,15 @@ detail::getBranchSuccessorArgument(const SuccessorOperands &operands, LogicalResult detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, const SuccessorOperands &operands) { + LDBG() << "Verifying branch successor operands for successor #" << succNo + << " in operation " << op->getName(); + // Check the count. unsigned operandCount = operands.size(); Block *destBB = op->getSuccessor(succNo); + LDBG() << "Branch has " << operandCount << " operands, target block has " + << destBB->getNumArguments() << " arguments"; + if (operandCount != destBB->getNumArguments()) return op->emitError() << "branch has " << operandCount << " operands for successor #" << succNo @@ -69,13 +88,22 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo, << destBB->getNumArguments(); // Check the types. + LDBG() << "Checking type compatibility for " + << (operandCount - operands.getProducedOperandCount()) + << " forwarded operands"; for (unsigned i = operands.getProducedOperandCount(); i != operandCount; ++i) { - if (!cast<BranchOpInterface>(op).areTypesCompatible( - operands[i].getType(), destBB->getArgument(i).getType())) + Type operandType = operands[i].getType(); + Type argType = destBB->getArgument(i).getType(); + LDBG() << "Checking type compatibility: operand type " << operandType + << " vs argument type " << argType; + + if (!cast<BranchOpInterface>(op).areTypesCompatible(operandType, argType)) return op->emitError() << "type mismatch for bb argument #" << i << " of successor #" << succNo; } + + LDBG() << "Branch successor operand verification successful"; return success(); } @@ -126,15 +154,15 @@ LogicalResult detail::verifyRegionBranchWeights(Operation *op) { static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag, RegionBranchPoint sourceNo, - RegionBranchPoint succRegionNo) { + RegionSuccessor succRegionNo) { diag << "from "; - if (Region *region = sourceNo.getRegionOrNull()) - diag << "Region #" << region->getRegionNumber(); + if (Operation *op = sourceNo.getTerminatorPredecessorOrNull()) + diag << "Operation " << op->getName(); else diag << "parent operands"; diag << " to "; - if (Region *region = succRegionNo.getRegionOrNull()) + if (Region *region = succRegionNo.getSuccessor()) diag << "Region #" << region->getRegionNumber(); else diag << "parent results"; @@ -145,13 +173,12 @@ static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag, /// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the /// types of the inputs that flow to a successor region. static LogicalResult -verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, - function_ref<FailureOr<TypeRange>(RegionBranchPoint)> +verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp, + RegionBranchPoint sourcePoint, + function_ref<FailureOr<TypeRange>(RegionSuccessor)> getInputsTypesForRegion) { - auto regionInterface = cast<RegionBranchOpInterface>(op); - SmallVector<RegionSuccessor, 2> successors; - regionInterface.getSuccessorRegions(sourcePoint, successors); + branchOp.getSuccessorRegions(sourcePoint, successors); for (RegionSuccessor &succ : successors) { FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ); @@ -160,10 +187,14 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); if (sourceTypes->size() != succInputsTypes.size()) { - InFlightDiagnostic diag = op->emitOpError("region control flow edge "); + InFlightDiagnostic diag = + branchOp->emitOpError("region control flow edge "); + std::string succStr; + llvm::raw_string_ostream os(succStr); + os << succ; return printRegionEdgeName(diag, sourcePoint, succ) << ": source has " << sourceTypes->size() - << " operands, but target successor needs " + << " operands, but target successor " << os.str() << " needs " << succInputsTypes.size(); } @@ -171,8 +202,10 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { Type sourceType = std::get<0>(typesIdx.value()); Type inputType = std::get<1>(typesIdx.value()); - if (!regionInterface.areTypesCompatible(sourceType, inputType)) { - InFlightDiagnostic diag = op->emitOpError("along control flow edge "); + + if (!branchOp.areTypesCompatible(sourceType, inputType)) { + InFlightDiagnostic diag = + branchOp->emitOpError("along control flow edge "); return printRegionEdgeName(diag, sourcePoint, succ) << ": source type #" << typesIdx.index() << " " << sourceType << " should match input type #" << typesIdx.index() << " " @@ -180,6 +213,7 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, } } } + return success(); } @@ -187,34 +221,18 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint, LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { auto regionInterface = cast<RegionBranchOpInterface>(op); - auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange { - return regionInterface.getEntrySuccessorOperands(point).getTypes(); + auto inputTypesFromParent = [&](RegionSuccessor successor) -> TypeRange { + return regionInterface.getEntrySuccessorOperands(successor).getTypes(); }; // Verify types along control flow edges originating from the parent. - if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(), - inputTypesFromParent))) + if (failed(verifyTypesAlongAllEdges( + regionInterface, RegionBranchPoint::parent(), inputTypesFromParent))) return failure(); - auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) { - if (lhs.size() != rhs.size()) - return false; - for (auto types : llvm::zip(lhs, rhs)) { - if (!regionInterface.areTypesCompatible(std::get<0>(types), - std::get<1>(types))) { - return false; - } - } - return true; - }; - // Verify types along control flow edges originating from each region. for (Region ®ion : op->getRegions()) { - - // Since there can be multiple terminators implementing the - // `RegionBranchTerminatorOpInterface`, all should have the same operand - // types when passing them to the same region. - + // Collect all return-like terminators in the region. SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps; for (Block &block : region) if (!block.empty()) @@ -227,33 +245,20 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { if (regionReturnOps.empty()) continue; - auto inputTypesForRegion = - [&](RegionBranchPoint point) -> FailureOr<TypeRange> { - std::optional<OperandRange> regionReturnOperands; - for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) { - auto terminatorOperands = regionReturnOp.getSuccessorOperands(point); - - if (!regionReturnOperands) { - regionReturnOperands = terminatorOperands; - continue; - } - - // Found more than one ReturnLike terminator. Make sure the operand - // types match with the first one. - if (!areTypesCompatible(regionReturnOperands->getTypes(), - terminatorOperands.getTypes())) { - InFlightDiagnostic diag = op->emitOpError("along control flow edge"); - return printRegionEdgeName(diag, region, point) - << " operands mismatch between return-like terminators"; - } - } - - // All successors get the same set of operand types. - return TypeRange(regionReturnOperands->getTypes()); - }; - - if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion))) - return failure(); + // Verify types along control flow edges originating from each return-like + // terminator. + for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) { + + auto inputTypesForRegion = + [&](RegionSuccessor successor) -> FailureOr<TypeRange> { + OperandRange terminatorOperands = + regionReturnOp.getSuccessorOperands(successor); + return TypeRange(terminatorOperands.getTypes()); + }; + if (failed(verifyTypesAlongAllEdges(regionInterface, regionReturnOp, + inputTypesForRegion))) + return failure(); + } } return success(); @@ -272,31 +277,74 @@ using StopConditionFn = function_ref<bool(Region *, ArrayRef<bool> visited)>; static bool traverseRegionGraph(Region *begin, StopConditionFn stopConditionFn) { auto op = cast<RegionBranchOpInterface>(begin->getParentOp()); + LDBG() << "Starting region graph traversal from region #" + << begin->getRegionNumber() << " in operation " << op->getName(); + SmallVector<bool> visited(op->getNumRegions(), false); visited[begin->getRegionNumber()] = true; + LDBG() << "Initialized visited array with " << op->getNumRegions() + << " regions"; // Retrieve all successors of the region and enqueue them in the worklist. SmallVector<Region *> worklist; auto enqueueAllSuccessors = [&](Region *region) { - SmallVector<RegionSuccessor> successors; - op.getSuccessorRegions(region, successors); - for (RegionSuccessor successor : successors) - if (!successor.isParent()) - worklist.push_back(successor.getSuccessor()); + LDBG() << "Enqueuing successors for region #" << region->getRegionNumber(); + SmallVector<Attribute> operandAttributes(op->getNumOperands()); + for (Block &block : *region) { + if (block.empty()) + continue; + auto terminator = + dyn_cast<RegionBranchTerminatorOpInterface>(block.back()); + if (!terminator) + continue; + SmallVector<RegionSuccessor> successors; + operandAttributes.resize(terminator->getNumOperands()); + terminator.getSuccessorRegions(operandAttributes, successors); + LDBG() << "Found " << successors.size() + << " successors from terminator in block"; + for (RegionSuccessor successor : successors) { + if (!successor.isParent()) { + worklist.push_back(successor.getSuccessor()); + LDBG() << "Added region #" + << successor.getSuccessor()->getRegionNumber() + << " to worklist"; + } else { + LDBG() << "Skipping parent successor"; + } + } + } }; enqueueAllSuccessors(begin); + LDBG() << "Initial worklist size: " << worklist.size(); // Process all regions in the worklist via DFS. while (!worklist.empty()) { Region *nextRegion = worklist.pop_back_val(); - if (stopConditionFn(nextRegion, visited)) + LDBG() << "Processing region #" << nextRegion->getRegionNumber() + << " from worklist (remaining: " << worklist.size() << ")"; + + if (stopConditionFn(nextRegion, visited)) { + LDBG() << "Stop condition met for region #" + << nextRegion->getRegionNumber() << ", returning true"; return true; - if (visited[nextRegion->getRegionNumber()]) + } + llvm::dbgs() << "Region: " << nextRegion << "\n"; + if (!nextRegion->getParentOp()) { + llvm::errs() << "Region " << *nextRegion << " has no parent op\n"; + return false; + } + if (visited[nextRegion->getRegionNumber()]) { + LDBG() << "Region #" << nextRegion->getRegionNumber() + << " already visited, skipping"; continue; + } visited[nextRegion->getRegionNumber()] = true; + LDBG() << "Marking region #" << nextRegion->getRegionNumber() + << " as visited"; enqueueAllSuccessors(nextRegion); } + LDBG() << "Traversal completed, returning false"; return false; } @@ -322,18 +370,26 @@ static bool isRegionReachable(Region *begin, Region *r) { /// mutually exclusive if they are not reachable from each other as per /// RegionBranchOpInterface::getSuccessorRegions. bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { + LDBG() << "Checking if operations are in mutually exclusive regions: " + << a->getName() << " and " << b->getName(); + assert(a && "expected non-empty operation"); assert(b && "expected non-empty operation"); auto branchOp = a->getParentOfType<RegionBranchOpInterface>(); while (branchOp) { + LDBG() << "Checking branch operation " << branchOp->getName(); + // Check if b is inside branchOp. (We already know that a is.) if (!branchOp->isProperAncestor(b)) { + LDBG() << "Operation b is not inside branchOp, checking next ancestor"; // Check next enclosing RegionBranchOpInterface. branchOp = branchOp->getParentOfType<RegionBranchOpInterface>(); continue; } + LDBG() << "Both operations are inside branchOp, finding their regions"; + // b is contained in branchOp. Retrieve the regions in which `a` and `b` // are contained. Region *regionA = nullptr, *regionB = nullptr; @@ -341,63 +397,136 @@ bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) { if (r.findAncestorOpInRegion(*a)) { assert(!regionA && "already found a region for a"); regionA = &r; + LDBG() << "Found region #" << r.getRegionNumber() << " for operation a"; } if (r.findAncestorOpInRegion(*b)) { assert(!regionB && "already found a region for b"); regionB = &r; + LDBG() << "Found region #" << r.getRegionNumber() << " for operation b"; } } assert(regionA && regionB && "could not find region of op"); + LDBG() << "Region A: #" << regionA->getRegionNumber() << ", Region B: #" + << regionB->getRegionNumber(); + // `a` and `b` are in mutually exclusive regions if both regions are // distinct and neither region is reachable from the other region. - return regionA != regionB && !isRegionReachable(regionA, regionB) && - !isRegionReachable(regionB, regionA); + bool regionsAreDistinct = (regionA != regionB); + bool aNotReachableFromB = !isRegionReachable(regionA, regionB); + bool bNotReachableFromA = !isRegionReachable(regionB, regionA); + + LDBG() << "Regions distinct: " << regionsAreDistinct + << ", A not reachable from B: " << aNotReachableFromB + << ", B not reachable from A: " << bNotReachableFromA; + + bool mutuallyExclusive = + regionsAreDistinct && aNotReachableFromB && bNotReachableFromA; + LDBG() << "Operations are mutually exclusive: " << mutuallyExclusive; + + return mutuallyExclusive; } // Could not find a common RegionBranchOpInterface among a's and b's // ancestors. + LDBG() << "No common RegionBranchOpInterface found, operations are not " + "mutually exclusive"; return false; } bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) { + LDBG() << "Checking if region #" << index << " is repetitive in operation " + << getOperation()->getName(); + Region *region = &getOperation()->getRegion(index); - return isRegionReachable(region, region); + bool isRepetitive = isRegionReachable(region, region); + + LDBG() << "Region #" << index << " is repetitive: " << isRepetitive; + return isRepetitive; } bool RegionBranchOpInterface::hasLoop() { + LDBG() << "Checking if operation " << getOperation()->getName() + << " has loops"; + SmallVector<RegionSuccessor> entryRegions; getSuccessorRegions(RegionBranchPoint::parent(), entryRegions); - for (RegionSuccessor successor : entryRegions) - if (!successor.isParent() && - traverseRegionGraph(successor.getSuccessor(), - [](Region *nextRegion, ArrayRef<bool> visited) { - // Interrupt traversal if the region was already - // visited. - return visited[nextRegion->getRegionNumber()]; - })) - return true; + LDBG() << "Found " << entryRegions.size() << " entry regions"; + + for (RegionSuccessor successor : entryRegions) { + if (!successor.isParent()) { + LDBG() << "Checking entry region #" + << successor.getSuccessor()->getRegionNumber() << " for loops"; + + bool hasLoop = + traverseRegionGraph(successor.getSuccessor(), + [](Region *nextRegion, ArrayRef<bool> visited) { + // Interrupt traversal if the region was already + // visited. + return visited[nextRegion->getRegionNumber()]; + }); + + if (hasLoop) { + LDBG() << "Found loop in entry region #" + << successor.getSuccessor()->getRegionNumber(); + return true; + } + } else { + LDBG() << "Skipping parent successor"; + } + } + + LDBG() << "No loops found in operation"; return false; } Region *mlir::getEnclosingRepetitiveRegion(Operation *op) { + LDBG() << "Finding enclosing repetitive region for operation " + << op->getName(); + while (Region *region = op->getParentRegion()) { + LDBG() << "Checking region #" << region->getRegionNumber() + << " in operation " << region->getParentOp()->getName(); + op = region->getParentOp(); - if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) - if (branchOp.isRepetitiveRegion(region->getRegionNumber())) + if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) { + LDBG() + << "Found RegionBranchOpInterface, checking if region is repetitive"; + if (branchOp.isRepetitiveRegion(region->getRegionNumber())) { + LDBG() << "Found repetitive region #" << region->getRegionNumber(); return region; + } + } else { + LDBG() << "Parent operation does not implement RegionBranchOpInterface"; + } } + + LDBG() << "No enclosing repetitive region found"; return nullptr; } Region *mlir::getEnclosingRepetitiveRegion(Value value) { + LDBG() << "Finding enclosing repetitive region for value"; + Region *region = value.getParentRegion(); while (region) { + LDBG() << "Checking region #" << region->getRegionNumber() + << " in operation " << region->getParentOp()->getName(); + Operation *op = region->getParentOp(); - if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) - if (branchOp.isRepetitiveRegion(region->getRegionNumber())) + if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) { + LDBG() + << "Found RegionBranchOpInterface, checking if region is repetitive"; + if (branchOp.isRepetitiveRegion(region->getRegionNumber())) { + LDBG() << "Found repetitive region #" << region->getRegionNumber(); return region; + } + } else { + LDBG() << "Parent operation does not implement RegionBranchOpInterface"; + } region = op->getParentRegion(); } + + LDBG() << "No enclosing repetitive region found for value"; return nullptr; } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 2acbd03..64e3c5f 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -649,40 +649,38 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant( auto *arrayType = llvm::ArrayType::get(elementType, numElements); if (child->isZeroValue() && !elementType->isFPOrFPVectorTy()) { return llvm::ConstantAggregateZero::get(arrayType); - } else { - if (llvm::ConstantDataSequential::isElementTypeCompatible( - elementType)) { - // TODO: Handle all compatible types. This code only handles integer. - if (isa<llvm::IntegerType>(elementType)) { - if (llvm::ConstantInt *ci = dyn_cast<llvm::ConstantInt>(child)) { - if (ci->getBitWidth() == 8) { - SmallVector<int8_t> constants(numElements, ci->getZExtValue()); - return llvm::ConstantDataArray::get(elementType->getContext(), - constants); - } - if (ci->getBitWidth() == 16) { - SmallVector<int16_t> constants(numElements, ci->getZExtValue()); - return llvm::ConstantDataArray::get(elementType->getContext(), - constants); - } - if (ci->getBitWidth() == 32) { - SmallVector<int32_t> constants(numElements, ci->getZExtValue()); - return llvm::ConstantDataArray::get(elementType->getContext(), - constants); - } - if (ci->getBitWidth() == 64) { - SmallVector<int64_t> constants(numElements, ci->getZExtValue()); - return llvm::ConstantDataArray::get(elementType->getContext(), - constants); - } + } + if (llvm::ConstantDataSequential::isElementTypeCompatible(elementType)) { + // TODO: Handle all compatible types. This code only handles integer. + if (isa<llvm::IntegerType>(elementType)) { + if (llvm::ConstantInt *ci = dyn_cast<llvm::ConstantInt>(child)) { + if (ci->getBitWidth() == 8) { + SmallVector<int8_t> constants(numElements, ci->getZExtValue()); + return llvm::ConstantDataArray::get(elementType->getContext(), + constants); + } + if (ci->getBitWidth() == 16) { + SmallVector<int16_t> constants(numElements, ci->getZExtValue()); + return llvm::ConstantDataArray::get(elementType->getContext(), + constants); + } + if (ci->getBitWidth() == 32) { + SmallVector<int32_t> constants(numElements, ci->getZExtValue()); + return llvm::ConstantDataArray::get(elementType->getContext(), + constants); + } + if (ci->getBitWidth() == 64) { + SmallVector<int64_t> constants(numElements, ci->getZExtValue()); + return llvm::ConstantDataArray::get(elementType->getContext(), + constants); } } } + } // std::vector is used here to accomodate large number of elements that // exceed SmallVector capacity. std::vector<llvm::Constant *> constants(numElements, child); return llvm::ConstantArray::get(arrayType, constants); - } } } diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index e0c65b0..41f3f9d 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -432,8 +432,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, // Return the successors of `region` if the latter is not null. Else return // the successors of `regionBranchOp`. - auto getSuccessors = [&](Region *region = nullptr) { - auto point = region ? region : RegionBranchPoint::parent(); + auto getSuccessors = [&](RegionBranchPoint point) { SmallVector<RegionSuccessor> successors; regionBranchOp.getSuccessorRegions(point, successors); return successors; @@ -456,7 +455,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, // `nonForwardedOperands`. auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) { nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true); - for (const RegionSuccessor &successor : getSuccessors()) { + for (const RegionSuccessor &successor : + getSuccessors(RegionBranchPoint::parent())) { for (OpOperand *opOperand : getForwardedOpOperands(successor)) nonForwardedOperands.reset(opOperand->getOperandNumber()); } @@ -469,10 +469,13 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, for (Region ®ion : regionBranchOp->getRegions()) { if (region.empty()) continue; + // TODO: this isn't correct in face of multiple terminators. Operation *terminator = region.front().getTerminator(); nonForwardedRets[terminator] = BitVector(terminator->getNumOperands(), true); - for (const RegionSuccessor &successor : getSuccessors(®ion)) { + for (const RegionSuccessor &successor : + getSuccessors(RegionBranchPoint( + cast<RegionBranchTerminatorOpInterface>(terminator)))) { for (OpOperand *opOperand : getForwardedOpOperands(successor, terminator)) nonForwardedRets[terminator].reset(opOperand->getOperandNumber()); @@ -489,8 +492,13 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) { Operation *terminator = region ? region->front().getTerminator() : nullptr; + RegionBranchPoint point = + terminator + ? RegionBranchPoint( + cast<RegionBranchTerminatorOpInterface>(terminator)) + : RegionBranchPoint::parent(); - for (const RegionSuccessor &successor : getSuccessors(region)) { + for (const RegionSuccessor &successor : getSuccessors(point)) { Region *successorRegion = successor.getSuccessor(); for (auto [opOperand, input] : llvm::zip(getForwardedOpOperands(successor, terminator), @@ -517,7 +525,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, resultsOrArgsToKeepChanged = false; // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`. - for (const RegionSuccessor &successor : getSuccessors()) { + for (const RegionSuccessor &successor : + getSuccessors(RegionBranchPoint::parent())) { Region *successorRegion = successor.getSuccessor(); for (auto [opOperand, input] : llvm::zip(getForwardedOpOperands(successor), @@ -551,7 +560,9 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, if (region.empty()) continue; Operation *terminator = region.front().getTerminator(); - for (const RegionSuccessor &successor : getSuccessors(®ion)) { + for (const RegionSuccessor &successor : + getSuccessors(RegionBranchPoint( + cast<RegionBranchTerminatorOpInterface>(terminator)))) { Region *successorRegion = successor.getSuccessor(); for (auto [opOperand, input] : llvm::zip(getForwardedOpOperands(successor, terminator), |
