aboutsummaryrefslogtreecommitdiff
path: root/mlir/test
diff options
context:
space:
mode:
authorAndrzej Warzyński <andrzej.warzynski@arm.com>2023-11-25 12:32:55 +0000
committerGitHub <noreply@github.com>2023-11-25 12:32:55 +0000
commit385304eb6f8c5b3137cd4937d6a93c5072bfc6b1 (patch)
treeb2ac91460c483b20a4a5ccb80200c8cfade5627b /mlir/test
parent9020026b88b7f49d9d0e1780a89df4a2bce811e8 (diff)
downloadllvm-385304eb6f8c5b3137cd4937d6a93c5072bfc6b1.zip
llvm-385304eb6f8c5b3137cd4937d6a93c5072bfc6b1.tar.gz
llvm-385304eb6f8c5b3137cd4937d6a93c5072bfc6b1.tar.bz2
[MLIR][Vector] Refactor tests for contract -> OP transforms (#73348)
This patch refactors tests for: * vector.contract -> vector.outerproduct transformations for matvec operations (b += Ax). Specifically, relevant tests from the following 2 files: * vector-contract-matvec-transforms.mlir * vector-contract-to-outerproduct-transforms.mlir are combined into one: * vector-contract-to-outerproduct-matvec-transforms.mlir All original tests are preserved and no new tests are added. This is a part of a larger effort to add cases with scalable vectors to tests for the Vector dialect. I am refactoring these test as a preparation for follow-up patches. Implements #72834.
Diffstat (limited to 'mlir/test')
-rw-r--r--mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir170
-rw-r--r--mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir531
-rw-r--r--mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir396
3 files changed, 531 insertions, 566 deletions
diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
deleted file mode 100644
index 811fb58..0000000
--- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir
+++ /dev/null
@@ -1,170 +0,0 @@
-// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
-
-#matvec_accesses = [
- affine_map<(m, k) -> (m, k)>,
- affine_map<(m, k) -> (k)>,
- affine_map<(m, k) -> (m)>
-]
-#matvec_trait = {
- indexing_maps = #matvec_accesses,
- iterator_types = ["parallel", "reduction"]
-}
-#matvecmax_trait = {
- indexing_maps = #matvec_accesses,
- iterator_types = ["parallel", "reduction"],
- kind = #vector.kind<maxf>
-}
-
-#mattransvec_accesses = [
- affine_map<(m, k) -> (k, m)>,
- affine_map<(m, k) -> (k)>,
- affine_map<(m, k) -> (m)>
-]
-#mattransvec_trait = {
- indexing_maps = #mattransvec_accesses,
- iterator_types = ["parallel", "reduction"]
-}
-
-#vecmat_accesses = [
- affine_map<(m, k) -> (k)>,
- affine_map<(m, k) -> (m, k)>,
- affine_map<(m, k) -> (m)>
-]
-#vecmat_trait = {
- indexing_maps = #vecmat_accesses,
- iterator_types = ["parallel", "reduction"]
-}
-
-#vecmattrans_accesses = [
- affine_map<(m, k) -> (k)>,
- affine_map<(m, k) -> (k, m)>,
- affine_map<(m, k) -> (m)>
-]
-#vecmattrans_trait = {
- indexing_maps = #vecmattrans_accesses,
- iterator_types = ["parallel", "reduction"]
-}
-
-#redpar_vecmattrans_accesses = [
- affine_map<(m, k) -> (m)>,
- affine_map<(m, k) -> (m, k)>,
- affine_map<(m, k) -> (k)>
-]
-#redpar_vecmattrans_trait = {
- indexing_maps = #redpar_vecmattrans_accesses,
- iterator_types = ["reduction", "parallel"]
-}
-
-// CHECK-LABEL: func @matvec_mk_k_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
- %x: vector<2xf32>,
- %b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @matvec_mk_k_m_max
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
-// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
-func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
- %x: vector<2xf32>,
- %b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @matvec_km_k_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_km_k_m(%A: vector<2x2xf32>,
- %x: vector<2xf32>,
- %b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @matvec_k_mk_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
- %x: vector<2xf32>,
- %b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @matvec_k_km_m
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_k_km_m(%A: vector<2x2xf32>,
- %x: vector<2xf32>,
- %b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func @matvec_m_mk_k
-// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
-// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
-// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
-// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
-// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
-// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
-// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
-func.func @matvec_m_mk_k(%A: vector<2x2xf32>,
- %x: vector<2xf32>,
- %b: vector<2xf32>) -> vector<2xf32> {
- %0 = vector.contract #redpar_vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
- transform.apply_patterns to %func_op {
- transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
- } : !transform.op<"func.func">
- transform.yield
- }
-}
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
new file mode 100644
index 0000000..3ca3d34
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir
@@ -0,0 +1,531 @@
+// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+
+#matvec_accesses_1 = [
+ affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m)>
+]
+#matvec_trait_1 = {
+ indexing_maps = #matvec_accesses_1,
+ iterator_types = ["parallel", "reduction"]
+}
+
+#matvecmax_trait = {
+ indexing_maps = #matvec_accesses_1,
+ iterator_types = ["parallel", "reduction"],
+ kind = #vector.kind<maxf>
+}
+
+#matvec_accesses_2 = [
+ affine_map<(m, k) -> (k, m)>,
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m)>
+]
+#matvec_trait_2 = {
+ indexing_maps = #matvec_accesses_2,
+ iterator_types = ["parallel", "reduction"]
+}
+
+#matvec_accesses_3 = [
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (m)>
+]
+#matvec_trait_3 = {
+ indexing_maps = #matvec_accesses_3,
+ iterator_types = ["parallel", "reduction"]
+}
+
+#matvec_accesses_4 = [
+ affine_map<(m, k) -> (k)>,
+ affine_map<(m, k) -> (k, m)>,
+ affine_map<(m, k) -> (m)>
+]
+#matvec_trait_4 = {
+ indexing_maps = #matvec_accesses_4,
+ iterator_types = ["parallel", "reduction"]
+}
+
+#redpar_vecmattrans_accesses = [
+ affine_map<(m, k) -> (m)>,
+ affine_map<(m, k) -> (m, k)>,
+ affine_map<(m, k) -> (k)>
+]
+#redpar_vecmattrans_trait = {
+ indexing_maps = #redpar_vecmattrans_accesses,
+ iterator_types = ["reduction", "parallel"]
+}
+
+// ============================================================================
+// Matvec 1 (plain + masked + scalable)
+// ============================================================================
+// CHECK-LABEL: func.func @masked_matvec_mk_k_m(
+// CHECK-SAME: %{{.*}}: vector<2x3xf32>,
+// CHECK-SAME: %{{.*}}: vector<3xf32>,
+// CHECK-SAME: %{{.*}}: vector<2xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
+// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
+// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+
+// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
+// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+
+// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
+// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+
+func.func @masked_matvec_mk_k_m(%arg0: vector<2x3xf32>,
+ %arg1: vector<3xf32>,
+ %arg2: vector<2xf32>,
+ %m: vector<2x3xi1>) -> vector<2xf32> {
+ %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+ : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func.func @masked_matvec_mk_k_m_scalable_parallel_dim(
+// CHECK-SAME: %{{.*}}: vector<[2]x3xf32>,
+// CHECK-SAME: %{{.*}}: vector<3xf32>,
+// CHECK-SAME: %{{.*}}: vector<[2]xf32>,
+// CHECK-SAME: %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
+// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
+// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+
+// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
+// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
+func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
+ %arg1: vector<3xf32>,
+ %arg2: vector<[2]xf32>,
+ %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
+ %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
+ : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
+ return %0 : vector<[2]xf32>
+}
+
+// CHECK-LABEL: func @matvec_mk_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_mk_k_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract #matvec_trait_1 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// ============================================================================
+// Matvec 1 - max (plain)
+// ============================================================================
+// CHECK-LABEL: func @matvec_mk_k_m_max
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
+// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32
+func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// ============================================================================
+// Matvec 2 (plain + masked + scalable)
+// ============================================================================
+// CHECK-LABEL: @masked_matvec_km_k_m
+// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<4xf32>,
+ %mask: vector<4x2xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+ : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<4x2xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
+// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<[4]xf32>,
+ %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
+ : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+ } : vector<[4]x2xi1> -> vector<[4]xf32>
+ return %res : vector<[4]xf32>
+}
+
+// CHECK-LABEL: func @matvec_km_k_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_km_k_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract #matvec_trait_2 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// ============================================================================
+// Matvec 3 (plain + masked + scalable)
+// ============================================================================
+// CHECK-LABEL: @masked_matvec_k_mk_m
+// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<4xf32>,
+ %mask: vector<4x2xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+ : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<4x2xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_k_mk_m_scalable_parallel_dim
+// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<[4]xf32>,
+ %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
+ : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
+ } : vector<[4]x2xi1> -> vector<[4]xf32>
+ return %res : vector<[4]xf32>
+}
+
+
+// CHECK-LABEL: func @matvec_k_mk_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_k_mk_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract #matvec_trait_3 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// ============================================================================
+// Matvec 4 (plain + masked + scalable)
+// ============================================================================
+// CHECK-LABEL: @masked_matvec_k_km_m_scalable_parallel_dim
+// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
+func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<[4]xf32>,
+ %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
+ : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
+ } : vector<[4]x2xi1> -> vector<[4]xf32>
+ return %res : vector<[4]xf32>
+}
+
+// CHECK-LABEL: @masked_matvec_k_km_m
+// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
+func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>,
+ %arg1: vector<2xf32>,
+ %arg2: vector<4xf32>,
+ %mask: vector<4x2xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MASK]]
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
+ : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<4x2xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: func @matvec_k_km_m
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_k_km_m(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract #matvec_trait_4 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func @matvec_m_mk_k
+// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32>
+// CHECK-SAME: %[[B:.*1]]: vector<2xf32>
+// CHECK-SAME: %[[C:.*2]]: vector<2xf32>
+// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32>
+// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32>
+// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32>
+// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
+func.func @matvec_m_mk_k(%A: vector<2x2xf32>,
+ %x: vector<2xf32>,
+ %b: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.contract #matvec_trait_4 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// ============================================================================
+// Matvec 5 (masked + scalable)
+// ============================================================================
+#matvec_accesses_5 = [
+ affine_map<(k, m) -> (m, k)>,
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m)>
+]
+#matvec_trait_5 = {
+ indexing_maps = #matvec_accesses_5,
+ iterator_types = ["reduction", "parallel"]
+}
+
+// CHECK-LABEL: @masked_tmatvec_mk_k_m
+// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
+ : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<2x4xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_mk_k_m_scalable_parallel_dim
+// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
+ : vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+ } : vector<2x[4]xi1> -> vector<[4]xf32>
+ return %res : vector<[4]xf32>
+}
+
+// ============================================================================
+// Matvec 6 (masked + scalable)
+// ============================================================================
+#matvec_accesses_6 = [
+ affine_map<(k, m) -> (k, m)>,
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m)>
+]
+#matvec_trait_6 = {
+ indexing_maps = #matvec_accesses_6,
+ iterator_types = ["reduction", "parallel"]
+}
+
+// CHECK-LABEL: @masked_tmatvec_km_k_m
+// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
+ : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<2x4xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_km_k_m_scalable_parallel_dim
+// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
+ : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
+ } : vector<2x[4]xi1> -> vector<[4]xf32>
+ return %res : vector<[4]xf32>
+}
+
+// ============================================================================
+// Matvec 7 (masked + scalable)
+// ============================================================================
+#matvec_accesses_7 = [
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (m, k)>,
+ affine_map<(k, m) -> (m)>
+]
+#matvec_trait_7 = {
+ indexing_maps = #matvec_accesses_7,
+ iterator_types = ["reduction", "parallel"]
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_mk_m
+// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_7 %arg1, %arg0, %arg2
+ : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<2x4xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_mk_m_scalable_parallel_dim
+// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+ // CHECK: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_7 %arg1, %arg0, %arg2
+ : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
+ } : vector<2x[4]xi1> -> vector<[4]xf32>
+ return %res : vector<[4]xf32>
+}
+
+// ============================================================================
+// Matvec 8 (masked + scalable)
+// ============================================================================
+#matvec_accesses_8 = [
+ affine_map<(k, m) -> (k)>,
+ affine_map<(k, m) -> (k, m)>,
+ affine_map<(k, m) -> (m)>
+]
+#matvec_trait_8 = {
+ indexing_maps = #matvec_accesses_8,
+ iterator_types = ["reduction", "parallel"]
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_km_m
+// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_8 %arg1, %arg0, %arg2
+ : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
+ } : vector<2x4xi1> -> vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK-LABEL: @masked_tmatvec_k_km_m_scalable_parallel_dim
+// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
+// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
+// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
+// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
+func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
+ // CHECK-NOT: vector.transpose %[[MAT]]
+ // CHECK-NOT: vector.transpose %[[MASK]]
+ // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
+ %res = vector.mask %mask {
+ vector.contract #matvec_trait_8 %arg1, %arg0, %arg2
+ : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
+ } : vector<2x[4]xi1> -> vector<[4]xf32>
+ return %res : vector<[4]xf32>
+}
+
+// ============================================================================
+// TD sequence
+// ============================================================================
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
index 2c228a0..7588b73 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir
@@ -11,8 +11,6 @@
//
// (*) The conversion tested in this file unrolls along the reduction
// dimension, which is not supported for scalable vectors.
-//
-// TODO: Matvec without a mask
// ============================================================================
// Matmul 0 (plain + masked + mixed types)
@@ -388,400 +386,6 @@ func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>,
}
// ============================================================================
-// Matvec 1 (masked)
-// ============================================================================
-#matvec_accesses_1 = [
- affine_map<(m, k) -> (m, k)>,
- affine_map<(m, k) -> (k)>,
- affine_map<(m, k) -> (m)>
-]
-#matvec_trait_1 = {
- indexing_maps = #matvec_accesses_1,
- iterator_types = ["parallel", "reduction"]
-}
-
-// CHECK-LABEL: func.func @masked_matvec_mk_k_m(
-// CHECK-SAME: %{{.*}}: vector<2x3xf32>,
-// CHECK-SAME: %{{.*}}: vector<3xf32>,
-// CHECK-SAME: %{{.*}}: vector<2xf32>,
-// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
-// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
-// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
-// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
-
-// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
-// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
-
-// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
-// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
-
-func.func @masked_matvec_mk_k_m(%arg0: vector<2x3xf32>,
- %arg1: vector<3xf32>,
- %arg2: vector<2xf32>,
- %m: vector<2x3xi1>) -> vector<2xf32> {
- %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
- : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
- return %0 : vector<2xf32>
-}
-
-// CHECK-LABEL: func.func @masked_matvec_mk_k_m_scalable_parallel_dim(
-// CHECK-SAME: %{{.*}}: vector<[2]x3xf32>,
-// CHECK-SAME: %{{.*}}: vector<3xf32>,
-// CHECK-SAME: %{{.*}}: vector<[2]xf32>,
-// CHECK-SAME: %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
-// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
-// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
-// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
-
-// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
-// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
-
-// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
-// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
-func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
- %arg1: vector<3xf32>,
- %arg2: vector<[2]xf32>,
- %m: vector<[2]x3xi1>) -> vector<[2]xf32> {
- %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2
- : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
- return %0 : vector<[2]xf32>
-}
-
-// ============================================================================
-// Matvec 2 (masked)
-// ============================================================================
-#matvec_accesses_2 = [
- affine_map<(m, k) -> (k, m)>,
- affine_map<(m, k) -> (k)>,
- affine_map<(m, k) -> (m)>
-]
-#matvec_trait_2 = {
- indexing_maps = #matvec_accesses_2,
- iterator_types = ["parallel", "reduction"]
-}
-
-// CHECK-LABEL: @masked_matvec_km_k_m
-// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
-// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<4xf32>,
- %mask: vector<4x2xi1>) -> vector<4xf32> {
- // CHECK: vector.transpose %[[MASK]]
- // CHECK-NOT: vector.transpose %[[MAT]]
- // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
- %res = vector.mask %mask {
- vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
- : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
- } : vector<4x2xi1> -> vector<4xf32>
- return %res : vector<4xf32>
-}
-
-// CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim
-// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
-// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<[4]xf32>,
- %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
- // CHECK: vector.transpose %[[MASK]]
- // CHECK-NOT: vector.transpose %[[MAT]]
- // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
- %res = vector.mask %mask {
- vector.contract #matvec_trait_2 %arg0, %arg1, %arg2
- : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
- } : vector<[4]x2xi1> -> vector<[4]xf32>
- return %res : vector<[4]xf32>
-}
-
-// ============================================================================
-// Matvec 3 (masked)
-// ============================================================================
-#matvec_accesses_3 = [
- affine_map<(m, k) -> (k)>,
- affine_map<(m, k) -> (m, k)>,
- affine_map<(m, k) -> (m)>
-]
-#matvec_trait_3 = {
- indexing_maps = #matvec_accesses_3,
- iterator_types = ["parallel", "reduction"]
-}
-
-// CHECK-LABEL: @masked_matvec_k_mk_m
-// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
-// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<4xf32>,
- %mask: vector<4x2xi1>) -> vector<4xf32> {
- // CHECK: vector.transpose %[[MASK]]
- // CHECK: vector.transpose %[[MAT]]
- // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
- %res = vector.mask %mask {
- vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
- : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
- } : vector<4x2xi1> -> vector<4xf32>
- return %res : vector<4xf32>
-}
-
-// CHECK-LABEL: @masked_matvec_k_mk_m_scalable_parallel_dim
-// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
-// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<[4]xf32>,
- %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
- // CHECK: vector.transpose %[[MASK]]
- // CHECK: vector.transpose %[[MAT]]
- // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
- %res = vector.mask %mask {
- vector.contract #matvec_trait_3 %arg1, %arg0, %arg2
- : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
- } : vector<[4]x2xi1> -> vector<[4]xf32>
- return %res : vector<[4]xf32>
-}
-
-// ============================================================================
-// Matvec 4 (masked)
-// ============================================================================
-#matvec_accesses_4 = [
- affine_map<(m, k) -> (k)>,
- affine_map<(m, k) -> (k, m)>,
- affine_map<(m, k) -> (m)>
-]
-#matvec_trait_4 = {
- indexing_maps = #matvec_accesses_4,
- iterator_types = ["parallel", "reduction"]
-}
-
-// CHECK-LABEL: @masked_matvec_k_km_m_scalable_parallel_dim
-// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
-// CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1>
-func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<[4]xf32>,
- %mask: vector<[4]x2xi1>) -> vector<[4]xf32> {
- // CHECK: vector.transpose %[[MASK]]
- // CHECK-NOT: vector.transpose %[[MAT]]
- // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
- %res = vector.mask %mask {
- vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
- : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
- } : vector<[4]x2xi1> -> vector<[4]xf32>
- return %res : vector<[4]xf32>
-}
-
-// CHECK-LABEL: @masked_matvec_k_km_m
-// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
-// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
-func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>,
- %arg1: vector<2xf32>,
- %arg2: vector<4xf32>,
- %mask: vector<4x2xi1>) -> vector<4xf32> {
- // CHECK: vector.transpose %[[MASK]]
- // CHECK-NOT: vector.transpose %[[MAT]]
- // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
- %res = vector.mask %mask {
- vector.contract #matvec_trait_4 %arg1, %arg0, %arg2
- : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
- } : vector<4x2xi1> -> vector<4xf32>
- return %res : vector<4xf32>
-}
-
-// ============================================================================
-// Matvec 5 (masked)
-// ============================================================================
-#matvec_accesses_5 = [
- affine_map<(k, m) -> (m, k)>,
- affine_map<(k, m) -> (k)>,
- affine_map<(k, m) -> (m)>
-]
-#matvec_trait_5 = {
- indexing_maps = #matvec_accesses_5,
- iterator_types = ["reduction", "parallel"]
-}
-
-// CHECK-LABEL: @masked_tmatvec_mk_k_m
-// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
-// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
-func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
- // CHECK: vector.transpose %[[MAT]]
- // CHECK-NOT: vector.transpose %[[MASK]]
- // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
- %res = vector.mask %mask {
- vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
- : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
- } : vector<2x4xi1> -> vector<4xf32>
- return %res : vector<4xf32>
-}
-
-// CHECK-LABEL: @masked_tmatvec_mk_k_m_scalable_parallel_dim
-// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
-// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
-func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
- // CHECK: vector.transpose %[[MAT]]
- // CHECK-NOT: vector.transpose %[[MASK]]
- // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
- %res = vector.mask %mask {
- vector.contract #matvec_trait_5 %arg0, %arg1, %arg2
- : vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
- } : vector<2x[4]xi1> -> vector<[4]xf32>
- return %res : vector<[4]xf32>
-}
-
-// ============================================================================
-// Matvec 6 (masked)
-// ============================================================================
-#matvec_accesses_6 = [
- affine_map<(k, m) -> (k, m)>,
- affine_map<(k, m) -> (k)>,
- affine_map<(k, m) -> (m)>
-]
-#matvec_trait_6 = {
- indexing_maps = #matvec_accesses_6,
- iterator_types = ["reduction", "parallel"]
-}
-
-// CHECK-LABEL: @masked_tmatvec_km_k_m
-// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
-// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
-func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
- // CHECK-NOT: vector.transpose %[[MAT]]
- // CHECK-NOT: vector.transpose %[[MASK]]
- // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
- %res = vector.mask %mask {
- vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
- : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
- } : vector<2x4xi1> -> vector<4xf32>
- return %res : vector<4xf32>
-}
-
-// CHECK-LABEL: @masked_tmatvec_km_k_m_scalable_parallel_dim
-// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
-// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
-func.func @masked_tmatvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
- // CHECK-NOT: vector.transpose %[[MAT]]
- // CHECK-NOT: vector.transpose %[[MASK]]
- // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
- %res = vector.mask %mask {
- vector.contract #matvec_trait_6 %arg0, %arg1, %arg2
- : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32>
- } : vector<2x[4]xi1> -> vector<[4]xf32>
- return %res : vector<[4]xf32>
-}
-
-// ============================================================================
-// Matvec 7 (masked)
-// ============================================================================
-#matvec_accesses_7 = [
- affine_map<(k, m) -> (k)>,
- affine_map<(k, m) -> (m, k)>,
- affine_map<(k, m) -> (m)>
-]
-#matvec_trait_7 = {
- indexing_maps = #matvec_accesses_7,
- iterator_types = ["reduction", "parallel"]
-}
-
-// CHECK-LABEL: @masked_tmatvec_k_mk_m
-// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
-// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
-func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
- // CHECK: vector.transpose %[[MAT]]
- // CHECK-NOT: vector.transpose %[[MASK]]
- // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
- %res = vector.mask %mask {
- vector.contract #matvec_trait_7 %arg1, %arg0, %arg2
- : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
- } : vector<2x4xi1> -> vector<4xf32>
- return %res : vector<4xf32>
-}
-
-// CHECK-LABEL: @masked_tmatvec_k_mk_m_scalable_parallel_dim
-// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
-// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
-func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
- // CHECK: vector.transpose %[[MAT]]
- // CHECK-NOT: vector.transpose %[[MASK]]
- // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
- %res = vector.mask %mask {
- vector.contract #matvec_trait_7 %arg1, %arg0, %arg2
- : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32>
- } : vector<2x[4]xi1> -> vector<[4]xf32>
- return %res : vector<[4]xf32>
-}
-
-// ============================================================================
-// Matvec 8 (masked)
-// ============================================================================
-#matvec_accesses_8 = [
- affine_map<(k, m) -> (k)>,
- affine_map<(k, m) -> (k, m)>,
- affine_map<(k, m) -> (m)>
-]
-#matvec_trait_8 = {
- indexing_maps = #matvec_accesses_8,
- iterator_types = ["reduction", "parallel"]
-}
-
-// CHECK-LABEL: @masked_tmatvec_k_km_m
-// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
-// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
-func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
- // CHECK-NOT: vector.transpose %[[MAT]]
- // CHECK-NOT: vector.transpose %[[MASK]]
- // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
- %res = vector.mask %mask {
- vector.contract #matvec_trait_8 %arg1, %arg0, %arg2
- : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
- } : vector<2x4xi1> -> vector<4xf32>
- return %res : vector<4xf32>
-}
-
-// CHECK-LABEL: @masked_tmatvec_k_km_m_scalable_parallel_dim
-// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32>
-// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
-// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32>
-// CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1>
-func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> {
- // CHECK-NOT: vector.transpose %[[MAT]]
- // CHECK-NOT: vector.transpose %[[MASK]]
- // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 }
- %res = vector.mask %mask {
- vector.contract #matvec_trait_8 %arg1, %arg0, %arg2
- : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32>
- } : vector<2x[4]xi1> -> vector<[4]xf32>
- return %res : vector<[4]xf32>
-}
-
-
-// ============================================================================
// TD sequence
// ============================================================================
module attributes {transform.with_named_sequence} {