aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/XeVMToLLVM
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/XeVMToLLVM')
-rw-r--r--mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp85
1 files changed, 84 insertions, 1 deletions
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index 57877b8..f449d90 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -214,6 +214,10 @@ static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
return op.getCacheControl();
}
+static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
+ return op.getCacheControl();
+}
+
static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
return op.getCacheControl();
}
@@ -222,6 +226,10 @@ static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
return op.getCacheControl();
}
+static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
+ return op.getCacheControl();
+}
+
static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
if (op->hasAttr("cache_control")) {
auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
@@ -263,6 +271,7 @@ getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
std::is_same_v<OpType, BlockPrefetch2dOp> ||
std::is_same_v<OpType, LLVM::LoadOp> ||
+ std::is_same_v<OpType, BlockLoadOp> ||
std::is_same_v<OpType, PrefetchOp>;
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
@@ -618,6 +627,77 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
return success();
}
};
+
+template <typename OpType>
+class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
+ // Get OpenCL function name
+ // https://registry.khronos.org/OpenCL/extensions/
+ // intel/cl_intel_subgroup_local_block_io.html
+ std::string funcName{"intel_sub_group_block_"};
+ // Value or Result type can be vector or scalar
+ Type valOrResTy;
+ if constexpr (isStore) {
+ funcName += "write_u";
+ valOrResTy = op.getVal().getType();
+ } else {
+ funcName += "read_u";
+ valOrResTy = op.getType();
+ }
+ // Get element type of the vector/scalar
+ VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
+ Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
+ funcName += getTypeMangling(elemType);
+ if (vecTy)
+ funcName += std::to_string(vecTy.getNumElements());
+ SmallVector<Type, 2> argTypes{};
+ // XeVM BlockLoad/StoreOp always use signless integer types
+ // but OpenCL builtins expect unsigned types
+ // use unsigned types for mangling
+ SmallVector<bool, 2> isUnsigned{};
+ // arg0: pointer to the src/dst address
+ // arg1 - only if store : vector to store
+ // Prepare arguments
+ SmallVector<Value, 2> args{};
+ args.push_back(op.getPtr());
+ argTypes.push_back(op.getPtr().getType());
+ isUnsigned.push_back(true);
+ Type retType;
+ if constexpr (isStore) {
+ args.push_back(op.getVal());
+ argTypes.push_back(op.getVal().getType());
+ isUnsigned.push_back(true);
+ retType = LLVM::LLVMVoidType::get(rewriter.getContext());
+ } else {
+ retType = valOrResTy;
+ }
+ funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName +
+ "PU3AS" +
+ std::to_string(op.getPtr().getType().getAddressSpace());
+ funcName += getTypeMangling(elemType, /*isUnsigned=*/true);
+ if constexpr (isStore)
+ funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true);
+ LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
+
+ LLVM::CallOp call =
+ createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
+ {}, funcAttr, op.getOperation());
+ if (std::optional<ArrayAttr> optCacheControls =
+ getCacheControlMetadata(rewriter, op)) {
+ call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+ }
+ if constexpr (isStore)
+ rewriter.eraseOp(op);
+ else
+ rewriter.replaceOp(op, call->getResult(0));
+ return success();
+ }
+};
+
template <typename OpType>
class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
@@ -693,7 +773,10 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
- LLVMLoadStoreToOCLPattern<LLVM::StoreOp>>(patterns.getContext());
+ LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
+ BlockLoadStore1DToOCLPattern<BlockLoadOp>,
+ BlockLoadStore1DToOCLPattern<BlockStoreOp>>(
+ patterns.getContext());
}
void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry &registry) {