diff options
author | Andrzej WarzyĆski <andrzej.warzynski@arm.com> | 2024-02-13 16:42:50 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-13 16:42:50 +0000 |
commit | 7a471133ef56bf6059b3e35125f86420ebbf3a33 (patch) | |
tree | 30ef90f155c6b5e9ac68df2946ec28914a298218 | |
parent | 7a5c1a4abc750fef335c2ee5191d59ebe9e4bf18 (diff) | |
download | llvm-7a471133ef56bf6059b3e35125f86420ebbf3a33.zip llvm-7a471133ef56bf6059b3e35125f86420ebbf3a33.tar.gz llvm-7a471133ef56bf6059b3e35125f86420ebbf3a33.tar.bz2 |
[mlir][nfc] Add tests for linalg.mmt4d (#81422)
linalg.mmt4d was added a while back (https://reviews.llvm.org/D105244),
but there are virtually no tests in-tree. In the spirit of documenting
through test, this PR adds a few basic examples.
-rw-r--r-- | mlir/test/Dialect/Linalg/invalid.mlir | 26 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/named-ops.mlir | 11 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir | 70 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/vectorization.mlir | 25 |
4 files changed, 132 insertions, 0 deletions
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 56890df..916c04f 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -744,3 +744,29 @@ func.func @illegal_softmax_output_shape(%arg0: tensor<2x16x32xf32>) -> tensor<2x -> tensor<2x16xf32> return %1 : tensor<2x16xf32> } + +// ----- + +func.func @mmt4d_dims_mismatch(%A: tensor<16x16x8x1xf32>, + %B: tensor<16x16x8x1xf32>, + %C_in: tensor<16x16x8x1xf32>) -> tensor<16x16x8x1xf32> { + // expected-error @+1 {{inferred input/output operand #2 has shape's dimension #3 to be 8, but found 1}} + %res = linalg.mmt4d + ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>) + outs(%C_in: tensor<16x16x8x1xf32>) + -> tensor<16x16x8x1xf32> + return %res : tensor<16x16x8x1xf32> +} + +// ----- + +func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>, + %B: tensor<16x16x8x1xf32>, + %C_in: tensor<8x8xf32>) -> tensor<8x8xf32> { + // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #2 (4)}} + %res = linalg.mmt4d + ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>) + outs(%C_in: tensor<8x8xf32>) + -> tensor<8x8xf32> + return %res : tensor<8x8xf32> +} diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir index 29977a7..7064e1b 100644 --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -1219,6 +1219,17 @@ func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5 // ----- +// CHECK-LABEL: func @mmt4d +func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> { + // CHECK: %{{.+}} = linalg.mmt4d + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) + // CHECK-SAME: outs(%{{.+}} : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> + %0 = linalg.mmt4d ins(%A, %B : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> + return %0: tensor<10x80x8x4xf32> +} + +// ----- + // CHECK-LABEL: func @batch_mmt4d func.func @batch_mmt4d(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> { // CHECK: %{{.+}} = linalg.batch_mmt4d diff --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir new file mode 100644 index 0000000..61e13d1 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir @@ -0,0 +1,70 @@ +// RUN: mlir-opt %s -transform-interpreter | FileCheck %s + +func.func @mmt4d_to_fma(%A: tensor<16x16x8x1xf32>, %B: tensor<16x16x8x1xf32>, %C_in: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> { + %res = linalg.mmt4d + ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>) + outs(%C_in: tensor<16x16x8x8xf32>) + -> tensor<16x16x8x8xf32> + return %res : tensor<16x16x8x8xf32> +} + + +// CHECK-LABEL: @mmt4d_to_fma +// CHECK-COUNT-8: vector.fma + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func"> + + %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %func + + // Step 1: Tile + : (!transform.op<"func.func">) -> !transform.any_op + // Tile parallel dims + %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d[1, 1, 0, 8, 8, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + // Tile reduction dims + %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p[0, 0, 1, 0, 0, 1] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + // Step 2: Vectorize + transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op + + // Step 3: Simplify + // vector.multi_reduction --> vector.contract + // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op + // and with the following split into parallel and reduction dims: + // * parallel, parallel, reduction, parallel, parallel, reduction + transform.apply_patterns to %func { + transform.apply_patterns.vector.reduction_to_contract + // Reduce the rank of xfer ops. This transforms vector.contract to be + // more matmul-like and to enable the lowering to outer product Ops. + transform.apply_patterns.vector.transfer_permutation_patterns + } : !transform.op<"func.func"> + + // Hoisting and LICM - not strictly required + %func_h = transform.structured.hoist_redundant_vector_transfers %func + : (!transform.op<"func.func">) -> !transform.op<"func.func"> + %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h + : (!transform.op<"func.func">) -> !transform.any_op + transform.apply_licm to %all_loops : !transform.any_op + transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op + + // Simplify the 6-dim vector.contract into a 3-dim matmul-like + // vector.contract with the following split into parallel and reduction + // dims: + // * parallel, parallel, reduction + transform.apply_patterns to %func_h { + transform.apply_patterns.vector.reduction_to_contract + transform.apply_patterns.vector.cast_away_vector_leading_one_dim + transform.apply_patterns.canonicalization + } : !transform.op<"func.func"> + + // Step 4: Lower vector.contract to vector.fma via vector.outerproduct + transform.apply_patterns to %func_h { + transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" + transform.apply_patterns.vector.lower_outerproduct + } : !transform.op<"func.func"> + transform.yield + } +} diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index 5d1bef4..0272ac5 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -639,6 +639,31 @@ module attributes {transform.with_named_sequence} { // ----- +func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) { + linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>) + outs(%C_in: memref<16x16x8x8xf32>) + return +} + +// CHECK-LABEL: func.func @mmt4d( +// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) { +// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32> +// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32> +// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32> +// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32> +// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32> +// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %mmt4d : !transform.any_op + transform.yield + } +} + +// ----- + func.func @matmul_scalable(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) { linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>) outs(%C: memref<?x?xf32>) |