aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion')
-rw-r--r--mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp2
-rw-r--r--mlir/lib/Conversion/MathToROCDL/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp76
-rw-r--r--mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp4
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp7
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp200
7 files changed, 265 insertions, 26 deletions
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index b215211..c03f3a5 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns(
GPUSubgroupBroadcastOpToROCDL>(converter);
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);
- populateMathToROCDLConversionPatterns(converter, patterns);
+ populateMathToROCDLConversionPatterns(converter, patterns, chipset);
}
diff --git a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
index 2771955a..8cc3fde 100644
--- a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMathToROCDL
Core
LINK_LIBS PUBLIC
+ MLIRAMDGPUUtils
MLIRDialectUtils
MLIRFuncDialect
MLIRGPUToGPURuntimeTransforms
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index df219f3..a2dfc12 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -10,6 +10,8 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -19,6 +21,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/Support/DebugLog.h"
#include "../GPUCommon/GPUOpsLowering.h"
#include "../GPUCommon/OpToFuncCallLowering.h"
@@ -42,8 +45,46 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
f32ApproxFunc, f16Func);
}
+struct ClampFOpConversion final
+ : public ConvertOpToLLVMPattern<math::ClampFOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only f16 and f32 types are supported by fmed3
+ Type opTy = op.getType();
+ Type resultType = getTypeConverter()->convertType(opTy);
+
+ if (auto vectorType = dyn_cast<VectorType>(opTy))
+ opTy = vectorType.getElementType();
+
+ if (!isa<Float16Type, Float32Type>(opTy))
+ return rewriter.notifyMatchFailure(
+ op, "fmed3 only supports f16 and f32 types");
+
+ // Handle multi-dimensional vectors (converted to LLVM arrays)
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType))
+ return LLVM::detail::handleMultidimensionalVectors(
+ op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+ [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
+ typename math::ClampFOp::Adaptor adaptor(operands);
+ return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,
+ adaptor.getValue(), adaptor.getMin(),
+ adaptor.getMax());
+ },
+ rewriter);
+
+ // Handle 1D vectors and scalars directly
+ rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
+ op.getMin(), op.getMax());
+ return success();
+ }
+};
+
void mlir::populateMathToROCDLConversionPatterns(
- const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+ const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ std::optional<amdgpu::Chipset> chipset) {
// Handled by mathToLLVM: math::AbsIOp
// Handled by mathToLLVM: math::AbsFOp
// Handled by mathToLLVM: math::CopySignOp
@@ -118,15 +159,21 @@ void mlir::populateMathToROCDLConversionPatterns(
// worth creating a separate pass for it.
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
"__ocml_fmod_f64", "__ocml_fmod_f16");
+
+ if (chipset.has_value() && chipset->majorVersion >= 9) {
+ patterns.add<ClampFOpConversion>(converter);
+ } else {
+ LDBG() << "Chipset dependent patterns were not added";
+ }
}
-namespace {
-struct ConvertMathToROCDLPass
- : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
- ConvertMathToROCDLPass() = default;
+struct ConvertMathToROCDLPass final
+ : impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
+ using impl::ConvertMathToROCDLBase<
+ ConvertMathToROCDLPass>::ConvertMathToROCDLBase;
+
void runOnOperation() override;
};
-} // namespace
void ConvertMathToROCDLPass::runOnOperation() {
auto m = getOperation();
@@ -135,10 +182,21 @@ void ConvertMathToROCDLPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
- populateMathToROCDLConversionPatterns(converter, patterns);
+
+ FailureOr<amdgpu::Chipset> maybeChipset;
+ if (!chipset.empty()) {
+ maybeChipset = amdgpu::Chipset::parse(chipset);
+ if (failed(maybeChipset))
+ return signalPassFailure();
+ }
+ populateMathToROCDLConversionPatterns(
+ converter, patterns,
+ succeeded(maybeChipset) ? std::optional(*maybeChipset) : std::nullopt);
+
ConversionTarget target(getContext());
- target.addLegalDialect<BuiltinDialect, func::FuncDialect,
- vector::VectorDialect, LLVM::LLVMDialect>();
+ target
+ .addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
+ LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index f0d8b78..610ce1f 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -407,11 +407,11 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
if (auto vectorType = dyn_cast<VectorType>(operandType))
nanAttr = DenseElementsAttr::get(vectorType, nan);
- Value NanValue =
+ Value nanValue =
spirv::ConstantOp::create(rewriter, loc, operandType, nanAttr);
Value lhs =
spirv::SelectOp::create(rewriter, loc, cmpNegativeWithFractionalExp,
- NanValue, adaptor.getLhs());
+ nanValue, adaptor.getLhs());
Value abs = spirv::GLFAbsOp::create(rewriter, loc, lhs);
// TODO: The following just forcefully casts y into an integer value in
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 5355909..41d8d53 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1723,17 +1723,18 @@ struct VectorBroadcastScalarToLowRankLowering
return success();
}
- // For 1-d vector, we additionally do a `vectorshuffle`.
auto v =
LLVM::InsertElementOp::create(rewriter, broadcast.getLoc(), vectorType,
poison, adaptor.getSource(), zero);
+ // For 1-d vector, we additionally do a `shufflevector`.
int64_t width = cast<VectorType>(broadcast.getType()).getDimSize(0);
SmallVector<int32_t> zeroValues(width, 0);
// Shuffle the value across the desired number of elements.
- rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(broadcast, v, poison,
- zeroValues);
+ auto shuffle = rewriter.createOrFold<LLVM::ShuffleVectorOp>(
+ broadcast.getLoc(), v, poison, zeroValues);
+ rewriter.replaceOp(broadcast, shuffle);
return success();
}
};
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
index 84b2580..dd9edc4 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
@@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM
MLIRIndexDialect
MLIRSCFDialect
MLIRXeGPUDialect
+ MLIRXeGPUUtils
MLIRPass
MLIRTransforms
MLIRSCFTransforms
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 71687b1..fcbf66d 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -20,7 +20,9 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
@@ -62,6 +64,7 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
case xegpu::MemorySpace::SLM:
return static_cast<int>(xevm::AddrSpace::SHARED);
}
+ llvm_unreachable("Unknown XeGPU memory space");
}
// Get same bitwidth flat vector type of new element type.
@@ -185,6 +188,7 @@ class CreateNdDescToXeVMPattern
int64_t rank = mixedSizes.size();
if (rank != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
+
auto sourceTy = source.getType();
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
// If source is a memref, we need to extract the aligned pointer as index.
@@ -363,10 +367,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
// Add a builder that creates
// offset * elemByteSize + baseAddr
-static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
- Value baseAddr, Value offset, int64_t elemByteSize) {
+static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
+ Location loc, Value baseAddr, Value offset,
+ int64_t elemByteSize) {
Value byteSize = arith::ConstantIntOp::create(
- rewriter, loc, rewriter.getI64Type(), elemByteSize);
+ rewriter, loc, baseAddr.getType(), elemByteSize);
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
return newAddr;
@@ -390,7 +395,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
// Load result or Store valye Type can be vector or scalar.
Type valOrResTy;
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
- valOrResTy = op.getResult().getType();
+ valOrResTy =
+ this->getTypeConverter()->convertType(op.getResult().getType());
else
valOrResTy = adaptor.getValue().getType();
VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
@@ -441,7 +447,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
// If offset is provided, we add them to the base pointer.
// Offset is in number of elements, we need to multiply by
// element byte size.
- basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize);
+ basePtrI64 =
+ addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
}
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
@@ -504,6 +511,147 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
}
};
+// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions
+// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
+// 32 bits will be converted to 32 bits.
+class CreateMemDescOpPattern final
+ : public OpConversionPattern<xegpu::CreateMemDescOp> {
+public:
+ using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto resTy = op.getMemDesc();
+
+ // Create the result MemRefType with the same shape, element type, and
+ // memory space
+ auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
+
+ Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
+ auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
+ op.getSource(), zero, ValueRange());
+ rewriter.replaceOp(op, viewOp);
+ return success();
+ }
+};
+
+template <typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
+class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
+
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ Value basePtrStruct = adaptor.getMemDesc();
+ Value mdescVal = op.getMemDesc();
+ // Load result or Store value Type can be vector or scalar.
+ Value data;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
+ data = op.getResult();
+ else
+ data = adaptor.getData();
+ VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
+ if (!valOrResVecTy)
+ valOrResVecTy = VectorType::get(1, data.getType());
+
+ int64_t elemBitWidth =
+ valOrResVecTy.getElementType().getIntOrFloatBitWidth();
+ // Element type must be multiple of 8 bits.
+ if (elemBitWidth % 8 != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+ int64_t elemByteSize = elemBitWidth / 8;
+
+ // Default memory space is SLM.
+ LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
+
+ auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
+
+ Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, loc, basePtrStruct);
+
+ // Convert base pointer (ptr) to i32
+ Value basePtrI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), basePtrLLVM);
+
+ Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
+ linearOffset = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), linearOffset);
+ basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
+ elemByteSize);
+
+ // convert base pointer (i32) to LLVM pointer type
+ basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
+
+ if (op.getSubgroupBlockIoAttr()) {
+ // if the attribute 'subgroup_block_io' is set to true, it lowers to
+ // xevm.blockload
+
+ Type intElemTy = rewriter.getIntegerType(elemBitWidth);
+ VectorType intVecTy =
+ VectorType::get(valOrResVecTy.getShape(), intElemTy);
+
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ Value loadOp =
+ xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
+ if (intVecTy != valOrResVecTy) {
+ loadOp =
+ vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
+ }
+ rewriter.replaceOp(op, loadOp);
+ } else {
+ Value dataToStore = adaptor.getData();
+ if (valOrResVecTy != intVecTy) {
+ dataToStore =
+ vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
+ }
+ xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
+ nullptr);
+ rewriter.eraseOp(op);
+ }
+ return success();
+ }
+
+ if (valOrResVecTy.getNumElements() >= 1) {
+ auto chipOpt = xegpu::getChipStr(op);
+ if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
+ // the lowering for chunk load only works for pvc and bmg
+ return rewriter.notifyMatchFailure(
+ op, "The lowering is specific to pvc or bmg.");
+ }
+ }
+
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
+ // operation. LLVM load/store does not support vector of size 1, so we
+ // need to handle this case separately.
+ auto scalarTy = valOrResVecTy.getElementType();
+ LLVM::LoadOp loadOp;
+ if (valOrResVecTy.getNumElements() == 1)
+ loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
+ else
+ loadOp =
+ LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
+ rewriter.replaceOp(op, loadOp);
+ } else {
+ LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
+ rewriter.eraseOp(op);
+ }
+ return success();
+ }
+};
+
class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -546,8 +694,8 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
op, "Expected element type bit width to be multiple of 8.");
elemByteSize = elemBitWidth / 8;
}
- basePtrI64 =
- addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
+ basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
+ elemByteSize);
}
}
// Default memory space is global.
@@ -784,6 +932,13 @@ struct ConvertXeGPUToXeVMPass
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
});
+ // Convert MemDescType into flattened MemRefType for SLM
+ typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
+ Type elemTy = type.getElementType();
+ int numElems = type.getNumElements();
+ return MemRefType::get(numElems, elemTy, AffineMap(), 3);
+ });
+
typeConverter.addConversion([&](MemRefType type) -> Type {
// Convert MemRefType to i64 type.
return IntegerType::get(&getContext(), 64);
@@ -878,10 +1033,30 @@ struct ConvertXeGPUToXeVMPass
}
return {};
};
- typeConverter.addSourceMaterialization(memrefMaterializationCast);
- typeConverter.addSourceMaterialization(ui64MaterializationCast);
- typeConverter.addSourceMaterialization(ui32MaterializationCast);
- typeConverter.addSourceMaterialization(vectorMaterializationCast);
+
+ // If result type of original op is single element vector and lowered type
+ // is scalar. This materialization cast creates a single element vector by
+ // broadcasting the scalar value.
+ auto singleElementVectorMaterializationCast =
+ [](OpBuilder &builder, Type type, ValueRange inputs,
+ Location loc) -> Value {
+ if (inputs.size() != 1)
+ return {};
+ auto input = inputs.front();
+ if (input.getType().isIntOrIndexOrFloat()) {
+ // If the input is a scalar, and the target type is a vector of single
+ // element, create a single element vector by broadcasting.
+ if (auto vecTy = dyn_cast<VectorType>(type)) {
+ if (vecTy.getNumElements() == 1) {
+ return vector::BroadcastOp::create(builder, loc, vecTy, input)
+ .getResult();
+ }
+ }
+ }
+ return {};
+ };
+ typeConverter.addSourceMaterialization(
+ singleElementVectorMaterializationCast);
typeConverter.addTargetMaterialization(memrefMaterializationCast);
typeConverter.addTargetMaterialization(ui32MaterializationCast);
typeConverter.addTargetMaterialization(ui64MaterializationCast);
@@ -918,6 +1093,9 @@ void mlir::populateXeGPUToXeVMConversionPatterns(
LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
typeConverter, patterns.getContext());
+ patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
+ LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
+ CreateMemDescOpPattern>(typeConverter, patterns.getContext());
patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
patterns.getContext());
}