aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorAndrzej WarzyƄski <andrzej.warzynski@arm.com>2024-06-21 12:48:26 +0100
committerGitHub <noreply@github.com>2024-06-21 12:48:26 +0100
commit1c85c711aadb65943f5187524274fc96d1151b02 (patch)
treee7b9ce1d5a8d55242fa21c84a70efbd1116cc17d /mlir
parentf1075a34ab30f67915deb9a519dd98e025c5c998 (diff)
downloadllvm-1c85c711aadb65943f5187524274fc96d1151b02.zip
llvm-1c85c711aadb65943f5187524274fc96d1151b02.tar.gz
llvm-1c85c711aadb65943f5187524274fc96d1151b02.tar.bz2
[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (2/n) (#95744)
The main goal of this and subsequent PRs is to unify and categorize tests in: * vector-transfer-flatten.mlir This should make it easier to identify the edge cases being tested (and how they differ), remove duplicates and to add tests for scalable vectors. Below are the main contributions of this PR 1. Two tests duplicated `@transfer_{read|write}_dims_mismatch_non_contiguous_slice`: * `@transfer_{read|write}_dims_mismatch_non_contiguous` and * `@transfer_read_flattenable_negative` duplicated `@transfer_{read|write}_dims_mismatch_non_contiguous_slice`. These tests are removed (the original test is preserved). 2. `@transfer_read_flattenable_negative2` is replaced with two tests with more descriptive names: * `@transfer_read_non_contiguous_src` (for `xfer_read`) and * `@transfer_write_non_contiguous_src` (for `xfer_write`)
Diffstat (limited to 'mlir')
-rw-r--r--mlir/test/Dialect/Vector/vector-transfer-flatten.mlir116
1 files changed, 44 insertions, 72 deletions
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 40a8b7e..3a5041f 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -131,25 +131,6 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
// -----
-func.func @transfer_read_dims_mismatch_non_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
-
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
- return %v : vector<2x1x2x2xi8>
-}
-
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
-
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
-// CHECK-128B-NOT: memref.collapse_shape
-
-// -----
-
// The input memref has a dynamic trailing shape and hence is not flattened.
// TODO: This case could be supported via memref.dim
@@ -214,6 +195,28 @@ func.func @transfer_read_0d(
// -----
+// Strides make the input memref non-contiguous, hence non-flattenable.
+
+func.func @transfer_read_non_contiguous_src(
+ %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+ return %v : vector<5x4x3x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_non_contiguous_src
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_read_non_contiguous_src
+// CHECK-128B-NOT: memref.collapse_shape
+// CHECK-128B-NOT: vector.shape_cast
+
+// -----
+
///----------------------------------------------------------------------------------------
/// vector.transfer_write
/// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
@@ -342,25 +345,6 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
// -----
-func.func @transfer_write_dims_mismatch_non_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
- %vec : vector<2x1x2x2xi8>) {
-
- %c0 = arith.constant 0 : index
- vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
- vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
- return
-}
-
-// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
-
-// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous(
-// CHECK-128B-NOT: memref.collapse_shape
-
-// -----
-
// The input memref has a dynamic trailing shape and hence is not flattened.
// TODO: This case could be supported via memref.dim
@@ -427,6 +411,28 @@ func.func @transfer_write_0d(
// -----
+// The strides make the input memref non-contiguous, hence non-flattenable.
+
+func.func @transfer_write_non_contiguous_src(
+ %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>,
+ %vec : vector<5x4x3x2xi8>) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] :
+ vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>
+ return
+}
+
+// CHECK-LABEL: func.func @transfer_write_non_contiguous_src
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_write_non_contiguous_src
+// CHECK-128B-NOT: memref.collapse_shape
+// CHECK-128B-NOT: vector.shape_cast
+
+// -----
+
///----------------------------------------------------------------------------------------
/// TODO: Categorize + re-format
///----------------------------------------------------------------------------------------
@@ -478,40 +484,6 @@ func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vecto
// -----
-func.func @transfer_read_flattenable_negative(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x2x2x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x2x2x2xi8>
- return %v : vector<2x2x2x2xi8>
-}
-
-// CHECK-LABEL: func @transfer_read_flattenable_negative
-// CHECK: vector.transfer_read {{.*}} vector<2x2x2x2xi8>
-
-// CHECK-128B-LABEL: func @transfer_read_flattenable_negative(
-// CHECK-128B-NOT: memref.collapse_shape
-
-// -----
-
-func.func @transfer_read_flattenable_negative2(
- %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
- return %v : vector<5x4x3x2xi8>
-}
-
-// CHECK-LABEL: func @transfer_read_flattenable_negative2
-// CHECK: vector.transfer_read {{.*}} vector<5x4x3x2xi8>
-
-// CHECK-128B-LABEL: func @transfer_read_flattenable_negative2(
-// CHECK-128B-NOT: memref.collapse_shape
-
-// -----
-
func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
%add = arith.addi %arg0, %arg0 : vector<1x8xi32>
return %add : vector<1x8xi32>