diff options
author | Quinn Dawkins <quinn.dawkins@gmail.com> | 2023-12-01 15:05:29 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-01 15:05:29 -0500 |
commit | f310a5d2c13455f1d68f5654fa4258357bafeff6 (patch) | |
tree | facec77faa031c5aad059607d2d9d46d899e1a8e /mlir/test | |
parent | 4c44dcffd5f1557bde2c21773221081437308895 (diff) | |
download | llvm-f310a5d2c13455f1d68f5654fa4258357bafeff6.zip llvm-f310a5d2c13455f1d68f5654fa4258357bafeff6.tar.gz llvm-f310a5d2c13455f1d68f5654fa4258357bafeff6.tar.bz2 |
[mlir][tensor] Add a tensor.concat operation (#72779)
This adds an operation for concatenating ranked tensors along a static
dimension, as well as a decomposition mirroring the existing lowering
from TOSA to Tensor. This offers a convergence point for "input" like
dialects that include various lowerings for concatenation operations,
easing later analysis. In the future, this op can implement the
necessary interfaces for tiling, as well as potentially add conversions
to some kind of linalg and/or memref counterpart.
This patch adds the op, the decomposition, and some basic
folding/canonicalization. Replacing lowerings with the op (such as the
TOSA lowering) will come as a follow up.
See
https://discourse.llvm.org/t/rfc-tensor-add-a-tensor-concatenate-operation/74858
Diffstat (limited to 'mlir/test')
-rw-r--r-- | mlir/test/Dialect/Tensor/canonicalize.mlir | 12 | ||||
-rw-r--r-- | mlir/test/Dialect/Tensor/decompose-concat.mlir | 57 | ||||
-rw-r--r-- | mlir/test/Dialect/Tensor/invalid.mlir | 48 | ||||
-rw-r--r-- | mlir/test/Dialect/Tensor/ops.mlir | 17 |
4 files changed, 134 insertions, 0 deletions
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 580c1db..84c44a0 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -87,6 +87,18 @@ func.func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> // ----- +// CHECK-LABEL: fold_concat +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x2x?xi32> +func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1x2x?xi32>) { + %0 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x3xi32> + // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<1x2x?xi32> to tensor<1x2x3xi32> + %1 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x?xi32> + // CHECK-NEXT: return %[[CAST]], %[[ARG0]] : tensor<1x2x3xi32>, tensor<1x2x?xi32> + return %0, %1 : tensor<1x2x3xi32>, tensor<1x2x?xi32> +} + +// ----- + // CHECK-LABEL: func @fold_extract func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) { %const_0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/Tensor/decompose-concat.mlir b/mlir/test/Dialect/Tensor/decompose-concat.mlir new file mode 100644 index 0000000..5712c77a --- /dev/null +++ b/mlir/test/Dialect/Tensor/decompose-concat.mlir @@ -0,0 +1,57 @@ +// RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) { + transform.apply_patterns to %func_op { + transform.apply_patterns.tensor.decompose_concat + } : !transform.op<"func.func"> + transform.yield + } +} + +func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> { + %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<8x4xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> + return %0 : tensor<?x?xf32> +} +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-LABEL: func @decompose_dynamic_concat( +// CHECK-SAME: %[[ARG0:.+]]: tensor<8x4xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32> + +// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32> +// CHECK: %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[C8]], %[[DIM]]] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[C8]], %[[CONCAT_SIZE]]) : tensor<?x?xf32> +// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<?x?xf32> +// CHECK: %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32> +// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32> +// CHECK: return %[[CONCAT]] : tensor<?x?xf32> + +// ----- + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) { + transform.apply_patterns to %func_op { + transform.apply_patterns.tensor.decompose_concat + } : !transform.op<"func.func"> + transform.yield + } +} + +func.func @decompose_1d_concat(%arg0 : tensor<1xf32>, + %arg1 : tensor<2xf32>, + %arg2 : tensor<3xf32>, + %arg3: tensor<4xf32>) -> tensor<10xf32> { + %0 = tensor.concat dim(0) %arg0, %arg1, %arg2, %arg3 + : (tensor<1xf32>, tensor<2xf32>, tensor<3xf32>, tensor<4xf32>) -> tensor<10xf32> + return %0 : tensor<10xf32> +} +// CHECK-LABEL: func @decompose_1d_concat +// CHECK: tensor.empty() : tensor<10xf32> +// CHECK: tensor.insert_slice %{{.*}}[0] [1] [1] : tensor<1xf32> into tensor<10xf32> +// CHECK: tensor.insert_slice %{{.*}}[1] [2] [1] : tensor<2xf32> into tensor<10xf32> +// CHECK: tensor.insert_slice %{{.*}}[3] [3] [1] : tensor<3xf32> into tensor<10xf32> +// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[6] [4] [1] : tensor<4xf32> into tensor<10xf32> +// CHECK: return %[[CONCAT]] : tensor<10xf32> diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 389e7e6..9b6c232 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -16,6 +16,54 @@ func.func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) { // ----- +func.func @concat_empty() { + // expected-error@+1 {{requires at least one input}} + %0 = tensor.concat dim(0) : () -> tensor<1x2x3xf32> + return +} + +// ----- + +func.func @concat_rank_mismatch(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) { + // expected-error@+1 {{rank of concatenated inputs must match result rank}} + %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<2x1xf32> + return +} + +// ----- + +func.func @concat_dim_out_of_range(%arg0: tensor<3xf32>) { + // expected-error@+1 {{concatenation dim must be less than the tensor rank}} + %0 = tensor.concat dim(1) %arg0 : (tensor<3xf32>) -> tensor<3xf32> + return +} + +// ----- + +func.func @concat_element_type_mismatch(%arg0: tensor<3xf32>, %arg1: tensor<3xi32>) { + // expected-error@+1 {{inputs and result element type must match}} + %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<3xf32>, tensor<3xi32>) -> tensor<3xf32> + return +} + +// ----- + +func.func @concat_incompatible_input_types(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) { + // expected-error@+1 {{static concatenation size mismatch along non-concatenated dimension 1}} + %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<7x5xf32> + return +} + +// ----- + +func.func @concat_static_shape_mismatch(%arg0: tensor<3xf32>) { + // expected-error@+1 {{result type 'tensor<7xf32>'does not match inferred shape 'tensor<6xf32>' static sizes}} + %0 = tensor.concat dim(0) %arg0, %arg0 : (tensor<3xf32>, tensor<3xf32>) -> tensor<7xf32> + return +} + +// ----- + func.func @extract_too_many_indices(%arg0: tensor<?xf32>) { // expected-error@+1 {{incorrect number of indices for extract_element}} %0 = tensor.extract %arg0[] : tensor<?xf32> diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir index 71a0489..2282da3 100644 --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -15,6 +15,23 @@ func.func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x? // ----- +// CHECK-LABEL: func @concat( +func.func @concat(%arg0: tensor<4x7x3xf32>, %arg1 : tensor<4x4x3xf32>, %arg2: tensor<?x?x?xf32>) { + // CHECK: tensor.concat dim(0) %{{.*}} : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32> + %0 = tensor.concat dim(0) %arg0 : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32> + // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32> + %1 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32> + // CHECK: tensor.concat dim(2) %{{.*}} : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> + %2 = tensor.concat dim(2) %arg0, %arg2 : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32> + // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x10x?xf32> + %3 = tensor.concat dim(1) %arg2, %arg2 : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x10x?xf32> + // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<?x?x?xf32>, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32> + %4 = tensor.concat dim(1) %arg2, %arg1, %arg0 : (tensor<?x?x?xf32>, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32> + return +} + +// ----- + // CHECK-LABEL: func @empty( // CHECK-SAME: %[[sz:.*]]: index func.func @empty(%sz: index) -> tensor<5x?x6xf32> { |