diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2023-12-06 12:48:44 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-06 12:48:44 +0000 |
commit | 3a772c3bfeb9d7641b2914672a55fe5838d748db (patch) | |
tree | 3d3f35b8a4c4f1a4c0fdcfc34b82d103d0f43436 /mlir | |
parent | a9673bd1ca217e46800f3c2b705c1bed01fdc457 (diff) | |
download | llvm-3a772c3bfeb9d7641b2914672a55fe5838d748db.zip llvm-3a772c3bfeb9d7641b2914672a55fe5838d748db.tar.gz llvm-3a772c3bfeb9d7641b2914672a55fe5838d748db.tar.bz2 |
[mlir][tosa] Add fp16 support to `tosa.resize` (#73019)
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 23 | ||||
-rw-r--r-- | mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir | 30 |
2 files changed, 36 insertions, 17 deletions
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index ca37bd2..beed71d 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1502,6 +1502,9 @@ public: auto resultTy = cast<ShapedType>(op.getType()); auto resultETy = resultTy.getElementType(); + bool floatingPointMode = resultETy.isF16() || resultETy.isF32(); + auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type(); + auto imageH = inputTy.getShape()[1]; auto imageW = inputTy.getShape()[2]; @@ -1535,16 +1538,13 @@ public: Value zeroI32 = b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type())); - Value zeroFp32 = - b.create<arith::ConstantOp>(b.getZeroAttr(b.getF32Type())); + Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy)); Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1)); Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1)); Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y); Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x); - bool floatingPointMode = resultETy.isF32(); - ArrayRef<int64_t> offset = op.getOffset(); ArrayRef<int64_t> border = op.getBorder(); ArrayRef<int64_t> scale = op.getScale(); @@ -1567,16 +1567,16 @@ public: int size, ImplicitLocOpBuilder &b) { if (size == 1) { index = zeroI32; - delta = zeroFp32; + delta = zeroFp; return; } // x = x * scale_d + offset; // ix = floor(x / scale_n) // dx = x / scale_n - ix - Value val = b.create<arith::UIToFPOp>(b.getF32Type(), in); - scaleN = b.create<arith::UIToFPOp>(b.getF32Type(), scaleN); - scaleD = b.create<arith::UIToFPOp>(b.getF32Type(), scaleD); - offset = b.create<arith::SIToFPOp>(b.getF32Type(), offset); + Value val = b.create<arith::UIToFPOp>(floatTy, in); + scaleN = b.create<arith::UIToFPOp>(floatTy, scaleN); + scaleD = b.create<arith::UIToFPOp>(floatTy, scaleD); + offset = b.create<arith::SIToFPOp>(floatTy, offset); val = b.create<arith::MulFOp>(val, scaleD); val = b.create<arith::AddFOp>(val, offset); val = b.create<arith::DivFOp>(val, scaleN); @@ -1625,7 +1625,7 @@ public: Value pred; if (floatingPointMode) { - auto h = b.create<arith::ConstantOp>(b.getF32FloatAttr(0.5f)); + auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f)); pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h); } else { Value dvalDouble = b.create<arith::ShLIOp>(dval, one); @@ -1681,7 +1681,8 @@ public: input, ValueRange{batch, y1, x1, channel}); if (floatingPointMode) { - auto oneVal = b.create<arith::ConstantOp>(b.getF32FloatAttr(1.0f)); + auto oneVal = + b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f)); auto interpolate = [&](Value val0, Value val1, Value delta, int inputSize, ImplicitLocOpBuilder &b) -> Value { diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir index e7db61a..aedc6b7 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -o -| FileCheck %s -// CHECK-LABEL: @unary_resize_nearest_fp -func.func @unary_resize_nearest_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> { +// CHECK-LABEL: @unary_resize_nearest_fp32 +func.func @unary_resize_nearest_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> { %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> // CHECK: return %arg0 return %resize : tensor<3x1x1x7xf32> @@ -9,8 +9,17 @@ func.func @unary_resize_nearest_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x // ----- -// CHECK-LABEL: @unary_resize_bilinear_fp -func.func @unary_resize_bilinear_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> { +// CHECK-LABEL: @unary_resize_nearest_fp16 +func.func @unary_resize_nearest_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> { + %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> + // CHECK: return %arg0 + return %resize : tensor<3x1x1x7xf16> +} + +// ----- + +// CHECK-LABEL: @unary_resize_bilinear_fp32 +func.func @unary_resize_bilinear_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> { %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> // CHECK: return %arg0 return %resize : tensor<3x1x1x7xf32> @@ -18,6 +27,15 @@ func.func @unary_resize_bilinear_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1 // ----- +// CHECK-LABEL: @unary_resize_bilinear_fp16 +func.func @unary_resize_bilinear_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> { + %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> + // CHECK: return %arg0 + return %resize : tensor<3x1x1x7xf16> +} + +// ----- + // CHECK-LABEL: @unary_resize_nearest_i8 func.func @unary_resize_nearest_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x1x1x7xi8> { %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = array<i64: 2, 1, 3, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xi8>) -> tensor<3x1x1x7xi8> @@ -285,8 +303,8 @@ func.func @resize_bilinear_int(%arg0: tensor<1x19x20x1xi8>) { // ----- -// CHECK-LABEL: @resize_nearest_fp -func.func @resize_nearest_fp(%input: tensor<1x50x48x1xf32>) -> () { +// CHECK-LABEL: @resize_nearest_fp32 +func.func @resize_nearest_fp32(%input: tensor<1x50x48x1xf32>) -> () { // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x1600x1536x1xf32> // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK: %[[IDX0:.+]] = linalg.index 0 |