From 465ea0bfa69aa48afef58666b084467a1c96c81b Mon Sep 17 00:00:00 2001 From: Crefeda Rodrigues <65665931+cfRod@users.noreply.github.com> Date: Fri, 22 Mar 2024 10:08:03 +0000 Subject: [mlir][vector] Propagate scalability in TransferWriteNonPermutationLowering (#85632) Updates `extendVectorRank` so that scalability in patterns that use it (in particular, `TransferWriteNonPermutationLowering`), is correctly propagated. Closed related previous PR https://github.com/llvm/llvm-project/pull/85270 --------- Signed-off-by: Crefeda Rodrigues Co-authored-by: Benjamin Maxwell --- mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'mlir/lib') diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index 4a5e8fc..0693aa5 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -41,8 +41,12 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, SmallVector newShape(addedRank, 1); newShape.append(originalVecType.getShape().begin(), originalVecType.getShape().end()); - VectorType newVecType = - VectorType::get(newShape, originalVecType.getElementType()); + + SmallVector newScalableDims(addedRank, false); + newScalableDims.append(originalVecType.getScalableDims().begin(), + originalVecType.getScalableDims().end()); + VectorType newVecType = VectorType::get( + newShape, originalVecType.getElementType(), newScalableDims); return builder.create(loc, newVecType, vec); } -- cgit v1.1