diff options
Diffstat (limited to 'mlir/lib/Dialect')
| -rw-r--r-- | mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 68 | ||||
| -rw-r--r-- | mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp | 16 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 159 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp | 32 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp | 25 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 6 | ||||
| -rw-r--r-- | mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp | 25 |
8 files changed, 262 insertions, 71 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 61166db..585b6da 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -360,45 +360,53 @@ LogicalResult ScaledExtPacked816Op::verify() { //===----------------------------------------------------------------------===// // WMMAOp //===----------------------------------------------------------------------===// -LogicalResult WMMAOp::verify() { - Type sourceAType = getSourceA().getType(); - Type sourceBType = getSourceB().getType(); - Type destType = getDestC().getType(); - VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType); - VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType); - VectorType destVectorType = dyn_cast<VectorType>(destType); +ParseResult mlir::amdgpu::parseMNKDimensionList(OpAsmParser &parser, + IntegerAttr &m, IntegerAttr &n, + IntegerAttr &k) { + SmallVector<int64_t, 3> dimensions; + if (parser.parseDimensionList(dimensions, false, false)) + return failure(); + if (dimensions.size() != 3) + return parser.emitError(parser.getCurrentLocation()) + << "expected 3 dimensions in MNK dimension list"; - Type sourceAElemType = sourceVectorAType.getElementType(); - Type sourceBElemType = sourceVectorBType.getElementType(); - Type destElemType = destVectorType.getElementType(); + m = parser.getBuilder().getI32IntegerAttr(dimensions[0]); + n = parser.getBuilder().getI32IntegerAttr(dimensions[1]); + k = parser.getBuilder().getI32IntegerAttr(dimensions[2]); + return success(); +} - if (sourceVectorAType.getNumElements() != - sourceVectorBType.getNumElements()) { +LogicalResult WMMAOp::verify() { + auto sourceAType = cast<VectorType>(getSourceA().getType()); + auto sourceBType = cast<VectorType>(getSourceB().getType()); + auto destType = cast<VectorType>(getDestC().getType()); + + Type sourceAElemType = sourceAType.getElementType(); + Type sourceBElemType = sourceBType.getElementType(); + if (sourceAType.getNumElements() != sourceBType.getNumElements()) { return emitOpError("source vectors have different lengths: ") - << sourceVectorAType << " vs. " << sourceVectorBType; + << sourceAType << " vs. " << sourceBType; } - bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType); - bool isSrcFloat = - isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>( - sourceAElemType); - - if (isDestFloat && !isSrcFloat) { - return emitOpError("Expected float sources with float destination"); - } + bool isDestFloat = destType.getElementType().isFloat(); + bool isSrcFloat = sourceAElemType.isFloat(); - if (!isDestFloat && isSrcFloat) { - return emitOpError("Expected int sources with int destination"); - } + if (isDestFloat && !isSrcFloat) + return emitOpError("expected float sources with float destination"); + if (!isDestFloat && isSrcFloat) + return emitOpError("expected int sources with int destination"); - if (sourceAElemType != sourceBElemType && - !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) && - isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) { + if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) { return emitOpError( "source element types much match (except for fp8) but have ") << sourceAType << " and " << sourceBType; } + + if (!sourceAElemType.isInteger(4) && getK() != 16) { + return emitOpError("K dimension must be 16 for source element type ") + << sourceAElemType; + } return success(); } @@ -414,11 +422,11 @@ LogicalResult MFMAOp::verify() { Type sourceElem = sourceType, destElem = destType; uint32_t sourceLen = 1, destLen = 1; - if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) { + if (auto sourceVector = dyn_cast<VectorType>(sourceType)) { sourceLen = sourceVector.getNumElements(); sourceElem = sourceVector.getElementType(); } - if (auto destVector = llvm::dyn_cast<VectorType>(destType)) { + if (auto destVector = dyn_cast<VectorType>(destType)) { destLen = destVector.getNumElements(); destElem = destVector.getElementType(); } @@ -443,7 +451,7 @@ LogicalResult MFMAOp::verify() { return emitOpError("expected both non-small-float source operand types " "to match exactly"); } - // Normalize the wider integer types the compiler expects to i8 + // Normalize the wider integer types the compiler expects to i8. if (sourceElem.isInteger(32)) { sourceLen *= 4; sourceElem = b.getI8Type(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp index 73e0f3d..f53d272 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp @@ -159,14 +159,22 @@ IterationGraphSorter::IterationGraphSorter( loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)), strategy(strategy) { // One map per tensor. - assert(loop2InsLvl.size() == ins.size()); + assert(this->loop2InsLvl.size() == this->ins.size()); // All the affine maps have the same number of dimensions (loops). assert(llvm::all_equal(llvm::map_range( - loop2InsLvl, [](AffineMap m) { return m.getNumDims(); }))); + this->loop2InsLvl, [](AffineMap m) { return m.getNumDims(); }))); // The number of results of the map should match the rank of the tensor. - assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) { + assert(llvm::all_of(llvm::zip(this->loop2InsLvl, this->ins), [](auto mvPair) { auto [m, v] = mvPair; - return m.getNumResults() == cast<ShapedType>(v.getType()).getRank(); + + // For ranked types the rank must match. + // Simply return true for UnrankedTensorType + if (auto shapedType = llvm::dyn_cast<ShapedType>(v.getType())) { + return !shapedType.hasRank() || + (m.getNumResults() == shapedType.getRank()); + } + // Non-shaped (scalar) types behave like rank-0. + return m.getNumResults() == 0; })); itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false)); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 6cd0eae..0aff67f 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -370,7 +370,7 @@ ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) { result.operands))) return failure(); - result.addTypes(fnTy.getResult(0)); + result.addTypes(fnTy.getResults()); result.addAttributes(attrs); return success(); @@ -532,6 +532,24 @@ void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) { printWithEnumHandling(parser, *this); } +ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling<tosa::BlockSize>(parser, result); +} + +void CastFromBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + +ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling<tosa::BlockSize>(parser, result); +} + +void CastToBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + //===----------------------------------------------------------------------===// // Tosa utilities. //===----------------------------------------------------------------------===// @@ -3944,6 +3962,145 @@ LogicalResult RescaleOp::inferReturnTypeComponents( return success(); } +LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional<Location> location, + CastFromBlockScaledOp::Adaptor adaptor, + SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { + const ShapeAdaptor inputShape(adaptor.getInputData().getType()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputShape)); + return success(); +} + +LogicalResult CastFromBlockScaledOp::verify() { + const Type inputDataType = getInputData().getType(); + const Type outputDataType = getResult().getType(); + if (failed(verifyCompatibleShape(inputDataType, outputDataType))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "output_data (" << outputDataType << ")"; + + const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType); + + if (inputDataShape.hasRank()) { + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(getBlockSize()); + const int64_t inputDataLastDim = + inputDataShape.getDimSize(inputDataShape.getRank() - 1); + if (inputDataLastDim % blockSize != 0) + return emitOpError() << "expect last dimension of input_data (" + << inputDataLastDim + << ") to be divisible by block_size (" << blockSize + << ")"; + + const Type inputScaleType = getInputScale().getType(); + const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType); + + if (inputScaleShape.hasRank()) { + SmallVector<int64_t> inputDataDims, inputScaleDims; + inputDataShape.getDims(inputDataDims); + inputScaleShape.getDims(inputScaleDims); + + if (inputDataDims.size() != inputScaleDims.size() || + failed(verifyCompatibleShape( + ArrayRef<int64_t>(inputDataDims).drop_back(1), + ArrayRef<int64_t>(inputScaleDims).drop_back(1)))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "input_scale (" << inputScaleType + << ") except for the last dimension"; + + const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize, + inputScaleDims.back()}; + if (ShapedType::isStatic(inputDataLastDim) && + failed(verifyCompatibleDims(dimsToCheck))) + return emitOpError() + << "expect last dimension of input_scale (" + << inputScaleDims.back() + << ") to be equal to last dimension of input_data / block_size (" + << inputDataDims.back() / blockSize << ")"; + } + } + + return success(); +} + +LogicalResult CastToBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional<Location> location, + CastToBlockScaledOp::Adaptor adaptor, + SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { + const ShapeAdaptor inputShape(adaptor.getInputData().getType()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputShape)); + if (!inputShape.hasRank()) + return success(); + + // Calculate output_scale shape if ranked input provided + SmallVector<int64_t> outputScaleShape; + inputShape.getDims(outputScaleShape); + const int64_t lastDimLoc = inputShape.getRank() - 1; + const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc); + if (ShapedType::isStatic(lastDimSize)) { + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize()); + outputScaleShape[lastDimLoc] = lastDimSize / blockSize; + } + inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape)); + return success(); +} + +LogicalResult CastToBlockScaledOp::verify() { + const Type inputDataType = getInputData().getType(); + const Type outputDataType = getResult(0).getType(); + if (failed(verifyCompatibleShape(inputDataType, outputDataType))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "output_data (" << outputDataType << ")"; + + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(getBlockSize()); + const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType); + if (inputDataShape.hasRank()) { + const int64_t inputDataLastDim = + inputDataShape.getDimSize(inputDataShape.getRank() - 1); + if (ShapedType::isStatic(inputDataLastDim) && + inputDataLastDim % blockSize != 0) + return emitOpError() << "expect last dimension of input_data (" + << inputDataLastDim + << ") to be divisible by block_size (" << blockSize + << ")"; + } + + const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType); + const Type outputScaleType = getResult(1).getType(); + const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType); + if (outputDataShape.hasRank() && outputScaleShape.hasRank()) { + SmallVector<int64_t> outputDataDims, outputScaleDims; + outputDataShape.getDims(outputDataDims); + outputScaleShape.getDims(outputScaleDims); + + if (outputDataDims.size() != outputScaleDims.size() || + failed(verifyCompatibleShape( + ArrayRef<int64_t>(outputDataDims).drop_back(1), + ArrayRef<int64_t>(outputScaleDims).drop_back(1)))) + return emitOpError() << "require compatible shapes for output_data (" + << outputDataType << ") and " + << "output_scale (" << outputScaleType + << ") except for the last dimension"; + + const int64_t outputDataLastDim = outputDataDims.back(); + const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize, + outputScaleDims.back()}; + if (ShapedType::isStatic(outputDataLastDim) && + failed(verifyCompatibleDims(dimsToCheck))) + return emitOpError() + << "expect last dimension of output_scale (" + << outputScaleDims.back() + << ") to be equal to last dimension of output_data / block_size (" + << outputDataDims.back() / blockSize << ")"; + } + + return success(); +} + LogicalResult IfOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional<Location> location, IfOp::Adaptor adaptor, diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 53afc5d..ab363ee 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -51,10 +51,11 @@ TosaProfileCompliance::getProfileComplianceMap() { // Base populating function LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands, - Value output) { - for (auto operand : operands) + ValueRange results) { + for (const auto &operand : operands) addValue(operand); - addValue(output); + for (const auto &result : results) + addValue(result); return success(); } @@ -177,23 +178,6 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) { } template <> -LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) { - addValue(op.getInputReal()); - addValue(op.getInputImag()); - addValue(op.getOutputReal()); - addValue(op.getOutputImag()); - return success(); -} - -template <> -LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) { - addValue(op.getInputReal()); - addValue(op.getOutputReal()); - addValue(op.getOutputImag()); - return success(); -} - -template <> LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) { addValue(op.getOnTrue()); addValue(op.getOnFalse()); @@ -246,7 +230,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { // This helper function populates the info for all operands. #define POPULATE_PROFILE_INFO_COMMON(tosaOp) \ if (isa<tosa::tosaOp##Op>(op)) { \ - return populateProfileInfo(op->getOperands(), op->getResult(0)); \ + return populateProfileInfo(op->getOperands(), op->getResults()); \ } // Skip irrelevant operands when they are independent and not tied to any @@ -257,8 +241,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { POPULATE_PROFILE_INFO_CUSTOM(Conv3D) POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D) POPULATE_PROFILE_INFO_CUSTOM(Mul) - POPULATE_PROFILE_INFO_CUSTOM(FFT2d) - POPULATE_PROFILE_INFO_CUSTOM(RFFT2d) POPULATE_PROFILE_INFO_CUSTOM(Concat) POPULATE_PROFILE_INFO_CUSTOM(Pad) POPULATE_PROFILE_INFO_CUSTOM(Reshape) @@ -277,7 +259,11 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) { // For the most of tosa operators, all operands are profile/extension related // and hence are all considered in this profile-based compilance check. POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled) + POPULATE_PROFILE_INFO_COMMON(FFT2d) + POPULATE_PROFILE_INFO_COMMON(RFFT2d) POPULATE_PROFILE_INFO_COMMON(Cast) + POPULATE_PROFILE_INFO_COMMON(CastFromBlockScaled) + POPULATE_PROFILE_INFO_COMMON(CastToBlockScaled) POPULATE_PROFILE_INFO_COMMON(Const) POPULATE_PROFILE_INFO_COMMON(ArgMax) POPULATE_PROFILE_INFO_COMMON(Sub) diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index cb544ad..4d0b61a 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -635,6 +635,8 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) { CHECK_RANKS_AND_SIZES(Transpose); // Type Conversion CHECK_RANKS_AND_SIZES(Cast); + CHECK_RANKS_AND_SIZES(CastFromBlockScaled); + CHECK_RANKS_AND_SIZES(CastToBlockScaled); CHECK_RANKS_AND_SIZES(Rescale); // Control Flow Operators CHECK_RANKS_AND_SIZES(If); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index f77784a..2c37140 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -145,8 +145,26 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const { xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(operandOrResult); if (layout && layout.isForSubgroup()) { - if (!layout.getEffectiveInstDataAsInt().empty()) - return layout.getEffectiveInstDataAsInt(); + if (!layout.getEffectiveInstDataAsInt().empty()) { + SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt(); + // Remove leading unit dimensions from inst_data + // For example, if the inst_data is [1, 1, 32] + // it will pass [32] as the unroll/blocking size. + // Skip it for xegpu nd ops since it will be 2D + // TODO: For vectors ops, experiment with the + // upstream vector remove leading unit dims patterns, + // populateCastAwayVectorLeadingOneDimPatterns. + Operation *definingOp = value.getDefiningOp(); + bool skipLeadingUnitDimRemoval = + definingOp && + (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::DpasOp, + xegpu::StoreNdOp, xegpu::PrefetchNdOp>(definingOp)); + if (!skipLeadingUnitDimRemoval) { + auto it = llvm::find_if(instData, [](auto val) { return val != 1; }); + instData.erase(instData.begin(), it); + } + return instData; + } if (auto type = dyn_cast<ShapedType>(value.getType())) return llvm::to_vector(type.getShape()); @@ -354,7 +372,6 @@ void XeGPUBlockingPass::runOnOperation() { // To create a new attribute with a different chunk_size: auto newEncoding = xegpu::ScatterTensorDescAttr::get( ctx, tdescTy.getMemorySpace(), blockedChunkSize); - encoding = newEncoding; } } @@ -363,7 +380,7 @@ void XeGPUBlockingPass::runOnOperation() { xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding, tdescTy.getLayoutAttr().dropInstData()); } else { - newTy = type.clone(tileShape, elemTy); + newTy = VectorType::get(tileShape, elemTy); } if (returnSingleType) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index aafa1b7..e6e71cc 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -66,8 +66,6 @@ protected: Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize, Location loc, PatternRewriter &rewriter) const { if (auto vecTy = dyn_cast<VectorType>(destTy)) { - assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) && - "Expecting blockSize size to match the rank of destTy."); auto shape = vecTy.getShape(); return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape); } @@ -93,8 +91,6 @@ protected: ArrayRef<int64_t> blockSize, Location loc, PatternRewriter &rewriter) const { if (auto vecTy = dyn_cast<VectorType>(src.getType())) { - assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) && - "Expecting blockSize size to match the rank of src."); return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src, blockSize); } @@ -635,7 +631,7 @@ struct UnrollLoadGatherOpWithOffset VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType()); VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType()); Type elemTy = valueTy.getElementType(); - VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy); + VectorType newValueTy = VectorType::get(*targetShape, elemTy); SmallVector<Type> convertedMaskTypes; SmallVector<Value> convertedMasks; diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index 2c56a43..b4605cd 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -246,11 +246,28 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, if (!computeShapeRatio(srcShape, shape)) return {value}; + int64_t srcShapeRank = srcShape.size(); + int64_t targetShapeRank = shape.size(); + + SmallVector<int64_t> adjustedTargetShape(srcShape.size()); + int64_t rankDiff = srcShapeRank - targetShapeRank; + std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff, + 1); + std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff); + SmallVector<Value> result; - for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) { + for (SmallVector<int64_t> offsets : + StaticTileOffsetRange(srcShape, adjustedTargetShape)) { SmallVector<int64_t> staticStrides(offsets.size(), 1); - result.push_back(vector::ExtractStridedSliceOp::create( - builder, loc, value, offsets, shape, staticStrides)); + Value slice = vector::ExtractStridedSliceOp::create( + builder, loc, value, offsets, adjustedTargetShape, staticStrides); + + // Reshape to remove leading unit dims if needed + if (srcShapeRank > targetShapeRank) { + auto targetTy = VectorType::get(shape, vecTy.getElementType()); + slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice); + } + result.push_back(slice); } return result; @@ -274,7 +291,7 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc, for (auto [src, offsets] : llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) { - SmallVector<int64_t> staticStrides(offsets.size(), 1); + SmallVector<int64_t> staticStrides(tileShape.size(), 1); result = vector::InsertStridedSliceOp::create(builder, loc, src, result, offsets, staticStrides); } |
