aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r--mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp3
-rw-r--r--mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp8
-rw-r--r--mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp2
-rw-r--r--mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp186
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp2
5 files changed, 120 insertions, 81 deletions
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index ba57155..03ed4d5 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -240,8 +240,7 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
- using Adaptor =
- typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
+ using Adaptor = ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
LogicalResult
matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index 798d8b0..b75968e 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -137,8 +137,7 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
/// op to llvm.br.
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
- using Adaptor =
- typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor;
+ using Adaptor = ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor;
LogicalResult
matchAndRewrite(cf::BranchOp op, Adaptor adaptor,
@@ -163,8 +162,7 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
/// branch op to llvm.cond_br.
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
- using Adaptor =
- typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor;
+ using Adaptor = ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor;
LogicalResult
matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor,
@@ -204,7 +202,7 @@ struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> {
using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor,
+ matchAndRewrite(cf::SwitchOp op, cf::SwitchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Get or convert default block.
FailureOr<Block *> convertedDefaultBlock = getConvertedBlock(
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index a2dfc12..a922338 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -68,7 +68,7 @@ struct ClampFOpConversion final
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
- typename math::ClampFOp::Adaptor adaptor(operands);
+ math::ClampFOp::Adaptor adaptor(operands);
return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,
adaptor.getValue(), adaptor.getMin(),
adaptor.getMax());
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 91c1aa5..1b4d1a4 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -97,57 +97,23 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
return success();
}
-static xegpu::CreateNdDescOp
-createNdDescriptor(PatternRewriter &rewriter, Location loc,
- xegpu::TensorDescType descType, TypedValue<MemRefType> src,
- Operation::operand_range offsets) {
+static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
+ Location loc,
+ xegpu::TensorDescType descType,
+ TypedValue<MemRefType> src) {
MemRefType srcTy = src.getType();
auto [strides, offset] = srcTy.getStridesAndOffset();
xegpu::CreateNdDescOp ndDesc;
if (srcTy.hasStaticShape()) {
- ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
- getAsOpFoldResult(offsets));
+ ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
} else {
// In case of any dynamic shapes, source's shape and strides have to be
// explicitly provided.
- SmallVector<Value> sourceDims;
- unsigned srcRank = srcTy.getRank();
- for (unsigned i = 0; i < srcRank; ++i)
- sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
-
- SmallVector<int64_t> constOffsets;
- SmallVector<Value> dynOffsets;
- for (Value offset : offsets) {
- std::optional<int64_t> staticVal = getConstantIntValue(offset);
- if (!staticVal)
- dynOffsets.push_back(offset);
- constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
- }
-
- SmallVector<Value> dynShapes;
- for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
- if (shape == ShapedType::kDynamic)
- dynShapes.push_back(sourceDims[idx]);
- }
-
- // Compute strides in reverse order.
- SmallVector<Value> dynStrides;
- Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
- // Last stride is guaranteed to be static and unit.
- for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
- accStride =
- arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
- if (strides[i] == ShapedType::kDynamic)
- dynStrides.push_back(accStride);
- }
- std::reverse(dynStrides.begin(), dynStrides.end());
-
- ndDesc = xegpu::CreateNdDescOp::create(
- rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
- DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
- DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
- DenseI64ArrayAttr::get(rewriter.getContext(), strides));
+ auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
+ ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
+ meta.getConstifiedMixedSizes(),
+ meta.getConstifiedMixedStrides());
}
return ndDesc;
@@ -392,6 +358,62 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
.getResult();
}
+// Collapses shapes of a nD memref to the target rank while applying offsets for
+// the collapsed dimensions. Returns the new memref value and the remaining
+// offsets for the last targetRank dimensions. For example:
+// input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
+// output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3]
+static std::pair<Value, SmallVector<OpFoldResult>>
+convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc,
+ Value memref,
+ SmallVector<OpFoldResult> offsets,
+ int64_t targetRank) {
+ auto memrefType = cast<MemRefType>(memref.getType());
+ unsigned rank = memrefType.getRank();
+
+ if (rank <= targetRank)
+ return {memref, offsets};
+
+ int64_t numCombinedDims = rank - targetRank;
+ SmallVector<OpFoldResult> subviewOffsets;
+ SmallVector<OpFoldResult> subviewSizes;
+ SmallVector<OpFoldResult> subviewStrides;
+
+ // For the combined dimensions: use the provided offsets, size=1, stride=1
+ for (unsigned i = 0; i < numCombinedDims; ++i) {
+ subviewOffsets.push_back(offsets[i]);
+ subviewSizes.push_back(rewriter.getI64IntegerAttr(1));
+ subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
+ }
+
+ // For the last targetRank dimensions: offset=0, use full size, stride=1
+ SmallVector<int64_t> resultShape;
+ auto originalShape = memrefType.getShape();
+ auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref);
+ for (unsigned i = numCombinedDims; i < rank; ++i) {
+ subviewOffsets.push_back(rewriter.getI64IntegerAttr(0));
+ if (ShapedType::isDynamic(originalShape[i])) {
+ subviewSizes.push_back(meta.getSizes()[i]);
+ resultShape.push_back(ShapedType::kDynamic);
+ } else {
+ subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i]));
+ resultShape.push_back(originalShape[i]);
+ }
+ subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
+ }
+
+ auto resultType = memref::SubViewOp::inferRankReducedResultType(
+ resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
+ auto subviewOp =
+ memref::SubViewOp::create(rewriter, loc, resultType, memref,
+ subviewOffsets, subviewSizes, subviewStrides);
+
+ // Return the remaining offsets for the last targetRank dimensions
+ SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims,
+ offsets.end());
+ return {subviewOp.getResult(), newOffsets};
+}
+
template <
typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
@@ -435,7 +457,8 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
- /*l3_hint=*/xegpu::CachePolicyAttr{});
+ /*l3_hint=*/xegpu::CachePolicyAttr{},
+ /*layout=*/nullptr);
rewriter.replaceOp(readOp, gatherOp.getResult());
return success();
@@ -469,7 +492,8 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
- /*l3_hint=*/xegpu::CachePolicyAttr{});
+ /*l3_hint=*/xegpu::CachePolicyAttr{},
+ /*layout=*/nullptr);
rewriter.eraseOp(writeOp);
return success();
}
@@ -523,18 +547,19 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
descShape, elementType, /*array_length=*/1,
/*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc =
- createNdDescriptor(rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
- readOp.getIndices());
-
DenseI64ArrayAttr transposeAttr =
!isTransposeLoad ? nullptr
: DenseI64ArrayAttr::get(rewriter.getContext(),
ArrayRef<int64_t>{1, 0});
+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+ rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()),
+ vecTy.getRank());
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+
+ auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
/*packed=*/nullptr, transposeAttr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
@@ -575,21 +600,23 @@ struct TransferWriteLowering
if (!map.isMinorIdentity())
return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+ rewriter, loc, writeOp.getBase(),
+ getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank());
+
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc =
- createNdDescriptor(rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
- writeOp.getIndices());
-
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- auto storeOp =
- xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+
+ auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
+ ndDesc, indices,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(writeOp, storeOp);
return success();
@@ -621,7 +648,8 @@ struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
- /*l3_hint=*/xegpu::CachePolicyAttr{});
+ /*l3_hint=*/xegpu::CachePolicyAttr{},
+ /*layout=*/nullptr);
auto selectOp =
arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
@@ -655,7 +683,8 @@ struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
/*chunk_size=*/IntegerAttr{},
/*l1_hint=*/xegpu::CachePolicyAttr{},
/*l2_hint=*/xegpu::CachePolicyAttr{},
- /*l3_hint=*/xegpu::CachePolicyAttr{});
+ /*l3_hint=*/xegpu::CachePolicyAttr{},
+ /*layout=*/nullptr);
rewriter.eraseOp(scatterOp);
return success();
}
@@ -674,19 +703,24 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
// Boundary check is available only for block instructions.
bool boundaryCheck = vecTy.getRank() > 1;
+ // By default, no specific caching policy is assigned.
+ xegpu::CachePolicyAttr hint = nullptr;
+
+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+ rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()),
+ vecTy.getRank());
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
boundaryCheck, xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
- rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
- // By default, no specific caching policy is assigned.
- xegpu::CachePolicyAttr hint = nullptr;
- auto loadNdOp = xegpu::LoadNdOp::create(
- rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+ auto loadNdOp =
+ xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
+ /*packed=*/nullptr, /*transpose=*/nullptr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(loadOp, loadNdOp);
return success();
@@ -708,18 +742,24 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
// Boundary check is available only for block instructions.
bool boundaryCheck = vecTy.getRank() > 1;
+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+ rewriter, loc, storeOp.getBase(),
+ getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank());
+
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
- rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+
auto storeNdOp =
- xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
+ xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
+
rewriter.replaceOp(storeOp, storeNdOp);
return success();
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 33e8f2e..de552ce 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -562,6 +562,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
if (!valOrResVecTy)
valOrResVecTy = VectorType::get(1, data.getType());
+ if (valOrResVecTy.getShape().size() != 1)
+ return rewriter.notifyMatchFailure(op, "Expected 1D data vector.");
int64_t elemBitWidth =
valOrResVecTy.getElementType().getIntOrFloatBitWidth();