aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2024-06-18 16:25:47 +0200
committerGitHub <noreply@github.com>2024-06-18 16:25:47 +0200
commit552f80a97982c30e3cb1bde2223e923a2e3a47e8 (patch)
treeefeefa4731236a87a4085212e92259b74d64ae62
parent41f6aee769a900558e58eba7b5f9044f8f8a05b7 (diff)
downloadllvm-552f80a97982c30e3cb1bde2223e923a2e3a47e8.zip
llvm-552f80a97982c30e3cb1bde2223e923a2e3a47e8.tar.gz
llvm-552f80a97982c30e3cb1bde2223e923a2e3a47e8.tar.bz2
[mlir][vector] Fix crash when folding 0D extract from splat/broadcast (#95918)
There was an assertion in the folder that caused a crash when extracting from a vector that is defined by an op with 0D semantics. This commit removes the assertion and adds test cases to ensure that 0D scenarios are handled correctly.
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp5
-rw-r--r--mlir/test/Dialect/Vector/canonicalize.mlir38
2 files changed, 38 insertions, 5 deletions
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5e5d3e0..2bf4f16 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1631,11 +1631,6 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
return Value();
- // 0-D vectors not supported.
- assert(!hasZeroDimVectors(extractOp) && "0-D vectors not supported");
- if (hasZeroDimVectors(defOp))
- return Value();
-
Value source = defOp->getOperand(0);
if (extractOp.getType() == source.getType())
return source;
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 61269e3..caccd1f 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2604,3 +2604,41 @@ func.func @extract_from_0d_regression(%v: vector<f32>) -> f32 {
%0 = vector.extract %v[] : f32 from vector<f32>
return %0 : f32
}
+
+// -----
+
+// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression(
+// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>)
+func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
+ // Splat scalar to 0D and extract scalar.
+ %0 = vector.splat %a : vector<f32>
+ %1 = vector.extract %0[] : f32 from vector<f32>
+
+ // Broadcast scalar to 0D and extract scalar.
+ %2 = vector.broadcast %a : f32 to vector<f32>
+ %3 = vector.extract %2[] : f32 from vector<f32>
+
+ // Broadcast 0D to 3D and extract scalar.
+ // CHECK: %[[extract1:.*]] = vector.extractelement %[[b]][] : vector<f32>
+ %4 = vector.broadcast %b : vector<f32> to vector<1x2x4xf32>
+ %5 = vector.extract %4[0, 0, 1] : f32 from vector<1x2x4xf32>
+
+ // Splat scalar to 2D and extract scalar.
+ %6 = vector.splat %a : vector<2x3xf32>
+ %7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
+
+ // Broadcast scalar to 3D and extract scalar.
+ %8 = vector.broadcast %a : f32 to vector<5x6x7xf32>
+ %9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32>
+
+ // Extract 2D from 3D that was broadcasted from a scalar.
+ // CHECK: %[[extract2:.*]] = vector.broadcast %[[a]] : f32 to vector<6x7xf32>
+ %10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32>
+
+ // Extract 1D from 2D that was splat'ed from a scalar.
+ // CHECK: %[[extract3:.*]] = vector.broadcast %[[a]] : f32 to vector<3xf32>
+ %11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32>
+
+ // CHECK: return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]]
+ return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
+}