aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2024-03-24 12:48:19 +0900
committerGitHub <noreply@github.com>2024-03-24 12:48:19 +0900
commita45e58af1b381cf3c0374332386b8291ec5310f4 (patch)
tree4b91770f2d80bfcce151c7d76ffeed94ac75abbc
parent74799f424063a2d751e0f9ea698db1f4efd0d8b2 (diff)
downloadllvm-a45e58af1b381cf3c0374332386b8291ec5310f4.zip
llvm-a45e58af1b381cf3c0374332386b8291ec5310f4.tar.gz
llvm-a45e58af1b381cf3c0374332386b8291ec5310f4.tar.bz2
[mlir][bufferization] Add `BufferViewFlowOpInterface` (#78718)
This commit adds the `BufferViewFlowOpInterface` to the bufferization dialect. This interface can be implemented by ops that operate on buffers to indicate that a buffer op result and/or region entry block argument may be the same buffer as a buffer operand (or a view thereof). This interface is queried by the `BufferViewFlowAnalysis`. The new interface has two interface methods: * `populateDependencies`: Implementations use the provided callback to declare dependencies between operands and op results/region entry block arguments. E.g., for `%r = arith.select %c, %m1, %m2 : memref<5xf32>`, the interface implementation should declare two dependencies: %m1 -> %r and %m2 -> %r. * `mayBeTerminalBuffer`: An SSA value is a terminal buffer if the buffer view flow analysis stops at the specified value. E.g., because the value is a newly allocated buffer or because no further information is available about the origin of the buffer. Ops that implement the `RegionBranchOpInterface` or `BranchOpInterface` do not have to implement the `BufferViewFlowOpInterface`. The buffer dependencies can be inferred from those two interfaces. This commit makes the `BufferViewFlowAnalysis` more accurate. For unknown ops, it conservatively used to declare all combinations of operands and op results/region entry block arguments as dependencies (false positives). This is no longer the case. While the analysis is still a "maybe" analysis with false positives (e.g., when analyzing ops such as `arith.select` or `scf.if` where the taken branch is not known at compile time), results and region entry block arguments of unknown ops are now marked as terminal buffers. This commit addresses a TODO in `BufferViewFlowAnalysis.cpp`: ``` // TODO: We should have an op interface instead of a hard-coded list of // interfaces/ops. ``` It is no longer needed to hard-code ops.
-rw-r--r--mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h20
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h27
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td73
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt1
-rw-r--r--mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h6
-rw-r--r--mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h20
-rw-r--r--mlir/include/mlir/InitAllDialects.h4
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp44
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp18
-rw-r--r--mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp72
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp48
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt2
-rw-r--r--utils/bazel/llvm-project-overlay/mlir/BUILD.bazel40
15 files changed, 363 insertions, 14 deletions
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h
new file mode 100644
index 0000000..a2b3a9b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+#define MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace arith {
+void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h
new file mode 100644
index 0000000..84e67fe
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h
@@ -0,0 +1,27 @@
+//===- BufferViewFlowOpInterface.h - Buffer View Flow Analysis --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class ValueRange;
+
+namespace bufferization {
+
+using RegisterDependenciesFn = std::function<void(ValueRange, ValueRange)>;
+
+} // namespace bufferization
+} // namespace mlir
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h.inc"
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td
new file mode 100644
index 0000000..58885d7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td
@@ -0,0 +1,73 @@
+//===-- BufferViewFlowOpInterface.td - Buffer View Flow ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BUFFER_VIEW_FLOW_OP_INTERFACE
+#define BUFFER_VIEW_FLOW_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def BufferViewFlowOpInterface :
+ OpInterface<"BufferViewFlowOpInterface"> {
+ let description = [{
+ An op interface for the buffer view flow analysis. This interface describes
+ buffer dependencies between operands and op results/region entry block
+ arguments.
+ }];
+ let cppNamespace = "::mlir::bufferization";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Populate buffer dependencies between operands and op results/region
+ entry block arguments.
+
+ Implementations should register dependencies between an operand ("X")
+ and an op result/region entry block argument ("Y") if Y may depend
+ on X. Y depends on X if Y and X are the same buffer or if Y is a
+ subview of X.
+
+ Example:
+ ```
+ %r = arith.select %c, %m1, %m2 : memref<5xf32>
+ ```
+ In the above example, %0 may depend on %m1 or %m2 and a correct
+ interface implementation should call:
+ - "registerDependenciesFn(%m1, %r)".
+ - "registerDependenciesFn(%m2, %r)"
+ }],
+ /*retType=*/"void",
+ /*methodName=*/"populateDependencies",
+ /*args=*/(ins
+ "::mlir::bufferization::RegisterDependenciesFn"
+ :$registerDependenciesFn)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return "true" if the given value may be a terminal buffer. A buffer
+ value is "terminal" if it cannot be traced back any further in the
+ buffer view flow analysis.
+
+ Examples: A buffer could be terminal because:
+ - it is a newly allocated buffer (e.g., "memref.alloc"),
+ - or: because there is not enough compile-time information available
+ to make a definite decision (e.g., "memref.realloc" may reallocate
+ but we do not know for sure; another example are call ops where we
+ would have to analyze the body of the callee).
+
+ Implementations can assume that the given SSA value is an OpResult of
+ this operation or a region entry block argument of this operation.
+ }],
+ /*retType=*/"bool",
+ /*methodName=*/"mayBeTerminalBuffer",
+ /*args=*/(ins "Value":$value),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return false;"
+ >,
+ ];
+}
+
+#endif // BUFFER_VIEW_FLOW_OP_INTERFACE
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index 31a553f..13a5bc3 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
add_mlir_interface(AllocationOpInterface)
add_mlir_interface(BufferDeallocationOpInterface)
add_mlir_interface(BufferizableOpInterface)
+add_mlir_interface(BufferViewFlowOpInterface)
set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
index 24825db..9e43265 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
@@ -63,6 +63,9 @@ public:
/// results have to be changed.
void rename(Value from, Value to);
+ /// Returns "true" if the given value may be a terminal.
+ bool mayBeTerminalBuffer(Value value) const;
+
private:
/// This function constructs a mapping from values to its immediate
/// dependencies.
@@ -70,6 +73,9 @@ private:
/// Maps values to all immediate dependencies this value can have.
ValueMapT dependencies;
+
+ /// A set of all SSA values that may be terminal buffers.
+ DenseSet<Value> terminals;
};
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h
new file mode 100644
index 0000000..714518a
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+#define MLIR_DIALECT_MEMREF_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 9bbf12d..c558dc5 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -21,6 +21,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
@@ -52,6 +53,7 @@
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
@@ -148,6 +150,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
affine::registerValueBoundsOpInterfaceExternalModels(registry);
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
+ arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
arith::registerValueBoundsOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
@@ -157,6 +160,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
linalg::registerAllDialectInterfaceImplementations(registry);
memref::registerAllocationOpInterfaceExternalModels(registry);
+ memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
memref::registerValueBoundsOpInterfaceExternalModels(registry);
memref::registerMemorySlotExternalModels(registry);
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp
new file mode 100644
index 0000000..9df9df8
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp
@@ -0,0 +1,44 @@
+//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+
+namespace mlir {
+namespace arith {
+namespace {
+
+struct SelectOpInterface
+ : public BufferViewFlowOpInterface::ExternalModel<SelectOpInterface,
+ SelectOp> {
+ void
+ populateDependencies(Operation *op,
+ RegisterDependenciesFn registerDependenciesFn) const {
+ auto selectOp = cast<SelectOp>(op);
+
+ // Either one of the true/false value may be selected at runtime.
+ registerDependenciesFn(selectOp.getTrueValue(), selectOp.getResult());
+ registerDependenciesFn(selectOp.getFalseValue(), selectOp.getResult());
+ }
+};
+
+} // namespace
+} // namespace arith
+} // namespace mlir
+
+void arith::registerBufferViewFlowOpInterfaceExternalModels(
+ DialectRegistry &registry) {
+ registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
+ SelectOp::attachInterface<SelectOpInterface>(*ctx);
+ });
+}
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 0224060..12659ea 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithTransforms
BufferDeallocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
+ BufferViewFlowOpInterfaceImpl.cpp
EmulateUnsupportedFloats.cpp
EmulateWideInt.cpp
EmulateNarrowType.cpp
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp
new file mode 100644
index 0000000..ea726a4
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp
@@ -0,0 +1,18 @@
+//===- BufferViewFlowOpInterface.cpp - Buffer View Flow Analysis ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+
+namespace mlir {
+namespace bufferization {
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp.inc"
+
+} // namespace bufferization
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 9895db9..63dcc1e 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
BufferDeallocationOpInterface.cpp
BufferizationOps.cpp
BufferizationDialect.cpp
+ BufferViewFlowOpInterface.cpp
UnstructuredControlFlow.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 88ef1b6..9a36057 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -8,12 +8,16 @@
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SetVector.h"
using namespace mlir;
+using namespace mlir::bufferization;
/// Constructs a new alias analysis using the op provided.
BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
@@ -65,18 +69,44 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) {
void BufferViewFlowAnalysis::build(Operation *op) {
// Registers all dependencies of the given values.
auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
- for (auto [value, dep] : llvm::zip(values, dependencies))
+ for (auto [value, dep] : llvm::zip_equal(values, dependencies))
this->dependencies[value].insert(dep);
};
+ // Mark all buffer results and buffer region entry block arguments of the
+ // given op as terminals.
+ auto populateTerminalValues = [&](Operation *op) {
+ for (Value v : op->getResults())
+ if (isa<BaseMemRefType>(v.getType()))
+ this->terminals.insert(v);
+ for (Region &r : op->getRegions())
+ for (BlockArgument v : r.getArguments())
+ if (isa<BaseMemRefType>(v.getType()))
+ this->terminals.insert(v);
+ };
+
op->walk([&](Operation *op) {
- // TODO: We should have an op interface instead of a hard-coded list of
- // interfaces/ops.
+ // Query BufferViewFlowOpInterface. If the op does not implement that
+ // interface, try to infer the dependencies from other interfaces that the
+ // op may implement.
+ if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
+ bufferViewFlowOp.populateDependencies(registerDependencies);
+ for (Value v : op->getResults())
+ if (isa<BaseMemRefType>(v.getType()) &&
+ bufferViewFlowOp.mayBeTerminalBuffer(v))
+ this->terminals.insert(v);
+ for (Region &r : op->getRegions())
+ for (BlockArgument v : r.getArguments())
+ if (isa<BaseMemRefType>(v.getType()) &&
+ bufferViewFlowOp.mayBeTerminalBuffer(v))
+ this->terminals.insert(v);
+ return WalkResult::advance();
+ }
// Add additional dependencies created by view changes to the alias list.
if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
- dependencies[viewInterface.getViewSource()].insert(
- viewInterface->getResult(0));
+ registerDependencies(viewInterface.getViewSource(),
+ viewInterface->getResult(0));
return WalkResult::advance();
}
@@ -131,16 +161,30 @@ void BufferViewFlowAnalysis::build(Operation *op) {
return WalkResult::advance();
}
- // Unknown op: Assume that all operands alias with all results.
- for (Value operand : op->getOperands()) {
- if (!isa<BaseMemRefType>(operand.getType()))
- continue;
- for (Value result : op->getResults()) {
- if (!isa<BaseMemRefType>(result.getType()))
- continue;
- registerDependencies({operand}, {result});
- }
+ // Region terminators are handled together with RegionBranchOpInterface.
+ if (isa<RegionBranchTerminatorOpInterface>(op))
+ return WalkResult::advance();
+
+ if (isa<CallOpInterface>(op)) {
+ // This is an intra-function analysis. We have no information about other
+ // functions. Conservatively assume that each operand may alias with each
+ // result. Also mark the results are terminals because the function could
+ // return newly allocated buffers.
+ populateTerminalValues(op);
+ for (Value operand : op->getOperands())
+ for (Value result : op->getResults())
+ registerDependencies({operand}, {result});
+ return WalkResult::advance();
}
+
+ // We have no information about unknown ops.
+ populateTerminalValues(op);
+
return WalkResult::advance();
});
}
+
+bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const {
+ assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
+ return terminals.contains(value);
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
new file mode 100644
index 0000000..bbb269b
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
@@ -0,0 +1,48 @@
+//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+
+namespace mlir {
+namespace memref {
+namespace {
+
+struct ReallocOpInterface
+ : public BufferViewFlowOpInterface::ExternalModel<ReallocOpInterface,
+ ReallocOp> {
+ void
+ populateDependencies(Operation *op,
+ RegisterDependenciesFn registerDependenciesFn) const {
+ auto reallocOp = cast<ReallocOp>(op);
+ // memref.realloc may return the source operand.
+ registerDependenciesFn(reallocOp.getSource(), reallocOp.getResult());
+ }
+
+ bool mayBeTerminalBuffer(Operation *op, Value value) const {
+ // The return value of memref.realloc is a terminal buffer because the op
+ // may return a newly allocated buffer.
+ return true;
+ }
+};
+
+} // namespace
+} // namespace memref
+} // namespace mlir
+
+void memref::registerBufferViewFlowOpInterfaceExternalModels(
+ DialectRegistry &registry) {
+ registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+ ReallocOp::attachInterface<ReallocOpInterface>(*ctx);
+ });
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index 08b7eab..f150ac7 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRMemRefTransforms
AllocationOpInterfaceImpl.cpp
+ BufferViewFlowOpInterfaceImpl.cpp
ComposeSubView.cpp
ExpandOps.cpp
ExpandRealloc.cpp
@@ -27,6 +28,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
MLIRArithDialect
MLIRArithTransforms
MLIRBufferizationDialect
+ MLIRBufferizationTransforms
MLIRDialectUtils
MLIRFuncDialect
MLIRGPUDialect
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 5b6e467..88b46bd 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -10829,6 +10829,36 @@ gentbl_cc_library(
)
td_library(
+ name = "BufferViewFlowOpInterfaceTdFiles",
+ srcs = [
+ "include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td",
+ ],
+ includes = ["include"],
+ deps = [
+ ":OpBaseTdFiles",
+ ],
+)
+
+gentbl_cc_library(
+ name = "BufferViewFlowOpInterfaceIncGen",
+ tbl_outs = [
+ (
+ ["-gen-op-interface-decls"],
+ "include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h.inc",
+ ),
+ (
+ ["-gen-op-interface-defs"],
+ "include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td",
+ deps = [
+ ":BufferViewFlowOpInterfaceTdFiles",
+ ],
+)
+
+td_library(
name = "SubsetOpInterfaceTdFiles",
srcs = [
"include/mlir/Interfaces/SubsetOpInterface.td",
@@ -12977,6 +13007,8 @@ cc_library(
":ArithTransforms",
":ArithUtils",
":BufferizationDialect",
+ ":BufferizationInterfaces",
+ ":BufferizationTransforms",
":ControlFlowDialect",
":DialectUtils",
":FuncDialect",
@@ -13369,6 +13401,7 @@ td_library(
includes = ["include"],
deps = [
":AllocationOpInterfaceTdFiles",
+ ":BufferViewFlowOpInterfaceTdFiles",
":BufferizableOpInterfaceTdFiles",
":CopyOpInterfaceTdFiles",
":DestinationStyleOpInterfaceTdFiles",
@@ -13515,11 +13548,13 @@ cc_library(
],
hdrs = [
"include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h",
+ "include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h",
"include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h",
],
includes = ["include"],
deps = [
":BufferDeallocationOpInterfaceIncGen",
+ ":BufferViewFlowOpInterfaceIncGen",
":BufferizableOpInterfaceIncGen",
":BufferizationEnumsIncGen",
":IR",
@@ -13532,6 +13567,7 @@ cc_library(
name = "BufferizationDialect",
srcs = [
"lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp",
+ "lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp",
"lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp",
"lib/Dialect/Bufferization/IR/BufferizationDialect.cpp",
"lib/Dialect/Bufferization/IR/BufferizationOps.cpp",
@@ -13549,10 +13585,12 @@ cc_library(
":Analysis",
":ArithDialect",
":BufferDeallocationOpInterfaceIncGen",
+ ":BufferViewFlowOpInterfaceIncGen",
":BufferizableOpInterfaceIncGen",
":BufferizationBaseIncGen",
":BufferizationInterfaces",
":BufferizationOpsIncGen",
+ ":CallOpInterfaces",
":ControlFlowInterfaces",
":CopyOpInterface",
":DestinationStyleOpInterface",
@@ -13602,9 +13640,11 @@ cc_library(
":BufferizationDialect",
":BufferizationInterfaces",
":BufferizationPassIncGen",
+ ":CallOpInterfaces",
":ControlFlowDialect",
":ControlFlowInterfaces",
":FuncDialect",
+ ":FunctionInterfaces",
":IR",
":LoopLikeInterface",
":MemRefDialect",