From a45e58af1b381cf3c0374332386b8291ec5310f4 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 24 Mar 2024 12:48:19 +0900 Subject: [mlir][bufferization] Add `BufferViewFlowOpInterface` (#78718) This commit adds the `BufferViewFlowOpInterface` to the bufferization dialect. This interface can be implemented by ops that operate on buffers to indicate that a buffer op result and/or region entry block argument may be the same buffer as a buffer operand (or a view thereof). This interface is queried by the `BufferViewFlowAnalysis`. The new interface has two interface methods: * `populateDependencies`: Implementations use the provided callback to declare dependencies between operands and op results/region entry block arguments. E.g., for `%r = arith.select %c, %m1, %m2 : memref<5xf32>`, the interface implementation should declare two dependencies: %m1 -> %r and %m2 -> %r. * `mayBeTerminalBuffer`: An SSA value is a terminal buffer if the buffer view flow analysis stops at the specified value. E.g., because the value is a newly allocated buffer or because no further information is available about the origin of the buffer. Ops that implement the `RegionBranchOpInterface` or `BranchOpInterface` do not have to implement the `BufferViewFlowOpInterface`. The buffer dependencies can be inferred from those two interfaces. This commit makes the `BufferViewFlowAnalysis` more accurate. For unknown ops, it conservatively used to declare all combinations of operands and op results/region entry block arguments as dependencies (false positives). This is no longer the case. While the analysis is still a "maybe" analysis with false positives (e.g., when analyzing ops such as `arith.select` or `scf.if` where the taken branch is not known at compile time), results and region entry block arguments of unknown ops are now marked as terminal buffers. This commit addresses a TODO in `BufferViewFlowAnalysis.cpp`: ``` // TODO: We should have an op interface instead of a hard-coded list of // interfaces/ops. ``` It is no longer needed to hard-code ops. --- .../Transforms/BufferViewFlowOpInterfaceImpl.h | 20 ++++++ .../Bufferization/IR/BufferViewFlowOpInterface.h | 27 ++++++++ .../Bufferization/IR/BufferViewFlowOpInterface.td | 73 ++++++++++++++++++++++ .../mlir/Dialect/Bufferization/IR/CMakeLists.txt | 1 + .../Transforms/BufferViewFlowAnalysis.h | 6 ++ .../Transforms/BufferViewFlowOpInterfaceImpl.h | 20 ++++++ mlir/include/mlir/InitAllDialects.h | 4 ++ .../Transforms/BufferViewFlowOpInterfaceImpl.cpp | 44 +++++++++++++ mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt | 1 + .../Bufferization/IR/BufferViewFlowOpInterface.cpp | 18 ++++++ mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt | 1 + .../Transforms/BufferViewFlowAnalysis.cpp | 72 ++++++++++++++++----- .../Transforms/BufferViewFlowOpInterfaceImpl.cpp | 48 ++++++++++++++ mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt | 2 + utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 40 ++++++++++++ 15 files changed, 363 insertions(+), 14 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h create mode 100644 mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td create mode 100644 mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h create mode 100644 mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp create mode 100644 mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp create mode 100644 mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h new file mode 100644 index 0000000..a2b3a9b --- /dev/null +++ b/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H +#define MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace arith { +void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace arith +} // namespace mlir + +#endif // MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h new file mode 100644 index 0000000..84e67fe --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h @@ -0,0 +1,27 @@ +//===- BufferViewFlowOpInterface.h - Buffer View Flow Analysis --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_ +#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +class ValueRange; + +namespace bufferization { + +using RegisterDependenciesFn = std::function; + +} // namespace bufferization +} // namespace mlir + +#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h.inc" + +#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_ diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td new file mode 100644 index 0000000..58885d7 --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td @@ -0,0 +1,73 @@ +//===-- BufferViewFlowOpInterface.td - Buffer View Flow ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef BUFFER_VIEW_FLOW_OP_INTERFACE +#define BUFFER_VIEW_FLOW_OP_INTERFACE + +include "mlir/IR/OpBase.td" + +def BufferViewFlowOpInterface : + OpInterface<"BufferViewFlowOpInterface"> { + let description = [{ + An op interface for the buffer view flow analysis. This interface describes + buffer dependencies between operands and op results/region entry block + arguments. + }]; + let cppNamespace = "::mlir::bufferization"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Populate buffer dependencies between operands and op results/region + entry block arguments. + + Implementations should register dependencies between an operand ("X") + and an op result/region entry block argument ("Y") if Y may depend + on X. Y depends on X if Y and X are the same buffer or if Y is a + subview of X. + + Example: + ``` + %r = arith.select %c, %m1, %m2 : memref<5xf32> + ``` + In the above example, %0 may depend on %m1 or %m2 and a correct + interface implementation should call: + - "registerDependenciesFn(%m1, %r)". + - "registerDependenciesFn(%m2, %r)" + }], + /*retType=*/"void", + /*methodName=*/"populateDependencies", + /*args=*/(ins + "::mlir::bufferization::RegisterDependenciesFn" + :$registerDependenciesFn) + >, + InterfaceMethod< + /*desc=*/[{ + Return "true" if the given value may be a terminal buffer. A buffer + value is "terminal" if it cannot be traced back any further in the + buffer view flow analysis. + + Examples: A buffer could be terminal because: + - it is a newly allocated buffer (e.g., "memref.alloc"), + - or: because there is not enough compile-time information available + to make a definite decision (e.g., "memref.realloc" may reallocate + but we do not know for sure; another example are call ops where we + would have to analyze the body of the callee). + + Implementations can assume that the given SSA value is an OpResult of + this operation or a region entry block argument of this operation. + }], + /*retType=*/"bool", + /*methodName=*/"mayBeTerminalBuffer", + /*args=*/(ins "Value":$value), + /*methodBody=*/"", + /*defaultImplementation=*/"return false;" + >, + ]; +} + +#endif // BUFFER_VIEW_FLOW_OP_INTERFACE diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt index 31a553f..13a5bc3 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc) add_mlir_interface(AllocationOpInterface) add_mlir_interface(BufferDeallocationOpInterface) add_mlir_interface(BufferizableOpInterface) +add_mlir_interface(BufferViewFlowOpInterface) set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td) mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h index 24825db..9e43265 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h @@ -63,6 +63,9 @@ public: /// results have to be changed. void rename(Value from, Value to); + /// Returns "true" if the given value may be a terminal. + bool mayBeTerminalBuffer(Value value) const; + private: /// This function constructs a mapping from values to its immediate /// dependencies. @@ -70,6 +73,9 @@ private: /// Maps values to all immediate dependencies this value can have. ValueMapT dependencies; + + /// A set of all SSA values that may be terminal buffers. + DenseSet terminals; }; } // namespace mlir diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h new file mode 100644 index 0000000..714518a --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H +#define MLIR_DIALECT_MEMREF_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace memref { +void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace memref +} // namespace mlir + +#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_BUFFERVIEWFLOWOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 9bbf12d..c558dc5 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -21,6 +21,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h" #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" @@ -52,6 +53,7 @@ #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h" #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" #include "mlir/Dialect/Mesh/IR/MeshDialect.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" @@ -148,6 +150,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { affine::registerValueBoundsOpInterfaceExternalModels(registry); arith::registerBufferDeallocationOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry); + arith::registerBufferViewFlowOpInterfaceExternalModels(registry); arith::registerValueBoundsOpInterfaceExternalModels(registry); bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( registry); @@ -157,6 +160,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { gpu::registerBufferDeallocationOpInterfaceExternalModels(registry); linalg::registerAllDialectInterfaceImplementations(registry); memref::registerAllocationOpInterfaceExternalModels(registry); + memref::registerBufferViewFlowOpInterfaceExternalModels(registry); memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); memref::registerValueBoundsOpInterfaceExternalModels(registry); memref::registerMemorySlotExternalModels(registry); diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp new file mode 100644 index 0000000..9df9df8 --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp @@ -0,0 +1,44 @@ +//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h" + +using namespace mlir; +using namespace mlir::bufferization; + +namespace mlir { +namespace arith { +namespace { + +struct SelectOpInterface + : public BufferViewFlowOpInterface::ExternalModel { + void + populateDependencies(Operation *op, + RegisterDependenciesFn registerDependenciesFn) const { + auto selectOp = cast(op); + + // Either one of the true/false value may be selected at runtime. + registerDependenciesFn(selectOp.getTrueValue(), selectOp.getResult()); + registerDependenciesFn(selectOp.getFalseValue(), selectOp.getResult()); + } +}; + +} // namespace +} // namespace arith +} // namespace mlir + +void arith::registerBufferViewFlowOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) { + SelectOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt index 0224060..12659ea 100644 --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithTransforms BufferDeallocationOpInterfaceImpl.cpp BufferizableOpInterfaceImpl.cpp Bufferize.cpp + BufferViewFlowOpInterfaceImpl.cpp EmulateUnsupportedFloats.cpp EmulateWideInt.cpp EmulateNarrowType.cpp diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp new file mode 100644 index 0000000..ea726a4 --- /dev/null +++ b/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp @@ -0,0 +1,18 @@ +//===- BufferViewFlowOpInterface.cpp - Buffer View Flow Analysis ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" + +namespace mlir { +namespace bufferization { + +#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp.inc" + +} // namespace bufferization +} // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt index 9895db9..63dcc1e 100644 --- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect BufferDeallocationOpInterface.cpp BufferizationOps.cpp BufferizationDialect.cpp + BufferViewFlowOpInterface.cpp UnstructuredControlFlow.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp index 88ef1b6..9a36057 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp @@ -8,12 +8,16 @@ #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h" +#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h" +#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SetVector.h" using namespace mlir; +using namespace mlir::bufferization; /// Constructs a new alias analysis using the op provided. BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); } @@ -65,18 +69,44 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) { void BufferViewFlowAnalysis::build(Operation *op) { // Registers all dependencies of the given values. auto registerDependencies = [&](ValueRange values, ValueRange dependencies) { - for (auto [value, dep] : llvm::zip(values, dependencies)) + for (auto [value, dep] : llvm::zip_equal(values, dependencies)) this->dependencies[value].insert(dep); }; + // Mark all buffer results and buffer region entry block arguments of the + // given op as terminals. + auto populateTerminalValues = [&](Operation *op) { + for (Value v : op->getResults()) + if (isa(v.getType())) + this->terminals.insert(v); + for (Region &r : op->getRegions()) + for (BlockArgument v : r.getArguments()) + if (isa(v.getType())) + this->terminals.insert(v); + }; + op->walk([&](Operation *op) { - // TODO: We should have an op interface instead of a hard-coded list of - // interfaces/ops. + // Query BufferViewFlowOpInterface. If the op does not implement that + // interface, try to infer the dependencies from other interfaces that the + // op may implement. + if (auto bufferViewFlowOp = dyn_cast(op)) { + bufferViewFlowOp.populateDependencies(registerDependencies); + for (Value v : op->getResults()) + if (isa(v.getType()) && + bufferViewFlowOp.mayBeTerminalBuffer(v)) + this->terminals.insert(v); + for (Region &r : op->getRegions()) + for (BlockArgument v : r.getArguments()) + if (isa(v.getType()) && + bufferViewFlowOp.mayBeTerminalBuffer(v)) + this->terminals.insert(v); + return WalkResult::advance(); + } // Add additional dependencies created by view changes to the alias list. if (auto viewInterface = dyn_cast(op)) { - dependencies[viewInterface.getViewSource()].insert( - viewInterface->getResult(0)); + registerDependencies(viewInterface.getViewSource(), + viewInterface->getResult(0)); return WalkResult::advance(); } @@ -131,16 +161,30 @@ void BufferViewFlowAnalysis::build(Operation *op) { return WalkResult::advance(); } - // Unknown op: Assume that all operands alias with all results. - for (Value operand : op->getOperands()) { - if (!isa(operand.getType())) - continue; - for (Value result : op->getResults()) { - if (!isa(result.getType())) - continue; - registerDependencies({operand}, {result}); - } + // Region terminators are handled together with RegionBranchOpInterface. + if (isa(op)) + return WalkResult::advance(); + + if (isa(op)) { + // This is an intra-function analysis. We have no information about other + // functions. Conservatively assume that each operand may alias with each + // result. Also mark the results are terminals because the function could + // return newly allocated buffers. + populateTerminalValues(op); + for (Value operand : op->getOperands()) + for (Value result : op->getResults()) + registerDependencies({operand}, {result}); + return WalkResult::advance(); } + + // We have no information about unknown ops. + populateTerminalValues(op); + return WalkResult::advance(); }); } + +bool BufferViewFlowAnalysis::mayBeTerminalBuffer(Value value) const { + assert(isa(value.getType()) && "expected memref"); + return terminals.contains(value); +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp new file mode 100644 index 0000000..bbb269b --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp @@ -0,0 +1,48 @@ +//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h" + +#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +using namespace mlir; +using namespace mlir::bufferization; + +namespace mlir { +namespace memref { +namespace { + +struct ReallocOpInterface + : public BufferViewFlowOpInterface::ExternalModel { + void + populateDependencies(Operation *op, + RegisterDependenciesFn registerDependenciesFn) const { + auto reallocOp = cast(op); + // memref.realloc may return the source operand. + registerDependenciesFn(reallocOp.getSource(), reallocOp.getResult()); + } + + bool mayBeTerminalBuffer(Operation *op, Value value) const { + // The return value of memref.realloc is a terminal buffer because the op + // may return a newly allocated buffer. + return true; + } +}; + +} // namespace +} // namespace memref +} // namespace mlir + +void memref::registerBufferViewFlowOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { + ReallocOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt index 08b7eab..f150ac7 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRMemRefTransforms AllocationOpInterfaceImpl.cpp + BufferViewFlowOpInterfaceImpl.cpp ComposeSubView.cpp ExpandOps.cpp ExpandRealloc.cpp @@ -27,6 +28,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms MLIRArithDialect MLIRArithTransforms MLIRBufferizationDialect + MLIRBufferizationTransforms MLIRDialectUtils MLIRFuncDialect MLIRGPUDialect diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 5b6e467..88b46bd 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -10829,6 +10829,36 @@ gentbl_cc_library( ) td_library( + name = "BufferViewFlowOpInterfaceTdFiles", + srcs = [ + "include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td", + ], + includes = ["include"], + deps = [ + ":OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "BufferViewFlowOpInterfaceIncGen", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td", + deps = [ + ":BufferViewFlowOpInterfaceTdFiles", + ], +) + +td_library( name = "SubsetOpInterfaceTdFiles", srcs = [ "include/mlir/Interfaces/SubsetOpInterface.td", @@ -12977,6 +13007,8 @@ cc_library( ":ArithTransforms", ":ArithUtils", ":BufferizationDialect", + ":BufferizationInterfaces", + ":BufferizationTransforms", ":ControlFlowDialect", ":DialectUtils", ":FuncDialect", @@ -13369,6 +13401,7 @@ td_library( includes = ["include"], deps = [ ":AllocationOpInterfaceTdFiles", + ":BufferViewFlowOpInterfaceTdFiles", ":BufferizableOpInterfaceTdFiles", ":CopyOpInterfaceTdFiles", ":DestinationStyleOpInterfaceTdFiles", @@ -13515,11 +13548,13 @@ cc_library( ], hdrs = [ "include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h", + "include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h", "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h", ], includes = ["include"], deps = [ ":BufferDeallocationOpInterfaceIncGen", + ":BufferViewFlowOpInterfaceIncGen", ":BufferizableOpInterfaceIncGen", ":BufferizationEnumsIncGen", ":IR", @@ -13532,6 +13567,7 @@ cc_library( name = "BufferizationDialect", srcs = [ "lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp", + "lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp", "lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp", "lib/Dialect/Bufferization/IR/BufferizationDialect.cpp", "lib/Dialect/Bufferization/IR/BufferizationOps.cpp", @@ -13549,10 +13585,12 @@ cc_library( ":Analysis", ":ArithDialect", ":BufferDeallocationOpInterfaceIncGen", + ":BufferViewFlowOpInterfaceIncGen", ":BufferizableOpInterfaceIncGen", ":BufferizationBaseIncGen", ":BufferizationInterfaces", ":BufferizationOpsIncGen", + ":CallOpInterfaces", ":ControlFlowInterfaces", ":CopyOpInterface", ":DestinationStyleOpInterface", @@ -13602,9 +13640,11 @@ cc_library( ":BufferizationDialect", ":BufferizationInterfaces", ":BufferizationPassIncGen", + ":CallOpInterfaces", ":ControlFlowDialect", ":ControlFlowInterfaces", ":FuncDialect", + ":FunctionInterfaces", ":IR", ":LoopLikeInterface", ":MemRefDialect", -- cgit v1.1