aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp325
-rw-r--r--mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp9
-rw-r--r--mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp4
-rw-r--r--mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp6
-rw-r--r--mlir/lib/Analysis/SliceWalk.cpp2
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp175
-rw-r--r--mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp54
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp10
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp50
-rw-r--r--mlir/lib/Dialect/Async/IR/Async.cpp11
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp11
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp8
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp3
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp23
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp52
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp1
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp1
-rw-r--r--mlir/lib/Dialect/Shape/IR/Shape.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp4
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp15
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h6
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp2
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp37
-rw-r--r--mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp5
-rw-r--r--mlir/lib/IR/Diagnostics.cpp4
-rw-r--r--mlir/lib/IR/Region.cpp15
-rw-r--r--mlir/lib/Interfaces/ControlFlowInterfaces.cpp305
-rw-r--r--mlir/lib/Transforms/RemoveDeadValues.cpp25
30 files changed, 825 insertions, 344 deletions
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
index a84d10d..24cb123 100644
--- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
+++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp
@@ -16,19 +16,21 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Value.h"
-#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/Casting.h"
+#include "llvm/Support/DebugLog.h"
#include <cassert>
#include <optional>
#include <utility>
using namespace mlir;
+#define DEBUG_TYPE "local-alias-analysis"
+
//===----------------------------------------------------------------------===//
// Underlying Address Computation
//===----------------------------------------------------------------------===//
@@ -42,81 +44,47 @@ static void collectUnderlyingAddressValues(Value value, unsigned maxDepth,
DenseSet<Value> &visited,
SmallVectorImpl<Value> &output);
-/// Given a successor (`region`) of a RegionBranchOpInterface, collect all of
-/// the underlying values being addressed by one of the successor inputs. If the
-/// provided `region` is null, as per `RegionBranchOpInterface` this represents
-/// the parent operation.
-static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
- Region *region, Value inputValue,
- unsigned inputIndex,
- unsigned maxDepth,
- DenseSet<Value> &visited,
- SmallVectorImpl<Value> &output) {
- // Given the index of a region of the branch (`predIndex`), or std::nullopt to
- // represent the parent operation, try to return the index into the outputs of
- // this region predecessor that correspond to the input values of `region`. If
- // an index could not be found, std::nullopt is returned instead.
- auto getOperandIndexIfPred =
- [&](RegionBranchPoint pred) -> std::optional<unsigned> {
- SmallVector<RegionSuccessor, 2> successors;
- branch.getSuccessorRegions(pred, successors);
- for (RegionSuccessor &successor : successors) {
- if (successor.getSuccessor() != region)
- continue;
- // Check that the successor inputs map to the given input value.
- ValueRange inputs = successor.getSuccessorInputs();
- if (inputs.empty()) {
- output.push_back(inputValue);
- break;
- }
- unsigned firstInputIndex, lastInputIndex;
- if (region) {
- firstInputIndex = cast<BlockArgument>(inputs[0]).getArgNumber();
- lastInputIndex = cast<BlockArgument>(inputs.back()).getArgNumber();
- } else {
- firstInputIndex = cast<OpResult>(inputs[0]).getResultNumber();
- lastInputIndex = cast<OpResult>(inputs.back()).getResultNumber();
- }
- if (firstInputIndex > inputIndex || lastInputIndex < inputIndex) {
- output.push_back(inputValue);
- break;
- }
- return inputIndex - firstInputIndex;
- }
- return std::nullopt;
- };
-
- // Check branches from the parent operation.
- auto branchPoint = RegionBranchPoint::parent();
- if (region)
- branchPoint = region;
-
- if (std::optional<unsigned> operandIndex =
- getOperandIndexIfPred(/*predIndex=*/RegionBranchPoint::parent())) {
- collectUnderlyingAddressValues(
- branch.getEntrySuccessorOperands(branchPoint)[*operandIndex], maxDepth,
- visited, output);
+/// Given a RegionBranchOpInterface operation (`branch`), a Value`inputValue`
+/// which is an input for the provided successor (`initialSuccessor`), try to
+/// find the possible sources for the value along the control flow edges.
+static void collectUnderlyingAddressValues2(
+ RegionBranchOpInterface branch, RegionSuccessor initialSuccessor,
+ Value inputValue, unsigned inputIndex, unsigned maxDepth,
+ DenseSet<Value> &visited, SmallVectorImpl<Value> &output) {
+ LDBG() << "collectUnderlyingAddressValues2: "
+ << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions());
+ LDBG() << " with initialSuccessor " << initialSuccessor;
+ LDBG() << " inputValue: " << inputValue;
+ LDBG() << " inputIndex: " << inputIndex;
+ LDBG() << " maxDepth: " << maxDepth;
+ ValueRange inputs = initialSuccessor.getSuccessorInputs();
+ if (inputs.empty()) {
+ LDBG() << " input is empty, enqueue value";
+ output.push_back(inputValue);
+ return;
}
- // Check branches from each child region.
- Operation *op = branch.getOperation();
- for (Region &region : op->getRegions()) {
- if (std::optional<unsigned> operandIndex = getOperandIndexIfPred(region)) {
- for (Block &block : region) {
- // Try to determine possible region-branch successor operands for the
- // current region.
- if (auto term = dyn_cast<RegionBranchTerminatorOpInterface>(
- block.getTerminator())) {
- collectUnderlyingAddressValues(
- term.getSuccessorOperands(branchPoint)[*operandIndex], maxDepth,
- visited, output);
- } else if (block.getNumSuccessors()) {
- // Otherwise, if this terminator may exit the region we can't make
- // any assumptions about which values get passed.
- output.push_back(inputValue);
- return;
- }
- }
- }
+ unsigned firstInputIndex, lastInputIndex;
+ if (isa<BlockArgument>(inputs[0])) {
+ firstInputIndex = cast<BlockArgument>(inputs[0]).getArgNumber();
+ lastInputIndex = cast<BlockArgument>(inputs.back()).getArgNumber();
+ } else {
+ firstInputIndex = cast<OpResult>(inputs[0]).getResultNumber();
+ lastInputIndex = cast<OpResult>(inputs.back()).getResultNumber();
+ }
+ if (firstInputIndex > inputIndex || lastInputIndex < inputIndex) {
+ LDBG() << " !! Input index " << inputIndex << " out of range "
+ << firstInputIndex << " to " << lastInputIndex
+ << ", adding input value to output";
+ output.push_back(inputValue);
+ return;
+ }
+ SmallVector<Value> predecessorValues;
+ branch.getPredecessorValues(initialSuccessor, inputIndex - firstInputIndex,
+ predecessorValues);
+ LDBG() << " Found " << predecessorValues.size() << " predecessor values";
+ for (Value predecessorValue : predecessorValues) {
+ LDBG() << " Processing predecessor value: " << predecessorValue;
+ collectUnderlyingAddressValues(predecessorValue, maxDepth, visited, output);
}
}
@@ -124,22 +92,28 @@ static void collectUnderlyingAddressValues(RegionBranchOpInterface branch,
static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth,
DenseSet<Value> &visited,
SmallVectorImpl<Value> &output) {
+ LDBG() << "collectUnderlyingAddressValues (OpResult): " << result;
+ LDBG() << " maxDepth: " << maxDepth;
+
Operation *op = result.getOwner();
// If this is a view, unwrap to the source.
if (ViewLikeOpInterface view = dyn_cast<ViewLikeOpInterface>(op)) {
if (result == view.getViewDest()) {
+ LDBG() << " Unwrapping view to source: " << view.getViewSource();
return collectUnderlyingAddressValues(view.getViewSource(), maxDepth,
visited, output);
}
}
// Check to see if we can reason about the control flow of this op.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
- return collectUnderlyingAddressValues(branch, /*region=*/nullptr, result,
- result.getResultNumber(), maxDepth,
- visited, output);
+ LDBG() << " Processing region branch operation";
+ return collectUnderlyingAddressValues2(
+ branch, RegionSuccessor(op, op->getResults()), result,
+ result.getResultNumber(), maxDepth, visited, output);
}
+ LDBG() << " Adding result to output: " << result;
output.push_back(result);
}
@@ -148,14 +122,23 @@ static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth,
static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth,
DenseSet<Value> &visited,
SmallVectorImpl<Value> &output) {
+ LDBG() << "collectUnderlyingAddressValues (BlockArgument): " << arg;
+ LDBG() << " maxDepth: " << maxDepth;
+ LDBG() << " argNumber: " << arg.getArgNumber();
+ LDBG() << " isEntryBlock: " << arg.getOwner()->isEntryBlock();
+
Block *block = arg.getOwner();
unsigned argNumber = arg.getArgNumber();
// Handle the case of a non-entry block.
if (!block->isEntryBlock()) {
+ LDBG() << " Processing non-entry block with "
+ << std::distance(block->pred_begin(), block->pred_end())
+ << " predecessors";
for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
auto branch = dyn_cast<BranchOpInterface>((*it)->getTerminator());
if (!branch) {
+ LDBG() << " Cannot analyze control flow, adding argument to output";
// We can't analyze the control flow, so bail out early.
output.push_back(arg);
return;
@@ -165,10 +148,12 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth,
unsigned index = it.getSuccessorIndex();
Value operand = branch.getSuccessorOperands(index)[argNumber];
if (!operand) {
+ LDBG() << " No operand found for argument, adding to output";
// We can't analyze the control flow, so bail out early.
output.push_back(arg);
return;
}
+ LDBG() << " Processing operand from predecessor: " << operand;
collectUnderlyingAddressValues(operand, maxDepth, visited, output);
}
return;
@@ -178,10 +163,35 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth,
Region *region = block->getParent();
Operation *op = region->getParentOp();
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
- return collectUnderlyingAddressValues(branch, region, arg, argNumber,
- maxDepth, visited, output);
+ LDBG() << " Processing region branch operation for entry block";
+ // We have to find the successor matching the region, so that the input
+ // arguments are correctly set.
+ // TODO: this isn't comprehensive: the successor may not be reachable from
+ // the entry block.
+ SmallVector<RegionSuccessor> successors;
+ branch.getSuccessorRegions(RegionBranchPoint::parent(), successors);
+ RegionSuccessor regionSuccessor(region);
+ bool found = false;
+ for (RegionSuccessor &successor : successors) {
+ if (successor.getSuccessor() == region) {
+ LDBG() << " Found matching region successor: " << successor;
+ found = true;
+ regionSuccessor = successor;
+ break;
+ }
+ }
+ if (!found) {
+ LDBG()
+ << " No matching region successor found, adding argument to output";
+ output.push_back(arg);
+ return;
+ }
+ return collectUnderlyingAddressValues2(
+ branch, regionSuccessor, arg, argNumber, maxDepth, visited, output);
}
+ LDBG()
+ << " Cannot reason about underlying address, adding argument to output";
// We can't reason about the underlying address of this argument.
output.push_back(arg);
}
@@ -190,17 +200,26 @@ static void collectUnderlyingAddressValues(BlockArgument arg, unsigned maxDepth,
static void collectUnderlyingAddressValues(Value value, unsigned maxDepth,
DenseSet<Value> &visited,
SmallVectorImpl<Value> &output) {
+ LDBG() << "collectUnderlyingAddressValues: " << value;
+ LDBG() << " maxDepth: " << maxDepth;
+
// Check that we don't infinitely recurse.
- if (!visited.insert(value).second)
+ if (!visited.insert(value).second) {
+ LDBG() << " Value already visited, skipping";
return;
+ }
if (maxDepth == 0) {
+ LDBG() << " Max depth reached, adding value to output";
output.push_back(value);
return;
}
--maxDepth;
- if (BlockArgument arg = dyn_cast<BlockArgument>(value))
+ if (BlockArgument arg = dyn_cast<BlockArgument>(value)) {
+ LDBG() << " Processing as BlockArgument";
return collectUnderlyingAddressValues(arg, maxDepth, visited, output);
+ }
+ LDBG() << " Processing as OpResult";
collectUnderlyingAddressValues(cast<OpResult>(value), maxDepth, visited,
output);
}
@@ -208,9 +227,11 @@ static void collectUnderlyingAddressValues(Value value, unsigned maxDepth,
/// Given a value, collect all of the underlying values being addressed.
static void collectUnderlyingAddressValues(Value value,
SmallVectorImpl<Value> &output) {
+ LDBG() << "collectUnderlyingAddressValues: " << value;
DenseSet<Value> visited;
collectUnderlyingAddressValues(value, maxUnderlyingValueSearchDepth, visited,
output);
+ LDBG() << " Collected " << output.size() << " underlying values";
}
//===----------------------------------------------------------------------===//
@@ -227,19 +248,33 @@ static LogicalResult
getAllocEffectFor(Value value,
std::optional<MemoryEffects::EffectInstance> &effect,
Operation *&allocScopeOp) {
+ LDBG() << "getAllocEffectFor: " << value;
+
// Try to get a memory effect interface for the parent operation.
Operation *op;
- if (BlockArgument arg = dyn_cast<BlockArgument>(value))
+ if (BlockArgument arg = dyn_cast<BlockArgument>(value)) {
op = arg.getOwner()->getParentOp();
- else
+ LDBG() << " BlockArgument, parent op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ } else {
op = cast<OpResult>(value).getOwner();
+ LDBG() << " OpResult, owner op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ }
+
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
- if (!interface)
+ if (!interface) {
+ LDBG() << " No memory effect interface found";
return failure();
+ }
// Try to find an allocation effect on the resource.
- if (!(effect = interface.getEffectOnValue<MemoryEffects::Allocate>(value)))
+ if (!(effect = interface.getEffectOnValue<MemoryEffects::Allocate>(value))) {
+ LDBG() << " No allocation effect found on value";
return failure();
+ }
+
+ LDBG() << " Found allocation effect";
// If we found an allocation effect, try to find a scope for the allocation.
// If the resource of this allocation is automatically scoped, find the parent
@@ -247,6 +282,12 @@ getAllocEffectFor(Value value,
if (llvm::isa<SideEffects::AutomaticAllocationScopeResource>(
effect->getResource())) {
allocScopeOp = op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
+ if (allocScopeOp) {
+ LDBG() << " Automatic allocation scope found: "
+ << OpWithFlags(allocScopeOp, OpPrintingFlags().skipRegions());
+ } else {
+ LDBG() << " Automatic allocation scope found: null";
+ }
return success();
}
@@ -255,6 +296,12 @@ getAllocEffectFor(Value value,
// For now assume allocation scope to the function scope (we don't care if
// pointer escape outside function).
allocScopeOp = op->getParentOfType<FunctionOpInterface>();
+ if (allocScopeOp) {
+ LDBG() << " Function scope found: "
+ << OpWithFlags(allocScopeOp, OpPrintingFlags().skipRegions());
+ } else {
+ LDBG() << " Function scope found: null";
+ }
return success();
}
@@ -293,33 +340,44 @@ static std::optional<AliasResult> checkDistinctObjects(Value lhs, Value rhs) {
/// Given the two values, return their aliasing behavior.
AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
- if (lhs == rhs)
+ LDBG() << "aliasImpl: " << lhs << " vs " << rhs;
+
+ if (lhs == rhs) {
+ LDBG() << " Same value, must alias";
return AliasResult::MustAlias;
+ }
+
Operation *lhsAllocScope = nullptr, *rhsAllocScope = nullptr;
std::optional<MemoryEffects::EffectInstance> lhsAlloc, rhsAlloc;
// Handle the case where lhs is a constant.
Attribute lhsAttr, rhsAttr;
if (matchPattern(lhs, m_Constant(&lhsAttr))) {
+ LDBG() << " lhs is constant";
// TODO: This is overly conservative. Two matching constants don't
// necessarily map to the same address. For example, if the two values
// correspond to different symbols that both represent a definition.
- if (matchPattern(rhs, m_Constant(&rhsAttr)))
+ if (matchPattern(rhs, m_Constant(&rhsAttr))) {
+ LDBG() << " rhs is also constant, may alias";
return AliasResult::MayAlias;
+ }
// Try to find an alloc effect on rhs. If an effect was found we can't
// alias, otherwise we might.
- return succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope))
- ? AliasResult::NoAlias
- : AliasResult::MayAlias;
+ bool rhsHasAlloc =
+ succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope));
+ LDBG() << " rhs has alloc effect: " << rhsHasAlloc;
+ return rhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias;
}
// Handle the case where rhs is a constant.
if (matchPattern(rhs, m_Constant(&rhsAttr))) {
+ LDBG() << " rhs is constant";
// Try to find an alloc effect on lhs. If an effect was found we can't
// alias, otherwise we might.
- return succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope))
- ? AliasResult::NoAlias
- : AliasResult::MayAlias;
+ bool lhsHasAlloc =
+ succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope));
+ LDBG() << " lhs has alloc effect: " << lhsHasAlloc;
+ return lhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias;
}
if (std::optional<AliasResult> result = checkDistinctObjects(lhs, rhs))
@@ -329,9 +387,14 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
// an allocation effect.
bool lhsHasAlloc = succeeded(getAllocEffectFor(lhs, lhsAlloc, lhsAllocScope));
bool rhsHasAlloc = succeeded(getAllocEffectFor(rhs, rhsAlloc, rhsAllocScope));
+ LDBG() << " lhs has alloc effect: " << lhsHasAlloc;
+ LDBG() << " rhs has alloc effect: " << rhsHasAlloc;
+
if (lhsHasAlloc == rhsHasAlloc) {
// If both values have an allocation effect we know they don't alias, and if
// neither have an effect we can't make an assumptions.
+ LDBG() << " Both have same alloc status: "
+ << (lhsHasAlloc ? "NoAlias" : "MayAlias");
return lhsHasAlloc ? AliasResult::NoAlias : AliasResult::MayAlias;
}
@@ -339,6 +402,7 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
// and one without. Move the one with the effect to the lhs to make the next
// checks simpler.
if (rhsHasAlloc) {
+ LDBG() << " Swapping lhs and rhs to put alloc effect on lhs";
std::swap(lhs, rhs);
lhsAlloc = rhsAlloc;
lhsAllocScope = rhsAllocScope;
@@ -347,49 +411,74 @@ AliasResult LocalAliasAnalysis::aliasImpl(Value lhs, Value rhs) {
// If the effect has a scoped allocation region, check to see if the
// non-effect value is defined above that scope.
if (lhsAllocScope) {
+ LDBG() << " Checking allocation scope: "
+ << OpWithFlags(lhsAllocScope, OpPrintingFlags().skipRegions());
// If the parent operation of rhs is an ancestor of the allocation scope, or
// if rhs is an entry block argument of the allocation scope we know the two
// values can't alias.
Operation *rhsParentOp = rhs.getParentRegion()->getParentOp();
- if (rhsParentOp->isProperAncestor(lhsAllocScope))
+ if (rhsParentOp->isProperAncestor(lhsAllocScope)) {
+ LDBG() << " rhs parent is ancestor of alloc scope, no alias";
return AliasResult::NoAlias;
+ }
if (rhsParentOp == lhsAllocScope) {
BlockArgument rhsArg = dyn_cast<BlockArgument>(rhs);
- if (rhsArg && rhs.getParentBlock()->isEntryBlock())
+ if (rhsArg && rhs.getParentBlock()->isEntryBlock()) {
+ LDBG() << " rhs is entry block arg of alloc scope, no alias";
return AliasResult::NoAlias;
+ }
}
}
// If we couldn't reason about the relationship between the two values,
// conservatively assume they might alias.
+ LDBG() << " Cannot reason about relationship, may alias";
return AliasResult::MayAlias;
}
/// Given the two values, return their aliasing behavior.
AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) {
- if (lhs == rhs)
+ LDBG() << "alias: " << lhs << " vs " << rhs;
+
+ if (lhs == rhs) {
+ LDBG() << " Same value, must alias";
return AliasResult::MustAlias;
+ }
// Get the underlying values being addressed.
SmallVector<Value, 8> lhsValues, rhsValues;
collectUnderlyingAddressValues(lhs, lhsValues);
collectUnderlyingAddressValues(rhs, rhsValues);
+ LDBG() << " lhs underlying values: " << lhsValues.size();
+ LDBG() << " rhs underlying values: " << rhsValues.size();
+
// If we failed to collect for either of the values somehow, conservatively
// assume they may alias.
- if (lhsValues.empty() || rhsValues.empty())
+ if (lhsValues.empty() || rhsValues.empty()) {
+ LDBG() << " Failed to collect underlying values, may alias";
return AliasResult::MayAlias;
+ }
// Check the alias results against each of the underlying values.
std::optional<AliasResult> result;
for (Value lhsVal : lhsValues) {
for (Value rhsVal : rhsValues) {
+ LDBG() << " Checking underlying values: " << lhsVal << " vs " << rhsVal;
AliasResult nextResult = aliasImpl(lhsVal, rhsVal);
+ LDBG() << " Result: "
+ << (nextResult == AliasResult::MustAlias ? "MustAlias"
+ : nextResult == AliasResult::NoAlias ? "NoAlias"
+ : "MayAlias");
result = result ? result->merge(nextResult) : nextResult;
}
}
// We should always have a valid result here.
+ LDBG() << " Final result: "
+ << (result->isMust() ? "MustAlias"
+ : result->isNo() ? "NoAlias"
+ : "MayAlias");
return *result;
}
@@ -398,8 +487,12 @@ AliasResult LocalAliasAnalysis::alias(Value lhs, Value rhs) {
//===----------------------------------------------------------------------===//
ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) {
+ LDBG() << "getModRef: " << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " on location " << location;
+
// Check to see if this operation relies on nested side effects.
if (op->hasTrait<OpTrait::HasRecursiveMemoryEffects>()) {
+ LDBG() << " Operation has recursive memory effects, returning ModAndRef";
// TODO: To check recursive operations we need to check all of the nested
// operations, which can result in a quadratic number of queries. We should
// introduce some caching of some kind to help alleviate this, especially as
@@ -410,38 +503,64 @@ ModRefResult LocalAliasAnalysis::getModRef(Operation *op, Value location) {
// Otherwise, check to see if this operation has a memory effect interface.
MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
- if (!interface)
+ if (!interface) {
+ LDBG() << " No memory effect interface, returning ModAndRef";
return ModRefResult::getModAndRef();
+ }
// Build a ModRefResult by merging the behavior of the effects of this
// operation.
SmallVector<MemoryEffects::EffectInstance> effects;
interface.getEffects(effects);
+ LDBG() << " Found " << effects.size() << " memory effects";
ModRefResult result = ModRefResult::getNoModRef();
for (const MemoryEffects::EffectInstance &effect : effects) {
- if (isa<MemoryEffects::Allocate, MemoryEffects::Free>(effect.getEffect()))
+ if (isa<MemoryEffects::Allocate, MemoryEffects::Free>(effect.getEffect())) {
+ LDBG() << " Skipping alloc/free effect";
continue;
+ }
// Check for an alias between the effect and our memory location.
// TODO: Add support for checking an alias with a symbol reference.
AliasResult aliasResult = AliasResult::MayAlias;
- if (Value effectValue = effect.getValue())
+ if (Value effectValue = effect.getValue()) {
+ LDBG() << " Checking alias between effect value " << effectValue
+ << " and location " << location;
aliasResult = alias(effectValue, location);
+ LDBG() << " Alias result: "
+ << (aliasResult.isMust() ? "MustAlias"
+ : aliasResult.isNo() ? "NoAlias"
+ : "MayAlias");
+ } else {
+ LDBG() << " No effect value, assuming MayAlias";
+ }
// If we don't alias, ignore this effect.
- if (aliasResult.isNo())
+ if (aliasResult.isNo()) {
+ LDBG() << " No alias, ignoring effect";
continue;
+ }
// Merge in the corresponding mod or ref for this effect.
if (isa<MemoryEffects::Read>(effect.getEffect())) {
+ LDBG() << " Adding Ref to result";
result = result.merge(ModRefResult::getRef());
} else {
assert(isa<MemoryEffects::Write>(effect.getEffect()));
+ LDBG() << " Adding Mod to result";
result = result.merge(ModRefResult::getMod());
}
- if (result.isModAndRef())
+ if (result.isModAndRef()) {
+ LDBG() << " Result is now ModAndRef, breaking";
break;
+ }
}
+
+ LDBG() << " Final ModRef result: "
+ << (result.isModAndRef() ? "ModAndRef"
+ : result.isMod() ? "Mod"
+ : result.isRef() ? "Ref"
+ : "NoModRef");
return result;
}
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 377f7eb..0fc5b44 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -501,11 +501,10 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
return;
SmallVector<RegionSuccessor> successors;
- if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op))
- terminator.getSuccessorRegions(*operands, successors);
- else
- branch.getSuccessorRegions(op->getParentRegion(), successors);
-
+ auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op);
+ if (!terminator)
+ return;
+ terminator.getSuccessorRegions(*operands, successors);
visitRegionBranchEdges(branch, op, successors);
}
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index daa3db5..0682e5f 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -588,7 +588,9 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
// flow, propagate the lattice back along the control flow edge.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
LDBG() << " Exit block of region branch operation";
- visitRegionBranchOperation(point, branch, block->getParent(), before);
+ auto terminator =
+ cast<RegionBranchTerminatorOpInterface>(block->getTerminator());
+ visitRegionBranchOperation(point, branch, terminator, before);
return;
}
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 0d2e2ed..8e63ae8 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -130,7 +130,7 @@ AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
// The results of a region branch operation are determined by control-flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
visitRegionSuccessors(getProgramPointAfter(branch), branch,
- /*successor=*/RegionBranchPoint::parent(),
+ /*successor=*/{branch, branch->getResults()},
resultLattices);
return success();
}
@@ -279,7 +279,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitCallableOperation(
void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
ProgramPoint *point, RegionBranchOpInterface branch,
- RegionBranchPoint successor, ArrayRef<AbstractSparseLattice *> lattices) {
+ RegionSuccessor successor, ArrayRef<AbstractSparseLattice *> lattices) {
const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
assert(predecessors->allPredecessorsKnown() &&
"unexpected unresolved region successors");
@@ -314,7 +314,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
visitNonControlFlowArgumentsImpl(
branch,
RegionSuccessor(
- branch->getResults().slice(firstIndex, inputs.size())),
+ branch, branch->getResults().slice(firstIndex, inputs.size())),
lattices, firstIndex);
} else {
if (!inputs.empty())
diff --git a/mlir/lib/Analysis/SliceWalk.cpp b/mlir/lib/Analysis/SliceWalk.cpp
index 817d71a..863f260 100644
--- a/mlir/lib/Analysis/SliceWalk.cpp
+++ b/mlir/lib/Analysis/SliceWalk.cpp
@@ -114,7 +114,7 @@ mlir::getControlFlowPredecessors(Value value) {
if (!regionOp)
return std::nullopt;
// Add the control flow predecessor operands to the work list.
- RegionSuccessor region(regionOp->getResults());
+ RegionSuccessor region(regionOp, regionOp->getResults());
SmallVector<Value> predecessorOperands = getRegionPredecessorOperands(
regionOp, region, opResult.getResultNumber());
return predecessorOperands;
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 478b6aa..1eca43d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -989,21 +989,17 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
smfma.getN(), smfma.getK(), 1u, chipset);
}
-/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
-/// if one exists. This includes checking to ensure the intrinsic is supported
-/// on the architecture you are compiling for.
-static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
- Chipset chipset) {
- auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
- auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
- auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
- Type elemSourceType = sourceVectorType.getElementType();
- Type elemBSourceType = sourceBVectorType.getElementType();
- Type elemDestType = destVectorType.getElementType();
-
- const uint32_t k = wmma.getK();
-
+/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// for RDNA3/4 architectures.
+static std::optional<StringRef>
+wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
+ Type elemDestType, uint32_t k, bool isRDNA3) {
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+
+ // Handle k == 16 for RDNA3/4.
if (k == 16) {
+ // Common patterns for RDNA3 and RDNA4.
if (elemSourceType.isF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isF32())
@@ -1014,39 +1010,160 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
- if (chipset.majorVersion == 11) {
+
+ // RDNA3 specific patterns.
+ if (isRDNA3) {
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+ return std::nullopt;
}
- }
- if (chipset.majorVersion < 12)
- return std::nullopt;
- // gfx12+
- if (k == 16) {
- if (isa<Float8E4M3FNType>(elemSourceType) &&
- isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+ // RDNA4 specific patterns (fp8/bf8).
+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
- if (isa<Float8E4M3FNType>(elemSourceType) &&
- isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+ if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
- if (isa<Float8E5M2Type>(elemSourceType) &&
- isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+ if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
- if (isa<Float8E5M2Type>(elemSourceType) &&
- isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+ if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
return std::nullopt;
}
- if (k == 32) {
+
+ // Handle k == 32 for RDNA4.
+ if (k == 32 && !isRDNA3) {
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+ }
+
+ llvm_unreachable("Unsupported k value");
+}
+
+/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// for the gfx1250 architecture.
+static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
+ Type elemBSourceType,
+ Type elemDestType,
+ uint32_t k) {
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+
+ if (k == 4) {
+ if (elemSourceType.isF32() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
+
return std::nullopt;
}
+ if (k == 32) {
+ if (elemSourceType.isF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
+ if (elemSourceType.isF16() && elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isBF16())
+ return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
+
+ return std::nullopt;
+ }
+
+ if (k == 64) {
+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
+ }
+ if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
+ }
+ if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
+
+ return std::nullopt;
+ }
+
+ if (k == 128) {
+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
+ }
+ if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
+ }
+
+ return std::nullopt;
+ }
+
+ llvm_unreachable("Unsupported k value");
+}
+
+/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
+/// if one exists. This includes checking to ensure the intrinsic is supported
+/// on the architecture you are compiling for.
+static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
+ Chipset chipset) {
+ auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
+ auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
+ auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
+ Type elemSourceType = sourceVectorType.getElementType();
+ Type elemBSourceType = sourceBVectorType.getElementType();
+ Type elemDestType = destVectorType.getElementType();
+
+ const uint32_t k = wmma.getK();
+ const bool isRDNA3 = chipset.majorVersion == 11;
+ const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
+
+ // Handle RDNA3 and RDNA4.
+ if (isRDNA3 || isRDNA4)
+ return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
+ k, isRDNA3);
+
+ // Handle gfx1250.
+ if (chipset == Chipset{12, 5, 0})
+ return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
+ elemDestType, k);
+
llvm_unreachable("unhandled WMMA case");
}
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 0fe7239..9e46b7d 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -313,25 +313,53 @@ private:
struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
+ // exp(x+I*y) = exp(x)*(cos(y)+I*sin(y))
+ // Handle special cases as StableHLO implementation does:
+ // 1. When b == 0, set imag(exp(z)) = 0
+ // 2. When exp(x) == inf, use exp(x/2)*(cos(y)+I*sin(y))*exp(x/2)
LogicalResult
matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto type = cast<ComplexType>(adaptor.getComplex().getType());
- auto elementType = cast<FloatType>(type.getElementType());
- arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
-
- Value real =
- complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
- Value imag =
- complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
- Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue());
- Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue());
+ auto ET = cast<FloatType>(type.getElementType());
+ arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
+ const auto &floatSemantics = ET.getFloatSemantics();
+ ImplicitLocOpBuilder b(loc, rewriter);
+
+ Value x = complex::ReOp::create(b, ET, adaptor.getComplex());
+ Value y = complex::ImOp::create(b, ET, adaptor.getComplex());
+ Value zero = arith::ConstantOp::create(b, ET, b.getZeroAttr(ET));
+ Value half = arith::ConstantOp::create(b, ET, b.getFloatAttr(ET, 0.5));
+ Value inf = arith::ConstantOp::create(
+ b, ET, b.getFloatAttr(ET, APFloat::getInf(floatSemantics)));
+
+ Value exp = math::ExpOp::create(b, x, fmf);
+ Value xHalf = arith::MulFOp::create(b, x, half, fmf);
+ Value expHalf = math::ExpOp::create(b, xHalf, fmf);
+ Value cos = math::CosOp::create(b, y, fmf);
+ Value sin = math::SinOp::create(b, y, fmf);
+
+ Value expIsInf =
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, exp, inf, fmf);
+ Value yIsZero =
+ arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, y, zero);
+
+ // Real path: select between exp(x)*cos(y) and exp(x/2)*cos(y)*exp(x/2)
+ Value realNormal = arith::MulFOp::create(b, exp, cos, fmf);
+ Value expHalfCos = arith::MulFOp::create(b, expHalf, cos, fmf);
+ Value realOverflow = arith::MulFOp::create(b, expHalfCos, expHalf, fmf);
Value resultReal =
- arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue());
- Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue());
- Value resultImag =
- arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue());
+ arith::SelectOp::create(b, expIsInf, realOverflow, realNormal);
+
+ // Imaginary part: if y == 0 return 0 else select between exp(x)*sin(y) and
+ // exp(x/2)*sin(y)*exp(x/2)
+ Value imagNormal = arith::MulFOp::create(b, exp, sin, fmf);
+ Value expHalfSin = arith::MulFOp::create(b, expHalf, sin, fmf);
+ Value imagOverflow = arith::MulFOp::create(b, expHalfSin, expHalf, fmf);
+ Value imagNonZero =
+ arith::SelectOp::create(b, expIsInf, imagOverflow, imagNormal);
+ Value resultImag = arith::SelectOp::create(b, yIsZero, zero, imagNonZero);
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
resultImag);
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 585b6da..df955fc 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -399,13 +399,15 @@ LogicalResult WMMAOp::verify() {
if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
return emitOpError(
- "source element types much match (except for fp8) but have ")
+ "source element types must match (except for fp8/bf8) but have ")
<< sourceAType << " and " << sourceBType;
}
- if (!sourceAElemType.isInteger(4) && getK() != 16) {
- return emitOpError("K dimension must be 16 for source element type ")
- << sourceAElemType;
+ if (isSrcFloat) {
+ if (getClamp())
+ return emitOpError("clamp flag is not supported for float types");
+ if (getUnsignedA() || getUnsignedB())
+ return emitOpError("unsigned flags are not supported for float types");
}
return success();
}
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index e0a53cd..0c35921 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2716,8 +2716,9 @@ LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
return success(folded);
}
-OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert((point.isParent() || point == getRegion()) && "invalid region point");
+OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert((successor.isParent() || successor.getSuccessor() == &getRegion()) &&
+ "invalid region point");
// The initial operands map to the loop arguments after the induction
// variable or are forwarded to the results when the trip count is zero.
@@ -2726,34 +2727,41 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
void AffineForOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
- assert((point.isParent() || point == getRegion()) && "expected loop region");
+ assert((point.isParent() ||
+ point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &getRegion()) &&
+ "expected loop region");
// The loop may typically branch back to its body or to the parent operation.
// If the predecessor is the parent op and the trip count is known to be at
// least one, branch into the body using the iterator arguments. And in cases
// we know the trip count is zero, it can only branch back to its parent.
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
- if (point.isParent() && tripCount.has_value()) {
- if (tripCount.value() > 0) {
- regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
- return;
- }
- if (tripCount.value() == 0) {
- regions.push_back(RegionSuccessor(getResults()));
- return;
+ if (tripCount.has_value()) {
+ if (!point.isParent()) {
+ // From the loop body, if the trip count is one, we can only branch back
+ // to the parent.
+ if (tripCount == 1) {
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
+ return;
+ }
+ if (tripCount == 0)
+ return;
+ } else {
+ if (tripCount.value() > 0) {
+ regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+ return;
+ }
+ if (tripCount.value() == 0) {
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
+ return;
+ }
}
}
- // From the loop body, if the trip count is one, we can only branch back to
- // the parent.
- if (!point.isParent() && tripCount == 1) {
- regions.push_back(RegionSuccessor(getResults()));
- return;
- }
-
// In all other cases, the loop may branch back to itself or the parent
// operation.
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
}
AffineBound AffineForOp::getLowerBound() {
@@ -3142,7 +3150,7 @@ void AffineIfOp::getSuccessorRegions(
RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
// If the "else" region is empty, branch bach into parent.
if (getElseRegion().empty()) {
- regions.push_back(getResults());
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
} else {
regions.push_back(
RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
@@ -3152,7 +3160,7 @@ void AffineIfOp::getSuccessorRegions(
// If the predecessor is the `else`/`then` region, then branching into parent
// op is valid.
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
}
LogicalResult AffineIfOp::verify() {
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index dc7b07d..8e4a49d 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -36,8 +36,9 @@ void AsyncDialect::initialize() {
constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
-OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBodyRegion() && "invalid region index");
+OperandRange ExecuteOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert(successor.getSuccessor() == &getBodyRegion() &&
+ "invalid region index");
return getBodyOperands();
}
@@ -53,8 +54,10 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `body` region branch back to the parent operation.
- if (point == getBodyRegion()) {
- regions.push_back(RegionSuccessor(getBodyResults()));
+ if (!point.isParent() &&
+ point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &getBodyRegion()) {
+ regions.push_back(RegionSuccessor(getOperation(), getBodyResults()));
return;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index b593cca..36a759c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -562,8 +562,11 @@ LogicalResult
BufferDeallocation::updateFunctionSignature(FunctionOpInterface op) {
SmallVector<TypeRange> returnOperandTypes(llvm::map_range(
op.getFunctionBody().getOps<RegionBranchTerminatorOpInterface>(),
- [](RegionBranchTerminatorOpInterface op) {
- return op.getSuccessorOperands(RegionBranchPoint::parent()).getTypes();
+ [&](RegionBranchTerminatorOpInterface branchOp) {
+ return branchOp
+ .getSuccessorOperands(RegionSuccessor(
+ op.getOperation(), op.getOperation()->getResults()))
+ .getTypes();
}));
if (!llvm::all_equal(returnOperandTypes))
return op->emitError(
@@ -942,8 +945,8 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
// about, but we would need to check how many successors there are and under
// which condition they are taken, etc.
- MutableOperandRange operands =
- op.getMutableSuccessorOperands(RegionBranchPoint::parent());
+ MutableOperandRange operands = op.getMutableSuccessorOperands(
+ RegionSuccessor(op.getOperation(), op.getOperation()->getResults()));
SmallVector<Value> updatedOwnerships;
auto result = deallocation_impl::insertDeallocOpForReturnLike(
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 4754f0b..0992ce14 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -845,7 +845,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `then` and the `else` region branch back to the parent operation.
if (!point.isParent()) {
- regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(getOperation(), getOperation()->getResults()));
return;
}
@@ -854,7 +855,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point,
// Don't consider the else region if it is empty.
Region *elseRegion = &this->getElseRegion();
if (elseRegion->empty())
- regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(getOperation(), getOperation()->getResults()));
else
regions.push_back(RegionSuccessor(elseRegion));
}
@@ -871,7 +873,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
if (!getElseRegion().empty())
regions.emplace_back(&getElseRegion());
else
- regions.emplace_back();
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index b5f8dda..6c6d8d2 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2399,7 +2399,7 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
void WarpExecuteOnLane0Op::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index eb2d825..bd25e94 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -495,13 +495,14 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
if (failed(maybePackedDimForEachOperand))
return failure();
packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand;
- listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
LDBG() << "++++ After pack size #" << i << ": " << packedSizes[i];
LDBG() << "maps: " << llvm::interleaved(indexingMaps);
LDBG() << "iterators: " << llvm::interleaved(iteratorTypes);
LDBG() << "packedDimForEachOperand: "
<< llvm::interleaved(packedOperandsDims.packedDimForEachOperand);
+
+ listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims));
}
// Step 2. Propagate packing to all LinalgOp operands.
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index c551fba..1c21a2f 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -405,7 +405,7 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
void AllocaScopeOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index 6fa8ce4..69afbca 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -98,6 +98,27 @@ struct RankOpInterface
}
};
+struct CollapseShapeOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface,
+ memref::CollapseShapeOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto collapseOp = cast<memref::CollapseShapeOp>(op);
+ assert(value == collapseOp.getResult() && "invalid value");
+
+ // Multiply the expressions for the dimensions in the reassociation group.
+ const ReassociationIndices reassocIndices =
+ collapseOp.getReassociationIndices()[dim];
+ AffineExpr productExpr =
+ cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]);
+ for (size_t i = 1; i < reassocIndices.size(); ++i) {
+ productExpr =
+ productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]);
+ }
+ cstr.bound(value)[dim] == productExpr;
+ }
+};
+
struct SubViewOpInterface
: public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
@@ -134,6 +155,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
+ memref::CollapseShapeOp::attachInterface<memref::CollapseShapeOpInterface>(
+ *ctx);
memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>(
*ctx);
memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 1ab01d8..2946b53 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -397,7 +397,7 @@ void ExecuteRegionOp::getSuccessorRegions(
}
// Otherwise, the region branches back to the parent operation.
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
}
//===----------------------------------------------------------------------===//
@@ -405,10 +405,11 @@ void ExecuteRegionOp::getSuccessorRegions(
//===----------------------------------------------------------------------===//
MutableOperandRange
-ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
- assert((point.isParent() || point == getParentOp().getAfter()) &&
- "condition op can only exit the loop or branch to the after"
- "region");
+ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
+ assert(
+ (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) &&
+ "condition op can only exit the loop or branch to the after"
+ "region");
// Pass all operands except the condition to the successor region.
return getArgsMutable();
}
@@ -426,7 +427,7 @@ void ConditionOp::getSuccessorRegions(
regions.emplace_back(&whileOp.getAfter(),
whileOp.getAfter().getArguments());
if (!boolAttr || !boolAttr.getValue())
- regions.emplace_back(whileOp.getResults());
+ regions.emplace_back(whileOp.getOperation(), whileOp.getResults());
}
//===----------------------------------------------------------------------===//
@@ -749,7 +750,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
return dyn_cast_or_null<ForOp>(containingOp);
}
-OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
return getInitArgs();
}
@@ -759,7 +760,7 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
// back into the operation itself. It is possible for loop not to enter the
// body.
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
}
SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
@@ -2053,9 +2054,10 @@ void ForallOp::getSuccessorRegions(RegionBranchPoint point,
// parallel by multiple threads. We should not expect to branch back into
// the forall body after the region's execution is complete.
if (point.isParent())
- regions.push_back(RegionSuccessor(&getRegion()));
+ regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
else
- regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(getOperation(), getOperation()->getResults()));
}
//===----------------------------------------------------------------------===//
@@ -2333,9 +2335,10 @@ void IfOp::print(OpAsmPrinter &p) {
void IfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
- // The `then` and the `else` region branch back to the parent operation.
+ // The `then` and the `else` region branch back to the parent operation or one
+ // of the recursive parent operations (early exit case).
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
@@ -2344,7 +2347,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point,
// Don't consider the else region if it is empty.
Region *elseRegion = &this->getElseRegion();
if (elseRegion->empty())
- regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(getOperation(), getOperation()->getResults()));
else
regions.push_back(RegionSuccessor(elseRegion));
}
@@ -2361,7 +2365,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
if (!getElseRegion().empty())
regions.emplace_back(&getElseRegion());
else
- regions.emplace_back(getResults());
+ regions.emplace_back(getOperation(), getResults());
}
}
@@ -3385,7 +3389,8 @@ void ParallelOp::getSuccessorRegions(
// back into the operation itself. It is possible for loop not to enter the
// body.
regions.push_back(RegionSuccessor(&getRegion()));
- regions.push_back(RegionSuccessor());
+ regions.push_back(RegionSuccessor(
+ getOperation(), ResultRange{getResults().end(), getResults().end()}));
}
//===----------------------------------------------------------------------===//
@@ -3431,7 +3436,7 @@ LogicalResult ReduceOp::verifyRegions() {
}
MutableOperandRange
-ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) {
+ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
// No operands are forwarded to the next iteration.
return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
}
@@ -3514,8 +3519,8 @@ Block::BlockArgListType WhileOp::getRegionIterArgs() {
return getBeforeArguments();
}
-OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBefore() &&
+OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert(successor.getSuccessor() == &getBefore() &&
"WhileOp is expected to branch only to the first region");
return getInits();
}
@@ -3528,15 +3533,18 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point,
return;
}
- assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
+ assert(llvm::is_contained(
+ {&getAfter(), &getBefore()},
+ point.getTerminatorPredecessorOrNull()->getParentRegion()) &&
"there are only two regions in a WhileOp");
// The body region always branches back to the condition region.
- if (point == getAfter()) {
+ if (point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &getAfter()) {
regions.emplace_back(&getBefore(), getBefore().getArguments());
return;
}
- regions.emplace_back(getResults());
+ regions.emplace_back(getOperation(), getResults());
regions.emplace_back(&getAfter(), getAfter().getArguments());
}
@@ -4445,7 +4453,7 @@ void IndexSwitchOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
// All regions branch back to the parent op.
if (!point.isParent()) {
- successors.emplace_back(getResults());
+ successors.emplace_back(getOperation(), getResults());
return;
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index ae52af5..ddcbda8 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -23,7 +23,6 @@ namespace mlir {
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
-using namespace llvm;
using namespace mlir;
using scf::ForOp;
using scf::WhileOp;
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
index a2f03f1..00bef70 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
@@ -21,7 +21,6 @@ namespace mlir {
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
-using namespace llvm;
using namespace mlir;
using scf::LoopNest;
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 5ba8289..f0f22e5 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -346,7 +346,7 @@ void AssumingOp::getSuccessorRegions(
// parent, so return the correct RegionSuccessor purely based on the index
// being None or 0.
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 1a9d9e1..3962e3e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2597,7 +2597,7 @@ std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
-OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+OperandRange IterateOp::getEntrySuccessorOperands(RegionSuccessor successor) {
return getInitArgs();
}
@@ -2607,7 +2607,7 @@ void IterateOp::getSuccessorRegions(RegionBranchPoint point,
// or back into the operation itself.
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
// It is possible for loop not to enter the body.
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
}
void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
index f53d272..ffa8b40 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
@@ -152,19 +152,20 @@ IterationGraphSorter IterationGraphSorter::fromGenericOp(
}
IterationGraphSorter::IterationGraphSorter(
- SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl, Value out,
- AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes,
+ SmallVector<Value> &&insArg, SmallVector<AffineMap> &&loop2InsLvlArg,
+ Value out, AffineMap loop2OutLvl,
+ SmallVector<utils::IteratorType> &&iterTypesArg,
sparse_tensor::LoopOrderingStrategy strategy)
- : ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
- loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)),
+ : ins(std::move(insArg)), loop2InsLvl(std::move(loop2InsLvlArg)), out(out),
+ loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypesArg)),
strategy(strategy) {
// One map per tensor.
- assert(this->loop2InsLvl.size() == this->ins.size());
+ assert(loop2InsLvl.size() == ins.size());
// All the affine maps have the same number of dimensions (loops).
assert(llvm::all_equal(llvm::map_range(
- this->loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
+ loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
// The number of results of the map should match the rank of the tensor.
- assert(llvm::all_of(llvm::zip(this->loop2InsLvl, this->ins), [](auto mvPair) {
+ assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
auto [m, v] = mvPair;
// For ranked types the rank must match.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
index b2a16e9..35e58ed 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h
@@ -59,10 +59,10 @@ public:
private:
// Private constructor.
- IterationGraphSorter(SmallVector<Value> &&ins,
- SmallVector<AffineMap> &&loop2InsLvl, Value out,
+ IterationGraphSorter(SmallVector<Value> &&insArg,
+ SmallVector<AffineMap> &&loop2InsLvlArg, Value out,
AffineMap loop2OutLvl,
- SmallVector<utils::IteratorType> &&iterTypes,
+ SmallVector<utils::IteratorType> &&iterTypesArg,
sparse_tensor::LoopOrderingStrategy strategy =
sparse_tensor::LoopOrderingStrategy::kDefault);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 1e3b377..549ac7a 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -77,7 +77,7 @@ FailureOr<TilingResult> tensor::replaceInsertSlicesWithTiledConsumer(
dyn_cast<TilingInterface>(consumerOperands.front()->getOwner());
if (!consumerOp)
return failure();
- for (auto opOperand : consumerOperands.drop_front()) {
+ for (auto *opOperand : consumerOperands.drop_front()) {
if (opOperand->getOwner() != consumerOp) {
LLVM_DEBUG({
llvm::dbgs()
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 365afab..062606e 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -96,9 +96,9 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
// AlternativesOp
//===----------------------------------------------------------------------===//
-OperandRange
-transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- if (!point.isParent() && getOperation()->getNumOperands() == 1)
+OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
+ RegionSuccessor successor) {
+ if (!successor.isParent() && getOperation()->getNumOperands() == 1)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
getOperation()->operand_end());
@@ -107,15 +107,18 @@ transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
void transform::AlternativesOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
for (Region &alternative : llvm::drop_begin(
- getAlternatives(),
- point.isParent() ? 0
- : point.getRegionOrNull()->getRegionNumber() + 1)) {
+ getAlternatives(), point.isParent()
+ ? 0
+ : point.getTerminatorPredecessorOrNull()
+ ->getParentRegion()
+ ->getRegionNumber() +
+ 1)) {
regions.emplace_back(&alternative, !getOperands().empty()
? alternative.getArguments()
: Block::BlockArgListType());
}
if (!point.isParent())
- regions.emplace_back(getOperation()->getResults());
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
void transform::AlternativesOp::getRegionInvocationBounds(
@@ -1740,16 +1743,18 @@ void transform::ForeachOp::getSuccessorRegions(
}
// Branch back to the region or the parent.
- assert(point == getBody() && "unexpected region index");
+ assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &getBody() &&
+ "unexpected region index");
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
- regions.emplace_back();
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
OperandRange
-transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+transform::ForeachOp::getEntrySuccessorOperands(RegionSuccessor successor) {
// Each block argument handle is mapped to a subset (one op to be precise)
// of the payload of the corresponding `targets` operand of ForeachOp.
- assert(point == getBody() && "unexpected region index");
+ assert(successor.getSuccessor() == &getBody() && "unexpected region index");
return getOperation()->getOperands();
}
@@ -2948,8 +2953,8 @@ void transform::SequenceOp::getEffects(
}
OperandRange
-transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBody() && "unexpected region index");
+transform::SequenceOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert(successor.getSuccessor() == &getBody() && "unexpected region index");
if (getOperation()->getNumOperands() > 0)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
@@ -2966,8 +2971,10 @@ void transform::SequenceOp::getSuccessorRegions(
return;
}
- assert(point == getBody() && "unexpected region index");
- regions.emplace_back(getOperation()->getResults());
+ assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &getBody() &&
+ "unexpected region index");
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
void transform::SequenceOp::getRegionInvocationBounds(
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
index c627158..f727118 100644
--- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
@@ -112,7 +113,7 @@ static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
}
OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
- RegionBranchPoint point) {
+ RegionSuccessor successor) {
// No operands will be forwarded to the region(s).
return getOperands().slice(0, 0);
}
@@ -128,7 +129,7 @@ void transform::tune::AlternativesOp::getSuccessorRegions(
for (Region &alternative : getAlternatives())
regions.emplace_back(&alternative, Block::BlockArgListType());
else
- regions.emplace_back(getOperation()->getResults());
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
void transform::tune::AlternativesOp::getRegionInvocationBounds(
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index 776b5c6..f4c9242 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -138,6 +138,10 @@ Diagnostic &Diagnostic::operator<<(Operation &op) {
return appendOp(op, OpPrintingFlags());
}
+Diagnostic &Diagnostic::operator<<(OpWithFlags op) {
+ return appendOp(*op.getOperation(), op.flags());
+}
+
Diagnostic &Diagnostic::appendOp(Operation &op, const OpPrintingFlags &flags) {
std::string str;
llvm::raw_string_ostream os(str);
diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp
index 46b6298..15a941f 100644
--- a/mlir/lib/IR/Region.cpp
+++ b/mlir/lib/IR/Region.cpp
@@ -253,6 +253,21 @@ void Region::OpIterator::skipOverBlocksWithNoOps() {
operation = block->begin();
}
+llvm::raw_ostream &mlir::operator<<(llvm::raw_ostream &os, Region &region) {
+ if (!region.getParentOp()) {
+ os << "Region has no parent op";
+ } else {
+ os << "Region #" << region.getRegionNumber() << " in operation "
+ << region.getParentOp()->getName();
+ }
+ for (auto it : llvm::enumerate(region.getBlocks())) {
+ os << "\n Block #" << it.index() << ":";
+ for (Operation &op : it.value().getOperations())
+ os << "\n " << OpWithFlags(&op, OpPrintingFlags().skipRegions());
+ }
+ return os;
+}
+
//===----------------------------------------------------------------------===//
// RegionRange
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index ca3f766..1e56810 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -9,7 +9,9 @@
#include <utility>
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "llvm/Support/DebugLog.h"
using namespace mlir;
@@ -38,20 +40,31 @@ SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
std::optional<BlockArgument>
detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
unsigned operandIndex, Block *successor) {
+ LDBG() << "Getting branch successor argument for operand index "
+ << operandIndex << " in successor block";
+
OperandRange forwardedOperands = operands.getForwardedOperands();
// Check that the operands are valid.
- if (forwardedOperands.empty())
+ if (forwardedOperands.empty()) {
+ LDBG() << "No forwarded operands, returning nullopt";
return std::nullopt;
+ }
// Check to ensure that this operand is within the range.
unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
if (operandIndex < operandsStart ||
- operandIndex >= (operandsStart + forwardedOperands.size()))
+ operandIndex >= (operandsStart + forwardedOperands.size())) {
+ LDBG() << "Operand index " << operandIndex << " out of range ["
+ << operandsStart << ", "
+ << (operandsStart + forwardedOperands.size())
+ << "), returning nullopt";
return std::nullopt;
+ }
// Index the successor.
unsigned argIndex =
operands.getProducedOperandCount() + operandIndex - operandsStart;
+ LDBG() << "Computed argument index " << argIndex << " for successor block";
return successor->getArgument(argIndex);
}
@@ -59,9 +72,15 @@ detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
LogicalResult
detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
const SuccessorOperands &operands) {
+ LDBG() << "Verifying branch successor operands for successor #" << succNo
+ << " in operation " << op->getName();
+
// Check the count.
unsigned operandCount = operands.size();
Block *destBB = op->getSuccessor(succNo);
+ LDBG() << "Branch has " << operandCount << " operands, target block has "
+ << destBB->getNumArguments() << " arguments";
+
if (operandCount != destBB->getNumArguments())
return op->emitError() << "branch has " << operandCount
<< " operands for successor #" << succNo
@@ -69,13 +88,22 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
<< destBB->getNumArguments();
// Check the types.
+ LDBG() << "Checking type compatibility for "
+ << (operandCount - operands.getProducedOperandCount())
+ << " forwarded operands";
for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
++i) {
- if (!cast<BranchOpInterface>(op).areTypesCompatible(
- operands[i].getType(), destBB->getArgument(i).getType()))
+ Type operandType = operands[i].getType();
+ Type argType = destBB->getArgument(i).getType();
+ LDBG() << "Checking type compatibility: operand type " << operandType
+ << " vs argument type " << argType;
+
+ if (!cast<BranchOpInterface>(op).areTypesCompatible(operandType, argType))
return op->emitError() << "type mismatch for bb argument #" << i
<< " of successor #" << succNo;
}
+
+ LDBG() << "Branch successor operand verification successful";
return success();
}
@@ -126,15 +154,15 @@ LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
RegionBranchPoint sourceNo,
- RegionBranchPoint succRegionNo) {
+ RegionSuccessor succRegionNo) {
diag << "from ";
- if (Region *region = sourceNo.getRegionOrNull())
- diag << "Region #" << region->getRegionNumber();
+ if (Operation *op = sourceNo.getTerminatorPredecessorOrNull())
+ diag << "Operation " << op->getName();
else
diag << "parent operands";
diag << " to ";
- if (Region *region = succRegionNo.getRegionOrNull())
+ if (Region *region = succRegionNo.getSuccessor())
diag << "Region #" << region->getRegionNumber();
else
diag << "parent results";
@@ -145,13 +173,12 @@ static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
/// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
/// types of the inputs that flow to a successor region.
static LogicalResult
-verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
- function_ref<FailureOr<TypeRange>(RegionBranchPoint)>
+verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp,
+ RegionBranchPoint sourcePoint,
+ function_ref<FailureOr<TypeRange>(RegionSuccessor)>
getInputsTypesForRegion) {
- auto regionInterface = cast<RegionBranchOpInterface>(op);
-
SmallVector<RegionSuccessor, 2> successors;
- regionInterface.getSuccessorRegions(sourcePoint, successors);
+ branchOp.getSuccessorRegions(sourcePoint, successors);
for (RegionSuccessor &succ : successors) {
FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
@@ -160,10 +187,14 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
if (sourceTypes->size() != succInputsTypes.size()) {
- InFlightDiagnostic diag = op->emitOpError("region control flow edge ");
+ InFlightDiagnostic diag =
+ branchOp->emitOpError("region control flow edge ");
+ std::string succStr;
+ llvm::raw_string_ostream os(succStr);
+ os << succ;
return printRegionEdgeName(diag, sourcePoint, succ)
<< ": source has " << sourceTypes->size()
- << " operands, but target successor needs "
+ << " operands, but target successor " << os.str() << " needs "
<< succInputsTypes.size();
}
@@ -171,8 +202,10 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
Type sourceType = std::get<0>(typesIdx.value());
Type inputType = std::get<1>(typesIdx.value());
- if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
- InFlightDiagnostic diag = op->emitOpError("along control flow edge ");
+
+ if (!branchOp.areTypesCompatible(sourceType, inputType)) {
+ InFlightDiagnostic diag =
+ branchOp->emitOpError("along control flow edge ");
return printRegionEdgeName(diag, sourcePoint, succ)
<< ": source type #" << typesIdx.index() << " " << sourceType
<< " should match input type #" << typesIdx.index() << " "
@@ -180,6 +213,7 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
}
}
}
+
return success();
}
@@ -187,34 +221,18 @@ verifyTypesAlongAllEdges(Operation *op, RegionBranchPoint sourcePoint,
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
- auto inputTypesFromParent = [&](RegionBranchPoint point) -> TypeRange {
- return regionInterface.getEntrySuccessorOperands(point).getTypes();
+ auto inputTypesFromParent = [&](RegionSuccessor successor) -> TypeRange {
+ return regionInterface.getEntrySuccessorOperands(successor).getTypes();
};
// Verify types along control flow edges originating from the parent.
- if (failed(verifyTypesAlongAllEdges(op, RegionBranchPoint::parent(),
- inputTypesFromParent)))
+ if (failed(verifyTypesAlongAllEdges(
+ regionInterface, RegionBranchPoint::parent(), inputTypesFromParent)))
return failure();
- auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
- if (lhs.size() != rhs.size())
- return false;
- for (auto types : llvm::zip(lhs, rhs)) {
- if (!regionInterface.areTypesCompatible(std::get<0>(types),
- std::get<1>(types))) {
- return false;
- }
- }
- return true;
- };
-
// Verify types along control flow edges originating from each region.
for (Region &region : op->getRegions()) {
-
- // Since there can be multiple terminators implementing the
- // `RegionBranchTerminatorOpInterface`, all should have the same operand
- // types when passing them to the same region.
-
+ // Collect all return-like terminators in the region.
SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
for (Block &block : region)
if (!block.empty())
@@ -227,33 +245,20 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
if (regionReturnOps.empty())
continue;
- auto inputTypesForRegion =
- [&](RegionBranchPoint point) -> FailureOr<TypeRange> {
- std::optional<OperandRange> regionReturnOperands;
- for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
- auto terminatorOperands = regionReturnOp.getSuccessorOperands(point);
-
- if (!regionReturnOperands) {
- regionReturnOperands = terminatorOperands;
- continue;
- }
-
- // Found more than one ReturnLike terminator. Make sure the operand
- // types match with the first one.
- if (!areTypesCompatible(regionReturnOperands->getTypes(),
- terminatorOperands.getTypes())) {
- InFlightDiagnostic diag = op->emitOpError("along control flow edge");
- return printRegionEdgeName(diag, region, point)
- << " operands mismatch between return-like terminators";
- }
- }
-
- // All successors get the same set of operand types.
- return TypeRange(regionReturnOperands->getTypes());
- };
-
- if (failed(verifyTypesAlongAllEdges(op, region, inputTypesForRegion)))
- return failure();
+ // Verify types along control flow edges originating from each return-like
+ // terminator.
+ for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
+
+ auto inputTypesForRegion =
+ [&](RegionSuccessor successor) -> FailureOr<TypeRange> {
+ OperandRange terminatorOperands =
+ regionReturnOp.getSuccessorOperands(successor);
+ return TypeRange(terminatorOperands.getTypes());
+ };
+ if (failed(verifyTypesAlongAllEdges(regionInterface, regionReturnOp,
+ inputTypesForRegion)))
+ return failure();
+ }
}
return success();
@@ -272,31 +277,74 @@ using StopConditionFn = function_ref<bool(Region *, ArrayRef<bool> visited)>;
static bool traverseRegionGraph(Region *begin,
StopConditionFn stopConditionFn) {
auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
+ LDBG() << "Starting region graph traversal from region #"
+ << begin->getRegionNumber() << " in operation " << op->getName();
+
SmallVector<bool> visited(op->getNumRegions(), false);
visited[begin->getRegionNumber()] = true;
+ LDBG() << "Initialized visited array with " << op->getNumRegions()
+ << " regions";
// Retrieve all successors of the region and enqueue them in the worklist.
SmallVector<Region *> worklist;
auto enqueueAllSuccessors = [&](Region *region) {
- SmallVector<RegionSuccessor> successors;
- op.getSuccessorRegions(region, successors);
- for (RegionSuccessor successor : successors)
- if (!successor.isParent())
- worklist.push_back(successor.getSuccessor());
+ LDBG() << "Enqueuing successors for region #" << region->getRegionNumber();
+ SmallVector<Attribute> operandAttributes(op->getNumOperands());
+ for (Block &block : *region) {
+ if (block.empty())
+ continue;
+ auto terminator =
+ dyn_cast<RegionBranchTerminatorOpInterface>(block.back());
+ if (!terminator)
+ continue;
+ SmallVector<RegionSuccessor> successors;
+ operandAttributes.resize(terminator->getNumOperands());
+ terminator.getSuccessorRegions(operandAttributes, successors);
+ LDBG() << "Found " << successors.size()
+ << " successors from terminator in block";
+ for (RegionSuccessor successor : successors) {
+ if (!successor.isParent()) {
+ worklist.push_back(successor.getSuccessor());
+ LDBG() << "Added region #"
+ << successor.getSuccessor()->getRegionNumber()
+ << " to worklist";
+ } else {
+ LDBG() << "Skipping parent successor";
+ }
+ }
+ }
};
enqueueAllSuccessors(begin);
+ LDBG() << "Initial worklist size: " << worklist.size();
// Process all regions in the worklist via DFS.
while (!worklist.empty()) {
Region *nextRegion = worklist.pop_back_val();
- if (stopConditionFn(nextRegion, visited))
+ LDBG() << "Processing region #" << nextRegion->getRegionNumber()
+ << " from worklist (remaining: " << worklist.size() << ")";
+
+ if (stopConditionFn(nextRegion, visited)) {
+ LDBG() << "Stop condition met for region #"
+ << nextRegion->getRegionNumber() << ", returning true";
return true;
- if (visited[nextRegion->getRegionNumber()])
+ }
+ llvm::dbgs() << "Region: " << nextRegion << "\n";
+ if (!nextRegion->getParentOp()) {
+ llvm::errs() << "Region " << *nextRegion << " has no parent op\n";
+ return false;
+ }
+ if (visited[nextRegion->getRegionNumber()]) {
+ LDBG() << "Region #" << nextRegion->getRegionNumber()
+ << " already visited, skipping";
continue;
+ }
visited[nextRegion->getRegionNumber()] = true;
+ LDBG() << "Marking region #" << nextRegion->getRegionNumber()
+ << " as visited";
enqueueAllSuccessors(nextRegion);
}
+ LDBG() << "Traversal completed, returning false";
return false;
}
@@ -322,18 +370,26 @@ static bool isRegionReachable(Region *begin, Region *r) {
/// mutually exclusive if they are not reachable from each other as per
/// RegionBranchOpInterface::getSuccessorRegions.
bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
+ LDBG() << "Checking if operations are in mutually exclusive regions: "
+ << a->getName() << " and " << b->getName();
+
assert(a && "expected non-empty operation");
assert(b && "expected non-empty operation");
auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
while (branchOp) {
+ LDBG() << "Checking branch operation " << branchOp->getName();
+
// Check if b is inside branchOp. (We already know that a is.)
if (!branchOp->isProperAncestor(b)) {
+ LDBG() << "Operation b is not inside branchOp, checking next ancestor";
// Check next enclosing RegionBranchOpInterface.
branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
continue;
}
+ LDBG() << "Both operations are inside branchOp, finding their regions";
+
// b is contained in branchOp. Retrieve the regions in which `a` and `b`
// are contained.
Region *regionA = nullptr, *regionB = nullptr;
@@ -341,63 +397,136 @@ bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
if (r.findAncestorOpInRegion(*a)) {
assert(!regionA && "already found a region for a");
regionA = &r;
+ LDBG() << "Found region #" << r.getRegionNumber() << " for operation a";
}
if (r.findAncestorOpInRegion(*b)) {
assert(!regionB && "already found a region for b");
regionB = &r;
+ LDBG() << "Found region #" << r.getRegionNumber() << " for operation b";
}
}
assert(regionA && regionB && "could not find region of op");
+ LDBG() << "Region A: #" << regionA->getRegionNumber() << ", Region B: #"
+ << regionB->getRegionNumber();
+
// `a` and `b` are in mutually exclusive regions if both regions are
// distinct and neither region is reachable from the other region.
- return regionA != regionB && !isRegionReachable(regionA, regionB) &&
- !isRegionReachable(regionB, regionA);
+ bool regionsAreDistinct = (regionA != regionB);
+ bool aNotReachableFromB = !isRegionReachable(regionA, regionB);
+ bool bNotReachableFromA = !isRegionReachable(regionB, regionA);
+
+ LDBG() << "Regions distinct: " << regionsAreDistinct
+ << ", A not reachable from B: " << aNotReachableFromB
+ << ", B not reachable from A: " << bNotReachableFromA;
+
+ bool mutuallyExclusive =
+ regionsAreDistinct && aNotReachableFromB && bNotReachableFromA;
+ LDBG() << "Operations are mutually exclusive: " << mutuallyExclusive;
+
+ return mutuallyExclusive;
}
// Could not find a common RegionBranchOpInterface among a's and b's
// ancestors.
+ LDBG() << "No common RegionBranchOpInterface found, operations are not "
+ "mutually exclusive";
return false;
}
bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
+ LDBG() << "Checking if region #" << index << " is repetitive in operation "
+ << getOperation()->getName();
+
Region *region = &getOperation()->getRegion(index);
- return isRegionReachable(region, region);
+ bool isRepetitive = isRegionReachable(region, region);
+
+ LDBG() << "Region #" << index << " is repetitive: " << isRepetitive;
+ return isRepetitive;
}
bool RegionBranchOpInterface::hasLoop() {
+ LDBG() << "Checking if operation " << getOperation()->getName()
+ << " has loops";
+
SmallVector<RegionSuccessor> entryRegions;
getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
- for (RegionSuccessor successor : entryRegions)
- if (!successor.isParent() &&
- traverseRegionGraph(successor.getSuccessor(),
- [](Region *nextRegion, ArrayRef<bool> visited) {
- // Interrupt traversal if the region was already
- // visited.
- return visited[nextRegion->getRegionNumber()];
- }))
- return true;
+ LDBG() << "Found " << entryRegions.size() << " entry regions";
+
+ for (RegionSuccessor successor : entryRegions) {
+ if (!successor.isParent()) {
+ LDBG() << "Checking entry region #"
+ << successor.getSuccessor()->getRegionNumber() << " for loops";
+
+ bool hasLoop =
+ traverseRegionGraph(successor.getSuccessor(),
+ [](Region *nextRegion, ArrayRef<bool> visited) {
+ // Interrupt traversal if the region was already
+ // visited.
+ return visited[nextRegion->getRegionNumber()];
+ });
+
+ if (hasLoop) {
+ LDBG() << "Found loop in entry region #"
+ << successor.getSuccessor()->getRegionNumber();
+ return true;
+ }
+ } else {
+ LDBG() << "Skipping parent successor";
+ }
+ }
+
+ LDBG() << "No loops found in operation";
return false;
}
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
+ LDBG() << "Finding enclosing repetitive region for operation "
+ << op->getName();
+
while (Region *region = op->getParentRegion()) {
+ LDBG() << "Checking region #" << region->getRegionNumber()
+ << " in operation " << region->getParentOp()->getName();
+
op = region->getParentOp();
- if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
- if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
+ if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
+ LDBG()
+ << "Found RegionBranchOpInterface, checking if region is repetitive";
+ if (branchOp.isRepetitiveRegion(region->getRegionNumber())) {
+ LDBG() << "Found repetitive region #" << region->getRegionNumber();
return region;
+ }
+ } else {
+ LDBG() << "Parent operation does not implement RegionBranchOpInterface";
+ }
}
+
+ LDBG() << "No enclosing repetitive region found";
return nullptr;
}
Region *mlir::getEnclosingRepetitiveRegion(Value value) {
+ LDBG() << "Finding enclosing repetitive region for value";
+
Region *region = value.getParentRegion();
while (region) {
+ LDBG() << "Checking region #" << region->getRegionNumber()
+ << " in operation " << region->getParentOp()->getName();
+
Operation *op = region->getParentOp();
- if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
- if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
+ if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
+ LDBG()
+ << "Found RegionBranchOpInterface, checking if region is repetitive";
+ if (branchOp.isRepetitiveRegion(region->getRegionNumber())) {
+ LDBG() << "Found repetitive region #" << region->getRegionNumber();
return region;
+ }
+ } else {
+ LDBG() << "Parent operation does not implement RegionBranchOpInterface";
+ }
region = op->getParentRegion();
}
+
+ LDBG() << "No enclosing repetitive region found for value";
return nullptr;
}
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index e0c65b0..41f3f9d 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -432,8 +432,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// Return the successors of `region` if the latter is not null. Else return
// the successors of `regionBranchOp`.
- auto getSuccessors = [&](Region *region = nullptr) {
- auto point = region ? region : RegionBranchPoint::parent();
+ auto getSuccessors = [&](RegionBranchPoint point) {
SmallVector<RegionSuccessor> successors;
regionBranchOp.getSuccessorRegions(point, successors);
return successors;
@@ -456,7 +455,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// `nonForwardedOperands`.
auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
- for (const RegionSuccessor &successor : getSuccessors()) {
+ for (const RegionSuccessor &successor :
+ getSuccessors(RegionBranchPoint::parent())) {
for (OpOperand *opOperand : getForwardedOpOperands(successor))
nonForwardedOperands.reset(opOperand->getOperandNumber());
}
@@ -469,10 +469,13 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
for (Region &region : regionBranchOp->getRegions()) {
if (region.empty())
continue;
+ // TODO: this isn't correct in face of multiple terminators.
Operation *terminator = region.front().getTerminator();
nonForwardedRets[terminator] =
BitVector(terminator->getNumOperands(), true);
- for (const RegionSuccessor &successor : getSuccessors(&region)) {
+ for (const RegionSuccessor &successor :
+ getSuccessors(RegionBranchPoint(
+ cast<RegionBranchTerminatorOpInterface>(terminator)))) {
for (OpOperand *opOperand :
getForwardedOpOperands(successor, terminator))
nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
@@ -489,8 +492,13 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
Operation *terminator =
region ? region->front().getTerminator() : nullptr;
+ RegionBranchPoint point =
+ terminator
+ ? RegionBranchPoint(
+ cast<RegionBranchTerminatorOpInterface>(terminator))
+ : RegionBranchPoint::parent();
- for (const RegionSuccessor &successor : getSuccessors(region)) {
+ for (const RegionSuccessor &successor : getSuccessors(point)) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor, terminator),
@@ -517,7 +525,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
resultsOrArgsToKeepChanged = false;
// Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
- for (const RegionSuccessor &successor : getSuccessors()) {
+ for (const RegionSuccessor &successor :
+ getSuccessors(RegionBranchPoint::parent())) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor),
@@ -551,7 +560,9 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
Operation *terminator = region.front().getTerminator();
- for (const RegionSuccessor &successor : getSuccessors(&region)) {
+ for (const RegionSuccessor &successor :
+ getSuccessors(RegionBranchPoint(
+ cast<RegionBranchTerminatorOpInterface>(terminator)))) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
llvm::zip(getForwardedOpOperands(successor, terminator),