diff options
author | Andrzej WarzyĆski <andrzej.warzynski@arm.com> | 2024-06-21 12:48:26 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-21 12:48:26 +0100 |
commit | 1c85c711aadb65943f5187524274fc96d1151b02 (patch) | |
tree | e7b9ce1d5a8d55242fa21c84a70efbd1116cc17d /mlir | |
parent | f1075a34ab30f67915deb9a519dd98e025c5c998 (diff) | |
download | llvm-1c85c711aadb65943f5187524274fc96d1151b02.zip llvm-1c85c711aadb65943f5187524274fc96d1151b02.tar.gz llvm-1c85c711aadb65943f5187524274fc96d1151b02.tar.bz2 |
[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (2/n) (#95744)
The main goal of this and subsequent PRs is to unify and categorize
tests in:
* vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.
Below are the main contributions of this PR
1. Two tests duplicated
`@transfer_{read|write}_dims_mismatch_non_contiguous_slice`:
* `@transfer_{read|write}_dims_mismatch_non_contiguous` and
* `@transfer_read_flattenable_negative` duplicated
`@transfer_{read|write}_dims_mismatch_non_contiguous_slice`.
These tests are removed (the original test is preserved).
2. `@transfer_read_flattenable_negative2` is replaced with
two tests with more descriptive names:
* `@transfer_read_non_contiguous_src` (for `xfer_read`) and
* `@transfer_write_non_contiguous_src` (for `xfer_write`)
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/test/Dialect/Vector/vector-transfer-flatten.mlir | 116 |
1 files changed, 44 insertions, 72 deletions
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index 40a8b7e..3a5041f 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -131,25 +131,6 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices( // ----- -func.func @transfer_read_dims_mismatch_non_contiguous( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> { - - %c0 = arith.constant 0 : index - %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : - memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8> - return %v : vector<2x1x2x2xi8> -} - -// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous -// CHECK-NOT: memref.collapse_shape -// CHECK-NOT: vector.shape_cast - -// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - // The input memref has a dynamic trailing shape and hence is not flattened. // TODO: This case could be supported via memref.dim @@ -214,6 +195,28 @@ func.func @transfer_read_0d( // ----- +// Strides make the input memref non-contiguous, hence non-flattenable. + +func.func @transfer_read_non_contiguous_src( + %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { + + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> +} + +// CHECK-LABEL: func.func @transfer_read_non_contiguous_src +// CHECK-NOT: memref.collapse_shape +// CHECK-NOT: vector.shape_cast + +// CHECK-128B-LABEL: func @transfer_read_non_contiguous_src +// CHECK-128B-NOT: memref.collapse_shape +// CHECK-128B-NOT: vector.shape_cast + +// ----- + ///---------------------------------------------------------------------------------------- /// vector.transfer_write /// [Pattern: FlattenContiguousRowMajorTransferWritePattern] @@ -342,25 +345,6 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices( // ----- -func.func @transfer_write_dims_mismatch_non_contiguous( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, - %vec : vector<2x1x2x2xi8>) { - - %c0 = arith.constant 0 : index - vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : - vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> - return -} - -// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous -// CHECK-NOT: memref.collapse_shape -// CHECK-NOT: vector.shape_cast - -// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - // The input memref has a dynamic trailing shape and hence is not flattened. // TODO: This case could be supported via memref.dim @@ -427,6 +411,28 @@ func.func @transfer_write_0d( // ----- +// The strides make the input memref non-contiguous, hence non-flattenable. + +func.func @transfer_write_non_contiguous_src( + %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, + %vec : vector<5x4x3x2xi8>) { + + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] : + vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>> + return +} + +// CHECK-LABEL: func.func @transfer_write_non_contiguous_src +// CHECK-NOT: memref.collapse_shape +// CHECK-NOT: vector.shape_cast + +// CHECK-128B-LABEL: func @transfer_write_non_contiguous_src +// CHECK-128B-NOT: memref.collapse_shape +// CHECK-128B-NOT: vector.shape_cast + +// ----- + ///---------------------------------------------------------------------------------------- /// TODO: Categorize + re-format ///---------------------------------------------------------------------------------------- @@ -478,40 +484,6 @@ func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vecto // ----- -func.func @transfer_read_flattenable_negative( - %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x2x2x2xi8> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : - memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x2x2x2xi8> - return %v : vector<2x2x2x2xi8> -} - -// CHECK-LABEL: func @transfer_read_flattenable_negative -// CHECK: vector.transfer_read {{.*}} vector<2x2x2x2xi8> - -// CHECK-128B-LABEL: func @transfer_read_flattenable_negative( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - -func.func @transfer_read_flattenable_negative2( - %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { - %c0 = arith.constant 0 : index - %cst = arith.constant 0 : i8 - %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : - memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8> - return %v : vector<5x4x3x2xi8> -} - -// CHECK-LABEL: func @transfer_read_flattenable_negative2 -// CHECK: vector.transfer_read {{.*}} vector<5x4x3x2xi8> - -// CHECK-128B-LABEL: func @transfer_read_flattenable_negative2( -// CHECK-128B-NOT: memref.collapse_shape - -// ----- - func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> { %add = arith.addi %arg0, %arg0 : vector<1x8xi32> return %add : vector<1x8xi32> |