diff options
Diffstat (limited to 'mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir')
-rw-r--r-- | mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 27 |
1 files changed, 19 insertions, 8 deletions
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 9d43f89..7b8fc24 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -357,6 +357,17 @@ func.func @test_accepts_unranked_scalar_tensor(%arg0: tensor<1x2x2xf32>, %arg1: // ----- +// CHECK-LABEL: @test_unranked_scalar_i8_tensor +func.func @test_unranked_scalar_i8_tensor(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>, %arg2: tensor<1xi8>) -> tensor<4xi32> { + // CHECK: %[[SHIFT:.*]] = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<1xi8> + %shift = tosa.cast %arg2 : (tensor<1xi8>) -> tensor<*xi8> + // CHECK: tosa.mul %arg0, %arg1, %[[SHIFT]] : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<4xi32>, tensor<4xi32>, tensor<*xi8>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// ----- + // CHECK-LABEL: @test_table_static func.func @test_table_static(%arg0 : tensor<4x5xi16>, %arg1 : tensor<513xi16>) -> () { // CHECK:tosa.table %arg0, %arg1 : (tensor<4x5xi16>, tensor<513xi16>) -> tensor<4x5xi16> @@ -1166,8 +1177,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens %b = tosa.log %arg1 : (tensor<f32>) -> tensor<f32> // CHECK: tosa.cond_if - // CHECK: -> (tensor<f32>) - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + // CHECK: -> tensor<f32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { tosa.yield %a : tensor<f32> } else { tosa.yield %b : tensor<f32> @@ -1180,8 +1191,8 @@ func.func @if_test_simple(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tens // CHECK-LABEL: @if_test_dynamic func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor<?xf32>) - %0 = tosa.cond_if %arg2 -> (tensor<?xf32>) { + // CHECK: -> tensor<?xf32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<?xf32> { tosa.yield %arg0 : tensor<2xf32> } else { tosa.yield %arg1 : tensor<3xf32> @@ -1194,8 +1205,8 @@ func.func @if_test_dynamic(%arg0 : tensor<2xf32>, %arg1 : tensor<3xf32>, %arg2 : // CHECK-LABEL: @if_test_unranked func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor<*xf32>) - %0 = tosa.cond_if %arg2 -> (tensor<*xf32>) { + // CHECK: -> tensor<*xf32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<*xf32> { tosa.yield %arg0 : tensor<f32> } else { tosa.yield %arg1 : tensor<3xf32> @@ -1208,8 +1219,8 @@ func.func @if_test_unranked(%arg0 : tensor<f32>, %arg1 : tensor<3xf32>, %arg2 : // CHECK-LABEL: @if_test_propagate func.func @if_test_propagate(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> () { // CHECK: tosa.cond_if - // CHECK: -> (tensor<f32>) - %0 = tosa.cond_if %arg2 -> (tensor<f32>) { + // CHECK: -> tensor<f32> + %0 = tosa.cond_if %arg2 : tensor<i1> -> tensor<f32> { %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32> tosa.yield %1 : tensor<f32> } else { |