diff options
Diffstat (limited to 'mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp')
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 25 |
1 files changed, 21 insertions, 4 deletions
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index 2c56a43..b4605cd 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -246,11 +246,28 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, if (!computeShapeRatio(srcShape, shape)) return {value}; + int64_t srcShapeRank = srcShape.size(); + int64_t targetShapeRank = shape.size(); + + SmallVector<int64_t> adjustedTargetShape(srcShape.size()); + int64_t rankDiff = srcShapeRank - targetShapeRank; + std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff, + 1); + std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff); + SmallVector<Value> result; - for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) { + for (SmallVector<int64_t> offsets : + StaticTileOffsetRange(srcShape, adjustedTargetShape)) { SmallVector<int64_t> staticStrides(offsets.size(), 1); - result.push_back(vector::ExtractStridedSliceOp::create( - builder, loc, value, offsets, shape, staticStrides)); + Value slice = vector::ExtractStridedSliceOp::create( + builder, loc, value, offsets, adjustedTargetShape, staticStrides); + + // Reshape to remove leading unit dims if needed + if (srcShapeRank > targetShapeRank) { + auto targetTy = VectorType::get(shape, vecTy.getElementType()); + slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice); + } + result.push_back(slice); } return result; @@ -274,7 +291,7 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc, for (auto [src, offsets] : llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) { - SmallVector<int64_t> staticStrides(offsets.size(), 1); + SmallVector<int64_t> staticStrides(tileShape.size(), 1); result = vector::InsertStridedSliceOp::create(builder, loc, src, result, offsets, staticStrides); } |
