diff options
author | Matthias Springer <me@m-sp.org> | 2024-06-18 16:25:47 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-18 16:25:47 +0200 |
commit | 552f80a97982c30e3cb1bde2223e923a2e3a47e8 (patch) | |
tree | efeefa4731236a87a4085212e92259b74d64ae62 /mlir | |
parent | 41f6aee769a900558e58eba7b5f9044f8f8a05b7 (diff) | |
download | llvm-552f80a97982c30e3cb1bde2223e923a2e3a47e8.zip llvm-552f80a97982c30e3cb1bde2223e923a2e3a47e8.tar.gz llvm-552f80a97982c30e3cb1bde2223e923a2e3a47e8.tar.bz2 |
[mlir][vector] Fix crash when folding 0D extract from splat/broadcast (#95918)
There was an assertion in the folder that caused a crash when extracting
from a vector that is defined by an op with 0D semantics. This commit
removes the assertion and adds test cases to ensure that 0D scenarios
are handled correctly.
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/canonicalize.mlir | 38 |
2 files changed, 38 insertions, 5 deletions
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 5e5d3e0..2bf4f16 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1631,11 +1631,6 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) { if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp)) return Value(); - // 0-D vectors not supported. - assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported"); - if (hasZeroDimVectors(defOp)) - return Value(); - Value source = defOp->getOperand(0); if (extractOp.getType() == source.getType()) return source; diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 61269e3..caccd1f 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2604,3 +2604,41 @@ func.func @extract_from_0d_regression(%v: vector<f32>) -> f32 { %0 = vector.extract %v[] : f32 from vector<f32> return %0 : f32 } + +// ----- + +// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression( +// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>) +func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) { + // Splat scalar to 0D and extract scalar. + %0 = vector.splat %a : vector<f32> + %1 = vector.extract %0[] : f32 from vector<f32> + + // Broadcast scalar to 0D and extract scalar. + %2 = vector.broadcast %a : f32 to vector<f32> + %3 = vector.extract %2[] : f32 from vector<f32> + + // Broadcast 0D to 3D and extract scalar. + // CHECK: %[[extract1:.*]] = vector.extractelement %[[b]][] : vector<f32> + %4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32> + %5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32> + + // Splat scalar to 2D and extract scalar. + %6 = vector.splat %a : vector<2x3xf32> + %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32> + + // Broadcast scalar to 3D and extract scalar. + %8 = vector.broadcast %a : f32 to vector<5x6x7xf32> + %9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32> + + // Extract 2D from 3D that was broadcasted from a scalar. + // CHECK: %[[extract2:.*]] = vector.broadcast %[[a]] : f32 to vector<6x7xf32> + %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32> + + // Extract 1D from 2D that was splat'ed from a scalar. + // CHECK: %[[extract3:.*]] = vector.broadcast %[[a]] : f32 to vector<3xf32> + %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32> + + // CHECK: return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]] + return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32> +} |