aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <springerm@google.com>2022-09-02 14:32:04 +0200
committerMatthias Springer <springerm@google.com>2022-09-02 14:47:20 +0200
commitf7f0c7f7e3d711cd8f0d069fef7f664f074a57e4 (patch)
treec8d7d25406c5f35bea35be6baac5fc2982907659
parent3da23970ed750bb7d42b94af26728cba96162207 (diff)
downloadllvm-f7f0c7f7e3d711cd8f0d069fef7f664f074a57e4.zip
llvm-f7f0c7f7e3d711cd8f0d069fef7f664f074a57e4.tar.gz
llvm-f7f0c7f7e3d711cd8f0d069fef7f664f074a57e4.tar.bz2
[mlir][bufferize] Add isRepetitiveRegion to BufferizableOpInterface
This method allows to declare regions as "repetitive" even if the parent op does not implement the RegionBranchOpInterface. This is needed to support loop-like ops that have parallel semantics but do not branch between regions. Differential Revision: https://reviews.llvm.org/D133113
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h6
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td23
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp11
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp34
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp23
-rw-r--r--utils/bazel/llvm-project-overlay/mlir/BUILD.bazel1
6 files changed, 82 insertions, 16 deletions
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index f22fe00..dc2b12f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -566,6 +566,12 @@ namespace detail {
FailureOr<BaseMemRefType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
const DenseMap<Value, BaseMemRefType> &fixedTypes);
+
+/// This is the default implementation of
+/// BufferizableOpInterface::isRepetitiveRegion. Should not be called from other
+/// places.
+bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp,
+ unsigned index);
} // namespace detail
} // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index c28e3f6..ab8f3e2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -360,6 +360,29 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
value, options, fixedTypes);
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return `true` if the given region of this op is repetitive. By default
+ this information is queried from the `RegionBranchOpInterface`. Ops
+ that do not implement this inferface can override this method to
+ declare regions as repetitive.
+
+ The RaW conflict detection of One-Shot Analysis is more strict inside
+ repetitive regions: Op dominance cannot always be used to rule out
+ certain potential conflicts (e.g., a conflicting write happening after
+ a read), because there may not be a meaningful ordering of certain ops
+ that are executed multiple times. This is described in more detail in
+ documentation of One-Shot Analysis.
+ }],
+ /*retType=*/"bool",
+ /*methodName=*/"isRepetitiveRegion",
+ /*args=*/(ins "unsigned":$index),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return mlir::bufferization::detail::defaultIsRepetitiveRegion(
+ cast<BufferizableOpInterface>($_op.getOperation()), index);
+ }]
+ >
];
let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 265c417..5ec4135 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/Support/Debug.h"
//===----------------------------------------------------------------------===//
@@ -784,3 +785,13 @@ bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
rankedTensorType.getElementType(), layout,
memorySpaceAttr);
}
+
+bool bufferization::detail::defaultIsRepetitiveRegion(
+ BufferizableOpInterface bufferizableOp, unsigned index) {
+ assert(index < bufferizableOp->getNumRegions() && "invalid region index");
+ auto regionInterface =
+ dyn_cast<RegionBranchOpInterface>(bufferizableOp.getOperation());
+ if (!regionInterface)
+ return false;
+ return regionInterface.isRepetitiveRegion(index);
+}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 375330c..46420f1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -351,14 +351,40 @@ static bool happensBefore(Operation *a, Operation *b,
return false;
}
+static Region *
+getEnclosingRepetitiveRegion(Operation *op,
+ const BufferizationOptions &options) {
+ while (Region *region = op->getParentRegion()) {
+ op = region->getParentOp();
+ if (auto bufferizableOp = options.dynCastBufferizableOp(op))
+ if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
+ return region;
+ }
+ return nullptr;
+}
+
+static Region *
+getEnclosingRepetitiveRegion(Value value, const BufferizationOptions &options) {
+ Region *region = value.getParentRegion();
+ while (region) {
+ Operation *op = region->getParentOp();
+ if (auto bufferizableOp = options.dynCastBufferizableOp(op))
+ if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
+ return region;
+ region = op->getParentRegion();
+ }
+ return nullptr;
+}
+
/// For each given value, find the closest enclosing repetitive region. If this
/// is the same region for each value, return it. Otherwise return None.
/// Note: If there is no enclosing repetitive region, return nullptr.
static Optional<Region *>
-getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values) {
+getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values,
+ const BufferizationOptions &options) {
if (values.empty())
return None;
- Region *r = getEnclosingRepetitiveRegion(values.front());
+ Region *r = getEnclosingRepetitiveRegion(values.front(), options);
for (Value value : values.drop_front())
if (getEnclosingRepetitiveRegion(value) != r)
return None;
@@ -432,7 +458,7 @@ static bool hasReadAfterWriteInterference(
// Find the inner-most enclosing repetitive region of each alias. If this is
// the same region for every alias, save it in `repetitiveRegionOfWrites`.
Optional<Region *> repetitiveRegionOfWrites =
- getCommonEnclosingRepetitiveRegion(writtenAliases);
+ getCommonEnclosingRepetitiveRegion(writtenAliases, options);
for (OpOperand *uRead : usesRead) {
Operation *readingOp = uRead->getOwner();
@@ -497,7 +523,7 @@ static bool hasReadAfterWriteInterference(
bool canUseOpDominance =
writtenAliases.empty() ||
repetitiveRegionOfWrites ==
- getEnclosingRepetitiveRegion(conflictingWritingOp);
+ getEnclosingRepetitiveRegion(conflictingWritingOp, options);
// No conflict if the readingOp dominates conflictingWritingOp, i.e., the
// write is not visible when reading.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index 2b7025c..165f5bb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -48,15 +48,14 @@ resolveUsesInRepetitiveRegions(Operation *op,
AnalysisState state(options);
// Look for repetitive ops (loops).
- op->walk([&](RegionBranchOpInterface regionBranchOp) {
- // Skip non-bufferizable ops.
- auto bufferizableOp = options.dynCastBufferizableOp(regionBranchOp);
- if (!bufferizableOp)
+ op->walk([&](BufferizableOpInterface bufferizableOp) {
+ // Skip filtered ops.
+ if (!options.isOpAllowed(bufferizableOp.getOperation()))
return WalkResult::advance();
- // Find all operands that are also used inside of a repetitve region of this
- // op.
- for (OpOperand &opOperand : regionBranchOp->getOpOperands()) {
+ // Find all operands that are also used inside of a repetitive region of
+ // this op.
+ for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
Value operand = opOperand.get();
// Skip non-tensor operands.
if (!operand.getType().isa<TensorType>())
@@ -69,11 +68,11 @@ resolveUsesInRepetitiveRegions(Operation *op,
SmallVector<OpOperand *> usesInsideRegion;
for (OpOperand &use : operand.getUses()) {
Operation *owner = use.getOwner();
- if (!regionBranchOp->isProperAncestor(owner))
+ if (!bufferizableOp->isProperAncestor(owner))
continue;
- for (Region &r : regionBranchOp->getRegions()) {
+ for (Region &r : bufferizableOp->getRegions()) {
if (r.findAncestorOpInRegion(*owner) &&
- regionBranchOp.isRepetitiveRegion(r.getRegionNumber())) {
+ bufferizableOp.isRepetitiveRegion(r.getRegionNumber())) {
usesInsideRegion.push_back(&use);
break;
}
@@ -84,9 +83,9 @@ resolveUsesInRepetitiveRegions(Operation *op,
continue;
// Insert a tensor copy and replace all uses inside of repetitive regions.
- rewriter.setInsertionPoint(regionBranchOp);
+ rewriter.setInsertionPoint(bufferizableOp);
auto tensorCopy = rewriter.create<AllocTensorOp>(
- regionBranchOp->getLoc(), operand.getType().cast<TensorType>(),
+ bufferizableOp->getLoc(), operand.getType().cast<TensorType>(),
/*dynamicSizes=*/ValueRange(),
/*copy=*/operand, /*memory_space=*/IntegerAttr());
for (OpOperand *use : usesInsideRegion)
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index c5860596..64b53bd 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -9088,6 +9088,7 @@ cc_library(
":BufferizableOpInterfaceIncGen",
":BufferizationBaseIncGen",
":BufferizationOpsIncGen",
+ ":ControlFlowInterfaces",
":CopyOpInterface",
":FuncDialect",
":IR",