diff options
Diffstat (limited to 'mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp')
-rw-r--r-- | mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 146 |
1 files changed, 145 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index f449d90..f276984 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -715,6 +715,135 @@ class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> { }; //===----------------------------------------------------------------------===// +// GPU index id operations +//===----------------------------------------------------------------------===// +/* +// Launch Config ops +// dimidx - x, y, z - is fixed to i32 +// return type is set by XeVM type converter +// get_local_id +xevm::WorkitemIdXOp; +xevm::WorkitemIdYOp; +xevm::WorkitemIdZOp; +// get_local_size +xevm::WorkgroupDimXOp; +xevm::WorkgroupDimYOp; +xevm::WorkgroupDimZOp; +// get_group_id +xevm::WorkgroupIdXOp; +xevm::WorkgroupIdYOp; +xevm::WorkgroupIdZOp; +// get_num_groups +xevm::GridDimXOp; +xevm::GridDimYOp; +xevm::GridDimZOp; +// get_global_id : to be added if needed +*/ + +// Helpers to get the OpenCL function name and dimension argument for each op. +static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) { + return {"get_local_id", 0}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) { + return {"get_local_id", 1}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) { + return {"get_local_id", 2}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) { + return {"get_local_size", 0}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) { + return {"get_local_size", 1}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) { + return {"get_local_size", 2}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) { + return {"get_group_id", 0}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) { + return {"get_group_id", 1}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) { + return {"get_group_id", 2}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) { + return {"get_num_groups", 0}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) { + return {"get_num_groups", 1}; +} +static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) { + return {"get_num_groups", 2}; +} +/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with +/// a constant argument for the dimension - x, y or z. +template <typename OpType> +class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> { + using OpConversionPattern<OpType>::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto [baseName, dim] = getConfig(op); + Type dimTy = rewriter.getI32Type(); + Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy, + static_cast<int64_t>(dim)); + std::string func = mangle(baseName, {dimTy}, {true}); + Type resTy = op.getType(); + auto call = + createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {}, + noUnwindWillReturnAttrs, op.getOperation()); + constexpr auto noModRef = LLVM::ModRefInfo::NoModRef; + auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( + /*other=*/noModRef, + /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); + call.setMemoryEffectsAttr(memAttr); + rewriter.replaceOp(op, call); + return success(); + } +}; + +/* +// Subgroup ops +// get_sub_group_local_id +xevm::LaneIdOp; +// get_sub_group_id +xevm::SubgroupIdOp; +// get_sub_group_size +xevm::SubgroupSizeOp; +// get_num_sub_groups : to be added if needed +*/ + +// Helpers to get the OpenCL function name for each op. +static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; } +static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; } +static StringRef getConfig(xevm::SubgroupSizeOp) { + return "get_sub_group_size"; +} +template <typename OpType> +class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> { + using OpConversionPattern<OpType>::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + std::string func = mangle(getConfig(op).str(), {}); + Type resTy = op.getType(); + auto call = + createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {}, + noUnwindWillReturnAttrs, op.getOperation()); + constexpr auto noModRef = LLVM::ModRefInfo::NoModRef; + auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>( + /*other=*/noModRef, + /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); + call.setMemoryEffectsAttr(memAttr); + rewriter.replaceOp(op, call); + return success(); + } +}; + +//===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// @@ -775,7 +904,22 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target, LLVMLoadStoreToOCLPattern<LLVM::LoadOp>, LLVMLoadStoreToOCLPattern<LLVM::StoreOp>, BlockLoadStore1DToOCLPattern<BlockLoadOp>, - BlockLoadStore1DToOCLPattern<BlockStoreOp>>( + BlockLoadStore1DToOCLPattern<BlockStoreOp>, + LaunchConfigOpToOCLPattern<WorkitemIdXOp>, + LaunchConfigOpToOCLPattern<WorkitemIdYOp>, + LaunchConfigOpToOCLPattern<WorkitemIdZOp>, + LaunchConfigOpToOCLPattern<WorkgroupDimXOp>, + LaunchConfigOpToOCLPattern<WorkgroupDimYOp>, + LaunchConfigOpToOCLPattern<WorkgroupDimZOp>, + LaunchConfigOpToOCLPattern<WorkgroupIdXOp>, + LaunchConfigOpToOCLPattern<WorkgroupIdYOp>, + LaunchConfigOpToOCLPattern<WorkgroupIdZOp>, + LaunchConfigOpToOCLPattern<GridDimXOp>, + LaunchConfigOpToOCLPattern<GridDimYOp>, + LaunchConfigOpToOCLPattern<GridDimZOp>, + SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>, + SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>, + SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>( patterns.getContext()); } |