diff options
author | qcolombet <quentin.colombet@gmail.com> | 2023-10-05 14:32:42 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-05 14:32:42 +0200 |
commit | 932dc9d8c4979f4f0aec33e2b9725a110ded8b61 (patch) | |
tree | 50889ce9b1c760c896db4fe0d2ac8ec20d29172c /mlir/test/Dialect/MemRef | |
parent | 253ee85f344e5f62d70e7b7bc796bd7bd5da3330 (diff) | |
download | llvm-932dc9d8c4979f4f0aec33e2b9725a110ded8b61.zip llvm-932dc9d8c4979f4f0aec33e2b9725a110ded8b61.tar.gz llvm-932dc9d8c4979f4f0aec33e2b9725a110ded8b61.tar.bz2 |
[mlir][MemRef] Add a pattern to simplify `extract_strided_metadata(ca… (#68291)
…st)`
`expand-strided-metadata` was missing a pattern to get rid of
`memref.cast`.
The pattern is straight foward:
Produce a new `extract_strided_metadata` with the source of the cast and
fold the static information (sizes, strides, offset) along the way.
Diffstat (limited to 'mlir/test/Dialect/MemRef')
-rw-r--r-- | mlir/test/Dialect/MemRef/expand-strided-metadata.mlir | 125 |
1 files changed, 125 insertions, 0 deletions
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir index a6303aa..ab0c78a 100644 --- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir +++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir @@ -1369,3 +1369,128 @@ func.func @extract_strided_metadata_of_get_global_with_offset() return %base, %offset, %sizes#0, %sizes#1, %strides#0, %strides#1 : memref<i32>, index, index, index, index, index } + +// ----- + +// Check that we simplify extract_strided_metadata of cast +// when the source of the cast is compatible with what +// `extract_strided_metadata`s accept. +// +// When we apply the transformation the resulting offset, sizes and strides +// should come straight from the inputs of the cast. +// Additionally the folder on extract_strided_metadata should propagate the +// static information. +// +// CHECK-LABEL: func @extract_strided_metadata_of_cast +// CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>) +// +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1 +func.func @extract_strided_metadata_of_cast( + %arg : memref<3x?xi32, strided<[4, ?], offset:?>>) + -> (memref<i32>, index, + index, index, + index, index) { + + %cast = + memref.cast %arg : + memref<3x?xi32, strided<[4, ?], offset: ?>> to + memref<?x?xi32, strided<[?, ?], offset: ?>> + + %base, %base_offset, %sizes:2, %strides:2 = + memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>> + -> memref<i32>, index, + index, index, + index, index + + return %base, %base_offset, + %sizes#0, %sizes#1, + %strides#0, %strides#1 : + memref<i32>, index, + index, index, + index, index +} + +// ----- + +// Check that we simplify extract_strided_metadata of cast +// when the source of the cast is compatible with what +// `extract_strided_metadata`s accept. +// +// Same as extract_strided_metadata_of_cast but with constant sizes and strides +// in the destination type. +// +// CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts +// CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>) +// +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index +// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index +// CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] +// +// CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]] +func.func @extract_strided_metadata_of_cast_w_csts( + %arg : memref<?x?xi32, strided<[?, ?], offset:?>>) + -> (memref<i32>, index, + index, index, + index, index) { + + %cast = + memref.cast %arg : + memref<?x?xi32, strided<[?, ?], offset: ?>> to + memref<4x?xi32, strided<[?, 18], offset: 25>> + + %base, %base_offset, %sizes:2, %strides:2 = + memref.extract_strided_metadata %cast:memref<4x?xi32, strided<[?, 18], offset: 25>> + -> memref<i32>, index, + index, index, + index, index + + return %base, %base_offset, + %sizes#0, %sizes#1, + %strides#0, %strides#1 : + memref<i32>, index, + index, index, + index, index +} +// ----- + +// Check that we don't simplify extract_strided_metadata of +// cast when the source of the cast is unranked. +// Unranked memrefs cannot feed into extract_strided_metadata operations. +// Note: Technically we could still fold the sizes and strides. +// +// CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked +// CHECK-SAME: %[[ARG:.*]]: memref<*xi32>) +// +// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]] +// +// CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1 +func.func @extract_strided_metadata_of_cast_unranked( + %arg : memref<*xi32>) + -> (memref<i32>, index, + index, index, + index, index) { + + %cast = + memref.cast %arg : + memref<*xi32> to + memref<?x?xi32, strided<[?, ?], offset: ?>> + + %base, %base_offset, %sizes:2, %strides:2 = + memref.extract_strided_metadata %cast:memref<?x?xi32, strided<[?, ?], offset: ?>> + -> memref<i32>, index, + index, index, + index, index + + return %base, %base_offset, + %sizes#0, %sizes#1, + %strides#0, %strides#1 : + memref<i32>, index, + index, index, + index, index +} |