diff options
Diffstat (limited to 'mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp')
-rw-r--r-- | mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 133 |
1 files changed, 68 insertions, 65 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index 01ca5e9..1037e29 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -28,7 +28,7 @@ LLVM::LLVMFuncOp mlir::getOrDefineFunction(gpu::GPUModuleOp moduleOp, if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) { OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart(moduleOp.getBody()); - ret = b.create<LLVM::LLVMFuncOp>(loc, name, type, LLVM::Linkage::External); + ret = LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External); } return ret; } @@ -68,9 +68,9 @@ mlir::getOrCreateStringConstant(OpBuilder &b, Location loc, OpBuilder::InsertionGuard guard(b); b.setInsertionPointToStart(moduleOp.getBody()); SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix); - return b.create<LLVM::GlobalOp>(loc, globalType, - /*isConstant=*/true, LLVM::Linkage::Internal, - name, attr, alignment, addrSpace); + return LLVM::GlobalOp::create(b, loc, globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, + name, attr, alignment, addrSpace); } LogicalResult @@ -151,8 +151,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, gpuFuncOp.getWorkgroupAttributionAttr( idx, LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); - auto globalOp = rewriter.create<LLVM::GlobalOp>( - gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, + auto globalOp = LLVM::GlobalOp::create( + rewriter, gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment, workgroupAddrSpace); workgroupBuffers.push_back(globalOp); @@ -220,8 +220,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, LLVM::CConv callingConvention = gpuFuncOp.isKernel() ? kernelCallingConvention : nonKernelCallingConvention; - auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>( - gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, + auto llvmFuncOp = LLVM::LLVMFuncOp::create( + rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention, /*comdat=*/nullptr, attributes); @@ -266,11 +266,11 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) { auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()); - Value address = rewriter.create<LLVM::AddressOfOp>( - loc, ptrType, global.getSymNameAttr()); + Value address = LLVM::AddressOfOp::create(rewriter, loc, ptrType, + global.getSymNameAttr()); Value memory = - rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getType(), - address, ArrayRef<LLVM::GEPArg>{0, 0}); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getType(), + address, ArrayRef<LLVM::GEPArg>{0, 0}); // Build a memref descriptor pointing to the buffer to plug with the // existing memref infrastructure. This may use more registers than @@ -298,15 +298,16 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, Type elementType = typeConverter->convertType(type.getElementType()); auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace); - Value numElements = rewriter.create<LLVM::ConstantOp>( - gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); + Value numElements = LLVM::ConstantOp::create( + rewriter, gpuFuncOp.getLoc(), int64Ty, type.getNumElements()); uint64_t alignment = 0; if (auto alignAttr = dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr( idx, LLVM::LLVMDialect::getAlignAttrName()))) alignment = alignAttr.getInt(); - Value allocated = rewriter.create<LLVM::AllocaOp>( - gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment); + Value allocated = + LLVM::AllocaOp::create(rewriter, gpuFuncOp.getLoc(), ptrType, + elementType, numElements, alignment); Value descr = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), type, allocated); signatureConversion.remapInput( @@ -418,8 +419,9 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32})); /// Start the printf hostcall - Value zeroI64 = rewriter.create<LLVM::ConstantOp>(loc, llvmI64, 0); - auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64); + Value zeroI64 = LLVM::ConstantOp::create(rewriter, loc, llvmI64, 0); + auto printfBeginCall = + LLVM::CallOp::create(rewriter, loc, ocklBegin, zeroI64); Value printfDesc = printfBeginCall.getResult(); // Create the global op or find an existing one. @@ -427,21 +429,21 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element and pass it to printf() - Value globalPtr = rewriter.create<LLVM::AddressOfOp>( - loc, + Value globalPtr = LLVM::AddressOfOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), global.getSymNameAttr()); Value stringStart = - rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), - globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); - Value stringLen = rewriter.create<LLVM::ConstantOp>( - loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size()); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), + globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); + Value stringLen = LLVM::ConstantOp::create( + rewriter, loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size()); - Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1); - Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0); + Value oneI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 1); + Value zeroI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 0); - auto appendFormatCall = rewriter.create<LLVM::CallOp>( - loc, ocklAppendStringN, + auto appendFormatCall = LLVM::CallOp::create( + rewriter, loc, ocklAppendStringN, ValueRange{printfDesc, stringStart, stringLen, adaptor.getArgs().empty() ? oneI32 : zeroI32}); printfDesc = appendFormatCall.getResult(); @@ -456,17 +458,18 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments; arguments.push_back(printfDesc); arguments.push_back( - rewriter.create<LLVM::ConstantOp>(loc, llvmI32, numArgsThisCall)); + LLVM::ConstantOp::create(rewriter, loc, llvmI32, numArgsThisCall)); for (size_t i = group; i < bound; ++i) { Value arg = adaptor.getArgs()[i]; if (auto floatType = dyn_cast<FloatType>(arg.getType())) { if (!floatType.isF64()) - arg = rewriter.create<LLVM::FPExtOp>( - loc, typeConverter->convertType(rewriter.getF64Type()), arg); - arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg); + arg = LLVM::FPExtOp::create( + rewriter, loc, typeConverter->convertType(rewriter.getF64Type()), + arg); + arg = LLVM::BitcastOp::create(rewriter, loc, llvmI64, arg); } if (arg.getType().getIntOrFloatBitWidth() != 64) - arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg); + arg = LLVM::ZExtOp::create(rewriter, loc, llvmI64, arg); arguments.push_back(arg); } @@ -477,7 +480,7 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( auto isLast = (bound == nArgs) ? oneI32 : zeroI32; arguments.push_back(isLast); - auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments); + auto call = LLVM::CallOp::create(rewriter, loc, ocklAppendArgs, arguments); printfDesc = call.getResult(); } rewriter.eraseOp(gpuPrintfOp); @@ -510,13 +513,13 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( /*alignment=*/0, addressSpace); // Get a pointer to the format string's first element - Value globalPtr = rewriter.create<LLVM::AddressOfOp>( - loc, + Value globalPtr = LLVM::AddressOfOp::create( + rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()), global.getSymNameAttr()); Value stringStart = - rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), - globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), + globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); // Construct arguments and function call auto argsRange = adaptor.getArgs(); @@ -525,7 +528,7 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( printfArgs.push_back(stringStart); printfArgs.append(argsRange.begin(), argsRange.end()); - rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs); + LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs); rewriter.eraseOp(gpuPrintfOp); return success(); } @@ -559,10 +562,10 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( "printfFormat_", adaptor.getFormat()); // Get a pointer to the format string's first element - Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global); + Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, global); Value stringStart = - rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(), - globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); + LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(), + globalPtr, ArrayRef<LLVM::GEPArg>{0, 0}); SmallVector<Type> types; SmallVector<Value> args; // Promote and pack the arguments into a stack allocation. @@ -572,27 +575,27 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( assert(type.isIntOrFloat()); if (isa<FloatType>(type)) { type = rewriter.getF64Type(); - promotedArg = rewriter.create<LLVM::FPExtOp>(loc, type, arg); + promotedArg = LLVM::FPExtOp::create(rewriter, loc, type, arg); } types.push_back(type); args.push_back(promotedArg); } Type structType = LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types); - Value one = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), - rewriter.getIndexAttr(1)); + Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), + rewriter.getIndexAttr(1)); Value tempAlloc = - rewriter.create<LLVM::AllocaOp>(loc, ptrType, structType, one, - /*alignment=*/0); + LLVM::AllocaOp::create(rewriter, loc, ptrType, structType, one, + /*alignment=*/0); for (auto [index, arg] : llvm::enumerate(args)) { - Value ptr = rewriter.create<LLVM::GEPOp>( - loc, ptrType, structType, tempAlloc, + Value ptr = LLVM::GEPOp::create( + rewriter, loc, ptrType, structType, tempAlloc, ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)}); - rewriter.create<LLVM::StoreOp>(loc, arg, ptr); + LLVM::StoreOp::create(rewriter, loc, arg, ptr); } std::array<Value, 2> printfArgs = {stringStart, tempAlloc}; - rewriter.create<LLVM::CallOp>(loc, vprintfDecl, printfArgs); + LLVM::CallOp::create(rewriter, loc, vprintfDecl, printfArgs); rewriter.eraseOp(gpuPrintfOp); return success(); } @@ -607,23 +610,23 @@ static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands, TypeRange operandTypes(operands); VectorType vectorType = cast<VectorType>(llvm1DVectorTy); Location loc = op->getLoc(); - Value result = rewriter.create<LLVM::PoisonOp>(loc, vectorType); + Value result = LLVM::PoisonOp::create(rewriter, loc, vectorType); Type indexType = converter.convertType(rewriter.getIndexType()); StringAttr name = op->getName().getIdentifier(); Type elementType = vectorType.getElementType(); for (int64_t i = 0; i < vectorType.getNumElements(); ++i) { - Value index = rewriter.create<LLVM::ConstantOp>(loc, indexType, i); + Value index = LLVM::ConstantOp::create(rewriter, loc, indexType, i); auto extractElement = [&](Value operand) -> Value { if (!isa<VectorType>(operand.getType())) return operand; - return rewriter.create<LLVM::ExtractElementOp>(loc, operand, index); + return LLVM::ExtractElementOp::create(rewriter, loc, operand, index); }; auto scalarOperands = llvm::map_to_vector(operands, extractElement); Operation *scalarOp = rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs()); - result = rewriter.create<LLVM::InsertElementOp>( - loc, result, scalarOp->getResult(0), index); + result = LLVM::InsertElementOp::create(rewriter, loc, result, + scalarOp->getResult(0), index); } return result; } @@ -705,10 +708,10 @@ LLVM::GlobalOp getDynamicSharedMemorySymbol( auto zeroSizedArrayType = LLVM::LLVMArrayType::get( typeConverter->convertType(memrefType.getElementType()), 0); - return rewriter.create<LLVM::GlobalOp>( - op->getLoc(), zeroSizedArrayType, /*isConstant=*/false, - LLVM::Linkage::Internal, symName, /*value=*/Attribute(), alignmentByte, - addressSpace.value()); + return LLVM::GlobalOp::create(rewriter, op->getLoc(), zeroSizedArrayType, + /*isConstant=*/false, LLVM::Linkage::Internal, + symName, /*value=*/Attribute(), alignmentByte, + addressSpace.value()); } LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite( @@ -732,13 +735,13 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite( // Step 3. Get address of the global symbol OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(op); - auto basePtr = rewriter.create<LLVM::AddressOfOp>(loc, shmemOp); + auto basePtr = LLVM::AddressOfOp::create(rewriter, loc, shmemOp); Type baseType = basePtr->getResultTypes().front(); // Step 4. Generate GEP using offsets SmallVector<LLVM::GEPArg> gepArgs = {0}; - Value shmemPtr = rewriter.create<LLVM::GEPOp>(loc, baseType, elementType, - basePtr, gepArgs); + Value shmemPtr = LLVM::GEPOp::create(rewriter, loc, baseType, elementType, + basePtr, gepArgs); // Step 5. Create a memref descriptor SmallVector<Value> shape, strides; Value sizeBytes; @@ -799,9 +802,9 @@ LogicalResult GPUReturnOpLowering::matchAndRewrite( return rewriter.notifyMatchFailure(op, "could not convert result types"); } - Value packed = rewriter.create<LLVM::PoisonOp>(loc, packedType); + Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType); for (auto [idx, operand] : llvm::enumerate(updatedOperands)) { - packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx); + packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx); } rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed, op->getAttrs()); |