diff options
Diffstat (limited to 'mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp')
-rw-r--r-- | mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp | 137 |
1 files changed, 69 insertions, 68 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 167cabb..3cfbd89 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -79,8 +79,8 @@ protected: uint64_t rank = type.getRank(); Value numElements = desc.size(rewriter, loc, /*pos=*/0); for (unsigned i = 1; i < rank; i++) - numElements = rewriter.create<LLVM::MulOp>( - loc, numElements, desc.size(rewriter, loc, /*pos=*/i)); + numElements = LLVM::MulOp::create(rewriter, loc, numElements, + desc.size(rewriter, loc, /*pos=*/i)); return numElements; } @@ -579,10 +579,10 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder, auto function = [&] { if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName)) return function; - return OpBuilder::atBlockEnd(module.getBody()) - .create<LLVM::LLVMFuncOp>(loc, functionName, functionType); + auto builder = OpBuilder::atBlockEnd(module.getBody()); + return LLVM::LLVMFuncOp::create(builder, loc, functionName, functionType); }(); - return builder.create<LLVM::CallOp>(loc, function, arguments); + return LLVM::CallOp::create(builder, loc, function, arguments); } // Corresponding to cusparseIndexType_t defined in cusparse.h. @@ -780,13 +780,13 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. - auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType); + auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType); Value stream = adaptor.getAsyncDependencies().empty() ? nullPtr : adaptor.getAsyncDependencies().front(); - auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>( - loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared)); + auto isHostShared = mlir::LLVM::ConstantOp::create( + rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared)); Value allocatedPtr = allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared}) @@ -1012,8 +1012,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) * static_cast<uint64_t>(memrefTy.getNumElements()); - Value sizeArg = rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(staticSize)); + Value sizeArg = LLVM::ConstantOp::create( + rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize)); llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer. llvmArgumentsWithSizes.push_back(sizeArg); } @@ -1025,8 +1025,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(), adaptor.getClusterSizeZ()}; } - rewriter.create<gpu::LaunchFuncOp>( - launchOp.getLoc(), launchOp.getKernelAttr(), + gpu::LaunchFuncOp::create( + rewriter, launchOp.getLoc(), launchOp.getKernelAttr(), gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(), adaptor.getGridSizeZ()}, gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), @@ -1048,8 +1048,8 @@ static Value bitAndAddrspaceCast(Location loc, const LLVMTypeConverter &typeConverter) { auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType()); if (destinationType.getAddressSpace() != sourceTy.getAddressSpace()) - sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>( - loc, + sourcePtr = LLVM::AddrSpaceCastOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), destinationType.getAddressSpace()), sourcePtr); @@ -1072,13 +1072,13 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc); Type elementPtrType = getElementPtrType(memRefType); - Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType); - Value gepPtr = rewriter.create<LLVM::GEPOp>( - loc, elementPtrType, + Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType); + Value gepPtr = LLVM::GEPOp::create( + rewriter, loc, elementPtrType, typeConverter->convertType(memRefType.getElementType()), nullPtr, numElements); auto sizeBytes = - rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); + LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr); auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, srcDesc.alignedPtr(rewriter, loc), @@ -1123,7 +1123,7 @@ LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite( Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc); auto value = - rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue()); + LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue()); auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, dstDesc.alignedPtr(rewriter, loc), *getTypeConverter()); @@ -1150,15 +1150,15 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite( template <typename T> static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) { Type llvmInt32Type = builder.getIntegerType(32); - return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, - static_cast<int32_t>(tValue)); + return LLVM::ConstantOp::create(builder, loc, llvmInt32Type, + static_cast<int32_t>(tValue)); } template <typename T> static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) { Type llvmFloat32Type = builder.getF32Type(); - return builder.create<LLVM::ConstantOp>( - loc, llvmFloat32Type, + return LLVM::ConstantOp::create( + builder, loc, llvmFloat32Type, builder.getF32FloatAttr(static_cast<float>(tValue))); } @@ -1189,11 +1189,11 @@ LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite( // the dnmat is used with spmat with 2:4 sparsity if (dims.size() == 2) { if (isSpMMCusparseLtOp(op.getDnTensor())) { - auto handleSz = rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(11032)); - handle = rewriter.create<LLVM::AllocaOp>( - loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); - handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle); + auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(11032)); + handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, + llvmInt8Type, handleSz, /*alignment=*/16); + handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle); createLtDnMatCallBuilder .create(loc, rewriter, @@ -1351,11 +1351,11 @@ LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite( auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); // CUDA runner asserts the size is 44104 bytes. - auto handleSz = rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(44104)); - Value handle = rewriter.create<LLVM::AllocaOp>( - loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); - handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle); + auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(44104)); + Value handle = LLVM::AllocaOp::create( + rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); + handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle); create2To4SpMatCallBuilder .create(loc, rewriter, @@ -1441,10 +1441,11 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA())); auto computeType = genConstInt32From( rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType())); - auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - rewriter.getIndexAttr(3)); - auto bufferSize = rewriter.create<LLVM::AllocaOp>( - loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16); + auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(3)); + auto bufferSize = + LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType, + three, /*alignment=*/16); createCuSparseLtSpMMBufferSizeBuilder .create(loc, rewriter, {bufferSize, modeA, modeB, adaptor.getSpmatA(), @@ -1452,20 +1453,20 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( pruneFlag, stream}) .getResult(); - auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>( - loc, llvmPointerType, llvmPointerType, bufferSize, - ValueRange{rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(1))}); - auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>( - loc, llvmPointerType, llvmPointerType, bufferSize, - ValueRange{rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(2))}); + auto bufferSizePtr1 = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, bufferSize, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(1))}); + auto bufferSizePtr2 = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, bufferSize, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(2))}); auto bufferSize0 = - rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize); + LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize); auto bufferSize1 = - rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1); + LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1); auto bufferSize2 = - rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2); + LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2); rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream}); } else { @@ -1669,28 +1670,28 @@ LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite( Location loc = op.getLoc(); auto stream = adaptor.getAsyncDependencies().front(); - auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - rewriter.getIndexAttr(3)); - auto buffer = rewriter.create<LLVM::AllocaOp>( - loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16); - - auto rowsPtr = rewriter.create<LLVM::GEPOp>( - loc, llvmPointerType, llvmPointerType, buffer, - ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - rewriter.getIndexAttr(0))}); - auto colsPtr = rewriter.create<LLVM::GEPOp>( - loc, llvmPointerType, llvmPointerType, buffer, - ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - rewriter.getIndexAttr(1))}); - auto nnzsPtr = rewriter.create<LLVM::GEPOp>( - loc, llvmPointerType, llvmPointerType, buffer, - ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - rewriter.getIndexAttr(2))}); + auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(3)); + auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, + llvmInt64Type, three, /*alignment=*/16); + + auto rowsPtr = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, buffer, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(0))}); + auto colsPtr = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, buffer, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(1))}); + auto nnzsPtr = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, buffer, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(2))}); createSpMatGetSizeBuilder.create( loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream}); - auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr); - auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr); - auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr); + auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr); + auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr); + auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr); rewriter.replaceOp(op, {rows, cols, nnzs, stream}); return success(); |