aboutsummaryrefslogtreecommitdiff
path: root/mlir/test
diff options
context:
space:
mode:
authorQuinn Dawkins <quinn.dawkins@gmail.com>2023-12-01 15:05:29 -0500
committerGitHub <noreply@github.com>2023-12-01 15:05:29 -0500
commitf310a5d2c13455f1d68f5654fa4258357bafeff6 (patch)
treefacec77faa031c5aad059607d2d9d46d899e1a8e /mlir/test
parent4c44dcffd5f1557bde2c21773221081437308895 (diff)
downloadllvm-f310a5d2c13455f1d68f5654fa4258357bafeff6.zip
llvm-f310a5d2c13455f1d68f5654fa4258357bafeff6.tar.gz
llvm-f310a5d2c13455f1d68f5654fa4258357bafeff6.tar.bz2
[mlir][tensor] Add a tensor.concat operation (#72779)
This adds an operation for concatenating ranked tensors along a static dimension, as well as a decomposition mirroring the existing lowering from TOSA to Tensor. This offers a convergence point for "input" like dialects that include various lowerings for concatenation operations, easing later analysis. In the future, this op can implement the necessary interfaces for tiling, as well as potentially add conversions to some kind of linalg and/or memref counterpart. This patch adds the op, the decomposition, and some basic folding/canonicalization. Replacing lowerings with the op (such as the TOSA lowering) will come as a follow up. See https://discourse.llvm.org/t/rfc-tensor-add-a-tensor-concatenate-operation/74858
Diffstat (limited to 'mlir/test')
-rw-r--r--mlir/test/Dialect/Tensor/canonicalize.mlir12
-rw-r--r--mlir/test/Dialect/Tensor/decompose-concat.mlir57
-rw-r--r--mlir/test/Dialect/Tensor/invalid.mlir48
-rw-r--r--mlir/test/Dialect/Tensor/ops.mlir17
4 files changed, 134 insertions, 0 deletions
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 580c1db..84c44a0 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -87,6 +87,18 @@ func.func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32>
// -----
+// CHECK-LABEL: fold_concat
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x2x?xi32>
+func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1x2x?xi32>) {
+ %0 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x3xi32>
+ // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<1x2x?xi32> to tensor<1x2x3xi32>
+ %1 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x?xi32>
+ // CHECK-NEXT: return %[[CAST]], %[[ARG0]] : tensor<1x2x3xi32>, tensor<1x2x?xi32>
+ return %0, %1 : tensor<1x2x3xi32>, tensor<1x2x?xi32>
+}
+
+// -----
+
// CHECK-LABEL: func @fold_extract
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
%const_0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Tensor/decompose-concat.mlir b/mlir/test/Dialect/Tensor/decompose-concat.mlir
new file mode 100644
index 0000000..5712c77a
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/decompose-concat.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.tensor.decompose_concat
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
+
+func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<8x4xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @decompose_dynamic_concat(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<8x4xf32>
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+
+// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[C8]], %[[DIM]]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[C8]], %[[CONCAT_SIZE]]) : tensor<?x?xf32>
+// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<?x?xf32>
+// CHECK: %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+// CHECK: return %[[CONCAT]] : tensor<?x?xf32>
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.tensor.decompose_concat
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
+
+func.func @decompose_1d_concat(%arg0 : tensor<1xf32>,
+ %arg1 : tensor<2xf32>,
+ %arg2 : tensor<3xf32>,
+ %arg3: tensor<4xf32>) -> tensor<10xf32> {
+ %0 = tensor.concat dim(0) %arg0, %arg1, %arg2, %arg3
+ : (tensor<1xf32>, tensor<2xf32>, tensor<3xf32>, tensor<4xf32>) -> tensor<10xf32>
+ return %0 : tensor<10xf32>
+}
+// CHECK-LABEL: func @decompose_1d_concat
+// CHECK: tensor.empty() : tensor<10xf32>
+// CHECK: tensor.insert_slice %{{.*}}[0] [1] [1] : tensor<1xf32> into tensor<10xf32>
+// CHECK: tensor.insert_slice %{{.*}}[1] [2] [1] : tensor<2xf32> into tensor<10xf32>
+// CHECK: tensor.insert_slice %{{.*}}[3] [3] [1] : tensor<3xf32> into tensor<10xf32>
+// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[6] [4] [1] : tensor<4xf32> into tensor<10xf32>
+// CHECK: return %[[CONCAT]] : tensor<10xf32>
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 389e7e6..9b6c232 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -16,6 +16,54 @@ func.func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) {
// -----
+func.func @concat_empty() {
+ // expected-error@+1 {{requires at least one input}}
+ %0 = tensor.concat dim(0) : () -> tensor<1x2x3xf32>
+ return
+}
+
+// -----
+
+func.func @concat_rank_mismatch(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) {
+ // expected-error@+1 {{rank of concatenated inputs must match result rank}}
+ %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<2x1xf32>
+ return
+}
+
+// -----
+
+func.func @concat_dim_out_of_range(%arg0: tensor<3xf32>) {
+ // expected-error@+1 {{concatenation dim must be less than the tensor rank}}
+ %0 = tensor.concat dim(1) %arg0 : (tensor<3xf32>) -> tensor<3xf32>
+ return
+}
+
+// -----
+
+func.func @concat_element_type_mismatch(%arg0: tensor<3xf32>, %arg1: tensor<3xi32>) {
+ // expected-error@+1 {{inputs and result element type must match}}
+ %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<3xf32>, tensor<3xi32>) -> tensor<3xf32>
+ return
+}
+
+// -----
+
+func.func @concat_incompatible_input_types(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) {
+ // expected-error@+1 {{static concatenation size mismatch along non-concatenated dimension 1}}
+ %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<7x5xf32>
+ return
+}
+
+// -----
+
+func.func @concat_static_shape_mismatch(%arg0: tensor<3xf32>) {
+ // expected-error@+1 {{result type 'tensor<7xf32>'does not match inferred shape 'tensor<6xf32>' static sizes}}
+ %0 = tensor.concat dim(0) %arg0, %arg0 : (tensor<3xf32>, tensor<3xf32>) -> tensor<7xf32>
+ return
+}
+
+// -----
+
func.func @extract_too_many_indices(%arg0: tensor<?xf32>) {
// expected-error@+1 {{incorrect number of indices for extract_element}}
%0 = tensor.extract %arg0[] : tensor<?xf32>
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 71a0489..2282da3 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -15,6 +15,23 @@ func.func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?
// -----
+// CHECK-LABEL: func @concat(
+func.func @concat(%arg0: tensor<4x7x3xf32>, %arg1 : tensor<4x4x3xf32>, %arg2: tensor<?x?x?xf32>) {
+ // CHECK: tensor.concat dim(0) %{{.*}} : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32>
+ %0 = tensor.concat dim(0) %arg0 : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32>
+ // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+ %1 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+ // CHECK: tensor.concat dim(2) %{{.*}} : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ %2 = tensor.concat dim(2) %arg0, %arg2 : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+ // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x10x?xf32>
+ %3 = tensor.concat dim(1) %arg2, %arg2 : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x10x?xf32>
+ // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<?x?x?xf32>, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32>
+ %4 = tensor.concat dim(1) %arg2, %arg1, %arg0 : (tensor<?x?x?xf32>, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32>
+ return
+}
+
+// -----
+
// CHECK-LABEL: func @empty(
// CHECK-SAME: %[[sz:.*]]: index
func.func @empty(%sz: index) -> tensor<5x?x6xf32> {