aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
authorCrefeda Rodrigues <65665931+cfRod@users.noreply.github.com>2024-03-22 10:08:03 +0000
committerGitHub <noreply@github.com>2024-03-22 10:08:03 +0000
commit465ea0bfa69aa48afef58666b084467a1c96c81b (patch)
treed8d30be0cd5e83ddc22c42c0355724584bfd9f3e /mlir/lib
parent99d8c25b3104fc07f46532bd681515c5f3c71133 (diff)
downloadllvm-465ea0bfa69aa48afef58666b084467a1c96c81b.zip
llvm-465ea0bfa69aa48afef58666b084467a1c96c81b.tar.gz
llvm-465ea0bfa69aa48afef58666b084467a1c96c81b.tar.bz2
[mlir][vector] Propagate scalability in TransferWriteNonPermutationLowering (#85632)
Updates `extendVectorRank` so that scalability in patterns that use it (in particular, `TransferWriteNonPermutationLowering`), is correctly propagated. Closed related previous PR https://github.com/llvm/llvm-project/pull/85270 --------- Signed-off-by: Crefeda Rodrigues <crefeda.rodrigues@arm.com> Co-authored-by: Benjamin Maxwell <macdue@dueutil.tech>
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp8
1 files changed, 6 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4a5e8fc..0693aa5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -41,8 +41,12 @@ static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec,
SmallVector<int64_t> newShape(addedRank, 1);
newShape.append(originalVecType.getShape().begin(),
originalVecType.getShape().end());
- VectorType newVecType =
- VectorType::get(newShape, originalVecType.getElementType());
+
+ SmallVector<bool> newScalableDims(addedRank, false);
+ newScalableDims.append(originalVecType.getScalableDims().begin(),
+ originalVecType.getScalableDims().end());
+ VectorType newVecType = VectorType::get(
+ newShape, originalVecType.getElementType(), newScalableDims);
return builder.create<vector::BroadcastOp>(loc, newVecType, vec);
}