diff options
author | Matthias Gehre <matthias.gehre@amd.com> | 2025-01-20 13:42:18 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-01-20 13:42:18 +0100 |
commit | 5ce271ef74dd3325993c827f496e460ced41af11 (patch) | |
tree | 31c77f466ec6577bdd67a486f242737be6234003 | |
parent | d70f54f248853f4d5f9e71a51dfda53a47f0b7d3 (diff) | |
download | llvm-5ce271ef74dd3325993c827f496e460ced41af11.zip llvm-5ce271ef74dd3325993c827f496e460ced41af11.tar.gz llvm-5ce271ef74dd3325993c827f496e460ced41af11.tar.bz2 |
[MLIR] TosaToLinalgNamed: Lower unsigned tosa.max_pool2d (#123290)
This PR allows to lower **unsigned** `tosa.max_pool2d` to linalg.
```
// CHECK-LABEL: @max_pool_ui8
func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> {
// CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8>
// CHECK: arith.constant 0
// CHECK: linalg.pooling_nhwc_max_unsigned {{.*}} : (tensor<1x4x32x62xi8>) -> tensor<1x4x32x62xi8>
// CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8>
%0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8>
return %0 : tensor<1x4x32x62xui8>
}
```
It does this by
- converting the MaxPool2dConverter from OpRewriterPattern to
OpConversion Pattern
- adjusting the padding value to the the minimum unsigned value when the
max_pool is unsigned
- lowering to `linalg.pooling_nhwc_max_unsigned` (which uses
`arith.maxui`) when the max_pool is unsigned
4 files changed, 56 insertions, 18 deletions
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h index 1822016..a1eb22e 100644 --- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -52,7 +52,8 @@ void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, /// Populates conversion passes from TOSA dialect to Linalg named operations. void populateTosaToLinalgNamedConversionPatterns( - RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options); + const TypeConverter &converter, RewritePatternSet *patterns, + const TosaToLinalgNamedOptions &options); } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index d537aef..b7af37d 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -695,17 +695,18 @@ public: } }; -class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> { +class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> { public: - using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern; + using OpConversionPattern::OpConversionPattern; // Compute the dynamic output sizes of the maxpool operation. static SmallVector<Value> - computeDynamicOutputSizes(tosa::MaxPool2dOp op, PatternRewriter &rewriter) { + computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) { TensorType resultTy = op.getType(); Location loc = op.getLoc(); - TypedValue<TensorType> input = op.getInput(); + Value input = adaptor.getInput(); ArrayRef<int64_t> kernel = op.getKernel(); ArrayRef<int64_t> pad = op.getPad(); ArrayRef<int64_t> stride = op.getStride(); @@ -744,16 +745,22 @@ public: return dynamicDims; } - LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { Location loc = op.getLoc(); - TypedValue<TensorType> input = op.getInput(); - ShapedType inputTy = input.getType(); + Value input = adaptor.getInput(); + ShapedType inputTy = cast<ShapedType>(input.getType()); - ShapedType resultTy = op.getType(); + bool isUnsigned = op.getType().getElementType().isUnsignedInteger(); + ShapedType resultTy = + cast<ShapedType>(getTypeConverter()->convertType(op.getType())); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "failed to convert type"); Type resultETy = inputTy.getElementType(); - SmallVector<Value> dynamicDims = computeDynamicOutputSizes(op, rewriter); + SmallVector<Value> dynamicDims = + computeDynamicOutputSizes(op, adaptor, rewriter); // Determine what the initial value needs to be for the max pool op. TypedAttr initialAttr; @@ -762,7 +769,10 @@ public: resultETy, APFloat::getLargest( cast<FloatType>(resultETy).getFloatSemantics(), true)); - if (isa<IntegerType>(resultETy)) + else if (isUnsigned) + initialAttr = rewriter.getIntegerAttr( + resultETy, APInt::getZero(resultETy.getIntOrFloatBitWidth())); + else if (isa<IntegerType>(resultETy)) initialAttr = rewriter.getIntegerAttr( resultETy, APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth())); @@ -798,9 +808,15 @@ public: Value fakeWindowDims = rewriter.create<tensor::EmptyOp>(loc, kernel, resultETy); - rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>( - op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims}, - filledEmptyTensor, strideAttr, dilationAttr); + if (isUnsigned) { + rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>( + op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims}, + filledEmptyTensor, strideAttr, dilationAttr); + } else { + rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>( + op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims}, + filledEmptyTensor, strideAttr, dilationAttr); + } return success(); } }; @@ -1070,7 +1086,8 @@ public: } // namespace void mlir::tosa::populateTosaToLinalgNamedConversionPatterns( - RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options) { + const TypeConverter &converter, RewritePatternSet *patterns, + const TosaToLinalgNamedOptions &options) { if (options.preferConv2DKernelLayoutHWCF) { patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcHwcfQOp>>( @@ -1085,10 +1102,13 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns( ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>, DepthwiseConvConverter, MatMulConverter, - MaxPool2dConverter, AvgPool2dConverter, FullyConnectedConverter, TransposeConverter >(patterns->getContext()); + + patterns->add< + MaxPool2dConverter + >(converter, patterns->getContext()); // clang-format on } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp index 0969693..7d943b3 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp @@ -47,6 +47,9 @@ public: } void runOnOperation() override { + TypeConverter converter; + tosa::populateTosaTypeConversion(converter); + RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); target.addLegalDialect<linalg::LinalgDialect, tosa::TosaDialect, @@ -67,7 +70,8 @@ public: FunctionOpInterface func = getOperation(); TosaToLinalgNamedOptions options; options.preferConv2DKernelLayoutHWCF = preferConv2DKernelLayoutHWCF; - tosa::populateTosaToLinalgNamedConversionPatterns(&patterns, options); + tosa::populateTosaToLinalgNamedConversionPatterns(converter, &patterns, + options); if (failed(applyFullConversion(func, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index 453a861..5eeaebb 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -200,6 +200,19 @@ func.func @max_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> () { return } +// CHECK-LABEL: @max_pool_ui8 +func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> { + // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8> + // CHECK: arith.constant 0 + // CHECK: linalg.pooling_nhwc_max_unsigned + // CHECK-SAME: ins({{.*}} : tensor<1x6x34x62xi8>, tensor<3x3xi8>) + // CHECK-SAME: outs({{.*}} : tensor<1x4x32x62xi8>) + // CHECK-SAME: -> tensor<1x4x32x62xi8> + // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8> + %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> + return %0 : tensor<1x4x32x62xui8> +} + // CHECK-LABEL: @max_pool_i16 func.func @max_pool_i16(%arg0: tensor<1x6x34x62xi16>) -> () { // CHECK: arith.constant -32768 |