//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" #include "llvm/ADT/TypeSwitch.h" namespace mlir { #define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; using namespace xevm; namespace { struct LLVMFuncAttributeOptions { bool isConvergent = false; bool isNoUnwind = false; bool isWillReturn = false; LLVM::MemoryEffectsAttr memEffectsAttr{}; }; static constexpr LLVMFuncAttributeOptions noUnwindAttrs = { false, true, false, {}}; static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = { false, true, true, {}}; static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = { true, true, true, {}}; std::string getTypeMangling(Type ty, bool isUnsigned = false) { return TypeSwitch(ty) .Case([isUnsigned](VectorType ty) -> std::string { return "Dv" + std::to_string(ty.getNumElements()) + "_" + getTypeMangling(ty.getElementType(), isUnsigned); }) .Case([](Float16Type) -> std::string { return "Dh"; }) .Case([](Float32Type) -> std::string { return "f"; }) .Case([](Float64Type) -> std::string { return "d"; }) .Case([isUnsigned](IntegerType ty) -> std::string { switch (ty.getWidth()) { case 8: return isUnsigned ? "h" : "c"; case 16: return isUnsigned ? "t" : "s"; case 32: return isUnsigned ? "j" : "i"; case 64: return isUnsigned ? "m" : "l"; default: llvm_unreachable("unhandled integer type"); } }) .Default([](Type) -> std::string { llvm_unreachable("unhandled type for mangling"); }); } std::string mangle(StringRef baseName, ArrayRef types, ArrayRef isUnsigned = {}) { assert((isUnsigned.empty() || isUnsigned.size() == types.size()) && "Signedness info doesn't match"); std::string s; llvm::raw_string_ostream os(s); llvm::SmallDenseMap substitutions; os << "_Z" << baseName.size() << baseName; for (auto [idx, type] : llvm::enumerate(types)) { auto it = substitutions.find(type); if (it != substitutions.end()) { os << "S"; // First substitution is `S_`, second is `S0_`, and so on. if (unsigned firstIdx = it->getSecond(); firstIdx > 0) os << firstIdx - 1; os << "_"; } else { if (!type.isIntOrFloat()) substitutions[type] = substitutions.size(); os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]); } } return os.str(); } template int32_t getL1CacheControl(OpType op) { int32_t control = 0; if constexpr (isLoad) { switch (*op.getCacheControl()) { case LoadCacheControl::L1UC_L2UC_L3UC: case LoadCacheControl::L1UC_L2UC_L3C: case LoadCacheControl::L1UC_L2C_L3UC: case LoadCacheControl::L1UC_L2C_L3C: control = 1; break; case LoadCacheControl::L1C_L2UC_L3UC: case LoadCacheControl::L1C_L2UC_L3C: case LoadCacheControl::L1C_L2C_L3UC: case LoadCacheControl::L1C_L2C_L3C: control = 2; break; case LoadCacheControl::L1S_L2UC_L3UC: case LoadCacheControl::L1S_L2UC_L3C: case LoadCacheControl::L1S_L2C_L3UC: case LoadCacheControl::L1S_L2C_L3C: control = 3; break; case LoadCacheControl::INVALIDATE_READ: control = 4; break; } } else { switch (*op.getCacheControl()) { case StoreCacheControl::L1UC_L2UC_L3UC: case StoreCacheControl::L1UC_L2UC_L3WB: case StoreCacheControl::L1UC_L2WB_L3UC: case StoreCacheControl::L1UC_L2WB_L3WB: control = 1; break; case StoreCacheControl::L1WT_L2UC_L3UC: case StoreCacheControl::L1WT_L2UC_L3WB: case StoreCacheControl::L1WT_L2WB_L3UC: case StoreCacheControl::L1WT_L2WB_L3WB: control = 2; break; case StoreCacheControl::L1S_L2UC_L3UC: case StoreCacheControl::L1S_L2UC_L3WB: case StoreCacheControl::L1S_L2WB_L3UC: case StoreCacheControl::L1S_L2WB_L3WB: control = 3; break; case StoreCacheControl::L1WB_L2UC_L3UC: case StoreCacheControl::L1WB_L2WB_L3UC: case StoreCacheControl::L1WB_L2UC_L3WB: control = 4; break; } } return control; } template int32_t getL3CacheControl(OpType op) { int32_t control = 0; if constexpr (isLoad) { switch (*op.getCacheControl()) { case LoadCacheControl::L1UC_L2UC_L3UC: case LoadCacheControl::L1UC_L2C_L3UC: case LoadCacheControl::L1C_L2UC_L3UC: case LoadCacheControl::L1C_L2C_L3UC: case LoadCacheControl::L1S_L2UC_L3UC: case LoadCacheControl::L1S_L2C_L3UC: control = 1; break; case LoadCacheControl::L1UC_L2UC_L3C: case LoadCacheControl::L1UC_L2C_L3C: case LoadCacheControl::L1C_L2UC_L3C: case LoadCacheControl::L1C_L2C_L3C: case LoadCacheControl::L1S_L2UC_L3C: case LoadCacheControl::L1S_L2C_L3C: control = 2; break; case LoadCacheControl::INVALIDATE_READ: control = 4; break; } } else { switch (*op.getCacheControl()) { case StoreCacheControl::L1UC_L2UC_L3UC: case StoreCacheControl::L1UC_L2WB_L3UC: case StoreCacheControl::L1WT_L2UC_L3UC: case StoreCacheControl::L1WT_L2WB_L3UC: case StoreCacheControl::L1S_L2UC_L3UC: case StoreCacheControl::L1S_L2WB_L3UC: case StoreCacheControl::L1WB_L2UC_L3UC: case StoreCacheControl::L1WB_L2WB_L3UC: control = 1; break; case StoreCacheControl::L1UC_L2UC_L3WB: case StoreCacheControl::L1UC_L2WB_L3WB: case StoreCacheControl::L1WT_L2UC_L3WB: case StoreCacheControl::L1WT_L2WB_L3WB: case StoreCacheControl::L1S_L2UC_L3WB: case StoreCacheControl::L1S_L2WB_L3WB: case StoreCacheControl::L1WB_L2UC_L3WB: control = 2; break; } } return control; } template static std::optional getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) { if (!op.getCacheControl()) return {}; constexpr int32_t decorationCacheControlArity{4}; constexpr int32_t loadCacheControlKey{6442}; constexpr int32_t storeCacheControlKey{6443}; const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey}; SmallVector decorationsL1{ controlKey, 0, getL1CacheControl(op), 0}; SmallVector decorationsL3{ controlKey, 1, getL3CacheControl(op), 0}; auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1); auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3); SmallVector combinedAttrs = {arrayAttrL1, arrayAttrL3}; return rewriter.getArrayAttr(combinedAttrs); } static LLVM::CallOp createDeviceFunctionCall( ConversionPatternRewriter &rewriter, StringRef funcName, Type retType, ArrayRef argTypes, ArrayRef args, mlir::ArrayRef> paramAttrs, LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) { auto moduleOp = op->getParentWithTrait(); assert(moduleOp && "Expecting module"); Location loc = op->getLoc(); auto funcOpRes = LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType); assert(!failed(funcOpRes)); LLVM::LLVMFuncOp funcOp = funcOpRes.value(); funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC); funcOp.setConvergent(funcAttributeOptions.isConvergent); funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind); funcOp.setWillReturn(funcAttributeOptions.isWillReturn); if (funcAttributeOptions.memEffectsAttr) funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr); for (auto [idx, attrName] : paramAttrs) funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr()); auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args); callOp->setAttrs(funcOp->getAttrs()); return callOp; } class MMAToOCLPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!op.getC()) { return rewriter.notifyMatchFailure(op, "OCL requires C operand"); } auto precisionA = op.getTypes().getA(); auto precisionB = op.getTypes().getB(); auto precisionC = op.getTypes().getC(); auto precisionD = op.getTypes().getD(); if (precisionC != precisionD) { return rewriter.notifyMatchFailure(op, "type of C and D need to match"); } if (precisionC != xevm::ElemType::S32 && precisionC != xevm::ElemType::F32 && precisionC != xevm::ElemType::F16 && precisionC != xevm::ElemType::BF16) { return rewriter.notifyMatchFailure( op, "type of C and D must be S32, F32, F16 or BF16"); } if (precisionA == xevm::ElemType::S32 || precisionA == xevm::ElemType::F32) { return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32"); } if (precisionB == xevm::ElemType::S32 || precisionB == xevm::ElemType::F32) { return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32"); } constexpr uint32_t bitWidthPackedA{16}; constexpr uint32_t bitWidthPackedB{32}; auto loc = op.getLoc(); auto castIfNeeded = [&](Value val, Type packedType) -> Value { VectorType origTy = cast(val.getType()); const uint32_t vecBitSize = origTy.getNumElements() * origTy.getElementType().getIntOrFloatBitWidth(); VectorType newTy = VectorType::get( vecBitSize / packedType.getIntOrFloatBitWidth(), packedType); if (origTy != newTy) val = LLVM::BitcastOp::create(rewriter, loc, newTy, val); return val; }; Value a = op.getA(); Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32) ? cast(rewriter.getF32Type()) : rewriter.getIntegerType(bitWidthPackedA); a = castIfNeeded(a, packedAType); Value b = op.getB(); Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32) ? cast(rewriter.getF32Type()) : rewriter.getIntegerType(bitWidthPackedB); b = castIfNeeded(b, packedBType); Value c = op.getC(); VectorType cOrigTy = cast(c.getType()); VectorType resOrigTy = cast(op->getResultTypes()[0]); assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch"); // OCL builtins encode bfloat16 as int16 VectorType cTy = cOrigTy.getElementType().isBF16() ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16)) : cOrigTy; VectorType resTy = cTy; if (cOrigTy != cTy) c = LLVM::BitcastOp::create(rewriter, loc, cTy, c); constexpr int32_t systolicDepth{8}; std::string fnName = llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}", stringifyElemType(op.getTypes().getA()).str(), stringifyElemType(op.getTypes().getB()).str(), systolicDepth * getNumOperandsPerDword(op.getTypes().getA())) .str(); SmallVector argTypes{a.getType(), b.getType(), cTy}; fnName = mangle(fnName, argTypes); SmallVector args{a, b, c}; auto memAttr = rewriter.getAttr( /*other=*/LLVM::ModRefInfo::NoModRef, /*argMem=*/LLVM::ModRefInfo::NoModRef, /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); auto funcAttrs = convergentNoUnwindWillReturnAttrs; funcAttrs.memEffectsAttr = memAttr; Value result = createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {}, funcAttrs, op.getOperation()) ->getResult(0); if (resOrigTy != resTy) result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result); rewriter.replaceOp(op, result); return success(); } private: static unsigned getNumOperandsPerDword(xevm::ElemType pTy) { switch (pTy) { case xevm::ElemType::TF32: return 1; case xevm::ElemType::BF16: case xevm::ElemType::F16: return 2; case xevm::ElemType::U8: case xevm::ElemType::S8: return 4; default: llvm_unreachable("unsupported xevm::ElemType"); } } }; class PrefetchToOCLPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); const std::string fnName{"_Z8prefetchPU3AS1Kcm"}; Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1); SmallVector args{op.getPtr(), one}; SmallVector argTypes; for (auto arg : args) argTypes.push_back(arg.getType()); auto funcAttr = noUnwindAttrs; auto memAttr = rewriter.getAttr( /*other=*/LLVM::ModRefInfo::NoModRef, /*argMem=*/LLVM::ModRefInfo::Ref, /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); funcAttr.memEffectsAttr = memAttr; LLVM::CallOp call = createDeviceFunctionCall( rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()), argTypes, args, {}, funcAttr, op.getOperation()); if (std::optional optCacheControls = getCacheControlMetadata(rewriter, op)) call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); rewriter.eraseOp(op); return success(); } }; class MemfenceToOCLPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); const std::string fnName{"atomic_work_item_fence"}; int memScope, addrSpace; switch (op.getAddrspace()) { case xevm::AddrSpace::SHARED: addrSpace = 1; // CLK_LOCAL_MEM_FENCE break; case xevm::AddrSpace::GLOBAL: addrSpace = 2; // CLK_GLOBAL_MEM_FENCE break; default: // GENERIC is not supported in OpenCL return rewriter.notifyMatchFailure( op, "Fence only supports global and shared address spaces."); } switch (op.getScope()) { case xevm::MemScope::WORKGROUP: memScope = 1; break; case xevm::MemScope::DEVICE: memScope = 2; break; default: // CLUSTER and SYSTEM are not supported in OpenCL return rewriter.notifyMatchFailure( op, "Fence only supports workgroup and device memory scopes."); } Type i32Type = rewriter.getI32Type(); Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4); Value memScopeConst = LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope); Value addrSpaceConst = LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace); SmallVector args{addrSpaceConst, acqRel, memScopeConst}; SmallVector argTypes{3, i32Type}; createDeviceFunctionCall(rewriter, mangle(fnName, argTypes), LLVM::LLVMVoidType::get(rewriter.getContext()), argTypes, args, {}, noUnwindAttrs, op.getOperation()); rewriter.eraseOp(op); return success(); } }; template class LoadStorePrefetchToOCLPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { constexpr bool isLoad = std::is_same_v; constexpr bool isPrefetch = std::is_same_v; auto loc = op.getLoc(); VectorType vecType; bool packReg = false; bool transpose = false; if constexpr (isLoad) { vecType = op.getRes().getType(); packReg = op.getPackRegister(); transpose = op.getTranspose(); } else if constexpr (!isPrefetch) { vecType = op.getStoredVal().getType(); } auto i32Type = rewriter.getI32Type(); Value byteCoord = LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type)); Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0); Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1); byteCoord = LLVM::InsertElementOp::create( rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero); byteCoord = LLVM::InsertElementOp::create( rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one); SmallVector args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(), op.getBasePitch(), byteCoord}; SmallVector retTypes; Value spvLoadDstPtr; std::string funcName{"intel_sub_group_2d_block_"}; std::string bitWidthId; LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs}; SmallVector, 4> paramAttrs; if constexpr (isPrefetch) { // Prefetch funcName += "prefetch"; paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())}; auto memAttr = rewriter.getAttr( /*other=*/LLVM::ModRefInfo::NoModRef, /*argMem=*/LLVM::ModRefInfo::Ref, /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef); funcAttr = noUnwindAttrs; funcAttr.memEffectsAttr = memAttr; } else { auto vecElemType = vecType.getElementType(); auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth(); Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type, vecType.getNumElements()); auto dstOrSrcPtr = LLVM::AllocaOp::create( rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vecElemType, numElems); args.push_back(dstOrSrcPtr); if constexpr (isLoad) { // Load funcName += "read"; bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true); if (packReg) funcName += "_transform"; else if (transpose) funcName += "_transpose"; spvLoadDstPtr = dstOrSrcPtr; retTypes.push_back(vecType); paramAttrs = { std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()), std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()), std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()), std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()), }; } else { // Store funcName += "write"; bitWidthId = (vecElemBitWidth == 32) ? "j" : ((vecElemBitWidth == 16) ? "t" : "h"); LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr); paramAttrs = { std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()), std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()), std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()), std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()), }; } } funcName = llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(), op.getTileHeight(), op.getTileWidth(), op.getVBlocks()) .str(); std::string prefetchCode(""); if (!isPrefetch) prefetchCode += "P"; funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(), funcName, prefetchCode, bitWidthId) .str(); SmallVector argTypes; for (auto arg : args) { argTypes.push_back(arg.getType()); } LLVM::CallOp call = createDeviceFunctionCall( rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()), argTypes, args, paramAttrs, funcAttr, op.getOperation()); if (std::optional optCacheControls = getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) { call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); } if constexpr (isLoad) rewriter.replaceOp( op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr)); else rewriter.eraseOp(op); return success(); } }; //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// struct ConvertXeVMToLLVMPass : public impl::ConvertXeVMToLLVMPassBase { using Base::Base; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalDialect(); RewritePatternSet patterns(&getContext()); populateXeVMToLLVMConversionPatterns(patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } }; } // namespace //===----------------------------------------------------------------------===// // ConvertToLLVMPatternInterface implementation //===----------------------------------------------------------------------===// namespace { /// Implement the interface to convert XeVM to LLVM. struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface { using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; void loadDependentDialects(MLIRContext *context) const final { context->loadDialect(); } /// Hook for derived dialect interface to provide conversion patterns /// and mark dialect legal for the conversion target. void populateConvertToLLVMConversionPatterns( ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const final { populateXeVMToLLVMConversionPatterns(patterns); } }; } // namespace //===----------------------------------------------------------------------===// // Pattern Population //===----------------------------------------------------------------------===// void ::mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) { patterns.add, LoadStorePrefetchToOCLPattern, LoadStorePrefetchToOCLPattern, MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>( patterns.getContext()); } void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, XeVMDialect *dialect) { dialect->addInterfaces(); }); }