diff options
author | Matthias Springer <me@m-sp.org> | 2024-03-24 12:48:19 +0900 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-24 12:48:19 +0900 |
commit | a45e58af1b381cf3c0374332386b8291ec5310f4 (patch) | |
tree | 4b91770f2d80bfcce151c7d76ffeed94ac75abbc /mlir/lib | |
parent | 74799f424063a2d751e0f9ea698db1f4efd0d8b2 (diff) | |
download | llvm-a45e58af1b381cf3c0374332386b8291ec5310f4.zip llvm-a45e58af1b381cf3c0374332386b8291ec5310f4.tar.gz llvm-a45e58af1b381cf3c0374332386b8291ec5310f4.tar.bz2 |
[mlir][bufferization] Add `BufferViewFlowOpInterface` (#78718)
This commit adds the `BufferViewFlowOpInterface` to the bufferization
dialect. This interface can be implemented by ops that operate on
buffers to indicate that a buffer op result and/or region entry block
argument may be the same buffer as a buffer operand (or a view thereof).
This interface is queried by the `BufferViewFlowAnalysis`.
The new interface has two interface methods:
* `populateDependencies`: Implementations use the provided callback to
declare dependencies between operands and op results/region entry block
arguments. E.g., for `%r = arith.select %c, %m1, %m2 : memref<5xf32>`,
the interface implementation should declare two dependencies: %m1 -> %r
and %m2 -> %r.
* `mayBeTerminalBuffer`: An SSA value is a terminal buffer if the buffer
view flow analysis stops at the specified value. E.g., because the value
is a newly allocated buffer or because no further information is
available about the origin of the buffer.
Ops that implement the `RegionBranchOpInterface` or `BranchOpInterface`
do not have to implement the `BufferViewFlowOpInterface`. The buffer
dependencies can be inferred from those two interfaces.
This commit makes the `BufferViewFlowAnalysis` more accurate. For
unknown ops, it conservatively used to declare all combinations of
operands and op results/region entry block arguments as dependencies
(false positives). This is no longer the case. While the analysis is
still a "maybe" analysis with false positives (e.g., when analyzing ops
such as `arith.select` or `scf.if` where the taken branch is not known
at compile time), results and region entry block arguments of unknown
ops are now marked as terminal buffers.
This commit addresses a TODO in `BufferViewFlowAnalysis.cpp`:
```
// TODO: We should have an op interface instead of a hard-coded list of
// interfaces/ops.
```
It is no longer needed to hard-code ops.
Diffstat (limited to 'mlir/lib')
7 files changed, 172 insertions, 14 deletions
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp new file mode 100644 index 0000000..9df9df8 --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp @@ -0,0 +1,44 @@ +//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h" + +using namespace mlir; +using namespace mlir::bufferization; + +namespace mlir { +namespace arith { +namespace { + +struct SelectOpInterface + : public BufferViewFlowOpInterface::ExternalModel<SelectOpInterface, + SelectOp> { + void + populateDependencies(Operation *op, + RegisterDependenciesFn registerDependenciesFn) const { + auto selectOp = cast<SelectOp>(op); + + // Either one of the true/false value may be selected at runtime. + registerDependenciesFn(selectOp.getTrueValue(), selectOp.getResult()); + registerDependenciesFn(selectOp.getFalseValue(), selectOp.getResult()); + } +}; + +} // namespace +} // namespace arith +} // namespace mlir + +void arith::registerBufferViewFlowOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) { + SelectOp::attachInterface<SelectOpInterface>(*ctx); + }); +} diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt index 0224060..12659ea 100644 --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithTransforms BufferDeallocationOpInterfaceImpl.cpp BufferizableOpInterfaceImpl.cpp Bufferize.cpp + BufferViewFlowOpInterfaceImpl.cpp EmulateUnsupportedFloats.cpp EmulateWideInt.cpp EmulateNarrowType.cpp diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp new file mode 100644 index 0000000..ea726a4 --- /dev/null +++ b/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp @@ -0,0 +1,18 @@ +//===- BufferViewFlowOpInterface.cpp - Buffer View Flow Analysis ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" + +namespace mlir { +namespace bufferization { + +#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp.inc" + +} // namespace bufferization +} // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt index 9895db9..63dcc1e 100644 --- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect BufferDeallocationOpInterface.cpp BufferizationOps.cpp BufferizationDialect.cpp + BufferViewFlowOpInterface.cpp UnstructuredControlFlow.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp index 88ef1b6..9a36057 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp @@ -8,12 +8,16 @@ #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" +#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h" +#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SetVector.h" using namespace mlir; +using namespace mlir::bufferization; /// Constructs a new alias analysis using the op provided. BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); } @@ -65,18 +69,44 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) { void BufferViewFlowAnalysis::build(Operation *op) { // Registers all dependencies of the given values. auto registerDependencies = [&](ValueRange values, ValueRange dependencies) { - for (auto [value, dep] : llvm::zip(values, dependencies)) + for (auto [value, dep] : llvm::zip_equal(values, dependencies)) this->dependencies[value].insert(dep); }; + // Mark all buffer results and buffer region entry block arguments of the + // given op as terminals. + auto populateTerminalValues = [&](Operation *op) { + for (Value v : op->getResults()) + if (isa<BaseMemRefType>(v.getType())) + this->terminals.insert(v); + for (Region &r : op->getRegions()) + for (BlockArgument v : r.getArguments()) + if (isa<BaseMemRefType>(v.getType())) + this->terminals.insert(v); + }; + op->walk([&](Operation *op) { - // TODO: We should have an op interface instead of a hard-coded list of - // interfaces/ops. + // Query BufferViewFlowOpInterface. If the op does not implement that + // interface, try to infer the dependencies from other interfaces that the + // op may implement. + if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) { + bufferViewFlowOp.populateDependencies(registerDependencies); + for (Value v : op->getResults()) + if (isa<BaseMemRefType>(v.getType()) && + bufferViewFlowOp.mayBeTerminalBuffer(v)) + this->terminals.insert(v); + for (Region &r : op->getRegions()) + for (BlockArgument v : r.getArguments()) + if (isa<BaseMemRefType>(v.getType()) && + bufferViewFlowOp.mayBeTerminalBuffer(v)) + this->terminals.insert(v); + return WalkResult::advance(); + } // Add additional dependencies created by view changes to the alias list. if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) { - dependencies[viewInterface.getViewSource()].insert( - viewInterface->getResult(0)); + registerDependencies(viewInterface.getViewSource(), + viewInterface->getResult(0)); return WalkResult::advance(); } @@ -131,16 +161,30 @@ void BufferViewFlowAnalysis::build(Operation *op) { return WalkResult::advance(); } - // Unknown op: Assume that all operands alias with all results. - for (Value operand : op->getOperands()) { - if (!isa<BaseMemRefType>(operand.getType())) - continue; - for (Value result : op->getResults()) { - if (!isa<BaseMemRefType>(result.getType())) - continue; - registerDependencies({operand}, {result}); - } + // Region terminators are handled together with RegionBranchOpInterface. + if (isa<RegionBranchTerminatorOpInterface>(op)) + return WalkResult::advance(); + + if (isa<CallOpInterface>(op)) { + // This is an intra-function analysis. We have no information about other + // functions. Conservatively assume that each operand may alias with each + // result. Also mark the results are terminals because the function could + // return newly allocated buffers. + populateTerminalValues(op); + for (Value operand : op->getOperands()) + for (Value result : op->getResults()) + registerDependencies({operand}, {result}); + return WalkResult::advance(); } + + // We have no information about unknown ops. + populateTerminalValues(op); + return WalkResult::advance(); }); } + +bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const { + assert(isa<BaseMemRefType>(value.getType()) && "expected memref"); + return terminals.contains(value); +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp new file mode 100644 index 0000000..bbb269b --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp @@ -0,0 +1,48 @@ +//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h" + +#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +using namespace mlir; +using namespace mlir::bufferization; + +namespace mlir { +namespace memref { +namespace { + +struct ReallocOpInterface + : public BufferViewFlowOpInterface::ExternalModel<ReallocOpInterface, + ReallocOp> { + void + populateDependencies(Operation *op, + RegisterDependenciesFn registerDependenciesFn) const { + auto reallocOp = cast<ReallocOp>(op); + // memref.realloc may return the source operand. + registerDependenciesFn(reallocOp.getSource(), reallocOp.getResult()); + } + + bool mayBeTerminalBuffer(Operation *op, Value value) const { + // The return value of memref.realloc is a terminal buffer because the op + // may return a newly allocated buffer. + return true; + } +}; + +} // namespace +} // namespace memref +} // namespace mlir + +void memref::registerBufferViewFlowOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { + ReallocOp::attachInterface<ReallocOpInterface>(*ctx); + }); +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt index 08b7eab..f150ac7 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRMemRefTransforms AllocationOpInterfaceImpl.cpp + BufferViewFlowOpInterfaceImpl.cpp ComposeSubView.cpp ExpandOps.cpp ExpandRealloc.cpp @@ -27,6 +28,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms MLIRArithDialect MLIRArithTransforms MLIRBufferizationDialect + MLIRBufferizationTransforms MLIRDialectUtils MLIRFuncDialect MLIRGPUDialect |