aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2024-06-18 16:25:47 +0200
committerGitHub <noreply@github.com>2024-06-18 16:25:47 +0200
commit552f80a97982c30e3cb1bde2223e923a2e3a47e8 (patch)
treeefeefa4731236a87a4085212e92259b74d64ae62 /mlir
parent41f6aee769a900558e58eba7b5f9044f8f8a05b7 (diff)
downloadllvm-552f80a97982c30e3cb1bde2223e923a2e3a47e8.zip
llvm-552f80a97982c30e3cb1bde2223e923a2e3a47e8.tar.gz
llvm-552f80a97982c30e3cb1bde2223e923a2e3a47e8.tar.bz2
[mlir][vector] Fix crash when folding 0D extract from splat/broadcast (#95918)
There was an assertion in the folder that caused a crash when extracting from a vector that is defined by an op with 0D semantics. This commit removes the assertion and adds test cases to ensure that 0D scenarios are handled correctly.
Diffstat (limited to 'mlir')
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp5
-rw-r--r--mlir/test/Dialect/Vector/canonicalize.mlir38
2 files changed, 38 insertions, 5 deletions
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5e5d3e0..2bf4f16 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1631,11 +1631,6 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
return Value();
- // 0-D vectors not supported.
- assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
- if (hasZeroDimVectors(defOp))
- return Value();
-
Value source = defOp->getOperand(0);
if (extractOp.getType() == source.getType())
return source;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 61269e3..caccd1f 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2604,3 +2604,41 @@ func.func @extract_from_0d_regression(%v: vector<f32>) -> f32 {
%0 = vector.extract %v[] : f32 from vector<f32>
return %0 : f32
}
+
+// -----
+
+// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression(
+// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>)
+func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
+ // Splat scalar to 0D and extract scalar.
+ %0 = vector.splat %a : vector<f32>
+ %1 = vector.extract %0[] : f32 from vector<f32>
+
+ // Broadcast scalar to 0D and extract scalar.
+ %2 = vector.broadcast %a : f32 to vector<f32>
+ %3 = vector.extract %2[] : f32 from vector<f32>
+
+ // Broadcast 0D to 3D and extract scalar.
+ // CHECK: %[[extract1:.*]] = vector.extractelement %[[b]][] : vector<f32>
+ %4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
+ %5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
+
+ // Splat scalar to 2D and extract scalar.
+ %6 = vector.splat %a : vector<2x3xf32>
+ %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
+
+ // Broadcast scalar to 3D and extract scalar.
+ %8 = vector.broadcast %a : f32 to vector<5x6x7xf32>
+ %9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32>
+
+ // Extract 2D from 3D that was broadcasted from a scalar.
+ // CHECK: %[[extract2:.*]] = vector.broadcast %[[a]] : f32 to vector<6x7xf32>
+ %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32>
+
+ // Extract 1D from 2D that was splat'ed from a scalar.
+ // CHECK: %[[extract3:.*]] = vector.broadcast %[[a]] : f32 to vector<3xf32>
+ %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32>
+
+ // CHECK: return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]]
+ return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
+}