diff options
Diffstat (limited to 'mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp')
-rw-r--r-- | mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 45 |
1 files changed, 22 insertions, 23 deletions
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index a8380b9..4dfcb2b 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -10,7 +10,6 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/XeVMDialect.h" @@ -251,7 +250,7 @@ static LLVM::CallOp createDeviceFunctionCall( for (auto [idx, attrName] : paramAttrs) funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr()); - auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args); + auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args); callOp->setAttrs(funcOp->getAttrs()); return callOp; @@ -299,7 +298,7 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> { VectorType newTy = VectorType::get( vecBitSize / packedType.getIntOrFloatBitWidth(), packedType); if (origTy != newTy) - val = rewriter.create<LLVM::BitcastOp>(loc, newTy, val); + val = LLVM::BitcastOp::create(rewriter, loc, newTy, val); return val; }; @@ -326,7 +325,7 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> { : cOrigTy; VectorType resTy = cTy; if (cOrigTy != cTy) - c = rewriter.create<LLVM::BitcastOp>(loc, cTy, c); + c = LLVM::BitcastOp::create(rewriter, loc, cTy, c); constexpr int32_t systolicDepth{8}; std::string fnName = @@ -352,7 +351,7 @@ class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> { ->getResult(0); if (resOrigTy != resTy) - result = rewriter.create<LLVM::BitcastOp>(loc, resOrigTy, result); + result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result); rewriter.replaceOp(op, result); return success(); @@ -383,7 +382,7 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> { auto loc = op.getLoc(); const std::string fnName{"_Z8prefetchPU3AS1Kcm"}; Value one = - rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), 1); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1); SmallVector<Value> args{op.getPtr(), one}; SmallVector<Type> argTypes; for (auto arg : args) @@ -439,11 +438,11 @@ class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> { op, "Fence only supports workgroup and device memory scopes."); } Type i32Type = rewriter.getI32Type(); - Value acqRel = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 4); + Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4); Value memScopeConst = - rewriter.create<LLVM::ConstantOp>(loc, i32Type, memScope); + LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope); Value addrSpaceConst = - rewriter.create<LLVM::ConstantOp>(loc, i32Type, addrSpace); + LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace); SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst}; SmallVector<Type> argTypes{3, i32Type}; createDeviceFunctionCall(rewriter, mangle(fnName, argTypes), @@ -477,13 +476,13 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { auto i32Type = rewriter.getI32Type(); Value byteCoord = - rewriter.create<LLVM::UndefOp>(loc, VectorType::get(2, i32Type)); - Value zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 0); - Value one = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 1); - byteCoord = rewriter.create<LLVM::InsertElementOp>( - loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero); - byteCoord = rewriter.create<LLVM::InsertElementOp>( - loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one); + LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type)); + Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0); + Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1); + byteCoord = LLVM::InsertElementOp::create( + rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero); + byteCoord = LLVM::InsertElementOp::create( + rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one); SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(), op.getBasePitch(), byteCoord}; SmallVector<Type> retTypes; @@ -504,11 +503,11 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { } else { auto vecElemType = vecType.getElementType(); auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth(); - Value numElems = rewriter.create<LLVM::ConstantOp>( - loc, i32Type, vecType.getNumElements()); - auto dstOrSrcPtr = rewriter.create<LLVM::AllocaOp>( - loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vecElemType, - numElems); + Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type, + vecType.getNumElements()); + auto dstOrSrcPtr = LLVM::AllocaOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()), + vecElemType, numElems); args.push_back(dstOrSrcPtr); if constexpr (isLoad) { // Load funcName += "read"; @@ -530,7 +529,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { bitWidthId = (vecElemBitWidth == 32) ? "j" : ((vecElemBitWidth == 16) ? "t" : "h"); - rewriter.create<LLVM::StoreOp>(loc, op.getStoredVal(), dstOrSrcPtr); + LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr); paramAttrs = { std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()), std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()), @@ -563,7 +562,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { } if constexpr (isLoad) rewriter.replaceOp( - op, rewriter.create<LLVM::LoadOp>(loc, vecType, spvLoadDstPtr)); + op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr)); else rewriter.eraseOp(op); return success(); |