aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp68
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp16
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp159
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp32
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp2
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp25
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp6
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp25
8 files changed, 262 insertions, 71 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 61166db..585b6da 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -360,45 +360,53 @@ LogicalResult ScaledExtPacked816Op::verify() {
//===----------------------------------------------------------------------===//
// WMMAOp
//===----------------------------------------------------------------------===//
-LogicalResult WMMAOp::verify() {
- Type sourceAType = getSourceA().getType();
- Type sourceBType = getSourceB().getType();
- Type destType = getDestC().getType();
- VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
- VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
- VectorType destVectorType = dyn_cast<VectorType>(destType);
+ParseResult mlir::amdgpu::parseMNKDimensionList(OpAsmParser &parser,
+ IntegerAttr &m, IntegerAttr &n,
+ IntegerAttr &k) {
+ SmallVector<int64_t, 3> dimensions;
+ if (parser.parseDimensionList(dimensions, false, false))
+ return failure();
+ if (dimensions.size() != 3)
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected 3 dimensions in MNK dimension list";
- Type sourceAElemType = sourceVectorAType.getElementType();
- Type sourceBElemType = sourceVectorBType.getElementType();
- Type destElemType = destVectorType.getElementType();
+ m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
+ n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
+ k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
+ return success();
+}
- if (sourceVectorAType.getNumElements() !=
- sourceVectorBType.getNumElements()) {
+LogicalResult WMMAOp::verify() {
+ auto sourceAType = cast<VectorType>(getSourceA().getType());
+ auto sourceBType = cast<VectorType>(getSourceB().getType());
+ auto destType = cast<VectorType>(getDestC().getType());
+
+ Type sourceAElemType = sourceAType.getElementType();
+ Type sourceBElemType = sourceBType.getElementType();
+ if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
return emitOpError("source vectors have different lengths: ")
- << sourceVectorAType << " vs. " << sourceVectorBType;
+ << sourceAType << " vs. " << sourceBType;
}
- bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
- bool isSrcFloat =
- isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
- sourceAElemType);
-
- if (isDestFloat && !isSrcFloat) {
- return emitOpError("Expected float sources with float destination");
- }
+ bool isDestFloat = destType.getElementType().isFloat();
+ bool isSrcFloat = sourceAElemType.isFloat();
- if (!isDestFloat && isSrcFloat) {
- return emitOpError("Expected int sources with int destination");
- }
+ if (isDestFloat && !isSrcFloat)
+ return emitOpError("expected float sources with float destination");
+ if (!isDestFloat && isSrcFloat)
+ return emitOpError("expected int sources with int destination");
- if (sourceAElemType != sourceBElemType &&
- !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
- isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
+ if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
return emitOpError(
"source element types much match (except for fp8) but have ")
<< sourceAType << " and " << sourceBType;
}
+
+ if (!sourceAElemType.isInteger(4) && getK() != 16) {
+ return emitOpError("K dimension must be 16 for source element type ")
+ << sourceAElemType;
+ }
return success();
}
@@ -414,11 +422,11 @@ LogicalResult MFMAOp::verify() {
Type sourceElem = sourceType, destElem = destType;
uint32_t sourceLen = 1, destLen = 1;
- if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
+ if (auto sourceVector = dyn_cast<VectorType>(sourceType)) {
sourceLen = sourceVector.getNumElements();
sourceElem = sourceVector.getElementType();
}
- if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
+ if (auto destVector = dyn_cast<VectorType>(destType)) {
destLen = destVector.getNumElements();
destElem = destVector.getElementType();
}
@@ -443,7 +451,7 @@ LogicalResult MFMAOp::verify() {
return emitOpError("expected both non-small-float source operand types "
"to match exactly");
}
- // Normalize the wider integer types the compiler expects to i8
+ // Normalize the wider integer types the compiler expects to i8.
if (sourceElem.isInteger(32)) {
sourceLen *= 4;
sourceElem = b.getI8Type();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
index 73e0f3d..f53d272 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
@@ -159,14 +159,22 @@ IterationGraphSorter::IterationGraphSorter(
loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)),
strategy(strategy) {
// One map per tensor.
- assert(loop2InsLvl.size() == ins.size());
+ assert(this->loop2InsLvl.size() == this->ins.size());
// All the affine maps have the same number of dimensions (loops).
assert(llvm::all_equal(llvm::map_range(
- loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
+ this->loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
// The number of results of the map should match the rank of the tensor.
- assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
+ assert(llvm::all_of(llvm::zip(this->loop2InsLvl, this->ins), [](auto mvPair) {
auto [m, v] = mvPair;
- return m.getNumResults() == cast<ShapedType>(v.getType()).getRank();
+
+ // For ranked types the rank must match.
+ // Simply return true for UnrankedTensorType
+ if (auto shapedType = llvm::dyn_cast<ShapedType>(v.getType())) {
+ return !shapedType.hasRank() ||
+ (m.getNumResults() == shapedType.getRank());
+ }
+ // Non-shaped (scalar) types behave like rank-0.
+ return m.getNumResults() == 0;
}));
itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false));
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 6cd0eae..0aff67f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -370,7 +370,7 @@ ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
result.operands)))
return failure();
- result.addTypes(fnTy.getResult(0));
+ result.addTypes(fnTy.getResults());
result.addAttributes(attrs);
return success();
@@ -532,6 +532,24 @@ void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
printWithEnumHandling(parser, *this);
}
+ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
//===----------------------------------------------------------------------===//
// Tosa utilities.
//===----------------------------------------------------------------------===//
@@ -3944,6 +3962,145 @@ LogicalResult RescaleOp::inferReturnTypeComponents(
return success();
}
+LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ CastFromBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ const ShapeAdaptor inputShape(adaptor.getInputData().getType());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+ return success();
+}
+
+LogicalResult CastFromBlockScaledOp::verify() {
+ const Type inputDataType = getInputData().getType();
+ const Type outputDataType = getResult().getType();
+ if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "output_data (" << outputDataType << ")";
+
+ const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
+
+ if (inputDataShape.hasRank()) {
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ const int64_t inputDataLastDim =
+ inputDataShape.getDimSize(inputDataShape.getRank() - 1);
+ if (inputDataLastDim % blockSize != 0)
+ return emitOpError() << "expect last dimension of input_data ("
+ << inputDataLastDim
+ << ") to be divisible by block_size (" << blockSize
+ << ")";
+
+ const Type inputScaleType = getInputScale().getType();
+ const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
+
+ if (inputScaleShape.hasRank()) {
+ SmallVector<int64_t> inputDataDims, inputScaleDims;
+ inputDataShape.getDims(inputDataDims);
+ inputScaleShape.getDims(inputScaleDims);
+
+ if (inputDataDims.size() != inputScaleDims.size() ||
+ failed(verifyCompatibleShape(
+ ArrayRef<int64_t>(inputDataDims).drop_back(1),
+ ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "input_scale (" << inputScaleType
+ << ") except for the last dimension";
+
+ const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
+ inputScaleDims.back()};
+ if (ShapedType::isStatic(inputDataLastDim) &&
+ failed(verifyCompatibleDims(dimsToCheck)))
+ return emitOpError()
+ << "expect last dimension of input_scale ("
+ << inputScaleDims.back()
+ << ") to be equal to last dimension of input_data / block_size ("
+ << inputDataDims.back() / blockSize << ")";
+ }
+ }
+
+ return success();
+}
+
+LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ CastToBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ const ShapeAdaptor inputShape(adaptor.getInputData().getType());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+ if (!inputShape.hasRank())
+ return success();
+
+ // Calculate output_scale shape if ranked input provided
+ SmallVector<int64_t> outputScaleShape;
+ inputShape.getDims(outputScaleShape);
+ const int64_t lastDimLoc = inputShape.getRank() - 1;
+ const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
+ if (ShapedType::isStatic(lastDimSize)) {
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
+ outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
+ }
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
+ return success();
+}
+
+LogicalResult CastToBlockScaledOp::verify() {
+ const Type inputDataType = getInputData().getType();
+ const Type outputDataType = getResult(0).getType();
+ if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "output_data (" << outputDataType << ")";
+
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
+ if (inputDataShape.hasRank()) {
+ const int64_t inputDataLastDim =
+ inputDataShape.getDimSize(inputDataShape.getRank() - 1);
+ if (ShapedType::isStatic(inputDataLastDim) &&
+ inputDataLastDim % blockSize != 0)
+ return emitOpError() << "expect last dimension of input_data ("
+ << inputDataLastDim
+ << ") to be divisible by block_size (" << blockSize
+ << ")";
+ }
+
+ const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
+ const Type outputScaleType = getResult(1).getType();
+ const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
+ if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
+ SmallVector<int64_t> outputDataDims, outputScaleDims;
+ outputDataShape.getDims(outputDataDims);
+ outputScaleShape.getDims(outputScaleDims);
+
+ if (outputDataDims.size() != outputScaleDims.size() ||
+ failed(verifyCompatibleShape(
+ ArrayRef<int64_t>(outputDataDims).drop_back(1),
+ ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
+ return emitOpError() << "require compatible shapes for output_data ("
+ << outputDataType << ") and "
+ << "output_scale (" << outputScaleType
+ << ") except for the last dimension";
+
+ const int64_t outputDataLastDim = outputDataDims.back();
+ const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
+ outputScaleDims.back()};
+ if (ShapedType::isStatic(outputDataLastDim) &&
+ failed(verifyCompatibleDims(dimsToCheck)))
+ return emitOpError()
+ << "expect last dimension of output_scale ("
+ << outputScaleDims.back()
+ << ") to be equal to last dimension of output_data / block_size ("
+ << outputDataDims.back() / blockSize << ")";
+ }
+
+ return success();
+}
+
LogicalResult IfOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
IfOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 53afc5d..ab363ee 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -51,10 +51,11 @@ TosaProfileCompliance::getProfileComplianceMap() {
// Base populating function
LogicalResult ProfileInfoDepot::populateProfileInfo(ValueRange operands,
- Value output) {
- for (auto operand : operands)
+ ValueRange results) {
+ for (const auto &operand : operands)
addValue(operand);
- addValue(output);
+ for (const auto &result : results)
+ addValue(result);
return success();
}
@@ -177,23 +178,6 @@ LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::ResizeOp op) {
}
template <>
-LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::FFT2dOp op) {
- addValue(op.getInputReal());
- addValue(op.getInputImag());
- addValue(op.getOutputReal());
- addValue(op.getOutputImag());
- return success();
-}
-
-template <>
-LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::RFFT2dOp op) {
- addValue(op.getInputReal());
- addValue(op.getOutputReal());
- addValue(op.getOutputImag());
- return success();
-}
-
-template <>
LogicalResult ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
addValue(op.getOnTrue());
addValue(op.getOnFalse());
@@ -246,7 +230,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// This helper function populates the info for all operands.
#define POPULATE_PROFILE_INFO_COMMON(tosaOp) \
if (isa<tosa::tosaOp##Op>(op)) { \
- return populateProfileInfo(op->getOperands(), op->getResult(0)); \
+ return populateProfileInfo(op->getOperands(), op->getResults()); \
}
// Skip irrelevant operands when they are independent and not tied to any
@@ -257,8 +241,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_CUSTOM(Conv3D)
POPULATE_PROFILE_INFO_CUSTOM(DepthwiseConv2D)
POPULATE_PROFILE_INFO_CUSTOM(Mul)
- POPULATE_PROFILE_INFO_CUSTOM(FFT2d)
- POPULATE_PROFILE_INFO_CUSTOM(RFFT2d)
POPULATE_PROFILE_INFO_CUSTOM(Concat)
POPULATE_PROFILE_INFO_CUSTOM(Pad)
POPULATE_PROFILE_INFO_CUSTOM(Reshape)
@@ -277,7 +259,11 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// For the most of tosa operators, all operands are profile/extension related
// and hence are all considered in this profile-based compilance check.
POPULATE_PROFILE_INFO_COMMON(MatmulTBlockScaled)
+ POPULATE_PROFILE_INFO_COMMON(FFT2d)
+ POPULATE_PROFILE_INFO_COMMON(RFFT2d)
POPULATE_PROFILE_INFO_COMMON(Cast)
+ POPULATE_PROFILE_INFO_COMMON(CastFromBlockScaled)
+ POPULATE_PROFILE_INFO_COMMON(CastToBlockScaled)
POPULATE_PROFILE_INFO_COMMON(Const)
POPULATE_PROFILE_INFO_COMMON(ArgMax)
POPULATE_PROFILE_INFO_COMMON(Sub)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index cb544ad..4d0b61a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -635,6 +635,8 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_RANKS_AND_SIZES(Transpose);
// Type Conversion
CHECK_RANKS_AND_SIZES(Cast);
+ CHECK_RANKS_AND_SIZES(CastFromBlockScaled);
+ CHECK_RANKS_AND_SIZES(CastToBlockScaled);
CHECK_RANKS_AND_SIZES(Rescale);
// Control Flow Operators
CHECK_RANKS_AND_SIZES(If);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index f77784a..2c37140 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -145,8 +145,26 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
xegpu::DistributeLayoutAttr layout =
xegpu::getDistributeLayoutAttr(operandOrResult);
if (layout && layout.isForSubgroup()) {
- if (!layout.getEffectiveInstDataAsInt().empty())
- return layout.getEffectiveInstDataAsInt();
+ if (!layout.getEffectiveInstDataAsInt().empty()) {
+ SmallVector<int64_t> instData = layout.getEffectiveInstDataAsInt();
+ // Remove leading unit dimensions from inst_data
+ // For example, if the inst_data is [1, 1, 32]
+ // it will pass [32] as the unroll/blocking size.
+ // Skip it for xegpu nd ops since it will be 2D
+ // TODO: For vectors ops, experiment with the
+ // upstream vector remove leading unit dims patterns,
+ // populateCastAwayVectorLeadingOneDimPatterns.
+ Operation *definingOp = value.getDefiningOp();
+ bool skipLeadingUnitDimRemoval =
+ definingOp &&
+ (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::DpasOp,
+ xegpu::StoreNdOp, xegpu::PrefetchNdOp>(definingOp));
+ if (!skipLeadingUnitDimRemoval) {
+ auto it = llvm::find_if(instData, [](auto val) { return val != 1; });
+ instData.erase(instData.begin(), it);
+ }
+ return instData;
+ }
if (auto type = dyn_cast<ShapedType>(value.getType()))
return llvm::to_vector(type.getShape());
@@ -354,7 +372,6 @@ void XeGPUBlockingPass::runOnOperation() {
// To create a new attribute with a different chunk_size:
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
ctx, tdescTy.getMemorySpace(), blockedChunkSize);
-
encoding = newEncoding;
}
}
@@ -363,7 +380,7 @@ void XeGPUBlockingPass::runOnOperation() {
xegpu::TensorDescType::get(ctx, tileShape, elemTy, encoding,
tdescTy.getLayoutAttr().dropInstData());
} else {
- newTy = type.clone(tileShape, elemTy);
+ newTy = VectorType::get(tileShape, elemTy);
}
if (returnSingleType)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index aafa1b7..e6e71cc 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -66,8 +66,6 @@ protected:
Value unpack(ValueRange srcs, Type destTy, ArrayRef<int64_t> blockSize,
Location loc, PatternRewriter &rewriter) const {
if (auto vecTy = dyn_cast<VectorType>(destTy)) {
- assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
- "Expecting blockSize size to match the rank of destTy.");
auto shape = vecTy.getShape();
return xegpu::createVectorWithShapeFromValues(rewriter, loc, srcs, shape);
}
@@ -93,8 +91,6 @@ protected:
ArrayRef<int64_t> blockSize, Location loc,
PatternRewriter &rewriter) const {
if (auto vecTy = dyn_cast<VectorType>(src.getType())) {
- assert(vecTy.getRank() == static_cast<int64_t>(blockSize.size()) &&
- "Expecting blockSize size to match the rank of src.");
return xegpu::extractVectorsWithShapeFromValue(rewriter, loc, src,
blockSize);
}
@@ -635,7 +631,7 @@ struct UnrollLoadGatherOpWithOffset
VectorType maskTy = llvm::dyn_cast<VectorType>(mask.getType());
VectorType offsetsTy = llvm::dyn_cast<VectorType>(offsets.getType());
Type elemTy = valueTy.getElementType();
- VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
+ VectorType newValueTy = VectorType::get(*targetShape, elemTy);
SmallVector<Type> convertedMaskTypes;
SmallVector<Value> convertedMasks;
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 2c56a43..b4605cd 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -246,11 +246,28 @@ xegpu::extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc,
if (!computeShapeRatio(srcShape, shape))
return {value};
+ int64_t srcShapeRank = srcShape.size();
+ int64_t targetShapeRank = shape.size();
+
+ SmallVector<int64_t> adjustedTargetShape(srcShape.size());
+ int64_t rankDiff = srcShapeRank - targetShapeRank;
+ std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
+ 1);
+ std::copy(shape.begin(), shape.end(), adjustedTargetShape.begin() + rankDiff);
+
SmallVector<Value> result;
- for (SmallVector<int64_t> offsets : StaticTileOffsetRange(srcShape, shape)) {
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(srcShape, adjustedTargetShape)) {
SmallVector<int64_t> staticStrides(offsets.size(), 1);
- result.push_back(vector::ExtractStridedSliceOp::create(
- builder, loc, value, offsets, shape, staticStrides));
+ Value slice = vector::ExtractStridedSliceOp::create(
+ builder, loc, value, offsets, adjustedTargetShape, staticStrides);
+
+ // Reshape to remove leading unit dims if needed
+ if (srcShapeRank > targetShapeRank) {
+ auto targetTy = VectorType::get(shape, vecTy.getElementType());
+ slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice);
+ }
+ result.push_back(slice);
}
return result;
@@ -274,7 +291,7 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
for (auto [src, offsets] :
llvm::zip_equal(values, StaticTileOffsetRange(shape, tileShape))) {
- SmallVector<int64_t> staticStrides(offsets.size(), 1);
+ SmallVector<int64_t> staticStrides(tileShape.size(), 1);
result = vector::InsertStridedSliceOp::create(builder, loc, src, result,
offsets, staticStrides);
}