aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp')
-rw-r--r--mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp45
1 files changed, 22 insertions, 23 deletions
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index a8380b9..4dfcb2b 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -10,7 +10,6 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
@@ -251,7 +250,7 @@ static LLVM::CallOp createDeviceFunctionCall(
for (auto [idx, attrName] : paramAttrs)
funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
- auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
+ auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
callOp->setAttrs(funcOp->getAttrs());
return callOp;
@@ -299,7 +298,7 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
VectorType newTy = VectorType::get(
vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
if (origTy != newTy)
- val = rewriter.create<LLVM::BitcastOp>(loc, newTy, val);
+ val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
return val;
};
@@ -326,7 +325,7 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
: cOrigTy;
VectorType resTy = cTy;
if (cOrigTy != cTy)
- c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c);
+ c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
constexpr int32_t systolicDepth{8};
std::string fnName =
@@ -352,7 +351,7 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
->getResult(0);
if (resOrigTy != resTy)
- result = rewriter.create<LLVM::BitcastOp>(loc, resOrigTy, result);
+ result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
rewriter.replaceOp(op, result);
return success();
@@ -383,7 +382,7 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
auto loc = op.getLoc();
const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
Value one =
- rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 1);
+ LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
SmallVector<Value> args{op.getPtr(), one};
SmallVector<Type> argTypes;
for (auto arg : args)
@@ -439,11 +438,11 @@ class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
op, "Fence only supports workgroup and device memory scopes.");
}
Type i32Type = rewriter.getI32Type();
- Value acqRel = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 4);
+ Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
Value memScopeConst =
- rewriter.create<LLVM::ConstantOp>(loc, i32Type, memScope);
+ LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
Value addrSpaceConst =
- rewriter.create<LLVM::ConstantOp>(loc, i32Type, addrSpace);
+ LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
SmallVector<Type> argTypes{3, i32Type};
createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
@@ -477,13 +476,13 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
auto i32Type = rewriter.getI32Type();
Value byteCoord =
- rewriter.create<LLVM::UndefOp>(loc, VectorType::get(2, i32Type));
- Value zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 0);
- Value one = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 1);
- byteCoord = rewriter.create<LLVM::InsertElementOp>(
- loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
- byteCoord = rewriter.create<LLVM::InsertElementOp>(
- loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
+ LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
+ Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
+ Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
+ byteCoord = LLVM::InsertElementOp::create(
+ rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
+ byteCoord = LLVM::InsertElementOp::create(
+ rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
op.getBasePitch(), byteCoord};
SmallVector<Type> retTypes;
@@ -504,11 +503,11 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
} else {
auto vecElemType = vecType.getElementType();
auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
- Value numElems = rewriter.create<LLVM::ConstantOp>(
- loc, i32Type, vecType.getNumElements());
- auto dstOrSrcPtr = rewriter.create<LLVM::AllocaOp>(
- loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vecElemType,
- numElems);
+ Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
+ vecType.getNumElements());
+ auto dstOrSrcPtr = LLVM::AllocaOp::create(
+ rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
+ vecElemType, numElems);
args.push_back(dstOrSrcPtr);
if constexpr (isLoad) { // Load
funcName += "read";
@@ -530,7 +529,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
bitWidthId = (vecElemBitWidth == 32)
? "j"
: ((vecElemBitWidth == 16) ? "t" : "h");
- rewriter.create<LLVM::StoreOp>(loc, op.getStoredVal(), dstOrSrcPtr);
+ LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
paramAttrs = {
std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
@@ -563,7 +562,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
}
if constexpr (isLoad)
rewriter.replaceOp(
- op, rewriter.create<LLVM::LoadOp>(loc, vecType, spvLoadDstPtr));
+ op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
else
rewriter.eraseOp(op);
return success();