diff options
author | Cullen Rhodes <cullen.rhodes@arm.com> | 2024-06-20 08:07:43 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-20 08:07:43 +0100 |
commit | cc145f40530667d65220536a3e03eabe9fdd46cf (patch) | |
tree | e8bc5d4f9d25ab01e1066e15bcde17e434f7abdd /mlir | |
parent | fa08e97d03afd215caeb297a822895c4d0d93b7b (diff) | |
download | llvm-cc145f40530667d65220536a3e03eabe9fdd46cf.zip llvm-cc145f40530667d65220536a3e03eabe9fdd46cf.tar.gz llvm-cc145f40530667d65220536a3e03eabe9fdd46cf.tar.bz2 |
[mlir][vector] Disable Gather1DToConditionalLoads for scalable vectors (#96049)
Pattern scalarizes vector.gather operations and is incorrect for
scalable vectors.
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp | 3 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/vector-gather-lowering.mlir | 10 |
2 files changed, 13 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 9012812..dd027d1 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -189,6 +189,9 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> { if (resultTy.getRank() != 1) return rewriter.notifyMatchFailure(op, "unsupported rank"); + if (resultTy.isScalable()) + return rewriter.notifyMatchFailure(op, "not a fixed-width vector"); + Location loc = op.getLoc(); Type elemTy = resultTy.getElementType(); // Vector type with a single element. Used to generate `vector.loads`. diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir index d047ac6..c2eb88a 100644 --- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir @@ -206,3 +206,13 @@ func.func @strided_gather(%base : memref<100x3xf32>, // CHECK: scf.if %[[MASK_3]] -> (vector<4xf32>) // CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32> // CHECK: %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32> + +// CHECK-LABEL: @scalable_gather_1d +// CHECK-NOT: extract +// CHECK: vector.gather +// CHECK-NOT: extract +func.func @scalable_gather_1d(%base: tensor<?xf32>, %v: vector<[2]xindex>, %mask: vector<[2]xi1>, %pass_thru: vector<[2]xf32>) -> vector<[2]xf32> { + %c0 = arith.constant 0 : index + %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<[2]xindex>, vector<[2]xi1>, vector<[2]xf32> into vector<[2]xf32> + return %0 : vector<[2]xf32> +} |