diff options
author | Andrzej WarzyĆski <andrzej.warzynski@arm.com> | 2023-09-15 16:35:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-15 16:35:47 +0100 |
commit | 57cf6896cd5a48d6978372b9e0fa93fa5381bbba (patch) | |
tree | 2d69cafa35a1002c53c444084e9e67f1293d29ed | |
parent | cadabb58f1a4d37074d75b0e1b8bc0f8cdacab98 (diff) | |
download | llvm-57cf6896cd5a48d6978372b9e0fa93fa5381bbba.zip llvm-57cf6896cd5a48d6978372b9e0fa93fa5381bbba.tar.gz llvm-57cf6896cd5a48d6978372b9e0fa93fa5381bbba.tar.bz2 |
[mlir][vector] Fix vector.broadcast lowering for scalable vectors (#66344)
This patch makes sure that the following case is lowered correctly
("duplication"):
```
func.func @broadcast_scalable_duplication(%arg0: vector<[32]xf32>) -> vector<1x[32]xf32> {
%res = vector.broadcast %arg0 : vector<[32]xf32> to vector<1x[32]xf32>
return %res : vector<1x[32]xf32>
}
```
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp | 3 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir | 11 |
2 files changed, 12 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp index 7c606e0..44e3f76 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -84,8 +84,7 @@ public: // %x = [%b,%b,%b,%b] : n-D if (srcRank < dstRank) { // Duplication. - VectorType resType = - VectorType::get(dstType.getShape().drop_front(), eltType); + VectorType resType = VectorType::Builder(dstType).dropDim(0); Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource()); Value result = rewriter.create<arith::ConstantOp>( diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir index 2d3c88d..386102cf 100644 --- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir @@ -162,6 +162,17 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2 return %0 : vector<4x3x2xf32> } +// CHECK-LABEL: func.func @broadcast_scalable_duplication +// CHECK-SAME: %[[ARG0:.*]]: vector<[32]xf32>) +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x[32]xf32> +// CHECK: %[[RES:.*]] = vector.insert %[[ARG0]], %[[CST]] [0] : vector<[32]xf32> into vector<1x[32]xf32> +// CHECK: return %[[RES]] : vector<1x[32]xf32> + +func.func @broadcast_scalable_duplication(%arg0: vector<[32]xf32>) -> vector<1x[32]xf32> { + %res = vector.broadcast %arg0 : vector<[32]xf32> to vector<1x[32]xf32> + return %res : vector<1x[32]xf32> +} + transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): %f = transform.structured.match ops{["func.func"]} in %module_op |