//===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" #include "llvm/ADT/TypeSwitch.h" #include namespace mlir { #define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { // TODO: Below are uArch dependent values, should move away from hardcoding static constexpr int32_t systolicDepth{8}; static constexpr int32_t executionSize{16}; // Offsets to individual fields of the 8xi32 layout nd tensor descriptor. enum class NdTdescOffset : uint32_t { BasePtr = 0, // Base pointer (i64) BaseShapeW = 2, // Base shape width (i32) BaseShapeH = 3, // Base shape height (i32) TensorOffsetW = 4, // Tensor offset W (i32) TensorOffsetH = 5 // Tensor offset H (i32) }; static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { switch (xeGpuMemspace) { case xegpu::MemorySpace::Global: return static_cast(xevm::AddrSpace::GLOBAL); case xegpu::MemorySpace::SLM: return static_cast(xevm::AddrSpace::SHARED); } } // Get same bitwidth flat vector type of new element type. static VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) { auto elemType = currentVecType.getElementType(); auto currentBitWidth = elemType.getIntOrFloatBitWidth(); auto newBitWidth = toElemType.getIntOrFloatBitWidth(); const int size = currentVecType.getNumElements() * currentBitWidth / newBitWidth; return VectorType::get(size, toElemType); } static xevm::LoadCacheControl translateLoadXeGPUCacheHint(std::optional L1hint, std::optional L3hint) { auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED); auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED); switch (L1hintVal) { case xegpu::CachePolicy::CACHED: if (L3hintVal == xegpu::CachePolicy::CACHED) return xevm::LoadCacheControl::L1C_L2UC_L3C; else if (L3hintVal == xegpu::CachePolicy::UNCACHED) return xevm::LoadCacheControl::L1C_L2UC_L3UC; else llvm_unreachable("Unsupported cache control."); case xegpu::CachePolicy::UNCACHED: if (L3hintVal == xegpu::CachePolicy::CACHED) return xevm::LoadCacheControl::L1UC_L2UC_L3C; else if (L3hintVal == xegpu::CachePolicy::UNCACHED) return xevm::LoadCacheControl::L1UC_L2UC_L3UC; else llvm_unreachable("Unsupported cache control."); case xegpu::CachePolicy::STREAMING: if (L3hintVal == xegpu::CachePolicy::CACHED) return xevm::LoadCacheControl::L1S_L2UC_L3C; else if (L3hintVal == xegpu::CachePolicy::UNCACHED) return xevm::LoadCacheControl::L1S_L2UC_L3UC; else llvm_unreachable("Unsupported cache control."); case xegpu::CachePolicy::READ_INVALIDATE: return xevm::LoadCacheControl::INVALIDATE_READ; default: llvm_unreachable("Unsupported cache control."); } } static xevm::StoreCacheControl translateStoreXeGPUCacheHint(std::optional L1hint, std::optional L3hint) { auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED); auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED); switch (L1hintVal) { case xegpu::CachePolicy::UNCACHED: if (L3hintVal == xegpu::CachePolicy::UNCACHED) return xevm::StoreCacheControl::L1UC_L2UC_L3UC; else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) return xevm::StoreCacheControl::L1UC_L2UC_L3WB; else llvm_unreachable("Unsupported cache control."); case xegpu::CachePolicy::STREAMING: if (L3hintVal == xegpu::CachePolicy::UNCACHED) return xevm::StoreCacheControl::L1S_L2UC_L3UC; else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) return xevm::StoreCacheControl::L1S_L2UC_L3WB; else llvm_unreachable("Unsupported cache control."); case xegpu::CachePolicy::WRITE_BACK: if (L3hintVal == xegpu::CachePolicy::UNCACHED) return xevm::StoreCacheControl::L1WB_L2UC_L3UC; else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) return xevm::StoreCacheControl::L1WB_L2UC_L3WB; else llvm_unreachable("Unsupported cache control."); case xegpu::CachePolicy::WRITE_THROUGH: if (L3hintVal == xegpu::CachePolicy::UNCACHED) return xevm::StoreCacheControl::L1WT_L2UC_L3UC; else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) return xevm::StoreCacheControl::L1WT_L2UC_L3WB; else llvm_unreachable("Unsupported cache control."); default: llvm_unreachable("Unsupported cache control."); } } class CreateNdDescToXeVMPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op, xegpu::CreateNdDescOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector mixedOffsets = op.getMixedOffsets(); if (mixedOffsets.size() != 0) return rewriter.notifyMatchFailure(op, "Offsets not supported."); auto loc = op.getLoc(); auto source = op.getSource(); // Op is lowered to a code sequence that populates payload. // Payload is a 8xi32 vector. Offset to individual fields are defined in // NdTdescOffset enum. Type payloadElemTy = rewriter.getI32Type(); VectorType payloadTy = VectorType::get(8, payloadElemTy); Type i64Ty = rewriter.getI64Type(); // 4xi64 view is used for inserting the base pointer. VectorType payloadI64Ty = VectorType::get(4, i64Ty); // Initialize payload to zero. Value payload = arith::ConstantOp::create( rewriter, loc, DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0))); Value baseAddr; Value baseShapeW; Value baseShapeH; Value offsetW; Value offsetH; // Source can be a memref or a pointer (ui64, ui32, i64 or i32). SmallVector mixedSizes = op.getMixedSizes(); // Descriptor shape is expected to be 2D. int64_t rank = mixedSizes.size(); if (rank != 2) return rewriter.notifyMatchFailure(op, "Expected 2D shape."); auto sourceTy = source.getType(); auto sourceMemrefTy = dyn_cast(sourceTy); // If source is a memref, we need to extract the aligned pointer as index. // Pointer type is passed as i32 or i64 by type converter. if (sourceMemrefTy) { if (!sourceMemrefTy.hasStaticShape()) { return rewriter.notifyMatchFailure(op, "Expected static memref shape."); } baseAddr = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); } else { baseAddr = adaptor.getSource(); } // Utility for creating offset values from op fold result. auto createOffset = [&](SmallVector &ofrVec, unsigned idx) -> Value { Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]); val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val); return val; }; // Offsets are not supported (0 is used). offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); // Get shape values from op fold results. baseShapeW = createOffset(mixedSizes, 1); baseShapeH = createOffset(mixedSizes, 0); if (sourceMemrefTy) { // Cast index to i64. baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr); } else if (baseAddr.getType() != i64Ty) { // Pointer type may be i32. Cast to i64 if needed. baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); } // Populate payload. Value payLoadAsI64 = vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload); payLoadAsI64 = vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64, static_cast(NdTdescOffset::BasePtr)); payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64); payload = vector::InsertOp::create(rewriter, loc, baseShapeW, payload, static_cast(NdTdescOffset::BaseShapeW)); payload = vector::InsertOp::create(rewriter, loc, baseShapeH, payload, static_cast(NdTdescOffset::BaseShapeH)); payload = vector::InsertOp::create( rewriter, loc, offsetW, payload, static_cast(NdTdescOffset::TensorOffsetW)); payload = vector::InsertOp::create( rewriter, loc, offsetH, payload, static_cast(NdTdescOffset::TensorOffsetH)); rewriter.replaceOp(op, payload); return success(); } }; template < typename OpType, typename = std::enable_if_t::value>> class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto mixedOffsets = op.getMixedOffsets(); int64_t opOffsetsSize = mixedOffsets.size(); if (opOffsetsSize != 2) return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); auto tdesc = adaptor.getTensorDesc(); auto tdescTy = op.getTensorDescType(); if (tdescTy.getRank() != 2) return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor."); auto elemType = tdescTy.getElementType(); auto elemBitSize = elemType.getIntOrFloatBitWidth(); if (elemBitSize % 8 != 0) return rewriter.notifyMatchFailure( op, "Expected element type bit width to be multiple of 8."); VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); Value payLoadAsI64 = vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc); Value basePtr = vector::ExtractOp::create( rewriter, loc, payLoadAsI64, static_cast(NdTdescOffset::BasePtr)); Value baseShapeW = vector::ExtractOp::create( rewriter, loc, tdesc, static_cast(NdTdescOffset::BaseShapeW)); Value baseShapeH = vector::ExtractOp::create( rewriter, loc, tdesc, static_cast(NdTdescOffset::BaseShapeH)); // Offsets are provided by the op. // convert them to i32. Value offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI32Type(), offsetW); Value offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI32Type(), offsetH); // Get address space from tensor descriptor memory space. auto ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); // Convert base pointer (i64) to LLVM pointer type. Value basePtrLLVM = LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); // Compute element byte size and surface width in bytes. Value elemByteSize = arith::ConstantIntOp::create( rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); Value surfaceW = arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); // Get tile sizes and vblocks from the tensor descriptor type. auto tileW = tdescTy.getDimSize(1); auto tileH = tdescTy.getDimSize(0); int32_t vblocks = tdescTy.getArrayLength(); if constexpr (std::is_same_v) { Value src = adaptor.getValue(); // If store value is a scalar, get value from op instead of adaptor. // Adaptor might have optimized away single element vector if (src.getType().isIntOrFloat()) { src = op.getValue(); } VectorType srcVecTy = dyn_cast(src.getType()); if (!srcVecTy) return rewriter.notifyMatchFailure( op, "Expected store value to be a vector type."); // Get flat vector type of integer type with matching element bit size. VectorType newSrcVecTy = encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); if (srcVecTy != newSrcVecTy) src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); auto storeCacheControl = translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); xevm::BlockStore2dOp::create( rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, src, xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); rewriter.eraseOp(op); } else { auto loadCacheControl = translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); if constexpr (std::is_same_v) { xevm::BlockPrefetch2dOp::create( rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); rewriter.eraseOp(op); } else { VectorType dstVecTy = cast(op.getValue().getType()); const bool vnni = op.getPacked().value_or(false); auto transposeValue = op.getTranspose(); bool transpose = transposeValue.has_value() && transposeValue.value()[0] == 1; VectorType loadedTy = encodeVectorTypeTo( dstVecTy, vnni ? rewriter.getI32Type() : rewriter.getIntegerType(elemBitSize)); Value resultFlatVec = xevm::BlockLoad2dOp::create( rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, transpose, vnni, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); resultFlatVec = vector::BitCastOp::create( rewriter, loc, encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()), resultFlatVec); rewriter.replaceOp(op, resultFlatVec); } } return success(); } }; // Add a builder that creates // offset * elemByteSize + baseAddr static Value addOffset(ConversionPatternRewriter &rewriter, Location loc, Value baseAddr, Value offset, int64_t elemByteSize) { Value byteSize = arith::ConstantIntOp::create( rewriter, loc, rewriter.getI64Type(), elemByteSize); Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize); Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset); return newAddr; } template ::value>> class LoadStoreToXeVMPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value offset = adaptor.getOffsets(); if (!offset) return rewriter.notifyMatchFailure(op, "Expected offset to be provided."); auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); auto tdescTy = op.getTensorDescType(); Value basePtrI64; // Load result or Store valye Type can be vector or scalar. Type valOrResTy; if constexpr (std::is_same_v) valOrResTy = this->getTypeConverter()->convertType(op.getResult().getType()); else valOrResTy = adaptor.getValue().getType(); VectorType valOrResVecTy = dyn_cast(valOrResTy); bool hasScalarVal = !valOrResVecTy; int64_t elemBitWidth = hasScalarVal ? valOrResTy.getIntOrFloatBitWidth() : valOrResVecTy.getElementType().getIntOrFloatBitWidth(); // Element type must be multiple of 8 bits. if (elemBitWidth % 8 != 0) return rewriter.notifyMatchFailure( op, "Expected element type bit width to be multiple of 8."); int64_t elemByteSize = elemBitWidth / 8; // Default memory space is global. LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global)); // If tensor descriptor is available, we use its memory space. if (tdescTy) ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); // Base pointer can come from source (load) or dest (store). // If they are memrefs, we use their memory space. if constexpr (std::is_same_v) { basePtrI64 = adaptor.getSource(); if (auto memRefTy = dyn_cast(op.getSource().getType())) { auto addrSpace = memRefTy.getMemorySpaceAsInt(); if (addrSpace != 0) ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace); } } else { basePtrI64 = adaptor.getDest(); if (auto memRefTy = dyn_cast(op.getDest().getType())) { auto addrSpace = memRefTy.getMemorySpaceAsInt(); if (addrSpace != 0) ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace); } } // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed. if (basePtrI64.getType() != rewriter.getI64Type()) { basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), basePtrI64); } Value mask = adaptor.getMask(); if (dyn_cast(offset.getType())) { // Offset needs be scalar. Single element vector is converted to scalar // by type converter. return rewriter.notifyMatchFailure(op, "Expected offset to be a scalar."); } else { // If offset is provided, we add them to the base pointer. // Offset is in number of elements, we need to multiply by // element byte size. basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize); } // Convert base pointer (i64) to LLVM pointer type. Value basePtrLLVM = LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); Value maskForLane; VectorType maskVecTy = dyn_cast(mask.getType()); if (maskVecTy) { // Mask needs be scalar. Single element vector is converted to scalar by // type converter. return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar."); } else maskForLane = mask; if constexpr (std::is_same_v) { scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy}, maskForLane, true, true); // If mask is true,- then clause - load from memory and yield. rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); if (!hasScalarVal) valOrResTy = VectorType::get({valOrResVecTy.getNumElements()}, valOrResVecTy.getElementType()); Value loaded = LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM); // Set cache control attribute on the load operation. loaded.getDefiningOp()->setAttr( "cache_control", xevm::LoadCacheControlAttr::get( ctxt, translateLoadXeGPUCacheHint( op.getL1Hint(), op.getL3Hint()))); scf::YieldOp::create(rewriter, loc, ValueRange{loaded}); rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); // If mask is false - else clause -yield a vector of zeros. auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType(); TypedAttr eVal; if (eTy.isFloat()) eVal = FloatAttr::get(eTy, 0.0); else eVal = IntegerAttr::get(eTy, 0); if (hasScalarVal) loaded = arith::ConstantOp::create(rewriter, loc, eVal); else loaded = arith::ConstantOp::create( rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal)); scf::YieldOp::create(rewriter, loc, ValueRange{loaded}); rewriter.replaceOp(op, ifOp.getResult(0)); } else { // If mask is true, perform the store. scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false); auto body = ifOp.getBody(); rewriter.setInsertionPointToStart(body); auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM); // Set cache control attribute on the store operation. storeOp.getOperation()->setAttr( "cache_control", xevm::StoreCacheControlAttr::get( ctxt, translateStoreXeGPUCacheHint( op.getL1Hint(), op.getL3Hint()))); rewriter.eraseOp(op); } return success(); } }; class PrefetchToXeVMPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); auto tdescTy = op.getTensorDescType(); Value basePtrI64 = adaptor.getSource(); // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed. if (basePtrI64.getType() != rewriter.getI64Type()) basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), basePtrI64); Value offsets = adaptor.getOffsets(); if (offsets) { VectorType offsetsVecTy = dyn_cast(offsets.getType()); if (offsetsVecTy) { // Offset needs be scalar. return rewriter.notifyMatchFailure(op, "Expected offsets to be a scalar."); } else { int64_t elemBitWidth{0}; int64_t elemByteSize; // Element byte size can come from three sources: if (tdescTy) { // If tensor descriptor is available, we use its element type to // determine element byte size. elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth(); } else if (auto memRefTy = dyn_cast(op.getSourceType())) { // If memref is available, we use its element type to // determine element byte size. elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth(); } else { // Otherwise, we use the provided offset byte alignment. elemByteSize = *op.getOffsetAlignByte(); } if (elemBitWidth != 0) { if (elemBitWidth % 8 != 0) return rewriter.notifyMatchFailure( op, "Expected element type bit width to be multiple of 8."); elemByteSize = elemBitWidth / 8; } basePtrI64 = addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); } } // Default memory space is global. LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global)); // If tensor descriptor is available, we use its memory space. if (tdescTy) ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); // If source is a memref, we use its memory space. if (auto memRefTy = dyn_cast(op.getSource().getType())) { auto addrSpace = memRefTy.getMemorySpaceAsInt(); if (addrSpace != 0) ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace); } // Convert base pointer (i64) to LLVM pointer type. Value ptrLLVM = LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); // Create the prefetch op with cache control attribute. xevm::PrefetchOp::create( rewriter, loc, ptrLLVM, xevm::LoadCacheControlAttr::get( ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()))); rewriter.eraseOp(op); return success(); } }; class FenceToXeVMPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); xevm::MemScope memScope{xevm::MemScope::WORKGROUP}; switch (op.getFenceScope()) { case xegpu::FenceScope::Workgroup: memScope = xevm::MemScope::WORKGROUP; break; case xegpu::FenceScope::GPU: memScope = xevm::MemScope::DEVICE; break; } xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL}; switch (op.getMemoryKind()) { case xegpu::MemorySpace::Global: addrSpace = xevm::AddrSpace::GLOBAL; break; case xegpu::MemorySpace::SLM: addrSpace = xevm::AddrSpace::SHARED; break; } xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace); rewriter.eraseOp(op); return success(); } }; class DpasToXeVMPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); auto aTy = cast(op.getLhs().getType()); auto bTy = cast(op.getRhs().getType()); auto resultType = cast(op.getResultType()); auto encodePrecision = [&](Type type) -> xevm::ElemType { if (type == rewriter.getBF16Type()) return xevm::ElemType::BF16; else if (type == rewriter.getF16Type()) return xevm::ElemType::F16; else if (type == rewriter.getTF32Type()) return xevm::ElemType::TF32; else if (type.isInteger(8)) { if (type.isUnsignedInteger()) return xevm::ElemType::U8; return xevm::ElemType::S8; } else if (type == rewriter.getF32Type()) return xevm::ElemType::F32; else if (type.isInteger(32)) return xevm::ElemType::S32; llvm_unreachable("add more support for ElemType"); }; xevm::ElemType precATy = encodePrecision(aTy.getElementType()); xevm::ElemType precBTy = encodePrecision(bTy.getElementType()); Value c = op.getAcc(); if (!c) { auto elementTy = resultType.getElementType(); Attribute initValueAttr; if (isa(elementTy)) initValueAttr = FloatAttr::get(elementTy, 0.0); else initValueAttr = IntegerAttr::get(elementTy, 0); c = arith::ConstantOp::create( rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr)); } Value aVec = op.getLhs(); Value bVec = op.getRhs(); auto cvecty = cast(c.getType()); xevm::ElemType precCTy = encodePrecision(cvecty.getElementType()); xevm::ElemType precDTy = encodePrecision(resultType.getElementType()); VectorType cNty = VectorType::get(cvecty.getNumElements(), cvecty.getElementType()); if (cvecty != cNty) c = vector::ShapeCastOp::create(rewriter, loc, cNty, c); Value dpasRes = xevm::MMAOp::create( rewriter, loc, cNty, aVec, bVec, c, xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize, systolicDepth * getNumOperandsPerDword(precATy)), xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy)); if (cvecty != cNty) dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes); rewriter.replaceOp(op, dpasRes); return success(); } private: static unsigned getNumOperandsPerDword(xevm::ElemType pTy) { switch (pTy) { case xevm::ElemType::TF32: return 1; case xevm::ElemType::BF16: case xevm::ElemType::F16: return 2; case xevm::ElemType::U8: case xevm::ElemType::S8: return 4; default: llvm_unreachable("unsupported xevm::ElemType"); } } }; static std::optional matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) { switch (arithKind) { case arith::AtomicRMWKind::addf: return LLVM::AtomicBinOp::fadd; case arith::AtomicRMWKind::addi: return LLVM::AtomicBinOp::add; case arith::AtomicRMWKind::assign: return LLVM::AtomicBinOp::xchg; case arith::AtomicRMWKind::maximumf: return LLVM::AtomicBinOp::fmax; case arith::AtomicRMWKind::maxs: return LLVM::AtomicBinOp::max; case arith::AtomicRMWKind::maxu: return LLVM::AtomicBinOp::umax; case arith::AtomicRMWKind::minimumf: return LLVM::AtomicBinOp::fmin; case arith::AtomicRMWKind::mins: return LLVM::AtomicBinOp::min; case arith::AtomicRMWKind::minu: return LLVM::AtomicBinOp::umin; case arith::AtomicRMWKind::ori: return LLVM::AtomicBinOp::_or; case arith::AtomicRMWKind::andi: return LLVM::AtomicBinOp::_and; default: return std::nullopt; } } class AtomicRMWToXeVMPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto ctxt = rewriter.getContext(); auto tdesc = op.getTensorDesc().getType(); auto ptrTypeLLVM = LLVM::LLVMPointerType::get( ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace())); Value basePtrI64 = arith::IndexCastOp::create( rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc()); Value basePtrLLVM = LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); VectorType srcOrDstVecTy = cast(op.getValue().getType()); VectorType srcOrDstFlatVecTy = VectorType::get( srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType()); Value srcFlatVec = vector::ShapeCastOp::create( rewriter, loc, srcOrDstFlatVecTy, op.getValue()); auto atomicKind = matchSimpleAtomicOp(op.getKind()); assert(atomicKind.has_value()); Value resVec = srcFlatVec; for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) { auto val = vector::ExtractOp::create(rewriter, loc, resVec, i); Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), rewriter.getIndexAttr(i)); Value currPtr = LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM, srcOrDstVecTy.getElementType(), basePtrLLVM, idx); Value newVal = LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr, val, LLVM::AtomicOrdering::seq_cst); resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i); } rewriter.replaceOp(op, resVec); return success(); } }; //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// struct ConvertXeGPUToXeVMPass : public impl::ConvertXeGPUToXeVMPassBase { using Base::Base; void runOnOperation() override { LLVMTypeConverter typeConverter(&getContext()); typeConverter.addConversion([&](VectorType type) -> Type { unsigned rank = type.getRank(); auto elemType = type.getElementType(); // If the element type is index, convert it to i64. if (llvm::isa(elemType)) elemType = IntegerType::get(&getContext(), 64); // If the vector is a scalar or has a single element, return the element if (rank < 1 || type.getNumElements() == 1) return elemType; // Otherwise, convert the vector to a flat vector type. int64_t sum = llvm::product_of(type.getShape()); return VectorType::get(sum, elemType); }); typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type { if (type.isScattered()) return IntegerType::get(&getContext(), 64); auto i32Type = IntegerType::get(&getContext(), 32); return VectorType::get(8, i32Type); }); typeConverter.addConversion([&](MemRefType type) -> Type { // Convert MemRefType to i64 type. return IntegerType::get(&getContext(), 64); }); // LLVM type converter puts unrealized casts for the following cases: // add materialization casts to handle them. // Materialization to convert memref to i64 auto memrefMaterializationCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { if (inputs.size() != 1) return {}; auto input = inputs.front(); if (auto memrefTy = dyn_cast(input.getType())) { Value addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input); return arith::IndexCastUIOp::create(builder, loc, type, addr) .getResult(); } return {}; }; // Materialization to convert ui64 to i64 auto ui64MaterializationCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { if (inputs.size() != 1) return {}; auto input = inputs.front(); if (input.getType() == builder.getIntegerType(64, false)) { Value cast = index::CastUOp::create(builder, loc, builder.getIndexType(), input) .getResult(); return arith::IndexCastUIOp::create(builder, loc, type, cast) .getResult(); } return {}; }; // Materialization to convert ui32 to i32 auto ui32MaterializationCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { if (inputs.size() != 1) return {}; auto input = inputs.front(); if (input.getType() == builder.getIntegerType(32, false)) { Value cast = index::CastUOp::create(builder, loc, builder.getIndexType(), input) .getResult(); return arith::IndexCastUIOp::create(builder, loc, type, cast) .getResult(); } return {}; }; // Materialization to convert // - single element 1D vector to scalar // - bitcast vector of same rank // - shape vector of different rank but same element type auto vectorMaterializationCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { if (inputs.size() != 1) return {}; auto input = inputs.front(); if (auto vecTy = dyn_cast(input.getType())) { if (vecTy.getNumElements() == 1) { // If the vector has a single element, return the element type. Value cast = vector::ExtractOp::create(builder, loc, input, 0).getResult(); if (vecTy.getElementType() == builder.getIndexType()) cast = arith::IndexCastUIOp::create(builder, loc, type, cast) .getResult(); return cast; } else if (auto targetVecTy = dyn_cast(type)) { // If the target type is a vector of same rank, // bitcast to the target type. if (targetVecTy.getRank() == vecTy.getRank()) return vector::BitCastOp::create(builder, loc, targetVecTy, input) .getResult(); else if (targetVecTy.getElementType() == vecTy.getElementType()) { // If the target type is a vector of different rank but same element // type, reshape to the target type. return vector::ShapeCastOp::create(builder, loc, targetVecTy, input) .getResult(); } } } return {}; }; // If result type of original op is single element vector and lowered type // is scalar. This materialization cast creates a single element vector by // broadcasting the scalar value. auto singleElementVectorMaterializationCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { if (inputs.size() != 1) return {}; auto input = inputs.front(); if (input.getType().isIntOrIndexOrFloat()) { // If the input is a scalar, and the target type is a vector of single // element, create a single element vector by broadcasting. if (auto vecTy = dyn_cast(type)) { if (vecTy.getNumElements() == 1) { return vector::BroadcastOp::create(builder, loc, vecTy, input) .getResult(); } } } return {}; }; typeConverter.addSourceMaterialization( singleElementVectorMaterializationCast); typeConverter.addTargetMaterialization(memrefMaterializationCast); typeConverter.addTargetMaterialization(ui32MaterializationCast); typeConverter.addTargetMaterialization(ui64MaterializationCast); typeConverter.addTargetMaterialization(vectorMaterializationCast); ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalDialect(); RewritePatternSet patterns(&getContext()); populateXeGPUToXeVMConversionPatterns(typeConverter, patterns); scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns, target); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } }; } // namespace //===----------------------------------------------------------------------===// // Pattern Population //===----------------------------------------------------------------------===// void mlir::populateXeGPUToXeVMConversionPatterns( const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add, LoadStorePrefetchNdToXeVMPattern, LoadStorePrefetchNdToXeVMPattern>( typeConverter, patterns.getContext()); patterns.add, LoadStoreToXeVMPattern>( typeConverter, patterns.getContext()); patterns.add(typeConverter, patterns.getContext()); }