diff options
author | Andrzej WarzyĆski <andrzej.warzynski@arm.com> | 2023-11-29 14:24:42 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-29 14:24:42 +0000 |
commit | 9619a2420eac885060bae6c45e85f85abfc7d6a9 (patch) | |
tree | da3ee8ba546d383d521dd386dfdaa91b502d508c /mlir/test | |
parent | 23b82bfdc2b6771231cc6a51e6aad9b48318cc1f (diff) | |
download | llvm-9619a2420eac885060bae6c45e85f85abfc7d6a9.zip llvm-9619a2420eac885060bae6c45e85f85abfc7d6a9.tar.gz llvm-9619a2420eac885060bae6c45e85f85abfc7d6a9.tar.bz2 |
[MLIR][Vector] Refactor tests for contract -> OP transforms (3/N) (#73447)
This patch refactors tests for:
vector.contract -> vector.outerproduct
for matvec operations (b += Ax). Summary of changes:
* names of LIT variables are unified,
* "plain" tests (i.e. without masking and with fixed-width vectors)
are moved to the top of their respective sections,
* missing "plain" cases are added.
This is a part of a larger effort to add cases with scalable vectors to
tests for the Vector dialect. I am refactoring these tests so that it's
easier to identify what cases are tested and where to add tests for
scalable vectors.
Implements #72834.
Diffstat (limited to 'mlir/test')
-rw-r--r-- | mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir | 471 |
1 files changed, 279 insertions, 192 deletions
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir index 7456e12..e84a43f 100644 --- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir @@ -1,5 +1,17 @@ // RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s +/// Tests for `vector.contract` -> `vector.outerproduct` transformations for +/// Matvec operations: +/// b += A * x. +/// (b and x are 1-d vectors, A is a 2-d matrix). ATM three different variants +/// are tested: +/// * plain (no mask, fixed-wdith vectors), +/// * masked (fixed-width vectors, +/// * scalable (mask + scalable vectors). +/// +/// TODO: These tests were extracted from 2 different files. If you find the +/// formatting inconsistent, please update accordingly. + #matvec_accesses_1 = [ affine_map<(m, k) -> (m, k)>, affine_map<(m, k) -> (k)>, @@ -46,6 +58,36 @@ iterator_types = ["parallel", "reduction"] } +#matvec_accesses_5 = [ + affine_map<(k, m) -> (m, k)>, + affine_map<(k, m) -> (k)>, + affine_map<(k, m) -> (m)> +] +#matvec_trait_5 = { + indexing_maps = #matvec_accesses_5, + iterator_types = ["reduction", "parallel"] +} + +#matvec_accesses_6 = [ + affine_map<(k, m) -> (k, m)>, + affine_map<(k, m) -> (k)>, + affine_map<(k, m) -> (m)> +] +#matvec_trait_6 = { + indexing_maps = #matvec_accesses_6, + iterator_types = ["reduction", "parallel"] +} + +#matvec_accesses_7 = [ + affine_map<(k, m) -> (k)>, + affine_map<(k, m) -> (m, k)>, + affine_map<(k, m) -> (m)> +] +#matvec_trait_7 = { + indexing_maps = #matvec_accesses_7, + iterator_types = ["reduction", "parallel"] +} + #matvec_accesses_8 = [ affine_map<(k, m) -> (k)>, affine_map<(k, m) -> (k, m)>, @@ -59,6 +101,24 @@ // ============================================================================ // Matvec 1 (plain + masked + scalable) // ============================================================================ +// CHECK-LABEL: func @matvec_mk_k_m +// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> +// CHECK-SAME: %[[X:.*1]]: vector<2xf32> +// CHECK-SAME: %[[B:.*2]]: vector<2xf32> +// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> +// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32 +// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> +// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32 +func.func @matvec_mk_k_m(%A: vector<2x2xf32>, + %x: vector<2xf32>, + %b: vector<2xf32>) -> vector<2xf32> { + %0 = vector.contract #matvec_trait_1 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32> + return %0 : vector<2xf32> +} + // CHECK-LABEL: func.func @masked_matvec_mk_k_m( // CHECK-SAME: %{{.*}}: vector<2x3xf32>, // CHECK-SAME: %{{.*}}: vector<3xf32>, @@ -73,12 +133,11 @@ // CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1> // CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> - -func.func @masked_matvec_mk_k_m(%arg0: vector<2x3xf32>, - %arg1: vector<3xf32>, - %arg2: vector<2xf32>, +func.func @masked_matvec_mk_k_m(%A: vector<2x3xf32>, + %x: vector<3xf32>, + %b: vector<2xf32>, %m: vector<2x3xi1>) -> vector<2xf32> { - %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2 + %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32> return %0 : vector<2xf32> } @@ -97,46 +156,28 @@ func.func @masked_matvec_mk_k_m(%arg0: vector<2x3xf32>, // CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1> // CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32> -func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[2]x3xf32>, - %arg1: vector<3xf32>, - %arg2: vector<[2]xf32>, +func.func @masked_matvec_mk_k_m_scalable_parallel_dim(%A: vector<[2]x3xf32>, + %x: vector<3xf32>, + %b: vector<[2]xf32>, %m: vector<[2]x3xi1>) -> vector<[2]xf32> { - %0 = vector.mask %m { vector.contract #matvec_trait_1 %arg0, %arg1, %arg2 + %0 = vector.mask %m { vector.contract #matvec_trait_1 %A, %x, %b : vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32> return %0 : vector<[2]xf32> } -// CHECK-LABEL: func @matvec_mk_k_m -// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> -// CHECK-SAME: %[[B:.*1]]: vector<2xf32> -// CHECK-SAME: %[[C:.*2]]: vector<2xf32> -// CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32> -// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32 -// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32> -// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32 -func.func @matvec_mk_k_m(%A: vector<2x2xf32>, - %x: vector<2xf32>, - %b: vector<2xf32>) -> vector<2xf32> { - %0 = vector.contract #matvec_trait_1 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32> - return %0 : vector<2xf32> -} - // ============================================================================ // Matvec 1 - max (plain) // ============================================================================ // CHECK-LABEL: func @matvec_mk_k_m_max // CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> -// CHECK-SAME: %[[B:.*1]]: vector<2xf32> -// CHECK-SAME: %[[C:.*2]]: vector<2xf32> +// CHECK-SAME: %[[X:.*1]]: vector<2xf32> +// CHECK-SAME: %[[B:.*2]]: vector<2xf32> // CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32> -// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32 +// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> +// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> // CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<maxf>} : vector<2xf32>, f32 func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>, %x: vector<2xf32>, @@ -149,38 +190,38 @@ func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>, // Matvec 2 (plain + masked + scalable) // ============================================================================ // CHECK-LABEL: @masked_matvec_km_k_m -// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32> -// CHECK-SAME: %[[VEC:.+]]: vector<2xf32> -// CHECK-SAME: %[[INIT:.+]]: vector<4xf32> +// CHECK-SAME: %[[A:.+]]: vector<2x4xf32> +// CHECK-SAME: %[[X:.+]]: vector<2xf32> +// CHECK-SAME: %[[B:.+]]: vector<4xf32> // CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1> -func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, - %arg1: vector<2xf32>, - %arg2: vector<4xf32>, +func.func @masked_matvec_km_k_m(%A: vector<2x4xf32>, + %x: vector<2xf32>, + %b: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> { // CHECK: vector.transpose %[[MASK]] - // CHECK-NOT: vector.transpose %[[MAT]] + // CHECK-NOT: vector.transpose %[[A]] // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 } %res = vector.mask %mask { - vector.contract #matvec_trait_2 %arg0, %arg1, %arg2 + vector.contract #matvec_trait_2 %A, %x, %b : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32> } : vector<4x2xi1> -> vector<4xf32> return %res : vector<4xf32> } // CHECK-LABEL: @masked_matvec_km_k_m_scalable_parallel_dim -// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32> -// CHECK-SAME: %[[VEC:.+]]: vector<2xf32> -// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32> +// CHECK-SAME: %[[A:.+]]: vector<2x[4]xf32> +// CHECK-SAME: %[[X:.+]]: vector<2xf32> +// CHECK-SAME: %[[B:.+]]: vector<[4]xf32> // CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1> -func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, - %arg1: vector<2xf32>, - %arg2: vector<[4]xf32>, +func.func @masked_matvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>, + %x: vector<2xf32>, + %b: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> { // CHECK: vector.transpose %[[MASK]] - // CHECK-NOT: vector.transpose %[[MAT]] + // CHECK-NOT: vector.transpose %[[A]] // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 } %res = vector.mask %mask { - vector.contract #matvec_trait_2 %arg0, %arg1, %arg2 + vector.contract #matvec_trait_2 %A, %x, %b : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32> } : vector<[4]x2xi1> -> vector<[4]xf32> return %res : vector<[4]xf32> @@ -188,13 +229,13 @@ func.func @masked_matvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, // CHECK-LABEL: func @matvec_km_k_m // CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> -// CHECK-SAME: %[[B:.*1]]: vector<2xf32> -// CHECK-SAME: %[[C:.*2]]: vector<2xf32> +// CHECK-SAME: %[[X:.*1]]: vector<2xf32> +// CHECK-SAME: %[[B:.*2]]: vector<2xf32> // CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32> -// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32 +// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> +// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32 // CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> // CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32 func.func @matvec_km_k_m(%A: vector<2x2xf32>, %x: vector<2xf32>, @@ -207,54 +248,53 @@ func.func @matvec_km_k_m(%A: vector<2x2xf32>, // Matvec 3 (plain + masked + scalable) // ============================================================================ // CHECK-LABEL: @masked_matvec_k_mk_m -// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32> -// CHECK-SAME: %[[VEC:.+]]: vector<2xf32> -// CHECK-SAME: %[[INIT:.+]]: vector<4xf32> +// CHECK-SAME: %[[A:.+]]: vector<4x2xf32> +// CHECK-SAME: %[[X:.+]]: vector<2xf32> +// CHECK-SAME: %[[B:.+]]: vector<4xf32> // CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1> -func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, - %arg1: vector<2xf32>, - %arg2: vector<4xf32>, +func.func @masked_matvec_k_mk_m(%A: vector<4x2xf32>, + %x: vector<2xf32>, + %b: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> { // CHECK: vector.transpose %[[MASK]] - // CHECK: vector.transpose %[[MAT]] + // CHECK: vector.transpose %[[A]] // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 } %res = vector.mask %mask { - vector.contract #matvec_trait_3 %arg1, %arg0, %arg2 + vector.contract #matvec_trait_3 %x, %A, %b : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32> } : vector<4x2xi1> -> vector<4xf32> return %res : vector<4xf32> } // CHECK-LABEL: @masked_matvec_k_mk_m_scalable_parallel_dim -// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32> -// CHECK-SAME: %[[VEC:.+]]: vector<2xf32> -// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32> +// CHECK-SAME: %[[A:.+]]: vector<[4]x2xf32> +// CHECK-SAME: %[[X:.+]]: vector<2xf32> +// CHECK-SAME: %[[B:.+]]: vector<[4]xf32> // CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1> -func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, - %arg1: vector<2xf32>, - %arg2: vector<[4]xf32>, +func.func @masked_matvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>, + %x: vector<2xf32>, + %b: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> { // CHECK: vector.transpose %[[MASK]] - // CHECK: vector.transpose %[[MAT]] + // CHECK: vector.transpose %[[A]] // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 } %res = vector.mask %mask { - vector.contract #matvec_trait_3 %arg1, %arg0, %arg2 + vector.contract #matvec_trait_3 %x, %A, %b : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32> } : vector<[4]x2xi1> -> vector<[4]xf32> return %res : vector<[4]xf32> } - // CHECK-LABEL: func @matvec_k_mk_m // CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> -// CHECK-SAME: %[[B:.*1]]: vector<2xf32> -// CHECK-SAME: %[[C:.*2]]: vector<2xf32> +// CHECK-SAME: %[[X:.*1]]: vector<2xf32> +// CHECK-SAME: %[[B:.*2]]: vector<2xf32> // CHECK: %[[T3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> // CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32> -// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32 +// CHECK: %[[T5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> +// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32 // CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> // CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32 func.func @matvec_k_mk_m(%A: vector<2x2xf32>, %x: vector<2xf32>, @@ -266,191 +306,232 @@ func.func @matvec_k_mk_m(%A: vector<2x2xf32>, // ============================================================================ // Matvec 4 (plain + masked + scalable) // ============================================================================ +// CHECK-LABEL: func @matvec_k_km_m +// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> +// CHECK-SAME: %[[X:.*1]]: vector<2xf32> +// CHECK-SAME: %[[B:.*2]]: vector<2xf32> +// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> +// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32 +// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> +// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32 +func.func @matvec_k_km_m(%A: vector<2x2xf32>, + %x: vector<2xf32>, + %b: vector<2xf32>) -> vector<2xf32> { + %0 = vector.contract #matvec_trait_4 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32> + return %0 : vector<2xf32> +} + // CHECK-LABEL: @masked_matvec_k_km_m_scalable_parallel_dim -// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32> -// CHECK-SAME: %[[VEC:.+]]: vector<2xf32> -// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32> +// CHECK-SAME: %[[A:.+]]: vector<2x[4]xf32> +// CHECK-SAME: %[[X:.+]]: vector<2xf32> +// CHECK-SAME: %[[B:.+]]: vector<[4]xf32> // CHECK-SAME: %[[MASK:.+]]: vector<[4]x2xi1> -func.func @masked_matvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, - %arg1: vector<2xf32>, - %arg2: vector<[4]xf32>, +func.func @masked_matvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>, + %x: vector<2xf32>, + %b: vector<[4]xf32>, %mask: vector<[4]x2xi1>) -> vector<[4]xf32> { // CHECK: vector.transpose %[[MASK]] - // CHECK-NOT: vector.transpose %[[MAT]] + // CHECK-NOT: vector.transpose %[[A]] // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 } %res = vector.mask %mask { - vector.contract #matvec_trait_4 %arg1, %arg0, %arg2 + vector.contract #matvec_trait_4 %x, %A, %b : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32> } : vector<[4]x2xi1> -> vector<[4]xf32> return %res : vector<[4]xf32> } // CHECK-LABEL: @masked_matvec_k_km_m -// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32> -// CHECK-SAME: %[[VEC:.+]]: vector<2xf32> -// CHECK-SAME: %[[INIT:.+]]: vector<4xf32> +// CHECK-SAME: %[[A:.+]]: vector<2x4xf32> +// CHECK-SAME: %[[X:.+]]: vector<2xf32> +// CHECK-SAME: %[[B:.+]]: vector<4xf32> // CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1> -func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, - %arg1: vector<2xf32>, - %arg2: vector<4xf32>, +func.func @masked_matvec_k_km_m(%A: vector<2x4xf32>, + %x: vector<2xf32>, + %b: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> { // CHECK: vector.transpose %[[MASK]] - // CHECK-NOT: vector.transpose %[[MAT]] + // CHECK-NOT: vector.transpose %[[A]] // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 } %res = vector.mask %mask { - vector.contract #matvec_trait_4 %arg1, %arg0, %arg2 + vector.contract #matvec_trait_4 %x, %A, %b : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32> } : vector<4x2xi1> -> vector<4xf32> return %res : vector<4xf32> } -// CHECK-LABEL: func @matvec_k_km_m -// CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> -// CHECK-SAME: %[[B:.*1]]: vector<2xf32> -// CHECK-SAME: %[[C:.*2]]: vector<2xf32> -// CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32> -// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32 -// CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32> -// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32 -func.func @matvec_k_km_m(%A: vector<2x2xf32>, - %x: vector<2xf32>, - %b: vector<2xf32>) -> vector<2xf32> { - %0 = vector.contract #matvec_trait_4 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32> - return %0 : vector<2xf32> -} - // ============================================================================ -// Matvec 5 (masked + scalable) +// Matvec 5 (plain + masked + scalable) // ============================================================================ -#matvec_accesses_5 = [ - affine_map<(k, m) -> (m, k)>, - affine_map<(k, m) -> (k)>, - affine_map<(k, m) -> (m)> -] -#matvec_trait_5 = { - indexing_maps = #matvec_accesses_5, - iterator_types = ["reduction", "parallel"] +// CHECK-LABEL: func.func @tmatvec_mk_k_m( +// CHECK-SAME: %[[A:.*]]: vector<2x2xf32>, +// CHECK-SAME: %[[X:.*]]: vector<2xf32>, +// CHECK-SAME: %[[B:.*]]: vector<2xf32>) -> vector<2xf32> { +// CHECK: %[[VAL_3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> +// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[VAL_5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> +// CHECK: %[[VAL_6:.*]] = vector.outerproduct %[[VAL_4]], %[[VAL_5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32 +// CHECK: %[[VAL_7:.*]] = vector.extract %[[VAL_3]][1] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[VAL_8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> +// CHECK: %[[VAL_9:.*]] = vector.outerproduct %[[VAL_7]], %[[VAL_8]], %[[VAL_6]] {kind = #vector.kind<add>} : vector<2xf32>, f32 +func.func @tmatvec_mk_k_m(%A: vector<2x2xf32>, + %x: vector<2xf32>, + %b: vector<2xf32>) -> vector<2xf32> { + %0 = vector.contract #matvec_trait_5 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32> + return %0 : vector<2xf32> } // CHECK-LABEL: @masked_tmatvec_mk_k_m -// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32> -// CHECK-SAME: %[[VEC:.+]]: vector<2xf32> -// CHECK-SAME: %[[INIT:.+]]: vector<4xf32> +// CHECK-SAME: %[[A:.+]]: vector<4x2xf32> +// CHECK-SAME: %[[X:.+]]: vector<2xf32> +// CHECK-SAME: %[[B:.+]]: vector<4xf32> // CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1> -func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> { - // CHECK: vector.transpose %[[MAT]] +func.func @masked_tmatvec_mk_k_m(%A: vector<4x2xf32>, + %x: vector<2xf32>, + %b: vector<4xf32>, + %mask: vector<2x4xi1>) -> vector<4xf32> { + // CHECK: vector.transpose %[[A]] // CHECK-NOT: vector.transpose %[[MASK]] // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 } %res = vector.mask %mask { - vector.contract #matvec_trait_5 %arg0, %arg1, %arg2 + vector.contract #matvec_trait_5 %A, %x, %b : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32> } : vector<2x4xi1> -> vector<4xf32> return %res : vector<4xf32> } // CHECK-LABEL: @masked_tmatvec_mk_k_m_scalable_parallel_dim -// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32> -// CHECK-SAME: %[[VEC:.+]]: vector<2xf32> -// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32> +// CHECK-SAME: %[[A:.+]]: vector<[4]x2xf32> +// CHECK-SAME: %[[X:.+]]: vector<2xf32> +// CHECK-SAME: %[[B:.+]]: vector<[4]xf32> // CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1> -func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> { - // CHECK: vector.transpose %[[MAT]] +func.func @masked_tmatvec_mk_k_m_scalable_parallel_dim(%A: vector<[4]x2xf32>, + %x: vector<2xf32>, + %b: vector<[4]xf32>, + %mask: vector<2x[4]xi1>) -> vector<[4]xf32> { + // CHECK: vector.transpose %[[A]] // CHECK-NOT: vector.transpose %[[MASK]] // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 } %res = vector.mask %mask { - vector.contract #matvec_trait_5 %arg0, %arg1, %arg2 + vector.contract #matvec_trait_5 %A, %x, %b : vector<[4]x2xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32> } : vector<2x[4]xi1> -> vector<[4]xf32> return %res : vector<[4]xf32> } // ============================================================================ -// Matvec 6 (masked + scalable) +// Matvec 6 (plain + masked + scalable) // ============================================================================ -#matvec_accesses_6 = [ - affine_map<(k, m) -> (k, m)>, - affine_map<(k, m) -> (k)>, - affine_map<(k, m) -> (m)> -] -#matvec_trait_6 = { - indexing_maps = #matvec_accesses_6, - iterator_types = ["reduction", "parallel"] +// CHECK-LABEL: func.func @tmatvec_km_k_m( +// CHECK-SAME: %[[A:.*]]: vector<2x2xf32>, +// CHECK-SAME: %[[X:.*]]: vector<2xf32>, +// CHECK-SAME: %[[B:.*]]: vector<2xf32>) -> vector<2xf32> { +// CHECK: %[[VAL_3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[VAL_4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> +// CHECK: %[[VAL_5:.*]] = vector.outerproduct %[[VAL_3]], %[[VAL_4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32 +// CHECK: %[[VAL_6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[VAL_7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> +// CHECK: %[[VAL_8:.*]] = vector.outerproduct %[[VAL_6]], %[[VAL_7]], %[[VAL_5]] {kind = #vector.kind<add>} : vector<2xf32>, f32 +func.func @tmatvec_km_k_m(%A: vector<2x2xf32>, + %x: vector<2xf32>, + %b: vector<2xf32>) -> vector<2xf32> { + %0 = vector.contract #matvec_trait_6 %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32> + return %0 : vector<2xf32> } // CHECK-LABEL: @masked_tmatvec_km_k_m -// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32> -// CHECK-SAME: %[[VEC:.+]]: vector<2xf32> -// CHECK-SAME: %[[INIT:.+]]: vector<4xf32> +// CHECK-SAME: %[[A:.+]]: vector<2x4xf32> +// CHECK-SAME: %[[X:.+]]: vector<2xf32> +// CHECK-SAME: %[[B:.+]]: vector<4xf32> // CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1> -func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> { - // CHECK-NOT: vector.transpose %[[MAT]] +func.func @masked_tmatvec_km_k_m(%A: vector<2x4xf32>, + %x: vector<2xf32>, + %b: vector<4xf32>, + %mask: vector<2x4xi1>) -> vector<4xf32> { + // CHECK-NOT: vector.transpose %[[A]] // CHECK-NOT: vector.transpose %[[MASK]] // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 } %res = vector.mask %mask { - vector.contract #matvec_trait_6 %arg0, %arg1, %arg2 + vector.contract #matvec_trait_6 %A, %x, %b : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32> } : vector<2x4xi1> -> vector<4xf32> return %res : vector<4xf32> } // CHECK-LABEL: @masked_tmatvec_km_k_m_scalable_parallel_dim -// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32> -// CHECK-SAME: %[[VEC:.+]]: vector<2xf32> -// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32> +// CHECK-SAME: %[[A:.+]]: vector<2x[4]xf32> +// CHECK-SAME: %[[X:.+]]: vector<2xf32> +// CHECK-SAME: %[[B:.+]]: vector<[4]xf32> // CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1> -func.func @masked_tmatvec_km_k_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> { - // CHECK-NOT: vector.transpose %[[MAT]] +func.func @masked_tmatvec_km_k_m_scalable_parallel_dim(%A: vector<2x[4]xf32>, + %x: vector<2xf32>, + %b: vector<[4]xf32>, + %mask: vector<2x[4]xi1>) -> vector<[4]xf32> { + // CHECK-NOT: vector.transpose %[[A]] // CHECK-NOT: vector.transpose %[[MASK]] // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 } %res = vector.mask %mask { - vector.contract #matvec_trait_6 %arg0, %arg1, %arg2 + vector.contract #matvec_trait_6 %A, %x, %b : vector<2x[4]xf32>, vector<2xf32>, vector<[4]xf32> into vector<[4]xf32> } : vector<2x[4]xi1> -> vector<[4]xf32> return %res : vector<[4]xf32> } // ============================================================================ -// Matvec 7 (masked + scalable) +// Matvec 7 (plain + masked + scalable) // ============================================================================ -#matvec_accesses_7 = [ - affine_map<(k, m) -> (k)>, - affine_map<(k, m) -> (m, k)>, - affine_map<(k, m) -> (m)> -] -#matvec_trait_7 = { - indexing_maps = #matvec_accesses_7, - iterator_types = ["reduction", "parallel"] +// CHECK-LABEL: func.func @tmatvec_k_mk_m( +// CHECK-SAME: %[[A:.*]]: vector<2x2xf32>, +// CHECK-SAME: %[[X:.*]]: vector<2xf32>, +// CHECK-SAME: %[[B:.*]]: vector<2xf32>) -> vector<2xf32> { +// CHECK: %[[VAL_3:.*]] = vector.transpose %[[A]], [1, 0] : vector<2x2xf32> to vector<2x2xf32> +// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[VAL_5:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> +// CHECK: %[[VAL_6:.*]] = vector.outerproduct %[[VAL_4]], %[[VAL_5]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32 +// CHECK: %[[VAL_7:.*]] = vector.extract %[[VAL_3]][1] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[VAL_8:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> +// CHECK: %[[VAL_9:.*]] = vector.outerproduct %[[VAL_7]], %[[VAL_8]], %[[VAL_6]] {kind = #vector.kind<add>} : vector<2xf32>, f32 +func.func @tmatvec_k_mk_m(%A: vector<2x2xf32>, + %x: vector<2xf32>, + %b: vector<2xf32>) -> vector<2xf32> { + %0 = vector.contract #matvec_trait_7 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32> + return %0 : vector<2xf32> } // CHECK-LABEL: @masked_tmatvec_k_mk_m -// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32> -// CHECK-SAME: %[[VEC:.+]]: vector<2xf32> -// CHECK-SAME: %[[INIT:.+]]: vector<4xf32> +// CHECK-SAME: %[[A:.+]]: vector<4x2xf32> +// CHECK-SAME: %[[X:.+]]: vector<2xf32> +// CHECK-SAME: %[[B:.+]]: vector<4xf32> // CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1> -func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> { - // CHECK: vector.transpose %[[MAT]] +func.func @masked_tmatvec_k_mk_m(%A: vector<4x2xf32>, + %x: vector<2xf32>, + %b: vector<4xf32>, + %mask: vector<2x4xi1>) -> vector<4xf32> { + // CHECK: vector.transpose %[[A]] // CHECK-NOT: vector.transpose %[[MASK]] // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 } %res = vector.mask %mask { - vector.contract #matvec_trait_7 %arg1, %arg0, %arg2 + vector.contract #matvec_trait_7 %x, %A, %b : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32> } : vector<2x4xi1> -> vector<4xf32> return %res : vector<4xf32> } // CHECK-LABEL: @masked_tmatvec_k_mk_m_scalable_parallel_dim -// CHECK-SAME: %[[MAT:.+]]: vector<[4]x2xf32> -// CHECK-SAME: %[[VEC:.+]]: vector<2xf32> -// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32> +// CHECK-SAME: %[[A:.+]]: vector<[4]x2xf32> +// CHECK-SAME: %[[X:.+]]: vector<2xf32> +// CHECK-SAME: %[[B:.+]]: vector<[4]xf32> // CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1> -func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> { - // CHECK: vector.transpose %[[MAT]] +func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%A: vector<[4]x2xf32>, + %x: vector<2xf32>, + %b: vector<[4]xf32>, + %mask: vector<2x[4]xi1>) -> vector<[4]xf32> { + // CHECK: vector.transpose %[[A]] // CHECK-NOT: vector.transpose %[[MASK]] // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 } %res = vector.mask %mask { - vector.contract #matvec_trait_7 %arg1, %arg0, %arg2 + vector.contract #matvec_trait_7 %x, %A, %b : vector<2xf32>, vector<[4]x2xf32>, vector<[4]xf32> into vector<[4]xf32> } : vector<2x[4]xi1> -> vector<[4]xf32> return %res : vector<[4]xf32> @@ -459,50 +540,56 @@ func.func @masked_tmatvec_k_mk_m_scalable_parallel_dim(%arg0: vector<[4]x2xf32>, // ============================================================================ // Matvec 8 (plain + masked + scalable) // ============================================================================ -// CHECK-LABEL: func @matvec_m_mk_k +// CHECK-LABEL: func @tmatvec_m_mk_k // CHECK-SAME: %[[A:.*0]]: vector<2x2xf32> -// CHECK-SAME: %[[B:.*1]]: vector<2xf32> -// CHECK-SAME: %[[C:.*2]]: vector<2xf32> +// CHECK-SAME: %[[X:.*1]]: vector<2xf32> +// CHECK-SAME: %[[B:.*2]]: vector<2xf32> // CHECK: %[[T3:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T4:.*]] = vector.extract %[[B]][0] : f32 from vector<2xf32> -// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[C]] {kind = #vector.kind<add>} : vector<2xf32>, f32 +// CHECK: %[[T4:.*]] = vector.extract %[[X]][0] : f32 from vector<2xf32> +// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[B]] {kind = #vector.kind<add>} : vector<2xf32>, f32 // CHECK: %[[T6:.*]] = vector.extract %[[A]][1] : vector<2xf32> from vector<2x2xf32> -// CHECK: %[[T7:.*]] = vector.extract %[[B]][1] : f32 from vector<2xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[X]][1] : f32 from vector<2xf32> // CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32 -func.func @matvec_m_mk_k(%A: vector<2x2xf32>, - %x: vector<2xf32>, - %b: vector<2xf32>) -> vector<2xf32> { +func.func @tmatvec_m_mk_k(%A: vector<2x2xf32>, + %x: vector<2xf32>, + %b: vector<2xf32>) -> vector<2xf32> { %0 = vector.contract #matvec_trait_8 %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32> return %0 : vector<2xf32> } // CHECK-LABEL: @masked_tmatvec_k_km_m -// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32> -// CHECK-SAME: %[[VEC:.+]]: vector<2xf32> -// CHECK-SAME: %[[INIT:.+]]: vector<4xf32> +// CHECK-SAME: %[[A:.+]]: vector<2x4xf32> +// CHECK-SAME: %[[X:.+]]: vector<2xf32> +// CHECK-SAME: %[[B:.+]]: vector<4xf32> // CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1> -func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> { - // CHECK-NOT: vector.transpose %[[MAT]] +func.func @masked_tmatvec_k_km_m(%A: vector<2x4xf32>, + %x: vector<2xf32>, + %b: vector<4xf32>, + %mask: vector<2x4xi1>) -> vector<4xf32> { + // CHECK-NOT: vector.transpose %[[A]] // CHECK-NOT: vector.transpose %[[MASK]] // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 } %res = vector.mask %mask { - vector.contract #matvec_trait_8 %arg1, %arg0, %arg2 + vector.contract #matvec_trait_8 %x, %A, %b : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32> } : vector<2x4xi1> -> vector<4xf32> return %res : vector<4xf32> } // CHECK-LABEL: @masked_tmatvec_k_km_m_scalable_parallel_dim -// CHECK-SAME: %[[MAT:.+]]: vector<2x[4]xf32> -// CHECK-SAME: %[[VEC:.+]]: vector<2xf32> -// CHECK-SAME: %[[INIT:.+]]: vector<[4]xf32> +// CHECK-SAME: %[[A:.+]]: vector<2x[4]xf32> +// CHECK-SAME: %[[X:.+]]: vector<2xf32> +// CHECK-SAME: %[[B:.+]]: vector<[4]xf32> // CHECK-SAME: %[[MASK:.+]]: vector<2x[4]xi1> -func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%arg0: vector<2x[4]xf32>, %arg1: vector<2xf32>, %arg2: vector<[4]xf32>, %mask: vector<2x[4]xi1>) -> vector<[4]xf32> { - // CHECK-NOT: vector.transpose %[[MAT]] +func.func @masked_tmatvec_k_km_m_scalable_parallel_dim(%A: vector<2x[4]xf32>, + %x: vector<2xf32>, + %b: vector<[4]xf32>, + %mask: vector<2x[4]xi1>) -> vector<[4]xf32> { + // CHECK-NOT: vector.transpose %[[A]] // CHECK-NOT: vector.transpose %[[MASK]] // CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<[4]xf32>, f32 } %res = vector.mask %mask { - vector.contract #matvec_trait_8 %arg1, %arg0, %arg2 + vector.contract #matvec_trait_8 %x, %A, %b : vector<2xf32>, vector<2x[4]xf32>, vector<[4]xf32> into vector<[4]xf32> } : vector<2x[4]xi1> -> vector<[4]xf32> return %res : vector<[4]xf32> |