From f310a5d2c13455f1d68f5654fa4258357bafeff6 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Fri, 1 Dec 2023 15:05:29 -0500 Subject: [mlir][tensor] Add a tensor.concat operation (#72779) This adds an operation for concatenating ranked tensors along a static dimension, as well as a decomposition mirroring the existing lowering from TOSA to Tensor. This offers a convergence point for "input" like dialects that include various lowerings for concatenation operations, easing later analysis. In the future, this op can implement the necessary interfaces for tiling, as well as potentially add conversions to some kind of linalg and/or memref counterpart. This patch adds the op, the decomposition, and some basic folding/canonicalization. Replacing lowerings with the op (such as the TOSA lowering) will come as a follow up. See https://discourse.llvm.org/t/rfc-tensor-add-a-tensor-concatenate-operation/74858 --- mlir/test/Dialect/Tensor/canonicalize.mlir | 12 ++++++ mlir/test/Dialect/Tensor/decompose-concat.mlir | 57 ++++++++++++++++++++++++++ mlir/test/Dialect/Tensor/invalid.mlir | 48 ++++++++++++++++++++++ mlir/test/Dialect/Tensor/ops.mlir | 17 ++++++++ 4 files changed, 134 insertions(+) create mode 100644 mlir/test/Dialect/Tensor/decompose-concat.mlir (limited to 'mlir/test') diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 580c1db..84c44a0 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -87,6 +87,18 @@ func.func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32> // ----- +// CHECK-LABEL: fold_concat +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x2x?xi32> +func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1x2x?xi32>) { + %0 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x3xi32> + // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<1x2x?xi32> to tensor<1x2x3xi32> + %1 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x?xi32> + // CHECK-NEXT: return %[[CAST]], %[[ARG0]] : tensor<1x2x3xi32>, tensor<1x2x?xi32> + return %0, %1 : tensor<1x2x3xi32>, tensor<1x2x?xi32> +} + +// ----- + // CHECK-LABEL: func @fold_extract func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex) { %const_0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/Tensor/decompose-concat.mlir b/mlir/test/Dialect/Tensor/decompose-concat.mlir new file mode 100644 index 0000000..5712c77a --- /dev/null +++ b/mlir/test/Dialect/Tensor/decompose-concat.mlir @@ -0,0 +1,57 @@ +// RUN: mlir-opt -split-input-file -transform-interpreter -cse %s | FileCheck %s + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) { + transform.apply_patterns to %func_op { + transform.apply_patterns.tensor.decompose_concat + } : !transform.op<"func.func"> + transform.yield + } +} + +func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor) -> tensor { + %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<8x4xf32>, tensor) -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-LABEL: func @decompose_dynamic_concat( +// CHECK-SAME: %[[ARG0:.+]]: tensor<8x4xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor + +// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor +// CHECK: %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[C8]], %[[DIM]]] +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[C8]], %[[CONCAT_SIZE]]) : tensor +// CHECK: %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor +// CHECK: %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor into tensor +// CHECK: return %[[CONCAT]] : tensor + +// ----- + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) { + transform.apply_patterns to %func_op { + transform.apply_patterns.tensor.decompose_concat + } : !transform.op<"func.func"> + transform.yield + } +} + +func.func @decompose_1d_concat(%arg0 : tensor<1xf32>, + %arg1 : tensor<2xf32>, + %arg2 : tensor<3xf32>, + %arg3: tensor<4xf32>) -> tensor<10xf32> { + %0 = tensor.concat dim(0) %arg0, %arg1, %arg2, %arg3 + : (tensor<1xf32>, tensor<2xf32>, tensor<3xf32>, tensor<4xf32>) -> tensor<10xf32> + return %0 : tensor<10xf32> +} +// CHECK-LABEL: func @decompose_1d_concat +// CHECK: tensor.empty() : tensor<10xf32> +// CHECK: tensor.insert_slice %{{.*}}[0] [1] [1] : tensor<1xf32> into tensor<10xf32> +// CHECK: tensor.insert_slice %{{.*}}[1] [2] [1] : tensor<2xf32> into tensor<10xf32> +// CHECK: tensor.insert_slice %{{.*}}[3] [3] [1] : tensor<3xf32> into tensor<10xf32> +// CHECK: %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[6] [4] [1] : tensor<4xf32> into tensor<10xf32> +// CHECK: return %[[CONCAT]] : tensor<10xf32> diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 389e7e6..9b6c232 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -16,6 +16,54 @@ func.func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) { // ----- +func.func @concat_empty() { + // expected-error@+1 {{requires at least one input}} + %0 = tensor.concat dim(0) : () -> tensor<1x2x3xf32> + return +} + +// ----- + +func.func @concat_rank_mismatch(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) { + // expected-error@+1 {{rank of concatenated inputs must match result rank}} + %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<2x1xf32> + return +} + +// ----- + +func.func @concat_dim_out_of_range(%arg0: tensor<3xf32>) { + // expected-error@+1 {{concatenation dim must be less than the tensor rank}} + %0 = tensor.concat dim(1) %arg0 : (tensor<3xf32>) -> tensor<3xf32> + return +} + +// ----- + +func.func @concat_element_type_mismatch(%arg0: tensor<3xf32>, %arg1: tensor<3xi32>) { + // expected-error@+1 {{inputs and result element type must match}} + %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<3xf32>, tensor<3xi32>) -> tensor<3xf32> + return +} + +// ----- + +func.func @concat_incompatible_input_types(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) { + // expected-error@+1 {{static concatenation size mismatch along non-concatenated dimension 1}} + %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<7x5xf32> + return +} + +// ----- + +func.func @concat_static_shape_mismatch(%arg0: tensor<3xf32>) { + // expected-error@+1 {{result type 'tensor<7xf32>'does not match inferred shape 'tensor<6xf32>' static sizes}} + %0 = tensor.concat dim(0) %arg0, %arg0 : (tensor<3xf32>, tensor<3xf32>) -> tensor<7xf32> + return +} + +// ----- + func.func @extract_too_many_indices(%arg0: tensor) { // expected-error@+1 {{incorrect number of indices for extract_element}} %0 = tensor.extract %arg0[] : tensor diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir index 71a0489..2282da3 100644 --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -15,6 +15,23 @@ func.func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor, %arg1 : tensor<4x4x3xf32>, %arg2: tensor) { + // CHECK: tensor.concat dim(0) %{{.*}} : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32> + %0 = tensor.concat dim(0) %arg0 : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32> + // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32> + %1 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32> + // CHECK: tensor.concat dim(2) %{{.*}} : (tensor<4x7x3xf32>, tensor) -> tensor + %2 = tensor.concat dim(2) %arg0, %arg2 : (tensor<4x7x3xf32>, tensor) -> tensor + // CHECK: tensor.concat dim(1) %{{.*}} : (tensor, tensor) -> tensor + %3 = tensor.concat dim(1) %arg2, %arg2 : (tensor, tensor) -> tensor + // CHECK: tensor.concat dim(1) %{{.*}} : (tensor, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32> + %4 = tensor.concat dim(1) %arg2, %arg1, %arg0 : (tensor, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32> + return +} + +// ----- + // CHECK-LABEL: func @empty( // CHECK-SAME: %[[sz:.*]]: index func.func @empty(%sz: index) -> tensor<5x?x6xf32> { -- cgit v1.1