aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp')
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp25
1 files changed, 21 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 2c56a43..b4605cd 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -246,11 +246,28 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
if (!computeShapeRatio(srcShape, shape))
return {value};
+ int64_t srcShapeRank = srcShape.size();
+ int64_t targetShapeRank = shape.size();
+
+ SmallVector<int64_t> adjustedTargetShape(srcShape.size());
+ int64_t rankDiff = srcShapeRank - targetShapeRank;
+ std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
+ 1);
+ std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff);
+
SmallVector<Value> result;
- for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) {
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(srcShape, adjustedTargetShape)) {
SmallVector<int64_t> staticStrides(offsets.size(), 1);
- result.push_back(vector::ExtractStridedSliceOp::create(
- builder, loc, value, offsets, shape, staticStrides));
+ Value slice = vector::ExtractStridedSliceOp::create(
+ builder, loc, value, offsets, adjustedTargetShape, staticStrides);
+
+ // Reshape to remove leading unit dims if needed
+ if (srcShapeRank > targetShapeRank) {
+ auto targetTy = VectorType::get(shape, vecTy.getElementType());
+ slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice);
+ }
+ result.push_back(slice);
}
return result;
@@ -274,7 +291,7 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
for (auto [src, offsets] :
llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) {
- SmallVector<int64_t> staticStrides(offsets.size(), 1);
+ SmallVector<int64_t> staticStrides(tileShape.size(), 1);
result = vector::InsertStridedSliceOp::create(builder, loc, src, result,
offsets, staticStrides);
}