diff options
Diffstat (limited to 'mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp')
| -rw-r--r-- | mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 426 |
1 files changed, 277 insertions, 149 deletions
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp index 33e8f2e..0ecb50e 100644 --- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -50,11 +50,10 @@ static constexpr int32_t executionSize{16}; // Offsets to individual fields of the 8xi32 layout nd tensor descriptor. enum class NdTdescOffset : uint32_t { - BasePtr = 0, // Base pointer (i64) - BaseShapeW = 2, // Base shape width (i32) - BaseShapeH = 3, // Base shape height (i32) - TensorOffsetW = 4, // Tensor offset W (i32) - TensorOffsetH = 5 // Tensor offset H (i32) + BasePtr = 0, // Base pointer (i64) + BaseShapeW = 2, // Base shape width (i32) + BaseShapeH = 3, // Base shape height (i32) + BasePitch = 4, // Base pitch (i32) }; static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { @@ -151,6 +150,14 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint, } } +// +// Note: +// Block operations for tile of sub byte element types are handled by +// emulating with larger element types. +// Tensor descriptor are keep intact and only ops consuming them are +// emulated +// + class CreateNdDescToXeVMPattern : public OpConversionPattern<xegpu::CreateNdDescOp> { using OpConversionPattern::OpConversionPattern; @@ -179,16 +186,12 @@ class CreateNdDescToXeVMPattern Value baseAddr; Value baseShapeW; Value baseShapeH; - Value offsetW; - Value offsetH; // Source can be a memref or a pointer (ui64, ui32, i64 or i32). SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes(); + SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides(); // Descriptor shape is expected to be 2D. int64_t rank = mixedSizes.size(); - if (rank != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D shape."); - auto sourceTy = source.getType(); auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy); // If source is a memref, we need to extract the aligned pointer as index. @@ -197,10 +200,20 @@ class CreateNdDescToXeVMPattern if (!sourceMemrefTy.hasRank()) { return rewriter.notifyMatchFailure(op, "Expected ranked Memref."); } - baseAddr = - memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); + // Access adaptor after failure check to avoid rolling back generated code + // for materialization cast. + baseAddr = adaptor.getSource(); } else { baseAddr = adaptor.getSource(); + if (baseAddr.getType() != i64Ty) { + // Pointer type may be i32. Cast to i64 if needed. + baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); + } + } + // 1D tensor descriptor is just the base address. + if (rank == 1) { + rewriter.replaceOp(op, baseAddr); + return success(); } // Utility for creating offset values from op fold result. auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec, @@ -209,19 +222,11 @@ class CreateNdDescToXeVMPattern val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val); return val; }; - // Offsets are not supported (0 is used). - offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); - offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); // Get shape values from op fold results. baseShapeW = createOffset(mixedSizes, 1); baseShapeH = createOffset(mixedSizes, 0); - if (sourceMemrefTy) { - // Cast index to i64. - baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr); - } else if (baseAddr.getType() != i64Ty) { - // Pointer type may be i32. Cast to i64 if needed. - baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); - } + // Get pitch value from op fold results. + Value basePitch = createOffset(mixedStrides, 0); // Populate payload. Value payLoadAsI64 = vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload); @@ -235,12 +240,9 @@ class CreateNdDescToXeVMPattern payload = vector::InsertOp::create(rewriter, loc, baseShapeH, payload, static_cast<int>(NdTdescOffset::BaseShapeH)); - payload = vector::InsertOp::create( - rewriter, loc, offsetW, payload, - static_cast<int>(NdTdescOffset::TensorOffsetW)); - payload = vector::InsertOp::create( - rewriter, loc, offsetH, payload, - static_cast<int>(NdTdescOffset::TensorOffsetH)); + payload = + vector::InsertOp::create(rewriter, loc, basePitch, payload, + static_cast<int>(NdTdescOffset::BasePitch)); rewriter.replaceOp(op, payload); return success(); } @@ -257,108 +259,240 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> { ConversionPatternRewriter &rewriter) const override { auto mixedOffsets = op.getMixedOffsets(); int64_t opOffsetsSize = mixedOffsets.size(); - if (opOffsetsSize != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); auto tdesc = adaptor.getTensorDesc(); auto tdescTy = op.getTensorDescType(); - if (tdescTy.getRank() != 2) - return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor."); + auto tileRank = tdescTy.getRank(); + if (opOffsetsSize != tileRank) + return rewriter.notifyMatchFailure( + op, "Expected offset rank to match descriptor rank."); auto elemType = tdescTy.getElementType(); auto elemBitSize = elemType.getIntOrFloatBitWidth(); - if (elemBitSize % 8 != 0) + bool isSubByte = elemBitSize < 8; + uint64_t wScaleFactor = 1; + + if (!isSubByte && (elemBitSize % 8 != 0)) return rewriter.notifyMatchFailure( op, "Expected element type bit width to be multiple of 8."); + auto tileW = tdescTy.getDimSize(tileRank - 1); + // For sub byte types, only 4bits are currently supported. + if (isSubByte) { + if (elemBitSize != 4) + return rewriter.notifyMatchFailure( + op, "Only sub byte types of 4bits are supported."); + if (tileRank != 2) + return rewriter.notifyMatchFailure( + op, "Sub byte types are only supported for 2D tensor descriptors."); + auto subByteFactor = 8 / elemBitSize; + auto tileH = tdescTy.getDimSize(0); + // Handle special case for packed load. + if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) { + if (op.getPacked().value_or(false)) { + // packed load is implemented as packed loads of 8bit elements. + if (tileH == systolicDepth * 4 && + tileW == executionSize * subByteFactor) { + // Usage case for loading as Matrix B with pack request. + // source is assumed to pre-packed into 8bit elements + // Emulate with 8bit loads with pack request. + // scaled_tileW = executionSize + elemType = rewriter.getIntegerType(8); + tileW = executionSize; + wScaleFactor = subByteFactor; + } + } + } + // If not handled by packed load case above, handle other cases. + if (wScaleFactor == 1) { + auto sub16BitFactor = subByteFactor * 2; + if (tileW == executionSize * sub16BitFactor) { + // Usage case for loading as Matrix A operand + // Emulate with 16bit loads/stores. + // scaled_tileW = executionSize + elemType = rewriter.getIntegerType(16); + tileW = executionSize; + wScaleFactor = sub16BitFactor; + } else { + return rewriter.notifyMatchFailure( + op, "Unsupported tile shape for sub byte types."); + } + } + // recompute element bit size for emulation. + elemBitSize = elemType.getIntOrFloatBitWidth(); + } - VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); - Value payLoadAsI64 = - vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc); - Value basePtr = vector::ExtractOp::create( - rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr)); - Value baseShapeW = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW)); - Value baseShapeH = vector::ExtractOp::create( - rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH)); - // Offsets are provided by the op. - // convert them to i32. - Value offsetW = - getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); - offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, - rewriter.getI32Type(), offsetW); - Value offsetH = - getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); - offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, - rewriter.getI32Type(), offsetH); // Get address space from tensor descriptor memory space. auto ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); - // Convert base pointer (i64) to LLVM pointer type. - Value basePtrLLVM = - LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); - // Compute element byte size and surface width in bytes. - Value elemByteSize = arith::ConstantIntOp::create( - rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); - Value surfaceW = - arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); - - // Get tile sizes and vblocks from the tensor descriptor type. - auto tileW = tdescTy.getDimSize(1); - auto tileH = tdescTy.getDimSize(0); - int32_t vblocks = tdescTy.getArrayLength(); - if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) { - Value src = adaptor.getValue(); - // If store value is a scalar, get value from op instead of adaptor. - // Adaptor might have optimized away single element vector - if (src.getType().isIntOrFloat()) { - src = op.getValue(); + if (tileRank == 2) { + // Compute element byte size. + Value elemByteSize = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); + VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); + Value payLoadAsI64 = + vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc); + Value basePtr = + vector::ExtractOp::create(rewriter, loc, payLoadAsI64, + static_cast<int>(NdTdescOffset::BasePtr)); + Value baseShapeW = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW)); + Value baseShapeH = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH)); + Value basePitch = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch)); + // Offsets are provided by the op. + // convert them to i32. + Value offsetW = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); + offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetW); + Value offsetH = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); + offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetH); + // Convert base pointer (i64) to LLVM pointer type. + Value basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); + // FIXME: width or pitch is not the same as baseShapeW it should be the + // stride of the second to last dimension in row major layout. + // Compute width in bytes. + Value baseShapeWInBytes = + arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); + // Compute pitch in bytes. + Value basePitchBytes = + arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize); + + if (wScaleFactor > 1) { + // Scale offsetW, baseShapeWInBytes for sub byte emulation. + // Note: tileW is already scaled above. + Value wScaleFactorValLog2 = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor)); + baseShapeWInBytes = arith::ShRSIOp::create( + rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2); + basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes, + wScaleFactorValLog2); + offsetW = + arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2); } - VectorType srcVecTy = dyn_cast<VectorType>(src.getType()); - if (!srcVecTy) - return rewriter.notifyMatchFailure( - op, "Expected store value to be a vector type."); - // Get flat vector type of integer type with matching element bit size. - VectorType newSrcVecTy = - encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); - if (srcVecTy != newSrcVecTy) - src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); - auto storeCacheControl = - translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); - xevm::BlockStore2dOp::create( - rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, - offsetH, elemBitSize, tileW, tileH, src, - xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); - rewriter.eraseOp(op); - } else { - auto loadCacheControl = - translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); - if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) { - xevm::BlockPrefetch2dOp::create( - rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, - offsetH, elemBitSize, tileW, tileH, vblocks, - xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + // Get tile height from the tensor descriptor type. + auto tileH = tdescTy.getDimSize(0); + // Get vblocks from the tensor descriptor type. + int32_t vblocks = tdescTy.getArrayLength(); + if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) { + Value src = adaptor.getValue(); + // If store value is a scalar, get value from op instead of adaptor. + // Adaptor might have optimized away single element vector + if (src.getType().isIntOrFloat()) { + src = op.getValue(); + } + VectorType srcVecTy = dyn_cast<VectorType>(src.getType()); + if (!srcVecTy) + return rewriter.notifyMatchFailure( + op, "Expected store value to be a vector type."); + // Get flat vector type of integer type with matching element bit size. + VectorType newSrcVecTy = + encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); + if (srcVecTy != newSrcVecTy) + src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); + auto storeCacheControl = + translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + xevm::BlockStore2dOp::create( + rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH, + basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src, + xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); rewriter.eraseOp(op); } else { - VectorType dstVecTy = cast<VectorType>(op.getValue().getType()); - const bool vnni = op.getPacked().value_or(false); - auto transposeValue = op.getTranspose(); - bool transpose = - transposeValue.has_value() && transposeValue.value()[0] == 1; - VectorType loadedTy = encodeVectorTypeTo( - dstVecTy, vnni ? rewriter.getI32Type() - : rewriter.getIntegerType(elemBitSize)); - - Value resultFlatVec = xevm::BlockLoad2dOp::create( - rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH, - surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, - transpose, vnni, + auto loadCacheControl = + translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) { + xevm::BlockPrefetch2dOp::create( + rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH, + basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, + vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + rewriter.eraseOp(op); + } else { + VectorType dstVecTy = cast<VectorType>(op.getValue().getType()); + const bool vnni = op.getPacked().value_or(false); + auto transposeValue = op.getTranspose(); + bool transpose = + transposeValue.has_value() && transposeValue.value()[0] == 1; + VectorType loadedTy = encodeVectorTypeTo( + dstVecTy, vnni ? rewriter.getI32Type() + : rewriter.getIntegerType(elemBitSize)); + + Value resultFlatVec = xevm::BlockLoad2dOp::create( + rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes, + baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW, + tileH, vblocks, transpose, vnni, + xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + resultFlatVec = vector::BitCastOp::create( + rewriter, loc, + encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()), + resultFlatVec); + rewriter.replaceOp(op, resultFlatVec); + } + } + } else { + // 1D tensor descriptor. + // `tdesc` represents base address as i64 + // Offset in number of elements, need to multiply by element byte size. + // Compute byte offset. + // byteOffset = offset * elementByteSize + Value offset = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); + offset = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI64Type(), offset); + // Compute element byte size. + Value elemByteSize = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI64Type(), elemBitSize / 8); + Value byteOffset = + rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize); + // Final address = basePtr + byteOffset + Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>( + loc, tdesc, + getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(), + byteOffset)); + // Convert base pointer (i64) to LLVM pointer type. + Value finalPtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64); + if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) { + Value src = adaptor.getValue(); + // If store value is a scalar, get value from op instead of adaptor. + // Adaptor might have optimized away single element vector + if (src.getType().isIntOrFloat()) { + src = op.getValue(); + } + VectorType srcVecTy = dyn_cast<VectorType>(src.getType()); + if (!srcVecTy) + return rewriter.notifyMatchFailure( + op, "Expected store value to be a vector type."); + // Get flat vector type of integer type with matching element bit size. + VectorType newSrcVecTy = + encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); + if (srcVecTy != newSrcVecTy) + src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); + auto storeCacheControl = + translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>( + op, finalPtrLLVM, src, + xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); + } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) { + auto loadCacheControl = + translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + VectorType resTy = cast<VectorType>(op.getValue().getType()); + VectorType loadedTy = + encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize)); + Value load = xevm::BlockLoadOp::create( + rewriter, loc, loadedTy, finalPtrLLVM, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); - resultFlatVec = vector::BitCastOp::create( - rewriter, loc, - encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()), - resultFlatVec); - rewriter.replaceOp(op, resultFlatVec); + if (loadedTy != resTy) + load = vector::BitCastOp::create(rewriter, loc, resTy, load); + rewriter.replaceOp(op, load); + } else { + return rewriter.notifyMatchFailure( + op, "Unsupported operation: xegpu.prefetch_nd with tensor " + "descriptor rank == 1"); } } return success(); @@ -511,9 +645,6 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { } }; -// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions -// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than -// 32 bits will be converted to 32 bits. class CreateMemDescOpPattern final : public OpConversionPattern<xegpu::CreateMemDescOp> { public: @@ -522,16 +653,7 @@ public: matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto resTy = op.getMemDesc(); - - // Create the result MemRefType with the same shape, element type, and - // memory space - auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy); - - Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); - auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, - op.getSource(), zero, ValueRange()); - rewriter.replaceOp(op, viewOp); + rewriter.replaceOp(op, adaptor.getSource()); return success(); } }; @@ -551,17 +673,27 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> { auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); - Value basePtrStruct = adaptor.getMemDesc(); + Value baseAddr32 = adaptor.getMemDesc(); Value mdescVal = op.getMemDesc(); // Load result or Store value Type can be vector or scalar. - Value data; - if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) - data = op.getResult(); - else - data = adaptor.getData(); - VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType()); + Type dataTy; + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) { + Type resType = op.getResult().getType(); + // Some transforms may leave unit dimension in the 2D vector, adaptors do + // not catch it for results. + if (auto vecType = dyn_cast<VectorType>(resType)) { + assert(llvm::count_if(vecType.getShape(), + [](int64_t d) { return d != 1; }) <= 1 && + "Expected either 1D vector or nD with unit dimensions"); + resType = VectorType::get({vecType.getNumElements()}, + vecType.getElementType()); + } + dataTy = resType; + } else + dataTy = adaptor.getData().getType(); + VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy); if (!valOrResVecTy) - valOrResVecTy = VectorType::get(1, data.getType()); + valOrResVecTy = VectorType::get(1, dataTy); int64_t elemBitWidth = valOrResVecTy.getElementType().getIntOrFloatBitWidth(); @@ -577,21 +709,14 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> { auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType()); - Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create( - rewriter, loc, basePtrStruct); - - // Convert base pointer (ptr) to i32 - Value basePtrI32 = arith::IndexCastUIOp::create( - rewriter, loc, rewriter.getI32Type(), basePtrLLVM); - Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); linearOffset = arith::IndexCastUIOp::create( rewriter, loc, rewriter.getI32Type(), linearOffset); - basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset, - elemByteSize); + Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32, + linearOffset, elemByteSize); // convert base pointer (i32) to LLVM pointer type - basePtrLLVM = + Value basePtrLLVM = LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32); if (op.getSubgroupBlockIoAttr()) { @@ -927,20 +1052,22 @@ struct ConvertXeGPUToXeVMPass return VectorType::get(sum, elemType); }); typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type { + // Scattered descriptors are not supported in XeVM lowering. if (type.isScattered()) + return {}; + if (type.getRank() == 1) return IntegerType::get(&getContext(), 64); auto i32Type = IntegerType::get(&getContext(), 32); return VectorType::get(8, i32Type); }); - // Convert MemDescType into flattened MemRefType for SLM + // Convert MemDescType into i32 for SLM typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { - Type elemTy = type.getElementType(); - int numElems = type.getNumElements(); - return MemRefType::get(numElems, elemTy, AffineMap(), 3); + return IntegerType::get(&getContext(), 32); }); typeConverter.addConversion([&](MemRefType type) -> Type { - // Convert MemRefType to i64 type. + if (type.getMemorySpaceAsInt() == 3) + return IntegerType::get(&getContext(), 32); return IntegerType::get(&getContext(), 64); }); @@ -1057,6 +1184,7 @@ struct ConvertXeGPUToXeVMPass }; typeConverter.addSourceMaterialization( singleElementVectorMaterializationCast); + typeConverter.addSourceMaterialization(vectorMaterializationCast); typeConverter.addTargetMaterialization(memrefMaterializationCast); typeConverter.addTargetMaterialization(ui32MaterializationCast); typeConverter.addTargetMaterialization(ui64MaterializationCast); |
