aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrzej WarzyƄski <andrzej.warzynski@arm.com>2024-02-13 16:42:50 +0000
committerGitHub <noreply@github.com>2024-02-13 16:42:50 +0000
commit7a471133ef56bf6059b3e35125f86420ebbf3a33 (patch)
tree30ef90f155c6b5e9ac68df2946ec28914a298218
parent7a5c1a4abc750fef335c2ee5191d59ebe9e4bf18 (diff)
downloadllvm-7a471133ef56bf6059b3e35125f86420ebbf3a33.zip
llvm-7a471133ef56bf6059b3e35125f86420ebbf3a33.tar.gz
llvm-7a471133ef56bf6059b3e35125f86420ebbf3a33.tar.bz2
[mlir][nfc] Add tests for linalg.mmt4d (#81422)
linalg.mmt4d was added a while back (https://reviews.llvm.org/D105244), but there are virtually no tests in-tree. In the spirit of documenting through test, this PR adds a few basic examples.
-rw-r--r--mlir/test/Dialect/Linalg/invalid.mlir26
-rw-r--r--mlir/test/Dialect/Linalg/named-ops.mlir11
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir70
-rw-r--r--mlir/test/Dialect/Linalg/vectorization.mlir25
4 files changed, 132 insertions, 0 deletions
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 56890df..916c04f 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -744,3 +744,29 @@ func.func @illegal_softmax_output_shape(%arg0: tensor<2x16x32xf32>) -> tensor<2x
-> tensor<2x16xf32>
return %1 : tensor<2x16xf32>
}
+
+// -----
+
+func.func @mmt4d_dims_mismatch(%A: tensor<16x16x8x1xf32>,
+ %B: tensor<16x16x8x1xf32>,
+ %C_in: tensor<16x16x8x1xf32>) -> tensor<16x16x8x1xf32> {
+ // expected-error @+1 {{inferred input/output operand #2 has shape's dimension #3 to be 8, but found 1}}
+ %res = linalg.mmt4d
+ ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+ outs(%C_in: tensor<16x16x8x1xf32>)
+ -> tensor<16x16x8x1xf32>
+ return %res : tensor<16x16x8x1xf32>
+}
+
+// -----
+
+func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
+ %B: tensor<16x16x8x1xf32>,
+ %C_in: tensor<8x8xf32>) -> tensor<8x8xf32> {
+ // expected-error @+1 {{expected operand rank (2) to match the result rank of indexing_map #2 (4)}}
+ %res = linalg.mmt4d
+ ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+ outs(%C_in: tensor<8x8xf32>)
+ -> tensor<8x8xf32>
+ return %res : tensor<8x8xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 29977a7..7064e1b 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1219,6 +1219,17 @@ func.func @batchmatmul_transpose_b(%arg0: memref<2x3x5xf32>, %arg1: memref<2x7x5
// -----
+// CHECK-LABEL: func @mmt4d
+func.func @mmt4d(%A: tensor<10x32x8x1xf32>, %B: tensor<80x32x4x1xf32>, %C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32> {
+ // CHECK: %{{.+}} = linalg.mmt4d
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>)
+ // CHECK-SAME: outs(%{{.+}} : tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
+ %0 = linalg.mmt4d ins(%A, %B : tensor<10x32x8x1xf32>, tensor<80x32x4x1xf32>) outs(%C: tensor<10x80x8x4xf32>) -> tensor<10x80x8x4xf32>
+ return %0: tensor<10x80x8x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @batch_mmt4d
func.func @batch_mmt4d(%arg0: tensor<128x10x32x8x1xf32>, %arg1: tensor<128x80x32x4x1xf32>, %arg2: tensor<128x10x80x8x4xf32>) -> tensor<128x10x80x8x4xf32> {
// CHECK: %{{.+}} = linalg.batch_mmt4d
diff --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
new file mode 100644
index 0000000..61e13d1
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
@@ -0,0 +1,70 @@
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
+
+func.func @mmt4d_to_fma(%A: tensor<16x16x8x1xf32>, %B: tensor<16x16x8x1xf32>, %C_in: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
+ %res = linalg.mmt4d
+ ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
+ outs(%C_in: tensor<16x16x8x8xf32>)
+ -> tensor<16x16x8x8xf32>
+ return %res : tensor<16x16x8x8xf32>
+}
+
+
+// CHECK-LABEL: @mmt4d_to_fma
+// CHECK-COUNT-8: vector.fma
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
+
+ %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %func
+
+ // Step 1: Tile
+ : (!transform.op<"func.func">) -> !transform.any_op
+ // Tile parallel dims
+ %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d[1, 1, 0, 8, 8, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
+ // Tile reduction dims
+ %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p[0, 0, 1, 0, 0, 1]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+ // Step 2: Vectorize
+ transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
+
+ // Step 3: Simplify
+ // vector.multi_reduction --> vector.contract
+ // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
+ // and with the following split into parallel and reduction dims:
+ // * parallel, parallel, reduction, parallel, parallel, reduction
+ transform.apply_patterns to %func {
+ transform.apply_patterns.vector.reduction_to_contract
+ // Reduce the rank of xfer ops. This transforms vector.contract to be
+ // more matmul-like and to enable the lowering to outer product Ops.
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ } : !transform.op<"func.func">
+
+ // Hoisting and LICM - not strictly required
+ %func_h = transform.structured.hoist_redundant_vector_transfers %func
+ : (!transform.op<"func.func">) -> !transform.op<"func.func">
+ %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
+ : (!transform.op<"func.func">) -> !transform.any_op
+ transform.apply_licm to %all_loops : !transform.any_op
+ transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
+
+ // Simplify the 6-dim vector.contract into a 3-dim matmul-like
+ // vector.contract with the following split into parallel and reduction
+ // dims:
+ // * parallel, parallel, reduction
+ transform.apply_patterns to %func_h {
+ transform.apply_patterns.vector.reduction_to_contract
+ transform.apply_patterns.vector.cast_away_vector_leading_one_dim
+ transform.apply_patterns.canonicalization
+ } : !transform.op<"func.func">
+
+ // Step 4: Lower vector.contract to vector.fma via vector.outerproduct
+ transform.apply_patterns to %func_h {
+ transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+ transform.apply_patterns.vector.lower_outerproduct
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 5d1bef4..0272ac5 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -639,6 +639,31 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @mmt4d(%A: memref<16x16x8x1xf32>, %B: memref<16x16x8x1xf32>, %C_in: memref<16x16x8x8xf32>) {
+ linalg.mmt4d ins(%A, %B: memref<16x16x8x1xf32>, memref<16x16x8x1xf32>)
+ outs(%C_in: memref<16x16x8x8xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @mmt4d(
+// CHECK-SAME: %[[A:.*]]: memref<16x16x8x1xf32>, %[[B:.*]]: memref<16x16x8x1xf32>, %[[C:.*]]: memref<16x16x8x8xf32>) {
+// CHECK: %[[VEC_A:.*]] = vector.transfer_read %[[A]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK: %[[VEC_B:.*]] = vector.transfer_read %[[B]]{{.*}} : memref<16x16x8x1xf32>, vector<16x16x16x8x8x1xf32>
+// CHECK: %[[VEC_C:.*]] = vector.transfer_read %[[C]]{{.*}} : memref<16x16x8x8xf32>, vector<16x16x8x8xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_A]], %[[VEC_B]] : vector<16x16x16x8x8x1xf32>
+// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[MUL]], %[[VEC_C]] [2, 5] : vector<16x16x16x8x8x1xf32> to vector<16x16x8x8xf32>
+// CHECK: vector.transfer_write %[[RED]], %[[C]]{{.*}} : vector<16x16x8x8xf32>, memref<16x16x8x8xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %mmt4d : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
func.func @matmul_scalable(%A: memref<?x?xf32>, %B: memref<?x?xf32>, %C: memref<?x?xf32>) {
linalg.matmul ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
outs(%C: memref<?x?xf32>)