diff options
Diffstat (limited to 'mlir/lib/Conversion/GPUCommon')
-rw-r--r-- | mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 133 | ||||
-rw-r--r-- | mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp | 137 | ||||
-rw-r--r-- | mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h | 16 | ||||
-rw-r--r-- | mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h | 24 |
4 files changed, 159 insertions, 151 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 01ca5e9..1037e29 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -28,7 +28,7 @@ LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp, if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) { OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart(moduleOp.getBody()); - ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External); + ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External); } return ret; } @@ -68,9 +68,9 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart(moduleOp.getBody()); SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix); - return b.create<LLVM::GlobalOp>(loc, globalType, - /*isConstant=*/true, LLVM::Linkage::Internal, - name, attr, alignment, addrSpace); + return LLVM::GlobalOp::create(b, loc, globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, + name, attr, alignment, addrSpace); } LogicalResult @@ -151,8 +151,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, gpuFuncOp.getWorkgroupAttributionAttr( idx, LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); - auto globalOp = rewriter.create<LLVM::GlobalOp>( - gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, + auto globalOp = LLVM::GlobalOp::create( + rewriter, gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment, workgroupAddrSpace); workgroupBuffers.push_back(globalOp); @@ -220,8 +220,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, LLVM::CConv callingConvention = gpuFuncOp.isKernel() ? kernelCallingConvention : nonKernelCallingConvention; - auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>( - gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, + auto llvmFuncOp = LLVM::LLVMFuncOp::create( + rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention, /*comdat=*/nullptr, attributes); @@ -266,11 +266,11 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) { auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()); - Value address = rewriter.create<LLVM::AddressOfOp>( - loc, ptrType, global.getSymNameAttr()); + Value address = LLVM::AddressOfOp::create(rewriter, loc, ptrType, + global.getSymNameAttr()); Value memory = - rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(), - address, ArrayRef<LLVM::GEPArg>{0, 0}); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getType(), + address, ArrayRef<LLVM::GEPArg>{0, 0}); // Build a memref descriptor pointing to the buffer to plug with the // existing memref infrastructure. This may use more registers than @@ -298,15 +298,16 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, Type elementType = typeConverter->convertType(type.getElementType()); auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace); - Value numElements = rewriter.create<LLVM::ConstantOp>( - gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); + Value numElements = LLVM::ConstantOp::create( + rewriter, gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); uint64_t alignment = 0; if (auto alignAttr = dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr( idx, LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); - Value allocated = rewriter.create<LLVM::AllocaOp>( - gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment); + Value allocated = + LLVM::AllocaOp::create(rewriter, gpuFuncOp.getLoc(), ptrType, + elementType, numElements, alignment); Value descr = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), type, allocated); signatureConversion.remapInput( @@ -418,8 +419,9 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32})); /// Start the printf hostcall - Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0); - auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64); + Value zeroI64 = LLVM::ConstantOp::create(rewriter, loc, llvmI64, 0); + auto printfBeginCall = + LLVM::CallOp::create(rewriter, loc, ocklBegin, zeroI64); Value printfDesc = printfBeginCall.getResult(); // Create the global op or find an existing one. @@ -427,21 +429,21 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element and pass it to printf() - Value globalPtr = rewriter.create<LLVM::AddressOfOp>( - loc, + Value globalPtr = LLVM::AddressOfOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), global.getSymNameAttr()); Value stringStart = - rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), - globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); - Value stringLen = rewriter.create<LLVM::ConstantOp>( - loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size()); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), + globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); + Value stringLen = LLVM::ConstantOp::create( + rewriter, loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size()); - Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1); - Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0); + Value oneI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 1); + Value zeroI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 0); - auto appendFormatCall = rewriter.create<LLVM::CallOp>( - loc, ocklAppendStringN, + auto appendFormatCall = LLVM::CallOp::create( + rewriter, loc, ocklAppendStringN, ValueRange{printfDesc, stringStart, stringLen, adaptor.getArgs().empty() ? oneI32 : zeroI32}); printfDesc = appendFormatCall.getResult(); @@ -456,17 +458,18 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments; arguments.push_back(printfDesc); arguments.push_back( - rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall)); + LLVM::ConstantOp::create(rewriter, loc, llvmI32, numArgsThisCall)); for (size_t i = group; i < bound; ++i) { Value arg = adaptor.getArgs()[i]; if (auto floatType = dyn_cast<FloatType>(arg.getType())) { if (!floatType.isF64()) - arg = rewriter.create<LLVM::FPExtOp>( - loc, typeConverter->convertType(rewriter.getF64Type()), arg); - arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg); + arg = LLVM::FPExtOp::create( + rewriter, loc, typeConverter->convertType(rewriter.getF64Type()), + arg); + arg = LLVM::BitcastOp::create(rewriter, loc, llvmI64, arg); } if (arg.getType().getIntOrFloatBitWidth() != 64) - arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg); + arg = LLVM::ZExtOp::create(rewriter, loc, llvmI64, arg); arguments.push_back(arg); } @@ -477,7 +480,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( auto isLast = (bound == nArgs) ? oneI32 : zeroI32; arguments.push_back(isLast); - auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments); + auto call = LLVM::CallOp::create(rewriter, loc, ocklAppendArgs, arguments); printfDesc = call.getResult(); } rewriter.eraseOp(gpuPrintfOp); @@ -510,13 +513,13 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( /*alignment=*/0, addressSpace); // Get a pointer to the format string's first element - Value globalPtr = rewriter.create<LLVM::AddressOfOp>( - loc, + Value globalPtr = LLVM::AddressOfOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), global.getSymNameAttr()); Value stringStart = - rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), - globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), + globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); // Construct arguments and function call auto argsRange = adaptor.getArgs(); @@ -525,7 +528,7 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( printfArgs.push_back(stringStart); printfArgs.append(argsRange.begin(), argsRange.end()); - rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs); + LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs); rewriter.eraseOp(gpuPrintfOp); return success(); } @@ -559,10 +562,10 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element - Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global); + Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, global); Value stringStart = - rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), - globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), + globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); SmallVector<Type> types; SmallVector<Value> args; // Promote and pack the arguments into a stack allocation. @@ -572,27 +575,27 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( assert(type.isIntOrFloat()); if (isa<FloatType>(type)) { type = rewriter.getF64Type(); - promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg); + promotedArg = LLVM::FPExtOp::create(rewriter, loc, type, arg); } types.push_back(type); args.push_back(promotedArg); } Type structType = LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types); - Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), - rewriter.getIndexAttr(1)); + Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), + rewriter.getIndexAttr(1)); Value tempAlloc = - rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one, - /*alignment=*/0); + LLVM::AllocaOp::create(rewriter, loc, ptrType, structType, one, + /*alignment=*/0); for (auto [index, arg] : llvm::enumerate(args)) { - Value ptr = rewriter.create<LLVM::GEPOp>( - loc, ptrType, structType, tempAlloc, + Value ptr = LLVM::GEPOp::create( + rewriter, loc, ptrType, structType, tempAlloc, ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)}); - rewriter.create<LLVM::StoreOp>(loc, arg, ptr); + LLVM::StoreOp::create(rewriter, loc, arg, ptr); } std::array<Value, 2> printfArgs = {stringStart, tempAlloc}; - rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs); + LLVM::CallOp::create(rewriter, loc, vprintfDecl, printfArgs); rewriter.eraseOp(gpuPrintfOp); return success(); } @@ -607,23 +610,23 @@ static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands, TypeRange operandTypes(operands); VectorType vectorType = cast<VectorType>(llvm1DVectorTy); Location loc = op->getLoc(); - Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType); + Value result = LLVM::PoisonOp::create(rewriter, loc, vectorType); Type indexType = converter.convertType(rewriter.getIndexType()); StringAttr name = op->getName().getIdentifier(); Type elementType = vectorType.getElementType(); for (int64_t i = 0; i < vectorType.getNumElements(); ++i) { - Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i); + Value index = LLVM::ConstantOp::create(rewriter, loc, indexType, i); auto extractElement = [&](Value operand) -> Value { if (!isa<VectorType>(operand.getType())) return operand; - return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index); + return LLVM::ExtractElementOp::create(rewriter, loc, operand, index); }; auto scalarOperands = llvm::map_to_vector(operands, extractElement); Operation *scalarOp = rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs()); - result = rewriter.create<LLVM::InsertElementOp>( - loc, result, scalarOp->getResult(0), index); + result = LLVM::InsertElementOp::create(rewriter, loc, result, + scalarOp->getResult(0), index); } return result; } @@ -705,10 +708,10 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol( auto zeroSizedArrayType = LLVM::LLVMArrayType::get( typeConverter->convertType(memrefType.getElementType()), 0); - return rewriter.create<LLVM::GlobalOp>( - op->getLoc(), zeroSizedArrayType, /*isConstant=*/false, - LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte, - addressSpace.value()); + return LLVM::GlobalOp::create(rewriter, op->getLoc(), zeroSizedArrayType, + /*isConstant=*/false, LLVM::Linkage::Internal, + symName, /*value=*/Attribute(), alignmentByte, + addressSpace.value()); } LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite( @@ -732,13 +735,13 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite( // Step 3. Get address of the global symbol OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(op); - auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp); + auto basePtr = LLVM::AddressOfOp::create(rewriter, loc, shmemOp); Type baseType = basePtr->getResultTypes().front(); // Step 4. Generate GEP using offsets SmallVector<LLVM::GEPArg> gepArgs = {0}; - Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType, - basePtr, gepArgs); + Value shmemPtr = LLVM::GEPOp::create(rewriter, loc, baseType, elementType, + basePtr, gepArgs); // Step 5. Create a memref descriptor SmallVector<Value> shape, strides; Value sizeBytes; @@ -799,9 +802,9 @@ LogicalResult GPUReturnOpLowering::matchAndRewrite( return rewriter.notifyMatchFailure(op, "could not convert result types"); } - Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType); + Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType); for (auto [idx, operand] : llvm::enumerate(updatedOperands)) { - packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx); + packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx); } rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed, op->getAttrs()); diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 167cabb..3cfbd89 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -79,8 +79,8 @@ protected: uint64_t rank = type.getRank(); Value numElements = desc.size(rewriter, loc, /*pos=*/0); for (unsigned i = 1; i < rank; i++) - numElements = rewriter.create<LLVM::MulOp>( - loc, numElements, desc.size(rewriter, loc, /*pos=*/i)); + numElements = LLVM::MulOp::create(rewriter, loc, numElements, + desc.size(rewriter, loc, /*pos=*/i)); return numElements; } @@ -579,10 +579,10 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder, auto function = [&] { if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName)) return function; - return OpBuilder::atBlockEnd(module.getBody()) - .create<LLVM::LLVMFuncOp>(loc, functionName, functionType); + auto builder = OpBuilder::atBlockEnd(module.getBody()); + return LLVM::LLVMFuncOp::create(builder, loc, functionName, functionType); }(); - return builder.create<LLVM::CallOp>(loc, function, arguments); + return LLVM::CallOp::create(builder, loc, function, arguments); } // Corresponding to cusparseIndexType_t defined in cusparse.h. @@ -780,13 +780,13 @@ LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. - auto nullPtr = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmPointerType); + auto nullPtr = mlir::LLVM::ZeroOp::create(rewriter, loc, llvmPointerType); Value stream = adaptor.getAsyncDependencies().empty() ? nullPtr : adaptor.getAsyncDependencies().front(); - auto isHostShared = rewriter.create<mlir::LLVM::ConstantOp>( - loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared)); + auto isHostShared = mlir::LLVM::ConstantOp::create( + rewriter, loc, llvmInt8Type, rewriter.getI8IntegerAttr(isShared)); Value allocatedPtr = allocCallBuilder.create(loc, rewriter, {sizeBytes, stream, isHostShared}) @@ -1012,8 +1012,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( uint64_t staticSize = static_cast<uint64_t>(bitwidth / 8) * static_cast<uint64_t>(memrefTy.getNumElements()); - Value sizeArg = rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(staticSize)); + Value sizeArg = LLVM::ConstantOp::create( + rewriter, loc, getIndexType(), rewriter.getIndexAttr(staticSize)); llvmArgumentsWithSizes.push_back(llvmArg); // Presumably a bare pointer. llvmArgumentsWithSizes.push_back(sizeArg); } @@ -1025,8 +1025,8 @@ LogicalResult LegalizeLaunchFuncOpPattern::matchAndRewrite( gpu::KernelDim3{adaptor.getClusterSizeX(), adaptor.getClusterSizeY(), adaptor.getClusterSizeZ()}; } - rewriter.create<gpu::LaunchFuncOp>( - launchOp.getLoc(), launchOp.getKernelAttr(), + gpu::LaunchFuncOp::create( + rewriter, launchOp.getLoc(), launchOp.getKernelAttr(), gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(), adaptor.getGridSizeZ()}, gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), @@ -1048,8 +1048,8 @@ static Value bitAndAddrspaceCast(Location loc, const LLVMTypeConverter &typeConverter) { auto sourceTy = cast<LLVM::LLVMPointerType>(sourcePtr.getType()); if (destinationType.getAddressSpace() != sourceTy.getAddressSpace()) - sourcePtr = rewriter.create<LLVM::AddrSpaceCastOp>( - loc, + sourcePtr = LLVM::AddrSpaceCastOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), destinationType.getAddressSpace()), sourcePtr); @@ -1072,13 +1072,13 @@ LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc); Type elementPtrType = getElementPtrType(memRefType); - Value nullPtr = rewriter.create<LLVM::ZeroOp>(loc, elementPtrType); - Value gepPtr = rewriter.create<LLVM::GEPOp>( - loc, elementPtrType, + Value nullPtr = LLVM::ZeroOp::create(rewriter, loc, elementPtrType); + Value gepPtr = LLVM::GEPOp::create( + rewriter, loc, elementPtrType, typeConverter->convertType(memRefType.getElementType()), nullPtr, numElements); auto sizeBytes = - rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); + LLVM::PtrToIntOp::create(rewriter, loc, getIndexType(), gepPtr); auto src = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, srcDesc.alignedPtr(rewriter, loc), @@ -1123,7 +1123,7 @@ LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite( Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc); auto value = - rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue()); + LLVM::BitcastOp::create(rewriter, loc, bitCastType, adaptor.getValue()); auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType, dstDesc.alignedPtr(rewriter, loc), *getTypeConverter()); @@ -1150,15 +1150,15 @@ LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite( template <typename T> static Value genConstInt32From(OpBuilder &builder, Location loc, T tValue) { Type llvmInt32Type = builder.getIntegerType(32); - return builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, - static_cast<int32_t>(tValue)); + return LLVM::ConstantOp::create(builder, loc, llvmInt32Type, + static_cast<int32_t>(tValue)); } template <typename T> static Value genConstFloat32From(OpBuilder &builder, Location loc, T tValue) { Type llvmFloat32Type = builder.getF32Type(); - return builder.create<LLVM::ConstantOp>( - loc, llvmFloat32Type, + return LLVM::ConstantOp::create( + builder, loc, llvmFloat32Type, builder.getF32FloatAttr(static_cast<float>(tValue))); } @@ -1189,11 +1189,11 @@ LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite( // the dnmat is used with spmat with 2:4 sparsity if (dims.size() == 2) { if (isSpMMCusparseLtOp(op.getDnTensor())) { - auto handleSz = rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(11032)); - handle = rewriter.create<LLVM::AllocaOp>( - loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); - handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle); + auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(11032)); + handle = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, + llvmInt8Type, handleSz, /*alignment=*/16); + handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle); createLtDnMatCallBuilder .create(loc, rewriter, @@ -1351,11 +1351,11 @@ LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite( auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType)); // CUDA runner asserts the size is 44104 bytes. - auto handleSz = rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(44104)); - Value handle = rewriter.create<LLVM::AllocaOp>( - loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); - handle = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, handle); + auto handleSz = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(44104)); + Value handle = LLVM::AllocaOp::create( + rewriter, loc, llvmPointerType, llvmInt8Type, handleSz, /*alignment=*/16); + handle = LLVM::BitcastOp::create(rewriter, loc, llvmPointerType, handle); create2To4SpMatCallBuilder .create(loc, rewriter, @@ -1441,10 +1441,11 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( genConstInt32From(rewriter, loc, get2To4PruneFlag(op.getSpmatA())); auto computeType = genConstInt32From( rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType())); - auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - rewriter.getIndexAttr(3)); - auto bufferSize = rewriter.create<LLVM::AllocaOp>( - loc, llvmPointerType, llvmPointerType, three, /*alignment=*/16); + auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(3)); + auto bufferSize = + LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, llvmPointerType, + three, /*alignment=*/16); createCuSparseLtSpMMBufferSizeBuilder .create(loc, rewriter, {bufferSize, modeA, modeB, adaptor.getSpmatA(), @@ -1452,20 +1453,20 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite( pruneFlag, stream}) .getResult(); - auto bufferSizePtr1 = rewriter.create<LLVM::GEPOp>( - loc, llvmPointerType, llvmPointerType, bufferSize, - ValueRange{rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(1))}); - auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>( - loc, llvmPointerType, llvmPointerType, bufferSize, - ValueRange{rewriter.create<LLVM::ConstantOp>( - loc, getIndexType(), rewriter.getIndexAttr(2))}); + auto bufferSizePtr1 = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, bufferSize, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(1))}); + auto bufferSizePtr2 = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, bufferSize, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(2))}); auto bufferSize0 = - rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSize); + LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSize); auto bufferSize1 = - rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr1); + LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr1); auto bufferSize2 = - rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2); + LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, bufferSizePtr2); rewriter.replaceOp(op, {bufferSize0, bufferSize1, bufferSize2, stream}); } else { @@ -1669,28 +1670,28 @@ LogicalResult ConvertSpMatGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite( Location loc = op.getLoc(); auto stream = adaptor.getAsyncDependencies().front(); - auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - rewriter.getIndexAttr(3)); - auto buffer = rewriter.create<LLVM::AllocaOp>( - loc, llvmPointerType, llvmInt64Type, three, /*alignment=*/16); - - auto rowsPtr = rewriter.create<LLVM::GEPOp>( - loc, llvmPointerType, llvmPointerType, buffer, - ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - rewriter.getIndexAttr(0))}); - auto colsPtr = rewriter.create<LLVM::GEPOp>( - loc, llvmPointerType, llvmPointerType, buffer, - ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - rewriter.getIndexAttr(1))}); - auto nnzsPtr = rewriter.create<LLVM::GEPOp>( - loc, llvmPointerType, llvmPointerType, buffer, - ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(), - rewriter.getIndexAttr(2))}); + auto three = LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(3)); + auto buffer = LLVM::AllocaOp::create(rewriter, loc, llvmPointerType, + llvmInt64Type, three, /*alignment=*/16); + + auto rowsPtr = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, buffer, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(0))}); + auto colsPtr = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, buffer, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(1))}); + auto nnzsPtr = LLVM::GEPOp::create( + rewriter, loc, llvmPointerType, llvmPointerType, buffer, + ValueRange{LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(2))}); createSpMatGetSizeBuilder.create( loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream}); - auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr); - auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr); - auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr); + auto rows = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, rowsPtr); + auto cols = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, colsPtr); + auto nnzs = LLVM::LoadOp::create(rewriter, loc, llvmInt64Type, nnzsPtr); rewriter.replaceOp(op, {rows, cols, nnzs, stream}); return success(); diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h index aab2409..91c43e8 100644 --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -59,13 +59,13 @@ public: Operation *newOp; switch (op.getDimension()) { case gpu::Dimension::x: - newOp = rewriter.create<XOp>(loc, IntegerType::get(context, 32)); + newOp = XOp::create(rewriter, loc, IntegerType::get(context, 32)); break; case gpu::Dimension::y: - newOp = rewriter.create<YOp>(loc, IntegerType::get(context, 32)); + newOp = YOp::create(rewriter, loc, IntegerType::get(context, 32)); break; case gpu::Dimension::z: - newOp = rewriter.create<ZOp>(loc, IntegerType::get(context, 32)); + newOp = ZOp::create(rewriter, loc, IntegerType::get(context, 32)); break; } @@ -124,11 +124,13 @@ public: rewriter.getContext(), 32, min, max)); } if (indexBitwidth > 32) { - newOp = rewriter.create<LLVM::SExtOp>( - loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0)); + newOp = LLVM::SExtOp::create(rewriter, loc, + IntegerType::get(context, indexBitwidth), + newOp->getResult(0)); } else if (indexBitwidth < 32) { - newOp = rewriter.create<LLVM::TruncOp>( - loc, IntegerType::get(context, indexBitwidth), newOp->getResult(0)); + newOp = LLVM::TruncOp::create(rewriter, loc, + IntegerType::get(context, indexBitwidth), + newOp->getResult(0)); } rewriter.replaceOp(op, newOp->getResults()); diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index 64cf09e..9f36e5c 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -103,7 +103,7 @@ public: LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op); auto callOp = - rewriter.create<LLVM::CallOp>(op->getLoc(), funcOp, castedOperands); + LLVM::CallOp::create(rewriter, op->getLoc(), funcOp, castedOperands); if (resultType == adaptor.getOperands().front().getType()) { rewriter.replaceOp(op, {callOp.getResult()}); @@ -115,19 +115,20 @@ public: // there is no guarantee of a specific value being used to indicate true, // compare for inequality with zero (rather than truncate or shift). if (isResultBool) { - Value zero = rewriter.create<LLVM::ConstantOp>( - op->getLoc(), rewriter.getIntegerType(32), - rewriter.getI32IntegerAttr(0)); - Value truncated = rewriter.create<LLVM::ICmpOp>( - op->getLoc(), LLVM::ICmpPredicate::ne, callOp.getResult(), zero); + Value zero = LLVM::ConstantOp::create(rewriter, op->getLoc(), + rewriter.getIntegerType(32), + rewriter.getI32IntegerAttr(0)); + Value truncated = + LLVM::ICmpOp::create(rewriter, op->getLoc(), LLVM::ICmpPredicate::ne, + callOp.getResult(), zero); rewriter.replaceOp(op, {truncated}); return success(); } assert(callOp.getResult().getType().isF32() && "only f32 types are supposed to be truncated back"); - Value truncated = rewriter.create<LLVM::FPTruncOp>( - op->getLoc(), adaptor.getOperands().front().getType(), + Value truncated = LLVM::FPTruncOp::create( + rewriter, op->getLoc(), adaptor.getOperands().front().getType(), callOp.getResult()); rewriter.replaceOp(op, {truncated}); return success(); @@ -142,8 +143,9 @@ public: if (!f16Func.empty() && isa<Float16Type>(type)) return operand; - return rewriter.create<LLVM::FPExtOp>( - operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); + return LLVM::FPExtOp::create(rewriter, operand.getLoc(), + Float32Type::get(rewriter.getContext()), + operand); } Type getFunctionType(Type resultType, ValueRange operands) const { @@ -169,7 +171,7 @@ public: // location as debug info metadata inside of a function cannot be used // outside of that function. auto globalloc = op->getLoc()->findInstanceOfOrUnknown<FileLineColLoc>(); - return b.create<LLVMFuncOp>(globalloc, funcName, funcType); + return LLVMFuncOp::create(b, globalloc, funcName, funcType); } StringRef getFunctionName(Type type, SourceOp op) const { |