aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Transforms/BufferPlacement.h172
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp13
-rw-r--r--mlir/lib/Transforms/BufferPlacement.cpp98
-rw-r--r--mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir108
-rw-r--r--mlir/test/Transforms/buffer-placement-preparation.mlir4
-rw-r--r--mlir/test/lib/Transforms/TestBufferPlacement.cpp53
-rw-r--r--mlir/tools/mlir-opt/mlir-opt.cpp2
7 files changed, 302 insertions, 148 deletions
diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h
index 89cb4b0..547db48 100644
--- a/mlir/include/mlir/Transforms/BufferPlacement.h
+++ b/mlir/include/mlir/Transforms/BufferPlacement.h
@@ -18,6 +18,7 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Function.h"
#include "mlir/IR/Operation.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -88,12 +89,23 @@ public:
static bool isConvertedMemref(Type type, Type before);
};
-/// Converts the signature of the function using the type converter. It adds an
-/// extra argument for each function result type which is going to be a memref
-/// type after type conversion. The other function result types remain
-/// unchanged. `BufferAssignmentTypeConverter` is a helper `TypeConverter` for
-/// this purpose.
-class FunctionAndBlockSignatureConverter
+namespace detail {
+
+/// Converts the signature of the function based on whether the function is
+/// allowed to return memref typed results or not using
+/// `allowMemrefFunctionResults` parameter. If this option is false, then it
+/// adds an extra function argument as an output buffer for each function result
+/// which is going to be a memref type only after type conversion. The
+/// other function result types remain unchanged. If
+/// `allowMemrefFunctionResults` is true, the types are converted in place.
+/// Any changes in function signature need to be applied
+/// to return and caller operations. `BufferAssignmentReturnOpConverter` and
+/// `BufferAssignmentCallOpConverter` are two helper function that match the
+/// return and caller operation with the new function signature. Furthermore,
+/// `BufferAssignmentTypeConverter` is a helper `TypeConverter` for converting
+/// tensor typed values to memref typed ones.
+template <bool allowMemrefFunctionResults>
+class BufferAssignmentFuncOpConverter
: public BufferAssignmentOpConversionPattern<FuncOp> {
public:
using BufferAssignmentOpConversionPattern<
@@ -101,17 +113,55 @@ public:
/// Performs the actual signature rewriting step.
LogicalResult
- matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final;
+ matchAndRewrite(mlir::FuncOp funcOp, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ if (!converter)
+ return funcOp.emitError("The type converter has not been defined for "
+ "BufferAssignmentFuncOpConverter");
+ auto funcType = funcOp.getType();
+
+ // Convert function arguments using the provided TypeConverter.
+ TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
+ for (auto argType : llvm::enumerate(funcType.getInputs()))
+ conversion.addInputs(argType.index(),
+ converter->convertType(argType.value()));
+
+ // If allowMemrefFunctionResults is false and a function result type is not
+ // a memref but it would be a memref after type conversion, a new argument
+ // should be appended to the function arguments list for this result.
+ // Otherwise, it remains unchanged as a function result.
+ SmallVector<Type, 2> newResultTypes;
+ newResultTypes.reserve(funcOp.getNumResults());
+ for (Type resType : funcType.getResults()) {
+ Type convertedType = converter->convertType(resType);
+ if (!allowMemrefFunctionResults &&
+ BufferAssignmentTypeConverter::isConvertedMemref(convertedType,
+ resType))
+ conversion.addInputs(convertedType);
+ else
+ newResultTypes.push_back(convertedType);
+ }
+
+ // Update the signature of the function.
+ rewriter.updateRootInPlace(funcOp, [&] {
+ funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
+ newResultTypes));
+ rewriter.applySignatureConversion(&funcOp.getBody(), conversion);
+ });
+ return success();
+ }
};
/// Rewrites the `ReturnOp` to conform with the changed function signature.
-/// Operands that correspond to return values that have been rewritten from
-/// tensor results to memref arguments are dropped. In their place, a
-/// corresponding copy operation from the operand to the new function argument
-/// is inserted.
+/// if allowMemrefFunctionResults is false, operands that correspond to return
+/// values and have been rewritten from illegal typed results to memref
+/// arguments are dropped. In their place, a corresponding copy operation from
+/// the operand to the output function argument is inserted. Otherwise, the
+/// memref typed operands are returned.
+/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter,
+/// allowMemrefFunctionResults must be set/unset for both.
template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
- typename CopyOpTy>
+ typename CopyOpTy, bool allowMemrefFunctionResults>
class BufferAssignmentReturnOpConverter
: public BufferAssignmentOpConversionPattern<ReturnOpSourceTy> {
public:
@@ -122,6 +172,13 @@ public:
LogicalResult
matchAndRewrite(ReturnOpSourceTy returnOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
+ // If the memref typed results can be returned as function results, the new
+ // `ReturnOp` should only return the type converted operands.
+ if (allowMemrefFunctionResults) {
+ rewriter.replaceOpWithNewOp<ReturnOpTargetTy>(returnOp, operands);
+ return success();
+ }
+
// Split the operands by their kinds whether they are converted memref or
// not.
SmallVector<Value, 2> needCopyOperands, newOperands;
@@ -158,20 +215,99 @@ public:
}
};
-/// Converts `CallOp` to match its operands and results with the
-/// the callee after rewriting the callee with
-/// FunctionAndBlockSignatureConverter.
+/// Rewrites the `CallOp` to match its operands and results with the signature
+/// of the callee after rewriting the callee with
+/// BufferAssignmentFuncOpConverter. If allowMemrefFunctionResults is false, a
+/// buffer is allocated as an output buffer only for each memref typed result
+/// that has been rewritten. The new allocated buffer is passed through the
+/// operands list of the new `CallOp`.
+/// Note: If this pattern rewriter is used with BufferAssignmentFuncOpConverter,
+/// allowMemrefFunctionResults must be set/unset for both.
+template <bool allowMemrefFunctionResults>
class BufferAssignmentCallOpConverter
: public BufferAssignmentOpConversionPattern<CallOp> {
public:
using BufferAssignmentOpConversionPattern<
CallOp>::BufferAssignmentOpConversionPattern;
- /// Performs the actual `CallOp` conversion step.
LogicalResult
matchAndRewrite(CallOp callOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const final;
+ ConversionPatternRewriter &rewriter) const final {
+ if (!converter)
+ return callOp.emitError("The type converter has not been defined for "
+ "BufferAssignmentCallOpConverter");
+ Location loc = callOp.getLoc();
+
+ // If the memref typed results can be returned as function results, there is
+ // no need to create output buffers. It is only required to convert the type
+ // of operands and results in place for creating the new `CallOp`.
+ if (allowMemrefFunctionResults) {
+ SmallVector<Type, 2> resultTypes;
+ resultTypes.reserve(callOp.getNumResults());
+ for (Type type : callOp.getResultTypes())
+ resultTypes.push_back(converter->convertType(type));
+ rewriter.replaceOpWithNewOp<CallOp>(callOp, callOp.getCallee(),
+ resultTypes, operands);
+ return success();
+ }
+
+ SmallVector<Value, 2> newOperands, replacingValues;
+ SmallVector<Type, 2> newResultTypes;
+ unsigned numResults = callOp.getNumResults();
+ newOperands.reserve(numResults + operands.size());
+ newOperands.append(operands.begin(), operands.end());
+ newResultTypes.reserve(numResults);
+ replacingValues.reserve(numResults);
+
+ // For each memref result of `CallOp` which has not been a memref before
+ // the type conversion, a new buffer is allocated and passed to the operands
+ // list of the new `CallOp`. Otherwise, it remains as a caller result.
+ for (Value result : callOp.getResults()) {
+ Type currType = result.getType();
+ Type newType = converter->convertType(result.getType());
+ if (BufferAssignmentTypeConverter::isConvertedMemref(newType, currType)) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.restoreInsertionPoint(bufferAssignment->computeAllocPosition(
+ result.dyn_cast<OpResult>()));
+ Value alloc =
+ rewriter.create<AllocOp>(loc, newType.dyn_cast<MemRefType>());
+ newOperands.push_back(alloc);
+ replacingValues.push_back(alloc);
+ } else {
+ newResultTypes.push_back(currType);
+
+ // No replacing is required.
+ replacingValues.push_back(nullptr);
+ }
+ }
+
+ // Creating the new `CallOp`.
+ rewriter.create<CallOp>(loc, callOp.getCallee(), newResultTypes,
+ newOperands);
+
+ // Replacing the results of the old `CallOp`.
+ rewriter.replaceOp(callOp, replacingValues);
+ return success();
+ }
};
+} // end namespace detail
+
+/// Populates `patterns` with the conversion patterns of buffer
+/// assignment.
+template <typename ReturnOpSourceTy, typename ReturnOpTargetTy,
+ typename CopyOpTy, bool allowMemrefFunctionResults>
+static void populateWithBufferAssignmentOpConversionPatterns(
+ MLIRContext *context, BufferAssignmentPlacer *placer,
+ TypeConverter *converter, OwningRewritePatternList *patterns) {
+ // clang-format off
+ patterns->insert<
+ detail::BufferAssignmentCallOpConverter<allowMemrefFunctionResults>,
+ detail::BufferAssignmentFuncOpConverter<allowMemrefFunctionResults>,
+ detail::BufferAssignmentReturnOpConverter
+ <ReturnOpSourceTy, ReturnOpTargetTy, CopyOpTy, allowMemrefFunctionResults>
+ >(context, placer, converter);
+ // clang-format on
+}
} // end namespace mlir
#endif // MLIR_TRANSFORMS_BUFFERPLACEMENT_H
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
index c663eb6..1f983e8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
@@ -20,9 +20,6 @@
#include "mlir/Transforms/BufferPlacement.h"
using namespace mlir;
-using ReturnOpConverter =
- BufferAssignmentReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
- linalg::CopyOp>;
namespace {
/// A pattern to convert Generic Linalg operations which work on tensors to
@@ -103,11 +100,11 @@ public:
static void populateConvertLinalgOnTensorsToBuffersPattern(
MLIRContext *context, BufferAssignmentPlacer *placer,
TypeConverter *converter, OwningRewritePatternList *patterns) {
- // clang-format off
- patterns->insert<FunctionAndBlockSignatureConverter,
- GenericOpConverter,
- ReturnOpConverter>(context, placer, converter);
- // clang-format on
+ populateWithBufferAssignmentOpConversionPatterns<
+ mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp,
+ /*allowMemrefFunctionResults=*/false>(context, placer, converter,
+ patterns);
+ patterns->insert<GenericOpConverter>(context, placer, converter);
}
/// Converts Linalg operations that work on tensor-type operands or results to
diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp
index edbaf8e..0bca5cf 100644
--- a/mlir/lib/Transforms/BufferPlacement.cpp
+++ b/mlir/lib/Transforms/BufferPlacement.cpp
@@ -49,8 +49,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/BufferPlacement.h"
-#include "mlir/IR/Function.h"
-#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
@@ -424,102 +422,6 @@ BufferAssignmentPlacer::computeAllocPosition(OpResult result) {
}
//===----------------------------------------------------------------------===//
-// FunctionAndBlockSignatureConverter
-//===----------------------------------------------------------------------===//
-
-// Performs the actual signature rewriting step.
-LogicalResult FunctionAndBlockSignatureConverter::matchAndRewrite(
- FuncOp funcOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
- if (!converter) {
- funcOp.emitError("The type converter has not been defined for "
- "FunctionAndBlockSignatureConverter");
- return failure();
- }
- auto funcType = funcOp.getType();
-
- // Convert function arguments using the provided TypeConverter.
- TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
- for (auto argType : llvm::enumerate(funcType.getInputs()))
- conversion.addInputs(argType.index(),
- converter->convertType(argType.value()));
-
- // If a function result type is not a memref but it would be a memref after
- // type conversion, a new argument should be appended to the function
- // arguments list for this result. Otherwise, it remains unchanged as a
- // function result.
- SmallVector<Type, 2> newResultTypes;
- newResultTypes.reserve(funcOp.getNumResults());
- for (Type resType : funcType.getResults()) {
- Type convertedType = converter->convertType(resType);
- if (BufferAssignmentTypeConverter::isConvertedMemref(convertedType,
- resType))
- conversion.addInputs(convertedType);
- else
- newResultTypes.push_back(convertedType);
- }
-
- // Update the signature of the function.
- rewriter.updateRootInPlace(funcOp, [&] {
- funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
- newResultTypes));
- rewriter.applySignatureConversion(&funcOp.getBody(), conversion);
- });
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// BufferAssignmentCallOpConverter
-//===----------------------------------------------------------------------===//
-
-// Performs `CallOp` conversion to match its operands and results with the
-// signature of the callee after rewriting the callee with
-// FunctionAndBlockSignatureConverter.
-LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
- CallOp callOp, ArrayRef<Value> operands,
- ConversionPatternRewriter &rewriter) const {
-
- Location loc = callOp.getLoc();
- SmallVector<Value, 2> newOperands, replacingValues;
- SmallVector<Type, 2> newResultTypes;
- unsigned numResults = callOp.getNumResults();
- newOperands.reserve(numResults + operands.size());
- newOperands.append(operands.begin(), operands.end());
- newResultTypes.reserve(numResults);
- replacingValues.reserve(numResults);
-
- // For each memref result of `CallOp` which has not been a memref before type
- // conversion, a new buffer is allocated and passed to the operands list of
- // the new `CallOp`. Otherwise, it remains as a caller result.
- for (Value result : callOp.getResults()) {
- Type currType = result.getType();
- Type newType = converter->convertType(result.getType());
- if (BufferAssignmentTypeConverter::isConvertedMemref(newType, currType)) {
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.restoreInsertionPoint(
- bufferAssignment->computeAllocPosition(result.dyn_cast<OpResult>()));
- Value alloc =
- rewriter.create<AllocOp>(loc, newType.dyn_cast<MemRefType>());
- newOperands.push_back(alloc);
- replacingValues.push_back(alloc);
- } else {
- newResultTypes.push_back(currType);
-
- // No replacing is required.
- replacingValues.push_back(nullptr);
- }
- }
-
- // Creating the new `CallOp`.
- rewriter.create<CallOp>(loc, callOp.getCallee(), newResultTypes, newOperands);
-
- // Replacing the results of the old `CallOp`.
- rewriter.replaceOp(callOp, replacingValues);
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
// BufferAssignmentTypeConverter
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
new file mode 100644
index 0000000..adf6e30
--- /dev/null
+++ b/mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt -test-buffer-placement-preparation-with-allowed-memref-results -split-input-file %s | FileCheck %s -dump-input-on-failure
+
+// Since allowMemrefEscaping is on for Buffer Placement in this test pass, all
+// tensor typed function results are converted to memref and remain as function
+// results. All memref typed function results will escape from the deallocation
+// phase of Buffer Placement.
+
+// CHECK-LABEL: func @void_function_signature_conversion
+func @void_function_signature_conversion(%arg0: tensor<4x8xf32>) {
+ return
+}
+// CHECK: ({{.*}}: memref<4x8xf32>)
+
+// -----
+
+#map0 = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: func @complex_signature_conversion
+func @complex_signature_conversion(%arg0: tensor<5xf32>, %arg1: memref<10xf32>, %arg2: i1, %arg3: f16) -> (i1, tensor<5xf32>, memref<10xf32>, memref<15xf32>, f16) {
+ %0 = alloc() : memref<15xf32>
+ %1 = linalg.generic {
+ args_in = 1 : i64,
+ args_out = 1 : i64,
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel"]
+ } %arg0 {
+ ^bb0(%gen1_arg0: f32):
+ %tmp1 = exp %gen1_arg0 : f32
+ linalg.yield %tmp1 : f32
+ }: tensor<5xf32> -> tensor<5xf32>
+ return %arg2, %1, %arg1, %0, %arg3 : i1, tensor<5xf32>, memref<10xf32>, memref<15xf32>, f16
+}
+// CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>, %[[ARG2:.*]]: i1, %[[ARG3:.*]]: f16)
+// CHECK-SAME: (i1, memref<5xf32>, memref<10xf32>, memref<15xf32>, f16)
+// CHECK: %[[FIRST_ALLOC:.*]] = alloc()
+// CHECK: %[[LINALG_ALLOC:.*]] = alloc()
+// CHECK: return %[[ARG2]], %[[LINALG_ALLOC]], %[[ARG1]], %[[FIRST_ALLOC]], %[[ARG3]]
+
+// -----
+
+// CHECK-LABEL: func @no_signature_conversion_is_needed
+func @no_signature_conversion_is_needed(%arg0: memref<4x8xf32>) {
+ return
+}
+// CHECK: ({{.*}}: memref<4x8xf32>)
+
+// -----
+
+// CHECK-LABEL: func @no_signature_conversion_is_needed
+func @no_signature_conversion_is_needed(%arg0: i1, %arg1: f16) -> (i1, f16){
+ return %arg0, %arg1 : i1, f16
+}
+// CHECK: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: f16) -> (i1, f16)
+// CHECK: return %[[ARG0]], %[[ARG1]]
+
+// -----
+
+// CHECK-LABEL: func @simple_signature_conversion
+func @simple_signature_conversion(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
+ return %arg0 : tensor<4x8xf32>
+}
+// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]]<[[RANK:.*]]>) -> [[TYPE]]<[[RANK]]>
+// CHECK-NEXT: return %[[ARG0]]
+
+// -----
+
+// CHECK-LABEL: func @func_and_block_signature_conversion
+func @func_and_block_signature_conversion(%arg0 : tensor<2xf32>, %cond : i1, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32>{
+ cond_br %cond, ^bb1, ^bb2
+ ^bb1:
+ br ^exit(%arg0 : tensor<2xf32>)
+ ^bb2:
+ br ^exit(%arg0 : tensor<2xf32>)
+ ^exit(%arg2: tensor<2xf32>):
+ return %arg1 : tensor<4x4xf32>
+}
+// CHECK: (%[[ARG0:.*]]: [[ARG0_TYPE:.*]], %[[COND:.*]]: i1, %[[ARG1:.*]]: [[ARG1_TYPE:.*]]) -> [[RESULT_TYPE:.*]]
+// CHECK: br ^[[EXIT_BLOCK:.*]](%[[ARG0]] : [[ARG0_TYPE]])
+// CHECK: br ^[[EXIT_BLOCK]](%[[ARG0]] : [[ARG0_TYPE]])
+// CHECK: ^[[EXIT_BLOCK]](%{{.*}}: [[ARG0_TYPE]])
+// CHECK-NEXT: return %[[ARG1]]
+
+// -----
+
+// CHECK-LABEL: func @callee
+func @callee(%arg1: tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>) {
+ %buff = alloc() : memref<2xf32>
+ return %arg1, %buff : tensor<5xf32>, memref<2xf32>
+}
+// CHECK: (%[[CALLEE_ARG:.*]]: memref<5xf32>) -> (memref<5xf32>, memref<2xf32>)
+// CHECK: %[[ALLOC:.*]] = alloc()
+// CHECK: return %[[CALLEE_ARG]], %[[ALLOC]]
+
+// CHECK-LABEL: func @caller
+func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
+ %x:2 = call @callee(%arg0) : (tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>)
+ %y:2 = call @callee(%x#0) : (tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>)
+ return %y#0 : tensor<5xf32>
+}
+// CHECK: (%[[CALLER_ARG:.*]]: memref<5xf32>) -> memref<5xf32>
+// CHECK: %[[X:.*]]:2 = call @callee(%[[CALLER_ARG]])
+// CHECK: %[[Y:.*]]:2 = call @callee(%[[X]]#0)
+// CHECK: return %[[Y]]#0
+
+
+
+
+
diff --git a/mlir/test/Transforms/buffer-placement-preparation.mlir b/mlir/test/Transforms/buffer-placement-preparation.mlir
index 5cde928..cae2829 100644
--- a/mlir/test/Transforms/buffer-placement-preparation.mlir
+++ b/mlir/test/Transforms/buffer-placement-preparation.mlir
@@ -199,7 +199,7 @@ func @compute_allocs_position(%cond: i1, %arg0: tensor<2xf32>) -> tensor<2xf32>{
// -----
// Test case: Checking BufferAssignmentCallOpConverter and
-// FunctionAndBlockSignatureConverter and BufferAssignmentReturnOpConverter all
+// BufferAssignmentFuncOpConverter and BufferAssignmentReturnOpConverter all
// together. The signature of `callee` after signature conversion would be:
// func @callee(%arg0: memref<5xf32>,%arg1: memref<5xf32>) -> ()
@@ -246,7 +246,7 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
// -----
// Test case: Checking BufferAssignmentCallOpConverter and
-// FunctionAndBlockSignatureConverter and BufferAssignmentReturnOpConverter all
+// BufferAssignmentFuncOpConverter and BufferAssignmentReturnOpConverter all
// together on functions that also have memref typed results. The signature of
// `callee` after signature conversion would be:
diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
index aee12b37..3d0cc29 100644
--- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp
+++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
@@ -21,17 +21,22 @@
using namespace mlir;
namespace {
-/// This pass tests the computeAllocPosition helper method and two provided
-/// operation converters, FunctionAndBlockSignatureConverter and
-/// BufferAssignmentReturnOpConverter. Furthermore, this pass converts linalg
-/// operations on tensors to linalg operations on buffers to prepare them for
-/// the BufferPlacement pass that can be applied afterwards.
+/// This pass tests the computeAllocPosition helper method and buffer assignment
+/// operation converters. Furthermore, this pass converts linalg operations on
+/// tensors to linalg operations on buffers to prepare them for the
+/// BufferPlacement pass that can be applied afterwards.
+/// `allowMemrefFunctionResults` informs the buffer placement to allow functions
+/// that have memref typed results. Buffer assignment operation converters will
+/// be adapted respectively. It will also allow memref typed results to escape
+/// from the deallocation.
+template <bool allowMemrefFunctionResults>
struct TestBufferPlacementPreparationPass
- : mlir::PassWrapper<TestBufferPlacementPreparationPass,
- OperationPass<ModuleOp>> {
+ : mlir::PassWrapper<
+ TestBufferPlacementPreparationPass<allowMemrefFunctionResults>,
+ OperationPass<ModuleOp>> {
- /// Converts tensor-type generic linalg operations to memref ones using buffer
- /// assignment.
+ /// Converts tensor-type generic linalg operations to memref ones using
+ /// buffer assignment.
class GenericOpConverter
: public BufferAssignmentOpConversionPattern<linalg::GenericOp> {
public:
@@ -104,19 +109,14 @@ struct TestBufferPlacementPreparationPass
void populateTensorLinalgToBufferLinalgConversionPattern(
MLIRContext *context, BufferAssignmentPlacer *placer,
TypeConverter *converter, OwningRewritePatternList *patterns) {
- // clang-format off
- patterns->insert<
- BufferAssignmentCallOpConverter,
- FunctionAndBlockSignatureConverter,
- GenericOpConverter,
- BufferAssignmentReturnOpConverter<
- ReturnOp, ReturnOp, linalg::CopyOp>
- >(context, placer, converter);
- // clang-format on
+ populateWithBufferAssignmentOpConversionPatterns<
+ mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp,
+ allowMemrefFunctionResults>(context, placer, converter, patterns);
+ patterns->insert<GenericOpConverter>(context, placer, converter);
}
void runOnOperation() override {
- MLIRContext &context = getContext();
+ MLIRContext &context = this->getContext();
ConversionTarget target(context);
BufferAssignmentTypeConverter converter;
@@ -150,7 +150,7 @@ struct TestBufferPlacementPreparationPass
});
// Walk over all the functions to apply buffer assignment.
- getOperation().walk([&](FuncOp function) -> WalkResult {
+ this->getOperation().walk([&](FuncOp function) -> WalkResult {
OwningRewritePatternList patterns;
BufferAssignmentPlacer placer(function);
populateTensorLinalgToBufferLinalgConversionPattern(
@@ -165,9 +165,18 @@ struct TestBufferPlacementPreparationPass
namespace mlir {
void registerTestBufferPlacementPreparationPass() {
- PassRegistration<TestBufferPlacementPreparationPass>(
+ PassRegistration<
+ TestBufferPlacementPreparationPass</*allowMemrefFunctionResults=*/false>>(
"test-buffer-placement-preparation",
"Tests buffer placement helper methods including its "
"operation-conversion patterns");
}
-} // end namespace mlir \ No newline at end of file
+
+void registerTestPreparationPassWithAllowedMemrefResults() {
+ PassRegistration<
+ TestBufferPlacementPreparationPass</*allowMemrefFunctionResults=*/true>>(
+ "test-buffer-placement-preparation-with-allowed-memref-results",
+ "Tests the helper operation converters of buffer placement for allowing "
+ "functions to have memref typed results.");
+}
+} // end namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 2d286e1..067a215 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -61,6 +61,7 @@ void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
void registerTestOpaqueLoc();
void registerTestParallelismDetection();
+void registerTestPreparationPassWithAllowedMemrefResults();
void registerTestGpuParallelLoopMappingPass();
void registerTestSCFUtilsPass();
void registerTestVectorConversions();
@@ -133,6 +134,7 @@ void registerTestPasses() {
registerTestMemRefStrideCalculation();
registerTestOpaqueLoc();
registerTestParallelismDetection();
+ registerTestPreparationPassWithAllowedMemrefResults();
registerTestGpuParallelLoopMappingPass();
registerTestSCFUtilsPass();
registerTestVectorConversions();