aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2024-03-25 18:57:53 +0900
committerGitHub <noreply@github.com>2024-03-25 18:57:53 +0900
commitdbfc38ed6b3f2a9be0b1a86b2a074aad69eb58a6 (patch)
treeced53182b26727bb5f4842221cfd708b29882fed /mlir/lib
parent336bdf1a255571f8d894e8befe4be7e9141f7541 (diff)
downloadllvm-dbfc38ed6b3f2a9be0b1a86b2a074aad69eb58a6.zip
llvm-dbfc38ed6b3f2a9be0b1a86b2a074aad69eb58a6.tar.gz
llvm-dbfc38ed6b3f2a9be0b1a86b2a074aad69eb58a6.tar.bz2
[mlir][bufferization] Add `BufferOriginAnalysis` (#86461)
This commit adds the `BufferOriginAnalysis`, which can be queried to check if two buffer SSA values originate from the same allocation. This new analysis is used in the buffer deallocation pass to fold away or simplify `bufferization.dealloc` ops more aggressively. The `BufferOriginAnalysis` is based on the `BufferViewFlowAnalysis`, which collects buffer SSA value "same buffer" dependencies. E.g., given IR such as: ``` %0 = memref.alloc() %1 = memref.subview %0 %2 = memref.subview %1 ``` The `BufferViewFlowAnalysis` will report the following "reverse" dependencies (`resolveReverse`) for `%2`: {`%2`, `%1`, `%0`}. I.e., all buffer SSA values in the reverse use-def chain that originate from the same allocation as `%2`. The `BufferOriginAnalysis` is built on top of that. It handles only simple cases at the moment and may conservatively return "unknown" around certain IR with branches, memref globals and function arguments. This analysis enables additional simplifications during `-buffer-deallocation-simplification`. In particular, "regular" scf.for loop nests, that yield buffers (or reallocations thereof) in the same order as they appear in the iter_args, are now handled much more efficiently. Such IR patterns are generated by the sparse compiler.
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp80
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp160
2 files changed, 192 insertions, 48 deletions
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index e307798..954485c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -12,8 +12,8 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Analysis/AliasAnalysis.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -34,6 +34,14 @@ using namespace mlir::bufferization;
// Helpers
//===----------------------------------------------------------------------===//
+/// Given a memref value, return the "base" value by skipping over all
+/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
+static Value getViewBase(Value value) {
+ while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+ value = viewLikeOp.getViewSource();
+ return value;
+}
+
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
ValueRange memrefs,
ValueRange conditions,
@@ -49,14 +57,6 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
return success();
}
-/// Given a memref value, return the "base" value by skipping over all
-/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
-static Value getViewBase(Value value) {
- while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
- value = viewLikeOp.getViewSource();
- return value;
-}
-
/// Return "true" if the given values are guaranteed to be different (and
/// non-aliasing) allocations based on the fact that one value is the result
/// of an allocation and the other value is a block argument of a parent block.
@@ -80,12 +80,14 @@ static bool distinctAllocAndBlockArgument(Value v1, Value v2) {
/// Checks if `memref` may potentially alias a MemRef in `otherList`. It is
/// often a requirement of optimization patterns that there cannot be any
/// aliasing memref in order to perform the desired simplification.
-static bool potentiallyAliasesMemref(AliasAnalysis &analysis,
+static bool potentiallyAliasesMemref(BufferOriginAnalysis &analysis,
ValueRange otherList, Value memref) {
for (auto other : otherList) {
if (distinctAllocAndBlockArgument(other, memref))
continue;
- if (!analysis.alias(other, memref).isNo())
+ std::optional<bool> analysisResult =
+ analysis.isSameAllocation(other, memref);
+ if (!analysisResult.has_value() || analysisResult == true)
return true;
}
return false;
@@ -129,8 +131,8 @@ namespace {
struct RemoveDeallocMemrefsContainedInRetained
: public OpRewritePattern<DeallocOp> {
RemoveDeallocMemrefsContainedInRetained(MLIRContext *context,
- AliasAnalysis &aliasAnalysis)
- : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+ BufferOriginAnalysis &analysis)
+ : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
/// The passed 'memref' must not have a may-alias relation to any retained
/// memref, and at least one must-alias relation. If there is no must-aliasing
@@ -147,10 +149,11 @@ struct RemoveDeallocMemrefsContainedInRetained
// deallocated in some situations and can thus not be dropped).
bool atLeastOneMustAlias = false;
for (Value retained : deallocOp.getRetained()) {
- AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
- if (analysisResult.isMay())
+ std::optional<bool> analysisResult =
+ analysis.isSameAllocation(retained, memref);
+ if (!analysisResult.has_value())
return failure();
- if (analysisResult.isMust() || analysisResult.isPartial())
+ if (analysisResult == true)
atLeastOneMustAlias = true;
}
if (!atLeastOneMustAlias)
@@ -161,8 +164,9 @@ struct RemoveDeallocMemrefsContainedInRetained
// we can remove that operand later on.
for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) {
Value updatedCondition = deallocOp.getUpdatedConditions()[i];
- AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
- if (analysisResult.isMust() || analysisResult.isPartial()) {
+ std::optional<bool> analysisResult =
+ analysis.isSameAllocation(retained, memref);
+ if (analysisResult == true) {
auto disjunction = rewriter.create<arith::OrIOp>(
deallocOp.getLoc(), updatedCondition, cond);
rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(),
@@ -206,7 +210,7 @@ struct RemoveDeallocMemrefsContainedInRetained
}
private:
- AliasAnalysis &aliasAnalysis;
+ BufferOriginAnalysis &analysis;
};
/// Remove memrefs from the `retained` list which are guaranteed to not alias
@@ -228,15 +232,15 @@ private:
struct RemoveRetainedMemrefsGuaranteedToNotAlias
: public OpRewritePattern<DeallocOp> {
RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context,
- AliasAnalysis &aliasAnalysis)
- : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+ BufferOriginAnalysis &analysis)
+ : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
SmallVector<Value> newRetainedMemrefs, replacements;
for (auto retainedMemref : deallocOp.getRetained()) {
- if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(),
+ if (potentiallyAliasesMemref(analysis, deallocOp.getMemrefs(),
retainedMemref)) {
newRetainedMemrefs.push_back(retainedMemref);
replacements.push_back({});
@@ -264,7 +268,7 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias
}
private:
- AliasAnalysis &aliasAnalysis;
+ BufferOriginAnalysis &analysis;
};
/// Split off memrefs to separate dealloc operations to reduce the number of
@@ -297,8 +301,8 @@ private:
struct SplitDeallocWhenNotAliasingAnyOther
: public OpRewritePattern<DeallocOp> {
SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context,
- AliasAnalysis &aliasAnalysis)
- : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+ BufferOriginAnalysis &analysis)
+ : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
@@ -314,7 +318,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
SmallVector<Value> otherMemrefs(deallocOp.getMemrefs());
otherMemrefs.erase(otherMemrefs.begin() + i);
// Check if `memref` can split off into a separate bufferization.dealloc.
- if (potentiallyAliasesMemref(aliasAnalysis, otherMemrefs, memref)) {
+ if (potentiallyAliasesMemref(analysis, otherMemrefs, memref)) {
// `memref` alias with other memrefs, do not split off.
remainingMemrefs.push_back(memref);
remainingConditions.push_back(cond);
@@ -352,7 +356,7 @@ struct SplitDeallocWhenNotAliasingAnyOther
}
private:
- AliasAnalysis &aliasAnalysis;
+ BufferOriginAnalysis &analysis;
};
/// Check for every retained memref if a must-aliasing memref exists in the
@@ -381,8 +385,8 @@ private:
struct RetainedMemrefAliasingAlwaysDeallocatedMemref
: public OpRewritePattern<DeallocOp> {
RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context,
- AliasAnalysis &aliasAnalysis)
- : OpRewritePattern<DeallocOp>(context), aliasAnalysis(aliasAnalysis) {}
+ BufferOriginAnalysis &analysis)
+ : OpRewritePattern<DeallocOp>(context), analysis(analysis) {}
LogicalResult matchAndRewrite(DeallocOp deallocOp,
PatternRewriter &rewriter) const override {
@@ -396,8 +400,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
if (!matchPattern(cond, m_One()))
continue;
- AliasResult analysisResult = aliasAnalysis.alias(retained, memref);
- if (analysisResult.isMust() || analysisResult.isPartial()) {
+ std::optional<bool> analysisResult =
+ analysis.isSameAllocation(retained, memref);
+ if (analysisResult == true) {
rewriter.replaceAllUsesWith(res, cond);
aliasesWithConstTrueMemref[i] = true;
canDropMemref = true;
@@ -411,10 +416,9 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
if (!extractOp)
continue;
- AliasResult extractAnalysisResult =
- aliasAnalysis.alias(retained, extractOp.getOperand());
- if (extractAnalysisResult.isMust() ||
- extractAnalysisResult.isPartial()) {
+ std::optional<bool> extractAnalysisResult =
+ analysis.isSameAllocation(retained, extractOp.getOperand());
+ if (extractAnalysisResult == true) {
rewriter.replaceAllUsesWith(res, cond);
aliasesWithConstTrueMemref[i] = true;
canDropMemref = true;
@@ -434,7 +438,7 @@ struct RetainedMemrefAliasingAlwaysDeallocatedMemref
}
private:
- AliasAnalysis &aliasAnalysis;
+ BufferOriginAnalysis &analysis;
};
} // namespace
@@ -452,13 +456,13 @@ struct BufferDeallocationSimplificationPass
: public bufferization::impl::BufferDeallocationSimplificationBase<
BufferDeallocationSimplificationPass> {
void runOnOperation() override {
- AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
+ BufferOriginAnalysis analysis(getOperation());
RewritePatternSet patterns(&getContext());
patterns.add<RemoveDeallocMemrefsContainedInRetained,
RemoveRetainedMemrefsGuaranteedToNotAlias,
SplitDeallocWhenNotAliasingAnyOther,
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
- aliasAnalysis);
+ analysis);
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
if (failed(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 9a36057..72f47b8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -19,22 +19,23 @@
using namespace mlir;
using namespace mlir::bufferization;
+//===----------------------------------------------------------------------===//
+// BufferViewFlowAnalysis
+//===----------------------------------------------------------------------===//
+
/// Constructs a new alias analysis using the op provided.
BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
-/// Find all immediate and indirect dependent buffers this value could
-/// potentially have. Note that the resulting set will also contain the value
-/// provided as it is a dependent alias of itself.
-BufferViewFlowAnalysis::ValueSetT
-BufferViewFlowAnalysis::resolve(Value rootValue) const {
- ValueSetT result;
+static BufferViewFlowAnalysis::ValueSetT
+resolveValues(const BufferViewFlowAnalysis::ValueMapT &map, Value value) {
+ BufferViewFlowAnalysis::ValueSetT result;
SmallVector<Value, 8> queue;
- queue.push_back(rootValue);
+ queue.push_back(value);
while (!queue.empty()) {
Value currentValue = queue.pop_back_val();
if (result.insert(currentValue).second) {
- auto it = dependencies.find(currentValue);
- if (it != dependencies.end()) {
+ auto it = map.find(currentValue);
+ if (it != map.end()) {
for (Value aliasValue : it->second)
queue.push_back(aliasValue);
}
@@ -43,6 +44,19 @@ BufferViewFlowAnalysis::resolve(Value rootValue) const {
return result;
}
+/// Find all immediate and indirect dependent buffers this value could
+/// potentially have. Note that the resulting set will also contain the value
+/// provided as it is a dependent alias of itself.
+BufferViewFlowAnalysis::ValueSetT
+BufferViewFlowAnalysis::resolve(Value rootValue) const {
+ return resolveValues(dependencies, rootValue);
+}
+
+BufferViewFlowAnalysis::ValueSetT
+BufferViewFlowAnalysis::resolveReverse(Value rootValue) const {
+ return resolveValues(reverseDependencies, rootValue);
+}
+
/// Removes the given values from all alias sets.
void BufferViewFlowAnalysis::remove(const SetVector<Value> &aliasValues) {
for (auto &entry : dependencies)
@@ -69,8 +83,10 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) {
void BufferViewFlowAnalysis::build(Operation *op) {
// Registers all dependencies of the given values.
auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
- for (auto [value, dep] : llvm::zip_equal(values, dependencies))
+ for (auto [value, dep] : llvm::zip_equal(values, dependencies)) {
this->dependencies[value].insert(dep);
+ this->reverseDependencies[dep].insert(value);
+ }
};
// Mark all buffer results and buffer region entry block arguments of the
@@ -188,3 +204,127 @@ bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const {
assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
return terminals.contains(value);
}
+
+//===----------------------------------------------------------------------===//
+// BufferOriginAnalysis
+//===----------------------------------------------------------------------===//
+
+/// Return "true" if the given value is the result of a memory allocation.
+static bool hasAllocateSideEffect(Value v) {
+ Operation *op = v.getDefiningOp();
+ if (!op)
+ return false;
+ return hasEffect<MemoryEffects::Allocate>(op, v);
+}
+
+/// Return "true" if the given value is a function block argument.
+static bool isFunctionArgument(Value v) {
+ auto bbArg = dyn_cast<BlockArgument>(v);
+ if (!bbArg)
+ return false;
+ Block *b = bbArg.getOwner();
+ auto funcOp = dyn_cast<FunctionOpInterface>(b->getParentOp());
+ if (!funcOp)
+ return false;
+ return bbArg.getOwner() == &funcOp.getFunctionBody().front();
+}
+
+/// Given a memref value, return the "base" value by skipping over all
+/// ViewLikeOpInterface ops (if any) in the reverse use-def chain.
+static Value getViewBase(Value value) {
+ while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>())
+ value = viewLikeOp.getViewSource();
+ return value;
+}
+
+BufferOriginAnalysis::BufferOriginAnalysis(Operation *op) : analysis(op) {}
+
+std::optional<bool> BufferOriginAnalysis::isSameAllocation(Value v1, Value v2) {
+ assert(isa<BaseMemRefType>(v1.getType()) && "expected buffer");
+ assert(isa<BaseMemRefType>(v2.getType()) && "expected buffer");
+
+ // Skip over all view-like ops.
+ v1 = getViewBase(v1);
+ v2 = getViewBase(v2);
+
+ // Fast path: If both buffers are the same SSA value, we can be sure that
+ // they originate from the same allocation.
+ if (v1 == v2)
+ return true;
+
+ // Compute the SSA values from which the buffers `v1` and `v2` originate.
+ SmallPtrSet<Value, 16> origin1 = analysis.resolveReverse(v1);
+ SmallPtrSet<Value, 16> origin2 = analysis.resolveReverse(v2);
+
+ // Originating buffers are "terminal" if they could not be traced back any
+ // further by the `BufferViewFlowAnalysis`. Examples of terminal buffers:
+ // - function block arguments
+ // - values defined by allocation ops such as "memref.alloc"
+ // - values defined by ops that are unknown to the buffer view flow analysis
+ // - values that are marked as "terminal" in the `BufferViewFlowOpInterface`
+ SmallPtrSet<Value, 16> terminal1, terminal2;
+
+ // While gathering terminal buffers, keep track of whether all terminal
+ // buffers are newly allocated buffer or function entry arguments.
+ bool allAllocs1 = true, allAllocs2 = true;
+ bool allAllocsOrFuncEntryArgs1 = true, allAllocsOrFuncEntryArgs2 = true;
+
+ // Helper function that gathers terminal buffers among `origin`.
+ auto gatherTerminalBuffers = [this](const SmallPtrSet<Value, 16> &origin,
+ SmallPtrSet<Value, 16> &terminal,
+ bool &allAllocs,
+ bool &allAllocsOrFuncEntryArgs) {
+ for (Value v : origin) {
+ if (isa<BaseMemRefType>(v.getType()) && analysis.mayBeTerminalBuffer(v)) {
+ terminal.insert(v);
+ allAllocs &= hasAllocateSideEffect(v);
+ allAllocsOrFuncEntryArgs &=
+ isFunctionArgument(v) || hasAllocateSideEffect(v);
+ }
+ }
+ assert(!terminal.empty() && "expected non-empty terminal set");
+ };
+
+ // Gather terminal buffers for `v1` and `v2`.
+ gatherTerminalBuffers(origin1, terminal1, allAllocs1,
+ allAllocsOrFuncEntryArgs1);
+ gatherTerminalBuffers(origin2, terminal2, allAllocs2,
+ allAllocsOrFuncEntryArgs2);
+
+ // If both `v1` and `v2` have a single matching terminal buffer, they are
+ // guaranteed to originate from the same buffer allocation.
+ if (llvm::hasSingleElement(terminal1) && llvm::hasSingleElement(terminal2) &&
+ *terminal1.begin() == *terminal2.begin())
+ return true;
+
+ // At least one of the two values has multiple terminals.
+
+ // Check if there is overlap between the terminal buffers of `v1` and `v2`.
+ bool distinctTerminalSets = true;
+ for (Value v : terminal1)
+ distinctTerminalSets &= !terminal2.contains(v);
+ // If there is overlap between the terminal buffers of `v1` and `v2`, we
+ // cannot make an accurate decision without further analysis.
+ if (!distinctTerminalSets)
+ return std::nullopt;
+
+ // If `v1` originates from only allocs, and `v2` is guaranteed to originate
+ // from different allocations (that is guaranteed if `v2` originates from
+ // only distinct allocs or function entry arguments), we can be sure that
+ // `v1` and `v2` originate from different allocations. The same argument can
+ // be made when swapping `v1` and `v2`.
+ bool isolatedAlloc1 = allAllocs1 && (allAllocs2 || allAllocsOrFuncEntryArgs2);
+ bool isolatedAlloc2 = (allAllocs1 || allAllocsOrFuncEntryArgs1) && allAllocs2;
+ if (isolatedAlloc1 || isolatedAlloc2)
+ return false;
+
+ // Otherwise: We do not know whether `v1` and `v2` originate from the same
+ // allocation or not.
+ // TODO: Function arguments are currently handled conservatively. We assume
+ // that they could be the same allocation.
+ // TODO: Terminals other than allocations and function arguments are
+ // currently handled conservatively. We assume that they could be the same
+ // allocation. E.g., we currently return "nullopt" for values that originate
+ // from different "memref.get_global" ops (with different symbols).
+ return std::nullopt;
+}