diff options
author | Crefeda Rodrigues <65665931+cfRod@users.noreply.github.com> | 2024-03-22 10:08:03 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-22 10:08:03 +0000 |
commit | 465ea0bfa69aa48afef58666b084467a1c96c81b (patch) | |
tree | d8d30be0cd5e83ddc22c42c0355724584bfd9f3e /mlir/lib | |
parent | 99d8c25b3104fc07f46532bd681515c5f3c71133 (diff) | |
download | llvm-465ea0bfa69aa48afef58666b084467a1c96c81b.zip llvm-465ea0bfa69aa48afef58666b084467a1c96c81b.tar.gz llvm-465ea0bfa69aa48afef58666b084467a1c96c81b.tar.bz2 |
[mlir][vector] Propagate scalability in TransferWriteNonPermutationLowering (#85632)
Updates `extendVectorRank` so that scalability in patterns
that use it (in particular, `TransferWriteNonPermutationLowering`),
is correctly propagated.
Closed related previous PR
https://github.com/llvm/llvm-project/pull/85270
---------
Signed-off-by: Crefeda Rodrigues <crefeda.rodrigues@arm.com>
Co-authored-by: Benjamin Maxwell <macdue@dueutil.tech>
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 4a5e8fc..0693aa5 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -41,8 +41,12 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, SmallVector<int64_t> newShape(addedRank, 1); newShape.append(originalVecType.getShape().begin(), originalVecType.getShape().end()); - VectorType newVecType = - VectorType::get(newShape, originalVecType.getElementType()); + + SmallVector<bool> newScalableDims(addedRank, false); + newScalableDims.append(originalVecType.getScalableDims().begin(), + originalVecType.getScalableDims().end()); + VectorType newVecType = VectorType::get( + newShape, originalVecType.getElementType(), newScalableDims); return builder.create<vector::BroadcastOp>(loc, newVecType, vec); } |