diff options
| author | Matthias Springer <springerm@google.com> | 2022-03-03 19:50:32 +0900 |
|---|---|---|
| committer | Matthias Springer <springerm@google.com> | 2022-03-03 20:12:37 +0900 |
| commit | 16cbe883b57ceda7880b65bbeab83bff2493820a (patch) | |
| tree | a7fdf6031fe8b31529a58016e14985b8b3723b5d | |
| parent | 65c0e45a3790c391a0e87d9150c993dbb6537dee (diff) | |
| download | llvm-16cbe883b57ceda7880b65bbeab83bff2493820a.zip llvm-16cbe883b57ceda7880b65bbeab83bff2493820a.tar.gz llvm-16cbe883b57ceda7880b65bbeab83bff2493820a.tar.bz2 | |
[mlir][linalg][bufferize] Migrate --linalg-bufferize to BufferizableOpInterface-based bufferization
This commit deletes the old dialect conversion-based bufferization patterns, which are now obsolete.
Differential Revision: https://reviews.llvm.org/D120883
5 files changed, 46 insertions, 211 deletions
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index ae03c4d..593057d 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -404,6 +404,17 @@ public: void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op, ValueRange values); +/// Lookup the buffer for the given value. If the value was not bufferized yet, +/// wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp, from +/// which the memref operand is returned. +/// +/// Note: Use `BufferizationState::getBuffer` during bufferization. +/// `lookupBuffer` is just for compatibility and gradual migration of +/// bufferization patterns to BufferizableOpInterface-based bufferization. It +/// does not insert any buffer copies. +Value lookupBuffer(RewriterBase &rewriter, Value tensor, + const BufferizationOptions &options); + /// Replace an op with a new op. The new op must have the same number of /// results as the replaced op. The new op may not return any tensor values. template <typename OpTy, typename... Args> diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 54c1fa9..87f5cc1 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -113,17 +113,6 @@ void populateFusePadTensorWithProducerLinalgOpPatterns( /// canonicalizations of named ops into another named op. void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns); -/// Populate the given list with patterns to bufferize linalg ops. -void populateLinalgBufferizePatterns( - bufferization::BufferizeTypeConverter &converter, - RewritePatternSet &patterns); - -/// Create linalg op on buffers given the original tensor-based operation and -/// the buffers for the outputs. -LinalgOp createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter, - LinalgOp linalgOp, ValueRange inputs, - ValueRange outputs); - /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on /// tensors. void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index ee5a34b..e5d9487 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -212,8 +212,8 @@ static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) { #endif } -static Value lookupBuffer(RewriterBase &rewriter, Value tensor, - const BufferizationOptions &options) { +Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor, + const BufferizationOptions &options) { auto tensorType = tensor.getType().dyn_cast<TensorType>(); assert(tensorType && "unexpected non-tensor type"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp index 510e14c..84b16bc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -1,4 +1,4 @@ -//===- Bufferize.cpp - Bufferization of linalg ops ------------------===// +//===- Bufferize.cpp - Bufferization of linalg ops ------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,208 +8,40 @@ #include "PassDetail.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" -using namespace ::mlir; -using namespace ::mlir::linalg; - -static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { - auto memrefType = memref.getType().cast<MemRefType>(); - auto alloc = b.create<memref::AllocOp>(loc, memrefType, - getDynOperands(loc, memref, b)); - b.create<memref::CopyOp>(loc, memref, alloc); - return alloc; -} - -static LogicalResult -allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs, - SmallVectorImpl<Value> &resultBuffers, OpBuilder &b) { - // Lazily compute loopRanges. - SmallVector<Range, 4> loopRanges; - - // Allocate a buffer for every tensor result. - assert(linalgOp.getNumOutputs() == linalgOp->getNumResults()); - for (const auto &en : llvm::enumerate(linalgOp->getResultTypes())) { - size_t resultIndex = en.index(); - Type resultType = en.value(); - - auto tensorType = resultType.dyn_cast<RankedTensorType>(); - if (tensorType == nullptr) { - linalgOp.emitOpError() - << "tensor to buffer conversion expects ranked tensor results"; - return failure(); - } - auto tensorShape = tensorType.getShape(); - auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); - Value resultTensor = outputs[resultIndex]; - - // Clone output buffers whose value is actually used. - OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex); - if (linalgOp.payloadUsesValueFromOperand(tiedOpOperand)) { - resultBuffers.push_back(cloneMemref(loc, resultTensor, b)); - continue; - } - - // Allocate buffers for statically-shaped results. - if (memrefType.hasStaticShape()) { - resultBuffers.push_back(b.create<memref::AllocOp>(loc, memrefType)); - continue; - } - - resultBuffers.push_back(b.create<memref::AllocOp>( - loc, memrefType, getDynOperands(loc, resultTensor, b))); - } - return success(); -} - -/// Create linalg op on buffers given the original tensor-based operation and -/// the buffers for the outputs. -LinalgOp -mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter, - LinalgOp linalgOp, ValueRange inputs, - ValueRange outputs) { - SmallVector<Value, 8> newOperands = inputs; - newOperands.append(outputs.begin(), outputs.end()); - auto *newOp = linalgOp.cloneWithoutRegions(rewriter, linalgOp.getLoc(), - /*resultTypes=*/ArrayRef<Type>{}, - newOperands); - for (auto regions : llvm::zip(linalgOp->getRegions(), newOp->getRegions())) { - auto &oldRegion = std::get<0>(regions); - auto &newRegion = std::get<1>(regions); - rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin()); - } - return newOp; -} - -//===----------------------------------------------------------------------===// -// Bufferization patterns. -//===----------------------------------------------------------------------===// - -namespace { - -/// Conversion pattern that replaces `linalg.init_tensor` with allocation. -class BufferizeInitTensorOp : public OpConversionPattern<InitTensorOp> { -public: - using OpConversionPattern<InitTensorOp>::OpConversionPattern; - - LogicalResult - matchAndRewrite(InitTensorOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - rewriter.replaceOpWithNewOp<memref::AllocOp>( - op, getTypeConverter()->convertType(op.getType()).cast<MemRefType>(), - adaptor.sizes()); - return success(); - } -}; - -/// Conversion pattern that bufferizes `linalg.fill` operation. -class BufferizeFillOp : public OpConversionPattern<FillOp> { -public: - using OpConversionPattern<FillOp>::OpConversionPattern; - - LogicalResult - matchAndRewrite(FillOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const final { - if (!op.output().getType().isa<TensorType>()) - return rewriter.notifyMatchFailure(op, - "operand must be of a tensor type"); - - rewriter.create<FillOp>(op.getLoc(), adaptor.value(), adaptor.output()); - rewriter.replaceOp(op, adaptor.output()); - - return success(); - } -}; - -/// Generic conversion pattern that matches any LinalgOp. This avoids template -/// instantiating one pattern for each LinalgOp. -class BufferizeAnyLinalgOp : public OpInterfaceConversionPattern<LinalgOp> { -public: - using OpInterfaceConversionPattern<LinalgOp>::OpInterfaceConversionPattern; - - LogicalResult - matchAndRewrite(LinalgOp op, ArrayRef<Value> operands, - ConversionPatternRewriter &rewriter) const final { - // GenericOpAdaptor below expects an `operand_segment_sizes` attribute. - if (!op->hasAttr("operand_segment_sizes")) - return failure(); - - // We abuse the GenericOpAdaptor here. - // TODO: Manually create an Adaptor that captures inputs and outputs for all - // linalg::LinalgOp interface ops. - linalg::GenericOpAdaptor adaptor(operands, op->getAttrDictionary()); - - Location loc = op.getLoc(); - SmallVector<Value, 2> newOutputBuffers; - - if (failed(allocateBuffersForResults(loc, op, adaptor.outputs(), - newOutputBuffers, rewriter))) { - return op.emitOpError() - << "Failed to allocate buffers for tensor results."; - } - createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers); - // Replace the results of the old op with the new output buffers. - rewriter.replaceOp(op, newOutputBuffers); - return success(); - } -}; -} // namespace +using namespace mlir; +using namespace bufferization; namespace { /// Converts Linalg operations that work on tensor-type operands or results to /// work on buffers. struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> { void runOnOperation() override { - MLIRContext &context = getContext(); - ConversionTarget target(context); - bufferization::BufferizeTypeConverter typeConverter; - - // Mark certain operations legal. - target.addLegalDialect<arith::ArithmeticDialect, AffineDialect, - memref::MemRefDialect, tensor::TensorDialect>(); - target.addIllegalOp<InitTensorOp>(); + BufferizationOptions options = getPartialBufferizationOptions(); + options.allowDialectInFilter<linalg::LinalgDialect>(); - // Mark all Linalg operations illegal as long as they work on tensors. - auto isLegalOperation = [&](Operation *op) { - return typeConverter.isLegal(op); - }; - target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation); - - RewritePatternSet patterns(&context); - populateLinalgBufferizePatterns(typeConverter, patterns); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + if (failed(bufferizeOp(getOperation(), options))) signalPassFailure(); } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<bufferization::BufferizationDialect, memref::MemRefDialect, + tensor::TensorDialect, linalg::LinalgDialect>(); + linalg::registerBufferizableOpInterfaceExternalModels(registry); + } }; } // namespace std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() { return std::make_unique<LinalgBufferizePass>(); } - -void mlir::linalg::populateLinalgBufferizePatterns( - bufferization::BufferizeTypeConverter &typeConverter, - RewritePatternSet &patterns) { - // TODO: Drop this once tensor constants work in standard. - // clang-format off - patterns.add< - BufferizeAnyLinalgOp, - BufferizeFillOp, - BufferizeInitTensorOp - >(typeConverter, patterns.getContext()); - // clang-format on -} diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir index 2edc104..e6d4f92 100644 --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -linalg-bufferize -canonicalize -cse -split-input-file %s | FileCheck %s +// RUN: mlir-opt -linalg-bufferize -canonicalize -cse -split-input-file %s | FileCheck %s #map0 = affine_map<(d0) -> (d0)> @@ -12,8 +12,8 @@ // CHECK: #map = affine_map<(d0) -> (d0)> // CHECK-LABEL: func @basic( // CHECK-SAME: %[[TENSOR:.*]]: tensor<4xf32>) -> tensor<4xf32> { -// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<4xf32> -// CHECK: %[[RESULT_MEMREF:.*]] = memref.alloc() : memref<4xf32> +// CHECK-DAG: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref<4xf32> +// CHECK-DAG: %[[RESULT_MEMREF:.*]] = memref.alloc() {{.*}} : memref<4xf32> // CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} // CHECK-SAME: ins(%[[MEMREF]] : memref<4xf32>) // CHECK-SAME: outs(%[[RESULT_MEMREF]] : memref<4xf32>) { @@ -46,8 +46,8 @@ func @basic(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK: #map = affine_map<(d0) -> (d0)> // CHECK-LABEL: func @init_tensor( // CHECK-SAME: %[[IN:.*]]: tensor<?xf32>, %[[SIZE:.*]]: index) -// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[IN]] : memref<?xf32> -// CHECK: %[[OUT_BUF:.*]] = memref.alloc(%[[SIZE]]) : memref<?xf32> +// CHECK-DAG: %[[MEMREF:.*]] = bufferization.to_memref %[[IN]] : memref<?xf32> +// CHECK-DAG: %[[OUT_BUF:.*]] = memref.alloc(%[[SIZE]]) {{.*}} : memref<?xf32> // CHECK: linalg.generic // CHECK-SAME: ins(%[[MEMREF]] : memref<?xf32>) // CHECK-SAME: outs(%[[OUT_BUF]] : memref<?xf32>) { @@ -71,8 +71,8 @@ func @init_tensor(%in : tensor<?xf32>, %size: index) -> tensor<?xf32> { #map0 = affine_map<(d0) -> (d0)> // CHECK-LABEL: func @multiple_results -// CHECK: %[[RESULT0:.*]] = memref.alloc() : memref<4xf32> -// CHECK: %[[RESULT1:.*]] = memref.alloc() : memref<4xf32> +// CHECK: %[[RESULT1:.*]] = memref.alloc() {{.*}} : memref<4xf32> +// CHECK: %[[RESULT0:.*]] = memref.alloc() {{.*}} : memref<4xf32> // CHECK: linalg.generic // CHECK-SAME: ins(%{{.*}} : memref<4xf32>) // CHECK-SAME: outs(%[[RESULT0]], %[[RESULT1]] : memref<4xf32>, memref<4xf32>) @@ -101,11 +101,11 @@ func @multiple_results(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { // CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[MEMREF_ARG:.*]] = bufferization.to_memref %[[ARG]] : memref<?x?xf32> // CHECK: %[[DIM0:.*]] = tensor.dim %[[ARG]], %[[C0]] : tensor<?x?xf32> // CHECK: %[[DIM1:.*]] = tensor.dim %[[ARG]], %[[C1]] : tensor<?x?xf32> -// CHECK: %[[RESULT0:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32> -// CHECK: %[[RESULT1:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref<?x?xf32> +// CHECK: %[[RESULT1:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {{.*}} : memref<?x?xf32> +// CHECK: %[[RESULT0:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) {{.*}} : memref<?x?xf32> +// CHECK: %[[MEMREF_ARG:.*]] = bufferization.to_memref %[[ARG]] : memref<?x?xf32> // CHECK: linalg.generic // CHECK-SAME: ins(%[[MEMREF_ARG]] : memref<?x?xf32>) // CHECK-SAME: outs(%[[RESULT0]], %[[RESULT1]] : memref<?x?xf32>, memref<?x?xf32>) @@ -140,9 +140,9 @@ func @dynamic_results(%arg0: tensor<?x?xf32>) // CHECK-LABEL: func @generic_with_init_tensor( // CHECK-SAME: %[[ARG0_TENSOR:.*]]: tensor<2x3x4xvector<3x4xi4>>, // CHECK-SAME: %[[ARG1_TENSOR:.*]]: tensor<3x2xf32>) -> tensor<3x2xf32> { +// CHECK: %[[INIT_BUFFER:.*]] = memref.alloc() {{.*}} : memref<3x2xf32> // CHECK-DAG: %[[ARG0_MEMREF:.*]] = bufferization.to_memref %[[ARG0_TENSOR]] : memref<2x3x4xvector<3x4xi4>> // CHECK-DAG: %[[ARG1_MEMREF:.*]] = bufferization.to_memref %[[ARG1_TENSOR]] : memref<3x2xf32> -// CHECK: %[[INIT_BUFFER:.*]] = memref.alloc() : memref<3x2xf32> // CHECK: memref.copy %[[ARG1_MEMREF]], %[[INIT_BUFFER]] : memref<3x2xf32> to memref<3x2xf32> // CHECK: linalg.generic // CHECK-SAME: ins(%[[ARG0_MEMREF]] : memref<2x3x4xvector<3x4xi4>>) @@ -166,9 +166,9 @@ func @generic_with_init_tensor(%arg0: tensor<2x3x4xvector<3x4xi4>>, // CHECK-SAME: %[[IN:.*]]: tensor<?xf32> func @bufferize_fill(%arg0: tensor<?xf32>) -> tensor<?xf32> { %c0 = arith.constant 0.0 : f32 - // CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[IN]] : memref<?xf32> - // CHECK: linalg.fill(%cst, %[[MEMREF]]) : f32, memref<?xf32> - // CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[MEMREF]] : memref<?xf32> + // CHECK: %[[ALLOC:.*]] = memref.alloc + // CHECK: linalg.fill(%cst, %[[ALLOC]]) : f32, memref<?xf32> + // CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<?xf32> // CHECK: return %[[TENSOR]] %0 = linalg.fill(%c0, %arg0) : f32, tensor<?xf32> -> tensor<?xf32> return %0 : tensor<?xf32> @@ -179,10 +179,13 @@ func @bufferize_fill(%arg0: tensor<?xf32>) -> tensor<?xf32> { // CHECK-LABEL: func @bufferize_dot func @bufferize_dot(%in: tensor<4xf32>, %out: tensor<f32>) -> tensor<f32> { %dot = linalg.dot ins(%in, %in : tensor<4xf32>, tensor<4xf32>) - outs(%out : tensor<f32>) -> tensor<f32> + outs(%out : tensor<f32>) -> tensor<f32> return %dot : tensor<f32> + // CHECK: %[[ALLOC:.*]] = memref.alloc + // TODO: The copy is not necessary. + // CHECK: memref.copy {{.*}}, %[[ALLOC]] // CHECK: linalg.dot ins(%{{.*}}, %{{.*}} : memref<4xf32>, memref<4xf32>) - // CHECK-SAME: outs(%[[OUT:.*]] : memref<f32>) - // CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[OUT]] : memref<f32> + // CHECK-SAME: outs(%[[ALLOC:.*]] : memref<f32>) + // CHECK: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ALLOC]] : memref<f32> // CHECK: return %[[OUT_TENSOR]] } |
