aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrzej WarzyƄski <andrzej.warzynski@arm.com>2023-09-15 16:35:47 +0100
committerGitHub <noreply@github.com>2023-09-15 16:35:47 +0100
commit57cf6896cd5a48d6978372b9e0fa93fa5381bbba (patch)
tree2d69cafa35a1002c53c444084e9e67f1293d29ed
parentcadabb58f1a4d37074d75b0e1b8bc0f8cdacab98 (diff)
downloadllvm-57cf6896cd5a48d6978372b9e0fa93fa5381bbba.zip
llvm-57cf6896cd5a48d6978372b9e0fa93fa5381bbba.tar.gz
llvm-57cf6896cd5a48d6978372b9e0fa93fa5381bbba.tar.bz2
[mlir][vector] Fix vector.broadcast lowering for scalable vectors (#66344)
This patch makes sure that the following case is lowered correctly ("duplication"): ``` func.func @broadcast_scalable_duplication(%arg0: vector<[32]xf32>) -> vector<1x[32]xf32> { %res = vector.broadcast %arg0 : vector<[32]xf32> to vector<1x[32]xf32> return %res : vector<1x[32]xf32> } ```
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp3
-rw-r--r--mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir11
2 files changed, 12 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 7c606e0..44e3f76 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -84,8 +84,7 @@ public:
// %x = [%b,%b,%b,%b] : n-D
if (srcRank < dstRank) {
// Duplication.
- VectorType resType =
- VectorType::get(dstType.getShape().drop_front(), eltType);
+ VectorType resType = VectorType::Builder(dstType).dropDim(0);
Value bcst =
rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
Value result = rewriter.create<arith::ConstantOp>(
diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
index 2d3c88d..386102cf 100644
--- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir
@@ -162,6 +162,17 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
return %0 : vector<4x3x2xf32>
}
+// CHECK-LABEL: func.func @broadcast_scalable_duplication
+// CHECK-SAME: %[[ARG0:.*]]: vector<[32]xf32>)
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x[32]xf32>
+// CHECK: %[[RES:.*]] = vector.insert %[[ARG0]], %[[CST]] [0] : vector<[32]xf32> into vector<1x[32]xf32>
+// CHECK: return %[[RES]] : vector<1x[32]xf32>
+
+func.func @broadcast_scalable_duplication(%arg0: vector<[32]xf32>) -> vector<1x[32]xf32> {
+ %res = vector.broadcast %arg0 : vector<[32]xf32> to vector<1x[32]xf32>
+ return %res : vector<1x[32]xf32>
+}
+
transform.sequence failures(propagate) {
^bb1(%module_op: !transform.any_op):
%f = transform.structured.match ops{["func.func"]} in %module_op