aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2023-12-06 12:48:44 +0000
committerGitHub <noreply@github.com>2023-12-06 12:48:44 +0000
commit3a772c3bfeb9d7641b2914672a55fe5838d748db (patch)
tree3d3f35b8a4c4f1a4c0fdcfc34b82d103d0f43436 /mlir
parenta9673bd1ca217e46800f3c2b705c1bed01fdc457 (diff)
downloadllvm-3a772c3bfeb9d7641b2914672a55fe5838d748db.zip
llvm-3a772c3bfeb9d7641b2914672a55fe5838d748db.tar.gz
llvm-3a772c3bfeb9d7641b2914672a55fe5838d748db.tar.bz2
[mlir][tosa] Add fp16 support to `tosa.resize` (#73019)
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp23
-rw-r--r--mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir30
2 files changed, 36 insertions, 17 deletions
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index ca37bd2..beed71d 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1502,6 +1502,9 @@ public:
auto resultTy = cast<ShapedType>(op.getType());
auto resultETy = resultTy.getElementType();
+ bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
+ auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
+
auto imageH = inputTy.getShape()[1];
auto imageW = inputTy.getShape()[2];
@@ -1535,16 +1538,13 @@ public:
Value zeroI32 =
b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
- Value zeroFp32 =
- b.create<arith::ConstantOp>(b.getZeroAttr(b.getF32Type()));
+ Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
- bool floatingPointMode = resultETy.isF32();
-
ArrayRef<int64_t> offset = op.getOffset();
ArrayRef<int64_t> border = op.getBorder();
ArrayRef<int64_t> scale = op.getScale();
@@ -1567,16 +1567,16 @@ public:
int size, ImplicitLocOpBuilder &b) {
if (size == 1) {
index = zeroI32;
- delta = zeroFp32;
+ delta = zeroFp;
return;
}
// x = x * scale_d + offset;
// ix = floor(x / scale_n)
// dx = x / scale_n - ix
- Value val = b.create<arith::UIToFPOp>(b.getF32Type(), in);
- scaleN = b.create<arith::UIToFPOp>(b.getF32Type(), scaleN);
- scaleD = b.create<arith::UIToFPOp>(b.getF32Type(), scaleD);
- offset = b.create<arith::SIToFPOp>(b.getF32Type(), offset);
+ Value val = b.create<arith::UIToFPOp>(floatTy, in);
+ scaleN = b.create<arith::UIToFPOp>(floatTy, scaleN);
+ scaleD = b.create<arith::UIToFPOp>(floatTy, scaleD);
+ offset = b.create<arith::SIToFPOp>(floatTy, offset);
val = b.create<arith::MulFOp>(val, scaleD);
val = b.create<arith::AddFOp>(val, offset);
val = b.create<arith::DivFOp>(val, scaleN);
@@ -1625,7 +1625,7 @@ public:
Value pred;
if (floatingPointMode) {
- auto h = b.create<arith::ConstantOp>(b.getF32FloatAttr(0.5f));
+ auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
} else {
Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
@@ -1681,7 +1681,8 @@ public:
input, ValueRange{batch, y1, x1, channel});
if (floatingPointMode) {
- auto oneVal = b.create<arith::ConstantOp>(b.getF32FloatAttr(1.0f));
+ auto oneVal =
+ b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
auto interpolate = [&](Value val0, Value val1, Value delta,
int inputSize,
ImplicitLocOpBuilder &b) -> Value {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
index e7db61a..aedc6b7 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-resize.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -o -| FileCheck %s
-// CHECK-LABEL: @unary_resize_nearest_fp
-func.func @unary_resize_nearest_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
+// CHECK-LABEL: @unary_resize_nearest_fp32
+func.func @unary_resize_nearest_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
%resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32>
// CHECK: return %arg0
return %resize : tensor<3x1x1x7xf32>
@@ -9,8 +9,17 @@ func.func @unary_resize_nearest_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x
// -----
-// CHECK-LABEL: @unary_resize_bilinear_fp
-func.func @unary_resize_bilinear_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
+// CHECK-LABEL: @unary_resize_nearest_fp16
+func.func @unary_resize_nearest_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> {
+ %resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16>
+ // CHECK: return %arg0
+ return %resize : tensor<3x1x1x7xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @unary_resize_bilinear_fp32
+func.func @unary_resize_bilinear_fp32(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32> {
%resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf32>) -> tensor<3x1x1x7xf32>
// CHECK: return %arg0
return %resize : tensor<3x1x1x7xf32>
@@ -18,6 +27,15 @@ func.func @unary_resize_bilinear_fp(%arg0 : tensor<3x1x1x7xf32>) -> tensor<3x1x1
// -----
+// CHECK-LABEL: @unary_resize_bilinear_fp16
+func.func @unary_resize_bilinear_fp16(%arg0 : tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16> {
+ %resize = "tosa.resize"(%arg0) {mode = "BILINEAR", scale = array<i64: 2, 2, 1, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xf16>) -> tensor<3x1x1x7xf16>
+ // CHECK: return %arg0
+ return %resize : tensor<3x1x1x7xf16>
+}
+
+// -----
+
// CHECK-LABEL: @unary_resize_nearest_i8
func.func @unary_resize_nearest_i8(%arg0 : tensor<3x1x1x7xi8>) -> tensor<3x1x1x7xi8> {
%resize = "tosa.resize"(%arg0) {mode = "NEAREST_NEIGHBOR", scale = array<i64: 2, 1, 3, 1>, offset = array<i64: 0, 0>, border = array<i64: 0, 0>} : (tensor<3x1x1x7xi8>) -> tensor<3x1x1x7xi8>
@@ -285,8 +303,8 @@ func.func @resize_bilinear_int(%arg0: tensor<1x19x20x1xi8>) {
// -----
-// CHECK-LABEL: @resize_nearest_fp
-func.func @resize_nearest_fp(%input: tensor<1x50x48x1xf32>) -> () {
+// CHECK-LABEL: @resize_nearest_fp32
+func.func @resize_nearest_fp32(%input: tensor<1x50x48x1xf32>) -> () {
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x1600x1536x1xf32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK: %[[IDX0:.+]] = linalg.index 0