aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp')
-rw-r--r--mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp146
1 files changed, 145 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index f449d90..f276984 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -715,6 +715,135 @@ class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
};
//===----------------------------------------------------------------------===//
+// GPU index id operations
+//===----------------------------------------------------------------------===//
+/*
+// Launch Config ops
+// dimidx - x, y, z - is fixed to i32
+// return type is set by XeVM type converter
+// get_local_id
+xevm::WorkitemIdXOp;
+xevm::WorkitemIdYOp;
+xevm::WorkitemIdZOp;
+// get_local_size
+xevm::WorkgroupDimXOp;
+xevm::WorkgroupDimYOp;
+xevm::WorkgroupDimZOp;
+// get_group_id
+xevm::WorkgroupIdXOp;
+xevm::WorkgroupIdYOp;
+xevm::WorkgroupIdZOp;
+// get_num_groups
+xevm::GridDimXOp;
+xevm::GridDimYOp;
+xevm::GridDimZOp;
+// get_global_id : to be added if needed
+*/
+
+// Helpers to get the OpenCL function name and dimension argument for each op.
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
+ return {"get_local_id", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
+ return {"get_local_id", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
+ return {"get_local_id", 2};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
+ return {"get_local_size", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
+ return {"get_local_size", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
+ return {"get_local_size", 2};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
+ return {"get_group_id", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
+ return {"get_group_id", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
+ return {"get_group_id", 2};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
+ return {"get_num_groups", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
+ return {"get_num_groups", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
+ return {"get_num_groups", 2};
+}
+/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
+/// a constant argument for the dimension - x, y or z.
+template <typename OpType>
+class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ auto [baseName, dim] = getConfig(op);
+ Type dimTy = rewriter.getI32Type();
+ Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
+ static_cast<int64_t>(dim));
+ std::string func = mangle(baseName, {dimTy}, {true});
+ Type resTy = op.getType();
+ auto call =
+ createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
+ noUnwindWillReturnAttrs, op.getOperation());
+ constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
+ auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+ /*other=*/noModRef,
+ /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
+ call.setMemoryEffectsAttr(memAttr);
+ rewriter.replaceOp(op, call);
+ return success();
+ }
+};
+
+/*
+// Subgroup ops
+// get_sub_group_local_id
+xevm::LaneIdOp;
+// get_sub_group_id
+xevm::SubgroupIdOp;
+// get_sub_group_size
+xevm::SubgroupSizeOp;
+// get_num_sub_groups : to be added if needed
+*/
+
+// Helpers to get the OpenCL function name for each op.
+static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
+static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
+static StringRef getConfig(xevm::SubgroupSizeOp) {
+ return "get_sub_group_size";
+}
+template <typename OpType>
+class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ std::string func = mangle(getConfig(op).str(), {});
+ Type resTy = op.getType();
+ auto call =
+ createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
+ noUnwindWillReturnAttrs, op.getOperation());
+ constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
+ auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+ /*other=*/noModRef,
+ /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
+ call.setMemoryEffectsAttr(memAttr);
+ rewriter.replaceOp(op, call);
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
@@ -775,7 +904,22 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
BlockLoadStore1DToOCLPattern<BlockLoadOp>,
- BlockLoadStore1DToOCLPattern<BlockStoreOp>>(
+ BlockLoadStore1DToOCLPattern<BlockStoreOp>,
+ LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
+ LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
+ LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
+ LaunchConfigOpToOCLPattern<GridDimXOp>,
+ LaunchConfigOpToOCLPattern<GridDimYOp>,
+ LaunchConfigOpToOCLPattern<GridDimZOp>,
+ SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
+ SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
+ SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>(
patterns.getContext());
}