aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp')
-rw-r--r--mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp426
1 files changed, 277 insertions, 149 deletions
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 33e8f2e..0ecb50e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -50,11 +50,10 @@ static constexpr int32_t executionSize{16};
// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
enum class NdTdescOffset : uint32_t {
- BasePtr = 0, // Base pointer (i64)
- BaseShapeW = 2, // Base shape width (i32)
- BaseShapeH = 3, // Base shape height (i32)
- TensorOffsetW = 4, // Tensor offset W (i32)
- TensorOffsetH = 5 // Tensor offset H (i32)
+ BasePtr = 0, // Base pointer (i64)
+ BaseShapeW = 2, // Base shape width (i32)
+ BaseShapeH = 3, // Base shape height (i32)
+ BasePitch = 4, // Base pitch (i32)
};
static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
@@ -151,6 +150,14 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
}
}
+//
+// Note:
+// Block operations for tile of sub byte element types are handled by
+// emulating with larger element types.
+// Tensor descriptor are keep intact and only ops consuming them are
+// emulated
+//
+
class CreateNdDescToXeVMPattern
: public OpConversionPattern<xegpu::CreateNdDescOp> {
using OpConversionPattern::OpConversionPattern;
@@ -179,16 +186,12 @@ class CreateNdDescToXeVMPattern
Value baseAddr;
Value baseShapeW;
Value baseShapeH;
- Value offsetW;
- Value offsetH;
// Source can be a memref or a pointer (ui64, ui32, i64 or i32).
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+ SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
// Descriptor shape is expected to be 2D.
int64_t rank = mixedSizes.size();
- if (rank != 2)
- return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
-
auto sourceTy = source.getType();
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
// If source is a memref, we need to extract the aligned pointer as index.
@@ -197,10 +200,20 @@ class CreateNdDescToXeVMPattern
if (!sourceMemrefTy.hasRank()) {
return rewriter.notifyMatchFailure(op, "Expected ranked Memref.");
}
- baseAddr =
- memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
+ // Access adaptor after failure check to avoid rolling back generated code
+ // for materialization cast.
+ baseAddr = adaptor.getSource();
} else {
baseAddr = adaptor.getSource();
+ if (baseAddr.getType() != i64Ty) {
+ // Pointer type may be i32. Cast to i64 if needed.
+ baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
+ }
+ }
+ // 1D tensor descriptor is just the base address.
+ if (rank == 1) {
+ rewriter.replaceOp(op, baseAddr);
+ return success();
}
// Utility for creating offset values from op fold result.
auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
@@ -209,19 +222,11 @@ class CreateNdDescToXeVMPattern
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
return val;
};
- // Offsets are not supported (0 is used).
- offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
- offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
// Get shape values from op fold results.
baseShapeW = createOffset(mixedSizes, 1);
baseShapeH = createOffset(mixedSizes, 0);
- if (sourceMemrefTy) {
- // Cast index to i64.
- baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
- } else if (baseAddr.getType() != i64Ty) {
- // Pointer type may be i32. Cast to i64 if needed.
- baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
- }
+ // Get pitch value from op fold results.
+ Value basePitch = createOffset(mixedStrides, 0);
// Populate payload.
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
@@ -235,12 +240,9 @@ class CreateNdDescToXeVMPattern
payload =
vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
static_cast<int>(NdTdescOffset::BaseShapeH));
- payload = vector::InsertOp::create(
- rewriter, loc, offsetW, payload,
- static_cast<int>(NdTdescOffset::TensorOffsetW));
- payload = vector::InsertOp::create(
- rewriter, loc, offsetH, payload,
- static_cast<int>(NdTdescOffset::TensorOffsetH));
+ payload =
+ vector::InsertOp::create(rewriter, loc, basePitch, payload,
+ static_cast<int>(NdTdescOffset::BasePitch));
rewriter.replaceOp(op, payload);
return success();
}
@@ -257,108 +259,240 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
ConversionPatternRewriter &rewriter) const override {
auto mixedOffsets = op.getMixedOffsets();
int64_t opOffsetsSize = mixedOffsets.size();
- if (opOffsetsSize != 2)
- return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
auto tdesc = adaptor.getTensorDesc();
auto tdescTy = op.getTensorDescType();
- if (tdescTy.getRank() != 2)
- return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
+ auto tileRank = tdescTy.getRank();
+ if (opOffsetsSize != tileRank)
+ return rewriter.notifyMatchFailure(
+ op, "Expected offset rank to match descriptor rank.");
auto elemType = tdescTy.getElementType();
auto elemBitSize = elemType.getIntOrFloatBitWidth();
- if (elemBitSize % 8 != 0)
+ bool isSubByte = elemBitSize < 8;
+ uint64_t wScaleFactor = 1;
+
+ if (!isSubByte && (elemBitSize % 8 != 0))
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");
+ auto tileW = tdescTy.getDimSize(tileRank - 1);
+ // For sub byte types, only 4bits are currently supported.
+ if (isSubByte) {
+ if (elemBitSize != 4)
+ return rewriter.notifyMatchFailure(
+ op, "Only sub byte types of 4bits are supported.");
+ if (tileRank != 2)
+ return rewriter.notifyMatchFailure(
+ op, "Sub byte types are only supported for 2D tensor descriptors.");
+ auto subByteFactor = 8 / elemBitSize;
+ auto tileH = tdescTy.getDimSize(0);
+ // Handle special case for packed load.
+ if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
+ if (op.getPacked().value_or(false)) {
+ // packed load is implemented as packed loads of 8bit elements.
+ if (tileH == systolicDepth * 4 &&
+ tileW == executionSize * subByteFactor) {
+ // Usage case for loading as Matrix B with pack request.
+ // source is assumed to pre-packed into 8bit elements
+ // Emulate with 8bit loads with pack request.
+ // scaled_tileW = executionSize
+ elemType = rewriter.getIntegerType(8);
+ tileW = executionSize;
+ wScaleFactor = subByteFactor;
+ }
+ }
+ }
+ // If not handled by packed load case above, handle other cases.
+ if (wScaleFactor == 1) {
+ auto sub16BitFactor = subByteFactor * 2;
+ if (tileW == executionSize * sub16BitFactor) {
+ // Usage case for loading as Matrix A operand
+ // Emulate with 16bit loads/stores.
+ // scaled_tileW = executionSize
+ elemType = rewriter.getIntegerType(16);
+ tileW = executionSize;
+ wScaleFactor = sub16BitFactor;
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported tile shape for sub byte types.");
+ }
+ }
+ // recompute element bit size for emulation.
+ elemBitSize = elemType.getIntOrFloatBitWidth();
+ }
- VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
- Value payLoadAsI64 =
- vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
- Value basePtr = vector::ExtractOp::create(
- rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
- Value baseShapeW = vector::ExtractOp::create(
- rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
- Value baseShapeH = vector::ExtractOp::create(
- rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
- // Offsets are provided by the op.
- // convert them to i32.
- Value offsetW =
- getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
- offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
- rewriter.getI32Type(), offsetW);
- Value offsetH =
- getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
- offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
- rewriter.getI32Type(), offsetH);
// Get address space from tensor descriptor memory space.
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
- // Convert base pointer (i64) to LLVM pointer type.
- Value basePtrLLVM =
- LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
- // Compute element byte size and surface width in bytes.
- Value elemByteSize = arith::ConstantIntOp::create(
- rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
- Value surfaceW =
- arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
-
- // Get tile sizes and vblocks from the tensor descriptor type.
- auto tileW = tdescTy.getDimSize(1);
- auto tileH = tdescTy.getDimSize(0);
- int32_t vblocks = tdescTy.getArrayLength();
- if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
- Value src = adaptor.getValue();
- // If store value is a scalar, get value from op instead of adaptor.
- // Adaptor might have optimized away single element vector
- if (src.getType().isIntOrFloat()) {
- src = op.getValue();
+ if (tileRank == 2) {
+ // Compute element byte size.
+ Value elemByteSize = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+ VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
+ Value payLoadAsI64 =
+ vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
+ Value basePtr =
+ vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
+ static_cast<int>(NdTdescOffset::BasePtr));
+ Value baseShapeW = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
+ Value baseShapeH = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
+ Value basePitch = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch));
+ // Offsets are provided by the op.
+ // convert them to i32.
+ Value offsetW =
+ getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
+ offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetW);
+ Value offsetH =
+ getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+ offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI32Type(), offsetH);
+ // Convert base pointer (i64) to LLVM pointer type.
+ Value basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+ // FIXME: width or pitch is not the same as baseShapeW it should be the
+ // stride of the second to last dimension in row major layout.
+ // Compute width in bytes.
+ Value baseShapeWInBytes =
+ arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+ // Compute pitch in bytes.
+ Value basePitchBytes =
+ arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
+
+ if (wScaleFactor > 1) {
+ // Scale offsetW, baseShapeWInBytes for sub byte emulation.
+ // Note: tileW is already scaled above.
+ Value wScaleFactorValLog2 = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
+ baseShapeWInBytes = arith::ShRSIOp::create(
+ rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
+ basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
+ wScaleFactorValLog2);
+ offsetW =
+ arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
}
- VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
- if (!srcVecTy)
- return rewriter.notifyMatchFailure(
- op, "Expected store value to be a vector type.");
- // Get flat vector type of integer type with matching element bit size.
- VectorType newSrcVecTy =
- encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
- if (srcVecTy != newSrcVecTy)
- src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
- auto storeCacheControl =
- translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
- xevm::BlockStore2dOp::create(
- rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
- offsetH, elemBitSize, tileW, tileH, src,
- xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
- rewriter.eraseOp(op);
- } else {
- auto loadCacheControl =
- translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
- if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
- xevm::BlockPrefetch2dOp::create(
- rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
- offsetH, elemBitSize, tileW, tileH, vblocks,
- xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ // Get tile height from the tensor descriptor type.
+ auto tileH = tdescTy.getDimSize(0);
+ // Get vblocks from the tensor descriptor type.
+ int32_t vblocks = tdescTy.getArrayLength();
+ if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+ Value src = adaptor.getValue();
+ // If store value is a scalar, get value from op instead of adaptor.
+ // Adaptor might have optimized away single element vector
+ if (src.getType().isIntOrFloat()) {
+ src = op.getValue();
+ }
+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
+ if (!srcVecTy)
+ return rewriter.notifyMatchFailure(
+ op, "Expected store value to be a vector type.");
+ // Get flat vector type of integer type with matching element bit size.
+ VectorType newSrcVecTy =
+ encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+ if (srcVecTy != newSrcVecTy)
+ src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+ auto storeCacheControl =
+ translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ xevm::BlockStore2dOp::create(
+ rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
+ basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
+ xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
rewriter.eraseOp(op);
} else {
- VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
- const bool vnni = op.getPacked().value_or(false);
- auto transposeValue = op.getTranspose();
- bool transpose =
- transposeValue.has_value() && transposeValue.value()[0] == 1;
- VectorType loadedTy = encodeVectorTypeTo(
- dstVecTy, vnni ? rewriter.getI32Type()
- : rewriter.getIntegerType(elemBitSize));
-
- Value resultFlatVec = xevm::BlockLoad2dOp::create(
- rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
- surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
- transpose, vnni,
+ auto loadCacheControl =
+ translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
+ xevm::BlockPrefetch2dOp::create(
+ rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
+ basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
+ vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ rewriter.eraseOp(op);
+ } else {
+ VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
+ const bool vnni = op.getPacked().value_or(false);
+ auto transposeValue = op.getTranspose();
+ bool transpose =
+ transposeValue.has_value() && transposeValue.value()[0] == 1;
+ VectorType loadedTy = encodeVectorTypeTo(
+ dstVecTy, vnni ? rewriter.getI32Type()
+ : rewriter.getIntegerType(elemBitSize));
+
+ Value resultFlatVec = xevm::BlockLoad2dOp::create(
+ rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
+ baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
+ tileH, vblocks, transpose, vnni,
+ xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+ resultFlatVec = vector::BitCastOp::create(
+ rewriter, loc,
+ encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
+ resultFlatVec);
+ rewriter.replaceOp(op, resultFlatVec);
+ }
+ }
+ } else {
+ // 1D tensor descriptor.
+ // `tdesc` represents base address as i64
+ // Offset in number of elements, need to multiply by element byte size.
+ // Compute byte offset.
+ // byteOffset = offset * elementByteSize
+ Value offset =
+ getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+ offset = getValueOrCreateCastToIndexLike(rewriter, loc,
+ rewriter.getI64Type(), offset);
+ // Compute element byte size.
+ Value elemByteSize = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
+ Value byteOffset =
+ rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
+ // Final address = basePtr + byteOffset
+ Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
+ loc, tdesc,
+ getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
+ byteOffset));
+ // Convert base pointer (i64) to LLVM pointer type.
+ Value finalPtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
+ if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+ Value src = adaptor.getValue();
+ // If store value is a scalar, get value from op instead of adaptor.
+ // Adaptor might have optimized away single element vector
+ if (src.getType().isIntOrFloat()) {
+ src = op.getValue();
+ }
+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
+ if (!srcVecTy)
+ return rewriter.notifyMatchFailure(
+ op, "Expected store value to be a vector type.");
+ // Get flat vector type of integer type with matching element bit size.
+ VectorType newSrcVecTy =
+ encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+ if (srcVecTy != newSrcVecTy)
+ src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+ auto storeCacheControl =
+ translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
+ op, finalPtrLLVM, src,
+ xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
+ } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
+ auto loadCacheControl =
+ translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+ VectorType resTy = cast<VectorType>(op.getValue().getType());
+ VectorType loadedTy =
+ encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
+ Value load = xevm::BlockLoadOp::create(
+ rewriter, loc, loadedTy, finalPtrLLVM,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
- resultFlatVec = vector::BitCastOp::create(
- rewriter, loc,
- encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
- resultFlatVec);
- rewriter.replaceOp(op, resultFlatVec);
+ if (loadedTy != resTy)
+ load = vector::BitCastOp::create(rewriter, loc, resTy, load);
+ rewriter.replaceOp(op, load);
+ } else {
+ return rewriter.notifyMatchFailure(
+ op, "Unsupported operation: xegpu.prefetch_nd with tensor "
+ "descriptor rank == 1");
}
}
return success();
@@ -511,9 +645,6 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
}
};
-// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions
-// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
-// 32 bits will be converted to 32 bits.
class CreateMemDescOpPattern final
: public OpConversionPattern<xegpu::CreateMemDescOp> {
public:
@@ -522,16 +653,7 @@ public:
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto resTy = op.getMemDesc();
-
- // Create the result MemRefType with the same shape, element type, and
- // memory space
- auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
-
- Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
- auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
- op.getSource(), zero, ValueRange());
- rewriter.replaceOp(op, viewOp);
+ rewriter.replaceOp(op, adaptor.getSource());
return success();
}
};
@@ -551,17 +673,27 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
- Value basePtrStruct = adaptor.getMemDesc();
+ Value baseAddr32 = adaptor.getMemDesc();
Value mdescVal = op.getMemDesc();
// Load result or Store value Type can be vector or scalar.
- Value data;
- if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
- data = op.getResult();
- else
- data = adaptor.getData();
- VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
+ Type dataTy;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ Type resType = op.getResult().getType();
+ // Some transforms may leave unit dimension in the 2D vector, adaptors do
+ // not catch it for results.
+ if (auto vecType = dyn_cast<VectorType>(resType)) {
+ assert(llvm::count_if(vecType.getShape(),
+ [](int64_t d) { return d != 1; }) <= 1 &&
+ "Expected either 1D vector or nD with unit dimensions");
+ resType = VectorType::get({vecType.getNumElements()},
+ vecType.getElementType());
+ }
+ dataTy = resType;
+ } else
+ dataTy = adaptor.getData().getType();
+ VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
if (!valOrResVecTy)
- valOrResVecTy = VectorType::get(1, data.getType());
+ valOrResVecTy = VectorType::get(1, dataTy);
int64_t elemBitWidth =
valOrResVecTy.getElementType().getIntOrFloatBitWidth();
@@ -577,21 +709,14 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
- Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
- rewriter, loc, basePtrStruct);
-
- // Convert base pointer (ptr) to i32
- Value basePtrI32 = arith::IndexCastUIOp::create(
- rewriter, loc, rewriter.getI32Type(), basePtrLLVM);
-
Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
linearOffset = arith::IndexCastUIOp::create(
rewriter, loc, rewriter.getI32Type(), linearOffset);
- basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
- elemByteSize);
+ Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
+ linearOffset, elemByteSize);
// convert base pointer (i32) to LLVM pointer type
- basePtrLLVM =
+ Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
if (op.getSubgroupBlockIoAttr()) {
@@ -927,20 +1052,22 @@ struct ConvertXeGPUToXeVMPass
return VectorType::get(sum, elemType);
});
typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
+ // Scattered descriptors are not supported in XeVM lowering.
if (type.isScattered())
+ return {};
+ if (type.getRank() == 1)
return IntegerType::get(&getContext(), 64);
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
});
- // Convert MemDescType into flattened MemRefType for SLM
+ // Convert MemDescType into i32 for SLM
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
- Type elemTy = type.getElementType();
- int numElems = type.getNumElements();
- return MemRefType::get(numElems, elemTy, AffineMap(), 3);
+ return IntegerType::get(&getContext(), 32);
});
typeConverter.addConversion([&](MemRefType type) -> Type {
- // Convert MemRefType to i64 type.
+ if (type.getMemorySpaceAsInt() == 3)
+ return IntegerType::get(&getContext(), 32);
return IntegerType::get(&getContext(), 64);
});
@@ -1057,6 +1184,7 @@ struct ConvertXeGPUToXeVMPass
};
typeConverter.addSourceMaterialization(
singleElementVectorMaterializationCast);
+ typeConverter.addSourceMaterialization(vectorMaterializationCast);
typeConverter.addTargetMaterialization(memrefMaterializationCast);
typeConverter.addTargetMaterialization(ui32MaterializationCast);
typeConverter.addTargetMaterialization(ui64MaterializationCast);