diff options
Diffstat (limited to 'mlir/lib/Conversion/XeVMToLLVM')
-rw-r--r-- | mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 85 |
1 files changed, 84 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index 57877b8..f449d90 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -214,6 +214,10 @@ static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) { return op.getCacheControl(); } +static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) { + return op.getCacheControl(); +} + static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) { return op.getCacheControl(); } @@ -222,6 +226,10 @@ static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) { return op.getCacheControl(); } +static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) { + return op.getCacheControl(); +} + static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) { if (op->hasAttr("cache_control")) { auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control"); @@ -263,6 +271,7 @@ getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) { constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> || std::is_same_v<OpType, BlockPrefetch2dOp> || std::is_same_v<OpType, LLVM::LoadOp> || + std::is_same_v<OpType, BlockLoadOp> || std::is_same_v<OpType, PrefetchOp>; const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey}; SmallVector<int32_t, decorationCacheControlArity> decorationsL1{ @@ -618,6 +627,77 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { return success(); } }; + +template <typename OpType> +class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> { + using OpConversionPattern<OpType>::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>; + // Get OpenCL function name + // https://registry.khronos.org/OpenCL/extensions/ + // intel/cl_intel_subgroup_local_block_io.html + std::string funcName{"intel_sub_group_block_"}; + // Value or Result type can be vector or scalar + Type valOrResTy; + if constexpr (isStore) { + funcName += "write_u"; + valOrResTy = op.getVal().getType(); + } else { + funcName += "read_u"; + valOrResTy = op.getType(); + } + // Get element type of the vector/scalar + VectorType vecTy = dyn_cast<VectorType>(valOrResTy); + Type elemType = vecTy ? vecTy.getElementType() : valOrResTy; + funcName += getTypeMangling(elemType); + if (vecTy) + funcName += std::to_string(vecTy.getNumElements()); + SmallVector<Type, 2> argTypes{}; + // XeVM BlockLoad/StoreOp always use signless integer types + // but OpenCL builtins expect unsigned types + // use unsigned types for mangling + SmallVector<bool, 2> isUnsigned{}; + // arg0: pointer to the src/dst address + // arg1 - only if store : vector to store + // Prepare arguments + SmallVector<Value, 2> args{}; + args.push_back(op.getPtr()); + argTypes.push_back(op.getPtr().getType()); + isUnsigned.push_back(true); + Type retType; + if constexpr (isStore) { + args.push_back(op.getVal()); + argTypes.push_back(op.getVal().getType()); + isUnsigned.push_back(true); + retType = LLVM::LLVMVoidType::get(rewriter.getContext()); + } else { + retType = valOrResTy; + } + funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName + + "PU3AS" + + std::to_string(op.getPtr().getType().getAddressSpace()); + funcName += getTypeMangling(elemType, /*isUnsigned=*/true); + if constexpr (isStore) + funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true); + LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs}; + + LLVM::CallOp call = + createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args, + {}, funcAttr, op.getOperation()); + if (std::optional<ArrayAttr> optCacheControls = + getCacheControlMetadata(rewriter, op)) { + call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); + } + if constexpr (isStore) + rewriter.eraseOp(op); + else + rewriter.replaceOp(op, call->getResult(0)); + return success(); + } +}; + template <typename OpType> class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> { using OpConversionPattern<OpType>::OpConversionPattern; @@ -693,7 +773,10 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target, LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>, MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern, LLVMLoadStoreToOCLPattern<LLVM::LoadOp>, - LLVMLoadStoreToOCLPattern<LLVM::StoreOp>>(patterns.getContext()); + LLVMLoadStoreToOCLPattern<LLVM::StoreOp>, + BlockLoadStore1DToOCLPattern<BlockLoadOp>, + BlockLoadStore1DToOCLPattern<BlockStoreOp>>( + patterns.getContext()); } void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry ®istry) { |